jax.experimental package¶
jax.experimental.optix has been moved into its own Python package
(optax).
-
jax.experimental.enable_x64(new_val=True)[source]¶ Experimental context manager to temporarily enable X64 mode.
Usage:
>>> import jax.numpy as jnp >>> with enable_x64(): ... print(jnp.arange(10.0).dtype) ... float64
See also
jax.experimental.enable_x64temporarily enable X64 mode.
- Parameters
new_val (
bool) –
-
jax.experimental.disable_x64()[source]¶ Experimental context manager to temporarily disable X64 mode.
Usage:
>>> import jax.numpy as jnp >>> with disable_x64(): ... print(jnp.arange(10.0).dtype) ... float32
See also
jax.experimental.enable_x64temporarily enable X64 mode.