headroom-ai 0.2.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (114) hide show
  1. headroom/__init__.py +212 -0
  2. headroom/cache/__init__.py +76 -0
  3. headroom/cache/anthropic.py +517 -0
  4. headroom/cache/base.py +342 -0
  5. headroom/cache/compression_feedback.py +613 -0
  6. headroom/cache/compression_store.py +814 -0
  7. headroom/cache/dynamic_detector.py +1026 -0
  8. headroom/cache/google.py +884 -0
  9. headroom/cache/openai.py +584 -0
  10. headroom/cache/registry.py +175 -0
  11. headroom/cache/semantic.py +451 -0
  12. headroom/ccr/__init__.py +77 -0
  13. headroom/ccr/context_tracker.py +582 -0
  14. headroom/ccr/mcp_server.py +319 -0
  15. headroom/ccr/response_handler.py +772 -0
  16. headroom/ccr/tool_injection.py +415 -0
  17. headroom/cli.py +219 -0
  18. headroom/client.py +977 -0
  19. headroom/compression/__init__.py +42 -0
  20. headroom/compression/detector.py +424 -0
  21. headroom/compression/handlers/__init__.py +22 -0
  22. headroom/compression/handlers/base.py +219 -0
  23. headroom/compression/handlers/code_handler.py +506 -0
  24. headroom/compression/handlers/json_handler.py +418 -0
  25. headroom/compression/masks.py +345 -0
  26. headroom/compression/universal.py +465 -0
  27. headroom/config.py +474 -0
  28. headroom/exceptions.py +192 -0
  29. headroom/integrations/__init__.py +159 -0
  30. headroom/integrations/agno/__init__.py +53 -0
  31. headroom/integrations/agno/hooks.py +345 -0
  32. headroom/integrations/agno/model.py +625 -0
  33. headroom/integrations/agno/providers.py +154 -0
  34. headroom/integrations/langchain/__init__.py +106 -0
  35. headroom/integrations/langchain/agents.py +326 -0
  36. headroom/integrations/langchain/chat_model.py +1002 -0
  37. headroom/integrations/langchain/langsmith.py +324 -0
  38. headroom/integrations/langchain/memory.py +319 -0
  39. headroom/integrations/langchain/providers.py +200 -0
  40. headroom/integrations/langchain/retriever.py +371 -0
  41. headroom/integrations/langchain/streaming.py +341 -0
  42. headroom/integrations/mcp/__init__.py +37 -0
  43. headroom/integrations/mcp/server.py +533 -0
  44. headroom/memory/__init__.py +37 -0
  45. headroom/memory/extractor.py +390 -0
  46. headroom/memory/fast_store.py +621 -0
  47. headroom/memory/fast_wrapper.py +311 -0
  48. headroom/memory/inline_extractor.py +229 -0
  49. headroom/memory/store.py +434 -0
  50. headroom/memory/worker.py +260 -0
  51. headroom/memory/wrapper.py +321 -0
  52. headroom/models/__init__.py +39 -0
  53. headroom/models/registry.py +687 -0
  54. headroom/parser.py +293 -0
  55. headroom/pricing/__init__.py +51 -0
  56. headroom/pricing/anthropic_prices.py +81 -0
  57. headroom/pricing/litellm_pricing.py +113 -0
  58. headroom/pricing/openai_prices.py +91 -0
  59. headroom/pricing/registry.py +188 -0
  60. headroom/providers/__init__.py +61 -0
  61. headroom/providers/anthropic.py +621 -0
  62. headroom/providers/base.py +131 -0
  63. headroom/providers/cohere.py +362 -0
  64. headroom/providers/google.py +427 -0
  65. headroom/providers/litellm.py +297 -0
  66. headroom/providers/openai.py +566 -0
  67. headroom/providers/openai_compatible.py +521 -0
  68. headroom/proxy/__init__.py +19 -0
  69. headroom/proxy/server.py +2683 -0
  70. headroom/py.typed +0 -0
  71. headroom/relevance/__init__.py +124 -0
  72. headroom/relevance/base.py +106 -0
  73. headroom/relevance/bm25.py +255 -0
  74. headroom/relevance/embedding.py +255 -0
  75. headroom/relevance/hybrid.py +259 -0
  76. headroom/reporting/__init__.py +5 -0
  77. headroom/reporting/generator.py +549 -0
  78. headroom/storage/__init__.py +41 -0
  79. headroom/storage/base.py +125 -0
  80. headroom/storage/jsonl.py +220 -0
  81. headroom/storage/sqlite.py +289 -0
  82. headroom/telemetry/__init__.py +91 -0
  83. headroom/telemetry/collector.py +764 -0
  84. headroom/telemetry/models.py +880 -0
  85. headroom/telemetry/toin.py +1579 -0
  86. headroom/tokenizer.py +80 -0
  87. headroom/tokenizers/__init__.py +75 -0
  88. headroom/tokenizers/base.py +210 -0
  89. headroom/tokenizers/estimator.py +198 -0
  90. headroom/tokenizers/huggingface.py +317 -0
  91. headroom/tokenizers/mistral.py +245 -0
  92. headroom/tokenizers/registry.py +398 -0
  93. headroom/tokenizers/tiktoken_counter.py +248 -0
  94. headroom/transforms/__init__.py +106 -0
  95. headroom/transforms/base.py +57 -0
  96. headroom/transforms/cache_aligner.py +357 -0
  97. headroom/transforms/code_compressor.py +1313 -0
  98. headroom/transforms/content_detector.py +335 -0
  99. headroom/transforms/content_router.py +1158 -0
  100. headroom/transforms/llmlingua_compressor.py +638 -0
  101. headroom/transforms/log_compressor.py +529 -0
  102. headroom/transforms/pipeline.py +297 -0
  103. headroom/transforms/rolling_window.py +350 -0
  104. headroom/transforms/search_compressor.py +365 -0
  105. headroom/transforms/smart_crusher.py +2682 -0
  106. headroom/transforms/text_compressor.py +259 -0
  107. headroom/transforms/tool_crusher.py +338 -0
  108. headroom/utils.py +215 -0
  109. headroom_ai-0.2.13.dist-info/METADATA +315 -0
  110. headroom_ai-0.2.13.dist-info/RECORD +114 -0
  111. headroom_ai-0.2.13.dist-info/WHEEL +4 -0
  112. headroom_ai-0.2.13.dist-info/entry_points.txt +2 -0
  113. headroom_ai-0.2.13.dist-info/licenses/LICENSE +190 -0
  114. headroom_ai-0.2.13.dist-info/licenses/NOTICE +43 -0
