tribalmemory 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,7 +1,9 @@
1
1
  """Tribal Memory Service - Main API for agents."""
2
2
 
3
+ import logging
3
4
  import os
4
5
  from datetime import datetime
6
+ from pathlib import Path
5
7
  from typing import Optional
6
8
  import uuid
7
9
 
@@ -15,6 +17,10 @@ from ..interfaces import (
15
17
  StoreResult,
16
18
  )
17
19
  from .deduplication import SemanticDeduplicationService
20
+ from .fts_store import FTSStore, hybrid_merge
21
+ from .reranker import IReranker, NoopReranker, create_reranker
22
+
23
+ logger = logging.getLogger(__name__)
18
24
 
19
25
 
20
26
  class TribalMemoryService(IMemoryService):
@@ -39,11 +45,25 @@ class TribalMemoryService(IMemoryService):
39
45
  dedup_exact_threshold: float = 0.98,
40
46
  dedup_near_threshold: float = 0.90,
41
47
  auto_reject_duplicates: bool = True,
48
+ fts_store: Optional[FTSStore] = None,
49
+ hybrid_search: bool = True,
50
+ hybrid_vector_weight: float = 0.7,
51
+ hybrid_text_weight: float = 0.3,
52
+ hybrid_candidate_multiplier: int = 4,
53
+ reranker: Optional[IReranker] = None,
54
+ rerank_pool_multiplier: int = 2,
42
55
  ):
43
56
  self.instance_id = instance_id
44
57
  self.embedding_service = embedding_service
45
58
  self.vector_store = vector_store
46
59
  self.auto_reject_duplicates = auto_reject_duplicates
60
+ self.fts_store = fts_store
61
+ self.hybrid_search = hybrid_search and fts_store is not None
62
+ self.hybrid_vector_weight = hybrid_vector_weight
63
+ self.hybrid_text_weight = hybrid_text_weight
64
+ self.hybrid_candidate_multiplier = hybrid_candidate_multiplier
65
+ self.reranker = reranker or NoopReranker()
66
+ self.rerank_pool_multiplier = rerank_pool_multiplier
47
67
 
48
68
  self.dedup_service = SemanticDeduplicationService(
49
69
  vector_store=vector_store,
@@ -89,7 +109,16 @@ class TribalMemoryService(IMemoryService):
89
109
  confidence=1.0,
90
110
  )
91
111
 
92
- return await self.vector_store.store(entry)
112
+ result = await self.vector_store.store(entry)
113
+
114
+ # Index in FTS for hybrid search (best-effort; vector store is primary)
115
+ if result.success and self.fts_store:
116
+ try:
117
+ self.fts_store.index(entry.id, content, tags or [])
118
+ except Exception as e:
119
+ logger.warning("FTS indexing failed for %s: %s", entry.id, e)
120
+
121
+ return result
93
122
 
94
123
  async def recall(
95
124
  self,
@@ -98,7 +127,11 @@ class TribalMemoryService(IMemoryService):
98
127
  min_relevance: float = 0.7,
99
128
  tags: Optional[list[str]] = None,
100
129
  ) -> list[RecallResult]:
101
- """Recall relevant memories.
130
+ """Recall relevant memories using hybrid search.
131
+
132
+ When hybrid search is enabled (FTS store available), combines
133
+ vector similarity with BM25 keyword matching for better results.
134
+ Falls back to vector-only search when FTS is unavailable.
102
135
 
103
136
  Args:
104
137
  query: Natural language query
@@ -112,7 +145,13 @@ class TribalMemoryService(IMemoryService):
112
145
  return []
113
146
 
114
147
  filters = {"tags": tags} if tags else None
148
+
149
+ if self.hybrid_search and self.fts_store:
150
+ return await self._hybrid_recall(
151
+ query, query_embedding, limit, min_relevance, filters
152
+ )
115
153
 
154
+ # Vector-only fallback
116
155
  results = await self.vector_store.recall(
117
156
  query_embedding,
118
157
  limit=limit,
@@ -121,6 +160,90 @@ class TribalMemoryService(IMemoryService):
121
160
  )
