auth0-api-python 1.0.0b7__py3-none-any.whl → 1.0.0b9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -5,13 +5,34 @@ A lightweight Python SDK for verifying Auth0-issued access tokens
5
5
  in server-side APIs, using Authlib for OIDC discovery and JWKS fetching.
6
6
  """
7
7
 
8
+ from .act import get_current_actor, get_delegation_chain
8
9
  from .api_client import ApiClient
10
+ from .cache import CacheAdapter, InMemoryCache
9
11
  from .config import ApiClientOptions
10
- from .errors import ApiError, GetTokenByExchangeProfileError
12
+ from .errors import (
13
+ ApiError,
14
+ ConfigurationError,
15
+ DomainsResolverError,
16
+ GetTokenByExchangeProfileError,
17
+ )
18
+ from .types import (
19
+ DomainsResolver,
20
+ DomainsResolverContext,
21
+ OnBehalfOfTokenResult,
22
+ )
11
23
 
12
24
  __all__ = [
13
25
  "ApiClient",
14
26
  "ApiClientOptions",
15
27
  "ApiError",
16
- "GetTokenByExchangeProfileError"
28
+ "CacheAdapter",
29
+ "ConfigurationError",
30
+ "DomainsResolver",
31
+ "DomainsResolverContext",
32
+ "DomainsResolverError",
33
+ "GetTokenByExchangeProfileError",
34
+ "get_current_actor",
35
+ "get_delegation_chain",
36
+ "InMemoryCache",
37
+ "OnBehalfOfTokenResult",
17
38
  ]
@@ -0,0 +1,64 @@
1
+ """
2
+ Helpers for working with the `act` claim on verified access token claims.
3
+ """
4
+
5
+ from collections.abc import Mapping
6
+ from typing import Any, Optional
7
+
8
+ from .errors import VerifyAccessTokenError
9
+
10
+ INVALID_ACT_CLAIM_MESSAGE = "Invalid act claim"
11
+
12
+
13
+ def get_current_actor(claims: Mapping[str, Any]) -> Optional[str]:
14
+ """
15
+ Return the current actor from the outermost `act.sub`, if present.
16
+
17
+ Only the outermost `act.sub` should be used for authorization decisions.
18
+ Nested `act` values represent prior actors and are informational.
19
+ """
20
+ if not isinstance(claims, Mapping):
21
+ raise VerifyAccessTokenError(INVALID_ACT_CLAIM_MESSAGE)
22
+
23
+ act_claim = claims.get("act")
24
+ if act_claim is None:
25
+ return None
26
+
27
+ if not isinstance(act_claim, Mapping):
28
+ raise VerifyAccessTokenError(INVALID_ACT_CLAIM_MESSAGE)
29
+
30
+ sub = act_claim.get("sub")
31
+ if not isinstance(sub, str) or not sub.strip():
32
+ raise VerifyAccessTokenError(INVALID_ACT_CLAIM_MESSAGE)
33
+
34
+ return sub
35
+
36
+
37
+ def get_delegation_chain(claims: Mapping[str, Any]) -> list[str]:
38
+ """
39
+ Return the delegation chain from newest actor to oldest actor.
40
+
41
+ The first entry is the current actor (outermost `act.sub`). Later entries are
42
+ prior actors from nested `act` values and are typically most useful for audit
43
+ and attribution.
44
+ """
45
+ if not isinstance(claims, Mapping):
46
+ raise VerifyAccessTokenError(INVALID_ACT_CLAIM_MESSAGE)
47
+
48
+ current = claims.get("act")
49
+ if current is None:
50
+ return []
51
+
52
+ chain: list[str] = []
53
+ while current is not None:
54
+ if not isinstance(current, Mapping):
55
+ raise VerifyAccessTokenError(INVALID_ACT_CLAIM_MESSAGE)
56
+
57
+ sub = current.get("sub")
58
+ if not isinstance(sub, str) or not sub.strip():
59
+ raise VerifyAccessTokenError(INVALID_ACT_CLAIM_MESSAGE)
60
+
61
+ chain.append(sub)
62
+ current = current.get("act")
63
+
64
+ return chain
@@ -1,3 +1,4 @@
1
+ import asyncio
1
2
  import time
2
3
  from collections.abc import Mapping, Sequence
3
4
  from typing import Any, Optional, Union
@@ -5,10 +6,13 @@ from typing import Any, Optional, Union
5
6
  import httpx
6
7
  from authlib.jose import JsonWebKey, JsonWebToken
7
8
 
9
+ from .cache import InMemoryCache
8
10
  from .config import ApiClientOptions
9
11
  from .errors import (
10
12
  ApiError,
11
13
  BaseAuthError,
14
+ ConfigurationError,
15
+ DomainsResolverError,
12
16
  GetAccessTokenForConnectionError,
13
17
  GetTokenByExchangeProfileError,
14
18
  InvalidAuthSchemeError,
@@ -17,17 +21,21 @@ from .errors import (
17
21
  MissingRequiredArgumentError,
18
22
  VerifyAccessTokenError,
19
23
  )
24
+ from .types import OnBehalfOfTokenResult
20
25
  from .utils import (
21
26
  calculate_jwk_thumbprint,
22
27
  fetch_jwks,
23
28
  fetch_oidc_metadata,
24
29
  get_unverified_header,
30
+ get_unverified_payload,
31
+ normalize_domain,
25
32
  normalize_url_for_htu,
26
33
  sha256_base64url,
27
34
  )
28
35
 
29
36
  # Token Exchange constants
30
37
  TOKEN_EXCHANGE_GRANT_TYPE = "urn:ietf:params:oauth:grant-type:token-exchange" # noqa: S105
38
+ OBO_ACCESS_TOKEN_TYPE = "urn:ietf:params:oauth:token-type:access_token" # noqa: S105
31
39
  MAX_ARRAY_VALUES_PER_KEY = 20 # DoS protection for extra parameter arrays
32
40
 
33
41
  # OAuth parameter denylist - parameters that cannot be overridden via extras
@@ -48,14 +56,62 @@ class ApiClient:
48
56
  """
