openadapt-ml 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. openadapt_ml/benchmarks/__init__.py +8 -0
  2. openadapt_ml/benchmarks/agent.py +90 -11
  3. openadapt_ml/benchmarks/azure.py +35 -6
  4. openadapt_ml/benchmarks/cli.py +4449 -201
  5. openadapt_ml/benchmarks/live_tracker.py +180 -0
  6. openadapt_ml/benchmarks/runner.py +41 -4
  7. openadapt_ml/benchmarks/viewer.py +1219 -0
  8. openadapt_ml/benchmarks/vm_monitor.py +610 -0
  9. openadapt_ml/benchmarks/waa.py +61 -4
  10. openadapt_ml/benchmarks/waa_deploy/Dockerfile +222 -0
  11. openadapt_ml/benchmarks/waa_deploy/__init__.py +10 -0
  12. openadapt_ml/benchmarks/waa_deploy/api_agent.py +539 -0
  13. openadapt_ml/benchmarks/waa_deploy/start_waa_server.bat +53 -0
  14. openadapt_ml/benchmarks/waa_live.py +619 -0
  15. openadapt_ml/cloud/local.py +1555 -1
  16. openadapt_ml/cloud/ssh_tunnel.py +553 -0
  17. openadapt_ml/datasets/next_action.py +87 -68
  18. openadapt_ml/evals/grounding.py +26 -8
  19. openadapt_ml/evals/trajectory_matching.py +84 -36
  20. openadapt_ml/experiments/demo_prompt/__init__.py +19 -0
  21. openadapt_ml/experiments/demo_prompt/format_demo.py +226 -0
  22. openadapt_ml/experiments/demo_prompt/results/experiment_20251231_002125.json +83 -0
  23. openadapt_ml/experiments/demo_prompt/results/experiment_n30_20251231_165958.json +1100 -0
  24. openadapt_ml/experiments/demo_prompt/results/multistep_20251231_025051.json +182 -0
  25. openadapt_ml/experiments/demo_prompt/run_experiment.py +531 -0
  26. openadapt_ml/experiments/waa_demo/__init__.py +10 -0
  27. openadapt_ml/experiments/waa_demo/demos.py +357 -0
  28. openadapt_ml/experiments/waa_demo/runner.py +717 -0
  29. openadapt_ml/experiments/waa_demo/tasks.py +151 -0
  30. openadapt_ml/export/__init__.py +9 -0
  31. openadapt_ml/export/__main__.py +6 -0
  32. openadapt_ml/export/cli.py +89 -0
  33. openadapt_ml/export/parquet.py +265 -0
  34. openadapt_ml/ingest/__init__.py +3 -4
  35. openadapt_ml/ingest/capture.py +89 -81
  36. openadapt_ml/ingest/loader.py +116 -68
  37. openadapt_ml/ingest/synthetic.py +221 -159
  38. openadapt_ml/retrieval/README.md +226 -0
  39. openadapt_ml/retrieval/USAGE.md +391 -0
  40. openadapt_ml/retrieval/__init__.py +91 -0
  41. openadapt_ml/retrieval/demo_retriever.py +817 -0
  42. openadapt_ml/retrieval/embeddings.py +629 -0
  43. openadapt_ml/retrieval/index.py +194 -0
  44. openadapt_ml/retrieval/retriever.py +160 -0
  45. openadapt_ml/runtime/policy.py +10 -10
  46. openadapt_ml/schema/__init__.py +104 -0
  47. openadapt_ml/schema/converters.py +541 -0
  48. openadapt_ml/schema/episode.py +457 -0
  49. openadapt_ml/scripts/compare.py +26 -16
  50. openadapt_ml/scripts/eval_policy.py +4 -5
  51. openadapt_ml/scripts/prepare_synthetic.py +14 -17
  52. openadapt_ml/scripts/train.py +81 -70
  53. openadapt_ml/training/benchmark_viewer.py +3225 -0
  54. openadapt_ml/training/trainer.py +120 -363
  55. openadapt_ml/training/trl_trainer.py +354 -0
  56. {openadapt_ml-0.1.0.dist-info → openadapt_ml-0.2.0.dist-info}/METADATA +102 -60
  57. openadapt_ml-0.2.0.dist-info/RECORD +86 -0
  58. openadapt_ml/schemas/__init__.py +0 -53
  59. openadapt_ml/schemas/sessions.py +0 -122
  60. openadapt_ml/schemas/validation.py +0 -252
  61. openadapt_ml-0.1.0.dist-info/RECORD +0 -55
  62. {openadapt_ml-0.1.0.dist-info → openadapt_ml-0.2.0.dist-info}/WHEEL +0 -0
  63. {openadapt_ml-0.1.0.dist-info → openadapt_ml-0.2.0.dist-info}/licenses/LICENSE +0 -0
