openadapt-ml 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openadapt_ml/__init__.py +0 -0
- openadapt_ml/benchmarks/__init__.py +125 -0
- openadapt_ml/benchmarks/agent.py +825 -0
- openadapt_ml/benchmarks/azure.py +761 -0
- openadapt_ml/benchmarks/base.py +366 -0
- openadapt_ml/benchmarks/cli.py +884 -0
- openadapt_ml/benchmarks/data_collection.py +432 -0
- openadapt_ml/benchmarks/runner.py +381 -0
- openadapt_ml/benchmarks/waa.py +704 -0
- openadapt_ml/cloud/__init__.py +5 -0
- openadapt_ml/cloud/azure_inference.py +441 -0
- openadapt_ml/cloud/lambda_labs.py +2445 -0
- openadapt_ml/cloud/local.py +790 -0
- openadapt_ml/config.py +56 -0
- openadapt_ml/datasets/__init__.py +0 -0
- openadapt_ml/datasets/next_action.py +507 -0
- openadapt_ml/evals/__init__.py +23 -0
- openadapt_ml/evals/grounding.py +241 -0
- openadapt_ml/evals/plot_eval_metrics.py +174 -0
- openadapt_ml/evals/trajectory_matching.py +486 -0
- openadapt_ml/grounding/__init__.py +45 -0
- openadapt_ml/grounding/base.py +236 -0
- openadapt_ml/grounding/detector.py +570 -0
- openadapt_ml/ingest/__init__.py +43 -0
- openadapt_ml/ingest/capture.py +312 -0
- openadapt_ml/ingest/loader.py +232 -0
- openadapt_ml/ingest/synthetic.py +1102 -0
- openadapt_ml/models/__init__.py +0 -0
- openadapt_ml/models/api_adapter.py +171 -0
- openadapt_ml/models/base_adapter.py +59 -0
- openadapt_ml/models/dummy_adapter.py +42 -0
- openadapt_ml/models/qwen_vl.py +426 -0
- openadapt_ml/runtime/__init__.py +0 -0
- openadapt_ml/runtime/policy.py +182 -0
- openadapt_ml/schemas/__init__.py +53 -0
- openadapt_ml/schemas/sessions.py +122 -0
- openadapt_ml/schemas/validation.py +252 -0
- openadapt_ml/scripts/__init__.py +0 -0
- openadapt_ml/scripts/compare.py +1490 -0
- openadapt_ml/scripts/demo_policy.py +62 -0
- openadapt_ml/scripts/eval_policy.py +287 -0
- openadapt_ml/scripts/make_gif.py +153 -0
- openadapt_ml/scripts/prepare_synthetic.py +43 -0
- openadapt_ml/scripts/run_qwen_login_benchmark.py +192 -0
- openadapt_ml/scripts/train.py +174 -0
- openadapt_ml/training/__init__.py +0 -0
- openadapt_ml/training/benchmark_viewer.py +1538 -0
- openadapt_ml/training/shared_ui.py +157 -0
- openadapt_ml/training/stub_provider.py +276 -0
- openadapt_ml/training/trainer.py +2446 -0
- openadapt_ml/training/viewer.py +2970 -0
- openadapt_ml-0.1.0.dist-info/METADATA +818 -0
- openadapt_ml-0.1.0.dist-info/RECORD +55 -0
- openadapt_ml-0.1.0.dist-info/WHEEL +4 -0
- openadapt_ml-0.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,2446 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import time
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Any, Callable, Dict, List, Optional
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
from torch.optim import Optimizer
|
|
11
|
+
from torch.optim.lr_scheduler import LambdaLR
|
|
12
|
+
from torch.utils.data import DataLoader, Dataset
|
|
13
|
+
|
|
14
|
+
from openadapt_ml.models.base_adapter import BaseVLMAdapter
|
|
15
|
+
from openadapt_ml.schemas.sessions import Episode
|
|
16
|
+
from openadapt_ml.training.shared_ui import (
|
|
17
|
+
get_shared_header_css as _get_shared_header_css,
|
|
18
|
+
generate_shared_header_html as _generate_shared_header_html,
|
|
19
|
+
build_nav_links as _build_nav_links,
|
|
20
|
+
)
|
|
21
|
+
from openadapt_ml.training.viewer import (
|
|
22
|
+
generate_unified_viewer_from_output_dir,
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def setup_job_directory(base_dir: str | Path, job_id: str) -> Path:
|
|
27
|
+
"""Set up job-scoped directory structure with symlink.
|
|
28
|
+
|
|
29
|
+
Creates:
|
|
30
|
+
{base_dir}/{job_id}/ - Job-specific directory
|
|
31
|
+
{base_dir}/current - Symlink to current job directory
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
base_dir: Base output directory (e.g., "training_output")
|
|
35
|
+
job_id: Unique job identifier (e.g., "20251214_200417")
|
|
36
|
+
|
|
37
|
+
Returns:
|
|
38
|
+
Path to the job-specific directory
|
|
39
|
+
"""
|
|
40
|
+
base_dir = Path(base_dir)
|
|
41
|
+
job_dir = base_dir / job_id
|
|
42
|
+
current_link = base_dir / "current"
|
|
43
|
+
|
|
44
|
+
# Create base and job directories
|
|
45
|
+
base_dir.mkdir(parents=True, exist_ok=True)
|
|
46
|
+
job_dir.mkdir(parents=True, exist_ok=True)
|
|
47
|
+
|
|
48
|
+
# Atomically update the 'current' symlink
|
|
49
|
+
# Use a temp link then rename for atomic operation
|
|
50
|
+
temp_link = base_dir / f".current_temp_{job_id}"
|
|
51
|
+
try:
|
|
52
|
+
# Remove temp link if it exists from a previous failed attempt
|
|
53
|
+
if temp_link.exists() or temp_link.is_symlink():
|
|
54
|
+
temp_link.unlink()
|
|
55
|
+
|
|
56
|
+
# Create temp symlink pointing to job_id (relative path)
|
|
57
|
+
temp_link.symlink_to(job_id)
|
|
58
|
+
|
|
59
|
+
# Atomically replace current with temp
|
|
60
|
+
temp_link.rename(current_link)
|
|
61
|
+
except Exception as e:
|
|
62
|
+
# Clean up temp link on failure
|
|
63
|
+
if temp_link.exists() or temp_link.is_symlink():
|
|
64
|
+
temp_link.unlink()
|
|
65
|
+
raise RuntimeError(f"Failed to create current symlink: {e}")
|
|
66
|
+
|
|
67
|
+
return job_dir
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def get_current_job_directory(base_dir: str | Path) -> Path | None:
|
|
71
|
+
"""Get the current job directory from symlink.
|
|
72
|
+
|
|
73
|
+
Returns:
|
|
74
|
+
Path to current job directory, or None if no current symlink
|
|
75
|
+
"""
|
|
76
|
+
base_dir = Path(base_dir)
|
|
77
|
+
current_link = base_dir / "current"
|
|
78
|
+
|
|
79
|
+
if current_link.is_symlink():
|
|
80
|
+
return current_link.resolve()
|
|
81
|
+
return None
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
@dataclass
|
|
85
|
+
class TrainingConfig:
|
|
86
|
+
# Model / LoRA-related fields are handled elsewhere; this covers loop hyperparams.
|
|
87
|
+
num_train_epochs: int = 1
|
|
88
|
+
per_device_train_batch_size: int = 1
|
|
89
|
+
gradient_accumulation_steps: int = 1
|
|
90
|
+
learning_rate: float = 2e-4
|
|
91
|
+
warmup_ratio: float = 0.03
|
|
92
|
+
weight_decay: float = 0.0
|
|
93
|
+
max_grad_norm: float = 1.0
|
|
94
|
+
logging_steps: int = 10
|
|
95
|
+
# Learning rate scheduler
|
|
96
|
+
lr_scheduler_type: str = "linear" # Options: linear, cosine, constant, none
|
|
97
|
+
# Early stopping: stop when loss is below threshold for patience consecutive steps
|
|
98
|
+
early_stop_loss: float = 1e-4
|
|
99
|
+
early_stop_patience: int = 10
|
|
100
|
+
# Output directory for logs and visualizations
|
|
101
|
+
output_dir: str = "training_output"
|
|
102
|
+
# Checkpoint saving
|
|
103
|
+
save_checkpoint_every_epoch: bool = True
|
|
104
|
+
checkpoint_dir: str = "checkpoints"
|
|
105
|
+
# Evaluation during training
|
|
106
|
+
eval_every_epoch: bool = True
|
|
107
|
+
eval_samples: int = 3 # Number of samples to evaluate per epoch
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
@dataclass
|
|
111
|
+
class TrainingState:
|
|
112
|
+
"""Tracks training progress for visualization."""
|
|
113
|
+
# Job identification
|
|
114
|
+
job_id: str = field(default_factory=lambda: time.strftime("%Y%m%d_%H%M%S"))
|
|
115
|
+
hostname: str = field(default_factory=lambda: __import__('socket').gethostname())
|
|
116
|
+
capture_path: str = ""
|
|
117
|
+
config_path: str = ""
|
|
118
|
+
goal: str = "" # Task goal/description for the training run
|
|
119
|
+
# Training progress
|
|
120
|
+
epoch: int = 0
|
|
121
|
+
step: int = 0
|
|
122
|
+
total_steps: int = 0
|
|
123
|
+
total_epochs: int = 1 # Set by logger from config
|
|
124
|
+
loss: float = 0.0
|
|
125
|
+
learning_rate: float = 0.0
|
|
126
|
+
samples_seen: int = 0
|
|
127
|
+
start_time: float = field(default_factory=time.time)
|
|
128
|
+
elapsed_time: float = 0.0 # For historical data loaded from JSON
|
|
129
|
+
losses: List[Dict[str, Any]] = field(default_factory=list)
|
|
130
|
+
evaluations: List[Dict[str, Any]] = field(default_factory=list)
|
|
131
|
+
# Cloud info (optional)
|
|
132
|
+
instance_type: str = ""
|
|
133
|
+
instance_ip: str = ""
|
|
134
|
+
# Cloud provider info (for dashboard link)
|
|
135
|
+
cloud_provider: str = "" # e.g. "lambda", "azure"
|
|
136
|
+
cloud_dashboard_url: str = "" # e.g. "https://cloud.lambda.ai/instances"
|
|
137
|
+
cloud_instance_id: str = "" # Provider-specific instance ID
|
|
138
|
+
# Setup status tracking
|
|
139
|
+
setup_status: str = "" # e.g. "booting", "installing", "training", "complete"
|
|
140
|
+
setup_logs: List[str] = field(default_factory=list) # Setup progress messages
|
|
141
|
+
# Termination tracking
|
|
142
|
+
termination_status: str = "" # e.g. "auto_low_loss", "auto_complete", "user_stop", "running"
|
|
143
|
+
termination_message: str = "" # Human-readable termination reason
|
|
144
|
+
|
|
145
|
+
def log_step(self, epoch: int, step: int, loss: float, lr: float = 0.0) -> None:
|
|
146
|
+
"""Log a training step."""
|
|
147
|
+
self.epoch = epoch
|
|
148
|
+
self.step = step
|
|
149
|
+
self.loss = loss
|
|
150
|
+
self.learning_rate = lr
|
|
151
|
+
self.losses.append({
|
|
152
|
+
"epoch": epoch,
|
|
153
|
+
"step": step,
|
|
154
|
+
"loss": loss,
|
|
155
|
+
"lr": lr,
|
|
156
|
+
"time": time.time() - self.start_time,
|
|
157
|
+
})
|
|
158
|
+
|
|
159
|
+
def log_evaluation(self, epoch: int, sample_idx: int, image_path: str,
|
|
160
|
+
human_action: Dict, predicted_action: Dict) -> None:
|
|
161
|
+
"""Log an evaluation sample."""
|
|
162
|
+
# Calculate distance for click actions
|
|
163
|
+
distance = 0.0
|
|
164
|
+
if human_action.get("type") == "click" and predicted_action.get("type") == "click":
|
|
165
|
+
hx, hy = human_action.get("x", 0), human_action.get("y", 0)
|
|
166
|
+
px, py = predicted_action.get("x", 0), predicted_action.get("y", 0)
|
|
167
|
+
distance = ((hx - px) ** 2 + (hy - py) ** 2) ** 0.5
|
|
168
|
+
|
|
169
|
+
self.evaluations.append({
|
|
170
|
+
"epoch": epoch,
|
|
171
|
+
"sample_idx": sample_idx,
|
|
172
|
+
"image_path": image_path,
|
|
173
|
+
"human_action": human_action,
|
|
174
|
+
"predicted_action": predicted_action,
|
|
175
|
+
"distance": distance,
|
|
176
|
+
"correct": distance < 50, # Within 50 pixels is "correct"
|
|
177
|
+
})
|
|
178
|
+
|
|
179
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
180
|
+
"""Convert state to serializable dict."""
|
|
181
|
+
return {
|
|
182
|
+
# Job metadata
|
|
183
|
+
"job_id": self.job_id,
|
|
184
|
+
"hostname": self.hostname,
|
|
185
|
+
"capture_path": self.capture_path,
|
|
186
|
+
"config_path": self.config_path,
|
|
187
|
+
"goal": self.goal,
|
|
188
|
+
"instance_type": self.instance_type,
|
|
189
|
+
"instance_ip": self.instance_ip,
|
|
190
|
+
"started_at": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime(self.start_time)),
|
|
191
|
+
# Cloud provider info
|
|
192
|
+
"cloud_provider": self.cloud_provider,
|
|
193
|
+
"cloud_dashboard_url": self.cloud_dashboard_url,
|
|
194
|
+
"cloud_instance_id": self.cloud_instance_id,
|
|
195
|
+
"setup_status": self.setup_status,
|
|
196
|
+
"setup_logs": self.setup_logs,
|
|
197
|
+
# Training progress
|
|
198
|
+
"epoch": self.epoch,
|
|
199
|
+
"step": self.step,
|
|
200
|
+
"total_steps": self.total_steps,
|
|
201
|
+
"total_epochs": self.total_epochs,
|
|
202
|
+
"loss": self.loss,
|
|
203
|
+
"learning_rate": self.learning_rate,
|
|
204
|
+
"samples_seen": self.samples_seen,
|
|
205
|
+
"elapsed_time": time.time() - self.start_time,
|
|
206
|
+
"losses": self.losses,
|
|
207
|
+
"evaluations": self.evaluations,
|
|
208
|
+
# Termination tracking
|
|
209
|
+
"termination_status": self.termination_status,
|
|
210
|
+
"termination_message": self.termination_message,
|
|
211
|
+
}
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
class TrainingLogger:
|
|
215
|
+
"""Logs training progress and generates visualization."""
|
|
216
|
+
|
|
217
|
+
def __init__(
|
|
218
|
+
self,
|
|
219
|
+
output_dir: str | Path,
|
|
220
|
+
config: TrainingConfig,
|
|
221
|
+
capture_path: str = "",
|
|
222
|
+
config_path: str = "",
|
|
223
|
+
goal: str = "",
|
|
224
|
+
instance_ip: str = "",
|
|
225
|
+
instance_type: str = "",
|
|
226
|
+
cloud_provider: str = "",
|
|
227
|
+
cloud_dashboard_url: str = "",
|
|
228
|
+
cloud_instance_id: str = "",
|
|
229
|
+
job_id: str = "",
|
|
230
|
+
):
|
|
231
|
+
# Generate job_id if not provided
|
|
232
|
+
if not job_id:
|
|
233
|
+
job_id = time.strftime("%Y%m%d_%H%M%S")
|
|
234
|
+
|
|
235
|
+
# Set up job-scoped directory with symlink
|
|
236
|
+
base_dir = Path(output_dir)
|
|
237
|
+
self.base_dir = base_dir
|
|
238
|
+
self.output_dir = setup_job_directory(base_dir, job_id)
|
|
239
|
+
self.config = config
|
|
240
|
+
self.state = TrainingState(
|
|
241
|
+
job_id=job_id,
|
|
242
|
+
capture_path=capture_path,
|
|
243
|
+
config_path=config_path,
|
|
244
|
+
goal=goal,
|
|
245
|
+
instance_ip=instance_ip,
|
|
246
|
+
instance_type=instance_type,
|
|
247
|
+
total_epochs=config.num_train_epochs,
|
|
248
|
+
cloud_provider=cloud_provider,
|
|
249
|
+
cloud_dashboard_url=cloud_dashboard_url,
|
|
250
|
+
cloud_instance_id=cloud_instance_id,
|
|
251
|
+
)
|
|
252
|
+
self.log_file = self.output_dir / "training_log.json"
|
|
253
|
+
self.terminal_log_file = self.output_dir / "training.log"
|
|
254
|
+
self.terminal_log_handle = None
|
|
255
|
+
|
|
256
|
+
# Save config snapshot
|
|
257
|
+
self._save_config_snapshot()
|
|
258
|
+
|
|
259
|
+
def _log_to_terminal(self, message: str):
|
|
260
|
+
"""Write message to training.log file.
|
|
261
|
+
|
|
262
|
+
Args:
|
|
263
|
+
message: Message to log
|
|
264
|
+
"""
|
|
265
|
+
from datetime import datetime
|
|
266
|
+
|
|
267
|
+
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
268
|
+
log_line = f"[{timestamp}] {message}"
|
|
269
|
+
|
|
270
|
+
# Open file on first write (line buffered)
|
|
271
|
+
if self.terminal_log_handle is None:
|
|
272
|
+
self.terminal_log_handle = open(self.terminal_log_file, "w", buffering=1)
|
|
273
|
+
|
|
274
|
+
self.terminal_log_handle.write(log_line + "\n")
|
|
275
|
+
self.terminal_log_handle.flush()
|
|
276
|
+
|
|
277
|
+
def on_step(self, epoch: int, step: int, loss: float, lr: float = 0.0) -> None:
|
|
278
|
+
"""Called after each training step."""
|
|
279
|
+
self.state.log_step(epoch, step, loss, lr)
|
|
280
|
+
self._save_log()
|
|
281
|
+
|
|
282
|
+
def on_epoch_end(self, epoch: int) -> None:
|
|
283
|
+
"""Called at the end of each epoch."""
|
|
284
|
+
self.state.epoch = epoch
|
|
285
|
+
self._save_log()
|
|
286
|
+
self._generate_dashboard()
|
|
287
|
+
|
|
288
|
+
def on_train_end(self) -> None:
|
|
289
|
+
"""Called at the end of training."""
|
|
290
|
+
self._save_log()
|
|
291
|
+
self._generate_dashboard()
|
|
292
|
+
print(f"Training dashboard: {self.output_dir / 'dashboard.html'}")
|
|
293
|
+
|
|
294
|
+
# Close terminal log file
|
|
295
|
+
if self.terminal_log_handle:
|
|
296
|
+
self.terminal_log_handle.close()
|
|
297
|
+
self.terminal_log_handle = None
|
|
298
|
+
|
|
299
|
+
def _save_config_snapshot(self) -> None:
|
|
300
|
+
"""Save training config snapshot to JSON."""
|
|
301
|
+
from dataclasses import asdict
|
|
302
|
+
config_file = self.output_dir / "config.json"
|
|
303
|
+
config_dict = asdict(self.config)
|
|
304
|
+
with open(config_file, "w") as f:
|
|
305
|
+
json.dump(config_dict, f, indent=2)
|
|
306
|
+
|
|
307
|
+
def _save_log(self) -> None:
|
|
308
|
+
"""Save training log to JSON."""
|
|
309
|
+
with open(self.log_file, "w") as f:
|
|
310
|
+
json.dump(self.state.to_dict(), f, indent=2)
|
|
311
|
+
|
|
312
|
+
def _generate_dashboard(self) -> None:
|
|
313
|
+
"""Generate HTML training dashboard."""
|
|
314
|
+
dashboard_path = self.output_dir / "dashboard.html"
|
|
315
|
+
html = generate_training_dashboard(self.state, self.config)
|
|
316
|
+
dashboard_path.write_text(html)
|
|
317
|
+
|
|
318
|
+
|
|
319
|
+
def _generate_termination_status_html(state: TrainingState, is_training_complete: bool) -> str:
|
|
320
|
+
"""Generate HTML for termination status section."""
|
|
321
|
+
# Check if we have termination info
|
|
322
|
+
if state.termination_status:
|
|
323
|
+
# Map termination status to colors and icons
|
|
324
|
+
status_styles = {
|
|
325
|
+
"auto_complete": {"color": "#22c55e", "icon": "✓", "label": "Training Complete"},
|
|
326
|
+
"auto_low_loss": {"color": "#22c55e", "icon": "✓", "label": "Auto-Stopped (Low Loss)"},
|
|
327
|
+
"user_stop": {"color": "#f59e0b", "icon": "■", "label": "Stopped by User"},
|
|
328
|
+
}
|
|
329
|
+
style = status_styles.get(state.termination_status, {"color": "#22c55e", "icon": "✓", "label": "Complete"})
|
|
330
|
+
|
|
331
|
+
return f'''<div style="display: flex; flex-direction: column; gap: 8px;">
|
|
332
|
+
<div style="display: flex; align-items: center; gap: 8px; color: {style['color']};">
|
|
333
|
+
<span style="font-size: 1.2rem;">{style['icon']}</span>
|
|
334
|
+
<span style="font-weight: 600;">{style['label']}</span>
|
|
335
|
+
</div>
|
|
336
|
+
{f'<div style="font-size: 0.85rem; color: var(--text-muted); margin-left: 28px;">{state.termination_message}</div>' if state.termination_message else ''}
|
|
337
|
+
</div>'''
|
|
338
|
+
elif is_training_complete:
|
|
339
|
+
return '''<div style="display: flex; align-items: center; gap: 8px; color: #22c55e;">
|
|
340
|
+
<span style="font-size: 1.2rem;">✓</span>
|
|
341
|
+
<span style="font-weight: 600;">Training Complete</span>
|
|
342
|
+
</div>'''
|
|
343
|
+
else:
|
|
344
|
+
return '''<button id="stop-training-btn" onclick="stopTraining()" style="
|
|
345
|
+
background: linear-gradient(135deg, #ef4444 0%, #dc2626 100%);
|
|
346
|
+
color: white;
|
|
347
|
+
border: none;
|
|
348
|
+
padding: 12px 24px;
|
|
349
|
+
border-radius: 8px;
|
|
350
|
+
font-size: 0.9rem;
|
|
351
|
+
font-weight: 600;
|
|
352
|
+
cursor: pointer;
|
|
353
|
+
display: flex;
|
|
354
|
+
align-items: center;
|
|
355
|
+
gap: 8px;
|
|
356
|
+
transition: all 0.2s;
|
|
357
|
+
">
|
|
358
|
+
<span style="font-size: 1.1rem;">■</span> Stop Training
|
|
359
|
+
</button>
|
|
360
|
+
<p id="stop-status" style="margin-top: 8px; font-size: 0.75rem; color: var(--text-muted);"></p>'''
|
|
361
|
+
|
|
362
|
+
|
|
363
|
+
def generate_training_dashboard(state: TrainingState, config: TrainingConfig) -> str:
|
|
364
|
+
"""Generate an HTML dashboard for training visualization."""
|
|
365
|
+
losses_json = json.dumps(state.losses)
|
|
366
|
+
# Use stored elapsed_time if available (historical data), otherwise calculate
|
|
367
|
+
elapsed = state.elapsed_time if state.elapsed_time > 0 else time.time() - state.start_time
|
|
368
|
+
elapsed_str = f"{int(elapsed // 60)}m {int(elapsed % 60)}s"
|
|
369
|
+
|
|
370
|
+
# Calculate stats
|
|
371
|
+
if state.losses:
|
|
372
|
+
min_loss = min(l["loss"] for l in state.losses)
|
|
373
|
+
avg_loss = sum(l["loss"] for l in state.losses) / len(state.losses)
|
|
374
|
+
recent_losses = state.losses[-10:] if len(state.losses) >= 10 else state.losses
|
|
375
|
+
recent_avg = sum(l["loss"] for l in recent_losses) / len(recent_losses)
|
|
376
|
+
# Calculate step times
|
|
377
|
+
step_times = []
|
|
378
|
+
for i in range(1, len(state.losses)):
|
|
379
|
+
step_times.append(state.losses[i]["time"] - state.losses[i-1]["time"])
|
|
380
|
+
avg_step_time = sum(step_times) / len(step_times) if step_times else 0
|
|
381
|
+
# Loss by epoch
|
|
382
|
+
epoch_losses: dict = {}
|
|
383
|
+
for l in state.losses:
|
|
384
|
+
ep = l["epoch"]
|
|
385
|
+
if ep not in epoch_losses:
|
|
386
|
+
epoch_losses[ep] = []
|
|
387
|
+
epoch_losses[ep].append(l["loss"])
|
|
388
|
+
epoch_avg = {ep: sum(losses)/len(losses) for ep, losses in epoch_losses.items()}
|
|
389
|
+
# Estimate ETA
|
|
390
|
+
# Steps per epoch = steps in completed epochs / completed epochs
|
|
391
|
+
completed_epochs = state.epoch
|
|
392
|
+
steps_in_completed = sum(1 for l in state.losses if l["epoch"] < completed_epochs)
|
|
393
|
+
if completed_epochs > 0 and steps_in_completed > 0:
|
|
394
|
+
steps_per_epoch = steps_in_completed / completed_epochs
|
|
395
|
+
else:
|
|
396
|
+
# Estimate from current epoch progress
|
|
397
|
+
steps_per_epoch = len(state.losses) / (state.epoch + 1) if state.epoch >= 0 else len(state.losses)
|
|
398
|
+
|
|
399
|
+
total_epochs = state.total_epochs if state.total_epochs > 0 else config.num_train_epochs
|
|
400
|
+
total_steps_estimate = steps_per_epoch * total_epochs
|
|
401
|
+
remaining_steps = max(0, total_steps_estimate - len(state.losses))
|
|
402
|
+
eta_seconds = remaining_steps * avg_step_time if avg_step_time > 0 else 0
|
|
403
|
+
# Check if training is complete (all steps done)
|
|
404
|
+
is_training_complete = remaining_steps == 0 and len(state.losses) > 0
|
|
405
|
+
else:
|
|
406
|
+
min_loss = avg_loss = recent_avg = avg_step_time = 0.0
|
|
407
|
+
epoch_avg = {}
|
|
408
|
+
eta_seconds = 0
|
|
409
|
+
steps_per_epoch = 0
|
|
410
|
+
total_steps_estimate = 0
|
|
411
|
+
remaining_steps = 0
|
|
412
|
+
is_training_complete = False
|
|
413
|
+
|
|
414
|
+
epoch_avg_json = json.dumps(list(epoch_avg.items()))
|
|
415
|
+
|
|
416
|
+
# Generate comparison viewer preview if capture path available
|
|
417
|
+
comparison_viewer_path = ""
|
|
418
|
+
if state.capture_path:
|
|
419
|
+
try:
|
|
420
|
+
from openadapt_ml.scripts.compare import generate_comparison_html, generate_comparison_data
|
|
421
|
+
from openadapt_ml.ingest.capture import capture_to_episode
|
|
422
|
+
|
|
423
|
+
capture_path = Path(state.capture_path)
|
|
424
|
+
if capture_path.exists():
|
|
425
|
+
# Load episode from capture
|
|
426
|
+
episode = capture_to_episode(capture_path)
|
|
427
|
+
|
|
428
|
+
# Generate comparison data with null predictions (shows "— No prediction")
|
|
429
|
+
comparison_data = []
|
|
430
|
+
for i, step in enumerate(episode.steps):
|
|
431
|
+
step_data = {
|
|
432
|
+
"index": i,
|
|
433
|
+
"time": step.t,
|
|
434
|
+
"image_path": step.observation.image_path,
|
|
435
|
+
"human_action": {
|
|
436
|
+
"type": step.action.type,
|
|
437
|
+
"x": step.action.x,
|
|
438
|
+
"y": step.action.y,
|
|
439
|
+
"text": step.action.text,
|
|
440
|
+
},
|
|
441
|
+
"predicted_action": None, # Shows "— No prediction" in viewer
|
|
442
|
+
"match": None,
|
|
443
|
+
}
|
|
444
|
+
comparison_data.append(step_data)
|
|
445
|
+
|
|
446
|
+
# Generate comparison HTML
|
|
447
|
+
output_dir = Path(config.output_dir) if hasattr(config, 'output_dir') else Path("training_output")
|
|
448
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
449
|
+
comparison_output = output_dir / "comparison_preview.html"
|
|
450
|
+
generate_comparison_html(capture_path, episode, comparison_data, comparison_output)
|
|
451
|
+
comparison_viewer_path = str(comparison_output.name) # Relative path
|
|
452
|
+
except Exception as e:
|
|
453
|
+
pass # Fail silently if comparison viewer can't be generated
|
|
454
|
+
|
|
455
|
+
html = f'''<!DOCTYPE html>
|
|
456
|
+
<html lang="en">
|
|
457
|
+
<head>
|
|
458
|
+
<meta charset="UTF-8">
|
|
459
|
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
|
460
|
+
<title>Training Dashboard</title>
|
|
461
|
+
<script src="https://cdn.jsdelivr.net/npm/chart.js"></script>
|
|
462
|
+
<style>
|
|
463
|
+
:root {{
|
|
464
|
+
--bg-primary: #0a0a0f;
|
|
465
|
+
--bg-secondary: #12121a;
|
|
466
|
+
--bg-tertiary: #1a1a24;
|
|
467
|
+
--border-color: rgba(255, 255, 255, 0.06);
|
|
468
|
+
--text-primary: #f0f0f0;
|
|
469
|
+
--text-secondary: #888;
|
|
470
|
+
--accent: #00d4aa;
|
|
471
|
+
--accent-secondary: #a78bfa;
|
|
472
|
+
}}
|
|
473
|
+
* {{ box-sizing: border-box; margin: 0; padding: 0; }}
|
|
474
|
+
body {{
|
|
475
|
+
font-family: -apple-system, BlinkMacSystemFont, "Inter", sans-serif;
|
|
476
|
+
background: var(--bg-primary);
|
|
477
|
+
color: var(--text-primary);
|
|
478
|
+
min-height: 100vh;
|
|
479
|
+
}}
|
|
480
|
+
.container {{ max-width: 1400px; margin: 0 auto; padding: 24px; }}
|
|
481
|
+
header {{
|
|
482
|
+
display: flex;
|
|
483
|
+
justify-content: space-between;
|
|
484
|
+
align-items: center;
|
|
485
|
+
padding: 20px 24px;
|
|
486
|
+
background: var(--bg-secondary);
|
|
487
|
+
border: 1px solid var(--border-color);
|
|
488
|
+
border-radius: 12px;
|
|
489
|
+
margin-bottom: 24px;
|
|
490
|
+
}}
|
|
491
|
+
header h1 {{ font-size: 1.3rem; font-weight: 600; }}
|
|
492
|
+
.job-info {{
|
|
493
|
+
display: flex;
|
|
494
|
+
gap: 16px;
|
|
495
|
+
margin-top: 4px;
|
|
496
|
+
font-size: 0.75rem;
|
|
497
|
+
color: var(--text-secondary);
|
|
498
|
+
}}
|
|
499
|
+
.job-id {{
|
|
500
|
+
font-family: "SF Mono", Monaco, monospace;
|
|
501
|
+
color: var(--accent);
|
|
502
|
+
}}
|
|
503
|
+
.job-host {{
|
|
504
|
+
font-family: "SF Mono", Monaco, monospace;
|
|
505
|
+
}}
|
|
506
|
+
.job-config {{
|
|
507
|
+
font-family: "SF Mono", Monaco, monospace;
|
|
508
|
+
opacity: 0.7;
|
|
509
|
+
}}
|
|
510
|
+
.cloud-link {{
|
|
511
|
+
display: inline-flex;
|
|
512
|
+
align-items: center;
|
|
513
|
+
gap: 6px;
|
|
514
|
+
padding: 6px 12px;
|
|
515
|
+
background: var(--bg-tertiary);
|
|
516
|
+
border: 1px solid var(--border-color);
|
|
517
|
+
border-radius: 6px;
|
|
518
|
+
font-size: 0.75rem;
|
|
519
|
+
color: var(--text-primary);
|
|
520
|
+
text-decoration: none;
|
|
521
|
+
transition: all 0.2s;
|
|
522
|
+
}}
|
|
523
|
+
.cloud-link:hover {{
|
|
524
|
+
border-color: var(--accent);
|
|
525
|
+
background: rgba(0, 212, 170, 0.1);
|
|
526
|
+
}}
|
|
527
|
+
.cloud-link svg {{
|
|
528
|
+
width: 14px;
|
|
529
|
+
height: 14px;
|
|
530
|
+
}}
|
|
531
|
+
.cloud-badge {{
|
|
532
|
+
background: linear-gradient(135deg, rgba(167, 139, 250, 0.2), rgba(0, 212, 170, 0.1));
|
|
533
|
+
border-color: rgba(167, 139, 250, 0.3);
|
|
534
|
+
margin-left: 12px;
|
|
535
|
+
}}
|
|
536
|
+
.setup-panel {{
|
|
537
|
+
background: var(--bg-secondary);
|
|
538
|
+
border: 1px solid var(--border-color);
|
|
539
|
+
border-radius: 12px;
|
|
540
|
+
padding: 20px;
|
|
541
|
+
margin-bottom: 24px;
|
|
542
|
+
}}
|
|
543
|
+
.setup-panel.hidden {{
|
|
544
|
+
display: none;
|
|
545
|
+
}}
|
|
546
|
+
.setup-header {{
|
|
547
|
+
display: flex;
|
|
548
|
+
justify-content: space-between;
|
|
549
|
+
align-items: center;
|
|
550
|
+
margin-bottom: 12px;
|
|
551
|
+
}}
|
|
552
|
+
.setup-header h2 {{
|
|
553
|
+
font-size: 0.9rem;
|
|
554
|
+
}}
|
|
555
|
+
.setup-status-badge {{
|
|
556
|
+
display: inline-flex;
|
|
557
|
+
align-items: center;
|
|
558
|
+
gap: 6px;
|
|
559
|
+
padding: 4px 10px;
|
|
560
|
+
border-radius: 12px;
|
|
561
|
+
font-size: 0.7rem;
|
|
562
|
+
text-transform: uppercase;
|
|
563
|
+
letter-spacing: 0.05em;
|
|
564
|
+
font-weight: 600;
|
|
565
|
+
}}
|
|
566
|
+
.setup-status-badge.booting {{
|
|
567
|
+
background: rgba(255, 149, 0, 0.2);
|
|
568
|
+
color: #ff9500;
|
|
569
|
+
}}
|
|
570
|
+
.setup-status-badge.installing {{
|
|
571
|
+
background: rgba(167, 139, 250, 0.2);
|
|
572
|
+
color: #a78bfa;
|
|
573
|
+
}}
|
|
574
|
+
.setup-status-badge.training {{
|
|
575
|
+
background: rgba(0, 212, 170, 0.2);
|
|
576
|
+
color: #00d4aa;
|
|
577
|
+
}}
|
|
578
|
+
.setup-status-badge.complete {{
|
|
579
|
+
background: rgba(52, 211, 153, 0.2);
|
|
580
|
+
color: #34d399;
|
|
581
|
+
}}
|
|
582
|
+
.setup-logs {{
|
|
583
|
+
background: var(--bg-tertiary);
|
|
584
|
+
border-radius: 8px;
|
|
585
|
+
padding: 12px;
|
|
586
|
+
max-height: 200px;
|
|
587
|
+
overflow-y: auto;
|
|
588
|
+
font-family: "SF Mono", Monaco, monospace;
|
|
589
|
+
font-size: 0.7rem;
|
|
590
|
+
line-height: 1.6;
|
|
591
|
+
}}
|
|
592
|
+
.setup-log-line {{
|
|
593
|
+
color: var(--text-secondary);
|
|
594
|
+
padding: 2px 0;
|
|
595
|
+
}}
|
|
596
|
+
.setup-log-line.current {{
|
|
597
|
+
color: var(--accent);
|
|
598
|
+
}}
|
|
599
|
+
.status {{
|
|
600
|
+
display: flex;
|
|
601
|
+
align-items: center;
|
|
602
|
+
gap: 8px;
|
|
603
|
+
color: var(--accent);
|
|
604
|
+
}}
|
|
605
|
+
.status-dot {{
|
|
606
|
+
width: 10px;
|
|
607
|
+
height: 10px;
|
|
608
|
+
background: var(--accent);
|
|
609
|
+
border-radius: 50%;
|
|
610
|
+
animation: pulse 2s infinite;
|
|
611
|
+
}}
|
|
612
|
+
.status.complete .status-dot {{
|
|
613
|
+
animation: none;
|
|
614
|
+
background: #34d399;
|
|
615
|
+
}}
|
|
616
|
+
.status.stale {{
|
|
617
|
+
color: #ff9500;
|
|
618
|
+
}}
|
|
619
|
+
.status.stale .status-dot {{
|
|
620
|
+
animation: none;
|
|
621
|
+
background: #ff9500;
|
|
622
|
+
}}
|
|
623
|
+
.stale-warning {{
|
|
624
|
+
font-size: 0.7rem;
|
|
625
|
+
color: #ff9500;
|
|
626
|
+
margin-top: 2px;
|
|
627
|
+
}}
|
|
628
|
+
@keyframes pulse {{
|
|
629
|
+
0%, 100% {{ opacity: 1; }}
|
|
630
|
+
50% {{ opacity: 0.4; }}
|
|
631
|
+
}}
|
|
632
|
+
.stats-grid {{
|
|
633
|
+
display: grid;
|
|
634
|
+
grid-template-columns: repeat(auto-fit, minmax(160px, 1fr));
|
|
635
|
+
gap: 16px;
|
|
636
|
+
margin-bottom: 24px;
|
|
637
|
+
}}
|
|
638
|
+
.stat-card {{
|
|
639
|
+
background: var(--bg-secondary);
|
|
640
|
+
border: 1px solid var(--border-color);
|
|
641
|
+
border-radius: 12px;
|
|
642
|
+
padding: 20px;
|
|
643
|
+
transition: all 0.3s ease;
|
|
644
|
+
}}
|
|
645
|
+
.stat-card.updating {{
|
|
646
|
+
border-color: var(--accent);
|
|
647
|
+
box-shadow: 0 0 20px rgba(0, 212, 170, 0.1);
|
|
648
|
+
}}
|
|
649
|
+
.stat-label {{
|
|
650
|
+
font-size: 0.75rem;
|
|
651
|
+
color: var(--text-secondary);
|
|
652
|
+
text-transform: uppercase;
|
|
653
|
+
letter-spacing: 0.05em;
|
|
654
|
+
margin-bottom: 8px;
|
|
655
|
+
}}
|
|
656
|
+
.stat-detail {{
|
|
657
|
+
font-size: 0.65rem;
|
|
658
|
+
color: var(--text-secondary);
|
|
659
|
+
margin-top: 4px;
|
|
660
|
+
}}
|
|
661
|
+
.eta-card {{
|
|
662
|
+
background: linear-gradient(135deg, rgba(167, 139, 250, 0.1), rgba(0, 212, 170, 0.05));
|
|
663
|
+
border-color: rgba(167, 139, 250, 0.3);
|
|
664
|
+
}}
|
|
665
|
+
.stat-value {{
|
|
666
|
+
font-size: 1.6rem;
|
|
667
|
+
font-weight: 600;
|
|
668
|
+
font-family: "SF Mono", Monaco, monospace;
|
|
669
|
+
transition: all 0.3s ease;
|
|
670
|
+
}}
|
|
671
|
+
.stat-value.accent {{ color: var(--accent); }}
|
|
672
|
+
.stat-delta {{
|
|
673
|
+
font-size: 0.75rem;
|
|
674
|
+
margin-top: 4px;
|
|
675
|
+
font-family: "SF Mono", Monaco, monospace;
|
|
676
|
+
}}
|
|
677
|
+
.stat-delta.positive {{ color: #34d399; }}
|
|
678
|
+
.stat-delta.negative {{ color: #ff5f5f; }}
|
|
679
|
+
.charts-grid {{
|
|
680
|
+
display: grid;
|
|
681
|
+
grid-template-columns: 2fr 1fr;
|
|
682
|
+
gap: 16px;
|
|
683
|
+
margin-bottom: 24px;
|
|
684
|
+
}}
|
|
685
|
+
@media (max-width: 900px) {{
|
|
686
|
+
.charts-grid {{ grid-template-columns: 1fr; }}
|
|
687
|
+
}}
|
|
688
|
+
.chart-container {{
|
|
689
|
+
background: var(--bg-secondary);
|
|
690
|
+
border: 1px solid var(--border-color);
|
|
691
|
+
border-radius: 12px;
|
|
692
|
+
padding: 24px;
|
|
693
|
+
}}
|
|
694
|
+
.chart-title {{
|
|
695
|
+
font-size: 0.9rem;
|
|
696
|
+
font-weight: 600;
|
|
697
|
+
margin-bottom: 16px;
|
|
698
|
+
display: flex;
|
|
699
|
+
justify-content: space-between;
|
|
700
|
+
align-items: center;
|
|
701
|
+
}}
|
|
702
|
+
.chart-subtitle {{
|
|
703
|
+
font-size: 0.75rem;
|
|
704
|
+
color: var(--text-secondary);
|
|
705
|
+
font-weight: normal;
|
|
706
|
+
}}
|
|
707
|
+
.chart-wrapper {{
|
|
708
|
+
height: 300px;
|
|
709
|
+
position: relative;
|
|
710
|
+
}}
|
|
711
|
+
.config-panel {{
|
|
712
|
+
background: var(--bg-secondary);
|
|
713
|
+
border: 1px solid var(--border-color);
|
|
714
|
+
border-radius: 12px;
|
|
715
|
+
padding: 20px;
|
|
716
|
+
}}
|
|
717
|
+
.config-panel h2 {{
|
|
718
|
+
font-size: 0.9rem;
|
|
719
|
+
margin-bottom: 16px;
|
|
720
|
+
}}
|
|
721
|
+
.config-grid {{
|
|
722
|
+
display: grid;
|
|
723
|
+
grid-template-columns: repeat(auto-fit, minmax(140px, 1fr));
|
|
724
|
+
gap: 12px;
|
|
725
|
+
}}
|
|
726
|
+
.config-item {{
|
|
727
|
+
font-size: 0.8rem;
|
|
728
|
+
}}
|
|
729
|
+
.config-item .key {{
|
|
730
|
+
color: var(--text-secondary);
|
|
731
|
+
}}
|
|
732
|
+
.config-item .value {{
|
|
733
|
+
font-family: "SF Mono", Monaco, monospace;
|
|
734
|
+
color: var(--accent);
|
|
735
|
+
}}
|
|
736
|
+
.progress-bar {{
|
|
737
|
+
height: 4px;
|
|
738
|
+
background: var(--bg-tertiary);
|
|
739
|
+
border-radius: 2px;
|
|
740
|
+
margin-top: 8px;
|
|
741
|
+
overflow: hidden;
|
|
742
|
+
}}
|
|
743
|
+
.progress-fill {{
|
|
744
|
+
height: 100%;
|
|
745
|
+
background: linear-gradient(90deg, var(--accent), var(--accent-secondary));
|
|
746
|
+
border-radius: 2px;
|
|
747
|
+
transition: width 0.5s ease;
|
|
748
|
+
}}
|
|
749
|
+
.update-indicator {{
|
|
750
|
+
font-size: 0.7rem;
|
|
751
|
+
color: var(--text-secondary);
|
|
752
|
+
text-align: right;
|
|
753
|
+
margin-top: 16px;
|
|
754
|
+
}}
|
|
755
|
+
/* Shared header styles (injected from _get_shared_header_css) */
|
|
756
|
+
{_get_shared_header_css()}
|
|
757
|
+
.eval-panel {{
|
|
758
|
+
background: var(--bg-secondary);
|
|
759
|
+
border: 1px solid var(--border-color);
|
|
760
|
+
border-radius: 12px;
|
|
761
|
+
padding: 20px;
|
|
762
|
+
margin-top: 16px;
|
|
763
|
+
}}
|
|
764
|
+
.eval-panel h2 {{
|
|
765
|
+
font-size: 0.9rem;
|
|
766
|
+
margin-bottom: 16px;
|
|
767
|
+
}}
|
|
768
|
+
.eval-metrics {{
|
|
769
|
+
display: flex;
|
|
770
|
+
gap: 24px;
|
|
771
|
+
margin-bottom: 16px;
|
|
772
|
+
font-size: 0.85rem;
|
|
773
|
+
}}
|
|
774
|
+
.eval-metrics .metric {{
|
|
775
|
+
display: flex;
|
|
776
|
+
flex-direction: column;
|
|
777
|
+
}}
|
|
778
|
+
.eval-metrics .metric-value {{
|
|
779
|
+
font-size: 1.2rem;
|
|
780
|
+
font-weight: 600;
|
|
781
|
+
color: var(--accent);
|
|
782
|
+
}}
|
|
783
|
+
.eval-filters {{
|
|
784
|
+
display: flex;
|
|
785
|
+
gap: 16px;
|
|
786
|
+
margin-bottom: 16px;
|
|
787
|
+
align-items: center;
|
|
788
|
+
flex-wrap: wrap;
|
|
789
|
+
}}
|
|
790
|
+
.eval-filters .filter-group {{
|
|
791
|
+
display: flex;
|
|
792
|
+
align-items: center;
|
|
793
|
+
gap: 8px;
|
|
794
|
+
}}
|
|
795
|
+
.eval-filters label {{
|
|
796
|
+
font-size: 0.75rem;
|
|
797
|
+
color: var(--text-secondary);
|
|
798
|
+
text-transform: uppercase;
|
|
799
|
+
letter-spacing: 0.05em;
|
|
800
|
+
}}
|
|
801
|
+
.eval-filters select {{
|
|
802
|
+
padding: 8px 32px 8px 12px;
|
|
803
|
+
border-radius: 8px;
|
|
804
|
+
font-size: 0.85rem;
|
|
805
|
+
background: rgba(0,0,0,0.4);
|
|
806
|
+
color: var(--text-primary);
|
|
807
|
+
border: 1px solid rgba(255,255,255,0.1);
|
|
808
|
+
cursor: pointer;
|
|
809
|
+
appearance: none;
|
|
810
|
+
background-image: url('data:image/svg+xml,%3Csvg xmlns=%27http://www.w3.org/2000/svg%27 width=%2712%27 height=%278%27%3E%3Cpath fill=%27%23888%27 d=%27M0 0l6 8 6-8z%27/%3E%3C/svg%3E');
|
|
811
|
+
background-repeat: no-repeat;
|
|
812
|
+
background-position: right 10px center;
|
|
813
|
+
transition: all 0.2s;
|
|
814
|
+
}}
|
|
815
|
+
.eval-filters select:hover {{
|
|
816
|
+
border-color: var(--accent);
|
|
817
|
+
background-color: rgba(0,212,170,0.1);
|
|
818
|
+
}}
|
|
819
|
+
.eval-filters select:focus {{
|
|
820
|
+
outline: none;
|
|
821
|
+
border-color: var(--accent);
|
|
822
|
+
box-shadow: 0 0 0 2px rgba(0,212,170,0.2);
|
|
823
|
+
}}
|
|
824
|
+
.eval-gallery {{
|
|
825
|
+
display: grid;
|
|
826
|
+
grid-template-columns: repeat(auto-fit, minmax(350px, 1fr));
|
|
827
|
+
gap: 20px;
|
|
828
|
+
}}
|
|
829
|
+
.eval-sample {{
|
|
830
|
+
background: var(--bg-tertiary);
|
|
831
|
+
border-radius: 8px;
|
|
832
|
+
padding: 0;
|
|
833
|
+
position: relative;
|
|
834
|
+
overflow: hidden;
|
|
835
|
+
border: 1px solid var(--border-color);
|
|
836
|
+
}}
|
|
837
|
+
.eval-sample.hidden {{
|
|
838
|
+
display: none;
|
|
839
|
+
}}
|
|
840
|
+
.eval-sample .image-container {{
|
|
841
|
+
position: relative;
|
|
842
|
+
background: #000;
|
|
843
|
+
min-height: 200px;
|
|
844
|
+
display: flex;
|
|
845
|
+
align-items: center;
|
|
846
|
+
justify-content: center;
|
|
847
|
+
}}
|
|
848
|
+
.eval-sample img {{
|
|
849
|
+
width: 100%;
|
|
850
|
+
height: auto;
|
|
851
|
+
display: block;
|
|
852
|
+
max-height: 400px;
|
|
853
|
+
object-fit: contain;
|
|
854
|
+
}}
|
|
855
|
+
.eval-sample .overlay {{
|
|
856
|
+
position: absolute;
|
|
857
|
+
top: 0;
|
|
858
|
+
left: 0;
|
|
859
|
+
right: 0;
|
|
860
|
+
bottom: 0;
|
|
861
|
+
pointer-events: none;
|
|
862
|
+
}}
|
|
863
|
+
.eval-sample .marker {{
|
|
864
|
+
position: absolute;
|
|
865
|
+
width: 24px;
|
|
866
|
+
height: 24px;
|
|
867
|
+
border-radius: 50%;
|
|
868
|
+
transform: translate(-50%, -50%);
|
|
869
|
+
border: 3px solid white;
|
|
870
|
+
display: flex;
|
|
871
|
+
align-items: center;
|
|
872
|
+
justify-content: center;
|
|
873
|
+
font-size: 10px;
|
|
874
|
+
font-weight: bold;
|
|
875
|
+
color: white;
|
|
876
|
+
z-index: 10;
|
|
877
|
+
}}
|
|
878
|
+
.eval-sample .marker.human {{
|
|
879
|
+
background: rgba(0, 212, 170, 0.4);
|
|
880
|
+
border-color: #00d4aa;
|
|
881
|
+
}}
|
|
882
|
+
.eval-sample .marker.human::after {{
|
|
883
|
+
content: 'H';
|
|
884
|
+
color: #00d4aa;
|
|
885
|
+
}}
|
|
886
|
+
.eval-sample .marker.predicted {{
|
|
887
|
+
background: rgba(167, 139, 250, 0.4);
|
|
888
|
+
border-color: #a78bfa;
|
|
889
|
+
}}
|
|
890
|
+
.eval-sample .marker.predicted::after {{
|
|
891
|
+
content: 'AI';
|
|
892
|
+
font-size: 9px;
|
|
893
|
+
color: #a78bfa;
|
|
894
|
+
}}
|
|
895
|
+
.eval-sample .line {{
|
|
896
|
+
position: absolute;
|
|
897
|
+
height: 2px;
|
|
898
|
+
background: rgba(255, 255, 255, 0.5);
|
|
899
|
+
transform-origin: left center;
|
|
900
|
+
}}
|
|
901
|
+
.eval-sample .content {{
|
|
902
|
+
padding: 12px;
|
|
903
|
+
}}
|
|
904
|
+
.eval-sample .info {{
|
|
905
|
+
font-size: 0.75rem;
|
|
906
|
+
color: var(--text-secondary);
|
|
907
|
+
margin-bottom: 8px;
|
|
908
|
+
padding-bottom: 8px;
|
|
909
|
+
border-bottom: 1px solid var(--border-color);
|
|
910
|
+
}}
|
|
911
|
+
.eval-sample .info .correct {{
|
|
912
|
+
color: #34d399;
|
|
913
|
+
font-weight: 600;
|
|
914
|
+
}}
|
|
915
|
+
.eval-sample .info .incorrect {{
|
|
916
|
+
color: #ff5f5f;
|
|
917
|
+
font-weight: 600;
|
|
918
|
+
}}
|
|
919
|
+
.eval-sample .details {{
|
|
920
|
+
font-size: 0.7rem;
|
|
921
|
+
color: var(--text-secondary);
|
|
922
|
+
}}
|
|
923
|
+
.eval-sample .coords {{
|
|
924
|
+
display: flex;
|
|
925
|
+
flex-direction: column;
|
|
926
|
+
gap: 4px;
|
|
927
|
+
margin-bottom: 8px;
|
|
928
|
+
}}
|
|
929
|
+
.eval-sample .coords .human-coord {{
|
|
930
|
+
color: #34d399;
|
|
931
|
+
}}
|
|
932
|
+
.eval-sample .coords .pred-coord {{
|
|
933
|
+
color: #a78bfa;
|
|
934
|
+
}}
|
|
935
|
+
.eval-sample .thinking {{
|
|
936
|
+
margin-top: 8px;
|
|
937
|
+
padding: 8px;
|
|
938
|
+
background: rgba(0,0,0,0.3);
|
|
939
|
+
border-radius: 4px;
|
|
940
|
+
font-size: 0.65rem;
|
|
941
|
+
color: var(--text-secondary);
|
|
942
|
+
max-height: 150px;
|
|
943
|
+
overflow-y: auto;
|
|
944
|
+
white-space: pre-wrap;
|
|
945
|
+
word-break: break-word;
|
|
946
|
+
font-family: "SF Mono", Monaco, monospace;
|
|
947
|
+
line-height: 1.4;
|
|
948
|
+
}}
|
|
949
|
+
.eval-sample .thinking.collapsed {{
|
|
950
|
+
max-height: 60px;
|
|
951
|
+
overflow: hidden;
|
|
952
|
+
position: relative;
|
|
953
|
+
}}
|
|
954
|
+
.eval-sample .thinking.collapsed::after {{
|
|
955
|
+
content: '';
|
|
956
|
+
position: absolute;
|
|
957
|
+
bottom: 0;
|
|
958
|
+
left: 0;
|
|
959
|
+
right: 0;
|
|
960
|
+
height: 30px;
|
|
961
|
+
background: linear-gradient(to bottom, transparent, rgba(0,0,0,0.5));
|
|
962
|
+
}}
|
|
963
|
+
.eval-sample .thinking-toggle {{
|
|
964
|
+
cursor: pointer;
|
|
965
|
+
color: var(--accent);
|
|
966
|
+
font-size: 0.7rem;
|
|
967
|
+
margin-top: 4px;
|
|
968
|
+
display: inline-block;
|
|
969
|
+
}}
|
|
970
|
+
.eval-sample .thinking-toggle:hover {{
|
|
971
|
+
text-decoration: underline;
|
|
972
|
+
}}
|
|
973
|
+
.terminal-panel {{
|
|
974
|
+
background: var(--bg-secondary);
|
|
975
|
+
border: 1px solid var(--border-color);
|
|
976
|
+
border-radius: 12px;
|
|
977
|
+
padding: 20px;
|
|
978
|
+
margin-top: 16px;
|
|
979
|
+
}}
|
|
980
|
+
.terminal-header {{
|
|
981
|
+
display: flex;
|
|
982
|
+
justify-content: space-between;
|
|
983
|
+
align-items: center;
|
|
984
|
+
margin-bottom: 12px;
|
|
985
|
+
}}
|
|
986
|
+
.terminal-header h2 {{
|
|
987
|
+
font-size: 0.9rem;
|
|
988
|
+
margin: 0;
|
|
989
|
+
}}
|
|
990
|
+
.terminal-toggle {{
|
|
991
|
+
background: var(--bg-tertiary);
|
|
992
|
+
border: 1px solid var(--border-color);
|
|
993
|
+
color: var(--text-primary);
|
|
994
|
+
padding: 6px 12px;
|
|
995
|
+
border-radius: 6px;
|
|
996
|
+
font-size: 0.75rem;
|
|
997
|
+
cursor: pointer;
|
|
998
|
+
transition: all 0.2s;
|
|
999
|
+
}}
|
|
1000
|
+
.terminal-toggle:hover {{
|
|
1001
|
+
border-color: var(--accent);
|
|
1002
|
+
background: rgba(0, 212, 170, 0.1);
|
|
1003
|
+
}}
|
|
1004
|
+
.terminal-container {{
|
|
1005
|
+
background: #000;
|
|
1006
|
+
border: 1px solid rgba(255, 255, 255, 0.1);
|
|
1007
|
+
border-radius: 8px;
|
|
1008
|
+
padding: 16px;
|
|
1009
|
+
max-height: 400px;
|
|
1010
|
+
overflow-y: auto;
|
|
1011
|
+
font-family: "SF Mono", Monaco, "Courier New", monospace;
|
|
1012
|
+
font-size: 0.7rem;
|
|
1013
|
+
line-height: 1.5;
|
|
1014
|
+
color: #0f0;
|
|
1015
|
+
position: relative;
|
|
1016
|
+
}}
|
|
1017
|
+
.terminal-container.collapsed {{
|
|
1018
|
+
max-height: 200px;
|
|
1019
|
+
}}
|
|
1020
|
+
.terminal-output {{
|
|
1021
|
+
white-space: pre-wrap;
|
|
1022
|
+
word-break: break-word;
|
|
1023
|
+
}}
|
|
1024
|
+
.terminal-line {{
|
|
1025
|
+
padding: 2px 0;
|
|
1026
|
+
}}
|
|
1027
|
+
.terminal-line.timestamp {{
|
|
1028
|
+
color: #888;
|
|
1029
|
+
}}
|
|
1030
|
+
.terminal-line.error {{
|
|
1031
|
+
color: #ff5f5f;
|
|
1032
|
+
}}
|
|
1033
|
+
.terminal-line.success {{
|
|
1034
|
+
color: #34d399;
|
|
1035
|
+
}}
|
|
1036
|
+
.terminal-line.warning {{
|
|
1037
|
+
color: #ff9500;
|
|
1038
|
+
}}
|
|
1039
|
+
.terminal-empty {{
|
|
1040
|
+
color: #888;
|
|
1041
|
+
font-style: italic;
|
|
1042
|
+
text-align: center;
|
|
1043
|
+
padding: 40px;
|
|
1044
|
+
}}
|
|
1045
|
+
.terminal-controls {{
|
|
1046
|
+
display: flex;
|
|
1047
|
+
gap: 8px;
|
|
1048
|
+
margin-bottom: 8px;
|
|
1049
|
+
font-size: 0.7rem;
|
|
1050
|
+
}}
|
|
1051
|
+
.terminal-control-btn {{
|
|
1052
|
+
background: var(--bg-tertiary);
|
|
1053
|
+
border: 1px solid var(--border-color);
|
|
1054
|
+
color: var(--text-secondary);
|
|
1055
|
+
padding: 4px 8px;
|
|
1056
|
+
border-radius: 4px;
|
|
1057
|
+
cursor: pointer;
|
|
1058
|
+
transition: all 0.2s;
|
|
1059
|
+
}}
|
|
1060
|
+
.terminal-control-btn:hover {{
|
|
1061
|
+
border-color: var(--accent);
|
|
1062
|
+
color: var(--text-primary);
|
|
1063
|
+
}}
|
|
1064
|
+
.terminal-control-btn.active {{
|
|
1065
|
+
border-color: var(--accent);
|
|
1066
|
+
color: var(--accent);
|
|
1067
|
+
}}
|
|
1068
|
+
</style>
|
|
1069
|
+
</head>
|
|
1070
|
+
<body>
|
|
1071
|
+
{_generate_shared_header_html("training", meta_html=f"Job: {state.job_id}")}
|
|
1072
|
+
|
|
1073
|
+
<div class="container">
|
|
1074
|
+
<header>
|
|
1075
|
+
<div>
|
|
1076
|
+
<h1>Training Dashboard{f' <a href="{state.cloud_dashboard_url}" target="_blank" class="cloud-link cloud-badge"><svg viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2"><path d="M18 10h-1.26A8 8 0 1 0 9 20h9a5 5 0 0 0 0-10z"/></svg>{state.cloud_provider.title()} Cloud</a>' if state.cloud_dashboard_url else ''}</h1>
|
|
1077
|
+
<div class="job-info" id="job-info">
|
|
1078
|
+
<span class="job-host">{state.hostname or 'stub-local'} @ {state.instance_ip or '127.0.0.1'}</span>
|
|
1079
|
+
{f'<span class="job-config">{state.instance_type}</span>' if state.instance_type else ''}
|
|
1080
|
+
</div>
|
|
1081
|
+
</div>
|
|
1082
|
+
<div class="status" id="status">
|
|
1083
|
+
<div class="status-dot"></div>
|
|
1084
|
+
<span id="status-text">Training in progress</span>
|
|
1085
|
+
</div>
|
|
1086
|
+
</header>
|
|
1087
|
+
|
|
1088
|
+
<div class="setup-panel{' hidden' if not state.setup_logs else ''}" id="setup-panel">
|
|
1089
|
+
<div class="setup-header">
|
|
1090
|
+
<h2>Setup Progress</h2>
|
|
1091
|
+
<span class="setup-status-badge {state.setup_status}" id="setup-status-badge">{state.setup_status or 'initializing'}</span>
|
|
1092
|
+
</div>
|
|
1093
|
+
<div class="setup-logs" id="setup-logs">
|
|
1094
|
+
{''.join(f'<div class="setup-log-line{" current" if i == len(state.setup_logs) - 1 else ""}">{log}</div>' for i, log in enumerate(state.setup_logs)) if state.setup_logs else '<div class="setup-log-line">Waiting for setup logs...</div>'}
|
|
1095
|
+
</div>
|
|
1096
|
+
</div>
|
|
1097
|
+
|
|
1098
|
+
<div class="stats-grid">
|
|
1099
|
+
<div class="stat-card" id="card-epoch">
|
|
1100
|
+
<div class="stat-label">Epoch Progress</div>
|
|
1101
|
+
<div class="stat-value" id="stat-epoch">{min(state.epoch + 1, config.num_train_epochs)} / {config.num_train_epochs}</div>
|
|
1102
|
+
<div class="progress-bar"><div class="progress-fill" id="epoch-progress" style="width: {(min(state.epoch + 1, config.num_train_epochs) / config.num_train_epochs) * 100}%"></div></div>
|
|
1103
|
+
</div>
|
|
1104
|
+
<div class="stat-card" id="card-step">
|
|
1105
|
+
<div class="stat-label">Steps</div>
|
|
1106
|
+
<div class="stat-value" id="stat-step">{state.step}</div>
|
|
1107
|
+
</div>
|
|
1108
|
+
<div class="stat-card" id="card-loss">
|
|
1109
|
+
<div class="stat-label">Current Loss</div>
|
|
1110
|
+
<div class="stat-value accent" id="stat-loss">{state.loss:.4f}</div>
|
|
1111
|
+
<div class="stat-delta" id="loss-delta"></div>
|
|
1112
|
+
</div>
|
|
1113
|
+
<div class="stat-card">
|
|
1114
|
+
<div class="stat-label">Min Loss</div>
|
|
1115
|
+
<div class="stat-value" id="stat-min-loss">{min_loss:.4f}</div>
|
|
1116
|
+
</div>
|
|
1117
|
+
<div class="stat-card">
|
|
1118
|
+
<div class="stat-label">Avg (last 10)</div>
|
|
1119
|
+
<div class="stat-value" id="stat-avg-loss">{recent_avg:.4f}</div>
|
|
1120
|
+
</div>
|
|
1121
|
+
<div class="stat-card">
|
|
1122
|
+
<div class="stat-label">Avg Step Time</div>
|
|
1123
|
+
<div class="stat-value" id="stat-step-time">{avg_step_time:.1f}s</div>
|
|
1124
|
+
</div>
|
|
1125
|
+
<div class="stat-card">
|
|
1126
|
+
<div class="stat-label">Elapsed</div>
|
|
1127
|
+
<div class="stat-value" id="stat-elapsed">{elapsed_str}</div>
|
|
1128
|
+
</div>
|
|
1129
|
+
<div class="stat-card eta-card">
|
|
1130
|
+
<div class="stat-label">ETA</div>
|
|
1131
|
+
<div class="stat-value" id="stat-eta">{f"{int(eta_seconds // 60)}m {int(eta_seconds % 60)}s" if eta_seconds > 0 else ("Complete" if is_training_complete else "calculating...")}</div>
|
|
1132
|
+
<div class="stat-detail" id="eta-detail">{f"~{int(remaining_steps)} steps @ {avg_step_time:.1f}s/step" if remaining_steps > 0 else ""}</div>
|
|
1133
|
+
</div>
|
|
1134
|
+
<div class="stat-card" id="card-cost" style="background: linear-gradient(135deg, rgba(239, 68, 68, 0.1), rgba(220, 38, 38, 0.05)); border-color: rgba(239, 68, 68, 0.3);">
|
|
1135
|
+
<div class="stat-label">Cloud Cost</div>
|
|
1136
|
+
<div class="stat-value" id="stat-running-cost" style="color: #ef4444;">$0.00</div>
|
|
1137
|
+
<div class="stat-detail" id="stat-est-total">Est. Total: $0.00</div>
|
|
1138
|
+
</div>
|
|
1139
|
+
</div>
|
|
1140
|
+
|
|
1141
|
+
<div class="charts-grid">
|
|
1142
|
+
<div class="chart-container">
|
|
1143
|
+
<div class="chart-title">
|
|
1144
|
+
Loss Curve
|
|
1145
|
+
<span class="chart-subtitle" id="loss-trend"></span>
|
|
1146
|
+
</div>
|
|
1147
|
+
<div class="chart-wrapper">
|
|
1148
|
+
<canvas id="lossChart"></canvas>
|
|
1149
|
+
</div>
|
|
1150
|
+
</div>
|
|
1151
|
+
<div class="chart-container">
|
|
1152
|
+
<div class="chart-title">Loss by Epoch</div>
|
|
1153
|
+
<div class="chart-wrapper">
|
|
1154
|
+
<canvas id="epochChart"></canvas>
|
|
1155
|
+
</div>
|
|
1156
|
+
</div>
|
|
1157
|
+
</div>
|
|
1158
|
+
|
|
1159
|
+
<div class="config-panel">
|
|
1160
|
+
<h2>Training Configuration</h2>
|
|
1161
|
+
<div class="config-grid">
|
|
1162
|
+
<div class="config-item"><span class="key">Epochs:</span> <span class="value">{config.num_train_epochs}</span></div>
|
|
1163
|
+
<div class="config-item"><span class="key">Batch size:</span> <span class="value">{config.per_device_train_batch_size}</span></div>
|
|
1164
|
+
<div class="config-item"><span class="key">Learning rate:</span> <span class="value">{config.learning_rate}</span></div>
|
|
1165
|
+
<div class="config-item"><span class="key">Grad accum:</span> <span class="value">{config.gradient_accumulation_steps}</span></div>
|
|
1166
|
+
<div class="config-item"><span class="key">Max grad norm:</span> <span class="value">{config.max_grad_norm}</span></div>
|
|
1167
|
+
<div class="config-item"><span class="key">Early stop:</span> <span class="value">{config.early_stop_loss}</span></div>
|
|
1168
|
+
</div>
|
|
1169
|
+
<div id="stop-training-section" class="stop-training-section" style="margin-top: 16px; padding-top: 16px; border-top: 1px solid var(--border-color);">
|
|
1170
|
+
{_generate_termination_status_html(state, is_training_complete)}
|
|
1171
|
+
</div>
|
|
1172
|
+
</div>
|
|
1173
|
+
|
|
1174
|
+
<div class="eval-panel" id="eval-panel" style="display: none;">
|
|
1175
|
+
<h2>Evaluation Samples</h2>
|
|
1176
|
+
<div class="eval-metrics" id="eval-metrics"></div>
|
|
1177
|
+
<div class="eval-filters">
|
|
1178
|
+
<div class="filter-group">
|
|
1179
|
+
<label for="epoch-filter">Epoch:</label>
|
|
1180
|
+
<select id="epoch-filter">
|
|
1181
|
+
<option value="all">All Epochs</option>
|
|
1182
|
+
</select>
|
|
1183
|
+
</div>
|
|
1184
|
+
<div class="filter-group">
|
|
1185
|
+
<label for="correctness-filter">Status:</label>
|
|
1186
|
+
<select id="correctness-filter">
|
|
1187
|
+
<option value="all">All</option>
|
|
1188
|
+
<option value="correct">Correct Only</option>
|
|
1189
|
+
<option value="incorrect">Incorrect Only</option>
|
|
1190
|
+
</select>
|
|
1191
|
+
</div>
|
|
1192
|
+
<div style="margin-left: auto; font-size: 0.75rem; color: var(--text-muted);">
|
|
1193
|
+
<span id="filter-count"></span>
|
|
1194
|
+
</div>
|
|
1195
|
+
</div>
|
|
1196
|
+
<div class="eval-gallery" id="eval-gallery"></div>
|
|
1197
|
+
</div>
|
|
1198
|
+
|
|
1199
|
+
<div class="terminal-panel" id="terminal-panel">
|
|
1200
|
+
<div class="terminal-header">
|
|
1201
|
+
<h2>Training Output</h2>
|
|
1202
|
+
<button class="terminal-toggle" id="terminal-toggle" onclick="toggleTerminal()">
|
|
1203
|
+
<span id="terminal-toggle-text">Collapse</span>
|
|
1204
|
+
</button>
|
|
1205
|
+
</div>
|
|
1206
|
+
<div class="terminal-controls">
|
|
1207
|
+
<button class="terminal-control-btn active" id="auto-scroll-btn" onclick="toggleAutoScroll()">Auto-scroll</button>
|
|
1208
|
+
<button class="terminal-control-btn" id="wrap-btn" onclick="toggleWrap()">Wrap text</button>
|
|
1209
|
+
<span style="margin-left: auto; color: var(--text-secondary); font-size: 0.7rem;">
|
|
1210
|
+
<span id="terminal-line-count">0</span> lines
|
|
1211
|
+
</span>
|
|
1212
|
+
</div>
|
|
1213
|
+
<div class="terminal-container" id="terminal-container">
|
|
1214
|
+
<div class="terminal-output" id="terminal-output">
|
|
1215
|
+
<div class="terminal-empty">Waiting for training output...</div>
|
|
1216
|
+
</div>
|
|
1217
|
+
</div>
|
|
1218
|
+
</div>
|
|
1219
|
+
|
|
1220
|
+
<div class="update-indicator" id="update-indicator">Last updated: just now</div>
|
|
1221
|
+
</div>
|
|
1222
|
+
|
|
1223
|
+
<script>
|
|
1224
|
+
let losses = {losses_json};
|
|
1225
|
+
let epochAvg = {epoch_avg_json};
|
|
1226
|
+
let lossChart, epochChart;
|
|
1227
|
+
let lastStep = {state.step};
|
|
1228
|
+
let lastLoss = {state.loss};
|
|
1229
|
+
|
|
1230
|
+
// Cloud cost tracking
|
|
1231
|
+
const instanceType = '{state.instance_type}';
|
|
1232
|
+
const COST_RATES = {{
|
|
1233
|
+
'gpu_1x_a10': 0.75, // Lambda Labs A10
|
|
1234
|
+
'gpu_8x_a100': 1.29, // Lambda Labs A100 (per GPU)
|
|
1235
|
+
'a10': 0.75, // Generic A10
|
|
1236
|
+
'a100': 1.29, // Generic A100
|
|
1237
|
+
}};
|
|
1238
|
+
|
|
1239
|
+
function getHourlyRate(instanceType) {{
|
|
1240
|
+
// Try exact match first
|
|
1241
|
+
if (COST_RATES[instanceType.toLowerCase()]) {{
|
|
1242
|
+
return COST_RATES[instanceType.toLowerCase()];
|
|
1243
|
+
}}
|
|
1244
|
+
// Try partial match
|
|
1245
|
+
const typeStr = instanceType.toLowerCase();
|
|
1246
|
+
if (typeStr.includes('a100')) return COST_RATES['a100'];
|
|
1247
|
+
if (typeStr.includes('a10')) return COST_RATES['a10'];
|
|
1248
|
+
// Default to A10 rate
|
|
1249
|
+
return COST_RATES['a10'];
|
|
1250
|
+
}}
|
|
1251
|
+
|
|
1252
|
+
function updateCostDisplay() {{
|
|
1253
|
+
// Only show costs for actual cloud training (not stub/local)
|
|
1254
|
+
if (!instanceType || instanceType === '' || instanceType === 'stub') {{
|
|
1255
|
+
document.getElementById('card-cost').style.display = 'none';
|
|
1256
|
+
return;
|
|
1257
|
+
}}
|
|
1258
|
+
|
|
1259
|
+
const hourlyRate = getHourlyRate(instanceType);
|
|
1260
|
+
|
|
1261
|
+
// Calculate running cost based on elapsed time
|
|
1262
|
+
const timeSinceSync = (Date.now() - lastSyncTime) / 1000;
|
|
1263
|
+
const liveElapsed = baseElapsedTime + timeSinceSync;
|
|
1264
|
+
const elapsedHours = liveElapsed / 3600;
|
|
1265
|
+
const runningCost = elapsedHours * hourlyRate;
|
|
1266
|
+
|
|
1267
|
+
// Calculate estimated total cost
|
|
1268
|
+
let estimatedTotal = runningCost;
|
|
1269
|
+
if (etaSeconds > 0) {{
|
|
1270
|
+
const totalTimeSeconds = liveElapsed + etaSeconds;
|
|
1271
|
+
const totalHours = totalTimeSeconds / 3600;
|
|
1272
|
+
estimatedTotal = totalHours * hourlyRate;
|
|
1273
|
+
}}
|
|
1274
|
+
|
|
1275
|
+
// Update display
|
|
1276
|
+
document.getElementById('stat-running-cost').textContent = `$${{runningCost.toFixed(2)}}`;
|
|
1277
|
+
document.getElementById('stat-est-total').textContent = `Est. Total: $${{estimatedTotal.toFixed(2)}}`;
|
|
1278
|
+
}}
|
|
1279
|
+
|
|
1280
|
+
async function stopTraining() {{
|
|
1281
|
+
const btn = document.getElementById('stop-training-btn');
|
|
1282
|
+
const status = document.getElementById('stop-status');
|
|
1283
|
+
|
|
1284
|
+
btn.disabled = true;
|
|
1285
|
+
btn.innerHTML = '<span style="font-size: 1.1rem;">⏳</span> Stopping...';
|
|
1286
|
+
btn.style.background = '#666';
|
|
1287
|
+
|
|
1288
|
+
try {{
|
|
1289
|
+
// Try to create stop signal via API
|
|
1290
|
+
const response = await fetch('/api/stop', {{
|
|
1291
|
+
method: 'POST',
|
|
1292
|
+
headers: {{ 'Content-Type': 'application/json' }}
|
|
1293
|
+
}});
|
|
1294
|
+
|
|
1295
|
+
if (response.ok) {{
|
|
1296
|
+
btn.innerHTML = '<span style="font-size: 1.1rem;">✓</span> Stop Signal Sent';
|
|
1297
|
+
btn.style.background = '#22c55e';
|
|
1298
|
+
status.textContent = 'Training will stop after current step. Checkpoints will be downloaded.';
|
|
1299
|
+
status.style.color = '#22c55e';
|
|
1300
|
+
}} else {{
|
|
1301
|
+
throw new Error('Server returned ' + response.status);
|
|
1302
|
+
}}
|
|
1303
|
+
}} catch (e) {{
|
|
1304
|
+
// Fallback: show manual command
|
|
1305
|
+
btn.innerHTML = '<span style="font-size: 1.1rem;">!</span> Manual Stop Required';
|
|
1306
|
+
btn.style.background = '#f59e0b';
|
|
1307
|
+
status.innerHTML = 'Run this command to stop training:<br><code style="background: #1a1a24; padding: 4px 8px; border-radius: 4px; font-family: monospace;">touch training_output/STOP_TRAINING</code>';
|
|
1308
|
+
status.style.color = '#f59e0b';
|
|
1309
|
+
}}
|
|
1310
|
+
}}
|
|
1311
|
+
|
|
1312
|
+
function updateTerminationStatus(data) {{
|
|
1313
|
+
const stopSection = document.getElementById('stop-training-section');
|
|
1314
|
+
if (!stopSection) return;
|
|
1315
|
+
|
|
1316
|
+
const termStatus = data.termination_status || 'auto_complete';
|
|
1317
|
+
const termMessage = data.termination_message || '';
|
|
1318
|
+
|
|
1319
|
+
const statusStyles = {{
|
|
1320
|
+
'auto_complete': {{ color: '#22c55e', icon: '✓', label: 'Training Complete' }},
|
|
1321
|
+
'auto_low_loss': {{ color: '#22c55e', icon: '✓', label: 'Auto-Stopped (Low Loss)' }},
|
|
1322
|
+
'user_stop': {{ color: '#f59e0b', icon: '■', label: 'Stopped by User' }},
|
|
1323
|
+
}};
|
|
1324
|
+
|
|
1325
|
+
const style = statusStyles[termStatus] || statusStyles['auto_complete'];
|
|
1326
|
+
|
|
1327
|
+
let html = `<div style="display: flex; flex-direction: column; gap: 8px;">
|
|
1328
|
+
<div style="display: flex; align-items: center; gap: 8px; color: ${{style.color}};">
|
|
1329
|
+
<span style="font-size: 1.2rem;">${{style.icon}}</span>
|
|
1330
|
+
<span style="font-weight: 600;">${{style.label}}</span>
|
|
1331
|
+
</div>`;
|
|
1332
|
+
|
|
1333
|
+
if (termMessage) {{
|
|
1334
|
+
html += `<div style="font-size: 0.85rem; color: var(--text-muted); margin-left: 28px;">${{termMessage}}</div>`;
|
|
1335
|
+
}}
|
|
1336
|
+
|
|
1337
|
+
html += '</div>';
|
|
1338
|
+
stopSection.innerHTML = html;
|
|
1339
|
+
}}
|
|
1340
|
+
|
|
1341
|
+
function initCharts() {{
|
|
1342
|
+
const lossCtx = document.getElementById('lossChart').getContext('2d');
|
|
1343
|
+
lossChart = new Chart(lossCtx, {{
|
|
1344
|
+
type: 'line',
|
|
1345
|
+
data: {{
|
|
1346
|
+
labels: losses.map(l => l.step),
|
|
1347
|
+
datasets: [{{
|
|
1348
|
+
label: 'Loss',
|
|
1349
|
+
data: losses.map(l => l.loss),
|
|
1350
|
+
borderColor: '#00d4aa',
|
|
1351
|
+
backgroundColor: 'rgba(0, 212, 170, 0.1)',
|
|
1352
|
+
fill: true,
|
|
1353
|
+
tension: 0.3,
|
|
1354
|
+
pointRadius: losses.length > 50 ? 0 : 3,
|
|
1355
|
+
}}]
|
|
1356
|
+
}},
|
|
1357
|
+
options: {{
|
|
1358
|
+
responsive: true,
|
|
1359
|
+
maintainAspectRatio: false,
|
|
1360
|
+
animation: {{ duration: 500 }},
|
|
1361
|
+
scales: {{
|
|
1362
|
+
x: {{
|
|
1363
|
+
title: {{ display: true, text: 'Step', color: '#888' }},
|
|
1364
|
+
grid: {{ color: 'rgba(255,255,255,0.05)' }},
|
|
1365
|
+
ticks: {{ color: '#888' }}
|
|
1366
|
+
}},
|
|
1367
|
+
y: {{
|
|
1368
|
+
title: {{ display: true, text: 'Loss', color: '#888' }},
|
|
1369
|
+
grid: {{ color: 'rgba(255,255,255,0.05)' }},
|
|
1370
|
+
ticks: {{ color: '#888' }}
|
|
1371
|
+
}}
|
|
1372
|
+
}},
|
|
1373
|
+
plugins: {{ legend: {{ display: false }} }}
|
|
1374
|
+
}}
|
|
1375
|
+
}});
|
|
1376
|
+
|
|
1377
|
+
const epochCtx = document.getElementById('epochChart').getContext('2d');
|
|
1378
|
+
epochChart = new Chart(epochCtx, {{
|
|
1379
|
+
type: 'bar',
|
|
1380
|
+
data: {{
|
|
1381
|
+
labels: epochAvg.map(e => `Epoch ${{e[0] + 1}}`),
|
|
1382
|
+
datasets: [{{
|
|
1383
|
+
label: 'Avg Loss',
|
|
1384
|
+
data: epochAvg.map(e => e[1]),
|
|
1385
|
+
backgroundColor: 'rgba(167, 139, 250, 0.6)',
|
|
1386
|
+
borderColor: '#a78bfa',
|
|
1387
|
+
borderWidth: 1,
|
|
1388
|
+
}}]
|
|
1389
|
+
}},
|
|
1390
|
+
options: {{
|
|
1391
|
+
responsive: true,
|
|
1392
|
+
maintainAspectRatio: false,
|
|
1393
|
+
animation: {{ duration: 500 }},
|
|
1394
|
+
scales: {{
|
|
1395
|
+
y: {{
|
|
1396
|
+
beginAtZero: false,
|
|
1397
|
+
grid: {{ color: 'rgba(255,255,255,0.05)' }},
|
|
1398
|
+
ticks: {{ color: '#888' }}
|
|
1399
|
+
}},
|
|
1400
|
+
x: {{
|
|
1401
|
+
grid: {{ display: false }},
|
|
1402
|
+
ticks: {{ color: '#888' }}
|
|
1403
|
+
}}
|
|
1404
|
+
}},
|
|
1405
|
+
plugins: {{ legend: {{ display: false }} }}
|
|
1406
|
+
}}
|
|
1407
|
+
}});
|
|
1408
|
+
|
|
1409
|
+
updateTrend();
|
|
1410
|
+
}}
|
|
1411
|
+
|
|
1412
|
+
function updateTrend() {{
|
|
1413
|
+
if (losses.length >= 10) {{
|
|
1414
|
+
const recent = losses.slice(-10);
|
|
1415
|
+
const first = recent[0].loss;
|
|
1416
|
+
const last = recent[recent.length - 1].loss;
|
|
1417
|
+
const change = ((last - first) / first * 100).toFixed(1);
|
|
1418
|
+
const trendEl = document.getElementById('loss-trend');
|
|
1419
|
+
if (change < 0) {{
|
|
1420
|
+
trendEl.textContent = `↓ ${{Math.abs(change)}}% (last 10)`;
|
|
1421
|
+
trendEl.style.color = '#34d399';
|
|
1422
|
+
}} else {{
|
|
1423
|
+
trendEl.textContent = `↑ ${{change}}% (last 10)`;
|
|
1424
|
+
trendEl.style.color = '#ff5f5f';
|
|
1425
|
+
}}
|
|
1426
|
+
}}
|
|
1427
|
+
}}
|
|
1428
|
+
|
|
1429
|
+
// Live elapsed timer variables
|
|
1430
|
+
let baseElapsedTime = {elapsed}; // Last known elapsed time from server
|
|
1431
|
+
let lastSyncTime = Date.now(); // When we last synced with server
|
|
1432
|
+
let lastSuccessfulFetch = Date.now(); // When we last got a successful response
|
|
1433
|
+
let currentJobId = '{state.job_id}'; // Current job ID
|
|
1434
|
+
const STALE_THRESHOLD_SECONDS = 30; // Consider stale after 30s without updates
|
|
1435
|
+
|
|
1436
|
+
// ETA tracking
|
|
1437
|
+
let etaSeconds = {eta_seconds};
|
|
1438
|
+
let avgStepTime = {avg_step_time};
|
|
1439
|
+
let remainingSteps = {remaining_steps};
|
|
1440
|
+
let isTrainingComplete = {'true' if is_training_complete else 'false'};
|
|
1441
|
+
|
|
1442
|
+
// Auto-stop when loss <= threshold (INVARIANT: training should stop when loss <= 1.0)
|
|
1443
|
+
const AUTO_STOP_LOSS_THRESHOLD = 1.0;
|
|
1444
|
+
let autoStopTriggered = false;
|
|
1445
|
+
|
|
1446
|
+
function updateElapsedDisplay() {{
|
|
1447
|
+
// Don't update elapsed if training is complete
|
|
1448
|
+
if (isTrainingComplete) {{
|
|
1449
|
+
return;
|
|
1450
|
+
}}
|
|
1451
|
+
|
|
1452
|
+
// Calculate live elapsed: base time + time since last sync
|
|
1453
|
+
const timeSinceSync = (Date.now() - lastSyncTime) / 1000;
|
|
1454
|
+
const liveElapsed = baseElapsedTime + timeSinceSync;
|
|
1455
|
+
const mins = Math.floor(liveElapsed / 60);
|
|
1456
|
+
const secs = Math.floor(liveElapsed % 60);
|
|
1457
|
+
document.getElementById('stat-elapsed').textContent = `${{mins}}m ${{secs}}s`;
|
|
1458
|
+
|
|
1459
|
+
// Update ETA countdown
|
|
1460
|
+
if (etaSeconds > 0) {{
|
|
1461
|
+
const liveEta = Math.max(0, etaSeconds - timeSinceSync);
|
|
1462
|
+
const etaMins = Math.floor(liveEta / 60);
|
|
1463
|
+
const etaSecs = Math.floor(liveEta % 60);
|
|
1464
|
+
document.getElementById('stat-eta').textContent = `${{etaMins}}m ${{etaSecs}}s`;
|
|
1465
|
+
}}
|
|
1466
|
+
|
|
1467
|
+
// Update cost display
|
|
1468
|
+
updateCostDisplay();
|
|
1469
|
+
}}
|
|
1470
|
+
|
|
1471
|
+
function updateStatusIndicator() {{
|
|
1472
|
+
const timeSinceUpdate = (Date.now() - lastSuccessfulFetch) / 1000;
|
|
1473
|
+
const statusEl = document.getElementById('status');
|
|
1474
|
+
const statusText = document.getElementById('status-text');
|
|
1475
|
+
|
|
1476
|
+
if (timeSinceUpdate > STALE_THRESHOLD_SECONDS) {{
|
|
1477
|
+
statusEl.className = 'status stale';
|
|
1478
|
+
const staleMins = Math.floor(timeSinceUpdate / 60);
|
|
1479
|
+
const staleSecs = Math.floor(timeSinceUpdate % 60);
|
|
1480
|
+
if (staleMins > 0) {{
|
|
1481
|
+
statusText.innerHTML = `STALE <span class="stale-warning">(no update for ${{staleMins}}m ${{staleSecs}}s)</span>`;
|
|
1482
|
+
}} else {{
|
|
1483
|
+
statusText.innerHTML = `STALE <span class="stale-warning">(no update for ${{staleSecs}}s)</span>`;
|
|
1484
|
+
}}
|
|
1485
|
+
}} else {{
|
|
1486
|
+
statusEl.className = 'status';
|
|
1487
|
+
statusText.textContent = 'LIVE';
|
|
1488
|
+
}}
|
|
1489
|
+
}}
|
|
1490
|
+
|
|
1491
|
+
async function fetchAndUpdate() {{
|
|
1492
|
+
try {{
|
|
1493
|
+
const response = await fetch('training_log.json?t=' + Date.now());
|
|
1494
|
+
if (!response.ok) return;
|
|
1495
|
+
|
|
1496
|
+
const data = await response.json();
|
|
1497
|
+
lastSuccessfulFetch = Date.now();
|
|
1498
|
+
|
|
1499
|
+
// Check if job_id has changed - if so, reload to get fresh data
|
|
1500
|
+
if (data.job_id && data.job_id !== currentJobId) {{
|
|
1501
|
+
console.log(`Job changed from ${{currentJobId}} to ${{data.job_id}}, reloading...`);
|
|
1502
|
+
location.reload();
|
|
1503
|
+
return;
|
|
1504
|
+
}}
|
|
1505
|
+
|
|
1506
|
+
// Update job info display
|
|
1507
|
+
if (data.job_id) {{
|
|
1508
|
+
const jobIdEl = document.querySelector('.job-id');
|
|
1509
|
+
const jobHostEl = document.querySelector('.job-host');
|
|
1510
|
+
const jobConfigEl = document.querySelector('.job-config');
|
|
1511
|
+
if (jobIdEl) jobIdEl.textContent = `Job: ${{data.job_id}}`;
|
|
1512
|
+
if (jobHostEl) {{
|
|
1513
|
+
let hostText = data.hostname || 'local';
|
|
1514
|
+
if (data.instance_ip) hostText += ` @ ${{data.instance_ip}}`;
|
|
1515
|
+
jobHostEl.textContent = hostText;
|
|
1516
|
+
}}
|
|
1517
|
+
if (jobConfigEl && data.config_path) {{
|
|
1518
|
+
jobConfigEl.textContent = data.config_path;
|
|
1519
|
+
}}
|
|
1520
|
+
}}
|
|
1521
|
+
|
|
1522
|
+
// Update setup panel if setup logs present
|
|
1523
|
+
if (data.setup_logs && data.setup_logs.length > 0) {{
|
|
1524
|
+
const setupPanel = document.getElementById('setup-panel');
|
|
1525
|
+
const setupLogs = document.getElementById('setup-logs');
|
|
1526
|
+
const setupBadge = document.getElementById('setup-status-badge');
|
|
1527
|
+
|
|
1528
|
+
setupPanel.classList.remove('hidden');
|
|
1529
|
+
|
|
1530
|
+
// Update status badge
|
|
1531
|
+
if (data.setup_status) {{
|
|
1532
|
+
setupBadge.textContent = data.setup_status;
|
|
1533
|
+
setupBadge.className = `setup-status-badge ${{data.setup_status}}`;
|
|
1534
|
+
}}
|
|
1535
|
+
|
|
1536
|
+
// Update logs
|
|
1537
|
+
setupLogs.innerHTML = data.setup_logs.map((log, i) =>
|
|
1538
|
+
`<div class="setup-log-line${{i === data.setup_logs.length - 1 ? ' current' : ''}}">${{log}}</div>`
|
|
1539
|
+
).join('');
|
|
1540
|
+
|
|
1541
|
+
// Auto-scroll to bottom
|
|
1542
|
+
setupLogs.scrollTop = setupLogs.scrollHeight;
|
|
1543
|
+
|
|
1544
|
+
// Hide setup panel when training starts
|
|
1545
|
+
if (data.setup_status === 'training' || data.setup_status === 'complete') {{
|
|
1546
|
+
setTimeout(() => setupPanel.classList.add('hidden'), 3000);
|
|
1547
|
+
}}
|
|
1548
|
+
}}
|
|
1549
|
+
|
|
1550
|
+
// Always update elapsed time base
|
|
1551
|
+
if (data.elapsed_time) {{
|
|
1552
|
+
baseElapsedTime = data.elapsed_time;
|
|
1553
|
+
lastSyncTime = Date.now();
|
|
1554
|
+
}}
|
|
1555
|
+
|
|
1556
|
+
// Check for termination status (handles completed/stopped states)
|
|
1557
|
+
if (data.termination_status && !isTrainingComplete) {{
|
|
1558
|
+
isTrainingComplete = true;
|
|
1559
|
+
document.getElementById('stat-eta').textContent = 'Complete';
|
|
1560
|
+
document.getElementById('eta-detail').textContent = '';
|
|
1561
|
+
updateTerminationStatus(data);
|
|
1562
|
+
updateCostDisplay();
|
|
1563
|
+
}}
|
|
1564
|
+
|
|
1565
|
+
// Only update other stats if step changed
|
|
1566
|
+
if (data.step !== lastStep) {{
|
|
1567
|
+
// Update with animation
|
|
1568
|
+
const cards = document.querySelectorAll('.stat-card');
|
|
1569
|
+
cards.forEach(c => c.classList.add('updating'));
|
|
1570
|
+
setTimeout(() => cards.forEach(c => c.classList.remove('updating')), 300);
|
|
1571
|
+
|
|
1572
|
+
// Update stats
|
|
1573
|
+
const totalEpochs = data.total_epochs || {config.num_train_epochs};
|
|
1574
|
+
const displayEpoch = Math.min(data.epoch + 1, totalEpochs); // Cap at max
|
|
1575
|
+
document.getElementById('stat-epoch').textContent = `${{displayEpoch}} / ${{totalEpochs}}`;
|
|
1576
|
+
document.getElementById('epoch-progress').style.width = `${{(displayEpoch / totalEpochs) * 100}}%`;
|
|
1577
|
+
document.getElementById('stat-step').textContent = data.step;
|
|
1578
|
+
document.getElementById('stat-loss').textContent = data.loss.toFixed(4);
|
|
1579
|
+
|
|
1580
|
+
// Loss delta
|
|
1581
|
+
const delta = data.loss - lastLoss;
|
|
1582
|
+
const deltaEl = document.getElementById('loss-delta');
|
|
1583
|
+
if (delta < 0) {{
|
|
1584
|
+
deltaEl.textContent = `↓ ${{Math.abs(delta).toFixed(4)}}`;
|
|
1585
|
+
deltaEl.className = 'stat-delta positive';
|
|
1586
|
+
}} else {{
|
|
1587
|
+
deltaEl.textContent = `↑ ${{delta.toFixed(4)}}`;
|
|
1588
|
+
deltaEl.className = 'stat-delta negative';
|
|
1589
|
+
}}
|
|
1590
|
+
|
|
1591
|
+
// AUTO-STOP: Trigger stop when loss <= threshold and training is running
|
|
1592
|
+
if (!autoStopTriggered && !isTrainingComplete && data.loss <= AUTO_STOP_LOSS_THRESHOLD) {{
|
|
1593
|
+
autoStopTriggered = true;
|
|
1594
|
+
console.log(`Auto-stop triggered: loss ${{data.loss.toFixed(4)}} <= threshold ${{AUTO_STOP_LOSS_THRESHOLD}}`);
|
|
1595
|
+
|
|
1596
|
+
// Show notification
|
|
1597
|
+
const notif = document.createElement('div');
|
|
1598
|
+
notif.className = 'auto-stop-notification';
|
|
1599
|
+
notif.innerHTML = `
|
|
1600
|
+
<strong>Auto-Stop Triggered</strong><br>
|
|
1601
|
+
Loss ${{data.loss.toFixed(4)}} ≤ ${{AUTO_STOP_LOSS_THRESHOLD}} threshold.<br>
|
|
1602
|
+
Stopping training...
|
|
1603
|
+
`;
|
|
1604
|
+
notif.style.cssText = `
|
|
1605
|
+
position: fixed; top: 20px; right: 20px; z-index: 9999;
|
|
1606
|
+
background: #2d4a3e; color: #4ade80; padding: 15px 20px;
|
|
1607
|
+
border-radius: 8px; border: 1px solid #4ade80;
|
|
1608
|
+
box-shadow: 0 4px 12px rgba(0,0,0,0.3);
|
|
1609
|
+
animation: slideIn 0.3s ease;
|
|
1610
|
+
`;
|
|
1611
|
+
document.body.appendChild(notif);
|
|
1612
|
+
|
|
1613
|
+
// Call stop endpoint
|
|
1614
|
+
fetch('/api/stop', {{ method: 'POST' }})
|
|
1615
|
+
.then(r => r.json())
|
|
1616
|
+
.then(result => {{
|
|
1617
|
+
console.log('Stop result:', result);
|
|
1618
|
+
setTimeout(() => notif.remove(), 5000);
|
|
1619
|
+
}})
|
|
1620
|
+
.catch(err => {{
|
|
1621
|
+
console.error('Stop failed:', err);
|
|
1622
|
+
notif.innerHTML += '<br><span style="color:#f87171">Stop request failed</span>';
|
|
1623
|
+
}});
|
|
1624
|
+
}}
|
|
1625
|
+
|
|
1626
|
+
// Other stats
|
|
1627
|
+
if (data.losses && data.losses.length > 0) {{
|
|
1628
|
+
const minLoss = Math.min(...data.losses.map(l => l.loss));
|
|
1629
|
+
document.getElementById('stat-min-loss').textContent = minLoss.toFixed(4);
|
|
1630
|
+
|
|
1631
|
+
const recentLosses = data.losses.slice(-10);
|
|
1632
|
+
const avgLoss = recentLosses.reduce((a, b) => a + b.loss, 0) / recentLosses.length;
|
|
1633
|
+
document.getElementById('stat-avg-loss').textContent = avgLoss.toFixed(4);
|
|
1634
|
+
|
|
1635
|
+
// Calculate avg step time and update ETA
|
|
1636
|
+
if (data.losses.length > 1) {{
|
|
1637
|
+
let stepTimes = [];
|
|
1638
|
+
for (let i = 1; i < data.losses.length; i++) {{
|
|
1639
|
+
stepTimes.push(data.losses[i].time - data.losses[i-1].time);
|
|
1640
|
+
}}
|
|
1641
|
+
avgStepTime = stepTimes.reduce((a,b) => a+b, 0) / stepTimes.length;
|
|
1642
|
+
document.getElementById('stat-step-time').textContent = avgStepTime.toFixed(1) + 's';
|
|
1643
|
+
|
|
1644
|
+
// Recalculate ETA
|
|
1645
|
+
const totalEpochs = data.total_epochs || {config.num_train_epochs};
|
|
1646
|
+
const currentEpoch = data.epoch;
|
|
1647
|
+
const stepsInCompletedEpochs = data.losses.filter(l => l.epoch < currentEpoch).length;
|
|
1648
|
+
const stepsPerEpoch = currentEpoch > 0 && stepsInCompletedEpochs > 0
|
|
1649
|
+
? stepsInCompletedEpochs / currentEpoch
|
|
1650
|
+
: data.losses.length / (currentEpoch + 1);
|
|
1651
|
+
const totalStepsEstimate = stepsPerEpoch * totalEpochs;
|
|
1652
|
+
remainingSteps = Math.max(0, totalStepsEstimate - data.losses.length);
|
|
1653
|
+
etaSeconds = remainingSteps * avgStepTime;
|
|
1654
|
+
|
|
1655
|
+
// Update ETA display
|
|
1656
|
+
if (etaSeconds > 0) {{
|
|
1657
|
+
const etaMins = Math.floor(etaSeconds / 60);
|
|
1658
|
+
const etaSecs = Math.floor(etaSeconds % 60);
|
|
1659
|
+
document.getElementById('stat-eta').textContent = `${{etaMins}}m ${{etaSecs}}s`;
|
|
1660
|
+
document.getElementById('eta-detail').textContent = `~${{Math.round(remainingSteps)}} steps @ ${{avgStepTime.toFixed(1)}}s/step`;
|
|
1661
|
+
}} else if (data.losses.length > 0) {{
|
|
1662
|
+
// Training complete - stop elapsed timer and update UI
|
|
1663
|
+
isTrainingComplete = true;
|
|
1664
|
+
document.getElementById('stat-eta').textContent = 'Complete';
|
|
1665
|
+
document.getElementById('eta-detail').textContent = '';
|
|
1666
|
+
// Update cost display one final time
|
|
1667
|
+
updateCostDisplay();
|
|
1668
|
+
// Replace stop button with termination status
|
|
1669
|
+
updateTerminationStatus(data);
|
|
1670
|
+
}} else {{
|
|
1671
|
+
// No data yet
|
|
1672
|
+
document.getElementById('stat-eta').textContent = 'calculating...';
|
|
1673
|
+
}}
|
|
1674
|
+
}}
|
|
1675
|
+
|
|
1676
|
+
// Update charts
|
|
1677
|
+
losses = data.losses;
|
|
1678
|
+
lossChart.data.labels = losses.map(l => l.step);
|
|
1679
|
+
lossChart.data.datasets[0].data = losses.map(l => l.loss);
|
|
1680
|
+
lossChart.data.datasets[0].pointRadius = losses.length > 50 ? 0 : 3;
|
|
1681
|
+
lossChart.update('none');
|
|
1682
|
+
|
|
1683
|
+
// Recalculate epoch averages
|
|
1684
|
+
const epochLosses = {{}};
|
|
1685
|
+
losses.forEach(l => {{
|
|
1686
|
+
if (!epochLosses[l.epoch]) epochLosses[l.epoch] = [];
|
|
1687
|
+
epochLosses[l.epoch].push(l.loss);
|
|
1688
|
+
}});
|
|
1689
|
+
epochAvg = Object.entries(epochLosses).map(([ep, arr]) => [parseInt(ep), arr.reduce((a,b) => a+b, 0) / arr.length]);
|
|
1690
|
+
epochChart.data.labels = epochAvg.map(e => `Epoch ${{e[0] + 1}}`);
|
|
1691
|
+
epochChart.data.datasets[0].data = epochAvg.map(e => e[1]);
|
|
1692
|
+
epochChart.update('none');
|
|
1693
|
+
|
|
1694
|
+
updateTrend();
|
|
1695
|
+
}}
|
|
1696
|
+
|
|
1697
|
+
lastStep = data.step;
|
|
1698
|
+
lastLoss = data.loss;
|
|
1699
|
+
}}
|
|
1700
|
+
|
|
1701
|
+
// Update evaluations if present
|
|
1702
|
+
if (data.evaluations && data.evaluations.length > 0) {{
|
|
1703
|
+
renderEvaluations(data.evaluations);
|
|
1704
|
+
}}
|
|
1705
|
+
|
|
1706
|
+
document.getElementById('update-indicator').textContent = 'Last updated: just now';
|
|
1707
|
+
}} catch (e) {{
|
|
1708
|
+
console.log('Update failed:', e);
|
|
1709
|
+
}}
|
|
1710
|
+
}}
|
|
1711
|
+
|
|
1712
|
+
function renderEvaluations(evaluations) {{
|
|
1713
|
+
const panel = document.getElementById('eval-panel');
|
|
1714
|
+
const gallery = document.getElementById('eval-gallery');
|
|
1715
|
+
const metrics = document.getElementById('eval-metrics');
|
|
1716
|
+
|
|
1717
|
+
if (evaluations.length === 0) {{
|
|
1718
|
+
panel.style.display = 'none';
|
|
1719
|
+
return;
|
|
1720
|
+
}}
|
|
1721
|
+
|
|
1722
|
+
panel.style.display = 'block';
|
|
1723
|
+
|
|
1724
|
+
// Calculate metrics
|
|
1725
|
+
const correctCount = evaluations.filter(e => e.correct).length;
|
|
1726
|
+
const avgDistance = evaluations.reduce((a, e) => a + e.distance, 0) / evaluations.length;
|
|
1727
|
+
const accuracy = (correctCount / evaluations.length * 100).toFixed(1);
|
|
1728
|
+
|
|
1729
|
+
metrics.innerHTML = `
|
|
1730
|
+
<div class="metric">
|
|
1731
|
+
<span class="metric-label">Accuracy</span>
|
|
1732
|
+
<span class="metric-value">${{accuracy}}%</span>
|
|
1733
|
+
</div>
|
|
1734
|
+
<div class="metric">
|
|
1735
|
+
<span class="metric-label">Avg Distance</span>
|
|
1736
|
+
<span class="metric-value">${{avgDistance.toFixed(1)}}px</span>
|
|
1737
|
+
</div>
|
|
1738
|
+
<div class="metric">
|
|
1739
|
+
<span class="metric-label">Samples</span>
|
|
1740
|
+
<span class="metric-value">${{evaluations.length}}</span>
|
|
1741
|
+
</div>
|
|
1742
|
+
<div class="legend" style="display: flex; gap: 16px; margin-left: auto; font-size: 0.75rem; align-items: center;">
|
|
1743
|
+
<span style="display: flex; align-items: center; gap: 4px;">
|
|
1744
|
+
<span style="width: 12px; height: 12px; border-radius: 50%; background: rgba(52, 211, 153, 0.8);"></span>
|
|
1745
|
+
Human
|
|
1746
|
+
</span>
|
|
1747
|
+
<span style="display: flex; align-items: center; gap: 4px;">
|
|
1748
|
+
<span style="width: 12px; height: 12px; border-radius: 50%; background: rgba(167, 139, 250, 0.8);"></span>
|
|
1749
|
+
Predicted
|
|
1750
|
+
</span>
|
|
1751
|
+
</div>
|
|
1752
|
+
`;
|
|
1753
|
+
|
|
1754
|
+
// Render gallery (show last 9 evaluations)
|
|
1755
|
+
const recentEvals = evaluations.slice(-9);
|
|
1756
|
+
gallery.innerHTML = recentEvals.map((ev, i) => {{
|
|
1757
|
+
const statusClass = ev.correct ? 'correct' : 'incorrect';
|
|
1758
|
+
const statusText = ev.correct ? '✓ Correct' : '✗ Off by ' + (ev.distance * 100).toFixed(1) + '%';
|
|
1759
|
+
const humanX = (ev.human_action.x || 0).toFixed(3);
|
|
1760
|
+
const humanY = (ev.human_action.y || 0).toFixed(3);
|
|
1761
|
+
const predX = (ev.predicted_action.x || 0).toFixed(3);
|
|
1762
|
+
const predY = (ev.predicted_action.y || 0).toFixed(3);
|
|
1763
|
+
const rawOutput = ev.predicted_action.raw_output || '';
|
|
1764
|
+
const thoughtMatch = rawOutput.match(/Thought:([\\s\\S]*?)(?:Action:|$)/);
|
|
1765
|
+
const thought = thoughtMatch ? thoughtMatch[1].trim().substring(0, 200) : '';
|
|
1766
|
+
const sampleId = 'eval-' + ev.epoch + '-' + ev.sample_idx;
|
|
1767
|
+
return `
|
|
1768
|
+
<div class="eval-sample">
|
|
1769
|
+
<div style="position: relative;">
|
|
1770
|
+
<img src="${{ev.image_path}}" alt="Sample ${{ev.sample_idx}}" onerror="this.style.display='none'">
|
|
1771
|
+
<div class="overlay" style="width: 100%; height: 100%;">
|
|
1772
|
+
<div class="marker human" style="left: ${{(ev.human_action.x || 0) * 100}}%; top: ${{(ev.human_action.y || 0) * 100}}%;" title="Human"></div>
|
|
1773
|
+
<div class="marker predicted" style="left: ${{(ev.predicted_action.x || 0) * 100}}%; top: ${{(ev.predicted_action.y || 0) * 100}}%;" title="Predicted"></div>
|
|
1774
|
+
</div>
|
|
1775
|
+
</div>
|
|
1776
|
+
<div class="info">
|
|
1777
|
+
<span class="${{statusClass}}">${{statusText}}</span>
|
|
1778
|
+
<span> | Epoch ${{ev.epoch + 1}}</span>
|
|
1779
|
+
</div>
|
|
1780
|
+
<div class="details">
|
|
1781
|
+
<div class="coords">
|
|
1782
|
+
<span class="human-coord">Human: (${{humanX}}, ${{humanY}})</span>
|
|
1783
|
+
<span class="pred-coord">Pred: (${{predX}}, ${{predY}})</span>
|
|
1784
|
+
</div>
|
|
1785
|
+
</div>
|
|
1786
|
+
${{thought ? `
|
|
1787
|
+
<div class="thinking">${{thought}}${{thought.length >= 200 ? '...' : ''}}</div>
|
|
1788
|
+
` : ''}}
|
|
1789
|
+
</div>
|
|
1790
|
+
`;
|
|
1791
|
+
}}).join('');
|
|
1792
|
+
}}
|
|
1793
|
+
|
|
1794
|
+
// Terminal output management
|
|
1795
|
+
let terminalAutoScroll = true;
|
|
1796
|
+
let terminalWrap = false;
|
|
1797
|
+
let terminalCollapsed = false;
|
|
1798
|
+
let lastTerminalSize = 0;
|
|
1799
|
+
const MAX_TERMINAL_LINES = 500;
|
|
1800
|
+
|
|
1801
|
+
function toggleTerminal() {{
|
|
1802
|
+
const container = document.getElementById('terminal-container');
|
|
1803
|
+
const toggleText = document.getElementById('terminal-toggle-text');
|
|
1804
|
+
terminalCollapsed = !terminalCollapsed;
|
|
1805
|
+
|
|
1806
|
+
if (terminalCollapsed) {{
|
|
1807
|
+
container.classList.add('collapsed');
|
|
1808
|
+
toggleText.textContent = 'Expand';
|
|
1809
|
+
}} else {{
|
|
1810
|
+
container.classList.remove('collapsed');
|
|
1811
|
+
toggleText.textContent = 'Collapse';
|
|
1812
|
+
}}
|
|
1813
|
+
}}
|
|
1814
|
+
|
|
1815
|
+
function toggleAutoScroll() {{
|
|
1816
|
+
terminalAutoScroll = !terminalAutoScroll;
|
|
1817
|
+
const btn = document.getElementById('auto-scroll-btn');
|
|
1818
|
+
if (terminalAutoScroll) {{
|
|
1819
|
+
btn.classList.add('active');
|
|
1820
|
+
scrollTerminalToBottom();
|
|
1821
|
+
}} else {{
|
|
1822
|
+
btn.classList.remove('active');
|
|
1823
|
+
}}
|
|
1824
|
+
}}
|
|
1825
|
+
|
|
1826
|
+
function toggleWrap() {{
|
|
1827
|
+
terminalWrap = !terminalWrap;
|
|
1828
|
+
const btn = document.getElementById('wrap-btn');
|
|
1829
|
+
const output = document.getElementById('terminal-output');
|
|
1830
|
+
if (terminalWrap) {{
|
|
1831
|
+
btn.classList.add('active');
|
|
1832
|
+
output.style.whiteSpace = 'pre-wrap';
|
|
1833
|
+
}} else {{
|
|
1834
|
+
btn.classList.remove('active');
|
|
1835
|
+
output.style.whiteSpace = 'pre';
|
|
1836
|
+
}}
|
|
1837
|
+
}}
|
|
1838
|
+
|
|
1839
|
+
function scrollTerminalToBottom() {{
|
|
1840
|
+
const container = document.getElementById('terminal-container');
|
|
1841
|
+
container.scrollTop = container.scrollHeight;
|
|
1842
|
+
}}
|
|
1843
|
+
|
|
1844
|
+
async function fetchTerminalOutput() {{
|
|
1845
|
+
try {{
|
|
1846
|
+
const response = await fetch('training.log?t=' + Date.now());
|
|
1847
|
+
if (!response.ok) {{
|
|
1848
|
+
// File doesn't exist yet
|
|
1849
|
+
return;
|
|
1850
|
+
}}
|
|
1851
|
+
|
|
1852
|
+
const text = await response.text();
|
|
1853
|
+
const lines = text.trim().split('\\n');
|
|
1854
|
+
|
|
1855
|
+
// Keep only last MAX_TERMINAL_LINES
|
|
1856
|
+
const displayLines = lines.slice(-MAX_TERMINAL_LINES);
|
|
1857
|
+
|
|
1858
|
+
const output = document.getElementById('terminal-output');
|
|
1859
|
+
const lineCount = document.getElementById('terminal-line-count');
|
|
1860
|
+
|
|
1861
|
+
// Update line count
|
|
1862
|
+
lineCount.textContent = lines.length;
|
|
1863
|
+
|
|
1864
|
+
// Only update if content changed
|
|
1865
|
+
if (displayLines.length === 0) {{
|
|
1866
|
+
output.innerHTML = '<div class="terminal-empty">Waiting for training output...</div>';
|
|
1867
|
+
return;
|
|
1868
|
+
}}
|
|
1869
|
+
|
|
1870
|
+
// Format lines with basic syntax highlighting
|
|
1871
|
+
const formattedLines = displayLines.map(line => {{
|
|
1872
|
+
let className = 'terminal-line';
|
|
1873
|
+
|
|
1874
|
+
// Detect line type
|
|
1875
|
+
if (line.match(/^\\d{{4}}-\\d{{2}}-\\d{{2}}/)) {{
|
|
1876
|
+
className += ' timestamp';
|
|
1877
|
+
}} else if (line.toLowerCase().includes('error') || line.toLowerCase().includes('failed')) {{
|
|
1878
|
+
className += ' error';
|
|
1879
|
+
}} else if (line.toLowerCase().includes('success') || line.toLowerCase().includes('complete')) {{
|
|
1880
|
+
className += ' success';
|
|
1881
|
+
}} else if (line.toLowerCase().includes('warning')) {{
|
|
1882
|
+
className += ' warning';
|
|
1883
|
+
}}
|
|
1884
|
+
|
|
1885
|
+
// Escape HTML
|
|
1886
|
+
const escaped = line
|
|
1887
|
+
.replace(/&/g, '&')
|
|
1888
|
+
.replace(/</g, '<')
|
|
1889
|
+
.replace(/>/g, '>');
|
|
1890
|
+
|
|
1891
|
+
return `<div class="${{className}}">${{escaped}}</div>`;
|
|
1892
|
+
}}).join('');
|
|
1893
|
+
|
|
1894
|
+
output.innerHTML = formattedLines;
|
|
1895
|
+
|
|
1896
|
+
// Auto-scroll if enabled and new content arrived
|
|
1897
|
+
if (terminalAutoScroll && lines.length > lastTerminalSize) {{
|
|
1898
|
+
scrollTerminalToBottom();
|
|
1899
|
+
}}
|
|
1900
|
+
|
|
1901
|
+
lastTerminalSize = lines.length;
|
|
1902
|
+
}} catch (err) {{
|
|
1903
|
+
console.error('Failed to fetch terminal output:', err);
|
|
1904
|
+
}}
|
|
1905
|
+
}}
|
|
1906
|
+
|
|
1907
|
+
initCharts();
|
|
1908
|
+
updateCostDisplay(); // Initialize cost display
|
|
1909
|
+
fetchAndUpdate(); // Initial fetch on page load
|
|
1910
|
+
fetchTerminalOutput(); // Initial terminal fetch
|
|
1911
|
+
setInterval(fetchAndUpdate, 3000);
|
|
1912
|
+
setInterval(fetchTerminalOutput, 2000); // Poll terminal output every 2 seconds
|
|
1913
|
+
setInterval(updateElapsedDisplay, 1000); // Update elapsed time every second
|
|
1914
|
+
setInterval(updateStatusIndicator, 1000); // Update LIVE/STALE indicator every second
|
|
1915
|
+
</script>
|
|
1916
|
+
</body>
|
|
1917
|
+
</html>'''
|
|
1918
|
+
return html
|
|
1919
|
+
|
|
1920
|
+
|
|
1921
|
+
def regenerate_all_dashboards(output_dir: str | Path) -> list[Path]:
|
|
1922
|
+
"""Regenerate all dashboards in a directory with static navigation.
|
|
1923
|
+
|
|
1924
|
+
This updates dashboard.html and generates the unified viewer.html.
|
|
1925
|
+
Old comparison_*.html files are left in place but no longer linked.
|
|
1926
|
+
|
|
1927
|
+
Args:
|
|
1928
|
+
output_dir: Directory containing dashboard files
|
|
1929
|
+
|
|
1930
|
+
Returns:
|
|
1931
|
+
List of paths to regenerated files
|
|
1932
|
+
"""
|
|
1933
|
+
output_dir = Path(output_dir)
|
|
1934
|
+
regenerated = []
|
|
1935
|
+
|
|
1936
|
+
# Nav links are now fixed (Training + Viewer)
|
|
1937
|
+
nav_links = _build_nav_links()
|
|
1938
|
+
|
|
1939
|
+
# Regenerate main dashboard
|
|
1940
|
+
if (output_dir / "training_log.json").exists():
|
|
1941
|
+
try:
|
|
1942
|
+
path = regenerate_local_dashboard(output_dir, nav_links=nav_links)
|
|
1943
|
+
regenerated.append(path)
|
|
1944
|
+
except Exception as e:
|
|
1945
|
+
print(f"Warning: Failed to regenerate dashboard: {e}")
|
|
1946
|
+
|
|
1947
|
+
# Generate unified viewer if we have capture data
|
|
1948
|
+
try:
|
|
1949
|
+
viewer_path = generate_unified_viewer_from_output_dir(output_dir)
|
|
1950
|
+
if viewer_path:
|
|
1951
|
+
regenerated.append(viewer_path)
|
|
1952
|
+
except Exception as e:
|
|
1953
|
+
print(f"Warning: Failed to generate unified viewer: {e}")
|
|
1954
|
+
import traceback
|
|
1955
|
+
traceback.print_exc()
|
|
1956
|
+
|
|
1957
|
+
return regenerated
|
|
1958
|
+
|
|
1959
|
+
|
|
1960
|
+
def regenerate_local_dashboard(
|
|
1961
|
+
output_dir: str | Path,
|
|
1962
|
+
capture_path: str | Path | None = None,
|
|
1963
|
+
checkpoint_path: str | Path | None = None,
|
|
1964
|
+
nav_links: list[tuple[str, str]] | None = None,
|
|
1965
|
+
) -> Path:
|
|
1966
|
+
"""Regenerate dashboard.html with correct local paths and static navigation.
|
|
1967
|
+
|
|
1968
|
+
This should be called after downloading training results from a remote instance.
|
|
1969
|
+
It fixes:
|
|
1970
|
+
- Training status (COMPLETED/STOPPED instead of always LIVE)
|
|
1971
|
+
- Navigation links to sibling dashboards (comparison, viewer)
|
|
1972
|
+
- Local capture path for comparison preview
|
|
1973
|
+
|
|
1974
|
+
Args:
|
|
1975
|
+
output_dir: Directory containing training_log.json and dashboard files
|
|
1976
|
+
capture_path: Local path to capture directory (for comparison preview)
|
|
1977
|
+
checkpoint_path: Local path to checkpoint directory
|
|
1978
|
+
nav_links: Pre-built list of (filename, label) tuples for consistency
|
|
1979
|
+
|
|
1980
|
+
Returns:
|
|
1981
|
+
Path to generated dashboard.html
|
|
1982
|
+
"""
|
|
1983
|
+
output_dir = Path(output_dir)
|
|
1984
|
+
log_file = output_dir / "training_log.json"
|
|
1985
|
+
|
|
1986
|
+
if not log_file.exists():
|
|
1987
|
+
raise FileNotFoundError(f"No training_log.json found in {output_dir}")
|
|
1988
|
+
|
|
1989
|
+
# Load training state from log
|
|
1990
|
+
with open(log_file) as f:
|
|
1991
|
+
data = json.load(f)
|
|
1992
|
+
|
|
1993
|
+
# Create state from log data
|
|
1994
|
+
state = TrainingState(
|
|
1995
|
+
job_id=data.get("job_id", "unknown"),
|
|
1996
|
+
hostname=data.get("hostname", ""),
|
|
1997
|
+
capture_path=str(capture_path) if capture_path else data.get("capture_path", ""),
|
|
1998
|
+
config_path=data.get("config_path", ""),
|
|
1999
|
+
epoch=data.get("epoch", 0),
|
|
2000
|
+
step=data.get("step", 0),
|
|
2001
|
+
loss=data.get("loss", 0),
|
|
2002
|
+
learning_rate=data.get("learning_rate", 0),
|
|
2003
|
+
total_epochs=data.get("total_epochs", 5),
|
|
2004
|
+
instance_type=data.get("instance_type", ""),
|
|
2005
|
+
instance_ip=data.get("instance_ip", ""),
|
|
2006
|
+
elapsed_time=data.get("elapsed_time", 0.0),
|
|
2007
|
+
cloud_provider=data.get("cloud_provider", ""),
|
|
2008
|
+
)
|
|
2009
|
+
state.losses = data.get("losses", [])
|
|
2010
|
+
state.evaluations = data.get("evaluations", [])
|
|
2011
|
+
|
|
2012
|
+
# Determine training status
|
|
2013
|
+
total_epochs = data.get("total_epochs", 5)
|
|
2014
|
+
current_epoch = data.get("epoch", 0)
|
|
2015
|
+
|
|
2016
|
+
if current_epoch + 1 >= total_epochs:
|
|
2017
|
+
training_status = "COMPLETED"
|
|
2018
|
+
elif len(state.losses) > 0:
|
|
2019
|
+
training_status = "STOPPED"
|
|
2020
|
+
else:
|
|
2021
|
+
training_status = "NOT_STARTED"
|
|
2022
|
+
|
|
2023
|
+
# Use provided nav_links or build them
|
|
2024
|
+
if nav_links is None:
|
|
2025
|
+
nav_links = _build_nav_links()
|
|
2026
|
+
|
|
2027
|
+
# Create config
|
|
2028
|
+
config = TrainingConfig(
|
|
2029
|
+
num_train_epochs=total_epochs,
|
|
2030
|
+
learning_rate=data.get("learning_rate", 5e-5),
|
|
2031
|
+
)
|
|
2032
|
+
|
|
2033
|
+
# Generate dashboard HTML with modifications
|
|
2034
|
+
html = generate_training_dashboard(state, config)
|
|
2035
|
+
|
|
2036
|
+
# Replace dynamic status with static status
|
|
2037
|
+
if training_status == "COMPLETED":
|
|
2038
|
+
html = html.replace(
|
|
2039
|
+
'<div class="status" id="status">',
|
|
2040
|
+
'<div class="status complete" id="status">'
|
|
2041
|
+
)
|
|
2042
|
+
html = html.replace(
|
|
2043
|
+
'<span id="status-text">Training in progress</span>',
|
|
2044
|
+
'<span id="status-text">COMPLETED</span>'
|
|
2045
|
+
)
|
|
2046
|
+
elif training_status == "STOPPED":
|
|
2047
|
+
html = html.replace(
|
|
2048
|
+
'<div class="status" id="status">',
|
|
2049
|
+
'<div class="status stale" id="status">'
|
|
2050
|
+
)
|
|
2051
|
+
html = html.replace(
|
|
2052
|
+
'<span id="status-text">Training in progress</span>',
|
|
2053
|
+
'<span id="status-text">STOPPED (Epoch {}/{})'.format(current_epoch + 1, total_epochs) + '</span>'
|
|
2054
|
+
)
|
|
2055
|
+
|
|
2056
|
+
# Fix ETA display for completed/stopped training
|
|
2057
|
+
import re
|
|
2058
|
+
if training_status in ("COMPLETED", "STOPPED"):
|
|
2059
|
+
# Replace "calculating..." with appropriate status
|
|
2060
|
+
html = re.sub(
|
|
2061
|
+
r'(<div class="stat-value" id="stat-eta">)[^<]*(</div>)',
|
|
2062
|
+
r'\1—\2' if training_status == "STOPPED" else r'\1complete\2',
|
|
2063
|
+
html
|
|
2064
|
+
)
|
|
2065
|
+
|
|
2066
|
+
# Replace dynamic nav with static unified header
|
|
2067
|
+
# The dashboard now uses the shared unified-header, so we just need to ensure
|
|
2068
|
+
# the header HTML is present (it's already generated by generate_training_dashboard)
|
|
2069
|
+
|
|
2070
|
+
# Disable the JS polling and dynamic discovery (training is done, no need to fetch updates)
|
|
2071
|
+
# This is critical for file:// protocol where fetch() doesn't work
|
|
2072
|
+
html = html.replace(
|
|
2073
|
+
"setInterval(fetchAndUpdate, 3000);",
|
|
2074
|
+
"// fetchAndUpdate disabled for static dashboard"
|
|
2075
|
+
)
|
|
2076
|
+
html = html.replace(
|
|
2077
|
+
"setInterval(updateElapsedDisplay, 1000);",
|
|
2078
|
+
"// updateElapsedDisplay disabled for static dashboard"
|
|
2079
|
+
)
|
|
2080
|
+
html = html.replace(
|
|
2081
|
+
"setInterval(updateStatusIndicator, 1000);",
|
|
2082
|
+
"// updateStatusIndicator disabled for static dashboard"
|
|
2083
|
+
)
|
|
2084
|
+
# CRITICAL: Disable discoverDashboards() - it overwrites static nav on file:// protocol
|
|
2085
|
+
html = html.replace(
|
|
2086
|
+
"discoverDashboards();",
|
|
2087
|
+
"// discoverDashboards disabled - using static nav for file:// protocol"
|
|
2088
|
+
)
|
|
2089
|
+
|
|
2090
|
+
# Write output
|
|
2091
|
+
dashboard_path = output_dir / "dashboard.html"
|
|
2092
|
+
dashboard_path.write_text(html)
|
|
2093
|
+
print(f"Regenerated dashboard: {dashboard_path}")
|
|
2094
|
+
|
|
2095
|
+
return dashboard_path
|
|
2096
|
+
|
|
2097
|
+
|
|
2098
|
+
def run_epoch_evaluation(
|
|
2099
|
+
adapter: BaseVLMAdapter,
|
|
2100
|
+
episode: Episode,
|
|
2101
|
+
epoch: int,
|
|
2102
|
+
config: TrainingConfig,
|
|
2103
|
+
logger: "TrainingLogger",
|
|
2104
|
+
sample_indices: Optional[List[int]] = None,
|
|
2105
|
+
) -> Path:
|
|
2106
|
+
"""Run inference evaluation on sample steps after an epoch.
|
|
2107
|
+
|
|
2108
|
+
This generates a comparison_epoch{N}.html file showing human vs predicted actions.
|
|
2109
|
+
|
|
2110
|
+
Args:
|
|
2111
|
+
adapter: Trained adapter to use for inference
|
|
2112
|
+
episode: Episode with steps to evaluate
|
|
2113
|
+
epoch: Current epoch number
|
|
2114
|
+
config: Training configuration
|
|
2115
|
+
logger: Training logger for state tracking
|
|
2116
|
+
sample_indices: Specific step indices to evaluate (default: evenly spaced)
|
|
2117
|
+
|
|
2118
|
+
Returns:
|
|
2119
|
+
Path to generated comparison HTML file
|
|
2120
|
+
"""
|
|
2121
|
+
from openadapt_ml.scripts.compare import generate_comparison_html, predict_action, format_action
|
|
2122
|
+
|
|
2123
|
+
output_dir = Path(config.output_dir)
|
|
2124
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
2125
|
+
|
|
2126
|
+
# Select sample indices if not provided
|
|
2127
|
+
num_samples = min(config.eval_samples, len(episode.steps))
|
|
2128
|
+
if sample_indices is None:
|
|
2129
|
+
if num_samples >= len(episode.steps):
|
|
2130
|
+
sample_indices = list(range(len(episode.steps)))
|
|
2131
|
+
else:
|
|
2132
|
+
# Evenly space samples across the episode
|
|
2133
|
+
step_size = len(episode.steps) // num_samples
|
|
2134
|
+
sample_indices = [i * step_size for i in range(num_samples)]
|
|
2135
|
+
|
|
2136
|
+
print(f" Running inference on {len(sample_indices)} sample steps...")
|
|
2137
|
+
|
|
2138
|
+
# Switch adapter to eval mode
|
|
2139
|
+
adapter.eval()
|
|
2140
|
+
|
|
2141
|
+
comparison_data = []
|
|
2142
|
+
action_history: List[str] = []
|
|
2143
|
+
total_steps = len(episode.steps)
|
|
2144
|
+
|
|
2145
|
+
for i, step in enumerate(episode.steps):
|
|
2146
|
+
step_data = {
|
|
2147
|
+
"index": i,
|
|
2148
|
+
"time": step.t,
|
|
2149
|
+
"image_path": step.observation.image_path,
|
|
2150
|
+
"human_action": {
|
|
2151
|
+
"type": step.action.type,
|
|
2152
|
+
"x": step.action.x,
|
|
2153
|
+
"y": step.action.y,
|
|
2154
|
+
"text": step.action.text,
|
|
2155
|
+
},
|
|
2156
|
+
"predicted_action": None,
|
|
2157
|
+
"match": None,
|
|
2158
|
+
}
|
|
2159
|
+
|
|
2160
|
+
# Only run inference on selected samples (for speed)
|
|
2161
|
+
if i in sample_indices and step.observation.image_path:
|
|
2162
|
+
try:
|
|
2163
|
+
predicted = predict_action(
|
|
2164
|
+
adapter,
|
|
2165
|
+
step.observation.image_path,
|
|
2166
|
+
episode.goal,
|
|
2167
|
+
step_index=i,
|
|
2168
|
+
total_steps=total_steps,
|
|
2169
|
+
action_history=action_history.copy(),
|
|
2170
|
+
)
|
|
2171
|
+
step_data["predicted_action"] = predicted
|
|
2172
|
+
|
|
2173
|
+
# Check match and calculate distance
|
|
2174
|
+
if predicted and predicted.get("type") == step.action.type:
|
|
2175
|
+
step_data["match"] = True
|
|
2176
|
+
|
|
2177
|
+
# Calculate distance for click actions
|
|
2178
|
+
if step.action.type == "click":
|
|
2179
|
+
hx, hy = step.action.x or 0, step.action.y or 0
|
|
2180
|
+
px, py = predicted.get("x", 0), predicted.get("y", 0)
|
|
2181
|
+
distance = ((hx - px) ** 2 + (hy - py) ** 2) ** 0.5
|
|
2182
|
+
|
|
2183
|
+
# Log evaluation to training state
|
|
2184
|
+
logger.state.log_evaluation(
|
|
2185
|
+
epoch=epoch,
|
|
2186
|
+
sample_idx=i,
|
|
2187
|
+
image_path=step.observation.image_path,
|
|
2188
|
+
human_action=step_data["human_action"],
|
|
2189
|
+
predicted_action=predicted,
|
|
2190
|
+
)
|
|
2191
|
+
else:
|
|
2192
|
+
step_data["match"] = False
|
|
2193
|
+
|
|
2194
|
+
print(f" Step {i}: {step.action.type} -> {predicted.get('type') if predicted else 'none'}")
|
|
2195
|
+
|
|
2196
|
+
except Exception as e:
|
|
2197
|
+
print(f" Step {i}: inference failed - {e}")
|
|
2198
|
+
|
|
2199
|
+
# Build action history for context
|
|
2200
|
+
action_history.append(format_action(step.action, use_som=False))
|
|
2201
|
+
comparison_data.append(step_data)
|
|
2202
|
+
|
|
2203
|
+
# Switch back to train mode
|
|
2204
|
+
adapter.train()
|
|
2205
|
+
|
|
2206
|
+
# Generate comparison HTML
|
|
2207
|
+
output_path = output_dir / f"comparison_epoch{epoch}.html"
|
|
2208
|
+
capture_path = Path(logger.state.capture_path) if logger.state.capture_path else Path(".")
|
|
2209
|
+
|
|
2210
|
+
generate_comparison_html(capture_path, episode, comparison_data, output_path)
|
|
2211
|
+
print(f" Comparison saved: {output_path}")
|
|
2212
|
+
|
|
2213
|
+
# Also regenerate all dashboards to update navigation
|
|
2214
|
+
regenerate_all_dashboards(output_dir)
|
|
2215
|
+
|
|
2216
|
+
return output_path
|
|
2217
|
+
|
|
2218
|
+
|
|
2219
|
+
def _create_dataloader(dataset: Dataset, batch_size: int) -> DataLoader:
|
|
2220
|
+
# Use an identity collate_fn so that each batch is a List[Dict], matching
|
|
2221
|
+
# the expectations of adapters that operate on SFT-style samples.
|
|
2222
|
+
return DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=lambda x: x)
|
|
2223
|
+
|
|
2224
|
+
|
|
2225
|
+
def _create_lr_scheduler(
|
|
2226
|
+
optimizer: Optimizer,
|
|
2227
|
+
config: TrainingConfig,
|
|
2228
|
+
num_training_steps: int,
|
|
2229
|
+
) -> Optional[LambdaLR]:
|
|
2230
|
+
"""Create learning rate scheduler based on config.
|
|
2231
|
+
|
|
2232
|
+
Args:
|
|
2233
|
+
optimizer: The optimizer to schedule.
|
|
2234
|
+
config: Training configuration with lr_scheduler_type and warmup_ratio.
|
|
2235
|
+
num_training_steps: Total number of training steps.
|
|
2236
|
+
|
|
2237
|
+
Returns:
|
|
2238
|
+
LambdaLR scheduler or None if scheduler_type is "none" or "constant".
|
|
2239
|
+
"""
|
|
2240
|
+
scheduler_type = config.lr_scheduler_type.lower()
|
|
2241
|
+
|
|
2242
|
+
if scheduler_type in ("none", "constant"):
|
|
2243
|
+
return None
|
|
2244
|
+
|
|
2245
|
+
num_warmup_steps = int(num_training_steps * config.warmup_ratio)
|
|
2246
|
+
|
|
2247
|
+
if scheduler_type == "linear":
|
|
2248
|
+
def lr_lambda(current_step: int) -> float:
|
|
2249
|
+
if current_step < num_warmup_steps:
|
|
2250
|
+
# Linear warmup
|
|
2251
|
+
return float(current_step) / float(max(1, num_warmup_steps))
|
|
2252
|
+
# Linear decay
|
|
2253
|
+
return max(
|
|
2254
|
+
0.0,
|
|
2255
|
+
float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))
|
|
2256
|
+
)
|
|
2257
|
+
elif scheduler_type == "cosine":
|
|
2258
|
+
import math
|
|
2259
|
+
def lr_lambda(current_step: int) -> float:
|
|
2260
|
+
if current_step < num_warmup_steps:
|
|
2261
|
+
# Linear warmup
|
|
2262
|
+
return float(current_step) / float(max(1, num_warmup_steps))
|
|
2263
|
+
# Cosine decay
|
|
2264
|
+
progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
|
|
2265
|
+
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress)))
|
|
2266
|
+
else:
|
|
2267
|
+
raise ValueError(f"Unknown lr_scheduler_type: {scheduler_type}. Use 'linear', 'cosine', 'constant', or 'none'.")
|
|
2268
|
+
|
|
2269
|
+
return LambdaLR(optimizer, lr_lambda)
|
|
2270
|
+
|
|
2271
|
+
|
|
2272
|
+
def train_supervised(
|
|
2273
|
+
adapter: BaseVLMAdapter,
|
|
2274
|
+
dataset: Dataset,
|
|
2275
|
+
config: TrainingConfig,
|
|
2276
|
+
optimizer: Optional[Optimizer] = None,
|
|
2277
|
+
logger: Optional[TrainingLogger] = None,
|
|
2278
|
+
episode: Optional[Episode] = None,
|
|
2279
|
+
) -> bool:
|
|
2280
|
+
"""Minimal supervised training loop skeleton.
|
|
2281
|
+
|
|
2282
|
+
This assumes that `adapter.prepare_inputs` and `adapter.compute_loss` are
|
|
2283
|
+
implemented. It will raise if those methods are not implemented.
|
|
2284
|
+
|
|
2285
|
+
Args:
|
|
2286
|
+
adapter: VLM adapter to train.
|
|
2287
|
+
dataset: Training dataset.
|
|
2288
|
+
config: Training configuration.
|
|
2289
|
+
optimizer: Optional optimizer (default: AdamW).
|
|
2290
|
+
logger: Optional training logger for visualization.
|
|
2291
|
+
episode: Optional episode for periodic evaluation (generates comparison_epoch{N}.html).
|
|
2292
|
+
|
|
2293
|
+
Returns:
|
|
2294
|
+
True if training completed successfully, False if aborted due to NaN/Inf loss.
|
|
2295
|
+
"""
|
|
2296
|
+
|
|
2297
|
+
device = adapter.device # type: ignore[attr-defined]
|
|
2298
|
+
dataloader = _create_dataloader(dataset, batch_size=config.per_device_train_batch_size)
|
|
2299
|
+
|
|
2300
|
+
if optimizer is None:
|
|
2301
|
+
optimizer = torch.optim.AdamW(
|
|
2302
|
+
adapter.model.parameters(), # type: ignore[arg-type]
|
|
2303
|
+
lr=config.learning_rate,
|
|
2304
|
+
weight_decay=config.weight_decay,
|
|
2305
|
+
)
|
|
2306
|
+
|
|
2307
|
+
# Create logger if not provided
|
|
2308
|
+
if logger is None:
|
|
2309
|
+
logger = TrainingLogger(config.output_dir, config)
|
|
2310
|
+
|
|
2311
|
+
# Calculate total training steps for scheduler
|
|
2312
|
+
num_training_steps = len(dataloader) * config.num_train_epochs // config.gradient_accumulation_steps
|
|
2313
|
+
|
|
2314
|
+
# Create learning rate scheduler
|
|
2315
|
+
lr_scheduler = _create_lr_scheduler(optimizer, config, num_training_steps)
|
|
2316
|
+
|
|
2317
|
+
total_steps = 0
|
|
2318
|
+
adapter.train()
|
|
2319
|
+
|
|
2320
|
+
# Early stopping tracking
|
|
2321
|
+
consecutive_low_loss = 0
|
|
2322
|
+
early_stopped = False
|
|
2323
|
+
user_stopped = False
|
|
2324
|
+
|
|
2325
|
+
for epoch in range(config.num_train_epochs):
|
|
2326
|
+
if early_stopped or user_stopped:
|
|
2327
|
+
break
|
|
2328
|
+
|
|
2329
|
+
for _, batch in enumerate(dataloader):
|
|
2330
|
+
# Check for stop signal from dashboard
|
|
2331
|
+
stop_file = Path(config.output_dir) / "STOP_TRAINING"
|
|
2332
|
+
if stop_file.exists():
|
|
2333
|
+
msg = "Stop signal received from dashboard. Stopping training..."
|
|
2334
|
+
print(msg)
|
|
2335
|
+
logger._log_to_terminal(msg)
|
|
2336
|
+
# Set termination status for dashboard
|
|
2337
|
+
logger.state.termination_status = "user_stop"
|
|
2338
|
+
logger.state.termination_message = "Training stopped by user via dashboard"
|
|
2339
|
+
logger.save()
|
|
2340
|
+
user_stopped = True
|
|
2341
|
+
stop_file.unlink() # Remove signal file
|
|
2342
|
+
break
|
|
2343
|
+
|
|
2344
|
+
# Batch is a List[Dict[str, Any]] of SFT-style samples; adapter is
|
|
2345
|
+
# responsible for converting it into model inputs.
|
|
2346
|
+
samples: List[Dict[str, Any]] = batch
|
|
2347
|
+
|
|
2348
|
+
inputs = adapter.prepare_inputs(samples)
|
|
2349
|
+
inputs = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
|
|
2350
|
+
|
|
2351
|
+
loss = adapter.compute_loss(inputs)
|
|
2352
|
+
|
|
2353
|
+
# Guard against invalid losses to avoid propagating NaNs/Infs
|
|
2354
|
+
if torch.isnan(loss) or torch.isinf(loss):
|
|
2355
|
+
msg = f"Encountered invalid loss at epoch={epoch} step={total_steps + 1}: {loss.item()}"
|
|
2356
|
+
print(msg)
|
|
2357
|
+
logger._log_to_terminal(msg)
|
|
2358
|
+
logger.on_train_end()
|
|
2359
|
+
return False
|
|
2360
|
+
|
|
2361
|
+
loss.backward()
|
|
2362
|
+
|
|
2363
|
+
if (total_steps + 1) % config.gradient_accumulation_steps == 0:
|
|
2364
|
+
torch.nn.utils.clip_grad_norm_(adapter.model.parameters(), config.max_grad_norm) # type: ignore[arg-type]
|
|
2365
|
+
optimizer.step()
|
|
2366
|
+
if lr_scheduler is not None:
|
|
2367
|
+
lr_scheduler.step()
|
|
2368
|
+
optimizer.zero_grad()
|
|
2369
|
+
|
|
2370
|
+
total_steps += 1
|
|
2371
|
+
loss_val = loss.item()
|
|
2372
|
+
|
|
2373
|
+
# Get current learning rate from optimizer
|
|
2374
|
+
current_lr = optimizer.param_groups[0]['lr']
|
|
2375
|
+
|
|
2376
|
+
# Log step
|
|
2377
|
+
logger.on_step(epoch, total_steps, loss_val, current_lr)
|
|
2378
|
+
|
|
2379
|
+
if config.logging_steps and total_steps % config.logging_steps == 0:
|
|
2380
|
+
msg = f"epoch={epoch} step={total_steps} loss={loss_val:.4f} lr={current_lr:.6f}"
|
|
2381
|
+
print(msg)
|
|
2382
|
+
logger._log_to_terminal(msg)
|
|
2383
|
+
|
|
2384
|
+
# Early stopping check
|
|
2385
|
+
if loss_val < config.early_stop_loss:
|
|
2386
|
+
consecutive_low_loss += 1
|
|
2387
|
+
if consecutive_low_loss >= config.early_stop_patience:
|
|
2388
|
+
msg = (
|
|
2389
|
+
f"Early stopping: loss ({loss_val:.6f}) below threshold "
|
|
2390
|
+
f"({config.early_stop_loss}) for {config.early_stop_patience} consecutive steps"
|
|
2391
|
+
)
|
|
2392
|
+
print(msg)
|
|
2393
|
+
logger._log_to_terminal(msg)
|
|
2394
|
+
# Set termination status for dashboard
|
|
2395
|
+
logger.state.termination_status = "auto_low_loss"
|
|
2396
|
+
logger.state.termination_message = (
|
|
2397
|
+
f"Loss reached {loss_val:.6f} (< {config.early_stop_loss}) "
|
|
2398
|
+
f"for {config.early_stop_patience} consecutive steps"
|
|
2399
|
+
)
|
|
2400
|
+
logger.save()
|
|
2401
|
+
early_stopped = True
|
|
2402
|
+
break
|
|
2403
|
+
else:
|
|
2404
|
+
consecutive_low_loss = 0
|
|
2405
|
+
|
|
2406
|
+
# End of epoch
|
|
2407
|
+
logger.on_epoch_end(epoch)
|
|
2408
|
+
|
|
2409
|
+
# Save checkpoint at end of each epoch
|
|
2410
|
+
if config.save_checkpoint_every_epoch:
|
|
2411
|
+
checkpoint_path = Path(config.checkpoint_dir) / f"epoch_{epoch}"
|
|
2412
|
+
checkpoint_path.mkdir(parents=True, exist_ok=True)
|
|
2413
|
+
try:
|
|
2414
|
+
adapter.save_checkpoint(str(checkpoint_path))
|
|
2415
|
+
msg = f"Checkpoint saved to {checkpoint_path}"
|
|
2416
|
+
print(msg)
|
|
2417
|
+
logger._log_to_terminal(msg)
|
|
2418
|
+
except Exception as e:
|
|
2419
|
+
msg = f"Warning: Failed to save checkpoint: {e}"
|
|
2420
|
+
print(msg)
|
|
2421
|
+
logger._log_to_terminal(msg)
|
|
2422
|
+
|
|
2423
|
+
# Run evaluation after each epoch (generates comparison_epoch{N}.html)
|
|
2424
|
+
if config.eval_every_epoch and episode is not None:
|
|
2425
|
+
try:
|
|
2426
|
+
print(f"Running epoch {epoch} evaluation...")
|
|
2427
|
+
run_epoch_evaluation(
|
|
2428
|
+
adapter=adapter,
|
|
2429
|
+
episode=episode,
|
|
2430
|
+
epoch=epoch,
|
|
2431
|
+
config=config,
|
|
2432
|
+
logger=logger,
|
|
2433
|
+
)
|
|
2434
|
+
except Exception as e:
|
|
2435
|
+
print(f"Warning: Epoch evaluation failed: {e}")
|
|
2436
|
+
import traceback
|
|
2437
|
+
traceback.print_exc()
|
|
2438
|
+
|
|
2439
|
+
# Set termination status if not already set (normal completion)
|
|
2440
|
+
if not logger.state.termination_status:
|
|
2441
|
+
logger.state.termination_status = "auto_complete"
|
|
2442
|
+
logger.state.termination_message = f"Training completed all {config.num_train_epochs} epochs"
|
|
2443
|
+
logger.save()
|
|
2444
|
+
|
|
2445
|
+
logger.on_train_end()
|
|
2446
|
+
return True
|