aixtools 0.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of aixtools might be problematic. Click here for more details.

Files changed (88) hide show
  1. aixtools/.chainlit/config.toml +113 -0
  2. aixtools/.chainlit/translations/bn.json +214 -0
  3. aixtools/.chainlit/translations/en-US.json +214 -0
  4. aixtools/.chainlit/translations/gu.json +214 -0
  5. aixtools/.chainlit/translations/he-IL.json +214 -0
  6. aixtools/.chainlit/translations/hi.json +214 -0
  7. aixtools/.chainlit/translations/ja.json +214 -0
  8. aixtools/.chainlit/translations/kn.json +214 -0
  9. aixtools/.chainlit/translations/ml.json +214 -0
  10. aixtools/.chainlit/translations/mr.json +214 -0
  11. aixtools/.chainlit/translations/nl.json +214 -0
  12. aixtools/.chainlit/translations/ta.json +214 -0
  13. aixtools/.chainlit/translations/te.json +214 -0
  14. aixtools/.chainlit/translations/zh-CN.json +214 -0
  15. aixtools/__init__.py +11 -0
  16. aixtools/_version.py +34 -0
  17. aixtools/a2a/app.py +126 -0
  18. aixtools/a2a/google_sdk/__init__.py +0 -0
  19. aixtools/a2a/google_sdk/card.py +27 -0
  20. aixtools/a2a/google_sdk/pydantic_ai_adapter/agent_executor.py +199 -0
  21. aixtools/a2a/google_sdk/pydantic_ai_adapter/storage.py +26 -0
  22. aixtools/a2a/google_sdk/remote_agent_connection.py +88 -0
  23. aixtools/a2a/google_sdk/utils.py +59 -0
  24. aixtools/a2a/utils.py +115 -0
  25. aixtools/agents/__init__.py +12 -0
  26. aixtools/agents/agent.py +164 -0
  27. aixtools/agents/agent_batch.py +71 -0
  28. aixtools/agents/prompt.py +97 -0
  29. aixtools/app.py +143 -0
  30. aixtools/chainlit.md +14 -0
  31. aixtools/compliance/__init__.py +9 -0
  32. aixtools/compliance/private_data.py +138 -0
  33. aixtools/context.py +17 -0
  34. aixtools/db/__init__.py +17 -0
  35. aixtools/db/database.py +110 -0
  36. aixtools/db/vector_db.py +115 -0
  37. aixtools/google/client.py +25 -0
  38. aixtools/log_view/__init__.py +17 -0
  39. aixtools/log_view/app.py +195 -0
  40. aixtools/log_view/display.py +285 -0
  41. aixtools/log_view/export.py +51 -0
  42. aixtools/log_view/filters.py +41 -0
  43. aixtools/log_view/log_utils.py +26 -0
  44. aixtools/log_view/node_summary.py +229 -0
  45. aixtools/logfilters/__init__.py +7 -0
  46. aixtools/logfilters/context_filter.py +67 -0
  47. aixtools/logging/__init__.py +30 -0
  48. aixtools/logging/log_objects.py +227 -0
  49. aixtools/logging/logging_config.py +161 -0
  50. aixtools/logging/mcp_log_models.py +102 -0
  51. aixtools/logging/mcp_logger.py +172 -0
  52. aixtools/logging/model_patch_logging.py +87 -0
  53. aixtools/logging/open_telemetry.py +36 -0
  54. aixtools/mcp/__init__.py +9 -0
  55. aixtools/mcp/client.py +375 -0
  56. aixtools/mcp/example_client.py +30 -0
  57. aixtools/mcp/example_server.py +22 -0
  58. aixtools/mcp/fast_mcp_log.py +31 -0
  59. aixtools/mcp/faulty_mcp.py +319 -0
  60. aixtools/model_patch/model_patch.py +63 -0
  61. aixtools/server/__init__.py +29 -0
  62. aixtools/server/app_mounter.py +90 -0
  63. aixtools/server/path.py +72 -0
  64. aixtools/server/utils.py +70 -0
  65. aixtools/server/workspace_privacy.py +65 -0
  66. aixtools/testing/__init__.py +9 -0
  67. aixtools/testing/aix_test_model.py +149 -0
  68. aixtools/testing/mock_tool.py +66 -0
  69. aixtools/testing/model_patch_cache.py +279 -0
  70. aixtools/tools/doctor/__init__.py +3 -0
  71. aixtools/tools/doctor/tool_doctor.py +61 -0
  72. aixtools/tools/doctor/tool_recommendation.py +44 -0
  73. aixtools/utils/__init__.py +35 -0
  74. aixtools/utils/chainlit/cl_agent_show.py +82 -0
  75. aixtools/utils/chainlit/cl_utils.py +168 -0
  76. aixtools/utils/config.py +131 -0
  77. aixtools/utils/config_util.py +69 -0
  78. aixtools/utils/enum_with_description.py +37 -0
  79. aixtools/utils/files.py +17 -0
  80. aixtools/utils/persisted_dict.py +99 -0
  81. aixtools/utils/utils.py +167 -0
  82. aixtools/vault/__init__.py +7 -0
  83. aixtools/vault/vault.py +137 -0
  84. aixtools-0.0.0.dist-info/METADATA +669 -0
  85. aixtools-0.0.0.dist-info/RECORD +88 -0
  86. aixtools-0.0.0.dist-info/WHEEL +5 -0
  87. aixtools-0.0.0.dist-info/entry_points.txt +2 -0
  88. aixtools-0.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,319 @@
