diff --git a/src/datacustomcode/__init__.py b/src/datacustomcode/__init__.py index 00cfae3..85cfa54 100644 --- a/src/datacustomcode/__init__.py +++ b/src/datacustomcode/__init__.py @@ -13,11 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from datacustomcode.client import Client -from datacustomcode.credentials import AuthType, Credentials -from datacustomcode.io.reader.query_api import QueryAPIDataCloudReader -from datacustomcode.io.writer.print import PrintDataCloudWriter - __all__ = [ "AuthType", "Client", @@ -25,3 +20,28 @@ "PrintDataCloudWriter", "QueryAPIDataCloudReader", ] + + +def __getattr__(name: str): + """Lazy import heavy dependencies.""" + if name == "Client": + from datacustomcode.client import Client + + return Client + elif name == "AuthType": + from datacustomcode.credentials import AuthType + + return AuthType + elif name == "Credentials": + from datacustomcode.credentials import Credentials + + return Credentials + elif name == "PrintDataCloudWriter": + from datacustomcode.io.writer.print import PrintDataCloudWriter + + return PrintDataCloudWriter + elif name == "QueryAPIDataCloudReader": + from datacustomcode.io.reader.query_api import QueryAPIDataCloudReader + + return QueryAPIDataCloudReader + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/src/datacustomcode/client.py b/src/datacustomcode/client.py index faecf0a..9ad95be 100644 --- a/src/datacustomcode/client.py +++ b/src/datacustomcode/client.py @@ -112,8 +112,8 @@ class Client: def __new__( cls, reader: Optional[BaseDataCloudReader] = None, - writer: Optional["BaseDataCloudWriter"] = None, - spark_provider: Optional["BaseSparkSessionProvider"] = None, + writer: Optional[BaseDataCloudWriter] = None, + spark_provider: Optional[BaseSparkSessionProvider] = None, code_type: str = "script", ) -> Client: diff --git a/src/datacustomcode/function/feature_types/chunking.py b/src/datacustomcode/function/feature_types/chunking.py index 1425921..26aedb5 100644 --- a/src/datacustomcode/function/feature_types/chunking.py +++ b/src/datacustomcode/function/feature_types/chunking.py @@ -50,16 +50,16 @@ class ChunkType(str, Enum): class SearchIndexChunkingV1PrependField(BaseModel): """Field to prepend to chunk content""" - dmo_name: str = Field( - default="", description="Data Model Object name", examples=["udmo_1__dlm"] + dmo_name: Optional[str] = Field( + default=None, description="Data Model Object name", examples=["udmo_1__dlm"] ) - field_name: str = Field( - default="", + field_name: Optional[str] = Field( + default=None, description="Field name to prepend", examples=["ResolvedFilePath__c"], ) - value: str = Field( - default="", + value: Optional[str] = Field( + default=None, description="Field value to prepend", examples=["udlo_1__dll:quarterly_report.pdf"], ) @@ -67,20 +67,20 @@ class SearchIndexChunkingV1PrependField(BaseModel): class SearchIndexChunkingV1TranscriptField(BaseModel): - """Field to prepend to chunk content""" + """Transcript timing and speaker metadata for audio/video documents""" - speaker: str = Field( - default="", + speaker: Optional[str] = Field( + default=None, description="Speaker name for audio/video transcripts", examples=["Agent"], ) - start_timestamp: str = Field( - default="", + start_timestamp: Optional[str] = Field( + default=None, description="Start timestamp in ISO8601 format: YYYY-MM-DDTHH:MM:SS.ffffff", examples=["2026-03-25T02:01:24.918000"], ) - end_timestamp: str = Field( - default="", + end_timestamp: Optional[str] = Field( + default=None, description="End timestamp in ISO8601 format: YYYY-MM-DDTHH:MM:SS.ffffff", examples=["2026-03-25T02:01:30.500000"], ) @@ -88,44 +88,76 @@ class SearchIndexChunkingV1TranscriptField(BaseModel): class SearchIndexChunkingV1Metadata(BaseModel): - """Metadata for input documents""" + """Metadata for input documents.""" - type: DocumentType = Field( - default=DocumentType.TEXT, description="Document type (text)", examples=["text"] - ) - transcript_fields: SearchIndexChunkingV1TranscriptField = Field( - default_factory=SearchIndexChunkingV1TranscriptField, + type: Optional[DocumentType] = Field( + default=DocumentType.TEXT, description=( - "Transcript information. Will only be there in case of audio-video files" + "Document type of the chunk input. Currently only 'text' is supported." ), + examples=["text"], ) - page_number: int = Field( - default=0, - description="Page number in the source document (0-based)", + page_number: Optional[int] = Field( + default=None, + description=("Page number in the source document (0-based). "), examples=[1], ) + transcript_fields: Optional[SearchIndexChunkingV1TranscriptField] = Field( + default=None, + description=( + "Speaker and timestamp metadata for audio/video transcripts. " + "Optional — only present when the source document is a transcript." + ), + ) text_as_html: Optional[str] = Field( default=None, - description="HTML representation of the document text", + description=("HTML representation of the chunk text, if available. "), examples=["
Online Remittance Instructions
"], ) - source_dmo_fields: Dict[str, Union[str, int]] = Field( - default_factory=dict, + source_dmo_fields: Optional[Dict[str, Union[str, int, float]]] = Field( + default=None, description=( - "Source Data Model Object fields as key-value pairs " - "(values can be string or int)" + "Source Data Model Object fields as key-value pairs. " + "Values can be string, int, or float." ), examples=[ { "FilePath__c": "quarterly_report.pdf", - "Size__c": 1377454, + "Size__c": 1377454.0, "ContentType__c": "pdf", "LastModified__c": "2026-03-25T02:01:24.918000", } ], ) - prepend: List[SearchIndexChunkingV1PrependField] = Field( - default_factory=list, description="List of fields to prepend to each chunk" + prepend: Optional[List[SearchIndexChunkingV1PrependField]] = Field( + default=None, + description=( + "List of DMO fields whose values are prepended to the chunk " + "text before indexing" + ), + ) + image_base64: Optional[str] = Field( + default=None, + description=( + "Base64-encoded image data associated with this chunk. " + "Optional — only applicable for image-type document elements." + ), + ) + image_mime_type: Optional[str] = Field( + default=None, + description=( + "MIME type of the associated image (e.g., 'image/png', 'image/jpeg'). " + "Optional — should be provided alongside image_base64 when present." + ), + examples=["image/png", "image/jpeg"], + ) + image_type: Optional[str] = Field( + default=None, + description=( + "Semantic category of the image content" + "(e.g., 'diagram', 'screenshot', 'chart'). Optional." + ), + examples=["diagram", "screenshot"], ) model_config = ConfigDict(extra="ignore") @@ -143,9 +175,12 @@ class SearchIndexChunkingV1DocElement(BaseModel): ) ], ) - metadata: SearchIndexChunkingV1Metadata = Field( - default_factory=SearchIndexChunkingV1Metadata, - description="Source document metadata", + metadata: Optional[SearchIndexChunkingV1Metadata] = Field( + default=None, + description=( + "Source document metadata. Optional — may be absent if no " + "metadata is available for the document element." + ), ) model_config = ConfigDict(extra="ignore") @@ -159,21 +194,25 @@ class SearchIndexChunkingV1Output(BaseModel): examples=["Online Remittance Instructions"], ) seq_no: int = Field( - default=0, description="Sequential chunk number (1-based)", ge=1, examples=[1] - ) - chunk_id: str = Field( - default="", - description="Unique identifier for this chunk (UUID format)", - examples=["550e8400-e29b-41d4-a716-446655440000"], + default=0, + description=( + "Sequential order of this chunk within the output " + "Represents chunk ordering within the source document (1-based)." + ), + ge=1, + examples=[1], ) chunk_type: ChunkType = Field( default=ChunkType.TEXT, - description="Type of chunk (e.g., 'text')", + description="Type of chunk. Fixed value — always 'text'.", examples=["text"], ) - citations: Dict[str, str] = Field( - default_factory=dict, - description="Citation information as key-value pairs", + citations: Optional[Dict[str, str]] = Field( + default=None, + description=( + "Citation metadata associated with this chunk as key-value " + "pairs. Optional — defaults to None if no citations are present." + ), examples=[{"source": "quarterly_report.pdf"}], ) model_config = ConfigDict(extra="ignore") @@ -194,4 +233,3 @@ class SearchIndexChunkingV1Response(BaseModel): output: List[SearchIndexChunkingV1Output] = Field( default_factory=list, description="Flat list of chunks from all docs" ) - model_config = ConfigDict(extra="ignore") diff --git a/src/datacustomcode/function_utils.py b/src/datacustomcode/function_utils.py index c499526..d10da1c 100644 --- a/src/datacustomcode/function_utils.py +++ b/src/datacustomcode/function_utils.py @@ -16,6 +16,7 @@ """Utilities for inspecting and working with function entrypoints.""" import ast +from enum import Enum import importlib.util import inspect import json @@ -278,11 +279,17 @@ def _generate_model_sample_data(model_type): # Use examples if available if field_info.examples and len(field_info.examples) > 0: sample_data[field_name] = field_info.examples[0] - # Check if field has a real default value - elif field_info.default is not PydanticUndefined: + # If field has a non-None, non-empty default value, use it + elif ( + field_info.default is not PydanticUndefined + and field_info.default is not None + and field_info.default != [] + and field_info.default != {} + ): sample_data[field_name] = field_info.default + # For all other fields (including default_factory, None defaults, + # empty defaults), generate sample data else: - # Required field or field without default - generate sample sample_data[field_name] = generate_sample_value( field_info.annotation, field_name ) @@ -301,6 +308,17 @@ def generate_sample_value(field_type, field_name: str): """ origin = typing.get_origin(field_type) + # Handle Optional[T] (Union[T, None]) by unwrapping to T + if origin is typing.Union: + non_none_args = [ + arg for arg in typing.get_args(field_type) if arg is not type(None) + ] + return ( + generate_sample_value(non_none_args[0], field_name) + if non_none_args + else None + ) + if origin is list or field_type is list: args = typing.get_args(field_type) if args: @@ -320,6 +338,10 @@ def generate_sample_value(field_type, field_name: str): return 1.0 elif field_type is bool: return True + # Handle Enum types + elif isinstance(field_type, type) and issubclass(field_type, Enum): + # Return the first enum value + return next(iter(field_type)).value elif hasattr(field_type, "model_fields"): # Nested Pydantic model - use shared helper return _generate_model_sample_data(field_type) diff --git a/src/datacustomcode/io/reader/sf_cli.py b/src/datacustomcode/io/reader/sf_cli.py index cfeb06e..37284e4 100644 --- a/src/datacustomcode/io/reader/sf_cli.py +++ b/src/datacustomcode/io/reader/sf_cli.py @@ -23,7 +23,6 @@ Union, ) -import pandas as pd import requests from datacustomcode.io.reader.base import BaseDataCloudReader @@ -31,6 +30,7 @@ from datacustomcode.token_provider import SFCLITokenProvider if TYPE_CHECKING: + import pandas as pd from pyspark.sql import DataFrame as PySparkDataFrame, SparkSession from pyspark.sql.types import AtomicType, StructType @@ -97,6 +97,8 @@ def _execute_query(self, sql: str) -> pd.DataFrame: Raises: RuntimeError: On HTTP errors or unexpected response shapes. """ + import pandas as pd + access_token, instance_url = self._get_token() url = f"{instance_url}/services/data/{API_VERSION}/ssot/query-sql" diff --git a/src/datacustomcode/io/reader/utils.py b/src/datacustomcode/io/reader/utils.py index b8e65e6..681cb84 100644 --- a/src/datacustomcode/io/reader/utils.py +++ b/src/datacustomcode/io/reader/utils.py @@ -16,32 +16,32 @@ from typing import TYPE_CHECKING -import pandas.api.types as pd_types -from pyspark.sql.types import ( - BooleanType, - DoubleType, - LongType, - StringType, - StructField, - StructType, - TimestampType, -) - if TYPE_CHECKING: import pandas - from pyspark.sql.types import AtomicType - -PANDAS_TYPE_MAPPING = { - "object": StringType(), - "int64": LongType(), - "float64": DoubleType(), - "bool": BooleanType(), -} + from pyspark.sql.types import AtomicType, StructType def _pandas_to_spark_schema( pandas_df: pandas.DataFrame, nullable: bool = True ) -> StructType: + import pandas.api.types as pd_types + from pyspark.sql.types import ( + BooleanType, + DoubleType, + LongType, + StringType, + StructField, + StructType, + TimestampType, + ) + + PANDAS_TYPE_MAPPING = { + "object": StringType(), + "int64": LongType(), + "float64": DoubleType(), + "bool": BooleanType(), + } + fields = [] for column, dtype in pandas_df.dtypes.items(): spark_type: AtomicType diff --git a/src/datacustomcode/io/writer/csv.py b/src/datacustomcode/io/writer/csv.py index 53a9ecc..3d037d9 100644 --- a/src/datacustomcode/io/writer/csv.py +++ b/src/datacustomcode/io/writer/csv.py @@ -13,8 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations -from pyspark.sql import DataFrame as PySparkDataFrame +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from pyspark.sql import DataFrame as PySparkDataFrame from datacustomcode.io.writer.base import BaseDataCloudWriter, WriteMode diff --git a/src/datacustomcode/io/writer/print.py b/src/datacustomcode/io/writer/print.py index 2fcec24..c4d2a75 100644 --- a/src/datacustomcode/io/writer/print.py +++ b/src/datacustomcode/io/writer/print.py @@ -13,12 +13,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations -from typing import Optional +from typing import TYPE_CHECKING, Optional -from pyspark.sql import DataFrame as PySparkDataFrame, SparkSession +if TYPE_CHECKING: + from pyspark.sql import DataFrame as PySparkDataFrame, SparkSession + + from datacustomcode.io.reader.query_api import QueryAPIDataCloudReader -from datacustomcode.io.reader.query_api import QueryAPIDataCloudReader from datacustomcode.io.writer.base import BaseDataCloudWriter, WriteMode @@ -61,6 +64,8 @@ def __init__( sf_cli_org: Optional SF CLI org alias or username. If provided, credentials are fetched via `sf org display`. """ + from datacustomcode.io.reader.query_api import QueryAPIDataCloudReader + super().__init__(spark) if reader is None: self.reader = QueryAPIDataCloudReader( diff --git a/src/datacustomcode/mixin.py b/src/datacustomcode/mixin.py index f2872c6..3751351 100644 --- a/src/datacustomcode/mixin.py +++ b/src/datacustomcode/mixin.py @@ -72,6 +72,35 @@ def subclass_from_config_name(cls: type[_V], config_name: str) -> type[_V]: Args: config_name: should match a subclass's ``CONFIG_NAME``. """ + # First, check if already registered (from __init_subclass__) + if config_name in UserExtendableNamedConfigMixin._registered_config_names: + candidate = UserExtendableNamedConfigMixin._registered_config_names[ + config_name + ] + # Verify it's actually a subclass of cls (respects hierarchy) + if candidate is cls or issubclass(candidate, cls): + return candidate + + # If not found, try to trigger lazy import via __getattr__ + # This handles the case where subclasses use lazy loading + try: + import datacustomcode + + # Attempt to trigger __getattr__ by accessing the name + getattr(datacustomcode, config_name, None) + except (ImportError, AttributeError): + pass + + # Check again after potential lazy import + if config_name in UserExtendableNamedConfigMixin._registered_config_names: + candidate = UserExtendableNamedConfigMixin._registered_config_names[ + config_name + ] + # Verify it's actually a subclass of cls (respects hierarchy) + if candidate is cls or issubclass(candidate, cls): + return candidate + + # Fallback to dynamic lookup (for user-added subclasses) subclass_config_name_map = {} for type_ in _get_all_subclass_descendants(cls): if name := getattr(type_, "CONFIG_NAME", ""): diff --git a/src/datacustomcode/spark/base.py b/src/datacustomcode/spark/base.py index fe7bf92..cb684f1 100644 --- a/src/datacustomcode/spark/base.py +++ b/src/datacustomcode/spark/base.py @@ -25,5 +25,5 @@ class BaseSparkSessionProvider(UserExtendableNamedConfigMixin): - def get_session(self, spark_config: SparkConfig) -> "SparkSession": + def get_session(self, spark_config: SparkConfig) -> SparkSession: raise NotImplementedError diff --git a/src/datacustomcode/spark/default.py b/src/datacustomcode/spark/default.py index d020dd1..418751c 100644 --- a/src/datacustomcode/spark/default.py +++ b/src/datacustomcode/spark/default.py @@ -27,7 +27,7 @@ class DefaultSparkSessionProvider(BaseSparkSessionProvider): CONFIG_NAME = "DefaultSparkSessionProvider" - def get_session(self, spark_config: SparkConfig) -> "SparkSession": + def get_session(self, spark_config: SparkConfig) -> SparkSession: from pyspark.sql import SparkSession builder = SparkSession.builder diff --git a/src/datacustomcode/templates/function/chunking/payload/entrypoint.py b/src/datacustomcode/templates/function/chunking/payload/entrypoint.py index dd199a7..8200e0f 100644 --- a/src/datacustomcode/templates/function/chunking/payload/entrypoint.py +++ b/src/datacustomcode/templates/function/chunking/payload/entrypoint.py @@ -1,5 +1,4 @@ import logging -import uuid from datacustomcode.function import Runtime from datacustomcode.function.feature_types.chunking import ( @@ -124,12 +123,11 @@ def function( for chunk_text in text_chunks: # Create citations from source_dmo_fields if available citations = {} - if metadata.source_dmo_fields: + if metadata and metadata.source_dmo_fields: for key, value in metadata.source_dmo_fields.items(): citations[key] = str(value) chunk_output = SearchIndexChunkingV1Output( - chunk_id=str(uuid.uuid4()), chunk_type=ChunkType.TEXT, text=chunk_text.strip(), seq_no=seq_no, diff --git a/tests/test_function_utils.py b/tests/test_function_utils.py index cc0f51d..64a7b99 100644 --- a/tests/test_function_utils.py +++ b/tests/test_function_utils.py @@ -193,7 +193,10 @@ def function(request: SimpleRequest): assert "message" in data assert data["count"] == 5 assert data["version"] == "v1" - assert data["tags"] == [] + # tags now gets sample data generated (not empty list) + assert "tags" in data + assert isinstance(data["tags"], list) + assert len(data["tags"]) > 0 # Test 2: Complex request type with nested models entrypoint_complex = os.path.join(temp_dir, "entrypoint_complex.py") @@ -225,8 +228,9 @@ def function(request: ComplexRequest): assert "port" in complex_data["config"] assert complex_data["config"]["port"] == 8080 assert complex_data["config"]["enabled"] is True + # metadata now gets sample data generated (not empty dict) assert "metadata" in complex_data - assert complex_data["metadata"] == {} + assert isinstance(complex_data["metadata"], dict) finally: if temp_dir in sys.path: