openadapt-ml 0.2.0__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openadapt_ml/baselines/__init__.py +121 -0
- openadapt_ml/baselines/adapter.py +185 -0
- openadapt_ml/baselines/cli.py +314 -0
- openadapt_ml/baselines/config.py +448 -0
- openadapt_ml/baselines/parser.py +922 -0
- openadapt_ml/baselines/prompts.py +787 -0
- openadapt_ml/benchmarks/__init__.py +13 -115
- openadapt_ml/benchmarks/agent.py +265 -421
- openadapt_ml/benchmarks/azure.py +28 -19
- openadapt_ml/benchmarks/azure_ops_tracker.py +521 -0
- openadapt_ml/benchmarks/cli.py +1722 -4847
- openadapt_ml/benchmarks/trace_export.py +631 -0
- openadapt_ml/benchmarks/viewer.py +22 -5
- openadapt_ml/benchmarks/vm_monitor.py +530 -29
- openadapt_ml/benchmarks/waa_deploy/Dockerfile +47 -53
- openadapt_ml/benchmarks/waa_deploy/api_agent.py +21 -20
- openadapt_ml/cloud/azure_inference.py +3 -5
- openadapt_ml/cloud/lambda_labs.py +722 -307
- openadapt_ml/cloud/local.py +2038 -487
- openadapt_ml/cloud/ssh_tunnel.py +68 -26
- openadapt_ml/datasets/next_action.py +40 -30
- openadapt_ml/evals/grounding.py +8 -3
- openadapt_ml/evals/plot_eval_metrics.py +15 -13
- openadapt_ml/evals/trajectory_matching.py +41 -26
- openadapt_ml/experiments/demo_prompt/format_demo.py +16 -6
- openadapt_ml/experiments/demo_prompt/run_experiment.py +26 -16
- openadapt_ml/experiments/representation_shootout/__init__.py +70 -0
- openadapt_ml/experiments/representation_shootout/conditions.py +708 -0
- openadapt_ml/experiments/representation_shootout/config.py +390 -0
- openadapt_ml/experiments/representation_shootout/evaluator.py +659 -0
- openadapt_ml/experiments/representation_shootout/runner.py +687 -0
- openadapt_ml/experiments/waa_demo/runner.py +29 -14
- openadapt_ml/export/parquet.py +36 -24
- openadapt_ml/grounding/detector.py +18 -14
- openadapt_ml/ingest/__init__.py +8 -6
- openadapt_ml/ingest/capture.py +25 -22
- openadapt_ml/ingest/loader.py +7 -4
- openadapt_ml/ingest/synthetic.py +189 -100
- openadapt_ml/models/api_adapter.py +14 -4
- openadapt_ml/models/base_adapter.py +10 -2
- openadapt_ml/models/providers/__init__.py +288 -0
- openadapt_ml/models/providers/anthropic.py +266 -0
- openadapt_ml/models/providers/base.py +299 -0
- openadapt_ml/models/providers/google.py +376 -0
- openadapt_ml/models/providers/openai.py +342 -0
- openadapt_ml/models/qwen_vl.py +46 -19
- openadapt_ml/perception/__init__.py +35 -0
- openadapt_ml/perception/integration.py +399 -0
- openadapt_ml/retrieval/demo_retriever.py +50 -24
- openadapt_ml/retrieval/embeddings.py +9 -8
- openadapt_ml/retrieval/retriever.py +3 -1
- openadapt_ml/runtime/__init__.py +50 -0
- openadapt_ml/runtime/policy.py +18 -5
- openadapt_ml/runtime/safety_gate.py +471 -0
- openadapt_ml/schema/__init__.py +9 -0
- openadapt_ml/schema/converters.py +74 -27
- openadapt_ml/schema/episode.py +31 -18
- openadapt_ml/scripts/capture_screenshots.py +530 -0
- openadapt_ml/scripts/compare.py +85 -54
- openadapt_ml/scripts/demo_policy.py +4 -1
- openadapt_ml/scripts/eval_policy.py +15 -9
- openadapt_ml/scripts/make_gif.py +1 -1
- openadapt_ml/scripts/prepare_synthetic.py +3 -1
- openadapt_ml/scripts/train.py +21 -9
- openadapt_ml/segmentation/README.md +920 -0
- openadapt_ml/segmentation/__init__.py +97 -0
- openadapt_ml/segmentation/adapters/__init__.py +5 -0
- openadapt_ml/segmentation/adapters/capture_adapter.py +420 -0
- openadapt_ml/segmentation/annotator.py +610 -0
- openadapt_ml/segmentation/cache.py +290 -0
- openadapt_ml/segmentation/cli.py +674 -0
- openadapt_ml/segmentation/deduplicator.py +656 -0
- openadapt_ml/segmentation/frame_describer.py +788 -0
- openadapt_ml/segmentation/pipeline.py +340 -0
- openadapt_ml/segmentation/schemas.py +622 -0
- openadapt_ml/segmentation/segment_extractor.py +634 -0
- openadapt_ml/training/azure_ops_viewer.py +1097 -0
- openadapt_ml/training/benchmark_viewer.py +52 -41
- openadapt_ml/training/shared_ui.py +7 -7
- openadapt_ml/training/stub_provider.py +57 -35
- openadapt_ml/training/trainer.py +143 -86
- openadapt_ml/training/trl_trainer.py +70 -21
- openadapt_ml/training/viewer.py +323 -108
- openadapt_ml/training/viewer_components.py +180 -0
- {openadapt_ml-0.2.0.dist-info → openadapt_ml-0.2.1.dist-info}/METADATA +215 -14
- openadapt_ml-0.2.1.dist-info/RECORD +116 -0
- openadapt_ml/benchmarks/base.py +0 -366
- openadapt_ml/benchmarks/data_collection.py +0 -432
- openadapt_ml/benchmarks/live_tracker.py +0 -180
- openadapt_ml/benchmarks/runner.py +0 -418
- openadapt_ml/benchmarks/waa.py +0 -761
- openadapt_ml/benchmarks/waa_live.py +0 -619
- openadapt_ml-0.2.0.dist-info/RECORD +0 -86
- {openadapt_ml-0.2.0.dist-info → openadapt_ml-0.2.1.dist-info}/WHEEL +0 -0
- {openadapt_ml-0.2.0.dist-info → openadapt_ml-0.2.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,390 @@
|
|
|
1
|
+
"""Configuration dataclasses for the Representation Shootout experiment.
|
|
2
|
+
|
|
3
|
+
This module defines all configuration structures using dataclasses with
|
|
4
|
+
sensible defaults for the experiment comparing Coordinates vs Marks approaches.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from dataclasses import dataclass, field
|
|
10
|
+
from enum import Enum
|
|
11
|
+
from typing import Any
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class ConditionName(str, Enum):
|
|
15
|
+
"""Experimental conditions for representation comparison."""
|
|
16
|
+
|
|
17
|
+
RAW_COORDS = "raw_coords" # Condition A: Raw coordinate regression
|
|
18
|
+
COORDS_CUES = "coords_cues" # Condition B: Coordinates + visual cues
|
|
19
|
+
MARKS = "marks" # Condition C: Element ID classification (SoM-style)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class DriftType(str, Enum):
|
|
23
|
+
"""Types of distribution drift for evaluation."""
|
|
24
|
+
|
|
25
|
+
RESOLUTION = "resolution" # Scale the UI resolution
|
|
26
|
+
TRANSLATION = "translation" # Shift the window position
|
|
27
|
+
THEME = "theme" # Change UI theme (light/dark/high-contrast)
|
|
28
|
+
SCROLL = "scroll" # Change scroll offset
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class MetricName(str, Enum):
|
|
32
|
+
"""Metrics computed during evaluation."""
|
|
33
|
+
|
|
34
|
+
CLICK_HIT_RATE = "click_hit_rate" # Clicks within target bbox
|
|
35
|
+
GROUNDING_TOP1_ACCURACY = (
|
|
36
|
+
"grounding_top1_accuracy" # Correct element ID (marks only)
|
|
37
|
+
)
|
|
38
|
+
EPISODE_SUCCESS_RATE = "episode_success_rate" # Episodes reaching goal
|
|
39
|
+
COORD_DISTANCE = "coord_distance" # L2 distance to target (normalized)
|
|
40
|
+
ROBUSTNESS_SCORE = "robustness_score" # Performance ratio: drift / canonical
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class OutputFormat(str, Enum):
|
|
44
|
+
"""Output format for model predictions."""
|
|
45
|
+
|
|
46
|
+
COORDINATES = "coordinates" # {"type": "CLICK", "x": float, "y": float}
|
|
47
|
+
ELEMENT_ID = "element_id" # {"type": "CLICK", "element_id": str}
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
@dataclass
|
|
51
|
+
class VisualCuesConfig:
|
|
52
|
+
"""Configuration for visual cues in Condition B.
|
|
53
|
+
|
|
54
|
+
Attributes:
|
|
55
|
+
marker_enabled: Whether to draw a marker at click target.
|
|
56
|
+
marker_radius: Radius of the marker circle in pixels.
|
|
57
|
+
marker_color: RGB color tuple for the marker.
|
|
58
|
+
zoom_enabled: Whether to include a zoomed inset patch.
|
|
59
|
+
zoom_factor: Magnification factor for zoom patch.
|
|
60
|
+
zoom_patch_size: Size of the zoom patch in pixels (square).
|
|
61
|
+
zoom_position: Where to place zoom ("auto", "top-left", "top-right", etc.).
|
|
62
|
+
"""
|
|
63
|
+
|
|
64
|
+
marker_enabled: bool = True
|
|
65
|
+
marker_radius: int = 8
|
|
66
|
+
marker_color: tuple[int, int, int] = (255, 0, 0) # Red
|
|
67
|
+
|
|
68
|
+
zoom_enabled: bool = True
|
|
69
|
+
zoom_factor: float = 2.0
|
|
70
|
+
zoom_patch_size: int = 100
|
|
71
|
+
zoom_position: str = "auto" # "auto" places opposite to click location
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
@dataclass
|
|
75
|
+
class MarksConfig:
|
|
76
|
+
"""Configuration for SoM-style marks in Condition C.
|
|
77
|
+
|
|
78
|
+
Attributes:
|
|
79
|
+
overlay_enabled: Whether to draw element ID overlays on screenshot.
|
|
80
|
+
font_size: Font size for element ID labels.
|
|
81
|
+
label_background: Background color for labels (RGBA).
|
|
82
|
+
label_text_color: Text color for labels.
|
|
83
|
+
max_elements: Maximum number of elements to include.
|
|
84
|
+
include_roles: Element roles to include (None = all).
|
|
85
|
+
exclude_roles: Element roles to exclude.
|
|
86
|
+
"""
|
|
87
|
+
|
|
88
|
+
overlay_enabled: bool = True
|
|
89
|
+
font_size: int = 12
|
|
90
|
+
label_background: tuple[int, int, int, int] = (
|
|
91
|
+
0,
|
|
92
|
+
0,
|
|
93
|
+
255,
|
|
94
|
+
200,
|
|
95
|
+
) # Blue, semi-transparent
|
|
96
|
+
label_text_color: tuple[int, int, int] = (255, 255, 255) # White
|
|
97
|
+
|
|
98
|
+
max_elements: int = 50
|
|
99
|
+
include_roles: list[str] | None = None # None = include all
|
|
100
|
+
exclude_roles: list[str] = field(
|
|
101
|
+
default_factory=lambda: ["group", "generic", "static_text"]
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
@dataclass
|
|
106
|
+
class ConditionConfig:
|
|
107
|
+
"""Configuration for a single experimental condition.
|
|
108
|
+
|
|
109
|
+
Attributes:
|
|
110
|
+
name: Condition identifier (raw_coords, coords_cues, marks).
|
|
111
|
+
output_format: Expected model output format.
|
|
112
|
+
include_history: Whether to include action history in prompt.
|
|
113
|
+
max_history_steps: Maximum number of history steps to include.
|
|
114
|
+
visual_cues: Configuration for visual cues (Condition B only).
|
|
115
|
+
marks: Configuration for element marks (Condition C only).
|
|
116
|
+
loss_type: Loss function type ("mse" for coords, "cross_entropy" for marks).
|
|
117
|
+
"""
|
|
118
|
+
|
|
119
|
+
name: ConditionName
|
|
120
|
+
output_format: OutputFormat
|
|
121
|
+
include_history: bool = True
|
|
122
|
+
max_history_steps: int = 5
|
|
123
|
+
|
|
124
|
+
# Condition-specific configs
|
|
125
|
+
visual_cues: VisualCuesConfig | None = None
|
|
126
|
+
marks: MarksConfig | None = None
|
|
127
|
+
|
|
128
|
+
# Training config
|
|
129
|
+
loss_type: str = (
|
|
130
|
+
"mse" # "mse" for coordinate regression, "cross_entropy" for classification
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
@classmethod
|
|
134
|
+
def raw_coords(cls, **kwargs: Any) -> ConditionConfig:
|
|
135
|
+
"""Create Condition A (Raw Coordinates) config."""
|
|
136
|
+
return cls(
|
|
137
|
+
name=ConditionName.RAW_COORDS,
|
|
138
|
+
output_format=OutputFormat.COORDINATES,
|
|
139
|
+
loss_type="mse",
|
|
140
|
+
**kwargs,
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
@classmethod
|
|
144
|
+
def coords_cues(cls, **kwargs: Any) -> ConditionConfig:
|
|
145
|
+
"""Create Condition B (Coordinates + Visual Cues) config."""
|
|
146
|
+
visual_cues = kwargs.pop("visual_cues", None) or VisualCuesConfig()
|
|
147
|
+
return cls(
|
|
148
|
+
name=ConditionName.COORDS_CUES,
|
|
149
|
+
output_format=OutputFormat.COORDINATES,
|
|
150
|
+
visual_cues=visual_cues,
|
|
151
|
+
loss_type="mse",
|
|
152
|
+
**kwargs,
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
@classmethod
|
|
156
|
+
def marks(cls, **kwargs: Any) -> ConditionConfig: # noqa: F811
|
|
157
|
+
"""Create Condition C (Marks/Element IDs) config."""
|
|
158
|
+
marks_config = kwargs.pop("marks", None) or MarksConfig()
|
|
159
|
+
return cls(
|
|
160
|
+
name=ConditionName.MARKS,
|
|
161
|
+
output_format=OutputFormat.ELEMENT_ID,
|
|
162
|
+
marks=marks_config,
|
|
163
|
+
loss_type="cross_entropy",
|
|
164
|
+
**kwargs,
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
@dataclass
|
|
169
|
+
class ResolutionDriftParams:
|
|
170
|
+
"""Parameters for resolution scaling drift."""
|
|
171
|
+
|
|
172
|
+
scale: float # 0.75, 1.0, 1.25, 1.5
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
@dataclass
|
|
176
|
+
class TranslationDriftParams:
|
|
177
|
+
"""Parameters for window translation drift."""
|
|
178
|
+
|
|
179
|
+
offset_x: int # Pixels to shift horizontally
|
|
180
|
+
offset_y: int # Pixels to shift vertically
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
@dataclass
|
|
184
|
+
class ThemeDriftParams:
|
|
185
|
+
"""Parameters for UI theme drift."""
|
|
186
|
+
|
|
187
|
+
theme: str # "light", "dark", "high_contrast"
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
@dataclass
|
|
191
|
+
class ScrollDriftParams:
|
|
192
|
+
"""Parameters for scroll offset drift."""
|
|
193
|
+
|
|
194
|
+
offset_y: int # Pixels scrolled down from top
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
@dataclass
|
|
198
|
+
class DriftConfig:
|
|
199
|
+
"""Configuration for a drift evaluation test.
|
|
200
|
+
|
|
201
|
+
Attributes:
|
|
202
|
+
name: Human-readable name for this drift test.
|
|
203
|
+
drift_type: Type of drift (resolution, translation, theme, scroll).
|
|
204
|
+
params: Drift-specific parameters.
|
|
205
|
+
is_canonical: Whether this is the canonical (no-drift) baseline.
|
|
206
|
+
"""
|
|
207
|
+
|
|
208
|
+
name: str
|
|
209
|
+
drift_type: DriftType
|
|
210
|
+
params: (
|
|
211
|
+
ResolutionDriftParams
|
|
212
|
+
| TranslationDriftParams
|
|
213
|
+
| ThemeDriftParams
|
|
214
|
+
| ScrollDriftParams
|
|
215
|
+
)
|
|
216
|
+
is_canonical: bool = False
|
|
217
|
+
|
|
218
|
+
@classmethod
|
|
219
|
+
def canonical(cls) -> DriftConfig:
|
|
220
|
+
"""Create canonical (no-drift) baseline config."""
|
|
221
|
+
return cls(
|
|
222
|
+
name="canonical",
|
|
223
|
+
drift_type=DriftType.RESOLUTION,
|
|
224
|
+
params=ResolutionDriftParams(scale=1.0),
|
|
225
|
+
is_canonical=True,
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
@classmethod
|
|
229
|
+
def resolution(cls, scale: float) -> DriftConfig:
|
|
230
|
+
"""Create resolution scaling drift config."""
|
|
231
|
+
return cls(
|
|
232
|
+
name=f"resolution_{scale}x",
|
|
233
|
+
drift_type=DriftType.RESOLUTION,
|
|
234
|
+
params=ResolutionDriftParams(scale=scale),
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
@classmethod
|
|
238
|
+
def translation(cls, offset_x: int, offset_y: int) -> DriftConfig:
|
|
239
|
+
"""Create window translation drift config."""
|
|
240
|
+
return cls(
|
|
241
|
+
name=f"translation_{offset_x}_{offset_y}",
|
|
242
|
+
drift_type=DriftType.TRANSLATION,
|
|
243
|
+
params=TranslationDriftParams(offset_x=offset_x, offset_y=offset_y),
|
|
244
|
+
)
|
|
245
|
+
|
|
246
|
+
@classmethod
|
|
247
|
+
def theme(cls, theme_name: str) -> DriftConfig:
|
|
248
|
+
"""Create UI theme drift config."""
|
|
249
|
+
return cls(
|
|
250
|
+
name=f"theme_{theme_name}",
|
|
251
|
+
drift_type=DriftType.THEME,
|
|
252
|
+
params=ThemeDriftParams(theme=theme_name),
|
|
253
|
+
)
|
|
254
|
+
|
|
255
|
+
@classmethod
|
|
256
|
+
def scroll(cls, offset_y: int) -> DriftConfig:
|
|
257
|
+
"""Create scroll offset drift config."""
|
|
258
|
+
return cls(
|
|
259
|
+
name=f"scroll_{offset_y}px",
|
|
260
|
+
drift_type=DriftType.SCROLL,
|
|
261
|
+
params=ScrollDriftParams(offset_y=offset_y),
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
|
|
265
|
+
@dataclass
|
|
266
|
+
class DatasetConfig:
|
|
267
|
+
"""Configuration for training/evaluation datasets.
|
|
268
|
+
|
|
269
|
+
Attributes:
|
|
270
|
+
train_path: Path to training data directory or file.
|
|
271
|
+
eval_path: Path to evaluation data directory or file.
|
|
272
|
+
canonical_resolution: Expected resolution for canonical data.
|
|
273
|
+
min_train_samples: Minimum training samples required.
|
|
274
|
+
min_eval_samples: Minimum evaluation samples per drift condition.
|
|
275
|
+
"""
|
|
276
|
+
|
|
277
|
+
train_path: str | None = None
|
|
278
|
+
eval_path: str | None = None
|
|
279
|
+
canonical_resolution: tuple[int, int] = (1920, 1080)
|
|
280
|
+
min_train_samples: int = 1000
|
|
281
|
+
min_eval_samples: int = 100
|
|
282
|
+
|
|
283
|
+
|
|
284
|
+
@dataclass
|
|
285
|
+
class ExperimentConfig:
|
|
286
|
+
"""Top-level configuration for the Representation Shootout experiment.
|
|
287
|
+
|
|
288
|
+
Attributes:
|
|
289
|
+
name: Experiment name for logging and results.
|
|
290
|
+
conditions: List of conditions to evaluate (A, B, C).
|
|
291
|
+
drift_tests: List of drift conditions for evaluation.
|
|
292
|
+
metrics: Metrics to compute during evaluation.
|
|
293
|
+
decision_tolerance: Tolerance for decision rule (default 5%).
|
|
294
|
+
dataset: Dataset configuration.
|
|
295
|
+
output_dir: Directory for experiment outputs.
|
|
296
|
+
seed: Random seed for reproducibility.
|
|
297
|
+
"""
|
|
298
|
+
|
|
299
|
+
name: str = "representation_shootout"
|
|
300
|
+
conditions: list[ConditionConfig] = field(default_factory=list)
|
|
301
|
+
drift_tests: list[DriftConfig] = field(default_factory=list)
|
|
302
|
+
metrics: list[MetricName] = field(default_factory=list)
|
|
303
|
+
decision_tolerance: float = 0.05 # 5% tolerance for decision rule
|
|
304
|
+
dataset: DatasetConfig = field(default_factory=DatasetConfig)
|
|
305
|
+
output_dir: str = "experiment_results/representation_shootout"
|
|
306
|
+
seed: int = 42
|
|
307
|
+
|
|
308
|
+
@classmethod
|
|
309
|
+
def default(cls) -> ExperimentConfig:
|
|
310
|
+
"""Create default experiment configuration with all conditions and drifts."""
|
|
311
|
+
return cls(
|
|
312
|
+
name="representation_shootout_default",
|
|
313
|
+
conditions=[
|
|
314
|
+
ConditionConfig.raw_coords(),
|
|
315
|
+
ConditionConfig.coords_cues(),
|
|
316
|
+
ConditionConfig.marks(),
|
|
317
|
+
],
|
|
318
|
+
drift_tests=[
|
|
319
|
+
# Canonical baseline
|
|
320
|
+
DriftConfig.canonical(),
|
|
321
|
+
# Resolution scaling
|
|
322
|
+
DriftConfig.resolution(0.75),
|
|
323
|
+
DriftConfig.resolution(1.25),
|
|
324
|
+
DriftConfig.resolution(1.5),
|
|
325
|
+
# Window translation
|
|
326
|
+
DriftConfig.translation(200, 0),
|
|
327
|
+
DriftConfig.translation(0, 100),
|
|
328
|
+
DriftConfig.translation(200, 100),
|
|
329
|
+
# Theme changes
|
|
330
|
+
DriftConfig.theme("dark"),
|
|
331
|
+
DriftConfig.theme("high_contrast"),
|
|
332
|
+
# Scroll offset
|
|
333
|
+
DriftConfig.scroll(300),
|
|
334
|
+
DriftConfig.scroll(600),
|
|
335
|
+
],
|
|
336
|
+
metrics=[
|
|
337
|
+
MetricName.CLICK_HIT_RATE,
|
|
338
|
+
MetricName.GROUNDING_TOP1_ACCURACY,
|
|
339
|
+
MetricName.EPISODE_SUCCESS_RATE,
|
|
340
|
+
MetricName.COORD_DISTANCE,
|
|
341
|
+
MetricName.ROBUSTNESS_SCORE,
|
|
342
|
+
],
|
|
343
|
+
)
|
|
344
|
+
|
|
345
|
+
@classmethod
|
|
346
|
+
def minimal(cls) -> ExperimentConfig:
|
|
347
|
+
"""Create minimal configuration for quick testing."""
|
|
348
|
+
return cls(
|
|
349
|
+
name="representation_shootout_minimal",
|
|
350
|
+
conditions=[
|
|
351
|
+
ConditionConfig.raw_coords(),
|
|
352
|
+
ConditionConfig.marks(),
|
|
353
|
+
],
|
|
354
|
+
drift_tests=[
|
|
355
|
+
DriftConfig.canonical(),
|
|
356
|
+
DriftConfig.resolution(0.75),
|
|
357
|
+
DriftConfig.resolution(1.5),
|
|
358
|
+
],
|
|
359
|
+
metrics=[
|
|
360
|
+
MetricName.CLICK_HIT_RATE,
|
|
361
|
+
MetricName.COORD_DISTANCE,
|
|
362
|
+
],
|
|
363
|
+
)
|
|
364
|
+
|
|
365
|
+
def validate(self) -> list[str]:
|
|
366
|
+
"""Validate configuration and return list of issues.
|
|
367
|
+
|
|
368
|
+
Returns:
|
|
369
|
+
List of validation error messages (empty if valid).
|
|
370
|
+
"""
|
|
371
|
+
issues = []
|
|
372
|
+
|
|
373
|
+
if not self.conditions:
|
|
374
|
+
issues.append("At least one condition must be specified")
|
|
375
|
+
|
|
376
|
+
if not self.drift_tests:
|
|
377
|
+
issues.append("At least one drift test must be specified")
|
|
378
|
+
|
|
379
|
+
# Check for canonical baseline
|
|
380
|
+
has_canonical = any(d.is_canonical for d in self.drift_tests)
|
|
381
|
+
if not has_canonical:
|
|
382
|
+
issues.append("At least one canonical (no-drift) baseline must be included")
|
|
383
|
+
|
|
384
|
+
# Check metrics
|
|
385
|
+
if MetricName.GROUNDING_TOP1_ACCURACY in self.metrics:
|
|
386
|
+
has_marks = any(c.name == ConditionName.MARKS for c in self.conditions)
|
|
387
|
+
if not has_marks:
|
|
388
|
+
issues.append("GROUNDING_TOP1_ACCURACY metric requires MARKS condition")
|
|
389
|
+
|
|
390
|
+
return issues
|