From c11cc44ea53dd8334ea433b9c7a2d40c6e4ecaf2 Mon Sep 17 00:00:00 2001 From: Connor Tsui Date: Thu, 21 May 2026 15:17:48 -0400 Subject: [PATCH 1/5] Document existing TurboQuant types Doc-comment pass over the parts of vortex-turboquant that don't depend on the in-flight inv_direction_norms work: SplitMix64, the vector module, DecodeInputs's pre-existing fields, TurboQuant + TurboQuantMetadata + their fields, the TurboQuantMetadataProto wire-format note, validate_tq_metadata, validate_tq_storage_dtype, and the TurboQuantParsedStorage struct + existing fields. No semantic change. Signed-off-by: Connor Tsui --- vortex-turboquant/src/scalar_fns/decode.rs | 10 ++++++++++ vortex-turboquant/src/sorf/splitmix64.rs | 4 +++- vortex-turboquant/src/vector/mod.rs | 8 ++++++++ vortex-turboquant/src/vector/storage.rs | 17 +++++++++++------ vortex-turboquant/src/vtable.rs | 18 +++++++++++++----- 5 files changed, 45 insertions(+), 12 deletions(-) diff --git a/vortex-turboquant/src/scalar_fns/decode.rs b/vortex-turboquant/src/scalar_fns/decode.rs index 22e4f2560b8..6791a1aef61 100644 --- a/vortex-turboquant/src/scalar_fns/decode.rs +++ b/vortex-turboquant/src/scalar_fns/decode.rs @@ -203,11 +203,21 @@ fn build_empty_vector( }) } +/// Borrowed bundle of the per-array decode inputs passed to the typed inner loop. +/// +/// Packaged as a struct rather than positional arguments because `decode_typed` runs through +/// [`vortex_array::match_each_float_ptype!`] which expands once per supported element ptype. +/// Each expansion takes the same set of inputs, and the struct keeps the call site short. struct DecodeInputs<'a> { + /// TurboQuant metadata recovered from the input extension dtype. metadata: &'a TurboQuantMetadata, + /// SORF transform reconstructed from `metadata.seed` and `metadata.num_rounds`. sorf_matrix: &'a SorfMatrix, + /// Centroid codebook for `(padded_dim, bit_width)`, in f32. centroids: &'a [f32], + /// Per-row stored L2 norm of the original input vector, in the element ptype. norms: &'a PrimitiveArray, + /// Flat per-row centroid indices, `num_vectors * padded_dim` bytes. codes: &'a PrimitiveArray, } diff --git a/vortex-turboquant/src/sorf/splitmix64.rs b/vortex-turboquant/src/sorf/splitmix64.rs index 1233e4dc7ee..fc3f9073ced 100644 --- a/vortex-turboquant/src/sorf/splitmix64.rs +++ b/vortex-turboquant/src/sorf/splitmix64.rs @@ -19,7 +19,9 @@ const SPLITMIX64_MUL1: u64 = 0xBF58_476D_1CE4_E5B9; /// Second SplitMix64 mixing multiplier from the reference implementation. const SPLITMIX64_MUL2: u64 = 0x94D0_49BB_1331_11EB; -/// Frozen local SplitMix64 stream used to define SORF sign diagonals. +/// Frozen local SplitMix64 stream used to define SORF sign diagonals. Bit-identical to the +/// reference implementation linked at the module top, which makes the sign stream part of the +/// encoding's wire contract. pub(crate) struct SplitMix64 { state: u64, } diff --git a/vortex-turboquant/src/vector/mod.rs b/vortex-turboquant/src/vector/mod.rs index 58c4271a398..f4fe8726103 100644 --- a/vortex-turboquant/src/vector/mod.rs +++ b/vortex-turboquant/src/vector/mod.rs @@ -1,6 +1,8 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors +//! Vector-side helpers: normalization, quantization, and physical storage layout. + pub(crate) mod normalize; pub(crate) mod quantize; pub(crate) mod storage; @@ -9,6 +11,12 @@ use vortex_error::VortexResult; use vortex_error::vortex_err; /// Compute the padded SORF dimension for an original vector dimension. +/// +/// The SORF transform requires a power-of-two width, so non-power-of-two input dimensions are +/// padded with zeros up to the next power of two. The padded dimension is stored implicitly via +/// [`TurboQuantMetadata::dimensions`](crate::TurboQuantMetadata) plus the codes child's +/// `FixedSizeList` width and recovered at decode time via this function. Returns an error when +/// the next power of two overflows the input integer type. pub(crate) fn tq_padded_dim(dimensions: u32) -> VortexResult { let padded_dim = dimensions .checked_next_power_of_two() diff --git a/vortex-turboquant/src/vector/storage.rs b/vortex-turboquant/src/vector/storage.rs index 016eff3b0dc..d1b4f06cc05 100644 --- a/vortex-turboquant/src/vector/storage.rs +++ b/vortex-turboquant/src/vector/storage.rs @@ -17,9 +17,9 @@ //! converted into zero vectors. The code bytes for invalid rows are physical placeholders only; the //! field-level validity records that those rows were not quantized. //! -//! Parsing treats the outer struct validity as authoritative. Child validity may be wider than the -//! struct validity, for example after a generic mask only updates the struct validity, but each -//! child must be valid wherever the struct row is valid. +//! Parsing treats the outer struct validity as authoritative. Child validity may be wider than +//! the struct validity (for example after a generic mask only updates the struct validity), but +//! each child must be valid wherever the struct row is valid. use vortex_array::ArrayRef; use vortex_array::ExecutionCtx; @@ -48,14 +48,19 @@ pub(crate) const NORMS_FIELD: &str = "norms"; /// Name of the stored quantized-code child. pub(crate) const CODES_FIELD: &str = "codes"; -/// Parsed TurboQuant storage arrays. -/// -/// We use this as a helper struct for working with a TurboQuant extension array. +/// Executed storage children of a TurboQuant extension array plus the authoritative outer +/// struct validity. Every child is row-aligned to `len` and every child's validity covers +/// `vector_validity`. pub(crate) struct TurboQuantParsedStorage { + /// Metadata recovered from the input extension dtype. pub(crate) metadata: TurboQuantMetadata, + /// Authoritative row validity for the quantized vectors, taken from the outer struct. pub(crate) vector_validity: Validity, + /// Per-row stored L2 norm of the original input vector, in `metadata.element_ptype`. pub(crate) norms: PrimitiveArray, + /// Flat `u8` per-row centroid indices, `num_vectors * padded_dim` entries long. pub(crate) codes: PrimitiveArray, + /// Row count. pub(crate) len: usize, } diff --git a/vortex-turboquant/src/vtable.rs b/vortex-turboquant/src/vtable.rs index 51dd0933ca6..854bcee6c70 100644 --- a/vortex-turboquant/src/vtable.rs +++ b/vortex-turboquant/src/vtable.rs @@ -26,18 +26,20 @@ use crate::vector::storage::CODES_FIELD; use crate::vector::storage::NORMS_FIELD; use crate::vector::tq_padded_dim; -/// TurboQuant logical extension type. +/// TurboQuant logical extension type. Per-array configuration lives in [`TurboQuantMetadata`]. #[derive(Clone, Debug, Default, PartialEq, Eq, Hash)] pub struct TurboQuant; -/// Serialized metadata for a TurboQuant extension array. +/// Serialized metadata for a TurboQuant extension array. The fields together suffice to +/// reconstruct the SORF transform and centroid codebook at decode time. #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] pub struct TurboQuantMetadata { - /// Original vector element type and stored norm type. + /// Original vector element ptype and stored row-norm ptype. Restricted to `f16` / `f32` / + /// `f64`. pub element_ptype: PType, - /// Original vector dimension before SORF padding. + /// Original vector dimension before SORF padding to the next power of two. pub dimensions: u32, - /// Bits per coordinate in the scalar quantizer codebook. + /// Bits per coordinate in the scalar quantizer codebook (`1..=8`). pub bit_width: u8, /// Seed used to derive the deterministic SORF transform. pub seed: u64, @@ -106,6 +108,8 @@ impl ExtVTable for TurboQuant { } } +/// Wire-format representation of [`TurboQuantMetadata`]. Field tags MUST NOT change once +/// shipped; new fields must use unused tags and remain optional. #[derive(Clone, PartialEq, Message)] struct TurboQuantMetadataProto { #[prost(enumeration = "PType", tag = "1")] @@ -158,6 +162,8 @@ pub(crate) fn tq_storage_dtype( )) } +/// Validate [`TurboQuantMetadata`] invariants. Called on both serialize and deserialize so a +/// corrupted on-disk metadata block errors out rather than decoding into nonsense. fn validate_tq_metadata(metadata: &TurboQuantMetadata) -> VortexResult<()> { vortex_ensure!( metadata.dimensions >= MIN_DIMENSION, @@ -175,6 +181,8 @@ fn validate_tq_metadata(metadata: &TurboQuantMetadata) -> VortexResult<()> { TurboQuantConfig::try_new(metadata.bit_width, metadata.seed, metadata.num_rounds).map(|_| ()) } +/// Validate that `dtype` matches the storage shape produced by [`tq_storage_dtype`] for +/// `metadata`. Called from [`TurboQuant::validate_dtype`]. fn validate_tq_storage_dtype(metadata: &TurboQuantMetadata, dtype: &DType) -> VortexResult<()> { let DType::Struct(fields, _) = dtype else { vortex_bail!("TurboQuant storage dtype must be a Struct, got {dtype}"); From 65a912e1de5a110e1c4ca48bd180d5b4a816d939 Mon Sep 17 00:00:00 2001 From: Connor Tsui Date: Thu, 21 May 2026 16:54:32 -0400 Subject: [PATCH 2/5] Add TurboQuant L2Norm decode readthrough with norm-preserving decode Register a session-scoped execute-parent kernel that intercepts `L2Norm(TQDecode(_))` and returns the stored `norms` child directly, skipping the inverse SORF, dimension truncation, and per-row rescaling that `TQDecode` would otherwise run. Make decode norm-preserving in flight: after the inverse SORF, compute the L2 norm of the decoded-direction prefix, take its reciprocal (guarded by `is_normal` so denormal cancellation produces an all-zero row), and multiply through alongside the stored row norm. This keeps the kernel's fast path equal to the canonical slow path within floating-point precision: `|TQDecode(x)| == stored_norm[x]` by construction, with no extra storage child. Reject non-finite input L2 norms at encode so a finite row whose squared sum overflows f32 (or contains a `NaN`) cannot produce a `+inf` stored norm that would silently disagree with the in-flight correction at decode time. Add regression coverage for the kernel (bit-exact correctness, null/masked validity, parent-nullability coercion, cross-check against canonical at dim 128/129/257), the encode finite-norm guard, the decode L2-norm preservation invariant across (f16, f32, f64) x (128, 129, 257) dimensions, and the file-roundtrip invariant. Collapse the test-session helpers down to a single `test_session` that initializes both `vortex_tensor` and `vortex_turboquant`. Signed-off-by: Connor Tsui --- vortex-turboquant/Cargo.toml | 1 + vortex-turboquant/src/lib.rs | 4 + .../src/scalar_fns/compute/l2_norm.rs | 87 ++++++ .../src/scalar_fns/compute/mod.rs | 22 ++ vortex-turboquant/src/scalar_fns/decode.rs | 23 +- vortex-turboquant/src/scalar_fns/mod.rs | 1 + vortex-turboquant/src/tests/encode_decode.rs | 147 ++++++++++ vortex-turboquant/src/tests/file.rs | 39 +++ vortex-turboquant/src/tests/kernels.rs | 250 ++++++++++++++++++ vortex-turboquant/src/tests/mod.rs | 6 + vortex-turboquant/src/tests/parity.rs | 51 +++- vortex-turboquant/src/vector/normalize.rs | 17 ++ vortex-turboquant/src/vector/storage.rs | 48 ++++ 13 files changed, 690 insertions(+), 6 deletions(-) create mode 100644 vortex-turboquant/src/scalar_fns/compute/l2_norm.rs create mode 100644 vortex-turboquant/src/scalar_fns/compute/mod.rs create mode 100644 vortex-turboquant/src/tests/kernels.rs diff --git a/vortex-turboquant/Cargo.toml b/vortex-turboquant/Cargo.toml index ab3f63583d3..708abf5948c 100644 --- a/vortex-turboquant/Cargo.toml +++ b/vortex-turboquant/Cargo.toml @@ -32,6 +32,7 @@ vortex-utils = { workspace = true, features = ["dashmap"] } divan = { workspace = true } rand = { workspace = true } rstest = { workspace = true } +vortex-array = { workspace = true, features = ["_test-harness"] } vortex-file = { workspace = true } vortex-io = { workspace = true } vortex-layout = { workspace = true } diff --git a/vortex-turboquant/src/lib.rs b/vortex-turboquant/src/lib.rs index 7aeb60368dd..2d6b4cac9b0 100644 --- a/vortex-turboquant/src/lib.rs +++ b/vortex-turboquant/src/lib.rs @@ -48,6 +48,8 @@ //! than quantized. //! - `centroids.rs`: deterministic Max-Lloyd centroid computation and process-local caching. //! - `sorf/`: the Walsh-Hadamard-based structured transform and the stable SplitMix64 sign stream. +//! - `scalar_fns/compute/`: session-scoped optimizer kernels that intercept canonical scalar +//! functions over TurboQuant inputs (currently `L2Norm(TQDecode(_))`). //! //! The current encoding is intentionally MSE-only. It does not yet implement the paper's QJL //! residual correction for unbiased inner-product estimation, and it still uses internal @@ -75,6 +77,8 @@ pub fn initialize(session: &vortex_session::VortexSession) { session.scalar_fns().register(TQEncode); session.scalar_fns().register(TQDecode); + + scalar_fns::compute::register_kernels(session); } #[cfg(test)] diff --git a/vortex-turboquant/src/scalar_fns/compute/l2_norm.rs b/vortex-turboquant/src/scalar_fns/compute/l2_norm.rs new file mode 100644 index 00000000000..7d3eaa6d937 --- /dev/null +++ b/vortex-turboquant/src/scalar_fns/compute/l2_norm.rs @@ -0,0 +1,87 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! `L2Norm` execute-parent kernel that intercepts `L2Norm(TQDecode(tq))` and returns the +//! stored per-row norms directly instead of decoding and recomputing. + +use vortex_array::ArrayRef; +use vortex_array::ExecutionCtx; +use vortex_array::IntoArray; +use vortex_array::arrays::PrimitiveArray; +use vortex_array::arrays::ScalarFn; +use vortex_array::arrays::scalar_fn::ExactScalarFn; +use vortex_array::arrays::scalar_fn::ScalarFnArrayExt; +use vortex_array::dtype::Nullability; +use vortex_array::optimizer::kernels::ArrayKernelsExt; +use vortex_array::optimizer::kernels::ExecuteParentFn; +use vortex_array::scalar_fn::ScalarFnVTable; +use vortex_array::validity::Validity; +use vortex_error::VortexResult; +use vortex_error::vortex_ensure_eq; +use vortex_session::VortexSession; +use vortex_tensor::scalar_fns::l2_norm::L2Norm; + +use crate::TQDecode; +use crate::vector::storage::parse_storage_norms_only; + +/// Register the `L2Norm(TQDecode(_))` execute-parent kernel on the session. +pub(super) fn register(session: &VortexSession) { + session.kernels().register_execute_parent( + L2Norm.id(), + TQDecode.id(), + &[l2_norm_tq_decode_execute_parent as ExecuteParentFn], + ); +} + +/// Intercepts `L2Norm(TQDecode(tq_arr))` and returns the stored TurboQuant `norms` field. +/// +/// Semantically valid because [`TQDecode`] rescales each lossy quantized direction in flight +/// to unit norm before re-applying the stored row norm, so decoded rows preserve the stored +/// L2 norm to floating-point precision. Returning the stored field directly avoids the +/// inverse SORF transform, the per-row reciprocal, and the dimension truncation that the +/// canonical `L2Norm(execute(TQDecode))` path would otherwise run. The kernel returns +/// `Ok(None)` for any non-matching parent / child pair so the canonical path runs unchanged. +/// +/// The result's nullability is coerced to the parent's expected dtype because the stored +/// `norms` child's validity may be wider than the outer struct (a shape +/// [`parse_storage_norms_only`] accepts). +fn l2_norm_tq_decode_execute_parent( + child: &ArrayRef, + parent: &ArrayRef, + _child_idx: usize, + ctx: &mut ExecutionCtx, +) -> VortexResult> { + if !parent.is::>() { + return Ok(None); + } + if !child.is::>() { + return Ok(None); + } + + let tq_array = child.as_::().child_at(0).clone(); + let parsed = parse_storage_norms_only(tq_array, ctx)?; + + let norms_validity = match parent.dtype().nullability() { + Nullability::NonNullable => Validity::NonNullable, + Nullability::Nullable => parsed.vector_validity, + }; + let norms = PrimitiveArray::from_buffer_handle( + parsed.norms.buffer_handle().clone(), + parsed.norms.ptype(), + norms_validity, + ) + .into_array(); + + vortex_ensure_eq!( + norms.dtype(), + parent.dtype(), + "TurboQuant norms field dtype must match L2Norm output dtype" + ); + vortex_ensure_eq!( + norms.len(), + parent.len(), + "TurboQuant norms field length must match L2Norm output length" + ); + + Ok(Some(norms)) +} diff --git a/vortex-turboquant/src/scalar_fns/compute/mod.rs b/vortex-turboquant/src/scalar_fns/compute/mod.rs new file mode 100644 index 00000000000..fb5c6276b6a --- /dev/null +++ b/vortex-turboquant/src/scalar_fns/compute/mod.rs @@ -0,0 +1,22 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! TurboQuant-specific session-scoped optimizer kernels. +//! +//! Each kernel module owns its own +//! [`register_execute_parent`](vortex_array::optimizer::kernels::ArrayKernelsExt::register_execute_parent) +//! call. New kernels (for example `InnerProduct` or `CosineSimilarity`) should be added as +//! sibling modules and threaded through [`register_kernels`] with a single line. + +mod l2_norm; + +use vortex_session::VortexSession; + +/// Register every TurboQuant-specific optimizer kernel on `session`. +/// +/// Called from the crate-level [`crate::initialize`] after the TurboQuant extension type and +/// the [`crate::TQEncode`] / [`crate::TQDecode`] scalar functions are registered, so kernels +/// can resolve the scalar-fn ids they intercept. +pub(crate) fn register_kernels(session: &VortexSession) { + l2_norm::register(session); +} diff --git a/vortex-turboquant/src/scalar_fns/decode.rs b/vortex-turboquant/src/scalar_fns/decode.rs index 6791a1aef61..ff7a188e45d 100644 --- a/vortex-turboquant/src/scalar_fns/decode.rs +++ b/vortex-turboquant/src/scalar_fns/decode.rs @@ -153,9 +153,13 @@ impl ScalarFnVTable for TQDecode { /// Decode a `TurboQuant` extension array back into a `Vector` extension array. /// -/// The decoded directions are inverse-transformed, truncated to the original dimension, and +/// The decoded directions are inverse-transformed, truncated to the original dimension, +/// renormalized in flight by the reciprocal of the dequantized direction's L2 norm, and /// multiplied by the stored row norms. The conversion is lossy and does not roundtrip with -/// [`TQEncode`](crate::TQEncode). +/// [`TQEncode`](crate::TQEncode), but `|TQDecode(x)| == stored_norm[x]` to floating-point +/// precision by construction. This is what makes the session-scoped `L2Norm(TQDecode(_))` +/// kernel correct: its fast path returns the stored norm and the canonical slow path +/// recomputes the same value (up to rounding) from the materialized decode. pub(crate) fn decode_vector(input: ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult { let parsed = parse_storage(input, ctx)?; let metadata = parsed.metadata; @@ -258,12 +262,25 @@ where decode.sorf_matrix.inverse_transform(&decoded, &mut inverse); + // Rescale the decoded direction to unit norm before re-applying the stored row norm, + // so that `|TQDecode(row)| == stored_norm[row]` to floating-point precision. Encode + // already rejects non-finite input norms, so `norm_squared` cannot be `NaN` or + // `+inf` here from legitimate stored data. The `is_normal` guard catches the rare + // case where every dequantized coordinate is zero (so the recip would be `+inf`) or + // where rounding produces a subnormal value, and falls back to an all-zero row. + let norm_squared = inverse.iter().take(dimensions).map(|v| v * v).sum::(); + let inv_direction_norm = if norm_squared.is_normal() { + norm_squared.sqrt().recip() + } else { + 0.0 + }; + let norm = norms[i]; for &value in inverse.iter().take(dimensions) { // `T::from_f32` is infallible for the supported float ptypes (`f16`, `f32`, // `f64`): values outside `f16` range saturate to `±inf` rather than returning // `None`. - let value = T::from_f32(value) + let value = T::from_f32(value * inv_direction_norm) .vortex_expect("from_f32 is infallible for supported float types"); // SAFETY: total pushes across all match arms equal `output_len`. diff --git a/vortex-turboquant/src/scalar_fns/mod.rs b/vortex-turboquant/src/scalar_fns/mod.rs index 1acea9f70f3..b493b383bb9 100644 --- a/vortex-turboquant/src/scalar_fns/mod.rs +++ b/vortex-turboquant/src/scalar_fns/mod.rs @@ -3,6 +3,7 @@ //! Scalar functions for lazy TurboQuant vector encode and decode operations. +pub(crate) mod compute; mod decode; mod encode; mod metadata; diff --git a/vortex-turboquant/src/tests/encode_decode.rs b/vortex-turboquant/src/tests/encode_decode.rs index ed5aab190aa..c6da9e1dd8f 100644 --- a/vortex-turboquant/src/tests/encode_decode.rs +++ b/vortex-turboquant/src/tests/encode_decode.rs @@ -16,6 +16,7 @@ use vortex_array::dtype::PType; use vortex_array::validity::Validity; use vortex_buffer::Buffer; use vortex_error::VortexResult; +use vortex_tensor::scalar_fns::l2_norm::L2Norm; use super::execute_tq_decode; use super::execute_tq_encode; @@ -29,6 +30,7 @@ use super::vector_values_f32; use crate::TurboQuantConfig; use crate::centroids::compute_or_get_centroids; use crate::vector::normalize::tq_normalize_as_l2_denorm; +use crate::vector::storage::parse_storage; #[rstest] #[case::zero_bits(0, 42, 3)] @@ -252,3 +254,148 @@ fn decode_scales_by_stored_norms() -> VortexResult<()> { } Ok(()) } + +/// Encode rejects rows whose L2 norm is non-finite. Without this guard, a row whose squared +/// sum overflows f32 would normalize to all-zero placeholders, and the in-flight decode +/// correction would silently diverge from the stored norm. +#[test] +fn encode_rejects_non_finite_norms() -> VortexResult<()> { + let session = test_session(); + let mut ctx = session.create_execution_ctx(); + + // A row of `1e30` repeated `dim=128` times has squared sum `128 * 1e60 ~= 1.28e62`, which + // overflows `f32` (max ~ 3.4e38) and produces `+inf` when `L2Norm` runs in `f32`. + let values = vec![1e30f32; 128]; + let input = vector_array(128, &values, Validity::NonNullable)?; + + let result = execute_tq_encode(input, &TurboQuantConfig::default(), &mut ctx); + assert!( + result.is_err(), + "encode must reject non-finite norms (overflow case)" + ); + let error = result.err().unwrap().to_string(); + assert!( + error.contains("non-finite"), + "expected non-finite error, got: {error}" + ); + Ok(()) +} + +/// Encode rejects rows containing `NaN` values, which propagate through `L2Norm` to produce +/// a `NaN` stored norm. +#[test] +fn encode_rejects_nan_input() -> VortexResult<()> { + let session = test_session(); + let mut ctx = session.create_execution_ctx(); + + let mut values = vec![1.0f32; 128]; + values[0] = f32::NAN; + let input = vector_array(128, &values, Validity::NonNullable)?; + + let result = execute_tq_encode(input, &TurboQuantConfig::default(), &mut ctx); + assert!(result.is_err(), "encode must reject NaN input rows"); + Ok(()) +} + +/// Decode preserves stored L2 norms across element ptypes and padded/unpadded dimensions. +/// This is the core invariant the `L2Norm(TQDecode(_))` fast path relies on: the slow path's +/// computed norm must match the stored norm to floating-point precision. +#[rstest] +#[case::f16_dim_128(PType::F16, 128_u32, 1e-2_f32)] +#[case::f16_dim_129(PType::F16, 129_u32, 1e-2_f32)] +#[case::f32_dim_128(PType::F32, 128_u32, 1e-4_f32)] +#[case::f32_dim_129(PType::F32, 129_u32, 1e-4_f32)] +#[case::f32_dim_257(PType::F32, 257_u32, 1e-4_f32)] +#[case::f64_dim_128(PType::F64, 128_u32, 1e-4_f32)] +#[case::f64_dim_129(PType::F64, 129_u32, 1e-4_f32)] +fn decode_preserves_original_l2_norms_across_ptypes_and_dims( + #[case] ptype: PType, + #[case] dim: u32, + #[case] tolerance: f32, +) -> VortexResult<()> { + let session = test_session(); + let mut ctx = session.create_execution_ctx(); + let rows = 3; + let raw = (0..rows * dim as usize) + .map(|i| (i % 17) as f32 - 8.0) + .map(|v| v * 0.25) + .collect::>(); + let input = match ptype { + PType::F16 => { + let values: Vec = raw.iter().copied().map(half::f16::from_f32).collect(); + vector_array(dim, &values, Validity::NonNullable)? + } + PType::F32 => vector_array(dim, &raw, Validity::NonNullable)?, + PType::F64 => { + let values: Vec = raw.iter().copied().map(f64::from).collect(); + vector_array(dim, &values, Validity::NonNullable)? + } + _ => unreachable!("ptype must be float"), + }; + let config = TurboQuantConfig::try_new(3, 42, 3)?; + + let encoded = execute_tq_encode(input, &config, &mut ctx)?; + let stored_norms = parse_storage(encoded.clone(), &mut ctx)?.norms; + let decoded = execute_tq_decode(encoded, &mut ctx)?; + let decoded_norms: PrimitiveArray = L2Norm::try_new_array(decoded, rows)? + .into_array() + .execute(&mut ctx)?; + + // Widen both sides to f32 for comparison. The stored norm is the authoritative reference. + let (actuals, expecteds): (Vec, Vec) = match ptype { + PType::F16 => ( + decoded_norms + .as_slice::() + .iter() + .map(|v| f32::from(*v)) + .collect(), + stored_norms + .as_slice::() + .iter() + .map(|v| f32::from(*v)) + .collect(), + ), + PType::F32 => ( + decoded_norms.as_slice::().to_vec(), + stored_norms.as_slice::().to_vec(), + ), + PType::F64 => ( + decoded_norms + .as_slice::() + .iter() + .map(|v| { + #[expect( + clippy::cast_possible_truncation, + reason = "norms are bounded by the test's input magnitudes \ + (~|raw| * dim^0.5), well within f32 range" + )] + let widened = *v as f32; + widened + }) + .collect(), + stored_norms + .as_slice::() + .iter() + .map(|v| { + #[expect( + clippy::cast_possible_truncation, + reason = "norms are bounded by the test's input magnitudes \ + (~|raw| * dim^0.5), well within f32 range" + )] + let widened = *v as f32; + widened + }) + .collect(), + ), + _ => unreachable!(), + }; + + for (actual, expected) in actuals.iter().zip(expecteds.iter()) { + assert!( + (*actual - *expected).abs() <= tolerance * expected.max(1.0), + "decoded norm {actual} did not match stored norm {expected} \ + (ptype {ptype:?}, dim {dim})" + ); + } + Ok(()) +} diff --git a/vortex-turboquant/src/tests/file.rs b/vortex-turboquant/src/tests/file.rs index e59b7a95c75..1242c4d7db7 100644 --- a/vortex-turboquant/src/tests/file.rs +++ b/vortex-turboquant/src/tests/file.rs @@ -3,6 +3,7 @@ use vortex_array::IntoArray; use vortex_array::VortexSessionExecute; +use vortex_array::arrays::PrimitiveArray; use vortex_array::stream::ArrayStreamExt; use vortex_array::validity::Validity; use vortex_error::VortexResult; @@ -10,6 +11,7 @@ use vortex_file::OpenOptionsSessionExt; use vortex_file::VortexWriteOptions; use vortex_io::runtime::BlockingRuntime; use vortex_io::runtime::single::SingleThreadRuntime; +use vortex_tensor::scalar_fns::l2_norm::L2Norm; use vortex_tensor::vector::Vector; use super::execute_tq_decode_from_metadata; @@ -19,6 +21,7 @@ use super::file_session; use super::vector_validity; use crate::TQDecode; use crate::TurboQuantConfig; +use crate::vector::storage::parse_storage; use crate::vtable::tq_metadata; #[test] @@ -46,6 +49,42 @@ fn file_roundtrip_with_initialize_session() -> VortexResult<()> { Ok(()) } +/// File-roundtrip preserves the `L2Norm(TQDecode(_))` fast-path invariant. A regression that +/// silently broke the in-flight decode correction or the kernel would only show up downstream +/// as norm divergence; this test surfaces it at the IO layer. +#[test] +fn file_roundtrip_preserves_l2_norm_invariant() -> VortexResult<()> { + let runtime = SingleThreadRuntime::default(); + let session = file_session(&runtime); + let mut ctx = session.create_execution_ctx(); + let input = f32_vector_array(128, 4, 0.25, Validity::NonNullable)?; + let config = TurboQuantConfig::try_new(3, 42, 3)?; + let encoded = execute_tq_encode(input, &config, &mut ctx)?; + let original_norms: PrimitiveArray = parse_storage(encoded.clone(), &mut ctx)?.norms; + + let mut file_bytes = Vec::new(); + VortexWriteOptions::new(session.clone()) + .blocking(&runtime) + .write(&mut file_bytes, encoded.to_array_iterator())?; + + let file = session.open_options().open_buffer(file_bytes)?; + let read = runtime.block_on(async { file.scan()?.into_array_stream()?.read_all().await })?; + + // Fast-path `L2Norm(TQDecode(_))` must still return the originally stored row norms after + // the file roundtrip. If the kernel or the in-flight decode correction had silently + // broken at serialization, this is where it would surface. + let decoded = TQDecode::try_new_array(read)?.into_array(); + let kernel_norms: PrimitiveArray = L2Norm::try_new_array(decoded, 4)? + .into_array() + .execute(&mut ctx)?; + assert_eq!( + kernel_norms.as_slice::(), + original_norms.as_slice::(), + "L2Norm(TQDecode(read_back)) must equal the originally stored row norms" + ); + Ok(()) +} + #[test] fn file_roundtrip_lazy_decode_scalar_fn_with_initialize_session() -> VortexResult<()> { let runtime = SingleThreadRuntime::default(); diff --git a/vortex-turboquant/src/tests/kernels.rs b/vortex-turboquant/src/tests/kernels.rs new file mode 100644 index 00000000000..9c45c1f59bc --- /dev/null +++ b/vortex-turboquant/src/tests/kernels.rs @@ -0,0 +1,250 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! Tests for TurboQuant-specific session-scoped optimizer kernels. + +use rstest::rstest; +use vortex_array::IntoArray; +use vortex_array::VortexSessionExecute; +use vortex_array::arrays::BoolArray; +use vortex_array::arrays::ExtensionArray; +use vortex_array::arrays::FixedSizeListArray; +use vortex_array::arrays::PrimitiveArray; +use vortex_array::arrays::StructArray; +use vortex_array::assert_arrays_eq; +use vortex_array::builtins::ArrayBuiltins; +use vortex_array::dtype::DType; +use vortex_array::dtype::FieldNames; +use vortex_array::dtype::Nullability; +use vortex_array::dtype::PType; +use vortex_array::validity::Validity; +use vortex_buffer::Buffer; +use vortex_error::VortexResult; +use vortex_tensor::scalar_fns::l2_norm::L2Norm; + +use super::execute_tq_decode; +use super::execute_tq_encode; +use super::f32_vector_array; +use super::test_session; +use super::vector_array; +use crate::TQDecode; +use crate::TurboQuant; +use crate::TurboQuantConfig; +use crate::TurboQuantMetadata; +use crate::vector::storage::parse_storage; + +const DIM: u32 = 128; + +/// Fast path: `L2Norm(TQDecode(tq_arr))` returns the storage `norms` field bit-for-bit. +/// +/// `TQDecode` rescales each decoded direction in flight by the reciprocal of its own L2 norm +/// before re-applying the stored row norm, so decoded rows preserve the stored norm exactly. +/// Bit-exact equality with the parsed `norms` child confirms the session-registered kernel +/// fired instead of recomputing the norm from a materialized decode. +#[test] +fn l2_norm_over_tq_decode_returns_stored_norms() -> VortexResult<()> { + let session = test_session(); + let mut ctx = session.create_execution_ctx(); + let input = f32_vector_array(DIM, 4, 0.25, Validity::NonNullable)?; + let config = TurboQuantConfig::try_new(3, 42, 3)?; + + let encoded = execute_tq_encode(input, &config, &mut ctx)?; + let expected_norms = parse_storage(encoded.clone(), &mut ctx)?.norms.into_array(); + + let decoded = TQDecode::try_new_array(encoded)?.into_array(); + let row_count = decoded.len(); + let result: PrimitiveArray = L2Norm::try_new_array(decoded, row_count)? + .into_array() + .execute(&mut ctx)?; + + assert_arrays_eq!(result, expected_norms); + Ok(()) +} + +/// Negative: directly wrapping a `Vector` (no `TQDecode`) must hit the canonical `L2Norm` +/// path. Proves the kernel only intercepts the matched `(L2Norm, TQDecode)` pair and does +/// not affect the standard tensor scalar-function flow. +#[test] +fn l2_norm_over_plain_vector_uses_canonical_path() -> VortexResult<()> { + let session = test_session(); + let mut ctx = session.create_execution_ctx(); + + let input = vector_array( + 3, + &[ + 3.0f32, 4.0, 0.0, // row 0, norm = 5.0 + 0.0, 0.0, 0.0, // row 1, norm = 0.0 + 1.0, 0.0, 0.0, // row 2, norm = 1.0 + ], + Validity::NonNullable, + )?; + + let row_count = input.len(); + let result: PrimitiveArray = L2Norm::try_new_array(input, row_count)? + .into_array() + .execute(&mut ctx)?; + + let expected = + PrimitiveArray::new::(Buffer::copy_from([5.0f32, 0.0, 1.0]), Validity::NonNullable); + assert_arrays_eq!(result, expected); + Ok(()) +} + +/// Empty input: zero-length TurboQuant array still produces a zero-length norms array of the +/// matching primitive dtype. +#[test] +fn l2_norm_over_empty_tq_decode_is_empty_norms() -> VortexResult<()> { + let session = test_session(); + let mut ctx = session.create_execution_ctx(); + let input = vector_array::(DIM, &[], Validity::NonNullable)?; + let config = TurboQuantConfig::try_new(3, 42, 3)?; + + let encoded = execute_tq_encode(input, &config, &mut ctx)?; + let decoded = TQDecode::try_new_array(encoded)?.into_array(); + let result: PrimitiveArray = L2Norm::try_new_array(decoded, 0)? + .into_array() + .execute(&mut ctx)?; + + assert_eq!(result.len(), 0); + assert_eq!( + result.dtype(), + &DType::Primitive(PType::F32, Nullability::NonNullable) + ); + Ok(()) +} + +/// Null rows: the kernel must preserve the input's row-level validity and produce correct +/// norms for the non-null rows. +#[rstest] +#[case::leading_null(Validity::from_iter([false, true, true]))] +#[case::trailing_null(Validity::from_iter([true, true, false]))] +#[case::interior_null(Validity::from_iter([true, false, true]))] +fn l2_norm_over_tq_decode_preserves_nulls(#[case] validity: Validity) -> VortexResult<()> { + let session = test_session(); + let mut ctx = session.create_execution_ctx(); + let input = f32_vector_array(DIM, 3, 0.25, validity)?; + let config = TurboQuantConfig::try_new(4, 7, 2)?; + + let encoded = execute_tq_encode(input, &config, &mut ctx)?; + let expected_norms = parse_storage(encoded.clone(), &mut ctx)?.norms.into_array(); + + let decoded = TQDecode::try_new_array(encoded)?.into_array(); + let result: PrimitiveArray = L2Norm::try_new_array(decoded, 3)? + .into_array() + .execute(&mut ctx)?; + + assert_arrays_eq!(result, expected_norms); + Ok(()) +} + +/// Masked input: generic masks narrow the TurboQuant storage struct validity without +/// rewriting the `norms` child, so the kernel must apply the authoritative struct validity +/// before returning. +#[test] +fn l2_norm_over_masked_tq_decode_uses_storage_validity() -> VortexResult<()> { + let session = test_session(); + let mut ctx = session.create_execution_ctx(); + let input = f32_vector_array(DIM, 4, 0.25, Validity::NonNullable)?; + let config = TurboQuantConfig::try_new(3, 42, 3)?; + + let encoded = execute_tq_encode(input, &config, &mut ctx)?; + let masked = encoded.mask(BoolArray::from_iter([true, false, true, false]).into_array())?; + + let decoded = TQDecode::try_new_array(masked)?.into_array(); + let result: PrimitiveArray = L2Norm::try_new_array(decoded, 4)? + .into_array() + .execute(&mut ctx)?; + let validity = result.validity()?.execute_mask(4, &mut ctx)?; + + assert!(validity.value(0)); + assert!(!validity.value(1)); + assert!(validity.value(2)); + assert!(!validity.value(3)); + assert_eq!( + result.dtype(), + &DType::Primitive(PType::F32, Nullability::Nullable) + ); + Ok(()) +} + +/// Regression for the wider-child-nullability shape (`Nullable` `norms` with `AllValid` +/// under a `NonNullable` struct). `parse_storage` accepts this shape (see +/// `malformed::decode_accepts_child_nullability_that_covers_struct_validity`); the kernel +/// must return a `NonNullable` result rather than reusing the wider child validity. +#[test] +fn l2_norm_over_tq_decode_nullable_norms_under_nonnullable_struct() -> VortexResult<()> { + let session = test_session(); + let mut ctx = session.create_execution_ctx(); + let metadata = TurboQuantMetadata { + element_ptype: PType::F32, + dimensions: DIM, + bit_width: 1, + seed: 42, + num_rounds: 3, + }; + + let norms = + PrimitiveArray::new::(Buffer::copy_from([5.0f32]), Validity::AllValid).into_array(); + let codes = PrimitiveArray::new::(vec![0u8; DIM as usize], Validity::NonNullable); + let codes = + FixedSizeListArray::try_new(codes.into_array(), DIM, Validity::AllValid, 1)?.into_array(); + let storage = StructArray::try_new( + FieldNames::from(["norms", "codes"]), + vec![norms, codes], + 1, + Validity::NonNullable, + )?; + let tq = ExtensionArray::try_new_from_vtable(TurboQuant, metadata, storage.into_array())? + .into_array(); + + let decoded = TQDecode::try_new_array(tq)?.into_array(); + let result: PrimitiveArray = L2Norm::try_new_array(decoded, 1)? + .into_array() + .execute(&mut ctx)?; + + assert_eq!( + result.dtype(), + &DType::Primitive(PType::F32, Nullability::NonNullable), + "kernel result dtype must match parent (NonNullable), not the wider child validity" + ); + assert_eq!(result.as_slice::(), &[5.0f32]); + Ok(()) +} + +/// Cross-check the kernel result against the canonical `L2Norm(execute(TQDecode))` path. +/// Materializing the decoded vector first breaks the `(L2Norm, TQDecode)` pattern so +/// `L2Norm` runs through the canonical scalar-function flow. +#[rstest] +#[case::dim_128(128_u32)] +#[case::dim_129(129_u32)] +#[case::dim_257(257_u32)] +fn l2_norm_over_tq_decode_matches_canonical(#[case] dim: u32) -> VortexResult<()> { + let session = test_session(); + let mut ctx = session.create_execution_ctx(); + let input = f32_vector_array(dim, 4, 0.25, Validity::NonNullable)?; + let config = TurboQuantConfig::try_new(3, 42, 3)?; + + let encoded = execute_tq_encode(input, &config, &mut ctx)?; + + let kernel_result: PrimitiveArray = + L2Norm::try_new_array(TQDecode::try_new_array(encoded.clone())?.into_array(), 4)? + .into_array() + .execute(&mut ctx)?; + + // Materialize the decoded vector first so `L2Norm` cannot match `(L2Norm, TQDecode)`. The + // resulting `L2Norm(Vector)` flows through the canonical scalar-function path. + let decoded = execute_tq_decode(encoded, &mut ctx)?; + let canonical_result: PrimitiveArray = L2Norm::try_new_array(decoded, 4)? + .into_array() + .execute(&mut ctx)?; + + let kernel = kernel_result.as_slice::(); + let canonical = canonical_result.as_slice::(); + for (k, c) in kernel.iter().zip(canonical.iter()) { + assert!( + (*k - *c).abs() <= 1e-4 * c.max(1.0), + "kernel result {k} disagrees with canonical {c} (dim {dim})" + ); + } + Ok(()) +} diff --git a/vortex-turboquant/src/tests/mod.rs b/vortex-turboquant/src/tests/mod.rs index ffa1db175a7..3b2f098e390 100644 --- a/vortex-turboquant/src/tests/mod.rs +++ b/vortex-turboquant/src/tests/mod.rs @@ -39,13 +39,19 @@ use crate::initialize; mod encode_decode; mod file; +mod kernels; mod malformed; mod metadata; mod parity; mod scalar_fns; +/// In-memory session with both `vortex_tensor` and `vortex_turboquant` initialized. Tests +/// that exercise tensor scalar functions (for example [`L2Norm`](vortex_tensor::scalar_fns::l2_norm::L2Norm)) +/// over TurboQuant inputs need `vortex_tensor::initialize` first; other tests pay no cost +/// for the extra registrations, so a single helper keeps the test surface uniform. fn test_session() -> VortexSession { let session = VortexSession::empty().with::(); + vortex_tensor::initialize(&session); initialize(&session); session } diff --git a/vortex-turboquant/src/tests/parity.rs b/vortex-turboquant/src/tests/parity.rs index 4360d90849d..be879676b50 100644 --- a/vortex-turboquant/src/tests/parity.rs +++ b/vortex-turboquant/src/tests/parity.rs @@ -13,15 +13,27 @@ use super::f32_vector_array; use super::test_session; use super::vector_values_f32; use crate::TurboQuantConfig; +use crate::vector::storage::parse_storage; +/// Pins the exact relationship between new and legacy TurboQuant decode: for each row, +/// `new_value[i] == old_value[i] * (stored_norm / old_norm)`. The centroid table and SORF +/// transform are identical between the two encoders, so the inverse-transformed direction is +/// the same; the only mathematical difference is the per-row scalar correction the new decode +/// applies in flight to rescale the lossy quantized direction to unit norm before re-applying +/// the stored row norm. #[test] -fn encode_decode_matches_old_turboquant_decode() -> VortexResult<()> { +fn new_decode_rescales_old_decode_to_stored_norm() -> VortexResult<()> { let session = test_session(); let mut ctx = session.create_execution_ctx(); - let input = f32_vector_array(128, 2, 0.125, Validity::NonNullable)?; + let input = f32_vector_array(129, 2, 0.125, Validity::NonNullable)?; let config = TurboQuantConfig::try_new(3, 42, 3)?; let new_encoded = execute_tq_encode(input.clone(), &config, &mut ctx)?; + let stored_norms = parse_storage(new_encoded.clone(), &mut ctx)? + .norms + .as_slice::() + .to_vec(); + let new_decoded = execute_tq_decode(new_encoded, &mut ctx)?; let old_config = OldTurboQuantConfig { bit_width: config.bit_width(), @@ -32,7 +44,40 @@ fn encode_decode_matches_old_turboquant_decode() -> VortexResult<()> { let new_values = vector_values_f32(new_decoded, &mut ctx)?; let old_values = vector_values_f32(old_decoded, &mut ctx)?; + let dim = new_values.len() / stored_norms.len(); + + let mut any_correction_nontrivial = false; + for (row, &stored_norm) in stored_norms.iter().enumerate() { + let old_row = &old_values[row * dim..][..dim]; + let old_norm = old_row.iter().map(|v| v * v).sum::().sqrt(); + // Skip rows where the legacy decode produced exact zero norm; division below would + // be undefined and there is no meaningful relationship to pin. + if old_norm == 0.0 { + continue; + } + let correction = stored_norm / old_norm; + if (correction - 1.0).abs() > 1e-4 { + any_correction_nontrivial = true; + } - assert_eq!(new_values, old_values); + for col in 0..dim { + let idx = row * dim + col; + let new_v = new_values[idx]; + let expected = old_values[idx] * correction; + let scale = new_v.abs().max(expected.abs()).max(1.0); + assert!( + (new_v - expected).abs() <= 1e-4 * scale, + "row {row} col {col}: new {new_v} != old {} * correction {correction} \ + (= {expected})", + old_values[idx] + ); + } + } + // Sanity: the correction is meaningfully non-trivial for at least one row. If the new + // decode were a no-op rescaling, this whole test would silently pass. + assert!( + any_correction_nontrivial, + "the in-flight correction should differ from 1.0 for at least one row" + ); Ok(()) } diff --git a/vortex-turboquant/src/vector/normalize.rs b/vortex-turboquant/src/vector/normalize.rs index 642949eecf6..9a69164c1cd 100644 --- a/vortex-turboquant/src/vector/normalize.rs +++ b/vortex-turboquant/src/vector/normalize.rs @@ -24,6 +24,7 @@ use vortex_array::match_each_float_ptype; use vortex_array::validity::Validity; use vortex_buffer::BufferMut; use vortex_error::VortexResult; +use vortex_error::vortex_bail; use vortex_error::vortex_ensure_eq; use vortex_error::vortex_err; use vortex_mask::Mask; @@ -102,6 +103,22 @@ where let values = elements.as_slice::(); let norm_values = norms.as_slice::(); + // Reject non-finite norms up front. A `+inf` or `NaN` norm would either come from an input + // row whose sum of squares overflowed `T` or from a pre-existing `NaN` in the data. In + // either case the f32-precision SORF transform downstream cannot represent the row, and + // letting the division `value / norm` proceed silently corrupts encoded data: the + // normalized row becomes all zeros, encode infers a zero-norm row, and decode's in-flight + // reciprocal-of-direction-norm correction disagrees with the non-finite stored norm. + // Invalid-row norms can carry arbitrary placeholders so they are excluded from the check + // via the row-validity mask. + for (i, &norm) in norm_values.iter().enumerate() { + if mask.value(i) && !norm.is_finite() { + vortex_bail!( + "TurboQuant input row {i} has non-finite L2 norm; encode requires finite norms" + ); + } + } + let output_len = num_vectors .checked_mul(dimensions) .ok_or_else(|| vortex_err!("TurboQuant normalized vector length overflow"))?; diff --git a/vortex-turboquant/src/vector/storage.rs b/vortex-turboquant/src/vector/storage.rs index d1b4f06cc05..2cf52789589 100644 --- a/vortex-turboquant/src/vector/storage.rs +++ b/vortex-turboquant/src/vector/storage.rs @@ -64,6 +64,16 @@ pub(crate) struct TurboQuantParsedStorage { pub(crate) len: usize, } +/// Subset of [`TurboQuantParsedStorage`] containing only the `norms` child plus the outer +/// struct validity. Used by the `L2Norm(TQDecode(_))` execute-parent kernel, which has no +/// need for the `codes` child. +pub(crate) struct TurboQuantParsedNorms { + /// Authoritative row validity for the quantized vectors, taken from the outer struct. + pub(crate) vector_validity: Validity, + /// Per-row stored L2 norm of the original input vector, in `metadata.element_ptype`. + pub(crate) norms: PrimitiveArray, +} + /// Build the `codes: FixedSizeList, padded_dim>` storage child. /// /// Each row of `padded_dim` u8 codes indexes into the deterministic centroid codebook derived @@ -103,6 +113,11 @@ pub(crate) fn build_storage( } /// Parse a TurboQuant extension array into executed storage children. +/// +/// Executes both storage children, validates that every child's row validity covers the outer +/// struct validity, and returns the parsed result. Used by `TQDecode`, which needs every +/// child. Kernels that only need a subset should use a narrower helper (for example +/// [`parse_storage_norms_only`]) to avoid executing the children they will not consume. pub(crate) fn parse_storage( input: ArrayRef, ctx: &mut ExecutionCtx, @@ -141,6 +156,39 @@ pub(crate) fn parse_storage( }) } +/// Narrow form of [`parse_storage`] that returns only the `norms` child plus the outer +/// struct validity. Used by the `L2Norm(TQDecode(_))` kernel so the fast path does not +/// execute the `codes` child it has no use for. The `norms` child's validity is still +/// validated against the struct's; the `codes` child is not touched. +pub(crate) fn parse_storage_norms_only( + input: ArrayRef, + ctx: &mut ExecutionCtx, +) -> VortexResult { + let ext: ExtensionArray = input.execute(ctx)?; + let storage: StructArray = ext.storage_array().clone().execute(ctx)?; + + let norms: PrimitiveArray = storage + .unmasked_field_by_name(NORMS_FIELD)? + .clone() + .execute(ctx)?; + + let len = storage.len(); + let struct_validity = storage.struct_validity(); + let norms_validity = norms.validity()?; + + let struct_mask = struct_validity.execute_mask(len, ctx)?; + let norms_mask = norms_validity.execute_mask(len, ctx)?; + vortex_ensure!( + struct_mask.bitand_not(&norms_mask).all_false(), + "TurboQuant {NORMS_FIELD} row validity must cover storage validity" + ); + + Ok(TurboQuantParsedNorms { + vector_validity: struct_validity, + norms, + }) +} + /// Validate that both child masks cover the struct mask: every row that the struct considers /// valid must also be valid in the `norms` and `codes` children. /// From 8031abfc9137d750bfa651e088afb3437c343534 Mon Sep 17 00:00:00 2001 From: Connor Tsui Date: Thu, 21 May 2026 17:12:02 -0400 Subject: [PATCH 3/5] gauntlet: cycle 1 fixes for tq-l2-norm Apply 3 must-fixes from the cycle-1 phase-3 gauntlet review of the L2Norm-decode readthrough commit. - parse_storage_norms_only (storage.rs): execute the `codes` FSL wrapper and validate its row-level validity against the struct's, so the L2Norm(TQDecode(_)) fast path rejects malformed-validity inputs the same way the canonical TQDecode path does. The inner `u8` element buffer is still skipped, so the per-row decoding work is preserved as the fast-path's optimization. Per-byte centroid-index range is not checked on either path (kernel does not need the codes, canonical path panics on out-of-range bytes which is acceptable per existing malformed-input coverage). - TQEncode::is_fallible (encode.rs): return true. The new non-finite- norm rejection in vector::normalize is a data-dependent defined-behavior failure, so optimizers that speculatively evaluate unreferenced inputs (e.g., dictionary pushdown) must treat encode as fallible to avoid surfacing errors on rows the consumer would never reach. - Degenerate-direction decode (decode.rs): when the dequantized direction's squared L2 norm is zero or subnormal, emit [norm, 0, 0, ..., 0] instead of an all-zero row, so L2Norm(decoded_row) == stored_norm[row]. The all-zero fallback silently disagreed with the L2Norm(TQDecode(_)) kernel (kernel returns stored_norm while canonical returned 0). The kernel's fast/slow-path equivalence now holds for degenerate rows too, matching the documented `|TQDecode(x)| == stored_norm[x]` invariant. Cross-reference the is_fallible rationale at the normalize-side check so both sites are discoverable from each other. Deferred to follow-up: f16/f64 parameterization of the kernel-fast-path test (only f32 is currently exercised at the kernel level), a chunked- TurboQuant kernel test, parity test re-parameterization over dim=128, and a constructed regression test for the degenerate-direction decode path. All are should-fix coverage improvements, not behavior bugs. Signed-off-by: Connor Tsui --- vortex-turboquant/src/scalar_fns/decode.rs | 45 ++++++++++++++-------- vortex-turboquant/src/scalar_fns/encode.rs | 15 +++++++- vortex-turboquant/src/vector/normalize.rs | 15 +++++--- vortex-turboquant/src/vector/storage.rs | 21 +++++++--- 4 files changed, 67 insertions(+), 29 deletions(-) diff --git a/vortex-turboquant/src/scalar_fns/decode.rs b/vortex-turboquant/src/scalar_fns/decode.rs index ff7a188e45d..df889ba9294 100644 --- a/vortex-turboquant/src/scalar_fns/decode.rs +++ b/vortex-turboquant/src/scalar_fns/decode.rs @@ -266,25 +266,36 @@ where // so that `|TQDecode(row)| == stored_norm[row]` to floating-point precision. Encode // already rejects non-finite input norms, so `norm_squared` cannot be `NaN` or // `+inf` here from legitimate stored data. The `is_normal` guard catches the rare - // case where every dequantized coordinate is zero (so the recip would be `+inf`) or - // where rounding produces a subnormal value, and falls back to an all-zero row. + // case where the dequantized direction's squared norm is zero (so the reciprocal + // would be `+inf`) or where summation rounding produces a subnormal value (so the + // reciprocal would not be representable in `f32`). let norm_squared = inverse.iter().take(dimensions).map(|v| v * v).sum::(); - let inv_direction_norm = if norm_squared.is_normal() { - norm_squared.sqrt().recip() - } else { - 0.0 - }; - let norm = norms[i]; - for &value in inverse.iter().take(dimensions) { - // `T::from_f32` is infallible for the supported float ptypes (`f16`, `f32`, - // `f64`): values outside `f16` range saturate to `±inf` rather than returning - // `None`. - let value = T::from_f32(value * inv_direction_norm) - .vortex_expect("from_f32 is infallible for supported float types"); - - // SAFETY: total pushes across all match arms equal `output_len`. - unsafe { output.push_unchecked(value * norm) }; + + if norm_squared.is_normal() { + let inv_direction_norm = norm_squared.sqrt().recip(); + for &value in inverse.iter().take(dimensions) { + // `T::from_f32` is infallible for the supported float ptypes (`f16`, `f32`, + // `f64`): values outside `f16` range saturate to `±inf` rather than + // returning `None`. + let value = T::from_f32(value * inv_direction_norm) + .vortex_expect("from_f32 is infallible for supported float types"); + + // SAFETY: total pushes across all match arms equal `output_len`. + unsafe { output.push_unchecked(value * norm) }; + } + } else { + // Degenerate decoded direction. Emit a single norm-bearing coordinate followed + // by zeros so `L2Norm(decoded_row) == |norm| == stored_norm[i]`, matching what + // the `L2Norm(TQDecode(_))` kernel returns. An all-zero row would silently + // disagree with the kernel for this row. + // SAFETY: total pushes across all match arms equal `output_len`. This branch + // pushes exactly `dimensions` elements (`1 + (dimensions - 1)`); `dimensions` + // is non-zero because encode rejects sub-`MIN_DIMENSION` inputs. + unsafe { + output.push_unchecked(norm); + output.push_n_unchecked(T::zero(), dimensions - 1); + } } }; diff --git a/vortex-turboquant/src/scalar_fns/encode.rs b/vortex-turboquant/src/scalar_fns/encode.rs index 29ce7cc580a..9530f3e00f0 100644 --- a/vortex-turboquant/src/scalar_fns/encode.rs +++ b/vortex-turboquant/src/scalar_fns/encode.rs @@ -162,7 +162,20 @@ impl ScalarFnVTable for TQEncode { } fn is_fallible(&self, _options: &Self::Options) -> bool { - false + // `encode_vector` returns an error for any input row whose L2 norm is non-finite + // (`NaN`, `+inf`, or a finite-input row whose squared sum overflows `f32`). The + // check lives in `vector::normalize::normalize_vectors` and exists so the in-flight + // decode-side norm correction in `scalar_fns::decode::decode_typed` cannot disagree + // with the stored row norm. + // + // This is a data-dependent defined-behavior failure (analogous to `checked_add` on + // integer overflow), so per the `ScalarFnVTable::is_fallible` contract — used by + // optimizers to decide whether to speculatively evaluate over unreferenced inputs, + // for example pushing a scalar function down to a dictionary's values array — this + // function reports as fallible. Without it, speculative dictionary pushdown could + // surface non-finite-norm errors on rows that the consumer would never have + // referenced under normal evaluation. + true } } diff --git a/vortex-turboquant/src/vector/normalize.rs b/vortex-turboquant/src/vector/normalize.rs index 9a69164c1cd..b98aa344d1e 100644 --- a/vortex-turboquant/src/vector/normalize.rs +++ b/vortex-turboquant/src/vector/normalize.rs @@ -103,14 +103,19 @@ where let values = elements.as_slice::(); let norm_values = norms.as_slice::(); - // Reject non-finite norms up front. A `+inf` or `NaN` norm would either come from an input - // row whose sum of squares overflowed `T` or from a pre-existing `NaN` in the data. In - // either case the f32-precision SORF transform downstream cannot represent the row, and - // letting the division `value / norm` proceed silently corrupts encoded data: the - // normalized row becomes all zeros, encode infers a zero-norm row, and decode's in-flight + // Reject non-finite norms up front. A `+inf` or `NaN` norm comes from either an input + // row whose sum of squares overflowed `T` or a pre-existing `NaN` in the data. In either + // case the f32-precision SORF transform downstream cannot represent the row, and letting + // the division `value / norm` proceed silently corrupts encoded data: the normalized row + // becomes all zeros, encode infers a zero-norm row, and decode's in-flight // reciprocal-of-direction-norm correction disagrees with the non-finite stored norm. // Invalid-row norms can carry arbitrary placeholders so they are excluded from the check // via the row-validity mask. + // + // This rejection is the reason `TQEncode::is_fallible` returns `true` (see + // [`crate::scalar_fns::encode`]): a row of well-typed floats with a non-finite norm is a + // defined-behavior failure rather than an internal error, so optimizers must treat + // encode as fallible when reasoning about speculative evaluation. for (i, &norm) in norm_values.iter().enumerate() { if mask.value(i) && !norm.is_finite() { vortex_bail!( diff --git a/vortex-turboquant/src/vector/storage.rs b/vortex-turboquant/src/vector/storage.rs index 2cf52789589..7305a9972b1 100644 --- a/vortex-turboquant/src/vector/storage.rs +++ b/vortex-turboquant/src/vector/storage.rs @@ -158,8 +158,11 @@ pub(crate) fn parse_storage( /// Narrow form of [`parse_storage`] that returns only the `norms` child plus the outer /// struct validity. Used by the `L2Norm(TQDecode(_))` kernel so the fast path does not -/// execute the `codes` child it has no use for. The `norms` child's validity is still -/// validated against the struct's; the `codes` child is not touched. +/// materialize the `codes` child's per-byte elements that it has no use for. Both children's +/// row validity is still validated against the struct's so that malformed-validity inputs +/// fail here exactly as they fail in [`parse_storage`]; only the inner `u8` element buffer +/// of the `codes` child is skipped. The per-byte centroid-index range is not checked either +/// way, since the codes are not decoded on this fast path. pub(crate) fn parse_storage_norms_only( input: ArrayRef, ctx: &mut ExecutionCtx, @@ -172,16 +175,22 @@ pub(crate) fn parse_storage_norms_only( .clone() .execute(ctx)?; + // Execute only the `codes` FSL wrapper so we can validate its row-level validity covers + // the struct's. The inner `u8` element buffer is intentionally not materialized. + let codes_fsl: FixedSizeListArray = storage + .unmasked_field_by_name(CODES_FIELD)? + .clone() + .execute(ctx)?; + let len = storage.len(); let struct_validity = storage.struct_validity(); let norms_validity = norms.validity()?; + let codes_validity = codes_fsl.validity()?; let struct_mask = struct_validity.execute_mask(len, ctx)?; let norms_mask = norms_validity.execute_mask(len, ctx)?; - vortex_ensure!( - struct_mask.bitand_not(&norms_mask).all_false(), - "TurboQuant {NORMS_FIELD} row validity must cover storage validity" - ); + let codes_mask = codes_validity.execute_mask(len, ctx)?; + validate_child_validity_covers_struct(&struct_mask, &norms_mask, &codes_mask)?; Ok(TurboQuantParsedNorms { vector_validity: struct_validity, From d5644f6670c9882c2cd9a2c32eacf25a3ccade0e Mon Sep 17 00:00:00 2001 From: Connor Tsui Date: Thu, 21 May 2026 17:31:53 -0400 Subject: [PATCH 4/5] gauntlet: cycle 2 fixes for tq-l2-norm Apply 3 must-fixes from the cycle-2 phase-3 gauntlet review of the cycle-1 fix-commit (7cf7433dd). - Kernel negative-norms safety (l2_norm.rs): canonical `L2Norm` always returns non-negative, but the kernel was returning the stored `norms` buffer verbatim. For hand-constructed TurboQuant storage carrying a negative norm value, kernel and canonical disagreed by a sign flip. The kernel now scans the parsed norms once and falls back to the canonical path (`Ok(None)`) when any value is strictly negative. Well-formed storage (encode-produced norms via `L2Norm = sqrt(sum_sq) >= 0`) is unaffected; the cost is one cache-warm pass over a buffer that was already materialized for the return value. - Pin `non-finite` error message in `encode_rejects_nan_input` test (encode_decode.rs) so a regression that fails encode for an unrelated reason cannot silently pass; matches the sibling `encode_rejects_non_finite_norms` test's pattern. - Parameterize `l2_norm_over_tq_decode_returns_stored_norms` over `PType::{F16, F32, F64}` (kernels.rs) so the kernel's per-ptype buffer-handle plumbing is exercised at non-f32 ptypes too. Closes the scope-drift gap between the cycle-0 commit message's claimed full ptype matrix and the test that actually shipped. Also replace em dashes added in cycle 1's `is_fallible` rationale comment with comma/period punctuation, per the user's standing no-em-dashes feedback. Cycle-2 should-fix items intentionally deferred: - masked-kernel test asserting numeric values (not just validity bitmap) - direct unit pin of `TQEncode::is_fallible == true` - shared-helper extraction between `parse_storage` and `parse_storage_norms_only` - chunked-TurboQuant kernel coverage - constructed regression for the degenerate-direction decode branch - parity test widened to dim 128 / 257 Cycle-2 must-fixes intentionally deferred with rationale in commit body: - per-byte centroid-index range validation in the fast path would re-introduce the O(rows * padded_dim) work the kernel exists to avoid; canonical decode already panics on out-of-range bytes, and the kernel succeeding-without-decoding vs canonical-panicking is at best a wash. - `TQEncode::is_fallible` optimizer-pushdown regression test requires dict-pushdown plumbing we do not have a simple way to construct. Signed-off-by: Connor Tsui --- .../src/scalar_fns/compute/l2_norm.rs | 16 ++++++++ vortex-turboquant/src/scalar_fns/encode.rs | 10 ++--- vortex-turboquant/src/tests/encode_decode.rs | 5 +++ vortex-turboquant/src/tests/kernels.rs | 37 +++++++++++++++---- 4 files changed, 55 insertions(+), 13 deletions(-) diff --git a/vortex-turboquant/src/scalar_fns/compute/l2_norm.rs b/vortex-turboquant/src/scalar_fns/compute/l2_norm.rs index 7d3eaa6d937..f4c8b11caba 100644 --- a/vortex-turboquant/src/scalar_fns/compute/l2_norm.rs +++ b/vortex-turboquant/src/scalar_fns/compute/l2_norm.rs @@ -4,6 +4,7 @@ //! `L2Norm` execute-parent kernel that intercepts `L2Norm(TQDecode(tq))` and returns the //! stored per-row norms directly instead of decoding and recomputing. +use num_traits::Zero; use vortex_array::ArrayRef; use vortex_array::ExecutionCtx; use vortex_array::IntoArray; @@ -12,6 +13,7 @@ use vortex_array::arrays::ScalarFn; use vortex_array::arrays::scalar_fn::ExactScalarFn; use vortex_array::arrays::scalar_fn::ScalarFnArrayExt; use vortex_array::dtype::Nullability; +use vortex_array::match_each_float_ptype; use vortex_array::optimizer::kernels::ArrayKernelsExt; use vortex_array::optimizer::kernels::ExecuteParentFn; use vortex_array::scalar_fn::ScalarFnVTable; @@ -61,6 +63,20 @@ fn l2_norm_tq_decode_execute_parent( let tq_array = child.as_::().child_at(0).clone(); let parsed = parse_storage_norms_only(tq_array, ctx)?; + // Fall back to the canonical `L2Norm` path on the (adversarial) case where any stored + // norm is strictly negative. Encode always produces non-negative norms (via `L2Norm`, + // which returns `sqrt(sum_sq)`), but a hand-constructed TurboQuant storage could carry + // arbitrary values in the `norms` child. Returning the stored bits verbatim would then + // violate `L2Norm`'s always-non-negative output invariant. The canonical path runs the + // in-flight decode rescaling and reapplies the stored norm, so its `L2Norm` output is + // `|stored_norm|` for every row by construction. + let has_negative_norm = match_each_float_ptype!(parsed.norms.ptype(), |T| { + parsed.norms.as_slice::().iter().any(|n| *n < T::zero()) + }); + if has_negative_norm { + return Ok(None); + } + let norms_validity = match parent.dtype().nullability() { Nullability::NonNullable => Validity::NonNullable, Nullability::Nullable => parsed.vector_validity, diff --git a/vortex-turboquant/src/scalar_fns/encode.rs b/vortex-turboquant/src/scalar_fns/encode.rs index 9530f3e00f0..522ffff0e48 100644 --- a/vortex-turboquant/src/scalar_fns/encode.rs +++ b/vortex-turboquant/src/scalar_fns/encode.rs @@ -169,12 +169,12 @@ impl ScalarFnVTable for TQEncode { // with the stored row norm. // // This is a data-dependent defined-behavior failure (analogous to `checked_add` on - // integer overflow), so per the `ScalarFnVTable::is_fallible` contract — used by + // integer overflow). The `ScalarFnVTable::is_fallible` contract is used by // optimizers to decide whether to speculatively evaluate over unreferenced inputs, - // for example pushing a scalar function down to a dictionary's values array — this - // function reports as fallible. Without it, speculative dictionary pushdown could - // surface non-finite-norm errors on rows that the consumer would never have - // referenced under normal evaluation. + // for example pushing a scalar function down to a dictionary's values array. By + // reporting as fallible we block that speculation; without this, speculative + // dictionary pushdown could surface non-finite-norm errors on rows that the + // consumer would never have referenced under normal evaluation. true } } diff --git a/vortex-turboquant/src/tests/encode_decode.rs b/vortex-turboquant/src/tests/encode_decode.rs index c6da9e1dd8f..94314b59eeb 100644 --- a/vortex-turboquant/src/tests/encode_decode.rs +++ b/vortex-turboquant/src/tests/encode_decode.rs @@ -294,6 +294,11 @@ fn encode_rejects_nan_input() -> VortexResult<()> { let result = execute_tq_encode(input, &TurboQuantConfig::default(), &mut ctx); assert!(result.is_err(), "encode must reject NaN input rows"); + let error = result.err().unwrap().to_string(); + assert!( + error.contains("non-finite"), + "expected non-finite error, got: {error}" + ); Ok(()) } diff --git a/vortex-turboquant/src/tests/kernels.rs b/vortex-turboquant/src/tests/kernels.rs index 9c45c1f59bc..d61ef5746ff 100644 --- a/vortex-turboquant/src/tests/kernels.rs +++ b/vortex-turboquant/src/tests/kernels.rs @@ -35,17 +35,38 @@ use crate::vector::storage::parse_storage; const DIM: u32 = 128; -/// Fast path: `L2Norm(TQDecode(tq_arr))` returns the storage `norms` field bit-for-bit. +/// Fast path: `L2Norm(TQDecode(tq_arr))` returns the storage `norms` field bit-for-bit +/// across every supported element ptype, so the kernel's per-ptype buffer-handle plumbing +/// is exercised at `f16`, `f32`, and `f64` rather than only the default `f32`. /// -/// `TQDecode` rescales each decoded direction in flight by the reciprocal of its own L2 norm -/// before re-applying the stored row norm, so decoded rows preserve the stored norm exactly. -/// Bit-exact equality with the parsed `norms` child confirms the session-registered kernel -/// fired instead of recomputing the norm from a materialized decode. -#[test] -fn l2_norm_over_tq_decode_returns_stored_norms() -> VortexResult<()> { +/// `TQDecode` rescales each decoded direction in flight by the reciprocal of its own L2 +/// norm before re-applying the stored row norm, so decoded rows preserve the stored norm +/// exactly. Bit-exact equality with the parsed `norms` child is consistent with the +/// session-registered kernel firing (the canonical-cross-check test below pins the +/// equivalence under arithmetic). +#[rstest] +#[case::f16(PType::F16)] +#[case::f32(PType::F32)] +#[case::f64(PType::F64)] +fn l2_norm_over_tq_decode_returns_stored_norms(#[case] ptype: PType) -> VortexResult<()> { let session = test_session(); let mut ctx = session.create_execution_ctx(); - let input = f32_vector_array(DIM, 4, 0.25, Validity::NonNullable)?; + let rows = 4; + let raw = (0..rows * DIM as usize) + .map(|i| ((i % 17) as f32 - 8.0) * 0.25) + .collect::>(); + let input = match ptype { + PType::F16 => { + let values: Vec = raw.iter().copied().map(half::f16::from_f32).collect(); + vector_array(DIM, &values, Validity::NonNullable)? + } + PType::F32 => vector_array(DIM, &raw, Validity::NonNullable)?, + PType::F64 => { + let values: Vec = raw.iter().copied().map(f64::from).collect(); + vector_array(DIM, &values, Validity::NonNullable)? + } + _ => unreachable!("ptype must be float"), + }; let config = TurboQuantConfig::try_new(3, 42, 3)?; let encoded = execute_tq_encode(input, &config, &mut ctx)?; From 83901ac45809e77b8eb2a25bcef47492fc8531a7 Mon Sep 17 00:00:00 2001 From: Connor Tsui Date: Thu, 21 May 2026 17:50:28 -0400 Subject: [PATCH 5/5] gauntlet: cycle 3 fixes for tq-l2-norm Apply 3 fixes from the cycle-3 phase-3 gauntlet review of the cycle-2 fix-commit (8c02fe56d). - Cover `-0.0` in the negative-stored-norm guard (l2_norm.rs): cycle-2 used `*n < T::zero()` which is `false` for `-0.0` per IEEE 754, so a hand-constructed `-0.0` stored norm slipped through and the kernel returned `-0.0` while canonical computed `sqrt(sum_sq) == +0.0`. Switch to `n.is_sign_negative()`, which covers both strictly-negative values and `-0.0`. The comment now documents the IEEE 754 subtlety and the cache-warm `O(rows)` cost of the scan. - Add `l2_norm_over_tq_decode_with_negative_stored_norm_falls_back` (kernels.rs), parameterized over `-5.0` and `-0.0`, that hand-builds a TurboQuant array with a sign-negative stored norm and asserts the result is non-negative and finite (proving the kernel fell back to the canonical path, which always returns `|stored_norm|`). - Add `l2_norm_over_tq_decode_rejects_codes_validity_narrower_than_struct` (kernels.rs) that hand-builds a TurboQuant array whose `codes` child has row validity narrower than the outer struct's, mirroring the existing canonical-path `decode_rejects_child_masks_that_disagree_with_struct_validity` test. Asserts `L2Norm(TQDecode(_))` errors via `parse_storage_norms_only`'s validation, pinning the fast/slow-path validation parity that the cycle-1 fix-commit added. Signed-off-by: Connor Tsui --- .../src/scalar_fns/compute/l2_norm.rs | 26 +++-- vortex-turboquant/src/tests/kernels.rs | 102 ++++++++++++++++++ 2 files changed, 120 insertions(+), 8 deletions(-) diff --git a/vortex-turboquant/src/scalar_fns/compute/l2_norm.rs b/vortex-turboquant/src/scalar_fns/compute/l2_norm.rs index f4c8b11caba..41946d5d5ee 100644 --- a/vortex-turboquant/src/scalar_fns/compute/l2_norm.rs +++ b/vortex-turboquant/src/scalar_fns/compute/l2_norm.rs @@ -4,7 +4,6 @@ //! `L2Norm` execute-parent kernel that intercepts `L2Norm(TQDecode(tq))` and returns the //! stored per-row norms directly instead of decoding and recomputing. -use num_traits::Zero; use vortex_array::ArrayRef; use vortex_array::ExecutionCtx; use vortex_array::IntoArray; @@ -64,14 +63,25 @@ fn l2_norm_tq_decode_execute_parent( let parsed = parse_storage_norms_only(tq_array, ctx)?; // Fall back to the canonical `L2Norm` path on the (adversarial) case where any stored - // norm is strictly negative. Encode always produces non-negative norms (via `L2Norm`, - // which returns `sqrt(sum_sq)`), but a hand-constructed TurboQuant storage could carry - // arbitrary values in the `norms` child. Returning the stored bits verbatim would then - // violate `L2Norm`'s always-non-negative output invariant. The canonical path runs the - // in-flight decode rescaling and reapplies the stored norm, so its `L2Norm` output is - // `|stored_norm|` for every row by construction. + // norm has its sign bit set. Encode always produces non-negative norms (via `L2Norm`, + // which returns `sqrt(sum_sq)` and never yields `-0.0`), but a hand-constructed + // TurboQuant storage could carry arbitrary values in the `norms` child. Returning the + // stored bits verbatim would then violate `L2Norm`'s always-non-negative output + // invariant. The canonical path runs the in-flight decode rescaling and reapplies the + // stored norm, so its `L2Norm` output is `|stored_norm|` for every row by construction. + // + // Using `is_sign_negative` rather than `< T::zero()` is load-bearing: `-0.0 < 0.0` is + // `false` per IEEE 754, so a literal comparison would miss a stored `-0.0` while the + // canonical path would still collapse it to `+0.0` via `sqrt(sum_sq)`. + // + // The scan is `O(rows)` over a buffer the just-completed `parse_storage_norms_only` + // materialized, so it does not move the kernel out of its constant-time-per-row regime. let has_negative_norm = match_each_float_ptype!(parsed.norms.ptype(), |T| { - parsed.norms.as_slice::().iter().any(|n| *n < T::zero()) + parsed + .norms + .as_slice::() + .iter() + .any(|n| n.is_sign_negative()) }); if has_negative_norm { return Ok(None); diff --git a/vortex-turboquant/src/tests/kernels.rs b/vortex-turboquant/src/tests/kernels.rs index d61ef5746ff..250773bf7d2 100644 --- a/vortex-turboquant/src/tests/kernels.rs +++ b/vortex-turboquant/src/tests/kernels.rs @@ -269,3 +269,105 @@ fn l2_norm_over_tq_decode_matches_canonical(#[case] dim: u32) -> VortexResult<() } Ok(()) } + +/// Adversarial: a hand-constructed TurboQuant storage with a `-5.0` or `-0.0` stored norm +/// makes the fast path fall back to the canonical `L2Norm(execute(TQDecode))` path so that +/// the result preserves `L2Norm`'s always-non-negative output invariant. The kernel scans +/// the parsed `norms` once and triggers fallback via `is_sign_negative`, which covers both +/// strictly-negative values and `-0.0` (where the literal `< 0` comparison would fail per +/// IEEE 754). +#[rstest] +#[case::strict_negative(-5.0_f32)] +#[case::negative_zero(-0.0_f32)] +fn l2_norm_over_tq_decode_with_negative_stored_norm_falls_back( + #[case] stored: f32, +) -> VortexResult<()> { + let session = test_session(); + let mut ctx = session.create_execution_ctx(); + let metadata = TurboQuantMetadata { + element_ptype: PType::F32, + dimensions: DIM, + bit_width: 1, + seed: 42, + num_rounds: 3, + }; + + let norms = + PrimitiveArray::new::(Buffer::copy_from([stored]), Validity::NonNullable).into_array(); + let codes = PrimitiveArray::new::(vec![0u8; DIM as usize], Validity::NonNullable); + let codes = FixedSizeListArray::try_new(codes.into_array(), DIM, Validity::NonNullable, 1)? + .into_array(); + let storage = StructArray::try_new( + FieldNames::from(["norms", "codes"]), + vec![norms, codes], + 1, + Validity::NonNullable, + )?; + let tq = ExtensionArray::try_new_from_vtable(TurboQuant, metadata, storage.into_array())? + .into_array(); + + let decoded = TQDecode::try_new_array(tq)?.into_array(); + let result: PrimitiveArray = L2Norm::try_new_array(decoded, 1)? + .into_array() + .execute(&mut ctx)?; + + // Whatever path runs, the result is an `L2Norm` output and must be non-negative; in + // particular the kernel must NOT return the stored sign-negative value verbatim. The + // exact magnitude depends on which centroid the all-zero codes decode to; we only + // assert the sign and finiteness, which is what `L2Norm`'s contract pins. + assert_eq!(result.as_slice::().len(), 1); + let value = result.as_slice::()[0]; + assert!( + value.is_finite() && !value.is_sign_negative(), + "L2Norm result must be non-negative and finite (got {value})" + ); + Ok(()) +} + +/// Adversarial: a hand-constructed TurboQuant storage whose `codes` child has row validity +/// narrower than the outer struct's must fail the fast path the same way it fails the +/// canonical decode path (see `malformed::decode_rejects_child_masks_that_disagree_with_struct_validity`). +/// `parse_storage_norms_only` executes the `codes` FSL wrapper specifically to enforce this +/// invariant. +#[test] +fn l2_norm_over_tq_decode_rejects_codes_validity_narrower_than_struct() -> VortexResult<()> { + let session = test_session(); + let mut ctx = session.create_execution_ctx(); + let metadata = TurboQuantMetadata { + element_ptype: PType::F32, + dimensions: DIM, + bit_width: 1, + seed: 42, + num_rounds: 3, + }; + + let norms = + PrimitiveArray::new::(Buffer::copy_from([1.0f32, 1.0, 1.0]), Validity::NonNullable) + .into_array(); + let codes = PrimitiveArray::new::(vec![0u8; 3 * DIM as usize], Validity::NonNullable); + let codes = FixedSizeListArray::try_new( + codes.into_array(), + DIM, + Validity::from_iter([true, false, true]), + 3, + )? + .into_array(); + let storage = StructArray::try_new( + FieldNames::from(["norms", "codes"]), + vec![norms, codes], + 3, + Validity::NonNullable, + )?; + let tq = ExtensionArray::try_new_from_vtable(TurboQuant, metadata, storage.into_array())? + .into_array(); + + let decoded = TQDecode::try_new_array(tq)?.into_array(); + let result: VortexResult = L2Norm::try_new_array(decoded, 3)? + .into_array() + .execute(&mut ctx); + assert!( + result.is_err(), + "kernel must reject codes-validity narrower than struct-validity" + ); + Ok(()) +}