cua-agent 0.1.5__py3-none-any.whl → 0.1.17__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cua-agent might be problematic. Click here for more details.
- agent/__init__.py +3 -4
- agent/core/__init__.py +3 -10
- agent/core/computer_agent.py +207 -32
- agent/core/experiment.py +20 -3
- agent/core/loop.py +78 -120
- agent/core/messages.py +279 -125
- agent/core/telemetry.py +44 -32
- agent/core/types.py +35 -0
- agent/core/visualization.py +197 -0
- agent/providers/anthropic/api/client.py +142 -1
- agent/providers/anthropic/api_handler.py +140 -0
- agent/providers/anthropic/callbacks/__init__.py +5 -0
- agent/providers/anthropic/loop.py +224 -209
- agent/providers/anthropic/messages/manager.py +3 -1
- agent/providers/anthropic/response_handler.py +229 -0
- agent/providers/anthropic/tools/base.py +1 -1
- agent/providers/anthropic/tools/bash.py +0 -97
- agent/providers/anthropic/tools/collection.py +2 -2
- agent/providers/anthropic/tools/computer.py +34 -24
- agent/providers/anthropic/tools/manager.py +2 -2
- agent/providers/anthropic/utils.py +370 -0
- agent/providers/omni/__init__.py +1 -20
- agent/providers/omni/api_handler.py +42 -0
- agent/providers/omni/clients/anthropic.py +4 -0
- agent/providers/omni/image_utils.py +0 -72
- agent/providers/omni/loop.py +497 -607
- agent/providers/omni/parser.py +60 -5
- agent/providers/omni/tools/__init__.py +25 -8
- agent/providers/omni/tools/base.py +29 -0
- agent/providers/omni/tools/bash.py +43 -38
- agent/providers/omni/tools/computer.py +144 -181
- agent/providers/omni/tools/manager.py +26 -48
- agent/providers/omni/types.py +0 -4
- agent/providers/omni/utils.py +225 -144
- {cua_agent-0.1.5.dist-info → cua_agent-0.1.17.dist-info}/METADATA +6 -36
- cua_agent-0.1.17.dist-info/RECORD +63 -0
- agent/core/agent.py +0 -252
- agent/core/base_agent.py +0 -164
- agent/core/factory.py +0 -102
- agent/providers/omni/callbacks.py +0 -78
- agent/providers/omni/clients/groq.py +0 -101
- agent/providers/omni/experiment.py +0 -273
- agent/providers/omni/messages.py +0 -171
- agent/providers/omni/tool_manager.py +0 -91
- agent/providers/omni/visualization.py +0 -130
- agent/types/__init__.py +0 -26
- agent/types/base.py +0 -53
- agent/types/messages.py +0 -36
- cua_agent-0.1.5.dist-info/RECORD +0 -67
- /agent/{types → core}/tools.py +0 -0
- {cua_agent-0.1.5.dist-info → cua_agent-0.1.17.dist-info}/WHEEL +0 -0
- {cua_agent-0.1.5.dist-info → cua_agent-0.1.17.dist-info}/entry_points.txt +0 -0
agent/__init__.py
CHANGED
|
@@ -48,9 +48,8 @@ except Exception as e:
|
|
|
48
48
|
# Other issues with telemetry
|
|
49
49
|
logger.warning(f"Error initializing telemetry: {e}")
|
|
50
50
|
|
|
51
|
-
from .core.factory import AgentFactory
|
|
52
|
-
from .core.agent import ComputerAgent
|
|
53
51
|
from .providers.omni.types import LLMProvider, LLM
|
|
54
|
-
from .
|
|
52
|
+
from .core.loop import AgentLoop
|
|
53
|
+
from .core.computer_agent import ComputerAgent
|
|
55
54
|
|
|
56
|
-
__all__ = ["
|
|
55
|
+
__all__ = ["AgentLoop", "LLMProvider", "LLM", "ComputerAgent"]
|
agent/core/__init__.py
CHANGED
|
@@ -1,18 +1,12 @@
|
|
|
1
1
|
"""Core agent components."""
|
|
2
2
|
|
|
3
|
-
from .base_agent import BaseComputerAgent
|
|
4
3
|
from .loop import BaseLoop
|
|
5
4
|
from .messages import (
|
|
6
|
-
create_user_message,
|
|
7
|
-
create_assistant_message,
|
|
8
|
-
create_system_message,
|
|
9
|
-
create_image_message,
|
|
10
|
-
create_screen_message,
|
|
11
5
|
BaseMessageManager,
|
|
12
6
|
ImageRetentionConfig,
|
|
13
7
|
)
|
|
14
8
|
from .callbacks import (
|
|
15
|
-
CallbackManager,
|
|
9
|
+
CallbackManager,
|
|
16
10
|
CallbackHandler,
|
|
17
11
|
BaseCallbackManager,
|
|
18
12
|
ContentCallback,
|
|
@@ -21,9 +15,8 @@ from .callbacks import (
|
|
|
21
15
|
)
|
|
22
16
|
|
|
23
17
|
__all__ = [
|
|
24
|
-
"
|
|
25
|
-
"
|
|
26
|
-
"CallbackManager",
|
|
18
|
+
"BaseLoop",
|
|
19
|
+
"CallbackManager",
|
|
27
20
|
"CallbackHandler",
|
|
28
21
|
"BaseMessageManager",
|
|
29
22
|
"ImageRetentionConfig",
|
agent/core/computer_agent.py
CHANGED
|
@@ -1,69 +1,244 @@
|
|
|
1
1
|
"""Main entry point for computer agents."""
|
|
2
2
|
|
|
3
|
+
import asyncio
|
|
3
4
|
import logging
|
|
4
|
-
|
|
5
|
+
import os
|
|
6
|
+
from typing import Any, AsyncGenerator, Dict, Optional, cast, List
|
|
5
7
|
|
|
6
8
|
from computer import Computer
|
|
7
|
-
from ..
|
|
8
|
-
from .
|
|
9
|
+
from ..providers.anthropic.loop import AnthropicLoop
|
|
10
|
+
from ..providers.omni.loop import OmniLoop
|
|
11
|
+
from ..providers.omni.parser import OmniParser
|
|
12
|
+
from ..providers.omni.types import LLMProvider, LLM
|
|
13
|
+
from .. import AgentLoop
|
|
14
|
+
from .messages import StandardMessageManager, ImageRetentionConfig
|
|
15
|
+
from .types import AgentResponse
|
|
9
16
|
|
|
10
17
|
logging.basicConfig(level=logging.INFO)
|
|
11
18
|
logger = logging.getLogger(__name__)
|
|
12
19
|
|
|
20
|
+
# Default models for different providers
|
|
21
|
+
DEFAULT_MODELS = {
|
|
22
|
+
LLMProvider.OPENAI: "gpt-4o",
|
|
23
|
+
LLMProvider.ANTHROPIC: "claude-3-7-sonnet-20250219",
|
|
24
|
+
}
|
|
25
|
+
|
|
26
|
+
# Map providers to their environment variable names
|
|
27
|
+
ENV_VARS = {
|
|
28
|
+
LLMProvider.OPENAI: "OPENAI_API_KEY",
|
|
29
|
+
LLMProvider.ANTHROPIC: "ANTHROPIC_API_KEY",
|
|
30
|
+
}
|
|
31
|
+
|
|
13
32
|
|
|
14
33
|
class ComputerAgent:
|
|
15
34
|
"""A computer agent that can perform automated tasks using natural language instructions."""
|
|
16
35
|
|
|
17
|
-
def __init__(
|
|
36
|
+
def __init__(
|
|
37
|
+
self,
|
|
38
|
+
computer: Computer,
|
|
39
|
+
model: LLM,
|
|
40
|
+
loop: AgentLoop,
|
|
41
|
+
max_retries: int = 3,
|
|
42
|
+
screenshot_dir: Optional[str] = None,
|
|
43
|
+
log_dir: Optional[str] = None,
|
|
44
|
+
api_key: Optional[str] = None,
|
|
45
|
+
save_trajectory: bool = True,
|
|
46
|
+
trajectory_dir: str = "trajectories",
|
|
47
|
+
only_n_most_recent_images: Optional[int] = None,
|
|
48
|
+
verbosity: int = logging.INFO,
|
|
49
|
+
):
|
|
18
50
|
"""Initialize the ComputerAgent.
|
|
19
51
|
|
|
20
52
|
Args:
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
53
|
+
computer: Computer instance. If not provided, one will be created with default settings.
|
|
54
|
+
max_retries: Maximum number of retry attempts.
|
|
55
|
+
screenshot_dir: Directory to save screenshots.
|
|
56
|
+
log_dir: Directory to save logs (set to None to disable logging to files).
|
|
57
|
+
model: LLM object containing provider and model name. Takes precedence over provider/model_name.
|
|
58
|
+
provider: The AI provider to use (e.g., LLMProvider.ANTHROPIC). Only used if model is None.
|
|
59
|
+
api_key: The API key for the provider. If not provided, will look for environment variable.
|
|
60
|
+
model_name: The model name to use. Only used if model is None.
|
|
61
|
+
save_trajectory: Whether to save the trajectory.
|
|
62
|
+
trajectory_dir: Directory to save the trajectory.
|
|
63
|
+
only_n_most_recent_images: Maximum number of recent screenshots to include in API requests.
|
|
64
|
+
verbosity: Logging level.
|
|
24
65
|
"""
|
|
25
|
-
|
|
26
|
-
self.
|
|
27
|
-
self.
|
|
28
|
-
self.
|
|
66
|
+
# Basic agent configuration
|
|
67
|
+
self.max_retries = max_retries
|
|
68
|
+
self.computer = computer
|
|
69
|
+
self.queue = asyncio.Queue()
|
|
70
|
+
self.screenshot_dir = screenshot_dir
|
|
71
|
+
self.log_dir = log_dir
|
|
72
|
+
self._retry_count = 0
|
|
29
73
|
self._initialized = False
|
|
30
74
|
self._in_context = False
|
|
31
75
|
|
|
32
|
-
#
|
|
33
|
-
|
|
76
|
+
# Set logging level
|
|
77
|
+
logger.setLevel(verbosity)
|
|
78
|
+
|
|
79
|
+
# Setup logging
|
|
80
|
+
if self.log_dir:
|
|
81
|
+
os.makedirs(self.log_dir, exist_ok=True)
|
|
82
|
+
logger.info(f"Created logs directory: {self.log_dir}")
|
|
83
|
+
|
|
84
|
+
# Setup screenshots directory
|
|
85
|
+
if self.screenshot_dir:
|
|
86
|
+
os.makedirs(self.screenshot_dir, exist_ok=True)
|
|
87
|
+
logger.info(f"Created screenshots directory: {self.screenshot_dir}")
|
|
88
|
+
|
|
89
|
+
# Use the provided LLM object
|
|
90
|
+
self.provider = model.provider
|
|
91
|
+
actual_model_name = model.name or DEFAULT_MODELS.get(self.provider, "")
|
|
92
|
+
|
|
93
|
+
# Ensure we have a valid model name
|
|
94
|
+
if not actual_model_name:
|
|
95
|
+
actual_model_name = DEFAULT_MODELS.get(self.provider, "")
|
|
96
|
+
if not actual_model_name:
|
|
97
|
+
raise ValueError(
|
|
98
|
+
f"No model specified for provider {self.provider} and no default found"
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
# Ensure computer is properly cast for typing purposes
|
|
102
|
+
computer_instance = self.computer
|
|
103
|
+
|
|
104
|
+
# Get API key from environment if not provided
|
|
105
|
+
actual_api_key = api_key or os.environ.get(ENV_VARS[self.provider], "")
|
|
106
|
+
if not actual_api_key:
|
|
107
|
+
raise ValueError(f"No API key provided for {self.provider}")
|
|
108
|
+
|
|
109
|
+
# Initialize the appropriate loop based on the loop parameter
|
|
110
|
+
if loop == AgentLoop.ANTHROPIC:
|
|
111
|
+
self._loop = AnthropicLoop(
|
|
112
|
+
api_key=actual_api_key,
|
|
113
|
+
model=actual_model_name,
|
|
114
|
+
computer=computer_instance,
|
|
115
|
+
save_trajectory=save_trajectory,
|
|
116
|
+
base_dir=trajectory_dir,
|
|
117
|
+
only_n_most_recent_images=only_n_most_recent_images,
|
|
118
|
+
)
|
|
119
|
+
else:
|
|
120
|
+
self._loop = OmniLoop(
|
|
121
|
+
provider=self.provider,
|
|
122
|
+
api_key=actual_api_key,
|
|
123
|
+
model=actual_model_name,
|
|
124
|
+
computer=computer_instance,
|
|
125
|
+
save_trajectory=save_trajectory,
|
|
126
|
+
base_dir=trajectory_dir,
|
|
127
|
+
only_n_most_recent_images=only_n_most_recent_images,
|
|
128
|
+
parser=OmniParser(),
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
# Initialize the message manager from the loop
|
|
132
|
+
self.message_manager = self._loop.message_manager
|
|
133
|
+
|
|
134
|
+
logger.info(
|
|
135
|
+
f"ComputerAgent initialized with provider: {self.provider}, model: {actual_model_name}"
|
|
136
|
+
)
|
|
34
137
|
|
|
35
138
|
async def __aenter__(self):
|
|
36
|
-
"""
|
|
139
|
+
"""Initialize the agent when used as a context manager."""
|
|
140
|
+
logger.info("Entering ComputerAgent context")
|
|
37
141
|
self._in_context = True
|
|
142
|
+
|
|
143
|
+
# In case the computer wasn't initialized
|
|
144
|
+
try:
|
|
145
|
+
# Initialize the computer only if not already initialized
|
|
146
|
+
logger.info("Checking if computer is already initialized...")
|
|
147
|
+
if not self.computer._initialized:
|
|
148
|
+
logger.info("Initializing computer in __aenter__...")
|
|
149
|
+
# Use the computer's __aenter__ directly instead of calling run()
|
|
150
|
+
await self.computer.__aenter__()
|
|
151
|
+
logger.info("Computer initialized in __aenter__")
|
|
152
|
+
else:
|
|
153
|
+
logger.info("Computer already initialized, skipping initialization")
|
|
154
|
+
|
|
155
|
+
# Take a test screenshot to verify the computer is working
|
|
156
|
+
logger.info("Testing computer with a screenshot...")
|
|
157
|
+
try:
|
|
158
|
+
test_screenshot = await self.computer.interface.screenshot()
|
|
159
|
+
# Determine the screenshot size based on its type
|
|
160
|
+
if isinstance(test_screenshot, (bytes, bytearray, memoryview)):
|
|
161
|
+
size = len(test_screenshot)
|
|
162
|
+
elif hasattr(test_screenshot, "base64_image"):
|
|
163
|
+
size = len(test_screenshot.base64_image)
|
|
164
|
+
else:
|
|
165
|
+
size = "unknown"
|
|
166
|
+
logger.info(f"Screenshot test successful, size: {size}")
|
|
167
|
+
except Exception as e:
|
|
168
|
+
logger.error(f"Screenshot test failed: {str(e)}")
|
|
169
|
+
# Even though screenshot failed, we continue since some tests might not need it
|
|
170
|
+
except Exception as e:
|
|
171
|
+
logger.error(f"Error initializing computer in __aenter__: {str(e)}")
|
|
172
|
+
raise
|
|
173
|
+
|
|
38
174
|
await self.initialize()
|
|
39
175
|
return self
|
|
40
176
|
|
|
41
177
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
42
|
-
"""
|
|
178
|
+
"""Cleanup agent resources if needed."""
|
|
179
|
+
logger.info("Cleaning up agent resources")
|
|
43
180
|
self._in_context = False
|
|
44
181
|
|
|
182
|
+
# Do any necessary cleanup
|
|
183
|
+
# We're not shutting down the computer here as it might be shared
|
|
184
|
+
# Just log that we're exiting
|
|
185
|
+
if exc_type:
|
|
186
|
+
logger.error(f"Exiting agent context with error: {exc_type.__name__}: {exc_val}")
|
|
187
|
+
else:
|
|
188
|
+
logger.info("Exiting agent context normally")
|
|
189
|
+
|
|
190
|
+
# If we have a queue, make sure to signal it's done
|
|
191
|
+
if hasattr(self, "queue") and self.queue:
|
|
192
|
+
await self.queue.put(None) # Signal that we're done
|
|
193
|
+
|
|
45
194
|
async def initialize(self) -> None:
|
|
46
195
|
"""Initialize the agent and its components."""
|
|
47
196
|
if not self._initialized:
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
await self.
|
|
197
|
+
# Always initialize the computer if available
|
|
198
|
+
if self.computer and not self.computer._initialized:
|
|
199
|
+
await self.computer.run()
|
|
51
200
|
self._initialized = True
|
|
52
201
|
|
|
53
|
-
async def run(self, task: str) -> AsyncGenerator[
|
|
54
|
-
"""Run
|
|
55
|
-
|
|
56
|
-
|
|
202
|
+
async def run(self, task: str) -> AsyncGenerator[AgentResponse, None]:
|
|
203
|
+
"""Run a task using the computer agent.
|
|
204
|
+
|
|
205
|
+
Args:
|
|
206
|
+
task: Task description
|
|
207
|
+
|
|
208
|
+
Yields:
|
|
209
|
+
Agent response format
|
|
210
|
+
"""
|
|
211
|
+
try:
|
|
212
|
+
logger.info(f"Running task: {task}")
|
|
213
|
+
logger.info(
|
|
214
|
+
f"Message history before task has {len(self.message_manager.messages)} messages"
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
# Initialize the computer if needed
|
|
218
|
+
if not self._initialized:
|
|
219
|
+
await self.initialize()
|
|
220
|
+
|
|
221
|
+
# Add task as a user message using the message manager
|
|
222
|
+
self.message_manager.add_user_message([{"type": "text", "text": task}])
|
|
223
|
+
logger.info(
|
|
224
|
+
f"Added task message. Message history now has {len(self.message_manager.messages)} messages"
|
|
225
|
+
)
|
|
57
226
|
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
227
|
+
# Pass properly formatted messages to the loop
|
|
228
|
+
if self._loop is None:
|
|
229
|
+
logger.error("Loop not initialized properly")
|
|
230
|
+
yield {"error": "Loop not initialized properly"}
|
|
231
|
+
return
|
|
62
232
|
|
|
63
|
-
|
|
64
|
-
|
|
233
|
+
# Execute the task and yield results
|
|
234
|
+
async for result in self._loop.run(self.message_manager.messages):
|
|
235
|
+
# Yield the result to the caller
|
|
236
|
+
yield result
|
|
65
237
|
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
238
|
+
except Exception as e:
|
|
239
|
+
logger.error(f"Error in agent run method: {str(e)}")
|
|
240
|
+
yield {
|
|
241
|
+
"role": "assistant",
|
|
242
|
+
"content": f"Error: {str(e)}",
|
|
243
|
+
"metadata": {"title": "❌ Error"},
|
|
244
|
+
}
|
agent/core/experiment.py
CHANGED
|
@@ -84,7 +84,21 @@ class ExperimentManager:
|
|
|
84
84
|
if isinstance(data, dict):
|
|
85
85
|
result = {}
|
|
86
86
|
for k, v in data.items():
|
|
87
|
-
|
|
87
|
+
# Special handling for 'data' field in Anthropic message source
|
|
88
|
+
if k == "data" and isinstance(v, str) and len(v) > 1000:
|
|
89
|
+
result[k] = f"[BASE64_DATA_LENGTH_{len(v)}]"
|
|
90
|
+
# Special handling for the 'media_type' key which indicates we're in an image block
|
|
91
|
+
elif k == "media_type" and "image" in str(v):
|
|
92
|
+
result[k] = v
|
|
93
|
+
# If we're in an image block, look for a sibling 'data' field with base64 content
|
|
94
|
+
if (
|
|
95
|
+
"data" in result
|
|
96
|
+
and isinstance(result["data"], str)
|
|
97
|
+
and len(result["data"]) > 1000
|
|
98
|
+
):
|
|
99
|
+
result["data"] = f"[BASE64_DATA_LENGTH_{len(result['data'])}]"
|
|
100
|
+
else:
|
|
101
|
+
result[k] = self.sanitize_log_data(v)
|
|
88
102
|
return result
|
|
89
103
|
elif isinstance(data, list):
|
|
90
104
|
return [self.sanitize_log_data(item) for item in data]
|
|
@@ -93,15 +107,18 @@ class ExperimentManager:
|
|
|
93
107
|
else:
|
|
94
108
|
return data
|
|
95
109
|
|
|
96
|
-
def save_screenshot(self, img_base64: str, action_type: str = "") ->
|
|
110
|
+
def save_screenshot(self, img_base64: str, action_type: str = "") -> Optional[str]:
|
|
97
111
|
"""Save a screenshot to the experiment directory.
|
|
98
112
|
|
|
99
113
|
Args:
|
|
100
114
|
img_base64: Base64 encoded screenshot
|
|
101
115
|
action_type: Type of action that triggered the screenshot
|
|
116
|
+
|
|
117
|
+
Returns:
|
|
118
|
+
Path to the saved screenshot or None if there was an error
|
|
102
119
|
"""
|
|
103
120
|
if not self.current_turn_dir:
|
|
104
|
-
return
|
|
121
|
+
return None
|
|
105
122
|
|
|
106
123
|
try:
|
|
107
124
|
# Increment screenshot counter
|
agent/core/loop.py
CHANGED
|
@@ -2,22 +2,34 @@
|
|
|
2
2
|
|
|
3
3
|
import logging
|
|
4
4
|
import asyncio
|
|
5
|
-
import json
|
|
6
|
-
import os
|
|
7
5
|
from abc import ABC, abstractmethod
|
|
6
|
+
from enum import Enum, auto
|
|
8
7
|
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple
|
|
9
8
|
from datetime import datetime
|
|
10
|
-
import base64
|
|
11
9
|
|
|
12
10
|
from computer import Computer
|
|
13
11
|
from .experiment import ExperimentManager
|
|
12
|
+
from .messages import StandardMessageManager, ImageRetentionConfig
|
|
13
|
+
from .types import AgentResponse
|
|
14
14
|
|
|
15
15
|
logger = logging.getLogger(__name__)
|
|
16
16
|
|
|
17
17
|
|
|
18
|
+
class AgentLoop(Enum):
|
|
19
|
+
"""Enumeration of available loop types."""
|
|
20
|
+
|
|
21
|
+
ANTHROPIC = auto() # Anthropic implementation
|
|
22
|
+
OMNI = auto() # OmniLoop implementation
|
|
23
|
+
# Add more loop types as needed
|
|
24
|
+
|
|
25
|
+
|
|
18
26
|
class BaseLoop(ABC):
|
|
19
27
|
"""Base class for agent loops that handle message processing and tool execution."""
|
|
20
28
|
|
|
29
|
+
###########################################
|
|
30
|
+
# INITIALIZATION AND CONFIGURATION
|
|
31
|
+
###########################################
|
|
32
|
+
|
|
21
33
|
def __init__(
|
|
22
34
|
self,
|
|
23
35
|
computer: Computer,
|
|
@@ -55,8 +67,6 @@ class BaseLoop(ABC):
|
|
|
55
67
|
self.save_trajectory = save_trajectory
|
|
56
68
|
self.only_n_most_recent_images = only_n_most_recent_images
|
|
57
69
|
self._kwargs = kwargs
|
|
58
|
-
self.message_history = []
|
|
59
|
-
# self.tool_manager = BaseToolManager(computer)
|
|
60
70
|
|
|
61
71
|
# Initialize experiment manager
|
|
62
72
|
if self.save_trajectory and self.base_dir:
|
|
@@ -75,6 +85,64 @@ class BaseLoop(ABC):
|
|
|
75
85
|
# Initialize basic tracking
|
|
76
86
|
self.turn_count = 0
|
|
77
87
|
|
|
88
|
+
async def initialize(self) -> None:
|
|
89
|
+
"""Initialize both the API client and computer interface with retries."""
|
|
90
|
+
for attempt in range(self.max_retries):
|
|
91
|
+
try:
|
|
92
|
+
logger.info(
|
|
93
|
+
f"Starting initialization (attempt {attempt + 1}/{self.max_retries})..."
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
# Initialize API client
|
|
97
|
+
await self.initialize_client()
|
|
98
|
+
|
|
99
|
+
logger.info("Initialization complete.")
|
|
100
|
+
return
|
|
101
|
+
except Exception as e:
|
|
102
|
+
if attempt < self.max_retries - 1:
|
|
103
|
+
logger.warning(
|
|
104
|
+
f"Initialization failed (attempt {attempt + 1}/{self.max_retries}): {str(e)}. Retrying..."
|
|
105
|
+
)
|
|
106
|
+
await asyncio.sleep(self.retry_delay)
|
|
107
|
+
else:
|
|
108
|
+
logger.error(
|
|
109
|
+
f"Initialization failed after {self.max_retries} attempts: {str(e)}"
|
|
110
|
+
)
|
|
111
|
+
raise RuntimeError(f"Failed to initialize: {str(e)}")
|
|
112
|
+
|
|
113
|
+
###########################################
|
|
114
|
+
|
|
115
|
+
# ABSTRACT METHODS TO BE IMPLEMENTED BY SUBCLASSES
|
|
116
|
+
###########################################
|
|
117
|
+
|
|
118
|
+
@abstractmethod
|
|
119
|
+
async def initialize_client(self) -> None:
|
|
120
|
+
"""Initialize the API client and any provider-specific components.
|
|
121
|
+
|
|
122
|
+
This method must be implemented by subclasses to set up
|
|
123
|
+
provider-specific clients and tools.
|
|
124
|
+
"""
|
|
125
|
+
raise NotImplementedError
|
|
126
|
+
|
|
127
|
+
@abstractmethod
|
|
128
|
+
async def run(self, messages: List[Dict[str, Any]]) -> AsyncGenerator[AgentResponse, None]:
|
|
129
|
+
"""Run the agent loop with provided messages.
|
|
130
|
+
|
|
131
|
+
This method handles the main agent loop including message processing,
|
|
132
|
+
API calls, response handling, and action execution.
|
|
133
|
+
|
|
134
|
+
Args:
|
|
135
|
+
messages: List of message objects
|
|
136
|
+
|
|
137
|
+
Yields:
|
|
138
|
+
Agent response format
|
|
139
|
+
"""
|
|
140
|
+
raise NotImplementedError
|
|
141
|
+
|
|
142
|
+
###########################################
|
|
143
|
+
# EXPERIMENT AND TRAJECTORY MANAGEMENT
|
|
144
|
+
###########################################
|
|
145
|
+
|
|
78
146
|
def _setup_experiment_dirs(self) -> None:
|
|
79
147
|
"""Setup the experiment directory structure."""
|
|
80
148
|
if self.experiment_manager:
|
|
@@ -100,10 +168,13 @@ class BaseLoop(ABC):
|
|
|
100
168
|
) -> None:
|
|
101
169
|
"""Log API call details to file.
|
|
102
170
|
|
|
171
|
+
Preserves provider-specific formats for requests and responses to ensure
|
|
172
|
+
accurate logging for debugging and analysis purposes.
|
|
173
|
+
|
|
103
174
|
Args:
|
|
104
175
|
call_type: Type of API call (e.g., 'request', 'response', 'error')
|
|
105
|
-
request: The API request data
|
|
106
|
-
response: Optional API response data
|
|
176
|
+
request: The API request data in provider-specific format
|
|
177
|
+
response: Optional API response data in provider-specific format
|
|
107
178
|
error: Optional error information
|
|
108
179
|
"""
|
|
109
180
|
if self.experiment_manager:
|
|
@@ -129,116 +200,3 @@ class BaseLoop(ABC):
|
|
|
129
200
|
"""
|
|
130
201
|
if self.experiment_manager:
|
|
131
202
|
self.experiment_manager.save_screenshot(img_base64, action_type)
|
|
132
|
-
|
|
133
|
-
async def initialize(self) -> None:
|
|
134
|
-
"""Initialize both the API client and computer interface with retries."""
|
|
135
|
-
for attempt in range(self.max_retries):
|
|
136
|
-
try:
|
|
137
|
-
logger.info(
|
|
138
|
-
f"Starting initialization (attempt {attempt + 1}/{self.max_retries})..."
|
|
139
|
-
)
|
|
140
|
-
|
|
141
|
-
# Initialize API client
|
|
142
|
-
await self.initialize_client()
|
|
143
|
-
|
|
144
|
-
# Initialize computer
|
|
145
|
-
await self.computer.initialize()
|
|
146
|
-
|
|
147
|
-
logger.info("Initialization complete.")
|
|
148
|
-
return
|
|
149
|
-
except Exception as e:
|
|
150
|
-
if attempt < self.max_retries - 1:
|
|
151
|
-
logger.warning(
|
|
152
|
-
f"Initialization failed (attempt {attempt + 1}/{self.max_retries}): {str(e)}. Retrying..."
|
|
153
|
-
)
|
|
154
|
-
await asyncio.sleep(self.retry_delay)
|
|
155
|
-
else:
|
|
156
|
-
logger.error(
|
|
157
|
-
f"Initialization failed after {self.max_retries} attempts: {str(e)}"
|
|
158
|
-
)
|
|
159
|
-
raise RuntimeError(f"Failed to initialize: {str(e)}")
|
|
160
|
-
|
|
161
|
-
async def _get_parsed_screen_som(self) -> Dict[str, Any]:
|
|
162
|
-
"""Get parsed screen information.
|
|
163
|
-
|
|
164
|
-
Returns:
|
|
165
|
-
Dict containing screen information
|
|
166
|
-
"""
|
|
167
|
-
try:
|
|
168
|
-
# Take screenshot
|
|
169
|
-
screenshot = await self.computer.interface.screenshot()
|
|
170
|
-
|
|
171
|
-
# Initialize with default values
|
|
172
|
-
width, height = 1024, 768
|
|
173
|
-
base64_image = ""
|
|
174
|
-
|
|
175
|
-
# Handle different types of screenshot returns
|
|
176
|
-
if isinstance(screenshot, bytes):
|
|
177
|
-
# Raw bytes screenshot
|
|
178
|
-
base64_image = base64.b64encode(screenshot).decode("utf-8")
|
|
179
|
-
elif hasattr(screenshot, "base64_image"):
|
|
180
|
-
# Object-style screenshot with attributes
|
|
181
|
-
base64_image = screenshot.base64_image
|
|
182
|
-
if hasattr(screenshot, "width") and hasattr(screenshot, "height"):
|
|
183
|
-
width = screenshot.width
|
|
184
|
-
height = screenshot.height
|
|
185
|
-
|
|
186
|
-
# Create parsed screen data
|
|
187
|
-
parsed_screen = {
|
|
188
|
-
"width": width,
|
|
189
|
-
"height": height,
|
|
190
|
-
"parsed_content_list": [],
|
|
191
|
-
"timestamp": datetime.now().isoformat(),
|
|
192
|
-
"screenshot_base64": base64_image,
|
|
193
|
-
}
|
|
194
|
-
|
|
195
|
-
# Save screenshot if requested
|
|
196
|
-
if self.save_trajectory and self.experiment_manager:
|
|
197
|
-
try:
|
|
198
|
-
img_data = base64_image
|
|
199
|
-
if "," in img_data:
|
|
200
|
-
img_data = img_data.split(",")[1]
|
|
201
|
-
self._save_screenshot(img_data, action_type="state")
|
|
202
|
-
except Exception as e:
|
|
203
|
-
logger.error(f"Error saving screenshot: {str(e)}")
|
|
204
|
-
|
|
205
|
-
return parsed_screen
|
|
206
|
-
except Exception as e:
|
|
207
|
-
logger.error(f"Error taking screenshot: {str(e)}")
|
|
208
|
-
return {
|
|
209
|
-
"width": 1024,
|
|
210
|
-
"height": 768,
|
|
211
|
-
"parsed_content_list": [],
|
|
212
|
-
"timestamp": datetime.now().isoformat(),
|
|
213
|
-
"error": f"Error taking screenshot: {str(e)}",
|
|
214
|
-
"screenshot_base64": "",
|
|
215
|
-
}
|
|
216
|
-
|
|
217
|
-
@abstractmethod
|
|
218
|
-
async def initialize_client(self) -> None:
|
|
219
|
-
"""Initialize the API client and any provider-specific components."""
|
|
220
|
-
raise NotImplementedError
|
|
221
|
-
|
|
222
|
-
@abstractmethod
|
|
223
|
-
async def run(self, messages: List[Dict[str, Any]]) -> AsyncGenerator[Dict[str, Any], None]:
|
|
224
|
-
"""Run the agent loop with provided messages.
|
|
225
|
-
|
|
226
|
-
Args:
|
|
227
|
-
messages: List of message objects
|
|
228
|
-
|
|
229
|
-
Yields:
|
|
230
|
-
Dict containing response data
|
|
231
|
-
"""
|
|
232
|
-
raise NotImplementedError
|
|
233
|
-
|
|
234
|
-
@abstractmethod
|
|
235
|
-
async def _process_screen(
|
|
236
|
-
self, parsed_screen: Dict[str, Any], messages: List[Dict[str, Any]]
|
|
237
|
-
) -> None:
|
|
238
|
-
"""Process screen information and add to messages.
|
|
239
|
-
|
|
240
|
-
Args:
|
|
241
|
-
parsed_screen: Dictionary containing parsed screen info
|
|
242
|
-
messages: List of messages to update
|
|
243
|
-
"""
|
|
244
|
-
raise NotImplementedError
|