openadapt-ml 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. openadapt_ml/__init__.py +0 -0
  2. openadapt_ml/benchmarks/__init__.py +125 -0
  3. openadapt_ml/benchmarks/agent.py +825 -0
  4. openadapt_ml/benchmarks/azure.py +761 -0
  5. openadapt_ml/benchmarks/base.py +366 -0
  6. openadapt_ml/benchmarks/cli.py +884 -0
  7. openadapt_ml/benchmarks/data_collection.py +432 -0
  8. openadapt_ml/benchmarks/runner.py +381 -0
  9. openadapt_ml/benchmarks/waa.py +704 -0
  10. openadapt_ml/cloud/__init__.py +5 -0
  11. openadapt_ml/cloud/azure_inference.py +441 -0
  12. openadapt_ml/cloud/lambda_labs.py +2445 -0
  13. openadapt_ml/cloud/local.py +790 -0
  14. openadapt_ml/config.py +56 -0
  15. openadapt_ml/datasets/__init__.py +0 -0
  16. openadapt_ml/datasets/next_action.py +507 -0
  17. openadapt_ml/evals/__init__.py +23 -0
  18. openadapt_ml/evals/grounding.py +241 -0
  19. openadapt_ml/evals/plot_eval_metrics.py +174 -0
  20. openadapt_ml/evals/trajectory_matching.py +486 -0
  21. openadapt_ml/grounding/__init__.py +45 -0
  22. openadapt_ml/grounding/base.py +236 -0
  23. openadapt_ml/grounding/detector.py +570 -0
  24. openadapt_ml/ingest/__init__.py +43 -0
  25. openadapt_ml/ingest/capture.py +312 -0
  26. openadapt_ml/ingest/loader.py +232 -0
  27. openadapt_ml/ingest/synthetic.py +1102 -0
  28. openadapt_ml/models/__init__.py +0 -0
  29. openadapt_ml/models/api_adapter.py +171 -0
  30. openadapt_ml/models/base_adapter.py +59 -0
  31. openadapt_ml/models/dummy_adapter.py +42 -0
  32. openadapt_ml/models/qwen_vl.py +426 -0
  33. openadapt_ml/runtime/__init__.py +0 -0
  34. openadapt_ml/runtime/policy.py +182 -0
  35. openadapt_ml/schemas/__init__.py +53 -0
  36. openadapt_ml/schemas/sessions.py +122 -0
  37. openadapt_ml/schemas/validation.py +252 -0
  38. openadapt_ml/scripts/__init__.py +0 -0
  39. openadapt_ml/scripts/compare.py +1490 -0
  40. openadapt_ml/scripts/demo_policy.py +62 -0
  41. openadapt_ml/scripts/eval_policy.py +287 -0
  42. openadapt_ml/scripts/make_gif.py +153 -0
  43. openadapt_ml/scripts/prepare_synthetic.py +43 -0
  44. openadapt_ml/scripts/run_qwen_login_benchmark.py +192 -0
  45. openadapt_ml/scripts/train.py +174 -0
  46. openadapt_ml/training/__init__.py +0 -0
  47. openadapt_ml/training/benchmark_viewer.py +1538 -0
  48. openadapt_ml/training/shared_ui.py +157 -0
  49. openadapt_ml/training/stub_provider.py +276 -0
  50. openadapt_ml/training/trainer.py +2446 -0
  51. openadapt_ml/training/viewer.py +2970 -0
  52. openadapt_ml-0.1.0.dist-info/METADATA +818 -0
  53. openadapt_ml-0.1.0.dist-info/RECORD +55 -0
  54. openadapt_ml-0.1.0.dist-info/WHEEL +4 -0
  55. openadapt_ml-0.1.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,1490 @@
