tribalmemory 0.1.1__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,267 @@
1
+ """Result reranking for improved retrieval quality.
2
+
3
+ Provides multiple reranking strategies:
4
+ - NoopReranker: Pass-through, no reranking
5
+ - HeuristicReranker: Fast heuristic scoring (recency, tags, length)
6
+ - CrossEncoderReranker: Model-based reranking (sentence-transformers)
7
+
8
+ Reranking happens after initial retrieval (vector + BM25) to refine ordering.
9
+ """
10
+
11
+ import logging
12
+ import math
13
+ from datetime import datetime
14
+ from typing import TYPE_CHECKING, Protocol
15
+
16
+ from ..interfaces import RecallResult
17
+
18
+ if TYPE_CHECKING:
19
+ from ..server.config import SearchConfig
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+ # Lazy import for optional dependency
24
+ CROSS_ENCODER_AVAILABLE = False
25
+ CrossEncoder = None
26
+
27
+ try:
28
+ from sentence_transformers import CrossEncoder as _CrossEncoder
29
+ CrossEncoder = _CrossEncoder
30
+ CROSS_ENCODER_AVAILABLE = True
31
+ except ImportError:
32
+ pass
33
+
34
+
35
+ class IReranker(Protocol):
36
+ """Interface for result reranking."""
37
+
38
+ def rerank(
39
+ self, query: str, candidates: list[RecallResult], top_k: int
40
+ ) -> list[RecallResult]:
41
+ """Rerank candidates and return top_k results.
42
+
43
+ Args:
44
+ query: Original search query
45
+ candidates: Initial retrieval results
46
+ top_k: Number of results to return
47
+
48
+ Returns:
49
+ Reranked results (up to top_k)
50
+ """
51
+ ...
52
+
53
+
54
+ class NoopReranker:
55
+ """Pass-through reranker (no reranking)."""
56
+
57
+ def rerank(
58
+ self, query: str, candidates: list[RecallResult], top_k: int
59
+ ) -> list[RecallResult]:
60
+ """Return top_k candidates unchanged."""
61
+ return candidates[:top_k]
62
+
63
+
64
+ class HeuristicReranker:
65
+ """Heuristic reranking with recency, tag match, and length signals.
66
+
67
+ Combines multiple quality signals:
68
+ - Recency: newer memories score higher (exponential decay)
69
+ - Tag match: query terms matching tags boost score
70
+ - Length penalty: very short or very long content penalized
71
+
72
+ Final score: original_score * (1 + boost_sum)
73
+ """
74
+
75
+ def __init__(
76
+ self,
77
+ recency_decay_days: float = 30.0,
78
+ tag_boost_weight: float = 0.1,
79
+ min_length: int = 10,
80
+ max_length: int = 2000,
81
+ short_penalty: float = 0.05,
82
+ long_penalty: float = 0.03,
83
+ ):
84
+ """Initialize heuristic reranker.
85
+
86
+ Args:
87
+ recency_decay_days: Half-life for recency boost (days)
88
+ tag_boost_weight: Weight for tag match boost
89
+ min_length: Content shorter than this gets penalty
90
+ max_length: Content longer than this gets penalty
91
+ short_penalty: Penalty for content shorter than min_length
92
+ long_penalty: Penalty for content longer than max_length
93
+ """
94
+ self.recency_decay_days = recency_decay_days
95
+ self.tag_boost_weight = tag_boost_weight
96
+ self.min_length = min_length
97
+ self.max_length = max_length
98
+ self.short_penalty = short_penalty
99
+ self.long_penalty = long_penalty
100
+
101
+ def rerank(
102
+ self, query: str, candidates: list[RecallResult], top_k: int
103
+ ) -> list[RecallResult]:
104
+ """Rerank using heuristic scoring."""
105
+ if not candidates:
106
+ return []
107
+
108
+ # Compute boost for each candidate
109
+ scored = []
110
+ query_lower = query.lower()
111
+ query_terms = set(query_lower.split())
112
+ now = datetime.utcnow()
113
+
114
+ for i, candidate in enumerate(candidates):
115
+ boost = 0.0
116
+
117
+ # Recency boost (exponential decay)
118
+ # Brand new memory (age=0) gets boost of 1.0, older memories decay exponentially
119
+ age_days = (now - candidate.memory.created_at).total_seconds() / 86400
120
+ recency_boost = math.exp(-age_days / self.recency_decay_days)
121
+ boost += recency_boost
122
+
123
+ # Tag match boost (exact term matching, not substring)
124
+ if candidate.memory.tags:
125
+ tag_lower = set(t.lower() for t in candidate.memory.tags)
126
+ # Count query terms that exactly match tags
127
+ matches = sum(1 for term in query_terms if term in tag_lower)
128
+ if matches > 0:
129
+ boost += self.tag_boost_weight * matches
130
+
131
+ # Length penalty
132
+ content_length = len(candidate.memory.content)
133
+ if content_length < self.min_length:
134
+ boost -= self.short_penalty
135
+ elif content_length > self.max_length:
136
+ boost -= self.long_penalty
137
+
138
+ # Combine with original score
139
+ final_score = candidate.similarity_score * (1.0 + boost)
140
+
141
+ scored.append((final_score, i, candidate))
142
+
143
+ # Sort by final score (desc), then original index (preserve order on ties)
144
+ scored.sort(key=lambda x: (-x[0], x[1]))
145
+
146
+ # Build reranked results with updated scores
147
+ reranked = []
148
+ for final_score, _, candidate in scored[:top_k]:
149
+ reranked.append(
150
+ RecallResult(
151
+ memory=candidate.memory,
152
+ similarity_score=final_score,
153
+ retrieval_time_ms=candidate.retrieval_time_ms,
154
+ )
155
+ )
156
+
157
+ return reranked
158
+
159
+
160
+ class CrossEncoderReranker:
161
+ """Cross-encoder model-based reranking.
162
+
163
+ Uses a sentence-transformers cross-encoder to score (query, candidate) pairs.
164
+ Model scores relevance directly, producing better ranking than retrieval alone.
165
+
166
+ Requires sentence-transformers package.
167
+ """
168
+
169
+ def __init__(self, model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"):
170
+ """Initialize cross-encoder reranker.
171
+
172
+ Args:
173
+ model_name: HuggingFace model name
174
+
175
+ Raises:
176
+ ImportError: If sentence-transformers not installed
177
+ """
178
+ if not CROSS_ENCODER_AVAILABLE:
179
+ raise ImportError(
180
+ "sentence-transformers required for CrossEncoderReranker. "
181
+ "Install with: pip install sentence-transformers"
182
+ )
183
+
184
+ logger.info(f"Loading cross-encoder model: {model_name}")
185
+ self.model = CrossEncoder(model_name)
186
+
187
+ def rerank(
188
+ self, query: str, candidates: list[RecallResult], top_k: int
189
+ ) -> list[RecallResult]:
190
+ """Rerank using cross-encoder model."""
191
+ if not candidates:
192
+ logger.debug("No candidates to rerank")
193
+ return []
194
+
195
+ # Build (query, content) pairs
196
+ pairs = [(query, candidate.memory.content) for candidate in candidates]
197
+
198
+ # Score with model
199
+ scores = self.model.predict(pairs)
200
+
201
+ # Sort by score descending
202
+ scored = list(zip(scores, candidates))
203
+ scored.sort(key=lambda x: -x[0])
204
+
205
+ # Build reranked results with updated scores
206
+ reranked = []
207
+ for score, candidate in scored[:top_k]:
208
+ reranked.append(
209
+ RecallResult(
210
+ memory=candidate.memory,
211
+ similarity_score=float(score),
212
+ retrieval_time_ms=candidate.retrieval_time_ms,
213
+ )
214
+ )
215
+
216
+ return reranked
217
+
218
+
219
+ def create_reranker(config: "SearchConfig") -> IReranker:
220
+ """Factory function to create reranker from config.
221
+
222
+ Args:
223
+ config: SearchConfig with reranking settings
224
+
225
+ Returns:
226
+ Configured reranker instance
227
+
228
+ Raises:
229
+ ValueError: If reranking mode is invalid
230
+ ImportError: If cross-encoder requested but unavailable
231
+ """
232
+ mode = getattr(config, "reranking", "heuristic")
233
+
234
+ if mode == "none":
235
+ return NoopReranker()
236
+
237
+ elif mode == "heuristic":
238
+ return HeuristicReranker(
239
+ recency_decay_days=getattr(config, "recency_decay_days", 30.0),
240
+ tag_boost_weight=getattr(config, "tag_boost_weight", 0.1),
241
+ )
242
+
243
+ elif mode == "cross-encoder":
244
+ if not CROSS_ENCODER_AVAILABLE:
245
+ raise ImportError(
246
+ "Cross-encoder reranking requires sentence-transformers. "
247
+ "Install with: pip install sentence-transformers"
248
+ )
249
+ return CrossEncoderReranker()
250
+
251
+ elif mode == "auto":
252
+ # Try cross-encoder, fall back to heuristic
253
+ if CROSS_ENCODER_AVAILABLE:
254
+ try:
255
+ return CrossEncoderReranker()
256
+ except Exception as e:
257
+ logger.warning(f"Cross-encoder init failed: {e}. Falling back to heuristic.")
258
+ return HeuristicReranker(
259
+ recency_decay_days=getattr(config, "recency_decay_days", 30.0),
260
+ tag_boost_weight=getattr(config, "tag_boost_weight", 0.1),
261
+ )
262
+
263
+ else:
264
+ raise ValueError(
265
+ f"Unknown reranking mode: {mode}. "
266
+ f"Valid options: 'none', 'heuristic', 'cross-encoder', 'auto'"
267
+ )
@@ -0,0 +1,412 @@
1
+ """Session transcript indexing service.
2
+
3
+ Indexes conversation transcripts as chunked embeddings for contextual recall.
4
+ Supports delta-based ingestion and retention-based cleanup.
5
+ """
6
+
7
+ import logging
8
+ import uuid
9
+ from dataclasses import dataclass, field
10
+ from datetime import datetime, timedelta, timezone
11
+ from typing import Optional
12
+
13
+ from ..interfaces import IEmbeddingService, IVectorStore
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ @dataclass
19
+ class SessionMessage:
20
+ """A single message in a conversation transcript.
21
+
22
+ Attributes:
23
+ role: Message role (user, assistant, system)
24
+ content: Message content
25
+ timestamp: When the message was sent
26
+ """
27
+ role: str
28
+ content: str
29
+ timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
30
+
31
+
32
+ @dataclass
33
+ class SessionChunk:
34
+ """A chunk of conversation transcript with embedding.
35
+
36
+ Attributes:
37
+ chunk_id: Unique identifier for this chunk
38
+ session_id: ID of the session this chunk belongs to
39
+ instance_id: Which agent instance processed this session
40
+ content: The actual conversation content (multiple messages)
41
+ embedding: Vector embedding of the content
42
+ start_time: Timestamp of first message in chunk
43
+ end_time: Timestamp of last message in chunk
44
+ chunk_index: Sequential index within session (0, 1, 2...)
45
+ """
46
+ chunk_id: str
47
+ session_id: str
48
+ instance_id: str
49
+ content: str
50
+ embedding: list[float]
51
+ start_time: datetime
52
+ end_time: datetime
53
+ chunk_index: int
54
+
55
+
56
+ class SessionStore:
57
+ """Service for indexing and searching session transcripts.
58
+
59
+ Usage:
60
+ store = SessionStore(
61
+ instance_id="clawdio-1",
62
+ embedding_service=embedding_service,
63
+ vector_store=vector_store,
64
+ )
65
+
66
+ # Ingest a session transcript
67
+ messages = [
68
+ SessionMessage("user", "What is Docker?", datetime.now(timezone.utc)),
69
+ SessionMessage("assistant", "Docker is a container platform", datetime.now(timezone.utc)),
70
+ ]
71
+ await store.ingest("session-123", messages)
72
+
73
+ # Search across all sessions
74
+ results = await store.search("Docker setup error")
75
+
76
+ # Search within specific session
77
+ results = await store.search("Docker", session_id="session-123")
78
+ """
79
+
80
+ # Chunking parameters
81
+ TARGET_CHUNK_TOKENS = 400 # Target size for each chunk
82
+ WORDS_PER_TOKEN = 0.75 # Approximate tokens per word
83
+ OVERLAP_TOKENS = 50 # Overlap between chunks for context
84
+
85
+ def __init__(
86
+ self,
87
+ instance_id: str,
88
+ embedding_service: IEmbeddingService,
89
+ vector_store: IVectorStore,
90
+ ):
91
+ self.instance_id = instance_id
92
+ self.embedding_service = embedding_service
93
+ self.vector_store = vector_store
94
+
95
+ # Track last ingested index per session for delta ingestion
96
+ self._session_state: dict[str, int] = {}
97
+
98
+ async def ingest(
99
+ self,
100
+ session_id: str,
101
+ messages: list[SessionMessage],
102
+ instance_id: Optional[str] = None,
103
+ ) -> dict:
104
+ """Ingest session messages with delta-based processing.
105
+
106
+ Only processes new messages since last ingestion for this session.
107
+
108
+ Args:
109
+ session_id: Unique identifier for the session
110
+ messages: List of conversation messages
111
+ instance_id: Override instance ID (defaults to self.instance_id)
112
+
113
+ Returns:
114
+ Dict with keys: success, chunks_created, messages_processed
115
+ """
116
+ if not messages:
117
+ return {
118
+ "success": True,
119
+ "chunks_created": 0,
120
+ "messages_processed": 0,
121
+ }
122
+
123
+ # Delta ingestion: only process new messages
124
+ last_index = self._session_state.get(session_id, 0)
125
+ new_messages = messages[last_index:]
126
+
127
+ if not new_messages:
128
+ return {
129
+ "success": True,
130
+ "chunks_created": 0,
131
+ "messages_processed": 0,
132
+ }
133
+
134
+ try:
135
+ # Create chunks from new messages
136
+ chunks = await self._chunk_messages(
137
+ new_messages,
138
+ session_id,
139
+ instance_id or self.instance_id,
140
+ )
141
+
142
+ # Store chunks in vector store
143
+ for chunk in chunks:
144
+ await self._store_chunk(chunk)
145
+
146
+ # Update state
147
+ self._session_state[session_id] = len(messages)
148
+
149
+ return {
150
+ "success": True,
151
+ "chunks_created": len(chunks),
152
+ "messages_processed": len(new_messages),
153
+ }
154
+
155
+ except Exception as e:
156
+ logger.exception(f"Failed to ingest session {session_id}: {e}")
157
+ return {
158
+ "success": False,
159
+ "error": str(e),
160
+ }
161
+
162
+ async def search(
163
+ self,
164
+ query: str,
165
+ session_id: Optional[str] = None,
166
+ limit: int = 5,
167
+ min_relevance: float = 0.0,
168
+ ) -> list[dict]:
169
+ """Search session transcripts by semantic similarity.
170
+
171
+ Args:
172
+ query: Natural language search query
173
+ session_id: Optional filter to specific session
174
+ limit: Maximum number of results to return
175
+ min_relevance: Minimum similarity score (0.0 to 1.0)
176
+
177
+ Returns:
178
+ List of dicts with keys: chunk_id, session_id, instance_id,
179
+ content, similarity_score, start_time, end_time, chunk_index
180
+ """
181
+ try:
182
+ # Generate query embedding
183
+ query_embedding = await self.embedding_service.embed(query)
184
+
185
+ # Search chunks
186
+ results = await self._search_chunks(
187
+ query_embedding,
188
+ session_id,
189
+ limit,
190
+ min_relevance,
191
+ )
192
+
193
+ return results
194
+
195
+ except Exception as e:
196
+ logger.exception(f"Failed to search sessions: {e}")
197
+ return []
198
+
199
+ async def cleanup(self, retention_days: int = 30) -> int:
200
+ """Delete session chunks older than retention period.
201
+
202
+ Args:
203
+ retention_days: Number of days to retain chunks
204
+
205
+ Returns:
206
+ Number of chunks deleted
207
+ """
208
+ try:
209
+ cutoff_time = datetime.now(timezone.utc) - timedelta(days=retention_days)
210
+
211
+ # Find and delete expired chunks
212
+ deleted = await self._delete_chunks_before(cutoff_time)
213
+
214
+ return deleted
215
+
216
+ except Exception as e:
217
+ logger.exception(f"Failed to cleanup sessions: {e}")
218
+ return 0
219
+
220
+ async def get_stats(self) -> dict:
221
+ """Get statistics about indexed sessions.
222
+
223
+ Returns:
224
+ Dict with keys: total_chunks, total_sessions,
225
+ earliest_chunk, latest_chunk
226
+ """
227
+ try:
228
+ chunks = await self._get_all_chunks()
229
+
230
+ if not chunks:
231
+ return {
232
+ "total_chunks": 0,
233
+ "total_sessions": 0,
234
+ "earliest_chunk": None,
235
+ "latest_chunk": None,
236
+ }
237
+
238
+ session_ids = set()
239
+ timestamps = []
240
+
241
+ for chunk in chunks:
242
+ session_ids.add(chunk["session_id"])
243
+ timestamps.append(chunk["start_time"])
244
+
245
+ return {
246
+ "total_chunks": len(chunks),
247
+ "total_sessions": len(session_ids),
248
+ "earliest_chunk": min(timestamps) if timestamps else None,
249
+ "latest_chunk": max(timestamps) if timestamps else None,
250
+ }
251
+
252
+ except Exception as e:
253
+ logger.exception(f"Failed to get stats: {e}")
254
+ return {
255
+ "total_chunks": 0,
256
+ "total_sessions": 0,
257
+ "earliest_chunk": None,
258
+ "latest_chunk": None,
259
+ }
260
+
261
+ async def _chunk_messages(
262
+ self,
263
+ messages: list[SessionMessage],
264
+ session_id: str,
265
+ instance_id: str,
266
+ ) -> list[SessionChunk]:
267
+ """Chunk messages into ~400 token windows with overlap.
268
+
269
+ Uses a simple word-count approximation: words / 0.75 ≈ tokens.
270
+ """
271
+ chunks = []
272
+ chunk_index = 0
273
+
274
+ # Convert messages to text with timestamps
275
+ message_texts = []
276
+ for msg in messages:
277
+ text = f"{msg.role}: {msg.content}"
278
+ message_texts.append((text, msg.timestamp))
279
+
280
+ # Estimate tokens
281
+ target_words = int(self.TARGET_CHUNK_TOKENS * self.WORDS_PER_TOKEN)
282
+ overlap_words = int(self.OVERLAP_TOKENS * self.WORDS_PER_TOKEN)
283
+
284
+ i = 0
285
+ while i < len(message_texts):
286
+ chunk_messages = []
287
+ chunk_word_count = 0
288
+ start_time = message_texts[i][1]
289
+ end_time = start_time
290
+
291
+ # Collect messages until we reach target size
292
+ while i < len(message_texts) and chunk_word_count < target_words:
293
+ text, timestamp = message_texts[i]
294
+ words = len(text.split())
295
+ chunk_messages.append(text)
296
+ chunk_word_count += words
297
+ end_time = timestamp
298
+ i += 1
299
+
300
+ # Create chunk
301
+ if chunk_messages:
302
+ content = "\n".join(chunk_messages)
303
+ embedding = await self.embedding_service.embed(content)
304
+
305
+ chunk = SessionChunk(
306
+ chunk_id=str(uuid.uuid4()),
307
+ session_id=session_id,
308
+ instance_id=instance_id,
309
+ content=content,
310
+ embedding=embedding,
311
+ start_time=start_time,
312
+ end_time=end_time,
313
+ chunk_index=chunk_index,
314
+ )
315
+ chunks.append(chunk)
316
+ chunk_index += 1
317
+
318
+ # Backtrack for overlap
319
+ if i < len(message_texts):
320
+ # Calculate how many messages to backtrack
321
+ overlap_word_target = 0
322
+ backtrack = 0
323
+ while (backtrack < len(chunk_messages) and
324
+ overlap_word_target < overlap_words):
325
+ backtrack += 1
326
+ overlap_word_target += len(chunk_messages[-backtrack].split())
327
+
328
+ i -= min(backtrack, 2) # Backtrack at most 2 messages
329
+ i = max(i, 0)
330
+
331
+ return chunks
332
+
333
+ async def _store_chunk(self, chunk: SessionChunk) -> None:
334
+ """Store a session chunk in memory.
335
+
336
+ Note: Currently uses in-memory list storage. This is intentional for v0.2.0
337
+ to keep the initial implementation simple and testable. Data does not persist
338
+ across restarts. A future version will integrate with LanceDB for persistent
339
+ storage in a separate 'session_chunks' table. See issue #38 follow-up.
340
+ """
341
+ if not hasattr(self, '_chunks'):
342
+ self._chunks = []
343
+
344
+ self._chunks.append({
345
+ "chunk_id": chunk.chunk_id,
346
+ "session_id": chunk.session_id,
347
+ "instance_id": chunk.instance_id,
348
+ "content": chunk.content,
349
+ "embedding": chunk.embedding,
350
+ "start_time": chunk.start_time,
351
+ "end_time": chunk.end_time,
352
+ "chunk_index": chunk.chunk_index,
353
+ })
354
+
355
+ async def _search_chunks(
356
+ self,
357
+ query_embedding: list[float],
358
+ session_id: Optional[str],
359
+ limit: int,
360
+ min_relevance: float,
361
+ ) -> list[dict]:
362
+ """Search for chunks by similarity."""
363
+ if not hasattr(self, '_chunks'):
364
+ return []
365
+
366
+ # Calculate similarities
367
+ results = []
368
+ for chunk in self._chunks:
369
+ # Filter by session_id if provided
370
+ if session_id and chunk["session_id"] != session_id:
371
+ continue
372
+
373
+ similarity = self.embedding_service.similarity(
374
+ query_embedding,
375
+ chunk["embedding"],
376
+ )
377
+
378
+ if similarity >= min_relevance:
379
+ results.append({
380
+ "chunk_id": chunk["chunk_id"],
381
+ "session_id": chunk["session_id"],
382
+ "instance_id": chunk["instance_id"],
383
+ "content": chunk["content"],
384
+ "similarity_score": similarity,
385
+ "start_time": chunk["start_time"],
386
+ "end_time": chunk["end_time"],
387
+ "chunk_index": chunk["chunk_index"],
388
+ })
389
+
390
+ # Sort by similarity
391
+ results.sort(key=lambda x: x["similarity_score"], reverse=True)
392
+
393
+ return results[:limit]
394
+
395
+ async def _delete_chunks_before(self, cutoff_time: datetime) -> int:
396
+ """Delete chunks older than cutoff time."""
397
+ if not hasattr(self, '_chunks'):
398
+ return 0
399
+
400
+ initial_count = len(self._chunks)
401
+ self._chunks = [
402
+ chunk for chunk in self._chunks
403
+ if chunk["end_time"] >= cutoff_time
404
+ ]
405
+
406
+ return initial_count - len(self._chunks)
407
+
408
+ async def _get_all_chunks(self) -> list[dict]:
409
+ """Get all stored chunks."""
410
+ if not hasattr(self, '_chunks'):
411
+ return []
412
+ return self._chunks