openadapt-ml 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. openadapt_ml/benchmarks/__init__.py +8 -0
  2. openadapt_ml/benchmarks/agent.py +90 -11
  3. openadapt_ml/benchmarks/azure.py +35 -6
  4. openadapt_ml/benchmarks/cli.py +4449 -201
  5. openadapt_ml/benchmarks/live_tracker.py +180 -0
  6. openadapt_ml/benchmarks/runner.py +41 -4
  7. openadapt_ml/benchmarks/viewer.py +1219 -0
  8. openadapt_ml/benchmarks/vm_monitor.py +610 -0
  9. openadapt_ml/benchmarks/waa.py +61 -4
  10. openadapt_ml/benchmarks/waa_deploy/Dockerfile +222 -0
  11. openadapt_ml/benchmarks/waa_deploy/__init__.py +10 -0
  12. openadapt_ml/benchmarks/waa_deploy/api_agent.py +539 -0
  13. openadapt_ml/benchmarks/waa_deploy/start_waa_server.bat +53 -0
  14. openadapt_ml/benchmarks/waa_live.py +619 -0
  15. openadapt_ml/cloud/local.py +1555 -1
  16. openadapt_ml/cloud/ssh_tunnel.py +553 -0
  17. openadapt_ml/datasets/next_action.py +87 -68
  18. openadapt_ml/evals/grounding.py +26 -8
  19. openadapt_ml/evals/trajectory_matching.py +84 -36
  20. openadapt_ml/experiments/demo_prompt/__init__.py +19 -0
  21. openadapt_ml/experiments/demo_prompt/format_demo.py +226 -0
  22. openadapt_ml/experiments/demo_prompt/results/experiment_20251231_002125.json +83 -0
  23. openadapt_ml/experiments/demo_prompt/results/experiment_n30_20251231_165958.json +1100 -0
  24. openadapt_ml/experiments/demo_prompt/results/multistep_20251231_025051.json +182 -0
  25. openadapt_ml/experiments/demo_prompt/run_experiment.py +531 -0
  26. openadapt_ml/experiments/waa_demo/__init__.py +10 -0
  27. openadapt_ml/experiments/waa_demo/demos.py +357 -0
  28. openadapt_ml/experiments/waa_demo/runner.py +717 -0
  29. openadapt_ml/experiments/waa_demo/tasks.py +151 -0
  30. openadapt_ml/export/__init__.py +9 -0
  31. openadapt_ml/export/__main__.py +6 -0
  32. openadapt_ml/export/cli.py +89 -0
  33. openadapt_ml/export/parquet.py +265 -0
  34. openadapt_ml/ingest/__init__.py +3 -4
  35. openadapt_ml/ingest/capture.py +89 -81
  36. openadapt_ml/ingest/loader.py +116 -68
  37. openadapt_ml/ingest/synthetic.py +221 -159
  38. openadapt_ml/retrieval/README.md +226 -0
  39. openadapt_ml/retrieval/USAGE.md +391 -0
  40. openadapt_ml/retrieval/__init__.py +91 -0
  41. openadapt_ml/retrieval/demo_retriever.py +817 -0
  42. openadapt_ml/retrieval/embeddings.py +629 -0
  43. openadapt_ml/retrieval/index.py +194 -0
  44. openadapt_ml/retrieval/retriever.py +160 -0
  45. openadapt_ml/runtime/policy.py +10 -10
  46. openadapt_ml/schema/__init__.py +104 -0
  47. openadapt_ml/schema/converters.py +541 -0
  48. openadapt_ml/schema/episode.py +457 -0
  49. openadapt_ml/scripts/compare.py +26 -16
  50. openadapt_ml/scripts/eval_policy.py +4 -5
  51. openadapt_ml/scripts/prepare_synthetic.py +14 -17
  52. openadapt_ml/scripts/train.py +81 -70
  53. openadapt_ml/training/benchmark_viewer.py +3225 -0
  54. openadapt_ml/training/trainer.py +120 -363
  55. openadapt_ml/training/trl_trainer.py +354 -0
  56. {openadapt_ml-0.1.0.dist-info → openadapt_ml-0.2.0.dist-info}/METADATA +102 -60
  57. openadapt_ml-0.2.0.dist-info/RECORD +86 -0
  58. openadapt_ml/schemas/__init__.py +0 -53
  59. openadapt_ml/schemas/sessions.py +0 -122
  60. openadapt_ml/schemas/validation.py +0 -252
  61. openadapt_ml-0.1.0.dist-info/RECORD +0 -55
  62. {openadapt_ml-0.1.0.dist-info → openadapt_ml-0.2.0.dist-info}/WHEEL +0 -0
  63. {openadapt_ml-0.1.0.dist-info → openadapt_ml-0.2.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,817 @@
1
+ """Main Demo Retriever class for finding similar demonstrations.
2
+
3
+ This module provides the DemoRetriever class that indexes demos by their task
4
+ descriptions using embeddings and retrieves the most similar demo(s) from a library.
5
+
6
+ Key features:
7
+ - Supports both local embeddings (sentence-transformers) and API embeddings (OpenAI)
8
+ - Uses FAISS or simple cosine similarity for vector search
9
+ - Caches embeddings to avoid recomputing
10
+ - Returns top-k most similar demos with formatting for prompts
11
+
12
+ Example usage:
13
+ from openadapt_ml.retrieval import DemoRetriever
14
+ from openadapt_ml.schema import Episode
15
+
16
+ # Create retriever with local embeddings
17
+ retriever = DemoRetriever(embedding_method="sentence_transformers")
18
+
19
+ # Add demos
20
+ retriever.add_demo(episode1)
21
+ retriever.add_demo(episode2, app_name="Chrome", domain="github.com")
22
+
23
+ # Build index
24
+ retriever.build_index()
25
+
26
+ # Retrieve similar demos
27
+ results = retriever.retrieve("Turn off Night Shift", top_k=3)
28
+
29
+ # Format for prompt
30
+ prompt_text = retriever.format_for_prompt(results)
31
+ """
32
+
33
+ from __future__ import annotations
34
+
35
+ import hashlib
36
+ import json
37
+ import logging
38
+ from dataclasses import dataclass, field
39
+ from pathlib import Path
40
+ from typing import Any, Callable, List, Optional, Union
41
+
42
+ from openadapt_ml.schema import Episode
43
+
44
+ logger = logging.getLogger(__name__)
45
+
46
+
47
+ @dataclass
48
+ class DemoMetadata:
49
+ """Metadata for a single demonstration.
50
+
51
+ Stores both the episode and computed features for retrieval.
52
+
53
+ Attributes:
54
+ demo_id: Unique identifier for the demo.
55
+ episode: The full Episode object.
56
+ goal: Task description/instruction.
57
+ app_name: Optional application name (e.g., "System Settings").
58
+ domain: Optional domain (e.g., "github.com").
59
+ platform: Operating system platform ("macos", "windows", "web").
60
+ action_types: List of action types used in the demo.
61
+ key_elements: Important UI elements touched.
62
+ step_count: Number of steps in the demo.
63
+ tags: User-provided tags for categorization.
64
+ file_path: Path to the source file (if loaded from disk).
65
+ embedding: Computed embedding vector (numpy array).
66
+ metadata: Additional custom metadata.
67
+ """
68
+
69
+ demo_id: str
70
+ episode: Episode
71
+ goal: str
72
+ app_name: Optional[str] = None
73
+ domain: Optional[str] = None
74
+ platform: Optional[str] = None
75
+ action_types: List[str] = field(default_factory=list)
76
+ key_elements: List[str] = field(default_factory=list)
77
+ step_count: int = 0
78
+ tags: List[str] = field(default_factory=list)
79
+ file_path: Optional[str] = None
80
+ embedding: Optional[Any] = None # numpy array when computed
81
+ metadata: dict[str, Any] = field(default_factory=dict)
82
+
83
+
84
+ @dataclass
85
+ class RetrievalResult:
86
+ """A single retrieval result with score breakdown.
87
+
88
+ Attributes:
89
+ demo: The demo metadata.
90
+ score: Combined retrieval score (higher is better).
91
+ text_score: Text/embedding similarity component.
92
+ domain_bonus: Domain match bonus applied.
93
+ rank: Rank in the result list (1-indexed).
94
+ """
95
+
96
+ demo: DemoMetadata
97
+ score: float
98
+ text_score: float
99
+ domain_bonus: float
100
+ rank: int = 0
101
+
102
+
103
+ class DemoRetriever:
104
+ """Retrieves similar demonstrations from a library using embeddings.
105
+
106
+ Supports multiple embedding backends:
107
+ - "tfidf": Simple TF-IDF (no external dependencies, baseline)
108
+ - "sentence_transformers": Local embedding model (recommended)
109
+ - "openai": OpenAI text-embedding API
110
+
111
+ The retriever uses FAISS for efficient similarity search when available,
112
+ falling back to brute-force cosine similarity for small indices.
113
+
114
+ Example:
115
+ >>> retriever = DemoRetriever(embedding_method="sentence_transformers")
116
+ >>> retriever.add_demo(episode, app_name="Chrome")
117
+ >>> retriever.build_index()
118
+ >>> results = retriever.retrieve("Search on GitHub", top_k=3)
119
+ >>> print(results[0].demo.goal)
120
+ """
121
+
122
+ def __init__(
123
+ self,
124
+ embedding_method: str = "tfidf",
125
+ embedding_model: str = "all-MiniLM-L6-v2",
126
+ cache_dir: Optional[Path] = None,
127
+ domain_bonus: float = 0.2,
128
+ app_bonus: float = 0.15,
129
+ use_faiss: bool = True,
130
+ ) -> None:
131
+ """Initialize the DemoRetriever.
132
+
133
+ Args:
134
+ embedding_method: Embedding backend ("tfidf", "sentence_transformers", "openai").
135
+ embedding_model: Model name for sentence_transformers or OpenAI.
136
+ - For sentence_transformers: "all-MiniLM-L6-v2", "all-mpnet-base-v2", etc.
137
+ - For OpenAI: "text-embedding-3-small", "text-embedding-3-large", etc.
138
+ cache_dir: Directory for caching embeddings. If None, uses ~/.cache/openadapt_ml/embeddings.
139
+ domain_bonus: Bonus score for matching domain (default: 0.2).
140
+ app_bonus: Bonus score for matching app name (default: 0.15).
141
+ use_faiss: Whether to use FAISS for vector search (default: True).
142
+ """
143
+ self.embedding_method = embedding_method
144
+ self.embedding_model = embedding_model
145
+ self.domain_bonus = domain_bonus
146
+ self.app_bonus = app_bonus
147
+ self.use_faiss = use_faiss
148
+
149
+ # Set up cache directory
150
+ if cache_dir is None:
151
+ cache_dir = Path.home() / ".cache" / "openadapt_ml" / "embeddings"
152
+ self.cache_dir = Path(cache_dir)
153
+ self.cache_dir.mkdir(parents=True, exist_ok=True)
154
+
155
+ # Internal state
156
+ self._demos: List[DemoMetadata] = []
157
+ self._embedder: Optional[Any] = None
158
+ self._faiss_index: Optional[Any] = None
159
+ self._is_indexed = False
160
+ self._embeddings_matrix: Optional[Any] = None
161
+
162
+ # =========================================================================
163
+ # Demo Management
164
+ # =========================================================================
165
+
166
+ def add_demo(
167
+ self,
168
+ episode: Episode,
169
+ demo_id: Optional[str] = None,
170
+ app_name: Optional[str] = None,
171
+ domain: Optional[str] = None,
172
+ platform: Optional[str] = None,
173
+ tags: Optional[List[str]] = None,
174
+ file_path: Optional[str] = None,
175
+ metadata: Optional[dict[str, Any]] = None,
176
+ ) -> DemoMetadata:
177
+ """Add a demonstration episode to the library.
178
+
179
+ Args:
180
+ episode: The Episode to add.
181
+ demo_id: Unique ID (auto-generated from episode_id if not provided).
182
+ app_name: Application name (auto-extracted if not provided).
183
+ domain: Domain (auto-extracted from URLs if not provided).
184
+ platform: Platform ("macos", "windows", "web"). Auto-detected if not provided.
185
+ tags: User-provided tags for categorization.
186
+ file_path: Path to the source file.
187
+ metadata: Additional custom metadata.
188
+
189
+ Returns:
190
+ DemoMetadata object for the added demo.
191
+ """
192
+ # Auto-generate demo_id
193
+ if demo_id is None:
194
+ demo_id = episode.episode_id
195
+
196
+ # Auto-extract app_name
197
+ if app_name is None:
198
+ app_name = self._extract_app_name(episode)
199
+
200
+ # Auto-extract domain
201
+ if domain is None:
202
+ domain = self._extract_domain(episode)
203
+
204
+ # Auto-detect platform
205
+ if platform is None:
206
+ platform = self._detect_platform(episode, app_name, domain)
207
+
208
+ # Extract action types
209
+ action_types = list(set(
210
+ step.action.type.value if hasattr(step.action.type, 'value') else str(step.action.type)
211
+ for step in episode.steps
212
+ if step.action
213
+ ))
214
+
215
+ # Extract key elements
216
+ key_elements = self._extract_key_elements(episode)
217
+
218
+ demo_meta = DemoMetadata(
219
+ demo_id=demo_id,
220
+ episode=episode,
221
+ goal=episode.instruction,
222
+ app_name=app_name,
223
+ domain=domain,
224
+ platform=platform,
225
+ action_types=action_types,
226
+ key_elements=key_elements,
227
+ step_count=len(episode.steps),
228
+ tags=tags or [],
229
+ file_path=file_path,
230
+ metadata=metadata or {},
231
+ )
232
+
233
+ self._demos.append(demo_meta)
234
+ self._is_indexed = False # Need to rebuild index
235
+
236
+ return demo_meta
237
+
238
+ def add_demos(
239
+ self,
240
+ episodes: List[Episode],
241
+ **kwargs: Any,
242
+ ) -> List[DemoMetadata]:
243
+ """Add multiple demonstration episodes.
244
+
245
+ Args:
246
+ episodes: List of Episodes to add.
247
+ **kwargs: Additional arguments passed to add_demo.
248
+
249
+ Returns:
250
+ List of DemoMetadata objects.
251
+ """
252
+ return [self.add_demo(ep, **kwargs) for ep in episodes]
253
+
254
+ def get_demo_count(self) -> int:
255
+ """Get the number of demos in the library."""
256
+ return len(self._demos)
257
+
258
+ def get_all_demos(self) -> List[DemoMetadata]:
259
+ """Get all demo metadata objects."""
260
+ return list(self._demos)
261
+
262
+ def get_apps(self) -> List[str]:
263
+ """Get unique app names in the library."""
264
+ apps = {d.app_name for d in self._demos if d.app_name}
265
+ return sorted(apps)
266
+
267
+ def get_domains(self) -> List[str]:
268
+ """Get unique domains in the library."""
269
+ domains = {d.domain for d in self._demos if d.domain}
270
+ return sorted(domains)
271
+
272
+ def clear(self) -> None:
273
+ """Clear all demos and reset the index."""
274
+ self._demos = []
275
+ self._faiss_index = None
276
+ self._embeddings_matrix = None
277
+ self._is_indexed = False
278
+
279
+ # =========================================================================
280
+ # Indexing
281
+ # =========================================================================
282
+
283
+ def build_index(self, force: bool = False) -> None:
284
+ """Build the search index from all added demos.
285
+
286
+ This computes embeddings for all demos and builds the FAISS index.
287
+ Must be called before retrieve().
288
+
289
+ Args:
290
+ force: If True, rebuild even if already indexed.
291
+
292
+ Raises:
293
+ ValueError: If no demos have been added.
294
+ """
295
+ if self._is_indexed and not force:
296
+ logger.debug("Index already built, skipping (use force=True to rebuild)")
297
+ return
298
+
299
+ if not self._demos:
300
+ raise ValueError("Cannot build index: no demos added. Use add_demo() first.")
301
+
302
+ logger.info(f"Building index for {len(self._demos)} demos using {self.embedding_method}...")
303
+
304
+ # Initialize embedder if needed
305
+ if self._embedder is None:
306
+ self._init_embedder()
307
+
308
+ # Compute embeddings for all demos
309
+ texts = self._get_indexable_texts()
310
+ embeddings = self._compute_embeddings(texts)
311
+
312
+ # Store embeddings in demo metadata
313
+ for demo, emb in zip(self._demos, embeddings):
314
+ demo.embedding = emb
315
+
316
+ # Build FAISS index if available
317
+ self._embeddings_matrix = embeddings
318
+ if self.use_faiss:
319
+ self._build_faiss_index(embeddings)
320
+
321
+ self._is_indexed = True
322
+ logger.info(f"Index built successfully with {len(self._demos)} demos")
323
+
324
+ def save_index(self, path: Union[str, Path]) -> None:
325
+ """Save the index and embeddings to disk.
326
+
327
+ Args:
328
+ path: Directory to save index files.
329
+ """
330
+ path = Path(path)
331
+ path.mkdir(parents=True, exist_ok=True)
332
+
333
+ # Save demo metadata (without embeddings - too large)
334
+ metadata = []
335
+ for demo in self._demos:
336
+ meta = {
337
+ "demo_id": demo.demo_id,
338
+ "goal": demo.goal,
339
+ "app_name": demo.app_name,
340
+ "domain": demo.domain,
341
+ "platform": demo.platform,
342
+ "action_types": demo.action_types,
343
+ "key_elements": demo.key_elements,
344
+ "step_count": demo.step_count,
345
+ "tags": demo.tags,
346
+ "file_path": demo.file_path,
347
+ "metadata": demo.metadata,
348
+ }
349
+ metadata.append(meta)
350
+
351
+ # Prepare embedder state for TF-IDF (needed to recreate same embedding dimension)
352
+ embedder_state = {}
353
+ if self.embedding_method == "tfidf" and self._embedder is not None:
354
+ embedder_state = {
355
+ "vocab": self._embedder.vocab,
356
+ "vocab_to_idx": self._embedder.vocab_to_idx,
357
+ "idf": self._embedder.idf,
358
+ }
359
+
360
+ with open(path / "index.json", "w") as f:
361
+ json.dump({
362
+ "embedding_method": self.embedding_method,
363
+ "embedding_model": self.embedding_model,
364
+ "demos": metadata,
365
+ "embedder_state": embedder_state,
366
+ }, f, indent=2)
367
+
368
+ # Save embeddings as numpy array
369
+ try:
370
+ import numpy as np
371
+ if self._embeddings_matrix is not None:
372
+ np.save(path / "embeddings.npy", self._embeddings_matrix)
373
+ except ImportError:
374
+ logger.warning("numpy not available, embeddings not saved")
375
+
376
+ logger.info(f"Index saved to {path}")
377
+
378
+ def load_index(self, path: Union[str, Path], episode_loader: Optional[Callable[[str], Episode]] = None) -> None:
379
+ """Load index from disk.
380
+
381
+ Args:
382
+ path: Directory containing index files.
383
+ episode_loader: Optional function to load episodes from file_path.
384
+ If not provided, episodes will be None.
385
+ """
386
+ path = Path(path)
387
+
388
+ with open(path / "index.json") as f:
389
+ data = json.load(f)
390
+
391
+ self.embedding_method = data.get("embedding_method", self.embedding_method)
392
+ self.embedding_model = data.get("embedding_model", self.embedding_model)
393
+
394
+ # Load embeddings
395
+ embeddings = None
396
+ try:
397
+ import numpy as np
398
+ embeddings_path = path / "embeddings.npy"
399
+ if embeddings_path.exists():
400
+ embeddings = np.load(embeddings_path)
401
+ except ImportError:
402
+ pass
403
+
404
+ # Reconstruct demos
405
+ self._demos = []
406
+ for i, meta in enumerate(data.get("demos", [])):
407
+ episode = None
408
+ if episode_loader and meta.get("file_path"):
409
+ try:
410
+ episode = episode_loader(meta["file_path"])
411
+ except Exception as e:
412
+ logger.warning(f"Failed to load episode from {meta['file_path']}: {e}")
413
+
414
+ # Create placeholder episode if not loaded
415
+ if episode is None:
416
+ from openadapt_ml.schema import Action, ActionType, Observation, Step
417
+ episode = Episode(
418
+ episode_id=meta["demo_id"],
419
+ instruction=meta["goal"],
420
+ steps=[
421
+ Step(
422
+ step_index=0,
423
+ observation=Observation(),
424
+ action=Action(type=ActionType.DONE),
425
+ )
426
+ ],
427
+ )
428
+
429
+ demo = DemoMetadata(
430
+ demo_id=meta["demo_id"],
431
+ episode=episode,
432
+ goal=meta["goal"],
433
+ app_name=meta.get("app_name"),
434
+ domain=meta.get("domain"),
435
+ platform=meta.get("platform"),
436
+ action_types=meta.get("action_types", []),
437
+ key_elements=meta.get("key_elements", []),
438
+ step_count=meta.get("step_count", 0),
439
+ tags=meta.get("tags", []),
440
+ file_path=meta.get("file_path"),
441
+ metadata=meta.get("metadata", {}),
442
+ embedding=embeddings[i] if embeddings is not None else None,
443
+ )
444
+ self._demos.append(demo)
445
+
446
+ # Rebuild FAISS index if we have embeddings
447
+ if embeddings is not None:
448
+ self._embeddings_matrix = embeddings
449
+ if self.use_faiss:
450
+ self._build_faiss_index(embeddings)
451
+ self._is_indexed = True
452
+
453
+ # Restore embedder state for TF-IDF (needed for query embedding)
454
+ embedder_state = data.get("embedder_state", {})
455
+ if embedder_state and self.embedding_method == "tfidf":
456
+ from openadapt_ml.retrieval.embeddings import TFIDFEmbedder
457
+ self._embedder = TFIDFEmbedder()
458
+ self._embedder.vocab = embedder_state.get("vocab", [])
459
+ self._embedder.vocab_to_idx = embedder_state.get("vocab_to_idx", {})
460
+ self._embedder.idf = embedder_state.get("idf", {})
461
+ self._embedder._is_fitted = True
462
+
463
+ logger.info(f"Index loaded from {path} with {len(self._demos)} demos")
464
+
465
+ # =========================================================================
466
+ # Retrieval
467
+ # =========================================================================
468
+
469
+ def retrieve(
470
+ self,
471
+ query: str,
472
+ top_k: int = 3,
473
+ app_context: Optional[str] = None,
474
+ domain_context: Optional[str] = None,
475
+ filter_platform: Optional[str] = None,
476
+ filter_tags: Optional[List[str]] = None,
477
+ ) -> List[RetrievalResult]:
478
+ """Retrieve top-K most similar demos for a query.
479
+
480
+ Args:
481
+ query: Task description to find demos for.
482
+ top_k: Number of demos to retrieve.
483
+ app_context: Optional app context for bonus scoring (e.g., "Chrome").
484
+ domain_context: Optional domain context for bonus scoring (e.g., "github.com").
485
+ filter_platform: Only return demos from this platform.
486
+ filter_tags: Only return demos with all these tags.
487
+
488
+ Returns:
489
+ List of RetrievalResult objects, ordered by relevance (best first).
490
+
491
+ Raises:
492
+ ValueError: If index has not been built.
493
+ """
494
+ if not self._is_indexed:
495
+ raise ValueError("Index not built. Call build_index() first.")
496
+
497
+ if not self._demos:
498
+ return []
499
+
500
+ # Get query embedding
501
+ query_embedding = self._compute_embeddings([query])[0]
502
+
503
+ # Get candidates (optionally filtered)
504
+ candidates = self._get_candidates(filter_platform, filter_tags)
505
+ if not candidates:
506
+ return []
507
+
508
+ # Compute scores
509
+ results = []
510
+ for demo in candidates:
511
+ text_score = self._compute_similarity(query_embedding, demo.embedding)
512
+ bonus = self._compute_context_bonus(demo, app_context, domain_context)
513
+ total_score = text_score + bonus
514
+
515
+ results.append(RetrievalResult(
516
+ demo=demo,
517
+ score=total_score,
518
+ text_score=text_score,
519
+ domain_bonus=bonus,
520
+ ))
521
+
522
+ # Sort by score (descending)
523
+ results.sort(key=lambda r: r.score, reverse=True)
524
+
525
+ # Add ranks
526
+ for i, result in enumerate(results[:top_k]):
527
+ result.rank = i + 1
528
+
529
+ return results[:top_k]
530
+
531
+ def retrieve_episodes(
532
+ self,
533
+ query: str,
534
+ top_k: int = 3,
535
+ **kwargs: Any,
536
+ ) -> List[Episode]:
537
+ """Retrieve top-K episodes (convenience method).
538
+
539
+ Args:
540
+ query: Task description to find demos for.
541
+ top_k: Number of demos to retrieve.
542
+ **kwargs: Additional arguments passed to retrieve().
543
+
544
+ Returns:
545
+ List of Episode objects.
546
+ """
547
+ results = self.retrieve(query, top_k=top_k, **kwargs)
548
+ return [r.demo.episode for r in results]
549
+
550
+ # =========================================================================
551
+ # Prompt Formatting
552
+ # =========================================================================
553
+
554
+ def format_for_prompt(
555
+ self,
556
+ results: List[RetrievalResult],
557
+ max_steps_per_demo: int = 10,
558
+ include_scores: bool = False,
559
+ format_style: str = "concise",
560
+ ) -> str:
561
+ """Format retrieved demos for inclusion in a prompt.
562
+
563
+ Args:
564
+ results: Retrieval results from retrieve().
565
+ max_steps_per_demo: Maximum steps to include per demo.
566
+ include_scores: Whether to include relevance scores.
567
+ format_style: Formatting style ("concise", "verbose", "minimal").
568
+
569
+ Returns:
570
+ Formatted string for prompt injection.
571
+ """
572
+ if not results:
573
+ return ""
574
+
575
+ from openadapt_ml.experiments.demo_prompt.format_demo import (
576
+ format_episode_as_demo,
577
+ format_episode_verbose,
578
+ )
579
+
580
+ lines = []
581
+
582
+ if len(results) == 1:
583
+ lines.append("Here is a relevant demonstration:")
584
+ else:
585
+ lines.append(f"Here are {len(results)} relevant demonstrations:")
586
+ lines.append("")
587
+
588
+ for i, result in enumerate(results, 1):
589
+ if include_scores:
590
+ lines.append(f"Demo {i} (relevance: {result.score:.2f}):")
591
+ elif len(results) > 1:
592
+ lines.append(f"Demo {i}:")
593
+
594
+ if format_style == "verbose":
595
+ demo_text = format_episode_verbose(
596
+ result.demo.episode,
597
+ max_steps=max_steps_per_demo,
598
+ )
599
+ elif format_style == "minimal":
600
+ # Just goal and action sequence
601
+ steps_text = " -> ".join(
602
+ self._format_action_minimal(step.action)
603
+ for step in result.demo.episode.steps[:max_steps_per_demo]
604
+ if step.action
605
+ )
606
+ demo_text = f"Task: {result.demo.goal}\nSteps: {steps_text}"
607
+ else: # concise (default)
608
+ demo_text = format_episode_as_demo(
609
+ result.demo.episode,
610
+ max_steps=max_steps_per_demo,
611
+ )
612
+
613
+ lines.append(demo_text)
614
+ lines.append("")
615
+
616
+ return "\n".join(lines)
617
+
618
+ def _format_action_minimal(self, action: Any) -> str:
619
+ """Format action as minimal string."""
620
+ from openadapt_ml.experiments.demo_prompt.format_demo import format_action
621
+ return format_action(action)
622
+
623
+ # =========================================================================
624
+ # Private Methods
625
+ # =========================================================================
626
+
627
+ def _init_embedder(self) -> None:
628
+ """Initialize the embedding backend."""
629
+ if self.embedding_method == "tfidf":
630
+ from openadapt_ml.retrieval.embeddings import TFIDFEmbedder
631
+ self._embedder = TFIDFEmbedder()
632
+
633
+ elif self.embedding_method == "sentence_transformers":
634
+ from openadapt_ml.retrieval.embeddings import SentenceTransformerEmbedder
635
+ self._embedder = SentenceTransformerEmbedder(
636
+ model_name=self.embedding_model,
637
+ cache_dir=self.cache_dir / "st_cache",
638
+ )
639
+
640
+ elif self.embedding_method == "openai":
641
+ from openadapt_ml.retrieval.embeddings import OpenAIEmbedder
642
+ self._embedder = OpenAIEmbedder(
643
+ model_name=self.embedding_model,
644
+ cache_dir=self.cache_dir / "openai_cache",
645
+ )
646
+
647
+ else:
648
+ raise ValueError(f"Unknown embedding method: {self.embedding_method}")
649
+
650
+ def _get_indexable_texts(self) -> List[str]:
651
+ """Get text representations for indexing."""
652
+ texts = []
653
+ for demo in self._demos:
654
+ # Combine goal with context
655
+ parts = [demo.goal]
656
+ if demo.app_name:
657
+ parts.append(f"[APP:{demo.app_name}]")
658
+ if demo.domain:
659
+ parts.append(f"[DOMAIN:{demo.domain}]")
660
+ texts.append(" ".join(parts))
661
+ return texts
662
+
663
+ def _compute_embeddings(self, texts: List[str]) -> Any:
664
+ """Compute embeddings for texts."""
665
+ import numpy as np
666
+
667
+ if self._embedder is None:
668
+ self._init_embedder()
669
+
670
+ embeddings = self._embedder.embed_batch(texts)
671
+
672
+ # Ensure numpy array
673
+ if not isinstance(embeddings, np.ndarray):
674
+ embeddings = np.array(embeddings, dtype=np.float32)
675
+
676
+ return embeddings
677
+
678
+ def _build_faiss_index(self, embeddings: Any) -> None:
679
+ """Build FAISS index from embeddings."""
680
+ try:
681
+ import faiss
682
+ import numpy as np
683
+
684
+ embeddings = np.asarray(embeddings, dtype=np.float32)
685
+ dim = embeddings.shape[1]
686
+
687
+ # Use IndexFlatIP for cosine similarity (assumes normalized embeddings)
688
+ self._faiss_index = faiss.IndexFlatIP(dim)
689
+ self._faiss_index.add(embeddings)
690
+
691
+ logger.debug(f"Built FAISS index with {len(embeddings)} vectors, dim={dim}")
692
+ except ImportError:
693
+ logger.debug("FAISS not available, using brute-force search")
694
+ self._faiss_index = None
695
+
696
+ def _compute_similarity(self, query_embedding: Any, doc_embedding: Any) -> float:
697
+ """Compute similarity between query and document embeddings."""
698
+ import numpy as np
699
+
700
+ query_embedding = np.asarray(query_embedding, dtype=np.float32)
701
+ doc_embedding = np.asarray(doc_embedding, dtype=np.float32)
702
+
703
+ # Normalize for cosine similarity
704
+ query_norm = query_embedding / (np.linalg.norm(query_embedding) + 1e-9)
705
+ doc_norm = doc_embedding / (np.linalg.norm(doc_embedding) + 1e-9)
706
+
707
+ return float(np.dot(query_norm, doc_norm))
708
+
709
+ def _compute_context_bonus(
710
+ self,
711
+ demo: DemoMetadata,
712
+ app_context: Optional[str],
713
+ domain_context: Optional[str],
714
+ ) -> float:
715
+ """Compute context bonus for app/domain matching."""
716
+ bonus = 0.0
717
+
718
+ if app_context and demo.app_name:
719
+ if app_context.lower() in demo.app_name.lower():
720
+ bonus += self.app_bonus
721
+
722
+ if domain_context and demo.domain:
723
+ if domain_context.lower() in demo.domain.lower():
724
+ bonus += self.domain_bonus
725
+
726
+ return bonus
727
+
728
+ def _get_candidates(
729
+ self,
730
+ filter_platform: Optional[str],
731
+ filter_tags: Optional[List[str]],
732
+ ) -> List[DemoMetadata]:
733
+ """Get candidate demos after filtering."""
734
+ candidates = self._demos
735
+
736
+ if filter_platform:
737
+ candidates = [d for d in candidates if d.platform == filter_platform]
738
+
739
+ if filter_tags:
740
+ filter_tags_set = set(filter_tags)
741
+ candidates = [
742
+ d for d in candidates
743
+ if filter_tags_set.issubset(set(d.tags))
744
+ ]
745
+
746
+ return candidates
747
+
748
+ def _extract_app_name(self, episode: Episode) -> Optional[str]:
749
+ """Extract app name from episode observations."""
750
+ for step in episode.steps:
751
+ if step.observation and step.observation.app_name:
752
+ return step.observation.app_name
753
+ return None
754
+
755
+ def _extract_domain(self, episode: Episode) -> Optional[str]:
756
+ """Extract domain from episode URLs."""
757
+ for step in episode.steps:
758
+ if step.observation and step.observation.url:
759
+ url = step.observation.url
760
+ if "://" in url:
761
+ domain = url.split("://")[1].split("/")[0]
762
+ if domain.startswith("www."):
763
+ domain = domain[4:]
764
+ return domain
765
+ return None
766
+
767
+ def _detect_platform(
768
+ self,
769
+ episode: Episode,
770
+ app_name: Optional[str],
771
+ domain: Optional[str],
772
+ ) -> Optional[str]:
773
+ """Detect platform from episode context."""
774
+ # Check for web indicators
775
+ if domain:
776
+ return "web"
777
+
778
+ # Check for macOS app names
779
+ macos_apps = {"System Settings", "Finder", "Safari", "Preview", "TextEdit"}
780
+ if app_name and app_name in macos_apps:
781
+ return "macos"
782
+
783
+ # Check for Windows app names
784
+ windows_apps = {"Settings", "File Explorer", "Microsoft Edge", "Notepad"}
785
+ if app_name and app_name in windows_apps:
786
+ return "windows"
787
+
788
+ # Check episode metadata
789
+ if episode.environment:
790
+ env_lower = episode.environment.lower()
791
+ if "macos" in env_lower or "darwin" in env_lower:
792
+ return "macos"
793
+ if "windows" in env_lower:
794
+ return "windows"
795
+
796
+ return None
797
+
798
+ def _extract_key_elements(self, episode: Episode) -> List[str]:
799
+ """Extract key UI elements from episode."""
800
+ elements = []
801
+ for step in episode.steps:
802
+ if step.action and step.action.element:
803
+ elem = step.action.element
804
+ if elem.role and elem.name:
805
+ elements.append(f"{elem.role}:{elem.name}")
806
+ elif elem.name:
807
+ elements.append(elem.name)
808
+ return list(set(elements))
809
+
810
+ def __len__(self) -> int:
811
+ """Return number of demos in the library."""
812
+ return len(self._demos)
813
+
814
+ def __repr__(self) -> str:
815
+ """String representation."""
816
+ status = "indexed" if self._is_indexed else "not indexed"
817
+ return f"DemoRetriever({len(self._demos)} demos, {self.embedding_method}, {status})"