openadapt-ml 0.2.0__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (95) hide show
  1. openadapt_ml/baselines/__init__.py +121 -0
  2. openadapt_ml/baselines/adapter.py +185 -0
  3. openadapt_ml/baselines/cli.py +314 -0
  4. openadapt_ml/baselines/config.py +448 -0
  5. openadapt_ml/baselines/parser.py +922 -0
  6. openadapt_ml/baselines/prompts.py +787 -0
  7. openadapt_ml/benchmarks/__init__.py +13 -115
  8. openadapt_ml/benchmarks/agent.py +265 -421
  9. openadapt_ml/benchmarks/azure.py +28 -19
  10. openadapt_ml/benchmarks/azure_ops_tracker.py +521 -0
  11. openadapt_ml/benchmarks/cli.py +1722 -4847
  12. openadapt_ml/benchmarks/trace_export.py +631 -0
  13. openadapt_ml/benchmarks/viewer.py +22 -5
  14. openadapt_ml/benchmarks/vm_monitor.py +530 -29
  15. openadapt_ml/benchmarks/waa_deploy/Dockerfile +47 -53
  16. openadapt_ml/benchmarks/waa_deploy/api_agent.py +21 -20
  17. openadapt_ml/cloud/azure_inference.py +3 -5
  18. openadapt_ml/cloud/lambda_labs.py +722 -307
  19. openadapt_ml/cloud/local.py +2038 -487
  20. openadapt_ml/cloud/ssh_tunnel.py +68 -26
  21. openadapt_ml/datasets/next_action.py +40 -30
  22. openadapt_ml/evals/grounding.py +8 -3
  23. openadapt_ml/evals/plot_eval_metrics.py +15 -13
  24. openadapt_ml/evals/trajectory_matching.py +41 -26
  25. openadapt_ml/experiments/demo_prompt/format_demo.py +16 -6
  26. openadapt_ml/experiments/demo_prompt/run_experiment.py +26 -16
  27. openadapt_ml/experiments/representation_shootout/__init__.py +70 -0
  28. openadapt_ml/experiments/representation_shootout/conditions.py +708 -0
  29. openadapt_ml/experiments/representation_shootout/config.py +390 -0
  30. openadapt_ml/experiments/representation_shootout/evaluator.py +659 -0
  31. openadapt_ml/experiments/representation_shootout/runner.py +687 -0
  32. openadapt_ml/experiments/waa_demo/runner.py +29 -14
  33. openadapt_ml/export/parquet.py +36 -24
  34. openadapt_ml/grounding/detector.py +18 -14
  35. openadapt_ml/ingest/__init__.py +8 -6
  36. openadapt_ml/ingest/capture.py +25 -22
  37. openadapt_ml/ingest/loader.py +7 -4
  38. openadapt_ml/ingest/synthetic.py +189 -100
  39. openadapt_ml/models/api_adapter.py +14 -4
  40. openadapt_ml/models/base_adapter.py +10 -2
  41. openadapt_ml/models/providers/__init__.py +288 -0
  42. openadapt_ml/models/providers/anthropic.py +266 -0
  43. openadapt_ml/models/providers/base.py +299 -0
  44. openadapt_ml/models/providers/google.py +376 -0
  45. openadapt_ml/models/providers/openai.py +342 -0
  46. openadapt_ml/models/qwen_vl.py +46 -19
  47. openadapt_ml/perception/__init__.py +35 -0
  48. openadapt_ml/perception/integration.py +399 -0
  49. openadapt_ml/retrieval/demo_retriever.py +50 -24
  50. openadapt_ml/retrieval/embeddings.py +9 -8
  51. openadapt_ml/retrieval/retriever.py +3 -1
  52. openadapt_ml/runtime/__init__.py +50 -0
  53. openadapt_ml/runtime/policy.py +18 -5
  54. openadapt_ml/runtime/safety_gate.py +471 -0
  55. openadapt_ml/schema/__init__.py +9 -0
  56. openadapt_ml/schema/converters.py +74 -27
  57. openadapt_ml/schema/episode.py +31 -18
  58. openadapt_ml/scripts/capture_screenshots.py +530 -0
  59. openadapt_ml/scripts/compare.py +85 -54
  60. openadapt_ml/scripts/demo_policy.py +4 -1
  61. openadapt_ml/scripts/eval_policy.py +15 -9
  62. openadapt_ml/scripts/make_gif.py +1 -1
  63. openadapt_ml/scripts/prepare_synthetic.py +3 -1
  64. openadapt_ml/scripts/train.py +21 -9
  65. openadapt_ml/segmentation/README.md +920 -0
  66. openadapt_ml/segmentation/__init__.py +97 -0
  67. openadapt_ml/segmentation/adapters/__init__.py +5 -0
  68. openadapt_ml/segmentation/adapters/capture_adapter.py +420 -0
  69. openadapt_ml/segmentation/annotator.py +610 -0
  70. openadapt_ml/segmentation/cache.py +290 -0
  71. openadapt_ml/segmentation/cli.py +674 -0
  72. openadapt_ml/segmentation/deduplicator.py +656 -0
  73. openadapt_ml/segmentation/frame_describer.py +788 -0
  74. openadapt_ml/segmentation/pipeline.py +340 -0
  75. openadapt_ml/segmentation/schemas.py +622 -0
  76. openadapt_ml/segmentation/segment_extractor.py +634 -0
  77. openadapt_ml/training/azure_ops_viewer.py +1097 -0
  78. openadapt_ml/training/benchmark_viewer.py +52 -41
  79. openadapt_ml/training/shared_ui.py +7 -7
  80. openadapt_ml/training/stub_provider.py +57 -35
  81. openadapt_ml/training/trainer.py +143 -86
  82. openadapt_ml/training/trl_trainer.py +70 -21
  83. openadapt_ml/training/viewer.py +323 -108
  84. openadapt_ml/training/viewer_components.py +180 -0
  85. {openadapt_ml-0.2.0.dist-info → openadapt_ml-0.2.1.dist-info}/METADATA +215 -14
  86. openadapt_ml-0.2.1.dist-info/RECORD +116 -0
  87. openadapt_ml/benchmarks/base.py +0 -366
  88. openadapt_ml/benchmarks/data_collection.py +0 -432
  89. openadapt_ml/benchmarks/live_tracker.py +0 -180
  90. openadapt_ml/benchmarks/runner.py +0 -418
  91. openadapt_ml/benchmarks/waa.py +0 -761
  92. openadapt_ml/benchmarks/waa_live.py +0 -619
  93. openadapt_ml-0.2.0.dist-info/RECORD +0 -86
  94. {openadapt_ml-0.2.0.dist-info → openadapt_ml-0.2.1.dist-info}/WHEEL +0 -0
  95. {openadapt_ml-0.2.0.dist-info → openadapt_ml-0.2.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,708 @@
1
+ """Experimental conditions for the Representation Shootout.
2
+
3
+ Defines the three conditions:
4
+ - Condition A: Raw Coordinates - Direct coordinate regression
5
+ - Condition B: Coordinates + Visual Cues - Enhanced with markers and zoom
6
+ - Condition C: Marks (Element IDs) - Element classification using SoM
7
+
8
+ Each condition implements:
9
+ 1. Input preparation (screenshot augmentation, prompt construction)
10
+ 2. Output parsing (model response to action dict)
11
+ 3. Loss computation (for training)
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ import logging
17
+ from abc import ABC, abstractmethod
18
+ from dataclasses import dataclass
19
+ from typing import Any
20
+
21
+ from openadapt_ml.experiments.representation_shootout.config import (
22
+ ConditionConfig,
23
+ ConditionName,
24
+ MarksConfig,
25
+ OutputFormat,
26
+ VisualCuesConfig,
27
+ )
28
+
29
+ logger = logging.getLogger(__name__)
30
+
31
+
32
+ @dataclass
33
+ class UIElement:
34
+ """UI element with ID and bounding box for marks condition."""
35
+
36
+ element_id: str
37
+ role: str
38
+ name: str | None
39
+ bbox: tuple[float, float, float, float] # (x1, y1, x2, y2) normalized or pixels
40
+
41
+ @property
42
+ def center(self) -> tuple[float, float]:
43
+ """Get center point of element."""
44
+ x1, y1, x2, y2 = self.bbox
45
+ return ((x1 + x2) / 2, (y1 + y2) / 2)
46
+
47
+ def contains_point(self, x: float, y: float) -> bool:
48
+ """Check if point is within this element's bbox."""
49
+ x1, y1, x2, y2 = self.bbox
50
+ return x1 <= x <= x2 and y1 <= y <= y2
51
+
52
+
53
+ @dataclass
54
+ class UIElementGraph:
55
+ """Collection of UI elements for marks condition."""
56
+
57
+ elements: list[UIElement]
58
+
59
+ def get_element(self, element_id: str) -> UIElement | None:
60
+ """Get element by ID."""
61
+ for el in self.elements:
62
+ if el.element_id == element_id:
63
+ return el
64
+ return None
65
+
66
+ def find_element_at(self, x: float, y: float) -> UIElement | None:
67
+ """Find element containing the given point."""
68
+ for el in self.elements:
69
+ if el.contains_point(x, y):
70
+ return el
71
+ return None
72
+
73
+ def to_prompt_text(self, max_elements: int | None = None) -> str:
74
+ """Format elements for text prompt.
75
+
76
+ Returns:
77
+ Formatted string like:
78
+ [e1] button "Submit" at (0.4, 0.8)-(0.6, 0.85)
79
+ [e17] textfield "Username" at (0.3, 0.4)-(0.7, 0.45)
80
+ """
81
+ lines = []
82
+ elements_to_show = (
83
+ self.elements[:max_elements] if max_elements else self.elements
84
+ )
85
+ for el in elements_to_show:
86
+ name_part = f' "{el.name}"' if el.name else ""
87
+ x1, y1, x2, y2 = el.bbox
88
+ lines.append(
89
+ f"[{el.element_id}] {el.role}{name_part} at ({x1:.2f}, {y1:.2f})-({x2:.2f}, {y2:.2f})"
90
+ )
91
+ return "\n".join(lines)
92
+
93
+
94
+ @dataclass
95
+ class Observation:
96
+ """Observation data for input preparation.
97
+
98
+ This is a simplified observation structure for the experiment.
99
+ In production, use openadapt_ml.benchmarks.base.BenchmarkObservation.
100
+ """
101
+
102
+ screenshot_path: str | None = None
103
+ screenshot_bytes: bytes | None = None
104
+ screen_size: tuple[int, int] | None = None # (width, height)
105
+ ui_elements: UIElementGraph | None = None
106
+ window_title: str | None = None
107
+ url: str | None = None
108
+
109
+
110
+ @dataclass
111
+ class ActionHistory:
112
+ """History of previous actions."""
113
+
114
+ actions: list[dict[str, Any]] # List of action dicts
115
+
116
+ def to_prompt_text(self, max_steps: int = 5) -> str:
117
+ """Format history for text prompt."""
118
+ lines = []
119
+ for i, action in enumerate(self.actions[-max_steps:], 1):
120
+ action_type = action.get("type", "unknown").upper()
121
+ if action_type == "CLICK":
122
+ if "element_id" in action:
123
+ lines.append(f"{i}. CLICK([{action['element_id']}])")
124
+ elif "x" in action and "y" in action:
125
+ lines.append(f"{i}. CLICK({action['x']:.3f}, {action['y']:.3f})")
126
+ else:
127
+ lines.append(f"{i}. CLICK()")
128
+ elif action_type == "TYPE":
129
+ text = action.get("text", "")
130
+ lines.append(f'{i}. TYPE("{text}")')
131
+ elif action_type == "KEY":
132
+ key = action.get("key", "")
133
+ lines.append(f"{i}. KEY({key})")
134
+ elif action_type == "SCROLL":
135
+ direction = action.get("direction", "down")
136
+ lines.append(f"{i}. SCROLL({direction})")
137
+ else:
138
+ lines.append(f"{i}. {action_type}()")
139
+ return "\n".join(lines)
140
+
141
+
142
+ @dataclass
143
+ class PreparedInput:
144
+ """Prepared input for the model.
145
+
146
+ Attributes:
147
+ screenshot_path: Path to (possibly augmented) screenshot.
148
+ prompt: Text prompt for the model.
149
+ metadata: Additional metadata for debugging/analysis.
150
+ """
151
+
152
+ screenshot_path: str | None
153
+ prompt: str
154
+ metadata: dict[str, Any] | None = None
155
+
156
+
157
+ @dataclass
158
+ class ParsedAction:
159
+ """Parsed action from model output.
160
+
161
+ Attributes:
162
+ type: Action type (e.g., "click", "type", "scroll", "done").
163
+ x: X coordinate (for coordinate-based outputs).
164
+ y: Y coordinate (for coordinate-based outputs).
165
+ element_id: Element ID (for marks-based outputs).
166
+ text: Text to type (for type actions).
167
+ raw_output: Original model output string.
168
+ parse_error: Error message if parsing failed.
169
+ """
170
+
171
+ type: str
172
+ x: float | None = None
173
+ y: float | None = None
174
+ element_id: str | None = None
175
+ text: str | None = None
176
+ raw_output: str | None = None
177
+ parse_error: str | None = None
178
+
179
+ def to_dict(self) -> dict[str, Any]:
180
+ """Convert to action dictionary."""
181
+ result: dict[str, Any] = {"type": self.type}
182
+ if self.x is not None:
183
+ result["x"] = self.x
184
+ if self.y is not None:
185
+ result["y"] = self.y
186
+ if self.element_id is not None:
187
+ result["element_id"] = self.element_id
188
+ if self.text is not None:
189
+ result["text"] = self.text
190
+ return result
191
+
192
+
193
+ class ConditionBase(ABC):
194
+ """Abstract base class for experimental conditions.
195
+
196
+ Each condition defines how to:
197
+ 1. Prepare input from observations (possibly augmenting screenshots)
198
+ 2. Parse model output into structured actions
199
+ 3. Compute training loss
200
+ """
201
+
202
+ def __init__(self, config: ConditionConfig):
203
+ """Initialize condition.
204
+
205
+ Args:
206
+ config: Condition-specific configuration.
207
+ """
208
+ self.config = config
209
+
210
+ @property
211
+ def name(self) -> ConditionName:
212
+ """Condition name."""
213
+ return self.config.name
214
+
215
+ @property
216
+ def output_format(self) -> OutputFormat:
217
+ """Expected output format."""
218
+ return self.config.output_format
219
+
220
+ @abstractmethod
221
+ def prepare_input(
222
+ self,
223
+ observation: Observation,
224
+ goal: str,
225
+ history: ActionHistory | None = None,
226
+ ) -> PreparedInput:
227
+ """Prepare model input from observation.
228
+
229
+ Args:
230
+ observation: Current observation with screenshot and UI elements.
231
+ goal: Task goal/instruction.
232
+ history: Optional history of previous actions.
233
+
234
+ Returns:
235
+ PreparedInput with (possibly augmented) screenshot and prompt.
236
+ """
237
+ pass
238
+
239
+ @abstractmethod
240
+ def parse_output(self, model_output: str) -> ParsedAction:
241
+ """Parse model output to structured action.
242
+
243
+ Args:
244
+ model_output: Raw string output from model.
245
+
246
+ Returns:
247
+ ParsedAction with extracted action information.
248
+ """
249
+ pass
250
+
251
+ @abstractmethod
252
+ def compute_loss(
253
+ self,
254
+ prediction: ParsedAction,
255
+ ground_truth: dict[str, Any],
256
+ ) -> float:
257
+ """Compute training loss for a single sample.
258
+
259
+ Args:
260
+ prediction: Parsed prediction from model.
261
+ ground_truth: Ground truth action dict with coordinates/element_id.
262
+
263
+ Returns:
264
+ Loss value (lower is better).
265
+ """
266
+ pass
267
+
268
+ def _build_base_prompt(
269
+ self,
270
+ goal: str,
271
+ history: ActionHistory | None = None,
272
+ ) -> str:
273
+ """Build the base prompt text (shared across conditions)."""
274
+ parts = [f"GOAL: {goal}"]
275
+
276
+ if self.config.include_history and history and history.actions:
277
+ history_text = history.to_prompt_text(self.config.max_history_steps)
278
+ parts.append(f"\nPREVIOUS ACTIONS:\n{history_text}")
279
+
280
+ return "\n".join(parts)
281
+
282
+
283
+ class RawCoordsCondition(ConditionBase):
284
+ """Condition A: Raw Coordinates.
285
+
286
+ Input: Screenshot (unmodified) + goal + history
287
+ Output: {"type": "CLICK", "x": float, "y": float}
288
+ Training: Coordinate regression (MSE loss)
289
+ """
290
+
291
+ def prepare_input(
292
+ self,
293
+ observation: Observation,
294
+ goal: str,
295
+ history: ActionHistory | None = None,
296
+ ) -> PreparedInput:
297
+ """Prepare input without any screenshot augmentation."""
298
+ prompt = self._build_base_prompt(goal, history)
299
+ prompt += "\n\nAnalyze the screenshot and provide the next action."
300
+ prompt += "\nRespond with: ACTION: CLICK(x, y) where x and y are normalized coordinates (0.0-1.0)"
301
+
302
+ return PreparedInput(
303
+ screenshot_path=observation.screenshot_path,
304
+ prompt=prompt,
305
+ metadata={"condition": "raw_coords"},
306
+ )
307
+
308
+ def parse_output(self, model_output: str) -> ParsedAction:
309
+ """Parse coordinate output from model."""
310
+ import re
311
+
312
+ # Look for ACTION: CLICK(x, y) pattern
313
+ action_match = re.search(
314
+ r"ACTION:\s*CLICK\s*\(\s*([\d.]+)\s*,\s*([\d.]+)\s*\)",
315
+ model_output,
316
+ re.IGNORECASE,
317
+ )
318
+ if action_match:
319
+ try:
320
+ x = float(action_match.group(1))
321
+ y = float(action_match.group(2))
322
+ return ParsedAction(type="click", x=x, y=y, raw_output=model_output)
323
+ except ValueError as e:
324
+ return ParsedAction(
325
+ type="click",
326
+ raw_output=model_output,
327
+ parse_error=f"Invalid coordinates: {e}",
328
+ )
329
+
330
+ # Try looser patterns
331
+ coord_match = re.search(
332
+ r"CLICK\s*\(\s*([\d.]+)\s*,\s*([\d.]+)\s*\)",
333
+ model_output,
334
+ re.IGNORECASE,
335
+ )
336
+ if coord_match:
337
+ try:
338
+ x = float(coord_match.group(1))
339
+ y = float(coord_match.group(2))
340
+ return ParsedAction(type="click", x=x, y=y, raw_output=model_output)
341
+ except ValueError:
342
+ pass
343
+
344
+ # Check for TYPE action
345
+ type_match = re.search(
346
+ r'TYPE\s*\(\s*["\'](.+?)["\']\s*\)', model_output, re.IGNORECASE
347
+ )
348
+ if type_match:
349
+ return ParsedAction(
350
+ type="type", text=type_match.group(1), raw_output=model_output
351
+ )
352
+
353
+ # Check for DONE action
354
+ if re.search(r"DONE\s*\(\s*\)", model_output, re.IGNORECASE):
355
+ return ParsedAction(type="done", raw_output=model_output)
356
+
357
+ return ParsedAction(
358
+ type="unknown",
359
+ raw_output=model_output,
360
+ parse_error="No action pattern found",
361
+ )
362
+
363
+ def compute_loss(
364
+ self,
365
+ prediction: ParsedAction,
366
+ ground_truth: dict[str, Any],
367
+ ) -> float:
368
+ """Compute MSE loss between predicted and ground truth coordinates."""
369
+ gt_x = ground_truth.get("x")
370
+ gt_y = ground_truth.get("y")
371
+
372
+ if gt_x is None or gt_y is None:
373
+ logger.warning("Ground truth missing coordinates, returning max loss")
374
+ return 1.0
375
+
376
+ if prediction.x is None or prediction.y is None:
377
+ # Prediction failed, return max loss
378
+ return 1.0
379
+
380
+ # MSE in normalized coordinate space
381
+ mse = (prediction.x - gt_x) ** 2 + (prediction.y - gt_y) ** 2
382
+ return mse
383
+
384
+
385
+ class CoordsCuesCondition(ConditionBase):
386
+ """Condition B: Coordinates + Visual Cues.
387
+
388
+ Input: Screenshot with red marker at click target + zoomed patch + goal
389
+ Output: {"type": "CLICK", "x": float, "y": float}
390
+ Training: Enhanced coordinate regression (MSE loss)
391
+
392
+ Note: Visual cues are only added during training. At test time,
393
+ the model must predict without the cues.
394
+ """
395
+
396
+ def __init__(self, config: ConditionConfig):
397
+ super().__init__(config)
398
+ self.visual_cues = config.visual_cues or VisualCuesConfig()
399
+
400
+ def prepare_input(
401
+ self,
402
+ observation: Observation,
403
+ goal: str,
404
+ history: ActionHistory | None = None,
405
+ target_coords: tuple[float, float] | None = None,
406
+ is_training: bool = False,
407
+ ) -> PreparedInput:
408
+ """Prepare input with visual cues during training.
409
+
410
+ Args:
411
+ observation: Current observation.
412
+ goal: Task goal.
413
+ history: Action history.
414
+ target_coords: Target (x, y) for training augmentation.
415
+ is_training: Whether this is for training (add cues) or eval (no cues).
416
+ """
417
+ prompt = self._build_base_prompt(goal, history)
418
+
419
+ augmented_path = observation.screenshot_path
420
+ metadata: dict[str, Any] = {
421
+ "condition": "coords_cues",
422
+ "is_training": is_training,
423
+ }
424
+
425
+ if is_training and target_coords:
426
+ # Add visual cues for training
427
+ prompt += (
428
+ "\n\nThe red marker and zoomed inset show the target click location."
429
+ )
430
+ prompt += "\nLearn to identify this location based on the UI context."
431
+
432
+ # Augment screenshot (placeholder - actual implementation would use PIL/cv2)
433
+ augmented_path = self._augment_screenshot(
434
+ observation.screenshot_path,
435
+ target_coords,
436
+ observation.screen_size,
437
+ )
438
+ metadata["target_coords"] = target_coords
439
+ metadata["augmented"] = True
440
+ else:
441
+ prompt += "\n\nAnalyze the screenshot and provide the next action."
442
+
443
+ prompt += "\nRespond with: ACTION: CLICK(x, y) where x and y are normalized coordinates (0.0-1.0)"
444
+
445
+ return PreparedInput(
446
+ screenshot_path=augmented_path,
447
+ prompt=prompt,
448
+ metadata=metadata,
449
+ )
450
+
451
+ def _augment_screenshot(
452
+ self,
453
+ screenshot_path: str | None,
454
+ target_coords: tuple[float, float],
455
+ screen_size: tuple[int, int] | None,
456
+ ) -> str | None:
457
+ """Add visual cues to screenshot.
458
+
459
+ This is a scaffolding implementation. Full implementation would:
460
+ 1. Draw red marker at target location
461
+ 2. Extract and magnify patch around target
462
+ 3. Overlay patch in corner opposite to target
463
+
464
+ Args:
465
+ screenshot_path: Path to original screenshot.
466
+ target_coords: Normalized (x, y) target coordinates.
467
+ screen_size: Screen dimensions for pixel conversion.
468
+
469
+ Returns:
470
+ Path to augmented screenshot.
471
+ """
472
+ if not screenshot_path:
473
+ return None
474
+
475
+ # Placeholder: In full implementation, would use PIL/cv2 to:
476
+ # 1. Load image
477
+ # 2. Draw red circle at target_coords
478
+ # 3. Extract zoom patch
479
+ # 4. Place patch in corner
480
+ # 5. Save augmented image
481
+
482
+ logger.debug(
483
+ f"Would augment screenshot {screenshot_path} with marker at {target_coords}"
484
+ )
485
+
486
+ # For scaffolding, return original path
487
+ # TODO: Implement actual augmentation
488
+ return screenshot_path
489
+
490
+ def parse_output(self, model_output: str) -> ParsedAction:
491
+ """Parse coordinate output (same as RawCoords)."""
492
+ # Reuse RawCoords parsing logic
493
+ return RawCoordsCondition(self.config).parse_output(model_output)
494
+
495
+ def compute_loss(
496
+ self,
497
+ prediction: ParsedAction,
498
+ ground_truth: dict[str, Any],
499
+ ) -> float:
500
+ """Compute MSE loss (same as RawCoords)."""
501
+ return RawCoordsCondition(self.config).compute_loss(prediction, ground_truth)
502
+
503
+
504
+ class MarksCondition(ConditionBase):
505
+ """Condition C: Marks (Element IDs).
506
+
507
+ Input: Screenshot with SoM overlay + UIElementGraph + goal
508
+ Output: {"type": "CLICK", "element_id": "e17"}
509
+ Training: Element classification (cross-entropy loss)
510
+ """
511
+
512
+ def __init__(self, config: ConditionConfig):
513
+ super().__init__(config)
514
+ self.marks_config = config.marks or MarksConfig()
515
+
516
+ def prepare_input(
517
+ self,
518
+ observation: Observation,
519
+ goal: str,
520
+ history: ActionHistory | None = None,
521
+ ) -> PreparedInput:
522
+ """Prepare input with element marks overlay and UIElementGraph."""
523
+ prompt = self._build_base_prompt(goal, history)
524
+
525
+ # Add UIElementGraph text representation
526
+ if observation.ui_elements:
527
+ elements_text = observation.ui_elements.to_prompt_text(
528
+ self.marks_config.max_elements
529
+ )
530
+ prompt += f"\n\nUI ELEMENTS:\n{elements_text}"
531
+ else:
532
+ prompt += "\n\nNo UI elements detected."
533
+
534
+ prompt += "\n\nWhich element should be clicked?"
535
+ prompt += (
536
+ "\nRespond with: ACTION: CLICK([element_id]) e.g., ACTION: CLICK([e17])"
537
+ )
538
+
539
+ # Augment screenshot with marks overlay
540
+ augmented_path = self._add_marks_overlay(
541
+ observation.screenshot_path,
542
+ observation.ui_elements,
543
+ observation.screen_size,
544
+ )
545
+
546
+ return PreparedInput(
547
+ screenshot_path=augmented_path,
548
+ prompt=prompt,
549
+ metadata={
550
+ "condition": "marks",
551
+ "num_elements": len(observation.ui_elements.elements)
552
+ if observation.ui_elements
553
+ else 0,
554
+ },
555
+ )
556
+
557
+ def _add_marks_overlay(
558
+ self,
559
+ screenshot_path: str | None,
560
+ ui_elements: UIElementGraph | None,
561
+ screen_size: tuple[int, int] | None,
562
+ ) -> str | None:
563
+ """Add SoM-style marks overlay to screenshot.
564
+
565
+ This is a scaffolding implementation. Full implementation would:
566
+ 1. Draw numbered labels on each element's bounding box
567
+ 2. Use consistent styling (font, colors)
568
+
569
+ Args:
570
+ screenshot_path: Path to original screenshot.
571
+ ui_elements: UI elements to mark.
572
+ screen_size: Screen dimensions.
573
+
574
+ Returns:
575
+ Path to screenshot with marks overlay.
576
+ """
577
+ if not screenshot_path:
578
+ return None
579
+
580
+ if not ui_elements:
581
+ return screenshot_path
582
+
583
+ # Placeholder: In full implementation, would use PIL to:
584
+ # 1. Load image
585
+ # 2. Draw colored box around each element
586
+ # 3. Add element ID label
587
+ # 4. Save marked image
588
+
589
+ logger.debug(
590
+ f"Would add marks overlay to {screenshot_path} "
591
+ f"with {len(ui_elements.elements)} elements"
592
+ )
593
+
594
+ # For scaffolding, return original path
595
+ # TODO: Implement actual overlay
596
+ return screenshot_path
597
+
598
+ def parse_output(self, model_output: str) -> ParsedAction:
599
+ """Parse element ID output from model."""
600
+ import re
601
+
602
+ # Look for ACTION: CLICK([element_id]) pattern
603
+ action_match = re.search(
604
+ r"ACTION:\s*CLICK\s*\(\s*\[?\s*([a-zA-Z]?\d+)\s*\]?\s*\)",
605
+ model_output,
606
+ re.IGNORECASE,
607
+ )
608
+ if action_match:
609
+ element_id = action_match.group(1)
610
+ # Normalize element ID format
611
+ if not element_id.startswith("e"):
612
+ element_id = f"e{element_id}"
613
+ return ParsedAction(
614
+ type="click", element_id=element_id, raw_output=model_output
615
+ )
616
+
617
+ # Try looser patterns
618
+ element_match = re.search(
619
+ r"CLICK\s*\(\s*\[?\s*([a-zA-Z]?\d+)\s*\]?\s*\)",
620
+ model_output,
621
+ re.IGNORECASE,
622
+ )
623
+ if element_match:
624
+ element_id = element_match.group(1)
625
+ if not element_id.startswith("e"):
626
+ element_id = f"e{element_id}"
627
+ return ParsedAction(
628
+ type="click", element_id=element_id, raw_output=model_output
629
+ )
630
+
631
+ # Check for element mentioned in text (e.g., "click element e17")
632
+ text_match = re.search(r"\b[eE](\d+)\b", model_output)
633
+ if text_match:
634
+ return ParsedAction(
635
+ type="click",
636
+ element_id=f"e{text_match.group(1)}",
637
+ raw_output=model_output,
638
+ )
639
+
640
+ # Check for TYPE action
641
+ type_match = re.search(
642
+ r'TYPE\s*\(\s*["\'](.+?)["\']\s*\)', model_output, re.IGNORECASE
643
+ )
644
+ if type_match:
645
+ return ParsedAction(
646
+ type="type", text=type_match.group(1), raw_output=model_output
647
+ )
648
+
649
+ # Check for DONE action
650
+ if re.search(r"DONE\s*\(\s*\)", model_output, re.IGNORECASE):
651
+ return ParsedAction(type="done", raw_output=model_output)
652
+
653
+ return ParsedAction(
654
+ type="unknown",
655
+ raw_output=model_output,
656
+ parse_error="No element ID pattern found",
657
+ )
658
+
659
+ def compute_loss(
660
+ self,
661
+ prediction: ParsedAction,
662
+ ground_truth: dict[str, Any],
663
+ ) -> float:
664
+ """Compute classification loss.
665
+
666
+ For scaffolding, this returns 0 for correct, 1 for incorrect.
667
+ Full implementation would return proper cross-entropy loss.
668
+ """
669
+ gt_element_id = ground_truth.get("element_id")
670
+
671
+ if gt_element_id is None:
672
+ logger.warning("Ground truth missing element_id, returning max loss")
673
+ return 1.0
674
+
675
+ if prediction.element_id is None:
676
+ # Prediction failed
677
+ return 1.0
678
+
679
+ # Normalize both IDs for comparison
680
+ pred_id = prediction.element_id.lower().replace("e", "")
681
+ gt_id = str(gt_element_id).lower().replace("e", "")
682
+
683
+ return 0.0 if pred_id == gt_id else 1.0
684
+
685
+
686
+ def create_condition(config: ConditionConfig) -> ConditionBase:
687
+ """Factory function to create condition from config.
688
+
689
+ Args:
690
+ config: Condition configuration.
691
+
692
+ Returns:
693
+ Appropriate ConditionBase subclass instance.
694
+
695
+ Raises:
696
+ ValueError: If condition name is unknown.
697
+ """
698
+ condition_map = {
699
+ ConditionName.RAW_COORDS: RawCoordsCondition,
700
+ ConditionName.COORDS_CUES: CoordsCuesCondition,
701
+ ConditionName.MARKS: MarksCondition,
702
+ }
703
+
704
+ condition_cls = condition_map.get(config.name)
705
+ if condition_cls is None:
706
+ raise ValueError(f"Unknown condition name: {config.name}")
707
+
708
+ return condition_cls(config)