tribalmemory 0.1.1__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tribalmemory/cli.py +147 -4
- tribalmemory/interfaces.py +44 -0
- tribalmemory/mcp/server.py +160 -14
- tribalmemory/server/app.py +53 -2
- tribalmemory/server/config.py +41 -0
- tribalmemory/server/models.py +65 -0
- tribalmemory/server/routes.py +68 -0
- tribalmemory/services/fts_store.py +255 -0
- tribalmemory/services/memory.py +193 -33
- tribalmemory/services/reranker.py +267 -0
- tribalmemory/services/session_store.py +412 -0
- tribalmemory/services/vector_store.py +86 -1
- {tribalmemory-0.1.1.dist-info → tribalmemory-0.2.0.dist-info}/METADATA +1 -1
- {tribalmemory-0.1.1.dist-info → tribalmemory-0.2.0.dist-info}/RECORD +18 -15
- {tribalmemory-0.1.1.dist-info → tribalmemory-0.2.0.dist-info}/WHEEL +0 -0
- {tribalmemory-0.1.1.dist-info → tribalmemory-0.2.0.dist-info}/entry_points.txt +0 -0
- {tribalmemory-0.1.1.dist-info → tribalmemory-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {tribalmemory-0.1.1.dist-info → tribalmemory-0.2.0.dist-info}/top_level.txt +0 -0
tribalmemory/services/memory.py
CHANGED
|
@@ -1,7 +1,9 @@
|
|
|
1
1
|
"""Tribal Memory Service - Main API for agents."""
|
|
2
2
|
|
|
3
|
+
import logging
|
|
3
4
|
import os
|
|
4
5
|
from datetime import datetime
|
|
6
|
+
from pathlib import Path
|
|
5
7
|
from typing import Optional
|
|
6
8
|
import uuid
|
|
7
9
|
|
|
@@ -15,6 +17,10 @@ from ..interfaces import (
|
|
|
15
17
|
StoreResult,
|
|
16
18
|
)
|
|
17
19
|
from .deduplication import SemanticDeduplicationService
|
|
20
|
+
from .fts_store import FTSStore, hybrid_merge
|
|
21
|
+
from .reranker import IReranker, NoopReranker, create_reranker
|
|
22
|
+
|
|
23
|
+
logger = logging.getLogger(__name__)
|
|
18
24
|
|
|
19
25
|
|
|
20
26
|
class TribalMemoryService(IMemoryService):
|
|
@@ -39,11 +45,25 @@ class TribalMemoryService(IMemoryService):
|
|
|
39
45
|
dedup_exact_threshold: float = 0.98,
|
|
40
46
|
dedup_near_threshold: float = 0.90,
|
|
41
47
|
auto_reject_duplicates: bool = True,
|
|
48
|
+
fts_store: Optional[FTSStore] = None,
|
|
49
|
+
hybrid_search: bool = True,
|
|
50
|
+
hybrid_vector_weight: float = 0.7,
|
|
51
|
+
hybrid_text_weight: float = 0.3,
|
|
52
|
+
hybrid_candidate_multiplier: int = 4,
|
|
53
|
+
reranker: Optional[IReranker] = None,
|
|
54
|
+
rerank_pool_multiplier: int = 2,
|
|
42
55
|
):
|
|
43
56
|
self.instance_id = instance_id
|
|
44
57
|
self.embedding_service = embedding_service
|
|
45
58
|
self.vector_store = vector_store
|
|
46
59
|
self.auto_reject_duplicates = auto_reject_duplicates
|
|
60
|
+
self.fts_store = fts_store
|
|
61
|
+
self.hybrid_search = hybrid_search and fts_store is not None
|
|
62
|
+
self.hybrid_vector_weight = hybrid_vector_weight
|
|
63
|
+
self.hybrid_text_weight = hybrid_text_weight
|
|
64
|
+
self.hybrid_candidate_multiplier = hybrid_candidate_multiplier
|
|
65
|
+
self.reranker = reranker or NoopReranker()
|
|
66
|
+
self.rerank_pool_multiplier = rerank_pool_multiplier
|
|
47
67
|
|
|
48
68
|
self.dedup_service = SemanticDeduplicationService(
|
|
49
69
|
vector_store=vector_store,
|
|
@@ -89,7 +109,16 @@ class TribalMemoryService(IMemoryService):
|
|
|
89
109
|
confidence=1.0,
|
|
90
110
|
)
|
|
91
111
|
|
|
92
|
-
|
|
112
|
+
result = await self.vector_store.store(entry)
|
|
113
|
+
|
|
114
|
+
# Index in FTS for hybrid search (best-effort; vector store is primary)
|
|
115
|
+
if result.success and self.fts_store:
|
|
116
|
+
try:
|
|
117
|
+
self.fts_store.index(entry.id, content, tags or [])
|
|
118
|
+
except Exception as e:
|
|
119
|
+
logger.warning("FTS indexing failed for %s: %s", entry.id, e)
|
|
120
|
+
|
|
121
|
+
return result
|
|
93
122
|
|
|
94
123
|
async def recall(
|
|
95
124
|
self,
|
|
@@ -98,7 +127,11 @@ class TribalMemoryService(IMemoryService):
|
|
|
98
127
|
min_relevance: float = 0.7,
|
|
99
128
|
tags: Optional[list[str]] = None,
|
|
100
129
|
) -> list[RecallResult]:
|
|
101
|
-
"""Recall relevant memories.
|
|
130
|
+
"""Recall relevant memories using hybrid search.
|
|
131
|
+
|
|
132
|
+
When hybrid search is enabled (FTS store available), combines
|
|
133
|
+
vector similarity with BM25 keyword matching for better results.
|
|
134
|
+
Falls back to vector-only search when FTS is unavailable.
|
|
102
135
|
|
|
103
136
|
Args:
|
|
104
137
|
query: Natural language query
|
|
@@ -112,7 +145,13 @@ class TribalMemoryService(IMemoryService):
|
|
|
112
145
|
return []
|
|
113
146
|
|
|
114
147
|
filters = {"tags": tags} if tags else None
|
|
148
|
+
|
|
149
|
+
if self.hybrid_search and self.fts_store:
|
|
150
|
+
return await self._hybrid_recall(
|
|
151
|
+
query, query_embedding, limit, min_relevance, filters
|
|
152
|
+
)
|
|
115
153
|
|
|
154
|
+
# Vector-only fallback
|
|
116
155
|
results = await self.vector_store.recall(
|
|
117
156
|
query_embedding,
|
|
118
157
|
limit=limit,
|
|
@@ -121,6 +160,90 @@ class TribalMemoryService(IMemoryService):
|
|
|
121
160
|
)
|
|
122
161
|
|
|
123
162
|
return self._filter_superseded(results)
|
|
163
|
+
|
|
164
|
+
async def _hybrid_recall(
|
|
165
|
+
self,
|
|
166
|
+
query: str,
|
|
167
|
+
query_embedding: list[float],
|
|
168
|
+
limit: int,
|
|
169
|
+
min_relevance: float,
|
|
170
|
+
filters: Optional[dict],
|
|
171
|
+
) -> list[RecallResult]:
|
|
172
|
+
"""Hybrid recall: vector + BM25 combined, then reranked."""
|
|
173
|
+
candidate_limit = limit * self.hybrid_candidate_multiplier
|
|
174
|
+
|
|
175
|
+
# 1. Vector search — get wide candidate pool
|
|
176
|
+
vector_results = await self.vector_store.recall(
|
|
177
|
+
query_embedding,
|
|
178
|
+
limit=candidate_limit,
|
|
179
|
+
min_similarity=min_relevance * 0.5, # Lower threshold for candidates
|
|
180
|
+
filters=filters,
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
# 2. BM25 search
|
|
184
|
+
bm25_results = self.fts_store.search(query, limit=candidate_limit)
|
|
185
|
+
|
|
186
|
+
# 3. Build lookup for vector results
|
|
187
|
+
vector_for_merge = [
|
|
188
|
+
{"id": r.memory.id, "score": r.similarity_score}
|
|
189
|
+
for r in vector_results
|
|
190
|
+
]
|
|
191
|
+
|
|
192
|
+
# 4. Hybrid merge
|
|
193
|
+
merged = hybrid_merge(
|
|
194
|
+
vector_for_merge,
|
|
195
|
+
bm25_results,
|
|
196
|
+
self.hybrid_vector_weight,
|
|
197
|
+
self.hybrid_text_weight,
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
# 5. Build candidate results for reranking — need full MemoryEntry for each
|
|
201
|
+
# Create lookup from vector results
|
|
202
|
+
entry_map = {r.memory.id: r for r in vector_results}
|
|
203
|
+
|
|
204
|
+
# Get rerank_pool_multiplier * limit candidates before reranking
|
|
205
|
+
rerank_pool_size = min(limit * self.rerank_pool_multiplier, len(merged))
|
|
206
|
+
|
|
207
|
+
# Separate cached (vector) hits from BM25-only hits that need fetching
|
|
208
|
+
cached_hits: list[tuple[dict, RecallResult]] = []
|
|
209
|
+
bm25_only_ids: list[dict] = []
|
|
210
|
+
|
|
211
|
+
for m in merged[:rerank_pool_size]:
|
|
212
|
+
if m["id"] in entry_map:
|
|
213
|
+
cached_hits.append((m, entry_map[m["id"]]))
|
|
214
|
+
else:
|
|
215
|
+
bm25_only_ids.append(m)
|
|
216
|
+
|
|
217
|
+
# Batch-fetch BM25-only hits concurrently
|
|
218
|
+
import asyncio
|
|
219
|
+
fetched_entries = await asyncio.gather(
|
|
220
|
+
*(self.vector_store.get(m["id"]) for m in bm25_only_ids)
|
|
221
|
+
) if bm25_only_ids else []
|
|
222
|
+
|
|
223
|
+
# Build candidate list
|
|
224
|
+
candidates: list[RecallResult] = []
|
|
225
|
+
|
|
226
|
+
# Add cached vector hits
|
|
227
|
+
for m, recall_result in cached_hits:
|
|
228
|
+
candidates.append(RecallResult(
|
|
229
|
+
memory=recall_result.memory,
|
|
230
|
+
similarity_score=m["final_score"],
|
|
231
|
+
retrieval_time_ms=recall_result.retrieval_time_ms,
|
|
232
|
+
))
|
|
233
|
+
|
|
234
|
+
# Add fetched BM25-only hits
|
|
235
|
+
for m, entry in zip(bm25_only_ids, fetched_entries):
|
|
236
|
+
if entry and m["final_score"] >= min_relevance * 0.5:
|
|
237
|
+
candidates.append(RecallResult(
|
|
238
|
+
memory=entry,
|
|
239
|
+
similarity_score=m["final_score"],
|
|
240
|
+
retrieval_time_ms=0,
|
|
241
|
+
))
|
|
242
|
+
|
|
243
|
+
# 6. Rerank candidates
|
|
244
|
+
reranked = self.reranker.rerank(query, candidates, top_k=limit)
|
|
245
|
+
|
|
246
|
+
return self._filter_superseded(reranked)
|
|
124
247
|
|
|
125
248
|
async def correct(
|
|
126
249
|
self,
|
|
@@ -157,7 +280,13 @@ class TribalMemoryService(IMemoryService):
|
|
|
157
280
|
|
|
158
281
|
async def forget(self, memory_id: str) -> bool:
|
|
159
282
|
"""Forget (soft delete) a memory."""
|
|
160
|
-
|
|
283
|
+
result = await self.vector_store.delete(memory_id)
|
|
284
|
+
if result and self.fts_store:
|
|
285
|
+
try:
|
|
286
|
+
self.fts_store.delete(memory_id)
|
|
287
|
+
except Exception as e:
|
|
288
|
+
logger.warning("FTS cleanup failed for %s: %s", memory_id, e)
|
|
289
|
+
return result
|
|
161
290
|
|
|
162
291
|
async def get(self, memory_id: str) -> Optional[MemoryEntry]:
|
|
163
292
|
"""Get a memory by ID with full provenance."""
|
|
@@ -165,40 +294,17 @@ class TribalMemoryService(IMemoryService):
|
|
|
165
294
|
|
|
166
295
|
async def get_stats(self) -> dict:
|
|
167
296
|
"""Get memory statistics.
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
297
|
+
|
|
298
|
+
Delegates to vector_store.get_stats() which computes aggregates
|
|
299
|
+
efficiently (paginated by default, native queries for SQL-backed
|
|
300
|
+
stores).
|
|
171
301
|
"""
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
by_source: dict[str, int] = {}
|
|
175
|
-
by_instance: dict[str, int] = {}
|
|
176
|
-
by_tag: dict[str, int] = {}
|
|
177
|
-
|
|
178
|
-
for m in all_memories:
|
|
179
|
-
source = m.source_type.value
|
|
180
|
-
by_source[source] = by_source.get(source, 0) + 1
|
|
181
|
-
|
|
182
|
-
instance = m.source_instance
|
|
183
|
-
by_instance[instance] = by_instance.get(instance, 0) + 1
|
|
184
|
-
|
|
185
|
-
for tag in m.tags:
|
|
186
|
-
by_tag[tag] = by_tag.get(tag, 0) + 1
|
|
187
|
-
|
|
188
|
-
corrections = sum(1 for m in all_memories if m.supersedes)
|
|
189
|
-
|
|
190
|
-
return {
|
|
191
|
-
"total_memories": len(all_memories),
|
|
192
|
-
"by_source_type": by_source,
|
|
193
|
-
"by_tag": by_tag,
|
|
194
|
-
"by_instance": by_instance,
|
|
195
|
-
"corrections": corrections,
|
|
196
|
-
}
|
|
302
|
+
return await self.vector_store.get_stats()
|
|
197
303
|
|
|
198
304
|
@staticmethod
|
|
199
305
|
def _filter_superseded(results: list[RecallResult]) -> list[RecallResult]:
|
|
200
306
|
"""Remove memories that are superseded by corrections in the result set."""
|
|
201
|
-
superseded_ids = {
|
|
307
|
+
superseded_ids: set[str] = {
|
|
202
308
|
r.memory.supersedes for r in results if r.memory.supersedes
|
|
203
309
|
}
|
|
204
310
|
if not superseded_ids:
|
|
@@ -213,6 +319,14 @@ def create_memory_service(
|
|
|
213
319
|
api_base: Optional[str] = None,
|
|
214
320
|
embedding_model: Optional[str] = None,
|
|
215
321
|
embedding_dimensions: Optional[int] = None,
|
|
322
|
+
hybrid_search: bool = True,
|
|
323
|
+
hybrid_vector_weight: float = 0.7,
|
|
324
|
+
hybrid_text_weight: float = 0.3,
|
|
325
|
+
hybrid_candidate_multiplier: int = 4,
|
|
326
|
+
reranking: str = "heuristic",
|
|
327
|
+
recency_decay_days: float = 30.0,
|
|
328
|
+
tag_boost_weight: float = 0.1,
|
|
329
|
+
rerank_pool_multiplier: int = 2,
|
|
216
330
|
) -> TribalMemoryService:
|
|
217
331
|
"""Factory function to create a memory service with sensible defaults.
|
|
218
332
|
|
|
@@ -225,6 +339,18 @@ def create_memory_service(
|
|
|
225
339
|
For Ollama: "http://localhost:11434/v1"
|
|
226
340
|
embedding_model: Embedding model name. Default: "text-embedding-3-small".
|
|
227
341
|
embedding_dimensions: Embedding output dimensions. Default: 1536.
|
|
342
|
+
hybrid_search: Enable BM25 hybrid search (default: True).
|
|
343
|
+
hybrid_vector_weight: Weight for vector similarity (default: 0.7).
|
|
344
|
+
hybrid_text_weight: Weight for BM25 text score (default: 0.3).
|
|
345
|
+
hybrid_candidate_multiplier: Multiplier for candidate pool size
|
|
346
|
+
(default: 4). Retrieves 4× limit from each source before
|
|
347
|
+
merging.
|
|
348
|
+
reranking: Reranking mode: "auto", "cross-encoder", "heuristic", "none"
|
|
349
|
+
(default: "heuristic").
|
|
350
|
+
recency_decay_days: Half-life for recency boost (default: 30.0).
|
|
351
|
+
tag_boost_weight: Weight for tag match boost (default: 0.1).
|
|
352
|
+
rerank_pool_multiplier: How many candidates to give the reranker
|
|
353
|
+
(N × limit). Default: 2.
|
|
228
354
|
|
|
229
355
|
Returns:
|
|
230
356
|
Configured TribalMemoryService ready for use.
|
|
@@ -267,9 +393,43 @@ def create_memory_service(
|
|
|
267
393
|
vector_store = InMemoryVectorStore(embedding_service)
|
|
268
394
|
else:
|
|
269
395
|
vector_store = InMemoryVectorStore(embedding_service)
|
|
396
|
+
|
|
397
|
+
# Create FTS store for hybrid search (co-located with LanceDB)
|
|
398
|
+
fts_store = None
|
|
399
|
+
if hybrid_search and db_path:
|
|
400
|
+
try:
|
|
401
|
+
fts_db_path = str(Path(db_path) / "fts_index.db")
|
|
402
|
+
fts_store = FTSStore(fts_db_path)
|
|
403
|
+
if fts_store.is_available():
|
|
404
|
+
logger.info("Hybrid search enabled (SQLite FTS5)")
|
|
405
|
+
else:
|
|
406
|
+
logger.warning(
|
|
407
|
+
"FTS5 not available in SQLite build. "
|
|
408
|
+
"Hybrid search disabled, using vector-only."
|
|
409
|
+
)
|
|
410
|
+
fts_store = None
|
|
411
|
+
except Exception as e:
|
|
412
|
+
logger.warning(f"FTS store init failed: {e}. Using vector-only.")
|
|
413
|
+
fts_store = None
|
|
414
|
+
|
|
415
|
+
# Create reranker
|
|
416
|
+
from ..server.config import SearchConfig
|
|
417
|
+
search_config = SearchConfig(
|
|
418
|
+
reranking=reranking,
|
|
419
|
+
recency_decay_days=recency_decay_days,
|
|
420
|
+
tag_boost_weight=tag_boost_weight,
|
|
421
|
+
)
|
|
422
|
+
reranker = create_reranker(search_config)
|
|
270
423
|
|
|
271
424
|
return TribalMemoryService(
|
|
272
425
|
instance_id=instance_id,
|
|
273
426
|
embedding_service=embedding_service,
|
|
274
|
-
vector_store=vector_store
|
|
427
|
+
vector_store=vector_store,
|
|
428
|
+
fts_store=fts_store,
|
|
429
|
+
hybrid_search=hybrid_search,
|
|
430
|
+
hybrid_vector_weight=hybrid_vector_weight,
|
|
431
|
+
hybrid_text_weight=hybrid_text_weight,
|
|
432
|
+
hybrid_candidate_multiplier=hybrid_candidate_multiplier,
|
|
433
|
+
reranker=reranker,
|
|
434
|
+
rerank_pool_multiplier=rerank_pool_multiplier,
|
|
275
435
|
)
|
|
@@ -0,0 +1,267 @@
|
|
|
1
|
+
"""Result reranking for improved retrieval quality.
|
|
2
|
+
|
|
3
|
+
Provides multiple reranking strategies:
|
|
4
|
+
- NoopReranker: Pass-through, no reranking
|
|
5
|
+
- HeuristicReranker: Fast heuristic scoring (recency, tags, length)
|
|
6
|
+
- CrossEncoderReranker: Model-based reranking (sentence-transformers)
|
|
7
|
+
|
|
8
|
+
Reranking happens after initial retrieval (vector + BM25) to refine ordering.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
import logging
|
|
12
|
+
import math
|
|
13
|
+
from datetime import datetime
|
|
14
|
+
from typing import TYPE_CHECKING, Protocol
|
|
15
|
+
|
|
16
|
+
from ..interfaces import RecallResult
|
|
17
|
+
|
|
18
|
+
if TYPE_CHECKING:
|
|
19
|
+
from ..server.config import SearchConfig
|
|
20
|
+
|
|
21
|
+
logger = logging.getLogger(__name__)
|
|
22
|
+
|
|
23
|
+
# Lazy import for optional dependency
|
|
24
|
+
CROSS_ENCODER_AVAILABLE = False
|
|
25
|
+
CrossEncoder = None
|
|
26
|
+
|
|
27
|
+
try:
|
|
28
|
+
from sentence_transformers import CrossEncoder as _CrossEncoder
|
|
29
|
+
CrossEncoder = _CrossEncoder
|
|
30
|
+
CROSS_ENCODER_AVAILABLE = True
|
|
31
|
+
except ImportError:
|
|
32
|
+
pass
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class IReranker(Protocol):
|
|
36
|
+
"""Interface for result reranking."""
|
|
37
|
+
|
|
38
|
+
def rerank(
|
|
39
|
+
self, query: str, candidates: list[RecallResult], top_k: int
|
|
40
|
+
) -> list[RecallResult]:
|
|
41
|
+
"""Rerank candidates and return top_k results.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
query: Original search query
|
|
45
|
+
candidates: Initial retrieval results
|
|
46
|
+
top_k: Number of results to return
|
|
47
|
+
|
|
48
|
+
Returns:
|
|
49
|
+
Reranked results (up to top_k)
|
|
50
|
+
"""
|
|
51
|
+
...
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class NoopReranker:
|
|
55
|
+
"""Pass-through reranker (no reranking)."""
|
|
56
|
+
|
|
57
|
+
def rerank(
|
|
58
|
+
self, query: str, candidates: list[RecallResult], top_k: int
|
|
59
|
+
) -> list[RecallResult]:
|
|
60
|
+
"""Return top_k candidates unchanged."""
|
|
61
|
+
return candidates[:top_k]
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class HeuristicReranker:
|
|
65
|
+
"""Heuristic reranking with recency, tag match, and length signals.
|
|
66
|
+
|
|
67
|
+
Combines multiple quality signals:
|
|
68
|
+
- Recency: newer memories score higher (exponential decay)
|
|
69
|
+
- Tag match: query terms matching tags boost score
|
|
70
|
+
- Length penalty: very short or very long content penalized
|
|
71
|
+
|
|
72
|
+
Final score: original_score * (1 + boost_sum)
|
|
73
|
+
"""
|
|
74
|
+
|
|
75
|
+
def __init__(
|
|
76
|
+
self,
|
|
77
|
+
recency_decay_days: float = 30.0,
|
|
78
|
+
tag_boost_weight: float = 0.1,
|
|
79
|
+
min_length: int = 10,
|
|
80
|
+
max_length: int = 2000,
|
|
81
|
+
short_penalty: float = 0.05,
|
|
82
|
+
long_penalty: float = 0.03,
|
|
83
|
+
):
|
|
84
|
+
"""Initialize heuristic reranker.
|
|
85
|
+
|
|
86
|
+
Args:
|
|
87
|
+
recency_decay_days: Half-life for recency boost (days)
|
|
88
|
+
tag_boost_weight: Weight for tag match boost
|
|
89
|
+
min_length: Content shorter than this gets penalty
|
|
90
|
+
max_length: Content longer than this gets penalty
|
|
91
|
+
short_penalty: Penalty for content shorter than min_length
|
|
92
|
+
long_penalty: Penalty for content longer than max_length
|
|
93
|
+
"""
|
|
94
|
+
self.recency_decay_days = recency_decay_days
|
|
95
|
+
self.tag_boost_weight = tag_boost_weight
|
|
96
|
+
self.min_length = min_length
|
|
97
|
+
self.max_length = max_length
|
|
98
|
+
self.short_penalty = short_penalty
|
|
99
|
+
self.long_penalty = long_penalty
|
|
100
|
+
|
|
101
|
+
def rerank(
|
|
102
|
+
self, query: str, candidates: list[RecallResult], top_k: int
|
|
103
|
+
) -> list[RecallResult]:
|
|
104
|
+
"""Rerank using heuristic scoring."""
|
|
105
|
+
if not candidates:
|
|
106
|
+
return []
|
|
107
|
+
|
|
108
|
+
# Compute boost for each candidate
|
|
109
|
+
scored = []
|
|
110
|
+
query_lower = query.lower()
|
|
111
|
+
query_terms = set(query_lower.split())
|
|
112
|
+
now = datetime.utcnow()
|
|
113
|
+
|
|
114
|
+
for i, candidate in enumerate(candidates):
|
|
115
|
+
boost = 0.0
|
|
116
|
+
|
|
117
|
+
# Recency boost (exponential decay)
|
|
118
|
+
# Brand new memory (age=0) gets boost of 1.0, older memories decay exponentially
|
|
119
|
+
age_days = (now - candidate.memory.created_at).total_seconds() / 86400
|
|
120
|
+
recency_boost = math.exp(-age_days / self.recency_decay_days)
|
|
121
|
+
boost += recency_boost
|
|
122
|
+
|
|
123
|
+
# Tag match boost (exact term matching, not substring)
|
|
124
|
+
if candidate.memory.tags:
|
|
125
|
+
tag_lower = set(t.lower() for t in candidate.memory.tags)
|
|
126
|
+
# Count query terms that exactly match tags
|
|
127
|
+
matches = sum(1 for term in query_terms if term in tag_lower)
|
|
128
|
+
if matches > 0:
|
|
129
|
+
boost += self.tag_boost_weight * matches
|
|
130
|
+
|
|
131
|
+
# Length penalty
|
|
132
|
+
content_length = len(candidate.memory.content)
|
|
133
|
+
if content_length < self.min_length:
|
|
134
|
+
boost -= self.short_penalty
|
|
135
|
+
elif content_length > self.max_length:
|
|
136
|
+
boost -= self.long_penalty
|
|
137
|
+
|
|
138
|
+
# Combine with original score
|
|
139
|
+
final_score = candidate.similarity_score * (1.0 + boost)
|
|
140
|
+
|
|
141
|
+
scored.append((final_score, i, candidate))
|
|
142
|
+
|
|
143
|
+
# Sort by final score (desc), then original index (preserve order on ties)
|
|
144
|
+
scored.sort(key=lambda x: (-x[0], x[1]))
|
|
145
|
+
|
|
146
|
+
# Build reranked results with updated scores
|
|
147
|
+
reranked = []
|
|
148
|
+
for final_score, _, candidate in scored[:top_k]:
|
|
149
|
+
reranked.append(
|
|
150
|
+
RecallResult(
|
|
151
|
+
memory=candidate.memory,
|
|
152
|
+
similarity_score=final_score,
|
|
153
|
+
retrieval_time_ms=candidate.retrieval_time_ms,
|
|
154
|
+
)
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
return reranked
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
class CrossEncoderReranker:
|
|
161
|
+
"""Cross-encoder model-based reranking.
|
|
162
|
+
|
|
163
|
+
Uses a sentence-transformers cross-encoder to score (query, candidate) pairs.
|
|
164
|
+
Model scores relevance directly, producing better ranking than retrieval alone.
|
|
165
|
+
|
|
166
|
+
Requires sentence-transformers package.
|
|
167
|
+
"""
|
|
168
|
+
|
|
169
|
+
def __init__(self, model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"):
|
|
170
|
+
"""Initialize cross-encoder reranker.
|
|
171
|
+
|
|
172
|
+
Args:
|
|
173
|
+
model_name: HuggingFace model name
|
|
174
|
+
|
|
175
|
+
Raises:
|
|
176
|
+
ImportError: If sentence-transformers not installed
|
|
177
|
+
"""
|
|
178
|
+
if not CROSS_ENCODER_AVAILABLE:
|
|
179
|
+
raise ImportError(
|
|
180
|
+
"sentence-transformers required for CrossEncoderReranker. "
|
|
181
|
+
"Install with: pip install sentence-transformers"
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
logger.info(f"Loading cross-encoder model: {model_name}")
|
|
185
|
+
self.model = CrossEncoder(model_name)
|
|
186
|
+
|
|
187
|
+
def rerank(
|
|
188
|
+
self, query: str, candidates: list[RecallResult], top_k: int
|
|
189
|
+
) -> list[RecallResult]:
|
|
190
|
+
"""Rerank using cross-encoder model."""
|
|
191
|
+
if not candidates:
|
|
192
|
+
logger.debug("No candidates to rerank")
|
|
193
|
+
return []
|
|
194
|
+
|
|
195
|
+
# Build (query, content) pairs
|
|
196
|
+
pairs = [(query, candidate.memory.content) for candidate in candidates]
|
|
197
|
+
|
|
198
|
+
# Score with model
|
|
199
|
+
scores = self.model.predict(pairs)
|
|
200
|
+
|
|
201
|
+
# Sort by score descending
|
|
202
|
+
scored = list(zip(scores, candidates))
|
|
203
|
+
scored.sort(key=lambda x: -x[0])
|
|
204
|
+
|
|
205
|
+
# Build reranked results with updated scores
|
|
206
|
+
reranked = []
|
|
207
|
+
for score, candidate in scored[:top_k]:
|
|
208
|
+
reranked.append(
|
|
209
|
+
RecallResult(
|
|
210
|
+
memory=candidate.memory,
|
|
211
|
+
similarity_score=float(score),
|
|
212
|
+
retrieval_time_ms=candidate.retrieval_time_ms,
|
|
213
|
+
)
|
|
214
|
+
)
|
|
215
|
+
|
|
216
|
+
return reranked
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
def create_reranker(config: "SearchConfig") -> IReranker:
|
|
220
|
+
"""Factory function to create reranker from config.
|
|
221
|
+
|
|
222
|
+
Args:
|
|
223
|
+
config: SearchConfig with reranking settings
|
|
224
|
+
|
|
225
|
+
Returns:
|
|
226
|
+
Configured reranker instance
|
|
227
|
+
|
|
228
|
+
Raises:
|
|
229
|
+
ValueError: If reranking mode is invalid
|
|
230
|
+
ImportError: If cross-encoder requested but unavailable
|
|
231
|
+
"""
|
|
232
|
+
mode = getattr(config, "reranking", "heuristic")
|
|
233
|
+
|
|
234
|
+
if mode == "none":
|
|
235
|
+
return NoopReranker()
|
|
236
|
+
|
|
237
|
+
elif mode == "heuristic":
|
|
238
|
+
return HeuristicReranker(
|
|
239
|
+
recency_decay_days=getattr(config, "recency_decay_days", 30.0),
|
|
240
|
+
tag_boost_weight=getattr(config, "tag_boost_weight", 0.1),
|
|
241
|
+
)
|
|
242
|
+
|
|
243
|
+
elif mode == "cross-encoder":
|
|
244
|
+
if not CROSS_ENCODER_AVAILABLE:
|
|
245
|
+
raise ImportError(
|
|
246
|
+
"Cross-encoder reranking requires sentence-transformers. "
|
|
247
|
+
"Install with: pip install sentence-transformers"
|
|
248
|
+
)
|
|
249
|
+
return CrossEncoderReranker()
|
|
250
|
+
|
|
251
|
+
elif mode == "auto":
|
|
252
|
+
# Try cross-encoder, fall back to heuristic
|
|
253
|
+
if CROSS_ENCODER_AVAILABLE:
|
|
254
|
+
try:
|
|
255
|
+
return CrossEncoderReranker()
|
|
256
|
+
except Exception as e:
|
|
257
|
+
logger.warning(f"Cross-encoder init failed: {e}. Falling back to heuristic.")
|
|
258
|
+
return HeuristicReranker(
|
|
259
|
+
recency_decay_days=getattr(config, "recency_decay_days", 30.0),
|
|
260
|
+
tag_boost_weight=getattr(config, "tag_boost_weight", 0.1),
|
|
261
|
+
)
|
|
262
|
+
|
|
263
|
+
else:
|
|
264
|
+
raise ValueError(
|
|
265
|
+
f"Unknown reranking mode: {mode}. "
|
|
266
|
+
f"Valid options: 'none', 'heuristic', 'cross-encoder', 'auto'"
|
|
267
|
+
)
|