openadapt-ml 0.1.0__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openadapt_ml/baselines/__init__.py +121 -0
- openadapt_ml/baselines/adapter.py +185 -0
- openadapt_ml/baselines/cli.py +314 -0
- openadapt_ml/baselines/config.py +448 -0
- openadapt_ml/baselines/parser.py +922 -0
- openadapt_ml/baselines/prompts.py +787 -0
- openadapt_ml/benchmarks/__init__.py +13 -107
- openadapt_ml/benchmarks/agent.py +297 -374
- openadapt_ml/benchmarks/azure.py +62 -24
- openadapt_ml/benchmarks/azure_ops_tracker.py +521 -0
- openadapt_ml/benchmarks/cli.py +1874 -751
- openadapt_ml/benchmarks/trace_export.py +631 -0
- openadapt_ml/benchmarks/viewer.py +1236 -0
- openadapt_ml/benchmarks/vm_monitor.py +1111 -0
- openadapt_ml/benchmarks/waa_deploy/Dockerfile +216 -0
- openadapt_ml/benchmarks/waa_deploy/__init__.py +10 -0
- openadapt_ml/benchmarks/waa_deploy/api_agent.py +540 -0
- openadapt_ml/benchmarks/waa_deploy/start_waa_server.bat +53 -0
- openadapt_ml/cloud/azure_inference.py +3 -5
- openadapt_ml/cloud/lambda_labs.py +722 -307
- openadapt_ml/cloud/local.py +3194 -89
- openadapt_ml/cloud/ssh_tunnel.py +595 -0
- openadapt_ml/datasets/next_action.py +125 -96
- openadapt_ml/evals/grounding.py +32 -9
- openadapt_ml/evals/plot_eval_metrics.py +15 -13
- openadapt_ml/evals/trajectory_matching.py +120 -57
- openadapt_ml/experiments/demo_prompt/__init__.py +19 -0
- openadapt_ml/experiments/demo_prompt/format_demo.py +236 -0
- openadapt_ml/experiments/demo_prompt/results/experiment_20251231_002125.json +83 -0
- openadapt_ml/experiments/demo_prompt/results/experiment_n30_20251231_165958.json +1100 -0
- openadapt_ml/experiments/demo_prompt/results/multistep_20251231_025051.json +182 -0
- openadapt_ml/experiments/demo_prompt/run_experiment.py +541 -0
- openadapt_ml/experiments/representation_shootout/__init__.py +70 -0
- openadapt_ml/experiments/representation_shootout/conditions.py +708 -0
- openadapt_ml/experiments/representation_shootout/config.py +390 -0
- openadapt_ml/experiments/representation_shootout/evaluator.py +659 -0
- openadapt_ml/experiments/representation_shootout/runner.py +687 -0
- openadapt_ml/experiments/waa_demo/__init__.py +10 -0
- openadapt_ml/experiments/waa_demo/demos.py +357 -0
- openadapt_ml/experiments/waa_demo/runner.py +732 -0
- openadapt_ml/experiments/waa_demo/tasks.py +151 -0
- openadapt_ml/export/__init__.py +9 -0
- openadapt_ml/export/__main__.py +6 -0
- openadapt_ml/export/cli.py +89 -0
- openadapt_ml/export/parquet.py +277 -0
- openadapt_ml/grounding/detector.py +18 -14
- openadapt_ml/ingest/__init__.py +11 -10
- openadapt_ml/ingest/capture.py +97 -86
- openadapt_ml/ingest/loader.py +120 -69
- openadapt_ml/ingest/synthetic.py +344 -193
- openadapt_ml/models/api_adapter.py +14 -4
- openadapt_ml/models/base_adapter.py +10 -2
- openadapt_ml/models/providers/__init__.py +288 -0
- openadapt_ml/models/providers/anthropic.py +266 -0
- openadapt_ml/models/providers/base.py +299 -0
- openadapt_ml/models/providers/google.py +376 -0
- openadapt_ml/models/providers/openai.py +342 -0
- openadapt_ml/models/qwen_vl.py +46 -19
- openadapt_ml/perception/__init__.py +35 -0
- openadapt_ml/perception/integration.py +399 -0
- openadapt_ml/retrieval/README.md +226 -0
- openadapt_ml/retrieval/USAGE.md +391 -0
- openadapt_ml/retrieval/__init__.py +91 -0
- openadapt_ml/retrieval/demo_retriever.py +843 -0
- openadapt_ml/retrieval/embeddings.py +630 -0
- openadapt_ml/retrieval/index.py +194 -0
- openadapt_ml/retrieval/retriever.py +162 -0
- openadapt_ml/runtime/__init__.py +50 -0
- openadapt_ml/runtime/policy.py +27 -14
- openadapt_ml/runtime/safety_gate.py +471 -0
- openadapt_ml/schema/__init__.py +113 -0
- openadapt_ml/schema/converters.py +588 -0
- openadapt_ml/schema/episode.py +470 -0
- openadapt_ml/scripts/capture_screenshots.py +530 -0
- openadapt_ml/scripts/compare.py +102 -61
- openadapt_ml/scripts/demo_policy.py +4 -1
- openadapt_ml/scripts/eval_policy.py +19 -14
- openadapt_ml/scripts/make_gif.py +1 -1
- openadapt_ml/scripts/prepare_synthetic.py +16 -17
- openadapt_ml/scripts/train.py +98 -75
- openadapt_ml/segmentation/README.md +920 -0
- openadapt_ml/segmentation/__init__.py +97 -0
- openadapt_ml/segmentation/adapters/__init__.py +5 -0
- openadapt_ml/segmentation/adapters/capture_adapter.py +420 -0
- openadapt_ml/segmentation/annotator.py +610 -0
- openadapt_ml/segmentation/cache.py +290 -0
- openadapt_ml/segmentation/cli.py +674 -0
- openadapt_ml/segmentation/deduplicator.py +656 -0
- openadapt_ml/segmentation/frame_describer.py +788 -0
- openadapt_ml/segmentation/pipeline.py +340 -0
- openadapt_ml/segmentation/schemas.py +622 -0
- openadapt_ml/segmentation/segment_extractor.py +634 -0
- openadapt_ml/training/azure_ops_viewer.py +1097 -0
- openadapt_ml/training/benchmark_viewer.py +3255 -19
- openadapt_ml/training/shared_ui.py +7 -7
- openadapt_ml/training/stub_provider.py +57 -35
- openadapt_ml/training/trainer.py +255 -441
- openadapt_ml/training/trl_trainer.py +403 -0
- openadapt_ml/training/viewer.py +323 -108
- openadapt_ml/training/viewer_components.py +180 -0
- {openadapt_ml-0.1.0.dist-info → openadapt_ml-0.2.1.dist-info}/METADATA +312 -69
- openadapt_ml-0.2.1.dist-info/RECORD +116 -0
- openadapt_ml/benchmarks/base.py +0 -366
- openadapt_ml/benchmarks/data_collection.py +0 -432
- openadapt_ml/benchmarks/runner.py +0 -381
- openadapt_ml/benchmarks/waa.py +0 -704
- openadapt_ml/schemas/__init__.py +0 -53
- openadapt_ml/schemas/sessions.py +0 -122
- openadapt_ml/schemas/validation.py +0 -252
- openadapt_ml-0.1.0.dist-info/RECORD +0 -55
- {openadapt_ml-0.1.0.dist-info → openadapt_ml-0.2.1.dist-info}/WHEEL +0 -0
- {openadapt_ml-0.1.0.dist-info → openadapt_ml-0.2.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,630 @@
|
|
|
1
|
+
"""Embedding functions for demo retrieval.
|
|
2
|
+
|
|
3
|
+
Supports multiple embedding backends:
|
|
4
|
+
- TF-IDF: Simple baseline, no external dependencies
|
|
5
|
+
- Sentence Transformers: Local embedding models (recommended)
|
|
6
|
+
- OpenAI: API-based embeddings
|
|
7
|
+
|
|
8
|
+
All embedders implement the same interface:
|
|
9
|
+
- embed(text: str) -> numpy.ndarray
|
|
10
|
+
- embed_batch(texts: List[str]) -> numpy.ndarray
|
|
11
|
+
|
|
12
|
+
Example:
|
|
13
|
+
from openadapt_ml.retrieval.embeddings import SentenceTransformerEmbedder
|
|
14
|
+
|
|
15
|
+
embedder = SentenceTransformerEmbedder()
|
|
16
|
+
embeddings = embedder.embed_batch(["Turn off Night Shift", "Search GitHub"])
|
|
17
|
+
print(embeddings.shape) # (2, 384)
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
from __future__ import annotations
|
|
21
|
+
|
|
22
|
+
import hashlib
|
|
23
|
+
import json
|
|
24
|
+
import logging
|
|
25
|
+
import re
|
|
26
|
+
from abc import ABC, abstractmethod
|
|
27
|
+
from collections import Counter
|
|
28
|
+
from math import log
|
|
29
|
+
from pathlib import Path
|
|
30
|
+
from typing import Any, Dict, List, Optional
|
|
31
|
+
|
|
32
|
+
logger = logging.getLogger(__name__)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class BaseEmbedder(ABC):
|
|
36
|
+
"""Abstract base class for text embedders."""
|
|
37
|
+
|
|
38
|
+
@abstractmethod
|
|
39
|
+
def embed(self, text: str) -> Any:
|
|
40
|
+
"""Embed a single text.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
text: Input text to embed.
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
Embedding vector (numpy array or dict for sparse).
|
|
47
|
+
"""
|
|
48
|
+
pass
|
|
49
|
+
|
|
50
|
+
@abstractmethod
|
|
51
|
+
def embed_batch(self, texts: List[str]) -> Any:
|
|
52
|
+
"""Embed multiple texts.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
texts: List of texts to embed.
|
|
56
|
+
|
|
57
|
+
Returns:
|
|
58
|
+
Embeddings matrix (numpy array of shape [n_texts, embedding_dim]).
|
|
59
|
+
"""
|
|
60
|
+
pass
|
|
61
|
+
|
|
62
|
+
def cosine_similarity(self, vec1: Any, vec2: Any) -> float:
|
|
63
|
+
"""Compute cosine similarity between two vectors.
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
vec1: First embedding vector.
|
|
67
|
+
vec2: Second embedding vector.
|
|
68
|
+
|
|
69
|
+
Returns:
|
|
70
|
+
Cosine similarity in [-1, 1].
|
|
71
|
+
"""
|
|
72
|
+
import numpy as np
|
|
73
|
+
|
|
74
|
+
vec1 = np.asarray(vec1, dtype=np.float32).flatten()
|
|
75
|
+
vec2 = np.asarray(vec2, dtype=np.float32).flatten()
|
|
76
|
+
|
|
77
|
+
dot = np.dot(vec1, vec2)
|
|
78
|
+
norm1 = np.linalg.norm(vec1)
|
|
79
|
+
norm2 = np.linalg.norm(vec2)
|
|
80
|
+
|
|
81
|
+
if norm1 == 0 or norm2 == 0:
|
|
82
|
+
return 0.0
|
|
83
|
+
|
|
84
|
+
return float(dot / (norm1 * norm2))
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
# =============================================================================
|
|
88
|
+
# TF-IDF Embedder (Baseline)
|
|
89
|
+
# =============================================================================
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
class TFIDFEmbedder(BaseEmbedder):
|
|
93
|
+
"""Simple TF-IDF based text embedder.
|
|
94
|
+
|
|
95
|
+
This is a minimal implementation for baseline/testing that doesn't require
|
|
96
|
+
any external ML libraries. Uses sparse representations internally but
|
|
97
|
+
converts to dense for compatibility.
|
|
98
|
+
|
|
99
|
+
Note: Must call fit() before embed() to build vocabulary.
|
|
100
|
+
"""
|
|
101
|
+
|
|
102
|
+
def __init__(self, max_features: int = 1000) -> None:
|
|
103
|
+
"""Initialize the TF-IDF embedder.
|
|
104
|
+
|
|
105
|
+
Args:
|
|
106
|
+
max_features: Maximum vocabulary size.
|
|
107
|
+
"""
|
|
108
|
+
self.max_features = max_features
|
|
109
|
+
self.documents: List[str] = []
|
|
110
|
+
self.idf: Dict[str, float] = {}
|
|
111
|
+
self.vocab: List[str] = []
|
|
112
|
+
self.vocab_to_idx: Dict[str, int] = {}
|
|
113
|
+
self._is_fitted = False
|
|
114
|
+
|
|
115
|
+
def _tokenize(self, text: str) -> List[str]:
|
|
116
|
+
"""Simple tokenization - lowercase and split on non-alphanumeric.
|
|
117
|
+
|
|
118
|
+
Args:
|
|
119
|
+
text: Input text to tokenize.
|
|
120
|
+
|
|
121
|
+
Returns:
|
|
122
|
+
List of tokens.
|
|
123
|
+
"""
|
|
124
|
+
tokens = re.findall(r"\b\w+\b", text.lower())
|
|
125
|
+
return tokens
|
|
126
|
+
|
|
127
|
+
def _compute_tf(self, tokens: List[str]) -> Dict[str, float]:
|
|
128
|
+
"""Compute term frequency for a document.
|
|
129
|
+
|
|
130
|
+
Args:
|
|
131
|
+
tokens: List of tokens.
|
|
132
|
+
|
|
133
|
+
Returns:
|
|
134
|
+
Dictionary mapping term to frequency.
|
|
135
|
+
"""
|
|
136
|
+
counter = Counter(tokens)
|
|
137
|
+
total = len(tokens)
|
|
138
|
+
if total == 0:
|
|
139
|
+
return {}
|
|
140
|
+
return {term: count / total for term, count in counter.items()}
|
|
141
|
+
|
|
142
|
+
def fit(self, documents: List[str]) -> "TFIDFEmbedder":
|
|
143
|
+
"""Fit the IDF on a corpus of documents.
|
|
144
|
+
|
|
145
|
+
Args:
|
|
146
|
+
documents: List of text documents.
|
|
147
|
+
|
|
148
|
+
Returns:
|
|
149
|
+
self for chaining.
|
|
150
|
+
"""
|
|
151
|
+
self.documents = documents
|
|
152
|
+
|
|
153
|
+
# Count document frequency for each term
|
|
154
|
+
doc_freq: Dict[str, int] = {}
|
|
155
|
+
all_terms: Counter[str] = Counter()
|
|
156
|
+
|
|
157
|
+
for doc in documents:
|
|
158
|
+
tokens = self._tokenize(doc)
|
|
159
|
+
unique_tokens = set(tokens)
|
|
160
|
+
all_terms.update(tokens)
|
|
161
|
+
for token in unique_tokens:
|
|
162
|
+
doc_freq[token] = doc_freq.get(token, 0) + 1
|
|
163
|
+
|
|
164
|
+
# Select top features by frequency
|
|
165
|
+
top_terms = [term for term, _ in all_terms.most_common(self.max_features)]
|
|
166
|
+
self.vocab = top_terms
|
|
167
|
+
self.vocab_to_idx = {term: idx for idx, term in enumerate(top_terms)}
|
|
168
|
+
|
|
169
|
+
# Compute IDF: log(N / df) + 1
|
|
170
|
+
n_docs = max(len(documents), 1)
|
|
171
|
+
self.idf = {
|
|
172
|
+
term: log(n_docs / doc_freq.get(term, 1)) + 1 for term in self.vocab
|
|
173
|
+
}
|
|
174
|
+
|
|
175
|
+
self._is_fitted = True
|
|
176
|
+
return self
|
|
177
|
+
|
|
178
|
+
def embed(self, text: str) -> Any:
|
|
179
|
+
"""Convert text to TF-IDF vector.
|
|
180
|
+
|
|
181
|
+
Args:
|
|
182
|
+
text: Input text.
|
|
183
|
+
|
|
184
|
+
Returns:
|
|
185
|
+
Dense embedding vector (numpy array).
|
|
186
|
+
"""
|
|
187
|
+
import numpy as np
|
|
188
|
+
|
|
189
|
+
if not self._is_fitted:
|
|
190
|
+
# Fit on single document for compatibility
|
|
191
|
+
self.fit([text])
|
|
192
|
+
|
|
193
|
+
tokens = self._tokenize(text)
|
|
194
|
+
tf = self._compute_tf(tokens)
|
|
195
|
+
|
|
196
|
+
# Create dense vector
|
|
197
|
+
vec = np.zeros(len(self.vocab), dtype=np.float32)
|
|
198
|
+
for term, tf_val in tf.items():
|
|
199
|
+
if term in self.vocab_to_idx:
|
|
200
|
+
idx = self.vocab_to_idx[term]
|
|
201
|
+
vec[idx] = tf_val * self.idf.get(term, 1.0)
|
|
202
|
+
|
|
203
|
+
# L2 normalize
|
|
204
|
+
norm = np.linalg.norm(vec)
|
|
205
|
+
if norm > 0:
|
|
206
|
+
vec = vec / norm
|
|
207
|
+
|
|
208
|
+
return vec
|
|
209
|
+
|
|
210
|
+
def embed_batch(self, texts: List[str]) -> Any:
|
|
211
|
+
"""Embed multiple texts.
|
|
212
|
+
|
|
213
|
+
If not fitted, fits on the input texts first.
|
|
214
|
+
|
|
215
|
+
Args:
|
|
216
|
+
texts: List of texts to embed.
|
|
217
|
+
|
|
218
|
+
Returns:
|
|
219
|
+
Embeddings matrix (numpy array).
|
|
220
|
+
"""
|
|
221
|
+
import numpy as np
|
|
222
|
+
|
|
223
|
+
if not self._is_fitted:
|
|
224
|
+
self.fit(texts)
|
|
225
|
+
|
|
226
|
+
embeddings = np.array([self.embed(text) for text in texts], dtype=np.float32)
|
|
227
|
+
return embeddings
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
# Alias for backward compatibility
|
|
231
|
+
TextEmbedder = TFIDFEmbedder
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
# =============================================================================
|
|
235
|
+
# Sentence Transformers Embedder
|
|
236
|
+
# =============================================================================
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
class SentenceTransformerEmbedder(BaseEmbedder):
|
|
240
|
+
"""Embedding using sentence-transformers library.
|
|
241
|
+
|
|
242
|
+
Recommended models:
|
|
243
|
+
- "all-MiniLM-L6-v2": Fast, 22MB, 384 dims (default)
|
|
244
|
+
- "all-mpnet-base-v2": Better quality, 420MB, 768 dims
|
|
245
|
+
- "BAAI/bge-small-en-v1.5": Good balance, 130MB, 384 dims
|
|
246
|
+
- "BAAI/bge-base-en-v1.5": Best quality, 440MB, 768 dims
|
|
247
|
+
|
|
248
|
+
Requires: pip install sentence-transformers
|
|
249
|
+
"""
|
|
250
|
+
|
|
251
|
+
def __init__(
|
|
252
|
+
self,
|
|
253
|
+
model_name: str = "all-MiniLM-L6-v2",
|
|
254
|
+
cache_dir: Optional[Path] = None,
|
|
255
|
+
device: Optional[str] = None,
|
|
256
|
+
normalize: bool = True,
|
|
257
|
+
) -> None:
|
|
258
|
+
"""Initialize the Sentence Transformer embedder.
|
|
259
|
+
|
|
260
|
+
Args:
|
|
261
|
+
model_name: Name of the sentence-transformers model.
|
|
262
|
+
cache_dir: Directory for caching model and embeddings.
|
|
263
|
+
device: Device to run on ("cpu", "cuda", "mps"). Auto-detected if None.
|
|
264
|
+
normalize: Whether to L2-normalize embeddings (for cosine similarity).
|
|
265
|
+
"""
|
|
266
|
+
self.model_name = model_name
|
|
267
|
+
self.cache_dir = Path(cache_dir) if cache_dir else None
|
|
268
|
+
self.device = device
|
|
269
|
+
self.normalize = normalize
|
|
270
|
+
self._model = None
|
|
271
|
+
self._embedding_cache: Dict[str, Any] = {}
|
|
272
|
+
|
|
273
|
+
def _load_model(self) -> None:
|
|
274
|
+
"""Lazy-load the model."""
|
|
275
|
+
if self._model is not None:
|
|
276
|
+
return
|
|
277
|
+
|
|
278
|
+
try:
|
|
279
|
+
from sentence_transformers import SentenceTransformer
|
|
280
|
+
except ImportError:
|
|
281
|
+
raise ImportError(
|
|
282
|
+
"sentence-transformers is required for SentenceTransformerEmbedder. "
|
|
283
|
+
"Install with: pip install sentence-transformers"
|
|
284
|
+
)
|
|
285
|
+
|
|
286
|
+
logger.info(f"Loading sentence-transformers model: {self.model_name}")
|
|
287
|
+
self._model = SentenceTransformer(
|
|
288
|
+
self.model_name,
|
|
289
|
+
cache_folder=str(self.cache_dir) if self.cache_dir else None,
|
|
290
|
+
device=self.device,
|
|
291
|
+
)
|
|
292
|
+
|
|
293
|
+
def _get_cache_key(self, text: str) -> str:
|
|
294
|
+
"""Generate cache key for text."""
|
|
295
|
+
return hashlib.md5(f"{self.model_name}:{text}".encode()).hexdigest()
|
|
296
|
+
|
|
297
|
+
def embed(self, text: str) -> Any:
|
|
298
|
+
"""Embed a single text.
|
|
299
|
+
|
|
300
|
+
Args:
|
|
301
|
+
text: Input text.
|
|
302
|
+
|
|
303
|
+
Returns:
|
|
304
|
+
Embedding vector (numpy array).
|
|
305
|
+
"""
|
|
306
|
+
import numpy as np
|
|
307
|
+
|
|
308
|
+
# Check cache
|
|
309
|
+
cache_key = self._get_cache_key(text)
|
|
310
|
+
if cache_key in self._embedding_cache:
|
|
311
|
+
return self._embedding_cache[cache_key]
|
|
312
|
+
|
|
313
|
+
self._load_model()
|
|
314
|
+
|
|
315
|
+
embedding = self._model.encode(
|
|
316
|
+
text,
|
|
317
|
+
normalize_embeddings=self.normalize,
|
|
318
|
+
convert_to_numpy=True,
|
|
319
|
+
)
|
|
320
|
+
embedding = np.asarray(embedding, dtype=np.float32)
|
|
321
|
+
|
|
322
|
+
# Cache result
|
|
323
|
+
self._embedding_cache[cache_key] = embedding
|
|
324
|
+
|
|
325
|
+
return embedding
|
|
326
|
+
|
|
327
|
+
def embed_batch(self, texts: List[str]) -> Any:
|
|
328
|
+
"""Embed multiple texts efficiently.
|
|
329
|
+
|
|
330
|
+
Args:
|
|
331
|
+
texts: List of texts.
|
|
332
|
+
|
|
333
|
+
Returns:
|
|
334
|
+
Embeddings matrix (numpy array of shape [n_texts, dim]).
|
|
335
|
+
"""
|
|
336
|
+
import numpy as np
|
|
337
|
+
|
|
338
|
+
self._load_model()
|
|
339
|
+
|
|
340
|
+
# Check which texts are cached
|
|
341
|
+
cached_embeddings = {}
|
|
342
|
+
uncached_texts = []
|
|
343
|
+
uncached_indices = []
|
|
344
|
+
|
|
345
|
+
for i, text in enumerate(texts):
|
|
346
|
+
cache_key = self._get_cache_key(text)
|
|
347
|
+
if cache_key in self._embedding_cache:
|
|
348
|
+
cached_embeddings[i] = self._embedding_cache[cache_key]
|
|
349
|
+
else:
|
|
350
|
+
uncached_texts.append(text)
|
|
351
|
+
uncached_indices.append(i)
|
|
352
|
+
|
|
353
|
+
# Embed uncached texts
|
|
354
|
+
if uncached_texts:
|
|
355
|
+
new_embeddings = self._model.encode(
|
|
356
|
+
uncached_texts,
|
|
357
|
+
normalize_embeddings=self.normalize,
|
|
358
|
+
convert_to_numpy=True,
|
|
359
|
+
show_progress_bar=len(uncached_texts) > 10,
|
|
360
|
+
)
|
|
361
|
+
|
|
362
|
+
# Cache new embeddings
|
|
363
|
+
for text, embedding in zip(uncached_texts, new_embeddings):
|
|
364
|
+
cache_key = self._get_cache_key(text)
|
|
365
|
+
self._embedding_cache[cache_key] = embedding
|
|
366
|
+
|
|
367
|
+
# Reassemble in original order
|
|
368
|
+
dim = self._model.get_sentence_embedding_dimension()
|
|
369
|
+
result = np.zeros((len(texts), dim), dtype=np.float32)
|
|
370
|
+
|
|
371
|
+
for i, emb in cached_embeddings.items():
|
|
372
|
+
result[i] = emb
|
|
373
|
+
|
|
374
|
+
for i, idx in enumerate(uncached_indices):
|
|
375
|
+
result[idx] = new_embeddings[i]
|
|
376
|
+
|
|
377
|
+
return result
|
|
378
|
+
|
|
379
|
+
def clear_cache(self) -> None:
|
|
380
|
+
"""Clear the embedding cache."""
|
|
381
|
+
self._embedding_cache = {}
|
|
382
|
+
|
|
383
|
+
|
|
384
|
+
# =============================================================================
|
|
385
|
+
# OpenAI Embedder
|
|
386
|
+
# =============================================================================
|
|
387
|
+
|
|
388
|
+
|
|
389
|
+
class OpenAIEmbedder(BaseEmbedder):
|
|
390
|
+
"""Embedding using OpenAI's text-embedding API.
|
|
391
|
+
|
|
392
|
+
Models:
|
|
393
|
+
- "text-embedding-3-small": Cheap, fast, 1536 dims ($0.00002/1K tokens)
|
|
394
|
+
- "text-embedding-3-large": Best quality, 3072 dims ($0.00013/1K tokens)
|
|
395
|
+
- "text-embedding-ada-002": Legacy, 1536 dims
|
|
396
|
+
|
|
397
|
+
Requires: pip install openai
|
|
398
|
+
Environment: OPENAI_API_KEY must be set
|
|
399
|
+
"""
|
|
400
|
+
|
|
401
|
+
def __init__(
|
|
402
|
+
self,
|
|
403
|
+
model_name: str = "text-embedding-3-small",
|
|
404
|
+
cache_dir: Optional[Path] = None,
|
|
405
|
+
api_key: Optional[str] = None,
|
|
406
|
+
normalize: bool = True,
|
|
407
|
+
batch_size: int = 100,
|
|
408
|
+
) -> None:
|
|
409
|
+
"""Initialize the OpenAI embedder.
|
|
410
|
+
|
|
411
|
+
Args:
|
|
412
|
+
model_name: OpenAI embedding model name.
|
|
413
|
+
cache_dir: Directory for caching embeddings to disk.
|
|
414
|
+
api_key: OpenAI API key. If None, uses OPENAI_API_KEY env var.
|
|
415
|
+
normalize: Whether to L2-normalize embeddings.
|
|
416
|
+
batch_size: Maximum texts per API call.
|
|
417
|
+
"""
|
|
418
|
+
self.model_name = model_name
|
|
419
|
+
self.cache_dir = Path(cache_dir) if cache_dir else None
|
|
420
|
+
self.api_key = api_key
|
|
421
|
+
self.normalize = normalize
|
|
422
|
+
self.batch_size = batch_size
|
|
423
|
+
self._client = None
|
|
424
|
+
self._embedding_cache: Dict[str, Any] = {}
|
|
425
|
+
|
|
426
|
+
if self.cache_dir:
|
|
427
|
+
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
|
428
|
+
self._load_disk_cache()
|
|
429
|
+
|
|
430
|
+
def _load_disk_cache(self) -> None:
|
|
431
|
+
"""Load cache from disk."""
|
|
432
|
+
if not self.cache_dir:
|
|
433
|
+
return
|
|
434
|
+
|
|
435
|
+
cache_file = self.cache_dir / "embeddings_cache.json"
|
|
436
|
+
if cache_file.exists():
|
|
437
|
+
try:
|
|
438
|
+
with open(cache_file) as f:
|
|
439
|
+
cached = json.load(f)
|
|
440
|
+
# Convert lists back to arrays
|
|
441
|
+
import numpy as np
|
|
442
|
+
|
|
443
|
+
for key, val in cached.items():
|
|
444
|
+
self._embedding_cache[key] = np.array(val, dtype=np.float32)
|
|
445
|
+
logger.debug(f"Loaded {len(self._embedding_cache)} cached embeddings")
|
|
446
|
+
except Exception as e:
|
|
447
|
+
logger.warning(f"Failed to load cache: {e}")
|
|
448
|
+
|
|
449
|
+
def _save_disk_cache(self) -> None:
|
|
450
|
+
"""Save cache to disk."""
|
|
451
|
+
if not self.cache_dir:
|
|
452
|
+
return
|
|
453
|
+
|
|
454
|
+
cache_file = self.cache_dir / "embeddings_cache.json"
|
|
455
|
+
try:
|
|
456
|
+
# Convert arrays to lists for JSON
|
|
457
|
+
cache_data = {
|
|
458
|
+
key: val.tolist() for key, val in self._embedding_cache.items()
|
|
459
|
+
}
|
|
460
|
+
with open(cache_file, "w") as f:
|
|
461
|
+
json.dump(cache_data, f)
|
|
462
|
+
except Exception as e:
|
|
463
|
+
logger.warning(f"Failed to save cache: {e}")
|
|
464
|
+
|
|
465
|
+
def _get_client(self) -> Any:
|
|
466
|
+
"""Get or create OpenAI client."""
|
|
467
|
+
if self._client is not None:
|
|
468
|
+
return self._client
|
|
469
|
+
|
|
470
|
+
try:
|
|
471
|
+
from openai import OpenAI
|
|
472
|
+
except ImportError:
|
|
473
|
+
raise ImportError(
|
|
474
|
+
"openai is required for OpenAIEmbedder. "
|
|
475
|
+
"Install with: pip install openai"
|
|
476
|
+
)
|
|
477
|
+
|
|
478
|
+
self._client = OpenAI(api_key=self.api_key)
|
|
479
|
+
return self._client
|
|
480
|
+
|
|
481
|
+
def _get_cache_key(self, text: str) -> str:
|
|
482
|
+
"""Generate cache key for text."""
|
|
483
|
+
return hashlib.md5(f"{self.model_name}:{text}".encode()).hexdigest()
|
|
484
|
+
|
|
485
|
+
def embed(self, text: str) -> Any:
|
|
486
|
+
"""Embed a single text.
|
|
487
|
+
|
|
488
|
+
Args:
|
|
489
|
+
text: Input text.
|
|
490
|
+
|
|
491
|
+
Returns:
|
|
492
|
+
Embedding vector (numpy array).
|
|
493
|
+
"""
|
|
494
|
+
return self.embed_batch([text])[0]
|
|
495
|
+
|
|
496
|
+
def embed_batch(self, texts: List[str]) -> Any:
|
|
497
|
+
"""Embed multiple texts.
|
|
498
|
+
|
|
499
|
+
Args:
|
|
500
|
+
texts: List of texts.
|
|
501
|
+
|
|
502
|
+
Returns:
|
|
503
|
+
Embeddings matrix (numpy array).
|
|
504
|
+
"""
|
|
505
|
+
import numpy as np
|
|
506
|
+
|
|
507
|
+
# Check cache first
|
|
508
|
+
cached_embeddings = {}
|
|
509
|
+
uncached_texts = []
|
|
510
|
+
uncached_indices = []
|
|
511
|
+
|
|
512
|
+
for i, text in enumerate(texts):
|
|
513
|
+
cache_key = self._get_cache_key(text)
|
|
514
|
+
if cache_key in self._embedding_cache:
|
|
515
|
+
cached_embeddings[i] = self._embedding_cache[cache_key]
|
|
516
|
+
else:
|
|
517
|
+
uncached_texts.append(text)
|
|
518
|
+
uncached_indices.append(i)
|
|
519
|
+
|
|
520
|
+
# Fetch uncached embeddings from API
|
|
521
|
+
new_embeddings = {}
|
|
522
|
+
if uncached_texts:
|
|
523
|
+
client = self._get_client()
|
|
524
|
+
|
|
525
|
+
# Process in batches
|
|
526
|
+
for batch_start in range(0, len(uncached_texts), self.batch_size):
|
|
527
|
+
batch_texts = uncached_texts[
|
|
528
|
+
batch_start : batch_start + self.batch_size
|
|
529
|
+
]
|
|
530
|
+
|
|
531
|
+
try:
|
|
532
|
+
response = client.embeddings.create(
|
|
533
|
+
model=self.model_name,
|
|
534
|
+
input=batch_texts,
|
|
535
|
+
)
|
|
536
|
+
|
|
537
|
+
for j, item in enumerate(response.data):
|
|
538
|
+
idx = uncached_indices[batch_start + j]
|
|
539
|
+
embedding = np.array(item.embedding, dtype=np.float32)
|
|
540
|
+
|
|
541
|
+
if self.normalize:
|
|
542
|
+
norm = np.linalg.norm(embedding)
|
|
543
|
+
if norm > 0:
|
|
544
|
+
embedding = embedding / norm
|
|
545
|
+
|
|
546
|
+
new_embeddings[idx] = embedding
|
|
547
|
+
|
|
548
|
+
# Cache the result
|
|
549
|
+
cache_key = self._get_cache_key(batch_texts[j])
|
|
550
|
+
self._embedding_cache[cache_key] = embedding
|
|
551
|
+
|
|
552
|
+
except Exception as e:
|
|
553
|
+
logger.error(f"OpenAI API error: {e}")
|
|
554
|
+
raise
|
|
555
|
+
|
|
556
|
+
# Save to disk cache periodically
|
|
557
|
+
self._save_disk_cache()
|
|
558
|
+
|
|
559
|
+
# Determine embedding dimension
|
|
560
|
+
if cached_embeddings:
|
|
561
|
+
dim = next(iter(cached_embeddings.values())).shape[0]
|
|
562
|
+
elif new_embeddings:
|
|
563
|
+
dim = next(iter(new_embeddings.values())).shape[0]
|
|
564
|
+
else:
|
|
565
|
+
# Default dimensions by model
|
|
566
|
+
dim = {"text-embedding-3-small": 1536, "text-embedding-3-large": 3072}.get(
|
|
567
|
+
self.model_name, 1536
|
|
568
|
+
)
|
|
569
|
+
|
|
570
|
+
# Assemble result
|
|
571
|
+
result = np.zeros((len(texts), dim), dtype=np.float32)
|
|
572
|
+
|
|
573
|
+
for i, emb in cached_embeddings.items():
|
|
574
|
+
result[i] = emb
|
|
575
|
+
|
|
576
|
+
for i, emb in new_embeddings.items():
|
|
577
|
+
result[i] = emb
|
|
578
|
+
|
|
579
|
+
return result
|
|
580
|
+
|
|
581
|
+
def clear_cache(self) -> None:
|
|
582
|
+
"""Clear embedding cache."""
|
|
583
|
+
self._embedding_cache = {}
|
|
584
|
+
if self.cache_dir:
|
|
585
|
+
cache_file = self.cache_dir / "embeddings_cache.json"
|
|
586
|
+
if cache_file.exists():
|
|
587
|
+
cache_file.unlink()
|
|
588
|
+
|
|
589
|
+
|
|
590
|
+
# =============================================================================
|
|
591
|
+
# Factory Function
|
|
592
|
+
# =============================================================================
|
|
593
|
+
|
|
594
|
+
|
|
595
|
+
def create_embedder(
|
|
596
|
+
method: str = "tfidf",
|
|
597
|
+
model_name: Optional[str] = None,
|
|
598
|
+
cache_dir: Optional[Path] = None,
|
|
599
|
+
**kwargs: Any,
|
|
600
|
+
) -> BaseEmbedder:
|
|
601
|
+
"""Factory function to create an embedder.
|
|
602
|
+
|
|
603
|
+
Args:
|
|
604
|
+
method: Embedding method ("tfidf", "sentence_transformers", "openai").
|
|
605
|
+
model_name: Model name (method-specific defaults if None).
|
|
606
|
+
cache_dir: Cache directory for embeddings.
|
|
607
|
+
**kwargs: Additional arguments passed to embedder.
|
|
608
|
+
|
|
609
|
+
Returns:
|
|
610
|
+
Embedder instance.
|
|
611
|
+
"""
|
|
612
|
+
if method == "tfidf":
|
|
613
|
+
return TFIDFEmbedder(**kwargs)
|
|
614
|
+
|
|
615
|
+
elif method == "sentence_transformers":
|
|
616
|
+
return SentenceTransformerEmbedder(
|
|
617
|
+
model_name=model_name or "all-MiniLM-L6-v2",
|
|
618
|
+
cache_dir=cache_dir,
|
|
619
|
+
**kwargs,
|
|
620
|
+
)
|
|
621
|
+
|
|
622
|
+
elif method == "openai":
|
|
623
|
+
return OpenAIEmbedder(
|
|
624
|
+
model_name=model_name or "text-embedding-3-small",
|
|
625
|
+
cache_dir=cache_dir,
|
|
626
|
+
**kwargs,
|
|
627
|
+
)
|
|
628
|
+
|
|
629
|
+
else:
|
|
630
|
+
raise ValueError(f"Unknown embedding method: {method}")
|