openadapt-ml 0.2.0__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (95) hide show
  1. openadapt_ml/baselines/__init__.py +121 -0
  2. openadapt_ml/baselines/adapter.py +185 -0
  3. openadapt_ml/baselines/cli.py +314 -0
  4. openadapt_ml/baselines/config.py +448 -0
  5. openadapt_ml/baselines/parser.py +922 -0
  6. openadapt_ml/baselines/prompts.py +787 -0
  7. openadapt_ml/benchmarks/__init__.py +13 -115
  8. openadapt_ml/benchmarks/agent.py +265 -421
  9. openadapt_ml/benchmarks/azure.py +28 -19
  10. openadapt_ml/benchmarks/azure_ops_tracker.py +521 -0
  11. openadapt_ml/benchmarks/cli.py +1722 -4847
  12. openadapt_ml/benchmarks/trace_export.py +631 -0
  13. openadapt_ml/benchmarks/viewer.py +22 -5
  14. openadapt_ml/benchmarks/vm_monitor.py +530 -29
  15. openadapt_ml/benchmarks/waa_deploy/Dockerfile +47 -53
  16. openadapt_ml/benchmarks/waa_deploy/api_agent.py +21 -20
  17. openadapt_ml/cloud/azure_inference.py +3 -5
  18. openadapt_ml/cloud/lambda_labs.py +722 -307
  19. openadapt_ml/cloud/local.py +2038 -487
  20. openadapt_ml/cloud/ssh_tunnel.py +68 -26
  21. openadapt_ml/datasets/next_action.py +40 -30
  22. openadapt_ml/evals/grounding.py +8 -3
  23. openadapt_ml/evals/plot_eval_metrics.py +15 -13
  24. openadapt_ml/evals/trajectory_matching.py +41 -26
  25. openadapt_ml/experiments/demo_prompt/format_demo.py +16 -6
  26. openadapt_ml/experiments/demo_prompt/run_experiment.py +26 -16
  27. openadapt_ml/experiments/representation_shootout/__init__.py +70 -0
  28. openadapt_ml/experiments/representation_shootout/conditions.py +708 -0
  29. openadapt_ml/experiments/representation_shootout/config.py +390 -0
  30. openadapt_ml/experiments/representation_shootout/evaluator.py +659 -0
  31. openadapt_ml/experiments/representation_shootout/runner.py +687 -0
  32. openadapt_ml/experiments/waa_demo/runner.py +29 -14
  33. openadapt_ml/export/parquet.py +36 -24
  34. openadapt_ml/grounding/detector.py +18 -14
  35. openadapt_ml/ingest/__init__.py +8 -6
  36. openadapt_ml/ingest/capture.py +25 -22
  37. openadapt_ml/ingest/loader.py +7 -4
  38. openadapt_ml/ingest/synthetic.py +189 -100
  39. openadapt_ml/models/api_adapter.py +14 -4
  40. openadapt_ml/models/base_adapter.py +10 -2
  41. openadapt_ml/models/providers/__init__.py +288 -0
  42. openadapt_ml/models/providers/anthropic.py +266 -0
  43. openadapt_ml/models/providers/base.py +299 -0
  44. openadapt_ml/models/providers/google.py +376 -0
  45. openadapt_ml/models/providers/openai.py +342 -0
  46. openadapt_ml/models/qwen_vl.py +46 -19
  47. openadapt_ml/perception/__init__.py +35 -0
  48. openadapt_ml/perception/integration.py +399 -0
  49. openadapt_ml/retrieval/demo_retriever.py +50 -24
  50. openadapt_ml/retrieval/embeddings.py +9 -8
  51. openadapt_ml/retrieval/retriever.py +3 -1
  52. openadapt_ml/runtime/__init__.py +50 -0
  53. openadapt_ml/runtime/policy.py +18 -5
  54. openadapt_ml/runtime/safety_gate.py +471 -0
  55. openadapt_ml/schema/__init__.py +9 -0
  56. openadapt_ml/schema/converters.py +74 -27
  57. openadapt_ml/schema/episode.py +31 -18
  58. openadapt_ml/scripts/capture_screenshots.py +530 -0
  59. openadapt_ml/scripts/compare.py +85 -54
  60. openadapt_ml/scripts/demo_policy.py +4 -1
  61. openadapt_ml/scripts/eval_policy.py +15 -9
  62. openadapt_ml/scripts/make_gif.py +1 -1
  63. openadapt_ml/scripts/prepare_synthetic.py +3 -1
  64. openadapt_ml/scripts/train.py +21 -9
  65. openadapt_ml/segmentation/README.md +920 -0
  66. openadapt_ml/segmentation/__init__.py +97 -0
  67. openadapt_ml/segmentation/adapters/__init__.py +5 -0
  68. openadapt_ml/segmentation/adapters/capture_adapter.py +420 -0
  69. openadapt_ml/segmentation/annotator.py +610 -0
  70. openadapt_ml/segmentation/cache.py +290 -0
  71. openadapt_ml/segmentation/cli.py +674 -0
  72. openadapt_ml/segmentation/deduplicator.py +656 -0
  73. openadapt_ml/segmentation/frame_describer.py +788 -0
  74. openadapt_ml/segmentation/pipeline.py +340 -0
  75. openadapt_ml/segmentation/schemas.py +622 -0
  76. openadapt_ml/segmentation/segment_extractor.py +634 -0
  77. openadapt_ml/training/azure_ops_viewer.py +1097 -0
  78. openadapt_ml/training/benchmark_viewer.py +52 -41
  79. openadapt_ml/training/shared_ui.py +7 -7
  80. openadapt_ml/training/stub_provider.py +57 -35
  81. openadapt_ml/training/trainer.py +143 -86
  82. openadapt_ml/training/trl_trainer.py +70 -21
  83. openadapt_ml/training/viewer.py +323 -108
  84. openadapt_ml/training/viewer_components.py +180 -0
  85. {openadapt_ml-0.2.0.dist-info → openadapt_ml-0.2.1.dist-info}/METADATA +215 -14
  86. openadapt_ml-0.2.1.dist-info/RECORD +116 -0
  87. openadapt_ml/benchmarks/base.py +0 -366
  88. openadapt_ml/benchmarks/data_collection.py +0 -432
  89. openadapt_ml/benchmarks/live_tracker.py +0 -180
  90. openadapt_ml/benchmarks/runner.py +0 -418
  91. openadapt_ml/benchmarks/waa.py +0 -761
  92. openadapt_ml/benchmarks/waa_live.py +0 -619
  93. openadapt_ml-0.2.0.dist-info/RECORD +0 -86
  94. {openadapt_ml-0.2.0.dist-info → openadapt_ml-0.2.1.dist-info}/WHEEL +0 -0
  95. {openadapt_ml-0.2.0.dist-info → openadapt_ml-0.2.1.dist-info}/licenses/LICENSE +0 -0