49
57
 
50
58
  def __init__(self, options: ApiClientOptions):
51
- if not options.domain:
52
- raise MissingRequiredArgumentError("domain")
59
+ # Validate audience is always required
53
60
  if not options.audience:
54
61
  raise MissingRequiredArgumentError("audience")
55
62
 
63
+ # Validate domains parameter if provided
64
+ if options.domains is not None:
65
+ if isinstance(options.domains, list):
66
+ # Static list validation
67
+ if len(options.domains) == 0:
68
+ raise ConfigurationError("domains list cannot be empty")
69
+ if not all(isinstance(d, str) and d.strip() for d in options.domains):
70
+ raise ConfigurationError(
71
+ "domains list must contain only non-empty strings"
72
+ )
73
+ # Normalize and store domains
74
+ self._allowed_domains = [normalize_domain(d) for d in options.domains]
75
+ elif callable(options.domains):
76
+ # Dynamic resolver - store the function
77
+ self._allowed_domains = options.domains
78
+ else:
79
+ raise ConfigurationError(
80
+ "domains must be either a list of domain strings or a callable resolver function"
81
+ )
82
+ else:
83
+ # Single domain mode
84
+ self._allowed_domains = None
85
+
86
+ # Validate domain/domains configuration
87
+ if not options.domain and not options.domains:
88
+ raise ConfigurationError(
89
+ "Must provide either 'domain' or 'domains' parameter. "
90
+ "Use 'domain' for single-domain mode, 'domains' for multi-domain support."
91
+ )
92
+
93
+ # Validate that domain is set when client_id is configured
94
+ if options.client_id and not options.domain:
95
+ raise ConfigurationError(
96
+ "The 'domain' parameter is required when 'client_id' is configured."
97
+ )
56
98
  self.options = options
