Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/array_api_extra/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
atleast_nd,
cov,
create_diagonal,
diag_indices,
expand_dims,
isclose,
isin,
Expand All @@ -15,6 +16,8 @@
searchsorted,
setdiff1d,
sinc,
tril_indices,
triu_indices,
union1d,
)
from ._lib._at import at
Expand All @@ -40,6 +43,7 @@
"cov",
"create_diagonal",
"default_dtype",
"diag_indices",
"expand_dims",
"isclose",
"isin",
Expand All @@ -53,5 +57,7 @@
"searchsorted",
"setdiff1d",
"sinc",
"tril_indices",
"triu_indices",
"union1d",
]
162 changes: 162 additions & 0 deletions src/array_api_extra/_delegation.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,16 @@
"atleast_nd",
"cov",
"create_diagonal",
"diag_indices",
"expand_dims",
"isclose",
"nan_to_num",
"one_hot",
"pad",
"searchsorted",
"sinc",
"tril_indices",
"triu_indices",
]


Expand Down Expand Up @@ -238,6 +241,49 @@ def create_diagonal(
return _funcs.create_diagonal(x, offset=offset, xp=xp)


def diag_indices(n: int, /, *, ndim: int = 2, xp: ModuleType) -> tuple[Array, ...]:
"""
Return the indices to access the main diagonal of an array.

Equivalent to ``numpy.diag_indices``.

Parameters
----------
n : int
The size of each dimension of the (hyper-)cube ``(n, n, ..., n)``
that the returned indices index into.
ndim : int, optional
The number of dimensions. Default: ``2``.
xp : array_namespace
The standard-compatible namespace to create the indices in.

Returns
-------
tuple of array
``ndim`` 1-D integer arrays of length ``n`` that together index
the main diagonal of an array of shape ``(n,) * ndim``.

Examples
--------
>>> import array_api_strict as xp
>>> import array_api_extra as xpx
>>> rows, cols = xpx.diag_indices(3, xp=xp)
>>> rows
Array([0, 1, 2], dtype=array_api_strict.int64)
>>> cols
Array([0, 1, 2], dtype=array_api_strict.int64)
"""
if n < 0:
msg = f"`n` must be non-negative, got {n}"
raise ValueError(msg)
if ndim < 1:
msg = f"`ndim` must be >= 1, got {ndim}"
raise ValueError(msg)
if is_numpy_namespace(xp) or is_cupy_namespace(xp) or is_jax_namespace(xp):
return xp.diag_indices(n, ndim=ndim)
return _funcs.diag_indices(n, ndim=ndim, xp=xp)


def expand_dims(
a: Array, /, *, axis: int | tuple[int, ...] = (0,), xp: ModuleType | None = None
) -> Array:
Expand Down Expand Up @@ -1150,3 +1196,119 @@ def union1d(a: Array, b: Array, /, *, xp: ModuleType | None = None) -> Array:
return xp.union1d(a, b)

return _funcs.union1d(a, b, xp=xp)


def tril_indices(
n: int, /, *, offset: int = 0, m: int | None = None, xp: ModuleType
) -> tuple[Array, Array]:
"""
Return the indices of the lower triangle of an ``(n, m)`` array.

Equivalent to ``numpy.tril_indices`` with parameter ``k`` renamed to
``offset`` to match ``xp.linalg.diagonal``'s naming.

Parameters
----------
n : int
The row dimension of the array.
offset : int, optional
Diagonal offset; ``0`` (default) is the main diagonal. Corresponds
to ``k`` in ``numpy.tril_indices``.
m : int, optional
The column dimension. If ``None`` (default), assumed equal to `n`.
xp : array_namespace
The standard-compatible namespace to create the indices in.

Returns
-------
tuple of array
Row and column indices ``(rows, cols)`` of the lower triangle of
the ``(n, m)`` matrix, shifted by `offset`.

Examples
--------
>>> import array_api_strict as xp
>>> import array_api_extra as xpx
>>> rows, cols = xpx.tril_indices(3, xp=xp)
>>> rows
Array([0, 1, 1, 2, 2, 2], dtype=array_api_strict.int64)
>>> cols
Array([0, 0, 1, 0, 1, 2], dtype=array_api_strict.int64)
"""
if n < 0:
msg = f"`n` must be non-negative, got {n}"
raise ValueError(msg)
if m is not None and m < 0:
msg = f"`m` must be non-negative, got {m}"
raise ValueError(msg)
if (
is_numpy_namespace(xp)
or is_cupy_namespace(xp)
or is_jax_namespace(xp)
or is_dask_namespace(xp)
):
return xp.tril_indices(n, k=offset, m=m)
if is_torch_namespace(xp):
# `torch.tril_indices` returns a 2xN tensor, not a tuple, and
# takes (row, col) rather than (n, *, m=None).
cols = n if m is None else m
idx = xp.tril_indices(n, cols, offset=offset)
return (idx[0], idx[1])
return _funcs.tril_indices(n, offset=offset, m=m, xp=xp)


def triu_indices(
n: int, /, *, offset: int = 0, m: int | None = None, xp: ModuleType
) -> tuple[Array, Array]:
"""
Return the indices of the upper triangle of an ``(n, m)`` array.

Equivalent to ``numpy.triu_indices`` with parameter ``k`` renamed to
``offset`` to match ``xp.linalg.diagonal``'s naming.

Parameters
----------
n : int
The row dimension of the array.
offset : int, optional
Diagonal offset; ``0`` (default) is the main diagonal. Corresponds
to ``k`` in ``numpy.triu_indices``.
m : int, optional
The column dimension. If ``None`` (default), assumed equal to `n`.
xp : array_namespace
The standard-compatible namespace to create the indices in.

Returns
-------
tuple of array
Row and column indices ``(rows, cols)`` of the upper triangle of
the ``(n, m)`` matrix, shifted by `offset`.

Examples
--------
>>> import array_api_strict as xp
>>> import array_api_extra as xpx
>>> rows, cols = xpx.triu_indices(3, xp=xp)
>>> rows
Array([0, 0, 0, 1, 1, 2], dtype=array_api_strict.int64)
>>> cols
Array([0, 1, 2, 1, 2, 2], dtype=array_api_strict.int64)
"""
if n < 0:
msg = f"`n` must be non-negative, got {n}"
raise ValueError(msg)
if m is not None and m < 0:
msg = f"`m` must be non-negative, got {m}"
raise ValueError(msg)
if (
is_numpy_namespace(xp)
or is_cupy_namespace(xp)
or is_jax_namespace(xp)
or is_dask_namespace(xp)
):
return xp.triu_indices(n, k=offset, m=m)
if is_torch_namespace(xp):
cols = n if m is None else m
idx = xp.triu_indices(n, cols, offset=offset)
return (idx[0], idx[1])
return _funcs.triu_indices(n, offset=offset, m=m, xp=xp)
38 changes: 38 additions & 0 deletions src/array_api_extra/_lib/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,16 @@
"broadcast_shapes",
"cov",
"create_diagonal",
"diag_indices",
"expand_dims",
"kron",
"nunique",
"pad",
"searchsorted",
"setdiff1d",
"sinc",
"tril_indices",
"triu_indices",
]


