From 3b5b8a24e130c7f968b51ab140d2b54438da6664 Mon Sep 17 00:00:00 2001 From: Peter Holloway Date: Thu, 21 May 2026 13:46:19 +0100 Subject: [PATCH 01/40] refactor: Move auth extractors into authentication module And split the auth check into two so that other methods can access the raw bearer token if required. --- src/blueapi/service/authentication.py | 55 ++++++++++++++++- src/blueapi/service/main.py | 31 +--------- .../unit_tests/service/test_authentication.py | 60 +++++++++++++++++-- 3 files changed, 112 insertions(+), 34 deletions(-) diff --git a/src/blueapi/service/authentication.py b/src/blueapi/service/authentication.py index b107f7b2b2..944dccf5d3 100644 --- a/src/blueapi/service/authentication.py +++ b/src/blueapi/service/authentication.py @@ -6,16 +6,20 @@ import time import webbrowser from abc import ABC, abstractmethod +from collections.abc import Mapping from functools import cached_property from http import HTTPStatus from pathlib import Path -from typing import Any, cast +from typing import Annotated, Any, cast import httpx import jwt import requests +from fastapi import Depends, HTTPException, Request +from fastapi.security.utils import get_authorization_scheme_param from pydantic import TypeAdapter from requests.auth import AuthBase +from starlette.status import HTTP_401_UNAUTHORIZED from blueapi.config import OIDCConfig, ServiceAccount from blueapi.service.model import Cache @@ -272,3 +276,52 @@ def get_access_token(self): def sync_auth_flow(self, request): request.headers["Authorization"] = f"Bearer {self.get_access_token()}" yield request + + +def unchecked_bearer_token(req: Request) -> str | None: + """Get bearer token value from authorization header""" + auth = req.headers.get("Authorization") + scheme, param = get_authorization_scheme_param(auth) + if scheme.casefold() != "bearer": + return None + return param.strip() + + +UncheckedBearerToken = Annotated[str | None, Depends(unchecked_bearer_token)] + + +def build_access_token_check(config: OIDCConfig): + """ + Create a function to validate the bearer token of requests + + The returned function should be used via fastAPI's 'Depends' mechanism to + ensure users are authenticated + """ + jwkclient = jwt.PyJWKClient(config.jwks_uri) + + def validate_bearer_token(request: Request, token: UncheckedBearerToken): + """Check that a bearer token is valid and inject into request state""" + if not token: + raise HTTPException( + status_code=HTTP_401_UNAUTHORIZED, + detail="Not authenticated", + headers={"WWW-Authenticate": "Bearer"}, + ) + + signing_key = jwkclient.get_signing_key_from_jwt(token) + decoded: dict[str, Any] = jwt.decode( + token, + signing_key.key, + algorithms=config.id_token_signing_alg_values_supported, + verify=True, + audience=config.client_audience, + issuer=config.issuer, + ) + request.state.decoded_access_token = decoded + + return validate_bearer_token + + +def access_token(request: Request) -> Mapping[str, Any] | None: + """Get the decoded and verified access token of the user making the request""" + return getattr(request.state, "decoded_access_token", None) diff --git a/src/blueapi/service/main.py b/src/blueapi/service/main.py index a53c46885a..5dae9462a2 100644 --- a/src/blueapi/service/main.py +++ b/src/blueapi/service/main.py @@ -19,7 +19,6 @@ from fastapi.datastructures import Address from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import RedirectResponse, StreamingResponse -from fastapi.security import OAuth2AuthorizationCodeBearer from observability_utils.tracing import ( add_span_attributes, get_tracer, @@ -37,6 +36,7 @@ from blueapi import __version__ from blueapi.config import ApplicationConfig, OIDCConfig, Tag from blueapi.service import interface +from blueapi.service.authentication import build_access_token_check from blueapi.worker import TrackableTask, WorkerState from blueapi.worker.event import TaskStatusEnum @@ -61,6 +61,7 @@ RUNNER: WorkerDispatcher | None = None LOGGER = logging.getLogger(__name__) +TRACER = get_tracer("interface") def _runner() -> WorkerDispatcher: @@ -117,7 +118,7 @@ def get_app(config: ApplicationConfig): ) dependencies = [] if config.oidc: - dependencies.append(Depends(decode_access_token(config.oidc))) + dependencies.append(Depends(build_access_token_check(config.oidc))) app.swagger_ui_init_oauth = { "clientId": "NOT_SUPPORTED", } @@ -140,32 +141,6 @@ def get_app(config: ApplicationConfig): return app -def decode_access_token(config: OIDCConfig): - jwkclient = jwt.PyJWKClient(config.jwks_uri) - oauth_scheme = OAuth2AuthorizationCodeBearer( - authorizationUrl=config.authorization_endpoint, - tokenUrl=config.token_endpoint, - refreshUrl=config.token_endpoint, - ) - - def inner(request: Request, access_token: str = Depends(oauth_scheme)): - signing_key = jwkclient.get_signing_key_from_jwt(access_token) - decoded: dict[str, Any] = jwt.decode( - access_token, - signing_key.key, - algorithms=config.id_token_signing_alg_values_supported, - verify=True, - audience=config.client_audience, - issuer=config.issuer, - ) - request.state.decoded_access_token = decoded - - return inner - - -TRACER = get_tracer("interface") - - async def on_key_error_404(_: Request, __: Exception): return JSONResponse( status_code=status.HTTP_404_NOT_FOUND, diff --git a/tests/unit_tests/service/test_authentication.py b/tests/unit_tests/service/test_authentication.py index 88227706be..01bc426e20 100644 --- a/tests/unit_tests/service/test_authentication.py +++ b/tests/unit_tests/service/test_authentication.py @@ -8,15 +8,19 @@ import pytest import responses import respx +from fastapi import HTTPException from pydantic import SecretStr from starlette.status import HTTP_200_OK, HTTP_403_FORBIDDEN from blueapi.config import OIDCConfig, ServiceAccount -from blueapi.service import main +from blueapi.service import authentication from blueapi.service.authentication import ( SessionCacheManager, SessionManager, TiledAuth, + access_token, + build_access_token_check, + unchecked_bearer_token, ) @@ -124,9 +128,9 @@ def test_poll_for_token_timeout( def test_server_raises_exception_for_invalid_token( oidc_config: OIDCConfig, mock_authn_server: responses.RequestsMock ): - inner = main.decode_access_token(oidc_config) + inner = authentication.build_access_token_check(oidc_config) with pytest.raises(jwt.PyJWTError): - inner(Mock(), access_token="Invalid Token") + inner(Mock(), token="Invalid Token") def test_processes_valid_token( @@ -134,8 +138,8 @@ def test_processes_valid_token( mock_authn_server: responses.RequestsMock, valid_token_with_jwt, ): - inner = main.decode_access_token(oidc_config) - inner(Mock(), access_token=valid_token_with_jwt["access_token"]) + inner = authentication.build_access_token_check(oidc_config) + inner(Mock(), token=valid_token_with_jwt["access_token"]) def test_session_cache_manager_returns_writable_file_path(tmp_path): @@ -182,3 +186,49 @@ def test_tiled_auth_sync_auth_flow(): result = next(flow) assert result.headers["Authorization"] == f"Bearer {access_token}" + + +@pytest.mark.parametrize( + "header,token", + [ + (None, None), + ("ApiKey foobar", None), + ("Bearer foobar", "foobar"), + ("Bearer with_whitespace ", "with_whitespace"), + ("Bearerfoobar", None), + ], +) +def test_unchecked_bearer_token(header: str | None, token: str | None): + req = Mock() + req.headers.get.side_effect = lambda key: header if key == "Authorization" else None + + assert unchecked_bearer_token(req) == token + + +def test_access_token(): + req = Mock() + req.state.decoded_access_token = {"foo": "bar"} + + assert access_token(req) == {"foo": "bar"} + + +def test_access_token_without_token(): + req = Mock() + del req.state.decoded_access_token + + assert access_token(req) is None + + +@patch("blueapi.service.authentication.jwt") +def test_build_access_token(mock_jwt: Mock): + # Return None when building client to ensure no field/method access + mock_jwt.PyJWKClient.return_value = None + oidc_config = Mock() + req = Mock() + + validate_fn = build_access_token_check(oidc_config) + + with pytest.raises(HTTPException, match="401"): + validate_fn(req, token=None) + + mock_jwt.decode.assert_not_called() From e6ca161f88e3a8887a78d06e0f760a5c1b5fed0d Mon Sep 17 00:00:00 2001 From: Peter Holloway Date: Fri, 17 Apr 2026 16:21:11 +0100 Subject: [PATCH 02/40] Use Depends injection to extract user name --- src/blueapi/service/authentication.py | 9 +++++++++ src/blueapi/service/main.py | 15 ++++----------- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/src/blueapi/service/authentication.py b/src/blueapi/service/authentication.py index 944dccf5d3..64dfc30040 100644 --- a/src/blueapi/service/authentication.py +++ b/src/blueapi/service/authentication.py @@ -325,3 +325,12 @@ def validate_bearer_token(request: Request, token: UncheckedBearerToken): def access_token(request: Request) -> Mapping[str, Any] | None: """Get the decoded and verified access token of the user making the request""" return getattr(request.state, "decoded_access_token", None) + + +def fedid( + access_token: Annotated[Mapping[str, Any] | None, Depends(access_token)], +) -> str | None: + return access_token.get("fedid") if access_token else None + + +Fedid = Annotated[str | None, Depends(fedid)] diff --git a/src/blueapi/service/main.py b/src/blueapi/service/main.py index 5dae9462a2..546155f5c2 100644 --- a/src/blueapi/service/main.py +++ b/src/blueapi/service/main.py @@ -2,7 +2,7 @@ import urllib.parse from collections.abc import Awaitable, Callable from contextlib import asynccontextmanager -from typing import Annotated, Any +from typing import Annotated import jwt from fastapi import ( @@ -36,7 +36,7 @@ from blueapi import __version__ from blueapi.config import ApplicationConfig, OIDCConfig, Tag from blueapi.service import interface -from blueapi.service.authentication import build_access_token_check +from blueapi.service.authentication import Fedid, build_access_token_check from blueapi.worker import TrackableTask, WorkerState from blueapi.worker.event import TaskStatusEnum @@ -267,18 +267,11 @@ def submit_task( response: Response, task_request: Annotated[TaskRequest, Body(..., examples=[example_task_request])], runner: Annotated[WorkerDispatcher, Depends(_runner)], + user: Fedid, ) -> TaskResponse: """Submit a task to the worker.""" try: - # Extract user from jwt if using OIDC (if jwt exists) - access_token: dict[str, Any] | None = getattr( - request.state, "decoded_access_token", None - ) - if access_token: - user: str = access_token.get("fedid", "Unknown") - else: - user = "Unknown" - + user = user or "Unknown" task_id: str = runner.run(interface.submit_task, task_request, {"user": user}) response.headers["Location"] = f"{request.url}/{task_id}" return TaskResponse(task_id=task_id) From 0e68df413864361a92d45490d4cba7d537829062 Mon Sep 17 00:00:00 2001 From: Peter Holloway Date: Thu, 21 May 2026 11:42:59 +0100 Subject: [PATCH 03/40] feat: Add OpaConfig for authorization configuration --- helm/blueapi/config_schema.json | 27 +++++++++++++++++++++++++++ helm/blueapi/values.schema.json | 26 ++++++++++++++++++++++++++ pyproject.toml | 1 + src/blueapi/config.py | 6 ++++++ tests/unit_tests/test_config.py | 4 ++++ uv.lock | 2 ++ 6 files changed, 66 insertions(+) diff --git a/helm/blueapi/config_schema.json b/helm/blueapi/config_schema.json index b5d0c9bf3d..1ad4e82bfe 100644 --- a/helm/blueapi/config_schema.json +++ b/helm/blueapi/config_schema.json @@ -330,6 +330,22 @@ "type": "object", "$id": "OIDCConfig" }, + "OpaConfig": { + "additionalProperties": false, + "properties": { + "root": { + "default": "http://localhost:8181/", + "format": "uri", + "maxLength": 2083, + "minLength": 1, + "title": "Root", + "type": "string" + } + }, + "title": "OpaConfig", + "type": "object", + "$id": "OpaConfig" + }, "PlanSource": { "additionalProperties": false, "properties": { @@ -612,6 +628,17 @@ } ], "default": null + }, + "opa": { + "anyOf": [ + { + "$ref": "OpaConfig" + }, + { + "type": "null" + } + ], + "default": null } }, "title": "ApplicationConfig", diff --git a/helm/blueapi/values.schema.json b/helm/blueapi/values.schema.json index 74deedadb2..6de532cc88 100644 --- a/helm/blueapi/values.schema.json +++ b/helm/blueapi/values.schema.json @@ -751,6 +751,22 @@ }, "additionalProperties": false }, + "OpaConfig": { + "$id": "OpaConfig", + "title": "OpaConfig", + "type": "object", + "properties": { + "root": { + "title": "Root", + "default": "http://localhost:8181/", + "type": "string", + "format": "uri", + "maxLength": 2083, + "minLength": 1 + } + }, + "additionalProperties": false + }, "PlanSource": { "$id": "PlanSource", "title": "PlanSource", @@ -1011,6 +1027,16 @@ } ] }, + "opa": { + "anyOf": [ + { + "$ref": "OpaConfig" + }, + { + "type": "null" + } + ] + }, "scratch": { "anyOf": [ { diff --git a/pyproject.toml b/pyproject.toml index 659779994a..9f4231c76c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,6 +37,7 @@ dependencies = [ "tomlkit", "graypy>=2.1.0", "httpx>=0.28.1", + "aiohttp>=3.13.5", ] dynamic = ["version"] license.file = "LICENSE" diff --git a/src/blueapi/config.py b/src/blueapi/config.py index 83d6d70211..4c2431cf1e 100644 --- a/src/blueapi/config.py +++ b/src/blueapi/config.py @@ -296,6 +296,10 @@ class Tag(StrEnum): META = "Meta" +class OpaConfig(BlueapiBaseModel): + root: HttpUrl = HttpUrl("http://localhost:8181") + + class ApplicationConfig(BlueapiBaseModel): """ Config for the worker application as a whole. Root of @@ -335,6 +339,7 @@ class ApplicationConfig(BlueapiBaseModel): oidc: OIDCConfig | None = None auth_token_path: Path | None = None numtracker: NumtrackerConfig | None = None + opa: OpaConfig | None = None def __eq__(self, other: object) -> bool: if isinstance(other, ApplicationConfig): @@ -343,6 +348,7 @@ def __eq__(self, other: object) -> bool: & (self.env == other.env) & (self.logging == other.logging) & (self.api == other.api) + & (self.opa == other.opa) ) return False diff --git a/tests/unit_tests/test_config.py b/tests/unit_tests/test_config.py index 30fe551c4c..7ce98f6fe6 100644 --- a/tests/unit_tests/test_config.py +++ b/tests/unit_tests/test_config.py @@ -337,6 +337,9 @@ def test_config_yaml_parsed(temp_yaml_config_file): } ], }, + "opa": { + "root": "http://opa.example.com/", + }, }, { "stomp": { @@ -392,6 +395,7 @@ def test_config_yaml_parsed(temp_yaml_config_file): } ], }, + "opa": None, }, ], indirect=True, diff --git a/uv.lock b/uv.lock index b2d10b5bd5..eab94a8670 100644 --- a/uv.lock +++ b/uv.lock @@ -420,6 +420,7 @@ name = "blueapi" source = { editable = "." } dependencies = [ { name = "aioca" }, + { name = "aiohttp" }, { name = "bluesky", extra = ["plotting"] }, { name = "bluesky-stomp" }, { name = "click" }, @@ -481,6 +482,7 @@ dev = [ [package.metadata] requires-dist = [ { name = "aioca" }, + { name = "aiohttp", specifier = ">=3.13.5" }, { name = "bluesky", extras = ["plotting"], specifier = ">=1.14.0" }, { name = "bluesky-stomp", specifier = ">=0.1.6" }, { name = "click", specifier = ">=8.2.0" }, From 6af44c062101c32c2b9b2e6ebfa88039fc09875b Mon Sep 17 00:00:00 2001 From: Peter Holloway Date: Thu, 21 May 2026 11:52:04 +0100 Subject: [PATCH 04/40] Add OpaClient to wrap OPA interactions --- helm/blueapi/config_schema.json | 5 ++++ helm/blueapi/values.schema.json | 5 ++++ src/blueapi/config.py | 1 + src/blueapi/service/authorization.py | 45 ++++++++++++++++++++++++++++ tests/unit_tests/test_config.py | 1 + 5 files changed, 57 insertions(+) create mode 100644 src/blueapi/service/authorization.py diff --git a/helm/blueapi/config_schema.json b/helm/blueapi/config_schema.json index 1ad4e82bfe..39a1b1aabe 100644 --- a/helm/blueapi/config_schema.json +++ b/helm/blueapi/config_schema.json @@ -340,6 +340,11 @@ "minLength": 1, "title": "Root", "type": "string" + }, + "audience": { + "default": "account", + "title": "Audience", + "type": "string" } }, "title": "OpaConfig", diff --git a/helm/blueapi/values.schema.json b/helm/blueapi/values.schema.json index 6de532cc88..f78dbfb848 100644 --- a/helm/blueapi/values.schema.json +++ b/helm/blueapi/values.schema.json @@ -756,6 +756,11 @@ "title": "OpaConfig", "type": "object", "properties": { + "audience": { + "title": "Audience", + "default": "account", + "type": "string" + }, "root": { "title": "Root", "default": "http://localhost:8181/", diff --git a/src/blueapi/config.py b/src/blueapi/config.py index 4c2431cf1e..538d141589 100644 --- a/src/blueapi/config.py +++ b/src/blueapi/config.py @@ -298,6 +298,7 @@ class Tag(StrEnum): class OpaConfig(BlueapiBaseModel): root: HttpUrl = HttpUrl("http://localhost:8181") + audience: str = "account" class ApplicationConfig(BlueapiBaseModel): diff --git a/src/blueapi/service/authorization.py b/src/blueapi/service/authorization.py new file mode 100644 index 0000000000..151e60a0cb --- /dev/null +++ b/src/blueapi/service/authorization.py @@ -0,0 +1,45 @@ +import logging +from collections.abc import Mapping +from contextlib import AbstractAsyncContextManager, aclosing, nullcontext +from typing import Any, Self + +from aiohttp import ClientSession + +from blueapi.config import OpaConfig + +LOGGER = logging.getLogger(__name__) + + +class OpaClient: + def __init__(self, instrument: str, config: OpaConfig): + LOGGER.info("Creating OpaClient for %s with config %s", instrument, config) + self._instrument = instrument + self._config = config + self._session = ClientSession(base_url=config.root.encoded_string()) + self._audience = config.audience + + async def aclose(self): + LOGGER.info("Closing OPA session") + await self._session.close() + + async def _call_opa(self, endpoint: str, data: Mapping[str, Any]) -> bool: + resp = await self._session.post( + endpoint, + json={ + "input": { + "beamline": self._instrument, + "audience": self._audience, + **data, + } + }, + ) + return (await resp.json())["result"] + + @classmethod + def for_config( + cls, instrument: str, config: OpaConfig | None + ) -> AbstractAsyncContextManager[Self | None]: + if config: + return aclosing(cls(instrument, config)) + LOGGER.info("No OPA config provided - not creating OpaClient") + return nullcontext() diff --git a/tests/unit_tests/test_config.py b/tests/unit_tests/test_config.py index 7ce98f6fe6..97a4ce3806 100644 --- a/tests/unit_tests/test_config.py +++ b/tests/unit_tests/test_config.py @@ -339,6 +339,7 @@ def test_config_yaml_parsed(temp_yaml_config_file): }, "opa": { "root": "http://opa.example.com/", + "audience": "account", }, }, { From d2cee5427fea0299d682015c25a432c564164aed Mon Sep 17 00:00:00 2001 From: Peter Holloway Date: Thu, 21 May 2026 13:32:18 +0100 Subject: [PATCH 05/40] Create OpaClient as part of server lifecycle --- src/blueapi/service/main.py | 8 +++++++- tests/unit_tests/service/test_main.py | 18 +++++++++++++++++- 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/src/blueapi/service/main.py b/src/blueapi/service/main.py index 546155f5c2..23baa6f13e 100644 --- a/src/blueapi/service/main.py +++ b/src/blueapi/service/main.py @@ -40,6 +40,7 @@ from blueapi.worker import TrackableTask, WorkerState from blueapi.worker.event import TaskStatusEnum +from .authorization import OpaClient from .model import ( DeviceModel, DeviceResponse, @@ -93,8 +94,13 @@ def teardown_runner(): def lifespan(config: ApplicationConfig): @asynccontextmanager async def inner(app: FastAPI): + if not (meta := config.env.metadata): + raise ValueError("Instrument name is required in metadata") + setup_runner(config) - yield + async with OpaClient.for_config(meta.instrument, config.opa) as opa: + app.state.authz = opa + yield teardown_runner() return inner diff --git a/tests/unit_tests/service/test_main.py b/tests/unit_tests/service/test_main.py index a7e04105c2..1801b8756b 100644 --- a/tests/unit_tests/service/test_main.py +++ b/tests/unit_tests/service/test_main.py @@ -1,5 +1,5 @@ from unittest import mock -from unittest.mock import Mock, call +from unittest.mock import Mock, call, patch import pytest from fastapi import FastAPI, Request @@ -10,6 +10,7 @@ from blueapi.service.main import ( add_version_headers, get_passthrough_headers, + lifespan, log_request_details, ) @@ -79,3 +80,18 @@ def test_get_passthrough_headers( request = Mock(spec=Request) request.headers = headers assert get_passthrough_headers(request) == expected_headers + + +@patch("blueapi.service.main.teardown_runner") +@patch("blueapi.service.main.setup_runner") +async def test_lifespan(setup: Mock, teardown: Mock): + conf = ApplicationConfig() + lifespan_fn = lifespan(conf) + + app = Mock() + + async with lifespan_fn(app): + setup.assert_called_once_with(conf) + teardown.assert_not_called() + + teardown.assert_called_once() From 7ddf56d7dc37b7729f4106c3bde022366dcf057e Mon Sep 17 00:00:00 2001 From: Peter Holloway Date: Wed, 27 May 2026 14:54:30 +0100 Subject: [PATCH 06/40] Move instrument requirement into OpaClient init --- src/blueapi/service/authorization.py | 4 +++- src/blueapi/service/main.py | 6 ++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/blueapi/service/authorization.py b/src/blueapi/service/authorization.py index 151e60a0cb..aabd4a692a 100644 --- a/src/blueapi/service/authorization.py +++ b/src/blueapi/service/authorization.py @@ -37,9 +37,11 @@ async def _call_opa(self, endpoint: str, data: Mapping[str, Any]) -> bool: @classmethod def for_config( - cls, instrument: str, config: OpaConfig | None + cls, instrument: str | None, config: OpaConfig | None ) -> AbstractAsyncContextManager[Self | None]: if config: + if not instrument: + raise ValueError("Instrument name is required for OPA client") return aclosing(cls(instrument, config)) LOGGER.info("No OPA config provided - not creating OpaClient") return nullcontext() diff --git a/src/blueapi/service/main.py b/src/blueapi/service/main.py index 23baa6f13e..462c9318f1 100644 --- a/src/blueapi/service/main.py +++ b/src/blueapi/service/main.py @@ -94,11 +94,9 @@ def teardown_runner(): def lifespan(config: ApplicationConfig): @asynccontextmanager async def inner(app: FastAPI): - if not (meta := config.env.metadata): - raise ValueError("Instrument name is required in metadata") - + meta = config.env.metadata setup_runner(config) - async with OpaClient.for_config(meta.instrument, config.opa) as opa: + async with OpaClient.for_config(meta and meta.instrument, config.opa) as opa: app.state.authz = opa yield teardown_runner() From 0374da6d0a6dab555b887955d5b4ff698abf0f3f Mon Sep 17 00:00:00 2001 From: Peter Holloway Date: Thu, 28 May 2026 17:21:00 +0100 Subject: [PATCH 07/40] Add tests for OpaClient --- .../unit_tests/service/test_authorization.py | 62 +++++++++++++++++++ 1 file changed, 62 insertions(+) create mode 100644 tests/unit_tests/service/test_authorization.py diff --git a/tests/unit_tests/service/test_authorization.py b/tests/unit_tests/service/test_authorization.py new file mode 100644 index 0000000000..608269d20b --- /dev/null +++ b/tests/unit_tests/service/test_authorization.py @@ -0,0 +1,62 @@ +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from pydantic import HttpUrl + +from blueapi.config import OpaConfig +from blueapi.service.authorization import ( + OpaClient, +) + +# Reusable client patch decorator +patch_client_session = patch( + "blueapi.service.authorization.ClientSession", + name="mock_client_session", + spec=True, +) + + +@pytest.fixture(scope="module") +def opa_config() -> OpaConfig: + return OpaConfig( + root=HttpUrl("http://auth.example.com"), + ) + + +@patch_client_session +async def test_session_closed(session: MagicMock, opa_config: OpaConfig): + async with OpaClient.for_config("p45", opa_config): + pass + session().close.assert_called_once() + + +@patch_client_session +async def test_opa_client_for_config(session: MagicMock, opa_config: OpaConfig): + async with OpaClient.for_config("p45", opa_config) as opa: + assert opa is not None + session.assert_called_once_with(base_url="http://auth.example.com/") + + +@pytest.mark.parametrize("instrument", [None, "p99"]) +async def test_opa_client_without_config(instrument: str | None): + async with OpaClient.for_config(instrument, None) as opa: + assert opa is None + + +async def test_opa_fails_without_instrument(opa_config: OpaConfig): + with pytest.raises(ValueError, match="Instrument name is required"): + OpaClient.for_config(None, opa_config) + + +@patch_client_session +async def test_opa_adds_input_fields(session: MagicMock, opa_config: OpaConfig): + session.return_value.post = AsyncMock() + async with OpaClient.for_config("p45", opa_config) as opa: + assert opa is not None + await opa._call_opa("foo/bar", data={"foo": "bar"}) + + session.assert_called_once() + session().post.assert_called_once_with( + "foo/bar", + json={"input": {"beamline": "p45", "audience": "account", "foo": "bar"}}, + ) From 8bc7d84e38dac51d7a82255874d4fb43ccce933b Mon Sep 17 00:00:00 2001 From: Peter Holloway Date: Thu, 21 May 2026 16:45:44 +0100 Subject: [PATCH 08/40] Validate tiled service account configuration at startup --- helm/blueapi/config_schema.json | 7 +++++++ helm/blueapi/values.schema.json | 7 +++++++ src/blueapi/config.py | 1 + src/blueapi/service/authorization.py | 29 +++++++++++++++++++++++++++- src/blueapi/service/main.py | 3 ++- tests/unit_tests/test_config.py | 1 + 6 files changed, 46 insertions(+), 2 deletions(-) diff --git a/helm/blueapi/config_schema.json b/helm/blueapi/config_schema.json index 39a1b1aabe..dd8a48433a 100644 --- a/helm/blueapi/config_schema.json +++ b/helm/blueapi/config_schema.json @@ -345,8 +345,15 @@ "default": "account", "title": "Audience", "type": "string" + }, + "tiled_service_account_check": { + "title": "Tiled Service Account Check", + "type": "string" } }, + "required": [ + "tiled_service_account_check" + ], "title": "OpaConfig", "type": "object", "$id": "OpaConfig" diff --git a/helm/blueapi/values.schema.json b/helm/blueapi/values.schema.json index f78dbfb848..60310135cc 100644 --- a/helm/blueapi/values.schema.json +++ b/helm/blueapi/values.schema.json @@ -755,6 +755,9 @@ "$id": "OpaConfig", "title": "OpaConfig", "type": "object", + "required": [ + "tiled_service_account_check" + ], "properties": { "audience": { "title": "Audience", @@ -768,6 +771,10 @@ "format": "uri", "maxLength": 2083, "minLength": 1 + }, + "tiled_service_account_check": { + "title": "Tiled Service Account Check", + "type": "string" } }, "additionalProperties": false diff --git a/src/blueapi/config.py b/src/blueapi/config.py index 538d141589..c56415bfe6 100644 --- a/src/blueapi/config.py +++ b/src/blueapi/config.py @@ -299,6 +299,7 @@ class Tag(StrEnum): class OpaConfig(BlueapiBaseModel): root: HttpUrl = HttpUrl("http://localhost:8181") audience: str = "account" + tiled_service_account_check: str class ApplicationConfig(BlueapiBaseModel): diff --git a/src/blueapi/service/authorization.py b/src/blueapi/service/authorization.py index aabd4a692a..a4a7b5c985 100644 --- a/src/blueapi/service/authorization.py +++ b/src/blueapi/service/authorization.py @@ -5,7 +5,8 @@ from aiohttp import ClientSession -from blueapi.config import OpaConfig +from blueapi.config import OIDCConfig, OpaConfig, ServiceAccount +from blueapi.service.authentication import TiledAuth LOGGER = logging.getLogger(__name__) @@ -45,3 +46,29 @@ def for_config( return aclosing(cls(instrument, config)) LOGGER.info("No OPA config provided - not creating OpaClient") return nullcontext() + + async def require_tiled_service_account(self, token: str): + if not await self._call_opa( + self._config.tiled_service_account_check, + {"token": token, "beamline": self._instrument}, + ): + raise ValueError( + f"Tiled service account is not valid for '{self._instrument}'" + ) + + +async def validate_tiled_config( + tiled: ServiceAccount | str | None, oidc: OIDCConfig | None, opa: OpaClient | None +): + if not isinstance(tiled, ServiceAccount): + # can't validate an API key + return + + if not opa or not oidc: + LOGGER.info("Missing OPA or OIDC configuration required to validate tiled auth") + return + + LOGGER.info("Validating tiled configuration") + tiled.token_url = oidc.token_endpoint + auth = TiledAuth(tiled) + await opa.require_tiled_service_account(auth.get_access_token()) diff --git a/src/blueapi/service/main.py b/src/blueapi/service/main.py index 462c9318f1..3114fa73f5 100644 --- a/src/blueapi/service/main.py +++ b/src/blueapi/service/main.py @@ -40,7 +40,7 @@ from blueapi.worker import TrackableTask, WorkerState from blueapi.worker.event import TaskStatusEnum -from .authorization import OpaClient +from .authorization import OpaClient, validate_tiled_config from .model import ( DeviceModel, DeviceResponse, @@ -98,6 +98,7 @@ async def inner(app: FastAPI): setup_runner(config) async with OpaClient.for_config(meta and meta.instrument, config.opa) as opa: app.state.authz = opa + await validate_tiled_config(config.tiled.authentication, config.oidc, opa) yield teardown_runner() diff --git a/tests/unit_tests/test_config.py b/tests/unit_tests/test_config.py index 97a4ce3806..5cbb00c1b7 100644 --- a/tests/unit_tests/test_config.py +++ b/tests/unit_tests/test_config.py @@ -340,6 +340,7 @@ def test_config_yaml_parsed(temp_yaml_config_file): "opa": { "root": "http://opa.example.com/", "audience": "account", + "tiled_service_account_check": "v1/tiled_service_account", }, }, { From 41e28cae607108fee856a05f3ba8c55cc6e0b53f Mon Sep 17 00:00:00 2001 From: Peter Holloway Date: Thu, 28 May 2026 17:22:29 +0100 Subject: [PATCH 09/40] Add tests for tiled check --- .../unit_tests/service/test_authorization.py | 93 ++++++++++++++++++- 1 file changed, 91 insertions(+), 2 deletions(-) diff --git a/tests/unit_tests/service/test_authorization.py b/tests/unit_tests/service/test_authorization.py index 608269d20b..2491985806 100644 --- a/tests/unit_tests/service/test_authorization.py +++ b/tests/unit_tests/service/test_authorization.py @@ -1,11 +1,13 @@ -from unittest.mock import AsyncMock, MagicMock, patch +from contextlib import AbstractContextManager, nullcontext +from unittest.mock import AsyncMock, MagicMock, Mock, patch import pytest from pydantic import HttpUrl -from blueapi.config import OpaConfig +from blueapi.config import OIDCConfig, OpaConfig, ServiceAccount from blueapi.service.authorization import ( OpaClient, + validate_tiled_config, ) # Reusable client patch decorator @@ -20,9 +22,50 @@ def opa_config() -> OpaConfig: return OpaConfig( root=HttpUrl("http://auth.example.com"), + tiled_service_account_check="/auth/tiled", ) +@patch_client_session +@pytest.mark.parametrize( + "result,context", + [ + (False, pytest.raises(ValueError, match="Tiled service account is not valid ")), + (True, nullcontext()), + ], +) +async def test_tiled_service_account( + session: MagicMock, + opa_config: OpaConfig, + result: bool, + context: AbstractContextManager, +): + session.return_value.post = AsyncMock( + return_value=MagicMock(json=AsyncMock(return_value={"result": result})) + ) + + client = OpaClient(instrument="p99", config=opa_config) + + session.assert_called_once_with(base_url="http://auth.example.com/") + with context: + await client.require_tiled_service_account(token="foo_bar") + session().post.assert_called_once_with( + "/auth/tiled", + json={"input": {"token": "foo_bar", "beamline": "p99", "audience": "account"}}, + ) + + +@patch_client_session +async def test_exception_raised_when_opa_fails( + session: MagicMock, opa_config: OpaConfig +): + session.return_value.post = AsyncMock(side_effect=RuntimeError("Connection failed")) + async with OpaClient.for_config("p45", opa_config) as client: + assert client is not None + with pytest.raises(RuntimeError, match="Connection failed"): + await client.require_tiled_service_account(token="foo_bar") + + @patch_client_session async def test_session_closed(session: MagicMock, opa_config: OpaConfig): async with OpaClient.for_config("p45", opa_config): @@ -60,3 +103,49 @@ async def test_opa_adds_input_fields(session: MagicMock, opa_config: OpaConfig): "foo/bar", json={"input": {"beamline": "p45", "audience": "account", "foo": "bar"}}, ) + + +async def test_validate_tiled_config(): + opa = MagicMock(spec=OpaClient) + tiled = ServiceAccount() + oidc = Mock(spec=OIDCConfig) + oidc.token_endpoint = "token-endpoint" + with patch("blueapi.service.authorization.TiledAuth") as auth: + auth.return_value.get_access_token.return_value = "tiled-token" + await validate_tiled_config(tiled, oidc, opa) + + auth.assert_called_once_with(tiled) + opa.require_tiled_service_account.assert_called_once_with("tiled-token") + + +@pytest.mark.parametrize( + "tiled_auth,oidc,opa_client", + [ + (None, None, MagicMock(spec=OpaClient)), + ( + None, + OIDCConfig(well_known_url="http://example.com", client_id="test-client"), + MagicMock(spec=OpaClient), + ), + ("api_key", None, MagicMock(spec=OpaClient)), + ( + "api_key", + OIDCConfig(well_known_url="http://example.com", client_id="test-client"), + MagicMock(spec=OpaClient), + ), + (ServiceAccount(), None, MagicMock(spec=OpaClient)), + ( + ServiceAccount(), + OIDCConfig(well_known_url="http://example.com", client_id="test-client"), + None, + ), + ], +) +async def test_validate_tiled_config_with_missing_config( + tiled_auth: ServiceAccount | str | None, + oidc: OIDCConfig | None, + opa_client: MagicMock | None, +): + assert await validate_tiled_config(tiled_auth, oidc, opa_client) is None + if opa_client is not None: + opa_client.require_tiled_service_account.assert_not_called() From 4495151547991b22baf1bc0aaf74b8b62d914b0d Mon Sep 17 00:00:00 2001 From: Peter Holloway Date: Tue, 26 May 2026 12:01:54 +0100 Subject: [PATCH 10/40] Add opa dependency function to create OpaUserClient --- src/blueapi/service/authorization.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/src/blueapi/service/authorization.py b/src/blueapi/service/authorization.py index a4a7b5c985..c4f72ec45d 100644 --- a/src/blueapi/service/authorization.py +++ b/src/blueapi/service/authorization.py @@ -1,12 +1,14 @@ import logging from collections.abc import Mapping from contextlib import AbstractAsyncContextManager, aclosing, nullcontext -from typing import Any, Self +from typing import Any, Self, cast from aiohttp import ClientSession +from fastapi import Depends, HTTPException, Request +from starlette import status from blueapi.config import OIDCConfig, OpaConfig, ServiceAccount -from blueapi.service.authentication import TiledAuth +from blueapi.service.authentication import TiledAuth, unchecked_bearer_token LOGGER = logging.getLogger(__name__) @@ -72,3 +74,14 @@ async def validate_tiled_config( tiled.token_url = oidc.token_endpoint auth = TiledAuth(tiled) await opa.require_tiled_service_account(auth.get_access_token()) + + +async def opa( + request: Request, token: str | None = Depends(unchecked_bearer_token) +) -> OpaUserClient | None: + + if opa := cast(OpaClient | None, getattr(request.app.state, "authz", None)): + if not token: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) + return opa.for_token(token) + return None From f2f02de2f5f69b4268c9e14b60bd33be689a991f Mon Sep 17 00:00:00 2001 From: Peter Holloway Date: Thu, 28 May 2026 17:24:54 +0100 Subject: [PATCH 11/40] test opa dependency function --- src/blueapi/service/authorization.py | 2 +- .../unit_tests/service/test_authorization.py | 27 +++++++++++++++++++ 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/src/blueapi/service/authorization.py b/src/blueapi/service/authorization.py index c4f72ec45d..dbdc1f6853 100644 --- a/src/blueapi/service/authorization.py +++ b/src/blueapi/service/authorization.py @@ -83,5 +83,5 @@ async def opa( if opa := cast(OpaClient | None, getattr(request.app.state, "authz", None)): if not token: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) - return opa.for_token(token) + return OpaUserClient(opa, token) return None diff --git a/tests/unit_tests/service/test_authorization.py b/tests/unit_tests/service/test_authorization.py index 2491985806..ff3949174c 100644 --- a/tests/unit_tests/service/test_authorization.py +++ b/tests/unit_tests/service/test_authorization.py @@ -2,11 +2,13 @@ from unittest.mock import AsyncMock, MagicMock, Mock, patch import pytest +from fastapi import HTTPException from pydantic import HttpUrl from blueapi.config import OIDCConfig, OpaConfig, ServiceAccount from blueapi.service.authorization import ( OpaClient, + opa, validate_tiled_config, ) @@ -149,3 +151,28 @@ async def test_validate_tiled_config_with_missing_config( assert await validate_tiled_config(tiled_auth, oidc, opa_client) is None if opa_client is not None: opa_client.require_tiled_service_account.assert_not_called() + + +async def test_opa_dependency_method(): + request = MagicMock() + + user_client = await opa(request, "foo_bar") + + assert user_client is not None + assert user_client.client == request.app.state.authz + assert user_client.token == "foo_bar" + + +async def test_opa_dependency_without_token(): + request = MagicMock() + + with pytest.raises(HTTPException, match="401"): + await opa(request, None) + + +@pytest.mark.parametrize("token", ["foo_bar", None]) +async def test_opa_dependency_without_authz(token): + request = MagicMock() + del request.app.state.authz + user_client = await opa(request, token) + assert user_client is None From 82ff44b27cdf2ad31315adf8d73566d8c3224054 Mon Sep 17 00:00:00 2001 From: Peter Holloway Date: Tue, 26 May 2026 12:17:35 +0100 Subject: [PATCH 12/40] Add can_submit_task auth check method and config --- helm/blueapi/config_schema.json | 7 ++++- helm/blueapi/values.schema.json | 7 ++++- src/blueapi/config.py | 1 + src/blueapi/service/authorization.py | 30 +++++++++++++++++++ .../unit_tests/service/test_authorization.py | 1 + tests/unit_tests/test_config.py | 1 + 6 files changed, 45 insertions(+), 2 deletions(-) diff --git a/helm/blueapi/config_schema.json b/helm/blueapi/config_schema.json index dd8a48433a..e176e5c663 100644 --- a/helm/blueapi/config_schema.json +++ b/helm/blueapi/config_schema.json @@ -349,10 +349,15 @@ "tiled_service_account_check": { "title": "Tiled Service Account Check", "type": "string" + }, + "submit_task_check": { + "title": "Submit Task Check", + "type": "string" } }, "required": [ - "tiled_service_account_check" + "tiled_service_account_check", + "submit_task_check" ], "title": "OpaConfig", "type": "object", diff --git a/helm/blueapi/values.schema.json b/helm/blueapi/values.schema.json index 60310135cc..5808acc540 100644 --- a/helm/blueapi/values.schema.json +++ b/helm/blueapi/values.schema.json @@ -756,7 +756,8 @@ "title": "OpaConfig", "type": "object", "required": [ - "tiled_service_account_check" + "tiled_service_account_check", + "submit_task_check" ], "properties": { "audience": { @@ -772,6 +773,10 @@ "maxLength": 2083, "minLength": 1 }, + "submit_task_check": { + "title": "Submit Task Check", + "type": "string" + }, "tiled_service_account_check": { "title": "Tiled Service Account Check", "type": "string" diff --git a/src/blueapi/config.py b/src/blueapi/config.py index c56415bfe6..a64491090e 100644 --- a/src/blueapi/config.py +++ b/src/blueapi/config.py @@ -300,6 +300,7 @@ class OpaConfig(BlueapiBaseModel): root: HttpUrl = HttpUrl("http://localhost:8181") audience: str = "account" tiled_service_account_check: str + submit_task_check: str class ApplicationConfig(BlueapiBaseModel): diff --git a/src/blueapi/service/authorization.py b/src/blueapi/service/authorization.py index dbdc1f6853..9cc4ea7df2 100644 --- a/src/blueapi/service/authorization.py +++ b/src/blueapi/service/authorization.py @@ -1,4 +1,5 @@ import logging +import re from collections.abc import Mapping from contextlib import AbstractAsyncContextManager, aclosing, nullcontext from typing import Any, Self, cast @@ -9,8 +10,10 @@ from blueapi.config import OIDCConfig, OpaConfig, ServiceAccount from blueapi.service.authentication import TiledAuth, unchecked_bearer_token +from blueapi.service.model import TaskRequest LOGGER = logging.getLogger(__name__) +INSTRUMENT_SESSION_RE = re.compile(r"^[a-z]{2}(?P\d+)-(?P\d+)$") class OpaClient: @@ -58,6 +61,33 @@ async def require_tiled_service_account(self, token: str): f"Tiled service account is not valid for '{self._instrument}'" ) + async def require_submit_task(self, instrument_session: str, token: str): + if not (match := INSTRUMENT_SESSION_RE.match(instrument_session)): + raise ValueError("Invalid instrument session") + + if not await self._call_opa( + self._conf.submit_task_check, + { + "token": token, + "proposal": int(match["proposal"]), + "visit": int(match["visit"]), + }, + ): + raise HTTPException(status_code=status.HTTP_403_UNORTHORIZED) + + +class OpaUserClient: + client: OpaClient + token: str + + def __init__(self, client: OpaClient, token: str): + self.client = client + self.token = token + + async def can_submit_task(self, task: TaskRequest): + LOGGER.info("Checking permissions to run task") + await self.client.require_submit_task(task.instrument_session, self.token) + async def validate_tiled_config( tiled: ServiceAccount | str | None, oidc: OIDCConfig | None, opa: OpaClient | None diff --git a/tests/unit_tests/service/test_authorization.py b/tests/unit_tests/service/test_authorization.py index ff3949174c..37c1d7e3f9 100644 --- a/tests/unit_tests/service/test_authorization.py +++ b/tests/unit_tests/service/test_authorization.py @@ -24,6 +24,7 @@ def opa_config() -> OpaConfig: return OpaConfig( root=HttpUrl("http://auth.example.com"), + submit_task_check="/auth/submit", tiled_service_account_check="/auth/tiled", ) diff --git a/tests/unit_tests/test_config.py b/tests/unit_tests/test_config.py index 5cbb00c1b7..747e944d55 100644 --- a/tests/unit_tests/test_config.py +++ b/tests/unit_tests/test_config.py @@ -341,6 +341,7 @@ def test_config_yaml_parsed(temp_yaml_config_file): "root": "http://opa.example.com/", "audience": "account", "tiled_service_account_check": "v1/tiled_service_account", + "submit_task_check": "v1/submit_task", }, }, { From 84efe6ef89a8d7da96e3871b1f60b6879919cb17 Mon Sep 17 00:00:00 2001 From: root Date: Fri, 15 May 2026 08:27:14 +0000 Subject: [PATCH 13/40] feat: add authz dependency injection --- src/blueapi/service/main.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/blueapi/service/main.py b/src/blueapi/service/main.py index 3114fa73f5..7a8f46e422 100644 --- a/src/blueapi/service/main.py +++ b/src/blueapi/service/main.py @@ -40,7 +40,7 @@ from blueapi.worker import TrackableTask, WorkerState from blueapi.worker.event import TaskStatusEnum -from .authorization import OpaClient, validate_tiled_config +from .authorization import OpaClient, OpaUserClient, opa, validate_tiled_config from .model import ( DeviceModel, DeviceResponse, @@ -258,6 +258,13 @@ def get_device_by_name( ) +async def submission_check( + opa: Annotated[OpaUserClient, Depends(opa)], + task_request: TaskRequest, +): + await opa.can_submit_task(task_request) + + @secure_router_v1.post("/tasks", status_code=status.HTTP_201_CREATED, tags=[Tag.TASK]) @secure_router.post("/tasks", status_code=status.HTTP_201_CREATED, tags=[Tag.TASK]) @start_as_current_span( @@ -271,6 +278,7 @@ def submit_task( request: Request, response: Response, task_request: Annotated[TaskRequest, Body(..., examples=[example_task_request])], + authz_check: Annotated[None, Depends(submission_check)], runner: Annotated[WorkerDispatcher, Depends(_runner)], user: Fedid, ) -> TaskResponse: From 06a8bda5e8ef34a4a6e8d63ea42150dc8ec9df14 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 18 May 2026 15:01:34 +0000 Subject: [PATCH 14/40] feat: add auth check dependency injections to task endpoints --- src/blueapi/service/authorization.py | 8 +++++++- src/blueapi/service/main.py | 16 +++++++--------- 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/src/blueapi/service/authorization.py b/src/blueapi/service/authorization.py index 9cc4ea7df2..9d9c96ddd5 100644 --- a/src/blueapi/service/authorization.py +++ b/src/blueapi/service/authorization.py @@ -2,7 +2,7 @@ import re from collections.abc import Mapping from contextlib import AbstractAsyncContextManager, aclosing, nullcontext -from typing import Any, Self, cast +from typing import Annotated, Any, Self, cast from aiohttp import ClientSession from fastapi import Depends, HTTPException, Request @@ -115,3 +115,9 @@ async def opa( raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) return OpaUserClient(opa, token) return None + +async def submit_permission( + opa: Annotated[OpaUserClient, Depends(opa)], + task_request: TaskRequest, +): + await opa.can_submit_task(task_request) diff --git a/src/blueapi/service/main.py b/src/blueapi/service/main.py index 7a8f46e422..9cd0694c2b 100644 --- a/src/blueapi/service/main.py +++ b/src/blueapi/service/main.py @@ -40,7 +40,7 @@ from blueapi.worker import TrackableTask, WorkerState from blueapi.worker.event import TaskStatusEnum -from .authorization import OpaClient, OpaUserClient, opa, validate_tiled_config +from .authorization import OpaClient, submit_permission, validate_tiled_config from .model import ( DeviceModel, DeviceResponse, @@ -258,13 +258,6 @@ def get_device_by_name( ) -async def submission_check( - opa: Annotated[OpaUserClient, Depends(opa)], - task_request: TaskRequest, -): - await opa.can_submit_task(task_request) - - @secure_router_v1.post("/tasks", status_code=status.HTTP_201_CREATED, tags=[Tag.TASK]) @secure_router.post("/tasks", status_code=status.HTTP_201_CREATED, tags=[Tag.TASK]) @start_as_current_span( @@ -278,7 +271,7 @@ def submit_task( request: Request, response: Response, task_request: Annotated[TaskRequest, Body(..., examples=[example_task_request])], - authz_check: Annotated[None, Depends(submission_check)], + _: Annotated[None, Depends(submit_permission)], runner: Annotated[WorkerDispatcher, Depends(_runner)], user: Fedid, ) -> TaskResponse: @@ -317,6 +310,7 @@ def submit_task( @start_as_current_span(TRACER, "task_id") def delete_submitted_task( task_id: str, + _: Annotated[None, Depends(submit_permission)], runner: Annotated[WorkerDispatcher, Depends(_runner)], ) -> TaskResponse: return TaskResponse(task_id=runner.run(interface.clear_task, task_id)) @@ -335,6 +329,7 @@ def validate_task_status(v: str) -> TaskStatusEnum: @start_as_current_span(TRACER) def get_tasks( runner: Annotated[WorkerDispatcher, Depends(_runner)], + _: Annotated[None, Depends(submit_permission)], task_status: str | SkipJsonSchema[None] = None, ) -> TasksListResponse: """ @@ -371,6 +366,7 @@ def get_tasks( def set_active_task( request: Request, task: WorkerTask, + _: Annotated[None, Depends(submit_permission)], runner: Annotated[WorkerDispatcher, Depends(_runner)], ) -> WorkerTask: """Set a task to active status, the worker should begin it as soon as possible. @@ -401,6 +397,7 @@ def get_passthrough_headers(request: Request) -> dict[str, str]: @start_as_current_span(TRACER, "task_id") def get_task( task_id: str, + _: Annotated[None, Depends(submit_permission)], runner: Annotated[WorkerDispatcher, Depends(_runner)], ) -> TrackableTask: """Retrieve a task""" @@ -478,6 +475,7 @@ def get_state(runner: Annotated[WorkerDispatcher, Depends(_runner)]) -> WorkerSt def set_state( state_change_request: StateChangeRequest, response: Response, + _: Annotated[None, Depends(submit_permission)], runner: Annotated[WorkerDispatcher, Depends(_runner)], ) -> WorkerState: """ From 7536275f575b8620a254820071efa82c355a944e Mon Sep 17 00:00:00 2001 From: root Date: Wed, 20 May 2026 08:13:42 +0000 Subject: [PATCH 15/40] feat: create new access task permission fns and add as dependencies --- src/blueapi/service/main.py | 57 +++++++++++++++++++++++++++++++++---- 1 file changed, 52 insertions(+), 5 deletions(-) diff --git a/src/blueapi/service/main.py b/src/blueapi/service/main.py index 9cd0694c2b..6b999cb6e3 100644 --- a/src/blueapi/service/main.py +++ b/src/blueapi/service/main.py @@ -146,6 +146,41 @@ def get_app(config: ApplicationConfig): return app +def access_task_permission( + request: Request, + task_id: str, + runner: Annotated[WorkerDispatcher, Depends(_runner)], +): + access_token: dict[str, Any] | None = getattr( + request.state, "decoded_access_token", None + ) + try: + task = runner.run(interface.get_task_by_id, task_id) + except KeyError: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) from None + + if ( + access_token + and task + and access_token.get("fedid") != task.task.metadata.get("user") + ): + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) + + +# start_task_permission is used when there is WorkerTask +def start_task_permission( + request: Request, + task: WorkerTask, + runner: Annotated[WorkerDispatcher, Depends(_runner)], +): + if not task.task_id: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_CONTENT, + detail="No task id provided", + ) + access_task_permission(request, task.task_id, runner) + + async def on_key_error_404(_: Request, __: Exception): return JSONResponse( status_code=status.HTTP_404_NOT_FOUND, @@ -310,7 +345,7 @@ def submit_task( @start_as_current_span(TRACER, "task_id") def delete_submitted_task( task_id: str, - _: Annotated[None, Depends(submit_permission)], + _: Annotated[None, Depends(access_task_permission)], runner: Annotated[WorkerDispatcher, Depends(_runner)], ) -> TaskResponse: return TaskResponse(task_id=runner.run(interface.clear_task, task_id)) @@ -328,8 +363,8 @@ def validate_task_status(v: str) -> TaskStatusEnum: @secure_router.get("/tasks", status_code=status.HTTP_200_OK, tags=[Tag.TASK]) @start_as_current_span(TRACER) def get_tasks( + request: Request, runner: Annotated[WorkerDispatcher, Depends(_runner)], - _: Annotated[None, Depends(submit_permission)], task_status: str | SkipJsonSchema[None] = None, ) -> TasksListResponse: """ @@ -349,6 +384,15 @@ def get_tasks( tasks = runner.run(interface.get_tasks_by_status, desired_status) else: tasks = runner.run(interface.get_tasks) + + access_token: dict[str, Any] | None = getattr( + request.state, "decoded_access_token", None + ) + user = access_token.get("fedid") if access_token else None + + if user: + tasks = [t for t in tasks if t.task.metadata.get("user") == user] + return TasksListResponse(tasks=tasks) @@ -366,7 +410,7 @@ def get_tasks( def set_active_task( request: Request, task: WorkerTask, - _: Annotated[None, Depends(submit_permission)], + _: Annotated[None, Depends(start_task_permission)], runner: Annotated[WorkerDispatcher, Depends(_runner)], ) -> WorkerTask: """Set a task to active status, the worker should begin it as soon as possible. @@ -397,7 +441,7 @@ def get_passthrough_headers(request: Request) -> dict[str, str]: @start_as_current_span(TRACER, "task_id") def get_task( task_id: str, - _: Annotated[None, Depends(submit_permission)], + _: Annotated[None, Depends(access_task_permission)], runner: Annotated[WorkerDispatcher, Depends(_runner)], ) -> TrackableTask: """Retrieve a task""" @@ -475,7 +519,7 @@ def get_state(runner: Annotated[WorkerDispatcher, Depends(_runner)]) -> WorkerSt def set_state( state_change_request: StateChangeRequest, response: Response, - _: Annotated[None, Depends(submit_permission)], + _: Annotated[None, Depends(access_task_permission)], runner: Annotated[WorkerDispatcher, Depends(_runner)], ) -> WorkerState: """ @@ -507,6 +551,9 @@ def set_state( elif new_state == WorkerState.RUNNING: runner.run(interface.resume_worker) elif new_state in {WorkerState.ABORTING, WorkerState.STOPPING}: + # active = runner.run(interface.get_active_task) + # if active.task.metadata.get("user"): + try: runner.run( interface.cancel_active_task, From f066cbfdfb76f3b7ec44ef59d49699e2b2faf857 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 20 May 2026 08:30:53 +0000 Subject: [PATCH 16/40] refactor: update rest api version --- src/blueapi/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/blueapi/config.py b/src/blueapi/config.py index a64491090e..1f4da71618 100644 --- a/src/blueapi/config.py +++ b/src/blueapi/config.py @@ -310,7 +310,7 @@ class ApplicationConfig(BlueapiBaseModel): """ #: API version to publish in OpenAPI schema - REST_API_VERSION: ClassVar[str] = "1.3.0" + REST_API_VERSION: ClassVar[str] = "1.3.1" LICENSE_INFO: ClassVar[dict[str, str]] = { "name": "Apache 2.0", From a56893dffcaad7a5080a58362257fc67add7c752 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 20 May 2026 10:13:46 +0000 Subject: [PATCH 17/40] comment out dependency addition in set_state --- src/blueapi/config.py | 2 +- src/blueapi/service/main.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/blueapi/config.py b/src/blueapi/config.py index 1f4da71618..a64491090e 100644 --- a/src/blueapi/config.py +++ b/src/blueapi/config.py @@ -310,7 +310,7 @@ class ApplicationConfig(BlueapiBaseModel): """ #: API version to publish in OpenAPI schema - REST_API_VERSION: ClassVar[str] = "1.3.1" + REST_API_VERSION: ClassVar[str] = "1.3.0" LICENSE_INFO: ClassVar[dict[str, str]] = { "name": "Apache 2.0", diff --git a/src/blueapi/service/main.py b/src/blueapi/service/main.py index 6b999cb6e3..0ac4d94a4f 100644 --- a/src/blueapi/service/main.py +++ b/src/blueapi/service/main.py @@ -519,7 +519,7 @@ def get_state(runner: Annotated[WorkerDispatcher, Depends(_runner)]) -> WorkerSt def set_state( state_change_request: StateChangeRequest, response: Response, - _: Annotated[None, Depends(access_task_permission)], + # _: Annotated[None, Depends(access_task_permission)], runner: Annotated[WorkerDispatcher, Depends(_runner)], ) -> WorkerState: """ From 451882561344112fa70ded7bc7c997bdbf146316 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 20 May 2026 14:16:24 +0000 Subject: [PATCH 18/40] refactor: add admin check and check to set state function --- src/blueapi/service/authentication.py | 3 +++ src/blueapi/service/main.py | 30 ++++++++++++++--------- tests/unit_tests/service/test_rest_api.py | 2 +- 3 files changed, 22 insertions(+), 13 deletions(-) diff --git a/src/blueapi/service/authentication.py b/src/blueapi/service/authentication.py index 64dfc30040..156540bb02 100644 --- a/src/blueapi/service/authentication.py +++ b/src/blueapi/service/authentication.py @@ -286,6 +286,9 @@ def unchecked_bearer_token(req: Request) -> str | None: return None return param.strip() + def admin(self): + return False + UncheckedBearerToken = Annotated[str | None, Depends(unchecked_bearer_token)] diff --git a/src/blueapi/service/main.py b/src/blueapi/service/main.py index 0ac4d94a4f..4d6e1708a9 100644 --- a/src/blueapi/service/main.py +++ b/src/blueapi/service/main.py @@ -147,6 +147,7 @@ def get_app(config: ApplicationConfig): def access_task_permission( + opa: Annotated[OPAClient, Depends(get_opa_client)], request: Request, task_id: str, runner: Annotated[WorkerDispatcher, Depends(_runner)], @@ -154,21 +155,19 @@ def access_task_permission( access_token: dict[str, Any] | None = getattr( request.state, "decoded_access_token", None ) - try: - task = runner.run(interface.get_task_by_id, task_id) - except KeyError: - raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) from None + task = runner.run(interface.get_task_by_id, task_id) - if ( + if not opa.admin() and ( access_token and task and access_token.get("fedid") != task.task.metadata.get("user") ): - raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) # start_task_permission is used when there is WorkerTask def start_task_permission( + opa: Annotated[OPAClient, Depends(get_opa_client)], request: Request, task: WorkerTask, runner: Annotated[WorkerDispatcher, Depends(_runner)], @@ -178,7 +177,7 @@ def start_task_permission( status_code=status.HTTP_422_UNPROCESSABLE_CONTENT, detail="No task id provided", ) - access_task_permission(request, task.task_id, runner) + access_task_permission(opa, request, task.task_id, runner) async def on_key_error_404(_: Request, __: Exception): @@ -390,8 +389,7 @@ def get_tasks( ) user = access_token.get("fedid") if access_token else None - if user: - tasks = [t for t in tasks if t.task.metadata.get("user") == user] + tasks = [t for t in tasks if t.task.metadata.get("user") == user] return TasksListResponse(tasks=tasks) @@ -517,8 +515,10 @@ def get_state(runner: Annotated[WorkerDispatcher, Depends(_runner)]) -> WorkerSt ) @start_as_current_span(TRACER, "state_change_request.new_state") def set_state( + request: Request, state_change_request: StateChangeRequest, response: Response, + opa: Annotated[OPAClient, Depends(get_opa_client)], # _: Annotated[None, Depends(access_task_permission)], runner: Annotated[WorkerDispatcher, Depends(_runner)], ) -> WorkerState: @@ -546,14 +546,20 @@ def set_state( current_state in _ALLOWED_TRANSITIONS and new_state in _ALLOWED_TRANSITIONS[current_state] ): + active = runner.run(interface.get_active_task) + access_token: dict[str, Any] | None = getattr( + request.state, "decoded_access_token", None + ) + user = access_token.get("fedid") if access_token else None + + if not opa.admin() and active and active.task.metadata.get("user") != user: + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) + if new_state == WorkerState.PAUSED: runner.run(interface.pause_worker, state_change_request.defer) elif new_state == WorkerState.RUNNING: runner.run(interface.resume_worker) elif new_state in {WorkerState.ABORTING, WorkerState.STOPPING}: - # active = runner.run(interface.get_active_task) - # if active.task.metadata.get("user"): - try: runner.run( interface.cancel_active_task, diff --git a/tests/unit_tests/service/test_rest_api.py b/tests/unit_tests/service/test_rest_api.py index c1d3b6a957..bf0a6a9977 100644 --- a/tests/unit_tests/service/test_rest_api.py +++ b/tests/unit_tests/service/test_rest_api.py @@ -251,7 +251,7 @@ def test_create_task(mock_runner: Mock, client: TestClient) -> None: response = client.post("/tasks", json=task.model_dump()) - mock_runner.run.assert_called_with(submit_task, task, {"user": "Unknown"}) + mock_runner.run.assert_called_with(submit_task, task, {"user": None}) assert response.json() == {"task_id": task_id} From 060ec2e1d2552f3cacbfc92aca1cd1e6bf1b90b3 Mon Sep 17 00:00:00 2001 From: Peter Holloway Date: Tue, 26 May 2026 12:34:41 +0100 Subject: [PATCH 19/40] Update dependency names --- src/blueapi/service/main.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/src/blueapi/service/main.py b/src/blueapi/service/main.py index 4d6e1708a9..22401afe39 100644 --- a/src/blueapi/service/main.py +++ b/src/blueapi/service/main.py @@ -40,7 +40,13 @@ from blueapi.worker import TrackableTask, WorkerState from blueapi.worker.event import TaskStatusEnum -from .authorization import OpaClient, submit_permission, validate_tiled_config +from .authorization import ( + OpaClient, + OpaUserClient, + opa, + submit_permission, + validate_tiled_config, +) from .model import ( DeviceModel, DeviceResponse, @@ -147,7 +153,7 @@ def get_app(config: ApplicationConfig): def access_task_permission( - opa: Annotated[OPAClient, Depends(get_opa_client)], + opa: Annotated[OpaUserClient, Depends(opa)], request: Request, task_id: str, runner: Annotated[WorkerDispatcher, Depends(_runner)], @@ -167,7 +173,7 @@ def access_task_permission( # start_task_permission is used when there is WorkerTask def start_task_permission( - opa: Annotated[OPAClient, Depends(get_opa_client)], + opa: Annotated[OpaUserClient, Depends(opa)], request: Request, task: WorkerTask, runner: Annotated[WorkerDispatcher, Depends(_runner)], @@ -518,7 +524,7 @@ def set_state( request: Request, state_change_request: StateChangeRequest, response: Response, - opa: Annotated[OPAClient, Depends(get_opa_client)], + opa: Annotated[OpaUserClient, Depends(opa)], # _: Annotated[None, Depends(access_task_permission)], runner: Annotated[WorkerDispatcher, Depends(_runner)], ) -> WorkerState: From 29e2a5a0801b0667560133cf29a15b84326ddf8a Mon Sep 17 00:00:00 2001 From: Peter Holloway Date: Tue, 26 May 2026 12:53:37 +0100 Subject: [PATCH 20/40] Add missing admin check --- helm/blueapi/config_schema.json | 7 ++++++- helm/blueapi/values.schema.json | 7 ++++++- src/blueapi/config.py | 1 + src/blueapi/service/authorization.py | 7 +++++++ tests/unit_tests/test_config.py | 1 + 5 files changed, 21 insertions(+), 2 deletions(-) diff --git a/helm/blueapi/config_schema.json b/helm/blueapi/config_schema.json index e176e5c663..4f5d157eb8 100644 --- a/helm/blueapi/config_schema.json +++ b/helm/blueapi/config_schema.json @@ -353,11 +353,16 @@ "submit_task_check": { "title": "Submit Task Check", "type": "string" + }, + "admin_check": { + "title": "Admin Check", + "type": "string" } }, "required": [ "tiled_service_account_check", - "submit_task_check" + "submit_task_check", + "admin_check" ], "title": "OpaConfig", "type": "object", diff --git a/helm/blueapi/values.schema.json b/helm/blueapi/values.schema.json index 5808acc540..6083f77e5e 100644 --- a/helm/blueapi/values.schema.json +++ b/helm/blueapi/values.schema.json @@ -757,9 +757,14 @@ "type": "object", "required": [ "tiled_service_account_check", - "submit_task_check" + "submit_task_check", + "admin_check" ], "properties": { + "admin_check": { + "title": "Admin Check", + "type": "string" + }, "audience": { "title": "Audience", "default": "account", diff --git a/src/blueapi/config.py b/src/blueapi/config.py index a64491090e..a19e30c7bb 100644 --- a/src/blueapi/config.py +++ b/src/blueapi/config.py @@ -301,6 +301,7 @@ class OpaConfig(BlueapiBaseModel): audience: str = "account" tiled_service_account_check: str submit_task_check: str + admin_check: str class ApplicationConfig(BlueapiBaseModel): diff --git a/src/blueapi/service/authorization.py b/src/blueapi/service/authorization.py index 9d9c96ddd5..c84a7d1aa5 100644 --- a/src/blueapi/service/authorization.py +++ b/src/blueapi/service/authorization.py @@ -75,6 +75,9 @@ async def require_submit_task(self, instrument_session: str, token: str): ): raise HTTPException(status_code=status.HTTP_403_UNORTHORIZED) + async def is_admin(self, token: str) -> bool: + return await self._call_opa(self._conf.admin_check, {"token": token}) + class OpaUserClient: client: OpaClient @@ -88,6 +91,9 @@ async def can_submit_task(self, task: TaskRequest): LOGGER.info("Checking permissions to run task") await self.client.require_submit_task(task.instrument_session, self.token) + async def admin(self) -> bool: + return await self.client.is_admin(self.token) + async def validate_tiled_config( tiled: ServiceAccount | str | None, oidc: OIDCConfig | None, opa: OpaClient | None @@ -116,6 +122,7 @@ async def opa( return OpaUserClient(opa, token) return None + async def submit_permission( opa: Annotated[OpaUserClient, Depends(opa)], task_request: TaskRequest, diff --git a/tests/unit_tests/test_config.py b/tests/unit_tests/test_config.py index 747e944d55..f3d3d37fdc 100644 --- a/tests/unit_tests/test_config.py +++ b/tests/unit_tests/test_config.py @@ -342,6 +342,7 @@ def test_config_yaml_parsed(temp_yaml_config_file): "audience": "account", "tiled_service_account_check": "v1/tiled_service_account", "submit_task_check": "v1/submit_task", + "admin_check": "v1/admin_check", }, }, { From 3757ac58426fda3bbc49a4a66bff26ed72a12348 Mon Sep 17 00:00:00 2001 From: Peter Holloway Date: Tue, 26 May 2026 15:45:58 +0100 Subject: [PATCH 21/40] Handle missing opa and fix tests --- src/blueapi/service/authorization.py | 3 +- src/blueapi/service/main.py | 12 ++++++-- tests/unit_tests/service/test_rest_api.py | 35 +++++++++++++++++++---- 3 files changed, 42 insertions(+), 8 deletions(-) diff --git a/src/blueapi/service/authorization.py b/src/blueapi/service/authorization.py index c84a7d1aa5..c669b627c0 100644 --- a/src/blueapi/service/authorization.py +++ b/src/blueapi/service/authorization.py @@ -127,4 +127,5 @@ async def submit_permission( opa: Annotated[OpaUserClient, Depends(opa)], task_request: TaskRequest, ): - await opa.can_submit_task(task_request) + if opa: + await opa.can_submit_task(task_request) diff --git a/src/blueapi/service/main.py b/src/blueapi/service/main.py index 22401afe39..6e2bd2f884 100644 --- a/src/blueapi/service/main.py +++ b/src/blueapi/service/main.py @@ -153,11 +153,14 @@ def get_app(config: ApplicationConfig): def access_task_permission( - opa: Annotated[OpaUserClient, Depends(opa)], + opa: Annotated[OpaUserClient | None, Depends(opa)], request: Request, task_id: str, runner: Annotated[WorkerDispatcher, Depends(_runner)], ): + if not opa: + return + access_token: dict[str, Any] | None = getattr( request.state, "decoded_access_token", None ) @@ -558,7 +561,12 @@ def set_state( ) user = access_token.get("fedid") if access_token else None - if not opa.admin() and active and active.task.metadata.get("user") != user: + if ( + opa + and not opa.admin() + and active + and active.task.metadata.get("user") != user + ): raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) if new_state == WorkerState.PAUSED: diff --git a/tests/unit_tests/service/test_rest_api.py b/tests/unit_tests/service/test_rest_api.py index bf0a6a9977..a2248e7985 100644 --- a/tests/unit_tests/service/test_rest_api.py +++ b/tests/unit_tests/service/test_rest_api.py @@ -574,7 +574,12 @@ def test_get_state(mock_runner: Mock, client: TestClient): def test_set_state_running_to_paused(mock_runner: Mock, client: TestClient): current_state = WorkerState.RUNNING final_state = WorkerState.PAUSED - mock_runner.run.side_effect = [current_state, None, final_state] + mock_runner.run.side_effect = [ + current_state, + TrackableTask(task_id="foobar", task=Task(name="foo")), + None, + final_state, + ] response = client.put( "/worker/state", json=StateChangeRequest(new_state=final_state).model_dump() @@ -588,7 +593,12 @@ def test_set_state_running_to_paused(mock_runner: Mock, client: TestClient): def test_set_state_paused_to_running(mock_runner: Mock, client: TestClient): current_state = WorkerState.PAUSED final_state = WorkerState.RUNNING - mock_runner.run.side_effect = [current_state, None, final_state] + mock_runner.run.side_effect = [ + current_state, + TrackableTask(task_id="foobar", task=Task(name="foo")), + None, + final_state, + ] response = client.put( "/worker/state", json=StateChangeRequest(new_state=final_state).model_dump() @@ -602,7 +612,12 @@ def test_set_state_paused_to_running(mock_runner: Mock, client: TestClient): def test_set_state_running_to_aborting(mock_runner: Mock, client: TestClient): current_state = WorkerState.RUNNING final_state = WorkerState.ABORTING - mock_runner.run.side_effect = [current_state, None, final_state] + mock_runner.run.side_effect = [ + current_state, + TrackableTask(task_id="foobar", task=Task(name="foo")), + None, + final_state, + ] response = client.put( "/worker/state", json=StateChangeRequest(new_state=final_state).model_dump() @@ -619,7 +634,12 @@ def test_set_state_running_to_stopping_including_reason( current_state = WorkerState.RUNNING final_state = WorkerState.STOPPING reason = "blueapi is being stopped" - mock_runner.run.side_effect = [current_state, None, final_state] + mock_runner.run.side_effect = [ + current_state, + TrackableTask(task_id="foobar", task=Task(name="foo")), + None, + final_state, + ] response = client.put( "/worker/state", @@ -635,7 +655,12 @@ def test_set_state_transition_error(mock_runner: Mock, client: TestClient): current_state = WorkerState.RUNNING final_state = WorkerState.STOPPING - mock_runner.run.side_effect = [current_state, TransitionError(), final_state] + mock_runner.run.side_effect = [ + current_state, + TrackableTask(task_id="foobar", task=Task(name="foo")), + TransitionError(), + final_state, + ] response = client.put( "/worker/state", From b4c61a73b608bb3821c27de947228dd600712f01 Mon Sep 17 00:00:00 2001 From: Peter Holloway Date: Tue, 26 May 2026 15:46:43 +0100 Subject: [PATCH 22/40] Remove old admin method --- src/blueapi/service/authentication.py | 3 -- src/blueapi/service/main.py | 42 +++++++-------------------- 2 files changed, 11 insertions(+), 34 deletions(-) diff --git a/src/blueapi/service/authentication.py b/src/blueapi/service/authentication.py index 156540bb02..64dfc30040 100644 --- a/src/blueapi/service/authentication.py +++ b/src/blueapi/service/authentication.py @@ -286,9 +286,6 @@ def unchecked_bearer_token(req: Request) -> str | None: return None return param.strip() - def admin(self): - return False - UncheckedBearerToken = Annotated[str | None, Depends(unchecked_bearer_token)] diff --git a/src/blueapi/service/main.py b/src/blueapi/service/main.py index 6e2bd2f884..06c8765c66 100644 --- a/src/blueapi/service/main.py +++ b/src/blueapi/service/main.py @@ -154,31 +154,21 @@ def get_app(config: ApplicationConfig): def access_task_permission( opa: Annotated[OpaUserClient | None, Depends(opa)], - request: Request, task_id: str, + fedid: Fedid, runner: Annotated[WorkerDispatcher, Depends(_runner)], ): - if not opa: - return - - access_token: dict[str, Any] | None = getattr( - request.state, "decoded_access_token", None - ) task = runner.run(interface.get_task_by_id, task_id) - if not opa.admin() and ( - access_token - and task - and access_token.get("fedid") != task.task.metadata.get("user") - ): + if opa and not opa.admin() and (task and fedid != task.task.metadata.get("user")): raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) # start_task_permission is used when there is WorkerTask def start_task_permission( - opa: Annotated[OpaUserClient, Depends(opa)], - request: Request, task: WorkerTask, + opa: Annotated[OpaUserClient, Depends(opa)], + fedid: Fedid, runner: Annotated[WorkerDispatcher, Depends(_runner)], ): if not task.task_id: @@ -186,7 +176,7 @@ def start_task_permission( status_code=status.HTTP_422_UNPROCESSABLE_CONTENT, detail="No task id provided", ) - access_task_permission(opa, request, task.task_id, runner) + access_task_permission(opa, task.task_id, fedid, runner) async def on_key_error_404(_: Request, __: Exception): @@ -316,12 +306,11 @@ def submit_task( task_request: Annotated[TaskRequest, Body(..., examples=[example_task_request])], _: Annotated[None, Depends(submit_permission)], runner: Annotated[WorkerDispatcher, Depends(_runner)], - user: Fedid, + fedid: Fedid, ) -> TaskResponse: """Submit a task to the worker.""" try: - user = user or "Unknown" - task_id: str = runner.run(interface.submit_task, task_request, {"user": user}) + task_id: str = runner.run(interface.submit_task, task_request, {"user": fedid}) response.headers["Location"] = f"{request.url}/{task_id}" return TaskResponse(task_id=task_id) except ValidationError as e: @@ -371,7 +360,7 @@ def validate_task_status(v: str) -> TaskStatusEnum: @secure_router.get("/tasks", status_code=status.HTTP_200_OK, tags=[Tag.TASK]) @start_as_current_span(TRACER) def get_tasks( - request: Request, + fedid: Fedid, runner: Annotated[WorkerDispatcher, Depends(_runner)], task_status: str | SkipJsonSchema[None] = None, ) -> TasksListResponse: @@ -393,12 +382,7 @@ def get_tasks( else: tasks = runner.run(interface.get_tasks) - access_token: dict[str, Any] | None = getattr( - request.state, "decoded_access_token", None - ) - user = access_token.get("fedid") if access_token else None - - tasks = [t for t in tasks if t.task.metadata.get("user") == user] + tasks = [t for t in tasks if t.task.metadata.get("user") == fedid] return TasksListResponse(tasks=tasks) @@ -524,9 +508,9 @@ def get_state(runner: Annotated[WorkerDispatcher, Depends(_runner)]) -> WorkerSt ) @start_as_current_span(TRACER, "state_change_request.new_state") def set_state( - request: Request, state_change_request: StateChangeRequest, response: Response, + fedid: Fedid, opa: Annotated[OpaUserClient, Depends(opa)], # _: Annotated[None, Depends(access_task_permission)], runner: Annotated[WorkerDispatcher, Depends(_runner)], @@ -556,16 +540,12 @@ def set_state( and new_state in _ALLOWED_TRANSITIONS[current_state] ): active = runner.run(interface.get_active_task) - access_token: dict[str, Any] | None = getattr( - request.state, "decoded_access_token", None - ) - user = access_token.get("fedid") if access_token else None if ( opa and not opa.admin() and active - and active.task.metadata.get("user") != user + and active.task.metadata.get("user") != fedid ): raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) From 38fc6381ef4a5e296629c3dd87c2337f15e43e3d Mon Sep 17 00:00:00 2001 From: Peter Holloway Date: Thu, 28 May 2026 17:17:51 +0100 Subject: [PATCH 23/40] Use starlette statuses directly --- src/blueapi/service/authorization.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/blueapi/service/authorization.py b/src/blueapi/service/authorization.py index c669b627c0..edcec5cbd6 100644 --- a/src/blueapi/service/authorization.py +++ b/src/blueapi/service/authorization.py @@ -6,7 +6,7 @@ from aiohttp import ClientSession from fastapi import Depends, HTTPException, Request -from starlette import status +from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN from blueapi.config import OIDCConfig, OpaConfig, ServiceAccount from blueapi.service.authentication import TiledAuth, unchecked_bearer_token @@ -73,7 +73,7 @@ async def require_submit_task(self, instrument_session: str, token: str): "visit": int(match["visit"]), }, ): - raise HTTPException(status_code=status.HTTP_403_UNORTHORIZED) + raise HTTPException(status_code=HTTP_403_FORBIDDEN) async def is_admin(self, token: str) -> bool: return await self._call_opa(self._conf.admin_check, {"token": token}) @@ -118,13 +118,13 @@ async def opa( if opa := cast(OpaClient | None, getattr(request.app.state, "authz", None)): if not token: - raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) + raise HTTPException(status_code=HTTP_401_UNAUTHORIZED) return OpaUserClient(opa, token) return None async def submit_permission( - opa: Annotated[OpaUserClient, Depends(opa)], + opa: Annotated[OpaUserClient | None, Depends(opa)], task_request: TaskRequest, ): if opa: From b14a25647e739a5f190766ee2d3dfa99a2c663cf Mon Sep 17 00:00:00 2001 From: Peter Holloway Date: Thu, 28 May 2026 17:27:47 +0100 Subject: [PATCH 24/40] test task submission authz --- .../unit_tests/service/test_authorization.py | 121 ++++++++++++++++++ 1 file changed, 121 insertions(+) diff --git a/tests/unit_tests/service/test_authorization.py b/tests/unit_tests/service/test_authorization.py index 37c1d7e3f9..65f5c44a61 100644 --- a/tests/unit_tests/service/test_authorization.py +++ b/tests/unit_tests/service/test_authorization.py @@ -8,9 +8,12 @@ from blueapi.config import OIDCConfig, OpaConfig, ServiceAccount from blueapi.service.authorization import ( OpaClient, + OpaUserClient, opa, + submit_permission, validate_tiled_config, ) +from blueapi.service.model import TaskRequest # Reusable client patch decorator patch_client_session = patch( @@ -25,6 +28,7 @@ def opa_config() -> OpaConfig: return OpaConfig( root=HttpUrl("http://auth.example.com"), submit_task_check="/auth/submit", + admin_check="/auth/admin", tiled_service_account_check="/auth/tiled", ) @@ -108,6 +112,105 @@ async def test_opa_adds_input_fields(session: MagicMock, opa_config: OpaConfig): ) +@pytest.mark.parametrize( + "result,context", + [(True, nullcontext()), (False, pytest.raises(HTTPException, match="403"))], +) +@patch_client_session +async def test_require_submit_task( + session: MagicMock, + opa_config: OpaConfig, + result: bool, + context: AbstractContextManager, +): + session.return_value.post = AsyncMock( + return_value=MagicMock(json=AsyncMock(return_value={"result": result})) + ) + + client = OpaClient(instrument="p99", config=opa_config) + + session.assert_called_once_with(base_url="http://auth.example.com/") + with context: + await client.require_submit_task( + instrument_session="cm12345-1", token="foo_bar" + ) + + session().post.assert_called_once_with( + "/auth/submit", + json={ + "input": { + "token": "foo_bar", + "beamline": "p99", + "audience": "account", + "visit": 1, + "proposal": 12345, + } + }, + ) + + +@patch_client_session +async def test_opa_require_submit_task_invalid_session( + session: MagicMock, opa_config: OpaConfig +): + client = OpaClient(instrument="p45", config=opa_config) + + with pytest.raises(ValueError): + await client.require_submit_task( + instrument_session="not a session", token="foo_bar" + ) + + +@pytest.mark.parametrize("result", [True, False]) +@patch_client_session +async def test_opa_is_admin(session: MagicMock, opa_config: OpaConfig, result: bool): + session.return_value.post = AsyncMock( + return_value=MagicMock(json=AsyncMock(return_value={"result": result})) + ) + client = OpaClient(instrument="p45", config=opa_config) + + admin = await client.is_admin("foo_bar") + + assert admin == result + + session().post.assert_called_once_with( + "/auth/admin", + json={"input": {"token": "foo_bar", "beamline": "p45", "audience": "account"}}, + ) + + +@pytest.mark.parametrize( + "result,context", + [ + (None, nullcontext()), + (HTTPException(status_code=403), pytest.raises(HTTPException, match="403")), + ], +) +async def test_user_client_can_submit_task(result, context: AbstractContextManager): + opa = MagicMock(spec=OpaUserClient) + opa.require_submit_task = AsyncMock(side_effect=result) + + user_client = OpaUserClient(opa, "foo_bar") + + with context: + await user_client.can_submit_task( + TaskRequest(name="foo", params={}, instrument_session="cm12345-1") + ) + opa.require_submit_task.assert_called_once_with("cm12345-1", "foo_bar") + + +@pytest.mark.parametrize("result", [True, False]) +async def test_user_client_admin(result: bool): + opa = MagicMock(spec=OpaUserClient) + opa.is_admin = AsyncMock(return_value=result) + + user_client = OpaUserClient(opa, "foo_bar") + + admin = await user_client.admin() + + assert admin == result + + async def test_validate_tiled_config(): opa = MagicMock(spec=OpaClient) tiled = ServiceAccount() @@ -177,3 +280,21 @@ async def test_opa_dependency_without_authz(token): del request.app.state.authz user_client = await opa(request, token) assert user_client is None + + +@pytest.mark.parametrize( + "result,context", + [ + (None, nullcontext()), + (HTTPException(status_code=403), pytest.raises(HTTPException, match="403")), + ], +) +async def test_submit_permission_dependency(result, context: AbstractContextManager): + opa = MagicMock(spec=OpaUserClient) + opa.can_submit_task.side_effect = result + with context: + await submit_permission(opa, Mock()) + + +async def test_submit_permission_dependency_without_opa(): + assert await submit_permission(None, Mock()) is None From 072c36c831fb6af58791d284b9c0c2743d8b369a Mon Sep 17 00:00:00 2001 From: Peter Holloway Date: Fri, 5 Jun 2026 15:19:28 +0100 Subject: [PATCH 25/40] Use _config instead of _conf --- src/blueapi/service/authorization.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/blueapi/service/authorization.py b/src/blueapi/service/authorization.py index edcec5cbd6..5f1e2e9806 100644 --- a/src/blueapi/service/authorization.py +++ b/src/blueapi/service/authorization.py @@ -66,7 +66,7 @@ async def require_submit_task(self, instrument_session: str, token: str): raise ValueError("Invalid instrument session") if not await self._call_opa( - self._conf.submit_task_check, + self._config.submit_task_check, { "token": token, "proposal": int(match["proposal"]), @@ -76,7 +76,7 @@ async def require_submit_task(self, instrument_session: str, token: str): raise HTTPException(status_code=HTTP_403_FORBIDDEN) async def is_admin(self, token: str) -> bool: - return await self._call_opa(self._conf.admin_check, {"token": token}) + return await self._call_opa(self._config.admin_check, {"token": token}) class OpaUserClient: From f193a5b3d10d0f8904e4ff8bcfbbccaf6f6bde1e Mon Sep 17 00:00:00 2001 From: Peter Holloway Date: Fri, 5 Jun 2026 16:36:19 +0100 Subject: [PATCH 26/40] Re-use instrument session regex --- src/blueapi/service/authorization.py | 3 +-- src/blueapi/utils/__init__.py | 3 +++ src/blueapi/utils/serialization.py | 10 +++------- 3 files changed, 7 insertions(+), 9 deletions(-) diff --git a/src/blueapi/service/authorization.py b/src/blueapi/service/authorization.py index 5f1e2e9806..e54d0daa93 100644 --- a/src/blueapi/service/authorization.py +++ b/src/blueapi/service/authorization.py @@ -1,5 +1,4 @@ import logging -import re from collections.abc import Mapping from contextlib import AbstractAsyncContextManager, aclosing, nullcontext from typing import Annotated, Any, Self, cast @@ -11,9 +10,9 @@ from blueapi.config import OIDCConfig, OpaConfig, ServiceAccount from blueapi.service.authentication import TiledAuth, unchecked_bearer_token from blueapi.service.model import TaskRequest +from blueapi.utils import INSTRUMENT_SESSION_RE LOGGER = logging.getLogger(__name__) -INSTRUMENT_SESSION_RE = re.compile(r"^[a-z]{2}(?P\d+)-(?P\d+)$") class OpaClient: diff --git a/src/blueapi/utils/__init__.py b/src/blueapi/utils/__init__.py index 4b2e41f2c1..bf96b7009d 100644 --- a/src/blueapi/utils/__init__.py +++ b/src/blueapi/utils/__init__.py @@ -1,3 +1,4 @@ +import re from collections.abc import Callable from functools import wraps from typing import ParamSpec, TypeVar @@ -31,6 +32,8 @@ Args = ParamSpec("Args") Return = TypeVar("Return") +INSTRUMENT_SESSION_RE = re.compile(r"^[a-z]{2}(?P\d+)-(?P\d+)$") + def deprecated(alternative): from warnings import warn diff --git a/src/blueapi/utils/serialization.py b/src/blueapi/utils/serialization.py index deee82b1e1..8918cf8823 100644 --- a/src/blueapi/utils/serialization.py +++ b/src/blueapi/utils/serialization.py @@ -1,9 +1,10 @@ import json -import re from typing import Any from pydantic import BaseModel +from blueapi import utils + def serialize(obj: Any) -> Any: """ @@ -28,13 +29,8 @@ def serialize(obj: Any) -> Any: return obj -_INSTRUMENT_SESSION_AUTHZ_REGEX: re.Pattern = re.compile( - r"^[a-zA-Z]{2}(?P\d+)-(?P\d+)$" -) - - def access_blob(instrument_session: str, beamline: str) -> str: - m = _INSTRUMENT_SESSION_AUTHZ_REGEX.match(instrument_session) + m = utils.INSTRUMENT_SESSION_RE.match(instrument_session) if m is None: raise ValueError( "Unable to extract proposal and visit from " From a372c174b401cbe0dd878e27507e04effe674292 Mon Sep 17 00:00:00 2001 From: Peter Holloway Date: Fri, 5 Jun 2026 17:04:21 +0100 Subject: [PATCH 27/40] remove task access check --- src/blueapi/service/main.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/blueapi/service/main.py b/src/blueapi/service/main.py index 06c8765c66..aacba8d730 100644 --- a/src/blueapi/service/main.py +++ b/src/blueapi/service/main.py @@ -512,7 +512,6 @@ def set_state( response: Response, fedid: Fedid, opa: Annotated[OpaUserClient, Depends(opa)], - # _: Annotated[None, Depends(access_task_permission)], runner: Annotated[WorkerDispatcher, Depends(_runner)], ) -> WorkerState: """ From cfc22d3d92ed3dcf3a20162a9dc3001cad31942b Mon Sep 17 00:00:00 2001 From: Peter Holloway Date: Fri, 5 Jun 2026 17:06:25 +0100 Subject: [PATCH 28/40] Add match to raises check --- tests/unit_tests/service/test_authorization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit_tests/service/test_authorization.py b/tests/unit_tests/service/test_authorization.py index 65f5c44a61..a2e602f211 100644 --- a/tests/unit_tests/service/test_authorization.py +++ b/tests/unit_tests/service/test_authorization.py @@ -155,7 +155,7 @@ async def test_opa_require_submit_task_invalid_session( ): client = OpaClient(instrument="p45", config=opa_config) - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="Invalid instrument session"): await client.require_submit_task( instrument_session="not a session", token="foo_bar" ) From 340ef6542e4cffac705d3cdd9a8ce497b89ef75d Mon Sep 17 00:00:00 2001 From: Peter Holloway Date: Fri, 5 Jun 2026 17:12:30 +0100 Subject: [PATCH 29/40] Add exception detail --- src/blueapi/service/authorization.py | 8 ++++++-- src/blueapi/service/main.py | 5 ++++- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/src/blueapi/service/authorization.py b/src/blueapi/service/authorization.py index e54d0daa93..f9008138a8 100644 --- a/src/blueapi/service/authorization.py +++ b/src/blueapi/service/authorization.py @@ -72,7 +72,9 @@ async def require_submit_task(self, instrument_session: str, token: str): "visit": int(match["visit"]), }, ): - raise HTTPException(status_code=HTTP_403_FORBIDDEN) + raise HTTPException( + status_code=HTTP_403_FORBIDDEN, detail="Not authorized to submit task" + ) async def is_admin(self, token: str) -> bool: return await self._call_opa(self._config.admin_check, {"token": token}) @@ -117,7 +119,9 @@ async def opa( if opa := cast(OpaClient | None, getattr(request.app.state, "authz", None)): if not token: - raise HTTPException(status_code=HTTP_401_UNAUTHORIZED) + raise HTTPException( + status_code=HTTP_401_UNAUTHORIZED, detail="Authentication missing" + ) return OpaUserClient(opa, token) return None diff --git a/src/blueapi/service/main.py b/src/blueapi/service/main.py index aacba8d730..073e6318f4 100644 --- a/src/blueapi/service/main.py +++ b/src/blueapi/service/main.py @@ -546,7 +546,10 @@ def set_state( and active and active.task.metadata.get("user") != fedid ): - raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Not authorized to set worker state", + ) if new_state == WorkerState.PAUSED: runner.run(interface.pause_worker, state_change_request.defer) From 8ad021f9d50664ab7d5f6439dedd7754ebaf1039 Mon Sep 17 00:00:00 2001 From: Peter Holloway Date: Fri, 5 Jun 2026 17:15:15 +0100 Subject: [PATCH 30/40] Let admin see all tasks --- src/blueapi/service/main.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/blueapi/service/main.py b/src/blueapi/service/main.py index 073e6318f4..680c36881a 100644 --- a/src/blueapi/service/main.py +++ b/src/blueapi/service/main.py @@ -362,6 +362,7 @@ def validate_task_status(v: str) -> TaskStatusEnum: def get_tasks( fedid: Fedid, runner: Annotated[WorkerDispatcher, Depends(_runner)], + opa: Annotated[OpaUserClient, Depends(opa)], task_status: str | SkipJsonSchema[None] = None, ) -> TasksListResponse: """ @@ -382,7 +383,8 @@ def get_tasks( else: tasks = runner.run(interface.get_tasks) - tasks = [t for t in tasks if t.task.metadata.get("user") == fedid] + if opa and not opa.admin(): + tasks = [t for t in tasks if t.task.metadata.get("user") == fedid] return TasksListResponse(tasks=tasks) From 6aa6121b81e07363bacce873aa5ef04d40e870b5 Mon Sep 17 00:00:00 2001 From: Peter Holloway Date: Fri, 5 Jun 2026 17:15:55 +0100 Subject: [PATCH 31/40] Start of api authz tests --- tests/unit_tests/service/test_rest_api.py | 56 ++++++++++++++++++++++- 1 file changed, 55 insertions(+), 1 deletion(-) diff --git a/tests/unit_tests/service/test_rest_api.py b/tests/unit_tests/service/test_rest_api.py index a2248e7985..d1ba177e8f 100644 --- a/tests/unit_tests/service/test_rest_api.py +++ b/tests/unit_tests/service/test_rest_api.py @@ -14,9 +14,15 @@ from pydantic_core import InitErrorDetails from super_state_machine.errors import TransitionError -from blueapi.config import ApplicationConfig, CORSConfig, OIDCConfig, RestConfig +from blueapi.config import ( + ApplicationConfig, + CORSConfig, + OIDCConfig, + RestConfig, +) from blueapi.core.bluesky_types import Plan from blueapi.service import main +from blueapi.service.authorization import OpaUserClient, opa from blueapi.service.interface import ( cancel_active_task, get_device, @@ -54,6 +60,11 @@ def mock_runner() -> Mock: return Mock(spec=WorkerDispatcher) +@pytest.fixture +def mock_opa_client() -> Mock: + return Mock(spec=OpaUserClient) + + @pytest.fixture def client(mock_runner: Mock) -> Iterator[TestClient]: with patch("blueapi.service.interface.worker"): @@ -79,6 +90,27 @@ def client_with_auth( main.teardown_runner() +@pytest.fixture +def access_token(valid_token_with_jwt: dict[str, Any]) -> str: + return valid_token_with_jwt["access_token"] + + +@pytest.fixture +def client_with_opa( + mock_runner: Mock, + oidc_config: OIDCConfig, + mock_opa_client: Mock, + mock_authn_server, +): + with patch("blueapi.service.interface.worker"): + main.setup_runner(runner=mock_runner) + app = main.get_app(ApplicationConfig(oidc=oidc_config)) + app.dependency_overrides[opa] = lambda: mock_opa_client + client = TestClient(app) + yield client + main.teardown_runner() + + @pytest.fixture def rest_config_with_cors() -> RestConfig: cors_config = CORSConfig( @@ -416,6 +448,28 @@ def test_get_tasks_by_status_invalid(client: TestClient) -> None: assert response.status_code == status.HTTP_400_BAD_REQUEST +def test_get_tasks_filters_by_user( + mock_runner: Mock, + client_with_opa: TestClient, + access_token: str, + mock_opa_client: Mock, +): + + print("Start of test") + mock_runner.run.return_value = [ + TrackableTask(task_id="foo", task=Task(name="f1", metadata={"user": "jd1"})), + TrackableTask(task_id="bar", task=Task(name="f2", metadata={"user": "jd2"})), + ] + print(f"in test: {mock_opa_client=}") + mock_opa_client.admin.return_value = False + client_with_opa.headers["Authorization"] = f"Bearer {access_token}" + tasks = client_with_opa.get("/tasks").json().get("tasks") + print(tasks) + + assert len(tasks) == 1 + assert tasks[0]["task_id"] == "foo" + + def test_delete_submitted_task(mock_runner: Mock, client: TestClient) -> None: task_id = str(uuid.uuid4()) mock_runner.run.return_value = task_id From 824b4e852d6e70be46ebf68cca2d29b56c110076 Mon Sep 17 00:00:00 2001 From: Peter Holloway Date: Mon, 8 Jun 2026 14:30:20 +0100 Subject: [PATCH 32/40] Make get_tasks async to access authz check --- src/blueapi/service/main.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/blueapi/service/main.py b/src/blueapi/service/main.py index 680c36881a..529cf0c870 100644 --- a/src/blueapi/service/main.py +++ b/src/blueapi/service/main.py @@ -359,7 +359,7 @@ def validate_task_status(v: str) -> TaskStatusEnum: @secure_router_v1.get("/tasks", status_code=status.HTTP_200_OK, tags=[Tag.TASK]) @secure_router.get("/tasks", status_code=status.HTTP_200_OK, tags=[Tag.TASK]) @start_as_current_span(TRACER) -def get_tasks( +async def get_tasks( fedid: Fedid, runner: Annotated[WorkerDispatcher, Depends(_runner)], opa: Annotated[OpaUserClient, Depends(opa)], @@ -383,7 +383,7 @@ def get_tasks( else: tasks = runner.run(interface.get_tasks) - if opa and not opa.admin(): + if opa and not await opa.admin(): tasks = [t for t in tasks if t.task.metadata.get("user") == fedid] return TasksListResponse(tasks=tasks) From e6a6917dcae6514291c4d4f6307378dd23ae1444 Mon Sep 17 00:00:00 2001 From: Peter Holloway Date: Mon, 8 Jun 2026 14:45:41 +0100 Subject: [PATCH 33/40] Parametrise filter test to check with and without admin --- tests/unit_tests/service/test_rest_api.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tests/unit_tests/service/test_rest_api.py b/tests/unit_tests/service/test_rest_api.py index d1ba177e8f..36c2d1b429 100644 --- a/tests/unit_tests/service/test_rest_api.py +++ b/tests/unit_tests/service/test_rest_api.py @@ -448,11 +448,14 @@ def test_get_tasks_by_status_invalid(client: TestClient) -> None: assert response.status_code == status.HTTP_400_BAD_REQUEST +@pytest.mark.parametrize("admin,task_ids", [(True, ["foo", "bar"]), (False, ["foo"])]) def test_get_tasks_filters_by_user( mock_runner: Mock, client_with_opa: TestClient, access_token: str, mock_opa_client: Mock, + admin: bool, + task_ids: list[str], ): print("Start of test") @@ -461,13 +464,12 @@ def test_get_tasks_filters_by_user( TrackableTask(task_id="bar", task=Task(name="f2", metadata={"user": "jd2"})), ] print(f"in test: {mock_opa_client=}") - mock_opa_client.admin.return_value = False + mock_opa_client.admin.return_value = admin client_with_opa.headers["Authorization"] = f"Bearer {access_token}" tasks = client_with_opa.get("/tasks").json().get("tasks") print(tasks) - assert len(tasks) == 1 - assert tasks[0]["task_id"] == "foo" + assert [t["task_id"] for t in tasks] == task_ids def test_delete_submitted_task(mock_runner: Mock, client: TestClient) -> None: From 0518099fdaedb79d090d796cd69071bd52fc71cb Mon Sep 17 00:00:00 2001 From: Peter Holloway Date: Mon, 8 Jun 2026 15:17:11 +0100 Subject: [PATCH 34/40] Add test for deleting tasks --- src/blueapi/service/main.py | 12 ++++++++---- tests/unit_tests/service/test_rest_api.py | 24 ++++++++++++++++++++++- 2 files changed, 31 insertions(+), 5 deletions(-) diff --git a/src/blueapi/service/main.py b/src/blueapi/service/main.py index 529cf0c870..5b8ca70e42 100644 --- a/src/blueapi/service/main.py +++ b/src/blueapi/service/main.py @@ -152,7 +152,7 @@ def get_app(config: ApplicationConfig): return app -def access_task_permission( +async def access_task_permission( opa: Annotated[OpaUserClient | None, Depends(opa)], task_id: str, fedid: Fedid, @@ -160,12 +160,16 @@ def access_task_permission( ): task = runner.run(interface.get_task_by_id, task_id) - if opa and not opa.admin() and (task and fedid != task.task.metadata.get("user")): + if ( + opa + and not await opa.admin() + and (task and fedid != task.task.metadata.get("user")) + ): raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) # start_task_permission is used when there is WorkerTask -def start_task_permission( +async def start_task_permission( task: WorkerTask, opa: Annotated[OpaUserClient, Depends(opa)], fedid: Fedid, @@ -176,7 +180,7 @@ def start_task_permission( status_code=status.HTTP_422_UNPROCESSABLE_CONTENT, detail="No task id provided", ) - access_task_permission(opa, task.task_id, fedid, runner) + await access_task_permission(opa, task.task_id, fedid, runner) async def on_key_error_404(_: Request, __: Exception): diff --git a/tests/unit_tests/service/test_rest_api.py b/tests/unit_tests/service/test_rest_api.py index 36c2d1b429..9ed15ff4c6 100644 --- a/tests/unit_tests/service/test_rest_api.py +++ b/tests/unit_tests/service/test_rest_api.py @@ -21,7 +21,7 @@ RestConfig, ) from blueapi.core.bluesky_types import Plan -from blueapi.service import main +from blueapi.service import interface, main from blueapi.service.authorization import OpaUserClient, opa from blueapi.service.interface import ( cancel_active_task, @@ -479,6 +479,28 @@ def test_delete_submitted_task(mock_runner: Mock, client: TestClient) -> None: assert response.json() == {"task_id": f"{task_id}"} +def test_cant_delete_other_users_task( + mock_runner: Mock, + client_with_opa: TestClient, + access_token: str, + mock_opa_client: Mock, +): + mock_opa_client.admin.return_value = False + mock_runner.run.side_effect = lambda mth, *args: { + interface.get_task_by_id: TrackableTask( + task_id="bar", task=Task(name="t2", metadata={"user": "jd2"}) + ), + }[mth] + client_with_opa.headers["Authorization"] = f"Bearer {access_token}" + + resp = client_with_opa.delete("/tasks/bar") + + # 404 to obfuscate whether task exists when inaccessible + assert resp.status_code == 404 + + mock_runner.run.assert_called_once() + + def test_set_active_task(client: TestClient) -> None: task_id = str(uuid.uuid4()) task = WorkerTask(task_id=task_id) From fa02ed501987745b7b08df1630983c8d661cb8dc Mon Sep 17 00:00:00 2001 From: Peter Holloway Date: Mon, 8 Jun 2026 16:42:20 +0100 Subject: [PATCH 35/40] Add test for submit without permission --- tests/unit_tests/service/test_rest_api.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/tests/unit_tests/service/test_rest_api.py b/tests/unit_tests/service/test_rest_api.py index 9ed15ff4c6..90b04f04c6 100644 --- a/tests/unit_tests/service/test_rest_api.py +++ b/tests/unit_tests/service/test_rest_api.py @@ -7,7 +7,7 @@ import jwt import pytest from bluesky.protocols import Stoppable -from fastapi import status +from fastapi import HTTPException, status from fastapi.testclient import TestClient from httpx import Headers from pydantic import BaseModel, ValidationError @@ -287,6 +287,23 @@ def test_create_task(mock_runner: Mock, client: TestClient) -> None: assert response.json() == {"task_id": task_id} +def test_submit_task_requires_permission( + mock_runner: Mock, + client_with_opa: TestClient, + mock_opa_client: Mock, + access_token: str, +): + task = TaskRequest(name="sleep", params={"time": 2}, instrument_session="cm12345-2") + client_with_opa.headers["Authorization"] = f"Bearer {access_token}" + mock_opa_client.can_submit_task.side_effect = HTTPException(status_code=403) + mock_runner.run.side_effect = RuntimeError("Task should not be submitted") + + resp = client_with_opa.post("/tasks", json=task.model_dump()) + + assert resp.status_code == 403 + mock_runner.run.assert_not_called() + + def test_create_task_inserts_auth_metadata( mock_runner: Mock, client_with_auth: TestClient, From a0161c395b8984d538d0e1735730df0b173d3412 Mon Sep 17 00:00:00 2001 From: Peter Holloway Date: Tue, 9 Jun 2026 09:27:58 +0100 Subject: [PATCH 36/40] Test setting other user's task active --- tests/unit_tests/service/test_rest_api.py | 29 +++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/tests/unit_tests/service/test_rest_api.py b/tests/unit_tests/service/test_rest_api.py index 90b04f04c6..455471fda5 100644 --- a/tests/unit_tests/service/test_rest_api.py +++ b/tests/unit_tests/service/test_rest_api.py @@ -566,6 +566,35 @@ def test_set_active_task_worker_already_running( assert response.json() == {"detail": "Worker already active"} +@pytest.mark.parametrize("admin,status", [(True, 200), (False, 404)]) +def test_set_other_users_task_active( + mock_runner: Mock, + client_with_opa: TestClient, + mock_opa_client: Mock, + access_token: str, + admin: bool, + status: int, +): + + task_id = "foo" + task = WorkerTask(task_id=task_id) + mock_opa_client.admin.return_value = admin + + client_with_opa.headers["Authorization"] = f"Bearer {access_token}" + + mock_runner.run.side_effect = lambda mth, *a, **kw: { + interface.get_task_by_id: TrackableTask( + task_id="foo", task=Task(name="bar", metadata={"user": "jd2"}) + ), + interface.get_active_task: None, + interface.begin_task: None, + }[mth] + + resp = client_with_opa.put("/worker/task", json=task.model_dump()) + + assert resp.status_code == status + + def test_get_task(mock_runner: Mock, client: TestClient): task_id = str(uuid.uuid4()) task = TrackableTask( From 74fffa9a1b611a2cadc4a43b86c832ff54a6d0a2 Mon Sep 17 00:00:00 2001 From: Peter Holloway Date: Tue, 9 Jun 2026 09:44:14 +0100 Subject: [PATCH 37/40] Test getting other users task --- tests/unit_tests/service/test_rest_api.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/tests/unit_tests/service/test_rest_api.py b/tests/unit_tests/service/test_rest_api.py index 455471fda5..697666f41b 100644 --- a/tests/unit_tests/service/test_rest_api.py +++ b/tests/unit_tests/service/test_rest_api.py @@ -627,6 +627,25 @@ def test_get_task(mock_runner: Mock, client: TestClient): } +@pytest.mark.parametrize("admin,status", [(True, 200), (False, 404)]) +def test_get_other_users_task( + mock_runner: Mock, + client_with_opa: TestClient, + mock_opa_client: Mock, + access_token: str, + admin: bool, + status: int, +): + client_with_opa.headers["Authorization"] = f"Bearer {access_token}" + mock_runner.run.return_value = TrackableTask( + task_id="foo", task=Task(name="bar", metadata={"user": "jd2"}) + ) + mock_opa_client.admin.return_value = admin + + resp = client_with_opa.get("/tasks/foo") + assert resp.status_code == status + + def test_get_all_tasks(mock_runner: Mock, client: TestClient): task_id = str(uuid.uuid4()) tasks = [ From 50c2677b39697744cde08f6754b4c57ba9274af3 Mon Sep 17 00:00:00 2001 From: Peter Holloway Date: Tue, 9 Jun 2026 11:05:06 +0100 Subject: [PATCH 38/40] Add tests for set state --- src/blueapi/service/main.py | 4 ++-- tests/unit_tests/service/test_rest_api.py | 29 +++++++++++++++++++++++ 2 files changed, 31 insertions(+), 2 deletions(-) diff --git a/src/blueapi/service/main.py b/src/blueapi/service/main.py index 5b8ca70e42..90fe28aaf6 100644 --- a/src/blueapi/service/main.py +++ b/src/blueapi/service/main.py @@ -513,7 +513,7 @@ def get_state(runner: Annotated[WorkerDispatcher, Depends(_runner)]) -> WorkerSt tags=[Tag.TASK], ) @start_as_current_span(TRACER, "state_change_request.new_state") -def set_state( +async def set_state( state_change_request: StateChangeRequest, response: Response, fedid: Fedid, @@ -548,7 +548,7 @@ def set_state( if ( opa - and not opa.admin() + and not await opa.admin() and active and active.task.metadata.get("user") != fedid ): diff --git a/tests/unit_tests/service/test_rest_api.py b/tests/unit_tests/service/test_rest_api.py index 697666f41b..5d0de75481 100644 --- a/tests/unit_tests/service/test_rest_api.py +++ b/tests/unit_tests/service/test_rest_api.py @@ -830,6 +830,35 @@ def test_set_state_invalid_transition(mock_runner: Mock, client: TestClient): assert response.json() == final_state +@pytest.mark.parametrize("admin,status", [(True, 202), (False, 403)]) +def test_set_state_of_other_users_task( + mock_runner: Mock, + client_with_opa: TestClient, + mock_opa_client: Mock, + access_token: str, + admin: bool, + status: int, +): + + mock_opa_client.admin.return_value = admin + mock_runner.run.side_effect = lambda mth, *a, **kw: { + interface.get_active_task: TrackableTask( + task_id="foo", task=Task(name="bar", metadata={"user": "jd2"}) + ), + interface.get_worker_state: WorkerState.RUNNING, + interface.cancel_active_task: WorkerState.ABORTING, + }[mth] + + client_with_opa.headers["Authorization"] = f"Bearer {access_token}" + + resp = client_with_opa.put( + "/worker/state", + json=StateChangeRequest(new_state=WorkerState.ABORTING).model_dump(), + ) + + assert resp.status_code == status + + def test_get_environment_idle(mock_runner: Mock, client: TestClient) -> None: environment_id = uuid.uuid4() mock_runner.state = EnvironmentResponse( From 09121588cba4bc0142bc3a76ae803e39e751f304 Mon Sep 17 00:00:00 2001 From: Peter Holloway Date: Wed, 10 Jun 2026 10:27:17 +0100 Subject: [PATCH 39/40] Remove print debugging --- tests/unit_tests/service/test_rest_api.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/unit_tests/service/test_rest_api.py b/tests/unit_tests/service/test_rest_api.py index 5d0de75481..a2f46e668e 100644 --- a/tests/unit_tests/service/test_rest_api.py +++ b/tests/unit_tests/service/test_rest_api.py @@ -475,16 +475,13 @@ def test_get_tasks_filters_by_user( task_ids: list[str], ): - print("Start of test") mock_runner.run.return_value = [ TrackableTask(task_id="foo", task=Task(name="f1", metadata={"user": "jd1"})), TrackableTask(task_id="bar", task=Task(name="f2", metadata={"user": "jd2"})), ] - print(f"in test: {mock_opa_client=}") mock_opa_client.admin.return_value = admin client_with_opa.headers["Authorization"] = f"Bearer {access_token}" tasks = client_with_opa.get("/tasks").json().get("tasks") - print(tasks) assert [t["task_id"] for t in tasks] == task_ids From b65f9b54dc44c119db3827867798bd42cf9d2506 Mon Sep 17 00:00:00 2001 From: Peter Holloway Date: Tue, 9 Jun 2026 18:13:05 +0100 Subject: [PATCH 40/40] Refactor mock_runner to make it clear what is being mocked --- tests/unit_tests/service/test_rest_api.py | 347 ++++++++++++++-------- 1 file changed, 219 insertions(+), 128 deletions(-) diff --git a/tests/unit_tests/service/test_rest_api.py b/tests/unit_tests/service/test_rest_api.py index a2f46e668e..1b15c77907 100644 --- a/tests/unit_tests/service/test_rest_api.py +++ b/tests/unit_tests/service/test_rest_api.py @@ -2,9 +2,8 @@ from collections.abc import Iterator from dataclasses import dataclass from typing import Any -from unittest.mock import MagicMock, Mock, patch +from unittest.mock import MagicMock, Mock, PropertyMock, patch -import jwt import pytest from bluesky.protocols import Stoppable from fastapi import HTTPException, status @@ -27,7 +26,6 @@ cancel_active_task, get_device, get_plan, - pause_worker, resume_worker, submit_task, ) @@ -43,7 +41,7 @@ WorkerTask, ) from blueapi.service.runner import WorkerDispatcher -from blueapi.worker.event import WorkerState +from blueapi.worker.event import TaskStatusEnum, WorkerState from blueapi.worker.task import Task from blueapi.worker.task_worker import TrackableTask @@ -56,8 +54,78 @@ class MockCountModel(BaseModel): ... @pytest.fixture -def mock_runner() -> Mock: - return Mock(spec=WorkerDispatcher) +def mock_runner_data() -> Mock: + return Mock() + + +@pytest.fixture +def mock_runner(mock_runner_data: Mock) -> Mock: + runner = Mock(spec=WorkerDispatcher) + + def run(method, *args, **kwargs): + match method: + case interface.get_active_task: + running = run(interface.get_tasks_by_status, TaskStatusEnum.RUNNING) + return running[0] if running else None + case interface.get_task_by_id: + return {t.task_id: t for t in mock_runner_data.tasks}.get( + kwargs.get("task_id") or args[0] + ) + case interface.get_tasks: + return mock_runner_data.tasks + case interface.get_tasks_by_status: + status = kwargs.get("status") or args[0] + match status: + case TaskStatusEnum.RUNNING: + return [ + t + for t in mock_runner_data.tasks + if not t.is_pending and not t.is_complete + ] + case TaskStatusEnum.PENDING: + return [t for t in mock_runner_data.tasks if t.is_pending] + case TaskStatusEnum.COMPLETE: + return [t for t in mock_runner_data.tasks if t.is_complete] + case _: + return [] + + case interface.get_plans: + return mock_runner_data.plans + case interface.get_plan: + name = kwargs.get("name") or args[0] + plans = [p for p in mock_runner_data.plans if p.name == name] + if plans: + return plans[0] + raise KeyError(name) + case interface.get_devices: + return mock_runner_data.devices + case interface.get_device: + name = kwargs.get("name") or args[0] + devices = [d for d in mock_runner_data.devices if d.name == name] + if devices: + return devices[0] + raise KeyError(name) + case interface.get_oidc_config: + return mock_runner_data.oidc_config + case interface.get_worker_state: + return mock_runner_data.state + case interface.get_python_env: + return mock_runner_data.python_environment + case interface.submit_task: + return mock_runner_data.submit_task(*args, **kwargs) + case interface.begin_task: + return mock_runner_data.begin_task(*args, **kwargs) + case interface.clear_task: + return mock_runner_data.clear_task(*args, **kwargs) + case interface.cancel_active_task: + return mock_runner_data.cancel_active_task(*args, **kwargs) + case interface.pause_worker: + return mock_runner_data.pause_worker(*args, **kwargs) + case _: + raise ValueError("Unsupported method: " + method.__name__) + + runner.run.side_effect = run + return runner @pytest.fixture @@ -141,13 +209,13 @@ def stop(self, success: bool = True): def test_rest_config_with_cors_gets_plan( client_with_cors: TestClient, - mock_runner: Mock, + mock_runner_data: Mock, ): class MyModel(BaseModel): id: str plan = Plan(name="my-plan", model=MyModel) - mock_runner.run.return_value = [PlanModel.from_plan(plan)] + mock_runner_data.plans = [PlanModel.from_plan(plan)] response_get = client_with_cors.get("/plans") assert response_get.status_code == status.HTTP_200_OK @@ -155,7 +223,7 @@ class MyModel(BaseModel): def test_rest_config_with_cors( client_with_cors: TestClient, - mock_runner: Mock, + mock_runner_data: Mock, ): task = TaskRequest( name="my-plan", @@ -163,7 +231,7 @@ def test_rest_config_with_cors( instrument_session=FAKE_INSTRUMENT_SESSION, ) task_id = "f8424be3-203c-494e-b22f-219933b4fa67" - mock_runner.run.side_effect = [task_id] + mock_runner_data.submit_task.return_value = task_id # Allowed method response_post = client_with_cors.post( @@ -175,12 +243,12 @@ def test_rest_config_with_cors( assert response_post.headers["content-type"] == "application/json" -def test_get_plans(mock_runner: Mock, client: TestClient) -> None: +def test_get_plans(mock_runner_data: Mock, client: TestClient) -> None: class MyModel(BaseModel): id: str plan = Plan(name="my-plan", model=MyModel) - mock_runner.run.return_value = [PlanModel.from_plan(plan)] + mock_runner_data.plans = [PlanModel.from_plan(plan)] response = client.get("/plans") @@ -201,12 +269,14 @@ class MyModel(BaseModel): } -def test_get_plan_by_name(mock_runner: Mock, client: TestClient) -> None: +def test_get_plan_by_name( + mock_runner: Mock, mock_runner_data: Mock, client: TestClient +) -> None: class MyModel(BaseModel): id: str plan = Plan(name="my-plan", model=MyModel) - mock_runner.run.return_value = PlanModel.from_plan(plan) + mock_runner_data.plans = [PlanModel.from_plan(plan)] response = client.get("/plans/my-plan") @@ -224,17 +294,19 @@ class MyModel(BaseModel): } -def test_get_non_existent_plan_by_name(mock_runner: Mock, client: TestClient) -> None: - mock_runner.run.side_effect = KeyError("my-plan") +def test_get_non_existent_plan_by_name( + mock_runner_data: Mock, client: TestClient +) -> None: + mock_runner_data.plans = [] response = client.get("/plans/my-plan") assert response.status_code == status.HTTP_404_NOT_FOUND assert response.json() == {"detail": "Item not found"} -def test_get_devices(mock_runner: Mock, client: TestClient) -> None: +def test_get_devices(mock_runner_data: Mock, client: TestClient) -> None: device = MinimalDevice("my-device") - mock_runner.run.return_value = [DeviceModel.from_device(device)] + mock_runner_data.devices = [DeviceModel.from_device(device)] response = client.get("/devices") @@ -249,10 +321,12 @@ def test_get_devices(mock_runner: Mock, client: TestClient) -> None: } -def test_get_device_by_name(mock_runner: Mock, client: TestClient) -> None: +def test_get_device_by_name( + mock_runner: Mock, mock_runner_data: Mock, client: TestClient +) -> None: device = MinimalDevice("my-device") - mock_runner.run.return_value = DeviceModel.from_device(device) + mock_runner_data.devices = [DeviceModel.from_device(device)] response = client.get("/devices/my-device") mock_runner.run.assert_called_once_with(get_device, "my-device") @@ -263,15 +337,15 @@ def test_get_device_by_name(mock_runner: Mock, client: TestClient) -> None: } -def test_get_non_existent_device_by_name(mock_runner: Mock, client: TestClient) -> None: - mock_runner.run.side_effect = KeyError("my-device") +def test_get_non_existent_device_by_name(mock_runner_data: Mock, client: TestClient): + mock_runner_data.devices = [] response = client.get("/devices/my-device") assert response.status_code == status.HTTP_404_NOT_FOUND assert response.json() == {"detail": "Item not found"} -def test_create_task(mock_runner: Mock, client: TestClient) -> None: +def test_create_task(mock_runner: Mock, mock_runner_data: Mock, client: TestClient): task = TaskRequest( name="count", params={"detectors": ["x"]}, @@ -279,7 +353,7 @@ def test_create_task(mock_runner: Mock, client: TestClient) -> None: ) task_id = str(uuid.uuid4()) - mock_runner.run.side_effect = [task_id] + mock_runner_data.submit_task.return_value = task_id response = client.post("/tasks", json=task.model_dump()) @@ -306,6 +380,7 @@ def test_submit_task_requires_permission( def test_create_task_inserts_auth_metadata( mock_runner: Mock, + mock_runner_data: Mock, client_with_auth: TestClient, ) -> None: task = TaskRequest( @@ -316,8 +391,7 @@ def test_create_task_inserts_auth_metadata( client_with_auth.follow_redirects = False task_id = str(uuid.uuid4()) - # mock_runner.run.side_effect = [task_id] - mock_runner.run.return_value = [task_id] + mock_runner_data.submit_task.return_value = task_id client_with_auth.post("/tasks", json=task.model_dump()) @@ -356,8 +430,10 @@ def test_create_task_validation_error(mock_runner: Mock, client: TestClient) -> } -def test_put_plan_begins_task(client: TestClient) -> None: +def test_put_plan_begins_task(client: TestClient, mock_runner_data: Mock) -> None: task_id = "04cd9aa6-b902-414b-ae4b-49ea4200e957" + mock_runner_data.tasks = [TrackableTask(task_id=task_id, task=Task(name="foo"))] + mock_runner_data.begin_task.side_effect = lambda task, **kw: task resp = client.put("/worker/task", json={"task_id": task_id}) @@ -365,14 +441,19 @@ def test_put_plan_begins_task(client: TestClient) -> None: assert resp.json() == {"task_id": task_id} -def test_put_plan_fails_if_not_idle(mock_runner: Mock, client: TestClient) -> None: +def test_put_plan_fails_if_not_idle(mock_runner_data: Mock, client: TestClient) -> None: task_id_current = "260f7de3-b608-4cdc-a66c-257e95809792" task_id_new = "07e98d68-21b5-4ad7-ac34-08b2cb992d42" # Set to non idle - mock_runner.run.return_value = TrackableTask( - task=Task(name="none"), task_id=task_id_current, is_complete=False - ) + mock_runner_data.tasks = [ + TrackableTask( + task=Task(name="none"), + task_id=task_id_current, + is_pending=False, + is_complete=False, + ) + ] resp = client.put("/worker/task", json={"task_id": task_id_new}) @@ -380,8 +461,8 @@ def test_put_plan_fails_if_not_idle(mock_runner: Mock, client: TestClient) -> No assert resp.json() == {"detail": "Worker already active"} -def test_get_tasks(mock_runner: Mock, client: TestClient) -> None: - tasks = [ +def test_get_tasks(mock_runner_data: Mock, client: TestClient) -> None: + mock_runner_data.tasks = [ TrackableTask(task_id="0", task=Task(name="sleep", params={"time": 0.0})), TrackableTask( task_id="1", @@ -391,8 +472,6 @@ def test_get_tasks(mock_runner: Mock, client: TestClient) -> None: ), ] - mock_runner.run.return_value = tasks - response = client.get("/tasks") assert response.status_code == status.HTTP_200_OK @@ -428,8 +507,8 @@ def test_get_tasks(mock_runner: Mock, client: TestClient) -> None: } -def test_get_tasks_by_status(mock_runner: Mock, client: TestClient) -> None: - tasks = [ +def test_get_tasks_by_status(mock_runner_data: Mock, client: TestClient) -> None: + mock_runner_data.tasks = [ TrackableTask( task_id="3", task=Task(name="third_task"), @@ -438,9 +517,7 @@ def test_get_tasks_by_status(mock_runner: Mock, client: TestClient) -> None: ), ] - mock_runner.run.return_value = tasks - - response = client.get("/tasks", params={"task_status": "PENDING"}) + response = client.get("/tasks", params={"task_status": "COMPLETE"}) assert response.json() == { "tasks": [ { @@ -467,7 +544,7 @@ def test_get_tasks_by_status_invalid(client: TestClient) -> None: @pytest.mark.parametrize("admin,task_ids", [(True, ["foo", "bar"]), (False, ["foo"])]) def test_get_tasks_filters_by_user( - mock_runner: Mock, + mock_runner_data: Mock, client_with_opa: TestClient, access_token: str, mock_opa_client: Mock, @@ -475,7 +552,7 @@ def test_get_tasks_filters_by_user( task_ids: list[str], ): - mock_runner.run.return_value = [ + mock_runner_data.tasks = [ TrackableTask(task_id="foo", task=Task(name="f1", metadata={"user": "jd1"})), TrackableTask(task_id="bar", task=Task(name="f2", metadata={"user": "jd2"})), ] @@ -486,38 +563,43 @@ def test_get_tasks_filters_by_user( assert [t["task_id"] for t in tasks] == task_ids -def test_delete_submitted_task(mock_runner: Mock, client: TestClient) -> None: +def test_delete_submitted_task(mock_runner_data: Mock, client: TestClient) -> None: task_id = str(uuid.uuid4()) - mock_runner.run.return_value = task_id + mock_runner_data.tasks = [TrackableTask(task_id=task_id, task=Task(name="foo"))] + mock_runner_data.clear_task.return_value = task_id response = client.delete(f"/tasks/{task_id}") assert response.json() == {"task_id": f"{task_id}"} + mock_runner_data.clear_task.assert_called_once_with(task_id) def test_cant_delete_other_users_task( mock_runner: Mock, + mock_runner_data: Mock, client_with_opa: TestClient, access_token: str, mock_opa_client: Mock, ): mock_opa_client.admin.return_value = False - mock_runner.run.side_effect = lambda mth, *args: { - interface.get_task_by_id: TrackableTask( - task_id="bar", task=Task(name="t2", metadata={"user": "jd2"}) - ), - }[mth] + mock_runner_data.tasks = [ + TrackableTask( + task_id="bar", + is_pending=False, + task=Task(name="t2", metadata={"user": "jd2"}), + ) + ] client_with_opa.headers["Authorization"] = f"Bearer {access_token}" resp = client_with_opa.delete("/tasks/bar") # 404 to obfuscate whether task exists when inaccessible assert resp.status_code == 404 - - mock_runner.run.assert_called_once() + mock_runner_data.clear_task.assert_not_called() -def test_set_active_task(client: TestClient) -> None: +def test_set_active_task(client: TestClient, mock_runner_data: Mock) -> None: task_id = str(uuid.uuid4()) task = WorkerTask(task_id=task_id) + mock_runner_data.tasks = [TrackableTask(task_id=task_id, task=Task(name="foo"))] response = client.put("/worker/task", json=task.model_dump()) @@ -526,17 +608,19 @@ def test_set_active_task(client: TestClient) -> None: def test_set_active_task_active_task_complete( - mock_runner: Mock, client: TestClient + mock_runner_data: Mock, client: TestClient ) -> None: task_id = str(uuid.uuid4()) task = WorkerTask(task_id=task_id) - mock_runner.run.return_value = TrackableTask( - task_id="1", - task=Task(name="a_completed_task"), - is_complete=True, - is_pending=False, - ) + mock_runner_data.tasks = [ + TrackableTask( + task_id="1", + task=Task(name="a_completed_task"), + is_complete=True, + is_pending=False, + ) + ] response = client.put("/worker/task", json=task.model_dump()) @@ -545,17 +629,19 @@ def test_set_active_task_active_task_complete( def test_set_active_task_worker_already_running( - mock_runner: Mock, client: TestClient + mock_runner_data: Mock, client: TestClient ) -> None: task_id = str(uuid.uuid4()) task = WorkerTask(task_id=task_id) - mock_runner.run.return_value = TrackableTask( - task_id="1", - task=Task(name="a_running_task"), - is_complete=False, - is_pending=False, - ) + mock_runner_data.tasks = [ + TrackableTask( + task_id="1", + task=Task(name="a_running_task"), + is_complete=False, + is_pending=False, + ) + ] response = client.put("/worker/task", json=task.model_dump()) @@ -565,7 +651,7 @@ def test_set_active_task_worker_already_running( @pytest.mark.parametrize("admin,status", [(True, 200), (False, 404)]) def test_set_other_users_task_active( - mock_runner: Mock, + mock_runner_data: Mock, client_with_opa: TestClient, mock_opa_client: Mock, access_token: str, @@ -579,32 +665,31 @@ def test_set_other_users_task_active( client_with_opa.headers["Authorization"] = f"Bearer {access_token}" - mock_runner.run.side_effect = lambda mth, *a, **kw: { - interface.get_task_by_id: TrackableTask( - task_id="foo", task=Task(name="bar", metadata={"user": "jd2"}) - ), - interface.get_active_task: None, - interface.begin_task: None, - }[mth] + mock_runner_data.tasks = [ + TrackableTask(task_id=task_id, task=Task(name="foo", metadata={"user": "jd2"})) + ] resp = client_with_opa.put("/worker/task", json=task.model_dump()) + if status >= 400: + mock_runner_data.begin_task.assert_not_called() + assert resp.status_code == status -def test_get_task(mock_runner: Mock, client: TestClient): +def test_get_task(mock_runner_data: Mock, client: TestClient): task_id = str(uuid.uuid4()) - task = TrackableTask( - task_id=task_id, - task=Task( - name="third_task", - metadata={ - "foo": "bar", - }, - ), - ) - - mock_runner.run.return_value = task + mock_runner_data.tasks = [ + TrackableTask( + task_id=task_id, + task=Task( + name="third_task", + metadata={ + "foo": "bar", + }, + ), + ) + ] response = client.get(f"/tasks/{task_id}") assert response.json() == { @@ -626,7 +711,7 @@ def test_get_task(mock_runner: Mock, client: TestClient): @pytest.mark.parametrize("admin,status", [(True, 200), (False, 404)]) def test_get_other_users_task( - mock_runner: Mock, + mock_runner_data: Mock, client_with_opa: TestClient, mock_opa_client: Mock, access_token: str, @@ -634,25 +719,24 @@ def test_get_other_users_task( status: int, ): client_with_opa.headers["Authorization"] = f"Bearer {access_token}" - mock_runner.run.return_value = TrackableTask( - task_id="foo", task=Task(name="bar", metadata={"user": "jd2"}) - ) + mock_runner_data.tasks = [ + TrackableTask(task_id="foo", task=Task(name="bar", metadata={"user": "jd2"})) + ] mock_opa_client.admin.return_value = admin resp = client_with_opa.get("/tasks/foo") assert resp.status_code == status -def test_get_all_tasks(mock_runner: Mock, client: TestClient): +def test_get_all_tasks(mock_runner_data: Mock, client: TestClient): task_id = str(uuid.uuid4()) - tasks = [ + mock_runner_data.tasks = [ TrackableTask( task_id=task_id, task=Task(name="third_task"), ) ] - mock_runner.run.return_value = tasks response = client.get("/tasks") assert response.status_code == status.HTTP_200_OK assert response.json() == { @@ -674,58 +758,60 @@ def test_get_all_tasks(mock_runner: Mock, client: TestClient): } -def test_get_task_error(mock_runner: Mock, client: TestClient): +def test_get_task_error(mock_runner_data: Mock, client: TestClient): task_id = 567 - mock_runner.run.return_value = None + mock_runner_data.tasks = [] response = client.get(f"/tasks/{task_id}") assert response.json() == {"detail": "Item not found"} -def test_get_active_task(mock_runner: Mock, client: TestClient): +def test_get_active_task(mock_runner_data: Mock, client: TestClient): task_id = str(uuid.uuid4()) task = TrackableTask( task_id=task_id, task=Task(name="third_task"), + is_pending=False, + is_complete=False, ) - mock_runner.run.return_value = task + mock_runner_data.tasks = [task] response = client.get("/worker/task") assert response.json() == {"task_id": f"{task_id}"} -def test_get_active_task_none(mock_runner: Mock, client: TestClient): - mock_runner.run.return_value = None +def test_get_active_task_none(mock_runner_data: Mock, client: TestClient): + mock_runner_data.tasks = [] response = client.get("/worker/task") assert response.json() == {"task_id": None} -def test_get_state(mock_runner: Mock, client: TestClient): +def test_get_state(mock_runner_data: Mock, client: TestClient): state = WorkerState.SUSPENDING - mock_runner.run.return_value = state + mock_runner_data.state = state response = client.get("/worker/state") assert response.json() == state -def test_set_state_running_to_paused(mock_runner: Mock, client: TestClient): +def test_set_state_running_to_paused(mock_runner_data: Mock, client: TestClient): current_state = WorkerState.RUNNING final_state = WorkerState.PAUSED - mock_runner.run.side_effect = [ - current_state, + type(mock_runner_data).state = PropertyMock( + side_effect=[current_state, final_state] + ) + mock_runner_data.tasks = [ TrackableTask(task_id="foobar", task=Task(name="foo")), - None, - final_state, ] response = client.put( "/worker/state", json=StateChangeRequest(new_state=final_state).model_dump() ) - mock_runner.run.assert_any_call(pause_worker, False) + mock_runner_data.pause_worker.assert_called_once_with(False) assert response.status_code == status.HTTP_202_ACCEPTED assert response.json() == final_state @@ -829,7 +915,7 @@ def test_set_state_invalid_transition(mock_runner: Mock, client: TestClient): @pytest.mark.parametrize("admin,status", [(True, 202), (False, 403)]) def test_set_state_of_other_users_task( - mock_runner: Mock, + mock_runner_data: Mock, client_with_opa: TestClient, mock_opa_client: Mock, access_token: str, @@ -838,13 +924,15 @@ def test_set_state_of_other_users_task( ): mock_opa_client.admin.return_value = admin - mock_runner.run.side_effect = lambda mth, *a, **kw: { - interface.get_active_task: TrackableTask( - task_id="foo", task=Task(name="bar", metadata={"user": "jd2"}) + mock_runner_data.tasks = [ + TrackableTask( + task_id="foo", + is_pending=False, + task=Task(name="bar", metadata={"user": "jd2"}), ), - interface.get_worker_state: WorkerState.RUNNING, - interface.cancel_active_task: WorkerState.ABORTING, - }[mth] + ] + mock_runner_data.state = WorkerState.RUNNING + mock_runner_data.cancel_active_task.return_value = WorkerState.ABORTING client_with_opa.headers["Authorization"] = f"Bearer {access_token}" @@ -853,6 +941,8 @@ def test_set_state_of_other_users_task( json=StateChangeRequest(new_state=WorkerState.ABORTING).model_dump(), ) + if not admin: + mock_runner_data.cancel_active_task.assert_not_called() assert resp.status_code == status @@ -895,36 +985,37 @@ def test_subprocess_enabled_by_default(mp_pool_mock: MagicMock): main.teardown_runner() -def test_get_without_authentication(mock_runner: Mock, client: TestClient) -> None: - mock_runner.run.side_effect = jwt.PyJWTError - response = client.get("/devices/my-device") +def test_get_without_authentication(mock_runner: Mock, client_with_auth: TestClient): + del client_with_auth.headers["Authorization"] + response = client_with_auth.get("/devices/my-device") + mock_runner.run.assert_not_called() assert response.status_code == status.HTTP_401_UNAUTHORIZED assert response.json() == {"detail": "Not authenticated"} def test_oidc_config_not_found_when_auth_is_disabled( - mock_runner: Mock, client: TestClient + mock_runner_data: Mock, client: TestClient ): - mock_runner.run.return_value = None + mock_runner_data.oidc_config = None response = client.get("/config/oidc") assert response.status_code == status.HTTP_204_NO_CONTENT assert response.text == "" def test_get_oidc_config( - mock_runner: Mock, + mock_runner_data: Mock, oidc_config: OIDCConfig, mock_authn_server, client_with_auth: TestClient, ): - mock_runner.run.return_value = oidc_config + mock_runner_data.oidc_config = oidc_config response = client_with_auth.get("/config/oidc") assert response.status_code == status.HTTP_200_OK assert response.json() == oidc_config.model_dump() -def test_get_python_environment(mock_runner: Mock, client: TestClient): +def test_get_python_environment(mock_runner_data: Mock, client: TestClient): packages = PythonEnvironmentResponse( installed_packages=[ PackageInfo( @@ -936,7 +1027,7 @@ def test_get_python_environment(mock_runner: Mock, client: TestClient): ) ] ) - mock_runner.run.return_value = packages + mock_runner_data.python_environment = packages response = client.get("/python_environment") assert response.status_code == status.HTTP_200_OK assert response.json() == packages.model_dump() @@ -950,13 +1041,13 @@ def test_health_probe(client: TestClient): def test_logout( - mock_runner: Mock, + mock_runner_data: Mock, mock_authn_server, oidc_config: OIDCConfig, client_with_auth: TestClient, ): oidc_config.logout_redirect_endpoint = "/oauth2/sign_out/" - mock_runner.run.return_value = oidc_config + mock_runner_data.oidc_config = oidc_config client_with_auth.follow_redirects = False response = client_with_auth.get("/logout") assert response.status_code == status.HTTP_308_PERMANENT_REDIRECT @@ -979,16 +1070,16 @@ def test_docs_redirect( @pytest.mark.parametrize("has_oidc_config", [True, False]) def test_logout_when_oidc_config_invalid( has_oidc_config: bool, - mock_runner: Mock, + mock_runner_data: Mock, oidc_config: OIDCConfig, mock_authn_server, client_with_auth: TestClient, ): if has_oidc_config: oidc_config.logout_redirect_endpoint = "" - mock_runner.run.return_value = oidc_config + mock_runner_data.oidc_config = oidc_config else: - mock_runner.run.return_value = None + mock_runner_data.oidc_config = None response = client_with_auth.get("/logout") assert response.status_code == status.HTTP_205_RESET_CONTENT