57
- self._metadata: Optional[dict[str, Any]] = None
58
- self._jwks_data: Optional[dict[str, Any]] = None
99
+
100
+ # Validate cache configuration
101
+ if not isinstance(options.cache_ttl_seconds, (int, float)) or options.cache_ttl_seconds < 0:
102
+ raise ConfigurationError("cache_ttl_seconds must be a non-negative number")
103
+
104
+ if not isinstance(options.cache_max_entries, int) or options.cache_max_entries < 2:
105
+ raise ConfigurationError("cache_max_entries must be an integer greater than 1")
106
+
107
+ if options.cache_adapter:
108
+ self._discovery_cache = options.cache_adapter
109
+ self._jwks_cache = options.cache_adapter
110
+ else:
111
+ self._discovery_cache = InMemoryCache(max_entries=options.cache_max_entries)
112
+ self._jwks_cache = InMemoryCache(max_entries=options.cache_max_entries)
113
+
114
+ self._cache_ttl = options.cache_ttl_seconds
59
115
 
60
116
  self._jwt = JsonWebToken(["RS256"])
61
117
 
@@ -66,6 +122,92 @@ class ApiClient:
66
122
  """Check if DPoP authentication is required."""
67
123
  return getattr(self.options, "dpop_required", False)
68
124
 
125
+ async def _resolve_allowed_domains(
126
+ self,
127
+ unverified_iss: str,
128
+ request_url: Optional[str] = None,
129
+ request_headers: Optional[dict] = None
130
+ ) -> Optional[list[str]]:
131
+ """
132
+ Resolve and validate allowed domains for the given issuer.
133
+
134
+ Handles three modes:
135
+ 1. Static list: Returns normalized list, validates issuer against it
136
+ 2. Dynamic resolver: Invokes resolver function, validates issuer against result
137
+ 3. Single domain: Returns None (backward compatibility, uses domain)
138
+
139
+ Args:
140
+ unverified_iss: The issuer claim from the token (not yet verified)
141
+ request_url: Optional request URL for dynamic resolvers
142
+ request_headers: Optional request headers for dynamic resolvers
143
+
144
+ Returns:
145
+ List of normalized allowed domain strings
146
+
147
+ Raises:
148
+ DomainsResolverError: If resolver invocation fails
149
+ VerifyAccessTokenError: If issuer is not in allowed domains
150
+ """
151
+ # Single domain mode
152
+ if self._allowed_domains is None:
153
+ return None
154
+
155
+ # Static list mode
156
+ if isinstance(self._allowed_domains, list):
157
+ allowed_domains = self._allowed_domains
158
+ # Dynamic resolver mode
159
+ elif callable(self._allowed_domains):
160
+ # Build resolver context
161
+ context = {
162
+ 'request_url': request_url,
163
+ 'request_headers': request_headers,
164
+ 'unverified_iss': unverified_iss
165
+ }
166
+
167
+ # Invoke resolver (supports both sync and async resolvers)
168
+ try:
169
+ result = self._allowed_domains(context)
170
+ if asyncio.iscoroutine(result) or asyncio.isfuture(result):
171
+ result = await result
172
+ except Exception as e:
173
+ raise DomainsResolverError(
174
+ f"Domains resolver function failed: {str(e)}"
175
+ ) from e
176
+
177
+ # Validate resolver result
178
+ if not isinstance(result, list):
179
+ raise DomainsResolverError(
180
+ "Domains resolver must return a list"
181
+ )
182
+
183
+ if len(result) == 0:
184
+ raise DomainsResolverError(
185
+ "Domains resolver returned an empty list"
186
+ )
187
+
188
+ if not all(isinstance(d, str) and d.strip() for d in result):
189
+ raise DomainsResolverError(
190
+ "Domains resolver must return a list of non-empty strings"
191
+ )
192
+
193
+ # Normalize domains from resolver
194
+ try:
195
+ allowed_domains = [normalize_domain(d) for d in result]
196
+ except ValueError as e:
197
+ raise DomainsResolverError(
198
+ f"Domains resolver returned invalid domain: {str(e)}"
199
+ ) from e
200
+ else:
201
+ # Should never happen due to __init__ validation
202
+ raise ConfigurationError("Invalid _allowed_domains type")
203
+
204
+ # Validate issuer is in allowed domains
205
+ if unverified_iss not in allowed_domains:
206
+ raise VerifyAccessTokenError(
207
+ "Token issuer is not in the list of allowed domains"
208
+ )
209
+
210
+ return allowed_domains
69
211
 
