openadapt-ml 0.2.0__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openadapt_ml/baselines/__init__.py +121 -0
- openadapt_ml/baselines/adapter.py +185 -0
- openadapt_ml/baselines/cli.py +314 -0
- openadapt_ml/baselines/config.py +448 -0
- openadapt_ml/baselines/parser.py +922 -0
- openadapt_ml/baselines/prompts.py +787 -0
- openadapt_ml/benchmarks/__init__.py +13 -115
- openadapt_ml/benchmarks/agent.py +265 -421
- openadapt_ml/benchmarks/azure.py +28 -19
- openadapt_ml/benchmarks/azure_ops_tracker.py +521 -0
- openadapt_ml/benchmarks/cli.py +1722 -4847
- openadapt_ml/benchmarks/trace_export.py +631 -0
- openadapt_ml/benchmarks/viewer.py +22 -5
- openadapt_ml/benchmarks/vm_monitor.py +530 -29
- openadapt_ml/benchmarks/waa_deploy/Dockerfile +47 -53
- openadapt_ml/benchmarks/waa_deploy/api_agent.py +21 -20
- openadapt_ml/cloud/azure_inference.py +3 -5
- openadapt_ml/cloud/lambda_labs.py +722 -307
- openadapt_ml/cloud/local.py +2038 -487
- openadapt_ml/cloud/ssh_tunnel.py +68 -26
- openadapt_ml/datasets/next_action.py +40 -30
- openadapt_ml/evals/grounding.py +8 -3
- openadapt_ml/evals/plot_eval_metrics.py +15 -13
- openadapt_ml/evals/trajectory_matching.py +41 -26
- openadapt_ml/experiments/demo_prompt/format_demo.py +16 -6
- openadapt_ml/experiments/demo_prompt/run_experiment.py +26 -16
- openadapt_ml/experiments/representation_shootout/__init__.py +70 -0
- openadapt_ml/experiments/representation_shootout/conditions.py +708 -0
- openadapt_ml/experiments/representation_shootout/config.py +390 -0
- openadapt_ml/experiments/representation_shootout/evaluator.py +659 -0
- openadapt_ml/experiments/representation_shootout/runner.py +687 -0
- openadapt_ml/experiments/waa_demo/runner.py +29 -14
- openadapt_ml/export/parquet.py +36 -24
- openadapt_ml/grounding/detector.py +18 -14
- openadapt_ml/ingest/__init__.py +8 -6
- openadapt_ml/ingest/capture.py +25 -22
- openadapt_ml/ingest/loader.py +7 -4
- openadapt_ml/ingest/synthetic.py +189 -100
- openadapt_ml/models/api_adapter.py +14 -4
- openadapt_ml/models/base_adapter.py +10 -2
- openadapt_ml/models/providers/__init__.py +288 -0
- openadapt_ml/models/providers/anthropic.py +266 -0
- openadapt_ml/models/providers/base.py +299 -0
- openadapt_ml/models/providers/google.py +376 -0
- openadapt_ml/models/providers/openai.py +342 -0
- openadapt_ml/models/qwen_vl.py +46 -19
- openadapt_ml/perception/__init__.py +35 -0
- openadapt_ml/perception/integration.py +399 -0
- openadapt_ml/retrieval/demo_retriever.py +50 -24
- openadapt_ml/retrieval/embeddings.py +9 -8
- openadapt_ml/retrieval/retriever.py +3 -1
- openadapt_ml/runtime/__init__.py +50 -0
- openadapt_ml/runtime/policy.py +18 -5
- openadapt_ml/runtime/safety_gate.py +471 -0
- openadapt_ml/schema/__init__.py +9 -0
- openadapt_ml/schema/converters.py +74 -27
- openadapt_ml/schema/episode.py +31 -18
- openadapt_ml/scripts/capture_screenshots.py +530 -0
- openadapt_ml/scripts/compare.py +85 -54
- openadapt_ml/scripts/demo_policy.py +4 -1
- openadapt_ml/scripts/eval_policy.py +15 -9
- openadapt_ml/scripts/make_gif.py +1 -1
- openadapt_ml/scripts/prepare_synthetic.py +3 -1
- openadapt_ml/scripts/train.py +21 -9
- openadapt_ml/segmentation/README.md +920 -0
- openadapt_ml/segmentation/__init__.py +97 -0
- openadapt_ml/segmentation/adapters/__init__.py +5 -0
- openadapt_ml/segmentation/adapters/capture_adapter.py +420 -0
- openadapt_ml/segmentation/annotator.py +610 -0
- openadapt_ml/segmentation/cache.py +290 -0
- openadapt_ml/segmentation/cli.py +674 -0
- openadapt_ml/segmentation/deduplicator.py +656 -0
- openadapt_ml/segmentation/frame_describer.py +788 -0
- openadapt_ml/segmentation/pipeline.py +340 -0
- openadapt_ml/segmentation/schemas.py +622 -0
- openadapt_ml/segmentation/segment_extractor.py +634 -0
- openadapt_ml/training/azure_ops_viewer.py +1097 -0
- openadapt_ml/training/benchmark_viewer.py +52 -41
- openadapt_ml/training/shared_ui.py +7 -7
- openadapt_ml/training/stub_provider.py +57 -35
- openadapt_ml/training/trainer.py +143 -86
- openadapt_ml/training/trl_trainer.py +70 -21
- openadapt_ml/training/viewer.py +323 -108
- openadapt_ml/training/viewer_components.py +180 -0
- {openadapt_ml-0.2.0.dist-info → openadapt_ml-0.2.1.dist-info}/METADATA +215 -14
- openadapt_ml-0.2.1.dist-info/RECORD +116 -0
- openadapt_ml/benchmarks/base.py +0 -366
- openadapt_ml/benchmarks/data_collection.py +0 -432
- openadapt_ml/benchmarks/live_tracker.py +0 -180
- openadapt_ml/benchmarks/runner.py +0 -418
- openadapt_ml/benchmarks/waa.py +0 -761
- openadapt_ml/benchmarks/waa_live.py +0 -619
- openadapt_ml-0.2.0.dist-info/RECORD +0 -86
- {openadapt_ml-0.2.0.dist-info → openadapt_ml-0.2.1.dist-info}/WHEEL +0 -0
- {openadapt_ml-0.2.0.dist-info → openadapt_ml-0.2.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,708 @@
|
|
|
1
|
+
"""Experimental conditions for the Representation Shootout.
|
|
2
|
+
|
|
3
|
+
Defines the three conditions:
|
|
4
|
+
- Condition A: Raw Coordinates - Direct coordinate regression
|
|
5
|
+
- Condition B: Coordinates + Visual Cues - Enhanced with markers and zoom
|
|
6
|
+
- Condition C: Marks (Element IDs) - Element classification using SoM
|
|
7
|
+
|
|
8
|
+
Each condition implements:
|
|
9
|
+
1. Input preparation (screenshot augmentation, prompt construction)
|
|
10
|
+
2. Output parsing (model response to action dict)
|
|
11
|
+
3. Loss computation (for training)
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
from __future__ import annotations
|
|
15
|
+
|
|
16
|
+
import logging
|
|
17
|
+
from abc import ABC, abstractmethod
|
|
18
|
+
from dataclasses import dataclass
|
|
19
|
+
from typing import Any
|
|
20
|
+
|
|
21
|
+
from openadapt_ml.experiments.representation_shootout.config import (
|
|
22
|
+
ConditionConfig,
|
|
23
|
+
ConditionName,
|
|
24
|
+
MarksConfig,
|
|
25
|
+
OutputFormat,
|
|
26
|
+
VisualCuesConfig,
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
logger = logging.getLogger(__name__)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@dataclass
|
|
33
|
+
class UIElement:
|
|
34
|
+
"""UI element with ID and bounding box for marks condition."""
|
|
35
|
+
|
|
36
|
+
element_id: str
|
|
37
|
+
role: str
|
|
38
|
+
name: str | None
|
|
39
|
+
bbox: tuple[float, float, float, float] # (x1, y1, x2, y2) normalized or pixels
|
|
40
|
+
|
|
41
|
+
@property
|
|
42
|
+
def center(self) -> tuple[float, float]:
|
|
43
|
+
"""Get center point of element."""
|
|
44
|
+
x1, y1, x2, y2 = self.bbox
|
|
45
|
+
return ((x1 + x2) / 2, (y1 + y2) / 2)
|
|
46
|
+
|
|
47
|
+
def contains_point(self, x: float, y: float) -> bool:
|
|
48
|
+
"""Check if point is within this element's bbox."""
|
|
49
|
+
x1, y1, x2, y2 = self.bbox
|
|
50
|
+
return x1 <= x <= x2 and y1 <= y <= y2
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
@dataclass
|
|
54
|
+
class UIElementGraph:
|
|
55
|
+
"""Collection of UI elements for marks condition."""
|
|
56
|
+
|
|
57
|
+
elements: list[UIElement]
|
|
58
|
+
|
|
59
|
+
def get_element(self, element_id: str) -> UIElement | None:
|
|
60
|
+
"""Get element by ID."""
|
|
61
|
+
for el in self.elements:
|
|
62
|
+
if el.element_id == element_id:
|
|
63
|
+
return el
|
|
64
|
+
return None
|
|
65
|
+
|
|
66
|
+
def find_element_at(self, x: float, y: float) -> UIElement | None:
|
|
67
|
+
"""Find element containing the given point."""
|
|
68
|
+
for el in self.elements:
|
|
69
|
+
if el.contains_point(x, y):
|
|
70
|
+
return el
|
|
71
|
+
return None
|
|
72
|
+
|
|
73
|
+
def to_prompt_text(self, max_elements: int | None = None) -> str:
|
|
74
|
+
"""Format elements for text prompt.
|
|
75
|
+
|
|
76
|
+
Returns:
|
|
77
|
+
Formatted string like:
|
|
78
|
+
[e1] button "Submit" at (0.4, 0.8)-(0.6, 0.85)
|
|
79
|
+
[e17] textfield "Username" at (0.3, 0.4)-(0.7, 0.45)
|
|
80
|
+
"""
|
|
81
|
+
lines = []
|
|
82
|
+
elements_to_show = (
|
|
83
|
+
self.elements[:max_elements] if max_elements else self.elements
|
|
84
|
+
)
|
|
85
|
+
for el in elements_to_show:
|
|
86
|
+
name_part = f' "{el.name}"' if el.name else ""
|
|
87
|
+
x1, y1, x2, y2 = el.bbox
|
|
88
|
+
lines.append(
|
|
89
|
+
f"[{el.element_id}] {el.role}{name_part} at ({x1:.2f}, {y1:.2f})-({x2:.2f}, {y2:.2f})"
|
|
90
|
+
)
|
|
91
|
+
return "\n".join(lines)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
@dataclass
|
|
95
|
+
class Observation:
|
|
96
|
+
"""Observation data for input preparation.
|
|
97
|
+
|
|
98
|
+
This is a simplified observation structure for the experiment.
|
|
99
|
+
In production, use openadapt_ml.benchmarks.base.BenchmarkObservation.
|
|
100
|
+
"""
|
|
101
|
+
|
|
102
|
+
screenshot_path: str | None = None
|
|
103
|
+
screenshot_bytes: bytes | None = None
|
|
104
|
+
screen_size: tuple[int, int] | None = None # (width, height)
|
|
105
|
+
ui_elements: UIElementGraph | None = None
|
|
106
|
+
window_title: str | None = None
|
|
107
|
+
url: str | None = None
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
@dataclass
|
|
111
|
+
class ActionHistory:
|
|
112
|
+
"""History of previous actions."""
|
|
113
|
+
|
|
114
|
+
actions: list[dict[str, Any]] # List of action dicts
|
|
115
|
+
|
|
116
|
+
def to_prompt_text(self, max_steps: int = 5) -> str:
|
|
117
|
+
"""Format history for text prompt."""
|
|
118
|
+
lines = []
|
|
119
|
+
for i, action in enumerate(self.actions[-max_steps:], 1):
|
|
120
|
+
action_type = action.get("type", "unknown").upper()
|
|
121
|
+
if action_type == "CLICK":
|
|
122
|
+
if "element_id" in action:
|
|
123
|
+
lines.append(f"{i}. CLICK([{action['element_id']}])")
|
|
124
|
+
elif "x" in action and "y" in action:
|
|
125
|
+
lines.append(f"{i}. CLICK({action['x']:.3f}, {action['y']:.3f})")
|
|
126
|
+
else:
|
|
127
|
+
lines.append(f"{i}. CLICK()")
|
|
128
|
+
elif action_type == "TYPE":
|
|
129
|
+
text = action.get("text", "")
|
|
130
|
+
lines.append(f'{i}. TYPE("{text}")')
|
|
131
|
+
elif action_type == "KEY":
|
|
132
|
+
key = action.get("key", "")
|
|
133
|
+
lines.append(f"{i}. KEY({key})")
|
|
134
|
+
elif action_type == "SCROLL":
|
|
135
|
+
direction = action.get("direction", "down")
|
|
136
|
+
lines.append(f"{i}. SCROLL({direction})")
|
|
137
|
+
else:
|
|
138
|
+
lines.append(f"{i}. {action_type}()")
|
|
139
|
+
return "\n".join(lines)
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
@dataclass
|
|
143
|
+
class PreparedInput:
|
|
144
|
+
"""Prepared input for the model.
|
|
145
|
+
|
|
146
|
+
Attributes:
|
|
147
|
+
screenshot_path: Path to (possibly augmented) screenshot.
|
|
148
|
+
prompt: Text prompt for the model.
|
|
149
|
+
metadata: Additional metadata for debugging/analysis.
|
|
150
|
+
"""
|
|
151
|
+
|
|
152
|
+
screenshot_path: str | None
|
|
153
|
+
prompt: str
|
|
154
|
+
metadata: dict[str, Any] | None = None
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
@dataclass
|
|
158
|
+
class ParsedAction:
|
|
159
|
+
"""Parsed action from model output.
|
|
160
|
+
|
|
161
|
+
Attributes:
|
|
162
|
+
type: Action type (e.g., "click", "type", "scroll", "done").
|
|
163
|
+
x: X coordinate (for coordinate-based outputs).
|
|
164
|
+
y: Y coordinate (for coordinate-based outputs).
|
|
165
|
+
element_id: Element ID (for marks-based outputs).
|
|
166
|
+
text: Text to type (for type actions).
|
|
167
|
+
raw_output: Original model output string.
|
|
168
|
+
parse_error: Error message if parsing failed.
|
|
169
|
+
"""
|
|
170
|
+
|
|
171
|
+
type: str
|
|
172
|
+
x: float | None = None
|
|
173
|
+
y: float | None = None
|
|
174
|
+
element_id: str | None = None
|
|
175
|
+
text: str | None = None
|
|
176
|
+
raw_output: str | None = None
|
|
177
|
+
parse_error: str | None = None
|
|
178
|
+
|
|
179
|
+
def to_dict(self) -> dict[str, Any]:
|
|
180
|
+
"""Convert to action dictionary."""
|
|
181
|
+
result: dict[str, Any] = {"type": self.type}
|
|
182
|
+
if self.x is not None:
|
|
183
|
+
result["x"] = self.x
|
|
184
|
+
if self.y is not None:
|
|
185
|
+
result["y"] = self.y
|
|
186
|
+
if self.element_id is not None:
|
|
187
|
+
result["element_id"] = self.element_id
|
|
188
|
+
if self.text is not None:
|
|
189
|
+
result["text"] = self.text
|
|
190
|
+
return result
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
class ConditionBase(ABC):
|
|
194
|
+
"""Abstract base class for experimental conditions.
|
|
195
|
+
|
|
196
|
+
Each condition defines how to:
|
|
197
|
+
1. Prepare input from observations (possibly augmenting screenshots)
|
|
198
|
+
2. Parse model output into structured actions
|
|
199
|
+
3. Compute training loss
|
|
200
|
+
"""
|
|
201
|
+
|
|
202
|
+
def __init__(self, config: ConditionConfig):
|
|
203
|
+
"""Initialize condition.
|
|
204
|
+
|
|
205
|
+
Args:
|
|
206
|
+
config: Condition-specific configuration.
|
|
207
|
+
"""
|
|
208
|
+
self.config = config
|
|
209
|
+
|
|
210
|
+
@property
|
|
211
|
+
def name(self) -> ConditionName:
|
|
212
|
+
"""Condition name."""
|
|
213
|
+
return self.config.name
|
|
214
|
+
|
|
215
|
+
@property
|
|
216
|
+
def output_format(self) -> OutputFormat:
|
|
217
|
+
"""Expected output format."""
|
|
218
|
+
return self.config.output_format
|
|
219
|
+
|
|
220
|
+
@abstractmethod
|
|
221
|
+
def prepare_input(
|
|
222
|
+
self,
|
|
223
|
+
observation: Observation,
|
|
224
|
+
goal: str,
|
|
225
|
+
history: ActionHistory | None = None,
|
|
226
|
+
) -> PreparedInput:
|
|
227
|
+
"""Prepare model input from observation.
|
|
228
|
+
|
|
229
|
+
Args:
|
|
230
|
+
observation: Current observation with screenshot and UI elements.
|
|
231
|
+
goal: Task goal/instruction.
|
|
232
|
+
history: Optional history of previous actions.
|
|
233
|
+
|
|
234
|
+
Returns:
|
|
235
|
+
PreparedInput with (possibly augmented) screenshot and prompt.
|
|
236
|
+
"""
|
|
237
|
+
pass
|
|
238
|
+
|
|
239
|
+
@abstractmethod
|
|
240
|
+
def parse_output(self, model_output: str) -> ParsedAction:
|
|
241
|
+
"""Parse model output to structured action.
|
|
242
|
+
|
|
243
|
+
Args:
|
|
244
|
+
model_output: Raw string output from model.
|
|
245
|
+
|
|
246
|
+
Returns:
|
|
247
|
+
ParsedAction with extracted action information.
|
|
248
|
+
"""
|
|
249
|
+
pass
|
|
250
|
+
|
|
251
|
+
@abstractmethod
|
|
252
|
+
def compute_loss(
|
|
253
|
+
self,
|
|
254
|
+
prediction: ParsedAction,
|
|
255
|
+
ground_truth: dict[str, Any],
|
|
256
|
+
) -> float:
|
|
257
|
+
"""Compute training loss for a single sample.
|
|
258
|
+
|
|
259
|
+
Args:
|
|
260
|
+
prediction: Parsed prediction from model.
|
|
261
|
+
ground_truth: Ground truth action dict with coordinates/element_id.
|
|
262
|
+
|
|
263
|
+
Returns:
|
|
264
|
+
Loss value (lower is better).
|
|
265
|
+
"""
|
|
266
|
+
pass
|
|
267
|
+
|
|
268
|
+
def _build_base_prompt(
|
|
269
|
+
self,
|
|
270
|
+
goal: str,
|
|
271
|
+
history: ActionHistory | None = None,
|
|
272
|
+
) -> str:
|
|
273
|
+
"""Build the base prompt text (shared across conditions)."""
|
|
274
|
+
parts = [f"GOAL: {goal}"]
|
|
275
|
+
|
|
276
|
+
if self.config.include_history and history and history.actions:
|
|
277
|
+
history_text = history.to_prompt_text(self.config.max_history_steps)
|
|
278
|
+
parts.append(f"\nPREVIOUS ACTIONS:\n{history_text}")
|
|
279
|
+
|
|
280
|
+
return "\n".join(parts)
|
|
281
|
+
|
|
282
|
+
|
|
283
|
+
class RawCoordsCondition(ConditionBase):
|
|
284
|
+
"""Condition A: Raw Coordinates.
|
|
285
|
+
|
|
286
|
+
Input: Screenshot (unmodified) + goal + history
|
|
287
|
+
Output: {"type": "CLICK", "x": float, "y": float}
|
|
288
|
+
Training: Coordinate regression (MSE loss)
|
|
289
|
+
"""
|
|
290
|
+
|
|
291
|
+
def prepare_input(
|
|
292
|
+
self,
|
|
293
|
+
observation: Observation,
|
|
294
|
+
goal: str,
|
|
295
|
+
history: ActionHistory | None = None,
|
|
296
|
+
) -> PreparedInput:
|
|
297
|
+
"""Prepare input without any screenshot augmentation."""
|
|
298
|
+
prompt = self._build_base_prompt(goal, history)
|
|
299
|
+
prompt += "\n\nAnalyze the screenshot and provide the next action."
|
|
300
|
+
prompt += "\nRespond with: ACTION: CLICK(x, y) where x and y are normalized coordinates (0.0-1.0)"
|
|
301
|
+
|
|
302
|
+
return PreparedInput(
|
|
303
|
+
screenshot_path=observation.screenshot_path,
|
|
304
|
+
prompt=prompt,
|
|
305
|
+
metadata={"condition": "raw_coords"},
|
|
306
|
+
)
|
|
307
|
+
|
|
308
|
+
def parse_output(self, model_output: str) -> ParsedAction:
|
|
309
|
+
"""Parse coordinate output from model."""
|
|
310
|
+
import re
|
|
311
|
+
|
|
312
|
+
# Look for ACTION: CLICK(x, y) pattern
|
|
313
|
+
action_match = re.search(
|
|
314
|
+
r"ACTION:\s*CLICK\s*\(\s*([\d.]+)\s*,\s*([\d.]+)\s*\)",
|
|
315
|
+
model_output,
|
|
316
|
+
re.IGNORECASE,
|
|
317
|
+
)
|
|
318
|
+
if action_match:
|
|
319
|
+
try:
|
|
320
|
+
x = float(action_match.group(1))
|
|
321
|
+
y = float(action_match.group(2))
|
|
322
|
+
return ParsedAction(type="click", x=x, y=y, raw_output=model_output)
|
|
323
|
+
except ValueError as e:
|
|
324
|
+
return ParsedAction(
|
|
325
|
+
type="click",
|
|
326
|
+
raw_output=model_output,
|
|
327
|
+
parse_error=f"Invalid coordinates: {e}",
|
|
328
|
+
)
|
|
329
|
+
|
|
330
|
+
# Try looser patterns
|
|
331
|
+
coord_match = re.search(
|
|
332
|
+
r"CLICK\s*\(\s*([\d.]+)\s*,\s*([\d.]+)\s*\)",
|
|
333
|
+
model_output,
|
|
334
|
+
re.IGNORECASE,
|
|
335
|
+
)
|
|
336
|
+
if coord_match:
|
|
337
|
+
try:
|
|
338
|
+
x = float(coord_match.group(1))
|
|
339
|
+
y = float(coord_match.group(2))
|
|
340
|
+
return ParsedAction(type="click", x=x, y=y, raw_output=model_output)
|
|
341
|
+
except ValueError:
|
|
342
|
+
pass
|
|
343
|
+
|
|
344
|
+
# Check for TYPE action
|
|
345
|
+
type_match = re.search(
|
|
346
|
+
r'TYPE\s*\(\s*["\'](.+?)["\']\s*\)', model_output, re.IGNORECASE
|
|
347
|
+
)
|
|
348
|
+
if type_match:
|
|
349
|
+
return ParsedAction(
|
|
350
|
+
type="type", text=type_match.group(1), raw_output=model_output
|
|
351
|
+
)
|
|
352
|
+
|
|
353
|
+
# Check for DONE action
|
|
354
|
+
if re.search(r"DONE\s*\(\s*\)", model_output, re.IGNORECASE):
|
|
355
|
+
return ParsedAction(type="done", raw_output=model_output)
|
|
356
|
+
|
|
357
|
+
return ParsedAction(
|
|
358
|
+
type="unknown",
|
|
359
|
+
raw_output=model_output,
|
|
360
|
+
parse_error="No action pattern found",
|
|
361
|
+
)
|
|
362
|
+
|
|
363
|
+
def compute_loss(
|
|
364
|
+
self,
|
|
365
|
+
prediction: ParsedAction,
|
|
366
|
+
ground_truth: dict[str, Any],
|
|
367
|
+
) -> float:
|
|
368
|
+
"""Compute MSE loss between predicted and ground truth coordinates."""
|
|
369
|
+
gt_x = ground_truth.get("x")
|
|
370
|
+
gt_y = ground_truth.get("y")
|
|
371
|
+
|
|
372
|
+
if gt_x is None or gt_y is None:
|
|
373
|
+
logger.warning("Ground truth missing coordinates, returning max loss")
|
|
374
|
+
return 1.0
|
|
375
|
+
|
|
376
|
+
if prediction.x is None or prediction.y is None:
|
|
377
|
+
# Prediction failed, return max loss
|
|
378
|
+
return 1.0
|
|
379
|
+
|
|
380
|
+
# MSE in normalized coordinate space
|
|
381
|
+
mse = (prediction.x - gt_x) ** 2 + (prediction.y - gt_y) ** 2
|
|
382
|
+
return mse
|
|
383
|
+
|
|
384
|
+
|
|
385
|
+
class CoordsCuesCondition(ConditionBase):
|
|
386
|
+
"""Condition B: Coordinates + Visual Cues.
|
|
387
|
+
|
|
388
|
+
Input: Screenshot with red marker at click target + zoomed patch + goal
|
|
389
|
+
Output: {"type": "CLICK", "x": float, "y": float}
|
|
390
|
+
Training: Enhanced coordinate regression (MSE loss)
|
|
391
|
+
|
|
392
|
+
Note: Visual cues are only added during training. At test time,
|
|
393
|
+
the model must predict without the cues.
|
|
394
|
+
"""
|
|
395
|
+
|
|
396
|
+
def __init__(self, config: ConditionConfig):
|
|
397
|
+
super().__init__(config)
|
|
398
|
+
self.visual_cues = config.visual_cues or VisualCuesConfig()
|
|
399
|
+
|
|
400
|
+
def prepare_input(
|
|
401
|
+
self,
|
|
402
|
+
observation: Observation,
|
|
403
|
+
goal: str,
|
|
404
|
+
history: ActionHistory | None = None,
|
|
405
|
+
target_coords: tuple[float, float] | None = None,
|
|
406
|
+
is_training: bool = False,
|
|
407
|
+
) -> PreparedInput:
|
|
408
|
+
"""Prepare input with visual cues during training.
|
|
409
|
+
|
|
410
|
+
Args:
|
|
411
|
+
observation: Current observation.
|
|
412
|
+
goal: Task goal.
|
|
413
|
+
history: Action history.
|
|
414
|
+
target_coords: Target (x, y) for training augmentation.
|
|
415
|
+
is_training: Whether this is for training (add cues) or eval (no cues).
|
|
416
|
+
"""
|
|
417
|
+
prompt = self._build_base_prompt(goal, history)
|
|
418
|
+
|
|
419
|
+
augmented_path = observation.screenshot_path
|
|
420
|
+
metadata: dict[str, Any] = {
|
|
421
|
+
"condition": "coords_cues",
|
|
422
|
+
"is_training": is_training,
|
|
423
|
+
}
|
|
424
|
+
|
|
425
|
+
if is_training and target_coords:
|
|
426
|
+
# Add visual cues for training
|
|
427
|
+
prompt += (
|
|
428
|
+
"\n\nThe red marker and zoomed inset show the target click location."
|
|
429
|
+
)
|
|
430
|
+
prompt += "\nLearn to identify this location based on the UI context."
|
|
431
|
+
|
|
432
|
+
# Augment screenshot (placeholder - actual implementation would use PIL/cv2)
|
|
433
|
+
augmented_path = self._augment_screenshot(
|
|
434
|
+
observation.screenshot_path,
|
|
435
|
+
target_coords,
|
|
436
|
+
observation.screen_size,
|
|
437
|
+
)
|
|
438
|
+
metadata["target_coords"] = target_coords
|
|
439
|
+
metadata["augmented"] = True
|
|
440
|
+
else:
|
|
441
|
+
prompt += "\n\nAnalyze the screenshot and provide the next action."
|
|
442
|
+
|
|
443
|
+
prompt += "\nRespond with: ACTION: CLICK(x, y) where x and y are normalized coordinates (0.0-1.0)"
|
|
444
|
+
|
|
445
|
+
return PreparedInput(
|
|
446
|
+
screenshot_path=augmented_path,
|
|
447
|
+
prompt=prompt,
|
|
448
|
+
metadata=metadata,
|
|
449
|
+
)
|
|
450
|
+
|
|
451
|
+
def _augment_screenshot(
|
|
452
|
+
self,
|
|
453
|
+
screenshot_path: str | None,
|
|
454
|
+
target_coords: tuple[float, float],
|
|
455
|
+
screen_size: tuple[int, int] | None,
|
|
456
|
+
) -> str | None:
|
|
457
|
+
"""Add visual cues to screenshot.
|
|
458
|
+
|
|
459
|
+
This is a scaffolding implementation. Full implementation would:
|
|
460
|
+
1. Draw red marker at target location
|
|
461
|
+
2. Extract and magnify patch around target
|
|
462
|
+
3. Overlay patch in corner opposite to target
|
|
463
|
+
|
|
464
|
+
Args:
|
|
465
|
+
screenshot_path: Path to original screenshot.
|
|
466
|
+
target_coords: Normalized (x, y) target coordinates.
|
|
467
|
+
screen_size: Screen dimensions for pixel conversion.
|
|
468
|
+
|
|
469
|
+
Returns:
|
|
470
|
+
Path to augmented screenshot.
|
|
471
|
+
"""
|
|
472
|
+
if not screenshot_path:
|
|
473
|
+
return None
|
|
474
|
+
|
|
475
|
+
# Placeholder: In full implementation, would use PIL/cv2 to:
|
|
476
|
+
# 1. Load image
|
|
477
|
+
# 2. Draw red circle at target_coords
|
|
478
|
+
# 3. Extract zoom patch
|
|
479
|
+
# 4. Place patch in corner
|
|
480
|
+
# 5. Save augmented image
|
|
481
|
+
|
|
482
|
+
logger.debug(
|
|
483
|
+
f"Would augment screenshot {screenshot_path} with marker at {target_coords}"
|
|
484
|
+
)
|
|
485
|
+
|
|
486
|
+
# For scaffolding, return original path
|
|
487
|
+
# TODO: Implement actual augmentation
|
|
488
|
+
return screenshot_path
|
|
489
|
+
|
|
490
|
+
def parse_output(self, model_output: str) -> ParsedAction:
|
|
491
|
+
"""Parse coordinate output (same as RawCoords)."""
|
|
492
|
+
# Reuse RawCoords parsing logic
|
|
493
|
+
return RawCoordsCondition(self.config).parse_output(model_output)
|
|
494
|
+
|
|
495
|
+
def compute_loss(
|
|
496
|
+
self,
|
|
497
|
+
prediction: ParsedAction,
|
|
498
|
+
ground_truth: dict[str, Any],
|
|
499
|
+
) -> float:
|
|
500
|
+
"""Compute MSE loss (same as RawCoords)."""
|
|
501
|
+
return RawCoordsCondition(self.config).compute_loss(prediction, ground_truth)
|
|
502
|
+
|
|
503
|
+
|
|
504
|
+
class MarksCondition(ConditionBase):
|
|
505
|
+
"""Condition C: Marks (Element IDs).
|
|
506
|
+
|
|
507
|
+
Input: Screenshot with SoM overlay + UIElementGraph + goal
|
|
508
|
+
Output: {"type": "CLICK", "element_id": "e17"}
|
|
509
|
+
Training: Element classification (cross-entropy loss)
|
|
510
|
+
"""
|
|
511
|
+
|
|
512
|
+
def __init__(self, config: ConditionConfig):
|
|
513
|
+
super().__init__(config)
|
|
514
|
+
self.marks_config = config.marks or MarksConfig()
|
|
515
|
+
|
|
516
|
+
def prepare_input(
|
|
517
|
+
self,
|
|
518
|
+
observation: Observation,
|
|
519
|
+
goal: str,
|
|
520
|
+
history: ActionHistory | None = None,
|
|
521
|
+
) -> PreparedInput:
|
|
522
|
+
"""Prepare input with element marks overlay and UIElementGraph."""
|
|
523
|
+
prompt = self._build_base_prompt(goal, history)
|
|
524
|
+
|
|
525
|
+
# Add UIElementGraph text representation
|
|
526
|
+
if observation.ui_elements:
|
|
527
|
+
elements_text = observation.ui_elements.to_prompt_text(
|
|
528
|
+
self.marks_config.max_elements
|
|
529
|
+
)
|
|
530
|
+
prompt += f"\n\nUI ELEMENTS:\n{elements_text}"
|
|
531
|
+
else:
|
|
532
|
+
prompt += "\n\nNo UI elements detected."
|
|
533
|
+
|
|
534
|
+
prompt += "\n\nWhich element should be clicked?"
|
|
535
|
+
prompt += (
|
|
536
|
+
"\nRespond with: ACTION: CLICK([element_id]) e.g., ACTION: CLICK([e17])"
|
|
537
|
+
)
|
|
538
|
+
|
|
539
|
+
# Augment screenshot with marks overlay
|
|
540
|
+
augmented_path = self._add_marks_overlay(
|
|
541
|
+
observation.screenshot_path,
|
|
542
|
+
observation.ui_elements,
|
|
543
|
+
observation.screen_size,
|
|
544
|
+
)
|
|
545
|
+
|
|
546
|
+
return PreparedInput(
|
|
547
|
+
screenshot_path=augmented_path,
|
|
548
|
+
prompt=prompt,
|
|
549
|
+
metadata={
|
|
550
|
+
"condition": "marks",
|
|
551
|
+
"num_elements": len(observation.ui_elements.elements)
|
|
552
|
+
if observation.ui_elements
|
|
553
|
+
else 0,
|
|
554
|
+
},
|
|
555
|
+
)
|
|
556
|
+
|
|
557
|
+
def _add_marks_overlay(
|
|
558
|
+
self,
|
|
559
|
+
screenshot_path: str | None,
|
|
560
|
+
ui_elements: UIElementGraph | None,
|
|
561
|
+
screen_size: tuple[int, int] | None,
|
|
562
|
+
) -> str | None:
|
|
563
|
+
"""Add SoM-style marks overlay to screenshot.
|
|
564
|
+
|
|
565
|
+
This is a scaffolding implementation. Full implementation would:
|
|
566
|
+
1. Draw numbered labels on each element's bounding box
|
|
567
|
+
2. Use consistent styling (font, colors)
|
|
568
|
+
|
|
569
|
+
Args:
|
|
570
|
+
screenshot_path: Path to original screenshot.
|
|
571
|
+
ui_elements: UI elements to mark.
|
|
572
|
+
screen_size: Screen dimensions.
|
|
573
|
+
|
|
574
|
+
Returns:
|
|
575
|
+
Path to screenshot with marks overlay.
|
|
576
|
+
"""
|
|
577
|
+
if not screenshot_path:
|
|
578
|
+
return None
|
|
579
|
+
|
|
580
|
+
if not ui_elements:
|
|
581
|
+
return screenshot_path
|
|
582
|
+
|
|
583
|
+
# Placeholder: In full implementation, would use PIL to:
|
|
584
|
+
# 1. Load image
|
|
585
|
+
# 2. Draw colored box around each element
|
|
586
|
+
# 3. Add element ID label
|
|
587
|
+
# 4. Save marked image
|
|
588
|
+
|
|
589
|
+
logger.debug(
|
|
590
|
+
f"Would add marks overlay to {screenshot_path} "
|
|
591
|
+
f"with {len(ui_elements.elements)} elements"
|
|
592
|
+
)
|
|
593
|
+
|
|
594
|
+
# For scaffolding, return original path
|
|
595
|
+
# TODO: Implement actual overlay
|
|
596
|
+
return screenshot_path
|
|
597
|
+
|
|
598
|
+
def parse_output(self, model_output: str) -> ParsedAction:
|
|
599
|
+
"""Parse element ID output from model."""
|
|
600
|
+
import re
|
|
601
|
+
|
|
602
|
+
# Look for ACTION: CLICK([element_id]) pattern
|
|
603
|
+
action_match = re.search(
|
|
604
|
+
r"ACTION:\s*CLICK\s*\(\s*\[?\s*([a-zA-Z]?\d+)\s*\]?\s*\)",
|
|
605
|
+
model_output,
|
|
606
|
+
re.IGNORECASE,
|
|
607
|
+
)
|
|
608
|
+
if action_match:
|
|
609
|
+
element_id = action_match.group(1)
|
|
610
|
+
# Normalize element ID format
|
|
611
|
+
if not element_id.startswith("e"):
|
|
612
|
+
element_id = f"e{element_id}"
|
|
613
|
+
return ParsedAction(
|
|
614
|
+
type="click", element_id=element_id, raw_output=model_output
|
|
615
|
+
)
|
|
616
|
+
|
|
617
|
+
# Try looser patterns
|
|
618
|
+
element_match = re.search(
|
|
619
|
+
r"CLICK\s*\(\s*\[?\s*([a-zA-Z]?\d+)\s*\]?\s*\)",
|
|
620
|
+
model_output,
|
|
621
|
+
re.IGNORECASE,
|
|
622
|
+
)
|
|
623
|
+
if element_match:
|
|
624
|
+
element_id = element_match.group(1)
|
|
625
|
+
if not element_id.startswith("e"):
|
|
626
|
+
element_id = f"e{element_id}"
|
|
627
|
+
return ParsedAction(
|
|
628
|
+
type="click", element_id=element_id, raw_output=model_output
|
|
629
|
+
)
|
|
630
|
+
|
|
631
|
+
# Check for element mentioned in text (e.g., "click element e17")
|
|
632
|
+
text_match = re.search(r"\b[eE](\d+)\b", model_output)
|
|
633
|
+
if text_match:
|
|
634
|
+
return ParsedAction(
|
|
635
|
+
type="click",
|
|
636
|
+
element_id=f"e{text_match.group(1)}",
|
|
637
|
+
raw_output=model_output,
|
|
638
|
+
)
|
|
639
|
+
|
|
640
|
+
# Check for TYPE action
|
|
641
|
+
type_match = re.search(
|
|
642
|
+
r'TYPE\s*\(\s*["\'](.+?)["\']\s*\)', model_output, re.IGNORECASE
|
|
643
|
+
)
|
|
644
|
+
if type_match:
|
|
645
|
+
return ParsedAction(
|
|
646
|
+
type="type", text=type_match.group(1), raw_output=model_output
|
|
647
|
+
)
|
|
648
|
+
|
|
649
|
+
# Check for DONE action
|
|
650
|
+
if re.search(r"DONE\s*\(\s*\)", model_output, re.IGNORECASE):
|
|
651
|
+
return ParsedAction(type="done", raw_output=model_output)
|
|
652
|
+
|
|
653
|
+
return ParsedAction(
|
|
654
|
+
type="unknown",
|
|
655
|
+
raw_output=model_output,
|
|
656
|
+
parse_error="No element ID pattern found",
|
|
657
|
+
)
|
|
658
|
+
|
|
659
|
+
def compute_loss(
|
|
660
|
+
self,
|
|
661
|
+
prediction: ParsedAction,
|
|
662
|
+
ground_truth: dict[str, Any],
|
|
663
|
+
) -> float:
|
|
664
|
+
"""Compute classification loss.
|
|
665
|
+
|
|
666
|
+
For scaffolding, this returns 0 for correct, 1 for incorrect.
|
|
667
|
+
Full implementation would return proper cross-entropy loss.
|
|
668
|
+
"""
|
|
669
|
+
gt_element_id = ground_truth.get("element_id")
|
|
670
|
+
|
|
671
|
+
if gt_element_id is None:
|
|
672
|
+
logger.warning("Ground truth missing element_id, returning max loss")
|
|
673
|
+
return 1.0
|
|
674
|
+
|
|
675
|
+
if prediction.element_id is None:
|
|
676
|
+
# Prediction failed
|
|
677
|
+
return 1.0
|
|
678
|
+
|
|
679
|
+
# Normalize both IDs for comparison
|
|
680
|
+
pred_id = prediction.element_id.lower().replace("e", "")
|
|
681
|
+
gt_id = str(gt_element_id).lower().replace("e", "")
|
|
682
|
+
|
|
683
|
+
return 0.0 if pred_id == gt_id else 1.0
|
|
684
|
+
|
|
685
|
+
|
|
686
|
+
def create_condition(config: ConditionConfig) -> ConditionBase:
|
|
687
|
+
"""Factory function to create condition from config.
|
|
688
|
+
|
|
689
|
+
Args:
|
|
690
|
+
config: Condition configuration.
|
|
691
|
+
|
|
692
|
+
Returns:
|
|
693
|
+
Appropriate ConditionBase subclass instance.
|
|
694
|
+
|
|
695
|
+
Raises:
|
|
696
|
+
ValueError: If condition name is unknown.
|
|
697
|
+
"""
|
|
698
|
+
condition_map = {
|
|
699
|
+
ConditionName.RAW_COORDS: RawCoordsCondition,
|
|
700
|
+
ConditionName.COORDS_CUES: CoordsCuesCondition,
|
|
701
|
+
ConditionName.MARKS: MarksCondition,
|
|
702
|
+
}
|
|
703
|
+
|
|
704
|
+
condition_cls = condition_map.get(config.name)
|
|
705
|
+
if condition_cls is None:
|
|
706
|
+
raise ValueError(f"Unknown condition name: {config.name}")
|
|
707
|
+
|
|
708
|
+
return condition_cls(config)
|