Getting Started
Reference Documentation
Advanced JAX Tutorials
Notes
custom_vjp
nondiff_argnums
Developer documentation
API documentation
jit
vmap
pmap
jax.lax.
round
Elementwise round.
Rounds values to the nearest integer.
x (Any) – an array or scalar value to round.
Any
rounding_method (RoundingMethod) – the method to use when rounding halfway values (e.g., 0.5). See lax.RoundingMethod for the list of possible values.
RoundingMethod
lax.RoundingMethod
An array containing the elementwise rounding of x.