Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 28 additions & 33 deletions openfga_sdk/oauth2.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,29 @@
import asyncio
import json
import math
import random
import sys

from datetime import datetime, timedelta

import urllib3

from openfga_sdk.configuration import Configuration
from openfga_sdk.constants import USER_AGENT
from openfga_sdk.constants import (
TOKEN_EXPIRY_JITTER_IN_SEC,
TOKEN_EXPIRY_THRESHOLD_BUFFER_IN_SEC,
USER_AGENT,
)
from openfga_sdk.credentials import Credentials
from openfga_sdk.exceptions import AuthenticationError
from openfga_sdk.oauth2_common import _TokenState, jitter
from openfga_sdk.telemetry.attributes import TelemetryAttributes
from openfga_sdk.telemetry.telemetry import Telemetry


def jitter(loop_count, min_wait_in_ms):
"""
Generate a random jitter value for exponential backoff
"""
minimum = math.ceil(2**loop_count * min_wait_in_ms)
maximum = math.ceil(2 ** (loop_count + 1) * min_wait_in_ms)
jitter = random.randrange(minimum, maximum) / 1000

# If running in pytest, set jitter to 0 to speed up tests
if "pytest" in sys.modules:
jitter = 0

return jitter


class OAuth2Client:
def __init__(self, credentials: Credentials, configuration=None):
self._credentials = credentials
self._access_token = None
self._access_expiry_time = None
self._token_state: _TokenState | None = None
self._lock = asyncio.Lock()
self._telemetry = Telemetry()

if configuration is None:
Expand All @@ -45,13 +33,13 @@ def __init__(self, credentials: Credentials, configuration=None):

def _token_valid(self):
"""
Return whether token is valid
Return whether token is valid (with proactive expiry buffer to avoid using near-expired tokens)
"""
if self._access_token is None or self._access_expiry_time is None:
return False
if self._access_expiry_time < datetime.now():
state = self._token_state # atomic snapshot — either old or new, never torn
if state is None:
return False
return True
remaining = (state.expiry_time - datetime.now()).total_seconds()
return remaining > state.expiry_buffer

