openadapt-ml 0.1.0__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (112) hide show
  1. openadapt_ml/baselines/__init__.py +121 -0
  2. openadapt_ml/baselines/adapter.py +185 -0
  3. openadapt_ml/baselines/cli.py +314 -0
  4. openadapt_ml/baselines/config.py +448 -0
  5. openadapt_ml/baselines/parser.py +922 -0
  6. openadapt_ml/baselines/prompts.py +787 -0
  7. openadapt_ml/benchmarks/__init__.py +13 -107
  8. openadapt_ml/benchmarks/agent.py +297 -374
  9. openadapt_ml/benchmarks/azure.py +62 -24
  10. openadapt_ml/benchmarks/azure_ops_tracker.py +521 -0
  11. openadapt_ml/benchmarks/cli.py +1874 -751
  12. openadapt_ml/benchmarks/trace_export.py +631 -0
  13. openadapt_ml/benchmarks/viewer.py +1236 -0
  14. openadapt_ml/benchmarks/vm_monitor.py +1111 -0
  15. openadapt_ml/benchmarks/waa_deploy/Dockerfile +216 -0
  16. openadapt_ml/benchmarks/waa_deploy/__init__.py +10 -0
  17. openadapt_ml/benchmarks/waa_deploy/api_agent.py +540 -0
  18. openadapt_ml/benchmarks/waa_deploy/start_waa_server.bat +53 -0
  19. openadapt_ml/cloud/azure_inference.py +3 -5
  20. openadapt_ml/cloud/lambda_labs.py +722 -307
  21. openadapt_ml/cloud/local.py +3194 -89
  22. openadapt_ml/cloud/ssh_tunnel.py +595 -0
  23. openadapt_ml/datasets/next_action.py +125 -96
  24. openadapt_ml/evals/grounding.py +32 -9
  25. openadapt_ml/evals/plot_eval_metrics.py +15 -13
  26. openadapt_ml/evals/trajectory_matching.py +120 -57
  27. openadapt_ml/experiments/demo_prompt/__init__.py +19 -0
  28. openadapt_ml/experiments/demo_prompt/format_demo.py +236 -0
  29. openadapt_ml/experiments/demo_prompt/results/experiment_20251231_002125.json +83 -0
  30. openadapt_ml/experiments/demo_prompt/results/experiment_n30_20251231_165958.json +1100 -0
  31. openadapt_ml/experiments/demo_prompt/results/multistep_20251231_025051.json +182 -0
  32. openadapt_ml/experiments/demo_prompt/run_experiment.py +541 -0
  33. openadapt_ml/experiments/representation_shootout/__init__.py +70 -0
  34. openadapt_ml/experiments/representation_shootout/conditions.py +708 -0
  35. openadapt_ml/experiments/representation_shootout/config.py +390 -0
  36. openadapt_ml/experiments/representation_shootout/evaluator.py +659 -0
  37. openadapt_ml/experiments/representation_shootout/runner.py +687 -0
  38. openadapt_ml/experiments/waa_demo/__init__.py +10 -0
  39. openadapt_ml/experiments/waa_demo/demos.py +357 -0
  40. openadapt_ml/experiments/waa_demo/runner.py +732 -0
  41. openadapt_ml/experiments/waa_demo/tasks.py +151 -0
  42. openadapt_ml/export/__init__.py +9 -0
  43. openadapt_ml/export/__main__.py +6 -0
  44. openadapt_ml/export/cli.py +89 -0
  45. openadapt_ml/export/parquet.py +277 -0
  46. openadapt_ml/grounding/detector.py +18 -14
  47. openadapt_ml/ingest/__init__.py +11 -10
  48. openadapt_ml/ingest/capture.py +97 -86
  49. openadapt_ml/ingest/loader.py +120 -69
  50. openadapt_ml/ingest/synthetic.py +344 -193
  51. openadapt_ml/models/api_adapter.py +14 -4
  52. openadapt_ml/models/base_adapter.py +10 -2
  53. openadapt_ml/models/providers/__init__.py +288 -0
  54. openadapt_ml/models/providers/anthropic.py +266 -0
  55. openadapt_ml/models/providers/base.py +299 -0
  56. openadapt_ml/models/providers/google.py +376 -0
  57. openadapt_ml/models/providers/openai.py +342 -0
  58. openadapt_ml/models/qwen_vl.py +46 -19
  59. openadapt_ml/perception/__init__.py +35 -0
  60. openadapt_ml/perception/integration.py +399 -0
  61. openadapt_ml/retrieval/README.md +226 -0
  62. openadapt_ml/retrieval/USAGE.md +391 -0
  63. openadapt_ml/retrieval/__init__.py +91 -0
  64. openadapt_ml/retrieval/demo_retriever.py +843 -0
  65. openadapt_ml/retrieval/embeddings.py +630 -0
  66. openadapt_ml/retrieval/index.py +194 -0
  67. openadapt_ml/retrieval/retriever.py +162 -0
  68. openadapt_ml/runtime/__init__.py +50 -0
  69. openadapt_ml/runtime/policy.py +27 -14
  70. openadapt_ml/runtime/safety_gate.py +471 -0
  71. openadapt_ml/schema/__init__.py +113 -0
  72. openadapt_ml/schema/converters.py +588 -0
  73. openadapt_ml/schema/episode.py +470 -0
  74. openadapt_ml/scripts/capture_screenshots.py +530 -0
  75. openadapt_ml/scripts/compare.py +102 -61
  76. openadapt_ml/scripts/demo_policy.py +4 -1
  77. openadapt_ml/scripts/eval_policy.py +19 -14
  78. openadapt_ml/scripts/make_gif.py +1 -1
  79. openadapt_ml/scripts/prepare_synthetic.py +16 -17
  80. openadapt_ml/scripts/train.py +98 -75
  81. openadapt_ml/segmentation/README.md +920 -0
  82. openadapt_ml/segmentation/__init__.py +97 -0
  83. openadapt_ml/segmentation/adapters/__init__.py +5 -0
  84. openadapt_ml/segmentation/adapters/capture_adapter.py +420 -0
  85. openadapt_ml/segmentation/annotator.py +610 -0
  86. openadapt_ml/segmentation/cache.py +290 -0
  87. openadapt_ml/segmentation/cli.py +674 -0
  88. openadapt_ml/segmentation/deduplicator.py +656 -0
  89. openadapt_ml/segmentation/frame_describer.py +788 -0
  90. openadapt_ml/segmentation/pipeline.py +340 -0
  91. openadapt_ml/segmentation/schemas.py +622 -0
  92. openadapt_ml/segmentation/segment_extractor.py +634 -0
  93. openadapt_ml/training/azure_ops_viewer.py +1097 -0
  94. openadapt_ml/training/benchmark_viewer.py +3255 -19
  95. openadapt_ml/training/shared_ui.py +7 -7
  96. openadapt_ml/training/stub_provider.py +57 -35
  97. openadapt_ml/training/trainer.py +255 -441
  98. openadapt_ml/training/trl_trainer.py +403 -0
  99. openadapt_ml/training/viewer.py +323 -108
  100. openadapt_ml/training/viewer_components.py +180 -0
  101. {openadapt_ml-0.1.0.dist-info → openadapt_ml-0.2.1.dist-info}/METADATA +312 -69
  102. openadapt_ml-0.2.1.dist-info/RECORD +116 -0
  103. openadapt_ml/benchmarks/base.py +0 -366
  104. openadapt_ml/benchmarks/data_collection.py +0 -432
  105. openadapt_ml/benchmarks/runner.py +0 -381
  106. openadapt_ml/benchmarks/waa.py +0 -704
  107. openadapt_ml/schemas/__init__.py +0 -53
  108. openadapt_ml/schemas/sessions.py +0 -122
  109. openadapt_ml/schemas/validation.py +0 -252
  110. openadapt_ml-0.1.0.dist-info/RECORD +0 -55
  111. {openadapt_ml-0.1.0.dist-info → openadapt_ml-0.2.1.dist-info}/WHEEL +0 -0
  112. {openadapt_ml-0.1.0.dist-info → openadapt_ml-0.2.1.dist-info}/licenses/LICENSE +0 -0
