openadapt-ml 0.1.0__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openadapt_ml/baselines/__init__.py +121 -0
- openadapt_ml/baselines/adapter.py +185 -0
- openadapt_ml/baselines/cli.py +314 -0
- openadapt_ml/baselines/config.py +448 -0
- openadapt_ml/baselines/parser.py +922 -0
- openadapt_ml/baselines/prompts.py +787 -0
- openadapt_ml/benchmarks/__init__.py +13 -107
- openadapt_ml/benchmarks/agent.py +297 -374
- openadapt_ml/benchmarks/azure.py +62 -24
- openadapt_ml/benchmarks/azure_ops_tracker.py +521 -0
- openadapt_ml/benchmarks/cli.py +1874 -751
- openadapt_ml/benchmarks/trace_export.py +631 -0
- openadapt_ml/benchmarks/viewer.py +1236 -0
- openadapt_ml/benchmarks/vm_monitor.py +1111 -0
- openadapt_ml/benchmarks/waa_deploy/Dockerfile +216 -0
- openadapt_ml/benchmarks/waa_deploy/__init__.py +10 -0
- openadapt_ml/benchmarks/waa_deploy/api_agent.py +540 -0
- openadapt_ml/benchmarks/waa_deploy/start_waa_server.bat +53 -0
- openadapt_ml/cloud/azure_inference.py +3 -5
- openadapt_ml/cloud/lambda_labs.py +722 -307
- openadapt_ml/cloud/local.py +3194 -89
- openadapt_ml/cloud/ssh_tunnel.py +595 -0
- openadapt_ml/datasets/next_action.py +125 -96
- openadapt_ml/evals/grounding.py +32 -9
- openadapt_ml/evals/plot_eval_metrics.py +15 -13
- openadapt_ml/evals/trajectory_matching.py +120 -57
- openadapt_ml/experiments/demo_prompt/__init__.py +19 -0
- openadapt_ml/experiments/demo_prompt/format_demo.py +236 -0
- openadapt_ml/experiments/demo_prompt/results/experiment_20251231_002125.json +83 -0
- openadapt_ml/experiments/demo_prompt/results/experiment_n30_20251231_165958.json +1100 -0
- openadapt_ml/experiments/demo_prompt/results/multistep_20251231_025051.json +182 -0
- openadapt_ml/experiments/demo_prompt/run_experiment.py +541 -0
- openadapt_ml/experiments/representation_shootout/__init__.py +70 -0
- openadapt_ml/experiments/representation_shootout/conditions.py +708 -0
- openadapt_ml/experiments/representation_shootout/config.py +390 -0
- openadapt_ml/experiments/representation_shootout/evaluator.py +659 -0
- openadapt_ml/experiments/representation_shootout/runner.py +687 -0
- openadapt_ml/experiments/waa_demo/__init__.py +10 -0
- openadapt_ml/experiments/waa_demo/demos.py +357 -0
- openadapt_ml/experiments/waa_demo/runner.py +732 -0
- openadapt_ml/experiments/waa_demo/tasks.py +151 -0
- openadapt_ml/export/__init__.py +9 -0
- openadapt_ml/export/__main__.py +6 -0
- openadapt_ml/export/cli.py +89 -0
- openadapt_ml/export/parquet.py +277 -0
- openadapt_ml/grounding/detector.py +18 -14
- openadapt_ml/ingest/__init__.py +11 -10
- openadapt_ml/ingest/capture.py +97 -86
- openadapt_ml/ingest/loader.py +120 -69
- openadapt_ml/ingest/synthetic.py +344 -193
- openadapt_ml/models/api_adapter.py +14 -4
- openadapt_ml/models/base_adapter.py +10 -2
- openadapt_ml/models/providers/__init__.py +288 -0
- openadapt_ml/models/providers/anthropic.py +266 -0
- openadapt_ml/models/providers/base.py +299 -0
- openadapt_ml/models/providers/google.py +376 -0
- openadapt_ml/models/providers/openai.py +342 -0
- openadapt_ml/models/qwen_vl.py +46 -19
- openadapt_ml/perception/__init__.py +35 -0
- openadapt_ml/perception/integration.py +399 -0
- openadapt_ml/retrieval/README.md +226 -0
- openadapt_ml/retrieval/USAGE.md +391 -0
- openadapt_ml/retrieval/__init__.py +91 -0
- openadapt_ml/retrieval/demo_retriever.py +843 -0
- openadapt_ml/retrieval/embeddings.py +630 -0
- openadapt_ml/retrieval/index.py +194 -0
- openadapt_ml/retrieval/retriever.py +162 -0
- openadapt_ml/runtime/__init__.py +50 -0
- openadapt_ml/runtime/policy.py +27 -14
- openadapt_ml/runtime/safety_gate.py +471 -0
- openadapt_ml/schema/__init__.py +113 -0
- openadapt_ml/schema/converters.py +588 -0
- openadapt_ml/schema/episode.py +470 -0
- openadapt_ml/scripts/capture_screenshots.py +530 -0
- openadapt_ml/scripts/compare.py +102 -61
- openadapt_ml/scripts/demo_policy.py +4 -1
- openadapt_ml/scripts/eval_policy.py +19 -14
- openadapt_ml/scripts/make_gif.py +1 -1
- openadapt_ml/scripts/prepare_synthetic.py +16 -17
- openadapt_ml/scripts/train.py +98 -75
- openadapt_ml/segmentation/README.md +920 -0
- openadapt_ml/segmentation/__init__.py +97 -0
- openadapt_ml/segmentation/adapters/__init__.py +5 -0
- openadapt_ml/segmentation/adapters/capture_adapter.py +420 -0
- openadapt_ml/segmentation/annotator.py +610 -0
- openadapt_ml/segmentation/cache.py +290 -0
- openadapt_ml/segmentation/cli.py +674 -0
- openadapt_ml/segmentation/deduplicator.py +656 -0
- openadapt_ml/segmentation/frame_describer.py +788 -0
- openadapt_ml/segmentation/pipeline.py +340 -0
- openadapt_ml/segmentation/schemas.py +622 -0
- openadapt_ml/segmentation/segment_extractor.py +634 -0
- openadapt_ml/training/azure_ops_viewer.py +1097 -0
- openadapt_ml/training/benchmark_viewer.py +3255 -19
- openadapt_ml/training/shared_ui.py +7 -7
- openadapt_ml/training/stub_provider.py +57 -35
- openadapt_ml/training/trainer.py +255 -441
- openadapt_ml/training/trl_trainer.py +403 -0
- openadapt_ml/training/viewer.py +323 -108
- openadapt_ml/training/viewer_components.py +180 -0
- {openadapt_ml-0.1.0.dist-info → openadapt_ml-0.2.1.dist-info}/METADATA +312 -69
- openadapt_ml-0.2.1.dist-info/RECORD +116 -0
- openadapt_ml/benchmarks/base.py +0 -366
- openadapt_ml/benchmarks/data_collection.py +0 -432
- openadapt_ml/benchmarks/runner.py +0 -381
- openadapt_ml/benchmarks/waa.py +0 -704
- openadapt_ml/schemas/__init__.py +0 -53
- openadapt_ml/schemas/sessions.py +0 -122
- openadapt_ml/schemas/validation.py +0 -252
- openadapt_ml-0.1.0.dist-info/RECORD +0 -55
- {openadapt_ml-0.1.0.dist-info → openadapt_ml-0.2.1.dist-info}/WHEEL +0 -0
- {openadapt_ml-0.1.0.dist-info → openadapt_ml-0.2.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,399 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Integration Bridge between openadapt-grounding and openadapt-ml
|
|
3
|
+
|
|
4
|
+
This module provides the UIElementGraph class which wraps parsed UI elements
|
|
5
|
+
from openadapt-grounding parsers and converts them to the openadapt-ml schema.
|
|
6
|
+
|
|
7
|
+
Types imported from openadapt-grounding:
|
|
8
|
+
- Parser: Protocol for UI element parsers (OmniParser, UITars, etc.)
|
|
9
|
+
- Element: A detected UI element with normalized bounds
|
|
10
|
+
- LocatorResult: Result of attempting to locate an element
|
|
11
|
+
- RegistryEntry: A stable element that survived temporal filtering
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
from __future__ import annotations
|
|
15
|
+
|
|
16
|
+
import uuid
|
|
17
|
+
from collections import Counter
|
|
18
|
+
from typing import Any, Literal, Optional, Union
|
|
19
|
+
|
|
20
|
+
from pydantic import BaseModel, Field
|
|
21
|
+
|
|
22
|
+
from openadapt_ml.schema.episode import BoundingBox, UIElement
|
|
23
|
+
|
|
24
|
+
# Lazy import for openadapt-grounding to make it an optional dependency
|
|
25
|
+
_grounding_available: Optional[bool] = None
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def _check_grounding_available() -> bool:
|
|
29
|
+
"""Check if openadapt-grounding is installed."""
|
|
30
|
+
global _grounding_available
|
|
31
|
+
if _grounding_available is None:
|
|
32
|
+
try:
|
|
33
|
+
import openadapt_grounding # noqa: F401
|
|
34
|
+
|
|
35
|
+
_grounding_available = True
|
|
36
|
+
except ImportError:
|
|
37
|
+
_grounding_available = False
|
|
38
|
+
return _grounding_available
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def _get_grounding_types():
|
|
42
|
+
"""Import and return openadapt-grounding types.
|
|
43
|
+
|
|
44
|
+
Raises:
|
|
45
|
+
ImportError: If openadapt-grounding is not installed
|
|
46
|
+
"""
|
|
47
|
+
if not _check_grounding_available():
|
|
48
|
+
raise ImportError(
|
|
49
|
+
"openadapt-grounding is required for perception integration. "
|
|
50
|
+
"Install it with: pip install openadapt-grounding"
|
|
51
|
+
)
|
|
52
|
+
from openadapt_grounding import Element, LocatorResult, Parser, RegistryEntry
|
|
53
|
+
|
|
54
|
+
return Element, LocatorResult, Parser, RegistryEntry
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
# Source types for UI element graphs
|
|
58
|
+
SourceType = Literal["omniparser", "uitars", "ax", "dom", "mixed"]
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def element_to_ui_element(
|
|
62
|
+
element: Any, # openadapt_grounding.Element
|
|
63
|
+
image_width: Optional[int] = None,
|
|
64
|
+
image_height: Optional[int] = None,
|
|
65
|
+
element_index: Optional[int] = None,
|
|
66
|
+
source: Optional[str] = None,
|
|
67
|
+
) -> UIElement:
|
|
68
|
+
"""Convert an openadapt-grounding Element to a schema UIElement.
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
element: Element from openadapt-grounding with normalized bounds
|
|
72
|
+
image_width: Image width in pixels (for coordinate conversion)
|
|
73
|
+
image_height: Image height in pixels (for coordinate conversion)
|
|
74
|
+
element_index: Optional index to use as element_id
|
|
75
|
+
source: Source parser name (stored in automation_id for reference)
|
|
76
|
+
|
|
77
|
+
Returns:
|
|
78
|
+
UIElement with converted bounds and properties
|
|
79
|
+
"""
|
|
80
|
+
Element, _, _, _ = _get_grounding_types()
|
|
81
|
+
|
|
82
|
+
if not isinstance(element, Element):
|
|
83
|
+
raise TypeError(f"Expected Element, got {type(element)}")
|
|
84
|
+
|
|
85
|
+
# Extract normalized bounds (x, y, width, height)
|
|
86
|
+
norm_x, norm_y, norm_w, norm_h = element.bounds
|
|
87
|
+
|
|
88
|
+
# Convert to pixel coordinates if dimensions provided
|
|
89
|
+
if image_width is not None and image_height is not None:
|
|
90
|
+
bounds = BoundingBox(
|
|
91
|
+
x=int(norm_x * image_width),
|
|
92
|
+
y=int(norm_y * image_height),
|
|
93
|
+
width=int(norm_w * image_width),
|
|
94
|
+
height=int(norm_h * image_height),
|
|
95
|
+
)
|
|
96
|
+
else:
|
|
97
|
+
# Store normalized as integers (multiply by 10000 for precision)
|
|
98
|
+
# This allows using BoundingBox which requires int values
|
|
99
|
+
bounds = BoundingBox(
|
|
100
|
+
x=int(norm_x * 10000),
|
|
101
|
+
y=int(norm_y * 10000),
|
|
102
|
+
width=int(norm_w * 10000),
|
|
103
|
+
height=int(norm_h * 10000),
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
# Generate element_id from index if provided
|
|
107
|
+
element_id = str(element_index) if element_index is not None else None
|
|
108
|
+
|
|
109
|
+
return UIElement(
|
|
110
|
+
role=element.element_type,
|
|
111
|
+
name=element.text,
|
|
112
|
+
bounds=bounds,
|
|
113
|
+
element_id=element_id,
|
|
114
|
+
automation_id=source, # Store source in automation_id for reference
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def ui_element_to_element(
|
|
119
|
+
ui_element: UIElement,
|
|
120
|
+
image_width: Optional[int] = None,
|
|
121
|
+
image_height: Optional[int] = None,
|
|
122
|
+
) -> Any: # Returns openadapt_grounding.Element
|
|
123
|
+
"""Convert a schema UIElement back to an openadapt-grounding Element.
|
|
124
|
+
|
|
125
|
+
Useful for evaluation or when passing elements back to grounding functions.
|
|
126
|
+
|
|
127
|
+
Args:
|
|
128
|
+
ui_element: UIElement from openadapt-ml schema
|
|
129
|
+
image_width: Image width in pixels (for coordinate normalization)
|
|
130
|
+
image_height: Image height in pixels (for coordinate normalization)
|
|
131
|
+
|
|
132
|
+
Returns:
|
|
133
|
+
Element with normalized bounds
|
|
134
|
+
"""
|
|
135
|
+
Element, _, _, _ = _get_grounding_types()
|
|
136
|
+
|
|
137
|
+
if ui_element.bounds is None:
|
|
138
|
+
# Return element with zero bounds if no bounds available
|
|
139
|
+
return Element(
|
|
140
|
+
bounds=(0.0, 0.0, 0.0, 0.0),
|
|
141
|
+
text=ui_element.name,
|
|
142
|
+
element_type=ui_element.role or "unknown",
|
|
143
|
+
confidence=1.0,
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
# Convert pixel coordinates to normalized
|
|
147
|
+
if image_width is not None and image_height is not None:
|
|
148
|
+
norm_x = ui_element.bounds.x / image_width
|
|
149
|
+
norm_y = ui_element.bounds.y / image_height
|
|
150
|
+
norm_w = ui_element.bounds.width / image_width
|
|
151
|
+
norm_h = ui_element.bounds.height / image_height
|
|
152
|
+
else:
|
|
153
|
+
# Assume bounds are already in 10000-scale normalized format
|
|
154
|
+
norm_x = ui_element.bounds.x / 10000
|
|
155
|
+
norm_y = ui_element.bounds.y / 10000
|
|
156
|
+
norm_w = ui_element.bounds.width / 10000
|
|
157
|
+
norm_h = ui_element.bounds.height / 10000
|
|
158
|
+
|
|
159
|
+
return Element(
|
|
160
|
+
bounds=(norm_x, norm_y, norm_w, norm_h),
|
|
161
|
+
text=ui_element.name,
|
|
162
|
+
element_type=ui_element.role or "unknown",
|
|
163
|
+
confidence=1.0,
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
class UIElementGraph(BaseModel):
|
|
168
|
+
"""A graph of UI elements parsed from a screenshot.
|
|
169
|
+
|
|
170
|
+
This class wraps a list of UIElement objects and tracks their source
|
|
171
|
+
(which parser produced them). It provides a bridge between openadapt-grounding
|
|
172
|
+
parsers and the openadapt-ml schema system.
|
|
173
|
+
|
|
174
|
+
Attributes:
|
|
175
|
+
graph_id: Unique identifier for this graph (UUID)
|
|
176
|
+
elements: List of UIElement objects
|
|
177
|
+
source: Primary source parser ("omniparser", "uitars", "ax", "dom", "mixed")
|
|
178
|
+
source_summary: Count of elements by source (e.g., {"omniparser": 15, "ax": 8})
|
|
179
|
+
timestamp_ms: Optional timestamp when the screenshot was captured
|
|
180
|
+
image_width: Original image width (for coordinate reference)
|
|
181
|
+
image_height: Original image height (for coordinate reference)
|
|
182
|
+
|
|
183
|
+
Example:
|
|
184
|
+
>>> from openadapt_grounding import OmniParserClient
|
|
185
|
+
>>> from openadapt_ml.perception import UIElementGraph
|
|
186
|
+
>>>
|
|
187
|
+
>>> parser = OmniParserClient(endpoint="http://localhost:8000")
|
|
188
|
+
>>> elements = parser.parse(image)
|
|
189
|
+
>>> graph = UIElementGraph.from_parser_output(elements, "omniparser")
|
|
190
|
+
>>> print(f"Found {len(graph.elements)} elements")
|
|
191
|
+
"""
|
|
192
|
+
|
|
193
|
+
graph_id: str = Field(
|
|
194
|
+
default_factory=lambda: str(uuid.uuid4()),
|
|
195
|
+
description="Unique identifier for this graph",
|
|
196
|
+
)
|
|
197
|
+
elements: list[UIElement] = Field(
|
|
198
|
+
default_factory=list,
|
|
199
|
+
description="List of UI elements in the graph",
|
|
200
|
+
)
|
|
201
|
+
source: SourceType = Field(
|
|
202
|
+
default="mixed",
|
|
203
|
+
description="Primary source parser for the elements",
|
|
204
|
+
)
|
|
205
|
+
source_summary: dict[str, int] = Field(
|
|
206
|
+
default_factory=dict,
|
|
207
|
+
description="Count of elements by source",
|
|
208
|
+
)
|
|
209
|
+
timestamp_ms: Optional[int] = Field(
|
|
210
|
+
None,
|
|
211
|
+
description="Timestamp when the screenshot was captured (milliseconds)",
|
|
212
|
+
)
|
|
213
|
+
image_width: Optional[int] = Field(
|
|
214
|
+
None,
|
|
215
|
+
description="Original image width in pixels",
|
|
216
|
+
)
|
|
217
|
+
image_height: Optional[int] = Field(
|
|
218
|
+
None,
|
|
219
|
+
description="Original image height in pixels",
|
|
220
|
+
)
|
|
221
|
+
|
|
222
|
+
@classmethod
|
|
223
|
+
def from_parser_output(
|
|
224
|
+
cls,
|
|
225
|
+
elements: list[Any], # list[openadapt_grounding.Element]
|
|
226
|
+
source: Union[SourceType, str],
|
|
227
|
+
image_width: Optional[int] = None,
|
|
228
|
+
image_height: Optional[int] = None,
|
|
229
|
+
timestamp_ms: Optional[int] = None,
|
|
230
|
+
) -> "UIElementGraph":
|
|
231
|
+
"""Create a UIElementGraph from parser output.
|
|
232
|
+
|
|
233
|
+
Args:
|
|
234
|
+
elements: List of Element objects from openadapt-grounding parser
|
|
235
|
+
source: Parser source name ("omniparser", "uitars", "ax", "dom")
|
|
236
|
+
image_width: Image width in pixels (for coordinate conversion)
|
|
237
|
+
image_height: Image height in pixels (for coordinate conversion)
|
|
238
|
+
timestamp_ms: Optional timestamp when screenshot was captured
|
|
239
|
+
|
|
240
|
+
Returns:
|
|
241
|
+
UIElementGraph with converted UIElement objects
|
|
242
|
+
"""
|
|
243
|
+
# Validate source type
|
|
244
|
+
valid_sources = {"omniparser", "uitars", "ax", "dom", "mixed"}
|
|
245
|
+
if source not in valid_sources:
|
|
246
|
+
# Allow custom sources but warn
|
|
247
|
+
pass
|
|
248
|
+
|
|
249
|
+
# Convert elements
|
|
250
|
+
ui_elements = [
|
|
251
|
+
element_to_ui_element(
|
|
252
|
+
element=el,
|
|
253
|
+
image_width=image_width,
|
|
254
|
+
image_height=image_height,
|
|
255
|
+
element_index=i,
|
|
256
|
+
source=source,
|
|
257
|
+
)
|
|
258
|
+
for i, el in enumerate(elements)
|
|
259
|
+
]
|
|
260
|
+
|
|
261
|
+
# Build source summary
|
|
262
|
+
source_summary = {source: len(elements)}
|
|
263
|
+
|
|
264
|
+
return cls(
|
|
265
|
+
elements=ui_elements,
|
|
266
|
+
source=source if source in valid_sources else "mixed",
|
|
267
|
+
source_summary=source_summary,
|
|
268
|
+
timestamp_ms=timestamp_ms,
|
|
269
|
+
image_width=image_width,
|
|
270
|
+
image_height=image_height,
|
|
271
|
+
)
|
|
272
|
+
|
|
273
|
+
@classmethod
|
|
274
|
+
def merge(
|
|
275
|
+
cls,
|
|
276
|
+
graphs: list["UIElementGraph"],
|
|
277
|
+
deduplicate_iou_threshold: Optional[float] = None,
|
|
278
|
+
) -> "UIElementGraph":
|
|
279
|
+
"""Merge multiple UIElementGraphs into one.
|
|
280
|
+
|
|
281
|
+
Args:
|
|
282
|
+
graphs: List of UIElementGraph objects to merge
|
|
283
|
+
deduplicate_iou_threshold: If provided, remove duplicate elements
|
|
284
|
+
with IoU greater than this threshold (0.0-1.0)
|
|
285
|
+
|
|
286
|
+
Returns:
|
|
287
|
+
New UIElementGraph with combined elements
|
|
288
|
+
"""
|
|
289
|
+
if not graphs:
|
|
290
|
+
return cls()
|
|
291
|
+
|
|
292
|
+
# Combine all elements
|
|
293
|
+
all_elements: list[UIElement] = []
|
|
294
|
+
source_counts: Counter[str] = Counter()
|
|
295
|
+
|
|
296
|
+
for graph in graphs:
|
|
297
|
+
all_elements.extend(graph.elements)
|
|
298
|
+
for src, count in graph.source_summary.items():
|
|
299
|
+
source_counts[src] += count
|
|
300
|
+
|
|
301
|
+
# Get image dimensions from first graph that has them
|
|
302
|
+
image_width = None
|
|
303
|
+
image_height = None
|
|
304
|
+
for graph in graphs:
|
|
305
|
+
if graph.image_width is not None:
|
|
306
|
+
image_width = graph.image_width
|
|
307
|
+
image_height = graph.image_height
|
|
308
|
+
break
|
|
309
|
+
|
|
310
|
+
# TODO: Implement deduplication by IoU if threshold provided
|
|
311
|
+
# For now, just combine all elements
|
|
312
|
+
|
|
313
|
+
return cls(
|
|
314
|
+
elements=all_elements,
|
|
315
|
+
source="mixed" if len(source_counts) > 1 else list(source_counts.keys())[0],
|
|
316
|
+
source_summary=dict(source_counts),
|
|
317
|
+
timestamp_ms=graphs[0].timestamp_ms if graphs else None,
|
|
318
|
+
image_width=image_width,
|
|
319
|
+
image_height=image_height,
|
|
320
|
+
)
|
|
321
|
+
|
|
322
|
+
def to_dict(self) -> dict[str, Any]:
|
|
323
|
+
"""Convert to JSON-serializable dictionary.
|
|
324
|
+
|
|
325
|
+
Returns:
|
|
326
|
+
Dictionary representation suitable for JSON serialization
|
|
327
|
+
"""
|
|
328
|
+
return self.model_dump()
|
|
329
|
+
|
|
330
|
+
@classmethod
|
|
331
|
+
def from_dict(cls, data: dict[str, Any]) -> "UIElementGraph":
|
|
332
|
+
"""Create from dictionary.
|
|
333
|
+
|
|
334
|
+
Args:
|
|
335
|
+
data: Dictionary representation
|
|
336
|
+
|
|
337
|
+
Returns:
|
|
338
|
+
UIElementGraph instance
|
|
339
|
+
"""
|
|
340
|
+
return cls.model_validate(data)
|
|
341
|
+
|
|
342
|
+
def get_element_by_id(self, element_id: str) -> Optional[UIElement]:
|
|
343
|
+
"""Find element by ID.
|
|
344
|
+
|
|
345
|
+
Args:
|
|
346
|
+
element_id: Element ID to search for
|
|
347
|
+
|
|
348
|
+
Returns:
|
|
349
|
+
UIElement if found, None otherwise
|
|
350
|
+
"""
|
|
351
|
+
for element in self.elements:
|
|
352
|
+
if element.element_id == element_id:
|
|
353
|
+
return element
|
|
354
|
+
return None
|
|
355
|
+
|
|
356
|
+
def get_elements_by_role(self, role: str) -> list[UIElement]:
|
|
357
|
+
"""Find elements by role.
|
|
358
|
+
|
|
359
|
+
Args:
|
|
360
|
+
role: Role to filter by (e.g., "button", "textbox")
|
|
361
|
+
|
|
362
|
+
Returns:
|
|
363
|
+
List of matching UIElements
|
|
364
|
+
"""
|
|
365
|
+
return [el for el in self.elements if el.role == role]
|
|
366
|
+
|
|
367
|
+
def get_elements_by_text(
|
|
368
|
+
self,
|
|
369
|
+
text: str,
|
|
370
|
+
exact: bool = False,
|
|
371
|
+
) -> list[UIElement]:
|
|
372
|
+
"""Find elements by text content.
|
|
373
|
+
|
|
374
|
+
Args:
|
|
375
|
+
text: Text to search for
|
|
376
|
+
exact: If True, require exact match; if False, use substring match
|
|
377
|
+
|
|
378
|
+
Returns:
|
|
379
|
+
List of matching UIElements
|
|
380
|
+
"""
|
|
381
|
+
results = []
|
|
382
|
+
for el in self.elements:
|
|
383
|
+
if el.name is None:
|
|
384
|
+
continue
|
|
385
|
+
if exact:
|
|
386
|
+
if el.name == text:
|
|
387
|
+
results.append(el)
|
|
388
|
+
else:
|
|
389
|
+
if text.lower() in el.name.lower():
|
|
390
|
+
results.append(el)
|
|
391
|
+
return results
|
|
392
|
+
|
|
393
|
+
def __len__(self) -> int:
|
|
394
|
+
"""Return number of elements in the graph."""
|
|
395
|
+
return len(self.elements)
|
|
396
|
+
|
|
397
|
+
def __iter__(self):
|
|
398
|
+
"""Iterate over elements."""
|
|
399
|
+
return iter(self.elements)
|
|
@@ -0,0 +1,226 @@
|
|
|
1
|
+
# Demo Retrieval Module
|
|
2
|
+
|
|
3
|
+
This module provides functionality to index and retrieve similar demonstrations for few-shot prompting in GUI automation.
|
|
4
|
+
|
|
5
|
+
## Overview
|
|
6
|
+
|
|
7
|
+
The retrieval module consists of three main components:
|
|
8
|
+
|
|
9
|
+
1. **TextEmbedder** (`embeddings.py`) - Simple TF-IDF based text embeddings
|
|
10
|
+
2. **DemoIndex** (`index.py`) - Stores episodes with metadata and embeddings
|
|
11
|
+
3. **DemoRetriever** (`retriever.py`) - Retrieves top-K similar demos
|
|
12
|
+
|
|
13
|
+
## Quick Start
|
|
14
|
+
|
|
15
|
+
```python
|
|
16
|
+
from openadapt_ml.retrieval import DemoIndex, DemoRetriever
|
|
17
|
+
from openadapt_ml.schema import Episode
|
|
18
|
+
|
|
19
|
+
# 1. Create index and add episodes
|
|
20
|
+
index = DemoIndex()
|
|
21
|
+
index.add_many(episodes) # episodes is a list of Episode objects
|
|
22
|
+
index.build() # Compute embeddings
|
|
23
|
+
|
|
24
|
+
# 2. Create retriever
|
|
25
|
+
retriever = DemoRetriever(index, domain_bonus=0.2)
|
|
26
|
+
|
|
27
|
+
# 3. Retrieve similar demos
|
|
28
|
+
task = "Turn off Night Shift on macOS"
|
|
29
|
+
app_context = "System Settings"
|
|
30
|
+
similar_demos = retriever.retrieve(task, app_context, top_k=3)
|
|
31
|
+
|
|
32
|
+
# 4. Use with prompt formatting
|
|
33
|
+
from openadapt_ml.experiments.demo_prompt.format_demo import format_episode_as_demo
|
|
34
|
+
formatted_demo = format_episode_as_demo(similar_demos[0])
|
|
35
|
+
```
|
|
36
|
+
|
|
37
|
+
## Features
|
|
38
|
+
|
|
39
|
+
### Text Similarity
|
|
40
|
+
- Uses TF-IDF with cosine similarity for v1
|
|
41
|
+
- No external ML libraries required
|
|
42
|
+
- Can be upgraded to sentence-transformers later
|
|
43
|
+
|
|
44
|
+
### Domain Matching
|
|
45
|
+
- Auto-extracts app name from observations
|
|
46
|
+
- Auto-extracts domain from URLs
|
|
47
|
+
- Applies bonus score for domain/app matches
|
|
48
|
+
|
|
49
|
+
### Metadata Support
|
|
50
|
+
- Stores arbitrary metadata with each demo
|
|
51
|
+
- Tracks app name, domain, and custom fields
|
|
52
|
+
- Efficient filtering by app/domain
|
|
53
|
+
|
|
54
|
+
## API Reference
|
|
55
|
+
|
|
56
|
+
### DemoIndex
|
|
57
|
+
|
|
58
|
+
```python
|
|
59
|
+
index = DemoIndex()
|
|
60
|
+
|
|
61
|
+
# Add episodes
|
|
62
|
+
index.add(episode, app_name="Chrome", domain="github.com")
|
|
63
|
+
index.add_many(episodes)
|
|
64
|
+
|
|
65
|
+
# Build index (required before retrieval)
|
|
66
|
+
index.build()
|
|
67
|
+
|
|
68
|
+
# Query index
|
|
69
|
+
index.get_apps() # List of unique app names
|
|
70
|
+
index.get_domains() # List of unique domains
|
|
71
|
+
len(index) # Number of demos
|
|
72
|
+
index.is_fitted() # Check if built
|
|
73
|
+
```
|
|
74
|
+
|
|
75
|
+
### DemoRetriever
|
|
76
|
+
|
|
77
|
+
```python
|
|
78
|
+
retriever = DemoRetriever(
|
|
79
|
+
index,
|
|
80
|
+
domain_bonus=0.2, # Bonus score for domain match
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
# Retrieve episodes
|
|
84
|
+
episodes = retriever.retrieve(
|
|
85
|
+
task="Description of task",
|
|
86
|
+
app_context="Chrome", # Optional
|
|
87
|
+
top_k=3,
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
# Retrieve with scores (for debugging)
|
|
91
|
+
results = retriever.retrieve_with_scores(task, app_context, top_k=3)
|
|
92
|
+
for result in results:
|
|
93
|
+
print(f"Score: {result.score}")
|
|
94
|
+
print(f" Text similarity: {result.text_score}")
|
|
95
|
+
print(f" Domain bonus: {result.domain_bonus}")
|
|
96
|
+
print(f" Goal: {result.demo.episode.goal}")
|
|
97
|
+
```
|
|
98
|
+
|
|
99
|
+
### TextEmbedder
|
|
100
|
+
|
|
101
|
+
```python
|
|
102
|
+
from openadapt_ml.retrieval.embeddings import TextEmbedder
|
|
103
|
+
|
|
104
|
+
embedder = TextEmbedder()
|
|
105
|
+
|
|
106
|
+
# Fit on corpus
|
|
107
|
+
documents = ["task 1", "task 2", "task 3"]
|
|
108
|
+
embedder.fit(documents)
|
|
109
|
+
|
|
110
|
+
# Embed text
|
|
111
|
+
vec1 = embedder.embed("new task")
|
|
112
|
+
vec2 = embedder.embed("another task")
|
|
113
|
+
|
|
114
|
+
# Compute similarity
|
|
115
|
+
similarity = embedder.cosine_similarity(vec1, vec2)
|
|
116
|
+
```
|
|
117
|
+
|
|
118
|
+
## Scoring
|
|
119
|
+
|
|
120
|
+
The retrieval score combines text similarity and domain matching:
|
|
121
|
+
|
|
122
|
+
```
|
|
123
|
+
total_score = text_similarity + domain_bonus
|
|
124
|
+
```
|
|
125
|
+
|
|
126
|
+
- **Text similarity**: TF-IDF cosine similarity between task descriptions (0-1)
|
|
127
|
+
- **Domain bonus**: Fixed bonus if app_context matches demo's app or domain (default: 0.2)
|
|
128
|
+
|
|
129
|
+
### Example Scores
|
|
130
|
+
|
|
131
|
+
```
|
|
132
|
+
Query: "Search GitHub for ML papers"
|
|
133
|
+
App context: "github.com"
|
|
134
|
+
|
|
135
|
+
Demo 1: "Search for machine learning papers on GitHub"
|
|
136
|
+
- Text similarity: 0.678
|
|
137
|
+
- Domain bonus: 0.200 (github.com match)
|
|
138
|
+
- Total: 0.878 ⭐ Best match
|
|
139
|
+
|
|
140
|
+
Demo 2: "Create a new GitHub repository"
|
|
141
|
+
- Text similarity: 0.111
|
|
142
|
+
- Domain bonus: 0.200 (github.com match)
|
|
143
|
+
- Total: 0.311
|
|
144
|
+
|
|
145
|
+
Demo 3: "Search for Python documentation on Google"
|
|
146
|
+
- Text similarity: 0.232
|
|
147
|
+
- Domain bonus: 0.000 (no match)
|
|
148
|
+
- Total: 0.232
|
|
149
|
+
```
|
|
150
|
+
|
|
151
|
+
## Loading Real Episodes
|
|
152
|
+
|
|
153
|
+
```python
|
|
154
|
+
from openadapt_ml.ingest.capture import load_capture
|
|
155
|
+
from openadapt_ml.retrieval import DemoIndex, DemoRetriever
|
|
156
|
+
|
|
157
|
+
# Load from capture directory
|
|
158
|
+
capture_path = "/path/to/capture"
|
|
159
|
+
episodes = load_capture(capture_path)
|
|
160
|
+
|
|
161
|
+
# Build index
|
|
162
|
+
index = DemoIndex()
|
|
163
|
+
index.add_many(episodes)
|
|
164
|
+
index.build()
|
|
165
|
+
|
|
166
|
+
# Retrieve
|
|
167
|
+
retriever = DemoRetriever(index)
|
|
168
|
+
demos = retriever.retrieve("New task description", top_k=3)
|
|
169
|
+
```
|
|
170
|
+
|
|
171
|
+
## Integration with Prompting
|
|
172
|
+
|
|
173
|
+
```python
|
|
174
|
+
from openadapt_ml.experiments.demo_prompt.format_demo import format_episode_as_demo
|
|
175
|
+
|
|
176
|
+
# Retrieve demo
|
|
177
|
+
demos = retriever.retrieve(task, app_context, top_k=1)
|
|
178
|
+
|
|
179
|
+
# Format for prompt
|
|
180
|
+
demo_text = format_episode_as_demo(demos[0], max_steps=10)
|
|
181
|
+
|
|
182
|
+
# Inject into prompt
|
|
183
|
+
prompt = f"""Here is a demonstration of a similar task:
|
|
184
|
+
|
|
185
|
+
{demo_text}
|
|
186
|
+
|
|
187
|
+
Now perform this task:
|
|
188
|
+
Task: {task}
|
|
189
|
+
"""
|
|
190
|
+
```
|
|
191
|
+
|
|
192
|
+
## Examples
|
|
193
|
+
|
|
194
|
+
See `examples/demo_retrieval_example.py` for a complete working example.
|
|
195
|
+
|
|
196
|
+
Run it with:
|
|
197
|
+
```bash
|
|
198
|
+
uv run python examples/demo_retrieval_example.py
|
|
199
|
+
```
|
|
200
|
+
|
|
201
|
+
## Future Improvements
|
|
202
|
+
|
|
203
|
+
### v2: Better Embeddings
|
|
204
|
+
Replace TF-IDF with sentence-transformers:
|
|
205
|
+
```python
|
|
206
|
+
from sentence_transformers import SentenceTransformer
|
|
207
|
+
model = SentenceTransformer('all-MiniLM-L6-v2')
|
|
208
|
+
```
|
|
209
|
+
|
|
210
|
+
### v3: Semantic Search
|
|
211
|
+
- Use FAISS or Qdrant for large-scale retrieval
|
|
212
|
+
- Add metadata filtering before similarity search
|
|
213
|
+
- Support multi-modal embeddings (text + screenshots)
|
|
214
|
+
|
|
215
|
+
### v4: Learning to Rank
|
|
216
|
+
- Train a ranking model using success/failure data
|
|
217
|
+
- Incorporate user feedback
|
|
218
|
+
- Personalized retrieval based on agent history
|
|
219
|
+
|
|
220
|
+
## Design Principles
|
|
221
|
+
|
|
222
|
+
1. **Start simple** - v1 uses no ML models, just text matching
|
|
223
|
+
2. **Functional over optimal** - Works out of the box, can be improved later
|
|
224
|
+
3. **Clear API** - Simple retrieve() interface, complex details hidden
|
|
225
|
+
4. **Composable** - Each component can be used independently
|
|
226
|
+
5. **Schema-first** - Works with Episode schema, no custom data structures
|