jax.numpy.trilΒΆ
-
jax.numpy.tril(m, k=0)[source]ΒΆ Lower triangle of an array.
LAX-backend implementation of
tril().Original docstring below.
Return a copy of an array with elements above the k-th diagonal zeroed.
- Parameters
m (array_like, shape (M, N)) β Input array.
k (int, optional) β Diagonal above which to zero elements. k = 0 (the default) is the main diagonal, k < 0 is below it and k > 0 is above.
- Returns
tril β Lower triangle of m, of same shape and data-type as m.
- Return type
ndarray, shape (M, N)