convmemory 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
convmemory/ccge.py ADDED
@@ -0,0 +1,391 @@
1
+ """CCGE-LA conflict-aware candidate-set editor.
2
+
3
+ CCGE-LA stands for Low-Amplitude Counterfactual Conflict Graph Editor. It is a
4
+ lightweight editor that runs after ConvMemory and applies a small residual score
5
+ correction when the retrieved candidate set looks conflict-prone.
6
+
7
+ The module is intentionally checkpoint-agnostic. Applications can attach a
8
+ trained editor with ``ConvMemory.attach_ccge_editor`` or load one from disk with
9
+ ``ConvMemory.load_ccge_editor``.
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ from dataclasses import asdict, dataclass
15
+ from pathlib import Path
16
+ from typing import Sequence
17
+
18
+ import numpy as np
19
+ import torch
20
+ from torch import nn
21
+
22
+ from .hub import resolve_checkpoint_path
23
+ from .scoring import lexical_signature
24
+
25
+
26
+ FEATURE_NAMES = [
27
+ "base_score_z",
28
+ "dense_score_z",
29
+ "position_z",
30
+ "query_overlap_z",
31
+ "base_rank_norm",
32
+ "dense_rank_norm",
33
+ "sim_to_base_top",
34
+ "sim_to_dense_top",
35
+ "semantic_density_top16",
36
+ "token_overlap_to_top",
37
+ "newer_than_base_top",
38
+ "older_than_base_top",
39
+ "abs_pos_gap_top_z",
40
+ "base_margin_1_2",
41
+ "base_entropy_top16",
42
+ "conflict_density_top16",
43
+ "time_span_top16",
44
+ "top_overlap",
45
+ ]
46
+
47
+
48
+ @dataclass(frozen=True)
49
+ class CCGEConfig:
50
+ """Configuration for the public CCGE-LA editor."""
51
+
52
+ feature_dim: int = len(FEATURE_NAMES)
53
+ model_dim: int = 96
54
+ layers: int = 2
55
+ num_heads: int = 4
56
+ dropout: float = 0.08
57
+ gate_bias: float = -2.0
58
+ residual_init: float = 0.35
59
+
60
+
61
+ @dataclass(frozen=True)
62
+ class CCGEFeatureBatch:
63
+ """Feature matrix for one query's candidate set."""
64
+
65
+ candidate_ids: list[str]
66
+ features: np.ndarray
67
+
68
+
69
+ def zscore(values: np.ndarray) -> np.ndarray:
70
+ values = np.asarray(values, dtype=np.float32)
71
+ if values.size == 0:
72
+ return values
73
+ std = float(values.std())
74
+ if std < 1.0e-6:
75
+ return values - float(values.mean())
76
+ return (values - float(values.mean())) / std
77
+
78
+
79
+ def rank_norm(scores: np.ndarray) -> np.ndarray:
80
+ order = np.argsort(-scores, kind="mergesort")
81
+ ranks = np.zeros(len(scores), dtype=np.float32)
82
+ for rank, idx in enumerate(order):
83
+ ranks[int(idx)] = rank / max(1, len(scores) - 1)
84
+ return ranks
85
+
86
+
87
+ def softmax_entropy(values: np.ndarray) -> float:
88
+ if values.size <= 1:
89
+ return 0.0
90
+ x = np.asarray(values, dtype=np.float32)
91
+ x = x - float(x.max())
92
+ p = np.exp(x)
93
+ p = p / max(float(p.sum()), 1.0e-8)
94
+ return float(-(p * np.log(p + 1.0e-8)).sum() / np.log(len(p)))
95
+
96
+
97
+ def normalized_embeddings(embeddings: np.ndarray | None, n: int) -> np.ndarray:
98
+ if embeddings is None:
99
+ return np.eye(n, dtype=np.float32)
100
+ x = np.asarray(embeddings, dtype=np.float32)
101
+ if x.ndim != 2 or x.shape[0] != n:
102
+ raise ValueError("candidate_embeddings must have shape [num_candidates, dim]")
103
+ return x / (np.linalg.norm(x, axis=1, keepdims=True) + 1.0e-8)
104
+
105
+
106
+ def query_overlap_scores(query: str, candidate_texts: Sequence[str]) -> np.ndarray:
107
+ """Lexical overlap scores for query and candidate memories."""
108
+
109
+ query_set, _ = lexical_signature(query)
110
+ values = []
111
+ for text in candidate_texts:
112
+ memory_set, _ = lexical_signature(str(text))
113
+ values.append(len(query_set & memory_set) / max(1, len(query_set)))
114
+ return np.asarray(values, dtype=np.float32)
115
+
116
+
117
+ def token_overlap_to_text(candidate_texts: Sequence[str], top_index: int) -> np.ndarray:
118
+ """Token overlap between each candidate and the selected top candidate."""
119
+
120
+ top_set, _ = lexical_signature(str(candidate_texts[int(top_index)]))
121
+ values = []
122
+ for text in candidate_texts:
123
+ memory_set, _ = lexical_signature(str(text))
124
+ union = top_set | memory_set
125
+ values.append(len(top_set & memory_set) / max(1, len(union)))
126
+ return np.asarray(values, dtype=np.float32)
127
+
128
+
129
+ def build_ccge_features(
130
+ *,
131
+ candidate_ids: Sequence[str],
132
+ convmemory_scores: Sequence[float],
133
+ dense_scores: Sequence[float] | None = None,
134
+ positions: Sequence[float] | None = None,
135
+ candidate_embeddings: np.ndarray | None = None,
136
+ query_overlaps: Sequence[float] | None = None,
137
+ query: str | None = None,
138
+ candidate_texts: Sequence[str] | None = None,
139
+ top_k_density: int = 16,
140
+ ) -> CCGEFeatureBatch:
141
+ """Build CCGE-LA candidate-set features.
142
+
143
+ The features describe the retrieved candidate set. They do not encode
144
+ gold/current/stale labels and are safe to compute at inference time.
145
+ """
146
+
147
+ ids = [str(x) for x in candidate_ids]
148
+ n = len(ids)
149
+ if n == 0:
150
+ raise ValueError("candidate_ids must not be empty")
151
+
152
+ base = np.asarray(convmemory_scores, dtype=np.float32)
153
+ if base.shape[0] != n:
154
+ raise ValueError("convmemory_scores must match candidate_ids")
155
+
156
+ dense = np.asarray(dense_scores if dense_scores is not None else base, dtype=np.float32)
157
+ pos = np.asarray(positions if positions is not None else np.arange(n), dtype=np.float32)
158
+ if query_overlaps is not None:
159
+ overlap = np.asarray(query_overlaps, dtype=np.float32)
160
+ elif query is not None and candidate_texts is not None:
161
+ overlap = query_overlap_scores(query, candidate_texts)
162
+ else:
163
+ overlap = np.zeros(n, dtype=np.float32)
164
+ if dense.shape[0] != n or pos.shape[0] != n or overlap.shape[0] != n:
165
+ raise ValueError("dense_scores, positions, and query_overlaps must match candidate_ids")
166
+
167
+ emb = normalized_embeddings(candidate_embeddings, n)
168
+ base_order = np.argsort(-base, kind="mergesort")
169
+ dense_order = np.argsort(-dense, kind="mergesort")
170
+ top_base = int(base_order[0])
171
+ top_dense = int(dense_order[0])
172
+ topk = base_order[: min(top_k_density, n)]
173
+
174
+ sim_to_base_top = emb @ emb[top_base]
175
+ sim_to_dense_top = emb @ emb[top_dense]
176
+ density = (emb @ emb[topk].T).mean(axis=1) if len(topk) else np.zeros(n, dtype=np.float32)
177
+ if candidate_texts is not None:
178
+ overlap_to_top = token_overlap_to_text(candidate_texts, top_base)
179
+ else:
180
+ overlap_to_top = np.full(n, float(overlap[top_base]), dtype=np.float32)
181
+ pos_gap = np.abs(pos - pos[top_base])
182
+
183
+ sorted_base_z = np.sort(zscore(base))[::-1]
184
+ margin = float(sorted_base_z[0] - sorted_base_z[1]) if len(sorted_base_z) > 1 else 0.0
185
+ entropy = softmax_entropy(zscore(base)[topk])
186
+ conflict_density = (
187
+ float(np.mean((sim_to_base_top[topk] > 0.45) & (np.abs(pos[topk] - pos[top_base]) > 0)))
188
+ if len(topk)
189
+ else 0.0
190
+ )
191
+ span = float(pos[topk].max() - pos[topk].min()) if len(topk) else 0.0
192
+ full_span = max(1.0, float(pos.max() - pos.min()))
193
+ top_overlap = float(overlap[top_base])
194
+
195
+ features = np.stack(
196
+ [
197
+ zscore(base),
198
+ zscore(dense),
199
+ zscore(pos),
200
+ zscore(overlap),
201
+ rank_norm(base),
202
+ rank_norm(dense),
203
+ sim_to_base_top.astype(np.float32),
204
+ sim_to_dense_top.astype(np.float32),
205
+ density.astype(np.float32),
206
+ overlap_to_top.astype(np.float32),
207
+ (pos > pos[top_base]).astype(np.float32),
208
+ (pos < pos[top_base]).astype(np.float32),
209
+ zscore(pos_gap),
210
+ np.full(n, margin, dtype=np.float32),
211
+ np.full(n, entropy, dtype=np.float32),
212
+ np.full(n, conflict_density, dtype=np.float32),
213
+ np.full(n, span / full_span, dtype=np.float32),
214
+ np.full(n, top_overlap, dtype=np.float32),
215
+ ],
216
+ axis=1,
217
+ ).astype(np.float32)
218
+ return CCGEFeatureBatch(candidate_ids=ids, features=features)
219
+
220
+
221
+ class CCGELowAmplitudeEditor(nn.Module):
222
+ """Low-amplitude residual editor over ConvMemory candidate scores."""
223
+
224
+ def __init__(
225
+ self,
226
+ feature_dim: int = len(FEATURE_NAMES),
227
+ *,
228
+ model_dim: int = 96,
229
+ layers: int = 2,
230
+ num_heads: int = 4,
231
+ dropout: float = 0.08,
232
+ gate_bias: float = -2.0,
233
+ residual_init: float = 0.35,
234
+ ):
235
+ super().__init__()
236
+ if model_dim % num_heads != 0:
237
+ raise ValueError("model_dim must be divisible by num_heads")
238
+ self.config = CCGEConfig(
239
+ feature_dim=int(feature_dim),
240
+ model_dim=int(model_dim),
241
+ layers=int(layers),
242
+ num_heads=int(num_heads),
243
+ dropout=float(dropout),
244
+ gate_bias=float(gate_bias),
245
+ residual_init=float(residual_init),
246
+ )
247
+ self.trained_embedding_model_name = None
248
+ self.in_proj = nn.Sequential(
249
+ nn.Linear(feature_dim, model_dim),
250
+ nn.GELU(),
251
+ nn.LayerNorm(model_dim),
252
+ )
253
+ enc = nn.TransformerEncoderLayer(
254
+ d_model=model_dim,
255
+ nhead=num_heads,
256
+ dim_feedforward=model_dim * 3,
257
+ dropout=dropout,
258
+ activation="gelu",
259
+ batch_first=True,
260
+ norm_first=True,
261
+ )
262
+ self.encoder = nn.TransformerEncoder(enc, num_layers=layers)
263
+ self.residual = nn.Sequential(
264
+ nn.Linear(model_dim, model_dim),
265
+ nn.GELU(),
266
+ nn.Dropout(0.05),
267
+ nn.Linear(model_dim, 1),
268
+ )
269
+ self.gate = nn.Sequential(nn.Linear(model_dim + 7, 64), nn.GELU(), nn.Linear(64, 1))
270
+ self.residual_scale = nn.Parameter(torch.tensor(float(residual_init)))
271
+ nn.init.zeros_(self.gate[-1].weight)
272
+ nn.init.constant_(self.gate[-1].bias, gate_bias)
273
+
274
+ def forward(self, features: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
275
+ base = features[..., 0]
276
+ h = self.encoder(self.in_proj(features))
277
+ residual = self.residual(h).squeeze(-1)
278
+ pooled = h.mean(dim=1)
279
+ state = torch.stack(
280
+ [
281
+ features[..., 13].mean(dim=1),
282
+ features[..., 14].mean(dim=1),
283
+ features[..., 15].mean(dim=1),
284
+ features[..., 16].mean(dim=1),
285
+ features[..., 17].mean(dim=1),
286
+ (features[..., 4] < 0.05).float().mean(dim=1),
287
+ features[..., 8].max(dim=1).values,
288
+ ],
289
+ dim=-1,
290
+ )
291
+ gate = torch.sigmoid(self.gate(torch.cat([pooled, state], dim=-1))).squeeze(-1)
292
+ scale = torch.clamp(self.residual_scale, 0.05, 2.0)
293
+ scores = base + gate.unsqueeze(-1) * scale * residual
294
+ return scores, gate
295
+
296
+ @torch.no_grad()
297
+ def edit_batch(
298
+ self,
299
+ batch: CCGEFeatureBatch,
300
+ *,
301
+ device: str | torch.device | None = None,
302
+ ) -> tuple[np.ndarray, float]:
303
+ """Return edited scores and the query-level gate for one feature batch."""
304
+
305
+ if device is None:
306
+ device = next(self.parameters()).device
307
+ self.eval()
308
+ x = torch.tensor(batch.features, dtype=torch.float32, device=device).unsqueeze(0)
309
+ scores, gate = self.to(device)(x)
310
+ return scores.detach().cpu().numpy()[0], float(gate.detach().cpu().numpy()[0])
311
+
312
+ def save_pretrained(self, path: str | Path) -> None:
313
+ """Save a CCGE-LA editor checkpoint."""
314
+
315
+ path = Path(path)
316
+ if path.suffix:
317
+ path.parent.mkdir(parents=True, exist_ok=True)
318
+ target = path
319
+ else:
320
+ path.mkdir(parents=True, exist_ok=True)
321
+ target = path / "ccge_la.pt"
322
+ torch.save(
323
+ {
324
+ "format": "convmemory-ccge-la",
325
+ "version": 1,
326
+ "config": asdict(self.config),
327
+ "state_dict": self.state_dict(),
328
+ "trained_embedding_model_name": getattr(
329
+ self,
330
+ "trained_embedding_model_name",
331
+ None,
332
+ ),
333
+ },
334
+ target,
335
+ )
336
+
337
+ @classmethod
338
+ def from_pretrained(
339
+ cls,
340
+ path: str | Path,
341
+ *,
342
+ device: str | torch.device = "cpu",
343
+ strict: bool = True,
344
+ ) -> "CCGELowAmplitudeEditor":
345
+ """Load a CCGE-LA editor checkpoint from disk or Hugging Face Hub."""
346
+
347
+ path = resolve_checkpoint_path(path)
348
+ source = path / "ccge_la.pt" if path.is_dir() else path
349
+ payload = torch.load(source, map_location="cpu")
350
+ config = payload.get("config", {})
351
+ model = cls(**config)
352
+ state_dict = payload.get("state_dict", payload)
353
+ model.load_state_dict(state_dict, strict=strict)
354
+ model.trained_embedding_model_name = payload.get("trained_embedding_model_name")
355
+ return model.to(device).eval()
356
+
357
+
358
+ def multi_positive_retrieval_loss(scores: torch.Tensor, gold_mask: torch.Tensor) -> torch.Tensor:
359
+ """Retrieval cross-entropy for one or more positive candidates."""
360
+
361
+ all_lse = torch.logsumexp(scores, dim=-1)
362
+ masked = scores.masked_fill(~gold_mask, -1.0e9)
363
+ gold_lse = torch.logsumexp(masked, dim=-1)
364
+ return -(gold_lse - all_lse).mean()
365
+
366
+
367
+ @torch.no_grad()
368
+ def rank_candidates(
369
+ editor: CCGELowAmplitudeEditor,
370
+ batch: CCGEFeatureBatch,
371
+ *,
372
+ device: str | torch.device = "cpu",
373
+ ) -> list[tuple[str, float]]:
374
+ """Return candidate ids sorted by edited CCGE-LA score."""
375
+
376
+ values, _ = editor.edit_batch(batch, device=device)
377
+ order = np.argsort(-values, kind="mergesort")
378
+ return [(batch.candidate_ids[int(i)], float(values[int(i)])) for i in order]
379
+
380
+
381
+ __all__ = [
382
+ "FEATURE_NAMES",
383
+ "CCGEConfig",
384
+ "CCGEFeatureBatch",
385
+ "CCGELowAmplitudeEditor",
386
+ "build_ccge_features",
387
+ "multi_positive_retrieval_loss",
388
+ "query_overlap_scores",
389
+ "rank_candidates",
390
+ "token_overlap_to_text",
391
+ ]
convmemory/encoder.py ADDED
@@ -0,0 +1,150 @@
1
+ import torch
2
+
3
+
4
+ class MixerConvMemoryEncoder(torch.nn.Module):
5
+ """Lightweight temporal encoder over a short memory window.
6
+
7
+ The input shape is `[batch, window, embedding_dim]`. The query embedding is
8
+ used both for feature construction and query-aware pooling.
9
+ """
10
+
11
+ def __init__(
12
+ self,
13
+ dim,
14
+ window_size=5,
15
+ kernel_size=3,
16
+ hidden_dim=256,
17
+ token_mlp_dim=32,
18
+ channel_mlp_dim=512,
19
+ type_vocab_size=0,
20
+ output_mode="residual",
21
+ output_gate_init=0.1,
22
+ score_mode="cosine",
23
+ score_gate_init=0.1,
24
+ ):
25
+ super().__init__()
26
+ self.window_size = window_size
27
+ self.output_mode = output_mode
28
+ self.score_mode = score_mode
29
+ self.type_embedding = None
30
+ if type_vocab_size:
31
+ self.type_embedding = torch.nn.Embedding(type_vocab_size, dim)
32
+
33
+ self.input_proj = torch.nn.Sequential(
34
+ torch.nn.Linear(dim * 3, hidden_dim),
35
+ torch.nn.GELU(),
36
+ torch.nn.LayerNorm(hidden_dim),
37
+ )
38
+ self.conv_norm = torch.nn.LayerNorm(hidden_dim)
39
+ self.depthwise_conv = torch.nn.Conv1d(
40
+ hidden_dim,
41
+ hidden_dim,
42
+ kernel_size=kernel_size,
43
+ padding=kernel_size // 2,
44
+ groups=hidden_dim,
45
+ )
46
+ self.pointwise = torch.nn.Linear(hidden_dim, hidden_dim)
47
+ self.conv_gate = torch.nn.Parameter(torch.tensor(0.1))
48
+
49
+ self.token_norm = torch.nn.LayerNorm(window_size)
50
+ self.token_mlp = torch.nn.Sequential(
51
+ torch.nn.Linear(window_size, token_mlp_dim),
52
+ torch.nn.GELU(),
53
+ torch.nn.Linear(token_mlp_dim, window_size),
54
+ )
55
+ self.token_gate = torch.nn.Parameter(torch.tensor(0.1))
56
+
57
+ self.channel_norm = torch.nn.LayerNorm(hidden_dim)
58
+ self.channel_mlp = torch.nn.Sequential(
59
+ torch.nn.Linear(hidden_dim, channel_mlp_dim),
60
+ torch.nn.GELU(),
61
+ torch.nn.Linear(channel_mlp_dim, hidden_dim),
62
+ )
63
+ self.channel_gate = torch.nn.Parameter(torch.tensor(0.1))
64
+
65
+ self.query_proj = torch.nn.Linear(dim, hidden_dim)
66
+ self.attn_x = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)
67
+ self.attn_q = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)
68
+ self.attn_v = torch.nn.Linear(hidden_dim, 1, bias=False)
69
+ self.output_head = torch.nn.Sequential(
70
+ torch.nn.Linear(hidden_dim * 4, dim),
71
+ torch.nn.LayerNorm(dim),
72
+ )
73
+ self.output_gate = torch.nn.Parameter(torch.tensor(float(output_gate_init)))
74
+ self.score_head = torch.nn.Sequential(
75
+ torch.nn.Linear(dim * 4, hidden_dim),
76
+ torch.nn.GELU(),
77
+ torch.nn.Linear(hidden_dim, 1),
78
+ )
79
+ self.score_gate = torch.nn.Parameter(torch.tensor(float(score_gate_init)))
80
+
81
+ def _token_mix(self, h):
82
+ length = h.shape[1]
83
+ if length < self.window_size:
84
+ pad = torch.zeros(
85
+ h.shape[0],
86
+ self.window_size - length,
87
+ h.shape[2],
88
+ dtype=h.dtype,
89
+ device=h.device,
90
+ )
91
+ h_for_mix = torch.cat([h, pad], dim=1)
92
+ else:
93
+ h_for_mix = h[:, : self.window_size]
94
+
95
+ mixed = h_for_mix.transpose(1, 2)
96
+ mixed = self.token_mlp(self.token_norm(mixed)).transpose(1, 2)
97
+ return mixed[:, :length]
98
+
99
+ def forward(self, x, query=None, type_ids=None):
100
+ base_x = x
101
+ if self.type_embedding is not None and type_ids is not None:
102
+ x = x + self.type_embedding(type_ids)
103
+ base_x = x
104
+ if query is None:
105
+ query = x.mean(dim=1)
106
+
107
+ query_norm = torch.nn.functional.normalize(query, dim=-1)
108
+ base_norm = torch.nn.functional.normalize(base_x, dim=-1)
109
+ base_scores = (base_norm * query_norm[:, None, :]).sum(dim=-1)
110
+ base_weights = torch.softmax(base_scores, dim=1)
111
+ base = (base_x * base_weights[:, :, None]).sum(dim=1)
112
+
113
+ query_per_turn = query[:, None, :].expand(-1, x.shape[1], -1)
114
+ features = torch.cat([x, x * query_per_turn, torch.abs(x - query_per_turn)], dim=-1)
115
+ h = self.input_proj(features)
116
+
117
+ conv_in = self.conv_norm(h).transpose(1, 2)
118
+ conv_out = self.depthwise_conv(conv_in).transpose(1, 2)
119
+ h = h + self.conv_gate * self.pointwise(torch.nn.functional.gelu(conv_out))
120
+
121
+ h = h + self.token_gate * self._token_mix(h)
122
+ h = h + self.channel_gate * self.channel_mlp(self.channel_norm(h))
123
+
124
+ qh = self.query_proj(query)
125
+ attn = self.attn_v(torch.tanh(self.attn_x(h) + self.attn_q(qh)[:, None, :])).squeeze(-1)
126
+ weights = torch.softmax(attn, dim=1)
127
+ pooled = (h * weights[:, :, None]).sum(dim=1)
128
+
129
+ out = self.output_head(
130
+ torch.cat([pooled, qh, pooled * qh, torch.abs(pooled - qh)], dim=-1)
131
+ )
132
+ if self.output_mode == "residual":
133
+ out = base + self.output_gate * out
134
+ return torch.nn.functional.normalize(out, dim=-1)
135
+
136
+ def score_windows(self, x, query=None, type_ids=None):
137
+ vectors = self.forward(x, query=query, type_ids=type_ids)
138
+ if query is None:
139
+ query = x.mean(dim=1)
140
+ query_norm = torch.nn.functional.normalize(query, dim=-1)
141
+ cosine = (vectors * query_norm).sum(dim=-1)
142
+ if self.score_mode == "cosine":
143
+ return cosine
144
+
145
+ features = torch.cat(
146
+ [vectors, query_norm, vectors * query_norm, torch.abs(vectors - query_norm)],
147
+ dim=-1,
148
+ )
149
+ correction = torch.tanh(self.score_head(features).squeeze(-1))
150
+ return cosine + self.score_gate * correction
convmemory/hub.py ADDED
@@ -0,0 +1,45 @@
1
+ """Optional Hugging Face Hub path resolution helpers."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from pathlib import Path
6
+
7
+ try:
8
+ from huggingface_hub import snapshot_download as _hf_snapshot_download
9
+ except Exception: # pragma: no cover - exercised when optional dep is absent
10
+ _hf_snapshot_download = None
11
+
12
+
13
+ def looks_like_hub_id(path: str | Path) -> bool:
14
+ """Return whether a missing path looks like a `namespace/repo` Hub id."""
15
+
16
+ text = str(path).replace("\\", "/").strip()
17
+ if not text or "://" in text or ":" in text:
18
+ return False
19
+ if text.startswith(("/", "./", "../", "~")):
20
+ return False
21
+ parts = text.split("/")
22
+ return len(parts) == 2 and all(parts)
23
+
24
+
25
+ def resolve_checkpoint_path(path: str | Path, *, repo_type: str = "model") -> Path:
26
+ """Resolve a local checkpoint path or download a Hugging Face Hub repo id."""
27
+
28
+ candidate = Path(path)
29
+ if candidate.exists():
30
+ return candidate
31
+ if not looks_like_hub_id(path):
32
+ return candidate
33
+ if _hf_snapshot_download is None:
34
+ raise ValueError(
35
+ "Checkpoint path does not exist and looks like a Hugging Face Hub "
36
+ "repo id, but `huggingface_hub` is not installed. Install it with "
37
+ "`pip install huggingface_hub` or pass a local checkpoint path."
38
+ )
39
+ try:
40
+ return Path(_hf_snapshot_download(repo_id=str(path), repo_type=repo_type))
41
+ except Exception as exc:
42
+ raise ValueError(
43
+ f"Could not download Hugging Face Hub checkpoint repo '{path}'. "
44
+ "Pass a local checkpoint path or verify repo access."
45
+ ) from exc
convmemory/metrics.py ADDED
@@ -0,0 +1,14 @@
1
+ def recall_at_k(ranked_ids, gold_ids, k):
2
+ return len(set(ranked_ids[:k]) & set(gold_ids)) / max(1, len(gold_ids))
3
+
4
+
5
+ def hit_at_k(ranked_ids, gold_ids, k):
6
+ return float(bool(set(ranked_ids[:k]) & set(gold_ids)))
7
+
8
+
9
+ def mrr(ranked_ids, gold_ids):
10
+ gold = set(gold_ids)
11
+ for rank, item_id in enumerate(ranked_ids, start=1):
12
+ if item_id in gold:
13
+ return 1.0 / rank
14
+ return 0.0
convmemory/models.py ADDED
@@ -0,0 +1,31 @@
1
+ from .encoder import MixerConvMemoryEncoder
2
+ from .scoring import CELiteScorer
3
+
4
+
5
+ def build_default_components(
6
+ embedding_dim,
7
+ window_size=5,
8
+ kernel_size=3,
9
+ hidden_dim=256,
10
+ token_mlp_dim=32,
11
+ channel_mlp_dim=512,
12
+ extra_scalar_features=5,
13
+ device="cpu",
14
+ ):
15
+ conv_model = MixerConvMemoryEncoder(
16
+ embedding_dim,
17
+ window_size=window_size,
18
+ kernel_size=kernel_size,
19
+ hidden_dim=hidden_dim,
20
+ token_mlp_dim=token_mlp_dim,
21
+ channel_mlp_dim=channel_mlp_dim,
22
+ output_mode="residual",
23
+ output_gate_init=0.1,
24
+ score_mode="cosine",
25
+ ).to(device)
26
+ scorer = CELiteScorer(
27
+ embedding_dim,
28
+ hidden_dim=hidden_dim,
29
+ extra_scalar_features=extra_scalar_features,
30
+ ).to(device)
31
+ return conv_model, scorer