Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
181 changes: 138 additions & 43 deletions spark/src/main/scala/org/apache/comet/serde/datetime.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ package org.apache.comet.serde

import java.util.Locale

import org.apache.spark.sql.catalyst.expressions.{AddMonths, Attribute, ConvertTimezone, DateAdd, DateDiff, DateFormatClass, DateFromUnixDate, DateSub, DayOfMonth, DayOfWeek, DayOfYear, Days, FromUTCTimestamp, GetDateField, GetTimestamp, Hour, Hours, LastDay, Literal, MakeDate, MakeTimestamp, MicrosToTimestamp, MillisToTimestamp, Minute, Month, MonthsBetween, NextDay, Quarter, Second, SecondsToTimestamp, ToUnixTimestamp, ToUTCTimestamp, TruncDate, TruncTimestamp, UnixDate, UnixMicros, UnixMillis, UnixSeconds, UnixTimestamp, WeekDay, WeekOfYear, Year}
import org.apache.spark.sql.catalyst.expressions.{AddMonths, Attribute, ConvertTimezone, DateAdd, DateDiff, DateFormatClass, DateFromUnixDate, DateSub, DayOfMonth, DayOfWeek, DayOfYear, Days, Expression, FromUTCTimestamp, GetDateField, GetTimestamp, Hour, Hours, LastDay, Literal, MakeDate, MakeTimestamp, MicrosToTimestamp, MillisToTimestamp, Minute, Month, MonthsBetween, NextDay, Quarter, Second, SecondsToTimestamp, ToUnixTimestamp, ToUTCTimestamp, TruncDate, TruncTimestamp, UnixDate, UnixMicros, UnixMillis, UnixSeconds, UnixTimestamp, WeekDay, WeekOfYear, Year}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DataType, DateType, DoubleType, FloatType, IntegerType, LongType, StringType, TimestampNTZType, TimestampType}
import org.apache.spark.unsafe.types.UTF8String
Expand All @@ -32,6 +32,7 @@ import org.apache.comet.expressions.{CometCast, CometEvalMode}
import org.apache.comet.serde.CometGetDateField.CometGetDateField
import org.apache.comet.serde.ExprOuterClass.Expr
import org.apache.comet.serde.QueryPlanSerde._
import org.apache.comet.shims.CometTypeShim

private object CometGetDateField extends Enumeration {
type CometGetDateField = Value
Expand Down Expand Up @@ -310,11 +311,28 @@ object CometSecond extends CometExpressionSerde[Second] with CodegenDispatchFall
}
}

private[serde] object DatetimeCollation extends CometTypeShim {
def reason(functionName: String): String =
s"$functionName does not support non-UTF8_BINARY collations " +
"(https://github.com/apache/datafusion-comet/issues/4646)"

def incompatibleReasons(functionName: String): Seq[String] =
if (hasCollationSupport) Seq(reason(functionName)) else Seq.empty

def hasNonDefaultCollation(expr: Expression): Boolean =
expr.children.exists(c => hasNonDefaultStringCollation(c.dataType))
}

