openadapt-ml 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. openadapt_ml/benchmarks/__init__.py +8 -0
  2. openadapt_ml/benchmarks/agent.py +90 -11
  3. openadapt_ml/benchmarks/azure.py +35 -6
  4. openadapt_ml/benchmarks/cli.py +4449 -201
  5. openadapt_ml/benchmarks/live_tracker.py +180 -0
  6. openadapt_ml/benchmarks/runner.py +41 -4
  7. openadapt_ml/benchmarks/viewer.py +1219 -0
  8. openadapt_ml/benchmarks/vm_monitor.py +610 -0
  9. openadapt_ml/benchmarks/waa.py +61 -4
  10. openadapt_ml/benchmarks/waa_deploy/Dockerfile +222 -0
  11. openadapt_ml/benchmarks/waa_deploy/__init__.py +10 -0
  12. openadapt_ml/benchmarks/waa_deploy/api_agent.py +539 -0
  13. openadapt_ml/benchmarks/waa_deploy/start_waa_server.bat +53 -0
  14. openadapt_ml/benchmarks/waa_live.py +619 -0
  15. openadapt_ml/cloud/local.py +1555 -1
  16. openadapt_ml/cloud/ssh_tunnel.py +553 -0
  17. openadapt_ml/datasets/next_action.py +87 -68
  18. openadapt_ml/evals/grounding.py +26 -8
  19. openadapt_ml/evals/trajectory_matching.py +84 -36
  20. openadapt_ml/experiments/demo_prompt/__init__.py +19 -0
  21. openadapt_ml/experiments/demo_prompt/format_demo.py +226 -0
  22. openadapt_ml/experiments/demo_prompt/results/experiment_20251231_002125.json +83 -0
  23. openadapt_ml/experiments/demo_prompt/results/experiment_n30_20251231_165958.json +1100 -0
  24. openadapt_ml/experiments/demo_prompt/results/multistep_20251231_025051.json +182 -0
  25. openadapt_ml/experiments/demo_prompt/run_experiment.py +531 -0
  26. openadapt_ml/experiments/waa_demo/__init__.py +10 -0
  27. openadapt_ml/experiments/waa_demo/demos.py +357 -0
  28. openadapt_ml/experiments/waa_demo/runner.py +717 -0
  29. openadapt_ml/experiments/waa_demo/tasks.py +151 -0
  30. openadapt_ml/export/__init__.py +9 -0
  31. openadapt_ml/export/__main__.py +6 -0
  32. openadapt_ml/export/cli.py +89 -0
  33. openadapt_ml/export/parquet.py +265 -0
  34. openadapt_ml/ingest/__init__.py +3 -4
  35. openadapt_ml/ingest/capture.py +89 -81
  36. openadapt_ml/ingest/loader.py +116 -68
  37. openadapt_ml/ingest/synthetic.py +221 -159
  38. openadapt_ml/retrieval/README.md +226 -0
  39. openadapt_ml/retrieval/USAGE.md +391 -0
  40. openadapt_ml/retrieval/__init__.py +91 -0
  41. openadapt_ml/retrieval/demo_retriever.py +817 -0
  42. openadapt_ml/retrieval/embeddings.py +629 -0
  43. openadapt_ml/retrieval/index.py +194 -0
  44. openadapt_ml/retrieval/retriever.py +160 -0
  45. openadapt_ml/runtime/policy.py +10 -10
  46. openadapt_ml/schema/__init__.py +104 -0
  47. openadapt_ml/schema/converters.py +541 -0
  48. openadapt_ml/schema/episode.py +457 -0
  49. openadapt_ml/scripts/compare.py +26 -16
  50. openadapt_ml/scripts/eval_policy.py +4 -5
  51. openadapt_ml/scripts/prepare_synthetic.py +14 -17
  52. openadapt_ml/scripts/train.py +81 -70
  53. openadapt_ml/training/benchmark_viewer.py +3225 -0
  54. openadapt_ml/training/trainer.py +120 -363
  55. openadapt_ml/training/trl_trainer.py +354 -0
  56. {openadapt_ml-0.1.0.dist-info → openadapt_ml-0.2.0.dist-info}/METADATA +102 -60
  57. openadapt_ml-0.2.0.dist-info/RECORD +86 -0
  58. openadapt_ml/schemas/__init__.py +0 -53
  59. openadapt_ml/schemas/sessions.py +0 -122
  60. openadapt_ml/schemas/validation.py +0 -252
  61. openadapt_ml-0.1.0.dist-info/RECORD +0 -55
  62. {openadapt_ml-0.1.0.dist-info → openadapt_ml-0.2.0.dist-info}/WHEEL +0 -0
  63. {openadapt_ml-0.1.0.dist-info → openadapt_ml-0.2.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,226 @@
1
+ # Demo Retrieval Module
2
+
3
+ This module provides functionality to index and retrieve similar demonstrations for few-shot prompting in GUI automation.
4
+
5
+ ## Overview
6
+
7
+ The retrieval module consists of three main components:
8
+
9
+ 1. **TextEmbedder** (`embeddings.py`) - Simple TF-IDF based text embeddings
10
+ 2. **DemoIndex** (`index.py`) - Stores episodes with metadata and embeddings
11
+ 3. **DemoRetriever** (`retriever.py`) - Retrieves top-K similar demos
12
+
13
+ ## Quick Start
14
+
15
+ ```python
16
+ from openadapt_ml.retrieval import DemoIndex, DemoRetriever
17
+ from openadapt_ml.schema import Episode
18
+
19
+ # 1. Create index and add episodes
20
+ index = DemoIndex()
21
+ index.add_many(episodes) # episodes is a list of Episode objects
22
+ index.build() # Compute embeddings
23
+
24
+ # 2. Create retriever
25
+ retriever = DemoRetriever(index, domain_bonus=0.2)
26
+
27
+ # 3. Retrieve similar demos
28
+ task = "Turn off Night Shift on macOS"
29
+ app_context = "System Settings"
30
+ similar_demos = retriever.retrieve(task, app_context, top_k=3)
31
+
32
+ # 4. Use with prompt formatting
33
+ from openadapt_ml.experiments.demo_prompt.format_demo import format_episode_as_demo
34
+ formatted_demo = format_episode_as_demo(similar_demos[0])
35
+ ```
36
+
37
+ ## Features
38
+
39
+ ### Text Similarity
40
+ - Uses TF-IDF with cosine similarity for v1
41
+ - No external ML libraries required
42
+ - Can be upgraded to sentence-transformers later
43
+
44
+ ### Domain Matching
45
+ - Auto-extracts app name from observations
46
+ - Auto-extracts domain from URLs
47
+ - Applies bonus score for domain/app matches
48
+
49
+ ### Metadata Support
50
+ - Stores arbitrary metadata with each demo
51
+ - Tracks app name, domain, and custom fields
52
+ - Efficient filtering by app/domain
53
+
54
+ ## API Reference
55
+
56
+ ### DemoIndex
57
+
58
+ ```python
59
+ index = DemoIndex()
60
+
61
+ # Add episodes
62
+ index.add(episode, app_name="Chrome", domain="github.com")
63
+ index.add_many(episodes)
64
+
65
+ # Build index (required before retrieval)
66
+ index.build()
67
+
68
+ # Query index
69
+ index.get_apps() # List of unique app names
70
+ index.get_domains() # List of unique domains
71
+ len(index) # Number of demos
72
+ index.is_fitted() # Check if built
73
+ ```
74
+
75
+ ### DemoRetriever
76
+
77
+ ```python
78
+ retriever = DemoRetriever(
79
+ index,
80
+ domain_bonus=0.2, # Bonus score for domain match
81
+ )
82
+
83
+ # Retrieve episodes
84
+ episodes = retriever.retrieve(
85
+ task="Description of task",
86
+ app_context="Chrome", # Optional
87
+ top_k=3,
88
+ )
89
+
90
+ # Retrieve with scores (for debugging)
91
+ results = retriever.retrieve_with_scores(task, app_context, top_k=3)
92
+ for result in results:
93
+ print(f"Score: {result.score}")
94
+ print(f" Text similarity: {result.text_score}")
95
+ print(f" Domain bonus: {result.domain_bonus}")
96
+ print(f" Goal: {result.demo.episode.goal}")
97
+ ```
98
+
99
+ ### TextEmbedder
100
+
101
+ ```python
102
+ from openadapt_ml.retrieval.embeddings import TextEmbedder
103
+
104
+ embedder = TextEmbedder()
105
+
106
+ # Fit on corpus
107
+ documents = ["task 1", "task 2", "task 3"]
108
+ embedder.fit(documents)
109
+
110
+ # Embed text
111
+ vec1 = embedder.embed("new task")
112
+ vec2 = embedder.embed("another task")
113
+
114
+ # Compute similarity
115
+ similarity = embedder.cosine_similarity(vec1, vec2)
116
+ ```
117
+
118
+ ## Scoring
119
+
120
+ The retrieval score combines text similarity and domain matching:
121
+
122
+ ```
123
+ total_score = text_similarity + domain_bonus
124
+ ```
125
+
126
+ - **Text similarity**: TF-IDF cosine similarity between task descriptions (0-1)
127
+ - **Domain bonus**: Fixed bonus if app_context matches demo's app or domain (default: 0.2)
128
+
129
+ ### Example Scores
130
+
131
+ ```
132
+ Query: "Search GitHub for ML papers"
133
+ App context: "github.com"
134
+
135
+ Demo 1: "Search for machine learning papers on GitHub"
136
+ - Text similarity: 0.678
137
+ - Domain bonus: 0.200 (github.com match)
138
+ - Total: 0.878 ⭐ Best match
139
+
140
+ Demo 2: "Create a new GitHub repository"
141
+ - Text similarity: 0.111
142
+ - Domain bonus: 0.200 (github.com match)
143
+ - Total: 0.311
144
+
145
+ Demo 3: "Search for Python documentation on Google"
146
+ - Text similarity: 0.232
147
+ - Domain bonus: 0.000 (no match)
148
+ - Total: 0.232
149
+ ```
150
+
151
+ ## Loading Real Episodes
152
+
153
+ ```python
154
+ from openadapt_ml.ingest.capture import load_capture
155
+ from openadapt_ml.retrieval import DemoIndex, DemoRetriever
156
+
157
+ # Load from capture directory
158
+ capture_path = "/path/to/capture"
159
+ episodes = load_capture(capture_path)
160
+
161
+ # Build index
162
+ index = DemoIndex()
163
+ index.add_many(episodes)
164
+ index.build()
165
+
166
+ # Retrieve
167
+ retriever = DemoRetriever(index)
168
+ demos = retriever.retrieve("New task description", top_k=3)
169
+ ```
170
+
171
+ ## Integration with Prompting
172
+
173
+ ```python
174
+ from openadapt_ml.experiments.demo_prompt.format_demo import format_episode_as_demo
175
+
176
+ # Retrieve demo
177
+ demos = retriever.retrieve(task, app_context, top_k=1)
178
+
179
+ # Format for prompt
180
+ demo_text = format_episode_as_demo(demos[0], max_steps=10)
181
+
182
+ # Inject into prompt
183
+ prompt = f"""Here is a demonstration of a similar task:
184
+
185
+ {demo_text}
186
+
187
+ Now perform this task:
188
+ Task: {task}
189
+ """
190
+ ```
191
+
192
+ ## Examples
193
+
194
+ See `examples/demo_retrieval_example.py` for a complete working example.
195
+
196
+ Run it with:
197
+ ```bash
198
+ uv run python examples/demo_retrieval_example.py
199
+ ```
200
+
201
+ ## Future Improvements
202
+
203
+ ### v2: Better Embeddings
204
+ Replace TF-IDF with sentence-transformers:
205
+ ```python
206
+ from sentence_transformers import SentenceTransformer
207
+ model = SentenceTransformer('all-MiniLM-L6-v2')
208
+ ```
209
+
210
+ ### v3: Semantic Search
211
+ - Use FAISS or Qdrant for large-scale retrieval
212
+ - Add metadata filtering before similarity search
213
+ - Support multi-modal embeddings (text + screenshots)
214
+
215
+ ### v4: Learning to Rank
216
+ - Train a ranking model using success/failure data
217
+ - Incorporate user feedback
218
+ - Personalized retrieval based on agent history
219
+
220
+ ## Design Principles
221
+
222
+ 1. **Start simple** - v1 uses no ML models, just text matching
223
+ 2. **Functional over optimal** - Works out of the box, can be improved later
224
+ 3. **Clear API** - Simple retrieve() interface, complex details hidden
225
+ 4. **Composable** - Each component can be used independently
226
+ 5. **Schema-first** - Works with Episode schema, no custom data structures
@@ -0,0 +1,391 @@
1
+ # Demo Retrieval - Usage Guide
2
+
3
+ ## Quick Reference
4
+
5
+ ```python
6
+ # 1. Build index
7
+ from openadapt_ml.retrieval import DemoIndex, DemoRetriever
8
+ index = DemoIndex()
9
+ index.add_many(episodes)
10
+ index.build()
11
+
12
+ # 2. Retrieve
13
+ retriever = DemoRetriever(index)
14
+ similar_demos = retriever.retrieve("Turn off Night Shift", top_k=3)
15
+ ```
16
+
17
+ ## Complete Examples
18
+
19
+ ### Example 1: Basic Usage with Synthetic Data
20
+
21
+ ```python
22
+ from openadapt_ml.retrieval import DemoIndex, DemoRetriever
23
+ from openadapt_ml.schema import Action, ActionType, Episode, Observation, Step
24
+
25
+ # Create test episodes
26
+ def create_episode(instruction, app_name=None):
27
+ obs = Observation(app_name=app_name)
28
+ action = Action(type=ActionType.CLICK, normalized_coordinates=(0.5, 0.5))
29
+ step = Step(step_index=0, observation=obs, action=action)
30
+ return Episode(episode_id=f"ep_{instruction[:10]}", instruction=instruction, steps=[step])
31
+
32
+ episodes = [
33
+ create_episode("Turn off Night Shift", app_name="System Settings"),
34
+ create_episode("Search GitHub", app_name="Chrome"),
35
+ create_episode("Open calculator", app_name="Calculator"),
36
+ ]
37
+
38
+ # Build index
39
+ index = DemoIndex()
40
+ index.add_many(episodes)
41
+ index.build()
42
+
43
+ # Retrieve
44
+ retriever = DemoRetriever(index, domain_bonus=0.2)
45
+ results = retriever.retrieve("Disable Night Shift", top_k=2)
46
+
47
+ print(f"Found {len(results)} similar demos:")
48
+ for ep in results:
49
+ print(f"- {ep.goal}")
50
+ ```
51
+
52
+ ### Example 2: Loading from Capture
53
+
54
+ ```python
55
+ from openadapt_ml.ingest.capture import capture_to_episode
56
+ from openadapt_ml.retrieval import DemoIndex, DemoRetriever
57
+
58
+ # Load multiple captures
59
+ capture_paths = [
60
+ "/path/to/capture1",
61
+ "/path/to/capture2",
62
+ "/path/to/capture3",
63
+ ]
64
+
65
+ episodes = [
66
+ capture_to_episode(path, include_moves=False)
67
+ for path in capture_paths
68
+ ]
69
+
70
+ # Build index
71
+ index = DemoIndex()
72
+ index.add_many(episodes)
73
+ index.build()
74
+
75
+ # Retrieve for new task
76
+ retriever = DemoRetriever(index)
77
+ task = "Turn on dark mode"
78
+ app = "System Settings"
79
+ demos = retriever.retrieve(task, app_context=app, top_k=3)
80
+ ```
81
+
82
+ ### Example 3: Integration with Prompting
83
+
84
+ ```python
85
+ from openadapt_ml.experiments.demo_prompt.format_demo import format_episode_as_demo
86
+ from openadapt_ml.retrieval import DemoIndex, DemoRetriever
87
+
88
+ # Build index (assume episodes already loaded)
89
+ index = DemoIndex()
90
+ index.add_many(episodes)
91
+ index.build()
92
+
93
+ # Retrieve for new task
94
+ retriever = DemoRetriever(index)
95
+ task = "Turn off Night Shift"
96
+ demos = retriever.retrieve(task, top_k=1)
97
+
98
+ # Format for prompt
99
+ if demos:
100
+ demo_text = format_episode_as_demo(demos[0], max_steps=10)
101
+
102
+ # Create few-shot prompt
103
+ prompt = f"""You are a GUI automation agent.
104
+
105
+ DEMONSTRATION OF SIMILAR TASK:
106
+ {demo_text}
107
+
108
+ NEW TASK:
109
+ {task}
110
+
111
+ What is your first action?"""
112
+
113
+ print(prompt)
114
+ ```
115
+
116
+ ### Example 4: Retrieval with Scores (Debugging)
117
+
118
+ ```python
119
+ from openadapt_ml.retrieval import DemoRetriever
120
+
121
+ retriever = DemoRetriever(index, domain_bonus=0.3)
122
+
123
+ # Retrieve with scores for analysis
124
+ results = retriever.retrieve_with_scores(
125
+ task="Search for Python docs",
126
+ app_context="github.com",
127
+ top_k=5,
128
+ )
129
+
130
+ # Analyze scores
131
+ for i, result in enumerate(results, 1):
132
+ print(f"\n{i}. {result.demo.episode.goal}")
133
+ print(f" Total score: {result.score:.3f}")
134
+ print(f" Text similarity: {result.text_score:.3f}")
135
+ print(f" Domain bonus: {result.domain_bonus:.3f}")
136
+
137
+ if result.demo.app_name:
138
+ print(f" App: {result.demo.app_name}")
139
+ if result.demo.domain:
140
+ print(f" Domain: {result.demo.domain}")
141
+ ```
142
+
143
+ ### Example 5: Custom Metadata
144
+
145
+ ```python
146
+ from openadapt_ml.retrieval import DemoIndex
147
+
148
+ index = DemoIndex()
149
+
150
+ # Add episodes with custom metadata
151
+ for episode in episodes:
152
+ metadata = {
153
+ "difficulty": "easy",
154
+ "success_rate": 0.95,
155
+ "duration_seconds": 30,
156
+ "tags": ["settings", "macOS"],
157
+ }
158
+
159
+ index.add(
160
+ episode,
161
+ app_name="System Settings",
162
+ domain=None,
163
+ metadata=metadata,
164
+ )
165
+
166
+ index.build()
167
+
168
+ # Access metadata after retrieval
169
+ retriever = DemoRetriever(index)
170
+ results = retriever.retrieve_with_scores("Turn off Night Shift", top_k=1)
171
+
172
+ if results:
173
+ demo = results[0].demo
174
+ print(f"Difficulty: {demo.metadata.get('difficulty')}")
175
+ print(f"Tags: {demo.metadata.get('tags')}")
176
+ ```
177
+
178
+ ## CLI Examples
179
+
180
+ Run the provided example scripts:
181
+
182
+ ```bash
183
+ # Basic demo with synthetic data
184
+ uv run python examples/demo_retrieval_example.py
185
+
186
+ # Test with real capture
187
+ uv run python examples/retrieval_with_capture.py /path/to/capture
188
+
189
+ # With custom task
190
+ uv run python examples/retrieval_with_capture.py /path/to/capture "Turn off dark mode"
191
+ ```
192
+
193
+ ## Common Patterns
194
+
195
+ ### Pattern 1: Multi-Domain Index
196
+
197
+ ```python
198
+ # Build index with episodes from multiple domains
199
+ web_episodes = load_web_captures()
200
+ desktop_episodes = load_desktop_captures()
201
+
202
+ index = DemoIndex()
203
+ index.add_many(web_episodes)
204
+ index.add_many(desktop_episodes)
205
+ index.build()
206
+
207
+ # Retrieve with domain filtering via app_context
208
+ retriever = DemoRetriever(index, domain_bonus=0.5)
209
+
210
+ # This will prefer github.com demos
211
+ web_demos = retriever.retrieve("Search code", app_context="github.com", top_k=3)
212
+
213
+ # This will prefer System Settings demos
214
+ desktop_demos = retriever.retrieve("Change settings", app_context="System Settings", top_k=3)
215
+ ```
216
+
217
+ ### Pattern 2: Incremental Index Updates
218
+
219
+ ```python
220
+ # Build initial index
221
+ index = DemoIndex()
222
+ index.add_many(initial_episodes)
223
+ index.build()
224
+
225
+ # Add new episodes
226
+ index.add(new_episode)
227
+
228
+ # Rebuild required after adding
229
+ index.build()
230
+
231
+ # Now retriever will use updated index
232
+ retriever = DemoRetriever(index)
233
+ ```
234
+
235
+ ### Pattern 3: Batch Retrieval
236
+
237
+ ```python
238
+ # Retrieve for multiple tasks
239
+ tasks = [
240
+ "Turn off Night Shift",
241
+ "Enable dark mode",
242
+ "Adjust brightness",
243
+ ]
244
+
245
+ retriever = DemoRetriever(index)
246
+
247
+ for task in tasks:
248
+ demos = retriever.retrieve(task, top_k=3)
249
+ print(f"\nTask: {task}")
250
+ for demo in demos:
251
+ print(f" - {demo.goal}")
252
+ ```
253
+
254
+ ## Tuning Parameters
255
+
256
+ ### Domain Bonus
257
+
258
+ Controls how much to favor domain/app matches:
259
+
260
+ ```python
261
+ # No domain bonus - pure text similarity
262
+ retriever = DemoRetriever(index, domain_bonus=0.0)
263
+
264
+ # Small bonus (default)
265
+ retriever = DemoRetriever(index, domain_bonus=0.2)
266
+
267
+ # Large bonus - heavily favor same domain
268
+ retriever = DemoRetriever(index, domain_bonus=0.5)
269
+ ```
270
+
271
+ **Rule of thumb:**
272
+ - `0.0-0.1`: When task text is very specific and domain doesn't matter much
273
+ - `0.2-0.3`: Good default for most cases
274
+ - `0.4-0.5`: When domain matching is critical (e.g., domain-specific workflows)
275
+
276
+ ### Top-K
277
+
278
+ Number of demos to retrieve:
279
+
280
+ ```python
281
+ # Single best match
282
+ demos = retriever.retrieve(task, top_k=1)
283
+
284
+ # Few-shot with 3 examples
285
+ demos = retriever.retrieve(task, top_k=3)
286
+
287
+ # Retrieve more for analysis/selection
288
+ demos = retriever.retrieve(task, top_k=10)
289
+ ```
290
+
291
+ **Rule of thumb:**
292
+ - `top_k=1`: When prompt length is constrained
293
+ - `top_k=3`: Good default for few-shot learning
294
+ - `top_k=5+`: For ensemble methods or human selection
295
+
296
+ ## Performance Tips
297
+
298
+ ### 1. Build Once, Retrieve Many
299
+
300
+ ```python
301
+ # Good: Build once
302
+ index.build()
303
+ retriever = DemoRetriever(index)
304
+ for task in many_tasks:
305
+ retriever.retrieve(task)
306
+
307
+ # Bad: Build repeatedly
308
+ for task in many_tasks:
309
+ index.build() # Wasteful!
310
+ retriever = DemoRetriever(index)
311
+ retriever.retrieve(task)
312
+ ```
313
+
314
+ ### 2. Pre-extract Metadata
315
+
316
+ ```python
317
+ # Good: Extract once when adding
318
+ index.add(episode, app_name="Chrome", domain="github.com")
319
+
320
+ # Less efficient: Let auto-extraction scan every episode
321
+ index.add(episode) # Will scan steps for app_name and domain
322
+ ```
323
+
324
+ ### 3. Filter Before Retrieval
325
+
326
+ ```python
327
+ # If you have a large index but know the domain, create a filtered index
328
+ web_demos = [d for d in index.get_all_demos() if d.domain]
329
+ web_index = DemoIndex()
330
+ for demo in web_demos:
331
+ web_index.add(demo.episode)
332
+ web_index.build()
333
+ ```
334
+
335
+ ## Troubleshooting
336
+
337
+ ### Issue: All scores are 0.0
338
+
339
+ **Cause:** Only one episode in index, so IDF is undefined.
340
+
341
+ **Solution:** Add more episodes or use a larger demo library.
342
+
343
+ ```python
344
+ # Need at least 2-3 episodes for meaningful scores
345
+ assert len(index) >= 3, "Add more demos to the index"
346
+ ```
347
+
348
+ ### Issue: Domain bonus not applied
349
+
350
+ **Cause:** app_context doesn't match app_name or domain.
351
+
352
+ **Debug:**
353
+ ```python
354
+ results = retriever.retrieve_with_scores(task, app_context, top_k=5)
355
+ for r in results:
356
+ print(f"App: {r.demo.app_name}, Domain: {r.demo.domain}, Bonus: {r.domain_bonus}")
357
+ ```
358
+
359
+ **Solution:** Check exact string matching (case-insensitive contains).
360
+
361
+ ### Issue: Poor retrieval quality
362
+
363
+ **Causes:**
364
+ 1. Task descriptions too generic
365
+ 2. Demo library too small
366
+ 3. TF-IDF limitations
367
+
368
+ **Solutions:**
369
+ 1. Use more specific task descriptions
370
+ 2. Add more diverse demos to index
371
+ 3. Upgrade to sentence-transformers (see README.md § Future Improvements)
372
+
373
+ ## Testing
374
+
375
+ Run unit tests:
376
+ ```bash
377
+ uv run pytest tests/test_retrieval.py -v
378
+ ```
379
+
380
+ Run integration tests:
381
+ ```bash
382
+ uv run python test_retrieval.py
383
+ ```
384
+
385
+ ## Next Steps
386
+
387
+ 1. **Integrate with training**: Use retrieval in data augmentation
388
+ 2. **Experiment with prompting**: Test different demo counts and formats
389
+ 3. **Upgrade embeddings**: Try sentence-transformers for better similarity
390
+ 4. **Add filtering**: Support domain/app filtering before similarity scoring
391
+ 5. **Evaluate impact**: Measure action accuracy with/without retrieval
@@ -0,0 +1,91 @@
1
+ """Demo retrieval module for finding similar demonstrations.
2
+
3
+ This module provides functionality for indexing and retrieving demonstrations
4
+ based on semantic similarity of task descriptions.
5
+
6
+ Main Components:
7
+ - DemoRetriever: Main class for indexing and retrieving demos
8
+ - Embedders: TFIDFEmbedder, SentenceTransformerEmbedder, OpenAIEmbedder
9
+ - DemoIndex: Legacy index class (use DemoRetriever instead)
10
+
11
+ Quick Start:
12
+ from openadapt_ml.retrieval import DemoRetriever
13
+ from openadapt_ml.schema import Episode
14
+
15
+ # Create retriever (TF-IDF is default, no external dependencies)
16
+ retriever = DemoRetriever()
17
+
18
+ # Or use sentence-transformers for better semantic matching
19
+ retriever = DemoRetriever(embedding_method="sentence_transformers")
20
+
21
+ # Add demos
22
+ retriever.add_demo(episode1)
23
+ retriever.add_demo(episode2, app_name="Chrome", domain="github.com")
24
+
25
+ # Build index (required before retrieval)
26
+ retriever.build_index()
27
+
28
+ # Retrieve similar demos
29
+ results = retriever.retrieve("Turn off Night Shift", top_k=3)
30
+
31
+ # Format for inclusion in a prompt
32
+ prompt_text = retriever.format_for_prompt(results)
33
+
34
+ Embedding Methods:
35
+ # TF-IDF (default, no dependencies)
36
+ retriever = DemoRetriever(embedding_method="tfidf")
37
+
38
+ # Sentence Transformers (recommended, requires: pip install sentence-transformers)
39
+ retriever = DemoRetriever(
40
+ embedding_method="sentence_transformers",
41
+ embedding_model="all-MiniLM-L6-v2", # Fast, 22MB
42
+ )
43
+
44
+ # OpenAI (requires: pip install openai, OPENAI_API_KEY env var)
45
+ retriever = DemoRetriever(
46
+ embedding_method="openai",
47
+ embedding_model="text-embedding-3-small",
48
+ )
49
+
50
+ See Also:
51
+ - docs/demo_retrieval_design.md - Full design document
52
+ - openadapt_ml/experiments/demo_prompt/ - Demo-conditioned prompting
53
+ """
54
+
55
+ # Main retrieval class (recommended)
56
+ from openadapt_ml.retrieval.demo_retriever import (
57
+ DemoRetriever,
58
+ DemoMetadata,
59
+ RetrievalResult,
60
+ )
61
+
62
+ # Embedders
63
+ from openadapt_ml.retrieval.embeddings import (
64
+ BaseEmbedder,
65
+ TFIDFEmbedder,
66
+ TextEmbedder, # Alias for TFIDFEmbedder (backward compat)
67
+ SentenceTransformerEmbedder,
68
+ OpenAIEmbedder,
69
+ create_embedder,
70
+ )
71
+
72
+ # Legacy classes (for backward compatibility)
73
+ from openadapt_ml.retrieval.index import DemoIndex
74
+ from openadapt_ml.retrieval.retriever import DemoRetriever as LegacyDemoRetriever
75
+
76
+ __all__ = [
77
+ # Main classes
78
+ "DemoRetriever",
79
+ "DemoMetadata",
80
+ "RetrievalResult",
81
+ # Embedders
82
+ "BaseEmbedder",
83
+ "TFIDFEmbedder",
84
+ "TextEmbedder",
85
+ "SentenceTransformerEmbedder",
86
+ "OpenAIEmbedder",
87
+ "create_embedder",
88
+ # Legacy (backward compat)
89
+ "DemoIndex",
90
+ "LegacyDemoRetriever",
91
+ ]