headroom-ai 0.2.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- headroom/__init__.py +212 -0
- headroom/cache/__init__.py +76 -0
- headroom/cache/anthropic.py +517 -0
- headroom/cache/base.py +342 -0
- headroom/cache/compression_feedback.py +613 -0
- headroom/cache/compression_store.py +814 -0
- headroom/cache/dynamic_detector.py +1026 -0
- headroom/cache/google.py +884 -0
- headroom/cache/openai.py +584 -0
- headroom/cache/registry.py +175 -0
- headroom/cache/semantic.py +451 -0
- headroom/ccr/__init__.py +77 -0
- headroom/ccr/context_tracker.py +582 -0
- headroom/ccr/mcp_server.py +319 -0
- headroom/ccr/response_handler.py +772 -0
- headroom/ccr/tool_injection.py +415 -0
- headroom/cli.py +219 -0
- headroom/client.py +977 -0
- headroom/compression/__init__.py +42 -0
- headroom/compression/detector.py +424 -0
- headroom/compression/handlers/__init__.py +22 -0
- headroom/compression/handlers/base.py +219 -0
- headroom/compression/handlers/code_handler.py +506 -0
- headroom/compression/handlers/json_handler.py +418 -0
- headroom/compression/masks.py +345 -0
- headroom/compression/universal.py +465 -0
- headroom/config.py +474 -0
- headroom/exceptions.py +192 -0
- headroom/integrations/__init__.py +159 -0
- headroom/integrations/agno/__init__.py +53 -0
- headroom/integrations/agno/hooks.py +345 -0
- headroom/integrations/agno/model.py +625 -0
- headroom/integrations/agno/providers.py +154 -0
- headroom/integrations/langchain/__init__.py +106 -0
- headroom/integrations/langchain/agents.py +326 -0
- headroom/integrations/langchain/chat_model.py +1002 -0
- headroom/integrations/langchain/langsmith.py +324 -0
- headroom/integrations/langchain/memory.py +319 -0
- headroom/integrations/langchain/providers.py +200 -0
- headroom/integrations/langchain/retriever.py +371 -0
- headroom/integrations/langchain/streaming.py +341 -0
- headroom/integrations/mcp/__init__.py +37 -0
- headroom/integrations/mcp/server.py +533 -0
- headroom/memory/__init__.py +37 -0
- headroom/memory/extractor.py +390 -0
- headroom/memory/fast_store.py +621 -0
- headroom/memory/fast_wrapper.py +311 -0
- headroom/memory/inline_extractor.py +229 -0
- headroom/memory/store.py +434 -0
- headroom/memory/worker.py +260 -0
- headroom/memory/wrapper.py +321 -0
- headroom/models/__init__.py +39 -0
- headroom/models/registry.py +687 -0
- headroom/parser.py +293 -0
- headroom/pricing/__init__.py +51 -0
- headroom/pricing/anthropic_prices.py +81 -0
- headroom/pricing/litellm_pricing.py +113 -0
- headroom/pricing/openai_prices.py +91 -0
- headroom/pricing/registry.py +188 -0
- headroom/providers/__init__.py +61 -0
- headroom/providers/anthropic.py +621 -0
- headroom/providers/base.py +131 -0
- headroom/providers/cohere.py +362 -0
- headroom/providers/google.py +427 -0
- headroom/providers/litellm.py +297 -0
- headroom/providers/openai.py +566 -0
- headroom/providers/openai_compatible.py +521 -0
- headroom/proxy/__init__.py +19 -0
- headroom/proxy/server.py +2683 -0
- headroom/py.typed +0 -0
- headroom/relevance/__init__.py +124 -0
- headroom/relevance/base.py +106 -0
- headroom/relevance/bm25.py +255 -0
- headroom/relevance/embedding.py +255 -0
- headroom/relevance/hybrid.py +259 -0
- headroom/reporting/__init__.py +5 -0
- headroom/reporting/generator.py +549 -0
- headroom/storage/__init__.py +41 -0
- headroom/storage/base.py +125 -0
- headroom/storage/jsonl.py +220 -0
- headroom/storage/sqlite.py +289 -0
- headroom/telemetry/__init__.py +91 -0
- headroom/telemetry/collector.py +764 -0
- headroom/telemetry/models.py +880 -0
- headroom/telemetry/toin.py +1579 -0
- headroom/tokenizer.py +80 -0
- headroom/tokenizers/__init__.py +75 -0
- headroom/tokenizers/base.py +210 -0
- headroom/tokenizers/estimator.py +198 -0
- headroom/tokenizers/huggingface.py +317 -0
- headroom/tokenizers/mistral.py +245 -0
- headroom/tokenizers/registry.py +398 -0
- headroom/tokenizers/tiktoken_counter.py +248 -0
- headroom/transforms/__init__.py +106 -0
- headroom/transforms/base.py +57 -0
- headroom/transforms/cache_aligner.py +357 -0
- headroom/transforms/code_compressor.py +1313 -0
- headroom/transforms/content_detector.py +335 -0
- headroom/transforms/content_router.py +1158 -0
- headroom/transforms/llmlingua_compressor.py +638 -0
- headroom/transforms/log_compressor.py +529 -0
- headroom/transforms/pipeline.py +297 -0
- headroom/transforms/rolling_window.py +350 -0
- headroom/transforms/search_compressor.py +365 -0
- headroom/transforms/smart_crusher.py +2682 -0
- headroom/transforms/text_compressor.py +259 -0
- headroom/transforms/tool_crusher.py +338 -0
- headroom/utils.py +215 -0
- headroom_ai-0.2.13.dist-info/METADATA +315 -0
- headroom_ai-0.2.13.dist-info/RECORD +114 -0
- headroom_ai-0.2.13.dist-info/WHEEL +4 -0
- headroom_ai-0.2.13.dist-info/entry_points.txt +2 -0
- headroom_ai-0.2.13.dist-info/licenses/LICENSE +190 -0
- headroom_ai-0.2.13.dist-info/licenses/NOTICE +43 -0
headroom/memory/store.py
ADDED
|
@@ -0,0 +1,434 @@
|
|
|
1
|
+
"""SQLite + FTS5 memory storage for Headroom Memory.
|
|
2
|
+
|
|
3
|
+
Simple, fast, local-first storage with full-text search.
|
|
4
|
+
No external dependencies - just SQLite (built into Python).
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import json
|
|
10
|
+
import sqlite3
|
|
11
|
+
import uuid
|
|
12
|
+
from dataclasses import dataclass, field
|
|
13
|
+
from datetime import datetime
|
|
14
|
+
from pathlib import Path
|
|
15
|
+
from typing import Literal
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@dataclass
|
|
19
|
+
class Memory:
|
|
20
|
+
"""A single memory entry."""
|
|
21
|
+
|
|
22
|
+
content: str
|
|
23
|
+
category: Literal["preference", "fact", "context"] = "fact"
|
|
24
|
+
importance: float = 0.5
|
|
25
|
+
id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
|
26
|
+
created_at: datetime = field(default_factory=datetime.utcnow)
|
|
27
|
+
metadata: dict = field(default_factory=dict)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@dataclass
|
|
31
|
+
class PendingExtraction:
|
|
32
|
+
"""A conversation pending memory extraction."""
|
|
33
|
+
|
|
34
|
+
user_id: str
|
|
35
|
+
query: str
|
|
36
|
+
response: str
|
|
37
|
+
id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
|
38
|
+
created_at: datetime = field(default_factory=datetime.utcnow)
|
|
39
|
+
status: Literal["pending", "processing", "done", "failed"] = "pending"
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class SQLiteMemoryStore:
|
|
43
|
+
"""SQLite + FTS5 storage for memories.
|
|
44
|
+
|
|
45
|
+
Features:
|
|
46
|
+
- Full-text search via FTS5
|
|
47
|
+
- User isolation (each user_id has separate memories)
|
|
48
|
+
- Pending extractions for crash recovery
|
|
49
|
+
- Thread-safe with connection per call
|
|
50
|
+
|
|
51
|
+
Usage:
|
|
52
|
+
store = SQLiteMemoryStore("./memory.db")
|
|
53
|
+
store.save("alice", Memory(content="Prefers Python"))
|
|
54
|
+
results = store.search("alice", "python")
|
|
55
|
+
"""
|
|
56
|
+
|
|
57
|
+
def __init__(self, db_path: str | Path = "headroom_memory.db"):
|
|
58
|
+
"""Initialize the store.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
db_path: Path to SQLite database file. Created if doesn't exist.
|
|
62
|
+
"""
|
|
63
|
+
self.db_path = Path(db_path)
|
|
64
|
+
self._init_db()
|
|
65
|
+
|
|
66
|
+
def _get_conn(self) -> sqlite3.Connection:
|
|
67
|
+
"""Get a new connection (thread-safe pattern)."""
|
|
68
|
+
conn = sqlite3.connect(str(self.db_path))
|
|
69
|
+
conn.row_factory = sqlite3.Row
|
|
70
|
+
return conn
|
|
71
|
+
|
|
72
|
+
def _init_db(self) -> None:
|
|
73
|
+
"""Initialize database schema."""
|
|
74
|
+
with self._get_conn() as conn:
|
|
75
|
+
# Main memories table
|
|
76
|
+
conn.execute("""
|
|
77
|
+
CREATE TABLE IF NOT EXISTS memories (
|
|
78
|
+
id TEXT PRIMARY KEY,
|
|
79
|
+
user_id TEXT NOT NULL,
|
|
80
|
+
content TEXT NOT NULL,
|
|
81
|
+
category TEXT NOT NULL DEFAULT 'fact',
|
|
82
|
+
importance REAL NOT NULL DEFAULT 0.5,
|
|
83
|
+
created_at TEXT NOT NULL,
|
|
84
|
+
metadata TEXT NOT NULL DEFAULT '{}'
|
|
85
|
+
)
|
|
86
|
+
""")
|
|
87
|
+
|
|
88
|
+
# FTS5 virtual table for full-text search
|
|
89
|
+
conn.execute("""
|
|
90
|
+
CREATE VIRTUAL TABLE IF NOT EXISTS memories_fts USING fts5(
|
|
91
|
+
content,
|
|
92
|
+
content='memories',
|
|
93
|
+
content_rowid='rowid'
|
|
94
|
+
)
|
|
95
|
+
""")
|
|
96
|
+
|
|
97
|
+
# Triggers to keep FTS in sync
|
|
98
|
+
conn.execute("""
|
|
99
|
+
CREATE TRIGGER IF NOT EXISTS memories_ai AFTER INSERT ON memories BEGIN
|
|
100
|
+
INSERT INTO memories_fts(rowid, content)
|
|
101
|
+
VALUES (new.rowid, new.content);
|
|
102
|
+
END
|
|
103
|
+
""")
|
|
104
|
+
|
|
105
|
+
conn.execute("""
|
|
106
|
+
CREATE TRIGGER IF NOT EXISTS memories_ad AFTER DELETE ON memories BEGIN
|
|
107
|
+
INSERT INTO memories_fts(memories_fts, rowid, content)
|
|
108
|
+
VALUES ('delete', old.rowid, old.content);
|
|
109
|
+
END
|
|
110
|
+
""")
|
|
111
|
+
|
|
112
|
+
conn.execute("""
|
|
113
|
+
CREATE TRIGGER IF NOT EXISTS memories_au AFTER UPDATE ON memories BEGIN
|
|
114
|
+
INSERT INTO memories_fts(memories_fts, rowid, content)
|
|
115
|
+
VALUES ('delete', old.rowid, old.content);
|
|
116
|
+
INSERT INTO memories_fts(rowid, content)
|
|
117
|
+
VALUES (new.rowid, new.content);
|
|
118
|
+
END
|
|
119
|
+
""")
|
|
120
|
+
|
|
121
|
+
# Index for user_id filtering
|
|
122
|
+
conn.execute("""
|
|
123
|
+
CREATE INDEX IF NOT EXISTS idx_memories_user_id
|
|
124
|
+
ON memories(user_id)
|
|
125
|
+
""")
|
|
126
|
+
|
|
127
|
+
# Pending extractions table (for crash recovery)
|
|
128
|
+
conn.execute("""
|
|
129
|
+
CREATE TABLE IF NOT EXISTS pending_extractions (
|
|
130
|
+
id TEXT PRIMARY KEY,
|
|
131
|
+
user_id TEXT NOT NULL,
|
|
132
|
+
query TEXT NOT NULL,
|
|
133
|
+
response TEXT NOT NULL,
|
|
134
|
+
created_at TEXT NOT NULL,
|
|
135
|
+
status TEXT NOT NULL DEFAULT 'pending'
|
|
136
|
+
)
|
|
137
|
+
""")
|
|
138
|
+
|
|
139
|
+
conn.execute("""
|
|
140
|
+
CREATE INDEX IF NOT EXISTS idx_pending_status
|
|
141
|
+
ON pending_extractions(status)
|
|
142
|
+
""")
|
|
143
|
+
|
|
144
|
+
conn.commit()
|
|
145
|
+
|
|
146
|
+
def save(self, user_id: str, memory: Memory) -> None:
|
|
147
|
+
"""Save a memory for a user.
|
|
148
|
+
|
|
149
|
+
Args:
|
|
150
|
+
user_id: User identifier for isolation
|
|
151
|
+
memory: Memory to save
|
|
152
|
+
"""
|
|
153
|
+
with self._get_conn() as conn:
|
|
154
|
+
conn.execute(
|
|
155
|
+
"""
|
|
156
|
+
INSERT INTO memories (id, user_id, content, category, importance, created_at, metadata)
|
|
157
|
+
VALUES (?, ?, ?, ?, ?, ?, ?)
|
|
158
|
+
""",
|
|
159
|
+
(
|
|
160
|
+
memory.id,
|
|
161
|
+
user_id,
|
|
162
|
+
memory.content,
|
|
163
|
+
memory.category,
|
|
164
|
+
memory.importance,
|
|
165
|
+
memory.created_at.isoformat(),
|
|
166
|
+
json.dumps(memory.metadata),
|
|
167
|
+
),
|
|
168
|
+
)
|
|
169
|
+
conn.commit()
|
|
170
|
+
|
|
171
|
+
def search(self, user_id: str, query: str, top_k: int = 5) -> list[Memory]:
|
|
172
|
+
"""Search memories using FTS5 full-text search.
|
|
173
|
+
|
|
174
|
+
Args:
|
|
175
|
+
user_id: User identifier for isolation
|
|
176
|
+
query: Search query (auto-escaped, or use raw FTS5 syntax with prefix '_raw:')
|
|
177
|
+
top_k: Maximum number of results
|
|
178
|
+
|
|
179
|
+
Returns:
|
|
180
|
+
List of matching memories, ranked by relevance
|
|
181
|
+
"""
|
|
182
|
+
# Sanitize query for FTS5 (escape special characters unless raw mode)
|
|
183
|
+
if query.startswith("_raw:"):
|
|
184
|
+
fts_query = query[5:] # Use raw FTS5 syntax
|
|
185
|
+
else:
|
|
186
|
+
fts_query = self._sanitize_fts_query(query)
|
|
187
|
+
|
|
188
|
+
if not fts_query.strip():
|
|
189
|
+
return []
|
|
190
|
+
|
|
191
|
+
with self._get_conn() as conn:
|
|
192
|
+
# Use FTS5 MATCH with BM25 ranking, filtered by user_id
|
|
193
|
+
cursor = conn.execute(
|
|
194
|
+
"""
|
|
195
|
+
SELECT m.*, bm25(memories_fts) as rank
|
|
196
|
+
FROM memories m
|
|
197
|
+
JOIN memories_fts ON m.rowid = memories_fts.rowid
|
|
198
|
+
WHERE memories_fts MATCH ? AND m.user_id = ?
|
|
199
|
+
ORDER BY rank
|
|
200
|
+
LIMIT ?
|
|
201
|
+
""",
|
|
202
|
+
(fts_query, user_id, top_k),
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
results = []
|
|
206
|
+
for row in cursor:
|
|
207
|
+
results.append(
|
|
208
|
+
Memory(
|
|
209
|
+
id=row["id"],
|
|
210
|
+
content=row["content"],
|
|
211
|
+
category=row["category"],
|
|
212
|
+
importance=row["importance"],
|
|
213
|
+
created_at=datetime.fromisoformat(row["created_at"]),
|
|
214
|
+
metadata=json.loads(row["metadata"]),
|
|
215
|
+
)
|
|
216
|
+
)
|
|
217
|
+
return results
|
|
218
|
+
|
|
219
|
+
def _sanitize_fts_query(self, query: str) -> str:
|
|
220
|
+
"""Sanitize a query for FTS5.
|
|
221
|
+
|
|
222
|
+
Escapes special characters and converts to prefix search for better matching.
|
|
223
|
+
|
|
224
|
+
Args:
|
|
225
|
+
query: Raw user query
|
|
226
|
+
|
|
227
|
+
Returns:
|
|
228
|
+
FTS5-safe query string
|
|
229
|
+
"""
|
|
230
|
+
# FTS5 special characters that need escaping
|
|
231
|
+
# We use a simple approach: extract words and use OR between them
|
|
232
|
+
import re
|
|
233
|
+
|
|
234
|
+
# Extract alphanumeric words
|
|
235
|
+
words = re.findall(r"\w+", query)
|
|
236
|
+
|
|
237
|
+
if not words:
|
|
238
|
+
return ""
|
|
239
|
+
|
|
240
|
+
# Use OR between words with prefix matching for flexibility
|
|
241
|
+
# This allows "What language" to match "Python" memories when searching
|
|
242
|
+
# by using prefix matching (word*)
|
|
243
|
+
escaped_words = []
|
|
244
|
+
for word in words:
|
|
245
|
+
# Quote each word to handle any remaining special chars
|
|
246
|
+
escaped_words.append(f'"{word}"')
|
|
247
|
+
|
|
248
|
+
return " OR ".join(escaped_words)
|
|
249
|
+
|
|
250
|
+
def get_all(self, user_id: str) -> list[Memory]:
|
|
251
|
+
"""Get all memories for a user.
|
|
252
|
+
|
|
253
|
+
Args:
|
|
254
|
+
user_id: User identifier
|
|
255
|
+
|
|
256
|
+
Returns:
|
|
257
|
+
All memories for the user, ordered by creation time (newest first)
|
|
258
|
+
"""
|
|
259
|
+
with self._get_conn() as conn:
|
|
260
|
+
cursor = conn.execute(
|
|
261
|
+
"""
|
|
262
|
+
SELECT * FROM memories
|
|
263
|
+
WHERE user_id = ?
|
|
264
|
+
ORDER BY created_at DESC
|
|
265
|
+
""",
|
|
266
|
+
(user_id,),
|
|
267
|
+
)
|
|
268
|
+
|
|
269
|
+
return [
|
|
270
|
+
Memory(
|
|
271
|
+
id=row["id"],
|
|
272
|
+
content=row["content"],
|
|
273
|
+
category=row["category"],
|
|
274
|
+
importance=row["importance"],
|
|
275
|
+
created_at=datetime.fromisoformat(row["created_at"]),
|
|
276
|
+
metadata=json.loads(row["metadata"]),
|
|
277
|
+
)
|
|
278
|
+
for row in cursor
|
|
279
|
+
]
|
|
280
|
+
|
|
281
|
+
def delete(self, user_id: str, memory_id: str) -> bool:
|
|
282
|
+
"""Delete a specific memory.
|
|
283
|
+
|
|
284
|
+
Args:
|
|
285
|
+
user_id: User identifier
|
|
286
|
+
memory_id: ID of memory to delete
|
|
287
|
+
|
|
288
|
+
Returns:
|
|
289
|
+
True if deleted, False if not found
|
|
290
|
+
"""
|
|
291
|
+
with self._get_conn() as conn:
|
|
292
|
+
cursor = conn.execute(
|
|
293
|
+
"DELETE FROM memories WHERE id = ? AND user_id = ?",
|
|
294
|
+
(memory_id, user_id),
|
|
295
|
+
)
|
|
296
|
+
conn.commit()
|
|
297
|
+
return cursor.rowcount > 0
|
|
298
|
+
|
|
299
|
+
def clear(self, user_id: str) -> int:
|
|
300
|
+
"""Delete all memories for a user.
|
|
301
|
+
|
|
302
|
+
Args:
|
|
303
|
+
user_id: User identifier
|
|
304
|
+
|
|
305
|
+
Returns:
|
|
306
|
+
Number of memories deleted
|
|
307
|
+
"""
|
|
308
|
+
with self._get_conn() as conn:
|
|
309
|
+
cursor = conn.execute(
|
|
310
|
+
"DELETE FROM memories WHERE user_id = ?",
|
|
311
|
+
(user_id,),
|
|
312
|
+
)
|
|
313
|
+
conn.commit()
|
|
314
|
+
return cursor.rowcount
|
|
315
|
+
|
|
316
|
+
def stats(self, user_id: str) -> dict:
|
|
317
|
+
"""Get memory statistics for a user.
|
|
318
|
+
|
|
319
|
+
Args:
|
|
320
|
+
user_id: User identifier
|
|
321
|
+
|
|
322
|
+
Returns:
|
|
323
|
+
Dict with count, categories breakdown, etc.
|
|
324
|
+
"""
|
|
325
|
+
with self._get_conn() as conn:
|
|
326
|
+
# Total count
|
|
327
|
+
total = conn.execute(
|
|
328
|
+
"SELECT COUNT(*) as count FROM memories WHERE user_id = ?",
|
|
329
|
+
(user_id,),
|
|
330
|
+
).fetchone()["count"]
|
|
331
|
+
|
|
332
|
+
# Category breakdown
|
|
333
|
+
categories = {}
|
|
334
|
+
for row in conn.execute(
|
|
335
|
+
"""
|
|
336
|
+
SELECT category, COUNT(*) as count
|
|
337
|
+
FROM memories WHERE user_id = ?
|
|
338
|
+
GROUP BY category
|
|
339
|
+
""",
|
|
340
|
+
(user_id,),
|
|
341
|
+
):
|
|
342
|
+
categories[row["category"]] = row["count"]
|
|
343
|
+
|
|
344
|
+
return {
|
|
345
|
+
"total": total,
|
|
346
|
+
"categories": categories,
|
|
347
|
+
}
|
|
348
|
+
|
|
349
|
+
# --- Pending Extractions (for crash recovery) ---
|
|
350
|
+
|
|
351
|
+
def queue_extraction(self, pending: PendingExtraction) -> None:
|
|
352
|
+
"""Queue a conversation for memory extraction.
|
|
353
|
+
|
|
354
|
+
Args:
|
|
355
|
+
pending: The pending extraction to queue
|
|
356
|
+
"""
|
|
357
|
+
with self._get_conn() as conn:
|
|
358
|
+
conn.execute(
|
|
359
|
+
"""
|
|
360
|
+
INSERT INTO pending_extractions (id, user_id, query, response, created_at, status)
|
|
361
|
+
VALUES (?, ?, ?, ?, ?, ?)
|
|
362
|
+
""",
|
|
363
|
+
(
|
|
364
|
+
pending.id,
|
|
365
|
+
pending.user_id,
|
|
366
|
+
pending.query,
|
|
367
|
+
pending.response,
|
|
368
|
+
pending.created_at.isoformat(),
|
|
369
|
+
pending.status,
|
|
370
|
+
),
|
|
371
|
+
)
|
|
372
|
+
conn.commit()
|
|
373
|
+
|
|
374
|
+
def get_pending_extractions(
|
|
375
|
+
self, limit: int = 10, status: str = "pending"
|
|
376
|
+
) -> list[PendingExtraction]:
|
|
377
|
+
"""Get pending extractions for processing.
|
|
378
|
+
|
|
379
|
+
Args:
|
|
380
|
+
limit: Maximum number to return
|
|
381
|
+
status: Filter by status
|
|
382
|
+
|
|
383
|
+
Returns:
|
|
384
|
+
List of pending extractions
|
|
385
|
+
"""
|
|
386
|
+
with self._get_conn() as conn:
|
|
387
|
+
cursor = conn.execute(
|
|
388
|
+
"""
|
|
389
|
+
SELECT * FROM pending_extractions
|
|
390
|
+
WHERE status = ?
|
|
391
|
+
ORDER BY created_at ASC
|
|
392
|
+
LIMIT ?
|
|
393
|
+
""",
|
|
394
|
+
(status, limit),
|
|
395
|
+
)
|
|
396
|
+
|
|
397
|
+
return [
|
|
398
|
+
PendingExtraction(
|
|
399
|
+
id=row["id"],
|
|
400
|
+
user_id=row["user_id"],
|
|
401
|
+
query=row["query"],
|
|
402
|
+
response=row["response"],
|
|
403
|
+
created_at=datetime.fromisoformat(row["created_at"]),
|
|
404
|
+
status=row["status"],
|
|
405
|
+
)
|
|
406
|
+
for row in cursor
|
|
407
|
+
]
|
|
408
|
+
|
|
409
|
+
def update_extraction_status(self, extraction_id: str, status: str) -> None:
|
|
410
|
+
"""Update the status of a pending extraction.
|
|
411
|
+
|
|
412
|
+
Args:
|
|
413
|
+
extraction_id: ID of the extraction
|
|
414
|
+
status: New status
|
|
415
|
+
"""
|
|
416
|
+
with self._get_conn() as conn:
|
|
417
|
+
conn.execute(
|
|
418
|
+
"UPDATE pending_extractions SET status = ? WHERE id = ?",
|
|
419
|
+
(status, extraction_id),
|
|
420
|
+
)
|
|
421
|
+
conn.commit()
|
|
422
|
+
|
|
423
|
+
def delete_extraction(self, extraction_id: str) -> None:
|
|
424
|
+
"""Delete a completed extraction.
|
|
425
|
+
|
|
426
|
+
Args:
|
|
427
|
+
extraction_id: ID of the extraction to delete
|
|
428
|
+
"""
|
|
429
|
+
with self._get_conn() as conn:
|
|
430
|
+
conn.execute(
|
|
431
|
+
"DELETE FROM pending_extractions WHERE id = ?",
|
|
432
|
+
(extraction_id,),
|
|
433
|
+
)
|
|
434
|
+
conn.commit()
|
|
@@ -0,0 +1,260 @@
|
|
|
1
|
+
"""Background worker for batched memory extraction.
|
|
2
|
+
|
|
3
|
+
Collects conversations in a queue and processes them in batches,
|
|
4
|
+
reducing LLM calls and improving efficiency.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import atexit
|
|
10
|
+
import logging
|
|
11
|
+
import threading
|
|
12
|
+
import time
|
|
13
|
+
from typing import TYPE_CHECKING
|
|
14
|
+
|
|
15
|
+
if TYPE_CHECKING:
|
|
16
|
+
from headroom.memory.extractor import MemoryExtractor
|
|
17
|
+
from headroom.memory.store import SQLiteMemoryStore
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
logger = logging.getLogger(__name__)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class ExtractionWorker:
|
|
24
|
+
"""Background worker that batches memory extractions.
|
|
25
|
+
|
|
26
|
+
Features:
|
|
27
|
+
- Collects conversations in a queue
|
|
28
|
+
- Processes in batches (configurable size and timeout)
|
|
29
|
+
- Persists pending work to SQLite for crash recovery
|
|
30
|
+
- Thread-safe, daemon thread (stops with main program)
|
|
31
|
+
|
|
32
|
+
Usage:
|
|
33
|
+
worker = ExtractionWorker(store, extractor)
|
|
34
|
+
worker.start()
|
|
35
|
+
worker.schedule("alice", "I prefer Python", "Great choice!")
|
|
36
|
+
# ... later, memories are extracted and saved automatically
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
def __init__(
|
|
40
|
+
self,
|
|
41
|
+
store: SQLiteMemoryStore,
|
|
42
|
+
extractor: MemoryExtractor,
|
|
43
|
+
batch_size: int = 10,
|
|
44
|
+
max_wait_seconds: float = 30.0,
|
|
45
|
+
):
|
|
46
|
+
"""Initialize the worker.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
store: Memory store for saving extracted memories
|
|
50
|
+
extractor: Extractor for processing conversations
|
|
51
|
+
batch_size: Max conversations per batch
|
|
52
|
+
max_wait_seconds: Max time to wait before processing partial batch
|
|
53
|
+
"""
|
|
54
|
+
self.store = store
|
|
55
|
+
self.extractor = extractor
|
|
56
|
+
self.batch_size = batch_size
|
|
57
|
+
self.max_wait_seconds = max_wait_seconds
|
|
58
|
+
|
|
59
|
+
self._queue: list[tuple[str, str, str]] = [] # (user_id, query, response)
|
|
60
|
+
self._lock = threading.Lock()
|
|
61
|
+
self._event = threading.Event()
|
|
62
|
+
self._running = False
|
|
63
|
+
self._thread: threading.Thread | None = None
|
|
64
|
+
|
|
65
|
+
# Register cleanup on exit
|
|
66
|
+
atexit.register(self._cleanup)
|
|
67
|
+
|
|
68
|
+
def start(self) -> None:
|
|
69
|
+
"""Start the background worker thread."""
|
|
70
|
+
if self._running:
|
|
71
|
+
return
|
|
72
|
+
|
|
73
|
+
self._running = True
|
|
74
|
+
self._thread = threading.Thread(target=self._run, daemon=True)
|
|
75
|
+
self._thread.start()
|
|
76
|
+
|
|
77
|
+
# Process any pending extractions from previous runs (crash recovery)
|
|
78
|
+
self._recover_pending()
|
|
79
|
+
|
|
80
|
+
def stop(self, wait: bool = True, timeout: float = 5.0) -> None:
|
|
81
|
+
"""Stop the worker.
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
wait: If True, process remaining queue before stopping
|
|
85
|
+
timeout: Max time to wait for remaining work
|
|
86
|
+
"""
|
|
87
|
+
if not self._running:
|
|
88
|
+
return
|
|
89
|
+
|
|
90
|
+
self._running = False
|
|
91
|
+
self._event.set() # Wake up the thread
|
|
92
|
+
|
|
93
|
+
if wait and self._thread:
|
|
94
|
+
self._thread.join(timeout=timeout)
|
|
95
|
+
|
|
96
|
+
def schedule(self, user_id: str, query: str, response: str) -> None:
|
|
97
|
+
"""Schedule a conversation for memory extraction.
|
|
98
|
+
|
|
99
|
+
Non-blocking - returns immediately and extracts in background.
|
|
100
|
+
|
|
101
|
+
Args:
|
|
102
|
+
user_id: User identifier
|
|
103
|
+
query: User's message
|
|
104
|
+
response: Assistant's response
|
|
105
|
+
"""
|
|
106
|
+
# Persist to SQLite first (crash recovery)
|
|
107
|
+
from headroom.memory.store import PendingExtraction
|
|
108
|
+
|
|
109
|
+
pending = PendingExtraction(
|
|
110
|
+
user_id=user_id,
|
|
111
|
+
query=query,
|
|
112
|
+
response=response,
|
|
113
|
+
)
|
|
114
|
+
self.store.queue_extraction(pending)
|
|
115
|
+
|
|
116
|
+
# Add to in-memory queue
|
|
117
|
+
with self._lock:
|
|
118
|
+
self._queue.append((user_id, query, response))
|
|
119
|
+
|
|
120
|
+
# Wake up worker if batch is full
|
|
121
|
+
if len(self._queue) >= self.batch_size:
|
|
122
|
+
self._event.set()
|
|
123
|
+
|
|
124
|
+
def flush(self, timeout: float = 60.0) -> bool:
|
|
125
|
+
"""Force immediate processing of all queued extractions.
|
|
126
|
+
|
|
127
|
+
Blocks until all pending extractions are processed or timeout.
|
|
128
|
+
|
|
129
|
+
Args:
|
|
130
|
+
timeout: Max time to wait in seconds
|
|
131
|
+
|
|
132
|
+
Returns:
|
|
133
|
+
True if all extractions completed, False if timed out
|
|
134
|
+
"""
|
|
135
|
+
# Signal worker to process immediately by temporarily setting max_wait to 0
|
|
136
|
+
original_max_wait = self.max_wait_seconds
|
|
137
|
+
self.max_wait_seconds = 0
|
|
138
|
+
self._event.set()
|
|
139
|
+
|
|
140
|
+
# Wait for queue to empty
|
|
141
|
+
start = time.time()
|
|
142
|
+
while time.time() - start < timeout:
|
|
143
|
+
pending = self.store.get_pending_extractions(limit=1, status="pending")
|
|
144
|
+
if not pending:
|
|
145
|
+
self.max_wait_seconds = original_max_wait
|
|
146
|
+
return True
|
|
147
|
+
time.sleep(0.5)
|
|
148
|
+
|
|
149
|
+
self.max_wait_seconds = original_max_wait
|
|
150
|
+
return False
|
|
151
|
+
|
|
152
|
+
def _run(self) -> None:
|
|
153
|
+
"""Main worker loop."""
|
|
154
|
+
last_process_time = time.time()
|
|
155
|
+
|
|
156
|
+
while self._running:
|
|
157
|
+
# Wait for batch to fill or timeout
|
|
158
|
+
self._event.wait(timeout=1.0)
|
|
159
|
+
self._event.clear()
|
|
160
|
+
|
|
161
|
+
now = time.time()
|
|
162
|
+
time_since_last = now - last_process_time
|
|
163
|
+
|
|
164
|
+
with self._lock:
|
|
165
|
+
should_process = len(self._queue) >= self.batch_size or (
|
|
166
|
+
self._queue and time_since_last >= self.max_wait_seconds
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
if should_process:
|
|
170
|
+
batch = self._queue[: self.batch_size]
|
|
171
|
+
self._queue = self._queue[self.batch_size :]
|
|
172
|
+
else:
|
|
173
|
+
batch = []
|
|
174
|
+
|
|
175
|
+
if batch:
|
|
176
|
+
self._process_batch(batch)
|
|
177
|
+
last_process_time = time.time()
|
|
178
|
+
|
|
179
|
+
# Process remaining queue on shutdown
|
|
180
|
+
with self._lock:
|
|
181
|
+
remaining = self._queue[:]
|
|
182
|
+
self._queue = []
|
|
183
|
+
|
|
184
|
+
if remaining:
|
|
185
|
+
self._process_batch(remaining)
|
|
186
|
+
|
|
187
|
+
def _process_batch(self, batch: list[tuple[str, str, str]]) -> None:
|
|
188
|
+
"""Process a batch of conversations.
|
|
189
|
+
|
|
190
|
+
Args:
|
|
191
|
+
batch: List of (user_id, query, response) tuples
|
|
192
|
+
"""
|
|
193
|
+
logger.debug(f"Processing batch of {len(batch)} conversations")
|
|
194
|
+
|
|
195
|
+
try:
|
|
196
|
+
# Extract memories
|
|
197
|
+
result = self.extractor.extract_batch(batch)
|
|
198
|
+
|
|
199
|
+
# Save memories
|
|
200
|
+
for user_id, memories in result.items():
|
|
201
|
+
for memory in memories:
|
|
202
|
+
self.store.save(user_id, memory)
|
|
203
|
+
logger.debug(f"Saved memory for {user_id}: {memory.content[:50]}...")
|
|
204
|
+
|
|
205
|
+
# Mark pending extractions as done
|
|
206
|
+
# Note: In a production system, we'd track exact IDs
|
|
207
|
+
# For simplicity, we clear pending by matching user/query/response
|
|
208
|
+
self._mark_batch_done(batch)
|
|
209
|
+
|
|
210
|
+
except Exception as e:
|
|
211
|
+
logger.error(f"Batch extraction failed: {e}")
|
|
212
|
+
self._mark_batch_failed(batch)
|
|
213
|
+
|
|
214
|
+
def _recover_pending(self) -> None:
|
|
215
|
+
"""Recover pending extractions from previous runs."""
|
|
216
|
+
pending = self.store.get_pending_extractions(limit=100, status="pending")
|
|
217
|
+
|
|
218
|
+
if not pending:
|
|
219
|
+
return
|
|
220
|
+
|
|
221
|
+
logger.info(f"Recovering {len(pending)} pending extractions")
|
|
222
|
+
|
|
223
|
+
with self._lock:
|
|
224
|
+
for p in pending:
|
|
225
|
+
self._queue.append((p.user_id, p.query, p.response))
|
|
226
|
+
|
|
227
|
+
# Trigger processing
|
|
228
|
+
self._event.set()
|
|
229
|
+
|
|
230
|
+
def _mark_batch_done(self, batch: list[tuple[str, str, str]]) -> None:
|
|
231
|
+
"""Mark batch items as completed in the pending table."""
|
|
232
|
+
# Get pending extractions and mark matching ones as done
|
|
233
|
+
pending = self.store.get_pending_extractions(limit=100)
|
|
234
|
+
|
|
235
|
+
for user_id, query, response in batch:
|
|
236
|
+
for p in pending:
|
|
237
|
+
if p.user_id == user_id and p.query == query and p.response == response:
|
|
238
|
+
self.store.delete_extraction(p.id)
|
|
239
|
+
break
|
|
240
|
+
|
|
241
|
+
def _mark_batch_failed(self, batch: list[tuple[str, str, str]]) -> None:
|
|
242
|
+
"""Mark batch items as failed in the pending table."""
|
|
243
|
+
pending = self.store.get_pending_extractions(limit=100)
|
|
244
|
+
|
|
245
|
+
for user_id, query, response in batch:
|
|
246
|
+
for p in pending:
|
|
247
|
+
if p.user_id == user_id and p.query == query and p.response == response:
|
|
248
|
+
self.store.update_extraction_status(p.id, "failed")
|
|
249
|
+
break
|
|
250
|
+
|
|
251
|
+
def _cleanup(self) -> None:
|
|
252
|
+
"""Cleanup on program exit."""
|
|
253
|
+
if self._running:
|
|
254
|
+
self.stop(wait=True, timeout=2.0)
|
|
255
|
+
|
|
256
|
+
@property
|
|
257
|
+
def queue_size(self) -> int:
|
|
258
|
+
"""Get current queue size."""
|
|
259
|
+
with self._lock:
|
|
260
|
+
return len(self._queue)
|