@@ -4,9 +4,9 @@ import json
4
4
  import time
5
5
  from dataclasses import dataclass, field
6
6
  from pathlib import Path
7
- from typing import Any, Callable, Dict, List, Optional
7
+ from typing import Any, Dict, List
8
8
 
9
- from openadapt_ml.schema import Episode, Step, Action, ActionType
9
+ from openadapt_ml.schema import ActionType
10
10
  from openadapt_ml.training.shared_ui import (
11
11
  get_shared_header_css as _get_shared_header_css,
12
12
  generate_shared_header_html as _generate_shared_header_html,
@@ -108,9 +108,10 @@ class TrainingConfig:
108
108
  @dataclass
109
109
  class TrainingState:
110
110
  """Tracks training progress for visualization."""
111
+
111
112
  # Job identification
112
113
  job_id: str = field(default_factory=lambda: time.strftime("%Y%m%d_%H%M%S"))
113
- hostname: str = field(default_factory=lambda: __import__('socket').gethostname())
114
+ hostname: str = field(default_factory=lambda: __import__("socket").gethostname())
114
115
  capture_path: str = ""
115
116
  config_path: str = ""
116
117
  goal: str = "" # Task goal/description for the training run
@@ -142,7 +143,9 @@ class TrainingState:
142
143
  setup_status: str = "" # e.g. "booting", "installing", "training", "complete"
143
144
  setup_logs: List[str] = field(default_factory=list) # Setup progress messages
144
145
  # Termination tracking
145
- termination_status: str = "" # e.g. "auto_low_loss", "auto_complete", "user_stop", "running"
146
+ termination_status: str = (
147
+ "" # e.g. "auto_low_loss", "auto_complete", "user_stop", "running"
148
+ )
146
149
  termination_message: str = "" # Human-readable termination reason
147
150
 
148
151
  def log_step(self, epoch: int, step: int, loss: float, lr: float = 0.0) -> None:
@@ -151,33 +154,46 @@ class TrainingState:
151
154
  self.step = step
152
155
  self.loss = loss
153
156
  self.learning_rate = lr
154
- self.losses.append({
155
- "epoch": epoch,
156
- "step": step,
157
- "loss": loss,
158
- "lr": lr,
159
- "time": time.time() - self.start_time,
160
- })
161
-
162
- def log_evaluation(self, epoch: int, sample_idx: int, image_path: str,
163
- human_action: Dict, predicted_action: Dict) -> None:
157
+ self.losses.append(
158
+ {
159
+ "epoch": epoch,
160
+ "step": step,
161
+ "loss": loss,
162
+ "lr": lr,
163
+ "time": time.time() - self.start_time,
164
+ }
165
+ )
166
+
167
+ def log_evaluation(
168
+ self,
169
+ epoch: int,
170
+ sample_idx: int,
171
+ image_path: str,
172
+ human_action: Dict,
173
+ predicted_action: Dict,
174
+ ) -> None:
164
175
  """Log an evaluation sample."""
165
176
  # Calculate distance for click actions
166
177
  distance = 0.0
167
- if human_action.get("type") == "click" and predicted_action.get("type") == "click":
178
+ if (
179
+ human_action.get("type") == "click"
180
+ and predicted_action.get("type") == "click"
181
+ ):
168
182
  hx, hy = human_action.get("x", 0), human_action.get("y", 0)
169
183
  px, py = predicted_action.get("x", 0), predicted_action.get("y", 0)
170
184
  distance = ((hx - px) ** 2 + (hy - py) ** 2) ** 0.5
171
185
 
172
- self.evaluations.append({
173
- "epoch": epoch,
174
- "sample_idx": sample_idx,
175
- "image_path": image_path,
176
- "human_action": human_action,
177
- "predicted_action": predicted_action,
178
- "distance": distance,
179
- "correct": distance < 50, # Within 50 pixels is "correct"
180
- })
186
+ self.evaluations.append(
187
+ {
188
+ "epoch": epoch,
189
+ "sample_idx": sample_idx,
190
+ "image_path": image_path,
191
+ "human_action": human_action,
192
+ "predicted_action": predicted_action,
193
+ "distance": distance,
194
+ "correct": distance < 50, # Within 50 pixels is "correct"
195
+ }
196
+ )
181
197
 
182
198
  def to_dict(self) -> Dict[str, Any]:
183
199
  """Convert state to serializable dict."""
@@ -195,7 +211,9 @@ class TrainingState:
195
211
  "load_in_4bit": self.load_in_4bit,
196
212
  "instance_type": self.instance_type,
197
213
  "instance_ip": self.instance_ip,
198
- "started_at": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime(self.start_time)),
214
+ "started_at": time.strftime(
215
+ "%Y-%m-%dT%H:%M:%SZ", time.gmtime(self.start_time)
216
+ ),
199
217
  # Cloud provider info
