aixtools 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of aixtools might be problematic. Click here for more details.
- aixtools/__init__.py +5 -0
- aixtools/a2a/__init__.py +5 -0
- aixtools/a2a/app.py +126 -0
- aixtools/a2a/utils.py +115 -0
- aixtools/agents/__init__.py +12 -0
- aixtools/agents/agent.py +164 -0
- aixtools/agents/agent_batch.py +74 -0
- aixtools/app.py +143 -0
- aixtools/context.py +12 -0
- aixtools/db/__init__.py +17 -0
- aixtools/db/database.py +110 -0
- aixtools/db/vector_db.py +115 -0
- aixtools/log_view/__init__.py +17 -0
- aixtools/log_view/app.py +195 -0
- aixtools/log_view/display.py +285 -0
- aixtools/log_view/export.py +51 -0
- aixtools/log_view/filters.py +41 -0
- aixtools/log_view/log_utils.py +26 -0
- aixtools/log_view/node_summary.py +229 -0
- aixtools/logfilters/__init__.py +7 -0
- aixtools/logfilters/context_filter.py +67 -0
- aixtools/logging/__init__.py +30 -0
- aixtools/logging/log_objects.py +227 -0
- aixtools/logging/logging_config.py +116 -0
- aixtools/logging/mcp_log_models.py +102 -0
- aixtools/logging/mcp_logger.py +172 -0
- aixtools/logging/model_patch_logging.py +87 -0
- aixtools/logging/open_telemetry.py +36 -0
- aixtools/mcp/__init__.py +9 -0
- aixtools/mcp/example_client.py +30 -0
- aixtools/mcp/example_server.py +22 -0
- aixtools/mcp/fast_mcp_log.py +31 -0
- aixtools/mcp/faulty_mcp.py +320 -0
- aixtools/model_patch/model_patch.py +65 -0
- aixtools/server/__init__.py +23 -0
- aixtools/server/app_mounter.py +90 -0
- aixtools/server/path.py +72 -0
- aixtools/server/utils.py +70 -0
- aixtools/testing/__init__.py +9 -0
- aixtools/testing/aix_test_model.py +147 -0
- aixtools/testing/mock_tool.py +66 -0
- aixtools/testing/model_patch_cache.py +279 -0
- aixtools/tools/doctor/__init__.py +3 -0
- aixtools/tools/doctor/tool_doctor.py +61 -0
- aixtools/tools/doctor/tool_recommendation.py +44 -0
- aixtools/utils/__init__.py +35 -0
- aixtools/utils/chainlit/cl_agent_show.py +82 -0
- aixtools/utils/chainlit/cl_utils.py +168 -0
- aixtools/utils/config.py +118 -0
- aixtools/utils/config_util.py +69 -0
- aixtools/utils/enum_with_description.py +37 -0
- aixtools/utils/persisted_dict.py +99 -0
- aixtools/utils/utils.py +160 -0
- aixtools-0.1.0.dist-info/METADATA +355 -0
- aixtools-0.1.0.dist-info/RECORD +58 -0
- aixtools-0.1.0.dist-info/WHEEL +5 -0
- aixtools-0.1.0.dist-info/entry_points.txt +2 -0
- aixtools-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,320 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""
|
|
3
|
+
Faulty MCP Server for Testing MCP Errors
|
|
4
|
+
- Simulates 404 errors for specific MCP requests
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import argparse
|
|
8
|
+
import asyncio
|
|
9
|
+
import json
|
|
10
|
+
import logging.config
|
|
11
|
+
import os
|
|
12
|
+
from dataclasses import dataclass
|
|
13
|
+
from random import choice, random
|
|
14
|
+
|
|
15
|
+
from fastapi import HTTPException, status
|
|
16
|
+
from fastmcp import FastMCP
|
|
17
|
+
from fastmcp.exceptions import (
|
|
18
|
+
ClientError,
|
|
19
|
+
DisabledError,
|
|
20
|
+
InvalidSignature,
|
|
21
|
+
NotFoundError,
|
|
22
|
+
PromptError,
|
|
23
|
+
ResourceError,
|
|
24
|
+
ToolError,
|
|
25
|
+
ValidationError,
|
|
26
|
+
)
|
|
27
|
+
from fastmcp.server.middleware import Middleware as McpMiddleware
|
|
28
|
+
from fastmcp.server.middleware import MiddlewareContext
|
|
29
|
+
from fastmcp.server.middleware.logging import LoggingMiddleware
|
|
30
|
+
from starlette.middleware import Middleware as StarletteMiddleware
|
|
31
|
+
from starlette.types import Receive, Scope, Send
|
|
32
|
+
|
|
33
|
+
from aixtools.logging.logging_config import DEFAULT_LOGGING_CONFIG
|
|
34
|
+
from aixtools.utils import get_logger
|
|
35
|
+
|
|
36
|
+
# Remove the user/session ID from logger line to shorten it
|
|
37
|
+
DEFAULT_LOGGING_CONFIG["formatters"]["color"]["format"] = (
|
|
38
|
+
"%(log_color)s%(asctime)s.%(msecs)03d %(levelname)-8s%(reset)s %(message)s"
|
|
39
|
+
)
|
|
40
|
+
logging.config.dictConfig(DEFAULT_LOGGING_CONFIG)
|
|
41
|
+
|
|
42
|
+
# Get the logger
|
|
43
|
+
logger = get_logger(__name__)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
@dataclass
|
|
47
|
+
class Config:
|
|
48
|
+
"""Global configuration for the faulty MCP server."""
|
|
49
|
+
|
|
50
|
+
port: int = 9999
|
|
51
|
+
prob_throw_in_list_handle: float = 0.5 # Probability of throwing an exception in list tools handling
|
|
52
|
+
prob_delete_404: float = 0.5 # Probability of injecting a 404 error for DELETE requests
|
|
53
|
+
prob_general_404: float = 0.5 # Probability of injecting a 404 error for other requests
|
|
54
|
+
prob_terminate_on_empty_request: float = 0.3 # Probability of terminating the process on empty request
|
|
55
|
+
prob_terminate_in_list_handle: float = 0.3 # Probability of terminating the process in list tools handling
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
# Global configuration
|
|
59
|
+
config = Config()
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class McpErrorMiddleware(McpMiddleware):
|
|
63
|
+
"""Custom middleware to simulate errors in MCP requests."""
|
|
64
|
+
|
|
65
|
+
async def __call__(self, context: MiddlewareContext, call_next):
|
|
66
|
+
# This method receives ALL messages regardless of type
|
|
67
|
+
logger.info("[McpErrorMiddleware] processing: %s", context.method)
|
|
68
|
+
|
|
69
|
+
if context.method == "tools/list":
|
|
70
|
+
random_number = random()
|
|
71
|
+
logger.info("[McpErrorMiddleware] random number: %f", random_number)
|
|
72
|
+
if random_number < config.prob_terminate_in_list_handle:
|
|
73
|
+
logger.warning("[McpErrorMiddleware] Simulating server crash!")
|
|
74
|
+
os.kill(os.getpid(), 9)
|
|
75
|
+
|
|
76
|
+
if random_number < config.prob_throw_in_list_handle:
|
|
77
|
+
exception_class = choice(
|
|
78
|
+
[
|
|
79
|
+
ValidationError,
|
|
80
|
+
ResourceError,
|
|
81
|
+
ToolError,
|
|
82
|
+
PromptError,
|
|
83
|
+
InvalidSignature,
|
|
84
|
+
ClientError,
|
|
85
|
+
NotFoundError,
|
|
86
|
+
DisabledError,
|
|
87
|
+
]
|
|
88
|
+
)
|
|
89
|
+
logger.warning("[McpErrorMiddleware] throwing %s for: %s", exception_class.__name__, context.method)
|
|
90
|
+
raise exception_class(f"[McpErrorMiddleware] {exception_class.__name__}.")
|
|
91
|
+
|
|
92
|
+
result = await call_next(context)
|
|
93
|
+
logger.info("[McpErrorMiddleware] completed: %s", context.method)
|
|
94
|
+
return result
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
class StarletteErrorMiddleware: # pylint: disable=too-few-public-methods
|
|
98
|
+
"""Custom Starlette middleware to log and inject errors."""
|
|
99
|
+
|
|
100
|
+
def __init__(self, app):
|
|
101
|
+
"""Initialize middleware."""
|
|
102
|
+
self.app = app
|
|
103
|
+
logger.info("[StarletteErrorMiddleware] Middleware initialized!")
|
|
104
|
+
logger.info("Current configuration:")
|
|
105
|
+
logger.info("Exception in list tools handling probability: %f", config.prob_throw_in_list_handle)
|
|
106
|
+
logger.info("DELETE 404 probability: %f", config.prob_delete_404)
|
|
107
|
+
logger.info("General 404 probability: %f", config.prob_general_404)
|
|
108
|
+
logger.info("Terminate on empty request probability: %f", config.prob_terminate_on_empty_request)
|
|
109
|
+
logger.info("Terminate in list handle probability: %f", config.prob_terminate_in_list_handle)
|
|
110
|
+
|
|
111
|
+
async def __call__(self, scope: Scope, receive: Receive, send: Send): # noqa: PLR0915 # pylint: disable=too-many-statements
|
|
112
|
+
# Log all the condition variables for debugging
|
|
113
|
+
|
|
114
|
+
logger.info("[StarletteErrorMiddleware] >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>")
|
|
115
|
+
logger.info("[StarletteErrorMiddleware] scope['type']: %s", scope.get("type", "unknown"))
|
|
116
|
+
logger.info("[StarletteErrorMiddleware] scope['path']: %s", scope.get("path", "unknown"))
|
|
117
|
+
logger.info("[StarletteErrorMiddleware] HTTP method: %s", http_method := scope.get("method", "unknown"))
|
|
118
|
+
logger.info("[StarletteErrorMiddleware] Headers: %s", str(dict(scope.get("headers", []))))
|
|
119
|
+
|
|
120
|
+
# Wrap receive to log body content without breaking the flow
|
|
121
|
+
body_parts = []
|
|
122
|
+
should_inject_404 = False
|
|
123
|
+
|
|
124
|
+
if http_method == "DELETE":
|
|
125
|
+
random_number = random()
|
|
126
|
+
logger.info("[StarletteErrorMiddleware] random number: %f", random_number)
|
|
127
|
+
if random_number < config.prob_delete_404:
|
|
128
|
+
logger.info("[StarletteErrorMiddleware] Simulating 404 error for DELETE request")
|
|
129
|
+
should_inject_404 = True
|
|
130
|
+
|
|
131
|
+
async def logging_receive():
|
|
132
|
+
nonlocal should_inject_404
|
|
133
|
+
message = await receive()
|
|
134
|
+
logger.info("[StarletteErrorMiddleware] Received message: %s", str(message))
|
|
135
|
+
|
|
136
|
+
# Check for empty request and possibly terminate the process
|
|
137
|
+
if message["type"] == "http.request" and message["body"] == b"" and not message.get("more_body", False):
|
|
138
|
+
random_number = random()
|
|
139
|
+
logger.info("[StarletteErrorMiddleware] Empty request received, random number: %f", random_number)
|
|
140
|
+
if random_number < config.prob_terminate_on_empty_request:
|
|
141
|
+
logger.warning("[StarletteErrorMiddleware] Simulating server crash on empty request!")
|
|
142
|
+
os.kill(os.getpid(), 9)
|
|
143
|
+
|
|
144
|
+
if message["type"] == "http.request": # pylint: disable=too-many-nested-blocks
|
|
145
|
+
body = message.get("body", b"")
|
|
146
|
+
if body:
|
|
147
|
+
body_parts.append(body)
|
|
148
|
+
|
|
149
|
+
# Log when we have the complete body
|
|
150
|
+
if not message.get("more_body", False):
|
|
151
|
+
complete_body = b"".join(body_parts)
|
|
152
|
+
if complete_body:
|
|
153
|
+
try:
|
|
154
|
+
body_str = complete_body.decode("utf-8")
|
|
155
|
+
logger.info("[StarletteErrorMiddleware] Request body: %s", body_str)
|
|
156
|
+
|
|
157
|
+
json_data = json.loads(body_str)
|
|
158
|
+
if isinstance(json_data, dict):
|
|
159
|
+
method_name = json_data.get("method", "unknown")
|
|
160
|
+
if method_name == "initialize" and not json_data.get("params", {}).get("capabilities"):
|
|
161
|
+
logger.info("Detected initial health check from Navari, skipping 404 injection.")
|
|
162
|
+
return message
|
|
163
|
+
logger.info("[StarletteErrorMiddleware] MCP method: %s", method_name)
|
|
164
|
+
|
|
165
|
+
# Check if we should inject 404
|
|
166
|
+
random_number = random()
|
|
167
|
+
logger.info("[StarletteErrorMiddleware] random number: %f", random_number)
|
|
168
|
+
if random_number < config.prob_general_404:
|
|
169
|
+
should_inject_404 = True
|
|
170
|
+
logger.info("[StarletteErrorMiddleware] %s - will inject 404!", method_name)
|
|
171
|
+
except (UnicodeDecodeError, json.JSONDecodeError) as e:
|
|
172
|
+
logger.exception("[StarletteErrorMiddleware] Error parsing body: %s", e)
|
|
173
|
+
else:
|
|
174
|
+
logger.info("[StarletteErrorMiddleware] Request body: (empty)")
|
|
175
|
+
|
|
176
|
+
return message
|
|
177
|
+
|
|
178
|
+
async def intercepting_send(message):
|
|
179
|
+
logger.info("[StarletteErrorMiddleware] Sending message: %s", str(message))
|
|
180
|
+
if message["type"] == "http.response.start" and should_inject_404:
|
|
181
|
+
logger.warning("[StarletteErrorMiddleware] Injecting 404!")
|
|
182
|
+
# Replace the response with 404
|
|
183
|
+
message = {
|
|
184
|
+
"type": "http.response.start",
|
|
185
|
+
"status": 404,
|
|
186
|
+
"headers": [[b"content-type", b"text/plain"]],
|
|
187
|
+
}
|
|
188
|
+
elif message["type"] == "http.response.body" and should_inject_404:
|
|
189
|
+
# Replace body with 404 message
|
|
190
|
+
message = {
|
|
191
|
+
"type": "http.response.body",
|
|
192
|
+
"body": b"Simulated 404 error",
|
|
193
|
+
}
|
|
194
|
+
|
|
195
|
+
await send(message)
|
|
196
|
+
|
|
197
|
+
logger.info("[StarletteErrorMiddleware] <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<")
|
|
198
|
+
await self.app(scope, logging_receive, intercepting_send)
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
# Create the MCP server
|
|
202
|
+
mcp = FastMCP(
|
|
203
|
+
name="Faulty MCP Server",
|
|
204
|
+
instructions="A simple test server for reproducing MCP errors.",
|
|
205
|
+
middleware=[LoggingMiddleware(include_payloads=True), McpErrorMiddleware()],
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
@mcp.tool
|
|
210
|
+
def add(a: float, b: float) -> float:
|
|
211
|
+
"""Add two numbers together."""
|
|
212
|
+
return a + b
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
@mcp.tool
|
|
216
|
+
def always_error() -> None:
|
|
217
|
+
"""Always throw an exception to simulate errors."""
|
|
218
|
+
raise ValueError("Simulated error")
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
@mcp.tool
|
|
222
|
+
async def freeze_server(seconds: int = 60) -> str:
|
|
223
|
+
"""Simulate a server freeze for testing purposes."""
|
|
224
|
+
await asyncio.sleep(seconds)
|
|
225
|
+
return f"Server frozen for {seconds} seconds"
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
@mcp.tool
|
|
229
|
+
def multiply(a: float, b: float) -> float:
|
|
230
|
+
"""Multiply two numbers together."""
|
|
231
|
+
return a * b
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
@mcp.tool
|
|
235
|
+
def random_throw_exception(a: float, b: float, prob: float = 0.5) -> float:
|
|
236
|
+
"""Randomly throw an exception to simulate errors."""
|
|
237
|
+
if random() < prob:
|
|
238
|
+
raise ValueError("Simulated error")
|
|
239
|
+
return a * b
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
@mcp.tool
|
|
243
|
+
def throw_404_exception() -> None:
|
|
244
|
+
"""Randomly throw an exception to simulate errors."""
|
|
245
|
+
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Throwing a 404 error for testing purposes.")
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
def run_server_on_port():
|
|
249
|
+
"""Run a single MCP server using the global configuration."""
|
|
250
|
+
|
|
251
|
+
async def run_async():
|
|
252
|
+
print(f"[Port {config.port}] Starting MCP server on http://localhost:{config.port}/mcp/")
|
|
253
|
+
await mcp.run_http_async(
|
|
254
|
+
transport="streamable-http",
|
|
255
|
+
host="localhost",
|
|
256
|
+
port=config.port,
|
|
257
|
+
path="/mcp/",
|
|
258
|
+
middleware=[StarletteMiddleware(StarletteErrorMiddleware)],
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
asyncio.run(run_async())
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
if __name__ == "__main__":
|
|
265
|
+
# Parse command line arguments
|
|
266
|
+
parser = argparse.ArgumentParser(description="Run a faulty MCP server for testing error handling")
|
|
267
|
+
parser.add_argument(
|
|
268
|
+
"--safe-mode",
|
|
269
|
+
action="store_true",
|
|
270
|
+
help="Set all error probabilities to 0 by default, only use explicitly provided values",
|
|
271
|
+
)
|
|
272
|
+
parser.add_argument(
|
|
273
|
+
"--port",
|
|
274
|
+
type=int,
|
|
275
|
+
help=f"Port to run the server on (default: {config.port})",
|
|
276
|
+
)
|
|
277
|
+
parser.add_argument(
|
|
278
|
+
"--prob-throw-in-list-handle",
|
|
279
|
+
type=float,
|
|
280
|
+
help=f"Probability of exception in list tools handling (default: {config.prob_throw_in_list_handle})",
|
|
281
|
+
)
|
|
282
|
+
parser.add_argument(
|
|
283
|
+
"--prob-delete-404",
|
|
284
|
+
type=float,
|
|
285
|
+
help=f"Probability of injecting a 404 error for DELETE requests (default: {config.prob_delete_404})",
|
|
286
|
+
)
|
|
287
|
+
parser.add_argument(
|
|
288
|
+
"--prob-general-404",
|
|
289
|
+
type=float,
|
|
290
|
+
help=f"Probability of injecting a 404 error for other requests (default: {config.prob_general_404})",
|
|
291
|
+
)
|
|
292
|
+
parser.add_argument(
|
|
293
|
+
"--prob-terminate-on-empty-request",
|
|
294
|
+
type=float,
|
|
295
|
+
help=f"Probability of terminating on empty request (default: {config.prob_terminate_on_empty_request})",
|
|
296
|
+
)
|
|
297
|
+
parser.add_argument(
|
|
298
|
+
"--prob-terminate-in-list-handle",
|
|
299
|
+
type=float,
|
|
300
|
+
help=f"Probability of terminating in list tools handling (default: {config.prob_terminate_in_list_handle})",
|
|
301
|
+
)
|
|
302
|
+
|
|
303
|
+
args = parser.parse_args()
|
|
304
|
+
|
|
305
|
+
def _update_config_value(attr_name: str):
|
|
306
|
+
if args.safe_mode:
|
|
307
|
+
setattr(config, attr_name, 0)
|
|
308
|
+
if (given_value := getattr(args, attr_name)) is not None:
|
|
309
|
+
setattr(config, attr_name, given_value)
|
|
310
|
+
|
|
311
|
+
# Update the global configuration with command line arguments
|
|
312
|
+
config.port = args.port or config.port
|
|
313
|
+
_update_config_value("prob_throw_in_list_handle")
|
|
314
|
+
_update_config_value("prob_delete_404")
|
|
315
|
+
_update_config_value("prob_general_404")
|
|
316
|
+
_update_config_value("prob_terminate_on_empty_request")
|
|
317
|
+
_update_config_value("prob_terminate_in_list_handle")
|
|
318
|
+
|
|
319
|
+
# Run the server
|
|
320
|
+
run_server_on_port()
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel
|
|
4
|
+
|
|
5
|
+
from aixtools.logging.logging_config import get_logger
|
|
6
|
+
|
|
7
|
+
logger = get_logger(__name__)
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class ModelRawRequest(BaseModel):
|
|
11
|
+
method_name: str # Model method name
|
|
12
|
+
request_id: str # Unique request ID
|
|
13
|
+
args: tuple # Method arguments
|
|
14
|
+
kwargs: dict # Method keyword arguments
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class ModelRawRequestResult(BaseModel):
|
|
18
|
+
method_name: str # Model method name
|
|
19
|
+
request_id: str # Unique request ID
|
|
20
|
+
result: Any # Method's result
|
|
21
|
+
|
|
22
|
+
class Config:
|
|
23
|
+
arbitrary_types_allowed = True
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class ModelRawRequestYieldItem(BaseModel):
|
|
27
|
+
method_name: str # Model method name
|
|
28
|
+
request_id: str # Unique request ID
|
|
29
|
+
item_num: int # Item number in the stream
|
|
30
|
+
item: Any # Yielded item
|
|
31
|
+
|
|
32
|
+
class Config:
|
|
33
|
+
arbitrary_types_allowed = True
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def get_request_fn(model):
|
|
37
|
+
"""Get the original request method"""
|
|
38
|
+
if is_patched(model):
|
|
39
|
+
return model._request_ori
|
|
40
|
+
return model.request
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def get_request_stream_fn(model):
|
|
44
|
+
"""Get the original request method"""
|
|
45
|
+
if is_patched(model):
|
|
46
|
+
return model._request_stream_ori
|
|
47
|
+
return model.request_stream
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def is_patched(model):
|
|
51
|
+
return hasattr(model, "_request_ori") or hasattr(model, "_request_stream_ori")
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def model_patch(model, request_method, request_stream_method):
|
|
55
|
+
"""Replace model.request and model.request_stream with logging versions"""
|
|
56
|
+
if is_patched(model):
|
|
57
|
+
logger.warning(f"Model {model.__class__.__name__} is already patched. Skipping patching.")
|
|
58
|
+
return model
|
|
59
|
+
# Save original methods
|
|
60
|
+
model._request_ori = model.request
|
|
61
|
+
model._request_stream_ori = model.request_stream
|
|
62
|
+
# Patch methods
|
|
63
|
+
model.request = request_method
|
|
64
|
+
model.request_stream = request_stream_method
|
|
65
|
+
return model
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
"""
|
|
2
|
+
FastMCP utilities for:
|
|
3
|
+
- extracting user metadata from context
|
|
4
|
+
- running mcp tools tasks in a separate thread.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from .path import (
|
|
8
|
+
container_to_host_path,
|
|
9
|
+
get_workspace_path,
|
|
10
|
+
host_to_container_path,
|
|
11
|
+
)
|
|
12
|
+
from .utils import (
|
|
13
|
+
get_session_id_tuple,
|
|
14
|
+
run_in_thread,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
__all__ = [
|
|
18
|
+
"get_workspace_path",
|
|
19
|
+
"get_session_id_tuple",
|
|
20
|
+
"container_to_host_path",
|
|
21
|
+
"host_to_container_path",
|
|
22
|
+
"run_in_thread",
|
|
23
|
+
]
|
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
"""Utility for mounting sub-applications with lifespan management in Starlette/FastAPI."""
|
|
2
|
+
|
|
3
|
+
from contextlib import asynccontextmanager
|
|
4
|
+
|
|
5
|
+
from starlette.applications import Starlette
|
|
6
|
+
from starlette.types import ASGIApp
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class AppMounter: # pylint: disable=too-few-public-methods
|
|
10
|
+
"""
|
|
11
|
+
A utility class for mounting sub-applications with their lifespans.
|
|
12
|
+
|
|
13
|
+
This class handles the complexity of ensuring that mounted sub-applications
|
|
14
|
+
have their lifespans properly managed alongside the parent application.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
def __init__(self, parent_app: Starlette):
|
|
18
|
+
"""
|
|
19
|
+
Initialize the SubAppMounter with a parent application.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
parent_app: The parent Starlette/FastAPI application
|
|
23
|
+
"""
|
|
24
|
+
self.parent_app = parent_app
|
|
25
|
+
self.mounted_apps: list[tuple[str, ASGIApp]] = []
|
|
26
|
+
self._original_lifespan = parent_app.router.lifespan_context
|
|
27
|
+
self._setup_combined_lifespan()
|
|
28
|
+
|
|
29
|
+
def mount_with_lifespan(self, path: str, app: ASGIApp, name: str = None) -> None:
|
|
30
|
+
"""
|
|
31
|
+
Mount a sub-application and ensure its lifespan is managed.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
path: The path to mount the application at
|
|
35
|
+
app: The ASGI application to mount
|
|
36
|
+
name: Optional name for the mounted application
|
|
37
|
+
"""
|
|
38
|
+
# Mount the app using the parent app's mount method
|
|
39
|
+
self.parent_app.mount(path, app=app, name=name)
|
|
40
|
+
|
|
41
|
+
# Store the mounted app for lifespan management
|
|
42
|
+
self.mounted_apps.append((path, app))
|
|
43
|
+
|
|
44
|
+
def _setup_combined_lifespan(self) -> None:
|
|
45
|
+
"""
|
|
46
|
+
Set up a combined lifespan that manages both the parent app and all mounted sub-apps.
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
@asynccontextmanager
|
|
50
|
+
async def combined_lifespan(app):
|
|
51
|
+
# First enter the parent app's lifespan
|
|
52
|
+
async with self._original_lifespan(app):
|
|
53
|
+
# Then enter each mounted app's lifespan in order
|
|
54
|
+
# We use nested context managers to ensure proper cleanup
|
|
55
|
+
async with self._create_nested_lifespans():
|
|
56
|
+
yield
|
|
57
|
+
|
|
58
|
+
# Replace the parent app's lifespan with our combined one
|
|
59
|
+
self.parent_app.router.lifespan_context = combined_lifespan
|
|
60
|
+
|
|
61
|
+
@asynccontextmanager
|
|
62
|
+
async def _create_nested_lifespans(self):
|
|
63
|
+
"""
|
|
64
|
+
Create nested async context managers for all mounted apps.
|
|
65
|
+
|
|
66
|
+
This ensures that all sub-app lifespans are entered and exited in the correct order.
|
|
67
|
+
"""
|
|
68
|
+
# If no apps are mounted, just yield
|
|
69
|
+
if not self.mounted_apps:
|
|
70
|
+
yield
|
|
71
|
+
return
|
|
72
|
+
|
|
73
|
+
# Otherwise, create nested context managers for each mounted app
|
|
74
|
+
async def enter_lifespans(index=0):
|
|
75
|
+
if index >= len(self.mounted_apps):
|
|
76
|
+
yield
|
|
77
|
+
return
|
|
78
|
+
|
|
79
|
+
_, app = self.mounted_apps[index]
|
|
80
|
+
if hasattr(app, "router") and hasattr(app.router, "lifespan_context"):
|
|
81
|
+
async with app.router.lifespan_context(app):
|
|
82
|
+
async for _ in enter_lifespans(index + 1):
|
|
83
|
+
yield
|
|
84
|
+
else:
|
|
85
|
+
# If the app doesn't have a lifespan, just move to the next one
|
|
86
|
+
async for _ in enter_lifespans(index + 1):
|
|
87
|
+
yield
|
|
88
|
+
|
|
89
|
+
async for _ in enter_lifespans():
|
|
90
|
+
yield
|
aixtools/server/path.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Workspace path handling for user sessions.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from pathlib import Path, PurePath, PurePosixPath
|
|
6
|
+
|
|
7
|
+
from fastmcp import Context
|
|
8
|
+
|
|
9
|
+
from ..utils.config import DATA_DIR
|
|
10
|
+
from .utils import get_session_id_tuple
|
|
11
|
+
|
|
12
|
+
WORKSPACES_ROOT_DIR = DATA_DIR / "workspaces" # Path on the host where workspaces are stored
|
|
13
|
+
CONTAINER_WORKSPACE_PATH = PurePosixPath("/workspace") # Path inside the sandbox container where workspace is mounted
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def get_workspace_path(service_name: str = None, *, in_sandbox: bool = False, ctx: Context | tuple = None) -> PurePath:
|
|
17
|
+
"""
|
|
18
|
+
Get the workspace path for a specific service (e.g. MCP server).
|
|
19
|
+
If `service_name` is None, then the path to entire workspace folder (as mounted to a sandbox) is returned.
|
|
20
|
+
If `in_sandbox` is True, it returns a path in sandbox, e.g.: `/workspace/mcp_repl`.
|
|
21
|
+
If `in_sandbox` is False, it returns the path based on user and session IDs in the format:
|
|
22
|
+
`<DATA_DIR>/workspaces/<user_id>/<session_id>/<service_name>`, where `DATA_DIR` should come from
|
|
23
|
+
the environment variables, e.g.:
|
|
24
|
+
`/data/workspaces/foo-user/bar-session/mcp_repl`.
|
|
25
|
+
The `ctx` is used to get user and session IDs tuple. It can be passed directly or via HTTP headers from `Context`.
|
|
26
|
+
If `ctx` is None, the current FastMCP request HTTP headers are used.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
ctx: The FastMCP context, which contains the user session.
|
|
30
|
+
service_name: The name of the service (e.g. "mcp_server").
|
|
31
|
+
in_sandbox: If True, use a sandbox path; otherwise, use user/session-based path.
|
|
32
|
+
|
|
33
|
+
Returns: The workspace path as a PurePath object.
|
|
34
|
+
"""
|
|
35
|
+
if in_sandbox:
|
|
36
|
+
path = CONTAINER_WORKSPACE_PATH
|
|
37
|
+
else:
|
|
38
|
+
user_id, session_id = ctx if isinstance(ctx, tuple) else get_session_id_tuple(ctx)
|
|
39
|
+
path = WORKSPACES_ROOT_DIR / user_id / session_id
|
|
40
|
+
if service_name:
|
|
41
|
+
path = path / service_name
|
|
42
|
+
return path
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def container_to_host_path(path: PurePosixPath, *, ctx: Context | tuple = None) -> Path | None:
|
|
46
|
+
"""
|
|
47
|
+
Convert a path in a sandbox container to a host path
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
container_path: Path inside the container (must be a subdir of CONTAINER_WORKSPACE_PATH).
|
|
51
|
+
user_id: ID of the user.
|
|
52
|
+
session_id: ID of the session.
|
|
53
|
+
|
|
54
|
+
Returns:
|
|
55
|
+
Path to the file on the host, or None if the conversion fails.
|
|
56
|
+
"""
|
|
57
|
+
old_root = CONTAINER_WORKSPACE_PATH
|
|
58
|
+
new_root = get_workspace_path(ctx=ctx)
|
|
59
|
+
try:
|
|
60
|
+
return new_root / PurePosixPath(path).relative_to(old_root)
|
|
61
|
+
except ValueError as e:
|
|
62
|
+
raise ValueError(f"Container path must be a subdir of '{old_root}', got '{path}' instead") from e
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def host_to_container_path(path: Path, *, ctx: Context | tuple = None) -> PurePosixPath:
|
|
66
|
+
"""Convert a host path to a path in a sandbox container."""
|
|
67
|
+
old_root = get_workspace_path(ctx=ctx)
|
|
68
|
+
new_root = CONTAINER_WORKSPACE_PATH
|
|
69
|
+
try:
|
|
70
|
+
return new_root / Path(path).relative_to(old_root)
|
|
71
|
+
except ValueError as exc:
|
|
72
|
+
raise ValueError(f"Host path must be a subdir of '{old_root}', got '{path}' instead") from exc
|
aixtools/server/utils.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
"""
|
|
2
|
+
FastMCP server utilities for handling user context and threading.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import asyncio
|
|
6
|
+
from functools import wraps
|
|
7
|
+
|
|
8
|
+
from fastmcp import Context
|
|
9
|
+
from fastmcp.server import dependencies
|
|
10
|
+
|
|
11
|
+
from ..context import session_id_var, user_id_var
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def get_session_id_tuple(ctx: Context | None = None) -> tuple[str, str]:
|
|
15
|
+
"""
|
|
16
|
+
Get the user and session IDs from the user session.
|
|
17
|
+
If `ctx` is None, the current FastMCP request HTTP headers are used.
|
|
18
|
+
Returns: Tuple of (user_id, session_id).
|
|
19
|
+
"""
|
|
20
|
+
user_id = get_user_id_from_request(ctx)
|
|
21
|
+
user_id = user_id or user_id_var.get("default_user")
|
|
22
|
+
session_id = get_session_id_from_request(ctx)
|
|
23
|
+
session_id = session_id or session_id_var.get("default_session")
|
|
24
|
+
return user_id, session_id
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def get_session_id_from_request(ctx: Context | None = None) -> str | None:
|
|
28
|
+
"""
|
|
29
|
+
Get the session ID from the HTTP request headers.
|
|
30
|
+
If `ctx` is None, the current FastMCP request HTTP headers are used.
|
|
31
|
+
"""
|
|
32
|
+
try:
|
|
33
|
+
return (ctx or dependencies).get_http_request().headers.get("session-id")
|
|
34
|
+
except (ValueError, RuntimeError):
|
|
35
|
+
return None
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def get_user_id_from_request(ctx: Context | None = None) -> str | None:
|
|
39
|
+
"""
|
|
40
|
+
Get the user ID from the HTTP request headers.
|
|
41
|
+
If `ctx` is None, the current FastMCP request HTTP headers are used.
|
|
42
|
+
The user_id is always returned as lowercase.
|
|
43
|
+
|
|
44
|
+
Returns:
|
|
45
|
+
str | None: The lowercase user ID, or None if not found or an error occurs.
|
|
46
|
+
"""
|
|
47
|
+
try:
|
|
48
|
+
user_id = (ctx or dependencies).get_http_request().headers.get("user-id")
|
|
49
|
+
return user_id.lower() if user_id else None
|
|
50
|
+
except (ValueError, RuntimeError, AttributeError):
|
|
51
|
+
return None
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def get_session_id_str(ctx: Context | None = None) -> str:
|
|
55
|
+
"""
|
|
56
|
+
Combined session ID for the current user and session.
|
|
57
|
+
If `ctx` is None, the current FastMCP request HTTP headers are used.
|
|
58
|
+
"""
|
|
59
|
+
user_id, session_id = get_session_id_tuple(ctx)
|
|
60
|
+
return f"{user_id}:{session_id}"
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def run_in_thread(func):
|
|
64
|
+
"""decorator to run blocking function with `asyncio.to_thread`"""
|
|
65
|
+
|
|
66
|
+
@wraps(func)
|
|
67
|
+
async def wrapper(*args, **kwargs):
|
|
68
|
+
return await asyncio.to_thread(func, *args, **kwargs)
|
|
69
|
+
|
|
70
|
+
return wrapper
|