70
212
  async def verify_request(
71
213
  self,
@@ -89,10 +231,10 @@ class ApiClient:
89
231
  - "authorization": The Authorization header value (required)
90
232
  - "dpop": The DPoP proof header value (required for DPoP)
91
233
  http_method: The HTTP method (required for DPoP)
92
- http_url: The HTTP URL (required for DPoP)
234
+ http_url: The HTTP URL (required for DPoP, also used for MCD resolver context)
93
235
 
94
236
  Returns:
95
- The decoded access token claims
237
+ The decoded access token claims, including `act` when present.
96
238
 
97
239
  Raises:
98
240
  MissingRequiredArgumentError: If required args are missing
@@ -171,7 +313,11 @@ class ApiClient:
171
313
  )
172
314
 
173
315
  try:
174
- access_token_claims = await self.verify_access_token(token)
316
+ access_token_claims = await self.verify_access_token(
317
+ token,
318
+ request_url=http_url,
319
+ request_headers=headers
320
+ )
175
321
  except VerifyAccessTokenError as e:
176
322
  raise self._prepare_error(e, auth_scheme=scheme)
177
323
 
@@ -219,7 +365,11 @@ class ApiClient:
219
365
 
220
366
  if scheme == "bearer":
221
367
  try:
222
- claims = await self.verify_access_token(token)
368
+ claims = await self.verify_access_token(
369
+ token,
370
+ request_url=http_url,
371
+ request_headers=headers
372
+ )
223
373
  if claims.get("cnf") and isinstance(claims["cnf"], dict) and claims["cnf"].get("jkt"):
224
374
  if self.options.dpop_enabled:
