diff --git a/openfga_sdk/oauth2.py b/openfga_sdk/oauth2.py index 27eac5a..de00961 100644 --- a/openfga_sdk/oauth2.py +++ b/openfga_sdk/oauth2.py @@ -1,41 +1,29 @@ import asyncio import json -import math import random -import sys from datetime import datetime, timedelta import urllib3 from openfga_sdk.configuration import Configuration -from openfga_sdk.constants import USER_AGENT +from openfga_sdk.constants import ( + TOKEN_EXPIRY_JITTER_IN_SEC, + TOKEN_EXPIRY_THRESHOLD_BUFFER_IN_SEC, + USER_AGENT, +) from openfga_sdk.credentials import Credentials from openfga_sdk.exceptions import AuthenticationError +from openfga_sdk.oauth2_common import _TokenState, jitter from openfga_sdk.telemetry.attributes import TelemetryAttributes from openfga_sdk.telemetry.telemetry import Telemetry -def jitter(loop_count, min_wait_in_ms): - """ - Generate a random jitter value for exponential backoff - """ - minimum = math.ceil(2**loop_count * min_wait_in_ms) - maximum = math.ceil(2 ** (loop_count + 1) * min_wait_in_ms) - jitter = random.randrange(minimum, maximum) / 1000 - - # If running in pytest, set jitter to 0 to speed up tests - if "pytest" in sys.modules: - jitter = 0 - - return jitter - - class OAuth2Client: def __init__(self, credentials: Credentials, configuration=None): self._credentials = credentials - self._access_token = None - self._access_expiry_time = None + self._token_state: _TokenState | None = None + self._lock = asyncio.Lock() self._telemetry = Telemetry() if configuration is None: @@ -45,13 +33,13 @@ def __init__(self, credentials: Credentials, configuration=None): def _token_valid(self): """ - Return whether token is valid + Return whether token is valid (with proactive expiry buffer to avoid using near-expired tokens) """ - if self._access_token is None or self._access_expiry_time is None: - return False - if self._access_expiry_time < datetime.now(): + state = self._token_state # atomic snapshot — either old or new, never torn + if state is None: return False - return True + remaining = (state.expiry_time - datetime.now()).total_seconds() + return remaining > state.expiry_buffer async def _obtain_token(self, client): """ @@ -76,7 +64,9 @@ async def _obtain_token(self, client): # Add scope parameter if scopes are configured if configuration.scopes is not None: if isinstance(configuration.scopes, list): - scope_str = " ".join(s.strip() for s in configuration.scopes if s and s.strip()) + scope_str = " ".join( + s.strip() for s in configuration.scopes if s and s.strip() + ) else: scope_str = ( configuration.scopes.strip() @@ -136,10 +126,15 @@ async def _obtain_token(self, client): raise AuthenticationError(http_resp=raw_response) if api_response.get("expires_in") and api_response.get("access_token"): - self._access_expiry_time = datetime.now() + timedelta( - seconds=int(api_response.get("expires_in")) + self._token_state = _TokenState( + access_token=api_response.get("access_token"), + expiry_time=datetime.now() + + timedelta(seconds=int(api_response.get("expires_in"))), + expiry_buffer=( + TOKEN_EXPIRY_THRESHOLD_BUFFER_IN_SEC + + random.random() * TOKEN_EXPIRY_JITTER_IN_SEC + ), ) - self._access_token = api_response.get("access_token") self._telemetry.metrics.credentialsRequest( attributes={ TelemetryAttributes.fga_client_request_client_id: configuration.client_id @@ -154,8 +149,8 @@ async def get_authentication_header(self, client): """ If configured, return the header for authentication """ - # check to see token is valid if not self._token_valid(): - # In this case, the token is not valid, we need to get the refresh the token - await self._obtain_token(client) - return {"Authorization": f"Bearer {self._access_token}"} + async with self._lock: + if not self._token_valid(): + await self._obtain_token(client) + return {"Authorization": f"Bearer {self._token_state.access_token}"} diff --git a/openfga_sdk/oauth2_common.py b/openfga_sdk/oauth2_common.py new file mode 100644 index 0000000..71562a6 --- /dev/null +++ b/openfga_sdk/oauth2_common.py @@ -0,0 +1,28 @@ +import math +import random +import sys + +from dataclasses import dataclass +from datetime import datetime + + +@dataclass(frozen=True) +class _TokenState: + access_token: str + expiry_time: datetime + expiry_buffer: float + + +def jitter(loop_count, min_wait_in_ms): + """ + Generate a random jitter value for exponential backoff + """ + minimum = math.ceil(2**loop_count * min_wait_in_ms) + maximum = math.ceil(2 ** (loop_count + 1) * min_wait_in_ms) + jitter = random.randrange(minimum, maximum) / 1000 + + # If running in pytest, set jitter to 0 to speed up tests + if "pytest" in sys.modules: + jitter = 0 + + return jitter diff --git a/openfga_sdk/sync/oauth2.py b/openfga_sdk/sync/oauth2.py index 0f5bc09..d23cc93 100644 --- a/openfga_sdk/sync/oauth2.py +++ b/openfga_sdk/sync/oauth2.py @@ -1,7 +1,6 @@ import json -import math import random -import sys +import threading import time from datetime import datetime, timedelta @@ -9,33 +8,23 @@ import urllib3 from openfga_sdk.configuration import Configuration -from openfga_sdk.constants import USER_AGENT +from openfga_sdk.constants import ( + TOKEN_EXPIRY_JITTER_IN_SEC, + TOKEN_EXPIRY_THRESHOLD_BUFFER_IN_SEC, + USER_AGENT, +) from openfga_sdk.credentials import Credentials from openfga_sdk.exceptions import AuthenticationError +from openfga_sdk.oauth2_common import _TokenState, jitter from openfga_sdk.telemetry.attributes import TelemetryAttributes from openfga_sdk.telemetry.telemetry import Telemetry -def jitter(loop_count, min_wait_in_ms): - """ - Generate a random jitter value for exponential backoff - """ - minimum = math.ceil(2**loop_count * min_wait_in_ms) - maximum = math.ceil(2 ** (loop_count + 1) * min_wait_in_ms) - jitter = random.randrange(minimum, maximum) / 1000 - - # If running in pytest, set jitter to 0 to speed up tests - if "pytest" in sys.modules: - jitter = 0 - - return jitter - - class OAuth2Client: def __init__(self, credentials: Credentials, configuration=None): self._credentials = credentials - self._access_token = None - self._access_expiry_time = None + self._token_state: _TokenState | None = None + self._lock = threading.Lock() self._telemetry = Telemetry() if configuration is None: @@ -45,13 +34,13 @@ def __init__(self, credentials: Credentials, configuration=None): def _token_valid(self): """ - Return whether token is valid + Return whether token is valid (with proactive expiry buffer to avoid using near-expired tokens) """ - if self._access_token is None or self._access_expiry_time is None: - return False - if self._access_expiry_time < datetime.now(): + state = self._token_state # atomic snapshot — either old or new, never torn + if state is None: return False - return True + remaining = (state.expiry_time - datetime.now()).total_seconds() + return remaining > state.expiry_buffer def _obtain_token(self, client): """ @@ -76,7 +65,9 @@ def _obtain_token(self, client): # Add scope parameter if scopes are configured if configuration.scopes is not None: if isinstance(configuration.scopes, list): - scope_str = " ".join(s.strip() for s in configuration.scopes if s and s.strip()) + scope_str = " ".join( + s.strip() for s in configuration.scopes if s and s.strip() + ) else: scope_str = ( configuration.scopes.strip() @@ -136,10 +127,15 @@ def _obtain_token(self, client): raise AuthenticationError(http_resp=raw_response) if api_response.get("expires_in") and api_response.get("access_token"): - self._access_expiry_time = datetime.now() + timedelta( - seconds=int(api_response.get("expires_in")) + self._token_state = _TokenState( + access_token=api_response.get("access_token"), + expiry_time=datetime.now() + + timedelta(seconds=int(api_response.get("expires_in"))), + expiry_buffer=( + TOKEN_EXPIRY_THRESHOLD_BUFFER_IN_SEC + + random.random() * TOKEN_EXPIRY_JITTER_IN_SEC + ), ) - self._access_token = api_response.get("access_token") self._telemetry.metrics.credentialsRequest( attributes={ TelemetryAttributes.fga_client_request_client_id: configuration.client_id @@ -154,8 +150,8 @@ def get_authentication_header(self, client): """ If configured, return the header for authentication """ - # check to see token is valid if not self._token_valid(): - # In this case, the token is not valid, we need to get the refresh the token - self._obtain_token(client) - return {"Authorization": f"Bearer {self._access_token}"} + with self._lock: + if not self._token_valid(): + self._obtain_token(client) + return {"Authorization": f"Bearer {self._token_state.access_token}"} diff --git a/test/oauth2_test.py b/test/oauth2_test.py index 48b5030..ffa3e5a 100644 --- a/test/oauth2_test.py +++ b/test/oauth2_test.py @@ -1,3 +1,5 @@ +import asyncio + from datetime import datetime, timedelta from unittest import IsolatedAsyncioTestCase from unittest.mock import patch @@ -6,10 +8,11 @@ from openfga_sdk import rest from openfga_sdk.configuration import Configuration -from openfga_sdk.constants import USER_AGENT +from openfga_sdk.constants import TOKEN_EXPIRY_THRESHOLD_BUFFER_IN_SEC, USER_AGENT from openfga_sdk.credentials import CredentialConfiguration, Credentials from openfga_sdk.exceptions import AuthenticationError from openfga_sdk.oauth2 import OAuth2Client +from openfga_sdk.oauth2_common import _TokenState # Helper function to construct mock response @@ -33,8 +36,11 @@ async def test_get_authentication_valid_client_credentials(self): Test getting authentication header when method is client credentials """ client = OAuth2Client(None) - client._access_token = "XYZ123" - client._access_expiry_time = datetime.now() + timedelta(seconds=60) + client._token_state = _TokenState( + access_token="XYZ123", + expiry_time=datetime.now() + timedelta(seconds=3600), + expiry_buffer=0, + ) auth_header = await client.get_authentication_header(None) self.assertEqual(auth_header, {"Authorization": "Bearer XYZ123"}) @@ -65,9 +71,9 @@ async def test_get_authentication_obtain_client_credentials(self, mock_request): client = OAuth2Client(credentials) auth_header = await client.get_authentication_header(rest_client) self.assertEqual(auth_header, {"Authorization": "Bearer AABBCCDD"}) - self.assertEqual(client._access_token, "AABBCCDD") + self.assertEqual(client._token_state.access_token, "AABBCCDD") self.assertGreaterEqual( - client._access_expiry_time, current_time + timedelta(seconds=120) + client._token_state.expiry_time, current_time + timedelta(seconds=120) ) expected_header = urllib3.response.HTTPHeaderDict( { @@ -149,8 +155,11 @@ async def test_get_authentication_obtain_with_expired_client_credentials_failed( rest_client = rest.RESTClientObject(Configuration()) client = OAuth2Client(credentials) - client._access_token = "XYZ123" - client._access_expiry_time = datetime.now() - timedelta(seconds=240) + client._token_state = _TokenState( + access_token="XYZ123", + expiry_time=datetime.now() - timedelta(seconds=240), + expiry_buffer=0, + ) with self.assertRaises(AuthenticationError): await client.get_authentication_header(rest_client) @@ -291,9 +300,9 @@ async def test_get_authentication_keep_full_url(self, mock_request): client = OAuth2Client(credentials) auth_header = await client.get_authentication_header(rest_client) self.assertEqual(auth_header, {"Authorization": "Bearer AABBCCDD"}) - self.assertEqual(client._access_token, "AABBCCDD") + self.assertEqual(client._token_state.access_token, "AABBCCDD") self.assertGreaterEqual( - client._access_expiry_time, current_time + timedelta(seconds=120) + client._token_state.expiry_time, current_time + timedelta(seconds=120) ) expected_header = urllib3.response.HTTPHeaderDict( { @@ -346,9 +355,9 @@ async def test_get_authentication_add_scheme(self, mock_request): client = OAuth2Client(credentials) auth_header = await client.get_authentication_header(rest_client) self.assertEqual(auth_header, {"Authorization": "Bearer AABBCCDD"}) - self.assertEqual(client._access_token, "AABBCCDD") + self.assertEqual(client._token_state.access_token, "AABBCCDD") self.assertGreaterEqual( - client._access_expiry_time, current_time + timedelta(seconds=120) + client._token_state.expiry_time, current_time + timedelta(seconds=120) ) expected_header = urllib3.response.HTTPHeaderDict( { @@ -401,9 +410,9 @@ async def test_get_authentication_add_path(self, mock_request): client = OAuth2Client(credentials) auth_header = await client.get_authentication_header(rest_client) self.assertEqual(auth_header, {"Authorization": "Bearer AABBCCDD"}) - self.assertEqual(client._access_token, "AABBCCDD") + self.assertEqual(client._token_state.access_token, "AABBCCDD") self.assertGreaterEqual( - client._access_expiry_time, current_time + timedelta(seconds=120) + client._token_state.expiry_time, current_time + timedelta(seconds=120) ) expected_header = urllib3.response.HTTPHeaderDict( { @@ -456,9 +465,9 @@ async def test_get_authentication_add_scheme_and_path(self, mock_request): client = OAuth2Client(credentials) auth_header = await client.get_authentication_header(rest_client) self.assertEqual(auth_header, {"Authorization": "Bearer AABBCCDD"}) - self.assertEqual(client._access_token, "AABBCCDD") + self.assertEqual(client._token_state.access_token, "AABBCCDD") self.assertGreaterEqual( - client._access_expiry_time, current_time + timedelta(seconds=120) + client._token_state.expiry_time, current_time + timedelta(seconds=120) ) expected_header = urllib3.response.HTTPHeaderDict( { @@ -514,9 +523,9 @@ async def test_get_authentication_obtain_client_credentials_with_scopes_list( client = OAuth2Client(credentials) auth_header = await client.get_authentication_header(rest_client) self.assertEqual(auth_header, {"Authorization": "Bearer AABBCCDD"}) - self.assertEqual(client._access_token, "AABBCCDD") + self.assertEqual(client._token_state.access_token, "AABBCCDD") self.assertGreaterEqual( - client._access_expiry_time, current_time + timedelta(seconds=120) + client._token_state.expiry_time, current_time + timedelta(seconds=120) ) expected_header = urllib3.response.HTTPHeaderDict( { @@ -573,9 +582,9 @@ async def test_get_authentication_obtain_client_credentials_with_scopes_string( client = OAuth2Client(credentials) auth_header = await client.get_authentication_header(rest_client) self.assertEqual(auth_header, {"Authorization": "Bearer AABBCCDD"}) - self.assertEqual(client._access_token, "AABBCCDD") + self.assertEqual(client._token_state.access_token, "AABBCCDD") self.assertGreaterEqual( - client._access_expiry_time, current_time + timedelta(seconds=120) + client._token_state.expiry_time, current_time + timedelta(seconds=120) ) expected_header = urllib3.response.HTTPHeaderDict( { @@ -651,6 +660,87 @@ async def test_get_authentication_without_audience(self, mock_request): ) await rest_client.close() + @patch.object(rest.RESTClientObject, "request") + @patch("openfga_sdk.oauth2.random") + async def test_get_authentication_refreshes_near_expiry_token( + self, mock_random, mock_request + ): + """ + Token close to expiry (within buffer window) should trigger a proactive refresh + """ + mock_random.random.return_value = 0 + short_lived_secs = max(1, TOKEN_EXPIRY_THRESHOLD_BUFFER_IN_SEC - 1) + + mock_request.side_effect = [ + mock_response( + f'{{"expires_in": {short_lived_secs}, "access_token": "short-lived-token"}}', + 200, + ), + mock_response( + '{"expires_in": 3600, "access_token": "refreshed-token"}', + 200, + ), + ] + + credentials = Credentials( + method="client_credentials", + configuration=CredentialConfiguration( + client_id="myclientid", + client_secret="mysecret", + api_issuer="issuer.fga.example", + api_audience="myaudience", + ), + ) + rest_client = rest.RESTClientObject(Configuration()) + client = OAuth2Client(credentials) + + header1 = await client.get_authentication_header(rest_client) + header2 = await client.get_authentication_header(rest_client) + + self.assertEqual(header1, {"Authorization": "Bearer short-lived-token"}) + self.assertEqual(header2, {"Authorization": "Bearer refreshed-token"}) + self.assertEqual(mock_request.call_count, 2) + + await rest_client.close() + + async def test_concurrent_requests_only_fetch_token_once(self): + """ + Multiple concurrent requests while the token is invalid should result in + only one token fetch — subsequent coroutines wait on the lock and reuse + the token obtained by the first. + """ + obtain_calls = [] + + credentials = Credentials( + method="client_credentials", + configuration=CredentialConfiguration( + client_id="myclientid", + client_secret="mysecret", + api_issuer="issuer.fga.example", + api_audience="myaudience", + ), + ) + oauth_client = OAuth2Client(credentials) + + async def mock_obtain_token(client): + obtain_calls.append(1) + await asyncio.sleep(0) # yield so other coroutines reach the lock + oauth_client._token_state = _TokenState( + access_token="concurrent-token", + expiry_time=datetime.now() + timedelta(seconds=3600), + expiry_buffer=300, + ) + + with patch.object(oauth_client, "_obtain_token", side_effect=mock_obtain_token): + results = await asyncio.gather( + *[oauth_client.get_authentication_header(None) for _ in range(5)] + ) + + self.assertEqual(len(obtain_calls), 1) + self.assertTrue( + all(r == {"Authorization": "Bearer concurrent-token"} for r in results) + ) + @patch.object(rest.RESTClientObject, "request") async def test_get_authentication_with_scopes_no_audience(self, mock_request): """ diff --git a/test/sync/oauth2_test.py b/test/sync/oauth2_test.py index d0dc387..40109fa 100644 --- a/test/sync/oauth2_test.py +++ b/test/sync/oauth2_test.py @@ -1,3 +1,6 @@ +import threading +import time + from datetime import datetime, timedelta from unittest import IsolatedAsyncioTestCase from unittest.mock import patch @@ -5,9 +8,10 @@ import urllib3 from openfga_sdk.configuration import Configuration -from openfga_sdk.constants import USER_AGENT +from openfga_sdk.constants import TOKEN_EXPIRY_THRESHOLD_BUFFER_IN_SEC, USER_AGENT from openfga_sdk.credentials import CredentialConfiguration, Credentials from openfga_sdk.exceptions import AuthenticationError +from openfga_sdk.oauth2_common import _TokenState from openfga_sdk.sync import rest from openfga_sdk.sync.oauth2 import OAuth2Client @@ -33,8 +37,11 @@ def test_get_authentication_valid_client_credentials(self): Test getting authentication header when method is client credentials """ client = OAuth2Client(None) - client._access_token = "XYZ123" - client._access_expiry_time = datetime.now() + timedelta(seconds=60) + client._token_state = _TokenState( + access_token="XYZ123", + expiry_time=datetime.now() + timedelta(seconds=3600), + expiry_buffer=0, + ) auth_header = client.get_authentication_header(None) self.assertEqual(auth_header, {"Authorization": "Bearer XYZ123"}) @@ -65,9 +72,9 @@ def test_get_authentication_obtain_client_credentials(self, mock_request): client = OAuth2Client(credentials) auth_header = client.get_authentication_header(rest_client) self.assertEqual(auth_header, {"Authorization": "Bearer AABBCCDD"}) - self.assertEqual(client._access_token, "AABBCCDD") + self.assertEqual(client._token_state.access_token, "AABBCCDD") self.assertGreaterEqual( - client._access_expiry_time, current_time + timedelta(seconds=120) + client._token_state.expiry_time, current_time + timedelta(seconds=120) ) expected_header = urllib3.response.HTTPHeaderDict( { @@ -123,9 +130,9 @@ def test_get_authentication_obtain_client_credentials_with_scopes_list( client = OAuth2Client(credentials) auth_header = client.get_authentication_header(rest_client) self.assertEqual(auth_header, {"Authorization": "Bearer AABBCCDD"}) - self.assertEqual(client._access_token, "AABBCCDD") + self.assertEqual(client._token_state.access_token, "AABBCCDD") self.assertGreaterEqual( - client._access_expiry_time, current_time + timedelta(seconds=120) + client._token_state.expiry_time, current_time + timedelta(seconds=120) ) expected_header = urllib3.response.HTTPHeaderDict( { @@ -182,9 +189,9 @@ def test_get_authentication_obtain_client_credentials_with_scopes_string( client = OAuth2Client(credentials) auth_header = client.get_authentication_header(rest_client) self.assertEqual(auth_header, {"Authorization": "Bearer AABBCCDD"}) - self.assertEqual(client._access_token, "AABBCCDD") + self.assertEqual(client._token_state.access_token, "AABBCCDD") self.assertGreaterEqual( - client._access_expiry_time, current_time + timedelta(seconds=120) + client._token_state.expiry_time, current_time + timedelta(seconds=120) ) expected_header = urllib3.response.HTTPHeaderDict( { @@ -265,8 +272,11 @@ async def test_get_authentication_obtain_with_expired_client_credentials_failed( rest_client = rest.RESTClientObject(Configuration()) client = OAuth2Client(credentials) - client._access_token = "XYZ123" - client._access_expiry_time = datetime.now() - timedelta(seconds=240) + client._token_state = _TokenState( + access_token="XYZ123", + expiry_time=datetime.now() - timedelta(seconds=240), + expiry_buffer=0, + ) with self.assertRaises(AuthenticationError): client.get_authentication_header(rest_client) @@ -427,6 +437,95 @@ def test_get_authentication_without_audience(self, mock_request): ) rest_client.close() + @patch.object(rest.RESTClientObject, "request") + @patch("openfga_sdk.sync.oauth2.random") + def test_get_authentication_refreshes_near_expiry_token( + self, mock_random, mock_request + ): + """ + Token close to expiry (within buffer window) should trigger a proactive refresh + """ + mock_random.random.return_value = 0 + short_lived_secs = max(1, TOKEN_EXPIRY_THRESHOLD_BUFFER_IN_SEC - 1) + + mock_request.side_effect = [ + mock_response( + f'{{"expires_in": {short_lived_secs}, "access_token": "short-lived-token"}}', + 200, + ), + mock_response( + '{"expires_in": 3600, "access_token": "refreshed-token"}', + 200, + ), + ] + + credentials = Credentials( + method="client_credentials", + configuration=CredentialConfiguration( + client_id="myclientid", + client_secret="mysecret", + api_issuer="issuer.fga.example", + api_audience="myaudience", + ), + ) + rest_client = rest.RESTClientObject(Configuration()) + client = OAuth2Client(credentials) + + header1 = client.get_authentication_header(rest_client) + header2 = client.get_authentication_header(rest_client) + + self.assertEqual(header1, {"Authorization": "Bearer short-lived-token"}) + self.assertEqual(header2, {"Authorization": "Bearer refreshed-token"}) + self.assertEqual(mock_request.call_count, 2) + + rest_client.close() + + def test_concurrent_requests_only_fetch_token_once(self): + """ + Multiple concurrent threads while the token is invalid should result in + only one token fetch — subsequent threads wait on the lock and reuse + the token obtained by the first. + """ + obtain_calls = [] + + credentials = Credentials( + method="client_credentials", + configuration=CredentialConfiguration( + client_id="myclientid", + client_secret="mysecret", + api_issuer="issuer.fga.example", + api_audience="myaudience", + ), + ) + oauth_client = OAuth2Client(credentials) + + def mock_obtain_token(client): + obtain_calls.append(1) + time.sleep(0.05) # hold the lock briefly so other threads queue up + oauth_client._token_state = _TokenState( + access_token="concurrent-token", + expiry_time=datetime.now() + timedelta(seconds=3600), + expiry_buffer=300, + ) + + results = [] + + def call(): + results.append(oauth_client.get_authentication_header(None)) + + with patch.object(oauth_client, "_obtain_token", side_effect=mock_obtain_token): + threads = [threading.Thread(target=call) for _ in range(5)] + for t in threads: + t.start() + for t in threads: + t.join() + + self.assertEqual(len(results), 5) + self.assertEqual(len(obtain_calls), 1) + self.assertTrue( + all(r == {"Authorization": "Bearer concurrent-token"} for r in results) + ) + @patch.object(rest.RESTClientObject, "request") def test_get_authentication_with_scopes_no_audience(self, mock_request): """ @@ -477,4 +576,3 @@ def test_get_authentication_with_scopes_no_audience(self, mock_request): }, ) rest_client.close() -