openadapt-ml 0.1.0__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openadapt_ml/baselines/__init__.py +121 -0
- openadapt_ml/baselines/adapter.py +185 -0
- openadapt_ml/baselines/cli.py +314 -0
- openadapt_ml/baselines/config.py +448 -0
- openadapt_ml/baselines/parser.py +922 -0
- openadapt_ml/baselines/prompts.py +787 -0
- openadapt_ml/benchmarks/__init__.py +13 -107
- openadapt_ml/benchmarks/agent.py +297 -374
- openadapt_ml/benchmarks/azure.py +62 -24
- openadapt_ml/benchmarks/azure_ops_tracker.py +521 -0
- openadapt_ml/benchmarks/cli.py +1874 -751
- openadapt_ml/benchmarks/trace_export.py +631 -0
- openadapt_ml/benchmarks/viewer.py +1236 -0
- openadapt_ml/benchmarks/vm_monitor.py +1111 -0
- openadapt_ml/benchmarks/waa_deploy/Dockerfile +216 -0
- openadapt_ml/benchmarks/waa_deploy/__init__.py +10 -0
- openadapt_ml/benchmarks/waa_deploy/api_agent.py +540 -0
- openadapt_ml/benchmarks/waa_deploy/start_waa_server.bat +53 -0
- openadapt_ml/cloud/azure_inference.py +3 -5
- openadapt_ml/cloud/lambda_labs.py +722 -307
- openadapt_ml/cloud/local.py +3194 -89
- openadapt_ml/cloud/ssh_tunnel.py +595 -0
- openadapt_ml/datasets/next_action.py +125 -96
- openadapt_ml/evals/grounding.py +32 -9
- openadapt_ml/evals/plot_eval_metrics.py +15 -13
- openadapt_ml/evals/trajectory_matching.py +120 -57
- openadapt_ml/experiments/demo_prompt/__init__.py +19 -0
- openadapt_ml/experiments/demo_prompt/format_demo.py +236 -0
- openadapt_ml/experiments/demo_prompt/results/experiment_20251231_002125.json +83 -0
- openadapt_ml/experiments/demo_prompt/results/experiment_n30_20251231_165958.json +1100 -0
- openadapt_ml/experiments/demo_prompt/results/multistep_20251231_025051.json +182 -0
- openadapt_ml/experiments/demo_prompt/run_experiment.py +541 -0
- openadapt_ml/experiments/representation_shootout/__init__.py +70 -0
- openadapt_ml/experiments/representation_shootout/conditions.py +708 -0
- openadapt_ml/experiments/representation_shootout/config.py +390 -0
- openadapt_ml/experiments/representation_shootout/evaluator.py +659 -0
- openadapt_ml/experiments/representation_shootout/runner.py +687 -0
- openadapt_ml/experiments/waa_demo/__init__.py +10 -0
- openadapt_ml/experiments/waa_demo/demos.py +357 -0
- openadapt_ml/experiments/waa_demo/runner.py +732 -0
- openadapt_ml/experiments/waa_demo/tasks.py +151 -0
- openadapt_ml/export/__init__.py +9 -0
- openadapt_ml/export/__main__.py +6 -0
- openadapt_ml/export/cli.py +89 -0
- openadapt_ml/export/parquet.py +277 -0
- openadapt_ml/grounding/detector.py +18 -14
- openadapt_ml/ingest/__init__.py +11 -10
- openadapt_ml/ingest/capture.py +97 -86
- openadapt_ml/ingest/loader.py +120 -69
- openadapt_ml/ingest/synthetic.py +344 -193
- openadapt_ml/models/api_adapter.py +14 -4
- openadapt_ml/models/base_adapter.py +10 -2
- openadapt_ml/models/providers/__init__.py +288 -0
- openadapt_ml/models/providers/anthropic.py +266 -0
- openadapt_ml/models/providers/base.py +299 -0
- openadapt_ml/models/providers/google.py +376 -0
- openadapt_ml/models/providers/openai.py +342 -0
- openadapt_ml/models/qwen_vl.py +46 -19
- openadapt_ml/perception/__init__.py +35 -0
- openadapt_ml/perception/integration.py +399 -0
- openadapt_ml/retrieval/README.md +226 -0
- openadapt_ml/retrieval/USAGE.md +391 -0
- openadapt_ml/retrieval/__init__.py +91 -0
- openadapt_ml/retrieval/demo_retriever.py +843 -0
- openadapt_ml/retrieval/embeddings.py +630 -0
- openadapt_ml/retrieval/index.py +194 -0
- openadapt_ml/retrieval/retriever.py +162 -0
- openadapt_ml/runtime/__init__.py +50 -0
- openadapt_ml/runtime/policy.py +27 -14
- openadapt_ml/runtime/safety_gate.py +471 -0
- openadapt_ml/schema/__init__.py +113 -0
- openadapt_ml/schema/converters.py +588 -0
- openadapt_ml/schema/episode.py +470 -0
- openadapt_ml/scripts/capture_screenshots.py +530 -0
- openadapt_ml/scripts/compare.py +102 -61
- openadapt_ml/scripts/demo_policy.py +4 -1
- openadapt_ml/scripts/eval_policy.py +19 -14
- openadapt_ml/scripts/make_gif.py +1 -1
- openadapt_ml/scripts/prepare_synthetic.py +16 -17
- openadapt_ml/scripts/train.py +98 -75
- openadapt_ml/segmentation/README.md +920 -0
- openadapt_ml/segmentation/__init__.py +97 -0
- openadapt_ml/segmentation/adapters/__init__.py +5 -0
- openadapt_ml/segmentation/adapters/capture_adapter.py +420 -0
- openadapt_ml/segmentation/annotator.py +610 -0
- openadapt_ml/segmentation/cache.py +290 -0
- openadapt_ml/segmentation/cli.py +674 -0
- openadapt_ml/segmentation/deduplicator.py +656 -0
- openadapt_ml/segmentation/frame_describer.py +788 -0
- openadapt_ml/segmentation/pipeline.py +340 -0
- openadapt_ml/segmentation/schemas.py +622 -0
- openadapt_ml/segmentation/segment_extractor.py +634 -0
- openadapt_ml/training/azure_ops_viewer.py +1097 -0
- openadapt_ml/training/benchmark_viewer.py +3255 -19
- openadapt_ml/training/shared_ui.py +7 -7
- openadapt_ml/training/stub_provider.py +57 -35
- openadapt_ml/training/trainer.py +255 -441
- openadapt_ml/training/trl_trainer.py +403 -0
- openadapt_ml/training/viewer.py +323 -108
- openadapt_ml/training/viewer_components.py +180 -0
- {openadapt_ml-0.1.0.dist-info → openadapt_ml-0.2.1.dist-info}/METADATA +312 -69
- openadapt_ml-0.2.1.dist-info/RECORD +116 -0
- openadapt_ml/benchmarks/base.py +0 -366
- openadapt_ml/benchmarks/data_collection.py +0 -432
- openadapt_ml/benchmarks/runner.py +0 -381
- openadapt_ml/benchmarks/waa.py +0 -704
- openadapt_ml/schemas/__init__.py +0 -53
- openadapt_ml/schemas/sessions.py +0 -122
- openadapt_ml/schemas/validation.py +0 -252
- openadapt_ml-0.1.0.dist-info/RECORD +0 -55
- {openadapt_ml-0.1.0.dist-info → openadapt_ml-0.2.1.dist-info}/WHEEL +0 -0
- {openadapt_ml-0.1.0.dist-info → openadapt_ml-0.2.1.dist-info}/licenses/LICENSE +0 -0
openadapt_ml/training/trainer.py
CHANGED
|
@@ -4,15 +4,9 @@ import json
|
|
|
4
4
|
import time
|
|
5
5
|
from dataclasses import dataclass, field
|
|
6
6
|
from pathlib import Path
|
|
7
|
-
from typing import Any,
|
|
7
|
+
from typing import Any, Dict, List
|
|
8
8
|
|
|
9
|
-
import
|
|
10
|
-
from torch.optim import Optimizer
|
|
11
|
-
from torch.optim.lr_scheduler import LambdaLR
|
|
12
|
-
from torch.utils.data import DataLoader, Dataset
|
|
13
|
-
|
|
14
|
-
from openadapt_ml.models.base_adapter import BaseVLMAdapter
|
|
15
|
-
from openadapt_ml.schemas.sessions import Episode
|
|
9
|
+
from openadapt_ml.schema import ActionType
|
|
16
10
|
from openadapt_ml.training.shared_ui import (
|
|
17
11
|
get_shared_header_css as _get_shared_header_css,
|
|
18
12
|
generate_shared_header_html as _generate_shared_header_html,
|
|
@@ -21,6 +15,10 @@ from openadapt_ml.training.shared_ui import (
|
|
|
21
15
|
from openadapt_ml.training.viewer import (
|
|
22
16
|
generate_unified_viewer_from_output_dir,
|
|
23
17
|
)
|
|
18
|
+
from openadapt_ml.training.benchmark_viewer import (
|
|
19
|
+
_get_azure_jobs_panel_css,
|
|
20
|
+
_get_azure_jobs_panel_html,
|
|
21
|
+
)
|
|
24
22
|
|
|
25
23
|
|
|
26
24
|
def setup_job_directory(base_dir: str | Path, job_id: str) -> Path:
|
|
@@ -110,12 +108,18 @@ class TrainingConfig:
|
|
|
110
108
|
@dataclass
|
|
111
109
|
class TrainingState:
|
|
112
110
|
"""Tracks training progress for visualization."""
|
|
111
|
+
|
|
113
112
|
# Job identification
|
|
114
113
|
job_id: str = field(default_factory=lambda: time.strftime("%Y%m%d_%H%M%S"))
|
|
115
|
-
hostname: str = field(default_factory=lambda: __import__(
|
|
114
|
+
hostname: str = field(default_factory=lambda: __import__("socket").gethostname())
|
|
116
115
|
capture_path: str = ""
|
|
117
116
|
config_path: str = ""
|
|
118
117
|
goal: str = "" # Task goal/description for the training run
|
|
118
|
+
# Model configuration
|
|
119
|
+
model_name: str = "" # e.g. "Qwen/Qwen3-VL-2B-Instruct"
|
|
120
|
+
lora_r: int = 0 # LoRA rank
|
|
121
|
+
lora_alpha: int = 0 # LoRA alpha
|
|
122
|
+
load_in_4bit: bool = False # Quantization
|
|
119
123
|
# Training progress
|
|
120
124
|
epoch: int = 0
|
|
121
125
|
step: int = 0
|
|
@@ -139,7 +143,9 @@ class TrainingState:
|
|
|
139
143
|
setup_status: str = "" # e.g. "booting", "installing", "training", "complete"
|
|
140
144
|
setup_logs: List[str] = field(default_factory=list) # Setup progress messages
|
|
141
145
|
# Termination tracking
|
|
142
|
-
termination_status: str =
|
|
146
|
+
termination_status: str = (
|
|
147
|
+
"" # e.g. "auto_low_loss", "auto_complete", "user_stop", "running"
|
|
148
|
+
)
|
|
143
149
|
termination_message: str = "" # Human-readable termination reason
|
|
144
150
|
|
|
145
151
|
def log_step(self, epoch: int, step: int, loss: float, lr: float = 0.0) -> None:
|
|
@@ -148,33 +154,46 @@ class TrainingState:
|
|
|
148
154
|
self.step = step
|
|
149
155
|
self.loss = loss
|
|
150
156
|
self.learning_rate = lr
|
|
151
|
-
self.losses.append(
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
157
|
+
self.losses.append(
|
|
158
|
+
{
|
|
159
|
+
"epoch": epoch,
|
|
160
|
+
"step": step,
|
|
161
|
+
"loss": loss,
|
|
162
|
+
"lr": lr,
|
|
163
|
+
"time": time.time() - self.start_time,
|
|
164
|
+
}
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
def log_evaluation(
|
|
168
|
+
self,
|
|
169
|
+
epoch: int,
|
|
170
|
+
sample_idx: int,
|
|
171
|
+
image_path: str,
|
|
172
|
+
human_action: Dict,
|
|
173
|
+
predicted_action: Dict,
|
|
174
|
+
) -> None:
|
|
161
175
|
"""Log an evaluation sample."""
|
|
162
176
|
# Calculate distance for click actions
|
|
163
177
|
distance = 0.0
|
|
164
|
-
if
|
|
178
|
+
if (
|
|
179
|
+
human_action.get("type") == "click"
|
|
180
|
+
and predicted_action.get("type") == "click"
|
|
181
|
+
):
|
|
165
182
|
hx, hy = human_action.get("x", 0), human_action.get("y", 0)
|
|
166
183
|
px, py = predicted_action.get("x", 0), predicted_action.get("y", 0)
|
|
167
184
|
distance = ((hx - px) ** 2 + (hy - py) ** 2) ** 0.5
|
|
168
185
|
|
|
169
|
-
self.evaluations.append(
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
186
|
+
self.evaluations.append(
|
|
187
|
+
{
|
|
188
|
+
"epoch": epoch,
|
|
189
|
+
"sample_idx": sample_idx,
|
|
190
|
+
"image_path": image_path,
|
|
191
|
+
"human_action": human_action,
|
|
192
|
+
"predicted_action": predicted_action,
|
|
193
|
+
"distance": distance,
|
|
194
|
+
"correct": distance < 50, # Within 50 pixels is "correct"
|
|
195
|
+
}
|
|
196
|
+
)
|
|
178
197
|
|
|
179
198
|
def to_dict(self) -> Dict[str, Any]:
|
|
180
199
|
"""Convert state to serializable dict."""
|
|
@@ -185,9 +204,16 @@ class TrainingState:
|
|
|
185
204
|
"capture_path": self.capture_path,
|
|
186
205
|
"config_path": self.config_path,
|
|
187
206
|
"goal": self.goal,
|
|
207
|
+
# Model configuration
|
|
208
|
+
"model_name": self.model_name,
|
|
209
|
+
"lora_r": self.lora_r,
|
|
210
|
+
"lora_alpha": self.lora_alpha,
|
|
211
|
+
"load_in_4bit": self.load_in_4bit,
|
|
188
212
|
"instance_type": self.instance_type,
|
|
189
213
|
"instance_ip": self.instance_ip,
|
|
190
|
-
"started_at": time.strftime(
|
|
214
|
+
"started_at": time.strftime(
|
|
215
|
+
"%Y-%m-%dT%H:%M:%SZ", time.gmtime(self.start_time)
|
|
216
|
+
),
|
|
191
217
|
# Cloud provider info
|
|
192
218
|
"cloud_provider": self.cloud_provider,
|
|
193
219
|
"cloud_dashboard_url": self.cloud_dashboard_url,
|
|
@@ -227,6 +253,11 @@ class TrainingLogger:
|
|
|
227
253
|
cloud_dashboard_url: str = "",
|
|
228
254
|
cloud_instance_id: str = "",
|
|
229
255
|
job_id: str = "",
|
|
256
|
+
# Model configuration
|
|
257
|
+
model_name: str = "",
|
|
258
|
+
lora_r: int = 0,
|
|
259
|
+
lora_alpha: int = 0,
|
|
260
|
+
load_in_4bit: bool = False,
|
|
230
261
|
):
|
|
231
262
|
# Generate job_id if not provided
|
|
232
263
|
if not job_id:
|
|
@@ -242,6 +273,10 @@ class TrainingLogger:
|
|
|
242
273
|
capture_path=capture_path,
|
|
243
274
|
config_path=config_path,
|
|
244
275
|
goal=goal,
|
|
276
|
+
model_name=model_name,
|
|
277
|
+
lora_r=lora_r,
|
|
278
|
+
lora_alpha=lora_alpha,
|
|
279
|
+
load_in_4bit=load_in_4bit,
|
|
245
280
|
instance_ip=instance_ip,
|
|
246
281
|
instance_type=instance_type,
|
|
247
282
|
total_epochs=config.num_train_epochs,
|
|
@@ -299,6 +334,7 @@ class TrainingLogger:
|
|
|
299
334
|
def _save_config_snapshot(self) -> None:
|
|
300
335
|
"""Save training config snapshot to JSON."""
|
|
301
336
|
from dataclasses import asdict
|
|
337
|
+
|
|
302
338
|
config_file = self.output_dir / "config.json"
|
|
303
339
|
config_dict = asdict(self.config)
|
|
304
340
|
with open(config_file, "w") as f:
|
|
@@ -316,32 +352,45 @@ class TrainingLogger:
|
|
|
316
352
|
dashboard_path.write_text(html)
|
|
317
353
|
|
|
318
354
|
|
|
319
|
-
def _generate_termination_status_html(
|
|
355
|
+
def _generate_termination_status_html(
|
|
356
|
+
state: TrainingState, is_training_complete: bool
|
|
357
|
+
) -> str:
|
|
320
358
|
"""Generate HTML for termination status section."""
|
|
321
359
|
# Check if we have termination info
|
|
322
360
|
if state.termination_status:
|
|
323
361
|
# Map termination status to colors and icons
|
|
324
362
|
status_styles = {
|
|
325
|
-
"auto_complete": {
|
|
326
|
-
|
|
363
|
+
"auto_complete": {
|
|
364
|
+
"color": "#22c55e",
|
|
365
|
+
"icon": "✓",
|
|
366
|
+
"label": "Training Complete",
|
|
367
|
+
},
|
|
368
|
+
"auto_low_loss": {
|
|
369
|
+
"color": "#22c55e",
|
|
370
|
+
"icon": "✓",
|
|
371
|
+
"label": "Auto-Stopped (Low Loss)",
|
|
372
|
+
},
|
|
327
373
|
"user_stop": {"color": "#f59e0b", "icon": "■", "label": "Stopped by User"},
|
|
328
374
|
}
|
|
329
|
-
style = status_styles.get(
|
|
375
|
+
style = status_styles.get(
|
|
376
|
+
state.termination_status,
|
|
377
|
+
{"color": "#22c55e", "icon": "✓", "label": "Complete"},
|
|
378
|
+
)
|
|
330
379
|
|
|
331
|
-
return f
|
|
332
|
-
<div style="display: flex; align-items: center; gap: 8px; color: {style[
|
|
333
|
-
<span style="font-size: 1.2rem;">{style[
|
|
334
|
-
<span style="font-weight: 600;">{style[
|
|
380
|
+
return f"""<div style="display: flex; flex-direction: column; gap: 8px;">
|
|
381
|
+
<div style="display: flex; align-items: center; gap: 8px; color: {style["color"]};">
|
|
382
|
+
<span style="font-size: 1.2rem;">{style["icon"]}</span>
|
|
383
|
+
<span style="font-weight: 600;">{style["label"]}</span>
|
|
335
384
|
</div>
|
|
336
|
-
{f'<div style="font-size: 0.85rem; color: var(--text-muted); margin-left: 28px;">{state.termination_message}</div>' if state.termination_message else
|
|
337
|
-
</div>
|
|
385
|
+
{f'<div style="font-size: 0.85rem; color: var(--text-muted); margin-left: 28px;">{state.termination_message}</div>' if state.termination_message else ""}
|
|
386
|
+
</div>"""
|
|
338
387
|
elif is_training_complete:
|
|
339
|
-
return
|
|
388
|
+
return """<div style="display: flex; align-items: center; gap: 8px; color: #22c55e;">
|
|
340
389
|
<span style="font-size: 1.2rem;">✓</span>
|
|
341
390
|
<span style="font-weight: 600;">Training Complete</span>
|
|
342
|
-
</div>
|
|
391
|
+
</div>"""
|
|
343
392
|
else:
|
|
344
|
-
return
|
|
393
|
+
return """<button id="stop-training-btn" onclick="stopTraining()" style="
|
|
345
394
|
background: linear-gradient(135deg, #ef4444 0%, #dc2626 100%);
|
|
346
395
|
color: white;
|
|
347
396
|
border: none;
|
|
@@ -357,53 +406,65 @@ def _generate_termination_status_html(state: TrainingState, is_training_complete
|
|
|
357
406
|
">
|
|
358
407
|
<span style="font-size: 1.1rem;">■</span> Stop Training
|
|
359
408
|
</button>
|
|
360
|
-
<p id="stop-status" style="margin-top: 8px; font-size: 0.75rem; color: var(--text-muted);"></p>
|
|
409
|
+
<p id="stop-status" style="margin-top: 8px; font-size: 0.75rem; color: var(--text-muted);"></p>"""
|
|
361
410
|
|
|
362
411
|
|
|
363
412
|
def generate_training_dashboard(state: TrainingState, config: TrainingConfig) -> str:
|
|
364
413
|
"""Generate an HTML dashboard for training visualization."""
|
|
365
414
|
losses_json = json.dumps(state.losses)
|
|
366
415
|
# Use stored elapsed_time if available (historical data), otherwise calculate
|
|
367
|
-
elapsed =
|
|
416
|
+
elapsed = (
|
|
417
|
+
state.elapsed_time if state.elapsed_time > 0 else time.time() - state.start_time
|
|
418
|
+
)
|
|
368
419
|
elapsed_str = f"{int(elapsed // 60)}m {int(elapsed % 60)}s"
|
|
369
420
|
|
|
370
421
|
# Calculate stats
|
|
371
422
|
if state.losses:
|
|
372
|
-
min_loss = min(
|
|
373
|
-
|
|
423
|
+
min_loss = min(loss["loss"] for loss in state.losses)
|
|
424
|
+
sum(loss["loss"] for loss in state.losses) / len(state.losses)
|
|
374
425
|
recent_losses = state.losses[-10:] if len(state.losses) >= 10 else state.losses
|
|
375
|
-
recent_avg = sum(
|
|
426
|
+
recent_avg = sum(loss["loss"] for loss in recent_losses) / len(recent_losses)
|
|
376
427
|
# Calculate step times
|
|
377
428
|
step_times = []
|
|
378
429
|
for i in range(1, len(state.losses)):
|
|
379
|
-
step_times.append(state.losses[i]["time"] - state.losses[i-1]["time"])
|
|
430
|
+
step_times.append(state.losses[i]["time"] - state.losses[i - 1]["time"])
|
|
380
431
|
avg_step_time = sum(step_times) / len(step_times) if step_times else 0
|
|
381
432
|
# Loss by epoch
|
|
382
433
|
epoch_losses: dict = {}
|
|
383
|
-
for
|
|
384
|
-
ep =
|
|
434
|
+
for loss in state.losses:
|
|
435
|
+
ep = loss["epoch"]
|
|
385
436
|
if ep not in epoch_losses:
|
|
386
437
|
epoch_losses[ep] = []
|
|
387
|
-
epoch_losses[ep].append(
|
|
388
|
-
epoch_avg = {
|
|
438
|
+
epoch_losses[ep].append(loss["loss"])
|
|
439
|
+
epoch_avg = {
|
|
440
|
+
ep: sum(losses) / len(losses) for ep, losses in epoch_losses.items()
|
|
441
|
+
}
|
|
389
442
|
# Estimate ETA
|
|
390
443
|
# Steps per epoch = steps in completed epochs / completed epochs
|
|
391
444
|
completed_epochs = state.epoch
|
|
392
|
-
steps_in_completed = sum(
|
|
445
|
+
steps_in_completed = sum(
|
|
446
|
+
1 for loss in state.losses if loss["epoch"] < completed_epochs
|
|
447
|
+
)
|
|
393
448
|
if completed_epochs > 0 and steps_in_completed > 0:
|
|
394
449
|
steps_per_epoch = steps_in_completed / completed_epochs
|
|
395
450
|
else:
|
|
396
451
|
# Estimate from current epoch progress
|
|
397
|
-
steps_per_epoch =
|
|
452
|
+
steps_per_epoch = (
|
|
453
|
+
len(state.losses) / (state.epoch + 1)
|
|
454
|
+
if state.epoch >= 0
|
|
455
|
+
else len(state.losses)
|
|
456
|
+
)
|
|
398
457
|
|
|
399
|
-
total_epochs =
|
|
458
|
+
total_epochs = (
|
|
459
|
+
state.total_epochs if state.total_epochs > 0 else config.num_train_epochs
|
|
460
|
+
)
|
|
400
461
|
total_steps_estimate = steps_per_epoch * total_epochs
|
|
401
462
|
remaining_steps = max(0, total_steps_estimate - len(state.losses))
|
|
402
463
|
eta_seconds = remaining_steps * avg_step_time if avg_step_time > 0 else 0
|
|
403
464
|
# Check if training is complete (all steps done)
|
|
404
465
|
is_training_complete = remaining_steps == 0 and len(state.losses) > 0
|
|
405
466
|
else:
|
|
406
|
-
min_loss =
|
|
467
|
+
min_loss = recent_avg = avg_step_time = 0.0
|
|
407
468
|
epoch_avg = {}
|
|
408
469
|
eta_seconds = 0
|
|
409
470
|
steps_per_epoch = 0
|
|
@@ -414,10 +475,9 @@ def generate_training_dashboard(state: TrainingState, config: TrainingConfig) ->
|
|
|
414
475
|
epoch_avg_json = json.dumps(list(epoch_avg.items()))
|
|
415
476
|
|
|
416
477
|
# Generate comparison viewer preview if capture path available
|
|
417
|
-
comparison_viewer_path = ""
|
|
418
478
|
if state.capture_path:
|
|
419
479
|
try:
|
|
420
|
-
from openadapt_ml.scripts.compare import generate_comparison_html
|
|
480
|
+
from openadapt_ml.scripts.compare import generate_comparison_html
|
|
421
481
|
from openadapt_ml.ingest.capture import capture_to_episode
|
|
422
482
|
|
|
423
483
|
capture_path = Path(state.capture_path)
|
|
@@ -428,14 +488,20 @@ def generate_training_dashboard(state: TrainingState, config: TrainingConfig) ->
|
|
|
428
488
|
# Generate comparison data with null predictions (shows "— No prediction")
|
|
429
489
|
comparison_data = []
|
|
430
490
|
for i, step in enumerate(episode.steps):
|
|
491
|
+
# Extract normalized coordinates if available
|
|
492
|
+
action_x, action_y = None, None
|
|
493
|
+
if step.action.normalized_coordinates:
|
|
494
|
+
action_x, action_y = step.action.normalized_coordinates
|
|
431
495
|
step_data = {
|
|
432
496
|
"index": i,
|
|
433
|
-
"time": step.
|
|
434
|
-
"image_path": step.observation.
|
|
497
|
+
"time": step.step_index,
|
|
498
|
+
"image_path": step.observation.screenshot_path,
|
|
435
499
|
"human_action": {
|
|
436
|
-
"type": step.action.type
|
|
437
|
-
|
|
438
|
-
|
|
500
|
+
"type": step.action.type.value
|
|
501
|
+
if isinstance(step.action.type, ActionType)
|
|
502
|
+
else step.action.type,
|
|
503
|
+
"x": action_x,
|
|
504
|
+
"y": action_y,
|
|
439
505
|
"text": step.action.text,
|
|
440
506
|
},
|
|
441
507
|
"predicted_action": None, # Shows "— No prediction" in viewer
|
|
@@ -444,15 +510,21 @@ def generate_training_dashboard(state: TrainingState, config: TrainingConfig) ->
|
|
|
444
510
|
comparison_data.append(step_data)
|
|
445
511
|
|
|
446
512
|
# Generate comparison HTML
|
|
447
|
-
output_dir =
|
|
513
|
+
output_dir = (
|
|
514
|
+
Path(config.output_dir)
|
|
515
|
+
if hasattr(config, "output_dir")
|
|
516
|
+
else Path("training_output")
|
|
517
|
+
)
|
|
448
518
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
449
519
|
comparison_output = output_dir / "comparison_preview.html"
|
|
450
|
-
generate_comparison_html(
|
|
451
|
-
|
|
452
|
-
|
|
520
|
+
generate_comparison_html(
|
|
521
|
+
capture_path, episode, comparison_data, comparison_output
|
|
522
|
+
)
|
|
523
|
+
str(comparison_output.name) # Relative path
|
|
524
|
+
except Exception:
|
|
453
525
|
pass # Fail silently if comparison viewer can't be generated
|
|
454
526
|
|
|
455
|
-
html = f
|
|
527
|
+
html = f"""<!DOCTYPE html>
|
|
456
528
|
<html lang="en">
|
|
457
529
|
<head>
|
|
458
530
|
<meta charset="UTF-8">
|
|
@@ -596,6 +668,42 @@ def generate_training_dashboard(state: TrainingState, config: TrainingConfig) ->
|
|
|
596
668
|
.setup-log-line.current {{
|
|
597
669
|
color: var(--accent);
|
|
598
670
|
}}
|
|
671
|
+
.config-panel {{
|
|
672
|
+
background: var(--bg-secondary);
|
|
673
|
+
border: 1px solid var(--border-color);
|
|
674
|
+
border-radius: 12px;
|
|
675
|
+
padding: 16px 20px;
|
|
676
|
+
margin-bottom: 24px;
|
|
677
|
+
}}
|
|
678
|
+
.config-grid {{
|
|
679
|
+
display: grid;
|
|
680
|
+
grid-template-columns: repeat(auto-fit, minmax(200px, 1fr));
|
|
681
|
+
gap: 16px;
|
|
682
|
+
}}
|
|
683
|
+
.config-item {{
|
|
684
|
+
display: flex;
|
|
685
|
+
flex-direction: column;
|
|
686
|
+
gap: 4px;
|
|
687
|
+
}}
|
|
688
|
+
.config-label {{
|
|
689
|
+
font-size: 0.7rem;
|
|
690
|
+
color: var(--text-secondary);
|
|
691
|
+
text-transform: uppercase;
|
|
692
|
+
letter-spacing: 0.5px;
|
|
693
|
+
}}
|
|
694
|
+
.config-value {{
|
|
695
|
+
font-family: "SF Mono", Monaco, monospace;
|
|
696
|
+
font-size: 0.85rem;
|
|
697
|
+
color: var(--text-primary);
|
|
698
|
+
}}
|
|
699
|
+
.config-value.model {{
|
|
700
|
+
color: var(--accent);
|
|
701
|
+
}}
|
|
702
|
+
.config-value.goal {{
|
|
703
|
+
font-family: -apple-system, BlinkMacSystemFont, "Inter", sans-serif;
|
|
704
|
+
font-size: 0.8rem;
|
|
705
|
+
opacity: 0.9;
|
|
706
|
+
}}
|
|
599
707
|
.status {{
|
|
600
708
|
display: flex;
|
|
601
709
|
align-items: center;
|
|
@@ -754,6 +862,8 @@ def generate_training_dashboard(state: TrainingState, config: TrainingConfig) ->
|
|
|
754
862
|
}}
|
|
755
863
|
/* Shared header styles (injected from _get_shared_header_css) */
|
|
756
864
|
{_get_shared_header_css()}
|
|
865
|
+
/* Azure ML Jobs panel styles */
|
|
866
|
+
{_get_azure_jobs_panel_css()}
|
|
757
867
|
.eval-panel {{
|
|
758
868
|
background: var(--bg-secondary);
|
|
759
869
|
border: 1px solid var(--border-color);
|
|
@@ -1073,10 +1183,10 @@ def generate_training_dashboard(state: TrainingState, config: TrainingConfig) ->
|
|
|
1073
1183
|
<div class="container">
|
|
1074
1184
|
<header>
|
|
1075
1185
|
<div>
|
|
1076
|
-
<h1>Training Dashboard{f' <a href="{state.cloud_dashboard_url}" target="_blank" class="cloud-link cloud-badge"><svg viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2"><path d="M18 10h-1.26A8 8 0 1 0 9 20h9a5 5 0 0 0 0-10z"/></svg>{state.cloud_provider.title()} Cloud</a>' if state.cloud_dashboard_url else
|
|
1186
|
+
<h1>Training Dashboard{f' <a href="{state.cloud_dashboard_url}" target="_blank" class="cloud-link cloud-badge"><svg viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2"><path d="M18 10h-1.26A8 8 0 1 0 9 20h9a5 5 0 0 0 0-10z"/></svg>{state.cloud_provider.title()} Cloud</a>' if state.cloud_dashboard_url else ""}</h1>
|
|
1077
1187
|
<div class="job-info" id="job-info">
|
|
1078
|
-
<span class="job-host">{state.hostname or
|
|
1079
|
-
{f'<span class="job-config">{state.instance_type}</span>' if state.instance_type else
|
|
1188
|
+
<span class="job-host">{state.hostname or "stub-local"} @ {state.instance_ip or "127.0.0.1"}</span>
|
|
1189
|
+
{f'<span class="job-config">{state.instance_type}</span>' if state.instance_type else ""}
|
|
1080
1190
|
</div>
|
|
1081
1191
|
</div>
|
|
1082
1192
|
<div class="status" id="status">
|
|
@@ -1085,13 +1195,40 @@ def generate_training_dashboard(state: TrainingState, config: TrainingConfig) ->
|
|
|
1085
1195
|
</div>
|
|
1086
1196
|
</header>
|
|
1087
1197
|
|
|
1088
|
-
<div class="setup-panel{
|
|
1198
|
+
<div class="setup-panel{" hidden" if not state.setup_logs else ""}" id="setup-panel">
|
|
1089
1199
|
<div class="setup-header">
|
|
1090
1200
|
<h2>Setup Progress</h2>
|
|
1091
|
-
<span class="setup-status-badge {state.setup_status}" id="setup-status-badge">{state.setup_status or
|
|
1201
|
+
<span class="setup-status-badge {state.setup_status}" id="setup-status-badge">{state.setup_status or "initializing"}</span>
|
|
1092
1202
|
</div>
|
|
1093
1203
|
<div class="setup-logs" id="setup-logs">
|
|
1094
|
-
{
|
|
1204
|
+
{"".join(f'<div class="setup-log-line{" current" if i == len(state.setup_logs) - 1 else ""}">{log}</div>' for i, log in enumerate(state.setup_logs)) if state.setup_logs else '<div class="setup-log-line">Waiting for setup logs...</div>'}
|
|
1205
|
+
</div>
|
|
1206
|
+
</div>
|
|
1207
|
+
|
|
1208
|
+
{_get_azure_jobs_panel_html()}
|
|
1209
|
+
|
|
1210
|
+
<div class="config-panel" id="config-panel">
|
|
1211
|
+
<div class="config-grid">
|
|
1212
|
+
<div class="config-item">
|
|
1213
|
+
<span class="config-label">Model</span>
|
|
1214
|
+
<span class="config-value model" id="config-model">{state.model_name or "Not specified"}</span>
|
|
1215
|
+
</div>
|
|
1216
|
+
<div class="config-item">
|
|
1217
|
+
<span class="config-label">Goal</span>
|
|
1218
|
+
<span class="config-value goal" id="config-goal">{state.goal or "Not specified"}</span>
|
|
1219
|
+
</div>
|
|
1220
|
+
<div class="config-item">
|
|
1221
|
+
<span class="config-label">LoRA</span>
|
|
1222
|
+
<span class="config-value" id="config-lora">{f"r={state.lora_r}, α={state.lora_alpha}" if state.lora_r else "Not specified"}</span>
|
|
1223
|
+
</div>
|
|
1224
|
+
<div class="config-item">
|
|
1225
|
+
<span class="config-label">Quantization</span>
|
|
1226
|
+
<span class="config-value" id="config-quant">{"4-bit" if state.load_in_4bit else "None"}</span>
|
|
1227
|
+
</div>
|
|
1228
|
+
<div class="config-item">
|
|
1229
|
+
<span class="config-label">Config</span>
|
|
1230
|
+
<span class="config-value" id="config-path">{state.config_path or "Not specified"}</span>
|
|
1231
|
+
</div>
|
|
1095
1232
|
</div>
|
|
1096
1233
|
</div>
|
|
1097
1234
|
|
|
@@ -1437,7 +1574,7 @@ def generate_training_dashboard(state: TrainingState, config: TrainingConfig) ->
|
|
|
1437
1574
|
let etaSeconds = {eta_seconds};
|
|
1438
1575
|
let avgStepTime = {avg_step_time};
|
|
1439
1576
|
let remainingSteps = {remaining_steps};
|
|
1440
|
-
let isTrainingComplete = {
|
|
1577
|
+
let isTrainingComplete = {"true" if is_training_complete else "false"};
|
|
1441
1578
|
|
|
1442
1579
|
// Auto-stop when loss <= threshold (INVARIANT: training should stop when loss <= 1.0)
|
|
1443
1580
|
const AUTO_STOP_LOSS_THRESHOLD = 1.0;
|
|
@@ -1519,6 +1656,28 @@ def generate_training_dashboard(state: TrainingState, config: TrainingConfig) ->
|
|
|
1519
1656
|
}}
|
|
1520
1657
|
}}
|
|
1521
1658
|
|
|
1659
|
+
// Update config panel
|
|
1660
|
+
const configModel = document.getElementById('config-model');
|
|
1661
|
+
const configGoal = document.getElementById('config-goal');
|
|
1662
|
+
const configLora = document.getElementById('config-lora');
|
|
1663
|
+
const configQuant = document.getElementById('config-quant');
|
|
1664
|
+
const configPath = document.getElementById('config-path');
|
|
1665
|
+
if (configModel && data.model_name) {{
|
|
1666
|
+
configModel.textContent = data.model_name;
|
|
1667
|
+
}}
|
|
1668
|
+
if (configGoal && data.goal) {{
|
|
1669
|
+
configGoal.textContent = data.goal;
|
|
1670
|
+
}}
|
|
1671
|
+
if (configLora && (data.lora_r || data.lora_alpha)) {{
|
|
1672
|
+
configLora.textContent = `r=${{data.lora_r || 0}}, α=${{data.lora_alpha || 0}}`;
|
|
1673
|
+
}}
|
|
1674
|
+
if (configQuant) {{
|
|
1675
|
+
configQuant.textContent = data.load_in_4bit ? '4-bit' : 'None';
|
|
1676
|
+
}}
|
|
1677
|
+
if (configPath && data.config_path) {{
|
|
1678
|
+
configPath.textContent = data.config_path;
|
|
1679
|
+
}}
|
|
1680
|
+
|
|
1522
1681
|
// Update setup panel if setup logs present
|
|
1523
1682
|
if (data.setup_logs && data.setup_logs.length > 0) {{
|
|
1524
1683
|
const setupPanel = document.getElementById('setup-panel');
|
|
@@ -1914,7 +2073,7 @@ def generate_training_dashboard(state: TrainingState, config: TrainingConfig) ->
|
|
|
1914
2073
|
setInterval(updateStatusIndicator, 1000); // Update LIVE/STALE indicator every second
|
|
1915
2074
|
</script>
|
|
1916
2075
|
</body>
|
|
1917
|
-
</html>
|
|
2076
|
+
</html>"""
|
|
1918
2077
|
return html
|
|
1919
2078
|
|
|
1920
2079
|
|
|
@@ -1952,6 +2111,7 @@ def regenerate_all_dashboards(output_dir: str | Path) -> list[Path]:
|
|
|
1952
2111
|
except Exception as e:
|
|
1953
2112
|
print(f"Warning: Failed to generate unified viewer: {e}")
|
|
1954
2113
|
import traceback
|
|
2114
|
+
|
|
1955
2115
|
traceback.print_exc()
|
|
1956
2116
|
|
|
1957
2117
|
return regenerated
|
|
@@ -1994,7 +2154,9 @@ def regenerate_local_dashboard(
|
|
|
1994
2154
|
state = TrainingState(
|
|
1995
2155
|
job_id=data.get("job_id", "unknown"),
|
|
1996
2156
|
hostname=data.get("hostname", ""),
|
|
1997
|
-
capture_path=str(capture_path)
|
|
2157
|
+
capture_path=str(capture_path)
|
|
2158
|
+
if capture_path
|
|
2159
|
+
else data.get("capture_path", ""),
|
|
1998
2160
|
config_path=data.get("config_path", ""),
|
|
1999
2161
|
epoch=data.get("epoch", 0),
|
|
2000
2162
|
step=data.get("step", 0),
|
|
@@ -2037,30 +2199,33 @@ def regenerate_local_dashboard(
|
|
|
2037
2199
|
if training_status == "COMPLETED":
|
|
2038
2200
|
html = html.replace(
|
|
2039
2201
|
'<div class="status" id="status">',
|
|
2040
|
-
'<div class="status complete" id="status">'
|
|
2202
|
+
'<div class="status complete" id="status">',
|
|
2041
2203
|
)
|
|
2042
2204
|
html = html.replace(
|
|
2043
2205
|
'<span id="status-text">Training in progress</span>',
|
|
2044
|
-
'<span id="status-text">COMPLETED</span>'
|
|
2206
|
+
'<span id="status-text">COMPLETED</span>',
|
|
2045
2207
|
)
|
|
2046
2208
|
elif training_status == "STOPPED":
|
|
2047
2209
|
html = html.replace(
|
|
2048
|
-
'<div class="status" id="status">',
|
|
2049
|
-
'<div class="status stale" id="status">'
|
|
2210
|
+
'<div class="status" id="status">', '<div class="status stale" id="status">'
|
|
2050
2211
|
)
|
|
2051
2212
|
html = html.replace(
|
|
2052
2213
|
'<span id="status-text">Training in progress</span>',
|
|
2053
|
-
'<span id="status-text">STOPPED (Epoch {}/{})'.format(
|
|
2214
|
+
'<span id="status-text">STOPPED (Epoch {}/{})'.format(
|
|
2215
|
+
current_epoch + 1, total_epochs
|
|
2216
|
+
)
|
|
2217
|
+
+ "</span>",
|
|
2054
2218
|
)
|
|
2055
2219
|
|
|
2056
2220
|
# Fix ETA display for completed/stopped training
|
|
2057
2221
|
import re
|
|
2222
|
+
|
|
2058
2223
|
if training_status in ("COMPLETED", "STOPPED"):
|
|
2059
2224
|
# Replace "calculating..." with appropriate status
|
|
2060
2225
|
html = re.sub(
|
|
2061
2226
|
r'(<div class="stat-value" id="stat-eta">)[^<]*(</div>)',
|
|
2062
|
-
r
|
|
2063
|
-
html
|
|
2227
|
+
r"\1—\2" if training_status == "STOPPED" else r"\1complete\2",
|
|
2228
|
+
html,
|
|
2064
2229
|
)
|
|
2065
2230
|
|
|
2066
2231
|
# Replace dynamic nav with static unified header
|
|
@@ -2071,20 +2236,20 @@ def regenerate_local_dashboard(
|
|
|
2071
2236
|
# This is critical for file:// protocol where fetch() doesn't work
|
|
2072
2237
|
html = html.replace(
|
|
2073
2238
|
"setInterval(fetchAndUpdate, 3000);",
|
|
2074
|
-
"// fetchAndUpdate disabled for static dashboard"
|
|
2239
|
+
"// fetchAndUpdate disabled for static dashboard",
|
|
2075
2240
|
)
|
|
2076
2241
|
html = html.replace(
|
|
2077
2242
|
"setInterval(updateElapsedDisplay, 1000);",
|
|
2078
|
-
"// updateElapsedDisplay disabled for static dashboard"
|
|
2243
|
+
"// updateElapsedDisplay disabled for static dashboard",
|
|
2079
2244
|
)
|
|
2080
2245
|
html = html.replace(
|
|
2081
2246
|
"setInterval(updateStatusIndicator, 1000);",
|
|
2082
|
-
"// updateStatusIndicator disabled for static dashboard"
|
|
2247
|
+
"// updateStatusIndicator disabled for static dashboard",
|
|
2083
2248
|
)
|
|
2084
2249
|
# CRITICAL: Disable discoverDashboards() - it overwrites static nav on file:// protocol
|
|
2085
2250
|
html = html.replace(
|
|
2086
2251
|
"discoverDashboards();",
|
|
2087
|
-
"// discoverDashboards disabled - using static nav for file:// protocol"
|
|
2252
|
+
"// discoverDashboards disabled - using static nav for file:// protocol",
|
|
2088
2253
|
)
|
|
2089
2254
|
|
|
2090
2255
|
# Write output
|
|
@@ -2093,354 +2258,3 @@ def regenerate_local_dashboard(
|
|
|
2093
2258
|
print(f"Regenerated dashboard: {dashboard_path}")
|
|
2094
2259
|
|
|
2095
2260
|
return dashboard_path
|
|
2096
|
-
|
|
2097
|
-
|
|
2098
|
-
def run_epoch_evaluation(
|
|
2099
|
-
adapter: BaseVLMAdapter,
|
|
2100
|
-
episode: Episode,
|
|
2101
|
-
epoch: int,
|
|
2102
|
-
config: TrainingConfig,
|
|
2103
|
-
logger: "TrainingLogger",
|
|
2104
|
-
sample_indices: Optional[List[int]] = None,
|
|
2105
|
-
) -> Path:
|
|
2106
|
-
"""Run inference evaluation on sample steps after an epoch.
|
|
2107
|
-
|
|
2108
|
-
This generates a comparison_epoch{N}.html file showing human vs predicted actions.
|
|
2109
|
-
|
|
2110
|
-
Args:
|
|
2111
|
-
adapter: Trained adapter to use for inference
|
|
2112
|
-
episode: Episode with steps to evaluate
|
|
2113
|
-
epoch: Current epoch number
|
|
2114
|
-
config: Training configuration
|
|
2115
|
-
logger: Training logger for state tracking
|
|
2116
|
-
sample_indices: Specific step indices to evaluate (default: evenly spaced)
|
|
2117
|
-
|
|
2118
|
-
Returns:
|
|
2119
|
-
Path to generated comparison HTML file
|
|
2120
|
-
"""
|
|
2121
|
-
from openadapt_ml.scripts.compare import generate_comparison_html, predict_action, format_action
|
|
2122
|
-
|
|
2123
|
-
output_dir = Path(config.output_dir)
|
|
2124
|
-
output_dir.mkdir(parents=True, exist_ok=True)
|
|
2125
|
-
|
|
2126
|
-
# Select sample indices if not provided
|
|
2127
|
-
num_samples = min(config.eval_samples, len(episode.steps))
|
|
2128
|
-
if sample_indices is None:
|
|
2129
|
-
if num_samples >= len(episode.steps):
|
|
2130
|
-
sample_indices = list(range(len(episode.steps)))
|
|
2131
|
-
else:
|
|
2132
|
-
# Evenly space samples across the episode
|
|
2133
|
-
step_size = len(episode.steps) // num_samples
|
|
2134
|
-
sample_indices = [i * step_size for i in range(num_samples)]
|
|
2135
|
-
|
|
2136
|
-
print(f" Running inference on {len(sample_indices)} sample steps...")
|
|
2137
|
-
|
|
2138
|
-
# Switch adapter to eval mode
|
|
2139
|
-
adapter.eval()
|
|
2140
|
-
|
|
2141
|
-
comparison_data = []
|
|
2142
|
-
action_history: List[str] = []
|
|
2143
|
-
total_steps = len(episode.steps)
|
|
2144
|
-
|
|
2145
|
-
for i, step in enumerate(episode.steps):
|
|
2146
|
-
step_data = {
|
|
2147
|
-
"index": i,
|
|
2148
|
-
"time": step.t,
|
|
2149
|
-
"image_path": step.observation.image_path,
|
|
2150
|
-
"human_action": {
|
|
2151
|
-
"type": step.action.type,
|
|
2152
|
-
"x": step.action.x,
|
|
2153
|
-
"y": step.action.y,
|
|
2154
|
-
"text": step.action.text,
|
|
2155
|
-
},
|
|
2156
|
-
"predicted_action": None,
|
|
2157
|
-
"match": None,
|
|
2158
|
-
}
|
|
2159
|
-
|
|
2160
|
-
# Only run inference on selected samples (for speed)
|
|
2161
|
-
if i in sample_indices and step.observation.image_path:
|
|
2162
|
-
try:
|
|
2163
|
-
predicted = predict_action(
|
|
2164
|
-
adapter,
|
|
2165
|
-
step.observation.image_path,
|
|
2166
|
-
episode.goal,
|
|
2167
|
-
step_index=i,
|
|
2168
|
-
total_steps=total_steps,
|
|
2169
|
-
action_history=action_history.copy(),
|
|
2170
|
-
)
|
|
2171
|
-
step_data["predicted_action"] = predicted
|
|
2172
|
-
|
|
2173
|
-
# Check match and calculate distance
|
|
2174
|
-
if predicted and predicted.get("type") == step.action.type:
|
|
2175
|
-
step_data["match"] = True
|
|
2176
|
-
|
|
2177
|
-
# Calculate distance for click actions
|
|
2178
|
-
if step.action.type == "click":
|
|
2179
|
-
hx, hy = step.action.x or 0, step.action.y or 0
|
|
2180
|
-
px, py = predicted.get("x", 0), predicted.get("y", 0)
|
|
2181
|
-
distance = ((hx - px) ** 2 + (hy - py) ** 2) ** 0.5
|
|
2182
|
-
|
|
2183
|
-
# Log evaluation to training state
|
|
2184
|
-
logger.state.log_evaluation(
|
|
2185
|
-
epoch=epoch,
|
|
2186
|
-
sample_idx=i,
|
|
2187
|
-
image_path=step.observation.image_path,
|
|
2188
|
-
human_action=step_data["human_action"],
|
|
2189
|
-
predicted_action=predicted,
|
|
2190
|
-
)
|
|
2191
|
-
else:
|
|
2192
|
-
step_data["match"] = False
|
|
2193
|
-
|
|
2194
|
-
print(f" Step {i}: {step.action.type} -> {predicted.get('type') if predicted else 'none'}")
|
|
2195
|
-
|
|
2196
|
-
except Exception as e:
|
|
2197
|
-
print(f" Step {i}: inference failed - {e}")
|
|
2198
|
-
|
|
2199
|
-
# Build action history for context
|
|
2200
|
-
action_history.append(format_action(step.action, use_som=False))
|
|
2201
|
-
comparison_data.append(step_data)
|
|
2202
|
-
|
|
2203
|
-
# Switch back to train mode
|
|
2204
|
-
adapter.train()
|
|
2205
|
-
|
|
2206
|
-
# Generate comparison HTML
|
|
2207
|
-
output_path = output_dir / f"comparison_epoch{epoch}.html"
|
|
2208
|
-
capture_path = Path(logger.state.capture_path) if logger.state.capture_path else Path(".")
|
|
2209
|
-
|
|
2210
|
-
generate_comparison_html(capture_path, episode, comparison_data, output_path)
|
|
2211
|
-
print(f" Comparison saved: {output_path}")
|
|
2212
|
-
|
|
2213
|
-
# Also regenerate all dashboards to update navigation
|
|
2214
|
-
regenerate_all_dashboards(output_dir)
|
|
2215
|
-
|
|
2216
|
-
return output_path
|
|
2217
|
-
|
|
2218
|
-
|
|
2219
|
-
def _create_dataloader(dataset: Dataset, batch_size: int) -> DataLoader:
|
|
2220
|
-
# Use an identity collate_fn so that each batch is a List[Dict], matching
|
|
2221
|
-
# the expectations of adapters that operate on SFT-style samples.
|
|
2222
|
-
return DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=lambda x: x)
|
|
2223
|
-
|
|
2224
|
-
|
|
2225
|
-
def _create_lr_scheduler(
|
|
2226
|
-
optimizer: Optimizer,
|
|
2227
|
-
config: TrainingConfig,
|
|
2228
|
-
num_training_steps: int,
|
|
2229
|
-
) -> Optional[LambdaLR]:
|
|
2230
|
-
"""Create learning rate scheduler based on config.
|
|
2231
|
-
|
|
2232
|
-
Args:
|
|
2233
|
-
optimizer: The optimizer to schedule.
|
|
2234
|
-
config: Training configuration with lr_scheduler_type and warmup_ratio.
|
|
2235
|
-
num_training_steps: Total number of training steps.
|
|
2236
|
-
|
|
2237
|
-
Returns:
|
|
2238
|
-
LambdaLR scheduler or None if scheduler_type is "none" or "constant".
|
|
2239
|
-
"""
|
|
2240
|
-
scheduler_type = config.lr_scheduler_type.lower()
|
|
2241
|
-
|
|
2242
|
-
if scheduler_type in ("none", "constant"):
|
|
2243
|
-
return None
|
|
2244
|
-
|
|
2245
|
-
num_warmup_steps = int(num_training_steps * config.warmup_ratio)
|
|
2246
|
-
|
|
2247
|
-
if scheduler_type == "linear":
|
|
2248
|
-
def lr_lambda(current_step: int) -> float:
|
|
2249
|
-
if current_step < num_warmup_steps:
|
|
2250
|
-
# Linear warmup
|
|
2251
|
-
return float(current_step) / float(max(1, num_warmup_steps))
|
|
2252
|
-
# Linear decay
|
|
2253
|
-
return max(
|
|
2254
|
-
0.0,
|
|
2255
|
-
float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))
|
|
2256
|
-
)
|
|
2257
|
-
elif scheduler_type == "cosine":
|
|
2258
|
-
import math
|
|
2259
|
-
def lr_lambda(current_step: int) -> float:
|
|
2260
|
-
if current_step < num_warmup_steps:
|
|
2261
|
-
# Linear warmup
|
|
2262
|
-
return float(current_step) / float(max(1, num_warmup_steps))
|
|
2263
|
-
# Cosine decay
|
|
2264
|
-
progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
|
|
2265
|
-
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress)))
|
|
2266
|
-
else:
|
|
2267
|
-
raise ValueError(f"Unknown lr_scheduler_type: {scheduler_type}. Use 'linear', 'cosine', 'constant', or 'none'.")
|
|
2268
|
-
|
|
2269
|
-
return LambdaLR(optimizer, lr_lambda)
|
|
2270
|
-
|
|
2271
|
-
|
|
2272
|
-
def train_supervised(
|
|
2273
|
-
adapter: BaseVLMAdapter,
|
|
2274
|
-
dataset: Dataset,
|
|
2275
|
-
config: TrainingConfig,
|
|
2276
|
-
optimizer: Optional[Optimizer] = None,
|
|
2277
|
-
logger: Optional[TrainingLogger] = None,
|
|
2278
|
-
episode: Optional[Episode] = None,
|
|
2279
|
-
) -> bool:
|
|
2280
|
-
"""Minimal supervised training loop skeleton.
|
|
2281
|
-
|
|
2282
|
-
This assumes that `adapter.prepare_inputs` and `adapter.compute_loss` are
|
|
2283
|
-
implemented. It will raise if those methods are not implemented.
|
|
2284
|
-
|
|
2285
|
-
Args:
|
|
2286
|
-
adapter: VLM adapter to train.
|
|
2287
|
-
dataset: Training dataset.
|
|
2288
|
-
config: Training configuration.
|
|
2289
|
-
optimizer: Optional optimizer (default: AdamW).
|
|
2290
|
-
logger: Optional training logger for visualization.
|
|
2291
|
-
episode: Optional episode for periodic evaluation (generates comparison_epoch{N}.html).
|
|
2292
|
-
|
|
2293
|
-
Returns:
|
|
2294
|
-
True if training completed successfully, False if aborted due to NaN/Inf loss.
|
|
2295
|
-
"""
|
|
2296
|
-
|
|
2297
|
-
device = adapter.device # type: ignore[attr-defined]
|
|
2298
|
-
dataloader = _create_dataloader(dataset, batch_size=config.per_device_train_batch_size)
|
|
2299
|
-
|
|
2300
|
-
if optimizer is None:
|
|
2301
|
-
optimizer = torch.optim.AdamW(
|
|
2302
|
-
adapter.model.parameters(), # type: ignore[arg-type]
|
|
2303
|
-
lr=config.learning_rate,
|
|
2304
|
-
weight_decay=config.weight_decay,
|
|
2305
|
-
)
|
|
2306
|
-
|
|
2307
|
-
# Create logger if not provided
|
|
2308
|
-
if logger is None:
|
|
2309
|
-
logger = TrainingLogger(config.output_dir, config)
|
|
2310
|
-
|
|
2311
|
-
# Calculate total training steps for scheduler
|
|
2312
|
-
num_training_steps = len(dataloader) * config.num_train_epochs // config.gradient_accumulation_steps
|
|
2313
|
-
|
|
2314
|
-
# Create learning rate scheduler
|
|
2315
|
-
lr_scheduler = _create_lr_scheduler(optimizer, config, num_training_steps)
|
|
2316
|
-
|
|
2317
|
-
total_steps = 0
|
|
2318
|
-
adapter.train()
|
|
2319
|
-
|
|
2320
|
-
# Early stopping tracking
|
|
2321
|
-
consecutive_low_loss = 0
|
|
2322
|
-
early_stopped = False
|
|
2323
|
-
user_stopped = False
|
|
2324
|
-
|
|
2325
|
-
for epoch in range(config.num_train_epochs):
|
|
2326
|
-
if early_stopped or user_stopped:
|
|
2327
|
-
break
|
|
2328
|
-
|
|
2329
|
-
for _, batch in enumerate(dataloader):
|
|
2330
|
-
# Check for stop signal from dashboard
|
|
2331
|
-
stop_file = Path(config.output_dir) / "STOP_TRAINING"
|
|
2332
|
-
if stop_file.exists():
|
|
2333
|
-
msg = "Stop signal received from dashboard. Stopping training..."
|
|
2334
|
-
print(msg)
|
|
2335
|
-
logger._log_to_terminal(msg)
|
|
2336
|
-
# Set termination status for dashboard
|
|
2337
|
-
logger.state.termination_status = "user_stop"
|
|
2338
|
-
logger.state.termination_message = "Training stopped by user via dashboard"
|
|
2339
|
-
logger.save()
|
|
2340
|
-
user_stopped = True
|
|
2341
|
-
stop_file.unlink() # Remove signal file
|
|
2342
|
-
break
|
|
2343
|
-
|
|
2344
|
-
# Batch is a List[Dict[str, Any]] of SFT-style samples; adapter is
|
|
2345
|
-
# responsible for converting it into model inputs.
|
|
2346
|
-
samples: List[Dict[str, Any]] = batch
|
|
2347
|
-
|
|
2348
|
-
inputs = adapter.prepare_inputs(samples)
|
|
2349
|
-
inputs = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
|
|
2350
|
-
|
|
2351
|
-
loss = adapter.compute_loss(inputs)
|
|
2352
|
-
|
|
2353
|
-
# Guard against invalid losses to avoid propagating NaNs/Infs
|
|
2354
|
-
if torch.isnan(loss) or torch.isinf(loss):
|
|
2355
|
-
msg = f"Encountered invalid loss at epoch={epoch} step={total_steps + 1}: {loss.item()}"
|
|
2356
|
-
print(msg)
|
|
2357
|
-
logger._log_to_terminal(msg)
|
|
2358
|
-
logger.on_train_end()
|
|
2359
|
-
return False
|
|
2360
|
-
|
|
2361
|
-
loss.backward()
|
|
2362
|
-
|
|
2363
|
-
if (total_steps + 1) % config.gradient_accumulation_steps == 0:
|
|
2364
|
-
torch.nn.utils.clip_grad_norm_(adapter.model.parameters(), config.max_grad_norm) # type: ignore[arg-type]
|
|
2365
|
-
optimizer.step()
|
|
2366
|
-
if lr_scheduler is not None:
|
|
2367
|
-
lr_scheduler.step()
|
|
2368
|
-
optimizer.zero_grad()
|
|
2369
|
-
|
|
2370
|
-
total_steps += 1
|
|
2371
|
-
loss_val = loss.item()
|
|
2372
|
-
|
|
2373
|
-
# Get current learning rate from optimizer
|
|
2374
|
-
current_lr = optimizer.param_groups[0]['lr']
|
|
2375
|
-
|
|
2376
|
-
# Log step
|
|
2377
|
-
logger.on_step(epoch, total_steps, loss_val, current_lr)
|
|
2378
|
-
|
|
2379
|
-
if config.logging_steps and total_steps % config.logging_steps == 0:
|
|
2380
|
-
msg = f"epoch={epoch} step={total_steps} loss={loss_val:.4f} lr={current_lr:.6f}"
|
|
2381
|
-
print(msg)
|
|
2382
|
-
logger._log_to_terminal(msg)
|
|
2383
|
-
|
|
2384
|
-
# Early stopping check
|
|
2385
|
-
if loss_val < config.early_stop_loss:
|
|
2386
|
-
consecutive_low_loss += 1
|
|
2387
|
-
if consecutive_low_loss >= config.early_stop_patience:
|
|
2388
|
-
msg = (
|
|
2389
|
-
f"Early stopping: loss ({loss_val:.6f}) below threshold "
|
|
2390
|
-
f"({config.early_stop_loss}) for {config.early_stop_patience} consecutive steps"
|
|
2391
|
-
)
|
|
2392
|
-
print(msg)
|
|
2393
|
-
logger._log_to_terminal(msg)
|
|
2394
|
-
# Set termination status for dashboard
|
|
2395
|
-
logger.state.termination_status = "auto_low_loss"
|
|
2396
|
-
logger.state.termination_message = (
|
|
2397
|
-
f"Loss reached {loss_val:.6f} (< {config.early_stop_loss}) "
|
|
2398
|
-
f"for {config.early_stop_patience} consecutive steps"
|
|
2399
|
-
)
|
|
2400
|
-
logger.save()
|
|
2401
|
-
early_stopped = True
|
|
2402
|
-
break
|
|
2403
|
-
else:
|
|
2404
|
-
consecutive_low_loss = 0
|
|
2405
|
-
|
|
2406
|
-
# End of epoch
|
|
2407
|
-
logger.on_epoch_end(epoch)
|
|
2408
|
-
|
|
2409
|
-
# Save checkpoint at end of each epoch
|
|
2410
|
-
if config.save_checkpoint_every_epoch:
|
|
2411
|
-
checkpoint_path = Path(config.checkpoint_dir) / f"epoch_{epoch}"
|
|
2412
|
-
checkpoint_path.mkdir(parents=True, exist_ok=True)
|
|
2413
|
-
try:
|
|
2414
|
-
adapter.save_checkpoint(str(checkpoint_path))
|
|
2415
|
-
msg = f"Checkpoint saved to {checkpoint_path}"
|
|
2416
|
-
print(msg)
|
|
2417
|
-
logger._log_to_terminal(msg)
|
|
2418
|
-
except Exception as e:
|
|
2419
|
-
msg = f"Warning: Failed to save checkpoint: {e}"
|
|
2420
|
-
print(msg)
|
|
2421
|
-
logger._log_to_terminal(msg)
|
|
2422
|
-
|
|
2423
|
-
# Run evaluation after each epoch (generates comparison_epoch{N}.html)
|
|
2424
|
-
if config.eval_every_epoch and episode is not None:
|
|
2425
|
-
try:
|
|
2426
|
-
print(f"Running epoch {epoch} evaluation...")
|
|
2427
|
-
run_epoch_evaluation(
|
|
2428
|
-
adapter=adapter,
|
|
2429
|
-
episode=episode,
|
|
2430
|
-
epoch=epoch,
|
|
2431
|
-
config=config,
|
|
2432
|
-
logger=logger,
|
|
2433
|
-
)
|
|
2434
|
-
except Exception as e:
|
|
2435
|
-
print(f"Warning: Epoch evaluation failed: {e}")
|
|
2436
|
-
import traceback
|
|
2437
|
-
traceback.print_exc()
|
|
2438
|
-
|
|
2439
|
-
# Set termination status if not already set (normal completion)
|
|
2440
|
-
if not logger.state.termination_status:
|
|
2441
|
-
logger.state.termination_status = "auto_complete"
|
|
2442
|
-
logger.state.termination_message = f"Training completed all {config.num_train_epochs} epochs"
|
|
2443
|
-
logger.save()
|
|
2444
|
-
|
|
2445
|
-
logger.on_train_end()
|
|
2446
|
-
return True
|