Getting Started
Reference Documentation
Advanced JAX Tutorials
Notes
custom_vjp
nondiff_argnums
Developer documentation
API documentation
jit
vmap
pmap
jax.nn.
log_softmax
Log-Softmax function.
Computes the logarithm of the softmax function, which rescales elements to the range \([-\infty, 0)\).
softmax
x (Any) – input array
Any
axis (Union[int, Tuple[int, …], None]) – the axis or axes along which the log_softmax should be computed. Either an integer or a tuple of integers.
Union
int
Tuple
None