Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 104 additions & 16 deletions src/array_api_extra/_delegation.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,16 @@ def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType | None = None) -> Array
return _funcs.atleast_nd(x, ndim=ndim, xp=xp)


def cov(m: Array, /, *, xp: ModuleType | None = None) -> Array:
def cov(
m: Array,
/,
*,
axis: int = -1,
correction: int | float = 1,
frequency_weights: Array | None = None,
weights: Array | None = None,
xp: ModuleType | None = None,
) -> Array:
"""
Estimate a covariance matrix (or a stack of covariance matrices).

Expand All @@ -92,16 +101,37 @@ def cov(m: Array, /, *, xp: ModuleType | None = None) -> Array:
:math:`x_i` and :math:`x_j`. The element :math:`C_{ii}` is the variance
of :math:`x_i`.

With the exception of supporting batch input, this provides a subset of
the functionality of ``numpy.cov``.
Extends ``numpy.cov`` with support for batch input and array-api
backends. Naming follows the array-api conventions used elsewhere in
this library (``axis``, ``correction``) rather than the numpy spellings
(``rowvar``, ``bias``, ``ddof``); see Notes for the mapping.

Parameters
----------
m : array
An array of shape ``(..., N, M)`` whose innermost two dimensions
contain *M* observations of *N* variables. That is,
each row of `m` represents a variable, and each column a single
observation of all those variables.
contain *M* observations of *N* variables by default. The axis of
observations is controlled by `axis`.
axis : int, optional
Axis of `m` containing the observations. Default: ``-1`` (the last
axis), matching the array-api convention. Use ``axis=-2`` (or ``0``
for 2-D input) to treat each column as a variable, which
corresponds to ``rowvar=False`` in ``numpy.cov``.
correction : int or float, optional
Degrees of freedom correction: normalization divides by
``N - correction`` (for unweighted input). Default: ``1``, which
gives the unbiased estimate (matches ``numpy.cov`` default of
``bias=False``). Set to ``0`` for the biased estimate (``N``
normalization). Corresponds to ``ddof`` in ``numpy.cov`` and to
``correction`` in ``numpy.var``/``std`` and ``torch.cov``.
frequency_weights : array, optional
1-D array of integer frequency weights: the number of times each
observation is repeated. Corresponds to ``fweights`` in
``numpy.cov``/``torch.cov``.
weights : array, optional
1-D array of observation-vector weights (analytic weights). Larger
values mark more important observations. Corresponds to
``aweights`` in ``numpy.cov``/``torch.cov``.
xp : array_namespace, optional
The standard-compatible namespace for `m`. Default: infer.

Expand All @@ -111,6 +141,23 @@ def cov(m: Array, /, *, xp: ModuleType | None = None) -> Array:
An array having shape (..., N, N) whose innermost two dimensions represent
the covariance matrix of the variables.

Notes
-----
Mapping from ``numpy.cov`` to this function::

numpy.cov(m, rowvar=True) -> cov(m, axis=-1) # default
numpy.cov(m, rowvar=False) -> cov(m, axis=-2)
numpy.cov(m, bias=True) -> cov(m, correction=0)
numpy.cov(m, ddof=k) -> cov(m, correction=k)
numpy.cov(m, fweights=f) -> cov(m, frequency_weights=f)
numpy.cov(m, aweights=a) -> cov(m, weights=a)

Unlike ``numpy.cov``, a ``RuntimeWarning`` for non-positive effective
degrees of freedom is only emitted on the unweighted path. The
weighted path omits the check so that lazy backends (e.g. Dask) can
stay lazy end-to-end; choose ``correction`` and weights such that the
effective normalizer is positive.

Examples
--------
>>> import array_api_strict as xp
Expand Down Expand Up @@ -164,16 +211,57 @@ def cov(m: Array, /, *, xp: ModuleType | None = None) -> Array:
if xp is None:
xp = array_namespace(m)

if (
is_numpy_namespace(xp)
or is_cupy_namespace(xp)
or is_torch_namespace(xp)
or is_dask_namespace(xp)
or is_jax_namespace(xp)
) and m.ndim <= 2:
return xp.cov(m)

return _funcs.cov(m, xp=xp)
# Validate axis against m.ndim.
ndim = max(m.ndim, 1)
if not -ndim <= axis < ndim:
msg = f"axis {axis} is out of bounds for array of dimension {m.ndim}"
raise IndexError(msg)

# Normalize: observations on the last axis. After this, every backend
# sees the same convention and we never need to deal with `rowvar`.
if m.ndim >= 2 and axis not in (-1, m.ndim - 1):
m = xp.moveaxis(m, axis, -1)

# `numpy.cov` (and cupy/dask/jax) require integer `ddof`; `torch.cov`
# requires integer `correction`. For non-integer-valued `correction`,
# fall through to the generic implementation.
integer_correction = isinstance(correction, int) or correction.is_integer()
has_weights = frequency_weights is not None or weights is not None

if m.ndim <= 2 and integer_correction:
if is_torch_namespace(xp):
device = get_device(m)
fw = (
None
if frequency_weights is None
else xp.asarray(frequency_weights, device=device)
)
aw = None if weights is None else xp.asarray(weights, device=device)
return xp.cov(m, correction=int(correction), fweights=fw, aweights=aw)
# `dask.array.cov` forces `.compute()` whenever weights are given:
# its internal `if fact <= 0` check on a lazy 0-D scalar triggers
# materialization. Route to the generic impl, which is fully lazy
# because it only does sum/matmul and skips that scalar check.
if (
is_numpy_namespace(xp)
or is_cupy_namespace(xp)
or is_jax_namespace(xp)
or (is_dask_namespace(xp) and not has_weights)
):
return xp.cov(
m,
ddof=int(correction),
fweights=frequency_weights,
aweights=weights,
)

return _funcs.cov(
m,
correction=correction,
frequency_weights=frequency_weights,
weights=weights,
xp=xp,
)


def create_diagonal(
Expand Down
66 changes: 51 additions & 15 deletions src/array_api_extra/_lib/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,31 +281,67 @@ def broadcast_shapes(*shapes: tuple[float | None, ...]) -> tuple[int | None, ...
return tuple(out)


def cov(m: Array, /, *, xp: ModuleType) -> Array: # numpydoc ignore=PR01,RT01
def cov(
m: Array,
/,
*,
correction: int | float = 1,
frequency_weights: Array | None = None,
weights: Array | None = None,
xp: ModuleType,
) -> Array: # numpydoc ignore=PR01,RT01
"""See docstring in array_api_extra._delegation."""
m = xp.asarray(m, copy=True)
m = xp.asarray(m)
dtype = (
xp.float64 if xp.isdtype(m.dtype, "integral") else xp.result_type(m, xp.float64)
)

m = atleast_nd(m, ndim=2, xp=xp)
m = xp.astype(m, dtype)

avg = xp.mean(m, axis=-1, keepdims=True)
device = _compat.device(m)
fw = (
None
if frequency_weights is None
else xp.astype(xp.asarray(frequency_weights, device=device), dtype)
)
aw = (
None
if weights is None
else xp.astype(xp.asarray(weights, device=device), dtype)
)
if fw is None and aw is None:
w = None
elif fw is None:
w = aw
elif aw is None:
w = fw
else:
w = fw * aw

m_shape = eager_shape(m)
fact = m_shape[-1] - 1

if fact <= 0:
warnings.warn("Degrees of freedom <= 0 for slice", RuntimeWarning, stacklevel=2)
fact = 0

m -= avg
m_transpose = xp.matrix_transpose(m)
if xp.isdtype(m_transpose.dtype, "complex floating"):
m_transpose = xp.conj(m_transpose)
c = xp.matmul(m, m_transpose)
c /= fact
if w is None:
avg = xp.mean(m, axis=-1, keepdims=True)
fact = m_shape[-1] - correction
if fact <= 0:
warnings.warn(
"Degrees of freedom <= 0 for slice", RuntimeWarning, stacklevel=2
)
fact = 0
else:
v1 = xp.sum(w, axis=-1)
avg = xp.sum(m * w, axis=-1, keepdims=True) / v1
if aw is None:
fact = v1 - correction
else:
fact = v1 - correction * xp.sum(w * aw, axis=-1) / v1

m_c = m - avg
m_w = m_c if w is None else m_c * w
m_cT = xp.matrix_transpose(m_c)
if xp.isdtype(m_cT.dtype, "complex floating"):
m_cT = xp.conj(m_cT)
c = xp.matmul(m_w, m_cT) / fact
axes = tuple(axis for axis, length in enumerate(c.shape) if length == 1)
return xp.squeeze(c, axis=axes)

Expand Down
91 changes: 91 additions & 0 deletions tests/test_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,6 +608,97 @@ def test_batch(self, xp: ModuleType):
ref = np.reshape(np.stack(ref_list), (*batch_shape, n_var, n_var))
xp_assert_close(res, xp.asarray(ref))

def test_correction(self, xp: ModuleType):
rng = np.random.default_rng(20260417)
m = rng.random((3, 20))
for correction in (0, 1, 2):
ref = np.cov(m, ddof=correction)
res = cov(xp.asarray(m), correction=correction)
xp_assert_close(res, xp.asarray(ref))

def test_correction_float(self, xp: ModuleType):
# Float correction: reference computed by hand (numpy.cov rejects
# non-integer ddof; our generic path supports it).
rng = np.random.default_rng(20260417)
m = rng.random((3, 20))
n = m.shape[-1]
centered = m - m.mean(axis=-1, keepdims=True)
ref = centered @ centered.T / (n - 1.5)
res = cov(xp.asarray(m), correction=1.5)
xp_assert_close(res, xp.asarray(ref))

def test_axis(self, xp: ModuleType):
rng = np.random.default_rng(20260417)
m = rng.random((20, 3)) # observations on axis 0
ref = np.cov(m, rowvar=False)
res = cov(xp.asarray(m), axis=0)
xp_assert_close(res, xp.asarray(ref))
res_neg = cov(xp.asarray(m), axis=-2)
xp_assert_close(res_neg, xp.asarray(ref))

def test_frequency_weights(self, xp: ModuleType):
rng = np.random.default_rng(20260417)
m = rng.random((3, 10))
fw = np.asarray([1, 2, 1, 3, 1, 2, 1, 1, 2, 1], dtype=np.int64)
ref = np.cov(m, fweights=fw)
res = cov(xp.asarray(m), frequency_weights=xp.asarray(fw))
xp_assert_close(res, xp.asarray(ref))

def test_weights(self, xp: ModuleType):
rng = np.random.default_rng(20260417)
m = rng.random((3, 10))
aw = rng.random(10)
ref = np.cov(m, aweights=aw)
res = cov(xp.asarray(m), weights=xp.asarray(aw))
xp_assert_close(res, xp.asarray(ref))

def test_both_weights(self, xp: ModuleType):
rng = np.random.default_rng(20260417)
m = rng.random((3, 10))
fw = np.asarray([1, 2, 1, 3, 1, 2, 1, 1, 2, 1], dtype=np.int64)
aw = rng.random(10)
for correction in (0, 1, 2):
ref = np.cov(m, ddof=correction, fweights=fw, aweights=aw)
res = cov(
xp.asarray(m),
correction=correction,
frequency_weights=xp.asarray(fw),
weights=xp.asarray(aw),
)
xp_assert_close(res, xp.asarray(ref))

def test_batch_with_weights(self, xp: ModuleType):
rng = np.random.default_rng(20260417)
batch_shape = (2, 3)
n_var, n_obs = 3, 15
m = rng.random((*batch_shape, n_var, n_obs))
aw = rng.random(n_obs)
res = cov(xp.asarray(m), weights=xp.asarray(aw))
ref_list = [np.cov(m_, aweights=aw) for m_ in np.reshape(m, (-1, n_var, n_obs))]
ref = np.reshape(np.stack(ref_list), (*batch_shape, n_var, n_var))
xp_assert_close(res, xp.asarray(ref))

def test_axis_with_weights(self, xp: ModuleType):
# axis=-2 (observations on first of 2D) combined with weights:
# verifies that moveaxis and weight alignment cooperate.
rng = np.random.default_rng(20260417)
m = rng.random((15, 3)) # observations on axis 0
aw = rng.random(15)
fw = np.asarray([1, 2, 1, 3, 1, 2, 1, 1, 2, 1, 1, 1, 2, 1, 1], dtype=np.int64)
ref = np.cov(m, rowvar=False, fweights=fw, aweights=aw)
res = cov(
xp.asarray(m),
axis=-2,
frequency_weights=xp.asarray(fw),
weights=xp.asarray(aw),
)
xp_assert_close(res, xp.asarray(ref))

def test_axis_out_of_bounds(self, xp: ModuleType):
m = xp.asarray([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
with pytest.raises(IndexError):
_ = cov(m, axis=5)


@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no arange", strict=False)
class TestOneHot:
Expand Down