dara-core 1.22.4__py3-none-any.whl → 1.23.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,523 @@
1
+ from typing import ClassVar
2
+ from urllib.parse import urlencode
3
+
4
+ import httpx
5
+ import jwt
6
+ from fastapi import HTTPException, Response
7
+ from jwt import PyJWKClient
8
+ from pydantic import Field
9
+ from pydantic.config import ConfigDict
10
+
11
+ from dara.core.definitions import ApiRoute
12
+ from dara.core.internal.settings import get_settings
13
+ from dara.core.logging import dev_logger
14
+
15
+ from ..base import AuthComponent, AuthComponentConfig, BaseAuthConfig
16
+ from ..definitions import (
17
+ ID_TOKEN,
18
+ INVALID_TOKEN_ERROR,
19
+ JWT_ALGO,
20
+ SESSION_ID,
21
+ UNAUTHORIZED_ERROR,
22
+ USER,
23
+ AuthError,
24
+ RedirectResponse,
25
+ SessionRequestBody,
26
+ SuccessResponse,
27
+ TokenData,
28
+ TokenResponse,
29
+ UserData,
30
+ UserGroup,
31
+ )
32
+ from ..utils import decode_token, sign_jwt
33
+ from .definitions import (
34
+ JWK_CLIENT_REGISTRY_KEY,
35
+ REFRESH_TOKEN_COOKIE_NAME,
36
+ IdTokenClaims,
37
+ OIDCDiscoveryMetadata,
38
+ StateObject,
39
+ )
40
+ from .routes import sso_callback
41
+ from .settings import get_oidc_settings
42
+ from .utils import decode_id_token, get_token_from_idp
43
+
44
+ OIDCAuthLogin = AuthComponent(js_module='@darajs/core', py_module='dara.core', js_name='OIDCAuthLogin')
45
+
46
+ OIDCAuthLogout = AuthComponent(js_module='@darajs/core', py_module='dara.core', js_name='OIDCAuthLogout')
47
+
48
+ OIDCAuthSSOCallback = AuthComponent(js_module='@darajs/core', py_module='dara.core', js_name='OIDCAuthSSOCallback')
49
+
50
+
51
+ class OIDCAuthConfig(BaseAuthConfig):
52
+ """
53
+ Generic OIDC auth config.
54
+
55
+ This config requires the following ENV variables to be set:
56
+ - SSO_ISSUER_URL - URL of the identity provider issuer; should expose a `SSO_ISSUER_URL/.well-known/openid-configuration` endpoint for discovery
57
+ - SSO_CLIENT_ID - client_id generated for the application by the identity provider
58
+ - SSO_CLIENT_SECRET - client_secret generated for the application by the identity provider
59
+ - SSO_REDIRECT_URI - URL that identity provider should redirect back to, in most cases https://deployed-app-url/sso-callback
60
+ - SSO_GROUPS - comma separated list of allowed SSO groups
61
+
62
+ In addition, the following ENV variables can be set:
63
+ - SSO_ALLOWED_IDENTITY_ID - if set, only the user with matching identity_id will be allowed to access the app
64
+ - SSO_VERIFY_AUDIENCE - if set, the ID token will be verified against the configured audience, by default `sso_client_id`
65
+ - SSO_EXTRA_AUDIENCE - if set, extra audiences to verify against the ID token in addition to `sso_client_id`
66
+ - SSO_SCOPES - space-separated list of scopes to request from the identity provider, defaults to `openid`
67
+ - SSO_JWT_ALGO - algorithm to use for verifying IDP-provided JWTs, defaults to `ES256`
68
+ """
69
+
70
+ # NOTE: the config follows OIDC specification, but makes a few concessions
71
+ # to be more lenient with the internal IDP. These are marked with CONCESSION comments.
72
+
73
+ required_routes: ClassVar[list[ApiRoute]] = [sso_callback]
74
+
75
+ component_config: ClassVar[AuthComponentConfig] = AuthComponentConfig(
76
+ login=OIDCAuthLogin,
77
+ logout=OIDCAuthLogout,
78
+ extra={
79
+ 'sso-callback': OIDCAuthSSOCallback,
80
+ },
81
+ )
82
+
83
+ client: httpx.AsyncClient = Field(default_factory=httpx.AsyncClient, exclude=True)
84
+
85
+ model_config = ConfigDict(arbitrary_types_allowed=True)
86
+
87
+ # Populated during startup_hook
88
+ _discovery: OIDCDiscoveryMetadata | None = None
89
+
90
+ @property
91
+ def discovery(self) -> OIDCDiscoveryMetadata:
92
+ """Get the OIDC discovery metadata. Raises if not initialized."""
93
+ if self._discovery is None:
94
+ raise RuntimeError('OIDC discovery metadata not initialized. Ensure startup_hook has been called.')
95
+ return self._discovery
96
+
97
+ @property
98
+ def allowed_groups(self):
99
+ # initialise user groups from ENV
100
+ env_groups = get_oidc_settings().groups
101
+ parsed_groups = env_groups.split(',')
102
+ return {group.strip(): UserGroup(name=group.strip()) for group in parsed_groups}
103
+
104
+ def get_discovery_url(self) -> str:
105
+ issuer_url = get_oidc_settings().issuer_url
106
+ return f'{issuer_url}/.well-known/openid-configuration'
107
+
108
+ async def startup_hook(self) -> None:
109
+ await self.client.__aenter__()
110
+
111
+ # 1. Enforce SSO env vars are set - this will run validation and raise if not set
112
+ get_settings.cache_clear()
113
+ get_settings()
114
+ get_oidc_settings.cache_clear()
115
+ oidc_settings = get_oidc_settings()
116
+
117
+ # 2. Fetch OIDC discovery document
118
+ discovery_url = self.get_discovery_url()
119
+ dev_logger.info(f'Fetching OIDC discovery document from {discovery_url}...')
120
+ try:
121
+ response = await self.client.get(discovery_url)
122
+ response.raise_for_status()
123
+ except httpx.HTTPStatusError as e:
124
+ raise RuntimeError(
125
+ f'Failed to fetch OIDC discovery document from {discovery_url}: HTTP {e.response.status_code}'
126
+ ) from e
127
+ except httpx.RequestError as e:
128
+ raise RuntimeError(f'Failed to fetch OIDC discovery document from {discovery_url}: {e}') from e
129
+
130
+ try:
131
+ self._discovery = OIDCDiscoveryMetadata.model_validate(response.json())
132
+ except Exception as e:
133
+ raise RuntimeError(f'Failed to parse OIDC discovery document from {discovery_url}: {e}') from e
134
+
135
+ dev_logger.info(f'Successfully fetched OIDC discovery document from {discovery_url}')
136
+
137
+ # 3. Register a PyJWKClient instance bound to the jwks_uri from discovery
138
+ from dara.core.internal.registries import utils_registry
139
+
140
+ py_jwk_client = PyJWKClient(self.discovery.jwks_uri, lifespan=oidc_settings.jwks_lifespan)
141
+ utils_registry.register(JWK_CLIENT_REGISTRY_KEY, py_jwk_client)
142
+
143
+ def generate_state(self, redirect_to: str | None = None) -> str:
144
+ """
145
+ Generate a signed JWT state parameter for CSRF protection.
146
+
147
+ The state is a JWT signed with the application's secret containing:
148
+ - nonce: cryptographically random value for uniqueness
149
+ - redirect_to: optional URL to redirect to after authentication
150
+ - exp: expiration timestamp
151
+
152
+ :param redirect_to: Optional URL to redirect to after successful authentication
153
+ :return: Signed JWT string to use as the state parameter
154
+ """
155
+ payload = StateObject(redirect_to=redirect_to)
156
+ return jwt.encode(payload.model_dump(), get_settings().jwt_secret, algorithm=JWT_ALGO)
157
+
158
+ def verify_state(self, state: str) -> StateObject:
159
+ """
160
+ Verify and decode the state JWT.
161
+
162
+ :param state: The state JWT string from the callback
163
+ :return: Decoded payload containing nonce and optional redirect_to
164
+ :raises jwt.ExpiredSignatureError: If the state has expired
165
+ :raises jwt.InvalidTokenError: If the state is invalid
166
+ """
167
+ return StateObject.model_validate(jwt.decode(state, get_settings().jwt_secret, algorithms=[JWT_ALGO]))
168
+
169
+ def get_authorization_params(self, state: str) -> dict[str, str]:
170
+ """
171
+ Build the query parameters for the authorization request per OpenID Connect Core 1.0 Section 3.1.2.1.
172
+
173
+ Required parameters:
174
+ - scope: Must contain 'openid', may contain additional scopes (from SSO_SCOPES setting)
175
+ - response_type: 'code' for Authorization Code Flow
176
+ - client_id: OAuth 2.0 Client Identifier (from SSO_CLIENT_ID setting)
177
+ - redirect_uri: Redirection URI for the response (from SSO_REDIRECT_URI setting)
178
+
179
+ Recommended parameters:
180
+ - state: Opaque value for CSRF protection (signed JWT containing nonce and optional redirect URL)
181
+
182
+ Override this method to add optional parameters like nonce, display, prompt, max_age, etc.
183
+ """
184
+ oidc_settings = get_oidc_settings()
185
+ return {
186
+ 'scope': oidc_settings.scopes,
187
+ 'response_type': 'code',
188
+ 'client_id': oidc_settings.client_id,
189
+ 'redirect_uri': oidc_settings.redirect_uri,
190
+ 'state': state,
191
+ }
192
+
193
+ def get_authorization_url(self, state: str) -> str:
194
+ """
195
+ Build the full authorization URL using the discovery document's authorization_endpoint.
196
+ """
197
+ params = self.get_authorization_params(state)
198
+ return f'{self.discovery.authorization_endpoint}?{urlencode(params)}'
199
+
200
+ def get_token(self, body: SessionRequestBody) -> TokenResponse | RedirectResponse:
201
+ """
202
+ Get token from the IDP - redirect to the authorization endpoint.
203
+
204
+ Generates a signed JWT state parameter containing a nonce for CSRF protection
205
+ and optionally the redirect URL for post-authentication navigation.
206
+
207
+ :param body: Request body, may contain redirect_to for post-auth navigation
208
+ """
209
+ state = self.generate_state(redirect_to=body.redirect_to)
210
+ return RedirectResponse(redirect_uri=self.get_authorization_url(state))
211
+
212
+ def extract_user_data_from_id_token(self, claims: IdTokenClaims) -> UserData:
213
+ """
214
+ Extract user data from ID token claims.
215
+
216
+ Override this method in subclasses to handle provider-specific claim structures.
217
+ The default implementation uses standard OIDC claims, with support for the
218
+ non-standard 'identity' claim.
219
+
220
+ :param claims: Decoded ID token claims
221
+ :return: UserData extracted from the claims
222
+ """
223
+ # Check for non-standard 'identity' claim (Causalens IDP)
224
+ # This is a nested object with id, name, email fields
225
+ identity_claim = getattr(claims, 'identity', None)
226
+ if isinstance(identity_claim, dict):
227
+ identity_id = identity_claim.get('id') or claims.sub
228
+ identity_name = identity_claim.get('name')
229
+ identity_email = identity_claim.get('email') or claims.email
230
+ else:
231
+ # Standard OIDC: use 'sub' as the identity ID
232
+ identity_id = claims.sub
233
+ identity_email = claims.email
234
+ identity_name = None
235
+
236
+ # Fall back to standard claims for name if not set
237
+ if not identity_name:
238
+ identity_name = (
239
+ claims.name
240
+ or claims.preferred_username
241
+ or claims.nickname
242
+ or (
243
+ f'{claims.given_name} {claims.family_name}'.strip()
244
+ if claims.given_name or claims.family_name
245
+ else None
246
+ )
247
+ or identity_email
248
+ or claims.sub
249
+ )
250
+
251
+ return UserData(
252
+ identity_id=identity_id,
253
+ identity_name=identity_name,
254
+ identity_email=identity_email,
255
+ groups=claims.groups,
256
+ )
257
+
258
+ def verify_token(self, token: str) -> TokenData:
259
+ """
260
+ Verify a session token.
261
+
262
+ Handles both:
263
+ 1. Dara-issued session tokens (wrapped tokens signed with jwt_secret)
264
+ 2. Raw IDP tokens (ID tokens signed by the OIDC provider)
265
+
266
+ Sets SESSION_ID, USER, and ID_TOKEN context variables.
267
+
268
+ :param token: encoded JWT token (either Dara session token or raw IDP token)
269
+ :return: TokenData for the verified token
270
+ """
271
+ # First, decode without verification to check the issuer
272
+ try:
273
+ unverified = jwt.decode(token, options={'verify_signature': False})
274
+ except jwt.DecodeError as e:
275
+ raise AuthError(code=401, detail=INVALID_TOKEN_ERROR) from e
276
+
277
+ # Check if this is a raw IDP token (issuer matches the configured SSO issuer)
278
+ if unverified.get('iss') == get_oidc_settings().issuer_url:
279
+ return self._verify_idp_token(token)
280
+ else:
281
+ return self._verify_dara_token(token)
282
+
283
+ def _verify_idp_token(self, token: str) -> TokenData:
284
+ """
285
+ Verify a raw ID token from the IDP.
286
+
287
+ :param token: Raw ID token JWT
288
+ :return: TokenData extracted from the ID token
289
+ """
290
+ # Decode and verify the ID token signature using JWKS
291
+ claims = decode_id_token(token)
292
+
293
+ # Extract user data (can be overridden for provider-specific claim structures)
294
+ user_data = self.extract_user_data_from_id_token(claims)
295
+
296
+ # Verify user has access based on groups
297
+ self.verify_user_access(user_data)
298
+
299
+ # Set context variables
300
+ SESSION_ID.set(user_data.identity_id)
301
+ USER.set(user_data)
302
+ ID_TOKEN.set(token)
303
+
304
+ # Return TokenData structure
305
+ return TokenData(
306
+ session_id=user_data.identity_id,
307
+ exp=claims.exp,
308
+ identity_id=user_data.identity_id,
309
+ identity_name=user_data.identity_name,
310
+ identity_email=user_data.identity_email,
311
+ id_token=token,
312
+ groups=user_data.groups,
313
+ )
314
+
315
+ def _verify_dara_token(self, token: str) -> TokenData:
316
+ """
317
+ Verify a Dara-issued session token.
318
+
319
+ :param token: Dara session token JWT
320
+ :return: TokenData from the decoded token
321
+ """
322
+ # Decode and verify with our jwt_secret
323
+ token_data = decode_token(token)
324
+
325
+ user_data = UserData.from_token_data(token_data)
326
+
327
+ # Verify user has access based on groups
328
+ self.verify_user_access(user_data)
329
+
330
+ # Set context variables
331
+ SESSION_ID.set(token_data.session_id)
332
+ USER.set(user_data)
333
+ ID_TOKEN.set(token_data.id_token)
334
+
335
+ return token_data
336
+
337
+ def verify_user_access(self, user_data: UserData) -> None:
338
+ """
339
+ Verify that the user has access based on their groups.
340
+
341
+ :param user_groups: List of groups the user belongs to
342
+ :raises HTTPException: If user doesn't have access
343
+ """
344
+ # Identity verification enabled
345
+ if (allowed_identity_id := get_oidc_settings().allowed_identity_id) is not None:
346
+ identity_id = user_data.identity_id
347
+ if identity_id != allowed_identity_id:
348
+ dev_logger.error(
349
+ 'User identity does not have access to this app',
350
+ error=Exception(UNAUTHORIZED_ERROR),
351
+ extra={
352
+ 'identity_id': identity_id,
353
+ },
354
+ )
355
+ raise HTTPException(status_code=403, detail=UNAUTHORIZED_ERROR)
356
+
357
+ allowed_groups = set(self.allowed_groups.keys())
358
+ user_group_set = set(user_data.groups or [])
359
+
360
+ # Check if there's any intersection between allowed and user groups
361
+ if not allowed_groups.intersection(user_group_set):
362
+ dev_logger.error(
363
+ 'User group does not have access to this app',
364
+ error=Exception('Unauthorized'),
365
+ extra={'user_groups': user_data.groups or [], 'allowed_groups': list(allowed_groups)},
366
+ )
367
+ raise HTTPException(status_code=403, detail=UNAUTHORIZED_ERROR)
368
+
369
+ def get_token_endpoint(self) -> str:
370
+ """
371
+ Get the token endpoint URL from discovery.
372
+
373
+ :return: Token endpoint URL
374
+ :raises RuntimeError: If token_endpoint is not available in discovery
375
+ """
376
+ return self.discovery.token_endpoint
377
+
378
+ async def refresh_token(self, old_token: TokenData, refresh_token: str) -> tuple[str, str]:
379
+ """
380
+ Refresh the session using an OIDC refresh token.
381
+
382
+ Per RFC 6749 Section 6, sends a refresh token grant to the token endpoint
383
+ to obtain new access/id tokens.
384
+
385
+ Note: the new issued session token includes the same session_id as the old token
386
+ to maintain session continuity.
387
+
388
+ :param old_token: Old session token data (used to preserve session_id)
389
+ :param refresh_token: OIDC refresh token
390
+ :return: Tuple of (new_session_token, new_refresh_token)
391
+ :raises HTTPException: If the refresh fails
392
+ """
393
+ # Request new tokens from the IDP
394
+ oidc_tokens = await get_token_from_idp(
395
+ self,
396
+ {
397
+ 'grant_type': 'refresh_token',
398
+ 'refresh_token': refresh_token,
399
+ },
400
+ )
401
+
402
+ # Ensure we got an id_token back
403
+ if not oidc_tokens.id_token:
404
+ raise HTTPException(status_code=401, detail=INVALID_TOKEN_ERROR)
405
+
406
+ # Decode and verify the new ID token
407
+ claims = decode_id_token(oidc_tokens.id_token)
408
+
409
+ # Extract user data from claims
410
+ user_data = self.extract_user_data_from_id_token(claims)
411
+
412
+ # Verify user still has access
413
+ self.verify_user_access(user_data)
414
+
415
+ # Create a new Dara session token, preserving the original session_id
416
+ new_session_token = sign_jwt(
417
+ identity_id=user_data.identity_id,
418
+ identity_name=user_data.identity_name,
419
+ identity_email=user_data.identity_email,
420
+ groups=user_data.groups or [],
421
+ id_token=oidc_tokens.id_token,
422
+ exp=int(claims.exp),
423
+ session_id=old_token.session_id,
424
+ )
425
+
426
+ # Return new session token and refresh token (or the old one if not rotated)
427
+ new_refresh_token = oidc_tokens.refresh_token or refresh_token
428
+
429
+ return new_session_token, new_refresh_token
430
+
431
+ def get_end_session_endpoint(self) -> str | None:
432
+ """
433
+ Get the end session endpoint URL.
434
+
435
+ Uses the end_session_endpoint from OIDC discovery if available.
436
+
437
+ Override this method in subclasses to customize the logout endpoint.
438
+ """
439
+ return self.discovery.end_session_endpoint
440
+
441
+ def get_logout_params(self, id_token: str | None) -> dict[str, str]:
442
+ """
443
+ Build the query parameters for the logout/end session request.
444
+
445
+ Per OpenID Connect RP-Initiated Logout 1.0:
446
+ - id_token_hint: RECOMMENDED. ID Token previously issued by the OP, used as a hint
447
+ about the End-User's current authenticated session.
448
+ - client_id: OPTIONAL. OAuth 2.0 Client Identifier. Required if id_token_hint is not provided.
449
+ - post_logout_redirect_uri: OPTIONAL. URI to redirect to after logout.
450
+ - state: OPTIONAL. Opaque value for maintaining state.
451
+
452
+ Override this method to add custom parameters like post_logout_redirect_uri.
453
+
454
+ :param id_token: The ID token to use as a hint, or None if not available
455
+ :return: Dictionary of query parameters
456
+ """
457
+ oidc_settings = get_oidc_settings()
458
+ params: dict[str, str] = {}
459
+
460
+ if id_token:
461
+ params['id_token_hint'] = id_token
462
+
463
+ # Include client_id if we're verifying audience, or if no id_token_hint is provided
464
+ if oidc_settings.verify_audience or not id_token:
465
+ params['client_id'] = oidc_settings.client_id
466
+
467
+ return params
468
+
469
+ def get_logout_url(self, id_token: str | None = None) -> str | None:
470
+ """
471
+ Build the full logout URL for RP-Initiated Logout.
472
+
473
+ :param id_token: The ID token to use as a hint, or None if not available
474
+ :return: Full logout URL with query parameters
475
+ """
476
+ endpoint = self.get_end_session_endpoint()
477
+
478
+ if not endpoint:
479
+ return None
480
+
481
+ params = self.get_logout_params(id_token)
482
+
483
+ if params:
484
+ return f'{endpoint}?{urlencode(params)}'
485
+ return endpoint
486
+
487
+ def revoke_token(self, token: str, response: Response) -> SuccessResponse | RedirectResponse:
488
+ """
489
+ Revoke the session and redirect to the OP's end session endpoint.
490
+
491
+ Per OpenID Connect RP-Initiated Logout 1.0, this initiates logout at the OP
492
+ by redirecting to the end_session_endpoint with the id_token_hint.
493
+
494
+ :param token: Session token to revoke (Dara-issued or raw IDP token)
495
+ :param response: FastAPI response object
496
+ :return: RedirectResponse to the logout URL
497
+ """
498
+ oidc_settings = get_oidc_settings()
499
+
500
+ # Clean up the refresh token cookie
501
+ response.delete_cookie(REFRESH_TOKEN_COOKIE_NAME)
502
+
503
+ # Extract the ID token to use as a hint
504
+ id_token: str | None = None
505
+
506
+ try:
507
+ # Decode without verification to check the issuer
508
+ unverified = jwt.decode(token, options={'verify_signature': False})
509
+
510
+ # Raw IDP token -> use directly as the id_token_hint
511
+ # Dara-issued token -> extract the embedded id_token
512
+ id_token = token if unverified.get('iss') == oidc_settings.issuer_url else unverified.get('id_token')
513
+ except jwt.DecodeError:
514
+ # If we can't decode the token, proceed without id_token_hint
515
+ dev_logger.warning('Failed to decode token for logout, proceeding without id_token_hint')
516
+
517
+ logout_url = self.get_logout_url(id_token)
518
+
519
+ # IDP doesn't support RP-Initiated Logout, so treat logout as success
520
+ if not logout_url:
521
+ return {'success': True}
522
+
523
+ return RedirectResponse(redirect_uri=logout_url)