object CometUnixTimestamp extends CometExpressionSerde[UnixTimestamp] {

private val collationReason = DatetimeCollation.reason("unix_timestamp")

override def getUnsupportedReasons(): Seq[String] = Seq(
"Only `DateType`, `TimestampType`, and `TimestampNTZType` inputs are supported.")

override def getIncompatibleReasons(): Seq[String] =
DatetimeCollation.incompatibleReasons("unix_timestamp")

private def isSupportedInputType(expr: UnixTimestamp): Boolean = {
expr.children.head.dataType match {
case TimestampType | DateType => true
Expand All @@ -324,7 +342,9 @@ object CometUnixTimestamp extends CometExpressionSerde[UnixTimestamp] {
}

override def getSupportLevel(expr: UnixTimestamp): SupportLevel = {
if (isSupportedInputType(expr)) {
if (DatetimeCollation.hasNonDefaultCollation(expr)) {
Incompatible(Some(collationReason))
} else if (isSupportedInputType(expr)) {
Compatible()
} else {
val inputType = expr.children.head.dataType
Expand Down Expand Up @@ -420,11 +440,19 @@ object CometConvertTimezone
extends CometExpressionSerde[ConvertTimezone]
with CodegenDispatchFallback {

override def getSupportLevel(expr: ConvertTimezone): SupportLevel =
Incompatible(Some(UTCTimestampSerde.tzParseIncompatReason))
private val collationReason = DatetimeCollation.reason("convert_timezone")

override def getSupportLevel(expr: ConvertTimezone): SupportLevel = {
if (DatetimeCollation.hasNonDefaultCollation(expr)) {
Incompatible(Some(collationReason))
} else {
Incompatible(Some(UTCTimestampSerde.tzParseIncompatReason))
}
}

override def getIncompatibleReasons(): Seq[String] =
Seq(UTCTimestampSerde.tzParseIncompatReason)
Seq(UTCTimestampSerde.tzParseIncompatReason) ++
DatetimeCollation.incompatibleReasons("convert_timezone")

override def convert(
expr: ConvertTimezone,
Expand All @@ -446,6 +474,18 @@ object CometNextDay extends CometExpressionSerde[NextDay] {
* `dayOfWeek` rather than returning NULL. The resolved flag is passed to native via the
* `ScalarFunc.fail_on_error` field.
*/
private val collationReason = DatetimeCollation.reason("next_day")

override def getIncompatibleReasons(): Seq[String] =
DatetimeCollation.incompatibleReasons("next_day")

override def getSupportLevel(expr: NextDay): SupportLevel = {
if (DatetimeCollation.hasNonDefaultCollation(expr)) {
Incompatible(Some(collationReason))
} else {
Compatible()
}
}
override def convert(expr: NextDay, inputs: Seq[Attribute], binding: Boolean): Option[Expr] = {
val childExpr = expr.children.map(exprToProtoInternal(_, inputs, binding))
val optExpr = scalarFunctionExprToProtoWithReturnType(
Expand Down Expand Up @@ -527,28 +567,35 @@ object CometTruncDate extends CometExpressionSerde[TruncDate] with CodegenDispat
val supportedFormats: Seq[String] =
Seq("year", "yyyy", "yy", "quarter", "mon", "month", "mm", "week")

private val collationReason = DatetimeCollation.reason("trunc")

private val nonLiteralFormatIncompatReason: String =
"Non-literal format strings will throw an exception instead of returning NULL"

private def unsupportedFormatReason(fmt: Any): String =
s"Format $fmt is not supported. Only the following formats are supported: " +
supportedFormats.mkString(", ")

override def getIncompatibleReasons(): Seq[String] = Seq(nonLiteralFormatIncompatReason)
override def getIncompatibleReasons(): Seq[String] =
Seq(nonLiteralFormatIncompatReason) ++ DatetimeCollation.incompatibleReasons("trunc")

override def getUnsupportedReasons(): Seq[String] = Seq(
"Only the following formats are supported: " + supportedFormats.mkString(", "))

override def getSupportLevel(expr: TruncDate): SupportLevel = {
expr.format match {
case Literal(fmt: UTF8String, _) =>
if (supportedFormats.contains(fmt.toString.toLowerCase(Locale.ROOT))) {
Compatible()
} else {
Unsupported(Some(unsupportedFormatReason(fmt)))
}
case _ =>
Incompatible(Some(nonLiteralFormatIncompatReason))
if (DatetimeCollation.hasNonDefaultCollation(expr)) {
Incompatible(Some(collationReason))
} else {
expr.format match {
case Literal(fmt: UTF8String, _) =>
if (supportedFormats.contains(fmt.toString.toLowerCase(Locale.ROOT))) {
Compatible()
} else {
Unsupported(Some(unsupportedFormatReason(fmt)))
}
case _ =>
Incompatible(Some(nonLiteralFormatIncompatReason))
}
}
}

Expand Down Expand Up @@ -591,6 +638,8 @@ object CometTruncTimestamp
"millisecond",
"microsecond")

private val collationReason = DatetimeCollation.reason("date_trunc")

private val nonUtcIncompatReason: String =
"Produces incorrect results when used with non-UTC timezones. Compatible when timezone is" +
" UTC. (https://github.com/apache/datafusion-comet/issues/2649)"
Expand All @@ -603,27 +652,32 @@ object CometTruncTimestamp
supportedFormats.mkString(", ")

override def getIncompatibleReasons(): Seq[String] =
Seq(nonUtcIncompatReason, nonLiteralFormatIncompatReason)
Seq(nonUtcIncompatReason, nonLiteralFormatIncompatReason) ++
DatetimeCollation.incompatibleReasons("date_trunc")

override def getUnsupportedReasons(): Seq[String] = Seq(
"Only the following formats are supported: " + supportedFormats.mkString(", "))

override def getSupportLevel(expr: TruncTimestamp): SupportLevel = {
val timezone = expr.timeZoneId.getOrElse("UTC")
val isUtc = timezone == "UTC" || timezone == "Etc/UTC"
expr.format match {
case Literal(fmt: UTF8String, _) =>
if (supportedFormats.contains(fmt.toString.toLowerCase(Locale.ROOT))) {
if (isUtc) {
Compatible()
if (DatetimeCollation.hasNonDefaultCollation(expr)) {
Incompatible(Some(collationReason))
} else {
val timezone = expr.timeZoneId.getOrElse("UTC")
val isUtc = timezone == "UTC" || timezone == "Etc/UTC"
expr.format match {
case Literal(fmt: UTF8String, _) =>
if (supportedFormats.contains(fmt.toString.toLowerCase(Locale.ROOT))) {
if (isUtc) {
Compatible()
} else {
Incompatible(Some(nonUtcIncompatReason))
}
} else {
Incompatible(Some(nonUtcIncompatReason))
Unsupported(Some(unsupportedFormatReason(fmt)))
}
} else {
Unsupported(Some(unsupportedFormatReason(fmt)))
}
case _ =>
Incompatible(Some(nonLiteralFormatIncompatReason))
case _ =>
Incompatible(Some(nonLiteralFormatIncompatReason))
}
}
}

Expand Down Expand Up @@ -667,10 +721,9 @@ object CometTruncTimestamp
* by [[CometConf.COMET_SCALA_UDF_CODEGEN_ENABLED]]. When that flag is disabled the operator
* falls back to Spark.
*/
object CometDateFormat extends CometExpressionSerde[DateFormatClass] with NativeOptInAvailable {

override def getIncompatibleReasons(): Seq[String] =
Seq("Non-UTC timezones may produce different results than Spark")
object CometDateFormat
extends CometExpressionSerde[DateFormatClass]
with CodegenDispatchFallback {

/**
* Mapping from Spark SimpleDateFormat patterns to strftime patterns. Only formats in this map
Expand Down Expand Up @@ -708,6 +761,12 @@ object CometDateFormat extends CometExpressionSerde[DateFormatClass] with Native
// ISO formats
"yyyy-MM-dd'T'HH:mm:ss" -> "%Y-%m-%dT%H:%M:%S")

private val collationReason = DatetimeCollation.reason("date_format")

override def getIncompatibleReasons(): Seq[String] =
Seq("Non-UTC timezones may produce different results than Spark") ++
DatetimeCollation.incompatibleReasons("date_format")

// Returns true when the format literal is in the native-format whitelist.
private def nativeApplicable(expr: DateFormatClass): Boolean = expr.right match {
case Literal(fmt: UTF8String, _) => supportedFormats.contains(fmt.toString)
Expand All @@ -720,15 +779,19 @@ object CometDateFormat extends CometExpressionSerde[DateFormatClass] with Native
}

override def getSupportLevel(expr: DateFormatClass): SupportLevel = {
// Show the opt-in hint only when: native is applicable, the config is OFF, and native is not
// already running due to UTC timezone. When isUtc is true, native already runs regardless of
// the config, so the hint would be misleading.
val isExprAllowIncompat = CometConf.isExprAllowIncompat(getExprConfigName(expr))
if (nativeApplicable(expr) && !isUtc(expr) && !isExprAllowIncompat) {
Compatible(nativeOptIn =
Some(NativeOptIn(CometConf.getExprAllowIncompatConfigKey(getExprConfigName(expr)))))
if (DatetimeCollation.hasNonDefaultCollation(expr)) {
Incompatible(Some(collationReason))
} else {
Compatible()
// Show the opt-in hint only when: native is applicable, the config is OFF, and native is not
// already running due to UTC timezone. When isUtc is true, native already runs regardless of
// the config, so the hint would be misleading.
val isExprAllowIncompat = CometConf.isExprAllowIncompat(getExprConfigName(expr))
if (nativeApplicable(expr) && !isUtc(expr) && !isExprAllowIncompat) {
Compatible(nativeOptIn =
Some(NativeOptIn(CometConf.getExprAllowIncompatConfigKey(getExprConfigName(expr)))))
} else {
Compatible()
}
}
}

Expand Down Expand Up @@ -868,7 +931,23 @@ object CometAddMonths extends CometCodegenDispatch[AddMonths]

object CometMonthsBetween extends CometCodegenDispatch[MonthsBetween]

object CometMakeTimestamp extends CometCodegenDispatch[MakeTimestamp]
object CometMakeTimestamp
extends CometCodegenDispatch[MakeTimestamp]
with CodegenDispatchFallback {

private val collationReason = DatetimeCollation.reason("make_timestamp")

override def getIncompatibleReasons(): Seq[String] =
DatetimeCollation.incompatibleReasons("make_timestamp")

override def getSupportLevel(expr: MakeTimestamp): SupportLevel = {
if (DatetimeCollation.hasNonDefaultCollation(expr)) {
Incompatible(Some(collationReason))
} else {
Compatible()
}
}
}

object CometMicrosToTimestamp extends CometCodegenDispatch[MicrosToTimestamp]

Expand All @@ -880,6 +959,22 @@ object CometUnixMillis extends CometCodegenDispatch[UnixMillis]

object CometUnixMicros extends CometCodegenDispatch[UnixMicros]

object CometToUnixTimestamp extends CometCodegenDispatch[ToUnixTimestamp]
object CometToUnixTimestamp
extends CometCodegenDispatch[ToUnixTimestamp]
with CodegenDispatchFallback {

private val collationReason = DatetimeCollation.reason("to_unix_timestamp")

override def getIncompatibleReasons(): Seq[String] =
DatetimeCollation.incompatibleReasons("to_unix_timestamp")

override def getSupportLevel(expr: ToUnixTimestamp): SupportLevel = {
if (DatetimeCollation.hasNonDefaultCollation(expr)) {
Incompatible(Some(collationReason))
} else {
Compatible()
}
}
}

object CometGetTimestamp extends CometCodegenDispatch[GetTimestamp]
17 changes: 14 additions & 3 deletions spark/src/main/scala/org/apache/comet/serde/unixtime.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,23 @@ import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, optExprWithFa
// https://github.com/apache/datafusion/issues/16594
object CometFromUnixTime extends CometExpressionSerde[FromUnixTime] with CodegenDispatchFallback {

override def getIncompatibleReasons(): Seq[String] = Seq(
private val collationReason = DatetimeCollation.reason("from_unixtime")

private val formatReason =
"Only supports the default datetime format pattern `yyyy-MM-dd HH:mm:ss`." +
" DataFusion's valid timestamp range differs from Spark" +
" (https://github.com/apache/datafusion/issues/16594)")
" (https://github.com/apache/datafusion/issues/16594)"

override def getIncompatibleReasons(): Seq[String] =
Seq(formatReason) ++ DatetimeCollation.incompatibleReasons("from_unixtime")

override def getSupportLevel(expr: FromUnixTime): SupportLevel = Incompatible(None)
override def getSupportLevel(expr: FromUnixTime): SupportLevel = {
if (DatetimeCollation.hasNonDefaultCollation(expr)) {
Incompatible(Some(collationReason))
} else {
Incompatible(Some(formatReason))
}
}

override def convert(
expr: FromUnixTime,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ trait CometTypeShim {
@nowarn // Spark 4 feature; stubbed to false in Spark 3.x for compatibility.
def hasNonDefaultStringCollation(dt: DataType): Boolean = false

@nowarn // Spark 4 feature; collation does not exist in Spark 3.x.
def hasCollationSupport: Boolean = false

@nowarn // Spark 4 feature; Variant shredding doesn't exist in Spark 3.x.
def isVariantStruct(s: StructType): Boolean = false

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,6 @@ trait CometTypeShim {

def isTimeType(dt: DataType): Boolean =
dt.getClass.getSimpleName.startsWith("TimeType")

def hasCollationSupport: Boolean = true
}
Loading
Loading