hud-python 0.4.45__py3-none-any.whl → 0.5.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hud/__init__.py +27 -7
- hud/agents/__init__.py +70 -5
- hud/agents/base.py +238 -500
- hud/agents/claude.py +236 -247
- hud/agents/gateway.py +42 -0
- hud/agents/gemini.py +264 -0
- hud/agents/gemini_cua.py +324 -0
- hud/agents/grounded_openai.py +98 -100
- hud/agents/misc/integration_test_agent.py +51 -20
- hud/agents/misc/response_agent.py +48 -36
- hud/agents/openai.py +282 -296
- hud/agents/{openai_chat_generic.py → openai_chat.py} +63 -33
- hud/agents/operator.py +199 -0
- hud/agents/resolver.py +70 -0
- hud/agents/tests/conftest.py +133 -0
- hud/agents/tests/test_base.py +300 -622
- hud/agents/tests/test_base_runtime.py +233 -0
- hud/agents/tests/test_claude.py +381 -214
- hud/agents/tests/test_client.py +9 -10
- hud/agents/tests/test_gemini.py +369 -0
- hud/agents/tests/test_grounded_openai_agent.py +65 -50
- hud/agents/tests/test_openai.py +377 -140
- hud/agents/tests/test_operator.py +362 -0
- hud/agents/tests/test_resolver.py +192 -0
- hud/agents/tests/test_run_eval.py +179 -0
- hud/agents/types.py +148 -0
- hud/cli/__init__.py +493 -546
- hud/cli/analyze.py +43 -5
- hud/cli/build.py +699 -113
- hud/cli/debug.py +8 -5
- hud/cli/dev.py +889 -732
- hud/cli/eval.py +793 -667
- hud/cli/flows/dev.py +167 -0
- hud/cli/flows/init.py +191 -0
- hud/cli/flows/tasks.py +153 -56
- hud/cli/flows/templates.py +151 -0
- hud/cli/flows/tests/__init__.py +1 -0
- hud/cli/flows/tests/test_dev.py +126 -0
- hud/cli/init.py +60 -58
- hud/cli/pull.py +1 -1
- hud/cli/push.py +38 -13
- hud/cli/rft.py +311 -0
- hud/cli/rft_status.py +145 -0
- hud/cli/tests/test_analyze.py +5 -5
- hud/cli/tests/test_analyze_metadata.py +3 -2
- hud/cli/tests/test_analyze_module.py +120 -0
- hud/cli/tests/test_build.py +110 -8
- hud/cli/tests/test_build_failure.py +41 -0
- hud/cli/tests/test_build_module.py +50 -0
- hud/cli/tests/test_cli_init.py +6 -1
- hud/cli/tests/test_cli_more_wrappers.py +30 -0
- hud/cli/tests/test_cli_root.py +140 -0
- hud/cli/tests/test_convert.py +361 -0
- hud/cli/tests/test_debug.py +12 -10
- hud/cli/tests/test_dev.py +197 -0
- hud/cli/tests/test_eval.py +251 -0
- hud/cli/tests/test_eval_bedrock.py +51 -0
- hud/cli/tests/test_init.py +124 -0
- hud/cli/tests/test_main_module.py +11 -5
- hud/cli/tests/test_mcp_server.py +12 -100
- hud/cli/tests/test_push.py +1 -1
- hud/cli/tests/test_push_happy.py +74 -0
- hud/cli/tests/test_push_wrapper.py +23 -0
- hud/cli/tests/test_registry.py +1 -1
- hud/cli/tests/test_utils.py +1 -1
- hud/cli/{rl → utils}/celebrate.py +14 -12
- hud/cli/utils/config.py +18 -1
- hud/cli/utils/docker.py +130 -4
- hud/cli/utils/env_check.py +9 -9
- hud/cli/utils/git.py +136 -0
- hud/cli/utils/interactive.py +39 -5
- hud/cli/utils/metadata.py +70 -1
- hud/cli/utils/runner.py +1 -1
- hud/cli/utils/server.py +2 -2
- hud/cli/utils/source_hash.py +3 -3
- hud/cli/utils/tasks.py +4 -1
- hud/cli/utils/tests/__init__.py +0 -0
- hud/cli/utils/tests/test_config.py +58 -0
- hud/cli/utils/tests/test_docker.py +93 -0
- hud/cli/utils/tests/test_docker_hints.py +71 -0
- hud/cli/utils/tests/test_env_check.py +74 -0
- hud/cli/utils/tests/test_environment.py +42 -0
- hud/cli/utils/tests/test_git.py +142 -0
- hud/cli/utils/tests/test_interactive_module.py +60 -0
- hud/cli/utils/tests/test_local_runner.py +50 -0
- hud/cli/utils/tests/test_logging_utils.py +23 -0
- hud/cli/utils/tests/test_metadata.py +49 -0
- hud/cli/utils/tests/test_package_runner.py +35 -0
- hud/cli/utils/tests/test_registry_utils.py +49 -0
- hud/cli/utils/tests/test_remote_runner.py +25 -0
- hud/cli/utils/tests/test_runner_modules.py +52 -0
- hud/cli/utils/tests/test_source_hash.py +36 -0
- hud/cli/utils/tests/test_tasks.py +80 -0
- hud/cli/utils/version_check.py +258 -0
- hud/cli/{rl → utils}/viewer.py +2 -2
- hud/clients/README.md +12 -11
- hud/clients/__init__.py +4 -3
- hud/clients/base.py +166 -26
- hud/clients/environment.py +51 -0
- hud/clients/fastmcp.py +13 -6
- hud/clients/mcp_use.py +45 -15
- hud/clients/tests/test_analyze_scenarios.py +206 -0
- hud/clients/tests/test_protocol.py +9 -3
- hud/datasets/__init__.py +23 -20
- hud/datasets/loader.py +326 -0
- hud/datasets/runner.py +198 -105
- hud/datasets/tests/__init__.py +0 -0
- hud/datasets/tests/test_loader.py +221 -0
- hud/datasets/tests/test_utils.py +315 -0
- hud/datasets/utils.py +270 -90
- hud/environment/__init__.py +52 -0
- hud/environment/connection.py +258 -0
- hud/environment/connectors/__init__.py +33 -0
- hud/environment/connectors/base.py +68 -0
- hud/environment/connectors/local.py +177 -0
- hud/environment/connectors/mcp_config.py +137 -0
- hud/environment/connectors/openai.py +101 -0
- hud/environment/connectors/remote.py +172 -0
- hud/environment/environment.py +835 -0
- hud/environment/integrations/__init__.py +45 -0
- hud/environment/integrations/adk.py +67 -0
- hud/environment/integrations/anthropic.py +196 -0
- hud/environment/integrations/gemini.py +92 -0
- hud/environment/integrations/langchain.py +82 -0
- hud/environment/integrations/llamaindex.py +68 -0
- hud/environment/integrations/openai.py +238 -0
- hud/environment/mock.py +306 -0
- hud/environment/router.py +263 -0
- hud/environment/scenarios.py +620 -0
- hud/environment/tests/__init__.py +1 -0
- hud/environment/tests/test_connection.py +317 -0
- hud/environment/tests/test_connectors.py +205 -0
- hud/environment/tests/test_environment.py +593 -0
- hud/environment/tests/test_integrations.py +257 -0
- hud/environment/tests/test_local_connectors.py +242 -0
- hud/environment/tests/test_scenarios.py +1086 -0
- hud/environment/tests/test_tools.py +208 -0
- hud/environment/types.py +23 -0
- hud/environment/utils/__init__.py +35 -0
- hud/environment/utils/formats.py +215 -0
- hud/environment/utils/schema.py +171 -0
- hud/environment/utils/tool_wrappers.py +113 -0
- hud/eval/__init__.py +67 -0
- hud/eval/context.py +727 -0
- hud/eval/display.py +299 -0
- hud/eval/instrument.py +187 -0
- hud/eval/manager.py +533 -0
- hud/eval/parallel.py +268 -0
- hud/eval/task.py +372 -0
- hud/eval/tests/__init__.py +1 -0
- hud/eval/tests/test_context.py +178 -0
- hud/eval/tests/test_eval.py +210 -0
- hud/eval/tests/test_manager.py +152 -0
- hud/eval/tests/test_parallel.py +168 -0
- hud/eval/tests/test_task.py +291 -0
- hud/eval/types.py +65 -0
- hud/eval/utils.py +194 -0
- hud/patches/__init__.py +19 -0
- hud/patches/mcp_patches.py +308 -0
- hud/patches/warnings.py +54 -0
- hud/samples/browser.py +4 -4
- hud/server/__init__.py +2 -1
- hud/server/low_level.py +2 -1
- hud/server/router.py +164 -0
- hud/server/server.py +567 -80
- hud/server/tests/test_mcp_server_integration.py +11 -11
- hud/server/tests/test_mcp_server_more.py +1 -1
- hud/server/tests/test_server_extra.py +2 -0
- hud/settings.py +45 -3
- hud/shared/exceptions.py +36 -10
- hud/shared/hints.py +26 -1
- hud/shared/requests.py +15 -3
- hud/shared/tests/test_exceptions.py +40 -31
- hud/shared/tests/test_hints.py +167 -0
- hud/telemetry/__init__.py +20 -19
- hud/telemetry/exporter.py +201 -0
- hud/telemetry/instrument.py +165 -253
- hud/telemetry/tests/test_eval_telemetry.py +356 -0
- hud/telemetry/tests/test_exporter.py +258 -0
- hud/telemetry/tests/test_instrument.py +401 -0
- hud/tools/__init__.py +18 -2
- hud/tools/agent.py +223 -0
- hud/tools/apply_patch.py +639 -0
- hud/tools/base.py +54 -4
- hud/tools/bash.py +2 -2
- hud/tools/computer/__init__.py +36 -3
- hud/tools/computer/anthropic.py +2 -2
- hud/tools/computer/gemini.py +385 -0
- hud/tools/computer/hud.py +23 -6
- hud/tools/computer/openai.py +20 -21
- hud/tools/computer/qwen.py +434 -0
- hud/tools/computer/settings.py +37 -0
- hud/tools/edit.py +3 -7
- hud/tools/executors/base.py +4 -2
- hud/tools/executors/pyautogui.py +1 -1
- hud/tools/grounding/grounded_tool.py +13 -18
- hud/tools/grounding/grounder.py +10 -31
- hud/tools/grounding/tests/test_grounded_tool.py +26 -44
- hud/tools/jupyter.py +330 -0
- hud/tools/playwright.py +18 -3
- hud/tools/shell.py +308 -0
- hud/tools/tests/test_agent_tool.py +355 -0
- hud/tools/tests/test_apply_patch.py +718 -0
- hud/tools/tests/test_computer.py +4 -9
- hud/tools/tests/test_computer_actions.py +24 -2
- hud/tools/tests/test_jupyter_tool.py +181 -0
- hud/tools/tests/test_shell.py +596 -0
- hud/tools/tests/test_submit.py +85 -0
- hud/tools/tests/test_types.py +193 -0
- hud/tools/types.py +21 -1
- hud/types.py +194 -56
- hud/utils/__init__.py +2 -0
- hud/utils/env.py +67 -0
- hud/utils/hud_console.py +89 -18
- hud/utils/mcp.py +15 -58
- hud/utils/strict_schema.py +162 -0
- hud/utils/tests/test_init.py +1 -2
- hud/utils/tests/test_mcp.py +1 -28
- hud/utils/tests/test_pretty_errors.py +186 -0
- hud/utils/tests/test_tool_shorthand.py +154 -0
- hud/utils/tests/test_version.py +1 -1
- hud/utils/types.py +20 -0
- hud/version.py +1 -1
- hud_python-0.5.13.dist-info/METADATA +264 -0
- hud_python-0.5.13.dist-info/RECORD +305 -0
- {hud_python-0.4.45.dist-info → hud_python-0.5.13.dist-info}/WHEEL +1 -1
- hud/agents/langchain.py +0 -261
- hud/agents/lite_llm.py +0 -72
- hud/cli/rl/__init__.py +0 -180
- hud/cli/rl/config.py +0 -101
- hud/cli/rl/display.py +0 -133
- hud/cli/rl/gpu.py +0 -63
- hud/cli/rl/gpu_utils.py +0 -321
- hud/cli/rl/local_runner.py +0 -595
- hud/cli/rl/presets.py +0 -96
- hud/cli/rl/remote_runner.py +0 -463
- hud/cli/rl/rl_api.py +0 -150
- hud/cli/rl/vllm.py +0 -177
- hud/cli/rl/wait_utils.py +0 -89
- hud/datasets/parallel.py +0 -687
- hud/misc/__init__.py +0 -1
- hud/misc/claude_plays_pokemon.py +0 -292
- hud/otel/__init__.py +0 -35
- hud/otel/collector.py +0 -142
- hud/otel/config.py +0 -181
- hud/otel/context.py +0 -570
- hud/otel/exporters.py +0 -369
- hud/otel/instrumentation.py +0 -135
- hud/otel/processors.py +0 -121
- hud/otel/tests/__init__.py +0 -1
- hud/otel/tests/test_processors.py +0 -197
- hud/rl/README.md +0 -30
- hud/rl/__init__.py +0 -1
- hud/rl/actor.py +0 -176
- hud/rl/buffer.py +0 -405
- hud/rl/chat_template.jinja +0 -101
- hud/rl/config.py +0 -192
- hud/rl/distributed.py +0 -132
- hud/rl/learner.py +0 -637
- hud/rl/tests/__init__.py +0 -1
- hud/rl/tests/test_learner.py +0 -186
- hud/rl/train.py +0 -382
- hud/rl/types.py +0 -101
- hud/rl/utils/start_vllm_server.sh +0 -30
- hud/rl/utils.py +0 -524
- hud/rl/vllm_adapter.py +0 -143
- hud/telemetry/job.py +0 -352
- hud/telemetry/replay.py +0 -74
- hud/telemetry/tests/test_replay.py +0 -40
- hud/telemetry/tests/test_trace.py +0 -63
- hud/telemetry/trace.py +0 -158
- hud/utils/agent_factories.py +0 -86
- hud/utils/async_utils.py +0 -65
- hud/utils/group_eval.py +0 -223
- hud/utils/progress.py +0 -149
- hud/utils/tasks.py +0 -127
- hud/utils/tests/test_async_utils.py +0 -173
- hud/utils/tests/test_progress.py +0 -261
- hud_python-0.4.45.dist-info/METADATA +0 -552
- hud_python-0.4.45.dist-info/RECORD +0 -228
- {hud_python-0.4.45.dist-info → hud_python-0.5.13.dist-info}/entry_points.txt +0 -0
- {hud_python-0.4.45.dist-info → hud_python-0.5.13.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,835 @@
|
|
|
1
|
+
"""Environment class - unified MCP server and client."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import asyncio
|
|
6
|
+
import logging
|
|
7
|
+
from collections.abc import Awaitable, Callable
|
|
8
|
+
from typing import TYPE_CHECKING, Any, Literal, Self
|
|
9
|
+
|
|
10
|
+
import mcp.types as mcp_types
|
|
11
|
+
|
|
12
|
+
from hud.environment.connectors import ConnectorsMixin
|
|
13
|
+
from hud.environment.integrations import IntegrationsMixin
|
|
14
|
+
from hud.environment.mock import MockMixin
|
|
15
|
+
from hud.environment.router import ConflictResolution, ToolRouter
|
|
16
|
+
from hud.environment.scenarios import ScenarioMixin
|
|
17
|
+
from hud.server.server import MCPServer
|
|
18
|
+
from hud.types import MCPToolResult
|
|
19
|
+
|
|
20
|
+
if TYPE_CHECKING:
|
|
21
|
+
import types
|
|
22
|
+
|
|
23
|
+
from hud.environment.connection import Connector
|
|
24
|
+
from hud.eval.task import Task
|
|
25
|
+
|
|
26
|
+
__all__ = ["Environment"]
|
|
27
|
+
|
|
28
|
+
logger = logging.getLogger(__name__)
|
|
29
|
+
|
|
30
|
+
# Suppress verbose fastmcp logging
|
|
31
|
+
logging.getLogger("fastmcp.server.server").setLevel(logging.WARNING)
|
|
32
|
+
logging.getLogger("fastmcp.server.openapi").setLevel(logging.WARNING)
|
|
33
|
+
|
|
34
|
+
# Type alias for async callables (no-arg functions that return awaitable)
|
|
35
|
+
AsyncCallable = Callable[[], Awaitable[Any]]
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class Environment(
|
|
39
|
+
ConnectorsMixin,
|
|
40
|
+
IntegrationsMixin,
|
|
41
|
+
MockMixin,
|
|
42
|
+
ScenarioMixin,
|
|
43
|
+
MCPServer,
|
|
44
|
+
):
|
|
45
|
+
"""Unified MCP environment that acts as both server and client.
|
|
46
|
+
|
|
47
|
+
Features:
|
|
48
|
+
- Define local tools with @env.tool decorator
|
|
49
|
+
- Connect to HUD Hub, URLs, or mcp_config dicts
|
|
50
|
+
- Automatic tool routing (local vs remote)
|
|
51
|
+
- Format tools for any LLM provider
|
|
52
|
+
- Integrate with popular agent frameworks
|
|
53
|
+
- Mock mode for testing without real connections
|
|
54
|
+
|
|
55
|
+
Connector methods (connect to sources):
|
|
56
|
+
connect_hub(name) - HUD Hub environment
|
|
57
|
+
connect_url(url) - MCP server via URL
|
|
58
|
+
connect_mcp(config) - Single mcp_config server
|
|
59
|
+
connect_mcp_config(mcp_config) - Multiple mcp_config servers
|
|
60
|
+
connect_image(image) - Docker image via stdio
|
|
61
|
+
connect_fastapi(app) - Mount FastAPI app as MCP server
|
|
62
|
+
connect_openapi(spec) - Mount OpenAPI spec as MCP server
|
|
63
|
+
connect_server(server) - Mount MCPServer/FastMCP directly
|
|
64
|
+
|
|
65
|
+
Mock methods (for testing):
|
|
66
|
+
mock() - Enable mock mode, all tools return mock values
|
|
67
|
+
unmock() - Disable mock mode
|
|
68
|
+
mock_tool(name, output) - Set specific mock output for a tool
|
|
69
|
+
is_mock - Check if mock mode is enabled
|
|
70
|
+
|
|
71
|
+
OpenAI integrations:
|
|
72
|
+
as_openai_chat_tools() - Chat Completions format
|
|
73
|
+
as_openai_responses_tools() - Responses API format
|
|
74
|
+
as_openai_agent_tools() - Agents SDK (requires openai-agents)
|
|
75
|
+
|
|
76
|
+
Anthropic/Claude integrations:
|
|
77
|
+
as_claude_tools() - Claude API format
|
|
78
|
+
as_claude_programmatic_tools() - Programmatic tool use
|
|
79
|
+
as_anthropic_runner() - Tool runner (requires anthropic)
|
|
80
|
+
|
|
81
|
+
Google/Gemini integrations:
|
|
82
|
+
as_gemini_tools() - Gemini format
|
|
83
|
+
as_gemini_tool_config() - Tool execution config
|
|
84
|
+
|
|
85
|
+
LangChain integrations:
|
|
86
|
+
as_langchain_tools() - StructuredTools (requires langchain-core)
|
|
87
|
+
|
|
88
|
+
Example:
|
|
89
|
+
```python
|
|
90
|
+
env = Environment("my-env")
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
@env.tool
|
|
94
|
+
def greet(name: str) -> str:
|
|
95
|
+
return f"Hello, {name}!"
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
env.connect_hub("browser", prefix="browser")
|
|
99
|
+
|
|
100
|
+
async with env:
|
|
101
|
+
# Get tools in any format
|
|
102
|
+
openai_tools = env.as_openai_chat_tools()
|
|
103
|
+
claude_tools = env.as_claude_tools()
|
|
104
|
+
|
|
105
|
+
# Call tools - automatically routed
|
|
106
|
+
result = await env.call_tool("greet", name="World")
|
|
107
|
+
|
|
108
|
+
# Or pass provider-specific format - auto-detected
|
|
109
|
+
result = await env.call_tool(response.choices[0].message.tool_calls[0])
|
|
110
|
+
|
|
111
|
+
# Mock mode for testing
|
|
112
|
+
env.mock()
|
|
113
|
+
env.mock_tool("browser_navigate", "Navigation successful")
|
|
114
|
+
async with env:
|
|
115
|
+
result = await env.call_tool("browser_navigate", url="https://example.com")
|
|
116
|
+
# Returns mock value instead of actually navigating
|
|
117
|
+
```
|
|
118
|
+
"""
|
|
119
|
+
|
|
120
|
+
MAX_CONCURRENT_CONNECTIONS = 10
|
|
121
|
+
|
|
122
|
+
@staticmethod
|
|
123
|
+
def _normalize_name(name: str) -> str:
|
|
124
|
+
"""Normalize environment name to lowercase with hyphens.
|
|
125
|
+
|
|
126
|
+
- Strips whitespace
|
|
127
|
+
- Replaces spaces and underscores with hyphens
|
|
128
|
+
- Lowercases the result
|
|
129
|
+
- Removes any non-alphanumeric characters except hyphens
|
|
130
|
+
"""
|
|
131
|
+
import re
|
|
132
|
+
|
|
133
|
+
normalized = name.strip().lower()
|
|
134
|
+
normalized = normalized.replace(" ", "-").replace("_", "-")
|
|
135
|
+
# Keep only alphanumeric and hyphens
|
|
136
|
+
normalized = re.sub(r"[^a-z0-9-]", "", normalized)
|
|
137
|
+
# Collapse multiple hyphens
|
|
138
|
+
normalized = re.sub(r"-+", "-", normalized)
|
|
139
|
+
# Strip leading/trailing hyphens
|
|
140
|
+
return normalized.strip("-") or "environment"
|
|
141
|
+
|
|
142
|
+
def __init__(
|
|
143
|
+
self,
|
|
144
|
+
name: str = "environment",
|
|
145
|
+
instructions: str | None = None,
|
|
146
|
+
conflict_resolution: ConflictResolution = ConflictResolution.PREFIX,
|
|
147
|
+
**fastmcp_kwargs: Any,
|
|
148
|
+
) -> None:
|
|
149
|
+
# Normalize name to prevent casing/spacing issues
|
|
150
|
+
name = self._normalize_name(name)
|
|
151
|
+
super().__init__(name=name, instructions=instructions, **fastmcp_kwargs)
|
|
152
|
+
self._connections: dict[str, Connector] = {}
|
|
153
|
+
self._router = ToolRouter(conflict_resolution=conflict_resolution)
|
|
154
|
+
# Granular routing flags - only rebuild what's invalidated
|
|
155
|
+
self._tool_routing_built = False
|
|
156
|
+
self._prompt_routing_built = False
|
|
157
|
+
self._resource_routing_built = False
|
|
158
|
+
self._in_context = False
|
|
159
|
+
|
|
160
|
+
# Tool call queues - run after connections established
|
|
161
|
+
self._setup_calls: list[tuple[str, dict[str, Any]]] = []
|
|
162
|
+
self._evaluate_calls: list[tuple[str, dict[str, Any]]] = []
|
|
163
|
+
self._integration_test_calls: list[tuple[str, dict[str, Any]]] = []
|
|
164
|
+
# Store setup tool results for append_setup_output feature
|
|
165
|
+
self._setup_results: list[MCPToolResult] = []
|
|
166
|
+
|
|
167
|
+
# Default prompt (EvalContext has per-run prompt)
|
|
168
|
+
self.prompt: str | None = None
|
|
169
|
+
|
|
170
|
+
# Serialization support
|
|
171
|
+
# _hub_config: set by connect_hub() for v5 format {"name": "hub", "include": [...]}
|
|
172
|
+
# _mcp_config: set by connect_mcp_config() for v4 format {"server_name": {...}}
|
|
173
|
+
self._hub_config: dict[str, Any] | None = None
|
|
174
|
+
self._mcp_config: dict[str, dict[str, Any]] | None = None
|
|
175
|
+
|
|
176
|
+
# Agent-level tool filtering (applied in as_tools(), not at connection level)
|
|
177
|
+
# This allows Environment to call all tools while limiting agent visibility
|
|
178
|
+
self._agent_include: list[str] | None = None
|
|
179
|
+
self._agent_exclude: list[str] | None = None
|
|
180
|
+
|
|
181
|
+
# Initialize mock state
|
|
182
|
+
self._init_mock()
|
|
183
|
+
|
|
184
|
+
# Initialize scenario state
|
|
185
|
+
self._init_scenarios()
|
|
186
|
+
|
|
187
|
+
# =========================================================================
|
|
188
|
+
# Core Methods
|
|
189
|
+
# =========================================================================
|
|
190
|
+
|
|
191
|
+
def as_tools(self) -> list[mcp_types.Tool]:
|
|
192
|
+
"""Return tools in MCP format (base format).
|
|
193
|
+
|
|
194
|
+
Applies agent-level include/exclude filtering if set.
|
|
195
|
+
Supports fnmatch-style wildcards (e.g., "*setup*", "browser_*").
|
|
196
|
+
"""
|
|
197
|
+
import fnmatch
|
|
198
|
+
|
|
199
|
+
tools = self._router.tools
|
|
200
|
+
|
|
201
|
+
# Apply agent-level filtering (from v4 allowed_tools/disallowed_tools)
|
|
202
|
+
if self._agent_include is not None or self._agent_exclude is not None:
|
|
203
|
+
filtered = []
|
|
204
|
+
for tool in tools:
|
|
205
|
+
# Include filter: None means include all, check if matches any pattern
|
|
206
|
+
if self._agent_include is not None and not any(
|
|
207
|
+
fnmatch.fnmatch(tool.name, pattern) for pattern in self._agent_include
|
|
208
|
+
):
|
|
209
|
+
continue
|
|
210
|
+
# Exclude filter: skip if tool matches any exclude pattern
|
|
211
|
+
if self._agent_exclude is not None and any(
|
|
212
|
+
fnmatch.fnmatch(tool.name, pattern) for pattern in self._agent_exclude
|
|
213
|
+
):
|
|
214
|
+
continue
|
|
215
|
+
filtered.append(tool)
|
|
216
|
+
return filtered
|
|
217
|
+
|
|
218
|
+
return tools
|
|
219
|
+
|
|
220
|
+
def add_tool(self, obj: Any, **kwargs: Any) -> None:
|
|
221
|
+
super().add_tool(obj, **kwargs)
|
|
222
|
+
self._tool_routing_built = False # Only invalidate tool routing
|
|
223
|
+
|
|
224
|
+
async def call_tool(self, call: Any, /, **kwargs: Any) -> Any:
|
|
225
|
+
"""Call a tool, auto-detecting format and returning matching result format.
|
|
226
|
+
|
|
227
|
+
Accepts any format:
|
|
228
|
+
- String with kwargs: call_tool("navigate", url="...")
|
|
229
|
+
- Tuple: call_tool(("navigate", {"url": "..."}))
|
|
230
|
+
- MCPToolCall: call_tool(MCPToolCall(name="navigate", ...))
|
|
231
|
+
- OpenAI: call_tool(response.choices[0].message.tool_calls[0])
|
|
232
|
+
- Claude: call_tool(response.content[0]) # tool_use block
|
|
233
|
+
- Gemini: call_tool(response.candidates[0].content.parts[0])
|
|
234
|
+
|
|
235
|
+
Returns:
|
|
236
|
+
Result formatted to match input format (OpenAI -> OpenAI tool message, etc.)
|
|
237
|
+
"""
|
|
238
|
+
from hud.environment.utils import format_result, parse_tool_call
|
|
239
|
+
|
|
240
|
+
# Parse the tool call (kwargs merged when call is string)
|
|
241
|
+
parsed, fmt = parse_tool_call(call, **kwargs)
|
|
242
|
+
result = await self._execute_tool(parsed.name, parsed.arguments or {})
|
|
243
|
+
return format_result(result, parsed, fmt)
|
|
244
|
+
|
|
245
|
+
def _connections_with_tool(self, tool_name: str) -> set[str]:
|
|
246
|
+
"""Get connection names that have a specific tool.
|
|
247
|
+
|
|
248
|
+
Uses cached_tools from each Connector to check availability.
|
|
249
|
+
"""
|
|
250
|
+
result = set()
|
|
251
|
+
for name, connector in self._connections.items():
|
|
252
|
+
tool_names = {t.name for t in connector.cached_tools}
|
|
253
|
+
if tool_name in tool_names:
|
|
254
|
+
result.add(name)
|
|
255
|
+
return result
|
|
256
|
+
|
|
257
|
+
async def _broadcast_tool(
|
|
258
|
+
self,
|
|
259
|
+
tool_name: str,
|
|
260
|
+
**kwargs: Any,
|
|
261
|
+
) -> dict[str, Any]:
|
|
262
|
+
"""Broadcast a tool call to all connections that have the tool.
|
|
263
|
+
|
|
264
|
+
Automatically filters to only connections where the tool exists
|
|
265
|
+
(based on cached_tools from initial discovery).
|
|
266
|
+
|
|
267
|
+
For internal tools (starting with _), tries ALL connections since
|
|
268
|
+
internal tools are hidden from list_tools() and won't be in cached_tools.
|
|
269
|
+
|
|
270
|
+
Args:
|
|
271
|
+
tool_name: Name of the tool to call
|
|
272
|
+
**kwargs: Arguments to pass to the tool
|
|
273
|
+
|
|
274
|
+
Returns:
|
|
275
|
+
Dict mapping connection name to result (or exception)
|
|
276
|
+
"""
|
|
277
|
+
import asyncio
|
|
278
|
+
|
|
279
|
+
# For internal tools (underscore prefix), try ALL connections since
|
|
280
|
+
# they're hidden from list_tools() and won't appear in cached_tools.
|
|
281
|
+
# For regular tools, only try connections that advertise the tool.
|
|
282
|
+
if tool_name.startswith("_"):
|
|
283
|
+
targets = set(self._connections.keys())
|
|
284
|
+
else:
|
|
285
|
+
targets = self._connections_with_tool(tool_name)
|
|
286
|
+
|
|
287
|
+
results: dict[str, Any] = {}
|
|
288
|
+
|
|
289
|
+
async def call_one(name: str) -> None:
|
|
290
|
+
connector = self._connections.get(name)
|
|
291
|
+
if not connector or not connector.client:
|
|
292
|
+
return
|
|
293
|
+
try:
|
|
294
|
+
# Use connector.call_tool which expects arguments as a dict
|
|
295
|
+
results[name] = await connector.call_tool(tool_name, kwargs)
|
|
296
|
+
logger.debug("Broadcast '%s' to '%s' succeeded", tool_name, name)
|
|
297
|
+
except Exception as e:
|
|
298
|
+
results[name] = e
|
|
299
|
+
logger.debug("Broadcast '%s' to '%s' failed: %s", tool_name, name, e)
|
|
300
|
+
|
|
301
|
+
await asyncio.gather(*[call_one(n) for n in targets], return_exceptions=True)
|
|
302
|
+
return results
|
|
303
|
+
|
|
304
|
+
async def call_tools(self, calls: Any) -> list[Any]:
|
|
305
|
+
"""Call multiple tools, returning results in matching formats."""
|
|
306
|
+
if calls is None:
|
|
307
|
+
return []
|
|
308
|
+
if not isinstance(calls, list):
|
|
309
|
+
return [await self.call_tool(calls)]
|
|
310
|
+
|
|
311
|
+
# Filter to tool calls only (skip text blocks, etc.)
|
|
312
|
+
tool_calls = []
|
|
313
|
+
for call in calls:
|
|
314
|
+
t = call.get("type") if isinstance(call, dict) else getattr(call, "type", None)
|
|
315
|
+
if t is None or t in ("tool_use", "function"):
|
|
316
|
+
tool_calls.append(call)
|
|
317
|
+
|
|
318
|
+
return await asyncio.gather(*[self.call_tool(c) for c in tool_calls])
|
|
319
|
+
|
|
320
|
+
# =========================================================================
|
|
321
|
+
# Lifecycle Configuration
|
|
322
|
+
# =========================================================================
|
|
323
|
+
|
|
324
|
+
def setup_tool(self, call: Any, /, **kwargs: Any) -> Environment:
|
|
325
|
+
"""Add a tool call to execute after connections are established."""
|
|
326
|
+
from hud.environment.utils import parse_tool_call
|
|
327
|
+
|
|
328
|
+
if isinstance(call, str) and kwargs:
|
|
329
|
+
self._setup_calls.append((call, kwargs))
|
|
330
|
+
else:
|
|
331
|
+
parsed, _ = parse_tool_call(call)
|
|
332
|
+
self._setup_calls.append((parsed.name, parsed.arguments or {}))
|
|
333
|
+
return self
|
|
334
|
+
|
|
335
|
+
def evaluate_tool(self, call: Any, /, **kwargs: Any) -> Environment:
|
|
336
|
+
"""Add a tool call to execute before disconnecting."""
|
|
337
|
+
from hud.environment.utils import parse_tool_call
|
|
338
|
+
|
|
339
|
+
if isinstance(call, str) and kwargs:
|
|
340
|
+
self._evaluate_calls.append((call, kwargs))
|
|
341
|
+
else:
|
|
342
|
+
parsed, _ = parse_tool_call(call)
|
|
343
|
+
self._evaluate_calls.append((parsed.name, parsed.arguments or {}))
|
|
344
|
+
return self
|
|
345
|
+
|
|
346
|
+
# =========================================================================
|
|
347
|
+
# Context Manager
|
|
348
|
+
# =========================================================================
|
|
349
|
+
|
|
350
|
+
async def __aenter__(self) -> Self:
|
|
351
|
+
"""Connect all connectors, build routing, run setup tools."""
|
|
352
|
+
self._in_context = True
|
|
353
|
+
|
|
354
|
+
# Connect to all servers and fetch tools/prompts/resources in parallel
|
|
355
|
+
sem = asyncio.Semaphore(self.MAX_CONCURRENT_CONNECTIONS)
|
|
356
|
+
errors: list[tuple[str, Exception]] = []
|
|
357
|
+
|
|
358
|
+
async def connect_one(name: str, conn: Connector) -> None:
|
|
359
|
+
async with sem:
|
|
360
|
+
try:
|
|
361
|
+
await conn.connect()
|
|
362
|
+
# Batch fetch all MCP primitives in parallel for performance
|
|
363
|
+
await asyncio.gather(
|
|
364
|
+
conn.list_tools(),
|
|
365
|
+
conn.list_prompts(),
|
|
366
|
+
conn.list_resources(),
|
|
367
|
+
)
|
|
368
|
+
except Exception as e:
|
|
369
|
+
errors.append((name, e))
|
|
370
|
+
|
|
371
|
+
if self._connections:
|
|
372
|
+
await asyncio.gather(*[connect_one(n, c) for n, c in self._connections.items()])
|
|
373
|
+
if errors:
|
|
374
|
+
for conn in self._connections.values():
|
|
375
|
+
if conn.is_connected:
|
|
376
|
+
await conn.disconnect()
|
|
377
|
+
name, err = errors[0]
|
|
378
|
+
str_err = str(err).replace("Client failed to connect: ", "") # Strip from FastMCP
|
|
379
|
+
raise ConnectionError(f"Failed to connect to {name}: {str_err}") from err
|
|
380
|
+
|
|
381
|
+
await self._build_routing()
|
|
382
|
+
|
|
383
|
+
# Setup tool calls (after connections) - abort if any setup tool fails
|
|
384
|
+
# Store results for append_setup_output feature
|
|
385
|
+
self._setup_results = []
|
|
386
|
+
for name, args in self._setup_calls:
|
|
387
|
+
result = await self._execute_tool(name, args)
|
|
388
|
+
self._setup_results.append(result)
|
|
389
|
+
if result.isError:
|
|
390
|
+
# Extract error message from result content
|
|
391
|
+
error_msg = "Setup tool failed"
|
|
392
|
+
if result.content:
|
|
393
|
+
for block in result.content:
|
|
394
|
+
if isinstance(block, mcp_types.TextContent):
|
|
395
|
+
error_msg = block.text
|
|
396
|
+
break
|
|
397
|
+
# Clean up connections before raising (since __aexit__ won't be called)
|
|
398
|
+
for conn in self._connections.values():
|
|
399
|
+
if conn.is_connected:
|
|
400
|
+
await conn.disconnect()
|
|
401
|
+
raise RuntimeError(f"Setup tool '{name}' failed: {error_msg}")
|
|
402
|
+
|
|
403
|
+
return self
|
|
404
|
+
|
|
405
|
+
async def __aexit__(
|
|
406
|
+
self,
|
|
407
|
+
exc_type: type[BaseException] | None,
|
|
408
|
+
exc_val: BaseException | None,
|
|
409
|
+
exc_tb: types.TracebackType | None,
|
|
410
|
+
) -> None:
|
|
411
|
+
"""Run evaluate tools, exit queue, then disconnect."""
|
|
412
|
+
from hud.agents.base import find_reward
|
|
413
|
+
|
|
414
|
+
# Evaluate tool calls and collect rewards
|
|
415
|
+
rewards: list[float] = []
|
|
416
|
+
for name, args in self._evaluate_calls:
|
|
417
|
+
try:
|
|
418
|
+
result = await self._execute_tool(name, args)
|
|
419
|
+
rewards.append(find_reward(result))
|
|
420
|
+
except Exception as e:
|
|
421
|
+
logger.warning("Evaluate tool %s failed: %s", name, e)
|
|
422
|
+
# Record 0.0 for failed evaluate tools so they affect the average
|
|
423
|
+
rewards.append(0.0)
|
|
424
|
+
|
|
425
|
+
# Store average reward from evaluate tools
|
|
426
|
+
self._evaluate_reward: float | None = None
|
|
427
|
+
if rewards:
|
|
428
|
+
self._evaluate_reward = sum(rewards) / len(rewards)
|
|
429
|
+
|
|
430
|
+
self._in_context = False
|
|
431
|
+
if self._connections:
|
|
432
|
+
await asyncio.gather(*[c.disconnect() for c in self._connections.values()])
|
|
433
|
+
self._router.clear()
|
|
434
|
+
self._tool_routing_built = False
|
|
435
|
+
self._prompt_routing_built = False
|
|
436
|
+
self._resource_routing_built = False
|
|
437
|
+
self._active_session = None # Clear stale scenario state
|
|
438
|
+
|
|
439
|
+
async def run_async(
|
|
440
|
+
self,
|
|
441
|
+
transport: Literal["stdio", "http", "sse"] | None = None,
|
|
442
|
+
show_banner: bool = True,
|
|
443
|
+
**transport_kwargs: Any,
|
|
444
|
+
) -> None:
|
|
445
|
+
"""Run the MCP server, auto-connecting all connectors first.
|
|
446
|
+
|
|
447
|
+
This ensures that tools from external MCP servers (via connect_mcp_config)
|
|
448
|
+
are discovered and available when the server starts.
|
|
449
|
+
"""
|
|
450
|
+
async with self: # Connect all connectors via __aenter__
|
|
451
|
+
await super().run_async(
|
|
452
|
+
transport=transport, show_banner=show_banner, **transport_kwargs
|
|
453
|
+
)
|
|
454
|
+
|
|
455
|
+
async def _build_routing(self) -> None:
|
|
456
|
+
"""Build routing for tools, prompts, and resources in parallel.
|
|
457
|
+
|
|
458
|
+
Only rebuilds what's actually invalidated for performance.
|
|
459
|
+
"""
|
|
460
|
+
tasks = []
|
|
461
|
+
if not self._tool_routing_built:
|
|
462
|
+
tasks.append(self._build_tool_routing())
|
|
463
|
+
if not self._prompt_routing_built:
|
|
464
|
+
tasks.append(self._build_prompt_routing())
|
|
465
|
+
if not self._resource_routing_built:
|
|
466
|
+
tasks.append(self._build_resource_routing())
|
|
467
|
+
if tasks:
|
|
468
|
+
await asyncio.gather(*tasks)
|
|
469
|
+
|
|
470
|
+
async def _build_tool_routing(self) -> None:
|
|
471
|
+
"""Build tool routing from local tools and connection caches."""
|
|
472
|
+
local_tools_dict = await self._tool_manager.get_tools()
|
|
473
|
+
local_tools = list(local_tools_dict.values())
|
|
474
|
+
self._router.build(
|
|
475
|
+
local_tools=[t.to_mcp_tool() for t in local_tools],
|
|
476
|
+
connections=self._connections,
|
|
477
|
+
connection_order=list(self._connections.keys()),
|
|
478
|
+
)
|
|
479
|
+
# Populate mock schemas for auto-generated mock values
|
|
480
|
+
self._populate_mock_schemas()
|
|
481
|
+
self._tool_routing_built = True
|
|
482
|
+
|
|
483
|
+
async def _build_prompt_routing(self) -> None:
|
|
484
|
+
"""Build prompt routing from local prompts and connections."""
|
|
485
|
+
local_prompts_dict = await self._prompt_manager.get_prompts()
|
|
486
|
+
local_prompts = [p.to_mcp_prompt() for p in local_prompts_dict.values()]
|
|
487
|
+
self._router.build_prompts(local_prompts, self._connections)
|
|
488
|
+
self._prompt_routing_built = True
|
|
489
|
+
|
|
490
|
+
async def _build_resource_routing(self) -> None:
|
|
491
|
+
"""Build resource routing from local resources and connections."""
|
|
492
|
+
local_resources_dict = await self._resource_manager.get_resources()
|
|
493
|
+
local_resources = [r.to_mcp_resource() for r in local_resources_dict.values()]
|
|
494
|
+
self._router.build_resources(local_resources, self._connections)
|
|
495
|
+
self._resource_routing_built = True
|
|
496
|
+
|
|
497
|
+
# =========================================================================
|
|
498
|
+
# MCP Protocol Overrides - Include connector tools in MCP responses
|
|
499
|
+
# =========================================================================
|
|
500
|
+
|
|
501
|
+
def _setup_handlers(self) -> None:
|
|
502
|
+
"""Override FastMCP to register our custom handlers for tools."""
|
|
503
|
+
# Call parent to set up all standard handlers
|
|
504
|
+
super()._setup_handlers()
|
|
505
|
+
# Re-register our custom handlers (overwrites parent's registrations)
|
|
506
|
+
self._mcp_server.list_tools()(self._env_list_tools)
|
|
507
|
+
self._mcp_server.call_tool()(self._env_call_tool)
|
|
508
|
+
|
|
509
|
+
async def _env_list_tools(self) -> list[mcp_types.Tool]:
|
|
510
|
+
"""Return all tools including those from connectors."""
|
|
511
|
+
if not self._tool_routing_built:
|
|
512
|
+
await self._build_tool_routing()
|
|
513
|
+
return self._router.tools
|
|
514
|
+
|
|
515
|
+
async def _env_call_tool(self, name: str, arguments: dict[str, Any] | None = None) -> list[Any]:
|
|
516
|
+
"""Route tool calls through our router (handles both local and connector tools)."""
|
|
517
|
+
result = await self._execute_tool(name, arguments or {})
|
|
518
|
+
return result.content or []
|
|
519
|
+
|
|
520
|
+
# =========================================================================
|
|
521
|
+
# Tool Operations
|
|
522
|
+
# =========================================================================
|
|
523
|
+
|
|
524
|
+
async def list_tools(self) -> list[mcp_types.Tool]:
|
|
525
|
+
"""Refresh tools from all connections and rebuild tool routing."""
|
|
526
|
+
if self._connections:
|
|
527
|
+
await asyncio.gather(*[c.list_tools() for c in self._connections.values()])
|
|
528
|
+
await self._build_tool_routing()
|
|
529
|
+
return self._router.tools
|
|
530
|
+
|
|
531
|
+
async def _execute_tool(self, name: str, arguments: dict[str, Any]) -> MCPToolResult:
|
|
532
|
+
"""Execute a tool by name. Routes to local or remote handler.
|
|
533
|
+
|
|
534
|
+
If mock mode is enabled, returns a mock result instead of executing.
|
|
535
|
+
"""
|
|
536
|
+
# Check mock mode first
|
|
537
|
+
if self._mock_mode:
|
|
538
|
+
logger.debug("Mock mode: returning mock result for tool %s", name)
|
|
539
|
+
return self._get_mock_result(name, arguments)
|
|
540
|
+
|
|
541
|
+
# Rebuild tool routing if invalidated (e.g., after add_tool)
|
|
542
|
+
if not self._tool_routing_built:
|
|
543
|
+
await self._build_tool_routing()
|
|
544
|
+
|
|
545
|
+
if self._router.is_local(name):
|
|
546
|
+
# Call tool manager directly to avoid FastMCP context requirement
|
|
547
|
+
result = await self._tool_manager.call_tool(name, arguments)
|
|
548
|
+
return MCPToolResult(
|
|
549
|
+
content=result.content, structuredContent=result.structured_content
|
|
550
|
+
)
|
|
551
|
+
|
|
552
|
+
connection_name = self._router.get_connection(name)
|
|
553
|
+
if connection_name:
|
|
554
|
+
conn = self._connections[connection_name]
|
|
555
|
+
result = await conn.call_tool(name, arguments)
|
|
556
|
+
return MCPToolResult(
|
|
557
|
+
content=result.content,
|
|
558
|
+
isError=result.isError,
|
|
559
|
+
structuredContent=result.structuredContent,
|
|
560
|
+
)
|
|
561
|
+
|
|
562
|
+
raise ValueError(f"Tool not found: {name}")
|
|
563
|
+
|
|
564
|
+
# =========================================================================
|
|
565
|
+
# Resource Operations
|
|
566
|
+
# =========================================================================
|
|
567
|
+
|
|
568
|
+
async def list_resources(self) -> list[mcp_types.Resource]:
|
|
569
|
+
"""Refresh resources from all connections and rebuild resource routing."""
|
|
570
|
+
if self._connections:
|
|
571
|
+
await asyncio.gather(*[c.list_resources() for c in self._connections.values()])
|
|
572
|
+
await self._build_resource_routing()
|
|
573
|
+
return self._router.resources
|
|
574
|
+
|
|
575
|
+
async def read_resource(
|
|
576
|
+
self, uri: str
|
|
577
|
+
) -> list[mcp_types.TextResourceContents | mcp_types.BlobResourceContents]:
|
|
578
|
+
"""Read a resource by URI using router for connection lookup."""
|
|
579
|
+
from pydantic import AnyUrl
|
|
580
|
+
|
|
581
|
+
# Ensure resource routing is built
|
|
582
|
+
if not self._resource_routing_built:
|
|
583
|
+
await self._build_resource_routing()
|
|
584
|
+
|
|
585
|
+
# Use router to find which connection has this resource
|
|
586
|
+
conn_name = self._router.get_resource_connection(uri)
|
|
587
|
+
|
|
588
|
+
if conn_name is None:
|
|
589
|
+
# Local resource
|
|
590
|
+
try:
|
|
591
|
+
result = await self._resource_manager.read_resource(uri)
|
|
592
|
+
resource_uri = AnyUrl(uri)
|
|
593
|
+
if isinstance(result, str):
|
|
594
|
+
return [mcp_types.TextResourceContents(uri=resource_uri, text=result)]
|
|
595
|
+
import base64
|
|
596
|
+
|
|
597
|
+
return [
|
|
598
|
+
mcp_types.BlobResourceContents(
|
|
599
|
+
uri=resource_uri, blob=base64.b64encode(result).decode()
|
|
600
|
+
)
|
|
601
|
+
]
|
|
602
|
+
except Exception as e:
|
|
603
|
+
logger.debug("Local resource read failed for %s: %s", uri, e)
|
|
604
|
+
raise ValueError(f"Resource not found: {uri}") from e
|
|
605
|
+
else:
|
|
606
|
+
# Remote resource
|
|
607
|
+
conn = self._connections.get(conn_name)
|
|
608
|
+
if conn is None:
|
|
609
|
+
raise ValueError(f"Connection '{conn_name}' not found for resource '{uri}'")
|
|
610
|
+
return await conn.read_resource(uri)
|
|
611
|
+
|
|
612
|
+
# =========================================================================
|
|
613
|
+
# Prompt Operations
|
|
614
|
+
# =========================================================================
|
|
615
|
+
|
|
616
|
+
async def list_prompts(self) -> list[mcp_types.Prompt]:
|
|
617
|
+
"""Refresh prompts from all connections and rebuild prompt routing."""
|
|
618
|
+
if self._connections:
|
|
619
|
+
await asyncio.gather(*[c.list_prompts() for c in self._connections.values()])
|
|
620
|
+
await self._build_prompt_routing()
|
|
621
|
+
return self._router.prompts
|
|
622
|
+
|
|
623
|
+
async def get_prompt(
|
|
624
|
+
self, name: str, arguments: dict[str, Any] | None = None
|
|
625
|
+
) -> mcp_types.GetPromptResult:
|
|
626
|
+
"""Get a prompt by name using router for connection lookup."""
|
|
627
|
+
# Ensure prompt routing is built
|
|
628
|
+
if not self._prompt_routing_built:
|
|
629
|
+
await self._build_prompt_routing()
|
|
630
|
+
|
|
631
|
+
# Use router to find which connection has this prompt
|
|
632
|
+
conn_name = self._router.get_prompt_connection(name)
|
|
633
|
+
|
|
634
|
+
if conn_name is None:
|
|
635
|
+
# Local prompt
|
|
636
|
+
try:
|
|
637
|
+
return await self._prompt_manager.render_prompt(name, arguments or {})
|
|
638
|
+
except Exception as e:
|
|
639
|
+
raise ValueError(f"Prompt not found: {name}") from e
|
|
640
|
+
else:
|
|
641
|
+
# Remote prompt
|
|
642
|
+
conn = self._connections.get(conn_name)
|
|
643
|
+
if conn is None:
|
|
644
|
+
raise ValueError(f"Connection '{conn_name}' not found for prompt '{name}'")
|
|
645
|
+
return await conn.get_prompt(name, arguments)
|
|
646
|
+
|
|
647
|
+
# =========================================================================
|
|
648
|
+
# Server Methods
|
|
649
|
+
# =========================================================================
|
|
650
|
+
|
|
651
|
+
def serve(
|
|
652
|
+
self,
|
|
653
|
+
transport: Literal["stdio", "sse", "streamable-http"] = "streamable-http",
|
|
654
|
+
host: str = "0.0.0.0", # noqa: S104
|
|
655
|
+
port: int = 8000,
|
|
656
|
+
**kwargs: Any,
|
|
657
|
+
) -> None:
|
|
658
|
+
"""Start serving as an MCP server."""
|
|
659
|
+
self.run(transport=transport, host=host, port=port, **kwargs)
|
|
660
|
+
|
|
661
|
+
# =========================================================================
|
|
662
|
+
# Properties
|
|
663
|
+
# =========================================================================
|
|
664
|
+
|
|
665
|
+
@property
|
|
666
|
+
def connections(self) -> dict[str, Connector]:
|
|
667
|
+
return self._connections
|
|
668
|
+
|
|
669
|
+
@property
|
|
670
|
+
def is_connected(self) -> bool:
|
|
671
|
+
return self._in_context
|
|
672
|
+
|
|
673
|
+
@property
|
|
674
|
+
def is_parallelizable(self) -> bool:
|
|
675
|
+
"""True if all connections are remote (can spawn multiple instances)."""
|
|
676
|
+
if not self._connections:
|
|
677
|
+
return True # No connections = can parallelize (local tools only)
|
|
678
|
+
return all(conn.is_remote for conn in self._connections.values())
|
|
679
|
+
|
|
680
|
+
@property
|
|
681
|
+
def local_connections(self) -> list[str]:
|
|
682
|
+
"""Names of local (non-parallelizable) connections."""
|
|
683
|
+
return [name for name, conn in self._connections.items() if conn.is_local]
|
|
684
|
+
|
|
685
|
+
# =========================================================================
|
|
686
|
+
# Serialization
|
|
687
|
+
# =========================================================================
|
|
688
|
+
|
|
689
|
+
@property
|
|
690
|
+
def is_serializable(self) -> bool:
|
|
691
|
+
"""True if environment can be serialized (no local tools/scenarios).
|
|
692
|
+
|
|
693
|
+
For v5 format: requires hub config from connect_hub()
|
|
694
|
+
For v4 format: requires mcp_config, prompt, AND evaluate_tool
|
|
695
|
+
"""
|
|
696
|
+
# Check for local tools (registered via @env.tool)
|
|
697
|
+
if self._router._local_tool_names:
|
|
698
|
+
return False
|
|
699
|
+
# Check for local scenarios (registered via @env.scenario)
|
|
700
|
+
if getattr(self, "_scenarios", {}):
|
|
701
|
+
return False
|
|
702
|
+
# v5 hub format
|
|
703
|
+
if self._hub_config is not None:
|
|
704
|
+
return True
|
|
705
|
+
# v4 format requires mcp_config + prompt + evaluate_tool
|
|
706
|
+
if self._mcp_config is not None:
|
|
707
|
+
return bool(self.prompt and self._evaluate_calls)
|
|
708
|
+
return False
|
|
709
|
+
|
|
710
|
+
def to_config(self) -> dict[str, Any]:
|
|
711
|
+
"""Serialize environment config for remote submission.
|
|
712
|
+
|
|
713
|
+
Returns the config in either v5 format (hub-based) or v4 format (legacy).
|
|
714
|
+
For v4 format, automatically includes prompt, setup_tool, and evaluate_tool
|
|
715
|
+
from the Environment's state.
|
|
716
|
+
|
|
717
|
+
Returns:
|
|
718
|
+
dict: Serializable config
|
|
719
|
+
|
|
720
|
+
Raises:
|
|
721
|
+
ValueError: If environment has local tools/scenarios that can't be serialized
|
|
722
|
+
|
|
723
|
+
Example:
|
|
724
|
+
```python
|
|
725
|
+
# v5 hub-based
|
|
726
|
+
env = Environment("my").connect_hub("browser", include=["navigate"])
|
|
727
|
+
env.to_config() # {"name": "browser", "include": ["navigate"]}
|
|
728
|
+
|
|
729
|
+
# v4 legacy (from Task.from_v4())
|
|
730
|
+
task = Task.from_v4(legacy_task)
|
|
731
|
+
task.env.to_config() # {"prompt": "...", "mcp_config": {...}, ...}
|
|
732
|
+
```
|
|
733
|
+
"""
|
|
734
|
+
if self._router._local_tool_names:
|
|
735
|
+
raise ValueError(
|
|
736
|
+
f"Cannot serialize Environment with local tools: "
|
|
737
|
+
f"{list(self._router._local_tool_names)}. "
|
|
738
|
+
"Local tools require local execution. For remote submission, "
|
|
739
|
+
"use dict config or connect to a remote hub."
|
|
740
|
+
)
|
|
741
|
+
if getattr(self, "_scenarios", {}):
|
|
742
|
+
raise ValueError(
|
|
743
|
+
f"Cannot serialize Environment with local scenarios: "
|
|
744
|
+
f"{list(self._scenarios.keys())}. "
|
|
745
|
+
"Local scenarios require local execution. For remote submission, "
|
|
746
|
+
"define scenarios on the remote environment."
|
|
747
|
+
)
|
|
748
|
+
|
|
749
|
+
# v5 hub-based format
|
|
750
|
+
if self._hub_config is not None:
|
|
751
|
+
return self._hub_config.copy()
|
|
752
|
+
|
|
753
|
+
# v4 legacy format - requires mcp_config, prompt, AND evaluate_tool
|
|
754
|
+
if self._mcp_config is not None:
|
|
755
|
+
# Validate required fields for v4 format
|
|
756
|
+
if not self.prompt:
|
|
757
|
+
raise ValueError(
|
|
758
|
+
"Cannot serialize v4 Environment without prompt. "
|
|
759
|
+
"Set env.prompt before serializing."
|
|
760
|
+
)
|
|
761
|
+
if not self._evaluate_calls:
|
|
762
|
+
raise ValueError(
|
|
763
|
+
"Cannot serialize v4 Environment without evaluate_tool. "
|
|
764
|
+
"Use env.evaluate_tool() to define evaluation criteria."
|
|
765
|
+
)
|
|
766
|
+
|
|
767
|
+
config: dict[str, Any] = {
|
|
768
|
+
"prompt": self.prompt,
|
|
769
|
+
"mcp_config": self._mcp_config,
|
|
770
|
+
"evaluate_tool": [
|
|
771
|
+
{"name": name, "arguments": args} for name, args in self._evaluate_calls
|
|
772
|
+
],
|
|
773
|
+
}
|
|
774
|
+
if self._setup_calls:
|
|
775
|
+
config["setup_tool"] = [
|
|
776
|
+
{"name": name, "arguments": args} for name, args in self._setup_calls
|
|
777
|
+
]
|
|
778
|
+
return config
|
|
779
|
+
|
|
780
|
+
raise ValueError(
|
|
781
|
+
"Cannot serialize Environment without config. "
|
|
782
|
+
"Use connect_hub() for v5 tasks or connect_mcp_config() for legacy tasks."
|
|
783
|
+
)
|
|
784
|
+
|
|
785
|
+
def __repr__(self) -> str:
|
|
786
|
+
return f"Environment({self.name!r}, connections={list(self._connections.keys())})"
|
|
787
|
+
|
|
788
|
+
# =========================================================================
|
|
789
|
+
# Task Creation
|
|
790
|
+
# =========================================================================
|
|
791
|
+
|
|
792
|
+
def __call__(
|
|
793
|
+
self,
|
|
794
|
+
scenario: str | None = None,
|
|
795
|
+
**args: Any,
|
|
796
|
+
) -> Task:
|
|
797
|
+
"""Create a Task from this environment.
|
|
798
|
+
|
|
799
|
+
Returns a Task that can be passed to hud.eval() for orchestration.
|
|
800
|
+
|
|
801
|
+
Args:
|
|
802
|
+
scenario: Scenario name to run (from @env.scenario). Optional for v4 legacy.
|
|
803
|
+
**args: Arguments for the scenario
|
|
804
|
+
|
|
805
|
+
Returns:
|
|
806
|
+
Task: A runnable evaluation unit
|
|
807
|
+
|
|
808
|
+
Example:
|
|
809
|
+
```python
|
|
810
|
+
env = Environment("my-env").connect_hub("browser")
|
|
811
|
+
|
|
812
|
+
|
|
813
|
+
@env.scenario()
|
|
814
|
+
async def checkout(user_id: str):
|
|
815
|
+
yield "Complete checkout"
|
|
816
|
+
yield 1.0
|
|
817
|
+
|
|
818
|
+
|
|
819
|
+
# Single task via hud.eval
|
|
820
|
+
async with hud.eval(env("checkout", user_id="alice")) as ctx:
|
|
821
|
+
await agent.run(ctx.prompt)
|
|
822
|
+
|
|
823
|
+
# Multiple tasks with variants
|
|
824
|
+
tasks = [env("checkout", user_id="alice"), env("checkout", user_id="bob")]
|
|
825
|
+
async with hud.eval(tasks, variants={"model": ["gpt-4o"]}, group=4) as ctx:
|
|
826
|
+
...
|
|
827
|
+
```
|
|
828
|
+
"""
|
|
829
|
+
from hud.eval.task import Task
|
|
830
|
+
|
|
831
|
+
return Task(
|
|
832
|
+
env=self,
|
|
833
|
+
scenario=scenario,
|
|
834
|
+
args=args,
|
|
835
|
+
)
|