hud-python 0.5.1__py3-none-any.whl → 0.5.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hud/__init__.py +1 -1
- hud/agents/__init__.py +65 -6
- hud/agents/base.py +33 -15
- hud/agents/claude.py +60 -31
- hud/agents/gateway.py +42 -0
- hud/agents/gemini.py +15 -26
- hud/agents/gemini_cua.py +6 -17
- hud/agents/misc/response_agent.py +7 -0
- hud/agents/openai.py +16 -29
- hud/agents/openai_chat.py +3 -19
- hud/agents/operator.py +5 -17
- hud/agents/resolver.py +70 -0
- hud/agents/tests/test_claude.py +2 -4
- hud/agents/tests/test_openai.py +2 -1
- hud/agents/tests/test_resolver.py +192 -0
- hud/agents/types.py +148 -0
- hud/cli/__init__.py +34 -3
- hud/cli/build.py +37 -5
- hud/cli/dev.py +11 -2
- hud/cli/eval.py +51 -39
- hud/cli/flows/init.py +1 -1
- hud/cli/pull.py +1 -1
- hud/cli/push.py +9 -2
- hud/cli/tests/test_build.py +2 -2
- hud/cli/tests/test_push.py +1 -1
- hud/cli/utils/metadata.py +1 -1
- hud/cli/utils/tests/test_metadata.py +1 -1
- hud/clients/mcp_use.py +6 -1
- hud/datasets/loader.py +17 -18
- hud/datasets/runner.py +16 -10
- hud/datasets/tests/test_loader.py +15 -15
- hud/environment/__init__.py +5 -3
- hud/environment/connection.py +58 -6
- hud/environment/connectors/mcp_config.py +29 -1
- hud/environment/environment.py +218 -77
- hud/environment/router.py +175 -24
- hud/environment/scenarios.py +313 -186
- hud/environment/tests/test_connectors.py +10 -23
- hud/environment/tests/test_environment.py +432 -0
- hud/environment/tests/test_local_connectors.py +81 -40
- hud/environment/tests/test_scenarios.py +820 -14
- hud/eval/context.py +63 -10
- hud/eval/instrument.py +4 -2
- hud/eval/manager.py +79 -12
- hud/eval/task.py +36 -4
- hud/eval/tests/test_eval.py +1 -1
- hud/eval/tests/test_task.py +147 -1
- hud/eval/types.py +2 -0
- hud/eval/utils.py +14 -3
- hud/patches/mcp_patches.py +178 -21
- hud/telemetry/instrument.py +8 -1
- hud/telemetry/tests/test_eval_telemetry.py +8 -8
- hud/tools/__init__.py +2 -0
- hud/tools/agent.py +223 -0
- hud/tools/computer/__init__.py +34 -5
- hud/tools/shell.py +3 -3
- hud/tools/tests/test_agent_tool.py +355 -0
- hud/types.py +62 -34
- hud/utils/hud_console.py +30 -17
- hud/utils/strict_schema.py +1 -1
- hud/utils/tests/test_version.py +1 -1
- hud/version.py +1 -1
- {hud_python-0.5.1.dist-info → hud_python-0.5.13.dist-info}/METADATA +2 -2
- {hud_python-0.5.1.dist-info → hud_python-0.5.13.dist-info}/RECORD +67 -61
- {hud_python-0.5.1.dist-info → hud_python-0.5.13.dist-info}/WHEEL +0 -0
- {hud_python-0.5.1.dist-info → hud_python-0.5.13.dist-info}/entry_points.txt +0 -0
- {hud_python-0.5.1.dist-info → hud_python-0.5.13.dist-info}/licenses/LICENSE +0 -0
hud/environment/connection.py
CHANGED
|
@@ -68,6 +68,8 @@ class Connector:
|
|
|
68
68
|
self.connection_type = connection_type
|
|
69
69
|
self.client: FastMCPClient[Any] | None = None
|
|
70
70
|
self._tools_cache: list[mcp_types.Tool] | None = None
|
|
71
|
+
self._prompts_cache: list[mcp_types.Prompt] | None = None
|
|
72
|
+
self._resources_cache: list[mcp_types.Resource] | None = None
|
|
71
73
|
|
|
72
74
|
def copy(self) -> Connector:
|
|
73
75
|
"""Create a copy of this connector with fresh (unconnected) state.
|
|
@@ -101,6 +103,14 @@ class Connector:
|
|
|
101
103
|
def cached_tools(self) -> list[mcp_types.Tool]:
|
|
102
104
|
return self._tools_cache or []
|
|
103
105
|
|
|
106
|
+
@property
|
|
107
|
+
def cached_prompts(self) -> list[mcp_types.Prompt]:
|
|
108
|
+
return self._prompts_cache or []
|
|
109
|
+
|
|
110
|
+
@property
|
|
111
|
+
def cached_resources(self) -> list[mcp_types.Resource]:
|
|
112
|
+
return self._resources_cache or []
|
|
113
|
+
|
|
104
114
|
async def connect(self) -> None:
|
|
105
115
|
"""Create FastMCP client and connect.
|
|
106
116
|
|
|
@@ -110,19 +120,27 @@ class Connector:
|
|
|
110
120
|
"""
|
|
111
121
|
from fastmcp.client import Client as FastMCPClient
|
|
112
122
|
|
|
113
|
-
|
|
114
|
-
|
|
123
|
+
self.client = FastMCPClient(
|
|
124
|
+
transport=self._transport,
|
|
125
|
+
auth=self._auth,
|
|
126
|
+
)
|
|
115
127
|
await self.client.__aenter__()
|
|
116
128
|
|
|
117
129
|
async def disconnect(self) -> None:
|
|
118
|
-
"""Disconnect and clear
|
|
130
|
+
"""Disconnect and clear all caches."""
|
|
119
131
|
if self.client is not None and self.is_connected:
|
|
120
132
|
await self.client.__aexit__(None, None, None)
|
|
121
133
|
self.client = None
|
|
122
134
|
self._tools_cache = None
|
|
135
|
+
self._prompts_cache = None
|
|
136
|
+
self._resources_cache = None
|
|
123
137
|
|
|
124
138
|
async def list_tools(self) -> list[mcp_types.Tool]:
|
|
125
|
-
"""Fetch tools from server, apply filters/transforms/prefix, and cache.
|
|
139
|
+
"""Fetch tools from server, apply filters/transforms/prefix, and cache.
|
|
140
|
+
|
|
141
|
+
Always fetches fresh data from the server (no caching check).
|
|
142
|
+
The result is cached for use by router.build() via cached_tools property.
|
|
143
|
+
"""
|
|
126
144
|
if self.client is None:
|
|
127
145
|
raise RuntimeError("Not connected - call connect() first")
|
|
128
146
|
tools = await self.client.list_tools()
|
|
@@ -178,14 +196,48 @@ class Connector:
|
|
|
178
196
|
return await self.client.call_tool_mcp(name, arguments or {})
|
|
179
197
|
|
|
180
198
|
async def list_resources(self) -> list[mcp_types.Resource]:
|
|
199
|
+
"""Fetch resources from server and cache.
|
|
200
|
+
|
|
201
|
+
Always fetches fresh data from the server (no caching check).
|
|
202
|
+
The result is cached for use by router.build_resources() via cached_resources property.
|
|
203
|
+
|
|
204
|
+
Note: resources/list is optional in the MCP spec. If the server doesn't
|
|
205
|
+
implement it, we return an empty list gracefully.
|
|
206
|
+
"""
|
|
181
207
|
if self.client is None:
|
|
182
208
|
raise RuntimeError("Not connected - call connect() first")
|
|
183
|
-
|
|
209
|
+
try:
|
|
210
|
+
self._resources_cache = await self.client.list_resources()
|
|
211
|
+
except Exception as e:
|
|
212
|
+
# Handle servers that don't implement resources/list (optional in MCP spec)
|
|
213
|
+
if "Method not found" in str(e):
|
|
214
|
+
logger.debug("Server %s does not support resources/list", self.name)
|
|
215
|
+
self._resources_cache = []
|
|
216
|
+
else:
|
|
217
|
+
raise
|
|
218
|
+
return self._resources_cache
|
|
184
219
|
|
|
185
220
|
async def list_prompts(self) -> list[mcp_types.Prompt]:
|
|
221
|
+
"""Fetch prompts from server and cache.
|
|
222
|
+
|
|
223
|
+
Always fetches fresh data from the server (no caching check).
|
|
224
|
+
The result is cached for use by router.build_prompts() via cached_prompts property.
|
|
225
|
+
|
|
226
|
+
Note: prompts/list is optional in the MCP spec. If the server doesn't
|
|
227
|
+
implement it, we return an empty list gracefully.
|
|
228
|
+
"""
|
|
186
229
|
if self.client is None:
|
|
187
230
|
raise RuntimeError("Not connected - call connect() first")
|
|
188
|
-
|
|
231
|
+
try:
|
|
232
|
+
self._prompts_cache = await self.client.list_prompts()
|
|
233
|
+
except Exception as e:
|
|
234
|
+
# Handle servers that don't implement prompts/list (optional in MCP spec)
|
|
235
|
+
if "Method not found" in str(e):
|
|
236
|
+
logger.debug("Server %s does not support prompts/list", self.name)
|
|
237
|
+
self._prompts_cache = []
|
|
238
|
+
else:
|
|
239
|
+
raise
|
|
240
|
+
return self._prompts_cache
|
|
189
241
|
|
|
190
242
|
async def read_resource(
|
|
191
243
|
self, uri: str
|
|
@@ -50,6 +50,7 @@ class MCPConfigConnectorMixin(BaseConnectorMixin):
|
|
|
50
50
|
```
|
|
51
51
|
"""
|
|
52
52
|
from hud.environment.connection import ConnectionType
|
|
53
|
+
from hud.settings import settings
|
|
53
54
|
|
|
54
55
|
name = alias or next(iter(config.keys()), "mcp")
|
|
55
56
|
server_config = next(iter(config.values()), {})
|
|
@@ -57,9 +58,20 @@ class MCPConfigConnectorMixin(BaseConnectorMixin):
|
|
|
57
58
|
is_local = "command" in server_config or "args" in server_config
|
|
58
59
|
conn_type = ConnectionType.LOCAL if is_local else ConnectionType.REMOTE
|
|
59
60
|
|
|
61
|
+
transport: Any = config
|
|
62
|
+
if not is_local and "url" in server_config:
|
|
63
|
+
max_request_timeout = 840
|
|
64
|
+
server_config.setdefault(
|
|
65
|
+
"sse_read_timeout",
|
|
66
|
+
min(settings.client_timeout, max_request_timeout)
|
|
67
|
+
if settings.client_timeout > 0
|
|
68
|
+
else max_request_timeout,
|
|
69
|
+
)
|
|
70
|
+
transport = _build_transport(server_config)
|
|
71
|
+
|
|
60
72
|
return self._add_connection(
|
|
61
73
|
name,
|
|
62
|
-
|
|
74
|
+
transport,
|
|
63
75
|
connection_type=conn_type,
|
|
64
76
|
prefix=prefix,
|
|
65
77
|
include=include,
|
|
@@ -107,3 +119,19 @@ class MCPConfigConnectorMixin(BaseConnectorMixin):
|
|
|
107
119
|
for server_name, server_config in mcp_config.items():
|
|
108
120
|
self.connect_mcp({server_name: server_config}, alias=server_name, **kwargs)
|
|
109
121
|
return self
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def _build_transport(server_config: dict[str, Any]) -> Any:
|
|
125
|
+
from fastmcp.client.transports import SSETransport, StreamableHttpTransport
|
|
126
|
+
from fastmcp.mcp_config import infer_transport_type_from_url
|
|
127
|
+
|
|
128
|
+
url = server_config["url"]
|
|
129
|
+
transport_type = server_config.get("transport") or infer_transport_type_from_url(url)
|
|
130
|
+
transport_cls = SSETransport if transport_type == "sse" else StreamableHttpTransport
|
|
131
|
+
|
|
132
|
+
return transport_cls(
|
|
133
|
+
url=url,
|
|
134
|
+
headers=server_config.get("headers"),
|
|
135
|
+
auth=server_config.get("auth"),
|
|
136
|
+
sse_read_timeout=server_config.get("sse_read_timeout"),
|
|
137
|
+
)
|
hud/environment/environment.py
CHANGED
|
@@ -119,6 +119,26 @@ class Environment(
|
|
|
119
119
|
|
|
120
120
|
MAX_CONCURRENT_CONNECTIONS = 10
|
|
121
121
|
|
|
122
|
+
@staticmethod
|
|
123
|
+
def _normalize_name(name: str) -> str:
|
|
124
|
+
"""Normalize environment name to lowercase with hyphens.
|
|
125
|
+
|
|
126
|
+
- Strips whitespace
|
|
127
|
+
- Replaces spaces and underscores with hyphens
|
|
128
|
+
- Lowercases the result
|
|
129
|
+
- Removes any non-alphanumeric characters except hyphens
|
|
130
|
+
"""
|
|
131
|
+
import re
|
|
132
|
+
|
|
133
|
+
normalized = name.strip().lower()
|
|
134
|
+
normalized = normalized.replace(" ", "-").replace("_", "-")
|
|
135
|
+
# Keep only alphanumeric and hyphens
|
|
136
|
+
normalized = re.sub(r"[^a-z0-9-]", "", normalized)
|
|
137
|
+
# Collapse multiple hyphens
|
|
138
|
+
normalized = re.sub(r"-+", "-", normalized)
|
|
139
|
+
# Strip leading/trailing hyphens
|
|
140
|
+
return normalized.strip("-") or "environment"
|
|
141
|
+
|
|
122
142
|
def __init__(
|
|
123
143
|
self,
|
|
124
144
|
name: str = "environment",
|
|
@@ -126,14 +146,23 @@ class Environment(
|
|
|
126
146
|
conflict_resolution: ConflictResolution = ConflictResolution.PREFIX,
|
|
127
147
|
**fastmcp_kwargs: Any,
|
|
128
148
|
) -> None:
|
|
149
|
+
# Normalize name to prevent casing/spacing issues
|
|
150
|
+
name = self._normalize_name(name)
|
|
129
151
|
super().__init__(name=name, instructions=instructions, **fastmcp_kwargs)
|
|
130
152
|
self._connections: dict[str, Connector] = {}
|
|
131
153
|
self._router = ToolRouter(conflict_resolution=conflict_resolution)
|
|
154
|
+
# Granular routing flags - only rebuild what's invalidated
|
|
155
|
+
self._tool_routing_built = False
|
|
156
|
+
self._prompt_routing_built = False
|
|
157
|
+
self._resource_routing_built = False
|
|
132
158
|
self._in_context = False
|
|
133
159
|
|
|
134
160
|
# Tool call queues - run after connections established
|
|
135
161
|
self._setup_calls: list[tuple[str, dict[str, Any]]] = []
|
|
136
162
|
self._evaluate_calls: list[tuple[str, dict[str, Any]]] = []
|
|
163
|
+
self._integration_test_calls: list[tuple[str, dict[str, Any]]] = []
|
|
164
|
+
# Store setup tool results for append_setup_output feature
|
|
165
|
+
self._setup_results: list[MCPToolResult] = []
|
|
137
166
|
|
|
138
167
|
# Default prompt (EvalContext has per-run prompt)
|
|
139
168
|
self.prompt: str | None = None
|
|
@@ -163,24 +192,35 @@ class Environment(
|
|
|
163
192
|
"""Return tools in MCP format (base format).
|
|
164
193
|
|
|
165
194
|
Applies agent-level include/exclude filtering if set.
|
|
195
|
+
Supports fnmatch-style wildcards (e.g., "*setup*", "browser_*").
|
|
166
196
|
"""
|
|
197
|
+
import fnmatch
|
|
198
|
+
|
|
167
199
|
tools = self._router.tools
|
|
168
200
|
|
|
169
201
|
# Apply agent-level filtering (from v4 allowed_tools/disallowed_tools)
|
|
170
202
|
if self._agent_include is not None or self._agent_exclude is not None:
|
|
171
203
|
filtered = []
|
|
172
204
|
for tool in tools:
|
|
173
|
-
# Include filter: None means include all
|
|
174
|
-
if self._agent_include is not None and
|
|
205
|
+
# Include filter: None means include all, check if matches any pattern
|
|
206
|
+
if self._agent_include is not None and not any(
|
|
207
|
+
fnmatch.fnmatch(tool.name, pattern) for pattern in self._agent_include
|
|
208
|
+
):
|
|
175
209
|
continue
|
|
176
|
-
# Exclude filter
|
|
177
|
-
if self._agent_exclude is not None and
|
|
210
|
+
# Exclude filter: skip if tool matches any exclude pattern
|
|
211
|
+
if self._agent_exclude is not None and any(
|
|
212
|
+
fnmatch.fnmatch(tool.name, pattern) for pattern in self._agent_exclude
|
|
213
|
+
):
|
|
178
214
|
continue
|
|
179
215
|
filtered.append(tool)
|
|
180
216
|
return filtered
|
|
181
217
|
|
|
182
218
|
return tools
|
|
183
219
|
|
|
220
|
+
def add_tool(self, obj: Any, **kwargs: Any) -> None:
|
|
221
|
+
super().add_tool(obj, **kwargs)
|
|
222
|
+
self._tool_routing_built = False # Only invalidate tool routing
|
|
223
|
+
|
|
184
224
|
async def call_tool(self, call: Any, /, **kwargs: Any) -> Any:
|
|
185
225
|
"""Call a tool, auto-detecting format and returning matching result format.
|
|
186
226
|
|
|
@@ -224,6 +264,9 @@ class Environment(
|
|
|
224
264
|
Automatically filters to only connections where the tool exists
|
|
225
265
|
(based on cached_tools from initial discovery).
|
|
226
266
|
|
|
267
|
+
For internal tools (starting with _), tries ALL connections since
|
|
268
|
+
internal tools are hidden from list_tools() and won't be in cached_tools.
|
|
269
|
+
|
|
227
270
|
Args:
|
|
228
271
|
tool_name: Name of the tool to call
|
|
229
272
|
**kwargs: Arguments to pass to the tool
|
|
@@ -233,10 +276,13 @@ class Environment(
|
|
|
233
276
|
"""
|
|
234
277
|
import asyncio
|
|
235
278
|
|
|
236
|
-
#
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
279
|
+
# For internal tools (underscore prefix), try ALL connections since
|
|
280
|
+
# they're hidden from list_tools() and won't appear in cached_tools.
|
|
281
|
+
# For regular tools, only try connections that advertise the tool.
|
|
282
|
+
if tool_name.startswith("_"):
|
|
283
|
+
targets = set(self._connections.keys())
|
|
284
|
+
else:
|
|
285
|
+
targets = self._connections_with_tool(tool_name)
|
|
240
286
|
|
|
241
287
|
results: dict[str, Any] = {}
|
|
242
288
|
|
|
@@ -245,7 +291,8 @@ class Environment(
|
|
|
245
291
|
if not connector or not connector.client:
|
|
246
292
|
return
|
|
247
293
|
try:
|
|
248
|
-
|
|
294
|
+
# Use connector.call_tool which expects arguments as a dict
|
|
295
|
+
results[name] = await connector.call_tool(tool_name, kwargs)
|
|
249
296
|
logger.debug("Broadcast '%s' to '%s' succeeded", tool_name, name)
|
|
250
297
|
except Exception as e:
|
|
251
298
|
results[name] = e
|
|
@@ -304,7 +351,7 @@ class Environment(
|
|
|
304
351
|
"""Connect all connectors, build routing, run setup tools."""
|
|
305
352
|
self._in_context = True
|
|
306
353
|
|
|
307
|
-
# Connect to all servers
|
|
354
|
+
# Connect to all servers and fetch tools/prompts/resources in parallel
|
|
308
355
|
sem = asyncio.Semaphore(self.MAX_CONCURRENT_CONNECTIONS)
|
|
309
356
|
errors: list[tuple[str, Exception]] = []
|
|
310
357
|
|
|
@@ -312,7 +359,12 @@ class Environment(
|
|
|
312
359
|
async with sem:
|
|
313
360
|
try:
|
|
314
361
|
await conn.connect()
|
|
315
|
-
|
|
362
|
+
# Batch fetch all MCP primitives in parallel for performance
|
|
363
|
+
await asyncio.gather(
|
|
364
|
+
conn.list_tools(),
|
|
365
|
+
conn.list_prompts(),
|
|
366
|
+
conn.list_resources(),
|
|
367
|
+
)
|
|
316
368
|
except Exception as e:
|
|
317
369
|
errors.append((name, e))
|
|
318
370
|
|
|
@@ -328,9 +380,25 @@ class Environment(
|
|
|
328
380
|
|
|
329
381
|
await self._build_routing()
|
|
330
382
|
|
|
331
|
-
# Setup tool calls (after connections)
|
|
383
|
+
# Setup tool calls (after connections) - abort if any setup tool fails
|
|
384
|
+
# Store results for append_setup_output feature
|
|
385
|
+
self._setup_results = []
|
|
332
386
|
for name, args in self._setup_calls:
|
|
333
|
-
await self._execute_tool(name, args)
|
|
387
|
+
result = await self._execute_tool(name, args)
|
|
388
|
+
self._setup_results.append(result)
|
|
389
|
+
if result.isError:
|
|
390
|
+
# Extract error message from result content
|
|
391
|
+
error_msg = "Setup tool failed"
|
|
392
|
+
if result.content:
|
|
393
|
+
for block in result.content:
|
|
394
|
+
if isinstance(block, mcp_types.TextContent):
|
|
395
|
+
error_msg = block.text
|
|
396
|
+
break
|
|
397
|
+
# Clean up connections before raising (since __aexit__ won't be called)
|
|
398
|
+
for conn in self._connections.values():
|
|
399
|
+
if conn.is_connected:
|
|
400
|
+
await conn.disconnect()
|
|
401
|
+
raise RuntimeError(f"Setup tool '{name}' failed: {error_msg}")
|
|
334
402
|
|
|
335
403
|
return self
|
|
336
404
|
|
|
@@ -351,6 +419,8 @@ class Environment(
|
|
|
351
419
|
rewards.append(find_reward(result))
|
|
352
420
|
except Exception as e:
|
|
353
421
|
logger.warning("Evaluate tool %s failed: %s", name, e)
|
|
422
|
+
# Record 0.0 for failed evaluate tools so they affect the average
|
|
423
|
+
rewards.append(0.0)
|
|
354
424
|
|
|
355
425
|
# Store average reward from evaluate tools
|
|
356
426
|
self._evaluate_reward: float | None = None
|
|
@@ -361,11 +431,44 @@ class Environment(
|
|
|
361
431
|
if self._connections:
|
|
362
432
|
await asyncio.gather(*[c.disconnect() for c in self._connections.values()])
|
|
363
433
|
self._router.clear()
|
|
434
|
+
self._tool_routing_built = False
|
|
435
|
+
self._prompt_routing_built = False
|
|
436
|
+
self._resource_routing_built = False
|
|
437
|
+
self._active_session = None # Clear stale scenario state
|
|
438
|
+
|
|
439
|
+
async def run_async(
|
|
440
|
+
self,
|
|
441
|
+
transport: Literal["stdio", "http", "sse"] | None = None,
|
|
442
|
+
show_banner: bool = True,
|
|
443
|
+
**transport_kwargs: Any,
|
|
444
|
+
) -> None:
|
|
445
|
+
"""Run the MCP server, auto-connecting all connectors first.
|
|
446
|
+
|
|
447
|
+
This ensures that tools from external MCP servers (via connect_mcp_config)
|
|
448
|
+
are discovered and available when the server starts.
|
|
449
|
+
"""
|
|
450
|
+
async with self: # Connect all connectors via __aenter__
|
|
451
|
+
await super().run_async(
|
|
452
|
+
transport=transport, show_banner=show_banner, **transport_kwargs
|
|
453
|
+
)
|
|
364
454
|
|
|
365
455
|
async def _build_routing(self) -> None:
|
|
456
|
+
"""Build routing for tools, prompts, and resources in parallel.
|
|
457
|
+
|
|
458
|
+
Only rebuilds what's actually invalidated for performance.
|
|
459
|
+
"""
|
|
460
|
+
tasks = []
|
|
461
|
+
if not self._tool_routing_built:
|
|
462
|
+
tasks.append(self._build_tool_routing())
|
|
463
|
+
if not self._prompt_routing_built:
|
|
464
|
+
tasks.append(self._build_prompt_routing())
|
|
465
|
+
if not self._resource_routing_built:
|
|
466
|
+
tasks.append(self._build_resource_routing())
|
|
467
|
+
if tasks:
|
|
468
|
+
await asyncio.gather(*tasks)
|
|
469
|
+
|
|
470
|
+
async def _build_tool_routing(self) -> None:
|
|
366
471
|
"""Build tool routing from local tools and connection caches."""
|
|
367
|
-
# Use get_tools() not list_tools() - it includes mounted servers without
|
|
368
|
-
# requiring MCP server communication (via_server=False)
|
|
369
472
|
local_tools_dict = await self._tool_manager.get_tools()
|
|
370
473
|
local_tools = list(local_tools_dict.values())
|
|
371
474
|
self._router.build(
|
|
@@ -375,16 +478,54 @@ class Environment(
|
|
|
375
478
|
)
|
|
376
479
|
# Populate mock schemas for auto-generated mock values
|
|
377
480
|
self._populate_mock_schemas()
|
|
481
|
+
self._tool_routing_built = True
|
|
482
|
+
|
|
483
|
+
async def _build_prompt_routing(self) -> None:
|
|
484
|
+
"""Build prompt routing from local prompts and connections."""
|
|
485
|
+
local_prompts_dict = await self._prompt_manager.get_prompts()
|
|
486
|
+
local_prompts = [p.to_mcp_prompt() for p in local_prompts_dict.values()]
|
|
487
|
+
self._router.build_prompts(local_prompts, self._connections)
|
|
488
|
+
self._prompt_routing_built = True
|
|
489
|
+
|
|
490
|
+
async def _build_resource_routing(self) -> None:
|
|
491
|
+
"""Build resource routing from local resources and connections."""
|
|
492
|
+
local_resources_dict = await self._resource_manager.get_resources()
|
|
493
|
+
local_resources = [r.to_mcp_resource() for r in local_resources_dict.values()]
|
|
494
|
+
self._router.build_resources(local_resources, self._connections)
|
|
495
|
+
self._resource_routing_built = True
|
|
496
|
+
|
|
497
|
+
# =========================================================================
|
|
498
|
+
# MCP Protocol Overrides - Include connector tools in MCP responses
|
|
499
|
+
# =========================================================================
|
|
500
|
+
|
|
501
|
+
def _setup_handlers(self) -> None:
|
|
502
|
+
"""Override FastMCP to register our custom handlers for tools."""
|
|
503
|
+
# Call parent to set up all standard handlers
|
|
504
|
+
super()._setup_handlers()
|
|
505
|
+
# Re-register our custom handlers (overwrites parent's registrations)
|
|
506
|
+
self._mcp_server.list_tools()(self._env_list_tools)
|
|
507
|
+
self._mcp_server.call_tool()(self._env_call_tool)
|
|
508
|
+
|
|
509
|
+
async def _env_list_tools(self) -> list[mcp_types.Tool]:
|
|
510
|
+
"""Return all tools including those from connectors."""
|
|
511
|
+
if not self._tool_routing_built:
|
|
512
|
+
await self._build_tool_routing()
|
|
513
|
+
return self._router.tools
|
|
514
|
+
|
|
515
|
+
async def _env_call_tool(self, name: str, arguments: dict[str, Any] | None = None) -> list[Any]:
|
|
516
|
+
"""Route tool calls through our router (handles both local and connector tools)."""
|
|
517
|
+
result = await self._execute_tool(name, arguments or {})
|
|
518
|
+
return result.content or []
|
|
378
519
|
|
|
379
520
|
# =========================================================================
|
|
380
521
|
# Tool Operations
|
|
381
522
|
# =========================================================================
|
|
382
523
|
|
|
383
524
|
async def list_tools(self) -> list[mcp_types.Tool]:
|
|
384
|
-
"""Refresh tools from all connections and rebuild routing."""
|
|
525
|
+
"""Refresh tools from all connections and rebuild tool routing."""
|
|
385
526
|
if self._connections:
|
|
386
527
|
await asyncio.gather(*[c.list_tools() for c in self._connections.values()])
|
|
387
|
-
await self.
|
|
528
|
+
await self._build_tool_routing()
|
|
388
529
|
return self._router.tools
|
|
389
530
|
|
|
390
531
|
async def _execute_tool(self, name: str, arguments: dict[str, Any]) -> MCPToolResult:
|
|
@@ -397,12 +538,15 @@ class Environment(
|
|
|
397
538
|
logger.debug("Mock mode: returning mock result for tool %s", name)
|
|
398
539
|
return self._get_mock_result(name, arguments)
|
|
399
540
|
|
|
541
|
+
# Rebuild tool routing if invalidated (e.g., after add_tool)
|
|
542
|
+
if not self._tool_routing_built:
|
|
543
|
+
await self._build_tool_routing()
|
|
544
|
+
|
|
400
545
|
if self._router.is_local(name):
|
|
401
546
|
# Call tool manager directly to avoid FastMCP context requirement
|
|
402
547
|
result = await self._tool_manager.call_tool(name, arguments)
|
|
403
548
|
return MCPToolResult(
|
|
404
|
-
content=result.content,
|
|
405
|
-
structuredContent=result.structured_content,
|
|
549
|
+
content=result.content, structuredContent=result.structured_content
|
|
406
550
|
)
|
|
407
551
|
|
|
408
552
|
connection_name = self._router.get_connection(name)
|
|
@@ -422,86 +566,83 @@ class Environment(
|
|
|
422
566
|
# =========================================================================
|
|
423
567
|
|
|
424
568
|
async def list_resources(self) -> list[mcp_types.Resource]:
|
|
425
|
-
"""
|
|
426
|
-
local = list((await self._resource_manager.get_resources()).values())
|
|
427
|
-
resources: list[mcp_types.Resource] = [r.to_mcp_resource() for r in local]
|
|
428
|
-
|
|
569
|
+
"""Refresh resources from all connections and rebuild resource routing."""
|
|
429
570
|
if self._connections:
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
for r in results:
|
|
434
|
-
if isinstance(r, list):
|
|
435
|
-
resources.extend(r)
|
|
436
|
-
|
|
437
|
-
return resources
|
|
571
|
+
await asyncio.gather(*[c.list_resources() for c in self._connections.values()])
|
|
572
|
+
await self._build_resource_routing()
|
|
573
|
+
return self._router.resources
|
|
438
574
|
|
|
439
575
|
async def read_resource(
|
|
440
576
|
self, uri: str
|
|
441
577
|
) -> list[mcp_types.TextResourceContents | mcp_types.BlobResourceContents]:
|
|
442
|
-
"""Read a resource by URI
|
|
578
|
+
"""Read a resource by URI using router for connection lookup."""
|
|
443
579
|
from pydantic import AnyUrl
|
|
444
580
|
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
if isinstance(result, str):
|
|
449
|
-
return [mcp_types.TextResourceContents(uri=resource_uri, text=result)]
|
|
450
|
-
import base64
|
|
581
|
+
# Ensure resource routing is built
|
|
582
|
+
if not self._resource_routing_built:
|
|
583
|
+
await self._build_resource_routing()
|
|
451
584
|
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
uri=resource_uri, blob=base64.b64encode(result).decode()
|
|
455
|
-
)
|
|
456
|
-
]
|
|
457
|
-
except Exception as e:
|
|
458
|
-
logger.debug("Local resource read failed for %s: %s", uri, e)
|
|
585
|
+
# Use router to find which connection has this resource
|
|
586
|
+
conn_name = self._router.get_resource_connection(uri)
|
|
459
587
|
|
|
460
|
-
|
|
588
|
+
if conn_name is None:
|
|
589
|
+
# Local resource
|
|
461
590
|
try:
|
|
462
|
-
|
|
591
|
+
result = await self._resource_manager.read_resource(uri)
|
|
592
|
+
resource_uri = AnyUrl(uri)
|
|
593
|
+
if isinstance(result, str):
|
|
594
|
+
return [mcp_types.TextResourceContents(uri=resource_uri, text=result)]
|
|
595
|
+
import base64
|
|
596
|
+
|
|
597
|
+
return [
|
|
598
|
+
mcp_types.BlobResourceContents(
|
|
599
|
+
uri=resource_uri, blob=base64.b64encode(result).decode()
|
|
600
|
+
)
|
|
601
|
+
]
|
|
463
602
|
except Exception as e:
|
|
464
|
-
logger.debug("
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
603
|
+
logger.debug("Local resource read failed for %s: %s", uri, e)
|
|
604
|
+
raise ValueError(f"Resource not found: {uri}") from e
|
|
605
|
+
else:
|
|
606
|
+
# Remote resource
|
|
607
|
+
conn = self._connections.get(conn_name)
|
|
608
|
+
if conn is None:
|
|
609
|
+
raise ValueError(f"Connection '{conn_name}' not found for resource '{uri}'")
|
|
610
|
+
return await conn.read_resource(uri)
|
|
468
611
|
|
|
469
612
|
# =========================================================================
|
|
470
613
|
# Prompt Operations
|
|
471
614
|
# =========================================================================
|
|
472
615
|
|
|
473
616
|
async def list_prompts(self) -> list[mcp_types.Prompt]:
|
|
474
|
-
"""
|
|
475
|
-
local = list((await self._prompt_manager.get_prompts()).values())
|
|
476
|
-
prompts: list[mcp_types.Prompt] = [p.to_mcp_prompt() for p in local]
|
|
477
|
-
|
|
617
|
+
"""Refresh prompts from all connections and rebuild prompt routing."""
|
|
478
618
|
if self._connections:
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
for r in results:
|
|
483
|
-
if isinstance(r, list):
|
|
484
|
-
prompts.extend(r)
|
|
485
|
-
|
|
486
|
-
return prompts
|
|
619
|
+
await asyncio.gather(*[c.list_prompts() for c in self._connections.values()])
|
|
620
|
+
await self._build_prompt_routing()
|
|
621
|
+
return self._router.prompts
|
|
487
622
|
|
|
488
623
|
async def get_prompt(
|
|
489
624
|
self, name: str, arguments: dict[str, Any] | None = None
|
|
490
625
|
) -> mcp_types.GetPromptResult:
|
|
491
|
-
"""Get a prompt by name
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
626
|
+
"""Get a prompt by name using router for connection lookup."""
|
|
627
|
+
# Ensure prompt routing is built
|
|
628
|
+
if not self._prompt_routing_built:
|
|
629
|
+
await self._build_prompt_routing()
|
|
630
|
+
|
|
631
|
+
# Use router to find which connection has this prompt
|
|
632
|
+
conn_name = self._router.get_prompt_connection(name)
|
|
496
633
|
|
|
497
|
-
|
|
634
|
+
if conn_name is None:
|
|
635
|
+
# Local prompt
|
|
498
636
|
try:
|
|
499
|
-
return await
|
|
637
|
+
return await self._prompt_manager.render_prompt(name, arguments or {})
|
|
500
638
|
except Exception as e:
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
639
|
+
raise ValueError(f"Prompt not found: {name}") from e
|
|
640
|
+
else:
|
|
641
|
+
# Remote prompt
|
|
642
|
+
conn = self._connections.get(conn_name)
|
|
643
|
+
if conn is None:
|
|
644
|
+
raise ValueError(f"Connection '{conn_name}' not found for prompt '{name}'")
|
|
645
|
+
return await conn.get_prompt(name, arguments)
|
|
505
646
|
|
|
506
647
|
# =========================================================================
|
|
507
648
|
# Server Methods
|
|
@@ -553,7 +694,7 @@ class Environment(
|
|
|
553
694
|
For v4 format: requires mcp_config, prompt, AND evaluate_tool
|
|
554
695
|
"""
|
|
555
696
|
# Check for local tools (registered via @env.tool)
|
|
556
|
-
if self._router.
|
|
697
|
+
if self._router._local_tool_names:
|
|
557
698
|
return False
|
|
558
699
|
# Check for local scenarios (registered via @env.scenario)
|
|
559
700
|
if getattr(self, "_scenarios", {}):
|
|
@@ -590,10 +731,10 @@ class Environment(
|
|
|
590
731
|
task.env.to_config() # {"prompt": "...", "mcp_config": {...}, ...}
|
|
591
732
|
```
|
|
592
733
|
"""
|
|
593
|
-
if self._router.
|
|
734
|
+
if self._router._local_tool_names:
|
|
594
735
|
raise ValueError(
|
|
595
736
|
f"Cannot serialize Environment with local tools: "
|
|
596
|
-
f"{list(self._router.
|
|
737
|
+
f"{list(self._router._local_tool_names)}. "
|
|
597
738
|
"Local tools require local execution. For remote submission, "
|
|
598
739
|
"use dict config or connect to a remote hub."
|
|
599
740
|
)
|