200
218
  "cloud_provider": self.cloud_provider,
201
219
  "cloud_dashboard_url": self.cloud_dashboard_url,
@@ -316,6 +334,7 @@ class TrainingLogger:
316
334
  def _save_config_snapshot(self) -> None:
317
335
  """Save training config snapshot to JSON."""
318
336
  from dataclasses import asdict
337
+
319
338
  config_file = self.output_dir / "config.json"
320
339
  config_dict = asdict(self.config)
321
340
  with open(config_file, "w") as f:
@@ -333,32 +352,45 @@ class TrainingLogger:
333
352
  dashboard_path.write_text(html)
334
353
 
335
354
 
336
- def _generate_termination_status_html(state: TrainingState, is_training_complete: bool) -> str:
355
+ def _generate_termination_status_html(
356
+ state: TrainingState, is_training_complete: bool
357
+ ) -> str:
337
358
  """Generate HTML for termination status section."""
338
359
  # Check if we have termination info
339
360
  if state.termination_status:
340
361
  # Map termination status to colors and icons
341
362
  status_styles = {
342
- "auto_complete": {"color": "#22c55e", "icon": "✓", "label": "Training Complete"},
343
- "auto_low_loss": {"color": "#22c55e", "icon": "✓", "label": "Auto-Stopped (Low Loss)"},
363
+ "auto_complete": {
364
+ "color": "#22c55e",
365
+ "icon": "✓",
366
+ "label": "Training Complete",
367
+ },
368
+ "auto_low_loss": {
369
+ "color": "#22c55e",
370
+ "icon": "✓",
371
+ "label": "Auto-Stopped (Low Loss)",
372
+ },
344
373
  "user_stop": {"color": "#f59e0b", "icon": "■", "label": "Stopped by User"},
345
374
  }
346
- style = status_styles.get(state.termination_status, {"color": "#22c55e", "icon": "✓", "label": "Complete"})
375
+ style = status_styles.get(
376
+ state.termination_status,
377
+ {"color": "#22c55e", "icon": "✓", "label": "Complete"},
378
+ )
347
379
 
348
- return f'''<div style="display: flex; flex-direction: column; gap: 8px;">
349
- <div style="display: flex; align-items: center; gap: 8px; color: {style['color']};">
350
- <span style="font-size: 1.2rem;">{style['icon']}</span>
351
- <span style="font-weight: 600;">{style['label']}</span>
380
+ return f"""<div style="display: flex; flex-direction: column; gap: 8px;">
381
+ <div style="display: flex; align-items: center; gap: 8px; color: {style["color"]};">
382
+ <span style="font-size: 1.2rem;">{style["icon"]}</span>
383
+ <span style="font-weight: 600;">{style["label"]}</span>
352
384
  </div>
353
- {f'<div style="font-size: 0.85rem; color: var(--text-muted); margin-left: 28px;">{state.termination_message}</div>' if state.termination_message else ''}
354
- </div>'''
385
+ {f'<div style="font-size: 0.85rem; color: var(--text-muted); margin-left: 28px;">{state.termination_message}</div>' if state.termination_message else ""}
386
+ </div>"""
355
387
  elif is_training_complete:
356
- return '''<div style="display: flex; align-items: center; gap: 8px; color: #22c55e;">
388
+ return """<div style="display: flex; align-items: center; gap: 8px; color: #22c55e;">
357
389
  <span style="font-size: 1.2rem;">✓</span>