1
+ #!/usr/bin/env python3
2
+ """
3
+ Faulty MCP Server for Testing MCP Errors
4
+ - Simulates 404 errors for specific MCP requests
5
+ """
6
+
7
+ import argparse
8
+ import asyncio
9
+ import json
10
+ import logging.config
11
+ import os
12
+ from dataclasses import dataclass
13
+ from random import choice, random
14
+
15
+ from fastapi import HTTPException, status
16
+ from fastmcp import FastMCP
17
+ from fastmcp.exceptions import (
18
+ ClientError,
19
+ DisabledError,
20
+ InvalidSignature,
21
+ NotFoundError,
22
+ PromptError,
23
+ ResourceError,
24
+ ToolError,
25
+ ValidationError,
26
+ )
27
+ from fastmcp.server.middleware import Middleware as McpMiddleware
28
+ from fastmcp.server.middleware import MiddlewareContext
29
+ from fastmcp.server.middleware.logging import LoggingMiddleware
30
+ from starlette.middleware import Middleware as StarletteMiddleware
31
+ from starlette.types import Receive, Scope, Send
32
+
33
+ from aixtools.logging.logging_config import DEFAULT_LOGGING_CONFIG
34
+ from aixtools.utils import get_logger
35
+
36
+ # Remove the user/session ID from logger line to shorten it
37
+ DEFAULT_LOGGING_CONFIG["formatters"]["color"]["format"] = (
38
+ "%(log_color)s%(asctime)s.%(msecs)03d %(levelname)-8s%(reset)s %(message)s"
39
+ )
40
+ logging.config.dictConfig(DEFAULT_LOGGING_CONFIG)
41
+
42
+ # Get the logger
43
+ logger = get_logger(__name__)
44
+
45
+
46
+ @dataclass
47
+ class Config:
48
+ """Global configuration for the faulty MCP server."""
49
+
50
+ port: int = 9999
51
+ prob_on_post_404: float = 0.5 # Probability of injecting a 404 error for POST requests
52
+ prob_on_get_crash: float = 0.3 # Probability of terminating the process on GET request
53
+ prob_on_delete_404: float = 0.5 # Probability of injecting a 404 error for DELETE requests
54
+ prob_in_list_tools_throw: float = 0.5 # Probability of throwing an exception in list tools handling
55
+ prob_in_list_tools_crash: float = 0.3 # Probability of terminating the process in list tools handling
56
+
57
+
58
+ # Global configuration
59
+ config = Config()
60
+
61
+
62
+ class McpErrorMiddleware(McpMiddleware):
63
+ """Custom middleware to simulate errors in MCP requests."""
64
+
65
+ async def __call__(self, context: MiddlewareContext, call_next):
66
+ # This method receives ALL messages regardless of type
67
+ logger.info("[McpErrorMiddleware] processing: %s", context.method)
68
+
69
+ if context.method == "tools/list":
70
+ random_number = random()
71
+ logger.info("[McpErrorMiddleware] random number: %f", random_number)
72
+ if random_number < config.prob_in_list_tools_crash:
73
+ logger.warning("[McpErrorMiddleware] Simulating server crash!")
74
+ os.kill(os.getpid(), 9)
75
+
76
+ if random_number < config.prob_in_list_tools_throw:
77
+ exception_class = choice(
78
+ [
79
+ ValidationError,
80
+ ResourceError,
81
+ ToolError,
82
+ PromptError,
83
+ InvalidSignature,
84
+ ClientError,
85
+ NotFoundError,
86
+ DisabledError,
87
+ ]
88
+ )
89
+ logger.warning("[McpErrorMiddleware] throwing %s for: %s", exception_class.__name__, context.method)
90
+ raise exception_class(f"[McpErrorMiddleware] {exception_class.__name__}.")
91
+
92
+ result = await call_next(context)
93
+ logger.info("[McpErrorMiddleware] completed: %s", context.method)
94
+ return result
95
+
96
+
97
+ class StarletteErrorMiddleware: # pylint: disable=too-few-public-methods
98
+ """Custom Starlette middleware to log and inject errors."""
99
+
100
+ def __init__(self, app):
101
+ """Initialize middleware."""
102
+ self.app = app
103
+ logger.info("[StarletteErrorMiddleware] Middleware initialized!")
104
+ logger.info("Current configuration:")
105
+ logger.info("HTTP 404 on POST request probability: %f", config.prob_on_post_404)
106
+ logger.info("Terminate on GET request probability: %f", config.prob_on_get_crash)
107
+ logger.info("HTTP 404 on DELETE request probability: %f", config.prob_on_delete_404)
108
+ logger.info("Exception in list tools handling probability: %f", config.prob_in_list_tools_throw)
109
+ logger.info("Terminate in list handle probability: %f", config.prob_in_list_tools_crash)
110
+
111
+ async def __call__(self, scope: Scope, receive: Receive, send: Send): # noqa: PLR0915 # pylint: disable=too-many-statements
112
+ # Log all the condition variables for debugging
113
+
114
+ logger.info("[StarletteErrorMiddleware] >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>")
115
+ logger.info("[StarletteErrorMiddleware] scope['type']: %s", scope.get("type", "unknown"))
116
+ logger.info("[StarletteErrorMiddleware] scope['path']: %s", scope.get("path", "unknown"))
117
+ logger.info("[StarletteErrorMiddleware] HTTP method: %s", http_method := scope.get("method", "unknown"))
118
+ logger.info("[StarletteErrorMiddleware] Headers: %s", str(dict(scope.get("headers", []))))
119
+
120
+ # Wrap receive to log body content without breaking the flow
121
+ body_parts = []
122
+ should_inject_404 = False
123
+
124
+ if http_method == "DELETE":
125
+ random_number = random()
126
+ logger.info("[StarletteErrorMiddleware] random number: %f", random_number)
127
+ if random_number < config.prob_on_delete_404:
128
+ logger.info("[StarletteErrorMiddleware] Simulating 404 error on DELETE request")
129
+ should_inject_404 = True
130
+
131
+ async def logging_receive():
132
+ nonlocal should_inject_404
133
+ message = await receive()
134
+ logger.info("[StarletteErrorMiddleware] Received message: %s", str(message))
135
+
136
+ if message["type"] == "http.request": # pylint: disable=too-many-nested-blocks
137
+ if http_method == "GET":
138
+ random_number = random()
139
+ logger.info("[StarletteErrorMiddleware] random number: %f", random_number)
140
+ if random_number < config.prob_on_get_crash:
141
+ logger.warning("[StarletteErrorMiddleware] Simulating server crash on GET request!")
142
+ os.kill(os.getpid(), 9)
143
+
144
+ body = message.get("body", b"")
145
+ if body:
146
+ body_parts.append(body)
147
+
148
+ # Log when we have the complete body
149
+ if not message.get("more_body", False):
150
+ complete_body = b"".join(body_parts)
151
+ if complete_body:
152
+ try:
153
+ body_str = complete_body.decode("utf-8")
154
+ logger.info("[StarletteErrorMiddleware] Request body: %s", body_str)
155
+
156
+ json_data = json.loads(body_str)
157
+ if isinstance(json_data, dict):
158
+ method_name = json_data.get("method", "unknown")
159
+ if method_name == "initialize" and not json_data.get("params", {}).get("capabilities"):
160
+ logger.info("Detected initial health check from Navari, skipping 404 injection.")
161
+ return message
162
+ logger.info("[StarletteErrorMiddleware] MCP method: %s", method_name)
163
+
164
+ # Check if we should inject 404
165
+ random_number = random()
166
+ logger.info("[StarletteErrorMiddleware] random number: %f", random_number)
167
+ if random_number < config.prob_on_post_404:
168
+ should_inject_404 = True
169
+ logger.info("[StarletteErrorMiddleware] %s - will inject 404!", method_name)
170
+ except (UnicodeDecodeError, json.JSONDecodeError) as e:
171
+ logger.exception("[StarletteErrorMiddleware] Error parsing body: %s", e)
172
+ else:
173
+ logger.info("[StarletteErrorMiddleware] Request body: (empty)")
174
+
175
+ return message
176
+
177
+ async def intercepting_send(message):
178
+ logger.info("[StarletteErrorMiddleware] Sending message: %s", str(message))
179
+ if message["type"] == "http.response.start" and should_inject_404:
180
+ logger.warning("[StarletteErrorMiddleware] Injecting 404!")
181
+ # Replace the response with 404
182
+ message = {
183
+ "type": "http.response.start",
184
+ "status": 404,
185
+ "headers": [[b"content-type", b"text/plain"]],
186
+ }
187
+ elif message["type"] == "http.response.body" and should_inject_404:
188
+ # Replace body with 404 message
189
+ message = {
190
+ "type": "http.response.body",
191
+ "body": b"Simulated 404 error",
192
+ }
193
+
194
+ await send(message)
195
+
196
+ logger.info("[StarletteErrorMiddleware] <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<")
197
+ await self.app(scope, logging_receive, intercepting_send)
198
+
199
+
200
+ # Create the MCP server
201
+ mcp = FastMCP(
202
+ name="Faulty MCP Server",
203
+ instructions="A simple test server for reproducing MCP errors.",
204
+ middleware=[LoggingMiddleware(include_payloads=True), McpErrorMiddleware()],
205
+ )
206
+
207
+
208
+ @mcp.tool
209
+ def add(a: float, b: float) -> float:
210
+ """Add two numbers together."""
211
+ return a + b
212
+
213
+
214
+ @mcp.tool
215
+ def always_error() -> None:
216
+ """Always throw an exception to simulate errors."""
217
+ raise ValueError("Simulated error")
218
+
219
+
220
+ @mcp.tool
221
+ async def freeze_server(seconds: int = 60) -> str:
222
+ """Simulate a server freeze for testing purposes."""
223
+ await asyncio.sleep(seconds)
224
+ return f"Server frozen for {seconds} seconds"
225
+
226
+
227
+ @mcp.tool
228
+ def multiply(a: float, b: float) -> float:
229
+ """Multiply two numbers together."""
230
+ return a * b
231
+
232
+
233
+ @mcp.tool
234
+ def random_throw_exception(a: float, b: float, prob: float = 0.5) -> float:
235
+ """Randomly throw an exception to simulate errors."""
236
+ if random() < prob:
237
+ raise ValueError("Simulated error")
238
+ return a * b
239
+
240
+
241
+ @mcp.tool
242
+ def throw_404_exception() -> None:
243
+ """Randomly throw an exception to simulate errors."""
244
+ raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Throwing a 404 error for testing purposes.")
245
+
246
+
247
+ def run_server_on_port():
248
+ """Run a single MCP server using the global configuration."""
249
+
250
+ async def run_async():
251
+ print(f"[Port {config.port}] Starting MCP server on http://localhost:{config.port}/mcp/")
252
+ await mcp.run_http_async(
253
+ transport="streamable-http",
254
+ host="localhost",
255
+ port=config.port,
256
+ path="/mcp/",
257
+ middleware=[StarletteMiddleware(StarletteErrorMiddleware)],
258
+ )
259
+
260
+ asyncio.run(run_async())
261
+
262
+
263
+ if __name__ == "__main__":
264
+ # Parse command line arguments
265
+ parser = argparse.ArgumentParser(description="Run a faulty MCP server for testing error handling")
266
+ parser.add_argument(
267
+ "--safe-mode",
268
+ action="store_true",
269
+ help="Set all error probabilities to 0 by default, only use explicitly provided values",
270
+ )
271
+ parser.add_argument(
272
+ "--port",
273
+ type=int,
274
+ help=f"Port to run the server on (default: {config.port})",
275
+ )
276
+ parser.add_argument(
277
+ "--prob-on-post-404",
278
+ type=float,
279
+ help=f"Probability of injecting a 404 error for POST requests (default: {config.prob_on_post_404})",
280
+ )
281
+ parser.add_argument(
282
+ "--prob-on-get-crash",
283
+ type=float,
284
+ help=f"Probability of terminating on GET request (default: {config.prob_on_get_crash})",
285
+ )
286
+ parser.add_argument(
287
+ "--prob-on-delete-404",
288
+ type=float,
289
+ help=f"Probability of injecting a 404 error for DELETE requests (default: {config.prob_on_delete_404})",
290
+ )
291
+ parser.add_argument(
292
+ "--prob-in-list-tools-throw",
293
+ type=float,
294
+ help=f"Probability of exception in list tools handling (default: {config.prob_in_list_tools_throw})",
295
+ )
296
+ parser.add_argument(
297
+ "--prob-in-list-tools-crash",
298
+ type=float,
299
+ help=f"Probability of terminating in list tools handling (default: {config.prob_in_list_tools_crash})",
300
+ )
301
+
302
+ args = parser.parse_args()
303
+
304
+ def _update_config_value(attr_name: str):
305
+ if args.safe_mode:
306
+ setattr(config, attr_name, 0)
307
+ if (given_value := getattr(args, attr_name)) is not None:
308
+ setattr(config, attr_name, given_value)
309
+
310
+ # Update the global configuration with command line arguments
311
+ config.port = args.port or config.port
312
+ _update_config_value("prob_on_post_404")
313
+ _update_config_value("prob_on_get_crash")
314
+ _update_config_value("prob_on_delete_404")
315
+ _update_config_value("prob_in_list_tools_throw")
316
+ _update_config_value("prob_in_list_tools_crash")
317
+
318
+ # Run the server
319
+ run_server_on_port()
@@ -0,0 +1,63 @@
1
+ from typing import Any
2
+
3
+ from pydantic import BaseModel, ConfigDict
4
+
5
+ from aixtools.logging.logging_config import get_logger
6
+
7
+ logger = get_logger(__name__)
8
+
9
+
10
+ class ModelRawRequest(BaseModel):
11
+ method_name: str # Model method name
12
+ request_id: str # Unique request ID
13
+ args: tuple # Method arguments
14
+ kwargs: dict # Method keyword arguments
15
+
16
+
17
+ class ModelRawRequestResult(BaseModel):
18
+ method_name: str # Model method name
19
+ request_id: str # Unique request ID
20
+ result: Any # Method's result
21
+
22
+ model_config = ConfigDict(arbitrary_types_allowed=True)
23
+
24
+
25
+ class ModelRawRequestYieldItem(BaseModel):
26
+ method_name: str # Model method name
27
+ request_id: str # Unique request ID
28
+ item_num: int # Item number in the stream
29
+ item: Any # Yielded item
30
+
31
+ model_config = ConfigDict(arbitrary_types_allowed=True)
32
+
33
+
34
+ def get_request_fn(model):
35
+ """Get the original request method"""
36
+ if is_patched(model):
37
+ return model._request_ori
38
+ return model.request
39
+
40
+
41
+ def get_request_stream_fn(model):
42
+ """Get the original request method"""
43
+ if is_patched(model):
44
+ return model._request_stream_ori
45
+ return model.request_stream
46
+
47
+
48
+ def is_patched(model):
49
+ return hasattr(model, "_request_ori") or hasattr(model, "_request_stream_ori")
50
+
51
+
52
+ def model_patch(model, request_method, request_stream_method):
53
+ """Replace model.request and model.request_stream with logging versions"""
54
+ if is_patched(model):
55
+ logger.warning(f"Model {model.__class__.__name__} is already patched. Skipping patching.")
56
+ return model
57
+ # Save original methods
58
+ model._request_ori = model.request
59
+ model._request_stream_ori = model.request_stream
60
+ # Patch methods
61
+ model.request = request_method
62
+ model.request_stream = request_stream_method
63
+ return model
@@ -0,0 +1,29 @@
1
+ """
2
+ FastMCP utilities for:
3
+ - extracting user metadata from context
4
+ - running mcp tools tasks in a separate thread.
5
+ """
6
+
7
+ from .path import (
8
+ container_to_host_path,
9
+ get_workspace_path,
10
+ host_to_container_path,
11
+ )
12
+ from .utils import (
13
+ get_session_id_tuple,
14
+ run_in_thread,
15
+ )
16
+ from .workspace_privacy import (
17
+ is_session_private,
18
+ set_session_private,
19
+ )
20
+
21
+ __all__ = [
22
+ "get_workspace_path",
23
+ "get_session_id_tuple",
24
+ "container_to_host_path",
25
+ "host_to_container_path",
26
+ "run_in_thread",
27
+ "is_session_private",
28
+ "set_session_private",
29
+ ]
@@ -0,0 +1,90 @@
1
+ """Utility for mounting sub-applications with lifespan management in Starlette/FastAPI."""
2
+
3
+ from contextlib import asynccontextmanager
4
+
5
+ from starlette.applications import Starlette
6
+ from starlette.types import ASGIApp
7
+
8
+
9
+ class AppMounter: # pylint: disable=too-few-public-methods
10
+ """
11
+ A utility class for mounting sub-applications with their lifespans.
12
+
13
+ This class handles the complexity of ensuring that mounted sub-applications
14
+ have their lifespans properly managed alongside the parent application.
15
+ """
16
+
17
+ def __init__(self, parent_app: Starlette):
18
+ """
19
+ Initialize the SubAppMounter with a parent application.
20
+
21
+ Args:
22
+ parent_app: The parent Starlette/FastAPI application
23
+ """
24
+ self.parent_app = parent_app
25
+ self.mounted_apps: list[tuple[str, ASGIApp]] = []
26
+ self._original_lifespan = parent_app.router.lifespan_context
27
+ self._setup_combined_lifespan()
28
+
29
+ def mount_with_lifespan(self, path: str, app: ASGIApp, name: str = None) -> None:
30
+ """
31
+ Mount a sub-application and ensure its lifespan is managed.
32
+
33
+ Args:
34
+ path: The path to mount the application at
35
+ app: The ASGI application to mount
36
+ name: Optional name for the mounted application
37
+ """
38
+ # Mount the app using the parent app's mount method
39
+ self.parent_app.mount(path, app=app, name=name)
40
+
41
+ # Store the mounted app for lifespan management
42
+ self.mounted_apps.append((path, app))
43
+
44
+ def _setup_combined_lifespan(self) -> None:
45
+ """
46
+ Set up a combined lifespan that manages both the parent app and all mounted sub-apps.
47
+ """
48
+
49
+ @asynccontextmanager
50
+ async def combined_lifespan(app):
51
+ # First enter the parent app's lifespan
52
+ async with self._original_lifespan(app):
53
+ # Then enter each mounted app's lifespan in order
54
+ # We use nested context managers to ensure proper cleanup
55
+ async with self._create_nested_lifespans():
56
+ yield
57
+
58
+ # Replace the parent app's lifespan with our combined one
59
+ self.parent_app.router.lifespan_context = combined_lifespan
60
+
61
+ @asynccontextmanager
62
+ async def _create_nested_lifespans(self):
63
+ """
64
+ Create nested async context managers for all mounted apps.
65
+
66
+ This ensures that all sub-app lifespans are entered and exited in the correct order.
67
+ """
68
+ # If no apps are mounted, just yield
69
+ if not self.mounted_apps:
70
+ yield
71
+ return
72
+
73
+ # Otherwise, create nested context managers for each mounted app
74
+ async def enter_lifespans(index=0):
75
+ if index >= len(self.mounted_apps):
76
+ yield
77
+ return
78
+
79
+ _, app = self.mounted_apps[index]
80
+ if hasattr(app, "router") and hasattr(app.router, "lifespan_context"):
81
+ async with app.router.lifespan_context(app):
82
+ async for _ in enter_lifespans(index + 1):
83
+ yield
84
+ else:
85
+ # If the app doesn't have a lifespan, just move to the next one
86
+ async for _ in enter_lifespans(index + 1):
87
+ yield
88
+
89
+ async for _ in enter_lifespans():
90
+ yield
@@ -0,0 +1,72 @@
1
+ """
2
+ Workspace path handling for user sessions.
3
+ """
4
+
5
+ from pathlib import Path, PurePath, PurePosixPath
6
+
7
+ from fastmcp import Context
8
+
9
+ from ..utils.config import DATA_DIR
10
+ from .utils import get_session_id_tuple
11
+
12
+ WORKSPACES_ROOT_DIR = DATA_DIR / "workspaces" # Path on the host where workspaces are stored
13
+ CONTAINER_WORKSPACE_PATH = PurePosixPath("/workspace") # Path inside the sandbox container where workspace is mounted
14
+
15
+
16
+ def get_workspace_path(service_name: str = None, *, in_sandbox: bool = False, ctx: Context | tuple = None) -> PurePath:
17
+ """
18
+ Get the workspace path for a specific service (e.g. MCP server).
19
+ If `service_name` is None, then the path to entire workspace folder (as mounted to a sandbox) is returned.
20
+ If `in_sandbox` is True, it returns a path in sandbox, e.g.: `/workspace/mcp_repl`.
21
+ If `in_sandbox` is False, it returns the path based on user and session IDs in the format:
22
+ `<DATA_DIR>/workspaces/<user_id>/<session_id>/<service_name>`, where `DATA_DIR` should come from
23
+ the environment variables, e.g.:
24
+ `/data/workspaces/foo-user/bar-session/mcp_repl`.
25
+ The `ctx` is used to get user and session IDs tuple. It can be passed directly or via HTTP headers from `Context`.
26
+ If `ctx` is None, the current FastMCP request HTTP headers are used.
27
+
28
+ Args:
29
+ ctx: The FastMCP context, which contains the user session.
30
+ service_name: The name of the service (e.g. "mcp_server").
31
+ in_sandbox: If True, use a sandbox path; otherwise, use user/session-based path.
32
+
33
+ Returns: The workspace path as a PurePath object.
34
+ """
35
+ if in_sandbox:
36
+ path = CONTAINER_WORKSPACE_PATH
37
+ else:
38
+ user_id, session_id = ctx if isinstance(ctx, tuple) else get_session_id_tuple(ctx)
39
+ path = WORKSPACES_ROOT_DIR / user_id / session_id
40
+ if service_name:
41
+ path = path / service_name
42
+ return path
43
+
44
+
45
+ def container_to_host_path(path: PurePosixPath, *, ctx: Context | tuple = None) -> Path | None:
46
+ """
47
+ Convert a path in a sandbox container to a host path
48
+
49
+ Args:
50
+ container_path: Path inside the container (must be a subdir of CONTAINER_WORKSPACE_PATH).
51
+ user_id: ID of the user.
52
+ session_id: ID of the session.
53
+
54
+ Returns:
55
+ Path to the file on the host, or None if the conversion fails.
56
+ """
57
+ old_root = CONTAINER_WORKSPACE_PATH
58
+ new_root = get_workspace_path(ctx=ctx)
59
+ try:
60
+ return new_root / PurePosixPath(path).relative_to(old_root)
61
+ except ValueError as e:
62
+ raise ValueError(f"Container path must be a subdir of '{old_root}', got '{path}' instead") from e
63
+
64
+
65
+ def host_to_container_path(path: Path, *, ctx: Context | tuple = None) -> PurePosixPath:
66
+ """Convert a host path to a path in a sandbox container."""
67
+ old_root = get_workspace_path(ctx=ctx)
68
+ new_root = CONTAINER_WORKSPACE_PATH
69
+ try:
70
+ return new_root / Path(path).relative_to(old_root)
71
+ except ValueError as exc:
72
+ raise ValueError(f"Host path must be a subdir of '{old_root}', got '{path}' instead") from exc
@@ -0,0 +1,70 @@
1
+ """
2
+ FastMCP server utilities for handling user context and threading.
3
+ """
4
+
5
+ import asyncio
6
+ from functools import wraps
7
+
8
+ from fastmcp import Context
9
+ from fastmcp.server import dependencies
10
+
11
+ from ..context import DEFAULT_SESSION_ID, DEFAULT_USER_ID, session_id_var, user_id_var
12
+
13
+
14
+ def get_session_id_tuple(ctx: Context | None = None) -> tuple[str, str]:
15
+ """
16
+ Get the user and session IDs from the user session.
17
+ If `ctx` is None, the current FastMCP request HTTP headers are used.
18
+ Returns: Tuple of (user_id, session_id).
19
+ """
20
+ user_id = get_user_id_from_request(ctx)
21
+ user_id = user_id or user_id_var.get(DEFAULT_USER_ID)
22
+ session_id = get_session_id_from_request(ctx)
23
+ session_id = session_id or session_id_var.get(DEFAULT_SESSION_ID)
24
+ return user_id, session_id
25
+
26
+
27
+ def get_session_id_from_request(ctx: Context | None = None) -> str | None:
28
+ """
29
+ Get the session ID from the HTTP request headers.
30
+ If `ctx` is None, the current FastMCP request HTTP headers are used.
31
+ """
32
+ try:
33
+ return (ctx or dependencies).get_http_request().headers.get("session-id")
34
+ except (ValueError, RuntimeError):
35
+ return None
36
+
37
+
38
+ def get_user_id_from_request(ctx: Context | None = None) -> str | None:
39
+ """
40
+ Get the user ID from the HTTP request headers.
41
+ If `ctx` is None, the current FastMCP request HTTP headers are used.
42
+ The user_id is always returned as lowercase.
43
+
44
+ Returns:
45
+ str | None: The lowercase user ID, or None if not found or an error occurs.
46
+ """
47
+ try:
48
+ user_id = (ctx or dependencies).get_http_request().headers.get("user-id")
49
+ return user_id.lower() if user_id else None
50
+ except (ValueError, RuntimeError, AttributeError):
51
+ return None
52
+
53
+
54
+ def get_session_id_str(ctx: Context | None = None) -> str:
55
+ """
56
+ Combined session ID for the current user and session.
57
+ If `ctx` is None, the current FastMCP request HTTP headers are used.
58
+ """
59
+ user_id, session_id = get_session_id_tuple(ctx)
60
+ return f"{user_id}:{session_id}"
61
+
62
+
63
+ def run_in_thread(func):
64
+ """decorator to run blocking function with `asyncio.to_thread`"""
65
+
66
+ @wraps(func)
67
+ async def wrapper(*args, **kwargs):
68
+ return await asyncio.to_thread(func, *args, **kwargs)
69
+
70
+ return wrapper