hud-python 0.4.45__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hud/__init__.py +27 -7
- hud/agents/__init__.py +11 -5
- hud/agents/base.py +220 -500
- hud/agents/claude.py +200 -240
- hud/agents/gemini.py +275 -0
- hud/agents/gemini_cua.py +335 -0
- hud/agents/grounded_openai.py +98 -100
- hud/agents/misc/integration_test_agent.py +51 -20
- hud/agents/misc/response_agent.py +41 -36
- hud/agents/openai.py +291 -292
- hud/agents/{openai_chat_generic.py → openai_chat.py} +80 -34
- hud/agents/operator.py +211 -0
- hud/agents/tests/conftest.py +133 -0
- hud/agents/tests/test_base.py +300 -622
- hud/agents/tests/test_base_runtime.py +233 -0
- hud/agents/tests/test_claude.py +379 -210
- hud/agents/tests/test_client.py +9 -10
- hud/agents/tests/test_gemini.py +369 -0
- hud/agents/tests/test_grounded_openai_agent.py +65 -50
- hud/agents/tests/test_openai.py +376 -140
- hud/agents/tests/test_operator.py +362 -0
- hud/agents/tests/test_run_eval.py +179 -0
- hud/cli/__init__.py +461 -545
- hud/cli/analyze.py +43 -5
- hud/cli/build.py +664 -110
- hud/cli/debug.py +8 -5
- hud/cli/dev.py +882 -734
- hud/cli/eval.py +782 -668
- hud/cli/flows/dev.py +167 -0
- hud/cli/flows/init.py +191 -0
- hud/cli/flows/tasks.py +153 -56
- hud/cli/flows/templates.py +151 -0
- hud/cli/flows/tests/__init__.py +1 -0
- hud/cli/flows/tests/test_dev.py +126 -0
- hud/cli/init.py +60 -58
- hud/cli/push.py +29 -11
- hud/cli/rft.py +311 -0
- hud/cli/rft_status.py +145 -0
- hud/cli/tests/test_analyze.py +5 -5
- hud/cli/tests/test_analyze_metadata.py +3 -2
- hud/cli/tests/test_analyze_module.py +120 -0
- hud/cli/tests/test_build.py +108 -6
- hud/cli/tests/test_build_failure.py +41 -0
- hud/cli/tests/test_build_module.py +50 -0
- hud/cli/tests/test_cli_init.py +6 -1
- hud/cli/tests/test_cli_more_wrappers.py +30 -0
- hud/cli/tests/test_cli_root.py +140 -0
- hud/cli/tests/test_convert.py +361 -0
- hud/cli/tests/test_debug.py +12 -10
- hud/cli/tests/test_dev.py +197 -0
- hud/cli/tests/test_eval.py +251 -0
- hud/cli/tests/test_eval_bedrock.py +51 -0
- hud/cli/tests/test_init.py +124 -0
- hud/cli/tests/test_main_module.py +11 -5
- hud/cli/tests/test_mcp_server.py +12 -100
- hud/cli/tests/test_push_happy.py +74 -0
- hud/cli/tests/test_push_wrapper.py +23 -0
- hud/cli/tests/test_registry.py +1 -1
- hud/cli/tests/test_utils.py +1 -1
- hud/cli/{rl → utils}/celebrate.py +14 -12
- hud/cli/utils/config.py +18 -1
- hud/cli/utils/docker.py +130 -4
- hud/cli/utils/env_check.py +9 -9
- hud/cli/utils/git.py +136 -0
- hud/cli/utils/interactive.py +39 -5
- hud/cli/utils/metadata.py +69 -0
- hud/cli/utils/runner.py +1 -1
- hud/cli/utils/server.py +2 -2
- hud/cli/utils/source_hash.py +3 -3
- hud/cli/utils/tasks.py +4 -1
- hud/cli/utils/tests/__init__.py +0 -0
- hud/cli/utils/tests/test_config.py +58 -0
- hud/cli/utils/tests/test_docker.py +93 -0
- hud/cli/utils/tests/test_docker_hints.py +71 -0
- hud/cli/utils/tests/test_env_check.py +74 -0
- hud/cli/utils/tests/test_environment.py +42 -0
- hud/cli/utils/tests/test_git.py +142 -0
- hud/cli/utils/tests/test_interactive_module.py +60 -0
- hud/cli/utils/tests/test_local_runner.py +50 -0
- hud/cli/utils/tests/test_logging_utils.py +23 -0
- hud/cli/utils/tests/test_metadata.py +49 -0
- hud/cli/utils/tests/test_package_runner.py +35 -0
- hud/cli/utils/tests/test_registry_utils.py +49 -0
- hud/cli/utils/tests/test_remote_runner.py +25 -0
- hud/cli/utils/tests/test_runner_modules.py +52 -0
- hud/cli/utils/tests/test_source_hash.py +36 -0
- hud/cli/utils/tests/test_tasks.py +80 -0
- hud/cli/utils/version_check.py +258 -0
- hud/cli/{rl → utils}/viewer.py +2 -2
- hud/clients/README.md +12 -11
- hud/clients/__init__.py +4 -3
- hud/clients/base.py +166 -26
- hud/clients/environment.py +51 -0
- hud/clients/fastmcp.py +13 -6
- hud/clients/mcp_use.py +40 -15
- hud/clients/tests/test_analyze_scenarios.py +206 -0
- hud/clients/tests/test_protocol.py +9 -3
- hud/datasets/__init__.py +23 -20
- hud/datasets/loader.py +327 -0
- hud/datasets/runner.py +192 -105
- hud/datasets/tests/__init__.py +0 -0
- hud/datasets/tests/test_loader.py +221 -0
- hud/datasets/tests/test_utils.py +315 -0
- hud/datasets/utils.py +270 -90
- hud/environment/__init__.py +50 -0
- hud/environment/connection.py +206 -0
- hud/environment/connectors/__init__.py +33 -0
- hud/environment/connectors/base.py +68 -0
- hud/environment/connectors/local.py +177 -0
- hud/environment/connectors/mcp_config.py +109 -0
- hud/environment/connectors/openai.py +101 -0
- hud/environment/connectors/remote.py +172 -0
- hud/environment/environment.py +694 -0
- hud/environment/integrations/__init__.py +45 -0
- hud/environment/integrations/adk.py +67 -0
- hud/environment/integrations/anthropic.py +196 -0
- hud/environment/integrations/gemini.py +92 -0
- hud/environment/integrations/langchain.py +82 -0
- hud/environment/integrations/llamaindex.py +68 -0
- hud/environment/integrations/openai.py +238 -0
- hud/environment/mock.py +306 -0
- hud/environment/router.py +112 -0
- hud/environment/scenarios.py +493 -0
- hud/environment/tests/__init__.py +1 -0
- hud/environment/tests/test_connection.py +317 -0
- hud/environment/tests/test_connectors.py +218 -0
- hud/environment/tests/test_environment.py +161 -0
- hud/environment/tests/test_integrations.py +257 -0
- hud/environment/tests/test_local_connectors.py +201 -0
- hud/environment/tests/test_scenarios.py +280 -0
- hud/environment/tests/test_tools.py +208 -0
- hud/environment/types.py +23 -0
- hud/environment/utils/__init__.py +35 -0
- hud/environment/utils/formats.py +215 -0
- hud/environment/utils/schema.py +171 -0
- hud/environment/utils/tool_wrappers.py +113 -0
- hud/eval/__init__.py +67 -0
- hud/eval/context.py +674 -0
- hud/eval/display.py +299 -0
- hud/eval/instrument.py +185 -0
- hud/eval/manager.py +466 -0
- hud/eval/parallel.py +268 -0
- hud/eval/task.py +340 -0
- hud/eval/tests/__init__.py +1 -0
- hud/eval/tests/test_context.py +178 -0
- hud/eval/tests/test_eval.py +210 -0
- hud/eval/tests/test_manager.py +152 -0
- hud/eval/tests/test_parallel.py +168 -0
- hud/eval/tests/test_task.py +145 -0
- hud/eval/types.py +63 -0
- hud/eval/utils.py +183 -0
- hud/patches/__init__.py +19 -0
- hud/patches/mcp_patches.py +151 -0
- hud/patches/warnings.py +54 -0
- hud/samples/browser.py +4 -4
- hud/server/__init__.py +2 -1
- hud/server/low_level.py +2 -1
- hud/server/router.py +164 -0
- hud/server/server.py +567 -80
- hud/server/tests/test_mcp_server_integration.py +11 -11
- hud/server/tests/test_mcp_server_more.py +1 -1
- hud/server/tests/test_server_extra.py +2 -0
- hud/settings.py +45 -3
- hud/shared/exceptions.py +36 -10
- hud/shared/hints.py +26 -1
- hud/shared/requests.py +15 -3
- hud/shared/tests/test_exceptions.py +40 -31
- hud/shared/tests/test_hints.py +167 -0
- hud/telemetry/__init__.py +20 -19
- hud/telemetry/exporter.py +201 -0
- hud/telemetry/instrument.py +158 -253
- hud/telemetry/tests/test_eval_telemetry.py +356 -0
- hud/telemetry/tests/test_exporter.py +258 -0
- hud/telemetry/tests/test_instrument.py +401 -0
- hud/tools/__init__.py +16 -2
- hud/tools/apply_patch.py +639 -0
- hud/tools/base.py +54 -4
- hud/tools/bash.py +2 -2
- hud/tools/computer/__init__.py +4 -0
- hud/tools/computer/anthropic.py +2 -2
- hud/tools/computer/gemini.py +385 -0
- hud/tools/computer/hud.py +23 -6
- hud/tools/computer/openai.py +20 -21
- hud/tools/computer/qwen.py +434 -0
- hud/tools/computer/settings.py +37 -0
- hud/tools/edit.py +3 -7
- hud/tools/executors/base.py +4 -2
- hud/tools/executors/pyautogui.py +1 -1
- hud/tools/grounding/grounded_tool.py +13 -18
- hud/tools/grounding/grounder.py +10 -31
- hud/tools/grounding/tests/test_grounded_tool.py +26 -44
- hud/tools/jupyter.py +330 -0
- hud/tools/playwright.py +18 -3
- hud/tools/shell.py +308 -0
- hud/tools/tests/test_apply_patch.py +718 -0
- hud/tools/tests/test_computer.py +4 -9
- hud/tools/tests/test_computer_actions.py +24 -2
- hud/tools/tests/test_jupyter_tool.py +181 -0
- hud/tools/tests/test_shell.py +596 -0
- hud/tools/tests/test_submit.py +85 -0
- hud/tools/tests/test_types.py +193 -0
- hud/tools/types.py +21 -1
- hud/types.py +167 -57
- hud/utils/__init__.py +2 -0
- hud/utils/env.py +67 -0
- hud/utils/hud_console.py +61 -3
- hud/utils/mcp.py +15 -58
- hud/utils/strict_schema.py +162 -0
- hud/utils/tests/test_init.py +1 -2
- hud/utils/tests/test_mcp.py +1 -28
- hud/utils/tests/test_pretty_errors.py +186 -0
- hud/utils/tests/test_tool_shorthand.py +154 -0
- hud/utils/tests/test_version.py +1 -1
- hud/utils/types.py +20 -0
- hud/version.py +1 -1
- hud_python-0.5.1.dist-info/METADATA +264 -0
- hud_python-0.5.1.dist-info/RECORD +299 -0
- {hud_python-0.4.45.dist-info → hud_python-0.5.1.dist-info}/WHEEL +1 -1
- hud/agents/langchain.py +0 -261
- hud/agents/lite_llm.py +0 -72
- hud/cli/rl/__init__.py +0 -180
- hud/cli/rl/config.py +0 -101
- hud/cli/rl/display.py +0 -133
- hud/cli/rl/gpu.py +0 -63
- hud/cli/rl/gpu_utils.py +0 -321
- hud/cli/rl/local_runner.py +0 -595
- hud/cli/rl/presets.py +0 -96
- hud/cli/rl/remote_runner.py +0 -463
- hud/cli/rl/rl_api.py +0 -150
- hud/cli/rl/vllm.py +0 -177
- hud/cli/rl/wait_utils.py +0 -89
- hud/datasets/parallel.py +0 -687
- hud/misc/__init__.py +0 -1
- hud/misc/claude_plays_pokemon.py +0 -292
- hud/otel/__init__.py +0 -35
- hud/otel/collector.py +0 -142
- hud/otel/config.py +0 -181
- hud/otel/context.py +0 -570
- hud/otel/exporters.py +0 -369
- hud/otel/instrumentation.py +0 -135
- hud/otel/processors.py +0 -121
- hud/otel/tests/__init__.py +0 -1
- hud/otel/tests/test_processors.py +0 -197
- hud/rl/README.md +0 -30
- hud/rl/__init__.py +0 -1
- hud/rl/actor.py +0 -176
- hud/rl/buffer.py +0 -405
- hud/rl/chat_template.jinja +0 -101
- hud/rl/config.py +0 -192
- hud/rl/distributed.py +0 -132
- hud/rl/learner.py +0 -637
- hud/rl/tests/__init__.py +0 -1
- hud/rl/tests/test_learner.py +0 -186
- hud/rl/train.py +0 -382
- hud/rl/types.py +0 -101
- hud/rl/utils/start_vllm_server.sh +0 -30
- hud/rl/utils.py +0 -524
- hud/rl/vllm_adapter.py +0 -143
- hud/telemetry/job.py +0 -352
- hud/telemetry/replay.py +0 -74
- hud/telemetry/tests/test_replay.py +0 -40
- hud/telemetry/tests/test_trace.py +0 -63
- hud/telemetry/trace.py +0 -158
- hud/utils/agent_factories.py +0 -86
- hud/utils/async_utils.py +0 -65
- hud/utils/group_eval.py +0 -223
- hud/utils/progress.py +0 -149
- hud/utils/tasks.py +0 -127
- hud/utils/tests/test_async_utils.py +0 -173
- hud/utils/tests/test_progress.py +0 -261
- hud_python-0.4.45.dist-info/METADATA +0 -552
- hud_python-0.4.45.dist-info/RECORD +0 -228
- {hud_python-0.4.45.dist-info → hud_python-0.5.1.dist-info}/entry_points.txt +0 -0
- {hud_python-0.4.45.dist-info → hud_python-0.5.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -3,14 +3,15 @@
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
5
|
import logging
|
|
6
|
-
from typing import Any
|
|
6
|
+
from typing import TYPE_CHECKING, Any
|
|
7
7
|
|
|
8
8
|
from mcp import ErrorData, McpError
|
|
9
9
|
from mcp.types import INVALID_PARAMS, ContentBlock
|
|
10
10
|
|
|
11
|
-
from hud.clients.base import AgentMCPClient # noqa: TC001
|
|
12
11
|
from hud.tools.grounding.grounder import Grounder # noqa: TC001
|
|
13
|
-
|
|
12
|
+
|
|
13
|
+
if TYPE_CHECKING:
|
|
14
|
+
from hud.environment import Environment
|
|
14
15
|
|
|
15
16
|
logger = logging.getLogger(__name__)
|
|
16
17
|
|
|
@@ -33,18 +34,18 @@ class GroundedComputerTool:
|
|
|
33
34
|
self,
|
|
34
35
|
*,
|
|
35
36
|
grounder: Grounder,
|
|
36
|
-
|
|
37
|
+
ctx: Environment,
|
|
37
38
|
computer_tool_name: str = "computer",
|
|
38
39
|
) -> None:
|
|
39
40
|
"""Initialize the grounded computer tool.
|
|
40
41
|
|
|
41
42
|
Args:
|
|
42
43
|
grounder: Grounder instance for visual grounding
|
|
43
|
-
|
|
44
|
+
ctx: Environment or EvalContext to call tools through
|
|
44
45
|
computer_tool_name: Name of the computer tool in the environment
|
|
45
46
|
"""
|
|
46
47
|
self._grounder = grounder
|
|
47
|
-
self.
|
|
48
|
+
self._ctx = ctx
|
|
48
49
|
self._computer_tool_name = computer_tool_name
|
|
49
50
|
|
|
50
51
|
def get_openai_tool_schema(self) -> dict:
|
|
@@ -172,10 +173,8 @@ class GroundedComputerTool:
|
|
|
172
173
|
if keys is not None:
|
|
173
174
|
computer_args["keys"] = keys
|
|
174
175
|
|
|
175
|
-
result = await self.
|
|
176
|
-
|
|
177
|
-
name=self._computer_tool_name, arguments={**computer_args, **kwargs}
|
|
178
|
-
)
|
|
176
|
+
result = await self._ctx.call_tool(
|
|
177
|
+
(self._computer_tool_name, {**computer_args, **kwargs})
|
|
179
178
|
)
|
|
180
179
|
return result.content
|
|
181
180
|
|
|
@@ -224,10 +223,8 @@ class GroundedComputerTool:
|
|
|
224
223
|
if scroll_y is not None:
|
|
225
224
|
computer_args["scroll_y"] = scroll_y
|
|
226
225
|
|
|
227
|
-
result = await self.
|
|
228
|
-
|
|
229
|
-
name=self._computer_tool_name, arguments={**computer_args, **kwargs}
|
|
230
|
-
)
|
|
226
|
+
result = await self._ctx.call_tool(
|
|
227
|
+
(self._computer_tool_name, {**computer_args, **kwargs})
|
|
231
228
|
)
|
|
232
229
|
return result.content
|
|
233
230
|
|
|
@@ -292,10 +289,8 @@ class GroundedComputerTool:
|
|
|
292
289
|
if button:
|
|
293
290
|
computer_args["button"] = button
|
|
294
291
|
|
|
295
|
-
result = await self.
|
|
296
|
-
|
|
297
|
-
name=self._computer_tool_name, arguments={**computer_args, **kwargs}
|
|
298
|
-
)
|
|
292
|
+
result = await self._ctx.call_tool(
|
|
293
|
+
(self._computer_tool_name, {**computer_args, **kwargs})
|
|
299
294
|
)
|
|
300
295
|
return result.content
|
|
301
296
|
|
hud/tools/grounding/grounder.py
CHANGED
|
@@ -4,15 +4,15 @@ from __future__ import annotations
|
|
|
4
4
|
|
|
5
5
|
import base64
|
|
6
6
|
import io
|
|
7
|
-
import
|
|
7
|
+
import logging
|
|
8
8
|
import re
|
|
9
9
|
|
|
10
10
|
from openai import AsyncOpenAI
|
|
11
|
-
from opentelemetry import trace
|
|
12
11
|
|
|
13
|
-
from hud import instrument
|
|
14
12
|
from hud.tools.grounding.config import GrounderConfig # noqa: TC001
|
|
15
13
|
|
|
14
|
+
logger = logging.getLogger(__name__)
|
|
15
|
+
|
|
16
16
|
|
|
17
17
|
class Grounder:
|
|
18
18
|
"""Grounder that uses AsyncOpenAI to call vLLM or other model endpoints for visual grounding.
|
|
@@ -181,12 +181,6 @@ class Grounder:
|
|
|
181
181
|
|
|
182
182
|
return (final_x, final_y)
|
|
183
183
|
|
|
184
|
-
@instrument(
|
|
185
|
-
name="Grounding.predict_click",
|
|
186
|
-
span_type="agent",
|
|
187
|
-
record_args=True,
|
|
188
|
-
record_result=True,
|
|
189
|
-
)
|
|
190
184
|
async def predict_click(
|
|
191
185
|
self, *, image_b64: str, instruction: str, max_retries: int = 3
|
|
192
186
|
) -> tuple[int, int] | None:
|
|
@@ -247,12 +241,7 @@ class Grounder:
|
|
|
247
241
|
|
|
248
242
|
# Extract response text
|
|
249
243
|
response_text = response.choices[0].message.content
|
|
250
|
-
|
|
251
|
-
# Manually record the raw response in the span
|
|
252
|
-
span = trace.get_current_span()
|
|
253
|
-
if span and span.is_recording():
|
|
254
|
-
span.set_attribute("grounder.raw_response", json.dumps(response.model_dump()))
|
|
255
|
-
span.set_attribute("grounder.attempt", attempt + 1)
|
|
244
|
+
logger.debug("Grounder attempt %d response: %s", attempt + 1, response_text)
|
|
256
245
|
|
|
257
246
|
# Parse coordinates from response
|
|
258
247
|
if response_text is None:
|
|
@@ -277,26 +266,16 @@ class Grounder:
|
|
|
277
266
|
y = max(0, min(y, original_size[1] - 1))
|
|
278
267
|
pixel_coords = (x, y)
|
|
279
268
|
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
"grounder.final_coords", f"{pixel_coords[0]},{pixel_coords[1]}"
|
|
286
|
-
)
|
|
287
|
-
span.set_attribute("grounder.total_attempts", attempt + 1)
|
|
288
|
-
|
|
269
|
+
logger.debug(
|
|
270
|
+
"Grounder success: coords=%s after %d attempts",
|
|
271
|
+
pixel_coords,
|
|
272
|
+
attempt + 1,
|
|
273
|
+
)
|
|
289
274
|
return pixel_coords
|
|
290
275
|
|
|
291
276
|
except Exception:
|
|
292
277
|
if attempt < max_retries - 1:
|
|
293
278
|
continue
|
|
294
279
|
|
|
295
|
-
|
|
296
|
-
span = trace.get_current_span()
|
|
297
|
-
if span and span.is_recording():
|
|
298
|
-
span.set_attribute("grounder.success", False)
|
|
299
|
-
span.set_attribute("grounder.total_attempts", max_retries)
|
|
300
|
-
span.set_attribute("grounder.failure_reason", "All attempts exhausted")
|
|
301
|
-
|
|
280
|
+
logger.debug("Grounder failed after %d attempts", max_retries)
|
|
302
281
|
return None
|
|
@@ -7,7 +7,7 @@ import mcp.types as types
|
|
|
7
7
|
import pytest
|
|
8
8
|
|
|
9
9
|
from hud.tools.grounding.grounded_tool import GroundedComputerTool
|
|
10
|
-
from hud.types import
|
|
10
|
+
from hud.types import MCPToolResult
|
|
11
11
|
|
|
12
12
|
|
|
13
13
|
@dataclass
|
|
@@ -17,36 +17,18 @@ class FakeResult:
|
|
|
17
17
|
structuredContent: dict | None = None
|
|
18
18
|
|
|
19
19
|
|
|
20
|
-
class
|
|
21
|
-
"""Fake
|
|
22
|
-
|
|
23
|
-
_initialized: bool
|
|
20
|
+
class FakeEnvironment:
|
|
21
|
+
"""Fake Environment that implements the call_tool interface."""
|
|
24
22
|
|
|
25
23
|
def __init__(self) -> None:
|
|
26
24
|
self.calls: list[tuple[str, dict[str, Any]]] = []
|
|
27
|
-
self._initialized = False
|
|
28
|
-
|
|
29
|
-
@property
|
|
30
|
-
def mcp_config(self) -> dict[str, dict[str, Any]]:
|
|
31
|
-
return {"test": {"command": "echo", "args": ["test"]}}
|
|
32
|
-
|
|
33
|
-
@property
|
|
34
|
-
def is_connected(self) -> bool:
|
|
35
|
-
return self._initialized
|
|
36
25
|
|
|
37
|
-
async def
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
return [types.Tool(name="computer", description="Test tool", inputSchema={})]
|
|
42
|
-
|
|
43
|
-
async def call_tool(self, tool_call: MCPToolCall) -> MCPToolResult:
|
|
44
|
-
self.calls.append((tool_call.name, tool_call.arguments or {}))
|
|
26
|
+
async def call_tool(self, call: tuple[str, dict[str, Any]], /, **kwargs: Any) -> MCPToolResult:
|
|
27
|
+
"""Record the tool call and return a fake result."""
|
|
28
|
+
tool_name, tool_args = call
|
|
29
|
+
self.calls.append((tool_name, tool_args))
|
|
45
30
|
return MCPToolResult(content=[types.TextContent(text="ok", type="text")], isError=False)
|
|
46
31
|
|
|
47
|
-
async def shutdown(self) -> None:
|
|
48
|
-
self._initialized = False
|
|
49
|
-
|
|
50
32
|
|
|
51
33
|
class FakeGrounder:
|
|
52
34
|
"""Fake grounder that implements Grounder interface."""
|
|
@@ -72,9 +54,9 @@ def _png_b64() -> str:
|
|
|
72
54
|
|
|
73
55
|
@pytest.mark.asyncio
|
|
74
56
|
async def test_click_action_grounds_and_calls_mcp() -> None:
|
|
75
|
-
|
|
57
|
+
ctx = FakeEnvironment()
|
|
76
58
|
grounder = FakeGrounder(coords=(123, 456))
|
|
77
|
-
tool = GroundedComputerTool(grounder=grounder,
|
|
59
|
+
tool = GroundedComputerTool(grounder=grounder, ctx=ctx) # type: ignore
|
|
78
60
|
|
|
79
61
|
blocks = await tool(
|
|
80
62
|
action="click",
|
|
@@ -87,14 +69,14 @@ async def test_click_action_grounds_and_calls_mcp() -> None:
|
|
|
87
69
|
# Grounder called once
|
|
88
70
|
assert len(grounder.calls) == 1
|
|
89
71
|
# MCP called with resolved coordinates
|
|
90
|
-
assert
|
|
72
|
+
assert ctx.calls == [("computer", {"action": "click", "x": 123, "y": 456, "button": "left"})]
|
|
91
73
|
|
|
92
74
|
|
|
93
75
|
@pytest.mark.asyncio
|
|
94
76
|
async def test_move_and_scroll_require_element_description_and_screenshot() -> None:
|
|
95
|
-
|
|
77
|
+
ctx = FakeEnvironment()
|
|
96
78
|
grounder = FakeGrounder(coords=(5, 6))
|
|
97
|
-
tool = GroundedComputerTool(grounder=grounder,
|
|
79
|
+
tool = GroundedComputerTool(grounder=grounder, ctx=ctx) # type: ignore
|
|
98
80
|
|
|
99
81
|
# Missing element_description
|
|
100
82
|
with pytest.raises(Exception) as ei:
|
|
@@ -109,9 +91,9 @@ async def test_move_and_scroll_require_element_description_and_screenshot() -> N
|
|
|
109
91
|
|
|
110
92
|
@pytest.mark.asyncio
|
|
111
93
|
async def test_drag_grounds_both_points_and_calls_mcp() -> None:
|
|
112
|
-
|
|
94
|
+
ctx = FakeEnvironment()
|
|
113
95
|
grounder = FakeGrounder(coords=(10, 20))
|
|
114
|
-
tool = GroundedComputerTool(grounder=grounder,
|
|
96
|
+
tool = GroundedComputerTool(grounder=grounder, ctx=ctx) # type: ignore
|
|
115
97
|
|
|
116
98
|
await tool(
|
|
117
99
|
action="drag",
|
|
@@ -124,7 +106,7 @@ async def test_drag_grounds_both_points_and_calls_mcp() -> None:
|
|
|
124
106
|
# Two grounding calls (start and end)
|
|
125
107
|
assert len(grounder.calls) == 2
|
|
126
108
|
# Drag path contains two points, same coords from fake grounder
|
|
127
|
-
name, args =
|
|
109
|
+
name, args = ctx.calls[0]
|
|
128
110
|
assert name == "computer"
|
|
129
111
|
assert args["action"] == "drag"
|
|
130
112
|
assert args["button"] == "left"
|
|
@@ -133,9 +115,9 @@ async def test_drag_grounds_both_points_and_calls_mcp() -> None:
|
|
|
133
115
|
|
|
134
116
|
@pytest.mark.asyncio
|
|
135
117
|
async def test_drag_requires_both_descriptions_and_screenshot() -> None:
|
|
136
|
-
|
|
118
|
+
ctx = FakeEnvironment()
|
|
137
119
|
grounder = FakeGrounder()
|
|
138
|
-
tool = GroundedComputerTool(grounder=grounder,
|
|
120
|
+
tool = GroundedComputerTool(grounder=grounder, ctx=ctx) # type: ignore
|
|
139
121
|
|
|
140
122
|
with pytest.raises(Exception) as ei:
|
|
141
123
|
await tool(action="drag", start_element_description="a", screenshot_b64=_png_b64())
|
|
@@ -152,9 +134,9 @@ async def test_drag_requires_both_descriptions_and_screenshot() -> None:
|
|
|
152
134
|
|
|
153
135
|
@pytest.mark.asyncio
|
|
154
136
|
async def test_direct_actions_bypass_grounding_and_call_mcp() -> None:
|
|
155
|
-
|
|
137
|
+
ctx = FakeEnvironment()
|
|
156
138
|
grounder = FakeGrounder()
|
|
157
|
-
tool = GroundedComputerTool(grounder=grounder,
|
|
139
|
+
tool = GroundedComputerTool(grounder=grounder, ctx=ctx) # type: ignore
|
|
158
140
|
|
|
159
141
|
# Actions that bypass grounding
|
|
160
142
|
for action, extra in [
|
|
@@ -166,19 +148,19 @@ async def test_direct_actions_bypass_grounding_and_call_mcp() -> None:
|
|
|
166
148
|
("get_dimensions", {}),
|
|
167
149
|
("get_environment", {}),
|
|
168
150
|
]:
|
|
169
|
-
|
|
151
|
+
ctx.calls.clear()
|
|
170
152
|
_ = await tool(action=action, **extra)
|
|
171
|
-
assert
|
|
172
|
-
assert
|
|
153
|
+
assert ctx.calls and ctx.calls[0][0] == "computer"
|
|
154
|
+
assert ctx.calls[0][1]["action"] == action
|
|
173
155
|
# Grounder not invoked for these
|
|
174
156
|
assert grounder.calls == []
|
|
175
157
|
|
|
176
158
|
|
|
177
159
|
@pytest.mark.asyncio
|
|
178
160
|
async def test_unsupported_action_raises() -> None:
|
|
179
|
-
|
|
161
|
+
ctx = FakeEnvironment()
|
|
180
162
|
grounder = FakeGrounder()
|
|
181
|
-
tool = GroundedComputerTool(grounder=grounder,
|
|
163
|
+
tool = GroundedComputerTool(grounder=grounder, ctx=ctx) # type: ignore
|
|
182
164
|
|
|
183
165
|
with pytest.raises(Exception) as ei:
|
|
184
166
|
await tool(action="zoom")
|
|
@@ -187,9 +169,9 @@ async def test_unsupported_action_raises() -> None:
|
|
|
187
169
|
|
|
188
170
|
@pytest.mark.asyncio
|
|
189
171
|
async def test_grounding_failure_propagates_as_error() -> None:
|
|
190
|
-
|
|
172
|
+
ctx = FakeEnvironment()
|
|
191
173
|
grounder = FakeGrounder(coords=None)
|
|
192
|
-
tool = GroundedComputerTool(grounder=grounder,
|
|
174
|
+
tool = GroundedComputerTool(grounder=grounder, ctx=ctx) # type: ignore
|
|
193
175
|
|
|
194
176
|
with pytest.raises(Exception) as ei:
|
|
195
177
|
await tool(action="click", element_description="x", screenshot_b64=_png_b64())
|
hud/tools/jupyter.py
ADDED
|
@@ -0,0 +1,330 @@
|
|
|
1
|
+
"""Jupyter execution tool.
|
|
2
|
+
|
|
3
|
+
Requires the [agents] extra: pip install hud-python[agents]
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
import asyncio
|
|
9
|
+
import logging
|
|
10
|
+
import re
|
|
11
|
+
from typing import TYPE_CHECKING, Any, ClassVar
|
|
12
|
+
from uuid import uuid4
|
|
13
|
+
|
|
14
|
+
from hud.tools.base import BaseTool
|
|
15
|
+
from hud.tools.types import ContentResult, ToolError
|
|
16
|
+
|
|
17
|
+
if TYPE_CHECKING:
|
|
18
|
+
from mcp.types import ContentBlock
|
|
19
|
+
|
|
20
|
+
logger = logging.getLogger(__name__)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def strip_ansi(output: str) -> str:
|
|
24
|
+
"""Remove ANSI escape sequences from string output."""
|
|
25
|
+
pattern = re.compile(r"\x1B\[\d+(;\d+){0,2}m")
|
|
26
|
+
return pattern.sub("", output)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class JupyterTool(BaseTool):
|
|
30
|
+
"""
|
|
31
|
+
Execute Python code in a Jupyter kernel.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
# Class-level kernel registry for sharing kernels
|
|
35
|
+
_kernel_registry: ClassVar[dict[str, str]] = {}
|
|
36
|
+
|
|
37
|
+
@classmethod
|
|
38
|
+
def register_shared_kernel(cls, registry_name: str, kernel_id: str) -> None:
|
|
39
|
+
"""Register a kernel_id with a name for reuse.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
registry_name: Name to register the kernel under
|
|
43
|
+
kernel_id: The kernel ID to register
|
|
44
|
+
"""
|
|
45
|
+
cls._kernel_registry[registry_name] = kernel_id
|
|
46
|
+
logger.info("Registered kernel '%s': %s", registry_name, kernel_id)
|
|
47
|
+
|
|
48
|
+
@classmethod
|
|
49
|
+
def from_shared_kernel(cls, registry_name: str, **kwargs: Any) -> JupyterTool:
|
|
50
|
+
"""Connect to a kernel using its registry name.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
registry_name: Name of the registered kernel
|
|
54
|
+
**kwargs: Additional parameters for JupyterTool (url_suffix, kernel_name)
|
|
55
|
+
|
|
56
|
+
Returns:
|
|
57
|
+
JupyterTool instance connected to the registered kernel
|
|
58
|
+
"""
|
|
59
|
+
kernel_id = cls._kernel_registry.get(registry_name)
|
|
60
|
+
if not kernel_id:
|
|
61
|
+
raise ValueError(f"No kernel registered with name '{registry_name}'")
|
|
62
|
+
|
|
63
|
+
logger.info("Connecting to registered kernel '%s': %s", registry_name, kernel_id)
|
|
64
|
+
return cls(kernel_id=kernel_id, **kwargs)
|
|
65
|
+
|
|
66
|
+
def __init__(
|
|
67
|
+
self,
|
|
68
|
+
url_suffix: str = "localhost:8888",
|
|
69
|
+
kernel_name: str = "python3",
|
|
70
|
+
kernel_id: str = "",
|
|
71
|
+
) -> None:
|
|
72
|
+
"""Initialize JupyterTool with connection parameters.
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
url_suffix: (Optional) Kernel gateway host:port (default: localhost:8888)
|
|
76
|
+
kernel_name: (Optional) Kernel name to use (default: python3)
|
|
77
|
+
kernel_id: (Optional) If set, connect to the existed kernel with kernel_id.
|
|
78
|
+
If empty, create new kernel
|
|
79
|
+
"""
|
|
80
|
+
# Check tornado is available
|
|
81
|
+
try:
|
|
82
|
+
import tornado # noqa: F401
|
|
83
|
+
except ImportError as e:
|
|
84
|
+
raise ImportError(
|
|
85
|
+
"JupyterTool requires the [agents] extra. "
|
|
86
|
+
"Install with: pip install hud-python[agents]"
|
|
87
|
+
) from e
|
|
88
|
+
|
|
89
|
+
super().__init__(
|
|
90
|
+
env=None,
|
|
91
|
+
name="jupyter",
|
|
92
|
+
title="Jupyter Code Execution",
|
|
93
|
+
description="Execute Python code in a Jupyter kernel",
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
# Connection parameters
|
|
97
|
+
self._base_url = f"http://{url_suffix}"
|
|
98
|
+
self._base_ws_url = f"ws://{url_suffix}"
|
|
99
|
+
self._kernel_name = kernel_name
|
|
100
|
+
|
|
101
|
+
# Kernel state (reuse existing or create new)
|
|
102
|
+
self._kernel_id = kernel_id
|
|
103
|
+
self._ws: Any = None
|
|
104
|
+
self._initialized = False
|
|
105
|
+
|
|
106
|
+
# WebSocket heartbeat
|
|
107
|
+
self._heartbeat_interval = 10000 # 10 seconds
|
|
108
|
+
self._heartbeat_callback: Any = None
|
|
109
|
+
|
|
110
|
+
async def __call__(self, code: str, execution_timeout: int = 15) -> list[ContentBlock]:
|
|
111
|
+
"""Execute Python code in the Jupyter kernel.
|
|
112
|
+
|
|
113
|
+
Args:
|
|
114
|
+
code: Python code to execute
|
|
115
|
+
execution_timeout: Execution timeout in seconds (default: 15)
|
|
116
|
+
|
|
117
|
+
Returns:
|
|
118
|
+
List of ContentBlock with execution results
|
|
119
|
+
"""
|
|
120
|
+
try:
|
|
121
|
+
# Ensure kernel is ready (lazy initialization)
|
|
122
|
+
await self._ensure_kernel()
|
|
123
|
+
|
|
124
|
+
# Execute code
|
|
125
|
+
result = await self._execute(code, execution_timeout)
|
|
126
|
+
|
|
127
|
+
# Check for timeout
|
|
128
|
+
if result.startswith("[Execution timed out"):
|
|
129
|
+
return ContentResult(error=result).to_content_blocks()
|
|
130
|
+
|
|
131
|
+
# Return result
|
|
132
|
+
output = result if result.strip() else "Code executed successfully (no output)"
|
|
133
|
+
return ContentResult(output=output).to_content_blocks()
|
|
134
|
+
|
|
135
|
+
except Exception as e:
|
|
136
|
+
logger.error("Jupyter execution error: %s", e)
|
|
137
|
+
raise ToolError(f"Execution failed: {e!s}") from e
|
|
138
|
+
|
|
139
|
+
async def _ensure_kernel(self) -> None:
|
|
140
|
+
"""Ensure kernel is initialized and connected."""
|
|
141
|
+
if not self._initialized:
|
|
142
|
+
logger.info("Initializing Jupyter kernel connection")
|
|
143
|
+
await self._connect()
|
|
144
|
+
self._initialized = True
|
|
145
|
+
logger.info("Jupyter kernel connected successfully")
|
|
146
|
+
|
|
147
|
+
async def _connect(self) -> None:
|
|
148
|
+
"""Connect to Jupyter kernel via WebSocket."""
|
|
149
|
+
import tornado.iostream
|
|
150
|
+
from tornado.escape import json_decode, json_encode, url_escape
|
|
151
|
+
from tornado.httpclient import AsyncHTTPClient, HTTPRequest
|
|
152
|
+
from tornado.ioloop import PeriodicCallback
|
|
153
|
+
from tornado.websocket import websocket_connect
|
|
154
|
+
|
|
155
|
+
if self._ws:
|
|
156
|
+
self._ws.close()
|
|
157
|
+
self._ws = None
|
|
158
|
+
|
|
159
|
+
client = AsyncHTTPClient()
|
|
160
|
+
if not self._kernel_id:
|
|
161
|
+
# Start a new kernel
|
|
162
|
+
n_tries = 5
|
|
163
|
+
while n_tries > 0:
|
|
164
|
+
try:
|
|
165
|
+
response = await client.fetch(
|
|
166
|
+
f"{self._base_url}/api/kernels",
|
|
167
|
+
method="POST",
|
|
168
|
+
body=json_encode({"name": self._kernel_name}),
|
|
169
|
+
)
|
|
170
|
+
kernel = json_decode(response.body)
|
|
171
|
+
self._kernel_id = kernel["id"]
|
|
172
|
+
logger.info("Kernel started with ID: %s", self._kernel_id)
|
|
173
|
+
break
|
|
174
|
+
except Exception as e:
|
|
175
|
+
logger.warning("Kernel connection attempt failed: %s", e)
|
|
176
|
+
n_tries -= 1
|
|
177
|
+
await asyncio.sleep(1)
|
|
178
|
+
|
|
179
|
+
if n_tries == 0:
|
|
180
|
+
raise ConnectionRefusedError("Failed to connect to kernel gateway")
|
|
181
|
+
|
|
182
|
+
# Connect WebSocket to kernel
|
|
183
|
+
ws_req = HTTPRequest(
|
|
184
|
+
url=f"{self._base_ws_url}/api/kernels/{url_escape(self._kernel_id)}/channels"
|
|
185
|
+
)
|
|
186
|
+
self._ws = await websocket_connect(ws_req)
|
|
187
|
+
logger.info("WebSocket connected to kernel")
|
|
188
|
+
|
|
189
|
+
# Setup heartbeat to keep connection alive
|
|
190
|
+
if self._heartbeat_callback:
|
|
191
|
+
self._heartbeat_callback.stop()
|
|
192
|
+
|
|
193
|
+
async def heartbeat() -> None:
|
|
194
|
+
if not self._ws:
|
|
195
|
+
return
|
|
196
|
+
try:
|
|
197
|
+
self._ws.ping()
|
|
198
|
+
except tornado.iostream.StreamClosedError:
|
|
199
|
+
try:
|
|
200
|
+
await self._connect()
|
|
201
|
+
except ConnectionRefusedError:
|
|
202
|
+
logger.warning(
|
|
203
|
+
"Failed to reconnect to kernel websocket - Is the kernel still running?"
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
self._heartbeat_callback = PeriodicCallback(heartbeat, self._heartbeat_interval)
|
|
207
|
+
self._heartbeat_callback.start()
|
|
208
|
+
|
|
209
|
+
async def _execute(self, code: str, execution_timeout: int = 15) -> str:
|
|
210
|
+
"""Execute code in Jupyter kernel and return output.
|
|
211
|
+
|
|
212
|
+
Args:
|
|
213
|
+
code: Python code to execute
|
|
214
|
+
execution_timeout: Execution timeout in seconds
|
|
215
|
+
|
|
216
|
+
Returns:
|
|
217
|
+
String output from the kernel
|
|
218
|
+
"""
|
|
219
|
+
from tornado.escape import json_decode, json_encode
|
|
220
|
+
from tornado.httpclient import AsyncHTTPClient
|
|
221
|
+
|
|
222
|
+
if not self._ws:
|
|
223
|
+
await self._connect()
|
|
224
|
+
|
|
225
|
+
msg_id = uuid4().hex
|
|
226
|
+
self._ws.write_message(
|
|
227
|
+
json_encode(
|
|
228
|
+
{
|
|
229
|
+
"header": {
|
|
230
|
+
"username": "",
|
|
231
|
+
"version": "5.0",
|
|
232
|
+
"session": "",
|
|
233
|
+
"msg_id": msg_id,
|
|
234
|
+
"msg_type": "execute_request",
|
|
235
|
+
},
|
|
236
|
+
"parent_header": {},
|
|
237
|
+
"channel": "shell",
|
|
238
|
+
"content": {
|
|
239
|
+
"code": code,
|
|
240
|
+
"silent": False,
|
|
241
|
+
"store_history": False,
|
|
242
|
+
"user_expressions": {},
|
|
243
|
+
"allow_stdin": False,
|
|
244
|
+
},
|
|
245
|
+
"metadata": {},
|
|
246
|
+
"buffers": {},
|
|
247
|
+
}
|
|
248
|
+
)
|
|
249
|
+
)
|
|
250
|
+
|
|
251
|
+
outputs: list[str] = []
|
|
252
|
+
|
|
253
|
+
async def wait_for_messages() -> bool:
|
|
254
|
+
execution_done = False
|
|
255
|
+
while not execution_done:
|
|
256
|
+
msg = await self._ws.read_message()
|
|
257
|
+
msg = json_decode(msg)
|
|
258
|
+
msg_type = msg["msg_type"]
|
|
259
|
+
parent_msg_id = msg["parent_header"].get("msg_id", None)
|
|
260
|
+
|
|
261
|
+
if parent_msg_id != msg_id:
|
|
262
|
+
continue
|
|
263
|
+
|
|
264
|
+
if msg_type == "error":
|
|
265
|
+
traceback = "\n\n\n\n".join(msg["content"]["traceback"])
|
|
266
|
+
outputs.append(traceback)
|
|
267
|
+
execution_done = True
|
|
268
|
+
elif msg_type == "stream":
|
|
269
|
+
outputs.append(msg["content"]["text"])
|
|
270
|
+
elif msg_type in ["execute_result", "display_data"]:
|
|
271
|
+
outputs.append(msg["content"]["data"]["text/plain"])
|
|
272
|
+
# Handle image outputs
|
|
273
|
+
if "image/png" in msg["content"]["data"]:
|
|
274
|
+
outputs.append(
|
|
275
|
+
f""
|
|
276
|
+
)
|
|
277
|
+
elif msg_type == "execute_reply":
|
|
278
|
+
execution_done = True
|
|
279
|
+
return execution_done
|
|
280
|
+
|
|
281
|
+
async def interrupt_kernel() -> None:
|
|
282
|
+
client = AsyncHTTPClient()
|
|
283
|
+
interrupt_response = await client.fetch(
|
|
284
|
+
f"{self._base_url}/api/kernels/{self._kernel_id}/interrupt",
|
|
285
|
+
method="POST",
|
|
286
|
+
body=json_encode({"kernel_id": self._kernel_id}),
|
|
287
|
+
)
|
|
288
|
+
logger.info("Kernel interrupted: %s", interrupt_response)
|
|
289
|
+
|
|
290
|
+
try:
|
|
291
|
+
await asyncio.wait_for(wait_for_messages(), execution_timeout)
|
|
292
|
+
except TimeoutError:
|
|
293
|
+
await interrupt_kernel()
|
|
294
|
+
return f"[Execution timed out ({execution_timeout} seconds).]"
|
|
295
|
+
|
|
296
|
+
ret = "".join(outputs)
|
|
297
|
+
|
|
298
|
+
# Remove ANSI escape sequences
|
|
299
|
+
return strip_ansi(ret)
|
|
300
|
+
|
|
301
|
+
async def shutdown(self) -> None:
|
|
302
|
+
"""Shutdown the kernel connection."""
|
|
303
|
+
from tornado.httpclient import AsyncHTTPClient
|
|
304
|
+
|
|
305
|
+
if self._kernel_id:
|
|
306
|
+
client = AsyncHTTPClient()
|
|
307
|
+
try:
|
|
308
|
+
await client.fetch(
|
|
309
|
+
f"{self._base_url}/api/kernels/{self._kernel_id}",
|
|
310
|
+
method="DELETE",
|
|
311
|
+
)
|
|
312
|
+
logger.info("Kernel %s shut down", self._kernel_id)
|
|
313
|
+
except Exception as e:
|
|
314
|
+
logger.warning("Error shutting down kernel: %s", e)
|
|
315
|
+
|
|
316
|
+
self._kernel_id = ""
|
|
317
|
+
|
|
318
|
+
if self._heartbeat_callback:
|
|
319
|
+
self._heartbeat_callback.stop()
|
|
320
|
+
self._heartbeat_callback = None
|
|
321
|
+
|
|
322
|
+
if self._ws:
|
|
323
|
+
self._ws.close()
|
|
324
|
+
self._ws = None
|
|
325
|
+
|
|
326
|
+
self._initialized = False
|
|
327
|
+
|
|
328
|
+
def get_kernel_id(self) -> str:
|
|
329
|
+
"""Get the jupyter kernel id."""
|
|
330
|
+
return self._kernel_id
|