hud-python 0.4.57__py3-none-any.whl → 0.4.59__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of hud-python might be problematic. Click here for more details.

hud/agents/__init__.py CHANGED
@@ -2,11 +2,13 @@ from __future__ import annotations
2
2
 
3
3
  from .base import MCPAgent
4
4
  from .claude import ClaudeAgent
5
+ from .gemini import GeminiAgent
5
6
  from .openai import OperatorAgent
6
7
  from .openai_chat_generic import GenericOpenAIChatAgent
7
8
 
8
9
  __all__ = [
9
10
  "ClaudeAgent",
11
+ "GeminiAgent",
10
12
  "GenericOpenAIChatAgent",
11
13
  "MCPAgent",
12
14
  "OperatorAgent",
hud/agents/gemini.py ADDED
@@ -0,0 +1,492 @@
1
+ """Gemini MCP Agent implementation."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ from typing import TYPE_CHECKING, Any, ClassVar, cast
7
+
8
+ from google import genai
9
+ from google.genai import types as genai_types
10
+
11
+ import hud
12
+
13
+ if TYPE_CHECKING:
14
+ from hud.datasets import Task
15
+
16
+ import mcp.types as types
17
+
18
+ from hud.settings import settings
19
+ from hud.tools.computer.settings import computer_settings
20
+ from hud.types import AgentResponse, MCPToolCall, MCPToolResult
21
+ from hud.utils.hud_console import HUDConsole
22
+
23
+ from .base import MCPAgent
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+ # Predefined Gemini computer use functions
28
+ PREDEFINED_COMPUTER_USE_FUNCTIONS = [
29
+ "open_web_browser",
30
+ "click_at",
31
+ "hover_at",
32
+ "type_text_at",
33
+ "scroll_document",
34
+ "scroll_at",
35
+ "wait_5_seconds",
36
+ "go_back",
37
+ "go_forward",
38
+ "search",
39
+ "navigate",
40
+ "key_combination",
41
+ "drag_and_drop",
42
+ ]
43
+
44
+
45
+ class GeminiAgent(MCPAgent):
46
+ """
47
+ Gemini agent that uses MCP servers for tool execution.
48
+
49
+ This agent uses Gemini's native computer use capabilities but executes
50
+ tools through MCP servers instead of direct implementation.
51
+ """
52
+
53
+ metadata: ClassVar[dict[str, Any]] = {
54
+ "display_width": computer_settings.GEMINI_COMPUTER_WIDTH,
55
+ "display_height": computer_settings.GEMINI_COMPUTER_HEIGHT,
56
+ }
57
+
58
+ def __init__(
59
+ self,
60
+ model_client: genai.Client | None = None,
61
+ model: str = "gemini-2.5-computer-use-preview-10-2025",
62
+ temperature: float = 1.0,
63
+ top_p: float = 0.95,
64
+ top_k: int = 40,
65
+ max_output_tokens: int = 8192,
66
+ validate_api_key: bool = True,
67
+ excluded_predefined_functions: list[str] | None = None,
68
+ **kwargs: Any,
69
+ ) -> None:
70
+ """
71
+ Initialize Gemini MCP agent.
72
+
73
+ Args:
74
+ model_client: Gemini client (created if not provided)
75
+ model: Gemini model to use
76
+ temperature: Temperature for response generation
77
+ top_p: Top-p sampling parameter
78
+ top_k: Top-k sampling parameter
79
+ max_output_tokens: Maximum tokens for response
80
+ validate_api_key: Whether to validate API key on initialization
81
+ excluded_predefined_functions: List of predefined functions to exclude
82
+ **kwargs: Additional arguments passed to BaseMCPAgent (including mcp_client)
83
+ """
84
+ super().__init__(**kwargs)
85
+
86
+ # Initialize client if not provided
87
+ if model_client is None:
88
+ api_key = settings.gemini_api_key
89
+ if not api_key:
90
+ raise ValueError("Gemini API key not found. Set GEMINI_API_KEY.")
91
+ model_client = genai.Client(api_key=api_key)
92
+
93
+ # Validate API key if requested
94
+ if validate_api_key:
95
+ try:
96
+ # Simple validation - try to list models
97
+ list(model_client.models.list(config=genai_types.ListModelsConfig(page_size=1)))
98
+ except Exception as e:
99
+ raise ValueError(f"Gemini API key is invalid: {e}") from e
100
+
101
+ self.gemini_client = model_client
102
+ self.model = model
103
+ self.temperature = temperature
104
+ self.top_p = top_p
105
+ self.top_k = top_k
106
+ self.max_output_tokens = max_output_tokens
107
+ self.excluded_predefined_functions = excluded_predefined_functions or []
108
+ self.hud_console = HUDConsole(logger=logger)
109
+
110
+ # Context management: Maximum number of recent turns to keep screenshots for
111
+ # Configurable via GEMINI_MAX_RECENT_TURN_WITH_SCREENSHOTS environment variable
112
+ self.max_recent_turn_with_screenshots = (
113
+ computer_settings.GEMINI_MAX_RECENT_TURN_WITH_SCREENSHOTS
114
+ )
115
+
116
+ self.model_name = self.model
117
+
118
+ # Track mapping from Gemini tool names to MCP tool names
119
+ self._gemini_to_mcp_tool_map: dict[str, str] = {}
120
+ self.gemini_tools: list[genai_types.Tool] = []
121
+
122
+ # Append Gemini-specific instructions to the base system prompt
123
+ gemini_instructions = "\n".join(
124
+ [
125
+ "You are Gemini, a helpful AI assistant created by Google.",
126
+ "You can interact with computer interfaces.",
127
+ "",
128
+ "When working on tasks:",
129
+ "1. Be thorough and systematic in your approach",
130
+ "2. Complete tasks autonomously without asking for confirmation",
131
+ "3. Use available tools efficiently to accomplish your goals",
132
+ "4. Verify your actions and ensure task completion",
133
+ "5. Be precise and accurate in all operations",
134
+ "6. Adapt to the environment and the task at hand",
135
+ "",
136
+ "Remember: You are expected to complete tasks autonomously.",
137
+ "The user trusts you to accomplish what they asked.",
138
+ ]
139
+ )
140
+
141
+ # Append Gemini instructions to any base system prompt
142
+ if self.system_prompt:
143
+ self.system_prompt = f"{self.system_prompt}\n\n{gemini_instructions}"
144
+ else:
145
+ self.system_prompt = gemini_instructions
146
+
147
+ async def initialize(self, task: str | Task | None = None) -> None:
148
+ """Initialize the agent and build tool mappings."""
149
+ await super().initialize(task)
150
+ # Build tool mappings after tools are discovered
151
+ self._convert_tools_for_gemini()
152
+
153
+ async def get_system_messages(self) -> list[Any]:
154
+ """No system messages for Gemini because applied in get_response"""
155
+ return []
156
+
157
+ async def format_blocks(self, blocks: list[types.ContentBlock]) -> list[genai_types.Content]:
158
+ """Format messages for Gemini."""
159
+ # Convert MCP content types to Gemini content types
160
+ gemini_parts: list[genai_types.Part] = []
161
+
162
+ for block in blocks:
163
+ if isinstance(block, types.TextContent):
164
+ gemini_parts.append(genai_types.Part(text=block.text))
165
+ elif isinstance(block, types.ImageContent):
166
+ # Convert MCP ImageContent to Gemini format
167
+ # Need to decode base64 string to bytes
168
+ import base64
169
+
170
+ image_bytes = base64.b64decode(block.data)
171
+ gemini_parts.append(
172
+ genai_types.Part.from_bytes(data=image_bytes, mime_type=block.mimeType)
173
+ )
174
+ else:
175
+ # For other types, try to handle but log a warning
176
+ self.hud_console.log(f"Unknown content block type: {type(block)}", level="warning")
177
+
178
+ return [genai_types.Content(role="user", parts=gemini_parts)]
179
+
180
+ @hud.instrument(
181
+ span_type="agent",
182
+ record_args=False, # Messages can be large
183
+ record_result=True,
184
+ )
185
+ async def get_response(self, messages: list[genai_types.Content]) -> AgentResponse:
186
+ """Get response from Gemini including any tool calls."""
187
+
188
+ # Build generate content config
189
+ generate_config = genai_types.GenerateContentConfig(
190
+ temperature=self.temperature,
191
+ top_p=self.top_p,
192
+ top_k=self.top_k,
193
+ max_output_tokens=self.max_output_tokens,
194
+ tools=cast("Any", self.gemini_tools),
195
+ system_instruction=self.system_prompt,
196
+ )
197
+
198
+ # Trim screenshots from older turns to manage context growth
199
+ self._remove_old_screenshots(messages)
200
+
201
+ # Make API call - using a simpler call pattern
202
+ response = self.gemini_client.models.generate_content(
203
+ model=self.model,
204
+ contents=cast("Any", messages),
205
+ config=generate_config,
206
+ )
207
+
208
+ # Append assistant response (including any function_call) so that
209
+ # subsequent FunctionResponse messages correspond to a prior FunctionCall
210
+ if response.candidates and len(response.candidates) > 0 and response.candidates[0].content:
211
+ cast("list[genai_types.Content]", messages).append(response.candidates[0].content)
212
+
213
+ # Process response
214
+ result = AgentResponse(content="", tool_calls=[], done=True)
215
+ collected_tool_calls: list[MCPToolCall] = []
216
+
217
+ if not response.candidates:
218
+ self.hud_console.warning("Response has no candidates")
219
+ return result
220
+
221
+ candidate = response.candidates[0]
222
+
223
+ # Extract text content and function calls
224
+ text_content = ""
225
+ thinking_content = ""
226
+
227
+ if candidate.content and candidate.content.parts:
228
+ for part in candidate.content.parts:
229
+ if part.function_call:
230
+ # Map Gemini tool name back to MCP tool name
231
+ func_name = part.function_call.name or ""
232
+ mcp_tool_name = self._gemini_to_mcp_tool_map.get(func_name, func_name)
233
+
234
+ # Create MCPToolCall object with Gemini metadata
235
+ raw_args = dict(part.function_call.args) if part.function_call.args else {}
236
+
237
+ # Normalize Gemini Computer Use calls to MCP tool schema
238
+ if part.function_call.name in PREDEFINED_COMPUTER_USE_FUNCTIONS:
239
+ # Ensure 'action' is present and equals the Gemini function name
240
+ normalized_args: dict[str, Any] = {"action": part.function_call.name}
241
+
242
+ # Map common argument shapes used by Gemini Computer Use
243
+ # 1) Coordinate arrays → x/y
244
+ coord = raw_args.get("coordinate") or raw_args.get("coordinates")
245
+ if isinstance(coord, (list, tuple)) and len(coord) >= 2:
246
+ try:
247
+ normalized_args["x"] = int(coord[0])
248
+ normalized_args["y"] = int(coord[1])
249
+ except (TypeError, ValueError):
250
+ # Fall back to raw if casting fails
251
+ pass
252
+
253
+ # Destination coordinate arrays → destination_x/destination_y
254
+ dest = (
255
+ raw_args.get("destination")
256
+ or raw_args.get("destination_coordinate")
257
+ or raw_args.get("destinationCoordinate")
258
+ )
259
+ if isinstance(dest, (list, tuple)) and len(dest) >= 2:
260
+ try:
261
+ normalized_args["destination_x"] = int(dest[0])
262
+ normalized_args["destination_y"] = int(dest[1])
263
+ except (TypeError, ValueError):
264
+ pass
265
+
266
+ # Pass through supported fields if present (including direct coords)
267
+ for key in (
268
+ "text",
269
+ "press_enter",
270
+ "clear_before_typing",
271
+ "safety_decision",
272
+ "direction",
273
+ "magnitude",
274
+ "url",
275
+ "keys",
276
+ "x",
277
+ "y",
278
+ "destination_x",
279
+ "destination_y",
280
+ ):
281
+ if key in raw_args:
282
+ normalized_args[key] = raw_args[key]
283
+
284
+ # Use normalized args for computer tool calls
285
+ final_args = normalized_args
286
+ else:
287
+ # Non-computer tools: pass args as-is
288
+ final_args = raw_args
289
+
290
+ tool_call = MCPToolCall(
291
+ name=mcp_tool_name,
292
+ arguments=final_args,
293
+ gemini_name=func_name, # type: ignore[arg-type]
294
+ )
295
+ collected_tool_calls.append(tool_call)
296
+ elif part.text:
297
+ text_content += part.text
298
+ elif hasattr(part, "thought") and part.thought:
299
+ thinking_content += f"Thinking: {part.thought}\n"
300
+
301
+ # Assign collected tool calls and mark done status
302
+ if collected_tool_calls:
303
+ result.tool_calls = collected_tool_calls
304
+ result.done = False
305
+
306
+ # Combine text and thinking for final content
307
+ if thinking_content:
308
+ result.content = thinking_content + text_content
309
+ else:
310
+ result.content = text_content
311
+
312
+ return result
313
+
314
+ async def format_tool_results(
315
+ self, tool_calls: list[MCPToolCall], tool_results: list[MCPToolResult]
316
+ ) -> list[genai_types.Content]:
317
+ """Format tool results into Gemini messages."""
318
+ # Process each tool result
319
+ function_responses = []
320
+
321
+ for tool_call, result in zip(tool_calls, tool_results, strict=True):
322
+ # Get the Gemini function name from metadata
323
+ gemini_name = getattr(tool_call, "gemini_name", tool_call.name)
324
+
325
+ # Convert MCP tool results to Gemini format
326
+ response_dict: dict[str, Any] = {}
327
+ url = None
328
+
329
+ if result.isError:
330
+ # Extract error message from content
331
+ error_msg = "Tool execution failed"
332
+ for content in result.content:
333
+ if isinstance(content, types.TextContent):
334
+ # Check if this is a URL metadata block
335
+ if content.text.startswith("__URL__:"):
336
+ url = content.text.replace("__URL__:", "")
337
+ else:
338
+ error_msg = content.text
339
+ break
340
+ response_dict["error"] = error_msg
341
+ else:
342
+ # Process success content
343
+ response_dict["success"] = True
344
+
345
+ # Extract URL and screenshot from content
346
+ screenshot_parts = []
347
+ for content in result.content:
348
+ if isinstance(content, types.TextContent):
349
+ # Check if this is a URL metadata block
350
+ if content.text.startswith("__URL__:"):
351
+ url = content.text.replace("__URL__:", "")
352
+ elif isinstance(content, types.ImageContent):
353
+ # Decode base64 string to bytes for FunctionResponseBlob
354
+ import base64
355
+
356
+ image_bytes = base64.b64decode(content.data)
357
+ screenshot_parts.append(
358
+ genai_types.FunctionResponsePart(
359
+ inline_data=genai_types.FunctionResponseBlob(
360
+ mime_type=content.mimeType or "image/png",
361
+ data=image_bytes,
362
+ )
363
+ )
364
+ )
365
+
366
+ # Add URL to response dict (required by Gemini Computer Use model)
367
+ # URL must ALWAYS be present per Gemini API requirements
368
+ response_dict["url"] = url if url else "about:blank"
369
+
370
+ # For Gemini Computer Use actions, always acknowledge safety decisions
371
+ requires_ack = False
372
+ if tool_call.arguments:
373
+ requires_ack = bool(tool_call.arguments.get("safety_decision"))
374
+ if gemini_name in PREDEFINED_COMPUTER_USE_FUNCTIONS and requires_ack:
375
+ response_dict["safety_acknowledgement"] = True
376
+
377
+ # Create function response
378
+ function_response = genai_types.FunctionResponse(
379
+ name=gemini_name,
380
+ response=response_dict,
381
+ parts=screenshot_parts if screenshot_parts else None,
382
+ )
383
+ function_responses.append(function_response)
384
+
385
+ # Return as a user message containing all function responses
386
+ return [
387
+ genai_types.Content(
388
+ role="user",
389
+ parts=[genai_types.Part(function_response=fr) for fr in function_responses],
390
+ )
391
+ ]
392
+
393
+ async def create_user_message(self, text: str) -> genai_types.Content:
394
+ """Create a user message in Gemini's format."""
395
+ return genai_types.Content(role="user", parts=[genai_types.Part(text=text)])
396
+
397
+ def _convert_tools_for_gemini(self) -> list[genai_types.Tool]:
398
+ """Convert MCP tools to Gemini tool format."""
399
+ gemini_tools = []
400
+ self._gemini_to_mcp_tool_map = {} # Reset mapping
401
+
402
+ # Find computer tool by priority
403
+ computer_tool_priority = ["gemini_computer", "computer_gemini", "computer"]
404
+ selected_computer_tool = None
405
+
406
+ for priority_name in computer_tool_priority:
407
+ for tool in self.get_available_tools():
408
+ # Check both exact match and suffix match (for prefixed tools)
409
+ if tool.name == priority_name or tool.name.endswith(f"_{priority_name}"):
410
+ selected_computer_tool = tool
411
+ break
412
+ if selected_computer_tool:
413
+ break
414
+
415
+ # Add the selected computer tool if found
416
+ if selected_computer_tool:
417
+ gemini_tool = genai_types.Tool(
418
+ computer_use=genai_types.ComputerUse(
419
+ environment=genai_types.Environment.ENVIRONMENT_BROWSER,
420
+ excluded_predefined_functions=self.excluded_predefined_functions,
421
+ )
422
+ )
423
+ # Map Gemini's computer use functions back to the actual MCP tool name
424
+ for func_name in PREDEFINED_COMPUTER_USE_FUNCTIONS:
425
+ if func_name not in self.excluded_predefined_functions:
426
+ self._gemini_to_mcp_tool_map[func_name] = selected_computer_tool.name
427
+
428
+ gemini_tools.append(gemini_tool)
429
+ self.hud_console.debug(
430
+ f"Using {selected_computer_tool.name} as computer tool for Gemini"
431
+ )
432
+
433
+ # Add other non-computer tools as custom functions
434
+ for tool in self.get_available_tools():
435
+ # Skip computer tools (already handled)
436
+ if any(
437
+ tool.name == priority_name or tool.name.endswith(f"_{priority_name}")
438
+ for priority_name in computer_tool_priority
439
+ ):
440
+ continue
441
+
442
+ # Convert MCP tool schema to Gemini function declaration
443
+ try:
444
+ # Ensure parameters have proper Schema format
445
+ params = tool.inputSchema or {"type": "object", "properties": {}}
446
+ function_decl = genai_types.FunctionDeclaration(
447
+ name=tool.name,
448
+ description=tool.description or f"Execute {tool.name}",
449
+ parameters=genai_types.Schema(**params) if isinstance(params, dict) else params, # type: ignore
450
+ )
451
+ custom_tool = genai_types.Tool(function_declarations=[function_decl])
452
+ gemini_tools.append(custom_tool)
453
+ # Direct mapping for non-computer tools
454
+ self._gemini_to_mcp_tool_map[tool.name] = tool.name
455
+ except Exception:
456
+ self.hud_console.warning(f"Failed to convert tool {tool.name} to Gemini format")
457
+
458
+ self.gemini_tools = gemini_tools
459
+ return gemini_tools
460
+
461
+ def _remove_old_screenshots(self, messages: list[genai_types.Content]) -> None:
462
+ """
463
+ Remove screenshots from old turns to manage context length.
464
+ Keeps only the last N turns with screenshots (configured via self.max_recent_turn_with_screenshots).
465
+ """
466
+ turn_with_screenshots_found = 0
467
+
468
+ for content in reversed(messages):
469
+ if content.role == "user" and content.parts:
470
+ # Check if content has screenshots (function responses with images)
471
+ has_screenshot = False
472
+ for part in content.parts:
473
+ if (
474
+ part.function_response
475
+ and part.function_response.parts
476
+ and part.function_response.name in PREDEFINED_COMPUTER_USE_FUNCTIONS
477
+ ):
478
+ has_screenshot = True
479
+ break
480
+
481
+ if has_screenshot:
482
+ turn_with_screenshots_found += 1
483
+ # Remove the screenshot image if the number of screenshots exceeds the limit
484
+ if turn_with_screenshots_found > self.max_recent_turn_with_screenshots:
485
+ for part in content.parts:
486
+ if (
487
+ part.function_response
488
+ and part.function_response.parts
489
+ and part.function_response.name in PREDEFINED_COMPUTER_USE_FUNCTIONS
490
+ ):
491
+ # Clear the parts (screenshots)
492
+ part.function_response.parts = None