kailash 0.6.3__py3-none-any.whl → 0.6.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kailash/__init__.py +3 -3
- kailash/api/custom_nodes_secure.py +3 -3
- kailash/api/gateway.py +1 -1
- kailash/api/studio.py +1 -1
- kailash/api/workflow_api.py +2 -2
- kailash/core/resilience/bulkhead.py +475 -0
- kailash/core/resilience/circuit_breaker.py +92 -10
- kailash/core/resilience/health_monitor.py +578 -0
- kailash/edge/discovery.py +86 -0
- kailash/mcp_server/__init__.py +309 -33
- kailash/mcp_server/advanced_features.py +1022 -0
- kailash/mcp_server/ai_registry_server.py +27 -2
- kailash/mcp_server/auth.py +789 -0
- kailash/mcp_server/client.py +645 -378
- kailash/mcp_server/discovery.py +1593 -0
- kailash/mcp_server/errors.py +673 -0
- kailash/mcp_server/oauth.py +1727 -0
- kailash/mcp_server/protocol.py +1126 -0
- kailash/mcp_server/registry_integration.py +587 -0
- kailash/mcp_server/server.py +1228 -96
- kailash/mcp_server/transports.py +1169 -0
- kailash/mcp_server/utils/__init__.py +6 -1
- kailash/mcp_server/utils/cache.py +250 -7
- kailash/middleware/auth/auth_manager.py +3 -3
- kailash/middleware/communication/api_gateway.py +1 -1
- kailash/middleware/communication/realtime.py +1 -1
- kailash/middleware/mcp/enhanced_server.py +1 -1
- kailash/nodes/__init__.py +2 -0
- kailash/nodes/admin/audit_log.py +6 -6
- kailash/nodes/admin/permission_check.py +8 -8
- kailash/nodes/admin/role_management.py +32 -28
- kailash/nodes/admin/schema.sql +6 -1
- kailash/nodes/admin/schema_manager.py +13 -13
- kailash/nodes/admin/security_event.py +15 -15
- kailash/nodes/admin/tenant_isolation.py +3 -3
- kailash/nodes/admin/transaction_utils.py +3 -3
- kailash/nodes/admin/user_management.py +21 -21
- kailash/nodes/ai/a2a.py +11 -11
- kailash/nodes/ai/ai_providers.py +9 -12
- kailash/nodes/ai/embedding_generator.py +13 -14
- kailash/nodes/ai/intelligent_agent_orchestrator.py +19 -19
- kailash/nodes/ai/iterative_llm_agent.py +2 -2
- kailash/nodes/ai/llm_agent.py +210 -33
- kailash/nodes/ai/self_organizing.py +2 -2
- kailash/nodes/alerts/discord.py +4 -4
- kailash/nodes/api/graphql.py +6 -6
- kailash/nodes/api/http.py +10 -10
- kailash/nodes/api/rate_limiting.py +4 -4
- kailash/nodes/api/rest.py +15 -15
- kailash/nodes/auth/mfa.py +3 -3
- kailash/nodes/auth/risk_assessment.py +2 -2
- kailash/nodes/auth/session_management.py +5 -5
- kailash/nodes/auth/sso.py +143 -0
- kailash/nodes/base.py +8 -2
- kailash/nodes/base_async.py +16 -2
- kailash/nodes/base_with_acl.py +2 -2
- kailash/nodes/cache/__init__.py +9 -0
- kailash/nodes/cache/cache.py +1172 -0
- kailash/nodes/cache/cache_invalidation.py +874 -0
- kailash/nodes/cache/redis_pool_manager.py +595 -0
- kailash/nodes/code/async_python.py +2 -1
- kailash/nodes/code/python.py +194 -30
- kailash/nodes/compliance/data_retention.py +6 -6
- kailash/nodes/compliance/gdpr.py +5 -5
- kailash/nodes/data/__init__.py +10 -0
- kailash/nodes/data/async_sql.py +1956 -129
- kailash/nodes/data/optimistic_locking.py +906 -0
- kailash/nodes/data/readers.py +8 -8
- kailash/nodes/data/redis.py +378 -0
- kailash/nodes/data/sql.py +314 -3
- kailash/nodes/data/streaming.py +21 -0
- kailash/nodes/enterprise/__init__.py +8 -0
- kailash/nodes/enterprise/audit_logger.py +285 -0
- kailash/nodes/enterprise/batch_processor.py +22 -3
- kailash/nodes/enterprise/data_lineage.py +1 -1
- kailash/nodes/enterprise/mcp_executor.py +205 -0
- kailash/nodes/enterprise/service_discovery.py +150 -0
- kailash/nodes/enterprise/tenant_assignment.py +108 -0
- kailash/nodes/logic/async_operations.py +2 -2
- kailash/nodes/logic/convergence.py +1 -1
- kailash/nodes/logic/operations.py +1 -1
- kailash/nodes/monitoring/__init__.py +11 -1
- kailash/nodes/monitoring/health_check.py +456 -0
- kailash/nodes/monitoring/log_processor.py +817 -0
- kailash/nodes/monitoring/metrics_collector.py +627 -0
- kailash/nodes/monitoring/performance_benchmark.py +137 -11
- kailash/nodes/rag/advanced.py +7 -7
- kailash/nodes/rag/agentic.py +49 -2
- kailash/nodes/rag/conversational.py +3 -3
- kailash/nodes/rag/evaluation.py +3 -3
- kailash/nodes/rag/federated.py +3 -3
- kailash/nodes/rag/graph.py +3 -3
- kailash/nodes/rag/multimodal.py +3 -3
- kailash/nodes/rag/optimized.py +5 -5
- kailash/nodes/rag/privacy.py +3 -3
- kailash/nodes/rag/query_processing.py +6 -6
- kailash/nodes/rag/realtime.py +1 -1
- kailash/nodes/rag/registry.py +1 -1
- kailash/nodes/rag/router.py +1 -1
- kailash/nodes/rag/similarity.py +7 -7
- kailash/nodes/rag/strategies.py +4 -4
- kailash/nodes/security/abac_evaluator.py +6 -6
- kailash/nodes/security/behavior_analysis.py +5 -5
- kailash/nodes/security/credential_manager.py +1 -1
- kailash/nodes/security/rotating_credentials.py +11 -11
- kailash/nodes/security/threat_detection.py +8 -8
- kailash/nodes/testing/credential_testing.py +2 -2
- kailash/nodes/transform/processors.py +5 -5
- kailash/runtime/local.py +163 -9
- kailash/runtime/parameter_injection.py +425 -0
- kailash/runtime/parameter_injector.py +657 -0
- kailash/runtime/testing.py +2 -2
- kailash/testing/fixtures.py +2 -2
- kailash/workflow/builder.py +99 -14
- kailash/workflow/builder_improvements.py +207 -0
- kailash/workflow/input_handling.py +170 -0
- {kailash-0.6.3.dist-info → kailash-0.6.5.dist-info}/METADATA +22 -9
- {kailash-0.6.3.dist-info → kailash-0.6.5.dist-info}/RECORD +122 -95
- {kailash-0.6.3.dist-info → kailash-0.6.5.dist-info}/WHEEL +0 -0
- {kailash-0.6.3.dist-info → kailash-0.6.5.dist-info}/entry_points.txt +0 -0
- {kailash-0.6.3.dist-info → kailash-0.6.5.dist-info}/licenses/LICENSE +0 -0
- {kailash-0.6.3.dist-info → kailash-0.6.5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1727 @@
|
|
1
|
+
"""
|
2
|
+
OAuth 2.1 Authentication System for MCP.
|
3
|
+
|
4
|
+
This module implements a complete OAuth 2.1 authorization server and resource
|
5
|
+
server for MCP, following the latest OAuth 2.1 specification. It provides
|
6
|
+
secure authentication and authorization for MCP servers and clients.
|
7
|
+
|
8
|
+
Features:
|
9
|
+
- Complete OAuth 2.1 authorization server
|
10
|
+
- Dynamic client registration
|
11
|
+
- Multiple grant types (authorization code, client credentials)
|
12
|
+
- JWT access and refresh tokens
|
13
|
+
- Scope-based authorization
|
14
|
+
- PKCE support for public clients
|
15
|
+
- Token introspection and revocation
|
16
|
+
- Resource server middleware
|
17
|
+
- Well-known metadata endpoints
|
18
|
+
|
19
|
+
Examples:
|
20
|
+
OAuth 2.1 Authorization Server:
|
21
|
+
|
22
|
+
>>> from kailash.mcp_server.oauth import AuthorizationServer
|
23
|
+
>>>
|
24
|
+
>>> auth_server = AuthorizationServer(
|
25
|
+
... issuer="https://auth.example.com",
|
26
|
+
... private_key_path="private.pem",
|
27
|
+
... client_store=InMemoryClientStore()
|
28
|
+
... )
|
29
|
+
>>>
|
30
|
+
>>> # Register client
|
31
|
+
>>> client = await auth_server.register_client(
|
32
|
+
... client_name="MCP Client",
|
33
|
+
... redirect_uris=["http://localhost:8080/callback"],
|
34
|
+
... grant_types=["authorization_code"],
|
35
|
+
... scopes=["mcp.tools", "mcp.resources"]
|
36
|
+
... )
|
37
|
+
|
38
|
+
Resource Server Integration:
|
39
|
+
|
40
|
+
>>> from kailash.mcp_server.oauth import ResourceServer
|
41
|
+
>>> from kailash.mcp_server import MCPServer
|
42
|
+
>>>
|
43
|
+
>>> resource_server = ResourceServer(
|
44
|
+
... issuer="https://auth.example.com",
|
45
|
+
... audience="mcp-api"
|
46
|
+
... )
|
47
|
+
>>>
|
48
|
+
>>> server = MCPServer("protected-server", auth_provider=resource_server)
|
49
|
+
>>>
|
50
|
+
>>> @server.tool(required_permission="mcp.tools")
|
51
|
+
>>> def protected_tool():
|
52
|
+
... return "Only accessible with proper token"
|
53
|
+
|
54
|
+
Client Credentials Flow:
|
55
|
+
|
56
|
+
>>> from kailash.mcp_server.oauth import OAuth2Client
|
57
|
+
>>>
|
58
|
+
>>> oauth_client = OAuth2Client(
|
59
|
+
... client_id="client123",
|
60
|
+
... client_secret="secret456",
|
61
|
+
... token_endpoint="https://auth.example.com/token"
|
62
|
+
... )
|
63
|
+
>>>
|
64
|
+
>>> token = await oauth_client.get_client_credentials_token(
|
65
|
+
... scopes=["mcp.tools", "mcp.resources"]
|
66
|
+
... )
|
67
|
+
"""
|
68
|
+
|
69
|
+
import asyncio
|
70
|
+
import base64
|
71
|
+
import hashlib
|
72
|
+
import json
|
73
|
+
import logging
|
74
|
+
import secrets
|
75
|
+
import time
|
76
|
+
import uuid
|
77
|
+
from abc import ABC, abstractmethod
|
78
|
+
from dataclasses import asdict, dataclass, field
|
79
|
+
from datetime import datetime, timedelta
|
80
|
+
from enum import Enum
|
81
|
+
from typing import Any, Dict, List, Optional, Set, Tuple, Union
|
82
|
+
from urllib.parse import parse_qs, urlencode, urlparse
|
83
|
+
|
84
|
+
import aiohttp
|
85
|
+
import jwt
|
86
|
+
from cryptography.hazmat.primitives import hashes, serialization
|
87
|
+
from cryptography.hazmat.primitives.asymmetric import rsa
|
88
|
+
|
89
|
+
from .auth import AuthProvider
|
90
|
+
from .errors import AuthenticationError, AuthorizationError, MCPError
|
91
|
+
|
92
|
+
logger = logging.getLogger(__name__)
|
93
|
+
|
94
|
+
|
95
|
+
class GrantType(Enum):
|
96
|
+
"""OAuth 2.1 grant types."""
|
97
|
+
|
98
|
+
AUTHORIZATION_CODE = "authorization_code"
|
99
|
+
CLIENT_CREDENTIALS = "client_credentials"
|
100
|
+
REFRESH_TOKEN = "refresh_token"
|
101
|
+
|
102
|
+
|
103
|
+
class TokenType(Enum):
|
104
|
+
"""Token types."""
|
105
|
+
|
106
|
+
ACCESS_TOKEN = "access_token"
|
107
|
+
REFRESH_TOKEN = "refresh_token"
|
108
|
+
ID_TOKEN = "id_token"
|
109
|
+
|
110
|
+
|
111
|
+
class ClientType(Enum):
|
112
|
+
"""OAuth client types."""
|
113
|
+
|
114
|
+
CONFIDENTIAL = "confidential"
|
115
|
+
PUBLIC = "public"
|
116
|
+
|
117
|
+
|
118
|
+
@dataclass
|
119
|
+
class OAuthClient:
|
120
|
+
"""OAuth 2.1 client registration."""
|
121
|
+
|
122
|
+
client_id: str
|
123
|
+
client_name: str = ""
|
124
|
+
client_type: ClientType = ClientType.CONFIDENTIAL
|
125
|
+
redirect_uris: List[str] = field(default_factory=list)
|
126
|
+
grant_types: List[GrantType] = field(default_factory=list)
|
127
|
+
scopes: List[str] = field(default_factory=list)
|
128
|
+
client_secret: Optional[str] = None
|
129
|
+
response_types: List[str] = field(default_factory=lambda: ["code"])
|
130
|
+
token_endpoint_auth_method: str = "client_secret_basic"
|
131
|
+
created_at: float = field(default_factory=time.time)
|
132
|
+
metadata: Dict[str, Any] = field(default_factory=dict)
|
133
|
+
|
134
|
+
def to_dict(self) -> Dict[str, Any]:
|
135
|
+
"""Convert to dictionary format."""
|
136
|
+
result = asdict(self)
|
137
|
+
result["client_type"] = self.client_type.value
|
138
|
+
result["grant_types"] = [gt.value for gt in self.grant_types]
|
139
|
+
return result
|
140
|
+
|
141
|
+
@classmethod
|
142
|
+
def from_dict(cls, data: Dict[str, Any]) -> "OAuthClient":
|
143
|
+
"""Create from dictionary format."""
|
144
|
+
data = data.copy()
|
145
|
+
data["client_type"] = ClientType(data["client_type"])
|
146
|
+
data["grant_types"] = [GrantType(gt) for gt in data["grant_types"]]
|
147
|
+
return cls(**data)
|
148
|
+
|
149
|
+
def supports_grant_type(self, grant_type: GrantType) -> bool:
|
150
|
+
"""Check if client supports grant type."""
|
151
|
+
return grant_type in self.grant_types
|
152
|
+
|
153
|
+
def has_scope(self, scope: str) -> bool:
|
154
|
+
"""Check if client has scope."""
|
155
|
+
return scope in self.scopes
|
156
|
+
|
157
|
+
def validate_redirect_uri(self, redirect_uri: str) -> bool:
|
158
|
+
"""Validate redirect URI."""
|
159
|
+
return redirect_uri in self.redirect_uris
|
160
|
+
|
161
|
+
def is_valid_redirect_uri(self, redirect_uri: str) -> bool:
|
162
|
+
"""Validate redirect URI (alias for validate_redirect_uri)."""
|
163
|
+
return redirect_uri in self.redirect_uris
|
164
|
+
|
165
|
+
|
166
|
+
@dataclass
|
167
|
+
class AccessToken:
|
168
|
+
"""OAuth 2.1 access token."""
|
169
|
+
|
170
|
+
token: str
|
171
|
+
client_id: str
|
172
|
+
token_type: str = "Bearer"
|
173
|
+
expires_in: int = 3600
|
174
|
+
scope: Optional[str] = None
|
175
|
+
scopes: Optional[List[str]] = None
|
176
|
+
subject: Optional[str] = None
|
177
|
+
user_id: Optional[str] = None # Alias for subject
|
178
|
+
audience: Optional[List[str]] = None
|
179
|
+
issued_at: float = field(default_factory=time.time)
|
180
|
+
expires_at: Optional[float] = None
|
181
|
+
|
182
|
+
def __post_init__(self):
|
183
|
+
# Handle user_id as alias for subject
|
184
|
+
if self.user_id and not self.subject:
|
185
|
+
self.subject = self.user_id
|
186
|
+
elif self.subject and not self.user_id:
|
187
|
+
self.user_id = self.subject
|
188
|
+
|
189
|
+
# Set expires_at if not provided
|
190
|
+
if self.expires_at is None:
|
191
|
+
self.expires_at = self.issued_at + self.expires_in
|
192
|
+
|
193
|
+
# Convert scopes list to scope string if needed
|
194
|
+
if self.scopes and not self.scope:
|
195
|
+
self.scope = " ".join(self.scopes)
|
196
|
+
|
197
|
+
def is_expired(self) -> bool:
|
198
|
+
"""Check if token is expired."""
|
199
|
+
return time.time() > self.expires_at
|
200
|
+
|
201
|
+
def to_dict(self) -> Dict[str, Any]:
|
202
|
+
"""Convert to dictionary format."""
|
203
|
+
return {
|
204
|
+
"access_token": self.token,
|
205
|
+
"token_type": self.token_type,
|
206
|
+
"expires_in": self.expires_in,
|
207
|
+
"scope": self.scope,
|
208
|
+
}
|
209
|
+
|
210
|
+
def has_scope(self, scope: str) -> bool:
|
211
|
+
"""Check if token has a specific scope."""
|
212
|
+
if self.scopes:
|
213
|
+
return scope in self.scopes
|
214
|
+
elif self.scope:
|
215
|
+
return scope in self.scope.split()
|
216
|
+
return False
|
217
|
+
|
218
|
+
|
219
|
+
@dataclass
|
220
|
+
class RefreshToken:
|
221
|
+
"""OAuth 2.1 refresh token."""
|
222
|
+
|
223
|
+
token: str
|
224
|
+
client_id: str
|
225
|
+
subject: Optional[str] = None
|
226
|
+
user_id: Optional[str] = None # Alias for subject
|
227
|
+
scope: Optional[str] = None
|
228
|
+
scopes: Optional[List[str]] = None
|
229
|
+
issued_at: float = field(default_factory=time.time)
|
230
|
+
expires_at: Optional[float] = None
|
231
|
+
is_revoked: bool = False
|
232
|
+
|
233
|
+
def __post_init__(self):
|
234
|
+
# Handle user_id as alias for subject
|
235
|
+
if self.user_id and not self.subject:
|
236
|
+
self.subject = self.user_id
|
237
|
+
elif self.subject and not self.user_id:
|
238
|
+
self.user_id = self.subject
|
239
|
+
|
240
|
+
# Convert scopes list to scope string if needed
|
241
|
+
if self.scopes and not self.scope:
|
242
|
+
self.scope = " ".join(self.scopes)
|
243
|
+
|
244
|
+
def is_expired(self) -> bool:
|
245
|
+
"""Check if token is expired."""
|
246
|
+
if self.expires_at is None:
|
247
|
+
return False
|
248
|
+
return time.time() > self.expires_at
|
249
|
+
|
250
|
+
def revoke(self) -> None:
|
251
|
+
"""Revoke the refresh token."""
|
252
|
+
self.is_revoked = True
|
253
|
+
|
254
|
+
|
255
|
+
@dataclass
|
256
|
+
class AuthorizationCode:
|
257
|
+
"""OAuth 2.1 authorization code."""
|
258
|
+
|
259
|
+
code: str
|
260
|
+
client_id: str
|
261
|
+
redirect_uri: str
|
262
|
+
scope: Optional[str] = None
|
263
|
+
scopes: Optional[List[str]] = None
|
264
|
+
subject: Optional[str] = None
|
265
|
+
user_id: Optional[str] = None # Alias for subject
|
266
|
+
code_challenge: Optional[str] = None
|
267
|
+
code_challenge_method: Optional[str] = None
|
268
|
+
issued_at: float = field(default_factory=time.time)
|
269
|
+
expires_at: float = field(default_factory=lambda: time.time() + 600) # 10 minutes
|
270
|
+
|
271
|
+
def __post_init__(self):
|
272
|
+
# Handle user_id as alias for subject
|
273
|
+
if self.user_id and not self.subject:
|
274
|
+
self.subject = self.user_id
|
275
|
+
elif self.subject and not self.user_id:
|
276
|
+
self.user_id = self.subject
|
277
|
+
|
278
|
+
# Convert scopes list to scope string if needed
|
279
|
+
if self.scopes and not self.scope:
|
280
|
+
self.scope = " ".join(self.scopes)
|
281
|
+
|
282
|
+
def is_expired(self) -> bool:
|
283
|
+
"""Check if code is expired."""
|
284
|
+
return time.time() > self.expires_at
|
285
|
+
|
286
|
+
def validate_pkce(self, code_verifier: str) -> bool:
|
287
|
+
"""Validate PKCE code verifier."""
|
288
|
+
if not self.code_challenge:
|
289
|
+
return True # PKCE not used
|
290
|
+
|
291
|
+
if self.code_challenge_method == "S256":
|
292
|
+
# SHA256 challenge method
|
293
|
+
verifier_hash = hashlib.sha256(code_verifier.encode()).digest()
|
294
|
+
verifier_challenge = (
|
295
|
+
base64.urlsafe_b64encode(verifier_hash).decode().rstrip("=")
|
296
|
+
)
|
297
|
+
return verifier_challenge == self.code_challenge
|
298
|
+
elif self.code_challenge_method == "plain":
|
299
|
+
# Plain challenge method
|
300
|
+
return code_verifier == self.code_challenge
|
301
|
+
else:
|
302
|
+
return False
|
303
|
+
|
304
|
+
|
305
|
+
class ClientStore(ABC):
|
306
|
+
"""Abstract base class for OAuth client storage."""
|
307
|
+
|
308
|
+
@abstractmethod
|
309
|
+
async def store_client(self, client: OAuthClient) -> None:
|
310
|
+
"""Store OAuth client."""
|
311
|
+
pass
|
312
|
+
|
313
|
+
@abstractmethod
|
314
|
+
async def get_client(self, client_id: str) -> Optional[OAuthClient]:
|
315
|
+
"""Get OAuth client by ID."""
|
316
|
+
pass
|
317
|
+
|
318
|
+
@abstractmethod
|
319
|
+
async def delete_client(self, client_id: str) -> bool:
|
320
|
+
"""Delete OAuth client."""
|
321
|
+
pass
|
322
|
+
|
323
|
+
@abstractmethod
|
324
|
+
async def list_clients(self) -> List[OAuthClient]:
|
325
|
+
"""List all OAuth clients."""
|
326
|
+
pass
|
327
|
+
|
328
|
+
|
329
|
+
class InMemoryClientStore(ClientStore):
|
330
|
+
"""In-memory OAuth client store."""
|
331
|
+
|
332
|
+
def __init__(self):
|
333
|
+
"""Initialize in-memory store."""
|
334
|
+
self._clients: Dict[str, OAuthClient] = {}
|
335
|
+
|
336
|
+
async def store_client(self, client: OAuthClient) -> None:
|
337
|
+
"""Store OAuth client."""
|
338
|
+
self._clients[client.client_id] = client
|
339
|
+
|
340
|
+
async def get_client(self, client_id: str) -> Optional[OAuthClient]:
|
341
|
+
"""Get OAuth client by ID."""
|
342
|
+
return self._clients.get(client_id)
|
343
|
+
|
344
|
+
async def delete_client(self, client_id: str) -> bool:
|
345
|
+
"""Delete OAuth client."""
|
346
|
+
if client_id in self._clients:
|
347
|
+
del self._clients[client_id]
|
348
|
+
return True
|
349
|
+
return False
|
350
|
+
|
351
|
+
async def list_clients(self) -> List[OAuthClient]:
|
352
|
+
"""List all OAuth clients."""
|
353
|
+
return list(self._clients.values())
|
354
|
+
|
355
|
+
async def authenticate_client(
|
356
|
+
self, client_id: str, client_secret: str
|
357
|
+
) -> Optional[OAuthClient]:
|
358
|
+
"""Authenticate OAuth client."""
|
359
|
+
client = await self.get_client(client_id)
|
360
|
+
if client and client.client_secret == client_secret:
|
361
|
+
return client
|
362
|
+
return None
|
363
|
+
|
364
|
+
|
365
|
+
class TokenStore(ABC):
|
366
|
+
"""Abstract base class for token storage."""
|
367
|
+
|
368
|
+
@abstractmethod
|
369
|
+
async def store_access_token(self, token: AccessToken) -> None:
|
370
|
+
"""Store access token."""
|
371
|
+
pass
|
372
|
+
|
373
|
+
@abstractmethod
|
374
|
+
async def get_access_token(self, token: str) -> Optional[AccessToken]:
|
375
|
+
"""Get access token."""
|
376
|
+
pass
|
377
|
+
|
378
|
+
@abstractmethod
|
379
|
+
async def revoke_access_token(self, token: str) -> bool:
|
380
|
+
"""Revoke access token."""
|
381
|
+
pass
|
382
|
+
|
383
|
+
@abstractmethod
|
384
|
+
async def store_refresh_token(self, token: RefreshToken) -> None:
|
385
|
+
"""Store refresh token."""
|
386
|
+
pass
|
387
|
+
|
388
|
+
@abstractmethod
|
389
|
+
async def get_refresh_token(self, token: str) -> Optional[RefreshToken]:
|
390
|
+
"""Get refresh token."""
|
391
|
+
pass
|
392
|
+
|
393
|
+
@abstractmethod
|
394
|
+
async def revoke_refresh_token(self, token: str) -> bool:
|
395
|
+
"""Revoke refresh token."""
|
396
|
+
pass
|
397
|
+
|
398
|
+
@abstractmethod
|
399
|
+
async def store_authorization_code(self, code: AuthorizationCode) -> None:
|
400
|
+
"""Store authorization code."""
|
401
|
+
pass
|
402
|
+
|
403
|
+
@abstractmethod
|
404
|
+
async def get_authorization_code(self, code: str) -> Optional[AuthorizationCode]:
|
405
|
+
"""Get authorization code."""
|
406
|
+
pass
|
407
|
+
|
408
|
+
@abstractmethod
|
409
|
+
async def consume_authorization_code(
|
410
|
+
self, code: str
|
411
|
+
) -> Optional[AuthorizationCode]:
|
412
|
+
"""Consume authorization code (get and delete)."""
|
413
|
+
pass
|
414
|
+
|
415
|
+
|
416
|
+
class InMemoryTokenStore(TokenStore):
|
417
|
+
"""In-memory token store."""
|
418
|
+
|
419
|
+
def __init__(self):
|
420
|
+
"""Initialize in-memory store."""
|
421
|
+
self._access_tokens: Dict[str, AccessToken] = {}
|
422
|
+
self._refresh_tokens: Dict[str, RefreshToken] = {}
|
423
|
+
self._authorization_codes: Dict[str, AuthorizationCode] = {}
|
424
|
+
|
425
|
+
async def store_access_token(self, token: AccessToken) -> None:
|
426
|
+
"""Store access token."""
|
427
|
+
self._access_tokens[token.token] = token
|
428
|
+
|
429
|
+
async def get_access_token(self, token: str) -> Optional[AccessToken]:
|
430
|
+
"""Get access token."""
|
431
|
+
access_token = self._access_tokens.get(token)
|
432
|
+
if access_token and access_token.is_expired():
|
433
|
+
del self._access_tokens[token]
|
434
|
+
return None
|
435
|
+
return access_token
|
436
|
+
|
437
|
+
async def revoke_access_token(self, token: str) -> bool:
|
438
|
+
"""Revoke access token."""
|
439
|
+
if token in self._access_tokens:
|
440
|
+
del self._access_tokens[token]
|
441
|
+
return True
|
442
|
+
return False
|
443
|
+
|
444
|
+
async def store_refresh_token(self, token: RefreshToken) -> None:
|
445
|
+
"""Store refresh token."""
|
446
|
+
self._refresh_tokens[token.token] = token
|
447
|
+
|
448
|
+
async def get_refresh_token(self, token: str) -> Optional[RefreshToken]:
|
449
|
+
"""Get refresh token."""
|
450
|
+
refresh_token = self._refresh_tokens.get(token)
|
451
|
+
if refresh_token and refresh_token.is_expired():
|
452
|
+
del self._refresh_tokens[token]
|
453
|
+
return None
|
454
|
+
return refresh_token
|
455
|
+
|
456
|
+
async def revoke_refresh_token(self, token: str) -> bool:
|
457
|
+
"""Revoke refresh token."""
|
458
|
+
if token in self._refresh_tokens:
|
459
|
+
del self._refresh_tokens[token]
|
460
|
+
return True
|
461
|
+
return False
|
462
|
+
|
463
|
+
async def store_authorization_code(self, code: AuthorizationCode) -> None:
|
464
|
+
"""Store authorization code."""
|
465
|
+
self._authorization_codes[code.code] = code
|
466
|
+
|
467
|
+
async def get_authorization_code(self, code: str) -> Optional[AuthorizationCode]:
|
468
|
+
"""Get authorization code."""
|
469
|
+
auth_code = self._authorization_codes.get(code)
|
470
|
+
if auth_code and auth_code.is_expired():
|
471
|
+
del self._authorization_codes[code]
|
472
|
+
return None
|
473
|
+
return auth_code
|
474
|
+
|
475
|
+
async def consume_authorization_code(
|
476
|
+
self, code: str
|
477
|
+
) -> Optional[AuthorizationCode]:
|
478
|
+
"""Consume authorization code."""
|
479
|
+
auth_code = await self.get_authorization_code(code)
|
480
|
+
if auth_code:
|
481
|
+
del self._authorization_codes[code]
|
482
|
+
return auth_code
|
483
|
+
|
484
|
+
|
485
|
+
class JWTManager:
|
486
|
+
"""JWT token manager for OAuth 2.1."""
|
487
|
+
|
488
|
+
def __init__(
|
489
|
+
self,
|
490
|
+
private_key: Optional[str] = None,
|
491
|
+
public_key: Optional[str] = None,
|
492
|
+
algorithm: str = "RS256",
|
493
|
+
issuer: Optional[str] = None,
|
494
|
+
private_key_pem: Optional[str] = None, # Backward compatibility
|
495
|
+
public_key_pem: Optional[str] = None, # Backward compatibility
|
496
|
+
):
|
497
|
+
"""Initialize JWT manager.
|
498
|
+
|
499
|
+
Args:
|
500
|
+
private_key: Private key for signing (PEM format)
|
501
|
+
public_key: Public key for verification (PEM format)
|
502
|
+
algorithm: JWT algorithm
|
503
|
+
issuer: Token issuer
|
504
|
+
"""
|
505
|
+
self.algorithm = algorithm
|
506
|
+
self.issuer = issuer
|
507
|
+
|
508
|
+
# Handle backward compatibility
|
509
|
+
private_key = private_key or private_key_pem
|
510
|
+
public_key = public_key or public_key_pem
|
511
|
+
|
512
|
+
if private_key:
|
513
|
+
self.private_key = serialization.load_pem_private_key(
|
514
|
+
private_key.encode(), password=None
|
515
|
+
)
|
516
|
+
else:
|
517
|
+
# Generate key pair
|
518
|
+
self.private_key = rsa.generate_private_key(
|
519
|
+
public_exponent=65537, key_size=2048
|
520
|
+
)
|
521
|
+
|
522
|
+
if public_key:
|
523
|
+
self.public_key = serialization.load_pem_public_key(public_key.encode())
|
524
|
+
else:
|
525
|
+
self.public_key = self.private_key.public_key()
|
526
|
+
|
527
|
+
def create_access_token(
|
528
|
+
self,
|
529
|
+
subject: Optional[Union[str, Dict[str, Any]]] = None,
|
530
|
+
client_id: Optional[str] = None,
|
531
|
+
scope: Optional[str] = None,
|
532
|
+
audience: Optional[List[str]] = None,
|
533
|
+
expires_in: int = 3600,
|
534
|
+
) -> Union[AccessToken, str]:
|
535
|
+
"""Create JWT access token.
|
536
|
+
|
537
|
+
Args:
|
538
|
+
subject: Token subject (user ID)
|
539
|
+
client_id: OAuth client ID
|
540
|
+
scope: Token scope
|
541
|
+
audience: Token audience
|
542
|
+
expires_in: Token lifetime in seconds
|
543
|
+
|
544
|
+
Returns:
|
545
|
+
Access token
|
546
|
+
"""
|
547
|
+
# Handle dictionary input for backward compatibility
|
548
|
+
token_data_dict = None
|
549
|
+
if isinstance(subject, dict):
|
550
|
+
token_data_dict = subject
|
551
|
+
subject = token_data_dict.get("user_id")
|
552
|
+
client_id = token_data_dict.get("client_id", client_id)
|
553
|
+
scope = token_data_dict.get("scope")
|
554
|
+
if not scope and "scopes" in token_data_dict:
|
555
|
+
scope = " ".join(token_data_dict["scopes"])
|
556
|
+
audience = token_data_dict.get("audience", audience)
|
557
|
+
expires_in = token_data_dict.get("expires_in", expires_in)
|
558
|
+
|
559
|
+
now = time.time()
|
560
|
+
expires_at = now + expires_in
|
561
|
+
|
562
|
+
payload = {
|
563
|
+
"iss": self.issuer,
|
564
|
+
"iat": int(now),
|
565
|
+
"exp": int(expires_at),
|
566
|
+
"jti": str(uuid.uuid4()),
|
567
|
+
"token_type": "access_token",
|
568
|
+
}
|
569
|
+
|
570
|
+
if subject:
|
571
|
+
payload["sub"] = subject
|
572
|
+
if client_id:
|
573
|
+
payload["client_id"] = client_id
|
574
|
+
if scope:
|
575
|
+
payload["scope"] = scope
|
576
|
+
if audience:
|
577
|
+
payload["aud"] = audience
|
578
|
+
|
579
|
+
# Add custom claims from token_data if it was a dict
|
580
|
+
if token_data_dict:
|
581
|
+
for key in ["user_id", "scopes"]:
|
582
|
+
if key in token_data_dict and key not in [
|
583
|
+
"client_id",
|
584
|
+
"scope",
|
585
|
+
"audience",
|
586
|
+
"expires_in",
|
587
|
+
]:
|
588
|
+
payload[key] = token_data_dict[key]
|
589
|
+
|
590
|
+
token = jwt.encode(payload, self.private_key, algorithm=self.algorithm)
|
591
|
+
|
592
|
+
# For backward compatibility, return string if called with dict
|
593
|
+
if token_data_dict:
|
594
|
+
return token
|
595
|
+
|
596
|
+
return AccessToken(
|
597
|
+
token=token,
|
598
|
+
expires_in=expires_in,
|
599
|
+
scope=scope,
|
600
|
+
client_id=client_id,
|
601
|
+
subject=subject,
|
602
|
+
audience=audience,
|
603
|
+
issued_at=now,
|
604
|
+
expires_at=expires_at,
|
605
|
+
)
|
606
|
+
|
607
|
+
def verify_access_token(self, token: str) -> Optional[Dict[str, Any]]:
|
608
|
+
"""Verify JWT access token.
|
609
|
+
|
610
|
+
Args:
|
611
|
+
token: JWT token to verify
|
612
|
+
|
613
|
+
Returns:
|
614
|
+
Token payload or None if invalid
|
615
|
+
"""
|
616
|
+
try:
|
617
|
+
payload = jwt.decode(token, self.public_key, algorithms=[self.algorithm])
|
618
|
+
|
619
|
+
# Verify token type
|
620
|
+
if payload.get("token_type") != "access_token":
|
621
|
+
return None
|
622
|
+
|
623
|
+
return payload
|
624
|
+
|
625
|
+
except jwt.InvalidTokenError as e:
|
626
|
+
raise AuthenticationError(f"Invalid token: {e}")
|
627
|
+
|
628
|
+
def create_refresh_token(
|
629
|
+
self,
|
630
|
+
token_data: Union[Dict[str, Any], str],
|
631
|
+
expires_in: int = 2592000, # 30 days
|
632
|
+
) -> str:
|
633
|
+
"""Create JWT refresh token.
|
634
|
+
|
635
|
+
Args:
|
636
|
+
token_data: Token data dict or client ID
|
637
|
+
expires_in: Token lifetime in seconds
|
638
|
+
|
639
|
+
Returns:
|
640
|
+
JWT refresh token string
|
641
|
+
"""
|
642
|
+
now = time.time()
|
643
|
+
expires_at = now + expires_in
|
644
|
+
|
645
|
+
payload = {
|
646
|
+
"iss": self.issuer,
|
647
|
+
"iat": int(now),
|
648
|
+
"exp": int(expires_at),
|
649
|
+
"jti": str(uuid.uuid4()),
|
650
|
+
"token_type": "refresh_token",
|
651
|
+
}
|
652
|
+
|
653
|
+
if isinstance(token_data, dict):
|
654
|
+
if "client_id" in token_data:
|
655
|
+
payload["client_id"] = token_data["client_id"]
|
656
|
+
if "user_id" in token_data:
|
657
|
+
payload["sub"] = token_data["user_id"]
|
658
|
+
payload["user_id"] = token_data["user_id"]
|
659
|
+
else:
|
660
|
+
payload["client_id"] = token_data
|
661
|
+
|
662
|
+
return jwt.encode(payload, self.private_key, algorithm=self.algorithm)
|
663
|
+
|
664
|
+
def verify_refresh_token(self, token: str) -> Optional[Dict[str, Any]]:
|
665
|
+
"""Verify JWT refresh token.
|
666
|
+
|
667
|
+
Args:
|
668
|
+
token: JWT token to verify
|
669
|
+
|
670
|
+
Returns:
|
671
|
+
Token payload if valid, None otherwise
|
672
|
+
"""
|
673
|
+
try:
|
674
|
+
payload = jwt.decode(
|
675
|
+
token,
|
676
|
+
self.public_key,
|
677
|
+
algorithms=[self.algorithm],
|
678
|
+
options={"verify_aud": False},
|
679
|
+
)
|
680
|
+
|
681
|
+
# Check token type
|
682
|
+
if payload.get("token_type") != "refresh_token":
|
683
|
+
raise AuthenticationError("Invalid token type")
|
684
|
+
|
685
|
+
# Check issuer
|
686
|
+
if self.issuer and payload.get("iss") != self.issuer:
|
687
|
+
raise AuthenticationError("Invalid issuer")
|
688
|
+
|
689
|
+
return payload
|
690
|
+
|
691
|
+
except jwt.ExpiredSignatureError:
|
692
|
+
raise AuthenticationError("Token expired")
|
693
|
+
except jwt.InvalidTokenError as e:
|
694
|
+
raise AuthenticationError(f"Invalid token: {e}")
|
695
|
+
except Exception as e:
|
696
|
+
logger.error(f"Token verification failed: {e}")
|
697
|
+
return None
|
698
|
+
|
699
|
+
def get_public_key_jwks(self) -> Dict[str, Any]:
|
700
|
+
"""Get public key in JWKS format.
|
701
|
+
|
702
|
+
Returns:
|
703
|
+
JWKS public key
|
704
|
+
"""
|
705
|
+
public_numbers = self.public_key.public_numbers()
|
706
|
+
|
707
|
+
# Convert to base64url encoding
|
708
|
+
def int_to_base64url(value: int) -> str:
|
709
|
+
byte_length = (value.bit_length() + 7) // 8
|
710
|
+
bytes_value = value.to_bytes(byte_length, byteorder="big")
|
711
|
+
return base64.urlsafe_b64encode(bytes_value).decode().rstrip("=")
|
712
|
+
|
713
|
+
return {
|
714
|
+
"keys": [
|
715
|
+
{
|
716
|
+
"kty": "RSA",
|
717
|
+
"use": "sig",
|
718
|
+
"alg": self.algorithm,
|
719
|
+
"n": int_to_base64url(public_numbers.n),
|
720
|
+
"e": int_to_base64url(public_numbers.e),
|
721
|
+
}
|
722
|
+
]
|
723
|
+
}
|
724
|
+
|
725
|
+
|
726
|
+
class AuthorizationServer:
|
727
|
+
"""OAuth 2.1 Authorization Server."""
|
728
|
+
|
729
|
+
def __init__(
|
730
|
+
self,
|
731
|
+
issuer: str,
|
732
|
+
client_store: Optional[ClientStore] = None,
|
733
|
+
token_store: Optional[TokenStore] = None,
|
734
|
+
jwt_manager: Optional[JWTManager] = None,
|
735
|
+
default_scopes: Optional[List[str]] = None,
|
736
|
+
private_key_path: Optional[str] = None, # For backward compatibility
|
737
|
+
):
|
738
|
+
"""Initialize authorization server.
|
739
|
+
|
740
|
+
Args:
|
741
|
+
issuer: Server issuer URL
|
742
|
+
client_store: Client storage
|
743
|
+
token_store: Token storage
|
744
|
+
jwt_manager: JWT manager
|
745
|
+
default_scopes: Default scopes
|
746
|
+
"""
|
747
|
+
self.issuer = issuer
|
748
|
+
self.client_store = client_store or InMemoryClientStore()
|
749
|
+
self.token_store = token_store or InMemoryTokenStore()
|
750
|
+
|
751
|
+
# Create JWT manager with private key if provided
|
752
|
+
if jwt_manager:
|
753
|
+
self.jwt_manager = jwt_manager
|
754
|
+
elif private_key_path:
|
755
|
+
# Read private key from file
|
756
|
+
try:
|
757
|
+
with open(private_key_path, "r") as f:
|
758
|
+
private_key = f.read()
|
759
|
+
self.jwt_manager = JWTManager(issuer=issuer, private_key=private_key)
|
760
|
+
except FileNotFoundError:
|
761
|
+
# For testing, create a default JWT manager
|
762
|
+
self.jwt_manager = JWTManager(issuer=issuer)
|
763
|
+
else:
|
764
|
+
self.jwt_manager = JWTManager(issuer=issuer)
|
765
|
+
|
766
|
+
self.default_scopes = default_scopes or ["mcp.basic"]
|
767
|
+
|
768
|
+
async def register_client(
|
769
|
+
self,
|
770
|
+
client_name: str,
|
771
|
+
redirect_uris: Optional[List[str]] = None,
|
772
|
+
grant_types: Optional[List[str]] = None,
|
773
|
+
scopes: Optional[List[str]] = None,
|
774
|
+
client_type: Optional[str] = None,
|
775
|
+
**metadata,
|
776
|
+
) -> OAuthClient:
|
777
|
+
"""Register OAuth client.
|
778
|
+
|
779
|
+
Args:
|
780
|
+
client_name: Client name
|
781
|
+
redirect_uris: Redirect URIs
|
782
|
+
grant_types: Allowed grant types
|
783
|
+
scopes: Allowed scopes
|
784
|
+
client_type: Client type (confidential/public)
|
785
|
+
**metadata: Additional metadata
|
786
|
+
|
787
|
+
Returns:
|
788
|
+
Registered client
|
789
|
+
"""
|
790
|
+
client_id = f"client_{uuid.uuid4().hex[:16]}"
|
791
|
+
|
792
|
+
# Determine client type
|
793
|
+
if client_type:
|
794
|
+
client_type_enum = ClientType(client_type)
|
795
|
+
else:
|
796
|
+
# Default to confidential
|
797
|
+
client_type_enum = ClientType.CONFIDENTIAL
|
798
|
+
|
799
|
+
# Generate client secret for confidential clients
|
800
|
+
client_secret = None
|
801
|
+
if client_type_enum == ClientType.CONFIDENTIAL:
|
802
|
+
client_secret = secrets.token_urlsafe(32)
|
803
|
+
|
804
|
+
# Parse grant types
|
805
|
+
grant_type_enums = []
|
806
|
+
if grant_types:
|
807
|
+
grant_type_enums = [GrantType(gt) for gt in grant_types]
|
808
|
+
else:
|
809
|
+
grant_type_enums = [GrantType.AUTHORIZATION_CODE]
|
810
|
+
|
811
|
+
# Use default scopes if not provided
|
812
|
+
if not scopes:
|
813
|
+
scopes = self.default_scopes.copy()
|
814
|
+
|
815
|
+
# Default redirect URIs for certain grant types
|
816
|
+
if not redirect_uris:
|
817
|
+
redirect_uris = []
|
818
|
+
|
819
|
+
client = OAuthClient(
|
820
|
+
client_id=client_id,
|
821
|
+
client_secret=client_secret,
|
822
|
+
client_name=client_name,
|
823
|
+
client_type=client_type_enum,
|
824
|
+
redirect_uris=redirect_uris,
|
825
|
+
grant_types=grant_type_enums,
|
826
|
+
scopes=scopes,
|
827
|
+
metadata=metadata,
|
828
|
+
)
|
829
|
+
|
830
|
+
await self.client_store.store_client(client)
|
831
|
+
|
832
|
+
logger.info(f"Registered OAuth client: {client_name} ({client_id})")
|
833
|
+
return client
|
834
|
+
|
835
|
+
async def create_authorization_url(
|
836
|
+
self,
|
837
|
+
client_id: str,
|
838
|
+
redirect_uri: str,
|
839
|
+
scope: Optional[str] = None,
|
840
|
+
state: Optional[str] = None,
|
841
|
+
code_challenge: Optional[str] = None,
|
842
|
+
code_challenge_method: Optional[str] = None,
|
843
|
+
) -> str:
|
844
|
+
"""Create authorization URL.
|
845
|
+
|
846
|
+
Args:
|
847
|
+
client_id: OAuth client ID
|
848
|
+
redirect_uri: Redirect URI
|
849
|
+
scope: Requested scope
|
850
|
+
state: State parameter
|
851
|
+
code_challenge: PKCE code challenge
|
852
|
+
code_challenge_method: PKCE challenge method
|
853
|
+
|
854
|
+
Returns:
|
855
|
+
Authorization URL
|
856
|
+
"""
|
857
|
+
# Validate client
|
858
|
+
client = await self.client_store.get_client(client_id)
|
859
|
+
if not client:
|
860
|
+
raise AuthorizationError("Invalid client")
|
861
|
+
|
862
|
+
if not client.validate_redirect_uri(redirect_uri):
|
863
|
+
raise AuthorizationError("Invalid redirect URI")
|
864
|
+
|
865
|
+
# Build authorization URL parameters
|
866
|
+
params = {
|
867
|
+
"response_type": "code",
|
868
|
+
"client_id": client_id,
|
869
|
+
"redirect_uri": redirect_uri,
|
870
|
+
}
|
871
|
+
|
872
|
+
if scope:
|
873
|
+
params["scope"] = scope
|
874
|
+
if state:
|
875
|
+
params["state"] = state
|
876
|
+
if code_challenge:
|
877
|
+
params["code_challenge"] = code_challenge
|
878
|
+
params["code_challenge_method"] = code_challenge_method or "S256"
|
879
|
+
|
880
|
+
query_string = urlencode(params)
|
881
|
+
return f"{self.issuer}/authorize?{query_string}"
|
882
|
+
|
883
|
+
async def generate_authorization_code(
|
884
|
+
self,
|
885
|
+
client_id: str,
|
886
|
+
user_id: str,
|
887
|
+
redirect_uri: str,
|
888
|
+
scopes: Optional[List[str]] = None,
|
889
|
+
state: Optional[str] = None,
|
890
|
+
code_challenge: Optional[str] = None,
|
891
|
+
code_challenge_method: Optional[str] = None,
|
892
|
+
) -> str:
|
893
|
+
"""Generate authorization code for the user.
|
894
|
+
|
895
|
+
Args:
|
896
|
+
client_id: OAuth client ID
|
897
|
+
user_id: User ID
|
898
|
+
redirect_uri: Redirect URI
|
899
|
+
scopes: Requested scopes
|
900
|
+
state: State parameter
|
901
|
+
code_challenge: PKCE code challenge
|
902
|
+
code_challenge_method: PKCE challenge method
|
903
|
+
|
904
|
+
Returns:
|
905
|
+
Authorization code
|
906
|
+
"""
|
907
|
+
# Validate client
|
908
|
+
client = await self.client_store.get_client(client_id)
|
909
|
+
if not client:
|
910
|
+
raise AuthorizationError("Invalid client")
|
911
|
+
|
912
|
+
if not client.validate_redirect_uri(redirect_uri):
|
913
|
+
raise AuthorizationError("Invalid redirect URI")
|
914
|
+
|
915
|
+
# Convert scopes list to string
|
916
|
+
scope = " ".join(scopes) if scopes else None
|
917
|
+
|
918
|
+
# Create authorization code
|
919
|
+
auth_code = AuthorizationCode(
|
920
|
+
code=secrets.token_urlsafe(32),
|
921
|
+
client_id=client_id,
|
922
|
+
redirect_uri=redirect_uri,
|
923
|
+
scope=scope,
|
924
|
+
scopes=scopes,
|
925
|
+
subject=user_id,
|
926
|
+
user_id=user_id,
|
927
|
+
code_challenge=code_challenge,
|
928
|
+
code_challenge_method=code_challenge_method,
|
929
|
+
)
|
930
|
+
|
931
|
+
await self.token_store.store_authorization_code(auth_code)
|
932
|
+
|
933
|
+
return auth_code.code
|
934
|
+
|
935
|
+
async def exchange_authorization_code(
|
936
|
+
self,
|
937
|
+
client_id: str,
|
938
|
+
client_secret: Optional[str],
|
939
|
+
code: str,
|
940
|
+
redirect_uri: str,
|
941
|
+
code_verifier: Optional[str] = None,
|
942
|
+
) -> Dict[str, Any]:
|
943
|
+
"""Exchange authorization code for tokens.
|
944
|
+
|
945
|
+
Args:
|
946
|
+
client_id: OAuth client ID
|
947
|
+
client_secret: OAuth client secret
|
948
|
+
code: Authorization code
|
949
|
+
redirect_uri: Redirect URI
|
950
|
+
code_verifier: PKCE code verifier
|
951
|
+
|
952
|
+
Returns:
|
953
|
+
Token response
|
954
|
+
"""
|
955
|
+
# Validate client
|
956
|
+
client = await self.client_store.get_client(client_id)
|
957
|
+
if not client:
|
958
|
+
raise AuthorizationError("Invalid client")
|
959
|
+
|
960
|
+
# Validate client secret for confidential clients
|
961
|
+
if client.client_type == ClientType.CONFIDENTIAL:
|
962
|
+
if not client_secret or client_secret != client.client_secret:
|
963
|
+
raise AuthorizationError("Invalid client credentials")
|
964
|
+
|
965
|
+
# Get and consume authorization code
|
966
|
+
auth_code = await self.token_store.consume_authorization_code(code)
|
967
|
+
if not auth_code:
|
968
|
+
raise AuthorizationError("Invalid or expired authorization code")
|
969
|
+
|
970
|
+
# Validate authorization code
|
971
|
+
if auth_code.client_id != client_id:
|
972
|
+
raise AuthorizationError("Authorization code mismatch")
|
973
|
+
|
974
|
+
if auth_code.redirect_uri != redirect_uri:
|
975
|
+
raise AuthorizationError("Redirect URI mismatch")
|
976
|
+
|
977
|
+
# Validate PKCE if used
|
978
|
+
if auth_code.code_challenge:
|
979
|
+
if not code_verifier:
|
980
|
+
raise AuthorizationError("Code verifier required")
|
981
|
+
|
982
|
+
if not auth_code.validate_pkce(code_verifier):
|
983
|
+
raise AuthorizationError("Invalid code verifier")
|
984
|
+
|
985
|
+
# Create access token
|
986
|
+
access_token_jwt = self.jwt_manager.create_access_token(
|
987
|
+
subject=auth_code.subject,
|
988
|
+
client_id=client_id,
|
989
|
+
scope=auth_code.scope,
|
990
|
+
audience=["mcp-api"],
|
991
|
+
)
|
992
|
+
|
993
|
+
# Create AccessToken object if JWT string was returned
|
994
|
+
if isinstance(access_token_jwt, str):
|
995
|
+
access_token = AccessToken(
|
996
|
+
token=access_token_jwt,
|
997
|
+
client_id=client_id,
|
998
|
+
subject=auth_code.subject,
|
999
|
+
user_id=auth_code.user_id,
|
1000
|
+
scope=auth_code.scope,
|
1001
|
+
scopes=auth_code.scopes,
|
1002
|
+
)
|
1003
|
+
else:
|
1004
|
+
access_token = access_token_jwt
|
1005
|
+
|
1006
|
+
# Create refresh token JWT
|
1007
|
+
refresh_token_jwt = self.jwt_manager.create_refresh_token(
|
1008
|
+
{"client_id": client_id, "user_id": auth_code.subject}
|
1009
|
+
)
|
1010
|
+
|
1011
|
+
# Create RefreshToken object
|
1012
|
+
refresh_token = RefreshToken(
|
1013
|
+
token=refresh_token_jwt,
|
1014
|
+
client_id=client_id,
|
1015
|
+
subject=auth_code.subject,
|
1016
|
+
scope=auth_code.scope,
|
1017
|
+
)
|
1018
|
+
|
1019
|
+
# Store tokens
|
1020
|
+
await self.token_store.store_access_token(access_token)
|
1021
|
+
await self.token_store.store_refresh_token(refresh_token)
|
1022
|
+
|
1023
|
+
response = access_token.to_dict()
|
1024
|
+
response["refresh_token"] = refresh_token.token
|
1025
|
+
|
1026
|
+
return response
|
1027
|
+
|
1028
|
+
async def client_credentials_grant(
|
1029
|
+
self, client_id: str, client_secret: str, scopes: Optional[List[str]] = None
|
1030
|
+
) -> Dict[str, Any]:
|
1031
|
+
"""Handle client credentials grant.
|
1032
|
+
|
1033
|
+
Args:
|
1034
|
+
client_id: OAuth client ID
|
1035
|
+
client_secret: OAuth client secret
|
1036
|
+
scopes: Requested scopes
|
1037
|
+
|
1038
|
+
Returns:
|
1039
|
+
Token response
|
1040
|
+
"""
|
1041
|
+
# Validate client
|
1042
|
+
client = await self.client_store.get_client(client_id)
|
1043
|
+
if not client:
|
1044
|
+
raise AuthorizationError("Invalid client")
|
1045
|
+
|
1046
|
+
if not client.supports_grant_type(GrantType.CLIENT_CREDENTIALS):
|
1047
|
+
raise AuthorizationError("Grant type not supported")
|
1048
|
+
|
1049
|
+
# Validate client secret
|
1050
|
+
if client_secret != client.client_secret:
|
1051
|
+
raise AuthorizationError("Invalid client credentials")
|
1052
|
+
|
1053
|
+
# Validate scope
|
1054
|
+
scope = None
|
1055
|
+
if scopes:
|
1056
|
+
for requested_scope in scopes:
|
1057
|
+
if not client.has_scope(requested_scope):
|
1058
|
+
raise AuthorizationError(f"Invalid scope: {requested_scope}")
|
1059
|
+
scope = " ".join(scopes)
|
1060
|
+
|
1061
|
+
# Create access token
|
1062
|
+
access_token_jwt = self.jwt_manager.create_access_token(
|
1063
|
+
client_id=client_id, scope=scope, audience=["mcp-api"]
|
1064
|
+
)
|
1065
|
+
|
1066
|
+
# Create AccessToken object if JWT string was returned
|
1067
|
+
if isinstance(access_token_jwt, str):
|
1068
|
+
access_token = AccessToken(
|
1069
|
+
token=access_token_jwt,
|
1070
|
+
client_id=client_id,
|
1071
|
+
scope=scope,
|
1072
|
+
scopes=scopes,
|
1073
|
+
)
|
1074
|
+
else:
|
1075
|
+
access_token = access_token_jwt
|
1076
|
+
|
1077
|
+
# Store token
|
1078
|
+
await self.token_store.store_access_token(access_token)
|
1079
|
+
|
1080
|
+
return access_token.to_dict()
|
1081
|
+
|
1082
|
+
async def refresh_token_grant(
|
1083
|
+
self, client_id: str, client_secret: Optional[str], refresh_token: str
|
1084
|
+
) -> Dict[str, Any]:
|
1085
|
+
"""Handle refresh token grant.
|
1086
|
+
|
1087
|
+
Args:
|
1088
|
+
client_id: OAuth client ID
|
1089
|
+
client_secret: OAuth client secret
|
1090
|
+
refresh_token: Refresh token
|
1091
|
+
|
1092
|
+
Returns:
|
1093
|
+
Token response
|
1094
|
+
"""
|
1095
|
+
# Validate client
|
1096
|
+
client = await self.client_store.get_client(client_id)
|
1097
|
+
if not client:
|
1098
|
+
raise AuthorizationError("Invalid client")
|
1099
|
+
|
1100
|
+
# Validate client secret for confidential clients
|
1101
|
+
if client.client_type == ClientType.CONFIDENTIAL:
|
1102
|
+
if not client_secret or client_secret != client.client_secret:
|
1103
|
+
raise AuthorizationError("Invalid client credentials")
|
1104
|
+
|
1105
|
+
# First try to verify the refresh token as JWT
|
1106
|
+
try:
|
1107
|
+
token_data = self.jwt_manager.verify_refresh_token(refresh_token)
|
1108
|
+
if token_data:
|
1109
|
+
# Create RefreshToken object from JWT data
|
1110
|
+
refresh_token_obj = RefreshToken(
|
1111
|
+
token=refresh_token,
|
1112
|
+
client_id=token_data.get("client_id", client_id),
|
1113
|
+
subject=token_data.get("sub") or token_data.get("user_id"),
|
1114
|
+
user_id=token_data.get("user_id") or token_data.get("sub"),
|
1115
|
+
scope=(
|
1116
|
+
" ".join(token_data.get("scopes", []))
|
1117
|
+
if token_data.get("scopes")
|
1118
|
+
else None
|
1119
|
+
),
|
1120
|
+
scopes=token_data.get("scopes"),
|
1121
|
+
)
|
1122
|
+
except:
|
1123
|
+
# Fall back to token store
|
1124
|
+
refresh_token_obj = await self.token_store.get_refresh_token(refresh_token)
|
1125
|
+
if not refresh_token_obj:
|
1126
|
+
raise AuthorizationError("Invalid refresh token")
|
1127
|
+
|
1128
|
+
if refresh_token_obj.client_id != client_id:
|
1129
|
+
raise AuthorizationError("Client mismatch")
|
1130
|
+
|
1131
|
+
# Create new access token
|
1132
|
+
access_token_jwt = self.jwt_manager.create_access_token(
|
1133
|
+
subject=refresh_token_obj.subject,
|
1134
|
+
client_id=client_id,
|
1135
|
+
scope=refresh_token_obj.scope,
|
1136
|
+
audience=["mcp-api"],
|
1137
|
+
)
|
1138
|
+
|
1139
|
+
# Create AccessToken object if JWT string was returned
|
1140
|
+
if isinstance(access_token_jwt, str):
|
1141
|
+
access_token = AccessToken(
|
1142
|
+
token=access_token_jwt,
|
1143
|
+
client_id=client_id,
|
1144
|
+
subject=refresh_token_obj.subject,
|
1145
|
+
user_id=refresh_token_obj.user_id,
|
1146
|
+
scope=refresh_token_obj.scope,
|
1147
|
+
scopes=(
|
1148
|
+
refresh_token_obj.scopes
|
1149
|
+
if hasattr(refresh_token_obj, "scopes")
|
1150
|
+
else None
|
1151
|
+
),
|
1152
|
+
)
|
1153
|
+
else:
|
1154
|
+
access_token = access_token_jwt
|
1155
|
+
|
1156
|
+
# Store new access token
|
1157
|
+
await self.token_store.store_access_token(access_token)
|
1158
|
+
|
1159
|
+
return access_token.to_dict()
|
1160
|
+
|
1161
|
+
async def introspect_token(self, token: str) -> Dict[str, Any]:
|
1162
|
+
"""Introspect token.
|
1163
|
+
|
1164
|
+
Args:
|
1165
|
+
token: Token to introspect
|
1166
|
+
|
1167
|
+
Returns:
|
1168
|
+
Token introspection response
|
1169
|
+
"""
|
1170
|
+
try:
|
1171
|
+
# Try to verify as JWT access token
|
1172
|
+
payload = self.jwt_manager.verify_access_token(token)
|
1173
|
+
if payload:
|
1174
|
+
# Extract token information from JWT payload
|
1175
|
+
client_id = payload.get("client_id")
|
1176
|
+
scope = payload.get("scope", "")
|
1177
|
+
exp = payload.get("exp")
|
1178
|
+
iat = payload.get("iat", time.time())
|
1179
|
+
sub = payload.get("sub") or payload.get("user_id")
|
1180
|
+
aud = payload.get("aud", [])
|
1181
|
+
|
1182
|
+
# Get scopes from payload
|
1183
|
+
scopes = payload.get("scopes", [])
|
1184
|
+
if not scopes and scope:
|
1185
|
+
scopes = scope.split()
|
1186
|
+
|
1187
|
+
return {
|
1188
|
+
"active": True,
|
1189
|
+
"client_id": client_id,
|
1190
|
+
"scope": " ".join(scopes) if scopes else scope,
|
1191
|
+
"exp": exp,
|
1192
|
+
"iat": iat,
|
1193
|
+
"sub": sub,
|
1194
|
+
"aud": aud,
|
1195
|
+
"token_type": "access_token",
|
1196
|
+
}
|
1197
|
+
except AuthenticationError:
|
1198
|
+
# Token is invalid or expired
|
1199
|
+
pass
|
1200
|
+
|
1201
|
+
return {"active": False}
|
1202
|
+
|
1203
|
+
async def revoke_token(
|
1204
|
+
self,
|
1205
|
+
token: str,
|
1206
|
+
client_id: Optional[str] = None,
|
1207
|
+
client_secret: Optional[str] = None,
|
1208
|
+
) -> bool:
|
1209
|
+
"""Revoke token.
|
1210
|
+
|
1211
|
+
Args:
|
1212
|
+
token: Token to revoke
|
1213
|
+
client_id: OAuth client ID
|
1214
|
+
client_secret: OAuth client secret
|
1215
|
+
|
1216
|
+
Returns:
|
1217
|
+
True if revoked successfully
|
1218
|
+
"""
|
1219
|
+
# If client_id is provided, validate client
|
1220
|
+
if client_id:
|
1221
|
+
client = await self.client_store.get_client(client_id)
|
1222
|
+
if not client:
|
1223
|
+
return False
|
1224
|
+
|
1225
|
+
# Validate client secret for confidential clients
|
1226
|
+
if client.client_type == ClientType.CONFIDENTIAL:
|
1227
|
+
if not client_secret or client_secret != client.client_secret:
|
1228
|
+
return False
|
1229
|
+
|
1230
|
+
# Try to revoke as access token
|
1231
|
+
if await self.token_store.revoke_access_token(token):
|
1232
|
+
return True
|
1233
|
+
|
1234
|
+
# Try to revoke as refresh token
|
1235
|
+
return await self.token_store.revoke_refresh_token(token)
|
1236
|
+
|
1237
|
+
async def refresh_access_token(
|
1238
|
+
self, client_id: str, client_secret: Optional[str], refresh_token: str
|
1239
|
+
) -> Dict[str, Any]:
|
1240
|
+
"""Refresh access token (alias for refresh_token_grant).
|
1241
|
+
|
1242
|
+
Args:
|
1243
|
+
client_id: OAuth client ID
|
1244
|
+
client_secret: OAuth client secret
|
1245
|
+
refresh_token: Refresh token
|
1246
|
+
|
1247
|
+
Returns:
|
1248
|
+
Token response
|
1249
|
+
"""
|
1250
|
+
return await self.refresh_token_grant(client_id, client_secret, refresh_token)
|
1251
|
+
|
1252
|
+
def get_well_known_metadata(self) -> Dict[str, Any]:
|
1253
|
+
"""Get well-known authorization server metadata.
|
1254
|
+
|
1255
|
+
Returns:
|
1256
|
+
Authorization server metadata
|
1257
|
+
"""
|
1258
|
+
return {
|
1259
|
+
"issuer": self.issuer,
|
1260
|
+
"authorization_endpoint": f"{self.issuer}/authorize",
|
1261
|
+
"token_endpoint": f"{self.issuer}/token",
|
1262
|
+
"introspection_endpoint": f"{self.issuer}/introspect",
|
1263
|
+
"revocation_endpoint": f"{self.issuer}/revoke",
|
1264
|
+
"jwks_uri": f"{self.issuer}/.well-known/jwks.json",
|
1265
|
+
"registration_endpoint": f"{self.issuer}/register",
|
1266
|
+
"scopes_supported": self.default_scopes,
|
1267
|
+
"response_types_supported": ["code"],
|
1268
|
+
"grant_types_supported": [
|
1269
|
+
"authorization_code",
|
1270
|
+
"client_credentials",
|
1271
|
+
"refresh_token",
|
1272
|
+
],
|
1273
|
+
"token_endpoint_auth_methods_supported": [
|
1274
|
+
"client_secret_basic",
|
1275
|
+
"client_secret_post",
|
1276
|
+
],
|
1277
|
+
"code_challenge_methods_supported": ["S256", "plain"],
|
1278
|
+
}
|
1279
|
+
|
1280
|
+
|
1281
|
+
class ResourceServer:
|
1282
|
+
"""OAuth 2.1 Resource Server for MCP."""
|
1283
|
+
|
1284
|
+
def __init__(
|
1285
|
+
self,
|
1286
|
+
issuer: str,
|
1287
|
+
audience: str,
|
1288
|
+
jwt_manager: Optional[JWTManager] = None,
|
1289
|
+
required_scopes: Optional[List[str]] = None,
|
1290
|
+
):
|
1291
|
+
"""Initialize resource server.
|
1292
|
+
|
1293
|
+
Args:
|
1294
|
+
issuer: Authorization server issuer
|
1295
|
+
audience: Expected token audience
|
1296
|
+
jwt_manager: JWT manager for token verification
|
1297
|
+
required_scopes: Required scopes for access
|
1298
|
+
"""
|
1299
|
+
self.issuer = issuer
|
1300
|
+
self.audience = audience
|
1301
|
+
self.jwt_manager = jwt_manager or JWTManager(issuer=issuer)
|
1302
|
+
self.required_scopes = required_scopes or []
|
1303
|
+
|
1304
|
+
async def authenticate(
|
1305
|
+
self, credentials: Union[str, Dict[str, Any]]
|
1306
|
+
) -> Dict[str, Any]:
|
1307
|
+
"""Authenticate using OAuth 2.1 access token.
|
1308
|
+
|
1309
|
+
Args:
|
1310
|
+
credentials: Token string or dict with 'token' key
|
1311
|
+
|
1312
|
+
Returns:
|
1313
|
+
Authentication result
|
1314
|
+
"""
|
1315
|
+
# Handle both string and dict inputs
|
1316
|
+
if isinstance(credentials, str):
|
1317
|
+
token = credentials
|
1318
|
+
else:
|
1319
|
+
token = credentials.get("token")
|
1320
|
+
if not token:
|
1321
|
+
raise AuthenticationError("No token provided")
|
1322
|
+
|
1323
|
+
# Remove 'Bearer ' prefix if present
|
1324
|
+
if token.startswith("Bearer "):
|
1325
|
+
token = token[7:]
|
1326
|
+
|
1327
|
+
# Verify JWT token
|
1328
|
+
payload = self.jwt_manager.verify_access_token(token)
|
1329
|
+
if not payload:
|
1330
|
+
raise AuthenticationError("Invalid token")
|
1331
|
+
|
1332
|
+
# Check audience
|
1333
|
+
token_audience = payload.get("aud", [])
|
1334
|
+
if isinstance(token_audience, str):
|
1335
|
+
token_audience = [token_audience]
|
1336
|
+
|
1337
|
+
if self.audience not in token_audience:
|
1338
|
+
raise AuthorizationError("Invalid token audience")
|
1339
|
+
|
1340
|
+
# Check required scopes
|
1341
|
+
token_scope = payload.get("scope", "")
|
1342
|
+
token_scopes = token_scope.split() if token_scope else []
|
1343
|
+
|
1344
|
+
for required_scope in self.required_scopes:
|
1345
|
+
if required_scope not in token_scopes:
|
1346
|
+
raise AuthenticationError(f"Missing required scope: {required_scope}")
|
1347
|
+
|
1348
|
+
return {
|
1349
|
+
"id": payload.get("sub") or payload.get("client_id"),
|
1350
|
+
"client_id": payload.get("client_id"),
|
1351
|
+
"subject": payload.get("sub"),
|
1352
|
+
"user_id": payload.get("sub") or payload.get("user_id"),
|
1353
|
+
"scopes": token_scopes,
|
1354
|
+
"token_type": "Bearer",
|
1355
|
+
}
|
1356
|
+
|
1357
|
+
async def check_permission(
|
1358
|
+
self, auth_info: Dict[str, Any], required_permission: str
|
1359
|
+
) -> None:
|
1360
|
+
"""Check if authenticated entity has required permission.
|
1361
|
+
|
1362
|
+
Args:
|
1363
|
+
auth_info: Authentication information from authenticate()
|
1364
|
+
required_permission: Required permission/scope
|
1365
|
+
|
1366
|
+
Raises:
|
1367
|
+
AuthorizationError: If permission is missing
|
1368
|
+
"""
|
1369
|
+
scopes = auth_info.get("scopes", [])
|
1370
|
+
if required_permission not in scopes:
|
1371
|
+
raise AuthorizationError(
|
1372
|
+
f"Missing required permission: {required_permission}"
|
1373
|
+
)
|
1374
|
+
|
1375
|
+
async def get_headers(self) -> Dict[str, str]:
|
1376
|
+
"""Get headers for authentication (empty for resource server).
|
1377
|
+
|
1378
|
+
Returns:
|
1379
|
+
Empty dict as resource server doesn't add headers
|
1380
|
+
"""
|
1381
|
+
return {}
|
1382
|
+
|
1383
|
+
|
1384
|
+
class OAuth2Client:
|
1385
|
+
"""OAuth 2.1 client for MCP."""
|
1386
|
+
|
1387
|
+
def __init__(
|
1388
|
+
self,
|
1389
|
+
client_id: str,
|
1390
|
+
client_secret: Optional[str] = None,
|
1391
|
+
token_endpoint: Optional[str] = None,
|
1392
|
+
authorization_endpoint: Optional[str] = None,
|
1393
|
+
redirect_uri: Optional[str] = None,
|
1394
|
+
):
|
1395
|
+
"""Initialize OAuth 2.1 client.
|
1396
|
+
|
1397
|
+
Args:
|
1398
|
+
client_id: OAuth client ID
|
1399
|
+
client_secret: OAuth client secret
|
1400
|
+
token_endpoint: Token endpoint URL
|
1401
|
+
authorization_endpoint: Authorization endpoint URL
|
1402
|
+
redirect_uri: Redirect URI
|
1403
|
+
"""
|
1404
|
+
self.client_id = client_id
|
1405
|
+
self.client_secret = client_secret
|
1406
|
+
self.token_endpoint = token_endpoint
|
1407
|
+
self.authorization_endpoint = authorization_endpoint
|
1408
|
+
self.redirect_uri = redirect_uri
|
1409
|
+
|
1410
|
+
# Token storage
|
1411
|
+
self._access_token: Optional[str] = None
|
1412
|
+
self._refresh_token: Optional[str] = None
|
1413
|
+
self._token_expires_at: Optional[float] = None
|
1414
|
+
|
1415
|
+
async def get_client_credentials_token(
|
1416
|
+
self, scopes: Optional[List[str]] = None
|
1417
|
+
) -> Dict[str, Any]:
|
1418
|
+
"""Get access token using client credentials grant.
|
1419
|
+
|
1420
|
+
Args:
|
1421
|
+
scopes: Requested scopes
|
1422
|
+
|
1423
|
+
Returns:
|
1424
|
+
Token response dict
|
1425
|
+
"""
|
1426
|
+
if not self.token_endpoint:
|
1427
|
+
raise AuthenticationError("Token endpoint not configured")
|
1428
|
+
|
1429
|
+
if not self.client_secret:
|
1430
|
+
raise AuthenticationError("Client secret required for client credentials")
|
1431
|
+
|
1432
|
+
# Prepare token request
|
1433
|
+
data = {
|
1434
|
+
"grant_type": "client_credentials",
|
1435
|
+
"client_id": self.client_id,
|
1436
|
+
"client_secret": self.client_secret,
|
1437
|
+
}
|
1438
|
+
|
1439
|
+
if scopes:
|
1440
|
+
data["scope"] = " ".join(scopes)
|
1441
|
+
|
1442
|
+
# Make token request
|
1443
|
+
async with aiohttp.ClientSession() as session:
|
1444
|
+
async with session.post(
|
1445
|
+
self.token_endpoint,
|
1446
|
+
data=data,
|
1447
|
+
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
1448
|
+
) as response:
|
1449
|
+
if response.status != 200:
|
1450
|
+
error_text = await response.text()
|
1451
|
+
raise AuthenticationError(f"Token request failed: {error_text}")
|
1452
|
+
|
1453
|
+
token_response = await response.json()
|
1454
|
+
|
1455
|
+
# Store token information
|
1456
|
+
self._access_token = token_response["access_token"]
|
1457
|
+
self._refresh_token = token_response.get("refresh_token")
|
1458
|
+
|
1459
|
+
expires_in = token_response.get("expires_in", 3600)
|
1460
|
+
self._token_expires_at = time.time() + expires_in
|
1461
|
+
|
1462
|
+
return token_response
|
1463
|
+
|
1464
|
+
def get_authorization_url(
|
1465
|
+
self,
|
1466
|
+
scopes: Optional[List[str]] = None,
|
1467
|
+
state: Optional[str] = None,
|
1468
|
+
use_pkce: bool = True,
|
1469
|
+
) -> Tuple[str, Optional[str]]:
|
1470
|
+
"""Get authorization URL for authorization code flow.
|
1471
|
+
|
1472
|
+
Args:
|
1473
|
+
scopes: Requested scopes
|
1474
|
+
state: State parameter
|
1475
|
+
use_pkce: Use PKCE for security
|
1476
|
+
|
1477
|
+
Returns:
|
1478
|
+
Tuple of (authorization_url, code_verifier)
|
1479
|
+
"""
|
1480
|
+
if not self.authorization_endpoint:
|
1481
|
+
raise AuthenticationError("Authorization endpoint not configured")
|
1482
|
+
|
1483
|
+
if not self.redirect_uri:
|
1484
|
+
raise AuthenticationError("Redirect URI not configured")
|
1485
|
+
|
1486
|
+
params = {
|
1487
|
+
"response_type": "code",
|
1488
|
+
"client_id": self.client_id,
|
1489
|
+
"redirect_uri": self.redirect_uri,
|
1490
|
+
}
|
1491
|
+
|
1492
|
+
if scopes:
|
1493
|
+
params["scope"] = " ".join(scopes)
|
1494
|
+
|
1495
|
+
if state:
|
1496
|
+
params["state"] = state
|
1497
|
+
|
1498
|
+
code_verifier = None
|
1499
|
+
if use_pkce:
|
1500
|
+
# Generate PKCE parameters
|
1501
|
+
code_verifier = (
|
1502
|
+
base64.urlsafe_b64encode(secrets.token_bytes(32)).decode().rstrip("=")
|
1503
|
+
)
|
1504
|
+
code_challenge = (
|
1505
|
+
base64.urlsafe_b64encode(
|
1506
|
+
hashlib.sha256(code_verifier.encode()).digest()
|
1507
|
+
)
|
1508
|
+
.decode()
|
1509
|
+
.rstrip("=")
|
1510
|
+
)
|
1511
|
+
|
1512
|
+
params["code_challenge"] = code_challenge
|
1513
|
+
params["code_challenge_method"] = "S256"
|
1514
|
+
|
1515
|
+
query_string = urlencode(params)
|
1516
|
+
authorization_url = f"{self.authorization_endpoint}?{query_string}"
|
1517
|
+
|
1518
|
+
return authorization_url, code_verifier
|
1519
|
+
|
1520
|
+
async def exchange_authorization_code(
|
1521
|
+
self, code: str, code_verifier: Optional[str] = None
|
1522
|
+
) -> str:
|
1523
|
+
"""Exchange authorization code for access token.
|
1524
|
+
|
1525
|
+
Args:
|
1526
|
+
code: Authorization code
|
1527
|
+
code_verifier: PKCE code verifier
|
1528
|
+
|
1529
|
+
Returns:
|
1530
|
+
Access token
|
1531
|
+
"""
|
1532
|
+
if not self.token_endpoint:
|
1533
|
+
raise AuthenticationError("Token endpoint not configured")
|
1534
|
+
|
1535
|
+
if not self.redirect_uri:
|
1536
|
+
raise AuthenticationError("Redirect URI not configured")
|
1537
|
+
|
1538
|
+
# Prepare token request
|
1539
|
+
data = {
|
1540
|
+
"grant_type": "authorization_code",
|
1541
|
+
"client_id": self.client_id,
|
1542
|
+
"code": code,
|
1543
|
+
"redirect_uri": self.redirect_uri,
|
1544
|
+
}
|
1545
|
+
|
1546
|
+
if self.client_secret:
|
1547
|
+
data["client_secret"] = self.client_secret
|
1548
|
+
|
1549
|
+
if code_verifier:
|
1550
|
+
data["code_verifier"] = code_verifier
|
1551
|
+
|
1552
|
+
# Make token request
|
1553
|
+
async with aiohttp.ClientSession() as session:
|
1554
|
+
async with session.post(
|
1555
|
+
self.token_endpoint,
|
1556
|
+
data=data,
|
1557
|
+
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
1558
|
+
) as response:
|
1559
|
+
if response.status != 200:
|
1560
|
+
error_text = await response.text()
|
1561
|
+
raise AuthenticationError(f"Token exchange failed: {error_text}")
|
1562
|
+
|
1563
|
+
token_response = await response.json()
|
1564
|
+
|
1565
|
+
# Store token information
|
1566
|
+
self._access_token = token_response["access_token"]
|
1567
|
+
self._refresh_token = token_response.get("refresh_token")
|
1568
|
+
|
1569
|
+
expires_in = token_response.get("expires_in", 3600)
|
1570
|
+
self._token_expires_at = time.time() + expires_in
|
1571
|
+
|
1572
|
+
return token_response
|
1573
|
+
|
1574
|
+
async def get_valid_token(self) -> Optional[str]:
|
1575
|
+
"""Get valid access token, refreshing if necessary.
|
1576
|
+
|
1577
|
+
Returns:
|
1578
|
+
Valid access token or None
|
1579
|
+
"""
|
1580
|
+
# Check if current token is valid
|
1581
|
+
if self._access_token and self._token_expires_at:
|
1582
|
+
if time.time() < self._token_expires_at - 60: # 1 minute buffer
|
1583
|
+
return self._access_token
|
1584
|
+
|
1585
|
+
# Try to refresh token
|
1586
|
+
if self._refresh_token:
|
1587
|
+
try:
|
1588
|
+
return await self._refresh_access_token()
|
1589
|
+
except Exception as e:
|
1590
|
+
logger.error(f"Token refresh failed: {e}")
|
1591
|
+
|
1592
|
+
return None
|
1593
|
+
|
1594
|
+
async def _refresh_access_token(self) -> str:
|
1595
|
+
"""Refresh access token using refresh token.
|
1596
|
+
|
1597
|
+
Returns:
|
1598
|
+
New access token
|
1599
|
+
"""
|
1600
|
+
if not self.token_endpoint or not self._refresh_token:
|
1601
|
+
raise AuthenticationError("Cannot refresh token")
|
1602
|
+
|
1603
|
+
# Prepare refresh request
|
1604
|
+
data = {
|
1605
|
+
"grant_type": "refresh_token",
|
1606
|
+
"client_id": self.client_id,
|
1607
|
+
"refresh_token": self._refresh_token,
|
1608
|
+
}
|
1609
|
+
|
1610
|
+
if self.client_secret:
|
1611
|
+
data["client_secret"] = self.client_secret
|
1612
|
+
|
1613
|
+
# Make refresh request
|
1614
|
+
async with aiohttp.ClientSession() as session:
|
1615
|
+
async with session.post(
|
1616
|
+
self.token_endpoint,
|
1617
|
+
data=data,
|
1618
|
+
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
1619
|
+
) as response:
|
1620
|
+
if response.status != 200:
|
1621
|
+
error_text = await response.text()
|
1622
|
+
raise AuthenticationError(f"Token refresh failed: {error_text}")
|
1623
|
+
|
1624
|
+
token_response = await response.json()
|
1625
|
+
|
1626
|
+
# Update token information
|
1627
|
+
self._access_token = token_response["access_token"]
|
1628
|
+
|
1629
|
+
# Update refresh token if provided
|
1630
|
+
if "refresh_token" in token_response:
|
1631
|
+
self._refresh_token = token_response["refresh_token"]
|
1632
|
+
|
1633
|
+
expires_in = token_response.get("expires_in", 3600)
|
1634
|
+
self._token_expires_at = time.time() + expires_in
|
1635
|
+
|
1636
|
+
return self._access_token
|
1637
|
+
|
1638
|
+
async def refresh_token(self, refresh_token: str) -> Dict[str, Any]:
|
1639
|
+
"""Refresh access token.
|
1640
|
+
|
1641
|
+
Args:
|
1642
|
+
refresh_token: Refresh token
|
1643
|
+
|
1644
|
+
Returns:
|
1645
|
+
Token response
|
1646
|
+
"""
|
1647
|
+
if not self.token_endpoint:
|
1648
|
+
raise AuthenticationError("Token endpoint not configured")
|
1649
|
+
|
1650
|
+
# Prepare token request
|
1651
|
+
data = {
|
1652
|
+
"grant_type": "refresh_token",
|
1653
|
+
"refresh_token": refresh_token,
|
1654
|
+
"client_id": self.client_id,
|
1655
|
+
}
|
1656
|
+
|
1657
|
+
if self.client_secret:
|
1658
|
+
data["client_secret"] = self.client_secret
|
1659
|
+
|
1660
|
+
# Make token request
|
1661
|
+
async with aiohttp.ClientSession() as session:
|
1662
|
+
async with session.post(
|
1663
|
+
self.token_endpoint,
|
1664
|
+
data=data,
|
1665
|
+
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
1666
|
+
) as response:
|
1667
|
+
if response.status != 200:
|
1668
|
+
error_data = await response.json()
|
1669
|
+
error = error_data.get("error", "unknown_error")
|
1670
|
+
error_description = error_data.get(
|
1671
|
+
"error_description", "Token refresh failed"
|
1672
|
+
)
|
1673
|
+
raise AuthenticationError(f"{error}: {error_description}")
|
1674
|
+
|
1675
|
+
token_response = await response.json()
|
1676
|
+
|
1677
|
+
# Store token information
|
1678
|
+
self._access_token = token_response["access_token"]
|
1679
|
+
self._refresh_token = token_response.get("refresh_token", refresh_token)
|
1680
|
+
|
1681
|
+
expires_in = token_response.get("expires_in", 3600)
|
1682
|
+
self._token_expires_at = time.time() + expires_in
|
1683
|
+
|
1684
|
+
return token_response
|
1685
|
+
|
1686
|
+
async def introspect_token(
|
1687
|
+
self, token: str, introspection_endpoint: Optional[str] = None
|
1688
|
+
) -> Dict[str, Any]:
|
1689
|
+
"""Introspect a token.
|
1690
|
+
|
1691
|
+
Args:
|
1692
|
+
token: Token to introspect
|
1693
|
+
introspection_endpoint: Introspection endpoint URL
|
1694
|
+
|
1695
|
+
Returns:
|
1696
|
+
Introspection response
|
1697
|
+
"""
|
1698
|
+
if not introspection_endpoint and self.token_endpoint:
|
1699
|
+
# Try to derive introspection endpoint from token endpoint
|
1700
|
+
introspection_endpoint = self.token_endpoint.replace(
|
1701
|
+
"/token", "/introspect"
|
1702
|
+
)
|
1703
|
+
|
1704
|
+
if not introspection_endpoint:
|
1705
|
+
raise AuthenticationError("Introspection endpoint not configured")
|
1706
|
+
|
1707
|
+
data = {
|
1708
|
+
"token": token,
|
1709
|
+
"client_id": self.client_id,
|
1710
|
+
}
|
1711
|
+
|
1712
|
+
if self.client_secret:
|
1713
|
+
data["client_secret"] = self.client_secret
|
1714
|
+
|
1715
|
+
async with aiohttp.ClientSession() as session:
|
1716
|
+
async with session.post(
|
1717
|
+
introspection_endpoint,
|
1718
|
+
data=data,
|
1719
|
+
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
1720
|
+
) as response:
|
1721
|
+
if response.status != 200:
|
1722
|
+
error_text = await response.text()
|
1723
|
+
raise AuthenticationError(
|
1724
|
+
f"Token introspection failed: {error_text}"
|
1725
|
+
)
|
1726
|
+
|
1727
|
+
return await response.json()
|