golf-mcp 0.1.11__py3-none-any.whl → 0.1.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of golf-mcp might be problematic. Click here for more details.

Files changed (48) hide show
  1. golf/__init__.py +1 -1
  2. golf/auth/__init__.py +38 -26
  3. golf/auth/api_key.py +16 -23
  4. golf/auth/helpers.py +68 -54
  5. golf/auth/oauth.py +340 -277
  6. golf/auth/provider.py +58 -53
  7. golf/cli/__init__.py +1 -1
  8. golf/cli/main.py +209 -87
  9. golf/commands/__init__.py +1 -1
  10. golf/commands/build.py +31 -25
  11. golf/commands/init.py +81 -53
  12. golf/commands/run.py +30 -15
  13. golf/core/__init__.py +1 -1
  14. golf/core/builder.py +493 -362
  15. golf/core/builder_auth.py +115 -107
  16. golf/core/builder_telemetry.py +12 -9
  17. golf/core/config.py +62 -46
  18. golf/core/parser.py +174 -136
  19. golf/core/telemetry.py +216 -95
  20. golf/core/transformer.py +53 -55
  21. golf/examples/__init__.py +0 -1
  22. golf/examples/api_key/pre_build.py +2 -2
  23. golf/examples/api_key/tools/issues/create.py +35 -36
  24. golf/examples/api_key/tools/issues/list.py +42 -37
  25. golf/examples/api_key/tools/repos/list.py +50 -29
  26. golf/examples/api_key/tools/search/code.py +50 -37
  27. golf/examples/api_key/tools/users/get.py +21 -20
  28. golf/examples/basic/pre_build.py +4 -4
  29. golf/examples/basic/prompts/welcome.py +6 -7
  30. golf/examples/basic/resources/current_time.py +10 -9
  31. golf/examples/basic/resources/info.py +6 -5
  32. golf/examples/basic/resources/weather/common.py +16 -10
  33. golf/examples/basic/resources/weather/current.py +15 -11
  34. golf/examples/basic/resources/weather/forecast.py +15 -11
  35. golf/examples/basic/tools/github_user.py +19 -21
  36. golf/examples/basic/tools/hello.py +10 -6
  37. golf/examples/basic/tools/payments/charge.py +34 -25
  38. golf/examples/basic/tools/payments/common.py +8 -6
  39. golf/examples/basic/tools/payments/refund.py +29 -25
  40. golf/telemetry/__init__.py +6 -6
  41. golf/telemetry/instrumentation.py +455 -310
  42. {golf_mcp-0.1.11.dist-info → golf_mcp-0.1.13.dist-info}/METADATA +1 -1
  43. golf_mcp-0.1.13.dist-info/RECORD +55 -0
  44. golf_mcp-0.1.11.dist-info/RECORD +0 -55
  45. {golf_mcp-0.1.11.dist-info → golf_mcp-0.1.13.dist-info}/WHEEL +0 -0
  46. {golf_mcp-0.1.11.dist-info → golf_mcp-0.1.13.dist-info}/entry_points.txt +0 -0
  47. {golf_mcp-0.1.11.dist-info → golf_mcp-0.1.13.dist-info}/licenses/LICENSE +0 -0
  48. {golf_mcp-0.1.11.dist-info → golf_mcp-0.1.13.dist-info}/top_level.txt +0 -0
golf/auth/oauth.py CHANGED
@@ -5,25 +5,25 @@ interface for GolfMCP servers. It handles the OAuth 2.0 authentication flow,
5
5
  token management, and client registration.
6
6
  """
7
7
 
8
+ import os
8
9
  import time
9
10
  import uuid
10
- import jwt
11
- import httpx
12
- import os
13
- from typing import Dict, List, Optional, Any, Union
14
11
  from datetime import datetime
12
+ from typing import Any
15
13
 
14
+ import httpx
15
+ import jwt
16
16
  from mcp.server.auth.provider import (
17
- OAuthAuthorizationServerProvider,
18
17
  AccessToken,
19
- RefreshToken,
20
18
  AuthorizationCode,
19
+ AuthorizationParams,
20
+ OAuthAuthorizationServerProvider,
21
+ RefreshToken,
21
22
  RegistrationError,
22
- AuthorizationParams
23
23
  )
24
24
  from mcp.shared.auth import (
25
- OAuthToken,
26
25
  OAuthClientInformationFull,
26
+ OAuthToken,
27
27
  )
28
28
  from starlette.responses import RedirectResponse
29
29
 
@@ -32,45 +32,47 @@ from .provider import ProviderConfig
32
32
 
33
33
  class TokenStorage:
34
34
  """Simple in-memory token storage.
35
-
35
+
36
36
  This class provides a simple in-memory storage for OAuth tokens,
37
37
  authorization codes, and client information. In a production
38
38
  environment, this should be replaced with a persistent storage
39
39
  solution.
40
40
  """
41
-
42
- def __init__(self):
41
+
42
+ def __init__(self) -> None:
43
43
  """Initialize the token storage."""
44
44
  self.auth_codes = {} # code_str -> AuthorizationCode
45
45
  self.refresh_tokens = {} # token_str -> RefreshToken
46
46
  self.access_tokens = {} # token_str -> AccessToken
47
47
  self.clients = {} # client_id -> OAuthClientInformationFull
48
48
  self.provider_tokens = {} # mcp_access_token_str -> provider_access_token_str
49
- self.auth_code_to_provider_token = {} # auth_code_str -> provider_access_token_str
50
-
51
- def store_auth_code(self, code: str, auth_code_obj: AuthorizationCode) -> None: # Renamed auth_code to auth_code_obj for clarity
49
+ self.auth_code_to_provider_token = {} # auth_code_str -> provider_access_token_str
50
+
51
+ def store_auth_code(
52
+ self, code: str, auth_code_obj: AuthorizationCode
53
+ ) -> None: # Renamed auth_code to auth_code_obj for clarity
52
54
  """Store an authorization code.
