openadapt-ml 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openadapt_ml/benchmarks/__init__.py +8 -0
- openadapt_ml/benchmarks/agent.py +90 -11
- openadapt_ml/benchmarks/azure.py +35 -6
- openadapt_ml/benchmarks/cli.py +4449 -201
- openadapt_ml/benchmarks/live_tracker.py +180 -0
- openadapt_ml/benchmarks/runner.py +41 -4
- openadapt_ml/benchmarks/viewer.py +1219 -0
- openadapt_ml/benchmarks/vm_monitor.py +610 -0
- openadapt_ml/benchmarks/waa.py +61 -4
- openadapt_ml/benchmarks/waa_deploy/Dockerfile +222 -0
- openadapt_ml/benchmarks/waa_deploy/__init__.py +10 -0
- openadapt_ml/benchmarks/waa_deploy/api_agent.py +539 -0
- openadapt_ml/benchmarks/waa_deploy/start_waa_server.bat +53 -0
- openadapt_ml/benchmarks/waa_live.py +619 -0
- openadapt_ml/cloud/local.py +1555 -1
- openadapt_ml/cloud/ssh_tunnel.py +553 -0
- openadapt_ml/datasets/next_action.py +87 -68
- openadapt_ml/evals/grounding.py +26 -8
- openadapt_ml/evals/trajectory_matching.py +84 -36
- openadapt_ml/experiments/demo_prompt/__init__.py +19 -0
- openadapt_ml/experiments/demo_prompt/format_demo.py +226 -0
- openadapt_ml/experiments/demo_prompt/results/experiment_20251231_002125.json +83 -0
- openadapt_ml/experiments/demo_prompt/results/experiment_n30_20251231_165958.json +1100 -0
- openadapt_ml/experiments/demo_prompt/results/multistep_20251231_025051.json +182 -0
- openadapt_ml/experiments/demo_prompt/run_experiment.py +531 -0
- openadapt_ml/experiments/waa_demo/__init__.py +10 -0
- openadapt_ml/experiments/waa_demo/demos.py +357 -0
- openadapt_ml/experiments/waa_demo/runner.py +717 -0
- openadapt_ml/experiments/waa_demo/tasks.py +151 -0
- openadapt_ml/export/__init__.py +9 -0
- openadapt_ml/export/__main__.py +6 -0
- openadapt_ml/export/cli.py +89 -0
- openadapt_ml/export/parquet.py +265 -0
- openadapt_ml/ingest/__init__.py +3 -4
- openadapt_ml/ingest/capture.py +89 -81
- openadapt_ml/ingest/loader.py +116 -68
- openadapt_ml/ingest/synthetic.py +221 -159
- openadapt_ml/retrieval/README.md +226 -0
- openadapt_ml/retrieval/USAGE.md +391 -0
- openadapt_ml/retrieval/__init__.py +91 -0
- openadapt_ml/retrieval/demo_retriever.py +817 -0
- openadapt_ml/retrieval/embeddings.py +629 -0
- openadapt_ml/retrieval/index.py +194 -0
- openadapt_ml/retrieval/retriever.py +160 -0
- openadapt_ml/runtime/policy.py +10 -10
- openadapt_ml/schema/__init__.py +104 -0
- openadapt_ml/schema/converters.py +541 -0
- openadapt_ml/schema/episode.py +457 -0
- openadapt_ml/scripts/compare.py +26 -16
- openadapt_ml/scripts/eval_policy.py +4 -5
- openadapt_ml/scripts/prepare_synthetic.py +14 -17
- openadapt_ml/scripts/train.py +81 -70
- openadapt_ml/training/benchmark_viewer.py +3225 -0
- openadapt_ml/training/trainer.py +120 -363
- openadapt_ml/training/trl_trainer.py +354 -0
- {openadapt_ml-0.1.0.dist-info → openadapt_ml-0.2.0.dist-info}/METADATA +102 -60
- openadapt_ml-0.2.0.dist-info/RECORD +86 -0
- openadapt_ml/schemas/__init__.py +0 -53
- openadapt_ml/schemas/sessions.py +0 -122
- openadapt_ml/schemas/validation.py +0 -252
- openadapt_ml-0.1.0.dist-info/RECORD +0 -55
- {openadapt_ml-0.1.0.dist-info → openadapt_ml-0.2.0.dist-info}/WHEEL +0 -0
- {openadapt_ml-0.1.0.dist-info → openadapt_ml-0.2.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,817 @@
|
|
|
1
|
+
"""Main Demo Retriever class for finding similar demonstrations.
|
|
2
|
+
|
|
3
|
+
This module provides the DemoRetriever class that indexes demos by their task
|
|
4
|
+
descriptions using embeddings and retrieves the most similar demo(s) from a library.
|
|
5
|
+
|
|
6
|
+
Key features:
|
|
7
|
+
- Supports both local embeddings (sentence-transformers) and API embeddings (OpenAI)
|
|
8
|
+
- Uses FAISS or simple cosine similarity for vector search
|
|
9
|
+
- Caches embeddings to avoid recomputing
|
|
10
|
+
- Returns top-k most similar demos with formatting for prompts
|
|
11
|
+
|
|
12
|
+
Example usage:
|
|
13
|
+
from openadapt_ml.retrieval import DemoRetriever
|
|
14
|
+
from openadapt_ml.schema import Episode
|
|
15
|
+
|
|
16
|
+
# Create retriever with local embeddings
|
|
17
|
+
retriever = DemoRetriever(embedding_method="sentence_transformers")
|
|
18
|
+
|
|
19
|
+
# Add demos
|
|
20
|
+
retriever.add_demo(episode1)
|
|
21
|
+
retriever.add_demo(episode2, app_name="Chrome", domain="github.com")
|
|
22
|
+
|
|
23
|
+
# Build index
|
|
24
|
+
retriever.build_index()
|
|
25
|
+
|
|
26
|
+
# Retrieve similar demos
|
|
27
|
+
results = retriever.retrieve("Turn off Night Shift", top_k=3)
|
|
28
|
+
|
|
29
|
+
# Format for prompt
|
|
30
|
+
prompt_text = retriever.format_for_prompt(results)
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
from __future__ import annotations
|
|
34
|
+
|
|
35
|
+
import hashlib
|
|
36
|
+
import json
|
|
37
|
+
import logging
|
|
38
|
+
from dataclasses import dataclass, field
|
|
39
|
+
from pathlib import Path
|
|
40
|
+
from typing import Any, Callable, List, Optional, Union
|
|
41
|
+
|
|
42
|
+
from openadapt_ml.schema import Episode
|
|
43
|
+
|
|
44
|
+
logger = logging.getLogger(__name__)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
@dataclass
|
|
48
|
+
class DemoMetadata:
|
|
49
|
+
"""Metadata for a single demonstration.
|
|
50
|
+
|
|
51
|
+
Stores both the episode and computed features for retrieval.
|
|
52
|
+
|
|
53
|
+
Attributes:
|
|
54
|
+
demo_id: Unique identifier for the demo.
|
|
55
|
+
episode: The full Episode object.
|
|
56
|
+
goal: Task description/instruction.
|
|
57
|
+
app_name: Optional application name (e.g., "System Settings").
|
|
58
|
+
domain: Optional domain (e.g., "github.com").
|
|
59
|
+
platform: Operating system platform ("macos", "windows", "web").
|
|
60
|
+
action_types: List of action types used in the demo.
|
|
61
|
+
key_elements: Important UI elements touched.
|
|
62
|
+
step_count: Number of steps in the demo.
|
|
63
|
+
tags: User-provided tags for categorization.
|
|
64
|
+
file_path: Path to the source file (if loaded from disk).
|
|
65
|
+
embedding: Computed embedding vector (numpy array).
|
|
66
|
+
metadata: Additional custom metadata.
|
|
67
|
+
"""
|
|
68
|
+
|
|
69
|
+
demo_id: str
|
|
70
|
+
episode: Episode
|
|
71
|
+
goal: str
|
|
72
|
+
app_name: Optional[str] = None
|
|
73
|
+
domain: Optional[str] = None
|
|
74
|
+
platform: Optional[str] = None
|
|
75
|
+
action_types: List[str] = field(default_factory=list)
|
|
76
|
+
key_elements: List[str] = field(default_factory=list)
|
|
77
|
+
step_count: int = 0
|
|
78
|
+
tags: List[str] = field(default_factory=list)
|
|
79
|
+
file_path: Optional[str] = None
|
|
80
|
+
embedding: Optional[Any] = None # numpy array when computed
|
|
81
|
+
metadata: dict[str, Any] = field(default_factory=dict)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
@dataclass
|
|
85
|
+
class RetrievalResult:
|
|
86
|
+
"""A single retrieval result with score breakdown.
|
|
87
|
+
|
|
88
|
+
Attributes:
|
|
89
|
+
demo: The demo metadata.
|
|
90
|
+
score: Combined retrieval score (higher is better).
|
|
91
|
+
text_score: Text/embedding similarity component.
|
|
92
|
+
domain_bonus: Domain match bonus applied.
|
|
93
|
+
rank: Rank in the result list (1-indexed).
|
|
94
|
+
"""
|
|
95
|
+
|
|
96
|
+
demo: DemoMetadata
|
|
97
|
+
score: float
|
|
98
|
+
text_score: float
|
|
99
|
+
domain_bonus: float
|
|
100
|
+
rank: int = 0
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
class DemoRetriever:
|
|
104
|
+
"""Retrieves similar demonstrations from a library using embeddings.
|
|
105
|
+
|
|
106
|
+
Supports multiple embedding backends:
|
|
107
|
+
- "tfidf": Simple TF-IDF (no external dependencies, baseline)
|
|
108
|
+
- "sentence_transformers": Local embedding model (recommended)
|
|
109
|
+
- "openai": OpenAI text-embedding API
|
|
110
|
+
|
|
111
|
+
The retriever uses FAISS for efficient similarity search when available,
|
|
112
|
+
falling back to brute-force cosine similarity for small indices.
|
|
113
|
+
|
|
114
|
+
Example:
|
|
115
|
+
>>> retriever = DemoRetriever(embedding_method="sentence_transformers")
|
|
116
|
+
>>> retriever.add_demo(episode, app_name="Chrome")
|
|
117
|
+
>>> retriever.build_index()
|
|
118
|
+
>>> results = retriever.retrieve("Search on GitHub", top_k=3)
|
|
119
|
+
>>> print(results[0].demo.goal)
|
|
120
|
+
"""
|
|
121
|
+
|
|
122
|
+
def __init__(
|
|
123
|
+
self,
|
|
124
|
+
embedding_method: str = "tfidf",
|
|
125
|
+
embedding_model: str = "all-MiniLM-L6-v2",
|
|
126
|
+
cache_dir: Optional[Path] = None,
|
|
127
|
+
domain_bonus: float = 0.2,
|
|
128
|
+
app_bonus: float = 0.15,
|
|
129
|
+
use_faiss: bool = True,
|
|
130
|
+
) -> None:
|
|
131
|
+
"""Initialize the DemoRetriever.
|
|
132
|
+
|
|
133
|
+
Args:
|
|
134
|
+
embedding_method: Embedding backend ("tfidf", "sentence_transformers", "openai").
|
|
135
|
+
embedding_model: Model name for sentence_transformers or OpenAI.
|
|
136
|
+
- For sentence_transformers: "all-MiniLM-L6-v2", "all-mpnet-base-v2", etc.
|
|
137
|
+
- For OpenAI: "text-embedding-3-small", "text-embedding-3-large", etc.
|
|
138
|
+
cache_dir: Directory for caching embeddings. If None, uses ~/.cache/openadapt_ml/embeddings.
|
|
139
|
+
domain_bonus: Bonus score for matching domain (default: 0.2).
|
|
140
|
+
app_bonus: Bonus score for matching app name (default: 0.15).
|
|
141
|
+
use_faiss: Whether to use FAISS for vector search (default: True).
|
|
142
|
+
"""
|
|
143
|
+
self.embedding_method = embedding_method
|
|
144
|
+
self.embedding_model = embedding_model
|
|
145
|
+
self.domain_bonus = domain_bonus
|
|
146
|
+
self.app_bonus = app_bonus
|
|
147
|
+
self.use_faiss = use_faiss
|
|
148
|
+
|
|
149
|
+
# Set up cache directory
|
|
150
|
+
if cache_dir is None:
|
|
151
|
+
cache_dir = Path.home() / ".cache" / "openadapt_ml" / "embeddings"
|
|
152
|
+
self.cache_dir = Path(cache_dir)
|
|
153
|
+
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
|
154
|
+
|
|
155
|
+
# Internal state
|
|
156
|
+
self._demos: List[DemoMetadata] = []
|
|
157
|
+
self._embedder: Optional[Any] = None
|
|
158
|
+
self._faiss_index: Optional[Any] = None
|
|
159
|
+
self._is_indexed = False
|
|
160
|
+
self._embeddings_matrix: Optional[Any] = None
|
|
161
|
+
|
|
162
|
+
# =========================================================================
|
|
163
|
+
# Demo Management
|
|
164
|
+
# =========================================================================
|
|
165
|
+
|
|
166
|
+
def add_demo(
|
|
167
|
+
self,
|
|
168
|
+
episode: Episode,
|
|
169
|
+
demo_id: Optional[str] = None,
|
|
170
|
+
app_name: Optional[str] = None,
|
|
171
|
+
domain: Optional[str] = None,
|
|
172
|
+
platform: Optional[str] = None,
|
|
173
|
+
tags: Optional[List[str]] = None,
|
|
174
|
+
file_path: Optional[str] = None,
|
|
175
|
+
metadata: Optional[dict[str, Any]] = None,
|
|
176
|
+
) -> DemoMetadata:
|
|
177
|
+
"""Add a demonstration episode to the library.
|
|
178
|
+
|
|
179
|
+
Args:
|
|
180
|
+
episode: The Episode to add.
|
|
181
|
+
demo_id: Unique ID (auto-generated from episode_id if not provided).
|
|
182
|
+
app_name: Application name (auto-extracted if not provided).
|
|
183
|
+
domain: Domain (auto-extracted from URLs if not provided).
|
|
184
|
+
platform: Platform ("macos", "windows", "web"). Auto-detected if not provided.
|
|
185
|
+
tags: User-provided tags for categorization.
|
|
186
|
+
file_path: Path to the source file.
|
|
187
|
+
metadata: Additional custom metadata.
|
|
188
|
+
|
|
189
|
+
Returns:
|
|
190
|
+
DemoMetadata object for the added demo.
|
|
191
|
+
"""
|
|
192
|
+
# Auto-generate demo_id
|
|
193
|
+
if demo_id is None:
|
|
194
|
+
demo_id = episode.episode_id
|
|
195
|
+
|
|
196
|
+
# Auto-extract app_name
|
|
197
|
+
if app_name is None:
|
|
198
|
+
app_name = self._extract_app_name(episode)
|
|
199
|
+
|
|
200
|
+
# Auto-extract domain
|
|
201
|
+
if domain is None:
|
|
202
|
+
domain = self._extract_domain(episode)
|
|
203
|
+
|
|
204
|
+
# Auto-detect platform
|
|
205
|
+
if platform is None:
|
|
206
|
+
platform = self._detect_platform(episode, app_name, domain)
|
|
207
|
+
|
|
208
|
+
# Extract action types
|
|
209
|
+
action_types = list(set(
|
|
210
|
+
step.action.type.value if hasattr(step.action.type, 'value') else str(step.action.type)
|
|
211
|
+
for step in episode.steps
|
|
212
|
+
if step.action
|
|
213
|
+
))
|
|
214
|
+
|
|
215
|
+
# Extract key elements
|
|
216
|
+
key_elements = self._extract_key_elements(episode)
|
|
217
|
+
|
|
218
|
+
demo_meta = DemoMetadata(
|
|
219
|
+
demo_id=demo_id,
|
|
220
|
+
episode=episode,
|
|
221
|
+
goal=episode.instruction,
|
|
222
|
+
app_name=app_name,
|
|
223
|
+
domain=domain,
|
|
224
|
+
platform=platform,
|
|
225
|
+
action_types=action_types,
|
|
226
|
+
key_elements=key_elements,
|
|
227
|
+
step_count=len(episode.steps),
|
|
228
|
+
tags=tags or [],
|
|
229
|
+
file_path=file_path,
|
|
230
|
+
metadata=metadata or {},
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
self._demos.append(demo_meta)
|
|
234
|
+
self._is_indexed = False # Need to rebuild index
|
|
235
|
+
|
|
236
|
+
return demo_meta
|
|
237
|
+
|
|
238
|
+
def add_demos(
|
|
239
|
+
self,
|
|
240
|
+
episodes: List[Episode],
|
|
241
|
+
**kwargs: Any,
|
|
242
|
+
) -> List[DemoMetadata]:
|
|
243
|
+
"""Add multiple demonstration episodes.
|
|
244
|
+
|
|
245
|
+
Args:
|
|
246
|
+
episodes: List of Episodes to add.
|
|
247
|
+
**kwargs: Additional arguments passed to add_demo.
|
|
248
|
+
|
|
249
|
+
Returns:
|
|
250
|
+
List of DemoMetadata objects.
|
|
251
|
+
"""
|
|
252
|
+
return [self.add_demo(ep, **kwargs) for ep in episodes]
|
|
253
|
+
|
|
254
|
+
def get_demo_count(self) -> int:
|
|
255
|
+
"""Get the number of demos in the library."""
|
|
256
|
+
return len(self._demos)
|
|
257
|
+
|
|
258
|
+
def get_all_demos(self) -> List[DemoMetadata]:
|
|
259
|
+
"""Get all demo metadata objects."""
|
|
260
|
+
return list(self._demos)
|
|
261
|
+
|
|
262
|
+
def get_apps(self) -> List[str]:
|
|
263
|
+
"""Get unique app names in the library."""
|
|
264
|
+
apps = {d.app_name for d in self._demos if d.app_name}
|
|
265
|
+
return sorted(apps)
|
|
266
|
+
|
|
267
|
+
def get_domains(self) -> List[str]:
|
|
268
|
+
"""Get unique domains in the library."""
|
|
269
|
+
domains = {d.domain for d in self._demos if d.domain}
|
|
270
|
+
return sorted(domains)
|
|
271
|
+
|
|
272
|
+
def clear(self) -> None:
|
|
273
|
+
"""Clear all demos and reset the index."""
|
|
274
|
+
self._demos = []
|
|
275
|
+
self._faiss_index = None
|
|
276
|
+
self._embeddings_matrix = None
|
|
277
|
+
self._is_indexed = False
|
|
278
|
+
|
|
279
|
+
# =========================================================================
|
|
280
|
+
# Indexing
|
|
281
|
+
# =========================================================================
|
|
282
|
+
|
|
283
|
+
def build_index(self, force: bool = False) -> None:
|
|
284
|
+
"""Build the search index from all added demos.
|
|
285
|
+
|
|
286
|
+
This computes embeddings for all demos and builds the FAISS index.
|
|
287
|
+
Must be called before retrieve().
|
|
288
|
+
|
|
289
|
+
Args:
|
|
290
|
+
force: If True, rebuild even if already indexed.
|
|
291
|
+
|
|
292
|
+
Raises:
|
|
293
|
+
ValueError: If no demos have been added.
|
|
294
|
+
"""
|
|
295
|
+
if self._is_indexed and not force:
|
|
296
|
+
logger.debug("Index already built, skipping (use force=True to rebuild)")
|
|
297
|
+
return
|
|
298
|
+
|
|
299
|
+
if not self._demos:
|
|
300
|
+
raise ValueError("Cannot build index: no demos added. Use add_demo() first.")
|
|
301
|
+
|
|
302
|
+
logger.info(f"Building index for {len(self._demos)} demos using {self.embedding_method}...")
|
|
303
|
+
|
|
304
|
+
# Initialize embedder if needed
|
|
305
|
+
if self._embedder is None:
|
|
306
|
+
self._init_embedder()
|
|
307
|
+
|
|
308
|
+
# Compute embeddings for all demos
|
|
309
|
+
texts = self._get_indexable_texts()
|
|
310
|
+
embeddings = self._compute_embeddings(texts)
|
|
311
|
+
|
|
312
|
+
# Store embeddings in demo metadata
|
|
313
|
+
for demo, emb in zip(self._demos, embeddings):
|
|
314
|
+
demo.embedding = emb
|
|
315
|
+
|
|
316
|
+
# Build FAISS index if available
|
|
317
|
+
self._embeddings_matrix = embeddings
|
|
318
|
+
if self.use_faiss:
|
|
319
|
+
self._build_faiss_index(embeddings)
|
|
320
|
+
|
|
321
|
+
self._is_indexed = True
|
|
322
|
+
logger.info(f"Index built successfully with {len(self._demos)} demos")
|
|
323
|
+
|
|
324
|
+
def save_index(self, path: Union[str, Path]) -> None:
|
|
325
|
+
"""Save the index and embeddings to disk.
|
|
326
|
+
|
|
327
|
+
Args:
|
|
328
|
+
path: Directory to save index files.
|
|
329
|
+
"""
|
|
330
|
+
path = Path(path)
|
|
331
|
+
path.mkdir(parents=True, exist_ok=True)
|
|
332
|
+
|
|
333
|
+
# Save demo metadata (without embeddings - too large)
|
|
334
|
+
metadata = []
|
|
335
|
+
for demo in self._demos:
|
|
336
|
+
meta = {
|
|
337
|
+
"demo_id": demo.demo_id,
|
|
338
|
+
"goal": demo.goal,
|
|
339
|
+
"app_name": demo.app_name,
|
|
340
|
+
"domain": demo.domain,
|
|
341
|
+
"platform": demo.platform,
|
|
342
|
+
"action_types": demo.action_types,
|
|
343
|
+
"key_elements": demo.key_elements,
|
|
344
|
+
"step_count": demo.step_count,
|
|
345
|
+
"tags": demo.tags,
|
|
346
|
+
"file_path": demo.file_path,
|
|
347
|
+
"metadata": demo.metadata,
|
|
348
|
+
}
|
|
349
|
+
metadata.append(meta)
|
|
350
|
+
|
|
351
|
+
# Prepare embedder state for TF-IDF (needed to recreate same embedding dimension)
|
|
352
|
+
embedder_state = {}
|
|
353
|
+
if self.embedding_method == "tfidf" and self._embedder is not None:
|
|
354
|
+
embedder_state = {
|
|
355
|
+
"vocab": self._embedder.vocab,
|
|
356
|
+
"vocab_to_idx": self._embedder.vocab_to_idx,
|
|
357
|
+
"idf": self._embedder.idf,
|
|
358
|
+
}
|
|
359
|
+
|
|
360
|
+
with open(path / "index.json", "w") as f:
|
|
361
|
+
json.dump({
|
|
362
|
+
"embedding_method": self.embedding_method,
|
|
363
|
+
"embedding_model": self.embedding_model,
|
|
364
|
+
"demos": metadata,
|
|
365
|
+
"embedder_state": embedder_state,
|
|
366
|
+
}, f, indent=2)
|
|
367
|
+
|
|
368
|
+
# Save embeddings as numpy array
|
|
369
|
+
try:
|
|
370
|
+
import numpy as np
|
|
371
|
+
if self._embeddings_matrix is not None:
|
|
372
|
+
np.save(path / "embeddings.npy", self._embeddings_matrix)
|
|
373
|
+
except ImportError:
|
|
374
|
+
logger.warning("numpy not available, embeddings not saved")
|
|
375
|
+
|
|
376
|
+
logger.info(f"Index saved to {path}")
|
|
377
|
+
|
|
378
|
+
def load_index(self, path: Union[str, Path], episode_loader: Optional[Callable[[str], Episode]] = None) -> None:
|
|
379
|
+
"""Load index from disk.
|
|
380
|
+
|
|
381
|
+
Args:
|
|
382
|
+
path: Directory containing index files.
|
|
383
|
+
episode_loader: Optional function to load episodes from file_path.
|
|
384
|
+
If not provided, episodes will be None.
|
|
385
|
+
"""
|
|
386
|
+
path = Path(path)
|
|
387
|
+
|
|
388
|
+
with open(path / "index.json") as f:
|
|
389
|
+
data = json.load(f)
|
|
390
|
+
|
|
391
|
+
self.embedding_method = data.get("embedding_method", self.embedding_method)
|
|
392
|
+
self.embedding_model = data.get("embedding_model", self.embedding_model)
|
|
393
|
+
|
|
394
|
+
# Load embeddings
|
|
395
|
+
embeddings = None
|
|
396
|
+
try:
|
|
397
|
+
import numpy as np
|
|
398
|
+
embeddings_path = path / "embeddings.npy"
|
|
399
|
+
if embeddings_path.exists():
|
|
400
|
+
embeddings = np.load(embeddings_path)
|
|
401
|
+
except ImportError:
|
|
402
|
+
pass
|
|
403
|
+
|
|
404
|
+
# Reconstruct demos
|
|
405
|
+
self._demos = []
|
|
406
|
+
for i, meta in enumerate(data.get("demos", [])):
|
|
407
|
+
episode = None
|
|
408
|
+
if episode_loader and meta.get("file_path"):
|
|
409
|
+
try:
|
|
410
|
+
episode = episode_loader(meta["file_path"])
|
|
411
|
+
except Exception as e:
|
|
412
|
+
logger.warning(f"Failed to load episode from {meta['file_path']}: {e}")
|
|
413
|
+
|
|
414
|
+
# Create placeholder episode if not loaded
|
|
415
|
+
if episode is None:
|
|
416
|
+
from openadapt_ml.schema import Action, ActionType, Observation, Step
|
|
417
|
+
episode = Episode(
|
|
418
|
+
episode_id=meta["demo_id"],
|
|
419
|
+
instruction=meta["goal"],
|
|
420
|
+
steps=[
|
|
421
|
+
Step(
|
|
422
|
+
step_index=0,
|
|
423
|
+
observation=Observation(),
|
|
424
|
+
action=Action(type=ActionType.DONE),
|
|
425
|
+
)
|
|
426
|
+
],
|
|
427
|
+
)
|
|
428
|
+
|
|
429
|
+
demo = DemoMetadata(
|
|
430
|
+
demo_id=meta["demo_id"],
|
|
431
|
+
episode=episode,
|
|
432
|
+
goal=meta["goal"],
|
|
433
|
+
app_name=meta.get("app_name"),
|
|
434
|
+
domain=meta.get("domain"),
|
|
435
|
+
platform=meta.get("platform"),
|
|
436
|
+
action_types=meta.get("action_types", []),
|
|
437
|
+
key_elements=meta.get("key_elements", []),
|
|
438
|
+
step_count=meta.get("step_count", 0),
|
|
439
|
+
tags=meta.get("tags", []),
|
|
440
|
+
file_path=meta.get("file_path"),
|
|
441
|
+
metadata=meta.get("metadata", {}),
|
|
442
|
+
embedding=embeddings[i] if embeddings is not None else None,
|
|
443
|
+
)
|
|
444
|
+
self._demos.append(demo)
|
|
445
|
+
|
|
446
|
+
# Rebuild FAISS index if we have embeddings
|
|
447
|
+
if embeddings is not None:
|
|
448
|
+
self._embeddings_matrix = embeddings
|
|
449
|
+
if self.use_faiss:
|
|
450
|
+
self._build_faiss_index(embeddings)
|
|
451
|
+
self._is_indexed = True
|
|
452
|
+
|
|
453
|
+
# Restore embedder state for TF-IDF (needed for query embedding)
|
|
454
|
+
embedder_state = data.get("embedder_state", {})
|
|
455
|
+
if embedder_state and self.embedding_method == "tfidf":
|
|
456
|
+
from openadapt_ml.retrieval.embeddings import TFIDFEmbedder
|
|
457
|
+
self._embedder = TFIDFEmbedder()
|
|
458
|
+
self._embedder.vocab = embedder_state.get("vocab", [])
|
|
459
|
+
self._embedder.vocab_to_idx = embedder_state.get("vocab_to_idx", {})
|
|
460
|
+
self._embedder.idf = embedder_state.get("idf", {})
|
|
461
|
+
self._embedder._is_fitted = True
|
|
462
|
+
|
|
463
|
+
logger.info(f"Index loaded from {path} with {len(self._demos)} demos")
|
|
464
|
+
|
|
465
|
+
# =========================================================================
|
|
466
|
+
# Retrieval
|
|
467
|
+
# =========================================================================
|
|
468
|
+
|
|
469
|
+
def retrieve(
|
|
470
|
+
self,
|
|
471
|
+
query: str,
|
|
472
|
+
top_k: int = 3,
|
|
473
|
+
app_context: Optional[str] = None,
|
|
474
|
+
domain_context: Optional[str] = None,
|
|
475
|
+
filter_platform: Optional[str] = None,
|
|
476
|
+
filter_tags: Optional[List[str]] = None,
|
|
477
|
+
) -> List[RetrievalResult]:
|
|
478
|
+
"""Retrieve top-K most similar demos for a query.
|
|
479
|
+
|
|
480
|
+
Args:
|
|
481
|
+
query: Task description to find demos for.
|
|
482
|
+
top_k: Number of demos to retrieve.
|
|
483
|
+
app_context: Optional app context for bonus scoring (e.g., "Chrome").
|
|
484
|
+
domain_context: Optional domain context for bonus scoring (e.g., "github.com").
|
|
485
|
+
filter_platform: Only return demos from this platform.
|
|
486
|
+
filter_tags: Only return demos with all these tags.
|
|
487
|
+
|
|
488
|
+
Returns:
|
|
489
|
+
List of RetrievalResult objects, ordered by relevance (best first).
|
|
490
|
+
|
|
491
|
+
Raises:
|
|
492
|
+
ValueError: If index has not been built.
|
|
493
|
+
"""
|
|
494
|
+
if not self._is_indexed:
|
|
495
|
+
raise ValueError("Index not built. Call build_index() first.")
|
|
496
|
+
|
|
497
|
+
if not self._demos:
|
|
498
|
+
return []
|
|
499
|
+
|
|
500
|
+
# Get query embedding
|
|
501
|
+
query_embedding = self._compute_embeddings([query])[0]
|
|
502
|
+
|
|
503
|
+
# Get candidates (optionally filtered)
|
|
504
|
+
candidates = self._get_candidates(filter_platform, filter_tags)
|
|
505
|
+
if not candidates:
|
|
506
|
+
return []
|
|
507
|
+
|
|
508
|
+
# Compute scores
|
|
509
|
+
results = []
|
|
510
|
+
for demo in candidates:
|
|
511
|
+
text_score = self._compute_similarity(query_embedding, demo.embedding)
|
|
512
|
+
bonus = self._compute_context_bonus(demo, app_context, domain_context)
|
|
513
|
+
total_score = text_score + bonus
|
|
514
|
+
|
|
515
|
+
results.append(RetrievalResult(
|
|
516
|
+
demo=demo,
|
|
517
|
+
score=total_score,
|
|
518
|
+
text_score=text_score,
|
|
519
|
+
domain_bonus=bonus,
|
|
520
|
+
))
|
|
521
|
+
|
|
522
|
+
# Sort by score (descending)
|
|
523
|
+
results.sort(key=lambda r: r.score, reverse=True)
|
|
524
|
+
|
|
525
|
+
# Add ranks
|
|
526
|
+
for i, result in enumerate(results[:top_k]):
|
|
527
|
+
result.rank = i + 1
|
|
528
|
+
|
|
529
|
+
return results[:top_k]
|
|
530
|
+
|
|
531
|
+
def retrieve_episodes(
|
|
532
|
+
self,
|
|
533
|
+
query: str,
|
|
534
|
+
top_k: int = 3,
|
|
535
|
+
**kwargs: Any,
|
|
536
|
+
) -> List[Episode]:
|
|
537
|
+
"""Retrieve top-K episodes (convenience method).
|
|
538
|
+
|
|
539
|
+
Args:
|
|
540
|
+
query: Task description to find demos for.
|
|
541
|
+
top_k: Number of demos to retrieve.
|
|
542
|
+
**kwargs: Additional arguments passed to retrieve().
|
|
543
|
+
|
|
544
|
+
Returns:
|
|
545
|
+
List of Episode objects.
|
|
546
|
+
"""
|
|
547
|
+
results = self.retrieve(query, top_k=top_k, **kwargs)
|
|
548
|
+
return [r.demo.episode for r in results]
|
|
549
|
+
|
|
550
|
+
# =========================================================================
|
|
551
|
+
# Prompt Formatting
|
|
552
|
+
# =========================================================================
|
|
553
|
+
|
|
554
|
+
def format_for_prompt(
|
|
555
|
+
self,
|
|
556
|
+
results: List[RetrievalResult],
|
|
557
|
+
max_steps_per_demo: int = 10,
|
|
558
|
+
include_scores: bool = False,
|
|
559
|
+
format_style: str = "concise",
|
|
560
|
+
) -> str:
|
|
561
|
+
"""Format retrieved demos for inclusion in a prompt.
|
|
562
|
+
|
|
563
|
+
Args:
|
|
564
|
+
results: Retrieval results from retrieve().
|
|
565
|
+
max_steps_per_demo: Maximum steps to include per demo.
|
|
566
|
+
include_scores: Whether to include relevance scores.
|
|
567
|
+
format_style: Formatting style ("concise", "verbose", "minimal").
|
|
568
|
+
|
|
569
|
+
Returns:
|
|
570
|
+
Formatted string for prompt injection.
|
|
571
|
+
"""
|
|
572
|
+
if not results:
|
|
573
|
+
return ""
|
|
574
|
+
|
|
575
|
+
from openadapt_ml.experiments.demo_prompt.format_demo import (
|
|
576
|
+
format_episode_as_demo,
|
|
577
|
+
format_episode_verbose,
|
|
578
|
+
)
|
|
579
|
+
|
|
580
|
+
lines = []
|
|
581
|
+
|
|
582
|
+
if len(results) == 1:
|
|
583
|
+
lines.append("Here is a relevant demonstration:")
|
|
584
|
+
else:
|
|
585
|
+
lines.append(f"Here are {len(results)} relevant demonstrations:")
|
|
586
|
+
lines.append("")
|
|
587
|
+
|
|
588
|
+
for i, result in enumerate(results, 1):
|
|
589
|
+
if include_scores:
|
|
590
|
+
lines.append(f"Demo {i} (relevance: {result.score:.2f}):")
|
|
591
|
+
elif len(results) > 1:
|
|
592
|
+
lines.append(f"Demo {i}:")
|
|
593
|
+
|
|
594
|
+
if format_style == "verbose":
|
|
595
|
+
demo_text = format_episode_verbose(
|
|
596
|
+
result.demo.episode,
|
|
597
|
+
max_steps=max_steps_per_demo,
|
|
598
|
+
)
|
|
599
|
+
elif format_style == "minimal":
|
|
600
|
+
# Just goal and action sequence
|
|
601
|
+
steps_text = " -> ".join(
|
|
602
|
+
self._format_action_minimal(step.action)
|
|
603
|
+
for step in result.demo.episode.steps[:max_steps_per_demo]
|
|
604
|
+
if step.action
|
|
605
|
+
)
|
|
606
|
+
demo_text = f"Task: {result.demo.goal}\nSteps: {steps_text}"
|
|
607
|
+
else: # concise (default)
|
|
608
|
+
demo_text = format_episode_as_demo(
|
|
609
|
+
result.demo.episode,
|
|
610
|
+
max_steps=max_steps_per_demo,
|
|
611
|
+
)
|
|
612
|
+
|
|
613
|
+
lines.append(demo_text)
|
|
614
|
+
lines.append("")
|
|
615
|
+
|
|
616
|
+
return "\n".join(lines)
|
|
617
|
+
|
|
618
|
+
def _format_action_minimal(self, action: Any) -> str:
|
|
619
|
+
"""Format action as minimal string."""
|
|
620
|
+
from openadapt_ml.experiments.demo_prompt.format_demo import format_action
|
|
621
|
+
return format_action(action)
|
|
622
|
+
|
|
623
|
+
# =========================================================================
|
|
624
|
+
# Private Methods
|
|
625
|
+
# =========================================================================
|
|
626
|
+
|
|
627
|
+
def _init_embedder(self) -> None:
|
|
628
|
+
"""Initialize the embedding backend."""
|
|
629
|
+
if self.embedding_method == "tfidf":
|
|
630
|
+
from openadapt_ml.retrieval.embeddings import TFIDFEmbedder
|
|
631
|
+
self._embedder = TFIDFEmbedder()
|
|
632
|
+
|
|
633
|
+
elif self.embedding_method == "sentence_transformers":
|
|
634
|
+
from openadapt_ml.retrieval.embeddings import SentenceTransformerEmbedder
|
|
635
|
+
self._embedder = SentenceTransformerEmbedder(
|
|
636
|
+
model_name=self.embedding_model,
|
|
637
|
+
cache_dir=self.cache_dir / "st_cache",
|
|
638
|
+
)
|
|
639
|
+
|
|
640
|
+
elif self.embedding_method == "openai":
|
|
641
|
+
from openadapt_ml.retrieval.embeddings import OpenAIEmbedder
|
|
642
|
+
self._embedder = OpenAIEmbedder(
|
|
643
|
+
model_name=self.embedding_model,
|
|
644
|
+
cache_dir=self.cache_dir / "openai_cache",
|
|
645
|
+
)
|
|
646
|
+
|
|
647
|
+
else:
|
|
648
|
+
raise ValueError(f"Unknown embedding method: {self.embedding_method}")
|
|
649
|
+
|
|
650
|
+
def _get_indexable_texts(self) -> List[str]:
|
|
651
|
+
"""Get text representations for indexing."""
|
|
652
|
+
texts = []
|
|
653
|
+
for demo in self._demos:
|
|
654
|
+
# Combine goal with context
|
|
655
|
+
parts = [demo.goal]
|
|
656
|
+
if demo.app_name:
|
|
657
|
+
parts.append(f"[APP:{demo.app_name}]")
|
|
658
|
+
if demo.domain:
|
|
659
|
+
parts.append(f"[DOMAIN:{demo.domain}]")
|
|
660
|
+
texts.append(" ".join(parts))
|
|
661
|
+
return texts
|
|
662
|
+
|
|
663
|
+
def _compute_embeddings(self, texts: List[str]) -> Any:
|
|
664
|
+
"""Compute embeddings for texts."""
|
|
665
|
+
import numpy as np
|
|
666
|
+
|
|
667
|
+
if self._embedder is None:
|
|
668
|
+
self._init_embedder()
|
|
669
|
+
|
|
670
|
+
embeddings = self._embedder.embed_batch(texts)
|
|
671
|
+
|
|
672
|
+
# Ensure numpy array
|
|
673
|
+
if not isinstance(embeddings, np.ndarray):
|
|
674
|
+
embeddings = np.array(embeddings, dtype=np.float32)
|
|
675
|
+
|
|
676
|
+
return embeddings
|
|
677
|
+
|
|
678
|
+
def _build_faiss_index(self, embeddings: Any) -> None:
|
|
679
|
+
"""Build FAISS index from embeddings."""
|
|
680
|
+
try:
|
|
681
|
+
import faiss
|
|
682
|
+
import numpy as np
|
|
683
|
+
|
|
684
|
+
embeddings = np.asarray(embeddings, dtype=np.float32)
|
|
685
|
+
dim = embeddings.shape[1]
|
|
686
|
+
|
|
687
|
+
# Use IndexFlatIP for cosine similarity (assumes normalized embeddings)
|
|
688
|
+
self._faiss_index = faiss.IndexFlatIP(dim)
|
|
689
|
+
self._faiss_index.add(embeddings)
|
|
690
|
+
|
|
691
|
+
logger.debug(f"Built FAISS index with {len(embeddings)} vectors, dim={dim}")
|
|
692
|
+
except ImportError:
|
|
693
|
+
logger.debug("FAISS not available, using brute-force search")
|
|
694
|
+
self._faiss_index = None
|
|
695
|
+
|
|
696
|
+
def _compute_similarity(self, query_embedding: Any, doc_embedding: Any) -> float:
|
|
697
|
+
"""Compute similarity between query and document embeddings."""
|
|
698
|
+
import numpy as np
|
|
699
|
+
|
|
700
|
+
query_embedding = np.asarray(query_embedding, dtype=np.float32)
|
|
701
|
+
doc_embedding = np.asarray(doc_embedding, dtype=np.float32)
|
|
702
|
+
|
|
703
|
+
# Normalize for cosine similarity
|
|
704
|
+
query_norm = query_embedding / (np.linalg.norm(query_embedding) + 1e-9)
|
|
705
|
+
doc_norm = doc_embedding / (np.linalg.norm(doc_embedding) + 1e-9)
|
|
706
|
+
|
|
707
|
+
return float(np.dot(query_norm, doc_norm))
|
|
708
|
+
|
|
709
|
+
def _compute_context_bonus(
|
|
710
|
+
self,
|
|
711
|
+
demo: DemoMetadata,
|
|
712
|
+
app_context: Optional[str],
|
|
713
|
+
domain_context: Optional[str],
|
|
714
|
+
) -> float:
|
|
715
|
+
"""Compute context bonus for app/domain matching."""
|
|
716
|
+
bonus = 0.0
|
|
717
|
+
|
|
718
|
+
if app_context and demo.app_name:
|
|
719
|
+
if app_context.lower() in demo.app_name.lower():
|
|
720
|
+
bonus += self.app_bonus
|
|
721
|
+
|
|
722
|
+
if domain_context and demo.domain:
|
|
723
|
+
if domain_context.lower() in demo.domain.lower():
|
|
724
|
+
bonus += self.domain_bonus
|
|
725
|
+
|
|
726
|
+
return bonus
|
|
727
|
+
|
|
728
|
+
def _get_candidates(
|
|
729
|
+
self,
|
|
730
|
+
filter_platform: Optional[str],
|
|
731
|
+
filter_tags: Optional[List[str]],
|
|
732
|
+
) -> List[DemoMetadata]:
|
|
733
|
+
"""Get candidate demos after filtering."""
|
|
734
|
+
candidates = self._demos
|
|
735
|
+
|
|
736
|
+
if filter_platform:
|
|
737
|
+
candidates = [d for d in candidates if d.platform == filter_platform]
|
|
738
|
+
|
|
739
|
+
if filter_tags:
|
|
740
|
+
filter_tags_set = set(filter_tags)
|
|
741
|
+
candidates = [
|
|
742
|
+
d for d in candidates
|
|
743
|
+
if filter_tags_set.issubset(set(d.tags))
|
|
744
|
+
]
|
|
745
|
+
|
|
746
|
+
return candidates
|
|
747
|
+
|
|
748
|
+
def _extract_app_name(self, episode: Episode) -> Optional[str]:
|
|
749
|
+
"""Extract app name from episode observations."""
|
|
750
|
+
for step in episode.steps:
|
|
751
|
+
if step.observation and step.observation.app_name:
|
|
752
|
+
return step.observation.app_name
|
|
753
|
+
return None
|
|
754
|
+
|
|
755
|
+
def _extract_domain(self, episode: Episode) -> Optional[str]:
|
|
756
|
+
"""Extract domain from episode URLs."""
|
|
757
|
+
for step in episode.steps:
|
|
758
|
+
if step.observation and step.observation.url:
|
|
759
|
+
url = step.observation.url
|
|
760
|
+
if "://" in url:
|
|
761
|
+
domain = url.split("://")[1].split("/")[0]
|
|
762
|
+
if domain.startswith("www."):
|
|
763
|
+
domain = domain[4:]
|
|
764
|
+
return domain
|
|
765
|
+
return None
|
|
766
|
+
|
|
767
|
+
def _detect_platform(
|
|
768
|
+
self,
|
|
769
|
+
episode: Episode,
|
|
770
|
+
app_name: Optional[str],
|
|
771
|
+
domain: Optional[str],
|
|
772
|
+
) -> Optional[str]:
|
|
773
|
+
"""Detect platform from episode context."""
|
|
774
|
+
# Check for web indicators
|
|
775
|
+
if domain:
|
|
776
|
+
return "web"
|
|
777
|
+
|
|
778
|
+
# Check for macOS app names
|
|
779
|
+
macos_apps = {"System Settings", "Finder", "Safari", "Preview", "TextEdit"}
|
|
780
|
+
if app_name and app_name in macos_apps:
|
|
781
|
+
return "macos"
|
|
782
|
+
|
|
783
|
+
# Check for Windows app names
|
|
784
|
+
windows_apps = {"Settings", "File Explorer", "Microsoft Edge", "Notepad"}
|
|
785
|
+
if app_name and app_name in windows_apps:
|
|
786
|
+
return "windows"
|
|
787
|
+
|
|
788
|
+
# Check episode metadata
|
|
789
|
+
if episode.environment:
|
|
790
|
+
env_lower = episode.environment.lower()
|
|
791
|
+
if "macos" in env_lower or "darwin" in env_lower:
|
|
792
|
+
return "macos"
|
|
793
|
+
if "windows" in env_lower:
|
|
794
|
+
return "windows"
|
|
795
|
+
|
|
796
|
+
return None
|
|
797
|
+
|
|
798
|
+
def _extract_key_elements(self, episode: Episode) -> List[str]:
|
|
799
|
+
"""Extract key UI elements from episode."""
|
|
800
|
+
elements = []
|
|
801
|
+
for step in episode.steps:
|
|
802
|
+
if step.action and step.action.element:
|
|
803
|
+
elem = step.action.element
|
|
804
|
+
if elem.role and elem.name:
|
|
805
|
+
elements.append(f"{elem.role}:{elem.name}")
|
|
806
|
+
elif elem.name:
|
|
807
|
+
elements.append(elem.name)
|
|
808
|
+
return list(set(elements))
|
|
809
|
+
|
|
810
|
+
def __len__(self) -> int:
|
|
811
|
+
"""Return number of demos in the library."""
|
|
812
|
+
return len(self._demos)
|
|
813
|
+
|
|
814
|
+
def __repr__(self) -> str:
|
|
815
|
+
"""String representation."""
|
|
816
|
+
status = "indexed" if self._is_indexed else "not indexed"
|
|
817
|
+
return f"DemoRetriever({len(self._demos)} demos, {self.embedding_method}, {status})"
|