hud-python 0.4.45__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (274) hide show
  1. hud/__init__.py +27 -7
  2. hud/agents/__init__.py +11 -5
  3. hud/agents/base.py +220 -500
  4. hud/agents/claude.py +200 -240
  5. hud/agents/gemini.py +275 -0
  6. hud/agents/gemini_cua.py +335 -0
  7. hud/agents/grounded_openai.py +98 -100
  8. hud/agents/misc/integration_test_agent.py +51 -20
  9. hud/agents/misc/response_agent.py +41 -36
  10. hud/agents/openai.py +291 -292
  11. hud/agents/{openai_chat_generic.py → openai_chat.py} +80 -34
  12. hud/agents/operator.py +211 -0
  13. hud/agents/tests/conftest.py +133 -0
  14. hud/agents/tests/test_base.py +300 -622
  15. hud/agents/tests/test_base_runtime.py +233 -0
  16. hud/agents/tests/test_claude.py +379 -210
  17. hud/agents/tests/test_client.py +9 -10
  18. hud/agents/tests/test_gemini.py +369 -0
  19. hud/agents/tests/test_grounded_openai_agent.py +65 -50
  20. hud/agents/tests/test_openai.py +376 -140
  21. hud/agents/tests/test_operator.py +362 -0
  22. hud/agents/tests/test_run_eval.py +179 -0
  23. hud/cli/__init__.py +461 -545
  24. hud/cli/analyze.py +43 -5
  25. hud/cli/build.py +664 -110
  26. hud/cli/debug.py +8 -5
  27. hud/cli/dev.py +882 -734
  28. hud/cli/eval.py +782 -668
  29. hud/cli/flows/dev.py +167 -0
  30. hud/cli/flows/init.py +191 -0
  31. hud/cli/flows/tasks.py +153 -56
  32. hud/cli/flows/templates.py +151 -0
  33. hud/cli/flows/tests/__init__.py +1 -0
  34. hud/cli/flows/tests/test_dev.py +126 -0
  35. hud/cli/init.py +60 -58
  36. hud/cli/push.py +29 -11
  37. hud/cli/rft.py +311 -0
  38. hud/cli/rft_status.py +145 -0
  39. hud/cli/tests/test_analyze.py +5 -5
  40. hud/cli/tests/test_analyze_metadata.py +3 -2
  41. hud/cli/tests/test_analyze_module.py +120 -0
  42. hud/cli/tests/test_build.py +108 -6
  43. hud/cli/tests/test_build_failure.py +41 -0
  44. hud/cli/tests/test_build_module.py +50 -0
  45. hud/cli/tests/test_cli_init.py +6 -1
  46. hud/cli/tests/test_cli_more_wrappers.py +30 -0
  47. hud/cli/tests/test_cli_root.py +140 -0
  48. hud/cli/tests/test_convert.py +361 -0
  49. hud/cli/tests/test_debug.py +12 -10
  50. hud/cli/tests/test_dev.py +197 -0
  51. hud/cli/tests/test_eval.py +251 -0
  52. hud/cli/tests/test_eval_bedrock.py +51 -0
  53. hud/cli/tests/test_init.py +124 -0
  54. hud/cli/tests/test_main_module.py +11 -5
  55. hud/cli/tests/test_mcp_server.py +12 -100
  56. hud/cli/tests/test_push_happy.py +74 -0
  57. hud/cli/tests/test_push_wrapper.py +23 -0
  58. hud/cli/tests/test_registry.py +1 -1
  59. hud/cli/tests/test_utils.py +1 -1
  60. hud/cli/{rl → utils}/celebrate.py +14 -12
  61. hud/cli/utils/config.py +18 -1
  62. hud/cli/utils/docker.py +130 -4
  63. hud/cli/utils/env_check.py +9 -9
  64. hud/cli/utils/git.py +136 -0
  65. hud/cli/utils/interactive.py +39 -5
  66. hud/cli/utils/metadata.py +69 -0
  67. hud/cli/utils/runner.py +1 -1
  68. hud/cli/utils/server.py +2 -2
  69. hud/cli/utils/source_hash.py +3 -3
  70. hud/cli/utils/tasks.py +4 -1
  71. hud/cli/utils/tests/__init__.py +0 -0
  72. hud/cli/utils/tests/test_config.py +58 -0
  73. hud/cli/utils/tests/test_docker.py +93 -0
  74. hud/cli/utils/tests/test_docker_hints.py +71 -0
  75. hud/cli/utils/tests/test_env_check.py +74 -0
  76. hud/cli/utils/tests/test_environment.py +42 -0
  77. hud/cli/utils/tests/test_git.py +142 -0
  78. hud/cli/utils/tests/test_interactive_module.py +60 -0
  79. hud/cli/utils/tests/test_local_runner.py +50 -0
  80. hud/cli/utils/tests/test_logging_utils.py +23 -0
  81. hud/cli/utils/tests/test_metadata.py +49 -0
  82. hud/cli/utils/tests/test_package_runner.py +35 -0
  83. hud/cli/utils/tests/test_registry_utils.py +49 -0
  84. hud/cli/utils/tests/test_remote_runner.py +25 -0
  85. hud/cli/utils/tests/test_runner_modules.py +52 -0
  86. hud/cli/utils/tests/test_source_hash.py +36 -0
  87. hud/cli/utils/tests/test_tasks.py +80 -0
  88. hud/cli/utils/version_check.py +258 -0
  89. hud/cli/{rl → utils}/viewer.py +2 -2
  90. hud/clients/README.md +12 -11
  91. hud/clients/__init__.py +4 -3
  92. hud/clients/base.py +166 -26
  93. hud/clients/environment.py +51 -0
  94. hud/clients/fastmcp.py +13 -6
  95. hud/clients/mcp_use.py +40 -15
  96. hud/clients/tests/test_analyze_scenarios.py +206 -0
  97. hud/clients/tests/test_protocol.py +9 -3
  98. hud/datasets/__init__.py +23 -20
  99. hud/datasets/loader.py +327 -0
  100. hud/datasets/runner.py +192 -105
  101. hud/datasets/tests/__init__.py +0 -0
  102. hud/datasets/tests/test_loader.py +221 -0
  103. hud/datasets/tests/test_utils.py +315 -0
  104. hud/datasets/utils.py +270 -90
  105. hud/environment/__init__.py +50 -0
  106. hud/environment/connection.py +206 -0
  107. hud/environment/connectors/__init__.py +33 -0
  108. hud/environment/connectors/base.py +68 -0
  109. hud/environment/connectors/local.py +177 -0
  110. hud/environment/connectors/mcp_config.py +109 -0
  111. hud/environment/connectors/openai.py +101 -0
  112. hud/environment/connectors/remote.py +172 -0
  113. hud/environment/environment.py +694 -0
  114. hud/environment/integrations/__init__.py +45 -0
  115. hud/environment/integrations/adk.py +67 -0
  116. hud/environment/integrations/anthropic.py +196 -0
  117. hud/environment/integrations/gemini.py +92 -0
  118. hud/environment/integrations/langchain.py +82 -0
  119. hud/environment/integrations/llamaindex.py +68 -0
  120. hud/environment/integrations/openai.py +238 -0
  121. hud/environment/mock.py +306 -0
  122. hud/environment/router.py +112 -0
  123. hud/environment/scenarios.py +493 -0
  124. hud/environment/tests/__init__.py +1 -0
  125. hud/environment/tests/test_connection.py +317 -0
  126. hud/environment/tests/test_connectors.py +218 -0
  127. hud/environment/tests/test_environment.py +161 -0
  128. hud/environment/tests/test_integrations.py +257 -0
  129. hud/environment/tests/test_local_connectors.py +201 -0
  130. hud/environment/tests/test_scenarios.py +280 -0
  131. hud/environment/tests/test_tools.py +208 -0
  132. hud/environment/types.py +23 -0
  133. hud/environment/utils/__init__.py +35 -0
  134. hud/environment/utils/formats.py +215 -0
  135. hud/environment/utils/schema.py +171 -0
  136. hud/environment/utils/tool_wrappers.py +113 -0
  137. hud/eval/__init__.py +67 -0
  138. hud/eval/context.py +674 -0
  139. hud/eval/display.py +299 -0
  140. hud/eval/instrument.py +185 -0
  141. hud/eval/manager.py +466 -0
  142. hud/eval/parallel.py +268 -0
  143. hud/eval/task.py +340 -0
  144. hud/eval/tests/__init__.py +1 -0
  145. hud/eval/tests/test_context.py +178 -0
  146. hud/eval/tests/test_eval.py +210 -0
  147. hud/eval/tests/test_manager.py +152 -0
  148. hud/eval/tests/test_parallel.py +168 -0
  149. hud/eval/tests/test_task.py +145 -0
  150. hud/eval/types.py +63 -0
  151. hud/eval/utils.py +183 -0
  152. hud/patches/__init__.py +19 -0
  153. hud/patches/mcp_patches.py +151 -0
  154. hud/patches/warnings.py +54 -0
  155. hud/samples/browser.py +4 -4
  156. hud/server/__init__.py +2 -1
  157. hud/server/low_level.py +2 -1
  158. hud/server/router.py +164 -0
  159. hud/server/server.py +567 -80
  160. hud/server/tests/test_mcp_server_integration.py +11 -11
  161. hud/server/tests/test_mcp_server_more.py +1 -1
  162. hud/server/tests/test_server_extra.py +2 -0
  163. hud/settings.py +45 -3
  164. hud/shared/exceptions.py +36 -10
  165. hud/shared/hints.py +26 -1
  166. hud/shared/requests.py +15 -3
  167. hud/shared/tests/test_exceptions.py +40 -31
  168. hud/shared/tests/test_hints.py +167 -0
  169. hud/telemetry/__init__.py +20 -19
  170. hud/telemetry/exporter.py +201 -0
  171. hud/telemetry/instrument.py +158 -253
  172. hud/telemetry/tests/test_eval_telemetry.py +356 -0
  173. hud/telemetry/tests/test_exporter.py +258 -0
  174. hud/telemetry/tests/test_instrument.py +401 -0
  175. hud/tools/__init__.py +16 -2
  176. hud/tools/apply_patch.py +639 -0
  177. hud/tools/base.py +54 -4
  178. hud/tools/bash.py +2 -2
  179. hud/tools/computer/__init__.py +4 -0
  180. hud/tools/computer/anthropic.py +2 -2
  181. hud/tools/computer/gemini.py +385 -0
  182. hud/tools/computer/hud.py +23 -6
  183. hud/tools/computer/openai.py +20 -21
  184. hud/tools/computer/qwen.py +434 -0
  185. hud/tools/computer/settings.py +37 -0
  186. hud/tools/edit.py +3 -7
  187. hud/tools/executors/base.py +4 -2
  188. hud/tools/executors/pyautogui.py +1 -1
  189. hud/tools/grounding/grounded_tool.py +13 -18
  190. hud/tools/grounding/grounder.py +10 -31
  191. hud/tools/grounding/tests/test_grounded_tool.py +26 -44
  192. hud/tools/jupyter.py +330 -0
  193. hud/tools/playwright.py +18 -3
  194. hud/tools/shell.py +308 -0
  195. hud/tools/tests/test_apply_patch.py +718 -0
  196. hud/tools/tests/test_computer.py +4 -9
  197. hud/tools/tests/test_computer_actions.py +24 -2
  198. hud/tools/tests/test_jupyter_tool.py +181 -0
  199. hud/tools/tests/test_shell.py +596 -0
  200. hud/tools/tests/test_submit.py +85 -0
  201. hud/tools/tests/test_types.py +193 -0
  202. hud/tools/types.py +21 -1
  203. hud/types.py +167 -57
  204. hud/utils/__init__.py +2 -0
  205. hud/utils/env.py +67 -0
  206. hud/utils/hud_console.py +61 -3
  207. hud/utils/mcp.py +15 -58
  208. hud/utils/strict_schema.py +162 -0
  209. hud/utils/tests/test_init.py +1 -2
  210. hud/utils/tests/test_mcp.py +1 -28
  211. hud/utils/tests/test_pretty_errors.py +186 -0
  212. hud/utils/tests/test_tool_shorthand.py +154 -0
  213. hud/utils/tests/test_version.py +1 -1
  214. hud/utils/types.py +20 -0
  215. hud/version.py +1 -1
  216. hud_python-0.5.1.dist-info/METADATA +264 -0
  217. hud_python-0.5.1.dist-info/RECORD +299 -0
  218. {hud_python-0.4.45.dist-info → hud_python-0.5.1.dist-info}/WHEEL +1 -1
  219. hud/agents/langchain.py +0 -261
  220. hud/agents/lite_llm.py +0 -72
  221. hud/cli/rl/__init__.py +0 -180
  222. hud/cli/rl/config.py +0 -101
  223. hud/cli/rl/display.py +0 -133
  224. hud/cli/rl/gpu.py +0 -63
  225. hud/cli/rl/gpu_utils.py +0 -321
  226. hud/cli/rl/local_runner.py +0 -595
  227. hud/cli/rl/presets.py +0 -96
  228. hud/cli/rl/remote_runner.py +0 -463
  229. hud/cli/rl/rl_api.py +0 -150
  230. hud/cli/rl/vllm.py +0 -177
  231. hud/cli/rl/wait_utils.py +0 -89
  232. hud/datasets/parallel.py +0 -687
  233. hud/misc/__init__.py +0 -1
  234. hud/misc/claude_plays_pokemon.py +0 -292
  235. hud/otel/__init__.py +0 -35
  236. hud/otel/collector.py +0 -142
  237. hud/otel/config.py +0 -181
  238. hud/otel/context.py +0 -570
  239. hud/otel/exporters.py +0 -369
  240. hud/otel/instrumentation.py +0 -135
  241. hud/otel/processors.py +0 -121
  242. hud/otel/tests/__init__.py +0 -1
  243. hud/otel/tests/test_processors.py +0 -197
  244. hud/rl/README.md +0 -30
  245. hud/rl/__init__.py +0 -1
  246. hud/rl/actor.py +0 -176
  247. hud/rl/buffer.py +0 -405
  248. hud/rl/chat_template.jinja +0 -101
  249. hud/rl/config.py +0 -192
  250. hud/rl/distributed.py +0 -132
  251. hud/rl/learner.py +0 -637
  252. hud/rl/tests/__init__.py +0 -1
  253. hud/rl/tests/test_learner.py +0 -186
  254. hud/rl/train.py +0 -382
  255. hud/rl/types.py +0 -101
  256. hud/rl/utils/start_vllm_server.sh +0 -30
  257. hud/rl/utils.py +0 -524
  258. hud/rl/vllm_adapter.py +0 -143
  259. hud/telemetry/job.py +0 -352
  260. hud/telemetry/replay.py +0 -74
  261. hud/telemetry/tests/test_replay.py +0 -40
  262. hud/telemetry/tests/test_trace.py +0 -63
  263. hud/telemetry/trace.py +0 -158
  264. hud/utils/agent_factories.py +0 -86
  265. hud/utils/async_utils.py +0 -65
  266. hud/utils/group_eval.py +0 -223
  267. hud/utils/progress.py +0 -149
  268. hud/utils/tasks.py +0 -127
  269. hud/utils/tests/test_async_utils.py +0 -173
  270. hud/utils/tests/test_progress.py +0 -261
  271. hud_python-0.4.45.dist-info/METADATA +0 -552
  272. hud_python-0.4.45.dist-info/RECORD +0 -228
  273. {hud_python-0.4.45.dist-info → hud_python-0.5.1.dist-info}/entry_points.txt +0 -0
  274. {hud_python-0.4.45.dist-info → hud_python-0.5.1.dist-info}/licenses/LICENSE +0 -0
