diff --git a/datafusion/spark/src/function/string/substring.rs b/datafusion/spark/src/function/string/substring.rs index 21b22423298b7..1c26564e03993 100644 --- a/datafusion/spark/src/function/string/substring.rs +++ b/datafusion/spark/src/function/string/substring.rs @@ -16,8 +16,9 @@ // under the License. use arrow::array::{ - Array, ArrayBuilder, ArrayRef, AsArray, GenericStringBuilder, Int64Array, - OffsetSizeTrait, StringArrayType, StringViewBuilder, + Array, ArrayAccessor, ArrayBuilder, ArrayRef, AsArray, BinaryViewBuilder, + GenericBinaryBuilder, GenericStringBuilder, Int64Array, OffsetSizeTrait, + StringViewBuilder, }; use arrow::datatypes::DataType; use datafusion_common::arrow::datatypes::{Field, FieldRef}; @@ -56,6 +57,7 @@ impl Default for SparkSubstring { impl SparkSubstring { pub fn new() -> Self { let string = Coercion::new_exact(TypeSignatureClass::Native(logical_string())); + let binary = Coercion::new_exact(TypeSignatureClass::Binary); let int64 = Coercion::new_implicit( TypeSignatureClass::Native(logical_int64()), vec![TypeSignatureClass::Native(logical_int32())], @@ -70,6 +72,12 @@ impl SparkSubstring { int64.clone(), int64.clone(), ]), + TypeSignature::Coercible(vec![binary.clone(), int64.clone()]), + TypeSignature::Coercible(vec![ + binary.clone(), + int64.clone(), + int64.clone(), + ]), ], Volatility::Immutable, ) @@ -128,26 +136,65 @@ fn spark_substring(args: &[ArrayRef]) -> Result { }; match args[0].data_type() { - DataType::Utf8 => spark_substring_impl( - &args[0].as_string::(), + DataType::Utf8 => { + let array = args[0].as_string::(); + let is_ascii = enable_ascii_fast_path(&array, start_array, length_array); + spark_substring_generic( + &array, + start_array, + length_array, + GenericStringBuilder::::new(), + is_ascii, + ) + } + DataType::LargeUtf8 => { + let array = args[0].as_string::(); + let is_ascii = enable_ascii_fast_path(&array, start_array, length_array); + spark_substring_generic( + &array, + start_array, + length_array, + GenericStringBuilder::::new(), + is_ascii, + ) + } + DataType::Utf8View => { + let array = args[0].as_string_view(); + let is_ascii = enable_ascii_fast_path(&array, start_array, length_array); + spark_substring_generic( + &array, + start_array, + length_array, + StringViewBuilder::new(), + is_ascii, + ) + } + // Binary paths always use byte-level indexing, so `is_ascii` is irrelevant + // and set to `true` (its value is ignored by the `[u8]` impl of + // `SubstringItem`). + DataType::Binary => spark_substring_generic( + &args[0].as_binary::(), start_array, length_array, - GenericStringBuilder::::new(), + GenericBinaryBuilder::::new(), + true, ), - DataType::LargeUtf8 => spark_substring_impl( - &args[0].as_string::(), + DataType::LargeBinary => spark_substring_generic( + &args[0].as_binary::(), start_array, length_array, - GenericStringBuilder::::new(), + GenericBinaryBuilder::::new(), + true, ), - DataType::Utf8View => spark_substring_impl( - &args[0].as_string_view(), + DataType::BinaryView => spark_substring_generic( + &args[0].as_binary_view(), start_array, length_array, - StringViewBuilder::new(), + BinaryViewBuilder::new(), + true, ), other => exec_err!( - "Unsupported data type {other:?} for function spark_substring, expected Utf8View, Utf8 or LargeUtf8." + "Unsupported data type {other:?} for function spark_substring, expected Utf8View, Utf8, LargeUtf8, Binary, LargeBinary or BinaryView." ), } } @@ -173,43 +220,156 @@ fn spark_start_to_datafusion_start(start: i64, len: usize) -> i64 { } } -trait StringArrayBuilder: ArrayBuilder { - fn append_value(&mut self, val: &str); +trait SubstringItem { + /// Length used for Spark's negative-position adjustment. + /// For `str` this is characters (or bytes in ASCII mode); for `[u8]` it is + /// always byte count. + fn positional_len(&self, is_ascii: bool) -> usize; + + /// Converts Spark's 1-indexed adjusted start + optional length into a + /// byte range clamped to `[0, byte_len]`. + fn byte_range( + &self, + adjusted_start: i64, + len: Option, + is_ascii: bool, + ) -> Result<(usize, usize)>; + + fn byte_slice(&self, start: usize, end: usize) -> &Self; +} + +impl SubstringItem for str { + fn positional_len(&self, is_ascii: bool) -> usize { + if is_ascii { + self.len() + } else { + self.chars().count() + } + } + + fn byte_range( + &self, + adjusted_start: i64, + len: Option, + is_ascii: bool, + ) -> Result<(usize, usize)> { + get_true_start_end(self, adjusted_start, len, is_ascii) + } + + fn byte_slice(&self, start: usize, end: usize) -> &Self { + &self[start..end] + } +} + +impl SubstringItem for [u8] { + fn positional_len(&self, _is_ascii: bool) -> usize { + self.len() + } + + fn byte_range( + &self, + adjusted_start: i64, + len: Option, + _is_ascii: bool, + ) -> Result<(usize, usize)> { + let byte_len = self.len(); + let start0 = adjusted_start.saturating_sub(1); + let end0 = match len { + Some(l) => start0.saturating_add(l), + None => byte_len as i64, + }; + let byte_len_i64 = byte_len as i64; + Ok(( + start0.clamp(0, byte_len_i64) as usize, + end0.clamp(0, byte_len_i64) as usize, + )) + } + + fn byte_slice(&self, start: usize, end: usize) -> &Self { + &self[start..end] + } +} + +trait SubstringBuilder: ArrayBuilder { + type Item: SubstringItem + ?Sized; + fn append_value(&mut self, val: &Self::Item); fn append_null(&mut self); + /// Spark's semantic "empty" for this builder's item type, used for the + /// negative-length short-circuit. + fn append_empty(&mut self); } -impl StringArrayBuilder for GenericStringBuilder { +impl SubstringBuilder for GenericStringBuilder { + type Item = str; fn append_value(&mut self, val: &str) { GenericStringBuilder::append_value(self, val); } fn append_null(&mut self) { GenericStringBuilder::append_null(self); } + fn append_empty(&mut self) { + GenericStringBuilder::append_value(self, ""); + } } -impl StringArrayBuilder for StringViewBuilder { +impl SubstringBuilder for StringViewBuilder { + type Item = str; fn append_value(&mut self, val: &str) { StringViewBuilder::append_value(self, val); } fn append_null(&mut self) { StringViewBuilder::append_null(self); } + fn append_empty(&mut self) { + StringViewBuilder::append_value(self, ""); + } } -fn spark_substring_impl<'a, V, B>( - string_array: &V, +impl SubstringBuilder for GenericBinaryBuilder { + type Item = [u8]; + fn append_value(&mut self, val: &[u8]) { + GenericBinaryBuilder::append_value(self, val); + } + fn append_null(&mut self) { + GenericBinaryBuilder::append_null(self); + } + fn append_empty(&mut self) { + GenericBinaryBuilder::append_value(self, &[]); + } +} + +impl SubstringBuilder for BinaryViewBuilder { + type Item = [u8]; + fn append_value(&mut self, val: &[u8]) { + BinaryViewBuilder::append_value(self, val); + } + fn append_null(&mut self) { + BinaryViewBuilder::append_null(self); + } + fn append_empty(&mut self) { + BinaryViewBuilder::append_value(self, []); + } +} + +/// Unified implementation of Spark's `substring`, generic over the source +/// array (`StringArrayType`/`BinaryArrayType` via `ArrayAccessor`) and its +/// corresponding builder. Per-row indexing semantics are delegated to +/// [`SubstringItem`], which differs between `str` (char-aware when +/// `is_ascii` is false) and `[u8]` (always byte-level). +fn spark_substring_generic<'a, Source, Item, Builder>( + array: &Source, start_array: &Int64Array, length_array: Option<&Int64Array>, - mut builder: B, + mut builder: Builder, + is_ascii: bool, ) -> Result where - V: StringArrayType<'a>, - B: StringArrayBuilder, + Source: ArrayAccessor, + Item: SubstringItem + ?Sized + 'a, + Builder: SubstringBuilder, { - let is_ascii = enable_ascii_fast_path(string_array, start_array, length_array); - - for i in 0..string_array.len() { - if string_array.is_null(i) || start_array.is_null(i) { + for i in 0..array.len() { + if array.is_null(i) || start_array.is_null(i) { builder.append_null(); continue; } @@ -221,30 +381,23 @@ where continue; } - let string = string_array.value(i); + let value = array.value(i); let start = start_array.value(i); let len_opt = length_array.map(|arr| arr.value(i)); - // Spark: negative length returns empty string + // Spark: negative length yields an empty value if let Some(len) = len_opt && len < 0 { - builder.append_value(""); + builder.append_empty(); continue; } - let string_len = if is_ascii { - string.len() - } else { - string.chars().count() - }; - - let adjusted_start = spark_start_to_datafusion_start(start, string_len); - + let positional_len = value.positional_len(is_ascii); + let adjusted_start = spark_start_to_datafusion_start(start, positional_len); let (byte_start, byte_end) = - get_true_start_end(string, adjusted_start, len_opt, is_ascii)?; - let substr = &string[byte_start..byte_end]; - builder.append_value(substr); + value.byte_range(adjusted_start, len_opt, is_ascii)?; + builder.append_value(value.byte_slice(byte_start, byte_end)); } Ok(builder.finish()) diff --git a/datafusion/sqllogictest/test_files/spark/string/substring.slt b/datafusion/sqllogictest/test_files/spark/string/substring.slt index 9189ce50c2ffe..1670de2c822b3 100644 --- a/datafusion/sqllogictest/test_files/spark/string/substring.slt +++ b/datafusion/sqllogictest/test_files/spark/string/substring.slt @@ -207,4 +207,164 @@ SELECT substr(column1, -10, 3) FROM VALUES ('abc'::string) ---- -(empty) \ No newline at end of file +(empty) + +# ---------------------------------------------------------------------- +# Binary / LargeBinary / BinaryView support +# ---------------------------------------------------------------------- + +# 'Spark SQL' bytes => 53 70 61 72 6b 20 53 51 4c +# byte 1..9 => 's' 'p' 'a' 'r' 'k' ' ' 'S' 'Q' 'L' + +# Binary: basic substring with only pos +query ? +SELECT substring(arrow_cast('Spark SQL', 'Binary'), 5::int); +---- +6b2053514c + +# Binary: substring with pos and length +query ? +SELECT substring(arrow_cast('Spark SQL', 'Binary'), 3::int, 1::int); +---- +61 + +# Binary: zero pos is treated as 1 +query ? +SELECT substring(arrow_cast('Spark SQL', 'Binary'), 0::int); +---- +537061726b2053514c + +# Binary: negative pos counts from the end +query ? +SELECT substring(arrow_cast('Spark SQL', 'Binary'), -3::int); +---- +53514c + +query ? +SELECT substring(arrow_cast('Spark SQL', 'Binary'), -3::int, 2::int); +---- +5351 + +# Binary: length exceeding input length +query ? +SELECT substring(arrow_cast('Spark SQL', 'Binary'), 2::int, 700::int); +---- +7061726b2053514c + +# Binary: pos beyond end returns empty binary +query I +SELECT character_length(substring(arrow_cast('Spark SQL', 'Binary'), 30::int)); +---- +0 + +# Binary: large negative pos returns the full value +query ? +SELECT substring(arrow_cast('Spark SQL', 'Binary'), -30::int); +---- +537061726b2053514c + +# Binary: negative length returns empty binary +query I +SELECT character_length(substring(arrow_cast('Spark SQL', 'Binary'), 3::int, -1::int)); +---- +0 + +# Binary: zero length returns empty binary +query I +SELECT character_length(substring(arrow_cast('Spark SQL', 'Binary'), 3::int, 0::int)); +---- +0 + +# Binary: NULL handling +query ? +SELECT substring(arrow_cast('Spark SQL', 'Binary'), NULL::int); +---- +NULL + +query ? +SELECT substring(arrow_cast(NULL, 'Binary'), 5::int); +---- +NULL + +query ? +SELECT substring(arrow_cast('Spark SQL', 'Binary'), NULL::int, 1::int); +---- +NULL + +query ? +SELECT substring(arrow_cast('Spark SQL', 'Binary'), 3::int, NULL::int); +---- +NULL + +# Binary: return type matches input type +query ?T +SELECT substring(arrow_cast('Spark SQL', 'Binary'), 3::int, 2::int), + arrow_typeof(substring(arrow_cast('Spark SQL', 'Binary'), 3::int, 2::int)); +---- +6172 Binary + +# LargeBinary: return type is preserved +query ?T +SELECT substring(arrow_cast('Spark SQL', 'LargeBinary'), 3::int, 2::int), + arrow_typeof(substring(arrow_cast('Spark SQL', 'LargeBinary'), 3::int, 2::int)); +---- +6172 LargeBinary + +# BinaryView: return type is preserved +query ?T +SELECT substring(arrow_cast('Spark SQL', 'BinaryView'), 3::int, 2::int), + arrow_typeof(substring(arrow_cast('Spark SQL', 'BinaryView'), 3::int, 2::int)); +---- +6172 BinaryView + +# Binary with non-UTF-8 bytes: byte-level slicing is preserved +# X'ff00aabb1122' has 6 bytes and is typed as Binary by default +query ? +SELECT substring(X'ff00aabb1122', 2::int, 3::int); +---- +00aabb + +query ? +SELECT substring(X'ff00aabb1122', -2::int, 2::int); +---- +1122 + +# substr alias works with binary input +query ? +SELECT substr(arrow_cast('Spark SQL', 'Binary'), 5::int); +---- +6b2053514c + +query ? +SELECT substr(arrow_cast('Spark SQL', 'LargeBinary'), 3::int, 1::int); +---- +61 + +# Array input (Binary column) produces per-row byte-level substrings. +# The second column is the byte length of the result so empty-binary +# rows remain unambiguous in the expected output. +query ?I +SELECT substring(column1, column2, column3), + character_length(substring(column1, column2, column3)) +FROM VALUES +(arrow_cast('Spark SQL', 'Binary'), -3::int, 2::int), +(arrow_cast('Spark SQL', 'Binary'), 3::int, 1::int), +(arrow_cast('Spark SQL', 'Binary'), 3::int, 700::int), +(arrow_cast('Spark SQL', 'Binary'), 3::int, -1::int), +(arrow_cast('Spark SQL', 'Binary'), 3::int, 0::int), +(arrow_cast('Spark SQL', 'Binary'), 300::int, 3::int), +(arrow_cast('Spark SQL', 'Binary'), -300::int, 3::int), +(arrow_cast(NULL, 'Binary'), 3::int, 1::int), +(arrow_cast('Spark SQL', 'Binary'), NULL::int, 1::int), +(arrow_cast('Spark SQL', 'Binary'), 3::int, NULL::int); +---- +5351 2 +61 1 +61726b2053514c 7 + 0 + 0 + 0 + 0 +NULL NULL +NULL NULL +NULL NULL