diff --git a/sqlspec/adapters/oracledb/config.py b/sqlspec/adapters/oracledb/config.py index ae18741b8..8aac72144 100644 --- a/sqlspec/adapters/oracledb/config.py +++ b/sqlspec/adapters/oracledb/config.py @@ -1,8 +1,12 @@ """OracleDB database configuration with direct field-based configuration.""" -from typing import TYPE_CHECKING, Any, ClassVar, TypedDict, cast +from collections.abc import Awaitable, Callable +from inspect import isawaitable +from ssl import TLSVersion +from typing import TYPE_CHECKING, Any, ClassVar, Literal, TypedDict, cast import oracledb +from oracledb import AuthMode, PoolGetMode, Purity from typing_extensions import NotRequired from sqlspec.adapters.oracledb._json_handlers import register_json_handlers # pyright: ignore[reportPrivateUsage] @@ -32,11 +36,8 @@ from sqlspec.utils.config_tools import normalize_connection_config if TYPE_CHECKING: - from collections.abc import Awaitable, Callable from types import TracebackType - from oracledb import AuthMode - from sqlspec.core import StatementConfig @@ -49,44 +50,96 @@ ) +OracleAccessToken = str | tuple[str, str] | Callable[..., str | tuple[str, str]] +OracleAppContext = tuple[str, str, str] +OracleProtocol = Literal["tcp", "tcps"] +OracleServerType = Literal["dedicated", "shared", "pooled"] +OraclePoolBoundary = Literal["statement", "transaction"] +OracleVectorReturnFormat = Literal["array", "list", "numpy"] +OracleEventsBackend = Literal["advanced_queue", "table_queue"] + + class OracleConnectionParams(TypedDict): """OracleDB connection parameters.""" dsn: NotRequired[str] + pool_alias: NotRequired[str] user: NotRequired[str] + proxy_user: NotRequired[str] password: NotRequired[str] + newpassword: NotRequired[str] + wallet_password: NotRequired[str] + access_token: NotRequired[OracleAccessToken] host: NotRequired[str] port: NotRequired[int] + protocol: NotRequired[OracleProtocol] + https_proxy: NotRequired[str] + https_proxy_port: NotRequired[int] service_name: NotRequired[str] + instance_name: NotRequired[str] sid: NotRequired[str] + server_type: NotRequired[OracleServerType] + cclass: NotRequired[str] + purity: NotRequired[Purity] + expire_time: NotRequired[int] + externalauth: NotRequired[bool] + mode: NotRequired[AuthMode] wallet_location: NotRequired[str] - wallet_password: NotRequired[str] config_dir: NotRequired[str] tcp_connect_timeout: NotRequired[float] retry_count: NotRequired[int] retry_delay: NotRequired[int] - mode: NotRequired["AuthMode"] + ssl_server_dn_match: NotRequired[bool] + ssl_server_cert_dn: NotRequired[str] events: NotRequired[bool] + disable_oob: NotRequired[bool] + stmtcachesize: NotRequired[int] edition: NotRequired[str] + tag: NotRequired[str] + matchanytag: NotRequired[bool] + appcontext: NotRequired[list[OracleAppContext]] + shardingkey: NotRequired[list[Any]] + supershardingkey: NotRequired[list[Any]] + debug_jdwp: NotRequired[str] + connection_id_prefix: NotRequired[str] + ssl_context: NotRequired[Any] + sdu: NotRequired[int] + pool_boundary: NotRequired[OraclePoolBoundary] + use_tcp_fast_open: NotRequired[bool] + ssl_version: NotRequired[TLSVersion] + program: NotRequired[str] + machine: NotRequired[str] + terminal: NotRequired[str] + osuser: NotRequired[str] + driver_name: NotRequired[str] + use_sni: NotRequired[bool] + thick_mode_dsn_passthrough: NotRequired[bool] + extra_auth_params: NotRequired[dict[str, Any]] + pool_name: NotRequired[str] + on_connect_callback: NotRequired[Callable[..., Any]] + handle: NotRequired[int] + extra: NotRequired[dict[str, Any]] class OraclePoolParams(OracleConnectionParams): """OracleDB pool parameters.""" + pool_class: NotRequired[type[Any]] + params: NotRequired[oracledb.PoolParams] min: NotRequired[int] max: NotRequired[int] increment: NotRequired[int] - threaded: NotRequired[bool] - getmode: NotRequired[Any] + connectiontype: NotRequired[type[Any]] + getmode: NotRequired[PoolGetMode] homogeneous: NotRequired[bool] timeout: NotRequired[int] wait_timeout: NotRequired[int] max_lifetime_session: NotRequired[int] - session_callback: NotRequired["Callable[..., Any]"] + session_callback: NotRequired[Callable[..., Any]] max_sessions_per_shard: NotRequired[int] soda_metadata_cache: NotRequired[bool] ping_interval: NotRequired[int] - extra: NotRequired["dict[str, Any]"] + ping_timeout: NotRequired[int] class OracleDriverFeatures(TypedDict): @@ -122,26 +175,31 @@ class OracleDriverFeatures(TypedDict): For sync: Callable[[OracleSyncConnection, str], None] - receives connection and tag For async: Callable[[OracleAsyncConnection, str], Awaitable[None]] Called after internal setup (numpy vectors, UUID handlers). - enable_events: Enable database event channel support. + enable_events: Enable SQLSpec event queue support. Defaults to True when extension_config["events"] is configured. Provides pub/sub capabilities via Oracle Advanced Queuing or table-backed fallback. Requires extension_config["events"] for migration setup when using table_queue backend. + This is separate from connection_config["events"], which enables python-oracledb + Thick mode database event notifications for HA and continuous query notification. events_backend: Event channel backend selection. Options: "advanced_queue", "table_queue" - "advanced_queue": Oracle Advanced Queuing (native messaging, requires DBMS_AQADM privileges) - "table_queue": Durable table-backed queue with retries and exactly-once delivery Defaults to "table_queue" (works on all Oracle editions without special privileges). + Native pipeline execution is runtime-gated by driver API support, Oracle Database + version, and the SQLSPEC_ORACLE_DISABLE_PIPELINE environment override; there is + no adapter config switch that can force-enable unsupported pipeline execution. """ enable_numpy_vectors: NotRequired[bool] enable_lowercase_column_names: NotRequired[bool] enable_uuid_binary: NotRequired[bool] - vector_return_format: NotRequired[str] + vector_return_format: NotRequired[OracleVectorReturnFormat] oracle_varchar2_byte_limit: NotRequired[int] oracle_raw_byte_limit: NotRequired[int] - on_connection_create: "NotRequired[Callable[..., Any]]" + on_connection_create: NotRequired[Callable[..., Any]] enable_events: NotRequired[bool] - events_backend: NotRequired[str] + events_backend: NotRequired[OracleEventsBackend] class OracleSyncConnectionContext(SyncPoolConnectionContext): @@ -193,7 +251,7 @@ def release_connection(self, _conn: "OracleSyncConnection", **kwargs: Any) -> No class OracleSyncConfig(SyncDatabaseConfig[OracleSyncConnection, "OracleSyncConnectionPool", OracleSyncDriver]): """Configuration for Oracle synchronous database connections.""" - __slots__ = ("_user_connection_hook",) + __slots__ = ("_pool_session_callback", "_user_connection_hook") driver_type: ClassVar[type[OracleSyncDriver]] = OracleSyncDriver connection_type: "ClassVar[type[OracleSyncConnection]]" = OracleSyncConnection @@ -234,6 +292,9 @@ def __init__( **kwargs: Additional keyword arguments. """ connection_config = normalize_connection_config(connection_config) + self._pool_session_callback = cast( + "Callable[[OracleSyncConnection, str], None] | None", connection_config.pop("session_callback", None) + ) statement_config = statement_config or default_statement_config driver_features = apply_driver_features(driver_features) @@ -259,6 +320,7 @@ def _create_pool(self) -> "OracleSyncConnectionPool": """Create the actual connection pool.""" config = dict(self.connection_config) + config.pop("threaded", None) config["session_callback"] = self._init_connection return oracledb.create_pool(**config) @@ -297,7 +359,9 @@ def _init_connection(self, connection: "OracleSyncConnection", tag: str) -> None # dispatch without re-reading driver-feature defaults on every fetch. setattr(connection, "_sqlspec_vector_return_format", self.driver_features.get("vector_return_format")) - # Call user-provided callback after internal setup + if self._pool_session_callback is not None: + self._pool_session_callback(connection, tag) + if self._user_connection_hook is not None: self._user_connection_hook(connection, tag) @@ -392,7 +456,7 @@ class _OracleAsyncSessionConnectionHandler(AsyncPoolSessionFactory): class OracleAsyncConfig(AsyncDatabaseConfig[OracleAsyncConnection, "OracleAsyncConnectionPool", OracleAsyncDriver]): """Configuration for Oracle asynchronous database connections.""" - __slots__ = ("_user_connection_hook",) + __slots__ = ("_pool_session_callback", "_user_connection_hook") connection_type: "ClassVar[type[OracleAsyncConnection]]" = OracleAsyncConnection driver_type: ClassVar[type[OracleAsyncDriver]] = OracleAsyncDriver @@ -435,6 +499,9 @@ def __init__( **kwargs: Additional keyword arguments. """ connection_config = normalize_connection_config(connection_config) + self._pool_session_callback = cast( + "Callable[[OracleAsyncConnection, str], Any] | None", connection_config.pop("session_callback", None) + ) driver_features = apply_driver_features(driver_features) @@ -459,6 +526,7 @@ async def _create_pool(self) -> "OracleAsyncConnectionPool": """Create the actual async connection pool.""" config = dict(self.connection_config) + config.pop("threaded", None) config["session_callback"] = self._init_connection return oracledb.create_pool_async(**config) @@ -493,9 +561,15 @@ async def _init_connection(self, connection: "OracleAsyncConnection", tag: str) # dispatch without re-reading driver-feature defaults on every fetch. setattr(connection, "_sqlspec_vector_return_format", self.driver_features.get("vector_return_format")) - # Call user-provided callback after internal setup + if self._pool_session_callback is not None: + session_callback_result = self._pool_session_callback(connection, tag) + if isawaitable(session_callback_result): + await session_callback_result + if self._user_connection_hook is not None: - await self._user_connection_hook(connection, tag) + hook_result = self._user_connection_hook(connection, tag) + if isawaitable(hook_result): + await hook_result async def _close_pool(self) -> None: """Close the actual async connection pool.""" diff --git a/tests/unit/adapters/test_oracledb/test_config.py b/tests/unit/adapters/test_oracledb/test_config.py new file mode 100644 index 000000000..ee95f03d4 --- /dev/null +++ b/tests/unit/adapters/test_oracledb/test_config.py @@ -0,0 +1,207 @@ +"""OracleDB configuration tests covering driver kwargs and typed options.""" + +from collections.abc import Awaitable, Callable +from ssl import TLSVersion +from typing import Any, cast, get_args, get_origin, get_type_hints + +import pytest +from oracledb import AuthMode, PoolGetMode, Purity +from typing_extensions import NotRequired + +from sqlspec.adapters.oracledb import config as oracle_config_module +from sqlspec.adapters.oracledb.config import ( + OracleAsyncConfig, + OracleConnectionParams, + OracleDriverFeatures, + OraclePoolParams, + OracleSyncConfig, +) + + +class _StubConnection: + version = "23.5.0.0.0" + + +def _unwrap_not_required(annotation: object) -> object: + assert get_origin(annotation) is NotRequired + return get_args(annotation)[0] + + +def _oracle_config_hints(typeddict: type[object]) -> dict[str, object]: + globalns = dict(vars(oracle_config_module)) + globalns.update({ + "AuthMode": AuthMode, + "Awaitable": Awaitable, + "Callable": Callable, + "PoolGetMode": PoolGetMode, + "Purity": Purity, + "TLSVersion": TLSVersion, + }) + return get_type_hints(typeddict, globalns=globalns, localns=globalns, include_extras=True) + + +def _stub_sync_connection_setup(monkeypatch: pytest.MonkeyPatch, calls: list[str]) -> None: + monkeypatch.setattr(oracle_config_module, "register_numpy_handlers", lambda _connection: calls.append("numpy")) + monkeypatch.setattr(oracle_config_module, "register_json_handlers", lambda _connection: calls.append("json")) + monkeypatch.setattr(oracle_config_module, "register_uuid_handlers", lambda _connection: calls.append("uuid")) + monkeypatch.setattr(oracle_config_module, "_extract_oracle_major", lambda _connection: 23) + + +def test_oracle_connection_params_expose_current_driver_options() -> None: + """Connection params should mirror current python-oracledb connection knobs.""" + annotations = _oracle_config_hints(OracleConnectionParams) + + expected_options = { + "access_token", + "appcontext", + "cclass", + "connection_id_prefix", + "debug_jdwp", + "disable_oob", + "driver_name", + "events", + "expire_time", + "extra", + "extra_auth_params", + "externalauth", + "handle", + "https_proxy", + "https_proxy_port", + "instance_name", + "machine", + "matchanytag", + "mode", + "newpassword", + "on_connect_callback", + "osuser", + "pool_boundary", + "pool_name", + "program", + "protocol", + "proxy_user", + "purity", + "sdu", + "server_type", + "shardingkey", + "ssl_context", + "ssl_server_cert_dn", + "ssl_server_dn_match", + "ssl_version", + "stmtcachesize", + "supershardingkey", + "tag", + "terminal", + "thick_mode_dsn_passthrough", + "use_sni", + "use_tcp_fast_open", + "wallet_password", + } + + assert expected_options <= annotations.keys() + + +def test_oracle_pool_params_expose_current_pool_options_and_remove_threaded() -> None: + """Pool params should include current pool options without stale ``threaded``.""" + annotations = _oracle_config_hints(OraclePoolParams) + + assert { + "connectiontype", + "getmode", + "homogeneous", + "max_lifetime_session", + "max_sessions_per_shard", + "on_connect_callback", + "ping_timeout", + "pool_alias", + "pool_class", + "soda_metadata_cache", + "wait_timeout", + } <= annotations.keys() + assert "threaded" not in annotations + + +def test_oracle_config_finite_options_use_literals_and_driver_enums() -> None: + """Finite Oracle settings should be typed more narrowly than plain ``str`` or ``Any``.""" + connection_hints = _oracle_config_hints(OracleConnectionParams) + pool_hints = _oracle_config_hints(OraclePoolParams) + driver_feature_hints = _oracle_config_hints(OracleDriverFeatures) + + assert set(get_args(_unwrap_not_required(connection_hints["protocol"]))) == {"tcp", "tcps"} + assert set(get_args(_unwrap_not_required(connection_hints["server_type"]))) == {"dedicated", "pooled", "shared"} + assert _unwrap_not_required(connection_hints["mode"]) is AuthMode + assert _unwrap_not_required(connection_hints["purity"]) is Purity + assert _unwrap_not_required(pool_hints["getmode"]) is PoolGetMode + assert set(get_args(_unwrap_not_required(driver_feature_hints["vector_return_format"]))) == { + "array", + "list", + "numpy", + } + assert set(get_args(_unwrap_not_required(driver_feature_hints["events_backend"]))) == { + "advanced_queue", + "table_queue", + } + + +def test_oracle_sync_create_pool_merges_extra_and_drops_stale_threaded(monkeypatch: pytest.MonkeyPatch) -> None: + """``extra`` should merge as kwargs, while stale ``threaded`` should not reach python-oracledb.""" + seen_kwargs: dict[str, object] = {} + + def fake_create_pool(**kwargs: object) -> object: + seen_kwargs.update(kwargs) + return object() + + monkeypatch.setattr(oracle_config_module.oracledb, "create_pool", fake_create_pool) + config = OracleSyncConfig( + connection_config={"threaded": True, "user": "scott", "extra": {"pool_alias": "sqlspec-main", "use_sni": True}} + ) + + config._create_pool() # pyright: ignore[reportPrivateUsage] + + assert seen_kwargs["user"] == "scott" + assert seen_kwargs["pool_alias"] == "sqlspec-main" + assert seen_kwargs["use_sni"] is True + assert "extra" not in seen_kwargs + assert "threaded" not in seen_kwargs + + +def test_oracle_sync_connection_config_session_callback_is_preserved(monkeypatch: pytest.MonkeyPatch) -> None: + """Native pool ``session_callback`` should run in addition to SQLSpec setup.""" + calls: list[str] = [] + _stub_sync_connection_setup(monkeypatch, calls) + + def session_callback(_connection: object, _tag: str) -> None: + calls.append("session_callback") + + def on_connection_create(_connection: object, _tag: str) -> None: + calls.append("on_connection_create") + + config = OracleSyncConfig( + connection_config={"session_callback": session_callback}, + driver_features={"on_connection_create": on_connection_create}, + ) + + config._init_connection(cast(Any, _StubConnection()), "analytics") # pyright: ignore[reportPrivateUsage] + + assert calls == ["numpy", "json", "uuid", "session_callback", "on_connection_create"] + + +@pytest.mark.anyio +async def test_oracle_async_connection_config_session_callback_is_preserved(monkeypatch: pytest.MonkeyPatch) -> None: + """Async native pool ``session_callback`` should be awaited when it returns an awaitable.""" + calls: list[str] = [] + _stub_sync_connection_setup(monkeypatch, calls) + + async def session_callback(_connection: object, _tag: str) -> None: + calls.append("session_callback") + + async def on_connection_create(_connection: object, _tag: str) -> None: + calls.append("on_connection_create") + + config = OracleAsyncConfig( + connection_config={"session_callback": session_callback}, + driver_features={"on_connection_create": on_connection_create}, + ) + + await config._init_connection(cast(Any, _StubConnection()), "analytics") # pyright: ignore[reportPrivateUsage] + + assert calls == ["numpy", "json", "uuid", "session_callback", "on_connection_create"]