Getting Started
Reference Documentation
Advanced JAX Tutorials
Notes
custom_vjp
nondiff_argnums
Developer documentation
API documentation
jit
vmap
pmap
jax.numpy.
ones
Return a new array of given shape and type, filled with ones.
LAX-backend implementation of ones().
ones()
Original docstring below.
shape (int or sequence of ints) β Shape of the new array, e.g., (2, 3) or 2.
(2, 3)
2
dtype (data-type, optional) β The desired data-type for the array, e.g., numpy.int8. Default is numpy.float64.
out β Array of ones with the given shape, dtype, and order.
ndarray