headroom-ai 0.2.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- headroom/__init__.py +212 -0
- headroom/cache/__init__.py +76 -0
- headroom/cache/anthropic.py +517 -0
- headroom/cache/base.py +342 -0
- headroom/cache/compression_feedback.py +613 -0
- headroom/cache/compression_store.py +814 -0
- headroom/cache/dynamic_detector.py +1026 -0
- headroom/cache/google.py +884 -0
- headroom/cache/openai.py +584 -0
- headroom/cache/registry.py +175 -0
- headroom/cache/semantic.py +451 -0
- headroom/ccr/__init__.py +77 -0
- headroom/ccr/context_tracker.py +582 -0
- headroom/ccr/mcp_server.py +319 -0
- headroom/ccr/response_handler.py +772 -0
- headroom/ccr/tool_injection.py +415 -0
- headroom/cli.py +219 -0
- headroom/client.py +977 -0
- headroom/compression/__init__.py +42 -0
- headroom/compression/detector.py +424 -0
- headroom/compression/handlers/__init__.py +22 -0
- headroom/compression/handlers/base.py +219 -0
- headroom/compression/handlers/code_handler.py +506 -0
- headroom/compression/handlers/json_handler.py +418 -0
- headroom/compression/masks.py +345 -0
- headroom/compression/universal.py +465 -0
- headroom/config.py +474 -0
- headroom/exceptions.py +192 -0
- headroom/integrations/__init__.py +159 -0
- headroom/integrations/agno/__init__.py +53 -0
- headroom/integrations/agno/hooks.py +345 -0
- headroom/integrations/agno/model.py +625 -0
- headroom/integrations/agno/providers.py +154 -0
- headroom/integrations/langchain/__init__.py +106 -0
- headroom/integrations/langchain/agents.py +326 -0
- headroom/integrations/langchain/chat_model.py +1002 -0
- headroom/integrations/langchain/langsmith.py +324 -0
- headroom/integrations/langchain/memory.py +319 -0
- headroom/integrations/langchain/providers.py +200 -0
- headroom/integrations/langchain/retriever.py +371 -0
- headroom/integrations/langchain/streaming.py +341 -0
- headroom/integrations/mcp/__init__.py +37 -0
- headroom/integrations/mcp/server.py +533 -0
- headroom/memory/__init__.py +37 -0
- headroom/memory/extractor.py +390 -0
- headroom/memory/fast_store.py +621 -0
- headroom/memory/fast_wrapper.py +311 -0
- headroom/memory/inline_extractor.py +229 -0
- headroom/memory/store.py +434 -0
- headroom/memory/worker.py +260 -0
- headroom/memory/wrapper.py +321 -0
- headroom/models/__init__.py +39 -0
- headroom/models/registry.py +687 -0
- headroom/parser.py +293 -0
- headroom/pricing/__init__.py +51 -0
- headroom/pricing/anthropic_prices.py +81 -0
- headroom/pricing/litellm_pricing.py +113 -0
- headroom/pricing/openai_prices.py +91 -0
- headroom/pricing/registry.py +188 -0
- headroom/providers/__init__.py +61 -0
- headroom/providers/anthropic.py +621 -0
- headroom/providers/base.py +131 -0
- headroom/providers/cohere.py +362 -0
- headroom/providers/google.py +427 -0
- headroom/providers/litellm.py +297 -0
- headroom/providers/openai.py +566 -0
- headroom/providers/openai_compatible.py +521 -0
- headroom/proxy/__init__.py +19 -0
- headroom/proxy/server.py +2683 -0
- headroom/py.typed +0 -0
- headroom/relevance/__init__.py +124 -0
- headroom/relevance/base.py +106 -0
- headroom/relevance/bm25.py +255 -0
- headroom/relevance/embedding.py +255 -0
- headroom/relevance/hybrid.py +259 -0
- headroom/reporting/__init__.py +5 -0
- headroom/reporting/generator.py +549 -0
- headroom/storage/__init__.py +41 -0
- headroom/storage/base.py +125 -0
- headroom/storage/jsonl.py +220 -0
- headroom/storage/sqlite.py +289 -0
- headroom/telemetry/__init__.py +91 -0
- headroom/telemetry/collector.py +764 -0
- headroom/telemetry/models.py +880 -0
- headroom/telemetry/toin.py +1579 -0
- headroom/tokenizer.py +80 -0
- headroom/tokenizers/__init__.py +75 -0
- headroom/tokenizers/base.py +210 -0
- headroom/tokenizers/estimator.py +198 -0
- headroom/tokenizers/huggingface.py +317 -0
- headroom/tokenizers/mistral.py +245 -0
- headroom/tokenizers/registry.py +398 -0
- headroom/tokenizers/tiktoken_counter.py +248 -0
- headroom/transforms/__init__.py +106 -0
- headroom/transforms/base.py +57 -0
- headroom/transforms/cache_aligner.py +357 -0
- headroom/transforms/code_compressor.py +1313 -0
- headroom/transforms/content_detector.py +335 -0
- headroom/transforms/content_router.py +1158 -0
- headroom/transforms/llmlingua_compressor.py +638 -0
- headroom/transforms/log_compressor.py +529 -0
- headroom/transforms/pipeline.py +297 -0
- headroom/transforms/rolling_window.py +350 -0
- headroom/transforms/search_compressor.py +365 -0
- headroom/transforms/smart_crusher.py +2682 -0
- headroom/transforms/text_compressor.py +259 -0
- headroom/transforms/tool_crusher.py +338 -0
- headroom/utils.py +215 -0
- headroom_ai-0.2.13.dist-info/METADATA +315 -0
- headroom_ai-0.2.13.dist-info/RECORD +114 -0
- headroom_ai-0.2.13.dist-info/WHEEL +4 -0
- headroom_ai-0.2.13.dist-info/entry_points.txt +2 -0
- headroom_ai-0.2.13.dist-info/licenses/LICENSE +190 -0
- headroom_ai-0.2.13.dist-info/licenses/NOTICE +43 -0
|
@@ -0,0 +1,621 @@
|
|
|
1
|
+
"""Fast embedding-based memory store.
|
|
2
|
+
|
|
3
|
+
Sub-100ms write and read latency by:
|
|
4
|
+
1. NO LLM extraction - just embed and store
|
|
5
|
+
2. Vector similarity search - not keyword matching
|
|
6
|
+
3. Optional local embeddings for sub-10ms latency
|
|
7
|
+
|
|
8
|
+
This replaces the slow LLM-based extraction approach.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from __future__ import annotations
|
|
12
|
+
|
|
13
|
+
import json
|
|
14
|
+
import logging
|
|
15
|
+
import sqlite3
|
|
16
|
+
import time
|
|
17
|
+
from collections.abc import Callable
|
|
18
|
+
from dataclasses import dataclass, field
|
|
19
|
+
from datetime import datetime
|
|
20
|
+
from pathlib import Path
|
|
21
|
+
from typing import Any
|
|
22
|
+
from uuid import uuid4
|
|
23
|
+
|
|
24
|
+
import numpy as np
|
|
25
|
+
|
|
26
|
+
logger = logging.getLogger(__name__)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@dataclass
|
|
30
|
+
class MemoryChunk:
|
|
31
|
+
"""A memory chunk with text and embedding."""
|
|
32
|
+
|
|
33
|
+
id: str = field(default_factory=lambda: str(uuid4()))
|
|
34
|
+
text: str = ""
|
|
35
|
+
role: str = "user" # "user" or "assistant"
|
|
36
|
+
embedding: np.ndarray | None = None
|
|
37
|
+
timestamp: datetime = field(default_factory=datetime.utcnow)
|
|
38
|
+
metadata: dict[str, Any] = field(default_factory=dict)
|
|
39
|
+
|
|
40
|
+
def to_dict(self) -> dict:
|
|
41
|
+
"""Convert to dictionary for storage."""
|
|
42
|
+
return {
|
|
43
|
+
"id": self.id,
|
|
44
|
+
"text": self.text,
|
|
45
|
+
"role": self.role,
|
|
46
|
+
"embedding": self.embedding.tolist() if self.embedding is not None else None,
|
|
47
|
+
"timestamp": self.timestamp.isoformat(),
|
|
48
|
+
"metadata": self.metadata,
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
@classmethod
|
|
52
|
+
def from_dict(cls, data: dict) -> MemoryChunk:
|
|
53
|
+
"""Create from dictionary."""
|
|
54
|
+
embedding = None
|
|
55
|
+
if data.get("embedding"):
|
|
56
|
+
embedding = np.array(data["embedding"], dtype=np.float32)
|
|
57
|
+
return cls(
|
|
58
|
+
id=data["id"],
|
|
59
|
+
text=data["text"],
|
|
60
|
+
role=data.get("role", "user"),
|
|
61
|
+
embedding=embedding,
|
|
62
|
+
timestamp=datetime.fromisoformat(data["timestamp"]),
|
|
63
|
+
metadata=data.get("metadata", {}),
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
# Type aliases for embedding functions
|
|
68
|
+
EmbedFn = Callable[[str], np.ndarray]
|
|
69
|
+
BatchEmbedFn = Callable[[list[str]], list[np.ndarray]]
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
|
|
73
|
+
"""Compute cosine similarity between two vectors."""
|
|
74
|
+
norm_a = np.linalg.norm(a)
|
|
75
|
+
norm_b = np.linalg.norm(b)
|
|
76
|
+
if norm_a == 0 or norm_b == 0:
|
|
77
|
+
return 0.0
|
|
78
|
+
return float(np.dot(a, b) / (norm_a * norm_b))
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def cosine_similarity_batch(query: np.ndarray, vectors: np.ndarray) -> np.ndarray:
|
|
82
|
+
"""Compute cosine similarity between query and multiple vectors."""
|
|
83
|
+
# Normalize query
|
|
84
|
+
query_norm = query / (np.linalg.norm(query) + 1e-9)
|
|
85
|
+
# Normalize vectors
|
|
86
|
+
norms = np.linalg.norm(vectors, axis=1, keepdims=True) + 1e-9
|
|
87
|
+
vectors_norm = vectors / norms
|
|
88
|
+
# Dot product - cast to ndarray to satisfy mypy
|
|
89
|
+
result: np.ndarray = np.dot(vectors_norm, query_norm)
|
|
90
|
+
return result
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
class FastMemoryStore:
|
|
94
|
+
"""Fast embedding-based memory store.
|
|
95
|
+
|
|
96
|
+
Features:
|
|
97
|
+
- Sub-100ms write latency (no LLM, just embedding)
|
|
98
|
+
- Sub-50ms read latency (vector similarity search)
|
|
99
|
+
- Pluggable embedding functions (local or API)
|
|
100
|
+
- SQLite storage with in-memory vector cache
|
|
101
|
+
|
|
102
|
+
Usage:
|
|
103
|
+
store = FastMemoryStore(db_path, embed_fn=my_embed_fn)
|
|
104
|
+
store.add("user_123", "I prefer Python", role="user")
|
|
105
|
+
results = store.search("user_123", "programming language", top_k=5)
|
|
106
|
+
"""
|
|
107
|
+
|
|
108
|
+
def __init__(
|
|
109
|
+
self,
|
|
110
|
+
db_path: str | Path,
|
|
111
|
+
embed_fn: EmbedFn | None = None,
|
|
112
|
+
embedding_dim: int = 1536, # OpenAI default
|
|
113
|
+
):
|
|
114
|
+
"""Initialize the store.
|
|
115
|
+
|
|
116
|
+
Args:
|
|
117
|
+
db_path: Path to SQLite database
|
|
118
|
+
embed_fn: Function to embed text (if None, must call set_embed_fn later)
|
|
119
|
+
embedding_dim: Dimension of embeddings
|
|
120
|
+
"""
|
|
121
|
+
self.db_path = Path(db_path)
|
|
122
|
+
self.embed_fn = embed_fn
|
|
123
|
+
self.embedding_dim = embedding_dim
|
|
124
|
+
|
|
125
|
+
# In-memory vector cache for fast similarity search
|
|
126
|
+
self._vector_cache: dict[
|
|
127
|
+
str, dict[str, np.ndarray]
|
|
128
|
+
] = {} # user_id -> {chunk_id -> embedding}
|
|
129
|
+
self._chunk_cache: dict[str, dict[str, MemoryChunk]] = {} # user_id -> {chunk_id -> chunk}
|
|
130
|
+
|
|
131
|
+
self._init_db()
|
|
132
|
+
self._load_cache()
|
|
133
|
+
|
|
134
|
+
def _init_db(self) -> None:
|
|
135
|
+
"""Initialize SQLite database."""
|
|
136
|
+
self.db_path.parent.mkdir(parents=True, exist_ok=True)
|
|
137
|
+
|
|
138
|
+
with sqlite3.connect(str(self.db_path)) as conn:
|
|
139
|
+
conn.execute("""
|
|
140
|
+
CREATE TABLE IF NOT EXISTS memory_chunks (
|
|
141
|
+
id TEXT PRIMARY KEY,
|
|
142
|
+
user_id TEXT NOT NULL,
|
|
143
|
+
text TEXT NOT NULL,
|
|
144
|
+
role TEXT DEFAULT 'user',
|
|
145
|
+
embedding BLOB,
|
|
146
|
+
timestamp TEXT NOT NULL,
|
|
147
|
+
metadata TEXT DEFAULT '{}',
|
|
148
|
+
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
|
149
|
+
)
|
|
150
|
+
""")
|
|
151
|
+
conn.execute("""
|
|
152
|
+
CREATE INDEX IF NOT EXISTS idx_chunks_user_id
|
|
153
|
+
ON memory_chunks(user_id)
|
|
154
|
+
""")
|
|
155
|
+
conn.execute("""
|
|
156
|
+
CREATE INDEX IF NOT EXISTS idx_chunks_timestamp
|
|
157
|
+
ON memory_chunks(user_id, timestamp DESC)
|
|
158
|
+
""")
|
|
159
|
+
conn.commit()
|
|
160
|
+
|
|
161
|
+
def _load_cache(self) -> None:
|
|
162
|
+
"""Load all embeddings into memory for fast search."""
|
|
163
|
+
with sqlite3.connect(str(self.db_path)) as conn:
|
|
164
|
+
cursor = conn.execute("""
|
|
165
|
+
SELECT id, user_id, text, role, embedding, timestamp, metadata
|
|
166
|
+
FROM memory_chunks
|
|
167
|
+
WHERE embedding IS NOT NULL
|
|
168
|
+
""")
|
|
169
|
+
|
|
170
|
+
for row in cursor:
|
|
171
|
+
chunk_id, user_id, text, role, embedding_blob, timestamp, metadata = row
|
|
172
|
+
|
|
173
|
+
if user_id not in self._vector_cache:
|
|
174
|
+
self._vector_cache[user_id] = {}
|
|
175
|
+
self._chunk_cache[user_id] = {}
|
|
176
|
+
|
|
177
|
+
# Deserialize embedding
|
|
178
|
+
embedding = np.frombuffer(embedding_blob, dtype=np.float32)
|
|
179
|
+
|
|
180
|
+
self._vector_cache[user_id][chunk_id] = embedding
|
|
181
|
+
|
|
182
|
+
chunk = MemoryChunk(
|
|
183
|
+
id=chunk_id,
|
|
184
|
+
text=text,
|
|
185
|
+
role=role,
|
|
186
|
+
embedding=embedding,
|
|
187
|
+
timestamp=datetime.fromisoformat(timestamp),
|
|
188
|
+
metadata=json.loads(metadata) if metadata else {},
|
|
189
|
+
)
|
|
190
|
+
self._chunk_cache[user_id][chunk_id] = chunk
|
|
191
|
+
|
|
192
|
+
logger.debug(f"Loaded {sum(len(v) for v in self._vector_cache.values())} chunks into cache")
|
|
193
|
+
|
|
194
|
+
def set_embed_fn(self, embed_fn: EmbedFn) -> None:
|
|
195
|
+
"""Set the embedding function."""
|
|
196
|
+
self.embed_fn = embed_fn
|
|
197
|
+
|
|
198
|
+
def add(
|
|
199
|
+
self,
|
|
200
|
+
user_id: str,
|
|
201
|
+
text: str,
|
|
202
|
+
role: str = "user",
|
|
203
|
+
metadata: dict[str, Any] | None = None,
|
|
204
|
+
) -> MemoryChunk:
|
|
205
|
+
"""Add a memory chunk.
|
|
206
|
+
|
|
207
|
+
This is the FAST path - just embed and store, no LLM extraction.
|
|
208
|
+
Typical latency: <50ms with API embeddings, <10ms with local.
|
|
209
|
+
|
|
210
|
+
Args:
|
|
211
|
+
user_id: User/entity identifier
|
|
212
|
+
text: Text to store
|
|
213
|
+
role: "user" or "assistant"
|
|
214
|
+
metadata: Optional metadata
|
|
215
|
+
|
|
216
|
+
Returns:
|
|
217
|
+
The created MemoryChunk
|
|
218
|
+
"""
|
|
219
|
+
if not self.embed_fn:
|
|
220
|
+
raise ValueError("No embedding function set. Call set_embed_fn() first.")
|
|
221
|
+
|
|
222
|
+
start_time = time.perf_counter()
|
|
223
|
+
|
|
224
|
+
# Embed the text
|
|
225
|
+
embedding = self.embed_fn(text)
|
|
226
|
+
embed_time = time.perf_counter() - start_time
|
|
227
|
+
|
|
228
|
+
# Create chunk
|
|
229
|
+
chunk = MemoryChunk(
|
|
230
|
+
text=text,
|
|
231
|
+
role=role,
|
|
232
|
+
embedding=embedding,
|
|
233
|
+
metadata=metadata or {},
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
# Store in SQLite
|
|
237
|
+
with sqlite3.connect(str(self.db_path)) as conn:
|
|
238
|
+
conn.execute(
|
|
239
|
+
"""
|
|
240
|
+
INSERT INTO memory_chunks (id, user_id, text, role, embedding, timestamp, metadata)
|
|
241
|
+
VALUES (?, ?, ?, ?, ?, ?, ?)
|
|
242
|
+
""",
|
|
243
|
+
(
|
|
244
|
+
chunk.id,
|
|
245
|
+
user_id,
|
|
246
|
+
chunk.text,
|
|
247
|
+
chunk.role,
|
|
248
|
+
embedding.astype(np.float32).tobytes(),
|
|
249
|
+
chunk.timestamp.isoformat(),
|
|
250
|
+
json.dumps(chunk.metadata),
|
|
251
|
+
),
|
|
252
|
+
)
|
|
253
|
+
conn.commit()
|
|
254
|
+
|
|
255
|
+
# Update cache
|
|
256
|
+
if user_id not in self._vector_cache:
|
|
257
|
+
self._vector_cache[user_id] = {}
|
|
258
|
+
self._chunk_cache[user_id] = {}
|
|
259
|
+
|
|
260
|
+
self._vector_cache[user_id][chunk.id] = embedding
|
|
261
|
+
self._chunk_cache[user_id][chunk.id] = chunk
|
|
262
|
+
|
|
263
|
+
total_time = time.perf_counter() - start_time
|
|
264
|
+
logger.debug(f"Added chunk in {total_time * 1000:.1f}ms (embed: {embed_time * 1000:.1f}ms)")
|
|
265
|
+
|
|
266
|
+
return chunk
|
|
267
|
+
|
|
268
|
+
def add_turn(
|
|
269
|
+
self,
|
|
270
|
+
user_id: str,
|
|
271
|
+
user_message: str,
|
|
272
|
+
assistant_response: str,
|
|
273
|
+
metadata: dict[str, Any] | None = None,
|
|
274
|
+
) -> tuple[MemoryChunk, MemoryChunk]:
|
|
275
|
+
"""Add a conversation turn (user message + assistant response).
|
|
276
|
+
|
|
277
|
+
Convenience method that stores both parts of a turn.
|
|
278
|
+
|
|
279
|
+
Args:
|
|
280
|
+
user_id: User/entity identifier
|
|
281
|
+
user_message: The user's message
|
|
282
|
+
assistant_response: The assistant's response
|
|
283
|
+
metadata: Optional metadata for both chunks
|
|
284
|
+
|
|
285
|
+
Returns:
|
|
286
|
+
Tuple of (user_chunk, assistant_chunk)
|
|
287
|
+
"""
|
|
288
|
+
user_chunk = self.add(user_id, user_message, role="user", metadata=metadata)
|
|
289
|
+
assistant_chunk = self.add(user_id, assistant_response, role="assistant", metadata=metadata)
|
|
290
|
+
return user_chunk, assistant_chunk
|
|
291
|
+
|
|
292
|
+
def add_turn_batched(
|
|
293
|
+
self,
|
|
294
|
+
user_id: str,
|
|
295
|
+
user_message: str,
|
|
296
|
+
assistant_response: str,
|
|
297
|
+
batch_embed_fn: BatchEmbedFn,
|
|
298
|
+
metadata: dict[str, Any] | None = None,
|
|
299
|
+
) -> tuple[MemoryChunk, MemoryChunk]:
|
|
300
|
+
"""Add a conversation turn using BATCHED embedding (single API call).
|
|
301
|
+
|
|
302
|
+
This is the FASTEST path - embeds both messages in ONE API call.
|
|
303
|
+
Typical latency: 50-100ms total vs 200-400ms with individual calls.
|
|
304
|
+
|
|
305
|
+
Args:
|
|
306
|
+
user_id: User/entity identifier
|
|
307
|
+
user_message: The user's message
|
|
308
|
+
assistant_response: The assistant's response
|
|
309
|
+
batch_embed_fn: Batch embedding function
|
|
310
|
+
metadata: Optional metadata for both chunks
|
|
311
|
+
|
|
312
|
+
Returns:
|
|
313
|
+
Tuple of (user_chunk, assistant_chunk)
|
|
314
|
+
"""
|
|
315
|
+
start_time = time.perf_counter()
|
|
316
|
+
|
|
317
|
+
# Embed BOTH messages in ONE API call
|
|
318
|
+
embeddings = batch_embed_fn([user_message, assistant_response])
|
|
319
|
+
embed_time = time.perf_counter() - start_time
|
|
320
|
+
|
|
321
|
+
# Create chunks
|
|
322
|
+
user_chunk = MemoryChunk(
|
|
323
|
+
text=user_message,
|
|
324
|
+
role="user",
|
|
325
|
+
embedding=embeddings[0],
|
|
326
|
+
metadata=metadata or {},
|
|
327
|
+
)
|
|
328
|
+
assistant_chunk = MemoryChunk(
|
|
329
|
+
text=assistant_response,
|
|
330
|
+
role="assistant",
|
|
331
|
+
embedding=embeddings[1],
|
|
332
|
+
metadata=metadata or {},
|
|
333
|
+
)
|
|
334
|
+
|
|
335
|
+
# Store in SQLite (batch insert)
|
|
336
|
+
with sqlite3.connect(str(self.db_path)) as conn:
|
|
337
|
+
conn.executemany(
|
|
338
|
+
"""
|
|
339
|
+
INSERT INTO memory_chunks (id, user_id, text, role, embedding, timestamp, metadata)
|
|
340
|
+
VALUES (?, ?, ?, ?, ?, ?, ?)
|
|
341
|
+
""",
|
|
342
|
+
[
|
|
343
|
+
(
|
|
344
|
+
user_chunk.id,
|
|
345
|
+
user_id,
|
|
346
|
+
user_chunk.text,
|
|
347
|
+
user_chunk.role,
|
|
348
|
+
embeddings[0].astype(np.float32).tobytes(),
|
|
349
|
+
user_chunk.timestamp.isoformat(),
|
|
350
|
+
json.dumps(user_chunk.metadata),
|
|
351
|
+
),
|
|
352
|
+
(
|
|
353
|
+
assistant_chunk.id,
|
|
354
|
+
user_id,
|
|
355
|
+
assistant_chunk.text,
|
|
356
|
+
assistant_chunk.role,
|
|
357
|
+
embeddings[1].astype(np.float32).tobytes(),
|
|
358
|
+
assistant_chunk.timestamp.isoformat(),
|
|
359
|
+
json.dumps(assistant_chunk.metadata),
|
|
360
|
+
),
|
|
361
|
+
],
|
|
362
|
+
)
|
|
363
|
+
conn.commit()
|
|
364
|
+
|
|
365
|
+
# Update cache
|
|
366
|
+
if user_id not in self._vector_cache:
|
|
367
|
+
self._vector_cache[user_id] = {}
|
|
368
|
+
self._chunk_cache[user_id] = {}
|
|
369
|
+
|
|
370
|
+
self._vector_cache[user_id][user_chunk.id] = embeddings[0]
|
|
371
|
+
self._vector_cache[user_id][assistant_chunk.id] = embeddings[1]
|
|
372
|
+
self._chunk_cache[user_id][user_chunk.id] = user_chunk
|
|
373
|
+
self._chunk_cache[user_id][assistant_chunk.id] = assistant_chunk
|
|
374
|
+
|
|
375
|
+
total_time = time.perf_counter() - start_time
|
|
376
|
+
logger.debug(
|
|
377
|
+
f"Added turn (batched) in {total_time * 1000:.1f}ms (embed: {embed_time * 1000:.1f}ms)"
|
|
378
|
+
)
|
|
379
|
+
|
|
380
|
+
return user_chunk, assistant_chunk
|
|
381
|
+
|
|
382
|
+
def search(
|
|
383
|
+
self,
|
|
384
|
+
user_id: str,
|
|
385
|
+
query: str,
|
|
386
|
+
top_k: int = 5,
|
|
387
|
+
min_similarity: float = 0.0,
|
|
388
|
+
role_filter: str | None = None,
|
|
389
|
+
) -> list[tuple[MemoryChunk, float]]:
|
|
390
|
+
"""Search for relevant memory chunks.
|
|
391
|
+
|
|
392
|
+
Uses vector similarity search for semantic matching.
|
|
393
|
+
Typical latency: <50ms with API embeddings, <10ms with local.
|
|
394
|
+
|
|
395
|
+
Args:
|
|
396
|
+
user_id: User/entity identifier
|
|
397
|
+
query: Search query
|
|
398
|
+
top_k: Number of results to return
|
|
399
|
+
min_similarity: Minimum cosine similarity threshold
|
|
400
|
+
role_filter: Optional filter by role ("user" or "assistant")
|
|
401
|
+
|
|
402
|
+
Returns:
|
|
403
|
+
List of (chunk, similarity_score) tuples, sorted by relevance
|
|
404
|
+
"""
|
|
405
|
+
if not self.embed_fn:
|
|
406
|
+
raise ValueError("No embedding function set. Call set_embed_fn() first.")
|
|
407
|
+
|
|
408
|
+
start_time = time.perf_counter()
|
|
409
|
+
|
|
410
|
+
# Check if user has any memories
|
|
411
|
+
if user_id not in self._vector_cache or not self._vector_cache[user_id]:
|
|
412
|
+
return []
|
|
413
|
+
|
|
414
|
+
# Embed query
|
|
415
|
+
query_embedding = self.embed_fn(query)
|
|
416
|
+
embed_time = time.perf_counter() - start_time
|
|
417
|
+
|
|
418
|
+
# Get user's vectors
|
|
419
|
+
chunk_ids = list(self._vector_cache[user_id].keys())
|
|
420
|
+
vectors = np.array([self._vector_cache[user_id][cid] for cid in chunk_ids])
|
|
421
|
+
|
|
422
|
+
# Compute similarities
|
|
423
|
+
similarities = cosine_similarity_batch(query_embedding, vectors)
|
|
424
|
+
search_time = time.perf_counter() - start_time - embed_time
|
|
425
|
+
|
|
426
|
+
# Sort by similarity
|
|
427
|
+
sorted_indices = np.argsort(similarities)[::-1]
|
|
428
|
+
|
|
429
|
+
# Collect results
|
|
430
|
+
results = []
|
|
431
|
+
for idx in sorted_indices:
|
|
432
|
+
chunk_id = chunk_ids[idx]
|
|
433
|
+
similarity = float(similarities[idx])
|
|
434
|
+
|
|
435
|
+
if similarity < min_similarity:
|
|
436
|
+
break
|
|
437
|
+
|
|
438
|
+
chunk = self._chunk_cache[user_id][chunk_id]
|
|
439
|
+
|
|
440
|
+
# Apply role filter
|
|
441
|
+
if role_filter and chunk.role != role_filter:
|
|
442
|
+
continue
|
|
443
|
+
|
|
444
|
+
results.append((chunk, similarity))
|
|
445
|
+
|
|
446
|
+
if len(results) >= top_k:
|
|
447
|
+
break
|
|
448
|
+
|
|
449
|
+
total_time = time.perf_counter() - start_time
|
|
450
|
+
logger.debug(
|
|
451
|
+
f"Search completed in {total_time * 1000:.1f}ms "
|
|
452
|
+
f"(embed: {embed_time * 1000:.1f}ms, search: {search_time * 1000:.1f}ms)"
|
|
453
|
+
)
|
|
454
|
+
|
|
455
|
+
return results
|
|
456
|
+
|
|
457
|
+
def get_recent(
|
|
458
|
+
self,
|
|
459
|
+
user_id: str,
|
|
460
|
+
limit: int = 10,
|
|
461
|
+
role_filter: str | None = None,
|
|
462
|
+
) -> list[MemoryChunk]:
|
|
463
|
+
"""Get recent memory chunks.
|
|
464
|
+
|
|
465
|
+
Args:
|
|
466
|
+
user_id: User/entity identifier
|
|
467
|
+
limit: Maximum number of chunks to return
|
|
468
|
+
role_filter: Optional filter by role
|
|
469
|
+
|
|
470
|
+
Returns:
|
|
471
|
+
List of chunks, sorted by timestamp (newest first)
|
|
472
|
+
"""
|
|
473
|
+
if user_id not in self._chunk_cache:
|
|
474
|
+
return []
|
|
475
|
+
|
|
476
|
+
chunks = list(self._chunk_cache[user_id].values())
|
|
477
|
+
|
|
478
|
+
# Apply role filter
|
|
479
|
+
if role_filter:
|
|
480
|
+
chunks = [c for c in chunks if c.role == role_filter]
|
|
481
|
+
|
|
482
|
+
# Sort by timestamp
|
|
483
|
+
chunks.sort(key=lambda c: c.timestamp, reverse=True)
|
|
484
|
+
|
|
485
|
+
return chunks[:limit]
|
|
486
|
+
|
|
487
|
+
def get_all(self, user_id: str) -> list[MemoryChunk]:
|
|
488
|
+
"""Get all memory chunks for a user."""
|
|
489
|
+
if user_id not in self._chunk_cache:
|
|
490
|
+
return []
|
|
491
|
+
return list(self._chunk_cache[user_id].values())
|
|
492
|
+
|
|
493
|
+
def delete(self, user_id: str, chunk_id: str) -> bool:
|
|
494
|
+
"""Delete a specific chunk."""
|
|
495
|
+
with sqlite3.connect(str(self.db_path)) as conn:
|
|
496
|
+
cursor = conn.execute(
|
|
497
|
+
"DELETE FROM memory_chunks WHERE id = ? AND user_id = ?",
|
|
498
|
+
(chunk_id, user_id),
|
|
499
|
+
)
|
|
500
|
+
conn.commit()
|
|
501
|
+
deleted = cursor.rowcount > 0
|
|
502
|
+
|
|
503
|
+
if deleted and user_id in self._vector_cache:
|
|
504
|
+
self._vector_cache[user_id].pop(chunk_id, None)
|
|
505
|
+
self._chunk_cache[user_id].pop(chunk_id, None)
|
|
506
|
+
|
|
507
|
+
return deleted
|
|
508
|
+
|
|
509
|
+
def clear(self, user_id: str) -> int:
|
|
510
|
+
"""Clear all memories for a user."""
|
|
511
|
+
with sqlite3.connect(str(self.db_path)) as conn:
|
|
512
|
+
cursor = conn.execute(
|
|
513
|
+
"DELETE FROM memory_chunks WHERE user_id = ?",
|
|
514
|
+
(user_id,),
|
|
515
|
+
)
|
|
516
|
+
conn.commit()
|
|
517
|
+
count = cursor.rowcount
|
|
518
|
+
|
|
519
|
+
self._vector_cache.pop(user_id, None)
|
|
520
|
+
self._chunk_cache.pop(user_id, None)
|
|
521
|
+
|
|
522
|
+
return count
|
|
523
|
+
|
|
524
|
+
def stats(self, user_id: str) -> dict[str, Any]:
|
|
525
|
+
"""Get statistics for a user."""
|
|
526
|
+
chunks = self.get_all(user_id)
|
|
527
|
+
return {
|
|
528
|
+
"total": len(chunks),
|
|
529
|
+
"user_messages": sum(1 for c in chunks if c.role == "user"),
|
|
530
|
+
"assistant_messages": sum(1 for c in chunks if c.role == "assistant"),
|
|
531
|
+
}
|
|
532
|
+
|
|
533
|
+
|
|
534
|
+
# =============================================================================
|
|
535
|
+
# Embedding Functions
|
|
536
|
+
# =============================================================================
|
|
537
|
+
|
|
538
|
+
|
|
539
|
+
def create_openai_embed_fn(
|
|
540
|
+
client: Any,
|
|
541
|
+
model: str = "text-embedding-3-small",
|
|
542
|
+
) -> EmbedFn:
|
|
543
|
+
"""Create an embedding function using OpenAI API.
|
|
544
|
+
|
|
545
|
+
Typical latency: 30-100ms per call.
|
|
546
|
+
|
|
547
|
+
Args:
|
|
548
|
+
client: OpenAI client
|
|
549
|
+
model: Embedding model to use
|
|
550
|
+
|
|
551
|
+
Returns:
|
|
552
|
+
Embedding function
|
|
553
|
+
"""
|
|
554
|
+
|
|
555
|
+
def embed(text: str) -> np.ndarray:
|
|
556
|
+
response = client.embeddings.create(
|
|
557
|
+
model=model,
|
|
558
|
+
input=text,
|
|
559
|
+
)
|
|
560
|
+
return np.array(response.data[0].embedding, dtype=np.float32)
|
|
561
|
+
|
|
562
|
+
return embed
|
|
563
|
+
|
|
564
|
+
|
|
565
|
+
def create_openai_batch_embed_fn(
|
|
566
|
+
client: Any,
|
|
567
|
+
model: str = "text-embedding-3-small",
|
|
568
|
+
) -> BatchEmbedFn:
|
|
569
|
+
"""Create a BATCH embedding function using OpenAI API.
|
|
570
|
+
|
|
571
|
+
Much faster than individual calls - single API round trip for multiple texts.
|
|
572
|
+
Typical latency: 50-200ms for 10 texts vs 500-2000ms for 10 individual calls.
|
|
573
|
+
|
|
574
|
+
Args:
|
|
575
|
+
client: OpenAI client
|
|
576
|
+
model: Embedding model to use
|
|
577
|
+
|
|
578
|
+
Returns:
|
|
579
|
+
Batch embedding function
|
|
580
|
+
"""
|
|
581
|
+
|
|
582
|
+
def embed_batch(texts: list[str]) -> list[np.ndarray]:
|
|
583
|
+
if not texts:
|
|
584
|
+
return []
|
|
585
|
+
response = client.embeddings.create(
|
|
586
|
+
model=model,
|
|
587
|
+
input=texts,
|
|
588
|
+
)
|
|
589
|
+
# Sort by index to maintain order
|
|
590
|
+
sorted_data = sorted(response.data, key=lambda x: x.index)
|
|
591
|
+
return [np.array(d.embedding, dtype=np.float32) for d in sorted_data]
|
|
592
|
+
|
|
593
|
+
return embed_batch
|
|
594
|
+
|
|
595
|
+
|
|
596
|
+
def create_local_embed_fn(
|
|
597
|
+
model_name: str = "all-MiniLM-L6-v2",
|
|
598
|
+
) -> EmbedFn:
|
|
599
|
+
"""Create an embedding function using local sentence-transformers.
|
|
600
|
+
|
|
601
|
+
Typical latency: 5-20ms per call (after model load).
|
|
602
|
+
|
|
603
|
+
Args:
|
|
604
|
+
model_name: Sentence-transformers model name
|
|
605
|
+
|
|
606
|
+
Returns:
|
|
607
|
+
Embedding function
|
|
608
|
+
"""
|
|
609
|
+
try:
|
|
610
|
+
from sentence_transformers import SentenceTransformer
|
|
611
|
+
except ImportError:
|
|
612
|
+
raise ImportError(
|
|
613
|
+
"sentence-transformers not installed. Install with: pip install sentence-transformers"
|
|
614
|
+
) from None
|
|
615
|
+
|
|
616
|
+
model = SentenceTransformer(model_name)
|
|
617
|
+
|
|
618
|
+
def embed(text: str) -> np.ndarray:
|
|
619
|
+
return model.encode(text, convert_to_numpy=True).astype(np.float32)
|
|
620
|
+
|
|
621
|
+
return embed
|