@@ -4,15 +4,9 @@ import json
4
4
  import time
5
5
  from dataclasses import dataclass, field
6
6
  from pathlib import Path
7
- from typing import Any, Callable, Dict, List, Optional
7
+ from typing import Any, Dict, List
8
8
 
9
- import torch
10
- from torch.optim import Optimizer
11
- from torch.optim.lr_scheduler import LambdaLR
12
- from torch.utils.data import DataLoader, Dataset
13
-
14
- from openadapt_ml.models.base_adapter import BaseVLMAdapter
15
- from openadapt_ml.schemas.sessions import Episode
9
+ from openadapt_ml.schema import ActionType
16
10
  from openadapt_ml.training.shared_ui import (
17
11
  get_shared_header_css as _get_shared_header_css,
18
12
  generate_shared_header_html as _generate_shared_header_html,
@@ -21,6 +15,10 @@ from openadapt_ml.training.shared_ui import (
21
15
  from openadapt_ml.training.viewer import (
22
16
  generate_unified_viewer_from_output_dir,
23
17
  )
18
+ from openadapt_ml.training.benchmark_viewer import (
19
+ _get_azure_jobs_panel_css,
20
+ _get_azure_jobs_panel_html,
21
+ )
24
22
 
25
23
 
26
24
  def setup_job_directory(base_dir: str | Path, job_id: str) -> Path:
@@ -110,12 +108,18 @@ class TrainingConfig:
110
108
  @dataclass
111
109
  class TrainingState:
112
110
  """Tracks training progress for visualization."""
111
+
113
112
  # Job identification
114
113
  job_id: str = field(default_factory=lambda: time.strftime("%Y%m%d_%H%M%S"))
115
- hostname: str = field(default_factory=lambda: __import__('socket').gethostname())
114
+ hostname: str = field(default_factory=lambda: __import__("socket").gethostname())
116
115
  capture_path: str = ""
117
116
  config_path: str = ""
118
117
  goal: str = "" # Task goal/description for the training run
118
+ # Model configuration
119
+ model_name: str = "" # e.g. "Qwen/Qwen3-VL-2B-Instruct"
120
+ lora_r: int = 0 # LoRA rank
121
+ lora_alpha: int = 0 # LoRA alpha
122
+ load_in_4bit: bool = False # Quantization
119
123
  # Training progress
120
124
  epoch: int = 0
121
125
  step: int = 0
@@ -139,7 +143,9 @@ class TrainingState:
139
143
  setup_status: str = "" # e.g. "booting", "installing", "training", "complete"
140
144
  setup_logs: List[str] = field(default_factory=list) # Setup progress messages
141
145
  # Termination tracking
142
- termination_status: str = "" # e.g. "auto_low_loss", "auto_complete", "user_stop", "running"
146
+ termination_status: str = (
147
+ "" # e.g. "auto_low_loss", "auto_complete", "user_stop", "running"
148
+ )
143
149
  termination_message: str = "" # Human-readable termination reason
144
150
 
145
151
  def log_step(self, epoch: int, step: int, loss: float, lr: float = 0.0) -> None:
@@ -148,33 +154,46 @@ class TrainingState:
148
154
  self.step = step
149
155
  self.loss = loss
150
156
  self.learning_rate = lr
151
- self.losses.append({
152
- "epoch": epoch,
153
- "step": step,
154
- "loss": loss,
155
- "lr": lr,
156
- "time": time.time() - self.start_time,
157
- })
158
-
159
- def log_evaluation(self, epoch: int, sample_idx: int, image_path: str,
160
- human_action: Dict, predicted_action: Dict) -> None:
157
+ self.losses.append(
158
+ {
159
+ "epoch": epoch,
160
+ "step": step,
161
+ "loss": loss,
162
+ "lr": lr,
163
+ "time": time.time() - self.start_time,
164
+ }
165
+ )
166
+
167
+ def log_evaluation(
168
+ self,
169
+ epoch: int,
170
+ sample_idx: int,
171
+ image_path: str,
172
+ human_action: Dict,
173
+ predicted_action: Dict,
174
+ ) -> None:
161
175
  """Log an evaluation sample."""
162
176
  # Calculate distance for click actions
163
177
  distance = 0.0
164
- if human_action.get("type") == "click" and predicted_action.get("type") == "click":
178
+ if (
179
+ human_action.get("type") == "click"
180
+ and predicted_action.get("type") == "click"
181
+ ):
165
182
  hx, hy = human_action.get("x", 0), human_action.get("y", 0)
166
183
  px, py = predicted_action.get("x", 0), predicted_action.get("y", 0)
167
184
  distance = ((hx - px) ** 2 + (hy - py) ** 2) ** 0.5
168
185
 
169
- self.evaluations.append({
170
- "epoch": epoch,
171
- "sample_idx": sample_idx,
172
- "image_path": image_path,
173
- "human_action": human_action,
174
- "predicted_action": predicted_action,
175
- "distance": distance,
176
- "correct": distance < 50, # Within 50 pixels is "correct"
177
- })
186
+ self.evaluations.append(
187
+ {
188
+ "epoch": epoch,
189
+ "sample_idx": sample_idx,
190
+ "image_path": image_path,
191
+ "human_action": human_action,
192
+ "predicted_action": predicted_action,
193
+ "distance": distance,
194
+ "correct": distance < 50, # Within 50 pixels is "correct"
195
+ }
196
+ )
178
197
 
179
198
  def to_dict(self) -> Dict[str, Any]:
180
199
  """Convert state to serializable dict."""
@@ -185,9 +204,16 @@ class TrainingState:
185
204
  "capture_path": self.capture_path,
186
205
  "config_path": self.config_path,
187
206
  "goal": self.goal,
207
+ # Model configuration
208
+ "model_name": self.model_name,
209
+ "lora_r": self.lora_r,
210
+ "lora_alpha": self.lora_alpha,
211
+ "load_in_4bit": self.load_in_4bit,
188
212
  "instance_type": self.instance_type,
189
213
  "instance_ip": self.instance_ip,
190
- "started_at": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime(self.start_time)),
214
+ "started_at": time.strftime(
215
+ "%Y-%m-%dT%H:%M:%SZ", time.gmtime(self.start_time)
216
+ ),
191
217
  # Cloud provider info
192
218
  "cloud_provider": self.cloud_provider,
193
219
  "cloud_dashboard_url": self.cloud_dashboard_url,
@@ -227,6 +253,11 @@ class TrainingLogger:
227
253
  cloud_dashboard_url: str = "",
228
254
  cloud_instance_id: str = "",
229
255
  job_id: str = "",
256
+ # Model configuration
257
+ model_name: str = "",
258
+ lora_r: int = 0,
259
+ lora_alpha: int = 0,
260
+ load_in_4bit: bool = False,
230
261
  ):
