openadapt-ml 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openadapt_ml/__init__.py +0 -0
- openadapt_ml/benchmarks/__init__.py +125 -0
- openadapt_ml/benchmarks/agent.py +825 -0
- openadapt_ml/benchmarks/azure.py +761 -0
- openadapt_ml/benchmarks/base.py +366 -0
- openadapt_ml/benchmarks/cli.py +884 -0
- openadapt_ml/benchmarks/data_collection.py +432 -0
- openadapt_ml/benchmarks/runner.py +381 -0
- openadapt_ml/benchmarks/waa.py +704 -0
- openadapt_ml/cloud/__init__.py +5 -0
- openadapt_ml/cloud/azure_inference.py +441 -0
- openadapt_ml/cloud/lambda_labs.py +2445 -0
- openadapt_ml/cloud/local.py +790 -0
- openadapt_ml/config.py +56 -0
- openadapt_ml/datasets/__init__.py +0 -0
- openadapt_ml/datasets/next_action.py +507 -0
- openadapt_ml/evals/__init__.py +23 -0
- openadapt_ml/evals/grounding.py +241 -0
- openadapt_ml/evals/plot_eval_metrics.py +174 -0
- openadapt_ml/evals/trajectory_matching.py +486 -0
- openadapt_ml/grounding/__init__.py +45 -0
- openadapt_ml/grounding/base.py +236 -0
- openadapt_ml/grounding/detector.py +570 -0
- openadapt_ml/ingest/__init__.py +43 -0
- openadapt_ml/ingest/capture.py +312 -0
- openadapt_ml/ingest/loader.py +232 -0
- openadapt_ml/ingest/synthetic.py +1102 -0
- openadapt_ml/models/__init__.py +0 -0
- openadapt_ml/models/api_adapter.py +171 -0
- openadapt_ml/models/base_adapter.py +59 -0
- openadapt_ml/models/dummy_adapter.py +42 -0
- openadapt_ml/models/qwen_vl.py +426 -0
- openadapt_ml/runtime/__init__.py +0 -0
- openadapt_ml/runtime/policy.py +182 -0
- openadapt_ml/schemas/__init__.py +53 -0
- openadapt_ml/schemas/sessions.py +122 -0
- openadapt_ml/schemas/validation.py +252 -0
- openadapt_ml/scripts/__init__.py +0 -0
- openadapt_ml/scripts/compare.py +1490 -0
- openadapt_ml/scripts/demo_policy.py +62 -0
- openadapt_ml/scripts/eval_policy.py +287 -0
- openadapt_ml/scripts/make_gif.py +153 -0
- openadapt_ml/scripts/prepare_synthetic.py +43 -0
- openadapt_ml/scripts/run_qwen_login_benchmark.py +192 -0
- openadapt_ml/scripts/train.py +174 -0
- openadapt_ml/training/__init__.py +0 -0
- openadapt_ml/training/benchmark_viewer.py +1538 -0
- openadapt_ml/training/shared_ui.py +157 -0
- openadapt_ml/training/stub_provider.py +276 -0
- openadapt_ml/training/trainer.py +2446 -0
- openadapt_ml/training/viewer.py +2970 -0
- openadapt_ml-0.1.0.dist-info/METADATA +818 -0
- openadapt_ml-0.1.0.dist-info/RECORD +55 -0
- openadapt_ml-0.1.0.dist-info/WHEEL +4 -0
- openadapt_ml-0.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,236 @@
|
|
|
1
|
+
"""Base interface for grounding modules.
|
|
2
|
+
|
|
3
|
+
Grounding is the process of converting a natural language target description
|
|
4
|
+
(e.g., "the login button") into executable coordinates on a screenshot.
|
|
5
|
+
|
|
6
|
+
This module defines the abstract interface that all grounding strategies must implement,
|
|
7
|
+
enabling policy/grounding separation as described in the architecture document.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
from abc import ABC, abstractmethod
|
|
13
|
+
from dataclasses import dataclass, field
|
|
14
|
+
from typing import TYPE_CHECKING
|
|
15
|
+
|
|
16
|
+
if TYPE_CHECKING:
|
|
17
|
+
from PIL import Image
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@dataclass
|
|
21
|
+
class RegionCandidate:
|
|
22
|
+
"""A candidate region for action execution.
|
|
23
|
+
|
|
24
|
+
Represents a potential target location on the screen, with confidence
|
|
25
|
+
score and optional metadata.
|
|
26
|
+
|
|
27
|
+
Attributes:
|
|
28
|
+
bbox: Bounding box as (x1, y1, x2, y2) in normalized [0,1] coordinates.
|
|
29
|
+
centroid: Click point as (x, y) in normalized coordinates.
|
|
30
|
+
confidence: Confidence score in [0, 1], higher is better.
|
|
31
|
+
element_label: Optional label describing the element (e.g., "button", "input").
|
|
32
|
+
text_content: Optional text content of the element.
|
|
33
|
+
metadata: Additional grounding-specific data.
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
bbox: tuple[float, float, float, float] # x1, y1, x2, y2 normalized
|
|
37
|
+
centroid: tuple[float, float] # click point (x, y)
|
|
38
|
+
confidence: float
|
|
39
|
+
element_label: str | None = None
|
|
40
|
+
text_content: str | None = None
|
|
41
|
+
metadata: dict | None = field(default_factory=dict)
|
|
42
|
+
|
|
43
|
+
def __post_init__(self) -> None:
|
|
44
|
+
"""Validate coordinates are in [0, 1] range."""
|
|
45
|
+
x1, y1, x2, y2 = self.bbox
|
|
46
|
+
cx, cy = self.centroid
|
|
47
|
+
|
|
48
|
+
for val in [x1, y1, x2, y2, cx, cy]:
|
|
49
|
+
if not 0 <= val <= 1:
|
|
50
|
+
raise ValueError(f"Coordinates must be in [0, 1] range, got {val}")
|
|
51
|
+
|
|
52
|
+
if x1 > x2 or y1 > y2:
|
|
53
|
+
raise ValueError(f"Invalid bbox: x1 > x2 or y1 > y2: {self.bbox}")
|
|
54
|
+
|
|
55
|
+
if not 0 <= self.confidence <= 1:
|
|
56
|
+
raise ValueError(f"Confidence must be in [0, 1], got {self.confidence}")
|
|
57
|
+
|
|
58
|
+
@property
|
|
59
|
+
def area(self) -> float:
|
|
60
|
+
"""Compute normalized area of the bounding box."""
|
|
61
|
+
x1, y1, x2, y2 = self.bbox
|
|
62
|
+
return (x2 - x1) * (y2 - y1)
|
|
63
|
+
|
|
64
|
+
def iou(self, other: "RegionCandidate") -> float:
|
|
65
|
+
"""Compute Intersection over Union with another region.
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
other: Another RegionCandidate.
|
|
69
|
+
|
|
70
|
+
Returns:
|
|
71
|
+
IoU score in [0, 1].
|
|
72
|
+
"""
|
|
73
|
+
x1, y1, x2, y2 = self.bbox
|
|
74
|
+
ox1, oy1, ox2, oy2 = other.bbox
|
|
75
|
+
|
|
76
|
+
# Intersection
|
|
77
|
+
ix1 = max(x1, ox1)
|
|
78
|
+
iy1 = max(y1, oy1)
|
|
79
|
+
ix2 = min(x2, ox2)
|
|
80
|
+
iy2 = min(y2, oy2)
|
|
81
|
+
|
|
82
|
+
if ix1 >= ix2 or iy1 >= iy2:
|
|
83
|
+
return 0.0
|
|
84
|
+
|
|
85
|
+
intersection = (ix2 - ix1) * (iy2 - iy1)
|
|
86
|
+
union = self.area + other.area - intersection
|
|
87
|
+
|
|
88
|
+
return intersection / union if union > 0 else 0.0
|
|
89
|
+
|
|
90
|
+
def contains_point(self, x: float, y: float) -> bool:
|
|
91
|
+
"""Check if a point is inside the bounding box.
|
|
92
|
+
|
|
93
|
+
Args:
|
|
94
|
+
x: X coordinate in normalized [0, 1].
|
|
95
|
+
y: Y coordinate in normalized [0, 1].
|
|
96
|
+
|
|
97
|
+
Returns:
|
|
98
|
+
True if point is inside bbox.
|
|
99
|
+
"""
|
|
100
|
+
x1, y1, x2, y2 = self.bbox
|
|
101
|
+
return x1 <= x <= x2 and y1 <= y <= y2
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
class GroundingModule(ABC):
|
|
105
|
+
"""Abstract base class for grounding strategies.
|
|
106
|
+
|
|
107
|
+
A grounding module takes a screenshot and a natural language description
|
|
108
|
+
of the target element, and returns candidate regions where the target
|
|
109
|
+
might be located.
|
|
110
|
+
|
|
111
|
+
Implementations include:
|
|
112
|
+
- SoMGrounder: Uses pre-labeled element indices (for synthetic/SoM mode)
|
|
113
|
+
- CoordinateGrounder: Fine-tuned VLM regression
|
|
114
|
+
- DetectorGrounder: External detection (OmniParser, Gemini bbox API)
|
|
115
|
+
- AttentionGrounder: GUI-Actor style attention-based region selection
|
|
116
|
+
|
|
117
|
+
Example:
|
|
118
|
+
grounder = DetectorGrounder()
|
|
119
|
+
candidates = grounder.ground(screenshot, "the submit button", k=3)
|
|
120
|
+
best = candidates[0] # Highest confidence
|
|
121
|
+
click(best.centroid[0], best.centroid[1])
|
|
122
|
+
"""
|
|
123
|
+
|
|
124
|
+
@abstractmethod
|
|
125
|
+
def ground(
|
|
126
|
+
self,
|
|
127
|
+
image: "Image",
|
|
128
|
+
target_description: str,
|
|
129
|
+
k: int = 1,
|
|
130
|
+
) -> list[RegionCandidate]:
|
|
131
|
+
"""Locate regions matching the target description.
|
|
132
|
+
|
|
133
|
+
Args:
|
|
134
|
+
image: PIL Image of the screenshot to search.
|
|
135
|
+
target_description: Natural language description of the target
|
|
136
|
+
element (e.g., "login button", "username field", "the red X").
|
|
137
|
+
k: Maximum number of candidates to return.
|
|
138
|
+
|
|
139
|
+
Returns:
|
|
140
|
+
List of candidate regions, sorted by confidence descending.
|
|
141
|
+
Returns empty list if no candidates found.
|
|
142
|
+
"""
|
|
143
|
+
pass
|
|
144
|
+
|
|
145
|
+
def ground_batch(
|
|
146
|
+
self,
|
|
147
|
+
images: list["Image"],
|
|
148
|
+
target_descriptions: list[str],
|
|
149
|
+
k: int = 1,
|
|
150
|
+
) -> list[list[RegionCandidate]]:
|
|
151
|
+
"""Batch grounding for multiple images/targets.
|
|
152
|
+
|
|
153
|
+
Default implementation calls ground() for each pair.
|
|
154
|
+
Subclasses can override for more efficient batching.
|
|
155
|
+
|
|
156
|
+
Args:
|
|
157
|
+
images: List of PIL Images.
|
|
158
|
+
target_descriptions: List of target descriptions (same length as images).
|
|
159
|
+
k: Maximum candidates per image.
|
|
160
|
+
|
|
161
|
+
Returns:
|
|
162
|
+
List of candidate lists, one per input image.
|
|
163
|
+
"""
|
|
164
|
+
if len(images) != len(target_descriptions):
|
|
165
|
+
raise ValueError("images and target_descriptions must have same length")
|
|
166
|
+
|
|
167
|
+
return [
|
|
168
|
+
self.ground(img, desc, k=k)
|
|
169
|
+
for img, desc in zip(images, target_descriptions)
|
|
170
|
+
]
|
|
171
|
+
|
|
172
|
+
@property
|
|
173
|
+
def name(self) -> str:
|
|
174
|
+
"""Return the name of this grounding module."""
|
|
175
|
+
return self.__class__.__name__
|
|
176
|
+
|
|
177
|
+
@property
|
|
178
|
+
def supports_batch(self) -> bool:
|
|
179
|
+
"""Whether this module has optimized batch processing."""
|
|
180
|
+
return False
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
class OracleGrounder(GroundingModule):
|
|
184
|
+
"""Oracle grounding using ground-truth bounding boxes.
|
|
185
|
+
|
|
186
|
+
Used for evaluation to measure policy performance independent of grounding.
|
|
187
|
+
Returns the ground-truth bbox as the only candidate with confidence 1.0.
|
|
188
|
+
"""
|
|
189
|
+
|
|
190
|
+
def __init__(self) -> None:
|
|
191
|
+
"""Initialize oracle grounder."""
|
|
192
|
+
self._ground_truth: dict[str, RegionCandidate] = {}
|
|
193
|
+
|
|
194
|
+
def set_ground_truth(
|
|
195
|
+
self,
|
|
196
|
+
target_description: str,
|
|
197
|
+
bbox: tuple[float, float, float, float],
|
|
198
|
+
centroid: tuple[float, float] | None = None,
|
|
199
|
+
) -> None:
|
|
200
|
+
"""Set ground truth for a target description.
|
|
201
|
+
|
|
202
|
+
Args:
|
|
203
|
+
target_description: The target to set ground truth for.
|
|
204
|
+
bbox: Ground truth bounding box (x1, y1, x2, y2).
|
|
205
|
+
centroid: Optional click point. If None, uses bbox center.
|
|
206
|
+
"""
|
|
207
|
+
if centroid is None:
|
|
208
|
+
x1, y1, x2, y2 = bbox
|
|
209
|
+
centroid = ((x1 + x2) / 2, (y1 + y2) / 2)
|
|
210
|
+
|
|
211
|
+
self._ground_truth[target_description] = RegionCandidate(
|
|
212
|
+
bbox=bbox,
|
|
213
|
+
centroid=centroid,
|
|
214
|
+
confidence=1.0,
|
|
215
|
+
element_label="ground_truth",
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
def ground(
|
|
219
|
+
self,
|
|
220
|
+
image: "Image",
|
|
221
|
+
target_description: str,
|
|
222
|
+
k: int = 1,
|
|
223
|
+
) -> list[RegionCandidate]:
|
|
224
|
+
"""Return ground truth if available.
|
|
225
|
+
|
|
226
|
+
Args:
|
|
227
|
+
image: Screenshot (ignored, we use ground truth).
|
|
228
|
+
target_description: Target to look up.
|
|
229
|
+
k: Ignored (always returns 0 or 1 candidate).
|
|
230
|
+
|
|
231
|
+
Returns:
|
|
232
|
+
List containing ground truth candidate, or empty list.
|
|
233
|
+
"""
|
|
234
|
+
if target_description in self._ground_truth:
|
|
235
|
+
return [self._ground_truth[target_description]]
|
|
236
|
+
return []
|