Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions changes/2929.bugfix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Fix equality comparison of `ArrayV2Metadata` and `ArrayV3Metadata` objects with a
`NaN` fill value. Such objects are now compared by their JSON-serialized form, so two
otherwise-identical metadata objects with a `NaN` (or infinite) fill value compare equal.
15 changes: 14 additions & 1 deletion src/zarr/core/metadata/v2.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import json
import warnings
from collections.abc import Iterable, Sequence
from functools import cached_property
Expand All @@ -25,7 +26,6 @@
ZDType,
)

import json
from dataclasses import dataclass, field, fields, replace

import numpy as np
Expand Down Expand Up @@ -239,6 +239,19 @@ def to_dict(self) -> dict[str, JSON]:

return zarray_dict

def __eq__(self, other: object) -> bool:
# The default dataclass __eq__ compares fields directly, which is wrong for a NaN
# fill_value: NaN != NaN under IEEE 754. Comparing the JSON-serialized form instead
# treats matching NaN (and inf) fill values as equal. See issue #2929.
if not isinstance(other, ArrayV2Metadata):
return NotImplemented
return self.to_dict() == other.to_dict()

def __hash__(self) -> int:
# Hash the JSON-serialized form to stay consistent with __eq__: equal metadata
# must hash equally, which a field-based hash violates for a NaN fill_value.
return hash(json.dumps(self.to_dict(), sort_keys=True))

def get_chunk_spec(
self, _chunk_coords: tuple[int, ...], array_config: ArrayConfig, prototype: BufferPrototype
) -> ArraySpec:
Expand Down
13 changes: 13 additions & 0 deletions src/zarr/core/metadata/v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,6 +697,19 @@ def to_dict(self) -> dict[str, JSON]:
out_dict["data_type"] = dtype_meta.to_json(zarr_format=3) # type: ignore[unreachable]
return out_dict

def __eq__(self, other: object) -> bool:
# The default dataclass __eq__ compares fields directly, which is wrong for a NaN
# fill_value: NaN != NaN under IEEE 754. Comparing the JSON-serialized form instead
# treats matching NaN (and inf) fill values as equal. See issue #2929.
if not isinstance(other, ArrayV3Metadata):
return NotImplemented
return self.to_dict() == other.to_dict()

def __hash__(self) -> int:
# Hash the JSON-serialized form to stay consistent with __eq__: equal metadata
# must hash equally, which a field-based hash violates for a NaN fill_value.
return hash(json.dumps(self.to_dict(), sort_keys=True))

def update_shape(self, shape: tuple[int, ...]) -> Self:
chunk_grid = self.chunk_grid
if isinstance(chunk_grid, RectilinearChunkGridMetadata):
Expand Down
54 changes: 54 additions & 0 deletions tests/test_metadata/test_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,60 @@ def test_from_dict_extra_fields() -> None:
assert result == expected


def test_eq_nan_fill_value() -> None:
"""Two metadata objects with an identical NaN fill_value compare equal.

NaN is not equal to itself under IEEE 754, so the default dataclass __eq__
reports two otherwise-identical metadata objects as unequal. Metadata
equality must treat matching NaN fill values as equal (see issue #2929).
"""
a = ArrayV2Metadata(
shape=(8,), dtype=Float64(), chunks=(8,), fill_value=np.float64("nan"), order="C"
)
b = ArrayV2Metadata(
shape=(8,), dtype=Float64(), chunks=(8,), fill_value=np.float64("nan"), order="C"
)
assert a == b


def test_eq_distinct_fill_value() -> None:
"""Metadata objects that differ only in fill_value do not compare equal."""
a = ArrayV2Metadata(shape=(8,), dtype=Float64(), chunks=(8,), fill_value=0.0, order="C")
b = ArrayV2Metadata(shape=(8,), dtype=Float64(), chunks=(8,), fill_value=1.0, order="C")
assert a != b


@pytest.mark.parametrize("fill_value", [np.float64("inf"), np.float64("-inf")])
def test_eq_inf_fill_value(fill_value: np.float64) -> None:
"""Two metadata objects with an identical infinite fill_value compare equal."""
a = ArrayV2Metadata(shape=(8,), dtype=Float64(), chunks=(8,), fill_value=fill_value, order="C")
b = ArrayV2Metadata(shape=(8,), dtype=Float64(), chunks=(8,), fill_value=fill_value, order="C")
assert a == b


def test_hash_consistent_with_eq_nan_fill_value() -> None:
"""Equal metadata objects with a NaN fill_value hash equal.

NaN hashes by identity, so a field-based hash would break the
``a == b implies hash(a) == hash(b)`` invariant for objects that compare
equal under the to_dict-based __eq__.
"""
a = ArrayV2Metadata(
shape=(8,), dtype=Float64(), chunks=(8,), fill_value=np.float64("nan"), order="C"
)
b = ArrayV2Metadata(
shape=(8,), dtype=Float64(), chunks=(8,), fill_value=np.float64("nan"), order="C"
)
assert a == b
assert hash(a) == hash(b)


