Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
244 changes: 105 additions & 139 deletions Cargo.lock

Large diffs are not rendered by default.

72 changes: 52 additions & 20 deletions datafusion/expr/src/logical_plan/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

//! This module provides a builder for creating LogicalPlans

use datafusion_common::metadata::check_metadata_with_storage_equal;
use std::borrow::Cow;
use std::cmp::Ordering;
use std::collections::{HashMap, HashSet};
Expand Down Expand Up @@ -55,9 +56,8 @@ use datafusion_common::file_options::file_type::FileType;
use datafusion_common::metadata::FieldMetadata;
use datafusion_common::{
Column, Constraints, DFSchema, DFSchemaRef, NullEquality, Result, ScalarValue,
TableReference, ToDFSchema, UnnestOptions, exec_err,
get_target_functional_dependencies, internal_datafusion_err, plan_datafusion_err,
plan_err,
TableReference, ToDFSchema, UnnestOptions, get_target_functional_dependencies,
internal_datafusion_err, plan_datafusion_err, plan_err,
};
use datafusion_expr_common::type_coercion::binary::type_union_resolution;

Expand Down Expand Up @@ -272,23 +272,59 @@ impl LogicalPlanBuilder {
let n_cols = values[0].len();
let mut fields = ValuesFields::new();
for j in 0..n_cols {
let field_type = schema.field(j).data_type();
let field_nullable = schema.field(j).is_nullable();
let field = schema.field(j);
let field_type = field.data_type();
let field_metadata = field.metadata();
let is_target_ext = field_metadata
.contains_key(arrow_schema::extension::EXTENSION_TYPE_NAME_KEY);
let column_name = format!("column {j}");
for row in values.iter() {
let value = &row[j];
let data_type = value.get_type(schema)?;

if !data_type.equals_datatype(field_type)
&& !can_cast_types(&data_type, field_type)
{
return exec_err!(
"type mismatch and can't cast to got {} and {}",
data_type,
field_type
);
let value_type = value.get_type(schema)?;
if value_type == DataType::Null {
continue;
}

if is_target_ext {
let value_meta = value.metadata(schema)?;
let value_meta_map = value_meta.to_hashmap();
let is_value_ext = value_meta_map
.contains_key(arrow_schema::extension::EXTENSION_TYPE_NAME_KEY);

if is_value_ext && value_type == *field_type {
check_metadata_with_storage_equal(
(&value_type, Some(&value_meta_map)),
(field_type, Some(field_metadata)),
&column_name,
" in VALUES list",
)?;
} else if !can_cast_types(&value_type, field_type) {
return plan_err!(
"Cannot cast {} to extension type at {}",
value_type,
column_name
);
}
} else {
// Optimized path for standard types
if !value_type.equals_datatype(field_type)
&& !can_cast_types(&value_type, field_type)
{
return plan_err!(
"Cannot cast {} to {} for {}",
value_type,
field_type,
column_name
);
}
}
}
fields.push(field_type.to_owned(), field_nullable);

fields.push_with_metadata(
field_type.clone(),
field.is_nullable(),
Some(FieldMetadata::new_from_field(field)),
);
}

Self::infer_inner(values, fields, schema)
Expand Down Expand Up @@ -1567,10 +1603,6 @@ impl ValuesFields {
Self::default()
}

pub fn push(&mut self, data_type: DataType, nullable: bool) {
self.push_with_metadata(data_type, nullable, None);
}

pub fn push_with_metadata(
&mut self,
data_type: DataType,
Expand Down
1 change: 1 addition & 0 deletions datafusion/sql/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ recursive_protection = ["dep:recursive"]
# They are used for testing purposes only, so they are in the dev-dependencies section.
[dependencies]
arrow = { workspace = true }
arrow-schema = { workspace = true }
bigdecimal = { workspace = true }
chrono = { workspace = true }
datafusion-common = { workspace = true, features = ["sql"] }
Expand Down
27 changes: 27 additions & 0 deletions datafusion/sql/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,12 @@ use std::vec;

use crate::utils::make_decimal_type;
use arrow::datatypes::*;
use arrow_schema::extension::{ExtensionType, Json, Uuid};
use datafusion_common::TableReference;
use datafusion_common::config::SqlParserOptions;
use datafusion_common::datatype::{DataTypeExt, FieldExt};
use datafusion_common::error::add_possible_columns_to_diag;
use datafusion_common::types::{DFExtensionType, DFJson, DFUuid};
use datafusion_common::{DFSchema, DataFusionError, Result, not_impl_err, plan_err};
use datafusion_common::{
DFSchemaRef, Diagnostic, SchemaError, field_not_found, internal_err,
Expand Down Expand Up @@ -658,6 +660,31 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {

// If no type_planner can handle this type, use the default conversion
match sql_type {
// Canonical Arrow extension types
SQLDataType::Uuid => {
let data_type = DataType::FixedSizeBinary(16);
let df_uuid = DFUuid::try_new(&data_type, Default::default())?;

// Re-construct the Arrow Uuid using the validated data from df_uuid
let arrow_uuid = Uuid::try_new(&df_uuid.storage_type().clone(), ())
.map_err(|e| DataFusionError::External(Box::new(e)))?;

Ok(Arc::new(
Field::new("", data_type, true).with_extension_type(arrow_uuid),
))
}
SQLDataType::JSON => {
let data_type = DataType::Utf8;
let df_json = DFJson::try_new(&data_type, Default::default())?;

let arrow_json =
Json::try_new(&df_json.storage_type().clone(), Default::default())
.map_err(|e| DataFusionError::External(Box::new(e)))?;

Ok(Arc::new(
Field::new("", data_type, true).with_extension_type(arrow_json),
))
}
SQLDataType::Array(ArrayElemTypeDef::AngleBracket(inner_sql_type)) => {
// Arrays may be multi-dimensional.
Ok(self.convert_data_type_to_field(inner_sql_type)?.into_list())
Expand Down
55 changes: 47 additions & 8 deletions datafusion/sql/src/statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ use crate::utils::normalize_ident;

use arrow::datatypes::{Field, FieldRef, Fields};
use datafusion_common::error::_plan_err;
use datafusion_common::metadata::check_metadata_with_storage_equal;
use datafusion_common::parsers::CompressionTypeVariant;
use datafusion_common::{
Column, Constraint, Constraints, DFSchema, DFSchemaRef, DataFusionError, Result,
Expand All @@ -52,7 +53,7 @@ use datafusion_expr::{
LogicalPlan, LogicalPlanBuilder, OperateFunctionArg, PlanType, Prepare,
ResetVariable, SetVariable, SortExpr, Statement as PlanStatement, ToStringifiedPlan,
TransactionAccessMode, TransactionConclusion, TransactionEnd,
TransactionIsolationLevel, TransactionStart, Volatility, WriteOp, cast, col,
TransactionIsolationLevel, TransactionStart, Volatility, WriteOp, col,
};
use sqlparser::ast::{
self, BeginTransactionKind, CheckConstraint, ForeignKeyConstraint, IndexColumn,
Expand Down Expand Up @@ -529,14 +530,52 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
.fields()
.iter()
.zip(input_fields)
.map(|(field, input_field)| {
cast(
col(input_field.name()),
field.data_type().clone(),
)
.alias(field.name())
.map(|(field, input_field)| -> Result<Expr> {
let input_meta = input_field.metadata();
let target_meta = field.metadata();

let is_target_ext = target_meta.contains_key(
arrow_schema::extension::EXTENSION_TYPE_NAME_KEY,
);
let is_input_ext = input_meta.contains_key(
arrow_schema::extension::EXTENSION_TYPE_NAME_KEY,
);

if !is_target_ext && !is_input_ext {
return if input_field.data_type()
== field.data_type()
{
Ok(col(input_field.name())
.alias(field.name()))
} else {
Ok(col(input_field.name())
.cast_to(field.data_type(), input_schema)?
.alias(field.name()))
};
}

let metadata_matches =
check_metadata_with_storage_equal(
(input_field.data_type(), Some(input_meta)),
(field.data_type(), Some(target_meta)),
"input",
"target",
)
.is_ok();

if metadata_matches {
Ok(col(input_field.name()).alias(field.name()))
} else {
Ok(Expr::Cast(
datafusion_expr::Cast::new_from_field(
Box::new(col(input_field.name())),
Arc::clone(field),
),
)
.alias(field.name()))
}
})
.collect::<Vec<_>>();
.collect::<Result<Vec<_>>>()?;

LogicalPlanBuilder::from(plan.clone())
.project(project_exprs)?
Expand Down
6 changes: 0 additions & 6 deletions datafusion/sql/tests/common/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -380,12 +380,6 @@ impl TypePlanner for CustomTypePlanner {
sql_type: &sqlparser::ast::DataType,
) -> Result<Option<FieldRef>> {
match sql_type {
sqlparser::ast::DataType::Uuid => Ok(Some(Arc::new(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🥳

Field::new("", DataType::FixedSizeBinary(16), true).with_metadata(
[("ARROW:extension:name".to_string(), "arrow.uuid".to_string())]
.into(),
),
))),
sqlparser::ast::DataType::Datetime(precision) => {
let precision = match precision {
Some(0) => TimeUnit::Second,
Expand Down
30 changes: 1 addition & 29 deletions datafusion/sqllogictest/src/test_context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@ use arrow::array::{
};
use arrow::buffer::ScalarBuffer;
use arrow::datatypes::{
DataType, Field, FieldRef, Fields, Schema, SchemaRef, TimeUnit, UInt32Type,
UnionFields,
DataType, Field, Fields, Schema, SchemaRef, TimeUnit, UInt32Type, UnionFields,
};
use arrow::record_batch::RecordBatch;
use datafusion::catalog::{
Expand All @@ -39,7 +38,6 @@ use datafusion::catalog::{
use datafusion::common::{DataFusionError, Result, not_impl_err};
use datafusion::functions::math::abs;
use datafusion::logical_expr::async_udf::{AsyncScalarUDF, AsyncScalarUDFImpl};
use datafusion::logical_expr::planner::TypePlanner;
use datafusion::logical_expr::{
ColumnarValue, Expr, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature,
Volatility, create_udf,
Expand All @@ -58,7 +56,6 @@ use datafusion::common::cast::as_float64_array;
use datafusion::execution::SessionStateBuilder;
use datafusion::execution::runtime_env::RuntimeEnv;
use log::info;
use sqlparser::ast;
use tempfile::TempDir;

/// Context for running tests
Expand All @@ -69,23 +66,6 @@ pub struct TestContext {
test_dir: Option<TempDir>,
}

#[derive(Debug)]
struct SqlLogicTestTypePlanner;

impl TypePlanner for SqlLogicTestTypePlanner {
fn plan_type_field(&self, sql_type: &ast::DataType) -> Result<Option<FieldRef>> {
match sql_type {
ast::DataType::Uuid => Ok(Some(Arc::new(
Field::new("", DataType::FixedSizeBinary(16), true).with_metadata(
[("ARROW:extension:name".to_string(), "arrow.uuid".to_string())]
.into(),
),
))),
_ => Ok(None),
}
}
}

impl TestContext {
pub fn new(ctx: SessionContext) -> Self {
Self {
Expand Down Expand Up @@ -114,14 +94,6 @@ impl TestContext {
state_builder = state_builder.with_spark_features();
}

if matches!(
relative_path.file_name().and_then(|name| name.to_str()),
Some("cast_extension_type_metadata.slt")
) {
state_builder =
state_builder.with_type_planner(Arc::new(SqlLogicTestTypePlanner));
}

let state = state_builder.build();

let mut test_ctx = TestContext::new(SessionContext::new_with_state(state));
Expand Down
Loading