openadapt-ml 0.1.0__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (112) hide show
  1. openadapt_ml/baselines/__init__.py +121 -0
  2. openadapt_ml/baselines/adapter.py +185 -0
  3. openadapt_ml/baselines/cli.py +314 -0
  4. openadapt_ml/baselines/config.py +448 -0
  5. openadapt_ml/baselines/parser.py +922 -0
  6. openadapt_ml/baselines/prompts.py +787 -0
  7. openadapt_ml/benchmarks/__init__.py +13 -107
  8. openadapt_ml/benchmarks/agent.py +297 -374
  9. openadapt_ml/benchmarks/azure.py +62 -24
  10. openadapt_ml/benchmarks/azure_ops_tracker.py +521 -0
  11. openadapt_ml/benchmarks/cli.py +1874 -751
  12. openadapt_ml/benchmarks/trace_export.py +631 -0
  13. openadapt_ml/benchmarks/viewer.py +1236 -0
  14. openadapt_ml/benchmarks/vm_monitor.py +1111 -0
  15. openadapt_ml/benchmarks/waa_deploy/Dockerfile +216 -0
  16. openadapt_ml/benchmarks/waa_deploy/__init__.py +10 -0
  17. openadapt_ml/benchmarks/waa_deploy/api_agent.py +540 -0
  18. openadapt_ml/benchmarks/waa_deploy/start_waa_server.bat +53 -0
  19. openadapt_ml/cloud/azure_inference.py +3 -5
  20. openadapt_ml/cloud/lambda_labs.py +722 -307
  21. openadapt_ml/cloud/local.py +3194 -89
  22. openadapt_ml/cloud/ssh_tunnel.py +595 -0
  23. openadapt_ml/datasets/next_action.py +125 -96
  24. openadapt_ml/evals/grounding.py +32 -9
  25. openadapt_ml/evals/plot_eval_metrics.py +15 -13
  26. openadapt_ml/evals/trajectory_matching.py +120 -57
  27. openadapt_ml/experiments/demo_prompt/__init__.py +19 -0
  28. openadapt_ml/experiments/demo_prompt/format_demo.py +236 -0
  29. openadapt_ml/experiments/demo_prompt/results/experiment_20251231_002125.json +83 -0
  30. openadapt_ml/experiments/demo_prompt/results/experiment_n30_20251231_165958.json +1100 -0
  31. openadapt_ml/experiments/demo_prompt/results/multistep_20251231_025051.json +182 -0
  32. openadapt_ml/experiments/demo_prompt/run_experiment.py +541 -0
  33. openadapt_ml/experiments/representation_shootout/__init__.py +70 -0
  34. openadapt_ml/experiments/representation_shootout/conditions.py +708 -0
  35. openadapt_ml/experiments/representation_shootout/config.py +390 -0
  36. openadapt_ml/experiments/representation_shootout/evaluator.py +659 -0
  37. openadapt_ml/experiments/representation_shootout/runner.py +687 -0
  38. openadapt_ml/experiments/waa_demo/__init__.py +10 -0
  39. openadapt_ml/experiments/waa_demo/demos.py +357 -0
  40. openadapt_ml/experiments/waa_demo/runner.py +732 -0
  41. openadapt_ml/experiments/waa_demo/tasks.py +151 -0
  42. openadapt_ml/export/__init__.py +9 -0
  43. openadapt_ml/export/__main__.py +6 -0
  44. openadapt_ml/export/cli.py +89 -0
  45. openadapt_ml/export/parquet.py +277 -0
  46. openadapt_ml/grounding/detector.py +18 -14
  47. openadapt_ml/ingest/__init__.py +11 -10
  48. openadapt_ml/ingest/capture.py +97 -86
  49. openadapt_ml/ingest/loader.py +120 -69
  50. openadapt_ml/ingest/synthetic.py +344 -193
  51. openadapt_ml/models/api_adapter.py +14 -4
  52. openadapt_ml/models/base_adapter.py +10 -2
  53. openadapt_ml/models/providers/__init__.py +288 -0
  54. openadapt_ml/models/providers/anthropic.py +266 -0
  55. openadapt_ml/models/providers/base.py +299 -0
  56. openadapt_ml/models/providers/google.py +376 -0
  57. openadapt_ml/models/providers/openai.py +342 -0
  58. openadapt_ml/models/qwen_vl.py +46 -19
  59. openadapt_ml/perception/__init__.py +35 -0
  60. openadapt_ml/perception/integration.py +399 -0
  61. openadapt_ml/retrieval/README.md +226 -0
  62. openadapt_ml/retrieval/USAGE.md +391 -0
  63. openadapt_ml/retrieval/__init__.py +91 -0
  64. openadapt_ml/retrieval/demo_retriever.py +843 -0
  65. openadapt_ml/retrieval/embeddings.py +630 -0
  66. openadapt_ml/retrieval/index.py +194 -0
  67. openadapt_ml/retrieval/retriever.py +162 -0
  68. openadapt_ml/runtime/__init__.py +50 -0
  69. openadapt_ml/runtime/policy.py +27 -14
  70. openadapt_ml/runtime/safety_gate.py +471 -0
  71. openadapt_ml/schema/__init__.py +113 -0
  72. openadapt_ml/schema/converters.py +588 -0
  73. openadapt_ml/schema/episode.py +470 -0
  74. openadapt_ml/scripts/capture_screenshots.py +530 -0
  75. openadapt_ml/scripts/compare.py +102 -61
  76. openadapt_ml/scripts/demo_policy.py +4 -1
  77. openadapt_ml/scripts/eval_policy.py +19 -14
  78. openadapt_ml/scripts/make_gif.py +1 -1
  79. openadapt_ml/scripts/prepare_synthetic.py +16 -17
  80. openadapt_ml/scripts/train.py +98 -75
  81. openadapt_ml/segmentation/README.md +920 -0
  82. openadapt_ml/segmentation/__init__.py +97 -0
  83. openadapt_ml/segmentation/adapters/__init__.py +5 -0
  84. openadapt_ml/segmentation/adapters/capture_adapter.py +420 -0
  85. openadapt_ml/segmentation/annotator.py +610 -0
  86. openadapt_ml/segmentation/cache.py +290 -0
  87. openadapt_ml/segmentation/cli.py +674 -0
  88. openadapt_ml/segmentation/deduplicator.py +656 -0
  89. openadapt_ml/segmentation/frame_describer.py +788 -0
  90. openadapt_ml/segmentation/pipeline.py +340 -0
  91. openadapt_ml/segmentation/schemas.py +622 -0
  92. openadapt_ml/segmentation/segment_extractor.py +634 -0
  93. openadapt_ml/training/azure_ops_viewer.py +1097 -0
  94. openadapt_ml/training/benchmark_viewer.py +3255 -19
  95. openadapt_ml/training/shared_ui.py +7 -7
  96. openadapt_ml/training/stub_provider.py +57 -35
  97. openadapt_ml/training/trainer.py +255 -441
  98. openadapt_ml/training/trl_trainer.py +403 -0
  99. openadapt_ml/training/viewer.py +323 -108
  100. openadapt_ml/training/viewer_components.py +180 -0
  101. {openadapt_ml-0.1.0.dist-info → openadapt_ml-0.2.1.dist-info}/METADATA +312 -69
  102. openadapt_ml-0.2.1.dist-info/RECORD +116 -0
  103. openadapt_ml/benchmarks/base.py +0 -366
  104. openadapt_ml/benchmarks/data_collection.py +0 -432
  105. openadapt_ml/benchmarks/runner.py +0 -381
  106. openadapt_ml/benchmarks/waa.py +0 -704
  107. openadapt_ml/schemas/__init__.py +0 -53
  108. openadapt_ml/schemas/sessions.py +0 -122
  109. openadapt_ml/schemas/validation.py +0 -252
  110. openadapt_ml-0.1.0.dist-info/RECORD +0 -55
  111. {openadapt_ml-0.1.0.dist-info → openadapt_ml-0.2.1.dist-info}/WHEEL +0 -0
  112. {openadapt_ml-0.1.0.dist-info → openadapt_ml-0.2.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,630 @@
1
+ """Embedding functions for demo retrieval.
2
+
3
+ Supports multiple embedding backends:
4
+ - TF-IDF: Simple baseline, no external dependencies
5
+ - Sentence Transformers: Local embedding models (recommended)
6
+ - OpenAI: API-based embeddings
7
+
8
+ All embedders implement the same interface:
9
+ - embed(text: str) -> numpy.ndarray
10
+ - embed_batch(texts: List[str]) -> numpy.ndarray
11
+
12
+ Example:
13
+ from openadapt_ml.retrieval.embeddings import SentenceTransformerEmbedder
14
+
15
+ embedder = SentenceTransformerEmbedder()
16
+ embeddings = embedder.embed_batch(["Turn off Night Shift", "Search GitHub"])
17
+ print(embeddings.shape) # (2, 384)
18
+ """
19
+
20
+ from __future__ import annotations
21
+
22
+ import hashlib
23
+ import json
24
+ import logging
25
+ import re
26
+ from abc import ABC, abstractmethod
27
+ from collections import Counter
28
+ from math import log
29
+ from pathlib import Path
30
+ from typing import Any, Dict, List, Optional
31
+
32
+ logger = logging.getLogger(__name__)
33
+
34
+
35
+ class BaseEmbedder(ABC):
36
+ """Abstract base class for text embedders."""
37
+
38
+ @abstractmethod
39
+ def embed(self, text: str) -> Any:
40
+ """Embed a single text.
41
+
42
+ Args:
43
+ text: Input text to embed.
44
+
45
+ Returns:
46
+ Embedding vector (numpy array or dict for sparse).
47
+ """
48
+ pass
49
+
50
+ @abstractmethod
51
+ def embed_batch(self, texts: List[str]) -> Any:
52
+ """Embed multiple texts.
53
+
54
+ Args:
55
+ texts: List of texts to embed.
56
+
57
+ Returns:
58
+ Embeddings matrix (numpy array of shape [n_texts, embedding_dim]).
59
+ """
60
+ pass
61
+
62
+ def cosine_similarity(self, vec1: Any, vec2: Any) -> float:
63
+ """Compute cosine similarity between two vectors.
64
+
65
+ Args:
66
+ vec1: First embedding vector.
67
+ vec2: Second embedding vector.
68
+
69
+ Returns:
70
+ Cosine similarity in [-1, 1].
71
+ """
72
+ import numpy as np
73
+
74
+ vec1 = np.asarray(vec1, dtype=np.float32).flatten()
75
+ vec2 = np.asarray(vec2, dtype=np.float32).flatten()
76
+
77
+ dot = np.dot(vec1, vec2)
78
+ norm1 = np.linalg.norm(vec1)
79
+ norm2 = np.linalg.norm(vec2)
80
+
81
+ if norm1 == 0 or norm2 == 0:
82
+ return 0.0
83
+
84
+ return float(dot / (norm1 * norm2))
85
+
86
+
87
+ # =============================================================================
88
+ # TF-IDF Embedder (Baseline)
89
+ # =============================================================================
90
+
91
+
92
+ class TFIDFEmbedder(BaseEmbedder):
93
+ """Simple TF-IDF based text embedder.
94
+
95
+ This is a minimal implementation for baseline/testing that doesn't require
96
+ any external ML libraries. Uses sparse representations internally but
97
+ converts to dense for compatibility.
98
+
99
+ Note: Must call fit() before embed() to build vocabulary.
100
+ """
101
+
102
+ def __init__(self, max_features: int = 1000) -> None:
103
+ """Initialize the TF-IDF embedder.
104
+
105
+ Args:
106
+ max_features: Maximum vocabulary size.
107
+ """
108
+ self.max_features = max_features
109
+ self.documents: List[str] = []
110
+ self.idf: Dict[str, float] = {}
111
+ self.vocab: List[str] = []
112
+ self.vocab_to_idx: Dict[str, int] = {}
113
+ self._is_fitted = False
114
+
115
+ def _tokenize(self, text: str) -> List[str]:
116
+ """Simple tokenization - lowercase and split on non-alphanumeric.
117
+
118
+ Args:
119
+ text: Input text to tokenize.
120
+
121
+ Returns:
122
+ List of tokens.
123
+ """
124
+ tokens = re.findall(r"\b\w+\b", text.lower())
125
+ return tokens
126
+
127
+ def _compute_tf(self, tokens: List[str]) -> Dict[str, float]:
128
+ """Compute term frequency for a document.
129
+
130
+ Args:
131
+ tokens: List of tokens.
132
+
133
+ Returns:
134
+ Dictionary mapping term to frequency.
135
+ """
136
+ counter = Counter(tokens)
137
+ total = len(tokens)
138
+ if total == 0:
139
+ return {}
140
+ return {term: count / total for term, count in counter.items()}
141
+
142
+ def fit(self, documents: List[str]) -> "TFIDFEmbedder":
143
+ """Fit the IDF on a corpus of documents.
144
+
145
+ Args:
146
+ documents: List of text documents.
147
+
148
+ Returns:
149
+ self for chaining.
150
+ """
151
+ self.documents = documents
152
+
153
+ # Count document frequency for each term
154
+ doc_freq: Dict[str, int] = {}
155
+ all_terms: Counter[str] = Counter()
156
+
157
+ for doc in documents:
158
+ tokens = self._tokenize(doc)
159
+ unique_tokens = set(tokens)
160
+ all_terms.update(tokens)
161
+ for token in unique_tokens:
162
+ doc_freq[token] = doc_freq.get(token, 0) + 1
163
+
164
+ # Select top features by frequency
165
+ top_terms = [term for term, _ in all_terms.most_common(self.max_features)]
166
+ self.vocab = top_terms
167
+ self.vocab_to_idx = {term: idx for idx, term in enumerate(top_terms)}
168
+
169
+ # Compute IDF: log(N / df) + 1
170
+ n_docs = max(len(documents), 1)
171
+ self.idf = {
172
+ term: log(n_docs / doc_freq.get(term, 1)) + 1 for term in self.vocab
173
+ }
174
+
175
+ self._is_fitted = True
176
+ return self
177
+
178
+ def embed(self, text: str) -> Any:
179
+ """Convert text to TF-IDF vector.
180
+
181
+ Args:
182
+ text: Input text.
183
+
184
+ Returns:
185
+ Dense embedding vector (numpy array).
186
+ """
187
+ import numpy as np
188
+
189
+ if not self._is_fitted:
190
+ # Fit on single document for compatibility
191
+ self.fit([text])
192
+
193
+ tokens = self._tokenize(text)
194
+ tf = self._compute_tf(tokens)
195
+
196
+ # Create dense vector
197
+ vec = np.zeros(len(self.vocab), dtype=np.float32)
198
+ for term, tf_val in tf.items():
199
+ if term in self.vocab_to_idx:
200
+ idx = self.vocab_to_idx[term]
201
+ vec[idx] = tf_val * self.idf.get(term, 1.0)
202
+
203
+ # L2 normalize
204
+ norm = np.linalg.norm(vec)
205
+ if norm > 0:
206
+ vec = vec / norm
207
+
208
+ return vec
209
+
210
+ def embed_batch(self, texts: List[str]) -> Any:
211
+ """Embed multiple texts.
212
+
213
+ If not fitted, fits on the input texts first.
214
+
215
+ Args:
216
+ texts: List of texts to embed.
217
+
218
+ Returns:
219
+ Embeddings matrix (numpy array).
220
+ """
221
+ import numpy as np
222
+
223
+ if not self._is_fitted:
224
+ self.fit(texts)
225
+
226
+ embeddings = np.array([self.embed(text) for text in texts], dtype=np.float32)
227
+ return embeddings
228
+
229
+
230
+ # Alias for backward compatibility
231
+ TextEmbedder = TFIDFEmbedder
232
+
233
+
234
+ # =============================================================================
235
+ # Sentence Transformers Embedder
236
+ # =============================================================================
237
+
238
+
239
+ class SentenceTransformerEmbedder(BaseEmbedder):
240
+ """Embedding using sentence-transformers library.
241
+
242
+ Recommended models:
243
+ - "all-MiniLM-L6-v2": Fast, 22MB, 384 dims (default)
244
+ - "all-mpnet-base-v2": Better quality, 420MB, 768 dims
245
+ - "BAAI/bge-small-en-v1.5": Good balance, 130MB, 384 dims
246
+ - "BAAI/bge-base-en-v1.5": Best quality, 440MB, 768 dims
247
+
248
+ Requires: pip install sentence-transformers
249
+ """
250
+
251
+ def __init__(
252
+ self,
253
+ model_name: str = "all-MiniLM-L6-v2",
254
+ cache_dir: Optional[Path] = None,
255
+ device: Optional[str] = None,
256
+ normalize: bool = True,
257
+ ) -> None:
258
+ """Initialize the Sentence Transformer embedder.
259
+
260
+ Args:
261
+ model_name: Name of the sentence-transformers model.
262
+ cache_dir: Directory for caching model and embeddings.
263
+ device: Device to run on ("cpu", "cuda", "mps"). Auto-detected if None.
264
+ normalize: Whether to L2-normalize embeddings (for cosine similarity).
265
+ """
266
+ self.model_name = model_name
267
+ self.cache_dir = Path(cache_dir) if cache_dir else None
268
+ self.device = device
269
+ self.normalize = normalize
270
+ self._model = None
271
+ self._embedding_cache: Dict[str, Any] = {}
272
+
273
+ def _load_model(self) -> None:
274
+ """Lazy-load the model."""
275
+ if self._model is not None:
276
+ return
277
+
278
+ try:
279
+ from sentence_transformers import SentenceTransformer
280
+ except ImportError:
281
+ raise ImportError(
282
+ "sentence-transformers is required for SentenceTransformerEmbedder. "
283
+ "Install with: pip install sentence-transformers"
284
+ )
285
+
286
+ logger.info(f"Loading sentence-transformers model: {self.model_name}")
287
+ self._model = SentenceTransformer(
288
+ self.model_name,
289
+ cache_folder=str(self.cache_dir) if self.cache_dir else None,
290
+ device=self.device,
291
+ )
292
+
293
+ def _get_cache_key(self, text: str) -> str:
294
+ """Generate cache key for text."""
295
+ return hashlib.md5(f"{self.model_name}:{text}".encode()).hexdigest()
296
+
297
+ def embed(self, text: str) -> Any:
298
+ """Embed a single text.
299
+
300
+ Args:
301
+ text: Input text.
302
+
303
+ Returns:
304
+ Embedding vector (numpy array).
305
+ """
306
+ import numpy as np
307
+
308
+ # Check cache
309
+ cache_key = self._get_cache_key(text)
310
+ if cache_key in self._embedding_cache:
311
+ return self._embedding_cache[cache_key]
312
+
313
+ self._load_model()
314
+
315
+ embedding = self._model.encode(
316
+ text,
317
+ normalize_embeddings=self.normalize,
318
+ convert_to_numpy=True,
319
+ )
320
+ embedding = np.asarray(embedding, dtype=np.float32)
321
+
322
+ # Cache result
323
+ self._embedding_cache[cache_key] = embedding
324
+
325
+ return embedding
326
+
327
+ def embed_batch(self, texts: List[str]) -> Any:
328
+ """Embed multiple texts efficiently.
329
+
330
+ Args:
331
+ texts: List of texts.
332
+
333
+ Returns:
334
+ Embeddings matrix (numpy array of shape [n_texts, dim]).
335
+ """
336
+ import numpy as np
337
+
338
+ self._load_model()
339
+
340
+ # Check which texts are cached
341
+ cached_embeddings = {}
342
+ uncached_texts = []
343
+ uncached_indices = []
344
+
345
+ for i, text in enumerate(texts):
346
+ cache_key = self._get_cache_key(text)
347
+ if cache_key in self._embedding_cache:
348
+ cached_embeddings[i] = self._embedding_cache[cache_key]
349
+ else:
350
+ uncached_texts.append(text)
351
+ uncached_indices.append(i)
352
+
353
+ # Embed uncached texts
354
+ if uncached_texts:
355
+ new_embeddings = self._model.encode(
356
+ uncached_texts,
357
+ normalize_embeddings=self.normalize,
358
+ convert_to_numpy=True,
359
+ show_progress_bar=len(uncached_texts) > 10,
360
+ )
361
+
362
+ # Cache new embeddings
363
+ for text, embedding in zip(uncached_texts, new_embeddings):
364
+ cache_key = self._get_cache_key(text)
365
+ self._embedding_cache[cache_key] = embedding
366
+
367
+ # Reassemble in original order
368
+ dim = self._model.get_sentence_embedding_dimension()
369
+ result = np.zeros((len(texts), dim), dtype=np.float32)
370
+
371
+ for i, emb in cached_embeddings.items():
372
+ result[i] = emb
373
+
374
+ for i, idx in enumerate(uncached_indices):
375
+ result[idx] = new_embeddings[i]
376
+
377
+ return result
378
+
379
+ def clear_cache(self) -> None:
380
+ """Clear the embedding cache."""
381
+ self._embedding_cache = {}
382
+
383
+
384
+ # =============================================================================
385
+ # OpenAI Embedder
386
+ # =============================================================================
387
+
388
+
389
+ class OpenAIEmbedder(BaseEmbedder):
390
+ """Embedding using OpenAI's text-embedding API.
391
+
392
+ Models:
393
+ - "text-embedding-3-small": Cheap, fast, 1536 dims ($0.00002/1K tokens)
394
+ - "text-embedding-3-large": Best quality, 3072 dims ($0.00013/1K tokens)
395
+ - "text-embedding-ada-002": Legacy, 1536 dims
396
+
397
+ Requires: pip install openai
398
+ Environment: OPENAI_API_KEY must be set
399
+ """
400
+
401
+ def __init__(
402
+ self,
403
+ model_name: str = "text-embedding-3-small",
404
+ cache_dir: Optional[Path] = None,
405
+ api_key: Optional[str] = None,
406
+ normalize: bool = True,
407
+ batch_size: int = 100,
408
+ ) -> None:
409
+ """Initialize the OpenAI embedder.
410
+
411
+ Args:
412
+ model_name: OpenAI embedding model name.
413
+ cache_dir: Directory for caching embeddings to disk.
414
+ api_key: OpenAI API key. If None, uses OPENAI_API_KEY env var.
415
+ normalize: Whether to L2-normalize embeddings.
416
+ batch_size: Maximum texts per API call.
417
+ """
418
+ self.model_name = model_name
419
+ self.cache_dir = Path(cache_dir) if cache_dir else None
420
+ self.api_key = api_key
421
+ self.normalize = normalize
422
+ self.batch_size = batch_size
423
+ self._client = None
424
+ self._embedding_cache: Dict[str, Any] = {}
425
+
426
+ if self.cache_dir:
427
+ self.cache_dir.mkdir(parents=True, exist_ok=True)
428
+ self._load_disk_cache()
429
+
430
+ def _load_disk_cache(self) -> None:
431
+ """Load cache from disk."""
432
+ if not self.cache_dir:
433
+ return
434
+
435
+ cache_file = self.cache_dir / "embeddings_cache.json"
436
+ if cache_file.exists():
437
+ try:
438
+ with open(cache_file) as f:
439
+ cached = json.load(f)
440
+ # Convert lists back to arrays
441
+ import numpy as np
442
+
443
+ for key, val in cached.items():
444
+ self._embedding_cache[key] = np.array(val, dtype=np.float32)
445
+ logger.debug(f"Loaded {len(self._embedding_cache)} cached embeddings")
446
+ except Exception as e:
447
+ logger.warning(f"Failed to load cache: {e}")
448
+
449
+ def _save_disk_cache(self) -> None:
450
+ """Save cache to disk."""
451
+ if not self.cache_dir:
452
+ return
453
+
454
+ cache_file = self.cache_dir / "embeddings_cache.json"
455
+ try:
456
+ # Convert arrays to lists for JSON
457
+ cache_data = {
458
+ key: val.tolist() for key, val in self._embedding_cache.items()
459
+ }
460
+ with open(cache_file, "w") as f:
461
+ json.dump(cache_data, f)
462
+ except Exception as e:
463
+ logger.warning(f"Failed to save cache: {e}")
464
+
465
+ def _get_client(self) -> Any:
466
+ """Get or create OpenAI client."""
467
+ if self._client is not None:
468
+ return self._client
469
+
470
+ try:
471
+ from openai import OpenAI
472
+ except ImportError:
473
+ raise ImportError(
474
+ "openai is required for OpenAIEmbedder. "
475
+ "Install with: pip install openai"
476
+ )
477
+
478
+ self._client = OpenAI(api_key=self.api_key)
479
+ return self._client
480
+
481
+ def _get_cache_key(self, text: str) -> str:
482
+ """Generate cache key for text."""
483
+ return hashlib.md5(f"{self.model_name}:{text}".encode()).hexdigest()
484
+
485
+ def embed(self, text: str) -> Any:
486
+ """Embed a single text.
487
+
488
+ Args:
489
+ text: Input text.
490
+
491
+ Returns:
492
+ Embedding vector (numpy array).
493
+ """
494
+ return self.embed_batch([text])[0]
495
+
496
+ def embed_batch(self, texts: List[str]) -> Any:
497
+ """Embed multiple texts.
498
+
499
+ Args:
500
+ texts: List of texts.
501
+
502
+ Returns:
503
+ Embeddings matrix (numpy array).
504
+ """
505
+ import numpy as np
506
+
507
+ # Check cache first
508
+ cached_embeddings = {}
509
+ uncached_texts = []
510
+ uncached_indices = []
511
+
512
+ for i, text in enumerate(texts):
513
+ cache_key = self._get_cache_key(text)
514
+ if cache_key in self._embedding_cache:
515
+ cached_embeddings[i] = self._embedding_cache[cache_key]
516
+ else:
517
+ uncached_texts.append(text)
518
+ uncached_indices.append(i)
519
+
520
+ # Fetch uncached embeddings from API
521
+ new_embeddings = {}
522
+ if uncached_texts:
523
+ client = self._get_client()
524
+
525
+ # Process in batches
526
+ for batch_start in range(0, len(uncached_texts), self.batch_size):
527
+ batch_texts = uncached_texts[
528
+ batch_start : batch_start + self.batch_size
529
+ ]
530
+
531
+ try:
532
+ response = client.embeddings.create(
533
+ model=self.model_name,
534
+ input=batch_texts,
535
+ )
536
+
537
+ for j, item in enumerate(response.data):
538
+ idx = uncached_indices[batch_start + j]
539
+ embedding = np.array(item.embedding, dtype=np.float32)
540
+
541
+ if self.normalize:
542
+ norm = np.linalg.norm(embedding)
543
+ if norm > 0:
544
+ embedding = embedding / norm
545
+
546
+ new_embeddings[idx] = embedding
547
+
548
+ # Cache the result
549
+ cache_key = self._get_cache_key(batch_texts[j])
550
+ self._embedding_cache[cache_key] = embedding
551
+
552
+ except Exception as e:
553
+ logger.error(f"OpenAI API error: {e}")
554
+ raise
555
+
556
+ # Save to disk cache periodically
557
+ self._save_disk_cache()
558
+
559
+ # Determine embedding dimension
560
+ if cached_embeddings:
561
+ dim = next(iter(cached_embeddings.values())).shape[0]
562
+ elif new_embeddings:
563
+ dim = next(iter(new_embeddings.values())).shape[0]
564
+ else:
565
+ # Default dimensions by model
566
+ dim = {"text-embedding-3-small": 1536, "text-embedding-3-large": 3072}.get(
567
+ self.model_name, 1536
568
+ )
569
+
570
+ # Assemble result
571
+ result = np.zeros((len(texts), dim), dtype=np.float32)
572
+
573
+ for i, emb in cached_embeddings.items():
574
+ result[i] = emb
575
+
576
+ for i, emb in new_embeddings.items():
577
+ result[i] = emb
578
+
579
+ return result
580
+
581
+ def clear_cache(self) -> None:
582
+ """Clear embedding cache."""
583
+ self._embedding_cache = {}
584
+ if self.cache_dir:
585
+ cache_file = self.cache_dir / "embeddings_cache.json"
586
+ if cache_file.exists():
587
+ cache_file.unlink()
588
+
589
+
590
+ # =============================================================================
591
+ # Factory Function
592
+ # =============================================================================
593
+
594
+
595
+ def create_embedder(
596
+ method: str = "tfidf",
597
+ model_name: Optional[str] = None,
598
+ cache_dir: Optional[Path] = None,
599
+ **kwargs: Any,
600
+ ) -> BaseEmbedder:
601
+ """Factory function to create an embedder.
602
+
603
+ Args:
604
+ method: Embedding method ("tfidf", "sentence_transformers", "openai").
605
+ model_name: Model name (method-specific defaults if None).
606
+ cache_dir: Cache directory for embeddings.
607
+ **kwargs: Additional arguments passed to embedder.
608
+
609
+ Returns:
610
+ Embedder instance.
611
+ """
612
+ if method == "tfidf":
613
+ return TFIDFEmbedder(**kwargs)
614
+
615
+ elif method == "sentence_transformers":
616
+ return SentenceTransformerEmbedder(
617
+ model_name=model_name or "all-MiniLM-L6-v2",
618
+ cache_dir=cache_dir,
619
+ **kwargs,
620
+ )
621
+
622
+ elif method == "openai":
623
+ return OpenAIEmbedder(
624
+ model_name=model_name or "text-embedding-3-small",
625
+ cache_dir=cache_dir,
626
+ **kwargs,
627
+ )
628
+
629
+ else:
630
+ raise ValueError(f"Unknown embedding method: {method}")