hud-python 0.2.10__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of hud-python might be problematic. Click here for more details.

Files changed (64) hide show
  1. hud/__init__.py +14 -5
  2. hud/env/docker_client.py +1 -1
  3. hud/env/environment.py +10 -7
  4. hud/env/local_docker_client.py +1 -1
  5. hud/env/remote_client.py +1 -1
  6. hud/env/remote_docker_client.py +2 -2
  7. hud/exceptions.py +2 -1
  8. hud/mcp_agent/__init__.py +15 -0
  9. hud/mcp_agent/base.py +723 -0
  10. hud/mcp_agent/claude.py +316 -0
  11. hud/mcp_agent/langchain.py +231 -0
  12. hud/mcp_agent/openai.py +318 -0
  13. hud/mcp_agent/tests/__init__.py +1 -0
  14. hud/mcp_agent/tests/test_base.py +437 -0
  15. hud/settings.py +14 -2
  16. hud/task.py +4 -0
  17. hud/telemetry/__init__.py +11 -7
  18. hud/telemetry/_trace.py +82 -71
  19. hud/telemetry/context.py +9 -27
  20. hud/telemetry/exporter.py +6 -5
  21. hud/telemetry/instrumentation/mcp.py +174 -410
  22. hud/telemetry/mcp_models.py +13 -74
  23. hud/telemetry/tests/test_context.py +9 -6
  24. hud/telemetry/tests/test_trace.py +92 -61
  25. hud/tools/__init__.py +21 -0
  26. hud/tools/base.py +65 -0
  27. hud/tools/bash.py +137 -0
  28. hud/tools/computer/__init__.py +13 -0
  29. hud/tools/computer/anthropic.py +411 -0
  30. hud/tools/computer/hud.py +315 -0
  31. hud/tools/computer/openai.py +283 -0
  32. hud/tools/edit.py +290 -0
  33. hud/tools/executors/__init__.py +13 -0
  34. hud/tools/executors/base.py +331 -0
  35. hud/tools/executors/pyautogui.py +585 -0
  36. hud/tools/executors/tests/__init__.py +1 -0
  37. hud/tools/executors/tests/test_base_executor.py +338 -0
  38. hud/tools/executors/tests/test_pyautogui_executor.py +162 -0
  39. hud/tools/executors/xdo.py +503 -0
  40. hud/tools/helper/README.md +56 -0
  41. hud/tools/helper/__init__.py +9 -0
  42. hud/tools/helper/mcp_server.py +78 -0
  43. hud/tools/helper/server_initialization.py +115 -0
  44. hud/tools/helper/utils.py +58 -0
  45. hud/tools/playwright_tool.py +373 -0
  46. hud/tools/tests/__init__.py +3 -0
  47. hud/tools/tests/test_bash.py +152 -0
  48. hud/tools/tests/test_computer.py +52 -0
  49. hud/tools/tests/test_computer_actions.py +34 -0
  50. hud/tools/tests/test_edit.py +233 -0
  51. hud/tools/tests/test_init.py +27 -0
  52. hud/tools/tests/test_playwright_tool.py +183 -0
  53. hud/tools/tests/test_tools.py +154 -0
  54. hud/tools/tests/test_utils.py +156 -0
  55. hud/tools/utils.py +50 -0
  56. hud/types.py +10 -1
  57. hud/utils/tests/test_init.py +21 -0
  58. hud/utils/tests/test_version.py +1 -1
  59. hud/version.py +1 -1
  60. {hud_python-0.2.10.dist-info → hud_python-0.3.0.dist-info}/METADATA +9 -6
  61. hud_python-0.3.0.dist-info/RECORD +124 -0
  62. hud_python-0.2.10.dist-info/RECORD +0 -85
  63. {hud_python-0.2.10.dist-info → hud_python-0.3.0.dist-info}/WHEEL +0 -0
  64. {hud_python-0.2.10.dist-info → hud_python-0.3.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,115 @@
1
+ """Helper for MCP server initialization with progress notifications.
2
+
3
+ Example:
4
+ ```python
5
+ from hud.tools.helper import mcp_intialize_wrapper
6
+
7
+
8
+ @mcp_intialize_wrapper
9
+ async def initialize_environment(session=None, progress_token=None):
10
+ # Send progress if available
11
+ if session and progress_token:
12
+ await session.send_progress_notification(
13
+ progress_token=progress_token, progress=0, total=100, message="Starting services..."
14
+ )
15
+
16
+ # Your initialization code works with or without session
17
+ start_services()
18
+
19
+
20
+ # Create and run server - initialization happens automatically
21
+ mcp = FastMCP("My Server")
22
+ mcp.run()
23
+ ```
24
+ """
25
+
26
+ from __future__ import annotations
27
+
28
+ from typing import TYPE_CHECKING
29
+
30
+ import mcp.types as types
31
+ from mcp.server.session import ServerSession
32
+
33
+ if TYPE_CHECKING:
34
+ from collections.abc import Awaitable, Callable
35
+
36
+ from mcp.shared.session import RequestResponder
37
+
38
+ # Store the original _received_request method
39
+ _original_received_request = ServerSession._received_request
40
+ _init_function: Callable | None = None
41
+ _initialized = False
42
+
43
+
44
+ async def _patched_received_request(
45
+ self: ServerSession, responder: RequestResponder[types.ClientRequest, types.ServerResult]
46
+ ) -> types.ServerResult | None:
47
+ """Intercept initialization to run custom setup with progress notifications."""
48
+ global _initialized, _init_function
49
+
50
+ # Check if this is an initialization request
51
+ if isinstance(responder.request.root, types.InitializeRequest):
52
+ params = responder.request.root.params
53
+ # Extract progress token if present
54
+ progress_token = None
55
+ if hasattr(params, "meta") and params.meta and hasattr(params.meta, "progressToken"):
56
+ progress_token = params.meta.progressToken
57
+
58
+ # Run our initialization function if provided and not already done
59
+ if _init_function and not _initialized:
60
+ try:
61
+ await _init_function(session=self, progress_token=progress_token)
62
+ ServerSession._received_request = _original_received_request
63
+ except Exception as e:
64
+ if progress_token:
65
+ await self.send_progress_notification(
66
+ progress_token=progress_token,
67
+ progress=0,
68
+ total=100,
69
+ message=f"Initialization failed: {e!s}",
70
+ )
71
+ raise
72
+
73
+ # Call the original handler to send the InitializeResult
74
+ result = await _original_received_request(self, responder)
75
+ _initialized = True
76
+
77
+ return result
78
+
79
+
80
+ def mcp_intialize_wrapper(
81
+ init_function: Callable[[ServerSession | None, str | None], Awaitable[None]] | None = None,
82
+ ) -> Callable:
83
+ """Decorator to enable progress notifications during MCP server initialization.
84
+
85
+ Your init function receives optional session and progress_token parameters.
86
+ If provided, use them to send progress updates. If not, the function still works.
87
+
88
+ Usage:
89
+ @mcp_intialize_wrapper
90
+ async def initialize(session=None, progress_token=None):
91
+ if session and progress_token:
92
+ await session.send_progress_notification(...)
93
+ # Your init code here
94
+
95
+ Must be applied before creating FastMCP instance or calling mcp.run().
96
+ """
97
+ global _init_function
98
+
99
+ def decorator(func: Callable[[ServerSession | None, str | None], Awaitable[None]]) -> Callable:
100
+ global _init_function
101
+ # Store the initialization function
102
+ _init_function = func
103
+
104
+ # Apply the monkey patch if not already applied
105
+ if ServerSession._received_request != _patched_received_request:
106
+ ServerSession._received_request = _patched_received_request # type: ignore[assignment]
107
+
108
+ return func
109
+
110
+ # If called with a function directly
111
+ if init_function is not None:
112
+ return decorator(init_function)
113
+
114
+ # If used as @decorator
115
+ return decorator
@@ -0,0 +1,58 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import inspect
5
+ from functools import wraps
6
+ from typing import TYPE_CHECKING, Any
7
+
8
+ if TYPE_CHECKING:
9
+ from collections.abc import Callable
10
+
11
+ from mcp.server.fastmcp import FastMCP
12
+
13
+
14
+ def register_instance_tool(mcp: FastMCP, name: str, instance: Any) -> Callable[..., Any]:
15
+ """Register ``instance.__call__`` as a FastMCP tool.
16
+
17
+ Parameters
18
+ ----------
19
+ mcp:
20
+ A :class:`mcp.server.fastmcp.FastMCP` instance.
21
+ name:
22
+ Public tool name.
23
+ instance:
24
+ Object with an ``async def __call__`` (or sync) implementing the tool.
25
+ """
26
+
27
+ if inspect.isclass(instance):
28
+ class_name = instance.__name__
29
+ raise TypeError(
30
+ f"register_instance_tool() expects an instance, but got class '{class_name}'. "
31
+ f"Use: register_instance_tool(mcp, '{name}', {class_name}()) "
32
+ f"Not: register_instance_tool(mcp, '{name}', {class_name})"
33
+ )
34
+
35
+ call_fn = instance.__call__
36
+ sig = inspect.signature(call_fn)
37
+
38
+ # Remove *args/**kwargs so Pydantic doesn't treat them as required fields
39
+ from typing import Any as _Any
40
+
41
+ filtered = [
42
+ p.replace(kind=p.POSITIONAL_OR_KEYWORD, annotation=_Any)
43
+ for p in sig.parameters.values()
44
+ if p.kind not in (p.VAR_POSITIONAL, p.VAR_KEYWORD)
45
+ ]
46
+
47
+ public_sig = inspect.Signature(parameters=filtered, return_annotation=_Any)
48
+
49
+ @wraps(call_fn)
50
+ async def _wrapper(*args: Any, **kwargs: Any) -> Any: # type: ignore[override]
51
+ result = call_fn(*args, **kwargs)
52
+ if asyncio.iscoroutine(result):
53
+ result = await result
54
+ return result
55
+
56
+ _wrapper.__signature__ = public_sig # type: ignore[attr-defined]
57
+
58
+ return mcp.tool(name=name)(_wrapper)
@@ -0,0 +1,373 @@
1
+ """Playwright web automation tool for HUD."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ import os
7
+ from typing import TYPE_CHECKING, Any, Literal
8
+
9
+ from mcp import ErrorData, McpError
10
+ from mcp.types import INVALID_PARAMS, ImageContent, TextContent
11
+ from pydantic import Field
12
+
13
+ from hud.tools.base import ToolResult, tool_result_to_content_blocks
14
+
15
+ if TYPE_CHECKING:
16
+ from playwright.async_api import Browser, BrowserContext, Page
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ class PlaywrightTool:
22
+ """Playwright tool for web automation."""
23
+
24
+ def __init__(self) -> None:
25
+ super().__init__()
26
+ self._playwright = None
27
+ self._browser: Browser | None = None
28
+ self._context: BrowserContext | None = None
29
+ self._page: Page | None = None
30
+
31
+ @property
32
+ def page(self) -> Page:
33
+ """Get the current page, raising an error if not initialized."""
34
+ if self._page is None:
35
+ raise RuntimeError("Browser page is not initialized. Call ensure_browser_launched().")
36
+ return self._page
37
+
38
+ async def __call__(
39
+ self,
40
+ action: str = Field(
41
+ ...,
42
+ description="The action to perform (navigate, screenshot, click, type, get_page_info, wait_for_element)", # noqa: E501
43
+ ),
44
+ url: str | None = Field(None, description="URL to navigate to (for navigate action)"),
45
+ selector: str | None = Field(
46
+ None, description="CSS selector for element (for click, type, wait_for_element actions)"
47
+ ),
48
+ text: str | None = Field(None, description="Text to type (for type action)"),
49
+ path: str | None = Field(
50
+ None, description="File path to save screenshot (for screenshot action)"
51
+ ),
52
+ wait_for_load_state: Literal["commit", "domcontentloaded", "load", "networkidle"]
53
+ | None = Field(
54
+ None,
55
+ description="State to wait for: commit, domcontentloaded, load, networkidle (default: networkidle)", # noqa: E501
56
+ ),
57
+ ) -> list[ImageContent | TextContent]:
58
+ """
59
+ Execute a Playwright web automation action.
60
+
61
+ Returns:
62
+ List of MCP content blocks
63
+ """
64
+ logger.info("PlaywrightTool executing action: %s", action)
65
+
66
+ try:
67
+ if action == "navigate":
68
+ if url is None:
69
+ raise McpError(
70
+ ErrorData(
71
+ code=INVALID_PARAMS, message="url parameter is required for navigate"
72
+ )
73
+ )
74
+ result = await self.navigate(url, wait_for_load_state or "networkidle")
75
+
76
+ elif action == "screenshot":
77
+ result = await self.screenshot(path)
78
+
79
+ elif action == "click":
80
+ if selector is None:
81
+ raise McpError(
82
+ ErrorData(
83
+ code=INVALID_PARAMS, message="selector parameter is required for click"
84
+ )
85
+ )
86
+ result = await self.click(selector)
87
+
88
+ elif action == "type":
89
+ if selector is None:
90
+ raise McpError(
91
+ ErrorData(
92
+ code=INVALID_PARAMS, message="selector parameter is required for type"
93
+ )
94
+ )
95
+ if text is None:
96
+ raise McpError(
97
+ ErrorData(
98
+ code=INVALID_PARAMS, message="text parameter is required for type"
99
+ )
100
+ )
101
+ result = await self.type_text(selector, text)
102
+
103
+ elif action == "get_page_info":
104
+ result = await self.get_page_info()
105
+
106
+ elif action == "wait_for_element":
107
+ if selector is None:
108
+ raise McpError(
109
+ ErrorData(
110
+ code=INVALID_PARAMS,
111
+ message="selector parameter is required for wait_for_element",
112
+ )
113
+ )
114
+ result = await self.wait_for_element(selector)
115
+
116
+ else:
117
+ raise McpError(ErrorData(code=INVALID_PARAMS, message=f"Unknown action: {action}"))
118
+
119
+ # Convert dict result to ToolResult
120
+ if isinstance(result, dict):
121
+ if result.get("success"):
122
+ if "screenshot" in result:
123
+ # Return screenshot as image content
124
+ tool_result = ToolResult(
125
+ output=result.get("message", ""), base64_image=result["screenshot"]
126
+ )
127
+ else:
128
+ tool_result = ToolResult(output=result.get("message", ""))
129
+ else:
130
+ tool_result = ToolResult(error=result.get("error", "Unknown error"))
131
+ else:
132
+ tool_result = result
133
+
134
+ # Convert result to content blocks
135
+ return tool_result_to_content_blocks(tool_result)
136
+
137
+ except McpError:
138
+ raise
139
+ except Exception as e:
140
+ logger.error("PlaywrightTool error: %s", e)
141
+ raise McpError(ErrorData(code=INVALID_PARAMS, message=f"Playwright error: {e}")) from e
142
+
143
+ async def _ensure_browser(self) -> None:
144
+ """Ensure browser is launched and ready."""
145
+ if self._browser is None or not self._browser.is_connected():
146
+ logger.info("Launching Playwright browser...")
147
+
148
+ # Ensure DISPLAY is set
149
+ os.environ["DISPLAY"] = os.environ.get("DISPLAY", ":1")
150
+
151
+ if self._playwright is None:
152
+ try:
153
+ from playwright.async_api import async_playwright
154
+
155
+ self._playwright = await async_playwright().start()
156
+ except ImportError:
157
+ raise ImportError(
158
+ "Playwright is not installed. Please install with: pip install playwright"
159
+ ) from None
160
+
161
+ self._browser = await self._playwright.chromium.launch(
162
+ headless=False,
163
+ args=[
164
+ "--no-sandbox",
165
+ "--disable-dev-shm-usage",
166
+ "--disable-gpu",
167
+ "--disable-web-security",
168
+ "--disable-features=IsolateOrigins,site-per-process",
169
+ "--disable-blink-features=AutomationControlled",
170
+ "--window-size=1920,1080",
171
+ "--window-position=0,0",
172
+ "--start-maximized",
173
+ "--disable-background-timer-throttling",
174
+ "--disable-backgrounding-occluded-windows",
175
+ "--disable-renderer-backgrounding",
176
+ "--disable-features=TranslateUI",
177
+ "--disable-ipc-flooding-protection",
178
+ "--disable-default-apps",
179
+ "--no-first-run",
180
+ "--disable-sync",
181
+ "--no-default-browser-check",
182
+ ],
183
+ )
184
+
185
+ if self._browser is None:
186
+ raise RuntimeError("Browser failed to initialize")
187
+
188
+ self._context = await self._browser.new_context(
189
+ viewport={"width": 1920, "height": 1080},
190
+ ignore_https_errors=True,
191
+ )
192
+
193
+ if self._context is None:
194
+ raise RuntimeError("Browser context failed to initialize")
195
+
196
+ self._page = await self._context.new_page()
197
+ logger.info("Playwright browser launched successfully")
198
+
199
+ async def navigate(
200
+ self,
201
+ url: str,
202
+ wait_for_load_state: Literal[
203
+ "commit", "domcontentloaded", "load", "networkidle"
204
+ ] = "networkidle",
205
+ ) -> dict[str, Any]:
206
+ """Navigate to a URL.
207
+
208
+ Args:
209
+ url: URL to navigate to
210
+ wait_for_load_state: Load state to wait for (load, domcontentloaded, networkidle)
211
+
212
+ Returns:
213
+ Dict with navigation result
214
+ """
215
+ await self._ensure_browser()
216
+
217
+ logger.info("Navigating to %s", url)
218
+ try:
219
+ await self.page.goto(url, wait_until=wait_for_load_state)
220
+ current_url = self.page.url
221
+ title = await self.page.title()
222
+
223
+ return {
224
+ "success": True,
225
+ "url": current_url,
226
+ "title": title,
227
+ "message": f"Successfully navigated to {url}",
228
+ }
229
+ except Exception as e:
230
+ logger.error("Navigation failed: %s", e)
231
+ return {
232
+ "success": False,
233
+ "error": str(e),
234
+ "message": f"Failed to navigate to {url}: {e}",
235
+ }
236
+
237
+ async def screenshot(self, path: str | None = None) -> dict[str, Any]:
238
+ """Take a screenshot of the current page.
239
+
240
+ Args:
241
+ path: Optional path to save screenshot
242
+
243
+ Returns:
244
+ Dict with screenshot result
245
+ """
246
+ await self._ensure_browser()
247
+
248
+ try:
249
+ if path:
250
+ await self.page.screenshot(path=path, full_page=True)
251
+ return {"success": True, "path": path, "message": f"Screenshot saved to {path}"}
252
+ else:
253
+ # Return base64 encoded screenshot
254
+ screenshot_bytes = await self.page.screenshot(full_page=True)
255
+ import base64
256
+
257
+ screenshot_b64 = base64.b64encode(screenshot_bytes).decode()
258
+ return {
259
+ "success": True,
260
+ "screenshot": screenshot_b64,
261
+ "message": "Screenshot captured",
262
+ }
263
+ except Exception as e:
264
+ logger.error("Screenshot failed: %s", e)
265
+ return {"success": False, "error": str(e), "message": f"Failed to take screenshot: {e}"}
266
+
267
+ async def click(self, selector: str) -> dict[str, Any]:
268
+ """Click an element by selector.
269
+
270
+ Args:
271
+ selector: CSS selector for element to click
272
+
273
+ Returns:
274
+ Dict with click result
275
+ """
276
+ await self._ensure_browser()
277
+
278
+ try:
279
+ await self.page.click(selector)
280
+ return {"success": True, "message": f"Clicked element: {selector}"}
281
+ except Exception as e:
282
+ logger.error("Click failed: %s", e)
283
+ return {
284
+ "success": False,
285
+ "error": str(e),
286
+ "message": f"Failed to click {selector}: {e}",
287
+ }
288
+
289
+ async def type_text(self, selector: str, text: str) -> dict[str, Any]:
290
+ """Type text into an element.
291
+
292
+ Args:
293
+ selector: CSS selector for input element
294
+ text: Text to type
295
+
296
+ Returns:
297
+ Dict with type result
298
+ """
299
+ await self._ensure_browser()
300
+
301
+ try:
302
+ await self.page.fill(selector, text)
303
+ return {"success": True, "message": f"Typed '{text}' into {selector}"}
304
+ except Exception as e:
305
+ logger.error("Type failed: %s", e)
306
+ return {
307
+ "success": False,
308
+ "error": str(e),
309
+ "message": f"Failed to type into {selector}: {e}",
310
+ }
311
+
312
+ async def get_page_info(self) -> dict[str, Any]:
313
+ """Get current page information.
314
+
315
+ Returns:
316
+ Dict with page info
317
+ """
318
+ await self._ensure_browser()
319
+
320
+ try:
321
+ url = self.page.url
322
+ title = await self.page.title()
323
+ return {
324
+ "success": True,
325
+ "url": url,
326
+ "title": title,
327
+ "message": f"Current page: {title} ({url})",
328
+ }
329
+ except Exception as e:
330
+ logger.error("Get page info failed: %s", e)
331
+ return {"success": False, "error": str(e), "message": f"Failed to get page info: {e}"}
332
+
333
+ async def wait_for_element(self, selector: str) -> dict[str, Any]:
334
+ """Wait for an element to appear.
335
+
336
+ Args:
337
+ selector: CSS selector for element
338
+
339
+ Returns:
340
+ Dict with wait result
341
+ """
342
+ await self._ensure_browser()
343
+
344
+ try:
345
+ await self.page.wait_for_selector(selector, timeout=30000)
346
+ return {"success": True, "message": f"Element {selector} appeared"}
347
+ except Exception as e:
348
+ logger.error("Wait for element failed: %s", e)
349
+ return {
350
+ "success": False,
351
+ "error": str(e),
352
+ "message": f"Element {selector} did not appear within 30000ms: {e}",
353
+ }
354
+
355
+ async def close(self) -> None:
356
+ """Close browser and cleanup."""
357
+ if self._browser:
358
+ try:
359
+ await self._browser.close()
360
+ logger.info("Browser closed")
361
+ except Exception as e:
362
+ logger.error("Error closing browser: %s", e)
363
+
364
+ if self._playwright:
365
+ try:
366
+ await self._playwright.stop()
367
+ except Exception as e:
368
+ logger.error("Error stopping playwright: %s", e)
369
+
370
+ self._browser = None
371
+ self._context = None
372
+ self._page = None
373
+ self._playwright = None
@@ -0,0 +1,3 @@
1
+ from __future__ import annotations
2
+
3
+ __all__ = []
@@ -0,0 +1,152 @@
1
+ """Tests for bash tool."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from unittest.mock import AsyncMock, MagicMock, patch
6
+
7
+ import pytest
8
+
9
+ from hud.tools.base import ToolResult
10
+ from hud.tools.bash import BashTool, ToolError, _BashSession
11
+
12
+
13
+ class TestBashSession:
14
+ """Tests for _BashSession."""
15
+
16
+ @pytest.mark.asyncio
17
+ async def test_session_start(self):
18
+ """Test starting a bash session."""
19
+ session = _BashSession()
20
+ assert session._started is False
21
+
22
+ with patch("asyncio.create_subprocess_shell") as mock_create:
23
+ mock_process = MagicMock()
24
+ mock_create.return_value = mock_process
25
+
26
+ await session.start()
27
+
28
+ assert session._started is True
29
+ assert session._process == mock_process
30
+ mock_create.assert_called_once()
31
+
32
+ def test_session_stop_not_started(self):
33
+ """Test stopping a session that hasn't started."""
34
+ session = _BashSession()
35
+
36
+ with pytest.raises(ToolError) as exc_info:
37
+ session.stop()
38
+
39
+ assert "Session has not started" in str(exc_info.value)
40
+
41
+ @pytest.mark.asyncio
42
+ async def test_session_run_not_started(self):
43
+ """Test running command on a session that hasn't started."""
44
+ session = _BashSession()
45
+
46
+ with pytest.raises(ToolError) as exc_info:
47
+ await session.run("echo test")
48
+
49
+ assert "Session has not started" in str(exc_info.value)
50
+
51
+ @pytest.mark.asyncio
52
+ async def test_session_run_success(self):
53
+ """Test successful command execution."""
54
+ session = _BashSession()
55
+ session._started = True
56
+
57
+ # Mock process
58
+ mock_process = MagicMock()
59
+ mock_process.returncode = None
60
+ mock_process.stdin = MagicMock()
61
+ mock_process.stdin.write = MagicMock()
62
+ mock_process.stdin.drain = AsyncMock()
63
+ mock_process.stdout = MagicMock()
64
+ mock_process.stdout.readuntil = AsyncMock(return_value=b"Hello World\n<<exit>>\n")
65
+ mock_process.stderr = MagicMock()
66
+ mock_process.stderr.read = AsyncMock(return_value=b"")
67
+
68
+ session._process = mock_process
69
+
70
+ result = await session.run("echo Hello World")
71
+
72
+ assert result.output == "Hello World\n"
73
+ assert result.error == ""
74
+
75
+
76
+ class TestBashTool:
77
+ """Tests for BashTool."""
78
+
79
+ def test_bash_tool_init(self):
80
+ """Test BashTool initialization."""
81
+ tool = BashTool()
82
+ assert tool._session is None
83
+
84
+ @pytest.mark.asyncio
85
+ async def test_call_with_command(self):
86
+ """Test calling tool with a command."""
87
+ tool = BashTool()
88
+
89
+ # Mock session
90
+ mock_session = MagicMock()
91
+ mock_session.run = AsyncMock(return_value=ToolResult(output="test output"))
92
+
93
+ # Mock _BashSession creation
94
+ with patch("hud.tools.bash._BashSession") as mock_session_class:
95
+ mock_session_class.return_value = mock_session
96
+ mock_session.start = AsyncMock()
97
+
98
+ result = await tool(command="echo test")
99
+
100
+ assert isinstance(result, ToolResult)
101
+ assert result.output == "test output"
102
+ mock_session.start.assert_called_once()
103
+ mock_session.run.assert_called_once_with("echo test")
104
+
105
+ @pytest.mark.asyncio
106
+ async def test_call_restart(self):
107
+ """Test restarting the tool."""
108
+ tool = BashTool()
109
+
110
+ # Set up existing session
111
+ old_session = MagicMock()
112
+ old_session.stop = MagicMock()
113
+ tool._session = old_session
114
+
115
+ # Mock new session
116
+ new_session = MagicMock()
117
+ new_session.start = AsyncMock()
118
+
119
+ with patch("hud.tools.bash._BashSession", return_value=new_session):
120
+ result = await tool(restart=True)
121
+
122
+ assert isinstance(result, ToolResult)
123
+ assert result.system == "tool has been restarted."
124
+ old_session.stop.assert_called_once()
125
+ new_session.start.assert_called_once()
126
+ assert tool._session == new_session
127
+
128
+ @pytest.mark.asyncio
129
+ async def test_call_no_command_error(self):
130
+ """Test calling without command raises error."""
131
+ tool = BashTool()
132
+
133
+ with pytest.raises(ToolError) as exc_info:
134
+ await tool()
135
+
136
+ assert "no command provided" in str(exc_info.value)
137
+
138
+ @pytest.mark.asyncio
139
+ async def test_call_with_existing_session(self):
140
+ """Test calling with an existing session."""
141
+ tool = BashTool()
142
+
143
+ # Set up existing session
144
+ existing_session = MagicMock()
145
+ existing_session.run = AsyncMock(return_value=ToolResult(output="result"))
146
+ tool._session = existing_session
147
+
148
+ result = await tool(command="ls")
149
+
150
+ assert isinstance(result, ToolResult)
151
+ assert result.output == "result"
152
+ existing_session.run.assert_called_once_with("ls")