53
-
55
+
54
56
  Args:
55
57
  code: The authorization code string
56
58
  auth_code_obj: The authorization code object
57
59
  """
58
60
  self.auth_codes[code] = auth_code_obj
59
-
60
- def get_auth_code(self, code: str) -> Optional[AuthorizationCode]:
61
+
62
+ def get_auth_code(self, code: str) -> AuthorizationCode | None:
61
63
  """Get an authorization code by value.
62
-
64
+
63
65
  Args:
64
66
  code: The authorization code string
65
-
67
+
66
68
  Returns:
67
69
  The authorization code object or None if not found
68
70
  """
69
71
  return self.auth_codes.get(code)
70
-
72
+
71
73
  def delete_auth_code(self, code: str) -> None:
72
74
  """Delete an authorization code and its associated provider token mapping.
73
-
75
+
74
76
  Args:
75
77
  code: The authorization code string
76
78
  """
@@ -78,288 +80,299 @@ class TokenStorage:
78
80
  del self.auth_codes[code]
79
81
  if code in self.auth_code_to_provider_token:
80
82
  del self.auth_code_to_provider_token[code]
81
-
82
- def store_auth_code_provider_token_mapping(self, auth_code_str: str, provider_token: str) -> None:
83
+
84
+ def store_auth_code_provider_token_mapping(
85
+ self, auth_code_str: str, provider_token: str
86
+ ) -> None:
83
87
  """Store a mapping from an auth_code string to a provider_token string."""
84
88
  self.auth_code_to_provider_token[auth_code_str] = provider_token
85
89
 
86
- def get_provider_token_for_auth_code(self, auth_code_str: str) -> Optional[str]:
90
+ def get_provider_token_for_auth_code(self, auth_code_str: str) -> str | None:
87
91
  """Retrieve a provider_token string using an auth_code string."""
88
92
  return self.auth_code_to_provider_token.get(auth_code_str)
89
-
93
+
90
94
  def store_client(self, client_id: str, client: OAuthClientInformationFull) -> None:
91
95
  """Store client information.
92
-
96
+
93
97
  Args:
94
98
  client_id: The client ID
95
99
  client: The client information
96
100
  """
97
101
  self.clients[client_id] = client
98
-
99
- def get_client(self, client_id: str) -> Optional[OAuthClientInformationFull]:
102
+
103
+ def get_client(self, client_id: str) -> OAuthClientInformationFull | None:
100
104
  """Get client information by ID.
101
-
105
+
102
106
  Args:
103
107
  client_id: The client ID
104
-
108
+
105
109
  Returns:
106
110
  The client information or None if not found
107
111
  """
108
112
  # _diag_logger.info(f"TokenStorage: get_client called for client_id '{client_id}'. Known clients: {list(self.clients.keys())}") # Optional: uncomment for debugging
109
113
  return self.clients.get(client_id)
110
-
114
+
111
115
  def store_refresh_token(self, token: str, refresh_token: RefreshToken) -> None:
112
116
  """Store a refresh token.
113
-
117
+
114
118
  Args:
115
119
  token: The refresh token string
116
120
  refresh_token: The refresh token object
117
121
  """
118
122
  self.refresh_tokens[token] = refresh_token
119
-
120
- def get_refresh_token(self, token: str) -> Optional[RefreshToken]:
123
+
124
+ def get_refresh_token(self, token: str) -> RefreshToken | None:
121
125
  """Get a refresh token by value.
122
-
126
+
123
127
  Args:
124
128
  token: The refresh token string
125
-
129
+
126
130
  Returns:
127
131
  The refresh token object or None if not found
128
132
  """
129
133
  return self.refresh_tokens.get(token)
130
-
134
+
131
135
  def delete_refresh_token(self, token: str) -> None:
132
136
  """Delete a refresh token.
133
-
137
+
134
138
  Args:
135
139
  token: The refresh token string
136
140
  """
137
141
  if token in self.refresh_tokens:
138
142
  del self.refresh_tokens[token]
139
-
143
+
140
144
  def store_access_token(self, token: str, access_token: AccessToken) -> None:
141
145
  """Store an access token.
142
-
146
+
143
147
  Args:
144
148
  token: The access token string
145
149
  access_token: The access token object
146
150
  """
147
151
  self.access_tokens[token] = access_token
148
-
149
- def get_access_token(self, token: str) -> Optional[AccessToken]:
152
+
153
+ def get_access_token(self, token: str) -> AccessToken | None:
150
154
  """Get an access token by value.
151
-
155
+
152
156
  Args:
153
157
  token: The access token string
154
-
158
+
155
159
  Returns:
156
160
  The access token object or None if not found
157
161
  """
158
162
  return self.access_tokens.get(token)
159
-
163
+
160
164
  def delete_access_token(self, token: str) -> None:
161
165
  """Delete an access token.
162
-
166
+
163
167
  Args:
164
168
  token: The access token string
165
169
  """
166
170
  if token in self.access_tokens:
167
171
  del self.access_tokens[token]
168
-
172
+
169
173
  def store_provider_token(self, mcp_token: str, provider_token: str) -> None:
170
174
  """Store a provider token mapping.
171
-
175
+
172
176
  Args:
173
177
  mcp_token: The MCP token string
174
178
  provider_token: The provider token string (e.g., GitHub token)
175
179
  """
176
180
  self.provider_tokens[mcp_token] = provider_token
177
-
178
- def get_provider_token(self, mcp_token: str) -> Optional[str]:
181
+
182
+ def get_provider_token(self, mcp_token: str) -> str | None:
179
183
  """Get the provider token associated with an MCP token.
180
-
184
+
181
185
  This is a non-standard method to allow access to the provider token
182
186
  (e.g., GitHub token) for a given MCP token. This can be used by
183
187
  tools that need to access provider APIs.
