Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 93 additions & 19 deletions sqlspec/adapters/oracledb/config.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
"""OracleDB database configuration with direct field-based configuration."""

from typing import TYPE_CHECKING, Any, ClassVar, TypedDict, cast
from collections.abc import Awaitable, Callable
from inspect import isawaitable
from ssl import TLSVersion
from typing import TYPE_CHECKING, Any, ClassVar, Literal, TypedDict, cast

import oracledb
from oracledb import AuthMode, PoolGetMode, Purity
from typing_extensions import NotRequired

from sqlspec.adapters.oracledb._json_handlers import register_json_handlers # pyright: ignore[reportPrivateUsage]
Expand Down Expand Up @@ -32,11 +36,8 @@
from sqlspec.utils.config_tools import normalize_connection_config

if TYPE_CHECKING:
from collections.abc import Awaitable, Callable
from types import TracebackType

from oracledb import AuthMode

from sqlspec.core import StatementConfig


Expand All @@ -49,44 +50,96 @@
)


OracleAccessToken = str | tuple[str, str] | Callable[..., str | tuple[str, str]]
OracleAppContext = tuple[str, str, str]
OracleProtocol = Literal["tcp", "tcps"]
OracleServerType = Literal["dedicated", "shared", "pooled"]
OraclePoolBoundary = Literal["statement", "transaction"]
OracleVectorReturnFormat = Literal["array", "list", "numpy"]
OracleEventsBackend = Literal["advanced_queue", "table_queue"]


class OracleConnectionParams(TypedDict):
"""OracleDB connection parameters."""

dsn: NotRequired[str]
pool_alias: NotRequired[str]
user: NotRequired[str]
proxy_user: NotRequired[str]
password: NotRequired[str]
newpassword: NotRequired[str]
wallet_password: NotRequired[str]
access_token: NotRequired[OracleAccessToken]
host: NotRequired[str]
port: NotRequired[int]
protocol: NotRequired[OracleProtocol]
https_proxy: NotRequired[str]
https_proxy_port: NotRequired[int]
service_name: NotRequired[str]
instance_name: NotRequired[str]
sid: NotRequired[str]
server_type: NotRequired[OracleServerType]
cclass: NotRequired[str]
purity: NotRequired[Purity]
expire_time: NotRequired[int]
externalauth: NotRequired[bool]
mode: NotRequired[AuthMode]
wallet_location: NotRequired[str]
wallet_password: NotRequired[str]
config_dir: NotRequired[str]
tcp_connect_timeout: NotRequired[float]
retry_count: NotRequired[int]
retry_delay: NotRequired[int]
mode: NotRequired["AuthMode"]
ssl_server_dn_match: NotRequired[bool]
ssl_server_cert_dn: NotRequired[str]
events: NotRequired[bool]
disable_oob: NotRequired[bool]
stmtcachesize: NotRequired[int]
edition: NotRequired[str]
tag: NotRequired[str]
matchanytag: NotRequired[bool]
appcontext: NotRequired[list[OracleAppContext]]
shardingkey: NotRequired[list[Any]]
supershardingkey: NotRequired[list[Any]]
debug_jdwp: NotRequired[str]
connection_id_prefix: NotRequired[str]
ssl_context: NotRequired[Any]
sdu: NotRequired[int]
pool_boundary: NotRequired[OraclePoolBoundary]
use_tcp_fast_open: NotRequired[bool]
ssl_version: NotRequired[TLSVersion]
program: NotRequired[str]
machine: NotRequired[str]
terminal: NotRequired[str]
osuser: NotRequired[str]
driver_name: NotRequired[str]
use_sni: NotRequired[bool]
thick_mode_dsn_passthrough: NotRequired[bool]
extra_auth_params: NotRequired[dict[str, Any]]
pool_name: NotRequired[str]
on_connect_callback: NotRequired[Callable[..., Any]]
handle: NotRequired[int]
extra: NotRequired[dict[str, Any]]


class OraclePoolParams(OracleConnectionParams):
"""OracleDB pool parameters."""

pool_class: NotRequired[type[Any]]
params: NotRequired[oracledb.PoolParams]
min: NotRequired[int]
max: NotRequired[int]
increment: NotRequired[int]
threaded: NotRequired[bool]
getmode: NotRequired[Any]
connectiontype: NotRequired[type[Any]]
getmode: NotRequired[PoolGetMode]
homogeneous: NotRequired[bool]
timeout: NotRequired[int]
wait_timeout: NotRequired[int]
max_lifetime_session: NotRequired[int]
session_callback: NotRequired["Callable[..., Any]"]
session_callback: NotRequired[Callable[..., Any]]
max_sessions_per_shard: NotRequired[int]
soda_metadata_cache: NotRequired[bool]
ping_interval: NotRequired[int]
extra: NotRequired["dict[str, Any]"]
ping_timeout: NotRequired[int]


