dara-core 1.13.1__py3-none-any.whl → 1.14.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
dara/core/auth/base.py CHANGED
@@ -18,7 +18,7 @@ limitations under the License.
18
18
  import abc
19
19
  from typing import Any, ClassVar, Dict, Union
20
20
 
21
- from fastapi import Response
21
+ from fastapi import HTTPException, Response
22
22
  from pydantic import BaseModel
23
23
  from typing_extensions import TypedDict
24
24
 
@@ -91,6 +91,18 @@ class BaseAuthConfig(BaseModel, abc.ABC):
91
91
  :param token: encoded token
92
92
  """
93
93
 
94
+ def refresh_token(self, old_token: TokenData, refresh_token: str) -> tuple[str, str]:
95
+ """
96
+ Create a new session token and refresh token from a refresh token.
97
+
98
+ Note: the new issued session token should include the same session_id as the old token
99
+
100
+ :param old_token: old session token data
101
+ :param refresh_token: encoded refresh token
102
+ :return: new session token, new refresh token
103
+ """
104
+ raise HTTPException(400, f'Auth config {self.__class__.__name__} does not support token refresh')
105
+
94
106
  def revoke_token(self, token: str, response: Response) -> Union[SuccessResponse, RedirectResponse]:
95
107
  """
96
108
  Revoke a session token.
dara/core/auth/routes.py CHANGED
@@ -16,9 +16,18 @@ limitations under the License.
16
16
  """
17
17
 
18
18
  from inspect import iscoroutinefunction
19
+ from typing import Union, cast
19
20
 
20
21
  import jwt
21
- from fastapi import APIRouter, Depends, HTTPException, Request, Response
22
+ from fastapi import (
23
+ APIRouter,
24
+ BackgroundTasks,
25
+ Cookie,
26
+ Depends,
27
+ HTTPException,
28
+ Request,
29
+ Response,
30
+ )
22
31
  from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
23
32
 
24
33
  from dara.core.auth.base import BaseAuthConfig
@@ -31,6 +40,7 @@ from dara.core.auth.definitions import (
31
40
  AuthError,
32
41
  SessionRequestBody,
33
42
  )
43
+ from dara.core.auth.utils import cached_refresh_token, decode_token
34
44
  from dara.core.logging import dev_logger
35
45
 
36
46
  auth_router = APIRouter()
@@ -103,6 +113,71 @@ async def _revoke_session(response: Response, credentials: HTTPAuthorizationCred
103
113
  raise HTTPException(status_code=400, detail=BAD_REQUEST_ERROR('No auth credentials passed'))
104
114
 
105
115
 
116
+ @auth_router.post('/refresh-token')
117
+ async def handle_refresh_token(
118
+ response: Response,
119
+ background_tasks: BackgroundTasks,
120
+ dara_refresh_token: Union[str, None] = Cookie(default=None),
121
+ credentials: HTTPAuthorizationCredentials = Depends(HTTPBearer()),
122
+ ):
123
+ """
124
+ Given a refresh token, issues a new session token and refresh token cookie.
125
+
126
+ :param response: FastAPI response object
127
+ :param dara_refresh_token: refresh token cookie
128
+ :param settings: env settings object
129
+ """
130
+ if dara_refresh_token is None:
131
+ raise HTTPException(status_code=400, detail=BAD_REQUEST_ERROR('No refresh token provided'))
132
+
133
+ # Check scheme is correct
134
+ if credentials.scheme != 'Bearer':
135
+ raise HTTPException(
136
+ status_code=400,
137
+ detail=BAD_REQUEST_ERROR(
138
+ 'Invalid authentication scheme, previous Bearer token must be included in the refresh request'
139
+ ),
140
+ )
141
+
142
+ from dara.core.internal.registries import auth_registry
143
+
144
+ auth_config: BaseAuthConfig = auth_registry.get('auth_config')
145
+
146
+ try:
147
+ # decode the old token ignoring expiry date
148
+ old_token_data = decode_token(credentials.credentials, options={'verify_exp': False})
149
+
150
+ # Refresh logic up to implementation - passing in old token data so session_id can be preserved
151
+ session_token, refresh_token = await cached_refresh_token(
152
+ auth_config.refresh_token, old_token_data, dara_refresh_token
153
+ )
154
+
155
+ # Using 'Strict' as it is only used for the refresh-token endpoint so cross-site requests are not expected
156
+ response.set_cookie(
157
+ key='dara_refresh_token', value=refresh_token, secure=True, httponly=True, samesite='strict'
158
+ )
159
+ return {'token': session_token}
160
+ except BaseException as e:
161
+ # Regardless of exception type, clear the refresh token cookie
162
+ response.delete_cookie('dara_refresh_token')
163
+ headers = {'set-cookie': response.headers['set-cookie']}
164
+
165
+ # If an explicit HTTPException was raised, re-raise it with the cookie header
166
+ if isinstance(e, HTTPException):
167
+ dev_logger.error('Auth Error', error=e)
168
+ e.headers = headers
169
+ raise e
170
+
171
+ # Explicitly handle expired signature error
172
+ if isinstance(e, jwt.ExpiredSignatureError):
173
+ dev_logger.error('Expired Token Signature', error=e)
174
+ raise HTTPException(status_code=401, detail=EXPIRED_TOKEN_ERROR, headers=headers)
175
+
176
+ # Otherwise show a generic invalid token error
177
+ dev_logger.error('Invalid Token', error=cast(Exception, e))
178
+ raise HTTPException(status_code=401, detail=INVALID_TOKEN_ERROR, headers=headers)
179
+
180
+
106
181
  # Request to retrieve a session token from the backend. The app does this on startup.
107
182
  @auth_router.post('/session')
108
183
  async def _get_session(body: SessionRequestBody):
dara/core/auth/utils.py CHANGED
@@ -15,11 +15,13 @@ See the License for the specific language governing permissions and
15
15
  limitations under the License.
16
16
  """
