openadapt-ml 0.1.0__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (112) hide show
  1. openadapt_ml/baselines/__init__.py +121 -0
  2. openadapt_ml/baselines/adapter.py +185 -0
  3. openadapt_ml/baselines/cli.py +314 -0
  4. openadapt_ml/baselines/config.py +448 -0
  5. openadapt_ml/baselines/parser.py +922 -0
  6. openadapt_ml/baselines/prompts.py +787 -0
  7. openadapt_ml/benchmarks/__init__.py +13 -107
  8. openadapt_ml/benchmarks/agent.py +297 -374
  9. openadapt_ml/benchmarks/azure.py +62 -24
  10. openadapt_ml/benchmarks/azure_ops_tracker.py +521 -0
  11. openadapt_ml/benchmarks/cli.py +1874 -751
  12. openadapt_ml/benchmarks/trace_export.py +631 -0
  13. openadapt_ml/benchmarks/viewer.py +1236 -0
  14. openadapt_ml/benchmarks/vm_monitor.py +1111 -0
  15. openadapt_ml/benchmarks/waa_deploy/Dockerfile +216 -0
  16. openadapt_ml/benchmarks/waa_deploy/__init__.py +10 -0
  17. openadapt_ml/benchmarks/waa_deploy/api_agent.py +540 -0
  18. openadapt_ml/benchmarks/waa_deploy/start_waa_server.bat +53 -0
  19. openadapt_ml/cloud/azure_inference.py +3 -5
  20. openadapt_ml/cloud/lambda_labs.py +722 -307
  21. openadapt_ml/cloud/local.py +3194 -89
  22. openadapt_ml/cloud/ssh_tunnel.py +595 -0
  23. openadapt_ml/datasets/next_action.py +125 -96
  24. openadapt_ml/evals/grounding.py +32 -9
  25. openadapt_ml/evals/plot_eval_metrics.py +15 -13
  26. openadapt_ml/evals/trajectory_matching.py +120 -57
  27. openadapt_ml/experiments/demo_prompt/__init__.py +19 -0
  28. openadapt_ml/experiments/demo_prompt/format_demo.py +236 -0
  29. openadapt_ml/experiments/demo_prompt/results/experiment_20251231_002125.json +83 -0
  30. openadapt_ml/experiments/demo_prompt/results/experiment_n30_20251231_165958.json +1100 -0
  31. openadapt_ml/experiments/demo_prompt/results/multistep_20251231_025051.json +182 -0
  32. openadapt_ml/experiments/demo_prompt/run_experiment.py +541 -0
  33. openadapt_ml/experiments/representation_shootout/__init__.py +70 -0
  34. openadapt_ml/experiments/representation_shootout/conditions.py +708 -0
  35. openadapt_ml/experiments/representation_shootout/config.py +390 -0
  36. openadapt_ml/experiments/representation_shootout/evaluator.py +659 -0
  37. openadapt_ml/experiments/representation_shootout/runner.py +687 -0
  38. openadapt_ml/experiments/waa_demo/__init__.py +10 -0
  39. openadapt_ml/experiments/waa_demo/demos.py +357 -0
  40. openadapt_ml/experiments/waa_demo/runner.py +732 -0
  41. openadapt_ml/experiments/waa_demo/tasks.py +151 -0
  42. openadapt_ml/export/__init__.py +9 -0
  43. openadapt_ml/export/__main__.py +6 -0
  44. openadapt_ml/export/cli.py +89 -0
  45. openadapt_ml/export/parquet.py +277 -0
  46. openadapt_ml/grounding/detector.py +18 -14
  47. openadapt_ml/ingest/__init__.py +11 -10
  48. openadapt_ml/ingest/capture.py +97 -86
  49. openadapt_ml/ingest/loader.py +120 -69
  50. openadapt_ml/ingest/synthetic.py +344 -193
  51. openadapt_ml/models/api_adapter.py +14 -4
  52. openadapt_ml/models/base_adapter.py +10 -2
  53. openadapt_ml/models/providers/__init__.py +288 -0
  54. openadapt_ml/models/providers/anthropic.py +266 -0
  55. openadapt_ml/models/providers/base.py +299 -0
  56. openadapt_ml/models/providers/google.py +376 -0
  57. openadapt_ml/models/providers/openai.py +342 -0
  58. openadapt_ml/models/qwen_vl.py +46 -19
  59. openadapt_ml/perception/__init__.py +35 -0
  60. openadapt_ml/perception/integration.py +399 -0
  61. openadapt_ml/retrieval/README.md +226 -0
  62. openadapt_ml/retrieval/USAGE.md +391 -0
  63. openadapt_ml/retrieval/__init__.py +91 -0
  64. openadapt_ml/retrieval/demo_retriever.py +843 -0
  65. openadapt_ml/retrieval/embeddings.py +630 -0
  66. openadapt_ml/retrieval/index.py +194 -0
  67. openadapt_ml/retrieval/retriever.py +162 -0
  68. openadapt_ml/runtime/__init__.py +50 -0
  69. openadapt_ml/runtime/policy.py +27 -14
  70. openadapt_ml/runtime/safety_gate.py +471 -0
  71. openadapt_ml/schema/__init__.py +113 -0
  72. openadapt_ml/schema/converters.py +588 -0
  73. openadapt_ml/schema/episode.py +470 -0
  74. openadapt_ml/scripts/capture_screenshots.py +530 -0
  75. openadapt_ml/scripts/compare.py +102 -61
  76. openadapt_ml/scripts/demo_policy.py +4 -1
  77. openadapt_ml/scripts/eval_policy.py +19 -14
  78. openadapt_ml/scripts/make_gif.py +1 -1
  79. openadapt_ml/scripts/prepare_synthetic.py +16 -17
  80. openadapt_ml/scripts/train.py +98 -75
  81. openadapt_ml/segmentation/README.md +920 -0
  82. openadapt_ml/segmentation/__init__.py +97 -0
  83. openadapt_ml/segmentation/adapters/__init__.py +5 -0
  84. openadapt_ml/segmentation/adapters/capture_adapter.py +420 -0
  85. openadapt_ml/segmentation/annotator.py +610 -0
  86. openadapt_ml/segmentation/cache.py +290 -0
  87. openadapt_ml/segmentation/cli.py +674 -0
  88. openadapt_ml/segmentation/deduplicator.py +656 -0
  89. openadapt_ml/segmentation/frame_describer.py +788 -0
  90. openadapt_ml/segmentation/pipeline.py +340 -0
  91. openadapt_ml/segmentation/schemas.py +622 -0
  92. openadapt_ml/segmentation/segment_extractor.py +634 -0
  93. openadapt_ml/training/azure_ops_viewer.py +1097 -0
  94. openadapt_ml/training/benchmark_viewer.py +3255 -19
  95. openadapt_ml/training/shared_ui.py +7 -7
  96. openadapt_ml/training/stub_provider.py +57 -35
  97. openadapt_ml/training/trainer.py +255 -441
  98. openadapt_ml/training/trl_trainer.py +403 -0
  99. openadapt_ml/training/viewer.py +323 -108
  100. openadapt_ml/training/viewer_components.py +180 -0
  101. {openadapt_ml-0.1.0.dist-info → openadapt_ml-0.2.1.dist-info}/METADATA +312 -69
  102. openadapt_ml-0.2.1.dist-info/RECORD +116 -0
  103. openadapt_ml/benchmarks/base.py +0 -366
  104. openadapt_ml/benchmarks/data_collection.py +0 -432
  105. openadapt_ml/benchmarks/runner.py +0 -381
  106. openadapt_ml/benchmarks/waa.py +0 -704
  107. openadapt_ml/schemas/__init__.py +0 -53
  108. openadapt_ml/schemas/sessions.py +0 -122
  109. openadapt_ml/schemas/validation.py +0 -252
  110. openadapt_ml-0.1.0.dist-info/RECORD +0 -55
  111. {openadapt_ml-0.1.0.dist-info → openadapt_ml-0.2.1.dist-info}/WHEEL +0 -0
  112. {openadapt_ml-0.1.0.dist-info → openadapt_ml-0.2.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,399 @@
1
+ """
2
+ Integration Bridge between openadapt-grounding and openadapt-ml
3
+
4
+ This module provides the UIElementGraph class which wraps parsed UI elements
5
+ from openadapt-grounding parsers and converts them to the openadapt-ml schema.
6
+
7
+ Types imported from openadapt-grounding:
8
+ - Parser: Protocol for UI element parsers (OmniParser, UITars, etc.)
9
+ - Element: A detected UI element with normalized bounds
10
+ - LocatorResult: Result of attempting to locate an element
11
+ - RegistryEntry: A stable element that survived temporal filtering
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ import uuid
17
+ from collections import Counter
18
+ from typing import Any, Literal, Optional, Union
19
+
20
+ from pydantic import BaseModel, Field
21
+
22
+ from openadapt_ml.schema.episode import BoundingBox, UIElement
23
+
24
+ # Lazy import for openadapt-grounding to make it an optional dependency
25
+ _grounding_available: Optional[bool] = None
26
+
27
+
28
+ def _check_grounding_available() -> bool:
29
+ """Check if openadapt-grounding is installed."""
30
+ global _grounding_available
31
+ if _grounding_available is None:
32
+ try:
33
+ import openadapt_grounding # noqa: F401
34
+
35
+ _grounding_available = True
36
+ except ImportError:
37
+ _grounding_available = False
38
+ return _grounding_available
39
+
40
+
41
+ def _get_grounding_types():
42
+ """Import and return openadapt-grounding types.
43
+
44
+ Raises:
45
+ ImportError: If openadapt-grounding is not installed
46
+ """
47
+ if not _check_grounding_available():
48
+ raise ImportError(
49
+ "openadapt-grounding is required for perception integration. "
50
+ "Install it with: pip install openadapt-grounding"
51
+ )
52
+ from openadapt_grounding import Element, LocatorResult, Parser, RegistryEntry
53
+
54
+ return Element, LocatorResult, Parser, RegistryEntry
55
+
56
+
57
+ # Source types for UI element graphs
58
+ SourceType = Literal["omniparser", "uitars", "ax", "dom", "mixed"]
59
+
60
+
61
+ def element_to_ui_element(
62
+ element: Any, # openadapt_grounding.Element
63
+ image_width: Optional[int] = None,
64
+ image_height: Optional[int] = None,
65
+ element_index: Optional[int] = None,
66
+ source: Optional[str] = None,
67
+ ) -> UIElement:
68
+ """Convert an openadapt-grounding Element to a schema UIElement.
69
+
70
+ Args:
71
+ element: Element from openadapt-grounding with normalized bounds
72
+ image_width: Image width in pixels (for coordinate conversion)
73
+ image_height: Image height in pixels (for coordinate conversion)
74
+ element_index: Optional index to use as element_id
75
+ source: Source parser name (stored in automation_id for reference)
76
+
77
+ Returns:
78
+ UIElement with converted bounds and properties
79
+ """
80
+ Element, _, _, _ = _get_grounding_types()
81
+
82
+ if not isinstance(element, Element):
83
+ raise TypeError(f"Expected Element, got {type(element)}")
84
+
85
+ # Extract normalized bounds (x, y, width, height)
86
+ norm_x, norm_y, norm_w, norm_h = element.bounds
87
+
88
+ # Convert to pixel coordinates if dimensions provided
89
+ if image_width is not None and image_height is not None:
90
+ bounds = BoundingBox(
91
+ x=int(norm_x * image_width),
92
+ y=int(norm_y * image_height),
93
+ width=int(norm_w * image_width),
94
+ height=int(norm_h * image_height),
95
+ )
96
+ else:
97
+ # Store normalized as integers (multiply by 10000 for precision)
98
+ # This allows using BoundingBox which requires int values
99
+ bounds = BoundingBox(
100
+ x=int(norm_x * 10000),
101
+ y=int(norm_y * 10000),
102
+ width=int(norm_w * 10000),
103
+ height=int(norm_h * 10000),
104
+ )
105
+
106
+ # Generate element_id from index if provided
107
+ element_id = str(element_index) if element_index is not None else None
108
+
109
+ return UIElement(
110
+ role=element.element_type,
111
+ name=element.text,
112
+ bounds=bounds,
113
+ element_id=element_id,
114
+ automation_id=source, # Store source in automation_id for reference
115
+ )
116
+
117
+
118
+ def ui_element_to_element(
119
+ ui_element: UIElement,
120
+ image_width: Optional[int] = None,
121
+ image_height: Optional[int] = None,
122
+ ) -> Any: # Returns openadapt_grounding.Element
123
+ """Convert a schema UIElement back to an openadapt-grounding Element.
124
+
125
+ Useful for evaluation or when passing elements back to grounding functions.
126
+
127
+ Args:
128
+ ui_element: UIElement from openadapt-ml schema
129
+ image_width: Image width in pixels (for coordinate normalization)
130
+ image_height: Image height in pixels (for coordinate normalization)
131
+
132
+ Returns:
133
+ Element with normalized bounds
134
+ """
135
+ Element, _, _, _ = _get_grounding_types()
136
+
137
+ if ui_element.bounds is None:
138
+ # Return element with zero bounds if no bounds available
139
+ return Element(
140
+ bounds=(0.0, 0.0, 0.0, 0.0),
141
+ text=ui_element.name,
142
+ element_type=ui_element.role or "unknown",
143
+ confidence=1.0,
144
+ )
145
+
146
+ # Convert pixel coordinates to normalized
147
+ if image_width is not None and image_height is not None:
148
+ norm_x = ui_element.bounds.x / image_width
149
+ norm_y = ui_element.bounds.y / image_height
150
+ norm_w = ui_element.bounds.width / image_width
151
+ norm_h = ui_element.bounds.height / image_height
152
+ else:
153
+ # Assume bounds are already in 10000-scale normalized format
154
+ norm_x = ui_element.bounds.x / 10000
155
+ norm_y = ui_element.bounds.y / 10000
156
+ norm_w = ui_element.bounds.width / 10000
157
+ norm_h = ui_element.bounds.height / 10000
158
+
159
+ return Element(
160
+ bounds=(norm_x, norm_y, norm_w, norm_h),
161
+ text=ui_element.name,
162
+ element_type=ui_element.role or "unknown",
163
+ confidence=1.0,
164
+ )
165
+
166
+
167
+ class UIElementGraph(BaseModel):
168
+ """A graph of UI elements parsed from a screenshot.
169
+
170
+ This class wraps a list of UIElement objects and tracks their source
171
+ (which parser produced them). It provides a bridge between openadapt-grounding
172
+ parsers and the openadapt-ml schema system.
173
+
174
+ Attributes:
175
+ graph_id: Unique identifier for this graph (UUID)
176
+ elements: List of UIElement objects
177
+ source: Primary source parser ("omniparser", "uitars", "ax", "dom", "mixed")
178
+ source_summary: Count of elements by source (e.g., {"omniparser": 15, "ax": 8})
179
+ timestamp_ms: Optional timestamp when the screenshot was captured
180
+ image_width: Original image width (for coordinate reference)
181
+ image_height: Original image height (for coordinate reference)
182
+
183
+ Example:
184
+ >>> from openadapt_grounding import OmniParserClient
185
+ >>> from openadapt_ml.perception import UIElementGraph
186
+ >>>
187
+ >>> parser = OmniParserClient(endpoint="http://localhost:8000")
188
+ >>> elements = parser.parse(image)
189
+ >>> graph = UIElementGraph.from_parser_output(elements, "omniparser")
190
+ >>> print(f"Found {len(graph.elements)} elements")
191
+ """
192
+
193
+ graph_id: str = Field(
194
+ default_factory=lambda: str(uuid.uuid4()),
195
+ description="Unique identifier for this graph",
196
+ )
197
+ elements: list[UIElement] = Field(
198
+ default_factory=list,
199
+ description="List of UI elements in the graph",
200
+ )
201
+ source: SourceType = Field(
202
+ default="mixed",
203
+ description="Primary source parser for the elements",
204
+ )
205
+ source_summary: dict[str, int] = Field(
206
+ default_factory=dict,
207
+ description="Count of elements by source",
208
+ )
209
+ timestamp_ms: Optional[int] = Field(
210
+ None,
211
+ description="Timestamp when the screenshot was captured (milliseconds)",
212
+ )
213
+ image_width: Optional[int] = Field(
214
+ None,
215
+ description="Original image width in pixels",
216
+ )
217
+ image_height: Optional[int] = Field(
218
+ None,
219
+ description="Original image height in pixels",
220
+ )
221
+
222
+ @classmethod
223
+ def from_parser_output(
224
+ cls,
225
+ elements: list[Any], # list[openadapt_grounding.Element]
226
+ source: Union[SourceType, str],
227
+ image_width: Optional[int] = None,
228
+ image_height: Optional[int] = None,
229
+ timestamp_ms: Optional[int] = None,
230
+ ) -> "UIElementGraph":
231
+ """Create a UIElementGraph from parser output.
232
+
233
+ Args:
234
+ elements: List of Element objects from openadapt-grounding parser
235
+ source: Parser source name ("omniparser", "uitars", "ax", "dom")
236
+ image_width: Image width in pixels (for coordinate conversion)
237
+ image_height: Image height in pixels (for coordinate conversion)
238
+ timestamp_ms: Optional timestamp when screenshot was captured
239
+
240
+ Returns:
241
+ UIElementGraph with converted UIElement objects
242
+ """
243
+ # Validate source type
244
+ valid_sources = {"omniparser", "uitars", "ax", "dom", "mixed"}
245
+ if source not in valid_sources:
246
+ # Allow custom sources but warn
247
+ pass
248
+
249
+ # Convert elements
250
+ ui_elements = [
251
+ element_to_ui_element(
252
+ element=el,
253
+ image_width=image_width,
254
+ image_height=image_height,
255
+ element_index=i,
256
+ source=source,
257
+ )
258
+ for i, el in enumerate(elements)
259
+ ]
260
+
261
+ # Build source summary
262
+ source_summary = {source: len(elements)}
263
+
264
+ return cls(
265
+ elements=ui_elements,
266
+ source=source if source in valid_sources else "mixed",
267
+ source_summary=source_summary,
268
+ timestamp_ms=timestamp_ms,
269
+ image_width=image_width,
270
+ image_height=image_height,
271
+ )
272
+
273
+ @classmethod
274
+ def merge(
275
+ cls,
276
+ graphs: list["UIElementGraph"],
277
+ deduplicate_iou_threshold: Optional[float] = None,
278
+ ) -> "UIElementGraph":
279
+ """Merge multiple UIElementGraphs into one.
280
+
281
+ Args:
282
+ graphs: List of UIElementGraph objects to merge
283
+ deduplicate_iou_threshold: If provided, remove duplicate elements
284
+ with IoU greater than this threshold (0.0-1.0)
285
+
286
+ Returns:
287
+ New UIElementGraph with combined elements
288
+ """
289
+ if not graphs:
290
+ return cls()
291
+
292
+ # Combine all elements
293
+ all_elements: list[UIElement] = []
294
+ source_counts: Counter[str] = Counter()
295
+
296
+ for graph in graphs:
297
+ all_elements.extend(graph.elements)
298
+ for src, count in graph.source_summary.items():
299
+ source_counts[src] += count
300
+
301
+ # Get image dimensions from first graph that has them
302
+ image_width = None
303
+ image_height = None
304
+ for graph in graphs:
305
+ if graph.image_width is not None:
306
+ image_width = graph.image_width
307
+ image_height = graph.image_height
308
+ break
309
+
310
+ # TODO: Implement deduplication by IoU if threshold provided
311
+ # For now, just combine all elements
312
+
313
+ return cls(
314
+ elements=all_elements,
315
+ source="mixed" if len(source_counts) > 1 else list(source_counts.keys())[0],
316
+ source_summary=dict(source_counts),
317
+ timestamp_ms=graphs[0].timestamp_ms if graphs else None,
318
+ image_width=image_width,
319
+ image_height=image_height,
320
+ )
321
+
322
+ def to_dict(self) -> dict[str, Any]:
323
+ """Convert to JSON-serializable dictionary.
324
+
325
+ Returns:
326
+ Dictionary representation suitable for JSON serialization
327
+ """
328
+ return self.model_dump()
329
+
330
+ @classmethod
331
+ def from_dict(cls, data: dict[str, Any]) -> "UIElementGraph":
332
+ """Create from dictionary.
333
+
334
+ Args:
335
+ data: Dictionary representation
336
+
337
+ Returns:
338
+ UIElementGraph instance
339
+ """
340
+ return cls.model_validate(data)
341
+
342
+ def get_element_by_id(self, element_id: str) -> Optional[UIElement]:
343
+ """Find element by ID.
344
+
345
+ Args:
346
+ element_id: Element ID to search for
347
+
348
+ Returns:
349
+ UIElement if found, None otherwise
350
+ """
351
+ for element in self.elements:
352
+ if element.element_id == element_id:
353
+ return element
354
+ return None
355
+
356
+ def get_elements_by_role(self, role: str) -> list[UIElement]:
357
+ """Find elements by role.
358
+
359
+ Args:
360
+ role: Role to filter by (e.g., "button", "textbox")
361
+
362
+ Returns:
363
+ List of matching UIElements
364
+ """
365
+ return [el for el in self.elements if el.role == role]
366
+
367
+ def get_elements_by_text(
368
+ self,
369
+ text: str,
370
+ exact: bool = False,
371
+ ) -> list[UIElement]:
372
+ """Find elements by text content.
373
+
374
+ Args:
375
+ text: Text to search for
376
+ exact: If True, require exact match; if False, use substring match
377
+
378
+ Returns:
379
+ List of matching UIElements
380
+ """
381
+ results = []
382
+ for el in self.elements:
383
+ if el.name is None:
384
+ continue
385
+ if exact:
386
+ if el.name == text:
387
+ results.append(el)
388
+ else:
389
+ if text.lower() in el.name.lower():
390
+ results.append(el)
391
+ return results
392
+
393
+ def __len__(self) -> int:
394
+ """Return number of elements in the graph."""
395
+ return len(self.elements)
396
+
397
+ def __iter__(self):
398
+ """Iterate over elements."""
399
+ return iter(self.elements)
@@ -0,0 +1,226 @@
1
+ # Demo Retrieval Module
2
+
3
+ This module provides functionality to index and retrieve similar demonstrations for few-shot prompting in GUI automation.
4
+
5
+ ## Overview
6
+
7
+ The retrieval module consists of three main components:
8
+
9
+ 1. **TextEmbedder** (`embeddings.py`) - Simple TF-IDF based text embeddings
10
+ 2. **DemoIndex** (`index.py`) - Stores episodes with metadata and embeddings
11
+ 3. **DemoRetriever** (`retriever.py`) - Retrieves top-K similar demos
12
+
13
+ ## Quick Start
14
+
15
+ ```python
16
+ from openadapt_ml.retrieval import DemoIndex, DemoRetriever
17
+ from openadapt_ml.schema import Episode
18
+
19
+ # 1. Create index and add episodes
20
+ index = DemoIndex()
21
+ index.add_many(episodes) # episodes is a list of Episode objects
22
+ index.build() # Compute embeddings
23
+
24
+ # 2. Create retriever
25
+ retriever = DemoRetriever(index, domain_bonus=0.2)
26
+
27
+ # 3. Retrieve similar demos
28
+ task = "Turn off Night Shift on macOS"
29
+ app_context = "System Settings"
30
+ similar_demos = retriever.retrieve(task, app_context, top_k=3)
31
+
32
+ # 4. Use with prompt formatting
33
+ from openadapt_ml.experiments.demo_prompt.format_demo import format_episode_as_demo
34
+ formatted_demo = format_episode_as_demo(similar_demos[0])
35
+ ```
36
+
37
+ ## Features
38
+
39
+ ### Text Similarity
40
+ - Uses TF-IDF with cosine similarity for v1
41
+ - No external ML libraries required
42
+ - Can be upgraded to sentence-transformers later
43
+
44
+ ### Domain Matching
45
+ - Auto-extracts app name from observations
46
+ - Auto-extracts domain from URLs
47
+ - Applies bonus score for domain/app matches
48
+
49
+ ### Metadata Support
50
+ - Stores arbitrary metadata with each demo
51
+ - Tracks app name, domain, and custom fields
52
+ - Efficient filtering by app/domain
53
+
54
+ ## API Reference
55
+
56
+ ### DemoIndex
57
+
58
+ ```python
59
+ index = DemoIndex()
60
+
61
+ # Add episodes
62
+ index.add(episode, app_name="Chrome", domain="github.com")
63
+ index.add_many(episodes)
64
+
65
+ # Build index (required before retrieval)
66
+ index.build()
67
+
68
+ # Query index
69
+ index.get_apps() # List of unique app names
70
+ index.get_domains() # List of unique domains
71
+ len(index) # Number of demos
72
+ index.is_fitted() # Check if built
73
+ ```
74
+
75
+ ### DemoRetriever
76
+
77
+ ```python
78
+ retriever = DemoRetriever(
79
+ index,
80
+ domain_bonus=0.2, # Bonus score for domain match
81
+ )
82
+
83
+ # Retrieve episodes
84
+ episodes = retriever.retrieve(
85
+ task="Description of task",
86
+ app_context="Chrome", # Optional
87
+ top_k=3,
88
+ )
89
+
90
+ # Retrieve with scores (for debugging)
91
+ results = retriever.retrieve_with_scores(task, app_context, top_k=3)
92
+ for result in results:
93
+ print(f"Score: {result.score}")
94
+ print(f" Text similarity: {result.text_score}")
95
+ print(f" Domain bonus: {result.domain_bonus}")
96
+ print(f" Goal: {result.demo.episode.goal}")
97
+ ```
98
+
99
+ ### TextEmbedder
100
+
101
+ ```python
102
+ from openadapt_ml.retrieval.embeddings import TextEmbedder
103
+
104
+ embedder = TextEmbedder()
105
+
106
+ # Fit on corpus
107
+ documents = ["task 1", "task 2", "task 3"]
108
+ embedder.fit(documents)
109
+
110
+ # Embed text
111
+ vec1 = embedder.embed("new task")
112
+ vec2 = embedder.embed("another task")
113
+
114
+ # Compute similarity
115
+ similarity = embedder.cosine_similarity(vec1, vec2)
116
+ ```
117
+
118
+ ## Scoring
119
+
120
+ The retrieval score combines text similarity and domain matching:
121
+
122
+ ```
123
+ total_score = text_similarity + domain_bonus
124
+ ```
125
+
126
+ - **Text similarity**: TF-IDF cosine similarity between task descriptions (0-1)
127
+ - **Domain bonus**: Fixed bonus if app_context matches demo's app or domain (default: 0.2)
128
+
129
+ ### Example Scores
130
+
131
+ ```
132
+ Query: "Search GitHub for ML papers"
133
+ App context: "github.com"
134
+
135
+ Demo 1: "Search for machine learning papers on GitHub"
136
+ - Text similarity: 0.678
137
+ - Domain bonus: 0.200 (github.com match)
138
+ - Total: 0.878 ⭐ Best match
139
+
140
+ Demo 2: "Create a new GitHub repository"
141
+ - Text similarity: 0.111
142
+ - Domain bonus: 0.200 (github.com match)
143
+ - Total: 0.311
144
+
145
+ Demo 3: "Search for Python documentation on Google"
146
+ - Text similarity: 0.232
147
+ - Domain bonus: 0.000 (no match)
148
+ - Total: 0.232
149
+ ```
150
+
151
+ ## Loading Real Episodes
152
+
153
+ ```python
154
+ from openadapt_ml.ingest.capture import load_capture
155
+ from openadapt_ml.retrieval import DemoIndex, DemoRetriever
156
+
157
+ # Load from capture directory
158
+ capture_path = "/path/to/capture"
159
+ episodes = load_capture(capture_path)
160
+
161
+ # Build index
162
+ index = DemoIndex()
163
+ index.add_many(episodes)
164
+ index.build()
165
+
166
+ # Retrieve
167
+ retriever = DemoRetriever(index)
168
+ demos = retriever.retrieve("New task description", top_k=3)
169
+ ```
170
+
171
+ ## Integration with Prompting
172
+
173
+ ```python
174
+ from openadapt_ml.experiments.demo_prompt.format_demo import format_episode_as_demo
175
+
176
+ # Retrieve demo
177
+ demos = retriever.retrieve(task, app_context, top_k=1)
178
+
179
+ # Format for prompt
180
+ demo_text = format_episode_as_demo(demos[0], max_steps=10)
181
+
182
+ # Inject into prompt
183
+ prompt = f"""Here is a demonstration of a similar task:
184
+
185
+ {demo_text}
186
+
187
+ Now perform this task:
188
+ Task: {task}
189
+ """
190
+ ```
191
+
192
+ ## Examples
193
+
194
+ See `examples/demo_retrieval_example.py` for a complete working example.
195
+
196
+ Run it with:
197
+ ```bash
198
+ uv run python examples/demo_retrieval_example.py
199
+ ```
200
+
201
+ ## Future Improvements
202
+
203
+ ### v2: Better Embeddings
204
+ Replace TF-IDF with sentence-transformers:
205
+ ```python
206
+ from sentence_transformers import SentenceTransformer
207
+ model = SentenceTransformer('all-MiniLM-L6-v2')
208
+ ```
209
+
210
+ ### v3: Semantic Search
211
+ - Use FAISS or Qdrant for large-scale retrieval
212
+ - Add metadata filtering before similarity search
213
+ - Support multi-modal embeddings (text + screenshots)
214
+
215
+ ### v4: Learning to Rank
216
+ - Train a ranking model using success/failure data
217
+ - Incorporate user feedback
218
+ - Personalized retrieval based on agent history
219
+
220
+ ## Design Principles
221
+
222
+ 1. **Start simple** - v1 uses no ML models, just text matching
223
+ 2. **Functional over optimal** - Works out of the box, can be improved later
224
+ 3. **Clear API** - Simple retrieve() interface, complex details hidden
225
+ 4. **Composable** - Each component can be used independently
226
+ 5. **Schema-first** - Works with Episode schema, no custom data structures