@@ -6,13 +6,7 @@ from dataclasses import dataclass, field
6
6
  from pathlib import Path
7
7
  from typing import Any, Callable, Dict, List, Optional
8
8
 
9
- import torch
10
- from torch.optim import Optimizer
11
- from torch.optim.lr_scheduler import LambdaLR
12
- from torch.utils.data import DataLoader, Dataset
13
-
14
- from openadapt_ml.models.base_adapter import BaseVLMAdapter
15
- from openadapt_ml.schemas.sessions import Episode
9
+ from openadapt_ml.schema import Episode, Step, Action, ActionType
16
10
  from openadapt_ml.training.shared_ui import (
17
11
  get_shared_header_css as _get_shared_header_css,
18
12
  generate_shared_header_html as _generate_shared_header_html,
@@ -21,6 +15,10 @@ from openadapt_ml.training.shared_ui import (
21
15
  from openadapt_ml.training.viewer import (
22
16
  generate_unified_viewer_from_output_dir,
23
17
  )
18
+ from openadapt_ml.training.benchmark_viewer import (
19
+ _get_azure_jobs_panel_css,
20
+ _get_azure_jobs_panel_html,
21
+ )
24
22
 
25
23
 
26
24
  def setup_job_directory(base_dir: str | Path, job_id: str) -> Path:
@@ -116,6 +114,11 @@ class TrainingState:
116
114
  capture_path: str = ""
117
115
  config_path: str = ""
118
116
  goal: str = "" # Task goal/description for the training run
117
+ # Model configuration
118
+ model_name: str = "" # e.g. "Qwen/Qwen3-VL-2B-Instruct"
119
+ lora_r: int = 0 # LoRA rank
120
+ lora_alpha: int = 0 # LoRA alpha
121
+ load_in_4bit: bool = False # Quantization
119
122
  # Training progress
120
123
  epoch: int = 0
121
124
  step: int = 0
@@ -185,6 +188,11 @@ class TrainingState:
185
188
  "capture_path": self.capture_path,
186
189
  "config_path": self.config_path,
187
190
  "goal": self.goal,
191
+ # Model configuration
192
+ "model_name": self.model_name,
193
+ "lora_r": self.lora_r,
194
+ "lora_alpha": self.lora_alpha,
195
+ "load_in_4bit": self.load_in_4bit,
188
196
  "instance_type": self.instance_type,
189
197
  "instance_ip": self.instance_ip,
190
198
  "started_at": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime(self.start_time)),
@@ -227,6 +235,11 @@ class TrainingLogger:
227
235
  cloud_dashboard_url: str = "",
228
236
  cloud_instance_id: str = "",
229
237
  job_id: str = "",
238
+ # Model configuration
239
+ model_name: str = "",
240
+ lora_r: int = 0,
241
+ lora_alpha: int = 0,
242
+ load_in_4bit: bool = False,
230
243
  ):
231
244
  # Generate job_id if not provided
232
245
  if not job_id:
@@ -242,6 +255,10 @@ class TrainingLogger:
242
255
  capture_path=capture_path,
243
256
  config_path=config_path,
244
257
  goal=goal,
258
+ model_name=model_name,
259
+ lora_r=lora_r,
260
+ lora_alpha=lora_alpha,
261
+ load_in_4bit=load_in_4bit,
245
262
  instance_ip=instance_ip,
246
263
  instance_type=instance_type,
247
264
  total_epochs=config.num_train_epochs,
