These three index-generating functions exist across most array backends but are missing from both the Array API standard and array-api-extra.
Backend availability
For torch, tril_indices/triu_indices exist but with a different signature (torch.tril_indices(row, col, offset=, device=)) and diag_indices doesn't exist at all. This API divergence is the kind of gap array-api-extra is well-suited to unify.
Downstream usage
scikit-learn has a private _add_to_diagonal in sklearn.utils._array_api that works around the lack of diag_indices for array-API-compatible covariance estimators (see main branch).
These functions are commonly needed in any library that works with symmetric or triangular matrices across backends, example: pyriemann, cc. @qbarthelemy and @agramfort.
Relationship to fill_diagonal (#500) and xpx.at
fill_diagonal (PR #500) sets diagonal values, which is complementary but different. The index functions are needed for:
- Reading subsets (e.g., extracting lower-triangle elements for vectorization)
- Additive updates (e.g.,
at(X)[..., idx[0], idx[1]].add(reg) for regularization)
- Non-diagonal patterns (
tril_indices, triu_indices)
The existing xpx.at helper handles the JAX immutability concern for write operations using these indices.
Proposed signatures
Following numpy's API (which JAX and CuPy already match):
def diag_indices(n: int, ndim: int = 2, *, xp=None) -> tuple[Array, ...]:
...
def tril_indices(n: int, k: int = 0, m: int | None = None, *, xp=None) -> tuple[Array, Array]:
...
def triu_indices(n: int, k: int = 0, m: int | None = None, *, xp=None) -> tuple[Array, Array]:
...
Delegation would be straightforward: numpy/jax/cupy delegate directly; torch needs argument translation for tril_indices/triu_indices and a manual arange for diag_indices. dask has tril_indices/triu_indices but not diag_indices.
These three index-generating functions exist across most array backends but are missing from both the Array API standard and array-api-extra.
Backend availability
diag_indices(n)arange)tril_indices(n, k)triu_indices(n, k)For torch,
tril_indices/triu_indicesexist but with a different signature (torch.tril_indices(row, col, offset=, device=)) anddiag_indicesdoesn't exist at all. This API divergence is the kind of gap array-api-extra is well-suited to unify.Downstream usage
scikit-learn has a private
_add_to_diagonalinsklearn.utils._array_apithat works around the lack ofdiag_indicesfor array-API-compatible covariance estimators (see main branch).These functions are commonly needed in any library that works with symmetric or triangular matrices across backends, example: pyriemann, cc. @qbarthelemy and @agramfort.
Relationship to
fill_diagonal(#500) andxpx.atfill_diagonal(PR #500) sets diagonal values, which is complementary but different. The index functions are needed for:at(X)[..., idx[0], idx[1]].add(reg)for regularization)tril_indices,triu_indices)The existing
xpx.athelper handles the JAX immutability concern for write operations using these indices.Proposed signatures
Following numpy's API (which JAX and CuPy already match):
Delegation would be straightforward: numpy/jax/cupy delegate directly; torch needs argument translation for
tril_indices/triu_indicesand a manualarangefordiag_indices. dask hastril_indices/triu_indicesbut notdiag_indices.