openadapt-ml 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. openadapt_ml/__init__.py +0 -0
  2. openadapt_ml/benchmarks/__init__.py +125 -0
  3. openadapt_ml/benchmarks/agent.py +825 -0
  4. openadapt_ml/benchmarks/azure.py +761 -0
  5. openadapt_ml/benchmarks/base.py +366 -0
  6. openadapt_ml/benchmarks/cli.py +884 -0
  7. openadapt_ml/benchmarks/data_collection.py +432 -0
  8. openadapt_ml/benchmarks/runner.py +381 -0
  9. openadapt_ml/benchmarks/waa.py +704 -0
  10. openadapt_ml/cloud/__init__.py +5 -0
  11. openadapt_ml/cloud/azure_inference.py +441 -0
  12. openadapt_ml/cloud/lambda_labs.py +2445 -0
  13. openadapt_ml/cloud/local.py +790 -0
  14. openadapt_ml/config.py +56 -0
  15. openadapt_ml/datasets/__init__.py +0 -0
  16. openadapt_ml/datasets/next_action.py +507 -0
  17. openadapt_ml/evals/__init__.py +23 -0
  18. openadapt_ml/evals/grounding.py +241 -0
  19. openadapt_ml/evals/plot_eval_metrics.py +174 -0
  20. openadapt_ml/evals/trajectory_matching.py +486 -0
  21. openadapt_ml/grounding/__init__.py +45 -0
  22. openadapt_ml/grounding/base.py +236 -0
  23. openadapt_ml/grounding/detector.py +570 -0
  24. openadapt_ml/ingest/__init__.py +43 -0
  25. openadapt_ml/ingest/capture.py +312 -0
  26. openadapt_ml/ingest/loader.py +232 -0
  27. openadapt_ml/ingest/synthetic.py +1102 -0
  28. openadapt_ml/models/__init__.py +0 -0
  29. openadapt_ml/models/api_adapter.py +171 -0
  30. openadapt_ml/models/base_adapter.py +59 -0
  31. openadapt_ml/models/dummy_adapter.py +42 -0
  32. openadapt_ml/models/qwen_vl.py +426 -0
  33. openadapt_ml/runtime/__init__.py +0 -0
  34. openadapt_ml/runtime/policy.py +182 -0
  35. openadapt_ml/schemas/__init__.py +53 -0
  36. openadapt_ml/schemas/sessions.py +122 -0
  37. openadapt_ml/schemas/validation.py +252 -0
  38. openadapt_ml/scripts/__init__.py +0 -0
  39. openadapt_ml/scripts/compare.py +1490 -0
  40. openadapt_ml/scripts/demo_policy.py +62 -0
  41. openadapt_ml/scripts/eval_policy.py +287 -0
  42. openadapt_ml/scripts/make_gif.py +153 -0
  43. openadapt_ml/scripts/prepare_synthetic.py +43 -0
  44. openadapt_ml/scripts/run_qwen_login_benchmark.py +192 -0
  45. openadapt_ml/scripts/train.py +174 -0
  46. openadapt_ml/training/__init__.py +0 -0
  47. openadapt_ml/training/benchmark_viewer.py +1538 -0
  48. openadapt_ml/training/shared_ui.py +157 -0
  49. openadapt_ml/training/stub_provider.py +276 -0
  50. openadapt_ml/training/trainer.py +2446 -0
  51. openadapt_ml/training/viewer.py +2970 -0
  52. openadapt_ml-0.1.0.dist-info/METADATA +818 -0
  53. openadapt_ml-0.1.0.dist-info/RECORD +55 -0
  54. openadapt_ml-0.1.0.dist-info/WHEEL +4 -0
  55. openadapt_ml-0.1.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,790 @@