Expand Down Expand Up @@ -346,6 +349,41 @@ def create_diagonal(
return xp.reshape(diag, (*batch_dims, n, n))


def diag_indices(
n: int, /, *, ndim: int = 2, xp: ModuleType
) -> tuple[Array, ...]: # numpydoc ignore=PR01,RT01
"""See docstring in array_api_extra._delegation."""
idx = xp.arange(n)
return (idx,) * ndim


def _tri_indices(
n: int, *, offset: int, m: int | None, upper: bool, xp: ModuleType
) -> tuple[Array, Array]: # numpydoc ignore=PR01,RT01
"""Shared implementation for `tril_indices` and `triu_indices`."""
cols = n if m is None else m
rows = xp.arange(n)[:, None]
cols_a = xp.arange(cols)[None, :]
delta = cols_a - rows
mask = delta >= offset if upper else delta <= offset
r, c = xp.nonzero(mask)
return (r, c)


def tril_indices(
n: int, /, *, offset: int = 0, m: int | None = None, xp: ModuleType
) -> tuple[Array, Array]: # numpydoc ignore=PR01,RT01
"""See docstring in array_api_extra._delegation."""
return _tri_indices(n, offset=offset, m=m, upper=False, xp=xp)


def triu_indices(
n: int, /, *, offset: int = 0, m: int | None = None, xp: ModuleType
) -> tuple[Array, Array]: # numpydoc ignore=PR01,RT01
"""See docstring in array_api_extra._delegation."""
return _tri_indices(n, offset=offset, m=m, upper=True, xp=xp)


def default_dtype(
xp: ModuleType,
kind: Literal[
Expand Down
7 changes: 5 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,12 @@ def as_readonly(o: T) -> T: # numpydoc ignore=PR01,RT01
# Cannot interpret as a data type
return o

# This works with namedtuples too
if isinstance(o, tuple | list):
return type(o)(*(as_readonly(i) for i in o)) # type: ignore[arg-type,return-value] # pyright: ignore[reportArgumentType]
# namedtuple wants positional args; plain tuple/list wants an iterable.
items = (as_readonly(i) for i in o)
if hasattr(o, "_fields"):
return type(o)(*items) # type: ignore[arg-type,return-value] # pyright: ignore[reportArgumentType]
return type(o)(items) # type: ignore[return-value]

return o

Expand Down
Loading