openadapt-ml 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. openadapt_ml/__init__.py +0 -0
  2. openadapt_ml/benchmarks/__init__.py +125 -0
  3. openadapt_ml/benchmarks/agent.py +825 -0
  4. openadapt_ml/benchmarks/azure.py +761 -0
  5. openadapt_ml/benchmarks/base.py +366 -0
  6. openadapt_ml/benchmarks/cli.py +884 -0
  7. openadapt_ml/benchmarks/data_collection.py +432 -0
  8. openadapt_ml/benchmarks/runner.py +381 -0
  9. openadapt_ml/benchmarks/waa.py +704 -0
  10. openadapt_ml/cloud/__init__.py +5 -0
  11. openadapt_ml/cloud/azure_inference.py +441 -0
  12. openadapt_ml/cloud/lambda_labs.py +2445 -0
  13. openadapt_ml/cloud/local.py +790 -0
  14. openadapt_ml/config.py +56 -0
  15. openadapt_ml/datasets/__init__.py +0 -0
  16. openadapt_ml/datasets/next_action.py +507 -0
  17. openadapt_ml/evals/__init__.py +23 -0
  18. openadapt_ml/evals/grounding.py +241 -0
  19. openadapt_ml/evals/plot_eval_metrics.py +174 -0
  20. openadapt_ml/evals/trajectory_matching.py +486 -0
  21. openadapt_ml/grounding/__init__.py +45 -0
  22. openadapt_ml/grounding/base.py +236 -0
  23. openadapt_ml/grounding/detector.py +570 -0
  24. openadapt_ml/ingest/__init__.py +43 -0
  25. openadapt_ml/ingest/capture.py +312 -0
  26. openadapt_ml/ingest/loader.py +232 -0
  27. openadapt_ml/ingest/synthetic.py +1102 -0
  28. openadapt_ml/models/__init__.py +0 -0
  29. openadapt_ml/models/api_adapter.py +171 -0
  30. openadapt_ml/models/base_adapter.py +59 -0
  31. openadapt_ml/models/dummy_adapter.py +42 -0
  32. openadapt_ml/models/qwen_vl.py +426 -0
  33. openadapt_ml/runtime/__init__.py +0 -0
  34. openadapt_ml/runtime/policy.py +182 -0
  35. openadapt_ml/schemas/__init__.py +53 -0
  36. openadapt_ml/schemas/sessions.py +122 -0
  37. openadapt_ml/schemas/validation.py +252 -0
  38. openadapt_ml/scripts/__init__.py +0 -0
  39. openadapt_ml/scripts/compare.py +1490 -0
  40. openadapt_ml/scripts/demo_policy.py +62 -0
  41. openadapt_ml/scripts/eval_policy.py +287 -0
  42. openadapt_ml/scripts/make_gif.py +153 -0
  43. openadapt_ml/scripts/prepare_synthetic.py +43 -0
  44. openadapt_ml/scripts/run_qwen_login_benchmark.py +192 -0
  45. openadapt_ml/scripts/train.py +174 -0
  46. openadapt_ml/training/__init__.py +0 -0
  47. openadapt_ml/training/benchmark_viewer.py +1538 -0
  48. openadapt_ml/training/shared_ui.py +157 -0
  49. openadapt_ml/training/stub_provider.py +276 -0
  50. openadapt_ml/training/trainer.py +2446 -0
  51. openadapt_ml/training/viewer.py +2970 -0
  52. openadapt_ml-0.1.0.dist-info/METADATA +818 -0
  53. openadapt_ml-0.1.0.dist-info/RECORD +55 -0
  54. openadapt_ml-0.1.0.dist-info/WHEEL +4 -0
  55. openadapt_ml-0.1.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,486 @@
