openadapt-ml 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. openadapt_ml/__init__.py +0 -0
  2. openadapt_ml/benchmarks/__init__.py +125 -0
  3. openadapt_ml/benchmarks/agent.py +825 -0
  4. openadapt_ml/benchmarks/azure.py +761 -0
  5. openadapt_ml/benchmarks/base.py +366 -0
  6. openadapt_ml/benchmarks/cli.py +884 -0
  7. openadapt_ml/benchmarks/data_collection.py +432 -0
  8. openadapt_ml/benchmarks/runner.py +381 -0
  9. openadapt_ml/benchmarks/waa.py +704 -0
  10. openadapt_ml/cloud/__init__.py +5 -0
  11. openadapt_ml/cloud/azure_inference.py +441 -0
  12. openadapt_ml/cloud/lambda_labs.py +2445 -0
  13. openadapt_ml/cloud/local.py +790 -0
  14. openadapt_ml/config.py +56 -0
  15. openadapt_ml/datasets/__init__.py +0 -0
  16. openadapt_ml/datasets/next_action.py +507 -0
  17. openadapt_ml/evals/__init__.py +23 -0
  18. openadapt_ml/evals/grounding.py +241 -0
  19. openadapt_ml/evals/plot_eval_metrics.py +174 -0
  20. openadapt_ml/evals/trajectory_matching.py +486 -0
  21. openadapt_ml/grounding/__init__.py +45 -0
  22. openadapt_ml/grounding/base.py +236 -0
  23. openadapt_ml/grounding/detector.py +570 -0
  24. openadapt_ml/ingest/__init__.py +43 -0
  25. openadapt_ml/ingest/capture.py +312 -0
  26. openadapt_ml/ingest/loader.py +232 -0
  27. openadapt_ml/ingest/synthetic.py +1102 -0
  28. openadapt_ml/models/__init__.py +0 -0
  29. openadapt_ml/models/api_adapter.py +171 -0
  30. openadapt_ml/models/base_adapter.py +59 -0
  31. openadapt_ml/models/dummy_adapter.py +42 -0
  32. openadapt_ml/models/qwen_vl.py +426 -0
  33. openadapt_ml/runtime/__init__.py +0 -0
  34. openadapt_ml/runtime/policy.py +182 -0
  35. openadapt_ml/schemas/__init__.py +53 -0
  36. openadapt_ml/schemas/sessions.py +122 -0
  37. openadapt_ml/schemas/validation.py +252 -0
  38. openadapt_ml/scripts/__init__.py +0 -0
  39. openadapt_ml/scripts/compare.py +1490 -0
  40. openadapt_ml/scripts/demo_policy.py +62 -0
  41. openadapt_ml/scripts/eval_policy.py +287 -0
  42. openadapt_ml/scripts/make_gif.py +153 -0
  43. openadapt_ml/scripts/prepare_synthetic.py +43 -0
  44. openadapt_ml/scripts/run_qwen_login_benchmark.py +192 -0
  45. openadapt_ml/scripts/train.py +174 -0
  46. openadapt_ml/training/__init__.py +0 -0
  47. openadapt_ml/training/benchmark_viewer.py +1538 -0
  48. openadapt_ml/training/shared_ui.py +157 -0
  49. openadapt_ml/training/stub_provider.py +276 -0
  50. openadapt_ml/training/trainer.py +2446 -0
  51. openadapt_ml/training/viewer.py +2970 -0
  52. openadapt_ml-0.1.0.dist-info/METADATA +818 -0
  53. openadapt_ml-0.1.0.dist-info/RECORD +55 -0
  54. openadapt_ml-0.1.0.dist-info/WHEEL +4 -0
  55. openadapt_ml-0.1.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,236 @@
