diff --git a/src/array_api_extra/_lib/_testing.py b/src/array_api_extra/_lib/_testing.py index 30e2f1ef..a5cd808f 100644 --- a/src/array_api_extra/_lib/_testing.py +++ b/src/array_api_extra/_lib/_testing.py @@ -77,10 +77,11 @@ def _check_ns_shape_dtype( if check_shape: msg = f"shapes do not match: {actual_shape} != f{desired_shape}" assert actual_shape == desired_shape, msg - else: + elif desired.ndim > 0: # Ignore shape, but check flattened size. This is normally done by # np.testing.assert_array_equal etc even when strict=False, but not for # non-materializable arrays. + # This check excludes 0d arrays as they are special-cased in NumPy. actual_size = math.prod(actual_shape) # pyright: ignore[reportUnknownArgumentType] desired_size = math.prod(desired_shape) # pyright: ignore[reportUnknownArgumentType] msg = f"sizes do not match: {actual_size} != f{desired_size}"