openapi-mcp-gateway 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,15 @@
1
+ from .gateway import Gateway
2
+ from .settings import CORSConfig, GatewayConfig, LoggingConfig, ServerConfig, StoreConfig
3
+ from .stores import MemoryTokenStore, TokenStore
4
+
5
+
6
+ __all__ = [
7
+ 'CORSConfig',
8
+ 'Gateway',
9
+ 'GatewayConfig',
10
+ 'LoggingConfig',
11
+ 'MemoryTokenStore',
12
+ 'ServerConfig',
13
+ 'StoreConfig',
14
+ 'TokenStore',
15
+ ]
@@ -0,0 +1,15 @@
1
+ from .detector import DetectedOAuthFlow, detect_oauth_flows, detect_primary_oauth_flow
2
+ from .provider import GatewayOAuthProvider
3
+ from .resolver import AuthResolver, NullAuthResolver, OAuthAuthResolver, StaticAuthResolver
4
+
5
+
6
+ __all__ = [
7
+ 'AuthResolver',
8
+ 'DetectedOAuthFlow',
9
+ 'GatewayOAuthProvider',
10
+ 'NullAuthResolver',
11
+ 'OAuthAuthResolver',
12
+ 'StaticAuthResolver',
13
+ 'detect_oauth_flows',
14
+ 'detect_primary_oauth_flow',
15
+ ]
@@ -0,0 +1,64 @@
1
+ import typing
2
+
3
+ import pydantic
4
+
5
+ from ..openapi import OpenAPISpec
6
+
7
+
8
+ class DetectedOAuthFlow(pydantic.BaseModel):
9
+ """Minimal OAuth2 flow metadata extracted from ``components.securitySchemes``."""
10
+
11
+ flow_type: typing.Literal['authorization_code', 'client_credentials']
12
+ authorization_url: str | None = None
13
+ token_url: str
14
+ scopes: dict[str, str] = pydantic.Field(default_factory=dict)
15
+
16
+
17
+ def detect_oauth_flows(spec: OpenAPISpec) -> list[DetectedOAuthFlow]:
18
+ """Return every OAuth2 flow advertised under ``securitySchemes``."""
19
+ flows: list[DetectedOAuthFlow] = []
20
+
21
+ for _scheme_name, scheme in spec.security_schemes.items():
22
+ if scheme.get('type') != 'oauth2':
23
+ continue
24
+
25
+ oauth_flows = scheme.get('flows', {})
26
+
27
+ if 'authorizationCode' in oauth_flows:
28
+ flow_data = oauth_flows['authorizationCode']
29
+ flows.append(
30
+ DetectedOAuthFlow(
31
+ flow_type='authorization_code',
32
+ authorization_url=flow_data.get('authorizationUrl'),
33
+ token_url=flow_data['tokenUrl'],
34
+ scopes=flow_data.get('scopes', {}),
35
+ )
36
+ )
37
+
38
+ if 'clientCredentials' in oauth_flows:
39
+ flow_data = oauth_flows['clientCredentials']
40
+ flows.append(
41
+ DetectedOAuthFlow(
42
+ flow_type='client_credentials',
43
+ token_url=flow_data['tokenUrl'],
44
+ scopes=flow_data.get('scopes', {}),
45
+ )
46
+ )
47
+
48
+ return flows
49
+
50
+
51
+ def detect_primary_oauth_flow(spec: OpenAPISpec) -> DetectedOAuthFlow | None:
52
+ """Pick a single OAuth2 flow, favouring ``authorization_code``.
53
+
54
+ Returns ``None`` when the document defines no OAuth2 flows.
55
+ """
56
+ flows = detect_oauth_flows(spec)
57
+ if not flows:
58
+ return None
59
+
60
+ # Prefer authorization_code
61
+ for flow in flows:
62
+ if flow.flow_type == 'authorization_code':
63
+ return flow
64
+ return flows[0]
@@ -0,0 +1,412 @@
1
+ import logging
2
+ import secrets
3
+ import time
4
+ import typing
5
+ import urllib.parse
6
+
7
+ import httpx
8
+ from mcp.server.auth.middleware.auth_context import get_access_token
9
+ from mcp.server.auth.provider import (
10
+ AccessToken,
11
+ AuthorizationCode,
12
+ AuthorizationParams,
13
+ RefreshToken,
14
+ TokenError,
15
+ construct_redirect_uri,
16
+ )
17
+ from mcp.shared.auth import OAuthClientInformationFull, OAuthToken
18
+ from starlette.exceptions import HTTPException
19
+
20
+ from ..stores.base import TokenStore
21
+
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+ MCP_ACCESS_TOKEN_TTL = 3600 # 1 hour
26
+ MCP_REFRESH_TOKEN_TTL = 86400 # 24 hours
27
+ MCP_SCOPES = ['api']
28
+
29
+
30
+ class GatewayOAuthProvider:
31
+ """Implement MCP's OAuth server provider API against an upstream OAuth2 API.
32
+
33
+ Registers MCP clients, forwards browser authorization to the upstream IdP,
34
+ exchanges grants with ``upstream_token_url``, and keeps MCP ↔ upstream token
35
+ mappings inside ``store``.
36
+ """
37
+
38
+ def __init__(
39
+ self,
40
+ store: TokenStore,
41
+ upstream_auth_url: str,
42
+ upstream_token_url: str,
43
+ client_id: str,
44
+ client_secret: str,
45
+ callback_url: str,
46
+ scopes: list[str] | None = None,
47
+ prefix: str = 'gateway',
48
+ ) -> None:
49
+ """Bind token ``store``, upstream endpoints, IdP client credentials, and callback."""
50
+ self.store = store
51
+ self.upstream_auth_url = upstream_auth_url
52
+ self.upstream_token_url = upstream_token_url
53
+ self.client_id = client_id
54
+ self.client_secret = client_secret
55
+ self.callback_url = callback_url
56
+ self.scopes = scopes or []
57
+ self._prefix = prefix
58
+
59
+ # ── MCP SDK OAuthAuthorizationServerProvider interface ──
60
+
61
+ async def get_client(self, client_id: str) -> OAuthClientInformationFull | None:
62
+ """Load MCP OAuth client registration state by ``client_id``."""
63
+ data = await self.store.get('mcp_client', client_id)
64
+ if data:
65
+ return OAuthClientInformationFull(**data)
66
+ return None
67
+
68
+ async def register_client(self, client_info: OAuthClientInformationFull) -> None:
69
+ """Persist ``client_info`` for subsequent authorize/token exchanges."""
70
+ if not client_info.client_id:
71
+ raise ValueError('client_id is required')
72
+ await self.store.set(
73
+ 'mcp_client',
74
+ client_info.client_id,
75
+ client_info.model_dump(exclude_none=True, mode='json'),
76
+ )
77
+ logger.info('Registered MCP client: client_id=%s prefix=%s', client_info.client_id, self._prefix)
78
+
79
+ async def authorize(self, client: OAuthClientInformationFull, params: AuthorizationParams) -> str:
80
+ """Return the upstream authorize URL with PKCE/state payload stored in ``store``."""
81
+ client_id = self._require_client_id(client)
82
+ state = params.state or secrets.token_hex(16)
83
+
84
+ state_data = {
85
+ 'redirect_uri': str(params.redirect_uri),
86
+ 'code_challenge': params.code_challenge,
87
+ 'redirect_uri_provided_explicitly': params.redirect_uri_provided_explicitly,
88
+ 'client_id': client_id,
89
+ }
90
+ await self.store.set('mcp_auth_state', state, state_data, ttl=900)
91
+
92
+ query_params = {
93
+ 'client_id': self.client_id,
94
+ 'redirect_uri': self.callback_url,
95
+ 'state': state,
96
+ 'response_type': 'code',
97
+ }
98
+ if self.scopes:
99
+ query_params['scope'] = ' '.join(self.scopes)
100
+
101
+ logger.info('Upstream OAuth authorize: scopes=%s', self.scopes)
102
+ return f'{self.upstream_auth_url}?{urllib.parse.urlencode(query_params)}'
103
+
104
+ async def load_authorization_code(
105
+ self, client: OAuthClientInformationFull, authorization_code: str
106
+ ) -> AuthorizationCode | None:
107
+ """Hydrate ``AuthorizationCode`` when it belongs to ``client``."""
108
+ client_id = self._require_client_id(client)
109
+ data = await self.store.get('mcp_auth_code', authorization_code)
110
+ if data and data['client_id'] == client_id:
111
+ return AuthorizationCode(**data)
112
+ return None
113
+
114
+ async def exchange_authorization_code(
115
+ self, client: OAuthClientInformationFull, authorization_code: AuthorizationCode
116
+ ) -> OAuthToken:
117
+ """Exchange an MCP authorization code for MCP access/refresh tokens."""
118
+ client_id = self._require_client_id(client)
119
+ data = await self.store.get('mcp_auth_code', authorization_code.code)
120
+
121
+ if not data or data['client_id'] != client_id:
122
+ logger.warning('OAuth code exchange rejected: reason=invalid_code client_id=%s', client_id)
123
+ raise TokenError(error='invalid_grant', error_description='Invalid authorization code')
124
+
125
+ api_access_token = await self.store.get_mapping(
126
+ 'mcp_auth_code', authorization_code.code, 'api_access_token'
127
+ ) or await self.store.get_mapping('client', client_id, 'api_access_token')
128
+
129
+ if not api_access_token:
130
+ logger.warning('OAuth code exchange rejected: reason=no_upstream_token client_id=%s', client_id)
131
+ raise TokenError(error='invalid_grant', error_description='No upstream API token found')
132
+
133
+ api_refresh_token = await self.store.get_mapping('mcp_auth_code', authorization_code.code, 'api_refresh_token')
134
+
135
+ return await self._issue_mcp_token(
136
+ client_id=client_id,
137
+ scopes=authorization_code.scopes,
138
+ api_access_token=api_access_token,
139
+ api_refresh_token=api_refresh_token,
140
+ )
141
+
142
+ async def load_access_token(self, token: str) -> AccessToken | None:
143
+ """Return MCP access token payload when ``token`` is valid."""
144
+ data = await self.store.get('mcp_access_token', token)
145
+ if data:
146
+ return AccessToken(**data)
147
+ return None
148
+
149
+ async def load_refresh_token(self, client: OAuthClientInformationFull, refresh_token: str) -> RefreshToken | None:
150
+ """Return MCP refresh metadata scoped to ``client``."""
151
+ client_id = self._require_client_id(client)
152
+ data = await self.store.get('mcp_refresh_token', refresh_token)
153
+ if data and data['client_id'] == client_id:
154
+ return RefreshToken(**data)
155
+ return None
156
+
157
+ async def exchange_refresh_token(
158
+ self,
159
+ client: OAuthClientInformationFull,
160
+ refresh_token: RefreshToken,
161
+ scopes: list[str],
162
+ ) -> OAuthToken:
163
+ """Rotate MCP tokens while preserving upstream refresh semantics."""
164
+ client_id = self._require_client_id(client)
165
+ data = await self.store.get('mcp_refresh_token', refresh_token.token)
166
+
167
+ if not data or data['client_id'] != client_id:
168
+ logger.warning('OAuth refresh rejected: reason=invalid_refresh_token client_id=%s', client_id)
169
+ raise TokenError(error='invalid_grant', error_description='Invalid refresh token')
170
+
171
+ api_access_token = await self.store.get_mapping('mcp_refresh_token', refresh_token.token, 'api_access_token')
172
+ api_refresh_token = await self.store.get_mapping('mcp_refresh_token', refresh_token.token, 'api_refresh_token')
173
+
174
+ if not api_access_token:
175
+ logger.warning('OAuth refresh rejected: reason=mapping_lost client_id=%s', client_id)
176
+ raise TokenError(error='invalid_grant', error_description='Upstream token mapping lost')
177
+
178
+ # Check if upstream access token is still valid
179
+ if not await self.store.get('api_access_token', api_access_token):
180
+ if not api_refresh_token:
181
+ logger.warning('OAuth refresh rejected: reason=upstream_expired_no_refresh client_id=%s', client_id)
182
+ raise TokenError(
183
+ error='invalid_grant',
184
+ error_description='Upstream token expired and no refresh token available; re-authenticate',
185
+ )
186
+ api_access_token, new_refresh, expires_in = await self._request_upstream_token(
187
+ {
188
+ 'client_id': self.client_id,
189
+ 'client_secret': self.client_secret,
190
+ 'refresh_token': api_refresh_token,
191
+ 'grant_type': 'refresh_token',
192
+ }
193
+ )
194
+ api_refresh_token = new_refresh or api_refresh_token
195
+ await self._store_api_token(client_id, api_access_token, expires_in)
196
+
197
+ new_token = await self._issue_mcp_token(
198
+ client_id=client_id,
199
+ scopes=scopes or data.get('scopes', []),
200
+ api_access_token=api_access_token,
201
+ api_refresh_token=api_refresh_token,
202
+ )
203
+
204
+ # Revoke old tokens
205
+ old_access = await self.store.get_mapping('mcp_refresh_token', refresh_token.token, 'mcp_access_token')
206
+ if old_access:
207
+ await self.store.delete('mcp_access_token', old_access)
208
+ await self.store.delete('mcp_refresh_token', refresh_token.token)
209
+
210
+ return new_token
211
+
212
+ async def revoke_token(self, token: AccessToken | RefreshToken) -> None:
213
+ """Delete MCP tokens plus paired refresh/access mappings from ``store``."""
214
+ kind = 'access' if isinstance(token, AccessToken) else 'refresh'
215
+ if isinstance(token, AccessToken):
216
+ paired = await self.store.get_mapping('mcp_access_token', token.token, 'mcp_refresh_token')
217
+ if paired:
218
+ await self.store.delete('mcp_refresh_token', paired)
219
+ await self.store.delete('mcp_access_token', token.token)
220
+ elif isinstance(token, RefreshToken):
221
+ paired = await self.store.get_mapping('mcp_refresh_token', token.token, 'mcp_access_token')
222
+ if paired:
223
+ await self.store.delete('mcp_access_token', paired)
224
+ await self.store.delete('mcp_refresh_token', token.token)
225
+ logger.info('Revoked MCP %s token (prefix=%s)', kind, self._prefix)
226
+
227
+ # ── Gateway-specific methods ──
228
+
229
+ async def handle_upstream_callback(self, code: str, state: str) -> str:
230
+ """Finish the browser redirect: swap upstream ``code`` for MCP authorization artifacts.
231
+
232
+ Validates ``state``, exchanges tokens at ``upstream_token_url``, persists API
233
+ credentials, builds an MCP authorization code, and returns the client redirect URI.
234
+ """
235
+ state_data = await self.store.get('mcp_auth_state', state)
236
+ if not state_data:
237
+ logger.warning('OAuth callback rejected: reason=invalid_state state=%s', state)
238
+ raise HTTPException(400, 'Invalid state parameter')
239
+
240
+ redirect_uri = state_data['redirect_uri']
241
+ code_challenge = state_data['code_challenge']
242
+ redirect_uri_provided_explicitly = state_data['redirect_uri_provided_explicitly']
243
+ client_id = state_data['client_id']
244
+
245
+ # Exchange upstream code for tokens
246
+ api_access_token, api_refresh_token, expires_in = await self._request_upstream_token(
247
+ {
248
+ 'client_id': self.client_id,
249
+ 'client_secret': self.client_secret,
250
+ 'code': code,
251
+ 'redirect_uri': self.callback_url,
252
+ 'grant_type': 'authorization_code',
253
+ }
254
+ )
255
+
256
+ # Create MCP auth code
257
+ mcp_auth_code = f'mcp_{secrets.token_hex(16)}'
258
+ await self.store.set(
259
+ 'mcp_auth_code',
260
+ mcp_auth_code,
261
+ {
262
+ 'code': mcp_auth_code,
263
+ 'client_id': client_id,
264
+ 'redirect_uri': redirect_uri,
265
+ 'redirect_uri_provided_explicitly': redirect_uri_provided_explicitly,
266
+ 'expires_at': time.time() + 300,
267
+ 'scopes': MCP_SCOPES,
268
+ 'code_challenge': code_challenge,
269
+ },
270
+ ttl=300,
271
+ )
272
+
273
+ # Store upstream tokens
274
+ await self._store_api_token(client_id, api_access_token, expires_in)
275
+
276
+ # Create mappings
277
+ await self.store.set_mapping('mcp_auth_code', mcp_auth_code, 'api_access_token', api_access_token, ttl=300)
278
+ if api_refresh_token:
279
+ await self.store.set_mapping(
280
+ 'mcp_auth_code', mcp_auth_code, 'api_refresh_token', api_refresh_token, ttl=300
281
+ )
282
+ await self.store.set_mapping('client', client_id, 'api_access_token', api_access_token)
283
+
284
+ # Clean up state
285
+ await self.store.delete('mcp_auth_state', state)
286
+
287
+ logger.info(
288
+ 'Upstream OAuth callback handled: client_id=%s expires_in=%s refresh=%s',
289
+ client_id,
290
+ expires_in,
291
+ bool(api_refresh_token),
292
+ )
293
+ return construct_redirect_uri(redirect_uri, code=mcp_auth_code, state=state)
294
+
295
+ async def get_api_access_token(self) -> str | None:
296
+ """Map the active MCP access token (request context) to the upstream bearer secret."""
297
+ mcp_access_token = get_access_token()
298
+ if not mcp_access_token:
299
+ return None
300
+ return await self.store.get_mapping('mcp_access_token', mcp_access_token.token, 'api_access_token')
301
+
302
+ # ── Private helpers ──
303
+
304
+ @staticmethod
305
+ def _require_client_id(client: OAuthClientInformationFull) -> str:
306
+ """Return ``client.client_id`` or raise ``ValueError``."""
307
+ if not client.client_id:
308
+ raise ValueError('client_id is required')
309
+ return client.client_id
310
+
311
+ async def _issue_mcp_token(
312
+ self,
313
+ client_id: str,
314
+ scopes: list[str],
315
+ api_access_token: str,
316
+ api_refresh_token: str | None,
317
+ ) -> OAuthToken:
318
+ """Mint MCP access/refresh tokens and map them to upstream API tokens."""
319
+ mcp_access = f'mcp_{secrets.token_hex(32)}'
320
+ mcp_refresh = f'mcp_refresh_{secrets.token_hex(32)}'
321
+ now = int(time.time())
322
+
323
+ await self.store.set(
324
+ 'mcp_access_token',
325
+ mcp_access,
326
+ {
327
+ 'token': mcp_access,
328
+ 'client_id': client_id,
329
+ 'scopes': scopes,
330
+ 'expires_at': now + MCP_ACCESS_TOKEN_TTL,
331
+ },
332
+ ttl=MCP_ACCESS_TOKEN_TTL,
333
+ )
334
+
335
+ await self.store.set(
336
+ 'mcp_refresh_token',
337
+ mcp_refresh,
338
+ {
339
+ 'token': mcp_refresh,
340
+ 'client_id': client_id,
341
+ 'scopes': scopes,
342
+ 'expires_at': now + MCP_REFRESH_TOKEN_TTL,
343
+ },
344
+ ttl=MCP_REFRESH_TOKEN_TTL,
345
+ )
346
+
347
+ # mcp_access → api_access (for tool calls)
348
+ await self.store.set_mapping(
349
+ 'mcp_access_token', mcp_access, 'api_access_token', api_access_token, ttl=MCP_ACCESS_TOKEN_TTL
350
+ )
351
+ # mcp_refresh → api_access (for refresh chain)
352
+ await self.store.set_mapping(
353
+ 'mcp_refresh_token', mcp_refresh, 'api_access_token', api_access_token, ttl=MCP_REFRESH_TOKEN_TTL
354
+ )
355
+ if api_refresh_token:
356
+ await self.store.set_mapping(
357
+ 'mcp_refresh_token', mcp_refresh, 'api_refresh_token', api_refresh_token, ttl=MCP_REFRESH_TOKEN_TTL
358
+ )
359
+ # Pair access ↔ refresh for revoke lookup
360
+ await self.store.set_mapping(
361
+ 'mcp_access_token', mcp_access, 'mcp_refresh_token', mcp_refresh, ttl=MCP_ACCESS_TOKEN_TTL
362
+ )
363
+ await self.store.set_mapping(
364
+ 'mcp_refresh_token', mcp_refresh, 'mcp_access_token', mcp_access, ttl=MCP_REFRESH_TOKEN_TTL
365
+ )
366
+
367
+ return OAuthToken(
368
+ access_token=mcp_access,
369
+ refresh_token=mcp_refresh,
370
+ expires_in=MCP_ACCESS_TOKEN_TTL,
371
+ )
372
+
373
+ async def _store_api_token(self, client_id: str, token: str, expires_in: int) -> None:
374
+ """Persist upstream access token metadata under ``api_access_token`` with TTL."""
375
+ await self.store.set(
376
+ 'api_access_token',
377
+ token,
378
+ {
379
+ 'token': token,
380
+ 'client_id': client_id,
381
+ 'expires_at': int(time.time()) + expires_in,
382
+ },
383
+ ttl=expires_in,
384
+ )
385
+
386
+ async def _request_upstream_token(self, request_data: dict[str, typing.Any]) -> tuple[str, str | None, int]:
387
+ """POST ``request_data`` to upstream token URL; return ``(access, refresh | None, expires_in)``."""
388
+ async with httpx.AsyncClient() as client:
389
+ response = await client.post(
390
+ self.upstream_token_url,
391
+ data=request_data,
392
+ headers={'Accept': 'application/json'},
393
+ )
394
+
395
+ if response.status_code != 200:
396
+ logger.warning(
397
+ 'Upstream token exchange failed: status=%d url=%s',
398
+ response.status_code,
399
+ self.upstream_token_url,
400
+ )
401
+ raise HTTPException(400, f'Upstream token exchange failed: {response.text}')
402
+
403
+ data = response.json()
404
+ access_token = data.get('access_token')
405
+ if not access_token:
406
+ logger.warning('Upstream token exchange returned no access_token: url=%s', self.upstream_token_url)
407
+ raise HTTPException(400, 'Upstream returned no access_token')
408
+
409
+ logger.info(
410
+ 'Upstream token response: granted_scope=%r expires_in=%s', data.get('scope'), data.get('expires_in')
411
+ )
412
+ return access_token, data.get('refresh_token'), data.get('expires_in', 3600)
@@ -0,0 +1,51 @@
1
+ import abc
2
+ import typing
3
+
4
+ from mcp.server.fastmcp import Context
5
+
6
+
7
+ class AuthResolver(abc.ABC):
8
+ """Protocol for building the ``Authorization`` header for upstream HTTP calls."""
9
+
10
+ @abc.abstractmethod
11
+ async def resolve(self, ctx: Context) -> str | None:
12
+ """Return the header value to send upstream, or ``None`` if unauthenticated.
13
+
14
+ Implementations typically return a ``Bearer`` string or ``None``.
15
+ """
16
+
17
+
18
+ class NullAuthResolver(AuthResolver):
19
+ """Resolver that always omits authentication (public upstream APIs)."""
20
+
21
+ async def resolve(self, ctx: Context) -> str | None:
22
+ """Always return ``None``."""
23
+ return None
24
+
25
+
26
+ class StaticAuthResolver(AuthResolver):
27
+ """Fixed ``Authorization`` (or raw token) configured when the gateway starts."""
28
+
29
+ def __init__(self, header_value: str) -> None:
30
+ """Store the literal header value to attach on each request."""
31
+ self._header_value = header_value
32
+
33
+ async def resolve(self, ctx: Context) -> str | None:
34
+ """Return the configured header string."""
35
+ return self._header_value
36
+
37
+
38
+ class OAuthAuthResolver(AuthResolver):
39
+ """Exchange MCP bearer tokens for upstream API bearer tokens via OAuth."""
40
+
41
+ def __init__(self, provider: typing.Any) -> None:
42
+ """Keep a reference to ``GatewayOAuthProvider`` (``Any`` avoids import cycles)."""
43
+ # provider is GatewayOAuthProvider — use Any to avoid circular import
44
+ self._provider = provider
45
+
46
+ async def resolve(self, ctx: Context) -> str | None:
47
+ """Lookup upstream token using the current MCP access token context."""
48
+ api_token = await self._provider.get_api_access_token()
49
+ if api_token:
50
+ return f'Bearer {api_token}'
51
+ return None