openadapt-ml 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. openadapt_ml/__init__.py +0 -0
  2. openadapt_ml/benchmarks/__init__.py +125 -0
  3. openadapt_ml/benchmarks/agent.py +825 -0
  4. openadapt_ml/benchmarks/azure.py +761 -0
  5. openadapt_ml/benchmarks/base.py +366 -0
  6. openadapt_ml/benchmarks/cli.py +884 -0
  7. openadapt_ml/benchmarks/data_collection.py +432 -0
  8. openadapt_ml/benchmarks/runner.py +381 -0
  9. openadapt_ml/benchmarks/waa.py +704 -0
  10. openadapt_ml/cloud/__init__.py +5 -0
  11. openadapt_ml/cloud/azure_inference.py +441 -0
  12. openadapt_ml/cloud/lambda_labs.py +2445 -0
  13. openadapt_ml/cloud/local.py +790 -0
  14. openadapt_ml/config.py +56 -0
  15. openadapt_ml/datasets/__init__.py +0 -0
  16. openadapt_ml/datasets/next_action.py +507 -0
  17. openadapt_ml/evals/__init__.py +23 -0
  18. openadapt_ml/evals/grounding.py +241 -0
  19. openadapt_ml/evals/plot_eval_metrics.py +174 -0
  20. openadapt_ml/evals/trajectory_matching.py +486 -0
  21. openadapt_ml/grounding/__init__.py +45 -0
  22. openadapt_ml/grounding/base.py +236 -0
  23. openadapt_ml/grounding/detector.py +570 -0
  24. openadapt_ml/ingest/__init__.py +43 -0
  25. openadapt_ml/ingest/capture.py +312 -0
  26. openadapt_ml/ingest/loader.py +232 -0
  27. openadapt_ml/ingest/synthetic.py +1102 -0
  28. openadapt_ml/models/__init__.py +0 -0
  29. openadapt_ml/models/api_adapter.py +171 -0
  30. openadapt_ml/models/base_adapter.py +59 -0
  31. openadapt_ml/models/dummy_adapter.py +42 -0
  32. openadapt_ml/models/qwen_vl.py +426 -0
  33. openadapt_ml/runtime/__init__.py +0 -0
  34. openadapt_ml/runtime/policy.py +182 -0
  35. openadapt_ml/schemas/__init__.py +53 -0
  36. openadapt_ml/schemas/sessions.py +122 -0
  37. openadapt_ml/schemas/validation.py +252 -0
  38. openadapt_ml/scripts/__init__.py +0 -0
  39. openadapt_ml/scripts/compare.py +1490 -0
  40. openadapt_ml/scripts/demo_policy.py +62 -0
  41. openadapt_ml/scripts/eval_policy.py +287 -0
  42. openadapt_ml/scripts/make_gif.py +153 -0
  43. openadapt_ml/scripts/prepare_synthetic.py +43 -0
  44. openadapt_ml/scripts/run_qwen_login_benchmark.py +192 -0
  45. openadapt_ml/scripts/train.py +174 -0
  46. openadapt_ml/training/__init__.py +0 -0
  47. openadapt_ml/training/benchmark_viewer.py +1538 -0
  48. openadapt_ml/training/shared_ui.py +157 -0
  49. openadapt_ml/training/stub_provider.py +276 -0
  50. openadapt_ml/training/trainer.py +2446 -0
  51. openadapt_ml/training/viewer.py +2970 -0
  52. openadapt_ml-0.1.0.dist-info/METADATA +818 -0
  53. openadapt_ml-0.1.0.dist-info/RECORD +55 -0
  54. openadapt_ml-0.1.0.dist-info/WHEEL +4 -0
  55. openadapt_ml-0.1.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,122 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass, field
