openadapt-ml 0.1.0__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openadapt_ml/baselines/__init__.py +121 -0
- openadapt_ml/baselines/adapter.py +185 -0
- openadapt_ml/baselines/cli.py +314 -0
- openadapt_ml/baselines/config.py +448 -0
- openadapt_ml/baselines/parser.py +922 -0
- openadapt_ml/baselines/prompts.py +787 -0
- openadapt_ml/benchmarks/__init__.py +13 -107
- openadapt_ml/benchmarks/agent.py +297 -374
- openadapt_ml/benchmarks/azure.py +62 -24
- openadapt_ml/benchmarks/azure_ops_tracker.py +521 -0
- openadapt_ml/benchmarks/cli.py +1874 -751
- openadapt_ml/benchmarks/trace_export.py +631 -0
- openadapt_ml/benchmarks/viewer.py +1236 -0
- openadapt_ml/benchmarks/vm_monitor.py +1111 -0
- openadapt_ml/benchmarks/waa_deploy/Dockerfile +216 -0
- openadapt_ml/benchmarks/waa_deploy/__init__.py +10 -0
- openadapt_ml/benchmarks/waa_deploy/api_agent.py +540 -0
- openadapt_ml/benchmarks/waa_deploy/start_waa_server.bat +53 -0
- openadapt_ml/cloud/azure_inference.py +3 -5
- openadapt_ml/cloud/lambda_labs.py +722 -307
- openadapt_ml/cloud/local.py +3194 -89
- openadapt_ml/cloud/ssh_tunnel.py +595 -0
- openadapt_ml/datasets/next_action.py +125 -96
- openadapt_ml/evals/grounding.py +32 -9
- openadapt_ml/evals/plot_eval_metrics.py +15 -13
- openadapt_ml/evals/trajectory_matching.py +120 -57
- openadapt_ml/experiments/demo_prompt/__init__.py +19 -0
- openadapt_ml/experiments/demo_prompt/format_demo.py +236 -0
- openadapt_ml/experiments/demo_prompt/results/experiment_20251231_002125.json +83 -0
- openadapt_ml/experiments/demo_prompt/results/experiment_n30_20251231_165958.json +1100 -0
- openadapt_ml/experiments/demo_prompt/results/multistep_20251231_025051.json +182 -0
- openadapt_ml/experiments/demo_prompt/run_experiment.py +541 -0
- openadapt_ml/experiments/representation_shootout/__init__.py +70 -0
- openadapt_ml/experiments/representation_shootout/conditions.py +708 -0
- openadapt_ml/experiments/representation_shootout/config.py +390 -0
- openadapt_ml/experiments/representation_shootout/evaluator.py +659 -0
- openadapt_ml/experiments/representation_shootout/runner.py +687 -0
- openadapt_ml/experiments/waa_demo/__init__.py +10 -0
- openadapt_ml/experiments/waa_demo/demos.py +357 -0
- openadapt_ml/experiments/waa_demo/runner.py +732 -0
- openadapt_ml/experiments/waa_demo/tasks.py +151 -0
- openadapt_ml/export/__init__.py +9 -0
- openadapt_ml/export/__main__.py +6 -0
- openadapt_ml/export/cli.py +89 -0
- openadapt_ml/export/parquet.py +277 -0
- openadapt_ml/grounding/detector.py +18 -14
- openadapt_ml/ingest/__init__.py +11 -10
- openadapt_ml/ingest/capture.py +97 -86
- openadapt_ml/ingest/loader.py +120 -69
- openadapt_ml/ingest/synthetic.py +344 -193
- openadapt_ml/models/api_adapter.py +14 -4
- openadapt_ml/models/base_adapter.py +10 -2
- openadapt_ml/models/providers/__init__.py +288 -0
- openadapt_ml/models/providers/anthropic.py +266 -0
- openadapt_ml/models/providers/base.py +299 -0
- openadapt_ml/models/providers/google.py +376 -0
- openadapt_ml/models/providers/openai.py +342 -0
- openadapt_ml/models/qwen_vl.py +46 -19
- openadapt_ml/perception/__init__.py +35 -0
- openadapt_ml/perception/integration.py +399 -0
- openadapt_ml/retrieval/README.md +226 -0
- openadapt_ml/retrieval/USAGE.md +391 -0
- openadapt_ml/retrieval/__init__.py +91 -0
- openadapt_ml/retrieval/demo_retriever.py +843 -0
- openadapt_ml/retrieval/embeddings.py +630 -0
- openadapt_ml/retrieval/index.py +194 -0
- openadapt_ml/retrieval/retriever.py +162 -0
- openadapt_ml/runtime/__init__.py +50 -0
- openadapt_ml/runtime/policy.py +27 -14
- openadapt_ml/runtime/safety_gate.py +471 -0
- openadapt_ml/schema/__init__.py +113 -0
- openadapt_ml/schema/converters.py +588 -0
- openadapt_ml/schema/episode.py +470 -0
- openadapt_ml/scripts/capture_screenshots.py +530 -0
- openadapt_ml/scripts/compare.py +102 -61
- openadapt_ml/scripts/demo_policy.py +4 -1
- openadapt_ml/scripts/eval_policy.py +19 -14
- openadapt_ml/scripts/make_gif.py +1 -1
- openadapt_ml/scripts/prepare_synthetic.py +16 -17
- openadapt_ml/scripts/train.py +98 -75
- openadapt_ml/segmentation/README.md +920 -0
- openadapt_ml/segmentation/__init__.py +97 -0
- openadapt_ml/segmentation/adapters/__init__.py +5 -0
- openadapt_ml/segmentation/adapters/capture_adapter.py +420 -0
- openadapt_ml/segmentation/annotator.py +610 -0
- openadapt_ml/segmentation/cache.py +290 -0
- openadapt_ml/segmentation/cli.py +674 -0
- openadapt_ml/segmentation/deduplicator.py +656 -0
- openadapt_ml/segmentation/frame_describer.py +788 -0
- openadapt_ml/segmentation/pipeline.py +340 -0
- openadapt_ml/segmentation/schemas.py +622 -0
- openadapt_ml/segmentation/segment_extractor.py +634 -0
- openadapt_ml/training/azure_ops_viewer.py +1097 -0
- openadapt_ml/training/benchmark_viewer.py +3255 -19
- openadapt_ml/training/shared_ui.py +7 -7
- openadapt_ml/training/stub_provider.py +57 -35
- openadapt_ml/training/trainer.py +255 -441
- openadapt_ml/training/trl_trainer.py +403 -0
- openadapt_ml/training/viewer.py +323 -108
- openadapt_ml/training/viewer_components.py +180 -0
- {openadapt_ml-0.1.0.dist-info → openadapt_ml-0.2.1.dist-info}/METADATA +312 -69
- openadapt_ml-0.2.1.dist-info/RECORD +116 -0
- openadapt_ml/benchmarks/base.py +0 -366
- openadapt_ml/benchmarks/data_collection.py +0 -432
- openadapt_ml/benchmarks/runner.py +0 -381
- openadapt_ml/benchmarks/waa.py +0 -704
- openadapt_ml/schemas/__init__.py +0 -53
- openadapt_ml/schemas/sessions.py +0 -122
- openadapt_ml/schemas/validation.py +0 -252
- openadapt_ml-0.1.0.dist-info/RECORD +0 -55
- {openadapt_ml-0.1.0.dist-info → openadapt_ml-0.2.1.dist-info}/WHEEL +0 -0
- {openadapt_ml-0.1.0.dist-info → openadapt_ml-0.2.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,391 @@
|
|
|
1
|
+
# Demo Retrieval - Usage Guide
|
|
2
|
+
|
|
3
|
+
## Quick Reference
|
|
4
|
+
|
|
5
|
+
```python
|
|
6
|
+
# 1. Build index
|
|
7
|
+
from openadapt_ml.retrieval import DemoIndex, DemoRetriever
|
|
8
|
+
index = DemoIndex()
|
|
9
|
+
index.add_many(episodes)
|
|
10
|
+
index.build()
|
|
11
|
+
|
|
12
|
+
# 2. Retrieve
|
|
13
|
+
retriever = DemoRetriever(index)
|
|
14
|
+
similar_demos = retriever.retrieve("Turn off Night Shift", top_k=3)
|
|
15
|
+
```
|
|
16
|
+
|
|
17
|
+
## Complete Examples
|
|
18
|
+
|
|
19
|
+
### Example 1: Basic Usage with Synthetic Data
|
|
20
|
+
|
|
21
|
+
```python
|
|
22
|
+
from openadapt_ml.retrieval import DemoIndex, DemoRetriever
|
|
23
|
+
from openadapt_ml.schema import Action, ActionType, Episode, Observation, Step
|
|
24
|
+
|
|
25
|
+
# Create test episodes
|
|
26
|
+
def create_episode(instruction, app_name=None):
|
|
27
|
+
obs = Observation(app_name=app_name)
|
|
28
|
+
action = Action(type=ActionType.CLICK, normalized_coordinates=(0.5, 0.5))
|
|
29
|
+
step = Step(step_index=0, observation=obs, action=action)
|
|
30
|
+
return Episode(episode_id=f"ep_{instruction[:10]}", instruction=instruction, steps=[step])
|
|
31
|
+
|
|
32
|
+
episodes = [
|
|
33
|
+
create_episode("Turn off Night Shift", app_name="System Settings"),
|
|
34
|
+
create_episode("Search GitHub", app_name="Chrome"),
|
|
35
|
+
create_episode("Open calculator", app_name="Calculator"),
|
|
36
|
+
]
|
|
37
|
+
|
|
38
|
+
# Build index
|
|
39
|
+
index = DemoIndex()
|
|
40
|
+
index.add_many(episodes)
|
|
41
|
+
index.build()
|
|
42
|
+
|
|
43
|
+
# Retrieve
|
|
44
|
+
retriever = DemoRetriever(index, domain_bonus=0.2)
|
|
45
|
+
results = retriever.retrieve("Disable Night Shift", top_k=2)
|
|
46
|
+
|
|
47
|
+
print(f"Found {len(results)} similar demos:")
|
|
48
|
+
for ep in results:
|
|
49
|
+
print(f"- {ep.goal}")
|
|
50
|
+
```
|
|
51
|
+
|
|
52
|
+
### Example 2: Loading from Capture
|
|
53
|
+
|
|
54
|
+
```python
|
|
55
|
+
from openadapt_ml.ingest.capture import capture_to_episode
|
|
56
|
+
from openadapt_ml.retrieval import DemoIndex, DemoRetriever
|
|
57
|
+
|
|
58
|
+
# Load multiple captures
|
|
59
|
+
capture_paths = [
|
|
60
|
+
"/path/to/capture1",
|
|
61
|
+
"/path/to/capture2",
|
|
62
|
+
"/path/to/capture3",
|
|
63
|
+
]
|
|
64
|
+
|
|
65
|
+
episodes = [
|
|
66
|
+
capture_to_episode(path, include_moves=False)
|
|
67
|
+
for path in capture_paths
|
|
68
|
+
]
|
|
69
|
+
|
|
70
|
+
# Build index
|
|
71
|
+
index = DemoIndex()
|
|
72
|
+
index.add_many(episodes)
|
|
73
|
+
index.build()
|
|
74
|
+
|
|
75
|
+
# Retrieve for new task
|
|
76
|
+
retriever = DemoRetriever(index)
|
|
77
|
+
task = "Turn on dark mode"
|
|
78
|
+
app = "System Settings"
|
|
79
|
+
demos = retriever.retrieve(task, app_context=app, top_k=3)
|
|
80
|
+
```
|
|
81
|
+
|
|
82
|
+
### Example 3: Integration with Prompting
|
|
83
|
+
|
|
84
|
+
```python
|
|
85
|
+
from openadapt_ml.experiments.demo_prompt.format_demo import format_episode_as_demo
|
|
86
|
+
from openadapt_ml.retrieval import DemoIndex, DemoRetriever
|
|
87
|
+
|
|
88
|
+
# Build index (assume episodes already loaded)
|
|
89
|
+
index = DemoIndex()
|
|
90
|
+
index.add_many(episodes)
|
|
91
|
+
index.build()
|
|
92
|
+
|
|
93
|
+
# Retrieve for new task
|
|
94
|
+
retriever = DemoRetriever(index)
|
|
95
|
+
task = "Turn off Night Shift"
|
|
96
|
+
demos = retriever.retrieve(task, top_k=1)
|
|
97
|
+
|
|
98
|
+
# Format for prompt
|
|
99
|
+
if demos:
|
|
100
|
+
demo_text = format_episode_as_demo(demos[0], max_steps=10)
|
|
101
|
+
|
|
102
|
+
# Create few-shot prompt
|
|
103
|
+
prompt = f"""You are a GUI automation agent.
|
|
104
|
+
|
|
105
|
+
DEMONSTRATION OF SIMILAR TASK:
|
|
106
|
+
{demo_text}
|
|
107
|
+
|
|
108
|
+
NEW TASK:
|
|
109
|
+
{task}
|
|
110
|
+
|
|
111
|
+
What is your first action?"""
|
|
112
|
+
|
|
113
|
+
print(prompt)
|
|
114
|
+
```
|
|
115
|
+
|
|
116
|
+
### Example 4: Retrieval with Scores (Debugging)
|
|
117
|
+
|
|
118
|
+
```python
|
|
119
|
+
from openadapt_ml.retrieval import DemoRetriever
|
|
120
|
+
|
|
121
|
+
retriever = DemoRetriever(index, domain_bonus=0.3)
|
|
122
|
+
|
|
123
|
+
# Retrieve with scores for analysis
|
|
124
|
+
results = retriever.retrieve_with_scores(
|
|
125
|
+
task="Search for Python docs",
|
|
126
|
+
app_context="github.com",
|
|
127
|
+
top_k=5,
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
# Analyze scores
|
|
131
|
+
for i, result in enumerate(results, 1):
|
|
132
|
+
print(f"\n{i}. {result.demo.episode.goal}")
|
|
133
|
+
print(f" Total score: {result.score:.3f}")
|
|
134
|
+
print(f" Text similarity: {result.text_score:.3f}")
|
|
135
|
+
print(f" Domain bonus: {result.domain_bonus:.3f}")
|
|
136
|
+
|
|
137
|
+
if result.demo.app_name:
|
|
138
|
+
print(f" App: {result.demo.app_name}")
|
|
139
|
+
if result.demo.domain:
|
|
140
|
+
print(f" Domain: {result.demo.domain}")
|
|
141
|
+
```
|
|
142
|
+
|
|
143
|
+
### Example 5: Custom Metadata
|
|
144
|
+
|
|
145
|
+
```python
|
|
146
|
+
from openadapt_ml.retrieval import DemoIndex
|
|
147
|
+
|
|
148
|
+
index = DemoIndex()
|
|
149
|
+
|
|
150
|
+
# Add episodes with custom metadata
|
|
151
|
+
for episode in episodes:
|
|
152
|
+
metadata = {
|
|
153
|
+
"difficulty": "easy",
|
|
154
|
+
"success_rate": 0.95,
|
|
155
|
+
"duration_seconds": 30,
|
|
156
|
+
"tags": ["settings", "macOS"],
|
|
157
|
+
}
|
|
158
|
+
|
|
159
|
+
index.add(
|
|
160
|
+
episode,
|
|
161
|
+
app_name="System Settings",
|
|
162
|
+
domain=None,
|
|
163
|
+
metadata=metadata,
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
index.build()
|
|
167
|
+
|
|
168
|
+
# Access metadata after retrieval
|
|
169
|
+
retriever = DemoRetriever(index)
|
|
170
|
+
results = retriever.retrieve_with_scores("Turn off Night Shift", top_k=1)
|
|
171
|
+
|
|
172
|
+
if results:
|
|
173
|
+
demo = results[0].demo
|
|
174
|
+
print(f"Difficulty: {demo.metadata.get('difficulty')}")
|
|
175
|
+
print(f"Tags: {demo.metadata.get('tags')}")
|
|
176
|
+
```
|
|
177
|
+
|
|
178
|
+
## CLI Examples
|
|
179
|
+
|
|
180
|
+
Run the provided example scripts:
|
|
181
|
+
|
|
182
|
+
```bash
|
|
183
|
+
# Basic demo with synthetic data
|
|
184
|
+
uv run python examples/demo_retrieval_example.py
|
|
185
|
+
|
|
186
|
+
# Test with real capture
|
|
187
|
+
uv run python examples/retrieval_with_capture.py /path/to/capture
|
|
188
|
+
|
|
189
|
+
# With custom task
|
|
190
|
+
uv run python examples/retrieval_with_capture.py /path/to/capture "Turn off dark mode"
|
|
191
|
+
```
|
|
192
|
+
|
|
193
|
+
## Common Patterns
|
|
194
|
+
|
|
195
|
+
### Pattern 1: Multi-Domain Index
|
|
196
|
+
|
|
197
|
+
```python
|
|
198
|
+
# Build index with episodes from multiple domains
|
|
199
|
+
web_episodes = load_web_captures()
|
|
200
|
+
desktop_episodes = load_desktop_captures()
|
|
201
|
+
|
|
202
|
+
index = DemoIndex()
|
|
203
|
+
index.add_many(web_episodes)
|
|
204
|
+
index.add_many(desktop_episodes)
|
|
205
|
+
index.build()
|
|
206
|
+
|
|
207
|
+
# Retrieve with domain filtering via app_context
|
|
208
|
+
retriever = DemoRetriever(index, domain_bonus=0.5)
|
|
209
|
+
|
|
210
|
+
# This will prefer github.com demos
|
|
211
|
+
web_demos = retriever.retrieve("Search code", app_context="github.com", top_k=3)
|
|
212
|
+
|
|
213
|
+
# This will prefer System Settings demos
|
|
214
|
+
desktop_demos = retriever.retrieve("Change settings", app_context="System Settings", top_k=3)
|
|
215
|
+
```
|
|
216
|
+
|
|
217
|
+
### Pattern 2: Incremental Index Updates
|
|
218
|
+
|
|
219
|
+
```python
|
|
220
|
+
# Build initial index
|
|
221
|
+
index = DemoIndex()
|
|
222
|
+
index.add_many(initial_episodes)
|
|
223
|
+
index.build()
|
|
224
|
+
|
|
225
|
+
# Add new episodes
|
|
226
|
+
index.add(new_episode)
|
|
227
|
+
|
|
228
|
+
# Rebuild required after adding
|
|
229
|
+
index.build()
|
|
230
|
+
|
|
231
|
+
# Now retriever will use updated index
|
|
232
|
+
retriever = DemoRetriever(index)
|
|
233
|
+
```
|
|
234
|
+
|
|
235
|
+
### Pattern 3: Batch Retrieval
|
|
236
|
+
|
|
237
|
+
```python
|
|
238
|
+
# Retrieve for multiple tasks
|
|
239
|
+
tasks = [
|
|
240
|
+
"Turn off Night Shift",
|
|
241
|
+
"Enable dark mode",
|
|
242
|
+
"Adjust brightness",
|
|
243
|
+
]
|
|
244
|
+
|
|
245
|
+
retriever = DemoRetriever(index)
|
|
246
|
+
|
|
247
|
+
for task in tasks:
|
|
248
|
+
demos = retriever.retrieve(task, top_k=3)
|
|
249
|
+
print(f"\nTask: {task}")
|
|
250
|
+
for demo in demos:
|
|
251
|
+
print(f" - {demo.goal}")
|
|
252
|
+
```
|
|
253
|
+
|
|
254
|
+
## Tuning Parameters
|
|
255
|
+
|
|
256
|
+
### Domain Bonus
|
|
257
|
+
|
|
258
|
+
Controls how much to favor domain/app matches:
|
|
259
|
+
|
|
260
|
+
```python
|
|
261
|
+
# No domain bonus - pure text similarity
|
|
262
|
+
retriever = DemoRetriever(index, domain_bonus=0.0)
|
|
263
|
+
|
|
264
|
+
# Small bonus (default)
|
|
265
|
+
retriever = DemoRetriever(index, domain_bonus=0.2)
|
|
266
|
+
|
|
267
|
+
# Large bonus - heavily favor same domain
|
|
268
|
+
retriever = DemoRetriever(index, domain_bonus=0.5)
|
|
269
|
+
```
|
|
270
|
+
|
|
271
|
+
**Rule of thumb:**
|
|
272
|
+
- `0.0-0.1`: When task text is very specific and domain doesn't matter much
|
|
273
|
+
- `0.2-0.3`: Good default for most cases
|
|
274
|
+
- `0.4-0.5`: When domain matching is critical (e.g., domain-specific workflows)
|
|
275
|
+
|
|
276
|
+
### Top-K
|
|
277
|
+
|
|
278
|
+
Number of demos to retrieve:
|
|
279
|
+
|
|
280
|
+
```python
|
|
281
|
+
# Single best match
|
|
282
|
+
demos = retriever.retrieve(task, top_k=1)
|
|
283
|
+
|
|
284
|
+
# Few-shot with 3 examples
|
|
285
|
+
demos = retriever.retrieve(task, top_k=3)
|
|
286
|
+
|
|
287
|
+
# Retrieve more for analysis/selection
|
|
288
|
+
demos = retriever.retrieve(task, top_k=10)
|
|
289
|
+
```
|
|
290
|
+
|
|
291
|
+
**Rule of thumb:**
|
|
292
|
+
- `top_k=1`: When prompt length is constrained
|
|
293
|
+
- `top_k=3`: Good default for few-shot learning
|
|
294
|
+
- `top_k=5+`: For ensemble methods or human selection
|
|
295
|
+
|
|
296
|
+
## Performance Tips
|
|
297
|
+
|
|
298
|
+
### 1. Build Once, Retrieve Many
|
|
299
|
+
|
|
300
|
+
```python
|
|
301
|
+
# Good: Build once
|
|
302
|
+
index.build()
|
|
303
|
+
retriever = DemoRetriever(index)
|
|
304
|
+
for task in many_tasks:
|
|
305
|
+
retriever.retrieve(task)
|
|
306
|
+
|
|
307
|
+
# Bad: Build repeatedly
|
|
308
|
+
for task in many_tasks:
|
|
309
|
+
index.build() # Wasteful!
|
|
310
|
+
retriever = DemoRetriever(index)
|
|
311
|
+
retriever.retrieve(task)
|
|
312
|
+
```
|
|
313
|
+
|
|
314
|
+
### 2. Pre-extract Metadata
|
|
315
|
+
|
|
316
|
+
```python
|
|
317
|
+
# Good: Extract once when adding
|
|
318
|
+
index.add(episode, app_name="Chrome", domain="github.com")
|
|
319
|
+
|
|
320
|
+
# Less efficient: Let auto-extraction scan every episode
|
|
321
|
+
index.add(episode) # Will scan steps for app_name and domain
|
|
322
|
+
```
|
|
323
|
+
|
|
324
|
+
### 3. Filter Before Retrieval
|
|
325
|
+
|
|
326
|
+
```python
|
|
327
|
+
# If you have a large index but know the domain, create a filtered index
|
|
328
|
+
web_demos = [d for d in index.get_all_demos() if d.domain]
|
|
329
|
+
web_index = DemoIndex()
|
|
330
|
+
for demo in web_demos:
|
|
331
|
+
web_index.add(demo.episode)
|
|
332
|
+
web_index.build()
|
|
333
|
+
```
|
|
334
|
+
|
|
335
|
+
## Troubleshooting
|
|
336
|
+
|
|
337
|
+
### Issue: All scores are 0.0
|
|
338
|
+
|
|
339
|
+
**Cause:** Only one episode in index, so IDF is undefined.
|
|
340
|
+
|
|
341
|
+
**Solution:** Add more episodes or use a larger demo library.
|
|
342
|
+
|
|
343
|
+
```python
|
|
344
|
+
# Need at least 2-3 episodes for meaningful scores
|
|
345
|
+
assert len(index) >= 3, "Add more demos to the index"
|
|
346
|
+
```
|
|
347
|
+
|
|
348
|
+
### Issue: Domain bonus not applied
|
|
349
|
+
|
|
350
|
+
**Cause:** app_context doesn't match app_name or domain.
|
|
351
|
+
|
|
352
|
+
**Debug:**
|
|
353
|
+
```python
|
|
354
|
+
results = retriever.retrieve_with_scores(task, app_context, top_k=5)
|
|
355
|
+
for r in results:
|
|
356
|
+
print(f"App: {r.demo.app_name}, Domain: {r.demo.domain}, Bonus: {r.domain_bonus}")
|
|
357
|
+
```
|
|
358
|
+
|
|
359
|
+
**Solution:** Check exact string matching (case-insensitive contains).
|
|
360
|
+
|
|
361
|
+
### Issue: Poor retrieval quality
|
|
362
|
+
|
|
363
|
+
**Causes:**
|
|
364
|
+
1. Task descriptions too generic
|
|
365
|
+
2. Demo library too small
|
|
366
|
+
3. TF-IDF limitations
|
|
367
|
+
|
|
368
|
+
**Solutions:**
|
|
369
|
+
1. Use more specific task descriptions
|
|
370
|
+
2. Add more diverse demos to index
|
|
371
|
+
3. Upgrade to sentence-transformers (see README.md § Future Improvements)
|
|
372
|
+
|
|
373
|
+
## Testing
|
|
374
|
+
|
|
375
|
+
Run unit tests:
|
|
376
|
+
```bash
|
|
377
|
+
uv run pytest tests/test_retrieval.py -v
|
|
378
|
+
```
|
|
379
|
+
|
|
380
|
+
Run integration tests:
|
|
381
|
+
```bash
|
|
382
|
+
uv run python test_retrieval.py
|
|
383
|
+
```
|
|
384
|
+
|
|
385
|
+
## Next Steps
|
|
386
|
+
|
|
387
|
+
1. **Integrate with training**: Use retrieval in data augmentation
|
|
388
|
+
2. **Experiment with prompting**: Test different demo counts and formats
|
|
389
|
+
3. **Upgrade embeddings**: Try sentence-transformers for better similarity
|
|
390
|
+
4. **Add filtering**: Support domain/app filtering before similarity scoring
|
|
391
|
+
5. **Evaluate impact**: Measure action accuracy with/without retrieval
|
|
@@ -0,0 +1,91 @@
|
|
|
1
|
+
"""Demo retrieval module for finding similar demonstrations.
|
|
2
|
+
|
|
3
|
+
This module provides functionality for indexing and retrieving demonstrations
|
|
4
|
+
based on semantic similarity of task descriptions.
|
|
5
|
+
|
|
6
|
+
Main Components:
|
|
7
|
+
- DemoRetriever: Main class for indexing and retrieving demos
|
|
8
|
+
- Embedders: TFIDFEmbedder, SentenceTransformerEmbedder, OpenAIEmbedder
|
|
9
|
+
- DemoIndex: Legacy index class (use DemoRetriever instead)
|
|
10
|
+
|
|
11
|
+
Quick Start:
|
|
12
|
+
from openadapt_ml.retrieval import DemoRetriever
|
|
13
|
+
from openadapt_ml.schema import Episode
|
|
14
|
+
|
|
15
|
+
# Create retriever (TF-IDF is default, no external dependencies)
|
|
16
|
+
retriever = DemoRetriever()
|
|
17
|
+
|
|
18
|
+
# Or use sentence-transformers for better semantic matching
|
|
19
|
+
retriever = DemoRetriever(embedding_method="sentence_transformers")
|
|
20
|
+
|
|
21
|
+
# Add demos
|
|
22
|
+
retriever.add_demo(episode1)
|
|
23
|
+
retriever.add_demo(episode2, app_name="Chrome", domain="github.com")
|
|
24
|
+
|
|
25
|
+
# Build index (required before retrieval)
|
|
26
|
+
retriever.build_index()
|
|
27
|
+
|
|
28
|
+
# Retrieve similar demos
|
|
29
|
+
results = retriever.retrieve("Turn off Night Shift", top_k=3)
|
|
30
|
+
|
|
31
|
+
# Format for inclusion in a prompt
|
|
32
|
+
prompt_text = retriever.format_for_prompt(results)
|
|
33
|
+
|
|
34
|
+
Embedding Methods:
|
|
35
|
+
# TF-IDF (default, no dependencies)
|
|
36
|
+
retriever = DemoRetriever(embedding_method="tfidf")
|
|
37
|
+
|
|
38
|
+
# Sentence Transformers (recommended, requires: pip install sentence-transformers)
|
|
39
|
+
retriever = DemoRetriever(
|
|
40
|
+
embedding_method="sentence_transformers",
|
|
41
|
+
embedding_model="all-MiniLM-L6-v2", # Fast, 22MB
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
# OpenAI (requires: pip install openai, OPENAI_API_KEY env var)
|
|
45
|
+
retriever = DemoRetriever(
|
|
46
|
+
embedding_method="openai",
|
|
47
|
+
embedding_model="text-embedding-3-small",
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
See Also:
|
|
51
|
+
- docs/demo_retrieval_design.md - Full design document
|
|
52
|
+
- openadapt_ml/experiments/demo_prompt/ - Demo-conditioned prompting
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
# Main retrieval class (recommended)
|
|
56
|
+
from openadapt_ml.retrieval.demo_retriever import (
|
|
57
|
+
DemoRetriever,
|
|
58
|
+
DemoMetadata,
|
|
59
|
+
RetrievalResult,
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
# Embedders
|
|
63
|
+
from openadapt_ml.retrieval.embeddings import (
|
|
64
|
+
BaseEmbedder,
|
|
65
|
+
TFIDFEmbedder,
|
|
66
|
+
TextEmbedder, # Alias for TFIDFEmbedder (backward compat)
|
|
67
|
+
SentenceTransformerEmbedder,
|
|
68
|
+
OpenAIEmbedder,
|
|
69
|
+
create_embedder,
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
# Legacy classes (for backward compatibility)
|
|
73
|
+
from openadapt_ml.retrieval.index import DemoIndex
|
|
74
|
+
from openadapt_ml.retrieval.retriever import DemoRetriever as LegacyDemoRetriever
|
|
75
|
+
|
|
76
|
+
__all__ = [
|
|
77
|
+
# Main classes
|
|
78
|
+
"DemoRetriever",
|
|
79
|
+
"DemoMetadata",
|
|
80
|
+
"RetrievalResult",
|
|
81
|
+
# Embedders
|
|
82
|
+
"BaseEmbedder",
|
|
83
|
+
"TFIDFEmbedder",
|
|
84
|
+
"TextEmbedder",
|
|
85
|
+
"SentenceTransformerEmbedder",
|
|
86
|
+
"OpenAIEmbedder",
|
|
87
|
+
"create_embedder",
|
|
88
|
+
# Legacy (backward compat)
|
|
89
|
+
"DemoIndex",
|
|
90
|
+
"LegacyDemoRetriever",
|
|
91
|
+
]
|