184
-
188
+
185
189
  Args:
186
190
  mcp_token: The MCP token
187
-
191
+
188
192
  Returns:
189
193
  The provider token or None if not found
190
194
  """
191
- return self.provider_tokens.get(mcp_token) # Changed from self.storage to self
195
+ return self.provider_tokens.get(mcp_token) # Changed from self.storage to self
192
196
 
193
197
 
194
198
  class GolfOAuthProvider(OAuthAuthorizationServerProvider):
195
199
  """OAuth provider implementation for GolfMCP.
196
-
200
+
197
201
  This class implements the OAuthAuthorizationServerProvider interface
198
202
  for GolfMCP servers. It handles the OAuth 2.0 authentication flow,
199
203
  token management, and client registration.
200
204
  """
201
-
202
- def __init__(self, config: ProviderConfig):
205
+
206
+ def __init__(self, config: ProviderConfig) -> None:
203
207
  """Initialize the provider.
204
-
208
+
205
209
  Args:
206
210
  config: The provider configuration
207
211
  """
208
212
  self.config = config
209
213
  self.storage = TokenStorage()
210
- self.state_mapping: Dict[str, Dict[str, Any]] = {} # Initialize state_mapping
211
-
214
+ self.state_mapping: dict[str, dict[str, Any]] = {} # Initialize state_mapping
215
+
212
216
  # Register default client
213
217
  self._register_default_client()
214
-
218
+
215
219
  def _get_client_id(self) -> str:
216
220
  """Get the client ID from config or environment."""
217
221
  if self.config.client_id:
218
222
  return self.config.client_id
219
-
223
+
220
224
  if self.config.client_id_env_var:
221
225
  value = os.environ.get(self.config.client_id_env_var)
222
226
  if value:
223
227
  return value
224
-
228
+
225
229
  return "missing-client-id"
226
-
230
+
227
231
  def _get_client_secret(self) -> str:
228
232
  """Get the client secret from config or environment."""
229
233
  if self.config.client_secret:
230
234
  return self.config.client_secret
231
-
235
+
232
236
  if self.config.client_secret_env_var:
233
237
  value = os.environ.get(self.config.client_secret_env_var)
234
238
  if value:
235
239
  return value
236
-
240
+
237
241
  return "missing-client-secret"
238
-
242
+
239
243
  def _get_jwt_secret(self) -> str:
240
244
  """Get the JWT secret from config. It's expected to be resolved by server startup."""
241
245
  if self.config.jwt_secret:
242
246
  # _diag_logger.info(f"GolfOAuthProvider: Using JWT secret from config: {self.config.jwt_secret[:5]}...")
243
247
  return self.config.jwt_secret
244
248
  else:
245
- raise ValueError("JWT Secret is not configured in the provider. Check server logs and environment variables.")
246
-
249
+ raise ValueError(
250
+ "JWT Secret is not configured in the provider. Check server logs and environment variables."
251
+ )
252
+
247
253
  def _register_default_client(self) -> None:
248
254
  """Register a default client for MCP."""
249
255
  # These are the URIs where *this server* is allowed to redirect an MCP client
250
256
  # after successful authentication and MCP auth code generation.
251
257
  client_redirect_uris = [
252
258
  # Common redirect URI for MCP Inspector running locally
253
- "http://localhost:5173/callback",
259
+ "http://localhost:5173/callback",
254
260
  "http://127.0.0.1:5173/callback",
255
261
  # A generic callback relative to the server's issuer URL, if needed by some clients
256
262
  # This assumes such a client-side endpoint exists.
257
- f"{self.config.issuer_url.rstrip('/') if self.config.issuer_url else 'http://localhost:3000'}/client/callback"
263
+ f"{self.config.issuer_url.rstrip('/') if self.config.issuer_url else 'http://localhost:3000'}/client/callback",
258
264
  ]
259
265
 
260
266
  default_client = OAuthClientInformationFull(
261
267
  client_id="default",
262
268
  client_name="Default MCP Client",
263
269
  client_secret="", # Public client
264
- redirect_uris=client_redirect_uris,
270
+ redirect_uris=client_redirect_uris,
265
271
  grant_types=["authorization_code", "refresh_token"],
266
272
  response_types=["code"],
267
273
  token_endpoint_auth_method="none", # Public client
268
- scope=" ".join(self.config.scopes)
274
+ scope=" ".join(self.config.scopes),
269
275
  )
270
276
  self.storage.store_client("default", default_client)
271
-
277
+
272
278
  def _generate_jwt(
273
- self,
274
- subject: str,
275
- scopes: List[str],
276
- expires_in: int = None
279
+ self, subject: str, scopes: list[str], expires_in: int = None
277
280
  ) -> str:
278
281
  """Generate a JWT token.
279
-
282
+
280
283
  Args:
281
284
  subject: The subject of the token (usually client_id)
282
285
  scopes: The scopes granted to the token
283
286
  expires_in: The token lifetime in seconds (or None for default)
284
-
287
+
285
288
  Returns:
286
289
  The signed JWT token
287
290
  """
288
291
  now = int(time.time())
289
292
  expiry = now + (expires_in or self.config.token_expiration)
290
-
293
+
291
294
  payload = {
292
295
  "iss": self.config.issuer_url or "golf:auth",
293
296
  "sub": subject,
294
297
  "iat": now,
295
298
  "exp": expiry,
296
- "scp": scopes
299
+ "scp": scopes,
297
300
  }
298
-
301
+
299
302
  jwt_secret = self._get_jwt_secret()
300
303
  return jwt.encode(payload, jwt_secret, algorithm="HS256")
301
-
302
- def _verify_jwt(self, token: str) -> Optional[Dict[str, Any]]:
304
+
305
+ def _verify_jwt(self, token: str) -> dict[str, Any] | None:
303
306
  """Verify a JWT token."""
