openadapt-ml 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. openadapt_ml/__init__.py +0 -0
  2. openadapt_ml/benchmarks/__init__.py +125 -0
  3. openadapt_ml/benchmarks/agent.py +825 -0
  4. openadapt_ml/benchmarks/azure.py +761 -0
  5. openadapt_ml/benchmarks/base.py +366 -0
  6. openadapt_ml/benchmarks/cli.py +884 -0
  7. openadapt_ml/benchmarks/data_collection.py +432 -0
  8. openadapt_ml/benchmarks/runner.py +381 -0
  9. openadapt_ml/benchmarks/waa.py +704 -0
  10. openadapt_ml/cloud/__init__.py +5 -0
  11. openadapt_ml/cloud/azure_inference.py +441 -0
  12. openadapt_ml/cloud/lambda_labs.py +2445 -0
  13. openadapt_ml/cloud/local.py +790 -0
  14. openadapt_ml/config.py +56 -0
  15. openadapt_ml/datasets/__init__.py +0 -0
  16. openadapt_ml/datasets/next_action.py +507 -0
  17. openadapt_ml/evals/__init__.py +23 -0
  18. openadapt_ml/evals/grounding.py +241 -0
  19. openadapt_ml/evals/plot_eval_metrics.py +174 -0
  20. openadapt_ml/evals/trajectory_matching.py +486 -0
  21. openadapt_ml/grounding/__init__.py +45 -0
  22. openadapt_ml/grounding/base.py +236 -0
  23. openadapt_ml/grounding/detector.py +570 -0
  24. openadapt_ml/ingest/__init__.py +43 -0
  25. openadapt_ml/ingest/capture.py +312 -0
  26. openadapt_ml/ingest/loader.py +232 -0
  27. openadapt_ml/ingest/synthetic.py +1102 -0
  28. openadapt_ml/models/__init__.py +0 -0
  29. openadapt_ml/models/api_adapter.py +171 -0
  30. openadapt_ml/models/base_adapter.py +59 -0
  31. openadapt_ml/models/dummy_adapter.py +42 -0
  32. openadapt_ml/models/qwen_vl.py +426 -0
  33. openadapt_ml/runtime/__init__.py +0 -0
  34. openadapt_ml/runtime/policy.py +182 -0
  35. openadapt_ml/schemas/__init__.py +53 -0
  36. openadapt_ml/schemas/sessions.py +122 -0
  37. openadapt_ml/schemas/validation.py +252 -0
  38. openadapt_ml/scripts/__init__.py +0 -0
  39. openadapt_ml/scripts/compare.py +1490 -0
  40. openadapt_ml/scripts/demo_policy.py +62 -0
  41. openadapt_ml/scripts/eval_policy.py +287 -0
  42. openadapt_ml/scripts/make_gif.py +153 -0
  43. openadapt_ml/scripts/prepare_synthetic.py +43 -0
  44. openadapt_ml/scripts/run_qwen_login_benchmark.py +192 -0
  45. openadapt_ml/scripts/train.py +174 -0
  46. openadapt_ml/training/__init__.py +0 -0
  47. openadapt_ml/training/benchmark_viewer.py +1538 -0
  48. openadapt_ml/training/shared_ui.py +157 -0
  49. openadapt_ml/training/stub_provider.py +276 -0
  50. openadapt_ml/training/trainer.py +2446 -0
  51. openadapt_ml/training/viewer.py +2970 -0
  52. openadapt_ml-0.1.0.dist-info/METADATA +818 -0
  53. openadapt_ml-0.1.0.dist-info/RECORD +55 -0
  54. openadapt_ml-0.1.0.dist-info/WHEEL +4 -0
  55. openadapt_ml-0.1.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,241 @@
