openadapt-ml 0.2.0__py3-none-any.whl → 0.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (95) hide show
  1. openadapt_ml/baselines/__init__.py +121 -0
  2. openadapt_ml/baselines/adapter.py +185 -0
  3. openadapt_ml/baselines/cli.py +314 -0
  4. openadapt_ml/baselines/config.py +448 -0
  5. openadapt_ml/baselines/parser.py +922 -0
  6. openadapt_ml/baselines/prompts.py +787 -0
  7. openadapt_ml/benchmarks/__init__.py +13 -115
  8. openadapt_ml/benchmarks/agent.py +265 -421
  9. openadapt_ml/benchmarks/azure.py +28 -19
  10. openadapt_ml/benchmarks/azure_ops_tracker.py +521 -0
  11. openadapt_ml/benchmarks/cli.py +1722 -4847
  12. openadapt_ml/benchmarks/trace_export.py +631 -0
  13. openadapt_ml/benchmarks/viewer.py +22 -5
  14. openadapt_ml/benchmarks/vm_monitor.py +530 -29
  15. openadapt_ml/benchmarks/waa_deploy/Dockerfile +47 -53
  16. openadapt_ml/benchmarks/waa_deploy/api_agent.py +21 -20
  17. openadapt_ml/cloud/azure_inference.py +3 -5
  18. openadapt_ml/cloud/lambda_labs.py +722 -307
  19. openadapt_ml/cloud/local.py +2038 -487
  20. openadapt_ml/cloud/ssh_tunnel.py +68 -26
  21. openadapt_ml/datasets/next_action.py +40 -30
  22. openadapt_ml/evals/grounding.py +8 -3
  23. openadapt_ml/evals/plot_eval_metrics.py +15 -13
  24. openadapt_ml/evals/trajectory_matching.py +41 -26
  25. openadapt_ml/experiments/demo_prompt/format_demo.py +16 -6
  26. openadapt_ml/experiments/demo_prompt/run_experiment.py +26 -16
  27. openadapt_ml/experiments/representation_shootout/__init__.py +70 -0
  28. openadapt_ml/experiments/representation_shootout/conditions.py +708 -0
  29. openadapt_ml/experiments/representation_shootout/config.py +390 -0
  30. openadapt_ml/experiments/representation_shootout/evaluator.py +659 -0
  31. openadapt_ml/experiments/representation_shootout/runner.py +687 -0
  32. openadapt_ml/experiments/waa_demo/runner.py +29 -14
  33. openadapt_ml/export/parquet.py +36 -24
  34. openadapt_ml/grounding/detector.py +18 -14
  35. openadapt_ml/ingest/__init__.py +8 -6
  36. openadapt_ml/ingest/capture.py +25 -22
  37. openadapt_ml/ingest/loader.py +7 -4
  38. openadapt_ml/ingest/synthetic.py +189 -100
  39. openadapt_ml/models/api_adapter.py +14 -4
  40. openadapt_ml/models/base_adapter.py +10 -2
  41. openadapt_ml/models/providers/__init__.py +288 -0
  42. openadapt_ml/models/providers/anthropic.py +266 -0
  43. openadapt_ml/models/providers/base.py +299 -0
  44. openadapt_ml/models/providers/google.py +376 -0
  45. openadapt_ml/models/providers/openai.py +342 -0
  46. openadapt_ml/models/qwen_vl.py +46 -19
  47. openadapt_ml/perception/__init__.py +35 -0
  48. openadapt_ml/perception/integration.py +399 -0
  49. openadapt_ml/retrieval/demo_retriever.py +50 -24
  50. openadapt_ml/retrieval/embeddings.py +9 -8
  51. openadapt_ml/retrieval/retriever.py +3 -1
  52. openadapt_ml/runtime/__init__.py +50 -0
  53. openadapt_ml/runtime/policy.py +18 -5
  54. openadapt_ml/runtime/safety_gate.py +471 -0
  55. openadapt_ml/schema/__init__.py +9 -0
  56. openadapt_ml/schema/converters.py +74 -27
  57. openadapt_ml/schema/episode.py +31 -18
  58. openadapt_ml/scripts/capture_screenshots.py +530 -0
  59. openadapt_ml/scripts/compare.py +85 -54
  60. openadapt_ml/scripts/demo_policy.py +4 -1
  61. openadapt_ml/scripts/eval_policy.py +15 -9
  62. openadapt_ml/scripts/make_gif.py +1 -1
  63. openadapt_ml/scripts/prepare_synthetic.py +3 -1
  64. openadapt_ml/scripts/train.py +21 -9
  65. openadapt_ml/segmentation/README.md +920 -0
  66. openadapt_ml/segmentation/__init__.py +97 -0
  67. openadapt_ml/segmentation/adapters/__init__.py +5 -0
  68. openadapt_ml/segmentation/adapters/capture_adapter.py +420 -0
  69. openadapt_ml/segmentation/annotator.py +610 -0
  70. openadapt_ml/segmentation/cache.py +290 -0
  71. openadapt_ml/segmentation/cli.py +674 -0
  72. openadapt_ml/segmentation/deduplicator.py +656 -0
  73. openadapt_ml/segmentation/frame_describer.py +788 -0
  74. openadapt_ml/segmentation/pipeline.py +340 -0
  75. openadapt_ml/segmentation/schemas.py +622 -0
  76. openadapt_ml/segmentation/segment_extractor.py +634 -0
  77. openadapt_ml/training/azure_ops_viewer.py +1097 -0
  78. openadapt_ml/training/benchmark_viewer.py +52 -41
  79. openadapt_ml/training/shared_ui.py +7 -7
  80. openadapt_ml/training/stub_provider.py +57 -35
  81. openadapt_ml/training/trainer.py +143 -86
  82. openadapt_ml/training/trl_trainer.py +70 -21
  83. openadapt_ml/training/viewer.py +323 -108
  84. openadapt_ml/training/viewer_components.py +180 -0
  85. {openadapt_ml-0.2.0.dist-info → openadapt_ml-0.2.2.dist-info}/METADATA +215 -14
  86. openadapt_ml-0.2.2.dist-info/RECORD +116 -0
  87. openadapt_ml/benchmarks/base.py +0 -366
  88. openadapt_ml/benchmarks/data_collection.py +0 -432
  89. openadapt_ml/benchmarks/live_tracker.py +0 -180
  90. openadapt_ml/benchmarks/runner.py +0 -418
  91. openadapt_ml/benchmarks/waa.py +0 -761
  92. openadapt_ml/benchmarks/waa_live.py +0 -619
  93. openadapt_ml-0.2.0.dist-info/RECORD +0 -86
  94. {openadapt_ml-0.2.0.dist-info → openadapt_ml-0.2.2.dist-info}/WHEEL +0 -0
  95. {openadapt_ml-0.2.0.dist-info → openadapt_ml-0.2.2.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,687 @@
1
+ """Main experiment runner for the Representation Shootout.
2
+
3
+ This module provides:
4
+ 1. ExperimentRunner class for running the full experiment
5
+ 2. CLI interface for experiment execution
6
+ 3. Results reporting and recommendation generation
7
+
8
+ Usage:
9
+ # Run full experiment
10
+ python -m openadapt_ml.experiments.representation_shootout.runner run
11
+
12
+ # Run specific condition
13
+ python -m openadapt_ml.experiments.representation_shootout.runner run --condition marks
14
+
15
+ # Evaluate under specific drift
16
+ python -m openadapt_ml.experiments.representation_shootout.runner eval --drift resolution
17
+
18
+ # Generate recommendation from existing results
19
+ python -m openadapt_ml.experiments.representation_shootout.runner recommend --results-dir results/
20
+ """
21
+
22
+ from __future__ import annotations
23
+
24
+ import argparse
25
+ import json
26
+ import logging
27
+ import sys
28
+ from datetime import datetime
29
+ from pathlib import Path
30
+
31
+ from openadapt_ml.experiments.representation_shootout.conditions import (
32
+ ConditionBase,
33
+ Observation,
34
+ UIElement,
35
+ UIElementGraph,
36
+ create_condition,
37
+ )
38
+ from openadapt_ml.experiments.representation_shootout.config import (
39
+ ConditionName,
40
+ DriftConfig,
41
+ ExperimentConfig,
42
+ MetricName,
43
+ )
44
+ from openadapt_ml.experiments.representation_shootout.evaluator import (
45
+ DriftEvaluator,
46
+ EvaluationResult,
47
+ Recommendation,
48
+ Sample,
49
+ make_recommendation,
50
+ )
51
+
52
+ logger = logging.getLogger(__name__)
53
+
54
+
55
+ class ExperimentRunner:
56
+ """Orchestrates the Representation Shootout experiment.
57
+
58
+ This class manages:
59
+ 1. Loading/generating evaluation data
60
+ 2. Running conditions under drift
61
+ 3. Computing metrics and generating recommendations
62
+ 4. Saving results
63
+ """
64
+
65
+ def __init__(self, config: ExperimentConfig):
66
+ """Initialize experiment runner.
67
+
68
+ Args:
69
+ config: Experiment configuration.
70
+ """
71
+ self.config = config
72
+ self.conditions: dict[ConditionName, ConditionBase] = {}
73
+ self.evaluator: DriftEvaluator | None = None
74
+ self.results: list[EvaluationResult] = []
75
+
76
+ # Validate config
77
+ issues = config.validate()
78
+ if issues:
79
+ raise ValueError(f"Invalid config: {issues}")
80
+
81
+ # Initialize conditions
82
+ for cond_config in config.conditions:
83
+ self.conditions[cond_config.name] = create_condition(cond_config)
84
+
85
+ # Initialize evaluator
86
+ self.evaluator = DriftEvaluator(self.conditions, config.drift_tests)
87
+
88
+ def load_samples(self, data_path: str | None = None) -> list[Sample]:
89
+ """Load evaluation samples from data directory.
90
+
91
+ This is a scaffolding implementation. Full implementation would:
92
+ 1. Load screenshots from data_path
93
+ 2. Load ground truth actions
94
+ 3. Load UI element annotations (for marks condition)
95
+
96
+ Args:
97
+ data_path: Path to data directory (uses config.dataset.eval_path if None).
98
+
99
+ Returns:
100
+ List of Sample objects.
101
+ """
102
+ data_path = data_path or self.config.dataset.eval_path
103
+
104
+ if data_path and Path(data_path).exists():
105
+ # Load from files
106
+ samples = self._load_samples_from_path(data_path)
107
+ if samples:
108
+ return samples
109
+
110
+ # Generate synthetic samples for scaffolding
111
+ logger.info("Generating synthetic samples for scaffolding")
112
+ return self._generate_synthetic_samples(num_samples=100)
113
+
114
+ def _load_samples_from_path(self, data_path: str) -> list[Sample]:
115
+ """Load samples from a data directory.
116
+
117
+ Expected structure:
118
+ data_path/
119
+ samples.json # List of sample metadata
120
+ screenshots/
121
+ sample_001.png
122
+ sample_002.png
123
+ ...
124
+
125
+ Args:
126
+ data_path: Path to data directory.
127
+
128
+ Returns:
129
+ List of Sample objects.
130
+ """
131
+ data_dir = Path(data_path)
132
+ samples_file = data_dir / "samples.json"
133
+
134
+ if not samples_file.exists():
135
+ logger.warning(f"No samples.json found in {data_path}")
136
+ return []
137
+
138
+ with open(samples_file) as f:
139
+ samples_data = json.load(f)
140
+
141
+ samples = []
142
+ for item in samples_data:
143
+ # Build UI elements if present
144
+ ui_elements = None
145
+ if "ui_elements" in item:
146
+ elements = [
147
+ UIElement(
148
+ element_id=el["id"],
149
+ role=el.get("role", "unknown"),
150
+ name=el.get("name"),
151
+ bbox=tuple(el["bbox"]), # type: ignore
152
+ )
153
+ for el in item["ui_elements"]
154
+ ]
155
+ ui_elements = UIElementGraph(elements=elements)
156
+
157
+ observation = Observation(
158
+ screenshot_path=str(data_dir / "screenshots" / item["screenshot"]),
159
+ screen_size=tuple(item.get("screen_size", (1920, 1080))), # type: ignore
160
+ ui_elements=ui_elements,
161
+ )
162
+
163
+ sample = Sample(
164
+ sample_id=item["id"],
165
+ observation=observation,
166
+ goal=item["goal"],
167
+ ground_truth=item["ground_truth"],
168
+ )
169
+ samples.append(sample)
170
+
171
+ logger.info(f"Loaded {len(samples)} samples from {data_path}")
172
+ return samples
173
+
174
+ def _generate_synthetic_samples(self, num_samples: int = 100) -> list[Sample]:
175
+ """Generate synthetic samples for scaffolding.
176
+
177
+ These are placeholder samples for testing the framework.
178
+ Real experiments should use actual UI data.
179
+
180
+ Args:
181
+ num_samples: Number of samples to generate.
182
+
183
+ Returns:
184
+ List of synthetic Sample objects.
185
+ """
186
+ import random
187
+
188
+ random.seed(self.config.seed)
189
+
190
+ samples = []
191
+ for i in range(num_samples):
192
+ # Generate random UI elements
193
+ num_elements = random.randint(5, 20)
194
+ elements = []
195
+ for j in range(num_elements):
196
+ x1 = random.uniform(0, 0.8)
197
+ y1 = random.uniform(0, 0.8)
198
+ w = random.uniform(0.05, 0.2)
199
+ h = random.uniform(0.03, 0.1)
200
+ elements.append(
201
+ UIElement(
202
+ element_id=f"e{j + 1}",
203
+ role=random.choice(["button", "textfield", "link", "checkbox"]),
204
+ name=f"Element {j + 1}",
205
+ bbox=(x1, y1, x1 + w, y1 + h),
206
+ )
207
+ )
208
+
209
+ ui_elements = UIElementGraph(elements=elements)
210
+
211
+ # Pick a random target element
212
+ target_element = random.choice(elements)
213
+ target_x, target_y = target_element.center
214
+
215
+ observation = Observation(
216
+ screenshot_path=None, # No actual screenshot in scaffolding
217
+ screen_size=(1920, 1080),
218
+ ui_elements=ui_elements,
219
+ )
220
+
221
+ sample = Sample(
222
+ sample_id=f"synthetic_{i:04d}",
223
+ observation=observation,
224
+ goal=f"Click the {target_element.role} named '{target_element.name}'",
225
+ ground_truth={
226
+ "type": "click",
227
+ "x": target_x,
228
+ "y": target_y,
229
+ "element_id": target_element.element_id,
230
+ "target_bbox": target_element.bbox,
231
+ },
232
+ )
233
+ samples.append(sample)
234
+
235
+ logger.info(f"Generated {num_samples} synthetic samples")
236
+ return samples
237
+
238
+ def get_model_predictions(
239
+ self,
240
+ condition: ConditionBase,
241
+ samples: list[Sample],
242
+ drift_config: DriftConfig,
243
+ ) -> list[str]:
244
+ """Get model predictions for samples.
245
+
246
+ This is a scaffolding implementation that returns mock predictions.
247
+ Full implementation would:
248
+ 1. Prepare inputs using condition.prepare_input()
249
+ 2. Send to model (VLM API or local model)
250
+ 3. Return raw model outputs
251
+
252
+ Args:
253
+ condition: Condition to use for input preparation.
254
+ samples: Samples to get predictions for.
255
+ drift_config: Drift applied to samples.
256
+
257
+ Returns:
258
+ List of raw model output strings.
259
+ """
260
+ # Scaffolding: Generate plausible mock predictions
261
+ import random
262
+
263
+ random.seed(
264
+ self.config.seed + hash(condition.name.value) + hash(drift_config.name)
265
+ )
266
+
267
+ predictions = []
268
+ for sample in samples:
269
+ gt = sample.ground_truth
270
+
271
+ if condition.name == ConditionName.MARKS:
272
+ # Generate element ID prediction
273
+ # Simulate some errors based on drift
274
+ error_rate = self._get_error_rate(drift_config)
275
+ if random.random() < error_rate:
276
+ # Make an error - pick wrong element
277
+ if sample.observation.ui_elements:
278
+ wrong_el = random.choice(
279
+ sample.observation.ui_elements.elements
280
+ )
281
+ predictions.append(f"ACTION: CLICK([{wrong_el.element_id}])")
282
+ else:
283
+ predictions.append("ACTION: CLICK([e1])")
284
+ else:
285
+ # Correct prediction
286
+ predictions.append(f"ACTION: CLICK([{gt.get('element_id', 'e1')}])")
287
+ else:
288
+ # Generate coordinate prediction
289
+ # Add some noise based on drift
290
+ noise_std = self._get_coordinate_noise(drift_config)
291
+ pred_x = gt.get("x", 0.5) + random.gauss(0, noise_std)
292
+ pred_y = gt.get("y", 0.5) + random.gauss(0, noise_std)
293
+ # Clamp to valid range
294
+ pred_x = max(0, min(1, pred_x))
295
+ pred_y = max(0, min(1, pred_y))
296
+ predictions.append(f"ACTION: CLICK({pred_x:.4f}, {pred_y:.4f})")
297
+
298
+ return predictions
299
+
300
+ def _get_error_rate(self, drift_config: DriftConfig) -> float:
301
+ """Get expected error rate for marks condition under drift."""
302
+ if drift_config.is_canonical:
303
+ return 0.05 # 5% baseline error
304
+
305
+ # Different drifts have different impacts
306
+ drift_impact = {
307
+ "resolution": 0.08, # Small impact - elements still identifiable
308
+ "translation": 0.05, # Minimal impact - relative positions preserved
309
+ "theme": 0.15, # Moderate impact - visual appearance changes
310
+ "scroll": 0.10, # Some elements may be off-screen
311
+ }
312
+
313
+ drift_type = drift_config.drift_type.value
314
+ base_rate = drift_impact.get(drift_type, 0.10)
315
+
316
+ # Scale by drift severity
317
+ if drift_config.drift_type.value == "resolution":
318
+ scale = abs(drift_config.params.scale - 1.0) # type: ignore
319
+ return base_rate + scale * 0.2
320
+ elif drift_config.drift_type.value == "scroll":
321
+ scroll_amount = drift_config.params.offset_y # type: ignore
322
+ return base_rate + (scroll_amount / 1000) * 0.2
323
+
324
+ return base_rate
325
+
326
+ def _get_coordinate_noise(self, drift_config: DriftConfig) -> float:
327
+ """Get expected coordinate noise for coords conditions under drift."""
328
+ if drift_config.is_canonical:
329
+ return 0.02 # 2% baseline noise (normalized coords)
330
+
331
+ # Coordinates are more sensitive to drift than marks
332
+ drift_impact = {
333
+ "resolution": 0.08, # Significant - coordinates may not scale correctly
334
+ "translation": 0.06, # Moderate - if using screen-absolute coords
335
+ "theme": 0.03, # Minimal - visual changes don't affect coordinates directly
336
+ "scroll": 0.12, # High - y-coordinates shift significantly
337
+ }
338
+
339
+ drift_type = drift_config.drift_type.value
340
+ return drift_impact.get(drift_type, 0.05)
341
+
342
+ def run_condition(
343
+ self,
344
+ condition_name: ConditionName,
345
+ samples: list[Sample],
346
+ ) -> list[EvaluationResult]:
347
+ """Run a single condition across all drift tests.
348
+
349
+ Args:
350
+ condition_name: Name of condition to run.
351
+ samples: Samples to evaluate.
352
+
353
+ Returns:
354
+ List of EvaluationResult for each drift test.
355
+ """
356
+ condition = self.conditions.get(condition_name)
357
+ if not condition:
358
+ raise ValueError(f"Condition {condition_name} not found")
359
+
360
+ results = []
361
+ for drift_config in self.config.drift_tests:
362
+ logger.info(f"Evaluating {condition_name.value} under {drift_config.name}")
363
+
364
+ # Get model predictions
365
+ predictions = self.get_model_predictions(condition, samples, drift_config)
366
+
367
+ # Evaluate
368
+ result = self.evaluator.evaluate_condition_under_drift( # type: ignore
369
+ condition, samples, drift_config, predictions
370
+ )
371
+ results.append(result)
372
+
373
+ # Log summary
374
+ hit_rate = result.metrics.get(MetricName.CLICK_HIT_RATE.value, 0)
375
+ logger.info(f" Click-hit rate: {hit_rate:.1%}")
376
+
377
+ return results
378
+
379
+ def run(self) -> Recommendation:
380
+ """Run the full experiment.
381
+
382
+ Returns:
383
+ Recommendation based on results.
384
+ """
385
+ logger.info(f"Starting experiment: {self.config.name}")
386
+ logger.info(f"Conditions: {[c.value for c in self.conditions.keys()]}")
387
+ logger.info(f"Drift tests: {[d.name for d in self.config.drift_tests]}")
388
+
389
+ # Load samples
390
+ samples = self.load_samples()
391
+ logger.info(f"Loaded {len(samples)} samples")
392
+
393
+ # Run all conditions
394
+ all_results = []
395
+ for condition_name in self.conditions.keys():
396
+ results = self.run_condition(condition_name, samples)
397
+ all_results.extend(results)
398
+
399
+ self.results = all_results
400
+
401
+ # Generate recommendation
402
+ recommendation = make_recommendation(
403
+ all_results,
404
+ tolerance=self.config.decision_tolerance,
405
+ )
406
+
407
+ # Save results
408
+ self.save_results(recommendation)
409
+
410
+ return recommendation
411
+
412
+ def save_results(self, recommendation: Recommendation) -> None:
413
+ """Save experiment results to output directory.
414
+
415
+ Args:
416
+ recommendation: Final recommendation.
417
+ """
418
+ output_dir = Path(self.config.output_dir)
419
+ output_dir.mkdir(parents=True, exist_ok=True)
420
+
421
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
422
+ results_file = output_dir / f"results_{timestamp}.json"
423
+
424
+ # Serialize results
425
+ results_data = {
426
+ "experiment": self.config.name,
427
+ "timestamp": timestamp,
428
+ "config": {
429
+ "conditions": [c.value for c in self.conditions.keys()],
430
+ "drift_tests": [d.name for d in self.config.drift_tests],
431
+ "decision_tolerance": self.config.decision_tolerance,
432
+ },
433
+ "results": [
434
+ {
435
+ "condition": r.condition.value,
436
+ "drift": r.drift,
437
+ "num_samples": r.num_samples,
438
+ "metrics": r.metrics,
439
+ }
440
+ for r in self.results
441
+ ],
442
+ "recommendation": {
443
+ "recommended": recommendation.recommended,
444
+ "reason": recommendation.reason,
445
+ "coords_cues_avg": recommendation.coords_cues_avg,
446
+ "marks_avg": recommendation.marks_avg,
447
+ "tolerance": recommendation.tolerance,
448
+ "detailed_comparison": recommendation.detailed_comparison,
449
+ },
450
+ }
451
+
452
+ with open(results_file, "w") as f:
453
+ json.dump(results_data, f, indent=2)
454
+
455
+ logger.info(f"Results saved to {results_file}")
456
+
457
+ def print_summary(self, recommendation: Recommendation) -> None:
458
+ """Print experiment summary to stdout.
459
+
460
+ Args:
461
+ recommendation: Final recommendation.
462
+ """
463
+ print("\n" + "=" * 70)
464
+ print("REPRESENTATION SHOOTOUT - EXPERIMENT SUMMARY")
465
+ print("=" * 70)
466
+
467
+ print(f"\nExperiment: {self.config.name}")
468
+ print(f"Conditions: {', '.join(c.value for c in self.conditions.keys())}")
469
+ print(f"Drift tests: {len(self.config.drift_tests)}")
470
+ print(f"Samples: {len(self.results[0].sample_results) if self.results else 0}")
471
+
472
+ print("\n" + "-" * 70)
473
+ print("RESULTS BY CONDITION AND DRIFT")
474
+ print("-" * 70)
475
+
476
+ # Group results by condition
477
+ by_condition: dict[str, list[EvaluationResult]] = {}
478
+ for r in self.results:
479
+ key = r.condition.value
480
+ if key not in by_condition:
481
+ by_condition[key] = []
482
+ by_condition[key].append(r)
483
+
484
+ # Print table
485
+ header = f"{'Condition':<15} {'Drift':<25} {'Hit Rate':<12} {'Distance':<12}"
486
+ print(header)
487
+ print("-" * len(header))
488
+
489
+ for condition, results in by_condition.items():
490
+ for r in results:
491
+ hit_rate = r.metrics.get(MetricName.CLICK_HIT_RATE.value, 0)
492
+ distance = r.metrics.get(MetricName.COORD_DISTANCE.value, 0)
493
+ print(
494
+ f"{condition:<15} {r.drift:<25} {hit_rate:>10.1%} {distance:>10.4f}"
495
+ )
496
+ print()
497
+
498
+ print("-" * 70)
499
+ print("RECOMMENDATION")
500
+ print("-" * 70)
501
+ print(f"\nRecommended approach: {recommendation.recommended}")
502
+ print(f"\nReason: {recommendation.reason}")
503
+ print(f"\nCoords+Cues average: {recommendation.coords_cues_avg:.1%}")
504
+ print(f"Marks average: {recommendation.marks_avg:.1%}")
505
+ print(f"Tolerance: {recommendation.tolerance:.1%}")
506
+
507
+ print("\n" + "=" * 70)
508
+
509
+
510
+ def run_experiment(
511
+ config: ExperimentConfig | None = None,
512
+ data_path: str | None = None,
513
+ verbose: bool = True,
514
+ ) -> Recommendation:
515
+ """Convenience function to run the experiment.
516
+
517
+ Args:
518
+ config: Experiment configuration (uses default if None).
519
+ data_path: Path to evaluation data.
520
+ verbose: Whether to print progress.
521
+
522
+ Returns:
523
+ Final recommendation.
524
+ """
525
+ if verbose:
526
+ logging.basicConfig(level=logging.INFO, format="%(message)s")
527
+
528
+ config = config or ExperimentConfig.default()
529
+ runner = ExperimentRunner(config)
530
+
531
+ recommendation = runner.run()
532
+
533
+ if verbose:
534
+ runner.print_summary(recommendation)
535
+
536
+ return recommendation
537
+
538
+
539
+ def main() -> int:
540
+ """CLI entry point."""
541
+ parser = argparse.ArgumentParser(
542
+ description="Representation Shootout Experiment",
543
+ formatter_class=argparse.RawDescriptionHelpFormatter,
544
+ epilog="""
545
+ Examples:
546
+ # Run full experiment with default config
547
+ python -m openadapt_ml.experiments.representation_shootout.runner run
548
+
549
+ # Run with minimal config (quick test)
550
+ python -m openadapt_ml.experiments.representation_shootout.runner run --minimal
551
+
552
+ # Run specific condition only
553
+ python -m openadapt_ml.experiments.representation_shootout.runner run --condition marks
554
+
555
+ # Specify output directory
556
+ python -m openadapt_ml.experiments.representation_shootout.runner run --output results/my_experiment
557
+
558
+ # Generate recommendation from existing results
559
+ python -m openadapt_ml.experiments.representation_shootout.runner recommend --results results/results_20260116.json
560
+ """,
561
+ )
562
+
563
+ subparsers = parser.add_subparsers(dest="command", required=True)
564
+
565
+ # Run command
566
+ run_parser = subparsers.add_parser("run", help="Run the experiment")
567
+ run_parser.add_argument(
568
+ "--minimal",
569
+ action="store_true",
570
+ help="Use minimal config for quick testing",
571
+ )
572
+ run_parser.add_argument(
573
+ "--condition",
574
+ choices=["raw_coords", "coords_cues", "marks"],
575
+ help="Run only specific condition",
576
+ )
577
+ run_parser.add_argument(
578
+ "--data",
579
+ help="Path to evaluation data directory",
580
+ )
581
+ run_parser.add_argument(
582
+ "--output",
583
+ default="experiment_results/representation_shootout",
584
+ help="Output directory for results",
585
+ )
586
+ run_parser.add_argument(
587
+ "--seed",
588
+ type=int,
589
+ default=42,
590
+ help="Random seed for reproducibility",
591
+ )
592
+ run_parser.add_argument(
593
+ "-v",
594
+ "--verbose",
595
+ action="store_true",
596
+ help="Verbose output",
597
+ )
598
+
599
+ # Recommend command (analyze existing results)
600
+ rec_parser = subparsers.add_parser(
601
+ "recommend", help="Generate recommendation from results"
602
+ )
603
+ rec_parser.add_argument(
604
+ "--results",
605
+ required=True,
606
+ help="Path to results JSON file",
607
+ )
608
+
609
+ args = parser.parse_args()
610
+
611
+ if args.command == "run":
612
+ # Configure logging
613
+ log_level = logging.DEBUG if args.verbose else logging.INFO
614
+ logging.basicConfig(level=log_level, format="%(levelname)s: %(message)s")
615
+
616
+ # Build config
617
+ if args.minimal:
618
+ config = ExperimentConfig.minimal()
619
+ else:
620
+ config = ExperimentConfig.default()
621
+
622
+ # Override output dir
623
+ config.output_dir = args.output
624
+ config.seed = args.seed
625
+
626
+ # Filter to specific condition if requested
627
+ if args.condition:
628
+ condition_name = ConditionName(args.condition)
629
+ config.conditions = [
630
+ c for c in config.conditions if c.name == condition_name
631
+ ]
632
+ if not config.conditions:
633
+ print(f"Error: No matching condition found for {args.condition}")
634
+ return 1
635
+
636
+ try:
637
+ runner = ExperimentRunner(config)
638
+ if args.data:
639
+ runner.load_samples(args.data)
640
+ recommendation = runner.run()
641
+ runner.print_summary(recommendation)
642
+ return 0
643
+ except Exception as e:
644
+ logger.error(f"Experiment failed: {e}")
645
+ if args.verbose:
646
+ import traceback
647
+
648
+ traceback.print_exc()
649
+ return 1
650
+
651
+ elif args.command == "recommend":
652
+ # Load existing results and generate recommendation
653
+ try:
654
+ with open(args.results) as f:
655
+ data = json.load(f)
656
+
657
+ # Reconstruct EvaluationResults
658
+ results = []
659
+ for r in data["results"]:
660
+ results.append(
661
+ EvaluationResult(
662
+ condition=ConditionName(r["condition"]),
663
+ drift=r["drift"],
664
+ num_samples=r["num_samples"],
665
+ metrics=r["metrics"],
666
+ )
667
+ )
668
+
669
+ tolerance = data.get("config", {}).get("decision_tolerance", 0.05)
670
+ recommendation = make_recommendation(results, tolerance=tolerance)
671
+
672
+ print("\n" + "=" * 70)
673
+ print("RECOMMENDATION FROM RESULTS")
674
+ print("=" * 70)
675
+ print(f"\nRecommended approach: {recommendation.recommended}")
676
+ print(f"\nReason: {recommendation.reason}")
677
+ return 0
678
+
679
+ except Exception as e:
680
+ logger.error(f"Failed to load results: {e}")
681
+ return 1
682
+
683
+ return 0
684
+
685
+
686
+ if __name__ == "__main__":
687
+ sys.exit(main())