openadapt-ml 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openadapt_ml/__init__.py +0 -0
- openadapt_ml/benchmarks/__init__.py +125 -0
- openadapt_ml/benchmarks/agent.py +825 -0
- openadapt_ml/benchmarks/azure.py +761 -0
- openadapt_ml/benchmarks/base.py +366 -0
- openadapt_ml/benchmarks/cli.py +884 -0
- openadapt_ml/benchmarks/data_collection.py +432 -0
- openadapt_ml/benchmarks/runner.py +381 -0
- openadapt_ml/benchmarks/waa.py +704 -0
- openadapt_ml/cloud/__init__.py +5 -0
- openadapt_ml/cloud/azure_inference.py +441 -0
- openadapt_ml/cloud/lambda_labs.py +2445 -0
- openadapt_ml/cloud/local.py +790 -0
- openadapt_ml/config.py +56 -0
- openadapt_ml/datasets/__init__.py +0 -0
- openadapt_ml/datasets/next_action.py +507 -0
- openadapt_ml/evals/__init__.py +23 -0
- openadapt_ml/evals/grounding.py +241 -0
- openadapt_ml/evals/plot_eval_metrics.py +174 -0
- openadapt_ml/evals/trajectory_matching.py +486 -0
- openadapt_ml/grounding/__init__.py +45 -0
- openadapt_ml/grounding/base.py +236 -0
- openadapt_ml/grounding/detector.py +570 -0
- openadapt_ml/ingest/__init__.py +43 -0
- openadapt_ml/ingest/capture.py +312 -0
- openadapt_ml/ingest/loader.py +232 -0
- openadapt_ml/ingest/synthetic.py +1102 -0
- openadapt_ml/models/__init__.py +0 -0
- openadapt_ml/models/api_adapter.py +171 -0
- openadapt_ml/models/base_adapter.py +59 -0
- openadapt_ml/models/dummy_adapter.py +42 -0
- openadapt_ml/models/qwen_vl.py +426 -0
- openadapt_ml/runtime/__init__.py +0 -0
- openadapt_ml/runtime/policy.py +182 -0
- openadapt_ml/schemas/__init__.py +53 -0
- openadapt_ml/schemas/sessions.py +122 -0
- openadapt_ml/schemas/validation.py +252 -0
- openadapt_ml/scripts/__init__.py +0 -0
- openadapt_ml/scripts/compare.py +1490 -0
- openadapt_ml/scripts/demo_policy.py +62 -0
- openadapt_ml/scripts/eval_policy.py +287 -0
- openadapt_ml/scripts/make_gif.py +153 -0
- openadapt_ml/scripts/prepare_synthetic.py +43 -0
- openadapt_ml/scripts/run_qwen_login_benchmark.py +192 -0
- openadapt_ml/scripts/train.py +174 -0
- openadapt_ml/training/__init__.py +0 -0
- openadapt_ml/training/benchmark_viewer.py +1538 -0
- openadapt_ml/training/shared_ui.py +157 -0
- openadapt_ml/training/stub_provider.py +276 -0
- openadapt_ml/training/trainer.py +2446 -0
- openadapt_ml/training/viewer.py +2970 -0
- openadapt_ml-0.1.0.dist-info/METADATA +818 -0
- openadapt_ml-0.1.0.dist-info/RECORD +55 -0
- openadapt_ml-0.1.0.dist-info/WHEEL +4 -0
- openadapt_ml-0.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,486 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
from dataclasses import dataclass, field
|
|
5
|
+
from typing import Any, Callable, Dict, List, Optional
|
|
6
|
+
|
|
7
|
+
from openadapt_ml.runtime.policy import AgentPolicy
|
|
8
|
+
from openadapt_ml.schemas.sessions import Action, Episode
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass
|
|
12
|
+
class MilestoneSpec:
|
|
13
|
+
"""Defines a semantic milestone for weak episode success evaluation.
|
|
14
|
+
|
|
15
|
+
A milestone is achieved when, at a specific step, the predicted action
|
|
16
|
+
matches certain criteria (type match + optional coord threshold).
|
|
17
|
+
"""
|
|
18
|
+
name: str
|
|
19
|
+
step_index: int # Which step in the episode (0-indexed)
|
|
20
|
+
expected_type: str # Expected ground truth action type ("click", "type", "done", etc.)
|
|
21
|
+
coord_threshold: Optional[float] = None # If set, coord error must be < this for clicks
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
# Predefined milestone specs per scenario
|
|
25
|
+
# Updated for 6-step episode (no spurious WAIT):
|
|
26
|
+
# Step 0: click username, Step 1: type username, Step 2: click password,
|
|
27
|
+
# Step 3: type password, Step 4: click login, Step 5: done
|
|
28
|
+
LOGIN_MILESTONES = [
|
|
29
|
+
MilestoneSpec("typed_username", step_index=1, expected_type="type"),
|
|
30
|
+
MilestoneSpec("typed_password", step_index=3, expected_type="type"),
|
|
31
|
+
MilestoneSpec("clicked_login", step_index=4, expected_type="click", coord_threshold=0.10),
|
|
32
|
+
MilestoneSpec("emitted_done", step_index=5, expected_type="done"),
|
|
33
|
+
]
|
|
34
|
+
|
|
35
|
+
SETTINGS_MILESTONES = [
|
|
36
|
+
# Placeholder - to be defined when settings scenario is implemented
|
|
37
|
+
# MilestoneSpec("toggled_setting", step_index=..., expected_type="click", coord_threshold=0.10),
|
|
38
|
+
# MilestoneSpec("clicked_save", step_index=..., expected_type="click", coord_threshold=0.10),
|
|
39
|
+
# MilestoneSpec("emitted_done", step_index=..., expected_type="done"),
|
|
40
|
+
]
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def get_milestones_for_scenario(scenario: str = "login") -> List[MilestoneSpec]:
|
|
44
|
+
"""Return milestone specs for a given scenario."""
|
|
45
|
+
if scenario == "login":
|
|
46
|
+
return LOGIN_MILESTONES
|
|
47
|
+
elif scenario == "settings":
|
|
48
|
+
return SETTINGS_MILESTONES
|
|
49
|
+
else:
|
|
50
|
+
return [] # Unknown scenario - no semantic milestones
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
@dataclass
|
|
54
|
+
class EpisodeMetrics:
|
|
55
|
+
episode_id: str
|
|
56
|
+
step_matches: int
|
|
57
|
+
step_total: int
|
|
58
|
+
coord_errors: List[float]
|
|
59
|
+
success_pred: bool # Strict: all steps must match
|
|
60
|
+
success_gt: Optional[bool]
|
|
61
|
+
click_hits: int # Point-based: click within 5% of target center
|
|
62
|
+
click_total: int
|
|
63
|
+
# Semantic goal milestones (scenario-agnostic)
|
|
64
|
+
milestones_achieved: Dict[str, bool] = field(default_factory=dict)
|
|
65
|
+
# Full step correctness (type match + click hit when applicable)
|
|
66
|
+
full_step_correct: int = 0
|
|
67
|
+
# State-based weak success: from model's State: {"success": true/false}
|
|
68
|
+
state_success: Optional[bool] = None
|
|
69
|
+
# Bbox-based click evaluation: click anywhere within element bounds
|
|
70
|
+
bbox_hits: int = 0
|
|
71
|
+
bbox_total: int = 0
|
|
72
|
+
# SoM element index accuracy: predicted index == GT index
|
|
73
|
+
element_hits: int = 0
|
|
74
|
+
element_total: int = 0
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
@dataclass
|
|
78
|
+
class AggregateMetrics:
|
|
79
|
+
num_episodes: int
|
|
80
|
+
num_steps: int
|
|
81
|
+
action_type_accuracy: float
|
|
82
|
+
mean_coord_error: Optional[float]
|
|
83
|
+
coord_error_count: int
|
|
84
|
+
episode_success_rate: Optional[float] # Strict: all steps must match (renamed from success_pred)
|
|
85
|
+
click_hit_rate: Optional[float] # Point-based: within 5% of center
|
|
86
|
+
mean_episode_progress: Optional[float] # Partial credit: avg(step_matches/step_total)
|
|
87
|
+
# New partial-credit metrics
|
|
88
|
+
mean_episode_step_score: Optional[float] # Strict partial: avg(full_step_correct/step_total)
|
|
89
|
+
weak_episode_success_rate: Optional[float] # Semantic milestones all achieved
|
|
90
|
+
state_success_rate: Optional[float] = None # From model's State: {"success": true}
|
|
91
|
+
bbox_hit_rate: Optional[float] = None # Bbox-based: click anywhere in element bounds
|
|
92
|
+
element_accuracy: Optional[float] = None # SoM element index accuracy
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def compute_coordinate_error(pred_action: Action, gt_action: Action) -> Optional[float]:
|
|
96
|
+
"""Compute normalized L2 distance between predicted and ground-truth coords.
|
|
97
|
+
|
|
98
|
+
Returns None if either action is missing coordinates.
|
|
99
|
+
"""
|
|
100
|
+
|
|
101
|
+
if (
|
|
102
|
+
pred_action.x is None
|
|
103
|
+
or pred_action.y is None
|
|
104
|
+
or gt_action.x is None
|
|
105
|
+
or gt_action.y is None
|
|
106
|
+
):
|
|
107
|
+
return None
|
|
108
|
+
|
|
109
|
+
dx = pred_action.x - gt_action.x
|
|
110
|
+
dy = pred_action.y - gt_action.y
|
|
111
|
+
return math.sqrt(dx * dx + dy * dy)
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def is_click_in_bbox(pred_action: Action, gt_action: Action) -> Optional[bool]:
|
|
115
|
+
"""Check if predicted click falls within ground truth bounding box.
|
|
116
|
+
|
|
117
|
+
Returns:
|
|
118
|
+
- True if prediction is inside bbox
|
|
119
|
+
- False if prediction is outside bbox
|
|
120
|
+
- None if no bbox is available (fall back to coord distance)
|
|
121
|
+
"""
|
|
122
|
+
if gt_action.bbox is None:
|
|
123
|
+
return None
|
|
124
|
+
|
|
125
|
+
if pred_action.x is None or pred_action.y is None:
|
|
126
|
+
return False
|
|
127
|
+
|
|
128
|
+
x_min, y_min, x_max, y_max = gt_action.bbox
|
|
129
|
+
return (x_min <= pred_action.x <= x_max) and (y_min <= pred_action.y <= y_max)
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def evaluate_episode(
|
|
133
|
+
policy: AgentPolicy,
|
|
134
|
+
episode: Episode,
|
|
135
|
+
samples: List[Dict[str, Any]],
|
|
136
|
+
start_idx: int,
|
|
137
|
+
log_fn: Optional[Callable[[Dict[str, Any]], None]] = None,
|
|
138
|
+
log_limit: Optional[int] = None,
|
|
139
|
+
logged_count: int = 0,
|
|
140
|
+
milestones: Optional[List[MilestoneSpec]] = None,
|
|
141
|
+
use_som: bool = False,
|
|
142
|
+
) -> tuple[EpisodeMetrics, int, int]:
|
|
143
|
+
"""Evaluate a single episode offline using pre-built SFT samples.
|
|
144
|
+
|
|
145
|
+
We assume `samples` were created by iterating episodes and steps in the
|
|
146
|
+
same order as here (see `build_next_action_sft_samples`). `start_idx`
|
|
147
|
+
indicates the index of the first sample corresponding to this episode's
|
|
148
|
+
first step. The function returns the episode metrics and the next sample
|
|
149
|
+
index after this episode.
|
|
150
|
+
|
|
151
|
+
Args:
|
|
152
|
+
milestones: Optional list of MilestoneSpec to track for weak success.
|
|
153
|
+
If None, defaults to LOGIN_MILESTONES for backward compat.
|
|
154
|
+
use_som: If True, evaluate using Set-of-Marks element index matching
|
|
155
|
+
instead of coordinate-based evaluation.
|
|
156
|
+
"""
|
|
157
|
+
if milestones is None:
|
|
158
|
+
milestones = LOGIN_MILESTONES
|
|
159
|
+
|
|
160
|
+
step_matches = 0
|
|
161
|
+
step_total = 0
|
|
162
|
+
coord_errors: List[float] = []
|
|
163
|
+
success_pred = True
|
|
164
|
+
click_hits = 0 # Point-based (5% threshold)
|
|
165
|
+
click_total = 0
|
|
166
|
+
bbox_hits = 0 # Bbox-based (anywhere in element)
|
|
167
|
+
bbox_total = 0
|
|
168
|
+
element_hits = 0 # SoM element index match
|
|
169
|
+
element_total = 0
|
|
170
|
+
# Generic milestone tracking
|
|
171
|
+
milestones_achieved: Dict[str, bool] = {m.name: False for m in milestones}
|
|
172
|
+
full_step_correct = 0
|
|
173
|
+
# Track the last state's success flag
|
|
174
|
+
last_state_success: Optional[bool] = None
|
|
175
|
+
|
|
176
|
+
sample_idx = start_idx
|
|
177
|
+
|
|
178
|
+
for step_idx, step in enumerate(episode.steps):
|
|
179
|
+
# Skip steps without an image; the dataset builder does the same.
|
|
180
|
+
if not step.observation.image_path:
|
|
181
|
+
continue
|
|
182
|
+
|
|
183
|
+
if sample_idx >= len(samples):
|
|
184
|
+
break
|
|
185
|
+
|
|
186
|
+
sample = samples[sample_idx]
|
|
187
|
+
sample_idx += 1
|
|
188
|
+
|
|
189
|
+
pred_action, _thought, pred_state, raw_text = policy.predict_action_from_sample(sample)
|
|
190
|
+
gt_action = step.action
|
|
191
|
+
|
|
192
|
+
# Track state-based success from final step
|
|
193
|
+
if pred_state and isinstance(pred_state, dict):
|
|
194
|
+
success_val = pred_state.get("success")
|
|
195
|
+
if isinstance(success_val, bool):
|
|
196
|
+
last_state_success = success_val
|
|
197
|
+
|
|
198
|
+
type_match = pred_action.type == gt_action.type
|
|
199
|
+
if type_match:
|
|
200
|
+
step_matches += 1
|
|
201
|
+
else:
|
|
202
|
+
success_pred = False
|
|
203
|
+
|
|
204
|
+
coord_error: Optional[float] = None
|
|
205
|
+
click_hit = False
|
|
206
|
+
bbox_hit = False
|
|
207
|
+
element_hit = False
|
|
208
|
+
|
|
209
|
+
# SoM mode: evaluate by element index for click/drag/type actions
|
|
210
|
+
if use_som and gt_action.type in {"click", "drag", "type"}:
|
|
211
|
+
if gt_action.element_index is not None:
|
|
212
|
+
element_total += 1
|
|
213
|
+
if pred_action.element_index == gt_action.element_index:
|
|
214
|
+
element_hits += 1
|
|
215
|
+
element_hit = True
|
|
216
|
+
elif gt_action.type in {"click", "drag"}:
|
|
217
|
+
# Coordinate mode: evaluate by coordinate distance
|
|
218
|
+
coord_error = compute_coordinate_error(pred_action, gt_action)
|
|
219
|
+
if coord_error is not None:
|
|
220
|
+
coord_errors.append(coord_error)
|
|
221
|
+
click_total += 1
|
|
222
|
+
if coord_error < 0.05:
|
|
223
|
+
click_hits += 1
|
|
224
|
+
click_hit = True
|
|
225
|
+
|
|
226
|
+
# Bbox-based evaluation (more lenient)
|
|
227
|
+
in_bbox = is_click_in_bbox(pred_action, gt_action)
|
|
228
|
+
if in_bbox is not None:
|
|
229
|
+
bbox_total += 1
|
|
230
|
+
if in_bbox:
|
|
231
|
+
bbox_hits += 1
|
|
232
|
+
bbox_hit = True
|
|
233
|
+
|
|
234
|
+
# Full step correctness: type matches AND element/coord match for relevant actions
|
|
235
|
+
if type_match:
|
|
236
|
+
if use_som and gt_action.type in {"click", "drag", "type"}:
|
|
237
|
+
# SoM mode: require element index match
|
|
238
|
+
if element_hit:
|
|
239
|
+
full_step_correct += 1
|
|
240
|
+
elif gt_action.type in {"click", "drag"}:
|
|
241
|
+
# Coordinate mode: require click hit
|
|
242
|
+
if click_hit:
|
|
243
|
+
full_step_correct += 1
|
|
244
|
+
else:
|
|
245
|
+
# Non-targeting actions (wait, done): type match is sufficient
|
|
246
|
+
full_step_correct += 1
|
|
247
|
+
|
|
248
|
+
# Track semantic milestones using the milestone spec
|
|
249
|
+
for milestone in milestones:
|
|
250
|
+
if step_idx == milestone.step_index and gt_action.type == milestone.expected_type:
|
|
251
|
+
if pred_action.type == milestone.expected_type:
|
|
252
|
+
# Check coord threshold if specified (for click actions)
|
|
253
|
+
if milestone.coord_threshold is not None:
|
|
254
|
+
if coord_error is not None and coord_error < milestone.coord_threshold:
|
|
255
|
+
milestones_achieved[milestone.name] = True
|
|
256
|
+
else:
|
|
257
|
+
# No coord threshold - type match is sufficient
|
|
258
|
+
milestones_achieved[milestone.name] = True
|
|
259
|
+
|
|
260
|
+
# Ensure DONE is correct at the DONE step.
|
|
261
|
+
if gt_action.type == "done" and pred_action.type != "done":
|
|
262
|
+
success_pred = False
|
|
263
|
+
|
|
264
|
+
# Optional logging of this step.
|
|
265
|
+
if log_fn is not None and (log_limit is None or logged_count < log_limit):
|
|
266
|
+
messages = sample.get("messages", [])
|
|
267
|
+
system_prompt = None
|
|
268
|
+
user_prompt = None
|
|
269
|
+
for m in messages:
|
|
270
|
+
if m.get("role") == "system" and system_prompt is None:
|
|
271
|
+
system_prompt = m.get("content")
|
|
272
|
+
if m.get("role") == "user" and user_prompt is None:
|
|
273
|
+
user_prompt = m.get("content")
|
|
274
|
+
|
|
275
|
+
record: Dict[str, Any] = {
|
|
276
|
+
"episode_id": episode.id,
|
|
277
|
+
"step_index": step_idx,
|
|
278
|
+
"goal": episode.goal,
|
|
279
|
+
"system_prompt": system_prompt,
|
|
280
|
+
"user_prompt": user_prompt,
|
|
281
|
+
"model_output_raw": raw_text,
|
|
282
|
+
"pred_action": {
|
|
283
|
+
"type": pred_action.type,
|
|
284
|
+
"x": pred_action.x,
|
|
285
|
+
"y": pred_action.y,
|
|
286
|
+
"text": pred_action.text,
|
|
287
|
+
"element_index": pred_action.element_index,
|
|
288
|
+
},
|
|
289
|
+
"ground_truth_action": {
|
|
290
|
+
"type": gt_action.type,
|
|
291
|
+
"x": gt_action.x,
|
|
292
|
+
"y": gt_action.y,
|
|
293
|
+
"text": gt_action.text,
|
|
294
|
+
"element_index": gt_action.element_index,
|
|
295
|
+
},
|
|
296
|
+
"correct_type": pred_action.type == gt_action.type,
|
|
297
|
+
"coord_error_norm": coord_error,
|
|
298
|
+
"element_match": pred_action.element_index == gt_action.element_index
|
|
299
|
+
if gt_action.element_index is not None
|
|
300
|
+
else None,
|
|
301
|
+
}
|
|
302
|
+
|
|
303
|
+
log_fn(record)
|
|
304
|
+
logged_count += 1
|
|
305
|
+
|
|
306
|
+
step_total += 1
|
|
307
|
+
|
|
308
|
+
metrics = EpisodeMetrics(
|
|
309
|
+
episode_id=episode.id,
|
|
310
|
+
step_matches=step_matches,
|
|
311
|
+
step_total=step_total,
|
|
312
|
+
coord_errors=coord_errors,
|
|
313
|
+
success_pred=success_pred,
|
|
314
|
+
success_gt=episode.success,
|
|
315
|
+
click_hits=click_hits,
|
|
316
|
+
click_total=click_total,
|
|
317
|
+
milestones_achieved=milestones_achieved,
|
|
318
|
+
full_step_correct=full_step_correct,
|
|
319
|
+
state_success=last_state_success,
|
|
320
|
+
bbox_hits=bbox_hits,
|
|
321
|
+
bbox_total=bbox_total,
|
|
322
|
+
element_hits=element_hits,
|
|
323
|
+
element_total=element_total,
|
|
324
|
+
)
|
|
325
|
+
return metrics, sample_idx, logged_count
|
|
326
|
+
|
|
327
|
+
|
|
328
|
+
def aggregate_metrics(episodes_metrics: List[EpisodeMetrics]) -> AggregateMetrics:
|
|
329
|
+
"""Aggregate per-episode metrics into global metrics.
|
|
330
|
+
|
|
331
|
+
Three-tier episodic success metrics (from least to most strict):
|
|
332
|
+
|
|
333
|
+
1. **weak_episode_success_rate**: Semantic goal completion. For the login
|
|
334
|
+
flow, requires: typed username, typed password, clicked login button
|
|
335
|
+
(within 10% coord error), and emitted DONE. Allows intermediate mistakes.
|
|
336
|
+
|
|
337
|
+
2. **mean_episode_step_score**: Strict partial credit. Average of
|
|
338
|
+
(full_step_correct / step_total) per episode. A step is "full correct"
|
|
339
|
+
if action type matches AND (not a click OR click within 5% threshold).
|
|
340
|
+
|
|
341
|
+
3. **episode_success_rate**: Hard metric. All steps must match exactly
|
|
342
|
+
(action type correct AND click hits where applicable). This is the
|
|
343
|
+
long-horizon metric that only becomes meaningful at high step accuracy.
|
|
344
|
+
|
|
345
|
+
Also computes:
|
|
346
|
+
- action_type_accuracy: total correct types / total steps.
|
|
347
|
+
- mean_coord_error: mean of all collected coordinate errors.
|
|
348
|
+
- mean_episode_progress: avg(step_matches / step_total) - type matches only.
|
|
349
|
+
"""
|
|
350
|
+
|
|
351
|
+
num_episodes = len(episodes_metrics)
|
|
352
|
+
num_steps = sum(m.step_total for m in episodes_metrics)
|
|
353
|
+
|
|
354
|
+
total_matches = sum(m.step_matches for m in episodes_metrics)
|
|
355
|
+
action_type_accuracy = (total_matches / num_steps) if num_steps > 0 else 0.0
|
|
356
|
+
|
|
357
|
+
all_coord_errors: List[float] = []
|
|
358
|
+
for m in episodes_metrics:
|
|
359
|
+
all_coord_errors.extend(m.coord_errors)
|
|
360
|
+
|
|
361
|
+
mean_coord_error: Optional[float]
|
|
362
|
+
if all_coord_errors:
|
|
363
|
+
mean_coord_error = sum(all_coord_errors) / len(all_coord_errors)
|
|
364
|
+
else:
|
|
365
|
+
mean_coord_error = None
|
|
366
|
+
|
|
367
|
+
eval_episodes = [m for m in episodes_metrics if m.step_total > 0]
|
|
368
|
+
if eval_episodes:
|
|
369
|
+
success_count = sum(1 for m in eval_episodes if m.success_pred)
|
|
370
|
+
episode_success_rate = success_count / len(eval_episodes)
|
|
371
|
+
else:
|
|
372
|
+
episode_success_rate = None
|
|
373
|
+
|
|
374
|
+
total_click_hits = sum(m.click_hits for m in episodes_metrics)
|
|
375
|
+
total_click_total = sum(m.click_total for m in episodes_metrics)
|
|
376
|
+
if total_click_total > 0:
|
|
377
|
+
click_hit_rate: Optional[float] = total_click_hits / total_click_total
|
|
378
|
+
else:
|
|
379
|
+
click_hit_rate = None
|
|
380
|
+
|
|
381
|
+
# Partial credit: average episode progress (step_matches / step_total per episode)
|
|
382
|
+
if eval_episodes:
|
|
383
|
+
episode_progress_scores = [
|
|
384
|
+
m.step_matches / m.step_total for m in eval_episodes
|
|
385
|
+
]
|
|
386
|
+
mean_episode_progress = sum(episode_progress_scores) / len(episode_progress_scores)
|
|
387
|
+
else:
|
|
388
|
+
mean_episode_progress = None
|
|
389
|
+
|
|
390
|
+
# Strict partial: avg(full_step_correct / step_total) - requires type match + click hit
|
|
391
|
+
if eval_episodes:
|
|
392
|
+
step_scores = [
|
|
393
|
+
m.full_step_correct / m.step_total for m in eval_episodes
|
|
394
|
+
]
|
|
395
|
+
mean_episode_step_score = sum(step_scores) / len(step_scores)
|
|
396
|
+
else:
|
|
397
|
+
mean_episode_step_score = None
|
|
398
|
+
|
|
399
|
+
# Weak episode success: all milestones achieved
|
|
400
|
+
if eval_episodes:
|
|
401
|
+
weak_success_count = sum(
|
|
402
|
+
1 for m in eval_episodes
|
|
403
|
+
if m.milestones_achieved and all(m.milestones_achieved.values())
|
|
404
|
+
)
|
|
405
|
+
weak_episode_success_rate = weak_success_count / len(eval_episodes)
|
|
406
|
+
else:
|
|
407
|
+
weak_episode_success_rate = None
|
|
408
|
+
|
|
409
|
+
# State-based success: from model's State: {"success": true}
|
|
410
|
+
episodes_with_state = [m for m in eval_episodes if m.state_success is not None]
|
|
411
|
+
if episodes_with_state:
|
|
412
|
+
state_success_count = sum(1 for m in episodes_with_state if m.state_success)
|
|
413
|
+
state_success_rate = state_success_count / len(episodes_with_state)
|
|
414
|
+
else:
|
|
415
|
+
state_success_rate = None
|
|
416
|
+
|
|
417
|
+
# Bbox-based click evaluation (more lenient than point-based)
|
|
418
|
+
total_bbox_hits = sum(m.bbox_hits for m in episodes_metrics)
|
|
419
|
+
total_bbox_total = sum(m.bbox_total for m in episodes_metrics)
|
|
420
|
+
if total_bbox_total > 0:
|
|
421
|
+
bbox_hit_rate: Optional[float] = total_bbox_hits / total_bbox_total
|
|
422
|
+
else:
|
|
423
|
+
bbox_hit_rate = None
|
|
424
|
+
|
|
425
|
+
# SoM element index accuracy
|
|
426
|
+
total_element_hits = sum(m.element_hits for m in episodes_metrics)
|
|
427
|
+
total_element_total = sum(m.element_total for m in episodes_metrics)
|
|
428
|
+
if total_element_total > 0:
|
|
429
|
+
element_accuracy: Optional[float] = total_element_hits / total_element_total
|
|
430
|
+
else:
|
|
431
|
+
element_accuracy = None
|
|
432
|
+
|
|
433
|
+
return AggregateMetrics(
|
|
434
|
+
num_episodes=num_episodes,
|
|
435
|
+
num_steps=num_steps,
|
|
436
|
+
action_type_accuracy=action_type_accuracy,
|
|
437
|
+
mean_coord_error=mean_coord_error,
|
|
438
|
+
coord_error_count=len(all_coord_errors),
|
|
439
|
+
episode_success_rate=episode_success_rate,
|
|
440
|
+
click_hit_rate=click_hit_rate,
|
|
441
|
+
mean_episode_progress=mean_episode_progress,
|
|
442
|
+
mean_episode_step_score=mean_episode_step_score,
|
|
443
|
+
weak_episode_success_rate=weak_episode_success_rate,
|
|
444
|
+
state_success_rate=state_success_rate,
|
|
445
|
+
bbox_hit_rate=bbox_hit_rate,
|
|
446
|
+
element_accuracy=element_accuracy,
|
|
447
|
+
)
|
|
448
|
+
|
|
449
|
+
|
|
450
|
+
def evaluate_policy_on_episodes(
|
|
451
|
+
policy: AgentPolicy,
|
|
452
|
+
episodes: List[Episode],
|
|
453
|
+
samples: List[Dict[str, Any]],
|
|
454
|
+
log_fn: Optional[Callable[[Dict[str, Any]], None]] = None,
|
|
455
|
+
log_limit: Optional[int] = None,
|
|
456
|
+
use_som: bool = False,
|
|
457
|
+
) -> AggregateMetrics:
|
|
458
|
+
"""Evaluate a policy on a list of episodes given corresponding SFT samples.
|
|
459
|
+
|
|
460
|
+
The `samples` list must have been produced from `episodes` using
|
|
461
|
+
`build_next_action_sft_samples`, so that iterating episodes/steps in order
|
|
462
|
+
aligns with iterating over `samples`.
|
|
463
|
+
|
|
464
|
+
Args:
|
|
465
|
+
use_som: If True, evaluate using Set-of-Marks element index matching
|
|
466
|
+
instead of coordinate-based evaluation.
|
|
467
|
+
"""
|
|
468
|
+
|
|
469
|
+
episodes_metrics: List[EpisodeMetrics] = []
|
|
470
|
+
sample_idx = 0
|
|
471
|
+
logged_count = 0
|
|
472
|
+
|
|
473
|
+
for episode in episodes:
|
|
474
|
+
metrics, sample_idx, logged_count = evaluate_episode(
|
|
475
|
+
policy,
|
|
476
|
+
episode,
|
|
477
|
+
samples,
|
|
478
|
+
sample_idx,
|
|
479
|
+
log_fn=log_fn,
|
|
480
|
+
log_limit=log_limit,
|
|
481
|
+
logged_count=logged_count,
|
|
482
|
+
use_som=use_som,
|
|
483
|
+
)
|
|
484
|
+
episodes_metrics.append(metrics)
|
|
485
|
+
|
|
486
|
+
return aggregate_metrics(episodes_metrics)
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
"""Grounding modules for visual element localization.
|
|
2
|
+
|
|
3
|
+
This package provides strategies for grounding natural language target descriptions
|
|
4
|
+
to specific regions on screenshots. The grounding/policy separation enables:
|
|
5
|
+
|
|
6
|
+
- Training policy and grounding separately or jointly
|
|
7
|
+
- Swapping grounding strategies without retraining policy
|
|
8
|
+
- Evaluating each layer independently
|
|
9
|
+
- Composing different grounding modules per platform
|
|
10
|
+
|
|
11
|
+
Available grounding strategies:
|
|
12
|
+
|
|
13
|
+
- OracleGrounder: Uses ground-truth bboxes (for evaluation)
|
|
14
|
+
- GeminiGrounder: Uses Google Gemini vision API for element detection
|
|
15
|
+
- DetectorGrounder: Generic detector wrapper with backend selection
|
|
16
|
+
- SoMGrounder: Element index selection from Set-of-Marks overlay (coming soon)
|
|
17
|
+
- AttentionGrounder: GUI-Actor style attention-based selection (coming soon)
|
|
18
|
+
|
|
19
|
+
Functions:
|
|
20
|
+
|
|
21
|
+
- extract_ui_elements: Extract all interactive UI elements from a screenshot
|
|
22
|
+
- overlay_element_marks: Overlay numbered labels (Set-of-Marks) on elements
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
from openadapt_ml.grounding.base import (
|
|
26
|
+
GroundingModule,
|
|
27
|
+
OracleGrounder,
|
|
28
|
+
RegionCandidate,
|
|
29
|
+
)
|
|
30
|
+
from openadapt_ml.grounding.detector import (
|
|
31
|
+
DetectorGrounder,
|
|
32
|
+
GeminiGrounder,
|
|
33
|
+
extract_ui_elements,
|
|
34
|
+
overlay_element_marks,
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
__all__ = [
|
|
38
|
+
"GroundingModule",
|
|
39
|
+
"OracleGrounder",
|
|
40
|
+
"RegionCandidate",
|
|
41
|
+
"DetectorGrounder",
|
|
42
|
+
"GeminiGrounder",
|
|
43
|
+
"extract_ui_elements",
|
|
44
|
+
"overlay_element_marks",
|
|
45
|
+
]
|