aixtools 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of aixtools might be problematic. Click here for more details.

Files changed (58) hide show
  1. aixtools/__init__.py +5 -0
  2. aixtools/a2a/__init__.py +5 -0
  3. aixtools/a2a/app.py +126 -0
  4. aixtools/a2a/utils.py +115 -0
  5. aixtools/agents/__init__.py +12 -0
  6. aixtools/agents/agent.py +164 -0
  7. aixtools/agents/agent_batch.py +74 -0
  8. aixtools/app.py +143 -0
  9. aixtools/context.py +12 -0
  10. aixtools/db/__init__.py +17 -0
  11. aixtools/db/database.py +110 -0
  12. aixtools/db/vector_db.py +115 -0
  13. aixtools/log_view/__init__.py +17 -0
  14. aixtools/log_view/app.py +195 -0
  15. aixtools/log_view/display.py +285 -0
  16. aixtools/log_view/export.py +51 -0
  17. aixtools/log_view/filters.py +41 -0
  18. aixtools/log_view/log_utils.py +26 -0
  19. aixtools/log_view/node_summary.py +229 -0
  20. aixtools/logfilters/__init__.py +7 -0
  21. aixtools/logfilters/context_filter.py +67 -0
  22. aixtools/logging/__init__.py +30 -0
  23. aixtools/logging/log_objects.py +227 -0
  24. aixtools/logging/logging_config.py +116 -0
  25. aixtools/logging/mcp_log_models.py +102 -0
  26. aixtools/logging/mcp_logger.py +172 -0
  27. aixtools/logging/model_patch_logging.py +87 -0
  28. aixtools/logging/open_telemetry.py +36 -0
  29. aixtools/mcp/__init__.py +9 -0
  30. aixtools/mcp/example_client.py +30 -0
  31. aixtools/mcp/example_server.py +22 -0
  32. aixtools/mcp/fast_mcp_log.py +31 -0
  33. aixtools/mcp/faulty_mcp.py +320 -0
  34. aixtools/model_patch/model_patch.py +65 -0
  35. aixtools/server/__init__.py +23 -0
  36. aixtools/server/app_mounter.py +90 -0
  37. aixtools/server/path.py +72 -0
  38. aixtools/server/utils.py +70 -0
  39. aixtools/testing/__init__.py +9 -0
  40. aixtools/testing/aix_test_model.py +147 -0
  41. aixtools/testing/mock_tool.py +66 -0
  42. aixtools/testing/model_patch_cache.py +279 -0
  43. aixtools/tools/doctor/__init__.py +3 -0
  44. aixtools/tools/doctor/tool_doctor.py +61 -0
  45. aixtools/tools/doctor/tool_recommendation.py +44 -0
  46. aixtools/utils/__init__.py +35 -0
  47. aixtools/utils/chainlit/cl_agent_show.py +82 -0
  48. aixtools/utils/chainlit/cl_utils.py +168 -0
  49. aixtools/utils/config.py +118 -0
  50. aixtools/utils/config_util.py +69 -0
  51. aixtools/utils/enum_with_description.py +37 -0
  52. aixtools/utils/persisted_dict.py +99 -0
  53. aixtools/utils/utils.py +160 -0
  54. aixtools-0.1.0.dist-info/METADATA +355 -0
  55. aixtools-0.1.0.dist-info/RECORD +58 -0
  56. aixtools-0.1.0.dist-info/WHEEL +5 -0
  57. aixtools-0.1.0.dist-info/entry_points.txt +2 -0
  58. aixtools-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,320 @@
