openadapt-ml 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openadapt_ml/__init__.py +0 -0
- openadapt_ml/benchmarks/__init__.py +125 -0
- openadapt_ml/benchmarks/agent.py +825 -0
- openadapt_ml/benchmarks/azure.py +761 -0
- openadapt_ml/benchmarks/base.py +366 -0
- openadapt_ml/benchmarks/cli.py +884 -0
- openadapt_ml/benchmarks/data_collection.py +432 -0
- openadapt_ml/benchmarks/runner.py +381 -0
- openadapt_ml/benchmarks/waa.py +704 -0
- openadapt_ml/cloud/__init__.py +5 -0
- openadapt_ml/cloud/azure_inference.py +441 -0
- openadapt_ml/cloud/lambda_labs.py +2445 -0
- openadapt_ml/cloud/local.py +790 -0
- openadapt_ml/config.py +56 -0
- openadapt_ml/datasets/__init__.py +0 -0
- openadapt_ml/datasets/next_action.py +507 -0
- openadapt_ml/evals/__init__.py +23 -0
- openadapt_ml/evals/grounding.py +241 -0
- openadapt_ml/evals/plot_eval_metrics.py +174 -0
- openadapt_ml/evals/trajectory_matching.py +486 -0
- openadapt_ml/grounding/__init__.py +45 -0
- openadapt_ml/grounding/base.py +236 -0
- openadapt_ml/grounding/detector.py +570 -0
- openadapt_ml/ingest/__init__.py +43 -0
- openadapt_ml/ingest/capture.py +312 -0
- openadapt_ml/ingest/loader.py +232 -0
- openadapt_ml/ingest/synthetic.py +1102 -0
- openadapt_ml/models/__init__.py +0 -0
- openadapt_ml/models/api_adapter.py +171 -0
- openadapt_ml/models/base_adapter.py +59 -0
- openadapt_ml/models/dummy_adapter.py +42 -0
- openadapt_ml/models/qwen_vl.py +426 -0
- openadapt_ml/runtime/__init__.py +0 -0
- openadapt_ml/runtime/policy.py +182 -0
- openadapt_ml/schemas/__init__.py +53 -0
- openadapt_ml/schemas/sessions.py +122 -0
- openadapt_ml/schemas/validation.py +252 -0
- openadapt_ml/scripts/__init__.py +0 -0
- openadapt_ml/scripts/compare.py +1490 -0
- openadapt_ml/scripts/demo_policy.py +62 -0
- openadapt_ml/scripts/eval_policy.py +287 -0
- openadapt_ml/scripts/make_gif.py +153 -0
- openadapt_ml/scripts/prepare_synthetic.py +43 -0
- openadapt_ml/scripts/run_qwen_login_benchmark.py +192 -0
- openadapt_ml/scripts/train.py +174 -0
- openadapt_ml/training/__init__.py +0 -0
- openadapt_ml/training/benchmark_viewer.py +1538 -0
- openadapt_ml/training/shared_ui.py +157 -0
- openadapt_ml/training/stub_provider.py +276 -0
- openadapt_ml/training/trainer.py +2446 -0
- openadapt_ml/training/viewer.py +2970 -0
- openadapt_ml-0.1.0.dist-info/METADATA +818 -0
- openadapt_ml-0.1.0.dist-info/RECORD +55 -0
- openadapt_ml-0.1.0.dist-info/WHEEL +4 -0
- openadapt_ml-0.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,241 @@
|
|
|
1
|
+
"""Grounding-specific evaluation metrics.
|
|
2
|
+
|
|
3
|
+
This module provides metrics for evaluating grounding accuracy independent
|
|
4
|
+
of policy performance, as described in the architecture document.
|
|
5
|
+
|
|
6
|
+
Metrics:
|
|
7
|
+
- bbox_iou: Intersection over Union with ground-truth element bbox
|
|
8
|
+
- centroid_hit_rate: Whether click point lands inside correct element
|
|
9
|
+
- oracle_hit_rate@k: Any of top-k candidates correct
|
|
10
|
+
- grounding_latency: Time per grounding call
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
from __future__ import annotations
|
|
14
|
+
|
|
15
|
+
import time
|
|
16
|
+
from dataclasses import dataclass, field
|
|
17
|
+
from typing import TYPE_CHECKING
|
|
18
|
+
|
|
19
|
+
if TYPE_CHECKING:
|
|
20
|
+
from PIL import Image
|
|
21
|
+
|
|
22
|
+
from openadapt_ml.grounding.base import GroundingModule, RegionCandidate
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@dataclass
|
|
26
|
+
class GroundingResult:
|
|
27
|
+
"""Result of a single grounding evaluation."""
|
|
28
|
+
|
|
29
|
+
target_description: str
|
|
30
|
+
ground_truth_bbox: tuple[float, float, float, float] | None
|
|
31
|
+
predicted_candidates: list["RegionCandidate"]
|
|
32
|
+
latency_ms: float
|
|
33
|
+
|
|
34
|
+
# Computed metrics
|
|
35
|
+
best_iou: float = 0.0
|
|
36
|
+
centroid_hit: bool = False
|
|
37
|
+
oracle_hit_at_k: dict[int, bool] = field(default_factory=dict)
|
|
38
|
+
|
|
39
|
+
def __post_init__(self) -> None:
|
|
40
|
+
"""Compute metrics from predictions and ground truth."""
|
|
41
|
+
if not self.ground_truth_bbox or not self.predicted_candidates:
|
|
42
|
+
return
|
|
43
|
+
|
|
44
|
+
gt_x1, gt_y1, gt_x2, gt_y2 = self.ground_truth_bbox
|
|
45
|
+
|
|
46
|
+
for k, candidate in enumerate(self.predicted_candidates, start=1):
|
|
47
|
+
# IoU
|
|
48
|
+
iou = self._compute_iou(candidate.bbox, self.ground_truth_bbox)
|
|
49
|
+
if iou > self.best_iou:
|
|
50
|
+
self.best_iou = iou
|
|
51
|
+
|
|
52
|
+
# Centroid hit
|
|
53
|
+
cx, cy = candidate.centroid
|
|
54
|
+
if gt_x1 <= cx <= gt_x2 and gt_y1 <= cy <= gt_y2:
|
|
55
|
+
if not self.centroid_hit:
|
|
56
|
+
self.centroid_hit = True
|
|
57
|
+
|
|
58
|
+
# Oracle hit at k (if any candidate up to k is a hit)
|
|
59
|
+
hit = iou > 0.5 or (gt_x1 <= cx <= gt_x2 and gt_y1 <= cy <= gt_y2)
|
|
60
|
+
if hit:
|
|
61
|
+
# Mark all k >= current k as hits
|
|
62
|
+
for check_k in range(k, max(len(self.predicted_candidates) + 1, 6)):
|
|
63
|
+
self.oracle_hit_at_k[check_k] = True
|
|
64
|
+
|
|
65
|
+
def _compute_iou(
|
|
66
|
+
self,
|
|
67
|
+
bbox1: tuple[float, float, float, float],
|
|
68
|
+
bbox2: tuple[float, float, float, float],
|
|
69
|
+
) -> float:
|
|
70
|
+
"""Compute IoU between two bboxes."""
|
|
71
|
+
x1, y1, x2, y2 = bbox1
|
|
72
|
+
ox1, oy1, ox2, oy2 = bbox2
|
|
73
|
+
|
|
74
|
+
# Intersection
|
|
75
|
+
ix1 = max(x1, ox1)
|
|
76
|
+
iy1 = max(y1, oy1)
|
|
77
|
+
ix2 = min(x2, ox2)
|
|
78
|
+
iy2 = min(y2, oy2)
|
|
79
|
+
|
|
80
|
+
if ix1 >= ix2 or iy1 >= iy2:
|
|
81
|
+
return 0.0
|
|
82
|
+
|
|
83
|
+
intersection = (ix2 - ix1) * (iy2 - iy1)
|
|
84
|
+
area1 = (x2 - x1) * (y2 - y1)
|
|
85
|
+
area2 = (ox2 - ox1) * (oy2 - oy1)
|
|
86
|
+
union = area1 + area2 - intersection
|
|
87
|
+
|
|
88
|
+
return intersection / union if union > 0 else 0.0
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
@dataclass
|
|
92
|
+
class GroundingMetrics:
|
|
93
|
+
"""Aggregated grounding metrics across multiple evaluations."""
|
|
94
|
+
|
|
95
|
+
results: list[GroundingResult] = field(default_factory=list)
|
|
96
|
+
|
|
97
|
+
@property
|
|
98
|
+
def count(self) -> int:
|
|
99
|
+
"""Number of evaluated samples."""
|
|
100
|
+
return len(self.results)
|
|
101
|
+
|
|
102
|
+
@property
|
|
103
|
+
def mean_iou(self) -> float:
|
|
104
|
+
"""Mean IoU across all samples."""
|
|
105
|
+
if not self.results:
|
|
106
|
+
return 0.0
|
|
107
|
+
return sum(r.best_iou for r in self.results) / len(self.results)
|
|
108
|
+
|
|
109
|
+
@property
|
|
110
|
+
def centroid_hit_rate(self) -> float:
|
|
111
|
+
"""Fraction of samples where centroid hit ground truth."""
|
|
112
|
+
if not self.results:
|
|
113
|
+
return 0.0
|
|
114
|
+
return sum(1 for r in self.results if r.centroid_hit) / len(self.results)
|
|
115
|
+
|
|
116
|
+
def oracle_hit_rate(self, k: int = 1) -> float:
|
|
117
|
+
"""Fraction of samples where any of top-k candidates hit.
|
|
118
|
+
|
|
119
|
+
Args:
|
|
120
|
+
k: Number of candidates to consider.
|
|
121
|
+
|
|
122
|
+
Returns:
|
|
123
|
+
Hit rate in [0, 1].
|
|
124
|
+
"""
|
|
125
|
+
if not self.results:
|
|
126
|
+
return 0.0
|
|
127
|
+
hits = sum(1 for r in self.results if r.oracle_hit_at_k.get(k, False))
|
|
128
|
+
return hits / len(self.results)
|
|
129
|
+
|
|
130
|
+
@property
|
|
131
|
+
def mean_latency_ms(self) -> float:
|
|
132
|
+
"""Mean grounding latency in milliseconds."""
|
|
133
|
+
if not self.results:
|
|
134
|
+
return 0.0
|
|
135
|
+
return sum(r.latency_ms for r in self.results) / len(self.results)
|
|
136
|
+
|
|
137
|
+
def summary(self) -> dict:
|
|
138
|
+
"""Return summary dict of all metrics."""
|
|
139
|
+
return {
|
|
140
|
+
"count": self.count,
|
|
141
|
+
"mean_iou": self.mean_iou,
|
|
142
|
+
"centroid_hit_rate": self.centroid_hit_rate,
|
|
143
|
+
"oracle_hit_rate@1": self.oracle_hit_rate(1),
|
|
144
|
+
"oracle_hit_rate@3": self.oracle_hit_rate(3),
|
|
145
|
+
"oracle_hit_rate@5": self.oracle_hit_rate(5),
|
|
146
|
+
"mean_latency_ms": self.mean_latency_ms,
|
|
147
|
+
}
|
|
148
|
+
|
|
149
|
+
def __str__(self) -> str:
|
|
150
|
+
"""Pretty-print metrics summary."""
|
|
151
|
+
s = self.summary()
|
|
152
|
+
return (
|
|
153
|
+
f"Grounding Metrics (n={s['count']}):\n"
|
|
154
|
+
f" Mean IoU: {s['mean_iou']:.3f}\n"
|
|
155
|
+
f" Centroid Hit Rate: {s['centroid_hit_rate']:.3f}\n"
|
|
156
|
+
f" Oracle Hit @1: {s['oracle_hit_rate@1']:.3f}\n"
|
|
157
|
+
f" Oracle Hit @3: {s['oracle_hit_rate@3']:.3f}\n"
|
|
158
|
+
f" Oracle Hit @5: {s['oracle_hit_rate@5']:.3f}\n"
|
|
159
|
+
f" Mean Latency: {s['mean_latency_ms']:.1f}ms"
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
def evaluate_grounder(
|
|
164
|
+
grounder: "GroundingModule",
|
|
165
|
+
test_cases: list[tuple["Image", str, tuple[float, float, float, float]]],
|
|
166
|
+
k: int = 5,
|
|
167
|
+
) -> GroundingMetrics:
|
|
168
|
+
"""Evaluate a grounding module on test cases.
|
|
169
|
+
|
|
170
|
+
Args:
|
|
171
|
+
grounder: GroundingModule to evaluate.
|
|
172
|
+
test_cases: List of (image, target_description, ground_truth_bbox) tuples.
|
|
173
|
+
k: Number of candidates to request from grounder.
|
|
174
|
+
|
|
175
|
+
Returns:
|
|
176
|
+
GroundingMetrics with aggregated results.
|
|
177
|
+
"""
|
|
178
|
+
metrics = GroundingMetrics()
|
|
179
|
+
|
|
180
|
+
for image, target_desc, gt_bbox in test_cases:
|
|
181
|
+
start = time.perf_counter()
|
|
182
|
+
candidates = grounder.ground(image, target_desc, k=k)
|
|
183
|
+
latency_ms = (time.perf_counter() - start) * 1000
|
|
184
|
+
|
|
185
|
+
result = GroundingResult(
|
|
186
|
+
target_description=target_desc,
|
|
187
|
+
ground_truth_bbox=gt_bbox,
|
|
188
|
+
predicted_candidates=candidates,
|
|
189
|
+
latency_ms=latency_ms,
|
|
190
|
+
)
|
|
191
|
+
metrics.results.append(result)
|
|
192
|
+
|
|
193
|
+
return metrics
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
def evaluate_grounder_on_episode(
|
|
197
|
+
grounder: "GroundingModule",
|
|
198
|
+
episode: "Episode",
|
|
199
|
+
k: int = 5,
|
|
200
|
+
) -> GroundingMetrics:
|
|
201
|
+
"""Evaluate a grounding module on an Episode's click actions.
|
|
202
|
+
|
|
203
|
+
Only evaluates steps with click actions that have ground-truth bboxes.
|
|
204
|
+
|
|
205
|
+
Args:
|
|
206
|
+
grounder: GroundingModule to evaluate.
|
|
207
|
+
episode: Episode with Steps containing Actions with bboxes.
|
|
208
|
+
k: Number of candidates to request.
|
|
209
|
+
|
|
210
|
+
Returns:
|
|
211
|
+
GroundingMetrics for click actions with bboxes.
|
|
212
|
+
"""
|
|
213
|
+
from PIL import Image
|
|
214
|
+
|
|
215
|
+
from openadapt_ml.schemas.sessions import Episode
|
|
216
|
+
|
|
217
|
+
test_cases = []
|
|
218
|
+
|
|
219
|
+
for step in episode.steps:
|
|
220
|
+
action = step.action
|
|
221
|
+
|
|
222
|
+
# Only evaluate clicks with bboxes
|
|
223
|
+
if action.type not in ("click", "double_click"):
|
|
224
|
+
continue
|
|
225
|
+
if action.bbox is None:
|
|
226
|
+
continue
|
|
227
|
+
if step.observation.image_path is None:
|
|
228
|
+
continue
|
|
229
|
+
|
|
230
|
+
# Load image
|
|
231
|
+
try:
|
|
232
|
+
image = Image.open(step.observation.image_path)
|
|
233
|
+
except Exception:
|
|
234
|
+
continue
|
|
235
|
+
|
|
236
|
+
# Create target description from thought or action
|
|
237
|
+
target_desc = step.thought or f"element at ({action.x:.2f}, {action.y:.2f})"
|
|
238
|
+
|
|
239
|
+
test_cases.append((image, target_desc, action.bbox))
|
|
240
|
+
|
|
241
|
+
return evaluate_grounder(grounder, test_cases, k=k)
|
|
@@ -0,0 +1,174 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import argparse
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Any, Dict, List
|
|
6
|
+
|
|
7
|
+
import json
|
|
8
|
+
|
|
9
|
+
import matplotlib.pyplot as plt
|
|
10
|
+
from matplotlib.patches import Patch
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
METRIC_KEYS = [
|
|
14
|
+
("action_type_accuracy", "Action Type Accuracy"),
|
|
15
|
+
("mean_coord_error", "Mean Coord Error"),
|
|
16
|
+
("click_hit_rate", "Click Hit Rate"),
|
|
17
|
+
("episode_success_rate", "Strict Episode Success"),
|
|
18
|
+
("mean_episode_progress", "Episode Progress"),
|
|
19
|
+
("mean_episode_step_score", "Step Score (Type+Click)"),
|
|
20
|
+
("weak_episode_success_rate", "Weak Episode Success"),
|
|
21
|
+
]
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def _load_metrics(path: Path) -> Dict[str, Any]:
|
|
25
|
+
with path.open("r", encoding="utf-8") as f:
|
|
26
|
+
payload = json.load(f)
|
|
27
|
+
return payload.get("metrics", payload)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def _get_bar_style(label: str) -> tuple[str, str]:
|
|
31
|
+
"""Determine bar color and hatch pattern based on model label.
|
|
32
|
+
|
|
33
|
+
Returns:
|
|
34
|
+
(color, hatch): color string and hatch pattern
|
|
35
|
+
"""
|
|
36
|
+
label_lower = label.lower()
|
|
37
|
+
|
|
38
|
+
# Determine color based on model type
|
|
39
|
+
if "claude" in label_lower:
|
|
40
|
+
color = "#FF6B35" # Orange for Claude
|
|
41
|
+
elif "gpt" in label_lower or "openai" in label_lower:
|
|
42
|
+
color = "#C1121F" # Red for GPT
|
|
43
|
+
elif "2b" in label_lower:
|
|
44
|
+
color = "#4A90E2" # Light blue for 2B
|
|
45
|
+
elif "8b" in label_lower:
|
|
46
|
+
color = "#2E5C8A" # Dark blue for 8B
|
|
47
|
+
else:
|
|
48
|
+
color = "#6C757D" # Gray for unknown
|
|
49
|
+
|
|
50
|
+
# Determine hatch pattern for fine-tuned models
|
|
51
|
+
if "ft" in label_lower or "fine" in label_lower or "finetuned" in label_lower:
|
|
52
|
+
hatch = "///" # Diagonal lines for fine-tuned
|
|
53
|
+
else:
|
|
54
|
+
hatch = "" # Solid for base/API models
|
|
55
|
+
|
|
56
|
+
return color, hatch
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def plot_eval_metrics(
|
|
60
|
+
metric_files: List[Path],
|
|
61
|
+
labels: List[str],
|
|
62
|
+
output_path: Path,
|
|
63
|
+
) -> None:
|
|
64
|
+
if len(metric_files) != len(labels):
|
|
65
|
+
raise ValueError("Number of labels must match number of metric files")
|
|
66
|
+
|
|
67
|
+
metrics_list = [_load_metrics(p) for p in metric_files]
|
|
68
|
+
|
|
69
|
+
num_models = len(metrics_list)
|
|
70
|
+
num_metrics = len(METRIC_KEYS)
|
|
71
|
+
|
|
72
|
+
fig, axes = plt.subplots(1, num_metrics, figsize=(4 * num_metrics, 5))
|
|
73
|
+
fig.suptitle(
|
|
74
|
+
"VLM Model Comparison (Offline fine-tuned vs API models)",
|
|
75
|
+
fontsize=12,
|
|
76
|
+
fontweight='bold',
|
|
77
|
+
)
|
|
78
|
+
if num_metrics == 1:
|
|
79
|
+
axes = [axes]
|
|
80
|
+
|
|
81
|
+
for idx, (key, title) in enumerate(METRIC_KEYS):
|
|
82
|
+
ax = axes[idx]
|
|
83
|
+
values: List[float] = []
|
|
84
|
+
colors: List[str] = []
|
|
85
|
+
hatches: List[str] = []
|
|
86
|
+
|
|
87
|
+
for m, label in zip(metrics_list, labels):
|
|
88
|
+
v = m.get(key)
|
|
89
|
+
if v is None:
|
|
90
|
+
values.append(0.0)
|
|
91
|
+
else:
|
|
92
|
+
values.append(float(v))
|
|
93
|
+
|
|
94
|
+
color, hatch = _get_bar_style(label)
|
|
95
|
+
colors.append(color)
|
|
96
|
+
hatches.append(hatch)
|
|
97
|
+
|
|
98
|
+
x = range(num_models)
|
|
99
|
+
bars = ax.bar(x, values, tick_label=labels, color=colors, edgecolor='black', linewidth=1.2)
|
|
100
|
+
|
|
101
|
+
# Apply hatch patterns
|
|
102
|
+
for bar, hatch in zip(bars, hatches):
|
|
103
|
+
bar.set_hatch(hatch)
|
|
104
|
+
|
|
105
|
+
ax.set_title(title, fontsize=11, fontweight='bold')
|
|
106
|
+
ax.set_ylabel(key, fontsize=9)
|
|
107
|
+
ax.set_ylim(bottom=0.0)
|
|
108
|
+
# Rotate x-axis labels to prevent crowding
|
|
109
|
+
ax.tick_params(axis='x', labelrotation=45, labelsize=8)
|
|
110
|
+
# Align labels to the right for better readability when rotated
|
|
111
|
+
for tick in ax.get_xticklabels():
|
|
112
|
+
tick.set_horizontalalignment('right')
|
|
113
|
+
|
|
114
|
+
fig.tight_layout()
|
|
115
|
+
|
|
116
|
+
# Add legend explaining color coding and hatch patterns
|
|
117
|
+
legend_elements = [
|
|
118
|
+
Patch(facecolor='#4A90E2', edgecolor='black', label='Qwen3-VL-2B'),
|
|
119
|
+
Patch(facecolor='#2E5C8A', edgecolor='black', label='Qwen3-VL-8B'),
|
|
120
|
+
Patch(facecolor='#FF6B35', edgecolor='black', label='Claude (API)'),
|
|
121
|
+
Patch(facecolor='#C1121F', edgecolor='black', label='GPT (API)'),
|
|
122
|
+
Patch(facecolor='gray', edgecolor='black', hatch='///', label='Fine-tuned'),
|
|
123
|
+
Patch(facecolor='gray', edgecolor='black', label='Base/Pretrained'),
|
|
124
|
+
]
|
|
125
|
+
|
|
126
|
+
fig.legend(
|
|
127
|
+
handles=legend_elements,
|
|
128
|
+
loc='lower center',
|
|
129
|
+
bbox_to_anchor=(0.5, -0.05),
|
|
130
|
+
ncol=3,
|
|
131
|
+
fontsize=9,
|
|
132
|
+
frameon=True,
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
136
|
+
fig.savefig(output_path, dpi=150, bbox_inches='tight')
|
|
137
|
+
plt.close(fig)
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def main() -> None:
|
|
141
|
+
parser = argparse.ArgumentParser(
|
|
142
|
+
description="Plot evaluation metrics (base vs fine-tuned or cross-model).",
|
|
143
|
+
)
|
|
144
|
+
parser.add_argument(
|
|
145
|
+
"--files",
|
|
146
|
+
type=str,
|
|
147
|
+
nargs="+",
|
|
148
|
+
required=True,
|
|
149
|
+
help="Paths to one or more JSON metric files produced by eval_policy.py.",
|
|
150
|
+
)
|
|
151
|
+
parser.add_argument(
|
|
152
|
+
"--labels",
|
|
153
|
+
type=str,
|
|
154
|
+
nargs="+",
|
|
155
|
+
required=True,
|
|
156
|
+
help="Labels for each metrics file (e.g. base ft).",
|
|
157
|
+
)
|
|
158
|
+
parser.add_argument(
|
|
159
|
+
"--output",
|
|
160
|
+
type=str,
|
|
161
|
+
required=True,
|
|
162
|
+
help="Output PNG path for the plot.",
|
|
163
|
+
)
|
|
164
|
+
args = parser.parse_args()
|
|
165
|
+
|
|
166
|
+
files = [Path(p) for p in args.files]
|
|
167
|
+
labels = list(args.labels)
|
|
168
|
+
output_path = Path(args.output)
|
|
169
|
+
|
|
170
|
+
plot_eval_metrics(files, labels, output_path)
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
if __name__ == "__main__":
|
|
174
|
+
main()
|