358
390
  <span style="font-weight: 600;">Training Complete</span>
359
- </div>'''
391
+ </div>"""
360
392
  else:
361
- return '''<button id="stop-training-btn" onclick="stopTraining()" style="
393
+ return """<button id="stop-training-btn" onclick="stopTraining()" style="
362
394
  background: linear-gradient(135deg, #ef4444 0%, #dc2626 100%);
363
395
  color: white;
364
396
  border: none;
@@ -374,53 +406,65 @@ def _generate_termination_status_html(state: TrainingState, is_training_complete
374
406
  ">
375
407
  <span style="font-size: 1.1rem;">■</span> Stop Training
376
408
  </button>
377
- <p id="stop-status" style="margin-top: 8px; font-size: 0.75rem; color: var(--text-muted);"></p>'''
409
+ <p id="stop-status" style="margin-top: 8px; font-size: 0.75rem; color: var(--text-muted);"></p>"""
378
410
 
379
411
 
380
412
  def generate_training_dashboard(state: TrainingState, config: TrainingConfig) -> str:
381
413
  """Generate an HTML dashboard for training visualization."""
382
414
  losses_json = json.dumps(state.losses)
383
415
  # Use stored elapsed_time if available (historical data), otherwise calculate
384
- elapsed = state.elapsed_time if state.elapsed_time > 0 else time.time() - state.start_time
416
+ elapsed = (
417
+ state.elapsed_time if state.elapsed_time > 0 else time.time() - state.start_time
418
+ )
385
419
  elapsed_str = f"{int(elapsed // 60)}m {int(elapsed % 60)}s"
386
420
 
387
421
  # Calculate stats
388
422
  if state.losses:
389
- min_loss = min(l["loss"] for l in state.losses)
390
- avg_loss = sum(l["loss"] for l in state.losses) / len(state.losses)
423
+ min_loss = min(loss["loss"] for loss in state.losses)
424
+ sum(loss["loss"] for loss in state.losses) / len(state.losses)
391
425
  recent_losses = state.losses[-10:] if len(state.losses) >= 10 else state.losses
392
- recent_avg = sum(l["loss"] for l in recent_losses) / len(recent_losses)
426
+ recent_avg = sum(loss["loss"] for loss in recent_losses) / len(recent_losses)
393
427
  # Calculate step times
394
428
  step_times = []
395
429
  for i in range(1, len(state.losses)):
396
- step_times.append(state.losses[i]["time"] - state.losses[i-1]["time"])
430
+ step_times.append(state.losses[i]["time"] - state.losses[i - 1]["time"])
397
431
  avg_step_time = sum(step_times) / len(step_times) if step_times else 0
398
432
  # Loss by epoch
399
433
  epoch_losses: dict = {}
400
- for l in state.losses:
401
- ep = l["epoch"]
434
+ for loss in state.losses:
435
+ ep = loss["epoch"]
402
436
  if ep not in epoch_losses:
403
437
  epoch_losses[ep] = []
404
- epoch_losses[ep].append(l["loss"])
405
- epoch_avg = {ep: sum(losses)/len(losses) for ep, losses in epoch_losses.items()}
438
+ epoch_losses[ep].append(loss["loss"])
439
+ epoch_avg = {
440
+ ep: sum(losses) / len(losses) for ep, losses in epoch_losses.items()
441
+ }
406
442
  # Estimate ETA
407
443
  # Steps per epoch = steps in completed epochs / completed epochs
408
444
  completed_epochs = state.epoch
409
- steps_in_completed = sum(1 for l in state.losses if l["epoch"] < completed_epochs)
445
+ steps_in_completed = sum(
446
+ 1 for loss in state.losses if loss["epoch"] < completed_epochs
447
+ )
410
448
  if completed_epochs > 0 and steps_in_completed > 0:
411
449
  steps_per_epoch = steps_in_completed / completed_epochs
412
450
  else:
413
451
  # Estimate from current epoch progress
414
- steps_per_epoch = len(state.losses) / (state.epoch + 1) if state.epoch >= 0 else len(state.losses)
415
-
416
- total_epochs = state.total_epochs if state.total_epochs > 0 else config.num_train_epochs
452
+ steps_per_epoch = (
453
+ len(state.losses) / (state.epoch + 1)
454
+ if state.epoch >= 0
455
+ else len(state.losses)
456
+ )
457
+
458
+ total_epochs = (
459
+ state.total_epochs if state.total_epochs > 0 else config.num_train_epochs
460
+ )
417
461
  total_steps_estimate = steps_per_epoch * total_epochs
418
462
  remaining_steps = max(0, total_steps_estimate - len(state.losses))
419
463
  eta_seconds = remaining_steps * avg_step_time if avg_step_time > 0 else 0
420
464
  # Check if training is complete (all steps done)
421
465
  is_training_complete = remaining_steps == 0 and len(state.losses) > 0
422
466
  else:
423
- min_loss = avg_loss = recent_avg = avg_step_time = 0.0
467
+ min_loss = recent_avg = avg_step_time = 0.0
424
468
  epoch_avg = {}
