Skip to content

ENH: add diag_indices, tril_indices, triu_indices#692

Open
bruAristimunha wants to merge 1 commit intodata-apis:mainfrom
bruAristimunha:diag_tril_triu_indices
Open

ENH: add diag_indices, tril_indices, triu_indices#692
bruAristimunha wants to merge 1 commit intodata-apis:mainfrom
bruAristimunha:diag_tril_triu_indices

Conversation

@bruAristimunha
Copy link
Copy Markdown

Resolves #686.

Summary

  • Adds diag_indices, tril_indices, triu_indices — three index-generating functions that numpy, jax, and cupy all have but that are missing from the array-api standard and from this library.
  • Signatures follow array-api conventions: parameter offset (matching xp.linalg.diagonal) instead of numpy's k; keyword-only arguments for everything except n; xp is required (these functions have no input array to infer from, following the default_dtype precedent).

Numpy migration

numpy.diag_indices(n, ndim=k)           -> xpx.diag_indices(n, ndim=k, xp=xp)
numpy.tril_indices(n, k=k, m=m)         -> xpx.tril_indices(n, offset=k, m=m, xp=xp)
numpy.triu_indices(n, k=k, m=m)         -> xpx.triu_indices(n, offset=k, m=m, xp=xp)

Delegation

  • numpy / cupy / jax: forward directly (signatures match verbatim).
  • dask: has tril_indices/triu_indices but no diag_indices — the last one falls through to the generic impl.
  • torch: has tril_indices/triu_indices but with (row, col, *, offset) signature returning a 2×N tensor rather than a tuple; delegation translates. No torch.diag_indices exists.
  • sparse, array-api-strictest: fall through to generic; marked xfail on those backends (no nonzero / data-dependent shapes).

Validation (n >= 0, ndim >= 1, m >= 0) happens in the delegation layer via a shared _check_nonneg helper, so all backends emit consistent ValueErrors before any backend-specific code runs.

Generic implementation

  • diag_indices(xp.arange(n),) * ndim.
  • tril_indices/triu_indices → shared _tri_indices helper: xp.arange + broadcasting + xp.nonzero on the mask. Pure array-api, fully lazy on dask.

Also in this PR

Fixes a pre-existing bug in tests/conftest.py's NumPyReadOnly wrapper: type(o)(*gen) worked for namedtuples but failed for plain tuples of length ≥ 2. Exposed here because these are the first functions in the library that return a tuple of arrays.

Test plan

  • pytest tests/test_funcs.py::TestDiagIndices tests/test_funcs.py::TestTriIndices — 155 passed across numpy, torch, jax, dask, array-api-strict (+ xfail on sparse/strictest/dask-use_to_read where noted).
  • pytest tests/test_funcs.py full — all passing.
  • lefthook run pre-commit --all-files — ruff, numpydoc, mypy, pyright, blacken-docs, validate-pyproject, dprint, typos all green.
  • Dask laziness verified — lazy_xp_function(tril_indices)/lazy_xp_function(triu_indices) assert 0 .compute() calls, holds for both native and generic paths.

Resolves data-apis#686. Adds the three index-generating functions that numpy,
jax, and cupy all have but that are missing from the array-api
standard and (so far) from this library.

Signatures follow array-api conventions: parameter `offset` (matching
`xp.linalg.diagonal`) instead of numpy's `k`; keyword-only arguments
for everything except `n`; `xp` is required (these functions have no
input array to infer from, following the `default_dtype` precedent).

Delegation:
- numpy/cupy/jax: forward directly (signatures match verbatim).
- dask: has tril/triu_indices but no diag_indices.
- torch: has tril/triu_indices but with (row, col, *, offset) signature
  returning a 2xN tensor rather than a tuple; delegation translates.
  No torch.diag_indices exists; falls through to generic.
- sparse, array-api-strictest: fall through to generic; marked xfail
  on those backends (no nonzero / data-dependent shapes).

Generic implementation uses `xp.arange` + broadcasting + `xp.nonzero`
for the triangle variants. Validation (n >= 0, ndim >= 1, m >= 0)
happens in the delegation layer so all backends produce consistent
ValueErrors.

Also fixes a pre-existing bug in tests/conftest.py's NumPyReadOnly
wrapper: `type(o)(*gen)` worked for namedtuples but failed for plain
tuples of length >= 2. Exposed here because these are the first
functions in the library that return a tuple of arrays.
@bruAristimunha bruAristimunha force-pushed the diag_tril_triu_indices branch from e905c26 to 2654568 Compare April 17, 2026 11:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

ENH: add diag_indices, tril_indices, triu_indices

1 participant