convmemory 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
convmemory/api.py ADDED
@@ -0,0 +1,733 @@
1
+ import json
2
+ from pathlib import Path
3
+ from typing import Iterable, Optional, Sequence
4
+ import warnings
5
+
6
+ import numpy as np
7
+ import torch
8
+ from sentence_transformers import SentenceTransformer
9
+
10
+ from .ccge import CCGELowAmplitudeEditor, build_ccge_features
11
+ from .hub import resolve_checkpoint_path
12
+ from .models import build_default_components
13
+ from .reranker import ConvMemoryReranker, RerankConfig, RerankResult
14
+ from .scoring import cosine_scores, lexical_signature
15
+
16
+
17
+ class ConvMemory:
18
+ """User-facing ConvMemory reranker.
19
+
20
+ Use `from_pretrained` for normal usage. `from_config` is mainly for
21
+ development and examples because it creates randomly initialized weights.
22
+ """
23
+
24
+ def __init__(
25
+ self,
26
+ conv_model,
27
+ scorer,
28
+ config=None,
29
+ device="cpu",
30
+ embedding_model=None,
31
+ embedding_model_name=None,
32
+ model_config=None,
33
+ ccge_editor=None,
34
+ ):
35
+ self.device = device
36
+ self.config = config or RerankConfig()
37
+ self.embedding_model_name = embedding_model_name
38
+ self.embedding_model = embedding_model
39
+ self.model_config = model_config or {}
40
+ self.ccge_editor = None
41
+ self.reranker = ConvMemoryReranker(
42
+ conv_model=conv_model,
43
+ scorer=scorer,
44
+ config=self.config,
45
+ device=device,
46
+ )
47
+ self.reranker.conv_model.eval()
48
+ self.reranker.scorer.eval()
49
+ if ccge_editor is not None:
50
+ self.attach_ccge_editor(ccge_editor)
51
+
52
+ @classmethod
53
+ def from_config(
54
+ cls,
55
+ embedding_dim,
56
+ device="cpu",
57
+ embedding_model=None,
58
+ config=None,
59
+ ccge_editor=None,
60
+ **model_kwargs,
61
+ ):
62
+ """Create a ConvMemory instance from dimensions and config.
63
+
64
+ This initializes random weights and is intended for development, tests,
65
+ or custom training code. Pass `embedding_model=None` to use only
66
+ precomputed embeddings, or a model name to attach a local encoder.
67
+ """
68
+
69
+ rerank_config = config or RerankConfig()
70
+ extra_scalar_features = model_kwargs.get("extra_scalar_features")
71
+ if extra_scalar_features is None:
72
+ extra_scalar_features = 0
73
+ if rerank_config.dca_router_block_size > 0:
74
+ extra_scalar_features += 1
75
+ if rerank_config.lexical_features:
76
+ extra_scalar_features += 4
77
+ model_config = {
78
+ "embedding_dim": int(embedding_dim),
79
+ "window_size": int(model_kwargs.get("window_size", 5)),
80
+ "kernel_size": int(model_kwargs.get("kernel_size", 3)),
81
+ "hidden_dim": int(model_kwargs.get("hidden_dim", 256)),
82
+ "token_mlp_dim": int(model_kwargs.get("token_mlp_dim", 32)),
83
+ "channel_mlp_dim": int(model_kwargs.get("channel_mlp_dim", 512)),
84
+ "extra_scalar_features": int(extra_scalar_features),
85
+ }
86
+ conv_model, scorer = build_default_components(device=device, **model_config)
87
+ embedder = None
88
+ if embedding_model:
89
+ embedder = SentenceTransformer(embedding_model, device=device)
90
+ return cls(
91
+ conv_model=conv_model,
92
+ scorer=scorer,
93
+ config=rerank_config,
94
+ device=device,
95
+ embedding_model=embedder,
96
+ embedding_model_name=embedding_model,
97
+ model_config=model_config,
98
+ ccge_editor=ccge_editor,
99
+ )
100
+
101
+ @classmethod
102
+ def from_pretrained(
103
+ cls,
104
+ path,
105
+ device="cpu",
106
+ embedding_model=None,
107
+ load_ccge: bool = False,
108
+ ):
109
+ """Load a ConvMemory checkpoint from disk or Hugging Face Hub.
110
+
111
+ `embedding_model` may be `None` to use checkpoint metadata, a string to
112
+ override the encoder, or `False` to skip encoder loading for precomputed
113
+ embeddings. `load_ccge=True` auto-attaches `ccge_la.pt` when present;
114
+ the default is `False` so CCGE-LA remains explicit opt-in. If `path`
115
+ does not exist and looks like `namespace/repo`, it is downloaded through
116
+ `huggingface_hub.snapshot_download`.
117
+ """
118
+
119
+ path = resolve_checkpoint_path(path)
120
+ metadata = json.loads((path / "config.json").read_text(encoding="utf-8"))
121
+ rerank_config = RerankConfig(**metadata["rerank_config"])
122
+ model_config = metadata["model_config"]
123
+ conv_model, scorer = build_default_components(device=device, **model_config)
124
+ state = torch.load(path / "model.pt", map_location="cpu")
125
+ conv_model.load_state_dict(state["conv_model"])
126
+ scorer.load_state_dict(state["scorer"])
127
+ conv_model.to(device).eval()
128
+ scorer.to(device).eval()
129
+
130
+ embedding_model_name = embedding_model
131
+ if embedding_model_name is None:
132
+ embedding_model_name = metadata.get("embedding_model")
133
+ embedder = None
134
+ if embedding_model_name:
135
+ embedder = SentenceTransformer(embedding_model_name, device=device)
136
+
137
+ ccge_editor = None
138
+ ccge_path = path / "ccge_la.pt"
139
+ if load_ccge and ccge_path.exists():
140
+ ccge_editor = CCGELowAmplitudeEditor.from_pretrained(ccge_path, device=device)
141
+
142
+ model = cls(
143
+ conv_model=conv_model,
144
+ scorer=scorer,
145
+ config=rerank_config,
146
+ device=device,
147
+ embedding_model=embedder,
148
+ embedding_model_name=embedding_model_name,
149
+ model_config=model_config,
150
+ ccge_editor=ccge_editor,
151
+ )
152
+ if ccge_editor is not None:
153
+ print(f"[ConvMemory] auto-attached CCGE-LA editor from {ccge_path}")
154
+ return model
155
+
156
+ def save_pretrained(self, path):
157
+ """Save ConvMemory weights, config, and an attached CCGE-LA editor."""
158
+
159
+ path = Path(path)
160
+ path.mkdir(parents=True, exist_ok=True)
161
+ metadata = {
162
+ "format": "convmemory",
163
+ "version": 1,
164
+ "embedding_model": self.embedding_model_name,
165
+ "model_config": self.model_config,
166
+ "rerank_config": self.config.__dict__,
167
+ }
168
+ (path / "config.json").write_text(
169
+ json.dumps(metadata, indent=2, sort_keys=True),
170
+ encoding="utf-8",
171
+ )
172
+ torch.save(
173
+ {
174
+ "conv_model": self.reranker.conv_model.state_dict(),
175
+ "scorer": self.reranker.scorer.state_dict(),
176
+ },
177
+ path / "model.pt",
178
+ )
179
+ if self.ccge_editor is not None:
180
+ self.ccge_editor.save_pretrained(path / "ccge_la.pt")
181
+
182
+ def attach_ccge_editor(self, editor):
183
+ """Attach a trained CCGE-LA editor to this ConvMemory instance.
184
+
185
+ Returns `self`. If both the ConvMemory checkpoint and editor declare
186
+ embedding backbone names and they differ, a `UserWarning` is emitted
187
+ because quality may degrade.
188
+ """
189
+
190
+ if not isinstance(editor, CCGELowAmplitudeEditor):
191
+ raise TypeError("editor must be a CCGELowAmplitudeEditor")
192
+ editor_backbone = getattr(editor, "trained_embedding_model_name", None)
193
+ if self.embedding_model_name and editor_backbone and self.embedding_model_name != editor_backbone:
194
+ warnings.warn(
195
+ "CCGE editor was trained on backbone "
196
+ f"{editor_backbone} but is being attached to ConvMemory with "
197
+ f"backbone {self.embedding_model_name}; quality may degrade.",
198
+ UserWarning,
199
+ stacklevel=2,
200
+ )
201
+ if editor_backbone is None and self.embedding_model_name is not None:
202
+ editor.trained_embedding_model_name = self.embedding_model_name
203
+ self.ccge_editor = editor.to(self.device).eval()
204
+ return self
205
+
206
+ def load_ccge_editor(self, path, strict: bool = True):
207
+ """Load and attach a CCGE-LA editor checkpoint.
208
+
209
+ `path` may be a local checkpoint path or a Hugging Face Hub repo id.
210
+ Returns `self`. `strict` is forwarded to the editor state-dict loader;
211
+ mismatched embedding backbone metadata emits the same warning as
212
+ `attach_ccge_editor`.
213
+ """
214
+
215
+ editor = CCGELowAmplitudeEditor.from_pretrained(
216
+ path,
217
+ device=self.device,
218
+ strict=strict,
219
+ )
220
+ return self.attach_ccge_editor(editor)
221
+
222
+ def encode(self, texts):
223
+ """Encode texts with the attached sentence-transformer encoder.
224
+
225
+ Raises `ValueError` when no encoder is attached; use
226
+ `rerank_embeddings` for precomputed embeddings.
227
+ """
228
+
229
+ if self.embedding_model is None:
230
+ raise ValueError(
231
+ "No embedding model is attached. Pass embeddings directly with "
232
+ "`rerank_embeddings`, or load with `from_pretrained(..., embedding_model=...)`."
233
+ )
234
+ return self.embedding_model.encode(
235
+ list(texts),
236
+ convert_to_numpy=True,
237
+ normalize_embeddings=True,
238
+ show_progress_bar=False,
239
+ ).astype(np.float32)
240
+
241
+ def prewarm_lexical(self, memories: Iterable):
242
+ """Cache lexical signatures for stable memory stores.
243
+
244
+ This is optional, but useful when reranking many queries over the same
245
+ user or agent memory. It keeps online reranking focused on scoring.
246
+ """
247
+ _, memory_texts = self._parse_memories(memories)
248
+ for text in memory_texts:
249
+ lexical_signature(text)
250
+
251
+ def rerank(
252
+ self,
253
+ query: str,
254
+ memories: Iterable,
255
+ top_k: Optional[int] = None,
256
+ candidate_ids: Optional[Iterable[str]] = None,
257
+ window_mode=None,
258
+ editor=None,
259
+ ccge_top_n: Optional[int] = None,
260
+ ):
261
+ """Rerank text memories and return `list[RerankResult]`.
262
+
263
+ Encodes `query` and `memories`, optionally restricts to `candidate_ids`,
264
+ and applies `editor="ccge_la"` or a `CCGELowAmplitudeEditor` instance
265
+ after ConvMemory. `ccge_top_n` limits how many top candidates are edited.
266
+ Raises `ValueError` for invalid editor or window-mode settings.
267
+ """
268
+
269
+ memory_ids, memory_texts = self._parse_memories(memories)
270
+ embeddings = self.encode([query, *memory_texts])
271
+ query_embedding = embeddings[0]
272
+ memory_embeddings = embeddings[1:]
273
+ candidate_indices = None
274
+ if candidate_ids is not None:
275
+ id_to_idx = {memory_id: i for i, memory_id in enumerate(memory_ids)}
276
+ candidate_indices = [
277
+ id_to_idx[str(memory_id)]
278
+ for memory_id in candidate_ids
279
+ if str(memory_id) in id_to_idx
280
+ ]
281
+ results = self.rerank_embeddings(
282
+ query_embedding=query_embedding,
283
+ memory_embeddings=memory_embeddings,
284
+ memory_ids=memory_ids,
285
+ memory_texts=memory_texts,
286
+ query=query,
287
+ candidate_indices=candidate_indices,
288
+ window_mode=window_mode,
289
+ editor=editor,
290
+ ccge_top_n=ccge_top_n,
291
+ )
292
+ return results[:top_k] if top_k is not None else results
293
+
294
+ def retrieve(
295
+ self,
296
+ query: str,
297
+ memories: Iterable,
298
+ top_k: Optional[int] = 10,
299
+ mode: str = "rerank",
300
+ candidate_ids: Optional[Iterable[str]] = None,
301
+ protected_k: int = 10,
302
+ context_budget: Optional[int] = None,
303
+ expansion_policy: str = "balanced",
304
+ expert_rankers: Optional[Sequence["ConvMemory"]] = None,
305
+ window_mode=None,
306
+ editor=None,
307
+ ccge_top_n: Optional[int] = None,
308
+ ):
309
+ """Retrieve memories and return `list[RerankResult]`.
310
+
311
+ `mode="rerank"` returns the normal ConvMemory ranking.
312
+ `mode="expand"` protects the strongest reranked memories, then fills the
313
+ remaining context budget with complementary candidates. `editor` and
314
+ `ccge_top_n` are passed through to the scoring path. Raises `ValueError`
315
+ for unknown modes, policies, editors, or window modes.
316
+ """
317
+ selected_mode = mode.lower().strip()
318
+ if selected_mode == "rerank":
319
+ return self.rerank(
320
+ query=query,
321
+ memories=memories,
322
+ top_k=top_k,
323
+ candidate_ids=candidate_ids,
324
+ window_mode=window_mode,
325
+ editor=editor,
326
+ ccge_top_n=ccge_top_n,
327
+ )
328
+ if selected_mode not in {"expand", "context", "expand_context"}:
329
+ raise ValueError("mode must be either 'rerank' or 'expand'")
330
+
331
+ if context_budget is None:
332
+ budget = top_k if top_k is not None else protected_k + 5
333
+ if budget <= protected_k:
334
+ budget = protected_k + 5
335
+ else:
336
+ budget = context_budget
337
+ return self.expand_context(
338
+ query=query,
339
+ memories=memories,
340
+ protected_k=protected_k,
341
+ context_budget=budget,
342
+ candidate_ids=candidate_ids,
343
+ expansion_policy=expansion_policy,
344
+ expert_rankers=expert_rankers,
345
+ window_mode=window_mode,
346
+ editor=editor,
347
+ ccge_top_n=ccge_top_n,
348
+ )
349
+
350
+ def expand_context(
351
+ self,
352
+ query: str,
353
+ memories: Iterable,
354
+ protected_k: int = 10,
355
+ context_budget: int = 15,
356
+ candidate_ids: Optional[Iterable[str]] = None,
357
+ expansion_policy: str = "balanced",
358
+ expert_rankers: Optional[Sequence["ConvMemory"]] = None,
359
+ window_mode=None,
360
+ editor=None,
361
+ ccge_top_n: Optional[int] = None,
362
+ ):
363
+ """Build a wider memory context and return `list[RerankResult]`.
364
+
365
+ The first `protected_k` memories come from the main ConvMemory ranking.
366
+ The remaining slots are filled from complementary rankings, which can
367
+ include raw dense retrieval, candidate-local window scoring, optional
368
+ expert rankers, and optional CCGE-LA editing via `editor`/`ccge_top_n`.
369
+ Raises `ValueError` for invalid expansion policies or editor settings.
370
+ """
371
+ memory_ids, memory_texts = self._parse_memories(memories)
372
+ embeddings = self.encode([query, *memory_texts])
373
+ query_embedding = embeddings[0]
374
+ memory_embeddings = embeddings[1:]
375
+ candidate_indices = None
376
+ if candidate_ids is not None:
377
+ id_to_idx = {memory_id: i for i, memory_id in enumerate(memory_ids)}
378
+ candidate_indices = [
379
+ id_to_idx[str(memory_id)]
380
+ for memory_id in candidate_ids
381
+ if str(memory_id) in id_to_idx
382
+ ]
383
+ return self.expand_context_embeddings(
384
+ query_embedding=query_embedding,
385
+ memory_embeddings=memory_embeddings,
386
+ memory_ids=memory_ids,
387
+ memory_texts=memory_texts,
388
+ query=query,
389
+ protected_k=protected_k,
390
+ context_budget=context_budget,
391
+ candidate_indices=candidate_indices,
392
+ expansion_policy=expansion_policy,
393
+ expert_rankers=expert_rankers,
394
+ window_mode=window_mode,
395
+ editor=editor,
396
+ ccge_top_n=ccge_top_n,
397
+ )
398
+
399
+ def rerank_embeddings(
400
+ self,
401
+ query_embedding,
402
+ memory_embeddings,
403
+ memory_ids,
404
+ memory_texts=None,
405
+ query="",
406
+ top_k: Optional[int] = None,
407
+ candidate_indices=None,
408
+ window_mode=None,
409
+ editor=None,
410
+ ccge_top_n: Optional[int] = None,
411
+ ):
412
+ """Rerank precomputed embeddings and return `list[RerankResult]`.
413
+
414
+ This is the no-encoder path for systems that already store embeddings.
415
+ `editor="ccge_la"` applies an attached CCGE-LA editor; `ccge_top_n`
416
+ limits the edited prefix. Raises `ValueError` for invalid editor or
417
+ window-mode settings.
418
+ """
419
+
420
+ results = self.reranker.rerank_embeddings(
421
+ query_embedding=query_embedding,
422
+ memory_embeddings=memory_embeddings,
423
+ memory_ids=memory_ids,
424
+ memory_texts=memory_texts,
425
+ query=query,
426
+ candidate_indices=candidate_indices,
427
+ window_mode=window_mode,
428
+ )
429
+ results = self._maybe_apply_editor(
430
+ results=results,
431
+ query_embedding=query_embedding,
432
+ memory_embeddings=memory_embeddings,
433
+ memory_ids=memory_ids,
434
+ memory_texts=memory_texts,
435
+ query=query,
436
+ candidate_indices=candidate_indices,
437
+ editor=editor,
438
+ ccge_top_n=ccge_top_n,
439
+ )
440
+ return results[:top_k] if top_k is not None else results
441
+
442
+ def expand_context_embeddings(
443
+ self,
444
+ query_embedding,
445
+ memory_embeddings,
446
+ memory_ids,
447
+ memory_texts=None,
448
+ query="",
449
+ protected_k: int = 10,
450
+ context_budget: int = 15,
451
+ candidate_indices=None,
452
+ expansion_policy: str = "balanced",
453
+ expert_rankers: Optional[Sequence["ConvMemory"]] = None,
454
+ window_mode=None,
455
+ editor=None,
456
+ ccge_top_n: Optional[int] = None,
457
+ ):
458
+ """Expand context over precomputed embeddings.
459
+
460
+ Protects a ConvMemory prefix, fills the remaining budget from
461
+ complementary rankings, and optionally applies `editor="ccge_la"`.
462
+ Returns `list[RerankResult]`; raises `ValueError` for invalid policy,
463
+ editor, or window-mode arguments.
464
+ """
465
+
466
+ if context_budget <= 0:
467
+ return []
468
+ protected_k = max(0, min(int(protected_k), int(context_budget)))
469
+ policy = expansion_policy.lower().strip()
470
+ if policy not in {"balanced", "model", "raw", "local"}:
471
+ raise ValueError(
472
+ "expansion_policy must be one of: 'balanced', 'model', 'raw', 'local'"
473
+ )
474
+
475
+ memory_ids = [str(memory_id) for memory_id in memory_ids]
476
+ base_results = self.rerank_embeddings(
477
+ query_embedding=query_embedding,
478
+ memory_embeddings=memory_embeddings,
479
+ memory_ids=memory_ids,
480
+ memory_texts=memory_texts,
481
+ query=query,
482
+ candidate_indices=candidate_indices,
483
+ window_mode=window_mode,
484
+ editor=editor,
485
+ ccge_top_n=ccge_top_n,
486
+ )
487
+ if context_budget <= protected_k:
488
+ return self._rerank_with_new_positions(base_results[:context_budget])
489
+
490
+ result_by_id = {result.memory_id: result for result in base_results}
491
+ selected = list(base_results[:protected_k])
492
+ selected_ids = {result.memory_id for result in selected}
493
+
494
+ rankings = []
495
+ if policy in {"balanced", "model"}:
496
+ rankings.append([result.memory_id for result in base_results])
497
+ if policy in {"balanced", "raw"}:
498
+ rankings.append(
499
+ self._raw_ranking_ids(
500
+ query_embedding=query_embedding,
501
+ memory_embeddings=memory_embeddings,
502
+ memory_ids=memory_ids,
503
+ candidate_indices=candidate_indices,
504
+ )
505
+ )
506
+ if policy in {"balanced", "local"}:
507
+ local_results = self.rerank_embeddings(
508
+ query_embedding=query_embedding,
509
+ memory_embeddings=memory_embeddings,
510
+ memory_ids=memory_ids,
511
+ memory_texts=memory_texts,
512
+ query=query,
513
+ candidate_indices=candidate_indices,
514
+ window_mode="candidate_local",
515
+ editor=editor,
516
+ ccge_top_n=ccge_top_n,
517
+ )
518
+ rankings.append([result.memory_id for result in local_results])
519
+ result_by_id.update({result.memory_id: result for result in local_results})
520
+
521
+ for expert in expert_rankers or []:
522
+ expert_results = expert.rerank_embeddings(
523
+ query_embedding=query_embedding,
524
+ memory_embeddings=memory_embeddings,
525
+ memory_ids=memory_ids,
526
+ memory_texts=memory_texts,
527
+ query=query,
528
+ candidate_indices=candidate_indices,
529
+ window_mode=window_mode,
530
+ )
531
+ rankings.append([result.memory_id for result in expert_results])
532
+ for result in expert_results:
533
+ result_by_id.setdefault(result.memory_id, result)
534
+
535
+ self._round_robin_fill(
536
+ selected=selected,
537
+ selected_ids=selected_ids,
538
+ rankings=rankings,
539
+ result_by_id=result_by_id,
540
+ context_budget=int(context_budget),
541
+ )
542
+ if len(selected) < context_budget:
543
+ self._round_robin_fill(
544
+ selected=selected,
545
+ selected_ids=selected_ids,
546
+ rankings=[[result.memory_id for result in base_results]],
547
+ result_by_id=result_by_id,
548
+ context_budget=int(context_budget),
549
+ )
550
+ return self._rerank_with_new_positions(selected)
551
+
552
+ def _resolve_editor(self, editor):
553
+ message = "editor must be None, 'ccge_la', or a CCGELowAmplitudeEditor instance"
554
+ if editor is None:
555
+ return None
556
+ if isinstance(editor, CCGELowAmplitudeEditor):
557
+ return editor.to(self.device).eval()
558
+ if isinstance(editor, str):
559
+ if editor != "ccge_la":
560
+ raise ValueError(message)
561
+ if self.ccge_editor is None:
562
+ raise ValueError(
563
+ "No CCGE-LA editor is attached. Call `load_ccge_editor(path)` "
564
+ "or `attach_ccge_editor(editor)` before using editor='ccge_la'."
565
+ )
566
+ return self.ccge_editor
567
+ raise ValueError(message)
568
+
569
+ def _maybe_apply_editor(
570
+ self,
571
+ *,
572
+ results,
573
+ query_embedding,
574
+ memory_embeddings,
575
+ memory_ids,
576
+ memory_texts,
577
+ query,
578
+ candidate_indices,
579
+ editor,
580
+ ccge_top_n: Optional[int],
581
+ ):
582
+ editor_module = self._resolve_editor(editor)
583
+ if editor_module is None or not results:
584
+ return results
585
+ return self._apply_ccge_editor(
586
+ results=results,
587
+ editor=editor_module,
588
+ memory_embeddings=memory_embeddings,
589
+ memory_ids=memory_ids,
590
+ memory_texts=memory_texts,
591
+ query=query,
592
+ candidate_indices=candidate_indices,
593
+ ccge_top_n=ccge_top_n,
594
+ )
595
+
596
+ def _apply_ccge_editor(
597
+ self,
598
+ *,
599
+ results,
600
+ editor,
601
+ memory_embeddings,
602
+ memory_ids,
603
+ memory_texts,
604
+ query,
605
+ candidate_indices,
606
+ ccge_top_n: Optional[int],
607
+ ):
608
+ memory_ids = [str(memory_id) for memory_id in memory_ids]
609
+ id_to_idx = {memory_id: i for i, memory_id in enumerate(memory_ids)}
610
+
611
+ if candidate_indices is None:
612
+ edit_count = min(int(self.config.candidate_top_n), len(results))
613
+ edit_candidates = list(results[:edit_count])
614
+ else:
615
+ candidate_ids = {
616
+ memory_ids[int(idx)]
617
+ for idx in np.asarray(candidate_indices, dtype=np.int64)
618
+ if 0 <= int(idx) < len(memory_ids)
619
+ }
620
+ edit_candidates = [result for result in results if result.memory_id in candidate_ids]
621
+
622
+ if ccge_top_n is not None:
623
+ edit_candidates = edit_candidates[: max(0, int(ccge_top_n))]
624
+ if not edit_candidates:
625
+ return results
626
+
627
+ edit_ids = [result.memory_id for result in edit_candidates]
628
+ edit_id_set = set(edit_ids)
629
+ edit_indices = [id_to_idx[memory_id] for memory_id in edit_ids]
630
+ matrix = np.asarray(memory_embeddings, dtype=np.float32)
631
+ if matrix.shape[0] != len(memory_ids):
632
+ raise ValueError("memory_embeddings must match memory_ids")
633
+
634
+ text_by_id = {result.memory_id: result.text for result in results}
635
+ if memory_texts is not None:
636
+ for memory_id, text in zip(memory_ids, memory_texts):
637
+ text_by_id.setdefault(memory_id, text)
638
+ candidate_texts = [text_by_id.get(memory_id) or "" for memory_id in edit_ids]
639
+
640
+ batch = build_ccge_features(
641
+ candidate_ids=edit_ids,
642
+ convmemory_scores=[result.score for result in edit_candidates],
643
+ dense_scores=[result.raw_score for result in edit_candidates],
644
+ positions=edit_indices,
645
+ candidate_embeddings=matrix[edit_indices],
646
+ query=query,
647
+ candidate_texts=candidate_texts,
648
+ )
649
+ edited_scores, _ = editor.edit_batch(batch, device=self.device)
650
+ score_by_id = {
651
+ memory_id: float(score)
652
+ for memory_id, score in zip(edit_ids, edited_scores)
653
+ }
654
+ original_by_id = {result.memory_id: result for result in results}
655
+ edited_results = [
656
+ RerankResult(
657
+ memory_id=memory_id,
658
+ score=score_by_id[memory_id],
659
+ raw_score=original_by_id[memory_id].raw_score,
660
+ rank=rank,
661
+ text=original_by_id[memory_id].text,
662
+ )
663
+ for rank, memory_id in enumerate(
664
+ sorted(edit_ids, key=lambda memory_id: score_by_id[memory_id], reverse=True),
665
+ start=1,
666
+ )
667
+ ]
668
+ tail = [result for result in results if result.memory_id not in edit_id_set]
669
+ return self._rerank_with_new_positions([*edited_results, *tail])
670
+
671
+ @staticmethod
672
+ def _raw_ranking_ids(query_embedding, memory_embeddings, memory_ids, candidate_indices=None):
673
+ matrix = np.asarray(memory_embeddings, dtype=np.float32)
674
+ matrix = matrix / (np.linalg.norm(matrix, axis=1, keepdims=True) + 1e-8)
675
+ query_vec = np.asarray(query_embedding, dtype=np.float32)
676
+ query_vec = query_vec / (np.linalg.norm(query_vec) + 1e-8)
677
+ raw_scores = cosine_scores(query_vec, matrix)
678
+ raw_order = [int(i) for i in np.argsort(-raw_scores)]
679
+ if candidate_indices is None:
680
+ return [memory_ids[i] for i in raw_order]
681
+
682
+ candidate_set = {int(i) for i in candidate_indices}
683
+ candidate_order = [i for i in raw_order if i in candidate_set]
684
+ tail_order = [i for i in raw_order if i not in candidate_set]
685
+ return [memory_ids[i] for i in [*candidate_order, *tail_order]]
686
+
687
+ @staticmethod
688
+ def _round_robin_fill(selected, selected_ids, rankings, result_by_id, context_budget):
689
+ if not rankings:
690
+ return
691
+ cursors = [0 for _ in rankings]
692
+ while len(selected) < context_budget:
693
+ added = False
694
+ for ranking_idx, ranking in enumerate(rankings):
695
+ while cursors[ranking_idx] < len(ranking):
696
+ memory_id = ranking[cursors[ranking_idx]]
697
+ cursors[ranking_idx] += 1
698
+ if memory_id in selected_ids:
699
+ continue
700
+ selected_ids.add(memory_id)
701
+ selected.append(result_by_id[memory_id])
702
+ added = True
703
+ break
704
+ if len(selected) >= context_budget:
705
+ return
706
+ if not added:
707
+ return
708
+
709
+ @staticmethod
710
+ def _rerank_with_new_positions(results):
711
+ return [
712
+ RerankResult(
713
+ memory_id=result.memory_id,
714
+ score=result.score,
715
+ raw_score=result.raw_score,
716
+ rank=rank,
717
+ text=result.text,
718
+ )
719
+ for rank, result in enumerate(results, start=1)
720
+ ]
721
+
722
+ @staticmethod
723
+ def _parse_memories(memories):
724
+ memory_ids = []
725
+ memory_texts = []
726
+ for i, memory in enumerate(memories):
727
+ if isinstance(memory, str):
728
+ memory_ids.append(str(i))
729
+ memory_texts.append(memory)
730
+ else:
731
+ memory_ids.append(str(memory.get("id", i)))
732
+ memory_texts.append(str(memory.get("text", "")))
733
+ return memory_ids, memory_texts