304
- jwt_secret = self._get_jwt_secret() # Get secret first
307
+ jwt_secret = self._get_jwt_secret() # Get secret first
305
308
  # _diag_logger.info(f"GolfOAuthProvider: _verify_jwt attempting to use secret: {jwt_secret[:5]}...")
306
309
 
307
310
  try:
308
- payload = jwt.decode(token, jwt_secret, algorithms=["HS256"], options={"verify_signature": True})
309
-
311
+ payload = jwt.decode(
312
+ token,
313
+ jwt_secret,
314
+ algorithms=["HS256"],
315
+ options={"verify_signature": True},
316
+ )
317
+
310
318
  if payload.get("exp", 0) < time.time():
311
319
  exp_timestamp = payload.get("exp")
312
320
  current_timestamp = time.time()
313
- exp_datetime_str = str(datetime.fromtimestamp(exp_timestamp)) if exp_timestamp is not None else "N/A"
314
- current_datetime_str = str(datetime.fromtimestamp(current_timestamp))
321
+ (
322
+ str(datetime.fromtimestamp(exp_timestamp))
323
+ if exp_timestamp is not None
324
+ else "N/A"
325
+ )
326
+ str(datetime.fromtimestamp(current_timestamp))
315
327
  return None
316
328
  return payload
317
- except jwt.ExpiredSignatureError as e:
329
+ except jwt.ExpiredSignatureError:
318
330
  return None
319
- except jwt.PyJWTError as e:
331
+ except jwt.PyJWTError:
320
332
  return None
321
- except Exception as e: # Catch any other unexpected error during decode
333
+ except Exception: # Catch any other unexpected error during decode
322
334
  return None
323
-
324
- async def get_client(self, client_id: str) -> Optional[OAuthClientInformationFull]:
335
+
336
+ async def get_client(self, client_id: str) -> OAuthClientInformationFull | None:
325
337
  """Get client information by ID.
326
-
338
+
327
339
  Args:
328
340
  client_id: The client ID
329
-
341
+
330
342
  Returns:
331
343
  The client information or None if not found
332
344
  """
333
345
  return self.storage.get_client(client_id)
334
-
335
- async def register_client(
336
- self,
337
- client_info: OAuthClientInformationFull
338
- ) -> None:
346
+
347
+ async def register_client(self, client_info: OAuthClientInformationFull) -> None:
339
348
  """Register a new client."""
340
349
  # Add detailed logging at the beginning
341
- client_id_to_register = getattr(client_info, 'client_id', 'UNKNOWN (client_info has no client_id attribute)')
350
+ getattr(
351
+ client_info, "client_id", "UNKNOWN (client_info has no client_id attribute)"
352
+ )
342
353
  try:
343
354
  # Validate the client information
344
355
  if not client_info.client_id:
345
356
  raise RegistrationError(
346
357
  error="invalid_client_metadata",
347
- error_description="Client ID is missing in client_info provided to register_client"
358
+ error_description="Client ID is missing in client_info provided to register_client",
348
359
  )
349
360
 
350
361
  if not client_info.redirect_uris:
351
362
  raise RegistrationError(
352
363
  error="invalid_redirect_uri",
353
- error_description="At least one redirect URI is required"
364
+ error_description="At least one redirect URI is required",
354
365
  )
355
-
366
+
356
367
  # Store the client
357
368
  self.storage.store_client(client_info.client_id, client_info)
358
- except Exception as e:
359
- raise # Re-raise the exception so FastMCP can handle it
360
-
369
+ except Exception:
370
+ raise # Re-raise the exception so FastMCP can handle it
371
+
361
372
  async def authorize(
362
- self, client: OAuthClientInformationFull, params: AuthorizationParams # params from MCP client
373
+ self,
374
+ client: OAuthClientInformationFull,
375
+ params: AuthorizationParams, # params from MCP client
363
376
  ) -> str:
364
377
  """Handle an authorization request.
365
378
  This method is called when an MCP client requests authorization.
@@ -369,118 +382,123 @@ class GolfOAuthProvider(OAuthAuthorizationServerProvider):
369
382
  import urllib.parse
370
383
 
371
384
  idp_flow_state = secrets.token_hex(16)
372
- mcp_client_original_state = params.state
373
-
385
+ mcp_client_original_state = params.state
386
+
374
387
  self.state_mapping[idp_flow_state] = {
375
388
  "client_id": client.client_id,
376
389
  "redirect_uri": str(params.redirect_uri),
377
390
  "code_challenge": params.code_challenge,
378
- "code_challenge_method": "S256" if params.code_challenge else None, # Store S256 if challenge exists, else None
391
+ "code_challenge_method": (
392
+ "S256" if params.code_challenge else None
393
+ ), # Store S256 if challenge exists, else None
379
394
  "scopes": params.scopes,
380
395
  "redirect_uri_provided_explicitly": params.redirect_uri_provided_explicitly,
381
- "mcp_client_original_state": mcp_client_original_state
396
+ "mcp_client_original_state": mcp_client_original_state,
382
397
  }
383
-
398
+
384
399
  # Use self.config.callback_path for consistency
385
- idp_callback_uri = f"{self.config.issuer_url.rstrip('/')}{self.config.callback_path}"
400
+ idp_callback_uri = (
401
+ f"{self.config.issuer_url.rstrip('/')}{self.config.callback_path}"
402
+ )
386
403
 
387
404
  client_id = self._get_client_id()
388
-
405
+
389
406
  auth_params_for_idp = {
390
407
  "client_id": client_id,
391
408
  "redirect_uri": idp_callback_uri,
392
409
  "scope": " ".join(self.config.scopes),
393
410
  "state": idp_flow_state,
394
- "response_type": "code"
411
+ "response_type": "code",
395
412
  }
396
-
413
+
397
414
  if params.code_challenge:
398
415
  auth_params_for_idp["code_challenge"] = params.code_challenge
399
416
  # Always use S256 if a challenge is present, as it's the standard and what the client sends.
400
- auth_params_for_idp["code_challenge_method"] = "S256"
417
+ auth_params_for_idp["code_challenge_method"] = "S256"
401
418
 
402
419
  query_for_idp = urllib.parse.urlencode(auth_params_for_idp)
403
-
420
+
404
421
  return f"{self.config.authorize_url}?{query_for_idp}"
405
-
422
+
406
423
  async def load_authorization_code(
407
- self,
408
- client: OAuthClientInformationFull,
409
- code: str
410
- ) -> Optional[AuthorizationCode]:
424
+ self, client: OAuthClientInformationFull, code: str
425
+ ) -> AuthorizationCode | None:
411
426
  """Load an authorization code.
