openadapt-ml 0.1.0__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openadapt_ml/baselines/__init__.py +121 -0
- openadapt_ml/baselines/adapter.py +185 -0
- openadapt_ml/baselines/cli.py +314 -0
- openadapt_ml/baselines/config.py +448 -0
- openadapt_ml/baselines/parser.py +922 -0
- openadapt_ml/baselines/prompts.py +787 -0
- openadapt_ml/benchmarks/__init__.py +13 -107
- openadapt_ml/benchmarks/agent.py +297 -374
- openadapt_ml/benchmarks/azure.py +62 -24
- openadapt_ml/benchmarks/azure_ops_tracker.py +521 -0
- openadapt_ml/benchmarks/cli.py +1874 -751
- openadapt_ml/benchmarks/trace_export.py +631 -0
- openadapt_ml/benchmarks/viewer.py +1236 -0
- openadapt_ml/benchmarks/vm_monitor.py +1111 -0
- openadapt_ml/benchmarks/waa_deploy/Dockerfile +216 -0
- openadapt_ml/benchmarks/waa_deploy/__init__.py +10 -0
- openadapt_ml/benchmarks/waa_deploy/api_agent.py +540 -0
- openadapt_ml/benchmarks/waa_deploy/start_waa_server.bat +53 -0
- openadapt_ml/cloud/azure_inference.py +3 -5
- openadapt_ml/cloud/lambda_labs.py +722 -307
- openadapt_ml/cloud/local.py +3194 -89
- openadapt_ml/cloud/ssh_tunnel.py +595 -0
- openadapt_ml/datasets/next_action.py +125 -96
- openadapt_ml/evals/grounding.py +32 -9
- openadapt_ml/evals/plot_eval_metrics.py +15 -13
- openadapt_ml/evals/trajectory_matching.py +120 -57
- openadapt_ml/experiments/demo_prompt/__init__.py +19 -0
- openadapt_ml/experiments/demo_prompt/format_demo.py +236 -0
- openadapt_ml/experiments/demo_prompt/results/experiment_20251231_002125.json +83 -0
- openadapt_ml/experiments/demo_prompt/results/experiment_n30_20251231_165958.json +1100 -0
- openadapt_ml/experiments/demo_prompt/results/multistep_20251231_025051.json +182 -0
- openadapt_ml/experiments/demo_prompt/run_experiment.py +541 -0
- openadapt_ml/experiments/representation_shootout/__init__.py +70 -0
- openadapt_ml/experiments/representation_shootout/conditions.py +708 -0
- openadapt_ml/experiments/representation_shootout/config.py +390 -0
- openadapt_ml/experiments/representation_shootout/evaluator.py +659 -0
- openadapt_ml/experiments/representation_shootout/runner.py +687 -0
- openadapt_ml/experiments/waa_demo/__init__.py +10 -0
- openadapt_ml/experiments/waa_demo/demos.py +357 -0
- openadapt_ml/experiments/waa_demo/runner.py +732 -0
- openadapt_ml/experiments/waa_demo/tasks.py +151 -0
- openadapt_ml/export/__init__.py +9 -0
- openadapt_ml/export/__main__.py +6 -0
- openadapt_ml/export/cli.py +89 -0
- openadapt_ml/export/parquet.py +277 -0
- openadapt_ml/grounding/detector.py +18 -14
- openadapt_ml/ingest/__init__.py +11 -10
- openadapt_ml/ingest/capture.py +97 -86
- openadapt_ml/ingest/loader.py +120 -69
- openadapt_ml/ingest/synthetic.py +344 -193
- openadapt_ml/models/api_adapter.py +14 -4
- openadapt_ml/models/base_adapter.py +10 -2
- openadapt_ml/models/providers/__init__.py +288 -0
- openadapt_ml/models/providers/anthropic.py +266 -0
- openadapt_ml/models/providers/base.py +299 -0
- openadapt_ml/models/providers/google.py +376 -0
- openadapt_ml/models/providers/openai.py +342 -0
- openadapt_ml/models/qwen_vl.py +46 -19
- openadapt_ml/perception/__init__.py +35 -0
- openadapt_ml/perception/integration.py +399 -0
- openadapt_ml/retrieval/README.md +226 -0
- openadapt_ml/retrieval/USAGE.md +391 -0
- openadapt_ml/retrieval/__init__.py +91 -0
- openadapt_ml/retrieval/demo_retriever.py +843 -0
- openadapt_ml/retrieval/embeddings.py +630 -0
- openadapt_ml/retrieval/index.py +194 -0
- openadapt_ml/retrieval/retriever.py +162 -0
- openadapt_ml/runtime/__init__.py +50 -0
- openadapt_ml/runtime/policy.py +27 -14
- openadapt_ml/runtime/safety_gate.py +471 -0
- openadapt_ml/schema/__init__.py +113 -0
- openadapt_ml/schema/converters.py +588 -0
- openadapt_ml/schema/episode.py +470 -0
- openadapt_ml/scripts/capture_screenshots.py +530 -0
- openadapt_ml/scripts/compare.py +102 -61
- openadapt_ml/scripts/demo_policy.py +4 -1
- openadapt_ml/scripts/eval_policy.py +19 -14
- openadapt_ml/scripts/make_gif.py +1 -1
- openadapt_ml/scripts/prepare_synthetic.py +16 -17
- openadapt_ml/scripts/train.py +98 -75
- openadapt_ml/segmentation/README.md +920 -0
- openadapt_ml/segmentation/__init__.py +97 -0
- openadapt_ml/segmentation/adapters/__init__.py +5 -0
- openadapt_ml/segmentation/adapters/capture_adapter.py +420 -0
- openadapt_ml/segmentation/annotator.py +610 -0
- openadapt_ml/segmentation/cache.py +290 -0
- openadapt_ml/segmentation/cli.py +674 -0
- openadapt_ml/segmentation/deduplicator.py +656 -0
- openadapt_ml/segmentation/frame_describer.py +788 -0
- openadapt_ml/segmentation/pipeline.py +340 -0
- openadapt_ml/segmentation/schemas.py +622 -0
- openadapt_ml/segmentation/segment_extractor.py +634 -0
- openadapt_ml/training/azure_ops_viewer.py +1097 -0
- openadapt_ml/training/benchmark_viewer.py +3255 -19
- openadapt_ml/training/shared_ui.py +7 -7
- openadapt_ml/training/stub_provider.py +57 -35
- openadapt_ml/training/trainer.py +255 -441
- openadapt_ml/training/trl_trainer.py +403 -0
- openadapt_ml/training/viewer.py +323 -108
- openadapt_ml/training/viewer_components.py +180 -0
- {openadapt_ml-0.1.0.dist-info → openadapt_ml-0.2.1.dist-info}/METADATA +312 -69
- openadapt_ml-0.2.1.dist-info/RECORD +116 -0
- openadapt_ml/benchmarks/base.py +0 -366
- openadapt_ml/benchmarks/data_collection.py +0 -432
- openadapt_ml/benchmarks/runner.py +0 -381
- openadapt_ml/benchmarks/waa.py +0 -704
- openadapt_ml/schemas/__init__.py +0 -53
- openadapt_ml/schemas/sessions.py +0 -122
- openadapt_ml/schemas/validation.py +0 -252
- openadapt_ml-0.1.0.dist-info/RECORD +0 -55
- {openadapt_ml-0.1.0.dist-info → openadapt_ml-0.2.1.dist-info}/WHEEL +0 -0
- {openadapt_ml-0.1.0.dist-info → openadapt_ml-0.2.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,470 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Episode Schema for GUI Trajectory Data
|
|
3
|
+
|
|
4
|
+
Canonical contract for episode/demonstration data in GUI automation. Designed for
|
|
5
|
+
interoperability across training pipelines, benchmarks, and human demonstrations.
|
|
6
|
+
|
|
7
|
+
Features:
|
|
8
|
+
- Pydantic models with runtime validation
|
|
9
|
+
- JSON Schema export for language-agnostic tooling
|
|
10
|
+
- Supports pixel coordinates AND normalized (0-1) coordinates
|
|
11
|
+
- Extensible via `raw` and `metadata` fields
|
|
12
|
+
- Converters for common formats (WAA, WebArena, etc.)
|
|
13
|
+
|
|
14
|
+
Quick Start:
|
|
15
|
+
from openadapt_ml.schema import Episode, Step, Action, Observation, ActionType
|
|
16
|
+
|
|
17
|
+
episode = Episode(
|
|
18
|
+
episode_id="demo_001",
|
|
19
|
+
instruction="Open the Settings app and enable Dark Mode",
|
|
20
|
+
steps=[
|
|
21
|
+
Step(
|
|
22
|
+
step_index=0,
|
|
23
|
+
observation=Observation(screenshot_path="step_0.png"),
|
|
24
|
+
action=Action(
|
|
25
|
+
type=ActionType.CLICK,
|
|
26
|
+
coordinates={"x": 512, "y": 384},
|
|
27
|
+
# Or use normalized coords for resolution independence:
|
|
28
|
+
# normalized_coordinates=(0.5, 0.375),
|
|
29
|
+
),
|
|
30
|
+
reasoning="Click on Settings icon",
|
|
31
|
+
),
|
|
32
|
+
],
|
|
33
|
+
success=True,
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
# Validate any dict against the schema
|
|
37
|
+
from openadapt_ml.schema import validate_episode
|
|
38
|
+
is_valid, error = validate_episode(data)
|
|
39
|
+
|
|
40
|
+
# Export JSON Schema for external tools
|
|
41
|
+
from openadapt_ml.schema import export_json_schema
|
|
42
|
+
export_json_schema("episode.schema.json")
|
|
43
|
+
|
|
44
|
+
Schema Version: 1.0.0
|
|
45
|
+
- Core models: Episode, Step, Action, Observation
|
|
46
|
+
- 24 action types covering mouse, keyboard, navigation, and system actions
|
|
47
|
+
- Support for both pixel and normalized coordinates
|
|
48
|
+
- Extension points: raw, metadata fields
|
|
49
|
+
|
|
50
|
+
Evolution Policy (SemVer):
|
|
51
|
+
- PATCH (1.0.x): Documentation, bug fixes (no schema changes)
|
|
52
|
+
- MINOR (1.x.0): New optional fields with defaults (backward compatible)
|
|
53
|
+
- MAJOR (x.0.0): Breaking changes (field removal, type changes, new required fields)
|
|
54
|
+
|
|
55
|
+
Migration Guide:
|
|
56
|
+
- MINOR bumps: No action needed, old data validates
|
|
57
|
+
- MAJOR bumps: Use converters or migration scripts (provided in release notes)
|
|
58
|
+
"""
|
|
59
|
+
|
|
60
|
+
from __future__ import annotations
|
|
61
|
+
|
|
62
|
+
import json
|
|
63
|
+
from datetime import datetime
|
|
64
|
+
from enum import Enum
|
|
65
|
+
from pathlib import Path
|
|
66
|
+
from typing import Any, Literal, Optional, Union
|
|
67
|
+
|
|
68
|
+
from pydantic import BaseModel, Field, field_validator, model_validator
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
# Schema version - follows semver
|
|
72
|
+
SCHEMA_VERSION = "1.0.0"
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class ActionType(str, Enum):
|
|
76
|
+
"""Supported action types for GUI automation."""
|
|
77
|
+
|
|
78
|
+
# Mouse actions
|
|
79
|
+
CLICK = "click"
|
|
80
|
+
DOUBLE_CLICK = "double_click"
|
|
81
|
+
RIGHT_CLICK = "right_click"
|
|
82
|
+
DRAG = "drag"
|
|
83
|
+
SCROLL = "scroll"
|
|
84
|
+
HOVER = "hover"
|
|
85
|
+
|
|
86
|
+
# Keyboard actions
|
|
87
|
+
TYPE = "type"
|
|
88
|
+
KEY = "key"
|
|
89
|
+
HOTKEY = "hotkey"
|
|
90
|
+
|
|
91
|
+
# Combined/special actions
|
|
92
|
+
CLICK_AND_TYPE = "click_and_type"
|
|
93
|
+
WAIT = "wait"
|
|
94
|
+
SCREENSHOT = "screenshot"
|
|
95
|
+
|
|
96
|
+
# Navigation (for web)
|
|
97
|
+
GOTO = "goto"
|
|
98
|
+
BACK = "back"
|
|
99
|
+
FORWARD = "forward"
|
|
100
|
+
REFRESH = "refresh"
|
|
101
|
+
|
|
102
|
+
# System actions
|
|
103
|
+
OPEN_APP = "open_app"
|
|
104
|
+
CLOSE_APP = "close_app"
|
|
105
|
+
SELECT_MONITOR = "select_monitor" # Multi-monitor: focus a specific display
|
|
106
|
+
WINDOW_FOCUS = "window_focus" # Focus a specific window
|
|
107
|
+
WINDOW_RESIZE = "window_resize" # Resize window
|
|
108
|
+
WINDOW_MOVE = "window_move" # Move window
|
|
109
|
+
|
|
110
|
+
# Meta actions
|
|
111
|
+
DONE = "done"
|
|
112
|
+
FAIL = "fail"
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
class BenchmarkSource(str, Enum):
|
|
116
|
+
"""Source benchmark/dataset for the episode."""
|
|
117
|
+
|
|
118
|
+
WAA = "waa" # Windows Agent Arena
|
|
119
|
+
WEBARENA = "webarena"
|
|
120
|
+
OSWORLD = "osworld"
|
|
121
|
+
MINIWOB = "miniwob"
|
|
122
|
+
HUMAN = "human" # Human demonstration
|
|
123
|
+
SYNTHETIC = "synthetic" # Generated/augmented
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
class Coordinates(BaseModel):
|
|
127
|
+
"""Screen coordinates for mouse actions."""
|
|
128
|
+
|
|
129
|
+
x: int = Field(..., description="X coordinate (pixels from left)")
|
|
130
|
+
y: int = Field(..., description="Y coordinate (pixels from top)")
|
|
131
|
+
|
|
132
|
+
@field_validator("x", "y")
|
|
133
|
+
@classmethod
|
|
134
|
+
def validate_non_negative(cls, v: int) -> int:
|
|
135
|
+
if v < 0:
|
|
136
|
+
raise ValueError("Coordinates must be non-negative")
|
|
137
|
+
return v
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
class BoundingBox(BaseModel):
|
|
141
|
+
"""Bounding box for UI elements."""
|
|
142
|
+
|
|
143
|
+
x: int = Field(..., description="Left edge X coordinate")
|
|
144
|
+
y: int = Field(..., description="Top edge Y coordinate")
|
|
145
|
+
width: int = Field(..., ge=0, description="Width in pixels")
|
|
146
|
+
height: int = Field(..., ge=0, description="Height in pixels")
|
|
147
|
+
|
|
148
|
+
@property
|
|
149
|
+
def center(self) -> Coordinates:
|
|
150
|
+
"""Get center point of bounding box."""
|
|
151
|
+
return Coordinates(x=self.x + self.width // 2, y=self.y + self.height // 2)
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
class UIElement(BaseModel):
|
|
155
|
+
"""UI element information from accessibility tree or DOM."""
|
|
156
|
+
|
|
157
|
+
role: Optional[str] = Field(
|
|
158
|
+
None, description="Element role (button, textbox, etc.)"
|
|
159
|
+
)
|
|
160
|
+
name: Optional[str] = Field(None, description="Element accessible name")
|
|
161
|
+
value: Optional[str] = Field(None, description="Element value (for inputs)")
|
|
162
|
+
bounds: Optional[BoundingBox] = Field(None, description="Element bounding box")
|
|
163
|
+
element_id: Optional[str] = Field(None, description="Unique element identifier")
|
|
164
|
+
xpath: Optional[str] = Field(None, description="XPath selector (web)")
|
|
165
|
+
selector: Optional[str] = Field(None, description="CSS selector (web)")
|
|
166
|
+
automation_id: Optional[str] = Field(None, description="Automation ID (Windows)")
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
class Action(BaseModel):
|
|
170
|
+
"""An action taken by the agent."""
|
|
171
|
+
|
|
172
|
+
type: ActionType = Field(..., description="Type of action")
|
|
173
|
+
|
|
174
|
+
# Mouse action parameters
|
|
175
|
+
coordinates: Optional[Coordinates] = Field(
|
|
176
|
+
None, description="Target coordinates for mouse actions"
|
|
177
|
+
)
|
|
178
|
+
start_coordinates: Optional[Coordinates] = Field(
|
|
179
|
+
None, description="Start coordinates for drag actions"
|
|
180
|
+
)
|
|
181
|
+
end_coordinates: Optional[Coordinates] = Field(
|
|
182
|
+
None, description="End coordinates for drag actions"
|
|
183
|
+
)
|
|
184
|
+
scroll_direction: Optional[Literal["up", "down", "left", "right"]] = Field(
|
|
185
|
+
None, description="Scroll direction"
|
|
186
|
+
)
|
|
187
|
+
scroll_amount: Optional[int] = Field(None, description="Scroll amount in pixels")
|
|
188
|
+
|
|
189
|
+
# Keyboard action parameters
|
|
190
|
+
text: Optional[str] = Field(None, description="Text to type")
|
|
191
|
+
key: Optional[str] = Field(None, description="Key to press (e.g., 'enter', 'tab')")
|
|
192
|
+
modifiers: Optional[list[str]] = Field(
|
|
193
|
+
None, description="Modifier keys (ctrl, alt, shift, meta)"
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
# Element targeting (alternative to coordinates)
|
|
197
|
+
element: Optional[UIElement] = Field(
|
|
198
|
+
None, description="Target element (for element-based actions)"
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
# Additional parameters
|
|
202
|
+
url: Optional[str] = Field(None, description="URL for goto action")
|
|
203
|
+
app_name: Optional[str] = Field(None, description="Application name for open/close")
|
|
204
|
+
duration: Optional[float] = Field(
|
|
205
|
+
None, description="Duration in seconds (for wait)"
|
|
206
|
+
)
|
|
207
|
+
monitor_id: Optional[int] = Field(
|
|
208
|
+
None, description="Monitor ID for select_monitor action"
|
|
209
|
+
)
|
|
210
|
+
window_title: Optional[str] = Field(
|
|
211
|
+
None, description="Window title for window_focus action"
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
# Normalized coordinates (0.0-1.0) - alternative to pixel coordinates
|
|
215
|
+
# Useful for resolution-independent recordings
|
|
216
|
+
normalized_coordinates: Optional[tuple[float, float]] = Field(
|
|
217
|
+
None, description="Normalized (x, y) coordinates (0.0-1.0 range)"
|
|
218
|
+
)
|
|
219
|
+
normalized_start: Optional[tuple[float, float]] = Field(
|
|
220
|
+
None, description="Normalized start coordinates for drag (0.0-1.0 range)"
|
|
221
|
+
)
|
|
222
|
+
normalized_end: Optional[tuple[float, float]] = Field(
|
|
223
|
+
None, description="Normalized end coordinates for drag (0.0-1.0 range)"
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
# Raw/original action data
|
|
227
|
+
raw: Optional[dict[str, Any]] = Field(
|
|
228
|
+
None, description="Original action data from source format"
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
@model_validator(mode="after")
|
|
232
|
+
def validate_action_params(self) -> "Action":
|
|
233
|
+
"""Validate that required parameters are present for action type."""
|
|
234
|
+
if self.type in {
|
|
235
|
+
ActionType.CLICK,
|
|
236
|
+
ActionType.DOUBLE_CLICK,
|
|
237
|
+
ActionType.RIGHT_CLICK,
|
|
238
|
+
}:
|
|
239
|
+
if self.coordinates is None and self.element is None:
|
|
240
|
+
# Allow missing coordinates - can be inferred from context
|
|
241
|
+
pass
|
|
242
|
+
|
|
243
|
+
if self.type == ActionType.TYPE and self.text is None:
|
|
244
|
+
raise ValueError("TYPE action requires 'text' parameter")
|
|
245
|
+
|
|
246
|
+
if self.type == ActionType.KEY and self.key is None:
|
|
247
|
+
raise ValueError("KEY action requires 'key' parameter")
|
|
248
|
+
|
|
249
|
+
if self.type == ActionType.GOTO and self.url is None:
|
|
250
|
+
raise ValueError("GOTO action requires 'url' parameter")
|
|
251
|
+
|
|
252
|
+
return self
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
class Observation(BaseModel):
|
|
256
|
+
"""An observation of the environment state."""
|
|
257
|
+
|
|
258
|
+
# Visual observation
|
|
259
|
+
screenshot_path: Optional[str] = Field(
|
|
260
|
+
None, description="Path to screenshot image file"
|
|
261
|
+
)
|
|
262
|
+
screenshot_base64: Optional[str] = Field(
|
|
263
|
+
None, description="Base64-encoded screenshot (for inline storage)"
|
|
264
|
+
)
|
|
265
|
+
|
|
266
|
+
# Structured observations
|
|
267
|
+
a11y_tree: Optional[dict[str, Any]] = Field(
|
|
268
|
+
None, description="Accessibility tree snapshot"
|
|
269
|
+
)
|
|
270
|
+
dom: Optional[str] = Field(None, description="DOM HTML snapshot (web)")
|
|
271
|
+
|
|
272
|
+
# Window/screen info
|
|
273
|
+
window_title: Optional[str] = Field(None, description="Active window title")
|
|
274
|
+
app_name: Optional[str] = Field(
|
|
275
|
+
None, description="Application name (e.g., 'Chrome', 'System Settings')"
|
|
276
|
+
)
|
|
277
|
+
url: Optional[str] = Field(None, description="Current URL (for web apps)")
|
|
278
|
+
screen_size: Optional[tuple[int, int]] = Field(
|
|
279
|
+
None, description="Screen dimensions (width, height)"
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
# Focused element
|
|
283
|
+
focused_element: Optional[UIElement] = Field(
|
|
284
|
+
None, description="Currently focused UI element"
|
|
285
|
+
)
|
|
286
|
+
|
|
287
|
+
# Additional metadata
|
|
288
|
+
timestamp: Optional[float] = Field(None, description="Unix timestamp")
|
|
289
|
+
raw: Optional[dict[str, Any]] = Field(
|
|
290
|
+
None, description="Original observation data from source format"
|
|
291
|
+
)
|
|
292
|
+
|
|
293
|
+
|
|
294
|
+
class Step(BaseModel):
|
|
295
|
+
"""A single step in an episode (observation -> action pair)."""
|
|
296
|
+
|
|
297
|
+
step_index: int = Field(..., ge=0, description="Step number (0-indexed)")
|
|
298
|
+
|
|
299
|
+
# Core data
|
|
300
|
+
observation: Observation = Field(..., description="State observation before action")
|
|
301
|
+
action: Action = Field(..., description="Action taken")
|
|
302
|
+
|
|
303
|
+
# Agent reasoning (for demos/training)
|
|
304
|
+
reasoning: Optional[str] = Field(
|
|
305
|
+
None, description="Agent's reasoning for the action (chain-of-thought)"
|
|
306
|
+
)
|
|
307
|
+
|
|
308
|
+
# Outcome
|
|
309
|
+
reward: Optional[float] = Field(None, description="Reward signal (if available)")
|
|
310
|
+
done: Optional[bool] = Field(
|
|
311
|
+
None, description="Whether episode ended after this step"
|
|
312
|
+
)
|
|
313
|
+
|
|
314
|
+
# Timing
|
|
315
|
+
timestamp: Optional[float] = Field(None, description="Unix timestamp of action")
|
|
316
|
+
duration_ms: Optional[int] = Field(
|
|
317
|
+
None, description="Time taken for this step in milliseconds"
|
|
318
|
+
)
|
|
319
|
+
|
|
320
|
+
|
|
321
|
+
class Episode(BaseModel):
|
|
322
|
+
"""A complete episode/demonstration for GUI automation.
|
|
323
|
+
|
|
324
|
+
This is the canonical format for storing and exchanging GUI trajectory data.
|
|
325
|
+
All benchmark-specific formats should be converted to/from this format.
|
|
326
|
+
"""
|
|
327
|
+
|
|
328
|
+
# Schema metadata
|
|
329
|
+
schema_version: str = Field(
|
|
330
|
+
default=SCHEMA_VERSION, description="Schema version for compatibility checking"
|
|
331
|
+
)
|
|
332
|
+
|
|
333
|
+
# Episode identification
|
|
334
|
+
episode_id: str = Field(..., description="Unique episode identifier")
|
|
335
|
+
task_id: Optional[str] = Field(None, description="Task identifier (from benchmark)")
|
|
336
|
+
|
|
337
|
+
# Task specification
|
|
338
|
+
instruction: str = Field(..., description="Natural language task instruction")
|
|
339
|
+
goal: Optional[str] = Field(
|
|
340
|
+
None, description="Detailed goal description (if different from instruction)"
|
|
341
|
+
)
|
|
342
|
+
|
|
343
|
+
# Episode data
|
|
344
|
+
steps: list[Step] = Field(..., description="Sequence of steps in the episode")
|
|
345
|
+
|
|
346
|
+
# Outcome
|
|
347
|
+
success: Optional[bool] = Field(
|
|
348
|
+
None, description="Whether task was completed successfully"
|
|
349
|
+
)
|
|
350
|
+
final_reward: Optional[float] = Field(None, description="Final reward/score")
|
|
351
|
+
|
|
352
|
+
# Provenance
|
|
353
|
+
source: Optional[BenchmarkSource] = Field(
|
|
354
|
+
None, description="Source benchmark/dataset"
|
|
355
|
+
)
|
|
356
|
+
source_file: Optional[str] = Field(None, description="Original source file path")
|
|
357
|
+
|
|
358
|
+
# Metadata
|
|
359
|
+
created_at: Optional[datetime] = Field(
|
|
360
|
+
default_factory=datetime.utcnow, description="When episode was created/recorded"
|
|
361
|
+
)
|
|
362
|
+
agent_model: Optional[str] = Field(
|
|
363
|
+
None, description="Model that generated this episode (e.g., 'gpt-4o')"
|
|
364
|
+
)
|
|
365
|
+
environment: Optional[str] = Field(
|
|
366
|
+
None, description="Environment info (OS, browser, etc.)"
|
|
367
|
+
)
|
|
368
|
+
tags: Optional[list[str]] = Field(None, description="Tags for categorization")
|
|
369
|
+
|
|
370
|
+
# Extension point for benchmark-specific data
|
|
371
|
+
metadata: Optional[dict[str, Any]] = Field(
|
|
372
|
+
None, description="Additional metadata from source"
|
|
373
|
+
)
|
|
374
|
+
|
|
375
|
+
@property
|
|
376
|
+
def num_steps(self) -> int:
|
|
377
|
+
"""Number of steps in the episode."""
|
|
378
|
+
return len(self.steps)
|
|
379
|
+
|
|
380
|
+
@property
|
|
381
|
+
def action_types(self) -> list[ActionType]:
|
|
382
|
+
"""List of action types in this episode."""
|
|
383
|
+
return [step.action.type for step in self.steps]
|
|
384
|
+
|
|
385
|
+
def to_json(self, indent: int = 2) -> str:
|
|
386
|
+
"""Serialize to JSON string."""
|
|
387
|
+
return self.model_dump_json(indent=indent)
|
|
388
|
+
|
|
389
|
+
@classmethod
|
|
390
|
+
def from_json(cls, json_str: str) -> "Episode":
|
|
391
|
+
"""Deserialize from JSON string."""
|
|
392
|
+
return cls.model_validate_json(json_str)
|
|
393
|
+
|
|
394
|
+
@classmethod
|
|
395
|
+
def json_schema(cls) -> dict[str, Any]:
|
|
396
|
+
"""Get JSON Schema for Episode format."""
|
|
397
|
+
return cls.model_json_schema()
|
|
398
|
+
|
|
399
|
+
|
|
400
|
+
# ============================================================================
|
|
401
|
+
# Utility Functions
|
|
402
|
+
# ============================================================================
|
|
403
|
+
|
|
404
|
+
|
|
405
|
+
def validate_episode(data: dict[str, Any]) -> tuple[bool, Optional[str]]:
|
|
406
|
+
"""Validate episode data against schema.
|
|
407
|
+
|
|
408
|
+
Args:
|
|
409
|
+
data: Episode data as dictionary
|
|
410
|
+
|
|
411
|
+
Returns:
|
|
412
|
+
Tuple of (is_valid, error_message)
|
|
413
|
+
"""
|
|
414
|
+
try:
|
|
415
|
+
Episode.model_validate(data)
|
|
416
|
+
return True, None
|
|
417
|
+
except Exception as e:
|
|
418
|
+
return False, str(e)
|
|
419
|
+
|
|
420
|
+
|
|
421
|
+
def load_episode(path: Union[str, Path]) -> Episode:
|
|
422
|
+
"""Load episode from JSON file.
|
|
423
|
+
|
|
424
|
+
Args:
|
|
425
|
+
path: Path to JSON file
|
|
426
|
+
|
|
427
|
+
Returns:
|
|
428
|
+
Episode instance
|
|
429
|
+
"""
|
|
430
|
+
path = Path(path)
|
|
431
|
+
with open(path, "r") as f:
|
|
432
|
+
data = json.load(f)
|
|
433
|
+
|
|
434
|
+
episode = Episode.model_validate(data)
|
|
435
|
+
|
|
436
|
+
# Set source_file if not already set
|
|
437
|
+
if episode.source_file is None:
|
|
438
|
+
episode = episode.model_copy(update={"source_file": str(path)})
|
|
439
|
+
|
|
440
|
+
return episode
|
|
441
|
+
|
|
442
|
+
|
|
443
|
+
def save_episode(episode: Episode, path: Union[str, Path], indent: int = 2) -> None:
|
|
444
|
+
"""Save episode to JSON file.
|
|
445
|
+
|
|
446
|
+
Args:
|
|
447
|
+
episode: Episode to save
|
|
448
|
+
path: Output path
|
|
449
|
+
indent: JSON indentation
|
|
450
|
+
"""
|
|
451
|
+
path = Path(path)
|
|
452
|
+
path.parent.mkdir(parents=True, exist_ok=True)
|
|
453
|
+
|
|
454
|
+
with open(path, "w") as f:
|
|
455
|
+
f.write(episode.to_json(indent=indent))
|
|
456
|
+
|
|
457
|
+
|
|
458
|
+
def export_json_schema(path: Union[str, Path]) -> None:
|
|
459
|
+
"""Export JSON Schema to file for documentation/tooling.
|
|
460
|
+
|
|
461
|
+
Args:
|
|
462
|
+
path: Output path for schema file
|
|
463
|
+
"""
|
|
464
|
+
path = Path(path)
|
|
465
|
+
path.parent.mkdir(parents=True, exist_ok=True)
|
|
466
|
+
|
|
467
|
+
schema = Episode.json_schema()
|
|
468
|
+
|
|
469
|
+
with open(path, "w") as f:
|
|
470
|
+
json.dump(schema, f, indent=2)
|