4
+ from typing import Any, Dict, List, Literal, Optional
5
+
6
+
7
+ ActionType = Literal[
8
+ "click",
9
+ "double_click",
10
+ "right_click",
11
+ "drag",
12
+ "scroll",
13
+ "type",
14
+ "key", # Single keypress (e.g., "Enter", "Tab")
15
+ "wait",
16
+ "done",
17
+ "answer", # For benchmarks that score by final answer
18
+ "failed",
19
+ ]
20
+
21
+
22
+ @dataclass
23
+ class Action:
24
+ """A single GUI action taken by an agent or demonstrator.
25
+
26
+ Coordinates are normalized to the range [0, 1] relative to the
27
+ associated screenshot image's width/height.
28
+
29
+ Supports both coordinate-based and element-based grounding ("grounding-first"
30
+ approach where both are stored when available).
31
+ """
32
+
33
+ type: str
34
+ x: Optional[float] = None
35
+ y: Optional[float] = None
36
+ text: Optional[str] = None
37
+ raw: Optional[Dict[str, Any]] = None
38
+
39
+ # Bounding box for click targets: (x_min, y_min, x_max, y_max) in normalized coords
40
+ bbox: Optional[tuple[float, float, float, float]] = None
41
+
42
+ # Element index for Set-of-Marks (SoM) style actions: CLICK([1]), TYPE([2], "text")
43
+ element_index: Optional[int] = None
44
+
45
+ # Element grounding (for benchmark compatibility)
46
+ target_node_id: Optional[str] = None # DOM/AX/UIA node ID
47
+ target_role: Optional[str] = None # "button", "textfield", etc.
48
+ target_name: Optional[str] = None # Accessible name
49
+
50
+ # Keyboard actions
51
+ key: Optional[str] = None # Single key: "Enter", "Tab", "Escape"
52
+ modifiers: Optional[List[str]] = None # ["ctrl", "shift", "alt"]
53
+
54
+ # Scroll actions
55
+ scroll_direction: Optional[str] = None # "up", "down", "left", "right"
56
+ scroll_amount: Optional[float] = None # Pixels or normalized
57
+
58
+ # Drag actions - end coordinates
59
+ end_x: Optional[float] = None
60
+ end_y: Optional[float] = None
61
+
62
+ # Answer action (for benchmarks that score by answer)
63
+ answer: Optional[str] = None
64
+
65
+
66
+ @dataclass
67
+ class Observation:
68
+ """A single observation of the GUI state.
69
+
70
+ Supports multiple observation modalities:
71
+ - Visual: screenshot image
72
+ - Structured UI: accessibility tree (UIA/AXTree/DOM)
73
+ - Context: URL, window title, focused element
74
+ """
75
+
76
+ image_path: Optional[str] = None
77
+ meta: Optional[Dict[str, Any]] = None
78
+
79
+ # Structured UI (format varies by platform)
80
+ accessibility_tree: Optional[Dict[str, Any]] = None # UIA/AXTree/DOM
81
+ dom_html: Optional[str] = None # Raw HTML for web tasks
82
+
83
+ # Context
84
+ url: Optional[str] = None # For web tasks
85
+ window_title: Optional[str] = None # Active window title
86
+ app_name: Optional[str] = None # Active application
87
+ focused_element: Optional[Dict[str, Any]] = None # {node_id, bbox, text}
88
+
89
+
90
+ @dataclass
91
+ class Step:
92
+ """One timestep in an episode: observation + action (+ optional thought)."""
93
+
94
+ t: float
95
+ observation: Observation
96
+ action: Action
97
+ thought: Optional[str] = None
98
+
99
+
100
+ @dataclass
101
+ class Episode:
102
+ """A single workflow instance / task attempt.
103
+
104
+ This is the primary training unit used by dataset builders and
105
+ training loops.
106
+ """
107
+
108
+ id: str
109
+ goal: str
110
+ steps: List[Step] = field(default_factory=list)
111
+ summary: Optional[str] = None
112
+ success: Optional[bool] = None
113
+ workflow_id: Optional[str] = None
114
+
115
+
116
+ @dataclass
117
+ class Session:
118
+ """A container for one or more episodes plus session-level metadata."""
119
+
120
+ id: str
121
+ episodes: List[Episode] = field(default_factory=list)
122
+ meta: Optional[Dict[str, Any]] = None
@@ -0,0 +1,252 @@
1
+ """Schema validation utilities for openadapt-ml.
2
+
3
+ Validates that data conforms to the canonical Episode/Session schema
4
+ before training or processing.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from dataclasses import fields
10
+ from pathlib import Path
11
+ from typing import Any, Dict, List, Optional, Union
12
+
13
+ from openadapt_ml.schemas.sessions import Action, Episode, Observation, Session, Step
14
+
15
+
16
+ class ValidationError(Exception):
17
+ """Raised when data fails schema validation."""
18
+
19
+ def __init__(self, message: str, path: str = "", details: Optional[List[str]] = None):
20
+ self.message = message
21
+ self.path = path
22
+ self.details = details or []
23
+ super().__init__(self._format_message())
24
+
25
+ def _format_message(self) -> str:
26
+ msg = f"{self.path}: {self.message}" if self.path else self.message
27
+ if self.details:
28
+ msg += "\n " + "\n ".join(self.details)
29
+ return msg
30
+
31
+
32
+ def validate_action(action: Action, path: str = "action") -> List[str]:
33
+ """Validate an Action object.
34
+
35
+ Returns list of warnings (non-fatal issues). Raises ValidationError for fatal issues.
36
+ """
37
+ warnings = []
38
+
39
+ if not isinstance(action, Action):
40
+ raise ValidationError(f"Expected Action, got {type(action).__name__}", path)
41
+
42
+ # Type is required
43
+ if not action.type:
44
+ raise ValidationError("Action type is required", path)
45
+
46
+ # Coordinate validation for click/drag actions
47
+ if action.type in ("click", "double_click", "right_click", "drag"):
48
+ if action.x is None or action.y is None:
49
+ # Only warn if no element_index either (SoM mode doesn't need coords)
50
+ if action.element_index is None:
51
+ warnings.append(f"{path}: Click action has no coordinates or element_index")
52
+ else:
53
+ # Validate coordinate range
54
+ if not (0.0 <= action.x <= 1.0):
55
+ warnings.append(f"{path}: x coordinate {action.x} outside [0, 1] range")
56
+ if not (0.0 <= action.y <= 1.0):
57
+ warnings.append(f"{path}: y coordinate {action.y} outside [0, 1] range")
58
+
59
+ # Drag requires end coordinates
60
+ if action.type == "drag":
61
+ if action.end_x is None or action.end_y is None:
62
+ if action.element_index is None:
63
+ warnings.append(f"{path}: Drag action missing end coordinates")
64
+
65
+ # Type action requires text
66
+ if action.type == "type" and not action.text:
67
+ warnings.append(f"{path}: Type action has no text")
68
+
69
+ # Key action requires key
70
+ if action.type == "key" and not action.key:
71
+ warnings.append(f"{path}: Key action has no key specified")
72
+
73
+ return warnings
74
+
75
+
76
+ def validate_observation(obs: Observation, path: str = "observation") -> List[str]:
77
+ """Validate an Observation object.
78
+
79
+ Returns list of warnings. Raises ValidationError for fatal issues.
80
+ """
81
+ warnings = []
82
+
83
+ if not isinstance(obs, Observation):
84
+ raise ValidationError(f"Expected Observation, got {type(obs).__name__}", path)
85
+
86
+ # At least one of image_path or accessibility_tree should be present
87
+ if obs.image_path is None and obs.accessibility_tree is None:
88
+ warnings.append(f"{path}: No image_path or accessibility_tree")
89
+
90
+ # If image_path is set, check it's a valid path format
91
+ if obs.image_path and not isinstance(obs.image_path, str):
92
+ raise ValidationError(f"image_path must be string, got {type(obs.image_path).__name__}", path)
93
+
94
+ return warnings
95
+
96
+
97
+ def validate_step(step: Step, path: str = "step") -> List[str]:
98
+ """Validate a Step object.
99
+
100
+ Returns list of warnings. Raises ValidationError for fatal issues.
101
+ """
102
+ warnings = []
103
+
104
+ if not isinstance(step, Step):
105
+ raise ValidationError(f"Expected Step, got {type(step).__name__}", path)
106
+
107
+ # Timestamp should be non-negative
108
+ if step.t < 0:
109
+ warnings.append(f"{path}: Negative timestamp {step.t}")
110
+
111
+ # Validate nested objects
112
+ warnings.extend(validate_observation(step.observation, f"{path}.observation"))
113
+ warnings.extend(validate_action(step.action, f"{path}.action"))
114
+
115
+ return warnings
116
+
117
+
118
+ def validate_episode(episode: Episode, check_images: bool = False) -> List[str]:
119
+ """Validate an Episode object.
120
+
121
+ Args:
122
+ episode: Episode to validate.
123
+ check_images: If True, verify image files exist on disk.
124
+
125
+ Returns:
126
+ List of warnings (non-fatal issues).
127
+
128
+ Raises:
129
+ ValidationError: If episode has fatal schema violations.
130
+ """
131
+ warnings = []
132
+
133
+ if not isinstance(episode, Episode):
134
+ raise ValidationError(f"Expected Episode, got {type(episode).__name__}")
135
+
136
+ # Required fields
137
+ if not episode.id:
138
+ raise ValidationError("Episode id is required")
139
+ if not episode.goal:
140
+ raise ValidationError("Episode goal is required")
141
+
142
+ # Steps validation
143
+ if not episode.steps:
144
+ warnings.append(f"Episode '{episode.id}': No steps")
145
+ else:
146
+ for i, step in enumerate(episode.steps):
147
+ step_warnings = validate_step(step, f"Episode '{episode.id}'.steps[{i}]")
148
+ warnings.extend(step_warnings)
149
+
150
+ # Optional: check image files exist
151
+ if check_images and step.observation.image_path:
152
+ img_path = Path(step.observation.image_path)
153
+ if not img_path.exists():
154
+ warnings.append(f"Episode '{episode.id}'.steps[{i}]: Image not found: {img_path}")
155
+
156
+ # Check timestamps are monotonic
157
+ if len(episode.steps) > 1:
158
+ for i in range(1, len(episode.steps)):
159
+ if episode.steps[i].t < episode.steps[i - 1].t:
160
+ warnings.append(
161
+ f"Episode '{episode.id}': Non-monotonic timestamps at steps {i-1} and {i}"
162
+ )
163
+
164
+ return warnings
165
+
166
+
167
+ def validate_session(session: Session, check_images: bool = False) -> List[str]:
168
+ """Validate a Session object.
169
+
170
+ Args:
171
+ session: Session to validate.
172
+ check_images: If True, verify image files exist on disk.
173
+
174
+ Returns:
175
+ List of warnings (non-fatal issues).
176
+
177
+ Raises:
178
+ ValidationError: If session has fatal schema violations.
179
+ """
180
+ warnings = []
181
+
182
+ if not isinstance(session, Session):
183
+ raise ValidationError(f"Expected Session, got {type(session).__name__}")
184
+
185
+ # Required fields
186
+ if not session.id:
187
+ raise ValidationError("Session id is required")
188
+
189
+ # Episodes validation
190
+ if not session.episodes:
191
+ warnings.append(f"Session '{session.id}': No episodes")
192
+ else:
193
+ for i, episode in enumerate(session.episodes):
194
+ ep_warnings = validate_episode(episode, check_images=check_images)
195
+ warnings.extend(ep_warnings)
196
+
197
+ return warnings
198
+
199
+
200
+ def validate_episodes(episodes: List[Episode], check_images: bool = False) -> List[str]:
201
+ """Validate a list of Episode objects.
202
+
203
+ Args:
204
+ episodes: List of episodes to validate.
205
+ check_images: If True, verify image files exist on disk.
206
+
207
+ Returns:
208
+ List of warnings (non-fatal issues).
209
+
210
+ Raises:
211
+ ValidationError: If any episode has fatal schema violations.
212
+ """
213
+ warnings = []
214
+
215
+ if not isinstance(episodes, list):
216
+ raise ValidationError(f"Expected list of Episodes, got {type(episodes).__name__}")
217
+
218
+ if not episodes:
219
+ warnings.append("Empty episode list")
220
+ return warnings
221
+
222
+ for i, episode in enumerate(episodes):
223
+ ep_warnings = validate_episode(episode, check_images=check_images)
224
+ warnings.extend(ep_warnings)
225
+
226
+ return warnings
227
+
228
+
229
+ def summarize_episodes(episodes: List[Episode]) -> Dict[str, Any]:
230
+ """Generate a summary of episode statistics.
231
+
232
+ Useful for quick sanity checks after loading data.
233
+ """
234
+ if not episodes:
235
+ return {"count": 0, "total_steps": 0, "action_types": {}}
236
+
237
+ action_types: Dict[str, int] = {}
238
+ total_steps = 0
239
+
240
+ for ep in episodes:
241
+ total_steps += len(ep.steps)
242
+ for step in ep.steps:
243
+ action_type = step.action.type
244
+ action_types[action_type] = action_types.get(action_type, 0) + 1
245
+
246
+ return {
247
+ "count": len(episodes),
248
+ "total_steps": total_steps,
249
+ "avg_steps_per_episode": total_steps / len(episodes),
250
+ "action_types": action_types,
251
+ "goals": [ep.goal for ep in episodes[:5]], # First 5 goals as sample
252
+ }
File without changes