231
262
  # Generate job_id if not provided
232
263
  if not job_id:
@@ -242,6 +273,10 @@ class TrainingLogger:
242
273
  capture_path=capture_path,
243
274
  config_path=config_path,
244
275
  goal=goal,
276
+ model_name=model_name,
277
+ lora_r=lora_r,
278
+ lora_alpha=lora_alpha,
279
+ load_in_4bit=load_in_4bit,
245
280
  instance_ip=instance_ip,
246
281
  instance_type=instance_type,
247
282
  total_epochs=config.num_train_epochs,
@@ -299,6 +334,7 @@ class TrainingLogger:
299
334
  def _save_config_snapshot(self) -> None:
300
335
  """Save training config snapshot to JSON."""
301
336
  from dataclasses import asdict
337
+
302
338
  config_file = self.output_dir / "config.json"
303
339
  config_dict = asdict(self.config)
304
340
  with open(config_file, "w") as f:
@@ -316,32 +352,45 @@ class TrainingLogger:
316
352
  dashboard_path.write_text(html)
317
353
 
318
354
 
319
- def _generate_termination_status_html(state: TrainingState, is_training_complete: bool) -> str:
355
+ def _generate_termination_status_html(
356
+ state: TrainingState, is_training_complete: bool
357
+ ) -> str:
320
358
  """Generate HTML for termination status section."""
321
359
  # Check if we have termination info
322
360
  if state.termination_status:
323
361
  # Map termination status to colors and icons
324
362
  status_styles = {
325
- "auto_complete": {"color": "#22c55e", "icon": "✓", "label": "Training Complete"},
326
- "auto_low_loss": {"color": "#22c55e", "icon": "✓", "label": "Auto-Stopped (Low Loss)"},
363
+ "auto_complete": {
364
+ "color": "#22c55e",
365
+ "icon": "✓",
366
+ "label": "Training Complete",
367
+ },
368
+ "auto_low_loss": {
369
+ "color": "#22c55e",
370
+ "icon": "✓",
371
+ "label": "Auto-Stopped (Low Loss)",
372
+ },
327
373
  "user_stop": {"color": "#f59e0b", "icon": "■", "label": "Stopped by User"},
328
374
  }
329
- style = status_styles.get(state.termination_status, {"color": "#22c55e", "icon": "✓", "label": "Complete"})
375
+ style = status_styles.get(
376
+ state.termination_status,
377
+ {"color": "#22c55e", "icon": "✓", "label": "Complete"},
378
+ )
330
379
 
331
- return f'''<div style="display: flex; flex-direction: column; gap: 8px;">
332
- <div style="display: flex; align-items: center; gap: 8px; color: {style['color']};">
333
- <span style="font-size: 1.2rem;">{style['icon']}</span>
334
- <span style="font-weight: 600;">{style['label']}</span>
380
+ return f"""<div style="display: flex; flex-direction: column; gap: 8px;">
381
+ <div style="display: flex; align-items: center; gap: 8px; color: {style["color"]};">
382
+ <span style="font-size: 1.2rem;">{style["icon"]}</span>
383
+ <span style="font-weight: 600;">{style["label"]}</span>
335
384
  </div>
336
- {f'<div style="font-size: 0.85rem; color: var(--text-muted); margin-left: 28px;">{state.termination_message}</div>' if state.termination_message else ''}
337
- </div>'''
385
+ {f'<div style="font-size: 0.85rem; color: var(--text-muted); margin-left: 28px;">{state.termination_message}</div>' if state.termination_message else ""}
386
+ </div>"""
338
387
  elif is_training_complete:
339
- return '''<div style="display: flex; align-items: center; gap: 8px; color: #22c55e;">
388
+ return """<div style="display: flex; align-items: center; gap: 8px; color: #22c55e;">
340
389
  <span style="font-size: 1.2rem;">✓</span>
341
390
  <span style="font-weight: 600;">Training Complete</span>
342
- </div>'''
391
+ </div>"""
343
392
  else:
344
- return '''<button id="stop-training-btn" onclick="stopTraining()" style="
393
+ return """<button id="stop-training-btn" onclick="stopTraining()" style="
345
394
  background: linear-gradient(135deg, #ef4444 0%, #dc2626 100%);
346
395
  color: white;
347
396
  border: none;
@@ -357,53 +406,65 @@ def _generate_termination_status_html(state: TrainingState, is_training_complete
357
406
  ">
358
407
  <span style="font-size: 1.1rem;">■</span> Stop Training
359
408
  </button>
360
- <p id="stop-status" style="margin-top: 8px; font-size: 0.75rem; color: var(--text-muted);"></p>'''
409
+ <p id="stop-status" style="margin-top: 8px; font-size: 0.75rem; color: var(--text-muted);"></p>"""
361
410
 
362
411
 
363
412
  def generate_training_dashboard(state: TrainingState, config: TrainingConfig) -> str:
364
413
  """Generate an HTML dashboard for training visualization."""
365
414
  losses_json = json.dumps(state.losses)
366
415
  # Use stored elapsed_time if available (historical data), otherwise calculate
367
- elapsed = state.elapsed_time if state.elapsed_time > 0 else time.time() - state.start_time
416
+ elapsed = (
417
+ state.elapsed_time if state.elapsed_time > 0 else time.time() - state.start_time
418
+ )
368
419
  elapsed_str = f"{int(elapsed // 60)}m {int(elapsed % 60)}s"
369
420
 
370
421
  # Calculate stats
371
422
  if state.losses:
