openadapt-ml 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openadapt_ml/__init__.py +0 -0
- openadapt_ml/benchmarks/__init__.py +125 -0
- openadapt_ml/benchmarks/agent.py +825 -0
- openadapt_ml/benchmarks/azure.py +761 -0
- openadapt_ml/benchmarks/base.py +366 -0
- openadapt_ml/benchmarks/cli.py +884 -0
- openadapt_ml/benchmarks/data_collection.py +432 -0
- openadapt_ml/benchmarks/runner.py +381 -0
- openadapt_ml/benchmarks/waa.py +704 -0
- openadapt_ml/cloud/__init__.py +5 -0
- openadapt_ml/cloud/azure_inference.py +441 -0
- openadapt_ml/cloud/lambda_labs.py +2445 -0
- openadapt_ml/cloud/local.py +790 -0
- openadapt_ml/config.py +56 -0
- openadapt_ml/datasets/__init__.py +0 -0
- openadapt_ml/datasets/next_action.py +507 -0
- openadapt_ml/evals/__init__.py +23 -0
- openadapt_ml/evals/grounding.py +241 -0
- openadapt_ml/evals/plot_eval_metrics.py +174 -0
- openadapt_ml/evals/trajectory_matching.py +486 -0
- openadapt_ml/grounding/__init__.py +45 -0
- openadapt_ml/grounding/base.py +236 -0
- openadapt_ml/grounding/detector.py +570 -0
- openadapt_ml/ingest/__init__.py +43 -0
- openadapt_ml/ingest/capture.py +312 -0
- openadapt_ml/ingest/loader.py +232 -0
- openadapt_ml/ingest/synthetic.py +1102 -0
- openadapt_ml/models/__init__.py +0 -0
- openadapt_ml/models/api_adapter.py +171 -0
- openadapt_ml/models/base_adapter.py +59 -0
- openadapt_ml/models/dummy_adapter.py +42 -0
- openadapt_ml/models/qwen_vl.py +426 -0
- openadapt_ml/runtime/__init__.py +0 -0
- openadapt_ml/runtime/policy.py +182 -0
- openadapt_ml/schemas/__init__.py +53 -0
- openadapt_ml/schemas/sessions.py +122 -0
- openadapt_ml/schemas/validation.py +252 -0
- openadapt_ml/scripts/__init__.py +0 -0
- openadapt_ml/scripts/compare.py +1490 -0
- openadapt_ml/scripts/demo_policy.py +62 -0
- openadapt_ml/scripts/eval_policy.py +287 -0
- openadapt_ml/scripts/make_gif.py +153 -0
- openadapt_ml/scripts/prepare_synthetic.py +43 -0
- openadapt_ml/scripts/run_qwen_login_benchmark.py +192 -0
- openadapt_ml/scripts/train.py +174 -0
- openadapt_ml/training/__init__.py +0 -0
- openadapt_ml/training/benchmark_viewer.py +1538 -0
- openadapt_ml/training/shared_ui.py +157 -0
- openadapt_ml/training/stub_provider.py +276 -0
- openadapt_ml/training/trainer.py +2446 -0
- openadapt_ml/training/viewer.py +2970 -0
- openadapt_ml-0.1.0.dist-info/METADATA +818 -0
- openadapt_ml-0.1.0.dist-info/RECORD +55 -0
- openadapt_ml-0.1.0.dist-info/WHEEL +4 -0
- openadapt_ml-0.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,122 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
|
+
from typing import Any, Dict, List, Literal, Optional
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
ActionType = Literal[
|
|
8
|
+
"click",
|
|
9
|
+
"double_click",
|
|
10
|
+
"right_click",
|
|
11
|
+
"drag",
|
|
12
|
+
"scroll",
|
|
13
|
+
"type",
|
|
14
|
+
"key", # Single keypress (e.g., "Enter", "Tab")
|
|
15
|
+
"wait",
|
|
16
|
+
"done",
|
|
17
|
+
"answer", # For benchmarks that score by final answer
|
|
18
|
+
"failed",
|
|
19
|
+
]
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@dataclass
|
|
23
|
+
class Action:
|
|
24
|
+
"""A single GUI action taken by an agent or demonstrator.
|
|
25
|
+
|
|
26
|
+
Coordinates are normalized to the range [0, 1] relative to the
|
|
27
|
+
associated screenshot image's width/height.
|
|
28
|
+
|
|
29
|
+
Supports both coordinate-based and element-based grounding ("grounding-first"
|
|
30
|
+
approach where both are stored when available).
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
type: str
|
|
34
|
+
x: Optional[float] = None
|
|
35
|
+
y: Optional[float] = None
|
|
36
|
+
text: Optional[str] = None
|
|
37
|
+
raw: Optional[Dict[str, Any]] = None
|
|
38
|
+
|
|
39
|
+
# Bounding box for click targets: (x_min, y_min, x_max, y_max) in normalized coords
|
|
40
|
+
bbox: Optional[tuple[float, float, float, float]] = None
|
|
41
|
+
|
|
42
|
+
# Element index for Set-of-Marks (SoM) style actions: CLICK([1]), TYPE([2], "text")
|
|
43
|
+
element_index: Optional[int] = None
|
|
44
|
+
|
|
45
|
+
# Element grounding (for benchmark compatibility)
|
|
46
|
+
target_node_id: Optional[str] = None # DOM/AX/UIA node ID
|
|
47
|
+
target_role: Optional[str] = None # "button", "textfield", etc.
|
|
48
|
+
target_name: Optional[str] = None # Accessible name
|
|
49
|
+
|
|
50
|
+
# Keyboard actions
|
|
51
|
+
key: Optional[str] = None # Single key: "Enter", "Tab", "Escape"
|
|
52
|
+
modifiers: Optional[List[str]] = None # ["ctrl", "shift", "alt"]
|
|
53
|
+
|
|
54
|
+
# Scroll actions
|
|
55
|
+
scroll_direction: Optional[str] = None # "up", "down", "left", "right"
|
|
56
|
+
scroll_amount: Optional[float] = None # Pixels or normalized
|
|
57
|
+
|
|
58
|
+
# Drag actions - end coordinates
|
|
59
|
+
end_x: Optional[float] = None
|
|
60
|
+
end_y: Optional[float] = None
|
|
61
|
+
|
|
62
|
+
# Answer action (for benchmarks that score by answer)
|
|
63
|
+
answer: Optional[str] = None
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
@dataclass
|
|
67
|
+
class Observation:
|
|
68
|
+
"""A single observation of the GUI state.
|
|
69
|
+
|
|
70
|
+
Supports multiple observation modalities:
|
|
71
|
+
- Visual: screenshot image
|
|
72
|
+
- Structured UI: accessibility tree (UIA/AXTree/DOM)
|
|
73
|
+
- Context: URL, window title, focused element
|
|
74
|
+
"""
|
|
75
|
+
|
|
76
|
+
image_path: Optional[str] = None
|
|
77
|
+
meta: Optional[Dict[str, Any]] = None
|
|
78
|
+
|
|
79
|
+
# Structured UI (format varies by platform)
|
|
80
|
+
accessibility_tree: Optional[Dict[str, Any]] = None # UIA/AXTree/DOM
|
|
81
|
+
dom_html: Optional[str] = None # Raw HTML for web tasks
|
|
82
|
+
|
|
83
|
+
# Context
|
|
84
|
+
url: Optional[str] = None # For web tasks
|
|
85
|
+
window_title: Optional[str] = None # Active window title
|
|
86
|
+
app_name: Optional[str] = None # Active application
|
|
87
|
+
focused_element: Optional[Dict[str, Any]] = None # {node_id, bbox, text}
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
@dataclass
|
|
91
|
+
class Step:
|
|
92
|
+
"""One timestep in an episode: observation + action (+ optional thought)."""
|
|
93
|
+
|
|
94
|
+
t: float
|
|
95
|
+
observation: Observation
|
|
96
|
+
action: Action
|
|
97
|
+
thought: Optional[str] = None
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
@dataclass
|
|
101
|
+
class Episode:
|
|
102
|
+
"""A single workflow instance / task attempt.
|
|
103
|
+
|
|
104
|
+
This is the primary training unit used by dataset builders and
|
|
105
|
+
training loops.
|
|
106
|
+
"""
|
|
107
|
+
|
|
108
|
+
id: str
|
|
109
|
+
goal: str
|
|
110
|
+
steps: List[Step] = field(default_factory=list)
|
|
111
|
+
summary: Optional[str] = None
|
|
112
|
+
success: Optional[bool] = None
|
|
113
|
+
workflow_id: Optional[str] = None
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
@dataclass
|
|
117
|
+
class Session:
|
|
118
|
+
"""A container for one or more episodes plus session-level metadata."""
|
|
119
|
+
|
|
120
|
+
id: str
|
|
121
|
+
episodes: List[Episode] = field(default_factory=list)
|
|
122
|
+
meta: Optional[Dict[str, Any]] = None
|
|
@@ -0,0 +1,252 @@
|
|
|
1
|
+
"""Schema validation utilities for openadapt-ml.
|
|
2
|
+
|
|
3
|
+
Validates that data conforms to the canonical Episode/Session schema
|
|
4
|
+
before training or processing.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from dataclasses import fields
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
from typing import Any, Dict, List, Optional, Union
|
|
12
|
+
|
|
13
|
+
from openadapt_ml.schemas.sessions import Action, Episode, Observation, Session, Step
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class ValidationError(Exception):
|
|
17
|
+
"""Raised when data fails schema validation."""
|
|
18
|
+
|
|
19
|
+
def __init__(self, message: str, path: str = "", details: Optional[List[str]] = None):
|
|
20
|
+
self.message = message
|
|
21
|
+
self.path = path
|
|
22
|
+
self.details = details or []
|
|
23
|
+
super().__init__(self._format_message())
|
|
24
|
+
|
|
25
|
+
def _format_message(self) -> str:
|
|
26
|
+
msg = f"{self.path}: {self.message}" if self.path else self.message
|
|
27
|
+
if self.details:
|
|
28
|
+
msg += "\n " + "\n ".join(self.details)
|
|
29
|
+
return msg
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def validate_action(action: Action, path: str = "action") -> List[str]:
|
|
33
|
+
"""Validate an Action object.
|
|
34
|
+
|
|
35
|
+
Returns list of warnings (non-fatal issues). Raises ValidationError for fatal issues.
|
|
36
|
+
"""
|
|
37
|
+
warnings = []
|
|
38
|
+
|
|
39
|
+
if not isinstance(action, Action):
|
|
40
|
+
raise ValidationError(f"Expected Action, got {type(action).__name__}", path)
|
|
41
|
+
|
|
42
|
+
# Type is required
|
|
43
|
+
if not action.type:
|
|
44
|
+
raise ValidationError("Action type is required", path)
|
|
45
|
+
|
|
46
|
+
# Coordinate validation for click/drag actions
|
|
47
|
+
if action.type in ("click", "double_click", "right_click", "drag"):
|
|
48
|
+
if action.x is None or action.y is None:
|
|
49
|
+
# Only warn if no element_index either (SoM mode doesn't need coords)
|
|
50
|
+
if action.element_index is None:
|
|
51
|
+
warnings.append(f"{path}: Click action has no coordinates or element_index")
|
|
52
|
+
else:
|
|
53
|
+
# Validate coordinate range
|
|
54
|
+
if not (0.0 <= action.x <= 1.0):
|
|
55
|
+
warnings.append(f"{path}: x coordinate {action.x} outside [0, 1] range")
|
|
56
|
+
if not (0.0 <= action.y <= 1.0):
|
|
57
|
+
warnings.append(f"{path}: y coordinate {action.y} outside [0, 1] range")
|
|
58
|
+
|
|
59
|
+
# Drag requires end coordinates
|
|
60
|
+
if action.type == "drag":
|
|
61
|
+
if action.end_x is None or action.end_y is None:
|
|
62
|
+
if action.element_index is None:
|
|
63
|
+
warnings.append(f"{path}: Drag action missing end coordinates")
|
|
64
|
+
|
|
65
|
+
# Type action requires text
|
|
66
|
+
if action.type == "type" and not action.text:
|
|
67
|
+
warnings.append(f"{path}: Type action has no text")
|
|
68
|
+
|
|
69
|
+
# Key action requires key
|
|
70
|
+
if action.type == "key" and not action.key:
|
|
71
|
+
warnings.append(f"{path}: Key action has no key specified")
|
|
72
|
+
|
|
73
|
+
return warnings
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def validate_observation(obs: Observation, path: str = "observation") -> List[str]:
|
|
77
|
+
"""Validate an Observation object.
|
|
78
|
+
|
|
79
|
+
Returns list of warnings. Raises ValidationError for fatal issues.
|
|
80
|
+
"""
|
|
81
|
+
warnings = []
|
|
82
|
+
|
|
83
|
+
if not isinstance(obs, Observation):
|
|
84
|
+
raise ValidationError(f"Expected Observation, got {type(obs).__name__}", path)
|
|
85
|
+
|
|
86
|
+
# At least one of image_path or accessibility_tree should be present
|
|
87
|
+
if obs.image_path is None and obs.accessibility_tree is None:
|
|
88
|
+
warnings.append(f"{path}: No image_path or accessibility_tree")
|
|
89
|
+
|
|
90
|
+
# If image_path is set, check it's a valid path format
|
|
91
|
+
if obs.image_path and not isinstance(obs.image_path, str):
|
|
92
|
+
raise ValidationError(f"image_path must be string, got {type(obs.image_path).__name__}", path)
|
|
93
|
+
|
|
94
|
+
return warnings
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def validate_step(step: Step, path: str = "step") -> List[str]:
|
|
98
|
+
"""Validate a Step object.
|
|
99
|
+
|
|
100
|
+
Returns list of warnings. Raises ValidationError for fatal issues.
|
|
101
|
+
"""
|
|
102
|
+
warnings = []
|
|
103
|
+
|
|
104
|
+
if not isinstance(step, Step):
|
|
105
|
+
raise ValidationError(f"Expected Step, got {type(step).__name__}", path)
|
|
106
|
+
|
|
107
|
+
# Timestamp should be non-negative
|
|
108
|
+
if step.t < 0:
|
|
109
|
+
warnings.append(f"{path}: Negative timestamp {step.t}")
|
|
110
|
+
|
|
111
|
+
# Validate nested objects
|
|
112
|
+
warnings.extend(validate_observation(step.observation, f"{path}.observation"))
|
|
113
|
+
warnings.extend(validate_action(step.action, f"{path}.action"))
|
|
114
|
+
|
|
115
|
+
return warnings
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def validate_episode(episode: Episode, check_images: bool = False) -> List[str]:
|
|
119
|
+
"""Validate an Episode object.
|
|
120
|
+
|
|
121
|
+
Args:
|
|
122
|
+
episode: Episode to validate.
|
|
123
|
+
check_images: If True, verify image files exist on disk.
|
|
124
|
+
|
|
125
|
+
Returns:
|
|
126
|
+
List of warnings (non-fatal issues).
|
|
127
|
+
|
|
128
|
+
Raises:
|
|
129
|
+
ValidationError: If episode has fatal schema violations.
|
|
130
|
+
"""
|
|
131
|
+
warnings = []
|
|
132
|
+
|
|
133
|
+
if not isinstance(episode, Episode):
|
|
134
|
+
raise ValidationError(f"Expected Episode, got {type(episode).__name__}")
|
|
135
|
+
|
|
136
|
+
# Required fields
|
|
137
|
+
if not episode.id:
|
|
138
|
+
raise ValidationError("Episode id is required")
|
|
139
|
+
if not episode.goal:
|
|
140
|
+
raise ValidationError("Episode goal is required")
|
|
141
|
+
|
|
142
|
+
# Steps validation
|
|
143
|
+
if not episode.steps:
|
|
144
|
+
warnings.append(f"Episode '{episode.id}': No steps")
|
|
145
|
+
else:
|
|
146
|
+
for i, step in enumerate(episode.steps):
|
|
147
|
+
step_warnings = validate_step(step, f"Episode '{episode.id}'.steps[{i}]")
|
|
148
|
+
warnings.extend(step_warnings)
|
|
149
|
+
|
|
150
|
+
# Optional: check image files exist
|
|
151
|
+
if check_images and step.observation.image_path:
|
|
152
|
+
img_path = Path(step.observation.image_path)
|
|
153
|
+
if not img_path.exists():
|
|
154
|
+
warnings.append(f"Episode '{episode.id}'.steps[{i}]: Image not found: {img_path}")
|
|
155
|
+
|
|
156
|
+
# Check timestamps are monotonic
|
|
157
|
+
if len(episode.steps) > 1:
|
|
158
|
+
for i in range(1, len(episode.steps)):
|
|
159
|
+
if episode.steps[i].t < episode.steps[i - 1].t:
|
|
160
|
+
warnings.append(
|
|
161
|
+
f"Episode '{episode.id}': Non-monotonic timestamps at steps {i-1} and {i}"
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
return warnings
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def validate_session(session: Session, check_images: bool = False) -> List[str]:
|
|
168
|
+
"""Validate a Session object.
|
|
169
|
+
|
|
170
|
+
Args:
|
|
171
|
+
session: Session to validate.
|
|
172
|
+
check_images: If True, verify image files exist on disk.
|
|
173
|
+
|
|
174
|
+
Returns:
|
|
175
|
+
List of warnings (non-fatal issues).
|
|
176
|
+
|
|
177
|
+
Raises:
|
|
178
|
+
ValidationError: If session has fatal schema violations.
|
|
179
|
+
"""
|
|
180
|
+
warnings = []
|
|
181
|
+
|
|
182
|
+
if not isinstance(session, Session):
|
|
183
|
+
raise ValidationError(f"Expected Session, got {type(session).__name__}")
|
|
184
|
+
|
|
185
|
+
# Required fields
|
|
186
|
+
if not session.id:
|
|
187
|
+
raise ValidationError("Session id is required")
|
|
188
|
+
|
|
189
|
+
# Episodes validation
|
|
190
|
+
if not session.episodes:
|
|
191
|
+
warnings.append(f"Session '{session.id}': No episodes")
|
|
192
|
+
else:
|
|
193
|
+
for i, episode in enumerate(session.episodes):
|
|
194
|
+
ep_warnings = validate_episode(episode, check_images=check_images)
|
|
195
|
+
warnings.extend(ep_warnings)
|
|
196
|
+
|
|
197
|
+
return warnings
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
def validate_episodes(episodes: List[Episode], check_images: bool = False) -> List[str]:
|
|
201
|
+
"""Validate a list of Episode objects.
|
|
202
|
+
|
|
203
|
+
Args:
|
|
204
|
+
episodes: List of episodes to validate.
|
|
205
|
+
check_images: If True, verify image files exist on disk.
|
|
206
|
+
|
|
207
|
+
Returns:
|
|
208
|
+
List of warnings (non-fatal issues).
|
|
209
|
+
|
|
210
|
+
Raises:
|
|
211
|
+
ValidationError: If any episode has fatal schema violations.
|
|
212
|
+
"""
|
|
213
|
+
warnings = []
|
|
214
|
+
|
|
215
|
+
if not isinstance(episodes, list):
|
|
216
|
+
raise ValidationError(f"Expected list of Episodes, got {type(episodes).__name__}")
|
|
217
|
+
|
|
218
|
+
if not episodes:
|
|
219
|
+
warnings.append("Empty episode list")
|
|
220
|
+
return warnings
|
|
221
|
+
|
|
222
|
+
for i, episode in enumerate(episodes):
|
|
223
|
+
ep_warnings = validate_episode(episode, check_images=check_images)
|
|
224
|
+
warnings.extend(ep_warnings)
|
|
225
|
+
|
|
226
|
+
return warnings
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
def summarize_episodes(episodes: List[Episode]) -> Dict[str, Any]:
|
|
230
|
+
"""Generate a summary of episode statistics.
|
|
231
|
+
|
|
232
|
+
Useful for quick sanity checks after loading data.
|
|
233
|
+
"""
|
|
234
|
+
if not episodes:
|
|
235
|
+
return {"count": 0, "total_steps": 0, "action_types": {}}
|
|
236
|
+
|
|
237
|
+
action_types: Dict[str, int] = {}
|
|
238
|
+
total_steps = 0
|
|
239
|
+
|
|
240
|
+
for ep in episodes:
|
|
241
|
+
total_steps += len(ep.steps)
|
|
242
|
+
for step in ep.steps:
|
|
243
|
+
action_type = step.action.type
|
|
244
|
+
action_types[action_type] = action_types.get(action_type, 0) + 1
|
|
245
|
+
|
|
246
|
+
return {
|
|
247
|
+
"count": len(episodes),
|
|
248
|
+
"total_steps": total_steps,
|
|
249
|
+
"avg_steps_per_episode": total_steps / len(episodes),
|
|
250
|
+
"action_types": action_types,
|
|
251
|
+
"goals": [ep.goal for ep in episodes[:5]], # First 5 goals as sample
|
|
252
|
+
}
|
|
File without changes
|