@@ -428,14 +445,18 @@ def generate_training_dashboard(state: TrainingState, config: TrainingConfig) ->
428
445
  # Generate comparison data with null predictions (shows "— No prediction")
429
446
  comparison_data = []
430
447
  for i, step in enumerate(episode.steps):
448
+ # Extract normalized coordinates if available
449
+ action_x, action_y = None, None
450
+ if step.action.normalized_coordinates:
451
+ action_x, action_y = step.action.normalized_coordinates
431
452
  step_data = {
432
453
  "index": i,
433
- "time": step.t,
434
- "image_path": step.observation.image_path,
454
+ "time": step.step_index,
455
+ "image_path": step.observation.screenshot_path,
435
456
  "human_action": {
436
- "type": step.action.type,
437
- "x": step.action.x,
438
- "y": step.action.y,
457
+ "type": step.action.type.value if isinstance(step.action.type, ActionType) else step.action.type,
458
+ "x": action_x,
459
+ "y": action_y,
439
460
  "text": step.action.text,
440
461
  },
441
462
  "predicted_action": None, # Shows "— No prediction" in viewer
@@ -596,6 +617,42 @@ def generate_training_dashboard(state: TrainingState, config: TrainingConfig) ->
596
617
  .setup-log-line.current {{
597
618
  color: var(--accent);
598
619
  }}
620
+ .config-panel {{
621
+ background: var(--bg-secondary);
622
+ border: 1px solid var(--border-color);
623
+ border-radius: 12px;
624
+ padding: 16px 20px;
625
+ margin-bottom: 24px;
626
+ }}
627
+ .config-grid {{
628
+ display: grid;
629
+ grid-template-columns: repeat(auto-fit, minmax(200px, 1fr));
630
+ gap: 16px;
631
+ }}
632
+ .config-item {{
633
+ display: flex;
634
+ flex-direction: column;
635
+ gap: 4px;
636
+ }}
637
+ .config-label {{
638
+ font-size: 0.7rem;
639
+ color: var(--text-secondary);
640
+ text-transform: uppercase;
641
+ letter-spacing: 0.5px;
642
+ }}
643
+ .config-value {{
644
+ font-family: "SF Mono", Monaco, monospace;
645
+ font-size: 0.85rem;
646
+ color: var(--text-primary);
647
+ }}
648
+ .config-value.model {{
649
+ color: var(--accent);
650
+ }}
651
+ .config-value.goal {{
652
+ font-family: -apple-system, BlinkMacSystemFont, "Inter", sans-serif;
653
+ font-size: 0.8rem;
654
+ opacity: 0.9;
655
+ }}
599
656
  .status {{
600
657
  display: flex;
601
658
  align-items: center;
@@ -754,6 +811,8 @@ def generate_training_dashboard(state: TrainingState, config: TrainingConfig) ->
754
811
  }}
755
812
  /* Shared header styles (injected from _get_shared_header_css) */
756
813
  {_get_shared_header_css()}
814
+ /* Azure ML Jobs panel styles */
815
+ {_get_azure_jobs_panel_css()}
757
816
  .eval-panel {{
758
817
  background: var(--bg-secondary);
759
818
  border: 1px solid var(--border-color);
@@ -1095,6 +1154,33 @@ def generate_training_dashboard(state: TrainingState, config: TrainingConfig) ->
1095
1154
  </div>
1096
1155
  </div>
1097
1156
 
1157
+ {_get_azure_jobs_panel_html()}
1158
+
1159
+ <div class="config-panel" id="config-panel">
1160
+ <div class="config-grid">
1161
+ <div class="config-item">
1162
+ <span class="config-label">Model</span>
1163
+ <span class="config-value model" id="config-model">{state.model_name or 'Not specified'}</span>
1164
+ </div>
1165
+ <div class="config-item">
1166
+ <span class="config-label">Goal</span>
1167
+ <span class="config-value goal" id="config-goal">{state.goal or 'Not specified'}</span>
1168
+ </div>
1169
+ <div class="config-item">
1170
+ <span class="config-label">LoRA</span>
1171
+ <span class="config-value" id="config-lora">{f'r={state.lora_r}, α={state.lora_alpha}' if state.lora_r else 'Not specified'}</span>
1172
+ </div>
1173
+ <div class="config-item">
1174
+ <span class="config-label">Quantization</span>
1175
+ <span class="config-value" id="config-quant">{'4-bit' if state.load_in_4bit else 'None'}</span>
1176
+ </div>
1177
+ <div class="config-item">
1178
+ <span class="config-label">Config</span>
1179
+ <span class="config-value" id="config-path">{state.config_path or 'Not specified'}</span>
1180
+ </div>
1181
+ </div>
1182
+ </div>
1183
+
1098
1184
  <div class="stats-grid">
1099
1185
  <div class="stat-card" id="card-epoch">
1100
1186
  <div class="stat-label">Epoch Progress</div>
@@ -1519,6 +1605,28 @@ def generate_training_dashboard(state: TrainingState, config: TrainingConfig) ->
1519
1605
  }}
1520
1606
  }}