hud/agents/gemini.py ADDED
@@ -0,0 +1,275 @@
1
+ """Gemini MCP Agent implementation."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ from typing import Any, ClassVar, cast
7
+
8
+ import mcp.types as types
9
+ from google import genai
10
+ from google.genai import types as genai_types
11
+ from pydantic import ConfigDict
12
+
13
+ from hud.settings import settings
14
+ from hud.types import AgentResponse, BaseAgentConfig, MCPToolCall, MCPToolResult
15
+ from hud.utils.hud_console import HUDConsole
16
+ from hud.utils.types import with_signature
17
+
18
+ from .base import BaseCreateParams, MCPAgent
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ class GeminiConfig(BaseAgentConfig):
24
+ """Configuration for `GeminiAgent`."""
25
+
26
+ model_config = ConfigDict(arbitrary_types_allowed=True)
27
+
28
+ model_name: str = "Gemini"
29
+ model: str = "gemini-3-pro-preview"
30
+ model_client: genai.Client | None = None
31
+ temperature: float = 1.0
32
+ top_p: float = 0.95
33
+ top_k: int = 40
34
+ max_output_tokens: int = 8192
35
+ validate_api_key: bool = True
36
+
37
+
38
+ class GeminiCreateParams(BaseCreateParams, GeminiConfig):
39
+ pass
40
+
41
+
42
+ class GeminiAgent(MCPAgent):
43
+ """
44
+ Gemini agent that uses MCP servers for tool execution.
45
+
46
+ This agent uses Gemini's native tool calling capabilities but executes
47
+ tools through MCP servers instead of direct implementation.
48
+ """
49
+
50
+ metadata: ClassVar[dict[str, Any] | None] = None
51
+ config_cls: ClassVar[type[BaseAgentConfig]] = GeminiConfig
52
+
53
+ @with_signature(GeminiCreateParams)
54
+ @classmethod
55
+ def create(cls, **kwargs: Any) -> GeminiAgent: # pyright: ignore[reportIncompatibleMethodOverride]
56
+ return MCPAgent.create.__func__(cls, **kwargs) # type: ignore[return-value]
57
+
58
+ def __init__(self, params: GeminiCreateParams | None = None, **kwargs: Any) -> None:
59
+ super().__init__(params, **kwargs)
60
+ self.config: GeminiConfig
61
+
62
+ model_client = self.config.model_client
63
+ if model_client is None:
64
+ api_key = settings.gemini_api_key
65
+ if not api_key:
66
+ raise ValueError("Gemini API key not found. Set GEMINI_API_KEY.")
67
+ model_client = genai.Client(api_key=api_key)
68
+
69
+ if self.config.validate_api_key:
70
+ try:
71
+ list(model_client.models.list(config=genai_types.ListModelsConfig(page_size=1)))
72
+ except Exception as e:
73
+ raise ValueError(f"Gemini API key is invalid: {e}") from e
74
+
75
+ self.gemini_client = model_client
76
+ self.temperature = self.config.temperature
77
+ self.top_p = self.config.top_p
78
+ self.top_k = self.config.top_k
79
+ self.max_output_tokens = self.config.max_output_tokens
80
+ self.hud_console = HUDConsole(logger=logger)
81
+
82
+ # Track mapping from Gemini tool names to MCP tool names
83
+ self._gemini_to_mcp_tool_map: dict[str, str] = {}
84
+ self.gemini_tools: genai_types.ToolListUnion = []
85
+
86
+ def _on_tools_ready(self) -> None:
87
+ """Build Gemini-specific tool mappings after tools are discovered."""
88
+ self._convert_tools_for_gemini()
89
+
90
+ async def get_system_messages(self) -> list[genai_types.Content]:
91
+ """No system messages for Gemini because applied in get_response"""
92
+ return []
93
+
94
+ async def format_blocks(self, blocks: list[types.ContentBlock]) -> list[genai_types.Content]:
95
+ """Format messages for Gemini."""
96
+ # Convert MCP content types to Gemini content types
97
+ gemini_parts: list[genai_types.Part] = []
98
+
99
+ for block in blocks:
100
+ if isinstance(block, types.TextContent):
101
+ gemini_parts.append(genai_types.Part(text=block.text))
102
+ elif isinstance(block, types.ImageContent):
103
+ # Convert MCP ImageContent to Gemini format
104
+ # Need to decode base64 string to bytes
105
+ import base64
106
+
107
+ image_bytes = base64.b64decode(block.data)
108
+ gemini_parts.append(
109
+ genai_types.Part.from_bytes(data=image_bytes, mime_type=block.mimeType)
110
+ )
111
+ else:
112
+ # For other types, try to handle but log a warning
113
+ self.hud_console.log(f"Unknown content block type: {type(block)}", level="warning")
114
+
115
+ return [genai_types.Content(role="user", parts=gemini_parts)]
116
+
117
+ async def get_response(self, messages: list[genai_types.Content]) -> AgentResponse:
118
+ """Get response from Gemini including any tool calls."""
119
+ # Build generate content config
120
+ generate_config = genai_types.GenerateContentConfig(
121
+ temperature=self.temperature,
122
+ top_p=self.top_p,
123
+ top_k=self.top_k,
124
+ max_output_tokens=self.max_output_tokens,
125
+ tools=self.gemini_tools,
126
+ system_instruction=self.system_prompt,
127
+ )
128
+
129
+ # Use async API to avoid blocking the event loop
130
+ response = await self.gemini_client.aio.models.generate_content(
131
+ model=self.config.model,
132
+ contents=cast("Any", messages),
133
+ config=generate_config,
134
+ )
135
+
136
+ # Append assistant response (including any function_call) so that
137
+ # subsequent FunctionResponse messages correspond to a prior FunctionCall
138
+ if response.candidates and len(response.candidates) > 0 and response.candidates[0].content:
139
+ messages.append(response.candidates[0].content)
140
+
141
+ # Process response
142
+ result = AgentResponse(content="", tool_calls=[], done=True)
143
+ collected_tool_calls: list[MCPToolCall] = []
144
+
145
+ if not response.candidates:
146
+ self.hud_console.warning("Response has no candidates")
147
+ return result
148
+
149
+ candidate = response.candidates[0]
150
+
151
+ # Extract text content and function calls
152
+ text_content = ""
153
+ thinking_content = ""
154
+
155
+ if candidate.content and candidate.content.parts:
156
+ for part in candidate.content.parts:
157
+ if part.function_call:
158
+ tool_call = self._extract_tool_call(part)
159
+ if tool_call is not None:
160
+ collected_tool_calls.append(tool_call)
161
+ elif part.thought is True and part.text:
162
+ if thinking_content:
163
+ thinking_content += "\n"
164
+ thinking_content += part.text
165
+ elif part.text:
166
+ text_content += part.text
167
+
168
+ # Assign collected tool calls and mark done status
169
+ if collected_tool_calls:
170
+ result.tool_calls = collected_tool_calls
171
+ result.done = False
172
+
173
+ result.content = text_content
174
+ if thinking_content:
175
+ result.reasoning = thinking_content
176
+
177
+ return result
178
+
179
+ def _extract_tool_call(self, part: genai_types.Part) -> MCPToolCall | None:
180
+ """Extract an MCPToolCall from a function call part.
181
+
182
+ Subclasses can override to customize tool call extraction (e.g., normalizing
183
+ computer use calls to a different schema).
184
+ """
185
+ if not part.function_call:
186
+ return None
187
+
188
+ func_name = part.function_call.name or ""
189
+ mcp_tool_name = self._gemini_to_mcp_tool_map.get(func_name, func_name)
190
+ raw_args = dict(part.function_call.args) if part.function_call.args else {}
191
+
192
+ return MCPToolCall(
193
+ name=mcp_tool_name,
194
+ arguments=raw_args,
195
+ )
196
+
197
+ async def format_tool_results(
198
+ self, tool_calls: list[MCPToolCall], tool_results: list[MCPToolResult]
199
+ ) -> list[genai_types.Content]:
200
+ """Format tool results into Gemini messages."""
201
+ # Process each tool result
202
+ function_responses = []
203
+
204
+ for tool_call, result in zip(tool_calls, tool_results, strict=True):
205
+ # Get the Gemini function name from metadata
206
+ gemini_name = getattr(tool_call, "gemini_name", tool_call.name)
207
+
208
+ # Convert MCP tool results to Gemini format
209
+ response_dict: dict[str, Any] = {}
210
+
211
+ if result.isError:
212
+ # Extract error message from content
213
+ error_msg = "Tool execution failed"
214
+ for content in result.content:
215
+ if isinstance(content, types.TextContent):
216
+ error_msg = content.text
217
+ break
218
+ response_dict["error"] = error_msg
219
+ else:
220
+ # Process success content
221
+ response_dict["success"] = True
222
+ # Add text content to response
223
+ for content in result.content:
224
+ if isinstance(content, types.TextContent):
225
+ response_dict["output"] = content.text
226
+ break
227
+
228
+ # Create function response
229
+ function_response = genai_types.FunctionResponse(
230
+ name=gemini_name,
231
+ response=response_dict,
232
+ )
233
+ function_responses.append(function_response)
234
+
235
+ # Return as a user message containing all function responses
236
+ return [
237
+ genai_types.Content(
238
+ role="user",
239
+ parts=[genai_types.Part(function_response=fr) for fr in function_responses],
240
+ )
241
+ ]
242
+
243
+ async def create_user_message(self, text: str) -> genai_types.Content:
244
+ """Create a user message in Gemini's format."""
245
+ return genai_types.Content(role="user", parts=[genai_types.Part(text=text)])
246
+
247
+ def _convert_tools_for_gemini(self) -> genai_types.ToolListUnion:
248
+ """Convert MCP tools to Gemini tool format."""
249
+ self._gemini_to_mcp_tool_map = {} # Reset mapping
250
+ self.gemini_tools = []
251
+
252
+ for tool in self.get_available_tools():
253
+ gemini_tool = self._to_gemini_tool(tool)
254
+ if gemini_tool is None:
255
+ continue
256
+
257
+ self._gemini_to_mcp_tool_map[tool.name] = tool.name
258
+ self.gemini_tools.append(gemini_tool)
259
+
260
+ return self.gemini_tools
261
+
262
+ def _to_gemini_tool(self, tool: types.Tool) -> genai_types.Tool | None:
263
+ """Convert a single MCP tool to Gemini tool format.
264
+
265
+ Subclasses can override to customize tool conversion (e.g., for computer use).
266
+ """
267
+ # Ensure parameters have proper Schema format
268
+ if tool.description is None or tool.inputSchema is None:
269
+ raise ValueError(f"MCP tool {tool.name} requires both a description and inputSchema.")
270
+ function_decl = genai_types.FunctionDeclaration(
271
+ name=tool.name,
272
+ description=tool.description,
273
+ parameters_json_schema=tool.inputSchema,
274
+ )
275
+ return genai_types.Tool(function_declarations=[function_decl])
@@ -0,0 +1,335 @@
1
+ """Gemini Computer Use Agent implementation."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ from typing import Any, ClassVar
7
+
8
+ import mcp.types as types
9
+ from google.genai import types as genai_types
10
+ from pydantic import ConfigDict, Field
11
+
12
+ from hud.tools.computer.settings import computer_settings
13
+ from hud.types import AgentResponse, BaseAgentConfig, MCPToolCall, MCPToolResult
14
+ from hud.utils.types import with_signature
15
+
16
+ from .base import BaseCreateParams, MCPAgent
17
+ from .gemini import GeminiAgent, GeminiConfig
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+ # Predefined Gemini computer use functions
22
+ PREDEFINED_COMPUTER_USE_FUNCTIONS = [
23
+ "open_web_browser",
24
+ "click_at",
25
+ "hover_at",
26
+ "type_text_at",
27
+ "scroll_document",
28
+ "scroll_at",
29
+ "wait_5_seconds",
30
+ "go_back",
31
+ "go_forward",
32
+ "search",
33
+ "navigate",
34
+ "key_combination",
35
+ "drag_and_drop",
36
+ ]
37
+
38
+ GEMINI_CUA_INSTRUCTIONS = """
39
+ You are an autonomous computer-using agent. Follow these guidelines:
40
+
41
+ 1. NEVER ask for confirmation. Complete all tasks autonomously.
42
+ 2. Do NOT send messages like "I need to confirm before..." or "Do you want me to
43
+ continue?" - just proceed.
44
+ 3. When the user asks you to interact with something (like clicking a chat or typing
45
+ a message), DO IT without asking.
46
+ 4. Only use the formal safety check mechanism for truly dangerous operations (like
47
+ deleting important files).
48
+ 5. For normal tasks like clicking buttons, typing in chat boxes, filling forms -
49
+ JUST DO IT.
50
+ 6. The user has already given you permission by running this agent. No further
51
+ confirmation is needed.
52
+ 7. Be decisive and action-oriented. Complete the requested task fully.
53
+
54
+ Remember: You are expected to complete tasks autonomously. The user trusts you to do
55
+ what they asked.
56
+ """.strip()
57
+
58
+
59
+ class GeminiCUAConfig(GeminiConfig):
60
+ """Configuration for `GeminiCUAAgent`."""
61
+
62
+ model_config = ConfigDict(arbitrary_types_allowed=True)
63
+
64
+ model_name: str = "GeminiCUA"
65
+ model: str = "gemini-2.5-computer-use-preview-10-2025"
66
+ excluded_predefined_functions: list[str] = Field(default_factory=list)
67
+
68
+
69
+ class GeminiCUACreateParams(BaseCreateParams, GeminiCUAConfig):
70
+ pass
71
+
72
+
73
+ class GeminiCUAAgent(GeminiAgent):
74
+ """
75
+ Gemini Computer Use Agent that extends GeminiAgent with computer use capabilities.
76
+
77
+ This agent uses Gemini's native computer use capabilities but executes
78
+ tools through MCP servers instead of direct implementation.
79
+ """
80
+
81
+ metadata: ClassVar[dict[str, Any] | None] = {
82
+ "display_width": computer_settings.GEMINI_COMPUTER_WIDTH,
83
+ "display_height": computer_settings.GEMINI_COMPUTER_HEIGHT,
84
+ }
85
+ required_tools: ClassVar[list[str]] = ["gemini_computer"]
86
+ config_cls: ClassVar[type[BaseAgentConfig]] = GeminiCUAConfig
87
+
88
+ @with_signature(GeminiCUACreateParams)
89
+ @classmethod
90
+ def create(cls, **kwargs: Any) -> GeminiCUAAgent: # pyright: ignore[reportIncompatibleMethodOverride]
91
+ return MCPAgent.create.__func__(cls, **kwargs) # type: ignore[return-value]
92
+
93
+ def __init__(self, params: GeminiCUACreateParams | None = None, **kwargs: Any) -> None:
94
+ super().__init__(params, **kwargs) # type: ignore[arg-type]
95
+ self.config: GeminiCUAConfig # type: ignore[assignment]
96
+
97
+ self._computer_tool_name = "gemini_computer"
98
+ self.excluded_predefined_functions = list(self.config.excluded_predefined_functions)
99
+
100
+ # Context management: Maximum number of recent turns to keep screenshots for
101
+ # Configurable via GEMINI_MAX_RECENT_TURN_WITH_SCREENSHOTS environment variable
102
+ self.max_recent_turn_with_screenshots = (
103
+ computer_settings.GEMINI_MAX_RECENT_TURN_WITH_SCREENSHOTS
104
+ )
105
+
106
+ # Add computer use instructions
107
+ if self.system_prompt:
108
+ self.system_prompt = f"{self.system_prompt}\n\n{GEMINI_CUA_INSTRUCTIONS}"
109
+ else:
110
+ self.system_prompt = GEMINI_CUA_INSTRUCTIONS
111
+
112
+ def _to_gemini_tool(self, tool: types.Tool) -> genai_types.Tool | None:
113
+ """Convert a single MCP tool to Gemini tool format.
114
+
115
+ Handles gemini_computer tool specially by using Gemini's native ComputerUse.
116
+ """
117
+ if tool.name == self._computer_tool_name:
118
+ # Use Gemini's native computer use capability
119
+ return genai_types.Tool(
120
+ computer_use=genai_types.ComputerUse(
121
+ environment=genai_types.Environment.ENVIRONMENT_BROWSER,
122
+ excluded_predefined_functions=self.excluded_predefined_functions,
123
+ )
124
+ )
125
+
126
+ # For non-computer tools, use the parent implementation
127
+ return super()._to_gemini_tool(tool)
128
+
129
+ async def get_response(self, messages: list[genai_types.Content]) -> AgentResponse:
130
+ """Get response from Gemini including any tool calls.
131
+
132
+ Extends parent to trim old screenshots before making API call.
133
+ """
134
+ # Trim screenshots from older turns to manage context growth
135
+ self._remove_old_screenshots(messages)
136
+
137
+ return await super().get_response(messages)
138
+
139
+ async def format_tool_results(
140
+ self, tool_calls: list[MCPToolCall], tool_results: list[MCPToolResult]
141
+ ) -> list[genai_types.Content]:
142
+ """Format tool results into Gemini messages.
143
+
144
+ Handles computer tool results specially with screenshots and URLs.
145
+ """
146
+ # Process each tool result
147
+ function_responses = []
148
+
149
+ for tool_call, result in zip(tool_calls, tool_results, strict=True):
150
+ # Get the Gemini function name from metadata
151
+ gemini_name = getattr(tool_call, "gemini_name", tool_call.name)
152
+
153
+ # Check if this is a computer use tool call
154
+ is_computer_call = tool_call.name == self._computer_tool_name
155
+
156
+ # Convert MCP tool results to Gemini format
157
+ response_dict: dict[str, Any] = {}
158
+ url = None
159
+
160
+ if result.isError:
161
+ # Extract error message from content
162
+ error_msg = "Tool execution failed"
163
+ for content in result.content:
164
+ if isinstance(content, types.TextContent):
165
+ # Check if this is a URL metadata block
166
+ if content.text.startswith("__URL__:"):
167
+ url = content.text.replace("__URL__:", "")
168
+ else:
169
+ error_msg = content.text
170
+ break
171
+ response_dict["error"] = error_msg
172
+ # for gemini cua agent, if a nonexistend computer tool is called, it won't
173
+ # #technically count as a computer tool call, but we still need to return a url
174
+ response_dict["url"] = url if url else "about:blank"
175
+ else:
176
+ # Process success content
177
+ response_dict["success"] = True
178
+
179
+ # Extract URL and screenshot from content (for computer use)
180
+ screenshot_parts = []
181
+ if is_computer_call:
182
+ for content in result.content:
183
+ if isinstance(content, types.TextContent):
184
+ # Check if this is a URL metadata block
185
+ if content.text.startswith("__URL__:"):
186
+ url = content.text.replace("__URL__:", "")
187
+ elif isinstance(content, types.ImageContent):
188
+ # Decode base64 string to bytes for FunctionResponseBlob
189
+ import base64
190
+
191
+ image_bytes = base64.b64decode(content.data)
192
+ screenshot_parts.append(
193
+ genai_types.FunctionResponsePart(
194
+ inline_data=genai_types.FunctionResponseBlob(
195
+ mime_type=content.mimeType or "image/png",
196
+ data=image_bytes,
197
+ )
198
+ )
199
+ )
200
+
201
+ # Add URL to response dict (required by Gemini Computer Use model)
202
+ # URL must ALWAYS be present per Gemini API requirements
203
+ response_dict["url"] = url if url else "about:blank"
204
+
205
+ # For Gemini Computer Use actions, always acknowledge safety decisions
206
+ requires_ack = False
207
+ if tool_call.arguments:
208
+ requires_ack = bool(tool_call.arguments.get("safety_decision"))
209
+ if requires_ack:
210
+ response_dict["safety_acknowledgement"] = True
211
+ else:
212
+ # For non-computer tools, add text content to response
213
+ for content in result.content:
214
+ if isinstance(content, types.TextContent):
215
+ response_dict["output"] = content.text
216
+ break
217
+
218
+ # Create function response
219
+ function_response = genai_types.FunctionResponse(
220
+ name=gemini_name,
221
+ response=response_dict,
222
+ parts=screenshot_parts if screenshot_parts else None,
223
+ )
224
+ function_responses.append(function_response)
225
+
226
+ # Return as a user message containing all function responses
227
+ return [
228
+ genai_types.Content(
229
+ role="user",
230
+ parts=[genai_types.Part(function_response=fr) for fr in function_responses],
231
+ )
232
+ ]
233
+
234
+ def _extract_tool_call(self, part: genai_types.Part) -> MCPToolCall | None:
235
+ """Extract an MCPToolCall from a function call part.
236
+
237
+ Routes predefined Gemini Computer Use functions to the gemini_computer tool
238
+ and normalizes the arguments to MCP tool schema.
239
+ """
240
+ if not part.function_call:
241
+ return None
242
+
243
+ func_name = part.function_call.name or ""
244
+ raw_args = dict(part.function_call.args) if part.function_call.args else {}
245
+
246
+ # Route predefined computer use functions to the computer tool
247
+ if func_name in PREDEFINED_COMPUTER_USE_FUNCTIONS:
248
+ # Normalize Gemini Computer Use calls to MCP tool schema
249
+ # Ensure 'action' is present and equals the Gemini function name
250
+ normalized_args: dict[str, Any] = {"action": func_name}
251
+
252
+ # Map common argument shapes used by Gemini Computer Use
253
+ # 1) Coordinate arrays → x/y
254
+ coord = raw_args.get("coordinate") or raw_args.get("coordinates")
255
+ if isinstance(coord, list | tuple) and len(coord) >= 2:
256
+ try:
257
+ normalized_args["x"] = int(coord[0])
258
+ normalized_args["y"] = int(coord[1])
259
+ except (TypeError, ValueError):
260
+ # Fall back to raw if casting fails
261
+ pass
262
+
263
+ # Destination coordinate arrays → destination_x/destination_y
264
+ dest = (
265
+ raw_args.get("destination")
266
+ or raw_args.get("destination_coordinate")
267
+ or raw_args.get("destinationCoordinate")
268
+ )
269
+ if isinstance(dest, list | tuple) and len(dest) >= 2:
270
+ try:
271
+ normalized_args["destination_x"] = int(dest[0])
272
+ normalized_args["destination_y"] = int(dest[1])
273
+ except (TypeError, ValueError):
274
+ pass
275
+
276
+ # Pass through supported fields if present (including direct coords)
277
+ for key in (
278
+ "text",
279
+ "press_enter",
280
+ "clear_before_typing",
281
+ "safety_decision",
282
+ "direction",
283
+ "magnitude",
284
+ "url",
285
+ "keys",
286
+ "x",
287
+ "y",
288
+ "destination_x",
289
+ "destination_y",
290
+ ):
291
+ if key in raw_args:
292
+ normalized_args[key] = raw_args[key]
293
+
294
+ return MCPToolCall(
295
+ name=self._computer_tool_name,
296
+ arguments=normalized_args,
297
+ gemini_name=func_name, # type: ignore[arg-type]
298
+ )
299
+
300
+ # Non-computer tools: use parent implementation
301
+ return super()._extract_tool_call(part)
302
+
303
+ def _remove_old_screenshots(self, messages: list[genai_types.Content]) -> None:
304
+ """
305
+ Remove screenshots from old turns to manage context length.
306
+ Keeps only the last N turns with screenshots (configured via
307
+ self.max_recent_turn_with_screenshots).
308
+ """
309
+ turn_with_screenshots_found = 0
310
+
311
+ for content in reversed(messages):
312
+ if content.role == "user" and content.parts:
313
+ # Check if content has screenshots (function responses with images)
314
+ has_screenshot = False
315
+ for part in content.parts:
316
+ if (
317
+ part.function_response
318
+ and part.function_response.parts
319
+ and part.function_response.name in PREDEFINED_COMPUTER_USE_FUNCTIONS
320
+ ):
321
+ has_screenshot = True
322
+ break
323
+
324
+ if has_screenshot:
325
+ turn_with_screenshots_found += 1
326
+ # Remove the screenshot image if the number of screenshots exceeds the limit
327
+ if turn_with_screenshots_found > self.max_recent_turn_with_screenshots:
328
+ for part in content.parts:
329
+ if (
330
+ part.function_response
331
+ and part.function_response.parts
332
+ and part.function_response.name in PREDEFINED_COMPUTER_USE_FUNCTIONS
333
+ ):
334
+ # Clear the parts (screenshots)
335
+ part.function_response.parts = None