openadapt-ml 0.2.0__py3-none-any.whl → 0.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openadapt_ml/baselines/__init__.py +121 -0
- openadapt_ml/baselines/adapter.py +185 -0
- openadapt_ml/baselines/cli.py +314 -0
- openadapt_ml/baselines/config.py +448 -0
- openadapt_ml/baselines/parser.py +922 -0
- openadapt_ml/baselines/prompts.py +787 -0
- openadapt_ml/benchmarks/__init__.py +13 -115
- openadapt_ml/benchmarks/agent.py +265 -421
- openadapt_ml/benchmarks/azure.py +28 -19
- openadapt_ml/benchmarks/azure_ops_tracker.py +521 -0
- openadapt_ml/benchmarks/cli.py +1722 -4847
- openadapt_ml/benchmarks/trace_export.py +631 -0
- openadapt_ml/benchmarks/viewer.py +22 -5
- openadapt_ml/benchmarks/vm_monitor.py +530 -29
- openadapt_ml/benchmarks/waa_deploy/Dockerfile +47 -53
- openadapt_ml/benchmarks/waa_deploy/api_agent.py +21 -20
- openadapt_ml/cloud/azure_inference.py +3 -5
- openadapt_ml/cloud/lambda_labs.py +722 -307
- openadapt_ml/cloud/local.py +2038 -487
- openadapt_ml/cloud/ssh_tunnel.py +68 -26
- openadapt_ml/datasets/next_action.py +40 -30
- openadapt_ml/evals/grounding.py +8 -3
- openadapt_ml/evals/plot_eval_metrics.py +15 -13
- openadapt_ml/evals/trajectory_matching.py +41 -26
- openadapt_ml/experiments/demo_prompt/format_demo.py +16 -6
- openadapt_ml/experiments/demo_prompt/run_experiment.py +26 -16
- openadapt_ml/experiments/representation_shootout/__init__.py +70 -0
- openadapt_ml/experiments/representation_shootout/conditions.py +708 -0
- openadapt_ml/experiments/representation_shootout/config.py +390 -0
- openadapt_ml/experiments/representation_shootout/evaluator.py +659 -0
- openadapt_ml/experiments/representation_shootout/runner.py +687 -0
- openadapt_ml/experiments/waa_demo/runner.py +29 -14
- openadapt_ml/export/parquet.py +36 -24
- openadapt_ml/grounding/detector.py +18 -14
- openadapt_ml/ingest/__init__.py +8 -6
- openadapt_ml/ingest/capture.py +25 -22
- openadapt_ml/ingest/loader.py +7 -4
- openadapt_ml/ingest/synthetic.py +189 -100
- openadapt_ml/models/api_adapter.py +14 -4
- openadapt_ml/models/base_adapter.py +10 -2
- openadapt_ml/models/providers/__init__.py +288 -0
- openadapt_ml/models/providers/anthropic.py +266 -0
- openadapt_ml/models/providers/base.py +299 -0
- openadapt_ml/models/providers/google.py +376 -0
- openadapt_ml/models/providers/openai.py +342 -0
- openadapt_ml/models/qwen_vl.py +46 -19
- openadapt_ml/perception/__init__.py +35 -0
- openadapt_ml/perception/integration.py +399 -0
- openadapt_ml/retrieval/demo_retriever.py +50 -24
- openadapt_ml/retrieval/embeddings.py +9 -8
- openadapt_ml/retrieval/retriever.py +3 -1
- openadapt_ml/runtime/__init__.py +50 -0
- openadapt_ml/runtime/policy.py +18 -5
- openadapt_ml/runtime/safety_gate.py +471 -0
- openadapt_ml/schema/__init__.py +9 -0
- openadapt_ml/schema/converters.py +74 -27
- openadapt_ml/schema/episode.py +31 -18
- openadapt_ml/scripts/capture_screenshots.py +530 -0
- openadapt_ml/scripts/compare.py +85 -54
- openadapt_ml/scripts/demo_policy.py +4 -1
- openadapt_ml/scripts/eval_policy.py +15 -9
- openadapt_ml/scripts/make_gif.py +1 -1
- openadapt_ml/scripts/prepare_synthetic.py +3 -1
- openadapt_ml/scripts/train.py +21 -9
- openadapt_ml/segmentation/README.md +920 -0
- openadapt_ml/segmentation/__init__.py +97 -0
- openadapt_ml/segmentation/adapters/__init__.py +5 -0
- openadapt_ml/segmentation/adapters/capture_adapter.py +420 -0
- openadapt_ml/segmentation/annotator.py +610 -0
- openadapt_ml/segmentation/cache.py +290 -0
- openadapt_ml/segmentation/cli.py +674 -0
- openadapt_ml/segmentation/deduplicator.py +656 -0
- openadapt_ml/segmentation/frame_describer.py +788 -0
- openadapt_ml/segmentation/pipeline.py +340 -0
- openadapt_ml/segmentation/schemas.py +622 -0
- openadapt_ml/segmentation/segment_extractor.py +634 -0
- openadapt_ml/training/azure_ops_viewer.py +1097 -0
- openadapt_ml/training/benchmark_viewer.py +52 -41
- openadapt_ml/training/shared_ui.py +7 -7
- openadapt_ml/training/stub_provider.py +57 -35
- openadapt_ml/training/trainer.py +143 -86
- openadapt_ml/training/trl_trainer.py +70 -21
- openadapt_ml/training/viewer.py +323 -108
- openadapt_ml/training/viewer_components.py +180 -0
- {openadapt_ml-0.2.0.dist-info → openadapt_ml-0.2.2.dist-info}/METADATA +215 -14
- openadapt_ml-0.2.2.dist-info/RECORD +116 -0
- openadapt_ml/benchmarks/base.py +0 -366
- openadapt_ml/benchmarks/data_collection.py +0 -432
- openadapt_ml/benchmarks/live_tracker.py +0 -180
- openadapt_ml/benchmarks/runner.py +0 -418
- openadapt_ml/benchmarks/waa.py +0 -761
- openadapt_ml/benchmarks/waa_live.py +0 -619
- openadapt_ml-0.2.0.dist-info/RECORD +0 -86
- {openadapt_ml-0.2.0.dist-info → openadapt_ml-0.2.2.dist-info}/WHEEL +0 -0
- {openadapt_ml-0.2.0.dist-info → openadapt_ml-0.2.2.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,448 @@
|
|
|
1
|
+
"""Configuration for baseline adapters.
|
|
2
|
+
|
|
3
|
+
Defines track types, model registry, and configuration dataclasses.
|
|
4
|
+
Based on SOTA patterns from:
|
|
5
|
+
- Claude Computer Use API
|
|
6
|
+
- Microsoft UFO/UFO2
|
|
7
|
+
- OSWorld benchmark
|
|
8
|
+
- Agent-S/Agent-S2
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from __future__ import annotations
|
|
12
|
+
|
|
13
|
+
from dataclasses import dataclass, field
|
|
14
|
+
from enum import Enum
|
|
15
|
+
from typing import Any
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class TrackType(str, Enum):
|
|
19
|
+
"""Baseline evaluation track types.
|
|
20
|
+
|
|
21
|
+
TRACK_A: Direct coordinate prediction (CLICK(x, y))
|
|
22
|
+
TRACK_B: ReAct-style reasoning with coordinates
|
|
23
|
+
TRACK_C: Set-of-Mark element selection (CLICK([id]))
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
TRACK_A = "direct_coords"
|
|
27
|
+
TRACK_B = "react_coords"
|
|
28
|
+
TRACK_C = "set_of_mark"
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class CoordinateSystem(str, Enum):
|
|
32
|
+
"""Coordinate system for action output.
|
|
33
|
+
|
|
34
|
+
NORMALIZED: Coordinates in 0.0-1.0 range (relative to screen)
|
|
35
|
+
PIXEL: Absolute pixel coordinates
|
|
36
|
+
PERCENTAGE: Coordinates as percentages (0-100)
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
NORMALIZED = "normalized"
|
|
40
|
+
PIXEL = "pixel"
|
|
41
|
+
PERCENTAGE = "percentage"
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class ActionOutputFormat(str, Enum):
|
|
45
|
+
"""Output format style for model responses.
|
|
46
|
+
|
|
47
|
+
JSON: Structured JSON object
|
|
48
|
+
FUNCTION_CALL: Function-style like CLICK(x, y)
|
|
49
|
+
PYAUTOGUI: PyAutoGUI-style Python code (OSWorld compatible)
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
JSON = "json"
|
|
53
|
+
FUNCTION_CALL = "function_call"
|
|
54
|
+
PYAUTOGUI = "pyautogui"
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
@dataclass
|
|
58
|
+
class SoMConfig:
|
|
59
|
+
"""Configuration for Set-of-Mark (SoM) overlay.
|
|
60
|
+
|
|
61
|
+
Controls how UI elements are labeled and displayed.
|
|
62
|
+
Based on patterns from SoM paper and OMNI-parser.
|
|
63
|
+
|
|
64
|
+
Attributes:
|
|
65
|
+
overlay_enabled: Whether to draw element overlays on screenshot.
|
|
66
|
+
label_format: Format for element labels ("[{id}]", "{id}", "e{id}").
|
|
67
|
+
font_size: Font size for labels in pixels.
|
|
68
|
+
label_background_color: RGBA tuple for label background.
|
|
69
|
+
label_text_color: RGB tuple for label text.
|
|
70
|
+
max_elements: Maximum elements to include (0=unlimited).
|
|
71
|
+
include_roles: Element roles to include (None=all).
|
|
72
|
+
exclude_roles: Element roles to exclude.
|
|
73
|
+
min_element_area: Minimum element area in pixels to include.
|
|
74
|
+
include_invisible: Whether to include non-visible elements.
|
|
75
|
+
"""
|
|
76
|
+
|
|
77
|
+
overlay_enabled: bool = True
|
|
78
|
+
label_format: str = "[{id}]" # "[1]", "1", "e1"
|
|
79
|
+
font_size: int = 12
|
|
80
|
+
label_background_color: tuple[int, int, int, int] = (0, 120, 255, 200) # Blue
|
|
81
|
+
label_text_color: tuple[int, int, int] = (255, 255, 255) # White
|
|
82
|
+
max_elements: int = 100
|
|
83
|
+
include_roles: list[str] | None = None # None = include all
|
|
84
|
+
exclude_roles: list[str] = field(
|
|
85
|
+
default_factory=lambda: ["group", "generic", "static_text", "separator"]
|
|
86
|
+
)
|
|
87
|
+
min_element_area: int = 100 # Minimum bbox area in pixels
|
|
88
|
+
include_invisible: bool = False
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
@dataclass
|
|
92
|
+
class ReActConfig:
|
|
93
|
+
"""Configuration for ReAct-style reasoning.
|
|
94
|
+
|
|
95
|
+
Controls the observation-thought-action cycle used in Track B.
|
|
96
|
+
Based on ReAct paper and UFO's Observation->Thought->Action pattern.
|
|
97
|
+
|
|
98
|
+
Attributes:
|
|
99
|
+
require_observation: Whether to require explicit observation.
|
|
100
|
+
require_thought: Whether to require reasoning explanation.
|
|
101
|
+
require_plan: Whether to require multi-step plan.
|
|
102
|
+
max_plan_steps: Maximum steps in plan output.
|
|
103
|
+
thinking_budget: Token budget for thinking (Claude extended thinking).
|
|
104
|
+
"""
|
|
105
|
+
|
|
106
|
+
require_observation: bool = True
|
|
107
|
+
require_thought: bool = True
|
|
108
|
+
require_plan: bool = False
|
|
109
|
+
max_plan_steps: int = 5
|
|
110
|
+
thinking_budget: int | None = None # For Claude extended thinking
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
@dataclass
|
|
114
|
+
class ScreenConfig:
|
|
115
|
+
"""Screen/display configuration for coordinate handling.
|
|
116
|
+
|
|
117
|
+
Attributes:
|
|
118
|
+
width: Display width in pixels.
|
|
119
|
+
height: Display height in pixels.
|
|
120
|
+
coordinate_system: How coordinates are represented.
|
|
121
|
+
scale_factor: DPI scale factor (1.0 = standard, 2.0 = retina).
|
|
122
|
+
"""
|
|
123
|
+
|
|
124
|
+
width: int = 1920
|
|
125
|
+
height: int = 1080
|
|
126
|
+
coordinate_system: CoordinateSystem = CoordinateSystem.NORMALIZED
|
|
127
|
+
scale_factor: float = 1.0
|
|
128
|
+
|
|
129
|
+
def normalize_coords(self, x: float, y: float) -> tuple[float, float]:
|
|
130
|
+
"""Convert pixel coordinates to normalized (0-1)."""
|
|
131
|
+
return (x / self.width, y / self.height)
|
|
132
|
+
|
|
133
|
+
def denormalize_coords(self, x: float, y: float) -> tuple[int, int]:
|
|
134
|
+
"""Convert normalized coordinates to pixels."""
|
|
135
|
+
return (int(x * self.width), int(y * self.height))
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
@dataclass
|
|
139
|
+
class TrackConfig:
|
|
140
|
+
"""Configuration for a specific evaluation track.
|
|
141
|
+
|
|
142
|
+
Attributes:
|
|
143
|
+
track_type: The track type (A, B, or C).
|
|
144
|
+
output_format: Expected output format string.
|
|
145
|
+
action_format: Style of action output (JSON, function, pyautogui).
|
|
146
|
+
use_som: Whether to use Set-of-Mark overlay.
|
|
147
|
+
som_config: Configuration for SoM (Track C).
|
|
148
|
+
use_a11y_tree: Whether to include accessibility tree.
|
|
149
|
+
max_a11y_elements: Max elements in a11y tree (truncation).
|
|
150
|
+
include_reasoning: Whether to request reasoning steps.
|
|
151
|
+
react_config: Configuration for ReAct (Track B).
|
|
152
|
+
include_history: Whether to include action history.
|
|
153
|
+
max_history_steps: Max history steps to include.
|
|
154
|
+
screen_config: Screen/coordinate configuration.
|
|
155
|
+
verify_after_action: Request screenshot verification after actions.
|
|
156
|
+
"""
|
|
157
|
+
|
|
158
|
+
track_type: TrackType
|
|
159
|
+
output_format: str
|
|
160
|
+
action_format: ActionOutputFormat = ActionOutputFormat.JSON
|
|
161
|
+
use_som: bool = False
|
|
162
|
+
som_config: SoMConfig | None = None
|
|
163
|
+
use_a11y_tree: bool = True
|
|
164
|
+
max_a11y_elements: int = 50
|
|
165
|
+
include_reasoning: bool = False
|
|
166
|
+
react_config: ReActConfig | None = None
|
|
167
|
+
include_history: bool = True
|
|
168
|
+
max_history_steps: int = 5
|
|
169
|
+
screen_config: ScreenConfig = field(default_factory=ScreenConfig)
|
|
170
|
+
verify_after_action: bool = False # Claude computer use best practice
|
|
171
|
+
|
|
172
|
+
@classmethod
|
|
173
|
+
def track_a(cls, **kwargs: Any) -> "TrackConfig":
|
|
174
|
+
"""Create Track A (Direct Coordinates) config.
|
|
175
|
+
|
|
176
|
+
Simplest track: screenshot + goal -> coordinates.
|
|
177
|
+
No reasoning or element IDs.
|
|
178
|
+
"""
|
|
179
|
+
return cls(
|
|
180
|
+
track_type=TrackType.TRACK_A,
|
|
181
|
+
output_format='{"action": "CLICK", "x": float, "y": float}',
|
|
182
|
+
action_format=ActionOutputFormat.JSON,
|
|
183
|
+
use_som=False,
|
|
184
|
+
use_a11y_tree=True,
|
|
185
|
+
include_reasoning=False,
|
|
186
|
+
**kwargs,
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
@classmethod
|
|
190
|
+
def track_b(cls, **kwargs: Any) -> "TrackConfig":
|
|
191
|
+
"""Create Track B (ReAct with Coordinates) config.
|
|
192
|
+
|
|
193
|
+
Includes observation->thought->action cycle.
|
|
194
|
+
Based on ReAct, UFO, and Claude thinking patterns.
|
|
195
|
+
"""
|
|
196
|
+
react_config = kwargs.pop("react_config", None) or ReActConfig()
|
|
197
|
+
return cls(
|
|
198
|
+
track_type=TrackType.TRACK_B,
|
|
199
|
+
output_format='{"observation": str, "thought": str, "action": "CLICK", "x": float, "y": float}',
|
|
200
|
+
action_format=ActionOutputFormat.JSON,
|
|
201
|
+
use_som=False,
|
|
202
|
+
use_a11y_tree=True,
|
|
203
|
+
include_reasoning=True,
|
|
204
|
+
react_config=react_config,
|
|
205
|
+
**kwargs,
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
@classmethod
|
|
209
|
+
def track_c(cls, **kwargs: Any) -> "TrackConfig":
|
|
210
|
+
"""Create Track C (Set-of-Mark) config.
|
|
211
|
+
|
|
212
|
+
Uses numbered element labels instead of coordinates.
|
|
213
|
+
Based on SoM paper and OMNI-parser patterns.
|
|
214
|
+
"""
|
|
215
|
+
som_config = kwargs.pop("som_config", None) or SoMConfig()
|
|
216
|
+
return cls(
|
|
217
|
+
track_type=TrackType.TRACK_C,
|
|
218
|
+
output_format='{"action": "CLICK", "element_id": int}',
|
|
219
|
+
action_format=ActionOutputFormat.JSON,
|
|
220
|
+
use_som=True,
|
|
221
|
+
som_config=som_config,
|
|
222
|
+
use_a11y_tree=True,
|
|
223
|
+
include_reasoning=False,
|
|
224
|
+
**kwargs,
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
@classmethod
|
|
228
|
+
def osworld_compatible(cls, **kwargs: Any) -> "TrackConfig":
|
|
229
|
+
"""Create OSWorld-compatible config.
|
|
230
|
+
|
|
231
|
+
Uses PyAutoGUI-style action format for OSWorld benchmark.
|
|
232
|
+
"""
|
|
233
|
+
return cls(
|
|
234
|
+
track_type=TrackType.TRACK_A,
|
|
235
|
+
output_format="pyautogui.click(x, y)",
|
|
236
|
+
action_format=ActionOutputFormat.PYAUTOGUI,
|
|
237
|
+
use_som=False,
|
|
238
|
+
use_a11y_tree=True,
|
|
239
|
+
include_reasoning=False,
|
|
240
|
+
**kwargs,
|
|
241
|
+
)
|
|
242
|
+
|
|
243
|
+
@classmethod
|
|
244
|
+
def ufo_compatible(cls, **kwargs: Any) -> "TrackConfig":
|
|
245
|
+
"""Create UFO-compatible config.
|
|
246
|
+
|
|
247
|
+
Uses UFO's AppAgent output format with observation/thought/plan.
|
|
248
|
+
"""
|
|
249
|
+
react_config = kwargs.pop("react_config", None) or ReActConfig(
|
|
250
|
+
require_observation=True,
|
|
251
|
+
require_thought=True,
|
|
252
|
+
require_plan=True,
|
|
253
|
+
)
|
|
254
|
+
return cls(
|
|
255
|
+
track_type=TrackType.TRACK_B,
|
|
256
|
+
output_format='{"Observation": str, "Thought": str, "ControlLabel": int, "Function": str, "Args": list}',
|
|
257
|
+
action_format=ActionOutputFormat.JSON,
|
|
258
|
+
use_som=True,
|
|
259
|
+
som_config=SoMConfig(),
|
|
260
|
+
use_a11y_tree=True,
|
|
261
|
+
include_reasoning=True,
|
|
262
|
+
react_config=react_config,
|
|
263
|
+
**kwargs,
|
|
264
|
+
)
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
@dataclass
|
|
268
|
+
class ModelSpec:
|
|
269
|
+
"""Specification for a supported model.
|
|
270
|
+
|
|
271
|
+
Attributes:
|
|
272
|
+
provider: Provider name (anthropic, openai, google).
|
|
273
|
+
model_id: Full model identifier for the API.
|
|
274
|
+
display_name: Human-readable name.
|
|
275
|
+
is_default: Whether this is the default for the provider.
|
|
276
|
+
max_tokens: Default max tokens for this model.
|
|
277
|
+
supports_vision: Whether the model supports images.
|
|
278
|
+
"""
|
|
279
|
+
|
|
280
|
+
provider: str
|
|
281
|
+
model_id: str
|
|
282
|
+
display_name: str
|
|
283
|
+
is_default: bool = False
|
|
284
|
+
max_tokens: int = 1024
|
|
285
|
+
supports_vision: bool = True
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
# Model registry
|
|
289
|
+
MODELS: dict[str, ModelSpec] = {
|
|
290
|
+
# Anthropic Claude
|
|
291
|
+
"claude-opus-4.5": ModelSpec(
|
|
292
|
+
provider="anthropic",
|
|
293
|
+
model_id="claude-opus-4-5-20251101",
|
|
294
|
+
display_name="Claude Opus 4.5",
|
|
295
|
+
is_default=True,
|
|
296
|
+
max_tokens=4096,
|
|
297
|
+
),
|
|
298
|
+
"claude-sonnet-4.5": ModelSpec(
|
|
299
|
+
provider="anthropic",
|
|
300
|
+
model_id="claude-sonnet-4-5-20250929",
|
|
301
|
+
display_name="Claude Sonnet 4.5",
|
|
302
|
+
max_tokens=4096,
|
|
303
|
+
),
|
|
304
|
+
# OpenAI GPT
|
|
305
|
+
"gpt-5.2": ModelSpec(
|
|
306
|
+
provider="openai",
|
|
307
|
+
model_id="gpt-5.2",
|
|
308
|
+
display_name="GPT-5.2",
|
|
309
|
+
is_default=True,
|
|
310
|
+
max_tokens=4096,
|
|
311
|
+
),
|
|
312
|
+
"gpt-5.1": ModelSpec(
|
|
313
|
+
provider="openai",
|
|
314
|
+
model_id="gpt-5.1",
|
|
315
|
+
display_name="GPT-5.1",
|
|
316
|
+
max_tokens=4096,
|
|
317
|
+
),
|
|
318
|
+
"gpt-4o": ModelSpec(
|
|
319
|
+
provider="openai",
|
|
320
|
+
model_id="gpt-4o",
|
|
321
|
+
display_name="GPT-4o",
|
|
322
|
+
max_tokens=4096,
|
|
323
|
+
),
|
|
324
|
+
# Google Gemini
|
|
325
|
+
"gemini-3-pro": ModelSpec(
|
|
326
|
+
provider="google",
|
|
327
|
+
model_id="gemini-3-pro",
|
|
328
|
+
display_name="Gemini 3 Pro",
|
|
329
|
+
is_default=True,
|
|
330
|
+
max_tokens=4096,
|
|
331
|
+
),
|
|
332
|
+
"gemini-3-flash": ModelSpec(
|
|
333
|
+
provider="google",
|
|
334
|
+
model_id="gemini-3-flash",
|
|
335
|
+
display_name="Gemini 3 Flash",
|
|
336
|
+
max_tokens=4096,
|
|
337
|
+
),
|
|
338
|
+
"gemini-2.5-pro": ModelSpec(
|
|
339
|
+
provider="google",
|
|
340
|
+
model_id="gemini-2.5-pro",
|
|
341
|
+
display_name="Gemini 2.5 Pro",
|
|
342
|
+
max_tokens=4096,
|
|
343
|
+
),
|
|
344
|
+
"gemini-2.5-flash": ModelSpec(
|
|
345
|
+
provider="google",
|
|
346
|
+
model_id="gemini-2.5-flash",
|
|
347
|
+
display_name="Gemini 2.5 Flash",
|
|
348
|
+
max_tokens=4096,
|
|
349
|
+
),
|
|
350
|
+
}
|
|
351
|
+
|
|
352
|
+
|
|
353
|
+
def get_model_spec(model_alias: str) -> ModelSpec:
|
|
354
|
+
"""Get model specification by alias.
|
|
355
|
+
|
|
356
|
+
Args:
|
|
357
|
+
model_alias: Model alias (e.g., 'claude-opus-4.5').
|
|
358
|
+
|
|
359
|
+
Returns:
|
|
360
|
+
ModelSpec for the model.
|
|
361
|
+
|
|
362
|
+
Raises:
|
|
363
|
+
ValueError: If alias not recognized.
|
|
364
|
+
"""
|
|
365
|
+
if model_alias not in MODELS:
|
|
366
|
+
available = ", ".join(MODELS.keys())
|
|
367
|
+
raise ValueError(f"Unknown model: {model_alias}. Available: {available}")
|
|
368
|
+
return MODELS[model_alias]
|
|
369
|
+
|
|
370
|
+
|
|
371
|
+
def get_default_model(provider: str) -> ModelSpec:
|
|
372
|
+
"""Get default model for a provider.
|
|
373
|
+
|
|
374
|
+
Args:
|
|
375
|
+
provider: Provider name (anthropic, openai, google).
|
|
376
|
+
|
|
377
|
+
Returns:
|
|
378
|
+
Default ModelSpec for the provider.
|
|
379
|
+
|
|
380
|
+
Raises:
|
|
381
|
+
ValueError: If no default found.
|
|
382
|
+
"""
|
|
383
|
+
for spec in MODELS.values():
|
|
384
|
+
if spec.provider == provider and spec.is_default:
|
|
385
|
+
return spec
|
|
386
|
+
raise ValueError(f"No default model for provider: {provider}")
|
|
387
|
+
|
|
388
|
+
|
|
389
|
+
@dataclass
|
|
390
|
+
class BaselineConfig:
|
|
391
|
+
"""Configuration for a baseline adapter run.
|
|
392
|
+
|
|
393
|
+
Attributes:
|
|
394
|
+
provider: Provider name or model alias.
|
|
395
|
+
model: Model identifier (full ID or alias).
|
|
396
|
+
track: Track configuration.
|
|
397
|
+
api_key: Optional API key (defaults to env).
|
|
398
|
+
temperature: Sampling temperature.
|
|
399
|
+
max_tokens: Max response tokens.
|
|
400
|
+
demo: Optional demo text to include.
|
|
401
|
+
verbose: Whether to log verbose output.
|
|
402
|
+
"""
|
|
403
|
+
|
|
404
|
+
provider: str
|
|
405
|
+
model: str
|
|
406
|
+
track: TrackConfig = field(default_factory=TrackConfig.track_a)
|
|
407
|
+
api_key: str | None = None
|
|
408
|
+
temperature: float = 0.1
|
|
409
|
+
max_tokens: int = 1024
|
|
410
|
+
demo: str | None = None
|
|
411
|
+
verbose: bool = False
|
|
412
|
+
|
|
413
|
+
def __post_init__(self):
|
|
414
|
+
"""Resolve model alias if needed."""
|
|
415
|
+
# If provider is actually a model alias, resolve it
|
|
416
|
+
if self.provider in MODELS:
|
|
417
|
+
spec = MODELS[self.provider]
|
|
418
|
+
self.provider = spec.provider
|
|
419
|
+
self.model = spec.model_id
|
|
420
|
+
# If model is an alias, resolve it
|
|
421
|
+
elif self.model in MODELS:
|
|
422
|
+
spec = MODELS[self.model]
|
|
423
|
+
self.model = spec.model_id
|
|
424
|
+
|
|
425
|
+
@classmethod
|
|
426
|
+
def from_alias(
|
|
427
|
+
cls,
|
|
428
|
+
model_alias: str,
|
|
429
|
+
track: TrackConfig | None = None,
|
|
430
|
+
**kwargs: Any,
|
|
431
|
+
) -> "BaselineConfig":
|
|
432
|
+
"""Create config from model alias.
|
|
433
|
+
|
|
434
|
+
Args:
|
|
435
|
+
model_alias: Model alias (e.g., 'claude-opus-4.5').
|
|
436
|
+
track: Track config (defaults to Track A).
|
|
437
|
+
**kwargs: Additional config options.
|
|
438
|
+
|
|
439
|
+
Returns:
|
|
440
|
+
BaselineConfig instance.
|
|
441
|
+
"""
|
|
442
|
+
spec = get_model_spec(model_alias)
|
|
443
|
+
return cls(
|
|
444
|
+
provider=spec.provider,
|
|
445
|
+
model=spec.model_id,
|
|
446
|
+
track=track or TrackConfig.track_a(),
|
|
447
|
+
**kwargs,
|
|
448
|
+
)
|