diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 414f8ce856..527e036a0e 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -23,13 +23,11 @@ jobs: steps: - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 - - name: Install Rye - run: | - curl -sSf https://rye.astral.sh/get | bash - echo "$HOME/.rye/shims" >> $GITHUB_PATH - env: - RYE_VERSION: '0.44.0' - RYE_INSTALL_OPTION: '--yes' + - name: Set up Rye + uses: eifinger/setup-rye@c694239a43768373e87d0103d7f547027a23f3c8 + with: + version: '0.44.0' + enable-cache: true - name: Install dependencies run: rye sync --all-features @@ -48,13 +46,11 @@ jobs: steps: - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 - - name: Install Rye - run: | - curl -sSf https://rye.astral.sh/get | bash - echo "$HOME/.rye/shims" >> $GITHUB_PATH - env: - RYE_VERSION: '0.44.0' - RYE_INSTALL_OPTION: '--yes' + - name: Set up Rye + uses: eifinger/setup-rye@c694239a43768373e87d0103d7f547027a23f3c8 + with: + version: '0.44.0' + enable-cache: true - name: Install dependencies run: rye sync --all-features @@ -89,13 +85,11 @@ jobs: steps: - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 - - name: Install Rye - run: | - curl -sSf https://rye.astral.sh/get | bash - echo "$HOME/.rye/shims" >> $GITHUB_PATH - env: - RYE_VERSION: '0.44.0' - RYE_INSTALL_OPTION: '--yes' + - name: Set up Rye + uses: eifinger/setup-rye@c694239a43768373e87d0103d7f547027a23f3c8 + with: + version: '0.44.0' + enable-cache: true - name: Bootstrap run: ./scripts/bootstrap @@ -112,13 +106,11 @@ jobs: steps: - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 - - name: Install Rye - run: | - curl -sSf https://rye.astral.sh/get | bash - echo "$HOME/.rye/shims" >> $GITHUB_PATH - env: - RYE_VERSION: '0.44.0' - RYE_INSTALL_OPTION: '--yes' + - name: Set up Rye + uses: eifinger/setup-rye@c694239a43768373e87d0103d7f547027a23f3c8 + with: + version: '0.44.0' + enable-cache: true - name: Install dependencies run: | rye sync --all-features diff --git a/.github/workflows/create-releases.yml b/.github/workflows/create-releases.yml index 98b7f20ffe..f819d07310 100644 --- a/.github/workflows/create-releases.yml +++ b/.github/workflows/create-releases.yml @@ -22,14 +22,12 @@ jobs: repo: ${{ github.event.repository.full_name }} stainless-api-key: ${{ secrets.STAINLESS_API_KEY }} - - name: Install Rye + - name: Set up Rye if: ${{ steps.release.outputs.releases_created }} - run: | - curl -sSf https://rye.astral.sh/get | bash - echo "$HOME/.rye/shims" >> $GITHUB_PATH - env: - RYE_VERSION: '0.44.0' - RYE_INSTALL_OPTION: '--yes' + uses: eifinger/setup-rye@c694239a43768373e87d0103d7f547027a23f3c8 + with: + version: '0.44.0' + enable-cache: true - name: Publish to PyPI if: ${{ steps.release.outputs.releases_created }} diff --git a/.github/workflows/detect-breaking-changes.yml b/.github/workflows/detect-breaking-changes.yml index bd73bb6e2a..641129f010 100644 --- a/.github/workflows/detect-breaking-changes.yml +++ b/.github/workflows/detect-breaking-changes.yml @@ -20,13 +20,11 @@ jobs: # Ensure we can check out the pull request base in the script below. fetch-depth: ${{ env.FETCH_DEPTH }} - - name: Install Rye - run: | - curl -sSf https://rye.astral.sh/get | bash - echo "$HOME/.rye/shims" >> $GITHUB_PATH - env: - RYE_VERSION: '0.44.0' - RYE_INSTALL_OPTION: '--yes' + - name: Set up Rye + uses: eifinger/setup-rye@c694239a43768373e87d0103d7f547027a23f3c8 + with: + version: '0.44.0' + enable-cache: true - name: Install dependencies run: | rye sync --all-features @@ -49,14 +47,12 @@ jobs: with: path: openai-python - - name: Install Rye - working-directory: openai-python - run: | - curl -sSf https://rye.astral.sh/get | bash - echo "$HOME/.rye/shims" >> $GITHUB_PATH - env: - RYE_VERSION: '0.44.0' - RYE_INSTALL_OPTION: '--yes' + - name: Set up Rye + uses: eifinger/setup-rye@c694239a43768373e87d0103d7f547027a23f3c8 + with: + version: '0.44.0' + enable-cache: true + working-directory: openai-python - name: Install dependencies working-directory: openai-python @@ -85,4 +81,3 @@ jobs: - name: Run integration type checks working-directory: openai-agents-python run: make mypy - diff --git a/.github/workflows/publish-pypi.yml b/.github/workflows/publish-pypi.yml index f67347f2fa..a7c62c4c4d 100644 --- a/.github/workflows/publish-pypi.yml +++ b/.github/workflows/publish-pypi.yml @@ -13,13 +13,11 @@ jobs: steps: - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6 - - name: Install Rye - run: | - curl -sSf https://rye.astral.sh/get | bash - echo "$HOME/.rye/shims" >> $GITHUB_PATH - env: - RYE_VERSION: '0.44.0' - RYE_INSTALL_OPTION: '--yes' + - name: Set up Rye + uses: eifinger/setup-rye@c694239a43768373e87d0103d7f547027a23f3c8 + with: + version: '0.44.0' + enable-cache: true - name: Publish to PyPI run: | diff --git a/.release-please-manifest.json b/.release-please-manifest.json index 8f5c652f67..527d2e30b0 100644 --- a/.release-please-manifest.json +++ b/.release-please-manifest.json @@ -1,3 +1,3 @@ { - ".": "2.31.0" + ".": "2.32.0" } \ No newline at end of file diff --git a/.stats.yml b/.stats.yml index 461d4c7c07..f069d6c8b9 100644 --- a/.stats.yml +++ b/.stats.yml @@ -1,4 +1,4 @@ configured_endpoints: 152 -openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/openai%2Fopenai-a6eca1bd01e0c434af356fe5275c206057216a4e626d1051d294c27016cd6d05.yml -openapi_spec_hash: 68abda9122013a9ae3f084cfdbe8e8c1 -config_hash: 4975e16a94e8f9901428022044131888 +openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/openai%2Fopenai-7c540cce6eb30401259f4831ea9803b6d88501605d13734f98212cbb3b199e10.yml +openapi_spec_hash: 06e656be22bbb92689954253668b42fc +config_hash: 1a88b104658b6c854117996c080ebe6b diff --git a/CHANGELOG.md b/CHANGELOG.md index bf8092f51f..2c3728032e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,27 @@ # Changelog +## 2.32.0 (2026-04-14) + +Full Changelog: [v2.31.0...v2.32.0](https://github.com/openai/openai-python/compare/v2.31.0...v2.32.0) + +### Features + +* **api:** Add detail to InputFileContent ([60de21d](https://github.com/openai/openai-python/commit/60de21d1fcfbcadea0d9b8d884c73c9dc49d14ff)) +* **api:** add OAuthErrorCode type ([0c8d2c3](https://github.com/openai/openai-python/commit/0c8d2c3b44242c9139dc554896ea489b56e236b8)) +* **client:** add event handler implementation for websockets ([0280d05](https://github.com/openai/openai-python/commit/0280d0568f706684ecbf0aabf3575cdcb7fd22d5)) +* **client:** allow enqueuing to websockets even when not connected ([67aa20e](https://github.com/openai/openai-python/commit/67aa20e69bc0e4a3b7694327c808606bfa24a966)) +* **client:** support reconnection in websockets ([eb72a95](https://github.com/openai/openai-python/commit/eb72a953ea9dc5beec3eef537be6eb32292c3f65)) + + +### Bug Fixes + +* ensure file data are only sent as 1 parameter ([c0c2ecd](https://github.com/openai/openai-python/commit/c0c2ecd0f6b64fa5fafda6134bb06995b143a2cf)) + + +### Documentation + +* improve examples ([84712fa](https://github.com/openai/openai-python/commit/84712fa0f094b53151a0fe6ac85aa98018b2a7e2)) + ## 2.31.0 (2026-04-08) Full Changelog: [v2.30.0...v2.31.0](https://github.com/openai/openai-python/compare/v2.30.0...v2.31.0) diff --git a/pyproject.toml b/pyproject.toml index c85db6b519..d0d533e8a6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "openai" -version = "2.31.0" +version = "2.32.0" description = "The official Python library for the openai API" dynamic = ["readme"] license = "Apache-2.0" diff --git a/src/openai/__init__.py b/src/openai/__init__.py index 3e0a135929..fc9675a8b5 100644 --- a/src/openai/__init__.py +++ b/src/openai/__init__.py @@ -29,14 +29,17 @@ InternalServerError, PermissionDeniedError, LengthFinishReasonError, + WebSocketQueueFullError, UnprocessableEntityError, APIResponseValidationError, InvalidWebhookSignatureError, ContentFilterFinishReasonError, + WebSocketConnectionClosedError, ) from ._base_client import DefaultHttpxClient, DefaultAioHttpClient, DefaultAsyncHttpxClient from ._utils._logs import setup_logging as _setup_logging from ._legacy_response import HttpxBinaryResponseContent as HttpxBinaryResponseContent +from .types.websocket_reconnection import ReconnectingEvent, ReconnectingOverrides __all__ = [ "types", @@ -84,6 +87,10 @@ "DefaultHttpxClient", "DefaultAsyncHttpxClient", "DefaultAioHttpClient", + "ReconnectingEvent", + "ReconnectingOverrides", + "WebSocketQueueFullError", + "WebSocketConnectionClosedError", ] if not _t.TYPE_CHECKING: diff --git a/src/openai/_event_handler.py b/src/openai/_event_handler.py new file mode 100644 index 0000000000..d9a3a59d71 --- /dev/null +++ b/src/openai/_event_handler.py @@ -0,0 +1,85 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +import threading +from typing import Any, Callable + +EventHandler = Callable[..., Any] + + +class EventHandlerRegistry: + """Thread-safe (optional) registry of event handlers.""" + + def __init__(self, *, use_lock: bool = False) -> None: + self._handlers: dict[str, list[EventHandler]] = {} + self._once_ids: set[int] = set() + self._lock: threading.Lock | None = threading.Lock() if use_lock else None + + def _acquire(self) -> None: + if self._lock is not None: + self._lock.acquire() + + def _release(self) -> None: + if self._lock is not None: + self._lock.release() + + def add(self, event_type: str, handler: EventHandler, *, once: bool = False) -> None: + self._acquire() + try: + handlers = self._handlers.setdefault(event_type, []) + handlers.append(handler) + if once: + self._once_ids.add(id(handler)) + finally: + self._release() + + def remove(self, event_type: str, handler: EventHandler) -> None: + self._acquire() + try: + handlers = self._handlers.get(event_type) + if handlers is not None: + try: + handlers.remove(handler) + except ValueError: + pass + self._once_ids.discard(id(handler)) + finally: + self._release() + + def get_handlers(self, event_type: str) -> list[EventHandler]: + """Return a snapshot of handlers for the given event type, removing once-handlers.""" + self._acquire() + try: + handlers = self._handlers.get(event_type) + if not handlers: + return [] + result = list(handlers) + to_remove = [h for h in result if id(h) in self._once_ids] + for h in to_remove: + handlers.remove(h) + self._once_ids.discard(id(h)) + return result + finally: + self._release() + + def has_handlers(self, event_type: str) -> bool: + self._acquire() + try: + handlers = self._handlers.get(event_type) + return bool(handlers) + finally: + self._release() + + def merge_into(self, target: EventHandlerRegistry) -> None: + """Move all handlers from this registry into *target*, then clear self.""" + self._acquire() + try: + for event_type, handlers in self._handlers.items(): + for handler in handlers: + once = id(handler) in self._once_ids + target.add(event_type, handler, once=once) + self._handlers.clear() + self._once_ids.clear() + finally: + self._release() diff --git a/src/openai/_exceptions.py b/src/openai/_exceptions.py index a37ed8ca82..86f44b0e15 100644 --- a/src/openai/_exceptions.py +++ b/src/openai/_exceptions.py @@ -28,6 +28,8 @@ "ContentFilterFinishReasonError", "InvalidWebhookSignatureError", "SubjectTokenProviderError", + "WebSocketConnectionClosedError", + "WebSocketQueueFullError", ] @@ -187,3 +189,19 @@ def __init__(self) -> None: class InvalidWebhookSignatureError(ValueError): """Raised when a webhook signature is invalid, meaning the computed signature does not match the expected signature.""" + + +class WebSocketConnectionClosedError(OpenAIError): + """Raised when a WebSocket connection closes with unsent messages.""" + + unsent_messages: list[str] + + def __init__(self, message: str, *, unsent_messages: list[str]) -> None: + super().__init__(message) + self.unsent_messages = unsent_messages + + +class WebSocketQueueFullError(OpenAIError): + """Raised when the outgoing WebSocket message queue exceeds its byte-size limit.""" + + pass diff --git a/src/openai/_send_queue.py b/src/openai/_send_queue.py new file mode 100644 index 0000000000..b35d0fbcba --- /dev/null +++ b/src/openai/_send_queue.py @@ -0,0 +1,90 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +import typing +import threading + +from ._exceptions import WebSocketQueueFullError + + +class SendQueue: + """Bounded byte-size queue for outgoing WebSocket messages. + + Messages are stored as pre-serialized strings. The queue enforces a + maximum byte budget so that unbounded buffering cannot occur during + reconnection windows. + """ + + def __init__(self, max_bytes: int = 1_048_576) -> None: + self._queue: list[tuple[str, int]] = [] # (data, byte_length) + self._bytes: int = 0 + self._max_bytes = max_bytes + self._lock = threading.Lock() + + def enqueue(self, data: str) -> None: + """Append *data* to the queue. + + Raises :class:`WebSocketQueueFullError` if the message would + exceed the byte-size limit. + """ + byte_length = len(data.encode("utf-8")) + with self._lock: + if self._bytes + byte_length > self._max_bytes: + raise WebSocketQueueFullError("send queue is full, message discarded") + self._queue.append((data, byte_length)) + self._bytes += byte_length + + def flush_sync(self, send: typing.Callable[[str], object]) -> None: + """Send every queued message via *send*. + + If *send* raises, the failing message and all subsequent messages + are re-queued and the error is re-raised. + """ + with self._lock: + pending = list(self._queue) + self._queue.clear() + self._bytes = 0 + + for i, (data, _byte_length) in enumerate(pending): + try: + send(data) + except Exception: + with self._lock: + remaining = pending[i:] + self._queue = remaining + self._queue + self._bytes = sum(bl for _, bl in self._queue) + raise + + async def flush_async(self, send: typing.Callable[[str], typing.Awaitable[object]]) -> None: + """Async variant of :meth:`flush_sync`.""" + with self._lock: + pending = list(self._queue) + self._queue.clear() + self._bytes = 0 + + for i, (data, _byte_length) in enumerate(pending): + try: + await send(data) + except Exception: + with self._lock: + remaining = pending[i:] + self._queue = remaining + self._queue + self._bytes = sum(bl for _, bl in self._queue) + raise + + def drain(self) -> list[str]: + """Remove and return all queued messages.""" + with self._lock: + items = [data for data, _ in self._queue] + self._queue.clear() + self._bytes = 0 + return items + + def __len__(self) -> int: + with self._lock: + return len(self._queue) + + def __bool__(self) -> bool: + with self._lock: + return len(self._queue) > 0 diff --git a/src/openai/_utils/_utils.py b/src/openai/_utils/_utils.py index 90494748cc..b1e8e0d041 100644 --- a/src/openai/_utils/_utils.py +++ b/src/openai/_utils/_utils.py @@ -90,8 +90,9 @@ def _extract_items( index += 1 if is_dict(obj): try: - # We are at the last entry in the path so we must remove the field - if (len(path)) == index: + # Remove the field if there are no more dict keys in the path, + # only "" traversal markers or end. + if all(p == "" for p in path[index:]): item = obj.pop(key) else: item = obj[key] diff --git a/src/openai/_version.py b/src/openai/_version.py index a435bdb765..a98695ba0e 100644 --- a/src/openai/_version.py +++ b/src/openai/_version.py @@ -1,4 +1,4 @@ # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. __title__ = "openai" -__version__ = "2.31.0" # x-release-please-version +__version__ = "2.32.0" # x-release-please-version diff --git a/src/openai/resources/realtime/realtime.py b/src/openai/resources/realtime/realtime.py index 82c9815b03..e4c5bd8163 100644 --- a/src/openai/resources/realtime/realtime.py +++ b/src/openai/resources/realtime/realtime.py @@ -3,9 +3,11 @@ from __future__ import annotations import json +import time +import random import logging from types import TracebackType -from typing import TYPE_CHECKING, Any, Iterator, cast +from typing import TYPE_CHECKING, Any, Union, Callable, Iterator, Awaitable, cast from typing_extensions import AsyncIterator import httpx @@ -30,7 +32,8 @@ from ..._compat import cached_property from ..._models import construct_type_unchecked from ..._resource import SyncAPIResource, AsyncAPIResource -from ..._exceptions import OpenAIError +from ..._exceptions import OpenAIError, WebSocketConnectionClosedError +from ..._send_queue import SendQueue from ..._base_client import _merge_mappings from .client_secrets import ( ClientSecrets, @@ -40,8 +43,11 @@ ClientSecretsWithStreamingResponse, AsyncClientSecretsWithStreamingResponse, ) +from ..._event_handler import EventHandlerRegistry from ...types.realtime import session_update_event_param +from ...types.websocket_reconnection import ReconnectingEvent, ReconnectingOverrides, is_recoverable_close from ...types.websocket_connection_options import WebSocketConnectionOptions +from ...types.realtime.realtime_error_event import RealtimeErrorEvent from ...types.realtime.realtime_client_event import RealtimeClientEvent from ...types.realtime.realtime_server_event import RealtimeServerEvent from ...types.realtime.conversation_item_param import ConversationItemParam @@ -97,6 +103,11 @@ def connect( extra_query: Query = {}, extra_headers: Headers = {}, websocket_connection_options: WebSocketConnectionOptions = {}, + on_reconnecting: Callable[[ReconnectingEvent], ReconnectingOverrides | None] | None = None, + max_retries: int = 5, + initial_delay: float = 0.5, + max_delay: float = 8.0, + max_queue_size: int = 1_048_576, ) -> RealtimeConnectionManager: """ The Realtime API enables you to build low-latency, multi-modal conversational experiences. It currently supports text and audio as both input and output, as well as function calling. @@ -114,6 +125,11 @@ def connect( extra_query=extra_query, extra_headers=extra_headers, websocket_connection_options=websocket_connection_options, + on_reconnecting=on_reconnecting, + max_retries=max_retries, + initial_delay=initial_delay, + max_delay=max_delay, + max_queue_size=max_queue_size, call_id=call_id, model=model, ) @@ -157,6 +173,11 @@ def connect( extra_query: Query = {}, extra_headers: Headers = {}, websocket_connection_options: WebSocketConnectionOptions = {}, + on_reconnecting: Callable[[ReconnectingEvent], ReconnectingOverrides | None] | None = None, + max_retries: int = 5, + initial_delay: float = 0.5, + max_delay: float = 8.0, + max_queue_size: int = 1_048_576, ) -> AsyncRealtimeConnectionManager: """ The Realtime API enables you to build low-latency, multi-modal conversational experiences. It currently supports text and audio as both input and output, as well as function calling. @@ -174,6 +195,11 @@ def connect( extra_query=extra_query, extra_headers=extra_headers, websocket_connection_options=websocket_connection_options, + on_reconnecting=on_reconnecting, + max_retries=max_retries, + initial_delay=initial_delay, + max_delay=max_delay, + max_queue_size=max_queue_size, call_id=call_id, model=model, ) @@ -242,8 +268,31 @@ class AsyncRealtimeConnection: _connection: AsyncWebSocketConnection - def __init__(self, connection: AsyncWebSocketConnection) -> None: + def __init__( + self, + connection: AsyncWebSocketConnection, + *, + make_ws: Callable[[Query, Headers], Awaitable[AsyncWebSocketConnection]] | None = None, + on_reconnecting: Callable[[ReconnectingEvent], ReconnectingOverrides | None] | None = None, + max_retries: int = 5, + initial_delay: float = 0.5, + max_delay: float = 8.0, + extra_query: Query = {}, + extra_headers: Headers = {}, + send_queue: SendQueue | None = None, + ) -> None: self._connection = connection + self._make_ws = make_ws + self._on_reconnecting = on_reconnecting + self._max_retries = max_retries + self._initial_delay = initial_delay + self._max_delay = max_delay + self._extra_query = extra_query + self._extra_headers = extra_headers + self._intentionally_closed = False + self._is_reconnecting = False + self._send_queue = send_queue or SendQueue() + self._event_handler_registry = EventHandlerRegistry(use_lock=False) self.session = AsyncRealtimeSessionResource(self) self.response = AsyncRealtimeResponseResource(self) @@ -256,13 +305,22 @@ async def __aiter__(self) -> AsyncIterator[RealtimeServerEvent]: An infinite-iterator that will continue to yield events until the connection is closed. """ - from websockets.exceptions import ConnectionClosedOK + from websockets.exceptions import ConnectionClosedOK, ConnectionClosedError - try: - while True: + while True: + try: yield await self.recv() - except ConnectionClosedOK: - return + except ConnectionClosedOK: + return + except ConnectionClosedError as exc: + if not await self._reconnect(exc): + unsent = self._send_queue.drain() + if unsent: + raise WebSocketConnectionClosedError( + "WebSocket connection closed with unsent messages", + unsent_messages=unsent, + ) from exc + raise async def recv(self) -> RealtimeServerEvent: """ @@ -290,12 +348,24 @@ async def send(self, event: RealtimeClientEvent | RealtimeClientEventParam) -> N if isinstance(event, BaseModel) else json.dumps(await async_maybe_transform(event, RealtimeClientEventParam)) ) - await self._connection.send(data) + if self._is_reconnecting: + self._send_queue.enqueue(data) + return + try: + await self._connection.send(data) + except Exception: + self._send_queue.enqueue(data) + raise async def send_raw(self, data: bytes | str) -> None: + if self._is_reconnecting: + raw = data if isinstance(data, str) else data.decode("utf-8") + self._send_queue.enqueue(raw) + return await self._connection.send(data) async def close(self, *, code: int = 1000, reason: str = "") -> None: + self._intentionally_closed = True await self._connection.close(code=code, reason=reason) def parse_event(self, data: str | bytes) -> RealtimeServerEvent: @@ -308,6 +378,173 @@ def parse_event(self, data: str | bytes) -> RealtimeServerEvent: RealtimeServerEvent, construct_type_unchecked(value=json.loads(data), type_=cast(Any, RealtimeServerEvent)) ) + async def _reconnect(self, exc: Exception) -> bool: + """Attempt to reconnect after a connection failure. + + Returns ``True`` if a new connection was established, ``False`` if the + caller should re-raise the original exception. + """ + import asyncio + + if self._on_reconnecting is None or self._make_ws is None: + return False + + from websockets.exceptions import ConnectionClosedError + + close_code = 1006 + if isinstance(exc, ConnectionClosedError) and exc.rcvd is not None: + close_code = exc.rcvd.code + + if not is_recoverable_close(close_code): + return False + + self._is_reconnecting = True + + for attempt in range(1, self._max_retries + 1): + base_delay = min(self._initial_delay * (2 ** (attempt - 1)), self._max_delay) + jitter = 0.75 + random.random() * 0.25 + delay = base_delay * jitter + + event = ReconnectingEvent( + attempt=attempt, + max_attempts=self._max_retries, + delay=delay, + close_code=close_code, + extra_query=self._extra_query, + extra_headers=self._extra_headers, + ) + + try: + result = self._on_reconnecting(event) + except Exception: + self._is_reconnecting = False + return False + + if result is not None and result.get("abort"): + self._is_reconnecting = False + return False + + if result is not None: + if "extra_query" in result: + self._extra_query = result["extra_query"] + if "extra_headers" in result: + self._extra_headers = result["extra_headers"] + + log.info( + "Reconnecting to WebSocket API (attempt %d/%d) after %.1fs delay", + attempt, + self._max_retries, + delay, + ) + await asyncio.sleep(delay) + + if self._intentionally_closed: + self._is_reconnecting = False + return False + + try: + self._connection = await self._make_ws(self._extra_query, self._extra_headers) + log.info("Reconnected to WebSocket API") + self._is_reconnecting = False + await self._flush_send_queue() + return True + except Exception: + pass + + self._is_reconnecting = False + return False + + async def _flush_send_queue(self) -> None: + """Send all queued messages over the current connection.""" + + async def _send(data: str) -> None: + await self._connection.send(data) + + try: + await self._send_queue.flush_async(_send) + except Exception: + log.warning("Failed to flush send queue after reconnect", exc_info=True) + + def on( + self, event_type: str, handler: Callable[..., Any] | None = None + ) -> Union[AsyncRealtimeConnection, Callable[[Callable[..., Any]], Callable[..., Any]]]: + """Adds the handler to the end of the handlers list for the given event type. + + No checks are made to see if the handler has already been added. Multiple calls + passing the same combination of event type and handler will result in the handler + being added, and called, multiple times. + + Can be used as a method (returns ``self`` for chaining):: + + connection.on("conversation.created", my_handler) + + Or as a decorator:: + + @connection.on("conversation.created") + async def my_handler(event): ... + """ + if handler is not None: + self._event_handler_registry.add(event_type, handler) + return self + + def decorator(fn: Callable[..., Any]) -> Callable[..., Any]: + self._event_handler_registry.add(event_type, fn) + return fn + + return decorator + + def off(self, event_type: str, handler: Callable[..., Any]) -> AsyncRealtimeConnection: + """Remove a previously registered event handler.""" + self._event_handler_registry.remove(event_type, handler) + return self + + def once( + self, event_type: str, handler: Callable[..., Any] | None = None + ) -> Union[AsyncRealtimeConnection, Callable[[Callable[..., Any]], Callable[..., Any]]]: + """Register a one-time event handler. + + Automatically removed after first invocation. + """ + if handler is not None: + self._event_handler_registry.add(event_type, handler, once=True) + return self + + def decorator(fn: Callable[..., Any]) -> Callable[..., Any]: + self._event_handler_registry.add(event_type, fn, once=True) + return fn + + return decorator + + async def dispatch_events(self) -> None: + """Run the event loop, dispatching received events to registered handlers. + + Blocks until the connection is closed. This is the push-based + alternative to iterating with ``async for event in connection``. + + If an ``"error"`` event arrives and no handler is registered for + ``"error"`` or ``"event"``, an ``OpenAIError`` is raised. + """ + import asyncio + + async for event in self: + event_type = event.type + specific = self._event_handler_registry.get_handlers(event_type) + generic = self._event_handler_registry.get_handlers("event") + + if event_type == "error" and not specific and not generic: + if isinstance(event, RealtimeErrorEvent): + raise OpenAIError(f"WebSocket error: {event}") + + for handler in specific: + result = handler(event) + if asyncio.iscoroutine(result): + await result + + for handler in generic: + result = handler(event) + if asyncio.iscoroutine(result): + await result + class AsyncRealtimeConnectionManager: """ @@ -338,6 +575,11 @@ def __init__( extra_query: Query, extra_headers: Headers, websocket_connection_options: WebSocketConnectionOptions, + on_reconnecting: Callable[[ReconnectingEvent], ReconnectingOverrides | None] | None = None, + max_retries: int = 5, + initial_delay: float = 0.5, + max_delay: float = 8.0, + max_queue_size: int = 1_048_576, ) -> None: self.__client = client self.__call_id = call_id @@ -346,10 +588,66 @@ def __init__( self.__extra_query = extra_query self.__extra_headers = extra_headers self.__websocket_connection_options = websocket_connection_options + self.__on_reconnecting = on_reconnecting + self.__max_retries = max_retries + self.__initial_delay = initial_delay + self.__max_delay = max_delay + self.__send_queue = SendQueue(max_bytes=max_queue_size) + self.__event_handler_registry = EventHandlerRegistry(use_lock=False) + + def send(self, event: RealtimeClientEvent | RealtimeClientEventParam) -> None: + """Queue a message to be sent when the connection is established. + + This can be called before entering the context manager. Queued messages + are automatically sent once the WebSocket connection opens. + """ + data = ( + event.to_json(use_api_names=True, exclude_defaults=True, exclude_unset=True) + if isinstance(event, BaseModel) + else json.dumps(event) + ) + self.__send_queue.enqueue(data) + + def on( + self, event_type: str, handler: Callable[..., Any] | None = None + ) -> Union[AsyncRealtimeConnectionManager, Callable[[Callable[..., Any]], Callable[..., Any]]]: + """Register an event handler before the connection is established. + + Handlers are transferred to the connection on enter. Supports the + same method and decorator forms as ``AsyncRealtimeConnection.on``. + """ + if handler is not None: + self.__event_handler_registry.add(event_type, handler) + return self + + def decorator(fn: Callable[..., Any]) -> Callable[..., Any]: + self.__event_handler_registry.add(event_type, fn) + return fn + + return decorator + + def off(self, event_type: str, handler: Callable[..., Any]) -> AsyncRealtimeConnectionManager: + """Remove a previously registered event handler.""" + self.__event_handler_registry.remove(event_type, handler) + return self + + def once( + self, event_type: str, handler: Callable[..., Any] | None = None + ) -> Union[AsyncRealtimeConnectionManager, Callable[[Callable[..., Any]], Callable[..., Any]]]: + """Register a one-time event handler before the connection is established.""" + if handler is not None: + self.__event_handler_registry.add(event_type, handler, once=True) + return self + + def decorator(fn: Callable[..., Any]) -> Callable[..., Any]: + self.__event_handler_registry.add(event_type, fn, once=True) + return fn + + return decorator async def __aenter__(self) -> AsyncRealtimeConnection: """ - ๐Ÿ‘‹ If your application doesn't work well with the context manager approach then you + If your application doesn't work well with the context manager approach then you can call this method directly to initiate a connection. **Warning**: You must remember to close the connection with `.close()`. @@ -360,15 +658,35 @@ async def __aenter__(self) -> AsyncRealtimeConnection: await connection.close() ``` """ + ws = await self._connect_ws(self.__extra_query, self.__extra_headers) + + self.__connection = AsyncRealtimeConnection( + ws, + make_ws=self._connect_ws if self.__on_reconnecting is not None else None, + on_reconnecting=self.__on_reconnecting, + max_retries=self.__max_retries, + initial_delay=self.__initial_delay, + max_delay=self.__max_delay, + extra_query=self.__extra_query, + extra_headers=self.__extra_headers, + send_queue=self.__send_queue, + ) + + self.__event_handler_registry.merge_into(self.__connection._event_handler_registry) + await self.__connection._flush_send_queue() + + return self.__connection + + enter = __aenter__ + + async def _connect_ws(self, extra_query: Query, extra_headers: Headers) -> AsyncWebSocketConnection: try: from websockets.asyncio.client import connect except ImportError as exc: raise OpenAIError("You need to install `openai[realtime]` to use this method") from exc - extra_query = self.__extra_query await self.__client._refresh_api_key() auth_headers = self.__client.auth_headers - extra_query = self.__extra_query if self.__call_id is not omit: extra_query = {**extra_query, "call_id": self.__call_id} if is_async_azure_client(self.__client): @@ -389,24 +707,18 @@ async def __aenter__(self) -> AsyncRealtimeConnection: if self.__websocket_connection_options: log.debug("Connection options: %s", self.__websocket_connection_options) - self.__connection = AsyncRealtimeConnection( - await connect( - str(url), - user_agent_header=self.__client.user_agent, - additional_headers=_merge_mappings( - { - **auth_headers, - }, - self.__extra_headers, - ), - **self.__websocket_connection_options, - ) + return await connect( + str(url), + user_agent_header=self.__client.user_agent, + additional_headers=_merge_mappings( + { + **auth_headers, + }, + extra_headers, + ), + **self.__websocket_connection_options, ) - return self.__connection - - enter = __aenter__ - def _prepare_url(self) -> httpx.URL: if self.__client.websocket_base_url is not None: base_url = httpx.URL(self.__client.websocket_base_url) @@ -436,8 +748,31 @@ class RealtimeConnection: _connection: WebSocketConnection - def __init__(self, connection: WebSocketConnection) -> None: + def __init__( + self, + connection: WebSocketConnection, + *, + make_ws: Callable[[Query, Headers], WebSocketConnection] | None = None, + on_reconnecting: Callable[[ReconnectingEvent], ReconnectingOverrides | None] | None = None, + max_retries: int = 5, + initial_delay: float = 0.5, + max_delay: float = 8.0, + extra_query: Query = {}, + extra_headers: Headers = {}, + send_queue: SendQueue | None = None, + ) -> None: self._connection = connection + self._make_ws = make_ws + self._on_reconnecting = on_reconnecting + self._max_retries = max_retries + self._initial_delay = initial_delay + self._max_delay = max_delay + self._extra_query = extra_query + self._extra_headers = extra_headers + self._intentionally_closed = False + self._is_reconnecting = False + self._send_queue = send_queue or SendQueue() + self._event_handler_registry = EventHandlerRegistry(use_lock=True) self.session = RealtimeSessionResource(self) self.response = RealtimeResponseResource(self) @@ -450,13 +785,22 @@ def __iter__(self) -> Iterator[RealtimeServerEvent]: An infinite-iterator that will continue to yield events until the connection is closed. """ - from websockets.exceptions import ConnectionClosedOK + from websockets.exceptions import ConnectionClosedOK, ConnectionClosedError - try: - while True: + while True: + try: yield self.recv() - except ConnectionClosedOK: - return + except ConnectionClosedOK: + return + except ConnectionClosedError as exc: + if not self._reconnect(exc): + unsent = self._send_queue.drain() + if unsent: + raise WebSocketConnectionClosedError( + "WebSocket connection closed with unsent messages", + unsent_messages=unsent, + ) from exc + raise def recv(self) -> RealtimeServerEvent: """ @@ -484,12 +828,24 @@ def send(self, event: RealtimeClientEvent | RealtimeClientEventParam) -> None: if isinstance(event, BaseModel) else json.dumps(maybe_transform(event, RealtimeClientEventParam)) ) - self._connection.send(data) + if self._is_reconnecting: + self._send_queue.enqueue(data) + return + try: + self._connection.send(data) + except Exception: + self._send_queue.enqueue(data) + raise def send_raw(self, data: bytes | str) -> None: + if self._is_reconnecting: + raw = data if isinstance(data, str) else data.decode("utf-8") + self._send_queue.enqueue(raw) + return self._connection.send(data) def close(self, *, code: int = 1000, reason: str = "") -> None: + self._intentionally_closed = True self._connection.close(code=code, reason=reason) def parse_event(self, data: str | bytes) -> RealtimeServerEvent: @@ -502,6 +858,161 @@ def parse_event(self, data: str | bytes) -> RealtimeServerEvent: RealtimeServerEvent, construct_type_unchecked(value=json.loads(data), type_=cast(Any, RealtimeServerEvent)) ) + def _reconnect(self, exc: Exception) -> bool: + """Attempt to reconnect after a connection failure. + + Returns ``True`` if a new connection was established, ``False`` if the + caller should re-raise the original exception. + """ + if self._on_reconnecting is None or self._make_ws is None: + return False + + from websockets.exceptions import ConnectionClosedError + + close_code = 1006 + if isinstance(exc, ConnectionClosedError) and exc.rcvd is not None: + close_code = exc.rcvd.code + + if not is_recoverable_close(close_code): + return False + + self._is_reconnecting = True + + for attempt in range(1, self._max_retries + 1): + base_delay = min(self._initial_delay * (2 ** (attempt - 1)), self._max_delay) + jitter = 0.75 + random.random() * 0.25 + delay = base_delay * jitter + + event = ReconnectingEvent( + attempt=attempt, + max_attempts=self._max_retries, + delay=delay, + close_code=close_code, + extra_query=self._extra_query, + extra_headers=self._extra_headers, + ) + + try: + result = self._on_reconnecting(event) + except Exception: + self._is_reconnecting = False + return False + + if result is not None and result.get("abort"): + self._is_reconnecting = False + return False + + if result is not None: + if "extra_query" in result: + self._extra_query = result["extra_query"] + if "extra_headers" in result: + self._extra_headers = result["extra_headers"] + + log.info( + "Reconnecting to WebSocket API (attempt %d/%d) after %.1fs delay", + attempt, + self._max_retries, + delay, + ) + time.sleep(delay) + + if self._intentionally_closed: + self._is_reconnecting = False + return False + + try: + self._connection = self._make_ws(self._extra_query, self._extra_headers) + log.info("Reconnected to WebSocket API") + self._is_reconnecting = False + self._flush_send_queue() + return True + except Exception: + pass + + self._is_reconnecting = False + return False + + def _flush_send_queue(self) -> None: + """Send all queued messages over the current connection.""" + try: + self._send_queue.flush_sync(lambda data: self._connection.send(data)) + except Exception: + log.warning("Failed to flush send queue after reconnect", exc_info=True) + + def on( + self, event_type: str, handler: Callable[..., Any] | None = None + ) -> Union[RealtimeConnection, Callable[[Callable[..., Any]], Callable[..., Any]]]: + """Adds the handler to the end of the handlers list for the given event type. + + No checks are made to see if the handler has already been added. Multiple calls + passing the same combination of event type and handler will result in the handler + being added, and called, multiple times. + + Can be used as a method (returns ``self`` for chaining):: + + connection.on("conversation.created", my_handler) + + Or as a decorator:: + + @connection.on("conversation.created") + def my_handler(event): ... + """ + if handler is not None: + self._event_handler_registry.add(event_type, handler) + return self + + def decorator(fn: Callable[..., Any]) -> Callable[..., Any]: + self._event_handler_registry.add(event_type, fn) + return fn + + return decorator + + def off(self, event_type: str, handler: Callable[..., Any]) -> RealtimeConnection: + """Remove a previously registered event handler.""" + self._event_handler_registry.remove(event_type, handler) + return self + + def once( + self, event_type: str, handler: Callable[..., Any] | None = None + ) -> Union[RealtimeConnection, Callable[[Callable[..., Any]], Callable[..., Any]]]: + """Register a one-time event handler. + + Automatically removed after first invocation. + """ + if handler is not None: + self._event_handler_registry.add(event_type, handler, once=True) + return self + + def decorator(fn: Callable[..., Any]) -> Callable[..., Any]: + self._event_handler_registry.add(event_type, fn, once=True) + return fn + + return decorator + + def dispatch_events(self) -> None: + """Run the event loop, dispatching received events to registered handlers. + + Blocks the current thread until the connection is closed. This is the push-based + alternative to iterating with ``for event in connection``. + + If an ``"error"`` event arrives and no handler is registered for + ``"error"`` or ``"event"``, an ``OpenAIError`` is raised. + """ + for event in self: + event_type = event.type + specific = self._event_handler_registry.get_handlers(event_type) + generic = self._event_handler_registry.get_handlers("event") + + if event_type == "error" and not specific and not generic: + if isinstance(event, RealtimeErrorEvent): + raise OpenAIError(f"WebSocket error: {event}") + + for handler in specific: + handler(event) + + for handler in generic: + handler(event) + class RealtimeConnectionManager: """ @@ -532,6 +1043,11 @@ def __init__( extra_query: Query, extra_headers: Headers, websocket_connection_options: WebSocketConnectionOptions, + on_reconnecting: Callable[[ReconnectingEvent], ReconnectingOverrides | None] | None = None, + max_retries: int = 5, + initial_delay: float = 0.5, + max_delay: float = 8.0, + max_queue_size: int = 1_048_576, ) -> None: self.__client = client self.__call_id = call_id @@ -540,10 +1056,66 @@ def __init__( self.__extra_query = extra_query self.__extra_headers = extra_headers self.__websocket_connection_options = websocket_connection_options + self.__on_reconnecting = on_reconnecting + self.__max_retries = max_retries + self.__initial_delay = initial_delay + self.__max_delay = max_delay + self.__send_queue = SendQueue(max_bytes=max_queue_size) + self.__event_handler_registry = EventHandlerRegistry(use_lock=True) + + def send(self, event: RealtimeClientEvent | RealtimeClientEventParam) -> None: + """Queue a message to be sent when the connection is established. + + This can be called before entering the context manager. Queued messages + are automatically sent once the WebSocket connection opens. + """ + data = ( + event.to_json(use_api_names=True, exclude_defaults=True, exclude_unset=True) + if isinstance(event, BaseModel) + else json.dumps(event) + ) + self.__send_queue.enqueue(data) + + def on( + self, event_type: str, handler: Callable[..., Any] | None = None + ) -> Union[RealtimeConnectionManager, Callable[[Callable[..., Any]], Callable[..., Any]]]: + """Register an event handler before the connection is established. + + Handlers are transferred to the connection on enter. Supports the + same method and decorator forms as ``RealtimeConnection.on``. + """ + if handler is not None: + self.__event_handler_registry.add(event_type, handler) + return self + + def decorator(fn: Callable[..., Any]) -> Callable[..., Any]: + self.__event_handler_registry.add(event_type, fn) + return fn + + return decorator + + def off(self, event_type: str, handler: Callable[..., Any]) -> RealtimeConnectionManager: + """Remove a previously registered event handler.""" + self.__event_handler_registry.remove(event_type, handler) + return self + + def once( + self, event_type: str, handler: Callable[..., Any] | None = None + ) -> Union[RealtimeConnectionManager, Callable[[Callable[..., Any]], Callable[..., Any]]]: + """Register a one-time event handler before the connection is established.""" + if handler is not None: + self.__event_handler_registry.add(event_type, handler, once=True) + return self + + def decorator(fn: Callable[..., Any]) -> Callable[..., Any]: + self.__event_handler_registry.add(event_type, fn, once=True) + return fn + + return decorator def __enter__(self) -> RealtimeConnection: """ - ๐Ÿ‘‹ If your application doesn't work well with the context manager approach then you + If your application doesn't work well with the context manager approach then you can call this method directly to initiate a connection. **Warning**: You must remember to close the connection with `.close()`. @@ -554,15 +1126,35 @@ def __enter__(self) -> RealtimeConnection: connection.close() ``` """ + ws = self._connect_ws(self.__extra_query, self.__extra_headers) + + self.__connection = RealtimeConnection( + ws, + make_ws=self._connect_ws if self.__on_reconnecting is not None else None, + on_reconnecting=self.__on_reconnecting, + max_retries=self.__max_retries, + initial_delay=self.__initial_delay, + max_delay=self.__max_delay, + extra_query=self.__extra_query, + extra_headers=self.__extra_headers, + send_queue=self.__send_queue, + ) + + self.__event_handler_registry.merge_into(self.__connection._event_handler_registry) + self.__connection._flush_send_queue() + + return self.__connection + + enter = __enter__ + + def _connect_ws(self, extra_query: Query, extra_headers: Headers) -> WebSocketConnection: try: from websockets.sync.client import connect except ImportError as exc: raise OpenAIError("You need to install `openai[realtime]` to use this method") from exc - extra_query = self.__extra_query self.__client._refresh_api_key() auth_headers = self.__client.auth_headers - extra_query = self.__extra_query if self.__call_id is not omit: extra_query = {**extra_query, "call_id": self.__call_id} if is_azure_client(self.__client): @@ -583,24 +1175,18 @@ def __enter__(self) -> RealtimeConnection: if self.__websocket_connection_options: log.debug("Connection options: %s", self.__websocket_connection_options) - self.__connection = RealtimeConnection( - connect( - str(url), - user_agent_header=self.__client.user_agent, - additional_headers=_merge_mappings( - { - **auth_headers, - }, - self.__extra_headers, - ), - **self.__websocket_connection_options, - ) + return connect( + str(url), + user_agent_header=self.__client.user_agent, + additional_headers=_merge_mappings( + { + **auth_headers, + }, + extra_headers, + ), + **self.__websocket_connection_options, ) - return self.__connection - - enter = __enter__ - def _prepare_url(self) -> httpx.URL: if self.__client.websocket_base_url is not None: base_url = httpx.URL(self.__client.websocket_base_url) diff --git a/src/openai/resources/responses/responses.py b/src/openai/resources/responses/responses.py index 360c5afa1a..ab5f7688a5 100644 --- a/src/openai/resources/responses/responses.py +++ b/src/openai/resources/responses/responses.py @@ -3,10 +3,25 @@ from __future__ import annotations import json +import time +import random import logging from copy import copy from types import TracebackType -from typing import TYPE_CHECKING, Any, List, Type, Union, Iterable, Iterator, Optional, AsyncIterator, cast +from typing import ( + TYPE_CHECKING, + Any, + List, + Type, + Union, + Callable, + Iterable, + Iterator, + Optional, + Awaitable, + AsyncIterator, + cast, +) from functools import partial from typing_extensions import Literal, overload @@ -38,8 +53,10 @@ InputTokensWithStreamingResponse, AsyncInputTokensWithStreamingResponse, ) -from ..._exceptions import OpenAIError +from ..._exceptions import OpenAIError, WebSocketConnectionClosedError +from ..._send_queue import SendQueue from ..._base_client import _merge_mappings, make_request_options +from ..._event_handler import EventHandlerRegistry from ...types.responses import ( response_create_params, response_compact_params, @@ -54,6 +71,7 @@ from ...types.responses.response import Response from ...types.responses.tool_param import ToolParam, ParseableToolParam from ...types.shared_params.metadata import Metadata +from ...types.websocket_reconnection import ReconnectingEvent, ReconnectingOverrides, is_recoverable_close from ...types.shared_params.reasoning import Reasoning from ...types.responses.parsed_response import ParsedResponse from ...lib.streaming.responses._responses import ResponseStreamManager, AsyncResponseStreamManager @@ -61,6 +79,7 @@ from ...types.websocket_connection_options import WebSocketConnectionOptions from ...types.responses.response_includable import ResponseIncludable from ...types.shared_params.responses_model import ResponsesModel +from ...types.responses.response_error_event import ResponseErrorEvent from ...types.responses.response_input_param import ResponseInputParam from ...types.responses.response_prompt_param import ResponsePromptParam from ...types.responses.response_stream_event import ResponseStreamEvent @@ -1735,6 +1754,11 @@ def connect( extra_query: Query = {}, extra_headers: Headers = {}, websocket_connection_options: WebSocketConnectionOptions = {}, + on_reconnecting: Callable[[ReconnectingEvent], ReconnectingOverrides | None] | None = None, + max_retries: int = 5, + initial_delay: float = 0.5, + max_delay: float = 8.0, + max_queue_size: int = 1_048_576, ) -> ResponsesConnectionManager: """Connect to a persistent Responses API WebSocket. @@ -1745,6 +1769,11 @@ def connect( extra_query=extra_query, extra_headers=extra_headers, websocket_connection_options=websocket_connection_options, + on_reconnecting=on_reconnecting, + max_retries=max_retries, + initial_delay=initial_delay, + max_delay=max_delay, + max_queue_size=max_queue_size, ) @@ -3406,6 +3435,11 @@ def connect( extra_query: Query = {}, extra_headers: Headers = {}, websocket_connection_options: WebSocketConnectionOptions = {}, + on_reconnecting: Callable[[ReconnectingEvent], ReconnectingOverrides | None] | None = None, + max_retries: int = 5, + initial_delay: float = 0.5, + max_delay: float = 8.0, + max_queue_size: int = 1_048_576, ) -> AsyncResponsesConnectionManager: """Connect to a persistent Responses API WebSocket. @@ -3416,6 +3450,11 @@ def connect( extra_query=extra_query, extra_headers=extra_headers, websocket_connection_options=websocket_connection_options, + on_reconnecting=on_reconnecting, + max_retries=max_retries, + initial_delay=initial_delay, + max_delay=max_delay, + max_queue_size=max_queue_size, ) @@ -3586,8 +3625,31 @@ class AsyncResponsesConnection: _connection: AsyncWebSocketConnection - def __init__(self, connection: AsyncWebSocketConnection) -> None: + def __init__( + self, + connection: AsyncWebSocketConnection, + *, + make_ws: Callable[[Query, Headers], Awaitable[AsyncWebSocketConnection]] | None = None, + on_reconnecting: Callable[[ReconnectingEvent], ReconnectingOverrides | None] | None = None, + max_retries: int = 5, + initial_delay: float = 0.5, + max_delay: float = 8.0, + extra_query: Query = {}, + extra_headers: Headers = {}, + send_queue: SendQueue | None = None, + ) -> None: self._connection = connection + self._make_ws = make_ws + self._on_reconnecting = on_reconnecting + self._max_retries = max_retries + self._initial_delay = initial_delay + self._max_delay = max_delay + self._extra_query = extra_query + self._extra_headers = extra_headers + self._intentionally_closed = False + self._is_reconnecting = False + self._send_queue = send_queue or SendQueue() + self._event_handler_registry = EventHandlerRegistry(use_lock=False) self.response = AsyncResponsesResponseResource(self) @@ -3596,13 +3658,22 @@ async def __aiter__(self) -> AsyncIterator[ResponsesServerEvent]: An infinite-iterator that will continue to yield events until the connection is closed. """ - from websockets.exceptions import ConnectionClosedOK + from websockets.exceptions import ConnectionClosedOK, ConnectionClosedError - try: - while True: + while True: + try: yield await self.recv() - except ConnectionClosedOK: - return + except ConnectionClosedOK: + return + except ConnectionClosedError as exc: + if not await self._reconnect(exc): + unsent = self._send_queue.drain() + if unsent: + raise WebSocketConnectionClosedError( + "WebSocket connection closed with unsent messages", + unsent_messages=unsent, + ) from exc + raise async def recv(self) -> ResponsesServerEvent: """ @@ -3630,12 +3701,24 @@ async def send(self, event: ResponsesClientEvent | ResponsesClientEventParam) -> if isinstance(event, BaseModel) else json.dumps(await async_maybe_transform(event, ResponsesClientEventParam)) ) - await self._connection.send(data) + if self._is_reconnecting: + self._send_queue.enqueue(data) + return + try: + await self._connection.send(data) + except Exception: + self._send_queue.enqueue(data) + raise async def send_raw(self, data: bytes | str) -> None: + if self._is_reconnecting: + raw = data if isinstance(data, str) else data.decode("utf-8") + self._send_queue.enqueue(raw) + return await self._connection.send(data) async def close(self, *, code: int = 1000, reason: str = "") -> None: + self._intentionally_closed = True await self._connection.close(code=code, reason=reason) def parse_event(self, data: str | bytes) -> ResponsesServerEvent: @@ -3649,6 +3732,173 @@ def parse_event(self, data: str | bytes) -> ResponsesServerEvent: construct_type_unchecked(value=json.loads(data), type_=cast(Any, ResponsesServerEvent)), ) + async def _reconnect(self, exc: Exception) -> bool: + """Attempt to reconnect after a connection failure. + + Returns ``True`` if a new connection was established, ``False`` if the + caller should re-raise the original exception. + """ + import asyncio + + if self._on_reconnecting is None or self._make_ws is None: + return False + + from websockets.exceptions import ConnectionClosedError + + close_code = 1006 + if isinstance(exc, ConnectionClosedError) and exc.rcvd is not None: + close_code = exc.rcvd.code + + if not is_recoverable_close(close_code): + return False + + self._is_reconnecting = True + + for attempt in range(1, self._max_retries + 1): + base_delay = min(self._initial_delay * (2 ** (attempt - 1)), self._max_delay) + jitter = 0.75 + random.random() * 0.25 + delay = base_delay * jitter + + event = ReconnectingEvent( + attempt=attempt, + max_attempts=self._max_retries, + delay=delay, + close_code=close_code, + extra_query=self._extra_query, + extra_headers=self._extra_headers, + ) + + try: + result = self._on_reconnecting(event) + except Exception: + self._is_reconnecting = False + return False + + if result is not None and result.get("abort"): + self._is_reconnecting = False + return False + + if result is not None: + if "extra_query" in result: + self._extra_query = result["extra_query"] + if "extra_headers" in result: + self._extra_headers = result["extra_headers"] + + log.info( + "Reconnecting to WebSocket API (attempt %d/%d) after %.1fs delay", + attempt, + self._max_retries, + delay, + ) + await asyncio.sleep(delay) + + if self._intentionally_closed: + self._is_reconnecting = False + return False + + try: + self._connection = await self._make_ws(self._extra_query, self._extra_headers) + log.info("Reconnected to WebSocket API") + self._is_reconnecting = False + await self._flush_send_queue() + return True + except Exception: + pass + + self._is_reconnecting = False + return False + + async def _flush_send_queue(self) -> None: + """Send all queued messages over the current connection.""" + + async def _send(data: str) -> None: + await self._connection.send(data) + + try: + await self._send_queue.flush_async(_send) + except Exception: + log.warning("Failed to flush send queue after reconnect", exc_info=True) + + def on( + self, event_type: str, handler: Callable[..., Any] | None = None + ) -> Union[AsyncResponsesConnection, Callable[[Callable[..., Any]], Callable[..., Any]]]: + """Adds the handler to the end of the handlers list for the given event type. + + No checks are made to see if the handler has already been added. Multiple calls + passing the same combination of event type and handler will result in the handler + being added, and called, multiple times. + + Can be used as a method (returns ``self`` for chaining):: + + connection.on("response.audio.delta", my_handler) + + Or as a decorator:: + + @connection.on("response.audio.delta") + async def my_handler(event): ... + """ + if handler is not None: + self._event_handler_registry.add(event_type, handler) + return self + + def decorator(fn: Callable[..., Any]) -> Callable[..., Any]: + self._event_handler_registry.add(event_type, fn) + return fn + + return decorator + + def off(self, event_type: str, handler: Callable[..., Any]) -> AsyncResponsesConnection: + """Remove a previously registered event handler.""" + self._event_handler_registry.remove(event_type, handler) + return self + + def once( + self, event_type: str, handler: Callable[..., Any] | None = None + ) -> Union[AsyncResponsesConnection, Callable[[Callable[..., Any]], Callable[..., Any]]]: + """Register a one-time event handler. + + Automatically removed after first invocation. + """ + if handler is not None: + self._event_handler_registry.add(event_type, handler, once=True) + return self + + def decorator(fn: Callable[..., Any]) -> Callable[..., Any]: + self._event_handler_registry.add(event_type, fn, once=True) + return fn + + return decorator + + async def dispatch_events(self) -> None: + """Run the event loop, dispatching received events to registered handlers. + + Blocks until the connection is closed. This is the push-based + alternative to iterating with ``async for event in connection``. + + If an ``"error"`` event arrives and no handler is registered for + ``"error"`` or ``"event"``, an ``OpenAIError`` is raised. + """ + import asyncio + + async for event in self: + event_type = event.type + specific = self._event_handler_registry.get_handlers(event_type) + generic = self._event_handler_registry.get_handlers("event") + + if event_type == "error" and not specific and not generic: + if isinstance(event, ResponseErrorEvent): + raise OpenAIError(f"WebSocket error: {event}") + + for handler in specific: + result = handler(event) + if asyncio.iscoroutine(result): + await result + + for handler in generic: + result = handler(event) + if asyncio.iscoroutine(result): + await result + class AsyncResponsesConnectionManager: """ @@ -3677,16 +3927,77 @@ def __init__( extra_query: Query, extra_headers: Headers, websocket_connection_options: WebSocketConnectionOptions, + on_reconnecting: Callable[[ReconnectingEvent], ReconnectingOverrides | None] | None = None, + max_retries: int = 5, + initial_delay: float = 0.5, + max_delay: float = 8.0, + max_queue_size: int = 1_048_576, ) -> None: self.__client = client self.__connection: AsyncResponsesConnection | None = None self.__extra_query = extra_query self.__extra_headers = extra_headers self.__websocket_connection_options = websocket_connection_options + self.__on_reconnecting = on_reconnecting + self.__max_retries = max_retries + self.__initial_delay = initial_delay + self.__max_delay = max_delay + self.__send_queue = SendQueue(max_bytes=max_queue_size) + self.__event_handler_registry = EventHandlerRegistry(use_lock=False) + + def send(self, event: ResponsesClientEvent | ResponsesClientEventParam) -> None: + """Queue a message to be sent when the connection is established. + + This can be called before entering the context manager. Queued messages + are automatically sent once the WebSocket connection opens. + """ + data = ( + event.to_json(use_api_names=True, exclude_defaults=True, exclude_unset=True) + if isinstance(event, BaseModel) + else json.dumps(event) + ) + self.__send_queue.enqueue(data) + + def on( + self, event_type: str, handler: Callable[..., Any] | None = None + ) -> Union[AsyncResponsesConnectionManager, Callable[[Callable[..., Any]], Callable[..., Any]]]: + """Register an event handler before the connection is established. + + Handlers are transferred to the connection on enter. Supports the + same method and decorator forms as ``AsyncResponsesConnection.on``. + """ + if handler is not None: + self.__event_handler_registry.add(event_type, handler) + return self + + def decorator(fn: Callable[..., Any]) -> Callable[..., Any]: + self.__event_handler_registry.add(event_type, fn) + return fn + + return decorator + + def off(self, event_type: str, handler: Callable[..., Any]) -> AsyncResponsesConnectionManager: + """Remove a previously registered event handler.""" + self.__event_handler_registry.remove(event_type, handler) + return self + + def once( + self, event_type: str, handler: Callable[..., Any] | None = None + ) -> Union[AsyncResponsesConnectionManager, Callable[[Callable[..., Any]], Callable[..., Any]]]: + """Register a one-time event handler before the connection is established.""" + if handler is not None: + self.__event_handler_registry.add(event_type, handler, once=True) + return self + + def decorator(fn: Callable[..., Any]) -> Callable[..., Any]: + self.__event_handler_registry.add(event_type, fn, once=True) + return fn + + return decorator async def __aenter__(self) -> AsyncResponsesConnection: """ - ๐Ÿ‘‹ If your application doesn't work well with the context manager approach then you + If your application doesn't work well with the context manager approach then you can call this method directly to initiate a connection. **Warning**: You must remember to close the connection with `.close()`. @@ -3697,6 +4008,28 @@ async def __aenter__(self) -> AsyncResponsesConnection: await connection.close() ``` """ + ws = await self._connect_ws(self.__extra_query, self.__extra_headers) + + self.__connection = AsyncResponsesConnection( + ws, + make_ws=self._connect_ws if self.__on_reconnecting is not None else None, + on_reconnecting=self.__on_reconnecting, + max_retries=self.__max_retries, + initial_delay=self.__initial_delay, + max_delay=self.__max_delay, + extra_query=self.__extra_query, + extra_headers=self.__extra_headers, + send_queue=self.__send_queue, + ) + + self.__event_handler_registry.merge_into(self.__connection._event_handler_registry) + await self.__connection._flush_send_queue() + + return self.__connection + + enter = __aenter__ + + async def _connect_ws(self, extra_query: Query, extra_headers: Headers) -> AsyncWebSocketConnection: try: from websockets.asyncio.client import connect except ImportError as exc: @@ -3705,31 +4038,25 @@ async def __aenter__(self) -> AsyncResponsesConnection: url = self._prepare_url().copy_with( params={ **self.__client.base_url.params, - **self.__extra_query, + **extra_query, }, ) log.debug("Connecting to %s", url) if self.__websocket_connection_options: log.debug("Connection options: %s", self.__websocket_connection_options) - self.__connection = AsyncResponsesConnection( - await connect( - str(url), - user_agent_header=self.__client.user_agent, - additional_headers=_merge_mappings( - { - **self.__client.auth_headers, - }, - self.__extra_headers, - ), - **self.__websocket_connection_options, - ) + return await connect( + str(url), + user_agent_header=self.__client.user_agent, + additional_headers=_merge_mappings( + { + **self.__client.auth_headers, + }, + extra_headers, + ), + **self.__websocket_connection_options, ) - return self.__connection - - enter = __aenter__ - def _prepare_url(self) -> httpx.URL: if self.__client.websocket_base_url is not None: base_url = httpx.URL(self.__client.websocket_base_url) @@ -3755,8 +4082,31 @@ class ResponsesConnection: _connection: WebSocketConnection - def __init__(self, connection: WebSocketConnection) -> None: + def __init__( + self, + connection: WebSocketConnection, + *, + make_ws: Callable[[Query, Headers], WebSocketConnection] | None = None, + on_reconnecting: Callable[[ReconnectingEvent], ReconnectingOverrides | None] | None = None, + max_retries: int = 5, + initial_delay: float = 0.5, + max_delay: float = 8.0, + extra_query: Query = {}, + extra_headers: Headers = {}, + send_queue: SendQueue | None = None, + ) -> None: self._connection = connection + self._make_ws = make_ws + self._on_reconnecting = on_reconnecting + self._max_retries = max_retries + self._initial_delay = initial_delay + self._max_delay = max_delay + self._extra_query = extra_query + self._extra_headers = extra_headers + self._intentionally_closed = False + self._is_reconnecting = False + self._send_queue = send_queue or SendQueue() + self._event_handler_registry = EventHandlerRegistry(use_lock=True) self.response = ResponsesResponseResource(self) @@ -3765,13 +4115,22 @@ def __iter__(self) -> Iterator[ResponsesServerEvent]: An infinite-iterator that will continue to yield events until the connection is closed. """ - from websockets.exceptions import ConnectionClosedOK + from websockets.exceptions import ConnectionClosedOK, ConnectionClosedError - try: - while True: + while True: + try: yield self.recv() - except ConnectionClosedOK: - return + except ConnectionClosedOK: + return + except ConnectionClosedError as exc: + if not self._reconnect(exc): + unsent = self._send_queue.drain() + if unsent: + raise WebSocketConnectionClosedError( + "WebSocket connection closed with unsent messages", + unsent_messages=unsent, + ) from exc + raise def recv(self) -> ResponsesServerEvent: """ @@ -3799,12 +4158,24 @@ def send(self, event: ResponsesClientEvent | ResponsesClientEventParam) -> None: if isinstance(event, BaseModel) else json.dumps(maybe_transform(event, ResponsesClientEventParam)) ) - self._connection.send(data) + if self._is_reconnecting: + self._send_queue.enqueue(data) + return + try: + self._connection.send(data) + except Exception: + self._send_queue.enqueue(data) + raise def send_raw(self, data: bytes | str) -> None: + if self._is_reconnecting: + raw = data if isinstance(data, str) else data.decode("utf-8") + self._send_queue.enqueue(raw) + return self._connection.send(data) def close(self, *, code: int = 1000, reason: str = "") -> None: + self._intentionally_closed = True self._connection.close(code=code, reason=reason) def parse_event(self, data: str | bytes) -> ResponsesServerEvent: @@ -3818,6 +4189,161 @@ def parse_event(self, data: str | bytes) -> ResponsesServerEvent: construct_type_unchecked(value=json.loads(data), type_=cast(Any, ResponsesServerEvent)), ) + def _reconnect(self, exc: Exception) -> bool: + """Attempt to reconnect after a connection failure. + + Returns ``True`` if a new connection was established, ``False`` if the + caller should re-raise the original exception. + """ + if self._on_reconnecting is None or self._make_ws is None: + return False + + from websockets.exceptions import ConnectionClosedError + + close_code = 1006 + if isinstance(exc, ConnectionClosedError) and exc.rcvd is not None: + close_code = exc.rcvd.code + + if not is_recoverable_close(close_code): + return False + + self._is_reconnecting = True + + for attempt in range(1, self._max_retries + 1): + base_delay = min(self._initial_delay * (2 ** (attempt - 1)), self._max_delay) + jitter = 0.75 + random.random() * 0.25 + delay = base_delay * jitter + + event = ReconnectingEvent( + attempt=attempt, + max_attempts=self._max_retries, + delay=delay, + close_code=close_code, + extra_query=self._extra_query, + extra_headers=self._extra_headers, + ) + + try: + result = self._on_reconnecting(event) + except Exception: + self._is_reconnecting = False + return False + + if result is not None and result.get("abort"): + self._is_reconnecting = False + return False + + if result is not None: + if "extra_query" in result: + self._extra_query = result["extra_query"] + if "extra_headers" in result: + self._extra_headers = result["extra_headers"] + + log.info( + "Reconnecting to WebSocket API (attempt %d/%d) after %.1fs delay", + attempt, + self._max_retries, + delay, + ) + time.sleep(delay) + + if self._intentionally_closed: + self._is_reconnecting = False + return False + + try: + self._connection = self._make_ws(self._extra_query, self._extra_headers) + log.info("Reconnected to WebSocket API") + self._is_reconnecting = False + self._flush_send_queue() + return True + except Exception: + pass + + self._is_reconnecting = False + return False + + def _flush_send_queue(self) -> None: + """Send all queued messages over the current connection.""" + try: + self._send_queue.flush_sync(lambda data: self._connection.send(data)) + except Exception: + log.warning("Failed to flush send queue after reconnect", exc_info=True) + + def on( + self, event_type: str, handler: Callable[..., Any] | None = None + ) -> Union[ResponsesConnection, Callable[[Callable[..., Any]], Callable[..., Any]]]: + """Adds the handler to the end of the handlers list for the given event type. + + No checks are made to see if the handler has already been added. Multiple calls + passing the same combination of event type and handler will result in the handler + being added, and called, multiple times. + + Can be used as a method (returns ``self`` for chaining):: + + connection.on("response.audio.delta", my_handler) + + Or as a decorator:: + + @connection.on("response.audio.delta") + def my_handler(event): ... + """ + if handler is not None: + self._event_handler_registry.add(event_type, handler) + return self + + def decorator(fn: Callable[..., Any]) -> Callable[..., Any]: + self._event_handler_registry.add(event_type, fn) + return fn + + return decorator + + def off(self, event_type: str, handler: Callable[..., Any]) -> ResponsesConnection: + """Remove a previously registered event handler.""" + self._event_handler_registry.remove(event_type, handler) + return self + + def once( + self, event_type: str, handler: Callable[..., Any] | None = None + ) -> Union[ResponsesConnection, Callable[[Callable[..., Any]], Callable[..., Any]]]: + """Register a one-time event handler. + + Automatically removed after first invocation. + """ + if handler is not None: + self._event_handler_registry.add(event_type, handler, once=True) + return self + + def decorator(fn: Callable[..., Any]) -> Callable[..., Any]: + self._event_handler_registry.add(event_type, fn, once=True) + return fn + + return decorator + + def dispatch_events(self) -> None: + """Run the event loop, dispatching received events to registered handlers. + + Blocks the current thread until the connection is closed. This is the push-based + alternative to iterating with ``for event in connection``. + + If an ``"error"`` event arrives and no handler is registered for + ``"error"`` or ``"event"``, an ``OpenAIError`` is raised. + """ + for event in self: + event_type = event.type + specific = self._event_handler_registry.get_handlers(event_type) + generic = self._event_handler_registry.get_handlers("event") + + if event_type == "error" and not specific and not generic: + if isinstance(event, ResponseErrorEvent): + raise OpenAIError(f"WebSocket error: {event}") + + for handler in specific: + handler(event) + + for handler in generic: + handler(event) + class ResponsesConnectionManager: """ @@ -3846,16 +4372,77 @@ def __init__( extra_query: Query, extra_headers: Headers, websocket_connection_options: WebSocketConnectionOptions, + on_reconnecting: Callable[[ReconnectingEvent], ReconnectingOverrides | None] | None = None, + max_retries: int = 5, + initial_delay: float = 0.5, + max_delay: float = 8.0, + max_queue_size: int = 1_048_576, ) -> None: self.__client = client self.__connection: ResponsesConnection | None = None self.__extra_query = extra_query self.__extra_headers = extra_headers self.__websocket_connection_options = websocket_connection_options + self.__on_reconnecting = on_reconnecting + self.__max_retries = max_retries + self.__initial_delay = initial_delay + self.__max_delay = max_delay + self.__send_queue = SendQueue(max_bytes=max_queue_size) + self.__event_handler_registry = EventHandlerRegistry(use_lock=True) + + def send(self, event: ResponsesClientEvent | ResponsesClientEventParam) -> None: + """Queue a message to be sent when the connection is established. + + This can be called before entering the context manager. Queued messages + are automatically sent once the WebSocket connection opens. + """ + data = ( + event.to_json(use_api_names=True, exclude_defaults=True, exclude_unset=True) + if isinstance(event, BaseModel) + else json.dumps(event) + ) + self.__send_queue.enqueue(data) + + def on( + self, event_type: str, handler: Callable[..., Any] | None = None + ) -> Union[ResponsesConnectionManager, Callable[[Callable[..., Any]], Callable[..., Any]]]: + """Register an event handler before the connection is established. + + Handlers are transferred to the connection on enter. Supports the + same method and decorator forms as ``ResponsesConnection.on``. + """ + if handler is not None: + self.__event_handler_registry.add(event_type, handler) + return self + + def decorator(fn: Callable[..., Any]) -> Callable[..., Any]: + self.__event_handler_registry.add(event_type, fn) + return fn + + return decorator + + def off(self, event_type: str, handler: Callable[..., Any]) -> ResponsesConnectionManager: + """Remove a previously registered event handler.""" + self.__event_handler_registry.remove(event_type, handler) + return self + + def once( + self, event_type: str, handler: Callable[..., Any] | None = None + ) -> Union[ResponsesConnectionManager, Callable[[Callable[..., Any]], Callable[..., Any]]]: + """Register a one-time event handler before the connection is established.""" + if handler is not None: + self.__event_handler_registry.add(event_type, handler, once=True) + return self + + def decorator(fn: Callable[..., Any]) -> Callable[..., Any]: + self.__event_handler_registry.add(event_type, fn, once=True) + return fn + + return decorator def __enter__(self) -> ResponsesConnection: """ - ๐Ÿ‘‹ If your application doesn't work well with the context manager approach then you + If your application doesn't work well with the context manager approach then you can call this method directly to initiate a connection. **Warning**: You must remember to close the connection with `.close()`. @@ -3866,6 +4453,28 @@ def __enter__(self) -> ResponsesConnection: connection.close() ``` """ + ws = self._connect_ws(self.__extra_query, self.__extra_headers) + + self.__connection = ResponsesConnection( + ws, + make_ws=self._connect_ws if self.__on_reconnecting is not None else None, + on_reconnecting=self.__on_reconnecting, + max_retries=self.__max_retries, + initial_delay=self.__initial_delay, + max_delay=self.__max_delay, + extra_query=self.__extra_query, + extra_headers=self.__extra_headers, + send_queue=self.__send_queue, + ) + + self.__event_handler_registry.merge_into(self.__connection._event_handler_registry) + self.__connection._flush_send_queue() + + return self.__connection + + enter = __enter__ + + def _connect_ws(self, extra_query: Query, extra_headers: Headers) -> WebSocketConnection: try: from websockets.sync.client import connect except ImportError as exc: @@ -3874,31 +4483,25 @@ def __enter__(self) -> ResponsesConnection: url = self._prepare_url().copy_with( params={ **self.__client.base_url.params, - **self.__extra_query, + **extra_query, }, ) log.debug("Connecting to %s", url) if self.__websocket_connection_options: log.debug("Connection options: %s", self.__websocket_connection_options) - self.__connection = ResponsesConnection( - connect( - str(url), - user_agent_header=self.__client.user_agent, - additional_headers=_merge_mappings( - { - **self.__client.auth_headers, - }, - self.__extra_headers, - ), - **self.__websocket_connection_options, - ) + return connect( + str(url), + user_agent_header=self.__client.user_agent, + additional_headers=_merge_mappings( + { + **self.__client.auth_headers, + }, + extra_headers, + ), + **self.__websocket_connection_options, ) - return self.__connection - - enter = __enter__ - def _prepare_url(self) -> httpx.URL: if self.__client.websocket_base_url is not None: base_url = httpx.URL(self.__client.websocket_base_url) diff --git a/src/openai/types/__init__.py b/src/openai/types/__init__.py index 6074eb0a1d..dc2c8a348c 100644 --- a/src/openai/types/__init__.py +++ b/src/openai/types/__init__.py @@ -85,6 +85,10 @@ from .file_chunking_strategy import FileChunkingStrategy as FileChunkingStrategy from .image_gen_stream_event import ImageGenStreamEvent as ImageGenStreamEvent from .upload_complete_params import UploadCompleteParams as UploadCompleteParams +from .websocket_reconnection import ( + ReconnectingEvent as ReconnectingEvent, + ReconnectingOverrides as ReconnectingOverrides, +) from .container_create_params import ContainerCreateParams as ContainerCreateParams from .container_list_response import ContainerListResponse as ContainerListResponse from .embedding_create_params import EmbeddingCreateParams as EmbeddingCreateParams diff --git a/src/openai/types/responses/response_input_file.py b/src/openai/types/responses/response_input_file.py index 3e5fb70c5f..f07ff7c049 100644 --- a/src/openai/types/responses/response_input_file.py +++ b/src/openai/types/responses/response_input_file.py @@ -14,6 +14,13 @@ class ResponseInputFile(BaseModel): type: Literal["input_file"] """The type of the input item. Always `input_file`.""" + detail: Optional[Literal["low", "high"]] = None + """The detail level of the file to be sent to the model. + + Use `low` for the default rendering behavior, or `high` to render the file at + higher quality. Defaults to `low`. + """ + file_data: Optional[str] = None """The content of the file to be sent to the model.""" diff --git a/src/openai/types/responses/response_input_file_content.py b/src/openai/types/responses/response_input_file_content.py index f0dfef55d0..a0c2de4823 100644 --- a/src/openai/types/responses/response_input_file_content.py +++ b/src/openai/types/responses/response_input_file_content.py @@ -14,6 +14,13 @@ class ResponseInputFileContent(BaseModel): type: Literal["input_file"] """The type of the input item. Always `input_file`.""" + detail: Optional[Literal["low", "high"]] = None + """The detail level of the file to be sent to the model. + + Use `low` for the default rendering behavior, or `high` to render the file at + higher quality. Defaults to `low`. + """ + file_data: Optional[str] = None """The base64-encoded data of the file to be sent to the model.""" diff --git a/src/openai/types/responses/response_input_file_content_param.py b/src/openai/types/responses/response_input_file_content_param.py index 376f6c7a45..4206ea09f3 100644 --- a/src/openai/types/responses/response_input_file_content_param.py +++ b/src/openai/types/responses/response_input_file_content_param.py @@ -14,6 +14,13 @@ class ResponseInputFileContentParam(TypedDict, total=False): type: Required[Literal["input_file"]] """The type of the input item. Always `input_file`.""" + detail: Literal["low", "high"] + """The detail level of the file to be sent to the model. + + Use `low` for the default rendering behavior, or `high` to render the file at + higher quality. Defaults to `low`. + """ + file_data: Optional[str] """The base64-encoded data of the file to be sent to the model.""" diff --git a/src/openai/types/responses/response_input_file_param.py b/src/openai/types/responses/response_input_file_param.py index 8b5da20245..a6bdb6e18c 100644 --- a/src/openai/types/responses/response_input_file_param.py +++ b/src/openai/types/responses/response_input_file_param.py @@ -14,6 +14,13 @@ class ResponseInputFileParam(TypedDict, total=False): type: Required[Literal["input_file"]] """The type of the input item. Always `input_file`.""" + detail: Literal["low", "high"] + """The detail level of the file to be sent to the model. + + Use `low` for the default rendering behavior, or `high` to render the file at + higher quality. Defaults to `low`. + """ + file_data: str """The content of the file to be sent to the model.""" diff --git a/src/openai/types/websocket_reconnection.py b/src/openai/types/websocket_reconnection.py new file mode 100644 index 0000000000..9ba66d0ca5 --- /dev/null +++ b/src/openai/types/websocket_reconnection.py @@ -0,0 +1,64 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +import dataclasses +from typing_extensions import TypedDict + +from .._types import Query, Headers + + +@dataclasses.dataclass(frozen=True) +class ReconnectingEvent: + """Information about a reconnection attempt, passed to the ``on_reconnecting`` handler.""" + + attempt: int + """Which retry attempt this is (1-based).""" + + max_attempts: int + """Total attempts that will be made.""" + + delay: float + """Delay in seconds before this attempt connects.""" + + close_code: int + """The WebSocket close code that triggered reconnection.""" + + extra_query: Query + """The current query parameters.""" + + extra_headers: Headers + """The current headers.""" + + +class ReconnectingOverrides(TypedDict, total=False): + """Optional overrides returned from the ``on_reconnecting`` handler + to customize the next reconnection attempt.""" + + extra_query: Query + """If provided, assigns the query parameters for the next connection.""" + + extra_headers: Headers + """If provided, assigns the headers for the next connection.""" + + abort: bool + """If set to ``True``, will stop attempting to reconnect.""" + + +# RFC 6455 ยง7.4.1 +_RECOVERABLE_CLOSE_CODES: frozenset[int] = frozenset( + { + 1001, # Going away (server shutting down) + 1005, # No status code (abnormal) + 1006, # Abnormal closure (network drop) + 1011, # Internal server error + 1012, # Service restart + 1013, # Try again later + 1015, # TLS handshake failure + } +) + + +def is_recoverable_close(code: int) -> bool: + """Return ``True`` if the WebSocket close *code* is worth retrying.""" + return code in _RECOVERABLE_CLOSE_CODES diff --git a/tests/api_resources/audio/test_speech.py b/tests/api_resources/audio/test_speech.py index a42c77126d..93ede5193d 100644 --- a/tests/api_resources/audio/test_speech.py +++ b/tests/api_resources/audio/test_speech.py @@ -27,8 +27,8 @@ def test_method_create(self, client: OpenAI, respx_mock: MockRouter) -> None: respx_mock.post("/audio/speech").mock(return_value=httpx.Response(200, json={"foo": "bar"})) speech = client.audio.speech.create( input="string", - model="string", - voice="string", + model="tts-1", + voice="alloy", ) assert isinstance(speech, _legacy_response.HttpxBinaryResponseContent) assert speech.json() == {"foo": "bar"} @@ -39,8 +39,8 @@ def test_method_create_with_all_params(self, client: OpenAI, respx_mock: MockRou respx_mock.post("/audio/speech").mock(return_value=httpx.Response(200, json={"foo": "bar"})) speech = client.audio.speech.create( input="string", - model="string", - voice="string", + model="tts-1", + voice="alloy", instructions="instructions", response_format="mp3", speed=0.25, @@ -56,8 +56,8 @@ def test_raw_response_create(self, client: OpenAI, respx_mock: MockRouter) -> No response = client.audio.speech.with_raw_response.create( input="string", - model="string", - voice="string", + model="tts-1", + voice="alloy", ) assert response.is_closed is True @@ -71,8 +71,8 @@ def test_streaming_response_create(self, client: OpenAI, respx_mock: MockRouter) respx_mock.post("/audio/speech").mock(return_value=httpx.Response(200, json={"foo": "bar"})) with client.audio.speech.with_streaming_response.create( input="string", - model="string", - voice="string", + model="tts-1", + voice="alloy", ) as response: assert not response.is_closed assert response.http_request.headers.get("X-Stainless-Lang") == "python" @@ -94,8 +94,8 @@ async def test_method_create(self, async_client: AsyncOpenAI, respx_mock: MockRo respx_mock.post("/audio/speech").mock(return_value=httpx.Response(200, json={"foo": "bar"})) speech = await async_client.audio.speech.create( input="string", - model="string", - voice="string", + model="tts-1", + voice="alloy", ) assert isinstance(speech, _legacy_response.HttpxBinaryResponseContent) assert speech.json() == {"foo": "bar"} @@ -106,8 +106,8 @@ async def test_method_create_with_all_params(self, async_client: AsyncOpenAI, re respx_mock.post("/audio/speech").mock(return_value=httpx.Response(200, json={"foo": "bar"})) speech = await async_client.audio.speech.create( input="string", - model="string", - voice="string", + model="tts-1", + voice="alloy", instructions="instructions", response_format="mp3", speed=0.25, @@ -123,8 +123,8 @@ async def test_raw_response_create(self, async_client: AsyncOpenAI, respx_mock: response = await async_client.audio.speech.with_raw_response.create( input="string", - model="string", - voice="string", + model="tts-1", + voice="alloy", ) assert response.is_closed is True @@ -138,8 +138,8 @@ async def test_streaming_response_create(self, async_client: AsyncOpenAI, respx_ respx_mock.post("/audio/speech").mock(return_value=httpx.Response(200, json={"foo": "bar"})) async with async_client.audio.speech.with_streaming_response.create( input="string", - model="string", - voice="string", + model="tts-1", + voice="alloy", ) as response: assert not response.is_closed assert response.http_request.headers.get("X-Stainless-Lang") == "python" diff --git a/tests/api_resources/beta/test_assistants.py b/tests/api_resources/beta/test_assistants.py index 3e85b56dcc..91ff13dded 100644 --- a/tests/api_resources/beta/test_assistants.py +++ b/tests/api_resources/beta/test_assistants.py @@ -149,7 +149,7 @@ def test_method_update_with_all_params(self, client: OpenAI) -> None: description="description", instructions="instructions", metadata={"foo": "string"}, - model="string", + model="gpt-5", name="name", reasoning_effort="none", response_format="auto", @@ -414,7 +414,7 @@ async def test_method_update_with_all_params(self, async_client: AsyncOpenAI) -> description="description", instructions="instructions", metadata={"foo": "string"}, - model="string", + model="gpt-5", name="name", reasoning_effort="none", response_format="auto", diff --git a/tests/api_resources/beta/test_threads.py b/tests/api_resources/beta/test_threads.py index f392c86729..7163511951 100644 --- a/tests/api_resources/beta/test_threads.py +++ b/tests/api_resources/beta/test_threads.py @@ -248,7 +248,7 @@ def test_method_create_and_run_with_all_params_overload_1(self, client: OpenAI) max_completion_tokens=256, max_prompt_tokens=256, metadata={"foo": "string"}, - model="string", + model="gpt-5.4", parallel_tool_calls=True, response_format="auto", stream=False, @@ -343,7 +343,7 @@ def test_method_create_and_run_with_all_params_overload_2(self, client: OpenAI) max_completion_tokens=256, max_prompt_tokens=256, metadata={"foo": "string"}, - model="string", + model="gpt-5.4", parallel_tool_calls=True, response_format="auto", temperature=1, @@ -649,7 +649,7 @@ async def test_method_create_and_run_with_all_params_overload_1(self, async_clie max_completion_tokens=256, max_prompt_tokens=256, metadata={"foo": "string"}, - model="string", + model="gpt-5.4", parallel_tool_calls=True, response_format="auto", stream=False, @@ -744,7 +744,7 @@ async def test_method_create_and_run_with_all_params_overload_2(self, async_clie max_completion_tokens=256, max_prompt_tokens=256, metadata={"foo": "string"}, - model="string", + model="gpt-5.4", parallel_tool_calls=True, response_format="auto", temperature=1, diff --git a/tests/api_resources/beta/threads/test_runs.py b/tests/api_resources/beta/threads/test_runs.py index 3a6b36864d..aa85871325 100644 --- a/tests/api_resources/beta/threads/test_runs.py +++ b/tests/api_resources/beta/threads/test_runs.py @@ -57,7 +57,7 @@ def test_method_create_with_all_params_overload_1(self, client: OpenAI) -> None: max_completion_tokens=256, max_prompt_tokens=256, metadata={"foo": "string"}, - model="string", + model="gpt-5.4", parallel_tool_calls=True, reasoning_effort="none", response_format="auto", @@ -148,7 +148,7 @@ def test_method_create_with_all_params_overload_2(self, client: OpenAI) -> None: max_completion_tokens=256, max_prompt_tokens=256, metadata={"foo": "string"}, - model="string", + model="gpt-5.4", parallel_tool_calls=True, reasoning_effort="none", response_format="auto", @@ -607,7 +607,7 @@ async def test_method_create_with_all_params_overload_1(self, async_client: Asyn max_completion_tokens=256, max_prompt_tokens=256, metadata={"foo": "string"}, - model="string", + model="gpt-5.4", parallel_tool_calls=True, reasoning_effort="none", response_format="auto", @@ -698,7 +698,7 @@ async def test_method_create_with_all_params_overload_2(self, async_client: Asyn max_completion_tokens=256, max_prompt_tokens=256, metadata={"foo": "string"}, - model="string", + model="gpt-5.4", parallel_tool_calls=True, reasoning_effort="none", response_format="auto", diff --git a/tests/api_resources/chat/test_completions.py b/tests/api_resources/chat/test_completions.py index c55c132697..a75764b5e9 100644 --- a/tests/api_resources/chat/test_completions.py +++ b/tests/api_resources/chat/test_completions.py @@ -48,7 +48,7 @@ def test_method_create_with_all_params_overload_1(self, client: OpenAI) -> None: model="gpt-5.4", audio={ "format": "wav", - "voice": "string", + "voice": "alloy", }, frequency_penalty=-2, function_call="none", @@ -182,7 +182,7 @@ def test_method_create_with_all_params_overload_2(self, client: OpenAI) -> None: stream=True, audio={ "format": "wav", - "voice": "string", + "voice": "alloy", }, frequency_penalty=-2, function_call="none", @@ -491,7 +491,7 @@ async def test_method_create_with_all_params_overload_1(self, async_client: Asyn model="gpt-5.4", audio={ "format": "wav", - "voice": "string", + "voice": "alloy", }, frequency_penalty=-2, function_call="none", @@ -625,7 +625,7 @@ async def test_method_create_with_all_params_overload_2(self, async_client: Asyn stream=True, audio={ "format": "wav", - "voice": "string", + "voice": "alloy", }, frequency_penalty=-2, function_call="none", diff --git a/tests/api_resources/realtime/test_calls.py b/tests/api_resources/realtime/test_calls.py index 9e2810841d..43ab9afe01 100644 --- a/tests/api_resources/realtime/test_calls.py +++ b/tests/api_resources/realtime/test_calls.py @@ -48,7 +48,7 @@ def test_method_create_with_all_params(self, client: OpenAI, respx_mock: MockRou "noise_reduction": {"type": "near_field"}, "transcription": { "language": "language", - "model": "string", + "model": "whisper-1", "prompt": "prompt", }, "turn_detection": { @@ -67,13 +67,13 @@ def test_method_create_with_all_params(self, client: OpenAI, respx_mock: MockRou "type": "audio/pcm", }, "speed": 0.25, - "voice": "string", + "voice": "alloy", }, }, "include": ["item.input_audio_transcription.logprobs"], "instructions": "instructions", - "max_output_tokens": 0, - "model": "string", + "max_output_tokens": "inf", + "model": "gpt-realtime", "output_modalities": ["text"], "prompt": { "id": "id", @@ -147,7 +147,7 @@ def test_method_accept_with_all_params(self, client: OpenAI) -> None: "noise_reduction": {"type": "near_field"}, "transcription": { "language": "language", - "model": "string", + "model": "whisper-1", "prompt": "prompt", }, "turn_detection": { @@ -166,13 +166,13 @@ def test_method_accept_with_all_params(self, client: OpenAI) -> None: "type": "audio/pcm", }, "speed": 0.25, - "voice": "string", + "voice": "alloy", }, }, include=["item.input_audio_transcription.logprobs"], instructions="instructions", - max_output_tokens=0, - model="string", + max_output_tokens="inf", + model="gpt-realtime", output_modalities=["text"], prompt={ "id": "id", @@ -386,7 +386,7 @@ async def test_method_create_with_all_params(self, async_client: AsyncOpenAI, re "noise_reduction": {"type": "near_field"}, "transcription": { "language": "language", - "model": "string", + "model": "whisper-1", "prompt": "prompt", }, "turn_detection": { @@ -405,13 +405,13 @@ async def test_method_create_with_all_params(self, async_client: AsyncOpenAI, re "type": "audio/pcm", }, "speed": 0.25, - "voice": "string", + "voice": "alloy", }, }, "include": ["item.input_audio_transcription.logprobs"], "instructions": "instructions", - "max_output_tokens": 0, - "model": "string", + "max_output_tokens": "inf", + "model": "gpt-realtime", "output_modalities": ["text"], "prompt": { "id": "id", @@ -485,7 +485,7 @@ async def test_method_accept_with_all_params(self, async_client: AsyncOpenAI) -> "noise_reduction": {"type": "near_field"}, "transcription": { "language": "language", - "model": "string", + "model": "whisper-1", "prompt": "prompt", }, "turn_detection": { @@ -504,13 +504,13 @@ async def test_method_accept_with_all_params(self, async_client: AsyncOpenAI) -> "type": "audio/pcm", }, "speed": 0.25, - "voice": "string", + "voice": "alloy", }, }, include=["item.input_audio_transcription.logprobs"], instructions="instructions", - max_output_tokens=0, - model="string", + max_output_tokens="inf", + model="gpt-realtime", output_modalities=["text"], prompt={ "id": "id", diff --git a/tests/api_resources/realtime/test_client_secrets.py b/tests/api_resources/realtime/test_client_secrets.py index bfa0deac55..a354019eac 100644 --- a/tests/api_resources/realtime/test_client_secrets.py +++ b/tests/api_resources/realtime/test_client_secrets.py @@ -40,7 +40,7 @@ def test_method_create_with_all_params(self, client: OpenAI) -> None: "noise_reduction": {"type": "near_field"}, "transcription": { "language": "language", - "model": "string", + "model": "whisper-1", "prompt": "prompt", }, "turn_detection": { @@ -59,13 +59,13 @@ def test_method_create_with_all_params(self, client: OpenAI) -> None: "type": "audio/pcm", }, "speed": 0.25, - "voice": "string", + "voice": "alloy", }, }, "include": ["item.input_audio_transcription.logprobs"], "instructions": "instructions", - "max_output_tokens": 0, - "model": "string", + "max_output_tokens": "inf", + "model": "gpt-realtime", "output_modalities": ["text"], "prompt": { "id": "id", @@ -136,7 +136,7 @@ async def test_method_create_with_all_params(self, async_client: AsyncOpenAI) -> "noise_reduction": {"type": "near_field"}, "transcription": { "language": "language", - "model": "string", + "model": "whisper-1", "prompt": "prompt", }, "turn_detection": { @@ -155,13 +155,13 @@ async def test_method_create_with_all_params(self, async_client: AsyncOpenAI) -> "type": "audio/pcm", }, "speed": 0.25, - "voice": "string", + "voice": "alloy", }, }, "include": ["item.input_audio_transcription.logprobs"], "instructions": "instructions", - "max_output_tokens": 0, - "model": "string", + "max_output_tokens": "inf", + "model": "gpt-realtime", "output_modalities": ["text"], "prompt": { "id": "id", diff --git a/tests/api_resources/test_completions.py b/tests/api_resources/test_completions.py index a8fb0e59eb..d0081aaca2 100644 --- a/tests/api_resources/test_completions.py +++ b/tests/api_resources/test_completions.py @@ -20,7 +20,7 @@ class TestCompletions: @parametrize def test_method_create_overload_1(self, client: OpenAI) -> None: completion = client.completions.create( - model="string", + model="gpt-3.5-turbo-instruct", prompt="This is a test.", ) assert_matches_type(Completion, completion, path=["response"]) @@ -28,7 +28,7 @@ def test_method_create_overload_1(self, client: OpenAI) -> None: @parametrize def test_method_create_with_all_params_overload_1(self, client: OpenAI) -> None: completion = client.completions.create( - model="string", + model="gpt-3.5-turbo-instruct", prompt="This is a test.", best_of=0, echo=True, @@ -55,7 +55,7 @@ def test_method_create_with_all_params_overload_1(self, client: OpenAI) -> None: @parametrize def test_raw_response_create_overload_1(self, client: OpenAI) -> None: response = client.completions.with_raw_response.create( - model="string", + model="gpt-3.5-turbo-instruct", prompt="This is a test.", ) @@ -67,7 +67,7 @@ def test_raw_response_create_overload_1(self, client: OpenAI) -> None: @parametrize def test_streaming_response_create_overload_1(self, client: OpenAI) -> None: with client.completions.with_streaming_response.create( - model="string", + model="gpt-3.5-turbo-instruct", prompt="This is a test.", ) as response: assert not response.is_closed @@ -81,7 +81,7 @@ def test_streaming_response_create_overload_1(self, client: OpenAI) -> None: @parametrize def test_method_create_overload_2(self, client: OpenAI) -> None: completion_stream = client.completions.create( - model="string", + model="gpt-3.5-turbo-instruct", prompt="This is a test.", stream=True, ) @@ -90,7 +90,7 @@ def test_method_create_overload_2(self, client: OpenAI) -> None: @parametrize def test_method_create_with_all_params_overload_2(self, client: OpenAI) -> None: completion_stream = client.completions.create( - model="string", + model="gpt-3.5-turbo-instruct", prompt="This is a test.", stream=True, best_of=0, @@ -117,7 +117,7 @@ def test_method_create_with_all_params_overload_2(self, client: OpenAI) -> None: @parametrize def test_raw_response_create_overload_2(self, client: OpenAI) -> None: response = client.completions.with_raw_response.create( - model="string", + model="gpt-3.5-turbo-instruct", prompt="This is a test.", stream=True, ) @@ -129,7 +129,7 @@ def test_raw_response_create_overload_2(self, client: OpenAI) -> None: @parametrize def test_streaming_response_create_overload_2(self, client: OpenAI) -> None: with client.completions.with_streaming_response.create( - model="string", + model="gpt-3.5-turbo-instruct", prompt="This is a test.", stream=True, ) as response: @@ -150,7 +150,7 @@ class TestAsyncCompletions: @parametrize async def test_method_create_overload_1(self, async_client: AsyncOpenAI) -> None: completion = await async_client.completions.create( - model="string", + model="gpt-3.5-turbo-instruct", prompt="This is a test.", ) assert_matches_type(Completion, completion, path=["response"]) @@ -158,7 +158,7 @@ async def test_method_create_overload_1(self, async_client: AsyncOpenAI) -> None @parametrize async def test_method_create_with_all_params_overload_1(self, async_client: AsyncOpenAI) -> None: completion = await async_client.completions.create( - model="string", + model="gpt-3.5-turbo-instruct", prompt="This is a test.", best_of=0, echo=True, @@ -185,7 +185,7 @@ async def test_method_create_with_all_params_overload_1(self, async_client: Asyn @parametrize async def test_raw_response_create_overload_1(self, async_client: AsyncOpenAI) -> None: response = await async_client.completions.with_raw_response.create( - model="string", + model="gpt-3.5-turbo-instruct", prompt="This is a test.", ) @@ -197,7 +197,7 @@ async def test_raw_response_create_overload_1(self, async_client: AsyncOpenAI) - @parametrize async def test_streaming_response_create_overload_1(self, async_client: AsyncOpenAI) -> None: async with async_client.completions.with_streaming_response.create( - model="string", + model="gpt-3.5-turbo-instruct", prompt="This is a test.", ) as response: assert not response.is_closed @@ -211,7 +211,7 @@ async def test_streaming_response_create_overload_1(self, async_client: AsyncOpe @parametrize async def test_method_create_overload_2(self, async_client: AsyncOpenAI) -> None: completion_stream = await async_client.completions.create( - model="string", + model="gpt-3.5-turbo-instruct", prompt="This is a test.", stream=True, ) @@ -220,7 +220,7 @@ async def test_method_create_overload_2(self, async_client: AsyncOpenAI) -> None @parametrize async def test_method_create_with_all_params_overload_2(self, async_client: AsyncOpenAI) -> None: completion_stream = await async_client.completions.create( - model="string", + model="gpt-3.5-turbo-instruct", prompt="This is a test.", stream=True, best_of=0, @@ -247,7 +247,7 @@ async def test_method_create_with_all_params_overload_2(self, async_client: Asyn @parametrize async def test_raw_response_create_overload_2(self, async_client: AsyncOpenAI) -> None: response = await async_client.completions.with_raw_response.create( - model="string", + model="gpt-3.5-turbo-instruct", prompt="This is a test.", stream=True, ) @@ -259,7 +259,7 @@ async def test_raw_response_create_overload_2(self, async_client: AsyncOpenAI) - @parametrize async def test_streaming_response_create_overload_2(self, async_client: AsyncOpenAI) -> None: async with async_client.completions.with_streaming_response.create( - model="string", + model="gpt-3.5-turbo-instruct", prompt="This is a test.", stream=True, ) as response: diff --git a/tests/api_resources/test_images.py b/tests/api_resources/test_images.py index 9862b79c65..13cbc0acb7 100644 --- a/tests/api_resources/test_images.py +++ b/tests/api_resources/test_images.py @@ -28,7 +28,7 @@ def test_method_create_variation(self, client: OpenAI) -> None: def test_method_create_variation_with_all_params(self, client: OpenAI) -> None: image = client.images.create_variation( image=b"Example data", - model="string", + model="gpt-image-1.5", n=1, response_format="url", size="1024x1024", @@ -76,7 +76,7 @@ def test_method_edit_with_all_params_overload_1(self, client: OpenAI) -> None: background="transparent", input_fidelity="high", mask=b"Example data", - model="string", + model="gpt-image-1.5", n=1, output_compression=100, output_format="png", @@ -133,7 +133,7 @@ def test_method_edit_with_all_params_overload_2(self, client: OpenAI) -> None: background="transparent", input_fidelity="high", mask=b"Example data", - model="string", + model="gpt-image-1.5", n=1, output_compression=100, output_format="png", @@ -184,7 +184,7 @@ def test_method_generate_with_all_params_overload_1(self, client: OpenAI) -> Non image = client.images.generate( prompt="A cute baby sea otter", background="transparent", - model="string", + model="gpt-image-1.5", moderation="low", n=1, output_compression=100, @@ -237,7 +237,7 @@ def test_method_generate_with_all_params_overload_2(self, client: OpenAI) -> Non prompt="A cute baby sea otter", stream=True, background="transparent", - model="string", + model="gpt-image-1.5", moderation="low", n=1, output_compression=100, @@ -293,7 +293,7 @@ async def test_method_create_variation(self, async_client: AsyncOpenAI) -> None: async def test_method_create_variation_with_all_params(self, async_client: AsyncOpenAI) -> None: image = await async_client.images.create_variation( image=b"Example data", - model="string", + model="gpt-image-1.5", n=1, response_format="url", size="1024x1024", @@ -341,7 +341,7 @@ async def test_method_edit_with_all_params_overload_1(self, async_client: AsyncO background="transparent", input_fidelity="high", mask=b"Example data", - model="string", + model="gpt-image-1.5", n=1, output_compression=100, output_format="png", @@ -398,7 +398,7 @@ async def test_method_edit_with_all_params_overload_2(self, async_client: AsyncO background="transparent", input_fidelity="high", mask=b"Example data", - model="string", + model="gpt-image-1.5", n=1, output_compression=100, output_format="png", @@ -449,7 +449,7 @@ async def test_method_generate_with_all_params_overload_1(self, async_client: As image = await async_client.images.generate( prompt="A cute baby sea otter", background="transparent", - model="string", + model="gpt-image-1.5", moderation="low", n=1, output_compression=100, @@ -502,7 +502,7 @@ async def test_method_generate_with_all_params_overload_2(self, async_client: As prompt="A cute baby sea otter", stream=True, background="transparent", - model="string", + model="gpt-image-1.5", moderation="low", n=1, output_compression=100, diff --git a/tests/api_resources/test_moderations.py b/tests/api_resources/test_moderations.py index 870c9e342f..8487d4a7a6 100644 --- a/tests/api_resources/test_moderations.py +++ b/tests/api_resources/test_moderations.py @@ -28,7 +28,7 @@ def test_method_create(self, client: OpenAI) -> None: def test_method_create_with_all_params(self, client: OpenAI) -> None: moderation = client.moderations.create( input="I want to kill them.", - model="string", + model="omni-moderation-latest", ) assert_matches_type(ModerationCreateResponse, moderation, path=["response"]) @@ -73,7 +73,7 @@ async def test_method_create(self, async_client: AsyncOpenAI) -> None: async def test_method_create_with_all_params(self, async_client: AsyncOpenAI) -> None: moderation = await async_client.moderations.create( input="I want to kill them.", - model="string", + model="omni-moderation-latest", ) assert_matches_type(ModerationCreateResponse, moderation, path=["response"]) diff --git a/tests/api_resources/test_responses.py b/tests/api_resources/test_responses.py index deaf35970f..0b871d525d 100644 --- a/tests/api_resources/test_responses.py +++ b/tests/api_resources/test_responses.py @@ -40,7 +40,7 @@ def test_method_create_with_all_params_overload_1(self, client: OpenAI) -> None: include=["file_search_call.results"], input="string", instructions="instructions", - max_output_tokens=0, + max_output_tokens=16, max_tool_calls=0, metadata={"foo": "string"}, model="gpt-5.1", @@ -128,7 +128,7 @@ def test_method_create_with_all_params_overload_2(self, client: OpenAI) -> None: include=["file_search_call.results"], input="string", instructions="instructions", - max_output_tokens=0, + max_output_tokens=16, max_tool_calls=0, metadata={"foo": "string"}, model="gpt-5.1", @@ -451,7 +451,7 @@ async def test_method_create_with_all_params_overload_1(self, async_client: Asyn include=["file_search_call.results"], input="string", instructions="instructions", - max_output_tokens=0, + max_output_tokens=16, max_tool_calls=0, metadata={"foo": "string"}, model="gpt-5.1", @@ -539,7 +539,7 @@ async def test_method_create_with_all_params_overload_2(self, async_client: Asyn include=["file_search_call.results"], input="string", instructions="instructions", - max_output_tokens=0, + max_output_tokens=16, max_tool_calls=0, metadata={"foo": "string"}, model="gpt-5.1", diff --git a/tests/api_resources/test_videos.py b/tests/api_resources/test_videos.py index 73acf6d05d..d6b51f9674 100644 --- a/tests/api_resources/test_videos.py +++ b/tests/api_resources/test_videos.py @@ -41,7 +41,7 @@ def test_method_create_with_all_params(self, client: OpenAI) -> None: video = client.videos.create( prompt="x", input_reference=b"Example data", - model="string", + model="sora-2", seconds="4", size="720x1280", ) @@ -442,7 +442,7 @@ async def test_method_create_with_all_params(self, async_client: AsyncOpenAI) -> video = await async_client.videos.create( prompt="x", input_reference=b"Example data", - model="string", + model="sora-2", seconds="4", size="720x1280", ) diff --git a/tests/test_extract_files.py b/tests/test_extract_files.py index 0f6fb04d7d..dcf85bd7c7 100644 --- a/tests/test_extract_files.py +++ b/tests/test_extract_files.py @@ -35,6 +35,15 @@ def test_multiple_files() -> None: assert query == {"documents": [{}, {}]} +def test_top_level_file_array() -> None: + query = {"files": [b"file one", b"file two"], "title": "hello"} + assert extract_files(query, paths=[["files", ""]]) == [ + ("files[]", b"file one"), + ("files[]", b"file two"), + ] + assert query == {"title": "hello"} + + @pytest.mark.parametrize( "query,paths,expected", [ diff --git a/tests/test_send_queue.py b/tests/test_send_queue.py new file mode 100644 index 0000000000..61db916bc4 --- /dev/null +++ b/tests/test_send_queue.py @@ -0,0 +1,128 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +import pytest + +from openai._exceptions import WebSocketQueueFullError +from openai._send_queue import SendQueue + + +class TestSendQueue: + def test_enqueue_and_drain(self) -> None: + q = SendQueue() + q.enqueue('{"type": "session.update"}') + q.enqueue('{"type": "response.create"}') + assert len(q) == 2 + + items = q.drain() + assert items == ['{"type": "session.update"}', '{"type": "response.create"}'] + assert len(q) == 0 + + def test_enqueue_respects_byte_limit(self) -> None: + q = SendQueue(max_bytes=10) + q.enqueue("12345") # 5 bytes, fits + with pytest.raises(WebSocketQueueFullError): + q.enqueue("123456") # 6 bytes, would exceed 10 + assert len(q) == 1 + + def test_drain_empties_queue(self) -> None: + q = SendQueue() + q.enqueue("hello") + q.drain() + assert len(q) == 0 + assert not q + + def test_bool(self) -> None: + q = SendQueue() + assert not q + q.enqueue("x") + assert q + + def test_flush_sync(self) -> None: + q = SendQueue() + q.enqueue("a") + q.enqueue("b") + q.enqueue("c") + + sent: list[str] = [] + q.flush_sync(sent.append) + assert sent == ["a", "b", "c"] + assert len(q) == 0 + + def test_flush_sync_requeues_on_failure(self) -> None: + q = SendQueue() + q.enqueue("a") + q.enqueue("b") + q.enqueue("c") + + sent: list[str] = [] + + def failing_send(data: str) -> None: + if data == "b": + raise RuntimeError("send failed") + sent.append(data) + + with pytest.raises(RuntimeError, match="send failed"): + q.flush_sync(failing_send) + + assert sent == ["a"] + # b and c should be re-queued + remaining = q.drain() + assert remaining == ["b", "c"] + + @pytest.mark.asyncio + async def test_flush_async(self) -> None: + q = SendQueue() + q.enqueue("a") + q.enqueue("b") + + sent: list[str] = [] + + async def async_send(data: str) -> None: + sent.append(data) + + await q.flush_async(async_send) + assert sent == ["a", "b"] + assert len(q) == 0 + + @pytest.mark.asyncio + async def test_flush_async_requeues_on_failure(self) -> None: + q = SendQueue() + q.enqueue("a") + q.enqueue("b") + q.enqueue("c") + + sent: list[str] = [] + + async def failing_send(data: str) -> None: + if data == "b": + raise RuntimeError("send failed") + sent.append(data) + + with pytest.raises(RuntimeError, match="send failed"): + await q.flush_async(failing_send) + + assert sent == ["a"] + remaining = q.drain() + assert remaining == ["b", "c"] + + def test_flush_sync_preserves_new_items_on_failure(self) -> None: + """If items are enqueued after flush starts and flush fails, + the re-queued items should come before the new items.""" + q = SendQueue() + q.enqueue("a") + q.enqueue("b") + + def failing_send(data: str) -> None: + if data == "b": + # Simulate another thread enqueuing during flush + q.enqueue("new") + raise RuntimeError("fail") + + with pytest.raises(RuntimeError): + q.flush_sync(failing_send) + + # "b" (failed) should come before "new" (added during flush) + remaining = q.drain() + assert remaining == ["b", "new"]