225
375
  raise self._prepare_error(
@@ -245,6 +395,8 @@ class ApiClient:
245
395
  async def verify_access_token(
246
396
  self,
247
397
  access_token: str,
398
+ request_url: Optional[str] = None,
399
+ request_headers: Optional[dict] = None,
248
400
  required_claims: Optional[list[str]] = None
249
401
  ) -> dict[str, Any]:
250
402
  """
@@ -255,25 +407,113 @@ class ApiClient:
255
407
  - Checks standard claims: 'iss', 'aud', 'exp', 'iat'
256
408
  - Checks extra required claims if 'required_claims' is provided.
257
409
 
410
+ Args:
411
+ access_token: The JWT access token to verify
412
+ request_url: Optional request URL for dynamic domain resolvers
413
+ request_headers: Optional request headers dict for dynamic domain resolvers
414
+ required_claims: Optional list of additional claim names that must be present
415
+
258
416
  Returns:
259
- The decoded token claims if valid.
417
+ The decoded token claims if valid, including `act` when present.
260
418
 
261
419
  Raises:
262
420
  MissingRequiredArgumentError: If no token is provided.
263
421
  VerifyAccessTokenError: If verification fails (signature, claims mismatch, etc.).
422
+ DomainsResolverError: If domains resolver function fails.
264
423
  """
265
424
  if not access_token:
266
425
  raise MissingRequiredArgumentError("access_token")
267
426
 
268
427
  required_claims = required_claims or []
269
428
 
429
+ # Extract header and payload without signature verification
270
430
  try:
271
431
  header = get_unverified_header(access_token)
272
- kid = header["kid"]
273
432
  except Exception as e:
274
433
  raise VerifyAccessTokenError(f"Failed to parse token header: {str(e)}") from e
275
434
 
276
- jwks_data = await self._load_jwks()
435
+ # Reject symmetric algorithms
436
+ alg = header.get('alg', '')
437
+ if alg.startswith('HS'):
438
+ raise VerifyAccessTokenError(
439
+ f"Symmetric algorithm '{alg}' is not supported. "
440
+ "Only asymmetric algorithms (e.g., RS256) are allowed."
441
+ )
442
+
443
+ # Extract and validate issuer claim (before network calls)
444
+ try:
445
+ unverified_payload = get_unverified_payload(access_token)
446
+ except Exception as e:
447
+ raise VerifyAccessTokenError(f"Failed to parse token payload: {str(e)}") from e
448
+
449
+ unverified_iss = unverified_payload.get('iss')
450
+ if not unverified_iss:
451
+ raise VerifyAccessTokenError("Token missing 'iss' claim")
452
+
453
+ # Normalize issuer for validation
454
+ try:
455
+ normalized_iss = normalize_domain(unverified_iss)
456
+ except ValueError as e:
457
+ raise VerifyAccessTokenError(f"Invalid token issuer format: {str(e)}") from e
458
+
459
+ # Validate issuer against allowed domains (MCD)
460
+ if self._allowed_domains is not None:
461
+ await self._resolve_allowed_domains(
462
+ normalized_iss,
463
+ request_url=request_url,
464
+ request_headers=request_headers
465
+ )
466
+
467
+ # Fetch OIDC discovery metadata
468
+ try:
469
+ if self._allowed_domains is not None:
470
+ metadata = await self._discover(issuer=normalized_iss)
471
+ else:
472
+ metadata = await self._discover()
473
+ except VerifyAccessTokenError:
474
+ raise
475
+ except Exception as e:
476
+ raise VerifyAccessTokenError(
477
+ f"Failed to fetch OIDC discovery metadata: {str(e)}"
478
+ ) from e
479
+
480
+ # First issuer validation: Prevent issuer confusion attacks
481
+ discovery_issuer = metadata.get("issuer")
482
+ if not discovery_issuer:
483
+ raise VerifyAccessTokenError("Discovery metadata missing 'issuer' field")
484
+
485
+ # Normalize discovery issuer for comparison
486
+ try:
487
+ normalized_discovery_issuer = normalize_domain(discovery_issuer)
488
+ except ValueError as e:
489
+ raise VerifyAccessTokenError(f"Invalid discovery issuer format: {str(e)}") from e
490
+
491
+ if normalized_iss != normalized_discovery_issuer:
492
+ raise VerifyAccessTokenError(
493
+ "Token issuer does not match the discovery issuer"
494
+ )
495
+
496
+ # Extract JWKS URI from discovery metadata
497
+ jwks_uri = metadata.get("jwks_uri")
498
+ if not jwks_uri:
499
+ raise VerifyAccessTokenError("Discovery metadata missing 'jwks_uri' field")
500
+
501
+ # Fetch JWKS from discovery's jwks_uri
502
+ try:
503
+ jwks_data = await self._fetch_jwks(jwks_uri)
504
+ except VerifyAccessTokenError:
505
+ raise
506
+ except Exception as e:
507
+ raise VerifyAccessTokenError(
508
+ f"Failed to fetch JWKS: {str(e)}"
509
+ ) from e
510
+
511
+ # Extract kid for JWKS lookup
512
+ kid = header.get("kid")
513
+ if not kid:
514
+ raise VerifyAccessTokenError("Token header missing 'kid' claim")
515
+
516
+ # Find matching key
277
517
  matching_key_dict = None
278
518
  for key_dict in jwks_data["keys"]:
279
519
  if key_dict.get("kid") == kid:
@@ -281,8 +521,9 @@ class ApiClient:
281
521
  break
282
522
 
283
523
  if not matching_key_dict:
284
- raise VerifyAccessTokenError(f"No matching key found for kid: {kid}")
524
+ raise VerifyAccessTokenError("No matching key found in JWKS")
285
525
 
526
+ # Import public key and verify signature
286
527
  public_key = JsonWebKey.import_key(matching_key_dict)
287
528
 
288
529
  if isinstance(access_token, str) and access_token.startswith("b'"):
@@ -292,11 +533,11 @@ class ApiClient:
292
533
  except Exception as e:
293
534
  raise VerifyAccessTokenError(f"Signature verification failed: {str(e)}") from e
294
535
 
295
- metadata = await self._discover()
296
- issuer = metadata["issuer"]
297
-
298
- if claims.get("iss") != issuer:
299
- raise VerifyAccessTokenError("Issuer mismatch")
536
+ # Second issuer validation: Ensure verified token wasn't tampered
537
+ if claims.get("iss") != discovery_issuer:
538
+ raise VerifyAccessTokenError(
539
+ "Verified Token issuer does not match the discovery issuer"
540
+ )
300
541
 
301
542
  expected_aud = self.options.audience
302
543
  actual_aud = claims.get("aud")
@@ -555,7 +796,7 @@ class ApiClient:
555
796
  Dictionary containing:
556
797
  - access_token (str): The Auth0 access token
557
798
  - expires_in (int): Token lifetime in seconds
558
- - expires_at (int): Unix timestamp when token expires
799
+ - expires_at (int): Absolute expiration time as a Unix timestamp in seconds, calculated by the SDK from expires_in
559
800
  - id_token (str, optional): OpenID Connect ID token
560
801
  - refresh_token (str, optional): Refresh token
561
802
  - scope (str, optional): Granted scopes
@@ -723,6 +964,64 @@ class ApiClient:
723
964
  exc
724
965
  )
