Public API: jax package¶
Subpackages¶
Just-in-time compilation (jit)¶
|
Sets up |
Context manager that disables |
|
|
Creates a function that produces its XLA computation given example args. |
|
Creates a function that produces its jaxpr given example args. |
|
Compute the shape/dtype of |
|
Transfers |
|
Transfer array(s) to each specified device and form ShardedDeviceArray(s). |
|
Transfer array shards to specified devices and form ShardedDeviceArray(s). |
|
Transfer |
Returns the platform name of the default XLA backend. |
|
|
Adds a user specified name to a function when staging out JAX computations. |
Automatic differentiation¶
|
Creates a function that evaluates the gradient of |
|
Create a function that evaluates both |
|
Jacobian of |
|
Jacobian of |
|
Hessian of |
|
Computes a (forward-mode) Jacobian-vector product of |
|
Produces a linear approximation to |
|
Transpose a function that is promised to be linear. |
|
Compute a (reverse-mode) vector-Jacobian product of |
|
Set up a JAX-transformable function for a custom JVP rule definition. |
|
Set up a JAX-transformable function for a custom VJP rule definition. |
|
Closure conversion utility, for use with higher-order custom derivatives. |
|
Make |
Vectorization (vmap)¶
|
Vectorizing map. |
|
Define a vectorized function with broadcasting. |
Parallelization (pmap)¶
|
Parallel map with support for collective operations. |
|
Returns a list of all devices for a given backend. |
|
Like |
|
Returns the integer process index of this process. |
|
Returns the total number of devices. |
|
Returns the number of devices addressable by this process. |
|
Returns the number of JAX processes associated with the backend. |
-
jax.jit(fun, *, static_argnums=None, static_argnames=None, device=None, backend=None, donate_argnums=(), inline=False)[source]¶ Sets up
funfor just-in-time compilation with XLA.- Parameters
fun (~F) – Function to be jitted. Should be a pure function, as side-effects may only be executed once. Its arguments and return value should be arrays, scalars, or (nested) standard Python containers (tuple/list/dict) thereof. Positional arguments indicated by
static_argnumscan be anything at all, provided they are hashable and have an equality operation defined. Static arguments are included as part of a compilation cache key, which is why hash and equality operators must be defined.static_argnums (
Union[int,Iterable[int],None]) –An optional int or collection of ints that specify which positional arguments to treat as static (compile-time constant). Operations that only depend on static arguments will be constant-folded in Python (during tracing), and so the corresponding argument values can be any Python object.
Static arguments should be hashable, meaning both
__hash__and__eq__are implemented, and immutable. Calling the jitted function with different values for these constants will trigger recompilation. Arguments that are not arrays or containers thereof must be marked as static.If neither
static_argnumsnorstatic_argnamesis provided, no arguments are treated as static. Ifstatic_argnumsis not provided butstatic_argnamesis, or vice versa, JAX usesinspect.signature(fun)to find any positional arguments that correspond tostatic_argnames(or vice versa). If bothstatic_argnumsandstatic_argnamesare provided,inspect.signatureis not used, and only actual parameters listed in eitherstatic_argnumsorstatic_argnameswill be treated as static.static_argnames (
Union[str,Iterable[str],None]) – An optional string or collection of strings specifying which named arguments to treat as static (compile-time constant). See the comment onstatic_argnumsfor details. If not provided butstatic_argnumsis set, the default is based on callinginspect.signature(fun)to find corresponding named arguments.device (
Optional[Device]) – This is an experimental feature and the API is likely to change. Optional, the Device the jitted function will run on. (Available devices can be retrieved viajax.devices().) The default is inherited from XLA’s DeviceAssignment logic and is usually to usejax.devices()[0].backend (
Optional[str]) – This is an experimental feature and the API is likely to change. Optional, a string representing the XLA backend:'cpu','gpu', or'tpu'.donate_argnums (
Union[int,Iterable[int]]) – Specify which arguments are “donated” to the computation. It is safe to donate arguments if you no longer need them once the computation has finished. In some cases XLA can make use of donated buffers to reduce the amount of memory needed to perform a computation, for example recycling one of your input buffers to store a result. You should not reuse buffers that you donate to a computation, JAX will raise an error if you try to. By default, no arguments are donated.inline (
bool) – Specify whether this function should be inlined into enclosing jaxprs (rather than being represented as an application of the xla_call primitive with its own subjaxpr). Default False.
- Return type
~F
- Returns
A wrapped version of
fun, set up for just-in-time compilation.
In the following example,
selucan be compiled into a single fused kernel by XLA:>>> import jax >>> >>> @jax.jit ... def selu(x, alpha=1.67, lmbda=1.05): ... return lmbda * jax.numpy.where(x > 0, x, alpha * jax.numpy.exp(x) - alpha) >>> >>> key = jax.random.PRNGKey(0) >>> x = jax.random.normal(key, (10,)) >>> print(selu(x)) [-0.54485 0.27744 -0.29255 -0.91421 -0.62452 -0.24748 -0.85743 -0.78232 0.76827 0.59566 ]
-
jax.disable_jit()[source]¶ Context manager that disables
jit()behavior under its dynamic context.For debugging it is useful to have a mechanism that disables
jit()everywhere in a dynamic context.Values that have a data dependence on the arguments to a jitted function are traced and abstracted. For example, an abstract value may be a
ShapedArrayinstance, representing the set of all possible arrays with a given shape and dtype, but not representing one concrete array with specific values. You might notice those if you use a benign side-effecting operation in a jitted function, like a print:>>> import jax >>> >>> @jax.jit ... def f(x): ... y = x * 2 ... print("Value of y is", y) ... return y + 3 ... >>> print(f(jax.numpy.array([1, 2, 3]))) Value of y is Traced<ShapedArray(int32[3])>with<DynamicJaxprTrace(level=0/1)> [5 7 9]
Here
yhas been abstracted byjit()to aShapedArray, which represents an array with a fixed shape and type but an arbitrary value. The value ofyis also traced. If we want to see a concrete value while debugging, and avoid the tracer too, we can use thedisable_jit()context manager:>>> import jax >>> >>> with jax.disable_jit(): ... print(f(jax.numpy.array([1, 2, 3]))) ... Value of y is [2 4 6] [5 7 9]
-
jax.xla_computation(fun, static_argnums=(), axis_env=None, in_parts=None, out_parts=None, backend=None, tuple_args=False, instantiate_const_outputs=None, return_shape=False, donate_argnums=())[source]¶ Creates a function that produces its XLA computation given example args.
- Parameters
fun (
Callable) – Function from which to form XLA computations.static_argnums (
Union[int,Iterable[int]]) – See thejax.jit()docstring.axis_env (
Optional[Sequence[Tuple[Any,int]]]) – Optional, a sequence of pairs where the first element is an axis name and the second element is a positive integer representing the size of the mapped axis with that name. This parameter is useful when lowering functions that involve parallel communication collectives, and it specifies the axis name/size environment that would be set up by applications ofjax.pmap(). See the examples below.in_parts – Optional, how each argument to
funshould be partitioned or replicated. This is used to specify partitioned XLA computations, seesharded_jitfor more info.out_parts – Optional, how each output of
funshould be partitioned or replicated. This is used to specify partitioned XLA computations, seesharded_jitfor more info.backend (
Optional[str]) – This is an experimental feature and the API is likely to change. Optional, a string representing the XLA backend:'cpu','gpu', or'tpu'.tuple_args (
bool) – Optional bool, defaults toFalse. IfTrue, the resulting XLA computation will have a single tuple argument that is unpacked into the specified function arguments. If None, tupling will be enabled when there are more than 100 arguments, since some platforms have limits on argument arity.instantiate_const_outputs (
Optional[bool]) – Deprecated argument, does nothing.return_shape (
bool) – Optional boolean, defaults toFalse. IfTrue, the wrapped function returns a pair where the first element is the XLA computation and the second element is a pytree with the same structure as the output offunand where the leaves are objects withshape,dtype, andnamed_shapeattributes representing the corresponding types of the output leaves.donate_argnums (
Union[int,Iterable[int]]) – Specify which arguments are “donated” to the computation. It is safe to donate arguments if you no longer need them once the computation has finished. In some cases XLA can make use of donated buffers to reduce the amount of memory needed to perform a computation, for example recycling one of your input buffers to store a result. You should not reuse buffers that you donate to a computation, JAX will raise an error if you try to.
- Return type
- Returns
A wrapped version of
funthat when applied to example arguments returns a built XLA Computation (see xla_client.py), from which representations of the unoptimized XLA HLO computation can be extracted using methods likeas_hlo_text,as_serialized_hlo_module_proto, andas_hlo_dot_graph. If the argumentreturn_shapeisTrue, then the wrapped function returns a pair where the first element is the XLA Computation and the second element is a pytree representing the structure, shapes, dtypes, and named shapes of the output offun.Concrete example arguments are not always necessary. For those arguments not indicated by
static_argnums, any object withshapeanddtypeattributes is acceptable (excepting namedtuples, which are treated as Python containers).
For example:
>>> import jax >>> >>> def f(x): return jax.numpy.sin(jax.numpy.cos(x)) >>> c = jax.xla_computation(f)(3.) >>> print(c.as_hlo_text()) HloModule xla_computation_f.6 ENTRY xla_computation_f.6 { constant.2 = pred[] constant(false) parameter.1 = f32[] parameter(0) cosine.3 = f32[] cosine(parameter.1) sine.4 = f32[] sine(cosine.3) ROOT tuple.5 = (f32[]) tuple(sine.4) }
Alternatively, the assignment to
cabove could be written:>>> import types >>> scalar = types.SimpleNamespace(shape=(), dtype=np.dtype(np.float32)) >>> c = jax.xla_computation(f)(scalar)
Here’s an example that involves a parallel collective and axis name:
>>> def f(x): return x - jax.lax.psum(x, 'i') >>> c = jax.xla_computation(f, axis_env=[('i', 4)])(2) >>> print(c.as_hlo_text()) HloModule jaxpr_computation.9 primitive_computation.3 { parameter.4 = s32[] parameter(0) parameter.5 = s32[] parameter(1) ROOT add.6 = s32[] add(parameter.4, parameter.5) } ENTRY jaxpr_computation.9 { tuple.1 = () tuple() parameter.2 = s32[] parameter(0) all-reduce.7 = s32[] all-reduce(parameter.2), replica_groups={{0,1,2,3}}, to_apply=primitive_computation.3 ROOT subtract.8 = s32[] subtract(parameter.2, all-reduce.7) }
Notice the
replica_groupsthat were generated. Here’s an example that generates more interestingreplica_groups:>>> from jax import lax >>> def g(x): ... rowsum = lax.psum(x, 'i') ... colsum = lax.psum(x, 'j') ... allsum = lax.psum(x, ('i', 'j')) ... return rowsum, colsum, allsum ... >>> axis_env = [('i', 4), ('j', 2)] >>> c = xla_computation(g, axis_env=axis_env)(5.) >>> print(c.as_hlo_text()) HloModule jaxpr_computation__1.19 [removed uninteresting text here] ENTRY jaxpr_computation__1.19 { tuple.1 = () tuple() parameter.2 = f32[] parameter(0) all-reduce.7 = f32[] all-reduce(parameter.2), replica_groups={{0,2,4,6},{1,3,5,7}}, to_apply=primitive_computation__1.3 all-reduce.12 = f32[] all-reduce(parameter.2), replica_groups={{0,1},{2,3},{4,5},{6,7}}, to_apply=primitive_computation__1.8 all-reduce.17 = f32[] all-reduce(parameter.2), replica_groups={{0,1,2,3,4,5,6,7}}, to_apply=primitive_computation__1.13 ROOT tuple.18 = (f32[], f32[], f32[]) tuple(all-reduce.7, all-reduce.12, all-reduce.17) }
-
jax.make_jaxpr(fun, static_argnums=(), axis_env=None, return_shape=False)[source]¶ Creates a function that produces its jaxpr given example args.
- Parameters
fun (
Callable) – The function whosejaxpris to be computed. Its positional arguments and return value should be arrays, scalars, or standard Python containers (tuple/list/dict) thereof.static_argnums (
Union[int,Iterable[int]]) – See thejax.jit()docstring.axis_env (
Optional[Sequence[Tuple[Any,int]]]) – Optional, a sequence of pairs where the first element is an axis name and the second element is a positive integer representing the size of the mapped axis with that name. This parameter is useful when lowering functions that involve parallel communication collectives, and it specifies the axis name/size environment that would be set up by applications ofjax.pmap().return_shape (
bool) – Optional boolean, defaults toFalse. IfTrue, the wrapped function returns a pair where the first element is the XLA computation and the second element is a pytree with the same structure as the output offunand where the leaves are objects withshape,dtype, andnamed_shapeattributes representing the corresponding types of the output leaves.
- Return type
Callable[…,ClosedJaxpr]- Returns
A wrapped version of
funthat when applied to example arguments returns aClosedJaxprrepresentation offunon those arguments. If the argumentreturn_shapeisTrue, then the returned function instead returns a pair where the first element is theClosedJaxprrepresentation offunand the second element is a pytree representing the structure, shape, dtypes, and named shapes of the output offun.
A
jaxpris JAX’s intermediate representation for program traces. Thejaxprlanguage is based on the simply-typed first-order lambda calculus with let-bindings.make_jaxpr()adapts a function to return itsjaxpr, which we can inspect to understand what JAX is doing internally. Thejaxprreturned is a trace offunabstracted toShapedArraylevel. Other levels of abstraction exist internally.We do not describe the semantics of the
jaxprlanguage in detail here, but instead give a few examples.>>> import jax >>> >>> def f(x): return jax.numpy.sin(jax.numpy.cos(x)) >>> print(f(3.0)) -0.83602 >>> jax.make_jaxpr(f)(3.0) { lambda ; a. let b = cos a c = sin b in (c,) } >>> jax.make_jaxpr(jax.grad(f))(3.0) { lambda ; a. let b = cos a c = sin a _ = sin b d = cos b e = mul 1.0 d f = neg e g = mul f c in (g,) }
-
jax.eval_shape(fun, *args, **kwargs)[source]¶ Compute the shape/dtype of
funwithout any FLOPs.This utility function is useful for performing shape inference. Its input/output behavior is defined by:
def eval_shape(fun, *args, **kwargs): out = fun(*args, **kwargs) return jax.tree_util.tree_map(shape_dtype_struct, out) def shape_dtype_struct(x): return ShapeDtypeStruct(x.shape, x.dtype) class ShapeDtypeStruct: __slots__ = ["shape", "dtype"] def __init__(self, shape, dtype): self.shape = shape self.dtype = dtype
In particular, the output is a pytree of objects that have
shapeanddtypeattributes, but nothing else about them is guaranteed by the API.But instead of applying
fundirectly, which might be expensive, it uses JAX’s abstract interpretation machinery to evaluate the shapes without doing any FLOPs.Using
eval_shape()can also catch shape errors, and will raise same shape errors as evaluatingfun(*args, **kwargs).- Parameters
fun (
Callable) – The function whose output shape should be evaluated.*args – a positional argument tuple of arrays, scalars, or (nested) standard Python containers (tuples, lists, dicts, namedtuples, i.e. pytrees) of those types. Since only the
shapeanddtypeattributes are accessed, only values that duck-type arrays are required, rather than real ndarrays. The duck-typed objects cannot be namedtuples because those are treated as standard Python containers. See the example below.**kwargs – a keyword argument dict of arrays, scalars, or (nested) standard Python containers (pytrees) of those types. As in
args, array values need only be duck-typed to haveshapeanddtypeattributes.
For example:
>>> import jax >>> import jax.numpy as jnp >>> >>> f = lambda A, x: jnp.tanh(jnp.dot(A, x)) >>> class MyArgArray(object): ... def __init__(self, shape, dtype): ... self.shape = shape ... self.dtype = jnp.dtype(dtype) ... >>> A = MyArgArray((2000, 3000), jnp.float32) >>> x = MyArgArray((3000, 1000), jnp.float32) >>> out = jax.eval_shape(f, A, x) # no FLOPs performed >>> print(out.shape) (2000, 1000) >>> print(out.dtype) float32
-
jax.device_put(x, device=None)[source]¶ Transfers
xtodevice.- Parameters
x – An array, scalar, or (nested) standard Python container thereof.
device (
Optional[Device]) – The (optional)Deviceto whichxshould be transferred. If given, then the result is committed to the device.
If the
deviceparameter isNone, then this operation behaves like the identity function if the operand is on any device already, otherwise it transfers the data to the default device, uncommitted.For more details on data placement see the FAQ on data placement.
- Returns
A copy of
xthat resides ondevice.
-
jax.device_put_replicated(x, devices)[source]¶ Transfer array(s) to each specified device and form ShardedDeviceArray(s).
- Parameters
- Returns
A ShardedDeviceArray or (nested) Python container thereof representing the value of
xbroadcasted along a new leading axis of sizelen(devices), with each slice along that new leading axis backed by memory on the device specified by the corresponding entry indevices.
Examples
Passing an array:
>>> import jax >>> devices = jax.local_devices() >>> x = jax.numpy.array([1., 2., 3.]) >>> y = jax.device_put_replicated(x, devices) >>> np.allclose(y, jax.numpy.stack([x for _ in devices])) True
See also
device_put
device_put_sharded
-
jax.device_put_sharded(shards, devices)[source]¶ Transfer array shards to specified devices and form ShardedDeviceArray(s).
- Parameters
shards (
Sequence[Any]) – A sequence of arrays, scalars, or (nested) standard Python containers thereof representing the shards to be stacked together to form the output. The length ofshardsmust equal the length ofdevices.devices (
Sequence[Device]) – A sequence ofDeviceinstances representing the devices to which corresponding shards inshardswill be transferred.
- Returns
A ShardedDeviceArray or (nested) Python container thereof representing the elements of
shardsstacked together, with each shard backed by physical device memory specified by the corresponding entry indevices.
Examples
Passing a list of arrays for
shardsresults in a sharded array containing a stacked version of the inputs:>>> import jax >>> devices = jax.local_devices() >>> x = [jax.numpy.ones(5) for device in devices] >>> y = jax.device_put_sharded(x, devices) >>> np.allclose(y, jax.numpy.stack(x)) True
Passing a list of nested container objects with arrays at the leaves for
shardscorresponds to stacking the shards at each leaf. This requires all entries in the list to have the same tree structure:>>> x = [(i, jax.numpy.arange(i, i + 4)) for i in range(len(devices))] >>> y = jax.device_put_sharded(x, devices) >>> type(y) <class 'tuple'> >>> y0 = jax.device_put_sharded([a for a, b in x], devices) >>> y1 = jax.device_put_sharded([b for a, b in x], devices) >>> np.allclose(y[0], y0) True >>> np.allclose(y[1], y1) True
See also
device_put
device_put_replicated
-
jax.device_get(x)[source]¶ Transfer
xto host.- Parameters
x (
Any) – An array, scalar, DeviceArray or (nested) standard Python container thereof representing the array to be transferred to host.- Returns
An array or (nested) Python container thereof representing the value of
x.
Examples
Passing a DeviceArray:
>>> import jax >>> x = jax.numpy.array([1., 2., 3.]) >>> jax.device_get(x) array([1., 2., 3.], dtype=float32)
Passing a scalar (has no effect):
>>> jax.device_get(1) 1
See also
device_put
device_put_sharded
device_put_replicated
-
jax.named_call(fun, *, name=None)[source]¶ Adds a user specified name to a function when staging out JAX computations.
When staging out computations for just-in-time compilation to XLA (or other backends such as TensorFlow) JAX runs your Python program but by default does not preserve any of the function names or other metadata associated with it. This can make debugging the staged out (and/or compiled) representation of your program complicated because there is limited context information for each operation being executed.
named_call tells JAX to stage the given function out as a subcomputation with a specific name. When the staged out program is compiled with XLA these named subcomputations are preserved and show up in debugging utilities like the TensorFlow Profiler in TensorBoard. Names are also preserved when staging out JAX programs to TensorFlow using
experimental.jax2tf.convert().- Parameters
- Return type
- Returns
A version of fun that is wrapped in a name_scope.
-
jax.grad(fun, argnums=0, has_aux=False, holomorphic=False, allow_int=False, reduce_axes=())[source]¶ Creates a function that evaluates the gradient of
fun.- Parameters
fun (
Callable) – Function to be differentiated. Its arguments at positions specified byargnumsshould be arrays, scalars, or standard Python containers. Argument arrays in the positions specified byargnumsmust be of inexact (i.e., floating-point or complex) type. It should return a scalar (which includes arrays with shape()but not arrays with shape(1,)etc.)argnums (
Union[int,Sequence[int]]) – Optional, integer or sequence of integers. Specifies which positional argument(s) to differentiate with respect to (default 0).has_aux (
bool) – Optional, bool. Indicates whetherfunreturns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False.holomorphic (
bool) – Optional, bool. Indicates whetherfunis promised to be holomorphic. If True, inputs and outputs must be complex. Default False.allow_int (
bool) – Optional, bool. Whether to allow differentiating with respect to integer valued inputs. The gradient of an integer input will have a trivial vector-space dtype (float0). Default False.reduce_axes (
Sequence[Any]) – Optional, tuple of axis names. If an axis is listed here, andfunimplicitly broadcasts a value over that axis, the backward pass will perform apsumof the corresponding gradient. Otherwise, the gradient will be per-example over named axes. For example, if'batch'is a named batch axis,grad(f, reduce_axes=('batch',))will create a function that computes the total gradient whilegrad(f)will create one that computes the per-example gradient.
- Return type
- Returns
A function with the same arguments as
fun, that evaluates the gradient offun. Ifargnumsis an integer then the gradient has the same shape and type as the positional argument indicated by that integer. If argnums is a tuple of integers, the gradient is a tuple of values with the same shapes and types as the corresponding arguments. Ifhas_auxis True then a pair of (gradient, auxiliary_data) is returned.
For example:
>>> import jax >>> >>> grad_tanh = jax.grad(jax.numpy.tanh) >>> print(grad_tanh(0.2)) 0.961043
-
jax.value_and_grad(fun, argnums=0, has_aux=False, holomorphic=False, allow_int=False, reduce_axes=())[source]¶ Create a function that evaluates both
funand the gradient offun.- Parameters
fun (
Callable) – Function to be differentiated. Its arguments at positions specified byargnumsshould be arrays, scalars, or standard Python containers. It should return a scalar (which includes arrays with shape()but not arrays with shape(1,)etc.)argnums (
Union[int,Sequence[int]]) – Optional, integer or sequence of integers. Specifies which positional argument(s) to differentiate with respect to (default 0).has_aux (
bool) – Optional, bool. Indicates whetherfunreturns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False.holomorphic (
bool) – Optional, bool. Indicates whetherfunis promised to be holomorphic. If True, inputs and outputs must be complex. Default False.allow_int (
bool) – Optional, bool. Whether to allow differentiating with respect to integer valued inputs. The gradient of an integer input will have a trivial vector-space dtype (float0). Default False.reduce_axes (
Sequence[Any]) – Optional, tuple of axis names. If an axis is listed here, andfunimplicitly broadcasts a value over that axis, the backward pass will perform apsumof the corresponding gradient. Otherwise, the gradient will be per-example over named axes. For example, if'batch'is a named batch axis,value_and_grad(f, reduce_axes=('batch',))will create a function that computes the total gradient whilevalue_and_grad(f)will create one that computes the per-example gradient.
- Return type
- Returns
A function with the same arguments as
funthat evaluates bothfunand the gradient offunand returns them as a pair (a two-element tuple). Ifargnumsis an integer then the gradient has the same shape and type as the positional argument indicated by that integer. If argnums is a sequence of integers, the gradient is a tuple of values with the same shapes and types as the corresponding arguments.
-
jax.jacfwd(fun, argnums=0, holomorphic=False)[source]¶ Jacobian of
funevaluated column-by-column using forward-mode AD.- Parameters
fun (
Callable) – Function whose Jacobian is to be computed.argnums (
Union[int,Sequence[int]]) – Optional, integer or sequence of integers. Specifies which positional argument(s) to differentiate with respect to (default0).holomorphic (
bool) – Optional, bool. Indicates whetherfunis promised to be holomorphic. Default False.
- Return type
- Returns
A function with the same arguments as
fun, that evaluates the Jacobian offunusing forward-mode automatic differentiation.
>>> import jax >>> import jax.numpy as jnp >>> >>> def f(x): ... return jnp.asarray( ... [x[0], 5*x[2], 4*x[1]**2 - 2*x[2], x[2] * jnp.sin(x[0])]) ... >>> print(jax.jacfwd(f)(jnp.array([1., 2., 3.]))) [[ 1. 0. 0. ] [ 0. 0. 5. ] [ 0. 16. -2. ] [ 1.6209 0. 0.84147]]
-
jax.jacrev(fun, argnums=0, holomorphic=False, allow_int=False)[source]¶ Jacobian of
funevaluated row-by-row using reverse-mode AD.- Parameters
fun (
Callable) – Function whose Jacobian is to be computed.argnums (
Union[int,Sequence[int]]) – Optional, integer or sequence of integers. Specifies which positional argument(s) to differentiate with respect to (default0).holomorphic (
bool) – Optional, bool. Indicates whetherfunis promised to be holomorphic. Default False.allow_int (
bool) – Optional, bool. Whether to allow differentiating with respect to integer valued inputs. The gradient of an integer input will have a trivial vector-space dtype (float0). Default False.
- Return type
- Returns
A function with the same arguments as
fun, that evaluates the Jacobian offunusing reverse-mode automatic differentiation.
>>> import jax >>> import jax.numpy as jnp >>> >>> def f(x): ... return jnp.asarray( ... [x[0], 5*x[2], 4*x[1]**2 - 2*x[2], x[2] * jnp.sin(x[0])]) ... >>> print(jax.jacrev(f)(jnp.array([1., 2., 3.]))) [[ 1. 0. 0. ] [ 0. 0. 5. ] [ 0. 16. -2. ] [ 1.6209 0. 0.84147]]
-
jax.hessian(fun, argnums=0, holomorphic=False)[source]¶ Hessian of
funas a dense array.- Parameters
fun (
Callable) – Function whose Hessian is to be computed. Its arguments at positions specified byargnumsshould be arrays, scalars, or standard Python containers thereof. It should return arrays, scalars, or standard Python containers thereof.argnums (
Union[int,Sequence[int]]) – Optional, integer or sequence of integers. Specifies which positional argument(s) to differentiate with respect to (default0).holomorphic (
bool) – Optional, bool. Indicates whetherfunis promised to be holomorphic. Default False.
- Return type
- Returns
A function with the same arguments as
fun, that evaluates the Hessian offun.
>>> import jax >>> >>> g = lambda x: x[0]**3 - 2*x[0]*x[1] - x[1]**6 >>> print(jax.hessian(g)(jax.numpy.array([1., 2.]))) [[ 6. -2.] [ -2. -480.]]
hessian()is a generalization of the usual definition of the Hessian that supports nested Python containers (i.e. pytrees) as inputs and outputs. The tree structure ofjax.hessian(fun)(x)is given by forming a tree product of the structure offun(x)with a tree product of two copies of the structure ofx. A tree product of two tree structures is formed by replacing each leaf of the first tree with a copy of the second. For example:>>> import jax.numpy as jnp >>> f = lambda dct: {"c": jnp.power(dct["a"], dct["b"])} >>> print(jax.hessian(f)({"a": jnp.arange(2.) + 1., "b": jnp.arange(2.) + 2.})) {'c': {'a': {'a': DeviceArray([[[ 2., 0.], [ 0., 0.]], [[ 0., 0.], [ 0., 12.]]], dtype=float32), 'b': DeviceArray([[[ 1. , 0. ], [ 0. , 0. ]], [[ 0. , 0. ], [ 0. , 12.317766]]], dtype=float32)}, 'b': {'a': DeviceArray([[[ 1. , 0. ], [ 0. , 0. ]], [[ 0. , 0. ], [ 0. , 12.317766]]], dtype=float32), 'b': DeviceArray([[[0. , 0. ], [0. , 0. ]], [[0. , 0. ], [0. , 3.843624]]], dtype=float32)}}}
Thus each leaf in the tree structure of
jax.hessian(fun)(x)corresponds to a leaf offun(x)and a pair of leaves ofx. For each leaf injax.hessian(fun)(x), if the corresponding array leaf offun(x)has shape(out_1, out_2, ...)and the corresponding array leaves ofxhave shape(in_1_1, in_1_2, ...)and(in_2_1, in_2_2, ...)respectively, then the Hessian leaf has shape(out_1, out_2, ..., in_1_1, in_1_2, ..., in_2_1, in_2_2, ...). In other words, the Python tree structure represents the block structure of the Hessian, with blocks determined by the input and output pytrees.In particular, an array is produced (with no pytrees involved) when the function input
xand outputfun(x)are each a single array, as in thegexample above. Iffun(x)has shape(out1, out2, ...)andxhas shape(in1, in2, ...)thenjax.hessian(fun)(x)has shape(out1, out2, ..., in1, in2, ..., in1, in2, ...). To flatten pytrees into 1D vectors, consider usingjax.flatten_util.flatten_pytree().
-
jax.jvp(fun, primals, tangents)[source]¶ Computes a (forward-mode) Jacobian-vector product of
fun.- Parameters
fun (
Callable) – Function to be differentiated. Its arguments should be arrays, scalars, or standard Python containers of arrays or scalars. It should return an array, scalar, or standard Python container of arrays or scalars.primals – The primal values at which the Jacobian of
funshould be evaluated. Should be either a tuple or a list of arguments, and its length should be equal to the number of positional parameters offun.tangents – The tangent vector for which the Jacobian-vector product should be evaluated. Should be either a tuple or a list of tangents, with the same tree structure and array shapes as
primals.
- Return type
- Returns
A
(primals_out, tangents_out)pair, whereprimals_outisfun(*primals), andtangents_outis the Jacobian-vector product offunctionevaluated atprimalswithtangents. Thetangents_outvalue has the same Python tree structure and shapes asprimals_out.
For example:
>>> import jax >>> >>> y, v = jax.jvp(jax.numpy.sin, (0.1,), (0.2,)) >>> print(y) 0.09983342 >>> print(v) 0.19900084
-
jax.linearize(fun, *primals)[source]¶ Produces a linear approximation to
funusingjvp()and partial eval.- Parameters
fun (
Callable) – Function to be differentiated. Its arguments should be arrays, scalars, or standard Python containers of arrays or scalars. It should return an array, scalar, or standard python container of arrays or scalars.primals – The primal values at which the Jacobian of
funshould be evaluated. Should be a tuple of arrays, scalar, or standard Python container thereof. The length of the tuple is equal to the number of positional parameters offun.
- Return type
- Returns
A pair where the first element is the value of
f(*primals)and the second element is a function that evaluates the (forward-mode) Jacobian-vector product offunevaluated atprimalswithout re-doing the linearization work.
In terms of values computed,
linearize()behaves much like a curriedjvp(), where these two code blocks compute the same values:y, out_tangent = jax.jvp(f, (x,), (in_tangent,)) y, f_jvp = jax.linearize(f, x) out_tangent = f_jvp(in_tangent)
However, the difference is that
linearize()uses partial evaluation so that the functionfis not re-linearized on calls tof_jvp. In general that means the memory usage scales with the size of the computation, much like in reverse-mode. (Indeed,linearize()has a similar signature tovjp()!)This function is mainly useful if you want to apply
f_jvpmultiple times, i.e. to evaluate a pushforward for many different input tangent vectors at the same linearization point. Moreover if all the input tangent vectors are known at once, it can be more efficient to vectorize usingvmap(), as in:pushfwd = partial(jvp, f, (x,)) y, out_tangents = vmap(pushfwd, out_axes=(None, 0))((in_tangents,))
By using
vmap()andjvp()together like this we avoid the stored-linearization memory cost that scales with the depth of the computation, which is incurred by bothlinearize()andvjp().Here’s a more complete example of using
linearize():>>> import jax >>> import jax.numpy as jnp >>> >>> def f(x): return 3. * jnp.sin(x) + jnp.cos(x / 2.) ... >>> jax.jvp(f, (2.,), (3.,)) (DeviceArray(3.26819, dtype=float32, weak_type=True), DeviceArray(-5.00753, dtype=float32, weak_type=True)) >>> y, f_jvp = jax.linearize(f, 2.) >>> print(y) 3.2681944 >>> print(f_jvp(3.)) -5.007528 >>> print(f_jvp(4.)) -6.676704
-
jax.linear_transpose(fun, *primals, reduce_axes=())[source]¶ Transpose a function that is promised to be linear.
For linear functions, this transformation is equivalent to
vjp, but avoids the overhead of computing the forward pass.The outputs of the transposed function will always have the exact same dtypes as
primals, even if some values are truncated (e.g., from complex to float, or from float64 to float32). To avoid truncation, use dtypes inprimalsthat match the full range of desired outputs from the transposed function. Integer dtypes are not supported.- Parameters
fun (
Callable) – the linear function to be transposed.*primals – a positional argument tuple of arrays, scalars, or (nested) standard Python containers (tuples, lists, dicts, namedtuples, i.e., pytrees) of those types used for evaluating the shape/dtype of
fun(*primals). These arguments may be real scalars/ndarrays, but that is not required: only theshapeanddtypeattributes are accessed. See below for an example. (Note that the duck-typed objects cannot be namedtuples because those are treated as standard Python containers.)reduce_axes – Optional, tuple of axis names. If an axis is listed here, and
funimplicitly broadcasts a value over that axis, the backward pass will perform apsumof the corresponding cotangent. Otherwise, the transposed function will be per-example over named axes. For example, if'batch'is a named batch axis,linear_transpose(f, *args, reduce_axes=('batch',))will create a transpose function that sums over the batch whilelinear_transpose(f, args)will create a per-example transpose.
- Return type
- Returns
A callable that calculates the transpose of
fun. Valid input into this function must have the same shape/dtypes/structure as the result offun(*primals). Output will be a tuple, with the same shape/dtypes/structure asprimals.
>>> import jax >>> import types >>> >>> f = lambda x, y: 0.5 * x - 0.5 * y >>> scalar = types.SimpleNamespace(shape=(), dtype=np.dtype(np.float32)) >>> f_transpose = jax.linear_transpose(f, scalar, scalar) >>> f_transpose(1.0) (DeviceArray(0.5, dtype=float32), DeviceArray(-0.5, dtype=float32))
-
jax.vjp(fun: Callable[[…], T], *primals: Any, has_aux: Literal[False] = 'False', reduce_axes: Sequence[Any] = '()') → Tuple[T, Callable][source]¶ -
jax.vjp(fun: Callable[[…], Tuple[T, U]], *primals: Any, has_aux: Literal[True], reduce_axes: Sequence[Any] = '()') → Tuple[T, Callable, U] -
jax.vjp(fun: Callable[[…], T], *primals: Any) → Tuple[T, Callable] -
jax.vjp(fun: Callable[[…], Any], *primals: Any, has_aux: bool, reduce_axes: Sequence[Any] = '()') → Union[Tuple[Any, Callable], Tuple[Any, Callable, Any]] Compute a (reverse-mode) vector-Jacobian product of
fun.grad()is implemented as a special case ofvjp().- Parameters
fun (
Callable) – Function to be differentiated. Its arguments should be arrays, scalars, or standard Python containers of arrays or scalars. It should return an array, scalar, or standard Python container of arrays or scalars.primals – A sequence of primal values at which the Jacobian of
funshould be evaluated. The length ofprimalsshould be equal to the number of positional parameters tofun. Each primal value should be a tuple of arrays, scalar, or standard Python containers thereof.has_aux (
bool) – Optional, bool. Indicates whetherfunreturns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False.reduce_axes – Optional, tuple of axis names. If an axis is listed here, and
funimplicitly broadcasts a value over that axis, the backward pass will perform apsumof the corresponding gradient. Otherwise, the VJP will be per-example over named axes. For example, if'batch'is a named batch axis,vjp(f, *args, reduce_axes=('batch',))will create a VJP function that sums over the batch whilevjp(f, *args)will create a per-example VJP.
- Return type
- Returns
If
has_auxisFalse, returns a(primals_out, vjpfun)pair, whereprimals_outisfun(*primals).vjpfunis a function from a cotangent vector with the same shape asprimals_outto a tuple of cotangent vectors with the same shape asprimals, representing the vector-Jacobian product offunevaluated atprimals. Ifhas_auxisTrue, returns a(primals_out, vjpfun, aux)tuple whereauxis the auxiliary data returned byfun.
>>> import jax >>> >>> def f(x, y): ... return jax.numpy.sin(x), jax.numpy.cos(y) ... >>> primals, f_vjp = jax.vjp(f, 0.5, 1.0) >>> xbar, ybar = f_vjp((-0.7, 0.3)) >>> print(xbar) -0.61430776 >>> print(ybar) -0.2524413
-
class
jax.custom_jvp(fun, nondiff_argnums=())[source]¶ Set up a JAX-transformable function for a custom JVP rule definition.
This class is meant to be used as a function decorator. Instances are callables that behave similarly to the underlying function to which the decorator was applied, except when a differentiation transformation (like
jax.jvp()orjax.grad()) is applied, in which case a custom user-supplied JVP rule function is used instead of tracing into and performing automatic differentiation of the underlying function’s implementation.There are two instance methods available for defining the custom JVP rule:
defjvp()for defining a single custom JVP rule for all the function’s inputs, and for conveniencedefjvps(), which wrapsdefjvp(), and allows you to provide separate definitions for the partial derivatives of the function w.r.t. each of its arguments.For example:
@jax.custom_jvp def f(x, y): return jnp.sin(x) * y @f.defjvp def f_jvp(primals, tangents): x, y = primals x_dot, y_dot = tangents primal_out = f(x, y) tangent_out = jnp.cos(x) * x_dot * y + jnp.sin(x) * y_dot return primal_out, tangent_out
For a more detailed introduction, see the tutorial.
-
defjvp(jvp)[source]¶ Define a custom JVP rule for the function represented by this instance.
- Parameters
jvp (
Callable[…,Tuple[~ReturnValue, ~ReturnValue]]) – a Python callable representing the custom JVP rule. When there are nonondiff_argnums, thejvpfunction should accept two arguments, where the first is a tuple of primal inputs and the second is a tuple of tangent inputs. The lengths of both tuples are equal to the number of parameters of thecustom_jvpfunction. Thejvpfunction should produce as output a pair where the first element is the primal output and the second element is the tangent output. Elements of the input and output tuples may be arrays or any nested tuples/lists/dicts thereof.- Return type
- Returns
None.
Example:
@jax.custom_jvp def f(x, y): return jnp.sin(x) * y @f.defjvp def f_jvp(primals, tangents): x, y = primals x_dot, y_dot = tangents primal_out = f(x, y) tangent_out = jnp.cos(x) * x_dot * y + jnp.sin(x) * y_dot return primal_out, tangent_out
-
defjvps(*jvps)[source]¶ Convenience wrapper for defining JVPs for each argument separately.
This convenience wrapper cannot be used together with
nondiff_argnums.- Parameters
*jvps – a sequence of functions, one for each positional argument of the
custom_jvpfunction. Each function takes as arguments the tangent value for the corresponding primal input, the primal output, and the primal inputs. See the example below.- Returns
None.
Example:
@jax.custom_jvp def f(x, y): return jnp.sin(x) * y f.defjvps(lambda x_dot, primal_out, x, y: jnp.cos(x) * x_dot * y, lambda y_dot, primal_out, x, y: jnp.sin(x) * y_dot)
-
-
class
jax.custom_vjp(fun, nondiff_argnums=())[source]¶ Set up a JAX-transformable function for a custom VJP rule definition.
This class is meant to be used as a function decorator. Instances are callables that behave similarly to the underlying function to which the decorator was applied, except when a reverse-mode differentiation transformation (like
jax.grad()) is applied, in which case a custom user-supplied VJP rule function is used instead of tracing into and performing automatic differentiation of the underlying function’s implementation. There is a single instance method,defvjp(), which may be used to define the custom VJP rule.This decorator precludes the use of forward-mode automatic differentiation.
For example:
@jax.custom_vjp def f(x, y): return jnp.sin(x) * y def f_fwd(x, y): return f(x, y), (jnp.cos(x), jnp.sin(x), y) def f_bwd(res, g): cos_x, sin_x, y = res return (cos_x * g * y, sin_x * g) f.defvjp(f_fwd, f_bwd)
For a more detailed introduction, see the tutorial.
-
defvjp(fwd, bwd)[source]¶ Define a custom VJP rule for the function represented by this instance.
- Parameters
fwd (
Callable[…,Tuple[~ReturnValue,Any]]) – a Python callable representing the forward pass of the custom VJP rule. When there are nonondiff_argnums, thefwdfunction has the same input signature as the underlying primal function. It should return as output a pair, where the first element represents the primal output and the second element represents any “residual” values to store from the forward pass for use on the backward pass by the functionbwd. Input arguments and elements of the output pair may be arrays or nested tuples/lists/dicts thereof.bwd (
Callable[…,Tuple[Any, …]]) – a Python callable representing the backward pass of the custom VJP rule. When there are nonondiff_argnums, thebwdfunction takes two arguments, where the first is the “residual” values produced on the forward pass byfwd, and the second is the output cotangent with the same structure as the primal function output. The output ofbwdmust be a tuple of length equal to the number of arguments of the primal function, and the tuple elements may be arrays or nested tuples/lists/dicts thereof so as to match the structure of the primal input arguments.
- Return type
- Returns
None.
Example:
@jax.custom_vjp def f(x, y): return jnp.sin(x) * y def f_fwd(x, y): return f(x, y), (jnp.cos(x), jnp.sin(x), y) def f_bwd(res, g): cos_x, sin_x, y = res return (cos_x * g * y, sin_x * g) f.defvjp(f_fwd, f_bwd)
-
-
jax.closure_convert(fun, *example_args)[source]¶ Closure conversion utility, for use with higher-order custom derivatives.
To define custom derivatives such as with
jax.custom_vjp(f), the target functionfmust take, as formal arguments, all values involved in differentiation. Iffis a higher-order function, in that it accepts as an argument a Python functiong, then values stored away ing’s closure will not be visible to the custom derivative rules, and attempts at AD involving these values will fail. One way around this is to convert the closure by extracting these values, and to pass them as explicit formal arguments across the custom derivative boundary. This utility carries out that conversion. More precisely, it closure-converts the functionfunspecialized to the types of the arguments given inexample_args.When we refer here to “values in the closure” of
fun, we do not mean the values that are captured by Python directly whenfunis defined (e.g. the Python objects infun.__closure__, if the attribute exists). Rather, we mean values encountered during the execution offunonexample_argsthat determine its output. This may include, for instance, arrays captured transitively in Python closures, i.e. in the Python closure of functions called byfun, the closures of the functions that they call, and so forth.The function
funmust be a pure function.Example usage:
def minimize(objective_fn, x0): converted_fn, aux_args = closure_convert(objective_fn, x0) return _minimize(converted_fn, x0, *aux_args) @partial(custom_vjp, nondiff_argnums=(0,)) def _minimize(objective_fn, x0, *args): z = objective_fn(x0, *args) # ... find minimizer x_opt ... return x_opt def fwd(objective_fn, x0, *args): y = _minimize(objective_fn, x0, *args) return y, (y, args) def rev(objective_fn, res, g): y, args = res y_bar = g # ... custom reverse-mode AD ... return x0_bar, *args_bars _minimize.defvjp(fwd, rev)
- Parameters
fun – Python callable to be converted. Must be a pure function.
example_args – Arrays, scalars, or (nested) standard Python containers (tuples, lists, dicts, namedtuples, i.e., pytrees) thereof, used to determine the types of the formal arguments to
fun. This type-specialized form offunis the function that will be closure converted.
- Returns
A pair comprising (i) a Python callable, accepting the same arguments as
funfollowed by arguments corresponding to the values hoisted from its closure, and (ii) a list of values hoisted from the closure.
-
jax.checkpoint(fun, concrete=False, prevent_cse=True, policy=None)[source]¶ Make
funrecompute internal linearization points when differentiated.The
jax.checkpoint()decorator, aliased tojax.remat, provides a way to trade off computation time and memory cost in the context of automatic differentiation, especially with reverse-mode autodiff likejax.grad()andjax.vjp()but also withjax.linearize().When differentiating a function in reverse-mode, by default all the linearization points (e.g. inputs to elementwise nonlinear primitive operations) are stored when evaluating the forward pass so that they can be reused on the backward pass. This evaluation strategy can lead to a high memory cost, or even to poor performance on hardware accelerators where memory access is much more expensive than FLOPs.
An alternative evaluation strategy is for some of the linearization points to be recomputed (i.e. rematerialized) rather than stored. This approach can reduce memory usage at the cost of increased computation.
This function decorator produces a new version of
funwhich follows the rematerialization strategy rather than the default store-everything strategy. That is, it returns a new version offunwhich, when differentiated, doesn’t store any of its intermediate linearization points. Instead, these linearization points are recomputed from the function’s saved inputs.See the examples below.
- Parameters
fun (
Callable) – Function for which the autodiff evaluation strategy is to be changed from the default of storing all intermediate linearization points to recomputing them. Its arguments and return value should be arrays, scalars, or (nested) standard Python containers (tuple/list/dict) thereof.concrete (
bool) – Optional, boolean indicating whetherfunmay involve value-dependent Python control flow (default False). Support for such control flow is optional, and disabled by default, because in some edge-case compositions withjax.jit()it can lead to some extra computation.prevent_cse (
bool) – Optional, boolean indicating whether to prevent common subexpression elimination (CSE) optimizations in the HLO generated from differentiation. This CSE prevention has costs because it can foil other optimizations, and because it can incur high overheads on some backends, especially GPU. The default is True because otherwise, under ajitorpmap, CSE can defeat the purpose of this decorator. But in some settings, like when used inside ascan, this CSE prevention mechanism is unnecessary, in which caseprevent_csecan be set to False.policy (
Optional[Callable[…,bool]]) – This is an experimental feature and the API is likely to change. Optional callable, one of the attributes ofjax.checkpoint_policies, which takes as input a type-level specification of a first-order primitive application and returns a boolean indicating whether the corresponding output value(s) can be saved as a residual (or, if not, instead must be recomputed in the (co)tangent computation).
- Return type
- Returns
A function (callable) with the same input/output behavior as
funbut which, when differentiated using e.g.jax.grad(),jax.vjp(), orjax.linearize(), recomputes rather than stores intermediate linearization points, thus potentially saving memory at the cost of extra computation.
Here is a simple example:
>>> import jax >>> import jax.numpy as jnp
>>> @jax.checkpoint ... def g(x): ... y = jnp.sin(x) ... z = jnp.sin(y) ... return z ... >>> jax.grad(g)(2.0) DeviceArray(-0.25563914, dtype=float32)
Here, the same value is produced whether or not the
jax.checkpoint()decorator is present. But when usingjax.checkpoint(), the valuejnp.sin(2.0)is computed twice: once on the forward pass, and once on the backward pass. The valuesjnp.cos(2.0)andjnp.cos(jnp.sin(2.0))are also computed twice. Without using the decorator, bothjnp.cos(2.0)andjnp.cos(jnp.sin(2.0))would be stored and reused.The
jax.checkpoint()decorator can be applied recursively to express sophisticated autodiff rematerialization strategies. For example:>>> def recursive_checkpoint(funs): ... if len(funs) == 1: ... return funs[0] ... elif len(funs) == 2: ... f1, f2 = funs ... return lambda x: f1(f2(x)) ... else: ... f1 = recursive_checkpoint(funs[:len(funs)//2]) ... f2 = recursive_checkpoint(funs[len(funs)//2:]) ... return lambda x: f1(jax.checkpoint(f2)(x)) ...
-
jax.vmap(fun, in_axes=0, out_axes=0, axis_name=None)[source]¶ Vectorizing map. Creates a function which maps
funover argument axes.- Parameters
fun (~F) – Function to be mapped over additional axes.
in_axes –
An integer, None, or (nested) standard Python container (tuple/list/dict) thereof specifying which input array axes to map over.
If each positional argument to
funis an array, thenin_axescan be an integer, a None, or a tuple of integers and Nones with length equal to the number of positional arguments tofun. An integer orNoneindicates which array axis to map over for all arguments (withNoneindicating not to map any axis), and a tuple indicates which axis to map for each corresponding positional argument. Axis integers must be in the range[-ndim, ndim)for each array, wherendimis the number of dimensions (axes) of the corresponding input array.If the positional arguments to
funare container types, the corresponding element ofin_axescan itself be a matching container, so that distinct array axes can be mapped for different container elements.in_axesmust be a container tree prefix of the positional argument tuple passed tofun.At least one positional argument must have
in_axesnot None. The sizes of the mapped input axes for all mapped positional arguments must all be equal.Arguments passed as keywords are always mapped over their leading axis (i.e. axis index 0).
See below for examples.
out_axes – An integer, None, or (nested) standard Python container (tuple/list/dict) thereof indicating where the mapped axis should appear in the output. All outputs with a mapped axis must have a non-None
out_axesspecification. Axis integers must be in the range[-ndim, ndim)for each output array, wherendimis the number of dimensions (axes) of the array returned by thevmap()-ed function, which is one more than the number of dimensions (axes) of the corresponding array returned byfun.
- Return type
~F
- Returns
Batched/vectorized version of
funwith arguments that correspond to those offun, but with extra array axes at positions indicated byin_axes, and a return value that corresponds to that offun, but with extra array axes at positions indicated byout_axes.
For example, we can implement a matrix-matrix product using a vector dot product:
>>> import jax.numpy as jnp >>> >>> vv = lambda x, y: jnp.vdot(x, y) # ([a], [a]) -> [] >>> mv = vmap(vv, (0, None), 0) # ([b,a], [a]) -> [b] (b is the mapped axis) >>> mm = vmap(mv, (None, 1), 1) # ([b,a], [a,c]) -> [b,c] (c is the mapped axis)
Here we use
[a,b]to indicate an array with shape (a,b). Here are some variants:>>> mv1 = vmap(vv, (0, 0), 0) # ([b,a], [b,a]) -> [b] (b is the mapped axis) >>> mv2 = vmap(vv, (0, 1), 0) # ([b,a], [a,b]) -> [b] (b is the mapped axis) >>> mm2 = vmap(mv2, (1, 1), 0) # ([b,c,a], [a,c,b]) -> [c,b] (c is the mapped axis)
Here’s an example of using container types in
in_axesto specify which axes of the container elements to map over:>>> A, B, C, D = 2, 3, 4, 5 >>> x = jnp.ones((A, B)) >>> y = jnp.ones((B, C)) >>> z = jnp.ones((C, D)) >>> def foo(tree_arg): ... x, (y, z) = tree_arg ... return jnp.dot(x, jnp.dot(y, z)) >>> tree = (x, (y, z)) >>> print(foo(tree)) [[12. 12. 12. 12. 12.] [12. 12. 12. 12. 12.]] >>> from jax import vmap >>> K = 6 # batch size >>> x = jnp.ones((K, A, B)) # batch axis in different locations >>> y = jnp.ones((B, K, C)) >>> z = jnp.ones((C, D, K)) >>> tree = (x, (y, z)) >>> vfoo = vmap(foo, in_axes=((0, (1, 2)),)) >>> print(vfoo(tree).shape) (6, 2, 5)
Here’s another example using container types in
in_axes, this time a dictionary, to specify the elements of the container to map over:>>> dct = {'a': 0., 'b': jnp.arange(5.)} >>> x = 1. >>> def foo(dct, x): ... return dct['a'] + dct['b'] + x >>> out = vmap(foo, in_axes=({'a': None, 'b': 0}, None))(dct, x) >>> print(out) [1. 2. 3. 4. 5.]
The results of a vectorized function can be mapped or unmapped. For example, the function below returns a pair with the first element mapped and the second unmapped. Only for unmapped results we can specify
out_axesto beNone(to keep it unmapped).>>> print(vmap(lambda x, y: (x + y, y * 2.), in_axes=(0, None), out_axes=(0, None))(jnp.arange(2.), 4.)) (DeviceArray([4., 5.], dtype=float32), 8.0)
If the
out_axesis specified for an unmapped result, the result is broadcast across the mapped axis:>>> print(vmap(lambda x, y: (x + y, y * 2.), in_axes=(0, None), out_axes=0)(jnp.arange(2.), 4.)) (DeviceArray([4., 5.], dtype=float32), DeviceArray([8., 8.], dtype=float32, weak_type=True))
If the
out_axesis specified for a mapped result, the result is transposed accordingly.
-
jax.numpy.vectorize(pyfunc, *, excluded=frozenset({}), signature=None)[source] Define a vectorized function with broadcasting.
vectorize()is a convenience wrapper for defining vectorized functions with broadcasting, in the style of NumPy’s generalized universal functions. It allows for defining functions that are automatically repeated across any leading dimensions, without the implementation of the function needing to be concerned about how to handle higher dimensional inputs.jax.numpy.vectorize()has the same interface asnumpy.vectorize, but it is syntactic sugar for an auto-batching transformation (vmap()) rather than a Python loop. This should be considerably more efficient, but the implementation must be written in terms of functions that act on JAX arrays.- Parameters
pyfunc – function to vectorize.
excluded – optional set of integers representing positional arguments for which the function will not be vectorized. These will be passed directly to
pyfuncunmodified.signature – optional generalized universal function signature, e.g.,
(m,n),(n)->(m)for vectorized matrix-vector multiplication. If provided,pyfuncwill be called with (and expected to return) arrays with shapes given by the size of corresponding core dimensions. By default, pyfunc is assumed to take scalars arrays as input and output.
- Returns
Vectorized version of the given function.
Here are a few examples of how one could write vectorized linear algebra routines using
vectorize():>>> from functools import partial
>>> @partial(jnp.vectorize, signature='(k),(k)->(k)') ... def cross_product(a, b): ... assert a.shape == b.shape and a.ndim == b.ndim == 1 ... return jnp.array([a[1] * b[2] - a[2] * b[1], ... a[2] * b[0] - a[0] * b[2], ... a[0] * b[1] - a[1] * b[0]])
>>> @partial(jnp.vectorize, signature='(n,m),(m)->(n)') ... def matrix_vector_product(matrix, vector): ... assert matrix.ndim == 2 and matrix.shape[1:] == vector.shape ... return matrix @ vector
These functions are only written to handle 1D or 2D arrays (the
assertstatements will never be violated), but with vectorize they support arbitrary dimensional inputs with NumPy style broadcasting, e.g.,>>> cross_product(jnp.ones(3), jnp.ones(3)).shape (3,) >>> cross_product(jnp.ones((2, 3)), jnp.ones(3)).shape (2, 3) >>> cross_product(jnp.ones((1, 2, 3)), jnp.ones((2, 1, 3))).shape (2, 2, 3) >>> matrix_vector_product(jnp.ones(3), jnp.ones(3)) Traceback (most recent call last): ValueError: input with shape (3,) does not have enough dimensions for all core dimensions ('n', 'k') on vectorized function with excluded=frozenset() and signature='(n,k),(k)->(k)' >>> matrix_vector_product(jnp.ones((2, 3)), jnp.ones(3)).shape (2,) >>> matrix_vector_product(jnp.ones((2, 3)), jnp.ones((4, 3))).shape (4, 2)
Note that this has different semantics than jnp.matmul:
>>> jnp.matmul(jnp.ones((2, 3)), jnp.ones((4, 3))) Traceback (most recent call last): TypeError: dot_general requires contracting dimensions to have the same shape, got [3] and [4].
-
jax.pmap(fun, axis_name=None, *, in_axes=0, out_axes=0, static_broadcasted_argnums=(), devices=None, backend=None, axis_size=None, donate_argnums=(), global_arg_shapes=None)[source]¶ Parallel map with support for collective operations.
The purpose of
pmap()is to express single-program multiple-data (SPMD) programs. Applyingpmap()to a function will compile the function with XLA (similarly tojit()), then execute it in parallel on XLA devices, such as multiple GPUs or multiple TPU cores. Semantically it is comparable tovmap()because both transformations map a function over array axes, but wherevmap()vectorizes functions by pushing the mapped axis down into primitive operations,pmap()instead replicates the function and executes each replica on its own XLA device in parallel.The mapped axis size must be less than or equal to the number of local XLA devices available, as returned by
jax.local_device_count()(unlessdevicesis specified, see below). For nestedpmap()calls, the product of the mapped axis sizes must be less than or equal to the number of XLA devices.Multi-process platforms: On multi-process platforms such as TPU pods,
pmap()is designed to be used in SPMD Python programs, where every process is running the same Python code such that all processes run the same pmapped function in the same order. Each process should still call the pmapped function with mapped axis size equal to the number of local devices (unlessdevicesis specified, see below), and an array of the same leading axis size will be returned as usual. However, any collective operations infunwill be computed over all participating devices, including those on other processes, via device-to-device communication. Conceptually, this can be thought of as running a pmap over a single array sharded across processes, where each process “sees” only its local shard of the input and output. The SPMD model requires that the same multi-process pmaps must be run in the same order on all devices, but they can be interspersed with arbitrary operations running in a single process.- Parameters
fun (~F) – Function to be mapped over argument axes. Its arguments and return value should be arrays, scalars, or (nested) standard Python containers (tuple/list/dict) thereof. Positional arguments indicated by
static_broadcasted_argnumscan be anything at all, provided they are hashable and have an equality operation defined.axis_name (
Optional[Any]) – Optional, a hashable Python object used to identify the mapped axis so that parallel collectives can be applied.in_axes – A non-negative integer, None, or nested Python container thereof that specifies which axes of positional arguments to map over. Arguments passed as keywords are always mapped over their leading axis (i.e. axis index 0). See
vmap()for details.out_axes – A non-negative integer, None, or nested Python container thereof indicating where the mapped axis should appear in the output. All outputs with a mapped axis must have a non-None
out_axesspecification (seevmap()).static_broadcasted_argnums (
Union[int,Iterable[int]]) – An int or collection of ints specifying which positional arguments to treat as static (compile-time constant). Operations that only depend on static arguments will be constant-folded. Calling the pmapped function with different values for these constants will trigger recompilation. If the pmapped function is called with fewer positional arguments than indicated bystatic_argnumsthen an error is raised. Each of the static arguments will be broadcasted to all devices. Arguments that are not arrays or containers thereof must be marked as static. Defaults to ().devices (
Optional[Sequence[Device]]) – This is an experimental feature and the API is likely to change. Optional, a sequence of Devices to map over. (Available devices can be retrieved via jax.devices()). Must be given identically for each process in multi-process settings (and will therefore include devices across processes). If specified, the size of the mapped axis must be equal to the number of devices in the sequence local to the given process. Nestedpmap()s withdevicesspecified in either the inner or outerpmap()are not yet supported.backend (
Optional[str]) – This is an experimental feature and the API is likely to change. Optional, a string representing the XLA backend. ‘cpu’, ‘gpu’, or ‘tpu’.axis_size (
Optional[int]) – Optional; the size of the mapped axis.donate_argnums (
Union[int,Iterable[int]]) – Specify which arguments are “donated” to the computation. It is safe to donate arguments if you no longer need them once the computation has finished. In some cases XLA can make use of donated buffers to reduce the amount of memory needed to perform a computation, for example recycling one of your input buffers to store a result. You should not reuse buffers that you donate to a computation, JAX will raise an error if you try to.global_arg_shapes (
Optional[Tuple[Tuple[int, …], …]]) – Optional, must be set when using pmap(sharded_jit) and the partitioned values span multiple processes. The global cross-process per-replica shape of each argument, i.e. does not include the leading pmapped dimension. Can be None for replicated arguments. This API is likely to change in the future.
- Return type
~F
- Returns
A parallelized version of
funwith arguments that correspond to those offunbut with extra array axes at positions indicated byin_axesand with output that has an additional leading array axis (with the same size).
For example, assuming 8 XLA devices are available,
pmap()can be used as a map along a leading array axis:>>> import jax.numpy as jnp >>> >>> out = pmap(lambda x: x ** 2)(jnp.arange(8)) >>> print(out) [0, 1, 4, 9, 16, 25, 36, 49]
When the leading dimension is smaller than the number of available devices JAX will simply run on a subset of devices:
>>> x = jnp.arange(3 * 2 * 2.).reshape((3, 2, 2)) >>> y = jnp.arange(3 * 2 * 2.).reshape((3, 2, 2)) ** 2 >>> out = pmap(jnp.dot)(x, y) >>> print(out) [[[ 4. 9.] [ 12. 29.]] [[ 244. 345.] [ 348. 493.]] [[ 1412. 1737.] [ 1740. 2141.]]]
If your leading dimension is larger than the number of available devices you will get an error:
>>> pmap(lambda x: x ** 2)(jnp.arange(9)) ValueError: ... requires 9 replicas, but only 8 XLA devices are available
As with
vmap(), usingNoneinin_axesindicates that an argument doesn’t have an extra axis and should be broadcasted, rather than mapped, across the replicas:>>> x, y = jnp.arange(2.), 4. >>> out = pmap(lambda x, y: (x + y, y * 2.), in_axes=(0, None))(x, y) >>> print(out) ([4., 5.], [8., 8.])
Note that
pmap()always returns values mapped over their leading axis, equivalent to usingout_axes=0invmap().In addition to expressing pure maps,
pmap()can also be used to express parallel single-program multiple-data (SPMD) programs that communicate via collective operations. For example:>>> f = lambda x: x / jax.lax.psum(x, axis_name='i') >>> out = pmap(f, axis_name='i')(jnp.arange(4.)) >>> print(out) [ 0. 0.16666667 0.33333334 0.5 ] >>> print(out.sum()) 1.0
In this example,
axis_nameis a string, but it can be any Python object with__hash__and__eq__defined.The argument
axis_nametopmap()names the mapped axis so that collective operations, likejax.lax.psum(), can refer to it. Axis names are important particularly in the case of nestedpmap()functions, where collective operations can operate over distinct axes:>>> from functools import partial >>> import jax >>> >>> @partial(pmap, axis_name='rows') ... @partial(pmap, axis_name='cols') ... def normalize(x): ... row_normed = x / jax.lax.psum(x, 'rows') ... col_normed = x / jax.lax.psum(x, 'cols') ... doubly_normed = x / jax.lax.psum(x, ('rows', 'cols')) ... return row_normed, col_normed, doubly_normed >>> >>> x = jnp.arange(8.).reshape((4, 2)) >>> row_normed, col_normed, doubly_normed = normalize(x) >>> print(row_normed.sum(0)) [ 1. 1.] >>> print(col_normed.sum(1)) [ 1. 1. 1. 1.] >>> print(doubly_normed.sum((0, 1))) 1.0
On multi-process platforms, collective operations operate over all devices, including those on other processes. For example, assuming the following code runs on two processes with 4 XLA devices each:
>>> f = lambda x: x + jax.lax.psum(x, axis_name='i') >>> data = jnp.arange(4) if jax.process_index() == 0 else jnp.arange(4, 8) >>> out = pmap(f, axis_name='i')(data) >>> print(out) [28 29 30 31] # on process 0 [32 33 34 35] # on process 1
Each process passes in a different length-4 array, corresponding to its 4 local devices, and the psum operates over all 8 values. Conceptually, the two length-4 arrays can be thought of as a sharded length-8 array (in this example equivalent to jnp.arange(8)) that is mapped over, with the length-8 mapped axis given name ‘i’. The pmap call on each process then returns the corresponding length-4 output shard.
The
devicesargument can be used to specify exactly which devices are used to run the parallel computation. For example, again assuming a single process with 8 devices, the following code defines two parallel computations, one which runs on the first six devices and one on the remaining two:>>> from functools import partial >>> @partial(pmap, axis_name='i', devices=jax.devices()[:6]) ... def f1(x): ... return x / jax.lax.psum(x, axis_name='i') >>> >>> @partial(pmap, axis_name='i', devices=jax.devices()[-2:]) ... def f2(x): ... return jax.lax.psum(x ** 2, axis_name='i') >>> >>> print(f1(jnp.arange(6.))) [0. 0.06666667 0.13333333 0.2 0.26666667 0.33333333] >>> print(f2(jnp.array([2., 3.]))) [ 13. 13.]
-
jax.devices(backend=None)[source]¶ Returns a list of all devices for a given backend.
Each device is represented by a subclass of
Device(e.g.CpuDevice,GpuDevice). The length of the returned list is equal todevice_count(backend). Local devices can be identified by comparingDevice.process_index()to the value returned byjax.process_index().If
backendisNone, returns all the devices from the default backend. The default backend is generally'gpu'or'tpu'if available, otherwise'cpu'.
-
jax.local_devices(process_index=None, backend=None, host_id=None)[source]¶ Like
jax.devices(), but only returns devices local to a given process.If
process_indexisNone, returns devices local to this process.- Parameters
- Return type
List[Device]- Returns
List of Device subclasses.
-
jax.process_index(backend=None)[source]¶ Returns the integer process index of this process.
On most platforms, this will always be 0. This will vary on multi-process platforms though.
-
jax.device_count(backend=None)[source]¶ Returns the total number of devices.
On most platforms, this is the same as
jax.local_device_count(). However, on multi-process platforms where different devices are associated with different processes, this will return the total number of devices across all processes.