jax.nn.one_hotΒΆ
-
jax.nn.one_hot(x, num_classes, *, dtype=<class 'jax._src.numpy.lax_numpy.float64'>, axis=-1)[source]ΒΆ One-hot encodes the given indicies.
Each index in the input
xis encoded as a vector of zeros of lengthnum_classeswith the element atindexset to one:>>> jax.nn.one_hot(jnp.array([0, 1, 2]), 3) DeviceArray([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]], dtype=float32)
Indicies outside the range [0, num_classes) will be encoded as zeros:
>>> jax.nn.one_hot(jnp.array([-1, 3]), 3) DeviceArray([[0., 0., 0.], [0., 0., 0.]], dtype=float32)
- Parameters
x (
Any) β A tensor of indices.num_classes (
int) β Number of classes in the one-hot dimension.dtype (
Any) β optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).axis (
Union[int,Hashable]) β the axis or axes along which the function should be computed.
- Return type