workspace-mcp 1.1.7__py3-none-any.whl → 1.1.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
auth/oauth21/oauth2.py ADDED
@@ -0,0 +1,353 @@
1
+ """
2
+ OAuth 2.1 Authorization Flow Handler
3
+
4
+ Implements OAuth 2.1 authorization flow with PKCE (RFC7636) and Resource Indicators (RFC8707)
5
+ for secure authorization code exchange.
6
+ """
7
+
8
+ import base64
9
+ import logging
10
+ import secrets
11
+ from typing import Dict, Any, Optional, Tuple, List
12
+ from urllib.parse import urlencode, urlparse, parse_qs
13
+
14
+ import aiohttp
15
+ from cryptography.hazmat.primitives import hashes
16
+ from cryptography.hazmat.backends import default_backend
17
+
18
+ from .discovery import AuthorizationServerDiscovery
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ class OAuth2AuthorizationFlow:
24
+ """Handles OAuth 2.1 authorization flow with PKCE."""
25
+
26
+ def __init__(
27
+ self,
28
+ client_id: str,
29
+ client_secret: Optional[str] = None,
30
+ discovery_service: Optional[AuthorizationServerDiscovery] = None,
31
+ ):
32
+ """
33
+ Initialize the OAuth 2.1 flow handler.
34
+
35
+ Args:
36
+ client_id: OAuth 2.0 client identifier
37
+ client_secret: OAuth 2.0 client secret (optional for public clients)
38
+ discovery_service: Authorization server discovery service
39
+ """
40
+ self.client_id = client_id
41
+ self.client_secret = client_secret
42
+ self.discovery = discovery_service or AuthorizationServerDiscovery()
43
+ self._session: Optional[aiohttp.ClientSession] = None
44
+
45
+ async def _get_session(self) -> aiohttp.ClientSession:
46
+ """Get or create HTTP session."""
47
+ if self._session is None or self._session.closed:
48
+ self._session = aiohttp.ClientSession(
49
+ timeout=aiohttp.ClientTimeout(total=30),
50
+ headers={"User-Agent": "MCP-OAuth2.1-Client/1.0"},
51
+ )
52
+ return self._session
53
+
54
+ async def close(self):
55
+ """Clean up resources."""
56
+ if self._session and not self._session.closed:
57
+ await self._session.close()
58
+ await self.discovery.close()
59
+
60
+ def generate_pkce_parameters(self) -> Tuple[str, str]:
61
+ """
62
+ Generate PKCE code_verifier and code_challenge per RFC7636.
63
+
64
+ Returns:
65
+ Tuple of (code_verifier, code_challenge)
66
+ """
67
+ # Generate cryptographically secure random code_verifier
68
+ # Must be 43-128 characters long
69
+ code_verifier = base64.urlsafe_b64encode(secrets.token_bytes(32)).decode('utf-8').rstrip('=')
70
+
71
+ # Create SHA256 hash of the code_verifier for code_challenge
72
+ digest = hashes.Hash(hashes.SHA256(), backend=default_backend())
73
+ digest.update(code_verifier.encode('utf-8'))
74
+ code_challenge = base64.urlsafe_b64encode(digest.finalize()).decode('utf-8').rstrip('=')
75
+
76
+ logger.debug("Generated PKCE parameters")
77
+ return code_verifier, code_challenge
78
+
79
+ def generate_state(self) -> str:
80
+ """
81
+ Generate a cryptographically secure state parameter.
82
+
83
+ Returns:
84
+ Random state string
85
+ """
86
+ return base64.urlsafe_b64encode(secrets.token_bytes(32)).decode('utf-8').rstrip('=')
87
+
88
+ async def build_authorization_url(
89
+ self,
90
+ authorization_server_url: str,
91
+ redirect_uri: str,
92
+ scopes: List[str],
93
+ state: Optional[str] = None,
94
+ resource: Optional[str] = None,
95
+ additional_params: Optional[Dict[str, str]] = None,
96
+ ) -> Tuple[str, str, str]:
97
+ """
98
+ Build OAuth 2.1 authorization URL with PKCE.
99
+
100
+ Args:
101
+ authorization_server_url: Authorization server base URL
102
+ redirect_uri: Client redirect URI
103
+ scopes: List of requested scopes
104
+ state: State parameter (generated if not provided)
105
+ resource: Resource indicator per RFC8707
106
+ additional_params: Additional query parameters
107
+
108
+ Returns:
109
+ Tuple of (authorization_url, state, code_verifier)
110
+
111
+ Raises:
112
+ ValueError: If authorization server metadata is invalid
113
+ aiohttp.ClientError: If metadata cannot be fetched
114
+ """
115
+ # Fetch authorization server metadata
116
+ as_metadata = await self.discovery.get_authorization_server_metadata(authorization_server_url)
117
+ auth_endpoint = as_metadata.get("authorization_endpoint")
118
+
119
+ if not auth_endpoint:
120
+ raise ValueError(f"No authorization_endpoint in metadata for {authorization_server_url}")
121
+
122
+ # Verify PKCE support
123
+ code_challenge_methods = as_metadata.get("code_challenge_methods_supported", [])
124
+ if "S256" not in code_challenge_methods:
125
+ logger.warning(f"Authorization server {authorization_server_url} may not support PKCE S256")
126
+
127
+ # Generate PKCE parameters
128
+ code_verifier, code_challenge = self.generate_pkce_parameters()
129
+
130
+ # Generate state if not provided
131
+ if state is None:
132
+ state = self.generate_state()
133
+
134
+ # Build authorization parameters
135
+ auth_params = {
136
+ "response_type": "code",
137
+ "client_id": self.client_id,
138
+ "redirect_uri": redirect_uri,
139
+ "scope": " ".join(scopes),
140
+ "state": state,
141
+ "code_challenge": code_challenge,
142
+ "code_challenge_method": "S256",
143
+ }
144
+
145
+ # Add resource indicator if provided (RFC8707)
146
+ if resource:
147
+ auth_params["resource"] = resource
148
+
149
+ # Add any additional parameters
150
+ if additional_params:
151
+ auth_params.update(additional_params)
152
+
153
+ # Build the complete authorization URL
154
+ authorization_url = f"{auth_endpoint}?{urlencode(auth_params)}"
155
+
156
+ logger.info(f"Built authorization URL for {authorization_server_url}")
157
+ logger.debug(f"Authorization URL: {authorization_url}")
158
+
159
+ return authorization_url, state, code_verifier
160
+
161
+ async def exchange_code_for_token(
162
+ self,
163
+ authorization_server_url: str,
164
+ authorization_code: str,
165
+ code_verifier: str,
166
+ redirect_uri: str,
167
+ resource: Optional[str] = None,
168
+ ) -> Dict[str, Any]:
169
+ """
170
+ Exchange authorization code for access token using PKCE.
171
+
172
+ Args:
173
+ authorization_server_url: Authorization server base URL
174
+ authorization_code: Authorization code from callback
175
+ code_verifier: PKCE code verifier
176
+ redirect_uri: Client redirect URI (must match authorization request)
177
+ resource: Resource indicator per RFC8707
178
+
179
+ Returns:
180
+ Token response dictionary
181
+
182
+ Raises:
183
+ ValueError: If token exchange fails or response is invalid
184
+ aiohttp.ClientError: If HTTP request fails
185
+ """
186
+ # Fetch authorization server metadata
187
+ as_metadata = await self.discovery.get_authorization_server_metadata(authorization_server_url)
188
+ token_endpoint = as_metadata.get("token_endpoint")
189
+
190
+ if not token_endpoint:
191
+ raise ValueError(f"No token_endpoint in metadata for {authorization_server_url}")
192
+
193
+ # Prepare token request data
194
+ token_data = {
195
+ "grant_type": "authorization_code",
196
+ "code": authorization_code,
197
+ "redirect_uri": redirect_uri,
198
+ "client_id": self.client_id,
199
+ "code_verifier": code_verifier,
200
+ }
201
+
202
+ # Add resource indicator if provided
203
+ if resource:
204
+ token_data["resource"] = resource
205
+
206
+ # Prepare headers
207
+ headers = {
208
+ "Content-Type": "application/x-www-form-urlencoded",
209
+ "Accept": "application/json",
210
+ }
211
+
212
+ # Add client authentication if client_secret is available
213
+ if self.client_secret:
214
+ # Use client_secret_post method
215
+ token_data["client_secret"] = self.client_secret
216
+
217
+ session = await self._get_session()
218
+
219
+ try:
220
+ logger.debug(f"Exchanging authorization code at {token_endpoint}")
221
+ async with session.post(token_endpoint, data=token_data, headers=headers) as response:
222
+ response_text = await response.text()
223
+
224
+ if response.status != 200:
225
+ logger.error(f"Token exchange failed: {response.status} {response_text}")
226
+ raise ValueError(f"Token exchange failed: {response.status} {response_text}")
227
+
228
+ try:
229
+ token_response = await response.json()
230
+ except Exception as e:
231
+ logger.error(f"Failed to parse token response: {e}")
232
+ raise ValueError(f"Invalid token response format: {e}")
233
+
234
+ # Validate required fields in token response
235
+ if "access_token" not in token_response:
236
+ raise ValueError("Token response missing access_token")
237
+
238
+ if "token_type" not in token_response:
239
+ raise ValueError("Token response missing token_type")
240
+
241
+ # Ensure token_type is Bearer (case-insensitive)
242
+ if token_response["token_type"].lower() != "bearer":
243
+ logger.warning(f"Unexpected token_type: {token_response['token_type']}")
244
+
245
+ logger.info("Successfully exchanged authorization code for tokens")
246
+ return token_response
247
+
248
+ except aiohttp.ClientError as e:
249
+ logger.error(f"HTTP error during token exchange: {e}")
250
+ raise
251
+
252
+ async def refresh_access_token(
253
+ self,
254
+ authorization_server_url: str,
255
+ refresh_token: str,
256
+ scopes: Optional[List[str]] = None,
257
+ resource: Optional[str] = None,
258
+ ) -> Dict[str, Any]:
259
+ """
260
+ Refresh access token using refresh token.
261
+
262
+ Args:
263
+ authorization_server_url: Authorization server base URL
264
+ refresh_token: Refresh token
265
+ scopes: Optional scope restriction
266
+ resource: Resource indicator per RFC8707
267
+
268
+ Returns:
269
+ Token response dictionary
270
+
271
+ Raises:
272
+ ValueError: If token refresh fails
273
+ aiohttp.ClientError: If HTTP request fails
274
+ """
275
+ # Fetch authorization server metadata
276
+ as_metadata = await self.discovery.get_authorization_server_metadata(authorization_server_url)
277
+ token_endpoint = as_metadata.get("token_endpoint")
278
+
279
+ if not token_endpoint:
280
+ raise ValueError(f"No token_endpoint in metadata for {authorization_server_url}")
281
+
282
+ # Prepare refresh request data
283
+ refresh_data = {
284
+ "grant_type": "refresh_token",
285
+ "refresh_token": refresh_token,
286
+ "client_id": self.client_id,
287
+ }
288
+
289
+ # Add optional scope restriction
290
+ if scopes:
291
+ refresh_data["scope"] = " ".join(scopes)
292
+
293
+ # Add resource indicator if provided
294
+ if resource:
295
+ refresh_data["resource"] = resource
296
+
297
+ # Add client authentication if available
298
+ if self.client_secret:
299
+ refresh_data["client_secret"] = self.client_secret
300
+
301
+ headers = {
302
+ "Content-Type": "application/x-www-form-urlencoded",
303
+ "Accept": "application/json",
304
+ }
305
+
306
+ session = await self._get_session()
307
+
308
+ try:
309
+ logger.debug(f"Refreshing access token at {token_endpoint}")
310
+ async with session.post(token_endpoint, data=refresh_data, headers=headers) as response:
311
+ response_text = await response.text()
312
+
313
+ if response.status != 200:
314
+ logger.error(f"Token refresh failed: {response.status} {response_text}")
315
+ raise ValueError(f"Token refresh failed: {response.status} {response_text}")
316
+
317
+ token_response = await response.json()
318
+
319
+ # Validate required fields
320
+ if "access_token" not in token_response:
321
+ raise ValueError("Refresh response missing access_token")
322
+
323
+ logger.info("Successfully refreshed access token")
324
+ return token_response
325
+
326
+ except aiohttp.ClientError as e:
327
+ logger.error(f"HTTP error during token refresh: {e}")
328
+ raise
329
+
330
+ def parse_authorization_response(self, authorization_response_url: str) -> Tuple[Optional[str], Optional[str], Optional[str]]:
331
+ """
332
+ Parse authorization response URL to extract code, state, and error.
333
+
334
+ Args:
335
+ authorization_response_url: Complete callback URL
336
+
337
+ Returns:
338
+ Tuple of (code, state, error)
339
+ """
340
+ parsed_url = urlparse(authorization_response_url)
341
+ query_params = parse_qs(parsed_url.query)
342
+
343
+ code = query_params.get("code", [None])[0]
344
+ state = query_params.get("state", [None])[0]
345
+ error = query_params.get("error", [None])[0]
346
+
347
+ if error:
348
+ error_description = query_params.get("error_description", [None])[0]
349
+ full_error = f"{error}: {error_description}" if error_description else error
350
+ logger.error(f"Authorization error: {full_error}")
351
+ return None, state, full_error
352
+
353
+ return code, state, None