async def _obtain_token(self, client):
"""
Expand All @@ -76,7 +64,9 @@ async def _obtain_token(self, client):
# Add scope parameter if scopes are configured
if configuration.scopes is not None:
if isinstance(configuration.scopes, list):
scope_str = " ".join(s.strip() for s in configuration.scopes if s and s.strip())
scope_str = " ".join(
s.strip() for s in configuration.scopes if s and s.strip()
)
else:
scope_str = (
configuration.scopes.strip()
Expand Down Expand Up @@ -136,10 +126,15 @@ async def _obtain_token(self, client):
raise AuthenticationError(http_resp=raw_response)

if api_response.get("expires_in") and api_response.get("access_token"):
self._access_expiry_time = datetime.now() + timedelta(
seconds=int(api_response.get("expires_in"))
self._token_state = _TokenState(
access_token=api_response.get("access_token"),
expiry_time=datetime.now()
+ timedelta(seconds=int(api_response.get("expires_in"))),
expiry_buffer=(
TOKEN_EXPIRY_THRESHOLD_BUFFER_IN_SEC
+ random.random() * TOKEN_EXPIRY_JITTER_IN_SEC
),
)
self._access_token = api_response.get("access_token")
self._telemetry.metrics.credentialsRequest(
attributes={
TelemetryAttributes.fga_client_request_client_id: configuration.client_id
Expand All @@ -154,8 +149,8 @@ async def get_authentication_header(self, client):
"""
If configured, return the header for authentication
"""
# check to see token is valid
if not self._token_valid():
# In this case, the token is not valid, we need to get the refresh the token
await self._obtain_token(client)
return {"Authorization": f"Bearer {self._access_token}"}
async with self._lock:
if not self._token_valid():
await self._obtain_token(client)
return {"Authorization": f"Bearer {self._token_state.access_token}"}
28 changes: 28 additions & 0 deletions openfga_sdk/oauth2_common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import math
import random
import sys

from dataclasses import dataclass
from datetime import datetime


@dataclass(frozen=True)
class _TokenState:
access_token: str
expiry_time: datetime
expiry_buffer: float


def jitter(loop_count, min_wait_in_ms):
"""
Generate a random jitter value for exponential backoff
"""
minimum = math.ceil(2**loop_count * min_wait_in_ms)
maximum = math.ceil(2 ** (loop_count + 1) * min_wait_in_ms)
jitter = random.randrange(minimum, maximum) / 1000

# If running in pytest, set jitter to 0 to speed up tests
if "pytest" in sys.modules:
jitter = 0

return jitter
62 changes: 29 additions & 33 deletions openfga_sdk/sync/oauth2.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,30 @@
import json
import math
import random
import sys
import threading
import time

from datetime import datetime, timedelta

import urllib3

from openfga_sdk.configuration import Configuration
from openfga_sdk.constants import USER_AGENT
from openfga_sdk.constants import (
TOKEN_EXPIRY_JITTER_IN_SEC,
TOKEN_EXPIRY_THRESHOLD_BUFFER_IN_SEC,
USER_AGENT,
)
from openfga_sdk.credentials import Credentials
from openfga_sdk.exceptions import AuthenticationError
from openfga_sdk.oauth2_common import _TokenState, jitter
from openfga_sdk.telemetry.attributes import TelemetryAttributes
from openfga_sdk.telemetry.telemetry import Telemetry


def jitter(loop_count, min_wait_in_ms):
"""
Generate a random jitter value for exponential backoff
"""
minimum = math.ceil(2**loop_count * min_wait_in_ms)
maximum = math.ceil(2 ** (loop_count + 1) * min_wait_in_ms)
jitter = random.randrange(minimum, maximum) / 1000

# If running in pytest, set jitter to 0 to speed up tests
if "pytest" in sys.modules:
jitter = 0

return jitter


class OAuth2Client:
def __init__(self, credentials: Credentials, configuration=None):
self._credentials = credentials
self._access_token = None
self._access_expiry_time = None
self._token_state: _TokenState | None = None
self._lock = threading.Lock()
self._telemetry = Telemetry()

if configuration is None:
Expand All @@ -45,13 +34,13 @@ def __init__(self, credentials: Credentials, configuration=None):

def _token_valid(self):
"""
Return whether token is valid
Return whether token is valid (with proactive expiry buffer to avoid using near-expired tokens)
"""
if self._access_token is None or self._access_expiry_time is None:
return False
if self._access_expiry_time < datetime.now():
state = self._token_state # atomic snapshot — either old or new, never torn
if state is None:
return False
return True
remaining = (state.expiry_time - datetime.now()).total_seconds()
return remaining > state.expiry_buffer

def _obtain_token(self, client):
"""
Expand All @@ -76,7 +65,9 @@ def _obtain_token(self, client):
# Add scope parameter if scopes are configured
if configuration.scopes is not None:
if isinstance(configuration.scopes, list):
scope_str = " ".join(s.strip() for s in configuration.scopes if s and s.strip())
scope_str = " ".join(
s.strip() for s in configuration.scopes if s and s.strip()
)
else:
scope_str = (
configuration.scopes.strip()
Expand Down Expand Up @@ -136,10 +127,15 @@ def _obtain_token(self, client):
raise AuthenticationError(http_resp=raw_response)

if api_response.get("expires_in") and api_response.get("access_token"):
self._access_expiry_time = datetime.now() + timedelta(
seconds=int(api_response.get("expires_in"))
self._token_state = _TokenState(
access_token=api_response.get("access_token"),
expiry_time=datetime.now()
+ timedelta(seconds=int(api_response.get("expires_in"))),
expiry_buffer=(
TOKEN_EXPIRY_THRESHOLD_BUFFER_IN_SEC
+ random.random() * TOKEN_EXPIRY_JITTER_IN_SEC
),
)
self._access_token = api_response.get("access_token")
self._telemetry.metrics.credentialsRequest(
attributes={
TelemetryAttributes.fga_client_request_client_id: configuration.client_id
Expand All @@ -154,8 +150,8 @@ def get_authentication_header(self, client):
"""
If configured, return the header for authentication
"""
# check to see token is valid
if not self._token_valid():
# In this case, the token is not valid, we need to get the refresh the token
self._obtain_token(client)
return {"Authorization": f"Bearer {self._access_token}"}
with self._lock:
if not self._token_valid():
self._obtain_token(client)
return {"Authorization": f"Bearer {self._token_state.access_token}"}
Loading
Loading