1521
1607
 
1608
+ // Update config panel
1609
+ const configModel = document.getElementById('config-model');
1610
+ const configGoal = document.getElementById('config-goal');
1611
+ const configLora = document.getElementById('config-lora');
1612
+ const configQuant = document.getElementById('config-quant');
1613
+ const configPath = document.getElementById('config-path');
1614
+ if (configModel && data.model_name) {{
1615
+ configModel.textContent = data.model_name;
1616
+ }}
1617
+ if (configGoal && data.goal) {{
1618
+ configGoal.textContent = data.goal;
1619
+ }}
1620
+ if (configLora && (data.lora_r || data.lora_alpha)) {{
1621
+ configLora.textContent = `r=${{data.lora_r || 0}}, α=${{data.lora_alpha || 0}}`;
1622
+ }}
1623
+ if (configQuant) {{
1624
+ configQuant.textContent = data.load_in_4bit ? '4-bit' : 'None';
1625
+ }}
1626
+ if (configPath && data.config_path) {{
1627
+ configPath.textContent = data.config_path;
1628
+ }}
1629
+
1522
1630
  // Update setup panel if setup logs present
1523
1631
  if (data.setup_logs && data.setup_logs.length > 0) {{
1524
1632
  const setupPanel = document.getElementById('setup-panel');
@@ -2093,354 +2201,3 @@ def regenerate_local_dashboard(
2093
2201
  print(f"Regenerated dashboard: {dashboard_path}")
2094
2202
 
2095
2203
  return dashboard_path
2096
-
2097
-
2098
- def run_epoch_evaluation(
2099
- adapter: BaseVLMAdapter,
2100
- episode: Episode,
2101
- epoch: int,
2102
- config: TrainingConfig,
2103
- logger: "TrainingLogger",
2104
- sample_indices: Optional[List[int]] = None,
2105
- ) -> Path:
2106
- """Run inference evaluation on sample steps after an epoch.
2107
-
2108
- This generates a comparison_epoch{N}.html file showing human vs predicted actions.
2109
-
2110
- Args:
2111
- adapter: Trained adapter to use for inference
2112
- episode: Episode with steps to evaluate
2113
- epoch: Current epoch number
2114
- config: Training configuration
2115
- logger: Training logger for state tracking
2116
- sample_indices: Specific step indices to evaluate (default: evenly spaced)
2117
-
2118
- Returns:
2119
- Path to generated comparison HTML file
2120
- """
2121
- from openadapt_ml.scripts.compare import generate_comparison_html, predict_action, format_action
2122
-
2123
- output_dir = Path(config.output_dir)
2124
- output_dir.mkdir(parents=True, exist_ok=True)
2125
-
2126
- # Select sample indices if not provided
2127
- num_samples = min(config.eval_samples, len(episode.steps))
2128
- if sample_indices is None:
2129
- if num_samples >= len(episode.steps):
2130
- sample_indices = list(range(len(episode.steps)))
2131
- else:
2132
- # Evenly space samples across the episode
2133
- step_size = len(episode.steps) // num_samples
2134
- sample_indices = [i * step_size for i in range(num_samples)]
2135
-
2136
- print(f" Running inference on {len(sample_indices)} sample steps...")
2137
-
2138
- # Switch adapter to eval mode
2139
- adapter.eval()
2140
-
2141
- comparison_data = []
2142
- action_history: List[str] = []
2143
- total_steps = len(episode.steps)
2144
-
2145
- for i, step in enumerate(episode.steps):
2146
- step_data = {
2147
- "index": i,
2148
- "time": step.t,
2149
- "image_path": step.observation.image_path,
2150
- "human_action": {
2151
- "type": step.action.type,
2152
- "x": step.action.x,
2153
- "y": step.action.y,
2154
- "text": step.action.text,
2155
- },
2156
- "predicted_action": None,
2157
- "match": None,
2158
- }
2159
-
2160
- # Only run inference on selected samples (for speed)
2161
- if i in sample_indices and step.observation.image_path:
2162
- try:
2163
- predicted = predict_action(
2164
- adapter,
2165
- step.observation.image_path,
2166
- episode.goal,
2167
- step_index=i,
2168
- total_steps=total_steps,
2169
- action_history=action_history.copy(),
2170
- )
2171
- step_data["predicted_action"] = predicted
2172
-
2173
- # Check match and calculate distance
2174
- if predicted and predicted.get("type") == step.action.type:
2175
- step_data["match"] = True
2176
-
2177
- # Calculate distance for click actions
2178
- if step.action.type == "click":
2179
- hx, hy = step.action.x or 0, step.action.y or 0
2180
- px, py = predicted.get("x", 0), predicted.get("y", 0)
2181
- distance = ((hx - px) ** 2 + (hy - py) ** 2) ** 0.5
2182
-
2183
- # Log evaluation to training state
2184
- logger.state.log_evaluation(
2185
- epoch=epoch,
2186
- sample_idx=i,
2187
- image_path=step.observation.image_path,
2188
- human_action=step_data["human_action"],
2189
- predicted_action=predicted,
2190
- )
2191
- else:
2192
- step_data["match"] = False
2193
-
2194
- print(f" Step {i}: {step.action.type} -> {predicted.get('type') if predicted else 'none'}")
2195
-
2196
- except Exception as e:
2197
- print(f" Step {i}: inference failed - {e}")
2198
-
2199
- # Build action history for context
2200
- action_history.append(format_action(step.action, use_som=False))
2201
- comparison_data.append(step_data)
2202
-
2203
- # Switch back to train mode
2204
- adapter.train()
2205
-
2206
- # Generate comparison HTML
2207
- output_path = output_dir / f"comparison_epoch{epoch}.html"
2208
- capture_path = Path(logger.state.capture_path) if logger.state.capture_path else Path(".")
2209
-
2210
- generate_comparison_html(capture_path, episode, comparison_data, output_path)
2211
- print(f" Comparison saved: {output_path}")
2212
-
2213
- # Also regenerate all dashboards to update navigation
2214
- regenerate_all_dashboards(output_dir)
2215
-
2216
- return output_path
2217
-
2218
-
2219
- def _create_dataloader(dataset: Dataset, batch_size: int) -> DataLoader:
2220
- # Use an identity collate_fn so that each batch is a List[Dict], matching
2221
- # the expectations of adapters that operate on SFT-style samples.
2222
- return DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=lambda x: x)
2223
-
2224
-
2225
- def _create_lr_scheduler(
2226
- optimizer: Optimizer,
2227
- config: TrainingConfig,
2228
- num_training_steps: int,
2229
- ) -> Optional[LambdaLR]:
2230
- """Create learning rate scheduler based on config.
2231
-
2232
- Args:
2233
- optimizer: The optimizer to schedule.
2234
- config: Training configuration with lr_scheduler_type and warmup_ratio.
2235
- num_training_steps: Total number of training steps.
2236
-
2237
- Returns:
2238
- LambdaLR scheduler or None if scheduler_type is "none" or "constant".
2239
- """
2240
- scheduler_type = config.lr_scheduler_type.lower()
2241
-
2242
- if scheduler_type in ("none", "constant"):
2243
- return None
2244
-
2245
- num_warmup_steps = int(num_training_steps * config.warmup_ratio)
2246
-
2247
- if scheduler_type == "linear":
2248
- def lr_lambda(current_step: int) -> float:
2249
- if current_step < num_warmup_steps:
2250
- # Linear warmup
2251
- return float(current_step) / float(max(1, num_warmup_steps))
2252
- # Linear decay
2253
- return max(
2254
- 0.0,
2255
- float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))
2256
- )
2257
- elif scheduler_type == "cosine":
2258
- import math
2259
- def lr_lambda(current_step: int) -> float:
2260
- if current_step < num_warmup_steps:
2261
- # Linear warmup
2262
- return float(current_step) / float(max(1, num_warmup_steps))
2263
- # Cosine decay
2264
- progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
2265
- return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress)))
2266
- else:
2267
- raise ValueError(f"Unknown lr_scheduler_type: {scheduler_type}. Use 'linear', 'cosine', 'constant', or 'none'.")
2268
-
2269
- return LambdaLR(optimizer, lr_lambda)
2270
-
2271
-
2272
- def train_supervised(
2273
- adapter: BaseVLMAdapter,
2274
- dataset: Dataset,
2275
- config: TrainingConfig,
2276
- optimizer: Optional[Optimizer] = None,
2277
- logger: Optional[TrainingLogger] = None,
2278
- episode: Optional[Episode] = None,
2279
- ) -> bool:
2280
- """Minimal supervised training loop skeleton.
2281
-
2282
- This assumes that `adapter.prepare_inputs` and `adapter.compute_loss` are
2283
- implemented. It will raise if those methods are not implemented.
2284
-
2285
- Args:
2286
- adapter: VLM adapter to train.
2287
- dataset: Training dataset.
2288
- config: Training configuration.
2289
- optimizer: Optional optimizer (default: AdamW).
2290
- logger: Optional training logger for visualization.
2291
- episode: Optional episode for periodic evaluation (generates comparison_epoch{N}.html).
2292
-
2293
- Returns:
2294
- True if training completed successfully, False if aborted due to NaN/Inf loss.
2295
- """
2296
-
2297
- device = adapter.device # type: ignore[attr-defined]
2298
- dataloader = _create_dataloader(dataset, batch_size=config.per_device_train_batch_size)
2299
-
2300
- if optimizer is None:
2301
- optimizer = torch.optim.AdamW(
2302
- adapter.model.parameters(), # type: ignore[arg-type]
2303
- lr=config.learning_rate,
2304
- weight_decay=config.weight_decay,
2305
- )
2306
-
2307
- # Create logger if not provided
2308
- if logger is None:
2309
- logger = TrainingLogger(config.output_dir, config)
2310
-
2311
- # Calculate total training steps for scheduler
2312
- num_training_steps = len(dataloader) * config.num_train_epochs // config.gradient_accumulation_steps
2313
-
2314
- # Create learning rate scheduler
2315
- lr_scheduler = _create_lr_scheduler(optimizer, config, num_training_steps)
2316
-
2317
- total_steps = 0
2318
- adapter.train()
2319
-
2320
- # Early stopping tracking
2321
- consecutive_low_loss = 0
2322
- early_stopped = False
2323
- user_stopped = False
2324
-
2325
- for epoch in range(config.num_train_epochs):
2326
- if early_stopped or user_stopped:
2327
- break
2328
-
2329
- for _, batch in enumerate(dataloader):
2330
- # Check for stop signal from dashboard
2331
- stop_file = Path(config.output_dir) / "STOP_TRAINING"
2332
- if stop_file.exists():
2333
- msg = "Stop signal received from dashboard. Stopping training..."
2334
- print(msg)
2335
- logger._log_to_terminal(msg)
2336
- # Set termination status for dashboard
2337
- logger.state.termination_status = "user_stop"
2338
- logger.state.termination_message = "Training stopped by user via dashboard"
2339
- logger.save()
2340
- user_stopped = True
2341
- stop_file.unlink() # Remove signal file
2342
- break
2343
-
2344
- # Batch is a List[Dict[str, Any]] of SFT-style samples; adapter is
2345
- # responsible for converting it into model inputs.
2346
- samples: List[Dict[str, Any]] = batch
2347
-
2348
- inputs = adapter.prepare_inputs(samples)
2349
- inputs = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
2350
-
2351
- loss = adapter.compute_loss(inputs)
2352
-
2353
- # Guard against invalid losses to avoid propagating NaNs/Infs
2354
- if torch.isnan(loss) or torch.isinf(loss):
2355
- msg = f"Encountered invalid loss at epoch={epoch} step={total_steps + 1}: {loss.item()}"
2356
- print(msg)
2357
- logger._log_to_terminal(msg)
2358
- logger.on_train_end()
2359
- return False
2360
-
2361
- loss.backward()
2362
-
2363
- if (total_steps + 1) % config.gradient_accumulation_steps == 0:
2364
- torch.nn.utils.clip_grad_norm_(adapter.model.parameters(), config.max_grad_norm) # type: ignore[arg-type]
2365
- optimizer.step()
2366
- if lr_scheduler is not None:
2367
- lr_scheduler.step()
2368
- optimizer.zero_grad()
2369
-
2370
- total_steps += 1
2371
- loss_val = loss.item()
2372
-
2373
- # Get current learning rate from optimizer
2374
- current_lr = optimizer.param_groups[0]['lr']
2375
-
2376
- # Log step
2377
- logger.on_step(epoch, total_steps, loss_val, current_lr)
2378
-
2379
- if config.logging_steps and total_steps % config.logging_steps == 0:
2380
- msg = f"epoch={epoch} step={total_steps} loss={loss_val:.4f} lr={current_lr:.6f}"
2381
- print(msg)
2382
- logger._log_to_terminal(msg)
2383
-
2384
- # Early stopping check
2385
- if loss_val < config.early_stop_loss:
2386
- consecutive_low_loss += 1
2387
- if consecutive_low_loss >= config.early_stop_patience:
2388
- msg = (
2389
- f"Early stopping: loss ({loss_val:.6f}) below threshold "
2390
- f"({config.early_stop_loss}) for {config.early_stop_patience} consecutive steps"
2391
- )
2392
- print(msg)
2393
- logger._log_to_terminal(msg)
2394
- # Set termination status for dashboard
2395
- logger.state.termination_status = "auto_low_loss"
2396
- logger.state.termination_message = (
2397
- f"Loss reached {loss_val:.6f} (< {config.early_stop_loss}) "
2398
- f"for {config.early_stop_patience} consecutive steps"
2399
- )
2400
- logger.save()
2401
- early_stopped = True
2402
- break
2403
- else:
2404
- consecutive_low_loss = 0
2405
-
2406
- # End of epoch
2407
- logger.on_epoch_end(epoch)
2408
-
2409
- # Save checkpoint at end of each epoch
2410
- if config.save_checkpoint_every_epoch:
2411
- checkpoint_path = Path(config.checkpoint_dir) / f"epoch_{epoch}"
2412
- checkpoint_path.mkdir(parents=True, exist_ok=True)
2413
- try:
2414
- adapter.save_checkpoint(str(checkpoint_path))
2415
- msg = f"Checkpoint saved to {checkpoint_path}"
2416
- print(msg)
2417
- logger._log_to_terminal(msg)
2418
- except Exception as e:
2419
- msg = f"Warning: Failed to save checkpoint: {e}"
2420
- print(msg)
2421
- logger._log_to_terminal(msg)
2422
-
2423
- # Run evaluation after each epoch (generates comparison_epoch{N}.html)
2424
- if config.eval_every_epoch and episode is not None:
2425
- try:
2426
- print(f"Running epoch {epoch} evaluation...")
2427
- run_epoch_evaluation(
2428
- adapter=adapter,
2429
- episode=episode,
2430
- epoch=epoch,
2431
- config=config,
2432
- logger=logger,
2433
- )
2434
- except Exception as e:
2435
- print(f"Warning: Epoch evaluation failed: {e}")
2436
- import traceback
2437
- traceback.print_exc()
2438
-
2439
- # Set termination status if not already set (normal completion)
2440
- if not logger.state.termination_status:
2441
- logger.state.termination_status = "auto_complete"
2442
- logger.state.termination_message = f"Training completed all {config.num_train_epochs} epochs"
2443
- logger.save()
2444
-
2445
- logger.on_train_end()
2446
- return True