auth0-server-python 1.0.0b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1123 @@
1
+ """
2
+ Main client for auth0-server-python SDK.
3
+ Handles authentication flows, token management, and user sessions.
4
+ """
5
+
6
+ import time
7
+ from typing import Dict, Any, Optional, List, Union, TypeVar, Generic, Callable
8
+ from urllib.parse import urlparse, parse_qs
9
+ import json
10
+ import asyncio
11
+ import jwt
12
+
13
+ from authlib.integrations.httpx_client import AsyncOAuth2Client
14
+ from authlib.integrations.base_client.errors import OAuthError
15
+ import httpx
16
+
17
+ from pydantic import BaseModel, ValidationError
18
+
19
+ from error import (
20
+ MissingTransactionError,
21
+ ApiError,
22
+ MissingRequiredArgumentError,
23
+ BackchannelLogoutError,
24
+ AccessTokenError,
25
+ AccessTokenForConnectionError,
26
+ StartLinkUserError,
27
+ AccessTokenErrorCode,
28
+ AccessTokenForConnectionErrorCode
29
+
30
+ )
31
+ from auth_types import (
32
+ StateData,
33
+ TransactionData,
34
+ UserClaims,
35
+ TokenSet,
36
+ LogoutTokenClaims,
37
+ StartInteractiveLoginOptions,
38
+ LogoutOptions
39
+ )
40
+ from utils import PKCE, State, URL
41
+
42
+
43
+ # Generic type for store options
44
+ TStoreOptions = TypeVar('TStoreOptions')
45
+
46
+ class ServerClient(Generic[TStoreOptions]):
47
+ """
48
+ Main client for Auth0 server SDK. Handles authentication flows, session management,
49
+ and token operations using Authlib for OIDC functionality.
50
+ """
51
+
52
+ def __init__(
53
+ self,
54
+ domain: str,
55
+ client_id: str,
56
+ client_secret: str,
57
+ redirect_uri: Optional[str] = None,
58
+ secret: str = None,
59
+ transaction_store = None,
60
+ state_store = None,
61
+ transaction_identifier: str = "_a0_tx",
62
+ state_identifier: str = "_a0_session",
63
+ authorization_params: Optional[Dict[str, Any]] = None,
64
+ pushed_authorization_requests: bool = False
65
+ ):
66
+ """
67
+ Initialize the Auth0 server client.
68
+
69
+ Args:
70
+ domain: Auth0 domain (e.g., 'your-tenant.auth0.com')
71
+ client_id: Auth0 client ID
72
+ client_secret: Auth0 client secret
73
+ redirect_uri: Default redirect URI for authentication
74
+ secret: Secret used for encryption
75
+ transaction_store: Custom transaction store (defaults to MemoryTransactionStore)
76
+ state_store: Custom state store (defaults to MemoryStateStore)
77
+ transaction_identifier: Identifier for transaction data
78
+ state_identifier: Identifier for state data
79
+ authorization_params: Default parameters for authorization requests
80
+ """
81
+ if not secret:
82
+ raise MissingRequiredArgumentError("secret")
83
+
84
+ # Store configuration
85
+ self._domain = domain
86
+ self._client_id = client_id
87
+ self._client_secret = client_secret
88
+ self._redirect_uri = redirect_uri
89
+ self._default_authorization_params = authorization_params or {}
90
+ self._pushed_authorization_requests = pushed_authorization_requests # store the flag
91
+
92
+ # Initialize stores
93
+ self._transaction_store = transaction_store
94
+ self._state_store = state_store
95
+ self._transaction_identifier = transaction_identifier
96
+ self._state_identifier = state_identifier
97
+
98
+ # Initialize OAuth client
99
+ self._oauth = AsyncOAuth2Client(
100
+ client_id=client_id,
101
+ client_secret=client_secret,
102
+ )
103
+
104
+ async def _fetch_oidc_metadata(self, domain: str) -> dict:
105
+ metadata_url = f"https://{domain}/.well-known/openid-configuration"
106
+ async with httpx.AsyncClient() as client:
107
+ response = await client.get(metadata_url)
108
+ response.raise_for_status()
109
+ return response.json()
110
+
111
+
112
+ async def start_interactive_login(
113
+ self,
114
+ options: Optional[StartInteractiveLoginOptions] = None,
115
+ store_options: dict = None
116
+ ) -> str:
117
+ """
118
+ Starts the interactive login process and returns a URL to redirect to.
119
+
120
+ Args:
121
+ options: Configuration options for the login process
122
+
123
+ Returns:
124
+ Authorization URL to redirect the user to
125
+ """
126
+ options = options or StartInteractiveLoginOptions()
127
+
128
+ # Get effective authorization params (merge defaults with provided ones)
129
+ auth_params = dict(self._default_authorization_params)
130
+ if options.authorization_params:
131
+ auth_params.update(options.authorization_params)
132
+
133
+ # Ensure we have a redirect_uri
134
+ if "redirect_uri" not in auth_params and not self._redirect_uri:
135
+ raise MissingRequiredArgumentError("redirect_uri")
136
+
137
+ # Use the default redirect_uri if none is specified
138
+ if "redirect_uri" not in auth_params and self._redirect_uri:
139
+ auth_params["redirect_uri"] = self._redirect_uri
140
+
141
+ # Generate PKCE code verifier and challenge
142
+ code_verifier = PKCE.generate_code_verifier()
143
+ code_challenge = PKCE.generate_code_challenge(code_verifier)
144
+
145
+ # Add PKCE parameters to the authorization request
146
+ auth_params["code_challenge"] = code_challenge
147
+ auth_params["code_challenge_method"] = "S256"
148
+
149
+ # State parameter to prevent CSRF
150
+ state = PKCE.generate_random_string(32)
151
+ auth_params["state"] = state
152
+
153
+ # Build the transaction data to store
154
+ transaction_data = TransactionData(
155
+ code_verifier=code_verifier,
156
+ app_state=options.app_state
157
+ )
158
+
159
+ # Store the transaction data
160
+ await self._transaction_store.set(
161
+ f"{self._transaction_identifier}:{state}",
162
+ transaction_data,
163
+ options=store_options
164
+ )
165
+ try:
166
+ self._oauth.metadata = await self._fetch_oidc_metadata(self._domain)
167
+ except Exception as e:
168
+ raise ApiError("metadata_error", "Failed to fetch OIDC metadata", e)
169
+ # If PAR is enabled, use the PAR endpoint
170
+ if self._pushed_authorization_requests:
171
+ par_endpoint = self._oauth.metadata.get("pushed_authorization_request_endpoint")
172
+ if not par_endpoint:
173
+ raise ApiError("configuration_error", "PAR is enabled but pushed_authorization_request_endpoint is missing in metadata")
174
+
175
+ auth_params["client_id"] = self._client_id
176
+ # Post the auth_params to the PAR endpoint
177
+ async with httpx.AsyncClient() as client:
178
+ par_response = await client.post(
179
+ par_endpoint,
180
+ data=auth_params,
181
+ auth=(self._client_id, self._client_secret)
182
+ )
183
+ if par_response.status_code not in (200, 201):
184
+ error_data = par_response.json()
185
+ raise ApiError(
186
+ error_data.get("error", "par_error"),
187
+ error_data.get("error_description", "Failed to obtain request_uri from PAR endpoint")
188
+ )
189
+ par_data = par_response.json()
190
+ request_uri = par_data.get("request_uri")
191
+ if not request_uri:
192
+ raise ApiError("par_error", "No request_uri returned from PAR endpoint")
193
+
194
+ auth_endpoint = self._oauth.metadata.get("authorization_endpoint")
195
+ final_url = f"{auth_endpoint}?request_uri={request_uri}&response_type={auth_params['response_type']}&client_id={self._client_id}"
196
+ return final_url
197
+ else:
198
+ if "authorization_endpoint" not in self._oauth.metadata:
199
+ raise ApiError("configuration_error", "Authorization endpoint missing in OIDC metadata")
200
+
201
+ authorization_endpoint = self._oauth.metadata["authorization_endpoint"]
202
+
203
+ try:
204
+ auth_url, state = self._oauth.create_authorization_url(authorization_endpoint, **auth_params)
205
+ except Exception as e:
206
+ raise ApiError("authorization_url_error", "Failed to create authorization URL", e)
207
+
208
+ return auth_url
209
+
210
+ async def complete_interactive_login(
211
+ self,
212
+ url: str,
213
+ store_options: dict = None
214
+ ) -> Dict[str, Any]:
215
+ """
216
+ Completes the login process after user is redirected back.
217
+
218
+ Args:
219
+ url: The full callback URL including query parameters
220
+ store_options: Options to pass to the state store
221
+
222
+ Returns:
223
+ Dictionary containing session data and app state
224
+ """
225
+ # Parse the URL to get query parameters
226
+ parsed_url = urlparse(url)
227
+ query_params = parse_qs(parsed_url.query)
228
+
229
+ # Get state parameter from the URL
230
+ state = query_params.get("state", [""])[0]
231
+ if not state:
232
+ raise MissingRequiredArgumentError("state")
233
+
234
+ # Retrieve the transaction data using the state
235
+ transaction_identifier = f"{self._transaction_identifier}:{state}"
236
+ transaction_data = await self._transaction_store.get(transaction_identifier, options=store_options)
237
+
238
+ if not transaction_data:
239
+ raise MissingTransactionError()
240
+
241
+ # Check for error response from Auth0
242
+ if "error" in query_params:
243
+ error = query_params.get("error", [""])[0]
244
+ error_description = query_params.get("error_description", [""])[0]
245
+ raise ApiError(error, error_description)
246
+
247
+ # Get the authorization code from the URL
248
+ code = query_params.get("code", [""])[0]
249
+ if not code:
250
+ raise MissingRequiredArgumentError("code")
251
+
252
+ if not self._oauth.metadata or "token_endpoint" not in self._oauth.metadata:
253
+ self._oauth.metadata = await self._fetch_oidc_metadata(self._domain)
254
+
255
+ # Exchange the code for tokens
256
+ try:
257
+ token_endpoint = self._oauth.metadata["token_endpoint"]
258
+ token_response = await self._oauth.fetch_token(
259
+ token_endpoint,
260
+ code=code,
261
+ code_verifier=transaction_data.code_verifier,
262
+ redirect_uri=self._redirect_uri,
263
+ )
264
+ except OAuthError as e:
265
+ # Raise a custom error (or handle it as appropriate)
266
+ raise ApiError("token_error", f"Token exchange failed: {str(e)}", e)
267
+
268
+ # Use the userinfo field from the token_response for user claims
269
+ user_info = token_response.get("userinfo")
270
+ user_claims = None
271
+ if user_info:
272
+ user_claims = UserClaims.parse_obj(user_info)
273
+ else:
274
+ id_token = token_response.get("id_token")
275
+ if id_token:
276
+ claims = jwt.decode(id_token, options={"verify_signature": False})
277
+ user_claims = UserClaims.parse_obj(claims)
278
+
279
+ # Build a token set using the token response data
280
+ token_set = TokenSet(
281
+ audience=token_response.get("audience", "default"),
282
+ access_token=token_response.get("access_token", ""),
283
+ scope=token_response.get("scope", ""),
284
+ expires_at=int(time.time()) + token_response.get("expires_in", 3600)
285
+ )
286
+
287
+ # Generate a session id (sid) from token_response or transaction data, or create a new one
288
+ sid = user_info.get("sid") if user_info and "sid" in user_info else PKCE.generate_random_string(32)
289
+
290
+ # Construct state data to represent the session
291
+ state_data = StateData(
292
+ user=user_claims,
293
+ id_token=token_response.get("id_token"),
294
+ refresh_token=token_response.get("refresh_token"), # might be None if not provided
295
+ token_sets=[token_set],
296
+ internal={
297
+ "sid": sid,
298
+ "created_at": int(time.time())
299
+ }
300
+ )
301
+
302
+ # Store the state data in the state store using store_options (Response required)
303
+ await self._state_store.set(self._state_identifier, state_data, options=store_options)
304
+
305
+ # Clean up transaction data after successful login
306
+ await self._transaction_store.delete(transaction_identifier, options=store_options)
307
+
308
+ result = {"state_data": state_data.dict()}
309
+ if transaction_data.app_state:
310
+ result["app_state"] = transaction_data.app_state
311
+
312
+ #For RAR
313
+ authorization_details = token_response.get("authorization_details")
314
+ if authorization_details:
315
+ result["authorization_details"] = authorization_details
316
+
317
+ return result
318
+
319
+ async def start_link_user(
320
+ self,
321
+ options,
322
+ store_options: Optional[Dict[str, Any]] = None
323
+ ):
324
+ """
325
+ Starts the user linking process, and returns a URL to redirect the user-agent to.
326
+
327
+ Args:
328
+ options: Options used to configure the user linking process.
329
+ store_options: Optional options used to pass to the Transaction and State Store.
330
+
331
+ Returns:
332
+ URL to redirect the user to for authentication.
333
+ """
334
+ state_data = await self._state_store.get(self._state_identifier, store_options)
335
+
336
+ if not state_data or not state_data.get("id_token"):
337
+ raise StartLinkUserError(
338
+ "Unable to start the user linking process without a logged in user. Ensure to login using the SDK before starting the user linking process."
339
+ )
340
+
341
+ # Generate PKCE and state for security
342
+ code_verifier = PKCE.generate_code_verifier()
343
+ state = PKCE.generate_random_string(32)
344
+
345
+ # Build the URL for user linking
346
+ link_user_url = await self._build_link_user_url(
347
+ connection=options.get("connection"),
348
+ connection_scope=options.get("connectionScope"),
349
+ id_token=state_data["id_token"],
350
+ code_verifier=code_verifier,
351
+ state=state,
352
+ authorization_params=options.get("authorization_params")
353
+ )
354
+
355
+ # Store transaction data
356
+ transaction_data = TransactionData(
357
+ code_verifier=code_verifier,
358
+ app_state=options.get("app_state")
359
+ )
360
+
361
+ await self._transaction_store.set(
362
+ f"{self._transaction_identifier}:{state}",
363
+ transaction_data,
364
+ options=store_options
365
+ )
366
+
367
+ return link_user_url
368
+
369
+ async def complete_link_user(
370
+ self,
371
+ url: str,
372
+ store_options: Optional[Dict[str, Any]] = None
373
+ ) -> Dict[str, Any]:
374
+ """
375
+ Completes the user linking process.
376
+
377
+ Args:
378
+ url: The URL from which the query params should be extracted
379
+ store_options: Optional options for the stores
380
+
381
+ Returns:
382
+ Dictionary containing the original app state
383
+ """
384
+
385
+ # We can reuse the interactive login completion since the flow is similar
386
+ result = await self.complete_interactive_login(url, store_options)
387
+
388
+ # Return just the app state as specified
389
+ return {
390
+ "app_state": result.get("app_state")
391
+ }
392
+
393
+ async def start_unlink_user(
394
+ self,
395
+ options,
396
+ store_options: Optional[Dict[str, Any]] = None
397
+ ):
398
+ """
399
+ Starts the user unlinking process, and returns a URL to redirect the user-agent to.
400
+
401
+ Args:
402
+ options: Options used to configure the user unlinking process.
403
+ store_options: Optional options used to pass to the Transaction and State Store.
404
+
405
+ Returns:
406
+ URL to redirect the user to for authentication.
407
+ """
408
+ state_data = await self._state_store.get(self._state_identifier, store_options)
409
+
410
+ if not state_data or not state_data.get("id_token"):
411
+ raise StartLinkUserError(
412
+ "Unable to start the user linking process without a logged in user. Ensure to login using the SDK before starting the user linking process."
413
+ )
414
+
415
+ # Generate PKCE and state for security
416
+ code_verifier = PKCE.generate_code_verifier()
417
+ state = PKCE.generate_random_string(32)
418
+
419
+ # Build the URL for user linking
420
+ link_user_url = await self._build_unlink_user_url(
421
+ connection=options.get("connection"),
422
+ id_token=state_data["id_token"],
423
+ code_verifier=code_verifier,
424
+ state=state,
425
+ authorization_params=options.get("authorization_params")
426
+ )
427
+
428
+ # Store transaction data
429
+ transaction_data = TransactionData(
430
+ code_verifier=code_verifier,
431
+ app_state=options.get("app_state")
432
+ )
433
+
434
+ await self._transaction_store.set(
435
+ f"{self._transaction_identifier}:{state}",
436
+ transaction_data,
437
+ options=store_options
438
+ )
439
+
440
+ return link_user_url
441
+
442
+ async def complete_unlink_user(
443
+ self,
444
+ url: str,
445
+ store_options: Optional[Dict[str, Any]] = None
446
+ ) -> Dict[str, Any]:
447
+ """
448
+ Completes the user unlinking process.
449
+
450
+ Args:
451
+ url: The URL from which the query params should be extracted
452
+ store_options: Optional options for the stores
453
+
454
+ Returns:
455
+ Dictionary containing the original app state
456
+ """
457
+
458
+ # We can reuse the interactive login completion since the flow is similar
459
+ result = await self.complete_interactive_login(url, store_options)
460
+
461
+ # Return just the app state as specified
462
+ return {
463
+ "app_state": result.get("app_state")
464
+ }
465
+
466
+
467
+
468
+ async def login_backchannel(
469
+ self,
470
+ options: Dict[str, Any],
471
+ store_options: Optional[Dict[str, Any]] = None
472
+ ) -> Dict[str, Any]:
473
+ """
474
+ Logs in using Client-Initiated Backchannel Authentication.
475
+
476
+ Note:
477
+ Using Client-Initiated Backchannel Authentication requires the feature
478
+ to be enabled in the Auth0 dashboard.
479
+
480
+ See:
481
+ https://auth0.com/docs/get-started/authentication-and-authorization-flow/client-initiated-backchannel-authentication-flow
482
+
483
+ Args:
484
+ options: Options used to configure the backchannel login process.
485
+ store_options: Optional options used to pass to the Transaction and State Store.
486
+
487
+ Returns:
488
+ A dictionary containing the authorizationDetails (when RAR was used).
489
+ """
490
+ token_endpoint_response = await self.backchannel_authentication({
491
+ "binding_message": options.get("binding_message"),
492
+ "login_hint": options.get("login_hint"),
493
+ "authorization_params": options.get("authorization_params"),
494
+ })
495
+
496
+ existing_state_data = await self._state_store.get(self._state_identifier, store_options)
497
+
498
+ audience = self._default_authorization_params.get("audience", "default")
499
+
500
+ state_data = State.update_state_data(
501
+ audience,
502
+ existing_state_data,
503
+ token_endpoint_response
504
+ )
505
+
506
+ await self._state_store.set(self._state_identifier, state_data, store_options)
507
+
508
+ result = {
509
+ "authorization_details": token_endpoint_response.get("authorization_details")
510
+ }
511
+ return result
512
+
513
+ async def get_user(self, store_options: Optional[Dict[str, Any]] = None) -> Optional[Dict[str, Any]]:
514
+ """
515
+ Retrieves the user from the store, or None if no user found.
516
+
517
+ Args:
518
+ store_options: Optional options used to pass to the Transaction and State Store.
519
+
520
+ Returns:
521
+ The user, or None if no user found in the store.
522
+ """
523
+ state_data = await self._state_store.get(self._state_identifier, store_options)
524
+
525
+ if state_data:
526
+ if hasattr(state_data, "dict") and callable(state_data.dict):
527
+ state_data = state_data.dict()
528
+ return state_data.get("user")
529
+ return None
530
+
531
+ async def get_session(self, store_options: Optional[Dict[str, Any]] = None) -> Optional[Dict[str, Any]]:
532
+ """
533
+ Retrieve the user session from the store, or None if no session found.
534
+
535
+ Args:
536
+ store_options: Optional options used to pass to the Transaction and State Store.
537
+
538
+ Returns:
539
+ The session, or None if no session found in the store.
540
+ """
541
+ state_data = await self._state_store.get(self._state_identifier, store_options)
542
+
543
+ if state_data:
544
+ if hasattr(state_data, "dict") and callable(state_data.dict):
545
+ state_data = state_data.dict()
546
+ session_data = {k: v for k, v in state_data.items() if k != "internal"}
547
+ return session_data
548
+ return None
549
+
550
+ async def get_access_token(self, store_options: Optional[Dict[str, Any]] = None) -> str:
551
+ """
552
+ Retrieves the access token from the store, or calls Auth0 when the access token
553
+ is expired and a refresh token is available in the store.
554
+ Also updates the store when a new token was retrieved from Auth0.
555
+
556
+ Args:
557
+ store_options: Optional options used to pass to the Transaction and State Store.
558
+
559
+ Returns:
560
+ The access token, retrieved from the store or Auth0.
561
+
562
+ Raises:
563
+ AccessTokenError: If the token is expired and no refresh token is available.
564
+ """
565
+ state_data = await self._state_store.get(self._state_identifier, store_options)
566
+
567
+ # Get audience and scope from options or use defaults
568
+ auth_params = self._default_authorization_params or {}
569
+ audience = auth_params.get("audience", "default")
570
+ scope = auth_params.get("scope")
571
+
572
+
573
+ if state_data and hasattr(state_data, "dict") and callable(state_data.dict):
574
+ state_data_dict = state_data.dict()
575
+ else:
576
+ state_data_dict = state_data or {}
577
+
578
+ # Find matching token set
579
+ token_set = None
580
+ if state_data_dict and "token_sets" in state_data_dict:
581
+ for ts in state_data_dict["token_sets"]:
582
+ if ts.get("audience") == audience and (not scope or ts.get("scope") == scope):
583
+ token_set = ts
584
+ break
585
+
586
+ # If token is valid, return it
587
+ if token_set and token_set.get("expires_at", 0) > time.time():
588
+ return token_set["access_token"]
589
+
590
+ # Check for refresh token
591
+ if not state_data_dict or not state_data_dict.get("refresh_token"):
592
+ raise AccessTokenError(
593
+ AccessTokenErrorCode.MISSING_REFRESH_TOKEN,
594
+ "The access token has expired and a refresh token was not provided. The user needs to re-authenticate."
595
+ )
596
+
597
+ # Get new token with refresh token
598
+ try:
599
+ token_endpoint_response = await self.get_token_by_refresh_token({
600
+ "refresh_token": state_data_dict["refresh_token"]
601
+ })
602
+
603
+ # Update state data with new token
604
+ existing_state_data = await self._state_store.get(self._state_identifier, store_options)
605
+ updated_state_data = State.update_state_data(audience, existing_state_data, token_endpoint_response)
606
+
607
+ # Store updated state
608
+ await self._state_store.set(self._state_identifier, updated_state_data, options=store_options)
609
+
610
+ return token_endpoint_response["access_token"]
611
+ except Exception as e:
612
+ if isinstance(e, AccessTokenError):
613
+ raise
614
+ raise AccessTokenError(
615
+ AccessTokenErrorCode.REFRESH_TOKEN_ERROR,
616
+ f"Failed to get token with refresh token: {str(e)}"
617
+ )
618
+
619
+ async def get_access_token_for_connection(
620
+ self,
621
+ options: Dict[str, Any],
622
+ store_options: Optional[Dict[str, Any]] = None
623
+ ) -> str:
624
+ """
625
+ Retrieves an access token for a connection.
626
+
627
+ This method attempts to obtain an access token for a specified connection.
628
+ It first checks if a refresh token exists in the store.
629
+ If no refresh token is found, it throws an `AccessTokenForConnectionError` indicating
630
+ that the refresh token was not found.
631
+
632
+ Args:
633
+ options: Options for retrieving an access token for a connection.
634
+ store_options: Optional options used to pass to the Transaction and State Store.
635
+
636
+ Returns:
637
+ The access token for the connection
638
+
639
+ Raises:
640
+ AccessTokenForConnectionError: If the access token was not found or
641
+ there was an issue requesting the access token.
642
+ """
643
+ state_data = await self._state_store.get(self._state_identifier, store_options)
644
+
645
+ if state_data and hasattr(state_data, "dict") and callable(state_data.dict):
646
+ state_data_dict = state_data.dict()
647
+ else:
648
+ state_data_dict = state_data or {}
649
+
650
+ # Find existing connection token
651
+ connection_token_set = None
652
+ if state_data_dict and len(state_data_dict["connection_token_sets"]) > 0:
653
+ for ts in state_data_dict.get("connection_token_sets"):
654
+ if ts.get("connection") == options["connection"]:
655
+ connection_token_set = ts
656
+ break
657
+
658
+ # If token is valid, return it
659
+ if connection_token_set and connection_token_set.get("expires_at", 0) > time.time():
660
+ return connection_token_set["access_token"]
661
+
662
+ # Check for refresh token
663
+ if not state_data_dict or not state_data_dict.get("refresh_token"):
664
+ raise AccessTokenForConnectionError(
665
+ AccessTokenForConnectionErrorCode.MISSING_REFRESH_TOKEN,
666
+ "A refresh token was not found but is required to be able to retrieve an access token for a connection."
667
+ )
668
+ # Get new token for connection
669
+ token_endpoint_response = await self.get_token_for_connection({
670
+ "connection": options.get("connection"),
671
+ "login_hint": options.get("login_hint"),
672
+ "refresh_token": state_data_dict["refresh_token"]
673
+ })
674
+
675
+ # Update state data with new token
676
+ updated_state_data = State.update_state_data_for_connection_token_set(options, state_data_dict, token_endpoint_response)
677
+
678
+ # Store updated state
679
+ await self._state_store.set(self._state_identifier, updated_state_data, store_options)
680
+
681
+ return token_endpoint_response["access_token"]
682
+
683
+
684
+ async def logout(
685
+ self,
686
+ options: Optional[LogoutOptions] = None,
687
+ store_options: Optional[Dict[str, Any]] = None
688
+ ) -> str:
689
+ options = options or LogoutOptions()
690
+
691
+ # Delete the session from the state store
692
+ await self._state_store.delete(self._state_identifier, store_options)
693
+
694
+ # Use the URL helper to create the logout URL.
695
+ logout_url = URL.create_logout_url(self._domain, self._client_id, options.return_to)
696
+
697
+ return logout_url
698
+
699
+ async def handle_backchannel_logout(
700
+ self,
701
+ logout_token: str,
702
+ store_options: Optional[Dict[str, Any]] = None
703
+ ) -> None:
704
+ """
705
+ Handles backchannel logout requests.
706
+
707
+ Args:
708
+ logout_token: The logout token sent by Auth0
709
+ store_options: Options to pass to the state store
710
+ """
711
+ if not logout_token:
712
+ raise BackchannelLogoutError("Missing logout token")
713
+
714
+ try:
715
+ # Decode the token without verification
716
+ claims = jwt.decode(logout_token, options={"verify_signature": False})
717
+
718
+ # Validate the token is a logout token
719
+ events = claims.get("events", {})
720
+ if "http://schemas.openid.net/event/backchannel-logout" not in events:
721
+ raise BackchannelLogoutError("Invalid logout token: not a backchannel logout event")
722
+
723
+ # Delete sessions associated with this token
724
+ logout_claims = LogoutTokenClaims(
725
+ sub=claims.get("sub"),
726
+ sid=claims.get("sid")
727
+ )
728
+
729
+ await self._state_store.delete_by_logout_token(logout_claims.dict(), store_options)
730
+
731
+ except (jwt.JoseError, ValidationError) as e:
732
+ raise BackchannelLogoutError(f"Error processing logout token: {str(e)}")
733
+
734
+
735
+
736
+ # Authlib Helpers
737
+
738
+ async def _build_link_user_url(
739
+ self,
740
+ connection: str,
741
+ id_token: str,
742
+ code_verifier: str,
743
+ state: str,
744
+ connection_scope: Optional[str] = None,
745
+ authorization_params: Optional[Dict[str, Any]] = None
746
+ ) -> str:
747
+ """Build a URL for linking user accounts"""
748
+ # Generate code challenge from verifier
749
+ code_challenge = PKCE.generate_code_challenge(code_verifier)
750
+
751
+ # Get metadata if not already fetched
752
+ if not hasattr(self, '_oauth_metadata'):
753
+ self._oauth_metadata = await self._fetch_oidc_metadata(self._domain)
754
+
755
+ # Get authorization endpoint
756
+ auth_endpoint = self._oauth_metadata.get("authorization_endpoint",
757
+ f"https://{self._domain}/authorize")
758
+
759
+ # Build params
760
+ params = {
761
+ "client_id": self._client_id,
762
+ "code_challenge": code_challenge,
763
+ "code_challenge_method": "S256",
764
+ "state": state,
765
+ "requested_connection": connection,
766
+ "requested_connection_scope": connection_scope,
767
+ "response_type": "code",
768
+ "id_token_hint": id_token,
769
+ "scope": "openid link_account",
770
+ "audience": "my-account"
771
+ }
772
+
773
+ # Add connection scope if provided
774
+ if connection_scope:
775
+ params["requested_connection_scope"] = connection_scope
776
+
777
+ # Add any additional parameters
778
+ if authorization_params:
779
+ params.update(authorization_params)
780
+ return URL.build_url(auth_endpoint, params)
781
+
782
+ async def _build_unlink_user_url(
783
+ self,
784
+ connection: str,
785
+ id_token: str,
786
+ code_verifier: str,
787
+ state: str,
788
+ authorization_params: Optional[Dict[str, Any]] = None
789
+ ) -> str:
790
+ """Build a URL for unlinking user accounts"""
791
+ # Generate code challenge from verifier
792
+ code_challenge = PKCE.generate_code_challenge(code_verifier)
793
+
794
+ # Get metadata if not already fetched
795
+ if not hasattr(self, '_oauth_metadata'):
796
+ self._oauth_metadata = await self._fetch_oidc_metadata(self._domain)
797
+
798
+ # Get authorization endpoint
799
+ auth_endpoint = self._oauth_metadata.get("authorization_endpoint",
800
+ f"https://{self._domain}/authorize")
801
+
802
+ # Build params
803
+ params = {
804
+ "client_id": self._client_id,
805
+ "code_challenge": code_challenge,
806
+ "code_challenge_method": "S256",
807
+ "state": state,
808
+ "requested_connection": connection,
809
+ "response_type": "code",
810
+ "id_token_hint": id_token,
811
+ "scope": "openid unlink_account",
812
+ "audience": "my-account"
813
+ }
814
+ # Add any additional parameters
815
+ if authorization_params:
816
+ params.update(authorization_params)
817
+
818
+ return URL.build_url(auth_endpoint, params)
819
+
820
+ async def backchannel_authentication(
821
+ self,
822
+ options: Dict[str, Any]
823
+ ) -> Dict[str, Any]:
824
+ """
825
+ Initiates backchannel authentication with Auth0.
826
+
827
+ This method starts a Client-Initiated Backchannel Authentication (CIBA) flow,
828
+ which allows an application to request authentication from a user via a separate
829
+ device or channel.
830
+
831
+ Args:
832
+ options: Configuration options for backchannel authentication
833
+
834
+ Returns:
835
+ Token response data from the backchannel authentication
836
+
837
+ Raises:
838
+ ApiError: If the backchannel authentication fails
839
+ """
840
+ try:
841
+ # Fetch OpenID Connect metadata if not already fetched
842
+ if not hasattr(self, '_oauth_metadata'):
843
+ self._oauth_metadata = await self._fetch_oidc_metadata(self._domain)
844
+
845
+ # Get the issuer from metadata
846
+ issuer = self._oauth_metadata.get("issuer") or f"https://{self._domain}/"
847
+
848
+ # Get backchannel authentication endpoint
849
+ backchannel_endpoint = self._oauth_metadata.get("backchannel_authentication_endpoint")
850
+ if not backchannel_endpoint:
851
+ raise ApiError(
852
+ "configuration_error",
853
+ "Backchannel authentication is not supported by the authorization server"
854
+ )
855
+
856
+ # Get token endpoint
857
+ token_endpoint = self._oauth_metadata.get("token_endpoint")
858
+ if not token_endpoint:
859
+ raise ApiError(
860
+ "configuration_error",
861
+ "Token endpoint is missing in OIDC metadata"
862
+ )
863
+
864
+ sub = sub = options.get('login_hint', {}).get("sub")
865
+ if not sub:
866
+ raise ApiError(
867
+ "invalid_parameter",
868
+ "login_hint must contain a 'sub' field"
869
+ )
870
+
871
+ # Prepare login hint in the required format
872
+ login_hint = json.dumps({
873
+ "format": "iss_sub",
874
+ "iss": issuer,
875
+ "sub": sub
876
+ })
877
+
878
+ # The Request Parameters
879
+ params = {
880
+ "client_id": self._client_id,
881
+ "scope": "openid profile email", # DEFAULT_SCOPES
882
+ "login_hint": login_hint,
883
+ }
884
+
885
+
886
+ # Add binding message if provided
887
+ if options.get('binding_message'):
888
+ params["binding_message"] = options.get('binding_message')
889
+
890
+ # Add any additional authorization parameters
891
+ if self._default_authorization_params:
892
+ params.update(self._default_authorization_params)
893
+
894
+ if options.get('authorization_params'):
895
+ params.update(options.get('authorization_params'))
896
+
897
+ # Make the backchannel authentication request
898
+ async with httpx.AsyncClient() as client:
899
+ backchannel_response = await client.post(
900
+ backchannel_endpoint,
901
+ data=params,
902
+ auth=(self._client_id, self._client_secret)
903
+ )
904
+
905
+ if backchannel_response.status_code != 200:
906
+ error_data = backchannel_response.json()
907
+ raise ApiError(
908
+ error_data.get("error", "backchannel_error"),
909
+ error_data.get("error_description", "Backchannel authentication request failed")
910
+ )
911
+
912
+ backchannel_data = backchannel_response.json()
913
+ auth_req_id = backchannel_data.get("auth_req_id")
914
+ expires_in = backchannel_data.get("expires_in", 120) # Default to 2 minutes
915
+ interval = backchannel_data.get("interval", 5) # Default to 5 seconds
916
+
917
+ if not auth_req_id:
918
+ raise ApiError(
919
+ "invalid_response",
920
+ "Missing auth_req_id in backchannel authentication response"
921
+ )
922
+
923
+ # Poll for token using the auth_req_id
924
+ token_params = {
925
+ "grant_type": "urn:openid:params:grant-type:ciba",
926
+ "auth_req_id": auth_req_id,
927
+ "client_id": self._client_id,
928
+ "client_secret": self._client_secret
929
+ }
930
+
931
+ # Calculate when to stop polling
932
+ end_time = time.time() + expires_in
933
+
934
+ # Poll until we get a response or timeout
935
+ while time.time() < end_time:
936
+ # Make token request
937
+ token_response = await client.post(token_endpoint, data=token_params)
938
+
939
+ # Check for success (200 OK)
940
+ if token_response.status_code == 200:
941
+ # Success! Parse and return the token response
942
+ return token_response.json()
943
+
944
+ # Check for specific error that indicates we should continue polling
945
+ if token_response.status_code == 400:
946
+ error_data = token_response.json()
947
+ error = error_data.get("error")
948
+
949
+ # authorization_pending means we should keep polling
950
+ if error == "authorization_pending":
951
+ # Wait for the specified interval before polling again
952
+ await asyncio.sleep(interval)
953
+ continue
954
+
955
+ # Other errors should be raised
956
+ raise ApiError(
957
+ error,
958
+ error_data.get("error_description", "Token request failed")
959
+ )
960
+
961
+ # Any other status code is an error
962
+ raise ApiError(
963
+ "token_error",
964
+ f"Unexpected status code: {token_response.status_code}"
965
+ )
966
+
967
+ # If we get here, we've timed out
968
+ raise ApiError("timeout", "Backchannel authentication timed out")
969
+
970
+ except Exception as e:
971
+ print("Caught exception:", type(e), e.args, repr(e))
972
+ raise ApiError(
973
+ "backchannel_error",
974
+ f"Backchannel authentication failed: {str(e) or 'Unknown error'}",
975
+ e
976
+ )
977
+
978
+ async def get_token_by_refresh_token(self, options: Dict[str, Any]) -> Dict[str, Any]:
979
+ """
980
+ Retrieves a token by exchanging a refresh token.
981
+
982
+ Args:
983
+ options: Dictionary containing the refresh token and any additional options.
984
+ Must include a 'refresh_token' key.
985
+
986
+ Raises:
987
+ AccessTokenError: If there was an issue requesting the access token.
988
+
989
+ Returns:
990
+ A dictionary containing the token response from Auth0.
991
+ """
992
+ refresh_token = options.get("refresh_token")
993
+ if not refresh_token:
994
+ raise MissingRequiredArgumentError("refresh_token")
995
+
996
+ try:
997
+ # Ensure we have the OIDC metadata
998
+ if not hasattr(self._oauth, "metadata") or not self._oauth.metadata:
999
+ self._oauth.metadata = await self._fetch_oidc_metadata(self._domain)
1000
+
1001
+ token_endpoint = self._oauth.metadata.get("token_endpoint")
1002
+ if not token_endpoint:
1003
+ raise ApiError("configuration_error", "Token endpoint missing in OIDC metadata")
1004
+
1005
+ # Prepare the token request parameters
1006
+ token_params = {
1007
+ "grant_type": "refresh_token",
1008
+ "refresh_token": refresh_token,
1009
+ "client_id": self._client_id,
1010
+ }
1011
+
1012
+ # Add scope if present in the original authorization params
1013
+ if "scope" in self._default_authorization_params:
1014
+ token_params["scope"] = self._default_authorization_params["scope"]
1015
+
1016
+ # Exchange the refresh token for an access token
1017
+ async with httpx.AsyncClient() as client:
1018
+ response = await client.post(
1019
+ token_endpoint,
1020
+ data=token_params,
1021
+ auth=(self._client_id, self._client_secret)
1022
+ )
1023
+
1024
+ if response.status_code != 200:
1025
+ error_data = response.json()
1026
+ raise ApiError(
1027
+ error_data.get("error", "refresh_token_error"),
1028
+ error_data.get("error_description", "Failed to exchange refresh token")
1029
+ )
1030
+
1031
+ token_response = response.json()
1032
+
1033
+ # Add required fields if they are missing
1034
+ if "expires_in" in token_response and "expires_at" not in token_response:
1035
+ token_response["expires_at"] = int(time.time()) + token_response["expires_in"]
1036
+
1037
+ return token_response
1038
+
1039
+ except Exception as e:
1040
+ if isinstance(e, ApiError):
1041
+ raise
1042
+ raise AccessTokenError(
1043
+ AccessTokenErrorCode.REFRESH_TOKEN_ERROR,
1044
+ "The access token has expired and there was an error while trying to refresh it.",
1045
+ e
1046
+ )
1047
+
1048
+ async def get_token_for_connection(self, options: Dict[str, Any]) -> Dict[str, Any]:
1049
+ """
1050
+ Retrieves a token for a connection.
1051
+
1052
+ Args:
1053
+ options: Options for retrieving an access token for a connection.
1054
+ Must include 'connection' and 'refresh_token' keys.
1055
+ May optionally include 'login_hint'.
1056
+
1057
+ Raises:
1058
+ AccessTokenForConnectionError: If there was an issue requesting the access token.
1059
+
1060
+ Returns:
1061
+ Dictionary containing the token response with accessToken, expiresAt, and scope.
1062
+ """
1063
+ # Constants
1064
+ SUBJECT_TYPE_REFRESH_TOKEN = "urn:ietf:params:oauth:token-type:refresh_token"
1065
+ REQUESTED_TOKEN_TYPE_FEDERATED_CONNECTION_ACCESS_TOKEN = "http://auth0.com/oauth/token-type/federated-connection-access-token"
1066
+ GRANT_TYPE_FEDERATED_CONNECTION_ACCESS_TOKEN = "urn:auth0:params:oauth:grant-type:token-exchange:federated-connection-access-token"
1067
+ try:
1068
+ # Ensure we have OIDC metadata
1069
+ if not hasattr(self._oauth, "metadata") or not self._oauth.metadata:
1070
+ self._oauth.metadata = await self._fetch_oidc_metadata(self._domain)
1071
+
1072
+ token_endpoint = self._oauth.metadata.get("token_endpoint")
1073
+ if not token_endpoint:
1074
+ raise ApiError("configuration_error", "Token endpoint missing in OIDC metadata")
1075
+
1076
+ # Prepare parameters
1077
+ params = {
1078
+ "connection": options["connection"],
1079
+ "subject_token_type": SUBJECT_TYPE_REFRESH_TOKEN,
1080
+ "subject_token": options["refresh_token"],
1081
+ "requested_token_type": REQUESTED_TOKEN_TYPE_FEDERATED_CONNECTION_ACCESS_TOKEN,
1082
+ "grant_type": GRANT_TYPE_FEDERATED_CONNECTION_ACCESS_TOKEN,
1083
+ "client_id": self._client_id
1084
+ }
1085
+
1086
+ # Add login_hint if provided
1087
+ if "login_hint" in options and options["login_hint"]:
1088
+ params["login_hint"] = options["login_hint"]
1089
+
1090
+ # Make the request
1091
+ async with httpx.AsyncClient() as client:
1092
+ response = await client.post(
1093
+ token_endpoint,
1094
+ data=params,
1095
+ auth=(self._client_id, self._client_secret)
1096
+ )
1097
+
1098
+ if response.status_code != 200:
1099
+ error_data = response.json() if response.headers.get("content-type") == "application/json" else {}
1100
+ raise ApiError(
1101
+ error_data.get("error", "connection_token_error"),
1102
+ error_data.get("error_description", f"Failed to get token for connection: {response.status_code}")
1103
+ )
1104
+
1105
+ token_endpoint_response = response.json()
1106
+
1107
+ return {
1108
+ "access_token": token_endpoint_response.get("access_token"),
1109
+ "expires_at": int(time.time()) + int(token_endpoint_response.get("expires_in", 3600)),
1110
+ "scope": token_endpoint_response.get("scope", "")
1111
+ }
1112
+
1113
+ except Exception as e:
1114
+ if isinstance(e, ApiError):
1115
+ raise AccessTokenForConnectionError(
1116
+ AccessTokenForConnectionErrorCode.API_ERROR,
1117
+ str(e)
1118
+ )
1119
+ raise AccessTokenForConnectionError(
1120
+ AccessTokenForConnectionErrorCode.FETCH_ERROR,
1121
+ "There was an error while trying to retrieve an access token for a connection.",
1122
+ e
1123
+ )