def test_eq_non_metadata() -> None:
"""Comparison against a non-metadata object returns False rather than erroring."""
a = ArrayV2Metadata(shape=(8,), dtype=Float64(), chunks=(8,), fill_value=0.0, order="C")
assert a != object()


def test_zstd_checksum() -> None:
compressor_config: dict[str, JSON] = {"id": "zstd", "level": 5, "checksum": False}
arr = zarr.create_array(
Expand Down
71 changes: 70 additions & 1 deletion tests/test_metadata/test_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@
from zarr.core.buffer import default_buffer_prototype
from zarr.core.chunk_grids import is_regular_1d, is_regular_nd
from zarr.core.config import config
from zarr.core.dtype import UInt8
from zarr.core.dtype import Float64, UInt8
from zarr.core.group import GroupMetadata, parse_node_type
from zarr.core.metadata.v2 import ArrayV2Metadata
from zarr.core.metadata.v3 import (
ARRAY_METADATA_KEYS,
ArrayMetadataJSON_V3,
Expand Down Expand Up @@ -335,6 +336,74 @@ def test_init_extra_fields_collision() -> None:
)


# ---------------------------------------------------------------------------
# Equality
# ---------------------------------------------------------------------------


def test_eq_nan_fill_value() -> None:
"""Two metadata objects with an identical NaN fill_value compare equal.

NaN is not equal to itself under IEEE 754, so the default dataclass __eq__
reports two otherwise-identical metadata objects as unequal. Metadata
equality must treat matching NaN fill values as equal (see issue #2929).
"""
a = ArrayV3Metadata.from_dict(minimal_metadata_dict_v3(data_type="float64", fill_value="NaN")) # type: ignore[arg-type]
b = ArrayV3Metadata.from_dict(minimal_metadata_dict_v3(data_type="float64", fill_value="NaN")) # type: ignore[arg-type]
assert a == b


def test_eq_distinct_fill_value() -> None:
"""Metadata objects that differ only in fill_value do not compare equal."""
a = ArrayV3Metadata.from_dict(minimal_metadata_dict_v3(data_type="float64", fill_value=0.0)) # type: ignore[arg-type]
b = ArrayV3Metadata.from_dict(minimal_metadata_dict_v3(data_type="float64", fill_value=1.0)) # type: ignore[arg-type]
assert a != b


@pytest.mark.parametrize("fill_value", ["Infinity", "-Infinity"])
def test_eq_inf_fill_value(fill_value: str) -> None:
"""Two metadata objects with an identical infinite fill_value compare equal."""
a = ArrayV3Metadata.from_dict(
minimal_metadata_dict_v3(data_type="float64", fill_value=fill_value) # type: ignore[arg-type]
)
b = ArrayV3Metadata.from_dict(
minimal_metadata_dict_v3(data_type="float64", fill_value=fill_value) # type: ignore[arg-type]
)
assert a == b


def test_hash_consistent_with_eq_nan_fill_value() -> None:
"""Equal metadata objects with a NaN fill_value hash equal.

NaN hashes by identity, so a field-based hash would break the
``a == b implies hash(a) == hash(b)`` invariant for objects that compare
equal under the to_dict-based __eq__.
"""
a = ArrayV3Metadata.from_dict(minimal_metadata_dict_v3(data_type="float64", fill_value="NaN")) # type: ignore[arg-type]
b = ArrayV3Metadata.from_dict(minimal_metadata_dict_v3(data_type="float64", fill_value="NaN")) # type: ignore[arg-type]
assert a == b
assert hash(a) == hash(b)


def test_eq_non_metadata() -> None:
"""Comparison against a non-metadata object returns False rather than erroring."""
a = ArrayV3Metadata.from_dict(minimal_metadata_dict_v3(data_type="float64", fill_value=0.0)) # type: ignore[arg-type]
assert a != object()


def test_eq_across_zarr_formats() -> None:
"""A v2 and v3 metadata describing the same array do not compare equal.

Each __eq__ guards on its own concrete type and returns NotImplemented
otherwise, so the two versions are never equal even when they describe the
same array.
"""
v3 = ArrayV3Metadata.from_dict(minimal_metadata_dict_v3(data_type="float64", fill_value=0.0)) # type: ignore[arg-type]
v2 = ArrayV2Metadata(shape=(4, 4), dtype=Float64(), chunks=(4, 4), fill_value=0.0, order="C")
assert v2 != v3
assert v3 != v2


# ---------------------------------------------------------------------------
# JSON indent
# ---------------------------------------------------------------------------
Expand Down
Loading