725
966
 
967
+ async def get_token_on_behalf_of(
968
+ self,
969
+ access_token: str,
970
+ audience: str,
971
+ scope: Optional[str] = None,
972
+ ) -> OnBehalfOfTokenResult:
973
+ """
974
+ Exchange an Auth0 access token for another Auth0 access token targeting a downstream API
975
+ while acting on behalf of the same end user (OBO).
976
+
977
+ This is a convenience wrapper around get_token_by_exchange_profile() that fixes the
978
+ RFC 8693 token types for Auth0 access-token-to-access-token exchange.
979
+
980
+ Args:
981
+ access_token: The Auth0 access token to exchange
982
+ audience: Target API identifier for the exchanged access token
983
+ scope: Optional space-separated OAuth 2.0 scopes to request
984
+
985
+ Returns:
986
+ Dictionary containing:
987
+ - access_token (str): The exchanged Auth0 access token
988
+ - expires_in (int): Token lifetime in seconds
989
+ - expires_at (int): Absolute expiration time as a Unix timestamp in seconds, calculated by the SDK from expires_in
990
+ - scope (str, optional): Granted scopes
991
+ - token_type (str, optional): Token type (typically "Bearer")
992
+ - issued_token_type (str, optional): RFC 8693 issued token type identifier
993
+
994
+ Raises:
995
+ MissingRequiredArgumentError: If required parameters are missing
996
+ GetTokenByExchangeProfileError: If client credentials are not configured or validation fails
997
+ ApiError: If the token endpoint returns an error
998
+ """
999
+ if not audience:
1000
+ raise MissingRequiredArgumentError("audience")
1001
+
1002
+ result = await self.get_token_by_exchange_profile(
1003
+ subject_token=access_token,
1004
+ subject_token_type=OBO_ACCESS_TOKEN_TYPE,
1005
+ audience=audience,
1006
+ scope=scope,
1007
+ requested_token_type=OBO_ACCESS_TOKEN_TYPE,
1008
+ )
1009
+
1010
+ obo_result: OnBehalfOfTokenResult = {
1011
+ "access_token": result["access_token"],
1012
+ "expires_in": result["expires_in"],
1013
+ "expires_at": result["expires_at"],
1014
+ }
1015
+
1016
+ if "scope" in result:
1017
+ obo_result["scope"] = result["scope"]
1018
+ if "token_type" in result:
1019
+ obo_result["token_type"] = result["token_type"]
1020
+ if "issued_token_type" in result:
1021
+ obo_result["issued_token_type"] = result["issued_token_type"]
1022
+
1023
+ return obo_result
1024
+
726
1025
  # ===== Private Methods =====