122
161
 
123
162
  return self._filter_superseded(results)
163
+
164
+ async def _hybrid_recall(
165
+ self,
166
+ query: str,
167
+ query_embedding: list[float],
168
+ limit: int,
169
+ min_relevance: float,
170
+ filters: Optional[dict],
171
+ ) -> list[RecallResult]:
172
+ """Hybrid recall: vector + BM25 combined, then reranked."""
173
+ candidate_limit = limit * self.hybrid_candidate_multiplier
174
+
175
+ # 1. Vector search — get wide candidate pool
176
+ vector_results = await self.vector_store.recall(
177
+ query_embedding,
178
+ limit=candidate_limit,
179
+ min_similarity=min_relevance * 0.5, # Lower threshold for candidates
180
+ filters=filters,
181
+ )
182
+
183
+ # 2. BM25 search
184
+ bm25_results = self.fts_store.search(query, limit=candidate_limit)
185
+
186
+ # 3. Build lookup for vector results
187
+ vector_for_merge = [
188
+ {"id": r.memory.id, "score": r.similarity_score}
189
+ for r in vector_results
190
+ ]
191
+
192
+ # 4. Hybrid merge
193
+ merged = hybrid_merge(
194
+ vector_for_merge,
195
+ bm25_results,
196
+ self.hybrid_vector_weight,
197
+ self.hybrid_text_weight,
198
+ )
199
+
200
+ # 5. Build candidate results for reranking — need full MemoryEntry for each
201
+ # Create lookup from vector results
202
+ entry_map = {r.memory.id: r for r in vector_results}
203
+
204
+ # Get rerank_pool_multiplier * limit candidates before reranking
205
+ rerank_pool_size = min(limit * self.rerank_pool_multiplier, len(merged))
206
+
207
+ # Separate cached (vector) hits from BM25-only hits that need fetching
208
+ cached_hits: list[tuple[dict, RecallResult]] = []
209
+ bm25_only_ids: list[dict] = []
210
+
211
+ for m in merged[:rerank_pool_size]:
212
+ if m["id"] in entry_map:
213
+ cached_hits.append((m, entry_map[m["id"]]))
214
+ else:
215
+ bm25_only_ids.append(m)
216
+
217
+ # Batch-fetch BM25-only hits concurrently
218
+ import asyncio
219
+ fetched_entries = await asyncio.gather(
220
+ *(self.vector_store.get(m["id"]) for m in bm25_only_ids)
221
+ ) if bm25_only_ids else []
222
+
223
+ # Build candidate list
224
+ candidates: list[RecallResult] = []
225
+
226
+ # Add cached vector hits
227
+ for m, recall_result in cached_hits:
228
+ candidates.append(RecallResult(
229
+ memory=recall_result.memory,
230
+ similarity_score=m["final_score"],
231
+ retrieval_time_ms=recall_result.retrieval_time_ms,
232
+ ))
233
+
234
+ # Add fetched BM25-only hits
235
+ for m, entry in zip(bm25_only_ids, fetched_entries):
236
+ if entry and m["final_score"] >= min_relevance * 0.5:
237
+ candidates.append(RecallResult(
238
+ memory=entry,
239
+ similarity_score=m["final_score"],
240
+ retrieval_time_ms=0,
241
+ ))
242
+
243
+ # 6. Rerank candidates
244
+ reranked = self.reranker.rerank(query, candidates, top_k=limit)
245
+
246
+ return self._filter_superseded(reranked)
124
247
 
