kailash 0.6.2__py3-none-any.whl → 0.6.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (131) hide show
  1. kailash/__init__.py +3 -3
  2. kailash/api/custom_nodes_secure.py +3 -3
  3. kailash/api/gateway.py +1 -1
  4. kailash/api/studio.py +2 -3
  5. kailash/api/workflow_api.py +3 -4
  6. kailash/core/resilience/bulkhead.py +460 -0
  7. kailash/core/resilience/circuit_breaker.py +92 -10
  8. kailash/edge/discovery.py +86 -0
  9. kailash/mcp_server/__init__.py +334 -0
  10. kailash/mcp_server/advanced_features.py +1022 -0
  11. kailash/{mcp → mcp_server}/ai_registry_server.py +29 -4
  12. kailash/mcp_server/auth.py +789 -0
  13. kailash/mcp_server/client.py +712 -0
  14. kailash/mcp_server/discovery.py +1593 -0
  15. kailash/mcp_server/errors.py +673 -0
  16. kailash/mcp_server/oauth.py +1727 -0
  17. kailash/mcp_server/protocol.py +1126 -0
  18. kailash/mcp_server/registry_integration.py +587 -0
  19. kailash/mcp_server/server.py +1747 -0
  20. kailash/{mcp → mcp_server}/servers/ai_registry.py +2 -2
  21. kailash/mcp_server/transports.py +1169 -0
  22. kailash/mcp_server/utils/cache.py +510 -0
  23. kailash/middleware/auth/auth_manager.py +3 -3
  24. kailash/middleware/communication/api_gateway.py +2 -9
  25. kailash/middleware/communication/realtime.py +1 -1
  26. kailash/middleware/mcp/client_integration.py +1 -1
  27. kailash/middleware/mcp/enhanced_server.py +2 -2
  28. kailash/nodes/__init__.py +2 -0
  29. kailash/nodes/admin/audit_log.py +6 -6
  30. kailash/nodes/admin/permission_check.py +8 -8
  31. kailash/nodes/admin/role_management.py +32 -28
  32. kailash/nodes/admin/schema.sql +6 -1
  33. kailash/nodes/admin/schema_manager.py +13 -13
  34. kailash/nodes/admin/security_event.py +16 -20
  35. kailash/nodes/admin/tenant_isolation.py +3 -3
  36. kailash/nodes/admin/transaction_utils.py +3 -3
  37. kailash/nodes/admin/user_management.py +21 -22
  38. kailash/nodes/ai/a2a.py +11 -11
  39. kailash/nodes/ai/ai_providers.py +9 -12
  40. kailash/nodes/ai/embedding_generator.py +13 -14
  41. kailash/nodes/ai/intelligent_agent_orchestrator.py +19 -19
  42. kailash/nodes/ai/iterative_llm_agent.py +3 -3
  43. kailash/nodes/ai/llm_agent.py +213 -36
  44. kailash/nodes/ai/self_organizing.py +2 -2
  45. kailash/nodes/alerts/discord.py +4 -4
  46. kailash/nodes/api/graphql.py +6 -6
  47. kailash/nodes/api/http.py +12 -17
  48. kailash/nodes/api/rate_limiting.py +4 -4
  49. kailash/nodes/api/rest.py +15 -15
  50. kailash/nodes/auth/mfa.py +3 -4
  51. kailash/nodes/auth/risk_assessment.py +2 -2
  52. kailash/nodes/auth/session_management.py +5 -5
  53. kailash/nodes/auth/sso.py +143 -0
  54. kailash/nodes/base.py +6 -2
  55. kailash/nodes/base_async.py +16 -2
  56. kailash/nodes/base_with_acl.py +2 -2
  57. kailash/nodes/cache/__init__.py +9 -0
  58. kailash/nodes/cache/cache.py +1172 -0
  59. kailash/nodes/cache/cache_invalidation.py +870 -0
  60. kailash/nodes/cache/redis_pool_manager.py +595 -0
  61. kailash/nodes/code/async_python.py +2 -1
  62. kailash/nodes/code/python.py +196 -35
  63. kailash/nodes/compliance/data_retention.py +6 -6
  64. kailash/nodes/compliance/gdpr.py +5 -5
  65. kailash/nodes/data/__init__.py +10 -0
  66. kailash/nodes/data/optimistic_locking.py +906 -0
  67. kailash/nodes/data/readers.py +8 -8
  68. kailash/nodes/data/redis.py +349 -0
  69. kailash/nodes/data/sql.py +314 -3
  70. kailash/nodes/data/streaming.py +21 -0
  71. kailash/nodes/enterprise/__init__.py +8 -0
  72. kailash/nodes/enterprise/audit_logger.py +285 -0
  73. kailash/nodes/enterprise/batch_processor.py +22 -3
  74. kailash/nodes/enterprise/data_lineage.py +1 -1
  75. kailash/nodes/enterprise/mcp_executor.py +205 -0
  76. kailash/nodes/enterprise/service_discovery.py +150 -0
  77. kailash/nodes/enterprise/tenant_assignment.py +108 -0
  78. kailash/nodes/logic/async_operations.py +2 -2
  79. kailash/nodes/logic/convergence.py +1 -1
  80. kailash/nodes/logic/operations.py +1 -1
  81. kailash/nodes/monitoring/__init__.py +11 -1
  82. kailash/nodes/monitoring/health_check.py +456 -0
  83. kailash/nodes/monitoring/log_processor.py +817 -0
  84. kailash/nodes/monitoring/metrics_collector.py +627 -0
  85. kailash/nodes/monitoring/performance_benchmark.py +137 -11
  86. kailash/nodes/rag/advanced.py +7 -7
  87. kailash/nodes/rag/agentic.py +49 -2
  88. kailash/nodes/rag/conversational.py +3 -3
  89. kailash/nodes/rag/evaluation.py +3 -3
  90. kailash/nodes/rag/federated.py +3 -3
  91. kailash/nodes/rag/graph.py +3 -3
  92. kailash/nodes/rag/multimodal.py +3 -3
  93. kailash/nodes/rag/optimized.py +5 -5
  94. kailash/nodes/rag/privacy.py +3 -3
  95. kailash/nodes/rag/query_processing.py +6 -6
  96. kailash/nodes/rag/realtime.py +1 -1
  97. kailash/nodes/rag/registry.py +2 -6
  98. kailash/nodes/rag/router.py +1 -1
  99. kailash/nodes/rag/similarity.py +7 -7
  100. kailash/nodes/rag/strategies.py +4 -4
  101. kailash/nodes/security/abac_evaluator.py +6 -6
  102. kailash/nodes/security/behavior_analysis.py +5 -6
  103. kailash/nodes/security/credential_manager.py +1 -1
  104. kailash/nodes/security/rotating_credentials.py +11 -11
  105. kailash/nodes/security/threat_detection.py +8 -8
  106. kailash/nodes/testing/credential_testing.py +2 -2
  107. kailash/nodes/transform/processors.py +5 -5
  108. kailash/runtime/local.py +162 -14
  109. kailash/runtime/parameter_injection.py +425 -0
  110. kailash/runtime/parameter_injector.py +657 -0
  111. kailash/runtime/testing.py +2 -2
  112. kailash/testing/fixtures.py +2 -2
  113. kailash/workflow/builder.py +99 -18
  114. kailash/workflow/builder_improvements.py +207 -0
  115. kailash/workflow/input_handling.py +170 -0
  116. {kailash-0.6.2.dist-info → kailash-0.6.4.dist-info}/METADATA +21 -8
  117. {kailash-0.6.2.dist-info → kailash-0.6.4.dist-info}/RECORD +126 -101
  118. kailash/mcp/__init__.py +0 -53
  119. kailash/mcp/client.py +0 -445
  120. kailash/mcp/server.py +0 -292
  121. kailash/mcp/server_enhanced.py +0 -449
  122. kailash/mcp/utils/cache.py +0 -267
  123. /kailash/{mcp → mcp_server}/client_new.py +0 -0
  124. /kailash/{mcp → mcp_server}/utils/__init__.py +0 -0
  125. /kailash/{mcp → mcp_server}/utils/config.py +0 -0
  126. /kailash/{mcp → mcp_server}/utils/formatters.py +0 -0
  127. /kailash/{mcp → mcp_server}/utils/metrics.py +0 -0
  128. {kailash-0.6.2.dist-info → kailash-0.6.4.dist-info}/WHEEL +0 -0
  129. {kailash-0.6.2.dist-info → kailash-0.6.4.dist-info}/entry_points.txt +0 -0
  130. {kailash-0.6.2.dist-info → kailash-0.6.4.dist-info}/licenses/LICENSE +0 -0
  131. {kailash-0.6.2.dist-info → kailash-0.6.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1727 @@
1
+ """
2
+ OAuth 2.1 Authentication System for MCP.
3
+
4
+ This module implements a complete OAuth 2.1 authorization server and resource
5
+ server for MCP, following the latest OAuth 2.1 specification. It provides
6
+ secure authentication and authorization for MCP servers and clients.
7
+
8
+ Features:
9
+ - Complete OAuth 2.1 authorization server
10
+ - Dynamic client registration
11
+ - Multiple grant types (authorization code, client credentials)
12
+ - JWT access and refresh tokens
13
+ - Scope-based authorization
14
+ - PKCE support for public clients
15
+ - Token introspection and revocation
16
+ - Resource server middleware
17
+ - Well-known metadata endpoints
18
+
19
+ Examples:
20
+ OAuth 2.1 Authorization Server:
21
+
22
+ >>> from kailash.mcp_server.oauth import AuthorizationServer
23
+ >>>
24
+ >>> auth_server = AuthorizationServer(
25
+ ... issuer="https://auth.example.com",
26
+ ... private_key_path="private.pem",
27
+ ... client_store=InMemoryClientStore()
28
+ ... )
29
+ >>>
30
+ >>> # Register client
31
+ >>> client = await auth_server.register_client(
32
+ ... client_name="MCP Client",
33
+ ... redirect_uris=["http://localhost:8080/callback"],
34
+ ... grant_types=["authorization_code"],
35
+ ... scopes=["mcp.tools", "mcp.resources"]
36
+ ... )
37
+
38
+ Resource Server Integration:
39
+
40
+ >>> from kailash.mcp_server.oauth import ResourceServer
41
+ >>> from kailash.mcp_server import MCPServer
42
+ >>>
43
+ >>> resource_server = ResourceServer(
44
+ ... issuer="https://auth.example.com",
45
+ ... audience="mcp-api"
46
+ ... )
47
+ >>>
48
+ >>> server = MCPServer("protected-server", auth_provider=resource_server)
49
+ >>>
50
+ >>> @server.tool(required_permission="mcp.tools")
51
+ >>> def protected_tool():
52
+ ... return "Only accessible with proper token"
53
+
54
+ Client Credentials Flow:
55
+
56
+ >>> from kailash.mcp_server.oauth import OAuth2Client
57
+ >>>
58
+ >>> oauth_client = OAuth2Client(
59
+ ... client_id="client123",
60
+ ... client_secret="secret456",
61
+ ... token_endpoint="https://auth.example.com/token"
62
+ ... )
63
+ >>>
64
+ >>> token = await oauth_client.get_client_credentials_token(
65
+ ... scopes=["mcp.tools", "mcp.resources"]
66
+ ... )
67
+ """
68
+
69
+ import asyncio
70
+ import base64
71
+ import hashlib
72
+ import json
73
+ import logging
74
+ import secrets
75
+ import time
76
+ import uuid
77
+ from abc import ABC, abstractmethod
78
+ from dataclasses import asdict, dataclass, field
79
+ from datetime import datetime, timedelta
80
+ from enum import Enum
81
+ from typing import Any, Dict, List, Optional, Set, Tuple, Union
82
+ from urllib.parse import parse_qs, urlencode, urlparse
83
+
84
+ import aiohttp
85
+ import jwt
86
+ from cryptography.hazmat.primitives import hashes, serialization
87
+ from cryptography.hazmat.primitives.asymmetric import rsa
88
+
89
+ from .auth import AuthProvider
90
+ from .errors import AuthenticationError, AuthorizationError, MCPError
91
+
92
+ logger = logging.getLogger(__name__)
93
+
94
+
95
+ class GrantType(Enum):
96
+ """OAuth 2.1 grant types."""
97
+
98
+ AUTHORIZATION_CODE = "authorization_code"
99
+ CLIENT_CREDENTIALS = "client_credentials"
100
+ REFRESH_TOKEN = "refresh_token"
101
+
102
+
103
+ class TokenType(Enum):
104
+ """Token types."""
105
+
106
+ ACCESS_TOKEN = "access_token"
107
+ REFRESH_TOKEN = "refresh_token"
108
+ ID_TOKEN = "id_token"
109
+
110
+
111
+ class ClientType(Enum):
112
+ """OAuth client types."""
113
+
114
+ CONFIDENTIAL = "confidential"
115
+ PUBLIC = "public"
116
+
117
+
118
+ @dataclass
119
+ class OAuthClient:
120
+ """OAuth 2.1 client registration."""
121
+
122
+ client_id: str
123
+ client_name: str = ""
124
+ client_type: ClientType = ClientType.CONFIDENTIAL
125
+ redirect_uris: List[str] = field(default_factory=list)
126
+ grant_types: List[GrantType] = field(default_factory=list)
127
+ scopes: List[str] = field(default_factory=list)
128
+ client_secret: Optional[str] = None
129
+ response_types: List[str] = field(default_factory=lambda: ["code"])
130
+ token_endpoint_auth_method: str = "client_secret_basic"
131
+ created_at: float = field(default_factory=time.time)
132
+ metadata: Dict[str, Any] = field(default_factory=dict)
133
+
134
+ def to_dict(self) -> Dict[str, Any]:
135
+ """Convert to dictionary format."""
136
+ result = asdict(self)
137
+ result["client_type"] = self.client_type.value
138
+ result["grant_types"] = [gt.value for gt in self.grant_types]
139
+ return result
140
+
141
+ @classmethod
142
+ def from_dict(cls, data: Dict[str, Any]) -> "OAuthClient":
143
+ """Create from dictionary format."""
144
+ data = data.copy()
145
+ data["client_type"] = ClientType(data["client_type"])
146
+ data["grant_types"] = [GrantType(gt) for gt in data["grant_types"]]
147
+ return cls(**data)
148
+
149
+ def supports_grant_type(self, grant_type: GrantType) -> bool:
150
+ """Check if client supports grant type."""
151
+ return grant_type in self.grant_types
152
+
153
+ def has_scope(self, scope: str) -> bool:
154
+ """Check if client has scope."""
155
+ return scope in self.scopes
156
+
157
+ def validate_redirect_uri(self, redirect_uri: str) -> bool:
158
+ """Validate redirect URI."""
159
+ return redirect_uri in self.redirect_uris
160
+
161
+ def is_valid_redirect_uri(self, redirect_uri: str) -> bool:
162
+ """Validate redirect URI (alias for validate_redirect_uri)."""
163
+ return redirect_uri in self.redirect_uris
164
+
165
+
166
+ @dataclass
167
+ class AccessToken:
168
+ """OAuth 2.1 access token."""
169
+
170
+ token: str
171
+ client_id: str
172
+ token_type: str = "Bearer"
173
+ expires_in: int = 3600
174
+ scope: Optional[str] = None
175
+ scopes: Optional[List[str]] = None
176
+ subject: Optional[str] = None
177
+ user_id: Optional[str] = None # Alias for subject
178
+ audience: Optional[List[str]] = None
179
+ issued_at: float = field(default_factory=time.time)
180
+ expires_at: Optional[float] = None
181
+
182
+ def __post_init__(self):
183
+ # Handle user_id as alias for subject
184
+ if self.user_id and not self.subject:
185
+ self.subject = self.user_id
186
+ elif self.subject and not self.user_id:
187
+ self.user_id = self.subject
188
+
189
+ # Set expires_at if not provided
190
+ if self.expires_at is None:
191
+ self.expires_at = self.issued_at + self.expires_in
192
+
193
+ # Convert scopes list to scope string if needed
194
+ if self.scopes and not self.scope:
195
+ self.scope = " ".join(self.scopes)
196
+
197
+ def is_expired(self) -> bool:
198
+ """Check if token is expired."""
199
+ return time.time() > self.expires_at
200
+
201
+ def to_dict(self) -> Dict[str, Any]:
202
+ """Convert to dictionary format."""
203
+ return {
204
+ "access_token": self.token,
205
+ "token_type": self.token_type,
206
+ "expires_in": self.expires_in,
207
+ "scope": self.scope,
208
+ }
209
+
210
+ def has_scope(self, scope: str) -> bool:
211
+ """Check if token has a specific scope."""
212
+ if self.scopes:
213
+ return scope in self.scopes
214
+ elif self.scope:
215
+ return scope in self.scope.split()
216
+ return False
217
+
218
+
219
+ @dataclass
220
+ class RefreshToken:
221
+ """OAuth 2.1 refresh token."""
222
+
223
+ token: str
224
+ client_id: str
225
+ subject: Optional[str] = None
226
+ user_id: Optional[str] = None # Alias for subject
227
+ scope: Optional[str] = None
228
+ scopes: Optional[List[str]] = None
229
+ issued_at: float = field(default_factory=time.time)
230
+ expires_at: Optional[float] = None
231
+ is_revoked: bool = False
232
+
233
+ def __post_init__(self):
234
+ # Handle user_id as alias for subject
235
+ if self.user_id and not self.subject:
236
+ self.subject = self.user_id
237
+ elif self.subject and not self.user_id:
238
+ self.user_id = self.subject
239
+
240
+ # Convert scopes list to scope string if needed
241
+ if self.scopes and not self.scope:
242
+ self.scope = " ".join(self.scopes)
243
+
244
+ def is_expired(self) -> bool:
245
+ """Check if token is expired."""
246
+ if self.expires_at is None:
247
+ return False
248
+ return time.time() > self.expires_at
249
+
250
+ def revoke(self) -> None:
251
+ """Revoke the refresh token."""
252
+ self.is_revoked = True
253
+
254
+
255
+ @dataclass
256
+ class AuthorizationCode:
257
+ """OAuth 2.1 authorization code."""
258
+
259
+ code: str
260
+ client_id: str
261
+ redirect_uri: str
262
+ scope: Optional[str] = None
263
+ scopes: Optional[List[str]] = None
264
+ subject: Optional[str] = None
265
+ user_id: Optional[str] = None # Alias for subject
266
+ code_challenge: Optional[str] = None
267
+ code_challenge_method: Optional[str] = None
268
+ issued_at: float = field(default_factory=time.time)
269
+ expires_at: float = field(default_factory=lambda: time.time() + 600) # 10 minutes
270
+
271
+ def __post_init__(self):
272
+ # Handle user_id as alias for subject
273
+ if self.user_id and not self.subject:
274
+ self.subject = self.user_id
275
+ elif self.subject and not self.user_id:
276
+ self.user_id = self.subject
277
+
278
+ # Convert scopes list to scope string if needed
279
+ if self.scopes and not self.scope:
280
+ self.scope = " ".join(self.scopes)
281
+
282
+ def is_expired(self) -> bool:
283
+ """Check if code is expired."""
284
+ return time.time() > self.expires_at
285
+
286
+ def validate_pkce(self, code_verifier: str) -> bool:
287
+ """Validate PKCE code verifier."""
288
+ if not self.code_challenge:
289
+ return True # PKCE not used
290
+
291
+ if self.code_challenge_method == "S256":
292
+ # SHA256 challenge method
293
+ verifier_hash = hashlib.sha256(code_verifier.encode()).digest()
294
+ verifier_challenge = (
295
+ base64.urlsafe_b64encode(verifier_hash).decode().rstrip("=")
296
+ )
297
+ return verifier_challenge == self.code_challenge
298
+ elif self.code_challenge_method == "plain":
299
+ # Plain challenge method
300
+ return code_verifier == self.code_challenge
301
+ else:
302
+ return False
303
+
304
+
305
+ class ClientStore(ABC):
306
+ """Abstract base class for OAuth client storage."""
307
+
308
+ @abstractmethod
309
+ async def store_client(self, client: OAuthClient) -> None:
310
+ """Store OAuth client."""
311
+ pass
312
+
313
+ @abstractmethod
314
+ async def get_client(self, client_id: str) -> Optional[OAuthClient]:
315
+ """Get OAuth client by ID."""
316
+ pass
317
+
318
+ @abstractmethod
319
+ async def delete_client(self, client_id: str) -> bool:
320
+ """Delete OAuth client."""
321
+ pass
322
+
323
+ @abstractmethod
324
+ async def list_clients(self) -> List[OAuthClient]:
325
+ """List all OAuth clients."""
326
+ pass
327
+
328
+
329
+ class InMemoryClientStore(ClientStore):
330
+ """In-memory OAuth client store."""
331
+
332
+ def __init__(self):
333
+ """Initialize in-memory store."""
334
+ self._clients: Dict[str, OAuthClient] = {}
335
+
336
+ async def store_client(self, client: OAuthClient) -> None:
337
+ """Store OAuth client."""
338
+ self._clients[client.client_id] = client
339
+
340
+ async def get_client(self, client_id: str) -> Optional[OAuthClient]:
341
+ """Get OAuth client by ID."""
342
+ return self._clients.get(client_id)
343
+
344
+ async def delete_client(self, client_id: str) -> bool:
345
+ """Delete OAuth client."""
346
+ if client_id in self._clients:
347
+ del self._clients[client_id]
348
+ return True
349
+ return False
350
+
351
+ async def list_clients(self) -> List[OAuthClient]:
352
+ """List all OAuth clients."""
353
+ return list(self._clients.values())
354
+
355
+ async def authenticate_client(
356
+ self, client_id: str, client_secret: str
357
+ ) -> Optional[OAuthClient]:
358
+ """Authenticate OAuth client."""
359
+ client = await self.get_client(client_id)
360
+ if client and client.client_secret == client_secret:
361
+ return client
362
+ return None
363
+
364
+
365
+ class TokenStore(ABC):
366
+ """Abstract base class for token storage."""
367
+
368
+ @abstractmethod
369
+ async def store_access_token(self, token: AccessToken) -> None:
370
+ """Store access token."""
371
+ pass
372
+
373
+ @abstractmethod
374
+ async def get_access_token(self, token: str) -> Optional[AccessToken]:
375
+ """Get access token."""
376
+ pass
377
+
378
+ @abstractmethod
379
+ async def revoke_access_token(self, token: str) -> bool:
380
+ """Revoke access token."""
381
+ pass
382
+
383
+ @abstractmethod
384
+ async def store_refresh_token(self, token: RefreshToken) -> None:
385
+ """Store refresh token."""
386
+ pass
387
+
388
+ @abstractmethod
389
+ async def get_refresh_token(self, token: str) -> Optional[RefreshToken]:
390
+ """Get refresh token."""
391
+ pass
392
+
393
+ @abstractmethod
394
+ async def revoke_refresh_token(self, token: str) -> bool:
395
+ """Revoke refresh token."""
396
+ pass
397
+
398
+ @abstractmethod
399
+ async def store_authorization_code(self, code: AuthorizationCode) -> None:
400
+ """Store authorization code."""
401
+ pass
402
+
403
+ @abstractmethod
404
+ async def get_authorization_code(self, code: str) -> Optional[AuthorizationCode]:
405
+ """Get authorization code."""
406
+ pass
407
+
408
+ @abstractmethod
409
+ async def consume_authorization_code(
410
+ self, code: str
411
+ ) -> Optional[AuthorizationCode]:
412
+ """Consume authorization code (get and delete)."""
413
+ pass
414
+
415
+
416
+ class InMemoryTokenStore(TokenStore):
417
+ """In-memory token store."""
418
+
419
+ def __init__(self):
420
+ """Initialize in-memory store."""
421
+ self._access_tokens: Dict[str, AccessToken] = {}
422
+ self._refresh_tokens: Dict[str, RefreshToken] = {}
423
+ self._authorization_codes: Dict[str, AuthorizationCode] = {}
424
+
425
+ async def store_access_token(self, token: AccessToken) -> None:
426
+ """Store access token."""
427
+ self._access_tokens[token.token] = token
428
+
429
+ async def get_access_token(self, token: str) -> Optional[AccessToken]:
430
+ """Get access token."""
431
+ access_token = self._access_tokens.get(token)
432
+ if access_token and access_token.is_expired():
433
+ del self._access_tokens[token]
434
+ return None
435
+ return access_token
436
+
437
+ async def revoke_access_token(self, token: str) -> bool:
438
+ """Revoke access token."""
439
+ if token in self._access_tokens:
440
+ del self._access_tokens[token]
441
+ return True
442
+ return False
443
+
444
+ async def store_refresh_token(self, token: RefreshToken) -> None:
445
+ """Store refresh token."""
446
+ self._refresh_tokens[token.token] = token
447
+
448
+ async def get_refresh_token(self, token: str) -> Optional[RefreshToken]:
449
+ """Get refresh token."""
450
+ refresh_token = self._refresh_tokens.get(token)
451
+ if refresh_token and refresh_token.is_expired():
452
+ del self._refresh_tokens[token]
453
+ return None
454
+ return refresh_token
455
+
456
+ async def revoke_refresh_token(self, token: str) -> bool:
457
+ """Revoke refresh token."""
458
+ if token in self._refresh_tokens:
459
+ del self._refresh_tokens[token]
460
+ return True
461
+ return False
462
+
463
+ async def store_authorization_code(self, code: AuthorizationCode) -> None:
464
+ """Store authorization code."""
465
+ self._authorization_codes[code.code] = code
466
+
467
+ async def get_authorization_code(self, code: str) -> Optional[AuthorizationCode]:
468
+ """Get authorization code."""
469
+ auth_code = self._authorization_codes.get(code)
470
+ if auth_code and auth_code.is_expired():
471
+ del self._authorization_codes[code]
472
+ return None
473
+ return auth_code
474
+
475
+ async def consume_authorization_code(
476
+ self, code: str
477
+ ) -> Optional[AuthorizationCode]:
478
+ """Consume authorization code."""
479
+ auth_code = await self.get_authorization_code(code)
480
+ if auth_code:
481
+ del self._authorization_codes[code]
482
+ return auth_code
483
+
484
+
485
+ class JWTManager:
486
+ """JWT token manager for OAuth 2.1."""
487
+
488
+ def __init__(
489
+ self,
490
+ private_key: Optional[str] = None,
491
+ public_key: Optional[str] = None,
492
+ algorithm: str = "RS256",
493
+ issuer: Optional[str] = None,
494
+ private_key_pem: Optional[str] = None, # Backward compatibility
495
+ public_key_pem: Optional[str] = None, # Backward compatibility
496
+ ):
497
+ """Initialize JWT manager.
498
+
499
+ Args:
500
+ private_key: Private key for signing (PEM format)
501
+ public_key: Public key for verification (PEM format)
502
+ algorithm: JWT algorithm
503
+ issuer: Token issuer
504
+ """
505
+ self.algorithm = algorithm
506
+ self.issuer = issuer
507
+
508
+ # Handle backward compatibility
509
+ private_key = private_key or private_key_pem
510
+ public_key = public_key or public_key_pem
511
+
512
+ if private_key:
513
+ self.private_key = serialization.load_pem_private_key(
514
+ private_key.encode(), password=None
515
+ )
516
+ else:
517
+ # Generate key pair
518
+ self.private_key = rsa.generate_private_key(
519
+ public_exponent=65537, key_size=2048
520
+ )
521
+
522
+ if public_key:
523
+ self.public_key = serialization.load_pem_public_key(public_key.encode())
524
+ else:
525
+ self.public_key = self.private_key.public_key()
526
+
527
+ def create_access_token(
528
+ self,
529
+ subject: Optional[Union[str, Dict[str, Any]]] = None,
530
+ client_id: Optional[str] = None,
531
+ scope: Optional[str] = None,
532
+ audience: Optional[List[str]] = None,
533
+ expires_in: int = 3600,
534
+ ) -> Union[AccessToken, str]:
535
+ """Create JWT access token.
536
+
537
+ Args:
538
+ subject: Token subject (user ID)
539
+ client_id: OAuth client ID
540
+ scope: Token scope
541
+ audience: Token audience
542
+ expires_in: Token lifetime in seconds
543
+
544
+ Returns:
545
+ Access token
546
+ """
547
+ # Handle dictionary input for backward compatibility
548
+ token_data_dict = None
549
+ if isinstance(subject, dict):
550
+ token_data_dict = subject
551
+ subject = token_data_dict.get("user_id")
552
+ client_id = token_data_dict.get("client_id", client_id)
553
+ scope = token_data_dict.get("scope")
554
+ if not scope and "scopes" in token_data_dict:
555
+ scope = " ".join(token_data_dict["scopes"])
556
+ audience = token_data_dict.get("audience", audience)
557
+ expires_in = token_data_dict.get("expires_in", expires_in)
558
+
559
+ now = time.time()
560
+ expires_at = now + expires_in
561
+
562
+ payload = {
563
+ "iss": self.issuer,
564
+ "iat": int(now),
565
+ "exp": int(expires_at),
566
+ "jti": str(uuid.uuid4()),
567
+ "token_type": "access_token",
568
+ }
569
+
570
+ if subject:
571
+ payload["sub"] = subject
572
+ if client_id:
573
+ payload["client_id"] = client_id
574
+ if scope:
575
+ payload["scope"] = scope
576
+ if audience:
577
+ payload["aud"] = audience
578
+
579
+ # Add custom claims from token_data if it was a dict
580
+ if token_data_dict:
581
+ for key in ["user_id", "scopes"]:
582
+ if key in token_data_dict and key not in [
583
+ "client_id",
584
+ "scope",
585
+ "audience",
586
+ "expires_in",
587
+ ]:
588
+ payload[key] = token_data_dict[key]
589
+
590
+ token = jwt.encode(payload, self.private_key, algorithm=self.algorithm)
591
+
592
+ # For backward compatibility, return string if called with dict
593
+ if token_data_dict:
594
+ return token
595
+
596
+ return AccessToken(
597
+ token=token,
598
+ expires_in=expires_in,
599
+ scope=scope,
600
+ client_id=client_id,
601
+ subject=subject,
602
+ audience=audience,
603
+ issued_at=now,
604
+ expires_at=expires_at,
605
+ )
606
+
607
+ def verify_access_token(self, token: str) -> Optional[Dict[str, Any]]:
608
+ """Verify JWT access token.
609
+
610
+ Args:
611
+ token: JWT token to verify
612
+
613
+ Returns:
614
+ Token payload or None if invalid
615
+ """
616
+ try:
617
+ payload = jwt.decode(token, self.public_key, algorithms=[self.algorithm])
618
+
619
+ # Verify token type
620
+ if payload.get("token_type") != "access_token":
621
+ return None
622
+
623
+ return payload
624
+
625
+ except jwt.InvalidTokenError as e:
626
+ raise AuthenticationError(f"Invalid token: {e}")
627
+
628
+ def create_refresh_token(
629
+ self,
630
+ token_data: Union[Dict[str, Any], str],
631
+ expires_in: int = 2592000, # 30 days
632
+ ) -> str:
633
+ """Create JWT refresh token.
634
+
635
+ Args:
636
+ token_data: Token data dict or client ID
637
+ expires_in: Token lifetime in seconds
638
+
639
+ Returns:
640
+ JWT refresh token string
641
+ """
642
+ now = time.time()
643
+ expires_at = now + expires_in
644
+
645
+ payload = {
646
+ "iss": self.issuer,
647
+ "iat": int(now),
648
+ "exp": int(expires_at),
649
+ "jti": str(uuid.uuid4()),
650
+ "token_type": "refresh_token",
651
+ }
652
+
653
+ if isinstance(token_data, dict):
654
+ if "client_id" in token_data:
655
+ payload["client_id"] = token_data["client_id"]
656
+ if "user_id" in token_data:
657
+ payload["sub"] = token_data["user_id"]
658
+ payload["user_id"] = token_data["user_id"]
659
+ else:
660
+ payload["client_id"] = token_data
661
+
662
+ return jwt.encode(payload, self.private_key, algorithm=self.algorithm)
663
+
664
+ def verify_refresh_token(self, token: str) -> Optional[Dict[str, Any]]:
665
+ """Verify JWT refresh token.
666
+
667
+ Args:
668
+ token: JWT token to verify
669
+
670
+ Returns:
671
+ Token payload if valid, None otherwise
672
+ """
673
+ try:
674
+ payload = jwt.decode(
675
+ token,
676
+ self.public_key,
677
+ algorithms=[self.algorithm],
678
+ options={"verify_aud": False},
679
+ )
680
+
681
+ # Check token type
682
+ if payload.get("token_type") != "refresh_token":
683
+ raise AuthenticationError("Invalid token type")
684
+
685
+ # Check issuer
686
+ if self.issuer and payload.get("iss") != self.issuer:
687
+ raise AuthenticationError("Invalid issuer")
688
+
689
+ return payload
690
+
691
+ except jwt.ExpiredSignatureError:
692
+ raise AuthenticationError("Token expired")
693
+ except jwt.InvalidTokenError as e:
694
+ raise AuthenticationError(f"Invalid token: {e}")
695
+ except Exception as e:
696
+ logger.error(f"Token verification failed: {e}")
697
+ return None
698
+
699
+ def get_public_key_jwks(self) -> Dict[str, Any]:
700
+ """Get public key in JWKS format.
701
+
702
+ Returns:
703
+ JWKS public key
704
+ """
705
+ public_numbers = self.public_key.public_numbers()
706
+
707
+ # Convert to base64url encoding
708
+ def int_to_base64url(value: int) -> str:
709
+ byte_length = (value.bit_length() + 7) // 8
710
+ bytes_value = value.to_bytes(byte_length, byteorder="big")
711
+ return base64.urlsafe_b64encode(bytes_value).decode().rstrip("=")
712
+
713
+ return {
714
+ "keys": [
715
+ {
716
+ "kty": "RSA",
717
+ "use": "sig",
718
+ "alg": self.algorithm,
719
+ "n": int_to_base64url(public_numbers.n),
720
+ "e": int_to_base64url(public_numbers.e),
721
+ }
722
+ ]
723
+ }
724
+
725
+
726
+ class AuthorizationServer:
727
+ """OAuth 2.1 Authorization Server."""
728
+
729
+ def __init__(
730
+ self,
731
+ issuer: str,
732
+ client_store: Optional[ClientStore] = None,
733
+ token_store: Optional[TokenStore] = None,
734
+ jwt_manager: Optional[JWTManager] = None,
735
+ default_scopes: Optional[List[str]] = None,
736
+ private_key_path: Optional[str] = None, # For backward compatibility
737
+ ):
738
+ """Initialize authorization server.
739
+
740
+ Args:
741
+ issuer: Server issuer URL
742
+ client_store: Client storage
743
+ token_store: Token storage
744
+ jwt_manager: JWT manager
745
+ default_scopes: Default scopes
746
+ """
747
+ self.issuer = issuer
748
+ self.client_store = client_store or InMemoryClientStore()
749
+ self.token_store = token_store or InMemoryTokenStore()
750
+
751
+ # Create JWT manager with private key if provided
752
+ if jwt_manager:
753
+ self.jwt_manager = jwt_manager
754
+ elif private_key_path:
755
+ # Read private key from file
756
+ try:
757
+ with open(private_key_path, "r") as f:
758
+ private_key = f.read()
759
+ self.jwt_manager = JWTManager(issuer=issuer, private_key=private_key)
760
+ except FileNotFoundError:
761
+ # For testing, create a default JWT manager
762
+ self.jwt_manager = JWTManager(issuer=issuer)
763
+ else:
764
+ self.jwt_manager = JWTManager(issuer=issuer)
765
+
766
+ self.default_scopes = default_scopes or ["mcp.basic"]
767
+
768
+ async def register_client(
769
+ self,
770
+ client_name: str,
771
+ redirect_uris: Optional[List[str]] = None,
772
+ grant_types: Optional[List[str]] = None,
773
+ scopes: Optional[List[str]] = None,
774
+ client_type: Optional[str] = None,
775
+ **metadata,
776
+ ) -> OAuthClient:
777
+ """Register OAuth client.
778
+
779
+ Args:
780
+ client_name: Client name
781
+ redirect_uris: Redirect URIs
782
+ grant_types: Allowed grant types
783
+ scopes: Allowed scopes
784
+ client_type: Client type (confidential/public)
785
+ **metadata: Additional metadata
786
+
787
+ Returns:
788
+ Registered client
789
+ """
790
+ client_id = f"client_{uuid.uuid4().hex[:16]}"
791
+
792
+ # Determine client type
793
+ if client_type:
794
+ client_type_enum = ClientType(client_type)
795
+ else:
796
+ # Default to confidential
797
+ client_type_enum = ClientType.CONFIDENTIAL
798
+
799
+ # Generate client secret for confidential clients
800
+ client_secret = None
801
+ if client_type_enum == ClientType.CONFIDENTIAL:
802
+ client_secret = secrets.token_urlsafe(32)
803
+
804
+ # Parse grant types
805
+ grant_type_enums = []
806
+ if grant_types:
807
+ grant_type_enums = [GrantType(gt) for gt in grant_types]
808
+ else:
809
+ grant_type_enums = [GrantType.AUTHORIZATION_CODE]
810
+
811
+ # Use default scopes if not provided
812
+ if not scopes:
813
+ scopes = self.default_scopes.copy()
814
+
815
+ # Default redirect URIs for certain grant types
816
+ if not redirect_uris:
817
+ redirect_uris = []
818
+
819
+ client = OAuthClient(
820
+ client_id=client_id,
821
+ client_secret=client_secret,
822
+ client_name=client_name,
823
+ client_type=client_type_enum,
824
+ redirect_uris=redirect_uris,
825
+ grant_types=grant_type_enums,
826
+ scopes=scopes,
827
+ metadata=metadata,
828
+ )
829
+
830
+ await self.client_store.store_client(client)
831
+
832
+ logger.info(f"Registered OAuth client: {client_name} ({client_id})")
833
+ return client
834
+
835
+ async def create_authorization_url(
836
+ self,
837
+ client_id: str,
838
+ redirect_uri: str,
839
+ scope: Optional[str] = None,
840
+ state: Optional[str] = None,
841
+ code_challenge: Optional[str] = None,
842
+ code_challenge_method: Optional[str] = None,
843
+ ) -> str:
844
+ """Create authorization URL.
845
+
846
+ Args:
847
+ client_id: OAuth client ID
848
+ redirect_uri: Redirect URI
849
+ scope: Requested scope
850
+ state: State parameter
851
+ code_challenge: PKCE code challenge
852
+ code_challenge_method: PKCE challenge method
853
+
854
+ Returns:
855
+ Authorization URL
856
+ """
857
+ # Validate client
858
+ client = await self.client_store.get_client(client_id)
859
+ if not client:
860
+ raise AuthorizationError("Invalid client")
861
+
862
+ if not client.validate_redirect_uri(redirect_uri):
863
+ raise AuthorizationError("Invalid redirect URI")
864
+
865
+ # Build authorization URL parameters
866
+ params = {
867
+ "response_type": "code",
868
+ "client_id": client_id,
869
+ "redirect_uri": redirect_uri,
870
+ }
871
+
872
+ if scope:
873
+ params["scope"] = scope
874
+ if state:
875
+ params["state"] = state
876
+ if code_challenge:
877
+ params["code_challenge"] = code_challenge
878
+ params["code_challenge_method"] = code_challenge_method or "S256"
879
+
880
+ query_string = urlencode(params)
881
+ return f"{self.issuer}/authorize?{query_string}"
882
+
883
+ async def generate_authorization_code(
884
+ self,
885
+ client_id: str,
886
+ user_id: str,
887
+ redirect_uri: str,
888
+ scopes: Optional[List[str]] = None,
889
+ state: Optional[str] = None,
890
+ code_challenge: Optional[str] = None,
891
+ code_challenge_method: Optional[str] = None,
892
+ ) -> str:
893
+ """Generate authorization code for the user.
894
+
895
+ Args:
896
+ client_id: OAuth client ID
897
+ user_id: User ID
898
+ redirect_uri: Redirect URI
899
+ scopes: Requested scopes
900
+ state: State parameter
901
+ code_challenge: PKCE code challenge
902
+ code_challenge_method: PKCE challenge method
903
+
904
+ Returns:
905
+ Authorization code
906
+ """
907
+ # Validate client
908
+ client = await self.client_store.get_client(client_id)
909
+ if not client:
910
+ raise AuthorizationError("Invalid client")
911
+
912
+ if not client.validate_redirect_uri(redirect_uri):
913
+ raise AuthorizationError("Invalid redirect URI")
914
+
915
+ # Convert scopes list to string
916
+ scope = " ".join(scopes) if scopes else None
917
+
918
+ # Create authorization code
919
+ auth_code = AuthorizationCode(
920
+ code=secrets.token_urlsafe(32),
921
+ client_id=client_id,
922
+ redirect_uri=redirect_uri,
923
+ scope=scope,
924
+ scopes=scopes,
925
+ subject=user_id,
926
+ user_id=user_id,
927
+ code_challenge=code_challenge,
928
+ code_challenge_method=code_challenge_method,
929
+ )
930
+
931
+ await self.token_store.store_authorization_code(auth_code)
932
+
933
+ return auth_code.code
934
+
935
+ async def exchange_authorization_code(
936
+ self,
937
+ client_id: str,
938
+ client_secret: Optional[str],
939
+ code: str,
940
+ redirect_uri: str,
941
+ code_verifier: Optional[str] = None,
942
+ ) -> Dict[str, Any]:
943
+ """Exchange authorization code for tokens.
944
+
945
+ Args:
946
+ client_id: OAuth client ID
947
+ client_secret: OAuth client secret
948
+ code: Authorization code
949
+ redirect_uri: Redirect URI
950
+ code_verifier: PKCE code verifier
951
+
952
+ Returns:
953
+ Token response
954
+ """
955
+ # Validate client
956
+ client = await self.client_store.get_client(client_id)
957
+ if not client:
958
+ raise AuthorizationError("Invalid client")
959
+
960
+ # Validate client secret for confidential clients
961
+ if client.client_type == ClientType.CONFIDENTIAL:
962
+ if not client_secret or client_secret != client.client_secret:
963
+ raise AuthorizationError("Invalid client credentials")
964
+
965
+ # Get and consume authorization code
966
+ auth_code = await self.token_store.consume_authorization_code(code)
967
+ if not auth_code:
968
+ raise AuthorizationError("Invalid or expired authorization code")
969
+
970
+ # Validate authorization code
971
+ if auth_code.client_id != client_id:
972
+ raise AuthorizationError("Authorization code mismatch")
973
+
974
+ if auth_code.redirect_uri != redirect_uri:
975
+ raise AuthorizationError("Redirect URI mismatch")
976
+
977
+ # Validate PKCE if used
978
+ if auth_code.code_challenge:
979
+ if not code_verifier:
980
+ raise AuthorizationError("Code verifier required")
981
+
982
+ if not auth_code.validate_pkce(code_verifier):
983
+ raise AuthorizationError("Invalid code verifier")
984
+
985
+ # Create access token
986
+ access_token_jwt = self.jwt_manager.create_access_token(
987
+ subject=auth_code.subject,
988
+ client_id=client_id,
989
+ scope=auth_code.scope,
990
+ audience=["mcp-api"],
991
+ )
992
+
993
+ # Create AccessToken object if JWT string was returned
994
+ if isinstance(access_token_jwt, str):
995
+ access_token = AccessToken(
996
+ token=access_token_jwt,
997
+ client_id=client_id,
998
+ subject=auth_code.subject,
999
+ user_id=auth_code.user_id,
1000
+ scope=auth_code.scope,
1001
+ scopes=auth_code.scopes,
1002
+ )
1003
+ else:
1004
+ access_token = access_token_jwt
1005
+
1006
+ # Create refresh token JWT
1007
+ refresh_token_jwt = self.jwt_manager.create_refresh_token(
1008
+ {"client_id": client_id, "user_id": auth_code.subject}
1009
+ )
1010
+
1011
+ # Create RefreshToken object
1012
+ refresh_token = RefreshToken(
1013
+ token=refresh_token_jwt,
1014
+ client_id=client_id,
1015
+ subject=auth_code.subject,
1016
+ scope=auth_code.scope,
1017
+ )
1018
+
1019
+ # Store tokens
1020
+ await self.token_store.store_access_token(access_token)
1021
+ await self.token_store.store_refresh_token(refresh_token)
1022
+
1023
+ response = access_token.to_dict()
1024
+ response["refresh_token"] = refresh_token.token
1025
+
1026
+ return response
1027
+
1028
+ async def client_credentials_grant(
1029
+ self, client_id: str, client_secret: str, scopes: Optional[List[str]] = None
1030
+ ) -> Dict[str, Any]:
1031
+ """Handle client credentials grant.
1032
+
1033
+ Args:
1034
+ client_id: OAuth client ID
1035
+ client_secret: OAuth client secret
1036
+ scopes: Requested scopes
1037
+
1038
+ Returns:
1039
+ Token response
1040
+ """
1041
+ # Validate client
1042
+ client = await self.client_store.get_client(client_id)
1043
+ if not client:
1044
+ raise AuthorizationError("Invalid client")
1045
+
1046
+ if not client.supports_grant_type(GrantType.CLIENT_CREDENTIALS):
1047
+ raise AuthorizationError("Grant type not supported")
1048
+
1049
+ # Validate client secret
1050
+ if client_secret != client.client_secret:
1051
+ raise AuthorizationError("Invalid client credentials")
1052
+
1053
+ # Validate scope
1054
+ scope = None
1055
+ if scopes:
1056
+ for requested_scope in scopes:
1057
+ if not client.has_scope(requested_scope):
1058
+ raise AuthorizationError(f"Invalid scope: {requested_scope}")
1059
+ scope = " ".join(scopes)
1060
+
1061
+ # Create access token
1062
+ access_token_jwt = self.jwt_manager.create_access_token(
1063
+ client_id=client_id, scope=scope, audience=["mcp-api"]
1064
+ )
1065
+
1066
+ # Create AccessToken object if JWT string was returned
1067
+ if isinstance(access_token_jwt, str):
1068
+ access_token = AccessToken(
1069
+ token=access_token_jwt,
1070
+ client_id=client_id,
1071
+ scope=scope,
1072
+ scopes=scopes,
1073
+ )
1074
+ else:
1075
+ access_token = access_token_jwt
1076
+
1077
+ # Store token
1078
+ await self.token_store.store_access_token(access_token)
1079
+
1080
+ return access_token.to_dict()
1081
+
1082
+ async def refresh_token_grant(
1083
+ self, client_id: str, client_secret: Optional[str], refresh_token: str
1084
+ ) -> Dict[str, Any]:
1085
+ """Handle refresh token grant.
1086
+
1087
+ Args:
1088
+ client_id: OAuth client ID
1089
+ client_secret: OAuth client secret
1090
+ refresh_token: Refresh token
1091
+
1092
+ Returns:
1093
+ Token response
1094
+ """
1095
+ # Validate client
1096
+ client = await self.client_store.get_client(client_id)
1097
+ if not client:
1098
+ raise AuthorizationError("Invalid client")
1099
+
1100
+ # Validate client secret for confidential clients
1101
+ if client.client_type == ClientType.CONFIDENTIAL:
1102
+ if not client_secret or client_secret != client.client_secret:
1103
+ raise AuthorizationError("Invalid client credentials")
1104
+
1105
+ # First try to verify the refresh token as JWT
1106
+ try:
1107
+ token_data = self.jwt_manager.verify_refresh_token(refresh_token)
1108
+ if token_data:
1109
+ # Create RefreshToken object from JWT data
1110
+ refresh_token_obj = RefreshToken(
1111
+ token=refresh_token,
1112
+ client_id=token_data.get("client_id", client_id),
1113
+ subject=token_data.get("sub") or token_data.get("user_id"),
1114
+ user_id=token_data.get("user_id") or token_data.get("sub"),
1115
+ scope=(
1116
+ " ".join(token_data.get("scopes", []))
1117
+ if token_data.get("scopes")
1118
+ else None
1119
+ ),
1120
+ scopes=token_data.get("scopes"),
1121
+ )
1122
+ except:
1123
+ # Fall back to token store
1124
+ refresh_token_obj = await self.token_store.get_refresh_token(refresh_token)
1125
+ if not refresh_token_obj:
1126
+ raise AuthorizationError("Invalid refresh token")
1127
+
1128
+ if refresh_token_obj.client_id != client_id:
1129
+ raise AuthorizationError("Client mismatch")
1130
+
1131
+ # Create new access token
1132
+ access_token_jwt = self.jwt_manager.create_access_token(
1133
+ subject=refresh_token_obj.subject,
1134
+ client_id=client_id,
1135
+ scope=refresh_token_obj.scope,
1136
+ audience=["mcp-api"],
1137
+ )
1138
+
1139
+ # Create AccessToken object if JWT string was returned
1140
+ if isinstance(access_token_jwt, str):
1141
+ access_token = AccessToken(
1142
+ token=access_token_jwt,
1143
+ client_id=client_id,
1144
+ subject=refresh_token_obj.subject,
1145
+ user_id=refresh_token_obj.user_id,
1146
+ scope=refresh_token_obj.scope,
1147
+ scopes=(
1148
+ refresh_token_obj.scopes
1149
+ if hasattr(refresh_token_obj, "scopes")
1150
+ else None
1151
+ ),
1152
+ )
1153
+ else:
1154
+ access_token = access_token_jwt
1155
+
1156
+ # Store new access token
1157
+ await self.token_store.store_access_token(access_token)
1158
+
1159
+ return access_token.to_dict()
1160
+
1161
+ async def introspect_token(self, token: str) -> Dict[str, Any]:
1162
+ """Introspect token.
1163
+
1164
+ Args:
1165
+ token: Token to introspect
1166
+
1167
+ Returns:
1168
+ Token introspection response
1169
+ """
1170
+ try:
1171
+ # Try to verify as JWT access token
1172
+ payload = self.jwt_manager.verify_access_token(token)
1173
+ if payload:
1174
+ # Extract token information from JWT payload
1175
+ client_id = payload.get("client_id")
1176
+ scope = payload.get("scope", "")
1177
+ exp = payload.get("exp")
1178
+ iat = payload.get("iat", time.time())
1179
+ sub = payload.get("sub") or payload.get("user_id")
1180
+ aud = payload.get("aud", [])
1181
+
1182
+ # Get scopes from payload
1183
+ scopes = payload.get("scopes", [])
1184
+ if not scopes and scope:
1185
+ scopes = scope.split()
1186
+
1187
+ return {
1188
+ "active": True,
1189
+ "client_id": client_id,
1190
+ "scope": " ".join(scopes) if scopes else scope,
1191
+ "exp": exp,
1192
+ "iat": iat,
1193
+ "sub": sub,
1194
+ "aud": aud,
1195
+ "token_type": "access_token",
1196
+ }
1197
+ except AuthenticationError:
1198
+ # Token is invalid or expired
1199
+ pass
1200
+
1201
+ return {"active": False}
1202
+
1203
+ async def revoke_token(
1204
+ self,
1205
+ token: str,
1206
+ client_id: Optional[str] = None,
1207
+ client_secret: Optional[str] = None,
1208
+ ) -> bool:
1209
+ """Revoke token.
1210
+
1211
+ Args:
1212
+ token: Token to revoke
1213
+ client_id: OAuth client ID
1214
+ client_secret: OAuth client secret
1215
+
1216
+ Returns:
1217
+ True if revoked successfully
1218
+ """
1219
+ # If client_id is provided, validate client
1220
+ if client_id:
1221
+ client = await self.client_store.get_client(client_id)
1222
+ if not client:
1223
+ return False
1224
+
1225
+ # Validate client secret for confidential clients
1226
+ if client.client_type == ClientType.CONFIDENTIAL:
1227
+ if not client_secret or client_secret != client.client_secret:
1228
+ return False
1229
+
1230
+ # Try to revoke as access token
1231
+ if await self.token_store.revoke_access_token(token):
1232
+ return True
1233
+
1234
+ # Try to revoke as refresh token
1235
+ return await self.token_store.revoke_refresh_token(token)
1236
+
1237
+ async def refresh_access_token(
1238
+ self, client_id: str, client_secret: Optional[str], refresh_token: str
1239
+ ) -> Dict[str, Any]:
1240
+ """Refresh access token (alias for refresh_token_grant).
1241
+
1242
+ Args:
1243
+ client_id: OAuth client ID
1244
+ client_secret: OAuth client secret
1245
+ refresh_token: Refresh token
1246
+
1247
+ Returns:
1248
+ Token response
1249
+ """
1250
+ return await self.refresh_token_grant(client_id, client_secret, refresh_token)
1251
+
1252
+ def get_well_known_metadata(self) -> Dict[str, Any]:
1253
+ """Get well-known authorization server metadata.
1254
+
1255
+ Returns:
1256
+ Authorization server metadata
1257
+ """
1258
+ return {
1259
+ "issuer": self.issuer,
1260
+ "authorization_endpoint": f"{self.issuer}/authorize",
1261
+ "token_endpoint": f"{self.issuer}/token",
1262
+ "introspection_endpoint": f"{self.issuer}/introspect",
1263
+ "revocation_endpoint": f"{self.issuer}/revoke",
1264
+ "jwks_uri": f"{self.issuer}/.well-known/jwks.json",
1265
+ "registration_endpoint": f"{self.issuer}/register",
1266
+ "scopes_supported": self.default_scopes,
1267
+ "response_types_supported": ["code"],
1268
+ "grant_types_supported": [
1269
+ "authorization_code",
1270
+ "client_credentials",
1271
+ "refresh_token",
1272
+ ],
1273
+ "token_endpoint_auth_methods_supported": [
1274
+ "client_secret_basic",
1275
+ "client_secret_post",
1276
+ ],
1277
+ "code_challenge_methods_supported": ["S256", "plain"],
1278
+ }
1279
+
1280
+
1281
+ class ResourceServer:
1282
+ """OAuth 2.1 Resource Server for MCP."""
1283
+
1284
+ def __init__(
1285
+ self,
1286
+ issuer: str,
1287
+ audience: str,
1288
+ jwt_manager: Optional[JWTManager] = None,
1289
+ required_scopes: Optional[List[str]] = None,
1290
+ ):
1291
+ """Initialize resource server.
1292
+
1293
+ Args:
1294
+ issuer: Authorization server issuer
1295
+ audience: Expected token audience
1296
+ jwt_manager: JWT manager for token verification
1297
+ required_scopes: Required scopes for access
1298
+ """
1299
+ self.issuer = issuer
1300
+ self.audience = audience
1301
+ self.jwt_manager = jwt_manager or JWTManager(issuer=issuer)
1302
+ self.required_scopes = required_scopes or []
1303
+
1304
+ async def authenticate(
1305
+ self, credentials: Union[str, Dict[str, Any]]
1306
+ ) -> Dict[str, Any]:
1307
+ """Authenticate using OAuth 2.1 access token.
1308
+
1309
+ Args:
1310
+ credentials: Token string or dict with 'token' key
1311
+
1312
+ Returns:
1313
+ Authentication result
1314
+ """
1315
+ # Handle both string and dict inputs
1316
+ if isinstance(credentials, str):
1317
+ token = credentials
1318
+ else:
1319
+ token = credentials.get("token")
1320
+ if not token:
1321
+ raise AuthenticationError("No token provided")
1322
+
1323
+ # Remove 'Bearer ' prefix if present
1324
+ if token.startswith("Bearer "):
1325
+ token = token[7:]
1326
+
1327
+ # Verify JWT token
1328
+ payload = self.jwt_manager.verify_access_token(token)
1329
+ if not payload:
1330
+ raise AuthenticationError("Invalid token")
1331
+
1332
+ # Check audience
1333
+ token_audience = payload.get("aud", [])
1334
+ if isinstance(token_audience, str):
1335
+ token_audience = [token_audience]
1336
+
1337
+ if self.audience not in token_audience:
1338
+ raise AuthorizationError("Invalid token audience")
1339
+
1340
+ # Check required scopes
1341
+ token_scope = payload.get("scope", "")
1342
+ token_scopes = token_scope.split() if token_scope else []
1343
+
1344
+ for required_scope in self.required_scopes:
1345
+ if required_scope not in token_scopes:
1346
+ raise AuthenticationError(f"Missing required scope: {required_scope}")
1347
+
1348
+ return {
1349
+ "id": payload.get("sub") or payload.get("client_id"),
1350
+ "client_id": payload.get("client_id"),
1351
+ "subject": payload.get("sub"),
1352
+ "user_id": payload.get("sub") or payload.get("user_id"),
1353
+ "scopes": token_scopes,
1354
+ "token_type": "Bearer",
1355
+ }
1356
+
1357
+ async def check_permission(
1358
+ self, auth_info: Dict[str, Any], required_permission: str
1359
+ ) -> None:
1360
+ """Check if authenticated entity has required permission.
1361
+
1362
+ Args:
1363
+ auth_info: Authentication information from authenticate()
1364
+ required_permission: Required permission/scope
1365
+
1366
+ Raises:
1367
+ AuthorizationError: If permission is missing
1368
+ """
1369
+ scopes = auth_info.get("scopes", [])
1370
+ if required_permission not in scopes:
1371
+ raise AuthorizationError(
1372
+ f"Missing required permission: {required_permission}"
1373
+ )
1374
+
1375
+ async def get_headers(self) -> Dict[str, str]:
1376
+ """Get headers for authentication (empty for resource server).
1377
+
1378
+ Returns:
1379
+ Empty dict as resource server doesn't add headers
1380
+ """
1381
+ return {}
1382
+
1383
+
1384
+ class OAuth2Client:
1385
+ """OAuth 2.1 client for MCP."""
1386
+
1387
+ def __init__(
1388
+ self,
1389
+ client_id: str,
1390
+ client_secret: Optional[str] = None,
1391
+ token_endpoint: Optional[str] = None,
1392
+ authorization_endpoint: Optional[str] = None,
1393
+ redirect_uri: Optional[str] = None,
1394
+ ):
1395
+ """Initialize OAuth 2.1 client.
1396
+
1397
+ Args:
1398
+ client_id: OAuth client ID
1399
+ client_secret: OAuth client secret
1400
+ token_endpoint: Token endpoint URL
1401
+ authorization_endpoint: Authorization endpoint URL
1402
+ redirect_uri: Redirect URI
1403
+ """
1404
+ self.client_id = client_id
1405
+ self.client_secret = client_secret
1406
+ self.token_endpoint = token_endpoint
1407
+ self.authorization_endpoint = authorization_endpoint
1408
+ self.redirect_uri = redirect_uri
1409
+
1410
+ # Token storage
1411
+ self._access_token: Optional[str] = None
1412
+ self._refresh_token: Optional[str] = None
1413
+ self._token_expires_at: Optional[float] = None
1414
+
1415
+ async def get_client_credentials_token(
1416
+ self, scopes: Optional[List[str]] = None
1417
+ ) -> Dict[str, Any]:
1418
+ """Get access token using client credentials grant.
1419
+
1420
+ Args:
1421
+ scopes: Requested scopes
1422
+
1423
+ Returns:
1424
+ Token response dict
1425
+ """
1426
+ if not self.token_endpoint:
1427
+ raise AuthenticationError("Token endpoint not configured")
1428
+
1429
+ if not self.client_secret:
1430
+ raise AuthenticationError("Client secret required for client credentials")
1431
+
1432
+ # Prepare token request
1433
+ data = {
1434
+ "grant_type": "client_credentials",
1435
+ "client_id": self.client_id,
1436
+ "client_secret": self.client_secret,
1437
+ }
1438
+
1439
+ if scopes:
1440
+ data["scope"] = " ".join(scopes)
1441
+
1442
+ # Make token request
1443
+ async with aiohttp.ClientSession() as session:
1444
+ async with session.post(
1445
+ self.token_endpoint,
1446
+ data=data,
1447
+ headers={"Content-Type": "application/x-www-form-urlencoded"},
1448
+ ) as response:
1449
+ if response.status != 200:
1450
+ error_text = await response.text()
1451
+ raise AuthenticationError(f"Token request failed: {error_text}")
1452
+
1453
+ token_response = await response.json()
1454
+
1455
+ # Store token information
1456
+ self._access_token = token_response["access_token"]
1457
+ self._refresh_token = token_response.get("refresh_token")
1458
+
1459
+ expires_in = token_response.get("expires_in", 3600)
1460
+ self._token_expires_at = time.time() + expires_in
1461
+
1462
+ return token_response
1463
+
1464
+ def get_authorization_url(
1465
+ self,
1466
+ scopes: Optional[List[str]] = None,
1467
+ state: Optional[str] = None,
1468
+ use_pkce: bool = True,
1469
+ ) -> Tuple[str, Optional[str]]:
1470
+ """Get authorization URL for authorization code flow.
1471
+
1472
+ Args:
1473
+ scopes: Requested scopes
1474
+ state: State parameter
1475
+ use_pkce: Use PKCE for security
1476
+
1477
+ Returns:
1478
+ Tuple of (authorization_url, code_verifier)
1479
+ """
1480
+ if not self.authorization_endpoint:
1481
+ raise AuthenticationError("Authorization endpoint not configured")
1482
+
1483
+ if not self.redirect_uri:
1484
+ raise AuthenticationError("Redirect URI not configured")
1485
+
1486
+ params = {
1487
+ "response_type": "code",
1488
+ "client_id": self.client_id,
1489
+ "redirect_uri": self.redirect_uri,
1490
+ }
1491
+
1492
+ if scopes:
1493
+ params["scope"] = " ".join(scopes)
1494
+
1495
+ if state:
1496
+ params["state"] = state
1497
+
1498
+ code_verifier = None
1499
+ if use_pkce:
1500
+ # Generate PKCE parameters
1501
+ code_verifier = (
1502
+ base64.urlsafe_b64encode(secrets.token_bytes(32)).decode().rstrip("=")
1503
+ )
1504
+ code_challenge = (
1505
+ base64.urlsafe_b64encode(
1506
+ hashlib.sha256(code_verifier.encode()).digest()
1507
+ )
1508
+ .decode()
1509
+ .rstrip("=")
1510
+ )
1511
+
1512
+ params["code_challenge"] = code_challenge
1513
+ params["code_challenge_method"] = "S256"
1514
+
1515
+ query_string = urlencode(params)
1516
+ authorization_url = f"{self.authorization_endpoint}?{query_string}"
1517
+
1518
+ return authorization_url, code_verifier
1519
+
1520
+ async def exchange_authorization_code(
1521
+ self, code: str, code_verifier: Optional[str] = None
1522
+ ) -> str:
1523
+ """Exchange authorization code for access token.
1524
+
1525
+ Args:
1526
+ code: Authorization code
1527
+ code_verifier: PKCE code verifier
1528
+
1529
+ Returns:
1530
+ Access token
1531
+ """
1532
+ if not self.token_endpoint:
1533
+ raise AuthenticationError("Token endpoint not configured")
1534
+
1535
+ if not self.redirect_uri:
1536
+ raise AuthenticationError("Redirect URI not configured")
1537
+
1538
+ # Prepare token request
1539
+ data = {
1540
+ "grant_type": "authorization_code",
1541
+ "client_id": self.client_id,
1542
+ "code": code,
1543
+ "redirect_uri": self.redirect_uri,
1544
+ }
1545
+
1546
+ if self.client_secret:
1547
+ data["client_secret"] = self.client_secret
1548
+
1549
+ if code_verifier:
1550
+ data["code_verifier"] = code_verifier
1551
+
1552
+ # Make token request
1553
+ async with aiohttp.ClientSession() as session:
1554
+ async with session.post(
1555
+ self.token_endpoint,
1556
+ data=data,
1557
+ headers={"Content-Type": "application/x-www-form-urlencoded"},
1558
+ ) as response:
1559
+ if response.status != 200:
1560
+ error_text = await response.text()
1561
+ raise AuthenticationError(f"Token exchange failed: {error_text}")
1562
+
1563
+ token_response = await response.json()
1564
+
1565
+ # Store token information
1566
+ self._access_token = token_response["access_token"]
1567
+ self._refresh_token = token_response.get("refresh_token")
1568
+
1569
+ expires_in = token_response.get("expires_in", 3600)
1570
+ self._token_expires_at = time.time() + expires_in
1571
+
1572
+ return token_response
1573
+
1574
+ async def get_valid_token(self) -> Optional[str]:
1575
+ """Get valid access token, refreshing if necessary.
1576
+
1577
+ Returns:
1578
+ Valid access token or None
1579
+ """
1580
+ # Check if current token is valid
1581
+ if self._access_token and self._token_expires_at:
1582
+ if time.time() < self._token_expires_at - 60: # 1 minute buffer
1583
+ return self._access_token
1584
+
1585
+ # Try to refresh token
1586
+ if self._refresh_token:
1587
+ try:
1588
+ return await self._refresh_access_token()
1589
+ except Exception as e:
1590
+ logger.error(f"Token refresh failed: {e}")
1591
+
1592
+ return None
1593
+
1594
+ async def _refresh_access_token(self) -> str:
1595
+ """Refresh access token using refresh token.
1596
+
1597
+ Returns:
1598
+ New access token
1599
+ """
1600
+ if not self.token_endpoint or not self._refresh_token:
1601
+ raise AuthenticationError("Cannot refresh token")
1602
+
1603
+ # Prepare refresh request
1604
+ data = {
1605
+ "grant_type": "refresh_token",
1606
+ "client_id": self.client_id,
1607
+ "refresh_token": self._refresh_token,
1608
+ }
1609
+
1610
+ if self.client_secret:
1611
+ data["client_secret"] = self.client_secret
1612
+
1613
+ # Make refresh request
1614
+ async with aiohttp.ClientSession() as session:
1615
+ async with session.post(
1616
+ self.token_endpoint,
1617
+ data=data,
1618
+ headers={"Content-Type": "application/x-www-form-urlencoded"},
1619
+ ) as response:
1620
+ if response.status != 200:
1621
+ error_text = await response.text()
1622
+ raise AuthenticationError(f"Token refresh failed: {error_text}")
1623
+
1624
+ token_response = await response.json()
1625
+
1626
+ # Update token information
1627
+ self._access_token = token_response["access_token"]
1628
+
1629
+ # Update refresh token if provided
1630
+ if "refresh_token" in token_response:
1631
+ self._refresh_token = token_response["refresh_token"]
1632
+
1633
+ expires_in = token_response.get("expires_in", 3600)
1634
+ self._token_expires_at = time.time() + expires_in
1635
+
1636
+ return self._access_token
1637
+
1638
+ async def refresh_token(self, refresh_token: str) -> Dict[str, Any]:
1639
+ """Refresh access token.
1640
+
1641
+ Args:
1642
+ refresh_token: Refresh token
1643
+
1644
+ Returns:
1645
+ Token response
1646
+ """
1647
+ if not self.token_endpoint:
1648
+ raise AuthenticationError("Token endpoint not configured")
1649
+
1650
+ # Prepare token request
1651
+ data = {
1652
+ "grant_type": "refresh_token",
1653
+ "refresh_token": refresh_token,
1654
+ "client_id": self.client_id,
1655
+ }
1656
+
1657
+ if self.client_secret:
1658
+ data["client_secret"] = self.client_secret
1659
+
1660
+ # Make token request
1661
+ async with aiohttp.ClientSession() as session:
1662
+ async with session.post(
1663
+ self.token_endpoint,
1664
+ data=data,
1665
+ headers={"Content-Type": "application/x-www-form-urlencoded"},
1666
+ ) as response:
1667
+ if response.status != 200:
1668
+ error_data = await response.json()
1669
+ error = error_data.get("error", "unknown_error")
1670
+ error_description = error_data.get(
1671
+ "error_description", "Token refresh failed"
1672
+ )
1673
+ raise AuthenticationError(f"{error}: {error_description}")
1674
+
1675
+ token_response = await response.json()
1676
+
1677
+ # Store token information
1678
+ self._access_token = token_response["access_token"]
1679
+ self._refresh_token = token_response.get("refresh_token", refresh_token)
1680
+
1681
+ expires_in = token_response.get("expires_in", 3600)
1682
+ self._token_expires_at = time.time() + expires_in
1683
+
1684
+ return token_response
1685
+
1686
+ async def introspect_token(
1687
+ self, token: str, introspection_endpoint: Optional[str] = None
1688
+ ) -> Dict[str, Any]:
1689
+ """Introspect a token.
1690
+
1691
+ Args:
1692
+ token: Token to introspect
1693
+ introspection_endpoint: Introspection endpoint URL
1694
+
1695
+ Returns:
1696
+ Introspection response
1697
+ """
1698
+ if not introspection_endpoint and self.token_endpoint:
1699
+ # Try to derive introspection endpoint from token endpoint
1700
+ introspection_endpoint = self.token_endpoint.replace(
1701
+ "/token", "/introspect"
1702
+ )
1703
+
1704
+ if not introspection_endpoint:
1705
+ raise AuthenticationError("Introspection endpoint not configured")
1706
+
1707
+ data = {
1708
+ "token": token,
1709
+ "client_id": self.client_id,
1710
+ }
1711
+
1712
+ if self.client_secret:
1713
+ data["client_secret"] = self.client_secret
1714
+
1715
+ async with aiohttp.ClientSession() as session:
1716
+ async with session.post(
1717
+ introspection_endpoint,
1718
+ data=data,
1719
+ headers={"Content-Type": "application/x-www-form-urlencoded"},
1720
+ ) as response:
1721
+ if response.status != 200:
1722
+ error_text = await response.text()
1723
+ raise AuthenticationError(
1724
+ f"Token introspection failed: {error_text}"
1725
+ )
1726
+
1727
+ return await response.json()