cua-agent 0.3.1__py3-none-any.whl → 0.4.0b1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cua-agent might be problematic. Click here for more details.
- agent/__init__.py +15 -51
- agent/__main__.py +21 -0
- agent/adapters/__init__.py +9 -0
- agent/adapters/huggingfacelocal_adapter.py +216 -0
- agent/agent.py +577 -0
- agent/callbacks/__init__.py +17 -0
- agent/callbacks/base.py +153 -0
- agent/callbacks/budget_manager.py +44 -0
- agent/callbacks/image_retention.py +139 -0
- agent/callbacks/logging.py +247 -0
- agent/callbacks/pii_anonymization.py +259 -0
- agent/callbacks/trajectory_saver.py +305 -0
- agent/cli.py +290 -0
- agent/computer_handler.py +107 -0
- agent/decorators.py +90 -0
- agent/loops/__init__.py +11 -0
- agent/loops/anthropic.py +728 -0
- agent/loops/omniparser.py +339 -0
- agent/loops/openai.py +95 -0
- agent/loops/uitars.py +688 -0
- agent/responses.py +207 -0
- agent/types.py +79 -0
- agent/ui/__init__.py +7 -1
- agent/ui/gradio/__init__.py +6 -19
- agent/ui/gradio/app.py +80 -1299
- agent/ui/gradio/ui_components.py +703 -0
- cua_agent-0.4.0b1.dist-info/METADATA +424 -0
- cua_agent-0.4.0b1.dist-info/RECORD +30 -0
- {cua_agent-0.3.1.dist-info → cua_agent-0.4.0b1.dist-info}/WHEEL +1 -1
- agent/core/__init__.py +0 -27
- agent/core/agent.py +0 -210
- agent/core/base.py +0 -217
- agent/core/callbacks.py +0 -200
- agent/core/experiment.py +0 -249
- agent/core/factory.py +0 -122
- agent/core/messages.py +0 -332
- agent/core/provider_config.py +0 -21
- agent/core/telemetry.py +0 -142
- agent/core/tools/__init__.py +0 -21
- agent/core/tools/base.py +0 -74
- agent/core/tools/bash.py +0 -52
- agent/core/tools/collection.py +0 -46
- agent/core/tools/computer.py +0 -113
- agent/core/tools/edit.py +0 -67
- agent/core/tools/manager.py +0 -56
- agent/core/tools.py +0 -32
- agent/core/types.py +0 -88
- agent/core/visualization.py +0 -197
- agent/providers/__init__.py +0 -4
- agent/providers/anthropic/__init__.py +0 -6
- agent/providers/anthropic/api/client.py +0 -360
- agent/providers/anthropic/api/logging.py +0 -150
- agent/providers/anthropic/api_handler.py +0 -140
- agent/providers/anthropic/callbacks/__init__.py +0 -5
- agent/providers/anthropic/callbacks/manager.py +0 -65
- agent/providers/anthropic/loop.py +0 -568
- agent/providers/anthropic/prompts.py +0 -23
- agent/providers/anthropic/response_handler.py +0 -226
- agent/providers/anthropic/tools/__init__.py +0 -33
- agent/providers/anthropic/tools/base.py +0 -88
- agent/providers/anthropic/tools/bash.py +0 -66
- agent/providers/anthropic/tools/collection.py +0 -34
- agent/providers/anthropic/tools/computer.py +0 -396
- agent/providers/anthropic/tools/edit.py +0 -326
- agent/providers/anthropic/tools/manager.py +0 -54
- agent/providers/anthropic/tools/run.py +0 -42
- agent/providers/anthropic/types.py +0 -16
- agent/providers/anthropic/utils.py +0 -367
- agent/providers/omni/__init__.py +0 -8
- agent/providers/omni/api_handler.py +0 -42
- agent/providers/omni/clients/anthropic.py +0 -103
- agent/providers/omni/clients/base.py +0 -35
- agent/providers/omni/clients/oaicompat.py +0 -195
- agent/providers/omni/clients/ollama.py +0 -122
- agent/providers/omni/clients/openai.py +0 -155
- agent/providers/omni/clients/utils.py +0 -25
- agent/providers/omni/image_utils.py +0 -34
- agent/providers/omni/loop.py +0 -990
- agent/providers/omni/parser.py +0 -307
- agent/providers/omni/prompts.py +0 -64
- agent/providers/omni/tools/__init__.py +0 -30
- agent/providers/omni/tools/base.py +0 -29
- agent/providers/omni/tools/bash.py +0 -74
- agent/providers/omni/tools/computer.py +0 -179
- agent/providers/omni/tools/manager.py +0 -61
- agent/providers/omni/utils.py +0 -236
- agent/providers/openai/__init__.py +0 -6
- agent/providers/openai/api_handler.py +0 -456
- agent/providers/openai/loop.py +0 -472
- agent/providers/openai/response_handler.py +0 -205
- agent/providers/openai/tools/__init__.py +0 -15
- agent/providers/openai/tools/base.py +0 -79
- agent/providers/openai/tools/computer.py +0 -326
- agent/providers/openai/tools/manager.py +0 -106
- agent/providers/openai/types.py +0 -36
- agent/providers/openai/utils.py +0 -98
- agent/providers/uitars/__init__.py +0 -1
- agent/providers/uitars/clients/base.py +0 -35
- agent/providers/uitars/clients/mlxvlm.py +0 -263
- agent/providers/uitars/clients/oaicompat.py +0 -214
- agent/providers/uitars/loop.py +0 -660
- agent/providers/uitars/prompts.py +0 -63
- agent/providers/uitars/tools/__init__.py +0 -1
- agent/providers/uitars/tools/computer.py +0 -283
- agent/providers/uitars/tools/manager.py +0 -60
- agent/providers/uitars/utils.py +0 -264
- agent/telemetry.py +0 -21
- agent/ui/__main__.py +0 -15
- cua_agent-0.3.1.dist-info/METADATA +0 -295
- cua_agent-0.3.1.dist-info/RECORD +0 -87
- {cua_agent-0.3.1.dist-info → cua_agent-0.4.0b1.dist-info}/entry_points.txt +0 -0
agent/providers/omni/loop.py
DELETED
|
@@ -1,990 +0,0 @@
|
|
|
1
|
-
"""Omni-specific agent loop implementation."""
|
|
2
|
-
|
|
3
|
-
import logging
|
|
4
|
-
from typing import Any, Dict, List, Optional, Tuple, AsyncGenerator
|
|
5
|
-
import json
|
|
6
|
-
import re
|
|
7
|
-
import os
|
|
8
|
-
import asyncio
|
|
9
|
-
from httpx import ConnectError, ReadTimeout
|
|
10
|
-
from typing import cast
|
|
11
|
-
|
|
12
|
-
from .parser import OmniParser, ParseResult
|
|
13
|
-
from ...core.base import BaseLoop
|
|
14
|
-
from ...core.visualization import VisualizationHelper
|
|
15
|
-
from ...core.messages import StandardMessageManager, ImageRetentionConfig
|
|
16
|
-
from .utils import to_openai_agent_response_format
|
|
17
|
-
from ...core.types import AgentResponse
|
|
18
|
-
from computer import Computer
|
|
19
|
-
from ...core.types import LLMProvider
|
|
20
|
-
from .clients.openai import OpenAIClient
|
|
21
|
-
from .clients.anthropic import AnthropicClient
|
|
22
|
-
from .clients.ollama import OllamaClient
|
|
23
|
-
from .clients.oaicompat import OAICompatClient
|
|
24
|
-
from .prompts import SYSTEM_PROMPT
|
|
25
|
-
from .api_handler import OmniAPIHandler
|
|
26
|
-
from .tools.manager import ToolManager
|
|
27
|
-
from .tools import ToolResult
|
|
28
|
-
|
|
29
|
-
logger = logging.getLogger(__name__)
|
|
30
|
-
|
|
31
|
-
def extract_data(input_string: str, data_type: str) -> str:
|
|
32
|
-
"""Extract content from code blocks."""
|
|
33
|
-
pattern = f"```{data_type}" + r"(.*?)(```|$)"
|
|
34
|
-
matches = re.findall(pattern, input_string, re.DOTALL)
|
|
35
|
-
return matches[0][0].strip() if matches else input_string
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
class OmniLoop(BaseLoop):
|
|
39
|
-
"""Omni-specific implementation of the agent loop.
|
|
40
|
-
|
|
41
|
-
This class extends BaseLoop to provide support for multimodal models
|
|
42
|
-
from various providers (OpenAI, Anthropic, etc.) with UI parsing
|
|
43
|
-
and desktop automation capabilities.
|
|
44
|
-
"""
|
|
45
|
-
|
|
46
|
-
###########################################
|
|
47
|
-
# INITIALIZATION AND CONFIGURATION
|
|
48
|
-
###########################################
|
|
49
|
-
|
|
50
|
-
def __init__(
|
|
51
|
-
self,
|
|
52
|
-
parser: OmniParser,
|
|
53
|
-
provider: LLMProvider,
|
|
54
|
-
api_key: str,
|
|
55
|
-
model: str,
|
|
56
|
-
computer: Computer,
|
|
57
|
-
only_n_most_recent_images: Optional[int] = 2,
|
|
58
|
-
base_dir: Optional[str] = "trajectories",
|
|
59
|
-
max_retries: int = 3,
|
|
60
|
-
retry_delay: float = 1.0,
|
|
61
|
-
save_trajectory: bool = True,
|
|
62
|
-
provider_base_url: Optional[str] = None,
|
|
63
|
-
**kwargs,
|
|
64
|
-
):
|
|
65
|
-
"""Initialize the loop.
|
|
66
|
-
|
|
67
|
-
Args:
|
|
68
|
-
parser: Parser instance
|
|
69
|
-
provider: API provider
|
|
70
|
-
api_key: API key
|
|
71
|
-
model: Model name
|
|
72
|
-
computer: Computer instance
|
|
73
|
-
only_n_most_recent_images: Maximum number of recent screenshots to include in API requests
|
|
74
|
-
base_dir: Base directory for saving experiment data
|
|
75
|
-
max_retries: Maximum number of retries for API calls
|
|
76
|
-
retry_delay: Delay between retries in seconds
|
|
77
|
-
save_trajectory: Whether to save trajectory data
|
|
78
|
-
provider_base_url: Base URL for the API provider (used for OAICOMPAT)
|
|
79
|
-
"""
|
|
80
|
-
# Set parser and provider before initializing base class
|
|
81
|
-
self.parser = parser
|
|
82
|
-
self.provider = provider
|
|
83
|
-
self.provider_base_url = provider_base_url
|
|
84
|
-
|
|
85
|
-
# Initialize message manager with image retention config
|
|
86
|
-
self.message_manager = StandardMessageManager(
|
|
87
|
-
config=ImageRetentionConfig(num_images_to_keep=only_n_most_recent_images)
|
|
88
|
-
)
|
|
89
|
-
|
|
90
|
-
# Initialize base class (which will set up experiment manager)
|
|
91
|
-
super().__init__(
|
|
92
|
-
computer=computer,
|
|
93
|
-
model=model,
|
|
94
|
-
api_key=api_key,
|
|
95
|
-
max_retries=max_retries,
|
|
96
|
-
retry_delay=retry_delay,
|
|
97
|
-
base_dir=base_dir,
|
|
98
|
-
save_trajectory=save_trajectory,
|
|
99
|
-
only_n_most_recent_images=only_n_most_recent_images,
|
|
100
|
-
**kwargs,
|
|
101
|
-
)
|
|
102
|
-
|
|
103
|
-
# Set API client attributes
|
|
104
|
-
self.client = None
|
|
105
|
-
self.retry_count = 0
|
|
106
|
-
self.loop_task = None # Store the loop task for cancellation
|
|
107
|
-
|
|
108
|
-
# Initialize handlers
|
|
109
|
-
self.api_handler = OmniAPIHandler(loop=self)
|
|
110
|
-
self.viz_helper = VisualizationHelper(agent=self)
|
|
111
|
-
|
|
112
|
-
# Initialize tool manager
|
|
113
|
-
self.tool_manager = ToolManager(computer=computer, provider=provider)
|
|
114
|
-
|
|
115
|
-
logger.info("OmniLoop initialized with StandardMessageManager")
|
|
116
|
-
|
|
117
|
-
async def initialize(self) -> None:
|
|
118
|
-
"""Initialize the loop by setting up tools and clients."""
|
|
119
|
-
# Initialize base class
|
|
120
|
-
await super().initialize()
|
|
121
|
-
|
|
122
|
-
# Initialize tool manager with error handling
|
|
123
|
-
try:
|
|
124
|
-
logger.info("Initializing tool manager...")
|
|
125
|
-
await self.tool_manager.initialize()
|
|
126
|
-
logger.info("Tool manager initialized successfully.")
|
|
127
|
-
except Exception as e:
|
|
128
|
-
logger.error(f"Error initializing tool manager: {str(e)}")
|
|
129
|
-
logger.warning("Will attempt to initialize tools on first use.")
|
|
130
|
-
|
|
131
|
-
# Initialize API clients based on provider
|
|
132
|
-
if self.provider == LLMProvider.ANTHROPIC:
|
|
133
|
-
self.client = AnthropicClient(
|
|
134
|
-
api_key=self.api_key,
|
|
135
|
-
model=self.model,
|
|
136
|
-
)
|
|
137
|
-
elif self.provider == LLMProvider.OPENAI:
|
|
138
|
-
self.client = OpenAIClient(
|
|
139
|
-
api_key=self.api_key,
|
|
140
|
-
model=self.model,
|
|
141
|
-
)
|
|
142
|
-
elif self.provider == LLMProvider.OLLAMA:
|
|
143
|
-
self.client = OllamaClient(
|
|
144
|
-
api_key=self.api_key,
|
|
145
|
-
model=self.model,
|
|
146
|
-
)
|
|
147
|
-
elif self.provider == LLMProvider.OAICOMPAT:
|
|
148
|
-
self.client = OAICompatClient(
|
|
149
|
-
api_key=self.api_key or "EMPTY", # Local endpoints typically don't require an API key
|
|
150
|
-
model=self.model,
|
|
151
|
-
provider_base_url=self.provider_base_url,
|
|
152
|
-
)
|
|
153
|
-
else:
|
|
154
|
-
raise ValueError(f"Unsupported provider: {self.provider}")
|
|
155
|
-
|
|
156
|
-
###########################################
|
|
157
|
-
# CLIENT INITIALIZATION - IMPLEMENTING ABSTRACT METHOD
|
|
158
|
-
###########################################
|
|
159
|
-
|
|
160
|
-
async def initialize_client(self) -> None:
|
|
161
|
-
"""Initialize the appropriate client based on provider.
|
|
162
|
-
|
|
163
|
-
Implements abstract method from BaseLoop to set up the specific
|
|
164
|
-
provider client (OpenAI, Anthropic, etc.).
|
|
165
|
-
"""
|
|
166
|
-
try:
|
|
167
|
-
logger.info(f"Initializing {self.provider} client with model {self.model}...")
|
|
168
|
-
|
|
169
|
-
if self.provider == LLMProvider.OPENAI:
|
|
170
|
-
self.client = OpenAIClient(api_key=self.api_key, model=self.model)
|
|
171
|
-
elif self.provider == LLMProvider.ANTHROPIC:
|
|
172
|
-
self.client = AnthropicClient(
|
|
173
|
-
api_key=self.api_key,
|
|
174
|
-
model=self.model,
|
|
175
|
-
max_retries=self.max_retries,
|
|
176
|
-
retry_delay=self.retry_delay,
|
|
177
|
-
)
|
|
178
|
-
elif self.provider == LLMProvider.OLLAMA:
|
|
179
|
-
self.client = OllamaClient(
|
|
180
|
-
api_key=self.api_key,
|
|
181
|
-
model=self.model,
|
|
182
|
-
)
|
|
183
|
-
elif self.provider == LLMProvider.OAICOMPAT:
|
|
184
|
-
self.client = OAICompatClient(
|
|
185
|
-
api_key=self.api_key or "EMPTY", # Local endpoints typically don't require an API key
|
|
186
|
-
model=self.model,
|
|
187
|
-
provider_base_url=self.provider_base_url,
|
|
188
|
-
)
|
|
189
|
-
else:
|
|
190
|
-
raise ValueError(f"Unsupported provider: {self.provider}")
|
|
191
|
-
|
|
192
|
-
logger.info(f"Initialized {self.provider} client with model {self.model}")
|
|
193
|
-
except Exception as e:
|
|
194
|
-
logger.error(f"Error initializing client: {str(e)}")
|
|
195
|
-
self.client = None
|
|
196
|
-
raise RuntimeError(f"Failed to initialize client: {str(e)}")
|
|
197
|
-
|
|
198
|
-
###########################################
|
|
199
|
-
# API CALL HANDLING
|
|
200
|
-
###########################################
|
|
201
|
-
|
|
202
|
-
async def _make_api_call(self, messages: List[Dict[str, Any]], system_prompt: str) -> Any:
|
|
203
|
-
"""Make API call to provider with retry logic."""
|
|
204
|
-
# Create new turn directory for this API call
|
|
205
|
-
self._create_turn_dir()
|
|
206
|
-
|
|
207
|
-
request_data = None
|
|
208
|
-
last_error = None
|
|
209
|
-
|
|
210
|
-
for attempt in range(self.max_retries):
|
|
211
|
-
try:
|
|
212
|
-
# Ensure client is initialized
|
|
213
|
-
if self.client is None:
|
|
214
|
-
logger.info(
|
|
215
|
-
f"Client not initialized in _make_api_call (attempt {attempt+1}), initializing now..."
|
|
216
|
-
)
|
|
217
|
-
await self.initialize_client()
|
|
218
|
-
if self.client is None:
|
|
219
|
-
raise RuntimeError("Failed to initialize client")
|
|
220
|
-
|
|
221
|
-
# Get messages in standard format from the message manager
|
|
222
|
-
self.message_manager.messages = messages.copy()
|
|
223
|
-
prepared_messages = self.message_manager.get_messages()
|
|
224
|
-
|
|
225
|
-
# Special handling for Anthropic
|
|
226
|
-
if self.provider == LLMProvider.ANTHROPIC:
|
|
227
|
-
# Convert to Anthropic format
|
|
228
|
-
anthropic_messages, anthropic_system = self.message_manager.to_anthropic_format(
|
|
229
|
-
prepared_messages
|
|
230
|
-
)
|
|
231
|
-
|
|
232
|
-
# Filter out any empty/invalid messages
|
|
233
|
-
filtered_messages = [
|
|
234
|
-
msg
|
|
235
|
-
for msg in anthropic_messages
|
|
236
|
-
if msg.get("role") in ["user", "assistant"]
|
|
237
|
-
]
|
|
238
|
-
|
|
239
|
-
# Ensure there's at least one message for Anthropic
|
|
240
|
-
if not filtered_messages:
|
|
241
|
-
logger.warning(
|
|
242
|
-
"No valid messages found for Anthropic API call. Adding a default user message."
|
|
243
|
-
)
|
|
244
|
-
filtered_messages = [
|
|
245
|
-
{
|
|
246
|
-
"role": "user",
|
|
247
|
-
"content": [
|
|
248
|
-
{"type": "text", "text": "Please help with this task."}
|
|
249
|
-
],
|
|
250
|
-
}
|
|
251
|
-
]
|
|
252
|
-
|
|
253
|
-
# Combine system prompts if needed
|
|
254
|
-
final_system_prompt = anthropic_system or system_prompt
|
|
255
|
-
|
|
256
|
-
# Log request
|
|
257
|
-
request_data = {
|
|
258
|
-
"messages": filtered_messages,
|
|
259
|
-
"max_tokens": self.max_tokens,
|
|
260
|
-
"system": final_system_prompt,
|
|
261
|
-
}
|
|
262
|
-
|
|
263
|
-
self._log_api_call("request", request_data)
|
|
264
|
-
|
|
265
|
-
# Make API call
|
|
266
|
-
response = await self.client.run_interleaved(
|
|
267
|
-
messages=filtered_messages,
|
|
268
|
-
system=final_system_prompt,
|
|
269
|
-
max_tokens=self.max_tokens,
|
|
270
|
-
)
|
|
271
|
-
else:
|
|
272
|
-
# For OpenAI and others, use standard format directly
|
|
273
|
-
# Log request
|
|
274
|
-
request_data = {
|
|
275
|
-
"messages": prepared_messages,
|
|
276
|
-
"max_tokens": self.max_tokens,
|
|
277
|
-
"system": system_prompt,
|
|
278
|
-
}
|
|
279
|
-
|
|
280
|
-
self._log_api_call("request", request_data)
|
|
281
|
-
|
|
282
|
-
# Make API call
|
|
283
|
-
response = await self.client.run_interleaved(
|
|
284
|
-
messages=prepared_messages,
|
|
285
|
-
system=system_prompt,
|
|
286
|
-
max_tokens=self.max_tokens,
|
|
287
|
-
)
|
|
288
|
-
|
|
289
|
-
# Log success response
|
|
290
|
-
self._log_api_call("response", request_data, response)
|
|
291
|
-
|
|
292
|
-
return response
|
|
293
|
-
|
|
294
|
-
except (ConnectError, ReadTimeout) as e:
|
|
295
|
-
last_error = e
|
|
296
|
-
logger.warning(
|
|
297
|
-
f"Connection error on attempt {attempt + 1}/{self.max_retries}: {str(e)}"
|
|
298
|
-
)
|
|
299
|
-
if attempt < self.max_retries - 1:
|
|
300
|
-
await asyncio.sleep(self.retry_delay * (attempt + 1)) # Exponential backoff
|
|
301
|
-
# Reset client on connection errors to force re-initialization
|
|
302
|
-
self.client = None
|
|
303
|
-
continue
|
|
304
|
-
|
|
305
|
-
except RuntimeError as e:
|
|
306
|
-
# Handle client initialization errors specifically
|
|
307
|
-
last_error = e
|
|
308
|
-
self._log_api_call("error", request_data, error=e)
|
|
309
|
-
logger.error(
|
|
310
|
-
f"Client initialization error (attempt {attempt + 1}/{self.max_retries}): {str(e)}"
|
|
311
|
-
)
|
|
312
|
-
if attempt < self.max_retries - 1:
|
|
313
|
-
# Reset client to force re-initialization
|
|
314
|
-
self.client = None
|
|
315
|
-
await asyncio.sleep(self.retry_delay)
|
|
316
|
-
continue
|
|
317
|
-
|
|
318
|
-
except Exception as e:
|
|
319
|
-
# Log unexpected error
|
|
320
|
-
last_error = e
|
|
321
|
-
self._log_api_call("error", request_data, error=e)
|
|
322
|
-
logger.error(f"Unexpected error in API call: {str(e)}")
|
|
323
|
-
if attempt < self.max_retries - 1:
|
|
324
|
-
await asyncio.sleep(self.retry_delay)
|
|
325
|
-
continue
|
|
326
|
-
|
|
327
|
-
# If we get here, all retries failed
|
|
328
|
-
error_message = f"API call failed after {self.max_retries} attempts"
|
|
329
|
-
if last_error:
|
|
330
|
-
error_message += f": {str(last_error)}"
|
|
331
|
-
|
|
332
|
-
logger.error(error_message)
|
|
333
|
-
raise RuntimeError(error_message)
|
|
334
|
-
|
|
335
|
-
###########################################
|
|
336
|
-
# RESPONSE AND ACTION HANDLING
|
|
337
|
-
###########################################
|
|
338
|
-
|
|
339
|
-
async def _handle_response(
|
|
340
|
-
self, response: Any, messages: List[Dict[str, Any]], parsed_screen: ParseResult
|
|
341
|
-
) -> Tuple[bool, bool]:
|
|
342
|
-
"""Handle API response.
|
|
343
|
-
|
|
344
|
-
Args:
|
|
345
|
-
response: API response
|
|
346
|
-
messages: List of messages to update
|
|
347
|
-
parsed_screen: Current parsed screen information
|
|
348
|
-
|
|
349
|
-
Returns:
|
|
350
|
-
Tuple of (should_continue, action_screenshot_saved)
|
|
351
|
-
"""
|
|
352
|
-
action_screenshot_saved = False
|
|
353
|
-
|
|
354
|
-
# Helper function to safely add assistant messages using the message manager
|
|
355
|
-
def add_assistant_message(content):
|
|
356
|
-
if isinstance(content, str):
|
|
357
|
-
# Convert string to proper format
|
|
358
|
-
formatted_content = [{"type": "text", "text": content}]
|
|
359
|
-
self.message_manager.add_assistant_message(formatted_content)
|
|
360
|
-
logger.info("Added formatted text assistant message")
|
|
361
|
-
elif isinstance(content, list):
|
|
362
|
-
# Already in proper format
|
|
363
|
-
self.message_manager.add_assistant_message(content)
|
|
364
|
-
logger.info("Added structured assistant message")
|
|
365
|
-
else:
|
|
366
|
-
# Default case - convert to string
|
|
367
|
-
formatted_content = [{"type": "text", "text": str(content)}]
|
|
368
|
-
self.message_manager.add_assistant_message(formatted_content)
|
|
369
|
-
logger.info("Added converted assistant message")
|
|
370
|
-
|
|
371
|
-
try:
|
|
372
|
-
# Step 1: Normalize response to standard format based on provider
|
|
373
|
-
standard_content = []
|
|
374
|
-
raw_text = None
|
|
375
|
-
|
|
376
|
-
# Convert response to standardized content based on provider
|
|
377
|
-
if self.provider == LLMProvider.ANTHROPIC:
|
|
378
|
-
if hasattr(response, "content") and isinstance(response.content, list):
|
|
379
|
-
# Convert Anthropic response to standard format
|
|
380
|
-
for block in response.content:
|
|
381
|
-
if hasattr(block, "type"):
|
|
382
|
-
if block.type == "text":
|
|
383
|
-
standard_content.append({"type": "text", "text": block.text})
|
|
384
|
-
# Store raw text for JSON parsing
|
|
385
|
-
if raw_text is None:
|
|
386
|
-
raw_text = block.text
|
|
387
|
-
else:
|
|
388
|
-
raw_text += "\n" + block.text
|
|
389
|
-
else:
|
|
390
|
-
# Add other block types
|
|
391
|
-
block_dict = {}
|
|
392
|
-
for key, value in vars(block).items():
|
|
393
|
-
if not key.startswith("_"):
|
|
394
|
-
block_dict[key] = value
|
|
395
|
-
standard_content.append(block_dict)
|
|
396
|
-
else:
|
|
397
|
-
logger.warning("Invalid Anthropic response format")
|
|
398
|
-
return True, action_screenshot_saved
|
|
399
|
-
elif self.provider == LLMProvider.OLLAMA:
|
|
400
|
-
try:
|
|
401
|
-
raw_text = response["message"]["content"]
|
|
402
|
-
standard_content = [{"type": "text", "text": raw_text}]
|
|
403
|
-
except (KeyError, TypeError, IndexError) as e:
|
|
404
|
-
logger.error(f"Invalid response format: {str(e)}")
|
|
405
|
-
return True, action_screenshot_saved
|
|
406
|
-
elif self.provider == LLMProvider.OAICOMPAT:
|
|
407
|
-
try:
|
|
408
|
-
# OpenAI-compatible response format
|
|
409
|
-
raw_text = response["choices"][0]["message"]["content"]
|
|
410
|
-
standard_content = [{"type": "text", "text": raw_text}]
|
|
411
|
-
except (KeyError, TypeError, IndexError) as e:
|
|
412
|
-
logger.error(f"Invalid response format: {str(e)}")
|
|
413
|
-
return True, action_screenshot_saved
|
|
414
|
-
else:
|
|
415
|
-
# Assume OpenAI or compatible format
|
|
416
|
-
try:
|
|
417
|
-
raw_text = response["choices"][0]["message"]["content"]
|
|
418
|
-
standard_content = [{"type": "text", "text": raw_text}]
|
|
419
|
-
except (KeyError, TypeError, IndexError) as e:
|
|
420
|
-
logger.error(f"Invalid response format: {str(e)}")
|
|
421
|
-
return True, action_screenshot_saved
|
|
422
|
-
|
|
423
|
-
# Step 2: Add the normalized response to message history
|
|
424
|
-
add_assistant_message(standard_content)
|
|
425
|
-
|
|
426
|
-
# Step 3: Extract JSON from the content for action execution
|
|
427
|
-
parsed_content = None
|
|
428
|
-
|
|
429
|
-
# If we have raw text, try to extract JSON from it
|
|
430
|
-
if raw_text:
|
|
431
|
-
# Try different approaches to extract JSON
|
|
432
|
-
try:
|
|
433
|
-
# First try to parse the whole content as JSON
|
|
434
|
-
parsed_content = json.loads(raw_text)
|
|
435
|
-
logger.info("Successfully parsed whole content as JSON")
|
|
436
|
-
except json.JSONDecodeError:
|
|
437
|
-
try:
|
|
438
|
-
# Try to find JSON block
|
|
439
|
-
json_content = extract_data(raw_text, "json")
|
|
440
|
-
parsed_content = json.loads(json_content)
|
|
441
|
-
logger.info("Successfully parsed JSON from code block")
|
|
442
|
-
except (json.JSONDecodeError, IndexError):
|
|
443
|
-
try:
|
|
444
|
-
# Look for JSON object pattern
|
|
445
|
-
import re # Local import to ensure availability
|
|
446
|
-
|
|
447
|
-
json_pattern = r"\{[^}]+\}"
|
|
448
|
-
json_match = re.search(json_pattern, raw_text)
|
|
449
|
-
if json_match:
|
|
450
|
-
json_str = json_match.group(0)
|
|
451
|
-
parsed_content = json.loads(json_str)
|
|
452
|
-
logger.info("Successfully parsed JSON from text")
|
|
453
|
-
else:
|
|
454
|
-
logger.error(f"No JSON found in content")
|
|
455
|
-
return True, action_screenshot_saved
|
|
456
|
-
except json.JSONDecodeError as e:
|
|
457
|
-
# Try to sanitize the JSON string and retry
|
|
458
|
-
try:
|
|
459
|
-
# Remove or replace invalid control characters
|
|
460
|
-
import re # Local import to ensure availability
|
|
461
|
-
|
|
462
|
-
sanitized_text = re.sub(r"[\x00-\x1F\x7F]", "", raw_text)
|
|
463
|
-
# Try parsing again with sanitized text
|
|
464
|
-
parsed_content = json.loads(sanitized_text)
|
|
465
|
-
logger.info(
|
|
466
|
-
"Successfully parsed JSON after sanitizing control characters"
|
|
467
|
-
)
|
|
468
|
-
except json.JSONDecodeError:
|
|
469
|
-
logger.error(f"Failed to parse JSON from text: {str(e)}")
|
|
470
|
-
return True, action_screenshot_saved
|
|
471
|
-
|
|
472
|
-
# Step 4: Process the parsed content if available
|
|
473
|
-
if parsed_content:
|
|
474
|
-
# Clean up Box ID format
|
|
475
|
-
if "Box ID" in parsed_content and isinstance(parsed_content["Box ID"], str):
|
|
476
|
-
parsed_content["Box ID"] = parsed_content["Box ID"].replace("Box #", "")
|
|
477
|
-
|
|
478
|
-
# Add any explanatory text as reasoning if not present
|
|
479
|
-
if "Explanation" not in parsed_content and raw_text:
|
|
480
|
-
# Extract any text before the JSON as reasoning
|
|
481
|
-
text_before_json = raw_text.split("{")[0].strip()
|
|
482
|
-
if text_before_json:
|
|
483
|
-
parsed_content["Explanation"] = text_before_json
|
|
484
|
-
|
|
485
|
-
# Log the parsed content for debugging
|
|
486
|
-
logger.info(f"Parsed content: {json.dumps(parsed_content, indent=2)}")
|
|
487
|
-
|
|
488
|
-
# Step 5: Execute the action
|
|
489
|
-
try:
|
|
490
|
-
# Execute action using the common helper method
|
|
491
|
-
should_continue, action_screenshot_saved = (
|
|
492
|
-
await self._execute_action_with_tools(
|
|
493
|
-
parsed_content, cast(ParseResult, parsed_screen)
|
|
494
|
-
)
|
|
495
|
-
)
|
|
496
|
-
|
|
497
|
-
# Check if task is complete
|
|
498
|
-
if parsed_content.get("Action") == "None":
|
|
499
|
-
return False, action_screenshot_saved
|
|
500
|
-
return should_continue, action_screenshot_saved
|
|
501
|
-
except Exception as e:
|
|
502
|
-
logger.error(f"Error executing action: {str(e)}")
|
|
503
|
-
# Update the last assistant message with error
|
|
504
|
-
error_message = [{"type": "text", "text": f"Error executing action: {str(e)}"}]
|
|
505
|
-
# Replace the last assistant message with the error
|
|
506
|
-
self.message_manager.add_assistant_message(error_message)
|
|
507
|
-
return False, action_screenshot_saved
|
|
508
|
-
|
|
509
|
-
return True, action_screenshot_saved
|
|
510
|
-
|
|
511
|
-
except Exception as e:
|
|
512
|
-
logger.error(f"Error handling response: {str(e)}")
|
|
513
|
-
# Add error message using the message manager
|
|
514
|
-
error_message = [{"type": "text", "text": f"Error: {str(e)}"}]
|
|
515
|
-
self.message_manager.add_assistant_message(error_message)
|
|
516
|
-
raise
|
|
517
|
-
|
|
518
|
-
###########################################
|
|
519
|
-
# SCREEN PARSING - IMPLEMENTING ABSTRACT METHOD
|
|
520
|
-
###########################################
|
|
521
|
-
|
|
522
|
-
async def _get_parsed_screen_som(self, save_screenshot: bool = True) -> ParseResult:
|
|
523
|
-
"""Get parsed screen information with Screen Object Model.
|
|
524
|
-
|
|
525
|
-
Extends the base class method to use the OmniParser to parse the screen
|
|
526
|
-
and extract UI elements.
|
|
527
|
-
|
|
528
|
-
Args:
|
|
529
|
-
save_screenshot: Whether to save the screenshot (set to False when screenshots will be saved elsewhere)
|
|
530
|
-
|
|
531
|
-
Returns:
|
|
532
|
-
ParseResult containing screen information and elements
|
|
533
|
-
"""
|
|
534
|
-
try:
|
|
535
|
-
# Use the parser's parse_screen method which handles the screenshot internally
|
|
536
|
-
parsed_screen = await self.parser.parse_screen(computer=self.computer)
|
|
537
|
-
|
|
538
|
-
# Log information about the parsed results
|
|
539
|
-
logger.info(
|
|
540
|
-
f"Parsed screen with {len(parsed_screen.elements) if parsed_screen.elements else 0} elements"
|
|
541
|
-
)
|
|
542
|
-
|
|
543
|
-
# Save screenshot if requested and if we have image data
|
|
544
|
-
if save_screenshot and self.save_trajectory and parsed_screen.annotated_image_base64:
|
|
545
|
-
try:
|
|
546
|
-
# Extract just the image data (remove data:image/png;base64, prefix)
|
|
547
|
-
img_data = parsed_screen.annotated_image_base64
|
|
548
|
-
if "," in img_data:
|
|
549
|
-
img_data = img_data.split(",")[1]
|
|
550
|
-
|
|
551
|
-
# Process screenshot through hooks and save if needed
|
|
552
|
-
await self.handle_screenshot(img_data, action_type="state", parsed_screen=parsed_screen)
|
|
553
|
-
|
|
554
|
-
# Save with a generic "state" action type to indicate this is the current screen state
|
|
555
|
-
self._save_screenshot(img_data, action_type="state")
|
|
556
|
-
except Exception as e:
|
|
557
|
-
logger.error(f"Error saving screenshot: {str(e)}")
|
|
558
|
-
|
|
559
|
-
return parsed_screen
|
|
560
|
-
|
|
561
|
-
except Exception as e:
|
|
562
|
-
logger.error(f"Error getting parsed screen: {str(e)}")
|
|
563
|
-
raise
|
|
564
|
-
|
|
565
|
-
def _get_system_prompt(self) -> str:
|
|
566
|
-
"""Get the system prompt for the model."""
|
|
567
|
-
return SYSTEM_PROMPT
|
|
568
|
-
|
|
569
|
-
###########################################
|
|
570
|
-
# MAIN LOOP - IMPLEMENTING ABSTRACT METHOD
|
|
571
|
-
###########################################
|
|
572
|
-
|
|
573
|
-
async def run(self, messages: List[Dict[str, Any]]) -> AsyncGenerator[AgentResponse, None]:
|
|
574
|
-
"""Run the agent loop with provided messages.
|
|
575
|
-
|
|
576
|
-
Args:
|
|
577
|
-
messages: List of messages in standard OpenAI format
|
|
578
|
-
|
|
579
|
-
Yields:
|
|
580
|
-
Agent response format
|
|
581
|
-
"""
|
|
582
|
-
try:
|
|
583
|
-
logger.info(f"Starting OmniLoop run with {len(messages)} messages")
|
|
584
|
-
|
|
585
|
-
# Initialize the message manager with the provided messages
|
|
586
|
-
self.message_manager.messages = messages.copy()
|
|
587
|
-
|
|
588
|
-
# Create queue for response streaming
|
|
589
|
-
queue = asyncio.Queue()
|
|
590
|
-
|
|
591
|
-
# Start loop in background task
|
|
592
|
-
self.loop_task = asyncio.create_task(self._run_loop(queue, messages))
|
|
593
|
-
|
|
594
|
-
# Process and yield messages as they arrive
|
|
595
|
-
while True:
|
|
596
|
-
try:
|
|
597
|
-
item = await queue.get()
|
|
598
|
-
if item is None: # Stop signal
|
|
599
|
-
break
|
|
600
|
-
yield item
|
|
601
|
-
queue.task_done()
|
|
602
|
-
except Exception as e:
|
|
603
|
-
logger.error(f"Error processing queue item: {str(e)}")
|
|
604
|
-
continue
|
|
605
|
-
|
|
606
|
-
# Wait for loop to complete
|
|
607
|
-
await self.loop_task
|
|
608
|
-
|
|
609
|
-
# Send completion message
|
|
610
|
-
yield {
|
|
611
|
-
"role": "assistant",
|
|
612
|
-
"content": "Task completed successfully.",
|
|
613
|
-
"metadata": {"title": "✅ Complete"},
|
|
614
|
-
}
|
|
615
|
-
|
|
616
|
-
except Exception as e:
|
|
617
|
-
logger.error(f"Error in run method: {str(e)}")
|
|
618
|
-
yield {
|
|
619
|
-
"role": "assistant",
|
|
620
|
-
"content": f"Error: {str(e)}",
|
|
621
|
-
"metadata": {"title": "❌ Error"},
|
|
622
|
-
}
|
|
623
|
-
|
|
624
|
-
async def _run_loop(self, queue: asyncio.Queue, messages: List[Dict[str, Any]]) -> None:
|
|
625
|
-
"""Internal method to run the agent loop with provided messages.
|
|
626
|
-
|
|
627
|
-
Args:
|
|
628
|
-
queue: Queue to put responses into
|
|
629
|
-
messages: List of messages in standard OpenAI format
|
|
630
|
-
"""
|
|
631
|
-
# Continue running until explicitly told to stop
|
|
632
|
-
running = True
|
|
633
|
-
turn_created = False
|
|
634
|
-
# Track if an action-specific screenshot has been saved this turn
|
|
635
|
-
action_screenshot_saved = False
|
|
636
|
-
|
|
637
|
-
attempt = 0
|
|
638
|
-
max_attempts = 3
|
|
639
|
-
|
|
640
|
-
while running and attempt < max_attempts:
|
|
641
|
-
try:
|
|
642
|
-
# Create a new turn directory if it's not already created
|
|
643
|
-
if not turn_created:
|
|
644
|
-
self._create_turn_dir()
|
|
645
|
-
turn_created = True
|
|
646
|
-
|
|
647
|
-
# Ensure client is initialized
|
|
648
|
-
if self.client is None:
|
|
649
|
-
logger.info("Initializing client...")
|
|
650
|
-
await self.initialize_client()
|
|
651
|
-
if self.client is None:
|
|
652
|
-
raise RuntimeError("Failed to initialize client")
|
|
653
|
-
logger.info("Client initialized successfully")
|
|
654
|
-
|
|
655
|
-
# Get up-to-date screen information
|
|
656
|
-
parsed_screen = await self._get_parsed_screen_som()
|
|
657
|
-
|
|
658
|
-
# Process screen info and update messages in standard format
|
|
659
|
-
try:
|
|
660
|
-
# Get image from parsed screen
|
|
661
|
-
image = parsed_screen.annotated_image_base64 or None
|
|
662
|
-
if image:
|
|
663
|
-
# Save elements as JSON if we have a turn directory
|
|
664
|
-
if self.current_turn_dir and hasattr(parsed_screen, "elements"):
|
|
665
|
-
elements_path = os.path.join(self.current_turn_dir, "elements.json")
|
|
666
|
-
with open(elements_path, "w") as f:
|
|
667
|
-
# Convert elements to dicts for JSON serialization
|
|
668
|
-
elements_json = [
|
|
669
|
-
elem.model_dump() for elem in parsed_screen.elements
|
|
670
|
-
]
|
|
671
|
-
json.dump(elements_json, f, indent=2)
|
|
672
|
-
logger.info(f"Saved elements to {elements_path}")
|
|
673
|
-
|
|
674
|
-
# Remove data URL prefix if present
|
|
675
|
-
if "," in image:
|
|
676
|
-
image = image.split(",")[1]
|
|
677
|
-
|
|
678
|
-
# Add screenshot to message history using message manager
|
|
679
|
-
self.message_manager.add_user_message(
|
|
680
|
-
[
|
|
681
|
-
{
|
|
682
|
-
"type": "image_url",
|
|
683
|
-
"image_url": {"url": f"data:image/png;base64,{image}"},
|
|
684
|
-
}
|
|
685
|
-
]
|
|
686
|
-
)
|
|
687
|
-
logger.info("Added screenshot to message history")
|
|
688
|
-
except Exception as e:
|
|
689
|
-
logger.error(f"Error processing screen info: {str(e)}")
|
|
690
|
-
raise
|
|
691
|
-
|
|
692
|
-
# Get system prompt
|
|
693
|
-
system_prompt = self._get_system_prompt()
|
|
694
|
-
|
|
695
|
-
# Make API call with retries using the APIHandler
|
|
696
|
-
response = await self.api_handler.make_api_call(
|
|
697
|
-
self.message_manager.messages, system_prompt
|
|
698
|
-
)
|
|
699
|
-
|
|
700
|
-
# Handle the response (may execute actions)
|
|
701
|
-
# Returns: (should_continue, action_screenshot_saved)
|
|
702
|
-
should_continue, new_screenshot_saved = await self._handle_response(
|
|
703
|
-
response, self.message_manager.messages, parsed_screen
|
|
704
|
-
)
|
|
705
|
-
|
|
706
|
-
# Update whether an action screenshot was saved this turn
|
|
707
|
-
action_screenshot_saved = action_screenshot_saved or new_screenshot_saved
|
|
708
|
-
|
|
709
|
-
# Create OpenAI-compatible response format using utility function
|
|
710
|
-
openai_compatible_response = await to_openai_agent_response_format(
|
|
711
|
-
response=response,
|
|
712
|
-
messages=self.message_manager.messages,
|
|
713
|
-
model=self.model,
|
|
714
|
-
parsed_screen=parsed_screen,
|
|
715
|
-
parser=self.parser
|
|
716
|
-
)
|
|
717
|
-
# Log standardized response for ease of parsing
|
|
718
|
-
self._log_api_call("agent_response", request=None, response=openai_compatible_response)
|
|
719
|
-
|
|
720
|
-
# Put the response in the queue
|
|
721
|
-
await queue.put(openai_compatible_response)
|
|
722
|
-
|
|
723
|
-
# Check if we should continue this conversation
|
|
724
|
-
running = should_continue
|
|
725
|
-
|
|
726
|
-
# Create a new turn directory if we're continuing
|
|
727
|
-
if running:
|
|
728
|
-
turn_created = False
|
|
729
|
-
|
|
730
|
-
# Reset attempt counter on success
|
|
731
|
-
attempt = 0
|
|
732
|
-
|
|
733
|
-
except Exception as e:
|
|
734
|
-
attempt += 1
|
|
735
|
-
error_msg = f"Error in _run_loop method (attempt {attempt}/{max_attempts}): {str(e)}"
|
|
736
|
-
logger.error(error_msg)
|
|
737
|
-
|
|
738
|
-
# If this is our last attempt, provide more info about the error
|
|
739
|
-
if attempt >= max_attempts:
|
|
740
|
-
logger.error(f"Maximum retry attempts reached. Last error was: {str(e)}")
|
|
741
|
-
|
|
742
|
-
await queue.put({
|
|
743
|
-
"role": "assistant",
|
|
744
|
-
"content": f"Error: {str(e)}",
|
|
745
|
-
"metadata": {"title": "❌ Error"},
|
|
746
|
-
})
|
|
747
|
-
|
|
748
|
-
# Create a brief delay before retrying
|
|
749
|
-
await asyncio.sleep(1)
|
|
750
|
-
finally:
|
|
751
|
-
# Signal that we're done
|
|
752
|
-
await queue.put(None)
|
|
753
|
-
|
|
754
|
-
async def cancel(self) -> None:
|
|
755
|
-
"""Cancel the currently running agent loop task.
|
|
756
|
-
|
|
757
|
-
This method stops the ongoing processing in the agent loop
|
|
758
|
-
by cancelling the loop_task if it exists and is running.
|
|
759
|
-
"""
|
|
760
|
-
if self.loop_task and not self.loop_task.done():
|
|
761
|
-
logger.info("Cancelling Omni loop task")
|
|
762
|
-
self.loop_task.cancel()
|
|
763
|
-
try:
|
|
764
|
-
# Wait for the task to be cancelled with a timeout
|
|
765
|
-
await asyncio.wait_for(self.loop_task, timeout=2.0)
|
|
766
|
-
except asyncio.TimeoutError:
|
|
767
|
-
logger.warning("Timeout while waiting for loop task to cancel")
|
|
768
|
-
except asyncio.CancelledError:
|
|
769
|
-
logger.info("Loop task cancelled successfully")
|
|
770
|
-
except Exception as e:
|
|
771
|
-
logger.error(f"Error while cancelling loop task: {str(e)}")
|
|
772
|
-
finally:
|
|
773
|
-
logger.info("Omni loop task cancelled")
|
|
774
|
-
else:
|
|
775
|
-
logger.info("No active Omni loop task to cancel")
|
|
776
|
-
|
|
777
|
-
async def process_model_response(self, response_text: str) -> Optional[Dict[str, Any]]:
|
|
778
|
-
"""Process model response to extract tool calls.
|
|
779
|
-
|
|
780
|
-
Args:
|
|
781
|
-
response_text: Model response text
|
|
782
|
-
|
|
783
|
-
Returns:
|
|
784
|
-
Extracted tool information, or None if no tool call was found
|
|
785
|
-
"""
|
|
786
|
-
try:
|
|
787
|
-
# Ensure tools are initialized before use
|
|
788
|
-
await self._ensure_tools_initialized()
|
|
789
|
-
|
|
790
|
-
# Look for tool use in the response
|
|
791
|
-
if "function_call" in response_text or "tool_use" in response_text:
|
|
792
|
-
# The extract_tool_call method should be implemented in the OmniAPIHandler
|
|
793
|
-
# For now, we'll just use a simple approach
|
|
794
|
-
# This will be replaced with the proper implementation
|
|
795
|
-
tool_info = None
|
|
796
|
-
if "function_call" in response_text:
|
|
797
|
-
# Extract function call params
|
|
798
|
-
try:
|
|
799
|
-
# Simple extraction - in real code this would be more robust
|
|
800
|
-
import json
|
|
801
|
-
import re
|
|
802
|
-
|
|
803
|
-
match = re.search(r'"function_call"\s*:\s*{([^}]+)}', response_text)
|
|
804
|
-
if match:
|
|
805
|
-
function_text = "{" + match.group(1) + "}"
|
|
806
|
-
tool_info = json.loads(function_text)
|
|
807
|
-
except Exception as e:
|
|
808
|
-
logger.error(f"Error extracting function call: {str(e)}")
|
|
809
|
-
|
|
810
|
-
if tool_info:
|
|
811
|
-
try:
|
|
812
|
-
# Execute the tool
|
|
813
|
-
result = await self.tool_manager.execute_tool(
|
|
814
|
-
name=tool_info.get("name"), tool_input=tool_info.get("arguments", {})
|
|
815
|
-
)
|
|
816
|
-
# Handle the result
|
|
817
|
-
return {"tool_result": result}
|
|
818
|
-
except Exception as e:
|
|
819
|
-
error_msg = (
|
|
820
|
-
f"Error executing tool '{tool_info.get('name', 'unknown')}': {str(e)}"
|
|
821
|
-
)
|
|
822
|
-
logger.error(error_msg)
|
|
823
|
-
return {"tool_result": ToolResult(error=error_msg)}
|
|
824
|
-
except Exception as e:
|
|
825
|
-
logger.error(f"Error processing tool call: {str(e)}")
|
|
826
|
-
|
|
827
|
-
return None
|
|
828
|
-
|
|
829
|
-
async def process_response_with_tools(
|
|
830
|
-
self, response_text: str, parsed_screen: Optional[ParseResult] = None
|
|
831
|
-
) -> Tuple[bool, str]:
|
|
832
|
-
"""Process model response and execute tools.
|
|
833
|
-
|
|
834
|
-
Args:
|
|
835
|
-
response_text: Model response text
|
|
836
|
-
parsed_screen: Current parsed screen information (optional)
|
|
837
|
-
|
|
838
|
-
Returns:
|
|
839
|
-
Tuple of (action_taken, observation)
|
|
840
|
-
"""
|
|
841
|
-
logger.info("Processing response with tools")
|
|
842
|
-
|
|
843
|
-
# Process the response to extract tool calls
|
|
844
|
-
tool_result = await self.process_model_response(response_text)
|
|
845
|
-
|
|
846
|
-
if tool_result and "tool_result" in tool_result:
|
|
847
|
-
# A tool was executed
|
|
848
|
-
result = tool_result["tool_result"]
|
|
849
|
-
if result.error:
|
|
850
|
-
return False, f"ERROR: {result.error}"
|
|
851
|
-
else:
|
|
852
|
-
return True, result.output or "Tool executed successfully"
|
|
853
|
-
|
|
854
|
-
# No action or tool call found
|
|
855
|
-
return False, "No action taken - no tool call detected in response"
|
|
856
|
-
|
|
857
|
-
###########################################
|
|
858
|
-
# UTILITY METHODS
|
|
859
|
-
###########################################
|
|
860
|
-
|
|
861
|
-
async def _ensure_tools_initialized(self) -> None:
|
|
862
|
-
"""Ensure the tool manager and tools are initialized before use."""
|
|
863
|
-
if not hasattr(self.tool_manager, "tools") or self.tool_manager.tools is None:
|
|
864
|
-
logger.info("Tools not initialized. Initializing now...")
|
|
865
|
-
await self.tool_manager.initialize()
|
|
866
|
-
logger.info("Tools initialized successfully.")
|
|
867
|
-
|
|
868
|
-
async def _execute_action_with_tools(
|
|
869
|
-
self, action_data: Dict[str, Any], parsed_screen: ParseResult
|
|
870
|
-
) -> Tuple[bool, bool]:
|
|
871
|
-
"""Execute an action using the tools-based approach.
|
|
872
|
-
|
|
873
|
-
Args:
|
|
874
|
-
action_data: Dictionary containing action details
|
|
875
|
-
parsed_screen: Current parsed screen information
|
|
876
|
-
|
|
877
|
-
Returns:
|
|
878
|
-
Tuple of (should_continue, action_screenshot_saved)
|
|
879
|
-
"""
|
|
880
|
-
action_screenshot_saved = False
|
|
881
|
-
action_type = None # Initialize for possible use in post-action screenshot
|
|
882
|
-
|
|
883
|
-
try:
|
|
884
|
-
# Extract the action
|
|
885
|
-
parsed_action = action_data.get("Action", "").lower()
|
|
886
|
-
|
|
887
|
-
# Only process if we have a valid action
|
|
888
|
-
if not parsed_action or parsed_action == "none":
|
|
889
|
-
return False, action_screenshot_saved
|
|
890
|
-
|
|
891
|
-
# Convert the parsed content to a format suitable for the tools system
|
|
892
|
-
tool_name = "computer" # Default to computer tool
|
|
893
|
-
tool_args = {"action": parsed_action}
|
|
894
|
-
|
|
895
|
-
# Add specific arguments based on action type
|
|
896
|
-
if parsed_action in ["left_click", "right_click", "double_click", "move_cursor"]:
|
|
897
|
-
# Calculate coordinates from Box ID using parser
|
|
898
|
-
try:
|
|
899
|
-
box_id = int(action_data["Box ID"])
|
|
900
|
-
x, y = await self.parser.calculate_click_coordinates(
|
|
901
|
-
box_id, cast(ParseResult, parsed_screen)
|
|
902
|
-
)
|
|
903
|
-
tool_args["x"] = x
|
|
904
|
-
tool_args["y"] = y
|
|
905
|
-
|
|
906
|
-
# Visualize action if screenshot is available
|
|
907
|
-
if parsed_screen and parsed_screen.annotated_image_base64:
|
|
908
|
-
img_data = parsed_screen.annotated_image_base64
|
|
909
|
-
# Remove data URL prefix if present
|
|
910
|
-
if img_data.startswith("data:image"):
|
|
911
|
-
img_data = img_data.split(",")[1]
|
|
912
|
-
# Save visualization for coordinate-based actions
|
|
913
|
-
self.viz_helper.visualize_action(x, y, img_data)
|
|
914
|
-
action_screenshot_saved = True
|
|
915
|
-
|
|
916
|
-
except (ValueError, KeyError) as e:
|
|
917
|
-
logger.error(f"Error processing Box ID: {str(e)}")
|
|
918
|
-
return False, action_screenshot_saved
|
|
919
|
-
|
|
920
|
-
elif parsed_action == "type_text":
|
|
921
|
-
tool_args["text"] = action_data.get("Value", "")
|
|
922
|
-
# For type_text, store the value in the action type for screenshot naming
|
|
923
|
-
action_type = f"type_{tool_args['text'][:20]}" # Truncate if too long
|
|
924
|
-
|
|
925
|
-
elif parsed_action == "press_key":
|
|
926
|
-
tool_args["key"] = action_data.get("Value", "")
|
|
927
|
-
action_type = f"press_{tool_args['key']}"
|
|
928
|
-
|
|
929
|
-
elif parsed_action == "hotkey":
|
|
930
|
-
value = action_data.get("Value", "")
|
|
931
|
-
if isinstance(value, list):
|
|
932
|
-
tool_args["keys"] = value
|
|
933
|
-
action_type = f"hotkey_{'_'.join(value)}"
|
|
934
|
-
else:
|
|
935
|
-
# Split string format like "command+space" into a list
|
|
936
|
-
keys = [k.strip() for k in value.lower().split("+")]
|
|
937
|
-
tool_args["keys"] = keys
|
|
938
|
-
action_type = f"hotkey_{value.replace('+', '_')}"
|
|
939
|
-
|
|
940
|
-
elif parsed_action in ["scroll_down", "scroll_up"]:
|
|
941
|
-
clicks = int(action_data.get("amount", 1))
|
|
942
|
-
tool_args["amount"] = clicks
|
|
943
|
-
action_type = f"scroll_{parsed_action.split('_')[1]}_{clicks}"
|
|
944
|
-
|
|
945
|
-
# Visualize scrolling if screenshot is available
|
|
946
|
-
if parsed_screen and parsed_screen.annotated_image_base64:
|
|
947
|
-
img_data = parsed_screen.annotated_image_base64
|
|
948
|
-
# Remove data URL prefix if present
|
|
949
|
-
if img_data.startswith("data:image"):
|
|
950
|
-
img_data = img_data.split(",")[1]
|
|
951
|
-
direction = "down" if parsed_action == "scroll_down" else "up"
|
|
952
|
-
# For scrolling, we save the visualization
|
|
953
|
-
self.viz_helper.visualize_scroll(direction, clicks, img_data)
|
|
954
|
-
action_screenshot_saved = True
|
|
955
|
-
|
|
956
|
-
# Ensure tools are initialized before use
|
|
957
|
-
await self._ensure_tools_initialized()
|
|
958
|
-
|
|
959
|
-
# Execute tool with prepared arguments
|
|
960
|
-
result = await self.tool_manager.execute_tool(name=tool_name, tool_input=tool_args)
|
|
961
|
-
|
|
962
|
-
# Take a new screenshot after the action if we haven't already saved one
|
|
963
|
-
if not action_screenshot_saved:
|
|
964
|
-
try:
|
|
965
|
-
# Get a new screenshot after the action
|
|
966
|
-
new_parsed_screen = await self._get_parsed_screen_som(save_screenshot=False)
|
|
967
|
-
if new_parsed_screen and new_parsed_screen.annotated_image_base64:
|
|
968
|
-
img_data = new_parsed_screen.annotated_image_base64
|
|
969
|
-
# Remove data URL prefix if present
|
|
970
|
-
if img_data.startswith("data:image"):
|
|
971
|
-
img_data = img_data.split(",")[1]
|
|
972
|
-
# Save with action type if defined, otherwise use the action name
|
|
973
|
-
if action_type:
|
|
974
|
-
self._save_screenshot(img_data, action_type=action_type)
|
|
975
|
-
else:
|
|
976
|
-
self._save_screenshot(img_data, action_type=parsed_action)
|
|
977
|
-
action_screenshot_saved = True
|
|
978
|
-
except Exception as screenshot_error:
|
|
979
|
-
logger.error(f"Error taking post-action screenshot: {str(screenshot_error)}")
|
|
980
|
-
|
|
981
|
-
# Continue the loop if the action is not "None"
|
|
982
|
-
return True, action_screenshot_saved
|
|
983
|
-
|
|
984
|
-
except Exception as e:
|
|
985
|
-
logger.error(f"Error executing action: {str(e)}")
|
|
986
|
-
# Update the last assistant message with error
|
|
987
|
-
error_message = [{"type": "text", "text": f"Error executing action: {str(e)}"}]
|
|
988
|
-
# Replace the last assistant message with the error
|
|
989
|
-
self.message_manager.add_assistant_message(error_message)
|
|
990
|
-
return False, action_screenshot_saved
|