golf-mcp 0.1.11__py3-none-any.whl → 0.1.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of golf-mcp might be problematic. Click here for more details.
- golf/__init__.py +1 -1
- golf/auth/__init__.py +38 -26
- golf/auth/api_key.py +16 -23
- golf/auth/helpers.py +68 -54
- golf/auth/oauth.py +340 -277
- golf/auth/provider.py +58 -53
- golf/cli/__init__.py +1 -1
- golf/cli/main.py +209 -87
- golf/commands/__init__.py +1 -1
- golf/commands/build.py +31 -25
- golf/commands/init.py +81 -53
- golf/commands/run.py +30 -15
- golf/core/__init__.py +1 -1
- golf/core/builder.py +493 -362
- golf/core/builder_auth.py +115 -107
- golf/core/builder_telemetry.py +12 -9
- golf/core/config.py +62 -46
- golf/core/parser.py +174 -136
- golf/core/telemetry.py +216 -95
- golf/core/transformer.py +53 -55
- golf/examples/__init__.py +0 -1
- golf/examples/api_key/pre_build.py +2 -2
- golf/examples/api_key/tools/issues/create.py +35 -36
- golf/examples/api_key/tools/issues/list.py +42 -37
- golf/examples/api_key/tools/repos/list.py +50 -29
- golf/examples/api_key/tools/search/code.py +50 -37
- golf/examples/api_key/tools/users/get.py +21 -20
- golf/examples/basic/pre_build.py +4 -4
- golf/examples/basic/prompts/welcome.py +6 -7
- golf/examples/basic/resources/current_time.py +10 -9
- golf/examples/basic/resources/info.py +6 -5
- golf/examples/basic/resources/weather/common.py +16 -10
- golf/examples/basic/resources/weather/current.py +15 -11
- golf/examples/basic/resources/weather/forecast.py +15 -11
- golf/examples/basic/tools/github_user.py +19 -21
- golf/examples/basic/tools/hello.py +10 -6
- golf/examples/basic/tools/payments/charge.py +34 -25
- golf/examples/basic/tools/payments/common.py +8 -6
- golf/examples/basic/tools/payments/refund.py +29 -25
- golf/telemetry/__init__.py +6 -6
- golf/telemetry/instrumentation.py +455 -310
- {golf_mcp-0.1.11.dist-info → golf_mcp-0.1.13.dist-info}/METADATA +1 -1
- golf_mcp-0.1.13.dist-info/RECORD +55 -0
- golf_mcp-0.1.11.dist-info/RECORD +0 -55
- {golf_mcp-0.1.11.dist-info → golf_mcp-0.1.13.dist-info}/WHEEL +0 -0
- {golf_mcp-0.1.11.dist-info → golf_mcp-0.1.13.dist-info}/entry_points.txt +0 -0
- {golf_mcp-0.1.11.dist-info → golf_mcp-0.1.13.dist-info}/licenses/LICENSE +0 -0
- {golf_mcp-0.1.11.dist-info → golf_mcp-0.1.13.dist-info}/top_level.txt +0 -0
golf/auth/oauth.py
CHANGED
|
@@ -5,25 +5,25 @@ interface for GolfMCP servers. It handles the OAuth 2.0 authentication flow,
|
|
|
5
5
|
token management, and client registration.
|
|
6
6
|
"""
|
|
7
7
|
|
|
8
|
+
import os
|
|
8
9
|
import time
|
|
9
10
|
import uuid
|
|
10
|
-
import jwt
|
|
11
|
-
import httpx
|
|
12
|
-
import os
|
|
13
|
-
from typing import Dict, List, Optional, Any, Union
|
|
14
11
|
from datetime import datetime
|
|
12
|
+
from typing import Any
|
|
15
13
|
|
|
14
|
+
import httpx
|
|
15
|
+
import jwt
|
|
16
16
|
from mcp.server.auth.provider import (
|
|
17
|
-
OAuthAuthorizationServerProvider,
|
|
18
17
|
AccessToken,
|
|
19
|
-
RefreshToken,
|
|
20
18
|
AuthorizationCode,
|
|
19
|
+
AuthorizationParams,
|
|
20
|
+
OAuthAuthorizationServerProvider,
|
|
21
|
+
RefreshToken,
|
|
21
22
|
RegistrationError,
|
|
22
|
-
AuthorizationParams
|
|
23
23
|
)
|
|
24
24
|
from mcp.shared.auth import (
|
|
25
|
-
OAuthToken,
|
|
26
25
|
OAuthClientInformationFull,
|
|
26
|
+
OAuthToken,
|
|
27
27
|
)
|
|
28
28
|
from starlette.responses import RedirectResponse
|
|
29
29
|
|
|
@@ -32,45 +32,47 @@ from .provider import ProviderConfig
|
|
|
32
32
|
|
|
33
33
|
class TokenStorage:
|
|
34
34
|
"""Simple in-memory token storage.
|
|
35
|
-
|
|
35
|
+
|
|
36
36
|
This class provides a simple in-memory storage for OAuth tokens,
|
|
37
37
|
authorization codes, and client information. In a production
|
|
38
38
|
environment, this should be replaced with a persistent storage
|
|
39
39
|
solution.
|
|
40
40
|
"""
|
|
41
|
-
|
|
42
|
-
def __init__(self):
|
|
41
|
+
|
|
42
|
+
def __init__(self) -> None:
|
|
43
43
|
"""Initialize the token storage."""
|
|
44
44
|
self.auth_codes = {} # code_str -> AuthorizationCode
|
|
45
45
|
self.refresh_tokens = {} # token_str -> RefreshToken
|
|
46
46
|
self.access_tokens = {} # token_str -> AccessToken
|
|
47
47
|
self.clients = {} # client_id -> OAuthClientInformationFull
|
|
48
48
|
self.provider_tokens = {} # mcp_access_token_str -> provider_access_token_str
|
|
49
|
-
self.auth_code_to_provider_token = {}
|
|
50
|
-
|
|
51
|
-
def store_auth_code(
|
|
49
|
+
self.auth_code_to_provider_token = {} # auth_code_str -> provider_access_token_str
|
|
50
|
+
|
|
51
|
+
def store_auth_code(
|
|
52
|
+
self, code: str, auth_code_obj: AuthorizationCode
|
|
53
|
+
) -> None: # Renamed auth_code to auth_code_obj for clarity
|
|
52
54
|
"""Store an authorization code.
|
|
53
|
-
|
|
55
|
+
|
|
54
56
|
Args:
|
|
55
57
|
code: The authorization code string
|
|
56
58
|
auth_code_obj: The authorization code object
|
|
57
59
|
"""
|
|
58
60
|
self.auth_codes[code] = auth_code_obj
|
|
59
|
-
|
|
60
|
-
def get_auth_code(self, code: str) ->
|
|
61
|
+
|
|
62
|
+
def get_auth_code(self, code: str) -> AuthorizationCode | None:
|
|
61
63
|
"""Get an authorization code by value.
|
|
62
|
-
|
|
64
|
+
|
|
63
65
|
Args:
|
|
64
66
|
code: The authorization code string
|
|
65
|
-
|
|
67
|
+
|
|
66
68
|
Returns:
|
|
67
69
|
The authorization code object or None if not found
|
|
68
70
|
"""
|
|
69
71
|
return self.auth_codes.get(code)
|
|
70
|
-
|
|
72
|
+
|
|
71
73
|
def delete_auth_code(self, code: str) -> None:
|
|
72
74
|
"""Delete an authorization code and its associated provider token mapping.
|
|
73
|
-
|
|
75
|
+
|
|
74
76
|
Args:
|
|
75
77
|
code: The authorization code string
|
|
76
78
|
"""
|
|
@@ -78,288 +80,299 @@ class TokenStorage:
|
|
|
78
80
|
del self.auth_codes[code]
|
|
79
81
|
if code in self.auth_code_to_provider_token:
|
|
80
82
|
del self.auth_code_to_provider_token[code]
|
|
81
|
-
|
|
82
|
-
def store_auth_code_provider_token_mapping(
|
|
83
|
+
|
|
84
|
+
def store_auth_code_provider_token_mapping(
|
|
85
|
+
self, auth_code_str: str, provider_token: str
|
|
86
|
+
) -> None:
|
|
83
87
|
"""Store a mapping from an auth_code string to a provider_token string."""
|
|
84
88
|
self.auth_code_to_provider_token[auth_code_str] = provider_token
|
|
85
89
|
|
|
86
|
-
def get_provider_token_for_auth_code(self, auth_code_str: str) ->
|
|
90
|
+
def get_provider_token_for_auth_code(self, auth_code_str: str) -> str | None:
|
|
87
91
|
"""Retrieve a provider_token string using an auth_code string."""
|
|
88
92
|
return self.auth_code_to_provider_token.get(auth_code_str)
|
|
89
|
-
|
|
93
|
+
|
|
90
94
|
def store_client(self, client_id: str, client: OAuthClientInformationFull) -> None:
|
|
91
95
|
"""Store client information.
|
|
92
|
-
|
|
96
|
+
|
|
93
97
|
Args:
|
|
94
98
|
client_id: The client ID
|
|
95
99
|
client: The client information
|
|
96
100
|
"""
|
|
97
101
|
self.clients[client_id] = client
|
|
98
|
-
|
|
99
|
-
def get_client(self, client_id: str) ->
|
|
102
|
+
|
|
103
|
+
def get_client(self, client_id: str) -> OAuthClientInformationFull | None:
|
|
100
104
|
"""Get client information by ID.
|
|
101
|
-
|
|
105
|
+
|
|
102
106
|
Args:
|
|
103
107
|
client_id: The client ID
|
|
104
|
-
|
|
108
|
+
|
|
105
109
|
Returns:
|
|
106
110
|
The client information or None if not found
|
|
107
111
|
"""
|
|
108
112
|
# _diag_logger.info(f"TokenStorage: get_client called for client_id '{client_id}'. Known clients: {list(self.clients.keys())}") # Optional: uncomment for debugging
|
|
109
113
|
return self.clients.get(client_id)
|
|
110
|
-
|
|
114
|
+
|
|
111
115
|
def store_refresh_token(self, token: str, refresh_token: RefreshToken) -> None:
|
|
112
116
|
"""Store a refresh token.
|
|
113
|
-
|
|
117
|
+
|
|
114
118
|
Args:
|
|
115
119
|
token: The refresh token string
|
|
116
120
|
refresh_token: The refresh token object
|
|
117
121
|
"""
|
|
118
122
|
self.refresh_tokens[token] = refresh_token
|
|
119
|
-
|
|
120
|
-
def get_refresh_token(self, token: str) ->
|
|
123
|
+
|
|
124
|
+
def get_refresh_token(self, token: str) -> RefreshToken | None:
|
|
121
125
|
"""Get a refresh token by value.
|
|
122
|
-
|
|
126
|
+
|
|
123
127
|
Args:
|
|
124
128
|
token: The refresh token string
|
|
125
|
-
|
|
129
|
+
|
|
126
130
|
Returns:
|
|
127
131
|
The refresh token object or None if not found
|
|
128
132
|
"""
|
|
129
133
|
return self.refresh_tokens.get(token)
|
|
130
|
-
|
|
134
|
+
|
|
131
135
|
def delete_refresh_token(self, token: str) -> None:
|
|
132
136
|
"""Delete a refresh token.
|
|
133
|
-
|
|
137
|
+
|
|
134
138
|
Args:
|
|
135
139
|
token: The refresh token string
|
|
136
140
|
"""
|
|
137
141
|
if token in self.refresh_tokens:
|
|
138
142
|
del self.refresh_tokens[token]
|
|
139
|
-
|
|
143
|
+
|
|
140
144
|
def store_access_token(self, token: str, access_token: AccessToken) -> None:
|
|
141
145
|
"""Store an access token.
|
|
142
|
-
|
|
146
|
+
|
|
143
147
|
Args:
|
|
144
148
|
token: The access token string
|
|
145
149
|
access_token: The access token object
|
|
146
150
|
"""
|
|
147
151
|
self.access_tokens[token] = access_token
|
|
148
|
-
|
|
149
|
-
def get_access_token(self, token: str) ->
|
|
152
|
+
|
|
153
|
+
def get_access_token(self, token: str) -> AccessToken | None:
|
|
150
154
|
"""Get an access token by value.
|
|
151
|
-
|
|
155
|
+
|
|
152
156
|
Args:
|
|
153
157
|
token: The access token string
|
|
154
|
-
|
|
158
|
+
|
|
155
159
|
Returns:
|
|
156
160
|
The access token object or None if not found
|
|
157
161
|
"""
|
|
158
162
|
return self.access_tokens.get(token)
|
|
159
|
-
|
|
163
|
+
|
|
160
164
|
def delete_access_token(self, token: str) -> None:
|
|
161
165
|
"""Delete an access token.
|
|
162
|
-
|
|
166
|
+
|
|
163
167
|
Args:
|
|
164
168
|
token: The access token string
|
|
165
169
|
"""
|
|
166
170
|
if token in self.access_tokens:
|
|
167
171
|
del self.access_tokens[token]
|
|
168
|
-
|
|
172
|
+
|
|
169
173
|
def store_provider_token(self, mcp_token: str, provider_token: str) -> None:
|
|
170
174
|
"""Store a provider token mapping.
|
|
171
|
-
|
|
175
|
+
|
|
172
176
|
Args:
|
|
173
177
|
mcp_token: The MCP token string
|
|
174
178
|
provider_token: The provider token string (e.g., GitHub token)
|
|
175
179
|
"""
|
|
176
180
|
self.provider_tokens[mcp_token] = provider_token
|
|
177
|
-
|
|
178
|
-
def get_provider_token(self, mcp_token: str) ->
|
|
181
|
+
|
|
182
|
+
def get_provider_token(self, mcp_token: str) -> str | None:
|
|
179
183
|
"""Get the provider token associated with an MCP token.
|
|
180
|
-
|
|
184
|
+
|
|
181
185
|
This is a non-standard method to allow access to the provider token
|
|
182
186
|
(e.g., GitHub token) for a given MCP token. This can be used by
|
|
183
187
|
tools that need to access provider APIs.
|
|
184
|
-
|
|
188
|
+
|
|
185
189
|
Args:
|
|
186
190
|
mcp_token: The MCP token
|
|
187
|
-
|
|
191
|
+
|
|
188
192
|
Returns:
|
|
189
193
|
The provider token or None if not found
|
|
190
194
|
"""
|
|
191
|
-
return self.provider_tokens.get(mcp_token)
|
|
195
|
+
return self.provider_tokens.get(mcp_token) # Changed from self.storage to self
|
|
192
196
|
|
|
193
197
|
|
|
194
198
|
class GolfOAuthProvider(OAuthAuthorizationServerProvider):
|
|
195
199
|
"""OAuth provider implementation for GolfMCP.
|
|
196
|
-
|
|
200
|
+
|
|
197
201
|
This class implements the OAuthAuthorizationServerProvider interface
|
|
198
202
|
for GolfMCP servers. It handles the OAuth 2.0 authentication flow,
|
|
199
203
|
token management, and client registration.
|
|
200
204
|
"""
|
|
201
|
-
|
|
202
|
-
def __init__(self, config: ProviderConfig):
|
|
205
|
+
|
|
206
|
+
def __init__(self, config: ProviderConfig) -> None:
|
|
203
207
|
"""Initialize the provider.
|
|
204
|
-
|
|
208
|
+
|
|
205
209
|
Args:
|
|
206
210
|
config: The provider configuration
|
|
207
211
|
"""
|
|
208
212
|
self.config = config
|
|
209
213
|
self.storage = TokenStorage()
|
|
210
|
-
self.state_mapping:
|
|
211
|
-
|
|
214
|
+
self.state_mapping: dict[str, dict[str, Any]] = {} # Initialize state_mapping
|
|
215
|
+
|
|
212
216
|
# Register default client
|
|
213
217
|
self._register_default_client()
|
|
214
|
-
|
|
218
|
+
|
|
215
219
|
def _get_client_id(self) -> str:
|
|
216
220
|
"""Get the client ID from config or environment."""
|
|
217
221
|
if self.config.client_id:
|
|
218
222
|
return self.config.client_id
|
|
219
|
-
|
|
223
|
+
|
|
220
224
|
if self.config.client_id_env_var:
|
|
221
225
|
value = os.environ.get(self.config.client_id_env_var)
|
|
222
226
|
if value:
|
|
223
227
|
return value
|
|
224
|
-
|
|
228
|
+
|
|
225
229
|
return "missing-client-id"
|
|
226
|
-
|
|
230
|
+
|
|
227
231
|
def _get_client_secret(self) -> str:
|
|
228
232
|
"""Get the client secret from config or environment."""
|
|
229
233
|
if self.config.client_secret:
|
|
230
234
|
return self.config.client_secret
|
|
231
|
-
|
|
235
|
+
|
|
232
236
|
if self.config.client_secret_env_var:
|
|
233
237
|
value = os.environ.get(self.config.client_secret_env_var)
|
|
234
238
|
if value:
|
|
235
239
|
return value
|
|
236
|
-
|
|
240
|
+
|
|
237
241
|
return "missing-client-secret"
|
|
238
|
-
|
|
242
|
+
|
|
239
243
|
def _get_jwt_secret(self) -> str:
|
|
240
244
|
"""Get the JWT secret from config. It's expected to be resolved by server startup."""
|
|
241
245
|
if self.config.jwt_secret:
|
|
242
246
|
# _diag_logger.info(f"GolfOAuthProvider: Using JWT secret from config: {self.config.jwt_secret[:5]}...")
|
|
243
247
|
return self.config.jwt_secret
|
|
244
248
|
else:
|
|
245
|
-
raise ValueError(
|
|
246
|
-
|
|
249
|
+
raise ValueError(
|
|
250
|
+
"JWT Secret is not configured in the provider. Check server logs and environment variables."
|
|
251
|
+
)
|
|
252
|
+
|
|
247
253
|
def _register_default_client(self) -> None:
|
|
248
254
|
"""Register a default client for MCP."""
|
|
249
255
|
# These are the URIs where *this server* is allowed to redirect an MCP client
|
|
250
256
|
# after successful authentication and MCP auth code generation.
|
|
251
257
|
client_redirect_uris = [
|
|
252
258
|
# Common redirect URI for MCP Inspector running locally
|
|
253
|
-
"http://localhost:5173/callback",
|
|
259
|
+
"http://localhost:5173/callback",
|
|
254
260
|
"http://127.0.0.1:5173/callback",
|
|
255
261
|
# A generic callback relative to the server's issuer URL, if needed by some clients
|
|
256
262
|
# This assumes such a client-side endpoint exists.
|
|
257
|
-
f"{self.config.issuer_url.rstrip('/') if self.config.issuer_url else 'http://localhost:3000'}/client/callback"
|
|
263
|
+
f"{self.config.issuer_url.rstrip('/') if self.config.issuer_url else 'http://localhost:3000'}/client/callback",
|
|
258
264
|
]
|
|
259
265
|
|
|
260
266
|
default_client = OAuthClientInformationFull(
|
|
261
267
|
client_id="default",
|
|
262
268
|
client_name="Default MCP Client",
|
|
263
269
|
client_secret="", # Public client
|
|
264
|
-
redirect_uris=client_redirect_uris,
|
|
270
|
+
redirect_uris=client_redirect_uris,
|
|
265
271
|
grant_types=["authorization_code", "refresh_token"],
|
|
266
272
|
response_types=["code"],
|
|
267
273
|
token_endpoint_auth_method="none", # Public client
|
|
268
|
-
scope=" ".join(self.config.scopes)
|
|
274
|
+
scope=" ".join(self.config.scopes),
|
|
269
275
|
)
|
|
270
276
|
self.storage.store_client("default", default_client)
|
|
271
|
-
|
|
277
|
+
|
|
272
278
|
def _generate_jwt(
|
|
273
|
-
self,
|
|
274
|
-
subject: str,
|
|
275
|
-
scopes: List[str],
|
|
276
|
-
expires_in: int = None
|
|
279
|
+
self, subject: str, scopes: list[str], expires_in: int = None
|
|
277
280
|
) -> str:
|
|
278
281
|
"""Generate a JWT token.
|
|
279
|
-
|
|
282
|
+
|
|
280
283
|
Args:
|
|
281
284
|
subject: The subject of the token (usually client_id)
|
|
282
285
|
scopes: The scopes granted to the token
|
|
283
286
|
expires_in: The token lifetime in seconds (or None for default)
|
|
284
|
-
|
|
287
|
+
|
|
285
288
|
Returns:
|
|
286
289
|
The signed JWT token
|
|
287
290
|
"""
|
|
288
291
|
now = int(time.time())
|
|
289
292
|
expiry = now + (expires_in or self.config.token_expiration)
|
|
290
|
-
|
|
293
|
+
|
|
291
294
|
payload = {
|
|
292
295
|
"iss": self.config.issuer_url or "golf:auth",
|
|
293
296
|
"sub": subject,
|
|
294
297
|
"iat": now,
|
|
295
298
|
"exp": expiry,
|
|
296
|
-
"scp": scopes
|
|
299
|
+
"scp": scopes,
|
|
297
300
|
}
|
|
298
|
-
|
|
301
|
+
|
|
299
302
|
jwt_secret = self._get_jwt_secret()
|
|
300
303
|
return jwt.encode(payload, jwt_secret, algorithm="HS256")
|
|
301
|
-
|
|
302
|
-
def _verify_jwt(self, token: str) ->
|
|
304
|
+
|
|
305
|
+
def _verify_jwt(self, token: str) -> dict[str, Any] | None:
|
|
303
306
|
"""Verify a JWT token."""
|
|
304
|
-
jwt_secret = self._get_jwt_secret()
|
|
307
|
+
jwt_secret = self._get_jwt_secret() # Get secret first
|
|
305
308
|
# _diag_logger.info(f"GolfOAuthProvider: _verify_jwt attempting to use secret: {jwt_secret[:5]}...")
|
|
306
309
|
|
|
307
310
|
try:
|
|
308
|
-
payload = jwt.decode(
|
|
309
|
-
|
|
311
|
+
payload = jwt.decode(
|
|
312
|
+
token,
|
|
313
|
+
jwt_secret,
|
|
314
|
+
algorithms=["HS256"],
|
|
315
|
+
options={"verify_signature": True},
|
|
316
|
+
)
|
|
317
|
+
|
|
310
318
|
if payload.get("exp", 0) < time.time():
|
|
311
319
|
exp_timestamp = payload.get("exp")
|
|
312
320
|
current_timestamp = time.time()
|
|
313
|
-
|
|
314
|
-
|
|
321
|
+
(
|
|
322
|
+
str(datetime.fromtimestamp(exp_timestamp))
|
|
323
|
+
if exp_timestamp is not None
|
|
324
|
+
else "N/A"
|
|
325
|
+
)
|
|
326
|
+
str(datetime.fromtimestamp(current_timestamp))
|
|
315
327
|
return None
|
|
316
328
|
return payload
|
|
317
|
-
except jwt.ExpiredSignatureError
|
|
329
|
+
except jwt.ExpiredSignatureError:
|
|
318
330
|
return None
|
|
319
|
-
except jwt.PyJWTError
|
|
331
|
+
except jwt.PyJWTError:
|
|
320
332
|
return None
|
|
321
|
-
except Exception
|
|
333
|
+
except Exception: # Catch any other unexpected error during decode
|
|
322
334
|
return None
|
|
323
|
-
|
|
324
|
-
async def get_client(self, client_id: str) ->
|
|
335
|
+
|
|
336
|
+
async def get_client(self, client_id: str) -> OAuthClientInformationFull | None:
|
|
325
337
|
"""Get client information by ID.
|
|
326
|
-
|
|
338
|
+
|
|
327
339
|
Args:
|
|
328
340
|
client_id: The client ID
|
|
329
|
-
|
|
341
|
+
|
|
330
342
|
Returns:
|
|
331
343
|
The client information or None if not found
|
|
332
344
|
"""
|
|
333
345
|
return self.storage.get_client(client_id)
|
|
334
|
-
|
|
335
|
-
async def register_client(
|
|
336
|
-
self,
|
|
337
|
-
client_info: OAuthClientInformationFull
|
|
338
|
-
) -> None:
|
|
346
|
+
|
|
347
|
+
async def register_client(self, client_info: OAuthClientInformationFull) -> None:
|
|
339
348
|
"""Register a new client."""
|
|
340
349
|
# Add detailed logging at the beginning
|
|
341
|
-
|
|
350
|
+
getattr(
|
|
351
|
+
client_info, "client_id", "UNKNOWN (client_info has no client_id attribute)"
|
|
352
|
+
)
|
|
342
353
|
try:
|
|
343
354
|
# Validate the client information
|
|
344
355
|
if not client_info.client_id:
|
|
345
356
|
raise RegistrationError(
|
|
346
357
|
error="invalid_client_metadata",
|
|
347
|
-
error_description="Client ID is missing in client_info provided to register_client"
|
|
358
|
+
error_description="Client ID is missing in client_info provided to register_client",
|
|
348
359
|
)
|
|
349
360
|
|
|
350
361
|
if not client_info.redirect_uris:
|
|
351
362
|
raise RegistrationError(
|
|
352
363
|
error="invalid_redirect_uri",
|
|
353
|
-
error_description="At least one redirect URI is required"
|
|
364
|
+
error_description="At least one redirect URI is required",
|
|
354
365
|
)
|
|
355
|
-
|
|
366
|
+
|
|
356
367
|
# Store the client
|
|
357
368
|
self.storage.store_client(client_info.client_id, client_info)
|
|
358
|
-
except Exception
|
|
359
|
-
raise
|
|
360
|
-
|
|
369
|
+
except Exception:
|
|
370
|
+
raise # Re-raise the exception so FastMCP can handle it
|
|
371
|
+
|
|
361
372
|
async def authorize(
|
|
362
|
-
self,
|
|
373
|
+
self,
|
|
374
|
+
client: OAuthClientInformationFull,
|
|
375
|
+
params: AuthorizationParams, # params from MCP client
|
|
363
376
|
) -> str:
|
|
364
377
|
"""Handle an authorization request.
|
|
365
378
|
This method is called when an MCP client requests authorization.
|
|
@@ -369,118 +382,123 @@ class GolfOAuthProvider(OAuthAuthorizationServerProvider):
|
|
|
369
382
|
import urllib.parse
|
|
370
383
|
|
|
371
384
|
idp_flow_state = secrets.token_hex(16)
|
|
372
|
-
mcp_client_original_state = params.state
|
|
373
|
-
|
|
385
|
+
mcp_client_original_state = params.state
|
|
386
|
+
|
|
374
387
|
self.state_mapping[idp_flow_state] = {
|
|
375
388
|
"client_id": client.client_id,
|
|
376
389
|
"redirect_uri": str(params.redirect_uri),
|
|
377
390
|
"code_challenge": params.code_challenge,
|
|
378
|
-
"code_challenge_method":
|
|
391
|
+
"code_challenge_method": (
|
|
392
|
+
"S256" if params.code_challenge else None
|
|
393
|
+
), # Store S256 if challenge exists, else None
|
|
379
394
|
"scopes": params.scopes,
|
|
380
395
|
"redirect_uri_provided_explicitly": params.redirect_uri_provided_explicitly,
|
|
381
|
-
"mcp_client_original_state": mcp_client_original_state
|
|
396
|
+
"mcp_client_original_state": mcp_client_original_state,
|
|
382
397
|
}
|
|
383
|
-
|
|
398
|
+
|
|
384
399
|
# Use self.config.callback_path for consistency
|
|
385
|
-
idp_callback_uri =
|
|
400
|
+
idp_callback_uri = (
|
|
401
|
+
f"{self.config.issuer_url.rstrip('/')}{self.config.callback_path}"
|
|
402
|
+
)
|
|
386
403
|
|
|
387
404
|
client_id = self._get_client_id()
|
|
388
|
-
|
|
405
|
+
|
|
389
406
|
auth_params_for_idp = {
|
|
390
407
|
"client_id": client_id,
|
|
391
408
|
"redirect_uri": idp_callback_uri,
|
|
392
409
|
"scope": " ".join(self.config.scopes),
|
|
393
410
|
"state": idp_flow_state,
|
|
394
|
-
"response_type": "code"
|
|
411
|
+
"response_type": "code",
|
|
395
412
|
}
|
|
396
|
-
|
|
413
|
+
|
|
397
414
|
if params.code_challenge:
|
|
398
415
|
auth_params_for_idp["code_challenge"] = params.code_challenge
|
|
399
416
|
# Always use S256 if a challenge is present, as it's the standard and what the client sends.
|
|
400
|
-
auth_params_for_idp["code_challenge_method"] = "S256"
|
|
417
|
+
auth_params_for_idp["code_challenge_method"] = "S256"
|
|
401
418
|
|
|
402
419
|
query_for_idp = urllib.parse.urlencode(auth_params_for_idp)
|
|
403
|
-
|
|
420
|
+
|
|
404
421
|
return f"{self.config.authorize_url}?{query_for_idp}"
|
|
405
|
-
|
|
422
|
+
|
|
406
423
|
async def load_authorization_code(
|
|
407
|
-
self,
|
|
408
|
-
|
|
409
|
-
code: str
|
|
410
|
-
) -> Optional[AuthorizationCode]:
|
|
424
|
+
self, client: OAuthClientInformationFull, code: str
|
|
425
|
+
) -> AuthorizationCode | None:
|
|
411
426
|
"""Load an authorization code.
|
|
412
|
-
|
|
427
|
+
|
|
413
428
|
Args:
|
|
414
429
|
client: The client information
|
|
415
430
|
code: The authorization code
|
|
416
|
-
|
|
431
|
+
|
|
417
432
|
Returns:
|
|
418
433
|
The authorization code object or None if not found
|
|
419
434
|
"""
|
|
420
435
|
auth_code = self.storage.get_auth_code(code)
|
|
421
|
-
|
|
436
|
+
|
|
422
437
|
if not auth_code:
|
|
423
438
|
return None
|
|
424
|
-
|
|
439
|
+
|
|
425
440
|
# Verify the code belongs to this client
|
|
426
441
|
if auth_code.client_id != client.client_id:
|
|
427
442
|
return None
|
|
428
|
-
|
|
443
|
+
|
|
429
444
|
# Verify the code hasn't expired
|
|
430
445
|
if auth_code.expires_at and auth_code.expires_at < datetime.now().timestamp():
|
|
431
446
|
self.storage.delete_auth_code(code)
|
|
432
447
|
return None
|
|
433
|
-
|
|
448
|
+
|
|
434
449
|
return auth_code
|
|
435
|
-
|
|
450
|
+
|
|
436
451
|
async def exchange_authorization_code(
|
|
437
|
-
self,
|
|
438
|
-
client: OAuthClientInformationFull,
|
|
439
|
-
code: AuthorizationCode
|
|
452
|
+
self,
|
|
453
|
+
client: OAuthClientInformationFull,
|
|
454
|
+
code: AuthorizationCode, # This is AuthorizationCode object
|
|
440
455
|
) -> OAuthToken:
|
|
441
456
|
"""Exchange an authorization code for tokens.
|
|
442
|
-
|
|
457
|
+
|
|
443
458
|
Args:
|
|
444
459
|
client: The client information
|
|
445
460
|
code: The authorization code object
|
|
446
|
-
|
|
461
|
+
|
|
447
462
|
Returns:
|
|
448
463
|
The OAuth token response
|
|
449
|
-
|
|
464
|
+
|
|
450
465
|
Raises:
|
|
451
466
|
TokenError: If the code exchange fails
|
|
452
467
|
"""
|
|
453
468
|
# Retrieve the provider token that was stored temporarily during callback
|
|
454
469
|
provider_token = self.storage.get_provider_token_for_auth_code(code.code)
|
|
455
|
-
|
|
470
|
+
|
|
456
471
|
# Delete the code and its mapping to ensure one-time use
|
|
457
|
-
self.storage.delete_auth_code(code.code)
|
|
458
|
-
|
|
472
|
+
self.storage.delete_auth_code(code.code) # This now also deletes the mapping
|
|
473
|
+
|
|
459
474
|
# Generate an access token
|
|
460
|
-
access_token_str = self._generate_jwt(
|
|
461
|
-
subject=client.client_id,
|
|
462
|
-
scopes=code.scopes
|
|
475
|
+
access_token_str = self._generate_jwt( # Renamed for clarity
|
|
476
|
+
subject=client.client_id, scopes=code.scopes
|
|
463
477
|
)
|
|
464
|
-
|
|
478
|
+
|
|
465
479
|
# Generate a refresh token if needed
|
|
466
|
-
refresh_token_str =
|
|
467
|
-
|
|
480
|
+
refresh_token_str = (
|
|
481
|
+
str(uuid.uuid4()) if "refresh_token" in client.grant_types else None
|
|
482
|
+
) # Renamed for clarity
|
|
483
|
+
|
|
468
484
|
# Store the mapping from our new MCP access token to the provider's access token
|
|
469
485
|
if provider_token and access_token_str:
|
|
470
486
|
self.storage.store_provider_token(access_token_str, provider_token)
|
|
471
|
-
|
|
487
|
+
|
|
472
488
|
# Store the tokens
|
|
473
489
|
if refresh_token_str:
|
|
474
490
|
self.storage.store_refresh_token(
|
|
475
|
-
refresh_token_str,
|
|
491
|
+
refresh_token_str,
|
|
476
492
|
RefreshToken(
|
|
477
493
|
token=refresh_token_str,
|
|
478
494
|
client_id=client.client_id,
|
|
479
495
|
scopes=code.scopes,
|
|
480
|
-
expires_at=int(
|
|
481
|
-
|
|
496
|
+
expires_at=int(
|
|
497
|
+
datetime.now().timestamp() + (self.config.token_expiration * 24)
|
|
498
|
+
), # 24x longer, cast to int
|
|
499
|
+
),
|
|
482
500
|
)
|
|
483
|
-
|
|
501
|
+
|
|
484
502
|
# Store access token information for validation later
|
|
485
503
|
# Note: For JWTs, we might not need to store them if we can verify the signature
|
|
486
504
|
self.storage.store_access_token(
|
|
@@ -489,71 +507,71 @@ class GolfOAuthProvider(OAuthAuthorizationServerProvider):
|
|
|
489
507
|
token=access_token_str,
|
|
490
508
|
client_id=client.client_id,
|
|
491
509
|
scopes=code.scopes,
|
|
492
|
-
expires_at=int(
|
|
493
|
-
|
|
510
|
+
expires_at=int(
|
|
511
|
+
datetime.now().timestamp() + self.config.token_expiration
|
|
512
|
+
), # Cast to int
|
|
513
|
+
),
|
|
494
514
|
)
|
|
495
|
-
|
|
515
|
+
|
|
496
516
|
# Create and return the OAuth token response
|
|
497
517
|
return OAuthToken(
|
|
498
518
|
access_token=access_token_str,
|
|
499
519
|
token_type="bearer",
|
|
500
520
|
expires_in=self.config.token_expiration,
|
|
501
521
|
refresh_token=refresh_token_str,
|
|
502
|
-
scope=" ".join(code.scopes)
|
|
522
|
+
scope=" ".join(code.scopes),
|
|
503
523
|
)
|
|
504
|
-
|
|
524
|
+
|
|
505
525
|
async def load_refresh_token(
|
|
506
|
-
self,
|
|
507
|
-
|
|
508
|
-
refresh_token: str
|
|
509
|
-
) -> Optional[RefreshToken]:
|
|
526
|
+
self, client: OAuthClientInformationFull, refresh_token: str
|
|
527
|
+
) -> RefreshToken | None:
|
|
510
528
|
"""Load a refresh token.
|
|
511
|
-
|
|
529
|
+
|
|
512
530
|
Args:
|
|
513
531
|
client: The client information
|
|
514
532
|
refresh_token: The refresh token string
|
|
515
|
-
|
|
533
|
+
|
|
516
534
|
Returns:
|
|
517
535
|
The refresh token object or None if not found
|
|
518
536
|
"""
|
|
519
537
|
token = self.storage.get_refresh_token(refresh_token)
|
|
520
|
-
|
|
538
|
+
|
|
521
539
|
if not token:
|
|
522
540
|
return None
|
|
523
|
-
|
|
541
|
+
|
|
524
542
|
# Verify the token belongs to this client
|
|
525
543
|
if token.client_id != client.client_id:
|
|
526
544
|
return None
|
|
527
|
-
|
|
545
|
+
|
|
528
546
|
# Verify the token hasn't expired
|
|
529
547
|
if token.expires_at and token.expires_at < datetime.now().timestamp():
|
|
530
548
|
self.storage.delete_refresh_token(refresh_token)
|
|
531
549
|
return None
|
|
532
|
-
|
|
550
|
+
|
|
533
551
|
return token
|
|
534
|
-
|
|
552
|
+
|
|
535
553
|
async def exchange_refresh_token(
|
|
536
|
-
self,
|
|
537
|
-
client: OAuthClientInformationFull,
|
|
538
|
-
refresh_token: RefreshToken,
|
|
539
|
-
scopes:
|
|
554
|
+
self,
|
|
555
|
+
client: OAuthClientInformationFull,
|
|
556
|
+
refresh_token: RefreshToken,
|
|
557
|
+
scopes: list[str],
|
|
540
558
|
) -> OAuthToken:
|
|
541
559
|
"""Exchange a refresh token for a new token pair.
|
|
542
|
-
|
|
560
|
+
|
|
543
561
|
Args:
|
|
544
562
|
client: The client information
|
|
545
563
|
refresh_token: The refresh token object
|
|
546
564
|
scopes: The requested scopes (may be a subset of original)
|
|
547
|
-
|
|
565
|
+
|
|
548
566
|
Returns:
|
|
549
567
|
The new OAuth token response
|
|
550
|
-
|
|
568
|
+
|
|
551
569
|
Raises:
|
|
552
570
|
TokenError: If the token exchange fails
|
|
553
571
|
"""
|
|
554
572
|
# Delete the old refresh token (implement token rotation for security)
|
|
555
573
|
self.storage.delete_refresh_token(refresh_token.token)
|
|
556
|
-
|
|
574
|
+
|
|
557
575
|
# Determine the scopes for the new token
|
|
558
576
|
# If requested scopes are provided, they must be a subset of the original
|
|
559
577
|
if scopes:
|
|
@@ -562,20 +580,18 @@ class GolfOAuthProvider(OAuthAuthorizationServerProvider):
|
|
|
562
580
|
valid_scopes = refresh_token.scopes
|
|
563
581
|
else:
|
|
564
582
|
valid_scopes = refresh_token.scopes
|
|
565
|
-
|
|
583
|
+
|
|
566
584
|
# Generate a new access token
|
|
567
|
-
access_token = self._generate_jwt(
|
|
568
|
-
|
|
569
|
-
scopes=valid_scopes
|
|
570
|
-
)
|
|
571
|
-
|
|
585
|
+
access_token = self._generate_jwt(subject=client.client_id, scopes=valid_scopes)
|
|
586
|
+
|
|
572
587
|
# Generate a new refresh token
|
|
573
588
|
new_refresh_token = str(uuid.uuid4())
|
|
574
|
-
|
|
589
|
+
|
|
575
590
|
# Find the provider token if it exists from the old access token
|
|
576
591
|
# Note: This assumes each refresh generates only one access token
|
|
577
592
|
old_access_tokens = [
|
|
578
|
-
token
|
|
593
|
+
token
|
|
594
|
+
for token, data in self.storage.access_tokens.items()
|
|
579
595
|
if data.client_id == client.client_id
|
|
580
596
|
]
|
|
581
597
|
provider_token = None
|
|
@@ -585,18 +601,20 @@ class GolfOAuthProvider(OAuthAuthorizationServerProvider):
|
|
|
585
601
|
# Store the provider token mapping for the new access token
|
|
586
602
|
self.storage.store_provider_token(access_token, provider_token)
|
|
587
603
|
break
|
|
588
|
-
|
|
604
|
+
|
|
589
605
|
# Store the new tokens
|
|
590
606
|
self.storage.store_refresh_token(
|
|
591
|
-
new_refresh_token,
|
|
607
|
+
new_refresh_token,
|
|
592
608
|
RefreshToken(
|
|
593
609
|
token=new_refresh_token,
|
|
594
610
|
client_id=client.client_id,
|
|
595
611
|
scopes=valid_scopes,
|
|
596
|
-
expires_at=int(
|
|
597
|
-
|
|
612
|
+
expires_at=int(
|
|
613
|
+
datetime.now().timestamp() + (self.config.token_expiration * 24)
|
|
614
|
+
), # Cast to int
|
|
615
|
+
),
|
|
598
616
|
)
|
|
599
|
-
|
|
617
|
+
|
|
600
618
|
# Store access token information
|
|
601
619
|
self.storage.store_access_token(
|
|
602
620
|
access_token,
|
|
@@ -604,67 +622,68 @@ class GolfOAuthProvider(OAuthAuthorizationServerProvider):
|
|
|
604
622
|
token=access_token,
|
|
605
623
|
client_id=client.client_id,
|
|
606
624
|
scopes=valid_scopes,
|
|
607
|
-
expires_at=int(
|
|
608
|
-
|
|
625
|
+
expires_at=int(
|
|
626
|
+
datetime.now().timestamp() + self.config.token_expiration
|
|
627
|
+
), # Cast to int
|
|
628
|
+
),
|
|
609
629
|
)
|
|
610
|
-
|
|
630
|
+
|
|
611
631
|
# Create and return the OAuth token response
|
|
612
632
|
return OAuthToken(
|
|
613
633
|
access_token=access_token,
|
|
614
634
|
token_type="bearer",
|
|
615
635
|
expires_in=self.config.token_expiration,
|
|
616
636
|
refresh_token=new_refresh_token,
|
|
617
|
-
scope=" ".join(valid_scopes)
|
|
637
|
+
scope=" ".join(valid_scopes),
|
|
618
638
|
)
|
|
619
|
-
|
|
620
|
-
async def load_access_token(self, token: str) ->
|
|
639
|
+
|
|
640
|
+
async def load_access_token(self, token: str) -> AccessToken | None:
|
|
621
641
|
"""Load and validate an access token."""
|
|
622
|
-
|
|
642
|
+
|
|
623
643
|
payload = self._verify_jwt(token)
|
|
624
644
|
if not payload:
|
|
625
|
-
return None
|
|
626
|
-
|
|
645
|
+
return None
|
|
646
|
+
|
|
627
647
|
client_id = payload.get("sub")
|
|
628
648
|
scopes = payload.get("scp", [])
|
|
629
649
|
expires_at = payload.get("exp")
|
|
630
|
-
|
|
631
|
-
|
|
650
|
+
|
|
632
651
|
access_token_obj = AccessToken(
|
|
633
652
|
token=token,
|
|
634
653
|
client_id=client_id,
|
|
635
654
|
scopes=scopes,
|
|
636
|
-
expires_at=int(expires_at) if expires_at is not None else None
|
|
655
|
+
expires_at=int(expires_at) if expires_at is not None else None,
|
|
637
656
|
)
|
|
638
|
-
|
|
657
|
+
|
|
639
658
|
return access_token_obj
|
|
640
|
-
|
|
641
|
-
async def revoke_token(self, token:
|
|
659
|
+
|
|
660
|
+
async def revoke_token(self, token: AccessToken | RefreshToken) -> None:
|
|
642
661
|
"""Revoke a token.
|
|
643
|
-
|
|
662
|
+
|
|
644
663
|
Args:
|
|
645
664
|
token: The token to revoke (access or refresh)
|
|
646
665
|
"""
|
|
647
666
|
# Try to revoke as access token
|
|
648
667
|
self.storage.delete_access_token(token.token)
|
|
649
|
-
|
|
668
|
+
|
|
650
669
|
# Try to revoke as refresh token
|
|
651
670
|
self.storage.delete_refresh_token(token.token)
|
|
652
|
-
|
|
671
|
+
|
|
653
672
|
# Clean up provider token mapping if it exists
|
|
654
673
|
provider_token = self.storage.get_provider_token(token.token)
|
|
655
674
|
if provider_token:
|
|
656
675
|
self.storage.provider_tokens.pop(token.token, None)
|
|
657
|
-
|
|
658
|
-
def get_provider_token(self, mcp_token: str) ->
|
|
676
|
+
|
|
677
|
+
def get_provider_token(self, mcp_token: str) -> str | None:
|
|
659
678
|
"""Get the provider token associated with an MCP token.
|
|
660
|
-
|
|
679
|
+
|
|
661
680
|
This is a non-standard method to allow access to the provider token
|
|
662
681
|
(e.g., GitHub token) for a given MCP token. This can be used by
|
|
663
682
|
tools that need to access provider APIs.
|
|
664
|
-
|
|
683
|
+
|
|
665
684
|
Args:
|
|
666
685
|
mcp_token: The MCP token
|
|
667
|
-
|
|
686
|
+
|
|
668
687
|
Returns:
|
|
669
688
|
The provider token or None if not found
|
|
670
689
|
"""
|
|
@@ -673,39 +692,48 @@ class GolfOAuthProvider(OAuthAuthorizationServerProvider):
|
|
|
673
692
|
|
|
674
693
|
def create_callback_handler(provider: GolfOAuthProvider):
|
|
675
694
|
"""Create a callback handler for OAuth authorization.
|
|
676
|
-
|
|
695
|
+
|
|
677
696
|
This function creates a callback handler that can be used to handle
|
|
678
697
|
the OAuth callback from the provider (e.g., GitHub).
|
|
679
|
-
|
|
698
|
+
|
|
680
699
|
Args:
|
|
681
700
|
provider: The OAuth provider
|
|
682
|
-
|
|
701
|
+
|
|
683
702
|
Returns:
|
|
684
703
|
An async function that handles the callback
|
|
685
704
|
"""
|
|
705
|
+
|
|
686
706
|
async def handle_callback(request):
|
|
687
707
|
"""Handle the OAuth callback.
|
|
688
|
-
|
|
708
|
+
|
|
689
709
|
Args:
|
|
690
710
|
request: The HTTP request
|
|
691
|
-
|
|
711
|
+
|
|
692
712
|
Returns:
|
|
693
713
|
The HTTP response
|
|
694
714
|
"""
|
|
695
715
|
# Extract the code and state from the request
|
|
696
|
-
idp_auth_code = request.query_params.get(
|
|
697
|
-
|
|
698
|
-
|
|
716
|
+
idp_auth_code = request.query_params.get(
|
|
717
|
+
"code"
|
|
718
|
+
) # Renamed for clarity: code from IdP
|
|
719
|
+
idp_state = request.query_params.get(
|
|
720
|
+
"state"
|
|
721
|
+
) # Renamed for clarity: state from IdP
|
|
722
|
+
|
|
699
723
|
if not idp_auth_code:
|
|
700
|
-
return RedirectResponse(
|
|
701
|
-
|
|
724
|
+
return RedirectResponse(
|
|
725
|
+
"/auth-error?error=no_code_from_idp"
|
|
726
|
+
) # More specific error
|
|
727
|
+
|
|
702
728
|
# Use provider.config.callback_path for consistency
|
|
703
729
|
# This is the redirect_uri registered with the IdP and used in the /authorize step
|
|
704
|
-
idp_callback_uri_for_token_exchange =
|
|
705
|
-
|
|
730
|
+
idp_callback_uri_for_token_exchange = (
|
|
731
|
+
f"{provider.config.issuer_url.rstrip('/')}{provider.config.callback_path}"
|
|
732
|
+
)
|
|
733
|
+
|
|
706
734
|
client_id_for_idp = provider._get_client_id()
|
|
707
735
|
client_secret_for_idp = provider._get_client_secret()
|
|
708
|
-
|
|
736
|
+
|
|
709
737
|
async with httpx.AsyncClient() as client:
|
|
710
738
|
response = await client.post(
|
|
711
739
|
provider.config.token_url,
|
|
@@ -713,86 +741,121 @@ def create_callback_handler(provider: GolfOAuthProvider):
|
|
|
713
741
|
data={
|
|
714
742
|
"client_id": client_id_for_idp,
|
|
715
743
|
"client_secret": client_secret_for_idp,
|
|
716
|
-
"code": idp_auth_code,
|
|
717
|
-
"redirect_uri": idp_callback_uri_for_token_exchange
|
|
718
|
-
}
|
|
744
|
+
"code": idp_auth_code, # Use code from IdP
|
|
745
|
+
"redirect_uri": idp_callback_uri_for_token_exchange,
|
|
746
|
+
},
|
|
719
747
|
)
|
|
720
|
-
|
|
748
|
+
|
|
721
749
|
if response.status_code != 200:
|
|
722
|
-
error_detail = response.text[:200]
|
|
723
|
-
return RedirectResponse(
|
|
724
|
-
|
|
750
|
+
error_detail = response.text[:200] # Limit error detail length
|
|
751
|
+
return RedirectResponse(
|
|
752
|
+
f"/auth-error?error=idp_token_exchange_failed&detail={urllib.parse.quote(error_detail)}"
|
|
753
|
+
)
|
|
754
|
+
|
|
725
755
|
# Get the provider token from the response
|
|
726
756
|
token_data = response.json()
|
|
727
|
-
provider_access_token = token_data.get(
|
|
728
|
-
|
|
757
|
+
provider_access_token = token_data.get(
|
|
758
|
+
"access_token"
|
|
759
|
+
) # This is the token from GitHub/Google etc.
|
|
760
|
+
|
|
729
761
|
if not provider_access_token:
|
|
730
762
|
return RedirectResponse("/auth-error?error=no_access_token_from_idp")
|
|
731
|
-
|
|
763
|
+
|
|
732
764
|
try:
|
|
733
765
|
# Get user information from the provider using the token (optional step)
|
|
734
766
|
# user_info = None (keep this if user_info is used later, otherwise remove)
|
|
735
767
|
# ... (userinfo fetching logic if needed) ...
|
|
736
|
-
|
|
737
|
-
original_mcp_client_details = provider.state_mapping.pop(
|
|
768
|
+
|
|
769
|
+
original_mcp_client_details = provider.state_mapping.pop(
|
|
770
|
+
idp_state, None
|
|
771
|
+
) # Use state from IdP
|
|
738
772
|
if not original_mcp_client_details:
|
|
739
|
-
return RedirectResponse(
|
|
773
|
+
return RedirectResponse("/auth-error?error=invalid_idp_state")
|
|
740
774
|
|
|
741
775
|
original_mcp_client_id = original_mcp_client_details["client_id"]
|
|
742
|
-
original_mcp_redirect_uri = original_mcp_client_details[
|
|
776
|
+
original_mcp_redirect_uri = original_mcp_client_details[
|
|
777
|
+
"redirect_uri"
|
|
778
|
+
] # MCP client's redirect_uri
|
|
743
779
|
original_code_challenge = original_mcp_client_details["code_challenge"]
|
|
744
|
-
original_code_challenge_method = original_mcp_client_details[
|
|
745
|
-
|
|
780
|
+
original_code_challenge_method = original_mcp_client_details[
|
|
781
|
+
"code_challenge_method"
|
|
782
|
+
]
|
|
783
|
+
|
|
746
784
|
requested_scopes_for_mcp_server_str = original_mcp_client_details["scopes"]
|
|
747
|
-
mcp_client_original_state_to_pass_back = original_mcp_client_details.get(
|
|
748
|
-
|
|
749
|
-
|
|
750
|
-
|
|
785
|
+
mcp_client_original_state_to_pass_back = original_mcp_client_details.get(
|
|
786
|
+
"mcp_client_original_state"
|
|
787
|
+
)
|
|
788
|
+
original_redirect_uri_provided_explicitly = original_mcp_client_details[
|
|
789
|
+
"redirect_uri_provided_explicitly"
|
|
790
|
+
]
|
|
791
|
+
|
|
792
|
+
mcp_client = await provider.get_client(
|
|
793
|
+
original_mcp_client_id
|
|
794
|
+
) # Renamed for clarity
|
|
751
795
|
if not mcp_client:
|
|
752
|
-
return RedirectResponse(
|
|
753
|
-
|
|
754
|
-
|
|
755
|
-
|
|
756
|
-
|
|
757
|
-
|
|
758
|
-
final_scopes_for_mcp_auth_code =
|
|
759
|
-
|
|
796
|
+
return RedirectResponse(
|
|
797
|
+
"/auth-error?error=mcp_client_not_found_post_callback"
|
|
798
|
+
)
|
|
799
|
+
|
|
800
|
+
final_scopes_for_mcp_auth_code: list[str]
|
|
801
|
+
if requested_scopes_for_mcp_server_str: # Scopes requested by MCP client
|
|
802
|
+
final_scopes_for_mcp_auth_code = (
|
|
803
|
+
requested_scopes_for_mcp_server_str.split()
|
|
804
|
+
)
|
|
805
|
+
else: # Default to client's registered scopes if none explicitly requested
|
|
806
|
+
final_scopes_for_mcp_auth_code = (
|
|
807
|
+
mcp_client.scope.split() if mcp_client.scope else []
|
|
808
|
+
)
|
|
809
|
+
|
|
760
810
|
# This is the auth code our GolfMCP server issues to the MCP client
|
|
761
|
-
mcp_auth_code_str = str(uuid.uuid4())
|
|
762
|
-
|
|
811
|
+
mcp_auth_code_str = str(uuid.uuid4())
|
|
812
|
+
|
|
763
813
|
# Store the mapping from our mcp_auth_code_str to the provider_access_token (e.g., GitHub token)
|
|
764
814
|
# This will be retrieved when the MCP client exchanges mcp_auth_code_str for an MCP access token
|
|
765
|
-
provider.storage.store_auth_code_provider_token_mapping(
|
|
766
|
-
|
|
815
|
+
provider.storage.store_auth_code_provider_token_mapping(
|
|
816
|
+
mcp_auth_code_str, provider_access_token
|
|
817
|
+
)
|
|
818
|
+
|
|
767
819
|
# Create the AuthorizationCode object for our server
|
|
768
|
-
mcp_auth_code_obj = AuthorizationCode(
|
|
820
|
+
mcp_auth_code_obj = AuthorizationCode( # Renamed for clarity
|
|
769
821
|
code=mcp_auth_code_str,
|
|
770
822
|
client_id=mcp_client.client_id,
|
|
771
|
-
redirect_uri=original_mcp_redirect_uri,
|
|
772
|
-
scopes=final_scopes_for_mcp_auth_code,
|
|
773
|
-
expires_at=int(
|
|
823
|
+
redirect_uri=original_mcp_redirect_uri,
|
|
824
|
+
scopes=final_scopes_for_mcp_auth_code,
|
|
825
|
+
expires_at=int(
|
|
826
|
+
datetime.now().timestamp() + 600
|
|
827
|
+
), # 10 minutes, cast to int
|
|
774
828
|
redirect_uri_provided_explicitly=original_redirect_uri_provided_explicitly,
|
|
775
829
|
code_challenge=original_code_challenge,
|
|
776
|
-
code_challenge_method=original_code_challenge_method
|
|
830
|
+
code_challenge_method=original_code_challenge_method,
|
|
777
831
|
)
|
|
778
|
-
|
|
832
|
+
|
|
779
833
|
# Store our auth code object (without provider_token as an attribute)
|
|
780
834
|
provider.storage.store_auth_code(mcp_auth_code_str, mcp_auth_code_obj)
|
|
781
|
-
|
|
835
|
+
|
|
782
836
|
query_params_for_mcp_client = {
|
|
783
|
-
"code": mcp_auth_code_str
|
|
837
|
+
"code": mcp_auth_code_str # Send our generated auth code to the MCP client
|
|
784
838
|
}
|
|
785
839
|
if mcp_client_original_state_to_pass_back:
|
|
786
|
-
query_params_for_mcp_client["state"] =
|
|
787
|
-
|
|
788
|
-
|
|
789
|
-
|
|
790
|
-
|
|
791
|
-
|
|
840
|
+
query_params_for_mcp_client["state"] = (
|
|
841
|
+
mcp_client_original_state_to_pass_back
|
|
842
|
+
)
|
|
843
|
+
|
|
844
|
+
import urllib.parse # Ensure it's imported here too
|
|
845
|
+
|
|
846
|
+
final_query_for_mcp_client = urllib.parse.urlencode(
|
|
847
|
+
query_params_for_mcp_client
|
|
848
|
+
)
|
|
849
|
+
final_redirect_to_mcp_client = (
|
|
850
|
+
f"{original_mcp_redirect_uri}?{final_query_for_mcp_client}"
|
|
851
|
+
)
|
|
852
|
+
|
|
792
853
|
return RedirectResponse(final_redirect_to_mcp_client)
|
|
793
|
-
|
|
794
|
-
except Exception
|
|
854
|
+
|
|
855
|
+
except Exception:
|
|
795
856
|
# Avoid sending raw exception details to the client for security
|
|
796
|
-
return RedirectResponse(
|
|
797
|
-
|
|
798
|
-
|
|
857
|
+
return RedirectResponse(
|
|
858
|
+
"/auth-error?error=callback_processing_failed&detail=internal_server_error"
|
|
859
|
+
)
|
|
860
|
+
|
|
861
|
+
return handle_callback
|