openadapt-ml 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. openadapt_ml/benchmarks/__init__.py +8 -0
  2. openadapt_ml/benchmarks/agent.py +90 -11
  3. openadapt_ml/benchmarks/azure.py +35 -6
  4. openadapt_ml/benchmarks/cli.py +4449 -201
  5. openadapt_ml/benchmarks/live_tracker.py +180 -0
  6. openadapt_ml/benchmarks/runner.py +41 -4
  7. openadapt_ml/benchmarks/viewer.py +1219 -0
  8. openadapt_ml/benchmarks/vm_monitor.py +610 -0
  9. openadapt_ml/benchmarks/waa.py +61 -4
  10. openadapt_ml/benchmarks/waa_deploy/Dockerfile +222 -0
  11. openadapt_ml/benchmarks/waa_deploy/__init__.py +10 -0
  12. openadapt_ml/benchmarks/waa_deploy/api_agent.py +539 -0
  13. openadapt_ml/benchmarks/waa_deploy/start_waa_server.bat +53 -0
  14. openadapt_ml/benchmarks/waa_live.py +619 -0
  15. openadapt_ml/cloud/local.py +1555 -1
  16. openadapt_ml/cloud/ssh_tunnel.py +553 -0
  17. openadapt_ml/datasets/next_action.py +87 -68
  18. openadapt_ml/evals/grounding.py +26 -8
  19. openadapt_ml/evals/trajectory_matching.py +84 -36
  20. openadapt_ml/experiments/demo_prompt/__init__.py +19 -0
  21. openadapt_ml/experiments/demo_prompt/format_demo.py +226 -0
  22. openadapt_ml/experiments/demo_prompt/results/experiment_20251231_002125.json +83 -0
  23. openadapt_ml/experiments/demo_prompt/results/experiment_n30_20251231_165958.json +1100 -0
  24. openadapt_ml/experiments/demo_prompt/results/multistep_20251231_025051.json +182 -0
  25. openadapt_ml/experiments/demo_prompt/run_experiment.py +531 -0
  26. openadapt_ml/experiments/waa_demo/__init__.py +10 -0
  27. openadapt_ml/experiments/waa_demo/demos.py +357 -0
  28. openadapt_ml/experiments/waa_demo/runner.py +717 -0
  29. openadapt_ml/experiments/waa_demo/tasks.py +151 -0
  30. openadapt_ml/export/__init__.py +9 -0
  31. openadapt_ml/export/__main__.py +6 -0
  32. openadapt_ml/export/cli.py +89 -0
  33. openadapt_ml/export/parquet.py +265 -0
  34. openadapt_ml/ingest/__init__.py +3 -4
  35. openadapt_ml/ingest/capture.py +89 -81
  36. openadapt_ml/ingest/loader.py +116 -68
  37. openadapt_ml/ingest/synthetic.py +221 -159
  38. openadapt_ml/retrieval/README.md +226 -0
  39. openadapt_ml/retrieval/USAGE.md +391 -0
  40. openadapt_ml/retrieval/__init__.py +91 -0
  41. openadapt_ml/retrieval/demo_retriever.py +817 -0
  42. openadapt_ml/retrieval/embeddings.py +629 -0
  43. openadapt_ml/retrieval/index.py +194 -0
  44. openadapt_ml/retrieval/retriever.py +160 -0
  45. openadapt_ml/runtime/policy.py +10 -10
  46. openadapt_ml/schema/__init__.py +104 -0
  47. openadapt_ml/schema/converters.py +541 -0
  48. openadapt_ml/schema/episode.py +457 -0
  49. openadapt_ml/scripts/compare.py +26 -16
  50. openadapt_ml/scripts/eval_policy.py +4 -5
  51. openadapt_ml/scripts/prepare_synthetic.py +14 -17
  52. openadapt_ml/scripts/train.py +81 -70
  53. openadapt_ml/training/benchmark_viewer.py +3225 -0
  54. openadapt_ml/training/trainer.py +120 -363
  55. openadapt_ml/training/trl_trainer.py +354 -0
  56. {openadapt_ml-0.1.0.dist-info → openadapt_ml-0.2.0.dist-info}/METADATA +102 -60
  57. openadapt_ml-0.2.0.dist-info/RECORD +86 -0
  58. openadapt_ml/schemas/__init__.py +0 -53
  59. openadapt_ml/schemas/sessions.py +0 -122
  60. openadapt_ml/schemas/validation.py +0 -252
  61. openadapt_ml-0.1.0.dist-info/RECORD +0 -55
  62. {openadapt_ml-0.1.0.dist-info → openadapt_ml-0.2.0.dist-info}/WHEEL +0 -0
  63. {openadapt_ml-0.1.0.dist-info → openadapt_ml-0.2.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,629 @@
1
+ """Embedding functions for demo retrieval.
2
+
3
+ Supports multiple embedding backends:
4
+ - TF-IDF: Simple baseline, no external dependencies
5
+ - Sentence Transformers: Local embedding models (recommended)
6
+ - OpenAI: API-based embeddings
7
+
8
+ All embedders implement the same interface:
9
+ - embed(text: str) -> numpy.ndarray
10
+ - embed_batch(texts: List[str]) -> numpy.ndarray
11
+
12
+ Example:
13
+ from openadapt_ml.retrieval.embeddings import SentenceTransformerEmbedder
14
+
15
+ embedder = SentenceTransformerEmbedder()
16
+ embeddings = embedder.embed_batch(["Turn off Night Shift", "Search GitHub"])
17
+ print(embeddings.shape) # (2, 384)
18
+ """
19
+
20
+ from __future__ import annotations
21
+
22
+ import hashlib
23
+ import json
24
+ import logging
25
+ import re
26
+ from abc import ABC, abstractmethod
27
+ from collections import Counter
28
+ from math import log, sqrt
29
+ from pathlib import Path
30
+ from typing import Any, Dict, List, Optional, Union
31
+
32
+ logger = logging.getLogger(__name__)
33
+
34
+
35
+ class BaseEmbedder(ABC):
36
+ """Abstract base class for text embedders."""
37
+
38
+ @abstractmethod
39
+ def embed(self, text: str) -> Any:
40
+ """Embed a single text.
41
+
42
+ Args:
43
+ text: Input text to embed.
44
+
45
+ Returns:
46
+ Embedding vector (numpy array or dict for sparse).
47
+ """
48
+ pass
49
+
50
+ @abstractmethod
51
+ def embed_batch(self, texts: List[str]) -> Any:
52
+ """Embed multiple texts.
53
+
54
+ Args:
55
+ texts: List of texts to embed.
56
+
57
+ Returns:
58
+ Embeddings matrix (numpy array of shape [n_texts, embedding_dim]).
59
+ """
60
+ pass
61
+
62
+ def cosine_similarity(self, vec1: Any, vec2: Any) -> float:
63
+ """Compute cosine similarity between two vectors.
64
+
65
+ Args:
66
+ vec1: First embedding vector.
67
+ vec2: Second embedding vector.
68
+
69
+ Returns:
70
+ Cosine similarity in [-1, 1].
71
+ """
72
+ import numpy as np
73
+
74
+ vec1 = np.asarray(vec1, dtype=np.float32).flatten()
75
+ vec2 = np.asarray(vec2, dtype=np.float32).flatten()
76
+
77
+ dot = np.dot(vec1, vec2)
78
+ norm1 = np.linalg.norm(vec1)
79
+ norm2 = np.linalg.norm(vec2)
80
+
81
+ if norm1 == 0 or norm2 == 0:
82
+ return 0.0
83
+
84
+ return float(dot / (norm1 * norm2))
85
+
86
+
87
+ # =============================================================================
88
+ # TF-IDF Embedder (Baseline)
89
+ # =============================================================================
90
+
91
+
92
+ class TFIDFEmbedder(BaseEmbedder):
93
+ """Simple TF-IDF based text embedder.
94
+
95
+ This is a minimal implementation for baseline/testing that doesn't require
96
+ any external ML libraries. Uses sparse representations internally but
97
+ converts to dense for compatibility.
98
+
99
+ Note: Must call fit() before embed() to build vocabulary.
100
+ """
101
+
102
+ def __init__(self, max_features: int = 1000) -> None:
103
+ """Initialize the TF-IDF embedder.
104
+
105
+ Args:
106
+ max_features: Maximum vocabulary size.
107
+ """
108
+ self.max_features = max_features
109
+ self.documents: List[str] = []
110
+ self.idf: Dict[str, float] = {}
111
+ self.vocab: List[str] = []
112
+ self.vocab_to_idx: Dict[str, int] = {}
113
+ self._is_fitted = False
114
+
115
+ def _tokenize(self, text: str) -> List[str]:
116
+ """Simple tokenization - lowercase and split on non-alphanumeric.
117
+
118
+ Args:
119
+ text: Input text to tokenize.
120
+
121
+ Returns:
122
+ List of tokens.
123
+ """
124
+ tokens = re.findall(r'\b\w+\b', text.lower())
125
+ return tokens
126
+
127
+ def _compute_tf(self, tokens: List[str]) -> Dict[str, float]:
128
+ """Compute term frequency for a document.
129
+
130
+ Args:
131
+ tokens: List of tokens.
132
+
133
+ Returns:
134
+ Dictionary mapping term to frequency.
135
+ """
136
+ counter = Counter(tokens)
137
+ total = len(tokens)
138
+ if total == 0:
139
+ return {}
140
+ return {term: count / total for term, count in counter.items()}
141
+
142
+ def fit(self, documents: List[str]) -> "TFIDFEmbedder":
143
+ """Fit the IDF on a corpus of documents.
144
+
145
+ Args:
146
+ documents: List of text documents.
147
+
148
+ Returns:
149
+ self for chaining.
150
+ """
151
+ self.documents = documents
152
+
153
+ # Count document frequency for each term
154
+ doc_freq: Dict[str, int] = {}
155
+ all_terms: Counter[str] = Counter()
156
+
157
+ for doc in documents:
158
+ tokens = self._tokenize(doc)
159
+ unique_tokens = set(tokens)
160
+ all_terms.update(tokens)
161
+ for token in unique_tokens:
162
+ doc_freq[token] = doc_freq.get(token, 0) + 1
163
+
164
+ # Select top features by frequency
165
+ top_terms = [term for term, _ in all_terms.most_common(self.max_features)]
166
+ self.vocab = top_terms
167
+ self.vocab_to_idx = {term: idx for idx, term in enumerate(top_terms)}
168
+
169
+ # Compute IDF: log(N / df) + 1
170
+ n_docs = max(len(documents), 1)
171
+ self.idf = {
172
+ term: log(n_docs / doc_freq.get(term, 1)) + 1
173
+ for term in self.vocab
174
+ }
175
+
176
+ self._is_fitted = True
177
+ return self
178
+
179
+ def embed(self, text: str) -> Any:
180
+ """Convert text to TF-IDF vector.
181
+
182
+ Args:
183
+ text: Input text.
184
+
185
+ Returns:
186
+ Dense embedding vector (numpy array).
187
+ """
188
+ import numpy as np
189
+
190
+ if not self._is_fitted:
191
+ # Fit on single document for compatibility
192
+ self.fit([text])
193
+
194
+ tokens = self._tokenize(text)
195
+ tf = self._compute_tf(tokens)
196
+
197
+ # Create dense vector
198
+ vec = np.zeros(len(self.vocab), dtype=np.float32)
199
+ for term, tf_val in tf.items():
200
+ if term in self.vocab_to_idx:
201
+ idx = self.vocab_to_idx[term]
202
+ vec[idx] = tf_val * self.idf.get(term, 1.0)
203
+
204
+ # L2 normalize
205
+ norm = np.linalg.norm(vec)
206
+ if norm > 0:
207
+ vec = vec / norm
208
+
209
+ return vec
210
+
211
+ def embed_batch(self, texts: List[str]) -> Any:
212
+ """Embed multiple texts.
213
+
214
+ If not fitted, fits on the input texts first.
215
+
216
+ Args:
217
+ texts: List of texts to embed.
218
+
219
+ Returns:
220
+ Embeddings matrix (numpy array).
221
+ """
222
+ import numpy as np
223
+
224
+ if not self._is_fitted:
225
+ self.fit(texts)
226
+
227
+ embeddings = np.array([self.embed(text) for text in texts], dtype=np.float32)
228
+ return embeddings
229
+
230
+
231
+ # Alias for backward compatibility
232
+ TextEmbedder = TFIDFEmbedder
233
+
234
+
235
+ # =============================================================================
236
+ # Sentence Transformers Embedder
237
+ # =============================================================================
238
+
239
+
240
+ class SentenceTransformerEmbedder(BaseEmbedder):
241
+ """Embedding using sentence-transformers library.
242
+
243
+ Recommended models:
244
+ - "all-MiniLM-L6-v2": Fast, 22MB, 384 dims (default)
245
+ - "all-mpnet-base-v2": Better quality, 420MB, 768 dims
246
+ - "BAAI/bge-small-en-v1.5": Good balance, 130MB, 384 dims
247
+ - "BAAI/bge-base-en-v1.5": Best quality, 440MB, 768 dims
248
+
249
+ Requires: pip install sentence-transformers
250
+ """
251
+
252
+ def __init__(
253
+ self,
254
+ model_name: str = "all-MiniLM-L6-v2",
255
+ cache_dir: Optional[Path] = None,
256
+ device: Optional[str] = None,
257
+ normalize: bool = True,
258
+ ) -> None:
259
+ """Initialize the Sentence Transformer embedder.
260
+
261
+ Args:
262
+ model_name: Name of the sentence-transformers model.
263
+ cache_dir: Directory for caching model and embeddings.
264
+ device: Device to run on ("cpu", "cuda", "mps"). Auto-detected if None.
265
+ normalize: Whether to L2-normalize embeddings (for cosine similarity).
266
+ """
267
+ self.model_name = model_name
268
+ self.cache_dir = Path(cache_dir) if cache_dir else None
269
+ self.device = device
270
+ self.normalize = normalize
271
+ self._model = None
272
+ self._embedding_cache: Dict[str, Any] = {}
273
+
274
+ def _load_model(self) -> None:
275
+ """Lazy-load the model."""
276
+ if self._model is not None:
277
+ return
278
+
279
+ try:
280
+ from sentence_transformers import SentenceTransformer
281
+ except ImportError:
282
+ raise ImportError(
283
+ "sentence-transformers is required for SentenceTransformerEmbedder. "
284
+ "Install with: pip install sentence-transformers"
285
+ )
286
+
287
+ logger.info(f"Loading sentence-transformers model: {self.model_name}")
288
+ self._model = SentenceTransformer(
289
+ self.model_name,
290
+ cache_folder=str(self.cache_dir) if self.cache_dir else None,
291
+ device=self.device,
292
+ )
293
+
294
+ def _get_cache_key(self, text: str) -> str:
295
+ """Generate cache key for text."""
296
+ return hashlib.md5(f"{self.model_name}:{text}".encode()).hexdigest()
297
+
298
+ def embed(self, text: str) -> Any:
299
+ """Embed a single text.
300
+
301
+ Args:
302
+ text: Input text.
303
+
304
+ Returns:
305
+ Embedding vector (numpy array).
306
+ """
307
+ import numpy as np
308
+
309
+ # Check cache
310
+ cache_key = self._get_cache_key(text)
311
+ if cache_key in self._embedding_cache:
312
+ return self._embedding_cache[cache_key]
313
+
314
+ self._load_model()
315
+
316
+ embedding = self._model.encode(
317
+ text,
318
+ normalize_embeddings=self.normalize,
319
+ convert_to_numpy=True,
320
+ )
321
+ embedding = np.asarray(embedding, dtype=np.float32)
322
+
323
+ # Cache result
324
+ self._embedding_cache[cache_key] = embedding
325
+
326
+ return embedding
327
+
328
+ def embed_batch(self, texts: List[str]) -> Any:
329
+ """Embed multiple texts efficiently.
330
+
331
+ Args:
332
+ texts: List of texts.
333
+
334
+ Returns:
335
+ Embeddings matrix (numpy array of shape [n_texts, dim]).
336
+ """
337
+ import numpy as np
338
+
339
+ self._load_model()
340
+
341
+ # Check which texts are cached
342
+ cached_embeddings = {}
343
+ uncached_texts = []
344
+ uncached_indices = []
345
+
346
+ for i, text in enumerate(texts):
347
+ cache_key = self._get_cache_key(text)
348
+ if cache_key in self._embedding_cache:
349
+ cached_embeddings[i] = self._embedding_cache[cache_key]
350
+ else:
351
+ uncached_texts.append(text)
352
+ uncached_indices.append(i)
353
+
354
+ # Embed uncached texts
355
+ if uncached_texts:
356
+ new_embeddings = self._model.encode(
357
+ uncached_texts,
358
+ normalize_embeddings=self.normalize,
359
+ convert_to_numpy=True,
360
+ show_progress_bar=len(uncached_texts) > 10,
361
+ )
362
+
363
+ # Cache new embeddings
364
+ for text, embedding in zip(uncached_texts, new_embeddings):
365
+ cache_key = self._get_cache_key(text)
366
+ self._embedding_cache[cache_key] = embedding
367
+
368
+ # Reassemble in original order
369
+ dim = self._model.get_sentence_embedding_dimension()
370
+ result = np.zeros((len(texts), dim), dtype=np.float32)
371
+
372
+ for i, emb in cached_embeddings.items():
373
+ result[i] = emb
374
+
375
+ for i, idx in enumerate(uncached_indices):
376
+ result[idx] = new_embeddings[i]
377
+
378
+ return result
379
+
380
+ def clear_cache(self) -> None:
381
+ """Clear the embedding cache."""
382
+ self._embedding_cache = {}
383
+
384
+
385
+ # =============================================================================
386
+ # OpenAI Embedder
387
+ # =============================================================================
388
+
389
+
390
+ class OpenAIEmbedder(BaseEmbedder):
391
+ """Embedding using OpenAI's text-embedding API.
392
+
393
+ Models:
394
+ - "text-embedding-3-small": Cheap, fast, 1536 dims ($0.00002/1K tokens)
395
+ - "text-embedding-3-large": Best quality, 3072 dims ($0.00013/1K tokens)
396
+ - "text-embedding-ada-002": Legacy, 1536 dims
397
+
398
+ Requires: pip install openai
399
+ Environment: OPENAI_API_KEY must be set
400
+ """
401
+
402
+ def __init__(
403
+ self,
404
+ model_name: str = "text-embedding-3-small",
405
+ cache_dir: Optional[Path] = None,
406
+ api_key: Optional[str] = None,
407
+ normalize: bool = True,
408
+ batch_size: int = 100,
409
+ ) -> None:
410
+ """Initialize the OpenAI embedder.
411
+
412
+ Args:
413
+ model_name: OpenAI embedding model name.
414
+ cache_dir: Directory for caching embeddings to disk.
415
+ api_key: OpenAI API key. If None, uses OPENAI_API_KEY env var.
416
+ normalize: Whether to L2-normalize embeddings.
417
+ batch_size: Maximum texts per API call.
418
+ """
419
+ self.model_name = model_name
420
+ self.cache_dir = Path(cache_dir) if cache_dir else None
421
+ self.api_key = api_key
422
+ self.normalize = normalize
423
+ self.batch_size = batch_size
424
+ self._client = None
425
+ self._embedding_cache: Dict[str, Any] = {}
426
+
427
+ if self.cache_dir:
428
+ self.cache_dir.mkdir(parents=True, exist_ok=True)
429
+ self._load_disk_cache()
430
+
431
+ def _load_disk_cache(self) -> None:
432
+ """Load cache from disk."""
433
+ if not self.cache_dir:
434
+ return
435
+
436
+ cache_file = self.cache_dir / "embeddings_cache.json"
437
+ if cache_file.exists():
438
+ try:
439
+ with open(cache_file) as f:
440
+ cached = json.load(f)
441
+ # Convert lists back to arrays
442
+ import numpy as np
443
+ for key, val in cached.items():
444
+ self._embedding_cache[key] = np.array(val, dtype=np.float32)
445
+ logger.debug(f"Loaded {len(self._embedding_cache)} cached embeddings")
446
+ except Exception as e:
447
+ logger.warning(f"Failed to load cache: {e}")
448
+
449
+ def _save_disk_cache(self) -> None:
450
+ """Save cache to disk."""
451
+ if not self.cache_dir:
452
+ return
453
+
454
+ cache_file = self.cache_dir / "embeddings_cache.json"
455
+ try:
456
+ # Convert arrays to lists for JSON
457
+ cache_data = {
458
+ key: val.tolist()
459
+ for key, val in self._embedding_cache.items()
460
+ }
461
+ with open(cache_file, "w") as f:
462
+ json.dump(cache_data, f)
463
+ except Exception as e:
464
+ logger.warning(f"Failed to save cache: {e}")
465
+
466
+ def _get_client(self) -> Any:
467
+ """Get or create OpenAI client."""
468
+ if self._client is not None:
469
+ return self._client
470
+
471
+ try:
472
+ from openai import OpenAI
473
+ except ImportError:
474
+ raise ImportError(
475
+ "openai is required for OpenAIEmbedder. "
476
+ "Install with: pip install openai"
477
+ )
478
+
479
+ self._client = OpenAI(api_key=self.api_key)
480
+ return self._client
481
+
482
+ def _get_cache_key(self, text: str) -> str:
483
+ """Generate cache key for text."""
484
+ return hashlib.md5(f"{self.model_name}:{text}".encode()).hexdigest()
485
+
486
+ def embed(self, text: str) -> Any:
487
+ """Embed a single text.
488
+
489
+ Args:
490
+ text: Input text.
491
+
492
+ Returns:
493
+ Embedding vector (numpy array).
494
+ """
495
+ return self.embed_batch([text])[0]
496
+
497
+ def embed_batch(self, texts: List[str]) -> Any:
498
+ """Embed multiple texts.
499
+
500
+ Args:
501
+ texts: List of texts.
502
+
503
+ Returns:
504
+ Embeddings matrix (numpy array).
505
+ """
506
+ import numpy as np
507
+
508
+ # Check cache first
509
+ cached_embeddings = {}
510
+ uncached_texts = []
511
+ uncached_indices = []
512
+
513
+ for i, text in enumerate(texts):
514
+ cache_key = self._get_cache_key(text)
515
+ if cache_key in self._embedding_cache:
516
+ cached_embeddings[i] = self._embedding_cache[cache_key]
517
+ else:
518
+ uncached_texts.append(text)
519
+ uncached_indices.append(i)
520
+
521
+ # Fetch uncached embeddings from API
522
+ new_embeddings = {}
523
+ if uncached_texts:
524
+ client = self._get_client()
525
+
526
+ # Process in batches
527
+ for batch_start in range(0, len(uncached_texts), self.batch_size):
528
+ batch_texts = uncached_texts[batch_start:batch_start + self.batch_size]
529
+
530
+ try:
531
+ response = client.embeddings.create(
532
+ model=self.model_name,
533
+ input=batch_texts,
534
+ )
535
+
536
+ for j, item in enumerate(response.data):
537
+ idx = uncached_indices[batch_start + j]
538
+ embedding = np.array(item.embedding, dtype=np.float32)
539
+
540
+ if self.normalize:
541
+ norm = np.linalg.norm(embedding)
542
+ if norm > 0:
543
+ embedding = embedding / norm
544
+
545
+ new_embeddings[idx] = embedding
546
+
547
+ # Cache the result
548
+ cache_key = self._get_cache_key(batch_texts[j])
549
+ self._embedding_cache[cache_key] = embedding
550
+
551
+ except Exception as e:
552
+ logger.error(f"OpenAI API error: {e}")
553
+ raise
554
+
555
+ # Save to disk cache periodically
556
+ self._save_disk_cache()
557
+
558
+ # Determine embedding dimension
559
+ if cached_embeddings:
560
+ dim = next(iter(cached_embeddings.values())).shape[0]
561
+ elif new_embeddings:
562
+ dim = next(iter(new_embeddings.values())).shape[0]
563
+ else:
564
+ # Default dimensions by model
565
+ dim = {"text-embedding-3-small": 1536, "text-embedding-3-large": 3072}.get(
566
+ self.model_name, 1536
567
+ )
568
+
569
+ # Assemble result
570
+ result = np.zeros((len(texts), dim), dtype=np.float32)
571
+
572
+ for i, emb in cached_embeddings.items():
573
+ result[i] = emb
574
+
575
+ for i, emb in new_embeddings.items():
576
+ result[i] = emb
577
+
578
+ return result
579
+
580
+ def clear_cache(self) -> None:
581
+ """Clear embedding cache."""
582
+ self._embedding_cache = {}
583
+ if self.cache_dir:
584
+ cache_file = self.cache_dir / "embeddings_cache.json"
585
+ if cache_file.exists():
586
+ cache_file.unlink()
587
+
588
+
589
+ # =============================================================================
590
+ # Factory Function
591
+ # =============================================================================
592
+
593
+
594
+ def create_embedder(
595
+ method: str = "tfidf",
596
+ model_name: Optional[str] = None,
597
+ cache_dir: Optional[Path] = None,
598
+ **kwargs: Any,
599
+ ) -> BaseEmbedder:
600
+ """Factory function to create an embedder.
601
+
602
+ Args:
603
+ method: Embedding method ("tfidf", "sentence_transformers", "openai").
604
+ model_name: Model name (method-specific defaults if None).
605
+ cache_dir: Cache directory for embeddings.
606
+ **kwargs: Additional arguments passed to embedder.
607
+
608
+ Returns:
609
+ Embedder instance.
610
+ """
611
+ if method == "tfidf":
612
+ return TFIDFEmbedder(**kwargs)
613
+
614
+ elif method == "sentence_transformers":
615
+ return SentenceTransformerEmbedder(
616
+ model_name=model_name or "all-MiniLM-L6-v2",
617
+ cache_dir=cache_dir,
618
+ **kwargs,
619
+ )
620
+
621
+ elif method == "openai":
622
+ return OpenAIEmbedder(
623
+ model_name=model_name or "text-embedding-3-small",
624
+ cache_dir=cache_dir,
625
+ **kwargs,
626
+ )
627
+
628
+ else:
629
+ raise ValueError(f"Unknown embedding method: {method}")