keep-skill 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- keep/__init__.py +53 -0
- keep/__main__.py +8 -0
- keep/api.py +686 -0
- keep/chunking.py +364 -0
- keep/cli.py +503 -0
- keep/config.py +323 -0
- keep/context.py +127 -0
- keep/indexing.py +208 -0
- keep/logging_config.py +73 -0
- keep/paths.py +67 -0
- keep/pending_summaries.py +166 -0
- keep/providers/__init__.py +40 -0
- keep/providers/base.py +416 -0
- keep/providers/documents.py +250 -0
- keep/providers/embedding_cache.py +260 -0
- keep/providers/embeddings.py +245 -0
- keep/providers/llm.py +371 -0
- keep/providers/mlx.py +256 -0
- keep/providers/summarization.py +107 -0
- keep/store.py +403 -0
- keep/types.py +65 -0
- keep_skill-0.1.0.dist-info/METADATA +290 -0
- keep_skill-0.1.0.dist-info/RECORD +26 -0
- keep_skill-0.1.0.dist-info/WHEEL +4 -0
- keep_skill-0.1.0.dist-info/entry_points.txt +2 -0
- keep_skill-0.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,260 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Embedding cache using SQLite.
|
|
3
|
+
|
|
4
|
+
Wraps any EmbeddingProvider to cache embeddings by content hash,
|
|
5
|
+
avoiding redundant embedding calls for unchanged content.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import hashlib
|
|
9
|
+
import json
|
|
10
|
+
import sqlite3
|
|
11
|
+
from datetime import datetime, timezone
|
|
12
|
+
from pathlib import Path
|
|
13
|
+
from typing import Optional
|
|
14
|
+
|
|
15
|
+
from .base import EmbeddingProvider
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class EmbeddingCache:
|
|
19
|
+
"""
|
|
20
|
+
SQLite-based embedding cache.
|
|
21
|
+
|
|
22
|
+
Cache key is SHA256(model_name + content), so different models
|
|
23
|
+
don't share cached embeddings.
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
def __init__(self, cache_path: Path, max_entries: int = 50000):
|
|
27
|
+
"""
|
|
28
|
+
Args:
|
|
29
|
+
cache_path: Path to SQLite database file
|
|
30
|
+
max_entries: Maximum cache entries (LRU eviction when exceeded)
|
|
31
|
+
"""
|
|
32
|
+
self._cache_path = cache_path
|
|
33
|
+
self._max_entries = max_entries
|
|
34
|
+
self._conn: Optional[sqlite3.Connection] = None
|
|
35
|
+
self._init_db()
|
|
36
|
+
|
|
37
|
+
def _init_db(self) -> None:
|
|
38
|
+
"""Initialize the SQLite database."""
|
|
39
|
+
self._cache_path.parent.mkdir(parents=True, exist_ok=True)
|
|
40
|
+
self._conn = sqlite3.connect(str(self._cache_path), check_same_thread=False)
|
|
41
|
+
self._conn.execute("""
|
|
42
|
+
CREATE TABLE IF NOT EXISTS embedding_cache (
|
|
43
|
+
content_hash TEXT PRIMARY KEY,
|
|
44
|
+
model_name TEXT NOT NULL,
|
|
45
|
+
embedding BLOB NOT NULL,
|
|
46
|
+
dimension INTEGER NOT NULL,
|
|
47
|
+
created_at TEXT NOT NULL,
|
|
48
|
+
last_accessed TEXT NOT NULL
|
|
49
|
+
)
|
|
50
|
+
""")
|
|
51
|
+
self._conn.execute("""
|
|
52
|
+
CREATE INDEX IF NOT EXISTS idx_last_accessed
|
|
53
|
+
ON embedding_cache(last_accessed)
|
|
54
|
+
""")
|
|
55
|
+
self._conn.commit()
|
|
56
|
+
|
|
57
|
+
def _hash_key(self, model_name: str, content: str) -> str:
|
|
58
|
+
"""Generate cache key from model name and content."""
|
|
59
|
+
key_input = f"{model_name}:{content}"
|
|
60
|
+
return hashlib.sha256(key_input.encode("utf-8")).hexdigest()
|
|
61
|
+
|
|
62
|
+
def get(self, model_name: str, content: str) -> Optional[list[float]]:
|
|
63
|
+
"""
|
|
64
|
+
Get cached embedding if it exists.
|
|
65
|
+
|
|
66
|
+
Updates last_accessed timestamp on hit.
|
|
67
|
+
"""
|
|
68
|
+
content_hash = self._hash_key(model_name, content)
|
|
69
|
+
cursor = self._conn.execute(
|
|
70
|
+
"SELECT embedding FROM embedding_cache WHERE content_hash = ?",
|
|
71
|
+
(content_hash,)
|
|
72
|
+
)
|
|
73
|
+
row = cursor.fetchone()
|
|
74
|
+
|
|
75
|
+
if row is not None:
|
|
76
|
+
# Update last_accessed
|
|
77
|
+
now = datetime.now(timezone.utc).isoformat()
|
|
78
|
+
self._conn.execute(
|
|
79
|
+
"UPDATE embedding_cache SET last_accessed = ? WHERE content_hash = ?",
|
|
80
|
+
(now, content_hash)
|
|
81
|
+
)
|
|
82
|
+
self._conn.commit()
|
|
83
|
+
|
|
84
|
+
# Deserialize embedding
|
|
85
|
+
return json.loads(row[0])
|
|
86
|
+
|
|
87
|
+
return None
|
|
88
|
+
|
|
89
|
+
def put(
|
|
90
|
+
self,
|
|
91
|
+
model_name: str,
|
|
92
|
+
content: str,
|
|
93
|
+
embedding: list[float]
|
|
94
|
+
) -> None:
|
|
95
|
+
"""
|
|
96
|
+
Cache an embedding.
|
|
97
|
+
|
|
98
|
+
Evicts oldest entries if cache exceeds max_entries.
|
|
99
|
+
"""
|
|
100
|
+
content_hash = self._hash_key(model_name, content)
|
|
101
|
+
now = datetime.now(timezone.utc).isoformat()
|
|
102
|
+
embedding_blob = json.dumps(embedding)
|
|
103
|
+
|
|
104
|
+
self._conn.execute("""
|
|
105
|
+
INSERT OR REPLACE INTO embedding_cache
|
|
106
|
+
(content_hash, model_name, embedding, dimension, created_at, last_accessed)
|
|
107
|
+
VALUES (?, ?, ?, ?, ?, ?)
|
|
108
|
+
""", (content_hash, model_name, embedding_blob, len(embedding), now, now))
|
|
109
|
+
self._conn.commit()
|
|
110
|
+
|
|
111
|
+
# Evict old entries if needed
|
|
112
|
+
self._maybe_evict()
|
|
113
|
+
|
|
114
|
+
def _maybe_evict(self) -> None:
|
|
115
|
+
"""Evict oldest entries if cache exceeds max size."""
|
|
116
|
+
cursor = self._conn.execute("SELECT COUNT(*) FROM embedding_cache")
|
|
117
|
+
count = cursor.fetchone()[0]
|
|
118
|
+
|
|
119
|
+
if count > self._max_entries:
|
|
120
|
+
# Delete oldest 10% by last_accessed
|
|
121
|
+
evict_count = max(1, count // 10)
|
|
122
|
+
self._conn.execute("""
|
|
123
|
+
DELETE FROM embedding_cache
|
|
124
|
+
WHERE content_hash IN (
|
|
125
|
+
SELECT content_hash FROM embedding_cache
|
|
126
|
+
ORDER BY last_accessed ASC
|
|
127
|
+
LIMIT ?
|
|
128
|
+
)
|
|
129
|
+
""", (evict_count,))
|
|
130
|
+
self._conn.commit()
|
|
131
|
+
|
|
132
|
+
def stats(self) -> dict:
|
|
133
|
+
"""Get cache statistics."""
|
|
134
|
+
cursor = self._conn.execute("""
|
|
135
|
+
SELECT COUNT(*), COUNT(DISTINCT model_name)
|
|
136
|
+
FROM embedding_cache
|
|
137
|
+
""")
|
|
138
|
+
count, models = cursor.fetchone()
|
|
139
|
+
return {
|
|
140
|
+
"entries": count,
|
|
141
|
+
"models": models,
|
|
142
|
+
"max_entries": self._max_entries,
|
|
143
|
+
"cache_path": str(self._cache_path),
|
|
144
|
+
}
|
|
145
|
+
|
|
146
|
+
def clear(self) -> None:
|
|
147
|
+
"""Clear all cached embeddings."""
|
|
148
|
+
self._conn.execute("DELETE FROM embedding_cache")
|
|
149
|
+
self._conn.commit()
|
|
150
|
+
|
|
151
|
+
def close(self) -> None:
|
|
152
|
+
"""Close the database connection."""
|
|
153
|
+
if self._conn is not None:
|
|
154
|
+
self._conn.close()
|
|
155
|
+
self._conn = None
|
|
156
|
+
|
|
157
|
+
def __del__(self) -> None:
|
|
158
|
+
"""Ensure connection is closed on cleanup."""
|
|
159
|
+
self.close()
|
|
160
|
+
|
|
161
|
+
def __enter__(self):
|
|
162
|
+
"""Context manager entry."""
|
|
163
|
+
return self
|
|
164
|
+
|
|
165
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
166
|
+
"""Context manager exit - close connection."""
|
|
167
|
+
self.close()
|
|
168
|
+
return False
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
class CachingEmbeddingProvider:
|
|
172
|
+
"""
|
|
173
|
+
Wrapper that adds caching to any EmbeddingProvider.
|
|
174
|
+
|
|
175
|
+
Usage:
|
|
176
|
+
base_provider = SentenceTransformerEmbedding()
|
|
177
|
+
cached = CachingEmbeddingProvider(base_provider, cache_path)
|
|
178
|
+
"""
|
|
179
|
+
|
|
180
|
+
def __init__(
|
|
181
|
+
self,
|
|
182
|
+
provider: EmbeddingProvider,
|
|
183
|
+
cache_path: Path,
|
|
184
|
+
max_entries: int = 50000
|
|
185
|
+
):
|
|
186
|
+
self._provider = provider
|
|
187
|
+
self._cache = EmbeddingCache(cache_path, max_entries)
|
|
188
|
+
self._hits = 0
|
|
189
|
+
self._misses = 0
|
|
190
|
+
|
|
191
|
+
@property
|
|
192
|
+
def model_name(self) -> str:
|
|
193
|
+
"""Get the underlying provider's model name."""
|
|
194
|
+
return getattr(self._provider, "model_name", "unknown")
|
|
195
|
+
|
|
196
|
+
@property
|
|
197
|
+
def dimension(self) -> int:
|
|
198
|
+
"""Get embedding dimension from the wrapped provider."""
|
|
199
|
+
return self._provider.dimension
|
|
200
|
+
|
|
201
|
+
def embed(self, text: str) -> list[float]:
|
|
202
|
+
"""
|
|
203
|
+
Get embedding, using cache when available.
|
|
204
|
+
"""
|
|
205
|
+
# Check cache
|
|
206
|
+
cached = self._cache.get(self.model_name, text)
|
|
207
|
+
if cached is not None:
|
|
208
|
+
self._hits += 1
|
|
209
|
+
return cached
|
|
210
|
+
|
|
211
|
+
# Cache miss - compute embedding
|
|
212
|
+
self._misses += 1
|
|
213
|
+
embedding = self._provider.embed(text)
|
|
214
|
+
|
|
215
|
+
# Store in cache
|
|
216
|
+
self._cache.put(self.model_name, text, embedding)
|
|
217
|
+
|
|
218
|
+
return embedding
|
|
219
|
+
|
|
220
|
+
def embed_batch(self, texts: list[str]) -> list[list[float]]:
|
|
221
|
+
"""
|
|
222
|
+
Get embeddings for batch, using cache where available.
|
|
223
|
+
|
|
224
|
+
Only computes embeddings for cache misses.
|
|
225
|
+
"""
|
|
226
|
+
results: list[Optional[list[float]]] = [None] * len(texts)
|
|
227
|
+
to_embed: list[tuple[int, str]] = []
|
|
228
|
+
|
|
229
|
+
# Check cache for each text
|
|
230
|
+
for i, text in enumerate(texts):
|
|
231
|
+
cached = self._cache.get(self.model_name, text)
|
|
232
|
+
if cached is not None:
|
|
233
|
+
self._hits += 1
|
|
234
|
+
results[i] = cached
|
|
235
|
+
else:
|
|
236
|
+
self._misses += 1
|
|
237
|
+
to_embed.append((i, text))
|
|
238
|
+
|
|
239
|
+
# Batch embed cache misses
|
|
240
|
+
if to_embed:
|
|
241
|
+
indices, texts_to_embed = zip(*to_embed)
|
|
242
|
+
embeddings = self._provider.embed_batch(list(texts_to_embed))
|
|
243
|
+
|
|
244
|
+
for idx, text, embedding in zip(indices, texts_to_embed, embeddings):
|
|
245
|
+
results[idx] = embedding
|
|
246
|
+
self._cache.put(self.model_name, text, embedding)
|
|
247
|
+
|
|
248
|
+
return results # type: ignore
|
|
249
|
+
|
|
250
|
+
def stats(self) -> dict:
|
|
251
|
+
"""Get cache and hit/miss statistics."""
|
|
252
|
+
cache_stats = self._cache.stats()
|
|
253
|
+
total = self._hits + self._misses
|
|
254
|
+
hit_rate = self._hits / total if total > 0 else 0.0
|
|
255
|
+
return {
|
|
256
|
+
**cache_stats,
|
|
257
|
+
"hits": self._hits,
|
|
258
|
+
"misses": self._misses,
|
|
259
|
+
"hit_rate": f"{hit_rate:.1%}",
|
|
260
|
+
}
|
|
@@ -0,0 +1,245 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Embedding providers for generating vector representations of text.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import os
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
from .base import EmbeddingProvider, get_registry
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class SentenceTransformerEmbedding:
|
|
12
|
+
"""
|
|
13
|
+
Embedding provider using sentence-transformers library.
|
|
14
|
+
|
|
15
|
+
Runs locally, no API key required. Good default for getting started.
|
|
16
|
+
|
|
17
|
+
Requires: pip install sentence-transformers
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
def __init__(self, model: str = "all-MiniLM-L6-v2"):
|
|
21
|
+
"""
|
|
22
|
+
Args:
|
|
23
|
+
model: Model name from sentence-transformers hub
|
|
24
|
+
"""
|
|
25
|
+
try:
|
|
26
|
+
from sentence_transformers import SentenceTransformer
|
|
27
|
+
except ImportError:
|
|
28
|
+
raise RuntimeError(
|
|
29
|
+
"SentenceTransformerEmbedding requires 'sentence-transformers' library. "
|
|
30
|
+
"Install with: pip install sentence-transformers"
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
self.model_name = model
|
|
34
|
+
self._model = SentenceTransformer(model)
|
|
35
|
+
|
|
36
|
+
@property
|
|
37
|
+
def dimension(self) -> int:
|
|
38
|
+
"""Get embedding dimension from the model."""
|
|
39
|
+
return self._model.get_sentence_embedding_dimension()
|
|
40
|
+
|
|
41
|
+
def embed(self, text: str) -> list[float]:
|
|
42
|
+
"""Generate embedding for a single text."""
|
|
43
|
+
embedding = self._model.encode(text, convert_to_numpy=True)
|
|
44
|
+
return embedding.tolist()
|
|
45
|
+
|
|
46
|
+
def embed_batch(self, texts: list[str]) -> list[list[float]]:
|
|
47
|
+
"""Generate embeddings for multiple texts."""
|
|
48
|
+
embeddings = self._model.encode(texts, convert_to_numpy=True)
|
|
49
|
+
return embeddings.tolist()
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class OpenAIEmbedding:
|
|
53
|
+
"""
|
|
54
|
+
Embedding provider using OpenAI's API.
|
|
55
|
+
|
|
56
|
+
Requires: KEEP_OPENAI_API_KEY or OPENAI_API_KEY environment variable.
|
|
57
|
+
Requires: pip install openai
|
|
58
|
+
"""
|
|
59
|
+
|
|
60
|
+
# Model dimensions (as of 2024)
|
|
61
|
+
MODEL_DIMENSIONS = {
|
|
62
|
+
"text-embedding-3-small": 1536,
|
|
63
|
+
"text-embedding-3-large": 3072,
|
|
64
|
+
"text-embedding-ada-002": 1536,
|
|
65
|
+
}
|
|
66
|
+
|
|
67
|
+
def __init__(
|
|
68
|
+
self,
|
|
69
|
+
model: str = "text-embedding-3-small",
|
|
70
|
+
api_key: str | None = None,
|
|
71
|
+
):
|
|
72
|
+
"""
|
|
73
|
+
Args:
|
|
74
|
+
model: OpenAI embedding model name
|
|
75
|
+
api_key: API key (defaults to environment variable)
|
|
76
|
+
"""
|
|
77
|
+
try:
|
|
78
|
+
from openai import OpenAI
|
|
79
|
+
except ImportError:
|
|
80
|
+
raise RuntimeError(
|
|
81
|
+
"OpenAIEmbedding requires 'openai' library. "
|
|
82
|
+
"Install with: pip install openai"
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
self.model_name = model
|
|
86
|
+
self._dimension = self.MODEL_DIMENSIONS.get(model, 1536)
|
|
87
|
+
|
|
88
|
+
# Resolve API key
|
|
89
|
+
key = api_key or os.environ.get("KEEP_OPENAI_API_KEY") or os.environ.get("OPENAI_API_KEY")
|
|
90
|
+
if not key:
|
|
91
|
+
raise ValueError(
|
|
92
|
+
"OpenAI API key required. Set KEEP_OPENAI_API_KEY or OPENAI_API_KEY"
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
self._client = OpenAI(api_key=key)
|
|
96
|
+
|
|
97
|
+
@property
|
|
98
|
+
def dimension(self) -> int:
|
|
99
|
+
"""Get embedding dimension for the model."""
|
|
100
|
+
return self._dimension
|
|
101
|
+
|
|
102
|
+
def embed(self, text: str) -> list[float]:
|
|
103
|
+
"""Generate embedding for a single text."""
|
|
104
|
+
response = self._client.embeddings.create(
|
|
105
|
+
model=self.model_name,
|
|
106
|
+
input=text,
|
|
107
|
+
)
|
|
108
|
+
return response.data[0].embedding
|
|
109
|
+
|
|
110
|
+
def embed_batch(self, texts: list[str]) -> list[list[float]]:
|
|
111
|
+
"""Generate embeddings for multiple texts."""
|
|
112
|
+
response = self._client.embeddings.create(
|
|
113
|
+
model=self.model_name,
|
|
114
|
+
input=texts,
|
|
115
|
+
)
|
|
116
|
+
# Sort by index to ensure order matches input
|
|
117
|
+
sorted_data = sorted(response.data, key=lambda x: x.index)
|
|
118
|
+
return [d.embedding for d in sorted_data]
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
class GeminiEmbedding:
|
|
122
|
+
"""
|
|
123
|
+
Embedding provider using Google's Gemini API.
|
|
124
|
+
|
|
125
|
+
Requires: GEMINI_API_KEY or GOOGLE_API_KEY environment variable.
|
|
126
|
+
Requires: pip install google-genai
|
|
127
|
+
"""
|
|
128
|
+
|
|
129
|
+
# Model dimensions (as of 2025)
|
|
130
|
+
MODEL_DIMENSIONS = {
|
|
131
|
+
"text-embedding-004": 768,
|
|
132
|
+
"embedding-001": 768,
|
|
133
|
+
"gemini-embedding-001": 768,
|
|
134
|
+
}
|
|
135
|
+
|
|
136
|
+
def __init__(
|
|
137
|
+
self,
|
|
138
|
+
model: str = "text-embedding-004",
|
|
139
|
+
api_key: str | None = None,
|
|
140
|
+
):
|
|
141
|
+
"""
|
|
142
|
+
Args:
|
|
143
|
+
model: Gemini embedding model name
|
|
144
|
+
api_key: API key (defaults to environment variable)
|
|
145
|
+
"""
|
|
146
|
+
try:
|
|
147
|
+
from google import genai
|
|
148
|
+
except ImportError:
|
|
149
|
+
raise RuntimeError(
|
|
150
|
+
"GeminiEmbedding requires 'google-genai' library. "
|
|
151
|
+
"Install with: pip install google-genai"
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
self.model_name = model
|
|
155
|
+
self._dimension = self.MODEL_DIMENSIONS.get(model, 768)
|
|
156
|
+
|
|
157
|
+
# Resolve API key
|
|
158
|
+
key = api_key or os.environ.get("GEMINI_API_KEY") or os.environ.get("GOOGLE_API_KEY")
|
|
159
|
+
if not key:
|
|
160
|
+
raise ValueError(
|
|
161
|
+
"Gemini API key required. Set GEMINI_API_KEY or GOOGLE_API_KEY"
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
self._client = genai.Client(api_key=key)
|
|
165
|
+
|
|
166
|
+
@property
|
|
167
|
+
def dimension(self) -> int:
|
|
168
|
+
"""Get embedding dimension for the model."""
|
|
169
|
+
return self._dimension
|
|
170
|
+
|
|
171
|
+
def embed(self, text: str) -> list[float]:
|
|
172
|
+
"""Generate embedding for a single text."""
|
|
173
|
+
result = self._client.models.embed_content(
|
|
174
|
+
model=self.model_name,
|
|
175
|
+
contents=text,
|
|
176
|
+
)
|
|
177
|
+
return list(result.embeddings[0].values)
|
|
178
|
+
|
|
179
|
+
def embed_batch(self, texts: list[str]) -> list[list[float]]:
|
|
180
|
+
"""Generate embeddings for multiple texts."""
|
|
181
|
+
result = self._client.models.embed_content(
|
|
182
|
+
model=self.model_name,
|
|
183
|
+
contents=texts,
|
|
184
|
+
)
|
|
185
|
+
return [list(e.values) for e in result.embeddings]
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
class OllamaEmbedding:
|
|
189
|
+
"""
|
|
190
|
+
Embedding provider using Ollama's local API.
|
|
191
|
+
|
|
192
|
+
Requires: Ollama running locally (default: http://localhost:11434)
|
|
193
|
+
"""
|
|
194
|
+
|
|
195
|
+
def __init__(
|
|
196
|
+
self,
|
|
197
|
+
model: str = "nomic-embed-text",
|
|
198
|
+
base_url: str = "http://localhost:11434",
|
|
199
|
+
):
|
|
200
|
+
"""
|
|
201
|
+
Args:
|
|
202
|
+
model: Ollama model name
|
|
203
|
+
base_url: Ollama API base URL
|
|
204
|
+
"""
|
|
205
|
+
self.model_name = model
|
|
206
|
+
self.base_url = base_url.rstrip("/")
|
|
207
|
+
self._dimension: int | None = None
|
|
208
|
+
|
|
209
|
+
@property
|
|
210
|
+
def dimension(self) -> int:
|
|
211
|
+
"""Get embedding dimension (determined on first embed call)."""
|
|
212
|
+
if self._dimension is None:
|
|
213
|
+
# Generate a test embedding to determine dimension
|
|
214
|
+
test_embedding = self.embed("test")
|
|
215
|
+
self._dimension = len(test_embedding)
|
|
216
|
+
return self._dimension
|
|
217
|
+
|
|
218
|
+
def embed(self, text: str) -> list[float]:
|
|
219
|
+
"""Generate embedding for a single text."""
|
|
220
|
+
import requests
|
|
221
|
+
|
|
222
|
+
response = requests.post(
|
|
223
|
+
f"{self.base_url}/api/embeddings",
|
|
224
|
+
json={"model": self.model_name, "prompt": text},
|
|
225
|
+
)
|
|
226
|
+
response.raise_for_status()
|
|
227
|
+
|
|
228
|
+
embedding = response.json()["embedding"]
|
|
229
|
+
|
|
230
|
+
if self._dimension is None:
|
|
231
|
+
self._dimension = len(embedding)
|
|
232
|
+
|
|
233
|
+
return embedding
|
|
234
|
+
|
|
235
|
+
def embed_batch(self, texts: list[str]) -> list[list[float]]:
|
|
236
|
+
"""Generate embeddings for multiple texts (sequential for Ollama)."""
|
|
237
|
+
return [self.embed(text) for text in texts]
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
# Register providers
|
|
241
|
+
_registry = get_registry()
|
|
242
|
+
_registry.register_embedding("sentence-transformers", SentenceTransformerEmbedding)
|
|
243
|
+
_registry.register_embedding("openai", OpenAIEmbedding)
|
|
244
|
+
_registry.register_embedding("gemini", GeminiEmbedding)
|
|
245
|
+
_registry.register_embedding("ollama", OllamaEmbedding)
|