openadapt-ml 0.2.0__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openadapt_ml/baselines/__init__.py +121 -0
- openadapt_ml/baselines/adapter.py +185 -0
- openadapt_ml/baselines/cli.py +314 -0
- openadapt_ml/baselines/config.py +448 -0
- openadapt_ml/baselines/parser.py +922 -0
- openadapt_ml/baselines/prompts.py +787 -0
- openadapt_ml/benchmarks/__init__.py +13 -115
- openadapt_ml/benchmarks/agent.py +265 -421
- openadapt_ml/benchmarks/azure.py +28 -19
- openadapt_ml/benchmarks/azure_ops_tracker.py +521 -0
- openadapt_ml/benchmarks/cli.py +1722 -4847
- openadapt_ml/benchmarks/trace_export.py +631 -0
- openadapt_ml/benchmarks/viewer.py +22 -5
- openadapt_ml/benchmarks/vm_monitor.py +530 -29
- openadapt_ml/benchmarks/waa_deploy/Dockerfile +47 -53
- openadapt_ml/benchmarks/waa_deploy/api_agent.py +21 -20
- openadapt_ml/cloud/azure_inference.py +3 -5
- openadapt_ml/cloud/lambda_labs.py +722 -307
- openadapt_ml/cloud/local.py +2038 -487
- openadapt_ml/cloud/ssh_tunnel.py +68 -26
- openadapt_ml/datasets/next_action.py +40 -30
- openadapt_ml/evals/grounding.py +8 -3
- openadapt_ml/evals/plot_eval_metrics.py +15 -13
- openadapt_ml/evals/trajectory_matching.py +41 -26
- openadapt_ml/experiments/demo_prompt/format_demo.py +16 -6
- openadapt_ml/experiments/demo_prompt/run_experiment.py +26 -16
- openadapt_ml/experiments/representation_shootout/__init__.py +70 -0
- openadapt_ml/experiments/representation_shootout/conditions.py +708 -0
- openadapt_ml/experiments/representation_shootout/config.py +390 -0
- openadapt_ml/experiments/representation_shootout/evaluator.py +659 -0
- openadapt_ml/experiments/representation_shootout/runner.py +687 -0
- openadapt_ml/experiments/waa_demo/runner.py +29 -14
- openadapt_ml/export/parquet.py +36 -24
- openadapt_ml/grounding/detector.py +18 -14
- openadapt_ml/ingest/__init__.py +8 -6
- openadapt_ml/ingest/capture.py +25 -22
- openadapt_ml/ingest/loader.py +7 -4
- openadapt_ml/ingest/synthetic.py +189 -100
- openadapt_ml/models/api_adapter.py +14 -4
- openadapt_ml/models/base_adapter.py +10 -2
- openadapt_ml/models/providers/__init__.py +288 -0
- openadapt_ml/models/providers/anthropic.py +266 -0
- openadapt_ml/models/providers/base.py +299 -0
- openadapt_ml/models/providers/google.py +376 -0
- openadapt_ml/models/providers/openai.py +342 -0
- openadapt_ml/models/qwen_vl.py +46 -19
- openadapt_ml/perception/__init__.py +35 -0
- openadapt_ml/perception/integration.py +399 -0
- openadapt_ml/retrieval/demo_retriever.py +50 -24
- openadapt_ml/retrieval/embeddings.py +9 -8
- openadapt_ml/retrieval/retriever.py +3 -1
- openadapt_ml/runtime/__init__.py +50 -0
- openadapt_ml/runtime/policy.py +18 -5
- openadapt_ml/runtime/safety_gate.py +471 -0
- openadapt_ml/schema/__init__.py +9 -0
- openadapt_ml/schema/converters.py +74 -27
- openadapt_ml/schema/episode.py +31 -18
- openadapt_ml/scripts/capture_screenshots.py +530 -0
- openadapt_ml/scripts/compare.py +85 -54
- openadapt_ml/scripts/demo_policy.py +4 -1
- openadapt_ml/scripts/eval_policy.py +15 -9
- openadapt_ml/scripts/make_gif.py +1 -1
- openadapt_ml/scripts/prepare_synthetic.py +3 -1
- openadapt_ml/scripts/train.py +21 -9
- openadapt_ml/segmentation/README.md +920 -0
- openadapt_ml/segmentation/__init__.py +97 -0
- openadapt_ml/segmentation/adapters/__init__.py +5 -0
- openadapt_ml/segmentation/adapters/capture_adapter.py +420 -0
- openadapt_ml/segmentation/annotator.py +610 -0
- openadapt_ml/segmentation/cache.py +290 -0
- openadapt_ml/segmentation/cli.py +674 -0
- openadapt_ml/segmentation/deduplicator.py +656 -0
- openadapt_ml/segmentation/frame_describer.py +788 -0
- openadapt_ml/segmentation/pipeline.py +340 -0
- openadapt_ml/segmentation/schemas.py +622 -0
- openadapt_ml/segmentation/segment_extractor.py +634 -0
- openadapt_ml/training/azure_ops_viewer.py +1097 -0
- openadapt_ml/training/benchmark_viewer.py +52 -41
- openadapt_ml/training/shared_ui.py +7 -7
- openadapt_ml/training/stub_provider.py +57 -35
- openadapt_ml/training/trainer.py +143 -86
- openadapt_ml/training/trl_trainer.py +70 -21
- openadapt_ml/training/viewer.py +323 -108
- openadapt_ml/training/viewer_components.py +180 -0
- {openadapt_ml-0.2.0.dist-info → openadapt_ml-0.2.1.dist-info}/METADATA +215 -14
- openadapt_ml-0.2.1.dist-info/RECORD +116 -0
- openadapt_ml/benchmarks/base.py +0 -366
- openadapt_ml/benchmarks/data_collection.py +0 -432
- openadapt_ml/benchmarks/live_tracker.py +0 -180
- openadapt_ml/benchmarks/runner.py +0 -418
- openadapt_ml/benchmarks/waa.py +0 -761
- openadapt_ml/benchmarks/waa_live.py +0 -619
- openadapt_ml-0.2.0.dist-info/RECORD +0 -86
- {openadapt_ml-0.2.0.dist-info → openadapt_ml-0.2.1.dist-info}/WHEEL +0 -0
- {openadapt_ml-0.2.0.dist-info → openadapt_ml-0.2.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,659 @@
|
|
|
1
|
+
"""Evaluation under drift conditions for the Representation Shootout.
|
|
2
|
+
|
|
3
|
+
This module implements:
|
|
4
|
+
1. Drift transformations (resolution, translation, theme, scroll)
|
|
5
|
+
2. Metrics computation (click-hit rate, grounding accuracy, etc.)
|
|
6
|
+
3. Decision rule for recommendation
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
import logging
|
|
12
|
+
import math
|
|
13
|
+
from dataclasses import dataclass, field
|
|
14
|
+
from typing import Any
|
|
15
|
+
|
|
16
|
+
from openadapt_ml.experiments.representation_shootout.conditions import (
|
|
17
|
+
ConditionBase,
|
|
18
|
+
Observation,
|
|
19
|
+
ParsedAction,
|
|
20
|
+
UIElement,
|
|
21
|
+
UIElementGraph,
|
|
22
|
+
)
|
|
23
|
+
from openadapt_ml.experiments.representation_shootout.config import (
|
|
24
|
+
ConditionName,
|
|
25
|
+
DriftConfig,
|
|
26
|
+
DriftType,
|
|
27
|
+
MetricName,
|
|
28
|
+
ResolutionDriftParams,
|
|
29
|
+
ScrollDriftParams,
|
|
30
|
+
ThemeDriftParams,
|
|
31
|
+
TranslationDriftParams,
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
logger = logging.getLogger(__name__)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@dataclass
|
|
38
|
+
class Sample:
|
|
39
|
+
"""A single evaluation sample.
|
|
40
|
+
|
|
41
|
+
Attributes:
|
|
42
|
+
sample_id: Unique identifier for this sample.
|
|
43
|
+
observation: Observation data (screenshot, UI elements).
|
|
44
|
+
goal: Task instruction.
|
|
45
|
+
ground_truth: Ground truth action dict.
|
|
46
|
+
drift_config: Applied drift configuration.
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
sample_id: str
|
|
50
|
+
observation: Observation
|
|
51
|
+
goal: str
|
|
52
|
+
ground_truth: dict[str, Any]
|
|
53
|
+
drift_config: DriftConfig | None = None
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
@dataclass
|
|
57
|
+
class SampleResult:
|
|
58
|
+
"""Result of evaluating a single sample.
|
|
59
|
+
|
|
60
|
+
Attributes:
|
|
61
|
+
sample_id: Sample identifier.
|
|
62
|
+
condition: Condition that was evaluated.
|
|
63
|
+
drift: Drift configuration applied.
|
|
64
|
+
prediction: Parsed prediction from model.
|
|
65
|
+
ground_truth: Ground truth action.
|
|
66
|
+
metrics: Computed metrics for this sample.
|
|
67
|
+
"""
|
|
68
|
+
|
|
69
|
+
sample_id: str
|
|
70
|
+
condition: ConditionName
|
|
71
|
+
drift: str
|
|
72
|
+
prediction: ParsedAction
|
|
73
|
+
ground_truth: dict[str, Any]
|
|
74
|
+
metrics: dict[str, float]
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
@dataclass
|
|
78
|
+
class EvaluationResult:
|
|
79
|
+
"""Aggregated evaluation results for a condition under a drift.
|
|
80
|
+
|
|
81
|
+
Attributes:
|
|
82
|
+
condition: Condition evaluated.
|
|
83
|
+
drift: Drift configuration.
|
|
84
|
+
num_samples: Number of samples evaluated.
|
|
85
|
+
metrics: Aggregated metrics (averages).
|
|
86
|
+
sample_results: Individual sample results.
|
|
87
|
+
"""
|
|
88
|
+
|
|
89
|
+
condition: ConditionName
|
|
90
|
+
drift: str
|
|
91
|
+
num_samples: int
|
|
92
|
+
metrics: dict[str, float]
|
|
93
|
+
sample_results: list[SampleResult] = field(default_factory=list)
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
@dataclass
|
|
97
|
+
class Recommendation:
|
|
98
|
+
"""Final recommendation from the experiment.
|
|
99
|
+
|
|
100
|
+
Attributes:
|
|
101
|
+
recommended: Recommended approach ("COORDINATES" or "MARKS").
|
|
102
|
+
reason: Explanation for the recommendation.
|
|
103
|
+
coords_cues_avg: Average performance of Coords+Cues across drifts.
|
|
104
|
+
marks_avg: Average performance of Marks across drifts.
|
|
105
|
+
tolerance: Tolerance threshold used for decision.
|
|
106
|
+
detailed_comparison: Per-drift comparison data.
|
|
107
|
+
"""
|
|
108
|
+
|
|
109
|
+
recommended: str # "COORDINATES" or "MARKS"
|
|
110
|
+
reason: str
|
|
111
|
+
coords_cues_avg: float
|
|
112
|
+
marks_avg: float
|
|
113
|
+
tolerance: float
|
|
114
|
+
detailed_comparison: dict[str, dict[str, float]] = field(default_factory=dict)
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
class DriftTransformer:
|
|
118
|
+
"""Applies drift transformations to samples."""
|
|
119
|
+
|
|
120
|
+
@staticmethod
|
|
121
|
+
def apply_drift(
|
|
122
|
+
observation: Observation,
|
|
123
|
+
ground_truth: dict[str, Any],
|
|
124
|
+
drift_config: DriftConfig,
|
|
125
|
+
) -> tuple[Observation, dict[str, Any]]:
|
|
126
|
+
"""Apply drift transformation to observation and ground truth.
|
|
127
|
+
|
|
128
|
+
Args:
|
|
129
|
+
observation: Original observation.
|
|
130
|
+
ground_truth: Original ground truth action.
|
|
131
|
+
drift_config: Drift to apply.
|
|
132
|
+
|
|
133
|
+
Returns:
|
|
134
|
+
Tuple of (transformed_observation, transformed_ground_truth).
|
|
135
|
+
"""
|
|
136
|
+
if drift_config.is_canonical:
|
|
137
|
+
return observation, ground_truth
|
|
138
|
+
|
|
139
|
+
if drift_config.drift_type == DriftType.RESOLUTION:
|
|
140
|
+
return DriftTransformer._apply_resolution_drift(
|
|
141
|
+
observation,
|
|
142
|
+
ground_truth,
|
|
143
|
+
drift_config.params, # type: ignore
|
|
144
|
+
)
|
|
145
|
+
elif drift_config.drift_type == DriftType.TRANSLATION:
|
|
146
|
+
return DriftTransformer._apply_translation_drift(
|
|
147
|
+
observation,
|
|
148
|
+
ground_truth,
|
|
149
|
+
drift_config.params, # type: ignore
|
|
150
|
+
)
|
|
151
|
+
elif drift_config.drift_type == DriftType.THEME:
|
|
152
|
+
return DriftTransformer._apply_theme_drift(
|
|
153
|
+
observation,
|
|
154
|
+
ground_truth,
|
|
155
|
+
drift_config.params, # type: ignore
|
|
156
|
+
)
|
|
157
|
+
elif drift_config.drift_type == DriftType.SCROLL:
|
|
158
|
+
return DriftTransformer._apply_scroll_drift(
|
|
159
|
+
observation,
|
|
160
|
+
ground_truth,
|
|
161
|
+
drift_config.params, # type: ignore
|
|
162
|
+
)
|
|
163
|
+
else:
|
|
164
|
+
logger.warning(f"Unknown drift type: {drift_config.drift_type}")
|
|
165
|
+
return observation, ground_truth
|
|
166
|
+
|
|
167
|
+
@staticmethod
|
|
168
|
+
def _apply_resolution_drift(
|
|
169
|
+
observation: Observation,
|
|
170
|
+
ground_truth: dict[str, Any],
|
|
171
|
+
params: ResolutionDriftParams,
|
|
172
|
+
) -> tuple[Observation, dict[str, Any]]:
|
|
173
|
+
"""Apply resolution scaling.
|
|
174
|
+
|
|
175
|
+
For normalized coordinates, no transformation is needed (they scale automatically).
|
|
176
|
+
For pixel coordinates, scale by the factor.
|
|
177
|
+
For UI elements, scale bounding boxes.
|
|
178
|
+
"""
|
|
179
|
+
scale = params.scale
|
|
180
|
+
|
|
181
|
+
# Create new observation with scaled screen size
|
|
182
|
+
new_screen_size = None
|
|
183
|
+
if observation.screen_size:
|
|
184
|
+
w, h = observation.screen_size
|
|
185
|
+
new_screen_size = (int(w * scale), int(h * scale))
|
|
186
|
+
|
|
187
|
+
# Scale UI element bboxes if they are in pixels
|
|
188
|
+
new_ui_elements = None
|
|
189
|
+
if observation.ui_elements:
|
|
190
|
+
new_elements = []
|
|
191
|
+
for el in observation.ui_elements.elements:
|
|
192
|
+
# Assuming bboxes are normalized (0-1), no scaling needed
|
|
193
|
+
# If they were pixels, we would scale them here
|
|
194
|
+
new_elements.append(el)
|
|
195
|
+
new_ui_elements = UIElementGraph(elements=new_elements)
|
|
196
|
+
|
|
197
|
+
new_observation = Observation(
|
|
198
|
+
screenshot_path=observation.screenshot_path, # Would need actual resize
|
|
199
|
+
screenshot_bytes=observation.screenshot_bytes,
|
|
200
|
+
screen_size=new_screen_size,
|
|
201
|
+
ui_elements=new_ui_elements,
|
|
202
|
+
window_title=observation.window_title,
|
|
203
|
+
url=observation.url,
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
# Ground truth coordinates are normalized, so no change needed
|
|
207
|
+
# If they were pixels, we would scale them
|
|
208
|
+
new_ground_truth = ground_truth.copy()
|
|
209
|
+
|
|
210
|
+
logger.debug(f"Applied resolution drift {scale}x: {new_screen_size}")
|
|
211
|
+
return new_observation, new_ground_truth
|
|
212
|
+
|
|
213
|
+
@staticmethod
|
|
214
|
+
def _apply_translation_drift(
|
|
215
|
+
observation: Observation,
|
|
216
|
+
ground_truth: dict[str, Any],
|
|
217
|
+
params: TranslationDriftParams,
|
|
218
|
+
) -> tuple[Observation, dict[str, Any]]:
|
|
219
|
+
"""Apply window translation.
|
|
220
|
+
|
|
221
|
+
This shifts the window position while keeping the UI elements
|
|
222
|
+
in their relative positions within the window.
|
|
223
|
+
"""
|
|
224
|
+
offset_x = params.offset_x
|
|
225
|
+
offset_y = params.offset_y
|
|
226
|
+
|
|
227
|
+
# For normalized coordinates within the window, no change is needed
|
|
228
|
+
# The translation affects where the window is on screen, but not
|
|
229
|
+
# the relative positions within the window
|
|
230
|
+
|
|
231
|
+
# However, if coordinates are screen-absolute, we need to adjust
|
|
232
|
+
# For this experiment, we assume window-relative normalized coords
|
|
233
|
+
|
|
234
|
+
new_ground_truth = ground_truth.copy()
|
|
235
|
+
|
|
236
|
+
# If ground truth has screen-absolute coordinates, adjust them
|
|
237
|
+
if "screen_x" in ground_truth and "screen_y" in ground_truth:
|
|
238
|
+
# Convert pixel offset to normalized offset
|
|
239
|
+
if observation.screen_size:
|
|
240
|
+
w, h = observation.screen_size
|
|
241
|
+
norm_offset_x = offset_x / w
|
|
242
|
+
norm_offset_y = offset_y / h
|
|
243
|
+
new_ground_truth["screen_x"] = ground_truth["screen_x"] + norm_offset_x
|
|
244
|
+
new_ground_truth["screen_y"] = ground_truth["screen_y"] + norm_offset_y
|
|
245
|
+
|
|
246
|
+
logger.debug(f"Applied translation drift: ({offset_x}, {offset_y})")
|
|
247
|
+
return observation, new_ground_truth
|
|
248
|
+
|
|
249
|
+
@staticmethod
|
|
250
|
+
def _apply_theme_drift(
|
|
251
|
+
observation: Observation,
|
|
252
|
+
ground_truth: dict[str, Any],
|
|
253
|
+
params: ThemeDriftParams,
|
|
254
|
+
) -> tuple[Observation, dict[str, Any]]:
|
|
255
|
+
"""Apply theme change.
|
|
256
|
+
|
|
257
|
+
Theme changes affect visual appearance but not coordinates.
|
|
258
|
+
Full implementation would load theme-variant screenshots.
|
|
259
|
+
"""
|
|
260
|
+
theme = params.theme
|
|
261
|
+
|
|
262
|
+
# For scaffolding, we don't transform the screenshot
|
|
263
|
+
# Full implementation would:
|
|
264
|
+
# 1. Load a pre-recorded screenshot in the target theme, OR
|
|
265
|
+
# 2. Apply synthetic color transformations
|
|
266
|
+
|
|
267
|
+
logger.debug(f"Applied theme drift: {theme}")
|
|
268
|
+
return observation, ground_truth
|
|
269
|
+
|
|
270
|
+
@staticmethod
|
|
271
|
+
def _apply_scroll_drift(
|
|
272
|
+
observation: Observation,
|
|
273
|
+
ground_truth: dict[str, Any],
|
|
274
|
+
params: ScrollDriftParams,
|
|
275
|
+
) -> tuple[Observation, dict[str, Any]]:
|
|
276
|
+
"""Apply scroll offset.
|
|
277
|
+
|
|
278
|
+
Scroll changes the visible portion of the page, affecting
|
|
279
|
+
which elements are visible and their y-coordinates.
|
|
280
|
+
"""
|
|
281
|
+
offset_y = params.offset_y
|
|
282
|
+
|
|
283
|
+
# Adjust UI element bboxes for scroll
|
|
284
|
+
new_ui_elements = None
|
|
285
|
+
if observation.ui_elements and observation.screen_size:
|
|
286
|
+
_, screen_h = observation.screen_size
|
|
287
|
+
norm_offset = offset_y / screen_h
|
|
288
|
+
|
|
289
|
+
new_elements = []
|
|
290
|
+
for el in observation.ui_elements.elements:
|
|
291
|
+
x1, y1, x2, y2 = el.bbox
|
|
292
|
+
# Shift y coordinates up by scroll amount
|
|
293
|
+
new_y1 = y1 - norm_offset
|
|
294
|
+
new_y2 = y2 - norm_offset
|
|
295
|
+
|
|
296
|
+
# Only include elements still visible on screen
|
|
297
|
+
if new_y2 > 0 and new_y1 < 1:
|
|
298
|
+
new_elements.append(
|
|
299
|
+
UIElement(
|
|
300
|
+
element_id=el.element_id,
|
|
301
|
+
role=el.role,
|
|
302
|
+
name=el.name,
|
|
303
|
+
bbox=(x1, max(0, new_y1), x2, min(1, new_y2)),
|
|
304
|
+
)
|
|
305
|
+
)
|
|
306
|
+
|
|
307
|
+
new_ui_elements = UIElementGraph(elements=new_elements)
|
|
308
|
+
|
|
309
|
+
new_observation = Observation(
|
|
310
|
+
screenshot_path=observation.screenshot_path, # Would need scroll-shifted image
|
|
311
|
+
screenshot_bytes=observation.screenshot_bytes,
|
|
312
|
+
screen_size=observation.screen_size,
|
|
313
|
+
ui_elements=new_ui_elements,
|
|
314
|
+
window_title=observation.window_title,
|
|
315
|
+
url=observation.url,
|
|
316
|
+
)
|
|
317
|
+
|
|
318
|
+
# Adjust ground truth coordinates
|
|
319
|
+
new_ground_truth = ground_truth.copy()
|
|
320
|
+
if "y" in ground_truth and observation.screen_size:
|
|
321
|
+
_, screen_h = observation.screen_size
|
|
322
|
+
norm_offset = offset_y / screen_h
|
|
323
|
+
new_ground_truth["y"] = ground_truth["y"] - norm_offset
|
|
324
|
+
|
|
325
|
+
logger.debug(f"Applied scroll drift: {offset_y}px")
|
|
326
|
+
return new_observation, new_ground_truth
|
|
327
|
+
|
|
328
|
+
|
|
329
|
+
def compute_metrics(
|
|
330
|
+
prediction: ParsedAction,
|
|
331
|
+
ground_truth: dict[str, Any],
|
|
332
|
+
ui_elements: UIElementGraph | None = None,
|
|
333
|
+
) -> dict[str, float]:
|
|
334
|
+
"""Compute all metrics for a single prediction.
|
|
335
|
+
|
|
336
|
+
Args:
|
|
337
|
+
prediction: Parsed prediction from model.
|
|
338
|
+
ground_truth: Ground truth action dict with coordinates/element_id.
|
|
339
|
+
ui_elements: UI elements (needed for click-hit computation).
|
|
340
|
+
|
|
341
|
+
Returns:
|
|
342
|
+
Dict of metric name to value.
|
|
343
|
+
"""
|
|
344
|
+
metrics: dict[str, float] = {}
|
|
345
|
+
|
|
346
|
+
# Click-Hit Rate: Is predicted coordinate within target element bbox?
|
|
347
|
+
if prediction.type == "click":
|
|
348
|
+
hit = 0.0
|
|
349
|
+
|
|
350
|
+
if prediction.x is not None and prediction.y is not None:
|
|
351
|
+
# Coordinate-based prediction
|
|
352
|
+
target_bbox = ground_truth.get("target_bbox")
|
|
353
|
+
if target_bbox:
|
|
354
|
+
x1, y1, x2, y2 = target_bbox
|
|
355
|
+
if x1 <= prediction.x <= x2 and y1 <= prediction.y <= y2:
|
|
356
|
+
hit = 1.0
|
|
357
|
+
|
|
358
|
+
# Also check if coordinates are within the target element from ui_elements
|
|
359
|
+
elif ui_elements and ground_truth.get("element_id"):
|
|
360
|
+
target_el = ui_elements.get_element(ground_truth["element_id"])
|
|
361
|
+
if target_el and target_el.contains_point(prediction.x, prediction.y):
|
|
362
|
+
hit = 1.0
|
|
363
|
+
|
|
364
|
+
elif prediction.element_id is not None and ui_elements:
|
|
365
|
+
# Element-based prediction - find element and check if it matches target
|
|
366
|
+
pred_el = ui_elements.get_element(prediction.element_id)
|
|
367
|
+
gt_el_id = ground_truth.get("element_id")
|
|
368
|
+
if pred_el and gt_el_id:
|
|
369
|
+
# Normalize IDs for comparison
|
|
370
|
+
pred_id = prediction.element_id.lower().replace("e", "")
|
|
371
|
+
gt_id = str(gt_el_id).lower().replace("e", "")
|
|
372
|
+
if pred_id == gt_id:
|
|
373
|
+
hit = 1.0
|
|
374
|
+
|
|
375
|
+
metrics[MetricName.CLICK_HIT_RATE.value] = hit
|
|
376
|
+
|
|
377
|
+
# Grounding Top-1 Accuracy: Is predicted element ID correct?
|
|
378
|
+
if prediction.element_id is not None:
|
|
379
|
+
gt_el_id = ground_truth.get("element_id")
|
|
380
|
+
if gt_el_id:
|
|
381
|
+
pred_id = prediction.element_id.lower().replace("e", "")
|
|
382
|
+
gt_id = str(gt_el_id).lower().replace("e", "")
|
|
383
|
+
metrics[MetricName.GROUNDING_TOP1_ACCURACY.value] = (
|
|
384
|
+
1.0 if pred_id == gt_id else 0.0
|
|
385
|
+
)
|
|
386
|
+
else:
|
|
387
|
+
metrics[MetricName.GROUNDING_TOP1_ACCURACY.value] = 0.0
|
|
388
|
+
|
|
389
|
+
# Coordinate Distance: L2 distance to target (normalized)
|
|
390
|
+
gt_x = ground_truth.get("x")
|
|
391
|
+
gt_y = ground_truth.get("y")
|
|
392
|
+
|
|
393
|
+
if gt_x is not None and gt_y is not None:
|
|
394
|
+
if prediction.x is not None and prediction.y is not None:
|
|
395
|
+
distance = math.sqrt(
|
|
396
|
+
(prediction.x - gt_x) ** 2 + (prediction.y - gt_y) ** 2
|
|
397
|
+
)
|
|
398
|
+
else:
|
|
399
|
+
# If prediction failed or is element-based, compute distance from element center
|
|
400
|
+
if prediction.element_id and ui_elements:
|
|
401
|
+
pred_el = ui_elements.get_element(prediction.element_id)
|
|
402
|
+
if pred_el:
|
|
403
|
+
cx, cy = pred_el.center
|
|
404
|
+
distance = math.sqrt((cx - gt_x) ** 2 + (cy - gt_y) ** 2)
|
|
405
|
+
else:
|
|
406
|
+
distance = math.sqrt(2) # Max normalized distance
|
|
407
|
+
else:
|
|
408
|
+
distance = math.sqrt(2) # Max normalized distance
|
|
409
|
+
|
|
410
|
+
metrics[MetricName.COORD_DISTANCE.value] = distance
|
|
411
|
+
|
|
412
|
+
return metrics
|
|
413
|
+
|
|
414
|
+
|
|
415
|
+
def aggregate_metrics(sample_results: list[SampleResult]) -> dict[str, float]:
|
|
416
|
+
"""Aggregate metrics across multiple samples.
|
|
417
|
+
|
|
418
|
+
Args:
|
|
419
|
+
sample_results: List of individual sample results.
|
|
420
|
+
|
|
421
|
+
Returns:
|
|
422
|
+
Dict of metric name to averaged value.
|
|
423
|
+
"""
|
|
424
|
+
if not sample_results:
|
|
425
|
+
return {}
|
|
426
|
+
|
|
427
|
+
# Collect all metrics
|
|
428
|
+
all_metrics: dict[str, list[float]] = {}
|
|
429
|
+
for result in sample_results:
|
|
430
|
+
for metric_name, value in result.metrics.items():
|
|
431
|
+
if metric_name not in all_metrics:
|
|
432
|
+
all_metrics[metric_name] = []
|
|
433
|
+
all_metrics[metric_name].append(value)
|
|
434
|
+
|
|
435
|
+
# Compute averages
|
|
436
|
+
aggregated = {}
|
|
437
|
+
for metric_name, values in all_metrics.items():
|
|
438
|
+
aggregated[metric_name] = sum(values) / len(values)
|
|
439
|
+
|
|
440
|
+
return aggregated
|
|
441
|
+
|
|
442
|
+
|
|
443
|
+
class DriftEvaluator:
|
|
444
|
+
"""Evaluates conditions under drift conditions.
|
|
445
|
+
|
|
446
|
+
This class orchestrates the evaluation process:
|
|
447
|
+
1. Apply drift transformations to samples
|
|
448
|
+
2. Generate predictions using conditions
|
|
449
|
+
3. Compute metrics
|
|
450
|
+
4. Aggregate results
|
|
451
|
+
"""
|
|
452
|
+
|
|
453
|
+
def __init__(
|
|
454
|
+
self,
|
|
455
|
+
conditions: dict[ConditionName, ConditionBase],
|
|
456
|
+
drift_configs: list[DriftConfig],
|
|
457
|
+
):
|
|
458
|
+
"""Initialize evaluator.
|
|
459
|
+
|
|
460
|
+
Args:
|
|
461
|
+
conditions: Map of condition name to condition instance.
|
|
462
|
+
drift_configs: List of drift configurations to test.
|
|
463
|
+
"""
|
|
464
|
+
self.conditions = conditions
|
|
465
|
+
self.drift_configs = drift_configs
|
|
466
|
+
self._canonical_results: dict[ConditionName, dict[str, float]] = {}
|
|
467
|
+
|
|
468
|
+
def evaluate_sample(
|
|
469
|
+
self,
|
|
470
|
+
condition: ConditionBase,
|
|
471
|
+
sample: Sample,
|
|
472
|
+
drift_config: DriftConfig,
|
|
473
|
+
model_output: str,
|
|
474
|
+
) -> SampleResult:
|
|
475
|
+
"""Evaluate a single sample under a drift condition.
|
|
476
|
+
|
|
477
|
+
Args:
|
|
478
|
+
condition: Condition to use for evaluation.
|
|
479
|
+
sample: Sample to evaluate.
|
|
480
|
+
drift_config: Drift to apply.
|
|
481
|
+
model_output: Raw model output to parse.
|
|
482
|
+
|
|
483
|
+
Returns:
|
|
484
|
+
SampleResult with metrics.
|
|
485
|
+
"""
|
|
486
|
+
# Apply drift
|
|
487
|
+
transformed_obs, transformed_gt = DriftTransformer.apply_drift(
|
|
488
|
+
sample.observation, sample.ground_truth, drift_config
|
|
489
|
+
)
|
|
490
|
+
|
|
491
|
+
# Parse model output
|
|
492
|
+
prediction = condition.parse_output(model_output)
|
|
493
|
+
|
|
494
|
+
# Compute metrics
|
|
495
|
+
metrics = compute_metrics(
|
|
496
|
+
prediction, transformed_gt, transformed_obs.ui_elements
|
|
497
|
+
)
|
|
498
|
+
|
|
499
|
+
return SampleResult(
|
|
500
|
+
sample_id=sample.sample_id,
|
|
501
|
+
condition=condition.name,
|
|
502
|
+
drift=drift_config.name,
|
|
503
|
+
prediction=prediction,
|
|
504
|
+
ground_truth=transformed_gt,
|
|
505
|
+
metrics=metrics,
|
|
506
|
+
)
|
|
507
|
+
|
|
508
|
+
def evaluate_condition_under_drift(
|
|
509
|
+
self,
|
|
510
|
+
condition: ConditionBase,
|
|
511
|
+
samples: list[Sample],
|
|
512
|
+
drift_config: DriftConfig,
|
|
513
|
+
model_outputs: list[str],
|
|
514
|
+
) -> EvaluationResult:
|
|
515
|
+
"""Evaluate a condition on all samples under a drift.
|
|
516
|
+
|
|
517
|
+
Args:
|
|
518
|
+
condition: Condition to evaluate.
|
|
519
|
+
samples: Samples to evaluate.
|
|
520
|
+
drift_config: Drift to apply.
|
|
521
|
+
model_outputs: Model outputs corresponding to samples.
|
|
522
|
+
|
|
523
|
+
Returns:
|
|
524
|
+
EvaluationResult with aggregated metrics.
|
|
525
|
+
"""
|
|
526
|
+
sample_results = []
|
|
527
|
+
for sample, output in zip(samples, model_outputs):
|
|
528
|
+
result = self.evaluate_sample(condition, sample, drift_config, output)
|
|
529
|
+
sample_results.append(result)
|
|
530
|
+
|
|
531
|
+
aggregated = aggregate_metrics(sample_results)
|
|
532
|
+
|
|
533
|
+
return EvaluationResult(
|
|
534
|
+
condition=condition.name,
|
|
535
|
+
drift=drift_config.name,
|
|
536
|
+
num_samples=len(samples),
|
|
537
|
+
metrics=aggregated,
|
|
538
|
+
sample_results=sample_results,
|
|
539
|
+
)
|
|
540
|
+
|
|
541
|
+
def compute_robustness_scores(
|
|
542
|
+
self,
|
|
543
|
+
results: list[EvaluationResult],
|
|
544
|
+
primary_metric: str = MetricName.CLICK_HIT_RATE.value,
|
|
545
|
+
) -> dict[ConditionName, dict[str, float]]:
|
|
546
|
+
"""Compute robustness scores relative to canonical baseline.
|
|
547
|
+
|
|
548
|
+
Args:
|
|
549
|
+
results: Evaluation results across conditions and drifts.
|
|
550
|
+
primary_metric: Metric to use for robustness computation.
|
|
551
|
+
|
|
552
|
+
Returns:
|
|
553
|
+
Dict mapping condition to dict of drift to robustness score.
|
|
554
|
+
"""
|
|
555
|
+
# Group results by condition
|
|
556
|
+
by_condition: dict[ConditionName, list[EvaluationResult]] = {}
|
|
557
|
+
for r in results:
|
|
558
|
+
if r.condition not in by_condition:
|
|
559
|
+
by_condition[r.condition] = []
|
|
560
|
+
by_condition[r.condition].append(r)
|
|
561
|
+
|
|
562
|
+
robustness_scores: dict[ConditionName, dict[str, float]] = {}
|
|
563
|
+
|
|
564
|
+
for condition, cond_results in by_condition.items():
|
|
565
|
+
# Find canonical result
|
|
566
|
+
canonical_result = next(
|
|
567
|
+
(r for r in cond_results if r.drift == "canonical"), None
|
|
568
|
+
)
|
|
569
|
+
if not canonical_result:
|
|
570
|
+
logger.warning(f"No canonical result for condition {condition}")
|
|
571
|
+
continue
|
|
572
|
+
|
|
573
|
+
canonical_value = canonical_result.metrics.get(primary_metric, 0)
|
|
574
|
+
if canonical_value == 0:
|
|
575
|
+
canonical_value = 1e-6 # Avoid division by zero
|
|
576
|
+
|
|
577
|
+
robustness_scores[condition] = {}
|
|
578
|
+
for r in cond_results:
|
|
579
|
+
if r.drift == "canonical":
|
|
580
|
+
robustness_scores[condition][r.drift] = 1.0
|
|
581
|
+
else:
|
|
582
|
+
drift_value = r.metrics.get(primary_metric, 0)
|
|
583
|
+
robustness_scores[condition][r.drift] = (
|
|
584
|
+
drift_value / canonical_value
|
|
585
|
+
)
|
|
586
|
+
|
|
587
|
+
return robustness_scores
|
|
588
|
+
|
|
589
|
+
|
|
590
|
+
def make_recommendation(
|
|
591
|
+
results: list[EvaluationResult],
|
|
592
|
+
tolerance: float = 0.05,
|
|
593
|
+
primary_metric: str = MetricName.CLICK_HIT_RATE.value,
|
|
594
|
+
) -> Recommendation:
|
|
595
|
+
"""Make recommendation based on evaluation results.
|
|
596
|
+
|
|
597
|
+
Decision rule (from design doc):
|
|
598
|
+
- If Coords+Cues within 5% of Marks under drift -> choose Coordinates
|
|
599
|
+
- Otherwise -> choose Marks
|
|
600
|
+
|
|
601
|
+
Args:
|
|
602
|
+
results: Evaluation results across all conditions and drifts.
|
|
603
|
+
tolerance: Tolerance threshold for decision (default 5%).
|
|
604
|
+
primary_metric: Metric to use for comparison.
|
|
605
|
+
|
|
606
|
+
Returns:
|
|
607
|
+
Recommendation with explanation.
|
|
608
|
+
"""
|
|
609
|
+
# Group results by condition and compute averages across drifts
|
|
610
|
+
by_condition: dict[ConditionName, list[float]] = {}
|
|
611
|
+
detailed_comparison: dict[str, dict[str, float]] = {}
|
|
612
|
+
|
|
613
|
+
for r in results:
|
|
614
|
+
if r.condition not in by_condition:
|
|
615
|
+
by_condition[r.condition] = []
|
|
616
|
+
|
|
617
|
+
metric_value = r.metrics.get(primary_metric, 0)
|
|
618
|
+
by_condition[r.condition].append(metric_value)
|
|
619
|
+
|
|
620
|
+
# Track detailed comparison
|
|
621
|
+
drift_key = r.drift
|
|
622
|
+
if drift_key not in detailed_comparison:
|
|
623
|
+
detailed_comparison[drift_key] = {}
|
|
624
|
+
detailed_comparison[drift_key][r.condition.value] = metric_value
|
|
625
|
+
|
|
626
|
+
# Compute averages
|
|
627
|
+
condition_averages: dict[ConditionName, float] = {}
|
|
628
|
+
for condition, values in by_condition.items():
|
|
629
|
+
condition_averages[condition] = sum(values) / len(values) if values else 0
|
|
630
|
+
|
|
631
|
+
# Get averages for decision
|
|
632
|
+
coords_cues_avg = condition_averages.get(ConditionName.COORDS_CUES, 0)
|
|
633
|
+
marks_avg = condition_averages.get(ConditionName.MARKS, 0)
|
|
634
|
+
|
|
635
|
+
# Apply decision rule
|
|
636
|
+
if coords_cues_avg >= marks_avg - tolerance:
|
|
637
|
+
recommended = "COORDINATES"
|
|
638
|
+
reason = (
|
|
639
|
+
f"Coords+Cues ({coords_cues_avg:.1%}) is within {tolerance * 100}% of "
|
|
640
|
+
f"Marks ({marks_avg:.1%}) under drift. Coordinates approach is simpler "
|
|
641
|
+
"and doesn't require element detection pipeline."
|
|
642
|
+
)
|
|
643
|
+
else:
|
|
644
|
+
recommended = "MARKS"
|
|
645
|
+
gap = marks_avg - coords_cues_avg
|
|
646
|
+
reason = (
|
|
647
|
+
f"Marks ({marks_avg:.1%}) outperforms Coords+Cues ({coords_cues_avg:.1%}) "
|
|
648
|
+
f"by {gap:.1%} (>{tolerance * 100}%) under drift. Element-based approach "
|
|
649
|
+
"provides better robustness to UI changes."
|
|
650
|
+
)
|
|
651
|
+
|
|
652
|
+
return Recommendation(
|
|
653
|
+
recommended=recommended,
|
|
654
|
+
reason=reason,
|
|
655
|
+
coords_cues_avg=coords_cues_avg,
|
|
656
|
+
marks_avg=marks_avg,
|
|
657
|
+
tolerance=tolerance,
|
|
658
|
+
detailed_comparison=detailed_comparison,
|
|
659
|
+
)
|