1
+ """Grounding-specific evaluation metrics.
2
+
3
+ This module provides metrics for evaluating grounding accuracy independent
4
+ of policy performance, as described in the architecture document.
5
+
6
+ Metrics:
7
+ - bbox_iou: Intersection over Union with ground-truth element bbox
8
+ - centroid_hit_rate: Whether click point lands inside correct element
9
+ - oracle_hit_rate@k: Any of top-k candidates correct
10
+ - grounding_latency: Time per grounding call
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ import time
16
+ from dataclasses import dataclass, field
17
+ from typing import TYPE_CHECKING
18
+
19
+ if TYPE_CHECKING:
20
+ from PIL import Image
21
+
22
+ from openadapt_ml.grounding.base import GroundingModule, RegionCandidate
23
+
24
+
25
+ @dataclass
26
+ class GroundingResult:
27
+ """Result of a single grounding evaluation."""
28
+
29
+ target_description: str
30
+ ground_truth_bbox: tuple[float, float, float, float] | None
31
+ predicted_candidates: list["RegionCandidate"]
32
+ latency_ms: float
33
+
34
+ # Computed metrics
35
+ best_iou: float = 0.0
36
+ centroid_hit: bool = False
37
+ oracle_hit_at_k: dict[int, bool] = field(default_factory=dict)
38
+
39
+ def __post_init__(self) -> None:
40
+ """Compute metrics from predictions and ground truth."""
41
+ if not self.ground_truth_bbox or not self.predicted_candidates:
42
+ return
43
+
44
+ gt_x1, gt_y1, gt_x2, gt_y2 = self.ground_truth_bbox
45
+
46
+ for k, candidate in enumerate(self.predicted_candidates, start=1):
47
+ # IoU
48
+ iou = self._compute_iou(candidate.bbox, self.ground_truth_bbox)
49
+ if iou > self.best_iou:
50
+ self.best_iou = iou
51
+
52
+ # Centroid hit
53
+ cx, cy = candidate.centroid
54
+ if gt_x1 <= cx <= gt_x2 and gt_y1 <= cy <= gt_y2:
55
+ if not self.centroid_hit:
56
+ self.centroid_hit = True
57
+
58
+ # Oracle hit at k (if any candidate up to k is a hit)
59
+ hit = iou > 0.5 or (gt_x1 <= cx <= gt_x2 and gt_y1 <= cy <= gt_y2)
60
+ if hit:
61
+ # Mark all k >= current k as hits
62
+ for check_k in range(k, max(len(self.predicted_candidates) + 1, 6)):
63
+ self.oracle_hit_at_k[check_k] = True
64
+
65
+ def _compute_iou(
66
+ self,
67
+ bbox1: tuple[float, float, float, float],
68
+ bbox2: tuple[float, float, float, float],
69
+ ) -> float:
70
+ """Compute IoU between two bboxes."""
71
+ x1, y1, x2, y2 = bbox1
72
+ ox1, oy1, ox2, oy2 = bbox2
73
+
74
+ # Intersection
75
+ ix1 = max(x1, ox1)
76
+ iy1 = max(y1, oy1)
77
+ ix2 = min(x2, ox2)
78
+ iy2 = min(y2, oy2)
79
+
80
+ if ix1 >= ix2 or iy1 >= iy2:
81
+ return 0.0
82
+
83
+ intersection = (ix2 - ix1) * (iy2 - iy1)
84
+ area1 = (x2 - x1) * (y2 - y1)
85
+ area2 = (ox2 - ox1) * (oy2 - oy1)
86
+ union = area1 + area2 - intersection
87
+
88
+ return intersection / union if union > 0 else 0.0
89
+
90
+
91
+ @dataclass
92
+ class GroundingMetrics:
93
+ """Aggregated grounding metrics across multiple evaluations."""
94
+
95
+ results: list[GroundingResult] = field(default_factory=list)
96
+
97
+ @property
98
+ def count(self) -> int:
99
+ """Number of evaluated samples."""
100
+ return len(self.results)
101
+
102
+ @property
103
+ def mean_iou(self) -> float:
104
+ """Mean IoU across all samples."""
105
+ if not self.results:
106
+ return 0.0
107
+ return sum(r.best_iou for r in self.results) / len(self.results)
108
+
109
+ @property
110
+ def centroid_hit_rate(self) -> float:
111
+ """Fraction of samples where centroid hit ground truth."""
112
+ if not self.results:
113
+ return 0.0
114
+ return sum(1 for r in self.results if r.centroid_hit) / len(self.results)
115
+
116
+ def oracle_hit_rate(self, k: int = 1) -> float:
117
+ """Fraction of samples where any of top-k candidates hit.
118
+
119
+ Args:
120
+ k: Number of candidates to consider.
121
+
122
+ Returns:
123
+ Hit rate in [0, 1].
124
+ """
125
+ if not self.results:
126
+ return 0.0
127
+ hits = sum(1 for r in self.results if r.oracle_hit_at_k.get(k, False))
128
+ return hits / len(self.results)
129
+
130
+ @property
131
+ def mean_latency_ms(self) -> float:
132
+ """Mean grounding latency in milliseconds."""
133
+ if not self.results:
134
+ return 0.0
135
+ return sum(r.latency_ms for r in self.results) / len(self.results)
136
+
137
+ def summary(self) -> dict:
138
+ """Return summary dict of all metrics."""
139
+ return {
140
+ "count": self.count,
141
+ "mean_iou": self.mean_iou,
142
+ "centroid_hit_rate": self.centroid_hit_rate,
143
+ "oracle_hit_rate@1": self.oracle_hit_rate(1),
144
+ "oracle_hit_rate@3": self.oracle_hit_rate(3),
145
+ "oracle_hit_rate@5": self.oracle_hit_rate(5),
146
+ "mean_latency_ms": self.mean_latency_ms,
147
+ }
148
+
149
+ def __str__(self) -> str:
150
+ """Pretty-print metrics summary."""
151
+ s = self.summary()
152
+ return (
153
+ f"Grounding Metrics (n={s['count']}):\n"
154
+ f" Mean IoU: {s['mean_iou']:.3f}\n"
155
+ f" Centroid Hit Rate: {s['centroid_hit_rate']:.3f}\n"
156
+ f" Oracle Hit @1: {s['oracle_hit_rate@1']:.3f}\n"
157
+ f" Oracle Hit @3: {s['oracle_hit_rate@3']:.3f}\n"
158
+ f" Oracle Hit @5: {s['oracle_hit_rate@5']:.3f}\n"
159
+ f" Mean Latency: {s['mean_latency_ms']:.1f}ms"
160
+ )
161
+
162
+
163
+ def evaluate_grounder(
164
+ grounder: "GroundingModule",
165
+ test_cases: list[tuple["Image", str, tuple[float, float, float, float]]],
166
+ k: int = 5,
167
+ ) -> GroundingMetrics:
168
+ """Evaluate a grounding module on test cases.
169
+
170
+ Args:
171
+ grounder: GroundingModule to evaluate.
172
+ test_cases: List of (image, target_description, ground_truth_bbox) tuples.
173
+ k: Number of candidates to request from grounder.
174
+
175
+ Returns:
176
+ GroundingMetrics with aggregated results.
177
+ """
178
+ metrics = GroundingMetrics()
179
+
180
+ for image, target_desc, gt_bbox in test_cases:
181
+ start = time.perf_counter()
182
+ candidates = grounder.ground(image, target_desc, k=k)
183
+ latency_ms = (time.perf_counter() - start) * 1000
184
+
185
+ result = GroundingResult(
186
+ target_description=target_desc,
187
+ ground_truth_bbox=gt_bbox,
188
+ predicted_candidates=candidates,
189
+ latency_ms=latency_ms,
190
+ )
191
+ metrics.results.append(result)
192
+
193
+ return metrics
194
+
195
+
196
+ def evaluate_grounder_on_episode(
197
+ grounder: "GroundingModule",
198
+ episode: "Episode",
199
+ k: int = 5,
200
+ ) -> GroundingMetrics:
201
+ """Evaluate a grounding module on an Episode's click actions.
202
+
203
+ Only evaluates steps with click actions that have ground-truth bboxes.
204
+
205
+ Args:
206
+ grounder: GroundingModule to evaluate.
207
+ episode: Episode with Steps containing Actions with bboxes.
208
+ k: Number of candidates to request.
209
+
210
+ Returns:
211
+ GroundingMetrics for click actions with bboxes.
212
+ """
213
+ from PIL import Image
214
+
215
+ from openadapt_ml.schemas.sessions import Episode
216
+
217
+ test_cases = []
218
+
219
+ for step in episode.steps:
220
+ action = step.action
221
+
222
+ # Only evaluate clicks with bboxes
223
+ if action.type not in ("click", "double_click"):
224
+ continue
225
+ if action.bbox is None:
226
+ continue
227
+ if step.observation.image_path is None:
228
+ continue
229
+
230
+ # Load image
231
+ try:
232
+ image = Image.open(step.observation.image_path)
233
+ except Exception:
234
+ continue
235
+
236
+ # Create target description from thought or action
237
+ target_desc = step.thought or f"element at ({action.x:.2f}, {action.y:.2f})"
238
+
239
+ test_cases.append((image, target_desc, action.bbox))
240
+
241
+ return evaluate_grounder(grounder, test_cases, k=k)
@@ -0,0 +1,174 @@
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ from pathlib import Path
5
+ from typing import Any, Dict, List
6
+
7
+ import json
8
+
9
+ import matplotlib.pyplot as plt
10
+ from matplotlib.patches import Patch
11
+
12
+
13
+ METRIC_KEYS = [
14
+ ("action_type_accuracy", "Action Type Accuracy"),
15
+ ("mean_coord_error", "Mean Coord Error"),
16
+ ("click_hit_rate", "Click Hit Rate"),
17
+ ("episode_success_rate", "Strict Episode Success"),
18
+ ("mean_episode_progress", "Episode Progress"),
19
+ ("mean_episode_step_score", "Step Score (Type+Click)"),
20
+ ("weak_episode_success_rate", "Weak Episode Success"),
21
+ ]
22
+
23
+
24
+ def _load_metrics(path: Path) -> Dict[str, Any]:
25
+ with path.open("r", encoding="utf-8") as f:
26
+ payload = json.load(f)
27
+ return payload.get("metrics", payload)
28
+
29
+
30
+ def _get_bar_style(label: str) -> tuple[str, str]:
31
+ """Determine bar color and hatch pattern based on model label.
32
+
33
+ Returns:
34
+ (color, hatch): color string and hatch pattern
35
+ """
36
+ label_lower = label.lower()
37
+
38
+ # Determine color based on model type
39
+ if "claude" in label_lower:
40
+ color = "#FF6B35" # Orange for Claude
41
+ elif "gpt" in label_lower or "openai" in label_lower:
42
+ color = "#C1121F" # Red for GPT
43
+ elif "2b" in label_lower:
44
+ color = "#4A90E2" # Light blue for 2B
45
+ elif "8b" in label_lower:
46
+ color = "#2E5C8A" # Dark blue for 8B
47
+ else:
48
+ color = "#6C757D" # Gray for unknown
49
+
50
+ # Determine hatch pattern for fine-tuned models
51
+ if "ft" in label_lower or "fine" in label_lower or "finetuned" in label_lower:
52
+ hatch = "///" # Diagonal lines for fine-tuned
53
+ else:
54
+ hatch = "" # Solid for base/API models
55
+
56
+ return color, hatch
57
+
58
+
59
+ def plot_eval_metrics(
60
+ metric_files: List[Path],
61
+ labels: List[str],
62
+ output_path: Path,
63
+ ) -> None:
64
+ if len(metric_files) != len(labels):
65
+ raise ValueError("Number of labels must match number of metric files")
66
+
67
+ metrics_list = [_load_metrics(p) for p in metric_files]
68
+
69
+ num_models = len(metrics_list)
70
+ num_metrics = len(METRIC_KEYS)
71
+
72
+ fig, axes = plt.subplots(1, num_metrics, figsize=(4 * num_metrics, 5))
73
+ fig.suptitle(
74
+ "VLM Model Comparison (Offline fine-tuned vs API models)",
75
+ fontsize=12,
76
+ fontweight='bold',
77
+ )
78
+ if num_metrics == 1:
79
+ axes = [axes]
80
+
81
+ for idx, (key, title) in enumerate(METRIC_KEYS):
82
+ ax = axes[idx]
83
+ values: List[float] = []
84
+ colors: List[str] = []
85
+ hatches: List[str] = []
86
+
87
+ for m, label in zip(metrics_list, labels):
88
+ v = m.get(key)
89
+ if v is None:
90
+ values.append(0.0)
91
+ else:
92
+ values.append(float(v))
93
+
94
+ color, hatch = _get_bar_style(label)
95
+ colors.append(color)
96
+ hatches.append(hatch)
97
+
98
+ x = range(num_models)
99
+ bars = ax.bar(x, values, tick_label=labels, color=colors, edgecolor='black', linewidth=1.2)
100
+
101
+ # Apply hatch patterns
102
+ for bar, hatch in zip(bars, hatches):
103
+ bar.set_hatch(hatch)
104
+
105
+ ax.set_title(title, fontsize=11, fontweight='bold')
106
+ ax.set_ylabel(key, fontsize=9)
107
+ ax.set_ylim(bottom=0.0)
108
+ # Rotate x-axis labels to prevent crowding
109
+ ax.tick_params(axis='x', labelrotation=45, labelsize=8)
110
+ # Align labels to the right for better readability when rotated
111
+ for tick in ax.get_xticklabels():
112
+ tick.set_horizontalalignment('right')
113
+
114
+ fig.tight_layout()
115
+
116
+ # Add legend explaining color coding and hatch patterns
117
+ legend_elements = [
118
+ Patch(facecolor='#4A90E2', edgecolor='black', label='Qwen3-VL-2B'),
119
+ Patch(facecolor='#2E5C8A', edgecolor='black', label='Qwen3-VL-8B'),
120
+ Patch(facecolor='#FF6B35', edgecolor='black', label='Claude (API)'),
121
+ Patch(facecolor='#C1121F', edgecolor='black', label='GPT (API)'),
122
+ Patch(facecolor='gray', edgecolor='black', hatch='///', label='Fine-tuned'),
123
+ Patch(facecolor='gray', edgecolor='black', label='Base/Pretrained'),
124
+ ]
125
+
126
+ fig.legend(
127
+ handles=legend_elements,
128
+ loc='lower center',
129
+ bbox_to_anchor=(0.5, -0.05),
130
+ ncol=3,
131
+ fontsize=9,
132
+ frameon=True,
133
+ )
134
+
135
+ output_path.parent.mkdir(parents=True, exist_ok=True)
136
+ fig.savefig(output_path, dpi=150, bbox_inches='tight')
137
+ plt.close(fig)
138
+
139
+
140
+ def main() -> None:
141
+ parser = argparse.ArgumentParser(
142
+ description="Plot evaluation metrics (base vs fine-tuned or cross-model).",
143
+ )
144
+ parser.add_argument(
145
+ "--files",
146
+ type=str,
147
+ nargs="+",
148
+ required=True,
149
+ help="Paths to one or more JSON metric files produced by eval_policy.py.",
150
+ )
151
+ parser.add_argument(
152
+ "--labels",
153
+ type=str,
154
+ nargs="+",
155
+ required=True,
156
+ help="Labels for each metrics file (e.g. base ft).",
157
+ )
158
+ parser.add_argument(
159
+ "--output",
160
+ type=str,
161
+ required=True,
162
+ help="Output PNG path for the plot.",
163
+ )
164
+ args = parser.parse_args()
165
+
166
+ files = [Path(p) for p in args.files]
167
+ labels = list(args.labels)
168
+ output_path = Path(args.output)
169
+
170
+ plot_eval_metrics(files, labels, output_path)
171
+
172
+
173
+ if __name__ == "__main__":
174
+ main()