Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions src/array_api_extra/_delegation.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType | None = None) -> Array
return _funcs.atleast_nd(x, ndim=ndim, xp=xp)


def cov(m: Array, /, *, xp: ModuleType | None = None) -> Array:
def cov(m: Array, /, *, bias: bool = False, xp: ModuleType | None = None) -> Array:
"""
Estimate a covariance matrix (or a stack of covariance matrices).

Expand All @@ -102,6 +102,9 @@ def cov(m: Array, /, *, xp: ModuleType | None = None) -> Array:
contain *M* observations of *N* variables. That is,
each row of `m` represents a variable, and each column a single
observation of all those variables.
bias : bool, optional
If ``False`` (default), normalize the covariance matrix by ``M - 1``
giving an unbiased estimate. If ``True``, normalize by ``M``.
xp : array_namespace, optional
The standard-compatible namespace for `m`. Default: infer.

Expand Down Expand Up @@ -164,16 +167,18 @@ def cov(m: Array, /, *, xp: ModuleType | None = None) -> Array:
if xp is None:
xp = array_namespace(m)

if is_torch_namespace(xp) and m.ndim <= 2:
return xp.cov(m, correction=int(not bias))

if (
is_numpy_namespace(xp)
or is_cupy_namespace(xp)
or is_torch_namespace(xp)
or is_dask_namespace(xp)
or is_jax_namespace(xp)
) and m.ndim <= 2:
return xp.cov(m)
return xp.cov(m, bias=bias)

return _funcs.cov(m, xp=xp)
return _funcs.cov(m, bias=bias, xp=xp)


def create_diagonal(
Expand Down
6 changes: 4 additions & 2 deletions src/array_api_extra/_lib/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,9 @@ def broadcast_shapes(*shapes: tuple[float | None, ...]) -> tuple[int | None, ...
return tuple(out)


def cov(m: Array, /, *, xp: ModuleType) -> Array: # numpydoc ignore=PR01,RT01
def cov( # numpydoc ignore=PR01,RT01
m: Array, /, *, bias: bool = False, xp: ModuleType
) -> Array:
"""See docstring in array_api_extra._delegation."""
m = xp.asarray(m, copy=True)
dtype = (
Expand All @@ -294,7 +296,7 @@ def cov(m: Array, /, *, xp: ModuleType) -> Array: # numpydoc ignore=PR01,RT01
avg = xp.mean(m, axis=-1, keepdims=True)

m_shape = eager_shape(m)
fact = m_shape[-1] - 1
fact = m_shape[-1] - (0 if bias else 1)

if fact <= 0:
warnings.warn("Degrees of freedom <= 0 for slice", RuntimeWarning, stacklevel=2)
Expand Down
23 changes: 23 additions & 0 deletions tests/test_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,6 +608,29 @@ def test_batch(self, xp: ModuleType):
ref = np.reshape(np.stack(ref_list), (*batch_shape, n_var, n_var))
xp_assert_close(res, xp.asarray(ref))

@pytest.mark.parametrize("bias", [True, False, 0, 1])
def test_bias(self, xp: ModuleType, bias: bool):
x = np.array([-2.1, -1, 4.3])
y = np.array([3, 1.1, 0.12])
X = np.stack((x, y), axis=0)
ref = np.cov(X, bias=bias)
xp_assert_close(
cov(xp.asarray(X, dtype=xp.float64), bias=bias),
xp.asarray(ref, dtype=xp.float64),
rtol=1e-6,
)

@pytest.mark.parametrize("bias", [True, False, 0, 1])
def test_bias_batch(self, xp: ModuleType, bias: bool):
rng = np.random.default_rng(8847643423)
batch_shape = (3, 4)
n_var, n_obs = 3, 20
m = rng.random((*batch_shape, n_var, n_obs))
res = cov(xp.asarray(m), bias=bias)
ref_list = [np.cov(m_, bias=bias) for m_ in np.reshape(m, (-1, n_var, n_obs))]
ref = np.reshape(np.stack(ref_list), (*batch_shape, n_var, n_var))
xp_assert_close(res, xp.asarray(ref))


@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no arange", strict=False)
class TestOneHot:
Expand Down