372
- min_loss = min(l["loss"] for l in state.losses)
373
- avg_loss = sum(l["loss"] for l in state.losses) / len(state.losses)
423
+ min_loss = min(loss["loss"] for loss in state.losses)
424
+ sum(loss["loss"] for loss in state.losses) / len(state.losses)
374
425
  recent_losses = state.losses[-10:] if len(state.losses) >= 10 else state.losses
375
- recent_avg = sum(l["loss"] for l in recent_losses) / len(recent_losses)
426
+ recent_avg = sum(loss["loss"] for loss in recent_losses) / len(recent_losses)
376
427
  # Calculate step times
377
428
  step_times = []
378
429
  for i in range(1, len(state.losses)):
379
- step_times.append(state.losses[i]["time"] - state.losses[i-1]["time"])
430
+ step_times.append(state.losses[i]["time"] - state.losses[i - 1]["time"])
380
431
  avg_step_time = sum(step_times) / len(step_times) if step_times else 0
381
432
  # Loss by epoch
382
433
  epoch_losses: dict = {}
383
- for l in state.losses:
384
- ep = l["epoch"]
434
+ for loss in state.losses:
435
+ ep = loss["epoch"]
385
436
  if ep not in epoch_losses:
386
437
  epoch_losses[ep] = []
387
- epoch_losses[ep].append(l["loss"])
388
- epoch_avg = {ep: sum(losses)/len(losses) for ep, losses in epoch_losses.items()}
438
+ epoch_losses[ep].append(loss["loss"])
439
+ epoch_avg = {
440
+ ep: sum(losses) / len(losses) for ep, losses in epoch_losses.items()
441
+ }
389
442
  # Estimate ETA
390
443
  # Steps per epoch = steps in completed epochs / completed epochs
391
444
  completed_epochs = state.epoch
392
- steps_in_completed = sum(1 for l in state.losses if l["epoch"] < completed_epochs)
445
+ steps_in_completed = sum(
446
+ 1 for loss in state.losses if loss["epoch"] < completed_epochs
447
+ )
393
448
  if completed_epochs > 0 and steps_in_completed > 0:
394
449
  steps_per_epoch = steps_in_completed / completed_epochs
395
450
  else:
396
451
  # Estimate from current epoch progress
397
- steps_per_epoch = len(state.losses) / (state.epoch + 1) if state.epoch >= 0 else len(state.losses)
452
+ steps_per_epoch = (
453
+ len(state.losses) / (state.epoch + 1)
454
+ if state.epoch >= 0
455
+ else len(state.losses)
456
+ )
398
457
 
399
- total_epochs = state.total_epochs if state.total_epochs > 0 else config.num_train_epochs
458
+ total_epochs = (
459
+ state.total_epochs if state.total_epochs > 0 else config.num_train_epochs
460
+ )
400
461
  total_steps_estimate = steps_per_epoch * total_epochs
401
462
  remaining_steps = max(0, total_steps_estimate - len(state.losses))
402
463
  eta_seconds = remaining_steps * avg_step_time if avg_step_time > 0 else 0
403
464
  # Check if training is complete (all steps done)
404
465
  is_training_complete = remaining_steps == 0 and len(state.losses) > 0
405
466
  else:
406
- min_loss = avg_loss = recent_avg = avg_step_time = 0.0
467
+ min_loss = recent_avg = avg_step_time = 0.0
407
468
  epoch_avg = {}
408
469
  eta_seconds = 0
409
470
  steps_per_epoch = 0
@@ -414,10 +475,9 @@ def generate_training_dashboard(state: TrainingState, config: TrainingConfig) ->
414
475
  epoch_avg_json = json.dumps(list(epoch_avg.items()))
415
476
 
416
477
  # Generate comparison viewer preview if capture path available
417
- comparison_viewer_path = ""
418
478
  if state.capture_path:
419
479
  try:
420
- from openadapt_ml.scripts.compare import generate_comparison_html, generate_comparison_data
480
+ from openadapt_ml.scripts.compare import generate_comparison_html
421
481
  from openadapt_ml.ingest.capture import capture_to_episode
422
482
 
423
483
  capture_path = Path(state.capture_path)
@@ -428,14 +488,20 @@ def generate_training_dashboard(state: TrainingState, config: TrainingConfig) ->
428
488
  # Generate comparison data with null predictions (shows "— No prediction")
429
489
  comparison_data = []
430
490
  for i, step in enumerate(episode.steps):