1
+ from __future__ import annotations
2
+
3
+ import math
4
+ from dataclasses import dataclass, field
5
+ from typing import Any, Callable, Dict, List, Optional
6
+
7
+ from openadapt_ml.runtime.policy import AgentPolicy
8
+ from openadapt_ml.schemas.sessions import Action, Episode
9
+
10
+
11
+ @dataclass
12
+ class MilestoneSpec:
13
+ """Defines a semantic milestone for weak episode success evaluation.
14
+
15
+ A milestone is achieved when, at a specific step, the predicted action
16
+ matches certain criteria (type match + optional coord threshold).
17
+ """
18
+ name: str
19
+ step_index: int # Which step in the episode (0-indexed)
20
+ expected_type: str # Expected ground truth action type ("click", "type", "done", etc.)
21
+ coord_threshold: Optional[float] = None # If set, coord error must be < this for clicks
22
+
23
+
24
+ # Predefined milestone specs per scenario
25
+ # Updated for 6-step episode (no spurious WAIT):
26
+ # Step 0: click username, Step 1: type username, Step 2: click password,
27
+ # Step 3: type password, Step 4: click login, Step 5: done
28
+ LOGIN_MILESTONES = [
29
+ MilestoneSpec("typed_username", step_index=1, expected_type="type"),
30
+ MilestoneSpec("typed_password", step_index=3, expected_type="type"),
31
+ MilestoneSpec("clicked_login", step_index=4, expected_type="click", coord_threshold=0.10),
32
+ MilestoneSpec("emitted_done", step_index=5, expected_type="done"),
33
+ ]
34
+
35
+ SETTINGS_MILESTONES = [
36
+ # Placeholder - to be defined when settings scenario is implemented
37
+ # MilestoneSpec("toggled_setting", step_index=..., expected_type="click", coord_threshold=0.10),
38
+ # MilestoneSpec("clicked_save", step_index=..., expected_type="click", coord_threshold=0.10),
39
+ # MilestoneSpec("emitted_done", step_index=..., expected_type="done"),
40
+ ]
41
+
42
+
43
+ def get_milestones_for_scenario(scenario: str = "login") -> List[MilestoneSpec]:
44
+ """Return milestone specs for a given scenario."""
45
+ if scenario == "login":
46
+ return LOGIN_MILESTONES
47
+ elif scenario == "settings":
48
+ return SETTINGS_MILESTONES
49
+ else:
50
+ return [] # Unknown scenario - no semantic milestones
51
+
52
+
53
+ @dataclass
54
+ class EpisodeMetrics:
55
+ episode_id: str
56
+ step_matches: int
57
+ step_total: int
58
+ coord_errors: List[float]
59
+ success_pred: bool # Strict: all steps must match
60
+ success_gt: Optional[bool]
61
+ click_hits: int # Point-based: click within 5% of target center
62
+ click_total: int
63
+ # Semantic goal milestones (scenario-agnostic)
64
+ milestones_achieved: Dict[str, bool] = field(default_factory=dict)
65
+ # Full step correctness (type match + click hit when applicable)
66
+ full_step_correct: int = 0
67
+ # State-based weak success: from model's State: {"success": true/false}
68
+ state_success: Optional[bool] = None
69
+ # Bbox-based click evaluation: click anywhere within element bounds
70
+ bbox_hits: int = 0
71
+ bbox_total: int = 0
72
+ # SoM element index accuracy: predicted index == GT index
73
+ element_hits: int = 0
74
+ element_total: int = 0
75
+
76
+
77
+ @dataclass
78
+ class AggregateMetrics:
79
+ num_episodes: int
80
+ num_steps: int
81
+ action_type_accuracy: float
82
+ mean_coord_error: Optional[float]
83
+ coord_error_count: int
84
+ episode_success_rate: Optional[float] # Strict: all steps must match (renamed from success_pred)
85
+ click_hit_rate: Optional[float] # Point-based: within 5% of center
86
+ mean_episode_progress: Optional[float] # Partial credit: avg(step_matches/step_total)
87
+ # New partial-credit metrics
88
+ mean_episode_step_score: Optional[float] # Strict partial: avg(full_step_correct/step_total)
89
+ weak_episode_success_rate: Optional[float] # Semantic milestones all achieved
90
+ state_success_rate: Optional[float] = None # From model's State: {"success": true}
91
+ bbox_hit_rate: Optional[float] = None # Bbox-based: click anywhere in element bounds
92
+ element_accuracy: Optional[float] = None # SoM element index accuracy
93
+
94
+
95
+ def compute_coordinate_error(pred_action: Action, gt_action: Action) -> Optional[float]:
96
+ """Compute normalized L2 distance between predicted and ground-truth coords.
97
+
98
+ Returns None if either action is missing coordinates.
99
+ """
100
+
101
+ if (
102
+ pred_action.x is None
103
+ or pred_action.y is None
104
+ or gt_action.x is None
105
+ or gt_action.y is None
106
+ ):
107
+ return None
108
+
109
+ dx = pred_action.x - gt_action.x
110
+ dy = pred_action.y - gt_action.y
111
+ return math.sqrt(dx * dx + dy * dy)
112
+
113
+
114
+ def is_click_in_bbox(pred_action: Action, gt_action: Action) -> Optional[bool]:
115
+ """Check if predicted click falls within ground truth bounding box.
116
+
117
+ Returns:
118
+ - True if prediction is inside bbox
119
+ - False if prediction is outside bbox
120
+ - None if no bbox is available (fall back to coord distance)
121
+ """
122
+ if gt_action.bbox is None:
123
+ return None
124
+
125
+ if pred_action.x is None or pred_action.y is None:
126
+ return False
127
+
128
+ x_min, y_min, x_max, y_max = gt_action.bbox
129
+ return (x_min <= pred_action.x <= x_max) and (y_min <= pred_action.y <= y_max)
130
+
131
+
132
+ def evaluate_episode(
133
+ policy: AgentPolicy,
134
+ episode: Episode,
135
+ samples: List[Dict[str, Any]],
136
+ start_idx: int,
137
+ log_fn: Optional[Callable[[Dict[str, Any]], None]] = None,
138
+ log_limit: Optional[int] = None,
139
+ logged_count: int = 0,
140
+ milestones: Optional[List[MilestoneSpec]] = None,
141
+ use_som: bool = False,
142
+ ) -> tuple[EpisodeMetrics, int, int]:
143
+ """Evaluate a single episode offline using pre-built SFT samples.
144
+
145
+ We assume `samples` were created by iterating episodes and steps in the
146
+ same order as here (see `build_next_action_sft_samples`). `start_idx`
147
+ indicates the index of the first sample corresponding to this episode's
148
+ first step. The function returns the episode metrics and the next sample
149
+ index after this episode.
150
+
151
+ Args:
152
+ milestones: Optional list of MilestoneSpec to track for weak success.
153
+ If None, defaults to LOGIN_MILESTONES for backward compat.
154
+ use_som: If True, evaluate using Set-of-Marks element index matching
155
+ instead of coordinate-based evaluation.
156
+ """
157
+ if milestones is None:
158
+ milestones = LOGIN_MILESTONES
159
+
160
+ step_matches = 0
161
+ step_total = 0
162
+ coord_errors: List[float] = []
163
+ success_pred = True
164
+ click_hits = 0 # Point-based (5% threshold)
165
+ click_total = 0
166
+ bbox_hits = 0 # Bbox-based (anywhere in element)
167
+ bbox_total = 0
168
+ element_hits = 0 # SoM element index match
169
+ element_total = 0
170
+ # Generic milestone tracking
171
+ milestones_achieved: Dict[str, bool] = {m.name: False for m in milestones}
172
+ full_step_correct = 0
173
+ # Track the last state's success flag
174
+ last_state_success: Optional[bool] = None
175
+
176
+ sample_idx = start_idx
177
+
178
+ for step_idx, step in enumerate(episode.steps):
179
+ # Skip steps without an image; the dataset builder does the same.
180
+ if not step.observation.image_path:
181
+ continue
182
+
183
+ if sample_idx >= len(samples):
184
+ break
185
+
186
+ sample = samples[sample_idx]
187
+ sample_idx += 1
188
+
189
+ pred_action, _thought, pred_state, raw_text = policy.predict_action_from_sample(sample)
190
+ gt_action = step.action
191
+
192
+ # Track state-based success from final step
193
+ if pred_state and isinstance(pred_state, dict):
194
+ success_val = pred_state.get("success")
195
+ if isinstance(success_val, bool):
196
+ last_state_success = success_val
197
+
198
+ type_match = pred_action.type == gt_action.type
199
+ if type_match:
200
+ step_matches += 1
201
+ else:
202
+ success_pred = False
203
+
204
+ coord_error: Optional[float] = None
205
+ click_hit = False
206
+ bbox_hit = False
207
+ element_hit = False
208
+
209
+ # SoM mode: evaluate by element index for click/drag/type actions
210
+ if use_som and gt_action.type in {"click", "drag", "type"}:
211
+ if gt_action.element_index is not None:
212
+ element_total += 1
213
+ if pred_action.element_index == gt_action.element_index:
214
+ element_hits += 1
215
+ element_hit = True
216
+ elif gt_action.type in {"click", "drag"}:
217
+ # Coordinate mode: evaluate by coordinate distance
218
+ coord_error = compute_coordinate_error(pred_action, gt_action)
219
+ if coord_error is not None:
220
+ coord_errors.append(coord_error)
221
+ click_total += 1
222
+ if coord_error < 0.05:
223
+ click_hits += 1
224
+ click_hit = True
225
+
226
+ # Bbox-based evaluation (more lenient)
227
+ in_bbox = is_click_in_bbox(pred_action, gt_action)
228
+ if in_bbox is not None:
229
+ bbox_total += 1
230
+ if in_bbox:
231
+ bbox_hits += 1
232
+ bbox_hit = True
233
+
234
+ # Full step correctness: type matches AND element/coord match for relevant actions
235
+ if type_match:
236
+ if use_som and gt_action.type in {"click", "drag", "type"}:
237
+ # SoM mode: require element index match
238
+ if element_hit:
239
+ full_step_correct += 1
240
+ elif gt_action.type in {"click", "drag"}:
241
+ # Coordinate mode: require click hit
242
+ if click_hit:
243
+ full_step_correct += 1
244
+ else:
245
+ # Non-targeting actions (wait, done): type match is sufficient
246
+ full_step_correct += 1
247
+
248
+ # Track semantic milestones using the milestone spec
249
+ for milestone in milestones:
250
+ if step_idx == milestone.step_index and gt_action.type == milestone.expected_type:
251
+ if pred_action.type == milestone.expected_type:
252
+ # Check coord threshold if specified (for click actions)
253
+ if milestone.coord_threshold is not None:
254
+ if coord_error is not None and coord_error < milestone.coord_threshold:
255
+ milestones_achieved[milestone.name] = True
256
+ else:
257
+ # No coord threshold - type match is sufficient
258
+ milestones_achieved[milestone.name] = True
259
+
260
+ # Ensure DONE is correct at the DONE step.
261
+ if gt_action.type == "done" and pred_action.type != "done":
262
+ success_pred = False
263
+
264
+ # Optional logging of this step.
265
+ if log_fn is not None and (log_limit is None or logged_count < log_limit):
266
+ messages = sample.get("messages", [])
267
+ system_prompt = None
268
+ user_prompt = None
269
+ for m in messages:
270
+ if m.get("role") == "system" and system_prompt is None:
271
+ system_prompt = m.get("content")
272
+ if m.get("role") == "user" and user_prompt is None:
273
+ user_prompt = m.get("content")
274
+
275
+ record: Dict[str, Any] = {
276
+ "episode_id": episode.id,
277
+ "step_index": step_idx,
278
+ "goal": episode.goal,
279
+ "system_prompt": system_prompt,
280
+ "user_prompt": user_prompt,
281
+ "model_output_raw": raw_text,
282
+ "pred_action": {
283
+ "type": pred_action.type,
284
+ "x": pred_action.x,
285
+ "y": pred_action.y,
286
+ "text": pred_action.text,
287
+ "element_index": pred_action.element_index,
288
+ },
289
+ "ground_truth_action": {
290
+ "type": gt_action.type,
291
+ "x": gt_action.x,
292
+ "y": gt_action.y,
293
+ "text": gt_action.text,
294
+ "element_index": gt_action.element_index,
295
+ },
296
+ "correct_type": pred_action.type == gt_action.type,
297
+ "coord_error_norm": coord_error,
298
+ "element_match": pred_action.element_index == gt_action.element_index
299
+ if gt_action.element_index is not None
300
+ else None,
301
+ }
302
+
303
+ log_fn(record)
304
+ logged_count += 1
305
+
306
+ step_total += 1
307
+
308
+ metrics = EpisodeMetrics(
309
+ episode_id=episode.id,
310
+ step_matches=step_matches,
311
+ step_total=step_total,
312
+ coord_errors=coord_errors,
313
+ success_pred=success_pred,
314
+ success_gt=episode.success,
315
+ click_hits=click_hits,
316
+ click_total=click_total,
317
+ milestones_achieved=milestones_achieved,
318
+ full_step_correct=full_step_correct,
319
+ state_success=last_state_success,
320
+ bbox_hits=bbox_hits,
321
+ bbox_total=bbox_total,
322
+ element_hits=element_hits,
323
+ element_total=element_total,
324
+ )
325
+ return metrics, sample_idx, logged_count
326
+
327
+
328
+ def aggregate_metrics(episodes_metrics: List[EpisodeMetrics]) -> AggregateMetrics:
329
+ """Aggregate per-episode metrics into global metrics.
330
+
331
+ Three-tier episodic success metrics (from least to most strict):
332
+
333
+ 1. **weak_episode_success_rate**: Semantic goal completion. For the login
334
+ flow, requires: typed username, typed password, clicked login button
335
+ (within 10% coord error), and emitted DONE. Allows intermediate mistakes.
336
+
337
+ 2. **mean_episode_step_score**: Strict partial credit. Average of
338
+ (full_step_correct / step_total) per episode. A step is "full correct"
339
+ if action type matches AND (not a click OR click within 5% threshold).
340
+
341
+ 3. **episode_success_rate**: Hard metric. All steps must match exactly
342
+ (action type correct AND click hits where applicable). This is the
343
+ long-horizon metric that only becomes meaningful at high step accuracy.
344
+
345
+ Also computes:
346
+ - action_type_accuracy: total correct types / total steps.
347
+ - mean_coord_error: mean of all collected coordinate errors.
348
+ - mean_episode_progress: avg(step_matches / step_total) - type matches only.
349
+ """
350
+
351
+ num_episodes = len(episodes_metrics)
352
+ num_steps = sum(m.step_total for m in episodes_metrics)
353
+
354
+ total_matches = sum(m.step_matches for m in episodes_metrics)
355
+ action_type_accuracy = (total_matches / num_steps) if num_steps > 0 else 0.0
356
+
357
+ all_coord_errors: List[float] = []
358
+ for m in episodes_metrics:
359
+ all_coord_errors.extend(m.coord_errors)
360
+
361
+ mean_coord_error: Optional[float]
362
+ if all_coord_errors:
363
+ mean_coord_error = sum(all_coord_errors) / len(all_coord_errors)
364
+ else:
365
+ mean_coord_error = None
366
+
367
+ eval_episodes = [m for m in episodes_metrics if m.step_total > 0]
368
+ if eval_episodes:
369
+ success_count = sum(1 for m in eval_episodes if m.success_pred)
370
+ episode_success_rate = success_count / len(eval_episodes)
371
+ else:
372
+ episode_success_rate = None
373
+
374
+ total_click_hits = sum(m.click_hits for m in episodes_metrics)
375
+ total_click_total = sum(m.click_total for m in episodes_metrics)
376
+ if total_click_total > 0:
377
+ click_hit_rate: Optional[float] = total_click_hits / total_click_total
378
+ else:
379
+ click_hit_rate = None
380
+
381
+ # Partial credit: average episode progress (step_matches / step_total per episode)
382
+ if eval_episodes:
383
+ episode_progress_scores = [
384
+ m.step_matches / m.step_total for m in eval_episodes
385
+ ]
386
+ mean_episode_progress = sum(episode_progress_scores) / len(episode_progress_scores)
387
+ else:
388
+ mean_episode_progress = None
389
+
390
+ # Strict partial: avg(full_step_correct / step_total) - requires type match + click hit
391
+ if eval_episodes:
392
+ step_scores = [
393
+ m.full_step_correct / m.step_total for m in eval_episodes
394
+ ]
395
+ mean_episode_step_score = sum(step_scores) / len(step_scores)
396
+ else:
397
+ mean_episode_step_score = None
398
+
399
+ # Weak episode success: all milestones achieved
400
+ if eval_episodes:
401
+ weak_success_count = sum(
402
+ 1 for m in eval_episodes
403
+ if m.milestones_achieved and all(m.milestones_achieved.values())
404
+ )
405
+ weak_episode_success_rate = weak_success_count / len(eval_episodes)
406
+ else:
407
+ weak_episode_success_rate = None
408
+
409
+ # State-based success: from model's State: {"success": true}
410
+ episodes_with_state = [m for m in eval_episodes if m.state_success is not None]
411
+ if episodes_with_state:
412
+ state_success_count = sum(1 for m in episodes_with_state if m.state_success)
413
+ state_success_rate = state_success_count / len(episodes_with_state)
414
+ else:
415
+ state_success_rate = None
416
+
417
+ # Bbox-based click evaluation (more lenient than point-based)
418
+ total_bbox_hits = sum(m.bbox_hits for m in episodes_metrics)
419
+ total_bbox_total = sum(m.bbox_total for m in episodes_metrics)
420
+ if total_bbox_total > 0:
421
+ bbox_hit_rate: Optional[float] = total_bbox_hits / total_bbox_total
422
+ else:
423
+ bbox_hit_rate = None
424
+
425
+ # SoM element index accuracy
426
+ total_element_hits = sum(m.element_hits for m in episodes_metrics)
427
+ total_element_total = sum(m.element_total for m in episodes_metrics)
428
+ if total_element_total > 0:
429
+ element_accuracy: Optional[float] = total_element_hits / total_element_total
430
+ else:
431
+ element_accuracy = None
432
+
433
+ return AggregateMetrics(
434
+ num_episodes=num_episodes,
435
+ num_steps=num_steps,
436
+ action_type_accuracy=action_type_accuracy,
437
+ mean_coord_error=mean_coord_error,
438
+ coord_error_count=len(all_coord_errors),
439
+ episode_success_rate=episode_success_rate,
440
+ click_hit_rate=click_hit_rate,
441
+ mean_episode_progress=mean_episode_progress,
442
+ mean_episode_step_score=mean_episode_step_score,
443
+ weak_episode_success_rate=weak_episode_success_rate,
444
+ state_success_rate=state_success_rate,
445
+ bbox_hit_rate=bbox_hit_rate,
446
+ element_accuracy=element_accuracy,
447
+ )
448
+
449
+
450
+ def evaluate_policy_on_episodes(
451
+ policy: AgentPolicy,
452
+ episodes: List[Episode],
453
+ samples: List[Dict[str, Any]],
454
+ log_fn: Optional[Callable[[Dict[str, Any]], None]] = None,
455
+ log_limit: Optional[int] = None,
456
+ use_som: bool = False,
457
+ ) -> AggregateMetrics:
458
+ """Evaluate a policy on a list of episodes given corresponding SFT samples.
459
+
460
+ The `samples` list must have been produced from `episodes` using
461
+ `build_next_action_sft_samples`, so that iterating episodes/steps in order
462
+ aligns with iterating over `samples`.
463
+
464
+ Args:
465
+ use_som: If True, evaluate using Set-of-Marks element index matching
466
+ instead of coordinate-based evaluation.
467
+ """
468
+
469
+ episodes_metrics: List[EpisodeMetrics] = []
470
+ sample_idx = 0
471
+ logged_count = 0
472
+
473
+ for episode in episodes:
474
+ metrics, sample_idx, logged_count = evaluate_episode(
475
+ policy,
476
+ episode,
477
+ samples,
478
+ sample_idx,
479
+ log_fn=log_fn,
480
+ log_limit=log_limit,
481
+ logged_count=logged_count,
482
+ use_som=use_som,
483
+ )
484
+ episodes_metrics.append(metrics)
485
+
486
+ return aggregate_metrics(episodes_metrics)
@@ -0,0 +1,45 @@
1
+ """Grounding modules for visual element localization.
2
+
3
+ This package provides strategies for grounding natural language target descriptions
4
+ to specific regions on screenshots. The grounding/policy separation enables:
5
+
6
+ - Training policy and grounding separately or jointly
7
+ - Swapping grounding strategies without retraining policy
8
+ - Evaluating each layer independently
9
+ - Composing different grounding modules per platform
10
+
11
+ Available grounding strategies:
12
+
13
+ - OracleGrounder: Uses ground-truth bboxes (for evaluation)
14
+ - GeminiGrounder: Uses Google Gemini vision API for element detection
15
+ - DetectorGrounder: Generic detector wrapper with backend selection
16
+ - SoMGrounder: Element index selection from Set-of-Marks overlay (coming soon)
17
+ - AttentionGrounder: GUI-Actor style attention-based selection (coming soon)
18
+
19
+ Functions:
20
+
21
+ - extract_ui_elements: Extract all interactive UI elements from a screenshot
22
+ - overlay_element_marks: Overlay numbered labels (Set-of-Marks) on elements
23
+ """
24
+
25
+ from openadapt_ml.grounding.base import (
26
+ GroundingModule,
27
+ OracleGrounder,
28
+ RegionCandidate,
29
+ )
30
+ from openadapt_ml.grounding.detector import (
31
+ DetectorGrounder,
32
+ GeminiGrounder,
33
+ extract_ui_elements,
34
+ overlay_element_marks,
35
+ )
36
+
37
+ __all__ = [
38
+ "GroundingModule",
39
+ "OracleGrounder",
40
+ "RegionCandidate",
41
+ "DetectorGrounder",
42
+ "GeminiGrounder",
43
+ "extract_ui_elements",
44
+ "overlay_element_marks",
45
+ ]