amazon-ads-mcp 0.2.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- amazon_ads_mcp/__init__.py +11 -0
- amazon_ads_mcp/auth/__init__.py +33 -0
- amazon_ads_mcp/auth/base.py +211 -0
- amazon_ads_mcp/auth/hooks.py +172 -0
- amazon_ads_mcp/auth/manager.py +791 -0
- amazon_ads_mcp/auth/oauth_state_store.py +277 -0
- amazon_ads_mcp/auth/providers/__init__.py +14 -0
- amazon_ads_mcp/auth/providers/direct.py +393 -0
- amazon_ads_mcp/auth/providers/example_auth0.py.example +216 -0
- amazon_ads_mcp/auth/providers/openbridge.py +512 -0
- amazon_ads_mcp/auth/registry.py +146 -0
- amazon_ads_mcp/auth/secure_token_store.py +297 -0
- amazon_ads_mcp/auth/token_store.py +723 -0
- amazon_ads_mcp/config/__init__.py +5 -0
- amazon_ads_mcp/config/sampling.py +111 -0
- amazon_ads_mcp/config/settings.py +366 -0
- amazon_ads_mcp/exceptions.py +314 -0
- amazon_ads_mcp/middleware/__init__.py +11 -0
- amazon_ads_mcp/middleware/authentication.py +1474 -0
- amazon_ads_mcp/middleware/caching.py +177 -0
- amazon_ads_mcp/middleware/oauth.py +175 -0
- amazon_ads_mcp/middleware/sampling.py +112 -0
- amazon_ads_mcp/models/__init__.py +320 -0
- amazon_ads_mcp/models/amc_models.py +837 -0
- amazon_ads_mcp/models/api_responses.py +847 -0
- amazon_ads_mcp/models/base_models.py +215 -0
- amazon_ads_mcp/models/builtin_responses.py +496 -0
- amazon_ads_mcp/models/dsp_models.py +556 -0
- amazon_ads_mcp/models/stores_brands.py +610 -0
- amazon_ads_mcp/server/__init__.py +6 -0
- amazon_ads_mcp/server/__main__.py +6 -0
- amazon_ads_mcp/server/builtin_prompts.py +269 -0
- amazon_ads_mcp/server/builtin_tools.py +962 -0
- amazon_ads_mcp/server/file_routes.py +547 -0
- amazon_ads_mcp/server/html_templates.py +149 -0
- amazon_ads_mcp/server/mcp_server.py +327 -0
- amazon_ads_mcp/server/openapi_utils.py +158 -0
- amazon_ads_mcp/server/sampling_handler.py +251 -0
- amazon_ads_mcp/server/server_builder.py +751 -0
- amazon_ads_mcp/server/sidecar_loader.py +178 -0
- amazon_ads_mcp/server/transform_executor.py +827 -0
- amazon_ads_mcp/tools/__init__.py +22 -0
- amazon_ads_mcp/tools/cache_management.py +105 -0
- amazon_ads_mcp/tools/download_tools.py +267 -0
- amazon_ads_mcp/tools/identity.py +236 -0
- amazon_ads_mcp/tools/oauth.py +598 -0
- amazon_ads_mcp/tools/profile.py +150 -0
- amazon_ads_mcp/tools/profile_listing.py +285 -0
- amazon_ads_mcp/tools/region.py +320 -0
- amazon_ads_mcp/tools/region_identity.py +175 -0
- amazon_ads_mcp/utils/__init__.py +6 -0
- amazon_ads_mcp/utils/async_compat.py +215 -0
- amazon_ads_mcp/utils/errors.py +452 -0
- amazon_ads_mcp/utils/export_content_type_resolver.py +249 -0
- amazon_ads_mcp/utils/export_download_handler.py +579 -0
- amazon_ads_mcp/utils/header_resolver.py +81 -0
- amazon_ads_mcp/utils/http/__init__.py +56 -0
- amazon_ads_mcp/utils/http/circuit_breaker.py +127 -0
- amazon_ads_mcp/utils/http/client_manager.py +329 -0
- amazon_ads_mcp/utils/http/request.py +207 -0
- amazon_ads_mcp/utils/http/resilience.py +512 -0
- amazon_ads_mcp/utils/http/resilient_client.py +195 -0
- amazon_ads_mcp/utils/http/retry.py +76 -0
- amazon_ads_mcp/utils/http_client.py +873 -0
- amazon_ads_mcp/utils/media/__init__.py +21 -0
- amazon_ads_mcp/utils/media/negotiator.py +243 -0
- amazon_ads_mcp/utils/media/types.py +199 -0
- amazon_ads_mcp/utils/openapi/__init__.py +16 -0
- amazon_ads_mcp/utils/openapi/json.py +55 -0
- amazon_ads_mcp/utils/openapi/loader.py +263 -0
- amazon_ads_mcp/utils/openapi/refs.py +46 -0
- amazon_ads_mcp/utils/region_config.py +200 -0
- amazon_ads_mcp/utils/response_wrapper.py +171 -0
- amazon_ads_mcp/utils/sampling_helpers.py +156 -0
- amazon_ads_mcp/utils/sampling_wrapper.py +173 -0
- amazon_ads_mcp/utils/security.py +630 -0
- amazon_ads_mcp/utils/tool_naming.py +137 -0
- amazon_ads_mcp-0.2.7.dist-info/METADATA +664 -0
- amazon_ads_mcp-0.2.7.dist-info/RECORD +82 -0
- amazon_ads_mcp-0.2.7.dist-info/WHEEL +4 -0
- amazon_ads_mcp-0.2.7.dist-info/entry_points.txt +3 -0
- amazon_ads_mcp-0.2.7.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,177 @@
|
|
|
1
|
+
"""Response caching middleware for Amazon Ads MCP.
|
|
2
|
+
|
|
3
|
+
This module provides a security-aware caching configuration that prevents
|
|
4
|
+
cross-account data leakage in multi-tenant scenarios.
|
|
5
|
+
|
|
6
|
+
Security Considerations
|
|
7
|
+
-----------------------
|
|
8
|
+
Amazon Ads MCP operates in a multi-tenant context where:
|
|
9
|
+
- Different profiles have different data access
|
|
10
|
+
- Region affects API endpoints and responses
|
|
11
|
+
- Account ID determines data isolation
|
|
12
|
+
|
|
13
|
+
The default FastMCP cache key (method + arguments) is UNSAFE because:
|
|
14
|
+
- Profile context is implicit (via Amazon-Advertising-API-Scope header)
|
|
15
|
+
- Results vary by active profile, but cache key doesn't include it
|
|
16
|
+
- OpenBridge identities add another dimension of isolation
|
|
17
|
+
|
|
18
|
+
Safe Caching Strategy
|
|
19
|
+
---------------------
|
|
20
|
+
1. WHITELIST ONLY - explicit list of safe-to-cache tools
|
|
21
|
+
2. STATIC DATA ONLY - tools that return the same data regardless of context
|
|
22
|
+
3. NO API CALLS - only cache server-local metadata
|
|
23
|
+
|
|
24
|
+
Examples
|
|
25
|
+
--------
|
|
26
|
+
.. code-block:: python
|
|
27
|
+
|
|
28
|
+
from amazon_ads_mcp.middleware.caching import create_caching_middleware
|
|
29
|
+
|
|
30
|
+
middleware = create_caching_middleware()
|
|
31
|
+
server.add_middleware(middleware)
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
import logging
|
|
35
|
+
from typing import Optional, Set
|
|
36
|
+
|
|
37
|
+
from fastmcp.server.middleware.caching import (
|
|
38
|
+
CallToolSettings,
|
|
39
|
+
ListPromptsSettings,
|
|
40
|
+
ListResourcesSettings,
|
|
41
|
+
ListToolsSettings,
|
|
42
|
+
ResponseCachingMiddleware,
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
logger = logging.getLogger(__name__)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
# Tools that are SAFE to cache (static data, no profile dependency)
|
|
49
|
+
SAFE_TO_CACHE_TOOLS: Set[str] = {
|
|
50
|
+
# Region configuration - static server data
|
|
51
|
+
"list_regions",
|
|
52
|
+
# Downloads - local filesystem, not API-dependent
|
|
53
|
+
"list_downloads",
|
|
54
|
+
}
|
|
55
|
+
|
|
56
|
+
# Tools that MUST NOT be cached (profile-dependent, write operations, or dynamic)
|
|
57
|
+
# This is a documentation list - actual enforcement is via whitelist above
|
|
58
|
+
NEVER_CACHE_TOOLS: Set[str] = {
|
|
59
|
+
# Profile/Identity management - state changes
|
|
60
|
+
"set_active_profile",
|
|
61
|
+
"get_active_profile", # Depends on server state
|
|
62
|
+
"clear_active_profile",
|
|
63
|
+
"set_active_identity",
|
|
64
|
+
"get_active_identity", # Depends on server state
|
|
65
|
+
"list_identities", # Depends on auth provider
|
|
66
|
+
# Region management - state changes
|
|
67
|
+
"set_region",
|
|
68
|
+
"get_region", # Depends on server state
|
|
69
|
+
"get_routing_state", # Depends on current routing config
|
|
70
|
+
# OAuth operations - security sensitive
|
|
71
|
+
"start_oauth_flow",
|
|
72
|
+
"check_oauth_status",
|
|
73
|
+
"refresh_oauth_token",
|
|
74
|
+
"clear_oauth_tokens",
|
|
75
|
+
# Download operations - side effects
|
|
76
|
+
"download_export",
|
|
77
|
+
# Sampling - dynamic operations
|
|
78
|
+
"test_sampling",
|
|
79
|
+
# ALL OpenAPI-generated tools - profile-dependent API responses
|
|
80
|
+
# These are excluded by whitelist (not in SAFE_TO_CACHE_TOOLS)
|
|
81
|
+
}
|
|
82
|
+
|
|
83
|
+
# TTL values in seconds
|
|
84
|
+
STATIC_DATA_TTL = 3600 # 1 hour for truly static data
|
|
85
|
+
LIST_METADATA_TTL = 60 # 1 minute for tool/resource/prompt lists
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def create_caching_middleware(
|
|
89
|
+
enabled: bool = True,
|
|
90
|
+
static_ttl: int = STATIC_DATA_TTL,
|
|
91
|
+
list_ttl: int = LIST_METADATA_TTL,
|
|
92
|
+
additional_safe_tools: Optional[Set[str]] = None,
|
|
93
|
+
) -> ResponseCachingMiddleware:
|
|
94
|
+
"""Create a security-aware caching middleware.
|
|
95
|
+
|
|
96
|
+
This middleware implements a conservative whitelist approach:
|
|
97
|
+
- Only explicitly listed tools are cached
|
|
98
|
+
- All OpenAPI-generated tools are excluded (profile-dependent)
|
|
99
|
+
- Write operations are never cached
|
|
100
|
+
|
|
101
|
+
:param enabled: Whether caching is enabled globally
|
|
102
|
+
:param static_ttl: TTL for static data (e.g., list_regions)
|
|
103
|
+
:param list_ttl: TTL for list operations (tools, resources, prompts)
|
|
104
|
+
:param additional_safe_tools: Additional tool names safe to cache
|
|
105
|
+
:return: Configured ResponseCachingMiddleware
|
|
106
|
+
|
|
107
|
+
Example
|
|
108
|
+
-------
|
|
109
|
+
.. code-block:: python
|
|
110
|
+
|
|
111
|
+
middleware = create_caching_middleware(
|
|
112
|
+
static_ttl=1800, # 30 minutes
|
|
113
|
+
additional_safe_tools={"my_static_tool"}
|
|
114
|
+
)
|
|
115
|
+
server.add_middleware(middleware)
|
|
116
|
+
"""
|
|
117
|
+
safe_tools = SAFE_TO_CACHE_TOOLS.copy()
|
|
118
|
+
if additional_safe_tools:
|
|
119
|
+
safe_tools.update(additional_safe_tools)
|
|
120
|
+
|
|
121
|
+
logger.info(
|
|
122
|
+
f"Creating caching middleware with {len(safe_tools)} safe tools: {safe_tools}"
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
return ResponseCachingMiddleware(
|
|
126
|
+
# Tool call caching - WHITELIST ONLY
|
|
127
|
+
call_tool_settings=CallToolSettings(
|
|
128
|
+
enabled=enabled,
|
|
129
|
+
ttl=static_ttl,
|
|
130
|
+
included_tools=list(safe_tools), # Only these tools are cached
|
|
131
|
+
),
|
|
132
|
+
# List operations - safe to cache (server metadata)
|
|
133
|
+
list_tools_settings=ListToolsSettings(
|
|
134
|
+
enabled=enabled,
|
|
135
|
+
ttl=list_ttl,
|
|
136
|
+
),
|
|
137
|
+
list_resources_settings=ListResourcesSettings(
|
|
138
|
+
enabled=enabled,
|
|
139
|
+
ttl=list_ttl,
|
|
140
|
+
),
|
|
141
|
+
list_prompts_settings=ListPromptsSettings(
|
|
142
|
+
enabled=enabled,
|
|
143
|
+
ttl=list_ttl,
|
|
144
|
+
),
|
|
145
|
+
# Resource reads - DISABLED (may be profile-dependent)
|
|
146
|
+
# read_resource_settings=ReadResourceSettings(enabled=False),
|
|
147
|
+
# Prompt gets - DISABLED (may be profile-dependent)
|
|
148
|
+
# get_prompt_settings=GetPromptSettings(enabled=False),
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
# Future enhancement: Custom cache key middleware
|
|
153
|
+
# This would allow caching OpenAPI tools safely by including
|
|
154
|
+
# profile_id/region/account_id in the cache key
|
|
155
|
+
#
|
|
156
|
+
# class ContextAwareCacheKeyMiddleware(Middleware):
|
|
157
|
+
# """Middleware that injects routing context into cache keys.
|
|
158
|
+
#
|
|
159
|
+
# This middleware captures the current profile_id, region, and
|
|
160
|
+
# account_id and stores them in context state for use by a
|
|
161
|
+
# custom caching implementation.
|
|
162
|
+
# """
|
|
163
|
+
#
|
|
164
|
+
# async def on_call_tool(self, context: MiddlewareContext, call_next):
|
|
165
|
+
# # Get current routing context
|
|
166
|
+
# from ..utils.http_client import get_routing_state
|
|
167
|
+
# routing = get_routing_state()
|
|
168
|
+
#
|
|
169
|
+
# # Store in context for cache key generation
|
|
170
|
+
# if context.fastmcp_context:
|
|
171
|
+
# context.fastmcp_context.set_state("cache_context", {
|
|
172
|
+
# "region": routing.get("region"),
|
|
173
|
+
# "profile_id": routing.get("profile_id"),
|
|
174
|
+
# "account_id": routing.get("account_id"),
|
|
175
|
+
# })
|
|
176
|
+
#
|
|
177
|
+
# return await call_next(context)
|
|
@@ -0,0 +1,175 @@
|
|
|
1
|
+
"""OAuth middleware for automatic token injection."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from datetime import datetime, timedelta, timezone
|
|
5
|
+
from typing import Optional
|
|
6
|
+
|
|
7
|
+
import httpx
|
|
8
|
+
from fastmcp.server.middleware import Middleware, MiddlewareContext
|
|
9
|
+
|
|
10
|
+
from ..tools.oauth import OAuthTokens
|
|
11
|
+
from ..utils.region_config import RegionConfig
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class OAuthTokenMiddleware(Middleware):
|
|
17
|
+
"""
|
|
18
|
+
Middleware that automatically injects OAuth tokens into API calls.
|
|
19
|
+
|
|
20
|
+
This middleware:
|
|
21
|
+
1. Checks for stored OAuth tokens in the context state
|
|
22
|
+
2. Refreshes expired access tokens automatically
|
|
23
|
+
3. Injects tokens into the authentication flow
|
|
24
|
+
|
|
25
|
+
Note: This middleware uses the AuthManager public API:
|
|
26
|
+
- get_active_identity() to check current identity
|
|
27
|
+
- set_active_identity() to switch to OAuth identity
|
|
28
|
+
- Stores tokens in context state for providers to access
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
def __init__(self, client_id: str, client_secret: str, region: str = "na"):
|
|
32
|
+
self.client_id = client_id
|
|
33
|
+
self.client_secret = client_secret
|
|
34
|
+
self.region = region
|
|
35
|
+
|
|
36
|
+
async def refresh_token(self, refresh_token: str) -> Optional[dict]:
|
|
37
|
+
"""Refresh an expired access token."""
|
|
38
|
+
token_url = RegionConfig.get_oauth_endpoint(self.region)
|
|
39
|
+
token_data = {
|
|
40
|
+
"grant_type": "refresh_token",
|
|
41
|
+
"refresh_token": refresh_token,
|
|
42
|
+
"client_id": self.client_id,
|
|
43
|
+
"client_secret": self.client_secret,
|
|
44
|
+
}
|
|
45
|
+
|
|
46
|
+
try:
|
|
47
|
+
# Use explicit timeout for OAuth token refresh
|
|
48
|
+
timeout = httpx.Timeout(connect=10.0, read=30.0, write=10.0, pool=10.0)
|
|
49
|
+
async with httpx.AsyncClient(timeout=timeout) as client:
|
|
50
|
+
response = await client.post(token_url, data=token_data)
|
|
51
|
+
|
|
52
|
+
if response.status_code == 200:
|
|
53
|
+
return response.json()
|
|
54
|
+
else:
|
|
55
|
+
logger.error(
|
|
56
|
+
f"Failed to refresh token: {response.status_code} - {response.text}"
|
|
57
|
+
)
|
|
58
|
+
return None
|
|
59
|
+
except Exception as e:
|
|
60
|
+
logger.error(f"Error refreshing token: {e}")
|
|
61
|
+
return None
|
|
62
|
+
|
|
63
|
+
async def on_call_tool(self, context: MiddlewareContext, call_next):
|
|
64
|
+
"""
|
|
65
|
+
Intercept tool calls to inject OAuth tokens if available.
|
|
66
|
+
"""
|
|
67
|
+
# Skip OAuth tools themselves to avoid recursion
|
|
68
|
+
if context.message and hasattr(context.message, "name"):
|
|
69
|
+
tool_name = context.message.name
|
|
70
|
+
if tool_name and "oauth" in tool_name.lower():
|
|
71
|
+
return await call_next(context)
|
|
72
|
+
|
|
73
|
+
# Check for OAuth tokens in state
|
|
74
|
+
if context.fastmcp_context:
|
|
75
|
+
try:
|
|
76
|
+
tokens_data = await context.fastmcp_context.get_state("oauth_tokens")
|
|
77
|
+
|
|
78
|
+
if tokens_data:
|
|
79
|
+
tokens = OAuthTokens(**tokens_data)
|
|
80
|
+
|
|
81
|
+
# Check if token needs refresh
|
|
82
|
+
if tokens.is_expired and tokens.refresh_token:
|
|
83
|
+
logger.info("OAuth access token expired, refreshing...")
|
|
84
|
+
|
|
85
|
+
token_response = await self.refresh_token(tokens.refresh_token)
|
|
86
|
+
if token_response:
|
|
87
|
+
# Update tokens
|
|
88
|
+
tokens.access_token = token_response["access_token"]
|
|
89
|
+
tokens.expires_in = token_response.get("expires_in", 3600)
|
|
90
|
+
tokens.obtained_at = datetime.now(timezone.utc)
|
|
91
|
+
|
|
92
|
+
if "refresh_token" in token_response:
|
|
93
|
+
tokens.refresh_token = token_response["refresh_token"]
|
|
94
|
+
|
|
95
|
+
# Store updated tokens
|
|
96
|
+
await context.fastmcp_context.set_state(
|
|
97
|
+
"oauth_tokens", tokens.model_dump()
|
|
98
|
+
)
|
|
99
|
+
logger.info("OAuth access token refreshed successfully")
|
|
100
|
+
|
|
101
|
+
# If auth manager exists, store tokens through unified token store
|
|
102
|
+
if hasattr(context.fastmcp_context, "auth_manager"):
|
|
103
|
+
auth_manager = context.fastmcp_context.auth_manager
|
|
104
|
+
|
|
105
|
+
# Store tokens in unified token store
|
|
106
|
+
if hasattr(auth_manager, "set_token"):
|
|
107
|
+
from ..auth.token_store import TokenKind
|
|
108
|
+
|
|
109
|
+
# Store access token
|
|
110
|
+
expires_at = tokens.obtained_at + timedelta(
|
|
111
|
+
seconds=tokens.expires_in
|
|
112
|
+
)
|
|
113
|
+
await auth_manager.set_token(
|
|
114
|
+
provider_type="oauth",
|
|
115
|
+
identity_id="oauth",
|
|
116
|
+
token_kind=TokenKind.ACCESS,
|
|
117
|
+
token=tokens.access_token,
|
|
118
|
+
expires_at=expires_at,
|
|
119
|
+
metadata={"token_type": "Bearer"},
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
# Store refresh token
|
|
123
|
+
await auth_manager.set_token(
|
|
124
|
+
provider_type="oauth",
|
|
125
|
+
identity_id="oauth",
|
|
126
|
+
token_kind=TokenKind.REFRESH,
|
|
127
|
+
token=tokens.refresh_token,
|
|
128
|
+
expires_at=datetime.now(timezone.utc)
|
|
129
|
+
+ timedelta(days=365), # Long-lived
|
|
130
|
+
metadata={},
|
|
131
|
+
)
|
|
132
|
+
logger.debug("Stored OAuth tokens in unified token store")
|
|
133
|
+
|
|
134
|
+
# Check current active identity
|
|
135
|
+
active_identity = auth_manager.get_active_identity()
|
|
136
|
+
|
|
137
|
+
# If not using OAuth identity, try to switch
|
|
138
|
+
if not active_identity or active_identity.id != "oauth":
|
|
139
|
+
try:
|
|
140
|
+
# Try to set OAuth as active identity
|
|
141
|
+
# This assumes OAuth provider is configured or identity exists
|
|
142
|
+
await auth_manager.set_active_identity("oauth")
|
|
143
|
+
logger.info("Switched to OAuth authentication identity")
|
|
144
|
+
except Exception as e:
|
|
145
|
+
# OAuth identity doesn't exist or provider not configured for it
|
|
146
|
+
logger.debug(f"Could not switch to OAuth identity: {e}")
|
|
147
|
+
else:
|
|
148
|
+
# Fallback: Store tokens in context for backward compatibility
|
|
149
|
+
await context.fastmcp_context.set_state(
|
|
150
|
+
"current_access_token", tokens.access_token
|
|
151
|
+
)
|
|
152
|
+
await context.fastmcp_context.set_state(
|
|
153
|
+
"current_refresh_token", tokens.refresh_token
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
except Exception as e:
|
|
157
|
+
logger.debug(f"OAuth middleware check: {e}")
|
|
158
|
+
# Continue without OAuth tokens
|
|
159
|
+
|
|
160
|
+
# Continue with the tool call
|
|
161
|
+
return await call_next(context)
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
def create_oauth_middleware():
|
|
165
|
+
"""Create OAuth middleware instance with settings."""
|
|
166
|
+
from ..config.settings import settings
|
|
167
|
+
|
|
168
|
+
if not settings.oauth_client_id or not settings.oauth_client_secret:
|
|
169
|
+
logger.warning("OAuth client credentials not configured")
|
|
170
|
+
return None
|
|
171
|
+
|
|
172
|
+
return OAuthTokenMiddleware(
|
|
173
|
+
client_id=settings.oauth_client_id,
|
|
174
|
+
client_secret=settings.oauth_client_secret,
|
|
175
|
+
)
|
|
@@ -0,0 +1,112 @@
|
|
|
1
|
+
"""Middleware to attach server-side sampling handler to request context."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from typing import Any, Callable
|
|
5
|
+
|
|
6
|
+
from fastmcp import Context
|
|
7
|
+
|
|
8
|
+
logger = logging.getLogger(__name__)
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def create_sampling_middleware(sampling_handler: Any = None) -> Callable:
|
|
12
|
+
"""
|
|
13
|
+
Create middleware that attaches the server's sampling handler to each request context.
|
|
14
|
+
|
|
15
|
+
This allows sample_with_fallback() to discover the handler when the client
|
|
16
|
+
doesn't support sampling.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
sampling_handler: The server's sampling handler instance
|
|
20
|
+
|
|
21
|
+
Returns:
|
|
22
|
+
Middleware function that can be added to the server
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
async def sampling_middleware(request: Any, handler: Callable) -> Any:
|
|
26
|
+
"""
|
|
27
|
+
Middleware that attaches sampling handler to the request context.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
request: The incoming request
|
|
31
|
+
handler: The next handler in the chain
|
|
32
|
+
|
|
33
|
+
Returns:
|
|
34
|
+
Response from the handler chain
|
|
35
|
+
"""
|
|
36
|
+
# Try to get the context from the request
|
|
37
|
+
# FastMCP typically stores context in the request or uses a context var
|
|
38
|
+
try:
|
|
39
|
+
# Method 1: Check if there's a context attribute on the request
|
|
40
|
+
if hasattr(request, "context"):
|
|
41
|
+
ctx = request.context
|
|
42
|
+
if isinstance(ctx, Context):
|
|
43
|
+
# Use the wrapper to provide sampling
|
|
44
|
+
from ..utils.sampling_wrapper import get_sampling_wrapper
|
|
45
|
+
|
|
46
|
+
wrapper = get_sampling_wrapper()
|
|
47
|
+
if wrapper.has_handler():
|
|
48
|
+
# Try to use public API if available, otherwise skip
|
|
49
|
+
if hasattr(ctx, "set_sampling_handler"):
|
|
50
|
+
ctx.set_sampling_handler(wrapper)
|
|
51
|
+
else:
|
|
52
|
+
# Avoid setting private attributes
|
|
53
|
+
logger.debug(
|
|
54
|
+
"Skipping sampling attachment - no public API available"
|
|
55
|
+
)
|
|
56
|
+
logger.debug("Processed sampling handler for request context")
|
|
57
|
+
|
|
58
|
+
# Method 2: Check for FastMCP context in request state
|
|
59
|
+
elif hasattr(request, "state") and hasattr(
|
|
60
|
+
request.state, "fastmcp_context"
|
|
61
|
+
):
|
|
62
|
+
ctx = request.state.fastmcp_context
|
|
63
|
+
if isinstance(ctx, Context):
|
|
64
|
+
from ..utils.sampling_wrapper import get_sampling_wrapper
|
|
65
|
+
|
|
66
|
+
wrapper = get_sampling_wrapper()
|
|
67
|
+
if wrapper.has_handler():
|
|
68
|
+
# Try to use public API if available, otherwise skip
|
|
69
|
+
if hasattr(ctx, "set_sampling_handler"):
|
|
70
|
+
ctx.set_sampling_handler(wrapper)
|
|
71
|
+
else:
|
|
72
|
+
# Avoid setting private attributes
|
|
73
|
+
logger.debug(
|
|
74
|
+
"Skipping sampling attachment - no public API available"
|
|
75
|
+
)
|
|
76
|
+
logger.debug("Processed sampling handler for FastMCP context")
|
|
77
|
+
|
|
78
|
+
# Method 3: Skip private contextvar usage
|
|
79
|
+
else:
|
|
80
|
+
# Avoid using private _current_context API
|
|
81
|
+
logger.debug("Skipping contextvar method to avoid private API usage")
|
|
82
|
+
|
|
83
|
+
except Exception as e:
|
|
84
|
+
logger.debug(f"Could not attach sampling handler to context: {e}")
|
|
85
|
+
|
|
86
|
+
# Continue with the request
|
|
87
|
+
return await handler(request)
|
|
88
|
+
|
|
89
|
+
return sampling_middleware
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def attach_sampling_to_context(ctx: Context) -> None:
|
|
93
|
+
"""
|
|
94
|
+
Helper function to directly attach sampling handler to a context.
|
|
95
|
+
|
|
96
|
+
This can be called from tool handlers or other places where we have
|
|
97
|
+
direct access to the context.
|
|
98
|
+
|
|
99
|
+
Args:
|
|
100
|
+
ctx: The FastMCP context
|
|
101
|
+
"""
|
|
102
|
+
from ..utils.sampling_wrapper import get_sampling_wrapper
|
|
103
|
+
|
|
104
|
+
wrapper = get_sampling_wrapper()
|
|
105
|
+
if wrapper.has_handler():
|
|
106
|
+
# Try to use public API if available, otherwise skip
|
|
107
|
+
if hasattr(ctx, "set_sampling_handler"):
|
|
108
|
+
ctx.set_sampling_handler(wrapper)
|
|
109
|
+
logger.debug("Sampling handler attached to context")
|
|
110
|
+
else:
|
|
111
|
+
# Avoid setting private attributes
|
|
112
|
+
logger.debug("Skipping sampling attachment - no public API available")
|