convmemory 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
convmemory/reranker.py ADDED
@@ -0,0 +1,253 @@
1
+ from dataclasses import dataclass
2
+ from typing import Iterable, Optional
3
+
4
+ import numpy as np
5
+ import torch
6
+
7
+ from .scoring import (
8
+ build_memory_to_windows,
9
+ cosine_scores,
10
+ rerank_candidates,
11
+ score_ce_lite,
12
+ window_scores,
13
+ )
14
+
15
+
16
+ @dataclass
17
+ class RerankConfig:
18
+ window_size: int = 5
19
+ stride: int = 1
20
+ candidate_top_n: int = 500
21
+ raw_weight: float = 0.0
22
+ dca_router_block_size: int = 32
23
+ lexical_features: bool = True
24
+ window_mode: str = "full"
25
+
26
+
27
+ @dataclass
28
+ class RerankResult:
29
+ memory_id: str
30
+ score: float
31
+ raw_score: float
32
+ rank: int
33
+ text: Optional[str] = None
34
+
35
+
36
+ def sliding_windows(num_items: int, window_size: int, stride: int):
37
+ if num_items <= 0:
38
+ return []
39
+ if num_items <= window_size:
40
+ return [list(range(num_items))]
41
+
42
+ windows = []
43
+ for start in range(0, num_items - window_size + 1, stride):
44
+ windows.append(list(range(start, start + window_size)))
45
+ last = list(range(num_items - window_size, num_items))
46
+ if windows[-1] != last:
47
+ windows.append(last)
48
+ return windows
49
+
50
+
51
+ def normalize_rows(matrix):
52
+ matrix = np.asarray(matrix, dtype=np.float32)
53
+ norms = np.linalg.norm(matrix, axis=1, keepdims=True) + 1e-8
54
+ return matrix / norms
55
+
56
+
57
+ def window_tensor(memory_embeddings, windows):
58
+ if not windows:
59
+ return torch.zeros((0, 0, memory_embeddings.shape[1]), dtype=torch.float32)
60
+ if all(len(window) == len(windows[0]) for window in windows):
61
+ indices = np.asarray(windows, dtype=np.int64)
62
+ return torch.as_tensor(memory_embeddings[indices], dtype=torch.float32)
63
+ max_len = max(len(window) for window in windows)
64
+ batch = np.zeros((len(windows), max_len, memory_embeddings.shape[1]), dtype=np.float32)
65
+ for i, window in enumerate(windows):
66
+ batch[i, : len(window)] = memory_embeddings[window]
67
+ return torch.tensor(batch, dtype=torch.float32)
68
+
69
+
70
+ def candidate_local_windows(num_items, candidate_indices, window_size):
71
+ if num_items <= 0:
72
+ return []
73
+ if num_items <= window_size:
74
+ return [list(range(num_items))]
75
+
76
+ windows = []
77
+ seen = set()
78
+ half = window_size // 2
79
+ for idx in candidate_indices:
80
+ idx = int(idx)
81
+ start = idx - half
82
+ end = start + window_size
83
+ if start < 0:
84
+ start = 0
85
+ end = window_size
86
+ if end > num_items:
87
+ end = num_items
88
+ start = max(0, end - window_size)
89
+ window = tuple(range(start, end))
90
+ if window not in seen:
91
+ seen.add(window)
92
+ windows.append(list(window))
93
+ return windows
94
+
95
+
96
+ class ConvMemoryReranker:
97
+ """Plug-in reranker over precomputed memory embeddings.
98
+
99
+ The class deliberately does not own text embedding. In production, users can
100
+ bring any retriever or embedding model, pass the raw top-k candidates here,
101
+ and get a ConvMemory-enhanced ordering back.
102
+ """
103
+
104
+ def __init__(self, conv_model, scorer, config=None, device="cpu"):
105
+ self.conv_model = conv_model
106
+ self.scorer = scorer
107
+ self.config = config or RerankConfig()
108
+ self.device = device
109
+
110
+ def make_item(
111
+ self,
112
+ query_embedding,
113
+ memory_embeddings,
114
+ memory_ids,
115
+ memory_texts: Optional[Iterable[str]] = None,
116
+ query: str = "",
117
+ ):
118
+ memory_ids = [str(x) for x in memory_ids]
119
+ memory_embeddings = normalize_rows(memory_embeddings)
120
+ query_embedding = np.asarray(query_embedding, dtype=np.float32)
121
+ query_embedding = query_embedding / (np.linalg.norm(query_embedding) + 1e-8)
122
+ if memory_texts is None:
123
+ memory_texts = ["" for _ in memory_ids]
124
+ memories = [
125
+ {"id": memory_id, "text": text}
126
+ for memory_id, text in zip(memory_ids, memory_texts)
127
+ ]
128
+ windows = sliding_windows(
129
+ len(memory_ids),
130
+ self.config.window_size,
131
+ self.config.stride,
132
+ )
133
+
134
+ return {
135
+ "question_id": "query",
136
+ "question_type": "unknown",
137
+ "query": query,
138
+ "query_embedding": query_embedding.astype(np.float32),
139
+ "memory_embeddings": memory_embeddings.astype(np.float32),
140
+ "memory_ids": memory_ids,
141
+ "memories": memories,
142
+ "windows": windows,
143
+ "window_tensor": None,
144
+ "gold_memory_ids": [],
145
+ }
146
+
147
+ def rerank_embeddings(
148
+ self,
149
+ query_embedding,
150
+ memory_embeddings,
151
+ memory_ids,
152
+ memory_texts: Optional[Iterable[str]] = None,
153
+ query: str = "",
154
+ candidate_indices=None,
155
+ candidate_top_n: Optional[int] = None,
156
+ raw_weight: Optional[float] = None,
157
+ window_mode: Optional[str] = None,
158
+ ):
159
+ item = self.make_item(
160
+ query_embedding=query_embedding,
161
+ memory_embeddings=memory_embeddings,
162
+ memory_ids=memory_ids,
163
+ memory_texts=memory_texts,
164
+ query=query,
165
+ )
166
+ return self.rerank_item(
167
+ item,
168
+ candidate_indices=candidate_indices,
169
+ candidate_top_n=candidate_top_n,
170
+ raw_weight=raw_weight,
171
+ window_mode=window_mode,
172
+ )
173
+
174
+ def rerank_item(
175
+ self,
176
+ item,
177
+ candidate_indices=None,
178
+ candidate_top_n: Optional[int] = None,
179
+ raw_weight: Optional[float] = None,
180
+ window_mode: Optional[str] = None,
181
+ ):
182
+ raw_scores = cosine_scores(item["query_embedding"], item["memory_embeddings"])
183
+ if candidate_indices is None:
184
+ top_n = candidate_top_n or self.config.candidate_top_n
185
+ candidate_indices = np.argsort(-raw_scores)[: min(top_n, len(raw_scores))]
186
+ else:
187
+ candidate_indices = np.asarray(candidate_indices, dtype=np.int64)
188
+
189
+ scoring_item = item
190
+ selected_window_mode = window_mode or self.config.window_mode
191
+ if selected_window_mode == "candidate_local":
192
+ local_windows = candidate_local_windows(
193
+ len(item["memory_ids"]),
194
+ candidate_indices,
195
+ self.config.window_size,
196
+ )
197
+ scoring_item = {
198
+ **item,
199
+ "windows": local_windows,
200
+ "window_tensor": window_tensor(item["memory_embeddings"], local_windows),
201
+ }
202
+ elif selected_window_mode != "full":
203
+ raise ValueError(f"Unknown window_mode: {selected_window_mode}")
204
+ elif scoring_item.get("window_tensor") is None:
205
+ scoring_item = {
206
+ **item,
207
+ "window_tensor": window_tensor(item["memory_embeddings"], item["windows"]),
208
+ }
209
+
210
+ with torch.no_grad():
211
+ conv_tensor = window_scores(self.conv_model, scoring_item, self.device)
212
+ memory_to_windows = build_memory_to_windows(scoring_item["windows"])
213
+ _, _, ce_lite_scores = score_ce_lite(
214
+ self.conv_model,
215
+ self.scorer,
216
+ scoring_item,
217
+ candidate_indices,
218
+ self.device,
219
+ raw_scores_all=raw_scores,
220
+ window_logits=conv_tensor,
221
+ memory_to_windows=memory_to_windows,
222
+ dca_router_block_size=self.config.dca_router_block_size,
223
+ lexical_features=self.config.lexical_features,
224
+ )
225
+ ranked_ids = rerank_candidates(
226
+ raw_scores,
227
+ candidate_indices,
228
+ ce_lite_scores,
229
+ item["memory_ids"],
230
+ raw_weight=self.config.raw_weight if raw_weight is None else raw_weight,
231
+ )
232
+ score_by_id = {
233
+ item["memory_ids"][int(idx)]: float(score)
234
+ for idx, score in zip(candidate_indices, ce_lite_scores)
235
+ }
236
+ raw_by_id = {
237
+ item["memory_ids"][idx]: float(score)
238
+ for idx, score in enumerate(raw_scores)
239
+ }
240
+ text_by_id = {
241
+ str(memory.get("id", idx)): memory.get("text")
242
+ for idx, memory in enumerate(item.get("memories", []))
243
+ }
244
+ return [
245
+ RerankResult(
246
+ memory_id=memory_id,
247
+ score=score_by_id.get(memory_id, raw_by_id[memory_id]),
248
+ raw_score=raw_by_id[memory_id],
249
+ rank=rank,
250
+ text=text_by_id.get(memory_id),
251
+ )
252
+ for rank, memory_id in enumerate(ranked_ids, start=1)
253
+ ]
convmemory/routing.py ADDED
@@ -0,0 +1,208 @@
1
+ from dataclasses import dataclass
2
+ from typing import Iterable, List, Mapping, Optional, Sequence
3
+
4
+ import numpy as np
5
+
6
+
7
+ @dataclass(frozen=True)
8
+ class CompressionRouteConfig:
9
+ """Configuration for note-to-memory candidate routing.
10
+
11
+ The defaults are intentionally conservative and reflect the stable
12
+ LoCoMo v0.31 setting: use compressed notes to select a smaller raw-memory
13
+ pool, then let ConvMemory rerank that pool.
14
+ """
15
+
16
+ note_depth: int = 240
17
+ max_sources_per_note: int = 5
18
+ max_candidates: int = 450
19
+ raw_anchor: int = 80
20
+
21
+
22
+ @dataclass(frozen=True)
23
+ class CompressionRouteResult:
24
+ candidate_indices: List[int]
25
+ candidate_ids: List[str]
26
+ note_indices: List[int]
27
+ raw_anchor_count: int
28
+
29
+
30
+ @dataclass(frozen=True)
31
+ class CompressedNoteConfig:
32
+ """Configuration for lightweight raw-memory note construction."""
33
+
34
+ mode: str = "session"
35
+ block_size: int = 32
36
+ representatives: int = 3
37
+ strategy: str = "central"
38
+ session_key: str = "session_id"
39
+
40
+
41
+ class CompressionRouter:
42
+ """Route raw memories through compressed note blocks.
43
+
44
+ Compressed memories are dictionaries with at least:
45
+ - `text`: note text used for embedding
46
+ - `source_ids`: raw memory ids covered by this note
47
+
48
+ The router does not call ConvMemory directly. It only returns candidate ids
49
+ so it can be plugged into any retrieval or agent-memory pipeline.
50
+ """
51
+
52
+ def __init__(self, config: Optional[CompressionRouteConfig] = None):
53
+ self.config = config or CompressionRouteConfig()
54
+
55
+ def route(
56
+ self,
57
+ query_embedding,
58
+ memory_embeddings,
59
+ memory_ids: Sequence[str],
60
+ compressed_embeddings,
61
+ compressed_memories: Iterable[Mapping],
62
+ ) -> CompressionRouteResult:
63
+ memory_ids = [str(memory_id) for memory_id in memory_ids]
64
+ compressed_memories = list(compressed_memories)
65
+ if len(compressed_memories) != len(compressed_embeddings):
66
+ raise ValueError("compressed_memories and compressed_embeddings must have the same length")
67
+
68
+ query = _normalize_vector(query_embedding)
69
+ memories = _normalize_matrix(memory_embeddings)
70
+ notes = _normalize_matrix(compressed_embeddings)
71
+
72
+ raw_scores = memories @ query
73
+ raw_order = np.argsort(-raw_scores)
74
+ note_scores = notes @ query if len(notes) else np.asarray([], dtype=np.float32)
75
+ note_order = np.argsort(-note_scores)[: max(0, int(self.config.note_depth))]
76
+
77
+ id_to_index = {memory_id: idx for idx, memory_id in enumerate(memory_ids)}
78
+ selected: List[int] = []
79
+ seen = set()
80
+
81
+ for idx in raw_order[: max(0, int(self.config.raw_anchor))]:
82
+ self._add_candidate(int(idx), selected, seen)
83
+ if len(selected) >= self.config.max_candidates:
84
+ return self._result(selected, memory_ids, note_order)
85
+
86
+ for note_idx in note_order:
87
+ source_indices = []
88
+ for source_id in compressed_memories[int(note_idx)].get("source_ids", []):
89
+ source_key = str(source_id)
90
+ if source_key in id_to_index:
91
+ source_indices.append(id_to_index[source_key])
92
+ source_indices.sort(key=lambda idx: -float(raw_scores[idx]))
93
+ limit = int(self.config.max_sources_per_note)
94
+ if limit > 0:
95
+ source_indices = source_indices[:limit]
96
+
97
+ for idx in source_indices:
98
+ self._add_candidate(int(idx), selected, seen)
99
+ if len(selected) >= self.config.max_candidates:
100
+ return self._result(selected, memory_ids, note_order)
101
+
102
+ return self._result(selected, memory_ids, note_order)
103
+
104
+ @staticmethod
105
+ def _add_candidate(idx: int, selected: List[int], seen) -> None:
106
+ if idx in seen:
107
+ return
108
+ selected.append(idx)
109
+ seen.add(idx)
110
+
111
+ def _result(self, selected: List[int], memory_ids: Sequence[str], note_order) -> CompressionRouteResult:
112
+ return CompressionRouteResult(
113
+ candidate_indices=list(selected),
114
+ candidate_ids=[memory_ids[idx] for idx in selected],
115
+ note_indices=[int(idx) for idx in note_order],
116
+ raw_anchor_count=min(len(selected), max(0, int(self.config.raw_anchor))),
117
+ )
118
+
119
+
120
+ def _normalize_vector(x):
121
+ arr = np.asarray(x, dtype=np.float32)
122
+ return arr / (np.linalg.norm(arr) + 1e-8)
123
+
124
+
125
+ def _normalize_matrix(x):
126
+ arr = np.asarray(x, dtype=np.float32)
127
+ if arr.ndim == 1:
128
+ arr = arr.reshape(1, -1)
129
+ return arr / (np.linalg.norm(arr, axis=1, keepdims=True) + 1e-8)
130
+
131
+
132
+ def build_compressed_notes(
133
+ memories: Iterable[Mapping],
134
+ memory_embeddings,
135
+ config: Optional[CompressedNoteConfig] = None,
136
+ ):
137
+ """Build compressed notes from raw memories.
138
+
139
+ This helper is intentionally simple: it groups an ordered memory stream by
140
+ session or fixed-size blocks, chooses representative turns, and keeps
141
+ `source_ids` so the note can be expanded back to raw memories.
142
+ """
143
+
144
+ cfg = config or CompressedNoteConfig()
145
+ memories = list(memories)
146
+ embeddings = _normalize_matrix(memory_embeddings)
147
+ if len(memories) != len(embeddings):
148
+ raise ValueError("memories and memory_embeddings must have the same length")
149
+ if cfg.mode not in {"session", "block"}:
150
+ raise ValueError("CompressedNoteConfig.mode must be 'session' or 'block'")
151
+ if cfg.strategy not in {"central", "first"}:
152
+ raise ValueError("CompressedNoteConfig.strategy must be 'central' or 'first'")
153
+
154
+ groups = _session_groups(memories, cfg.session_key)
155
+ if cfg.mode == "block":
156
+ groups = [
157
+ group[start : start + cfg.block_size]
158
+ for group in groups
159
+ for start in range(0, len(group), cfg.block_size)
160
+ ]
161
+
162
+ notes = []
163
+ for note_idx, group in enumerate(groups):
164
+ if not group:
165
+ continue
166
+ reps = _representative_indices(group, embeddings, cfg.strategy, cfg.representatives)
167
+ text = " ".join(str(memories[idx].get("text", "")) for idx in reps)
168
+ source_ids = [str(memories[idx].get("id", idx)) for idx in group]
169
+ session_id = str(memories[group[0]].get(cfg.session_key, ""))
170
+ notes.append(
171
+ {
172
+ "id": f"{cfg.mode}:{note_idx}",
173
+ "text": text,
174
+ "source_ids": source_ids,
175
+ cfg.session_key: session_id,
176
+ "granularity": cfg.mode,
177
+ }
178
+ )
179
+ return notes
180
+
181
+
182
+ def _session_groups(memories, session_key):
183
+ groups = []
184
+ current_value = None
185
+ current = []
186
+ for idx, item in enumerate(memories):
187
+ value = item.get(session_key, "")
188
+ if current and value != current_value:
189
+ groups.append(current)
190
+ current = []
191
+ current_value = value
192
+ current.append(idx)
193
+ if current:
194
+ groups.append(current)
195
+ return groups
196
+
197
+
198
+ def _representative_indices(group, embeddings, strategy, representatives):
199
+ count = max(1, min(int(representatives), len(group)))
200
+ if strategy == "first":
201
+ return group[:count]
202
+
203
+ local_embeddings = embeddings[group]
204
+ centroid = local_embeddings.mean(axis=0)
205
+ centroid = centroid / (np.linalg.norm(centroid) + 1e-8)
206
+ scores = local_embeddings @ centroid
207
+ picked = np.argsort(-scores)[:count]
208
+ return [group[int(i)] for i in sorted(picked)]