412
-
427
+
413
428
  Args:
414
429
  client: The client information
415
430
  code: The authorization code
416
-
431
+
417
432
  Returns:
418
433
  The authorization code object or None if not found
419
434
  """
420
435
  auth_code = self.storage.get_auth_code(code)
421
-
436
+
422
437
  if not auth_code:
423
438
  return None
424
-
439
+
425
440
  # Verify the code belongs to this client
426
441
  if auth_code.client_id != client.client_id:
427
442
  return None
428
-
443
+
429
444
  # Verify the code hasn't expired
430
445
  if auth_code.expires_at and auth_code.expires_at < datetime.now().timestamp():
431
446
  self.storage.delete_auth_code(code)
432
447
  return None
433
-
448
+
434
449
  return auth_code
435
-
450
+
436
451
  async def exchange_authorization_code(
437
- self,
438
- client: OAuthClientInformationFull,
439
- code: AuthorizationCode # This is AuthorizationCode object
452
+ self,
453
+ client: OAuthClientInformationFull,
454
+ code: AuthorizationCode, # This is AuthorizationCode object
440
455
  ) -> OAuthToken:
441
456
  """Exchange an authorization code for tokens.
442
-
457
+
443
458
  Args:
444
459
  client: The client information
445
460
  code: The authorization code object
446
-
461
+
447
462
  Returns:
448
463
  The OAuth token response
449
-
464
+
450
465
  Raises:
451
466
  TokenError: If the code exchange fails
452
467
  """
453
468
  # Retrieve the provider token that was stored temporarily during callback
454
469
  provider_token = self.storage.get_provider_token_for_auth_code(code.code)
455
-
470
+
456
471
  # Delete the code and its mapping to ensure one-time use
457
- self.storage.delete_auth_code(code.code) # This now also deletes the mapping
458
-
472
+ self.storage.delete_auth_code(code.code) # This now also deletes the mapping
473
+
459
474
  # Generate an access token
460
- access_token_str = self._generate_jwt( # Renamed for clarity
461
- subject=client.client_id,
462
- scopes=code.scopes
475
+ access_token_str = self._generate_jwt( # Renamed for clarity
476
+ subject=client.client_id, scopes=code.scopes
463
477
  )
464
-
478
+
465
479
  # Generate a refresh token if needed
466
- refresh_token_str = str(uuid.uuid4()) if "refresh_token" in client.grant_types else None # Renamed for clarity
467
-
480
+ refresh_token_str = (
481
+ str(uuid.uuid4()) if "refresh_token" in client.grant_types else None
482
+ ) # Renamed for clarity
483
+
468
484
  # Store the mapping from our new MCP access token to the provider's access token
469
485
  if provider_token and access_token_str:
470
486
  self.storage.store_provider_token(access_token_str, provider_token)
471
-
487
+
472
488
  # Store the tokens
473
489
  if refresh_token_str:
474
490
  self.storage.store_refresh_token(
475
- refresh_token_str,
491
+ refresh_token_str,
476
492
  RefreshToken(
477
493
  token=refresh_token_str,
478
494
  client_id=client.client_id,
479
495
  scopes=code.scopes,
480
- expires_at=int(datetime.now().timestamp() + (self.config.token_expiration * 24)) # 24x longer, cast to int
481
- )
496
+ expires_at=int(
497
+ datetime.now().timestamp() + (self.config.token_expiration * 24)
498
+ ), # 24x longer, cast to int
499
+ ),
482
500
  )
483
-
501
+
484
502
  # Store access token information for validation later
485
503
  # Note: For JWTs, we might not need to store them if we can verify the signature
486
504
  self.storage.store_access_token(
@@ -489,71 +507,71 @@ class GolfOAuthProvider(OAuthAuthorizationServerProvider):
489
507
  token=access_token_str,
490
508
  client_id=client.client_id,
491
509
  scopes=code.scopes,
492
- expires_at=int(datetime.now().timestamp() + self.config.token_expiration) # Cast to int
493
- )
510
+ expires_at=int(
511
+ datetime.now().timestamp() + self.config.token_expiration
512
+ ), # Cast to int
513
+ ),
494
514
  )
495
-
515
+
496
516
  # Create and return the OAuth token response
497
517
  return OAuthToken(
498
518
  access_token=access_token_str,
499
519
  token_type="bearer",
500
520
  expires_in=self.config.token_expiration,
501
521
  refresh_token=refresh_token_str,
502
- scope=" ".join(code.scopes)
522
+ scope=" ".join(code.scopes),
503
523
  )
504
-
524
+
505
525
  async def load_refresh_token(
506
- self,
507
- client: OAuthClientInformationFull,
508
- refresh_token: str
509
- ) -> Optional[RefreshToken]:
526
+ self, client: OAuthClientInformationFull, refresh_token: str
527
+ ) -> RefreshToken | None:
510
528
  """Load a refresh token.
511
-
529
+
512
530
  Args:
513
531
  client: The client information
514
532
  refresh_token: The refresh token string
515
-
533
+
516
534
  Returns:
517
535
  The refresh token object or None if not found
