diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/InsertSummary.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/InsertSummary.java new file mode 100644 index 0000000000000..40f41bf238447 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/InsertSummary.java @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.write; + +import org.apache.spark.annotation.Evolving; + +/** + * Provides an informational summary of the INSERT operation producing write. + * + * @since 4.2.0 + */ +@Evolving +public interface InsertSummary extends WriteSummary { + + /** + * Returns the number of inserted rows, or -1 if not found. + */ + long numInsertedRows(); +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala index 168f30623ee4c..8281f89bd2e8e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.expressions.{Alias, And, Attribute, Attribu import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, JoinType, LeftAnti, LeftOuter, RightOuter} -import org.apache.spark.sql.catalyst.plans.logical.{AppendData, DeleteAction, Filter, HintInfo, InsertAction, Join, JoinHint, LogicalPlan, MergeAction, MergeIntoTable, MergeRows, NO_BROADCAST_AND_REPLICATION, Project, ReplaceData, UpdateAction, WriteDelta} +import org.apache.spark.sql.catalyst.plans.logical.{DeleteAction, Filter, HintInfo, InsertAction, InsertOnlyMerge, Join, JoinHint, LogicalPlan, MergeAction, MergeIntoTable, MergeRows, NO_BROADCAST_AND_REPLICATION, Project, ReplaceData, UpdateAction, WriteDelta} import org.apache.spark.sql.catalyst.plans.logical.MergeRows.{Copy, Delete, Discard, Insert, Instruction, Keep, ROW_ID, Split, Update} import org.apache.spark.sql.catalyst.util.RowDeltaUtils.{COPY_OPERATION, INSERT_OPERATION, OPERATION_COLUMN, UPDATE_OPERATION} import org.apache.spark.sql.connector.catalog.SupportsRowLevelOperations @@ -73,7 +73,7 @@ object RewriteMergeIntoTable extends RewriteRowLevelCommand with PredicateHelper } val project = Project(projectList, joinPlan) - AppendData.byPosition(r, project) + InsertOnlyMerge(r, project) case _ => m @@ -114,7 +114,7 @@ object RewriteMergeIntoTable extends RewriteRowLevelCommand with PredicateHelper output = generateExpandOutput(r.output, outputs), joinPlan) - AppendData.byPosition(r, mergeRows) + InsertOnlyMerge(r, mergeRows) case _ => m diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala index 7b657ce34df45..b1ab46ee94817 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala @@ -194,6 +194,26 @@ object AppendData { } } +/** + * Append data to an existing table as the result of an insert-only MERGE rewrite. + * + * Functionally equivalent to [[AppendData]] but distinguishes the row-level MERGE rewrite path. + */ +case class InsertOnlyMerge( + table: NamedRelation, + query: LogicalPlan, + write: Option[Write] = None, + analyzedQuery: Option[LogicalPlan] = None) extends V2WriteCommand with TransactionalWrite { + override val isByName: Boolean = false + override val withSchemaEvolution: Boolean = false + override val writePrivileges: Set[TableWritePrivilege] = Set(TableWritePrivilege.INSERT) + override def withNewQuery(newQuery: LogicalPlan): InsertOnlyMerge = copy(query = newQuery) + override def withNewTable(newTable: NamedRelation): InsertOnlyMerge = copy(table = newTable) + override def storeAnalyzedQuery(): Command = copy(analyzedQuery = Some(query)) + override protected def withNewChildInternal(newChild: LogicalPlan): InsertOnlyMerge = + copy(query = newChild) +} + /** * Overwrite data matching a filter in an existing table. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/write/InsertSummaryImpl.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/write/InsertSummaryImpl.scala new file mode 100644 index 0000000000000..97c2e082c2573 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/write/InsertSummaryImpl.scala @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.write + +/** + * Implementation of [[InsertSummary]] that provides INSERT operation summary. + */ +private[sql] case class InsertSummaryImpl(numInsertedRows: Long) extends InsertSummary { +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala index 53e57153030b7..f582f3e408cb6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala @@ -788,30 +788,43 @@ abstract class InMemoryBaseTable( } override def abort(messages: Array[WriterCommitMessage]): Unit = {} + + protected def doCommit(messages: Array[WriterCommitMessage]): Unit + + override final def commit(messages: Array[WriterCommitMessage]): Unit = { + doCommit(messages) + commits += Commit(Instant.now().toEpochMilli) + } + + override final def commit( + messages: Array[WriterCommitMessage], + summary: WriteSummary): Unit = { + doCommit(messages) + commits += Commit(Instant.now().toEpochMilli, writeSummary = Some(summary)) + } } class Append(val info: LogicalWriteInfo) extends TestBatchWrite { - - override def commit(messages: Array[WriterCommitMessage]): Unit = dataMap.synchronized { + override protected def doCommit( + messages: Array[WriterCommitMessage]): Unit = dataMap.synchronized { withData(messages.map(_.asInstanceOf[BufferedRows])) - commits += Commit(Instant.now().toEpochMilli) } } class DynamicOverwrite(val info: LogicalWriteInfo) extends TestBatchWrite { - override def commit(messages: Array[WriterCommitMessage]): Unit = dataMap.synchronized { + override protected def doCommit( + messages: Array[WriterCommitMessage]): Unit = dataMap.synchronized { val newData = messages.map(_.asInstanceOf[BufferedRows]) dataMap --= newData.flatMap(_.rows.map(getKey)) withData(newData) - commits += Commit(Instant.now().toEpochMilli) } } class TruncateAndAppend(val info: LogicalWriteInfo) extends TestBatchWrite { - override def commit(messages: Array[WriterCommitMessage]): Unit = dataMap.synchronized { + override protected def doCommit( + messages: Array[WriterCommitMessage]): Unit = dataMap.synchronized { dataMap.clear() withData(messages.map(_.asInstanceOf[BufferedRows])) - commits += Commit(Instant.now().toEpochMilli) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTable.scala index 6778f4489e459..5c0bc0b143f3d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTable.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.connector.catalog -import java.time.Instant import java.util import org.apache.spark.sql.catalyst.InternalRow @@ -26,7 +25,7 @@ import org.apache.spark.sql.connector.catalog.constraints.Constraint import org.apache.spark.sql.connector.distributions.{Distribution, Distributions} import org.apache.spark.sql.connector.expressions.{FieldReference, LogicalExpressions, NamedReference, SortDirection, SortOrder, Transform} import org.apache.spark.sql.connector.read.{Scan, ScanBuilder} -import org.apache.spark.sql.connector.write.{BatchWrite, DeltaBatchWrite, DeltaWrite, DeltaWriteBuilder, DeltaWriter, DeltaWriterFactory, LogicalWriteInfo, PhysicalWriteInfo, RequiresDistributionAndOrdering, RowLevelOperation, RowLevelOperationBuilder, RowLevelOperationInfo, SupportsDelta, Write, WriteBuilder, WriterCommitMessage, WriteSummary} +import org.apache.spark.sql.connector.write.{BatchWrite, DeltaBatchWrite, DeltaWrite, DeltaWriteBuilder, DeltaWriter, DeltaWriterFactory, LogicalWriteInfo, PhysicalWriteInfo, RequiresDistributionAndOrdering, RowLevelOperation, RowLevelOperationBuilder, RowLevelOperationInfo, SupportsDelta, Write, WriteBuilder, WriterCommitMessage} import org.apache.spark.sql.connector.write.RowLevelOperation.Command import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -143,18 +142,11 @@ class InMemoryRowLevelOperationTable private ( override def description(): String = "InMemoryPartitionReplaceOperation" } - abstract class RowLevelOperationBatchWrite extends TestBatchWrite { - - override def commit(messages: Array[WriterCommitMessage], metrics: WriteSummary): Unit = { - commit(messages) - commits += Commit(Instant.now().toEpochMilli, Some(metrics)) - } - } - private case class PartitionBasedReplaceData(scan: InMemoryBatchScan) - extends RowLevelOperationBatchWrite { + extends TestBatchWrite { - override def commit(messages: Array[WriterCommitMessage]): Unit = dataMap.synchronized { + override protected def doCommit( + messages: Array[WriterCommitMessage]): Unit = dataMap.synchronized { val newData = messages.map(_.asInstanceOf[BufferedRows]) val readRows = scan.data.flatMap(_.asInstanceOf[BufferedRows].rows) val readPartitions = readRows.map(r => getKey(r, schema)).distinct @@ -216,12 +208,12 @@ class InMemoryRowLevelOperationTable private ( } } - private object TestDeltaBatchWrite extends RowLevelOperationBatchWrite with DeltaBatchWrite{ + private object TestDeltaBatchWrite extends TestBatchWrite with DeltaBatchWrite { override def createBatchWriterFactory(info: PhysicalWriteInfo): DeltaWriterFactory = { new DeltaBufferedRowsWriterFactory(CatalogV2Util.v2ColumnsToStructType(columns())) } - override def commit(messages: Array[WriterCommitMessage]): Unit = { + override protected def doCommit(messages: Array[WriterCommitMessage]): Unit = { val newData = messages.map(_.asInstanceOf[BufferedRows]) withDeletes(newData) withData(newData, columns()) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala index 15ed4136dbda8..66db9c18fa981 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala @@ -205,7 +205,8 @@ class InMemoryTable( private class Overwrite(filters: Array[Filter]) extends TestBatchWrite { import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper - override def commit(messages: Array[WriterCommitMessage]): Unit = dataMap.synchronized { + override protected def doCommit( + messages: Array[WriterCommitMessage]): Unit = dataMap.synchronized { val deleteKeys = InMemoryTable.filtersToKeys( dataMap.keys, partCols.map(_.toSeq.quoted).toImmutableArraySeq, filters) dataMap --= deleteKeys diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableWithV2Filter.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableWithV2Filter.scala index f2827faf59435..e9d73d0f9fe1e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableWithV2Filter.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableWithV2Filter.scala @@ -140,7 +140,8 @@ class InMemoryTableWithV2Filter( private class Overwrite(predicates: Array[Predicate]) extends TestBatchWrite { import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper - override def commit(messages: Array[WriterCommitMessage]): Unit = dataMap.synchronized { + override protected def doCommit( + messages: Array[WriterCommitMessage]): Unit = dataMap.synchronized { val deleteKeys = InMemoryTableWithV2Filter.filtersToKeys( dataMap.keys, partCols.map(_.toSeq.quoted).toImmutableArraySeq, predicates) dataMap --= deleteKeys diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index d9e915f82e07f..b970d52d985ec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -24,7 +24,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.{SparkException, SparkIllegalArgumentException} import org.apache.spark.internal.Logging import org.apache.spark.internal.LogKeys.EXPR -import org.apache.spark.sql.catalyst.analysis.{ResolvedIdentifier, ResolvedNamespace, ResolvedPartitionSpec, ResolvedPersistentView, ResolvedTable, ResolvedTempView} +import org.apache.spark.sql.catalyst.analysis.{NamedRelation, ResolvedIdentifier, ResolvedNamespace, ResolvedPartitionSpec, ResolvedPersistentView, ResolvedTable, ResolvedTempView} import org.apache.spark.sql.catalyst.catalog.CatalogUtils import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions.{And, Attribute, DynamicPruning, Expression, NamedExpression, Not, Or, PredicateHelper, SubqueryExpression} @@ -41,7 +41,7 @@ import org.apache.spark.sql.connector.expressions.{FieldReference, LiteralValue} import org.apache.spark.sql.connector.expressions.filter.{And => V2And, Not => V2Not, Or => V2Or, Predicate} import org.apache.spark.sql.connector.read.LocalScan import org.apache.spark.sql.connector.read.streaming.{ContinuousStream, MicroBatchStream, SupportsRealTimeMode} -import org.apache.spark.sql.connector.write.V1Write +import org.apache.spark.sql.connector.write.{V1Write, Write} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.execution.{FilterExec, InSubqueryExec, LeafExecNode, LocalTableScanExec, ProjectExec, RowDataSourceScanExec, ScalarSubquery => ExecScalarSubquery, SparkPlan, SparkStrategy => Strategy} import org.apache.spark.sql.execution.command.CommandUtils @@ -445,8 +445,8 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat invalidateCache) :: Nil } - case AppendData(r @ ExtractV2Table(v1: SupportsWrite), _, _, - _, _, Some(write), analyzedQuery) if v1.supports(TableCapability.V1_BATCH_WRITE) => + case AppendWrite(r @ ExtractV2Table(v1: SupportsWrite), Some(write), analyzedQuery) + if v1.supports(TableCapability.V1_BATCH_WRITE) => write match { case v1Write: V1Write => assert(analyzedQuery.isDefined) @@ -459,6 +459,9 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat case AppendData(r: DataSourceV2Relation, query, _, _, _, Some(write), _) => AppendDataExec(planLater(query), refreshCache(r), write, r.name) :: Nil + case InsertOnlyMerge(r: DataSourceV2Relation, query, Some(write), _) => + InsertOnlyMergeExec(planLater(query), refreshCache(r), write, r.name) :: Nil + case OverwriteByExpression(r @ ExtractV2Table(v1: SupportsWrite), _, _, _, _, _, Some(write), analyzedQuery) if v1.supports(TableCapability.V1_BATCH_WRITE) => write match { @@ -796,6 +799,20 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat } } +/** + * Pattern that matches either an [[AppendData]] or an [[InsertOnlyMerge]] and exposes the + * fields needed to plan the v1 batch-write fallback path. + */ +private object AppendWrite { + def unapply( + plan: LogicalPlan + ): Option[(NamedRelation, Option[Write], Option[LogicalPlan])] = plan match { + case a: AppendData => Some((a.table, a.write, a.analyzedQuery)) + case m: InsertOnlyMerge => Some((m.table, m.write, m.analyzedQuery)) + case _ => None + } +} + private[sql] object DataSourceV2Strategy extends Logging { private def translateLeafNodeFilterV2(predicate: Expression): Option[Predicate] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/TableCapabilityCheck.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/TableCapabilityCheck.scala index b25059bd7bac1..988aa86db1d34 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/TableCapabilityCheck.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/TableCapabilityCheck.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.datasources.v2 import org.apache.spark.sql.catalyst.expressions.Literal -import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic} +import org.apache.spark.sql.catalyst.plans.logical.{AppendData, InsertOnlyMerge, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic} import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2 import org.apache.spark.sql.connector.catalog.Table import org.apache.spark.sql.connector.catalog.TableCapability._ @@ -49,6 +49,9 @@ object TableCapabilityCheck extends (LogicalPlan => Unit) { case AppendData(r: DataSourceV2Relation, _, _, _, _, _, _) if !supportsBatchWrite(r.table) => throw QueryCompilationErrors.unsupportedAppendInBatchModeError(r.name) + case InsertOnlyMerge(r: DataSourceV2Relation, _, _, _) if !supportsBatchWrite(r.table) => + throw QueryCompilationErrors.unsupportedAppendInBatchModeError(r.name) + case OverwritePartitionsDynamic(r: DataSourceV2Relation, _, _, _, _, _) if !r.table.supports(BATCH_WRITE) || !r.table.supports(OVERWRITE_DYNAMIC) => throw QueryCompilationErrors.unsupportedDynamicOverwriteInBatchModeError(r.table) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2Writes.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2Writes.scala index 70a2bf6d5b8b9..0cbf260457ffb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2Writes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2Writes.scala @@ -22,7 +22,7 @@ import java.util.UUID import scala.jdk.CollectionConverters._ import org.apache.spark.sql.catalyst.expressions.PredicateHelper -import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic, ReplaceData, WriteDelta} +import org.apache.spark.sql.catalyst.plans.logical.{AppendData, InsertOnlyMerge, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic, ReplaceData, WriteDelta} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ import org.apache.spark.sql.catalyst.util.WriteDeltaProjections @@ -51,6 +51,13 @@ object V2Writes extends Rule[LogicalPlan] with PredicateHelper { val newQuery = DistributionAndOrderingUtils.prepareQuery(write, query, r.funCatalog) a.copy(write = Some(write), query = newQuery) + case m @ InsertOnlyMerge(r: DataSourceV2Relation, query, None, _) => + val writeOptions = r.options.asCaseSensitiveMap.asScala.toMap + val writeBuilder = newWriteBuilder(r.table, writeOptions, query.schema) + val write = writeBuilder.build() + val newQuery = DistributionAndOrderingUtils.prepareQuery(write, query, r.funCatalog) + m.copy(write = Some(write), query = newQuery) + case o @ OverwriteByExpression( r: DataSourceV2Relation, deleteExpr, query, options, _, _, None, _) => // fail if any filter cannot be converted. correctness depends on removing all matching data. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala index ccfcdc1855f04..f36b0d1b599da 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala @@ -33,16 +33,16 @@ import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Column, Identifier import org.apache.spark.sql.connector.catalog.transactions.Transaction import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.connector.metric.CustomMetric -import org.apache.spark.sql.connector.write.{BatchWrite, DataWriter, DataWriterFactory, DeleteSummaryImpl, DeltaWrite, DeltaWriter, MergeSummaryImpl, PhysicalWriteInfoImpl, RowLevelOperation, RowLevelOperationTable, UpdateSummaryImpl, Write, WriterCommitMessage, WriteSummary} +import org.apache.spark.sql.connector.write.{BatchWrite, DataWriter, DataWriterFactory, DeleteSummaryImpl, DeltaWrite, DeltaWriter, InsertSummaryImpl, MergeSummaryImpl, PhysicalWriteInfoImpl, RowLevelOperation, RowLevelOperationTable, UpdateSummaryImpl, Write, WriterCommitMessage, WriteSummary} import org.apache.spark.sql.connector.write.RowLevelOperation.Command._ import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} -import org.apache.spark.sql.execution.{QueryExecution, SparkPlan, UnaryExecNode} +import org.apache.spark.sql.execution.{QueryExecution, SparkPlan, SQLExecution, UnaryExecNode} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.metric.{CustomMetrics, SQLMetric, SQLMetrics} import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.SchemaValidationMode.PROHIBIT_CHANGES -import org.apache.spark.util.{LongAccumulator, Utils} import org.apache.spark.util.ArrayImplicits._ +import org.apache.spark.util.Utils /** * Deprecated logical plan for writing data into data source v2. This is being replaced by more @@ -294,6 +294,37 @@ case class AppendDataExec( override def withTransaction(txn: Option[Transaction]): AppendDataExec = copy(transaction = txn) override protected def withNewChildInternal(newChild: SparkPlan): AppendDataExec = copy(query = newChild) + + override protected def getWriteSummary(query: SparkPlan): Option[WriteSummary] = + Some(InsertSummaryImpl(numInsertedRows = numOutputRowsMetric.value)) +} + +/** + * Physical plan for an insert-only MERGE rewrite. Behaves like [[AppendDataExec]] but emits a + * [[org.apache.spark.sql.connector.write.MergeSummary]] so commit metadata reports the operation + * as a MERGE, with all output rows accounted for as inserts. + */ +case class InsertOnlyMergeExec( + query: SparkPlan, + refreshCache: () => Unit, + write: Write, + tableName: String, + transaction: Option[Transaction] = None) extends V2ExistingTableWriteExec { + override def withTransaction(txn: Option[Transaction]): InsertOnlyMergeExec = + copy(transaction = txn) + override protected def withNewChildInternal(newChild: SparkPlan): InsertOnlyMergeExec = + copy(query = newChild) + + override protected def getWriteSummary(query: SparkPlan): Option[WriteSummary] = + Some(MergeSummaryImpl( + numTargetRowsCopied = 0L, + numTargetRowsDeleted = 0L, + numTargetRowsUpdated = 0L, + numTargetRowsInserted = numOutputRowsMetric.value, + numTargetRowsMatchedUpdated = 0L, + numTargetRowsMatchedDeleted = 0L, + numTargetRowsNotMatchedBySourceUpdated = 0L, + numTargetRowsNotMatchedBySourceDeleted = 0L)) } /** @@ -477,17 +508,18 @@ trait V2ExistingTableWriteExec extends V2TableWriteExec with TransactionalExec { trait RowLevelWriteExec extends V2ExistingTableWriteExec { def rowLevelCommand: RowLevelOperation.Command - override protected lazy val sparkMetrics: Map[String, SQLMetric] = rowLevelCommand match { - case UPDATE => - Map( - "numUpdatedRows" -> SQLMetrics.createMetric(sparkContext, "number of updated rows"), - "numCopiedRows" -> SQLMetrics.createMetric(sparkContext, "number of copied rows")) - case DELETE => - Map( - "numDeletedRows" -> SQLMetrics.createMetric(sparkContext, "number of deleted rows"), - "numCopiedRows" -> SQLMetrics.createMetric(sparkContext, "number of copied rows")) - case _ => Map.empty - } + override protected lazy val sparkMetrics: Map[String, SQLMetric] = super.sparkMetrics ++ ( + rowLevelCommand match { + case UPDATE => + Map( + "numUpdatedRows" -> SQLMetrics.createMetric(sparkContext, "number of updated rows"), + "numCopiedRows" -> SQLMetrics.createMetric(sparkContext, "number of copied rows")) + case DELETE => + Map( + "numDeletedRows" -> SQLMetrics.createMetric(sparkContext, "number of deleted rows"), + "numCopiedRows" -> SQLMetrics.createMetric(sparkContext, "number of copied rows")) + case _ => Map.empty + }) /** * Returns the value of the named metric, or -1 if the metric is not found. @@ -542,6 +574,12 @@ trait V2TableWriteExec override def customMetrics: Map[String, SQLMetric] = Map.empty + protected lazy val numOutputRowsMetric: SQLMetric = + SQLMetrics.createMetric(sparkContext, "number of output rows") + + override protected def sparkMetrics: Map[String, SQLMetric] = Map( + "numOutputRows" -> numOutputRowsMetric) + protected def writeWithV2(batchWrite: BatchWrite): Seq[InternalRow] = { val rdd: RDD[InternalRow] = { val tempRdd = query.execute() @@ -559,7 +597,6 @@ trait V2TableWriteExec PhysicalWriteInfoImpl(rdd.getNumPartitions)) val useCommitCoordinator = batchWrite.useCommitCoordinator val messages = new Array[WriterCommitMessage](rdd.partitions.length) - val totalNumRowsAccumulator = new LongAccumulator() logInfo(log"Start processing data source write support: " + log"${MDC(LogKeys.BATCH_WRITE, batchWrite)}. The input RDD has " + @@ -577,11 +614,14 @@ trait V2TableWriteExec (index, result: DataWritingSparkTaskResult) => { val commitMessage = result.writerCommitMessage messages(index) = commitMessage - totalNumRowsAccumulator.add(result.numRows) + numOutputRowsMetric.add(result.numRows) batchWrite.onDataWriterCommit(commitMessage) } ) + val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) + SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, Seq(numOutputRowsMetric)) + val writeSummary = getWriteSummary(query) logInfo(log"Data source write support ${MDC(LogKeys.BATCH_WRITE, batchWrite)} is committing.") writeSummary match { @@ -589,7 +629,7 @@ trait V2TableWriteExec case None => batchWrite.commit(messages) } logInfo(log"Data source write support ${MDC(LogKeys.BATCH_WRITE, batchWrite)} committed.") - commitProgress = Some(StreamWriterCommitProgress(totalNumRowsAccumulator.value)) + commitProgress = Some(StreamWriterCommitProgress(numOutputRowsMetric.value)) } catch { case cause: Throwable => logError( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSessionCatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSessionCatalogSuite.scala index 4db59b36c1fec..97cdebe2d32df 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSessionCatalogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSessionCatalogSuite.scala @@ -181,7 +181,7 @@ object InMemoryTableSessionCatalog { private [connector] trait SessionCatalogTest[T <: Table, Catalog <: TestV2SessionCatalogBase[T]] extends SharedSparkSession - with BeforeAndAfter { + with BeforeAndAfter { self: InsertIntoSQLOnlyTests => protected def catalog(name: String): CatalogPlugin = { spark.sessionState.catalogManager.catalog(name) @@ -215,6 +215,7 @@ private [connector] trait SessionCatalogTest[T <: Table, Catalog <: TestV2Sessio val t1 = "tbl" val df = Seq((1L, "a"), (2L, "b"), (3L, "c")).toDF("id", "data") df.write.format(v2Format).saveAsTable(t1) + checkInsertMetrics(t1, numInsertedRows = 3) verifyTable(t1, df) } @@ -222,6 +223,7 @@ private [connector] trait SessionCatalogTest[T <: Table, Catalog <: TestV2Sessio val t1 = "tbl" val df = Seq((1L, "a"), (2L, "b"), (3L, "c")).toDF("id", "data") df.write.format(v2Format).mode("append").saveAsTable(t1) + checkInsertMetrics(t1, numInsertedRows = 3) verifyTable(t1, df) } @@ -245,10 +247,12 @@ private [connector] trait SessionCatalogTest[T <: Table, Catalog <: TestV2Sessio df.select("id", "data").write.format(v2Format).saveAsTable(t1) } df.write.format(v2Format).mode("append").saveAsTable(t1) + checkInsertMetrics(t1, numInsertedRows = 3) verifyTable(t1, df) // Check that appends are by name df.select($"data", $"id").write.format(v2Format).mode("append").saveAsTable(t1) + checkInsertMetrics(t1, numInsertedRows = 3) verifyTable(t1, df.union(df)) } @@ -284,6 +288,7 @@ private [connector] trait SessionCatalogTest[T <: Table, Catalog <: TestV2Sessio val t1 = "tbl" val df = Seq((1L, "a"), (2L, "b"), (3L, "c")).toDF("id", "data") df.write.format(v2Format).mode("ignore").saveAsTable(t1) + checkInsertMetrics(t1, numInsertedRows = 3) verifyTable(t1, df) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSuite.scala index 5cda5169369e7..c532ef359a7c1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSuite.scala @@ -102,7 +102,9 @@ class DataSourceV2DataFrameSuite sql(s"CREATE TABLE $t2 (id bigint, data string) USING foo") val df = Seq((1L, "a"), (2L, "b"), (3L, "c")).toDF("id", "data") df.write.insertInto(t1) + checkInsertMetrics(t1, numInsertedRows = 3) spark.table(t1).write.insertInto(t2) + checkInsertMetrics(t2, numInsertedRows = 3) checkAnswer(spark.table(t2), df) } } @@ -112,6 +114,7 @@ class DataSourceV2DataFrameSuite withTable(t1) { val df = Seq((1L, "a"), (2L, "b"), (3L, "c")).toDF("id", "data") df.write.saveAsTable(t1) + checkInsertMetrics(t1, numInsertedRows = 3) checkAnswer(spark.table(t1), df) } } @@ -129,6 +132,7 @@ class DataSourceV2DataFrameSuite // appends are by name not by position df.select($"data", $"id").write.mode("append").saveAsTable(t1) + checkInsertMetrics(t1, numInsertedRows = 3) checkAnswer(spark.table(t1), df) } } @@ -157,6 +161,7 @@ class DataSourceV2DataFrameSuite withTable(t1) { val df = Seq((1L, "a"), (2L, "b"), (3L, "c")).toDF("id", "data") df.write.mode("ignore").saveAsTable(t1) + checkInsertMetrics(t1, numInsertedRows = 3) checkAnswer(spark.table(t1), df) } } @@ -190,6 +195,7 @@ class DataSourceV2DataFrameSuite val df = Seq((1L, "a"), (2L, "b"), (3L, "c")).toDF("id", "data") df.write.option("other", "20").mode("append").saveAsTable(t1) + checkInsertMetrics(t1, numInsertedRows = 3) sparkContext.listenerBus.waitUntilEmpty() plan match { @@ -391,24 +397,29 @@ class DataSourceV2DataFrameSuite val df1 = Seq((1, "hr")).toDF("id", "dep") df1.writeTo(tableName).append() + checkInsertMetrics(tableName, numInsertedRows = 1) sql(s"ALTER TABLE $tableName ADD COLUMN txt STRING DEFAULT 'initial-text'") val df2 = Seq((2, "hr"), (3, "software")).toDF("id", "dep") df2.writeTo(tableName).append() + checkInsertMetrics(tableName, numInsertedRows = 2) sql(s"ALTER TABLE $tableName ALTER COLUMN txt SET DEFAULT 'new-text'") val df3 = Seq((4, "hr"), (5, "hr")).toDF("id", "dep") df3.writeTo(tableName).append() + checkInsertMetrics(tableName, numInsertedRows = 2) val df4 = Seq((6, "hr", null), (7, "hr", "explicit-text")).toDF("id", "dep", "txt") df4.writeTo(tableName).append() + checkInsertMetrics(tableName, numInsertedRows = 2) sql(s"ALTER TABLE $tableName ALTER COLUMN txt DROP DEFAULT") val df5 = Seq((8, "hr"), (9, "hr")).toDF("id", "dep") df5.writeTo(tableName).append() + checkInsertMetrics(tableName, numInsertedRows = 2) checkAnswer( sql(s"SELECT * FROM $tableName"), @@ -432,11 +443,13 @@ class DataSourceV2DataFrameSuite val df1 = Seq(1, 2).toDF("id") df1.writeTo(tableName).append() + checkInsertMetrics(tableName, numInsertedRows = 2) sql(s"ALTER TABLE $tableName ALTER COLUMN dep SET DEFAULT 'it'") val df2 = Seq(3, 4).toDF("id") df2.writeTo(tableName).append() + checkInsertMetrics(tableName, numInsertedRows = 2) checkAnswer( sql(s"SELECT * FROM $tableName"), @@ -450,6 +463,7 @@ class DataSourceV2DataFrameSuite val df3 = Seq(1, 2).toDF("id") df3.writeTo(tableName).append() + checkInsertMetrics(tableName, numInsertedRows = 2) checkAnswer( sql(s"SELECT * FROM $tableName"), @@ -493,11 +507,13 @@ class DataSourceV2DataFrameSuite val df1 = Seq(1).toDF("id") df1.writeTo(tableName).append() + checkInsertMetrics(tableName, numInsertedRows = 1) sql(s"ALTER TABLE $tableName ALTER COLUMN dep SET DEFAULT ('i' || 't')") val df2 = Seq(2).toDF("id") df2.writeTo(tableName).append() + checkInsertMetrics(tableName, numInsertedRows = 1) checkAnswer( sql(s"SELECT * FROM $tableName"), @@ -536,6 +552,7 @@ class DataSourceV2DataFrameSuite val df3 = Seq(1).toDF("id") df3.writeTo(tableName).append() + checkInsertMetrics(tableName, numInsertedRows = 1) checkAnswer( sql(s"SELECT * FROM $tableName"), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala index e49cb692e3b3c..cb7531a0dbafd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala @@ -439,6 +439,7 @@ class DataSourceV2SQLSuiteV1Filter Seq((basicCatalog, basicIdentifier), (atomicCatalog, atomicIdentifier)).foreach { case (catalog, identifier) => spark.sql(s"CREATE TABLE $identifier USING foo AS SELECT id, data FROM source") + checkInsertMetrics(identifier, numInsertedRows = 3) val table = catalog.loadTable(Identifier.of(Array(), "table_name")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/InsertIntoTests.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/InsertIntoTests.scala index fa6edc96ec9fd..4f023136a6fe1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/InsertIntoTests.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/InsertIntoTests.scala @@ -24,6 +24,9 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.util.TypeUtils._ import org.apache.spark.sql.catalyst.util.quoteIdentifier +import org.apache.spark.sql.connector.catalog.InMemoryBaseTable +import org.apache.spark.sql.connector.write.InsertSummary +import org.apache.spark.sql.execution.datasources.v2.ExtractV2Table import org.apache.spark.sql.functions.{array, lit, map, struct} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.{PARTITION_OVERWRITE_MODE, PartitionOverwriteMode} @@ -60,6 +63,7 @@ abstract class InsertIntoTests( sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format") val df = Seq((1L, "a"), (2L, "b"), (3L, "c")).toDF("id", "data") doInsert(t1, df) + checkInsertMetrics(t1, numInsertedRows = 3) verifyTable(t1, df) } @@ -70,6 +74,7 @@ abstract class InsertIntoTests( val dfr = Seq((1L, "a"), (2L, "b"), (3L, "c")).toDF("data", "id") doInsert(t1, dfr) + checkInsertMetrics(t1, numInsertedRows = 3) verifyTable(t1, df) } @@ -79,6 +84,7 @@ abstract class InsertIntoTests( sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format PARTITIONED BY (id)") val df = Seq((1L, "a"), (2L, "b"), (3L, "c")).toDF("id", "data") doInsert(t1, df) + checkInsertMetrics(t1, numInsertedRows = 3) verifyTable(t1, df) } } @@ -89,6 +95,7 @@ abstract class InsertIntoTests( val df = Seq((1L, "a"), (2L, "b"), (3L, "c")).toDF("id", "data") val df2 = Seq((4L, "d"), (5L, "e"), (6L, "f")).toDF("id", "data") doInsert(t1, df) + checkInsertMetrics(t1, numInsertedRows = 3) doInsert(t1, df2, SaveMode.Overwrite) verifyTable(t1, df2) } @@ -99,6 +106,7 @@ abstract class InsertIntoTests( sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format PARTITIONED BY (id)") val init = Seq((2L, "dummy"), (4L, "keep")).toDF("id", "data") doInsert(t1, init) + checkInsertMetrics(t1, numInsertedRows = 2) val df = Seq((1L, "a"), (2L, "b"), (3L, "c")).toDF("id", "data") doInsert(t1, df, SaveMode.Overwrite) @@ -114,6 +122,7 @@ abstract class InsertIntoTests( sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format PARTITIONED BY (id)") val init = Seq((2L, "dummy"), (4L, "keep")).toDF("id", "data") doInsert(t1, init) + checkInsertMetrics(t1, numInsertedRows = 2) val dfr = Seq((1L, "a"), (2L, "b"), (3L, "c")).toDF("data", "id") doInsert(t1, dfr, SaveMode.Overwrite) @@ -210,6 +219,15 @@ trait InsertIntoSQLOnlyTests /** Check that the results in `tableName` match the `expected` DataFrame. */ protected def verifyTable(tableName: String, expected: DataFrame): Unit + protected def checkInsertMetrics(tableName: String, numInsertedRows: Long): Unit = { + val inMemoryTable = spark.table(tableName).queryExecution.analyzed.collectFirst { + case ExtractV2Table(t) => t.asInstanceOf[InMemoryBaseTable] + }.get + val summary = inMemoryTable.commits.last.writeSummary.get.asInstanceOf[InsertSummary] + assert(summary.numInsertedRows() === numInsertedRows, + s"Expected numInsertedRows=$numInsertedRows, got ${summary.numInsertedRows()}") + } + protected val v2Format: String protected val catalogAndNamespace: String @@ -293,6 +311,7 @@ trait InsertIntoSQLOnlyTests withTableAndData(t1) { view => sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format PARTITIONED BY (id)") sql(s"INSERT INTO $t1 PARTITION (id = 23) SELECT data FROM $view") + checkInsertMetrics(t1, numInsertedRows = 3) verifyTable(t1, sql(s"SELECT 23, data FROM $view")) } } @@ -303,6 +322,7 @@ trait InsertIntoSQLOnlyTests withTableAndData(t1) { view => sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format PARTITIONED BY (id)") sql(s"INSERT INTO $t1 VALUES (2L, 'dummy'), (4L, 'also-deleted')") + checkInsertMetrics(t1, numInsertedRows = 2) sql(s"INSERT OVERWRITE TABLE $t1 PARTITION (id) SELECT * FROM $view") verifyTable(t1, Seq( (1, "a"), @@ -317,6 +337,7 @@ trait InsertIntoSQLOnlyTests withTableAndData(t1) { view => sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format PARTITIONED BY (id)") sql(s"INSERT INTO $t1 VALUES (2L, 'dummy'), (4L, 'keep')") + checkInsertMetrics(t1, numInsertedRows = 2) sql(s"INSERT OVERWRITE TABLE $t1 PARTITION (id) SELECT * FROM $view") verifyTable(t1, Seq( (1, "a"), @@ -332,6 +353,7 @@ trait InsertIntoSQLOnlyTests withTableAndData(t1) { view => sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format PARTITIONED BY (id)") sql(s"INSERT INTO $t1 VALUES (2L, 'dummy'), (4L, 'also-deleted')") + checkInsertMetrics(t1, numInsertedRows = 2) sql(s"INSERT OVERWRITE TABLE $t1 SELECT * FROM $view") verifyTable(t1, Seq( (1, "a"), @@ -346,6 +368,7 @@ trait InsertIntoSQLOnlyTests withTableAndData(t1) { view => sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format PARTITIONED BY (id)") sql(s"INSERT INTO $t1 VALUES (2L, 'dummy'), (4L, 'keep')") + checkInsertMetrics(t1, numInsertedRows = 2) sql(s"INSERT OVERWRITE TABLE $t1 SELECT * FROM $view") verifyTable(t1, Seq( (1, "a"), @@ -361,6 +384,7 @@ trait InsertIntoSQLOnlyTests sql(s"CREATE TABLE $t1 (id bigint, data string, p1 int) " + s"USING $v2Format PARTITIONED BY (p1)") sql(s"INSERT INTO $t1 VALUES (2L, 'dummy', 23), (4L, 'keep', 2)") + checkInsertMetrics(t1, numInsertedRows = 2) sql(s"INSERT OVERWRITE TABLE $t1 PARTITION (p1 = 23) SELECT * FROM $view") verifyTable(t1, Seq( (1, "a", 23), @@ -377,6 +401,7 @@ trait InsertIntoSQLOnlyTests sql(s"CREATE TABLE $t1 (id bigint, data string, p int) " + s"USING $v2Format PARTITIONED BY (id, p)") sql(s"INSERT INTO $t1 VALUES (2L, 'dummy', 2), (4L, 'also-deleted', 2)") + checkInsertMetrics(t1, numInsertedRows = 2) sql(s"INSERT OVERWRITE TABLE $t1 PARTITION (id, p = 2) SELECT * FROM $view") verifyTable(t1, Seq( (1, "a", 2), @@ -393,6 +418,7 @@ trait InsertIntoSQLOnlyTests sql(s"CREATE TABLE $t1 (id bigint, data string, p int) " + s"USING $v2Format PARTITIONED BY (id, p)") sql(s"INSERT INTO $t1 VALUES (2L, 'dummy', 2), (4L, 'also-deleted', 2)") + checkInsertMetrics(t1, numInsertedRows = 2) sql(s"INSERT OVERWRITE TABLE $t1 PARTITION (p = 2, id) SELECT * FROM $view") verifyTable(t1, Seq( (1, "a", 2), @@ -409,6 +435,7 @@ trait InsertIntoSQLOnlyTests sql(s"CREATE TABLE $t1 (id bigint, data string, p int) " + s"USING $v2Format PARTITIONED BY (id, p)") sql(s"INSERT INTO $t1 VALUES (2L, 'dummy', 2), (4L, 'also-deleted', 2)") + checkInsertMetrics(t1, numInsertedRows = 2) sql(s"INSERT OVERWRITE TABLE $t1 PARTITION (p = 2) SELECT * FROM $view") verifyTable(t1, Seq( (1, "a", 2), @@ -424,6 +451,7 @@ trait InsertIntoSQLOnlyTests sql(s"CREATE TABLE $t1 (id bigint, data string, p int) " + s"USING $v2Format PARTITIONED BY (id, p)") sql(s"INSERT INTO $t1 VALUES (2L, 'dummy', 2), (4L, 'keep', 2)") + checkInsertMetrics(t1, numInsertedRows = 2) sql(s"INSERT OVERWRITE TABLE $t1 PARTITION (p = 2, id) SELECT * FROM $view") verifyTable(t1, Seq( (1, "a", 2), @@ -439,6 +467,7 @@ trait InsertIntoSQLOnlyTests sql(s"CREATE TABLE $t1 (id bigint, data string, p int) " + s"USING $v2Format PARTITIONED BY (id, p)") sql(s"INSERT INTO $t1 VALUES (2L, 'dummy', 2), (4L, 'keep', 2)") + checkInsertMetrics(t1, numInsertedRows = 2) sql(s"INSERT OVERWRITE TABLE $t1 PARTITION (id, p = 2) SELECT * FROM $view") verifyTable(t1, Seq( (1, "a", 2), @@ -454,6 +483,7 @@ trait InsertIntoSQLOnlyTests sql(s"CREATE TABLE $t1 (id bigint, data string, p int) " + s"USING $v2Format PARTITIONED BY (id, p)") sql(s"INSERT INTO $t1 VALUES (2L, 'dummy', 2), (4L, 'keep', 2)") + checkInsertMetrics(t1, numInsertedRows = 2) sql(s"INSERT OVERWRITE TABLE $t1 PARTITION (p = 2) SELECT * FROM $view") verifyTable(t1, Seq( (1, "a", 2), @@ -469,6 +499,7 @@ trait InsertIntoSQLOnlyTests sql(s"CREATE TABLE $t1 (id bigint, data string, p int) " + s"USING $v2Format PARTITIONED BY (id, p)") sql(s"INSERT INTO $t1 VALUES (2L, 'dummy', 2), (4L, 'keep', 2)") + checkInsertMetrics(t1, numInsertedRows = 2) sql(s"INSERT OVERWRITE TABLE $t1 PARTITION (id = 2, p = 2) SELECT data FROM $view") verifyTable(t1, Seq( (2, "a", 2), @@ -491,6 +522,7 @@ trait InsertIntoSQLOnlyTests df.where("true").take(5) df.where("true").tail(5) + checkInsertMetrics(t1, numInsertedRows = 3) verifyTable(t1, spark.table(view)) } } @@ -510,9 +542,11 @@ trait InsertIntoSQLOnlyTests withTable(t1) { sql(s"CREATE TABLE $t1 (c1 INT DEFAULT 42, c2 STRING DEFAULT 'hello') USING $v2Format") sql(s"INSERT INTO $t1 VALUES (1, DEFAULT)") + checkInsertMetrics(t1, numInsertedRows = 1) checkAnswer(sql(s"SELECT * FROM $t1"), Row(1, "hello")) sql(s"INSERT INTO $t1 VALUES (DEFAULT, DEFAULT)") + checkInsertMetrics(t1, numInsertedRows = 1) checkAnswer( sql(s"SELECT * FROM $t1 ORDER BY c1"), Seq(Row(1, "hello"), Row(42, "hello"))) @@ -565,8 +599,10 @@ trait InsertIntoSchemaEvolutionTests { this: InsertIntoTests => sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format " + s"TBLPROPERTIES ('auto-schema-evolution' = 'false')") doInsert(t1, Seq((1L, "a")).toDF("id", "data")) + checkInsertMetrics(t1, numInsertedRows = 1) // Same column count, no evolution needed: should succeed even without capability. doInsertWithSchemaEvolution(t1, Seq((2L, "b")).toDF("id", "data")) + checkInsertMetrics(t1, numInsertedRows = 1) verifyTable(t1, Seq((1L, "a"), (2L, "b")).toDF("id", "data")) } } @@ -576,7 +612,9 @@ trait InsertIntoSchemaEvolutionTests { this: InsertIntoTests => withTable(t1) { sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format") doInsert(t1, Seq((1L, "a")).toDF("id", "data")) + checkInsertMetrics(t1, numInsertedRows = 1) doInsertWithSchemaEvolution(t1, Seq((2L, "b")).toDF("x", "y")) + checkInsertMetrics(t1, numInsertedRows = 1) // No evolution verifyTable(t1, Seq((1L, "a"), (2L, "b")).toDF("id", "data")) } @@ -587,8 +625,10 @@ trait InsertIntoSchemaEvolutionTests { this: InsertIntoTests => withTable(t1) { sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format") doInsert(t1, Seq((1L, "a")).toDF("id", "data")) + checkInsertMetrics(t1, numInsertedRows = 1) doInsertWithSchemaEvolution(t1, Seq((2L, "b", true)).toDF("id", "data", "active")) + checkInsertMetrics(t1, numInsertedRows = 1) verifyTable(t1, Seq[(Long, String, java.lang.Boolean)]( (1L, "a", null), (2L, "b", true) @@ -601,8 +641,10 @@ trait InsertIntoSchemaEvolutionTests { this: InsertIntoTests => withTable(t1) { sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format") doInsert(t1, Seq((1L, "a")).toDF("id", "data")) + checkInsertMetrics(t1, numInsertedRows = 1) doInsertWithSchemaEvolution(t1, Seq((2L, "b", true, 100L)).toDF("id", "data", "active", "score")) + checkInsertMetrics(t1, numInsertedRows = 1) verifyTable(t1, Seq[(Long, String, java.lang.Boolean, java.lang.Long)]( (1L, "a", null, null), (2L, "b", true, 100L) @@ -615,7 +657,9 @@ trait InsertIntoSchemaEvolutionTests { this: InsertIntoTests => withTable(t1) { sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format") doInsert(t1, Seq((1L, "a")).toDF("id", "data")) + checkInsertMetrics(t1, numInsertedRows = 1) doInsertWithSchemaEvolution(t1, Seq((2L, "b", true)).toDF("x", "y", "z")) + checkInsertMetrics(t1, numInsertedRows = 1) verifyTable(t1, Seq[(Long, String, java.lang.Boolean)]( (1L, "a", null), (2L, "b", true) @@ -629,6 +673,7 @@ trait InsertIntoSchemaEvolutionTests { this: InsertIntoTests => sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format") doInsertWithSchemaEvolution(t1, Seq((1L, "a", true)).toDF("id", "data", "active")) + checkInsertMetrics(t1, numInsertedRows = 1) verifyTable(t1, Seq( (1L, "a", true) ).toDF("id", "data", "active")) @@ -642,9 +687,11 @@ trait InsertIntoSchemaEvolutionTests { this: InsertIntoTests => doInsert(t1, Seq((1, "Alice")).toDF("id", "name") .select($"id", struct($"name").as("info"))) + checkInsertMetrics(t1, numInsertedRows = 1) doInsertWithSchemaEvolution(t1, Seq((2, "Bob", 30)).toDF("id", "name", "age") .select($"id", struct($"name", $"age").as("info"))) + checkInsertMetrics(t1, numInsertedRows = 1) checkAnswer( sql(s"SELECT * FROM $t1"), Seq(Row(1, Row("Alice", null)), Row(2, Row("Bob", 30)))) @@ -658,9 +705,11 @@ trait InsertIntoSchemaEvolutionTests { this: InsertIntoTests => doInsert(t1, Seq((1, "Alice")).toDF("id", "name") .select($"id", struct($"name").as("info"))) + checkInsertMetrics(t1, numInsertedRows = 1) doInsertWithSchemaEvolution(t1, Seq((2, "Bob", 30)).toDF("id", "firstName", "age") .select($"id", struct($"firstName", $"age").as("info"))) + checkInsertMetrics(t1, numInsertedRows = 1) checkAnswer( sql(s"SELECT * FROM $t1"), Seq(Row(1, Row("Alice", null)), Row(2, Row("Bob", 30)))) @@ -672,8 +721,10 @@ trait InsertIntoSchemaEvolutionTests { this: InsertIntoTests => withTable(t1) { sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format") doInsert(t1, Seq((1L, "a")).toDF("id", "data")) + checkInsertMetrics(t1, numInsertedRows = 1) doInsertWithSchemaEvolution(t1, Seq(("b", true, 2L)).toDF("data", "active", "id"), byName = true) + checkInsertMetrics(t1, numInsertedRows = 1) verifyTable(t1, Seq[(Long, String, java.lang.Boolean)]( (1L, "a", null), (2L, "b", true) @@ -688,10 +739,12 @@ trait InsertIntoSchemaEvolutionTests { this: InsertIntoTests => doInsert(t1, Seq((1, "Alice")).toDF("id", "name") .select($"id", struct($"name").as("info"))) + checkInsertMetrics(t1, numInsertedRows = 1) doInsertWithSchemaEvolution(t1, Seq((2, 30, "Bob")).toDF("id", "age", "name") .select($"id", struct($"age", $"name").as("info")), byName = true) + checkInsertMetrics(t1, numInsertedRows = 1) checkAnswer( sql(s"SELECT * FROM $t1"), Seq(Row(1, Row("Alice", null)), Row(2, Row("Bob", 30)))) @@ -705,10 +758,12 @@ trait InsertIntoSchemaEvolutionTests { this: InsertIntoTests => doInsert(t1, Seq((1, "Alice")).toDF("id", "name") .select($"id", struct($"name").as("info"))) + checkInsertMetrics(t1, numInsertedRows = 1) doInsertWithSchemaEvolution(t1, Seq((2, 30, "Bob", "NYC")).toDF("id", "age", "name", "city") .select($"id", struct($"age", $"name", $"city").as("info")), byName = true) + checkInsertMetrics(t1, numInsertedRows = 1) checkAnswer( sql(s"SELECT * FROM $t1"), Seq(Row(1, Row("Alice", null, null)), Row(2, Row("Bob", 30, "NYC")))) @@ -720,8 +775,10 @@ trait InsertIntoSchemaEvolutionTests { this: InsertIntoTests => withTable(t1) { sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format") doInsert(t1, Seq((1L, "a")).toDF("id", "data")) + checkInsertMetrics(t1, numInsertedRows = 1) doInsertWithSchemaEvolution(t1, Seq(("b", 2L)).toDF("data", "id"), byName = true) + checkInsertMetrics(t1, numInsertedRows = 1) // No evolution verifyTable(t1, Seq((1L, "a"), (2L, "b")).toDF("id", "data")) } @@ -732,8 +789,10 @@ trait InsertIntoSchemaEvolutionTests { this: InsertIntoTests => withTable(t1) { sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format") doInsert(t1, Seq((1L, "a")).toDF("id", "data")) + checkInsertMetrics(t1, numInsertedRows = 1) doInsertWithSchemaEvolution(t1, Seq(("b", 2L)).toDF("x", "y"), byName = true) + checkInsertMetrics(t1, numInsertedRows = 1) verifyTable(t1, Seq[(java.lang.Long, String, String, java.lang.Long)]( (1L, "a", null, null), (null, null, "b", 2L) @@ -748,6 +807,7 @@ trait InsertIntoSchemaEvolutionTests { this: InsertIntoTests => doInsert(t1, Seq((1, "Alice")).toDF("id", "name") .select($"id", struct($"name").as("info"))) + checkInsertMetrics(t1, numInsertedRows = 1) doInsertWithSchemaEvolution(t1, Seq((2, 30, "Bob")).toDF("id", "age", "name") .select($"id", struct($"age", $"name").as("info")), @@ -764,6 +824,7 @@ trait InsertIntoSchemaEvolutionTests { this: InsertIntoTests => withTable(t1) { sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format PARTITIONED BY (id)") doInsert(t1, Seq((1L, "a"), (2L, "b")).toDF("id", "data")) + checkInsertMetrics(t1, numInsertedRows = 2) // REPLACE WHERE only deletes rows matching the predicate, then inserts new data. doInsertWithSchemaEvolution(t1, Seq((2L, "x", true), (4L, "y", false)).toDF("id", "data", "active"), @@ -781,6 +842,7 @@ trait InsertIntoSchemaEvolutionTests { this: InsertIntoTests => withTable(t1) { sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format PARTITIONED BY (id)") doInsert(t1, Seq((1L, "a"), (2L, "b")).toDF("id", "data")) + checkInsertMetrics(t1, numInsertedRows = 2) doInsertWithSchemaEvolution(t1, Seq((true, "x", 2L), (false, "y", 4L)).toDF("active", "data", "id"), byName = true, @@ -801,6 +863,7 @@ trait InsertIntoSchemaEvolutionTests { this: InsertIntoTests => val initDf = Seq((1L, "Alice"), (2L, "Bob")).toDF("id", "name") .select($"id", struct($"name").as("info")) doInsert(t1, initDf) + checkInsertMetrics(t1, numInsertedRows = 2) doInsertWithSchemaEvolution(t1, Seq((2L, "Bobby", 25)).toDF("id", "name", "age") .select($"id", struct($"name", $"age").as("info")), @@ -820,6 +883,7 @@ trait InsertIntoSchemaEvolutionTests { this: InsertIntoTests => val initDf = Seq((1L, "Alice"), (2L, "Bob")).toDF("id", "name") .select($"id", struct($"name").as("info")) doInsert(t1, initDf) + checkInsertMetrics(t1, numInsertedRows = 2) doInsertWithSchemaEvolution(t1, Seq((2L, "Bobby", 25)).toDF("id", "name", "age") .select($"id", struct($"age", $"name").as("info")), @@ -853,6 +917,7 @@ trait InsertIntoSchemaEvolutionTests { this: InsertIntoTests => withTable(t1) { sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format PARTITIONED BY (id)") doInsert(t1, Seq((1L, "a"), (2L, "b")).toDF("id", "data")) + checkInsertMetrics(t1, numInsertedRows = 2) // Overwrite with schema evolution adding a new column, dynamic mode should only replace // partitions present in the inserted data. doInsertWithSchemaEvolution(t1, @@ -874,6 +939,7 @@ trait InsertIntoSchemaEvolutionTests { this: InsertIntoTests => withTable(t1) { sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format PARTITIONED BY (id)") doInsert(t1, Seq((1L, "a"), (2L, "b")).toDF("id", "data")) + checkInsertMetrics(t1, numInsertedRows = 2) doInsertWithSchemaEvolution(t1, Seq((true, "x", 2L), (false, "y", 3L)).toDF("active", "data", "id"), mode = SaveMode.Overwrite, @@ -894,6 +960,7 @@ trait InsertIntoSchemaEvolutionTests { this: InsertIntoTests => withTable(t1) { sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format PARTITIONED BY (id)") doInsert(t1, Seq((1L, "a"), (2L, "b")).toDF("id", "data")) + checkInsertMetrics(t1, numInsertedRows = 2) // Static mode overwrites the entire table. doInsertWithSchemaEvolution(t1, Seq((2L, "x", true), (3L, "y", false)).toDF("id", "data", "active"), @@ -949,8 +1016,10 @@ trait InsertIntoSchemaEvolutionTests { this: InsertIntoTests => sql(s"CREATE TABLE $t1 (id bigint) USING $v2Format") doInsertWithSchemaEvolution(t1, Seq((1L, "a")).toDF("id", "data")) + checkInsertMetrics(t1, numInsertedRows = 1) doInsertWithSchemaEvolution(t1, Seq((2L, "b", true)).toDF("id", "data", "active")) + checkInsertMetrics(t1, numInsertedRows = 1) verifyTable(t1, Seq[(Long, String, java.lang.Boolean)]( (1L, "a", null), (2L, "b", true) @@ -965,9 +1034,11 @@ trait InsertIntoSchemaEvolutionTests { this: InsertIntoTests => doInsert(t1, Seq((1L, "Alice")).toDF("id", "name") .select($"id", struct(struct($"name").as("nested")).as("info"))) + checkInsertMetrics(t1, numInsertedRows = 1) doInsertWithSchemaEvolution(t1, Seq((2L, "Bob", 30)).toDF("id", "name", "age") .select($"id", struct(struct($"name", $"age").as("nested")).as("info"))) + checkInsertMetrics(t1, numInsertedRows = 1) checkAnswer( sql(s"SELECT * FROM $t1"), Seq(Row(1L, Row(Row("Alice", null))), Row(2L, Row(Row("Bob", 30))))) @@ -981,10 +1052,12 @@ trait InsertIntoSchemaEvolutionTests { this: InsertIntoTests => doInsert(t1, Seq((1L, "Alice")).toDF("id", "name") .select($"id", struct(struct($"name").as("nested")).as("info"))) + checkInsertMetrics(t1, numInsertedRows = 1) doInsertWithSchemaEvolution(t1, Seq((2L, "Bob", 30)).toDF("id", "name", "age") .select($"id", struct(struct($"age", $"name").as("nested")).as("info")), byName = true) + checkInsertMetrics(t1, numInsertedRows = 1) checkAnswer( sql(s"SELECT * FROM $t1"), Seq(Row(1L, Row(Row("Alice", null))), Row(2L, Row(Row("Bob", 30))))) @@ -998,9 +1071,11 @@ trait InsertIntoSchemaEvolutionTests { this: InsertIntoTests => doInsert(t1, Seq((1L, "Alice")).toDF("id", "name") .select($"id", array(struct($"name")).as("info"))) + checkInsertMetrics(t1, numInsertedRows = 1) doInsertWithSchemaEvolution(t1, Seq((2L, "Bob", 30)).toDF("id", "name", "age") .select($"id", array(struct($"name", $"age")).as("info"))) + checkInsertMetrics(t1, numInsertedRows = 1) checkAnswer( sql(s"SELECT * FROM $t1"), Seq( @@ -1016,9 +1091,11 @@ trait InsertIntoSchemaEvolutionTests { this: InsertIntoTests => doInsert(t1, Seq((1L, "A", "Alice")).toDF("id", "key", "name") .select($"id", map($"key", struct($"name")).as("info"))) + checkInsertMetrics(t1, numInsertedRows = 1) doInsertWithSchemaEvolution(t1, Seq((2L, "B", "Bob", 30)).toDF("id", "key", "name", "age") .select($"id", map($"key", struct($"name", $"age")).as("info"))) + checkInsertMetrics(t1, numInsertedRows = 1) checkAnswer( sql(s"SELECT * FROM $t1"), Seq( @@ -1034,9 +1111,11 @@ trait InsertIntoSchemaEvolutionTests { this: InsertIntoTests => doInsert(t1, Seq((1L, "Alice", "A")).toDF("id", "name", "value") .select($"id", map(struct($"name"), $"value").as("info"))) + checkInsertMetrics(t1, numInsertedRows = 1) doInsertWithSchemaEvolution(t1, Seq((2L, "Bob", 30, "B")).toDF("id", "name", "age", "value") .select($"id", map(struct($"name", $"age"), $"value").as("info"))) + checkInsertMetrics(t1, numInsertedRows = 1) checkAnswer( sql(s"SELECT * FROM $t1"), Seq( @@ -1050,8 +1129,10 @@ trait InsertIntoSchemaEvolutionTests { this: InsertIntoTests => withTable(t1) { sql(s"CREATE TABLE $t1 (id int, data string) USING $v2Format") doInsert(t1, Seq((1, "a")).toDF("id", "data")) + checkInsertMetrics(t1, numInsertedRows = 1) doInsertWithSchemaEvolution(t1, Seq((Long.MaxValue, "b")).toDF("id", "data")) + checkInsertMetrics(t1, numInsertedRows = 1) checkAnswer( sql(s"SELECT * FROM $t1"), Seq(Row(1L, "a"), Row(Long.MaxValue, "b"))) @@ -1064,8 +1145,10 @@ trait InsertIntoSchemaEvolutionTests { this: InsertIntoTests => withTable(t1) { sql(s"CREATE TABLE $t1 (id int, data string) USING $v2Format") doInsert(t1, Seq((1, "a")).toDF("id", "data")) + checkInsertMetrics(t1, numInsertedRows = 1) doInsertWithSchemaEvolution(t1, Seq(("b", Long.MaxValue)).toDF("data", "id"), byName = true) + checkInsertMetrics(t1, numInsertedRows = 1) checkAnswer( sql(s"SELECT * FROM $t1"), Seq(Row(1L, "a"), Row(Long.MaxValue, "b"))) @@ -1078,8 +1161,10 @@ trait InsertIntoSchemaEvolutionTests { this: InsertIntoTests => withTable(t1) { sql(s"CREATE TABLE $t1 (id int, data string) USING $v2Format") doInsert(t1, Seq((1, "a")).toDF("id", "data")) + checkInsertMetrics(t1, numInsertedRows = 1) doInsertWithSchemaEvolution(t1, Seq((Long.MaxValue, "b", true)).toDF("id", "data", "active")) + checkInsertMetrics(t1, numInsertedRows = 1) checkAnswer( sql(s"SELECT * FROM $t1"), Seq( @@ -1098,9 +1183,11 @@ trait InsertIntoSchemaEvolutionTests { this: InsertIntoTests => doInsert(t1, Seq((1L, "Alice", 100)).toDF("id", "name", "value") .select($"id", struct($"value", $"name").as("info"))) + checkInsertMetrics(t1, numInsertedRows = 1) doInsertWithSchemaEvolution(t1, Seq((2L, "Bob", Long.MaxValue)).toDF("id", "name", "value") .select($"id", struct($"value", $"name").as("info"))) + checkInsertMetrics(t1, numInsertedRows = 1) checkAnswer( sql(s"SELECT id, info.value, info.name FROM $t1"), Seq(Row(1L, 100L, "Alice"), Row(2L, Long.MaxValue, "Bob"))) @@ -1116,10 +1203,12 @@ trait InsertIntoSchemaEvolutionTests { this: InsertIntoTests => doInsert(t1, Seq((1L, "Alice", 100)).toDF("id", "name", "value") .select($"id", struct($"value", $"name").as("info"))) + checkInsertMetrics(t1, numInsertedRows = 1) doInsertWithSchemaEvolution(t1, Seq((2L, "Bob", Long.MaxValue)).toDF("id", "name", "value") .select($"id", struct($"name", $"value").as("info")), byName = true) + checkInsertMetrics(t1, numInsertedRows = 1) checkAnswer( sql(s"SELECT id, info.value, info.name FROM $t1"), Seq(Row(1L, 100L, "Alice"), Row(2L, Long.MaxValue, "Bob"))) @@ -1135,9 +1224,11 @@ trait InsertIntoSchemaEvolutionTests { this: InsertIntoTests => doInsert(t1, Seq((1L, 100)).toDF("id", "value") .select($"id", array(struct($"value")).as("arr"))) + checkInsertMetrics(t1, numInsertedRows = 1) doInsertWithSchemaEvolution(t1, Seq((2L, Long.MaxValue)).toDF("id", "value") .select($"id", array(struct($"value")).as("arr"))) + checkInsertMetrics(t1, numInsertedRows = 1) checkAnswer( sql(s"SELECT id, arr[0].value FROM $t1"), Seq(Row(1L, 100L), Row(2L, Long.MaxValue))) @@ -1154,9 +1245,11 @@ trait InsertIntoSchemaEvolutionTests { this: InsertIntoTests => doInsert(t1, Seq((1L, "k1", 100)).toDF("id", "key", "value") .select($"id", map($"key", struct($"value")).as("m"))) + checkInsertMetrics(t1, numInsertedRows = 1) doInsertWithSchemaEvolution(t1, Seq((2L, "k2", Long.MaxValue)).toDF("id", "key", "value") .select($"id", map($"key", struct($"value")).as("m"))) + checkInsertMetrics(t1, numInsertedRows = 1) checkAnswer( sql(s"SELECT id, m['k1'].value, m['k2'].value FROM $t1"), Seq(Row(1L, 100L, null), Row(2L, null, Long.MaxValue))) @@ -1171,6 +1264,7 @@ trait InsertIntoSchemaEvolutionTests { this: InsertIntoTests => withTable(t1) { sql(s"CREATE TABLE $t1 (id int, data string) USING $v2Format") doInsert(t1, Seq((1, "a")).toDF("id", "data")) + checkInsertMetrics(t1, numInsertedRows = 1) doInsertWithSchemaEvolution(t1, Seq((Long.MaxValue, "b")).toDF("id", "data"), mode = SaveMode.Overwrite) checkAnswer( @@ -1185,8 +1279,10 @@ trait InsertIntoSchemaEvolutionTests { this: InsertIntoTests => withTable(t1) { sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format") doInsert(t1, Seq((1L, "a")).toDF("id", "data")) + checkInsertMetrics(t1, numInsertedRows = 1) // Inserting an int into a long column should not narrow the schema. doInsertWithSchemaEvolution(t1, Seq((2, "b")).toDF("id", "data")) + checkInsertMetrics(t1, numInsertedRows = 1) checkAnswer( sql(s"SELECT * FROM $t1"), Seq(Row(1L, "a"), Row(2L, "b"))) @@ -1200,9 +1296,11 @@ trait InsertIntoSchemaEvolutionTests { this: InsertIntoTests => withTable(t1) { sql(s"CREATE TABLE $t1 (id float, data string) USING $v2Format") doInsert(t1, Seq((1f, "a")).toDF("id", "data")) + checkInsertMetrics(t1, numInsertedRows = 1) // Inserting a double into a float should widen the schema, inserting an int into a string // should retain the string type. doInsertWithSchemaEvolution(t1, Seq((2d, 3)).toDF("id", "data")) + checkInsertMetrics(t1, numInsertedRows = 1) checkAnswer( sql(s"SELECT * FROM $t1"), Seq(Row(1d, "a"), Row(2d, "3"))) @@ -1239,9 +1337,11 @@ trait InsertIntoSchemaEvolutionTests { this: InsertIntoTests => withTable(t1) { sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format") doInsert(t1, Seq((1L, "a")).toDF("id", "data")) + checkInsertMetrics(t1, numInsertedRows = 1) // Insert a null value with NullType - should not change the target column type. doInsertWithSchemaEvolution(t1, Seq(2L).toDF("id").withColumn("data", lit(null))) + checkInsertMetrics(t1, numInsertedRows = 1) checkAnswer( sql(s"SELECT * FROM $t1"), Seq(Row(1L, "a"), Row(2L, null))) @@ -1254,9 +1354,11 @@ trait InsertIntoSchemaEvolutionTests { this: InsertIntoTests => withTable(t1) { sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format") doInsert(t1, Seq((1L, "a")).toDF("id", "data")) + checkInsertMetrics(t1, numInsertedRows = 1) doInsertWithSchemaEvolution(t1, Seq(2L).toDF("id").withColumn("data", lit(null)), byName = true) + checkInsertMetrics(t1, numInsertedRows = 1) checkAnswer( sql(s"SELECT * FROM $t1"), Seq(Row(1L, "a"), Row(2L, null))) @@ -1271,10 +1373,12 @@ trait InsertIntoSchemaEvolutionTests { this: InsertIntoTests => doInsert(t1, Seq((1L, "Alice", 100)).toDF("id", "name", "value") .select($"id", struct($"value", $"name").as("info"))) + checkInsertMetrics(t1, numInsertedRows = 1) // Insert with NullType for nested field - should not change the struct field type. doInsertWithSchemaEvolution(t1, Seq(2L).toDF("id") .withColumn("info", struct(lit(null).as("value"), lit("Bob").as("name")))) + checkInsertMetrics(t1, numInsertedRows = 1) checkAnswer( sql(s"SELECT id, info.value, info.name FROM $t1"), Seq(Row(1L, 100, "Alice"), Row(2L, null, "Bob"))) @@ -1288,8 +1392,10 @@ trait InsertIntoSchemaEvolutionTests { this: InsertIntoTests => withTable(t1) { sql(s"CREATE TABLE $t1 (id int, data string) USING $v2Format") doInsert(t1, Seq((1, "a")).toDF("id", "data")) + checkInsertMetrics(t1, numInsertedRows = 1) // Insert without schema evolution - should cast to target type, not widen. doInsert(t1, Seq((2L, "b")).toDF("id", "data")) + checkInsertMetrics(t1, numInsertedRows = 1) checkAnswer( sql(s"SELECT * FROM $t1"), Seq(Row(1, "a"), Row(2, "b"))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala index aaf45f0f5f7a5..b902074b547cb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala @@ -484,6 +484,16 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase Row(1, 100, "hr"), // insert Row(2, 200, "finance"), // insert Row(3, 300, "hr"))) // insert + + val mergeSummary = getMergeSummary() + assert(mergeSummary.numTargetRowsInserted === 3L) + assert(mergeSummary.numTargetRowsCopied === 0L) + assert(mergeSummary.numTargetRowsUpdated === 0L) + assert(mergeSummary.numTargetRowsDeleted === 0L) + assert(mergeSummary.numTargetRowsMatchedUpdated === 0L) + assert(mergeSummary.numTargetRowsMatchedDeleted === 0L) + assert(mergeSummary.numTargetRowsNotMatchedBySourceUpdated === 0L) + assert(mergeSummary.numTargetRowsNotMatchedBySourceDeleted === 0L) } } @@ -510,6 +520,16 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase Seq( Row(2, 200, "finance"), // insert Row(3, 300, "hr"))) // insert + + val mergeSummary = getMergeSummary() + assert(mergeSummary.numTargetRowsInserted === 2L) + assert(mergeSummary.numTargetRowsCopied === 0L) + assert(mergeSummary.numTargetRowsUpdated === 0L) + assert(mergeSummary.numTargetRowsDeleted === 0L) + assert(mergeSummary.numTargetRowsMatchedUpdated === 0L) + assert(mergeSummary.numTargetRowsMatchedDeleted === 0L) + assert(mergeSummary.numTargetRowsNotMatchedBySourceUpdated === 0L) + assert(mergeSummary.numTargetRowsNotMatchedBySourceDeleted === 0L) } } @@ -539,6 +559,16 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase Row(1, 100, "hr"), // insert Row(2, 200, "finance"), // insert Row(3, 300, "hr"))) // insert + + val mergeSummary = getMergeSummary() + assert(mergeSummary.numTargetRowsInserted === 3L) + assert(mergeSummary.numTargetRowsCopied === 0L) + assert(mergeSummary.numTargetRowsUpdated === 0L) + assert(mergeSummary.numTargetRowsDeleted === 0L) + assert(mergeSummary.numTargetRowsMatchedUpdated === 0L) + assert(mergeSummary.numTargetRowsMatchedDeleted === 0L) + assert(mergeSummary.numTargetRowsNotMatchedBySourceUpdated === 0L) + assert(mergeSummary.numTargetRowsNotMatchedBySourceDeleted === 0L) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala index 721b86593bacb..3e48c5222e6f2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala @@ -253,6 +253,9 @@ class V1WriteFallbackSessionCatalogSuite extends InsertIntoTests(supportsDynamicOverwrite = false, includeSQLOnlyTests = true) with SessionCatalogTest[InMemoryTableWithV1Fallback, V1FallbackTableCatalog] { + // V1 fallback writes do not flow through V2TableWriteExec, so no InsertSummary is emitted. + override protected def checkInsertMetrics(tableName: String, numInsertedRows: Long): Unit = () + override protected val v2Format = classOf[InMemoryV1Provider].getName override protected val catalogClassName: String = classOf[V1FallbackTableCatalog].getName override protected val catalogAndNamespace: String = ""