17
17
 
18
+ import asyncio
18
19
  import uuid
19
20
  from datetime import datetime, timedelta, timezone
20
- from typing import List, Optional, Union
21
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
21
22
 
22
23
  import jwt
24
+ from anyio import to_thread
23
25
 
24
26
  from dara.core.auth.definitions import (
25
27
  EXPIRED_TOKEN_ERROR,
@@ -33,12 +35,15 @@ from dara.core.internal.settings import get_settings
33
35
  from dara.core.logging import dev_logger
34
36
 
35
37
 
36
- def decode_token(token: str) -> TokenData:
38
+ def decode_token(token: str, **kwargs) -> TokenData:
37
39
  """
38
40
  Decode a JWT token
41
+
42
+ :param token: the JWT token to decode
43
+ :param kwargs: additional arguments to pass to the jwt.decode function
39
44
  """
40
45
  try:
41
- return TokenData.parse_obj(jwt.decode(token, get_settings().jwt_secret, algorithms=[JWT_ALGO]))
46
+ return TokenData.parse_obj(jwt.decode(token, get_settings().jwt_secret, algorithms=[JWT_ALGO], **kwargs))
42
47
  except jwt.ExpiredSignatureError:
43
48
  raise AuthError(code=401, detail=EXPIRED_TOKEN_ERROR)
44
49
  except jwt.DecodeError:
@@ -52,11 +57,13 @@ def sign_jwt(
52
57
  groups: List[str],
53
58
  id_token: Optional[str] = None,
54
59
  exp: Optional[Union[datetime, int]] = None,
60
+ session_id: Optional[str] = None,
55
61
  ):
56
62
  """
57
63
  Create a new Dara JWT token
58
64
  """
59
- session_id = str(uuid.uuid4())
65
+ if session_id is None:
66
+ session_id = str(uuid.uuid4())
60
67
 
61
68
  # Default expiry is 1 day unless specified
62
69
  if exp is None:
@@ -95,3 +102,120 @@ def get_user_data():
95
102
  )