518
536
  """
519
537
  token = self.storage.get_refresh_token(refresh_token)
520
-
538
+
521
539
  if not token:
522
540
  return None
523
-
541
+
524
542
  # Verify the token belongs to this client
525
543
  if token.client_id != client.client_id:
526
544
  return None
527
-
545
+
528
546
  # Verify the token hasn't expired
529
547
  if token.expires_at and token.expires_at < datetime.now().timestamp():
530
548
  self.storage.delete_refresh_token(refresh_token)
531
549
  return None
532
-
550
+
533
551
  return token
534
-
552
+
535
553
  async def exchange_refresh_token(
536
- self,
537
- client: OAuthClientInformationFull,
538
- refresh_token: RefreshToken,
539
- scopes: List[str]
554
+ self,
555
+ client: OAuthClientInformationFull,
556
+ refresh_token: RefreshToken,
557
+ scopes: list[str],
540
558
  ) -> OAuthToken:
541
559
  """Exchange a refresh token for a new token pair.
542
-
560
+
543
561
  Args:
544
562
  client: The client information
545
563
  refresh_token: The refresh token object
546
564
  scopes: The requested scopes (may be a subset of original)
547
-
565
+
548
566
  Returns:
549
567
  The new OAuth token response
550
-
568
+
551
569
  Raises:
552
570
  TokenError: If the token exchange fails
553
571
  """
554
572
  # Delete the old refresh token (implement token rotation for security)
555
573
  self.storage.delete_refresh_token(refresh_token.token)
556
-
574
+
557
575
  # Determine the scopes for the new token
558
576
  # If requested scopes are provided, they must be a subset of the original
559
577
  if scopes:
@@ -562,20 +580,18 @@ class GolfOAuthProvider(OAuthAuthorizationServerProvider):
562
580
  valid_scopes = refresh_token.scopes
563
581
  else:
564
582
  valid_scopes = refresh_token.scopes
565
-
583
+
566
584
  # Generate a new access token
567
- access_token = self._generate_jwt(
568
- subject=client.client_id,
569
- scopes=valid_scopes
570
- )
571
-
585
+ access_token = self._generate_jwt(subject=client.client_id, scopes=valid_scopes)
586
+
572
587
  # Generate a new refresh token
573
588
  new_refresh_token = str(uuid.uuid4())
574
-
589
+
575
590
  # Find the provider token if it exists from the old access token
576
591
  # Note: This assumes each refresh generates only one access token
577
592
  old_access_tokens = [
578
- token for token, data in self.storage.access_tokens.items()
593
+ token
594
+ for token, data in self.storage.access_tokens.items()
579
595
  if data.client_id == client.client_id
580
596
  ]
581
597
  provider_token = None
@@ -585,18 +601,20 @@ class GolfOAuthProvider(OAuthAuthorizationServerProvider):
585
601
  # Store the provider token mapping for the new access token
586
602
  self.storage.store_provider_token(access_token, provider_token)
587
603
  break
588
-
604
+
589
605
  # Store the new tokens
590
606
  self.storage.store_refresh_token(
591
- new_refresh_token,
607
+ new_refresh_token,
592
608
  RefreshToken(
593
609
  token=new_refresh_token,
594
610
  client_id=client.client_id,
595
611
  scopes=valid_scopes,
596
- expires_at=int(datetime.now().timestamp() + (self.config.token_expiration * 24)) # Cast to int
597
- )
612
+ expires_at=int(
613
+ datetime.now().timestamp() + (self.config.token_expiration * 24)
614
+ ), # Cast to int
615
+ ),
598
616
  )
599
-
617
+
600
618
  # Store access token information
601
619
  self.storage.store_access_token(
602
620
  access_token,
@@ -604,67 +622,68 @@ class GolfOAuthProvider(OAuthAuthorizationServerProvider):
604
622
  token=access_token,
605
623
  client_id=client.client_id,
606
624
  scopes=valid_scopes,
607
- expires_at=int(datetime.now().timestamp() + self.config.token_expiration) # Cast to int
608
- )
625
+ expires_at=int(
626
+ datetime.now().timestamp() + self.config.token_expiration
627
+ ), # Cast to int
628
+ ),
609
629
  )
610
-
630
+
611
631
  # Create and return the OAuth token response
612
632
  return OAuthToken(
613
633
  access_token=access_token,
614
634
  token_type="bearer",
615
635
  expires_in=self.config.token_expiration,
616
636
  refresh_token=new_refresh_token,
617
- scope=" ".join(valid_scopes)
637
+ scope=" ".join(valid_scopes),
618
638
  )
619
-
620
- async def load_access_token(self, token: str) -> Optional[AccessToken]:
639
+
640
+ async def load_access_token(self, token: str) -> AccessToken | None:
621
641
  """Load and validate an access token."""
622
-
642
+
623
643
  payload = self._verify_jwt(token)
624
644
  if not payload:
625
- return None
626
-
645
+ return None
646
+
627
647
  client_id = payload.get("sub")
628
648
  scopes = payload.get("scp", [])
629
649
  expires_at = payload.get("exp")
630
-
631
-
650
+
632
651
  access_token_obj = AccessToken(
633
652
  token=token,
634
653
  client_id=client_id,
635
654
  scopes=scopes,
636
- expires_at=int(expires_at) if expires_at is not None else None
655
+ expires_at=int(expires_at) if expires_at is not None else None,
637
656
  )
638
-
657
+
639
658
  return access_token_obj
640
-
641
- async def revoke_token(self, token: Union[AccessToken, RefreshToken]) -> None:
659
+
660
+ async def revoke_token(self, token: AccessToken | RefreshToken) -> None:
642
661
  """Revoke a token.
643
-
662
+
644
663
  Args:
645
664
  token: The token to revoke (access or refresh)
646
665
  """
647
666
  # Try to revoke as access token
648
667
  self.storage.delete_access_token(token.token)
649
-
668
+
650
669
  # Try to revoke as refresh token
651
670
  self.storage.delete_refresh_token(token.token)
652
-
671
+
653
672
  # Clean up provider token mapping if it exists
654
673
  provider_token = self.storage.get_provider_token(token.token)
655
674
  if provider_token:
656
675
  self.storage.provider_tokens.pop(token.token, None)
657
-
658
- def get_provider_token(self, mcp_token: str) -> Optional[str]:
676
+
677
+ def get_provider_token(self, mcp_token: str) -> str | None:
659
678
  """Get the provider token associated with an MCP token.
660
-
679
+
661
680
  This is a non-standard method to allow access to the provider token
662
681
  (e.g., GitHub token) for a given MCP token. This can be used by
663
682
  tools that need to access provider APIs.
664
-
683
+
665
684
  Args:
666
685
  mcp_token: The MCP token
667
-
686
+
668
687
  Returns:
669
688
  The provider token or None if not found
670
689
  """
@@ -673,39 +692,48 @@ class GolfOAuthProvider(OAuthAuthorizationServerProvider):
673
692
 
674
693
  def create_callback_handler(provider: GolfOAuthProvider):
675
694
  """Create a callback handler for OAuth authorization.
676
-
695
+
677
696
  This function creates a callback handler that can be used to handle
678
697
  the OAuth callback from the provider (e.g., GitHub).
679
-
698
+
680
699
  Args:
681
700
  provider: The OAuth provider
682
-
701
+
683
702
  Returns:
684
703
  An async function that handles the callback
685
704
  """
705
+
686
706
  async def handle_callback(request):
687
707
  """Handle the OAuth callback.
688
-
708
+
689
709
  Args:
690
710
  request: The HTTP request
691
-
711
+
692
712
  Returns:
693
713
  The HTTP response
694
714
  """
695
715
  # Extract the code and state from the request
696
- idp_auth_code = request.query_params.get("code") # Renamed for clarity: code from IdP
697
- idp_state = request.query_params.get("state") # Renamed for clarity: state from IdP
698
-
716
+ idp_auth_code = request.query_params.get(
717
+ "code"
718
+ ) # Renamed for clarity: code from IdP
719
+ idp_state = request.query_params.get(
720
+ "state"
721
+ ) # Renamed for clarity: state from IdP
722
+
699
723
  if not idp_auth_code:
700
- return RedirectResponse("/auth-error?error=no_code_from_idp") # More specific error
701
-
724
+ return RedirectResponse(
725
+ "/auth-error?error=no_code_from_idp"
726
+ ) # More specific error
727
+
702
728
  # Use provider.config.callback_path for consistency
703
729
  # This is the redirect_uri registered with the IdP and used in the /authorize step
704
- idp_callback_uri_for_token_exchange = f"{provider.config.issuer_url.rstrip('/')}{provider.config.callback_path}"
705
-
730
+ idp_callback_uri_for_token_exchange = (
731
+ f"{provider.config.issuer_url.rstrip('/')}{provider.config.callback_path}"
732
+ )
733
+
706
734
  client_id_for_idp = provider._get_client_id()
707
735
  client_secret_for_idp = provider._get_client_secret()
708
-
736
+
709
737
  async with httpx.AsyncClient() as client:
710
738
  response = await client.post(
711
739
  provider.config.token_url,
@@ -713,86 +741,121 @@ def create_callback_handler(provider: GolfOAuthProvider):
713
741
  data={
714
742
  "client_id": client_id_for_idp,
715
743
  "client_secret": client_secret_for_idp,
716
- "code": idp_auth_code, # Use code from IdP
717
- "redirect_uri": idp_callback_uri_for_token_exchange
718
- }
744
+ "code": idp_auth_code, # Use code from IdP
745
+ "redirect_uri": idp_callback_uri_for_token_exchange,
746
+ },
719
747
  )
720
-
748
+
721
749
  if response.status_code != 200:
722
- error_detail = response.text[:200] # Limit error detail length
723
- return RedirectResponse(f"/auth-error?error=idp_token_exchange_failed&detail={urllib.parse.quote(error_detail)}")
724
-
750
+ error_detail = response.text[:200] # Limit error detail length
751
+ return RedirectResponse(
752
+ f"/auth-error?error=idp_token_exchange_failed&detail={urllib.parse.quote(error_detail)}"
753
+ )
754
+
725
755
  # Get the provider token from the response
726
756
  token_data = response.json()
727
- provider_access_token = token_data.get("access_token") # This is the token from GitHub/Google etc.
728
-
757
+ provider_access_token = token_data.get(
758
+ "access_token"
759
+ ) # This is the token from GitHub/Google etc.
760
+
729
761
  if not provider_access_token:
730
762
  return RedirectResponse("/auth-error?error=no_access_token_from_idp")
731
-
763
+
732
764
  try:
733
765
  # Get user information from the provider using the token (optional step)
734
766
  # user_info = None (keep this if user_info is used later, otherwise remove)
735
767
  # ... (userinfo fetching logic if needed) ...
736
-
737
- original_mcp_client_details = provider.state_mapping.pop(idp_state, None) # Use state from IdP
768
+
769
+ original_mcp_client_details = provider.state_mapping.pop(
770
+ idp_state, None
771
+ ) # Use state from IdP
738
772
  if not original_mcp_client_details:
739
- return RedirectResponse(f"/auth-error?error=invalid_idp_state")
773
+ return RedirectResponse("/auth-error?error=invalid_idp_state")
740
774
 
741
775
  original_mcp_client_id = original_mcp_client_details["client_id"]
742
- original_mcp_redirect_uri = original_mcp_client_details["redirect_uri"] # MCP client's redirect_uri
776
+ original_mcp_redirect_uri = original_mcp_client_details[
777
+ "redirect_uri"
778
+ ] # MCP client's redirect_uri
743
779
  original_code_challenge = original_mcp_client_details["code_challenge"]
744
- original_code_challenge_method = original_mcp_client_details["code_challenge_method"]
745
-
780
+ original_code_challenge_method = original_mcp_client_details[
781
+ "code_challenge_method"
782
+ ]
783
+
746
784
  requested_scopes_for_mcp_server_str = original_mcp_client_details["scopes"]
747
- mcp_client_original_state_to_pass_back = original_mcp_client_details.get("mcp_client_original_state")
748
- original_redirect_uri_provided_explicitly = original_mcp_client_details["redirect_uri_provided_explicitly"]
749
-
750
- mcp_client = await provider.get_client(original_mcp_client_id) # Renamed for clarity
785
+ mcp_client_original_state_to_pass_back = original_mcp_client_details.get(
786
+ "mcp_client_original_state"
787
+ )
788
+ original_redirect_uri_provided_explicitly = original_mcp_client_details[
789
+ "redirect_uri_provided_explicitly"
790
+ ]
791
+
792
+ mcp_client = await provider.get_client(
793
+ original_mcp_client_id
794
+ ) # Renamed for clarity
751
795
  if not mcp_client:
752
- return RedirectResponse(f"/auth-error?error=mcp_client_not_found_post_callback")
753
-
754
- final_scopes_for_mcp_auth_code: List[str]
755
- if requested_scopes_for_mcp_server_str: # Scopes requested by MCP client
756
- final_scopes_for_mcp_auth_code = requested_scopes_for_mcp_server_str.split()
757
- else: # Default to client's registered scopes if none explicitly requested
758
- final_scopes_for_mcp_auth_code = mcp_client.scope.split() if mcp_client.scope else []
759
-
796
+ return RedirectResponse(
797
+ "/auth-error?error=mcp_client_not_found_post_callback"
798
+ )
799
+
800
+ final_scopes_for_mcp_auth_code: list[str]
801
+ if requested_scopes_for_mcp_server_str: # Scopes requested by MCP client
802
+ final_scopes_for_mcp_auth_code = (
803
+ requested_scopes_for_mcp_server_str.split()
804
+ )
805
+ else: # Default to client's registered scopes if none explicitly requested
806
+ final_scopes_for_mcp_auth_code = (
807
+ mcp_client.scope.split() if mcp_client.scope else []
808
+ )
809
+
760
810
  # This is the auth code our GolfMCP server issues to the MCP client
761
- mcp_auth_code_str = str(uuid.uuid4())
762
-
811
+ mcp_auth_code_str = str(uuid.uuid4())
812
+
763
813
  # Store the mapping from our mcp_auth_code_str to the provider_access_token (e.g., GitHub token)
764
814
  # This will be retrieved when the MCP client exchanges mcp_auth_code_str for an MCP access token
765
- provider.storage.store_auth_code_provider_token_mapping(mcp_auth_code_str, provider_access_token)
766
-
815
+ provider.storage.store_auth_code_provider_token_mapping(
816
+ mcp_auth_code_str, provider_access_token
817
+ )
818
+
767
819
  # Create the AuthorizationCode object for our server
768
- mcp_auth_code_obj = AuthorizationCode( # Renamed for clarity
820
+ mcp_auth_code_obj = AuthorizationCode( # Renamed for clarity
769
821
  code=mcp_auth_code_str,
770
822
  client_id=mcp_client.client_id,
771
- redirect_uri=original_mcp_redirect_uri,
772
- scopes=final_scopes_for_mcp_auth_code,
773
- expires_at=int(datetime.now().timestamp() + 600), # 10 minutes, cast to int
823
+ redirect_uri=original_mcp_redirect_uri,
824
+ scopes=final_scopes_for_mcp_auth_code,
825
+ expires_at=int(
826
+ datetime.now().timestamp() + 600
827
+ ), # 10 minutes, cast to int
774
828
  redirect_uri_provided_explicitly=original_redirect_uri_provided_explicitly,
775
829
  code_challenge=original_code_challenge,
776
- code_challenge_method=original_code_challenge_method
830
+ code_challenge_method=original_code_challenge_method,
777
831
  )
778
-
832
+
779
833
  # Store our auth code object (without provider_token as an attribute)
780
834
  provider.storage.store_auth_code(mcp_auth_code_str, mcp_auth_code_obj)
781
-
835
+
782
836
  query_params_for_mcp_client = {
783
- "code": mcp_auth_code_str # Send our generated auth code to the MCP client
837
+ "code": mcp_auth_code_str # Send our generated auth code to the MCP client
784
838
  }
785
839
  if mcp_client_original_state_to_pass_back:
786
- query_params_for_mcp_client["state"] = mcp_client_original_state_to_pass_back
787
-
788
- import urllib.parse # Ensure it's imported here too
789
- final_query_for_mcp_client = urllib.parse.urlencode(query_params_for_mcp_client)
790
- final_redirect_to_mcp_client = f"{original_mcp_redirect_uri}?{final_query_for_mcp_client}"
791
-
840
+ query_params_for_mcp_client["state"] = (
841
+ mcp_client_original_state_to_pass_back
842
+ )
843
+
844
+ import urllib.parse # Ensure it's imported here too
845
+
846
+ final_query_for_mcp_client = urllib.parse.urlencode(
847
+ query_params_for_mcp_client
848
+ )
849
+ final_redirect_to_mcp_client = (
850
+ f"{original_mcp_redirect_uri}?{final_query_for_mcp_client}"
851
+ )
852
+
792
853
  return RedirectResponse(final_redirect_to_mcp_client)
793
-
794
- except Exception as e:
854
+
855
+ except Exception:
795
856
  # Avoid sending raw exception details to the client for security
796
- return RedirectResponse("/auth-error?error=callback_processing_failed&detail=internal_server_error")
797
-
798
- return handle_callback
857
+ return RedirectResponse(
858
+ "/auth-error?error=callback_processing_failed&detail=internal_server_error"
859
+ )
860
+
861
+ return handle_callback