keep-skill 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,260 @@
1
+ """
2
+ Embedding cache using SQLite.
3
+
4
+ Wraps any EmbeddingProvider to cache embeddings by content hash,
5
+ avoiding redundant embedding calls for unchanged content.
6
+ """
7
+
8
+ import hashlib
9
+ import json
10
+ import sqlite3
11
+ from datetime import datetime, timezone
12
+ from pathlib import Path
13
+ from typing import Optional
14
+
15
+ from .base import EmbeddingProvider
16
+
17
+
18
+ class EmbeddingCache:
19
+ """
20
+ SQLite-based embedding cache.
21
+
22
+ Cache key is SHA256(model_name + content), so different models
23
+ don't share cached embeddings.
24
+ """
25
+
26
+ def __init__(self, cache_path: Path, max_entries: int = 50000):
27
+ """
28
+ Args:
29
+ cache_path: Path to SQLite database file
30
+ max_entries: Maximum cache entries (LRU eviction when exceeded)
31
+ """
32
+ self._cache_path = cache_path
33
+ self._max_entries = max_entries
34
+ self._conn: Optional[sqlite3.Connection] = None
35
+ self._init_db()
36
+
37
+ def _init_db(self) -> None:
38
+ """Initialize the SQLite database."""
39
+ self._cache_path.parent.mkdir(parents=True, exist_ok=True)
40
+ self._conn = sqlite3.connect(str(self._cache_path), check_same_thread=False)
41
+ self._conn.execute("""
42
+ CREATE TABLE IF NOT EXISTS embedding_cache (
43
+ content_hash TEXT PRIMARY KEY,
44
+ model_name TEXT NOT NULL,
45
+ embedding BLOB NOT NULL,
46
+ dimension INTEGER NOT NULL,
47
+ created_at TEXT NOT NULL,
48
+ last_accessed TEXT NOT NULL
49
+ )
50
+ """)
51
+ self._conn.execute("""
52
+ CREATE INDEX IF NOT EXISTS idx_last_accessed
53
+ ON embedding_cache(last_accessed)
54
+ """)
55
+ self._conn.commit()
56
+
57
+ def _hash_key(self, model_name: str, content: str) -> str:
58
+ """Generate cache key from model name and content."""
59
+ key_input = f"{model_name}:{content}"
60
+ return hashlib.sha256(key_input.encode("utf-8")).hexdigest()
61
+
62
+ def get(self, model_name: str, content: str) -> Optional[list[float]]:
63
+ """
64
+ Get cached embedding if it exists.
65
+
66
+ Updates last_accessed timestamp on hit.
67
+ """
68
+ content_hash = self._hash_key(model_name, content)
69
+ cursor = self._conn.execute(
70
+ "SELECT embedding FROM embedding_cache WHERE content_hash = ?",
71
+ (content_hash,)
72
+ )
73
+ row = cursor.fetchone()
74
+
75
+ if row is not None:
76
+ # Update last_accessed
77
+ now = datetime.now(timezone.utc).isoformat()
78
+ self._conn.execute(
79
+ "UPDATE embedding_cache SET last_accessed = ? WHERE content_hash = ?",
80
+ (now, content_hash)
81
+ )
82
+ self._conn.commit()
83
+
84
+ # Deserialize embedding
85
+ return json.loads(row[0])
86
+
87
+ return None
88
+
89
+ def put(
90
+ self,
91
+ model_name: str,
92
+ content: str,
93
+ embedding: list[float]
94
+ ) -> None:
95
+ """
96
+ Cache an embedding.
97
+
98
+ Evicts oldest entries if cache exceeds max_entries.
99
+ """
100
+ content_hash = self._hash_key(model_name, content)
101
+ now = datetime.now(timezone.utc).isoformat()
102
+ embedding_blob = json.dumps(embedding)
103
+
104
+ self._conn.execute("""
105
+ INSERT OR REPLACE INTO embedding_cache
106
+ (content_hash, model_name, embedding, dimension, created_at, last_accessed)
107
+ VALUES (?, ?, ?, ?, ?, ?)
108
+ """, (content_hash, model_name, embedding_blob, len(embedding), now, now))
109
+ self._conn.commit()
110
+
111
+ # Evict old entries if needed
112
+ self._maybe_evict()
113
+
114
+ def _maybe_evict(self) -> None:
115
+ """Evict oldest entries if cache exceeds max size."""
116
+ cursor = self._conn.execute("SELECT COUNT(*) FROM embedding_cache")
117
+ count = cursor.fetchone()[0]
118
+
119
+ if count > self._max_entries:
120
+ # Delete oldest 10% by last_accessed
121
+ evict_count = max(1, count // 10)
122
+ self._conn.execute("""
123
+ DELETE FROM embedding_cache
124
+ WHERE content_hash IN (
125
+ SELECT content_hash FROM embedding_cache
126
+ ORDER BY last_accessed ASC
127
+ LIMIT ?
128
+ )
129
+ """, (evict_count,))
130
+ self._conn.commit()
131
+
132
+ def stats(self) -> dict:
133
+ """Get cache statistics."""
134
+ cursor = self._conn.execute("""
135
+ SELECT COUNT(*), COUNT(DISTINCT model_name)
136
+ FROM embedding_cache
137
+ """)
138
+ count, models = cursor.fetchone()
139
+ return {
140
+ "entries": count,
141
+ "models": models,
142
+ "max_entries": self._max_entries,
143
+ "cache_path": str(self._cache_path),
144
+ }
145
+
146
+ def clear(self) -> None:
147
+ """Clear all cached embeddings."""
148
+ self._conn.execute("DELETE FROM embedding_cache")
149
+ self._conn.commit()
150
+
151
+ def close(self) -> None:
152
+ """Close the database connection."""
153
+ if self._conn is not None:
154
+ self._conn.close()
155
+ self._conn = None
156
+
157
+ def __del__(self) -> None:
158
+ """Ensure connection is closed on cleanup."""
159
+ self.close()
160
+
161
+ def __enter__(self):
162
+ """Context manager entry."""
163
+ return self
164
+
165
+ def __exit__(self, exc_type, exc_val, exc_tb):
166
+ """Context manager exit - close connection."""
167
+ self.close()
168
+ return False
169
+
170
+
171
+ class CachingEmbeddingProvider:
172
+ """
173
+ Wrapper that adds caching to any EmbeddingProvider.
174
+
175
+ Usage:
176
+ base_provider = SentenceTransformerEmbedding()
177
+ cached = CachingEmbeddingProvider(base_provider, cache_path)
178
+ """
179
+
180
+ def __init__(
181
+ self,
182
+ provider: EmbeddingProvider,
183
+ cache_path: Path,
184
+ max_entries: int = 50000
185
+ ):
186
+ self._provider = provider
187
+ self._cache = EmbeddingCache(cache_path, max_entries)
188
+ self._hits = 0
189
+ self._misses = 0
190
+
191
+ @property
192
+ def model_name(self) -> str:
193
+ """Get the underlying provider's model name."""
194
+ return getattr(self._provider, "model_name", "unknown")
195
+
196
+ @property
197
+ def dimension(self) -> int:
198
+ """Get embedding dimension from the wrapped provider."""
199
+ return self._provider.dimension
200
+
201
+ def embed(self, text: str) -> list[float]:
202
+ """
203
+ Get embedding, using cache when available.
204
+ """
205
+ # Check cache
206
+ cached = self._cache.get(self.model_name, text)
207
+ if cached is not None:
208
+ self._hits += 1
209
+ return cached
210
+
211
+ # Cache miss - compute embedding
212
+ self._misses += 1
213
+ embedding = self._provider.embed(text)
214
+
215
+ # Store in cache
216
+ self._cache.put(self.model_name, text, embedding)
217
+
218
+ return embedding
219
+
220
+ def embed_batch(self, texts: list[str]) -> list[list[float]]:
221
+ """
222
+ Get embeddings for batch, using cache where available.
223
+
224
+ Only computes embeddings for cache misses.
225
+ """
226
+ results: list[Optional[list[float]]] = [None] * len(texts)
227
+ to_embed: list[tuple[int, str]] = []
228
+
229
+ # Check cache for each text
230
+ for i, text in enumerate(texts):
231
+ cached = self._cache.get(self.model_name, text)
232
+ if cached is not None:
233
+ self._hits += 1
234
+ results[i] = cached
235
+ else:
236
+ self._misses += 1
237
+ to_embed.append((i, text))
238
+
239
+ # Batch embed cache misses
240
+ if to_embed:
241
+ indices, texts_to_embed = zip(*to_embed)
242
+ embeddings = self._provider.embed_batch(list(texts_to_embed))
243
+
244
+ for idx, text, embedding in zip(indices, texts_to_embed, embeddings):
245
+ results[idx] = embedding
246
+ self._cache.put(self.model_name, text, embedding)
247
+
248
+ return results # type: ignore
249
+
250
+ def stats(self) -> dict:
251
+ """Get cache and hit/miss statistics."""
252
+ cache_stats = self._cache.stats()
253
+ total = self._hits + self._misses
254
+ hit_rate = self._hits / total if total > 0 else 0.0
255
+ return {
256
+ **cache_stats,
257
+ "hits": self._hits,
258
+ "misses": self._misses,
259
+ "hit_rate": f"{hit_rate:.1%}",
260
+ }
@@ -0,0 +1,245 @@
1
+ """
2
+ Embedding providers for generating vector representations of text.
3
+ """
4
+
5
+ import os
6
+ from typing import Any
7
+
8
+ from .base import EmbeddingProvider, get_registry
9
+
10
+
11
+ class SentenceTransformerEmbedding:
12
+ """
13
+ Embedding provider using sentence-transformers library.
14
+
15
+ Runs locally, no API key required. Good default for getting started.
16
+
17
+ Requires: pip install sentence-transformers
18
+ """
19
+
20
+ def __init__(self, model: str = "all-MiniLM-L6-v2"):
21
+ """
22
+ Args:
23
+ model: Model name from sentence-transformers hub
24
+ """
25
+ try:
26
+ from sentence_transformers import SentenceTransformer
27
+ except ImportError:
28
+ raise RuntimeError(
29
+ "SentenceTransformerEmbedding requires 'sentence-transformers' library. "
30
+ "Install with: pip install sentence-transformers"
31
+ )
32
+
33
+ self.model_name = model
34
+ self._model = SentenceTransformer(model)
35
+
36
+ @property
37
+ def dimension(self) -> int:
38
+ """Get embedding dimension from the model."""
39
+ return self._model.get_sentence_embedding_dimension()
40
+
41
+ def embed(self, text: str) -> list[float]:
42
+ """Generate embedding for a single text."""
43
+ embedding = self._model.encode(text, convert_to_numpy=True)
44
+ return embedding.tolist()
45
+
46
+ def embed_batch(self, texts: list[str]) -> list[list[float]]:
47
+ """Generate embeddings for multiple texts."""
48
+ embeddings = self._model.encode(texts, convert_to_numpy=True)
49
+ return embeddings.tolist()
50
+
51
+
52
+ class OpenAIEmbedding:
53
+ """
54
+ Embedding provider using OpenAI's API.
55
+
56
+ Requires: KEEP_OPENAI_API_KEY or OPENAI_API_KEY environment variable.
57
+ Requires: pip install openai
58
+ """
59
+
60
+ # Model dimensions (as of 2024)
61
+ MODEL_DIMENSIONS = {
62
+ "text-embedding-3-small": 1536,
63
+ "text-embedding-3-large": 3072,
64
+ "text-embedding-ada-002": 1536,
65
+ }
66
+
67
+ def __init__(
68
+ self,
69
+ model: str = "text-embedding-3-small",
70
+ api_key: str | None = None,
71
+ ):
72
+ """
73
+ Args:
74
+ model: OpenAI embedding model name
75
+ api_key: API key (defaults to environment variable)
76
+ """
77
+ try:
78
+ from openai import OpenAI
79
+ except ImportError:
80
+ raise RuntimeError(
81
+ "OpenAIEmbedding requires 'openai' library. "
82
+ "Install with: pip install openai"
83
+ )
84
+
85
+ self.model_name = model
86
+ self._dimension = self.MODEL_DIMENSIONS.get(model, 1536)
87
+
88
+ # Resolve API key
89
+ key = api_key or os.environ.get("KEEP_OPENAI_API_KEY") or os.environ.get("OPENAI_API_KEY")
90
+ if not key:
91
+ raise ValueError(
92
+ "OpenAI API key required. Set KEEP_OPENAI_API_KEY or OPENAI_API_KEY"
93
+ )
94
+
95
+ self._client = OpenAI(api_key=key)
96
+
97
+ @property
98
+ def dimension(self) -> int:
99
+ """Get embedding dimension for the model."""
100
+ return self._dimension
101
+
102
+ def embed(self, text: str) -> list[float]:
103
+ """Generate embedding for a single text."""
104
+ response = self._client.embeddings.create(
105
+ model=self.model_name,
106
+ input=text,
107
+ )
108
+ return response.data[0].embedding
109
+
110
+ def embed_batch(self, texts: list[str]) -> list[list[float]]:
111
+ """Generate embeddings for multiple texts."""
112
+ response = self._client.embeddings.create(
113
+ model=self.model_name,
114
+ input=texts,
115
+ )
116
+ # Sort by index to ensure order matches input
117
+ sorted_data = sorted(response.data, key=lambda x: x.index)
118
+ return [d.embedding for d in sorted_data]
119
+
120
+
121
+ class GeminiEmbedding:
122
+ """
123
+ Embedding provider using Google's Gemini API.
124
+
125
+ Requires: GEMINI_API_KEY or GOOGLE_API_KEY environment variable.
126
+ Requires: pip install google-genai
127
+ """
128
+
129
+ # Model dimensions (as of 2025)
130
+ MODEL_DIMENSIONS = {
131
+ "text-embedding-004": 768,
132
+ "embedding-001": 768,
133
+ "gemini-embedding-001": 768,
134
+ }
135
+
136
+ def __init__(
137
+ self,
138
+ model: str = "text-embedding-004",
139
+ api_key: str | None = None,
140
+ ):
141
+ """
142
+ Args:
143
+ model: Gemini embedding model name
144
+ api_key: API key (defaults to environment variable)
145
+ """
146
+ try:
147
+ from google import genai
148
+ except ImportError:
149
+ raise RuntimeError(
150
+ "GeminiEmbedding requires 'google-genai' library. "
151
+ "Install with: pip install google-genai"
152
+ )
153
+
154
+ self.model_name = model
155
+ self._dimension = self.MODEL_DIMENSIONS.get(model, 768)
156
+
157
+ # Resolve API key
158
+ key = api_key or os.environ.get("GEMINI_API_KEY") or os.environ.get("GOOGLE_API_KEY")
159
+ if not key:
160
+ raise ValueError(
161
+ "Gemini API key required. Set GEMINI_API_KEY or GOOGLE_API_KEY"
162
+ )
163
+
164
+ self._client = genai.Client(api_key=key)
165
+
166
+ @property
167
+ def dimension(self) -> int:
168
+ """Get embedding dimension for the model."""
169
+ return self._dimension
170
+
171
+ def embed(self, text: str) -> list[float]:
172
+ """Generate embedding for a single text."""
173
+ result = self._client.models.embed_content(
174
+ model=self.model_name,
175
+ contents=text,
176
+ )
177
+ return list(result.embeddings[0].values)
178
+
179
+ def embed_batch(self, texts: list[str]) -> list[list[float]]:
180
+ """Generate embeddings for multiple texts."""
181
+ result = self._client.models.embed_content(
182
+ model=self.model_name,
183
+ contents=texts,
184
+ )
185
+ return [list(e.values) for e in result.embeddings]
186
+
187
+
188
+ class OllamaEmbedding:
189
+ """
190
+ Embedding provider using Ollama's local API.
191
+
192
+ Requires: Ollama running locally (default: http://localhost:11434)
193
+ """
194
+
195
+ def __init__(
196
+ self,
197
+ model: str = "nomic-embed-text",
198
+ base_url: str = "http://localhost:11434",
199
+ ):
200
+ """
201
+ Args:
202
+ model: Ollama model name
203
+ base_url: Ollama API base URL
204
+ """
205
+ self.model_name = model
206
+ self.base_url = base_url.rstrip("/")
207
+ self._dimension: int | None = None
208
+
209
+ @property
210
+ def dimension(self) -> int:
211
+ """Get embedding dimension (determined on first embed call)."""
212
+ if self._dimension is None:
213
+ # Generate a test embedding to determine dimension
214
+ test_embedding = self.embed("test")
215
+ self._dimension = len(test_embedding)
216
+ return self._dimension
217
+
218
+ def embed(self, text: str) -> list[float]:
219
+ """Generate embedding for a single text."""
220
+ import requests
221
+
222
+ response = requests.post(
223
+ f"{self.base_url}/api/embeddings",
224
+ json={"model": self.model_name, "prompt": text},
225
+ )
226
+ response.raise_for_status()
227
+
228
+ embedding = response.json()["embedding"]
229
+
230
+ if self._dimension is None:
231
+ self._dimension = len(embedding)
232
+
233
+ return embedding
234
+
235
+ def embed_batch(self, texts: list[str]) -> list[list[float]]:
236
+ """Generate embeddings for multiple texts (sequential for Ollama)."""
237
+ return [self.embed(text) for text in texts]
238
+
239
+
240
+ # Register providers
241
+ _registry = get_registry()
242
+ _registry.register_embedding("sentence-transformers", SentenceTransformerEmbedding)
243
+ _registry.register_embedding("openai", OpenAIEmbedding)
244
+ _registry.register_embedding("gemini", GeminiEmbedding)
245
+ _registry.register_embedding("ollama", OllamaEmbedding)