96
103
 
97
104
  return user_data
105
+
106
+
107
+ class AsyncTokenRefreshCache:
108
+ """
109
+ An asynchronous cache for token refresh operations that handles concurrent requests
110
+ and provides time-based cache invalidation.
111
+
112
+ This cache is designed to prevent multiple simultaneous refresh attempts with the
113
+ same refresh token, while also providing a short-term cache to reduce unnecessary
114
+ token refreshes from multiple tabs/windows.
115
+ """
116
+
117
+ def __init__(self, ttl_seconds: int = 5):
118
+ self.cache: Dict[str, Tuple[Any, datetime]] = {}
119
+ self.locks: Dict[str, asyncio.Lock] = {}
120
+ self.locks_lock = asyncio.Lock()
121
+ self.ttl = timedelta(seconds=ttl_seconds)
122
+
123
+ async def _get_or_create_lock(self, key: str) -> asyncio.Lock:
124
+ """
125
+ Get an existing lock for the given key or create a new one if it doesn't exist.
126
+
127
+ This method is thread-safe and ensures that only one lock exists per key.
128
+
129
+ :param key: The key to get or create a lock for.
130
+ """
131
+
132
+ async with self.locks_lock:
133
+ if key not in self.locks:
134
+ self.locks[key] = asyncio.Lock()
135
+ return self.locks[key]
136
+
137
+ def _cleanup_old_entries(self):
138
+ """
139
+ Remove expired entries from both the cache and locks dictionaries.
140
+
141
+ This method is called before each cache access to prevent memory leaks
142
+ from accumulated expired entries.
143
+ """
144
+ current_time = datetime.now()
145
+ expired_keys = [key for key, (_, timestamp) in self.cache.items() if current_time - timestamp > self.ttl]
146
+ for key in expired_keys:
147
+ self.cache.pop(key, None)
148
+ # We can modify self.locks here because we're always under an async lock when calling this
149
+ self.locks.pop(key, None)
150
+
151
+ def get_cached_value(self, key: str) -> Tuple[Any, bool]:
152
+ """
153
+ Retrieve a value from the cache if it exists and hasn't expired.
154
+
155
+ :param key: The key to retrieve from the cache.
156
+ :return: A tuple containing the value and a boolean indicating whether the value was found.
157
+ """
158
+ self._cleanup_old_entries()
159
+ if key in self.cache:
160
+ value, timestamp = self.cache[key]
161
+ if datetime.now() - timestamp <= self.ttl:
162
+ return value, True
163
+ return None, False
164
+
165
+ def set_cached_value(self, key: str, value: Any):
166
+ """
167
+ Set a value in the cache with the current timestamp.
168
+
169
+ :param key: The key to set in the cache.
170
+ :param value: The value to set in the cache.
171
+ """
172
+ self.cache[key] = (value, datetime.now())
173
+
174
+ def clear(self):
175
+ """
176
+ Clear the cache and locks dictionaries.
177
+ """
178
+ self.cache.clear()
179
+ self.locks.clear()
180
+
181
+
182
+ token_refresh_cache = AsyncTokenRefreshCache(ttl_seconds=5)
183
+ """
184
+ Shared token refresh cache instance
185
+ """
186
+
187
+
188
+ async def cached_refresh_token(
189
+ do_refresh_token: Callable[[TokenData, str], Tuple[str, str]], old_token_data: TokenData, refresh_token: str
190
+ ):
191
+ """
192
+ A utility to run a token refresh method with caching to prevent multiple concurrent refreshes
193
+ and short-term caching to reduce unnecessary refreshes from multiple tabs/windows.
194
+
195
+ :param do_refresh_token: The function to perform the token refresh
196
+ :param old_token_data: The old token data
197
+ :param refresh_token: The refresh token to use
198
+ """
199
+ cache_key = refresh_token
200
+
201
+ # check for cache hit
202
+ cached_result, found = token_refresh_cache.get_cached_value(cache_key)
203
+ if found:
204
+ return cached_result
205
+
206
+ # cache miss, acquire lock so only one call for given refresh_token is allowed
207
+ lock = await token_refresh_cache._get_or_create_lock(cache_key)
208
+
209
+ async with lock:
210
+ # check cache again in case another call already refreshed the token while we were waiting
211
+ cached_result, found = token_refresh_cache.get_cached_value(cache_key)
212
+ if found:
213
+ return cached_result
214
+
215
+ # Run the refresh function
216
+ result = await to_thread.run_sync(do_refresh_token, old_token_data, refresh_token)
217
+
218
+ # update cache
219
+ token_refresh_cache.set_cached_value(cache_key, result)
220
+
221
+ return result
@@ -208,9 +208,11 @@ def create_router(config: Configuration):
208
208
  'application_name': get_settings().project_name,