125
248
  async def correct(
126
249
  self,
@@ -157,7 +280,13 @@ class TribalMemoryService(IMemoryService):
157
280
 
158
281
  async def forget(self, memory_id: str) -> bool:
159
282
  """Forget (soft delete) a memory."""
160
- return await self.vector_store.delete(memory_id)
283
+ result = await self.vector_store.delete(memory_id)
284
+ if result and self.fts_store:
285
+ try:
286
+ self.fts_store.delete(memory_id)
287
+ except Exception as e:
288
+ logger.warning("FTS cleanup failed for %s: %s", memory_id, e)
289
+ return result
161
290
 
162
291
  async def get(self, memory_id: str) -> Optional[MemoryEntry]:
163
292
  """Get a memory by ID with full provenance."""
@@ -165,40 +294,17 @@ class TribalMemoryService(IMemoryService):
165
294
 
166
295
  async def get_stats(self) -> dict:
167
296
  """Get memory statistics.
168
-
169
- Note: Stats are computed over up to 10,000 most recent memories.
170
- For systems with >10k memories, consider using count() with filters.
297
+
298
+ Delegates to vector_store.get_stats() which computes aggregates
299
+ efficiently (paginated by default, native queries for SQL-backed
300
+ stores).
171
301
  """
172
- all_memories = await self.vector_store.list(limit=10000)
173
-
174
- by_source: dict[str, int] = {}
175
- by_instance: dict[str, int] = {}
176
- by_tag: dict[str, int] = {}
177
-
178
- for m in all_memories:
179
- source = m.source_type.value
180
- by_source[source] = by_source.get(source, 0) + 1
181
-
182
- instance = m.source_instance
183
- by_instance[instance] = by_instance.get(instance, 0) + 1
184
-
185
- for tag in m.tags:
186
- by_tag[tag] = by_tag.get(tag, 0) + 1
187
-
188
- corrections = sum(1 for m in all_memories if m.supersedes)
189
-
190
- return {
191
- "total_memories": len(all_memories),
192
- "by_source_type": by_source,
193
- "by_tag": by_tag,
194
- "by_instance": by_instance,
195
- "corrections": corrections,
196
- }
302
+ return await self.vector_store.get_stats()
197
303
 
198
304
  @staticmethod
199
305
  def _filter_superseded(results: list[RecallResult]) -> list[RecallResult]:
200
306
  """Remove memories that are superseded by corrections in the result set."""
201
- superseded_ids = {
307
+ superseded_ids: set[str] = {
202
308
  r.memory.supersedes for r in results if r.memory.supersedes
203
309
  }
204
310
  if not superseded_ids:
@@ -213,6 +319,14 @@ def create_memory_service(
213
319
  api_base: Optional[str] = None,
214
320
  embedding_model: Optional[str] = None,
215
321
  embedding_dimensions: Optional[int] = None,
322
+ hybrid_search: bool = True,
323
+ hybrid_vector_weight: float = 0.7,
324
+ hybrid_text_weight: float = 0.3,
325
+ hybrid_candidate_multiplier: int = 4,
326
+ reranking: str = "heuristic",
327
+ recency_decay_days: float = 30.0,
328
+ tag_boost_weight: float = 0.1,
329
+ rerank_pool_multiplier: int = 2,
216
330
  ) -> TribalMemoryService:
217
331
  """Factory function to create a memory service with sensible defaults.
218
332
 
@@ -225,6 +339,18 @@ def create_memory_service(
225
339
  For Ollama: "http://localhost:11434/v1"
226
340
  embedding_model: Embedding model name. Default: "text-embedding-3-small".
227
341
  embedding_dimensions: Embedding output dimensions. Default: 1536.
342
+ hybrid_search: Enable BM25 hybrid search (default: True).
343
+ hybrid_vector_weight: Weight for vector similarity (default: 0.7).
344
+ hybrid_text_weight: Weight for BM25 text score (default: 0.3).
345
+ hybrid_candidate_multiplier: Multiplier for candidate pool size
346
+ (default: 4). Retrieves 4× limit from each source before
347
+ merging.
348
+ reranking: Reranking mode: "auto", "cross-encoder", "heuristic", "none"
349
+ (default: "heuristic").
350
+ recency_decay_days: Half-life for recency boost (default: 30.0).
351
+ tag_boost_weight: Weight for tag match boost (default: 0.1).
352
+ rerank_pool_multiplier: How many candidates to give the reranker
353
+ (N × limit). Default: 2.
228
354
 
229
355
  Returns:
230
356
  Configured TribalMemoryService ready for use.
@@ -267,9 +393,43 @@ def create_memory_service(
267
393
  vector_store = InMemoryVectorStore(embedding_service)
268
394
  else:
269
395
  vector_store = InMemoryVectorStore(embedding_service)
396
+
397
+ # Create FTS store for hybrid search (co-located with LanceDB)
398
+ fts_store = None
399
+ if hybrid_search and db_path:
400
+ try:
401
+ fts_db_path = str(Path(db_path) / "fts_index.db")
402
+ fts_store = FTSStore(fts_db_path)
403
+ if fts_store.is_available():
404
+ logger.info("Hybrid search enabled (SQLite FTS5)")
405
+ else:
406
+ logger.warning(
407
+ "FTS5 not available in SQLite build. "
408
+ "Hybrid search disabled, using vector-only."
409
+ )
410
+ fts_store = None
411
+ except Exception as e:
412
+ logger.warning(f"FTS store init failed: {e}. Using vector-only.")
413
+ fts_store = None
414
+
415
+ # Create reranker
416
+ from ..server.config import SearchConfig
417
+ search_config = SearchConfig(
418
+ reranking=reranking,
419
+ recency_decay_days=recency_decay_days,
420
+ tag_boost_weight=tag_boost_weight,
421
+ )
422
+ reranker = create_reranker(search_config)
270
423
 
271
424
  return TribalMemoryService(
272
425
  instance_id=instance_id,
273
426
  embedding_service=embedding_service,
274
- vector_store=vector_store
427
+ vector_store=vector_store,
428
+ fts_store=fts_store,
429
+ hybrid_search=hybrid_search,
430
+ hybrid_vector_weight=hybrid_vector_weight,
431
+ hybrid_text_weight=hybrid_text_weight,
432
+ hybrid_candidate_multiplier=hybrid_candidate_multiplier,
433
+ reranker=reranker,
434
+ rerank_pool_multiplier=rerank_pool_multiplier,
275
435
  )
@@ -0,0 +1,267 @@
1
+ """Result reranking for improved retrieval quality.
2
+
3
+ Provides multiple reranking strategies:
4
+ - NoopReranker: Pass-through, no reranking
5
+ - HeuristicReranker: Fast heuristic scoring (recency, tags, length)
6
+ - CrossEncoderReranker: Model-based reranking (sentence-transformers)
7
+
8
+ Reranking happens after initial retrieval (vector + BM25) to refine ordering.
9
+ """
10
+
11
+ import logging
12
+ import math
13
+ from datetime import datetime
14
+ from typing import TYPE_CHECKING, Protocol
15
+
16
+ from ..interfaces import RecallResult
17
+
18
+ if TYPE_CHECKING:
19
+ from ..server.config import SearchConfig
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+ # Lazy import for optional dependency
24
+ CROSS_ENCODER_AVAILABLE = False
25
+ CrossEncoder = None
26
+
27
+ try:
28
+ from sentence_transformers import CrossEncoder as _CrossEncoder
29
+ CrossEncoder = _CrossEncoder
30
+ CROSS_ENCODER_AVAILABLE = True
31
+ except ImportError:
32
+ pass
33
+
34
+
35
+ class IReranker(Protocol):
36
+ """Interface for result reranking."""
37
+
38
+ def rerank(
39
+ self, query: str, candidates: list[RecallResult], top_k: int
40
+ ) -> list[RecallResult]:
41
+ """Rerank candidates and return top_k results.
42
+
43
+ Args:
44
+ query: Original search query
45
+ candidates: Initial retrieval results
46
+ top_k: Number of results to return
47
+
48
+ Returns:
49
+ Reranked results (up to top_k)
50
+ """
51
+ ...
52
+
53
+
54
+ class NoopReranker:
55
+ """Pass-through reranker (no reranking)."""
56
+
57
+ def rerank(
58
+ self, query: str, candidates: list[RecallResult], top_k: int
59
+ ) -> list[RecallResult]:
60
+ """Return top_k candidates unchanged."""
61
+ return candidates[:top_k]
62
+
63
+
64
+ class HeuristicReranker:
65
+ """Heuristic reranking with recency, tag match, and length signals.
66
+
67
+ Combines multiple quality signals:
68
+ - Recency: newer memories score higher (exponential decay)
69
+ - Tag match: query terms matching tags boost score
70
+ - Length penalty: very short or very long content penalized
71
+
72
+ Final score: original_score * (1 + boost_sum)
73
+ """
74
+
75
+ def __init__(
76
+ self,
77
+ recency_decay_days: float = 30.0,
78
+ tag_boost_weight: float = 0.1,
79
+ min_length: int = 10,
80
+ max_length: int = 2000,
81
+ short_penalty: float = 0.05,
82
+ long_penalty: float = 0.03,
83
+ ):
84
+ """Initialize heuristic reranker.
85
+
86
+ Args:
87
+ recency_decay_days: Half-life for recency boost (days)
88
+ tag_boost_weight: Weight for tag match boost
89
+ min_length: Content shorter than this gets penalty
90
+ max_length: Content longer than this gets penalty
91
+ short_penalty: Penalty for content shorter than min_length
92
+ long_penalty: Penalty for content longer than max_length
93
+ """
94
+ self.recency_decay_days = recency_decay_days
95
+ self.tag_boost_weight = tag_boost_weight
96
+ self.min_length = min_length
97
+ self.max_length = max_length
98
+ self.short_penalty = short_penalty
99
+ self.long_penalty = long_penalty
100
+
101
+ def rerank(
102
+ self, query: str, candidates: list[RecallResult], top_k: int
103
+ ) -> list[RecallResult]:
104
+ """Rerank using heuristic scoring."""
105
+ if not candidates:
106
+ return []
107
+
108
+ # Compute boost for each candidate
109
+ scored = []
110
+ query_lower = query.lower()
111
+ query_terms = set(query_lower.split())
112
+ now = datetime.utcnow()
113
+
114
+ for i, candidate in enumerate(candidates):
115
+ boost = 0.0
116
+
117
+ # Recency boost (exponential decay)
118
+ # Brand new memory (age=0) gets boost of 1.0, older memories decay exponentially
119
+ age_days = (now - candidate.memory.created_at).total_seconds() / 86400
120
+ recency_boost = math.exp(-age_days / self.recency_decay_days)
121
+ boost += recency_boost
122
+
123
+ # Tag match boost (exact term matching, not substring)
124
+ if candidate.memory.tags:
125
+ tag_lower = set(t.lower() for t in candidate.memory.tags)
126
+ # Count query terms that exactly match tags
127
+ matches = sum(1 for term in query_terms if term in tag_lower)
128
+ if matches > 0:
129
+ boost += self.tag_boost_weight * matches
130
+
131
+ # Length penalty
132
+ content_length = len(candidate.memory.content)
133
+ if content_length < self.min_length:
134
+ boost -= self.short_penalty
135
+ elif content_length > self.max_length:
136
+ boost -= self.long_penalty
137
+
138
+ # Combine with original score
139
+ final_score = candidate.similarity_score * (1.0 + boost)
140
+
141
+ scored.append((final_score, i, candidate))
142
+
143
+ # Sort by final score (desc), then original index (preserve order on ties)
144
+ scored.sort(key=lambda x: (-x[0], x[1]))
145
+
146
+ # Build reranked results with updated scores
147
+ reranked = []
148
+ for final_score, _, candidate in scored[:top_k]:
149
+ reranked.append(
150
+ RecallResult(
151
+ memory=candidate.memory,
152
+ similarity_score=final_score,
153
+ retrieval_time_ms=candidate.retrieval_time_ms,
154
+ )
155
+ )
156
+
157
+ return reranked
158
+
159
+
160
+ class CrossEncoderReranker:
161
+ """Cross-encoder model-based reranking.
162
+
163
+ Uses a sentence-transformers cross-encoder to score (query, candidate) pairs.
164
+ Model scores relevance directly, producing better ranking than retrieval alone.
165
+
166
+ Requires sentence-transformers package.
167
+ """
168
+
169
+ def __init__(self, model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"):
170
+ """Initialize cross-encoder reranker.
171
+
172
+ Args:
173
+ model_name: HuggingFace model name
174
+
175
+ Raises:
176
+ ImportError: If sentence-transformers not installed
177
+ """
178
+ if not CROSS_ENCODER_AVAILABLE:
179
+ raise ImportError(
180
+ "sentence-transformers required for CrossEncoderReranker. "
181
+ "Install with: pip install sentence-transformers"
182
+ )
183
+
184
+ logger.info(f"Loading cross-encoder model: {model_name}")
185
+ self.model = CrossEncoder(model_name)
186
+
187
+ def rerank(
188
+ self, query: str, candidates: list[RecallResult], top_k: int
189
+ ) -> list[RecallResult]:
190
+ """Rerank using cross-encoder model."""
191
+ if not candidates:
192
+ logger.debug("No candidates to rerank")
193
+ return []
194
+
195
+ # Build (query, content) pairs
196
+ pairs = [(query, candidate.memory.content) for candidate in candidates]
197
+
198
+ # Score with model
199
+ scores = self.model.predict(pairs)
200
+
201
+ # Sort by score descending
202
+ scored = list(zip(scores, candidates))
203
+ scored.sort(key=lambda x: -x[0])
204
+
205
+ # Build reranked results with updated scores
206
+ reranked = []
207
+ for score, candidate in scored[:top_k]:
208
+ reranked.append(
209
+ RecallResult(
210
+ memory=candidate.memory,
211
+ similarity_score=float(score),
212
+ retrieval_time_ms=candidate.retrieval_time_ms,
213
+ )
214
+ )
215
+
216
+ return reranked
217
+
218
+
219
+ def create_reranker(config: "SearchConfig") -> IReranker:
220
+ """Factory function to create reranker from config.
221
+
222
+ Args:
223
+ config: SearchConfig with reranking settings
224
+
225
+ Returns:
226
+ Configured reranker instance
227
+
228
+ Raises:
229
+ ValueError: If reranking mode is invalid
230
+ ImportError: If cross-encoder requested but unavailable
231
+ """
232
+ mode = getattr(config, "reranking", "heuristic")
233
+
234
+ if mode == "none":
235
+ return NoopReranker()
236
+
237
+ elif mode == "heuristic":
238
+ return HeuristicReranker(
239
+ recency_decay_days=getattr(config, "recency_decay_days", 30.0),
240
+ tag_boost_weight=getattr(config, "tag_boost_weight", 0.1),
241
+ )
242
+
243
+ elif mode == "cross-encoder":
244
+ if not CROSS_ENCODER_AVAILABLE:
245
+ raise ImportError(
246
+ "Cross-encoder reranking requires sentence-transformers. "
247
+ "Install with: pip install sentence-transformers"
248
+ )
249
+ return CrossEncoderReranker()
250
+
251
+ elif mode == "auto":
252
+ # Try cross-encoder, fall back to heuristic
253
+ if CROSS_ENCODER_AVAILABLE:
254
+ try:
255
+ return CrossEncoderReranker()
256
+ except Exception as e:
257
+ logger.warning(f"Cross-encoder init failed: {e}. Falling back to heuristic.")
258
+ return HeuristicReranker(
259
+ recency_decay_days=getattr(config, "recency_decay_days", 30.0),
260
+ tag_boost_weight=getattr(config, "tag_boost_weight", 0.1),
261
+ )
262
+
263
+ else:
264
+ raise ValueError(
265
+ f"Unknown reranking mode: {mode}. "
266
+ f"Valid options: 'none', 'heuristic', 'cross-encoder', 'auto'"
267
+ )