fastmcp 2.8.1__py3-none-any.whl → 2.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fastmcp/cli/cli.py +99 -1
- fastmcp/cli/run.py +1 -3
- fastmcp/client/auth/oauth.py +1 -2
- fastmcp/client/client.py +21 -5
- fastmcp/client/transports.py +17 -2
- fastmcp/contrib/mcp_mixin/README.md +79 -2
- fastmcp/contrib/mcp_mixin/mcp_mixin.py +14 -0
- fastmcp/prompts/prompt.py +91 -11
- fastmcp/prompts/prompt_manager.py +119 -43
- fastmcp/resources/resource.py +11 -1
- fastmcp/resources/resource_manager.py +249 -76
- fastmcp/resources/template.py +27 -1
- fastmcp/server/auth/providers/bearer.py +32 -10
- fastmcp/server/context.py +41 -2
- fastmcp/server/http.py +8 -0
- fastmcp/server/middleware/__init__.py +6 -0
- fastmcp/server/middleware/error_handling.py +206 -0
- fastmcp/server/middleware/logging.py +165 -0
- fastmcp/server/middleware/middleware.py +236 -0
- fastmcp/server/middleware/rate_limiting.py +231 -0
- fastmcp/server/middleware/timing.py +156 -0
- fastmcp/server/proxy.py +250 -140
- fastmcp/server/server.py +320 -242
- fastmcp/settings.py +2 -2
- fastmcp/tools/tool.py +6 -2
- fastmcp/tools/tool_manager.py +114 -45
- fastmcp/utilities/components.py +22 -2
- fastmcp/utilities/inspect.py +326 -0
- fastmcp/utilities/json_schema.py +67 -23
- fastmcp/utilities/mcp_config.py +13 -7
- fastmcp/utilities/openapi.py +5 -3
- fastmcp/utilities/tests.py +1 -1
- fastmcp/utilities/types.py +90 -1
- {fastmcp-2.8.1.dist-info → fastmcp-2.9.0.dist-info}/METADATA +2 -2
- {fastmcp-2.8.1.dist-info → fastmcp-2.9.0.dist-info}/RECORD +38 -31
- {fastmcp-2.8.1.dist-info → fastmcp-2.9.0.dist-info}/WHEEL +0 -0
- {fastmcp-2.8.1.dist-info → fastmcp-2.9.0.dist-info}/entry_points.txt +0 -0
- {fastmcp-2.8.1.dist-info → fastmcp-2.9.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,236 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from collections.abc import Awaitable
|
|
5
|
+
from dataclasses import dataclass, field, replace
|
|
6
|
+
from datetime import datetime, timezone
|
|
7
|
+
from functools import partial
|
|
8
|
+
from typing import (
|
|
9
|
+
TYPE_CHECKING,
|
|
10
|
+
Any,
|
|
11
|
+
Generic,
|
|
12
|
+
Literal,
|
|
13
|
+
Protocol,
|
|
14
|
+
TypeVar,
|
|
15
|
+
runtime_checkable,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
import mcp.types as mt
|
|
19
|
+
|
|
20
|
+
from fastmcp.prompts.prompt import Prompt
|
|
21
|
+
from fastmcp.resources.resource import Resource
|
|
22
|
+
from fastmcp.resources.template import ResourceTemplate
|
|
23
|
+
from fastmcp.tools.tool import Tool
|
|
24
|
+
|
|
25
|
+
if TYPE_CHECKING:
|
|
26
|
+
from fastmcp.server.context import Context
|
|
27
|
+
|
|
28
|
+
logger = logging.getLogger(__name__)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
T = TypeVar("T")
|
|
32
|
+
R = TypeVar("R", covariant=True)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@runtime_checkable
|
|
36
|
+
class CallNext(Protocol[T, R]):
|
|
37
|
+
def __call__(self, context: MiddlewareContext[T]) -> Awaitable[R]: ...
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
ServerResultT = TypeVar(
|
|
41
|
+
"ServerResultT",
|
|
42
|
+
bound=mt.EmptyResult
|
|
43
|
+
| mt.InitializeResult
|
|
44
|
+
| mt.CompleteResult
|
|
45
|
+
| mt.GetPromptResult
|
|
46
|
+
| mt.ListPromptsResult
|
|
47
|
+
| mt.ListResourcesResult
|
|
48
|
+
| mt.ListResourceTemplatesResult
|
|
49
|
+
| mt.ReadResourceResult
|
|
50
|
+
| mt.CallToolResult
|
|
51
|
+
| mt.ListToolsResult,
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
@dataclass(kw_only=True)
|
|
56
|
+
class CallToolResult:
|
|
57
|
+
content: list[mt.Content]
|
|
58
|
+
isError: bool = False
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
@dataclass(kw_only=True)
|
|
62
|
+
class ListToolsResult:
|
|
63
|
+
tools: dict[str, Tool]
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
@dataclass(kw_only=True)
|
|
67
|
+
class ListResourcesResult:
|
|
68
|
+
resources: list[Resource]
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
@dataclass(kw_only=True)
|
|
72
|
+
class ListResourceTemplatesResult:
|
|
73
|
+
resource_templates: list[ResourceTemplate]
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
@dataclass(kw_only=True)
|
|
77
|
+
class ListPromptsResult:
|
|
78
|
+
prompts: list[Prompt]
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
@runtime_checkable
|
|
82
|
+
class ServerResultProtocol(Protocol[ServerResultT]):
|
|
83
|
+
root: ServerResultT
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
@dataclass(kw_only=True, frozen=True)
|
|
87
|
+
class MiddlewareContext(Generic[T]):
|
|
88
|
+
"""
|
|
89
|
+
Unified context for all middleware operations.
|
|
90
|
+
"""
|
|
91
|
+
|
|
92
|
+
message: T
|
|
93
|
+
|
|
94
|
+
fastmcp_context: Context | None = None
|
|
95
|
+
|
|
96
|
+
# Common metadata
|
|
97
|
+
source: Literal["client", "server"] = "client"
|
|
98
|
+
type: Literal["request", "notification"] = "request"
|
|
99
|
+
method: str | None = None
|
|
100
|
+
timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
|
101
|
+
|
|
102
|
+
def copy(self, **kwargs: Any) -> MiddlewareContext[T]:
|
|
103
|
+
return replace(self, **kwargs)
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def make_middleware_wrapper(
|
|
107
|
+
middleware: Middleware, call_next: CallNext[T, R]
|
|
108
|
+
) -> CallNext[T, R]:
|
|
109
|
+
"""Create a wrapper that applies a single middleware to a context. The
|
|
110
|
+
closure bakes in the middleware and call_next function, so it can be
|
|
111
|
+
passed to other functions that expect a call_next function."""
|
|
112
|
+
|
|
113
|
+
async def wrapper(context: MiddlewareContext[T]) -> R:
|
|
114
|
+
return await middleware(context, call_next)
|
|
115
|
+
|
|
116
|
+
return wrapper
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
class Middleware:
|
|
120
|
+
"""Base class for FastMCP middleware with dispatching hooks."""
|
|
121
|
+
|
|
122
|
+
async def __call__(
|
|
123
|
+
self,
|
|
124
|
+
context: MiddlewareContext[T],
|
|
125
|
+
call_next: CallNext[T, Any],
|
|
126
|
+
) -> Any:
|
|
127
|
+
"""Main entry point that orchestrates the pipeline."""
|
|
128
|
+
handler_chain = await self._dispatch_handler(
|
|
129
|
+
context,
|
|
130
|
+
call_next=call_next,
|
|
131
|
+
)
|
|
132
|
+
return await handler_chain(context)
|
|
133
|
+
|
|
134
|
+
async def _dispatch_handler(
|
|
135
|
+
self, context: MiddlewareContext[Any], call_next: CallNext[Any, Any]
|
|
136
|
+
) -> CallNext[Any, Any]:
|
|
137
|
+
"""Builds a chain of handlers for a given message."""
|
|
138
|
+
handler = call_next
|
|
139
|
+
|
|
140
|
+
match context.method:
|
|
141
|
+
case "tools/call":
|
|
142
|
+
handler = partial(self.on_call_tool, call_next=handler)
|
|
143
|
+
case "resources/read":
|
|
144
|
+
handler = partial(self.on_read_resource, call_next=handler)
|
|
145
|
+
case "prompts/get":
|
|
146
|
+
handler = partial(self.on_get_prompt, call_next=handler)
|
|
147
|
+
case "tools/list":
|
|
148
|
+
handler = partial(self.on_list_tools, call_next=handler)
|
|
149
|
+
case "resources/list":
|
|
150
|
+
handler = partial(self.on_list_resources, call_next=handler)
|
|
151
|
+
case "resources/templates/list":
|
|
152
|
+
handler = partial(self.on_list_resource_templates, call_next=handler)
|
|
153
|
+
case "prompts/list":
|
|
154
|
+
handler = partial(self.on_list_prompts, call_next=handler)
|
|
155
|
+
|
|
156
|
+
match context.type:
|
|
157
|
+
case "request":
|
|
158
|
+
handler = partial(self.on_request, call_next=handler)
|
|
159
|
+
case "notification":
|
|
160
|
+
handler = partial(self.on_notification, call_next=handler)
|
|
161
|
+
|
|
162
|
+
handler = partial(self.on_message, call_next=handler)
|
|
163
|
+
|
|
164
|
+
return handler
|
|
165
|
+
|
|
166
|
+
async def on_message(
|
|
167
|
+
self,
|
|
168
|
+
context: MiddlewareContext[Any],
|
|
169
|
+
call_next: CallNext[Any, Any],
|
|
170
|
+
) -> Any:
|
|
171
|
+
return await call_next(context)
|
|
172
|
+
|
|
173
|
+
async def on_request(
|
|
174
|
+
self,
|
|
175
|
+
context: MiddlewareContext[mt.Request],
|
|
176
|
+
call_next: CallNext[mt.Request, Any],
|
|
177
|
+
) -> Any:
|
|
178
|
+
return await call_next(context)
|
|
179
|
+
|
|
180
|
+
async def on_notification(
|
|
181
|
+
self,
|
|
182
|
+
context: MiddlewareContext[mt.Notification],
|
|
183
|
+
call_next: CallNext[mt.Notification, Any],
|
|
184
|
+
) -> Any:
|
|
185
|
+
return await call_next(context)
|
|
186
|
+
|
|
187
|
+
async def on_call_tool(
|
|
188
|
+
self,
|
|
189
|
+
context: MiddlewareContext[mt.CallToolRequestParams],
|
|
190
|
+
call_next: CallNext[mt.CallToolRequestParams, mt.CallToolResult],
|
|
191
|
+
) -> mt.CallToolResult:
|
|
192
|
+
return await call_next(context)
|
|
193
|
+
|
|
194
|
+
async def on_read_resource(
|
|
195
|
+
self,
|
|
196
|
+
context: MiddlewareContext[mt.ReadResourceRequestParams],
|
|
197
|
+
call_next: CallNext[mt.ReadResourceRequestParams, mt.ReadResourceResult],
|
|
198
|
+
) -> mt.ReadResourceResult:
|
|
199
|
+
return await call_next(context)
|
|
200
|
+
|
|
201
|
+
async def on_get_prompt(
|
|
202
|
+
self,
|
|
203
|
+
context: MiddlewareContext[mt.GetPromptRequestParams],
|
|
204
|
+
call_next: CallNext[mt.GetPromptRequestParams, mt.GetPromptResult],
|
|
205
|
+
) -> mt.GetPromptResult:
|
|
206
|
+
return await call_next(context)
|
|
207
|
+
|
|
208
|
+
async def on_list_tools(
|
|
209
|
+
self,
|
|
210
|
+
context: MiddlewareContext[mt.ListToolsRequest],
|
|
211
|
+
call_next: CallNext[mt.ListToolsRequest, ListToolsResult],
|
|
212
|
+
) -> ListToolsResult:
|
|
213
|
+
return await call_next(context)
|
|
214
|
+
|
|
215
|
+
async def on_list_resources(
|
|
216
|
+
self,
|
|
217
|
+
context: MiddlewareContext[mt.ListResourcesRequest],
|
|
218
|
+
call_next: CallNext[mt.ListResourcesRequest, ListResourcesResult],
|
|
219
|
+
) -> ListResourcesResult:
|
|
220
|
+
return await call_next(context)
|
|
221
|
+
|
|
222
|
+
async def on_list_resource_templates(
|
|
223
|
+
self,
|
|
224
|
+
context: MiddlewareContext[mt.ListResourceTemplatesRequest],
|
|
225
|
+
call_next: CallNext[
|
|
226
|
+
mt.ListResourceTemplatesRequest, ListResourceTemplatesResult
|
|
227
|
+
],
|
|
228
|
+
) -> ListResourceTemplatesResult:
|
|
229
|
+
return await call_next(context)
|
|
230
|
+
|
|
231
|
+
async def on_list_prompts(
|
|
232
|
+
self,
|
|
233
|
+
context: MiddlewareContext[mt.ListPromptsRequest],
|
|
234
|
+
call_next: CallNext[mt.ListPromptsRequest, ListPromptsResult],
|
|
235
|
+
) -> ListPromptsResult:
|
|
236
|
+
return await call_next(context)
|
|
@@ -0,0 +1,231 @@
|
|
|
1
|
+
"""Rate limiting middleware for protecting FastMCP servers from abuse."""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import time
|
|
5
|
+
from collections import defaultdict, deque
|
|
6
|
+
from collections.abc import Callable
|
|
7
|
+
from typing import Any
|
|
8
|
+
|
|
9
|
+
from mcp import McpError
|
|
10
|
+
from mcp.types import ErrorData
|
|
11
|
+
|
|
12
|
+
from .middleware import CallNext, Middleware, MiddlewareContext
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class RateLimitError(McpError):
|
|
16
|
+
"""Error raised when rate limit is exceeded."""
|
|
17
|
+
|
|
18
|
+
def __init__(self, message: str = "Rate limit exceeded"):
|
|
19
|
+
super().__init__(ErrorData(code=-32000, message=message))
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class TokenBucketRateLimiter:
|
|
23
|
+
"""Token bucket implementation for rate limiting."""
|
|
24
|
+
|
|
25
|
+
def __init__(self, capacity: int, refill_rate: float):
|
|
26
|
+
"""Initialize token bucket.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
capacity: Maximum number of tokens in the bucket
|
|
30
|
+
refill_rate: Tokens added per second
|
|
31
|
+
"""
|
|
32
|
+
self.capacity = capacity
|
|
33
|
+
self.refill_rate = refill_rate
|
|
34
|
+
self.tokens = capacity
|
|
35
|
+
self.last_refill = time.time()
|
|
36
|
+
self._lock = asyncio.Lock()
|
|
37
|
+
|
|
38
|
+
async def consume(self, tokens: int = 1) -> bool:
|
|
39
|
+
"""Try to consume tokens from the bucket.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
tokens: Number of tokens to consume
|
|
43
|
+
|
|
44
|
+
Returns:
|
|
45
|
+
True if tokens were available and consumed, False otherwise
|
|
46
|
+
"""
|
|
47
|
+
async with self._lock:
|
|
48
|
+
now = time.time()
|
|
49
|
+
elapsed = now - self.last_refill
|
|
50
|
+
|
|
51
|
+
# Add tokens based on elapsed time
|
|
52
|
+
self.tokens = min(self.capacity, self.tokens + elapsed * self.refill_rate)
|
|
53
|
+
self.last_refill = now
|
|
54
|
+
|
|
55
|
+
if self.tokens >= tokens:
|
|
56
|
+
self.tokens -= tokens
|
|
57
|
+
return True
|
|
58
|
+
return False
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class SlidingWindowRateLimiter:
|
|
62
|
+
"""Sliding window rate limiter implementation."""
|
|
63
|
+
|
|
64
|
+
def __init__(self, max_requests: int, window_seconds: int):
|
|
65
|
+
"""Initialize sliding window rate limiter.
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
max_requests: Maximum requests allowed in the time window
|
|
69
|
+
window_seconds: Time window in seconds
|
|
70
|
+
"""
|
|
71
|
+
self.max_requests = max_requests
|
|
72
|
+
self.window_seconds = window_seconds
|
|
73
|
+
self.requests = deque()
|
|
74
|
+
self._lock = asyncio.Lock()
|
|
75
|
+
|
|
76
|
+
async def is_allowed(self) -> bool:
|
|
77
|
+
"""Check if a request is allowed."""
|
|
78
|
+
async with self._lock:
|
|
79
|
+
now = time.time()
|
|
80
|
+
cutoff = now - self.window_seconds
|
|
81
|
+
|
|
82
|
+
# Remove old requests outside the window
|
|
83
|
+
while self.requests and self.requests[0] < cutoff:
|
|
84
|
+
self.requests.popleft()
|
|
85
|
+
|
|
86
|
+
if len(self.requests) < self.max_requests:
|
|
87
|
+
self.requests.append(now)
|
|
88
|
+
return True
|
|
89
|
+
return False
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
class RateLimitingMiddleware(Middleware):
|
|
93
|
+
"""Middleware that implements rate limiting to prevent server abuse.
|
|
94
|
+
|
|
95
|
+
Uses a token bucket algorithm by default, allowing for burst traffic
|
|
96
|
+
while maintaining a sustainable long-term rate.
|
|
97
|
+
|
|
98
|
+
Example:
|
|
99
|
+
```python
|
|
100
|
+
from fastmcp.server.middleware.rate_limiting import RateLimitingMiddleware
|
|
101
|
+
|
|
102
|
+
# Allow 10 requests per second with bursts up to 20
|
|
103
|
+
rate_limiter = RateLimitingMiddleware(
|
|
104
|
+
max_requests_per_second=10,
|
|
105
|
+
burst_capacity=20
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
mcp = FastMCP("MyServer")
|
|
109
|
+
mcp.add_middleware(rate_limiter)
|
|
110
|
+
```
|
|
111
|
+
"""
|
|
112
|
+
|
|
113
|
+
def __init__(
|
|
114
|
+
self,
|
|
115
|
+
max_requests_per_second: float = 10.0,
|
|
116
|
+
burst_capacity: int | None = None,
|
|
117
|
+
get_client_id: Callable[[MiddlewareContext], str] | None = None,
|
|
118
|
+
global_limit: bool = False,
|
|
119
|
+
):
|
|
120
|
+
"""Initialize rate limiting middleware.
|
|
121
|
+
|
|
122
|
+
Args:
|
|
123
|
+
max_requests_per_second: Sustained requests per second allowed
|
|
124
|
+
burst_capacity: Maximum burst capacity. If None, defaults to 2x max_requests_per_second
|
|
125
|
+
get_client_id: Function to extract client ID from context. If None, uses global limiting
|
|
126
|
+
global_limit: If True, apply limit globally; if False, per-client
|
|
127
|
+
"""
|
|
128
|
+
self.max_requests_per_second = max_requests_per_second
|
|
129
|
+
self.burst_capacity = burst_capacity or int(max_requests_per_second * 2)
|
|
130
|
+
self.get_client_id = get_client_id
|
|
131
|
+
self.global_limit = global_limit
|
|
132
|
+
|
|
133
|
+
# Storage for rate limiters per client
|
|
134
|
+
self.limiters: dict[str, TokenBucketRateLimiter] = defaultdict(
|
|
135
|
+
lambda: TokenBucketRateLimiter(
|
|
136
|
+
self.burst_capacity, self.max_requests_per_second
|
|
137
|
+
)
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
# Global rate limiter
|
|
141
|
+
if self.global_limit:
|
|
142
|
+
self.global_limiter = TokenBucketRateLimiter(
|
|
143
|
+
self.burst_capacity, self.max_requests_per_second
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
def _get_client_identifier(self, context: MiddlewareContext) -> str:
|
|
147
|
+
"""Get client identifier for rate limiting."""
|
|
148
|
+
if self.get_client_id:
|
|
149
|
+
return self.get_client_id(context)
|
|
150
|
+
return "global"
|
|
151
|
+
|
|
152
|
+
async def on_request(self, context: MiddlewareContext, call_next: CallNext) -> Any:
|
|
153
|
+
"""Apply rate limiting to requests."""
|
|
154
|
+
if self.global_limit:
|
|
155
|
+
# Global rate limiting
|
|
156
|
+
allowed = await self.global_limiter.consume()
|
|
157
|
+
if not allowed:
|
|
158
|
+
raise RateLimitError("Global rate limit exceeded")
|
|
159
|
+
else:
|
|
160
|
+
# Per-client rate limiting
|
|
161
|
+
client_id = self._get_client_identifier(context)
|
|
162
|
+
limiter = self.limiters[client_id]
|
|
163
|
+
allowed = await limiter.consume()
|
|
164
|
+
if not allowed:
|
|
165
|
+
raise RateLimitError(f"Rate limit exceeded for client: {client_id}")
|
|
166
|
+
|
|
167
|
+
return await call_next(context)
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
class SlidingWindowRateLimitingMiddleware(Middleware):
|
|
171
|
+
"""Middleware that implements sliding window rate limiting.
|
|
172
|
+
|
|
173
|
+
Uses a sliding window approach which provides more precise rate limiting
|
|
174
|
+
but uses more memory to track individual request timestamps.
|
|
175
|
+
|
|
176
|
+
Example:
|
|
177
|
+
```python
|
|
178
|
+
from fastmcp.server.middleware.rate_limiting import SlidingWindowRateLimitingMiddleware
|
|
179
|
+
|
|
180
|
+
# Allow 100 requests per minute
|
|
181
|
+
rate_limiter = SlidingWindowRateLimitingMiddleware(
|
|
182
|
+
max_requests=100,
|
|
183
|
+
window_minutes=1
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
mcp = FastMCP("MyServer")
|
|
187
|
+
mcp.add_middleware(rate_limiter)
|
|
188
|
+
```
|
|
189
|
+
"""
|
|
190
|
+
|
|
191
|
+
def __init__(
|
|
192
|
+
self,
|
|
193
|
+
max_requests: int,
|
|
194
|
+
window_minutes: int = 1,
|
|
195
|
+
get_client_id: Callable[[MiddlewareContext], str] | None = None,
|
|
196
|
+
):
|
|
197
|
+
"""Initialize sliding window rate limiting middleware.
|
|
198
|
+
|
|
199
|
+
Args:
|
|
200
|
+
max_requests: Maximum requests allowed in the time window
|
|
201
|
+
window_minutes: Time window in minutes
|
|
202
|
+
get_client_id: Function to extract client ID from context
|
|
203
|
+
"""
|
|
204
|
+
self.max_requests = max_requests
|
|
205
|
+
self.window_seconds = window_minutes * 60
|
|
206
|
+
self.get_client_id = get_client_id
|
|
207
|
+
|
|
208
|
+
# Storage for rate limiters per client
|
|
209
|
+
self.limiters: dict[str, SlidingWindowRateLimiter] = defaultdict(
|
|
210
|
+
lambda: SlidingWindowRateLimiter(self.max_requests, self.window_seconds)
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
def _get_client_identifier(self, context: MiddlewareContext) -> str:
|
|
214
|
+
"""Get client identifier for rate limiting."""
|
|
215
|
+
if self.get_client_id:
|
|
216
|
+
return self.get_client_id(context)
|
|
217
|
+
return "global"
|
|
218
|
+
|
|
219
|
+
async def on_request(self, context: MiddlewareContext, call_next: CallNext) -> Any:
|
|
220
|
+
"""Apply sliding window rate limiting to requests."""
|
|
221
|
+
client_id = self._get_client_identifier(context)
|
|
222
|
+
limiter = self.limiters[client_id]
|
|
223
|
+
|
|
224
|
+
allowed = await limiter.is_allowed()
|
|
225
|
+
if not allowed:
|
|
226
|
+
raise RateLimitError(
|
|
227
|
+
f"Rate limit exceeded: {self.max_requests} requests per "
|
|
228
|
+
f"{self.window_seconds // 60} minutes for client: {client_id}"
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
return await call_next(context)
|
|
@@ -0,0 +1,156 @@
|
|
|
1
|
+
"""Timing middleware for measuring and logging request performance."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
import time
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
from .middleware import CallNext, Middleware, MiddlewareContext
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class TimingMiddleware(Middleware):
|
|
11
|
+
"""Middleware that logs the execution time of requests.
|
|
12
|
+
|
|
13
|
+
Only measures and logs timing for request messages (not notifications).
|
|
14
|
+
Provides insights into performance characteristics of your MCP server.
|
|
15
|
+
|
|
16
|
+
Example:
|
|
17
|
+
```python
|
|
18
|
+
from fastmcp.server.middleware.timing import TimingMiddleware
|
|
19
|
+
|
|
20
|
+
mcp = FastMCP("MyServer")
|
|
21
|
+
mcp.add_middleware(TimingMiddleware())
|
|
22
|
+
|
|
23
|
+
# Now all requests will be timed and logged
|
|
24
|
+
```
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
def __init__(
|
|
28
|
+
self, logger: logging.Logger | None = None, log_level: int = logging.INFO
|
|
29
|
+
):
|
|
30
|
+
"""Initialize timing middleware.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
logger: Logger instance to use. If None, creates a logger named 'fastmcp.timing'
|
|
34
|
+
log_level: Log level for timing messages (default: INFO)
|
|
35
|
+
"""
|
|
36
|
+
self.logger = logger or logging.getLogger("fastmcp.timing")
|
|
37
|
+
self.log_level = log_level
|
|
38
|
+
|
|
39
|
+
async def on_request(self, context: MiddlewareContext, call_next: CallNext) -> Any:
|
|
40
|
+
"""Time request execution and log the results."""
|
|
41
|
+
method = context.method or "unknown"
|
|
42
|
+
|
|
43
|
+
start_time = time.perf_counter()
|
|
44
|
+
try:
|
|
45
|
+
result = await call_next(context)
|
|
46
|
+
duration_ms = (time.perf_counter() - start_time) * 1000
|
|
47
|
+
self.logger.log(
|
|
48
|
+
self.log_level, f"Request {method} completed in {duration_ms:.2f}ms"
|
|
49
|
+
)
|
|
50
|
+
return result
|
|
51
|
+
except Exception as e:
|
|
52
|
+
duration_ms = (time.perf_counter() - start_time) * 1000
|
|
53
|
+
self.logger.log(
|
|
54
|
+
self.log_level,
|
|
55
|
+
f"Request {method} failed after {duration_ms:.2f}ms: {e}",
|
|
56
|
+
)
|
|
57
|
+
raise
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class DetailedTimingMiddleware(Middleware):
|
|
61
|
+
"""Enhanced timing middleware with per-operation breakdowns.
|
|
62
|
+
|
|
63
|
+
Provides detailed timing information for different types of MCP operations,
|
|
64
|
+
allowing you to identify performance bottlenecks in specific operations.
|
|
65
|
+
|
|
66
|
+
Example:
|
|
67
|
+
```python
|
|
68
|
+
from fastmcp.server.middleware.timing import DetailedTimingMiddleware
|
|
69
|
+
import logging
|
|
70
|
+
|
|
71
|
+
# Configure logging to see the output
|
|
72
|
+
logging.basicConfig(level=logging.INFO)
|
|
73
|
+
|
|
74
|
+
mcp = FastMCP("MyServer")
|
|
75
|
+
mcp.add_middleware(DetailedTimingMiddleware())
|
|
76
|
+
```
|
|
77
|
+
"""
|
|
78
|
+
|
|
79
|
+
def __init__(
|
|
80
|
+
self, logger: logging.Logger | None = None, log_level: int = logging.INFO
|
|
81
|
+
):
|
|
82
|
+
"""Initialize detailed timing middleware.
|
|
83
|
+
|
|
84
|
+
Args:
|
|
85
|
+
logger: Logger instance to use. If None, creates a logger named 'fastmcp.timing.detailed'
|
|
86
|
+
log_level: Log level for timing messages (default: INFO)
|
|
87
|
+
"""
|
|
88
|
+
self.logger = logger or logging.getLogger("fastmcp.timing.detailed")
|
|
89
|
+
self.log_level = log_level
|
|
90
|
+
|
|
91
|
+
async def _time_operation(
|
|
92
|
+
self, context: MiddlewareContext, call_next: CallNext, operation_name: str
|
|
93
|
+
) -> Any:
|
|
94
|
+
"""Helper method to time any operation."""
|
|
95
|
+
start_time = time.perf_counter()
|
|
96
|
+
try:
|
|
97
|
+
result = await call_next(context)
|
|
98
|
+
duration_ms = (time.perf_counter() - start_time) * 1000
|
|
99
|
+
self.logger.log(
|
|
100
|
+
self.log_level, f"{operation_name} completed in {duration_ms:.2f}ms"
|
|
101
|
+
)
|
|
102
|
+
return result
|
|
103
|
+
except Exception as e:
|
|
104
|
+
duration_ms = (time.perf_counter() - start_time) * 1000
|
|
105
|
+
self.logger.log(
|
|
106
|
+
self.log_level,
|
|
107
|
+
f"{operation_name} failed after {duration_ms:.2f}ms: {e}",
|
|
108
|
+
)
|
|
109
|
+
raise
|
|
110
|
+
|
|
111
|
+
async def on_call_tool(
|
|
112
|
+
self, context: MiddlewareContext, call_next: CallNext
|
|
113
|
+
) -> Any:
|
|
114
|
+
"""Time tool execution."""
|
|
115
|
+
tool_name = getattr(context.message, "name", "unknown")
|
|
116
|
+
return await self._time_operation(context, call_next, f"Tool '{tool_name}'")
|
|
117
|
+
|
|
118
|
+
async def on_read_resource(
|
|
119
|
+
self, context: MiddlewareContext, call_next: CallNext
|
|
120
|
+
) -> Any:
|
|
121
|
+
"""Time resource reading."""
|
|
122
|
+
resource_uri = getattr(context.message, "uri", "unknown")
|
|
123
|
+
return await self._time_operation(
|
|
124
|
+
context, call_next, f"Resource '{resource_uri}'"
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
async def on_get_prompt(
|
|
128
|
+
self, context: MiddlewareContext, call_next: CallNext
|
|
129
|
+
) -> Any:
|
|
130
|
+
"""Time prompt retrieval."""
|
|
131
|
+
prompt_name = getattr(context.message, "name", "unknown")
|
|
132
|
+
return await self._time_operation(context, call_next, f"Prompt '{prompt_name}'")
|
|
133
|
+
|
|
134
|
+
async def on_list_tools(
|
|
135
|
+
self, context: MiddlewareContext, call_next: CallNext
|
|
136
|
+
) -> Any:
|
|
137
|
+
"""Time tool listing."""
|
|
138
|
+
return await self._time_operation(context, call_next, "List tools")
|
|
139
|
+
|
|
140
|
+
async def on_list_resources(
|
|
141
|
+
self, context: MiddlewareContext, call_next: CallNext
|
|
142
|
+
) -> Any:
|
|
143
|
+
"""Time resource listing."""
|
|
144
|
+
return await self._time_operation(context, call_next, "List resources")
|
|
145
|
+
|
|
146
|
+
async def on_list_resource_templates(
|
|
147
|
+
self, context: MiddlewareContext, call_next: CallNext
|
|
148
|
+
) -> Any:
|
|
149
|
+
"""Time resource template listing."""
|
|
150
|
+
return await self._time_operation(context, call_next, "List resource templates")
|
|
151
|
+
|
|
152
|
+
async def on_list_prompts(
|
|
153
|
+
self, context: MiddlewareContext, call_next: CallNext
|
|
154
|
+
) -> Any:
|
|
155
|
+
"""Time prompt listing."""
|
|
156
|
+
return await self._time_operation(context, call_next, "List prompts")
|