1
+ #!/usr/bin/env python3
2
+ """
3
+ Faulty MCP Server for Testing MCP Errors
4
+ - Simulates 404 errors for specific MCP requests
5
+ """
6
+
7
+ import argparse
8
+ import asyncio
9
+ import json
10
+ import logging.config
11
+ import os
12
+ from dataclasses import dataclass
13
+ from random import choice, random
14
+
15
+ from fastapi import HTTPException, status
16
+ from fastmcp import FastMCP
17
+ from fastmcp.exceptions import (
18
+ ClientError,
19
+ DisabledError,
20
+ InvalidSignature,
21
+ NotFoundError,
22
+ PromptError,
23
+ ResourceError,
24
+ ToolError,
25
+ ValidationError,
26
+ )
27
+ from fastmcp.server.middleware import Middleware as McpMiddleware
28
+ from fastmcp.server.middleware import MiddlewareContext
29
+ from fastmcp.server.middleware.logging import LoggingMiddleware
30
+ from starlette.middleware import Middleware as StarletteMiddleware
31
+ from starlette.types import Receive, Scope, Send
32
+
33
+ from aixtools.logging.logging_config import DEFAULT_LOGGING_CONFIG
34
+ from aixtools.utils import get_logger
35
+
36
+ # Remove the user/session ID from logger line to shorten it
37
+ DEFAULT_LOGGING_CONFIG["formatters"]["color"]["format"] = (
38
+ "%(log_color)s%(asctime)s.%(msecs)03d %(levelname)-8s%(reset)s %(message)s"
39
+ )
40
+ logging.config.dictConfig(DEFAULT_LOGGING_CONFIG)
41
+
42
+ # Get the logger
43
+ logger = get_logger(__name__)
44
+
45
+
46
+ @dataclass
47
+ class Config:
48
+ """Global configuration for the faulty MCP server."""
49
+
50
+ port: int = 9999
51
+ prob_throw_in_list_handle: float = 0.5 # Probability of throwing an exception in list tools handling
52
+ prob_delete_404: float = 0.5 # Probability of injecting a 404 error for DELETE requests
53
+ prob_general_404: float = 0.5 # Probability of injecting a 404 error for other requests
54
+ prob_terminate_on_empty_request: float = 0.3 # Probability of terminating the process on empty request
55
+ prob_terminate_in_list_handle: float = 0.3 # Probability of terminating the process in list tools handling
56
+
57
+
58
+ # Global configuration
59
+ config = Config()
60
+
61
+
62
+ class McpErrorMiddleware(McpMiddleware):
63
+ """Custom middleware to simulate errors in MCP requests."""
64
+
65
+ async def __call__(self, context: MiddlewareContext, call_next):
66
+ # This method receives ALL messages regardless of type
67
+ logger.info("[McpErrorMiddleware] processing: %s", context.method)
68
+
69
+ if context.method == "tools/list":
70
+ random_number = random()
71
+ logger.info("[McpErrorMiddleware] random number: %f", random_number)
72
+ if random_number < config.prob_terminate_in_list_handle:
73
+ logger.warning("[McpErrorMiddleware] Simulating server crash!")
74
+ os.kill(os.getpid(), 9)
75
+
76
+ if random_number < config.prob_throw_in_list_handle:
77
+ exception_class = choice(
78
+ [
79
+ ValidationError,
80
+ ResourceError,
81
+ ToolError,
82
+ PromptError,
83
+ InvalidSignature,
84
+ ClientError,
85
+ NotFoundError,
86
+ DisabledError,
87
+ ]
88
+ )
89
+ logger.warning("[McpErrorMiddleware] throwing %s for: %s", exception_class.__name__, context.method)
90
+ raise exception_class(f"[McpErrorMiddleware] {exception_class.__name__}.")
91
+
92
+ result = await call_next(context)
93
+ logger.info("[McpErrorMiddleware] completed: %s", context.method)
94
+ return result
95
+
96
+
97
+ class StarletteErrorMiddleware: # pylint: disable=too-few-public-methods
98
+ """Custom Starlette middleware to log and inject errors."""
99
+
100
+ def __init__(self, app):
101
+ """Initialize middleware."""
102
+ self.app = app
103
+ logger.info("[StarletteErrorMiddleware] Middleware initialized!")
104
+ logger.info("Current configuration:")
105
+ logger.info("Exception in list tools handling probability: %f", config.prob_throw_in_list_handle)
106
+ logger.info("DELETE 404 probability: %f", config.prob_delete_404)
107
+ logger.info("General 404 probability: %f", config.prob_general_404)
108
+ logger.info("Terminate on empty request probability: %f", config.prob_terminate_on_empty_request)
109
+ logger.info("Terminate in list handle probability: %f", config.prob_terminate_in_list_handle)
110
+
111
+ async def __call__(self, scope: Scope, receive: Receive, send: Send): # noqa: PLR0915 # pylint: disable=too-many-statements
112
+ # Log all the condition variables for debugging
113
+
114
+ logger.info("[StarletteErrorMiddleware] >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>")
115
+ logger.info("[StarletteErrorMiddleware] scope['type']: %s", scope.get("type", "unknown"))
116
+ logger.info("[StarletteErrorMiddleware] scope['path']: %s", scope.get("path", "unknown"))
117
+ logger.info("[StarletteErrorMiddleware] HTTP method: %s", http_method := scope.get("method", "unknown"))
118
+ logger.info("[StarletteErrorMiddleware] Headers: %s", str(dict(scope.get("headers", []))))
119
+
120
+ # Wrap receive to log body content without breaking the flow
121
+ body_parts = []
122
+ should_inject_404 = False
123
+
124
+ if http_method == "DELETE":
125
+ random_number = random()
126
+ logger.info("[StarletteErrorMiddleware] random number: %f", random_number)
127
+ if random_number < config.prob_delete_404:
128
+ logger.info("[StarletteErrorMiddleware] Simulating 404 error for DELETE request")
129
+ should_inject_404 = True
130
+
131
+ async def logging_receive():
132
+ nonlocal should_inject_404
133
+ message = await receive()
134
+ logger.info("[StarletteErrorMiddleware] Received message: %s", str(message))
135
+
136
+ # Check for empty request and possibly terminate the process
137
+ if message["type"] == "http.request" and message["body"] == b"" and not message.get("more_body", False):
138
+ random_number = random()
139
+ logger.info("[StarletteErrorMiddleware] Empty request received, random number: %f", random_number)
140
+ if random_number < config.prob_terminate_on_empty_request:
141
+ logger.warning("[StarletteErrorMiddleware] Simulating server crash on empty request!")
142
+ os.kill(os.getpid(), 9)
143
+
144
+ if message["type"] == "http.request": # pylint: disable=too-many-nested-blocks
145
+ body = message.get("body", b"")
146
+ if body:
147
+ body_parts.append(body)
148
+
149
+ # Log when we have the complete body
150
+ if not message.get("more_body", False):
151
+ complete_body = b"".join(body_parts)
152
+ if complete_body:
153
+ try:
154
+ body_str = complete_body.decode("utf-8")
155
+ logger.info("[StarletteErrorMiddleware] Request body: %s", body_str)
156
+
157
+ json_data = json.loads(body_str)
158
+ if isinstance(json_data, dict):
159
+ method_name = json_data.get("method", "unknown")
160
+ if method_name == "initialize" and not json_data.get("params", {}).get("capabilities"):
161
+ logger.info("Detected initial health check from Navari, skipping 404 injection.")
162
+ return message
163
+ logger.info("[StarletteErrorMiddleware] MCP method: %s", method_name)
164
+
165
+ # Check if we should inject 404
166
+ random_number = random()
167
+ logger.info("[StarletteErrorMiddleware] random number: %f", random_number)
168
+ if random_number < config.prob_general_404:
169
+ should_inject_404 = True
170
+ logger.info("[StarletteErrorMiddleware] %s - will inject 404!", method_name)
171
+ except (UnicodeDecodeError, json.JSONDecodeError) as e:
172
+ logger.exception("[StarletteErrorMiddleware] Error parsing body: %s", e)
173
+ else:
174
+ logger.info("[StarletteErrorMiddleware] Request body: (empty)")
175
+
176
+ return message
177
+
178
+ async def intercepting_send(message):
179
+ logger.info("[StarletteErrorMiddleware] Sending message: %s", str(message))
180
+ if message["type"] == "http.response.start" and should_inject_404:
181
+ logger.warning("[StarletteErrorMiddleware] Injecting 404!")
182
+ # Replace the response with 404
183
+ message = {
184
+ "type": "http.response.start",
185
+ "status": 404,
186
+ "headers": [[b"content-type", b"text/plain"]],
187
+ }
188
+ elif message["type"] == "http.response.body" and should_inject_404:
189
+ # Replace body with 404 message
190
+ message = {
191
+ "type": "http.response.body",
192
+ "body": b"Simulated 404 error",
193
+ }
194
+
195
+ await send(message)
196
+
197
+ logger.info("[StarletteErrorMiddleware] <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<")
198
+ await self.app(scope, logging_receive, intercepting_send)
199
+
200
+
201
+ # Create the MCP server
202
+ mcp = FastMCP(
203
+ name="Faulty MCP Server",
204
+ instructions="A simple test server for reproducing MCP errors.",
205
+ middleware=[LoggingMiddleware(include_payloads=True), McpErrorMiddleware()],
206
+ )
207
+
208
+
209
+ @mcp.tool
210
+ def add(a: float, b: float) -> float:
211
+ """Add two numbers together."""
212
+ return a + b
213
+
214
+
215
+ @mcp.tool
216
+ def always_error() -> None:
217
+ """Always throw an exception to simulate errors."""
218
+ raise ValueError("Simulated error")
219
+
220
+
221
+ @mcp.tool
222
+ async def freeze_server(seconds: int = 60) -> str:
223
+ """Simulate a server freeze for testing purposes."""
224
+ await asyncio.sleep(seconds)
225
+ return f"Server frozen for {seconds} seconds"
226
+
227
+
228
+ @mcp.tool
229
+ def multiply(a: float, b: float) -> float:
230
+ """Multiply two numbers together."""
231
+ return a * b
232
+
233
+
234
+ @mcp.tool
235
+ def random_throw_exception(a: float, b: float, prob: float = 0.5) -> float:
236
+ """Randomly throw an exception to simulate errors."""
237
+ if random() < prob:
238
+ raise ValueError("Simulated error")
239
+ return a * b
240
+
241
+
242
+ @mcp.tool
243
+ def throw_404_exception() -> None:
244
+ """Randomly throw an exception to simulate errors."""
245
+ raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Throwing a 404 error for testing purposes.")
246
+
247
+
248
+ def run_server_on_port():
249
+ """Run a single MCP server using the global configuration."""
250
+
251
+ async def run_async():
252
+ print(f"[Port {config.port}] Starting MCP server on http://localhost:{config.port}/mcp/")
253
+ await mcp.run_http_async(
254
+ transport="streamable-http",
255
+ host="localhost",
256
+ port=config.port,
257
+ path="/mcp/",
258
+ middleware=[StarletteMiddleware(StarletteErrorMiddleware)],
259
+ )
260
+
261
+ asyncio.run(run_async())
262
+
263
+
264
+ if __name__ == "__main__":
265
+ # Parse command line arguments
266
+ parser = argparse.ArgumentParser(description="Run a faulty MCP server for testing error handling")
267
+ parser.add_argument(
268
+ "--safe-mode",
269
+ action="store_true",
270
+ help="Set all error probabilities to 0 by default, only use explicitly provided values",
271
+ )
272
+ parser.add_argument(
273
+ "--port",
274
+ type=int,
275
+ help=f"Port to run the server on (default: {config.port})",
276
+ )
277
+ parser.add_argument(
278
+ "--prob-throw-in-list-handle",
279
+ type=float,
280
+ help=f"Probability of exception in list tools handling (default: {config.prob_throw_in_list_handle})",
281
+ )
282
+ parser.add_argument(
283
+ "--prob-delete-404",
284
+ type=float,
285
+ help=f"Probability of injecting a 404 error for DELETE requests (default: {config.prob_delete_404})",
286
+ )
287
+ parser.add_argument(
288
+ "--prob-general-404",
289
+ type=float,
290
+ help=f"Probability of injecting a 404 error for other requests (default: {config.prob_general_404})",
291
+ )
292
+ parser.add_argument(
293
+ "--prob-terminate-on-empty-request",
294
+ type=float,
295
+ help=f"Probability of terminating on empty request (default: {config.prob_terminate_on_empty_request})",
296
+ )
297
+ parser.add_argument(
298
+ "--prob-terminate-in-list-handle",
299
+ type=float,
300
+ help=f"Probability of terminating in list tools handling (default: {config.prob_terminate_in_list_handle})",
301
+ )
302
+
303
+ args = parser.parse_args()
304
+
305
+ def _update_config_value(attr_name: str):
306
+ if args.safe_mode:
307
+ setattr(config, attr_name, 0)
308
+ if (given_value := getattr(args, attr_name)) is not None:
309
+ setattr(config, attr_name, given_value)
310
+
311
+ # Update the global configuration with command line arguments
312
+ config.port = args.port or config.port
313
+ _update_config_value("prob_throw_in_list_handle")
314
+ _update_config_value("prob_delete_404")
315
+ _update_config_value("prob_general_404")
316
+ _update_config_value("prob_terminate_on_empty_request")
317
+ _update_config_value("prob_terminate_in_list_handle")
318
+
319
+ # Run the server
320
+ run_server_on_port()
@@ -0,0 +1,65 @@
1
+ from typing import Any
2
+
3
+ from pydantic import BaseModel
4
+
5
+ from aixtools.logging.logging_config import get_logger
6
+
7
+ logger = get_logger(__name__)
8
+
9
+
10
+ class ModelRawRequest(BaseModel):
11
+ method_name: str # Model method name
12
+ request_id: str # Unique request ID
13
+ args: tuple # Method arguments
14
+ kwargs: dict # Method keyword arguments
15
+
16
+
17
+ class ModelRawRequestResult(BaseModel):
18
+ method_name: str # Model method name
19
+ request_id: str # Unique request ID
20
+ result: Any # Method's result
21
+
22
+ class Config:
23
+ arbitrary_types_allowed = True
24
+
25
+
26
+ class ModelRawRequestYieldItem(BaseModel):
27
+ method_name: str # Model method name
28
+ request_id: str # Unique request ID
29
+ item_num: int # Item number in the stream
30
+ item: Any # Yielded item
31
+
32
+ class Config:
33
+ arbitrary_types_allowed = True
34
+
35
+
36
+ def get_request_fn(model):
37
+ """Get the original request method"""
38
+ if is_patched(model):
39
+ return model._request_ori
40
+ return model.request
41
+
42
+
43
+ def get_request_stream_fn(model):
44
+ """Get the original request method"""
45
+ if is_patched(model):
46
+ return model._request_stream_ori
47
+ return model.request_stream
48
+
49
+
50
+ def is_patched(model):
51
+ return hasattr(model, "_request_ori") or hasattr(model, "_request_stream_ori")
52
+
53
+
54
+ def model_patch(model, request_method, request_stream_method):
55
+ """Replace model.request and model.request_stream with logging versions"""
56
+ if is_patched(model):
57
+ logger.warning(f"Model {model.__class__.__name__} is already patched. Skipping patching.")
58
+ return model
59
+ # Save original methods
60
+ model._request_ori = model.request
61
+ model._request_stream_ori = model.request_stream
62
+ # Patch methods
63
+ model.request = request_method
64
+ model.request_stream = request_stream_method
65
+ return model
@@ -0,0 +1,23 @@
1
+ """
2
+ FastMCP utilities for:
3
+ - extracting user metadata from context
4
+ - running mcp tools tasks in a separate thread.
5
+ """
6
+
7
+ from .path import (
8
+ container_to_host_path,
9
+ get_workspace_path,
10
+ host_to_container_path,
11
+ )
12
+ from .utils import (
13
+ get_session_id_tuple,
14
+ run_in_thread,
15
+ )
16
+
17
+ __all__ = [
18
+ "get_workspace_path",
19
+ "get_session_id_tuple",
20
+ "container_to_host_path",
21
+ "host_to_container_path",
22
+ "run_in_thread",
23
+ ]
@@ -0,0 +1,90 @@
1
+ """Utility for mounting sub-applications with lifespan management in Starlette/FastAPI."""
2
+
3
+ from contextlib import asynccontextmanager
4
+
5
+ from starlette.applications import Starlette
6
+ from starlette.types import ASGIApp
7
+
8
+
9
+ class AppMounter: # pylint: disable=too-few-public-methods
10
+ """
11
+ A utility class for mounting sub-applications with their lifespans.
12
+
13
+ This class handles the complexity of ensuring that mounted sub-applications
14
+ have their lifespans properly managed alongside the parent application.
15
+ """
16
+
17
+ def __init__(self, parent_app: Starlette):
18
+ """
19
+ Initialize the SubAppMounter with a parent application.
20
+
21
+ Args:
22
+ parent_app: The parent Starlette/FastAPI application
23
+ """
24
+ self.parent_app = parent_app
25
+ self.mounted_apps: list[tuple[str, ASGIApp]] = []
26
+ self._original_lifespan = parent_app.router.lifespan_context
27
+ self._setup_combined_lifespan()
28
+
29
+ def mount_with_lifespan(self, path: str, app: ASGIApp, name: str = None) -> None:
30
+ """
31
+ Mount a sub-application and ensure its lifespan is managed.
32
+
33
+ Args:
34
+ path: The path to mount the application at
35
+ app: The ASGI application to mount
36
+ name: Optional name for the mounted application
37
+ """
38
+ # Mount the app using the parent app's mount method
39
+ self.parent_app.mount(path, app=app, name=name)
40
+
41
+ # Store the mounted app for lifespan management
42
+ self.mounted_apps.append((path, app))
43
+
44
+ def _setup_combined_lifespan(self) -> None:
45
+ """
46
+ Set up a combined lifespan that manages both the parent app and all mounted sub-apps.
47
+ """
48
+
49
+ @asynccontextmanager
50
+ async def combined_lifespan(app):
51
+ # First enter the parent app's lifespan
52
+ async with self._original_lifespan(app):
53
+ # Then enter each mounted app's lifespan in order
54
+ # We use nested context managers to ensure proper cleanup
55
+ async with self._create_nested_lifespans():
56
+ yield
57
+
58
+ # Replace the parent app's lifespan with our combined one
59
+ self.parent_app.router.lifespan_context = combined_lifespan
60
+
61
+ @asynccontextmanager
62
+ async def _create_nested_lifespans(self):
63
+ """
64
+ Create nested async context managers for all mounted apps.
65
+
66
+ This ensures that all sub-app lifespans are entered and exited in the correct order.
67
+ """
68
+ # If no apps are mounted, just yield
69
+ if not self.mounted_apps:
70
+ yield
71
+ return
72
+
73
+ # Otherwise, create nested context managers for each mounted app
74
+ async def enter_lifespans(index=0):
75
+ if index >= len(self.mounted_apps):
76
+ yield
77
+ return
78
+
79
+ _, app = self.mounted_apps[index]
80
+ if hasattr(app, "router") and hasattr(app.router, "lifespan_context"):
81
+ async with app.router.lifespan_context(app):
82
+ async for _ in enter_lifespans(index + 1):
83
+ yield
84
+ else:
85
+ # If the app doesn't have a lifespan, just move to the next one
86
+ async for _ in enter_lifespans(index + 1):
87
+ yield
88
+
89
+ async for _ in enter_lifespans():
90
+ yield
@@ -0,0 +1,72 @@
1
+ """
2
+ Workspace path handling for user sessions.
3
+ """
4
+
5
+ from pathlib import Path, PurePath, PurePosixPath
6
+
7
+ from fastmcp import Context
8
+
9
+ from ..utils.config import DATA_DIR
10
+ from .utils import get_session_id_tuple
11
+
12
+ WORKSPACES_ROOT_DIR = DATA_DIR / "workspaces" # Path on the host where workspaces are stored
13
+ CONTAINER_WORKSPACE_PATH = PurePosixPath("/workspace") # Path inside the sandbox container where workspace is mounted
14
+
15
+
16
+ def get_workspace_path(service_name: str = None, *, in_sandbox: bool = False, ctx: Context | tuple = None) -> PurePath:
17
+ """
18
+ Get the workspace path for a specific service (e.g. MCP server).
19
+ If `service_name` is None, then the path to entire workspace folder (as mounted to a sandbox) is returned.
20
+ If `in_sandbox` is True, it returns a path in sandbox, e.g.: `/workspace/mcp_repl`.
21
+ If `in_sandbox` is False, it returns the path based on user and session IDs in the format:
22
+ `<DATA_DIR>/workspaces/<user_id>/<session_id>/<service_name>`, where `DATA_DIR` should come from
23
+ the environment variables, e.g.:
24
+ `/data/workspaces/foo-user/bar-session/mcp_repl`.
25
+ The `ctx` is used to get user and session IDs tuple. It can be passed directly or via HTTP headers from `Context`.
26
+ If `ctx` is None, the current FastMCP request HTTP headers are used.
27
+
28
+ Args:
29
+ ctx: The FastMCP context, which contains the user session.
30
+ service_name: The name of the service (e.g. "mcp_server").
31
+ in_sandbox: If True, use a sandbox path; otherwise, use user/session-based path.
32
+
33
+ Returns: The workspace path as a PurePath object.
34
+ """
35
+ if in_sandbox:
36
+ path = CONTAINER_WORKSPACE_PATH
37
+ else:
38
+ user_id, session_id = ctx if isinstance(ctx, tuple) else get_session_id_tuple(ctx)
39
+ path = WORKSPACES_ROOT_DIR / user_id / session_id
40
+ if service_name:
41
+ path = path / service_name
42
+ return path
43
+
44
+
45
+ def container_to_host_path(path: PurePosixPath, *, ctx: Context | tuple = None) -> Path | None:
46
+ """
47
+ Convert a path in a sandbox container to a host path
48
+
49
+ Args:
50
+ container_path: Path inside the container (must be a subdir of CONTAINER_WORKSPACE_PATH).
51
+ user_id: ID of the user.
52
+ session_id: ID of the session.
53
+
54
+ Returns:
55
+ Path to the file on the host, or None if the conversion fails.
56
+ """
57
+ old_root = CONTAINER_WORKSPACE_PATH
58
+ new_root = get_workspace_path(ctx=ctx)
59
+ try:
60
+ return new_root / PurePosixPath(path).relative_to(old_root)
61
+ except ValueError as e:
62
+ raise ValueError(f"Container path must be a subdir of '{old_root}', got '{path}' instead") from e
63
+
64
+
65
+ def host_to_container_path(path: Path, *, ctx: Context | tuple = None) -> PurePosixPath:
66
+ """Convert a host path to a path in a sandbox container."""
67
+ old_root = get_workspace_path(ctx=ctx)
68
+ new_root = CONTAINER_WORKSPACE_PATH
69
+ try:
70
+ return new_root / Path(path).relative_to(old_root)
71
+ except ValueError as exc:
72
+ raise ValueError(f"Host path must be a subdir of '{old_root}', got '{path}' instead") from exc
@@ -0,0 +1,70 @@
1
+ """
2
+ FastMCP server utilities for handling user context and threading.
3
+ """
4
+
5
+ import asyncio
6
+ from functools import wraps
7
+
8
+ from fastmcp import Context
9
+ from fastmcp.server import dependencies
10
+
11
+ from ..context import session_id_var, user_id_var
12
+
13
+
14
+ def get_session_id_tuple(ctx: Context | None = None) -> tuple[str, str]:
15
+ """
16
+ Get the user and session IDs from the user session.
17
+ If `ctx` is None, the current FastMCP request HTTP headers are used.
18
+ Returns: Tuple of (user_id, session_id).
19
+ """
20
+ user_id = get_user_id_from_request(ctx)
21
+ user_id = user_id or user_id_var.get("default_user")
22
+ session_id = get_session_id_from_request(ctx)
23
+ session_id = session_id or session_id_var.get("default_session")
24
+ return user_id, session_id
25
+
26
+
27
+ def get_session_id_from_request(ctx: Context | None = None) -> str | None:
28
+ """
29
+ Get the session ID from the HTTP request headers.
30
+ If `ctx` is None, the current FastMCP request HTTP headers are used.
31
+ """
32
+ try:
33
+ return (ctx or dependencies).get_http_request().headers.get("session-id")
34
+ except (ValueError, RuntimeError):
35
+ return None
36
+
37
+
38
+ def get_user_id_from_request(ctx: Context | None = None) -> str | None:
39
+ """
40
+ Get the user ID from the HTTP request headers.
41
+ If `ctx` is None, the current FastMCP request HTTP headers are used.
42
+ The user_id is always returned as lowercase.
43
+
44
+ Returns:
45
+ str | None: The lowercase user ID, or None if not found or an error occurs.
46
+ """
47
+ try:
48
+ user_id = (ctx or dependencies).get_http_request().headers.get("user-id")
49
+ return user_id.lower() if user_id else None
50
+ except (ValueError, RuntimeError, AttributeError):
51
+ return None
52
+
53
+
54
+ def get_session_id_str(ctx: Context | None = None) -> str:
55
+ """
56
+ Combined session ID for the current user and session.
57
+ If `ctx` is None, the current FastMCP request HTTP headers are used.
58
+ """
59
+ user_id, session_id = get_session_id_tuple(ctx)
60
+ return f"{user_id}:{session_id}"
61
+
62
+
63
+ def run_in_thread(func):
64
+ """decorator to run blocking function with `asyncio.to_thread`"""
65
+
66
+ @wraps(func)
67
+ async def wrapper(*args, **kwargs):
68
+ return await asyncio.to_thread(func, *args, **kwargs)
69
+
70
+ return wrapper
@@ -0,0 +1,9 @@
1
+ """
2
+ Testing utilities for AI agents and model patching.
3
+ """
4
+
5
+ from aixtools.testing.model_patch_cache import model_patch_cache
6
+
7
+ __all__ = [
8
+ "model_patch_cache",
9
+ ]