Getting Started
Reference Documentation
Advanced JAX Tutorials
Notes
custom_vjp
nondiff_argnums
Developer documentation
API documentation
jit
vmap
pmap
jax.nn.
normalize
Normalizes an array by subtracting mean and dividing by sqrt(var).
x (Any) β
Any
axis (Union[int, Tuple[int, β¦], None]) β
Union
int
Tuple
None
mean (Optional[Any]) β
Optional
variance (Optional[Any]) β
epsilon (Any) β