diff --git a/native/spark-expr/src/conversion_funcs/cast.rs b/native/spark-expr/src/conversion_funcs/cast.rs index 24d5083621..a09a7da687 100644 --- a/native/spark-expr/src/conversion_funcs/cast.rs +++ b/native/spark-expr/src/conversion_funcs/cast.rs @@ -393,6 +393,7 @@ pub(crate) fn cast_array( } (Utf8View, Utf8) => Ok(cast_with_options(&array, to_type, &CAST_OPTIONS)?), (Struct(_), Utf8) => Ok(casts_struct_to_string(array.as_struct(), cast_options)?), + (Map(_, _), Utf8) => Ok(cast_map_to_string(array.as_map(), cast_options)?), (Struct(_), Struct(_)) => Ok(cast_struct_to_struct( array.as_struct(), &from_type, @@ -728,6 +729,68 @@ fn casts_struct_to_string( Ok(Arc::new(builder.finish())) } +fn cast_map_to_string( + array: &MapArray, + spark_cast_options: &SparkCastOptions, +) -> DataFusionResult { + let mut builder = StringBuilder::with_capacity(array.len(), array.len() * 16); + let mut str = String::with_capacity(array.len() * 16); + + let casted_keys = cast_array( + Arc::clone(array.keys()), + &DataType::Utf8, + spark_cast_options, + )?; + let casted_values = cast_array( + Arc::clone(array.values()), + &DataType::Utf8, + spark_cast_options, + )?; + let key_values = casted_keys + .as_any() + .downcast_ref::() + .expect("Casted keys should be StringArray"); + let value_values = casted_values + .as_any() + .downcast_ref::() + .expect("Casted values should be StringArray"); + + let offsets = array.offsets(); + for row_index in 0..array.len() { + if array.is_null(row_index) { + builder.append_null(); + } else { + str.clear(); + let start = offsets[row_index] as usize; + let end = offsets[row_index + 1] as usize; + + str.push('{'); + let mut first = true; + for idx in start..end { + if !first { + str.push_str(", "); + } + if key_values.is_null(idx) { + str.push_str(&spark_cast_options.null_string); + } else { + str.push_str(key_values.value(idx)); + } + str.push_str(" -> "); + if value_values.is_null(idx) { + str.push_str(&spark_cast_options.null_string); + } else { + str.push_str(value_values.value(idx)); + } + first = false; + } + str.push('}'); + builder.append_value(&str); + } + } + + Ok(Arc::new(builder.finish())) +} + impl Display for Cast { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!( @@ -868,7 +931,8 @@ fn cast_binary_formatter(value: &[u8]) -> String { #[cfg(test)] mod tests { use super::*; - use arrow::array::{ListArray, NullArray, StringArray}; + use arrow::array::builder::{Int32Builder, MapBuilder, StringBuilder}; + use arrow::array::{ListArray, MapFieldNames, NullArray, StringArray}; use arrow::buffer::OffsetBuffer; use arrow::datatypes::TimestampMicrosecondType; use arrow::datatypes::{Field, Fields}; @@ -1025,6 +1089,41 @@ mod tests { } } + #[test] + fn test_cast_map_to_utf8() { + let mut map_builder = MapBuilder::new( + Some(MapFieldNames { + entry: "entries".into(), + key: "key".into(), + value: "value".into(), + }), + StringBuilder::new(), + Int32Builder::new(), + ); + + map_builder.keys().append_value("a"); + map_builder.values().append_value(1); + map_builder.keys().append_value("b"); + map_builder.values().append_null(); + map_builder.append(true).unwrap(); + + map_builder.append(true).unwrap(); + map_builder.append(false).unwrap(); + + let map_array: ArrayRef = Arc::new(map_builder.finish()); + let string_array = cast_array( + map_array, + &DataType::Utf8, + &SparkCastOptions::new(EvalMode::Legacy, "UTC", false), + ) + .unwrap(); + let string_array = string_array.as_string::(); + assert_eq!(3, string_array.len()); + assert_eq!(r#"{a -> 1, b -> null}"#, string_array.value(0)); + assert_eq!(r#"{}"#, string_array.value(1)); + assert!(string_array.is_null(2)); + } + #[test] fn test_cast_string_array_to_string() { let values_array = diff --git a/spark/src/main/scala/org/apache/comet/codegen/CometBatchKernelCodegen.scala b/spark/src/main/scala/org/apache/comet/codegen/CometBatchKernelCodegen.scala index 1c990835bb..045770e877 100644 --- a/spark/src/main/scala/org/apache/comet/codegen/CometBatchKernelCodegen.scala +++ b/spark/src/main/scala/org/apache/comet/codegen/CometBatchKernelCodegen.scala @@ -79,6 +79,7 @@ object CometBatchKernelCodegen extends Logging with CometExprTraitShim { * supported when their children are. */ def isSupportedDataType(dt: DataType): Boolean = dt match { + case NullType => true case BooleanType | ByteType | ShortType | IntegerType | LongType => true case FloatType | DoubleType => true case _: DecimalType => true diff --git a/spark/src/main/scala/org/apache/comet/codegen/CometBatchKernelCodegenInput.scala b/spark/src/main/scala/org/apache/comet/codegen/CometBatchKernelCodegenInput.scala index 09bfc52bd4..6da25a1261 100644 --- a/spark/src/main/scala/org/apache/comet/codegen/CometBatchKernelCodegenInput.scala +++ b/spark/src/main/scala/org/apache/comet/codegen/CometBatchKernelCodegenInput.scala @@ -587,6 +587,7 @@ private[codegen] object CometBatchKernelCodegenInput { * element or struct field. `idx` is the index/ordinal token (e.g. `"__i"` or `"3"`). */ private def elementGetterCall(dt: DataType, idx: String): String = dt match { + case NullType => "null" case BooleanType => s"getBoolean($idx)" case ByteType => s"getByte($idx)" case ShortType => s"getShort($idx)" @@ -672,6 +673,8 @@ private[codegen] object CometBatchKernelCodegenInput { if (elementNullable) " if (isNullAt(i)) return null;\n" else "" elemType match { + case NullType => + "" case BooleanType => s""" @Override | public boolean getBoolean(int i) { diff --git a/spark/src/main/scala/org/apache/comet/codegen/CometBatchKernelCodegenOutput.scala b/spark/src/main/scala/org/apache/comet/codegen/CometBatchKernelCodegenOutput.scala index a26e3d1796..a6806e77ef 100644 --- a/spark/src/main/scala/org/apache/comet/codegen/CometBatchKernelCodegenOutput.scala +++ b/spark/src/main/scala/org/apache/comet/codegen/CometBatchKernelCodegenOutput.scala @@ -158,6 +158,7 @@ private[codegen] object CometBatchKernelCodegenOutput { /** Concrete Arrow vector class name for the output type, used to cast `outRaw` once. */ private def outputVectorClass(dataType: DataType): String = dataType match { + case NullType => classOf[NullVector].getName case BooleanType => classOf[BitVector].getName case ByteType => classOf[TinyIntVector].getName case ShortType => classOf[SmallIntVector].getName @@ -204,6 +205,8 @@ private[codegen] object CometBatchKernelCodegenOutput { dataType: DataType, ctx: CodegenContext, nested: Boolean = false): OutputEmit = dataType match { + case NullType => + OutputEmit("", "") case BooleanType => val set = if (nested) "setSafe" else "set" OutputEmit("", s"$targetVec.$set($idx, $source ? 1 : 0);") @@ -389,6 +392,7 @@ private[codegen] object CometBatchKernelCodegenOutput { */ private def emitSpecializedGetterExpr(target: String, idx: String, elemType: DataType): String = elemType match { + case NullType => "null" case BooleanType => s"$target.getBoolean($idx)" case ByteType => s"$target.getByte($idx)" case ShortType => s"$target.getShort($idx)" diff --git a/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala b/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala index 42da809206..842e2391e8 100644 --- a/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala +++ b/spark/src/main/scala/org/apache/comet/expressions/CometCast.scala @@ -294,6 +294,8 @@ object CometCast extends CometExpressionSerde[Cast] with CometExprShim { Compatible() case DataTypes.BinaryType => Compatible() + case DataTypes.NullType => + Compatible() case StructType(fields) => for (field <- fields) { isSupported(field.dataType, DataTypes.StringType, timeZoneId, evalMode) match { @@ -305,6 +307,13 @@ object CometCast extends CometExpressionSerde[Cast] with CometExprShim { } } Compatible() + case MapType(keyType, valueType, _) => + isSupported(keyType, DataTypes.StringType, timeZoneId, evalMode) match { + case Compatible(_) => + isSupported(valueType, DataTypes.StringType, timeZoneId, evalMode) + case other => + other + } case _ => unsupported(fromType, DataTypes.StringType) } } diff --git a/spark/src/test/resources/sql-tests/expressions/cast/cast_complex_types_to_string.sql b/spark/src/test/resources/sql-tests/expressions/cast/cast_complex_types_to_string.sql index 8b1d989ae7..77919170df 100644 --- a/spark/src/test/resources/sql-tests/expressions/cast/cast_complex_types_to_string.sql +++ b/spark/src/test/resources/sql-tests/expressions/cast/cast_complex_types_to_string.sql @@ -149,8 +149,8 @@ SELECT cast(named_struct('a', named_struct('b', named_struct('c', 1, 'd', 'leaf' query SELECT cast(named_struct('s1', '', 's2', ' ', 's3', cast(null as string)) as string) --- Map-valued field: not supported, falls back to Spark. -query expect_fallback(to StringType is not supported) +-- Map-valued field: supported via recursive map -> string casting. +query SELECT cast(named_struct('m', map('k', 1)) as string) -- ---------------------------------------------------------------------------- @@ -270,69 +270,69 @@ SELECT cast(array(cast(1.5 as double), cast('NaN' as double), cast('-Infinity' a query SELECT cast(array(array(array(1, 2), array(3)), array(array(cast(null as int)))) as string) --- Array of map: not supported, falls back to Spark. -query expect_fallback(to StringType is not supported) +-- Array of map: supported via recursive map -> string casting. +query SELECT cast(array(map('k', 1)) as string) -- ---------------------------------------------------------------------------- -- Map → string -- ---------------------------------------------------------------------------- --- Comet does not implement map-to-string casts, so every map → string falls back to Spark. +-- Comet now implements map-to-string casts, including nested maps. -- Note: maps materialized through parquet have nondeterministic entry order, so map column -- tests use literal maps directly rather than reading from a parquet table. -- Map with string keys, int values. -query expect_fallback(Cast from MapType) +query SELECT cast(map('a', 1, 'b', 2, 'c', 3) as string) -- Map with NULL values rendered as "null". -query expect_fallback(Cast from MapType) +query SELECT cast(map('a', 1, 'b', cast(null as int), 'c', 3) as string) -- Map with int keys, string values. -query expect_fallback(Cast from MapType) +query SELECT cast(map(1, 'one', 2, 'two', 3, 'three') as string) -- Map with boolean values. -query expect_fallback(Cast from MapType) +query SELECT cast(map('t', true, 'f', false, 'n', cast(null as boolean)) as string) -- Map with bigint values at min/max. -query expect_fallback(Cast from MapType) +query SELECT cast(map('max', 9223372036854775807, 'min', -9223372036854775808, 'zero', cast(0 as bigint)) as string) -- Map with decimal values. -query expect_fallback(Cast from MapType) +query SELECT cast(map('pos', cast('1.234567890123456789' as decimal(38, 18)), 'neg', cast('-1.234567890123456789' as decimal(38, 18)), 'null', cast(null as decimal(38, 18))) as string) -- Map with date and timestamp values. -query expect_fallback(Cast from MapType) +query SELECT cast(map('a', date '2024-01-15', 'b', date '1970-01-01', 'c', cast(null as date)) as string) -query expect_fallback(Cast from MapType) +query SELECT cast(map('a', timestamp '2024-01-15 10:30:45', 'b', cast(null as timestamp)) as string) -- Map with binary values. -query expect_fallback(Cast from MapType) +query SELECT cast(map('a', X'616263', 'b', X'', 'c', cast(null as binary)) as string) -- Map with float / double values: NaN / ±0 / ±Infinity / NULL. -query expect_fallback(Cast from MapType) +query SELECT cast(map('nan', cast('NaN' as float), 'neg0', cast(-0.0 as float), 'null', cast(null as float)) as string) -query expect_fallback(Cast from MapType) +query SELECT cast(map('nan', cast('NaN' as double), 'inf', cast('Infinity' as double), 'ninf', cast('-Infinity' as double), 'null', cast(null as double)) as string) -- Map with struct values: each value rendered as `{f1, f2, ...}`. -query expect_fallback(Cast from MapType) +query SELECT cast(map('a', named_struct('x', 1, 'y', 'first'), 'b', cast(null as struct)) as string) -- Map with array values. -query expect_fallback(Cast from MapType) +query SELECT cast(map('a', array(1, 2, 3), 'b', array(cast(null as int)), 'c', cast(null as array)) as string) -- Empty map. -query expect_fallback(Cast from MapType) +query SELECT cast(map() as string) -- NULL map: Spark constant-folds this to a literal NULL, so the cast never reaches Comet @@ -341,5 +341,5 @@ query SELECT cast(cast(null as map) as string) -- Map of map. -query expect_fallback(Cast from MapType) +query SELECT cast(map('outer', map('inner', 1)) as string) diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala index 482183e914..8986dcd894 100644 --- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala @@ -1737,19 +1737,24 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { Incompatible(Some("There can be rounding differences"))) } - test("cast MapType propagates Unsupported from nested value cast") { + test("cast MapType to StringType is Compatible") { + val fromType = MapType(IntegerType, IntegerType) + assert( + CometCast.isSupported(fromType, DataTypes.StringType, None, CometEvalMode.LEGACY) == + Compatible()) + } + + test("cast MapType propagates supported nested value cast") { // Map> → Map: the inner Map → String - // cast is Unsupported, and that must propagate through the outer Map - // arm rather than being silently swallowed. + // cast is now supported and must propagate through the outer Map arm. val innerFrom = MapType(IntegerType, IntegerType) - val expectedMessage = s"Cast from $innerFrom to ${DataTypes.StringType} is not supported" assert( CometCast.isSupported( MapType(IntegerType, innerFrom), MapType(IntegerType, StringType), None, CometEvalMode.LEGACY) == - Unsupported(Some(expectedMessage))) + Compatible()) } test("cast ArrayType(DateType) to unsupported ArrayType falls back") {