1
+ """Base interface for grounding modules.
2
+
3
+ Grounding is the process of converting a natural language target description
4
+ (e.g., "the login button") into executable coordinates on a screenshot.
5
+
6
+ This module defines the abstract interface that all grounding strategies must implement,
7
+ enabling policy/grounding separation as described in the architecture document.
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ from abc import ABC, abstractmethod
13
+ from dataclasses import dataclass, field
14
+ from typing import TYPE_CHECKING
15
+
16
+ if TYPE_CHECKING:
17
+ from PIL import Image
18
+
19
+
20
+ @dataclass
21
+ class RegionCandidate:
22
+ """A candidate region for action execution.
23
+
24
+ Represents a potential target location on the screen, with confidence
25
+ score and optional metadata.
26
+
27
+ Attributes:
28
+ bbox: Bounding box as (x1, y1, x2, y2) in normalized [0,1] coordinates.
29
+ centroid: Click point as (x, y) in normalized coordinates.
30
+ confidence: Confidence score in [0, 1], higher is better.
31
+ element_label: Optional label describing the element (e.g., "button", "input").
32
+ text_content: Optional text content of the element.
33
+ metadata: Additional grounding-specific data.
34
+ """
35
+
36
+ bbox: tuple[float, float, float, float] # x1, y1, x2, y2 normalized
37
+ centroid: tuple[float, float] # click point (x, y)
38
+ confidence: float
39
+ element_label: str | None = None
40
+ text_content: str | None = None
41
+ metadata: dict | None = field(default_factory=dict)
42
+
43
+ def __post_init__(self) -> None:
44
+ """Validate coordinates are in [0, 1] range."""
45
+ x1, y1, x2, y2 = self.bbox
46
+ cx, cy = self.centroid
47
+
48
+ for val in [x1, y1, x2, y2, cx, cy]:
49
+ if not 0 <= val <= 1:
50
+ raise ValueError(f"Coordinates must be in [0, 1] range, got {val}")
51
+
52
+ if x1 > x2 or y1 > y2:
53
+ raise ValueError(f"Invalid bbox: x1 > x2 or y1 > y2: {self.bbox}")
54
+
55
+ if not 0 <= self.confidence <= 1:
56
+ raise ValueError(f"Confidence must be in [0, 1], got {self.confidence}")
57
+
58
+ @property
59
+ def area(self) -> float:
60
+ """Compute normalized area of the bounding box."""
61
+ x1, y1, x2, y2 = self.bbox
62
+ return (x2 - x1) * (y2 - y1)
63
+
64
+ def iou(self, other: "RegionCandidate") -> float:
65
+ """Compute Intersection over Union with another region.
66
+
67
+ Args:
68
+ other: Another RegionCandidate.
69
+
70
+ Returns:
71
+ IoU score in [0, 1].
72
+ """
73
+ x1, y1, x2, y2 = self.bbox
74
+ ox1, oy1, ox2, oy2 = other.bbox
75
+
76
+ # Intersection
77
+ ix1 = max(x1, ox1)
78
+ iy1 = max(y1, oy1)
79
+ ix2 = min(x2, ox2)
80
+ iy2 = min(y2, oy2)
81
+
82
+ if ix1 >= ix2 or iy1 >= iy2:
83
+ return 0.0
84
+
85
+ intersection = (ix2 - ix1) * (iy2 - iy1)
86
+ union = self.area + other.area - intersection
87
+
88
+ return intersection / union if union > 0 else 0.0
89
+
90
+ def contains_point(self, x: float, y: float) -> bool:
91
+ """Check if a point is inside the bounding box.
92
+
93
+ Args:
94
+ x: X coordinate in normalized [0, 1].
95
+ y: Y coordinate in normalized [0, 1].
96
+
97
+ Returns:
98
+ True if point is inside bbox.
99
+ """
100
+ x1, y1, x2, y2 = self.bbox
101
+ return x1 <= x <= x2 and y1 <= y <= y2
102
+
103
+
104
+ class GroundingModule(ABC):
105
+ """Abstract base class for grounding strategies.
106
+
107
+ A grounding module takes a screenshot and a natural language description
108
+ of the target element, and returns candidate regions where the target
109
+ might be located.
110
+
111
+ Implementations include:
112
+ - SoMGrounder: Uses pre-labeled element indices (for synthetic/SoM mode)
113
+ - CoordinateGrounder: Fine-tuned VLM regression
114
+ - DetectorGrounder: External detection (OmniParser, Gemini bbox API)
115
+ - AttentionGrounder: GUI-Actor style attention-based region selection
116
+
117
+ Example:
118
+ grounder = DetectorGrounder()
119
+ candidates = grounder.ground(screenshot, "the submit button", k=3)
120
+ best = candidates[0] # Highest confidence
121
+ click(best.centroid[0], best.centroid[1])
122
+ """
123
+
124
+ @abstractmethod
125
+ def ground(
126
+ self,
127
+ image: "Image",
128
+ target_description: str,
129
+ k: int = 1,
130
+ ) -> list[RegionCandidate]:
131
+ """Locate regions matching the target description.
132
+
133
+ Args:
134
+ image: PIL Image of the screenshot to search.
135
+ target_description: Natural language description of the target
136
+ element (e.g., "login button", "username field", "the red X").
137
+ k: Maximum number of candidates to return.
138
+
139
+ Returns:
140
+ List of candidate regions, sorted by confidence descending.
141
+ Returns empty list if no candidates found.
142
+ """
143
+ pass
144
+
145
+ def ground_batch(
146
+ self,
147
+ images: list["Image"],
148
+ target_descriptions: list[str],
149
+ k: int = 1,
150
+ ) -> list[list[RegionCandidate]]:
151
+ """Batch grounding for multiple images/targets.
152
+
153
+ Default implementation calls ground() for each pair.
154
+ Subclasses can override for more efficient batching.
155
+
156
+ Args:
157
+ images: List of PIL Images.
158
+ target_descriptions: List of target descriptions (same length as images).
159
+ k: Maximum candidates per image.
160
+
161
+ Returns:
162
+ List of candidate lists, one per input image.
163
+ """
164
+ if len(images) != len(target_descriptions):
165
+ raise ValueError("images and target_descriptions must have same length")
166
+
167
+ return [
168
+ self.ground(img, desc, k=k)
169
+ for img, desc in zip(images, target_descriptions)
170
+ ]
171
+
172
+ @property
173
+ def name(self) -> str:
174
+ """Return the name of this grounding module."""
175
+ return self.__class__.__name__
176
+
177
+ @property
178
+ def supports_batch(self) -> bool:
179
+ """Whether this module has optimized batch processing."""
180
+ return False
181
+
182
+
183
+ class OracleGrounder(GroundingModule):
184
+ """Oracle grounding using ground-truth bounding boxes.
185
+
186
+ Used for evaluation to measure policy performance independent of grounding.
187
+ Returns the ground-truth bbox as the only candidate with confidence 1.0.
188
+ """
189
+
190
+ def __init__(self) -> None:
191
+ """Initialize oracle grounder."""
192
+ self._ground_truth: dict[str, RegionCandidate] = {}
193
+
194
+ def set_ground_truth(
195
+ self,
196
+ target_description: str,
197
+ bbox: tuple[float, float, float, float],
198
+ centroid: tuple[float, float] | None = None,
199
+ ) -> None:
200
+ """Set ground truth for a target description.
201
+
202
+ Args:
203
+ target_description: The target to set ground truth for.
204
+ bbox: Ground truth bounding box (x1, y1, x2, y2).
205
+ centroid: Optional click point. If None, uses bbox center.
206
+ """
207
+ if centroid is None:
208
+ x1, y1, x2, y2 = bbox
209
+ centroid = ((x1 + x2) / 2, (y1 + y2) / 2)
210
+
211
+ self._ground_truth[target_description] = RegionCandidate(
212
+ bbox=bbox,
213
+ centroid=centroid,
214
+ confidence=1.0,
215
+ element_label="ground_truth",
216
+ )
217
+
218
+ def ground(
219
+ self,
220
+ image: "Image",
221
+ target_description: str,
222
+ k: int = 1,
223
+ ) -> list[RegionCandidate]:
224
+ """Return ground truth if available.
225
+
226
+ Args:
227
+ image: Screenshot (ignored, we use ground truth).
228
+ target_description: Target to look up.
229
+ k: Ignored (always returns 0 or 1 candidate).
230
+
231
+ Returns:
232
+ List containing ground truth candidate, or empty list.
233
+ """
234
+ if target_description in self._ground_truth:
235
+ return [self._ground_truth[target_description]]
236
+ return []