From 0ac11552990eb55248e1a6da6ee3f1aedc3bafeb Mon Sep 17 00:00:00 2001 From: Tim Head Date: Fri, 17 Apr 2026 12:07:07 +0200 Subject: [PATCH 1/2] Add `bias` keyword argument to cov --- src/array_api_extra/_delegation.py | 13 +++++++++---- src/array_api_extra/_lib/_funcs.py | 4 ++-- tests/test_funcs.py | 25 +++++++++++++++++++++++++ 3 files changed, 36 insertions(+), 6 deletions(-) diff --git a/src/array_api_extra/_delegation.py b/src/array_api_extra/_delegation.py index 46639559..32af74f9 100644 --- a/src/array_api_extra/_delegation.py +++ b/src/array_api_extra/_delegation.py @@ -81,7 +81,7 @@ def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType | None = None) -> Array return _funcs.atleast_nd(x, ndim=ndim, xp=xp) -def cov(m: Array, /, *, xp: ModuleType | None = None) -> Array: +def cov(m: Array, /, *, bias: bool = False, xp: ModuleType | None = None) -> Array: """ Estimate a covariance matrix (or a stack of covariance matrices). @@ -102,6 +102,9 @@ def cov(m: Array, /, *, xp: ModuleType | None = None) -> Array: contain *M* observations of *N* variables. That is, each row of `m` represents a variable, and each column a single observation of all those variables. + bias : bool, optional + If ``False`` (default), normalize the covariance matrix by ``M - 1`` + giving an unbiased estimate. If ``True``, normalize by ``M``. xp : array_namespace, optional The standard-compatible namespace for `m`. Default: infer. @@ -164,16 +167,18 @@ def cov(m: Array, /, *, xp: ModuleType | None = None) -> Array: if xp is None: xp = array_namespace(m) + if is_torch_namespace(xp) and m.ndim <= 2: + return xp.cov(m, correction=int(not bias)) + if ( is_numpy_namespace(xp) or is_cupy_namespace(xp) - or is_torch_namespace(xp) or is_dask_namespace(xp) or is_jax_namespace(xp) ) and m.ndim <= 2: - return xp.cov(m) + return xp.cov(m, bias=bias) - return _funcs.cov(m, xp=xp) + return _funcs.cov(m, bias=bias, xp=xp) def create_diagonal( diff --git a/src/array_api_extra/_lib/_funcs.py b/src/array_api_extra/_lib/_funcs.py index 97904ddb..e52ac435 100644 --- a/src/array_api_extra/_lib/_funcs.py +++ b/src/array_api_extra/_lib/_funcs.py @@ -281,7 +281,7 @@ def broadcast_shapes(*shapes: tuple[float | None, ...]) -> tuple[int | None, ... return tuple(out) -def cov(m: Array, /, *, xp: ModuleType) -> Array: # numpydoc ignore=PR01,RT01 +def cov(m: Array, /, *, bias: bool = False, xp: ModuleType) -> Array: # numpydoc ignore=PR01,RT01 """See docstring in array_api_extra._delegation.""" m = xp.asarray(m, copy=True) dtype = ( @@ -294,7 +294,7 @@ def cov(m: Array, /, *, xp: ModuleType) -> Array: # numpydoc ignore=PR01,RT01 avg = xp.mean(m, axis=-1, keepdims=True) m_shape = eager_shape(m) - fact = m_shape[-1] - 1 + fact = m_shape[-1] - (0 if bias else 1) if fact <= 0: warnings.warn("Degrees of freedom <= 0 for slice", RuntimeWarning, stacklevel=2) diff --git a/tests/test_funcs.py b/tests/test_funcs.py index 6a11e059..95768a2a 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -608,6 +608,31 @@ def test_batch(self, xp: ModuleType): ref = np.reshape(np.stack(ref_list), (*batch_shape, n_var, n_var)) xp_assert_close(res, xp.asarray(ref)) + @pytest.mark.parametrize("bias", [True, False, 0, 1]) + def test_bias(self, xp: ModuleType, bias: bool): + x = np.array([-2.1, -1, 4.3]) + y = np.array([3, 1.1, 0.12]) + X = np.stack((x, y), axis=0) + ref = np.cov(X, bias=bias) + xp_assert_close( + cov(xp.asarray(X, dtype=xp.float64), bias=bias), + xp.asarray(ref, dtype=xp.float64), + rtol=1e-6, + ) + + @pytest.mark.parametrize("bias", [True, False, 0, 1]) + def test_bias_batch(self, xp: ModuleType, bias: bool): + rng = np.random.default_rng(8847643423) + batch_shape = (3, 4) + n_var, n_obs = 3, 20 + m = rng.random((*batch_shape, n_var, n_obs)) + res = cov(xp.asarray(m), bias=bias) + ref_list = [ + np.cov(m_, bias=bias) for m_ in np.reshape(m, (-1, n_var, n_obs)) + ] + ref = np.reshape(np.stack(ref_list), (*batch_shape, n_var, n_var)) + xp_assert_close(res, xp.asarray(ref)) + @pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no arange", strict=False) class TestOneHot: From 23077c44e72fc14cf1076ce2796658f9ca6eb6da Mon Sep 17 00:00:00 2001 From: Tim Head Date: Fri, 17 Apr 2026 13:23:20 +0200 Subject: [PATCH 2/2] Fix up lint --- src/array_api_extra/_lib/_funcs.py | 4 +++- tests/test_funcs.py | 4 +--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/array_api_extra/_lib/_funcs.py b/src/array_api_extra/_lib/_funcs.py index e52ac435..9716a19a 100644 --- a/src/array_api_extra/_lib/_funcs.py +++ b/src/array_api_extra/_lib/_funcs.py @@ -281,7 +281,9 @@ def broadcast_shapes(*shapes: tuple[float | None, ...]) -> tuple[int | None, ... return tuple(out) -def cov(m: Array, /, *, bias: bool = False, xp: ModuleType) -> Array: # numpydoc ignore=PR01,RT01 +def cov( # numpydoc ignore=PR01,RT01 + m: Array, /, *, bias: bool = False, xp: ModuleType +) -> Array: """See docstring in array_api_extra._delegation.""" m = xp.asarray(m, copy=True) dtype = ( diff --git a/tests/test_funcs.py b/tests/test_funcs.py index 95768a2a..2025fc91 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -627,9 +627,7 @@ def test_bias_batch(self, xp: ModuleType, bias: bool): n_var, n_obs = 3, 20 m = rng.random((*batch_shape, n_var, n_obs)) res = cov(xp.asarray(m), bias=bias) - ref_list = [ - np.cov(m_, bias=bias) for m_ in np.reshape(m, (-1, n_var, n_obs)) - ] + ref_list = [np.cov(m_, bias=bias) for m_ in np.reshape(m, (-1, n_var, n_obs))] ref = np.reshape(np.stack(ref_list), (*batch_shape, n_var, n_var)) xp_assert_close(res, xp.asarray(ref))