openadapt-ml 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. openadapt_ml/__init__.py +0 -0
  2. openadapt_ml/benchmarks/__init__.py +125 -0
  3. openadapt_ml/benchmarks/agent.py +825 -0
  4. openadapt_ml/benchmarks/azure.py +761 -0
  5. openadapt_ml/benchmarks/base.py +366 -0
  6. openadapt_ml/benchmarks/cli.py +884 -0
  7. openadapt_ml/benchmarks/data_collection.py +432 -0
  8. openadapt_ml/benchmarks/runner.py +381 -0
  9. openadapt_ml/benchmarks/waa.py +704 -0
  10. openadapt_ml/cloud/__init__.py +5 -0
  11. openadapt_ml/cloud/azure_inference.py +441 -0
  12. openadapt_ml/cloud/lambda_labs.py +2445 -0
  13. openadapt_ml/cloud/local.py +790 -0
  14. openadapt_ml/config.py +56 -0
  15. openadapt_ml/datasets/__init__.py +0 -0
  16. openadapt_ml/datasets/next_action.py +507 -0
  17. openadapt_ml/evals/__init__.py +23 -0
  18. openadapt_ml/evals/grounding.py +241 -0
  19. openadapt_ml/evals/plot_eval_metrics.py +174 -0
  20. openadapt_ml/evals/trajectory_matching.py +486 -0
  21. openadapt_ml/grounding/__init__.py +45 -0
  22. openadapt_ml/grounding/base.py +236 -0
  23. openadapt_ml/grounding/detector.py +570 -0
  24. openadapt_ml/ingest/__init__.py +43 -0
  25. openadapt_ml/ingest/capture.py +312 -0
  26. openadapt_ml/ingest/loader.py +232 -0
  27. openadapt_ml/ingest/synthetic.py +1102 -0
  28. openadapt_ml/models/__init__.py +0 -0
  29. openadapt_ml/models/api_adapter.py +171 -0
  30. openadapt_ml/models/base_adapter.py +59 -0
  31. openadapt_ml/models/dummy_adapter.py +42 -0
  32. openadapt_ml/models/qwen_vl.py +426 -0
  33. openadapt_ml/runtime/__init__.py +0 -0
  34. openadapt_ml/runtime/policy.py +182 -0
  35. openadapt_ml/schemas/__init__.py +53 -0
  36. openadapt_ml/schemas/sessions.py +122 -0
  37. openadapt_ml/schemas/validation.py +252 -0
  38. openadapt_ml/scripts/__init__.py +0 -0
  39. openadapt_ml/scripts/compare.py +1490 -0
  40. openadapt_ml/scripts/demo_policy.py +62 -0
  41. openadapt_ml/scripts/eval_policy.py +287 -0
  42. openadapt_ml/scripts/make_gif.py +153 -0
  43. openadapt_ml/scripts/prepare_synthetic.py +43 -0
  44. openadapt_ml/scripts/run_qwen_login_benchmark.py +192 -0
  45. openadapt_ml/scripts/train.py +174 -0
  46. openadapt_ml/training/__init__.py +0 -0
  47. openadapt_ml/training/benchmark_viewer.py +1538 -0
  48. openadapt_ml/training/shared_ui.py +157 -0
  49. openadapt_ml/training/stub_provider.py +276 -0
  50. openadapt_ml/training/trainer.py +2446 -0
  51. openadapt_ml/training/viewer.py +2970 -0
  52. openadapt_ml-0.1.0.dist-info/METADATA +818 -0
  53. openadapt_ml-0.1.0.dist-info/RECORD +55 -0
  54. openadapt_ml-0.1.0.dist-info/WHEEL +4 -0
  55. openadapt_ml-0.1.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,2446 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import time
