hud-python 0.4.45__py3-none-any.whl → 0.5.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (282) hide show
  1. hud/__init__.py +27 -7
  2. hud/agents/__init__.py +70 -5
  3. hud/agents/base.py +238 -500
  4. hud/agents/claude.py +236 -247
  5. hud/agents/gateway.py +42 -0
  6. hud/agents/gemini.py +264 -0
  7. hud/agents/gemini_cua.py +324 -0
  8. hud/agents/grounded_openai.py +98 -100
  9. hud/agents/misc/integration_test_agent.py +51 -20
  10. hud/agents/misc/response_agent.py +48 -36
  11. hud/agents/openai.py +282 -296
  12. hud/agents/{openai_chat_generic.py → openai_chat.py} +63 -33
  13. hud/agents/operator.py +199 -0
  14. hud/agents/resolver.py +70 -0
  15. hud/agents/tests/conftest.py +133 -0
  16. hud/agents/tests/test_base.py +300 -622
  17. hud/agents/tests/test_base_runtime.py +233 -0
  18. hud/agents/tests/test_claude.py +381 -214
  19. hud/agents/tests/test_client.py +9 -10
  20. hud/agents/tests/test_gemini.py +369 -0
  21. hud/agents/tests/test_grounded_openai_agent.py +65 -50
  22. hud/agents/tests/test_openai.py +377 -140
  23. hud/agents/tests/test_operator.py +362 -0
  24. hud/agents/tests/test_resolver.py +192 -0
  25. hud/agents/tests/test_run_eval.py +179 -0
  26. hud/agents/types.py +148 -0
  27. hud/cli/__init__.py +493 -546
  28. hud/cli/analyze.py +43 -5
  29. hud/cli/build.py +699 -113
  30. hud/cli/debug.py +8 -5
  31. hud/cli/dev.py +889 -732
  32. hud/cli/eval.py +793 -667
  33. hud/cli/flows/dev.py +167 -0
  34. hud/cli/flows/init.py +191 -0
  35. hud/cli/flows/tasks.py +153 -56
  36. hud/cli/flows/templates.py +151 -0
  37. hud/cli/flows/tests/__init__.py +1 -0
  38. hud/cli/flows/tests/test_dev.py +126 -0
  39. hud/cli/init.py +60 -58
  40. hud/cli/pull.py +1 -1
  41. hud/cli/push.py +38 -13
  42. hud/cli/rft.py +311 -0
  43. hud/cli/rft_status.py +145 -0
  44. hud/cli/tests/test_analyze.py +5 -5
  45. hud/cli/tests/test_analyze_metadata.py +3 -2
  46. hud/cli/tests/test_analyze_module.py +120 -0
  47. hud/cli/tests/test_build.py +110 -8
  48. hud/cli/tests/test_build_failure.py +41 -0
  49. hud/cli/tests/test_build_module.py +50 -0
  50. hud/cli/tests/test_cli_init.py +6 -1
  51. hud/cli/tests/test_cli_more_wrappers.py +30 -0
  52. hud/cli/tests/test_cli_root.py +140 -0
  53. hud/cli/tests/test_convert.py +361 -0
  54. hud/cli/tests/test_debug.py +12 -10
  55. hud/cli/tests/test_dev.py +197 -0
  56. hud/cli/tests/test_eval.py +251 -0
  57. hud/cli/tests/test_eval_bedrock.py +51 -0
  58. hud/cli/tests/test_init.py +124 -0
  59. hud/cli/tests/test_main_module.py +11 -5
  60. hud/cli/tests/test_mcp_server.py +12 -100
  61. hud/cli/tests/test_push.py +1 -1
  62. hud/cli/tests/test_push_happy.py +74 -0
  63. hud/cli/tests/test_push_wrapper.py +23 -0
  64. hud/cli/tests/test_registry.py +1 -1
  65. hud/cli/tests/test_utils.py +1 -1
  66. hud/cli/{rl → utils}/celebrate.py +14 -12
  67. hud/cli/utils/config.py +18 -1
  68. hud/cli/utils/docker.py +130 -4
  69. hud/cli/utils/env_check.py +9 -9
  70. hud/cli/utils/git.py +136 -0
  71. hud/cli/utils/interactive.py +39 -5
  72. hud/cli/utils/metadata.py +70 -1
  73. hud/cli/utils/runner.py +1 -1
  74. hud/cli/utils/server.py +2 -2
  75. hud/cli/utils/source_hash.py +3 -3
  76. hud/cli/utils/tasks.py +4 -1
  77. hud/cli/utils/tests/__init__.py +0 -0
  78. hud/cli/utils/tests/test_config.py +58 -0
  79. hud/cli/utils/tests/test_docker.py +93 -0
  80. hud/cli/utils/tests/test_docker_hints.py +71 -0
  81. hud/cli/utils/tests/test_env_check.py +74 -0
  82. hud/cli/utils/tests/test_environment.py +42 -0
  83. hud/cli/utils/tests/test_git.py +142 -0
  84. hud/cli/utils/tests/test_interactive_module.py +60 -0
  85. hud/cli/utils/tests/test_local_runner.py +50 -0
  86. hud/cli/utils/tests/test_logging_utils.py +23 -0
  87. hud/cli/utils/tests/test_metadata.py +49 -0
  88. hud/cli/utils/tests/test_package_runner.py +35 -0
  89. hud/cli/utils/tests/test_registry_utils.py +49 -0
  90. hud/cli/utils/tests/test_remote_runner.py +25 -0
  91. hud/cli/utils/tests/test_runner_modules.py +52 -0
  92. hud/cli/utils/tests/test_source_hash.py +36 -0
  93. hud/cli/utils/tests/test_tasks.py +80 -0
  94. hud/cli/utils/version_check.py +258 -0
  95. hud/cli/{rl → utils}/viewer.py +2 -2
  96. hud/clients/README.md +12 -11
  97. hud/clients/__init__.py +4 -3
  98. hud/clients/base.py +166 -26
  99. hud/clients/environment.py +51 -0
  100. hud/clients/fastmcp.py +13 -6
  101. hud/clients/mcp_use.py +45 -15
  102. hud/clients/tests/test_analyze_scenarios.py +206 -0
  103. hud/clients/tests/test_protocol.py +9 -3
  104. hud/datasets/__init__.py +23 -20
  105. hud/datasets/loader.py +326 -0
  106. hud/datasets/runner.py +198 -105
  107. hud/datasets/tests/__init__.py +0 -0
  108. hud/datasets/tests/test_loader.py +221 -0
  109. hud/datasets/tests/test_utils.py +315 -0
  110. hud/datasets/utils.py +270 -90
  111. hud/environment/__init__.py +52 -0
  112. hud/environment/connection.py +258 -0
  113. hud/environment/connectors/__init__.py +33 -0
  114. hud/environment/connectors/base.py +68 -0
  115. hud/environment/connectors/local.py +177 -0
  116. hud/environment/connectors/mcp_config.py +137 -0
  117. hud/environment/connectors/openai.py +101 -0
  118. hud/environment/connectors/remote.py +172 -0
  119. hud/environment/environment.py +835 -0
  120. hud/environment/integrations/__init__.py +45 -0
  121. hud/environment/integrations/adk.py +67 -0
  122. hud/environment/integrations/anthropic.py +196 -0
  123. hud/environment/integrations/gemini.py +92 -0
  124. hud/environment/integrations/langchain.py +82 -0
  125. hud/environment/integrations/llamaindex.py +68 -0
  126. hud/environment/integrations/openai.py +238 -0
  127. hud/environment/mock.py +306 -0
  128. hud/environment/router.py +263 -0
  129. hud/environment/scenarios.py +620 -0
  130. hud/environment/tests/__init__.py +1 -0
  131. hud/environment/tests/test_connection.py +317 -0
  132. hud/environment/tests/test_connectors.py +205 -0
  133. hud/environment/tests/test_environment.py +593 -0
  134. hud/environment/tests/test_integrations.py +257 -0
  135. hud/environment/tests/test_local_connectors.py +242 -0
  136. hud/environment/tests/test_scenarios.py +1086 -0
  137. hud/environment/tests/test_tools.py +208 -0
  138. hud/environment/types.py +23 -0
  139. hud/environment/utils/__init__.py +35 -0
  140. hud/environment/utils/formats.py +215 -0
  141. hud/environment/utils/schema.py +171 -0
  142. hud/environment/utils/tool_wrappers.py +113 -0
  143. hud/eval/__init__.py +67 -0
  144. hud/eval/context.py +727 -0
  145. hud/eval/display.py +299 -0
  146. hud/eval/instrument.py +187 -0
  147. hud/eval/manager.py +533 -0
  148. hud/eval/parallel.py +268 -0
  149. hud/eval/task.py +372 -0
  150. hud/eval/tests/__init__.py +1 -0
  151. hud/eval/tests/test_context.py +178 -0
  152. hud/eval/tests/test_eval.py +210 -0
  153. hud/eval/tests/test_manager.py +152 -0
  154. hud/eval/tests/test_parallel.py +168 -0
  155. hud/eval/tests/test_task.py +291 -0
  156. hud/eval/types.py +65 -0
  157. hud/eval/utils.py +194 -0
  158. hud/patches/__init__.py +19 -0
  159. hud/patches/mcp_patches.py +308 -0
  160. hud/patches/warnings.py +54 -0
  161. hud/samples/browser.py +4 -4
  162. hud/server/__init__.py +2 -1
  163. hud/server/low_level.py +2 -1
  164. hud/server/router.py +164 -0
  165. hud/server/server.py +567 -80
  166. hud/server/tests/test_mcp_server_integration.py +11 -11
  167. hud/server/tests/test_mcp_server_more.py +1 -1
  168. hud/server/tests/test_server_extra.py +2 -0
  169. hud/settings.py +45 -3
  170. hud/shared/exceptions.py +36 -10
  171. hud/shared/hints.py +26 -1
  172. hud/shared/requests.py +15 -3
  173. hud/shared/tests/test_exceptions.py +40 -31
  174. hud/shared/tests/test_hints.py +167 -0
  175. hud/telemetry/__init__.py +20 -19
  176. hud/telemetry/exporter.py +201 -0
  177. hud/telemetry/instrument.py +165 -253
  178. hud/telemetry/tests/test_eval_telemetry.py +356 -0
  179. hud/telemetry/tests/test_exporter.py +258 -0
  180. hud/telemetry/tests/test_instrument.py +401 -0
  181. hud/tools/__init__.py +18 -2
  182. hud/tools/agent.py +223 -0
  183. hud/tools/apply_patch.py +639 -0
  184. hud/tools/base.py +54 -4
  185. hud/tools/bash.py +2 -2
  186. hud/tools/computer/__init__.py +36 -3
  187. hud/tools/computer/anthropic.py +2 -2
  188. hud/tools/computer/gemini.py +385 -0
  189. hud/tools/computer/hud.py +23 -6
  190. hud/tools/computer/openai.py +20 -21
  191. hud/tools/computer/qwen.py +434 -0
  192. hud/tools/computer/settings.py +37 -0
  193. hud/tools/edit.py +3 -7
  194. hud/tools/executors/base.py +4 -2
  195. hud/tools/executors/pyautogui.py +1 -1
  196. hud/tools/grounding/grounded_tool.py +13 -18
  197. hud/tools/grounding/grounder.py +10 -31
  198. hud/tools/grounding/tests/test_grounded_tool.py +26 -44
  199. hud/tools/jupyter.py +330 -0
  200. hud/tools/playwright.py +18 -3
  201. hud/tools/shell.py +308 -0
  202. hud/tools/tests/test_agent_tool.py +355 -0
  203. hud/tools/tests/test_apply_patch.py +718 -0
  204. hud/tools/tests/test_computer.py +4 -9
  205. hud/tools/tests/test_computer_actions.py +24 -2
  206. hud/tools/tests/test_jupyter_tool.py +181 -0
  207. hud/tools/tests/test_shell.py +596 -0
  208. hud/tools/tests/test_submit.py +85 -0
  209. hud/tools/tests/test_types.py +193 -0
  210. hud/tools/types.py +21 -1
  211. hud/types.py +194 -56
  212. hud/utils/__init__.py +2 -0
  213. hud/utils/env.py +67 -0
  214. hud/utils/hud_console.py +89 -18
  215. hud/utils/mcp.py +15 -58
  216. hud/utils/strict_schema.py +162 -0
  217. hud/utils/tests/test_init.py +1 -2
  218. hud/utils/tests/test_mcp.py +1 -28
  219. hud/utils/tests/test_pretty_errors.py +186 -0
  220. hud/utils/tests/test_tool_shorthand.py +154 -0
  221. hud/utils/tests/test_version.py +1 -1
  222. hud/utils/types.py +20 -0
  223. hud/version.py +1 -1
  224. hud_python-0.5.13.dist-info/METADATA +264 -0
  225. hud_python-0.5.13.dist-info/RECORD +305 -0
  226. {hud_python-0.4.45.dist-info → hud_python-0.5.13.dist-info}/WHEEL +1 -1
  227. hud/agents/langchain.py +0 -261
  228. hud/agents/lite_llm.py +0 -72
  229. hud/cli/rl/__init__.py +0 -180
  230. hud/cli/rl/config.py +0 -101
  231. hud/cli/rl/display.py +0 -133
  232. hud/cli/rl/gpu.py +0 -63
  233. hud/cli/rl/gpu_utils.py +0 -321
  234. hud/cli/rl/local_runner.py +0 -595
  235. hud/cli/rl/presets.py +0 -96
  236. hud/cli/rl/remote_runner.py +0 -463
  237. hud/cli/rl/rl_api.py +0 -150
  238. hud/cli/rl/vllm.py +0 -177
  239. hud/cli/rl/wait_utils.py +0 -89
  240. hud/datasets/parallel.py +0 -687
  241. hud/misc/__init__.py +0 -1
  242. hud/misc/claude_plays_pokemon.py +0 -292
  243. hud/otel/__init__.py +0 -35
  244. hud/otel/collector.py +0 -142
  245. hud/otel/config.py +0 -181
  246. hud/otel/context.py +0 -570
  247. hud/otel/exporters.py +0 -369
  248. hud/otel/instrumentation.py +0 -135
  249. hud/otel/processors.py +0 -121
  250. hud/otel/tests/__init__.py +0 -1
  251. hud/otel/tests/test_processors.py +0 -197
  252. hud/rl/README.md +0 -30
  253. hud/rl/__init__.py +0 -1
  254. hud/rl/actor.py +0 -176
  255. hud/rl/buffer.py +0 -405
  256. hud/rl/chat_template.jinja +0 -101
  257. hud/rl/config.py +0 -192
  258. hud/rl/distributed.py +0 -132
  259. hud/rl/learner.py +0 -637
  260. hud/rl/tests/__init__.py +0 -1
  261. hud/rl/tests/test_learner.py +0 -186
  262. hud/rl/train.py +0 -382
  263. hud/rl/types.py +0 -101
  264. hud/rl/utils/start_vllm_server.sh +0 -30
  265. hud/rl/utils.py +0 -524
  266. hud/rl/vllm_adapter.py +0 -143
  267. hud/telemetry/job.py +0 -352
  268. hud/telemetry/replay.py +0 -74
  269. hud/telemetry/tests/test_replay.py +0 -40
  270. hud/telemetry/tests/test_trace.py +0 -63
  271. hud/telemetry/trace.py +0 -158
  272. hud/utils/agent_factories.py +0 -86
  273. hud/utils/async_utils.py +0 -65
  274. hud/utils/group_eval.py +0 -223
  275. hud/utils/progress.py +0 -149
  276. hud/utils/tasks.py +0 -127
  277. hud/utils/tests/test_async_utils.py +0 -173
  278. hud/utils/tests/test_progress.py +0 -261
  279. hud_python-0.4.45.dist-info/METADATA +0 -552
  280. hud_python-0.4.45.dist-info/RECORD +0 -228
  281. {hud_python-0.4.45.dist-info → hud_python-0.5.13.dist-info}/entry_points.txt +0 -0
  282. {hud_python-0.4.45.dist-info → hud_python-0.5.13.dist-info}/licenses/LICENSE +0 -0