@@ -0,0 +1,621 @@
1
+ """Fast embedding-based memory store.
2
+
3
+ Sub-100ms write and read latency by:
4
+ 1. NO LLM extraction - just embed and store
5
+ 2. Vector similarity search - not keyword matching
6
+ 3. Optional local embeddings for sub-10ms latency
7
+
8
+ This replaces the slow LLM-based extraction approach.
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ import json
14
+ import logging
15
+ import sqlite3
16
+ import time
17
+ from collections.abc import Callable
18
+ from dataclasses import dataclass, field
19
+ from datetime import datetime
20
+ from pathlib import Path
21
+ from typing import Any
22
+ from uuid import uuid4
23
+
24
+ import numpy as np
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+
29
+ @dataclass
30
+ class MemoryChunk:
31
+ """A memory chunk with text and embedding."""
32
+
33
+ id: str = field(default_factory=lambda: str(uuid4()))
34
+ text: str = ""
35
+ role: str = "user" # "user" or "assistant"
36
+ embedding: np.ndarray | None = None
37
+ timestamp: datetime = field(default_factory=datetime.utcnow)
38
+ metadata: dict[str, Any] = field(default_factory=dict)
39
+
40
+ def to_dict(self) -> dict:
41
+ """Convert to dictionary for storage."""
42
+ return {
43
+ "id": self.id,
44
+ "text": self.text,
45
+ "role": self.role,
46
+ "embedding": self.embedding.tolist() if self.embedding is not None else None,
47
+ "timestamp": self.timestamp.isoformat(),
48
+ "metadata": self.metadata,
49
+ }
50
+
51
+ @classmethod
52
+ def from_dict(cls, data: dict) -> MemoryChunk:
53
+ """Create from dictionary."""
54
+ embedding = None
55
+ if data.get("embedding"):
56
+ embedding = np.array(data["embedding"], dtype=np.float32)
57
+ return cls(
58
+ id=data["id"],
59
+ text=data["text"],
60
+ role=data.get("role", "user"),
61
+ embedding=embedding,
62
+ timestamp=datetime.fromisoformat(data["timestamp"]),
63
+ metadata=data.get("metadata", {}),
64
+ )
65
+
66
+
67
+ # Type aliases for embedding functions
68
+ EmbedFn = Callable[[str], np.ndarray]
69
+ BatchEmbedFn = Callable[[list[str]], list[np.ndarray]]
70
+
71
+
72
+ def cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
73
+ """Compute cosine similarity between two vectors."""
74
+ norm_a = np.linalg.norm(a)
75
+ norm_b = np.linalg.norm(b)
76
+ if norm_a == 0 or norm_b == 0:
77
+ return 0.0
78
+ return float(np.dot(a, b) / (norm_a * norm_b))
79
+
80
+
81
+ def cosine_similarity_batch(query: np.ndarray, vectors: np.ndarray) -> np.ndarray:
82
+ """Compute cosine similarity between query and multiple vectors."""
83
+ # Normalize query
84
+ query_norm = query / (np.linalg.norm(query) + 1e-9)
85
+ # Normalize vectors
86
+ norms = np.linalg.norm(vectors, axis=1, keepdims=True) + 1e-9
87
+ vectors_norm = vectors / norms
88
+ # Dot product - cast to ndarray to satisfy mypy
89
+ result: np.ndarray = np.dot(vectors_norm, query_norm)
90
+ return result
91
+
92
+
93
+ class FastMemoryStore:
94
+ """Fast embedding-based memory store.
95
+
96
+ Features:
97
+ - Sub-100ms write latency (no LLM, just embedding)
98
+ - Sub-50ms read latency (vector similarity search)
99
+ - Pluggable embedding functions (local or API)
100
+ - SQLite storage with in-memory vector cache
101
+
102
+ Usage:
103
+ store = FastMemoryStore(db_path, embed_fn=my_embed_fn)
104
+ store.add("user_123", "I prefer Python", role="user")
105
+ results = store.search("user_123", "programming language", top_k=5)
106
+ """
107
+
108
+ def __init__(
109
+ self,
110
+ db_path: str | Path,
111
+ embed_fn: EmbedFn | None = None,
112
+ embedding_dim: int = 1536, # OpenAI default
113
+ ):
114
+ """Initialize the store.
115
+
116
+ Args:
117
+ db_path: Path to SQLite database
118
+ embed_fn: Function to embed text (if None, must call set_embed_fn later)
119
+ embedding_dim: Dimension of embeddings
120
+ """
121
+ self.db_path = Path(db_path)
122
+ self.embed_fn = embed_fn
123
+ self.embedding_dim = embedding_dim
124
+
125
+ # In-memory vector cache for fast similarity search
126
+ self._vector_cache: dict[
127
+ str, dict[str, np.ndarray]
128
+ ] = {} # user_id -> {chunk_id -> embedding}
129
+ self._chunk_cache: dict[str, dict[str, MemoryChunk]] = {} # user_id -> {chunk_id -> chunk}
130
+
131
+ self._init_db()
132
+ self._load_cache()
133
+
134
+ def _init_db(self) -> None:
135
+ """Initialize SQLite database."""
136
+ self.db_path.parent.mkdir(parents=True, exist_ok=True)
137
+
138
+ with sqlite3.connect(str(self.db_path)) as conn:
139
+ conn.execute("""
140
+ CREATE TABLE IF NOT EXISTS memory_chunks (
141
+ id TEXT PRIMARY KEY,
142
+ user_id TEXT NOT NULL,
143
+ text TEXT NOT NULL,
144
+ role TEXT DEFAULT 'user',
145
+ embedding BLOB,
146
+ timestamp TEXT NOT NULL,
147
+ metadata TEXT DEFAULT '{}',
148
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
149
+ )
150
+ """)
151
+ conn.execute("""
152
+ CREATE INDEX IF NOT EXISTS idx_chunks_user_id
153
+ ON memory_chunks(user_id)
154
+ """)
155
+ conn.execute("""
156
+ CREATE INDEX IF NOT EXISTS idx_chunks_timestamp
157
+ ON memory_chunks(user_id, timestamp DESC)
158
+ """)
159
+ conn.commit()
160
+
161
+ def _load_cache(self) -> None:
162
+ """Load all embeddings into memory for fast search."""
163
+ with sqlite3.connect(str(self.db_path)) as conn:
164
+ cursor = conn.execute("""
165
+ SELECT id, user_id, text, role, embedding, timestamp, metadata
166
+ FROM memory_chunks
167
+ WHERE embedding IS NOT NULL
168
+ """)
169
+
170
+ for row in cursor:
171
+ chunk_id, user_id, text, role, embedding_blob, timestamp, metadata = row
172
+
173
+ if user_id not in self._vector_cache:
174
+ self._vector_cache[user_id] = {}
175
+ self._chunk_cache[user_id] = {}
176
+
177
+ # Deserialize embedding
178
+ embedding = np.frombuffer(embedding_blob, dtype=np.float32)
179
+
180
+ self._vector_cache[user_id][chunk_id] = embedding
181
+
182
+ chunk = MemoryChunk(
183
+ id=chunk_id,
184
+ text=text,
185
+ role=role,
186
+ embedding=embedding,
187
+ timestamp=datetime.fromisoformat(timestamp),
188
+ metadata=json.loads(metadata) if metadata else {},
189
+ )
190
+ self._chunk_cache[user_id][chunk_id] = chunk
191
+
192
+ logger.debug(f"Loaded {sum(len(v) for v in self._vector_cache.values())} chunks into cache")
193
+
194
+ def set_embed_fn(self, embed_fn: EmbedFn) -> None:
195
+ """Set the embedding function."""
196
+ self.embed_fn = embed_fn
197
+
198
+ def add(
199
+ self,
200
+ user_id: str,
201
+ text: str,
202
+ role: str = "user",
203
+ metadata: dict[str, Any] | None = None,
204
+ ) -> MemoryChunk:
205
+ """Add a memory chunk.
206
+
207
+ This is the FAST path - just embed and store, no LLM extraction.
208
+ Typical latency: <50ms with API embeddings, <10ms with local.
209
+
210
+ Args:
211
+ user_id: User/entity identifier
212
+ text: Text to store
213
+ role: "user" or "assistant"
214
+ metadata: Optional metadata
215
+
216
+ Returns:
217
+ The created MemoryChunk
218
+ """
219
+ if not self.embed_fn:
220
+ raise ValueError("No embedding function set. Call set_embed_fn() first.")
221
+
222
+ start_time = time.perf_counter()
223
+
224
+ # Embed the text
225
+ embedding = self.embed_fn(text)
226
+ embed_time = time.perf_counter() - start_time
227
+
228
+ # Create chunk
229
+ chunk = MemoryChunk(
230
+ text=text,
231
+ role=role,
232
+ embedding=embedding,
233
+ metadata=metadata or {},
234
+ )
235
+
236
+ # Store in SQLite
237
+ with sqlite3.connect(str(self.db_path)) as conn:
238
+ conn.execute(
239
+ """
240
+ INSERT INTO memory_chunks (id, user_id, text, role, embedding, timestamp, metadata)
241
+ VALUES (?, ?, ?, ?, ?, ?, ?)
242
+ """,
243
+ (
244
+ chunk.id,
245
+ user_id,
246
+ chunk.text,
247
+ chunk.role,
248
+ embedding.astype(np.float32).tobytes(),
249
+ chunk.timestamp.isoformat(),
250
+ json.dumps(chunk.metadata),
251
+ ),
252
+ )
253
+ conn.commit()
254
+
255
+ # Update cache
256
+ if user_id not in self._vector_cache:
257
+ self._vector_cache[user_id] = {}
258
+ self._chunk_cache[user_id] = {}
259
+
260
+ self._vector_cache[user_id][chunk.id] = embedding
261
+ self._chunk_cache[user_id][chunk.id] = chunk
262
+
263
+ total_time = time.perf_counter() - start_time
264
+ logger.debug(f"Added chunk in {total_time * 1000:.1f}ms (embed: {embed_time * 1000:.1f}ms)")
265
+
266
+ return chunk
267
+
268
+ def add_turn(
269
+ self,
270
+ user_id: str,
271
+ user_message: str,
272
+ assistant_response: str,
273
+ metadata: dict[str, Any] | None = None,
274
+ ) -> tuple[MemoryChunk, MemoryChunk]:
275
+ """Add a conversation turn (user message + assistant response).
276
+
277
+ Convenience method that stores both parts of a turn.
278
+
279
+ Args:
280
+ user_id: User/entity identifier
281
+ user_message: The user's message
282
+ assistant_response: The assistant's response
283
+ metadata: Optional metadata for both chunks
284
+
285
+ Returns:
286
+ Tuple of (user_chunk, assistant_chunk)
287
+ """
288
+ user_chunk = self.add(user_id, user_message, role="user", metadata=metadata)
289
+ assistant_chunk = self.add(user_id, assistant_response, role="assistant", metadata=metadata)
290
+ return user_chunk, assistant_chunk
291
+
292
+ def add_turn_batched(
293
+ self,
294
+ user_id: str,
295
+ user_message: str,
296
+ assistant_response: str,
297
+ batch_embed_fn: BatchEmbedFn,
298
+ metadata: dict[str, Any] | None = None,
299
+ ) -> tuple[MemoryChunk, MemoryChunk]:
300
+ """Add a conversation turn using BATCHED embedding (single API call).
301
+
302
+ This is the FASTEST path - embeds both messages in ONE API call.
303
+ Typical latency: 50-100ms total vs 200-400ms with individual calls.
304
+
305
+ Args:
306
+ user_id: User/entity identifier
307
+ user_message: The user's message
308
+ assistant_response: The assistant's response
309
+ batch_embed_fn: Batch embedding function
310
+ metadata: Optional metadata for both chunks
311
+
312
+ Returns:
313
+ Tuple of (user_chunk, assistant_chunk)
314
+ """
315
+ start_time = time.perf_counter()
316
+
317
+ # Embed BOTH messages in ONE API call
318
+ embeddings = batch_embed_fn([user_message, assistant_response])
319
+ embed_time = time.perf_counter() - start_time
320
+
321
+ # Create chunks
322
+ user_chunk = MemoryChunk(
323
+ text=user_message,
324
+ role="user",
325
+ embedding=embeddings[0],
326
+ metadata=metadata or {},
327
+ )
328
+ assistant_chunk = MemoryChunk(
329
+ text=assistant_response,
330
+ role="assistant",
331
+ embedding=embeddings[1],
332
+ metadata=metadata or {},
333
+ )
334
+
335
+ # Store in SQLite (batch insert)
336
+ with sqlite3.connect(str(self.db_path)) as conn:
337
+ conn.executemany(
338
+ """
339
+ INSERT INTO memory_chunks (id, user_id, text, role, embedding, timestamp, metadata)
340
+ VALUES (?, ?, ?, ?, ?, ?, ?)
341
+ """,
342
+ [
343
+ (
344
+ user_chunk.id,
345
+ user_id,
346
+ user_chunk.text,
347
+ user_chunk.role,
348
+ embeddings[0].astype(np.float32).tobytes(),
349
+ user_chunk.timestamp.isoformat(),
350
+ json.dumps(user_chunk.metadata),
351
+ ),
352
+ (
353
+ assistant_chunk.id,
354
+ user_id,
355
+ assistant_chunk.text,
356
+ assistant_chunk.role,
357
+ embeddings[1].astype(np.float32).tobytes(),
358
+ assistant_chunk.timestamp.isoformat(),
359
+ json.dumps(assistant_chunk.metadata),
360
+ ),
361
+ ],
362
+ )
363
+ conn.commit()
364
+
365
+ # Update cache
366
+ if user_id not in self._vector_cache:
367
+ self._vector_cache[user_id] = {}
368
+ self._chunk_cache[user_id] = {}
369
+
370
+ self._vector_cache[user_id][user_chunk.id] = embeddings[0]
371
+ self._vector_cache[user_id][assistant_chunk.id] = embeddings[1]
372
+ self._chunk_cache[user_id][user_chunk.id] = user_chunk
373
+ self._chunk_cache[user_id][assistant_chunk.id] = assistant_chunk
374
+
375
+ total_time = time.perf_counter() - start_time
376
+ logger.debug(
377
+ f"Added turn (batched) in {total_time * 1000:.1f}ms (embed: {embed_time * 1000:.1f}ms)"
378
+ )
379
+
380
+ return user_chunk, assistant_chunk
381
+
382
+ def search(
383
+ self,
384
+ user_id: str,
385
+ query: str,
386
+ top_k: int = 5,
387
+ min_similarity: float = 0.0,
388
+ role_filter: str | None = None,
389
+ ) -> list[tuple[MemoryChunk, float]]:
390
+ """Search for relevant memory chunks.
391
+
392
+ Uses vector similarity search for semantic matching.
393
+ Typical latency: <50ms with API embeddings, <10ms with local.
394
+
395
+ Args:
396
+ user_id: User/entity identifier
397
+ query: Search query
398
+ top_k: Number of results to return
399
+ min_similarity: Minimum cosine similarity threshold
400
+ role_filter: Optional filter by role ("user" or "assistant")
401
+
402
+ Returns:
403
+ List of (chunk, similarity_score) tuples, sorted by relevance
404
+ """
405
+ if not self.embed_fn:
406
+ raise ValueError("No embedding function set. Call set_embed_fn() first.")
407
+
408
+ start_time = time.perf_counter()
409
+
410
+ # Check if user has any memories
411
+ if user_id not in self._vector_cache or not self._vector_cache[user_id]:
412
+ return []
413
+
414
+ # Embed query
415
+ query_embedding = self.embed_fn(query)
416
+ embed_time = time.perf_counter() - start_time
417
+
418
+ # Get user's vectors
419
+ chunk_ids = list(self._vector_cache[user_id].keys())
420
+ vectors = np.array([self._vector_cache[user_id][cid] for cid in chunk_ids])
421
+
422
+ # Compute similarities
423
+ similarities = cosine_similarity_batch(query_embedding, vectors)
424
+ search_time = time.perf_counter() - start_time - embed_time
425
+
426
+ # Sort by similarity
427
+ sorted_indices = np.argsort(similarities)[::-1]
428
+
429
+ # Collect results
430
+ results = []
431
+ for idx in sorted_indices:
432
+ chunk_id = chunk_ids[idx]
433
+ similarity = float(similarities[idx])
434
+
435
+ if similarity < min_similarity:
436
+ break
437
+
438
+ chunk = self._chunk_cache[user_id][chunk_id]
439
+
440
+ # Apply role filter
441
+ if role_filter and chunk.role != role_filter:
442
+ continue
443
+
444
+ results.append((chunk, similarity))
445
+
446
+ if len(results) >= top_k:
447
+ break
448
+
449
+ total_time = time.perf_counter() - start_time
450
+ logger.debug(
451
+ f"Search completed in {total_time * 1000:.1f}ms "
452
+ f"(embed: {embed_time * 1000:.1f}ms, search: {search_time * 1000:.1f}ms)"
453
+ )
454
+
455
+ return results
456
+
457
+ def get_recent(
458
+ self,
459
+ user_id: str,
460
+ limit: int = 10,
461
+ role_filter: str | None = None,
462
+ ) -> list[MemoryChunk]:
463
+ """Get recent memory chunks.
464
+
465
+ Args:
466
+ user_id: User/entity identifier
467
+ limit: Maximum number of chunks to return
468
+ role_filter: Optional filter by role
469
+
470
+ Returns:
471
+ List of chunks, sorted by timestamp (newest first)
472
+ """
473
+ if user_id not in self._chunk_cache:
474
+ return []
475
+
476
+ chunks = list(self._chunk_cache[user_id].values())
477
+
478
+ # Apply role filter
479
+ if role_filter:
480
+ chunks = [c for c in chunks if c.role == role_filter]
481
+
482
+ # Sort by timestamp
483
+ chunks.sort(key=lambda c: c.timestamp, reverse=True)
484
+
485
+ return chunks[:limit]
486
+
487
+ def get_all(self, user_id: str) -> list[MemoryChunk]:
488
+ """Get all memory chunks for a user."""
489
+ if user_id not in self._chunk_cache:
490
+ return []
491
+ return list(self._chunk_cache[user_id].values())
492
+
493
+ def delete(self, user_id: str, chunk_id: str) -> bool:
494
+ """Delete a specific chunk."""
495
+ with sqlite3.connect(str(self.db_path)) as conn:
496
+ cursor = conn.execute(
497
+ "DELETE FROM memory_chunks WHERE id = ? AND user_id = ?",
498
+ (chunk_id, user_id),
499
+ )
500
+ conn.commit()
501
+ deleted = cursor.rowcount > 0
502
+
503
+ if deleted and user_id in self._vector_cache:
504
+ self._vector_cache[user_id].pop(chunk_id, None)
505
+ self._chunk_cache[user_id].pop(chunk_id, None)
506
+
507
+ return deleted
508
+
509
+ def clear(self, user_id: str) -> int:
510
+ """Clear all memories for a user."""
511
+ with sqlite3.connect(str(self.db_path)) as conn:
512
+ cursor = conn.execute(
513
+ "DELETE FROM memory_chunks WHERE user_id = ?",
514
+ (user_id,),
515
+ )
516
+ conn.commit()
517
+ count = cursor.rowcount
518
+
519
+ self._vector_cache.pop(user_id, None)
520
+ self._chunk_cache.pop(user_id, None)
521
+
522
+ return count
523
+
524
+ def stats(self, user_id: str) -> dict[str, Any]:
525
+ """Get statistics for a user."""
526
+ chunks = self.get_all(user_id)
527
+ return {
528
+ "total": len(chunks),
529
+ "user_messages": sum(1 for c in chunks if c.role == "user"),
530
+ "assistant_messages": sum(1 for c in chunks if c.role == "assistant"),
531
+ }
532
+
533
+
534
+ # =============================================================================
535
+ # Embedding Functions
536
+ # =============================================================================
537
+
538
+
539
+ def create_openai_embed_fn(
540
+ client: Any,
541
+ model: str = "text-embedding-3-small",
542
+ ) -> EmbedFn:
543
+ """Create an embedding function using OpenAI API.
544
+
545
+ Typical latency: 30-100ms per call.
546
+
547
+ Args:
548
+ client: OpenAI client
549
+ model: Embedding model to use
550
+
551
+ Returns:
552
+ Embedding function
553
+ """
554
+
555
+ def embed(text: str) -> np.ndarray:
556
+ response = client.embeddings.create(
557
+ model=model,
558
+ input=text,
559
+ )
560
+ return np.array(response.data[0].embedding, dtype=np.float32)
561
+
562
+ return embed
563
+
564
+
565
+ def create_openai_batch_embed_fn(
566
+ client: Any,
567
+ model: str = "text-embedding-3-small",
568
+ ) -> BatchEmbedFn:
569
+ """Create a BATCH embedding function using OpenAI API.
570
+
571
+ Much faster than individual calls - single API round trip for multiple texts.
572
+ Typical latency: 50-200ms for 10 texts vs 500-2000ms for 10 individual calls.
573
+
574
+ Args:
575
+ client: OpenAI client
576
+ model: Embedding model to use
577
+
578
+ Returns:
579
+ Batch embedding function
580
+ """
581
+
582
+ def embed_batch(texts: list[str]) -> list[np.ndarray]:
583
+ if not texts:
584
+ return []
585
+ response = client.embeddings.create(
586
+ model=model,
587
+ input=texts,
588
+ )
589
+ # Sort by index to maintain order
590
+ sorted_data = sorted(response.data, key=lambda x: x.index)
591
+ return [np.array(d.embedding, dtype=np.float32) for d in sorted_data]
592
+
593
+ return embed_batch
594
+
595
+
596
+ def create_local_embed_fn(
597
+ model_name: str = "all-MiniLM-L6-v2",
598
+ ) -> EmbedFn:
599
+ """Create an embedding function using local sentence-transformers.
600
+
601
+ Typical latency: 5-20ms per call (after model load).
602
+
603
+ Args:
604
+ model_name: Sentence-transformers model name
605
+
606
+ Returns:
607
+ Embedding function
608
+ """
609
+ try:
610
+ from sentence_transformers import SentenceTransformer
611
+ except ImportError:
612
+ raise ImportError(
613
+ "sentence-transformers not installed. Install with: pip install sentence-transformers"
614
+ ) from None
615
+
616
+ model = SentenceTransformer(model_name)
617
+
618
+ def embed(text: str) -> np.ndarray:
619
+ return model.encode(text, convert_to_numpy=True).astype(np.float32)
620
+
621
+ return embed