425
469
  eta_seconds = 0
426
470
  steps_per_epoch = 0
@@ -431,10 +475,9 @@ def generate_training_dashboard(state: TrainingState, config: TrainingConfig) ->
431
475
  epoch_avg_json = json.dumps(list(epoch_avg.items()))
432
476
 
433
477
  # Generate comparison viewer preview if capture path available
434
- comparison_viewer_path = ""
435
478
  if state.capture_path:
436
479
  try:
437
- from openadapt_ml.scripts.compare import generate_comparison_html, generate_comparison_data
480
+ from openadapt_ml.scripts.compare import generate_comparison_html
438
481
  from openadapt_ml.ingest.capture import capture_to_episode
439
482
 
440
483
  capture_path = Path(state.capture_path)
@@ -454,7 +497,9 @@ def generate_training_dashboard(state: TrainingState, config: TrainingConfig) ->
454
497
  "time": step.step_index,
455
498
  "image_path": step.observation.screenshot_path,
456
499
  "human_action": {
457
- "type": step.action.type.value if isinstance(step.action.type, ActionType) else step.action.type,
500
+ "type": step.action.type.value
501
+ if isinstance(step.action.type, ActionType)
502
+ else step.action.type,
458
503
  "x": action_x,
459
504
  "y": action_y,
460
505
  "text": step.action.text,
@@ -465,15 +510,21 @@ def generate_training_dashboard(state: TrainingState, config: TrainingConfig) ->
465
510
  comparison_data.append(step_data)
466
511
 
467
512
  # Generate comparison HTML
468
- output_dir = Path(config.output_dir) if hasattr(config, 'output_dir') else Path("training_output")
513
+ output_dir = (
514
+ Path(config.output_dir)
515
+ if hasattr(config, "output_dir")
516
+ else Path("training_output")
517
+ )
469
518
  output_dir.mkdir(parents=True, exist_ok=True)
470
519
  comparison_output = output_dir / "comparison_preview.html"
471
- generate_comparison_html(capture_path, episode, comparison_data, comparison_output)
472
- comparison_viewer_path = str(comparison_output.name) # Relative path
473
- except Exception as e:
520
+ generate_comparison_html(
521
+ capture_path, episode, comparison_data, comparison_output
522
+ )
523
+ str(comparison_output.name) # Relative path
524
+ except Exception:
474
525
  pass # Fail silently if comparison viewer can't be generated
475
526
 
476
- html = f'''<!DOCTYPE html>
527
+ html = f"""<!DOCTYPE html>
477
528
  <html lang="en">
478
529
  <head>
479
530
  <meta charset="UTF-8">
@@ -1132,10 +1183,10 @@ def generate_training_dashboard(state: TrainingState, config: TrainingConfig) ->
1132
1183
  <div class="container">
1133
1184
  <header>
1134
1185
  <div>
1135
- <h1>Training Dashboard{f' <a href="{state.cloud_dashboard_url}" target="_blank" class="cloud-link cloud-badge"><svg viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2"><path d="M18 10h-1.26A8 8 0 1 0 9 20h9a5 5 0 0 0 0-10z"/></svg>{state.cloud_provider.title()} Cloud</a>' if state.cloud_dashboard_url else ''}</h1>
1186
+ <h1>Training Dashboard{f' <a href="{state.cloud_dashboard_url}" target="_blank" class="cloud-link cloud-badge"><svg viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2"><path d="M18 10h-1.26A8 8 0 1 0 9 20h9a5 5 0 0 0 0-10z"/></svg>{state.cloud_provider.title()} Cloud</a>' if state.cloud_dashboard_url else ""}</h1>
1136
1187
  <div class="job-info" id="job-info">
1137
- <span class="job-host">{state.hostname or 'stub-local'} @ {state.instance_ip or '127.0.0.1'}</span>
1138
- {f'<span class="job-config">{state.instance_type}</span>' if state.instance_type else ''}
1188
+ <span class="job-host">{state.hostname or "stub-local"} @ {state.instance_ip or "127.0.0.1"}</span>
1189
+ {f'<span class="job-config">{state.instance_type}</span>' if state.instance_type else ""}
1139
1190
  </div>
1140
1191
  </div>
1141
1192
  <div class="status" id="status">
@@ -1144,13 +1195,13 @@ def generate_training_dashboard(state: TrainingState, config: TrainingConfig) ->
1144
1195
  </div>
1145
1196
  </header>
1146
1197
 
1147
- <div class="setup-panel{' hidden' if not state.setup_logs else ''}" id="setup-panel">
1198
+ <div class="setup-panel{" hidden" if not state.setup_logs else ""}" id="setup-panel">
1148
1199
  <div class="setup-header">
1149
1200
  <h2>Setup Progress</h2>
1150
- <span class="setup-status-badge {state.setup_status}" id="setup-status-badge">{state.setup_status or 'initializing'}</span>
1201
+ <span class="setup-status-badge {state.setup_status}" id="setup-status-badge">{state.setup_status or "initializing"}</span>
1151
1202
  </div>
1152
1203
  <div class="setup-logs" id="setup-logs">
1153
- {''.join(f'<div class="setup-log-line{" current" if i == len(state.setup_logs) - 1 else ""}">{log}</div>' for i, log in enumerate(state.setup_logs)) if state.setup_logs else '<div class="setup-log-line">Waiting for setup logs...</div>'}
1204
+ {"".join(f'<div class="setup-log-line{" current" if i == len(state.setup_logs) - 1 else ""}">{log}</div>' for i, log in enumerate(state.setup_logs)) if state.setup_logs else '<div class="setup-log-line">Waiting for setup logs...</div>'}
1154
1205
  </div>
1155
1206
  </div>
1156
1207
 
@@ -1160,23 +1211,23 @@ def generate_training_dashboard(state: TrainingState, config: TrainingConfig) ->
1160
1211
  <div class="config-grid">
1161
1212
  <div class="config-item">
1162
1213
  <span class="config-label">Model</span>
1163
- <span class="config-value model" id="config-model">{state.model_name or 'Not specified'}</span>
1214
+ <span class="config-value model" id="config-model">{state.model_name or "Not specified"}</span>
1164
1215
  </div>
1165
1216
  <div class="config-item">
1166
1217
  <span class="config-label">Goal</span>
1167
- <span class="config-value goal" id="config-goal">{state.goal or 'Not specified'}</span>
1218
+ <span class="config-value goal" id="config-goal">{state.goal or "Not specified"}</span>
1168
1219
  </div>
1169
1220
  <div class="config-item">
1170
1221
  <span class="config-label">LoRA</span>
1171
- <span class="config-value" id="config-lora">{f'r={state.lora_r}, α={state.lora_alpha}' if state.lora_r else 'Not specified'}</span>
1222
+ <span class="config-value" id="config-lora">{f"r={state.lora_r}, α={state.lora_alpha}" if state.lora_r else "Not specified"}</span>
1172
1223
  </div>
1173
1224
  <div class="config-item">
1174
1225
  <span class="config-label">Quantization</span>
1175
- <span class="config-value" id="config-quant">{'4-bit' if state.load_in_4bit else 'None'}</span>
1226
+ <span class="config-value" id="config-quant">{"4-bit" if state.load_in_4bit else "None"}</span>
1176
1227
  </div>
1177
1228
  <div class="config-item">
1178
1229
  <span class="config-label">Config</span>
1179
- <span class="config-value" id="config-path">{state.config_path or 'Not specified'}</span>
1230
+ <span class="config-value" id="config-path">{state.config_path or "Not specified"}</span>
1180
1231
  </div>
1181
1232
  </div>
1182
1233
  </div>
@@ -1523,7 +1574,7 @@ def generate_training_dashboard(state: TrainingState, config: TrainingConfig) ->
1523
1574
  let etaSeconds = {eta_seconds};
1524
1575
  let avgStepTime = {avg_step_time};
1525
1576
  let remainingSteps = {remaining_steps};
1526
- let isTrainingComplete = {'true' if is_training_complete else 'false'};
1577
+ let isTrainingComplete = {"true" if is_training_complete else "false"};
1527
1578
 
1528
1579
  // Auto-stop when loss <= threshold (INVARIANT: training should stop when loss <= 1.0)
1529
1580
  const AUTO_STOP_LOSS_THRESHOLD = 1.0;
@@ -2022,7 +2073,7 @@ def generate_training_dashboard(state: TrainingState, config: TrainingConfig) ->
2022
2073
  setInterval(updateStatusIndicator, 1000); // Update LIVE/STALE indicator every second
2023
2074
  </script>
2024
2075
  </body>
2025
- </html>'''
2076
+ </html>"""
2026
2077
  return html
2027
2078
 
2028
2079
 
@@ -2060,6 +2111,7 @@ def regenerate_all_dashboards(output_dir: str | Path) -> list[Path]:
2060
2111
  except Exception as e:
2061
2112
  print(f"Warning: Failed to generate unified viewer: {e}")
2062
2113
  import traceback
2114
+
2063
2115
  traceback.print_exc()
2064
2116
 
2065
2117
  return regenerated
@@ -2102,7 +2154,9 @@ def regenerate_local_dashboard(
2102
2154
  state = TrainingState(
2103
2155
  job_id=data.get("job_id", "unknown"),
2104
2156
  hostname=data.get("hostname", ""),
2105
- capture_path=str(capture_path) if capture_path else data.get("capture_path", ""),
2157
+ capture_path=str(capture_path)
2158
+ if capture_path
2159
+ else data.get("capture_path", ""),
2106
2160
  config_path=data.get("config_path", ""),
2107
2161
  epoch=data.get("epoch", 0),
2108
2162
  step=data.get("step", 0),
@@ -2145,30 +2199,33 @@ def regenerate_local_dashboard(
2145
2199
  if training_status == "COMPLETED":
2146
2200
  html = html.replace(
2147
2201
  '<div class="status" id="status">',
2148
- '<div class="status complete" id="status">'
2202
+ '<div class="status complete" id="status">',
2149
2203
  )
2150
2204
  html = html.replace(
2151
2205
  '<span id="status-text">Training in progress</span>',
2152
- '<span id="status-text">COMPLETED</span>'
2206
+ '<span id="status-text">COMPLETED</span>',
2153
2207
  )
2154
2208
  elif training_status == "STOPPED":
2155
2209
  html = html.replace(
2156
- '<div class="status" id="status">',
2157
- '<div class="status stale" id="status">'
2210
+ '<div class="status" id="status">', '<div class="status stale" id="status">'
2158
2211
  )
2159
2212
  html = html.replace(
2160
2213
  '<span id="status-text">Training in progress</span>',
2161
- '<span id="status-text">STOPPED (Epoch {}/{})'.format(current_epoch + 1, total_epochs) + '</span>'
2214
+ '<span id="status-text">STOPPED (Epoch {}/{})'.format(
2215
+ current_epoch + 1, total_epochs
2216
+ )
2217
+ + "</span>",
2162
2218
  )
2163
2219
 
2164
2220
  # Fix ETA display for completed/stopped training
2165
2221
  import re
2222
+
2166
2223
  if training_status in ("COMPLETED", "STOPPED"):
2167
2224
  # Replace "calculating..." with appropriate status
2168
2225
  html = re.sub(
2169
2226
  r'(<div class="stat-value" id="stat-eta">)[^<]*(</div>)',
2170
- r'\1—\2' if training_status == "STOPPED" else r'\1complete\2',
2171
- html
2227
+ r"\1—\2" if training_status == "STOPPED" else r"\1complete\2",
2228
+ html,
2172
2229
  )
2173
2230
 
2174
2231
  # Replace dynamic nav with static unified header
@@ -2179,20 +2236,20 @@ def regenerate_local_dashboard(
2179
2236
  # This is critical for file:// protocol where fetch() doesn't work
2180
2237
  html = html.replace(
2181
2238
  "setInterval(fetchAndUpdate, 3000);",
2182
- "// fetchAndUpdate disabled for static dashboard"
2239
+ "// fetchAndUpdate disabled for static dashboard",
2183
2240
  )
2184
2241
  html = html.replace(
2185
2242
  "setInterval(updateElapsedDisplay, 1000);",
2186
- "// updateElapsedDisplay disabled for static dashboard"
2243
+ "// updateElapsedDisplay disabled for static dashboard",
2187
2244
  )
2188
2245
  html = html.replace(
2189
2246
  "setInterval(updateStatusIndicator, 1000);",
2190
- "// updateStatusIndicator disabled for static dashboard"
2247
+ "// updateStatusIndicator disabled for static dashboard",
2191
2248
  )
2192
2249
  # CRITICAL: Disable discoverDashboards() - it overwrites static nav on file:// protocol
2193
2250
  html = html.replace(
2194
2251
  "discoverDashboards();",
2195
- "// discoverDashboards disabled - using static nav for file:// protocol"
2252
+ "// discoverDashboards disabled - using static nav for file:// protocol",
2196
2253
  )
2197
2254
 
2198
2255
  # Write output
@@ -91,7 +91,9 @@ def _load_unsloth_model(config: TRLTrainingConfig):
91
91
  # Enable training mode
92
92
  FastVisionModel.for_training(model)
93
93
 
94
- print(f"✓ Loaded {config.model_name} with Unsloth (4-bit: {config.load_in_4bit})")
94
+ print(
95
+ f"✓ Loaded {config.model_name} with Unsloth (4-bit: {config.load_in_4bit})"
96
+ )
95
97
  return model, tokenizer, True
96
98
 
97
99
  except ImportError:
@@ -100,26 +102,70 @@ def _load_unsloth_model(config: TRLTrainingConfig):
100
102
 
101
103
 
102
104
  def _load_standard_model(config: TRLTrainingConfig):
103
- """Fallback: Load model with standard transformers + peft."""
104
- from transformers import AutoModelForCausalLM, AutoProcessor
105
+ """Fallback: Load model with standard transformers + peft.
106
+
107
+ Automatically detects vision-language models and uses the appropriate
108
+ model class (Qwen2VLForConditionalGeneration for VL models,
109
+ AutoModelForCausalLM for text-only models).
110
+ """
111
+ from transformers import AutoConfig, AutoProcessor
105
112
  from peft import LoraConfig, get_peft_model
106
113
  import torch
107
114
 
108
- model = AutoModelForCausalLM.from_pretrained(
109
- config.model_name,
110
- torch_dtype=torch.bfloat16,
111
- device_map="auto",
112
- trust_remote_code=True,
115
+ # Check if this is a vision-language model
116
+ model_config = AutoConfig.from_pretrained(
117
+ config.model_name, trust_remote_code=True
118
+ )
119
+ is_vl_model = (
120
+ "VL" in config.model_name.upper()
121
+ or "vision" in config.model_name.lower()
122
+ or hasattr(model_config, "vision_config")
113
123
  )
124
+
125
+ if is_vl_model:
126
+ # Vision-language model - use Qwen2VLForConditionalGeneration or AutoModelForVision2Seq
127
+ try:
128
+ from transformers import Qwen2VLForConditionalGeneration
129
+
130
+ model = Qwen2VLForConditionalGeneration.from_pretrained(
131
+ config.model_name,
132
+ torch_dtype=torch.bfloat16,
133
+ device_map="auto",
134
+ trust_remote_code=True,
135
+ )
136
+ print(" Using Qwen2VLForConditionalGeneration for VL model")
137
+ except (ImportError, ValueError, RuntimeError, TypeError):
138
+ # Fallback to AutoModelForVision2Seq for other VL models
139
+ from transformers import AutoModelForVision2Seq
140
+
141
+ model = AutoModelForVision2Seq.from_pretrained(
142
+ config.model_name,
143
+ torch_dtype=torch.bfloat16,
144
+ device_map="auto",
145
+ trust_remote_code=True,
146
+ )
147
+ print(" Using AutoModelForVision2Seq for VL model")
148
+ else:
149
+ # Text-only model
150
+ from transformers import AutoModelForCausalLM
151
+
152
+ model = AutoModelForCausalLM.from_pretrained(
153
+ config.model_name,
154
+ torch_dtype=torch.bfloat16,
155
+ device_map="auto",
156
+ trust_remote_code=True,
157
+ )
158
+ print(" Using AutoModelForCausalLM for text-only model")
159
+
114
160
  processor = AutoProcessor.from_pretrained(config.model_name, trust_remote_code=True)
115
161
 
116
- # Apply LoRA
162
+ # Apply LoRA - use SEQ_2_SEQ_LM for VL models, CAUSAL_LM for text-only
117
163
  peft_config = LoraConfig(
118
164
  r=config.lora_r,
119
165
  lora_alpha=config.lora_alpha,
120
166
  lora_dropout=config.lora_dropout,
121
167
  target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
122
- task_type="CAUSAL_LM",
168
+ task_type="SEQ_2_SEQ_LM" if is_vl_model else "CAUSAL_LM",
123
169
  )
124
170
  model = get_peft_model(model, peft_config)
125
171
 
@@ -161,10 +207,12 @@ def _convert_samples_to_trl_format(
161
207
  if not pil_images:
162
208
  continue # Skip samples with missing images
163
209
 
164
- trl_samples.append({
165
- "images": pil_images,
166
- "messages": sample["messages"],
167
- })
210
+ trl_samples.append(
211
+ {
212
+ "images": pil_images,
213
+ "messages": sample["messages"],
214
+ }
215
+ )
168
216
 
169
217
  return trl_samples
170
218
 
@@ -261,7 +309,7 @@ def train_with_trl(
261
309
  logging_steps=config.logging_steps,
262
310
  save_strategy=config.save_strategy,
263
311
  max_length=None, # Critical for VLMs
264
- assistant_only_loss=True,
312
+ assistant_only_loss=False, # Not supported for VL models yet
265
313
  )
266
314
 
267
315
  trainer = SFTTrainer(
@@ -270,15 +318,15 @@ def train_with_trl(
270
318
  args=training_args,
271
319
  )
272
320
 
273
- print(f"\n{'='*50}")
274
- print(f"Starting training:")
321
+ print(f"\n{'=' * 50}")
322
+ print("Starting training:")
275
323
  print(f" Model: {config.model_name}")
276
324
  print(f" Samples: {len(trl_samples)}")
277
325
  print(f" Epochs: {config.num_epochs}")
278
326
  print(f" Batch size: {config.batch_size}")
279
327
  print(f" Unsloth: {is_unsloth}")
280
328
  print(f" Output: {config.output_dir}")
281
- print(f"{'='*50}\n")
329
+ print(f"{'=' * 50}\n")
282
330
 
283
331
  trainer.train()
284
332
 
@@ -291,8 +339,7 @@ def train_with_trl(
291
339
 
292
340
  except ImportError as e:
293
341
  raise ImportError(
294
- f"TRL not installed. Install with: pip install trl\n"
295
- f"Original error: {e}"
342
+ f"TRL not installed. Install with: pip install trl\nOriginal error: {e}"
296
343
  )
297
344
 
298
345
 
@@ -333,7 +380,9 @@ if __name__ == "__main__":
333
380
  parser = argparse.ArgumentParser(description="Train VLM with TRL + Unsloth")
334
381
  parser.add_argument("--parquet", required=True, help="Path to parquet file")
335
382
  parser.add_argument("--output", default="checkpoints", help="Output directory")
336
- parser.add_argument("--model", default="unsloth/Qwen2.5-VL-7B-Instruct", help="Model name")
383
+ parser.add_argument(
384
+ "--model", default="unsloth/Qwen2.5-VL-7B-Instruct", help="Model name"
385
+ )
337
386
  parser.add_argument("--epochs", type=int, default=3, help="Number of epochs")
338
387
  parser.add_argument("--use-som", action="store_true", help="Use Set-of-Marks DSL")
339
388