209
209
  }
210
210
 
211
- @core_api_router.get('/auth-components')
212
- async def get_auth_components(): # pylint: disable=unused-variable
213
- return config.auth_config.component_config
211
+ @core_api_router.get('/auth-config')
212
+ async def get_auth_config(): # pylint: disable=unused-variable
213
+ return {
214
+ 'auth_components': config.auth_config.component_config.dict(),
215
+ }
214
216
 
215
217
  @core_api_router.get('/components', dependencies=[Depends(verify_session)])
216
218
  async def get_components(name: Optional[str] = None): # pylint: disable=unused-variable
@@ -167,9 +167,10 @@ class WebSocketHandler:
167
167
 
168
168
  def __init__(self, channel_id: str):
169
169
  send_stream, receive_stream = create_memory_object_stream[ServerMessage](math.inf)
170
- self.channel_id = channel_id
171
- self.send_stream = send_stream
172
170
  self.receive_stream = receive_stream
171
+ self.send_stream = send_stream
172
+
173
+ self.channel_id = channel_id
173
174
  self.pending_responses = {}
174
175
 
175
176
  async def send_message(self, message: ServerMessage):
@@ -446,17 +447,20 @@ async def ws_handler(websocket: WebSocket, token: Optional[str] = Query(default=
446
447
  else:
447
448
  sessions_registry.set(user_identifier, {token_content.session_id})
448
449
 
449
- # Set Auth context vars for the WS connection
450
- USER.set(
451
- UserData(
452
- identity_id=token_content.identity_id,
453
- identity_name=token_content.identity_name,
454
- identity_email=token_content.identity_email,
455
- groups=token_content.groups,
450
+ def update_context(token_data: TokenData):
451
+ USER.set(
452
+ UserData(
453
+ identity_id=token_data.identity_id,
454
+ identity_name=token_data.identity_name,
455
+ identity_email=token_data.identity_email,
456
+ groups=token_data.groups,
457
+ )
456
458
  )
457
- )
458
- SESSION_ID.set(token_content.session_id)
459
- ID_TOKEN.set(token_content.id_token)
459
+ SESSION_ID.set(token_data.session_id)
460
+ ID_TOKEN.set(token_data.id_token)
461
+
462
+ # Set initial Auth context vars for the WS connection
463
+ update_context(token_content)
460
464
 
461
465
  # Change protocol from http to ws - from this point exceptions can't be raised
462
466
  await websocket.accept()
@@ -491,6 +495,12 @@ async def ws_handler(websocket: WebSocket, token: Optional[str] = Query(default=
491
495
  # Heartbeat to keep connection alive
492
496
  if data['type'] == 'ping':
493
497
  await websocket.send_json({'type': 'pong', 'message': None})
498
+ elif data['type'] == 'token_update':
499
+ try:
500
+ # update Auth context vars for the WS connection
501
+ update_context(decode_token(data['message']))
502
+ except Exception as e:
503
+ eng_logger.error('Error updating token data', error=e)
494
504
  else:
495
505
  try:
496
506
  parsed_data = parse_obj_as(ClientMessage, data)