@@ -3,14 +3,15 @@
3
3
  from __future__ import annotations
4
4
 
5
5
  import logging
6
- from typing import Any
6
+ from typing import TYPE_CHECKING, Any
7
7
 
8
8
  from mcp import ErrorData, McpError
9
9
  from mcp.types import INVALID_PARAMS, ContentBlock
10
10
 
11
- from hud.clients.base import AgentMCPClient # noqa: TC001
12
11
  from hud.tools.grounding.grounder import Grounder # noqa: TC001
13
- from hud.types import MCPToolCall
12
+
13
+ if TYPE_CHECKING:
14
+ from hud.environment import Environment
14
15
 
15
16
  logger = logging.getLogger(__name__)
16
17
 
@@ -33,18 +34,18 @@ class GroundedComputerTool:
33
34
  self,
34
35
  *,
35
36
  grounder: Grounder,
36
- mcp_client: AgentMCPClient,
37
+ ctx: Environment,
37
38
  computer_tool_name: str = "computer",
38
39
  ) -> None:
39
40
  """Initialize the grounded computer tool.
40
41
 
41
42
  Args:
42
43
  grounder: Grounder instance for visual grounding
43
- mcp_client: MCP client to call the environment's computer tool
44
+ ctx: Environment or EvalContext to call tools through
44
45
  computer_tool_name: Name of the computer tool in the environment
45
46
  """