727
1026
 
728
1027
  def _apply_extra(
@@ -767,25 +1066,73 @@ class ApiClient:
767
1066
  else:
768
1067
  params[key] = str(v)
769
1068
 
770
- async def _discover(self) -> dict[str, Any]:
771
- """Lazy-load OIDC discovery metadata."""
772
- if self._metadata is None:
773
- self._metadata = await fetch_oidc_metadata(
774
- domain=self.options.domain,
775
- custom_fetch=self.options.custom_fetch
776
- )
777
- return self._metadata
778
-
779
- async def _load_jwks(self) -> dict[str, Any]:
780
- """Fetches and caches JWKS data from the OIDC metadata."""
781
- if self._jwks_data is None:
782
- metadata = await self._discover()
783
- jwks_uri = metadata["jwks_uri"]
784
- self._jwks_data = await fetch_jwks(
785
- jwks_uri=jwks_uri,
786
- custom_fetch=self.options.custom_fetch
787
- )
788
- return self._jwks_data
1069
+ async def _discover(self, issuer: Optional[str] = None) -> dict[str, Any]:
1070
+ """
1071
+ Lazy-load OIDC discovery metadata.
1072
+
1073
+ Args:
1074
+ issuer: Optional issuer URL to fetch discovery from (MCD mode).
1075
+ If provided, extracts domain from issuer URL.
1076
+ If None, uses configured domain.
1077
+
1078
+ Returns:
1079
+ OIDC discovery metadata dictionary
1080
+ """
1081
+ if issuer:
1082
+ cache_key = issuer # Already normalized by caller
1083
+ domain = issuer.replace('https://', '').replace('http://', '').rstrip('/')
1084
+ else:
1085
+ domain = self.options.domain
1086
+ cache_key = normalize_domain(f"https://{domain}")
1087
+
1088
+ cached = self._discovery_cache.get(cache_key)
1089
+ if cached:
1090
+ return cached
1091
+
1092
+ metadata, max_age = await fetch_oidc_metadata(
1093
+ domain=domain,
1094
+ custom_fetch=self.options.custom_fetch
1095
+ )
1096
+
1097
+ effective_ttl = self._cache_ttl
1098
+ if max_age is not None and self._cache_ttl is not None:
1099
+ effective_ttl = min(max_age, self._cache_ttl)
1100
+ elif max_age is not None:
1101
+ effective_ttl = max_age
1102
+
1103
+ self._discovery_cache.set(cache_key, metadata, ttl_seconds=effective_ttl)
1104
+ return metadata
1105
+
1106
+ async def _fetch_jwks(self, jwks_uri: str) -> dict[str, Any]:
1107
+ """
1108
+ Fetch JWKS with per-URI caching.
1109
+
1110
+ Args:
1111
+ jwks_uri: The JWKS URI to fetch from
1112
+
1113
+ Returns:
1114
+ JWKS data dictionary
1115
+
1116
+ """
1117
+ cache_key = jwks_uri
1118
+
1119
+ cached = self._jwks_cache.get(cache_key)
1120
+ if cached:
1121
+ return cached
1122
+
1123
+ jwks_data, max_age = await fetch_jwks(
1124
+ jwks_uri=jwks_uri,
1125
+ custom_fetch=self.options.custom_fetch
1126
+ )
1127
+
1128
+ effective_ttl = self._cache_ttl
1129
+ if max_age is not None and self._cache_ttl is not None:
1130
+ effective_ttl = min(max_age, self._cache_ttl)
1131
+ elif max_age is not None:
1132
+ effective_ttl = max_age
1133
+
1134
+ self._jwks_cache.set(cache_key, jwks_data, ttl_seconds=effective_ttl)
1135
+ return jwks_data
789
1136
 
790
1137
  def _validate_claims_presence(
791
1138
  self,