cua-agent 0.3.2__py3-none-any.whl → 0.4.0b1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cua-agent might be problematic. Click here for more details.
- agent/__init__.py +15 -51
- agent/__main__.py +21 -0
- agent/adapters/__init__.py +9 -0
- agent/adapters/huggingfacelocal_adapter.py +216 -0
- agent/agent.py +577 -0
- agent/callbacks/__init__.py +17 -0
- agent/callbacks/base.py +153 -0
- agent/callbacks/budget_manager.py +44 -0
- agent/callbacks/image_retention.py +139 -0
- agent/callbacks/logging.py +247 -0
- agent/callbacks/pii_anonymization.py +259 -0
- agent/callbacks/trajectory_saver.py +305 -0
- agent/cli.py +290 -0
- agent/computer_handler.py +107 -0
- agent/decorators.py +90 -0
- agent/loops/__init__.py +11 -0
- agent/loops/anthropic.py +728 -0
- agent/loops/omniparser.py +339 -0
- agent/loops/openai.py +95 -0
- agent/loops/uitars.py +688 -0
- agent/responses.py +207 -0
- agent/types.py +79 -0
- agent/ui/__init__.py +7 -1
- agent/ui/gradio/__init__.py +6 -19
- agent/ui/gradio/app.py +80 -1299
- agent/ui/gradio/ui_components.py +703 -0
- cua_agent-0.4.0b1.dist-info/METADATA +424 -0
- cua_agent-0.4.0b1.dist-info/RECORD +30 -0
- agent/core/__init__.py +0 -27
- agent/core/agent.py +0 -210
- agent/core/base.py +0 -217
- agent/core/callbacks.py +0 -200
- agent/core/experiment.py +0 -249
- agent/core/factory.py +0 -122
- agent/core/messages.py +0 -332
- agent/core/provider_config.py +0 -21
- agent/core/telemetry.py +0 -142
- agent/core/tools/__init__.py +0 -21
- agent/core/tools/base.py +0 -74
- agent/core/tools/bash.py +0 -52
- agent/core/tools/collection.py +0 -46
- agent/core/tools/computer.py +0 -113
- agent/core/tools/edit.py +0 -67
- agent/core/tools/manager.py +0 -56
- agent/core/tools.py +0 -32
- agent/core/types.py +0 -88
- agent/core/visualization.py +0 -197
- agent/providers/__init__.py +0 -4
- agent/providers/anthropic/__init__.py +0 -6
- agent/providers/anthropic/api/client.py +0 -360
- agent/providers/anthropic/api/logging.py +0 -150
- agent/providers/anthropic/api_handler.py +0 -140
- agent/providers/anthropic/callbacks/__init__.py +0 -5
- agent/providers/anthropic/callbacks/manager.py +0 -65
- agent/providers/anthropic/loop.py +0 -568
- agent/providers/anthropic/prompts.py +0 -23
- agent/providers/anthropic/response_handler.py +0 -226
- agent/providers/anthropic/tools/__init__.py +0 -33
- agent/providers/anthropic/tools/base.py +0 -88
- agent/providers/anthropic/tools/bash.py +0 -66
- agent/providers/anthropic/tools/collection.py +0 -34
- agent/providers/anthropic/tools/computer.py +0 -396
- agent/providers/anthropic/tools/edit.py +0 -326
- agent/providers/anthropic/tools/manager.py +0 -54
- agent/providers/anthropic/tools/run.py +0 -42
- agent/providers/anthropic/types.py +0 -16
- agent/providers/anthropic/utils.py +0 -381
- agent/providers/omni/__init__.py +0 -8
- agent/providers/omni/api_handler.py +0 -42
- agent/providers/omni/clients/anthropic.py +0 -103
- agent/providers/omni/clients/base.py +0 -35
- agent/providers/omni/clients/oaicompat.py +0 -195
- agent/providers/omni/clients/ollama.py +0 -122
- agent/providers/omni/clients/openai.py +0 -155
- agent/providers/omni/clients/utils.py +0 -25
- agent/providers/omni/image_utils.py +0 -34
- agent/providers/omni/loop.py +0 -990
- agent/providers/omni/parser.py +0 -307
- agent/providers/omni/prompts.py +0 -64
- agent/providers/omni/tools/__init__.py +0 -30
- agent/providers/omni/tools/base.py +0 -29
- agent/providers/omni/tools/bash.py +0 -74
- agent/providers/omni/tools/computer.py +0 -179
- agent/providers/omni/tools/manager.py +0 -61
- agent/providers/omni/utils.py +0 -236
- agent/providers/openai/__init__.py +0 -6
- agent/providers/openai/api_handler.py +0 -456
- agent/providers/openai/loop.py +0 -472
- agent/providers/openai/response_handler.py +0 -205
- agent/providers/openai/tools/__init__.py +0 -15
- agent/providers/openai/tools/base.py +0 -79
- agent/providers/openai/tools/computer.py +0 -326
- agent/providers/openai/tools/manager.py +0 -106
- agent/providers/openai/types.py +0 -36
- agent/providers/openai/utils.py +0 -98
- agent/providers/uitars/__init__.py +0 -1
- agent/providers/uitars/clients/base.py +0 -35
- agent/providers/uitars/clients/mlxvlm.py +0 -263
- agent/providers/uitars/clients/oaicompat.py +0 -214
- agent/providers/uitars/loop.py +0 -660
- agent/providers/uitars/prompts.py +0 -63
- agent/providers/uitars/tools/__init__.py +0 -1
- agent/providers/uitars/tools/computer.py +0 -283
- agent/providers/uitars/tools/manager.py +0 -60
- agent/providers/uitars/utils.py +0 -264
- agent/telemetry.py +0 -21
- agent/ui/__main__.py +0 -15
- cua_agent-0.3.2.dist-info/METADATA +0 -295
- cua_agent-0.3.2.dist-info/RECORD +0 -87
- {cua_agent-0.3.2.dist-info → cua_agent-0.4.0b1.dist-info}/WHEEL +0 -0
- {cua_agent-0.3.2.dist-info → cua_agent-0.4.0b1.dist-info}/entry_points.txt +0 -0
agent/core/base.py
DELETED
|
@@ -1,217 +0,0 @@
|
|
|
1
|
-
"""Base loop definitions."""
|
|
2
|
-
|
|
3
|
-
import logging
|
|
4
|
-
import asyncio
|
|
5
|
-
from abc import ABC, abstractmethod
|
|
6
|
-
from typing import Any, AsyncGenerator, Dict, List, Optional
|
|
7
|
-
|
|
8
|
-
from computer import Computer
|
|
9
|
-
from .messages import StandardMessageManager, ImageRetentionConfig
|
|
10
|
-
from .types import AgentResponse
|
|
11
|
-
from .experiment import ExperimentManager
|
|
12
|
-
from .callbacks import CallbackManager, CallbackHandler
|
|
13
|
-
|
|
14
|
-
logger = logging.getLogger(__name__)
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
class BaseLoop(ABC):
|
|
18
|
-
"""Base class for agent loops that handle message processing and tool execution."""
|
|
19
|
-
|
|
20
|
-
def __init__(
|
|
21
|
-
self,
|
|
22
|
-
computer: Computer,
|
|
23
|
-
model: str,
|
|
24
|
-
api_key: str,
|
|
25
|
-
max_tokens: int = 4096,
|
|
26
|
-
max_retries: int = 3,
|
|
27
|
-
retry_delay: float = 1.0,
|
|
28
|
-
base_dir: Optional[str] = "trajectories",
|
|
29
|
-
save_trajectory: bool = True,
|
|
30
|
-
only_n_most_recent_images: Optional[int] = 2,
|
|
31
|
-
callback_handlers: Optional[List[CallbackHandler]] = None,
|
|
32
|
-
**kwargs,
|
|
33
|
-
):
|
|
34
|
-
"""Initialize base agent loop.
|
|
35
|
-
|
|
36
|
-
Args:
|
|
37
|
-
computer: Computer instance to control
|
|
38
|
-
model: Model name to use
|
|
39
|
-
api_key: API key for provider
|
|
40
|
-
max_tokens: Maximum tokens to generate
|
|
41
|
-
max_retries: Maximum number of retries
|
|
42
|
-
retry_delay: Delay between retries in seconds
|
|
43
|
-
base_dir: Base directory for saving experiment data
|
|
44
|
-
save_trajectory: Whether to save trajectory data
|
|
45
|
-
only_n_most_recent_images: Maximum number of recent screenshots to include in API requests
|
|
46
|
-
**kwargs: Additional provider-specific arguments
|
|
47
|
-
"""
|
|
48
|
-
self.computer = computer
|
|
49
|
-
self.model = model
|
|
50
|
-
self.api_key = api_key
|
|
51
|
-
self.max_tokens = max_tokens
|
|
52
|
-
self.max_retries = max_retries
|
|
53
|
-
self.retry_delay = retry_delay
|
|
54
|
-
self.base_dir = base_dir
|
|
55
|
-
self.save_trajectory = save_trajectory
|
|
56
|
-
self.only_n_most_recent_images = only_n_most_recent_images
|
|
57
|
-
self._kwargs = kwargs
|
|
58
|
-
|
|
59
|
-
# Initialize message manager
|
|
60
|
-
self.message_manager = StandardMessageManager(
|
|
61
|
-
config=ImageRetentionConfig(num_images_to_keep=only_n_most_recent_images)
|
|
62
|
-
)
|
|
63
|
-
|
|
64
|
-
# Initialize experiment manager
|
|
65
|
-
if self.save_trajectory and self.base_dir:
|
|
66
|
-
self.experiment_manager = ExperimentManager(
|
|
67
|
-
base_dir=self.base_dir,
|
|
68
|
-
only_n_most_recent_images=only_n_most_recent_images,
|
|
69
|
-
)
|
|
70
|
-
# Track directories for convenience
|
|
71
|
-
self.run_dir = self.experiment_manager.run_dir
|
|
72
|
-
self.current_turn_dir = self.experiment_manager.current_turn_dir
|
|
73
|
-
else:
|
|
74
|
-
self.experiment_manager = None
|
|
75
|
-
self.run_dir = None
|
|
76
|
-
self.current_turn_dir = None
|
|
77
|
-
|
|
78
|
-
# Initialize basic tracking
|
|
79
|
-
self.turn_count = 0
|
|
80
|
-
|
|
81
|
-
# Initialize callback manager
|
|
82
|
-
self.callback_manager = CallbackManager(handlers=callback_handlers or [])
|
|
83
|
-
|
|
84
|
-
async def initialize(self) -> None:
|
|
85
|
-
"""Initialize both the API client and computer interface with retries."""
|
|
86
|
-
for attempt in range(self.max_retries):
|
|
87
|
-
try:
|
|
88
|
-
logger.info(
|
|
89
|
-
f"Starting initialization (attempt {attempt + 1}/{self.max_retries})..."
|
|
90
|
-
)
|
|
91
|
-
|
|
92
|
-
# Initialize API client
|
|
93
|
-
await self.initialize_client()
|
|
94
|
-
|
|
95
|
-
logger.info("Initialization complete.")
|
|
96
|
-
return
|
|
97
|
-
except Exception as e:
|
|
98
|
-
if attempt < self.max_retries - 1:
|
|
99
|
-
logger.warning(
|
|
100
|
-
f"Initialization failed (attempt {attempt + 1}/{self.max_retries}): {str(e)}. Retrying..."
|
|
101
|
-
)
|
|
102
|
-
await asyncio.sleep(self.retry_delay)
|
|
103
|
-
else:
|
|
104
|
-
logger.error(
|
|
105
|
-
f"Initialization failed after {self.max_retries} attempts: {str(e)}"
|
|
106
|
-
)
|
|
107
|
-
raise RuntimeError(f"Failed to initialize: {str(e)}")
|
|
108
|
-
|
|
109
|
-
###########################################
|
|
110
|
-
# ABSTRACT METHODS TO BE IMPLEMENTED BY SUBCLASSES
|
|
111
|
-
###########################################
|
|
112
|
-
|
|
113
|
-
@abstractmethod
|
|
114
|
-
async def initialize_client(self) -> None:
|
|
115
|
-
"""Initialize the API client and any provider-specific components.
|
|
116
|
-
|
|
117
|
-
This method must be implemented by subclasses to set up
|
|
118
|
-
provider-specific clients and tools.
|
|
119
|
-
"""
|
|
120
|
-
raise NotImplementedError
|
|
121
|
-
|
|
122
|
-
@abstractmethod
|
|
123
|
-
def run(self, messages: List[Dict[str, Any]]) -> AsyncGenerator[AgentResponse, None]:
|
|
124
|
-
"""Run the agent loop with provided messages.
|
|
125
|
-
|
|
126
|
-
Args:
|
|
127
|
-
messages: List of message objects
|
|
128
|
-
|
|
129
|
-
Returns:
|
|
130
|
-
An async generator that yields agent responses
|
|
131
|
-
"""
|
|
132
|
-
raise NotImplementedError
|
|
133
|
-
|
|
134
|
-
@abstractmethod
|
|
135
|
-
async def cancel(self) -> None:
|
|
136
|
-
"""Cancel the currently running agent loop task.
|
|
137
|
-
|
|
138
|
-
This method should stop any ongoing processing in the agent loop
|
|
139
|
-
and clean up resources appropriately.
|
|
140
|
-
"""
|
|
141
|
-
raise NotImplementedError
|
|
142
|
-
|
|
143
|
-
###########################################
|
|
144
|
-
# EXPERIMENT AND TRAJECTORY MANAGEMENT
|
|
145
|
-
###########################################
|
|
146
|
-
|
|
147
|
-
def _setup_experiment_dirs(self) -> None:
|
|
148
|
-
"""Setup the experiment directory structure."""
|
|
149
|
-
if self.experiment_manager:
|
|
150
|
-
# Use the experiment manager to set up directories
|
|
151
|
-
self.experiment_manager.setup_experiment_dirs()
|
|
152
|
-
|
|
153
|
-
# Update local tracking variables
|
|
154
|
-
self.run_dir = self.experiment_manager.run_dir
|
|
155
|
-
self.current_turn_dir = self.experiment_manager.current_turn_dir
|
|
156
|
-
|
|
157
|
-
def _create_turn_dir(self) -> None:
|
|
158
|
-
"""Create a new directory for the current turn."""
|
|
159
|
-
if self.experiment_manager:
|
|
160
|
-
# Use the experiment manager to create the turn directory
|
|
161
|
-
self.experiment_manager.create_turn_dir()
|
|
162
|
-
|
|
163
|
-
# Update local tracking variables
|
|
164
|
-
self.current_turn_dir = self.experiment_manager.current_turn_dir
|
|
165
|
-
self.turn_count = self.experiment_manager.turn_count
|
|
166
|
-
|
|
167
|
-
def _log_api_call(
|
|
168
|
-
self, call_type: str, request: Any, response: Any = None, error: Optional[Exception] = None
|
|
169
|
-
) -> None:
|
|
170
|
-
"""Log API call details to file.
|
|
171
|
-
|
|
172
|
-
Preserves provider-specific formats for requests and responses to ensure
|
|
173
|
-
accurate logging for debugging and analysis purposes.
|
|
174
|
-
|
|
175
|
-
Args:
|
|
176
|
-
call_type: Type of API call (e.g., 'request', 'response', 'error')
|
|
177
|
-
request: The API request data in provider-specific format
|
|
178
|
-
response: Optional API response data in provider-specific format
|
|
179
|
-
error: Optional error information
|
|
180
|
-
"""
|
|
181
|
-
if self.experiment_manager:
|
|
182
|
-
# Use the experiment manager to log the API call
|
|
183
|
-
provider = getattr(self, "provider", "unknown")
|
|
184
|
-
provider_str = str(provider) if provider else "unknown"
|
|
185
|
-
|
|
186
|
-
self.experiment_manager.log_api_call(
|
|
187
|
-
call_type=call_type,
|
|
188
|
-
request=request,
|
|
189
|
-
provider=provider_str,
|
|
190
|
-
model=self.model,
|
|
191
|
-
response=response,
|
|
192
|
-
error=error,
|
|
193
|
-
)
|
|
194
|
-
|
|
195
|
-
def _save_screenshot(self, img_base64: str, action_type: str = "") -> None:
|
|
196
|
-
"""Save a screenshot to the experiment directory.
|
|
197
|
-
|
|
198
|
-
Args:
|
|
199
|
-
img_base64: Base64 encoded screenshot
|
|
200
|
-
action_type: Type of action that triggered the screenshot
|
|
201
|
-
"""
|
|
202
|
-
if self.experiment_manager:
|
|
203
|
-
self.experiment_manager.save_screenshot(img_base64, action_type)
|
|
204
|
-
|
|
205
|
-
###########################################
|
|
206
|
-
# EVENT HOOKS / CALLBACKS
|
|
207
|
-
###########################################
|
|
208
|
-
|
|
209
|
-
async def handle_screenshot(self, screenshot_base64: str, action_type: str = "", parsed_screen: Optional[dict] = None) -> None:
|
|
210
|
-
"""Process a screenshot through callback managers
|
|
211
|
-
|
|
212
|
-
Args:
|
|
213
|
-
screenshot_base64: Base64 encoded screenshot
|
|
214
|
-
action_type: Type of action that triggered the screenshot
|
|
215
|
-
"""
|
|
216
|
-
if hasattr(self, 'callback_manager'):
|
|
217
|
-
await self.callback_manager.on_screenshot(screenshot_base64, action_type, parsed_screen)
|
agent/core/callbacks.py
DELETED
|
@@ -1,200 +0,0 @@
|
|
|
1
|
-
"""Callback handlers for agent."""
|
|
2
|
-
|
|
3
|
-
import json
|
|
4
|
-
import logging
|
|
5
|
-
from abc import ABC, abstractmethod
|
|
6
|
-
from datetime import datetime
|
|
7
|
-
from typing import Any, Dict, List, Optional, Protocol
|
|
8
|
-
|
|
9
|
-
logger = logging.getLogger(__name__)
|
|
10
|
-
|
|
11
|
-
class ContentCallback(Protocol):
|
|
12
|
-
"""Protocol for content callbacks."""
|
|
13
|
-
def __call__(self, content: Dict[str, Any]) -> None: ...
|
|
14
|
-
|
|
15
|
-
class ToolCallback(Protocol):
|
|
16
|
-
"""Protocol for tool callbacks."""
|
|
17
|
-
def __call__(self, result: Any, tool_id: str) -> None: ...
|
|
18
|
-
|
|
19
|
-
class APICallback(Protocol):
|
|
20
|
-
"""Protocol for API callbacks."""
|
|
21
|
-
def __call__(self, request: Any, response: Any, error: Optional[Exception] = None) -> None: ...
|
|
22
|
-
|
|
23
|
-
class ScreenshotCallback(Protocol):
|
|
24
|
-
"""Protocol for screenshot callbacks."""
|
|
25
|
-
def __call__(self, screenshot_base64: str, action_type: str = "") -> Optional[str]: ...
|
|
26
|
-
|
|
27
|
-
class BaseCallbackManager(ABC):
|
|
28
|
-
"""Base class for callback managers."""
|
|
29
|
-
|
|
30
|
-
def __init__(
|
|
31
|
-
self,
|
|
32
|
-
content_callback: ContentCallback,
|
|
33
|
-
tool_callback: ToolCallback,
|
|
34
|
-
api_callback: APICallback,
|
|
35
|
-
):
|
|
36
|
-
"""Initialize the callback manager.
|
|
37
|
-
|
|
38
|
-
Args:
|
|
39
|
-
content_callback: Callback for content updates
|
|
40
|
-
tool_callback: Callback for tool execution results
|
|
41
|
-
api_callback: Callback for API interactions
|
|
42
|
-
"""
|
|
43
|
-
self.content_callback = content_callback
|
|
44
|
-
self.tool_callback = tool_callback
|
|
45
|
-
self.api_callback = api_callback
|
|
46
|
-
|
|
47
|
-
@abstractmethod
|
|
48
|
-
def on_content(self, content: Any) -> None:
|
|
49
|
-
"""Handle content updates."""
|
|
50
|
-
raise NotImplementedError
|
|
51
|
-
|
|
52
|
-
@abstractmethod
|
|
53
|
-
def on_tool_result(self, result: Any, tool_id: str) -> None:
|
|
54
|
-
"""Handle tool execution results."""
|
|
55
|
-
raise NotImplementedError
|
|
56
|
-
|
|
57
|
-
@abstractmethod
|
|
58
|
-
def on_api_interaction(
|
|
59
|
-
self,
|
|
60
|
-
request: Any,
|
|
61
|
-
response: Any,
|
|
62
|
-
error: Optional[Exception] = None
|
|
63
|
-
) -> None:
|
|
64
|
-
"""Handle API interactions."""
|
|
65
|
-
raise NotImplementedError
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
class CallbackManager:
|
|
69
|
-
"""Manager for callback handlers."""
|
|
70
|
-
|
|
71
|
-
def __init__(self, handlers: Optional[List["CallbackHandler"]] = None):
|
|
72
|
-
"""Initialize with optional handlers.
|
|
73
|
-
|
|
74
|
-
Args:
|
|
75
|
-
handlers: List of callback handlers
|
|
76
|
-
"""
|
|
77
|
-
self.handlers = handlers or []
|
|
78
|
-
|
|
79
|
-
def add_handler(self, handler: "CallbackHandler") -> None:
|
|
80
|
-
"""Add a callback handler.
|
|
81
|
-
|
|
82
|
-
Args:
|
|
83
|
-
handler: Callback handler to add
|
|
84
|
-
"""
|
|
85
|
-
self.handlers.append(handler)
|
|
86
|
-
|
|
87
|
-
async def on_action_start(self, action: str, **kwargs) -> None:
|
|
88
|
-
"""Called when an action starts.
|
|
89
|
-
|
|
90
|
-
Args:
|
|
91
|
-
action: Action name
|
|
92
|
-
**kwargs: Additional data
|
|
93
|
-
"""
|
|
94
|
-
for handler in self.handlers:
|
|
95
|
-
await handler.on_action_start(action, **kwargs)
|
|
96
|
-
|
|
97
|
-
async def on_action_end(self, action: str, success: bool, **kwargs) -> None:
|
|
98
|
-
"""Called when an action ends.
|
|
99
|
-
|
|
100
|
-
Args:
|
|
101
|
-
action: Action name
|
|
102
|
-
success: Whether the action was successful
|
|
103
|
-
**kwargs: Additional data
|
|
104
|
-
"""
|
|
105
|
-
for handler in self.handlers:
|
|
106
|
-
await handler.on_action_end(action, success, **kwargs)
|
|
107
|
-
|
|
108
|
-
async def on_error(self, error: Exception, **kwargs) -> None:
|
|
109
|
-
"""Called when an error occurs.
|
|
110
|
-
|
|
111
|
-
Args:
|
|
112
|
-
error: Exception that occurred
|
|
113
|
-
**kwargs: Additional data
|
|
114
|
-
"""
|
|
115
|
-
for handler in self.handlers:
|
|
116
|
-
await handler.on_error(error, **kwargs)
|
|
117
|
-
|
|
118
|
-
async def on_screenshot(self, screenshot_base64: str, action_type: str = "", parsed_screen: Optional[dict] = None) -> None:
|
|
119
|
-
"""Called when a screenshot is taken.
|
|
120
|
-
|
|
121
|
-
Args:
|
|
122
|
-
screenshot_base64: Base64 encoded screenshot
|
|
123
|
-
action_type: Type of action that triggered the screenshot
|
|
124
|
-
parsed_screen: Optional output from parsing the screenshot
|
|
125
|
-
|
|
126
|
-
Returns:
|
|
127
|
-
Modified screenshot or original if no modifications
|
|
128
|
-
"""
|
|
129
|
-
for handler in self.handlers:
|
|
130
|
-
await handler.on_screenshot(screenshot_base64, action_type, parsed_screen)
|
|
131
|
-
|
|
132
|
-
class CallbackHandler(ABC):
|
|
133
|
-
"""Base class for callback handlers."""
|
|
134
|
-
|
|
135
|
-
@abstractmethod
|
|
136
|
-
async def on_action_start(self, action: str, **kwargs) -> None:
|
|
137
|
-
"""Called when an action starts.
|
|
138
|
-
|
|
139
|
-
Args:
|
|
140
|
-
action: Action name
|
|
141
|
-
**kwargs: Additional data
|
|
142
|
-
"""
|
|
143
|
-
pass
|
|
144
|
-
|
|
145
|
-
@abstractmethod
|
|
146
|
-
async def on_action_end(self, action: str, success: bool, **kwargs) -> None:
|
|
147
|
-
"""Called when an action ends.
|
|
148
|
-
|
|
149
|
-
Args:
|
|
150
|
-
action: Action name
|
|
151
|
-
success: Whether the action was successful
|
|
152
|
-
**kwargs: Additional data
|
|
153
|
-
"""
|
|
154
|
-
pass
|
|
155
|
-
|
|
156
|
-
@abstractmethod
|
|
157
|
-
async def on_error(self, error: Exception, **kwargs) -> None:
|
|
158
|
-
"""Called when an error occurs.
|
|
159
|
-
|
|
160
|
-
Args:
|
|
161
|
-
error: Exception that occurred
|
|
162
|
-
**kwargs: Additional data
|
|
163
|
-
"""
|
|
164
|
-
pass
|
|
165
|
-
|
|
166
|
-
@abstractmethod
|
|
167
|
-
async def on_screenshot(self, screenshot_base64: str, action_type: str = "", parsed_screen: Optional[dict] = None) -> None:
|
|
168
|
-
"""Called when a screenshot is taken.
|
|
169
|
-
|
|
170
|
-
Args:
|
|
171
|
-
screenshot_base64: Base64 encoded screenshot
|
|
172
|
-
action_type: Type of action that triggered the screenshot
|
|
173
|
-
|
|
174
|
-
Returns:
|
|
175
|
-
Optional modified screenshot
|
|
176
|
-
"""
|
|
177
|
-
pass
|
|
178
|
-
|
|
179
|
-
class DefaultCallbackHandler(CallbackHandler):
|
|
180
|
-
"""Default implementation of CallbackHandler with no-op methods.
|
|
181
|
-
|
|
182
|
-
This class implements all abstract methods from CallbackHandler,
|
|
183
|
-
allowing subclasses to override only the methods they need.
|
|
184
|
-
"""
|
|
185
|
-
|
|
186
|
-
async def on_action_start(self, action: str, **kwargs) -> None:
|
|
187
|
-
"""Default no-op implementation."""
|
|
188
|
-
pass
|
|
189
|
-
|
|
190
|
-
async def on_action_end(self, action: str, success: bool, **kwargs) -> None:
|
|
191
|
-
"""Default no-op implementation."""
|
|
192
|
-
pass
|
|
193
|
-
|
|
194
|
-
async def on_error(self, error: Exception, **kwargs) -> None:
|
|
195
|
-
"""Default no-op implementation."""
|
|
196
|
-
pass
|
|
197
|
-
|
|
198
|
-
async def on_screenshot(self, screenshot_base64: str, action_type: str = "") -> None:
|
|
199
|
-
"""Default no-op implementation."""
|
|
200
|
-
pass
|
agent/core/experiment.py
DELETED
|
@@ -1,249 +0,0 @@
|
|
|
1
|
-
"""Core experiment management for agents."""
|
|
2
|
-
|
|
3
|
-
import os
|
|
4
|
-
import logging
|
|
5
|
-
import base64
|
|
6
|
-
from io import BytesIO
|
|
7
|
-
from datetime import datetime
|
|
8
|
-
from typing import Any, Dict, List, Optional
|
|
9
|
-
from PIL import Image
|
|
10
|
-
import json
|
|
11
|
-
import re
|
|
12
|
-
|
|
13
|
-
logger = logging.getLogger(__name__)
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
class ExperimentManager:
|
|
17
|
-
"""Manages experiment directories and logging for the agent."""
|
|
18
|
-
|
|
19
|
-
def __init__(
|
|
20
|
-
self,
|
|
21
|
-
base_dir: Optional[str] = None,
|
|
22
|
-
only_n_most_recent_images: Optional[int] = None,
|
|
23
|
-
):
|
|
24
|
-
"""Initialize the experiment manager.
|
|
25
|
-
|
|
26
|
-
Args:
|
|
27
|
-
base_dir: Base directory for saving experiment data
|
|
28
|
-
only_n_most_recent_images: Maximum number of recent screenshots to include in API requests
|
|
29
|
-
"""
|
|
30
|
-
self.base_dir = base_dir
|
|
31
|
-
self.only_n_most_recent_images = only_n_most_recent_images
|
|
32
|
-
self.run_dir = None
|
|
33
|
-
self.current_turn_dir = None
|
|
34
|
-
self.turn_count = 0
|
|
35
|
-
self.screenshot_count = 0
|
|
36
|
-
# Track all screenshots for potential API request inclusion
|
|
37
|
-
self.screenshot_paths = []
|
|
38
|
-
|
|
39
|
-
# Set up experiment directories if base_dir is provided
|
|
40
|
-
if self.base_dir:
|
|
41
|
-
self.setup_experiment_dirs()
|
|
42
|
-
|
|
43
|
-
def setup_experiment_dirs(self) -> None:
|
|
44
|
-
"""Setup the experiment directory structure."""
|
|
45
|
-
if not self.base_dir:
|
|
46
|
-
return
|
|
47
|
-
|
|
48
|
-
# Create base experiments directory if it doesn't exist
|
|
49
|
-
os.makedirs(self.base_dir, exist_ok=True)
|
|
50
|
-
|
|
51
|
-
# Create timestamped run directory
|
|
52
|
-
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
53
|
-
self.run_dir = os.path.join(self.base_dir, timestamp)
|
|
54
|
-
os.makedirs(self.run_dir, exist_ok=True)
|
|
55
|
-
logger.info(f"Created run directory: {self.run_dir}")
|
|
56
|
-
|
|
57
|
-
# Create first turn directory
|
|
58
|
-
self.create_turn_dir()
|
|
59
|
-
|
|
60
|
-
def create_turn_dir(self) -> None:
|
|
61
|
-
"""Create a new directory for the current turn."""
|
|
62
|
-
if not self.run_dir:
|
|
63
|
-
logger.warning("Cannot create turn directory: run_dir not set")
|
|
64
|
-
return
|
|
65
|
-
|
|
66
|
-
# Increment turn counter
|
|
67
|
-
self.turn_count += 1
|
|
68
|
-
|
|
69
|
-
# Create turn directory with padded number
|
|
70
|
-
turn_name = f"turn_{self.turn_count:03d}"
|
|
71
|
-
self.current_turn_dir = os.path.join(self.run_dir, turn_name)
|
|
72
|
-
os.makedirs(self.current_turn_dir, exist_ok=True)
|
|
73
|
-
logger.info(f"Created turn directory: {self.current_turn_dir}")
|
|
74
|
-
|
|
75
|
-
def sanitize_log_data(self, data: Any) -> Any:
|
|
76
|
-
"""Sanitize log data by replacing large binary data with placeholders.
|
|
77
|
-
|
|
78
|
-
Args:
|
|
79
|
-
data: Data to sanitize
|
|
80
|
-
|
|
81
|
-
Returns:
|
|
82
|
-
Sanitized copy of the data
|
|
83
|
-
"""
|
|
84
|
-
if isinstance(data, dict):
|
|
85
|
-
result = {}
|
|
86
|
-
for k, v in data.items():
|
|
87
|
-
# Special handling for 'data' field in Anthropic message source
|
|
88
|
-
if k == "data" and isinstance(v, str) and len(v) > 1000:
|
|
89
|
-
result[k] = f"[BASE64_DATA_LENGTH_{len(v)}]"
|
|
90
|
-
# Special handling for the 'media_type' key which indicates we're in an image block
|
|
91
|
-
elif k == "media_type" and "image" in str(v):
|
|
92
|
-
result[k] = v
|
|
93
|
-
# If we're in an image block, look for a sibling 'data' field with base64 content
|
|
94
|
-
if (
|
|
95
|
-
"data" in result
|
|
96
|
-
and isinstance(result["data"], str)
|
|
97
|
-
and len(result["data"]) > 1000
|
|
98
|
-
):
|
|
99
|
-
result["data"] = f"[BASE64_DATA_LENGTH_{len(result['data'])}]"
|
|
100
|
-
else:
|
|
101
|
-
result[k] = self.sanitize_log_data(v)
|
|
102
|
-
return result
|
|
103
|
-
elif isinstance(data, list):
|
|
104
|
-
return [self.sanitize_log_data(item) for item in data]
|
|
105
|
-
elif isinstance(data, str) and len(data) > 1000 and "base64" in data.lower():
|
|
106
|
-
return f"[BASE64_DATA_LENGTH_{len(data)}]"
|
|
107
|
-
else:
|
|
108
|
-
return data
|
|
109
|
-
|
|
110
|
-
def save_screenshot(self, img_base64: str, action_type: str = "") -> Optional[str]:
|
|
111
|
-
"""Save a screenshot to the experiment directory.
|
|
112
|
-
|
|
113
|
-
Args:
|
|
114
|
-
img_base64: Base64 encoded screenshot
|
|
115
|
-
action_type: Type of action that triggered the screenshot
|
|
116
|
-
|
|
117
|
-
Returns:
|
|
118
|
-
Path to the saved screenshot or None if there was an error
|
|
119
|
-
"""
|
|
120
|
-
if not self.current_turn_dir:
|
|
121
|
-
return None
|
|
122
|
-
|
|
123
|
-
try:
|
|
124
|
-
# Increment screenshot counter
|
|
125
|
-
self.screenshot_count += 1
|
|
126
|
-
|
|
127
|
-
# Sanitize action_type to ensure valid filename
|
|
128
|
-
# Replace characters that are not safe for filenames
|
|
129
|
-
sanitized_action = ""
|
|
130
|
-
if action_type:
|
|
131
|
-
# Replace invalid filename characters with underscores
|
|
132
|
-
sanitized_action = re.sub(r'[\\/*?:"<>|]', "_", action_type)
|
|
133
|
-
# Limit the length to avoid excessively long filenames
|
|
134
|
-
sanitized_action = sanitized_action[:50]
|
|
135
|
-
|
|
136
|
-
# Create a descriptive filename
|
|
137
|
-
timestamp = int(datetime.now().timestamp() * 1000)
|
|
138
|
-
action_suffix = f"_{sanitized_action}" if sanitized_action else ""
|
|
139
|
-
filename = f"screenshot_{self.screenshot_count:03d}{action_suffix}_{timestamp}.png"
|
|
140
|
-
|
|
141
|
-
# Save directly to the turn directory
|
|
142
|
-
filepath = os.path.join(self.current_turn_dir, filename)
|
|
143
|
-
|
|
144
|
-
# Save the screenshot
|
|
145
|
-
img_data = base64.b64decode(img_base64)
|
|
146
|
-
with open(filepath, "wb") as f:
|
|
147
|
-
f.write(img_data)
|
|
148
|
-
|
|
149
|
-
# Keep track of the file path
|
|
150
|
-
self.screenshot_paths.append(filepath)
|
|
151
|
-
|
|
152
|
-
return filepath
|
|
153
|
-
except Exception as e:
|
|
154
|
-
logger.error(f"Error saving screenshot: {str(e)}")
|
|
155
|
-
return None
|
|
156
|
-
|
|
157
|
-
def save_action_visualization(
|
|
158
|
-
self, img: Image.Image, action_name: str, details: str = ""
|
|
159
|
-
) -> str:
|
|
160
|
-
"""Save a visualization of an action.
|
|
161
|
-
|
|
162
|
-
Args:
|
|
163
|
-
img: Image to save
|
|
164
|
-
action_name: Name of the action
|
|
165
|
-
details: Additional details about the action
|
|
166
|
-
|
|
167
|
-
Returns:
|
|
168
|
-
Path to the saved image
|
|
169
|
-
"""
|
|
170
|
-
if not self.current_turn_dir:
|
|
171
|
-
return ""
|
|
172
|
-
|
|
173
|
-
try:
|
|
174
|
-
# Create a descriptive filename
|
|
175
|
-
timestamp = int(datetime.now().timestamp() * 1000)
|
|
176
|
-
details_suffix = f"_{details}" if details else ""
|
|
177
|
-
filename = f"vis_{action_name}{details_suffix}_{timestamp}.png"
|
|
178
|
-
|
|
179
|
-
# Save directly to the turn directory
|
|
180
|
-
filepath = os.path.join(self.current_turn_dir, filename)
|
|
181
|
-
|
|
182
|
-
# Save the image
|
|
183
|
-
img.save(filepath)
|
|
184
|
-
|
|
185
|
-
# Keep track of the file path
|
|
186
|
-
self.screenshot_paths.append(filepath)
|
|
187
|
-
|
|
188
|
-
return filepath
|
|
189
|
-
except Exception as e:
|
|
190
|
-
logger.error(f"Error saving action visualization: {str(e)}")
|
|
191
|
-
return ""
|
|
192
|
-
|
|
193
|
-
def log_api_call(
|
|
194
|
-
self,
|
|
195
|
-
call_type: str,
|
|
196
|
-
request: Any,
|
|
197
|
-
provider: str = "unknown",
|
|
198
|
-
model: str = "unknown",
|
|
199
|
-
response: Any = None,
|
|
200
|
-
error: Optional[Exception] = None,
|
|
201
|
-
) -> None:
|
|
202
|
-
"""Log API call details to file.
|
|
203
|
-
|
|
204
|
-
Args:
|
|
205
|
-
call_type: Type of API call (request, response, error)
|
|
206
|
-
request: Request data
|
|
207
|
-
provider: API provider name
|
|
208
|
-
model: Model name
|
|
209
|
-
response: Response data (for response logs)
|
|
210
|
-
error: Error information (for error logs)
|
|
211
|
-
"""
|
|
212
|
-
if not self.current_turn_dir:
|
|
213
|
-
logger.warning("Cannot log API call: current_turn_dir not set")
|
|
214
|
-
return
|
|
215
|
-
|
|
216
|
-
try:
|
|
217
|
-
# Create a timestamp for the log file
|
|
218
|
-
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
219
|
-
|
|
220
|
-
# Create filename based on log type
|
|
221
|
-
filename = f"api_call_{timestamp}_{call_type}.json"
|
|
222
|
-
filepath = os.path.join(self.current_turn_dir, filename)
|
|
223
|
-
|
|
224
|
-
# Sanitize data before logging
|
|
225
|
-
sanitized_request = self.sanitize_log_data(request)
|
|
226
|
-
sanitized_response = self.sanitize_log_data(response) if response is not None else None
|
|
227
|
-
|
|
228
|
-
# Prepare log data
|
|
229
|
-
log_data = {
|
|
230
|
-
"timestamp": timestamp,
|
|
231
|
-
"provider": provider,
|
|
232
|
-
"model": model,
|
|
233
|
-
"type": call_type,
|
|
234
|
-
"request": sanitized_request,
|
|
235
|
-
}
|
|
236
|
-
|
|
237
|
-
if sanitized_response is not None:
|
|
238
|
-
log_data["response"] = sanitized_response
|
|
239
|
-
if error is not None:
|
|
240
|
-
log_data["error"] = str(error)
|
|
241
|
-
|
|
242
|
-
# Write to file
|
|
243
|
-
with open(filepath, "w") as f:
|
|
244
|
-
json.dump(log_data, f, indent=2, default=str)
|
|
245
|
-
|
|
246
|
-
logger.info(f"Logged API {call_type} to {filepath}")
|
|
247
|
-
|
|
248
|
-
except Exception as e:
|
|
249
|
-
logger.error(f"Error logging API call: {str(e)}")
|