1
+ """Local GPU training CLI.
2
+
3
+ Provides commands equivalent to lambda_labs.py but for local execution
4
+ on CUDA or Apple Silicon.
5
+
6
+ Usage:
7
+ # Train on a capture
8
+ uv run python -m openadapt_ml.cloud.local train --capture ~/captures/my-workflow
9
+
10
+ # Check training status
11
+ uv run python -m openadapt_ml.cloud.local status
12
+
13
+ # Check training health
14
+ uv run python -m openadapt_ml.cloud.local check
15
+
16
+ # Start dashboard server
17
+ uv run python -m openadapt_ml.cloud.local serve --open
18
+
19
+ # Regenerate viewer
20
+ uv run python -m openadapt_ml.cloud.local viewer
21
+ """
22
+
23
+ from __future__ import annotations
24
+
25
+ import argparse
26
+ import http.server
27
+ import json
28
+ import os
29
+ import shutil
30
+ import signal
31
+ import socketserver
32
+ import subprocess
33
+ import sys
34
+ import threading
35
+ import webbrowser
36
+ from pathlib import Path
37
+ from typing import Any
38
+
39
+ # Training output directory
40
+ TRAINING_OUTPUT = Path("training_output")
41
+
42
+
43
+ def get_current_output_dir() -> Path:
44
+ """Get the current job's output directory.
45
+
46
+ Returns the 'current' symlink path if it exists, otherwise falls back
47
+ to the base training_output directory for backward compatibility.
48
+ """
49
+ current_link = TRAINING_OUTPUT / "current"
50
+ if current_link.is_symlink() or current_link.exists():
51
+ return current_link
52
+ # Fallback for backward compatibility with old structure
53
+ return TRAINING_OUTPUT
54
+
55
+
56
+ def _regenerate_viewer_if_possible(output_dir: Path) -> bool:
57
+ """Regenerate viewer.html if comparison data exists.
58
+
59
+ Returns True if viewer was regenerated, False otherwise.
60
+ """
61
+ from openadapt_ml.training.trainer import generate_unified_viewer_from_output_dir
62
+
63
+ try:
64
+ viewer_path = generate_unified_viewer_from_output_dir(output_dir)
65
+ if viewer_path:
66
+ print(f"Regenerated viewer: {viewer_path}")
67
+ return True
68
+ return False
69
+ except Exception as e:
70
+ print(f"Could not regenerate viewer: {e}")
71
+ return False
72
+
73
+
74
+ def _is_mock_benchmark(benchmark_dir: Path) -> bool:
75
+ """Check if a benchmark run is mock/test data (not real evaluation).
76
+
77
+ Returns True if the benchmark is mock data that should be filtered out.
78
+
79
+ Note: API evaluations using the mock WAA adapter (waa-mock) are considered
80
+ real evaluations and should NOT be filtered out, since they represent actual
81
+ model performance on test tasks.
82
+ """
83
+ # Check summary.json for model_id
84
+ summary_path = benchmark_dir / "summary.json"
85
+ if summary_path.exists():
86
+ try:
87
+ with open(summary_path) as f:
88
+ summary = json.load(f)
89
+ model_id = summary.get("model_id", "").lower()
90
+ # Filter out mock/test/random agent runs (but keep API models like "anthropic-api")
91
+ if any(term in model_id for term in ["random-agent", "scripted-agent"]):
92
+ return True
93
+ except Exception:
94
+ pass
95
+
96
+ # Check metadata.json for model_id
97
+ metadata_path = benchmark_dir / "metadata.json"
98
+ if metadata_path.exists():
99
+ try:
100
+ with open(metadata_path) as f:
101
+ metadata = json.load(f)
102
+ model_id = metadata.get("model_id", "").lower()
103
+ if any(term in model_id for term in ["random-agent", "scripted-agent"]):
104
+ return True
105
+ except Exception:
106
+ pass
107
+
108
+ # Check for test runs (but allow waa-mock evaluations with real API models)
109
+ # Only filter out purely synthetic test data directories
110
+ if any(term in benchmark_dir.name.lower() for term in ["test_run", "test_cli", "quick_demo"]):
111
+ return True
112
+
113
+ return False
114
+
115
+
116
+ def _regenerate_benchmark_viewer_if_available(output_dir: Path) -> bool:
117
+ """Regenerate benchmark.html from all real benchmark results.
118
+
119
+ Loads all non-mock benchmark runs from benchmark_results/ directory
120
+ and generates a unified benchmark viewer supporting multiple runs.
121
+ If no real benchmark data exists, generates an empty state viewer with guidance.
122
+
123
+ Returns True if benchmark viewer was regenerated, False otherwise.
124
+ """
125
+ from openadapt_ml.training.benchmark_viewer import (
126
+ generate_multi_run_benchmark_viewer,
127
+ generate_empty_benchmark_viewer,
128
+ )
129
+
130
+ benchmark_results_dir = Path("benchmark_results")
131
+
132
+ # Find real (non-mock) benchmark runs
133
+ real_benchmarks = []
134
+ if benchmark_results_dir.exists():
135
+ for d in benchmark_results_dir.iterdir():
136
+ if d.is_dir() and (d / "summary.json").exists():
137
+ if not _is_mock_benchmark(d):
138
+ real_benchmarks.append(d)
139
+
140
+ benchmark_html_path = output_dir / "benchmark.html"
141
+
142
+ if not real_benchmarks:
143
+ # No real benchmark data - generate empty state viewer
144
+ try:
145
+ generate_empty_benchmark_viewer(benchmark_html_path)
146
+ print(" Generated benchmark viewer: No real evaluation data yet")
147
+ return True
148
+ except Exception as e:
149
+ print(f" Could not generate empty benchmark viewer: {e}")
150
+ return False
151
+
152
+ # Sort by modification time (most recent first)
153
+ real_benchmarks.sort(key=lambda d: d.stat().st_mtime, reverse=True)
154
+
155
+ try:
156
+ # Generate multi-run benchmark.html in the output directory
157
+ generate_multi_run_benchmark_viewer(real_benchmarks, benchmark_html_path)
158
+
159
+ # Copy all tasks folders for screenshots (organized by run)
160
+ benchmark_tasks_dir = output_dir / "benchmark_tasks"
161
+ if benchmark_tasks_dir.exists():
162
+ shutil.rmtree(benchmark_tasks_dir)
163
+ benchmark_tasks_dir.mkdir(exist_ok=True)
164
+
165
+ for benchmark_dir in real_benchmarks:
166
+ tasks_src = benchmark_dir / "tasks"
167
+ if tasks_src.exists():
168
+ tasks_dst = benchmark_tasks_dir / benchmark_dir.name
169
+ shutil.copytree(tasks_src, tasks_dst)
170
+
171
+ print(f" Regenerated benchmark viewer with {len(real_benchmarks)} run(s)")
172
+ return True
173
+ except Exception as e:
174
+ print(f" Could not regenerate benchmark viewer: {e}")
175
+ import traceback
176
+ traceback.print_exc()
177
+ return False
178
+
179
+
180
+ def detect_device() -> str:
181
+ """Detect available compute device."""
182
+ try:
183
+ import torch
184
+ if torch.cuda.is_available():
185
+ device_name = torch.cuda.get_device_name(0)
186
+ return f"cuda ({device_name})"
187
+ elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
188
+ return "mps (Apple Silicon)"
189
+ else:
190
+ return "cpu"
191
+ except ImportError:
192
+ return "unknown (torch not installed)"
193
+
194
+
195
+ def get_training_status() -> dict[str, Any]:
196
+ """Get current training status from training_output/current."""
197
+ current_dir = get_current_output_dir()
198
+
199
+ status = {
200
+ "running": False,
201
+ "epoch": 0,
202
+ "step": 0,
203
+ "loss": None,
204
+ "device": detect_device(),
205
+ "has_dashboard": False,
206
+ "has_viewer": False,
207
+ "checkpoints": [],
208
+ "job_id": None,
209
+ "output_dir": str(current_dir),
210
+ }
211
+
212
+ log_file = current_dir / "training_log.json"
213
+ if log_file.exists():
214
+ try:
215
+ with open(log_file) as f:
216
+ data = json.load(f)
217
+ status["job_id"] = data.get("job_id")
218
+ status["epoch"] = data.get("epoch", 0)
219
+ status["step"] = data.get("step", 0)
220
+ status["loss"] = data.get("loss")
221
+ status["learning_rate"] = data.get("learning_rate")
222
+ status["losses"] = data.get("losses", [])
223
+ status["status"] = data.get("status", "unknown")
224
+ status["running"] = data.get("status") == "training"
225
+ except (json.JSONDecodeError, KeyError):
226
+ pass
227
+
228
+ status["has_dashboard"] = (current_dir / "dashboard.html").exists()
229
+ status["has_viewer"] = (current_dir / "viewer.html").exists()
230
+
231
+ # Find checkpoints
232
+ checkpoints_dir = Path("checkpoints")
233
+ if checkpoints_dir.exists():
234
+ status["checkpoints"] = sorted([
235
+ d.name for d in checkpoints_dir.iterdir()
236
+ if d.is_dir() and (d / "adapter_config.json").exists()
237
+ ])
238
+
239
+ return status
240
+
241
+
242
+ def cmd_status(args: argparse.Namespace) -> int:
243
+ """Show local training status."""
244
+ status = get_training_status()
245
+ current_dir = get_current_output_dir()
246
+
247
+ print(f"\n{'='*50}")
248
+ print("LOCAL TRAINING STATUS")
249
+ print(f"{'='*50}")
250
+ print(f"Device: {status['device']}")
251
+ print(f"Status: {'RUNNING' if status['running'] else 'IDLE'}")
252
+ if status.get("job_id"):
253
+ print(f"Job ID: {status['job_id']}")
254
+ print(f"Output: {current_dir}")
255
+
256
+ if status.get("epoch"):
257
+ print(f"\nProgress:")
258
+ print(f" Epoch: {status['epoch']}")
259
+ print(f" Step: {status['step']}")
260
+ if status.get("loss"):
261
+ print(f" Loss: {status['loss']:.4f}")
262
+ if status.get("learning_rate"):
263
+ print(f" LR: {status['learning_rate']:.2e}")
264
+
265
+ if status["checkpoints"]:
266
+ print(f"\nCheckpoints ({len(status['checkpoints'])}):")
267
+ for cp in status["checkpoints"][-5:]: # Show last 5
268
+ print(f" - {cp}")
269
+
270
+ print(f"\nDashboard: {'✓' if status['has_dashboard'] else '✗'} {current_dir}/dashboard.html")
271
+ print(f"Viewer: {'✓' if status['has_viewer'] else '✗'} {current_dir}/viewer.html")
272
+ print()
273
+
274
+ return 0
275
+
276
+
277
+ def cmd_train(args: argparse.Namespace) -> int:
278
+ """Run training locally."""
279
+ capture_path = Path(args.capture).expanduser().resolve()
280
+ if not capture_path.exists():
281
+ print(f"Error: Capture not found: {capture_path}")
282
+ return 1
283
+
284
+ # Determine goal from capture directory name if not provided
285
+ goal = args.goal
286
+ if not goal:
287
+ goal = capture_path.name.replace("-", " ").replace("_", " ").title()
288
+
289
+ # Select config based on device
290
+ config = args.config
291
+ if not config:
292
+ device = detect_device()
293
+ if "cuda" in device:
294
+ config = "configs/qwen3vl_capture.yaml"
295
+ else:
296
+ config = "configs/qwen3vl_capture_4bit.yaml"
297
+
298
+ config_path = Path(config)
299
+ if not config_path.exists():
300
+ print(f"Error: Config not found: {config_path}")
301
+ return 1
302
+
303
+ print(f"\n{'='*50}")
304
+ print("STARTING LOCAL TRAINING")
305
+ print(f"{'='*50}")
306
+ print(f"Capture: {capture_path}")
307
+ print(f"Goal: {goal}")
308
+ print(f"Config: {config}")
309
+ print(f"Device: {detect_device()}")
310
+ print()
311
+
312
+ # Build command
313
+ cmd = [
314
+ sys.executable, "-m", "openadapt_ml.scripts.train",
315
+ "--config", str(config_path),
316
+ "--capture", str(capture_path),
317
+ "--goal", goal,
318
+ ]
319
+
320
+ if args.open:
321
+ cmd.append("--open")
322
+
323
+ # Run training
324
+ try:
325
+ result = subprocess.run(cmd, check=False)
326
+ return result.returncode
327
+ except KeyboardInterrupt:
328
+ print("\nTraining interrupted by user")
329
+ return 130
330
+
331
+
332
+ def cmd_check(args: argparse.Namespace) -> int:
333
+ """Check training health and early stopping analysis."""
334
+ status = get_training_status()
335
+
336
+ print(f"\n{'='*50}")
337
+ print("TRAINING HEALTH CHECK")
338
+ print(f"{'='*50}")
339
+
340
+ raw_losses = status.get("losses", [])
341
+ if not raw_losses:
342
+ print("No training data found.")
343
+ print("Run training first with: uv run python -m openadapt_ml.cloud.local train --capture <path>")
344
+ return 1
345
+
346
+ # Extract loss values (handle both dict and float formats)
347
+ losses = []
348
+ for item in raw_losses:
349
+ if isinstance(item, dict):
350
+ losses.append(item.get("loss", 0))
351
+ else:
352
+ losses.append(float(item))
353
+
354
+ print(f"Total steps: {len(losses)}")
355
+ print(f"Current epoch: {status.get('epoch', 0)}")
356
+
357
+ # Loss analysis
358
+ if len(losses) >= 2:
359
+ first_loss = losses[0]
360
+ last_loss = losses[-1]
361
+ min_loss = min(losses)
362
+ max_loss = max(losses)
363
+
364
+ print(f"\nLoss progression:")
365
+ print(f" First: {first_loss:.4f}")
366
+ print(f" Last: {last_loss:.4f}")
367
+ print(f" Min: {min_loss:.4f}")
368
+ print(f" Max: {max_loss:.4f}")
369
+ print(f" Reduction: {((first_loss - last_loss) / first_loss * 100):.1f}%")
370
+
371
+ # Check for convergence
372
+ if len(losses) >= 10:
373
+ recent = losses[-10:]
374
+ recent_avg = sum(recent) / len(recent)
375
+ recent_std = (sum((x - recent_avg) ** 2 for x in recent) / len(recent)) ** 0.5
376
+
377
+ print(f"\nRecent stability (last 10 steps):")
378
+ print(f" Avg loss: {recent_avg:.4f}")
379
+ print(f" Std dev: {recent_std:.4f}")
380
+
381
+ if recent_std < 0.01:
382
+ print(" Status: ✓ Converged (stable)")
383
+ elif last_loss > first_loss:
384
+ print(" Status: ⚠ Loss increasing - may need lower learning rate")
385
+ else:
386
+ print(" Status: Training in progress")
387
+
388
+ print()
389
+ return 0
390
+
391
+
392
+ def cmd_serve(args: argparse.Namespace) -> int:
393
+ """Start local web server for dashboard.
394
+
395
+ Automatically regenerates dashboard and viewer before serving to ensure
396
+ the latest code and data are reflected.
397
+ """
398
+ from openadapt_ml.training.trainer import regenerate_local_dashboard
399
+
400
+ port = args.port
401
+
402
+ # Determine what to serve: benchmark directory or training output
403
+ if hasattr(args, 'benchmark') and args.benchmark:
404
+ serve_dir = Path(args.benchmark).expanduser().resolve()
405
+ if not serve_dir.exists():
406
+ print(f"Error: Benchmark directory not found: {serve_dir}")
407
+ return 1
408
+
409
+ # Regenerate benchmark viewer if needed
410
+ if not args.no_regenerate:
411
+ print("Regenerating benchmark viewer...")
412
+ try:
413
+ from openadapt_ml.training.benchmark_viewer import generate_benchmark_viewer
414
+ generate_benchmark_viewer(serve_dir)
415
+ except Exception as e:
416
+ print(f"Warning: Could not regenerate benchmark viewer: {e}")
417
+
418
+ start_page = "benchmark.html"
419
+ else:
420
+ serve_dir = get_current_output_dir()
421
+
422
+ if not serve_dir.exists():
423
+ print(f"Error: {serve_dir} not found. Run training first.")
424
+ return 1
425
+
426
+ # Regenerate dashboard and viewer with latest code before serving
427
+ if not args.no_regenerate:
428
+ print("Regenerating dashboard and viewer...")
429
+ try:
430
+ regenerate_local_dashboard(str(serve_dir))
431
+ # Also regenerate viewer if comparison data exists
432
+ _regenerate_viewer_if_possible(serve_dir)
433
+ except Exception as e:
434
+ print(f"Warning: Could not regenerate: {e}")
435
+
436
+ # Also regenerate benchmark viewer from latest benchmark results
437
+ _regenerate_benchmark_viewer_if_available(serve_dir)
438
+
439
+ start_page = "dashboard.html"
440
+
441
+ # Serve from the specified directory
442
+ os.chdir(serve_dir)
443
+
444
+ # Custom handler with /api/stop support
445
+ quiet_mode = args.quiet
446
+
447
+ class StopHandler(http.server.SimpleHTTPRequestHandler):
448
+ def log_message(self, format, *log_args):
449
+ if quiet_mode:
450
+ pass # Suppress request logging
451
+ else:
452
+ super().log_message(format, *log_args)
453
+
454
+ def do_POST(self):
455
+ if self.path == '/api/stop':
456
+ # Create stop signal file
457
+ stop_file = serve_dir / "STOP_TRAINING"
458
+ stop_file.touch()
459
+ self.send_response(200)
460
+ self.send_header('Content-Type', 'application/json')
461
+ self.send_header('Access-Control-Allow-Origin', '*')
462
+ self.end_headers()
463
+ self.wfile.write(b'{"status": "stop_signal_created"}')
464
+ print(f"\n⏹ Stop signal created: {stop_file}")
465
+ elif self.path == '/api/run-benchmark':
466
+ # Parse request body for provider
467
+ content_length = int(self.headers.get('Content-Length', 0))
468
+ body = self.rfile.read(content_length).decode('utf-8') if content_length else '{}'
469
+ try:
470
+ params = json.loads(body)
471
+ except json.JSONDecodeError:
472
+ params = {}
473
+
474
+ provider = params.get('provider', 'anthropic')
475
+ tasks = params.get('tasks', 5)
476
+
477
+ self.send_response(200)
478
+ self.send_header('Content-Type', 'application/json')
479
+ self.send_header('Access-Control-Allow-Origin', '*')
480
+ self.end_headers()
481
+ self.wfile.write(json.dumps({"status": "started", "provider": provider, "tasks": tasks}).encode())
482
+
483
+ # Run benchmark in background thread with progress logging
484
+ def run_benchmark():
485
+ import subprocess
486
+ from dotenv import load_dotenv
487
+
488
+ # Load .env file for API keys
489
+ project_root = Path(__file__).parent.parent.parent
490
+ load_dotenv(project_root / ".env")
491
+
492
+ # Create progress log file (in cwd which is serve_dir)
493
+ progress_file = Path("benchmark_progress.json")
494
+
495
+ print(f"\n🚀 Starting {provider} benchmark evaluation ({tasks} tasks)...")
496
+
497
+ # Write initial progress
498
+ progress_file.write_text(json.dumps({
499
+ "status": "running",
500
+ "provider": provider,
501
+ "tasks_total": tasks,
502
+ "tasks_complete": 0,
503
+ "message": f"Starting {provider} evaluation..."
504
+ }))
505
+
506
+ # Copy environment with loaded vars
507
+ env = os.environ.copy()
508
+
509
+ result = subprocess.run(
510
+ ["uv", "run", "python", "-m", "openadapt_ml.benchmarks.cli", "run-api",
511
+ "--provider", provider, "--tasks", str(tasks),
512
+ "--model-id", f"{provider}-api"],
513
+ capture_output=True, text=True, cwd=str(project_root), env=env
514
+ )
515
+
516
+ print(f"\n📋 Benchmark output:\n{result.stdout}")
517
+ if result.stderr:
518
+ print(f"Stderr: {result.stderr}")
519
+
520
+ if result.returncode == 0:
521
+ print(f"✅ Benchmark complete. Regenerating viewer...")
522
+ progress_file.write_text(json.dumps({
523
+ "status": "complete",
524
+ "provider": provider,
525
+ "message": "Evaluation complete! Refreshing results..."
526
+ }))
527
+ # Regenerate benchmark viewer
528
+ _regenerate_benchmark_viewer_if_available(serve_dir)
529
+ else:
530
+ print(f"❌ Benchmark failed: {result.stderr}")
531
+ progress_file.write_text(json.dumps({
532
+ "status": "error",
533
+ "provider": provider,
534
+ "message": f"Evaluation failed: {result.stderr[:200]}"
535
+ }))
536
+
537
+ threading.Thread(target=run_benchmark, daemon=True).start()
538
+ else:
539
+ self.send_error(404, "Not found")
540
+
541
+ def do_GET(self):
542
+ if self.path.startswith('/api/benchmark-progress'):
543
+ # Return benchmark progress
544
+ progress_file = Path("benchmark_progress.json") # Relative to serve_dir (cwd)
545
+ if progress_file.exists():
546
+ progress = progress_file.read_text()
547
+ else:
548
+ progress = json.dumps({"status": "idle"})
549
+
550
+ self.send_response(200)
551
+ self.send_header('Content-Type', 'application/json')
552
+ self.send_header('Access-Control-Allow-Origin', '*')
553
+ self.end_headers()
554
+ self.wfile.write(progress.encode())
555
+ else:
556
+ # Default file serving
557
+ super().do_GET()
558
+
559
+ def do_OPTIONS(self):
560
+ # Handle CORS preflight
561
+ self.send_response(200)
562
+ self.send_header('Access-Control-Allow-Origin', '*')
563
+ self.send_header('Access-Control-Allow-Methods', 'POST, OPTIONS')
564
+ self.send_header('Access-Control-Allow-Headers', 'Content-Type')
565
+ self.end_headers()
566
+
567
+ with socketserver.TCPServer(("", port), StopHandler) as httpd:
568
+ url = f"http://localhost:{port}/{start_page}"
569
+ print(f"\nServing at: {url}")
570
+ print(f"Directory: {serve_dir}")
571
+ print("Press Ctrl+C to stop\n")
572
+
573
+ if args.open:
574
+ webbrowser.open(url)
575
+
576
+ try:
577
+ httpd.serve_forever()
578
+ except KeyboardInterrupt:
579
+ print("\nServer stopped")
580
+
581
+ return 0
582
+
583
+
584
+ def cmd_viewer(args: argparse.Namespace) -> int:
585
+ """Regenerate viewer from local training output."""
586
+ from openadapt_ml.training.trainer import (
587
+ generate_training_dashboard,
588
+ generate_unified_viewer_from_output_dir,
589
+ TrainingState,
590
+ TrainingConfig,
591
+ )
592
+
593
+ current_dir = get_current_output_dir()
594
+
595
+ if not current_dir.exists():
596
+ print(f"Error: {current_dir} not found. Run training first.")
597
+ return 1
598
+
599
+ print(f"Regenerating viewer from {current_dir}...")
600
+
601
+ # Regenerate dashboard
602
+ log_file = current_dir / "training_log.json"
603
+ if log_file.exists():
604
+ with open(log_file) as f:
605
+ data = json.load(f)
606
+
607
+ state = TrainingState(job_id=data.get("job_id", ""))
608
+ state.epoch = data.get("epoch", 0)
609
+ state.step = data.get("step", 0)
610
+ state.loss = data.get("loss", 0)
611
+ state.learning_rate = data.get("learning_rate", 0)
612
+ state.losses = data.get("losses", [])
613
+ state.status = data.get("status", "completed")
614
+ state.elapsed_time = data.get("elapsed_time", 0.0) # Load elapsed time for completed training
615
+
616
+ config = TrainingConfig(
617
+ num_train_epochs=data.get("total_epochs", 5),
618
+ learning_rate=data.get("learning_rate", 5e-5),
619
+ )
620
+
621
+ dashboard_html = generate_training_dashboard(state, config)
622
+ (current_dir / "dashboard.html").write_text(dashboard_html)
623
+ print(f" Regenerated: dashboard.html")
624
+
625
+ # Generate unified viewer using consolidated function
626
+ viewer_path = generate_unified_viewer_from_output_dir(current_dir)
627
+ if viewer_path:
628
+ print(f"\nGenerated: {viewer_path}")
629
+ else:
630
+ print("\nNo comparison data found. Run comparison first or copy from capture directory.")
631
+
632
+ # Also regenerate benchmark viewer from latest benchmark results
633
+ _regenerate_benchmark_viewer_if_available(current_dir)
634
+
635
+ if args.open:
636
+ webbrowser.open(str(current_dir / "viewer.html"))
637
+
638
+ return 0
639
+
640
+
641
+ def cmd_benchmark_viewer(args: argparse.Namespace) -> int:
642
+ """Generate benchmark viewer from benchmark results."""
643
+ from openadapt_ml.training.benchmark_viewer import generate_benchmark_viewer
644
+
645
+ benchmark_dir = Path(args.benchmark_dir).expanduser().resolve()
646
+ if not benchmark_dir.exists():
647
+ print(f"Error: Benchmark directory not found: {benchmark_dir}")
648
+ return 1
649
+
650
+ print(f"\n{'='*50}")
651
+ print("GENERATING BENCHMARK VIEWER")
652
+ print(f"{'='*50}")
653
+ print(f"Benchmark dir: {benchmark_dir}")
654
+ print()
655
+
656
+ try:
657
+ viewer_path = generate_benchmark_viewer(benchmark_dir)
658
+ print(f"\nSuccess! Benchmark viewer generated at: {viewer_path}")
659
+
660
+ if args.open:
661
+ webbrowser.open(str(viewer_path))
662
+
663
+ return 0
664
+ except Exception as e:
665
+ print(f"Error generating benchmark viewer: {e}")
666
+ import traceback
667
+ traceback.print_exc()
668
+ return 1
669
+
670
+
671
+ def cmd_compare(args: argparse.Namespace) -> int:
672
+ """Run human vs AI comparison on local checkpoint."""
673
+ capture_path = Path(args.capture).expanduser().resolve()
674
+ if not capture_path.exists():
675
+ print(f"Error: Capture not found: {capture_path}")
676
+ return 1
677
+
678
+ checkpoint = args.checkpoint
679
+ if checkpoint and not Path(checkpoint).exists():
680
+ print(f"Error: Checkpoint not found: {checkpoint}")
681
+ return 1
682
+
683
+ print(f"\n{'='*50}")
684
+ print("RUNNING COMPARISON")
685
+ print(f"{'='*50}")
686
+ print(f"Capture: {capture_path}")
687
+ print(f"Checkpoint: {checkpoint or 'None (capture only)'}")
688
+ print()
689
+
690
+ cmd = [
691
+ sys.executable, "-m", "openadapt_ml.scripts.compare",
692
+ "--capture", str(capture_path),
693
+ ]
694
+
695
+ if checkpoint:
696
+ cmd.extend(["--checkpoint", checkpoint])
697
+
698
+ if args.open:
699
+ cmd.append("--open")
700
+
701
+ result = subprocess.run(cmd, check=False)
702
+ return result.returncode
703
+
704
+
705
+ def main():
706
+ parser = argparse.ArgumentParser(
707
+ description="Local GPU training CLI",
708
+ formatter_class=argparse.RawDescriptionHelpFormatter,
709
+ epilog="""
710
+ Examples:
711
+ # Train on a capture (auto-detects CUDA/MPS/CPU)
712
+ uv run python -m openadapt_ml.cloud.local train --capture ~/captures/my-workflow --open
713
+
714
+ # Check training status
715
+ uv run python -m openadapt_ml.cloud.local status
716
+
717
+ # Check training health (loss progression)
718
+ uv run python -m openadapt_ml.cloud.local check
719
+
720
+ # Start dashboard server
721
+ uv run python -m openadapt_ml.cloud.local serve --open
722
+
723
+ # Regenerate viewer
724
+ uv run python -m openadapt_ml.cloud.local viewer --open
725
+
726
+ # Generate benchmark viewer
727
+ uv run python -m openadapt_ml.cloud.local benchmark-viewer benchmark_results/test_run --open
728
+
729
+ # Run comparison
730
+ uv run python -m openadapt_ml.cloud.local compare --capture ~/captures/my-workflow --checkpoint checkpoints/model
731
+ """
732
+ )
733
+
734
+ subparsers = parser.add_subparsers(dest="command", help="Command")
735
+
736
+ # status
737
+ p_status = subparsers.add_parser("status", help="Show local training status")
738
+ p_status.set_defaults(func=cmd_status)
739
+
740
+ # train
741
+ p_train = subparsers.add_parser("train", help="Run training locally")
742
+ p_train.add_argument("--capture", required=True, help="Path to capture directory")
743
+ p_train.add_argument("--goal", help="Task goal (default: derived from capture name)")
744
+ p_train.add_argument("--config", help="Config file (default: auto-select based on device)")
745
+ p_train.add_argument("--open", action="store_true", help="Open dashboard in browser")
746
+ p_train.set_defaults(func=cmd_train)
747
+
748
+ # check
749
+ p_check = subparsers.add_parser("check", help="Check training health")
750
+ p_check.set_defaults(func=cmd_check)
751
+
752
+ # serve
753
+ p_serve = subparsers.add_parser("serve", help="Start web server for dashboard")
754
+ p_serve.add_argument("--port", type=int, default=8765, help="Port number")
755
+ p_serve.add_argument("--open", action="store_true", help="Open in browser")
756
+ p_serve.add_argument("--quiet", "-q", action="store_true", help="Suppress request logging")
757
+ p_serve.add_argument("--no-regenerate", action="store_true",
758
+ help="Skip regenerating dashboard/viewer (serve existing files)")
759
+ p_serve.add_argument("--benchmark", help="Serve benchmark results directory instead of training output")
760
+ p_serve.set_defaults(func=cmd_serve)
761
+
762
+ # viewer
763
+ p_viewer = subparsers.add_parser("viewer", help="Regenerate viewer")
764
+ p_viewer.add_argument("--open", action="store_true", help="Open in browser")
765
+ p_viewer.set_defaults(func=cmd_viewer)
766
+
767
+ # benchmark_viewer
768
+ p_benchmark = subparsers.add_parser("benchmark-viewer", help="Generate benchmark viewer")
769
+ p_benchmark.add_argument("benchmark_dir", help="Path to benchmark results directory")
770
+ p_benchmark.add_argument("--open", action="store_true", help="Open viewer in browser")
771
+ p_benchmark.set_defaults(func=cmd_benchmark_viewer)
772
+
773
+ # compare
774
+ p_compare = subparsers.add_parser("compare", help="Run human vs AI comparison")
775
+ p_compare.add_argument("--capture", required=True, help="Path to capture directory")
776
+ p_compare.add_argument("--checkpoint", help="Path to checkpoint (optional)")
777
+ p_compare.add_argument("--open", action="store_true", help="Open viewer in browser")
778
+ p_compare.set_defaults(func=cmd_compare)
779
+
780
+ args = parser.parse_args()
781
+
782
+ if not args.command:
783
+ parser.print_help()
784
+ return 0
785
+
786
+ return args.func(args)
787
+
788
+
789
+ if __name__ == "__main__":
790
+ sys.exit(main())