openapi-mcp-gateway 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openapi_mcp_gateway/__init__.py +15 -0
- openapi_mcp_gateway/auth/__init__.py +15 -0
- openapi_mcp_gateway/auth/detector.py +64 -0
- openapi_mcp_gateway/auth/provider.py +412 -0
- openapi_mcp_gateway/auth/resolver.py +51 -0
- openapi_mcp_gateway/cli.py +289 -0
- openapi_mcp_gateway/client.py +98 -0
- openapi_mcp_gateway/gateway.py +477 -0
- openapi_mcp_gateway/generator.py +206 -0
- openapi_mcp_gateway/logger.py +111 -0
- openapi_mcp_gateway/openapi.py +260 -0
- openapi_mcp_gateway/policy.py +51 -0
- openapi_mcp_gateway/settings.py +193 -0
- openapi_mcp_gateway/stores/__init__.py +28 -0
- openapi_mcp_gateway/stores/base.py +58 -0
- openapi_mcp_gateway/stores/memory.py +56 -0
- openapi_mcp_gateway/stores/redis.py +62 -0
- openapi_mcp_gateway-0.1.0.dist-info/METADATA +295 -0
- openapi_mcp_gateway-0.1.0.dist-info/RECORD +22 -0
- openapi_mcp_gateway-0.1.0.dist-info/WHEEL +4 -0
- openapi_mcp_gateway-0.1.0.dist-info/entry_points.txt +2 -0
- openapi_mcp_gateway-0.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
from .gateway import Gateway
|
|
2
|
+
from .settings import CORSConfig, GatewayConfig, LoggingConfig, ServerConfig, StoreConfig
|
|
3
|
+
from .stores import MemoryTokenStore, TokenStore
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
__all__ = [
|
|
7
|
+
'CORSConfig',
|
|
8
|
+
'Gateway',
|
|
9
|
+
'GatewayConfig',
|
|
10
|
+
'LoggingConfig',
|
|
11
|
+
'MemoryTokenStore',
|
|
12
|
+
'ServerConfig',
|
|
13
|
+
'StoreConfig',
|
|
14
|
+
'TokenStore',
|
|
15
|
+
]
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
from .detector import DetectedOAuthFlow, detect_oauth_flows, detect_primary_oauth_flow
|
|
2
|
+
from .provider import GatewayOAuthProvider
|
|
3
|
+
from .resolver import AuthResolver, NullAuthResolver, OAuthAuthResolver, StaticAuthResolver
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
__all__ = [
|
|
7
|
+
'AuthResolver',
|
|
8
|
+
'DetectedOAuthFlow',
|
|
9
|
+
'GatewayOAuthProvider',
|
|
10
|
+
'NullAuthResolver',
|
|
11
|
+
'OAuthAuthResolver',
|
|
12
|
+
'StaticAuthResolver',
|
|
13
|
+
'detect_oauth_flows',
|
|
14
|
+
'detect_primary_oauth_flow',
|
|
15
|
+
]
|
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
import typing
|
|
2
|
+
|
|
3
|
+
import pydantic
|
|
4
|
+
|
|
5
|
+
from ..openapi import OpenAPISpec
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class DetectedOAuthFlow(pydantic.BaseModel):
|
|
9
|
+
"""Minimal OAuth2 flow metadata extracted from ``components.securitySchemes``."""
|
|
10
|
+
|
|
11
|
+
flow_type: typing.Literal['authorization_code', 'client_credentials']
|
|
12
|
+
authorization_url: str | None = None
|
|
13
|
+
token_url: str
|
|
14
|
+
scopes: dict[str, str] = pydantic.Field(default_factory=dict)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def detect_oauth_flows(spec: OpenAPISpec) -> list[DetectedOAuthFlow]:
|
|
18
|
+
"""Return every OAuth2 flow advertised under ``securitySchemes``."""
|
|
19
|
+
flows: list[DetectedOAuthFlow] = []
|
|
20
|
+
|
|
21
|
+
for _scheme_name, scheme in spec.security_schemes.items():
|
|
22
|
+
if scheme.get('type') != 'oauth2':
|
|
23
|
+
continue
|
|
24
|
+
|
|
25
|
+
oauth_flows = scheme.get('flows', {})
|
|
26
|
+
|
|
27
|
+
if 'authorizationCode' in oauth_flows:
|
|
28
|
+
flow_data = oauth_flows['authorizationCode']
|
|
29
|
+
flows.append(
|
|
30
|
+
DetectedOAuthFlow(
|
|
31
|
+
flow_type='authorization_code',
|
|
32
|
+
authorization_url=flow_data.get('authorizationUrl'),
|
|
33
|
+
token_url=flow_data['tokenUrl'],
|
|
34
|
+
scopes=flow_data.get('scopes', {}),
|
|
35
|
+
)
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
if 'clientCredentials' in oauth_flows:
|
|
39
|
+
flow_data = oauth_flows['clientCredentials']
|
|
40
|
+
flows.append(
|
|
41
|
+
DetectedOAuthFlow(
|
|
42
|
+
flow_type='client_credentials',
|
|
43
|
+
token_url=flow_data['tokenUrl'],
|
|
44
|
+
scopes=flow_data.get('scopes', {}),
|
|
45
|
+
)
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
return flows
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def detect_primary_oauth_flow(spec: OpenAPISpec) -> DetectedOAuthFlow | None:
|
|
52
|
+
"""Pick a single OAuth2 flow, favouring ``authorization_code``.
|
|
53
|
+
|
|
54
|
+
Returns ``None`` when the document defines no OAuth2 flows.
|
|
55
|
+
"""
|
|
56
|
+
flows = detect_oauth_flows(spec)
|
|
57
|
+
if not flows:
|
|
58
|
+
return None
|
|
59
|
+
|
|
60
|
+
# Prefer authorization_code
|
|
61
|
+
for flow in flows:
|
|
62
|
+
if flow.flow_type == 'authorization_code':
|
|
63
|
+
return flow
|
|
64
|
+
return flows[0]
|
|
@@ -0,0 +1,412 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import secrets
|
|
3
|
+
import time
|
|
4
|
+
import typing
|
|
5
|
+
import urllib.parse
|
|
6
|
+
|
|
7
|
+
import httpx
|
|
8
|
+
from mcp.server.auth.middleware.auth_context import get_access_token
|
|
9
|
+
from mcp.server.auth.provider import (
|
|
10
|
+
AccessToken,
|
|
11
|
+
AuthorizationCode,
|
|
12
|
+
AuthorizationParams,
|
|
13
|
+
RefreshToken,
|
|
14
|
+
TokenError,
|
|
15
|
+
construct_redirect_uri,
|
|
16
|
+
)
|
|
17
|
+
from mcp.shared.auth import OAuthClientInformationFull, OAuthToken
|
|
18
|
+
from starlette.exceptions import HTTPException
|
|
19
|
+
|
|
20
|
+
from ..stores.base import TokenStore
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
logger = logging.getLogger(__name__)
|
|
24
|
+
|
|
25
|
+
MCP_ACCESS_TOKEN_TTL = 3600 # 1 hour
|
|
26
|
+
MCP_REFRESH_TOKEN_TTL = 86400 # 24 hours
|
|
27
|
+
MCP_SCOPES = ['api']
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class GatewayOAuthProvider:
|
|
31
|
+
"""Implement MCP's OAuth server provider API against an upstream OAuth2 API.
|
|
32
|
+
|
|
33
|
+
Registers MCP clients, forwards browser authorization to the upstream IdP,
|
|
34
|
+
exchanges grants with ``upstream_token_url``, and keeps MCP ↔ upstream token
|
|
35
|
+
mappings inside ``store``.
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
def __init__(
|
|
39
|
+
self,
|
|
40
|
+
store: TokenStore,
|
|
41
|
+
upstream_auth_url: str,
|
|
42
|
+
upstream_token_url: str,
|
|
43
|
+
client_id: str,
|
|
44
|
+
client_secret: str,
|
|
45
|
+
callback_url: str,
|
|
46
|
+
scopes: list[str] | None = None,
|
|
47
|
+
prefix: str = 'gateway',
|
|
48
|
+
) -> None:
|
|
49
|
+
"""Bind token ``store``, upstream endpoints, IdP client credentials, and callback."""
|
|
50
|
+
self.store = store
|
|
51
|
+
self.upstream_auth_url = upstream_auth_url
|
|
52
|
+
self.upstream_token_url = upstream_token_url
|
|
53
|
+
self.client_id = client_id
|
|
54
|
+
self.client_secret = client_secret
|
|
55
|
+
self.callback_url = callback_url
|
|
56
|
+
self.scopes = scopes or []
|
|
57
|
+
self._prefix = prefix
|
|
58
|
+
|
|
59
|
+
# ── MCP SDK OAuthAuthorizationServerProvider interface ──
|
|
60
|
+
|
|
61
|
+
async def get_client(self, client_id: str) -> OAuthClientInformationFull | None:
|
|
62
|
+
"""Load MCP OAuth client registration state by ``client_id``."""
|
|
63
|
+
data = await self.store.get('mcp_client', client_id)
|
|
64
|
+
if data:
|
|
65
|
+
return OAuthClientInformationFull(**data)
|
|
66
|
+
return None
|
|
67
|
+
|
|
68
|
+
async def register_client(self, client_info: OAuthClientInformationFull) -> None:
|
|
69
|
+
"""Persist ``client_info`` for subsequent authorize/token exchanges."""
|
|
70
|
+
if not client_info.client_id:
|
|
71
|
+
raise ValueError('client_id is required')
|
|
72
|
+
await self.store.set(
|
|
73
|
+
'mcp_client',
|
|
74
|
+
client_info.client_id,
|
|
75
|
+
client_info.model_dump(exclude_none=True, mode='json'),
|
|
76
|
+
)
|
|
77
|
+
logger.info('Registered MCP client: client_id=%s prefix=%s', client_info.client_id, self._prefix)
|
|
78
|
+
|
|
79
|
+
async def authorize(self, client: OAuthClientInformationFull, params: AuthorizationParams) -> str:
|
|
80
|
+
"""Return the upstream authorize URL with PKCE/state payload stored in ``store``."""
|
|
81
|
+
client_id = self._require_client_id(client)
|
|
82
|
+
state = params.state or secrets.token_hex(16)
|
|
83
|
+
|
|
84
|
+
state_data = {
|
|
85
|
+
'redirect_uri': str(params.redirect_uri),
|
|
86
|
+
'code_challenge': params.code_challenge,
|
|
87
|
+
'redirect_uri_provided_explicitly': params.redirect_uri_provided_explicitly,
|
|
88
|
+
'client_id': client_id,
|
|
89
|
+
}
|
|
90
|
+
await self.store.set('mcp_auth_state', state, state_data, ttl=900)
|
|
91
|
+
|
|
92
|
+
query_params = {
|
|
93
|
+
'client_id': self.client_id,
|
|
94
|
+
'redirect_uri': self.callback_url,
|
|
95
|
+
'state': state,
|
|
96
|
+
'response_type': 'code',
|
|
97
|
+
}
|
|
98
|
+
if self.scopes:
|
|
99
|
+
query_params['scope'] = ' '.join(self.scopes)
|
|
100
|
+
|
|
101
|
+
logger.info('Upstream OAuth authorize: scopes=%s', self.scopes)
|
|
102
|
+
return f'{self.upstream_auth_url}?{urllib.parse.urlencode(query_params)}'
|
|
103
|
+
|
|
104
|
+
async def load_authorization_code(
|
|
105
|
+
self, client: OAuthClientInformationFull, authorization_code: str
|
|
106
|
+
) -> AuthorizationCode | None:
|
|
107
|
+
"""Hydrate ``AuthorizationCode`` when it belongs to ``client``."""
|
|
108
|
+
client_id = self._require_client_id(client)
|
|
109
|
+
data = await self.store.get('mcp_auth_code', authorization_code)
|
|
110
|
+
if data and data['client_id'] == client_id:
|
|
111
|
+
return AuthorizationCode(**data)
|
|
112
|
+
return None
|
|
113
|
+
|
|
114
|
+
async def exchange_authorization_code(
|
|
115
|
+
self, client: OAuthClientInformationFull, authorization_code: AuthorizationCode
|
|
116
|
+
) -> OAuthToken:
|
|
117
|
+
"""Exchange an MCP authorization code for MCP access/refresh tokens."""
|
|
118
|
+
client_id = self._require_client_id(client)
|
|
119
|
+
data = await self.store.get('mcp_auth_code', authorization_code.code)
|
|
120
|
+
|
|
121
|
+
if not data or data['client_id'] != client_id:
|
|
122
|
+
logger.warning('OAuth code exchange rejected: reason=invalid_code client_id=%s', client_id)
|
|
123
|
+
raise TokenError(error='invalid_grant', error_description='Invalid authorization code')
|
|
124
|
+
|
|
125
|
+
api_access_token = await self.store.get_mapping(
|
|
126
|
+
'mcp_auth_code', authorization_code.code, 'api_access_token'
|
|
127
|
+
) or await self.store.get_mapping('client', client_id, 'api_access_token')
|
|
128
|
+
|
|
129
|
+
if not api_access_token:
|
|
130
|
+
logger.warning('OAuth code exchange rejected: reason=no_upstream_token client_id=%s', client_id)
|
|
131
|
+
raise TokenError(error='invalid_grant', error_description='No upstream API token found')
|
|
132
|
+
|
|
133
|
+
api_refresh_token = await self.store.get_mapping('mcp_auth_code', authorization_code.code, 'api_refresh_token')
|
|
134
|
+
|
|
135
|
+
return await self._issue_mcp_token(
|
|
136
|
+
client_id=client_id,
|
|
137
|
+
scopes=authorization_code.scopes,
|
|
138
|
+
api_access_token=api_access_token,
|
|
139
|
+
api_refresh_token=api_refresh_token,
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
async def load_access_token(self, token: str) -> AccessToken | None:
|
|
143
|
+
"""Return MCP access token payload when ``token`` is valid."""
|
|
144
|
+
data = await self.store.get('mcp_access_token', token)
|
|
145
|
+
if data:
|
|
146
|
+
return AccessToken(**data)
|
|
147
|
+
return None
|
|
148
|
+
|
|
149
|
+
async def load_refresh_token(self, client: OAuthClientInformationFull, refresh_token: str) -> RefreshToken | None:
|
|
150
|
+
"""Return MCP refresh metadata scoped to ``client``."""
|
|
151
|
+
client_id = self._require_client_id(client)
|
|
152
|
+
data = await self.store.get('mcp_refresh_token', refresh_token)
|
|
153
|
+
if data and data['client_id'] == client_id:
|
|
154
|
+
return RefreshToken(**data)
|
|
155
|
+
return None
|
|
156
|
+
|
|
157
|
+
async def exchange_refresh_token(
|
|
158
|
+
self,
|
|
159
|
+
client: OAuthClientInformationFull,
|
|
160
|
+
refresh_token: RefreshToken,
|
|
161
|
+
scopes: list[str],
|
|
162
|
+
) -> OAuthToken:
|
|
163
|
+
"""Rotate MCP tokens while preserving upstream refresh semantics."""
|
|
164
|
+
client_id = self._require_client_id(client)
|
|
165
|
+
data = await self.store.get('mcp_refresh_token', refresh_token.token)
|
|
166
|
+
|
|
167
|
+
if not data or data['client_id'] != client_id:
|
|
168
|
+
logger.warning('OAuth refresh rejected: reason=invalid_refresh_token client_id=%s', client_id)
|
|
169
|
+
raise TokenError(error='invalid_grant', error_description='Invalid refresh token')
|
|
170
|
+
|
|
171
|
+
api_access_token = await self.store.get_mapping('mcp_refresh_token', refresh_token.token, 'api_access_token')
|
|
172
|
+
api_refresh_token = await self.store.get_mapping('mcp_refresh_token', refresh_token.token, 'api_refresh_token')
|
|
173
|
+
|
|
174
|
+
if not api_access_token:
|
|
175
|
+
logger.warning('OAuth refresh rejected: reason=mapping_lost client_id=%s', client_id)
|
|
176
|
+
raise TokenError(error='invalid_grant', error_description='Upstream token mapping lost')
|
|
177
|
+
|
|
178
|
+
# Check if upstream access token is still valid
|
|
179
|
+
if not await self.store.get('api_access_token', api_access_token):
|
|
180
|
+
if not api_refresh_token:
|
|
181
|
+
logger.warning('OAuth refresh rejected: reason=upstream_expired_no_refresh client_id=%s', client_id)
|
|
182
|
+
raise TokenError(
|
|
183
|
+
error='invalid_grant',
|
|
184
|
+
error_description='Upstream token expired and no refresh token available; re-authenticate',
|
|
185
|
+
)
|
|
186
|
+
api_access_token, new_refresh, expires_in = await self._request_upstream_token(
|
|
187
|
+
{
|
|
188
|
+
'client_id': self.client_id,
|
|
189
|
+
'client_secret': self.client_secret,
|
|
190
|
+
'refresh_token': api_refresh_token,
|
|
191
|
+
'grant_type': 'refresh_token',
|
|
192
|
+
}
|
|
193
|
+
)
|
|
194
|
+
api_refresh_token = new_refresh or api_refresh_token
|
|
195
|
+
await self._store_api_token(client_id, api_access_token, expires_in)
|
|
196
|
+
|
|
197
|
+
new_token = await self._issue_mcp_token(
|
|
198
|
+
client_id=client_id,
|
|
199
|
+
scopes=scopes or data.get('scopes', []),
|
|
200
|
+
api_access_token=api_access_token,
|
|
201
|
+
api_refresh_token=api_refresh_token,
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
# Revoke old tokens
|
|
205
|
+
old_access = await self.store.get_mapping('mcp_refresh_token', refresh_token.token, 'mcp_access_token')
|
|
206
|
+
if old_access:
|
|
207
|
+
await self.store.delete('mcp_access_token', old_access)
|
|
208
|
+
await self.store.delete('mcp_refresh_token', refresh_token.token)
|
|
209
|
+
|
|
210
|
+
return new_token
|
|
211
|
+
|
|
212
|
+
async def revoke_token(self, token: AccessToken | RefreshToken) -> None:
|
|
213
|
+
"""Delete MCP tokens plus paired refresh/access mappings from ``store``."""
|
|
214
|
+
kind = 'access' if isinstance(token, AccessToken) else 'refresh'
|
|
215
|
+
if isinstance(token, AccessToken):
|
|
216
|
+
paired = await self.store.get_mapping('mcp_access_token', token.token, 'mcp_refresh_token')
|
|
217
|
+
if paired:
|
|
218
|
+
await self.store.delete('mcp_refresh_token', paired)
|
|
219
|
+
await self.store.delete('mcp_access_token', token.token)
|
|
220
|
+
elif isinstance(token, RefreshToken):
|
|
221
|
+
paired = await self.store.get_mapping('mcp_refresh_token', token.token, 'mcp_access_token')
|
|
222
|
+
if paired:
|
|
223
|
+
await self.store.delete('mcp_access_token', paired)
|
|
224
|
+
await self.store.delete('mcp_refresh_token', token.token)
|
|
225
|
+
logger.info('Revoked MCP %s token (prefix=%s)', kind, self._prefix)
|
|
226
|
+
|
|
227
|
+
# ── Gateway-specific methods ──
|
|
228
|
+
|
|
229
|
+
async def handle_upstream_callback(self, code: str, state: str) -> str:
|
|
230
|
+
"""Finish the browser redirect: swap upstream ``code`` for MCP authorization artifacts.
|
|
231
|
+
|
|
232
|
+
Validates ``state``, exchanges tokens at ``upstream_token_url``, persists API
|
|
233
|
+
credentials, builds an MCP authorization code, and returns the client redirect URI.
|
|
234
|
+
"""
|
|
235
|
+
state_data = await self.store.get('mcp_auth_state', state)
|
|
236
|
+
if not state_data:
|
|
237
|
+
logger.warning('OAuth callback rejected: reason=invalid_state state=%s', state)
|
|
238
|
+
raise HTTPException(400, 'Invalid state parameter')
|
|
239
|
+
|
|
240
|
+
redirect_uri = state_data['redirect_uri']
|
|
241
|
+
code_challenge = state_data['code_challenge']
|
|
242
|
+
redirect_uri_provided_explicitly = state_data['redirect_uri_provided_explicitly']
|
|
243
|
+
client_id = state_data['client_id']
|
|
244
|
+
|
|
245
|
+
# Exchange upstream code for tokens
|
|
246
|
+
api_access_token, api_refresh_token, expires_in = await self._request_upstream_token(
|
|
247
|
+
{
|
|
248
|
+
'client_id': self.client_id,
|
|
249
|
+
'client_secret': self.client_secret,
|
|
250
|
+
'code': code,
|
|
251
|
+
'redirect_uri': self.callback_url,
|
|
252
|
+
'grant_type': 'authorization_code',
|
|
253
|
+
}
|
|
254
|
+
)
|
|
255
|
+
|
|
256
|
+
# Create MCP auth code
|
|
257
|
+
mcp_auth_code = f'mcp_{secrets.token_hex(16)}'
|
|
258
|
+
await self.store.set(
|
|
259
|
+
'mcp_auth_code',
|
|
260
|
+
mcp_auth_code,
|
|
261
|
+
{
|
|
262
|
+
'code': mcp_auth_code,
|
|
263
|
+
'client_id': client_id,
|
|
264
|
+
'redirect_uri': redirect_uri,
|
|
265
|
+
'redirect_uri_provided_explicitly': redirect_uri_provided_explicitly,
|
|
266
|
+
'expires_at': time.time() + 300,
|
|
267
|
+
'scopes': MCP_SCOPES,
|
|
268
|
+
'code_challenge': code_challenge,
|
|
269
|
+
},
|
|
270
|
+
ttl=300,
|
|
271
|
+
)
|
|
272
|
+
|
|
273
|
+
# Store upstream tokens
|
|
274
|
+
await self._store_api_token(client_id, api_access_token, expires_in)
|
|
275
|
+
|
|
276
|
+
# Create mappings
|
|
277
|
+
await self.store.set_mapping('mcp_auth_code', mcp_auth_code, 'api_access_token', api_access_token, ttl=300)
|
|
278
|
+
if api_refresh_token:
|
|
279
|
+
await self.store.set_mapping(
|
|
280
|
+
'mcp_auth_code', mcp_auth_code, 'api_refresh_token', api_refresh_token, ttl=300
|
|
281
|
+
)
|
|
282
|
+
await self.store.set_mapping('client', client_id, 'api_access_token', api_access_token)
|
|
283
|
+
|
|
284
|
+
# Clean up state
|
|
285
|
+
await self.store.delete('mcp_auth_state', state)
|
|
286
|
+
|
|
287
|
+
logger.info(
|
|
288
|
+
'Upstream OAuth callback handled: client_id=%s expires_in=%s refresh=%s',
|
|
289
|
+
client_id,
|
|
290
|
+
expires_in,
|
|
291
|
+
bool(api_refresh_token),
|
|
292
|
+
)
|
|
293
|
+
return construct_redirect_uri(redirect_uri, code=mcp_auth_code, state=state)
|
|
294
|
+
|
|
295
|
+
async def get_api_access_token(self) -> str | None:
|
|
296
|
+
"""Map the active MCP access token (request context) to the upstream bearer secret."""
|
|
297
|
+
mcp_access_token = get_access_token()
|
|
298
|
+
if not mcp_access_token:
|
|
299
|
+
return None
|
|
300
|
+
return await self.store.get_mapping('mcp_access_token', mcp_access_token.token, 'api_access_token')
|
|
301
|
+
|
|
302
|
+
# ── Private helpers ──
|
|
303
|
+
|
|
304
|
+
@staticmethod
|
|
305
|
+
def _require_client_id(client: OAuthClientInformationFull) -> str:
|
|
306
|
+
"""Return ``client.client_id`` or raise ``ValueError``."""
|
|
307
|
+
if not client.client_id:
|
|
308
|
+
raise ValueError('client_id is required')
|
|
309
|
+
return client.client_id
|
|
310
|
+
|
|
311
|
+
async def _issue_mcp_token(
|
|
312
|
+
self,
|
|
313
|
+
client_id: str,
|
|
314
|
+
scopes: list[str],
|
|
315
|
+
api_access_token: str,
|
|
316
|
+
api_refresh_token: str | None,
|
|
317
|
+
) -> OAuthToken:
|
|
318
|
+
"""Mint MCP access/refresh tokens and map them to upstream API tokens."""
|
|
319
|
+
mcp_access = f'mcp_{secrets.token_hex(32)}'
|
|
320
|
+
mcp_refresh = f'mcp_refresh_{secrets.token_hex(32)}'
|
|
321
|
+
now = int(time.time())
|
|
322
|
+
|
|
323
|
+
await self.store.set(
|
|
324
|
+
'mcp_access_token',
|
|
325
|
+
mcp_access,
|
|
326
|
+
{
|
|
327
|
+
'token': mcp_access,
|
|
328
|
+
'client_id': client_id,
|
|
329
|
+
'scopes': scopes,
|
|
330
|
+
'expires_at': now + MCP_ACCESS_TOKEN_TTL,
|
|
331
|
+
},
|
|
332
|
+
ttl=MCP_ACCESS_TOKEN_TTL,
|
|
333
|
+
)
|
|
334
|
+
|
|
335
|
+
await self.store.set(
|
|
336
|
+
'mcp_refresh_token',
|
|
337
|
+
mcp_refresh,
|
|
338
|
+
{
|
|
339
|
+
'token': mcp_refresh,
|
|
340
|
+
'client_id': client_id,
|
|
341
|
+
'scopes': scopes,
|
|
342
|
+
'expires_at': now + MCP_REFRESH_TOKEN_TTL,
|
|
343
|
+
},
|
|
344
|
+
ttl=MCP_REFRESH_TOKEN_TTL,
|
|
345
|
+
)
|
|
346
|
+
|
|
347
|
+
# mcp_access → api_access (for tool calls)
|
|
348
|
+
await self.store.set_mapping(
|
|
349
|
+
'mcp_access_token', mcp_access, 'api_access_token', api_access_token, ttl=MCP_ACCESS_TOKEN_TTL
|
|
350
|
+
)
|
|
351
|
+
# mcp_refresh → api_access (for refresh chain)
|
|
352
|
+
await self.store.set_mapping(
|
|
353
|
+
'mcp_refresh_token', mcp_refresh, 'api_access_token', api_access_token, ttl=MCP_REFRESH_TOKEN_TTL
|
|
354
|
+
)
|
|
355
|
+
if api_refresh_token:
|
|
356
|
+
await self.store.set_mapping(
|
|
357
|
+
'mcp_refresh_token', mcp_refresh, 'api_refresh_token', api_refresh_token, ttl=MCP_REFRESH_TOKEN_TTL
|
|
358
|
+
)
|
|
359
|
+
# Pair access ↔ refresh for revoke lookup
|
|
360
|
+
await self.store.set_mapping(
|
|
361
|
+
'mcp_access_token', mcp_access, 'mcp_refresh_token', mcp_refresh, ttl=MCP_ACCESS_TOKEN_TTL
|
|
362
|
+
)
|
|
363
|
+
await self.store.set_mapping(
|
|
364
|
+
'mcp_refresh_token', mcp_refresh, 'mcp_access_token', mcp_access, ttl=MCP_REFRESH_TOKEN_TTL
|
|
365
|
+
)
|
|
366
|
+
|
|
367
|
+
return OAuthToken(
|
|
368
|
+
access_token=mcp_access,
|
|
369
|
+
refresh_token=mcp_refresh,
|
|
370
|
+
expires_in=MCP_ACCESS_TOKEN_TTL,
|
|
371
|
+
)
|
|
372
|
+
|
|
373
|
+
async def _store_api_token(self, client_id: str, token: str, expires_in: int) -> None:
|
|
374
|
+
"""Persist upstream access token metadata under ``api_access_token`` with TTL."""
|
|
375
|
+
await self.store.set(
|
|
376
|
+
'api_access_token',
|
|
377
|
+
token,
|
|
378
|
+
{
|
|
379
|
+
'token': token,
|
|
380
|
+
'client_id': client_id,
|
|
381
|
+
'expires_at': int(time.time()) + expires_in,
|
|
382
|
+
},
|
|
383
|
+
ttl=expires_in,
|
|
384
|
+
)
|
|
385
|
+
|
|
386
|
+
async def _request_upstream_token(self, request_data: dict[str, typing.Any]) -> tuple[str, str | None, int]:
|
|
387
|
+
"""POST ``request_data`` to upstream token URL; return ``(access, refresh | None, expires_in)``."""
|
|
388
|
+
async with httpx.AsyncClient() as client:
|
|
389
|
+
response = await client.post(
|
|
390
|
+
self.upstream_token_url,
|
|
391
|
+
data=request_data,
|
|
392
|
+
headers={'Accept': 'application/json'},
|
|
393
|
+
)
|
|
394
|
+
|
|
395
|
+
if response.status_code != 200:
|
|
396
|
+
logger.warning(
|
|
397
|
+
'Upstream token exchange failed: status=%d url=%s',
|
|
398
|
+
response.status_code,
|
|
399
|
+
self.upstream_token_url,
|
|
400
|
+
)
|
|
401
|
+
raise HTTPException(400, f'Upstream token exchange failed: {response.text}')
|
|
402
|
+
|
|
403
|
+
data = response.json()
|
|
404
|
+
access_token = data.get('access_token')
|
|
405
|
+
if not access_token:
|
|
406
|
+
logger.warning('Upstream token exchange returned no access_token: url=%s', self.upstream_token_url)
|
|
407
|
+
raise HTTPException(400, 'Upstream returned no access_token')
|
|
408
|
+
|
|
409
|
+
logger.info(
|
|
410
|
+
'Upstream token response: granted_scope=%r expires_in=%s', data.get('scope'), data.get('expires_in')
|
|
411
|
+
)
|
|
412
|
+
return access_token, data.get('refresh_token'), data.get('expires_in', 3600)
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
import abc
|
|
2
|
+
import typing
|
|
3
|
+
|
|
4
|
+
from mcp.server.fastmcp import Context
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class AuthResolver(abc.ABC):
|
|
8
|
+
"""Protocol for building the ``Authorization`` header for upstream HTTP calls."""
|
|
9
|
+
|
|
10
|
+
@abc.abstractmethod
|
|
11
|
+
async def resolve(self, ctx: Context) -> str | None:
|
|
12
|
+
"""Return the header value to send upstream, or ``None`` if unauthenticated.
|
|
13
|
+
|
|
14
|
+
Implementations typically return a ``Bearer`` string or ``None``.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class NullAuthResolver(AuthResolver):
|
|
19
|
+
"""Resolver that always omits authentication (public upstream APIs)."""
|
|
20
|
+
|
|
21
|
+
async def resolve(self, ctx: Context) -> str | None:
|
|
22
|
+
"""Always return ``None``."""
|
|
23
|
+
return None
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class StaticAuthResolver(AuthResolver):
|
|
27
|
+
"""Fixed ``Authorization`` (or raw token) configured when the gateway starts."""
|
|
28
|
+
|
|
29
|
+
def __init__(self, header_value: str) -> None:
|
|
30
|
+
"""Store the literal header value to attach on each request."""
|
|
31
|
+
self._header_value = header_value
|
|
32
|
+
|
|
33
|
+
async def resolve(self, ctx: Context) -> str | None:
|
|
34
|
+
"""Return the configured header string."""
|
|
35
|
+
return self._header_value
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class OAuthAuthResolver(AuthResolver):
|
|
39
|
+
"""Exchange MCP bearer tokens for upstream API bearer tokens via OAuth."""
|
|
40
|
+
|
|
41
|
+
def __init__(self, provider: typing.Any) -> None:
|
|
42
|
+
"""Keep a reference to ``GatewayOAuthProvider`` (``Any`` avoids import cycles)."""
|
|
43
|
+
# provider is GatewayOAuthProvider — use Any to avoid circular import
|
|
44
|
+
self._provider = provider
|
|
45
|
+
|
|
46
|
+
async def resolve(self, ctx: Context) -> str | None:
|
|
47
|
+
"""Lookup upstream token using the current MCP access token context."""
|
|
48
|
+
api_token = await self._provider.get_api_access_token()
|
|
49
|
+
if api_token:
|
|
50
|
+
return f'Bearer {api_token}'
|
|
51
|
+
return None
|