openadapt-ml 0.1.0__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (112) hide show
  1. openadapt_ml/baselines/__init__.py +121 -0
  2. openadapt_ml/baselines/adapter.py +185 -0
  3. openadapt_ml/baselines/cli.py +314 -0
  4. openadapt_ml/baselines/config.py +448 -0
  5. openadapt_ml/baselines/parser.py +922 -0
  6. openadapt_ml/baselines/prompts.py +787 -0
  7. openadapt_ml/benchmarks/__init__.py +13 -107
  8. openadapt_ml/benchmarks/agent.py +297 -374
  9. openadapt_ml/benchmarks/azure.py +62 -24
  10. openadapt_ml/benchmarks/azure_ops_tracker.py +521 -0
  11. openadapt_ml/benchmarks/cli.py +1874 -751
  12. openadapt_ml/benchmarks/trace_export.py +631 -0
  13. openadapt_ml/benchmarks/viewer.py +1236 -0
  14. openadapt_ml/benchmarks/vm_monitor.py +1111 -0
  15. openadapt_ml/benchmarks/waa_deploy/Dockerfile +216 -0
  16. openadapt_ml/benchmarks/waa_deploy/__init__.py +10 -0
  17. openadapt_ml/benchmarks/waa_deploy/api_agent.py +540 -0
  18. openadapt_ml/benchmarks/waa_deploy/start_waa_server.bat +53 -0
  19. openadapt_ml/cloud/azure_inference.py +3 -5
  20. openadapt_ml/cloud/lambda_labs.py +722 -307
  21. openadapt_ml/cloud/local.py +3194 -89
  22. openadapt_ml/cloud/ssh_tunnel.py +595 -0
  23. openadapt_ml/datasets/next_action.py +125 -96
  24. openadapt_ml/evals/grounding.py +32 -9
  25. openadapt_ml/evals/plot_eval_metrics.py +15 -13
  26. openadapt_ml/evals/trajectory_matching.py +120 -57
  27. openadapt_ml/experiments/demo_prompt/__init__.py +19 -0
  28. openadapt_ml/experiments/demo_prompt/format_demo.py +236 -0
  29. openadapt_ml/experiments/demo_prompt/results/experiment_20251231_002125.json +83 -0
  30. openadapt_ml/experiments/demo_prompt/results/experiment_n30_20251231_165958.json +1100 -0
  31. openadapt_ml/experiments/demo_prompt/results/multistep_20251231_025051.json +182 -0
  32. openadapt_ml/experiments/demo_prompt/run_experiment.py +541 -0
  33. openadapt_ml/experiments/representation_shootout/__init__.py +70 -0
  34. openadapt_ml/experiments/representation_shootout/conditions.py +708 -0
  35. openadapt_ml/experiments/representation_shootout/config.py +390 -0
  36. openadapt_ml/experiments/representation_shootout/evaluator.py +659 -0
  37. openadapt_ml/experiments/representation_shootout/runner.py +687 -0
  38. openadapt_ml/experiments/waa_demo/__init__.py +10 -0
  39. openadapt_ml/experiments/waa_demo/demos.py +357 -0
  40. openadapt_ml/experiments/waa_demo/runner.py +732 -0
  41. openadapt_ml/experiments/waa_demo/tasks.py +151 -0
  42. openadapt_ml/export/__init__.py +9 -0
  43. openadapt_ml/export/__main__.py +6 -0
  44. openadapt_ml/export/cli.py +89 -0
  45. openadapt_ml/export/parquet.py +277 -0
  46. openadapt_ml/grounding/detector.py +18 -14
  47. openadapt_ml/ingest/__init__.py +11 -10
  48. openadapt_ml/ingest/capture.py +97 -86
  49. openadapt_ml/ingest/loader.py +120 -69
  50. openadapt_ml/ingest/synthetic.py +344 -193
  51. openadapt_ml/models/api_adapter.py +14 -4
  52. openadapt_ml/models/base_adapter.py +10 -2
  53. openadapt_ml/models/providers/__init__.py +288 -0
  54. openadapt_ml/models/providers/anthropic.py +266 -0
  55. openadapt_ml/models/providers/base.py +299 -0
  56. openadapt_ml/models/providers/google.py +376 -0
  57. openadapt_ml/models/providers/openai.py +342 -0
  58. openadapt_ml/models/qwen_vl.py +46 -19
  59. openadapt_ml/perception/__init__.py +35 -0
  60. openadapt_ml/perception/integration.py +399 -0
  61. openadapt_ml/retrieval/README.md +226 -0
  62. openadapt_ml/retrieval/USAGE.md +391 -0
  63. openadapt_ml/retrieval/__init__.py +91 -0
  64. openadapt_ml/retrieval/demo_retriever.py +843 -0
  65. openadapt_ml/retrieval/embeddings.py +630 -0
  66. openadapt_ml/retrieval/index.py +194 -0
  67. openadapt_ml/retrieval/retriever.py +162 -0
  68. openadapt_ml/runtime/__init__.py +50 -0
  69. openadapt_ml/runtime/policy.py +27 -14
  70. openadapt_ml/runtime/safety_gate.py +471 -0
  71. openadapt_ml/schema/__init__.py +113 -0
  72. openadapt_ml/schema/converters.py +588 -0
  73. openadapt_ml/schema/episode.py +470 -0
  74. openadapt_ml/scripts/capture_screenshots.py +530 -0
  75. openadapt_ml/scripts/compare.py +102 -61
  76. openadapt_ml/scripts/demo_policy.py +4 -1
  77. openadapt_ml/scripts/eval_policy.py +19 -14
  78. openadapt_ml/scripts/make_gif.py +1 -1
  79. openadapt_ml/scripts/prepare_synthetic.py +16 -17
  80. openadapt_ml/scripts/train.py +98 -75
  81. openadapt_ml/segmentation/README.md +920 -0
  82. openadapt_ml/segmentation/__init__.py +97 -0
  83. openadapt_ml/segmentation/adapters/__init__.py +5 -0
  84. openadapt_ml/segmentation/adapters/capture_adapter.py +420 -0
  85. openadapt_ml/segmentation/annotator.py +610 -0
  86. openadapt_ml/segmentation/cache.py +290 -0
  87. openadapt_ml/segmentation/cli.py +674 -0
  88. openadapt_ml/segmentation/deduplicator.py +656 -0
  89. openadapt_ml/segmentation/frame_describer.py +788 -0
  90. openadapt_ml/segmentation/pipeline.py +340 -0
  91. openadapt_ml/segmentation/schemas.py +622 -0
  92. openadapt_ml/segmentation/segment_extractor.py +634 -0
  93. openadapt_ml/training/azure_ops_viewer.py +1097 -0
  94. openadapt_ml/training/benchmark_viewer.py +3255 -19
  95. openadapt_ml/training/shared_ui.py +7 -7
  96. openadapt_ml/training/stub_provider.py +57 -35
  97. openadapt_ml/training/trainer.py +255 -441
  98. openadapt_ml/training/trl_trainer.py +403 -0
  99. openadapt_ml/training/viewer.py +323 -108
  100. openadapt_ml/training/viewer_components.py +180 -0
  101. {openadapt_ml-0.1.0.dist-info → openadapt_ml-0.2.1.dist-info}/METADATA +312 -69
  102. openadapt_ml-0.2.1.dist-info/RECORD +116 -0
  103. openadapt_ml/benchmarks/base.py +0 -366
  104. openadapt_ml/benchmarks/data_collection.py +0 -432
  105. openadapt_ml/benchmarks/runner.py +0 -381
  106. openadapt_ml/benchmarks/waa.py +0 -704
  107. openadapt_ml/schemas/__init__.py +0 -53
  108. openadapt_ml/schemas/sessions.py +0 -122
  109. openadapt_ml/schemas/validation.py +0 -252
  110. openadapt_ml-0.1.0.dist-info/RECORD +0 -55
  111. {openadapt_ml-0.1.0.dist-info → openadapt_ml-0.2.1.dist-info}/WHEEL +0 -0
  112. {openadapt_ml-0.1.0.dist-info → openadapt_ml-0.2.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,843 @@
1
+ """Main Demo Retriever class for finding similar demonstrations.
2
+
3
+ This module provides the DemoRetriever class that indexes demos by their task
4
+ descriptions using embeddings and retrieves the most similar demo(s) from a library.
5
+
6
+ Key features:
7
+ - Supports both local embeddings (sentence-transformers) and API embeddings (OpenAI)
8
+ - Uses FAISS or simple cosine similarity for vector search
9
+ - Caches embeddings to avoid recomputing
10
+ - Returns top-k most similar demos with formatting for prompts
11
+
12
+ Example usage:
13
+ from openadapt_ml.retrieval import DemoRetriever
14
+ from openadapt_ml.schema import Episode
15
+
16
+ # Create retriever with local embeddings
17
+ retriever = DemoRetriever(embedding_method="sentence_transformers")
18
+
19
+ # Add demos
20
+ retriever.add_demo(episode1)
21
+ retriever.add_demo(episode2, app_name="Chrome", domain="github.com")
22
+
23
+ # Build index
24
+ retriever.build_index()
25
+
26
+ # Retrieve similar demos
27
+ results = retriever.retrieve("Turn off Night Shift", top_k=3)
28
+
29
+ # Format for prompt
30
+ prompt_text = retriever.format_for_prompt(results)
31
+ """
32
+
33
+ from __future__ import annotations
34
+
35
+ import json
36
+ import logging
37
+ from dataclasses import dataclass, field
38
+ from pathlib import Path
39
+ from typing import Any, Callable, List, Optional, Union
40
+
41
+ from openadapt_ml.schema import Episode
42
+
43
+ logger = logging.getLogger(__name__)
44
+
45
+
46
+ @dataclass
47
+ class DemoMetadata:
48
+ """Metadata for a single demonstration.
49
+
50
+ Stores both the episode and computed features for retrieval.
51
+
52
+ Attributes:
53
+ demo_id: Unique identifier for the demo.
54
+ episode: The full Episode object.
55
+ goal: Task description/instruction.
56
+ app_name: Optional application name (e.g., "System Settings").
57
+ domain: Optional domain (e.g., "github.com").
58
+ platform: Operating system platform ("macos", "windows", "web").
59
+ action_types: List of action types used in the demo.
60
+ key_elements: Important UI elements touched.
61
+ step_count: Number of steps in the demo.
62
+ tags: User-provided tags for categorization.
63
+ file_path: Path to the source file (if loaded from disk).
64
+ embedding: Computed embedding vector (numpy array).
65
+ metadata: Additional custom metadata.
66
+ """
67
+
68
+ demo_id: str
69
+ episode: Episode
70
+ goal: str
71
+ app_name: Optional[str] = None
72
+ domain: Optional[str] = None
73
+ platform: Optional[str] = None
74
+ action_types: List[str] = field(default_factory=list)
75
+ key_elements: List[str] = field(default_factory=list)
76
+ step_count: int = 0
77
+ tags: List[str] = field(default_factory=list)
78
+ file_path: Optional[str] = None
79
+ embedding: Optional[Any] = None # numpy array when computed
80
+ metadata: dict[str, Any] = field(default_factory=dict)
81
+
82
+
83
+ @dataclass
84
+ class RetrievalResult:
85
+ """A single retrieval result with score breakdown.
86
+
87
+ Attributes:
88
+ demo: The demo metadata.
89
+ score: Combined retrieval score (higher is better).
90
+ text_score: Text/embedding similarity component.
91
+ domain_bonus: Domain match bonus applied.
92
+ rank: Rank in the result list (1-indexed).
93
+ """
94
+
95
+ demo: DemoMetadata
96
+ score: float
97
+ text_score: float
98
+ domain_bonus: float
99
+ rank: int = 0
100
+
101
+
102
+ class DemoRetriever:
103
+ """Retrieves similar demonstrations from a library using embeddings.
104
+
105
+ Supports multiple embedding backends:
106
+ - "tfidf": Simple TF-IDF (no external dependencies, baseline)
107
+ - "sentence_transformers": Local embedding model (recommended)
108
+ - "openai": OpenAI text-embedding API
109
+
110
+ The retriever uses FAISS for efficient similarity search when available,
111
+ falling back to brute-force cosine similarity for small indices.
112
+
113
+ Example:
114
+ >>> retriever = DemoRetriever(embedding_method="sentence_transformers")
115
+ >>> retriever.add_demo(episode, app_name="Chrome")
116
+ >>> retriever.build_index()
117
+ >>> results = retriever.retrieve("Search on GitHub", top_k=3)
118
+ >>> print(results[0].demo.goal)
119
+ """
120
+
121
+ def __init__(
122
+ self,
123
+ embedding_method: str = "tfidf",
124
+ embedding_model: str = "all-MiniLM-L6-v2",
125
+ cache_dir: Optional[Path] = None,
126
+ domain_bonus: float = 0.2,
127
+ app_bonus: float = 0.15,
128
+ use_faiss: bool = True,
129
+ ) -> None:
130
+ """Initialize the DemoRetriever.
131
+
132
+ Args:
133
+ embedding_method: Embedding backend ("tfidf", "sentence_transformers", "openai").
134
+ embedding_model: Model name for sentence_transformers or OpenAI.
135
+ - For sentence_transformers: "all-MiniLM-L6-v2", "all-mpnet-base-v2", etc.
136
+ - For OpenAI: "text-embedding-3-small", "text-embedding-3-large", etc.
137
+ cache_dir: Directory for caching embeddings. If None, uses ~/.cache/openadapt_ml/embeddings.
138
+ domain_bonus: Bonus score for matching domain (default: 0.2).
139
+ app_bonus: Bonus score for matching app name (default: 0.15).
140
+ use_faiss: Whether to use FAISS for vector search (default: True).
141
+ """
142
+ self.embedding_method = embedding_method
143
+ self.embedding_model = embedding_model
144
+ self.domain_bonus = domain_bonus
145
+ self.app_bonus = app_bonus
146
+ self.use_faiss = use_faiss
147
+
148
+ # Set up cache directory
149
+ if cache_dir is None:
150
+ cache_dir = Path.home() / ".cache" / "openadapt_ml" / "embeddings"
151
+ self.cache_dir = Path(cache_dir)
152
+ self.cache_dir.mkdir(parents=True, exist_ok=True)
153
+
154
+ # Internal state
155
+ self._demos: List[DemoMetadata] = []
156
+ self._embedder: Optional[Any] = None
157
+ self._faiss_index: Optional[Any] = None
158
+ self._is_indexed = False
159
+ self._embeddings_matrix: Optional[Any] = None
160
+
161
+ # =========================================================================
162
+ # Demo Management
163
+ # =========================================================================
164
+
165
+ def add_demo(
166
+ self,
167
+ episode: Episode,
168
+ demo_id: Optional[str] = None,
169
+ app_name: Optional[str] = None,
170
+ domain: Optional[str] = None,
171
+ platform: Optional[str] = None,
172
+ tags: Optional[List[str]] = None,
173
+ file_path: Optional[str] = None,
174
+ metadata: Optional[dict[str, Any]] = None,
175
+ ) -> DemoMetadata:
176
+ """Add a demonstration episode to the library.
177
+
178
+ Args:
179
+ episode: The Episode to add.
180
+ demo_id: Unique ID (auto-generated from episode_id if not provided).
181
+ app_name: Application name (auto-extracted if not provided).
182
+ domain: Domain (auto-extracted from URLs if not provided).
183
+ platform: Platform ("macos", "windows", "web"). Auto-detected if not provided.
184
+ tags: User-provided tags for categorization.
185
+ file_path: Path to the source file.
186
+ metadata: Additional custom metadata.
187
+
188
+ Returns:
189
+ DemoMetadata object for the added demo.
190
+ """
191
+ # Auto-generate demo_id
192
+ if demo_id is None:
193
+ demo_id = episode.episode_id
194
+
195
+ # Auto-extract app_name
196
+ if app_name is None:
197
+ app_name = self._extract_app_name(episode)
198
+
199
+ # Auto-extract domain
200
+ if domain is None:
201
+ domain = self._extract_domain(episode)
202
+
203
+ # Auto-detect platform
204
+ if platform is None:
205
+ platform = self._detect_platform(episode, app_name, domain)
206
+
207
+ # Extract action types
208
+ action_types = list(
209
+ set(
210
+ step.action.type.value
211
+ if hasattr(step.action.type, "value")
212
+ else str(step.action.type)
213
+ for step in episode.steps
214
+ if step.action
215
+ )
216
+ )
217
+
218
+ # Extract key elements
219
+ key_elements = self._extract_key_elements(episode)
220
+
221
+ demo_meta = DemoMetadata(
222
+ demo_id=demo_id,
223
+ episode=episode,
224
+ goal=episode.instruction,
225
+ app_name=app_name,
226
+ domain=domain,
227
+ platform=platform,
228
+ action_types=action_types,
229
+ key_elements=key_elements,
230
+ step_count=len(episode.steps),
231
+ tags=tags or [],
232
+ file_path=file_path,
233
+ metadata=metadata or {},
234
+ )
235
+
236
+ self._demos.append(demo_meta)
237
+ self._is_indexed = False # Need to rebuild index
238
+
239
+ return demo_meta
240
+
241
+ def add_demos(
242
+ self,
243
+ episodes: List[Episode],
244
+ **kwargs: Any,
245
+ ) -> List[DemoMetadata]:
246
+ """Add multiple demonstration episodes.
247
+
248
+ Args:
249
+ episodes: List of Episodes to add.
250
+ **kwargs: Additional arguments passed to add_demo.
251
+
252
+ Returns:
253
+ List of DemoMetadata objects.
254
+ """
255
+ return [self.add_demo(ep, **kwargs) for ep in episodes]
256
+
257
+ def get_demo_count(self) -> int:
258
+ """Get the number of demos in the library."""
259
+ return len(self._demos)
260
+
261
+ def get_all_demos(self) -> List[DemoMetadata]:
262
+ """Get all demo metadata objects."""
263
+ return list(self._demos)
264
+
265
+ def get_apps(self) -> List[str]:
266
+ """Get unique app names in the library."""
267
+ apps = {d.app_name for d in self._demos if d.app_name}
268
+ return sorted(apps)
269
+
270
+ def get_domains(self) -> List[str]:
271
+ """Get unique domains in the library."""
272
+ domains = {d.domain for d in self._demos if d.domain}
273
+ return sorted(domains)
274
+
275
+ def clear(self) -> None:
276
+ """Clear all demos and reset the index."""
277
+ self._demos = []
278
+ self._faiss_index = None
279
+ self._embeddings_matrix = None
280
+ self._is_indexed = False
281
+
282
+ # =========================================================================
283
+ # Indexing
284
+ # =========================================================================
285
+
286
+ def build_index(self, force: bool = False) -> None:
287
+ """Build the search index from all added demos.
288
+
289
+ This computes embeddings for all demos and builds the FAISS index.
290
+ Must be called before retrieve().
291
+
292
+ Args:
293
+ force: If True, rebuild even if already indexed.
294
+
295
+ Raises:
296
+ ValueError: If no demos have been added.
297
+ """
298
+ if self._is_indexed and not force:
299
+ logger.debug("Index already built, skipping (use force=True to rebuild)")
300
+ return
301
+
302
+ if not self._demos:
303
+ raise ValueError(
304
+ "Cannot build index: no demos added. Use add_demo() first."
305
+ )
306
+
307
+ logger.info(
308
+ f"Building index for {len(self._demos)} demos using {self.embedding_method}..."
309
+ )
310
+
311
+ # Initialize embedder if needed
312
+ if self._embedder is None:
313
+ self._init_embedder()
314
+
315
+ # Compute embeddings for all demos
316
+ texts = self._get_indexable_texts()
317
+ embeddings = self._compute_embeddings(texts)
318
+
319
+ # Store embeddings in demo metadata
320
+ for demo, emb in zip(self._demos, embeddings):
321
+ demo.embedding = emb
322
+
323
+ # Build FAISS index if available
324
+ self._embeddings_matrix = embeddings
325
+ if self.use_faiss:
326
+ self._build_faiss_index(embeddings)
327
+
328
+ self._is_indexed = True
329
+ logger.info(f"Index built successfully with {len(self._demos)} demos")
330
+
331
+ def save_index(self, path: Union[str, Path]) -> None:
332
+ """Save the index and embeddings to disk.
333
+
334
+ Args:
335
+ path: Directory to save index files.
336
+ """
337
+ path = Path(path)
338
+ path.mkdir(parents=True, exist_ok=True)
339
+
340
+ # Save demo metadata (without embeddings - too large)
341
+ metadata = []
342
+ for demo in self._demos:
343
+ meta = {
344
+ "demo_id": demo.demo_id,
345
+ "goal": demo.goal,
346
+ "app_name": demo.app_name,
347
+ "domain": demo.domain,
348
+ "platform": demo.platform,
349
+ "action_types": demo.action_types,
350
+ "key_elements": demo.key_elements,
351
+ "step_count": demo.step_count,
352
+ "tags": demo.tags,
353
+ "file_path": demo.file_path,
354
+ "metadata": demo.metadata,
355
+ }
356
+ metadata.append(meta)
357
+
358
+ # Prepare embedder state for TF-IDF (needed to recreate same embedding dimension)
359
+ embedder_state = {}
360
+ if self.embedding_method == "tfidf" and self._embedder is not None:
361
+ embedder_state = {
362
+ "vocab": self._embedder.vocab,
363
+ "vocab_to_idx": self._embedder.vocab_to_idx,
364
+ "idf": self._embedder.idf,
365
+ }
366
+
367
+ with open(path / "index.json", "w") as f:
368
+ json.dump(
369
+ {
370
+ "embedding_method": self.embedding_method,
371
+ "embedding_model": self.embedding_model,
372
+ "demos": metadata,
373
+ "embedder_state": embedder_state,
374
+ },
375
+ f,
376
+ indent=2,
377
+ )
378
+
379
+ # Save embeddings as numpy array
380
+ try:
381
+ import numpy as np
382
+
383
+ if self._embeddings_matrix is not None:
384
+ np.save(path / "embeddings.npy", self._embeddings_matrix)
385
+ except ImportError:
386
+ logger.warning("numpy not available, embeddings not saved")
387
+
388
+ logger.info(f"Index saved to {path}")
389
+
390
+ def load_index(
391
+ self,
392
+ path: Union[str, Path],
393
+ episode_loader: Optional[Callable[[str], Episode]] = None,
394
+ ) -> None:
395
+ """Load index from disk.
396
+
397
+ Args:
398
+ path: Directory containing index files.
399
+ episode_loader: Optional function to load episodes from file_path.
400
+ If not provided, episodes will be None.
401
+ """
402
+ path = Path(path)
403
+
404
+ with open(path / "index.json") as f:
405
+ data = json.load(f)
406
+
407
+ self.embedding_method = data.get("embedding_method", self.embedding_method)
408
+ self.embedding_model = data.get("embedding_model", self.embedding_model)
409
+
410
+ # Load embeddings
411
+ embeddings = None
412
+ try:
413
+ import numpy as np
414
+
415
+ embeddings_path = path / "embeddings.npy"
416
+ if embeddings_path.exists():
417
+ embeddings = np.load(embeddings_path)
418
+ except ImportError:
419
+ pass
420
+
421
+ # Reconstruct demos
422
+ self._demos = []
423
+ for i, meta in enumerate(data.get("demos", [])):
424
+ episode = None
425
+ if episode_loader and meta.get("file_path"):
426
+ try:
427
+ episode = episode_loader(meta["file_path"])
428
+ except Exception as e:
429
+ logger.warning(
430
+ f"Failed to load episode from {meta['file_path']}: {e}"
431
+ )
432
+
433
+ # Create placeholder episode if not loaded
434
+ if episode is None:
435
+ from openadapt_ml.schema import Action, ActionType, Observation, Step
436
+
437
+ episode = Episode(
438
+ episode_id=meta["demo_id"],
439
+ instruction=meta["goal"],
440
+ steps=[
441
+ Step(
442
+ step_index=0,
443
+ observation=Observation(),
444
+ action=Action(type=ActionType.DONE),
445
+ )
446
+ ],
447
+ )
448
+
449
+ demo = DemoMetadata(
450
+ demo_id=meta["demo_id"],
451
+ episode=episode,
452
+ goal=meta["goal"],
453
+ app_name=meta.get("app_name"),
454
+ domain=meta.get("domain"),
455
+ platform=meta.get("platform"),
456
+ action_types=meta.get("action_types", []),
457
+ key_elements=meta.get("key_elements", []),
458
+ step_count=meta.get("step_count", 0),
459
+ tags=meta.get("tags", []),
460
+ file_path=meta.get("file_path"),
461
+ metadata=meta.get("metadata", {}),
462
+ embedding=embeddings[i] if embeddings is not None else None,
463
+ )
464
+ self._demos.append(demo)
465
+
466
+ # Rebuild FAISS index if we have embeddings
467
+ if embeddings is not None:
468
+ self._embeddings_matrix = embeddings
469
+ if self.use_faiss:
470
+ self._build_faiss_index(embeddings)
471
+ self._is_indexed = True
472
+
473
+ # Restore embedder state for TF-IDF (needed for query embedding)
474
+ embedder_state = data.get("embedder_state", {})
475
+ if embedder_state and self.embedding_method == "tfidf":
476
+ from openadapt_ml.retrieval.embeddings import TFIDFEmbedder
477
+
478
+ self._embedder = TFIDFEmbedder()
479
+ self._embedder.vocab = embedder_state.get("vocab", [])
480
+ self._embedder.vocab_to_idx = embedder_state.get("vocab_to_idx", {})
481
+ self._embedder.idf = embedder_state.get("idf", {})
482
+ self._embedder._is_fitted = True
483
+
484
+ logger.info(f"Index loaded from {path} with {len(self._demos)} demos")
485
+
486
+ # =========================================================================
487
+ # Retrieval
488
+ # =========================================================================
489
+
490
+ def retrieve(
491
+ self,
492
+ query: str,
493
+ top_k: int = 3,
494
+ app_context: Optional[str] = None,
495
+ domain_context: Optional[str] = None,
496
+ filter_platform: Optional[str] = None,
497
+ filter_tags: Optional[List[str]] = None,
498
+ ) -> List[RetrievalResult]:
499
+ """Retrieve top-K most similar demos for a query.
500
+
501
+ Args:
502
+ query: Task description to find demos for.
503
+ top_k: Number of demos to retrieve.
504
+ app_context: Optional app context for bonus scoring (e.g., "Chrome").
505
+ domain_context: Optional domain context for bonus scoring (e.g., "github.com").
506
+ filter_platform: Only return demos from this platform.
507
+ filter_tags: Only return demos with all these tags.
508
+
509
+ Returns:
510
+ List of RetrievalResult objects, ordered by relevance (best first).
511
+
512
+ Raises:
513
+ ValueError: If index has not been built.
514
+ """
515
+ if not self._is_indexed:
516
+ raise ValueError("Index not built. Call build_index() first.")
517
+
518
+ if not self._demos:
519
+ return []
520
+
521
+ # Get query embedding
522
+ query_embedding = self._compute_embeddings([query])[0]
523
+
524
+ # Get candidates (optionally filtered)
525
+ candidates = self._get_candidates(filter_platform, filter_tags)
526
+ if not candidates:
527
+ return []
528
+
529
+ # Compute scores
530
+ results = []
531
+ for demo in candidates:
532
+ text_score = self._compute_similarity(query_embedding, demo.embedding)
533
+ bonus = self._compute_context_bonus(demo, app_context, domain_context)
534
+ total_score = text_score + bonus
535
+
536
+ results.append(
537
+ RetrievalResult(
538
+ demo=demo,
539
+ score=total_score,
540
+ text_score=text_score,
541
+ domain_bonus=bonus,
542
+ )
543
+ )
544
+
545
+ # Sort by score (descending)
546
+ results.sort(key=lambda r: r.score, reverse=True)
547
+
548
+ # Add ranks
549
+ for i, result in enumerate(results[:top_k]):
550
+ result.rank = i + 1
551
+
552
+ return results[:top_k]
553
+
554
+ def retrieve_episodes(
555
+ self,
556
+ query: str,
557
+ top_k: int = 3,
558
+ **kwargs: Any,
559
+ ) -> List[Episode]:
560
+ """Retrieve top-K episodes (convenience method).
561
+
562
+ Args:
563
+ query: Task description to find demos for.
564
+ top_k: Number of demos to retrieve.
565
+ **kwargs: Additional arguments passed to retrieve().
566
+
567
+ Returns:
568
+ List of Episode objects.
569
+ """
570
+ results = self.retrieve(query, top_k=top_k, **kwargs)
571
+ return [r.demo.episode for r in results]
572
+
573
+ # =========================================================================
574
+ # Prompt Formatting
575
+ # =========================================================================
576
+
577
+ def format_for_prompt(
578
+ self,
579
+ results: List[RetrievalResult],
580
+ max_steps_per_demo: int = 10,
581
+ include_scores: bool = False,
582
+ format_style: str = "concise",
583
+ ) -> str:
584
+ """Format retrieved demos for inclusion in a prompt.
585
+
586
+ Args:
587
+ results: Retrieval results from retrieve().
588
+ max_steps_per_demo: Maximum steps to include per demo.
589
+ include_scores: Whether to include relevance scores.
590
+ format_style: Formatting style ("concise", "verbose", "minimal").
591
+
592
+ Returns:
593
+ Formatted string for prompt injection.
594
+ """
595
+ if not results:
596
+ return ""
597
+
598
+ from openadapt_ml.experiments.demo_prompt.format_demo import (
599
+ format_episode_as_demo,
600
+ format_episode_verbose,
601
+ )
602
+
603
+ lines = []
604
+
605
+ if len(results) == 1:
606
+ lines.append("Here is a relevant demonstration:")
607
+ else:
608
+ lines.append(f"Here are {len(results)} relevant demonstrations:")
609
+ lines.append("")
610
+
611
+ for i, result in enumerate(results, 1):
612
+ if include_scores:
613
+ lines.append(f"Demo {i} (relevance: {result.score:.2f}):")
614
+ elif len(results) > 1:
615
+ lines.append(f"Demo {i}:")
616
+
617
+ if format_style == "verbose":
618
+ demo_text = format_episode_verbose(
619
+ result.demo.episode,
620
+ max_steps=max_steps_per_demo,
621
+ )
622
+ elif format_style == "minimal":
623
+ # Just goal and action sequence
624
+ steps_text = " -> ".join(
625
+ self._format_action_minimal(step.action)
626
+ for step in result.demo.episode.steps[:max_steps_per_demo]
627
+ if step.action
628
+ )
629
+ demo_text = f"Task: {result.demo.goal}\nSteps: {steps_text}"
630
+ else: # concise (default)
631
+ demo_text = format_episode_as_demo(
632
+ result.demo.episode,
633
+ max_steps=max_steps_per_demo,
634
+ )
635
+
636
+ lines.append(demo_text)
637
+ lines.append("")
638
+
639
+ return "\n".join(lines)
640
+
641
+ def _format_action_minimal(self, action: Any) -> str:
642
+ """Format action as minimal string."""
643
+ from openadapt_ml.experiments.demo_prompt.format_demo import format_action
644
+
645
+ return format_action(action)
646
+
647
+ # =========================================================================
648
+ # Private Methods
649
+ # =========================================================================
650
+
651
+ def _init_embedder(self) -> None:
652
+ """Initialize the embedding backend."""
653
+ if self.embedding_method == "tfidf":
654
+ from openadapt_ml.retrieval.embeddings import TFIDFEmbedder
655
+
656
+ self._embedder = TFIDFEmbedder()
657
+
658
+ elif self.embedding_method == "sentence_transformers":
659
+ from openadapt_ml.retrieval.embeddings import SentenceTransformerEmbedder
660
+
661
+ self._embedder = SentenceTransformerEmbedder(
662
+ model_name=self.embedding_model,
663
+ cache_dir=self.cache_dir / "st_cache",
664
+ )
665
+
666
+ elif self.embedding_method == "openai":
667
+ from openadapt_ml.retrieval.embeddings import OpenAIEmbedder
668
+
669
+ self._embedder = OpenAIEmbedder(
670
+ model_name=self.embedding_model,
671
+ cache_dir=self.cache_dir / "openai_cache",
672
+ )
673
+
674
+ else:
675
+ raise ValueError(f"Unknown embedding method: {self.embedding_method}")
676
+
677
+ def _get_indexable_texts(self) -> List[str]:
678
+ """Get text representations for indexing."""
679
+ texts = []
680
+ for demo in self._demos:
681
+ # Combine goal with context
682
+ parts = [demo.goal]
683
+ if demo.app_name:
684
+ parts.append(f"[APP:{demo.app_name}]")
685
+ if demo.domain:
686
+ parts.append(f"[DOMAIN:{demo.domain}]")
687
+ texts.append(" ".join(parts))
688
+ return texts
689
+
690
+ def _compute_embeddings(self, texts: List[str]) -> Any:
691
+ """Compute embeddings for texts."""
692
+ import numpy as np
693
+
694
+ if self._embedder is None:
695
+ self._init_embedder()
696
+
697
+ embeddings = self._embedder.embed_batch(texts)
698
+
699
+ # Ensure numpy array
700
+ if not isinstance(embeddings, np.ndarray):
701
+ embeddings = np.array(embeddings, dtype=np.float32)
702
+
703
+ return embeddings
704
+
705
+ def _build_faiss_index(self, embeddings: Any) -> None:
706
+ """Build FAISS index from embeddings."""
707
+ try:
708
+ import faiss
709
+ import numpy as np
710
+
711
+ embeddings = np.asarray(embeddings, dtype=np.float32)
712
+ dim = embeddings.shape[1]
713
+
714
+ # Use IndexFlatIP for cosine similarity (assumes normalized embeddings)
715
+ self._faiss_index = faiss.IndexFlatIP(dim)
716
+ self._faiss_index.add(embeddings)
717
+
718
+ logger.debug(f"Built FAISS index with {len(embeddings)} vectors, dim={dim}")
719
+ except ImportError:
720
+ logger.debug("FAISS not available, using brute-force search")
721
+ self._faiss_index = None
722
+
723
+ def _compute_similarity(self, query_embedding: Any, doc_embedding: Any) -> float:
724
+ """Compute similarity between query and document embeddings."""
725
+ import numpy as np
726
+
727
+ query_embedding = np.asarray(query_embedding, dtype=np.float32)
728
+ doc_embedding = np.asarray(doc_embedding, dtype=np.float32)
729
+
730
+ # Normalize for cosine similarity
731
+ query_norm = query_embedding / (np.linalg.norm(query_embedding) + 1e-9)
732
+ doc_norm = doc_embedding / (np.linalg.norm(doc_embedding) + 1e-9)
733
+
734
+ return float(np.dot(query_norm, doc_norm))
735
+
736
+ def _compute_context_bonus(
737
+ self,
738
+ demo: DemoMetadata,
739
+ app_context: Optional[str],
740
+ domain_context: Optional[str],
741
+ ) -> float:
742
+ """Compute context bonus for app/domain matching."""
743
+ bonus = 0.0
744
+
745
+ if app_context and demo.app_name:
746
+ if app_context.lower() in demo.app_name.lower():
747
+ bonus += self.app_bonus
748
+
749
+ if domain_context and demo.domain:
750
+ if domain_context.lower() in demo.domain.lower():
751
+ bonus += self.domain_bonus
752
+
753
+ return bonus
754
+
755
+ def _get_candidates(
756
+ self,
757
+ filter_platform: Optional[str],
758
+ filter_tags: Optional[List[str]],
759
+ ) -> List[DemoMetadata]:
760
+ """Get candidate demos after filtering."""
761
+ candidates = self._demos
762
+
763
+ if filter_platform:
764
+ candidates = [d for d in candidates if d.platform == filter_platform]
765
+
766
+ if filter_tags:
767
+ filter_tags_set = set(filter_tags)
768
+ candidates = [
769
+ d for d in candidates if filter_tags_set.issubset(set(d.tags))
770
+ ]
771
+
772
+ return candidates
773
+
774
+ def _extract_app_name(self, episode: Episode) -> Optional[str]:
775
+ """Extract app name from episode observations."""
776
+ for step in episode.steps:
777
+ if step.observation and step.observation.app_name:
778
+ return step.observation.app_name
779
+ return None
780
+
781
+ def _extract_domain(self, episode: Episode) -> Optional[str]:
782
+ """Extract domain from episode URLs."""
783
+ for step in episode.steps:
784
+ if step.observation and step.observation.url:
785
+ url = step.observation.url
786
+ if "://" in url:
787
+ domain = url.split("://")[1].split("/")[0]
788
+ if domain.startswith("www."):
789
+ domain = domain[4:]
790
+ return domain
791
+ return None
792
+
793
+ def _detect_platform(
794
+ self,
795
+ episode: Episode,
796
+ app_name: Optional[str],
797
+ domain: Optional[str],
798
+ ) -> Optional[str]:
799
+ """Detect platform from episode context."""
800
+ # Check for web indicators
801
+ if domain:
802
+ return "web"
803
+
804
+ # Check for macOS app names
805
+ macos_apps = {"System Settings", "Finder", "Safari", "Preview", "TextEdit"}
806
+ if app_name and app_name in macos_apps:
807
+ return "macos"
808
+
809
+ # Check for Windows app names
810
+ windows_apps = {"Settings", "File Explorer", "Microsoft Edge", "Notepad"}
811
+ if app_name and app_name in windows_apps:
812
+ return "windows"
813
+
814
+ # Check episode metadata
815
+ if episode.environment:
816
+ env_lower = episode.environment.lower()
817
+ if "macos" in env_lower or "darwin" in env_lower:
818
+ return "macos"
819
+ if "windows" in env_lower:
820
+ return "windows"
821
+
822
+ return None
823
+
824
+ def _extract_key_elements(self, episode: Episode) -> List[str]:
825
+ """Extract key UI elements from episode."""
826
+ elements = []
827
+ for step in episode.steps:
828
+ if step.action and step.action.element:
829
+ elem = step.action.element
830
+ if elem.role and elem.name:
831
+ elements.append(f"{elem.role}:{elem.name}")
832
+ elif elem.name:
833
+ elements.append(elem.name)
834
+ return list(set(elements))
835
+
836
+ def __len__(self) -> int:
837
+ """Return number of demos in the library."""
838
+ return len(self._demos)
839
+
840
+ def __repr__(self) -> str:
841
+ """String representation."""
842
+ status = "indexed" if self._is_indexed else "not indexed"
843
+ return f"DemoRetriever({len(self._demos)} demos, {self.embedding_method}, {status})"