491
+ # Extract normalized coordinates if available
492
+ action_x, action_y = None, None
493
+ if step.action.normalized_coordinates:
494
+ action_x, action_y = step.action.normalized_coordinates
431
495
  step_data = {
432
496
  "index": i,
433
- "time": step.t,
434
- "image_path": step.observation.image_path,
497
+ "time": step.step_index,
498
+ "image_path": step.observation.screenshot_path,
435
499
  "human_action": {
436
- "type": step.action.type,
437
- "x": step.action.x,
438
- "y": step.action.y,
500
+ "type": step.action.type.value
501
+ if isinstance(step.action.type, ActionType)
502
+ else step.action.type,
503
+ "x": action_x,
504
+ "y": action_y,
439
505
  "text": step.action.text,
440
506
  },
441
507
  "predicted_action": None, # Shows "— No prediction" in viewer
@@ -444,15 +510,21 @@ def generate_training_dashboard(state: TrainingState, config: TrainingConfig) ->
444
510
  comparison_data.append(step_data)
445
511
 
446
512
  # Generate comparison HTML
447
- output_dir = Path(config.output_dir) if hasattr(config, 'output_dir') else Path("training_output")
513
+ output_dir = (
514
+ Path(config.output_dir)
515
+ if hasattr(config, "output_dir")
516
+ else Path("training_output")
517
+ )
448
518
  output_dir.mkdir(parents=True, exist_ok=True)
449
519
  comparison_output = output_dir / "comparison_preview.html"
450
- generate_comparison_html(capture_path, episode, comparison_data, comparison_output)
451
- comparison_viewer_path = str(comparison_output.name) # Relative path
452
- except Exception as e:
520
+ generate_comparison_html(
521
+ capture_path, episode, comparison_data, comparison_output
522
+ )
523
+ str(comparison_output.name) # Relative path
524
+ except Exception:
453
525
  pass # Fail silently if comparison viewer can't be generated
454
526
 
455
- html = f'''<!DOCTYPE html>
527
+ html = f"""<!DOCTYPE html>
456
528
  <html lang="en">
457
529
  <head>
458
530
  <meta charset="UTF-8">
@@ -596,6 +668,42 @@ def generate_training_dashboard(state: TrainingState, config: TrainingConfig) ->
596
668
  .setup-log-line.current {{
597
669
  color: var(--accent);
598
670
  }}
671
+ .config-panel {{
672
+ background: var(--bg-secondary);
673
+ border: 1px solid var(--border-color);
674
+ border-radius: 12px;
675
+ padding: 16px 20px;
676
+ margin-bottom: 24px;
677
+ }}
678
+ .config-grid {{
679
+ display: grid;
680
+ grid-template-columns: repeat(auto-fit, minmax(200px, 1fr));
681
+ gap: 16px;
682
+ }}
683
+ .config-item {{
684
+ display: flex;
685
+ flex-direction: column;
686
+ gap: 4px;
687
+ }}
688
+ .config-label {{
689
+ font-size: 0.7rem;
690
+ color: var(--text-secondary);
691
+ text-transform: uppercase;
692
+ letter-spacing: 0.5px;
693
+ }}
694
+ .config-value {{
695
+ font-family: "SF Mono", Monaco, monospace;
696
+ font-size: 0.85rem;
697
+ color: var(--text-primary);
698
+ }}
699
+ .config-value.model {{
700
+ color: var(--accent);
701
+ }}
702
+ .config-value.goal {{
703
+ font-family: -apple-system, BlinkMacSystemFont, "Inter", sans-serif;
704
+ font-size: 0.8rem;
705
+ opacity: 0.9;
706
+ }}
599
707
  .status {{
600
708
  display: flex;
601
709
  align-items: center;
@@ -754,6 +862,8 @@ def generate_training_dashboard(state: TrainingState, config: TrainingConfig) ->
754
862
  }}
755
863
  /* Shared header styles (injected from _get_shared_header_css) */
756
864
  {_get_shared_header_css()}
865
+ /* Azure ML Jobs panel styles */
866
+ {_get_azure_jobs_panel_css()}
757
867
  .eval-panel {{
758
868
  background: var(--bg-secondary);
759
869
  border: 1px solid var(--border-color);
@@ -1073,10 +1183,10 @@ def generate_training_dashboard(state: TrainingState, config: TrainingConfig) ->
1073
1183
  <div class="container">
1074
1184
  <header>
1075
1185
  <div>
1076
- <h1>Training Dashboard{f' <a href="{state.cloud_dashboard_url}" target="_blank" class="cloud-link cloud-badge"><svg viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2"><path d="M18 10h-1.26A8 8 0 1 0 9 20h9a5 5 0 0 0 0-10z"/></svg>{state.cloud_provider.title()} Cloud</a>' if state.cloud_dashboard_url else ''}</h1>
1186
+ <h1>Training Dashboard{f' <a href="{state.cloud_dashboard_url}" target="_blank" class="cloud-link cloud-badge"><svg viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2"><path d="M18 10h-1.26A8 8 0 1 0 9 20h9a5 5 0 0 0 0-10z"/></svg>{state.cloud_provider.title()} Cloud</a>' if state.cloud_dashboard_url else ""}</h1>
1077
1187
  <div class="job-info" id="job-info">
1078
- <span class="job-host">{state.hostname or 'stub-local'} @ {state.instance_ip or '127.0.0.1'}</span>
1079
- {f'<span class="job-config">{state.instance_type}</span>' if state.instance_type else ''}
1188
+ <span class="job-host">{state.hostname or "stub-local"} @ {state.instance_ip or "127.0.0.1"}</span>
1189
+ {f'<span class="job-config">{state.instance_type}</span>' if state.instance_type else ""}
1080
1190
  </div>
1081
1191
  </div>
1082
1192
  <div class="status" id="status">
@@ -1085,13 +1195,40 @@ def generate_training_dashboard(state: TrainingState, config: TrainingConfig) ->
1085
1195
  </div>
1086
1196
  </header>
1087
1197
 
1088
- <div class="setup-panel{' hidden' if not state.setup_logs else ''}" id="setup-panel">
1198
+ <div class="setup-panel{" hidden" if not state.setup_logs else ""}" id="setup-panel">
1089
1199
  <div class="setup-header">
1090
1200
  <h2>Setup Progress</h2>
1091
- <span class="setup-status-badge {state.setup_status}" id="setup-status-badge">{state.setup_status or 'initializing'}</span>
1201
+ <span class="setup-status-badge {state.setup_status}" id="setup-status-badge">{state.setup_status or "initializing"}</span>
1092
1202
  </div>
1093
1203
  <div class="setup-logs" id="setup-logs">
1094
- {''.join(f'<div class="setup-log-line{" current" if i == len(state.setup_logs) - 1 else ""}">{log}</div>' for i, log in enumerate(state.setup_logs)) if state.setup_logs else '<div class="setup-log-line">Waiting for setup logs...</div>'}
1204
+ {"".join(f'<div class="setup-log-line{" current" if i == len(state.setup_logs) - 1 else ""}">{log}</div>' for i, log in enumerate(state.setup_logs)) if state.setup_logs else '<div class="setup-log-line">Waiting for setup logs...</div>'}
1205
+ </div>
1206
+ </div>
1207
+
1208
+ {_get_azure_jobs_panel_html()}
1209
+
1210
+ <div class="config-panel" id="config-panel">
1211
+ <div class="config-grid">
1212
+ <div class="config-item">
1213
+ <span class="config-label">Model</span>
1214
+ <span class="config-value model" id="config-model">{state.model_name or "Not specified"}</span>
1215
+ </div>
1216
+ <div class="config-item">
1217
+ <span class="config-label">Goal</span>
1218
+ <span class="config-value goal" id="config-goal">{state.goal or "Not specified"}</span>
1219
+ </div>
1220
+ <div class="config-item">
1221
+ <span class="config-label">LoRA</span>
1222
+ <span class="config-value" id="config-lora">{f"r={state.lora_r}, α={state.lora_alpha}" if state.lora_r else "Not specified"}</span>
1223
+ </div>
1224
+ <div class="config-item">
1225
+ <span class="config-label">Quantization</span>
1226
+ <span class="config-value" id="config-quant">{"4-bit" if state.load_in_4bit else "None"}</span>
1227
+ </div>
1228
+ <div class="config-item">
1229
+ <span class="config-label">Config</span>
1230
+ <span class="config-value" id="config-path">{state.config_path or "Not specified"}</span>
1231
+ </div>
1095
1232
  </div>
1096
1233
  </div>
1097
1234
 
@@ -1437,7 +1574,7 @@ def generate_training_dashboard(state: TrainingState, config: TrainingConfig) ->
1437
1574
  let etaSeconds = {eta_seconds};
1438
1575
  let avgStepTime = {avg_step_time};
1439
1576
  let remainingSteps = {remaining_steps};
1440
- let isTrainingComplete = {'true' if is_training_complete else 'false'};
1577
+ let isTrainingComplete = {"true" if is_training_complete else "false"};
1441
1578
 
1442
1579
  // Auto-stop when loss <= threshold (INVARIANT: training should stop when loss <= 1.0)
1443
1580
  const AUTO_STOP_LOSS_THRESHOLD = 1.0;
@@ -1519,6 +1656,28 @@ def generate_training_dashboard(state: TrainingState, config: TrainingConfig) ->
1519
1656
  }}
1520
1657
  }}
1521
1658
 
1659
+ // Update config panel
1660
+ const configModel = document.getElementById('config-model');
1661
+ const configGoal = document.getElementById('config-goal');
1662
+ const configLora = document.getElementById('config-lora');
1663
+ const configQuant = document.getElementById('config-quant');
1664
+ const configPath = document.getElementById('config-path');
1665
+ if (configModel && data.model_name) {{
1666
+ configModel.textContent = data.model_name;
1667
+ }}
1668
+ if (configGoal && data.goal) {{
1669
+ configGoal.textContent = data.goal;
1670
+ }}
1671
+ if (configLora && (data.lora_r || data.lora_alpha)) {{
1672
+ configLora.textContent = `r=${{data.lora_r || 0}}, α=${{data.lora_alpha || 0}}`;
1673
+ }}
1674
+ if (configQuant) {{
1675
+ configQuant.textContent = data.load_in_4bit ? '4-bit' : 'None';
1676
+ }}
1677
+ if (configPath && data.config_path) {{
1678
+ configPath.textContent = data.config_path;
1679
+ }}
1680
+
1522
1681
  // Update setup panel if setup logs present
1523
1682
  if (data.setup_logs && data.setup_logs.length > 0) {{
1524
1683
  const setupPanel = document.getElementById('setup-panel');
@@ -1914,7 +2073,7 @@ def generate_training_dashboard(state: TrainingState, config: TrainingConfig) ->
1914
2073
  setInterval(updateStatusIndicator, 1000); // Update LIVE/STALE indicator every second
1915
2074
  </script>
1916
2075
  </body>
1917
- </html>'''
2076
+ </html>"""
1918
2077
  return html
1919
2078
 
1920
2079
 
@@ -1952,6 +2111,7 @@ def regenerate_all_dashboards(output_dir: str | Path) -> list[Path]:
1952
2111
  except Exception as e:
1953
2112
  print(f"Warning: Failed to generate unified viewer: {e}")
1954
2113
  import traceback
2114
+
1955
2115
  traceback.print_exc()
1956
2116
 
1957
2117
  return regenerated
@@ -1994,7 +2154,9 @@ def regenerate_local_dashboard(
1994
2154
  state = TrainingState(
1995
2155
  job_id=data.get("job_id", "unknown"),
1996
2156
  hostname=data.get("hostname", ""),
1997
- capture_path=str(capture_path) if capture_path else data.get("capture_path", ""),
2157
+ capture_path=str(capture_path)
2158
+ if capture_path
2159
+ else data.get("capture_path", ""),
1998
2160
  config_path=data.get("config_path", ""),
1999
2161
  epoch=data.get("epoch", 0),
2000
2162
  step=data.get("step", 0),
@@ -2037,30 +2199,33 @@ def regenerate_local_dashboard(
2037
2199
  if training_status == "COMPLETED":
2038
2200
  html = html.replace(
2039
2201
  '<div class="status" id="status">',
2040
- '<div class="status complete" id="status">'
2202
+ '<div class="status complete" id="status">',
2041
2203
  )
2042
2204
  html = html.replace(
2043
2205
  '<span id="status-text">Training in progress</span>',
2044
- '<span id="status-text">COMPLETED</span>'
2206
+ '<span id="status-text">COMPLETED</span>',
2045
2207
  )
2046
2208
  elif training_status == "STOPPED":
2047
2209
  html = html.replace(
2048
- '<div class="status" id="status">',
2049
- '<div class="status stale" id="status">'
2210
+ '<div class="status" id="status">', '<div class="status stale" id="status">'
2050
2211
  )
2051
2212
  html = html.replace(
2052
2213
  '<span id="status-text">Training in progress</span>',
2053
- '<span id="status-text">STOPPED (Epoch {}/{})'.format(current_epoch + 1, total_epochs) + '</span>'
2214
+ '<span id="status-text">STOPPED (Epoch {}/{})'.format(
2215
+ current_epoch + 1, total_epochs
2216
+ )
2217
+ + "</span>",
2054
2218
  )
2055
2219
 
2056
2220
  # Fix ETA display for completed/stopped training
2057
2221
  import re
2222
+
2058
2223
  if training_status in ("COMPLETED", "STOPPED"):
2059
2224
  # Replace "calculating..." with appropriate status
2060
2225
  html = re.sub(
2061
2226
  r'(<div class="stat-value" id="stat-eta">)[^<]*(</div>)',
2062
- r'\1—\2' if training_status == "STOPPED" else r'\1complete\2',
2063
- html
2227
+ r"\1—\2" if training_status == "STOPPED" else r"\1complete\2",
2228
+ html,
2064
2229
  )
2065
2230
 
2066
2231
  # Replace dynamic nav with static unified header
@@ -2071,20 +2236,20 @@ def regenerate_local_dashboard(
2071
2236
  # This is critical for file:// protocol where fetch() doesn't work
2072
2237
  html = html.replace(
2073
2238
  "setInterval(fetchAndUpdate, 3000);",
2074
- "// fetchAndUpdate disabled for static dashboard"
2239
+ "// fetchAndUpdate disabled for static dashboard",
2075
2240
  )
2076
2241
  html = html.replace(
2077
2242
  "setInterval(updateElapsedDisplay, 1000);",
2078
- "// updateElapsedDisplay disabled for static dashboard"
2243
+ "// updateElapsedDisplay disabled for static dashboard",
2079
2244
  )
2080
2245
  html = html.replace(
2081
2246
  "setInterval(updateStatusIndicator, 1000);",
2082
- "// updateStatusIndicator disabled for static dashboard"
2247
+ "// updateStatusIndicator disabled for static dashboard",
2083
2248
  )
2084
2249
  # CRITICAL: Disable discoverDashboards() - it overwrites static nav on file:// protocol
2085
2250
  html = html.replace(
2086
2251
  "discoverDashboards();",
2087
- "// discoverDashboards disabled - using static nav for file:// protocol"
2252
+ "// discoverDashboards disabled - using static nav for file:// protocol",
2088
2253
  )
2089
2254
 
2090
2255
  # Write output
@@ -2093,354 +2258,3 @@ def regenerate_local_dashboard(
2093
2258
  print(f"Regenerated dashboard: {dashboard_path}")
2094
2259
 
2095
2260
  return dashboard_path
2096
-
2097
-
2098
- def run_epoch_evaluation(
2099
- adapter: BaseVLMAdapter,
2100
- episode: Episode,
2101
- epoch: int,
2102
- config: TrainingConfig,
2103
- logger: "TrainingLogger",
2104
- sample_indices: Optional[List[int]] = None,
2105
- ) -> Path:
2106
- """Run inference evaluation on sample steps after an epoch.
2107
-
2108
- This generates a comparison_epoch{N}.html file showing human vs predicted actions.
2109
-
2110
- Args:
2111
- adapter: Trained adapter to use for inference
2112
- episode: Episode with steps to evaluate
2113
- epoch: Current epoch number
2114
- config: Training configuration
2115
- logger: Training logger for state tracking
2116
- sample_indices: Specific step indices to evaluate (default: evenly spaced)
2117
-
2118
- Returns:
2119
- Path to generated comparison HTML file
2120
- """
2121
- from openadapt_ml.scripts.compare import generate_comparison_html, predict_action, format_action
2122
-
2123
- output_dir = Path(config.output_dir)
2124
- output_dir.mkdir(parents=True, exist_ok=True)
2125
-
2126
- # Select sample indices if not provided
2127
- num_samples = min(config.eval_samples, len(episode.steps))
2128
- if sample_indices is None:
2129
- if num_samples >= len(episode.steps):
2130
- sample_indices = list(range(len(episode.steps)))
2131
- else:
2132
- # Evenly space samples across the episode
2133
- step_size = len(episode.steps) // num_samples
2134
- sample_indices = [i * step_size for i in range(num_samples)]
2135
-
2136
- print(f" Running inference on {len(sample_indices)} sample steps...")
2137
-
2138
- # Switch adapter to eval mode
2139
- adapter.eval()
2140
-
2141
- comparison_data = []
2142
- action_history: List[str] = []
2143
- total_steps = len(episode.steps)
2144
-
2145
- for i, step in enumerate(episode.steps):
2146
- step_data = {
2147
- "index": i,
2148
- "time": step.t,
2149
- "image_path": step.observation.image_path,
2150
- "human_action": {
2151
- "type": step.action.type,
2152
- "x": step.action.x,
2153
- "y": step.action.y,
2154
- "text": step.action.text,
2155
- },
2156
- "predicted_action": None,
2157
- "match": None,
2158
- }
2159
-
2160
- # Only run inference on selected samples (for speed)
2161
- if i in sample_indices and step.observation.image_path:
2162
- try:
2163
- predicted = predict_action(
2164
- adapter,
2165
- step.observation.image_path,
2166
- episode.goal,
2167
- step_index=i,
2168
- total_steps=total_steps,
2169
- action_history=action_history.copy(),
2170
- )
2171
- step_data["predicted_action"] = predicted
2172
-
2173
- # Check match and calculate distance
2174
- if predicted and predicted.get("type") == step.action.type:
2175
- step_data["match"] = True
2176
-
2177
- # Calculate distance for click actions
2178
- if step.action.type == "click":
2179
- hx, hy = step.action.x or 0, step.action.y or 0
2180
- px, py = predicted.get("x", 0), predicted.get("y", 0)
2181
- distance = ((hx - px) ** 2 + (hy - py) ** 2) ** 0.5
2182
-
2183
- # Log evaluation to training state
2184
- logger.state.log_evaluation(
2185
- epoch=epoch,
2186
- sample_idx=i,
2187
- image_path=step.observation.image_path,
2188
- human_action=step_data["human_action"],
2189
- predicted_action=predicted,
2190
- )
2191
- else:
2192
- step_data["match"] = False
2193
-
2194
- print(f" Step {i}: {step.action.type} -> {predicted.get('type') if predicted else 'none'}")
2195
-
2196
- except Exception as e:
2197
- print(f" Step {i}: inference failed - {e}")
2198
-
2199
- # Build action history for context
2200
- action_history.append(format_action(step.action, use_som=False))
2201
- comparison_data.append(step_data)
2202
-
2203
- # Switch back to train mode
2204
- adapter.train()
2205
-
2206
- # Generate comparison HTML
2207
- output_path = output_dir / f"comparison_epoch{epoch}.html"
2208
- capture_path = Path(logger.state.capture_path) if logger.state.capture_path else Path(".")
2209
-
2210
- generate_comparison_html(capture_path, episode, comparison_data, output_path)
2211
- print(f" Comparison saved: {output_path}")
2212
-
2213
- # Also regenerate all dashboards to update navigation
2214
- regenerate_all_dashboards(output_dir)
2215
-
2216
- return output_path
2217
-
2218
-
2219
- def _create_dataloader(dataset: Dataset, batch_size: int) -> DataLoader:
2220
- # Use an identity collate_fn so that each batch is a List[Dict], matching
2221
- # the expectations of adapters that operate on SFT-style samples.
2222
- return DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=lambda x: x)
2223
-
2224
-
2225
- def _create_lr_scheduler(
2226
- optimizer: Optimizer,
2227
- config: TrainingConfig,
2228
- num_training_steps: int,
2229
- ) -> Optional[LambdaLR]:
2230
- """Create learning rate scheduler based on config.
2231
-
2232
- Args:
2233
- optimizer: The optimizer to schedule.
2234
- config: Training configuration with lr_scheduler_type and warmup_ratio.
2235
- num_training_steps: Total number of training steps.
2236
-
2237
- Returns:
2238
- LambdaLR scheduler or None if scheduler_type is "none" or "constant".
2239
- """
2240
- scheduler_type = config.lr_scheduler_type.lower()
2241
-
2242
- if scheduler_type in ("none", "constant"):
2243
- return None
2244
-
2245
- num_warmup_steps = int(num_training_steps * config.warmup_ratio)
2246
-
2247
- if scheduler_type == "linear":
2248
- def lr_lambda(current_step: int) -> float:
2249
- if current_step < num_warmup_steps:
2250
- # Linear warmup
2251
- return float(current_step) / float(max(1, num_warmup_steps))
2252
- # Linear decay
2253
- return max(
2254
- 0.0,
2255
- float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))
2256
- )
2257
- elif scheduler_type == "cosine":
2258
- import math
2259
- def lr_lambda(current_step: int) -> float:
2260
- if current_step < num_warmup_steps:
2261
- # Linear warmup
2262
- return float(current_step) / float(max(1, num_warmup_steps))
2263
- # Cosine decay
2264
- progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
2265
- return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress)))
2266
- else:
2267
- raise ValueError(f"Unknown lr_scheduler_type: {scheduler_type}. Use 'linear', 'cosine', 'constant', or 'none'.")
2268
-
2269
- return LambdaLR(optimizer, lr_lambda)
2270
-
2271
-
2272
- def train_supervised(
2273
- adapter: BaseVLMAdapter,
2274
- dataset: Dataset,
2275
- config: TrainingConfig,
2276
- optimizer: Optional[Optimizer] = None,
2277
- logger: Optional[TrainingLogger] = None,
2278
- episode: Optional[Episode] = None,
2279
- ) -> bool:
2280
- """Minimal supervised training loop skeleton.
2281
-
2282
- This assumes that `adapter.prepare_inputs` and `adapter.compute_loss` are
2283
- implemented. It will raise if those methods are not implemented.
2284
-
2285
- Args:
2286
- adapter: VLM adapter to train.
2287
- dataset: Training dataset.
2288
- config: Training configuration.
2289
- optimizer: Optional optimizer (default: AdamW).
2290
- logger: Optional training logger for visualization.
2291
- episode: Optional episode for periodic evaluation (generates comparison_epoch{N}.html).
2292
-
2293
- Returns:
2294
- True if training completed successfully, False if aborted due to NaN/Inf loss.
2295
- """
2296
-
2297
- device = adapter.device # type: ignore[attr-defined]
2298
- dataloader = _create_dataloader(dataset, batch_size=config.per_device_train_batch_size)
2299
-
2300
- if optimizer is None:
2301
- optimizer = torch.optim.AdamW(
2302
- adapter.model.parameters(), # type: ignore[arg-type]
2303
- lr=config.learning_rate,
2304
- weight_decay=config.weight_decay,
2305
- )
2306
-
2307
- # Create logger if not provided
2308
- if logger is None:
2309
- logger = TrainingLogger(config.output_dir, config)
2310
-
2311
- # Calculate total training steps for scheduler
2312
- num_training_steps = len(dataloader) * config.num_train_epochs // config.gradient_accumulation_steps
2313
-
2314
- # Create learning rate scheduler
2315
- lr_scheduler = _create_lr_scheduler(optimizer, config, num_training_steps)
2316
-
2317
- total_steps = 0
2318
- adapter.train()
2319
-
2320
- # Early stopping tracking
2321
- consecutive_low_loss = 0
2322
- early_stopped = False
2323
- user_stopped = False
2324
-
2325
- for epoch in range(config.num_train_epochs):
2326
- if early_stopped or user_stopped:
2327
- break
2328
-
2329
- for _, batch in enumerate(dataloader):
2330
- # Check for stop signal from dashboard
2331
- stop_file = Path(config.output_dir) / "STOP_TRAINING"
2332
- if stop_file.exists():
2333
- msg = "Stop signal received from dashboard. Stopping training..."
2334
- print(msg)
2335
- logger._log_to_terminal(msg)
2336
- # Set termination status for dashboard
2337
- logger.state.termination_status = "user_stop"
2338
- logger.state.termination_message = "Training stopped by user via dashboard"
2339
- logger.save()
2340
- user_stopped = True
2341
- stop_file.unlink() # Remove signal file
2342
- break
2343
-
2344
- # Batch is a List[Dict[str, Any]] of SFT-style samples; adapter is
2345
- # responsible for converting it into model inputs.
2346
- samples: List[Dict[str, Any]] = batch
2347
-
2348
- inputs = adapter.prepare_inputs(samples)
2349
- inputs = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
2350
-
2351
- loss = adapter.compute_loss(inputs)
2352
-
2353
- # Guard against invalid losses to avoid propagating NaNs/Infs
2354
- if torch.isnan(loss) or torch.isinf(loss):
2355
- msg = f"Encountered invalid loss at epoch={epoch} step={total_steps + 1}: {loss.item()}"
2356
- print(msg)
2357
- logger._log_to_terminal(msg)
2358
- logger.on_train_end()
2359
- return False
2360
-
2361
- loss.backward()
2362
-
2363
- if (total_steps + 1) % config.gradient_accumulation_steps == 0:
2364
- torch.nn.utils.clip_grad_norm_(adapter.model.parameters(), config.max_grad_norm) # type: ignore[arg-type]
2365
- optimizer.step()
2366
- if lr_scheduler is not None:
2367
- lr_scheduler.step()
2368
- optimizer.zero_grad()
2369
-
2370
- total_steps += 1
2371
- loss_val = loss.item()
2372
-
2373
- # Get current learning rate from optimizer
2374
- current_lr = optimizer.param_groups[0]['lr']
2375
-
2376
- # Log step
2377
- logger.on_step(epoch, total_steps, loss_val, current_lr)
2378
-
2379
- if config.logging_steps and total_steps % config.logging_steps == 0:
2380
- msg = f"epoch={epoch} step={total_steps} loss={loss_val:.4f} lr={current_lr:.6f}"
2381
- print(msg)
2382
- logger._log_to_terminal(msg)
2383
-
2384
- # Early stopping check
2385
- if loss_val < config.early_stop_loss:
2386
- consecutive_low_loss += 1
2387
- if consecutive_low_loss >= config.early_stop_patience:
2388
- msg = (
2389
- f"Early stopping: loss ({loss_val:.6f}) below threshold "
2390
- f"({config.early_stop_loss}) for {config.early_stop_patience} consecutive steps"
2391
- )
2392
- print(msg)
2393
- logger._log_to_terminal(msg)
2394
- # Set termination status for dashboard
2395
- logger.state.termination_status = "auto_low_loss"
2396
- logger.state.termination_message = (
2397
- f"Loss reached {loss_val:.6f} (< {config.early_stop_loss}) "
2398
- f"for {config.early_stop_patience} consecutive steps"
2399
- )
2400
- logger.save()
2401
- early_stopped = True
2402
- break
2403
- else:
2404
- consecutive_low_loss = 0
2405
-
2406
- # End of epoch
2407
- logger.on_epoch_end(epoch)
2408
-
2409
- # Save checkpoint at end of each epoch
2410
- if config.save_checkpoint_every_epoch:
2411
- checkpoint_path = Path(config.checkpoint_dir) / f"epoch_{epoch}"
2412
- checkpoint_path.mkdir(parents=True, exist_ok=True)
2413
- try:
2414
- adapter.save_checkpoint(str(checkpoint_path))
2415
- msg = f"Checkpoint saved to {checkpoint_path}"
2416
- print(msg)
2417
- logger._log_to_terminal(msg)
2418
- except Exception as e:
2419
- msg = f"Warning: Failed to save checkpoint: {e}"
2420
- print(msg)
2421
- logger._log_to_terminal(msg)
2422
-
2423
- # Run evaluation after each epoch (generates comparison_epoch{N}.html)
2424
- if config.eval_every_epoch and episode is not None:
2425
- try:
2426
- print(f"Running epoch {epoch} evaluation...")
2427
- run_epoch_evaluation(
2428
- adapter=adapter,
2429
- episode=episode,
2430
- epoch=epoch,
2431
- config=config,
2432
- logger=logger,
2433
- )
2434
- except Exception as e:
2435
- print(f"Warning: Epoch evaluation failed: {e}")
2436
- import traceback
2437
- traceback.print_exc()
2438
-
2439
- # Set termination status if not already set (normal completion)
2440
- if not logger.state.termination_status:
2441
- logger.state.termination_status = "auto_complete"
2442
- logger.state.termination_message = f"Training completed all {config.num_train_epochs} epochs"
2443
- logger.save()
2444
-
2445
- logger.on_train_end()
2446
- return True