openadapt-ml 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openadapt_ml/__init__.py +0 -0
- openadapt_ml/benchmarks/__init__.py +125 -0
- openadapt_ml/benchmarks/agent.py +825 -0
- openadapt_ml/benchmarks/azure.py +761 -0
- openadapt_ml/benchmarks/base.py +366 -0
- openadapt_ml/benchmarks/cli.py +884 -0
- openadapt_ml/benchmarks/data_collection.py +432 -0
- openadapt_ml/benchmarks/runner.py +381 -0
- openadapt_ml/benchmarks/waa.py +704 -0
- openadapt_ml/cloud/__init__.py +5 -0
- openadapt_ml/cloud/azure_inference.py +441 -0
- openadapt_ml/cloud/lambda_labs.py +2445 -0
- openadapt_ml/cloud/local.py +790 -0
- openadapt_ml/config.py +56 -0
- openadapt_ml/datasets/__init__.py +0 -0
- openadapt_ml/datasets/next_action.py +507 -0
- openadapt_ml/evals/__init__.py +23 -0
- openadapt_ml/evals/grounding.py +241 -0
- openadapt_ml/evals/plot_eval_metrics.py +174 -0
- openadapt_ml/evals/trajectory_matching.py +486 -0
- openadapt_ml/grounding/__init__.py +45 -0
- openadapt_ml/grounding/base.py +236 -0
- openadapt_ml/grounding/detector.py +570 -0
- openadapt_ml/ingest/__init__.py +43 -0
- openadapt_ml/ingest/capture.py +312 -0
- openadapt_ml/ingest/loader.py +232 -0
- openadapt_ml/ingest/synthetic.py +1102 -0
- openadapt_ml/models/__init__.py +0 -0
- openadapt_ml/models/api_adapter.py +171 -0
- openadapt_ml/models/base_adapter.py +59 -0
- openadapt_ml/models/dummy_adapter.py +42 -0
- openadapt_ml/models/qwen_vl.py +426 -0
- openadapt_ml/runtime/__init__.py +0 -0
- openadapt_ml/runtime/policy.py +182 -0
- openadapt_ml/schemas/__init__.py +53 -0
- openadapt_ml/schemas/sessions.py +122 -0
- openadapt_ml/schemas/validation.py +252 -0
- openadapt_ml/scripts/__init__.py +0 -0
- openadapt_ml/scripts/compare.py +1490 -0
- openadapt_ml/scripts/demo_policy.py +62 -0
- openadapt_ml/scripts/eval_policy.py +287 -0
- openadapt_ml/scripts/make_gif.py +153 -0
- openadapt_ml/scripts/prepare_synthetic.py +43 -0
- openadapt_ml/scripts/run_qwen_login_benchmark.py +192 -0
- openadapt_ml/scripts/train.py +174 -0
- openadapt_ml/training/__init__.py +0 -0
- openadapt_ml/training/benchmark_viewer.py +1538 -0
- openadapt_ml/training/shared_ui.py +157 -0
- openadapt_ml/training/stub_provider.py +276 -0
- openadapt_ml/training/trainer.py +2446 -0
- openadapt_ml/training/viewer.py +2970 -0
- openadapt_ml-0.1.0.dist-info/METADATA +818 -0
- openadapt_ml-0.1.0.dist-info/RECORD +55 -0
- openadapt_ml-0.1.0.dist-info/WHEEL +4 -0
- openadapt_ml-0.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,790 @@
|
|
|
1
|
+
"""Local GPU training CLI.
|
|
2
|
+
|
|
3
|
+
Provides commands equivalent to lambda_labs.py but for local execution
|
|
4
|
+
on CUDA or Apple Silicon.
|
|
5
|
+
|
|
6
|
+
Usage:
|
|
7
|
+
# Train on a capture
|
|
8
|
+
uv run python -m openadapt_ml.cloud.local train --capture ~/captures/my-workflow
|
|
9
|
+
|
|
10
|
+
# Check training status
|
|
11
|
+
uv run python -m openadapt_ml.cloud.local status
|
|
12
|
+
|
|
13
|
+
# Check training health
|
|
14
|
+
uv run python -m openadapt_ml.cloud.local check
|
|
15
|
+
|
|
16
|
+
# Start dashboard server
|
|
17
|
+
uv run python -m openadapt_ml.cloud.local serve --open
|
|
18
|
+
|
|
19
|
+
# Regenerate viewer
|
|
20
|
+
uv run python -m openadapt_ml.cloud.local viewer
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
from __future__ import annotations
|
|
24
|
+
|
|
25
|
+
import argparse
|
|
26
|
+
import http.server
|
|
27
|
+
import json
|
|
28
|
+
import os
|
|
29
|
+
import shutil
|
|
30
|
+
import signal
|
|
31
|
+
import socketserver
|
|
32
|
+
import subprocess
|
|
33
|
+
import sys
|
|
34
|
+
import threading
|
|
35
|
+
import webbrowser
|
|
36
|
+
from pathlib import Path
|
|
37
|
+
from typing import Any
|
|
38
|
+
|
|
39
|
+
# Training output directory
|
|
40
|
+
TRAINING_OUTPUT = Path("training_output")
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def get_current_output_dir() -> Path:
|
|
44
|
+
"""Get the current job's output directory.
|
|
45
|
+
|
|
46
|
+
Returns the 'current' symlink path if it exists, otherwise falls back
|
|
47
|
+
to the base training_output directory for backward compatibility.
|
|
48
|
+
"""
|
|
49
|
+
current_link = TRAINING_OUTPUT / "current"
|
|
50
|
+
if current_link.is_symlink() or current_link.exists():
|
|
51
|
+
return current_link
|
|
52
|
+
# Fallback for backward compatibility with old structure
|
|
53
|
+
return TRAINING_OUTPUT
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def _regenerate_viewer_if_possible(output_dir: Path) -> bool:
|
|
57
|
+
"""Regenerate viewer.html if comparison data exists.
|
|
58
|
+
|
|
59
|
+
Returns True if viewer was regenerated, False otherwise.
|
|
60
|
+
"""
|
|
61
|
+
from openadapt_ml.training.trainer import generate_unified_viewer_from_output_dir
|
|
62
|
+
|
|
63
|
+
try:
|
|
64
|
+
viewer_path = generate_unified_viewer_from_output_dir(output_dir)
|
|
65
|
+
if viewer_path:
|
|
66
|
+
print(f"Regenerated viewer: {viewer_path}")
|
|
67
|
+
return True
|
|
68
|
+
return False
|
|
69
|
+
except Exception as e:
|
|
70
|
+
print(f"Could not regenerate viewer: {e}")
|
|
71
|
+
return False
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def _is_mock_benchmark(benchmark_dir: Path) -> bool:
|
|
75
|
+
"""Check if a benchmark run is mock/test data (not real evaluation).
|
|
76
|
+
|
|
77
|
+
Returns True if the benchmark is mock data that should be filtered out.
|
|
78
|
+
|
|
79
|
+
Note: API evaluations using the mock WAA adapter (waa-mock) are considered
|
|
80
|
+
real evaluations and should NOT be filtered out, since they represent actual
|
|
81
|
+
model performance on test tasks.
|
|
82
|
+
"""
|
|
83
|
+
# Check summary.json for model_id
|
|
84
|
+
summary_path = benchmark_dir / "summary.json"
|
|
85
|
+
if summary_path.exists():
|
|
86
|
+
try:
|
|
87
|
+
with open(summary_path) as f:
|
|
88
|
+
summary = json.load(f)
|
|
89
|
+
model_id = summary.get("model_id", "").lower()
|
|
90
|
+
# Filter out mock/test/random agent runs (but keep API models like "anthropic-api")
|
|
91
|
+
if any(term in model_id for term in ["random-agent", "scripted-agent"]):
|
|
92
|
+
return True
|
|
93
|
+
except Exception:
|
|
94
|
+
pass
|
|
95
|
+
|
|
96
|
+
# Check metadata.json for model_id
|
|
97
|
+
metadata_path = benchmark_dir / "metadata.json"
|
|
98
|
+
if metadata_path.exists():
|
|
99
|
+
try:
|
|
100
|
+
with open(metadata_path) as f:
|
|
101
|
+
metadata = json.load(f)
|
|
102
|
+
model_id = metadata.get("model_id", "").lower()
|
|
103
|
+
if any(term in model_id for term in ["random-agent", "scripted-agent"]):
|
|
104
|
+
return True
|
|
105
|
+
except Exception:
|
|
106
|
+
pass
|
|
107
|
+
|
|
108
|
+
# Check for test runs (but allow waa-mock evaluations with real API models)
|
|
109
|
+
# Only filter out purely synthetic test data directories
|
|
110
|
+
if any(term in benchmark_dir.name.lower() for term in ["test_run", "test_cli", "quick_demo"]):
|
|
111
|
+
return True
|
|
112
|
+
|
|
113
|
+
return False
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def _regenerate_benchmark_viewer_if_available(output_dir: Path) -> bool:
|
|
117
|
+
"""Regenerate benchmark.html from all real benchmark results.
|
|
118
|
+
|
|
119
|
+
Loads all non-mock benchmark runs from benchmark_results/ directory
|
|
120
|
+
and generates a unified benchmark viewer supporting multiple runs.
|
|
121
|
+
If no real benchmark data exists, generates an empty state viewer with guidance.
|
|
122
|
+
|
|
123
|
+
Returns True if benchmark viewer was regenerated, False otherwise.
|
|
124
|
+
"""
|
|
125
|
+
from openadapt_ml.training.benchmark_viewer import (
|
|
126
|
+
generate_multi_run_benchmark_viewer,
|
|
127
|
+
generate_empty_benchmark_viewer,
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
benchmark_results_dir = Path("benchmark_results")
|
|
131
|
+
|
|
132
|
+
# Find real (non-mock) benchmark runs
|
|
133
|
+
real_benchmarks = []
|
|
134
|
+
if benchmark_results_dir.exists():
|
|
135
|
+
for d in benchmark_results_dir.iterdir():
|
|
136
|
+
if d.is_dir() and (d / "summary.json").exists():
|
|
137
|
+
if not _is_mock_benchmark(d):
|
|
138
|
+
real_benchmarks.append(d)
|
|
139
|
+
|
|
140
|
+
benchmark_html_path = output_dir / "benchmark.html"
|
|
141
|
+
|
|
142
|
+
if not real_benchmarks:
|
|
143
|
+
# No real benchmark data - generate empty state viewer
|
|
144
|
+
try:
|
|
145
|
+
generate_empty_benchmark_viewer(benchmark_html_path)
|
|
146
|
+
print(" Generated benchmark viewer: No real evaluation data yet")
|
|
147
|
+
return True
|
|
148
|
+
except Exception as e:
|
|
149
|
+
print(f" Could not generate empty benchmark viewer: {e}")
|
|
150
|
+
return False
|
|
151
|
+
|
|
152
|
+
# Sort by modification time (most recent first)
|
|
153
|
+
real_benchmarks.sort(key=lambda d: d.stat().st_mtime, reverse=True)
|
|
154
|
+
|
|
155
|
+
try:
|
|
156
|
+
# Generate multi-run benchmark.html in the output directory
|
|
157
|
+
generate_multi_run_benchmark_viewer(real_benchmarks, benchmark_html_path)
|
|
158
|
+
|
|
159
|
+
# Copy all tasks folders for screenshots (organized by run)
|
|
160
|
+
benchmark_tasks_dir = output_dir / "benchmark_tasks"
|
|
161
|
+
if benchmark_tasks_dir.exists():
|
|
162
|
+
shutil.rmtree(benchmark_tasks_dir)
|
|
163
|
+
benchmark_tasks_dir.mkdir(exist_ok=True)
|
|
164
|
+
|
|
165
|
+
for benchmark_dir in real_benchmarks:
|
|
166
|
+
tasks_src = benchmark_dir / "tasks"
|
|
167
|
+
if tasks_src.exists():
|
|
168
|
+
tasks_dst = benchmark_tasks_dir / benchmark_dir.name
|
|
169
|
+
shutil.copytree(tasks_src, tasks_dst)
|
|
170
|
+
|
|
171
|
+
print(f" Regenerated benchmark viewer with {len(real_benchmarks)} run(s)")
|
|
172
|
+
return True
|
|
173
|
+
except Exception as e:
|
|
174
|
+
print(f" Could not regenerate benchmark viewer: {e}")
|
|
175
|
+
import traceback
|
|
176
|
+
traceback.print_exc()
|
|
177
|
+
return False
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
def detect_device() -> str:
|
|
181
|
+
"""Detect available compute device."""
|
|
182
|
+
try:
|
|
183
|
+
import torch
|
|
184
|
+
if torch.cuda.is_available():
|
|
185
|
+
device_name = torch.cuda.get_device_name(0)
|
|
186
|
+
return f"cuda ({device_name})"
|
|
187
|
+
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
|
188
|
+
return "mps (Apple Silicon)"
|
|
189
|
+
else:
|
|
190
|
+
return "cpu"
|
|
191
|
+
except ImportError:
|
|
192
|
+
return "unknown (torch not installed)"
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
def get_training_status() -> dict[str, Any]:
|
|
196
|
+
"""Get current training status from training_output/current."""
|
|
197
|
+
current_dir = get_current_output_dir()
|
|
198
|
+
|
|
199
|
+
status = {
|
|
200
|
+
"running": False,
|
|
201
|
+
"epoch": 0,
|
|
202
|
+
"step": 0,
|
|
203
|
+
"loss": None,
|
|
204
|
+
"device": detect_device(),
|
|
205
|
+
"has_dashboard": False,
|
|
206
|
+
"has_viewer": False,
|
|
207
|
+
"checkpoints": [],
|
|
208
|
+
"job_id": None,
|
|
209
|
+
"output_dir": str(current_dir),
|
|
210
|
+
}
|
|
211
|
+
|
|
212
|
+
log_file = current_dir / "training_log.json"
|
|
213
|
+
if log_file.exists():
|
|
214
|
+
try:
|
|
215
|
+
with open(log_file) as f:
|
|
216
|
+
data = json.load(f)
|
|
217
|
+
status["job_id"] = data.get("job_id")
|
|
218
|
+
status["epoch"] = data.get("epoch", 0)
|
|
219
|
+
status["step"] = data.get("step", 0)
|
|
220
|
+
status["loss"] = data.get("loss")
|
|
221
|
+
status["learning_rate"] = data.get("learning_rate")
|
|
222
|
+
status["losses"] = data.get("losses", [])
|
|
223
|
+
status["status"] = data.get("status", "unknown")
|
|
224
|
+
status["running"] = data.get("status") == "training"
|
|
225
|
+
except (json.JSONDecodeError, KeyError):
|
|
226
|
+
pass
|
|
227
|
+
|
|
228
|
+
status["has_dashboard"] = (current_dir / "dashboard.html").exists()
|
|
229
|
+
status["has_viewer"] = (current_dir / "viewer.html").exists()
|
|
230
|
+
|
|
231
|
+
# Find checkpoints
|
|
232
|
+
checkpoints_dir = Path("checkpoints")
|
|
233
|
+
if checkpoints_dir.exists():
|
|
234
|
+
status["checkpoints"] = sorted([
|
|
235
|
+
d.name for d in checkpoints_dir.iterdir()
|
|
236
|
+
if d.is_dir() and (d / "adapter_config.json").exists()
|
|
237
|
+
])
|
|
238
|
+
|
|
239
|
+
return status
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
def cmd_status(args: argparse.Namespace) -> int:
|
|
243
|
+
"""Show local training status."""
|
|
244
|
+
status = get_training_status()
|
|
245
|
+
current_dir = get_current_output_dir()
|
|
246
|
+
|
|
247
|
+
print(f"\n{'='*50}")
|
|
248
|
+
print("LOCAL TRAINING STATUS")
|
|
249
|
+
print(f"{'='*50}")
|
|
250
|
+
print(f"Device: {status['device']}")
|
|
251
|
+
print(f"Status: {'RUNNING' if status['running'] else 'IDLE'}")
|
|
252
|
+
if status.get("job_id"):
|
|
253
|
+
print(f"Job ID: {status['job_id']}")
|
|
254
|
+
print(f"Output: {current_dir}")
|
|
255
|
+
|
|
256
|
+
if status.get("epoch"):
|
|
257
|
+
print(f"\nProgress:")
|
|
258
|
+
print(f" Epoch: {status['epoch']}")
|
|
259
|
+
print(f" Step: {status['step']}")
|
|
260
|
+
if status.get("loss"):
|
|
261
|
+
print(f" Loss: {status['loss']:.4f}")
|
|
262
|
+
if status.get("learning_rate"):
|
|
263
|
+
print(f" LR: {status['learning_rate']:.2e}")
|
|
264
|
+
|
|
265
|
+
if status["checkpoints"]:
|
|
266
|
+
print(f"\nCheckpoints ({len(status['checkpoints'])}):")
|
|
267
|
+
for cp in status["checkpoints"][-5:]: # Show last 5
|
|
268
|
+
print(f" - {cp}")
|
|
269
|
+
|
|
270
|
+
print(f"\nDashboard: {'✓' if status['has_dashboard'] else '✗'} {current_dir}/dashboard.html")
|
|
271
|
+
print(f"Viewer: {'✓' if status['has_viewer'] else '✗'} {current_dir}/viewer.html")
|
|
272
|
+
print()
|
|
273
|
+
|
|
274
|
+
return 0
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
def cmd_train(args: argparse.Namespace) -> int:
|
|
278
|
+
"""Run training locally."""
|
|
279
|
+
capture_path = Path(args.capture).expanduser().resolve()
|
|
280
|
+
if not capture_path.exists():
|
|
281
|
+
print(f"Error: Capture not found: {capture_path}")
|
|
282
|
+
return 1
|
|
283
|
+
|
|
284
|
+
# Determine goal from capture directory name if not provided
|
|
285
|
+
goal = args.goal
|
|
286
|
+
if not goal:
|
|
287
|
+
goal = capture_path.name.replace("-", " ").replace("_", " ").title()
|
|
288
|
+
|
|
289
|
+
# Select config based on device
|
|
290
|
+
config = args.config
|
|
291
|
+
if not config:
|
|
292
|
+
device = detect_device()
|
|
293
|
+
if "cuda" in device:
|
|
294
|
+
config = "configs/qwen3vl_capture.yaml"
|
|
295
|
+
else:
|
|
296
|
+
config = "configs/qwen3vl_capture_4bit.yaml"
|
|
297
|
+
|
|
298
|
+
config_path = Path(config)
|
|
299
|
+
if not config_path.exists():
|
|
300
|
+
print(f"Error: Config not found: {config_path}")
|
|
301
|
+
return 1
|
|
302
|
+
|
|
303
|
+
print(f"\n{'='*50}")
|
|
304
|
+
print("STARTING LOCAL TRAINING")
|
|
305
|
+
print(f"{'='*50}")
|
|
306
|
+
print(f"Capture: {capture_path}")
|
|
307
|
+
print(f"Goal: {goal}")
|
|
308
|
+
print(f"Config: {config}")
|
|
309
|
+
print(f"Device: {detect_device()}")
|
|
310
|
+
print()
|
|
311
|
+
|
|
312
|
+
# Build command
|
|
313
|
+
cmd = [
|
|
314
|
+
sys.executable, "-m", "openadapt_ml.scripts.train",
|
|
315
|
+
"--config", str(config_path),
|
|
316
|
+
"--capture", str(capture_path),
|
|
317
|
+
"--goal", goal,
|
|
318
|
+
]
|
|
319
|
+
|
|
320
|
+
if args.open:
|
|
321
|
+
cmd.append("--open")
|
|
322
|
+
|
|
323
|
+
# Run training
|
|
324
|
+
try:
|
|
325
|
+
result = subprocess.run(cmd, check=False)
|
|
326
|
+
return result.returncode
|
|
327
|
+
except KeyboardInterrupt:
|
|
328
|
+
print("\nTraining interrupted by user")
|
|
329
|
+
return 130
|
|
330
|
+
|
|
331
|
+
|
|
332
|
+
def cmd_check(args: argparse.Namespace) -> int:
|
|
333
|
+
"""Check training health and early stopping analysis."""
|
|
334
|
+
status = get_training_status()
|
|
335
|
+
|
|
336
|
+
print(f"\n{'='*50}")
|
|
337
|
+
print("TRAINING HEALTH CHECK")
|
|
338
|
+
print(f"{'='*50}")
|
|
339
|
+
|
|
340
|
+
raw_losses = status.get("losses", [])
|
|
341
|
+
if not raw_losses:
|
|
342
|
+
print("No training data found.")
|
|
343
|
+
print("Run training first with: uv run python -m openadapt_ml.cloud.local train --capture <path>")
|
|
344
|
+
return 1
|
|
345
|
+
|
|
346
|
+
# Extract loss values (handle both dict and float formats)
|
|
347
|
+
losses = []
|
|
348
|
+
for item in raw_losses:
|
|
349
|
+
if isinstance(item, dict):
|
|
350
|
+
losses.append(item.get("loss", 0))
|
|
351
|
+
else:
|
|
352
|
+
losses.append(float(item))
|
|
353
|
+
|
|
354
|
+
print(f"Total steps: {len(losses)}")
|
|
355
|
+
print(f"Current epoch: {status.get('epoch', 0)}")
|
|
356
|
+
|
|
357
|
+
# Loss analysis
|
|
358
|
+
if len(losses) >= 2:
|
|
359
|
+
first_loss = losses[0]
|
|
360
|
+
last_loss = losses[-1]
|
|
361
|
+
min_loss = min(losses)
|
|
362
|
+
max_loss = max(losses)
|
|
363
|
+
|
|
364
|
+
print(f"\nLoss progression:")
|
|
365
|
+
print(f" First: {first_loss:.4f}")
|
|
366
|
+
print(f" Last: {last_loss:.4f}")
|
|
367
|
+
print(f" Min: {min_loss:.4f}")
|
|
368
|
+
print(f" Max: {max_loss:.4f}")
|
|
369
|
+
print(f" Reduction: {((first_loss - last_loss) / first_loss * 100):.1f}%")
|
|
370
|
+
|
|
371
|
+
# Check for convergence
|
|
372
|
+
if len(losses) >= 10:
|
|
373
|
+
recent = losses[-10:]
|
|
374
|
+
recent_avg = sum(recent) / len(recent)
|
|
375
|
+
recent_std = (sum((x - recent_avg) ** 2 for x in recent) / len(recent)) ** 0.5
|
|
376
|
+
|
|
377
|
+
print(f"\nRecent stability (last 10 steps):")
|
|
378
|
+
print(f" Avg loss: {recent_avg:.4f}")
|
|
379
|
+
print(f" Std dev: {recent_std:.4f}")
|
|
380
|
+
|
|
381
|
+
if recent_std < 0.01:
|
|
382
|
+
print(" Status: ✓ Converged (stable)")
|
|
383
|
+
elif last_loss > first_loss:
|
|
384
|
+
print(" Status: ⚠ Loss increasing - may need lower learning rate")
|
|
385
|
+
else:
|
|
386
|
+
print(" Status: Training in progress")
|
|
387
|
+
|
|
388
|
+
print()
|
|
389
|
+
return 0
|
|
390
|
+
|
|
391
|
+
|
|
392
|
+
def cmd_serve(args: argparse.Namespace) -> int:
|
|
393
|
+
"""Start local web server for dashboard.
|
|
394
|
+
|
|
395
|
+
Automatically regenerates dashboard and viewer before serving to ensure
|
|
396
|
+
the latest code and data are reflected.
|
|
397
|
+
"""
|
|
398
|
+
from openadapt_ml.training.trainer import regenerate_local_dashboard
|
|
399
|
+
|
|
400
|
+
port = args.port
|
|
401
|
+
|
|
402
|
+
# Determine what to serve: benchmark directory or training output
|
|
403
|
+
if hasattr(args, 'benchmark') and args.benchmark:
|
|
404
|
+
serve_dir = Path(args.benchmark).expanduser().resolve()
|
|
405
|
+
if not serve_dir.exists():
|
|
406
|
+
print(f"Error: Benchmark directory not found: {serve_dir}")
|
|
407
|
+
return 1
|
|
408
|
+
|
|
409
|
+
# Regenerate benchmark viewer if needed
|
|
410
|
+
if not args.no_regenerate:
|
|
411
|
+
print("Regenerating benchmark viewer...")
|
|
412
|
+
try:
|
|
413
|
+
from openadapt_ml.training.benchmark_viewer import generate_benchmark_viewer
|
|
414
|
+
generate_benchmark_viewer(serve_dir)
|
|
415
|
+
except Exception as e:
|
|
416
|
+
print(f"Warning: Could not regenerate benchmark viewer: {e}")
|
|
417
|
+
|
|
418
|
+
start_page = "benchmark.html"
|
|
419
|
+
else:
|
|
420
|
+
serve_dir = get_current_output_dir()
|
|
421
|
+
|
|
422
|
+
if not serve_dir.exists():
|
|
423
|
+
print(f"Error: {serve_dir} not found. Run training first.")
|
|
424
|
+
return 1
|
|
425
|
+
|
|
426
|
+
# Regenerate dashboard and viewer with latest code before serving
|
|
427
|
+
if not args.no_regenerate:
|
|
428
|
+
print("Regenerating dashboard and viewer...")
|
|
429
|
+
try:
|
|
430
|
+
regenerate_local_dashboard(str(serve_dir))
|
|
431
|
+
# Also regenerate viewer if comparison data exists
|
|
432
|
+
_regenerate_viewer_if_possible(serve_dir)
|
|
433
|
+
except Exception as e:
|
|
434
|
+
print(f"Warning: Could not regenerate: {e}")
|
|
435
|
+
|
|
436
|
+
# Also regenerate benchmark viewer from latest benchmark results
|
|
437
|
+
_regenerate_benchmark_viewer_if_available(serve_dir)
|
|
438
|
+
|
|
439
|
+
start_page = "dashboard.html"
|
|
440
|
+
|
|
441
|
+
# Serve from the specified directory
|
|
442
|
+
os.chdir(serve_dir)
|
|
443
|
+
|
|
444
|
+
# Custom handler with /api/stop support
|
|
445
|
+
quiet_mode = args.quiet
|
|
446
|
+
|
|
447
|
+
class StopHandler(http.server.SimpleHTTPRequestHandler):
|
|
448
|
+
def log_message(self, format, *log_args):
|
|
449
|
+
if quiet_mode:
|
|
450
|
+
pass # Suppress request logging
|
|
451
|
+
else:
|
|
452
|
+
super().log_message(format, *log_args)
|
|
453
|
+
|
|
454
|
+
def do_POST(self):
|
|
455
|
+
if self.path == '/api/stop':
|
|
456
|
+
# Create stop signal file
|
|
457
|
+
stop_file = serve_dir / "STOP_TRAINING"
|
|
458
|
+
stop_file.touch()
|
|
459
|
+
self.send_response(200)
|
|
460
|
+
self.send_header('Content-Type', 'application/json')
|
|
461
|
+
self.send_header('Access-Control-Allow-Origin', '*')
|
|
462
|
+
self.end_headers()
|
|
463
|
+
self.wfile.write(b'{"status": "stop_signal_created"}')
|
|
464
|
+
print(f"\n⏹ Stop signal created: {stop_file}")
|
|
465
|
+
elif self.path == '/api/run-benchmark':
|
|
466
|
+
# Parse request body for provider
|
|
467
|
+
content_length = int(self.headers.get('Content-Length', 0))
|
|
468
|
+
body = self.rfile.read(content_length).decode('utf-8') if content_length else '{}'
|
|
469
|
+
try:
|
|
470
|
+
params = json.loads(body)
|
|
471
|
+
except json.JSONDecodeError:
|
|
472
|
+
params = {}
|
|
473
|
+
|
|
474
|
+
provider = params.get('provider', 'anthropic')
|
|
475
|
+
tasks = params.get('tasks', 5)
|
|
476
|
+
|
|
477
|
+
self.send_response(200)
|
|
478
|
+
self.send_header('Content-Type', 'application/json')
|
|
479
|
+
self.send_header('Access-Control-Allow-Origin', '*')
|
|
480
|
+
self.end_headers()
|
|
481
|
+
self.wfile.write(json.dumps({"status": "started", "provider": provider, "tasks": tasks}).encode())
|
|
482
|
+
|
|
483
|
+
# Run benchmark in background thread with progress logging
|
|
484
|
+
def run_benchmark():
|
|
485
|
+
import subprocess
|
|
486
|
+
from dotenv import load_dotenv
|
|
487
|
+
|
|
488
|
+
# Load .env file for API keys
|
|
489
|
+
project_root = Path(__file__).parent.parent.parent
|
|
490
|
+
load_dotenv(project_root / ".env")
|
|
491
|
+
|
|
492
|
+
# Create progress log file (in cwd which is serve_dir)
|
|
493
|
+
progress_file = Path("benchmark_progress.json")
|
|
494
|
+
|
|
495
|
+
print(f"\n🚀 Starting {provider} benchmark evaluation ({tasks} tasks)...")
|
|
496
|
+
|
|
497
|
+
# Write initial progress
|
|
498
|
+
progress_file.write_text(json.dumps({
|
|
499
|
+
"status": "running",
|
|
500
|
+
"provider": provider,
|
|
501
|
+
"tasks_total": tasks,
|
|
502
|
+
"tasks_complete": 0,
|
|
503
|
+
"message": f"Starting {provider} evaluation..."
|
|
504
|
+
}))
|
|
505
|
+
|
|
506
|
+
# Copy environment with loaded vars
|
|
507
|
+
env = os.environ.copy()
|
|
508
|
+
|
|
509
|
+
result = subprocess.run(
|
|
510
|
+
["uv", "run", "python", "-m", "openadapt_ml.benchmarks.cli", "run-api",
|
|
511
|
+
"--provider", provider, "--tasks", str(tasks),
|
|
512
|
+
"--model-id", f"{provider}-api"],
|
|
513
|
+
capture_output=True, text=True, cwd=str(project_root), env=env
|
|
514
|
+
)
|
|
515
|
+
|
|
516
|
+
print(f"\n📋 Benchmark output:\n{result.stdout}")
|
|
517
|
+
if result.stderr:
|
|
518
|
+
print(f"Stderr: {result.stderr}")
|
|
519
|
+
|
|
520
|
+
if result.returncode == 0:
|
|
521
|
+
print(f"✅ Benchmark complete. Regenerating viewer...")
|
|
522
|
+
progress_file.write_text(json.dumps({
|
|
523
|
+
"status": "complete",
|
|
524
|
+
"provider": provider,
|
|
525
|
+
"message": "Evaluation complete! Refreshing results..."
|
|
526
|
+
}))
|
|
527
|
+
# Regenerate benchmark viewer
|
|
528
|
+
_regenerate_benchmark_viewer_if_available(serve_dir)
|
|
529
|
+
else:
|
|
530
|
+
print(f"❌ Benchmark failed: {result.stderr}")
|
|
531
|
+
progress_file.write_text(json.dumps({
|
|
532
|
+
"status": "error",
|
|
533
|
+
"provider": provider,
|
|
534
|
+
"message": f"Evaluation failed: {result.stderr[:200]}"
|
|
535
|
+
}))
|
|
536
|
+
|
|
537
|
+
threading.Thread(target=run_benchmark, daemon=True).start()
|
|
538
|
+
else:
|
|
539
|
+
self.send_error(404, "Not found")
|
|
540
|
+
|
|
541
|
+
def do_GET(self):
|
|
542
|
+
if self.path.startswith('/api/benchmark-progress'):
|
|
543
|
+
# Return benchmark progress
|
|
544
|
+
progress_file = Path("benchmark_progress.json") # Relative to serve_dir (cwd)
|
|
545
|
+
if progress_file.exists():
|
|
546
|
+
progress = progress_file.read_text()
|
|
547
|
+
else:
|
|
548
|
+
progress = json.dumps({"status": "idle"})
|
|
549
|
+
|
|
550
|
+
self.send_response(200)
|
|
551
|
+
self.send_header('Content-Type', 'application/json')
|
|
552
|
+
self.send_header('Access-Control-Allow-Origin', '*')
|
|
553
|
+
self.end_headers()
|
|
554
|
+
self.wfile.write(progress.encode())
|
|
555
|
+
else:
|
|
556
|
+
# Default file serving
|
|
557
|
+
super().do_GET()
|
|
558
|
+
|
|
559
|
+
def do_OPTIONS(self):
|
|
560
|
+
# Handle CORS preflight
|
|
561
|
+
self.send_response(200)
|
|
562
|
+
self.send_header('Access-Control-Allow-Origin', '*')
|
|
563
|
+
self.send_header('Access-Control-Allow-Methods', 'POST, OPTIONS')
|
|
564
|
+
self.send_header('Access-Control-Allow-Headers', 'Content-Type')
|
|
565
|
+
self.end_headers()
|
|
566
|
+
|
|
567
|
+
with socketserver.TCPServer(("", port), StopHandler) as httpd:
|
|
568
|
+
url = f"http://localhost:{port}/{start_page}"
|
|
569
|
+
print(f"\nServing at: {url}")
|
|
570
|
+
print(f"Directory: {serve_dir}")
|
|
571
|
+
print("Press Ctrl+C to stop\n")
|
|
572
|
+
|
|
573
|
+
if args.open:
|
|
574
|
+
webbrowser.open(url)
|
|
575
|
+
|
|
576
|
+
try:
|
|
577
|
+
httpd.serve_forever()
|
|
578
|
+
except KeyboardInterrupt:
|
|
579
|
+
print("\nServer stopped")
|
|
580
|
+
|
|
581
|
+
return 0
|
|
582
|
+
|
|
583
|
+
|
|
584
|
+
def cmd_viewer(args: argparse.Namespace) -> int:
|
|
585
|
+
"""Regenerate viewer from local training output."""
|
|
586
|
+
from openadapt_ml.training.trainer import (
|
|
587
|
+
generate_training_dashboard,
|
|
588
|
+
generate_unified_viewer_from_output_dir,
|
|
589
|
+
TrainingState,
|
|
590
|
+
TrainingConfig,
|
|
591
|
+
)
|
|
592
|
+
|
|
593
|
+
current_dir = get_current_output_dir()
|
|
594
|
+
|
|
595
|
+
if not current_dir.exists():
|
|
596
|
+
print(f"Error: {current_dir} not found. Run training first.")
|
|
597
|
+
return 1
|
|
598
|
+
|
|
599
|
+
print(f"Regenerating viewer from {current_dir}...")
|
|
600
|
+
|
|
601
|
+
# Regenerate dashboard
|
|
602
|
+
log_file = current_dir / "training_log.json"
|
|
603
|
+
if log_file.exists():
|
|
604
|
+
with open(log_file) as f:
|
|
605
|
+
data = json.load(f)
|
|
606
|
+
|
|
607
|
+
state = TrainingState(job_id=data.get("job_id", ""))
|
|
608
|
+
state.epoch = data.get("epoch", 0)
|
|
609
|
+
state.step = data.get("step", 0)
|
|
610
|
+
state.loss = data.get("loss", 0)
|
|
611
|
+
state.learning_rate = data.get("learning_rate", 0)
|
|
612
|
+
state.losses = data.get("losses", [])
|
|
613
|
+
state.status = data.get("status", "completed")
|
|
614
|
+
state.elapsed_time = data.get("elapsed_time", 0.0) # Load elapsed time for completed training
|
|
615
|
+
|
|
616
|
+
config = TrainingConfig(
|
|
617
|
+
num_train_epochs=data.get("total_epochs", 5),
|
|
618
|
+
learning_rate=data.get("learning_rate", 5e-5),
|
|
619
|
+
)
|
|
620
|
+
|
|
621
|
+
dashboard_html = generate_training_dashboard(state, config)
|
|
622
|
+
(current_dir / "dashboard.html").write_text(dashboard_html)
|
|
623
|
+
print(f" Regenerated: dashboard.html")
|
|
624
|
+
|
|
625
|
+
# Generate unified viewer using consolidated function
|
|
626
|
+
viewer_path = generate_unified_viewer_from_output_dir(current_dir)
|
|
627
|
+
if viewer_path:
|
|
628
|
+
print(f"\nGenerated: {viewer_path}")
|
|
629
|
+
else:
|
|
630
|
+
print("\nNo comparison data found. Run comparison first or copy from capture directory.")
|
|
631
|
+
|
|
632
|
+
# Also regenerate benchmark viewer from latest benchmark results
|
|
633
|
+
_regenerate_benchmark_viewer_if_available(current_dir)
|
|
634
|
+
|
|
635
|
+
if args.open:
|
|
636
|
+
webbrowser.open(str(current_dir / "viewer.html"))
|
|
637
|
+
|
|
638
|
+
return 0
|
|
639
|
+
|
|
640
|
+
|
|
641
|
+
def cmd_benchmark_viewer(args: argparse.Namespace) -> int:
|
|
642
|
+
"""Generate benchmark viewer from benchmark results."""
|
|
643
|
+
from openadapt_ml.training.benchmark_viewer import generate_benchmark_viewer
|
|
644
|
+
|
|
645
|
+
benchmark_dir = Path(args.benchmark_dir).expanduser().resolve()
|
|
646
|
+
if not benchmark_dir.exists():
|
|
647
|
+
print(f"Error: Benchmark directory not found: {benchmark_dir}")
|
|
648
|
+
return 1
|
|
649
|
+
|
|
650
|
+
print(f"\n{'='*50}")
|
|
651
|
+
print("GENERATING BENCHMARK VIEWER")
|
|
652
|
+
print(f"{'='*50}")
|
|
653
|
+
print(f"Benchmark dir: {benchmark_dir}")
|
|
654
|
+
print()
|
|
655
|
+
|
|
656
|
+
try:
|
|
657
|
+
viewer_path = generate_benchmark_viewer(benchmark_dir)
|
|
658
|
+
print(f"\nSuccess! Benchmark viewer generated at: {viewer_path}")
|
|
659
|
+
|
|
660
|
+
if args.open:
|
|
661
|
+
webbrowser.open(str(viewer_path))
|
|
662
|
+
|
|
663
|
+
return 0
|
|
664
|
+
except Exception as e:
|
|
665
|
+
print(f"Error generating benchmark viewer: {e}")
|
|
666
|
+
import traceback
|
|
667
|
+
traceback.print_exc()
|
|
668
|
+
return 1
|
|
669
|
+
|
|
670
|
+
|
|
671
|
+
def cmd_compare(args: argparse.Namespace) -> int:
|
|
672
|
+
"""Run human vs AI comparison on local checkpoint."""
|
|
673
|
+
capture_path = Path(args.capture).expanduser().resolve()
|
|
674
|
+
if not capture_path.exists():
|
|
675
|
+
print(f"Error: Capture not found: {capture_path}")
|
|
676
|
+
return 1
|
|
677
|
+
|
|
678
|
+
checkpoint = args.checkpoint
|
|
679
|
+
if checkpoint and not Path(checkpoint).exists():
|
|
680
|
+
print(f"Error: Checkpoint not found: {checkpoint}")
|
|
681
|
+
return 1
|
|
682
|
+
|
|
683
|
+
print(f"\n{'='*50}")
|
|
684
|
+
print("RUNNING COMPARISON")
|
|
685
|
+
print(f"{'='*50}")
|
|
686
|
+
print(f"Capture: {capture_path}")
|
|
687
|
+
print(f"Checkpoint: {checkpoint or 'None (capture only)'}")
|
|
688
|
+
print()
|
|
689
|
+
|
|
690
|
+
cmd = [
|
|
691
|
+
sys.executable, "-m", "openadapt_ml.scripts.compare",
|
|
692
|
+
"--capture", str(capture_path),
|
|
693
|
+
]
|
|
694
|
+
|
|
695
|
+
if checkpoint:
|
|
696
|
+
cmd.extend(["--checkpoint", checkpoint])
|
|
697
|
+
|
|
698
|
+
if args.open:
|
|
699
|
+
cmd.append("--open")
|
|
700
|
+
|
|
701
|
+
result = subprocess.run(cmd, check=False)
|
|
702
|
+
return result.returncode
|
|
703
|
+
|
|
704
|
+
|
|
705
|
+
def main():
|
|
706
|
+
parser = argparse.ArgumentParser(
|
|
707
|
+
description="Local GPU training CLI",
|
|
708
|
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
709
|
+
epilog="""
|
|
710
|
+
Examples:
|
|
711
|
+
# Train on a capture (auto-detects CUDA/MPS/CPU)
|
|
712
|
+
uv run python -m openadapt_ml.cloud.local train --capture ~/captures/my-workflow --open
|
|
713
|
+
|
|
714
|
+
# Check training status
|
|
715
|
+
uv run python -m openadapt_ml.cloud.local status
|
|
716
|
+
|
|
717
|
+
# Check training health (loss progression)
|
|
718
|
+
uv run python -m openadapt_ml.cloud.local check
|
|
719
|
+
|
|
720
|
+
# Start dashboard server
|
|
721
|
+
uv run python -m openadapt_ml.cloud.local serve --open
|
|
722
|
+
|
|
723
|
+
# Regenerate viewer
|
|
724
|
+
uv run python -m openadapt_ml.cloud.local viewer --open
|
|
725
|
+
|
|
726
|
+
# Generate benchmark viewer
|
|
727
|
+
uv run python -m openadapt_ml.cloud.local benchmark-viewer benchmark_results/test_run --open
|
|
728
|
+
|
|
729
|
+
# Run comparison
|
|
730
|
+
uv run python -m openadapt_ml.cloud.local compare --capture ~/captures/my-workflow --checkpoint checkpoints/model
|
|
731
|
+
"""
|
|
732
|
+
)
|
|
733
|
+
|
|
734
|
+
subparsers = parser.add_subparsers(dest="command", help="Command")
|
|
735
|
+
|
|
736
|
+
# status
|
|
737
|
+
p_status = subparsers.add_parser("status", help="Show local training status")
|
|
738
|
+
p_status.set_defaults(func=cmd_status)
|
|
739
|
+
|
|
740
|
+
# train
|
|
741
|
+
p_train = subparsers.add_parser("train", help="Run training locally")
|
|
742
|
+
p_train.add_argument("--capture", required=True, help="Path to capture directory")
|
|
743
|
+
p_train.add_argument("--goal", help="Task goal (default: derived from capture name)")
|
|
744
|
+
p_train.add_argument("--config", help="Config file (default: auto-select based on device)")
|
|
745
|
+
p_train.add_argument("--open", action="store_true", help="Open dashboard in browser")
|
|
746
|
+
p_train.set_defaults(func=cmd_train)
|
|
747
|
+
|
|
748
|
+
# check
|
|
749
|
+
p_check = subparsers.add_parser("check", help="Check training health")
|
|
750
|
+
p_check.set_defaults(func=cmd_check)
|
|
751
|
+
|
|
752
|
+
# serve
|
|
753
|
+
p_serve = subparsers.add_parser("serve", help="Start web server for dashboard")
|
|
754
|
+
p_serve.add_argument("--port", type=int, default=8765, help="Port number")
|
|
755
|
+
p_serve.add_argument("--open", action="store_true", help="Open in browser")
|
|
756
|
+
p_serve.add_argument("--quiet", "-q", action="store_true", help="Suppress request logging")
|
|
757
|
+
p_serve.add_argument("--no-regenerate", action="store_true",
|
|
758
|
+
help="Skip regenerating dashboard/viewer (serve existing files)")
|
|
759
|
+
p_serve.add_argument("--benchmark", help="Serve benchmark results directory instead of training output")
|
|
760
|
+
p_serve.set_defaults(func=cmd_serve)
|
|
761
|
+
|
|
762
|
+
# viewer
|
|
763
|
+
p_viewer = subparsers.add_parser("viewer", help="Regenerate viewer")
|
|
764
|
+
p_viewer.add_argument("--open", action="store_true", help="Open in browser")
|
|
765
|
+
p_viewer.set_defaults(func=cmd_viewer)
|
|
766
|
+
|
|
767
|
+
# benchmark_viewer
|
|
768
|
+
p_benchmark = subparsers.add_parser("benchmark-viewer", help="Generate benchmark viewer")
|
|
769
|
+
p_benchmark.add_argument("benchmark_dir", help="Path to benchmark results directory")
|
|
770
|
+
p_benchmark.add_argument("--open", action="store_true", help="Open viewer in browser")
|
|
771
|
+
p_benchmark.set_defaults(func=cmd_benchmark_viewer)
|
|
772
|
+
|
|
773
|
+
# compare
|
|
774
|
+
p_compare = subparsers.add_parser("compare", help="Run human vs AI comparison")
|
|
775
|
+
p_compare.add_argument("--capture", required=True, help="Path to capture directory")
|
|
776
|
+
p_compare.add_argument("--checkpoint", help="Path to checkpoint (optional)")
|
|
777
|
+
p_compare.add_argument("--open", action="store_true", help="Open viewer in browser")
|
|
778
|
+
p_compare.set_defaults(func=cmd_compare)
|
|
779
|
+
|
|
780
|
+
args = parser.parse_args()
|
|
781
|
+
|
|
782
|
+
if not args.command:
|
|
783
|
+
parser.print_help()
|
|
784
|
+
return 0
|
|
785
|
+
|
|
786
|
+
return args.func(args)
|
|
787
|
+
|
|
788
|
+
|
|
789
|
+
if __name__ == "__main__":
|
|
790
|
+
sys.exit(main())
|