d365fo-client 0.2.3__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (58) hide show
  1. d365fo_client/__init__.py +7 -1
  2. d365fo_client/auth.py +9 -21
  3. d365fo_client/cli.py +25 -13
  4. d365fo_client/client.py +8 -4
  5. d365fo_client/config.py +52 -30
  6. d365fo_client/credential_sources.py +5 -0
  7. d365fo_client/main.py +1 -1
  8. d365fo_client/mcp/__init__.py +3 -1
  9. d365fo_client/mcp/auth_server/__init__.py +5 -0
  10. d365fo_client/mcp/auth_server/auth/__init__.py +30 -0
  11. d365fo_client/mcp/auth_server/auth/auth.py +372 -0
  12. d365fo_client/mcp/auth_server/auth/oauth_proxy.py +989 -0
  13. d365fo_client/mcp/auth_server/auth/providers/__init__.py +0 -0
  14. d365fo_client/mcp/auth_server/auth/providers/azure.py +325 -0
  15. d365fo_client/mcp/auth_server/auth/providers/bearer.py +25 -0
  16. d365fo_client/mcp/auth_server/auth/providers/jwt.py +547 -0
  17. d365fo_client/mcp/auth_server/auth/redirect_validation.py +65 -0
  18. d365fo_client/mcp/auth_server/dependencies.py +136 -0
  19. d365fo_client/mcp/client_manager.py +16 -67
  20. d365fo_client/mcp/fastmcp_main.py +358 -0
  21. d365fo_client/mcp/fastmcp_server.py +598 -0
  22. d365fo_client/mcp/fastmcp_utils.py +431 -0
  23. d365fo_client/mcp/main.py +40 -13
  24. d365fo_client/mcp/mixins/__init__.py +24 -0
  25. d365fo_client/mcp/mixins/base_tools_mixin.py +55 -0
  26. d365fo_client/mcp/mixins/connection_tools_mixin.py +50 -0
  27. d365fo_client/mcp/mixins/crud_tools_mixin.py +311 -0
  28. d365fo_client/mcp/mixins/database_tools_mixin.py +685 -0
  29. d365fo_client/mcp/mixins/label_tools_mixin.py +87 -0
  30. d365fo_client/mcp/mixins/metadata_tools_mixin.py +565 -0
  31. d365fo_client/mcp/mixins/performance_tools_mixin.py +109 -0
  32. d365fo_client/mcp/mixins/profile_tools_mixin.py +713 -0
  33. d365fo_client/mcp/mixins/sync_tools_mixin.py +321 -0
  34. d365fo_client/mcp/prompts/action_execution.py +1 -1
  35. d365fo_client/mcp/prompts/sequence_analysis.py +1 -1
  36. d365fo_client/mcp/tools/crud_tools.py +3 -3
  37. d365fo_client/mcp/tools/sync_tools.py +1 -1
  38. d365fo_client/mcp/utilities/__init__.py +1 -0
  39. d365fo_client/mcp/utilities/auth.py +34 -0
  40. d365fo_client/mcp/utilities/logging.py +58 -0
  41. d365fo_client/mcp/utilities/types.py +426 -0
  42. d365fo_client/metadata_v2/sync_manager_v2.py +2 -0
  43. d365fo_client/metadata_v2/sync_session_manager.py +7 -7
  44. d365fo_client/models.py +139 -139
  45. d365fo_client/output.py +2 -2
  46. d365fo_client/profile_manager.py +62 -27
  47. d365fo_client/profiles.py +118 -113
  48. d365fo_client/settings.py +355 -0
  49. d365fo_client/sync_models.py +85 -2
  50. d365fo_client/utils.py +2 -1
  51. {d365fo_client-0.2.3.dist-info → d365fo_client-0.3.0.dist-info}/METADATA +1261 -810
  52. d365fo_client-0.3.0.dist-info/RECORD +84 -0
  53. d365fo_client-0.3.0.dist-info/entry_points.txt +4 -0
  54. d365fo_client-0.2.3.dist-info/RECORD +0 -56
  55. d365fo_client-0.2.3.dist-info/entry_points.txt +0 -3
  56. {d365fo_client-0.2.3.dist-info → d365fo_client-0.3.0.dist-info}/WHEEL +0 -0
  57. {d365fo_client-0.2.3.dist-info → d365fo_client-0.3.0.dist-info}/licenses/LICENSE +0 -0
  58. {d365fo_client-0.2.3.dist-info → d365fo_client-0.3.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,547 @@
1
+ """TokenVerifier implementations for FastMCP."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import time
6
+ from dataclasses import dataclass
7
+ from typing import Any, cast
8
+
9
+ import httpx
10
+ from authlib.jose import JsonWebKey, JsonWebToken
11
+ from authlib.jose.errors import JoseError
12
+ from cryptography.hazmat.primitives import serialization
13
+ from cryptography.hazmat.primitives.asymmetric import rsa
14
+ from pydantic import AnyHttpUrl, SecretStr, field_validator
15
+ from pydantic_settings import BaseSettings, SettingsConfigDict
16
+ from typing_extensions import TypedDict
17
+
18
+ from ..auth import AccessToken, TokenVerifier
19
+ from d365fo_client.mcp.utilities.auth import parse_scopes
20
+ from d365fo_client.mcp.utilities.logging import get_logger
21
+ from d365fo_client.mcp.utilities.types import NotSet, NotSetT
22
+
23
+ logger = get_logger(__name__)
24
+
25
+
26
+ class JWKData(TypedDict, total=False):
27
+ """JSON Web Key data structure."""
28
+
29
+ kty: str # Key type (e.g., "RSA") - required
30
+ kid: str # Key ID (optional but recommended)
31
+ use: str # Usage (e.g., "sig")
32
+ alg: str # Algorithm (e.g., "RS256")
33
+ n: str # Modulus (for RSA keys)
34
+ e: str # Exponent (for RSA keys)
35
+ x5c: list[str] # X.509 certificate chain (for JWKs)
36
+ x5t: str # X.509 certificate thumbprint (for JWKs)
37
+
38
+
39
+ class JWKSData(TypedDict):
40
+ """JSON Web Key Set data structure."""
41
+
42
+ keys: list[JWKData]
43
+
44
+
45
+ @dataclass(frozen=True, kw_only=True, repr=False)
46
+ class RSAKeyPair:
47
+ """RSA key pair for JWT testing."""
48
+
49
+ private_key: SecretStr
50
+ public_key: str
51
+
52
+ @classmethod
53
+ def generate(cls) -> RSAKeyPair:
54
+ """
55
+ Generate an RSA key pair for testing.
56
+
57
+ Returns:
58
+ RSAKeyPair: Generated key pair
59
+ """
60
+ # Generate private key
61
+ private_key = rsa.generate_private_key(
62
+ public_exponent=65537,
63
+ key_size=2048,
64
+ )
65
+
66
+ # Serialize private key to PEM format
67
+ private_pem = private_key.private_bytes(
68
+ encoding=serialization.Encoding.PEM,
69
+ format=serialization.PrivateFormat.PKCS8,
70
+ encryption_algorithm=serialization.NoEncryption(),
71
+ ).decode("utf-8")
72
+
73
+ # Serialize public key to PEM format
74
+ public_pem = (
75
+ private_key.public_key()
76
+ .public_bytes(
77
+ encoding=serialization.Encoding.PEM,
78
+ format=serialization.PublicFormat.SubjectPublicKeyInfo,
79
+ )
80
+ .decode("utf-8")
81
+ )
82
+
83
+ return cls(
84
+ private_key=SecretStr(private_pem),
85
+ public_key=public_pem,
86
+ )
87
+
88
+ def create_token(
89
+ self,
90
+ subject: str = "fastmcp-user",
91
+ issuer: str = "https://fastmcp.example.com",
92
+ audience: str | list[str] | None = None,
93
+ scopes: list[str] | None = None,
94
+ expires_in_seconds: int = 3600,
95
+ additional_claims: dict[str, Any] | None = None,
96
+ kid: str | None = None,
97
+ ) -> str:
98
+ """
99
+ Generate a test JWT token for testing purposes.
100
+
101
+ Args:
102
+ subject: Subject claim (usually user ID)
103
+ issuer: Issuer claim
104
+ audience: Audience claim - can be a string or list of strings (optional)
105
+ scopes: List of scopes to include
106
+ expires_in_seconds: Token expiration time in seconds
107
+ additional_claims: Any additional claims to include
108
+ kid: Key ID to include in header
109
+ """
110
+ # Create header
111
+ header = {"alg": "RS256"}
112
+ if kid:
113
+ header["kid"] = kid
114
+
115
+ # Create payload
116
+ payload = {
117
+ "sub": subject,
118
+ "iss": issuer,
119
+ "iat": int(time.time()),
120
+ "exp": int(time.time()) + expires_in_seconds,
121
+ }
122
+
123
+ if audience:
124
+ payload["aud"] = audience
125
+
126
+ if scopes:
127
+ payload["scope"] = " ".join(scopes)
128
+
129
+ if additional_claims:
130
+ payload.update(additional_claims)
131
+
132
+ # Create JWT
133
+ jwt_lib = JsonWebToken(["RS256"])
134
+ token_bytes = jwt_lib.encode(
135
+ header, payload, self.private_key.get_secret_value()
136
+ )
137
+
138
+ return token_bytes.decode("utf-8")
139
+
140
+
141
+ class JWTVerifierSettings(BaseSettings):
142
+ """Settings for JWT token verification."""
143
+
144
+ model_config = SettingsConfigDict(
145
+ env_prefix="FASTMCP_SERVER_AUTH_JWT_",
146
+ env_file=".env",
147
+ extra="ignore",
148
+ )
149
+
150
+ public_key: str | None = None
151
+ jwks_uri: str | None = None
152
+ issuer: str | None = None
153
+ algorithm: str | None = None
154
+ audience: str | list[str] | None = None
155
+ required_scopes: list[str] | None = None
156
+ base_url: AnyHttpUrl | str | None = None
157
+
158
+ @field_validator("required_scopes", mode="before")
159
+ @classmethod
160
+ def _parse_scopes(cls, v):
161
+ return parse_scopes(v)
162
+
163
+
164
+ class JWTVerifier(TokenVerifier):
165
+ """
166
+ JWT token verifier supporting both asymmetric (RSA/ECDSA) and symmetric (HMAC) algorithms.
167
+
168
+ This verifier validates JWT tokens using various signing algorithms:
169
+ - **Asymmetric algorithms** (RS256/384/512, ES256/384/512, PS256/384/512):
170
+ Uses public/private key pairs. Ideal for external clients and services where
171
+ only the authorization server has the private key.
172
+ - **Symmetric algorithms** (HS256/384/512): Uses a shared secret for both
173
+ signing and verification. Perfect for internal microservices and trusted
174
+ environments where the secret can be securely shared.
175
+
176
+ Use this when:
177
+ - You have JWT tokens issued by an external service (asymmetric)
178
+ - You need JWKS support for automatic key rotation (asymmetric)
179
+ - You have internal microservices sharing a secret key (symmetric)
180
+ - Your tokens contain standard OAuth scopes and claims
181
+ """
182
+
183
+ def __init__(
184
+ self,
185
+ *,
186
+ public_key: str | None | NotSetT = NotSet,
187
+ jwks_uri: str | None | NotSetT = NotSet,
188
+ issuer: str | None | NotSetT = NotSet,
189
+ audience: str | list[str] | None | NotSetT = NotSet,
190
+ algorithm: str | None | NotSetT = NotSet,
191
+ required_scopes: list[str] | None | NotSetT = NotSet,
192
+ base_url: AnyHttpUrl | str | None | NotSetT = NotSet,
193
+ ):
194
+ """
195
+ Initialize the JWT token verifier.
196
+
197
+ Args:
198
+ public_key: For asymmetric algorithms (RS256, ES256, etc.): PEM-encoded public key.
199
+ For symmetric algorithms (HS256, HS384, HS512): The shared secret string.
200
+ jwks_uri: URI to fetch JSON Web Key Set (only for asymmetric algorithms)
201
+ issuer: Expected issuer claim
202
+ audience: Expected audience claim(s)
203
+ algorithm: JWT signing algorithm. Supported algorithms:
204
+ - Asymmetric: RS256/384/512, ES256/384/512, PS256/384/512 (default: RS256)
205
+ - Symmetric: HS256, HS384, HS512
206
+ required_scopes: Required scopes for all tokens
207
+ base_url: Base URL for TokenVerifier protocol
208
+ """
209
+ settings = JWTVerifierSettings.model_validate(
210
+ {
211
+ k: v
212
+ for k, v in {
213
+ "public_key": public_key,
214
+ "jwks_uri": jwks_uri,
215
+ "issuer": issuer,
216
+ "audience": audience,
217
+ "algorithm": algorithm,
218
+ "required_scopes": required_scopes,
219
+ "base_url": base_url,
220
+ }.items()
221
+ if v is not NotSet
222
+ }
223
+ )
224
+
225
+ if not settings.public_key and not settings.jwks_uri:
226
+ raise ValueError("Either public_key or jwks_uri must be provided")
227
+
228
+ if settings.public_key and settings.jwks_uri:
229
+ raise ValueError("Provide either public_key or jwks_uri, not both")
230
+
231
+ algorithm = settings.algorithm or "RS256"
232
+ if algorithm not in {
233
+ "HS256",
234
+ "HS384",
235
+ "HS512",
236
+ "RS256",
237
+ "RS384",
238
+ "RS512",
239
+ "ES256",
240
+ "ES384",
241
+ "ES512",
242
+ "PS256",
243
+ "PS384",
244
+ "PS512",
245
+ }:
246
+ raise ValueError(f"Unsupported algorithm: {algorithm}.")
247
+
248
+ # Initialize parent TokenVerifier
249
+ super().__init__(
250
+ base_url=settings.base_url,
251
+ required_scopes=settings.required_scopes,
252
+ )
253
+
254
+ self.algorithm = algorithm
255
+ self.issuer = settings.issuer
256
+ self.audience = settings.audience
257
+ self.public_key = settings.public_key
258
+ self.jwks_uri = settings.jwks_uri
259
+ self.jwt = JsonWebToken([self.algorithm])
260
+ self.logger = get_logger(__name__)
261
+
262
+ # Simple JWKS cache
263
+ self._jwks_cache: dict[str, str] = {}
264
+ self._jwks_cache_time: float = 0
265
+ self._cache_ttl = 3600 # 1 hour
266
+
267
+ async def _get_verification_key(self, token: str) -> str:
268
+ """Get the verification key for the token."""
269
+ if self.public_key:
270
+ return self.public_key
271
+
272
+ # Extract kid from token header for JWKS lookup
273
+ try:
274
+ import base64
275
+ import json
276
+
277
+ header_b64 = token.split(".")[0]
278
+ header_b64 += "=" * (4 - len(header_b64) % 4) # Add padding
279
+ header = json.loads(base64.urlsafe_b64decode(header_b64))
280
+ kid = header.get("kid")
281
+
282
+ return await self._get_jwks_key(kid)
283
+
284
+ except Exception as e:
285
+ raise ValueError(f"Failed to extract key ID from token: {e}")
286
+
287
+ async def _get_jwks_key(self, kid: str | None) -> str:
288
+ """Fetch key from JWKS with simple caching."""
289
+ if not self.jwks_uri:
290
+ raise ValueError("JWKS URI not configured")
291
+
292
+ current_time = time.time()
293
+
294
+ # Check cache first
295
+ if current_time - self._jwks_cache_time < self._cache_ttl:
296
+ if kid and kid in self._jwks_cache:
297
+ return self._jwks_cache[kid]
298
+ elif not kid and len(self._jwks_cache) == 1:
299
+ # If no kid but only one key cached, use it
300
+ return next(iter(self._jwks_cache.values()))
301
+
302
+ # Fetch JWKS
303
+ try:
304
+ async with httpx.AsyncClient() as client:
305
+ response = await client.get(self.jwks_uri)
306
+ response.raise_for_status()
307
+ jwks_data = response.json()
308
+
309
+ # Cache all keys
310
+ self._jwks_cache = {}
311
+ for key_data in jwks_data.get("keys", []):
312
+ key_kid = key_data.get("kid")
313
+ jwk = JsonWebKey.import_key(key_data)
314
+ public_key = jwk.get_public_key() # type: ignore
315
+
316
+ if key_kid:
317
+ self._jwks_cache[key_kid] = public_key
318
+ else:
319
+ # Key without kid - use a default identifier
320
+ self._jwks_cache["_default"] = public_key
321
+
322
+ self._jwks_cache_time = current_time
323
+
324
+ # Select the appropriate key
325
+ if kid:
326
+ if kid not in self._jwks_cache:
327
+ self.logger.debug(
328
+ "JWKS key lookup failed: key ID '%s' not found", kid
329
+ )
330
+ raise ValueError(f"Key ID '{kid}' not found in JWKS")
331
+ return self._jwks_cache[kid]
332
+ else:
333
+ # No kid in token - only allow if there's exactly one key
334
+ if len(self._jwks_cache) == 1:
335
+ return next(iter(self._jwks_cache.values()))
336
+ elif len(self._jwks_cache) > 1:
337
+ raise ValueError(
338
+ "Multiple keys in JWKS but no key ID (kid) in token"
339
+ )
340
+ else:
341
+ raise ValueError("No keys found in JWKS")
342
+
343
+ except httpx.HTTPError as e:
344
+ raise ValueError(f"Failed to fetch JWKS: {e}")
345
+ except Exception as e:
346
+ self.logger.debug(f"JWKS fetch failed: {e}")
347
+ raise ValueError(f"Failed to fetch JWKS: {e}")
348
+
349
+ def _extract_scopes(self, claims: dict[str, Any]) -> list[str]:
350
+ """
351
+ Extract scopes from JWT claims. Supports both 'scope' and 'scp'
352
+ claims.
353
+
354
+ Checks the `scope` claim first (standard OAuth2 claim), then the `scp`
355
+ claim (used by some Identity Providers).
356
+ """
357
+ for claim in ["scope", "scp"]:
358
+ if claim in claims:
359
+ if isinstance(claims[claim], str):
360
+ return claims[claim].split()
361
+ elif isinstance(claims[claim], list):
362
+ return claims[claim]
363
+
364
+ return []
365
+
366
+ async def load_access_token(self, token: str) -> AccessToken | None:
367
+ """
368
+ Validates the provided JWT bearer token.
369
+
370
+ Args:
371
+ token: The JWT token string to validate
372
+
373
+ Returns:
374
+ AccessToken object if valid, None if invalid or expired
375
+ """
376
+ try:
377
+ # Get verification key (static or from JWKS)
378
+ verification_key = await self._get_verification_key(token)
379
+
380
+ # Decode and verify the JWT token
381
+ claims = self.jwt.decode(token, verification_key)
382
+
383
+ # Extract client ID early for logging
384
+ client_id = claims.get("client_id") or claims.get("sub") or "unknown"
385
+
386
+ # Validate expiration
387
+ exp = claims.get("exp")
388
+ if exp and exp < time.time():
389
+ self.logger.debug(
390
+ "Token validation failed: expired token for client %s", client_id
391
+ )
392
+ self.logger.info("Bearer token rejected for client %s", client_id)
393
+ return None
394
+
395
+ # Validate issuer - note we use issuer instead of issuer_url here because
396
+ # issuer is optional, allowing users to make this check optional
397
+ if self.issuer:
398
+ if claims.get("iss") != self.issuer:
399
+ self.logger.debug(
400
+ "Token validation failed: issuer mismatch for client %s",
401
+ client_id,
402
+ )
403
+ self.logger.info("Bearer token rejected for client %s", client_id)
404
+ return None
405
+
406
+ # Validate audience if configured
407
+ if self.audience:
408
+ aud = claims.get("aud")
409
+
410
+ # Handle different combinations of audience types
411
+ audience_valid = False
412
+ if isinstance(self.audience, list):
413
+ # self.audience is a list - check if any expected audience is present
414
+ if isinstance(aud, list):
415
+ # Both are lists - check for intersection
416
+ audience_valid = any(
417
+ expected in aud for expected in self.audience
418
+ )
419
+ else:
420
+ # aud is a string - check if it's in our expected list
421
+ audience_valid = aud in cast(list, self.audience)
422
+ else:
423
+ # self.audience is a string - use original logic
424
+ if isinstance(aud, list):
425
+ audience_valid = self.audience in aud
426
+ else:
427
+ audience_valid = aud == self.audience
428
+
429
+ if not audience_valid:
430
+ self.logger.debug(
431
+ "Token validation failed: audience mismatch for client %s",
432
+ client_id,
433
+ )
434
+ self.logger.info("Bearer token rejected for client %s", client_id)
435
+ return None
436
+
437
+ # Extract scopes
438
+ scopes = self._extract_scopes(claims)
439
+
440
+ # Check required scopes
441
+ if self.required_scopes:
442
+ token_scopes = set(scopes)
443
+ required_scopes = set(self.required_scopes)
444
+ if not required_scopes.issubset(token_scopes):
445
+ self.logger.debug(
446
+ "Token missing required scopes. Has: %s, Required: %s",
447
+ token_scopes,
448
+ required_scopes,
449
+ )
450
+ self.logger.info("Bearer token rejected for client %s", client_id)
451
+ return None
452
+
453
+ return AccessToken(
454
+ token=token,
455
+ client_id=str(client_id),
456
+ scopes=scopes,
457
+ expires_at=int(exp) if exp else None,
458
+ claims=claims,
459
+ )
460
+
461
+ except JoseError:
462
+ self.logger.debug("Token validation failed: JWT signature/format invalid")
463
+ return None
464
+ except Exception as e:
465
+ self.logger.debug("Token validation failed: %s", str(e))
466
+ return None
467
+
468
+ async def verify_token(self, token: str) -> AccessToken | None:
469
+ """
470
+ Verify a bearer token and return access info if valid.
471
+
472
+ This method implements the TokenVerifier protocol by delegating
473
+ to our existing load_access_token method.
474
+
475
+ Args:
476
+ token: The JWT token string to validate
477
+
478
+ Returns:
479
+ AccessToken object if valid, None if invalid or expired
480
+ """
481
+ return await self.load_access_token(token)
482
+
483
+
484
+ class StaticTokenVerifier(TokenVerifier):
485
+ """
486
+ Simple static token verifier for testing and development.
487
+
488
+ This verifier validates tokens against a predefined dictionary of valid token
489
+ strings and their associated claims. When a token string matches a key in the
490
+ dictionary, the verifier returns the corresponding claims as if the token was
491
+ validated by a real authorization server.
492
+
493
+ Use this when:
494
+ - You're developing or testing locally without a real OAuth server
495
+ - You need predictable tokens for automated testing
496
+ - You want to simulate different users/scopes without complex setup
497
+ - You're prototyping and need simple API key-style authentication
498
+
499
+ WARNING: Never use this in production - tokens are stored in plain text!
500
+ """
501
+
502
+ def __init__(
503
+ self,
504
+ tokens: dict[str, dict[str, Any]],
505
+ required_scopes: list[str] | None = None,
506
+ ):
507
+ """
508
+ Initialize the static token verifier.
509
+
510
+ Args:
511
+ tokens: Dict mapping token strings to token metadata
512
+ Each token should have: client_id, scopes, expires_at (optional)
513
+ required_scopes: Required scopes for all tokens
514
+ """
515
+ super().__init__(required_scopes=required_scopes)
516
+ self.tokens = tokens
517
+
518
+ async def verify_token(self, token: str) -> AccessToken | None:
519
+ """Verify token against static token dictionary."""
520
+ token_data = self.tokens.get(token)
521
+ if not token_data:
522
+ return None
523
+
524
+ # Check expiration if present
525
+ expires_at = token_data.get("expires_at")
526
+ if expires_at is not None and expires_at < time.time():
527
+ return None
528
+
529
+ scopes = token_data.get("scopes", [])
530
+
531
+ # Check required scopes
532
+ if self.required_scopes:
533
+ token_scopes = set(scopes)
534
+ required_scopes = set(self.required_scopes)
535
+ if not required_scopes.issubset(token_scopes):
536
+ logger.debug(
537
+ f"Token missing required scopes. Has: {token_scopes}, Required: {required_scopes}"
538
+ )
539
+ return None
540
+
541
+ return AccessToken(
542
+ token=token,
543
+ client_id=token_data["client_id"],
544
+ scopes=scopes,
545
+ expires_at=expires_at,
546
+ claims=token_data,
547
+ )
@@ -0,0 +1,65 @@
1
+ """Utilities for validating client redirect URIs in OAuth flows."""
2
+
3
+ import fnmatch
4
+
5
+ from pydantic import AnyUrl
6
+
7
+
8
+ def matches_allowed_pattern(uri: str, pattern: str) -> bool:
9
+ """Check if a URI matches an allowed pattern with wildcard support.
10
+
11
+ Patterns support * wildcard matching:
12
+ - http://localhost:* matches any localhost port
13
+ - http://127.0.0.1:* matches any 127.0.0.1 port
14
+ - https://*.example.com/* matches any subdomain of example.com
15
+ - https://app.example.com/auth/* matches any path under /auth/
16
+
17
+ Args:
18
+ uri: The redirect URI to validate
19
+ pattern: The allowed pattern (may contain wildcards)
20
+
21
+ Returns:
22
+ True if the URI matches the pattern
23
+ """
24
+ # Use fnmatch for wildcard matching
25
+ return fnmatch.fnmatch(uri, pattern)
26
+
27
+
28
+ def validate_redirect_uri(
29
+ redirect_uri: str | AnyUrl | None,
30
+ allowed_patterns: list[str] | None,
31
+ ) -> bool:
32
+ """Validate a redirect URI against allowed patterns.
33
+
34
+ Args:
35
+ redirect_uri: The redirect URI to validate
36
+ allowed_patterns: List of allowed patterns. If None, all URIs are allowed (for DCR compatibility).
37
+ If empty list, no URIs are allowed.
38
+ To restrict to localhost only, explicitly pass DEFAULT_LOCALHOST_PATTERNS.
39
+
40
+ Returns:
41
+ True if the redirect URI is allowed
42
+ """
43
+ if redirect_uri is None:
44
+ return True # None is allowed (will use client's default)
45
+
46
+ uri_str = str(redirect_uri)
47
+
48
+ # If no patterns specified, allow all for DCR compatibility
49
+ # (clients need to dynamically register with their own redirect URIs)
50
+ if allowed_patterns is None:
51
+ return True
52
+
53
+ # Check if URI matches any allowed pattern
54
+ for pattern in allowed_patterns:
55
+ if matches_allowed_pattern(uri_str, pattern):
56
+ return True
57
+
58
+ return False
59
+
60
+
61
+ # Default patterns for localhost-only validation
62
+ DEFAULT_LOCALHOST_PATTERNS = [
63
+ "http://localhost:*",
64
+ "http://127.0.0.1:*",
65
+ ]