Skip to content

ENH: add diag_indices, tril_indices, triu_indices #686

@bruAristimunha

Description

@bruAristimunha

These three index-generating functions exist across most array backends but are missing from both the Array API standard and array-api-extra.

Backend availability

Function numpy torch JAX CuPy dask sparse
diag_indices(n) numpy ✗ (manual via arange) jax cupy
tril_indices(n, k) numpy torch (different signature) jax cupy dask
triu_indices(n, k) numpy torch (different signature) jax cupy dask

For torch, tril_indices/triu_indices exist but with a different signature (torch.tril_indices(row, col, offset=, device=)) and diag_indices doesn't exist at all. This API divergence is the kind of gap array-api-extra is well-suited to unify.

Downstream usage

scikit-learn has a private _add_to_diagonal in sklearn.utils._array_api that works around the lack of diag_indices for array-API-compatible covariance estimators (see main branch).

These functions are commonly needed in any library that works with symmetric or triangular matrices across backends, example: pyriemann, cc. @qbarthelemy and @agramfort.

Relationship to fill_diagonal (#500) and xpx.at

fill_diagonal (PR #500) sets diagonal values, which is complementary but different. The index functions are needed for:

  • Reading subsets (e.g., extracting lower-triangle elements for vectorization)
  • Additive updates (e.g., at(X)[..., idx[0], idx[1]].add(reg) for regularization)
  • Non-diagonal patterns (tril_indices, triu_indices)

The existing xpx.at helper handles the JAX immutability concern for write operations using these indices.

Proposed signatures

Following numpy's API (which JAX and CuPy already match):

def diag_indices(n: int, ndim: int = 2, *, xp=None) -> tuple[Array, ...]:
    ...

def tril_indices(n: int, k: int = 0, m: int | None = None, *, xp=None) -> tuple[Array, Array]:
    ...

def triu_indices(n: int, k: int = 0, m: int | None = None, *, xp=None) -> tuple[Array, Array]:
    ...

Delegation would be straightforward: numpy/jax/cupy delegate directly; torch needs argument translation for tril_indices/triu_indices and a manual arange for diag_indices. dask has tril_indices/triu_indices but not diag_indices.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions