convmemory 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
convmemory/scoring.py ADDED
@@ -0,0 +1,314 @@
1
+ import re
2
+ from functools import lru_cache
3
+
4
+ import numpy as np
5
+ import torch
6
+
7
+
8
+ TOKEN_RE = re.compile(r"[A-Za-z0-9]+")
9
+ STOPWORDS = {
10
+ "a",
11
+ "an",
12
+ "and",
13
+ "are",
14
+ "as",
15
+ "at",
16
+ "be",
17
+ "did",
18
+ "do",
19
+ "does",
20
+ "for",
21
+ "from",
22
+ "had",
23
+ "has",
24
+ "have",
25
+ "he",
26
+ "her",
27
+ "him",
28
+ "his",
29
+ "i",
30
+ "in",
31
+ "is",
32
+ "it",
33
+ "of",
34
+ "on",
35
+ "or",
36
+ "she",
37
+ "that",
38
+ "the",
39
+ "their",
40
+ "they",
41
+ "to",
42
+ "was",
43
+ "were",
44
+ "what",
45
+ "when",
46
+ "where",
47
+ "which",
48
+ "who",
49
+ "why",
50
+ "with",
51
+ }
52
+
53
+
54
+ def cosine_scores(query, matrix):
55
+ return matrix @ query
56
+
57
+
58
+ def normalize_scores(scores):
59
+ scores = np.asarray(scores, dtype=np.float32)
60
+ std = float(scores.std())
61
+ if std < 1e-8:
62
+ return scores - float(scores.mean())
63
+ return (scores - float(scores.mean())) / std
64
+
65
+
66
+ class CELiteScorer(torch.nn.Module):
67
+ def __init__(self, dim, hidden_dim=256, extra_scalar_features=0, extra_dense_features=0):
68
+ super().__init__()
69
+ self.input_dim = dim * 4 + 4 + extra_scalar_features + extra_dense_features
70
+ self.net = torch.nn.Sequential(
71
+ torch.nn.Linear(self.input_dim, hidden_dim),
72
+ torch.nn.GELU(),
73
+ torch.nn.LayerNorm(hidden_dim),
74
+ torch.nn.Linear(hidden_dim, hidden_dim // 2),
75
+ torch.nn.GELU(),
76
+ torch.nn.Linear(hidden_dim // 2, 1),
77
+ )
78
+
79
+ def forward(self, features):
80
+ return self.net(features).squeeze(-1)
81
+
82
+
83
+ def window_scores(model, item, device):
84
+ q = item["query_embedding"]
85
+ query = torch.tensor(q[None, :], dtype=torch.float32, device=device)
86
+ window_batch = item["window_tensor"].to(device)
87
+ query_batch = query.expand(window_batch.shape[0], -1)
88
+ kwargs = {}
89
+ if "window_type_tensor" in item:
90
+ kwargs["type_ids"] = item["window_type_tensor"].to(device)
91
+
92
+ if hasattr(model, "score_windows") and getattr(model, "score_mode", "cosine") != "cosine":
93
+ return model.score_windows(window_batch, query=query_batch, **kwargs)
94
+
95
+ blocks = model(window_batch, query=query_batch, **kwargs)
96
+ return (query @ blocks.T).squeeze(0)
97
+
98
+
99
+ def build_memory_to_windows(windows):
100
+ memory_to_windows = {}
101
+ for window_idx, window in enumerate(windows):
102
+ for memory_idx in window:
103
+ memory_to_windows.setdefault(int(memory_idx), []).append(window_idx)
104
+ return memory_to_windows
105
+
106
+
107
+ def best_window_scores_for_candidates(window_logits, memory_to_windows, candidate_indices, device):
108
+ if not torch.is_grad_enabled():
109
+ fallback_value = float(window_logits.min().detach().cpu())
110
+ window_values = window_logits.detach().cpu().numpy()
111
+ values = []
112
+ for memory_idx in candidate_indices:
113
+ touching = memory_to_windows.get(int(memory_idx))
114
+ if not touching:
115
+ values.append(fallback_value)
116
+ else:
117
+ values.append(float(window_values[touching].max()))
118
+ return torch.tensor(values, dtype=torch.float32, device=device)
119
+
120
+ fallback = window_logits.min()
121
+ values = []
122
+ for memory_idx in candidate_indices:
123
+ touching = memory_to_windows.get(int(memory_idx))
124
+ if not touching:
125
+ values.append(fallback)
126
+ continue
127
+ idx = torch.tensor(touching, dtype=torch.long, device=device)
128
+ values.append(window_logits.index_select(0, idx).max())
129
+ return torch.stack(values)
130
+
131
+
132
+ def dca_router_outputs(item, candidate_indices, block_size, device):
133
+ memory_embeddings = item["memory_embeddings"].astype(np.float32)
134
+ q_np = item["query_embedding"].astype(np.float32)
135
+ num_memories = memory_embeddings.shape[0]
136
+ if num_memories == 0:
137
+ return torch.zeros(len(candidate_indices), dtype=torch.float32, device=device)
138
+
139
+ block_ids = np.arange(num_memories) // max(1, block_size)
140
+ num_blocks = int(block_ids.max()) + 1
141
+ block_embeddings = np.zeros((num_blocks, memory_embeddings.shape[1]), dtype=np.float32)
142
+ for block_idx in range(num_blocks):
143
+ block_embeddings[block_idx] = memory_embeddings[block_ids == block_idx].mean(axis=0)
144
+
145
+ block_embeddings = block_embeddings / (
146
+ np.linalg.norm(block_embeddings, axis=1, keepdims=True) + 1e-8
147
+ )
148
+ q_norm = q_np / (np.linalg.norm(q_np) + 1e-8)
149
+ block_scores = block_embeddings @ q_norm
150
+ candidate_blocks = block_ids[candidate_indices]
151
+ candidate_scores = block_scores[candidate_blocks].astype(np.float32)
152
+ return torch.tensor(candidate_scores, dtype=torch.float32, device=device)
153
+
154
+
155
+ @lru_cache(maxsize=250_000)
156
+ def lexical_token_tuple(text):
157
+ text = str(text)
158
+ return tuple(t for t in TOKEN_RE.findall(text.lower()) if len(t) > 1 and t not in STOPWORDS)
159
+
160
+
161
+ @lru_cache(maxsize=250_000)
162
+ def lexical_signature(text):
163
+ tokens = lexical_token_tuple(str(text))
164
+ return frozenset(tokens), frozenset(zip(tokens, tokens[1:]))
165
+
166
+
167
+ def ensure_lexical_cache(item):
168
+ cache = item.get("_lexical_cache")
169
+ if cache is not None:
170
+ return cache
171
+
172
+ query_set, query_bigrams = lexical_signature(item["query"])
173
+ cache = {
174
+ "query_set": query_set,
175
+ "query_bigrams": query_bigrams,
176
+ "memory_signatures": [None for _ in item["memories"]],
177
+ }
178
+ item["_lexical_cache"] = cache
179
+ return cache
180
+
181
+
182
+ def lexical_overlap_features(item, candidate_indices, device):
183
+ cache = ensure_lexical_cache(item)
184
+ query_set = cache["query_set"]
185
+ query_bigrams = cache["query_bigrams"]
186
+ memory_signatures = cache["memory_signatures"]
187
+ rows = []
188
+ for idx in candidate_indices:
189
+ idx = int(idx)
190
+ signature = memory_signatures[idx]
191
+ if signature is None:
192
+ signature = lexical_signature(item["memories"][idx]["text"])
193
+ memory_signatures[idx] = signature
194
+ memory_set, candidate_bigrams = signature
195
+ overlap = query_set & memory_set
196
+ union = query_set | memory_set
197
+ rows.append(
198
+ [
199
+ len(overlap) / max(1, len(query_set)),
200
+ len(overlap) / max(1, len(union)),
201
+ len(query_bigrams & candidate_bigrams) / max(1, len(query_bigrams)),
202
+ np.log1p(len(overlap)) / np.log1p(max(1, len(query_set))),
203
+ ]
204
+ )
205
+ return torch.tensor(rows, dtype=torch.float32, device=device)
206
+
207
+
208
+ def candidate_features(
209
+ model,
210
+ item,
211
+ candidate_indices,
212
+ device,
213
+ raw_scores_all=None,
214
+ window_logits=None,
215
+ memory_to_windows=None,
216
+ dca_router_block_size=0,
217
+ lexical_features=False,
218
+ ):
219
+ q_np = item["query_embedding"].astype(np.float32)
220
+ memory_np = item["memory_embeddings"][candidate_indices].astype(np.float32)
221
+ if raw_scores_all is None:
222
+ raw_scores_all = cosine_scores(q_np, item["memory_embeddings"])
223
+ raw_scores = raw_scores_all[candidate_indices].astype(np.float32)
224
+ raw_norm = normalize_scores(raw_scores)
225
+
226
+ if window_logits is None:
227
+ window_logits = window_scores(model, item, device)
228
+ if memory_to_windows is None:
229
+ memory_to_windows = build_memory_to_windows(item["windows"])
230
+ best_window = best_window_scores_for_candidates(
231
+ window_logits,
232
+ memory_to_windows,
233
+ candidate_indices,
234
+ device,
235
+ )
236
+ best_window_norm = (best_window - best_window.mean()) / (best_window.std(unbiased=False) + 1e-6)
237
+
238
+ extra_features = []
239
+ if dca_router_block_size > 0:
240
+ router_scores = dca_router_outputs(item, candidate_indices, dca_router_block_size, device)
241
+ router_norm = (router_scores - router_scores.mean()) / (
242
+ router_scores.std(unbiased=False) + 1e-6
243
+ )
244
+ extra_features.append(router_norm[:, None])
245
+ if lexical_features:
246
+ extra_features.append(lexical_overlap_features(item, candidate_indices, device))
247
+
248
+ q = torch.tensor(q_np[None, :], dtype=torch.float32, device=device)
249
+ memories = torch.tensor(memory_np, dtype=torch.float32, device=device)
250
+ q_batch = q.expand(memories.shape[0], -1)
251
+ raw_rank = torch.linspace(0.0, 1.0, steps=memories.shape[0], dtype=torch.float32, device=device)
252
+ position = torch.tensor(
253
+ candidate_indices / max(1, len(item["memory_ids"]) - 1),
254
+ dtype=torch.float32,
255
+ device=device,
256
+ )
257
+ raw_tensor = torch.tensor(raw_norm, dtype=torch.float32, device=device)
258
+
259
+ features = torch.cat(
260
+ [
261
+ q_batch,
262
+ memories,
263
+ q_batch * memories,
264
+ torch.abs(q_batch - memories),
265
+ raw_tensor[:, None],
266
+ best_window_norm[:, None],
267
+ raw_rank[:, None],
268
+ position[:, None],
269
+ *extra_features,
270
+ ],
271
+ dim=-1,
272
+ )
273
+ return features, raw_scores, best_window.detach().cpu().numpy()
274
+
275
+
276
+ def score_ce_lite(
277
+ model,
278
+ scorer,
279
+ item,
280
+ candidate_indices,
281
+ device,
282
+ raw_scores_all=None,
283
+ window_logits=None,
284
+ memory_to_windows=None,
285
+ dca_router_block_size=0,
286
+ lexical_features=False,
287
+ ):
288
+ with torch.no_grad():
289
+ features, raw_scores, window_values = candidate_features(
290
+ model,
291
+ item,
292
+ candidate_indices,
293
+ device,
294
+ raw_scores_all=raw_scores_all,
295
+ window_logits=window_logits,
296
+ memory_to_windows=memory_to_windows,
297
+ dca_router_block_size=dca_router_block_size,
298
+ lexical_features=lexical_features,
299
+ )
300
+ ce_lite_scores = scorer(features).cpu().numpy()
301
+ return raw_scores, window_values, ce_lite_scores
302
+
303
+
304
+ def rerank_candidates(raw_scores_all, candidate_indices, candidate_scores, memory_ids, raw_weight):
305
+ raw_candidate = raw_scores_all[candidate_indices]
306
+ raw_norm = normalize_scores(raw_candidate)
307
+ score_norm = normalize_scores(candidate_scores)
308
+ final = raw_weight * raw_norm + (1.0 - raw_weight) * score_norm
309
+ order = [int(candidate_indices[i]) for i in np.argsort(-final)]
310
+ ranked = [memory_ids[i] for i in order]
311
+ ranked_set = set(ranked)
312
+ raw_order = np.argsort(-raw_scores_all)
313
+ ranked.extend([memory_ids[int(i)] for i in raw_order if memory_ids[int(i)] not in ranked_set])
314
+ return ranked
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2026 ConvMemory contributors
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.