46
47
  self._grounder = grounder
47
- self._mcp_client = mcp_client
48
+ self._ctx = ctx
48
49
  self._computer_tool_name = computer_tool_name
49
50
 
50
51
  def get_openai_tool_schema(self) -> dict:
@@ -172,10 +173,8 @@ class GroundedComputerTool:
172
173
  if keys is not None:
173
174
  computer_args["keys"] = keys
174
175
 
175
- result = await self._mcp_client.call_tool(
176
- MCPToolCall(
177
- name=self._computer_tool_name, arguments={**computer_args, **kwargs}
178
- )
176
+ result = await self._ctx.call_tool(
177
+ (self._computer_tool_name, {**computer_args, **kwargs})
179
178
  )
180
179
  return result.content
181
180
 
@@ -224,10 +223,8 @@ class GroundedComputerTool:
224
223
  if scroll_y is not None:
225
224
  computer_args["scroll_y"] = scroll_y
226
225
 
227
- result = await self._mcp_client.call_tool(
228
- MCPToolCall(
229
- name=self._computer_tool_name, arguments={**computer_args, **kwargs}
230
- )
226
+ result = await self._ctx.call_tool(
227
+ (self._computer_tool_name, {**computer_args, **kwargs})
231
228
  )
232
229
  return result.content
233
230
 
@@ -292,10 +289,8 @@ class GroundedComputerTool:
292
289
  if button:
293
290
  computer_args["button"] = button
294
291
 
295
- result = await self._mcp_client.call_tool(
296
- MCPToolCall(
297
- name=self._computer_tool_name, arguments={**computer_args, **kwargs}
298
- )
292
+ result = await self._ctx.call_tool(
293
+ (self._computer_tool_name, {**computer_args, **kwargs})
299
294
  )
300
295
  return result.content
301
296
 
@@ -4,15 +4,15 @@ from __future__ import annotations
4
4
 
5
5
  import base64
6
6
  import io
7
- import json
7
+ import logging
8
8
  import re
9
9
 
10
10
  from openai import AsyncOpenAI
11
- from opentelemetry import trace
12
11
 
13
- from hud import instrument
14
12
  from hud.tools.grounding.config import GrounderConfig # noqa: TC001
15
13
 
14
+ logger = logging.getLogger(__name__)
15
+
16
16
 
17
17
  class Grounder:
18
18
  """Grounder that uses AsyncOpenAI to call vLLM or other model endpoints for visual grounding.
@@ -181,12 +181,6 @@ class Grounder:
181
181
 
182
182
  return (final_x, final_y)
183
183
 
184
- @instrument(
185
- name="Grounding.predict_click",
186
- span_type="agent",
187
- record_args=True,
188
- record_result=True,
189
- )
190
184
  async def predict_click(
191
185
  self, *, image_b64: str, instruction: str, max_retries: int = 3
192
186
  ) -> tuple[int, int] | None:
@@ -247,12 +241,7 @@ class Grounder:
247
241
 
248
242
  # Extract response text
249
243
  response_text = response.choices[0].message.content
250
-
251
- # Manually record the raw response in the span
252
- span = trace.get_current_span()
253
- if span and span.is_recording():
254
- span.set_attribute("grounder.raw_response", json.dumps(response.model_dump()))
255
- span.set_attribute("grounder.attempt", attempt + 1)
244
+ logger.debug("Grounder attempt %d response: %s", attempt + 1, response_text)
256
245
 
257
246
  # Parse coordinates from response
258
247
  if response_text is None:
@@ -277,26 +266,16 @@ class Grounder:
277
266
  y = max(0, min(y, original_size[1] - 1))
278
267
  pixel_coords = (x, y)
279
268
 
280
- # Record successful grounding in span
281
- span = trace.get_current_span()
282
- if span and span.is_recording():
283
- span.set_attribute("grounder.success", True)
284
- span.set_attribute(
285
- "grounder.final_coords", f"{pixel_coords[0]},{pixel_coords[1]}"
286
- )
287
- span.set_attribute("grounder.total_attempts", attempt + 1)
288
-
269
+ logger.debug(
270
+ "Grounder success: coords=%s after %d attempts",
271
+ pixel_coords,
272
+ attempt + 1,
273
+ )
289
274
  return pixel_coords
290
275
 
291
276
  except Exception:
292
277
  if attempt < max_retries - 1:
293
278
  continue
294
279
 
295
- # Record failure in span
296
- span = trace.get_current_span()
297
- if span and span.is_recording():
298
- span.set_attribute("grounder.success", False)
299
- span.set_attribute("grounder.total_attempts", max_retries)
300
- span.set_attribute("grounder.failure_reason", "All attempts exhausted")
301
-
280
+ logger.debug("Grounder failed after %d attempts", max_retries)
302
281
  return None
@@ -7,7 +7,7 @@ import mcp.types as types
7
7
  import pytest
8
8
 
9
9
  from hud.tools.grounding.grounded_tool import GroundedComputerTool
10
- from hud.types import MCPToolCall, MCPToolResult
10
+ from hud.types import MCPToolResult
11
11
 
12
12
 
13
13
  @dataclass
@@ -17,36 +17,18 @@ class FakeResult:
17
17
  structuredContent: dict | None = None
18
18
 
19
19
 
20
- class FakeMCPClient:
21
- """Fake MCP client that implements AgentMCPClient protocol."""
22
-
23
- _initialized: bool
20
+ class FakeEnvironment:
21
+ """Fake Environment that implements the call_tool interface."""
24
22
 
25
23
  def __init__(self) -> None:
26
24
  self.calls: list[tuple[str, dict[str, Any]]] = []
27
- self._initialized = False
28
-
29
- @property
30
- def mcp_config(self) -> dict[str, dict[str, Any]]:
31
- return {"test": {"command": "echo", "args": ["test"]}}
32
-
33
- @property
34
- def is_connected(self) -> bool:
35
- return self._initialized
36
25
 
37
- async def initialize(self, mcp_config: dict[str, dict[str, Any]] | None = None) -> None:
38
- self._initialized = True
39
-
40
- async def list_tools(self) -> list[types.Tool]:
41
- return [types.Tool(name="computer", description="Test tool", inputSchema={})]
42
-
43
- async def call_tool(self, tool_call: MCPToolCall) -> MCPToolResult:
44
- self.calls.append((tool_call.name, tool_call.arguments or {}))
26
+ async def call_tool(self, call: tuple[str, dict[str, Any]], /, **kwargs: Any) -> MCPToolResult:
27
+ """Record the tool call and return a fake result."""
28
+ tool_name, tool_args = call
29
+ self.calls.append((tool_name, tool_args))
45
30
  return MCPToolResult(content=[types.TextContent(text="ok", type="text")], isError=False)
46
31
 
47
- async def shutdown(self) -> None:
48
- self._initialized = False
49
-
50
32
 
51
33
  class FakeGrounder:
52
34
  """Fake grounder that implements Grounder interface."""
@@ -72,9 +54,9 @@ def _png_b64() -> str:
72
54
 
73
55
  @pytest.mark.asyncio
74
56
  async def test_click_action_grounds_and_calls_mcp() -> None:
75
- client = FakeMCPClient()
57
+ ctx = FakeEnvironment()
76
58
  grounder = FakeGrounder(coords=(123, 456))
77
- tool = GroundedComputerTool(grounder=grounder, mcp_client=client) # type: ignore
59
+ tool = GroundedComputerTool(grounder=grounder, ctx=ctx) # type: ignore
78
60
 
79
61
  blocks = await tool(
80
62
  action="click",
@@ -87,14 +69,14 @@ async def test_click_action_grounds_and_calls_mcp() -> None:
87
69
  # Grounder called once
88
70
  assert len(grounder.calls) == 1
89
71
  # MCP called with resolved coordinates
90
- assert client.calls == [("computer", {"action": "click", "x": 123, "y": 456, "button": "left"})]
72
+ assert ctx.calls == [("computer", {"action": "click", "x": 123, "y": 456, "button": "left"})]
91
73
 
92
74
 
93
75
  @pytest.mark.asyncio
94
76
  async def test_move_and_scroll_require_element_description_and_screenshot() -> None:
95
- client = FakeMCPClient()
77
+ ctx = FakeEnvironment()
96
78
  grounder = FakeGrounder(coords=(5, 6))
97
- tool = GroundedComputerTool(grounder=grounder, mcp_client=client) # type: ignore
79
+ tool = GroundedComputerTool(grounder=grounder, ctx=ctx) # type: ignore
98
80
 
99
81
  # Missing element_description
100
82
  with pytest.raises(Exception) as ei:
@@ -109,9 +91,9 @@ async def test_move_and_scroll_require_element_description_and_screenshot() -> N
109
91
 
110
92
  @pytest.mark.asyncio
111
93
  async def test_drag_grounds_both_points_and_calls_mcp() -> None:
112
- client = FakeMCPClient()
94
+ ctx = FakeEnvironment()
113
95
  grounder = FakeGrounder(coords=(10, 20))
114
- tool = GroundedComputerTool(grounder=grounder, mcp_client=client) # type: ignore
96
+ tool = GroundedComputerTool(grounder=grounder, ctx=ctx) # type: ignore
115
97
 
116
98
  await tool(
117
99
  action="drag",
@@ -124,7 +106,7 @@ async def test_drag_grounds_both_points_and_calls_mcp() -> None:
124
106
  # Two grounding calls (start and end)
125
107
  assert len(grounder.calls) == 2
126
108
  # Drag path contains two points, same coords from fake grounder
127
- name, args = client.calls[0]
109
+ name, args = ctx.calls[0]
128
110
  assert name == "computer"
129
111
  assert args["action"] == "drag"
130
112
  assert args["button"] == "left"
@@ -133,9 +115,9 @@ async def test_drag_grounds_both_points_and_calls_mcp() -> None:
133
115
 
134
116
  @pytest.mark.asyncio
135
117
  async def test_drag_requires_both_descriptions_and_screenshot() -> None:
136
- client = FakeMCPClient()
118
+ ctx = FakeEnvironment()
137
119
  grounder = FakeGrounder()
138
- tool = GroundedComputerTool(grounder=grounder, mcp_client=client) # type: ignore
120
+ tool = GroundedComputerTool(grounder=grounder, ctx=ctx) # type: ignore
139
121
 
140
122
  with pytest.raises(Exception) as ei:
141
123
  await tool(action="drag", start_element_description="a", screenshot_b64=_png_b64())
@@ -152,9 +134,9 @@ async def test_drag_requires_both_descriptions_and_screenshot() -> None:
152
134
 
153
135
  @pytest.mark.asyncio
154
136
  async def test_direct_actions_bypass_grounding_and_call_mcp() -> None:
155
- client = FakeMCPClient()
137
+ ctx = FakeEnvironment()
156
138
  grounder = FakeGrounder()
157
- tool = GroundedComputerTool(grounder=grounder, mcp_client=client) # type: ignore
139
+ tool = GroundedComputerTool(grounder=grounder, ctx=ctx) # type: ignore
158
140
 
159
141
  # Actions that bypass grounding
160
142
  for action, extra in [
@@ -166,19 +148,19 @@ async def test_direct_actions_bypass_grounding_and_call_mcp() -> None:
166
148
  ("get_dimensions", {}),
167
149
  ("get_environment", {}),
168
150
  ]:
169
- client.calls.clear()
151
+ ctx.calls.clear()
170
152
  _ = await tool(action=action, **extra)
171
- assert client.calls and client.calls[0][0] == "computer"
172
- assert client.calls[0][1]["action"] == action
153
+ assert ctx.calls and ctx.calls[0][0] == "computer"
154
+ assert ctx.calls[0][1]["action"] == action
173
155
  # Grounder not invoked for these
174
156
  assert grounder.calls == []
175
157
 
176
158
 
177
159
  @pytest.mark.asyncio
178
160
  async def test_unsupported_action_raises() -> None:
179
- client = FakeMCPClient()
161
+ ctx = FakeEnvironment()
180
162
  grounder = FakeGrounder()
181
- tool = GroundedComputerTool(grounder=grounder, mcp_client=client) # type: ignore
163
+ tool = GroundedComputerTool(grounder=grounder, ctx=ctx) # type: ignore
182
164
 
183
165
  with pytest.raises(Exception) as ei:
184
166
  await tool(action="zoom")
@@ -187,9 +169,9 @@ async def test_unsupported_action_raises() -> None:
187
169
 
188
170
  @pytest.mark.asyncio
189
171
  async def test_grounding_failure_propagates_as_error() -> None:
190
- client = FakeMCPClient()
172
+ ctx = FakeEnvironment()
191
173
  grounder = FakeGrounder(coords=None)
192
- tool = GroundedComputerTool(grounder=grounder, mcp_client=client) # type: ignore
174
+ tool = GroundedComputerTool(grounder=grounder, ctx=ctx) # type: ignore
193
175
 
194
176
  with pytest.raises(Exception) as ei:
195
177
  await tool(action="click", element_description="x", screenshot_b64=_png_b64())
hud/tools/jupyter.py ADDED
@@ -0,0 +1,330 @@
1
+ """Jupyter execution tool.
2
+
3
+ Requires the [agents] extra: pip install hud-python[agents]
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ import asyncio
9
+ import logging
10
+ import re
11
+ from typing import TYPE_CHECKING, Any, ClassVar
12
+ from uuid import uuid4
13
+
14
+ from hud.tools.base import BaseTool
15
+ from hud.tools.types import ContentResult, ToolError
16
+
17
+ if TYPE_CHECKING:
18
+ from mcp.types import ContentBlock
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ def strip_ansi(output: str) -> str:
24
+ """Remove ANSI escape sequences from string output."""
25
+ pattern = re.compile(r"\x1B\[\d+(;\d+){0,2}m")
26
+ return pattern.sub("", output)
27
+
28
+
29
+ class JupyterTool(BaseTool):
30
+ """
31
+ Execute Python code in a Jupyter kernel.
32
+ """
33
+
34
+ # Class-level kernel registry for sharing kernels
35
+ _kernel_registry: ClassVar[dict[str, str]] = {}
36
+
37
+ @classmethod
38
+ def register_shared_kernel(cls, registry_name: str, kernel_id: str) -> None:
39
+ """Register a kernel_id with a name for reuse.
40
+
41
+ Args:
42
+ registry_name: Name to register the kernel under
43
+ kernel_id: The kernel ID to register
44
+ """
45
+ cls._kernel_registry[registry_name] = kernel_id
46
+ logger.info("Registered kernel '%s': %s", registry_name, kernel_id)
47
+
48
+ @classmethod
49
+ def from_shared_kernel(cls, registry_name: str, **kwargs: Any) -> JupyterTool:
50
+ """Connect to a kernel using its registry name.
51
+
52
+ Args:
53
+ registry_name: Name of the registered kernel
54
+ **kwargs: Additional parameters for JupyterTool (url_suffix, kernel_name)
55
+
56
+ Returns:
57
+ JupyterTool instance connected to the registered kernel
58
+ """
59
+ kernel_id = cls._kernel_registry.get(registry_name)
60
+ if not kernel_id:
61
+ raise ValueError(f"No kernel registered with name '{registry_name}'")
62
+
63
+ logger.info("Connecting to registered kernel '%s': %s", registry_name, kernel_id)
64
+ return cls(kernel_id=kernel_id, **kwargs)
65
+
66
+ def __init__(
67
+ self,
68
+ url_suffix: str = "localhost:8888",
69
+ kernel_name: str = "python3",
70
+ kernel_id: str = "",
71
+ ) -> None:
72
+ """Initialize JupyterTool with connection parameters.
73
+
74
+ Args:
75
+ url_suffix: (Optional) Kernel gateway host:port (default: localhost:8888)
76
+ kernel_name: (Optional) Kernel name to use (default: python3)
77
+ kernel_id: (Optional) If set, connect to the existed kernel with kernel_id.
78
+ If empty, create new kernel
79
+ """
80
+ # Check tornado is available
81
+ try:
82
+ import tornado # noqa: F401
83
+ except ImportError as e:
84
+ raise ImportError(
85
+ "JupyterTool requires the [agents] extra. "
86
+ "Install with: pip install hud-python[agents]"
87
+ ) from e
88
+
89
+ super().__init__(
90
+ env=None,
91
+ name="jupyter",
92
+ title="Jupyter Code Execution",
93
+ description="Execute Python code in a Jupyter kernel",
94
+ )
95
+
96
+ # Connection parameters
97
+ self._base_url = f"http://{url_suffix}"
98
+ self._base_ws_url = f"ws://{url_suffix}"
99
+ self._kernel_name = kernel_name
100
+
101
+ # Kernel state (reuse existing or create new)
102
+ self._kernel_id = kernel_id
103
+ self._ws: Any = None
104
+ self._initialized = False
105
+
106
+ # WebSocket heartbeat
107
+ self._heartbeat_interval = 10000 # 10 seconds
108
+ self._heartbeat_callback: Any = None
109
+
110
+ async def __call__(self, code: str, execution_timeout: int = 15) -> list[ContentBlock]:
111
+ """Execute Python code in the Jupyter kernel.
112
+
113
+ Args:
114
+ code: Python code to execute
115
+ execution_timeout: Execution timeout in seconds (default: 15)
116
+
117
+ Returns:
118
+ List of ContentBlock with execution results
119
+ """
120
+ try:
121
+ # Ensure kernel is ready (lazy initialization)
122
+ await self._ensure_kernel()
123
+
124
+ # Execute code
125
+ result = await self._execute(code, execution_timeout)
126
+
127
+ # Check for timeout
128
+ if result.startswith("[Execution timed out"):
129
+ return ContentResult(error=result).to_content_blocks()
130
+
131
+ # Return result
132
+ output = result if result.strip() else "Code executed successfully (no output)"
133
+ return ContentResult(output=output).to_content_blocks()
134
+
135
+ except Exception as e:
136
+ logger.error("Jupyter execution error: %s", e)
137
+ raise ToolError(f"Execution failed: {e!s}") from e
138
+
139
+ async def _ensure_kernel(self) -> None:
140
+ """Ensure kernel is initialized and connected."""
141
+ if not self._initialized:
142
+ logger.info("Initializing Jupyter kernel connection")
143
+ await self._connect()
144
+ self._initialized = True
145
+ logger.info("Jupyter kernel connected successfully")
146
+
147
+ async def _connect(self) -> None:
148
+ """Connect to Jupyter kernel via WebSocket."""
149
+ import tornado.iostream
150
+ from tornado.escape import json_decode, json_encode, url_escape
151
+ from tornado.httpclient import AsyncHTTPClient, HTTPRequest
152
+ from tornado.ioloop import PeriodicCallback
153
+ from tornado.websocket import websocket_connect
154
+
155
+ if self._ws:
156
+ self._ws.close()
157
+ self._ws = None
158
+
159
+ client = AsyncHTTPClient()
160
+ if not self._kernel_id:
161
+ # Start a new kernel
162
+ n_tries = 5
163
+ while n_tries > 0:
164
+ try:
165
+ response = await client.fetch(
166
+ f"{self._base_url}/api/kernels",
167
+ method="POST",
168
+ body=json_encode({"name": self._kernel_name}),
169
+ )
170
+ kernel = json_decode(response.body)
171
+ self._kernel_id = kernel["id"]
172
+ logger.info("Kernel started with ID: %s", self._kernel_id)
173
+ break
174
+ except Exception as e:
175
+ logger.warning("Kernel connection attempt failed: %s", e)
176
+ n_tries -= 1
177
+ await asyncio.sleep(1)
178
+
179
+ if n_tries == 0:
180
+ raise ConnectionRefusedError("Failed to connect to kernel gateway")
181
+
182
+ # Connect WebSocket to kernel
183
+ ws_req = HTTPRequest(
184
+ url=f"{self._base_ws_url}/api/kernels/{url_escape(self._kernel_id)}/channels"
185
+ )
186
+ self._ws = await websocket_connect(ws_req)
187
+ logger.info("WebSocket connected to kernel")
188
+
189
+ # Setup heartbeat to keep connection alive
190
+ if self._heartbeat_callback:
191
+ self._heartbeat_callback.stop()
192
+
193
+ async def heartbeat() -> None:
194
+ if not self._ws:
195
+ return
196
+ try:
197
+ self._ws.ping()
198
+ except tornado.iostream.StreamClosedError:
199
+ try:
200
+ await self._connect()
201
+ except ConnectionRefusedError:
202
+ logger.warning(
203
+ "Failed to reconnect to kernel websocket - Is the kernel still running?"
204
+ )
205
+
206
+ self._heartbeat_callback = PeriodicCallback(heartbeat, self._heartbeat_interval)
207
+ self._heartbeat_callback.start()
208
+
209
+ async def _execute(self, code: str, execution_timeout: int = 15) -> str:
210
+ """Execute code in Jupyter kernel and return output.
211
+
212
+ Args:
213
+ code: Python code to execute
214
+ execution_timeout: Execution timeout in seconds
215
+
216
+ Returns:
217
+ String output from the kernel
218
+ """
219
+ from tornado.escape import json_decode, json_encode
220
+ from tornado.httpclient import AsyncHTTPClient
221
+
222
+ if not self._ws:
223
+ await self._connect()
224
+
225
+ msg_id = uuid4().hex
226
+ self._ws.write_message(
227
+ json_encode(
228
+ {
229
+ "header": {
230
+ "username": "",
231
+ "version": "5.0",
232
+ "session": "",
233
+ "msg_id": msg_id,
234
+ "msg_type": "execute_request",
235
+ },
236
+ "parent_header": {},
237
+ "channel": "shell",
238
+ "content": {
239
+ "code": code,
240
+ "silent": False,
241
+ "store_history": False,
242
+ "user_expressions": {},
243
+ "allow_stdin": False,
244
+ },
245
+ "metadata": {},
246
+ "buffers": {},
247
+ }
248
+ )
249
+ )
250
+
251
+ outputs: list[str] = []
252
+
253
+ async def wait_for_messages() -> bool:
254
+ execution_done = False
255
+ while not execution_done:
256
+ msg = await self._ws.read_message()
257
+ msg = json_decode(msg)
258
+ msg_type = msg["msg_type"]
259
+ parent_msg_id = msg["parent_header"].get("msg_id", None)
260
+
261
+ if parent_msg_id != msg_id:
262
+ continue
263
+
264
+ if msg_type == "error":
265
+ traceback = "\n\n\n\n".join(msg["content"]["traceback"])
266
+ outputs.append(traceback)
267
+ execution_done = True
268
+ elif msg_type == "stream":
269
+ outputs.append(msg["content"]["text"])
270
+ elif msg_type in ["execute_result", "display_data"]:
271
+ outputs.append(msg["content"]["data"]["text/plain"])
272
+ # Handle image outputs
273
+ if "image/png" in msg["content"]["data"]:
274
+ outputs.append(
275
+ f"![image](data:image/png;base64,{msg['content']['data']['image/png']})"
276
+ )
277
+ elif msg_type == "execute_reply":
278
+ execution_done = True
279
+ return execution_done
280
+
281
+ async def interrupt_kernel() -> None:
282
+ client = AsyncHTTPClient()
283
+ interrupt_response = await client.fetch(
284
+ f"{self._base_url}/api/kernels/{self._kernel_id}/interrupt",
285
+ method="POST",
286
+ body=json_encode({"kernel_id": self._kernel_id}),
287
+ )
288
+ logger.info("Kernel interrupted: %s", interrupt_response)
289
+
290
+ try:
291
+ await asyncio.wait_for(wait_for_messages(), execution_timeout)
292
+ except TimeoutError:
293
+ await interrupt_kernel()
294
+ return f"[Execution timed out ({execution_timeout} seconds).]"
295
+
296
+ ret = "".join(outputs)
297
+
298
+ # Remove ANSI escape sequences
299
+ return strip_ansi(ret)
300
+
301
+ async def shutdown(self) -> None:
302
+ """Shutdown the kernel connection."""
303
+ from tornado.httpclient import AsyncHTTPClient
304
+
305
+ if self._kernel_id:
306
+ client = AsyncHTTPClient()
307
+ try:
308
+ await client.fetch(
309
+ f"{self._base_url}/api/kernels/{self._kernel_id}",
310
+ method="DELETE",
311
+ )
312
+ logger.info("Kernel %s shut down", self._kernel_id)
313
+ except Exception as e:
314
+ logger.warning("Error shutting down kernel: %s", e)
315
+
316
+ self._kernel_id = ""
317
+
318
+ if self._heartbeat_callback:
319
+ self._heartbeat_callback.stop()
320
+ self._heartbeat_callback = None
321
+
322
+ if self._ws:
323
+ self._ws.close()
324
+ self._ws = None
325
+
326
+ self._initialized = False
327
+
328
+ def get_kernel_id(self) -> str:
329
+ """Get the jupyter kernel id."""
330
+ return self._kernel_id