Getting Started
Reference Documentation
Advanced JAX Tutorials
Notes
custom_vjp
nondiff_argnums
Developer documentation
API documentation
jit
vmap
pmap
jax.random.
shuffle
Shuffle the elements of an array uniformly at random along an axis.
key (Union[Any, PRNGKeyArray]) β a PRNG key used as the random key.
Union
Any
PRNGKeyArray
x (Any) β the array to be shuffled.
axis (int) β optional, an int axis along which to shuffle (default 0).
int
ndarray
A shuffled version of x.