1
+ """Compare human actions vs model predictions on a capture.
2
+
3
+ Generates an enhanced viewer showing both human and predicted actions side-by-side.
4
+
5
+ Usage:
6
+ uv run python -m openadapt_ml.scripts.compare \
7
+ --capture /path/to/capture \
8
+ --checkpoint checkpoints/qwen3vl2b_capture_lora \
9
+ --output comparison.html
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import argparse
15
+ import json
16
+ from pathlib import Path
17
+ from typing import Any
18
+
19
+ from openadapt_ml.ingest.capture import capture_to_episode
20
+ from openadapt_ml.schemas.sessions import Episode, Step
21
+ from openadapt_ml.datasets.next_action import SYSTEM_PROMPT, format_action
22
+ from openadapt_ml.training.trainer import _get_shared_header_css, _generate_shared_header_html
23
+
24
+
25
+ def load_model(checkpoint_path: str | None, config_path: str | None = None):
26
+ """Load a trained model for inference.
27
+
28
+ Returns None if no checkpoint specified (will skip predictions).
29
+ """
30
+ if not checkpoint_path:
31
+ return None
32
+
33
+ checkpoint = Path(checkpoint_path)
34
+ if not checkpoint.exists():
35
+ print(f"Warning: Checkpoint not found at {checkpoint}")
36
+ return None
37
+
38
+ try:
39
+ from openadapt_ml.models.qwen_vl import QwenVLAdapter
40
+
41
+ # Load base model with LoRA weights
42
+ model_name = "Qwen/Qwen3-VL-2B-Instruct"
43
+ adapter = QwenVLAdapter.from_pretrained(
44
+ model_name=model_name,
45
+ lora_config={"weights_path": str(checkpoint)},
46
+ load_in_4bit=True, # Use 4-bit for inference too
47
+ )
48
+ print(f"Loaded model from {checkpoint}")
49
+ return adapter
50
+ except Exception as e:
51
+ print(f"Warning: Could not load model: {e}")
52
+ import traceback
53
+ traceback.print_exc()
54
+ return None
55
+
56
+
57
+ def predict_action(
58
+ model,
59
+ observation_image: str,
60
+ goal: str,
61
+ step_index: int = 0,
62
+ total_steps: int = 1,
63
+ action_history: list[str] | None = None,
64
+ ) -> dict[str, Any] | None:
65
+ """Run inference to predict an action given an observation.
66
+
67
+ Returns predicted action dict or None if no model.
68
+ """
69
+ if model is None:
70
+ return None
71
+
72
+ try:
73
+ import re
74
+
75
+ # Build history section matching training format
76
+ if action_history:
77
+ history_text = "ACTIONS COMPLETED SO FAR:\n"
78
+ for i, action_text in enumerate(action_history, 1):
79
+ history_text += f" {i}. {action_text}\n"
80
+ history_text += f"\nThis is step {step_index + 1} of {total_steps}. "
81
+ else:
82
+ history_text = f"This is step 1 of {total_steps} (no actions completed yet). "
83
+
84
+ # Match training prompt format exactly
85
+ user_content = (
86
+ f"Goal: {goal}\n\n"
87
+ f"{history_text}"
88
+ "Look at the screenshot and determine the NEXT action.\n\n"
89
+ "Thought: [what element to interact with and why]\n"
90
+ "Action: [CLICK(x=..., y=...) or TYPE(text=\"...\") or WAIT() or DONE()]"
91
+ )
92
+
93
+ # Build sample in the format expected by the adapter
94
+ sample = {
95
+ "images": [observation_image],
96
+ "messages": [
97
+ {"role": "system", "content": SYSTEM_PROMPT},
98
+ {"role": "user", "content": user_content},
99
+ ],
100
+ }
101
+
102
+ # Run inference using generate method
103
+ result = model.generate(sample, max_new_tokens=128)
104
+
105
+ # Parse result - look for CLICK(x=..., y=...) or similar patterns
106
+ action = {"type": "predicted", "raw_output": result}
107
+
108
+ # Try to extract coordinates from output
109
+ # Match patterns like: CLICK(x=0.42, y=0.31) or click at (0.42, 0.31)
110
+ click_match = re.search(r'CLICK\s*\(\s*x\s*=\s*([\d.]+)\s*,\s*y\s*=\s*([\d.]+)\s*\)', result, re.IGNORECASE)
111
+ if not click_match:
112
+ click_match = re.search(r'click.*?\(\s*([\d.]+)\s*,\s*([\d.]+)\s*\)', result, re.IGNORECASE)
113
+ if not click_match:
114
+ # Try to find any two decimal numbers
115
+ nums = re.findall(r'(0\.\d+)', result)
116
+ if len(nums) >= 2:
117
+ click_match = type('Match', (), {'group': lambda s, i: nums[i-1]})()
118
+
119
+ if click_match:
120
+ action["x"] = float(click_match.group(1))
121
+ action["y"] = float(click_match.group(2))
122
+ action["type"] = "click"
123
+
124
+ return action
125
+ except Exception as e:
126
+ import traceback
127
+ traceback.print_exc()
128
+ return {"type": "error", "error": str(e)}
129
+
130
+
131
+ def generate_comparison_data(
132
+ episode: Episode,
133
+ model=None,
134
+ ) -> list[dict[str, Any]]:
135
+ """Generate comparison data for each step.
136
+
137
+ Returns list of dicts with human action, predicted action, and metadata.
138
+ """
139
+ comparison_data = []
140
+ action_history: list[str] = []
141
+ total_steps = len(episode.steps)
142
+
143
+ for i, step in enumerate(episode.steps):
144
+ step_data = {
145
+ "index": i,
146
+ "time": step.t,
147
+ "image_path": step.observation.image_path,
148
+ "human_action": {
149
+ "type": step.action.type,
150
+ "x": step.action.x,
151
+ "y": step.action.y,
152
+ "text": step.action.text,
153
+ },
154
+ "predicted_action": None,
155
+ "match": None,
156
+ }
157
+
158
+ # Get prediction if model available
159
+ if model and step.observation.image_path:
160
+ predicted = predict_action(
161
+ model,
162
+ step.observation.image_path,
163
+ episode.goal,
164
+ step_index=i,
165
+ total_steps=total_steps,
166
+ action_history=action_history.copy(),
167
+ )
168
+ step_data["predicted_action"] = predicted
169
+
170
+ # Check if prediction matches human action
171
+ if predicted and predicted.get("type") == step.action.type:
172
+ step_data["match"] = True
173
+ else:
174
+ step_data["match"] = False
175
+
176
+ # Add this step's action to history for next iteration
177
+ action_history.append(format_action(step.action, use_som=False))
178
+ comparison_data.append(step_data)
179
+
180
+ return comparison_data
181
+
182
+
183
+ def generate_comparison_html(
184
+ capture_path: Path,
185
+ episode: Episode,
186
+ comparison_data: list[dict],
187
+ output_path: Path,
188
+ ) -> None:
189
+ """Generate an HTML viewer with comparison data."""
190
+
191
+ # Use openadapt-capture's viewer as base, then enhance
192
+ try:
193
+ from openadapt_capture.visualize.html import create_html
194
+
195
+ # Generate base viewer
196
+ base_html = create_html(capture_path, output=None)
197
+
198
+ # Inject comparison data and UI
199
+ comparison_json = json.dumps(comparison_data)
200
+
201
+ # Add comparison panel above screenshot in main content
202
+ comparison_panel = '''
203
+ <div class="comparison-panel" id="comparison-panel">
204
+ <div class="comparison-header">
205
+ <h2>Action Comparison</h2>
206
+ <div class="metrics-summary"></div>
207
+ <div class="overlay-toggles"></div>
208
+ </div>
209
+ <div class="comparison-content">
210
+ <div class="action-box human">
211
+ <div class="action-label">Human Action</div>
212
+ <div class="action-details" id="human-action"></div>
213
+ </div>
214
+ <div class="action-box predicted">
215
+ <div class="action-label">Model Prediction</div>
216
+ <div class="action-details" id="predicted-action"></div>
217
+ </div>
218
+ <div class="match-indicator" id="match-indicator"></div>
219
+ </div>
220
+ </div>
221
+ '''
222
+
223
+ comparison_styles = '''
224
+ <style>
225
+ /* Navigation bar */
226
+ .nav-bar {
227
+ display: flex;
228
+ gap: 8px;
229
+ padding: 12px 16px;
230
+ background: var(--bg-secondary);
231
+ border: 1px solid var(--border-color);
232
+ border-radius: 8px;
233
+ margin-bottom: 16px;
234
+ flex-wrap: wrap;
235
+ }
236
+ .nav-link {
237
+ padding: 8px 16px;
238
+ border-radius: 6px;
239
+ font-size: 0.8rem;
240
+ text-decoration: none;
241
+ color: var(--text-secondary);
242
+ background: var(--bg-tertiary);
243
+ border: 1px solid var(--border-color);
244
+ transition: all 0.2s;
245
+ }
246
+ .nav-link:hover {
247
+ border-color: var(--accent);
248
+ color: var(--text-primary);
249
+ }
250
+ .nav-link.active {
251
+ background: var(--accent);
252
+ color: var(--bg-primary);
253
+ border-color: var(--accent);
254
+ font-weight: 600;
255
+ }
256
+ .nav-label {
257
+ font-size: 0.75rem;
258
+ color: var(--text-secondary);
259
+ margin-right: 8px;
260
+ align-self: center;
261
+ }
262
+ .comparison-panel {
263
+ background: var(--bg-secondary);
264
+ border: 1px solid var(--border-color);
265
+ border-radius: 12px;
266
+ margin-bottom: 16px;
267
+ width: 100%;
268
+ }
269
+ .comparison-header {
270
+ display: flex;
271
+ align-items: center;
272
+ gap: 16px;
273
+ padding: 12px 18px;
274
+ border-bottom: 1px solid var(--border-color);
275
+ flex-wrap: wrap;
276
+ }
277
+ .comparison-panel h2 {
278
+ font-size: 0.9rem;
279
+ font-weight: 600;
280
+ margin: 0;
281
+ white-space: nowrap;
282
+ }
283
+ .comparison-content {
284
+ padding: 14px 18px;
285
+ display: grid;
286
+ grid-template-columns: 1fr 1fr auto;
287
+ gap: 16px;
288
+ align-items: start;
289
+ }
290
+ .action-box {
291
+ padding: 12px;
292
+ border-radius: 8px;
293
+ margin-bottom: 0;
294
+ }
295
+ .action-box.human {
296
+ background: rgba(0, 212, 170, 0.1);
297
+ border: 1px solid rgba(0, 212, 170, 0.3);
298
+ }
299
+ .action-box.predicted {
300
+ background: rgba(167, 139, 250, 0.1);
301
+ border: 1px solid rgba(167, 139, 250, 0.3);
302
+ }
303
+ .action-label {
304
+ font-size: 0.75rem;
305
+ text-transform: uppercase;
306
+ letter-spacing: 0.05em;
307
+ color: var(--text-muted);
308
+ margin-bottom: 6px;
309
+ }
310
+ .action-details {
311
+ font-family: "SF Mono", Monaco, monospace;
312
+ font-size: 0.85rem;
313
+ }
314
+ .match-indicator {
315
+ text-align: center;
316
+ padding: 8px;
317
+ border-radius: 6px;
318
+ font-weight: 600;
319
+ }
320
+ .match-indicator.match {
321
+ background: rgba(52, 211, 153, 0.2);
322
+ color: #34d399;
323
+ }
324
+ .match-indicator.mismatch {
325
+ background: rgba(255, 95, 95, 0.2);
326
+ color: #ff5f5f;
327
+ }
328
+ .match-indicator.pending {
329
+ background: var(--bg-tertiary);
330
+ color: var(--text-muted);
331
+ }
332
+ /* Visual overlays for clicks on screenshot */
333
+ .click-overlay {
334
+ position: absolute;
335
+ pointer-events: none;
336
+ z-index: 100;
337
+ }
338
+ .click-marker {
339
+ position: absolute;
340
+ width: 30px;
341
+ height: 30px;
342
+ border-radius: 50%;
343
+ transform: translate(-50%, -50%);
344
+ display: flex;
345
+ align-items: center;
346
+ justify-content: center;
347
+ font-size: 12px;
348
+ font-weight: bold;
349
+ animation: pulse-marker 1.5s ease-in-out infinite;
350
+ }
351
+ .click-marker.human {
352
+ background: rgba(0, 212, 170, 0.3);
353
+ border: 3px solid #00d4aa;
354
+ color: #00d4aa;
355
+ }
356
+ .click-marker.predicted {
357
+ background: rgba(167, 139, 250, 0.3);
358
+ border: 3px solid #a78bfa;
359
+ color: #a78bfa;
360
+ }
361
+ .click-marker.human::after {
362
+ content: 'H';
363
+ }
364
+ .click-marker.predicted::after {
365
+ content: 'AI';
366
+ font-size: 10px;
367
+ }
368
+ @keyframes pulse-marker {
369
+ 0%, 100% { transform: translate(-50%, -50%) scale(1); opacity: 1; }
370
+ 50% { transform: translate(-50%, -50%) scale(1.1); opacity: 0.8; }
371
+ }
372
+ /* Distance line between human and predicted */
373
+ .distance-line {
374
+ position: absolute;
375
+ height: 2px;
376
+ background: linear-gradient(90deg, #00d4aa, #a78bfa);
377
+ transform-origin: left center;
378
+ pointer-events: none;
379
+ z-index: 99;
380
+ }
381
+ /* Metrics summary - inline in header */
382
+ .metrics-summary {
383
+ display: flex;
384
+ gap: 16px;
385
+ padding: 6px 12px;
386
+ background: var(--bg-tertiary);
387
+ border-radius: 6px;
388
+ }
389
+ .metric-item {
390
+ display: flex;
391
+ align-items: center;
392
+ gap: 6px;
393
+ }
394
+ .metric-value {
395
+ font-size: 0.9rem;
396
+ font-weight: 600;
397
+ color: var(--accent);
398
+ }
399
+ .metric-label {
400
+ font-size: 0.7rem;
401
+ color: var(--text-muted);
402
+ text-transform: uppercase;
403
+ }
404
+ /* Toggle buttons - inline in header */
405
+ .overlay-toggles {
406
+ display: flex;
407
+ gap: 6px;
408
+ margin-left: auto;
409
+ }
410
+ .toggle-btn {
411
+ padding: 6px 12px;
412
+ border: 1px solid var(--border-color);
413
+ background: var(--bg-tertiary);
414
+ color: var(--text-primary);
415
+ border-radius: 6px;
416
+ cursor: pointer;
417
+ font-size: 0.75rem;
418
+ transition: all 0.2s;
419
+ white-space: nowrap;
420
+ }
421
+ .toggle-btn.active {
422
+ background: var(--accent);
423
+ color: var(--bg-primary);
424
+ border-color: var(--accent);
425
+ }
426
+ .toggle-btn:hover {
427
+ border-color: var(--accent);
428
+ }
429
+ </style>
430
+ '''
431
+
432
+ comparison_script = f'''
433
+ <script>
434
+ // Consolidated viewer script - all variables and functions in one scope
435
+ // Export to window for cross-script access (for checkpoint dropdown script)
436
+ window.comparisonData = {comparison_json};
437
+ const comparisonData = window.comparisonData; // Local alias
438
+ window.currentIndex = 0; // Explicit currentIndex declaration
439
+ let currentIndex = window.currentIndex; // Local alias
440
+ let showHumanOverlay = true;
441
+ let showPredictedOverlay = true;
442
+
443
+ // Compute aggregate metrics
444
+ window.computeMetrics = function() {{
445
+ let matches = 0;
446
+ let total = 0;
447
+ let totalDistance = 0;
448
+ let distanceCount = 0;
449
+
450
+ comparisonData.forEach(d => {{
451
+ if (d.match !== null) {{
452
+ total++;
453
+ if (d.match) matches++;
454
+ }}
455
+ // Compute spatial distance if both have coordinates
456
+ if (d.human_action.x !== null && d.predicted_action && d.predicted_action.x !== undefined) {{
457
+ const dx = d.human_action.x - d.predicted_action.x;
458
+ const dy = d.human_action.y - d.predicted_action.y;
459
+ totalDistance += Math.sqrt(dx*dx + dy*dy);
460
+ distanceCount++;
461
+ }}
462
+ }});
463
+
464
+ return {{
465
+ accuracy: total > 0 ? (matches / total * 100).toFixed(1) : 'N/A',
466
+ avgDistance: distanceCount > 0 ? (totalDistance / distanceCount * 100).toFixed(1) : 'N/A',
467
+ total: comparisonData.length
468
+ }};
469
+ }};
470
+ const computeMetrics = window.computeMetrics; // Local alias
471
+
472
+ window.updateClickOverlays = function(index) {{
473
+ // Remove existing overlays
474
+ document.querySelectorAll('.click-marker, .distance-line').forEach(el => el.remove());
475
+
476
+ const data = comparisonData[index];
477
+ if (!data) return;
478
+
479
+ const imgContainer = document.querySelector('.display-container');
480
+ if (!imgContainer) return;
481
+
482
+ // Make container relative for absolute positioning
483
+ imgContainer.style.position = 'relative';
484
+
485
+ // Human click marker
486
+ if (showHumanOverlay && data.human_action.x !== null) {{
487
+ const humanMarker = document.createElement('div');
488
+ humanMarker.className = 'click-marker human';
489
+ humanMarker.style.left = (data.human_action.x * 100) + '%';
490
+ humanMarker.style.top = (data.human_action.y * 100) + '%';
491
+ imgContainer.appendChild(humanMarker);
492
+ }}
493
+
494
+ // Predicted click marker
495
+ if (showPredictedOverlay && data.predicted_action && data.predicted_action.x !== undefined) {{
496
+ const predMarker = document.createElement('div');
497
+ predMarker.className = 'click-marker predicted';
498
+ predMarker.style.left = (data.predicted_action.x * 100) + '%';
499
+ predMarker.style.top = (data.predicted_action.y * 100) + '%';
500
+ imgContainer.appendChild(predMarker);
501
+
502
+ // Draw line between human and predicted if both visible
503
+ if (showHumanOverlay && data.human_action.x !== null) {{
504
+ const line = document.createElement('div');
505
+ line.className = 'distance-line';
506
+ const x1 = data.human_action.x * imgContainer.offsetWidth;
507
+ const y1 = data.human_action.y * imgContainer.offsetHeight;
508
+ const x2 = data.predicted_action.x * imgContainer.offsetWidth;
509
+ const y2 = data.predicted_action.y * imgContainer.offsetHeight;
510
+ const length = Math.sqrt((x2-x1)**2 + (y2-y1)**2);
511
+ const angle = Math.atan2(y2-y1, x2-x1) * 180 / Math.PI;
512
+ line.style.left = x1 + 'px';
513
+ line.style.top = y1 + 'px';
514
+ line.style.width = length + 'px';
515
+ line.style.transform = `rotate(${{angle}}deg)`;
516
+ imgContainer.appendChild(line);
517
+ }}
518
+ }}
519
+ }};
520
+ const updateClickOverlays = window.updateClickOverlays; // Local alias
521
+
522
+ window.updateComparison = function(index) {{
523
+ const data = comparisonData[index];
524
+ if (!data) return;
525
+
526
+ const humanEl = document.getElementById('human-action');
527
+ const predictedEl = document.getElementById('predicted-action');
528
+ const matchEl = document.getElementById('match-indicator');
529
+
530
+ // Human action
531
+ humanEl.innerHTML = `
532
+ <div>Type: ${{data.human_action.type}}</div>
533
+ ${{data.human_action.x !== null ? `<div>Position: (${{(data.human_action.x * 100).toFixed(1)}}%, ${{(data.human_action.y * 100).toFixed(1)}}%)</div>` : ''}}
534
+ ${{data.human_action.text ? `<div>Text: ${{data.human_action.text}}</div>` : ''}}
535
+ `;
536
+
537
+ // Predicted action
538
+ if (data.predicted_action) {{
539
+ const pred = data.predicted_action;
540
+ if (pred.x !== undefined) {{
541
+ predictedEl.innerHTML = `
542
+ <div>Type: ${{pred.type || 'click'}}</div>
543
+ <div>Position: (${{(pred.x * 100).toFixed(1)}}%, ${{(pred.y * 100).toFixed(1)}}%)</div>
544
+ `;
545
+ }} else {{
546
+ predictedEl.innerHTML = `<div>${{pred.raw_output || JSON.stringify(pred)}}</div>`;
547
+ }}
548
+ }} else {{
549
+ predictedEl.innerHTML = '<em style="color: var(--text-muted);">No model loaded</em>';
550
+ }}
551
+
552
+ // Match indicator
553
+ if (data.match === true) {{
554
+ matchEl.className = 'match-indicator match';
555
+ matchEl.textContent = '✓ Match';
556
+ }} else if (data.match === false) {{
557
+ matchEl.className = 'match-indicator mismatch';
558
+ matchEl.textContent = '✗ Mismatch';
559
+ }} else {{
560
+ matchEl.className = 'match-indicator pending';
561
+ matchEl.textContent = '— No prediction';
562
+ }}
563
+
564
+ // Update visual overlays
565
+ updateClickOverlays(index);
566
+
567
+ // Sync currentIndex to window
568
+ window.currentIndex = index;
569
+ }};
570
+ const updateComparison = window.updateComparison; // Local alias
571
+
572
+ window.setupOverlayToggles = function() {{
573
+ const togglesContainer = document.querySelector('.overlay-toggles');
574
+ if (!togglesContainer) return;
575
+
576
+ togglesContainer.innerHTML = `
577
+ <button class="toggle-btn active" id="toggle-human">Human (H)</button>
578
+ <button class="toggle-btn active" id="toggle-predicted">AI (P)</button>
579
+ `;
580
+
581
+ document.getElementById('toggle-human').addEventListener('click', function() {{
582
+ showHumanOverlay = !showHumanOverlay;
583
+ this.classList.toggle('active', showHumanOverlay);
584
+ updateClickOverlays(currentIndex);
585
+ }});
586
+
587
+ document.getElementById('toggle-predicted').addEventListener('click', function() {{
588
+ showPredictedOverlay = !showPredictedOverlay;
589
+ this.classList.toggle('active', showPredictedOverlay);
590
+ updateClickOverlays(currentIndex);
591
+ }});
592
+
593
+ // Keyboard shortcuts
594
+ document.addEventListener('keydown', (e) => {{
595
+ if (e.key === 'h' || e.key === 'H') {{
596
+ document.getElementById('toggle-human').click();
597
+ }} else if (e.key === 'p' || e.key === 'P') {{
598
+ document.getElementById('toggle-predicted').click();
599
+ }}
600
+ }});
601
+ }};
602
+ const setupOverlayToggles = window.setupOverlayToggles; // Local alias
603
+
604
+ window.setupMetricsSummary = function() {{
605
+ const metricsEl = document.querySelector('.metrics-summary');
606
+ if (!metricsEl) return;
607
+
608
+ const metrics = computeMetrics();
609
+ metricsEl.innerHTML = `
610
+ <div class="metric-item">
611
+ <span class="metric-label">Accuracy:</span>
612
+ <span class="metric-value">${{metrics.accuracy}}%</span>
613
+ </div>
614
+ <div class="metric-item">
615
+ <span class="metric-label">Avg Dist:</span>
616
+ <span class="metric-value">${{metrics.avgDistance}}%</span>
617
+ </div>
618
+ <div class="metric-item">
619
+ <span class="metric-label">Steps:</span>
620
+ <span class="metric-value">${{metrics.total}}</span>
621
+ </div>
622
+ `;
623
+ }};
624
+ const setupMetricsSummary = window.setupMetricsSummary; // Local alias
625
+
626
+ // Hook into existing updateDisplay
627
+ const originalUpdateDisplay = typeof updateDisplay !== 'undefined' ? updateDisplay : function() {{}};
628
+ window.updateDisplay = updateDisplay = function(skipAudioSync) {{
629
+ originalUpdateDisplay(skipAudioSync);
630
+ // Sync currentIndex from base viewer if it exists
631
+ if (typeof currentIndex !== 'undefined') {{
632
+ window.currentIndex = currentIndex;
633
+ }}
634
+ updateComparison(window.currentIndex);
635
+ }};
636
+
637
+ // Discover other dashboards in the same directory
638
+ async function discoverDashboards() {{
639
+ const currentFile = window.location.pathname.split('/').pop() || 'comparison.html';
640
+
641
+ // Create nav bar at top of container
642
+ const container = document.querySelector('.container') || document.body.firstElementChild;
643
+ if (!container) return;
644
+
645
+ const navBar = document.createElement('nav');
646
+ navBar.className = 'nav-bar';
647
+ navBar.innerHTML = '';
648
+ container.insertBefore(navBar, container.firstChild);
649
+
650
+ // Known dashboard patterns to look for
651
+ const patterns = [
652
+ 'dashboard.html',
653
+ 'comparison.html',
654
+ 'comparison_preview.html',
655
+ 'comparison_epoch0.html', 'comparison_epoch1.html', 'comparison_epoch2.html',
656
+ 'comparison_epoch3.html', 'comparison_epoch4.html', 'comparison_epoch5.html',
657
+ 'viewer.html'
658
+ ];
659
+
660
+ // For file:// protocol, only show essential links (fetch doesn't work)
661
+ const isFileProtocol = window.location.protocol === 'file:';
662
+
663
+ // Minimal links for file:// protocol - just the main ones
664
+ const fileProtocolLinks = ['dashboard.html', currentFile];
665
+
666
+ for (const file of patterns) {{
667
+ try {{
668
+ let exists = false;
669
+ if (isFileProtocol) {{
670
+ // For file://, only show essential links
671
+ exists = fileProtocolLinks.includes(file);
672
+ }} else {{
673
+ const response = await fetch(file, {{ method: 'HEAD' }});
674
+ exists = response.ok;
675
+ }}
676
+
677
+ if (exists) {{
678
+ const link = document.createElement('a');
679
+ link.href = file;
680
+ link.className = 'nav-link' + (file === currentFile ? ' active' : '');
681
+ // Pretty name - make comparison labels clear
682
+ if (file === 'dashboard.html') {{
683
+ link.textContent = 'Training';
684
+ }} else if (file.startsWith('comparison_epoch')) {{
685
+ const epoch = file.match(/epoch(\\d+)/)?.[1];
686
+ link.textContent = `Comparison (E${{epoch}})`;
687
+ }} else if (file === 'comparison.html') {{
688
+ link.textContent = 'Comparison';
689
+ }} else if (file === 'comparison_preview.html') {{
690
+ link.textContent = 'Preview';
691
+ }} else if (file === 'viewer.html') {{
692
+ link.textContent = 'Viewer';
693
+ }} else {{
694
+ link.textContent = file.replace('.html', '');
695
+ }}
696
+ navBar.appendChild(link);
697
+ }}
698
+ }} catch (e) {{
699
+ // File doesn't exist, skip
700
+ }}
701
+ }}
702
+ }}
703
+
704
+ // Initial setup
705
+ setTimeout(() => {{
706
+ setupOverlayToggles();
707
+ setupMetricsSummary();
708
+ updateComparison(window.currentIndex);
709
+ // Note: Nav is now injected via shared header HTML, no need for discoverDashboards()
710
+ }}, 100);
711
+ </script>
712
+ '''
713
+
714
+ # Insert into HTML
715
+ # Add shared header CSS and comparison styles before </head>
716
+ shared_header_css = f'<style>{_get_shared_header_css()}</style>'
717
+ html = base_html.replace('</head>', shared_header_css + comparison_styles + '</head>')
718
+
719
+ # Add shared header HTML after container div
720
+ shared_header_html = _generate_shared_header_html("viewer")
721
+ html = html.replace(
722
+ '<div class="container">',
723
+ '<div class="container">\n' + shared_header_html
724
+ )
725
+
726
+ # Add comparison panel as full-width row BEFORE the main-content/sidebar flex row
727
+ # Insert right BEFORE <div class="main-content"> as a sibling
728
+ html = html.replace(
729
+ '<div class="main-content">',
730
+ comparison_panel + '\n <div class="main-content">'
731
+ )
732
+
733
+ # Add script before </body>
734
+ html = html.replace('</body>', comparison_script + '</body>')
735
+
736
+ # Write output
737
+ output_path.write_text(html, encoding='utf-8')
738
+ print(f"Generated comparison viewer: {output_path}")
739
+
740
+ except ImportError:
741
+ print("Error: openadapt-capture is required for visualization")
742
+ print("Install with: pip install openadapt-capture")
743
+
744
+
745
+ def main():
746
+ parser = argparse.ArgumentParser(
747
+ description="Compare human actions vs model predictions on a capture."
748
+ )
749
+ parser.add_argument(
750
+ "--capture", "-c",
751
+ required=True,
752
+ help="Path to openadapt-capture recording directory",
753
+ )
754
+ parser.add_argument(
755
+ "--checkpoint", "-m",
756
+ help="Path to trained model checkpoint (optional)",
757
+ )
758
+ parser.add_argument(
759
+ "--output", "-o",
760
+ help="Output HTML path (default: capture_dir/comparison.html)",
761
+ )
762
+ parser.add_argument(
763
+ "--goal", "-g",
764
+ help="Task goal/description (auto-detected from capture if not provided)",
765
+ )
766
+ parser.add_argument(
767
+ "--open",
768
+ action="store_true",
769
+ help="Open viewer in browser after generation",
770
+ )
771
+ args = parser.parse_args()
772
+
773
+ capture_path = Path(args.capture)
774
+ if not capture_path.exists():
775
+ print(f"Error: Capture not found at {capture_path}")
776
+ return 1
777
+
778
+ # Convert capture to episode
779
+ print(f"Loading capture from: {capture_path}")
780
+ episode = capture_to_episode(capture_path, goal=args.goal)
781
+ print(f"Loaded {len(episode.steps)} steps")
782
+
783
+ # Load model if checkpoint provided
784
+ model = load_model(args.checkpoint)
785
+
786
+ # Generate comparison data
787
+ print("Generating comparison data...")
788
+ comparison_data = generate_comparison_data(episode, model)
789
+
790
+ # Compute stats
791
+ if model:
792
+ matches = sum(1 for d in comparison_data if d.get("match") is True)
793
+ total = sum(1 for d in comparison_data if d.get("match") is not None)
794
+ if total > 0:
795
+ print(f"Match rate: {matches}/{total} ({100*matches/total:.1f}%)")
796
+
797
+ # Generate HTML
798
+ output_path = Path(args.output) if args.output else capture_path / "comparison.html"
799
+ generate_comparison_html(capture_path, episode, comparison_data, output_path)
800
+
801
+ # Open in browser
802
+ if args.open:
803
+ import webbrowser
804
+ webbrowser.open(f"file://{output_path.absolute()}")
805
+
806
+ return 0
807
+
808
+
809
+ def generate_unified_viewer(
810
+ capture_path: Path,
811
+ episode: Episode,
812
+ predictions_by_checkpoint: dict[str, list[dict]],
813
+ output_path: Path,
814
+ capture_id: str | None = None,
815
+ available_captures: list[dict] | None = None,
816
+ ) -> None:
817
+ """Generate a unified viewer with dropdowns for capture and checkpoint selection.
818
+
819
+ Args:
820
+ capture_path: Path to the capture directory
821
+ episode: The episode data
822
+ predictions_by_checkpoint: Dict mapping checkpoint names to prediction lists
823
+ e.g. {"Epoch 1": [...], "Epoch 3": [...], "None": [...]}
824
+ output_path: Where to write the HTML
825
+ capture_id: ID of the current capture (for display)
826
+ available_captures: List of available captures for the dropdown
827
+ e.g. [{"id": "31807990", "name": "Turn off nightshift", "steps": 21}]
828
+ """
829
+ try:
830
+ from openadapt_capture.visualize.html import create_html
831
+
832
+ # Generate base viewer
833
+ base_html = create_html(capture_path, output=None)
834
+
835
+ # Prepare capture info
836
+ if capture_id is None:
837
+ capture_id = capture_path.name if capture_path else "unknown"
838
+
839
+ if available_captures is None:
840
+ available_captures = [{
841
+ "id": capture_id,
842
+ "name": episode.goal or "Untitled",
843
+ "steps": len(episode.steps),
844
+ }]
845
+
846
+ # Prepare base capture data (human actions only, no predictions)
847
+ base_data = []
848
+ for i, step in enumerate(episode.steps):
849
+ base_data.append({
850
+ "index": i,
851
+ "time": step.t,
852
+ "image_path": step.observation.image_path,
853
+ "human_action": {
854
+ "type": step.action.type,
855
+ "x": step.action.x,
856
+ "y": step.action.y,
857
+ "text": step.action.text,
858
+ },
859
+ })
860
+
861
+ # JSON encode all data
862
+ base_data_json = json.dumps(base_data)
863
+ predictions_json = json.dumps(predictions_by_checkpoint)
864
+ captures_json = json.dumps(available_captures)
865
+ current_capture_json = json.dumps(capture_id)
866
+
867
+ # Unified viewer styles and controls
868
+ unified_styles = '''
869
+ <style>
870
+ /* Navigation bar */
871
+ .nav-bar {
872
+ display: flex;
873
+ gap: 8px;
874
+ padding: 12px 16px;
875
+ background: var(--bg-secondary);
876
+ border: 1px solid var(--border-color);
877
+ border-radius: 8px;
878
+ margin-bottom: 16px;
879
+ flex-wrap: wrap;
880
+ align-items: center;
881
+ }
882
+ .nav-link {
883
+ padding: 8px 16px;
884
+ border-radius: 6px;
885
+ font-size: 0.8rem;
886
+ text-decoration: none;
887
+ color: var(--text-secondary);
888
+ background: var(--bg-tertiary);
889
+ border: 1px solid var(--border-color);
890
+ transition: all 0.2s;
891
+ }
892
+ .nav-link:hover {
893
+ border-color: var(--accent);
894
+ color: var(--text-primary);
895
+ }
896
+ .nav-link.active {
897
+ background: var(--accent);
898
+ color: var(--bg-primary);
899
+ border-color: var(--accent);
900
+ font-weight: 600;
901
+ }
902
+ .nav-label {
903
+ font-size: 0.75rem;
904
+ color: var(--text-secondary);
905
+ margin-right: 8px;
906
+ align-self: center;
907
+ }
908
+
909
+ /* Dropdown selectors */
910
+ .viewer-controls {
911
+ display: flex;
912
+ gap: 16px;
913
+ padding: 12px 16px;
914
+ background: var(--bg-secondary);
915
+ border: 1px solid var(--border-color);
916
+ border-radius: 8px;
917
+ margin-bottom: 16px;
918
+ flex-wrap: wrap;
919
+ align-items: center;
920
+ }
921
+ .control-group {
922
+ display: flex;
923
+ align-items: center;
924
+ gap: 8px;
925
+ }
926
+ .control-label {
927
+ font-size: 0.75rem;
928
+ color: var(--text-secondary);
929
+ text-transform: uppercase;
930
+ letter-spacing: 0.05em;
931
+ }
932
+ .control-select {
933
+ padding: 8px 12px;
934
+ border-radius: 6px;
935
+ font-size: 0.85rem;
936
+ background: var(--bg-tertiary);
937
+ color: var(--text-primary);
938
+ border: 1px solid var(--border-color);
939
+ cursor: pointer;
940
+ min-width: 180px;
941
+ }
942
+ .control-select:hover {
943
+ border-color: var(--accent);
944
+ }
945
+ .control-select:focus {
946
+ outline: none;
947
+ border-color: var(--accent);
948
+ box-shadow: 0 0 0 2px rgba(0, 212, 170, 0.2);
949
+ }
950
+ .control-hint {
951
+ font-size: 0.7rem;
952
+ color: var(--text-muted);
953
+ }
954
+
955
+ /* Comparison panel */
956
+ .comparison-panel {
957
+ background: var(--bg-secondary);
958
+ border: 1px solid var(--border-color);
959
+ border-radius: 12px;
960
+ margin-bottom: 16px;
961
+ width: 100%;
962
+ }
963
+ .comparison-header {
964
+ display: flex;
965
+ align-items: center;
966
+ gap: 16px;
967
+ padding: 12px 18px;
968
+ border-bottom: 1px solid var(--border-color);
969
+ flex-wrap: wrap;
970
+ }
971
+ .comparison-panel h2 {
972
+ font-size: 0.9rem;
973
+ font-weight: 600;
974
+ margin: 0;
975
+ white-space: nowrap;
976
+ }
977
+ .comparison-content {
978
+ padding: 14px 18px;
979
+ display: grid;
980
+ grid-template-columns: 1fr 1fr auto;
981
+ gap: 16px;
982
+ align-items: start;
983
+ }
984
+ .action-box {
985
+ padding: 12px;
986
+ border-radius: 8px;
987
+ margin-bottom: 0;
988
+ }
989
+ .action-box.human {
990
+ background: rgba(0, 212, 170, 0.1);
991
+ border: 1px solid rgba(0, 212, 170, 0.3);
992
+ }
993
+ .action-box.predicted {
994
+ background: rgba(167, 139, 250, 0.1);
995
+ border: 1px solid rgba(167, 139, 250, 0.3);
996
+ }
997
+ .action-box.predicted.disabled {
998
+ opacity: 0.5;
999
+ }
1000
+ .action-label {
1001
+ font-size: 0.75rem;
1002
+ text-transform: uppercase;
1003
+ letter-spacing: 0.05em;
1004
+ color: var(--text-muted);
1005
+ margin-bottom: 6px;
1006
+ }
1007
+ .action-details {
1008
+ font-family: "SF Mono", Monaco, monospace;
1009
+ font-size: 0.85rem;
1010
+ }
1011
+ .match-indicator {
1012
+ text-align: center;
1013
+ padding: 8px;
1014
+ border-radius: 6px;
1015
+ font-weight: 600;
1016
+ min-width: 80px;
1017
+ }
1018
+ .match-indicator.match {
1019
+ background: rgba(52, 211, 153, 0.2);
1020
+ color: #34d399;
1021
+ }
1022
+ .match-indicator.mismatch {
1023
+ background: rgba(255, 95, 95, 0.2);
1024
+ color: #ff5f5f;
1025
+ }
1026
+ .match-indicator.pending {
1027
+ background: var(--bg-tertiary);
1028
+ color: var(--text-muted);
1029
+ }
1030
+
1031
+ /* Visual overlays */
1032
+ .click-marker {
1033
+ position: absolute;
1034
+ width: 30px;
1035
+ height: 30px;
1036
+ border-radius: 50%;
1037
+ transform: translate(-50%, -50%);
1038
+ display: flex;
1039
+ align-items: center;
1040
+ justify-content: center;
1041
+ font-size: 12px;
1042
+ font-weight: bold;
1043
+ pointer-events: none;
1044
+ z-index: 100;
1045
+ animation: pulse-marker 1.5s ease-in-out infinite;
1046
+ }
1047
+ .click-marker.human {
1048
+ background: rgba(0, 212, 170, 0.3);
1049
+ border: 3px solid #00d4aa;
1050
+ color: #00d4aa;
1051
+ }
1052
+ .click-marker.predicted {
1053
+ background: rgba(167, 139, 250, 0.3);
1054
+ border: 3px solid #a78bfa;
1055
+ color: #a78bfa;
1056
+ }
1057
+ .click-marker.human::after { content: 'H'; }
1058
+ .click-marker.predicted::after { content: 'AI'; font-size: 10px; }
1059
+ @keyframes pulse-marker {
1060
+ 0%, 100% { transform: translate(-50%, -50%) scale(1); opacity: 1; }
1061
+ 50% { transform: translate(-50%, -50%) scale(1.1); opacity: 0.8; }
1062
+ }
1063
+ .distance-line {
1064
+ position: absolute;
1065
+ height: 2px;
1066
+ background: linear-gradient(90deg, #00d4aa, #a78bfa);
1067
+ transform-origin: left center;
1068
+ pointer-events: none;
1069
+ z-index: 99;
1070
+ }
1071
+
1072
+ /* Metrics summary */
1073
+ .metrics-summary {
1074
+ display: flex;
1075
+ gap: 16px;
1076
+ padding: 6px 12px;
1077
+ background: var(--bg-tertiary);
1078
+ border-radius: 6px;
1079
+ }
1080
+ .metric-item {
1081
+ display: flex;
1082
+ align-items: center;
1083
+ gap: 6px;
1084
+ }
1085
+ .metric-value {
1086
+ font-size: 0.9rem;
1087
+ font-weight: 600;
1088
+ color: var(--accent);
1089
+ }
1090
+ .metric-label {
1091
+ font-size: 0.7rem;
1092
+ color: var(--text-muted);
1093
+ text-transform: uppercase;
1094
+ }
1095
+
1096
+ /* Toggle buttons */
1097
+ .overlay-toggles {
1098
+ display: flex;
1099
+ gap: 6px;
1100
+ margin-left: auto;
1101
+ }
1102
+ .toggle-btn {
1103
+ padding: 6px 12px;
1104
+ border: 1px solid var(--border-color);
1105
+ background: var(--bg-tertiary);
1106
+ color: var(--text-primary);
1107
+ border-radius: 6px;
1108
+ cursor: pointer;
1109
+ font-size: 0.75rem;
1110
+ transition: all 0.2s;
1111
+ white-space: nowrap;
1112
+ }
1113
+ .toggle-btn.active {
1114
+ background: var(--accent);
1115
+ color: var(--bg-primary);
1116
+ border-color: var(--accent);
1117
+ }
1118
+ .toggle-btn:hover {
1119
+ border-color: var(--accent);
1120
+ }
1121
+ </style>
1122
+ '''
1123
+
1124
+ # Comparison panel HTML
1125
+ comparison_panel = '''
1126
+ <div class="viewer-controls" id="viewer-controls">
1127
+ <div class="control-group">
1128
+ <span class="control-label">Training Example:</span>
1129
+ <select class="control-select" id="capture-select"></select>
1130
+ <span class="control-hint" id="capture-hint"></span>
1131
+ </div>
1132
+ <div class="control-group">
1133
+ <span class="control-label">Checkpoint:</span>
1134
+ <select class="control-select" id="checkpoint-select"></select>
1135
+ </div>
1136
+ </div>
1137
+ <div class="comparison-panel" id="comparison-panel">
1138
+ <div class="comparison-header">
1139
+ <h2>Action Comparison</h2>
1140
+ <div class="metrics-summary" id="metrics-summary"></div>
1141
+ <div class="overlay-toggles" id="overlay-toggles"></div>
1142
+ </div>
1143
+ <div class="comparison-content">
1144
+ <div class="action-box human">
1145
+ <div class="action-label">Human Action</div>
1146
+ <div class="action-details" id="human-action"></div>
1147
+ </div>
1148
+ <div class="action-box predicted" id="predicted-box">
1149
+ <div class="action-label">Model Prediction</div>
1150
+ <div class="action-details" id="predicted-action"></div>
1151
+ </div>
1152
+ <div class="match-indicator" id="match-indicator"></div>
1153
+ </div>
1154
+ </div>
1155
+ '''
1156
+
1157
+ # Unified viewer script
1158
+ unified_script = f'''
1159
+ <script>
1160
+ // Consolidated unified viewer script - all variables in one scope
1161
+ // Data
1162
+ const baseData = {base_data_json};
1163
+ const predictionsByCheckpoint = {predictions_json};
1164
+ const availableCaptures = {captures_json};
1165
+ const currentCaptureId = {current_capture_json};
1166
+
1167
+ // State
1168
+ let currentIndex = 0; // Explicit currentIndex declaration
1169
+ let currentCheckpoint = 'None';
1170
+ let showHumanOverlay = true;
1171
+ let showPredictedOverlay = true;
1172
+
1173
+ // Get merged data for current checkpoint
1174
+ function getMergedData() {{
1175
+ const predictions = predictionsByCheckpoint[currentCheckpoint] || [];
1176
+ return baseData.map((base, i) => {{
1177
+ const pred = predictions[i] || {{}};
1178
+ return {{
1179
+ ...base,
1180
+ predicted_action: pred.predicted_action || null,
1181
+ match: pred.match !== undefined ? pred.match : null,
1182
+ }};
1183
+ }});
1184
+ }}
1185
+
1186
+ // Initialize dropdowns
1187
+ function initDropdowns() {{
1188
+ const captureSelect = document.getElementById('capture-select');
1189
+ const checkpointSelect = document.getElementById('checkpoint-select');
1190
+ const captureHint = document.getElementById('capture-hint');
1191
+
1192
+ // Populate capture dropdown
1193
+ captureSelect.innerHTML = '';
1194
+ availableCaptures.forEach(cap => {{
1195
+ const opt = document.createElement('option');
1196
+ opt.value = cap.id;
1197
+ opt.textContent = `${{cap.name}} (${{cap.steps}} steps)`;
1198
+ opt.selected = cap.id === currentCaptureId;
1199
+ captureSelect.appendChild(opt);
1200
+ }});
1201
+
1202
+ // Show hint about available captures
1203
+ captureHint.textContent = `(${{availableCaptures.length}} available)`;
1204
+
1205
+ // Populate checkpoint dropdown
1206
+ checkpointSelect.innerHTML = '';
1207
+ const checkpointNames = Object.keys(predictionsByCheckpoint);
1208
+ // Sort: "None" first, then by epoch number
1209
+ checkpointNames.sort((a, b) => {{
1210
+ if (a === 'None') return -1;
1211
+ if (b === 'None') return 1;
1212
+ const aNum = parseInt(a.match(/\\d+/)?.[0] || '999');
1213
+ const bNum = parseInt(b.match(/\\d+/)?.[0] || '999');
1214
+ return aNum - bNum;
1215
+ }});
1216
+
1217
+ checkpointNames.forEach(name => {{
1218
+ const opt = document.createElement('option');
1219
+ opt.value = name;
1220
+ opt.textContent = name === 'None' ? 'None (Capture Only)' : name;
1221
+ checkpointSelect.appendChild(opt);
1222
+ }});
1223
+
1224
+ // Set default to latest non-None checkpoint if available
1225
+ const latestCheckpoint = checkpointNames.filter(n => n !== 'None').pop();
1226
+ if (latestCheckpoint) {{
1227
+ checkpointSelect.value = latestCheckpoint;
1228
+ currentCheckpoint = latestCheckpoint;
1229
+ }}
1230
+
1231
+ // Event handlers
1232
+ captureSelect.addEventListener('change', (e) => {{
1233
+ // In future: load different capture
1234
+ // For now, just show that we'd switch
1235
+ console.log('Would switch to capture:', e.target.value);
1236
+ }});
1237
+
1238
+ checkpointSelect.addEventListener('change', (e) => {{
1239
+ currentCheckpoint = e.target.value;
1240
+ updateMetrics();
1241
+ updateComparison(typeof currentIndex !== 'undefined' ? currentIndex : 0);
1242
+ }});
1243
+ }}
1244
+
1245
+ // Compute metrics for current checkpoint
1246
+ function computeMetrics() {{
1247
+ const data = getMergedData();
1248
+ let matches = 0;
1249
+ let total = 0;
1250
+ let totalDistance = 0;
1251
+ let distanceCount = 0;
1252
+
1253
+ data.forEach(d => {{
1254
+ if (d.match !== null) {{
1255
+ total++;
1256
+ if (d.match) matches++;
1257
+ }}
1258
+ if (d.human_action.x !== null && d.predicted_action && d.predicted_action.x !== undefined) {{
1259
+ const dx = d.human_action.x - d.predicted_action.x;
1260
+ const dy = d.human_action.y - d.predicted_action.y;
1261
+ totalDistance += Math.sqrt(dx*dx + dy*dy);
1262
+ distanceCount++;
1263
+ }}
1264
+ }});
1265
+
1266
+ return {{
1267
+ accuracy: total > 0 ? (matches / total * 100).toFixed(1) : 'N/A',
1268
+ avgDistance: distanceCount > 0 ? (totalDistance / distanceCount * 100).toFixed(1) : 'N/A',
1269
+ total: data.length,
1270
+ hasPredictions: total > 0,
1271
+ }};
1272
+ }}
1273
+
1274
+ // Update metrics display
1275
+ function updateMetrics() {{
1276
+ const metricsEl = document.getElementById('metrics-summary');
1277
+ const metrics = computeMetrics();
1278
+
1279
+ if (!metrics.hasPredictions) {{
1280
+ metricsEl.innerHTML = `
1281
+ <div class="metric-item">
1282
+ <span class="metric-label">Steps:</span>
1283
+ <span class="metric-value">${{metrics.total}}</span>
1284
+ </div>
1285
+ <div class="metric-item">
1286
+ <span style="color: var(--text-muted); font-size: 0.75rem;">No predictions - select a checkpoint</span>
1287
+ </div>
1288
+ `;
1289
+ }} else {{
1290
+ metricsEl.innerHTML = `
1291
+ <div class="metric-item">
1292
+ <span class="metric-label">Accuracy:</span>
1293
+ <span class="metric-value">${{metrics.accuracy}}%</span>
1294
+ </div>
1295
+ <div class="metric-item">
1296
+ <span class="metric-label">Avg Dist:</span>
1297
+ <span class="metric-value">${{metrics.avgDistance}}%</span>
1298
+ </div>
1299
+ <div class="metric-item">
1300
+ <span class="metric-label">Steps:</span>
1301
+ <span class="metric-value">${{metrics.total}}</span>
1302
+ </div>
1303
+ `;
1304
+ }}
1305
+ }}
1306
+
1307
+ // Update click overlays on screenshot
1308
+ function updateClickOverlays(index) {{
1309
+ document.querySelectorAll('.click-marker, .distance-line').forEach(el => el.remove());
1310
+
1311
+ const data = getMergedData()[index];
1312
+ if (!data) return;
1313
+
1314
+ const imgContainer = document.querySelector('.display-container');
1315
+ if (!imgContainer) return;
1316
+ imgContainer.style.position = 'relative';
1317
+
1318
+ // Human click marker
1319
+ if (showHumanOverlay && data.human_action.x !== null) {{
1320
+ const humanMarker = document.createElement('div');
1321
+ humanMarker.className = 'click-marker human';
1322
+ humanMarker.style.left = (data.human_action.x * 100) + '%';
1323
+ humanMarker.style.top = (data.human_action.y * 100) + '%';
1324
+ imgContainer.appendChild(humanMarker);
1325
+ }}
1326
+
1327
+ // Predicted click marker
1328
+ if (showPredictedOverlay && data.predicted_action && data.predicted_action.x !== undefined) {{
1329
+ const predMarker = document.createElement('div');
1330
+ predMarker.className = 'click-marker predicted';
1331
+ predMarker.style.left = (data.predicted_action.x * 100) + '%';
1332
+ predMarker.style.top = (data.predicted_action.y * 100) + '%';
1333
+ imgContainer.appendChild(predMarker);
1334
+
1335
+ // Draw line between human and predicted
1336
+ if (showHumanOverlay && data.human_action.x !== null) {{
1337
+ const line = document.createElement('div');
1338
+ line.className = 'distance-line';
1339
+ const x1 = data.human_action.x * imgContainer.offsetWidth;
1340
+ const y1 = data.human_action.y * imgContainer.offsetHeight;
1341
+ const x2 = data.predicted_action.x * imgContainer.offsetWidth;
1342
+ const y2 = data.predicted_action.y * imgContainer.offsetHeight;
1343
+ const length = Math.sqrt((x2-x1)**2 + (y2-y1)**2);
1344
+ const angle = Math.atan2(y2-y1, x2-x1) * 180 / Math.PI;
1345
+ line.style.left = x1 + 'px';
1346
+ line.style.top = y1 + 'px';
1347
+ line.style.width = length + 'px';
1348
+ line.style.transform = `rotate(${{angle}}deg)`;
1349
+ imgContainer.appendChild(line);
1350
+ }}
1351
+ }}
1352
+ }}
1353
+
1354
+ // Update comparison display
1355
+ function updateComparison(index) {{
1356
+ const data = getMergedData()[index];
1357
+ if (!data) return;
1358
+
1359
+ const humanEl = document.getElementById('human-action');
1360
+ const predictedEl = document.getElementById('predicted-action');
1361
+ const predictedBox = document.getElementById('predicted-box');
1362
+ const matchEl = document.getElementById('match-indicator');
1363
+
1364
+ // Human action
1365
+ humanEl.innerHTML = `
1366
+ <div>Type: ${{data.human_action.type}}</div>
1367
+ ${{data.human_action.x !== null ? `<div>Position: (${{(data.human_action.x * 100).toFixed(1)}}%, ${{(data.human_action.y * 100).toFixed(1)}}%)</div>` : ''}}
1368
+ ${{data.human_action.text ? `<div>Text: ${{data.human_action.text}}</div>` : ''}}
1369
+ `;
1370
+
1371
+ // Predicted action
1372
+ const hasPredictions = currentCheckpoint !== 'None';
1373
+ predictedBox.classList.toggle('disabled', !hasPredictions);
1374
+
1375
+ if (!hasPredictions) {{
1376
+ predictedEl.innerHTML = '<em style="color: var(--text-muted);">Select a checkpoint to see predictions</em>';
1377
+ }} else if (data.predicted_action) {{
1378
+ const pred = data.predicted_action;
1379
+ if (pred.x !== undefined) {{
1380
+ predictedEl.innerHTML = `
1381
+ <div>Type: ${{pred.type || 'click'}}</div>
1382
+ <div>Position: (${{(pred.x * 100).toFixed(1)}}%, ${{(pred.y * 100).toFixed(1)}}%)</div>
1383
+ `;
1384
+ }} else {{
1385
+ predictedEl.innerHTML = `<div>${{pred.raw_output || JSON.stringify(pred)}}</div>`;
1386
+ }}
1387
+ }} else {{
1388
+ predictedEl.innerHTML = '<em style="color: var(--text-muted);">No prediction available</em>';
1389
+ }}
1390
+
1391
+ // Match indicator
1392
+ if (!hasPredictions) {{
1393
+ matchEl.className = 'match-indicator pending';
1394
+ matchEl.textContent = '—';
1395
+ }} else if (data.match === true) {{
1396
+ matchEl.className = 'match-indicator match';
1397
+ matchEl.textContent = '✓ Match';
1398
+ }} else if (data.match === false) {{
1399
+ matchEl.className = 'match-indicator mismatch';
1400
+ matchEl.textContent = '✗ Mismatch';
1401
+ }} else {{
1402
+ matchEl.className = 'match-indicator pending';
1403
+ matchEl.textContent = '— No prediction';
1404
+ }}
1405
+
1406
+ updateClickOverlays(index);
1407
+ }}
1408
+
1409
+ // Setup overlay toggle buttons
1410
+ function setupOverlayToggles() {{
1411
+ const togglesContainer = document.getElementById('overlay-toggles');
1412
+ togglesContainer.innerHTML = `
1413
+ <button class="toggle-btn active" id="toggle-human">Human (H)</button>
1414
+ <button class="toggle-btn active" id="toggle-predicted">AI (P)</button>
1415
+ `;
1416
+
1417
+ document.getElementById('toggle-human').addEventListener('click', function() {{
1418
+ showHumanOverlay = !showHumanOverlay;
1419
+ this.classList.toggle('active', showHumanOverlay);
1420
+ updateClickOverlays(typeof currentIndex !== 'undefined' ? currentIndex : 0);
1421
+ }});
1422
+
1423
+ document.getElementById('toggle-predicted').addEventListener('click', function() {{
1424
+ showPredictedOverlay = !showPredictedOverlay;
1425
+ this.classList.toggle('active', showPredictedOverlay);
1426
+ updateClickOverlays(typeof currentIndex !== 'undefined' ? currentIndex : 0);
1427
+ }});
1428
+
1429
+ document.addEventListener('keydown', (e) => {{
1430
+ if (e.key === 'h' || e.key === 'H') document.getElementById('toggle-human').click();
1431
+ if (e.key === 'p' || e.key === 'P') document.getElementById('toggle-predicted').click();
1432
+ }});
1433
+ }}
1434
+
1435
+ // Create navigation bar
1436
+ function createNavBar() {{
1437
+ const container = document.querySelector('.container') || document.body.firstElementChild;
1438
+ if (!container) return;
1439
+
1440
+ const navBar = document.createElement('nav');
1441
+ navBar.className = 'nav-bar';
1442
+ navBar.id = 'nav-bar';
1443
+ navBar.innerHTML = `
1444
+ <a href="dashboard.html" class="nav-link">Training</a>
1445
+ <a href="viewer.html" class="nav-link active">Viewer</a>
1446
+ `;
1447
+ container.insertBefore(navBar, container.firstChild);
1448
+ }}
1449
+
1450
+ // Hook into existing updateDisplay
1451
+ const originalUpdateDisplay = typeof updateDisplay !== 'undefined' ? updateDisplay : function() {{}};
1452
+ updateDisplay = function(skipAudioSync) {{
1453
+ originalUpdateDisplay(skipAudioSync);
1454
+ // Sync currentIndex from base viewer if it exists
1455
+ if (typeof currentIndex !== 'undefined') {{
1456
+ currentIndex = currentIndex;
1457
+ }}
1458
+ updateComparison(currentIndex);
1459
+ }};
1460
+
1461
+ // Initialize
1462
+ setTimeout(() => {{
1463
+ createNavBar();
1464
+ initDropdowns();
1465
+ setupOverlayToggles();
1466
+ updateMetrics();
1467
+ updateComparison(currentIndex);
1468
+ }}, 100);
1469
+ </script>
1470
+ '''
1471
+
1472
+ # Inject into HTML
1473
+ html = base_html.replace('</head>', unified_styles + '</head>')
1474
+ html = html.replace(
1475
+ '<div class="main-content">',
1476
+ comparison_panel + '\n <div class="main-content">'
1477
+ )
1478
+ html = html.replace('</body>', unified_script + '</body>')
1479
+
1480
+ # Write output
1481
+ output_path.write_text(html, encoding='utf-8')
1482
+ print(f"Generated unified viewer: {output_path}")
1483
+
1484
+ except ImportError:
1485
+ print("Error: openadapt-capture is required for visualization")
1486
+ print("Install with: pip install openadapt-capture")
1487
+
1488
+
1489
+ if __name__ == "__main__":
1490
+ exit(main())