diff --git a/src/array_api_extra/_delegation.py b/src/array_api_extra/_delegation.py index 46639559..31ace576 100644 --- a/src/array_api_extra/_delegation.py +++ b/src/array_api_extra/_delegation.py @@ -81,7 +81,16 @@ def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType | None = None) -> Array return _funcs.atleast_nd(x, ndim=ndim, xp=xp) -def cov(m: Array, /, *, xp: ModuleType | None = None) -> Array: +def cov( + m: Array, + /, + *, + axis: int = -1, + correction: int | float = 1, + frequency_weights: Array | None = None, + weights: Array | None = None, + xp: ModuleType | None = None, +) -> Array: """ Estimate a covariance matrix (or a stack of covariance matrices). @@ -92,16 +101,37 @@ def cov(m: Array, /, *, xp: ModuleType | None = None) -> Array: :math:`x_i` and :math:`x_j`. The element :math:`C_{ii}` is the variance of :math:`x_i`. - With the exception of supporting batch input, this provides a subset of - the functionality of ``numpy.cov``. + Extends ``numpy.cov`` with support for batch input and array-api + backends. Naming follows the array-api conventions used elsewhere in + this library (``axis``, ``correction``) rather than the numpy spellings + (``rowvar``, ``bias``, ``ddof``); see Notes for the mapping. Parameters ---------- m : array An array of shape ``(..., N, M)`` whose innermost two dimensions - contain *M* observations of *N* variables. That is, - each row of `m` represents a variable, and each column a single - observation of all those variables. + contain *M* observations of *N* variables by default. The axis of + observations is controlled by `axis`. + axis : int, optional + Axis of `m` containing the observations. Default: ``-1`` (the last + axis), matching the array-api convention. Use ``axis=-2`` (or ``0`` + for 2-D input) to treat each column as a variable, which + corresponds to ``rowvar=False`` in ``numpy.cov``. + correction : int or float, optional + Degrees of freedom correction: normalization divides by + ``N - correction`` (for unweighted input). Default: ``1``, which + gives the unbiased estimate (matches ``numpy.cov`` default of + ``bias=False``). Set to ``0`` for the biased estimate (``N`` + normalization). Corresponds to ``ddof`` in ``numpy.cov`` and to + ``correction`` in ``numpy.var``/``std`` and ``torch.cov``. + frequency_weights : array, optional + 1-D array of integer frequency weights: the number of times each + observation is repeated. Corresponds to ``fweights`` in + ``numpy.cov``/``torch.cov``. + weights : array, optional + 1-D array of observation-vector weights (analytic weights). Larger + values mark more important observations. Corresponds to + ``aweights`` in ``numpy.cov``/``torch.cov``. xp : array_namespace, optional The standard-compatible namespace for `m`. Default: infer. @@ -111,6 +141,23 @@ def cov(m: Array, /, *, xp: ModuleType | None = None) -> Array: An array having shape (..., N, N) whose innermost two dimensions represent the covariance matrix of the variables. + Notes + ----- + Mapping from ``numpy.cov`` to this function:: + + numpy.cov(m, rowvar=True) -> cov(m, axis=-1) # default + numpy.cov(m, rowvar=False) -> cov(m, axis=-2) + numpy.cov(m, bias=True) -> cov(m, correction=0) + numpy.cov(m, ddof=k) -> cov(m, correction=k) + numpy.cov(m, fweights=f) -> cov(m, frequency_weights=f) + numpy.cov(m, aweights=a) -> cov(m, weights=a) + + Unlike ``numpy.cov``, a ``RuntimeWarning`` for non-positive effective + degrees of freedom is only emitted on the unweighted path. The + weighted path omits the check so that lazy backends (e.g. Dask) can + stay lazy end-to-end; choose ``correction`` and weights such that the + effective normalizer is positive. + Examples -------- >>> import array_api_strict as xp @@ -164,16 +211,57 @@ def cov(m: Array, /, *, xp: ModuleType | None = None) -> Array: if xp is None: xp = array_namespace(m) - if ( - is_numpy_namespace(xp) - or is_cupy_namespace(xp) - or is_torch_namespace(xp) - or is_dask_namespace(xp) - or is_jax_namespace(xp) - ) and m.ndim <= 2: - return xp.cov(m) - - return _funcs.cov(m, xp=xp) + # Validate axis against m.ndim. + ndim = max(m.ndim, 1) + if not -ndim <= axis < ndim: + msg = f"axis {axis} is out of bounds for array of dimension {m.ndim}" + raise IndexError(msg) + + # Normalize: observations on the last axis. After this, every backend + # sees the same convention and we never need to deal with `rowvar`. + if m.ndim >= 2 and axis not in (-1, m.ndim - 1): + m = xp.moveaxis(m, axis, -1) + + # `numpy.cov` (and cupy/dask/jax) require integer `ddof`; `torch.cov` + # requires integer `correction`. For non-integer-valued `correction`, + # fall through to the generic implementation. + integer_correction = isinstance(correction, int) or correction.is_integer() + has_weights = frequency_weights is not None or weights is not None + + if m.ndim <= 2 and integer_correction: + if is_torch_namespace(xp): + device = get_device(m) + fw = ( + None + if frequency_weights is None + else xp.asarray(frequency_weights, device=device) + ) + aw = None if weights is None else xp.asarray(weights, device=device) + return xp.cov(m, correction=int(correction), fweights=fw, aweights=aw) + # `dask.array.cov` forces `.compute()` whenever weights are given: + # its internal `if fact <= 0` check on a lazy 0-D scalar triggers + # materialization. Route to the generic impl, which is fully lazy + # because it only does sum/matmul and skips that scalar check. + if ( + is_numpy_namespace(xp) + or is_cupy_namespace(xp) + or is_jax_namespace(xp) + or (is_dask_namespace(xp) and not has_weights) + ): + return xp.cov( + m, + ddof=int(correction), + fweights=frequency_weights, + aweights=weights, + ) + + return _funcs.cov( + m, + correction=correction, + frequency_weights=frequency_weights, + weights=weights, + xp=xp, + ) def create_diagonal( diff --git a/src/array_api_extra/_lib/_funcs.py b/src/array_api_extra/_lib/_funcs.py index 97904ddb..4f9309ec 100644 --- a/src/array_api_extra/_lib/_funcs.py +++ b/src/array_api_extra/_lib/_funcs.py @@ -281,9 +281,17 @@ def broadcast_shapes(*shapes: tuple[float | None, ...]) -> tuple[int | None, ... return tuple(out) -def cov(m: Array, /, *, xp: ModuleType) -> Array: # numpydoc ignore=PR01,RT01 +def cov( + m: Array, + /, + *, + correction: int | float = 1, + frequency_weights: Array | None = None, + weights: Array | None = None, + xp: ModuleType, +) -> Array: # numpydoc ignore=PR01,RT01 """See docstring in array_api_extra._delegation.""" - m = xp.asarray(m, copy=True) + m = xp.asarray(m) dtype = ( xp.float64 if xp.isdtype(m.dtype, "integral") else xp.result_type(m, xp.float64) ) @@ -291,21 +299,49 @@ def cov(m: Array, /, *, xp: ModuleType) -> Array: # numpydoc ignore=PR01,RT01 m = atleast_nd(m, ndim=2, xp=xp) m = xp.astype(m, dtype) - avg = xp.mean(m, axis=-1, keepdims=True) + device = _compat.device(m) + fw = ( + None + if frequency_weights is None + else xp.astype(xp.asarray(frequency_weights, device=device), dtype) + ) + aw = ( + None + if weights is None + else xp.astype(xp.asarray(weights, device=device), dtype) + ) + if fw is None and aw is None: + w = None + elif fw is None: + w = aw + elif aw is None: + w = fw + else: + w = fw * aw m_shape = eager_shape(m) - fact = m_shape[-1] - 1 - - if fact <= 0: - warnings.warn("Degrees of freedom <= 0 for slice", RuntimeWarning, stacklevel=2) - fact = 0 - - m -= avg - m_transpose = xp.matrix_transpose(m) - if xp.isdtype(m_transpose.dtype, "complex floating"): - m_transpose = xp.conj(m_transpose) - c = xp.matmul(m, m_transpose) - c /= fact + if w is None: + avg = xp.mean(m, axis=-1, keepdims=True) + fact = m_shape[-1] - correction + if fact <= 0: + warnings.warn( + "Degrees of freedom <= 0 for slice", RuntimeWarning, stacklevel=2 + ) + fact = 0 + else: + v1 = xp.sum(w, axis=-1) + avg = xp.sum(m * w, axis=-1, keepdims=True) / v1 + if aw is None: + fact = v1 - correction + else: + fact = v1 - correction * xp.sum(w * aw, axis=-1) / v1 + + m_c = m - avg + m_w = m_c if w is None else m_c * w + m_cT = xp.matrix_transpose(m_c) + if xp.isdtype(m_cT.dtype, "complex floating"): + m_cT = xp.conj(m_cT) + c = xp.matmul(m_w, m_cT) / fact axes = tuple(axis for axis, length in enumerate(c.shape) if length == 1) return xp.squeeze(c, axis=axes) diff --git a/tests/test_funcs.py b/tests/test_funcs.py index 6a11e059..a3cf59a2 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -608,6 +608,97 @@ def test_batch(self, xp: ModuleType): ref = np.reshape(np.stack(ref_list), (*batch_shape, n_var, n_var)) xp_assert_close(res, xp.asarray(ref)) + def test_correction(self, xp: ModuleType): + rng = np.random.default_rng(20260417) + m = rng.random((3, 20)) + for correction in (0, 1, 2): + ref = np.cov(m, ddof=correction) + res = cov(xp.asarray(m), correction=correction) + xp_assert_close(res, xp.asarray(ref)) + + def test_correction_float(self, xp: ModuleType): + # Float correction: reference computed by hand (numpy.cov rejects + # non-integer ddof; our generic path supports it). + rng = np.random.default_rng(20260417) + m = rng.random((3, 20)) + n = m.shape[-1] + centered = m - m.mean(axis=-1, keepdims=True) + ref = centered @ centered.T / (n - 1.5) + res = cov(xp.asarray(m), correction=1.5) + xp_assert_close(res, xp.asarray(ref)) + + def test_axis(self, xp: ModuleType): + rng = np.random.default_rng(20260417) + m = rng.random((20, 3)) # observations on axis 0 + ref = np.cov(m, rowvar=False) + res = cov(xp.asarray(m), axis=0) + xp_assert_close(res, xp.asarray(ref)) + res_neg = cov(xp.asarray(m), axis=-2) + xp_assert_close(res_neg, xp.asarray(ref)) + + def test_frequency_weights(self, xp: ModuleType): + rng = np.random.default_rng(20260417) + m = rng.random((3, 10)) + fw = np.asarray([1, 2, 1, 3, 1, 2, 1, 1, 2, 1], dtype=np.int64) + ref = np.cov(m, fweights=fw) + res = cov(xp.asarray(m), frequency_weights=xp.asarray(fw)) + xp_assert_close(res, xp.asarray(ref)) + + def test_weights(self, xp: ModuleType): + rng = np.random.default_rng(20260417) + m = rng.random((3, 10)) + aw = rng.random(10) + ref = np.cov(m, aweights=aw) + res = cov(xp.asarray(m), weights=xp.asarray(aw)) + xp_assert_close(res, xp.asarray(ref)) + + def test_both_weights(self, xp: ModuleType): + rng = np.random.default_rng(20260417) + m = rng.random((3, 10)) + fw = np.asarray([1, 2, 1, 3, 1, 2, 1, 1, 2, 1], dtype=np.int64) + aw = rng.random(10) + for correction in (0, 1, 2): + ref = np.cov(m, ddof=correction, fweights=fw, aweights=aw) + res = cov( + xp.asarray(m), + correction=correction, + frequency_weights=xp.asarray(fw), + weights=xp.asarray(aw), + ) + xp_assert_close(res, xp.asarray(ref)) + + def test_batch_with_weights(self, xp: ModuleType): + rng = np.random.default_rng(20260417) + batch_shape = (2, 3) + n_var, n_obs = 3, 15 + m = rng.random((*batch_shape, n_var, n_obs)) + aw = rng.random(n_obs) + res = cov(xp.asarray(m), weights=xp.asarray(aw)) + ref_list = [np.cov(m_, aweights=aw) for m_ in np.reshape(m, (-1, n_var, n_obs))] + ref = np.reshape(np.stack(ref_list), (*batch_shape, n_var, n_var)) + xp_assert_close(res, xp.asarray(ref)) + + def test_axis_with_weights(self, xp: ModuleType): + # axis=-2 (observations on first of 2D) combined with weights: + # verifies that moveaxis and weight alignment cooperate. + rng = np.random.default_rng(20260417) + m = rng.random((15, 3)) # observations on axis 0 + aw = rng.random(15) + fw = np.asarray([1, 2, 1, 3, 1, 2, 1, 1, 2, 1, 1, 1, 2, 1, 1], dtype=np.int64) + ref = np.cov(m, rowvar=False, fweights=fw, aweights=aw) + res = cov( + xp.asarray(m), + axis=-2, + frequency_weights=xp.asarray(fw), + weights=xp.asarray(aw), + ) + xp_assert_close(res, xp.asarray(ref)) + + def test_axis_out_of_bounds(self, xp: ModuleType): + m = xp.asarray([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + with pytest.raises(IndexError): + _ = cov(m, axis=5) + @pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no arange", strict=False) class TestOneHot: