golf-mcp 0.1.20__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of golf-mcp might be problematic. Click here for more details.

Files changed (123) hide show
  1. golf/__init__.py +9 -1
  2. golf/_endpoints.py +6 -0
  3. golf/_endpoints_fallback.py +10 -0
  4. golf/auth/__init__.py +235 -83
  5. golf/auth/api_key.py +6 -14
  6. golf/auth/factory.py +358 -0
  7. golf/auth/helpers.py +12 -42
  8. golf/auth/providers.py +446 -0
  9. golf/auth/registry.py +256 -0
  10. golf/cli/branding.py +192 -0
  11. golf/cli/main.py +28 -69
  12. golf/commands/__init__.py +2 -0
  13. golf/commands/build.py +4 -7
  14. golf/commands/init.py +30 -53
  15. golf/commands/run.py +50 -20
  16. golf/core/builder.py +355 -414
  17. golf/core/builder_auth.py +63 -144
  18. golf/core/builder_telemetry.py +26 -3
  19. golf/core/config.py +38 -59
  20. golf/core/parser.py +132 -139
  21. golf/core/platform.py +12 -10
  22. golf/core/telemetry.py +11 -19
  23. golf/core/transformer.py +38 -15
  24. golf/examples/__pycache__/__init__.cpython-311.pyc +0 -0
  25. golf/examples/basic/.coverage +0 -0
  26. golf/examples/basic/.env.example +8 -4
  27. golf/examples/basic/README.md +117 -45
  28. golf/examples/basic/__pycache__/auth.cpython-311.pyc +0 -0
  29. golf/examples/basic/auth.py +76 -0
  30. golf/examples/basic/golf.json +2 -5
  31. golf/examples/basic/htmlcov/.gitignore +2 -0
  32. golf/examples/basic/htmlcov/class_index.html +547 -0
  33. golf/examples/basic/htmlcov/coverage_html_cb_6fb7b396.js +733 -0
  34. golf/examples/basic/htmlcov/favicon_32_cb_58284776.png +0 -0
  35. golf/examples/basic/htmlcov/function_index.html +2091 -0
  36. golf/examples/basic/htmlcov/index.html +349 -0
  37. golf/examples/basic/htmlcov/keybd_closed_cb_ce680311.png +0 -0
  38. golf/examples/basic/htmlcov/status.json +1 -0
  39. golf/examples/basic/htmlcov/style_cb_8e611ae1.css +337 -0
  40. golf/examples/basic/htmlcov/z_1c9a91c0e91c8496___init___py.html +323 -0
  41. golf/examples/basic/htmlcov/z_1c9a91c0e91c8496_api_key_py.html +170 -0
  42. golf/examples/basic/htmlcov/z_1c9a91c0e91c8496_factory_py.html +430 -0
  43. golf/examples/basic/htmlcov/z_1c9a91c0e91c8496_helpers_py.html +288 -0
  44. golf/examples/basic/htmlcov/z_1c9a91c0e91c8496_providers_py.html +493 -0
  45. golf/examples/basic/htmlcov/z_1c9a91c0e91c8496_registry_py.html +353 -0
  46. golf/examples/basic/htmlcov/z_3ec3b3f490dc0950___init___py.html +120 -0
  47. golf/examples/basic/htmlcov/z_3ec3b3f490dc0950_instrumentation_py.html +1535 -0
  48. golf/examples/basic/htmlcov/z_4b8b9dd4ccccc5db___init___py.html +98 -0
  49. golf/examples/basic/htmlcov/z_4b8b9dd4ccccc5db_branding_py.html +289 -0
  50. golf/examples/basic/htmlcov/z_4b8b9dd4ccccc5db_main_py.html +476 -0
  51. golf/examples/basic/htmlcov/z_5a6c4e6bcc86fb2f___init___py.html +97 -0
  52. golf/examples/basic/htmlcov/z_6cadab9ec0df475d___init___py.html +102 -0
  53. golf/examples/basic/htmlcov/z_6cadab9ec0df475d_build_py.html +178 -0
  54. golf/examples/basic/htmlcov/z_6cadab9ec0df475d_init_py.html +387 -0
  55. golf/examples/basic/htmlcov/z_6cadab9ec0df475d_run_py.html +222 -0
  56. golf/examples/basic/htmlcov/z_6fcdee0582ba84e4___init___py.html +106 -0
  57. golf/examples/basic/htmlcov/z_6fcdee0582ba84e4__endpoints_fallback_py.html +107 -0
  58. golf/examples/basic/htmlcov/z_7ba499ed22986217___init___py.html +98 -0
  59. golf/examples/basic/htmlcov/z_7ba499ed22986217_builder_auth_py.html +306 -0
  60. golf/examples/basic/htmlcov/z_7ba499ed22986217_builder_metrics_py.html +329 -0
  61. golf/examples/basic/htmlcov/z_7ba499ed22986217_builder_py.html +1471 -0
  62. golf/examples/basic/htmlcov/z_7ba499ed22986217_builder_telemetry_py.html +186 -0
  63. golf/examples/basic/htmlcov/z_7ba499ed22986217_config_py.html +315 -0
  64. golf/examples/basic/htmlcov/z_7ba499ed22986217_parser_py.html +1149 -0
  65. golf/examples/basic/htmlcov/z_7ba499ed22986217_platform_py.html +279 -0
  66. golf/examples/basic/htmlcov/z_7ba499ed22986217_telemetry_py.html +589 -0
  67. golf/examples/basic/htmlcov/z_7ba499ed22986217_transformer_py.html +286 -0
  68. golf/examples/basic/htmlcov/z_7d7da37693a43688___init___py.html +107 -0
  69. golf/examples/basic/htmlcov/z_7d7da37693a43688_collector_py.html +417 -0
  70. golf/examples/basic/htmlcov/z_7d7da37693a43688_registry_py.html +109 -0
  71. golf/examples/basic/htmlcov/z_abe733142b40ad4e___init___py.html +109 -0
  72. golf/examples/basic/htmlcov/z_abe733142b40ad4e_context_py.html +150 -0
  73. golf/examples/basic/htmlcov/z_abe733142b40ad4e_elicitation_py.html +267 -0
  74. golf/examples/basic/htmlcov/z_abe733142b40ad4e_sampling_py.html +318 -0
  75. golf/examples/basic/prompts/__pycache__/welcome.cpython-311.pyc +0 -0
  76. golf/examples/basic/prompts/welcome.py +3 -5
  77. golf/examples/basic/resources/__pycache__/current_time.cpython-311.pyc +0 -0
  78. golf/examples/basic/resources/__pycache__/info.cpython-311.pyc +0 -0
  79. golf/examples/basic/resources/current_time.py +5 -13
  80. golf/examples/basic/resources/weather/__pycache__/common.cpython-311.pyc +0 -0
  81. golf/examples/basic/resources/weather/__pycache__/current.cpython-311.pyc +0 -0
  82. golf/examples/basic/resources/weather/__pycache__/forecast.cpython-311.pyc +0 -0
  83. golf/examples/basic/resources/weather/city.py +46 -0
  84. golf/examples/basic/resources/weather/common.py +4 -11
  85. golf/examples/basic/resources/weather/current.py +5 -5
  86. golf/examples/basic/resources/weather/forecast.py +5 -5
  87. golf/examples/basic/tools/__pycache__/calculator.cpython-311.pyc +0 -0
  88. golf/examples/basic/tools/calculator.py +94 -0
  89. golf/examples/basic/tools/say/__pycache__/hello.cpython-311.pyc +0 -0
  90. golf/examples/basic/tools/say/hello.py +65 -0
  91. golf/metrics/collector.py +100 -19
  92. golf/telemetry/__init__.py +4 -0
  93. golf/telemetry/instrumentation.py +484 -178
  94. golf/utilities/__init__.py +12 -0
  95. golf/utilities/context.py +53 -0
  96. golf/utilities/elicitation.py +170 -0
  97. golf/utilities/sampling.py +221 -0
  98. {golf_mcp-0.1.20.dist-info → golf_mcp-0.2.1.dist-info}/METADATA +51 -104
  99. golf_mcp-0.2.1.dist-info/RECORD +110 -0
  100. golf/auth/oauth.py +0 -861
  101. golf/auth/provider.py +0 -115
  102. golf/examples/api_key/.env +0 -2
  103. golf/examples/api_key/.env.example +0 -1
  104. golf/examples/api_key/README.md +0 -84
  105. golf/examples/api_key/golf.json +0 -8
  106. golf/examples/api_key/pre_build.py +0 -11
  107. golf/examples/api_key/tools/issues/create.py +0 -93
  108. golf/examples/api_key/tools/issues/list.py +0 -92
  109. golf/examples/api_key/tools/repos/list.py +0 -111
  110. golf/examples/api_key/tools/search/code.py +0 -106
  111. golf/examples/api_key/tools/users/get.py +0 -82
  112. golf/examples/basic/.env +0 -5
  113. golf/examples/basic/pre_build.py +0 -28
  114. golf/examples/basic/tools/github_user.py +0 -65
  115. golf/examples/basic/tools/hello.py +0 -34
  116. golf/examples/basic/tools/payments/charge.py +0 -70
  117. golf/examples/basic/tools/payments/common.py +0 -36
  118. golf/examples/basic/tools/payments/refund.py +0 -61
  119. golf_mcp-0.1.20.dist-info/RECORD +0 -60
  120. {golf_mcp-0.1.20.dist-info → golf_mcp-0.2.1.dist-info}/WHEEL +0 -0
  121. {golf_mcp-0.1.20.dist-info → golf_mcp-0.2.1.dist-info}/entry_points.txt +0 -0
  122. {golf_mcp-0.1.20.dist-info → golf_mcp-0.2.1.dist-info}/licenses/LICENSE +0 -0
  123. {golf_mcp-0.1.20.dist-info → golf_mcp-0.2.1.dist-info}/top_level.txt +0 -0
golf/auth/oauth.py DELETED
@@ -1,861 +0,0 @@
1
- """OAuth provider implementation for GolfMCP.
2
-
3
- This module provides an implementation of the MCP OAuthAuthorizationServerProvider
4
- interface for GolfMCP servers. It handles the OAuth 2.0 authentication flow,
5
- token management, and client registration.
6
- """
7
-
8
- import os
9
- import time
10
- import uuid
11
- from datetime import datetime
12
- from typing import Any
13
-
14
- import httpx
15
- import jwt
16
- from mcp.server.auth.provider import (
17
- AccessToken,
18
- AuthorizationCode,
19
- AuthorizationParams,
20
- OAuthAuthorizationServerProvider,
21
- RefreshToken,
22
- RegistrationError,
23
- )
24
- from mcp.shared.auth import (
25
- OAuthClientInformationFull,
26
- OAuthToken,
27
- )
28
- from starlette.responses import RedirectResponse
29
-
30
- from .provider import ProviderConfig
31
-
32
-
33
- class TokenStorage:
34
- """Simple in-memory token storage.
35
-
36
- This class provides a simple in-memory storage for OAuth tokens,
37
- authorization codes, and client information. In a production
38
- environment, this should be replaced with a persistent storage
39
- solution.
40
- """
41
-
42
- def __init__(self) -> None:
43
- """Initialize the token storage."""
44
- self.auth_codes = {} # code_str -> AuthorizationCode
45
- self.refresh_tokens = {} # token_str -> RefreshToken
46
- self.access_tokens = {} # token_str -> AccessToken
47
- self.clients = {} # client_id -> OAuthClientInformationFull
48
- self.provider_tokens = {} # mcp_access_token_str -> provider_access_token_str
49
- self.auth_code_to_provider_token = {} # auth_code_str -> provider_access_token_str
50
-
51
- def store_auth_code(
52
- self, code: str, auth_code_obj: AuthorizationCode
53
- ) -> None: # Renamed auth_code to auth_code_obj for clarity
54
- """Store an authorization code.
55
-
56
- Args:
57
- code: The authorization code string
58
- auth_code_obj: The authorization code object
59
- """
60
- self.auth_codes[code] = auth_code_obj
61
-
62
- def get_auth_code(self, code: str) -> AuthorizationCode | None:
63
- """Get an authorization code by value.
64
-
65
- Args:
66
- code: The authorization code string
67
-
68
- Returns:
69
- The authorization code object or None if not found
70
- """
71
- return self.auth_codes.get(code)
72
-
73
- def delete_auth_code(self, code: str) -> None:
74
- """Delete an authorization code and its associated provider token mapping.
75
-
76
- Args:
77
- code: The authorization code string
78
- """
79
- if code in self.auth_codes:
80
- del self.auth_codes[code]
81
- if code in self.auth_code_to_provider_token:
82
- del self.auth_code_to_provider_token[code]
83
-
84
- def store_auth_code_provider_token_mapping(
85
- self, auth_code_str: str, provider_token: str
86
- ) -> None:
87
- """Store a mapping from an auth_code string to a provider_token string."""
88
- self.auth_code_to_provider_token[auth_code_str] = provider_token
89
-
90
- def get_provider_token_for_auth_code(self, auth_code_str: str) -> str | None:
91
- """Retrieve a provider_token string using an auth_code string."""
92
- return self.auth_code_to_provider_token.get(auth_code_str)
93
-
94
- def store_client(self, client_id: str, client: OAuthClientInformationFull) -> None:
95
- """Store client information.
96
-
97
- Args:
98
- client_id: The client ID
99
- client: The client information
100
- """
101
- self.clients[client_id] = client
102
-
103
- def get_client(self, client_id: str) -> OAuthClientInformationFull | None:
104
- """Get client information by ID.
105
-
106
- Args:
107
- client_id: The client ID
108
-
109
- Returns:
110
- The client information or None if not found
111
- """
112
- # _diag_logger.info(f"TokenStorage: get_client called for client_id '{client_id}'. Known clients: {list(self.clients.keys())}") # Optional: uncomment for debugging
113
- return self.clients.get(client_id)
114
-
115
- def store_refresh_token(self, token: str, refresh_token: RefreshToken) -> None:
116
- """Store a refresh token.
117
-
118
- Args:
119
- token: The refresh token string
120
- refresh_token: The refresh token object
121
- """
122
- self.refresh_tokens[token] = refresh_token
123
-
124
- def get_refresh_token(self, token: str) -> RefreshToken | None:
125
- """Get a refresh token by value.
126
-
127
- Args:
128
- token: The refresh token string
129
-
130
- Returns:
131
- The refresh token object or None if not found
132
- """
133
- return self.refresh_tokens.get(token)
134
-
135
- def delete_refresh_token(self, token: str) -> None:
136
- """Delete a refresh token.
137
-
138
- Args:
139
- token: The refresh token string
140
- """
141
- if token in self.refresh_tokens:
142
- del self.refresh_tokens[token]
143
-
144
- def store_access_token(self, token: str, access_token: AccessToken) -> None:
145
- """Store an access token.
146
-
147
- Args:
148
- token: The access token string
149
- access_token: The access token object
150
- """
151
- self.access_tokens[token] = access_token
152
-
153
- def get_access_token(self, token: str) -> AccessToken | None:
154
- """Get an access token by value.
155
-
156
- Args:
157
- token: The access token string
158
-
159
- Returns:
160
- The access token object or None if not found
161
- """
162
- return self.access_tokens.get(token)
163
-
164
- def delete_access_token(self, token: str) -> None:
165
- """Delete an access token.
166
-
167
- Args:
168
- token: The access token string
169
- """
170
- if token in self.access_tokens:
171
- del self.access_tokens[token]
172
-
173
- def store_provider_token(self, mcp_token: str, provider_token: str) -> None:
174
- """Store a provider token mapping.
175
-
176
- Args:
177
- mcp_token: The MCP token string
178
- provider_token: The provider token string (e.g., GitHub token)
179
- """
180
- self.provider_tokens[mcp_token] = provider_token
181
-
182
- def get_provider_token(self, mcp_token: str) -> str | None:
183
- """Get the provider token associated with an MCP token.
184
-
185
- This is a non-standard method to allow access to the provider token
186
- (e.g., GitHub token) for a given MCP token. This can be used by
187
- tools that need to access provider APIs.
188
-
189
- Args:
190
- mcp_token: The MCP token
191
-
192
- Returns:
193
- The provider token or None if not found
194
- """
195
- return self.provider_tokens.get(mcp_token) # Changed from self.storage to self
196
-
197
-
198
- class GolfOAuthProvider(OAuthAuthorizationServerProvider):
199
- """OAuth provider implementation for GolfMCP.
200
-
201
- This class implements the OAuthAuthorizationServerProvider interface
202
- for GolfMCP servers. It handles the OAuth 2.0 authentication flow,
203
- token management, and client registration.
204
- """
205
-
206
- def __init__(self, config: ProviderConfig) -> None:
207
- """Initialize the provider.
208
-
209
- Args:
210
- config: The provider configuration
211
- """
212
- self.config = config
213
- self.storage = TokenStorage()
214
- self.state_mapping: dict[str, dict[str, Any]] = {} # Initialize state_mapping
215
-
216
- # Register default client
217
- self._register_default_client()
218
-
219
- def _get_client_id(self) -> str:
220
- """Get the client ID from config or environment."""
221
- if self.config.client_id:
222
- return self.config.client_id
223
-
224
- if self.config.client_id_env_var:
225
- value = os.environ.get(self.config.client_id_env_var)
226
- if value:
227
- return value
228
-
229
- return "missing-client-id"
230
-
231
- def _get_client_secret(self) -> str:
232
- """Get the client secret from config or environment."""
233
- if self.config.client_secret:
234
- return self.config.client_secret
235
-
236
- if self.config.client_secret_env_var:
237
- value = os.environ.get(self.config.client_secret_env_var)
238
- if value:
239
- return value
240
-
241
- return "missing-client-secret"
242
-
243
- def _get_jwt_secret(self) -> str:
244
- """Get the JWT secret from config. It's expected to be resolved by server startup."""
245
- if self.config.jwt_secret:
246
- # _diag_logger.info(f"GolfOAuthProvider: Using JWT secret from config: {self.config.jwt_secret[:5]}...")
247
- return self.config.jwt_secret
248
- else:
249
- raise ValueError(
250
- "JWT Secret is not configured in the provider. Check server logs and environment variables."
251
- )
252
-
253
- def _register_default_client(self) -> None:
254
- """Register a default client for MCP."""
255
- # These are the URIs where *this server* is allowed to redirect an MCP client
256
- # after successful authentication and MCP auth code generation.
257
- client_redirect_uris = [
258
- # Common redirect URI for MCP Inspector running locally
259
- "http://localhost:5173/callback",
260
- "http://127.0.0.1:5173/callback",
261
- # A generic callback relative to the server's issuer URL, if needed by some clients
262
- # This assumes such a client-side endpoint exists.
263
- f"{self.config.issuer_url.rstrip('/') if self.config.issuer_url else 'http://localhost:3000'}/client/callback",
264
- ]
265
-
266
- default_client = OAuthClientInformationFull(
267
- client_id="default",
268
- client_name="Default MCP Client",
269
- client_secret="", # Public client
270
- redirect_uris=client_redirect_uris,
271
- grant_types=["authorization_code", "refresh_token"],
272
- response_types=["code"],
273
- token_endpoint_auth_method="none", # Public client
274
- scope=" ".join(self.config.scopes),
275
- )
276
- self.storage.store_client("default", default_client)
277
-
278
- def _generate_jwt(
279
- self, subject: str, scopes: list[str], expires_in: int = None
280
- ) -> str:
281
- """Generate a JWT token.
282
-
283
- Args:
284
- subject: The subject of the token (usually client_id)
285
- scopes: The scopes granted to the token
286
- expires_in: The token lifetime in seconds (or None for default)
287
-
288
- Returns:
289
- The signed JWT token
290
- """
291
- now = int(time.time())
292
- expiry = now + (expires_in or self.config.token_expiration)
293
-
294
- payload = {
295
- "iss": self.config.issuer_url or "golf:auth",
296
- "sub": subject,
297
- "iat": now,
298
- "exp": expiry,
299
- "scp": scopes,
300
- }
301
-
302
- jwt_secret = self._get_jwt_secret()
303
- return jwt.encode(payload, jwt_secret, algorithm="HS256")
304
-
305
- def _verify_jwt(self, token: str) -> dict[str, Any] | None:
306
- """Verify a JWT token."""
307
- jwt_secret = self._get_jwt_secret() # Get secret first
308
- # _diag_logger.info(f"GolfOAuthProvider: _verify_jwt attempting to use secret: {jwt_secret[:5]}...")
309
-
310
- try:
311
- payload = jwt.decode(
312
- token,
313
- jwt_secret,
314
- algorithms=["HS256"],
315
- options={"verify_signature": True},
316
- )
317
-
318
- if payload.get("exp", 0) < time.time():
319
- exp_timestamp = payload.get("exp")
320
- current_timestamp = time.time()
321
- (
322
- str(datetime.fromtimestamp(exp_timestamp))
323
- if exp_timestamp is not None
324
- else "N/A"
325
- )
326
- str(datetime.fromtimestamp(current_timestamp))
327
- return None
328
- return payload
329
- except jwt.ExpiredSignatureError:
330
- return None
331
- except jwt.PyJWTError:
332
- return None
333
- except Exception: # Catch any other unexpected error during decode
334
- return None
335
-
336
- async def get_client(self, client_id: str) -> OAuthClientInformationFull | None:
337
- """Get client information by ID.
338
-
339
- Args:
340
- client_id: The client ID
341
-
342
- Returns:
343
- The client information or None if not found
344
- """
345
- return self.storage.get_client(client_id)
346
-
347
- async def register_client(self, client_info: OAuthClientInformationFull) -> None:
348
- """Register a new client."""
349
- # Add detailed logging at the beginning
350
- getattr(
351
- client_info, "client_id", "UNKNOWN (client_info has no client_id attribute)"
352
- )
353
- try:
354
- # Validate the client information
355
- if not client_info.client_id:
356
- raise RegistrationError(
357
- error="invalid_client_metadata",
358
- error_description="Client ID is missing in client_info provided to register_client",
359
- )
360
-
361
- if not client_info.redirect_uris:
362
- raise RegistrationError(
363
- error="invalid_redirect_uri",
364
- error_description="At least one redirect URI is required",
365
- )
366
-
367
- # Store the client
368
- self.storage.store_client(client_info.client_id, client_info)
369
- except Exception:
370
- raise # Re-raise the exception so FastMCP can handle it
371
-
372
- async def authorize(
373
- self,
374
- client: OAuthClientInformationFull,
375
- params: AuthorizationParams, # params from MCP client
376
- ) -> str:
377
- """Handle an authorization request.
378
- This method is called when an MCP client requests authorization.
379
- It should return a URL to redirect the user to the external IdP (e.g., GitHub).
380
- """
381
- import secrets
382
- import urllib.parse
383
-
384
- idp_flow_state = secrets.token_hex(16)
385
- mcp_client_original_state = params.state
386
-
387
- self.state_mapping[idp_flow_state] = {
388
- "client_id": client.client_id,
389
- "redirect_uri": str(params.redirect_uri),
390
- "code_challenge": params.code_challenge,
391
- "code_challenge_method": (
392
- "S256" if params.code_challenge else None
393
- ), # Store S256 if challenge exists, else None
394
- "scopes": params.scopes,
395
- "redirect_uri_provided_explicitly": params.redirect_uri_provided_explicitly,
396
- "mcp_client_original_state": mcp_client_original_state,
397
- }
398
-
399
- # Use self.config.callback_path for consistency
400
- idp_callback_uri = (
401
- f"{self.config.issuer_url.rstrip('/')}{self.config.callback_path}"
402
- )
403
-
404
- client_id = self._get_client_id()
405
-
406
- auth_params_for_idp = {
407
- "client_id": client_id,
408
- "redirect_uri": idp_callback_uri,
409
- "scope": " ".join(self.config.scopes),
410
- "state": idp_flow_state,
411
- "response_type": "code",
412
- }
413
-
414
- if params.code_challenge:
415
- auth_params_for_idp["code_challenge"] = params.code_challenge
416
- # Always use S256 if a challenge is present, as it's the standard and what the client sends.
417
- auth_params_for_idp["code_challenge_method"] = "S256"
418
-
419
- query_for_idp = urllib.parse.urlencode(auth_params_for_idp)
420
-
421
- return f"{self.config.authorize_url}?{query_for_idp}"
422
-
423
- async def load_authorization_code(
424
- self, client: OAuthClientInformationFull, code: str
425
- ) -> AuthorizationCode | None:
426
- """Load an authorization code.
427
-
428
- Args:
429
- client: The client information
430
- code: The authorization code
431
-
432
- Returns:
433
- The authorization code object or None if not found
434
- """
435
- auth_code = self.storage.get_auth_code(code)
436
-
437
- if not auth_code:
438
- return None
439
-
440
- # Verify the code belongs to this client
441
- if auth_code.client_id != client.client_id:
442
- return None
443
-
444
- # Verify the code hasn't expired
445
- if auth_code.expires_at and auth_code.expires_at < datetime.now().timestamp():
446
- self.storage.delete_auth_code(code)
447
- return None
448
-
449
- return auth_code
450
-
451
- async def exchange_authorization_code(
452
- self,
453
- client: OAuthClientInformationFull,
454
- code: AuthorizationCode, # This is AuthorizationCode object
455
- ) -> OAuthToken:
456
- """Exchange an authorization code for tokens.
457
-
458
- Args:
459
- client: The client information
460
- code: The authorization code object
461
-
462
- Returns:
463
- The OAuth token response
464
-
465
- Raises:
466
- TokenError: If the code exchange fails
467
- """
468
- # Retrieve the provider token that was stored temporarily during callback
469
- provider_token = self.storage.get_provider_token_for_auth_code(code.code)
470
-
471
- # Delete the code and its mapping to ensure one-time use
472
- self.storage.delete_auth_code(code.code) # This now also deletes the mapping
473
-
474
- # Generate an access token
475
- access_token_str = self._generate_jwt( # Renamed for clarity
476
- subject=client.client_id, scopes=code.scopes
477
- )
478
-
479
- # Generate a refresh token if needed
480
- refresh_token_str = (
481
- str(uuid.uuid4()) if "refresh_token" in client.grant_types else None
482
- ) # Renamed for clarity
483
-
484
- # Store the mapping from our new MCP access token to the provider's access token
485
- if provider_token and access_token_str:
486
- self.storage.store_provider_token(access_token_str, provider_token)
487
-
488
- # Store the tokens
489
- if refresh_token_str:
490
- self.storage.store_refresh_token(
491
- refresh_token_str,
492
- RefreshToken(
493
- token=refresh_token_str,
494
- client_id=client.client_id,
495
- scopes=code.scopes,
496
- expires_at=int(
497
- datetime.now().timestamp() + (self.config.token_expiration * 24)
498
- ), # 24x longer, cast to int
499
- ),
500
- )
501
-
502
- # Store access token information for validation later
503
- # Note: For JWTs, we might not need to store them if we can verify the signature
504
- self.storage.store_access_token(
505
- access_token_str,
506
- AccessToken(
507
- token=access_token_str,
508
- client_id=client.client_id,
509
- scopes=code.scopes,
510
- expires_at=int(
511
- datetime.now().timestamp() + self.config.token_expiration
512
- ), # Cast to int
513
- ),
514
- )
515
-
516
- # Create and return the OAuth token response
517
- return OAuthToken(
518
- access_token=access_token_str,
519
- token_type="bearer",
520
- expires_in=self.config.token_expiration,
521
- refresh_token=refresh_token_str,
522
- scope=" ".join(code.scopes),
523
- )
524
-
525
- async def load_refresh_token(
526
- self, client: OAuthClientInformationFull, refresh_token: str
527
- ) -> RefreshToken | None:
528
- """Load a refresh token.
529
-
530
- Args:
531
- client: The client information
532
- refresh_token: The refresh token string
533
-
534
- Returns:
535
- The refresh token object or None if not found
536
- """
537
- token = self.storage.get_refresh_token(refresh_token)
538
-
539
- if not token:
540
- return None
541
-
542
- # Verify the token belongs to this client
543
- if token.client_id != client.client_id:
544
- return None
545
-
546
- # Verify the token hasn't expired
547
- if token.expires_at and token.expires_at < datetime.now().timestamp():
548
- self.storage.delete_refresh_token(refresh_token)
549
- return None
550
-
551
- return token
552
-
553
- async def exchange_refresh_token(
554
- self,
555
- client: OAuthClientInformationFull,
556
- refresh_token: RefreshToken,
557
- scopes: list[str],
558
- ) -> OAuthToken:
559
- """Exchange a refresh token for a new token pair.
560
-
561
- Args:
562
- client: The client information
563
- refresh_token: The refresh token object
564
- scopes: The requested scopes (may be a subset of original)
565
-
566
- Returns:
567
- The new OAuth token response
568
-
569
- Raises:
570
- TokenError: If the token exchange fails
571
- """
572
- # Delete the old refresh token (implement token rotation for security)
573
- self.storage.delete_refresh_token(refresh_token.token)
574
-
575
- # Determine the scopes for the new token
576
- # If requested scopes are provided, they must be a subset of the original
577
- if scopes:
578
- valid_scopes = [s for s in scopes if s in refresh_token.scopes]
579
- if not valid_scopes:
580
- valid_scopes = refresh_token.scopes
581
- else:
582
- valid_scopes = refresh_token.scopes
583
-
584
- # Generate a new access token
585
- access_token = self._generate_jwt(subject=client.client_id, scopes=valid_scopes)
586
-
587
- # Generate a new refresh token
588
- new_refresh_token = str(uuid.uuid4())
589
-
590
- # Find the provider token if it exists from the old access token
591
- # Note: This assumes each refresh generates only one access token
592
- old_access_tokens = [
593
- token
594
- for token, data in self.storage.access_tokens.items()
595
- if data.client_id == client.client_id
596
- ]
597
- provider_token = None
598
- for old_token in old_access_tokens:
599
- provider_token = self.storage.get_provider_token(old_token)
600
- if provider_token:
601
- # Store the provider token mapping for the new access token
602
- self.storage.store_provider_token(access_token, provider_token)
603
- break
604
-
605
- # Store the new tokens
606
- self.storage.store_refresh_token(
607
- new_refresh_token,
608
- RefreshToken(
609
- token=new_refresh_token,
610
- client_id=client.client_id,
611
- scopes=valid_scopes,
612
- expires_at=int(
613
- datetime.now().timestamp() + (self.config.token_expiration * 24)
614
- ), # Cast to int
615
- ),
616
- )
617
-
618
- # Store access token information
619
- self.storage.store_access_token(
620
- access_token,
621
- AccessToken(
622
- token=access_token,
623
- client_id=client.client_id,
624
- scopes=valid_scopes,
625
- expires_at=int(
626
- datetime.now().timestamp() + self.config.token_expiration
627
- ), # Cast to int
628
- ),
629
- )
630
-
631
- # Create and return the OAuth token response
632
- return OAuthToken(
633
- access_token=access_token,
634
- token_type="bearer",
635
- expires_in=self.config.token_expiration,
636
- refresh_token=new_refresh_token,
637
- scope=" ".join(valid_scopes),
638
- )
639
-
640
- async def load_access_token(self, token: str) -> AccessToken | None:
641
- """Load and validate an access token."""
642
-
643
- payload = self._verify_jwt(token)
644
- if not payload:
645
- return None
646
-
647
- client_id = payload.get("sub")
648
- scopes = payload.get("scp", [])
649
- expires_at = payload.get("exp")
650
-
651
- access_token_obj = AccessToken(
652
- token=token,
653
- client_id=client_id,
654
- scopes=scopes,
655
- expires_at=int(expires_at) if expires_at is not None else None,
656
- )
657
-
658
- return access_token_obj
659
-
660
- async def revoke_token(self, token: AccessToken | RefreshToken) -> None:
661
- """Revoke a token.
662
-
663
- Args:
664
- token: The token to revoke (access or refresh)
665
- """
666
- # Try to revoke as access token
667
- self.storage.delete_access_token(token.token)
668
-
669
- # Try to revoke as refresh token
670
- self.storage.delete_refresh_token(token.token)
671
-
672
- # Clean up provider token mapping if it exists
673
- provider_token = self.storage.get_provider_token(token.token)
674
- if provider_token:
675
- self.storage.provider_tokens.pop(token.token, None)
676
-
677
- def get_provider_token(self, mcp_token: str) -> str | None:
678
- """Get the provider token associated with an MCP token.
679
-
680
- This is a non-standard method to allow access to the provider token
681
- (e.g., GitHub token) for a given MCP token. This can be used by
682
- tools that need to access provider APIs.
683
-
684
- Args:
685
- mcp_token: The MCP token
686
-
687
- Returns:
688
- The provider token or None if not found
689
- """
690
- return self.storage.get_provider_token(mcp_token)
691
-
692
-
693
- def create_callback_handler(provider: GolfOAuthProvider):
694
- """Create a callback handler for OAuth authorization.
695
-
696
- This function creates a callback handler that can be used to handle
697
- the OAuth callback from the provider (e.g., GitHub).
698
-
699
- Args:
700
- provider: The OAuth provider
701
-
702
- Returns:
703
- An async function that handles the callback
704
- """
705
-
706
- async def handle_callback(request):
707
- """Handle the OAuth callback.
708
-
709
- Args:
710
- request: The HTTP request
711
-
712
- Returns:
713
- The HTTP response
714
- """
715
- # Extract the code and state from the request
716
- idp_auth_code = request.query_params.get(
717
- "code"
718
- ) # Renamed for clarity: code from IdP
719
- idp_state = request.query_params.get(
720
- "state"
721
- ) # Renamed for clarity: state from IdP
722
-
723
- if not idp_auth_code:
724
- return RedirectResponse(
725
- "/auth-error?error=no_code_from_idp"
726
- ) # More specific error
727
-
728
- # Use provider.config.callback_path for consistency
729
- # This is the redirect_uri registered with the IdP and used in the /authorize step
730
- idp_callback_uri_for_token_exchange = (
731
- f"{provider.config.issuer_url.rstrip('/')}{provider.config.callback_path}"
732
- )
733
-
734
- client_id_for_idp = provider._get_client_id()
735
- client_secret_for_idp = provider._get_client_secret()
736
-
737
- async with httpx.AsyncClient() as client:
738
- response = await client.post(
739
- provider.config.token_url,
740
- headers={"Accept": "application/json"},
741
- data={
742
- "client_id": client_id_for_idp,
743
- "client_secret": client_secret_for_idp,
744
- "code": idp_auth_code, # Use code from IdP
745
- "redirect_uri": idp_callback_uri_for_token_exchange,
746
- },
747
- )
748
-
749
- if response.status_code != 200:
750
- error_detail = response.text[:200] # Limit error detail length
751
- return RedirectResponse(
752
- f"/auth-error?error=idp_token_exchange_failed&detail={urllib.parse.quote(error_detail)}"
753
- )
754
-
755
- # Get the provider token from the response
756
- token_data = response.json()
757
- provider_access_token = token_data.get(
758
- "access_token"
759
- ) # This is the token from GitHub/Google etc.
760
-
761
- if not provider_access_token:
762
- return RedirectResponse("/auth-error?error=no_access_token_from_idp")
763
-
764
- try:
765
- # Get user information from the provider using the token (optional step)
766
- # user_info = None (keep this if user_info is used later, otherwise remove)
767
- # ... (userinfo fetching logic if needed) ...
768
-
769
- original_mcp_client_details = provider.state_mapping.pop(
770
- idp_state, None
771
- ) # Use state from IdP
772
- if not original_mcp_client_details:
773
- return RedirectResponse("/auth-error?error=invalid_idp_state")
774
-
775
- original_mcp_client_id = original_mcp_client_details["client_id"]
776
- original_mcp_redirect_uri = original_mcp_client_details[
777
- "redirect_uri"
778
- ] # MCP client's redirect_uri
779
- original_code_challenge = original_mcp_client_details["code_challenge"]
780
- original_code_challenge_method = original_mcp_client_details[
781
- "code_challenge_method"
782
- ]
783
-
784
- requested_scopes_for_mcp_server_str = original_mcp_client_details["scopes"]
785
- mcp_client_original_state_to_pass_back = original_mcp_client_details.get(
786
- "mcp_client_original_state"
787
- )
788
- original_redirect_uri_provided_explicitly = original_mcp_client_details[
789
- "redirect_uri_provided_explicitly"
790
- ]
791
-
792
- mcp_client = await provider.get_client(
793
- original_mcp_client_id
794
- ) # Renamed for clarity
795
- if not mcp_client:
796
- return RedirectResponse(
797
- "/auth-error?error=mcp_client_not_found_post_callback"
798
- )
799
-
800
- final_scopes_for_mcp_auth_code: list[str]
801
- if requested_scopes_for_mcp_server_str: # Scopes requested by MCP client
802
- final_scopes_for_mcp_auth_code = (
803
- requested_scopes_for_mcp_server_str.split()
804
- )
805
- else: # Default to client's registered scopes if none explicitly requested
806
- final_scopes_for_mcp_auth_code = (
807
- mcp_client.scope.split() if mcp_client.scope else []
808
- )
809
-
810
- # This is the auth code our GolfMCP server issues to the MCP client
811
- mcp_auth_code_str = str(uuid.uuid4())
812
-
813
- # Store the mapping from our mcp_auth_code_str to the provider_access_token (e.g., GitHub token)
814
- # This will be retrieved when the MCP client exchanges mcp_auth_code_str for an MCP access token
815
- provider.storage.store_auth_code_provider_token_mapping(
816
- mcp_auth_code_str, provider_access_token
817
- )
818
-
819
- # Create the AuthorizationCode object for our server
820
- mcp_auth_code_obj = AuthorizationCode( # Renamed for clarity
821
- code=mcp_auth_code_str,
822
- client_id=mcp_client.client_id,
823
- redirect_uri=original_mcp_redirect_uri,
824
- scopes=final_scopes_for_mcp_auth_code,
825
- expires_at=int(
826
- datetime.now().timestamp() + 600
827
- ), # 10 minutes, cast to int
828
- redirect_uri_provided_explicitly=original_redirect_uri_provided_explicitly,
829
- code_challenge=original_code_challenge,
830
- code_challenge_method=original_code_challenge_method,
831
- )
832
-
833
- # Store our auth code object (without provider_token as an attribute)
834
- provider.storage.store_auth_code(mcp_auth_code_str, mcp_auth_code_obj)
835
-
836
- query_params_for_mcp_client = {
837
- "code": mcp_auth_code_str # Send our generated auth code to the MCP client
838
- }
839
- if mcp_client_original_state_to_pass_back:
840
- query_params_for_mcp_client["state"] = (
841
- mcp_client_original_state_to_pass_back
842
- )
843
-
844
- import urllib.parse # Ensure it's imported here too
845
-
846
- final_query_for_mcp_client = urllib.parse.urlencode(
847
- query_params_for_mcp_client
848
- )
849
- final_redirect_to_mcp_client = (
850
- f"{original_mcp_redirect_uri}?{final_query_for_mcp_client}"
851
- )
852
-
853
- return RedirectResponse(final_redirect_to_mcp_client)
854
-
855
- except Exception:
856
- # Avoid sending raw exception details to the client for security
857
- return RedirectResponse(
858
- "/auth-error?error=callback_processing_failed&detail=internal_server_error"
859
- )
860
-
861
- return handle_callback