cua-agent 0.4.33__py3-none-any.whl → 0.4.35__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cua-agent might be problematic. Click here for more details.
- agent/__init__.py +4 -10
- agent/__main__.py +2 -1
- agent/adapters/huggingfacelocal_adapter.py +54 -61
- agent/adapters/human_adapter.py +116 -114
- agent/adapters/mlxvlm_adapter.py +110 -99
- agent/adapters/models/__init__.py +14 -6
- agent/adapters/models/generic.py +7 -4
- agent/adapters/models/internvl.py +66 -30
- agent/adapters/models/opencua.py +23 -8
- agent/adapters/models/qwen2_5_vl.py +7 -4
- agent/agent.py +184 -158
- agent/callbacks/__init__.py +4 -4
- agent/callbacks/base.py +45 -31
- agent/callbacks/budget_manager.py +22 -10
- agent/callbacks/image_retention.py +18 -13
- agent/callbacks/logging.py +55 -42
- agent/callbacks/operator_validator.py +3 -1
- agent/callbacks/pii_anonymization.py +19 -16
- agent/callbacks/telemetry.py +67 -61
- agent/callbacks/trajectory_saver.py +90 -70
- agent/cli.py +115 -110
- agent/computers/__init__.py +13 -8
- agent/computers/base.py +26 -17
- agent/computers/cua.py +27 -23
- agent/computers/custom.py +72 -69
- agent/decorators.py +23 -14
- agent/human_tool/__init__.py +2 -7
- agent/human_tool/__main__.py +6 -2
- agent/human_tool/server.py +48 -37
- agent/human_tool/ui.py +235 -185
- agent/integrations/hud/__init__.py +15 -21
- agent/integrations/hud/agent.py +101 -83
- agent/integrations/hud/proxy.py +90 -57
- agent/loops/__init__.py +25 -21
- agent/loops/anthropic.py +537 -483
- agent/loops/base.py +13 -14
- agent/loops/composed_grounded.py +135 -149
- agent/loops/gemini.py +31 -12
- agent/loops/glm45v.py +135 -133
- agent/loops/gta1.py +47 -50
- agent/loops/holo.py +4 -2
- agent/loops/internvl.py +6 -11
- agent/loops/moondream3.py +49 -20
- agent/loops/omniparser.py +212 -209
- agent/loops/openai.py +49 -50
- agent/loops/opencua.py +29 -41
- agent/loops/qwen.py +475 -0
- agent/loops/uitars.py +237 -202
- agent/proxy/examples.py +54 -50
- agent/proxy/handlers.py +27 -34
- agent/responses.py +330 -330
- agent/types.py +11 -5
- agent/ui/__init__.py +1 -1
- agent/ui/__main__.py +1 -1
- agent/ui/gradio/app.py +23 -18
- agent/ui/gradio/ui_components.py +310 -161
- {cua_agent-0.4.33.dist-info → cua_agent-0.4.35.dist-info}/METADATA +22 -10
- cua_agent-0.4.35.dist-info/RECORD +64 -0
- cua_agent-0.4.33.dist-info/RECORD +0 -63
- {cua_agent-0.4.33.dist-info → cua_agent-0.4.35.dist-info}/WHEEL +0 -0
- {cua_agent-0.4.33.dist-info → cua_agent-0.4.35.dist-info}/entry_points.txt +0 -0
agent/callbacks/__init__.py
CHANGED
|
@@ -3,17 +3,17 @@ Callback system for ComputerAgent preprocessing and postprocessing hooks.
|
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
5
|
from .base import AsyncCallbackHandler
|
|
6
|
+
from .budget_manager import BudgetManagerCallback
|
|
6
7
|
from .image_retention import ImageRetentionCallback
|
|
7
8
|
from .logging import LoggingCallback
|
|
8
|
-
from .trajectory_saver import TrajectorySaverCallback
|
|
9
|
-
from .budget_manager import BudgetManagerCallback
|
|
10
|
-
from .telemetry import TelemetryCallback
|
|
11
9
|
from .operator_validator import OperatorNormalizerCallback
|
|
12
10
|
from .prompt_instructions import PromptInstructionsCallback
|
|
11
|
+
from .telemetry import TelemetryCallback
|
|
12
|
+
from .trajectory_saver import TrajectorySaverCallback
|
|
13
13
|
|
|
14
14
|
__all__ = [
|
|
15
15
|
"AsyncCallbackHandler",
|
|
16
|
-
"ImageRetentionCallback",
|
|
16
|
+
"ImageRetentionCallback",
|
|
17
17
|
"LoggingCallback",
|
|
18
18
|
"TrajectorySaverCallback",
|
|
19
19
|
"BudgetManagerCallback",
|
agent/callbacks/base.py
CHANGED
|
@@ -3,7 +3,7 @@ Base callback handler interface for ComputerAgent preprocessing and postprocessi
|
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
5
|
from abc import ABC, abstractmethod
|
|
6
|
-
from typing import
|
|
6
|
+
from typing import Any, Dict, List, Optional, Union
|
|
7
7
|
|
|
8
8
|
|
|
9
9
|
class AsyncCallbackHandler(ABC):
|
|
@@ -16,42 +16,52 @@ class AsyncCallbackHandler(ABC):
|
|
|
16
16
|
"""Called at the start of an agent run loop."""
|
|
17
17
|
pass
|
|
18
18
|
|
|
19
|
-
async def on_run_end(
|
|
19
|
+
async def on_run_end(
|
|
20
|
+
self,
|
|
21
|
+
kwargs: Dict[str, Any],
|
|
22
|
+
old_items: List[Dict[str, Any]],
|
|
23
|
+
new_items: List[Dict[str, Any]],
|
|
24
|
+
) -> None:
|
|
20
25
|
"""Called at the end of an agent run loop."""
|
|
21
26
|
pass
|
|
22
|
-
|
|
23
|
-
async def on_run_continue(
|
|
27
|
+
|
|
28
|
+
async def on_run_continue(
|
|
29
|
+
self,
|
|
30
|
+
kwargs: Dict[str, Any],
|
|
31
|
+
old_items: List[Dict[str, Any]],
|
|
32
|
+
new_items: List[Dict[str, Any]],
|
|
33
|
+
) -> bool:
|
|
24
34
|
"""Called during agent run loop to determine if execution should continue.
|
|
25
|
-
|
|
35
|
+
|
|
26
36
|
Args:
|
|
27
37
|
kwargs: Run arguments
|
|
28
38
|
old_items: Original messages
|
|
29
39
|
new_items: New messages generated during run
|
|
30
|
-
|
|
40
|
+
|
|
31
41
|
Returns:
|
|
32
42
|
True to continue execution, False to stop
|
|
33
43
|
"""
|
|
34
44
|
return True
|
|
35
|
-
|
|
45
|
+
|
|
36
46
|
async def on_llm_start(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
|
37
47
|
"""
|
|
38
48
|
Called before messages are sent to the agent loop.
|
|
39
|
-
|
|
49
|
+
|
|
40
50
|
Args:
|
|
41
51
|
messages: List of message dictionaries to preprocess
|
|
42
|
-
|
|
52
|
+
|
|
43
53
|
Returns:
|
|
44
54
|
List of preprocessed message dictionaries
|
|
45
55
|
"""
|
|
46
56
|
return messages
|
|
47
|
-
|
|
57
|
+
|
|
48
58
|
async def on_llm_end(self, output: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
|
49
59
|
"""
|
|
50
60
|
Called after the agent loop returns output.
|
|
51
|
-
|
|
61
|
+
|
|
52
62
|
Args:
|
|
53
63
|
output: List of output message dictionaries to postprocess
|
|
54
|
-
|
|
64
|
+
|
|
55
65
|
Returns:
|
|
56
66
|
List of postprocessed output dictionaries
|
|
57
67
|
"""
|
|
@@ -60,63 +70,67 @@ class AsyncCallbackHandler(ABC):
|
|
|
60
70
|
async def on_computer_call_start(self, item: Dict[str, Any]) -> None:
|
|
61
71
|
"""
|
|
62
72
|
Called when a computer call is about to start.
|
|
63
|
-
|
|
73
|
+
|
|
64
74
|
Args:
|
|
65
75
|
item: The computer call item dictionary
|
|
66
76
|
"""
|
|
67
77
|
pass
|
|
68
|
-
|
|
69
|
-
async def on_computer_call_end(
|
|
78
|
+
|
|
79
|
+
async def on_computer_call_end(
|
|
80
|
+
self, item: Dict[str, Any], result: List[Dict[str, Any]]
|
|
81
|
+
) -> None:
|
|
70
82
|
"""
|
|
71
83
|
Called when a computer call has completed.
|
|
72
|
-
|
|
84
|
+
|
|
73
85
|
Args:
|
|
74
86
|
item: The computer call item dictionary
|
|
75
87
|
result: The result of the computer call
|
|
76
88
|
"""
|
|
77
89
|
pass
|
|
78
|
-
|
|
90
|
+
|
|
79
91
|
async def on_function_call_start(self, item: Dict[str, Any]) -> None:
|
|
80
92
|
"""
|
|
81
93
|
Called when a function call is about to start.
|
|
82
|
-
|
|
94
|
+
|
|
83
95
|
Args:
|
|
84
96
|
item: The function call item dictionary
|
|
85
97
|
"""
|
|
86
98
|
pass
|
|
87
|
-
|
|
88
|
-
async def on_function_call_end(
|
|
99
|
+
|
|
100
|
+
async def on_function_call_end(
|
|
101
|
+
self, item: Dict[str, Any], result: List[Dict[str, Any]]
|
|
102
|
+
) -> None:
|
|
89
103
|
"""
|
|
90
104
|
Called when a function call has completed.
|
|
91
|
-
|
|
105
|
+
|
|
92
106
|
Args:
|
|
93
107
|
item: The function call item dictionary
|
|
94
108
|
result: The result of the function call
|
|
95
109
|
"""
|
|
96
110
|
pass
|
|
97
|
-
|
|
111
|
+
|
|
98
112
|
async def on_text(self, item: Dict[str, Any]) -> None:
|
|
99
113
|
"""
|
|
100
114
|
Called when a text message is encountered.
|
|
101
|
-
|
|
115
|
+
|
|
102
116
|
Args:
|
|
103
117
|
item: The message item dictionary
|
|
104
118
|
"""
|
|
105
119
|
pass
|
|
106
|
-
|
|
120
|
+
|
|
107
121
|
async def on_api_start(self, kwargs: Dict[str, Any]) -> None:
|
|
108
122
|
"""
|
|
109
123
|
Called when an API call is about to start.
|
|
110
|
-
|
|
124
|
+
|
|
111
125
|
Args:
|
|
112
126
|
kwargs: The kwargs being passed to the API call
|
|
113
127
|
"""
|
|
114
128
|
pass
|
|
115
|
-
|
|
129
|
+
|
|
116
130
|
async def on_api_end(self, kwargs: Dict[str, Any], result: Any) -> None:
|
|
117
131
|
"""
|
|
118
132
|
Called when an API call has completed.
|
|
119
|
-
|
|
133
|
+
|
|
120
134
|
Args:
|
|
121
135
|
kwargs: The kwargs that were passed to the API call
|
|
122
136
|
result: The result of the API call
|
|
@@ -126,7 +140,7 @@ class AsyncCallbackHandler(ABC):
|
|
|
126
140
|
async def on_usage(self, usage: Dict[str, Any]) -> None:
|
|
127
141
|
"""
|
|
128
142
|
Called when usage information is received.
|
|
129
|
-
|
|
143
|
+
|
|
130
144
|
Args:
|
|
131
145
|
usage: The usage information
|
|
132
146
|
"""
|
|
@@ -135,7 +149,7 @@ class AsyncCallbackHandler(ABC):
|
|
|
135
149
|
async def on_screenshot(self, screenshot: Union[str, bytes], name: str = "screenshot") -> None:
|
|
136
150
|
"""
|
|
137
151
|
Called when a screenshot is taken.
|
|
138
|
-
|
|
152
|
+
|
|
139
153
|
Args:
|
|
140
154
|
screenshot: The screenshot image
|
|
141
155
|
name: The name of the screenshot
|
|
@@ -145,9 +159,9 @@ class AsyncCallbackHandler(ABC):
|
|
|
145
159
|
async def on_responses(self, kwargs: Dict[str, Any], responses: Dict[str, Any]) -> None:
|
|
146
160
|
"""
|
|
147
161
|
Called when responses are received.
|
|
148
|
-
|
|
162
|
+
|
|
149
163
|
Args:
|
|
150
164
|
kwargs: The kwargs being passed to the agent loop
|
|
151
165
|
responses: The responses received
|
|
152
166
|
"""
|
|
153
|
-
pass
|
|
167
|
+
pass
|
|
@@ -1,17 +1,23 @@
|
|
|
1
|
-
from typing import Dict, List
|
|
1
|
+
from typing import Any, Dict, List
|
|
2
|
+
|
|
2
3
|
from .base import AsyncCallbackHandler
|
|
3
4
|
|
|
5
|
+
|
|
4
6
|
class BudgetExceededError(Exception):
|
|
5
7
|
"""Exception raised when budget is exceeded."""
|
|
8
|
+
|
|
6
9
|
pass
|
|
7
10
|
|
|
11
|
+
|
|
8
12
|
class BudgetManagerCallback(AsyncCallbackHandler):
|
|
9
13
|
"""Budget manager callback that tracks usage costs and can stop execution when budget is exceeded."""
|
|
10
|
-
|
|
11
|
-
def __init__(
|
|
14
|
+
|
|
15
|
+
def __init__(
|
|
16
|
+
self, max_budget: float, reset_after_each_run: bool = True, raise_error: bool = False
|
|
17
|
+
):
|
|
12
18
|
"""
|
|
13
19
|
Initialize BudgetManagerCallback.
|
|
14
|
-
|
|
20
|
+
|
|
15
21
|
Args:
|
|
16
22
|
max_budget: Maximum budget allowed
|
|
17
23
|
reset_after_each_run: Whether to reset budget after each run
|
|
@@ -21,24 +27,30 @@ class BudgetManagerCallback(AsyncCallbackHandler):
|
|
|
21
27
|
self.reset_after_each_run = reset_after_each_run
|
|
22
28
|
self.raise_error = raise_error
|
|
23
29
|
self.total_cost = 0.0
|
|
24
|
-
|
|
30
|
+
|
|
25
31
|
async def on_run_start(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]]) -> None:
|
|
26
32
|
"""Reset budget if configured to do so."""
|
|
27
33
|
if self.reset_after_each_run:
|
|
28
34
|
self.total_cost = 0.0
|
|
29
|
-
|
|
35
|
+
|
|
30
36
|
async def on_usage(self, usage: Dict[str, Any]) -> None:
|
|
31
37
|
"""Track usage costs."""
|
|
32
38
|
if "response_cost" in usage:
|
|
33
39
|
self.total_cost += usage["response_cost"]
|
|
34
|
-
|
|
35
|
-
async def on_run_continue(
|
|
40
|
+
|
|
41
|
+
async def on_run_continue(
|
|
42
|
+
self,
|
|
43
|
+
kwargs: Dict[str, Any],
|
|
44
|
+
old_items: List[Dict[str, Any]],
|
|
45
|
+
new_items: List[Dict[str, Any]],
|
|
46
|
+
) -> bool:
|
|
36
47
|
"""Check if budget allows continuation."""
|
|
37
48
|
if self.total_cost >= self.max_budget:
|
|
38
49
|
if self.raise_error:
|
|
39
|
-
raise BudgetExceededError(
|
|
50
|
+
raise BudgetExceededError(
|
|
51
|
+
f"Budget exceeded: ${self.total_cost} >= ${self.max_budget}"
|
|
52
|
+
)
|
|
40
53
|
else:
|
|
41
54
|
print(f"Budget exceeded: ${self.total_cost} >= ${self.max_budget}")
|
|
42
55
|
return False
|
|
43
56
|
return True
|
|
44
|
-
|
|
@@ -2,7 +2,8 @@
|
|
|
2
2
|
Image retention callback handler that limits the number of recent images in message history.
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
|
-
from typing import
|
|
5
|
+
from typing import Any, Dict, List, Optional
|
|
6
|
+
|
|
6
7
|
from .base import AsyncCallbackHandler
|
|
7
8
|
|
|
8
9
|
|
|
@@ -11,40 +12,40 @@ class ImageRetentionCallback(AsyncCallbackHandler):
|
|
|
11
12
|
Callback handler that applies image retention policy to limit the number
|
|
12
13
|
of recent images in message history to prevent context window overflow.
|
|
13
14
|
"""
|
|
14
|
-
|
|
15
|
+
|
|
15
16
|
def __init__(self, only_n_most_recent_images: Optional[int] = None):
|
|
16
17
|
"""
|
|
17
18
|
Initialize the image retention callback.
|
|
18
|
-
|
|
19
|
+
|
|
19
20
|
Args:
|
|
20
21
|
only_n_most_recent_images: If set, only keep the N most recent images in message history
|
|
21
22
|
"""
|
|
22
23
|
self.only_n_most_recent_images = only_n_most_recent_images
|
|
23
|
-
|
|
24
|
+
|
|
24
25
|
async def on_llm_start(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
|
25
26
|
"""
|
|
26
27
|
Apply image retention policy to messages before sending to agent loop.
|
|
27
|
-
|
|
28
|
+
|
|
28
29
|
Args:
|
|
29
30
|
messages: List of message dictionaries
|
|
30
|
-
|
|
31
|
+
|
|
31
32
|
Returns:
|
|
32
33
|
List of messages with image retention policy applied
|
|
33
34
|
"""
|
|
34
35
|
if self.only_n_most_recent_images is None:
|
|
35
36
|
return messages
|
|
36
|
-
|
|
37
|
+
|
|
37
38
|
return self._apply_image_retention(messages)
|
|
38
|
-
|
|
39
|
+
|
|
39
40
|
def _apply_image_retention(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
|
40
41
|
"""Apply image retention policy to keep only the N most recent images.
|
|
41
|
-
|
|
42
|
+
|
|
42
43
|
Removes computer_call_output items with image_url and their corresponding computer_call items,
|
|
43
44
|
keeping only the most recent N image pairs based on only_n_most_recent_images setting.
|
|
44
|
-
|
|
45
|
+
|
|
45
46
|
Args:
|
|
46
47
|
messages: List of message dictionaries
|
|
47
|
-
|
|
48
|
+
|
|
48
49
|
Returns:
|
|
49
50
|
Filtered list of messages with image retention applied
|
|
50
51
|
"""
|
|
@@ -78,7 +79,11 @@ class ImageRetentionCallback(AsyncCallbackHandler):
|
|
|
78
79
|
# Remove the immediately preceding computer_call with matching call_id (if present)
|
|
79
80
|
call_id = messages[idx].get("call_id")
|
|
80
81
|
prev_idx = idx - 1
|
|
81
|
-
if
|
|
82
|
+
if (
|
|
83
|
+
prev_idx >= 0
|
|
84
|
+
and messages[prev_idx].get("type") == "computer_call"
|
|
85
|
+
and messages[prev_idx].get("call_id") == call_id
|
|
86
|
+
):
|
|
82
87
|
to_remove.add(prev_idx)
|
|
83
88
|
# Check a single reasoning immediately before that computer_call
|
|
84
89
|
r_idx = prev_idx - 1
|
|
@@ -87,4 +92,4 @@ class ImageRetentionCallback(AsyncCallbackHandler):
|
|
|
87
92
|
|
|
88
93
|
# Construct filtered list
|
|
89
94
|
filtered = [m for i, m in enumerate(messages) if i not in to_remove]
|
|
90
|
-
return filtered
|
|
95
|
+
return filtered
|
agent/callbacks/logging.py
CHANGED
|
@@ -4,17 +4,18 @@ Logging callback for ComputerAgent that provides configurable logging of agent l
|
|
|
4
4
|
|
|
5
5
|
import json
|
|
6
6
|
import logging
|
|
7
|
-
from typing import Dict, List,
|
|
7
|
+
from typing import Any, Dict, List, Optional, Union
|
|
8
|
+
|
|
8
9
|
from .base import AsyncCallbackHandler
|
|
9
10
|
|
|
10
11
|
|
|
11
12
|
def sanitize_image_urls(data: Any) -> Any:
|
|
12
13
|
"""
|
|
13
14
|
Recursively search for 'image_url' keys and set their values to '[omitted]'.
|
|
14
|
-
|
|
15
|
+
|
|
15
16
|
Args:
|
|
16
17
|
data: Any data structure (dict, list, or primitive type)
|
|
17
|
-
|
|
18
|
+
|
|
18
19
|
Returns:
|
|
19
20
|
A deep copy of the data with all 'image_url' values replaced with '[omitted]'
|
|
20
21
|
"""
|
|
@@ -28,11 +29,11 @@ def sanitize_image_urls(data: Any) -> Any:
|
|
|
28
29
|
# Recursively sanitize the value
|
|
29
30
|
sanitized[key] = sanitize_image_urls(value)
|
|
30
31
|
return sanitized
|
|
31
|
-
|
|
32
|
+
|
|
32
33
|
elif isinstance(data, list):
|
|
33
34
|
# Recursively sanitize each item in the list
|
|
34
35
|
return [sanitize_image_urls(item) for item in data]
|
|
35
|
-
|
|
36
|
+
|
|
36
37
|
else:
|
|
37
38
|
# For primitive types (str, int, bool, None, etc.), return as-is
|
|
38
39
|
return data
|
|
@@ -41,37 +42,36 @@ def sanitize_image_urls(data: Any) -> Any:
|
|
|
41
42
|
class LoggingCallback(AsyncCallbackHandler):
|
|
42
43
|
"""
|
|
43
44
|
Callback handler that logs agent lifecycle events with configurable verbosity.
|
|
44
|
-
|
|
45
|
+
|
|
45
46
|
Logging levels:
|
|
46
47
|
- DEBUG: All events including API calls, message preprocessing, and detailed outputs
|
|
47
|
-
- INFO: Major lifecycle events (start/end, messages, outputs)
|
|
48
|
+
- INFO: Major lifecycle events (start/end, messages, outputs)
|
|
48
49
|
- WARNING: Only warnings and errors
|
|
49
50
|
- ERROR: Only errors
|
|
50
51
|
"""
|
|
51
|
-
|
|
52
|
+
|
|
52
53
|
def __init__(self, logger: Optional[logging.Logger] = None, level: int = logging.INFO):
|
|
53
54
|
"""
|
|
54
55
|
Initialize the logging callback.
|
|
55
|
-
|
|
56
|
+
|
|
56
57
|
Args:
|
|
57
58
|
logger: Logger instance to use. If None, creates a logger named 'agent.ComputerAgent'
|
|
58
59
|
level: Logging level (logging.DEBUG, logging.INFO, etc.)
|
|
59
60
|
"""
|
|
60
|
-
self.logger = logger or logging.getLogger(
|
|
61
|
+
self.logger = logger or logging.getLogger("agent.ComputerAgent")
|
|
61
62
|
self.level = level
|
|
62
|
-
|
|
63
|
+
|
|
63
64
|
# Set up logger if it doesn't have handlers
|
|
64
65
|
if not self.logger.handlers:
|
|
65
66
|
handler = logging.StreamHandler()
|
|
66
|
-
formatter = logging.Formatter(
|
|
67
|
-
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
|
68
|
-
)
|
|
67
|
+
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
|
69
68
|
handler.setFormatter(formatter)
|
|
70
69
|
self.logger.addHandler(handler)
|
|
71
70
|
self.logger.setLevel(level)
|
|
72
|
-
|
|
71
|
+
|
|
73
72
|
def _update_usage(self, usage: Dict[str, Any]) -> None:
|
|
74
73
|
"""Update total usage statistics."""
|
|
74
|
+
|
|
75
75
|
def add_dicts(target: Dict[str, Any], source: Dict[str, Any]) -> None:
|
|
76
76
|
for key, value in source.items():
|
|
77
77
|
if isinstance(value, dict):
|
|
@@ -82,18 +82,25 @@ class LoggingCallback(AsyncCallbackHandler):
|
|
|
82
82
|
if key not in target:
|
|
83
83
|
target[key] = 0
|
|
84
84
|
target[key] += value
|
|
85
|
+
|
|
85
86
|
add_dicts(self.total_usage, usage)
|
|
86
|
-
|
|
87
|
+
|
|
87
88
|
async def on_run_start(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]]) -> None:
|
|
88
89
|
"""Called before the run starts."""
|
|
89
90
|
self.total_usage = {}
|
|
90
|
-
|
|
91
|
+
|
|
91
92
|
async def on_usage(self, usage: Dict[str, Any]) -> None:
|
|
92
93
|
"""Called when usage information is received."""
|
|
93
94
|
self._update_usage(usage)
|
|
94
95
|
|
|
95
|
-
async def on_run_end(
|
|
96
|
+
async def on_run_end(
|
|
97
|
+
self,
|
|
98
|
+
kwargs: Dict[str, Any],
|
|
99
|
+
old_items: List[Dict[str, Any]],
|
|
100
|
+
new_items: List[Dict[str, Any]],
|
|
101
|
+
) -> None:
|
|
96
102
|
"""Called after the run ends."""
|
|
103
|
+
|
|
97
104
|
def format_dict(d, indent=0):
|
|
98
105
|
lines = []
|
|
99
106
|
prefix = f" - {' ' * indent}"
|
|
@@ -106,10 +113,10 @@ class LoggingCallback(AsyncCallbackHandler):
|
|
|
106
113
|
else:
|
|
107
114
|
lines.append(f"{prefix}{key}: {value}")
|
|
108
115
|
return lines
|
|
109
|
-
|
|
116
|
+
|
|
110
117
|
formatted_output = "\n".join(format_dict(self.total_usage))
|
|
111
118
|
self.logger.info(f"Total usage:\n{formatted_output}")
|
|
112
|
-
|
|
119
|
+
|
|
113
120
|
async def on_llm_start(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
|
114
121
|
"""Called before LLM processing starts."""
|
|
115
122
|
if self.logger.isEnabledFor(logging.INFO):
|
|
@@ -118,27 +125,27 @@ class LoggingCallback(AsyncCallbackHandler):
|
|
|
118
125
|
sanitized_messages = [sanitize_image_urls(msg) for msg in messages]
|
|
119
126
|
self.logger.debug(f"LLM input messages: {json.dumps(sanitized_messages, indent=2)}")
|
|
120
127
|
return messages
|
|
121
|
-
|
|
128
|
+
|
|
122
129
|
async def on_llm_end(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
|
123
130
|
"""Called after LLM processing ends."""
|
|
124
131
|
if self.logger.isEnabledFor(logging.DEBUG):
|
|
125
132
|
sanitized_messages = [sanitize_image_urls(msg) for msg in messages]
|
|
126
133
|
self.logger.debug(f"LLM output: {json.dumps(sanitized_messages, indent=2)}")
|
|
127
134
|
return messages
|
|
128
|
-
|
|
135
|
+
|
|
129
136
|
async def on_computer_call_start(self, item: Dict[str, Any]) -> None:
|
|
130
137
|
"""Called when a computer call starts."""
|
|
131
138
|
action = item.get("action", {})
|
|
132
139
|
action_type = action.get("type", "unknown")
|
|
133
140
|
action_args = {k: v for k, v in action.items() if k != "type"}
|
|
134
|
-
|
|
141
|
+
|
|
135
142
|
# INFO level logging for the action
|
|
136
143
|
self.logger.info(f"Computer: {action_type}({action_args})")
|
|
137
|
-
|
|
144
|
+
|
|
138
145
|
# DEBUG level logging for full details
|
|
139
146
|
if self.logger.isEnabledFor(logging.DEBUG):
|
|
140
147
|
self.logger.debug(f"Computer call started: {json.dumps(action, indent=2)}")
|
|
141
|
-
|
|
148
|
+
|
|
142
149
|
async def on_computer_call_end(self, item: Dict[str, Any], result: Any) -> None:
|
|
143
150
|
"""Called when a computer call ends."""
|
|
144
151
|
if self.logger.isEnabledFor(logging.DEBUG):
|
|
@@ -147,48 +154,52 @@ class LoggingCallback(AsyncCallbackHandler):
|
|
|
147
154
|
if result:
|
|
148
155
|
sanitized_result = sanitize_image_urls(result)
|
|
149
156
|
self.logger.debug(f"Computer call result: {json.dumps(sanitized_result, indent=2)}")
|
|
150
|
-
|
|
157
|
+
|
|
151
158
|
async def on_function_call_start(self, item: Dict[str, Any]) -> None:
|
|
152
159
|
"""Called when a function call starts."""
|
|
153
160
|
name = item.get("name", "unknown")
|
|
154
161
|
arguments = item.get("arguments", "{}")
|
|
155
|
-
|
|
162
|
+
|
|
156
163
|
# INFO level logging for the function call
|
|
157
164
|
self.logger.info(f"Function: {name}({arguments})")
|
|
158
|
-
|
|
165
|
+
|
|
159
166
|
# DEBUG level logging for full details
|
|
160
167
|
if self.logger.isEnabledFor(logging.DEBUG):
|
|
161
168
|
self.logger.debug(f"Function call started: {name}")
|
|
162
|
-
|
|
169
|
+
|
|
163
170
|
async def on_function_call_end(self, item: Dict[str, Any], result: Any) -> None:
|
|
164
171
|
"""Called when a function call ends."""
|
|
165
172
|
# INFO level logging for function output (similar to function_call_output)
|
|
166
173
|
if result:
|
|
167
174
|
# Handle both list and direct result formats
|
|
168
175
|
if isinstance(result, list) and len(result) > 0:
|
|
169
|
-
output =
|
|
176
|
+
output = (
|
|
177
|
+
result[0].get("output", str(result))
|
|
178
|
+
if isinstance(result[0], dict)
|
|
179
|
+
else str(result[0])
|
|
180
|
+
)
|
|
170
181
|
else:
|
|
171
182
|
output = str(result)
|
|
172
|
-
|
|
183
|
+
|
|
173
184
|
# Truncate long outputs
|
|
174
185
|
if len(output) > 100:
|
|
175
186
|
output = output[:100] + "..."
|
|
176
|
-
|
|
187
|
+
|
|
177
188
|
self.logger.info(f"Output: {output}")
|
|
178
|
-
|
|
189
|
+
|
|
179
190
|
# DEBUG level logging for full details
|
|
180
191
|
if self.logger.isEnabledFor(logging.DEBUG):
|
|
181
192
|
name = item.get("name", "unknown")
|
|
182
193
|
self.logger.debug(f"Function call completed: {name}")
|
|
183
194
|
if result:
|
|
184
195
|
self.logger.debug(f"Function call result: {json.dumps(result, indent=2)}")
|
|
185
|
-
|
|
196
|
+
|
|
186
197
|
async def on_text(self, item: Dict[str, Any]) -> None:
|
|
187
198
|
"""Called when a text message is encountered."""
|
|
188
199
|
# Get the role to determine if it's Agent or User
|
|
189
200
|
role = item.get("role", "unknown")
|
|
190
201
|
content_items = item.get("content", [])
|
|
191
|
-
|
|
202
|
+
|
|
192
203
|
# Process content items to build display text
|
|
193
204
|
text_parts = []
|
|
194
205
|
for content_item in content_items:
|
|
@@ -206,10 +217,10 @@ class LoggingCallback(AsyncCallbackHandler):
|
|
|
206
217
|
else:
|
|
207
218
|
# Non-text content, show as [type]
|
|
208
219
|
text_parts.append(f"[{content_type}]")
|
|
209
|
-
|
|
220
|
+
|
|
210
221
|
# Join all text parts
|
|
211
|
-
display_text =
|
|
212
|
-
|
|
222
|
+
display_text = "".join(text_parts) if text_parts else "[empty]"
|
|
223
|
+
|
|
213
224
|
# Log with appropriate level and format
|
|
214
225
|
if role == "assistant":
|
|
215
226
|
self.logger.info(f"Agent: {display_text}")
|
|
@@ -219,7 +230,7 @@ class LoggingCallback(AsyncCallbackHandler):
|
|
|
219
230
|
# Fallback for unknown roles, use debug level
|
|
220
231
|
if self.logger.isEnabledFor(logging.DEBUG):
|
|
221
232
|
self.logger.debug(f"Text message ({role}): {display_text}")
|
|
222
|
-
|
|
233
|
+
|
|
223
234
|
async def on_api_start(self, kwargs: Dict[str, Any]) -> None:
|
|
224
235
|
"""Called when an API call is about to start."""
|
|
225
236
|
if self.logger.isEnabledFor(logging.DEBUG):
|
|
@@ -232,16 +243,18 @@ class LoggingCallback(AsyncCallbackHandler):
|
|
|
232
243
|
elif "input" in kwargs:
|
|
233
244
|
sanitized_input = sanitize_image_urls(kwargs["input"])
|
|
234
245
|
self.logger.debug(f"API call input: {json.dumps(sanitized_input, indent=2)}")
|
|
235
|
-
|
|
246
|
+
|
|
236
247
|
async def on_api_end(self, kwargs: Dict[str, Any], result: Any) -> None:
|
|
237
248
|
"""Called when an API call has completed."""
|
|
238
249
|
if self.logger.isEnabledFor(logging.DEBUG):
|
|
239
250
|
model = kwargs.get("model", "unknown")
|
|
240
251
|
self.logger.debug(f"API call completed for model: {model}")
|
|
241
|
-
self.logger.debug(
|
|
252
|
+
self.logger.debug(
|
|
253
|
+
f"API call result: {json.dumps(sanitize_image_urls(result), indent=2)}"
|
|
254
|
+
)
|
|
242
255
|
|
|
243
256
|
async def on_screenshot(self, item: Union[str, bytes], name: str = "screenshot") -> None:
|
|
244
257
|
"""Called when a screenshot is taken."""
|
|
245
258
|
if self.logger.isEnabledFor(logging.DEBUG):
|
|
246
259
|
image_size = len(item) / 1024
|
|
247
|
-
self.logger.debug(f"Screenshot captured: {name} {image_size:.2f} KB")
|
|
260
|
+
self.logger.debug(f"Screenshot captured: {name} {image_size:.2f} KB")
|
|
@@ -9,6 +9,7 @@ Ensures agent output actions conform to expected schemas by fixing common issues
|
|
|
9
9
|
This runs in on_llm_end, which receives the output array (AgentMessage[] as dicts).
|
|
10
10
|
The purpose is to avoid spending another LLM call to fix broken computer call syntax when possible.
|
|
11
11
|
"""
|
|
12
|
+
|
|
12
13
|
from __future__ import annotations
|
|
13
14
|
|
|
14
15
|
from typing import Any, Dict, List
|
|
@@ -48,6 +49,7 @@ class OperatorNormalizerCallback(AsyncCallbackHandler):
|
|
|
48
49
|
action["type"] = "type"
|
|
49
50
|
|
|
50
51
|
action_type = action.get("type")
|
|
52
|
+
|
|
51
53
|
def _keep_keys(action: Dict[str, Any], keys_to_keep: List[str]):
|
|
52
54
|
"""Keep only the provided keys on action; delete everything else.
|
|
53
55
|
Always ensures required 'type' is present if listed in keys_to_keep.
|
|
@@ -55,6 +57,7 @@ class OperatorNormalizerCallback(AsyncCallbackHandler):
|
|
|
55
57
|
for key in list(action.keys()):
|
|
56
58
|
if key not in keys_to_keep:
|
|
57
59
|
del action[key]
|
|
60
|
+
|
|
58
61
|
# rename "coordinate" to "x", "y"
|
|
59
62
|
if "coordinate" in action:
|
|
60
63
|
action["x"] = action["coordinate"][0]
|
|
@@ -100,7 +103,6 @@ class OperatorNormalizerCallback(AsyncCallbackHandler):
|
|
|
100
103
|
keep = required_keys_by_type.get(action_type or "")
|
|
101
104
|
if keep:
|
|
102
105
|
_keep_keys(action, keep)
|
|
103
|
-
|
|
104
106
|
|
|
105
107
|
# # Second pass: if an assistant message is immediately followed by a computer_call,
|
|
106
108
|
# # replace the assistant message itself with a reasoning message with summary text.
|