diff --git a/helm/blueapi/config_schema.json b/helm/blueapi/config_schema.json index b5d0c9bf3..4f5d157eb 100644 --- a/helm/blueapi/config_schema.json +++ b/helm/blueapi/config_schema.json @@ -330,6 +330,44 @@ "type": "object", "$id": "OIDCConfig" }, + "OpaConfig": { + "additionalProperties": false, + "properties": { + "root": { + "default": "http://localhost:8181/", + "format": "uri", + "maxLength": 2083, + "minLength": 1, + "title": "Root", + "type": "string" + }, + "audience": { + "default": "account", + "title": "Audience", + "type": "string" + }, + "tiled_service_account_check": { + "title": "Tiled Service Account Check", + "type": "string" + }, + "submit_task_check": { + "title": "Submit Task Check", + "type": "string" + }, + "admin_check": { + "title": "Admin Check", + "type": "string" + } + }, + "required": [ + "tiled_service_account_check", + "submit_task_check", + "admin_check" + ], + "title": "OpaConfig", + "type": "object", + "$id": "OpaConfig" + }, "PlanSource": { "additionalProperties": false, "properties": { @@ -612,6 +650,17 @@ } ], "default": null + }, + "opa": { + "anyOf": [ + { + "$ref": "OpaConfig" + }, + { + "type": "null" + } + ], + "default": null } }, "title": "ApplicationConfig", diff --git a/helm/blueapi/values.schema.json b/helm/blueapi/values.schema.json index 74deedadb..6083f77e5 100644 --- a/helm/blueapi/values.schema.json +++ b/helm/blueapi/values.schema.json @@ -751,6 +751,44 @@ }, "additionalProperties": false }, + "OpaConfig": { + "$id": "OpaConfig", + "title": "OpaConfig", + "type": "object", + "required": [ + "tiled_service_account_check", + "submit_task_check", + "admin_check" + ], + "properties": { + "admin_check": { + "title": "Admin Check", + "type": "string" + }, + "audience": { + "title": "Audience", + "default": "account", + "type": "string" + }, + "root": { + "title": "Root", + "default": "http://localhost:8181/", + "type": "string", + "format": "uri", + "maxLength": 2083, + "minLength": 1 + }, + "submit_task_check": { + "title": "Submit Task Check", + "type": "string" + }, + "tiled_service_account_check": { + "title": "Tiled Service Account Check", + "type": "string" + } + }, + "additionalProperties": false + }, "PlanSource": { "$id": "PlanSource", "title": "PlanSource", @@ -1011,6 +1049,16 @@ } ] }, + "opa": { + "anyOf": [ + { + "$ref": "OpaConfig" + }, + { + "type": "null" + } + ] + }, "scratch": { "anyOf": [ { diff --git a/pyproject.toml b/pyproject.toml index 659779994..9f4231c76 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,6 +37,7 @@ dependencies = [ "tomlkit", "graypy>=2.1.0", "httpx>=0.28.1", + "aiohttp>=3.13.5", ] dynamic = ["version"] license.file = "LICENSE" diff --git a/src/blueapi/config.py b/src/blueapi/config.py index 83d6d7021..a19e30c7b 100644 --- a/src/blueapi/config.py +++ b/src/blueapi/config.py @@ -296,6 +296,14 @@ class Tag(StrEnum): META = "Meta" +class OpaConfig(BlueapiBaseModel): + root: HttpUrl = HttpUrl("http://localhost:8181") + audience: str = "account" + tiled_service_account_check: str + submit_task_check: str + admin_check: str + + class ApplicationConfig(BlueapiBaseModel): """ Config for the worker application as a whole. Root of @@ -335,6 +343,7 @@ class ApplicationConfig(BlueapiBaseModel): oidc: OIDCConfig | None = None auth_token_path: Path | None = None numtracker: NumtrackerConfig | None = None + opa: OpaConfig | None = None def __eq__(self, other: object) -> bool: if isinstance(other, ApplicationConfig): @@ -343,6 +352,7 @@ def __eq__(self, other: object) -> bool: & (self.env == other.env) & (self.logging == other.logging) & (self.api == other.api) + & (self.opa == other.opa) ) return False diff --git a/src/blueapi/service/authentication.py b/src/blueapi/service/authentication.py index b107f7b2b..64dfc3004 100644 --- a/src/blueapi/service/authentication.py +++ b/src/blueapi/service/authentication.py @@ -6,16 +6,20 @@ import time import webbrowser from abc import ABC, abstractmethod +from collections.abc import Mapping from functools import cached_property from http import HTTPStatus from pathlib import Path -from typing import Any, cast +from typing import Annotated, Any, cast import httpx import jwt import requests +from fastapi import Depends, HTTPException, Request +from fastapi.security.utils import get_authorization_scheme_param from pydantic import TypeAdapter from requests.auth import AuthBase +from starlette.status import HTTP_401_UNAUTHORIZED from blueapi.config import OIDCConfig, ServiceAccount from blueapi.service.model import Cache @@ -272,3 +276,61 @@ def get_access_token(self): def sync_auth_flow(self, request): request.headers["Authorization"] = f"Bearer {self.get_access_token()}" yield request + + +def unchecked_bearer_token(req: Request) -> str | None: + """Get bearer token value from authorization header""" + auth = req.headers.get("Authorization") + scheme, param = get_authorization_scheme_param(auth) + if scheme.casefold() != "bearer": + return None + return param.strip() + + +UncheckedBearerToken = Annotated[str | None, Depends(unchecked_bearer_token)] + + +def build_access_token_check(config: OIDCConfig): + """ + Create a function to validate the bearer token of requests + + The returned function should be used via fastAPI's 'Depends' mechanism to + ensure users are authenticated + """ + jwkclient = jwt.PyJWKClient(config.jwks_uri) + + def validate_bearer_token(request: Request, token: UncheckedBearerToken): + """Check that a bearer token is valid and inject into request state""" + if not token: + raise HTTPException( + status_code=HTTP_401_UNAUTHORIZED, + detail="Not authenticated", + headers={"WWW-Authenticate": "Bearer"}, + ) + + signing_key = jwkclient.get_signing_key_from_jwt(token) + decoded: dict[str, Any] = jwt.decode( + token, + signing_key.key, + algorithms=config.id_token_signing_alg_values_supported, + verify=True, + audience=config.client_audience, + issuer=config.issuer, + ) + request.state.decoded_access_token = decoded + + return validate_bearer_token + + +def access_token(request: Request) -> Mapping[str, Any] | None: + """Get the decoded and verified access token of the user making the request""" + return getattr(request.state, "decoded_access_token", None) + + +def fedid( + access_token: Annotated[Mapping[str, Any] | None, Depends(access_token)], +) -> str | None: + return access_token.get("fedid") if access_token else None + + +Fedid = Annotated[str | None, Depends(fedid)] diff --git a/src/blueapi/service/authorization.py b/src/blueapi/service/authorization.py new file mode 100644 index 000000000..f9008138a --- /dev/null +++ b/src/blueapi/service/authorization.py @@ -0,0 +1,134 @@ +import logging +from collections.abc import Mapping +from contextlib import AbstractAsyncContextManager, aclosing, nullcontext +from typing import Annotated, Any, Self, cast + +from aiohttp import ClientSession +from fastapi import Depends, HTTPException, Request +from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN + +from blueapi.config import OIDCConfig, OpaConfig, ServiceAccount +from blueapi.service.authentication import TiledAuth, unchecked_bearer_token +from blueapi.service.model import TaskRequest +from blueapi.utils import INSTRUMENT_SESSION_RE + +LOGGER = logging.getLogger(__name__) + + +class OpaClient: + def __init__(self, instrument: str, config: OpaConfig): + LOGGER.info("Creating OpaClient for %s with config %s", instrument, config) + self._instrument = instrument + self._config = config + self._session = ClientSession(base_url=config.root.encoded_string()) + self._audience = config.audience + + async def aclose(self): + LOGGER.info("Closing OPA session") + await self._session.close() + + async def _call_opa(self, endpoint: str, data: Mapping[str, Any]) -> bool: + resp = await self._session.post( + endpoint, + json={ + "input": { + "beamline": self._instrument, + "audience": self._audience, + **data, + } + }, + ) + return (await resp.json())["result"] + + @classmethod + def for_config( + cls, instrument: str | None, config: OpaConfig | None + ) -> AbstractAsyncContextManager[Self | None]: + if config: + if not instrument: + raise ValueError("Instrument name is required for OPA client") + return aclosing(cls(instrument, config)) + LOGGER.info("No OPA config provided - not creating OpaClient") + return nullcontext() + + async def require_tiled_service_account(self, token: str): + if not await self._call_opa( + self._config.tiled_service_account_check, + {"token": token, "beamline": self._instrument}, + ): + raise ValueError( + f"Tiled service account is not valid for '{self._instrument}'" + ) + + async def require_submit_task(self, instrument_session: str, token: str): + if not (match := INSTRUMENT_SESSION_RE.match(instrument_session)): + raise ValueError("Invalid instrument session") + + if not await self._call_opa( + self._config.submit_task_check, + { + "token": token, + "proposal": int(match["proposal"]), + "visit": int(match["visit"]), + }, + ): + raise HTTPException( + status_code=HTTP_403_FORBIDDEN, detail="Not authorized to submit task" + ) + + async def is_admin(self, token: str) -> bool: + return await self._call_opa(self._config.admin_check, {"token": token}) + + +class OpaUserClient: + client: OpaClient + token: str + + def __init__(self, client: OpaClient, token: str): + self.client = client + self.token = token + + async def can_submit_task(self, task: TaskRequest): + LOGGER.info("Checking permissions to run task") + await self.client.require_submit_task(task.instrument_session, self.token) + + async def admin(self) -> bool: + return await self.client.is_admin(self.token) + + +async def validate_tiled_config( + tiled: ServiceAccount | str | None, oidc: OIDCConfig | None, opa: OpaClient | None +): + if not isinstance(tiled, ServiceAccount): + # can't validate an API key + return + + if not opa or not oidc: + LOGGER.info("Missing OPA or OIDC configuration required to validate tiled auth") + return + + LOGGER.info("Validating tiled configuration") + tiled.token_url = oidc.token_endpoint + auth = TiledAuth(tiled) + await opa.require_tiled_service_account(auth.get_access_token()) + + +async def opa( + request: Request, token: str | None = Depends(unchecked_bearer_token) +) -> OpaUserClient | None: + + if opa := cast(OpaClient | None, getattr(request.app.state, "authz", None)): + if not token: + raise HTTPException( + status_code=HTTP_401_UNAUTHORIZED, detail="Authentication missing" + ) + return OpaUserClient(opa, token) + return None + + +async def submit_permission( + opa: Annotated[OpaUserClient | None, Depends(opa)], + task_request: TaskRequest, +): + if opa: + await opa.can_submit_task(task_request) diff --git a/src/blueapi/service/main.py b/src/blueapi/service/main.py index a53c46885..90fe28aaf 100644 --- a/src/blueapi/service/main.py +++ b/src/blueapi/service/main.py @@ -2,7 +2,7 @@ import urllib.parse from collections.abc import Awaitable, Callable from contextlib import asynccontextmanager -from typing import Annotated, Any +from typing import Annotated import jwt from fastapi import ( @@ -19,7 +19,6 @@ from fastapi.datastructures import Address from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import RedirectResponse, StreamingResponse -from fastapi.security import OAuth2AuthorizationCodeBearer from observability_utils.tracing import ( add_span_attributes, get_tracer, @@ -37,9 +36,17 @@ from blueapi import __version__ from blueapi.config import ApplicationConfig, OIDCConfig, Tag from blueapi.service import interface +from blueapi.service.authentication import Fedid, build_access_token_check from blueapi.worker import TrackableTask, WorkerState from blueapi.worker.event import TaskStatusEnum +from .authorization import ( + OpaClient, + OpaUserClient, + opa, + submit_permission, + validate_tiled_config, +) from .model import ( DeviceModel, DeviceResponse, @@ -61,6 +68,7 @@ RUNNER: WorkerDispatcher | None = None LOGGER = logging.getLogger(__name__) +TRACER = get_tracer("interface") def _runner() -> WorkerDispatcher: @@ -92,8 +100,12 @@ def teardown_runner(): def lifespan(config: ApplicationConfig): @asynccontextmanager async def inner(app: FastAPI): + meta = config.env.metadata setup_runner(config) - yield + async with OpaClient.for_config(meta and meta.instrument, config.opa) as opa: + app.state.authz = opa + await validate_tiled_config(config.tiled.authentication, config.oidc, opa) + yield teardown_runner() return inner @@ -117,7 +129,7 @@ def get_app(config: ApplicationConfig): ) dependencies = [] if config.oidc: - dependencies.append(Depends(decode_access_token(config.oidc))) + dependencies.append(Depends(build_access_token_check(config.oidc))) app.swagger_ui_init_oauth = { "clientId": "NOT_SUPPORTED", } @@ -140,30 +152,35 @@ def get_app(config: ApplicationConfig): return app -def decode_access_token(config: OIDCConfig): - jwkclient = jwt.PyJWKClient(config.jwks_uri) - oauth_scheme = OAuth2AuthorizationCodeBearer( - authorizationUrl=config.authorization_endpoint, - tokenUrl=config.token_endpoint, - refreshUrl=config.token_endpoint, - ) - - def inner(request: Request, access_token: str = Depends(oauth_scheme)): - signing_key = jwkclient.get_signing_key_from_jwt(access_token) - decoded: dict[str, Any] = jwt.decode( - access_token, - signing_key.key, - algorithms=config.id_token_signing_alg_values_supported, - verify=True, - audience=config.client_audience, - issuer=config.issuer, - ) - request.state.decoded_access_token = decoded +async def access_task_permission( + opa: Annotated[OpaUserClient | None, Depends(opa)], + task_id: str, + fedid: Fedid, + runner: Annotated[WorkerDispatcher, Depends(_runner)], +): + task = runner.run(interface.get_task_by_id, task_id) - return inner + if ( + opa + and not await opa.admin() + and (task and fedid != task.task.metadata.get("user")) + ): + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND) -TRACER = get_tracer("interface") +# start_task_permission is used when there is WorkerTask +async def start_task_permission( + task: WorkerTask, + opa: Annotated[OpaUserClient, Depends(opa)], + fedid: Fedid, + runner: Annotated[WorkerDispatcher, Depends(_runner)], +): + if not task.task_id: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_CONTENT, + detail="No task id provided", + ) + await access_task_permission(opa, task.task_id, fedid, runner) async def on_key_error_404(_: Request, __: Exception): @@ -291,20 +308,13 @@ def submit_task( request: Request, response: Response, task_request: Annotated[TaskRequest, Body(..., examples=[example_task_request])], + _: Annotated[None, Depends(submit_permission)], runner: Annotated[WorkerDispatcher, Depends(_runner)], + fedid: Fedid, ) -> TaskResponse: """Submit a task to the worker.""" try: - # Extract user from jwt if using OIDC (if jwt exists) - access_token: dict[str, Any] | None = getattr( - request.state, "decoded_access_token", None - ) - if access_token: - user: str = access_token.get("fedid", "Unknown") - else: - user = "Unknown" - - task_id: str = runner.run(interface.submit_task, task_request, {"user": user}) + task_id: str = runner.run(interface.submit_task, task_request, {"user": fedid}) response.headers["Location"] = f"{request.url}/{task_id}" return TaskResponse(task_id=task_id) except ValidationError as e: @@ -336,6 +346,7 @@ def submit_task( @start_as_current_span(TRACER, "task_id") def delete_submitted_task( task_id: str, + _: Annotated[None, Depends(access_task_permission)], runner: Annotated[WorkerDispatcher, Depends(_runner)], ) -> TaskResponse: return TaskResponse(task_id=runner.run(interface.clear_task, task_id)) @@ -352,8 +363,10 @@ def validate_task_status(v: str) -> TaskStatusEnum: @secure_router_v1.get("/tasks", status_code=status.HTTP_200_OK, tags=[Tag.TASK]) @secure_router.get("/tasks", status_code=status.HTTP_200_OK, tags=[Tag.TASK]) @start_as_current_span(TRACER) -def get_tasks( +async def get_tasks( + fedid: Fedid, runner: Annotated[WorkerDispatcher, Depends(_runner)], + opa: Annotated[OpaUserClient, Depends(opa)], task_status: str | SkipJsonSchema[None] = None, ) -> TasksListResponse: """ @@ -373,6 +386,10 @@ def get_tasks( tasks = runner.run(interface.get_tasks_by_status, desired_status) else: tasks = runner.run(interface.get_tasks) + + if opa and not await opa.admin(): + tasks = [t for t in tasks if t.task.metadata.get("user") == fedid] + return TasksListResponse(tasks=tasks) @@ -390,6 +407,7 @@ def get_tasks( def set_active_task( request: Request, task: WorkerTask, + _: Annotated[None, Depends(start_task_permission)], runner: Annotated[WorkerDispatcher, Depends(_runner)], ) -> WorkerTask: """Set a task to active status, the worker should begin it as soon as possible. @@ -420,6 +438,7 @@ def get_passthrough_headers(request: Request) -> dict[str, str]: @start_as_current_span(TRACER, "task_id") def get_task( task_id: str, + _: Annotated[None, Depends(access_task_permission)], runner: Annotated[WorkerDispatcher, Depends(_runner)], ) -> TrackableTask: """Retrieve a task""" @@ -494,9 +513,11 @@ def get_state(runner: Annotated[WorkerDispatcher, Depends(_runner)]) -> WorkerSt tags=[Tag.TASK], ) @start_as_current_span(TRACER, "state_change_request.new_state") -def set_state( +async def set_state( state_change_request: StateChangeRequest, response: Response, + fedid: Fedid, + opa: Annotated[OpaUserClient, Depends(opa)], runner: Annotated[WorkerDispatcher, Depends(_runner)], ) -> WorkerState: """ @@ -523,6 +544,19 @@ def set_state( current_state in _ALLOWED_TRANSITIONS and new_state in _ALLOWED_TRANSITIONS[current_state] ): + active = runner.run(interface.get_active_task) + + if ( + opa + and not await opa.admin() + and active + and active.task.metadata.get("user") != fedid + ): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Not authorized to set worker state", + ) + if new_state == WorkerState.PAUSED: runner.run(interface.pause_worker, state_change_request.defer) elif new_state == WorkerState.RUNNING: diff --git a/src/blueapi/utils/__init__.py b/src/blueapi/utils/__init__.py index 4b2e41f2c..bf96b7009 100644 --- a/src/blueapi/utils/__init__.py +++ b/src/blueapi/utils/__init__.py @@ -1,3 +1,4 @@ +import re from collections.abc import Callable from functools import wraps from typing import ParamSpec, TypeVar @@ -31,6 +32,8 @@ Args = ParamSpec("Args") Return = TypeVar("Return") +INSTRUMENT_SESSION_RE = re.compile(r"^[a-z]{2}(?P\d+)-(?P\d+)$") + def deprecated(alternative): from warnings import warn diff --git a/src/blueapi/utils/serialization.py b/src/blueapi/utils/serialization.py index deee82b1e..8918cf882 100644 --- a/src/blueapi/utils/serialization.py +++ b/src/blueapi/utils/serialization.py @@ -1,9 +1,10 @@ import json -import re from typing import Any from pydantic import BaseModel +from blueapi import utils + def serialize(obj: Any) -> Any: """ @@ -28,13 +29,8 @@ def serialize(obj: Any) -> Any: return obj -_INSTRUMENT_SESSION_AUTHZ_REGEX: re.Pattern = re.compile( - r"^[a-zA-Z]{2}(?P\d+)-(?P\d+)$" -) - - def access_blob(instrument_session: str, beamline: str) -> str: - m = _INSTRUMENT_SESSION_AUTHZ_REGEX.match(instrument_session) + m = utils.INSTRUMENT_SESSION_RE.match(instrument_session) if m is None: raise ValueError( "Unable to extract proposal and visit from " diff --git a/tests/unit_tests/service/test_authentication.py b/tests/unit_tests/service/test_authentication.py index 88227706b..01bc426e2 100644 --- a/tests/unit_tests/service/test_authentication.py +++ b/tests/unit_tests/service/test_authentication.py @@ -8,15 +8,19 @@ import pytest import responses import respx +from fastapi import HTTPException from pydantic import SecretStr from starlette.status import HTTP_200_OK, HTTP_403_FORBIDDEN from blueapi.config import OIDCConfig, ServiceAccount -from blueapi.service import main +from blueapi.service import authentication from blueapi.service.authentication import ( SessionCacheManager, SessionManager, TiledAuth, + access_token, + build_access_token_check, + unchecked_bearer_token, ) @@ -124,9 +128,9 @@ def test_poll_for_token_timeout( def test_server_raises_exception_for_invalid_token( oidc_config: OIDCConfig, mock_authn_server: responses.RequestsMock ): - inner = main.decode_access_token(oidc_config) + inner = authentication.build_access_token_check(oidc_config) with pytest.raises(jwt.PyJWTError): - inner(Mock(), access_token="Invalid Token") + inner(Mock(), token="Invalid Token") def test_processes_valid_token( @@ -134,8 +138,8 @@ def test_processes_valid_token( mock_authn_server: responses.RequestsMock, valid_token_with_jwt, ): - inner = main.decode_access_token(oidc_config) - inner(Mock(), access_token=valid_token_with_jwt["access_token"]) + inner = authentication.build_access_token_check(oidc_config) + inner(Mock(), token=valid_token_with_jwt["access_token"]) def test_session_cache_manager_returns_writable_file_path(tmp_path): @@ -182,3 +186,49 @@ def test_tiled_auth_sync_auth_flow(): result = next(flow) assert result.headers["Authorization"] == f"Bearer {access_token}" + + +@pytest.mark.parametrize( + "header,token", + [ + (None, None), + ("ApiKey foobar", None), + ("Bearer foobar", "foobar"), + ("Bearer with_whitespace ", "with_whitespace"), + ("Bearerfoobar", None), + ], +) +def test_unchecked_bearer_token(header: str | None, token: str | None): + req = Mock() + req.headers.get.side_effect = lambda key: header if key == "Authorization" else None + + assert unchecked_bearer_token(req) == token + + +def test_access_token(): + req = Mock() + req.state.decoded_access_token = {"foo": "bar"} + + assert access_token(req) == {"foo": "bar"} + + +def test_access_token_without_token(): + req = Mock() + del req.state.decoded_access_token + + assert access_token(req) is None + + +@patch("blueapi.service.authentication.jwt") +def test_build_access_token(mock_jwt: Mock): + # Return None when building client to ensure no field/method access + mock_jwt.PyJWKClient.return_value = None + oidc_config = Mock() + req = Mock() + + validate_fn = build_access_token_check(oidc_config) + + with pytest.raises(HTTPException, match="401"): + validate_fn(req, token=None) + + mock_jwt.decode.assert_not_called() diff --git a/tests/unit_tests/service/test_authorization.py b/tests/unit_tests/service/test_authorization.py new file mode 100644 index 000000000..a2e602f21 --- /dev/null +++ b/tests/unit_tests/service/test_authorization.py @@ -0,0 +1,300 @@ +from contextlib import AbstractContextManager, nullcontext +from unittest.mock import AsyncMock, MagicMock, Mock, patch + +import pytest +from fastapi import HTTPException +from pydantic import HttpUrl + +from blueapi.config import OIDCConfig, OpaConfig, ServiceAccount +from blueapi.service.authorization import ( + OpaClient, + OpaUserClient, + opa, + submit_permission, + validate_tiled_config, +) +from blueapi.service.model import TaskRequest + +# Reusable client patch decorator +patch_client_session = patch( + "blueapi.service.authorization.ClientSession", + name="mock_client_session", + spec=True, +) + + +@pytest.fixture(scope="module") +def opa_config() -> OpaConfig: + return OpaConfig( + root=HttpUrl("http://auth.example.com"), + submit_task_check="/auth/submit", + admin_check="/auth/admin", + tiled_service_account_check="/auth/tiled", + ) + + +@patch_client_session +@pytest.mark.parametrize( + "result,context", + [ + (False, pytest.raises(ValueError, match="Tiled service account is not valid ")), + (True, nullcontext()), + ], +) +async def test_tiled_service_account( + session: MagicMock, + opa_config: OpaConfig, + result: bool, + context: AbstractContextManager, +): + session.return_value.post = AsyncMock( + return_value=MagicMock(json=AsyncMock(return_value={"result": result})) + ) + + client = OpaClient(instrument="p99", config=opa_config) + + session.assert_called_once_with(base_url="http://auth.example.com/") + with context: + await client.require_tiled_service_account(token="foo_bar") + session().post.assert_called_once_with( + "/auth/tiled", + json={"input": {"token": "foo_bar", "beamline": "p99", "audience": "account"}}, + ) + + +@patch_client_session +async def test_exception_raised_when_opa_fails( + session: MagicMock, opa_config: OpaConfig +): + session.return_value.post = AsyncMock(side_effect=RuntimeError("Connection failed")) + async with OpaClient.for_config("p45", opa_config) as client: + assert client is not None + with pytest.raises(RuntimeError, match="Connection failed"): + await client.require_tiled_service_account(token="foo_bar") + + +@patch_client_session +async def test_session_closed(session: MagicMock, opa_config: OpaConfig): + async with OpaClient.for_config("p45", opa_config): + pass + session().close.assert_called_once() + + +@patch_client_session +async def test_opa_client_for_config(session: MagicMock, opa_config: OpaConfig): + async with OpaClient.for_config("p45", opa_config) as opa: + assert opa is not None + session.assert_called_once_with(base_url="http://auth.example.com/") + + +@pytest.mark.parametrize("instrument", [None, "p99"]) +async def test_opa_client_without_config(instrument: str | None): + async with OpaClient.for_config(instrument, None) as opa: + assert opa is None + + +async def test_opa_fails_without_instrument(opa_config: OpaConfig): + with pytest.raises(ValueError, match="Instrument name is required"): + OpaClient.for_config(None, opa_config) + + +@patch_client_session +async def test_opa_adds_input_fields(session: MagicMock, opa_config: OpaConfig): + session.return_value.post = AsyncMock() + async with OpaClient.for_config("p45", opa_config) as opa: + assert opa is not None + await opa._call_opa("foo/bar", data={"foo": "bar"}) + + session.assert_called_once() + session().post.assert_called_once_with( + "foo/bar", + json={"input": {"beamline": "p45", "audience": "account", "foo": "bar"}}, + ) + + +@pytest.mark.parametrize( + "result,context", + [(True, nullcontext()), (False, pytest.raises(HTTPException, match="403"))], +) +@patch_client_session +async def test_require_submit_task( + session: MagicMock, + opa_config: OpaConfig, + result: bool, + context: AbstractContextManager, +): + session.return_value.post = AsyncMock( + return_value=MagicMock(json=AsyncMock(return_value={"result": result})) + ) + + client = OpaClient(instrument="p99", config=opa_config) + + session.assert_called_once_with(base_url="http://auth.example.com/") + with context: + await client.require_submit_task( + instrument_session="cm12345-1", token="foo_bar" + ) + + session().post.assert_called_once_with( + "/auth/submit", + json={ + "input": { + "token": "foo_bar", + "beamline": "p99", + "audience": "account", + "visit": 1, + "proposal": 12345, + } + }, + ) + + +@patch_client_session +async def test_opa_require_submit_task_invalid_session( + session: MagicMock, opa_config: OpaConfig +): + client = OpaClient(instrument="p45", config=opa_config) + + with pytest.raises(ValueError, match="Invalid instrument session"): + await client.require_submit_task( + instrument_session="not a session", token="foo_bar" + ) + + +@pytest.mark.parametrize("result", [True, False]) +@patch_client_session +async def test_opa_is_admin(session: MagicMock, opa_config: OpaConfig, result: bool): + session.return_value.post = AsyncMock( + return_value=MagicMock(json=AsyncMock(return_value={"result": result})) + ) + client = OpaClient(instrument="p45", config=opa_config) + + admin = await client.is_admin("foo_bar") + + assert admin == result + + session().post.assert_called_once_with( + "/auth/admin", + json={"input": {"token": "foo_bar", "beamline": "p45", "audience": "account"}}, + ) + + +@pytest.mark.parametrize( + "result,context", + [ + (None, nullcontext()), + (HTTPException(status_code=403), pytest.raises(HTTPException, match="403")), + ], +) +async def test_user_client_can_submit_task(result, context: AbstractContextManager): + opa = MagicMock(spec=OpaUserClient) + opa.require_submit_task = AsyncMock(side_effect=result) + + user_client = OpaUserClient(opa, "foo_bar") + + with context: + await user_client.can_submit_task( + TaskRequest(name="foo", params={}, instrument_session="cm12345-1") + ) + opa.require_submit_task.assert_called_once_with("cm12345-1", "foo_bar") + + +@pytest.mark.parametrize("result", [True, False]) +async def test_user_client_admin(result: bool): + opa = MagicMock(spec=OpaUserClient) + opa.is_admin = AsyncMock(return_value=result) + + user_client = OpaUserClient(opa, "foo_bar") + + admin = await user_client.admin() + + assert admin == result + + +async def test_validate_tiled_config(): + opa = MagicMock(spec=OpaClient) + tiled = ServiceAccount() + oidc = Mock(spec=OIDCConfig) + oidc.token_endpoint = "token-endpoint" + with patch("blueapi.service.authorization.TiledAuth") as auth: + auth.return_value.get_access_token.return_value = "tiled-token" + await validate_tiled_config(tiled, oidc, opa) + + auth.assert_called_once_with(tiled) + opa.require_tiled_service_account.assert_called_once_with("tiled-token") + + +@pytest.mark.parametrize( + "tiled_auth,oidc,opa_client", + [ + (None, None, MagicMock(spec=OpaClient)), + ( + None, + OIDCConfig(well_known_url="http://example.com", client_id="test-client"), + MagicMock(spec=OpaClient), + ), + ("api_key", None, MagicMock(spec=OpaClient)), + ( + "api_key", + OIDCConfig(well_known_url="http://example.com", client_id="test-client"), + MagicMock(spec=OpaClient), + ), + (ServiceAccount(), None, MagicMock(spec=OpaClient)), + ( + ServiceAccount(), + OIDCConfig(well_known_url="http://example.com", client_id="test-client"), + None, + ), + ], +) +async def test_validate_tiled_config_with_missing_config( + tiled_auth: ServiceAccount | str | None, + oidc: OIDCConfig | None, + opa_client: MagicMock | None, +): + assert await validate_tiled_config(tiled_auth, oidc, opa_client) is None + if opa_client is not None: + opa_client.require_tiled_service_account.assert_not_called() + + +async def test_opa_dependency_method(): + request = MagicMock() + + user_client = await opa(request, "foo_bar") + + assert user_client is not None + assert user_client.client == request.app.state.authz + assert user_client.token == "foo_bar" + + +async def test_opa_dependency_without_token(): + request = MagicMock() + + with pytest.raises(HTTPException, match="401"): + await opa(request, None) + + +@pytest.mark.parametrize("token", ["foo_bar", None]) +async def test_opa_dependency_without_authz(token): + request = MagicMock() + del request.app.state.authz + user_client = await opa(request, token) + assert user_client is None + + +@pytest.mark.parametrize( + "result,context", + [ + (None, nullcontext()), + (HTTPException(status_code=403), pytest.raises(HTTPException, match="403")), + ], +) +async def test_submit_permission_dependency(result, context: AbstractContextManager): + opa = MagicMock(spec=OpaUserClient) + opa.can_submit_task.side_effect = result + with context: + await submit_permission(opa, Mock()) + + +async def test_submit_permission_dependency_without_opa(): + assert await submit_permission(None, Mock()) is None diff --git a/tests/unit_tests/service/test_main.py b/tests/unit_tests/service/test_main.py index a7e04105c..1801b8756 100644 --- a/tests/unit_tests/service/test_main.py +++ b/tests/unit_tests/service/test_main.py @@ -1,5 +1,5 @@ from unittest import mock -from unittest.mock import Mock, call +from unittest.mock import Mock, call, patch import pytest from fastapi import FastAPI, Request @@ -10,6 +10,7 @@ from blueapi.service.main import ( add_version_headers, get_passthrough_headers, + lifespan, log_request_details, ) @@ -79,3 +80,18 @@ def test_get_passthrough_headers( request = Mock(spec=Request) request.headers = headers assert get_passthrough_headers(request) == expected_headers + + +@patch("blueapi.service.main.teardown_runner") +@patch("blueapi.service.main.setup_runner") +async def test_lifespan(setup: Mock, teardown: Mock): + conf = ApplicationConfig() + lifespan_fn = lifespan(conf) + + app = Mock() + + async with lifespan_fn(app): + setup.assert_called_once_with(conf) + teardown.assert_not_called() + + teardown.assert_called_once() diff --git a/tests/unit_tests/service/test_rest_api.py b/tests/unit_tests/service/test_rest_api.py index c1d3b6a95..1b15c7790 100644 --- a/tests/unit_tests/service/test_rest_api.py +++ b/tests/unit_tests/service/test_rest_api.py @@ -2,26 +2,30 @@ from collections.abc import Iterator from dataclasses import dataclass from typing import Any -from unittest.mock import MagicMock, Mock, patch +from unittest.mock import MagicMock, Mock, PropertyMock, patch -import jwt import pytest from bluesky.protocols import Stoppable -from fastapi import status +from fastapi import HTTPException, status from fastapi.testclient import TestClient from httpx import Headers from pydantic import BaseModel, ValidationError from pydantic_core import InitErrorDetails from super_state_machine.errors import TransitionError -from blueapi.config import ApplicationConfig, CORSConfig, OIDCConfig, RestConfig +from blueapi.config import ( + ApplicationConfig, + CORSConfig, + OIDCConfig, + RestConfig, +) from blueapi.core.bluesky_types import Plan -from blueapi.service import main +from blueapi.service import interface, main +from blueapi.service.authorization import OpaUserClient, opa from blueapi.service.interface import ( cancel_active_task, get_device, get_plan, - pause_worker, resume_worker, submit_task, ) @@ -37,7 +41,7 @@ WorkerTask, ) from blueapi.service.runner import WorkerDispatcher -from blueapi.worker.event import WorkerState +from blueapi.worker.event import TaskStatusEnum, WorkerState from blueapi.worker.task import Task from blueapi.worker.task_worker import TrackableTask @@ -50,8 +54,83 @@ class MockCountModel(BaseModel): ... @pytest.fixture -def mock_runner() -> Mock: - return Mock(spec=WorkerDispatcher) +def mock_runner_data() -> Mock: + return Mock() + + +@pytest.fixture +def mock_runner(mock_runner_data: Mock) -> Mock: + runner = Mock(spec=WorkerDispatcher) + + def run(method, *args, **kwargs): + match method: + case interface.get_active_task: + running = run(interface.get_tasks_by_status, TaskStatusEnum.RUNNING) + return running[0] if running else None + case interface.get_task_by_id: + return {t.task_id: t for t in mock_runner_data.tasks}.get( + kwargs.get("task_id") or args[0] + ) + case interface.get_tasks: + return mock_runner_data.tasks + case interface.get_tasks_by_status: + status = kwargs.get("status") or args[0] + match status: + case TaskStatusEnum.RUNNING: + return [ + t + for t in mock_runner_data.tasks + if not t.is_pending and not t.is_complete + ] + case TaskStatusEnum.PENDING: + return [t for t in mock_runner_data.tasks if t.is_pending] + case TaskStatusEnum.COMPLETE: + return [t for t in mock_runner_data.tasks if t.is_complete] + case _: + return [] + + case interface.get_plans: + return mock_runner_data.plans + case interface.get_plan: + name = kwargs.get("name") or args[0] + plans = [p for p in mock_runner_data.plans if p.name == name] + if plans: + return plans[0] + raise KeyError(name) + case interface.get_devices: + return mock_runner_data.devices + case interface.get_device: + name = kwargs.get("name") or args[0] + devices = [d for d in mock_runner_data.devices if d.name == name] + if devices: + return devices[0] + raise KeyError(name) + case interface.get_oidc_config: + return mock_runner_data.oidc_config + case interface.get_worker_state: + return mock_runner_data.state + case interface.get_python_env: + return mock_runner_data.python_environment + case interface.submit_task: + return mock_runner_data.submit_task(*args, **kwargs) + case interface.begin_task: + return mock_runner_data.begin_task(*args, **kwargs) + case interface.clear_task: + return mock_runner_data.clear_task(*args, **kwargs) + case interface.cancel_active_task: + return mock_runner_data.cancel_active_task(*args, **kwargs) + case interface.pause_worker: + return mock_runner_data.pause_worker(*args, **kwargs) + case _: + raise ValueError("Unsupported method: " + method.__name__) + + runner.run.side_effect = run + return runner + + +@pytest.fixture +def mock_opa_client() -> Mock: + return Mock(spec=OpaUserClient) @pytest.fixture @@ -79,6 +158,27 @@ def client_with_auth( main.teardown_runner() +@pytest.fixture +def access_token(valid_token_with_jwt: dict[str, Any]) -> str: + return valid_token_with_jwt["access_token"] + + +@pytest.fixture +def client_with_opa( + mock_runner: Mock, + oidc_config: OIDCConfig, + mock_opa_client: Mock, + mock_authn_server, +): + with patch("blueapi.service.interface.worker"): + main.setup_runner(runner=mock_runner) + app = main.get_app(ApplicationConfig(oidc=oidc_config)) + app.dependency_overrides[opa] = lambda: mock_opa_client + client = TestClient(app) + yield client + main.teardown_runner() + + @pytest.fixture def rest_config_with_cors() -> RestConfig: cors_config = CORSConfig( @@ -109,13 +209,13 @@ def stop(self, success: bool = True): def test_rest_config_with_cors_gets_plan( client_with_cors: TestClient, - mock_runner: Mock, + mock_runner_data: Mock, ): class MyModel(BaseModel): id: str plan = Plan(name="my-plan", model=MyModel) - mock_runner.run.return_value = [PlanModel.from_plan(plan)] + mock_runner_data.plans = [PlanModel.from_plan(plan)] response_get = client_with_cors.get("/plans") assert response_get.status_code == status.HTTP_200_OK @@ -123,7 +223,7 @@ class MyModel(BaseModel): def test_rest_config_with_cors( client_with_cors: TestClient, - mock_runner: Mock, + mock_runner_data: Mock, ): task = TaskRequest( name="my-plan", @@ -131,7 +231,7 @@ def test_rest_config_with_cors( instrument_session=FAKE_INSTRUMENT_SESSION, ) task_id = "f8424be3-203c-494e-b22f-219933b4fa67" - mock_runner.run.side_effect = [task_id] + mock_runner_data.submit_task.return_value = task_id # Allowed method response_post = client_with_cors.post( @@ -143,12 +243,12 @@ def test_rest_config_with_cors( assert response_post.headers["content-type"] == "application/json" -def test_get_plans(mock_runner: Mock, client: TestClient) -> None: +def test_get_plans(mock_runner_data: Mock, client: TestClient) -> None: class MyModel(BaseModel): id: str plan = Plan(name="my-plan", model=MyModel) - mock_runner.run.return_value = [PlanModel.from_plan(plan)] + mock_runner_data.plans = [PlanModel.from_plan(plan)] response = client.get("/plans") @@ -169,12 +269,14 @@ class MyModel(BaseModel): } -def test_get_plan_by_name(mock_runner: Mock, client: TestClient) -> None: +def test_get_plan_by_name( + mock_runner: Mock, mock_runner_data: Mock, client: TestClient +) -> None: class MyModel(BaseModel): id: str plan = Plan(name="my-plan", model=MyModel) - mock_runner.run.return_value = PlanModel.from_plan(plan) + mock_runner_data.plans = [PlanModel.from_plan(plan)] response = client.get("/plans/my-plan") @@ -192,17 +294,19 @@ class MyModel(BaseModel): } -def test_get_non_existent_plan_by_name(mock_runner: Mock, client: TestClient) -> None: - mock_runner.run.side_effect = KeyError("my-plan") +def test_get_non_existent_plan_by_name( + mock_runner_data: Mock, client: TestClient +) -> None: + mock_runner_data.plans = [] response = client.get("/plans/my-plan") assert response.status_code == status.HTTP_404_NOT_FOUND assert response.json() == {"detail": "Item not found"} -def test_get_devices(mock_runner: Mock, client: TestClient) -> None: +def test_get_devices(mock_runner_data: Mock, client: TestClient) -> None: device = MinimalDevice("my-device") - mock_runner.run.return_value = [DeviceModel.from_device(device)] + mock_runner_data.devices = [DeviceModel.from_device(device)] response = client.get("/devices") @@ -217,10 +321,12 @@ def test_get_devices(mock_runner: Mock, client: TestClient) -> None: } -def test_get_device_by_name(mock_runner: Mock, client: TestClient) -> None: +def test_get_device_by_name( + mock_runner: Mock, mock_runner_data: Mock, client: TestClient +) -> None: device = MinimalDevice("my-device") - mock_runner.run.return_value = DeviceModel.from_device(device) + mock_runner_data.devices = [DeviceModel.from_device(device)] response = client.get("/devices/my-device") mock_runner.run.assert_called_once_with(get_device, "my-device") @@ -231,15 +337,15 @@ def test_get_device_by_name(mock_runner: Mock, client: TestClient) -> None: } -def test_get_non_existent_device_by_name(mock_runner: Mock, client: TestClient) -> None: - mock_runner.run.side_effect = KeyError("my-device") +def test_get_non_existent_device_by_name(mock_runner_data: Mock, client: TestClient): + mock_runner_data.devices = [] response = client.get("/devices/my-device") assert response.status_code == status.HTTP_404_NOT_FOUND assert response.json() == {"detail": "Item not found"} -def test_create_task(mock_runner: Mock, client: TestClient) -> None: +def test_create_task(mock_runner: Mock, mock_runner_data: Mock, client: TestClient): task = TaskRequest( name="count", params={"detectors": ["x"]}, @@ -247,16 +353,34 @@ def test_create_task(mock_runner: Mock, client: TestClient) -> None: ) task_id = str(uuid.uuid4()) - mock_runner.run.side_effect = [task_id] + mock_runner_data.submit_task.return_value = task_id response = client.post("/tasks", json=task.model_dump()) - mock_runner.run.assert_called_with(submit_task, task, {"user": "Unknown"}) + mock_runner.run.assert_called_with(submit_task, task, {"user": None}) assert response.json() == {"task_id": task_id} +def test_submit_task_requires_permission( + mock_runner: Mock, + client_with_opa: TestClient, + mock_opa_client: Mock, + access_token: str, +): + task = TaskRequest(name="sleep", params={"time": 2}, instrument_session="cm12345-2") + client_with_opa.headers["Authorization"] = f"Bearer {access_token}" + mock_opa_client.can_submit_task.side_effect = HTTPException(status_code=403) + mock_runner.run.side_effect = RuntimeError("Task should not be submitted") + + resp = client_with_opa.post("/tasks", json=task.model_dump()) + + assert resp.status_code == 403 + mock_runner.run.assert_not_called() + + def test_create_task_inserts_auth_metadata( mock_runner: Mock, + mock_runner_data: Mock, client_with_auth: TestClient, ) -> None: task = TaskRequest( @@ -267,8 +391,7 @@ def test_create_task_inserts_auth_metadata( client_with_auth.follow_redirects = False task_id = str(uuid.uuid4()) - # mock_runner.run.side_effect = [task_id] - mock_runner.run.return_value = [task_id] + mock_runner_data.submit_task.return_value = task_id client_with_auth.post("/tasks", json=task.model_dump()) @@ -307,8 +430,10 @@ def test_create_task_validation_error(mock_runner: Mock, client: TestClient) -> } -def test_put_plan_begins_task(client: TestClient) -> None: +def test_put_plan_begins_task(client: TestClient, mock_runner_data: Mock) -> None: task_id = "04cd9aa6-b902-414b-ae4b-49ea4200e957" + mock_runner_data.tasks = [TrackableTask(task_id=task_id, task=Task(name="foo"))] + mock_runner_data.begin_task.side_effect = lambda task, **kw: task resp = client.put("/worker/task", json={"task_id": task_id}) @@ -316,14 +441,19 @@ def test_put_plan_begins_task(client: TestClient) -> None: assert resp.json() == {"task_id": task_id} -def test_put_plan_fails_if_not_idle(mock_runner: Mock, client: TestClient) -> None: +def test_put_plan_fails_if_not_idle(mock_runner_data: Mock, client: TestClient) -> None: task_id_current = "260f7de3-b608-4cdc-a66c-257e95809792" task_id_new = "07e98d68-21b5-4ad7-ac34-08b2cb992d42" # Set to non idle - mock_runner.run.return_value = TrackableTask( - task=Task(name="none"), task_id=task_id_current, is_complete=False - ) + mock_runner_data.tasks = [ + TrackableTask( + task=Task(name="none"), + task_id=task_id_current, + is_pending=False, + is_complete=False, + ) + ] resp = client.put("/worker/task", json={"task_id": task_id_new}) @@ -331,8 +461,8 @@ def test_put_plan_fails_if_not_idle(mock_runner: Mock, client: TestClient) -> No assert resp.json() == {"detail": "Worker already active"} -def test_get_tasks(mock_runner: Mock, client: TestClient) -> None: - tasks = [ +def test_get_tasks(mock_runner_data: Mock, client: TestClient) -> None: + mock_runner_data.tasks = [ TrackableTask(task_id="0", task=Task(name="sleep", params={"time": 0.0})), TrackableTask( task_id="1", @@ -342,8 +472,6 @@ def test_get_tasks(mock_runner: Mock, client: TestClient) -> None: ), ] - mock_runner.run.return_value = tasks - response = client.get("/tasks") assert response.status_code == status.HTTP_200_OK @@ -379,8 +507,8 @@ def test_get_tasks(mock_runner: Mock, client: TestClient) -> None: } -def test_get_tasks_by_status(mock_runner: Mock, client: TestClient) -> None: - tasks = [ +def test_get_tasks_by_status(mock_runner_data: Mock, client: TestClient) -> None: + mock_runner_data.tasks = [ TrackableTask( task_id="3", task=Task(name="third_task"), @@ -389,9 +517,7 @@ def test_get_tasks_by_status(mock_runner: Mock, client: TestClient) -> None: ), ] - mock_runner.run.return_value = tasks - - response = client.get("/tasks", params={"task_status": "PENDING"}) + response = client.get("/tasks", params={"task_status": "COMPLETE"}) assert response.json() == { "tasks": [ { @@ -416,16 +542,64 @@ def test_get_tasks_by_status_invalid(client: TestClient) -> None: assert response.status_code == status.HTTP_400_BAD_REQUEST -def test_delete_submitted_task(mock_runner: Mock, client: TestClient) -> None: +@pytest.mark.parametrize("admin,task_ids", [(True, ["foo", "bar"]), (False, ["foo"])]) +def test_get_tasks_filters_by_user( + mock_runner_data: Mock, + client_with_opa: TestClient, + access_token: str, + mock_opa_client: Mock, + admin: bool, + task_ids: list[str], +): + + mock_runner_data.tasks = [ + TrackableTask(task_id="foo", task=Task(name="f1", metadata={"user": "jd1"})), + TrackableTask(task_id="bar", task=Task(name="f2", metadata={"user": "jd2"})), + ] + mock_opa_client.admin.return_value = admin + client_with_opa.headers["Authorization"] = f"Bearer {access_token}" + tasks = client_with_opa.get("/tasks").json().get("tasks") + + assert [t["task_id"] for t in tasks] == task_ids + + +def test_delete_submitted_task(mock_runner_data: Mock, client: TestClient) -> None: task_id = str(uuid.uuid4()) - mock_runner.run.return_value = task_id + mock_runner_data.tasks = [TrackableTask(task_id=task_id, task=Task(name="foo"))] + mock_runner_data.clear_task.return_value = task_id response = client.delete(f"/tasks/{task_id}") assert response.json() == {"task_id": f"{task_id}"} + mock_runner_data.clear_task.assert_called_once_with(task_id) -def test_set_active_task(client: TestClient) -> None: +def test_cant_delete_other_users_task( + mock_runner: Mock, + mock_runner_data: Mock, + client_with_opa: TestClient, + access_token: str, + mock_opa_client: Mock, +): + mock_opa_client.admin.return_value = False + mock_runner_data.tasks = [ + TrackableTask( + task_id="bar", + is_pending=False, + task=Task(name="t2", metadata={"user": "jd2"}), + ) + ] + client_with_opa.headers["Authorization"] = f"Bearer {access_token}" + + resp = client_with_opa.delete("/tasks/bar") + + # 404 to obfuscate whether task exists when inaccessible + assert resp.status_code == 404 + mock_runner_data.clear_task.assert_not_called() + + +def test_set_active_task(client: TestClient, mock_runner_data: Mock) -> None: task_id = str(uuid.uuid4()) task = WorkerTask(task_id=task_id) + mock_runner_data.tasks = [TrackableTask(task_id=task_id, task=Task(name="foo"))] response = client.put("/worker/task", json=task.model_dump()) @@ -434,17 +608,19 @@ def test_set_active_task(client: TestClient) -> None: def test_set_active_task_active_task_complete( - mock_runner: Mock, client: TestClient + mock_runner_data: Mock, client: TestClient ) -> None: task_id = str(uuid.uuid4()) task = WorkerTask(task_id=task_id) - mock_runner.run.return_value = TrackableTask( - task_id="1", - task=Task(name="a_completed_task"), - is_complete=True, - is_pending=False, - ) + mock_runner_data.tasks = [ + TrackableTask( + task_id="1", + task=Task(name="a_completed_task"), + is_complete=True, + is_pending=False, + ) + ] response = client.put("/worker/task", json=task.model_dump()) @@ -453,17 +629,19 @@ def test_set_active_task_active_task_complete( def test_set_active_task_worker_already_running( - mock_runner: Mock, client: TestClient + mock_runner_data: Mock, client: TestClient ) -> None: task_id = str(uuid.uuid4()) task = WorkerTask(task_id=task_id) - mock_runner.run.return_value = TrackableTask( - task_id="1", - task=Task(name="a_running_task"), - is_complete=False, - is_pending=False, - ) + mock_runner_data.tasks = [ + TrackableTask( + task_id="1", + task=Task(name="a_running_task"), + is_complete=False, + is_pending=False, + ) + ] response = client.put("/worker/task", json=task.model_dump()) @@ -471,19 +649,47 @@ def test_set_active_task_worker_already_running( assert response.json() == {"detail": "Worker already active"} -def test_get_task(mock_runner: Mock, client: TestClient): - task_id = str(uuid.uuid4()) - task = TrackableTask( - task_id=task_id, - task=Task( - name="third_task", - metadata={ - "foo": "bar", - }, - ), - ) +@pytest.mark.parametrize("admin,status", [(True, 200), (False, 404)]) +def test_set_other_users_task_active( + mock_runner_data: Mock, + client_with_opa: TestClient, + mock_opa_client: Mock, + access_token: str, + admin: bool, + status: int, +): - mock_runner.run.return_value = task + task_id = "foo" + task = WorkerTask(task_id=task_id) + mock_opa_client.admin.return_value = admin + + client_with_opa.headers["Authorization"] = f"Bearer {access_token}" + + mock_runner_data.tasks = [ + TrackableTask(task_id=task_id, task=Task(name="foo", metadata={"user": "jd2"})) + ] + + resp = client_with_opa.put("/worker/task", json=task.model_dump()) + + if status >= 400: + mock_runner_data.begin_task.assert_not_called() + + assert resp.status_code == status + + +def test_get_task(mock_runner_data: Mock, client: TestClient): + task_id = str(uuid.uuid4()) + mock_runner_data.tasks = [ + TrackableTask( + task_id=task_id, + task=Task( + name="third_task", + metadata={ + "foo": "bar", + }, + ), + ) + ] response = client.get(f"/tasks/{task_id}") assert response.json() == { @@ -503,16 +709,34 @@ def test_get_task(mock_runner: Mock, client: TestClient): } -def test_get_all_tasks(mock_runner: Mock, client: TestClient): +@pytest.mark.parametrize("admin,status", [(True, 200), (False, 404)]) +def test_get_other_users_task( + mock_runner_data: Mock, + client_with_opa: TestClient, + mock_opa_client: Mock, + access_token: str, + admin: bool, + status: int, +): + client_with_opa.headers["Authorization"] = f"Bearer {access_token}" + mock_runner_data.tasks = [ + TrackableTask(task_id="foo", task=Task(name="bar", metadata={"user": "jd2"})) + ] + mock_opa_client.admin.return_value = admin + + resp = client_with_opa.get("/tasks/foo") + assert resp.status_code == status + + +def test_get_all_tasks(mock_runner_data: Mock, client: TestClient): task_id = str(uuid.uuid4()) - tasks = [ + mock_runner_data.tasks = [ TrackableTask( task_id=task_id, task=Task(name="third_task"), ) ] - mock_runner.run.return_value = tasks response = client.get("/tasks") assert response.status_code == status.HTTP_200_OK assert response.json() == { @@ -534,53 +758,60 @@ def test_get_all_tasks(mock_runner: Mock, client: TestClient): } -def test_get_task_error(mock_runner: Mock, client: TestClient): +def test_get_task_error(mock_runner_data: Mock, client: TestClient): task_id = 567 - mock_runner.run.return_value = None + mock_runner_data.tasks = [] response = client.get(f"/tasks/{task_id}") assert response.json() == {"detail": "Item not found"} -def test_get_active_task(mock_runner: Mock, client: TestClient): +def test_get_active_task(mock_runner_data: Mock, client: TestClient): task_id = str(uuid.uuid4()) task = TrackableTask( task_id=task_id, task=Task(name="third_task"), + is_pending=False, + is_complete=False, ) - mock_runner.run.return_value = task + mock_runner_data.tasks = [task] response = client.get("/worker/task") assert response.json() == {"task_id": f"{task_id}"} -def test_get_active_task_none(mock_runner: Mock, client: TestClient): - mock_runner.run.return_value = None +def test_get_active_task_none(mock_runner_data: Mock, client: TestClient): + mock_runner_data.tasks = [] response = client.get("/worker/task") assert response.json() == {"task_id": None} -def test_get_state(mock_runner: Mock, client: TestClient): +def test_get_state(mock_runner_data: Mock, client: TestClient): state = WorkerState.SUSPENDING - mock_runner.run.return_value = state + mock_runner_data.state = state response = client.get("/worker/state") assert response.json() == state -def test_set_state_running_to_paused(mock_runner: Mock, client: TestClient): +def test_set_state_running_to_paused(mock_runner_data: Mock, client: TestClient): current_state = WorkerState.RUNNING final_state = WorkerState.PAUSED - mock_runner.run.side_effect = [current_state, None, final_state] + type(mock_runner_data).state = PropertyMock( + side_effect=[current_state, final_state] + ) + mock_runner_data.tasks = [ + TrackableTask(task_id="foobar", task=Task(name="foo")), + ] response = client.put( "/worker/state", json=StateChangeRequest(new_state=final_state).model_dump() ) - mock_runner.run.assert_any_call(pause_worker, False) + mock_runner_data.pause_worker.assert_called_once_with(False) assert response.status_code == status.HTTP_202_ACCEPTED assert response.json() == final_state @@ -588,7 +819,12 @@ def test_set_state_running_to_paused(mock_runner: Mock, client: TestClient): def test_set_state_paused_to_running(mock_runner: Mock, client: TestClient): current_state = WorkerState.PAUSED final_state = WorkerState.RUNNING - mock_runner.run.side_effect = [current_state, None, final_state] + mock_runner.run.side_effect = [ + current_state, + TrackableTask(task_id="foobar", task=Task(name="foo")), + None, + final_state, + ] response = client.put( "/worker/state", json=StateChangeRequest(new_state=final_state).model_dump() @@ -602,7 +838,12 @@ def test_set_state_paused_to_running(mock_runner: Mock, client: TestClient): def test_set_state_running_to_aborting(mock_runner: Mock, client: TestClient): current_state = WorkerState.RUNNING final_state = WorkerState.ABORTING - mock_runner.run.side_effect = [current_state, None, final_state] + mock_runner.run.side_effect = [ + current_state, + TrackableTask(task_id="foobar", task=Task(name="foo")), + None, + final_state, + ] response = client.put( "/worker/state", json=StateChangeRequest(new_state=final_state).model_dump() @@ -619,7 +860,12 @@ def test_set_state_running_to_stopping_including_reason( current_state = WorkerState.RUNNING final_state = WorkerState.STOPPING reason = "blueapi is being stopped" - mock_runner.run.side_effect = [current_state, None, final_state] + mock_runner.run.side_effect = [ + current_state, + TrackableTask(task_id="foobar", task=Task(name="foo")), + None, + final_state, + ] response = client.put( "/worker/state", @@ -635,7 +881,12 @@ def test_set_state_transition_error(mock_runner: Mock, client: TestClient): current_state = WorkerState.RUNNING final_state = WorkerState.STOPPING - mock_runner.run.side_effect = [current_state, TransitionError(), final_state] + mock_runner.run.side_effect = [ + current_state, + TrackableTask(task_id="foobar", task=Task(name="foo")), + TransitionError(), + final_state, + ] response = client.put( "/worker/state", @@ -662,6 +913,39 @@ def test_set_state_invalid_transition(mock_runner: Mock, client: TestClient): assert response.json() == final_state +@pytest.mark.parametrize("admin,status", [(True, 202), (False, 403)]) +def test_set_state_of_other_users_task( + mock_runner_data: Mock, + client_with_opa: TestClient, + mock_opa_client: Mock, + access_token: str, + admin: bool, + status: int, +): + + mock_opa_client.admin.return_value = admin + mock_runner_data.tasks = [ + TrackableTask( + task_id="foo", + is_pending=False, + task=Task(name="bar", metadata={"user": "jd2"}), + ), + ] + mock_runner_data.state = WorkerState.RUNNING + mock_runner_data.cancel_active_task.return_value = WorkerState.ABORTING + + client_with_opa.headers["Authorization"] = f"Bearer {access_token}" + + resp = client_with_opa.put( + "/worker/state", + json=StateChangeRequest(new_state=WorkerState.ABORTING).model_dump(), + ) + + if not admin: + mock_runner_data.cancel_active_task.assert_not_called() + assert resp.status_code == status + + def test_get_environment_idle(mock_runner: Mock, client: TestClient) -> None: environment_id = uuid.uuid4() mock_runner.state = EnvironmentResponse( @@ -701,36 +985,37 @@ def test_subprocess_enabled_by_default(mp_pool_mock: MagicMock): main.teardown_runner() -def test_get_without_authentication(mock_runner: Mock, client: TestClient) -> None: - mock_runner.run.side_effect = jwt.PyJWTError - response = client.get("/devices/my-device") +def test_get_without_authentication(mock_runner: Mock, client_with_auth: TestClient): + del client_with_auth.headers["Authorization"] + response = client_with_auth.get("/devices/my-device") + mock_runner.run.assert_not_called() assert response.status_code == status.HTTP_401_UNAUTHORIZED assert response.json() == {"detail": "Not authenticated"} def test_oidc_config_not_found_when_auth_is_disabled( - mock_runner: Mock, client: TestClient + mock_runner_data: Mock, client: TestClient ): - mock_runner.run.return_value = None + mock_runner_data.oidc_config = None response = client.get("/config/oidc") assert response.status_code == status.HTTP_204_NO_CONTENT assert response.text == "" def test_get_oidc_config( - mock_runner: Mock, + mock_runner_data: Mock, oidc_config: OIDCConfig, mock_authn_server, client_with_auth: TestClient, ): - mock_runner.run.return_value = oidc_config + mock_runner_data.oidc_config = oidc_config response = client_with_auth.get("/config/oidc") assert response.status_code == status.HTTP_200_OK assert response.json() == oidc_config.model_dump() -def test_get_python_environment(mock_runner: Mock, client: TestClient): +def test_get_python_environment(mock_runner_data: Mock, client: TestClient): packages = PythonEnvironmentResponse( installed_packages=[ PackageInfo( @@ -742,7 +1027,7 @@ def test_get_python_environment(mock_runner: Mock, client: TestClient): ) ] ) - mock_runner.run.return_value = packages + mock_runner_data.python_environment = packages response = client.get("/python_environment") assert response.status_code == status.HTTP_200_OK assert response.json() == packages.model_dump() @@ -756,13 +1041,13 @@ def test_health_probe(client: TestClient): def test_logout( - mock_runner: Mock, + mock_runner_data: Mock, mock_authn_server, oidc_config: OIDCConfig, client_with_auth: TestClient, ): oidc_config.logout_redirect_endpoint = "/oauth2/sign_out/" - mock_runner.run.return_value = oidc_config + mock_runner_data.oidc_config = oidc_config client_with_auth.follow_redirects = False response = client_with_auth.get("/logout") assert response.status_code == status.HTTP_308_PERMANENT_REDIRECT @@ -785,16 +1070,16 @@ def test_docs_redirect( @pytest.mark.parametrize("has_oidc_config", [True, False]) def test_logout_when_oidc_config_invalid( has_oidc_config: bool, - mock_runner: Mock, + mock_runner_data: Mock, oidc_config: OIDCConfig, mock_authn_server, client_with_auth: TestClient, ): if has_oidc_config: oidc_config.logout_redirect_endpoint = "" - mock_runner.run.return_value = oidc_config + mock_runner_data.oidc_config = oidc_config else: - mock_runner.run.return_value = None + mock_runner_data.oidc_config = None response = client_with_auth.get("/logout") assert response.status_code == status.HTTP_205_RESET_CONTENT diff --git a/tests/unit_tests/test_config.py b/tests/unit_tests/test_config.py index 30fe551c4..f3d3d37fd 100644 --- a/tests/unit_tests/test_config.py +++ b/tests/unit_tests/test_config.py @@ -337,6 +337,13 @@ def test_config_yaml_parsed(temp_yaml_config_file): } ], }, + "opa": { + "root": "http://opa.example.com/", + "audience": "account", + "tiled_service_account_check": "v1/tiled_service_account", + "submit_task_check": "v1/submit_task", + "admin_check": "v1/admin_check", + }, }, { "stomp": { @@ -392,6 +399,7 @@ def test_config_yaml_parsed(temp_yaml_config_file): } ], }, + "opa": None, }, ], indirect=True, diff --git a/uv.lock b/uv.lock index b2d10b5bd..eab94a867 100644 --- a/uv.lock +++ b/uv.lock @@ -420,6 +420,7 @@ name = "blueapi" source = { editable = "." } dependencies = [ { name = "aioca" }, + { name = "aiohttp" }, { name = "bluesky", extra = ["plotting"] }, { name = "bluesky-stomp" }, { name = "click" }, @@ -481,6 +482,7 @@ dev = [ [package.metadata] requires-dist = [ { name = "aioca" }, + { name = "aiohttp", specifier = ">=3.13.5" }, { name = "bluesky", extras = ["plotting"], specifier = ">=1.14.0" }, { name = "bluesky-stomp", specifier = ">=0.1.6" }, { name = "click", specifier = ">=8.2.0" },