openadapt-ml 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openadapt_ml/__init__.py +0 -0
- openadapt_ml/benchmarks/__init__.py +125 -0
- openadapt_ml/benchmarks/agent.py +825 -0
- openadapt_ml/benchmarks/azure.py +761 -0
- openadapt_ml/benchmarks/base.py +366 -0
- openadapt_ml/benchmarks/cli.py +884 -0
- openadapt_ml/benchmarks/data_collection.py +432 -0
- openadapt_ml/benchmarks/runner.py +381 -0
- openadapt_ml/benchmarks/waa.py +704 -0
- openadapt_ml/cloud/__init__.py +5 -0
- openadapt_ml/cloud/azure_inference.py +441 -0
- openadapt_ml/cloud/lambda_labs.py +2445 -0
- openadapt_ml/cloud/local.py +790 -0
- openadapt_ml/config.py +56 -0
- openadapt_ml/datasets/__init__.py +0 -0
- openadapt_ml/datasets/next_action.py +507 -0
- openadapt_ml/evals/__init__.py +23 -0
- openadapt_ml/evals/grounding.py +241 -0
- openadapt_ml/evals/plot_eval_metrics.py +174 -0
- openadapt_ml/evals/trajectory_matching.py +486 -0
- openadapt_ml/grounding/__init__.py +45 -0
- openadapt_ml/grounding/base.py +236 -0
- openadapt_ml/grounding/detector.py +570 -0
- openadapt_ml/ingest/__init__.py +43 -0
- openadapt_ml/ingest/capture.py +312 -0
- openadapt_ml/ingest/loader.py +232 -0
- openadapt_ml/ingest/synthetic.py +1102 -0
- openadapt_ml/models/__init__.py +0 -0
- openadapt_ml/models/api_adapter.py +171 -0
- openadapt_ml/models/base_adapter.py +59 -0
- openadapt_ml/models/dummy_adapter.py +42 -0
- openadapt_ml/models/qwen_vl.py +426 -0
- openadapt_ml/runtime/__init__.py +0 -0
- openadapt_ml/runtime/policy.py +182 -0
- openadapt_ml/schemas/__init__.py +53 -0
- openadapt_ml/schemas/sessions.py +122 -0
- openadapt_ml/schemas/validation.py +252 -0
- openadapt_ml/scripts/__init__.py +0 -0
- openadapt_ml/scripts/compare.py +1490 -0
- openadapt_ml/scripts/demo_policy.py +62 -0
- openadapt_ml/scripts/eval_policy.py +287 -0
- openadapt_ml/scripts/make_gif.py +153 -0
- openadapt_ml/scripts/prepare_synthetic.py +43 -0
- openadapt_ml/scripts/run_qwen_login_benchmark.py +192 -0
- openadapt_ml/scripts/train.py +174 -0
- openadapt_ml/training/__init__.py +0 -0
- openadapt_ml/training/benchmark_viewer.py +1538 -0
- openadapt_ml/training/shared_ui.py +157 -0
- openadapt_ml/training/stub_provider.py +276 -0
- openadapt_ml/training/trainer.py +2446 -0
- openadapt_ml/training/viewer.py +2970 -0
- openadapt_ml-0.1.0.dist-info/METADATA +818 -0
- openadapt_ml-0.1.0.dist-info/RECORD +55 -0
- openadapt_ml-0.1.0.dist-info/WHEEL +4 -0
- openadapt_ml-0.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,570 @@
|
|
|
1
|
+
"""Detector-based grounding using external vision APIs.
|
|
2
|
+
|
|
3
|
+
This module provides grounding implementations that use external element detection
|
|
4
|
+
services (Gemini, OmniParser) to locate UI elements on screenshots.
|
|
5
|
+
|
|
6
|
+
Functions:
|
|
7
|
+
extract_ui_elements: Extract all interactive UI elements from a screenshot
|
|
8
|
+
overlay_element_marks: Overlay numbered labels (Set-of-Marks) on elements
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from __future__ import annotations
|
|
12
|
+
|
|
13
|
+
import base64
|
|
14
|
+
import io
|
|
15
|
+
import json
|
|
16
|
+
import re
|
|
17
|
+
from typing import TYPE_CHECKING
|
|
18
|
+
|
|
19
|
+
from openadapt_ml.config import settings
|
|
20
|
+
from openadapt_ml.grounding.base import GroundingModule, RegionCandidate
|
|
21
|
+
|
|
22
|
+
if TYPE_CHECKING:
|
|
23
|
+
from PIL import Image, ImageDraw, ImageFont
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class GeminiGrounder(GroundingModule):
|
|
27
|
+
"""Grounding using Google Gemini's vision capabilities.
|
|
28
|
+
|
|
29
|
+
Uses Gemini to identify UI elements matching a description and return
|
|
30
|
+
their bounding boxes.
|
|
31
|
+
|
|
32
|
+
Requires:
|
|
33
|
+
- GOOGLE_API_KEY environment variable
|
|
34
|
+
- google-generativeai package: pip install google-generativeai
|
|
35
|
+
|
|
36
|
+
Example:
|
|
37
|
+
grounder = GeminiGrounder()
|
|
38
|
+
candidates = grounder.ground(screenshot, "the login button")
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
def __init__(
|
|
42
|
+
self,
|
|
43
|
+
model: str = "gemini-2.5-flash",
|
|
44
|
+
api_key: str | None = None,
|
|
45
|
+
) -> None:
|
|
46
|
+
"""Initialize Gemini grounder.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
model: Gemini model to use. Options include:
|
|
50
|
+
- "gemini-2.5-flash" (fast, good for grounding)
|
|
51
|
+
- "gemini-2.5-pro" (higher quality)
|
|
52
|
+
- "gemini-3-pro-preview" (most capable)
|
|
53
|
+
api_key: Google API key. If None, uses GOOGLE_API_KEY from settings.
|
|
54
|
+
"""
|
|
55
|
+
self._model_name = model
|
|
56
|
+
self._api_key = api_key or settings.google_api_key
|
|
57
|
+
self._model = None
|
|
58
|
+
|
|
59
|
+
def _get_model(self):
|
|
60
|
+
"""Lazy-load the Gemini model."""
|
|
61
|
+
if self._model is None:
|
|
62
|
+
try:
|
|
63
|
+
import google.generativeai as genai
|
|
64
|
+
except ImportError as e:
|
|
65
|
+
raise ImportError(
|
|
66
|
+
"google-generativeai is required. Install with: "
|
|
67
|
+
"pip install google-generativeai"
|
|
68
|
+
) from e
|
|
69
|
+
|
|
70
|
+
if not self._api_key:
|
|
71
|
+
raise ValueError(
|
|
72
|
+
"GOOGLE_API_KEY environment variable not set. "
|
|
73
|
+
"Get an API key from https://makersuite.google.com/app/apikey"
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
genai.configure(api_key=self._api_key)
|
|
77
|
+
self._model = genai.GenerativeModel(self._model_name)
|
|
78
|
+
|
|
79
|
+
return self._model
|
|
80
|
+
|
|
81
|
+
def _image_to_base64(self, image: "Image") -> str:
|
|
82
|
+
"""Convert PIL Image to base64 string."""
|
|
83
|
+
buffer = io.BytesIO()
|
|
84
|
+
image.save(buffer, format="PNG")
|
|
85
|
+
return base64.b64encode(buffer.getvalue()).decode("utf-8")
|
|
86
|
+
|
|
87
|
+
def _parse_bbox_response(
|
|
88
|
+
self,
|
|
89
|
+
response_text: str,
|
|
90
|
+
image_width: int,
|
|
91
|
+
image_height: int,
|
|
92
|
+
) -> list[RegionCandidate]:
|
|
93
|
+
"""Parse Gemini's bbox response into RegionCandidates.
|
|
94
|
+
|
|
95
|
+
Args:
|
|
96
|
+
response_text: Raw text response from Gemini.
|
|
97
|
+
image_width: Image width for normalization.
|
|
98
|
+
image_height: Image height for normalization.
|
|
99
|
+
|
|
100
|
+
Returns:
|
|
101
|
+
List of RegionCandidate objects.
|
|
102
|
+
"""
|
|
103
|
+
candidates = []
|
|
104
|
+
|
|
105
|
+
# Try to parse JSON from the response
|
|
106
|
+
# Look for JSON array or object in the response
|
|
107
|
+
json_match = re.search(r'\[[\s\S]*\]|\{[\s\S]*\}', response_text)
|
|
108
|
+
if not json_match:
|
|
109
|
+
return candidates
|
|
110
|
+
|
|
111
|
+
try:
|
|
112
|
+
data = json.loads(json_match.group())
|
|
113
|
+
|
|
114
|
+
# Handle both single object and array
|
|
115
|
+
if isinstance(data, dict):
|
|
116
|
+
data = [data]
|
|
117
|
+
|
|
118
|
+
for item in data:
|
|
119
|
+
# Extract bbox - handle various formats
|
|
120
|
+
bbox = item.get("bbox") or item.get("bounding_box") or item.get("box")
|
|
121
|
+
if not bbox:
|
|
122
|
+
# Try to get individual coordinates
|
|
123
|
+
if all(k in item for k in ["x1", "y1", "x2", "y2"]):
|
|
124
|
+
bbox = [item["x1"], item["y1"], item["x2"], item["y2"]]
|
|
125
|
+
elif all(k in item for k in ["x", "y", "width", "height"]):
|
|
126
|
+
x, y, w, h = item["x"], item["y"], item["width"], item["height"]
|
|
127
|
+
bbox = [x, y, x + w, y + h]
|
|
128
|
+
else:
|
|
129
|
+
continue
|
|
130
|
+
|
|
131
|
+
# Normalize coordinates
|
|
132
|
+
if len(bbox) == 4:
|
|
133
|
+
x1, y1, x2, y2 = bbox
|
|
134
|
+
|
|
135
|
+
# Check if already normalized (all values <= 1)
|
|
136
|
+
if all(0 <= v <= 1 for v in [x1, y1, x2, y2]):
|
|
137
|
+
norm_bbox = (x1, y1, x2, y2)
|
|
138
|
+
else:
|
|
139
|
+
# Normalize pixel coordinates
|
|
140
|
+
norm_bbox = (
|
|
141
|
+
x1 / image_width,
|
|
142
|
+
y1 / image_height,
|
|
143
|
+
x2 / image_width,
|
|
144
|
+
y2 / image_height,
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
# Clamp to valid range
|
|
148
|
+
norm_bbox = tuple(max(0, min(1, v)) for v in norm_bbox)
|
|
149
|
+
|
|
150
|
+
# Compute centroid
|
|
151
|
+
cx = (norm_bbox[0] + norm_bbox[2]) / 2
|
|
152
|
+
cy = (norm_bbox[1] + norm_bbox[3]) / 2
|
|
153
|
+
|
|
154
|
+
# Get confidence (default to 0.8 if not provided)
|
|
155
|
+
confidence = item.get("confidence", 0.8)
|
|
156
|
+
|
|
157
|
+
candidates.append(
|
|
158
|
+
RegionCandidate(
|
|
159
|
+
bbox=norm_bbox,
|
|
160
|
+
centroid=(cx, cy),
|
|
161
|
+
confidence=confidence,
|
|
162
|
+
element_label=item.get("label") or item.get("type"),
|
|
163
|
+
text_content=item.get("text"),
|
|
164
|
+
metadata={"raw": item},
|
|
165
|
+
)
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
except (json.JSONDecodeError, KeyError, TypeError):
|
|
169
|
+
pass
|
|
170
|
+
|
|
171
|
+
return candidates
|
|
172
|
+
|
|
173
|
+
def ground(
|
|
174
|
+
self,
|
|
175
|
+
image: "Image",
|
|
176
|
+
target_description: str,
|
|
177
|
+
k: int = 1,
|
|
178
|
+
) -> list[RegionCandidate]:
|
|
179
|
+
"""Locate regions matching the target description using Gemini.
|
|
180
|
+
|
|
181
|
+
Args:
|
|
182
|
+
image: PIL Image of the screenshot.
|
|
183
|
+
target_description: Natural language description of the target.
|
|
184
|
+
k: Maximum number of candidates to return.
|
|
185
|
+
|
|
186
|
+
Returns:
|
|
187
|
+
List of RegionCandidate objects sorted by confidence.
|
|
188
|
+
"""
|
|
189
|
+
model = self._get_model()
|
|
190
|
+
|
|
191
|
+
# Include image dimensions in prompt for accurate coordinate detection
|
|
192
|
+
prompt = f"""Analyze this screenshot and find the UI element matching this description: "{target_description}"
|
|
193
|
+
|
|
194
|
+
The image is {image.width} pixels wide and {image.height} pixels tall.
|
|
195
|
+
|
|
196
|
+
Return a JSON array with the bounding box(es) of matching elements. Each element should have:
|
|
197
|
+
- "bbox": [x1, y1, x2, y2] in pixel coordinates (top-left to bottom-right)
|
|
198
|
+
- "confidence": float between 0 and 1
|
|
199
|
+
- "label": element type (button, input, link, etc.)
|
|
200
|
+
- "text": visible text content if any
|
|
201
|
+
|
|
202
|
+
IMPORTANT: Use exact pixel coordinates based on the image dimensions provided above.
|
|
203
|
+
|
|
204
|
+
Return up to {k} best matches. If no match found, return an empty array [].
|
|
205
|
+
|
|
206
|
+
Example response format:
|
|
207
|
+
[{{"bbox": [100, 200, 250, 240], "confidence": 0.95, "label": "button", "text": "Submit"}}]
|
|
208
|
+
|
|
209
|
+
Return ONLY the JSON array, no other text."""
|
|
210
|
+
|
|
211
|
+
try:
|
|
212
|
+
# Create content with image
|
|
213
|
+
import google.generativeai as genai
|
|
214
|
+
|
|
215
|
+
response = model.generate_content(
|
|
216
|
+
[prompt, image],
|
|
217
|
+
generation_config=genai.GenerationConfig(
|
|
218
|
+
temperature=0.1,
|
|
219
|
+
max_output_tokens=1024,
|
|
220
|
+
),
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
candidates = self._parse_bbox_response(
|
|
224
|
+
response.text,
|
|
225
|
+
image.width,
|
|
226
|
+
image.height,
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
# Sort by confidence and limit to k
|
|
230
|
+
candidates.sort(key=lambda c: c.confidence, reverse=True)
|
|
231
|
+
return candidates[:k]
|
|
232
|
+
|
|
233
|
+
except Exception as e:
|
|
234
|
+
# Log error but don't crash
|
|
235
|
+
print(f"Gemini grounding error: {e}")
|
|
236
|
+
return []
|
|
237
|
+
|
|
238
|
+
@property
|
|
239
|
+
def supports_batch(self) -> bool:
|
|
240
|
+
"""Gemini doesn't have optimized batch processing."""
|
|
241
|
+
return False
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
def extract_ui_elements(
|
|
245
|
+
screenshot: "Image",
|
|
246
|
+
model_name: str = "gemini-2.0-flash",
|
|
247
|
+
api_key: str | None = None,
|
|
248
|
+
) -> list[dict]:
|
|
249
|
+
"""Extract all interactive UI elements from a screenshot using Gemini.
|
|
250
|
+
|
|
251
|
+
This function uses Gemini's vision capabilities to detect and extract
|
|
252
|
+
all interactive UI elements (buttons, text fields, links, etc.) with
|
|
253
|
+
their bounding boxes. Useful for Set-of-Marks (SoM) processing.
|
|
254
|
+
|
|
255
|
+
Args:
|
|
256
|
+
screenshot: PIL Image of the screenshot to analyze.
|
|
257
|
+
model_name: Gemini model to use (default: "gemini-2.0-flash").
|
|
258
|
+
api_key: Google API key. If None, uses GOOGLE_API_KEY from settings.
|
|
259
|
+
|
|
260
|
+
Returns:
|
|
261
|
+
List of element dictionaries with format:
|
|
262
|
+
{
|
|
263
|
+
"id": int, # Sequential ID starting at 1
|
|
264
|
+
"label": str, # Descriptive name (e.g., "Login button")
|
|
265
|
+
"bbox": [x1,y1,x2,y2], # Normalized coordinates [0,1]
|
|
266
|
+
"type": str, # Element type (button, text_field, etc.)
|
|
267
|
+
"text": str, # Visible text content (optional)
|
|
268
|
+
}
|
|
269
|
+
|
|
270
|
+
Example:
|
|
271
|
+
>>> from PIL import Image
|
|
272
|
+
>>> img = Image.open("login.png")
|
|
273
|
+
>>> elements = extract_ui_elements(img)
|
|
274
|
+
>>> print(elements[0])
|
|
275
|
+
{
|
|
276
|
+
"id": 1,
|
|
277
|
+
"label": "Username text field",
|
|
278
|
+
"bbox": [0.25, 0.30, 0.75, 0.38],
|
|
279
|
+
"type": "text_field",
|
|
280
|
+
"text": ""
|
|
281
|
+
}
|
|
282
|
+
|
|
283
|
+
Raises:
|
|
284
|
+
ImportError: If google-generativeai package not installed.
|
|
285
|
+
ValueError: If GOOGLE_API_KEY not set.
|
|
286
|
+
"""
|
|
287
|
+
try:
|
|
288
|
+
import google.generativeai as genai
|
|
289
|
+
except ImportError as e:
|
|
290
|
+
raise ImportError(
|
|
291
|
+
"google-generativeai is required. Install with: "
|
|
292
|
+
"pip install google-generativeai"
|
|
293
|
+
) from e
|
|
294
|
+
|
|
295
|
+
api_key = api_key or settings.google_api_key
|
|
296
|
+
if not api_key:
|
|
297
|
+
raise ValueError(
|
|
298
|
+
"GOOGLE_API_KEY environment variable not set. "
|
|
299
|
+
"Get an API key from https://makersuite.google.com/app/apikey"
|
|
300
|
+
)
|
|
301
|
+
|
|
302
|
+
genai.configure(api_key=api_key)
|
|
303
|
+
model = genai.GenerativeModel(model_name)
|
|
304
|
+
|
|
305
|
+
prompt = f"""Analyze this screenshot and identify ALL interactive UI elements.
|
|
306
|
+
|
|
307
|
+
The image is {screenshot.width} pixels wide and {screenshot.height} pixels tall.
|
|
308
|
+
|
|
309
|
+
For each interactive element (buttons, text fields, links, checkboxes, dropdowns, icons, tabs, menu items), output a JSON object with:
|
|
310
|
+
|
|
311
|
+
- "id": Sequential integer starting at 1
|
|
312
|
+
- "label": Descriptive name (e.g., "Login button", "Username text field", "Submit icon")
|
|
313
|
+
- "bbox": Bounding box as [x1, y1, x2, y2] in pixel coordinates (top-left to bottom-right)
|
|
314
|
+
- "type": One of: "button", "text_field", "checkbox", "link", "icon", "dropdown", "tab", "menu_item", "other"
|
|
315
|
+
- "text": Visible text content if any (empty string if no text)
|
|
316
|
+
|
|
317
|
+
IMPORTANT:
|
|
318
|
+
1. Use exact pixel coordinates based on the image dimensions provided above
|
|
319
|
+
2. Include ALL interactive elements you can see, even if they're small
|
|
320
|
+
3. Order elements from top-to-bottom, left-to-right
|
|
321
|
+
4. Return ONLY a valid JSON array, no markdown formatting, no explanation
|
|
322
|
+
|
|
323
|
+
Example output format:
|
|
324
|
+
[
|
|
325
|
+
{{"id": 1, "label": "Username text field", "bbox": [100, 150, 400, 185], "type": "text_field", "text": ""}},
|
|
326
|
+
{{"id": 2, "label": "Password text field", "bbox": [100, 200, 400, 235], "type": "text_field", "text": ""}},
|
|
327
|
+
{{"id": 3, "label": "Login button", "bbox": [200, 260, 300, 295], "type": "button", "text": "Login"}}
|
|
328
|
+
]"""
|
|
329
|
+
|
|
330
|
+
try:
|
|
331
|
+
response = model.generate_content(
|
|
332
|
+
[prompt, screenshot],
|
|
333
|
+
generation_config=genai.GenerationConfig(
|
|
334
|
+
temperature=0.1,
|
|
335
|
+
max_output_tokens=4096, # More tokens for many elements
|
|
336
|
+
),
|
|
337
|
+
)
|
|
338
|
+
|
|
339
|
+
# Parse JSON response
|
|
340
|
+
response_text = response.text
|
|
341
|
+
|
|
342
|
+
# Try to extract JSON array from response
|
|
343
|
+
json_match = re.search(r'\[[\s\S]*\]', response_text)
|
|
344
|
+
if not json_match:
|
|
345
|
+
# Maybe it's just a plain array
|
|
346
|
+
if response_text.strip().startswith('['):
|
|
347
|
+
json_match = re.match(r'.*', response_text)
|
|
348
|
+
else:
|
|
349
|
+
return []
|
|
350
|
+
|
|
351
|
+
elements = json.loads(json_match.group())
|
|
352
|
+
|
|
353
|
+
# Normalize coordinates to [0, 1]
|
|
354
|
+
normalized_elements = []
|
|
355
|
+
for elem in elements:
|
|
356
|
+
bbox = elem.get("bbox", [])
|
|
357
|
+
if len(bbox) == 4:
|
|
358
|
+
x1, y1, x2, y2 = bbox
|
|
359
|
+
|
|
360
|
+
# Check if already normalized
|
|
361
|
+
if all(0 <= v <= 1 for v in [x1, y1, x2, y2]):
|
|
362
|
+
norm_bbox = [x1, y1, x2, y2]
|
|
363
|
+
else:
|
|
364
|
+
# Normalize pixel coordinates
|
|
365
|
+
norm_bbox = [
|
|
366
|
+
max(0, min(1, x1 / screenshot.width)),
|
|
367
|
+
max(0, min(1, y1 / screenshot.height)),
|
|
368
|
+
max(0, min(1, x2 / screenshot.width)),
|
|
369
|
+
max(0, min(1, y2 / screenshot.height)),
|
|
370
|
+
]
|
|
371
|
+
|
|
372
|
+
normalized_elements.append({
|
|
373
|
+
"id": elem.get("id", len(normalized_elements) + 1),
|
|
374
|
+
"label": elem.get("label", f"Element {elem.get('id', len(normalized_elements) + 1)}"),
|
|
375
|
+
"bbox": norm_bbox,
|
|
376
|
+
"type": elem.get("type", "other"),
|
|
377
|
+
"text": elem.get("text", ""),
|
|
378
|
+
})
|
|
379
|
+
|
|
380
|
+
return normalized_elements
|
|
381
|
+
|
|
382
|
+
except (json.JSONDecodeError, KeyError, TypeError) as e:
|
|
383
|
+
print(f"Failed to parse Gemini response: {e}")
|
|
384
|
+
return []
|
|
385
|
+
except Exception as e:
|
|
386
|
+
print(f"Error extracting UI elements: {e}")
|
|
387
|
+
return []
|
|
388
|
+
|
|
389
|
+
|
|
390
|
+
def overlay_element_marks(
|
|
391
|
+
screenshot: "Image",
|
|
392
|
+
elements: list[dict],
|
|
393
|
+
style: str = "compact",
|
|
394
|
+
) -> "Image":
|
|
395
|
+
"""Overlay numbered labels (Set-of-Marks) on UI elements.
|
|
396
|
+
|
|
397
|
+
Creates a new image with numbered markers ([1], [2], [3], etc.) overlaid
|
|
398
|
+
on each UI element. This enables element-based interaction using indices
|
|
399
|
+
instead of coordinates (e.g., CLICK([1]) instead of CLICK(x=0.42, y=0.31)).
|
|
400
|
+
|
|
401
|
+
Args:
|
|
402
|
+
screenshot: PIL Image to annotate.
|
|
403
|
+
elements: List of element dicts from extract_ui_elements().
|
|
404
|
+
Each element must have "id" and "bbox" keys.
|
|
405
|
+
style: Label style - "compact" (small circles) or "full" (larger boxes).
|
|
406
|
+
|
|
407
|
+
Returns:
|
|
408
|
+
New PIL Image with numbered labels overlaid.
|
|
409
|
+
|
|
410
|
+
Example:
|
|
411
|
+
>>> elements = extract_ui_elements(screenshot)
|
|
412
|
+
>>> marked_img = overlay_element_marks(screenshot, elements)
|
|
413
|
+
>>> marked_img.save("screenshot_with_marks.png")
|
|
414
|
+
"""
|
|
415
|
+
from PIL import ImageDraw, ImageFont
|
|
416
|
+
|
|
417
|
+
img = screenshot.copy()
|
|
418
|
+
draw = ImageDraw.Draw(img)
|
|
419
|
+
|
|
420
|
+
width, height = img.size
|
|
421
|
+
|
|
422
|
+
# Try to load a good font
|
|
423
|
+
try:
|
|
424
|
+
# Try common font paths
|
|
425
|
+
font_paths = [
|
|
426
|
+
"/System/Library/Fonts/Helvetica.ttc", # macOS
|
|
427
|
+
"/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", # Linux
|
|
428
|
+
"C:\\Windows\\Fonts\\arial.ttf", # Windows
|
|
429
|
+
]
|
|
430
|
+
font = None
|
|
431
|
+
for font_path in font_paths:
|
|
432
|
+
try:
|
|
433
|
+
font = ImageFont.truetype(font_path, 14)
|
|
434
|
+
break
|
|
435
|
+
except OSError:
|
|
436
|
+
continue
|
|
437
|
+
|
|
438
|
+
if font is None:
|
|
439
|
+
font = ImageFont.load_default()
|
|
440
|
+
except Exception:
|
|
441
|
+
font = ImageFont.load_default()
|
|
442
|
+
|
|
443
|
+
for elem in elements:
|
|
444
|
+
elem_id = elem.get("id", 0)
|
|
445
|
+
bbox = elem.get("bbox", [])
|
|
446
|
+
|
|
447
|
+
if len(bbox) != 4:
|
|
448
|
+
continue
|
|
449
|
+
|
|
450
|
+
# Convert normalized coords to pixels
|
|
451
|
+
x1 = int(bbox[0] * width)
|
|
452
|
+
y1 = int(bbox[1] * height)
|
|
453
|
+
x2 = int(bbox[2] * width)
|
|
454
|
+
y2 = int(bbox[3] * height)
|
|
455
|
+
|
|
456
|
+
label = f"[{elem_id}]"
|
|
457
|
+
|
|
458
|
+
if style == "compact":
|
|
459
|
+
# Small circle with number at top-left corner
|
|
460
|
+
circle_radius = 12
|
|
461
|
+
circle_x = x1 + circle_radius
|
|
462
|
+
circle_y = y1 + circle_radius
|
|
463
|
+
|
|
464
|
+
# Ensure circle is within image bounds
|
|
465
|
+
circle_x = max(circle_radius, min(width - circle_radius, circle_x))
|
|
466
|
+
circle_y = max(circle_radius, min(height - circle_radius, circle_y))
|
|
467
|
+
|
|
468
|
+
# Draw red circle background
|
|
469
|
+
draw.ellipse(
|
|
470
|
+
[
|
|
471
|
+
circle_x - circle_radius,
|
|
472
|
+
circle_y - circle_radius,
|
|
473
|
+
circle_x + circle_radius,
|
|
474
|
+
circle_y + circle_radius,
|
|
475
|
+
],
|
|
476
|
+
fill="red",
|
|
477
|
+
outline="white",
|
|
478
|
+
width=1,
|
|
479
|
+
)
|
|
480
|
+
|
|
481
|
+
# Draw white text centered in circle
|
|
482
|
+
text_bbox = draw.textbbox((0, 0), label, font=font)
|
|
483
|
+
text_width = text_bbox[2] - text_bbox[0]
|
|
484
|
+
text_height = text_bbox[3] - text_bbox[1]
|
|
485
|
+
text_x = circle_x - text_width // 2
|
|
486
|
+
text_y = circle_y - text_height // 2
|
|
487
|
+
|
|
488
|
+
draw.text((text_x, text_y), label, fill="white", font=font)
|
|
489
|
+
|
|
490
|
+
else: # "full" style
|
|
491
|
+
# Draw bounding box
|
|
492
|
+
draw.rectangle([x1, y1, x2, y2], outline="red", width=2)
|
|
493
|
+
|
|
494
|
+
# Draw label box at top-right corner
|
|
495
|
+
text_bbox = draw.textbbox((0, 0), label, font=font)
|
|
496
|
+
text_width = text_bbox[2] - text_bbox[0] + 8
|
|
497
|
+
text_height = text_bbox[3] - text_bbox[1] + 4
|
|
498
|
+
|
|
499
|
+
label_x = x2 - text_width
|
|
500
|
+
label_y = y1 - text_height
|
|
501
|
+
|
|
502
|
+
# Ensure label is within image bounds
|
|
503
|
+
label_x = max(0, min(width - text_width, label_x))
|
|
504
|
+
label_y = max(0, min(height - text_height, label_y))
|
|
505
|
+
|
|
506
|
+
# Draw label background
|
|
507
|
+
draw.rectangle(
|
|
508
|
+
[label_x, label_y, label_x + text_width, label_y + text_height],
|
|
509
|
+
fill="red",
|
|
510
|
+
outline="white",
|
|
511
|
+
width=1,
|
|
512
|
+
)
|
|
513
|
+
|
|
514
|
+
# Draw label text
|
|
515
|
+
draw.text(
|
|
516
|
+
(label_x + 4, label_y + 2),
|
|
517
|
+
label,
|
|
518
|
+
fill="white",
|
|
519
|
+
font=font,
|
|
520
|
+
)
|
|
521
|
+
|
|
522
|
+
return img
|
|
523
|
+
|
|
524
|
+
|
|
525
|
+
class DetectorGrounder(GroundingModule):
|
|
526
|
+
"""Generic detector-based grounding with fallback support.
|
|
527
|
+
|
|
528
|
+
Wraps multiple detection backends and provides fallback if one fails.
|
|
529
|
+
|
|
530
|
+
Example:
|
|
531
|
+
grounder = DetectorGrounder() # Uses Gemini by default
|
|
532
|
+
grounder = DetectorGrounder(backend="omniparser") # Use OmniParser
|
|
533
|
+
"""
|
|
534
|
+
|
|
535
|
+
def __init__(
|
|
536
|
+
self,
|
|
537
|
+
backend: str = "gemini",
|
|
538
|
+
**kwargs,
|
|
539
|
+
) -> None:
|
|
540
|
+
"""Initialize detector grounder.
|
|
541
|
+
|
|
542
|
+
Args:
|
|
543
|
+
backend: Detection backend ("gemini", "omniparser").
|
|
544
|
+
**kwargs: Backend-specific arguments.
|
|
545
|
+
"""
|
|
546
|
+
self._backend_name = backend
|
|
547
|
+
|
|
548
|
+
if backend == "gemini":
|
|
549
|
+
self._backend = GeminiGrounder(**kwargs)
|
|
550
|
+
elif backend == "omniparser":
|
|
551
|
+
raise NotImplementedError(
|
|
552
|
+
"OmniParser backend not yet implemented. "
|
|
553
|
+
"Use backend='gemini' for now."
|
|
554
|
+
)
|
|
555
|
+
else:
|
|
556
|
+
raise ValueError(f"Unknown backend: {backend}")
|
|
557
|
+
|
|
558
|
+
def ground(
|
|
559
|
+
self,
|
|
560
|
+
image: "Image",
|
|
561
|
+
target_description: str,
|
|
562
|
+
k: int = 1,
|
|
563
|
+
) -> list[RegionCandidate]:
|
|
564
|
+
"""Delegate to backend grounder."""
|
|
565
|
+
return self._backend.ground(image, target_description, k=k)
|
|
566
|
+
|
|
567
|
+
@property
|
|
568
|
+
def name(self) -> str:
|
|
569
|
+
"""Return name including backend."""
|
|
570
|
+
return f"DetectorGrounder({self._backend_name})"
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
"""Data ingestion modules for openadapt-ml.
|
|
2
|
+
|
|
3
|
+
This package provides adapters for loading GUI interaction data from various sources
|
|
4
|
+
and converting them to the format used for training.
|
|
5
|
+
|
|
6
|
+
Data Model:
|
|
7
|
+
- Episode: A single task attempt (e.g., "log into the app"). Contains a sequence
|
|
8
|
+
of Steps, each with an Observation (screenshot) and Action (click/type/etc).
|
|
9
|
+
- Session: A container grouping one or more Episodes with shared metadata.
|
|
10
|
+
|
|
11
|
+
Functions:
|
|
12
|
+
- load_episodes(): Load Episodes from JSON files (primary entry point)
|
|
13
|
+
- save_episodes(): Save Episodes to JSON file
|
|
14
|
+
- capture_to_episode(): Converts one openadapt-capture recording → one Episode
|
|
15
|
+
- capture_to_session(): Converts one recording → Session containing one Episode
|
|
16
|
+
- load_captures_as_sessions(): Loads multiple recordings → list of Sessions
|
|
17
|
+
- generate_synthetic_sessions(): Creates synthetic training data
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
from openadapt_ml.ingest.loader import load_episodes, save_episodes
|
|
21
|
+
from openadapt_ml.ingest.synthetic import generate_synthetic_sessions
|
|
22
|
+
|
|
23
|
+
__all__ = [
|
|
24
|
+
"load_episodes",
|
|
25
|
+
"save_episodes",
|
|
26
|
+
"generate_synthetic_sessions",
|
|
27
|
+
]
|
|
28
|
+
|
|
29
|
+
# Conditionally export capture functions if openadapt-capture is installed
|
|
30
|
+
try:
|
|
31
|
+
from openadapt_ml.ingest.capture import (
|
|
32
|
+
capture_to_episode,
|
|
33
|
+
capture_to_session,
|
|
34
|
+
load_captures_as_sessions,
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
__all__.extend([
|
|
38
|
+
"capture_to_episode",
|
|
39
|
+
"capture_to_session",
|
|
40
|
+
"load_captures_as_sessions",
|
|
41
|
+
])
|
|
42
|
+
except ImportError:
|
|
43
|
+
pass
|