class OracleDriverFeatures(TypedDict):
Expand Down Expand Up @@ -122,26 +175,31 @@ class OracleDriverFeatures(TypedDict):
For sync: Callable[[OracleSyncConnection, str], None] - receives connection and tag
For async: Callable[[OracleAsyncConnection, str], Awaitable[None]]
Called after internal setup (numpy vectors, UUID handlers).
enable_events: Enable database event channel support.
enable_events: Enable SQLSpec event queue support.
Defaults to True when extension_config["events"] is configured.
Provides pub/sub capabilities via Oracle Advanced Queuing or table-backed fallback.
Requires extension_config["events"] for migration setup when using table_queue backend.
This is separate from connection_config["events"], which enables python-oracledb
Thick mode database event notifications for HA and continuous query notification.
events_backend: Event channel backend selection.
Options: "advanced_queue", "table_queue"
- "advanced_queue": Oracle Advanced Queuing (native messaging, requires DBMS_AQADM privileges)
- "table_queue": Durable table-backed queue with retries and exactly-once delivery
Defaults to "table_queue" (works on all Oracle editions without special privileges).
Native pipeline execution is runtime-gated by driver API support, Oracle Database
version, and the SQLSPEC_ORACLE_DISABLE_PIPELINE environment override; there is
no adapter config switch that can force-enable unsupported pipeline execution.
"""

enable_numpy_vectors: NotRequired[bool]
enable_lowercase_column_names: NotRequired[bool]
enable_uuid_binary: NotRequired[bool]
vector_return_format: NotRequired[str]
vector_return_format: NotRequired[OracleVectorReturnFormat]
oracle_varchar2_byte_limit: NotRequired[int]
oracle_raw_byte_limit: NotRequired[int]
on_connection_create: "NotRequired[Callable[..., Any]]"
on_connection_create: NotRequired[Callable[..., Any]]
enable_events: NotRequired[bool]
events_backend: NotRequired[str]
events_backend: NotRequired[OracleEventsBackend]


class OracleSyncConnectionContext(SyncPoolConnectionContext):
Expand Down Expand Up @@ -193,7 +251,7 @@ def release_connection(self, _conn: "OracleSyncConnection", **kwargs: Any) -> No
class OracleSyncConfig(SyncDatabaseConfig[OracleSyncConnection, "OracleSyncConnectionPool", OracleSyncDriver]):
"""Configuration for Oracle synchronous database connections."""

__slots__ = ("_user_connection_hook",)
__slots__ = ("_pool_session_callback", "_user_connection_hook")

driver_type: ClassVar[type[OracleSyncDriver]] = OracleSyncDriver
connection_type: "ClassVar[type[OracleSyncConnection]]" = OracleSyncConnection
Expand Down Expand Up @@ -234,6 +292,9 @@ def __init__(
**kwargs: Additional keyword arguments.
"""
connection_config = normalize_connection_config(connection_config)
self._pool_session_callback = cast(
"Callable[[OracleSyncConnection, str], None] | None", connection_config.pop("session_callback", None)
)
statement_config = statement_config or default_statement_config

driver_features = apply_driver_features(driver_features)
Expand All @@ -259,6 +320,7 @@ def _create_pool(self) -> "OracleSyncConnectionPool":
"""Create the actual connection pool."""
config = dict(self.connection_config)

config.pop("threaded", None)
config["session_callback"] = self._init_connection

return oracledb.create_pool(**config)
Expand Down Expand Up @@ -297,7 +359,9 @@ def _init_connection(self, connection: "OracleSyncConnection", tag: str) -> None
# dispatch without re-reading driver-feature defaults on every fetch.
setattr(connection, "_sqlspec_vector_return_format", self.driver_features.get("vector_return_format"))

# Call user-provided callback after internal setup
if self._pool_session_callback is not None:
self._pool_session_callback(connection, tag)

if self._user_connection_hook is not None:
self._user_connection_hook(connection, tag)

Expand Down Expand Up @@ -392,7 +456,7 @@ class _OracleAsyncSessionConnectionHandler(AsyncPoolSessionFactory):
class OracleAsyncConfig(AsyncDatabaseConfig[OracleAsyncConnection, "OracleAsyncConnectionPool", OracleAsyncDriver]):
"""Configuration for Oracle asynchronous database connections."""

__slots__ = ("_user_connection_hook",)
__slots__ = ("_pool_session_callback", "_user_connection_hook")

connection_type: "ClassVar[type[OracleAsyncConnection]]" = OracleAsyncConnection
driver_type: ClassVar[type[OracleAsyncDriver]] = OracleAsyncDriver
Expand Down Expand Up @@ -435,6 +499,9 @@ def __init__(
**kwargs: Additional keyword arguments.
"""
connection_config = normalize_connection_config(connection_config)
self._pool_session_callback = cast(
"Callable[[OracleAsyncConnection, str], Any] | None", connection_config.pop("session_callback", None)
)

driver_features = apply_driver_features(driver_features)

Expand All @@ -459,6 +526,7 @@ async def _create_pool(self) -> "OracleAsyncConnectionPool":
"""Create the actual async connection pool."""
config = dict(self.connection_config)

config.pop("threaded", None)
config["session_callback"] = self._init_connection

return oracledb.create_pool_async(**config)
Expand Down Expand Up @@ -493,9 +561,15 @@ async def _init_connection(self, connection: "OracleAsyncConnection", tag: str)
# dispatch without re-reading driver-feature defaults on every fetch.
setattr(connection, "_sqlspec_vector_return_format", self.driver_features.get("vector_return_format"))

# Call user-provided callback after internal setup
if self._pool_session_callback is not None:
session_callback_result = self._pool_session_callback(connection, tag)
if isawaitable(session_callback_result):
await session_callback_result

if self._user_connection_hook is not None:
await self._user_connection_hook(connection, tag)
hook_result = self._user_connection_hook(connection, tag)
if isawaitable(hook_result):
await hook_result

async def _close_pool(self) -> None:
"""Close the actual async connection pool."""
Expand Down
Loading
Loading