diff --git a/src/array_api_extra/__init__.py b/src/array_api_extra/__init__.py index 2fcdcd8e..44b2dc47 100644 --- a/src/array_api_extra/__init__.py +++ b/src/array_api_extra/__init__.py @@ -5,6 +5,7 @@ atleast_nd, cov, create_diagonal, + diag_indices, expand_dims, isclose, isin, @@ -15,6 +16,8 @@ searchsorted, setdiff1d, sinc, + tril_indices, + triu_indices, union1d, ) from ._lib._at import at @@ -40,6 +43,7 @@ "cov", "create_diagonal", "default_dtype", + "diag_indices", "expand_dims", "isclose", "isin", @@ -53,5 +57,7 @@ "searchsorted", "setdiff1d", "sinc", + "tril_indices", + "triu_indices", "union1d", ] diff --git a/src/array_api_extra/_delegation.py b/src/array_api_extra/_delegation.py index 46639559..4bc7adbc 100644 --- a/src/array_api_extra/_delegation.py +++ b/src/array_api_extra/_delegation.py @@ -22,6 +22,7 @@ "atleast_nd", "cov", "create_diagonal", + "diag_indices", "expand_dims", "isclose", "nan_to_num", @@ -29,6 +30,8 @@ "pad", "searchsorted", "sinc", + "tril_indices", + "triu_indices", ] @@ -238,6 +241,49 @@ def create_diagonal( return _funcs.create_diagonal(x, offset=offset, xp=xp) +def diag_indices(n: int, /, *, ndim: int = 2, xp: ModuleType) -> tuple[Array, ...]: + """ + Return the indices to access the main diagonal of an array. + + Equivalent to ``numpy.diag_indices``. + + Parameters + ---------- + n : int + The size of each dimension of the (hyper-)cube ``(n, n, ..., n)`` + that the returned indices index into. + ndim : int, optional + The number of dimensions. Default: ``2``. + xp : array_namespace + The standard-compatible namespace to create the indices in. + + Returns + ------- + tuple of array + ``ndim`` 1-D integer arrays of length ``n`` that together index + the main diagonal of an array of shape ``(n,) * ndim``. + + Examples + -------- + >>> import array_api_strict as xp + >>> import array_api_extra as xpx + >>> rows, cols = xpx.diag_indices(3, xp=xp) + >>> rows + Array([0, 1, 2], dtype=array_api_strict.int64) + >>> cols + Array([0, 1, 2], dtype=array_api_strict.int64) + """ + if n < 0: + msg = f"`n` must be non-negative, got {n}" + raise ValueError(msg) + if ndim < 1: + msg = f"`ndim` must be >= 1, got {ndim}" + raise ValueError(msg) + if is_numpy_namespace(xp) or is_cupy_namespace(xp) or is_jax_namespace(xp): + return xp.diag_indices(n, ndim=ndim) + return _funcs.diag_indices(n, ndim=ndim, xp=xp) + + def expand_dims( a: Array, /, *, axis: int | tuple[int, ...] = (0,), xp: ModuleType | None = None ) -> Array: @@ -1150,3 +1196,119 @@ def union1d(a: Array, b: Array, /, *, xp: ModuleType | None = None) -> Array: return xp.union1d(a, b) return _funcs.union1d(a, b, xp=xp) + + +def tril_indices( + n: int, /, *, offset: int = 0, m: int | None = None, xp: ModuleType +) -> tuple[Array, Array]: + """ + Return the indices of the lower triangle of an ``(n, m)`` array. + + Equivalent to ``numpy.tril_indices`` with parameter ``k`` renamed to + ``offset`` to match ``xp.linalg.diagonal``'s naming. + + Parameters + ---------- + n : int + The row dimension of the array. + offset : int, optional + Diagonal offset; ``0`` (default) is the main diagonal. Corresponds + to ``k`` in ``numpy.tril_indices``. + m : int, optional + The column dimension. If ``None`` (default), assumed equal to `n`. + xp : array_namespace + The standard-compatible namespace to create the indices in. + + Returns + ------- + tuple of array + Row and column indices ``(rows, cols)`` of the lower triangle of + the ``(n, m)`` matrix, shifted by `offset`. + + Examples + -------- + >>> import array_api_strict as xp + >>> import array_api_extra as xpx + >>> rows, cols = xpx.tril_indices(3, xp=xp) + >>> rows + Array([0, 1, 1, 2, 2, 2], dtype=array_api_strict.int64) + >>> cols + Array([0, 0, 1, 0, 1, 2], dtype=array_api_strict.int64) + """ + if n < 0: + msg = f"`n` must be non-negative, got {n}" + raise ValueError(msg) + if m is not None and m < 0: + msg = f"`m` must be non-negative, got {m}" + raise ValueError(msg) + if ( + is_numpy_namespace(xp) + or is_cupy_namespace(xp) + or is_jax_namespace(xp) + or is_dask_namespace(xp) + ): + return xp.tril_indices(n, k=offset, m=m) + if is_torch_namespace(xp): + # `torch.tril_indices` returns a 2xN tensor, not a tuple, and + # takes (row, col) rather than (n, *, m=None). + cols = n if m is None else m + idx = xp.tril_indices(n, cols, offset=offset) + return (idx[0], idx[1]) + return _funcs.tril_indices(n, offset=offset, m=m, xp=xp) + + +def triu_indices( + n: int, /, *, offset: int = 0, m: int | None = None, xp: ModuleType +) -> tuple[Array, Array]: + """ + Return the indices of the upper triangle of an ``(n, m)`` array. + + Equivalent to ``numpy.triu_indices`` with parameter ``k`` renamed to + ``offset`` to match ``xp.linalg.diagonal``'s naming. + + Parameters + ---------- + n : int + The row dimension of the array. + offset : int, optional + Diagonal offset; ``0`` (default) is the main diagonal. Corresponds + to ``k`` in ``numpy.triu_indices``. + m : int, optional + The column dimension. If ``None`` (default), assumed equal to `n`. + xp : array_namespace + The standard-compatible namespace to create the indices in. + + Returns + ------- + tuple of array + Row and column indices ``(rows, cols)`` of the upper triangle of + the ``(n, m)`` matrix, shifted by `offset`. + + Examples + -------- + >>> import array_api_strict as xp + >>> import array_api_extra as xpx + >>> rows, cols = xpx.triu_indices(3, xp=xp) + >>> rows + Array([0, 0, 0, 1, 1, 2], dtype=array_api_strict.int64) + >>> cols + Array([0, 1, 2, 1, 2, 2], dtype=array_api_strict.int64) + """ + if n < 0: + msg = f"`n` must be non-negative, got {n}" + raise ValueError(msg) + if m is not None and m < 0: + msg = f"`m` must be non-negative, got {m}" + raise ValueError(msg) + if ( + is_numpy_namespace(xp) + or is_cupy_namespace(xp) + or is_jax_namespace(xp) + or is_dask_namespace(xp) + ): + return xp.triu_indices(n, k=offset, m=m) + if is_torch_namespace(xp): + cols = n if m is None else m + idx = xp.triu_indices(n, cols, offset=offset) + return (idx[0], idx[1]) + return _funcs.triu_indices(n, offset=offset, m=m, xp=xp) diff --git a/src/array_api_extra/_lib/_funcs.py b/src/array_api_extra/_lib/_funcs.py index 97904ddb..1b346ae2 100644 --- a/src/array_api_extra/_lib/_funcs.py +++ b/src/array_api_extra/_lib/_funcs.py @@ -28,6 +28,7 @@ "broadcast_shapes", "cov", "create_diagonal", + "diag_indices", "expand_dims", "kron", "nunique", @@ -35,6 +36,8 @@ "searchsorted", "setdiff1d", "sinc", + "tril_indices", + "triu_indices", ] @@ -346,6 +349,41 @@ def create_diagonal( return xp.reshape(diag, (*batch_dims, n, n)) +def diag_indices( + n: int, /, *, ndim: int = 2, xp: ModuleType +) -> tuple[Array, ...]: # numpydoc ignore=PR01,RT01 + """See docstring in array_api_extra._delegation.""" + idx = xp.arange(n) + return (idx,) * ndim + + +def _tri_indices( + n: int, *, offset: int, m: int | None, upper: bool, xp: ModuleType +) -> tuple[Array, Array]: # numpydoc ignore=PR01,RT01 + """Shared implementation for `tril_indices` and `triu_indices`.""" + cols = n if m is None else m + rows = xp.arange(n)[:, None] + cols_a = xp.arange(cols)[None, :] + delta = cols_a - rows + mask = delta >= offset if upper else delta <= offset + r, c = xp.nonzero(mask) + return (r, c) + + +def tril_indices( + n: int, /, *, offset: int = 0, m: int | None = None, xp: ModuleType +) -> tuple[Array, Array]: # numpydoc ignore=PR01,RT01 + """See docstring in array_api_extra._delegation.""" + return _tri_indices(n, offset=offset, m=m, upper=False, xp=xp) + + +def triu_indices( + n: int, /, *, offset: int = 0, m: int | None = None, xp: ModuleType +) -> tuple[Array, Array]: # numpydoc ignore=PR01,RT01 + """See docstring in array_api_extra._delegation.""" + return _tri_indices(n, offset=offset, m=m, upper=True, xp=xp) + + def default_dtype( xp: ModuleType, kind: Literal[ diff --git a/tests/conftest.py b/tests/conftest.py index df703b97..6c735a20 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -96,9 +96,12 @@ def as_readonly(o: T) -> T: # numpydoc ignore=PR01,RT01 # Cannot interpret as a data type return o - # This works with namedtuples too if isinstance(o, tuple | list): - return type(o)(*(as_readonly(i) for i in o)) # type: ignore[arg-type,return-value] # pyright: ignore[reportArgumentType] + # namedtuple wants positional args; plain tuple/list wants an iterable. + items = (as_readonly(i) for i in o) + if hasattr(o, "_fields"): + return type(o)(*items) # type: ignore[arg-type,return-value] # pyright: ignore[reportArgumentType] + return type(o)(items) # type: ignore[return-value] return o diff --git a/tests/test_funcs.py b/tests/test_funcs.py index 6a11e059..89f09134 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -21,6 +21,7 @@ cov, create_diagonal, default_dtype, + diag_indices, expand_dims, isclose, isin, @@ -32,6 +33,8 @@ partition, setdiff1d, sinc, + tril_indices, + triu_indices, union1d, ) from array_api_extra import ( @@ -56,6 +59,7 @@ lazy_xp_function(cov) lazy_xp_function(create_diagonal) lazy_xp_function(default_dtype) +lazy_xp_function(diag_indices) lazy_xp_function(expand_dims) lazy_xp_function(isclose) lazy_xp_function(isin) @@ -68,6 +72,8 @@ # FIXME calls in1d which calls xp.unique_values without size lazy_xp_function(setdiff1d, jax_jit=False) lazy_xp_function(sinc) +lazy_xp_function(tril_indices) +lazy_xp_function(triu_indices) lazy_xp_function(union1d, jax_jit=False) lazy_xp_function(xpx_searchsorted) lazy_xp_function(_funcs_searchsorted) @@ -803,6 +809,108 @@ def test_torch(self, torch: ModuleType): assert default_dtype(xp, "complex floating") == xp.complex64 +@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no arange", strict=False) +class TestDiagIndices: + def test_basic(self, xp: ModuleType): + rows, cols = diag_indices(5, xp=xp) + ref_rows, ref_cols = np.diag_indices(5) + xp_assert_equal(rows, xp.asarray(ref_rows)) + xp_assert_equal(cols, xp.asarray(ref_cols)) + + @pytest.mark.parametrize("ndim", [1, 2, 3, 4]) + def test_ndim(self, xp: ModuleType, ndim: int): + idx = diag_indices(4, ndim=ndim, xp=xp) + assert len(idx) == ndim + ref = np.diag_indices(4, ndim=ndim) + for got, expected in zip(idx, ref, strict=True): + xp_assert_equal(got, xp.asarray(expected)) + + def test_empty(self, xp: ModuleType): + rows, cols = diag_indices(0, xp=xp) + assert rows.shape == (0,) + assert cols.shape == (0,) + + def test_validation(self, xp: ModuleType): + with pytest.raises(ValueError, match="`n` must be non-negative"): + _ = diag_indices(-1, xp=xp) + with pytest.raises(ValueError, match="`ndim` must be >= 1"): + _ = diag_indices(3, ndim=0, xp=xp) + + +@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no arange/nonzero", strict=False) +@pytest.mark.xfail_xp_backend( + Backend.ARRAY_API_STRICTEST, + reason="generic path uses nonzero (data-dependent)", + strict=False, +) +@pytest.mark.parametrize( + ("xpx_fn", "np_fn"), + [(tril_indices, np.tril_indices), (triu_indices, np.triu_indices)], + ids=["tril", "triu"], +) +class TestTriIndices: + def test_basic( + self, + xp: ModuleType, + xpx_fn: Callable[..., tuple[Array, Array]], + np_fn: Callable[..., tuple[Array, Array]], + ): + rows, cols = xpx_fn(4, xp=xp) + ref_rows, ref_cols = np_fn(4) + xp_assert_equal(rows, xp.asarray(ref_rows)) + xp_assert_equal(cols, xp.asarray(ref_cols)) + + @pytest.mark.parametrize("offset", [-2, -1, 0, 1, 2]) + def test_offset( + self, + xp: ModuleType, + xpx_fn: Callable[..., tuple[Array, Array]], + np_fn: Callable[..., tuple[Array, Array]], + offset: int, + ): + rows, cols = xpx_fn(5, offset=offset, xp=xp) + ref_rows, ref_cols = np_fn(5, k=offset) + xp_assert_equal(rows, xp.asarray(ref_rows)) + xp_assert_equal(cols, xp.asarray(ref_cols)) + + def test_rectangular( + self, + xp: ModuleType, + xpx_fn: Callable[..., tuple[Array, Array]], + np_fn: Callable[..., tuple[Array, Array]], + ): + rows, cols = xpx_fn(3, m=5, xp=xp) + ref_rows, ref_cols = np_fn(3, m=5) + xp_assert_equal(rows, xp.asarray(ref_rows)) + xp_assert_equal(cols, xp.asarray(ref_cols)) + + @pytest.mark.xfail_xp_backend( + Backend.DASK, reason="dask: no 2D fancy indexing", strict=False + ) + def test_use_to_read( + self, + xp: ModuleType, + xpx_fn: Callable[..., tuple[Array, Array]], + np_fn: Callable[..., tuple[Array, Array]], + ): + rng = np.random.default_rng(0) + a = rng.integers(0, 100, (4, 4)) + a_xp = xp.asarray(a) + rows, cols = xpx_fn(4, xp=xp) + xp_assert_equal(a_xp[rows, cols], xp.asarray(a[np_fn(4)])) + + def test_validation( + self, + xp: ModuleType, + xpx_fn: Callable[..., tuple[Array, Array]], + np_fn: Callable[..., tuple[Array, Array]], # noqa: ARG002 # pytest param + ): + with pytest.raises(ValueError, match="`n` must be non-negative"): + _ = xpx_fn(-1, xp=xp) + with pytest.raises(ValueError, match="`m` must be non-negative"): + _ = xpx_fn(3, m=-1, xp=xp) + + class TestExpandDims: def test_single_axis(self, xp: ModuleType): """Trivial case where xpx.expand_dims doesn't add anything to xp.expand_dims"""