openadapt-ml 0.2.0__py3-none-any.whl → 0.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (95) hide show
  1. openadapt_ml/baselines/__init__.py +121 -0
  2. openadapt_ml/baselines/adapter.py +185 -0
  3. openadapt_ml/baselines/cli.py +314 -0
  4. openadapt_ml/baselines/config.py +448 -0
  5. openadapt_ml/baselines/parser.py +922 -0
  6. openadapt_ml/baselines/prompts.py +787 -0
  7. openadapt_ml/benchmarks/__init__.py +13 -115
  8. openadapt_ml/benchmarks/agent.py +265 -421
  9. openadapt_ml/benchmarks/azure.py +28 -19
  10. openadapt_ml/benchmarks/azure_ops_tracker.py +521 -0
  11. openadapt_ml/benchmarks/cli.py +1722 -4847
  12. openadapt_ml/benchmarks/trace_export.py +631 -0
  13. openadapt_ml/benchmarks/viewer.py +22 -5
  14. openadapt_ml/benchmarks/vm_monitor.py +530 -29
  15. openadapt_ml/benchmarks/waa_deploy/Dockerfile +47 -53
  16. openadapt_ml/benchmarks/waa_deploy/api_agent.py +21 -20
  17. openadapt_ml/cloud/azure_inference.py +3 -5
  18. openadapt_ml/cloud/lambda_labs.py +722 -307
  19. openadapt_ml/cloud/local.py +2038 -487
  20. openadapt_ml/cloud/ssh_tunnel.py +68 -26
  21. openadapt_ml/datasets/next_action.py +40 -30
  22. openadapt_ml/evals/grounding.py +8 -3
  23. openadapt_ml/evals/plot_eval_metrics.py +15 -13
  24. openadapt_ml/evals/trajectory_matching.py +41 -26
  25. openadapt_ml/experiments/demo_prompt/format_demo.py +16 -6
  26. openadapt_ml/experiments/demo_prompt/run_experiment.py +26 -16
  27. openadapt_ml/experiments/representation_shootout/__init__.py +70 -0
  28. openadapt_ml/experiments/representation_shootout/conditions.py +708 -0
  29. openadapt_ml/experiments/representation_shootout/config.py +390 -0
  30. openadapt_ml/experiments/representation_shootout/evaluator.py +659 -0
  31. openadapt_ml/experiments/representation_shootout/runner.py +687 -0
  32. openadapt_ml/experiments/waa_demo/runner.py +29 -14
  33. openadapt_ml/export/parquet.py +36 -24
  34. openadapt_ml/grounding/detector.py +18 -14
  35. openadapt_ml/ingest/__init__.py +8 -6
  36. openadapt_ml/ingest/capture.py +25 -22
  37. openadapt_ml/ingest/loader.py +7 -4
  38. openadapt_ml/ingest/synthetic.py +189 -100
  39. openadapt_ml/models/api_adapter.py +14 -4
  40. openadapt_ml/models/base_adapter.py +10 -2
  41. openadapt_ml/models/providers/__init__.py +288 -0
  42. openadapt_ml/models/providers/anthropic.py +266 -0
  43. openadapt_ml/models/providers/base.py +299 -0
  44. openadapt_ml/models/providers/google.py +376 -0
  45. openadapt_ml/models/providers/openai.py +342 -0
  46. openadapt_ml/models/qwen_vl.py +46 -19
  47. openadapt_ml/perception/__init__.py +35 -0
  48. openadapt_ml/perception/integration.py +399 -0
  49. openadapt_ml/retrieval/demo_retriever.py +50 -24
  50. openadapt_ml/retrieval/embeddings.py +9 -8
  51. openadapt_ml/retrieval/retriever.py +3 -1
  52. openadapt_ml/runtime/__init__.py +50 -0
  53. openadapt_ml/runtime/policy.py +18 -5
  54. openadapt_ml/runtime/safety_gate.py +471 -0
  55. openadapt_ml/schema/__init__.py +9 -0
  56. openadapt_ml/schema/converters.py +74 -27
  57. openadapt_ml/schema/episode.py +31 -18
  58. openadapt_ml/scripts/capture_screenshots.py +530 -0
  59. openadapt_ml/scripts/compare.py +85 -54
  60. openadapt_ml/scripts/demo_policy.py +4 -1
  61. openadapt_ml/scripts/eval_policy.py +15 -9
  62. openadapt_ml/scripts/make_gif.py +1 -1
  63. openadapt_ml/scripts/prepare_synthetic.py +3 -1
  64. openadapt_ml/scripts/train.py +21 -9
  65. openadapt_ml/segmentation/README.md +920 -0
  66. openadapt_ml/segmentation/__init__.py +97 -0
  67. openadapt_ml/segmentation/adapters/__init__.py +5 -0
  68. openadapt_ml/segmentation/adapters/capture_adapter.py +420 -0
  69. openadapt_ml/segmentation/annotator.py +610 -0
  70. openadapt_ml/segmentation/cache.py +290 -0
  71. openadapt_ml/segmentation/cli.py +674 -0
  72. openadapt_ml/segmentation/deduplicator.py +656 -0
  73. openadapt_ml/segmentation/frame_describer.py +788 -0
  74. openadapt_ml/segmentation/pipeline.py +340 -0
  75. openadapt_ml/segmentation/schemas.py +622 -0
  76. openadapt_ml/segmentation/segment_extractor.py +634 -0
  77. openadapt_ml/training/azure_ops_viewer.py +1097 -0
  78. openadapt_ml/training/benchmark_viewer.py +52 -41
  79. openadapt_ml/training/shared_ui.py +7 -7
  80. openadapt_ml/training/stub_provider.py +57 -35
  81. openadapt_ml/training/trainer.py +143 -86
  82. openadapt_ml/training/trl_trainer.py +70 -21
  83. openadapt_ml/training/viewer.py +323 -108
  84. openadapt_ml/training/viewer_components.py +180 -0
  85. {openadapt_ml-0.2.0.dist-info → openadapt_ml-0.2.2.dist-info}/METADATA +215 -14
  86. openadapt_ml-0.2.2.dist-info/RECORD +116 -0
  87. openadapt_ml/benchmarks/base.py +0 -366
  88. openadapt_ml/benchmarks/data_collection.py +0 -432
  89. openadapt_ml/benchmarks/live_tracker.py +0 -180
  90. openadapt_ml/benchmarks/runner.py +0 -418
  91. openadapt_ml/benchmarks/waa.py +0 -761
  92. openadapt_ml/benchmarks/waa_live.py +0 -619
  93. openadapt_ml-0.2.0.dist-info/RECORD +0 -86
  94. {openadapt_ml-0.2.0.dist-info → openadapt_ml-0.2.2.dist-info}/WHEEL +0 -0
  95. {openadapt_ml-0.2.0.dist-info → openadapt_ml-0.2.2.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,659 @@
1
+ """Evaluation under drift conditions for the Representation Shootout.
2
+
3
+ This module implements:
4
+ 1. Drift transformations (resolution, translation, theme, scroll)
5
+ 2. Metrics computation (click-hit rate, grounding accuracy, etc.)
6
+ 3. Decision rule for recommendation
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import logging
12
+ import math
13
+ from dataclasses import dataclass, field
14
+ from typing import Any
15
+
16
+ from openadapt_ml.experiments.representation_shootout.conditions import (
17
+ ConditionBase,
18
+ Observation,
19
+ ParsedAction,
20
+ UIElement,
21
+ UIElementGraph,
22
+ )
23
+ from openadapt_ml.experiments.representation_shootout.config import (
24
+ ConditionName,
25
+ DriftConfig,
26
+ DriftType,
27
+ MetricName,
28
+ ResolutionDriftParams,
29
+ ScrollDriftParams,
30
+ ThemeDriftParams,
31
+ TranslationDriftParams,
32
+ )
33
+
34
+ logger = logging.getLogger(__name__)
35
+
36
+
37
+ @dataclass
38
+ class Sample:
39
+ """A single evaluation sample.
40
+
41
+ Attributes:
42
+ sample_id: Unique identifier for this sample.
43
+ observation: Observation data (screenshot, UI elements).
44
+ goal: Task instruction.
45
+ ground_truth: Ground truth action dict.
46
+ drift_config: Applied drift configuration.
47
+ """
48
+
49
+ sample_id: str
50
+ observation: Observation
51
+ goal: str
52
+ ground_truth: dict[str, Any]
53
+ drift_config: DriftConfig | None = None
54
+
55
+
56
+ @dataclass
57
+ class SampleResult:
58
+ """Result of evaluating a single sample.
59
+
60
+ Attributes:
61
+ sample_id: Sample identifier.
62
+ condition: Condition that was evaluated.
63
+ drift: Drift configuration applied.
64
+ prediction: Parsed prediction from model.
65
+ ground_truth: Ground truth action.
66
+ metrics: Computed metrics for this sample.
67
+ """
68
+
69
+ sample_id: str
70
+ condition: ConditionName
71
+ drift: str
72
+ prediction: ParsedAction
73
+ ground_truth: dict[str, Any]
74
+ metrics: dict[str, float]
75
+
76
+
77
+ @dataclass
78
+ class EvaluationResult:
79
+ """Aggregated evaluation results for a condition under a drift.
80
+
81
+ Attributes:
82
+ condition: Condition evaluated.
83
+ drift: Drift configuration.
84
+ num_samples: Number of samples evaluated.
85
+ metrics: Aggregated metrics (averages).
86
+ sample_results: Individual sample results.
87
+ """
88
+
89
+ condition: ConditionName
90
+ drift: str
91
+ num_samples: int
92
+ metrics: dict[str, float]
93
+ sample_results: list[SampleResult] = field(default_factory=list)
94
+
95
+
96
+ @dataclass
97
+ class Recommendation:
98
+ """Final recommendation from the experiment.
99
+
100
+ Attributes:
101
+ recommended: Recommended approach ("COORDINATES" or "MARKS").
102
+ reason: Explanation for the recommendation.
103
+ coords_cues_avg: Average performance of Coords+Cues across drifts.
104
+ marks_avg: Average performance of Marks across drifts.
105
+ tolerance: Tolerance threshold used for decision.
106
+ detailed_comparison: Per-drift comparison data.
107
+ """
108
+
109
+ recommended: str # "COORDINATES" or "MARKS"
110
+ reason: str
111
+ coords_cues_avg: float
112
+ marks_avg: float
113
+ tolerance: float
114
+ detailed_comparison: dict[str, dict[str, float]] = field(default_factory=dict)
115
+
116
+
117
+ class DriftTransformer:
118
+ """Applies drift transformations to samples."""
119
+
120
+ @staticmethod
121
+ def apply_drift(
122
+ observation: Observation,
123
+ ground_truth: dict[str, Any],
124
+ drift_config: DriftConfig,
125
+ ) -> tuple[Observation, dict[str, Any]]:
126
+ """Apply drift transformation to observation and ground truth.
127
+
128
+ Args:
129
+ observation: Original observation.
130
+ ground_truth: Original ground truth action.
131
+ drift_config: Drift to apply.
132
+
133
+ Returns:
134
+ Tuple of (transformed_observation, transformed_ground_truth).
135
+ """
136
+ if drift_config.is_canonical:
137
+ return observation, ground_truth
138
+
139
+ if drift_config.drift_type == DriftType.RESOLUTION:
140
+ return DriftTransformer._apply_resolution_drift(
141
+ observation,
142
+ ground_truth,
143
+ drift_config.params, # type: ignore
144
+ )
145
+ elif drift_config.drift_type == DriftType.TRANSLATION:
146
+ return DriftTransformer._apply_translation_drift(
147
+ observation,
148
+ ground_truth,
149
+ drift_config.params, # type: ignore
150
+ )
151
+ elif drift_config.drift_type == DriftType.THEME:
152
+ return DriftTransformer._apply_theme_drift(
153
+ observation,
154
+ ground_truth,
155
+ drift_config.params, # type: ignore
156
+ )
157
+ elif drift_config.drift_type == DriftType.SCROLL:
158
+ return DriftTransformer._apply_scroll_drift(
159
+ observation,
160
+ ground_truth,
161
+ drift_config.params, # type: ignore
162
+ )
163
+ else:
164
+ logger.warning(f"Unknown drift type: {drift_config.drift_type}")
165
+ return observation, ground_truth
166
+
167
+ @staticmethod
168
+ def _apply_resolution_drift(
169
+ observation: Observation,
170
+ ground_truth: dict[str, Any],
171
+ params: ResolutionDriftParams,
172
+ ) -> tuple[Observation, dict[str, Any]]:
173
+ """Apply resolution scaling.
174
+
175
+ For normalized coordinates, no transformation is needed (they scale automatically).
176
+ For pixel coordinates, scale by the factor.
177
+ For UI elements, scale bounding boxes.
178
+ """
179
+ scale = params.scale
180
+
181
+ # Create new observation with scaled screen size
182
+ new_screen_size = None
183
+ if observation.screen_size:
184
+ w, h = observation.screen_size
185
+ new_screen_size = (int(w * scale), int(h * scale))
186
+
187
+ # Scale UI element bboxes if they are in pixels
188
+ new_ui_elements = None
189
+ if observation.ui_elements:
190
+ new_elements = []
191
+ for el in observation.ui_elements.elements:
192
+ # Assuming bboxes are normalized (0-1), no scaling needed
193
+ # If they were pixels, we would scale them here
194
+ new_elements.append(el)
195
+ new_ui_elements = UIElementGraph(elements=new_elements)
196
+
197
+ new_observation = Observation(
198
+ screenshot_path=observation.screenshot_path, # Would need actual resize
199
+ screenshot_bytes=observation.screenshot_bytes,
200
+ screen_size=new_screen_size,
201
+ ui_elements=new_ui_elements,
202
+ window_title=observation.window_title,
203
+ url=observation.url,
204
+ )
205
+
206
+ # Ground truth coordinates are normalized, so no change needed
207
+ # If they were pixels, we would scale them
208
+ new_ground_truth = ground_truth.copy()
209
+
210
+ logger.debug(f"Applied resolution drift {scale}x: {new_screen_size}")
211
+ return new_observation, new_ground_truth
212
+
213
+ @staticmethod
214
+ def _apply_translation_drift(
215
+ observation: Observation,
216
+ ground_truth: dict[str, Any],
217
+ params: TranslationDriftParams,
218
+ ) -> tuple[Observation, dict[str, Any]]:
219
+ """Apply window translation.
220
+
221
+ This shifts the window position while keeping the UI elements
222
+ in their relative positions within the window.
223
+ """
224
+ offset_x = params.offset_x
225
+ offset_y = params.offset_y
226
+
227
+ # For normalized coordinates within the window, no change is needed
228
+ # The translation affects where the window is on screen, but not
229
+ # the relative positions within the window
230
+
231
+ # However, if coordinates are screen-absolute, we need to adjust
232
+ # For this experiment, we assume window-relative normalized coords
233
+
234
+ new_ground_truth = ground_truth.copy()
235
+
236
+ # If ground truth has screen-absolute coordinates, adjust them
237
+ if "screen_x" in ground_truth and "screen_y" in ground_truth:
238
+ # Convert pixel offset to normalized offset
239
+ if observation.screen_size:
240
+ w, h = observation.screen_size
241
+ norm_offset_x = offset_x / w
242
+ norm_offset_y = offset_y / h
243
+ new_ground_truth["screen_x"] = ground_truth["screen_x"] + norm_offset_x
244
+ new_ground_truth["screen_y"] = ground_truth["screen_y"] + norm_offset_y
245
+
246
+ logger.debug(f"Applied translation drift: ({offset_x}, {offset_y})")
247
+ return observation, new_ground_truth
248
+
249
+ @staticmethod
250
+ def _apply_theme_drift(
251
+ observation: Observation,
252
+ ground_truth: dict[str, Any],
253
+ params: ThemeDriftParams,
254
+ ) -> tuple[Observation, dict[str, Any]]:
255
+ """Apply theme change.
256
+
257
+ Theme changes affect visual appearance but not coordinates.
258
+ Full implementation would load theme-variant screenshots.
259
+ """
260
+ theme = params.theme
261
+
262
+ # For scaffolding, we don't transform the screenshot
263
+ # Full implementation would:
264
+ # 1. Load a pre-recorded screenshot in the target theme, OR
265
+ # 2. Apply synthetic color transformations
266
+
267
+ logger.debug(f"Applied theme drift: {theme}")
268
+ return observation, ground_truth
269
+
270
+ @staticmethod
271
+ def _apply_scroll_drift(
272
+ observation: Observation,
273
+ ground_truth: dict[str, Any],
274
+ params: ScrollDriftParams,
275
+ ) -> tuple[Observation, dict[str, Any]]:
276
+ """Apply scroll offset.
277
+
278
+ Scroll changes the visible portion of the page, affecting
279
+ which elements are visible and their y-coordinates.
280
+ """
281
+ offset_y = params.offset_y
282
+
283
+ # Adjust UI element bboxes for scroll
284
+ new_ui_elements = None
285
+ if observation.ui_elements and observation.screen_size:
286
+ _, screen_h = observation.screen_size
287
+ norm_offset = offset_y / screen_h
288
+
289
+ new_elements = []
290
+ for el in observation.ui_elements.elements:
291
+ x1, y1, x2, y2 = el.bbox
292
+ # Shift y coordinates up by scroll amount
293
+ new_y1 = y1 - norm_offset
294
+ new_y2 = y2 - norm_offset
295
+
296
+ # Only include elements still visible on screen
297
+ if new_y2 > 0 and new_y1 < 1:
298
+ new_elements.append(
299
+ UIElement(
300
+ element_id=el.element_id,
301
+ role=el.role,
302
+ name=el.name,
303
+ bbox=(x1, max(0, new_y1), x2, min(1, new_y2)),
304
+ )
305
+ )
306
+
307
+ new_ui_elements = UIElementGraph(elements=new_elements)
308
+
309
+ new_observation = Observation(
310
+ screenshot_path=observation.screenshot_path, # Would need scroll-shifted image
311
+ screenshot_bytes=observation.screenshot_bytes,
312
+ screen_size=observation.screen_size,
313
+ ui_elements=new_ui_elements,
314
+ window_title=observation.window_title,
315
+ url=observation.url,
316
+ )
317
+
318
+ # Adjust ground truth coordinates
319
+ new_ground_truth = ground_truth.copy()
320
+ if "y" in ground_truth and observation.screen_size:
321
+ _, screen_h = observation.screen_size
322
+ norm_offset = offset_y / screen_h
323
+ new_ground_truth["y"] = ground_truth["y"] - norm_offset
324
+
325
+ logger.debug(f"Applied scroll drift: {offset_y}px")
326
+ return new_observation, new_ground_truth
327
+
328
+
329
+ def compute_metrics(
330
+ prediction: ParsedAction,
331
+ ground_truth: dict[str, Any],
332
+ ui_elements: UIElementGraph | None = None,
333
+ ) -> dict[str, float]:
334
+ """Compute all metrics for a single prediction.
335
+
336
+ Args:
337
+ prediction: Parsed prediction from model.
338
+ ground_truth: Ground truth action dict with coordinates/element_id.
339
+ ui_elements: UI elements (needed for click-hit computation).
340
+
341
+ Returns:
342
+ Dict of metric name to value.
343
+ """
344
+ metrics: dict[str, float] = {}
345
+
346
+ # Click-Hit Rate: Is predicted coordinate within target element bbox?
347
+ if prediction.type == "click":
348
+ hit = 0.0
349
+
350
+ if prediction.x is not None and prediction.y is not None:
351
+ # Coordinate-based prediction
352
+ target_bbox = ground_truth.get("target_bbox")
353
+ if target_bbox:
354
+ x1, y1, x2, y2 = target_bbox
355
+ if x1 <= prediction.x <= x2 and y1 <= prediction.y <= y2:
356
+ hit = 1.0
357
+
358
+ # Also check if coordinates are within the target element from ui_elements
359
+ elif ui_elements and ground_truth.get("element_id"):
360
+ target_el = ui_elements.get_element(ground_truth["element_id"])
361
+ if target_el and target_el.contains_point(prediction.x, prediction.y):
362
+ hit = 1.0
363
+
364
+ elif prediction.element_id is not None and ui_elements:
365
+ # Element-based prediction - find element and check if it matches target
366
+ pred_el = ui_elements.get_element(prediction.element_id)
367
+ gt_el_id = ground_truth.get("element_id")
368
+ if pred_el and gt_el_id:
369
+ # Normalize IDs for comparison
370
+ pred_id = prediction.element_id.lower().replace("e", "")
371
+ gt_id = str(gt_el_id).lower().replace("e", "")
372
+ if pred_id == gt_id:
373
+ hit = 1.0
374
+
375
+ metrics[MetricName.CLICK_HIT_RATE.value] = hit
376
+
377
+ # Grounding Top-1 Accuracy: Is predicted element ID correct?
378
+ if prediction.element_id is not None:
379
+ gt_el_id = ground_truth.get("element_id")
380
+ if gt_el_id:
381
+ pred_id = prediction.element_id.lower().replace("e", "")
382
+ gt_id = str(gt_el_id).lower().replace("e", "")
383
+ metrics[MetricName.GROUNDING_TOP1_ACCURACY.value] = (
384
+ 1.0 if pred_id == gt_id else 0.0
385
+ )
386
+ else:
387
+ metrics[MetricName.GROUNDING_TOP1_ACCURACY.value] = 0.0
388
+
389
+ # Coordinate Distance: L2 distance to target (normalized)
390
+ gt_x = ground_truth.get("x")
391
+ gt_y = ground_truth.get("y")
392
+
393
+ if gt_x is not None and gt_y is not None:
394
+ if prediction.x is not None and prediction.y is not None:
395
+ distance = math.sqrt(
396
+ (prediction.x - gt_x) ** 2 + (prediction.y - gt_y) ** 2
397
+ )
398
+ else:
399
+ # If prediction failed or is element-based, compute distance from element center
400
+ if prediction.element_id and ui_elements:
401
+ pred_el = ui_elements.get_element(prediction.element_id)
402
+ if pred_el:
403
+ cx, cy = pred_el.center
404
+ distance = math.sqrt((cx - gt_x) ** 2 + (cy - gt_y) ** 2)
405
+ else:
406
+ distance = math.sqrt(2) # Max normalized distance
407
+ else:
408
+ distance = math.sqrt(2) # Max normalized distance
409
+
410
+ metrics[MetricName.COORD_DISTANCE.value] = distance
411
+
412
+ return metrics
413
+
414
+
415
+ def aggregate_metrics(sample_results: list[SampleResult]) -> dict[str, float]:
416
+ """Aggregate metrics across multiple samples.
417
+
418
+ Args:
419
+ sample_results: List of individual sample results.
420
+
421
+ Returns:
422
+ Dict of metric name to averaged value.
423
+ """
424
+ if not sample_results:
425
+ return {}
426
+
427
+ # Collect all metrics
428
+ all_metrics: dict[str, list[float]] = {}
429
+ for result in sample_results:
430
+ for metric_name, value in result.metrics.items():
431
+ if metric_name not in all_metrics:
432
+ all_metrics[metric_name] = []
433
+ all_metrics[metric_name].append(value)
434
+
435
+ # Compute averages
436
+ aggregated = {}
437
+ for metric_name, values in all_metrics.items():
438
+ aggregated[metric_name] = sum(values) / len(values)
439
+
440
+ return aggregated
441
+
442
+
443
+ class DriftEvaluator:
444
+ """Evaluates conditions under drift conditions.
445
+
446
+ This class orchestrates the evaluation process:
447
+ 1. Apply drift transformations to samples
448
+ 2. Generate predictions using conditions
449
+ 3. Compute metrics
450
+ 4. Aggregate results
451
+ """
452
+
453
+ def __init__(
454
+ self,
455
+ conditions: dict[ConditionName, ConditionBase],
456
+ drift_configs: list[DriftConfig],
457
+ ):
458
+ """Initialize evaluator.
459
+
460
+ Args:
461
+ conditions: Map of condition name to condition instance.
462
+ drift_configs: List of drift configurations to test.
463
+ """
464
+ self.conditions = conditions
465
+ self.drift_configs = drift_configs
466
+ self._canonical_results: dict[ConditionName, dict[str, float]] = {}
467
+
468
+ def evaluate_sample(
469
+ self,
470
+ condition: ConditionBase,
471
+ sample: Sample,
472
+ drift_config: DriftConfig,
473
+ model_output: str,
474
+ ) -> SampleResult:
475
+ """Evaluate a single sample under a drift condition.
476
+
477
+ Args:
478
+ condition: Condition to use for evaluation.
479
+ sample: Sample to evaluate.
480
+ drift_config: Drift to apply.
481
+ model_output: Raw model output to parse.
482
+
483
+ Returns:
484
+ SampleResult with metrics.
485
+ """
486
+ # Apply drift
487
+ transformed_obs, transformed_gt = DriftTransformer.apply_drift(
488
+ sample.observation, sample.ground_truth, drift_config
489
+ )
490
+
491
+ # Parse model output
492
+ prediction = condition.parse_output(model_output)
493
+
494
+ # Compute metrics
495
+ metrics = compute_metrics(
496
+ prediction, transformed_gt, transformed_obs.ui_elements
497
+ )
498
+
499
+ return SampleResult(
500
+ sample_id=sample.sample_id,
501
+ condition=condition.name,
502
+ drift=drift_config.name,
503
+ prediction=prediction,
504
+ ground_truth=transformed_gt,
505
+ metrics=metrics,
506
+ )
507
+
508
+ def evaluate_condition_under_drift(
509
+ self,
510
+ condition: ConditionBase,
511
+ samples: list[Sample],
512
+ drift_config: DriftConfig,
513
+ model_outputs: list[str],
514
+ ) -> EvaluationResult:
515
+ """Evaluate a condition on all samples under a drift.
516
+
517
+ Args:
518
+ condition: Condition to evaluate.
519
+ samples: Samples to evaluate.
520
+ drift_config: Drift to apply.
521
+ model_outputs: Model outputs corresponding to samples.
522
+
523
+ Returns:
524
+ EvaluationResult with aggregated metrics.
525
+ """
526
+ sample_results = []
527
+ for sample, output in zip(samples, model_outputs):
528
+ result = self.evaluate_sample(condition, sample, drift_config, output)
529
+ sample_results.append(result)
530
+
531
+ aggregated = aggregate_metrics(sample_results)
532
+
533
+ return EvaluationResult(
534
+ condition=condition.name,
535
+ drift=drift_config.name,
536
+ num_samples=len(samples),
537
+ metrics=aggregated,
538
+ sample_results=sample_results,
539
+ )
540
+
541
+ def compute_robustness_scores(
542
+ self,
543
+ results: list[EvaluationResult],
544
+ primary_metric: str = MetricName.CLICK_HIT_RATE.value,
545
+ ) -> dict[ConditionName, dict[str, float]]:
546
+ """Compute robustness scores relative to canonical baseline.
547
+
548
+ Args:
549
+ results: Evaluation results across conditions and drifts.
550
+ primary_metric: Metric to use for robustness computation.
551
+
552
+ Returns:
553
+ Dict mapping condition to dict of drift to robustness score.
554
+ """
555
+ # Group results by condition
556
+ by_condition: dict[ConditionName, list[EvaluationResult]] = {}
557
+ for r in results:
558
+ if r.condition not in by_condition:
559
+ by_condition[r.condition] = []
560
+ by_condition[r.condition].append(r)
561
+
562
+ robustness_scores: dict[ConditionName, dict[str, float]] = {}
563
+
564
+ for condition, cond_results in by_condition.items():
565
+ # Find canonical result
566
+ canonical_result = next(
567
+ (r for r in cond_results if r.drift == "canonical"), None
568
+ )
569
+ if not canonical_result:
570
+ logger.warning(f"No canonical result for condition {condition}")
571
+ continue
572
+
573
+ canonical_value = canonical_result.metrics.get(primary_metric, 0)
574
+ if canonical_value == 0:
575
+ canonical_value = 1e-6 # Avoid division by zero
576
+
577
+ robustness_scores[condition] = {}
578
+ for r in cond_results:
579
+ if r.drift == "canonical":
580
+ robustness_scores[condition][r.drift] = 1.0
581
+ else:
582
+ drift_value = r.metrics.get(primary_metric, 0)
583
+ robustness_scores[condition][r.drift] = (
584
+ drift_value / canonical_value
585
+ )
586
+
587
+ return robustness_scores
588
+
589
+
590
+ def make_recommendation(
591
+ results: list[EvaluationResult],
592
+ tolerance: float = 0.05,
593
+ primary_metric: str = MetricName.CLICK_HIT_RATE.value,
594
+ ) -> Recommendation:
595
+ """Make recommendation based on evaluation results.
596
+
597
+ Decision rule (from design doc):
598
+ - If Coords+Cues within 5% of Marks under drift -> choose Coordinates
599
+ - Otherwise -> choose Marks
600
+
601
+ Args:
602
+ results: Evaluation results across all conditions and drifts.
603
+ tolerance: Tolerance threshold for decision (default 5%).
604
+ primary_metric: Metric to use for comparison.
605
+
606
+ Returns:
607
+ Recommendation with explanation.
608
+ """
609
+ # Group results by condition and compute averages across drifts
610
+ by_condition: dict[ConditionName, list[float]] = {}
611
+ detailed_comparison: dict[str, dict[str, float]] = {}
612
+
613
+ for r in results:
614
+ if r.condition not in by_condition:
615
+ by_condition[r.condition] = []
616
+
617
+ metric_value = r.metrics.get(primary_metric, 0)
618
+ by_condition[r.condition].append(metric_value)
619
+
620
+ # Track detailed comparison
621
+ drift_key = r.drift
622
+ if drift_key not in detailed_comparison:
623
+ detailed_comparison[drift_key] = {}
624
+ detailed_comparison[drift_key][r.condition.value] = metric_value
625
+
626
+ # Compute averages
627
+ condition_averages: dict[ConditionName, float] = {}
628
+ for condition, values in by_condition.items():
629
+ condition_averages[condition] = sum(values) / len(values) if values else 0
630
+
631
+ # Get averages for decision
632
+ coords_cues_avg = condition_averages.get(ConditionName.COORDS_CUES, 0)
633
+ marks_avg = condition_averages.get(ConditionName.MARKS, 0)
634
+
635
+ # Apply decision rule
636
+ if coords_cues_avg >= marks_avg - tolerance:
637
+ recommended = "COORDINATES"
638
+ reason = (
639
+ f"Coords+Cues ({coords_cues_avg:.1%}) is within {tolerance * 100}% of "
640
+ f"Marks ({marks_avg:.1%}) under drift. Coordinates approach is simpler "
641
+ "and doesn't require element detection pipeline."
642
+ )
643
+ else:
644
+ recommended = "MARKS"
645
+ gap = marks_avg - coords_cues_avg
646
+ reason = (
647
+ f"Marks ({marks_avg:.1%}) outperforms Coords+Cues ({coords_cues_avg:.1%}) "
648
+ f"by {gap:.1%} (>{tolerance * 100}%) under drift. Element-based approach "
649
+ "provides better robustness to UI changes."
650
+ )
651
+
652
+ return Recommendation(
653
+ recommended=recommended,
654
+ reason=reason,
655
+ coords_cues_avg=coords_cues_avg,
656
+ marks_avg=marks_avg,
657
+ tolerance=tolerance,
658
+ detailed_comparison=detailed_comparison,
659
+ )