5
+ from dataclasses import dataclass, field
6
+ from pathlib import Path
7
+ from typing import Any, Callable, Dict, List, Optional
8
+
9
+ import torch
10
+ from torch.optim import Optimizer
11
+ from torch.optim.lr_scheduler import LambdaLR
12
+ from torch.utils.data import DataLoader, Dataset
13
+
14
+ from openadapt_ml.models.base_adapter import BaseVLMAdapter
15
+ from openadapt_ml.schemas.sessions import Episode
16
+ from openadapt_ml.training.shared_ui import (
17
+ get_shared_header_css as _get_shared_header_css,
18
+ generate_shared_header_html as _generate_shared_header_html,
19
+ build_nav_links as _build_nav_links,
20
+ )
21
+ from openadapt_ml.training.viewer import (
22
+ generate_unified_viewer_from_output_dir,
23
+ )
24
+
25
+
26
+ def setup_job_directory(base_dir: str | Path, job_id: str) -> Path:
27
+ """Set up job-scoped directory structure with symlink.
28
+
29
+ Creates:
30
+ {base_dir}/{job_id}/ - Job-specific directory
31
+ {base_dir}/current - Symlink to current job directory
32
+
33
+ Args:
34
+ base_dir: Base output directory (e.g., "training_output")
35
+ job_id: Unique job identifier (e.g., "20251214_200417")
36
+
37
+ Returns:
38
+ Path to the job-specific directory
39
+ """
40
+ base_dir = Path(base_dir)
41
+ job_dir = base_dir / job_id
42
+ current_link = base_dir / "current"
43
+
44
+ # Create base and job directories
45
+ base_dir.mkdir(parents=True, exist_ok=True)
46
+ job_dir.mkdir(parents=True, exist_ok=True)
47
+
48
+ # Atomically update the 'current' symlink
49
+ # Use a temp link then rename for atomic operation
50
+ temp_link = base_dir / f".current_temp_{job_id}"
51
+ try:
52
+ # Remove temp link if it exists from a previous failed attempt
53
+ if temp_link.exists() or temp_link.is_symlink():
54
+ temp_link.unlink()
55
+
56
+ # Create temp symlink pointing to job_id (relative path)
57
+ temp_link.symlink_to(job_id)
58
+
59
+ # Atomically replace current with temp
60
+ temp_link.rename(current_link)
61
+ except Exception as e:
62
+ # Clean up temp link on failure
63
+ if temp_link.exists() or temp_link.is_symlink():
64
+ temp_link.unlink()
65
+ raise RuntimeError(f"Failed to create current symlink: {e}")
66
+
67
+ return job_dir
68
+
69
+
70
+ def get_current_job_directory(base_dir: str | Path) -> Path | None:
71
+ """Get the current job directory from symlink.
72
+
73
+ Returns:
74
+ Path to current job directory, or None if no current symlink
75
+ """
76
+ base_dir = Path(base_dir)
77
+ current_link = base_dir / "current"
78
+
79
+ if current_link.is_symlink():
80
+ return current_link.resolve()
81
+ return None
82
+
83
+
84
+ @dataclass
85
+ class TrainingConfig:
86
+ # Model / LoRA-related fields are handled elsewhere; this covers loop hyperparams.
87
+ num_train_epochs: int = 1
88
+ per_device_train_batch_size: int = 1
89
+ gradient_accumulation_steps: int = 1
90
+ learning_rate: float = 2e-4
91
+ warmup_ratio: float = 0.03
92
+ weight_decay: float = 0.0
93
+ max_grad_norm: float = 1.0
94
+ logging_steps: int = 10
95
+ # Learning rate scheduler
96
+ lr_scheduler_type: str = "linear" # Options: linear, cosine, constant, none
97
+ # Early stopping: stop when loss is below threshold for patience consecutive steps
98
+ early_stop_loss: float = 1e-4
99
+ early_stop_patience: int = 10
100
+ # Output directory for logs and visualizations
101
+ output_dir: str = "training_output"
102
+ # Checkpoint saving
103
+ save_checkpoint_every_epoch: bool = True
104
+ checkpoint_dir: str = "checkpoints"
105
+ # Evaluation during training
106
+ eval_every_epoch: bool = True
107
+ eval_samples: int = 3 # Number of samples to evaluate per epoch
108
+
109
+
110
+ @dataclass
111
+ class TrainingState:
112
+ """Tracks training progress for visualization."""
113
+ # Job identification
114
+ job_id: str = field(default_factory=lambda: time.strftime("%Y%m%d_%H%M%S"))
115
+ hostname: str = field(default_factory=lambda: __import__('socket').gethostname())
116
+ capture_path: str = ""
117
+ config_path: str = ""
118
+ goal: str = "" # Task goal/description for the training run
119
+ # Training progress
120
+ epoch: int = 0
121
+ step: int = 0
122
+ total_steps: int = 0
123
+ total_epochs: int = 1 # Set by logger from config
124
+ loss: float = 0.0
125
+ learning_rate: float = 0.0
126
+ samples_seen: int = 0
127
+ start_time: float = field(default_factory=time.time)
128
+ elapsed_time: float = 0.0 # For historical data loaded from JSON
129
+ losses: List[Dict[str, Any]] = field(default_factory=list)
130
+ evaluations: List[Dict[str, Any]] = field(default_factory=list)
131
+ # Cloud info (optional)
132
+ instance_type: str = ""
133
+ instance_ip: str = ""
134
+ # Cloud provider info (for dashboard link)
135
+ cloud_provider: str = "" # e.g. "lambda", "azure"
136
+ cloud_dashboard_url: str = "" # e.g. "https://cloud.lambda.ai/instances"
137
+ cloud_instance_id: str = "" # Provider-specific instance ID
138
+ # Setup status tracking
139
+ setup_status: str = "" # e.g. "booting", "installing", "training", "complete"
140
+ setup_logs: List[str] = field(default_factory=list) # Setup progress messages
141
+ # Termination tracking
142
+ termination_status: str = "" # e.g. "auto_low_loss", "auto_complete", "user_stop", "running"
143
+ termination_message: str = "" # Human-readable termination reason
144
+
145
+ def log_step(self, epoch: int, step: int, loss: float, lr: float = 0.0) -> None:
146
+ """Log a training step."""
147
+ self.epoch = epoch
148
+ self.step = step
149
+ self.loss = loss
150
+ self.learning_rate = lr
151
+ self.losses.append({
152
+ "epoch": epoch,
153
+ "step": step,
154
+ "loss": loss,
155
+ "lr": lr,
156
+ "time": time.time() - self.start_time,
157
+ })
158
+
159
+ def log_evaluation(self, epoch: int, sample_idx: int, image_path: str,
160
+ human_action: Dict, predicted_action: Dict) -> None:
161
+ """Log an evaluation sample."""
162
+ # Calculate distance for click actions
163
+ distance = 0.0
164
+ if human_action.get("type") == "click" and predicted_action.get("type") == "click":
165
+ hx, hy = human_action.get("x", 0), human_action.get("y", 0)
166
+ px, py = predicted_action.get("x", 0), predicted_action.get("y", 0)
167
+ distance = ((hx - px) ** 2 + (hy - py) ** 2) ** 0.5
168
+
169
+ self.evaluations.append({
170
+ "epoch": epoch,
171
+ "sample_idx": sample_idx,
172
+ "image_path": image_path,
173
+ "human_action": human_action,
174
+ "predicted_action": predicted_action,
175
+ "distance": distance,
176
+ "correct": distance < 50, # Within 50 pixels is "correct"
177
+ })
178
+
179
+ def to_dict(self) -> Dict[str, Any]:
180
+ """Convert state to serializable dict."""
181
+ return {
182
+ # Job metadata
183
+ "job_id": self.job_id,
184
+ "hostname": self.hostname,
185
+ "capture_path": self.capture_path,
186
+ "config_path": self.config_path,
187
+ "goal": self.goal,
188
+ "instance_type": self.instance_type,
189
+ "instance_ip": self.instance_ip,
190
+ "started_at": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime(self.start_time)),
191
+ # Cloud provider info
192
+ "cloud_provider": self.cloud_provider,
193
+ "cloud_dashboard_url": self.cloud_dashboard_url,
194
+ "cloud_instance_id": self.cloud_instance_id,
195
+ "setup_status": self.setup_status,
196
+ "setup_logs": self.setup_logs,
197
+ # Training progress
198
+ "epoch": self.epoch,
199
+ "step": self.step,
200
+ "total_steps": self.total_steps,
201
+ "total_epochs": self.total_epochs,
202
+ "loss": self.loss,
203
+ "learning_rate": self.learning_rate,
204
+ "samples_seen": self.samples_seen,
205
+ "elapsed_time": time.time() - self.start_time,
206
+ "losses": self.losses,
207
+ "evaluations": self.evaluations,
208
+ # Termination tracking
209
+ "termination_status": self.termination_status,
210
+ "termination_message": self.termination_message,
211
+ }
212
+
213
+
214
+ class TrainingLogger:
215
+ """Logs training progress and generates visualization."""
216
+
217
+ def __init__(
218
+ self,
219
+ output_dir: str | Path,
220
+ config: TrainingConfig,
221
+ capture_path: str = "",
222
+ config_path: str = "",
223
+ goal: str = "",
224
+ instance_ip: str = "",
225
+ instance_type: str = "",
226
+ cloud_provider: str = "",
227
+ cloud_dashboard_url: str = "",
228
+ cloud_instance_id: str = "",
229
+ job_id: str = "",
230
+ ):
231
+ # Generate job_id if not provided
232
+ if not job_id:
233
+ job_id = time.strftime("%Y%m%d_%H%M%S")
234
+
235
+ # Set up job-scoped directory with symlink
236
+ base_dir = Path(output_dir)
237
+ self.base_dir = base_dir
238
+ self.output_dir = setup_job_directory(base_dir, job_id)
239
+ self.config = config
240
+ self.state = TrainingState(
241
+ job_id=job_id,
242
+ capture_path=capture_path,
243
+ config_path=config_path,
244
+ goal=goal,
245
+ instance_ip=instance_ip,
246
+ instance_type=instance_type,
247
+ total_epochs=config.num_train_epochs,
248
+ cloud_provider=cloud_provider,
249
+ cloud_dashboard_url=cloud_dashboard_url,
250
+ cloud_instance_id=cloud_instance_id,
251
+ )
252
+ self.log_file = self.output_dir / "training_log.json"
253
+ self.terminal_log_file = self.output_dir / "training.log"
254
+ self.terminal_log_handle = None
255
+
256
+ # Save config snapshot
257
+ self._save_config_snapshot()
258
+
259
+ def _log_to_terminal(self, message: str):
260
+ """Write message to training.log file.
261
+
262
+ Args:
263
+ message: Message to log
264
+ """
265
+ from datetime import datetime
266
+
267
+ timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
268
+ log_line = f"[{timestamp}] {message}"
269
+
270
+ # Open file on first write (line buffered)
271
+ if self.terminal_log_handle is None:
272
+ self.terminal_log_handle = open(self.terminal_log_file, "w", buffering=1)
273
+
274
+ self.terminal_log_handle.write(log_line + "\n")
275
+ self.terminal_log_handle.flush()
276
+
277
+ def on_step(self, epoch: int, step: int, loss: float, lr: float = 0.0) -> None:
278
+ """Called after each training step."""
279
+ self.state.log_step(epoch, step, loss, lr)
280
+ self._save_log()
281
+
282
+ def on_epoch_end(self, epoch: int) -> None:
283
+ """Called at the end of each epoch."""
284
+ self.state.epoch = epoch
285
+ self._save_log()
286
+ self._generate_dashboard()
287
+
288
+ def on_train_end(self) -> None:
289
+ """Called at the end of training."""
290
+ self._save_log()
291
+ self._generate_dashboard()
292
+ print(f"Training dashboard: {self.output_dir / 'dashboard.html'}")
293
+
294
+ # Close terminal log file
295
+ if self.terminal_log_handle:
296
+ self.terminal_log_handle.close()
297
+ self.terminal_log_handle = None
298
+
299
+ def _save_config_snapshot(self) -> None:
300
+ """Save training config snapshot to JSON."""
301
+ from dataclasses import asdict
302
+ config_file = self.output_dir / "config.json"
303
+ config_dict = asdict(self.config)
304
+ with open(config_file, "w") as f:
305
+ json.dump(config_dict, f, indent=2)
306
+
307
+ def _save_log(self) -> None:
308
+ """Save training log to JSON."""
309
+ with open(self.log_file, "w") as f:
310
+ json.dump(self.state.to_dict(), f, indent=2)
311
+
312
+ def _generate_dashboard(self) -> None:
313
+ """Generate HTML training dashboard."""
314
+ dashboard_path = self.output_dir / "dashboard.html"
315
+ html = generate_training_dashboard(self.state, self.config)
316
+ dashboard_path.write_text(html)
317
+
318
+
319
+ def _generate_termination_status_html(state: TrainingState, is_training_complete: bool) -> str:
320
+ """Generate HTML for termination status section."""
321
+ # Check if we have termination info
322
+ if state.termination_status:
323
+ # Map termination status to colors and icons
324
+ status_styles = {
325
+ "auto_complete": {"color": "#22c55e", "icon": "✓", "label": "Training Complete"},
326
+ "auto_low_loss": {"color": "#22c55e", "icon": "✓", "label": "Auto-Stopped (Low Loss)"},
327
+ "user_stop": {"color": "#f59e0b", "icon": "■", "label": "Stopped by User"},
328
+ }
329
+ style = status_styles.get(state.termination_status, {"color": "#22c55e", "icon": "✓", "label": "Complete"})
330
+
331
+ return f'''<div style="display: flex; flex-direction: column; gap: 8px;">
332
+ <div style="display: flex; align-items: center; gap: 8px; color: {style['color']};">
333
+ <span style="font-size: 1.2rem;">{style['icon']}</span>
334
+ <span style="font-weight: 600;">{style['label']}</span>
335
+ </div>
336
+ {f'<div style="font-size: 0.85rem; color: var(--text-muted); margin-left: 28px;">{state.termination_message}</div>' if state.termination_message else ''}
337
+ </div>'''
338
+ elif is_training_complete:
339
+ return '''<div style="display: flex; align-items: center; gap: 8px; color: #22c55e;">
340
+ <span style="font-size: 1.2rem;">✓</span>
341
+ <span style="font-weight: 600;">Training Complete</span>
342
+ </div>'''
343
+ else:
344
+ return '''<button id="stop-training-btn" onclick="stopTraining()" style="
345
+ background: linear-gradient(135deg, #ef4444 0%, #dc2626 100%);
346
+ color: white;
347
+ border: none;
348
+ padding: 12px 24px;
349
+ border-radius: 8px;
350
+ font-size: 0.9rem;
351
+ font-weight: 600;
352
+ cursor: pointer;
353
+ display: flex;
354
+ align-items: center;
355
+ gap: 8px;
356
+ transition: all 0.2s;
357
+ ">
358
+ <span style="font-size: 1.1rem;">■</span> Stop Training
359
+ </button>
360
+ <p id="stop-status" style="margin-top: 8px; font-size: 0.75rem; color: var(--text-muted);"></p>'''
361
+
362
+
363
+ def generate_training_dashboard(state: TrainingState, config: TrainingConfig) -> str:
364
+ """Generate an HTML dashboard for training visualization."""
365
+ losses_json = json.dumps(state.losses)
366
+ # Use stored elapsed_time if available (historical data), otherwise calculate
367
+ elapsed = state.elapsed_time if state.elapsed_time > 0 else time.time() - state.start_time
368
+ elapsed_str = f"{int(elapsed // 60)}m {int(elapsed % 60)}s"
369
+
370
+ # Calculate stats
371
+ if state.losses:
372
+ min_loss = min(l["loss"] for l in state.losses)
373
+ avg_loss = sum(l["loss"] for l in state.losses) / len(state.losses)
374
+ recent_losses = state.losses[-10:] if len(state.losses) >= 10 else state.losses
375
+ recent_avg = sum(l["loss"] for l in recent_losses) / len(recent_losses)
376
+ # Calculate step times
377
+ step_times = []
378
+ for i in range(1, len(state.losses)):
379
+ step_times.append(state.losses[i]["time"] - state.losses[i-1]["time"])
380
+ avg_step_time = sum(step_times) / len(step_times) if step_times else 0
381
+ # Loss by epoch
382
+ epoch_losses: dict = {}
383
+ for l in state.losses:
384
+ ep = l["epoch"]
385
+ if ep not in epoch_losses:
386
+ epoch_losses[ep] = []
387
+ epoch_losses[ep].append(l["loss"])
388
+ epoch_avg = {ep: sum(losses)/len(losses) for ep, losses in epoch_losses.items()}
389
+ # Estimate ETA
390
+ # Steps per epoch = steps in completed epochs / completed epochs
391
+ completed_epochs = state.epoch
392
+ steps_in_completed = sum(1 for l in state.losses if l["epoch"] < completed_epochs)
393
+ if completed_epochs > 0 and steps_in_completed > 0:
394
+ steps_per_epoch = steps_in_completed / completed_epochs
395
+ else:
396
+ # Estimate from current epoch progress
397
+ steps_per_epoch = len(state.losses) / (state.epoch + 1) if state.epoch >= 0 else len(state.losses)
398
+
399
+ total_epochs = state.total_epochs if state.total_epochs > 0 else config.num_train_epochs
400
+ total_steps_estimate = steps_per_epoch * total_epochs
401
+ remaining_steps = max(0, total_steps_estimate - len(state.losses))
402
+ eta_seconds = remaining_steps * avg_step_time if avg_step_time > 0 else 0
403
+ # Check if training is complete (all steps done)
404
+ is_training_complete = remaining_steps == 0 and len(state.losses) > 0
405
+ else:
406
+ min_loss = avg_loss = recent_avg = avg_step_time = 0.0
407
+ epoch_avg = {}
408
+ eta_seconds = 0
409
+ steps_per_epoch = 0
410
+ total_steps_estimate = 0
411
+ remaining_steps = 0
412
+ is_training_complete = False
413
+
414
+ epoch_avg_json = json.dumps(list(epoch_avg.items()))
415
+
416
+ # Generate comparison viewer preview if capture path available
417
+ comparison_viewer_path = ""
418
+ if state.capture_path:
419
+ try:
420
+ from openadapt_ml.scripts.compare import generate_comparison_html, generate_comparison_data
421
+ from openadapt_ml.ingest.capture import capture_to_episode
422
+
423
+ capture_path = Path(state.capture_path)
424
+ if capture_path.exists():
425
+ # Load episode from capture
426
+ episode = capture_to_episode(capture_path)
427
+
428
+ # Generate comparison data with null predictions (shows "— No prediction")
429
+ comparison_data = []
430
+ for i, step in enumerate(episode.steps):
431
+ step_data = {
432
+ "index": i,
433
+ "time": step.t,
434
+ "image_path": step.observation.image_path,
435
+ "human_action": {
436
+ "type": step.action.type,
437
+ "x": step.action.x,
438
+ "y": step.action.y,
439
+ "text": step.action.text,
440
+ },
441
+ "predicted_action": None, # Shows "— No prediction" in viewer
442
+ "match": None,
443
+ }
444
+ comparison_data.append(step_data)
445
+
446
+ # Generate comparison HTML
447
+ output_dir = Path(config.output_dir) if hasattr(config, 'output_dir') else Path("training_output")
448
+ output_dir.mkdir(parents=True, exist_ok=True)
449
+ comparison_output = output_dir / "comparison_preview.html"
450
+ generate_comparison_html(capture_path, episode, comparison_data, comparison_output)
451
+ comparison_viewer_path = str(comparison_output.name) # Relative path
452
+ except Exception as e:
453
+ pass # Fail silently if comparison viewer can't be generated
454
+
455
+ html = f'''<!DOCTYPE html>
456
+ <html lang="en">
457
+ <head>
458
+ <meta charset="UTF-8">
459
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
460
+ <title>Training Dashboard</title>
461
+ <script src="https://cdn.jsdelivr.net/npm/chart.js"></script>
462
+ <style>
463
+ :root {{
464
+ --bg-primary: #0a0a0f;
465
+ --bg-secondary: #12121a;
466
+ --bg-tertiary: #1a1a24;
467
+ --border-color: rgba(255, 255, 255, 0.06);
468
+ --text-primary: #f0f0f0;
469
+ --text-secondary: #888;
470
+ --accent: #00d4aa;
471
+ --accent-secondary: #a78bfa;
472
+ }}
473
+ * {{ box-sizing: border-box; margin: 0; padding: 0; }}
474
+ body {{
475
+ font-family: -apple-system, BlinkMacSystemFont, "Inter", sans-serif;
476
+ background: var(--bg-primary);
477
+ color: var(--text-primary);
478
+ min-height: 100vh;
479
+ }}
480
+ .container {{ max-width: 1400px; margin: 0 auto; padding: 24px; }}
481
+ header {{
482
+ display: flex;
483
+ justify-content: space-between;
484
+ align-items: center;
485
+ padding: 20px 24px;
486
+ background: var(--bg-secondary);
487
+ border: 1px solid var(--border-color);
488
+ border-radius: 12px;
489
+ margin-bottom: 24px;
490
+ }}
491
+ header h1 {{ font-size: 1.3rem; font-weight: 600; }}
492
+ .job-info {{
493
+ display: flex;
494
+ gap: 16px;
495
+ margin-top: 4px;
496
+ font-size: 0.75rem;
497
+ color: var(--text-secondary);
498
+ }}
499
+ .job-id {{
500
+ font-family: "SF Mono", Monaco, monospace;
501
+ color: var(--accent);
502
+ }}
503
+ .job-host {{
504
+ font-family: "SF Mono", Monaco, monospace;
505
+ }}
506
+ .job-config {{
507
+ font-family: "SF Mono", Monaco, monospace;
508
+ opacity: 0.7;
509
+ }}
510
+ .cloud-link {{
511
+ display: inline-flex;
512
+ align-items: center;
513
+ gap: 6px;
514
+ padding: 6px 12px;
515
+ background: var(--bg-tertiary);
516
+ border: 1px solid var(--border-color);
517
+ border-radius: 6px;
518
+ font-size: 0.75rem;
519
+ color: var(--text-primary);
520
+ text-decoration: none;
521
+ transition: all 0.2s;
522
+ }}
523
+ .cloud-link:hover {{
524
+ border-color: var(--accent);
525
+ background: rgba(0, 212, 170, 0.1);
526
+ }}
527
+ .cloud-link svg {{
528
+ width: 14px;
529
+ height: 14px;
530
+ }}
531
+ .cloud-badge {{
532
+ background: linear-gradient(135deg, rgba(167, 139, 250, 0.2), rgba(0, 212, 170, 0.1));
533
+ border-color: rgba(167, 139, 250, 0.3);
534
+ margin-left: 12px;
535
+ }}
536
+ .setup-panel {{
537
+ background: var(--bg-secondary);
538
+ border: 1px solid var(--border-color);
539
+ border-radius: 12px;
540
+ padding: 20px;
541
+ margin-bottom: 24px;
542
+ }}
543
+ .setup-panel.hidden {{
544
+ display: none;
545
+ }}
546
+ .setup-header {{
547
+ display: flex;
548
+ justify-content: space-between;
549
+ align-items: center;
550
+ margin-bottom: 12px;
551
+ }}
552
+ .setup-header h2 {{
553
+ font-size: 0.9rem;
554
+ }}
555
+ .setup-status-badge {{
556
+ display: inline-flex;
557
+ align-items: center;
558
+ gap: 6px;
559
+ padding: 4px 10px;
560
+ border-radius: 12px;
561
+ font-size: 0.7rem;
562
+ text-transform: uppercase;
563
+ letter-spacing: 0.05em;
564
+ font-weight: 600;
565
+ }}
566
+ .setup-status-badge.booting {{
567
+ background: rgba(255, 149, 0, 0.2);
568
+ color: #ff9500;
569
+ }}
570
+ .setup-status-badge.installing {{
571
+ background: rgba(167, 139, 250, 0.2);
572
+ color: #a78bfa;
573
+ }}
574
+ .setup-status-badge.training {{
575
+ background: rgba(0, 212, 170, 0.2);
576
+ color: #00d4aa;
577
+ }}
578
+ .setup-status-badge.complete {{
579
+ background: rgba(52, 211, 153, 0.2);
580
+ color: #34d399;
581
+ }}
582
+ .setup-logs {{
583
+ background: var(--bg-tertiary);
584
+ border-radius: 8px;
585
+ padding: 12px;
586
+ max-height: 200px;
587
+ overflow-y: auto;
588
+ font-family: "SF Mono", Monaco, monospace;
589
+ font-size: 0.7rem;
590
+ line-height: 1.6;
591
+ }}
592
+ .setup-log-line {{
593
+ color: var(--text-secondary);
594
+ padding: 2px 0;
595
+ }}
596
+ .setup-log-line.current {{
597
+ color: var(--accent);
598
+ }}
599
+ .status {{
600
+ display: flex;
601
+ align-items: center;
602
+ gap: 8px;
603
+ color: var(--accent);
604
+ }}
605
+ .status-dot {{
606
+ width: 10px;
607
+ height: 10px;
608
+ background: var(--accent);
609
+ border-radius: 50%;
610
+ animation: pulse 2s infinite;
611
+ }}
612
+ .status.complete .status-dot {{
613
+ animation: none;
614
+ background: #34d399;
615
+ }}
616
+ .status.stale {{
617
+ color: #ff9500;
618
+ }}
619
+ .status.stale .status-dot {{
620
+ animation: none;
621
+ background: #ff9500;
622
+ }}
623
+ .stale-warning {{
624
+ font-size: 0.7rem;
625
+ color: #ff9500;
626
+ margin-top: 2px;
627
+ }}
628
+ @keyframes pulse {{
629
+ 0%, 100% {{ opacity: 1; }}
630
+ 50% {{ opacity: 0.4; }}
631
+ }}
632
+ .stats-grid {{
633
+ display: grid;
634
+ grid-template-columns: repeat(auto-fit, minmax(160px, 1fr));
635
+ gap: 16px;
636
+ margin-bottom: 24px;
637
+ }}
638
+ .stat-card {{
639
+ background: var(--bg-secondary);
640
+ border: 1px solid var(--border-color);
641
+ border-radius: 12px;
642
+ padding: 20px;
643
+ transition: all 0.3s ease;
644
+ }}
645
+ .stat-card.updating {{
646
+ border-color: var(--accent);
647
+ box-shadow: 0 0 20px rgba(0, 212, 170, 0.1);
648
+ }}
649
+ .stat-label {{
650
+ font-size: 0.75rem;
651
+ color: var(--text-secondary);
652
+ text-transform: uppercase;
653
+ letter-spacing: 0.05em;
654
+ margin-bottom: 8px;
655
+ }}
656
+ .stat-detail {{
657
+ font-size: 0.65rem;
658
+ color: var(--text-secondary);
659
+ margin-top: 4px;
660
+ }}
661
+ .eta-card {{
662
+ background: linear-gradient(135deg, rgba(167, 139, 250, 0.1), rgba(0, 212, 170, 0.05));
663
+ border-color: rgba(167, 139, 250, 0.3);
664
+ }}
665
+ .stat-value {{
666
+ font-size: 1.6rem;
667
+ font-weight: 600;
668
+ font-family: "SF Mono", Monaco, monospace;
669
+ transition: all 0.3s ease;
670
+ }}
671
+ .stat-value.accent {{ color: var(--accent); }}
672
+ .stat-delta {{
673
+ font-size: 0.75rem;
674
+ margin-top: 4px;
675
+ font-family: "SF Mono", Monaco, monospace;
676
+ }}
677
+ .stat-delta.positive {{ color: #34d399; }}
678
+ .stat-delta.negative {{ color: #ff5f5f; }}
679
+ .charts-grid {{
680
+ display: grid;
681
+ grid-template-columns: 2fr 1fr;
682
+ gap: 16px;
683
+ margin-bottom: 24px;
684
+ }}
685
+ @media (max-width: 900px) {{
686
+ .charts-grid {{ grid-template-columns: 1fr; }}
687
+ }}
688
+ .chart-container {{
689
+ background: var(--bg-secondary);
690
+ border: 1px solid var(--border-color);
691
+ border-radius: 12px;
692
+ padding: 24px;
693
+ }}
694
+ .chart-title {{
695
+ font-size: 0.9rem;
696
+ font-weight: 600;
697
+ margin-bottom: 16px;
698
+ display: flex;
699
+ justify-content: space-between;
700
+ align-items: center;
701
+ }}
702
+ .chart-subtitle {{
703
+ font-size: 0.75rem;
704
+ color: var(--text-secondary);
705
+ font-weight: normal;
706
+ }}
707
+ .chart-wrapper {{
708
+ height: 300px;
709
+ position: relative;
710
+ }}
711
+ .config-panel {{
712
+ background: var(--bg-secondary);
713
+ border: 1px solid var(--border-color);
714
+ border-radius: 12px;
715
+ padding: 20px;
716
+ }}
717
+ .config-panel h2 {{
718
+ font-size: 0.9rem;
719
+ margin-bottom: 16px;
720
+ }}
721
+ .config-grid {{
722
+ display: grid;
723
+ grid-template-columns: repeat(auto-fit, minmax(140px, 1fr));
724
+ gap: 12px;
725
+ }}
726
+ .config-item {{
727
+ font-size: 0.8rem;
728
+ }}
729
+ .config-item .key {{
730
+ color: var(--text-secondary);
731
+ }}
732
+ .config-item .value {{
733
+ font-family: "SF Mono", Monaco, monospace;
734
+ color: var(--accent);
735
+ }}
736
+ .progress-bar {{
737
+ height: 4px;
738
+ background: var(--bg-tertiary);
739
+ border-radius: 2px;
740
+ margin-top: 8px;
741
+ overflow: hidden;
742
+ }}
743
+ .progress-fill {{
744
+ height: 100%;
745
+ background: linear-gradient(90deg, var(--accent), var(--accent-secondary));
746
+ border-radius: 2px;
747
+ transition: width 0.5s ease;
748
+ }}
749
+ .update-indicator {{
750
+ font-size: 0.7rem;
751
+ color: var(--text-secondary);
752
+ text-align: right;
753
+ margin-top: 16px;
754
+ }}
755
+ /* Shared header styles (injected from _get_shared_header_css) */
756
+ {_get_shared_header_css()}
757
+ .eval-panel {{
758
+ background: var(--bg-secondary);
759
+ border: 1px solid var(--border-color);
760
+ border-radius: 12px;
761
+ padding: 20px;
762
+ margin-top: 16px;
763
+ }}
764
+ .eval-panel h2 {{
765
+ font-size: 0.9rem;
766
+ margin-bottom: 16px;
767
+ }}
768
+ .eval-metrics {{
769
+ display: flex;
770
+ gap: 24px;
771
+ margin-bottom: 16px;
772
+ font-size: 0.85rem;
773
+ }}
774
+ .eval-metrics .metric {{
775
+ display: flex;
776
+ flex-direction: column;
777
+ }}
778
+ .eval-metrics .metric-value {{
779
+ font-size: 1.2rem;
780
+ font-weight: 600;
781
+ color: var(--accent);
782
+ }}
783
+ .eval-filters {{
784
+ display: flex;
785
+ gap: 16px;
786
+ margin-bottom: 16px;
787
+ align-items: center;
788
+ flex-wrap: wrap;
789
+ }}
790
+ .eval-filters .filter-group {{
791
+ display: flex;
792
+ align-items: center;
793
+ gap: 8px;
794
+ }}
795
+ .eval-filters label {{
796
+ font-size: 0.75rem;
797
+ color: var(--text-secondary);
798
+ text-transform: uppercase;
799
+ letter-spacing: 0.05em;
800
+ }}
801
+ .eval-filters select {{
802
+ padding: 8px 32px 8px 12px;
803
+ border-radius: 8px;
804
+ font-size: 0.85rem;
805
+ background: rgba(0,0,0,0.4);
806
+ color: var(--text-primary);
807
+ border: 1px solid rgba(255,255,255,0.1);
808
+ cursor: pointer;
809
+ appearance: none;
810
+ background-image: url('data:image/svg+xml,%3Csvg xmlns=%27http://www.w3.org/2000/svg%27 width=%2712%27 height=%278%27%3E%3Cpath fill=%27%23888%27 d=%27M0 0l6 8 6-8z%27/%3E%3C/svg%3E');
811
+ background-repeat: no-repeat;
812
+ background-position: right 10px center;
813
+ transition: all 0.2s;
814
+ }}
815
+ .eval-filters select:hover {{
816
+ border-color: var(--accent);
817
+ background-color: rgba(0,212,170,0.1);
818
+ }}
819
+ .eval-filters select:focus {{
820
+ outline: none;
821
+ border-color: var(--accent);
822
+ box-shadow: 0 0 0 2px rgba(0,212,170,0.2);
823
+ }}
824
+ .eval-gallery {{
825
+ display: grid;
826
+ grid-template-columns: repeat(auto-fit, minmax(350px, 1fr));
827
+ gap: 20px;
828
+ }}
829
+ .eval-sample {{
830
+ background: var(--bg-tertiary);
831
+ border-radius: 8px;
832
+ padding: 0;
833
+ position: relative;
834
+ overflow: hidden;
835
+ border: 1px solid var(--border-color);
836
+ }}
837
+ .eval-sample.hidden {{
838
+ display: none;
839
+ }}
840
+ .eval-sample .image-container {{
841
+ position: relative;
842
+ background: #000;
843
+ min-height: 200px;
844
+ display: flex;
845
+ align-items: center;
846
+ justify-content: center;
847
+ }}
848
+ .eval-sample img {{
849
+ width: 100%;
850
+ height: auto;
851
+ display: block;
852
+ max-height: 400px;
853
+ object-fit: contain;
854
+ }}
855
+ .eval-sample .overlay {{
856
+ position: absolute;
857
+ top: 0;
858
+ left: 0;
859
+ right: 0;
860
+ bottom: 0;
861
+ pointer-events: none;
862
+ }}
863
+ .eval-sample .marker {{
864
+ position: absolute;
865
+ width: 24px;
866
+ height: 24px;
867
+ border-radius: 50%;
868
+ transform: translate(-50%, -50%);
869
+ border: 3px solid white;
870
+ display: flex;
871
+ align-items: center;
872
+ justify-content: center;
873
+ font-size: 10px;
874
+ font-weight: bold;
875
+ color: white;
876
+ z-index: 10;
877
+ }}
878
+ .eval-sample .marker.human {{
879
+ background: rgba(0, 212, 170, 0.4);
880
+ border-color: #00d4aa;
881
+ }}
882
+ .eval-sample .marker.human::after {{
883
+ content: 'H';
884
+ color: #00d4aa;
885
+ }}
886
+ .eval-sample .marker.predicted {{
887
+ background: rgba(167, 139, 250, 0.4);
888
+ border-color: #a78bfa;
889
+ }}
890
+ .eval-sample .marker.predicted::after {{
891
+ content: 'AI';
892
+ font-size: 9px;
893
+ color: #a78bfa;
894
+ }}
895
+ .eval-sample .line {{
896
+ position: absolute;
897
+ height: 2px;
898
+ background: rgba(255, 255, 255, 0.5);
899
+ transform-origin: left center;
900
+ }}
901
+ .eval-sample .content {{
902
+ padding: 12px;
903
+ }}
904
+ .eval-sample .info {{
905
+ font-size: 0.75rem;
906
+ color: var(--text-secondary);
907
+ margin-bottom: 8px;
908
+ padding-bottom: 8px;
909
+ border-bottom: 1px solid var(--border-color);
910
+ }}
911
+ .eval-sample .info .correct {{
912
+ color: #34d399;
913
+ font-weight: 600;
914
+ }}
915
+ .eval-sample .info .incorrect {{
916
+ color: #ff5f5f;
917
+ font-weight: 600;
918
+ }}
919
+ .eval-sample .details {{
920
+ font-size: 0.7rem;
921
+ color: var(--text-secondary);
922
+ }}
923
+ .eval-sample .coords {{
924
+ display: flex;
925
+ flex-direction: column;
926
+ gap: 4px;
927
+ margin-bottom: 8px;
928
+ }}
929
+ .eval-sample .coords .human-coord {{
930
+ color: #34d399;
931
+ }}
932
+ .eval-sample .coords .pred-coord {{
933
+ color: #a78bfa;
934
+ }}
935
+ .eval-sample .thinking {{
936
+ margin-top: 8px;
937
+ padding: 8px;
938
+ background: rgba(0,0,0,0.3);
939
+ border-radius: 4px;
940
+ font-size: 0.65rem;
941
+ color: var(--text-secondary);
942
+ max-height: 150px;
943
+ overflow-y: auto;
944
+ white-space: pre-wrap;
945
+ word-break: break-word;
946
+ font-family: "SF Mono", Monaco, monospace;
947
+ line-height: 1.4;
948
+ }}
949
+ .eval-sample .thinking.collapsed {{
950
+ max-height: 60px;
951
+ overflow: hidden;
952
+ position: relative;
953
+ }}
954
+ .eval-sample .thinking.collapsed::after {{
955
+ content: '';
956
+ position: absolute;
957
+ bottom: 0;
958
+ left: 0;
959
+ right: 0;
960
+ height: 30px;
961
+ background: linear-gradient(to bottom, transparent, rgba(0,0,0,0.5));
962
+ }}
963
+ .eval-sample .thinking-toggle {{
964
+ cursor: pointer;
965
+ color: var(--accent);
966
+ font-size: 0.7rem;
967
+ margin-top: 4px;
968
+ display: inline-block;
969
+ }}
970
+ .eval-sample .thinking-toggle:hover {{
971
+ text-decoration: underline;
972
+ }}
973
+ .terminal-panel {{
974
+ background: var(--bg-secondary);
975
+ border: 1px solid var(--border-color);
976
+ border-radius: 12px;
977
+ padding: 20px;
978
+ margin-top: 16px;
979
+ }}
980
+ .terminal-header {{
981
+ display: flex;
982
+ justify-content: space-between;
983
+ align-items: center;
984
+ margin-bottom: 12px;
985
+ }}
986
+ .terminal-header h2 {{
987
+ font-size: 0.9rem;
988
+ margin: 0;
989
+ }}
990
+ .terminal-toggle {{
991
+ background: var(--bg-tertiary);
992
+ border: 1px solid var(--border-color);
993
+ color: var(--text-primary);
994
+ padding: 6px 12px;
995
+ border-radius: 6px;
996
+ font-size: 0.75rem;
997
+ cursor: pointer;
998
+ transition: all 0.2s;
999
+ }}
1000
+ .terminal-toggle:hover {{
1001
+ border-color: var(--accent);
1002
+ background: rgba(0, 212, 170, 0.1);
1003
+ }}
1004
+ .terminal-container {{
1005
+ background: #000;
1006
+ border: 1px solid rgba(255, 255, 255, 0.1);
1007
+ border-radius: 8px;
1008
+ padding: 16px;
1009
+ max-height: 400px;
1010
+ overflow-y: auto;
1011
+ font-family: "SF Mono", Monaco, "Courier New", monospace;
1012
+ font-size: 0.7rem;
1013
+ line-height: 1.5;
1014
+ color: #0f0;
1015
+ position: relative;
1016
+ }}
1017
+ .terminal-container.collapsed {{
1018
+ max-height: 200px;
1019
+ }}
1020
+ .terminal-output {{
1021
+ white-space: pre-wrap;
1022
+ word-break: break-word;
1023
+ }}
1024
+ .terminal-line {{
1025
+ padding: 2px 0;
1026
+ }}
1027
+ .terminal-line.timestamp {{
1028
+ color: #888;
1029
+ }}
1030
+ .terminal-line.error {{
1031
+ color: #ff5f5f;
1032
+ }}
1033
+ .terminal-line.success {{
1034
+ color: #34d399;
1035
+ }}
1036
+ .terminal-line.warning {{
1037
+ color: #ff9500;
1038
+ }}
1039
+ .terminal-empty {{
1040
+ color: #888;
1041
+ font-style: italic;
1042
+ text-align: center;
1043
+ padding: 40px;
1044
+ }}
1045
+ .terminal-controls {{
1046
+ display: flex;
1047
+ gap: 8px;
1048
+ margin-bottom: 8px;
1049
+ font-size: 0.7rem;
1050
+ }}
1051
+ .terminal-control-btn {{
1052
+ background: var(--bg-tertiary);
1053
+ border: 1px solid var(--border-color);
1054
+ color: var(--text-secondary);
1055
+ padding: 4px 8px;
1056
+ border-radius: 4px;
1057
+ cursor: pointer;
1058
+ transition: all 0.2s;
1059
+ }}
1060
+ .terminal-control-btn:hover {{
1061
+ border-color: var(--accent);
1062
+ color: var(--text-primary);
1063
+ }}
1064
+ .terminal-control-btn.active {{
1065
+ border-color: var(--accent);
1066
+ color: var(--accent);
1067
+ }}
1068
+ </style>
1069
+ </head>
1070
+ <body>
1071
+ {_generate_shared_header_html("training", meta_html=f"Job: {state.job_id}")}
1072
+
1073
+ <div class="container">
1074
+ <header>
1075
+ <div>
1076
+ <h1>Training Dashboard{f' <a href="{state.cloud_dashboard_url}" target="_blank" class="cloud-link cloud-badge"><svg viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2"><path d="M18 10h-1.26A8 8 0 1 0 9 20h9a5 5 0 0 0 0-10z"/></svg>{state.cloud_provider.title()} Cloud</a>' if state.cloud_dashboard_url else ''}</h1>
1077
+ <div class="job-info" id="job-info">
1078
+ <span class="job-host">{state.hostname or 'stub-local'} @ {state.instance_ip or '127.0.0.1'}</span>
1079
+ {f'<span class="job-config">{state.instance_type}</span>' if state.instance_type else ''}
1080
+ </div>
1081
+ </div>
1082
+ <div class="status" id="status">
1083
+ <div class="status-dot"></div>
1084
+ <span id="status-text">Training in progress</span>
1085
+ </div>
1086
+ </header>
1087
+
1088
+ <div class="setup-panel{' hidden' if not state.setup_logs else ''}" id="setup-panel">
1089
+ <div class="setup-header">
1090
+ <h2>Setup Progress</h2>
1091
+ <span class="setup-status-badge {state.setup_status}" id="setup-status-badge">{state.setup_status or 'initializing'}</span>
1092
+ </div>
1093
+ <div class="setup-logs" id="setup-logs">
1094
+ {''.join(f'<div class="setup-log-line{" current" if i == len(state.setup_logs) - 1 else ""}">{log}</div>' for i, log in enumerate(state.setup_logs)) if state.setup_logs else '<div class="setup-log-line">Waiting for setup logs...</div>'}
1095
+ </div>
1096
+ </div>
1097
+
1098
+ <div class="stats-grid">
1099
+ <div class="stat-card" id="card-epoch">
1100
+ <div class="stat-label">Epoch Progress</div>
1101
+ <div class="stat-value" id="stat-epoch">{min(state.epoch + 1, config.num_train_epochs)} / {config.num_train_epochs}</div>
1102
+ <div class="progress-bar"><div class="progress-fill" id="epoch-progress" style="width: {(min(state.epoch + 1, config.num_train_epochs) / config.num_train_epochs) * 100}%"></div></div>
1103
+ </div>
1104
+ <div class="stat-card" id="card-step">
1105
+ <div class="stat-label">Steps</div>
1106
+ <div class="stat-value" id="stat-step">{state.step}</div>
1107
+ </div>
1108
+ <div class="stat-card" id="card-loss">
1109
+ <div class="stat-label">Current Loss</div>
1110
+ <div class="stat-value accent" id="stat-loss">{state.loss:.4f}</div>
1111
+ <div class="stat-delta" id="loss-delta"></div>
1112
+ </div>
1113
+ <div class="stat-card">
1114
+ <div class="stat-label">Min Loss</div>
1115
+ <div class="stat-value" id="stat-min-loss">{min_loss:.4f}</div>
1116
+ </div>
1117
+ <div class="stat-card">
1118
+ <div class="stat-label">Avg (last 10)</div>
1119
+ <div class="stat-value" id="stat-avg-loss">{recent_avg:.4f}</div>
1120
+ </div>
1121
+ <div class="stat-card">
1122
+ <div class="stat-label">Avg Step Time</div>
1123
+ <div class="stat-value" id="stat-step-time">{avg_step_time:.1f}s</div>
1124
+ </div>
1125
+ <div class="stat-card">
1126
+ <div class="stat-label">Elapsed</div>
1127
+ <div class="stat-value" id="stat-elapsed">{elapsed_str}</div>
1128
+ </div>
1129
+ <div class="stat-card eta-card">
1130
+ <div class="stat-label">ETA</div>
1131
+ <div class="stat-value" id="stat-eta">{f"{int(eta_seconds // 60)}m {int(eta_seconds % 60)}s" if eta_seconds > 0 else ("Complete" if is_training_complete else "calculating...")}</div>
1132
+ <div class="stat-detail" id="eta-detail">{f"~{int(remaining_steps)} steps @ {avg_step_time:.1f}s/step" if remaining_steps > 0 else ""}</div>
1133
+ </div>
1134
+ <div class="stat-card" id="card-cost" style="background: linear-gradient(135deg, rgba(239, 68, 68, 0.1), rgba(220, 38, 38, 0.05)); border-color: rgba(239, 68, 68, 0.3);">
1135
+ <div class="stat-label">Cloud Cost</div>
1136
+ <div class="stat-value" id="stat-running-cost" style="color: #ef4444;">$0.00</div>
1137
+ <div class="stat-detail" id="stat-est-total">Est. Total: $0.00</div>
1138
+ </div>
1139
+ </div>
1140
+
1141
+ <div class="charts-grid">
1142
+ <div class="chart-container">
1143
+ <div class="chart-title">
1144
+ Loss Curve
1145
+ <span class="chart-subtitle" id="loss-trend"></span>
1146
+ </div>
1147
+ <div class="chart-wrapper">
1148
+ <canvas id="lossChart"></canvas>
1149
+ </div>
1150
+ </div>
1151
+ <div class="chart-container">
1152
+ <div class="chart-title">Loss by Epoch</div>
1153
+ <div class="chart-wrapper">
1154
+ <canvas id="epochChart"></canvas>
1155
+ </div>
1156
+ </div>
1157
+ </div>
1158
+
1159
+ <div class="config-panel">
1160
+ <h2>Training Configuration</h2>
1161
+ <div class="config-grid">
1162
+ <div class="config-item"><span class="key">Epochs:</span> <span class="value">{config.num_train_epochs}</span></div>
1163
+ <div class="config-item"><span class="key">Batch size:</span> <span class="value">{config.per_device_train_batch_size}</span></div>
1164
+ <div class="config-item"><span class="key">Learning rate:</span> <span class="value">{config.learning_rate}</span></div>
1165
+ <div class="config-item"><span class="key">Grad accum:</span> <span class="value">{config.gradient_accumulation_steps}</span></div>
1166
+ <div class="config-item"><span class="key">Max grad norm:</span> <span class="value">{config.max_grad_norm}</span></div>
1167
+ <div class="config-item"><span class="key">Early stop:</span> <span class="value">{config.early_stop_loss}</span></div>
1168
+ </div>
1169
+ <div id="stop-training-section" class="stop-training-section" style="margin-top: 16px; padding-top: 16px; border-top: 1px solid var(--border-color);">
1170
+ {_generate_termination_status_html(state, is_training_complete)}
1171
+ </div>
1172
+ </div>
1173
+
1174
+ <div class="eval-panel" id="eval-panel" style="display: none;">
1175
+ <h2>Evaluation Samples</h2>
1176
+ <div class="eval-metrics" id="eval-metrics"></div>
1177
+ <div class="eval-filters">
1178
+ <div class="filter-group">
1179
+ <label for="epoch-filter">Epoch:</label>
1180
+ <select id="epoch-filter">
1181
+ <option value="all">All Epochs</option>
1182
+ </select>
1183
+ </div>
1184
+ <div class="filter-group">
1185
+ <label for="correctness-filter">Status:</label>
1186
+ <select id="correctness-filter">
1187
+ <option value="all">All</option>
1188
+ <option value="correct">Correct Only</option>
1189
+ <option value="incorrect">Incorrect Only</option>
1190
+ </select>
1191
+ </div>
1192
+ <div style="margin-left: auto; font-size: 0.75rem; color: var(--text-muted);">
1193
+ <span id="filter-count"></span>
1194
+ </div>
1195
+ </div>
1196
+ <div class="eval-gallery" id="eval-gallery"></div>
1197
+ </div>
1198
+
1199
+ <div class="terminal-panel" id="terminal-panel">
1200
+ <div class="terminal-header">
1201
+ <h2>Training Output</h2>
1202
+ <button class="terminal-toggle" id="terminal-toggle" onclick="toggleTerminal()">
1203
+ <span id="terminal-toggle-text">Collapse</span>
1204
+ </button>
1205
+ </div>
1206
+ <div class="terminal-controls">
1207
+ <button class="terminal-control-btn active" id="auto-scroll-btn" onclick="toggleAutoScroll()">Auto-scroll</button>
1208
+ <button class="terminal-control-btn" id="wrap-btn" onclick="toggleWrap()">Wrap text</button>
1209
+ <span style="margin-left: auto; color: var(--text-secondary); font-size: 0.7rem;">
1210
+ <span id="terminal-line-count">0</span> lines
1211
+ </span>
1212
+ </div>
1213
+ <div class="terminal-container" id="terminal-container">
1214
+ <div class="terminal-output" id="terminal-output">
1215
+ <div class="terminal-empty">Waiting for training output...</div>
1216
+ </div>
1217
+ </div>
1218
+ </div>
1219
+
1220
+ <div class="update-indicator" id="update-indicator">Last updated: just now</div>
1221
+ </div>
1222
+
1223
+ <script>
1224
+ let losses = {losses_json};
1225
+ let epochAvg = {epoch_avg_json};
1226
+ let lossChart, epochChart;
1227
+ let lastStep = {state.step};
1228
+ let lastLoss = {state.loss};
1229
+
1230
+ // Cloud cost tracking
1231
+ const instanceType = '{state.instance_type}';
1232
+ const COST_RATES = {{
1233
+ 'gpu_1x_a10': 0.75, // Lambda Labs A10
1234
+ 'gpu_8x_a100': 1.29, // Lambda Labs A100 (per GPU)
1235
+ 'a10': 0.75, // Generic A10
1236
+ 'a100': 1.29, // Generic A100
1237
+ }};
1238
+
1239
+ function getHourlyRate(instanceType) {{
1240
+ // Try exact match first
1241
+ if (COST_RATES[instanceType.toLowerCase()]) {{
1242
+ return COST_RATES[instanceType.toLowerCase()];
1243
+ }}
1244
+ // Try partial match
1245
+ const typeStr = instanceType.toLowerCase();
1246
+ if (typeStr.includes('a100')) return COST_RATES['a100'];
1247
+ if (typeStr.includes('a10')) return COST_RATES['a10'];
1248
+ // Default to A10 rate
1249
+ return COST_RATES['a10'];
1250
+ }}
1251
+
1252
+ function updateCostDisplay() {{
1253
+ // Only show costs for actual cloud training (not stub/local)
1254
+ if (!instanceType || instanceType === '' || instanceType === 'stub') {{
1255
+ document.getElementById('card-cost').style.display = 'none';
1256
+ return;
1257
+ }}
1258
+
1259
+ const hourlyRate = getHourlyRate(instanceType);
1260
+
1261
+ // Calculate running cost based on elapsed time
1262
+ const timeSinceSync = (Date.now() - lastSyncTime) / 1000;
1263
+ const liveElapsed = baseElapsedTime + timeSinceSync;
1264
+ const elapsedHours = liveElapsed / 3600;
1265
+ const runningCost = elapsedHours * hourlyRate;
1266
+
1267
+ // Calculate estimated total cost
1268
+ let estimatedTotal = runningCost;
1269
+ if (etaSeconds > 0) {{
1270
+ const totalTimeSeconds = liveElapsed + etaSeconds;
1271
+ const totalHours = totalTimeSeconds / 3600;
1272
+ estimatedTotal = totalHours * hourlyRate;
1273
+ }}
1274
+
1275
+ // Update display
1276
+ document.getElementById('stat-running-cost').textContent = `$${{runningCost.toFixed(2)}}`;
1277
+ document.getElementById('stat-est-total').textContent = `Est. Total: $${{estimatedTotal.toFixed(2)}}`;
1278
+ }}
1279
+
1280
+ async function stopTraining() {{
1281
+ const btn = document.getElementById('stop-training-btn');
1282
+ const status = document.getElementById('stop-status');
1283
+
1284
+ btn.disabled = true;
1285
+ btn.innerHTML = '<span style="font-size: 1.1rem;">⏳</span> Stopping...';
1286
+ btn.style.background = '#666';
1287
+
1288
+ try {{
1289
+ // Try to create stop signal via API
1290
+ const response = await fetch('/api/stop', {{
1291
+ method: 'POST',
1292
+ headers: {{ 'Content-Type': 'application/json' }}
1293
+ }});
1294
+
1295
+ if (response.ok) {{
1296
+ btn.innerHTML = '<span style="font-size: 1.1rem;">✓</span> Stop Signal Sent';
1297
+ btn.style.background = '#22c55e';
1298
+ status.textContent = 'Training will stop after current step. Checkpoints will be downloaded.';
1299
+ status.style.color = '#22c55e';
1300
+ }} else {{
1301
+ throw new Error('Server returned ' + response.status);
1302
+ }}
1303
+ }} catch (e) {{
1304
+ // Fallback: show manual command
1305
+ btn.innerHTML = '<span style="font-size: 1.1rem;">!</span> Manual Stop Required';
1306
+ btn.style.background = '#f59e0b';
1307
+ status.innerHTML = 'Run this command to stop training:<br><code style="background: #1a1a24; padding: 4px 8px; border-radius: 4px; font-family: monospace;">touch training_output/STOP_TRAINING</code>';
1308
+ status.style.color = '#f59e0b';
1309
+ }}
1310
+ }}
1311
+
1312
+ function updateTerminationStatus(data) {{
1313
+ const stopSection = document.getElementById('stop-training-section');
1314
+ if (!stopSection) return;
1315
+
1316
+ const termStatus = data.termination_status || 'auto_complete';
1317
+ const termMessage = data.termination_message || '';
1318
+
1319
+ const statusStyles = {{
1320
+ 'auto_complete': {{ color: '#22c55e', icon: '✓', label: 'Training Complete' }},
1321
+ 'auto_low_loss': {{ color: '#22c55e', icon: '✓', label: 'Auto-Stopped (Low Loss)' }},
1322
+ 'user_stop': {{ color: '#f59e0b', icon: '■', label: 'Stopped by User' }},
1323
+ }};
1324
+
1325
+ const style = statusStyles[termStatus] || statusStyles['auto_complete'];
1326
+
1327
+ let html = `<div style="display: flex; flex-direction: column; gap: 8px;">
1328
+ <div style="display: flex; align-items: center; gap: 8px; color: ${{style.color}};">
1329
+ <span style="font-size: 1.2rem;">${{style.icon}}</span>
1330
+ <span style="font-weight: 600;">${{style.label}}</span>
1331
+ </div>`;
1332
+
1333
+ if (termMessage) {{
1334
+ html += `<div style="font-size: 0.85rem; color: var(--text-muted); margin-left: 28px;">${{termMessage}}</div>`;
1335
+ }}
1336
+
1337
+ html += '</div>';
1338
+ stopSection.innerHTML = html;
1339
+ }}
1340
+
1341
+ function initCharts() {{
1342
+ const lossCtx = document.getElementById('lossChart').getContext('2d');
1343
+ lossChart = new Chart(lossCtx, {{
1344
+ type: 'line',
1345
+ data: {{
1346
+ labels: losses.map(l => l.step),
1347
+ datasets: [{{
1348
+ label: 'Loss',
1349
+ data: losses.map(l => l.loss),
1350
+ borderColor: '#00d4aa',
1351
+ backgroundColor: 'rgba(0, 212, 170, 0.1)',
1352
+ fill: true,
1353
+ tension: 0.3,
1354
+ pointRadius: losses.length > 50 ? 0 : 3,
1355
+ }}]
1356
+ }},
1357
+ options: {{
1358
+ responsive: true,
1359
+ maintainAspectRatio: false,
1360
+ animation: {{ duration: 500 }},
1361
+ scales: {{
1362
+ x: {{
1363
+ title: {{ display: true, text: 'Step', color: '#888' }},
1364
+ grid: {{ color: 'rgba(255,255,255,0.05)' }},
1365
+ ticks: {{ color: '#888' }}
1366
+ }},
1367
+ y: {{
1368
+ title: {{ display: true, text: 'Loss', color: '#888' }},
1369
+ grid: {{ color: 'rgba(255,255,255,0.05)' }},
1370
+ ticks: {{ color: '#888' }}
1371
+ }}
1372
+ }},
1373
+ plugins: {{ legend: {{ display: false }} }}
1374
+ }}
1375
+ }});
1376
+
1377
+ const epochCtx = document.getElementById('epochChart').getContext('2d');
1378
+ epochChart = new Chart(epochCtx, {{
1379
+ type: 'bar',
1380
+ data: {{
1381
+ labels: epochAvg.map(e => `Epoch ${{e[0] + 1}}`),
1382
+ datasets: [{{
1383
+ label: 'Avg Loss',
1384
+ data: epochAvg.map(e => e[1]),
1385
+ backgroundColor: 'rgba(167, 139, 250, 0.6)',
1386
+ borderColor: '#a78bfa',
1387
+ borderWidth: 1,
1388
+ }}]
1389
+ }},
1390
+ options: {{
1391
+ responsive: true,
1392
+ maintainAspectRatio: false,
1393
+ animation: {{ duration: 500 }},
1394
+ scales: {{
1395
+ y: {{
1396
+ beginAtZero: false,
1397
+ grid: {{ color: 'rgba(255,255,255,0.05)' }},
1398
+ ticks: {{ color: '#888' }}
1399
+ }},
1400
+ x: {{
1401
+ grid: {{ display: false }},
1402
+ ticks: {{ color: '#888' }}
1403
+ }}
1404
+ }},
1405
+ plugins: {{ legend: {{ display: false }} }}
1406
+ }}
1407
+ }});
1408
+
1409
+ updateTrend();
1410
+ }}
1411
+
1412
+ function updateTrend() {{
1413
+ if (losses.length >= 10) {{
1414
+ const recent = losses.slice(-10);
1415
+ const first = recent[0].loss;
1416
+ const last = recent[recent.length - 1].loss;
1417
+ const change = ((last - first) / first * 100).toFixed(1);
1418
+ const trendEl = document.getElementById('loss-trend');
1419
+ if (change < 0) {{
1420
+ trendEl.textContent = `↓ ${{Math.abs(change)}}% (last 10)`;
1421
+ trendEl.style.color = '#34d399';
1422
+ }} else {{
1423
+ trendEl.textContent = `↑ ${{change}}% (last 10)`;
1424
+ trendEl.style.color = '#ff5f5f';
1425
+ }}
1426
+ }}
1427
+ }}
1428
+
1429
+ // Live elapsed timer variables
1430
+ let baseElapsedTime = {elapsed}; // Last known elapsed time from server
1431
+ let lastSyncTime = Date.now(); // When we last synced with server
1432
+ let lastSuccessfulFetch = Date.now(); // When we last got a successful response
1433
+ let currentJobId = '{state.job_id}'; // Current job ID
1434
+ const STALE_THRESHOLD_SECONDS = 30; // Consider stale after 30s without updates
1435
+
1436
+ // ETA tracking
1437
+ let etaSeconds = {eta_seconds};
1438
+ let avgStepTime = {avg_step_time};
1439
+ let remainingSteps = {remaining_steps};
1440
+ let isTrainingComplete = {'true' if is_training_complete else 'false'};
1441
+
1442
+ // Auto-stop when loss <= threshold (INVARIANT: training should stop when loss <= 1.0)
1443
+ const AUTO_STOP_LOSS_THRESHOLD = 1.0;
1444
+ let autoStopTriggered = false;
1445
+
1446
+ function updateElapsedDisplay() {{
1447
+ // Don't update elapsed if training is complete
1448
+ if (isTrainingComplete) {{
1449
+ return;
1450
+ }}
1451
+
1452
+ // Calculate live elapsed: base time + time since last sync
1453
+ const timeSinceSync = (Date.now() - lastSyncTime) / 1000;
1454
+ const liveElapsed = baseElapsedTime + timeSinceSync;
1455
+ const mins = Math.floor(liveElapsed / 60);
1456
+ const secs = Math.floor(liveElapsed % 60);
1457
+ document.getElementById('stat-elapsed').textContent = `${{mins}}m ${{secs}}s`;
1458
+
1459
+ // Update ETA countdown
1460
+ if (etaSeconds > 0) {{
1461
+ const liveEta = Math.max(0, etaSeconds - timeSinceSync);
1462
+ const etaMins = Math.floor(liveEta / 60);
1463
+ const etaSecs = Math.floor(liveEta % 60);
1464
+ document.getElementById('stat-eta').textContent = `${{etaMins}}m ${{etaSecs}}s`;
1465
+ }}
1466
+
1467
+ // Update cost display
1468
+ updateCostDisplay();
1469
+ }}
1470
+
1471
+ function updateStatusIndicator() {{
1472
+ const timeSinceUpdate = (Date.now() - lastSuccessfulFetch) / 1000;
1473
+ const statusEl = document.getElementById('status');
1474
+ const statusText = document.getElementById('status-text');
1475
+
1476
+ if (timeSinceUpdate > STALE_THRESHOLD_SECONDS) {{
1477
+ statusEl.className = 'status stale';
1478
+ const staleMins = Math.floor(timeSinceUpdate / 60);
1479
+ const staleSecs = Math.floor(timeSinceUpdate % 60);
1480
+ if (staleMins > 0) {{
1481
+ statusText.innerHTML = `STALE <span class="stale-warning">(no update for ${{staleMins}}m ${{staleSecs}}s)</span>`;
1482
+ }} else {{
1483
+ statusText.innerHTML = `STALE <span class="stale-warning">(no update for ${{staleSecs}}s)</span>`;
1484
+ }}
1485
+ }} else {{
1486
+ statusEl.className = 'status';
1487
+ statusText.textContent = 'LIVE';
1488
+ }}
1489
+ }}
1490
+
1491
+ async function fetchAndUpdate() {{
1492
+ try {{
1493
+ const response = await fetch('training_log.json?t=' + Date.now());
1494
+ if (!response.ok) return;
1495
+
1496
+ const data = await response.json();
1497
+ lastSuccessfulFetch = Date.now();
1498
+
1499
+ // Check if job_id has changed - if so, reload to get fresh data
1500
+ if (data.job_id && data.job_id !== currentJobId) {{
1501
+ console.log(`Job changed from ${{currentJobId}} to ${{data.job_id}}, reloading...`);
1502
+ location.reload();
1503
+ return;
1504
+ }}
1505
+
1506
+ // Update job info display
1507
+ if (data.job_id) {{
1508
+ const jobIdEl = document.querySelector('.job-id');
1509
+ const jobHostEl = document.querySelector('.job-host');
1510
+ const jobConfigEl = document.querySelector('.job-config');
1511
+ if (jobIdEl) jobIdEl.textContent = `Job: ${{data.job_id}}`;
1512
+ if (jobHostEl) {{
1513
+ let hostText = data.hostname || 'local';
1514
+ if (data.instance_ip) hostText += ` @ ${{data.instance_ip}}`;
1515
+ jobHostEl.textContent = hostText;
1516
+ }}
1517
+ if (jobConfigEl && data.config_path) {{
1518
+ jobConfigEl.textContent = data.config_path;
1519
+ }}
1520
+ }}
1521
+
1522
+ // Update setup panel if setup logs present
1523
+ if (data.setup_logs && data.setup_logs.length > 0) {{
1524
+ const setupPanel = document.getElementById('setup-panel');
1525
+ const setupLogs = document.getElementById('setup-logs');
1526
+ const setupBadge = document.getElementById('setup-status-badge');
1527
+
1528
+ setupPanel.classList.remove('hidden');
1529
+
1530
+ // Update status badge
1531
+ if (data.setup_status) {{
1532
+ setupBadge.textContent = data.setup_status;
1533
+ setupBadge.className = `setup-status-badge ${{data.setup_status}}`;
1534
+ }}
1535
+
1536
+ // Update logs
1537
+ setupLogs.innerHTML = data.setup_logs.map((log, i) =>
1538
+ `<div class="setup-log-line${{i === data.setup_logs.length - 1 ? ' current' : ''}}">${{log}}</div>`
1539
+ ).join('');
1540
+
1541
+ // Auto-scroll to bottom
1542
+ setupLogs.scrollTop = setupLogs.scrollHeight;
1543
+
1544
+ // Hide setup panel when training starts
1545
+ if (data.setup_status === 'training' || data.setup_status === 'complete') {{
1546
+ setTimeout(() => setupPanel.classList.add('hidden'), 3000);
1547
+ }}
1548
+ }}
1549
+
1550
+ // Always update elapsed time base
1551
+ if (data.elapsed_time) {{
1552
+ baseElapsedTime = data.elapsed_time;
1553
+ lastSyncTime = Date.now();
1554
+ }}
1555
+
1556
+ // Check for termination status (handles completed/stopped states)
1557
+ if (data.termination_status && !isTrainingComplete) {{
1558
+ isTrainingComplete = true;
1559
+ document.getElementById('stat-eta').textContent = 'Complete';
1560
+ document.getElementById('eta-detail').textContent = '';
1561
+ updateTerminationStatus(data);
1562
+ updateCostDisplay();
1563
+ }}
1564
+
1565
+ // Only update other stats if step changed
1566
+ if (data.step !== lastStep) {{
1567
+ // Update with animation
1568
+ const cards = document.querySelectorAll('.stat-card');
1569
+ cards.forEach(c => c.classList.add('updating'));
1570
+ setTimeout(() => cards.forEach(c => c.classList.remove('updating')), 300);
1571
+
1572
+ // Update stats
1573
+ const totalEpochs = data.total_epochs || {config.num_train_epochs};
1574
+ const displayEpoch = Math.min(data.epoch + 1, totalEpochs); // Cap at max
1575
+ document.getElementById('stat-epoch').textContent = `${{displayEpoch}} / ${{totalEpochs}}`;
1576
+ document.getElementById('epoch-progress').style.width = `${{(displayEpoch / totalEpochs) * 100}}%`;
1577
+ document.getElementById('stat-step').textContent = data.step;
1578
+ document.getElementById('stat-loss').textContent = data.loss.toFixed(4);
1579
+
1580
+ // Loss delta
1581
+ const delta = data.loss - lastLoss;
1582
+ const deltaEl = document.getElementById('loss-delta');
1583
+ if (delta < 0) {{
1584
+ deltaEl.textContent = `↓ ${{Math.abs(delta).toFixed(4)}}`;
1585
+ deltaEl.className = 'stat-delta positive';
1586
+ }} else {{
1587
+ deltaEl.textContent = `↑ ${{delta.toFixed(4)}}`;
1588
+ deltaEl.className = 'stat-delta negative';
1589
+ }}
1590
+
1591
+ // AUTO-STOP: Trigger stop when loss <= threshold and training is running
1592
+ if (!autoStopTriggered && !isTrainingComplete && data.loss <= AUTO_STOP_LOSS_THRESHOLD) {{
1593
+ autoStopTriggered = true;
1594
+ console.log(`Auto-stop triggered: loss ${{data.loss.toFixed(4)}} <= threshold ${{AUTO_STOP_LOSS_THRESHOLD}}`);
1595
+
1596
+ // Show notification
1597
+ const notif = document.createElement('div');
1598
+ notif.className = 'auto-stop-notification';
1599
+ notif.innerHTML = `
1600
+ <strong>Auto-Stop Triggered</strong><br>
1601
+ Loss ${{data.loss.toFixed(4)}} ≤ ${{AUTO_STOP_LOSS_THRESHOLD}} threshold.<br>
1602
+ Stopping training...
1603
+ `;
1604
+ notif.style.cssText = `
1605
+ position: fixed; top: 20px; right: 20px; z-index: 9999;
1606
+ background: #2d4a3e; color: #4ade80; padding: 15px 20px;
1607
+ border-radius: 8px; border: 1px solid #4ade80;
1608
+ box-shadow: 0 4px 12px rgba(0,0,0,0.3);
1609
+ animation: slideIn 0.3s ease;
1610
+ `;
1611
+ document.body.appendChild(notif);
1612
+
1613
+ // Call stop endpoint
1614
+ fetch('/api/stop', {{ method: 'POST' }})
1615
+ .then(r => r.json())
1616
+ .then(result => {{
1617
+ console.log('Stop result:', result);
1618
+ setTimeout(() => notif.remove(), 5000);
1619
+ }})
1620
+ .catch(err => {{
1621
+ console.error('Stop failed:', err);
1622
+ notif.innerHTML += '<br><span style="color:#f87171">Stop request failed</span>';
1623
+ }});
1624
+ }}
1625
+
1626
+ // Other stats
1627
+ if (data.losses && data.losses.length > 0) {{
1628
+ const minLoss = Math.min(...data.losses.map(l => l.loss));
1629
+ document.getElementById('stat-min-loss').textContent = minLoss.toFixed(4);
1630
+
1631
+ const recentLosses = data.losses.slice(-10);
1632
+ const avgLoss = recentLosses.reduce((a, b) => a + b.loss, 0) / recentLosses.length;
1633
+ document.getElementById('stat-avg-loss').textContent = avgLoss.toFixed(4);
1634
+
1635
+ // Calculate avg step time and update ETA
1636
+ if (data.losses.length > 1) {{
1637
+ let stepTimes = [];
1638
+ for (let i = 1; i < data.losses.length; i++) {{
1639
+ stepTimes.push(data.losses[i].time - data.losses[i-1].time);
1640
+ }}
1641
+ avgStepTime = stepTimes.reduce((a,b) => a+b, 0) / stepTimes.length;
1642
+ document.getElementById('stat-step-time').textContent = avgStepTime.toFixed(1) + 's';
1643
+
1644
+ // Recalculate ETA
1645
+ const totalEpochs = data.total_epochs || {config.num_train_epochs};
1646
+ const currentEpoch = data.epoch;
1647
+ const stepsInCompletedEpochs = data.losses.filter(l => l.epoch < currentEpoch).length;
1648
+ const stepsPerEpoch = currentEpoch > 0 && stepsInCompletedEpochs > 0
1649
+ ? stepsInCompletedEpochs / currentEpoch
1650
+ : data.losses.length / (currentEpoch + 1);
1651
+ const totalStepsEstimate = stepsPerEpoch * totalEpochs;
1652
+ remainingSteps = Math.max(0, totalStepsEstimate - data.losses.length);
1653
+ etaSeconds = remainingSteps * avgStepTime;
1654
+
1655
+ // Update ETA display
1656
+ if (etaSeconds > 0) {{
1657
+ const etaMins = Math.floor(etaSeconds / 60);
1658
+ const etaSecs = Math.floor(etaSeconds % 60);
1659
+ document.getElementById('stat-eta').textContent = `${{etaMins}}m ${{etaSecs}}s`;
1660
+ document.getElementById('eta-detail').textContent = `~${{Math.round(remainingSteps)}} steps @ ${{avgStepTime.toFixed(1)}}s/step`;
1661
+ }} else if (data.losses.length > 0) {{
1662
+ // Training complete - stop elapsed timer and update UI
1663
+ isTrainingComplete = true;
1664
+ document.getElementById('stat-eta').textContent = 'Complete';
1665
+ document.getElementById('eta-detail').textContent = '';
1666
+ // Update cost display one final time
1667
+ updateCostDisplay();
1668
+ // Replace stop button with termination status
1669
+ updateTerminationStatus(data);
1670
+ }} else {{
1671
+ // No data yet
1672
+ document.getElementById('stat-eta').textContent = 'calculating...';
1673
+ }}
1674
+ }}
1675
+
1676
+ // Update charts
1677
+ losses = data.losses;
1678
+ lossChart.data.labels = losses.map(l => l.step);
1679
+ lossChart.data.datasets[0].data = losses.map(l => l.loss);
1680
+ lossChart.data.datasets[0].pointRadius = losses.length > 50 ? 0 : 3;
1681
+ lossChart.update('none');
1682
+
1683
+ // Recalculate epoch averages
1684
+ const epochLosses = {{}};
1685
+ losses.forEach(l => {{
1686
+ if (!epochLosses[l.epoch]) epochLosses[l.epoch] = [];
1687
+ epochLosses[l.epoch].push(l.loss);
1688
+ }});
1689
+ epochAvg = Object.entries(epochLosses).map(([ep, arr]) => [parseInt(ep), arr.reduce((a,b) => a+b, 0) / arr.length]);
1690
+ epochChart.data.labels = epochAvg.map(e => `Epoch ${{e[0] + 1}}`);
1691
+ epochChart.data.datasets[0].data = epochAvg.map(e => e[1]);
1692
+ epochChart.update('none');
1693
+
1694
+ updateTrend();
1695
+ }}
1696
+
1697
+ lastStep = data.step;
1698
+ lastLoss = data.loss;
1699
+ }}
1700
+
1701
+ // Update evaluations if present
1702
+ if (data.evaluations && data.evaluations.length > 0) {{
1703
+ renderEvaluations(data.evaluations);
1704
+ }}
1705
+
1706
+ document.getElementById('update-indicator').textContent = 'Last updated: just now';
1707
+ }} catch (e) {{
1708
+ console.log('Update failed:', e);
1709
+ }}
1710
+ }}
1711
+
1712
+ function renderEvaluations(evaluations) {{
1713
+ const panel = document.getElementById('eval-panel');
1714
+ const gallery = document.getElementById('eval-gallery');
1715
+ const metrics = document.getElementById('eval-metrics');
1716
+
1717
+ if (evaluations.length === 0) {{
1718
+ panel.style.display = 'none';
1719
+ return;
1720
+ }}
1721
+
1722
+ panel.style.display = 'block';
1723
+
1724
+ // Calculate metrics
1725
+ const correctCount = evaluations.filter(e => e.correct).length;
1726
+ const avgDistance = evaluations.reduce((a, e) => a + e.distance, 0) / evaluations.length;
1727
+ const accuracy = (correctCount / evaluations.length * 100).toFixed(1);
1728
+
1729
+ metrics.innerHTML = `
1730
+ <div class="metric">
1731
+ <span class="metric-label">Accuracy</span>
1732
+ <span class="metric-value">${{accuracy}}%</span>
1733
+ </div>
1734
+ <div class="metric">
1735
+ <span class="metric-label">Avg Distance</span>
1736
+ <span class="metric-value">${{avgDistance.toFixed(1)}}px</span>
1737
+ </div>
1738
+ <div class="metric">
1739
+ <span class="metric-label">Samples</span>
1740
+ <span class="metric-value">${{evaluations.length}}</span>
1741
+ </div>
1742
+ <div class="legend" style="display: flex; gap: 16px; margin-left: auto; font-size: 0.75rem; align-items: center;">
1743
+ <span style="display: flex; align-items: center; gap: 4px;">
1744
+ <span style="width: 12px; height: 12px; border-radius: 50%; background: rgba(52, 211, 153, 0.8);"></span>
1745
+ Human
1746
+ </span>
1747
+ <span style="display: flex; align-items: center; gap: 4px;">
1748
+ <span style="width: 12px; height: 12px; border-radius: 50%; background: rgba(167, 139, 250, 0.8);"></span>
1749
+ Predicted
1750
+ </span>
1751
+ </div>
1752
+ `;
1753
+
1754
+ // Render gallery (show last 9 evaluations)
1755
+ const recentEvals = evaluations.slice(-9);
1756
+ gallery.innerHTML = recentEvals.map((ev, i) => {{
1757
+ const statusClass = ev.correct ? 'correct' : 'incorrect';
1758
+ const statusText = ev.correct ? '✓ Correct' : '✗ Off by ' + (ev.distance * 100).toFixed(1) + '%';
1759
+ const humanX = (ev.human_action.x || 0).toFixed(3);
1760
+ const humanY = (ev.human_action.y || 0).toFixed(3);
1761
+ const predX = (ev.predicted_action.x || 0).toFixed(3);
1762
+ const predY = (ev.predicted_action.y || 0).toFixed(3);
1763
+ const rawOutput = ev.predicted_action.raw_output || '';
1764
+ const thoughtMatch = rawOutput.match(/Thought:([\\s\\S]*?)(?:Action:|$)/);
1765
+ const thought = thoughtMatch ? thoughtMatch[1].trim().substring(0, 200) : '';
1766
+ const sampleId = 'eval-' + ev.epoch + '-' + ev.sample_idx;
1767
+ return `
1768
+ <div class="eval-sample">
1769
+ <div style="position: relative;">
1770
+ <img src="${{ev.image_path}}" alt="Sample ${{ev.sample_idx}}" onerror="this.style.display='none'">
1771
+ <div class="overlay" style="width: 100%; height: 100%;">
1772
+ <div class="marker human" style="left: ${{(ev.human_action.x || 0) * 100}}%; top: ${{(ev.human_action.y || 0) * 100}}%;" title="Human"></div>
1773
+ <div class="marker predicted" style="left: ${{(ev.predicted_action.x || 0) * 100}}%; top: ${{(ev.predicted_action.y || 0) * 100}}%;" title="Predicted"></div>
1774
+ </div>
1775
+ </div>
1776
+ <div class="info">
1777
+ <span class="${{statusClass}}">${{statusText}}</span>
1778
+ <span> | Epoch ${{ev.epoch + 1}}</span>
1779
+ </div>
1780
+ <div class="details">
1781
+ <div class="coords">
1782
+ <span class="human-coord">Human: (${{humanX}}, ${{humanY}})</span>
1783
+ <span class="pred-coord">Pred: (${{predX}}, ${{predY}})</span>
1784
+ </div>
1785
+ </div>
1786
+ ${{thought ? `
1787
+ <div class="thinking">${{thought}}${{thought.length >= 200 ? '...' : ''}}</div>
1788
+ ` : ''}}
1789
+ </div>
1790
+ `;
1791
+ }}).join('');
1792
+ }}
1793
+
1794
+ // Terminal output management
1795
+ let terminalAutoScroll = true;
1796
+ let terminalWrap = false;
1797
+ let terminalCollapsed = false;
1798
+ let lastTerminalSize = 0;
1799
+ const MAX_TERMINAL_LINES = 500;
1800
+
1801
+ function toggleTerminal() {{
1802
+ const container = document.getElementById('terminal-container');
1803
+ const toggleText = document.getElementById('terminal-toggle-text');
1804
+ terminalCollapsed = !terminalCollapsed;
1805
+
1806
+ if (terminalCollapsed) {{
1807
+ container.classList.add('collapsed');
1808
+ toggleText.textContent = 'Expand';
1809
+ }} else {{
1810
+ container.classList.remove('collapsed');
1811
+ toggleText.textContent = 'Collapse';
1812
+ }}
1813
+ }}
1814
+
1815
+ function toggleAutoScroll() {{
1816
+ terminalAutoScroll = !terminalAutoScroll;
1817
+ const btn = document.getElementById('auto-scroll-btn');
1818
+ if (terminalAutoScroll) {{
1819
+ btn.classList.add('active');
1820
+ scrollTerminalToBottom();
1821
+ }} else {{
1822
+ btn.classList.remove('active');
1823
+ }}
1824
+ }}
1825
+
1826
+ function toggleWrap() {{
1827
+ terminalWrap = !terminalWrap;
1828
+ const btn = document.getElementById('wrap-btn');
1829
+ const output = document.getElementById('terminal-output');
1830
+ if (terminalWrap) {{
1831
+ btn.classList.add('active');
1832
+ output.style.whiteSpace = 'pre-wrap';
1833
+ }} else {{
1834
+ btn.classList.remove('active');
1835
+ output.style.whiteSpace = 'pre';
1836
+ }}
1837
+ }}
1838
+
1839
+ function scrollTerminalToBottom() {{
1840
+ const container = document.getElementById('terminal-container');
1841
+ container.scrollTop = container.scrollHeight;
1842
+ }}
1843
+
1844
+ async function fetchTerminalOutput() {{
1845
+ try {{
1846
+ const response = await fetch('training.log?t=' + Date.now());
1847
+ if (!response.ok) {{
1848
+ // File doesn't exist yet
1849
+ return;
1850
+ }}
1851
+
1852
+ const text = await response.text();
1853
+ const lines = text.trim().split('\\n');
1854
+
1855
+ // Keep only last MAX_TERMINAL_LINES
1856
+ const displayLines = lines.slice(-MAX_TERMINAL_LINES);
1857
+
1858
+ const output = document.getElementById('terminal-output');
1859
+ const lineCount = document.getElementById('terminal-line-count');
1860
+
1861
+ // Update line count
1862
+ lineCount.textContent = lines.length;
1863
+
1864
+ // Only update if content changed
1865
+ if (displayLines.length === 0) {{
1866
+ output.innerHTML = '<div class="terminal-empty">Waiting for training output...</div>';
1867
+ return;
1868
+ }}
1869
+
1870
+ // Format lines with basic syntax highlighting
1871
+ const formattedLines = displayLines.map(line => {{
1872
+ let className = 'terminal-line';
1873
+
1874
+ // Detect line type
1875
+ if (line.match(/^\\d{{4}}-\\d{{2}}-\\d{{2}}/)) {{
1876
+ className += ' timestamp';
1877
+ }} else if (line.toLowerCase().includes('error') || line.toLowerCase().includes('failed')) {{
1878
+ className += ' error';
1879
+ }} else if (line.toLowerCase().includes('success') || line.toLowerCase().includes('complete')) {{
1880
+ className += ' success';
1881
+ }} else if (line.toLowerCase().includes('warning')) {{
1882
+ className += ' warning';
1883
+ }}
1884
+
1885
+ // Escape HTML
1886
+ const escaped = line
1887
+ .replace(/&/g, '&amp;')
1888
+ .replace(/</g, '&lt;')
1889
+ .replace(/>/g, '&gt;');
1890
+
1891
+ return `<div class="${{className}}">${{escaped}}</div>`;
1892
+ }}).join('');
1893
+
1894
+ output.innerHTML = formattedLines;
1895
+
1896
+ // Auto-scroll if enabled and new content arrived
1897
+ if (terminalAutoScroll && lines.length > lastTerminalSize) {{
1898
+ scrollTerminalToBottom();
1899
+ }}
1900
+
1901
+ lastTerminalSize = lines.length;
1902
+ }} catch (err) {{
1903
+ console.error('Failed to fetch terminal output:', err);
1904
+ }}
1905
+ }}
1906
+
1907
+ initCharts();
1908
+ updateCostDisplay(); // Initialize cost display
1909
+ fetchAndUpdate(); // Initial fetch on page load
1910
+ fetchTerminalOutput(); // Initial terminal fetch
1911
+ setInterval(fetchAndUpdate, 3000);
1912
+ setInterval(fetchTerminalOutput, 2000); // Poll terminal output every 2 seconds
1913
+ setInterval(updateElapsedDisplay, 1000); // Update elapsed time every second
1914
+ setInterval(updateStatusIndicator, 1000); // Update LIVE/STALE indicator every second
1915
+ </script>
1916
+ </body>
1917
+ </html>'''
1918
+ return html
1919
+
1920
+
1921
+ def regenerate_all_dashboards(output_dir: str | Path) -> list[Path]:
1922
+ """Regenerate all dashboards in a directory with static navigation.
1923
+
1924
+ This updates dashboard.html and generates the unified viewer.html.
1925
+ Old comparison_*.html files are left in place but no longer linked.
1926
+
1927
+ Args:
1928
+ output_dir: Directory containing dashboard files
1929
+
1930
+ Returns:
1931
+ List of paths to regenerated files
1932
+ """
1933
+ output_dir = Path(output_dir)
1934
+ regenerated = []
1935
+
1936
+ # Nav links are now fixed (Training + Viewer)
1937
+ nav_links = _build_nav_links()
1938
+
1939
+ # Regenerate main dashboard
1940
+ if (output_dir / "training_log.json").exists():
1941
+ try:
1942
+ path = regenerate_local_dashboard(output_dir, nav_links=nav_links)
1943
+ regenerated.append(path)
1944
+ except Exception as e:
1945
+ print(f"Warning: Failed to regenerate dashboard: {e}")
1946
+
1947
+ # Generate unified viewer if we have capture data
1948
+ try:
1949
+ viewer_path = generate_unified_viewer_from_output_dir(output_dir)
1950
+ if viewer_path:
1951
+ regenerated.append(viewer_path)
1952
+ except Exception as e:
1953
+ print(f"Warning: Failed to generate unified viewer: {e}")
1954
+ import traceback
1955
+ traceback.print_exc()
1956
+
1957
+ return regenerated
1958
+
1959
+
1960
+ def regenerate_local_dashboard(
1961
+ output_dir: str | Path,
1962
+ capture_path: str | Path | None = None,
1963
+ checkpoint_path: str | Path | None = None,
1964
+ nav_links: list[tuple[str, str]] | None = None,
1965
+ ) -> Path:
1966
+ """Regenerate dashboard.html with correct local paths and static navigation.
1967
+
1968
+ This should be called after downloading training results from a remote instance.
1969
+ It fixes:
1970
+ - Training status (COMPLETED/STOPPED instead of always LIVE)
1971
+ - Navigation links to sibling dashboards (comparison, viewer)
1972
+ - Local capture path for comparison preview
1973
+
1974
+ Args:
1975
+ output_dir: Directory containing training_log.json and dashboard files
1976
+ capture_path: Local path to capture directory (for comparison preview)
1977
+ checkpoint_path: Local path to checkpoint directory
1978
+ nav_links: Pre-built list of (filename, label) tuples for consistency
1979
+
1980
+ Returns:
1981
+ Path to generated dashboard.html
1982
+ """
1983
+ output_dir = Path(output_dir)
1984
+ log_file = output_dir / "training_log.json"
1985
+
1986
+ if not log_file.exists():
1987
+ raise FileNotFoundError(f"No training_log.json found in {output_dir}")
1988
+
1989
+ # Load training state from log
1990
+ with open(log_file) as f:
1991
+ data = json.load(f)
1992
+
1993
+ # Create state from log data
1994
+ state = TrainingState(
1995
+ job_id=data.get("job_id", "unknown"),
1996
+ hostname=data.get("hostname", ""),
1997
+ capture_path=str(capture_path) if capture_path else data.get("capture_path", ""),
1998
+ config_path=data.get("config_path", ""),
1999
+ epoch=data.get("epoch", 0),
2000
+ step=data.get("step", 0),
2001
+ loss=data.get("loss", 0),
2002
+ learning_rate=data.get("learning_rate", 0),
2003
+ total_epochs=data.get("total_epochs", 5),
2004
+ instance_type=data.get("instance_type", ""),
2005
+ instance_ip=data.get("instance_ip", ""),
2006
+ elapsed_time=data.get("elapsed_time", 0.0),
2007
+ cloud_provider=data.get("cloud_provider", ""),
2008
+ )
2009
+ state.losses = data.get("losses", [])
2010
+ state.evaluations = data.get("evaluations", [])
2011
+
2012
+ # Determine training status
2013
+ total_epochs = data.get("total_epochs", 5)
2014
+ current_epoch = data.get("epoch", 0)
2015
+
2016
+ if current_epoch + 1 >= total_epochs:
2017
+ training_status = "COMPLETED"
2018
+ elif len(state.losses) > 0:
2019
+ training_status = "STOPPED"
2020
+ else:
2021
+ training_status = "NOT_STARTED"
2022
+
2023
+ # Use provided nav_links or build them
2024
+ if nav_links is None:
2025
+ nav_links = _build_nav_links()
2026
+
2027
+ # Create config
2028
+ config = TrainingConfig(
2029
+ num_train_epochs=total_epochs,
2030
+ learning_rate=data.get("learning_rate", 5e-5),
2031
+ )
2032
+
2033
+ # Generate dashboard HTML with modifications
2034
+ html = generate_training_dashboard(state, config)
2035
+
2036
+ # Replace dynamic status with static status
2037
+ if training_status == "COMPLETED":
2038
+ html = html.replace(
2039
+ '<div class="status" id="status">',
2040
+ '<div class="status complete" id="status">'
2041
+ )
2042
+ html = html.replace(
2043
+ '<span id="status-text">Training in progress</span>',
2044
+ '<span id="status-text">COMPLETED</span>'
2045
+ )
2046
+ elif training_status == "STOPPED":
2047
+ html = html.replace(
2048
+ '<div class="status" id="status">',
2049
+ '<div class="status stale" id="status">'
2050
+ )
2051
+ html = html.replace(
2052
+ '<span id="status-text">Training in progress</span>',
2053
+ '<span id="status-text">STOPPED (Epoch {}/{})'.format(current_epoch + 1, total_epochs) + '</span>'
2054
+ )
2055
+
2056
+ # Fix ETA display for completed/stopped training
2057
+ import re
2058
+ if training_status in ("COMPLETED", "STOPPED"):
2059
+ # Replace "calculating..." with appropriate status
2060
+ html = re.sub(
2061
+ r'(<div class="stat-value" id="stat-eta">)[^<]*(</div>)',
2062
+ r'\1—\2' if training_status == "STOPPED" else r'\1complete\2',
2063
+ html
2064
+ )
2065
+
2066
+ # Replace dynamic nav with static unified header
2067
+ # The dashboard now uses the shared unified-header, so we just need to ensure
2068
+ # the header HTML is present (it's already generated by generate_training_dashboard)
2069
+
2070
+ # Disable the JS polling and dynamic discovery (training is done, no need to fetch updates)
2071
+ # This is critical for file:// protocol where fetch() doesn't work
2072
+ html = html.replace(
2073
+ "setInterval(fetchAndUpdate, 3000);",
2074
+ "// fetchAndUpdate disabled for static dashboard"
2075
+ )
2076
+ html = html.replace(
2077
+ "setInterval(updateElapsedDisplay, 1000);",
2078
+ "// updateElapsedDisplay disabled for static dashboard"
2079
+ )
2080
+ html = html.replace(
2081
+ "setInterval(updateStatusIndicator, 1000);",
2082
+ "// updateStatusIndicator disabled for static dashboard"
2083
+ )
2084
+ # CRITICAL: Disable discoverDashboards() - it overwrites static nav on file:// protocol
2085
+ html = html.replace(
2086
+ "discoverDashboards();",
2087
+ "// discoverDashboards disabled - using static nav for file:// protocol"
2088
+ )
2089
+
2090
+ # Write output
2091
+ dashboard_path = output_dir / "dashboard.html"
2092
+ dashboard_path.write_text(html)
2093
+ print(f"Regenerated dashboard: {dashboard_path}")
2094
+
2095
+ return dashboard_path
2096
+
2097
+
2098
+ def run_epoch_evaluation(
2099
+ adapter: BaseVLMAdapter,
2100
+ episode: Episode,
2101
+ epoch: int,
2102
+ config: TrainingConfig,
2103
+ logger: "TrainingLogger",
2104
+ sample_indices: Optional[List[int]] = None,
2105
+ ) -> Path:
2106
+ """Run inference evaluation on sample steps after an epoch.
2107
+
2108
+ This generates a comparison_epoch{N}.html file showing human vs predicted actions.
2109
+
2110
+ Args:
2111
+ adapter: Trained adapter to use for inference
2112
+ episode: Episode with steps to evaluate
2113
+ epoch: Current epoch number
2114
+ config: Training configuration
2115
+ logger: Training logger for state tracking
2116
+ sample_indices: Specific step indices to evaluate (default: evenly spaced)
2117
+
2118
+ Returns:
2119
+ Path to generated comparison HTML file
2120
+ """
2121
+ from openadapt_ml.scripts.compare import generate_comparison_html, predict_action, format_action
2122
+
2123
+ output_dir = Path(config.output_dir)
2124
+ output_dir.mkdir(parents=True, exist_ok=True)
2125
+
2126
+ # Select sample indices if not provided
2127
+ num_samples = min(config.eval_samples, len(episode.steps))
2128
+ if sample_indices is None:
2129
+ if num_samples >= len(episode.steps):
2130
+ sample_indices = list(range(len(episode.steps)))
2131
+ else:
2132
+ # Evenly space samples across the episode
2133
+ step_size = len(episode.steps) // num_samples
2134
+ sample_indices = [i * step_size for i in range(num_samples)]
2135
+
2136
+ print(f" Running inference on {len(sample_indices)} sample steps...")
2137
+
2138
+ # Switch adapter to eval mode
2139
+ adapter.eval()
2140
+
2141
+ comparison_data = []
2142
+ action_history: List[str] = []
2143
+ total_steps = len(episode.steps)
2144
+
2145
+ for i, step in enumerate(episode.steps):
2146
+ step_data = {
2147
+ "index": i,
2148
+ "time": step.t,
2149
+ "image_path": step.observation.image_path,
2150
+ "human_action": {
2151
+ "type": step.action.type,
2152
+ "x": step.action.x,
2153
+ "y": step.action.y,
2154
+ "text": step.action.text,
2155
+ },
2156
+ "predicted_action": None,
2157
+ "match": None,
2158
+ }
2159
+
2160
+ # Only run inference on selected samples (for speed)
2161
+ if i in sample_indices and step.observation.image_path:
2162
+ try:
2163
+ predicted = predict_action(
2164
+ adapter,
2165
+ step.observation.image_path,
2166
+ episode.goal,
2167
+ step_index=i,
2168
+ total_steps=total_steps,
2169
+ action_history=action_history.copy(),
2170
+ )
2171
+ step_data["predicted_action"] = predicted
2172
+
2173
+ # Check match and calculate distance
2174
+ if predicted and predicted.get("type") == step.action.type:
2175
+ step_data["match"] = True
2176
+
2177
+ # Calculate distance for click actions
2178
+ if step.action.type == "click":
2179
+ hx, hy = step.action.x or 0, step.action.y or 0
2180
+ px, py = predicted.get("x", 0), predicted.get("y", 0)
2181
+ distance = ((hx - px) ** 2 + (hy - py) ** 2) ** 0.5
2182
+
2183
+ # Log evaluation to training state
2184
+ logger.state.log_evaluation(
2185
+ epoch=epoch,
2186
+ sample_idx=i,
2187
+ image_path=step.observation.image_path,
2188
+ human_action=step_data["human_action"],
2189
+ predicted_action=predicted,
2190
+ )
2191
+ else:
2192
+ step_data["match"] = False
2193
+
2194
+ print(f" Step {i}: {step.action.type} -> {predicted.get('type') if predicted else 'none'}")
2195
+
2196
+ except Exception as e:
2197
+ print(f" Step {i}: inference failed - {e}")
2198
+
2199
+ # Build action history for context
2200
+ action_history.append(format_action(step.action, use_som=False))
2201
+ comparison_data.append(step_data)
2202
+
2203
+ # Switch back to train mode
2204
+ adapter.train()
2205
+
2206
+ # Generate comparison HTML
2207
+ output_path = output_dir / f"comparison_epoch{epoch}.html"
2208
+ capture_path = Path(logger.state.capture_path) if logger.state.capture_path else Path(".")
2209
+
2210
+ generate_comparison_html(capture_path, episode, comparison_data, output_path)
2211
+ print(f" Comparison saved: {output_path}")
2212
+
2213
+ # Also regenerate all dashboards to update navigation
2214
+ regenerate_all_dashboards(output_dir)
2215
+
2216
+ return output_path
2217
+
2218
+
2219
+ def _create_dataloader(dataset: Dataset, batch_size: int) -> DataLoader:
2220
+ # Use an identity collate_fn so that each batch is a List[Dict], matching
2221
+ # the expectations of adapters that operate on SFT-style samples.
2222
+ return DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=lambda x: x)
2223
+
2224
+
2225
+ def _create_lr_scheduler(
2226
+ optimizer: Optimizer,
2227
+ config: TrainingConfig,
2228
+ num_training_steps: int,
2229
+ ) -> Optional[LambdaLR]:
2230
+ """Create learning rate scheduler based on config.
2231
+
2232
+ Args:
2233
+ optimizer: The optimizer to schedule.
2234
+ config: Training configuration with lr_scheduler_type and warmup_ratio.
2235
+ num_training_steps: Total number of training steps.
2236
+
2237
+ Returns:
2238
+ LambdaLR scheduler or None if scheduler_type is "none" or "constant".
2239
+ """
2240
+ scheduler_type = config.lr_scheduler_type.lower()
2241
+
2242
+ if scheduler_type in ("none", "constant"):
2243
+ return None
2244
+
2245
+ num_warmup_steps = int(num_training_steps * config.warmup_ratio)
2246
+
2247
+ if scheduler_type == "linear":
2248
+ def lr_lambda(current_step: int) -> float:
2249
+ if current_step < num_warmup_steps:
2250
+ # Linear warmup
2251
+ return float(current_step) / float(max(1, num_warmup_steps))
2252
+ # Linear decay
2253
+ return max(
2254
+ 0.0,
2255
+ float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))
2256
+ )
2257
+ elif scheduler_type == "cosine":
2258
+ import math
2259
+ def lr_lambda(current_step: int) -> float:
2260
+ if current_step < num_warmup_steps:
2261
+ # Linear warmup
2262
+ return float(current_step) / float(max(1, num_warmup_steps))
2263
+ # Cosine decay
2264
+ progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
2265
+ return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress)))
2266
+ else:
2267
+ raise ValueError(f"Unknown lr_scheduler_type: {scheduler_type}. Use 'linear', 'cosine', 'constant', or 'none'.")
2268
+
2269
+ return LambdaLR(optimizer, lr_lambda)
2270
+
2271
+
2272
+ def train_supervised(
2273
+ adapter: BaseVLMAdapter,
2274
+ dataset: Dataset,
2275
+ config: TrainingConfig,
2276
+ optimizer: Optional[Optimizer] = None,
2277
+ logger: Optional[TrainingLogger] = None,
2278
+ episode: Optional[Episode] = None,
2279
+ ) -> bool:
2280
+ """Minimal supervised training loop skeleton.
2281
+
2282
+ This assumes that `adapter.prepare_inputs` and `adapter.compute_loss` are
2283
+ implemented. It will raise if those methods are not implemented.
2284
+
2285
+ Args:
2286
+ adapter: VLM adapter to train.
2287
+ dataset: Training dataset.
2288
+ config: Training configuration.
2289
+ optimizer: Optional optimizer (default: AdamW).
2290
+ logger: Optional training logger for visualization.
2291
+ episode: Optional episode for periodic evaluation (generates comparison_epoch{N}.html).
2292
+
2293
+ Returns:
2294
+ True if training completed successfully, False if aborted due to NaN/Inf loss.
2295
+ """
2296
+
2297
+ device = adapter.device # type: ignore[attr-defined]
2298
+ dataloader = _create_dataloader(dataset, batch_size=config.per_device_train_batch_size)
2299
+
2300
+ if optimizer is None:
2301
+ optimizer = torch.optim.AdamW(
2302
+ adapter.model.parameters(), # type: ignore[arg-type]
2303
+ lr=config.learning_rate,
2304
+ weight_decay=config.weight_decay,
2305
+ )
2306
+
2307
+ # Create logger if not provided
2308
+ if logger is None:
2309
+ logger = TrainingLogger(config.output_dir, config)
2310
+
2311
+ # Calculate total training steps for scheduler
2312
+ num_training_steps = len(dataloader) * config.num_train_epochs // config.gradient_accumulation_steps
2313
+
2314
+ # Create learning rate scheduler
2315
+ lr_scheduler = _create_lr_scheduler(optimizer, config, num_training_steps)
2316
+
2317
+ total_steps = 0
2318
+ adapter.train()
2319
+
2320
+ # Early stopping tracking
2321
+ consecutive_low_loss = 0
2322
+ early_stopped = False
2323
+ user_stopped = False
2324
+
2325
+ for epoch in range(config.num_train_epochs):
2326
+ if early_stopped or user_stopped:
2327
+ break
2328
+
2329
+ for _, batch in enumerate(dataloader):
2330
+ # Check for stop signal from dashboard
2331
+ stop_file = Path(config.output_dir) / "STOP_TRAINING"
2332
+ if stop_file.exists():
2333
+ msg = "Stop signal received from dashboard. Stopping training..."
2334
+ print(msg)
2335
+ logger._log_to_terminal(msg)
2336
+ # Set termination status for dashboard
2337
+ logger.state.termination_status = "user_stop"
2338
+ logger.state.termination_message = "Training stopped by user via dashboard"
2339
+ logger.save()
2340
+ user_stopped = True
2341
+ stop_file.unlink() # Remove signal file
2342
+ break
2343
+
2344
+ # Batch is a List[Dict[str, Any]] of SFT-style samples; adapter is
2345
+ # responsible for converting it into model inputs.
2346
+ samples: List[Dict[str, Any]] = batch
2347
+
2348
+ inputs = adapter.prepare_inputs(samples)
2349
+ inputs = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
2350
+
2351
+ loss = adapter.compute_loss(inputs)
2352
+
2353
+ # Guard against invalid losses to avoid propagating NaNs/Infs
2354
+ if torch.isnan(loss) or torch.isinf(loss):
2355
+ msg = f"Encountered invalid loss at epoch={epoch} step={total_steps + 1}: {loss.item()}"
2356
+ print(msg)
2357
+ logger._log_to_terminal(msg)
2358
+ logger.on_train_end()
2359
+ return False
2360
+
2361
+ loss.backward()
2362
+
2363
+ if (total_steps + 1) % config.gradient_accumulation_steps == 0:
2364
+ torch.nn.utils.clip_grad_norm_(adapter.model.parameters(), config.max_grad_norm) # type: ignore[arg-type]
2365
+ optimizer.step()
2366
+ if lr_scheduler is not None:
2367
+ lr_scheduler.step()
2368
+ optimizer.zero_grad()
2369
+
2370
+ total_steps += 1
2371
+ loss_val = loss.item()
2372
+
2373
+ # Get current learning rate from optimizer
2374
+ current_lr = optimizer.param_groups[0]['lr']
2375
+
2376
+ # Log step
2377
+ logger.on_step(epoch, total_steps, loss_val, current_lr)
2378
+
2379
+ if config.logging_steps and total_steps % config.logging_steps == 0:
2380
+ msg = f"epoch={epoch} step={total_steps} loss={loss_val:.4f} lr={current_lr:.6f}"
2381
+ print(msg)
2382
+ logger._log_to_terminal(msg)
2383
+
2384
+ # Early stopping check
2385
+ if loss_val < config.early_stop_loss:
2386
+ consecutive_low_loss += 1
2387
+ if consecutive_low_loss >= config.early_stop_patience:
2388
+ msg = (
2389
+ f"Early stopping: loss ({loss_val:.6f}) below threshold "
2390
+ f"({config.early_stop_loss}) for {config.early_stop_patience} consecutive steps"
2391
+ )
2392
+ print(msg)
2393
+ logger._log_to_terminal(msg)
2394
+ # Set termination status for dashboard
2395
+ logger.state.termination_status = "auto_low_loss"
2396
+ logger.state.termination_message = (
2397
+ f"Loss reached {loss_val:.6f} (< {config.early_stop_loss}) "
2398
+ f"for {config.early_stop_patience} consecutive steps"
2399
+ )
2400
+ logger.save()
2401
+ early_stopped = True
2402
+ break
2403
+ else:
2404
+ consecutive_low_loss = 0
2405
+
2406
+ # End of epoch
2407
+ logger.on_epoch_end(epoch)
2408
+
2409
+ # Save checkpoint at end of each epoch
2410
+ if config.save_checkpoint_every_epoch:
2411
+ checkpoint_path = Path(config.checkpoint_dir) / f"epoch_{epoch}"
2412
+ checkpoint_path.mkdir(parents=True, exist_ok=True)
2413
+ try:
2414
+ adapter.save_checkpoint(str(checkpoint_path))
2415
+ msg = f"Checkpoint saved to {checkpoint_path}"
2416
+ print(msg)
2417
+ logger._log_to_terminal(msg)
2418
+ except Exception as e:
2419
+ msg = f"Warning: Failed to save checkpoint: {e}"
2420
+ print(msg)
2421
+ logger._log_to_terminal(msg)
2422
+
2423
+ # Run evaluation after each epoch (generates comparison_epoch{N}.html)
2424
+ if config.eval_every_epoch and episode is not None:
2425
+ try:
2426
+ print(f"Running epoch {epoch} evaluation...")
2427
+ run_epoch_evaluation(
2428
+ adapter=adapter,
2429
+ episode=episode,
2430
+ epoch=epoch,
2431
+ config=config,
2432
+ logger=logger,
2433
+ )
2434
+ except Exception as e:
2435
+ print(f"Warning: Epoch evaluation failed: {e}")
2436
+ import traceback
2437
+ traceback.print_exc()
2438
+
2439
+ # Set termination status if not already set (normal completion)
2440
+ if not logger.state.termination_status:
2441
+ logger.state.termination_status = "auto_complete"
2442
+ logger.state.termination_message = f"Training completed all {config.num_train_epochs} epochs"
2443
+ logger.save()
2444
+
2445
+ logger.on_train_end()
2446
+ return True