convmemory 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- convmemory/__init__.py +35 -0
- convmemory/api.py +733 -0
- convmemory/ccge.py +391 -0
- convmemory/encoder.py +150 -0
- convmemory/hub.py +45 -0
- convmemory/metrics.py +14 -0
- convmemory/models.py +31 -0
- convmemory/reranker.py +253 -0
- convmemory/routing.py +208 -0
- convmemory/scoring.py +314 -0
- convmemory-0.4.0.dist-info/LICENSE +21 -0
- convmemory-0.4.0.dist-info/METADATA +517 -0
- convmemory-0.4.0.dist-info/RECORD +15 -0
- convmemory-0.4.0.dist-info/WHEEL +5 -0
- convmemory-0.4.0.dist-info/top_level.txt +1 -0
convmemory/api.py
ADDED
|
@@ -0,0 +1,733 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import Iterable, Optional, Sequence
|
|
4
|
+
import warnings
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
import torch
|
|
8
|
+
from sentence_transformers import SentenceTransformer
|
|
9
|
+
|
|
10
|
+
from .ccge import CCGELowAmplitudeEditor, build_ccge_features
|
|
11
|
+
from .hub import resolve_checkpoint_path
|
|
12
|
+
from .models import build_default_components
|
|
13
|
+
from .reranker import ConvMemoryReranker, RerankConfig, RerankResult
|
|
14
|
+
from .scoring import cosine_scores, lexical_signature
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class ConvMemory:
|
|
18
|
+
"""User-facing ConvMemory reranker.
|
|
19
|
+
|
|
20
|
+
Use `from_pretrained` for normal usage. `from_config` is mainly for
|
|
21
|
+
development and examples because it creates randomly initialized weights.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
def __init__(
|
|
25
|
+
self,
|
|
26
|
+
conv_model,
|
|
27
|
+
scorer,
|
|
28
|
+
config=None,
|
|
29
|
+
device="cpu",
|
|
30
|
+
embedding_model=None,
|
|
31
|
+
embedding_model_name=None,
|
|
32
|
+
model_config=None,
|
|
33
|
+
ccge_editor=None,
|
|
34
|
+
):
|
|
35
|
+
self.device = device
|
|
36
|
+
self.config = config or RerankConfig()
|
|
37
|
+
self.embedding_model_name = embedding_model_name
|
|
38
|
+
self.embedding_model = embedding_model
|
|
39
|
+
self.model_config = model_config or {}
|
|
40
|
+
self.ccge_editor = None
|
|
41
|
+
self.reranker = ConvMemoryReranker(
|
|
42
|
+
conv_model=conv_model,
|
|
43
|
+
scorer=scorer,
|
|
44
|
+
config=self.config,
|
|
45
|
+
device=device,
|
|
46
|
+
)
|
|
47
|
+
self.reranker.conv_model.eval()
|
|
48
|
+
self.reranker.scorer.eval()
|
|
49
|
+
if ccge_editor is not None:
|
|
50
|
+
self.attach_ccge_editor(ccge_editor)
|
|
51
|
+
|
|
52
|
+
@classmethod
|
|
53
|
+
def from_config(
|
|
54
|
+
cls,
|
|
55
|
+
embedding_dim,
|
|
56
|
+
device="cpu",
|
|
57
|
+
embedding_model=None,
|
|
58
|
+
config=None,
|
|
59
|
+
ccge_editor=None,
|
|
60
|
+
**model_kwargs,
|
|
61
|
+
):
|
|
62
|
+
"""Create a ConvMemory instance from dimensions and config.
|
|
63
|
+
|
|
64
|
+
This initializes random weights and is intended for development, tests,
|
|
65
|
+
or custom training code. Pass `embedding_model=None` to use only
|
|
66
|
+
precomputed embeddings, or a model name to attach a local encoder.
|
|
67
|
+
"""
|
|
68
|
+
|
|
69
|
+
rerank_config = config or RerankConfig()
|
|
70
|
+
extra_scalar_features = model_kwargs.get("extra_scalar_features")
|
|
71
|
+
if extra_scalar_features is None:
|
|
72
|
+
extra_scalar_features = 0
|
|
73
|
+
if rerank_config.dca_router_block_size > 0:
|
|
74
|
+
extra_scalar_features += 1
|
|
75
|
+
if rerank_config.lexical_features:
|
|
76
|
+
extra_scalar_features += 4
|
|
77
|
+
model_config = {
|
|
78
|
+
"embedding_dim": int(embedding_dim),
|
|
79
|
+
"window_size": int(model_kwargs.get("window_size", 5)),
|
|
80
|
+
"kernel_size": int(model_kwargs.get("kernel_size", 3)),
|
|
81
|
+
"hidden_dim": int(model_kwargs.get("hidden_dim", 256)),
|
|
82
|
+
"token_mlp_dim": int(model_kwargs.get("token_mlp_dim", 32)),
|
|
83
|
+
"channel_mlp_dim": int(model_kwargs.get("channel_mlp_dim", 512)),
|
|
84
|
+
"extra_scalar_features": int(extra_scalar_features),
|
|
85
|
+
}
|
|
86
|
+
conv_model, scorer = build_default_components(device=device, **model_config)
|
|
87
|
+
embedder = None
|
|
88
|
+
if embedding_model:
|
|
89
|
+
embedder = SentenceTransformer(embedding_model, device=device)
|
|
90
|
+
return cls(
|
|
91
|
+
conv_model=conv_model,
|
|
92
|
+
scorer=scorer,
|
|
93
|
+
config=rerank_config,
|
|
94
|
+
device=device,
|
|
95
|
+
embedding_model=embedder,
|
|
96
|
+
embedding_model_name=embedding_model,
|
|
97
|
+
model_config=model_config,
|
|
98
|
+
ccge_editor=ccge_editor,
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
@classmethod
|
|
102
|
+
def from_pretrained(
|
|
103
|
+
cls,
|
|
104
|
+
path,
|
|
105
|
+
device="cpu",
|
|
106
|
+
embedding_model=None,
|
|
107
|
+
load_ccge: bool = False,
|
|
108
|
+
):
|
|
109
|
+
"""Load a ConvMemory checkpoint from disk or Hugging Face Hub.
|
|
110
|
+
|
|
111
|
+
`embedding_model` may be `None` to use checkpoint metadata, a string to
|
|
112
|
+
override the encoder, or `False` to skip encoder loading for precomputed
|
|
113
|
+
embeddings. `load_ccge=True` auto-attaches `ccge_la.pt` when present;
|
|
114
|
+
the default is `False` so CCGE-LA remains explicit opt-in. If `path`
|
|
115
|
+
does not exist and looks like `namespace/repo`, it is downloaded through
|
|
116
|
+
`huggingface_hub.snapshot_download`.
|
|
117
|
+
"""
|
|
118
|
+
|
|
119
|
+
path = resolve_checkpoint_path(path)
|
|
120
|
+
metadata = json.loads((path / "config.json").read_text(encoding="utf-8"))
|
|
121
|
+
rerank_config = RerankConfig(**metadata["rerank_config"])
|
|
122
|
+
model_config = metadata["model_config"]
|
|
123
|
+
conv_model, scorer = build_default_components(device=device, **model_config)
|
|
124
|
+
state = torch.load(path / "model.pt", map_location="cpu")
|
|
125
|
+
conv_model.load_state_dict(state["conv_model"])
|
|
126
|
+
scorer.load_state_dict(state["scorer"])
|
|
127
|
+
conv_model.to(device).eval()
|
|
128
|
+
scorer.to(device).eval()
|
|
129
|
+
|
|
130
|
+
embedding_model_name = embedding_model
|
|
131
|
+
if embedding_model_name is None:
|
|
132
|
+
embedding_model_name = metadata.get("embedding_model")
|
|
133
|
+
embedder = None
|
|
134
|
+
if embedding_model_name:
|
|
135
|
+
embedder = SentenceTransformer(embedding_model_name, device=device)
|
|
136
|
+
|
|
137
|
+
ccge_editor = None
|
|
138
|
+
ccge_path = path / "ccge_la.pt"
|
|
139
|
+
if load_ccge and ccge_path.exists():
|
|
140
|
+
ccge_editor = CCGELowAmplitudeEditor.from_pretrained(ccge_path, device=device)
|
|
141
|
+
|
|
142
|
+
model = cls(
|
|
143
|
+
conv_model=conv_model,
|
|
144
|
+
scorer=scorer,
|
|
145
|
+
config=rerank_config,
|
|
146
|
+
device=device,
|
|
147
|
+
embedding_model=embedder,
|
|
148
|
+
embedding_model_name=embedding_model_name,
|
|
149
|
+
model_config=model_config,
|
|
150
|
+
ccge_editor=ccge_editor,
|
|
151
|
+
)
|
|
152
|
+
if ccge_editor is not None:
|
|
153
|
+
print(f"[ConvMemory] auto-attached CCGE-LA editor from {ccge_path}")
|
|
154
|
+
return model
|
|
155
|
+
|
|
156
|
+
def save_pretrained(self, path):
|
|
157
|
+
"""Save ConvMemory weights, config, and an attached CCGE-LA editor."""
|
|
158
|
+
|
|
159
|
+
path = Path(path)
|
|
160
|
+
path.mkdir(parents=True, exist_ok=True)
|
|
161
|
+
metadata = {
|
|
162
|
+
"format": "convmemory",
|
|
163
|
+
"version": 1,
|
|
164
|
+
"embedding_model": self.embedding_model_name,
|
|
165
|
+
"model_config": self.model_config,
|
|
166
|
+
"rerank_config": self.config.__dict__,
|
|
167
|
+
}
|
|
168
|
+
(path / "config.json").write_text(
|
|
169
|
+
json.dumps(metadata, indent=2, sort_keys=True),
|
|
170
|
+
encoding="utf-8",
|
|
171
|
+
)
|
|
172
|
+
torch.save(
|
|
173
|
+
{
|
|
174
|
+
"conv_model": self.reranker.conv_model.state_dict(),
|
|
175
|
+
"scorer": self.reranker.scorer.state_dict(),
|
|
176
|
+
},
|
|
177
|
+
path / "model.pt",
|
|
178
|
+
)
|
|
179
|
+
if self.ccge_editor is not None:
|
|
180
|
+
self.ccge_editor.save_pretrained(path / "ccge_la.pt")
|
|
181
|
+
|
|
182
|
+
def attach_ccge_editor(self, editor):
|
|
183
|
+
"""Attach a trained CCGE-LA editor to this ConvMemory instance.
|
|
184
|
+
|
|
185
|
+
Returns `self`. If both the ConvMemory checkpoint and editor declare
|
|
186
|
+
embedding backbone names and they differ, a `UserWarning` is emitted
|
|
187
|
+
because quality may degrade.
|
|
188
|
+
"""
|
|
189
|
+
|
|
190
|
+
if not isinstance(editor, CCGELowAmplitudeEditor):
|
|
191
|
+
raise TypeError("editor must be a CCGELowAmplitudeEditor")
|
|
192
|
+
editor_backbone = getattr(editor, "trained_embedding_model_name", None)
|
|
193
|
+
if self.embedding_model_name and editor_backbone and self.embedding_model_name != editor_backbone:
|
|
194
|
+
warnings.warn(
|
|
195
|
+
"CCGE editor was trained on backbone "
|
|
196
|
+
f"{editor_backbone} but is being attached to ConvMemory with "
|
|
197
|
+
f"backbone {self.embedding_model_name}; quality may degrade.",
|
|
198
|
+
UserWarning,
|
|
199
|
+
stacklevel=2,
|
|
200
|
+
)
|
|
201
|
+
if editor_backbone is None and self.embedding_model_name is not None:
|
|
202
|
+
editor.trained_embedding_model_name = self.embedding_model_name
|
|
203
|
+
self.ccge_editor = editor.to(self.device).eval()
|
|
204
|
+
return self
|
|
205
|
+
|
|
206
|
+
def load_ccge_editor(self, path, strict: bool = True):
|
|
207
|
+
"""Load and attach a CCGE-LA editor checkpoint.
|
|
208
|
+
|
|
209
|
+
`path` may be a local checkpoint path or a Hugging Face Hub repo id.
|
|
210
|
+
Returns `self`. `strict` is forwarded to the editor state-dict loader;
|
|
211
|
+
mismatched embedding backbone metadata emits the same warning as
|
|
212
|
+
`attach_ccge_editor`.
|
|
213
|
+
"""
|
|
214
|
+
|
|
215
|
+
editor = CCGELowAmplitudeEditor.from_pretrained(
|
|
216
|
+
path,
|
|
217
|
+
device=self.device,
|
|
218
|
+
strict=strict,
|
|
219
|
+
)
|
|
220
|
+
return self.attach_ccge_editor(editor)
|
|
221
|
+
|
|
222
|
+
def encode(self, texts):
|
|
223
|
+
"""Encode texts with the attached sentence-transformer encoder.
|
|
224
|
+
|
|
225
|
+
Raises `ValueError` when no encoder is attached; use
|
|
226
|
+
`rerank_embeddings` for precomputed embeddings.
|
|
227
|
+
"""
|
|
228
|
+
|
|
229
|
+
if self.embedding_model is None:
|
|
230
|
+
raise ValueError(
|
|
231
|
+
"No embedding model is attached. Pass embeddings directly with "
|
|
232
|
+
"`rerank_embeddings`, or load with `from_pretrained(..., embedding_model=...)`."
|
|
233
|
+
)
|
|
234
|
+
return self.embedding_model.encode(
|
|
235
|
+
list(texts),
|
|
236
|
+
convert_to_numpy=True,
|
|
237
|
+
normalize_embeddings=True,
|
|
238
|
+
show_progress_bar=False,
|
|
239
|
+
).astype(np.float32)
|
|
240
|
+
|
|
241
|
+
def prewarm_lexical(self, memories: Iterable):
|
|
242
|
+
"""Cache lexical signatures for stable memory stores.
|
|
243
|
+
|
|
244
|
+
This is optional, but useful when reranking many queries over the same
|
|
245
|
+
user or agent memory. It keeps online reranking focused on scoring.
|
|
246
|
+
"""
|
|
247
|
+
_, memory_texts = self._parse_memories(memories)
|
|
248
|
+
for text in memory_texts:
|
|
249
|
+
lexical_signature(text)
|
|
250
|
+
|
|
251
|
+
def rerank(
|
|
252
|
+
self,
|
|
253
|
+
query: str,
|
|
254
|
+
memories: Iterable,
|
|
255
|
+
top_k: Optional[int] = None,
|
|
256
|
+
candidate_ids: Optional[Iterable[str]] = None,
|
|
257
|
+
window_mode=None,
|
|
258
|
+
editor=None,
|
|
259
|
+
ccge_top_n: Optional[int] = None,
|
|
260
|
+
):
|
|
261
|
+
"""Rerank text memories and return `list[RerankResult]`.
|
|
262
|
+
|
|
263
|
+
Encodes `query` and `memories`, optionally restricts to `candidate_ids`,
|
|
264
|
+
and applies `editor="ccge_la"` or a `CCGELowAmplitudeEditor` instance
|
|
265
|
+
after ConvMemory. `ccge_top_n` limits how many top candidates are edited.
|
|
266
|
+
Raises `ValueError` for invalid editor or window-mode settings.
|
|
267
|
+
"""
|
|
268
|
+
|
|
269
|
+
memory_ids, memory_texts = self._parse_memories(memories)
|
|
270
|
+
embeddings = self.encode([query, *memory_texts])
|
|
271
|
+
query_embedding = embeddings[0]
|
|
272
|
+
memory_embeddings = embeddings[1:]
|
|
273
|
+
candidate_indices = None
|
|
274
|
+
if candidate_ids is not None:
|
|
275
|
+
id_to_idx = {memory_id: i for i, memory_id in enumerate(memory_ids)}
|
|
276
|
+
candidate_indices = [
|
|
277
|
+
id_to_idx[str(memory_id)]
|
|
278
|
+
for memory_id in candidate_ids
|
|
279
|
+
if str(memory_id) in id_to_idx
|
|
280
|
+
]
|
|
281
|
+
results = self.rerank_embeddings(
|
|
282
|
+
query_embedding=query_embedding,
|
|
283
|
+
memory_embeddings=memory_embeddings,
|
|
284
|
+
memory_ids=memory_ids,
|
|
285
|
+
memory_texts=memory_texts,
|
|
286
|
+
query=query,
|
|
287
|
+
candidate_indices=candidate_indices,
|
|
288
|
+
window_mode=window_mode,
|
|
289
|
+
editor=editor,
|
|
290
|
+
ccge_top_n=ccge_top_n,
|
|
291
|
+
)
|
|
292
|
+
return results[:top_k] if top_k is not None else results
|
|
293
|
+
|
|
294
|
+
def retrieve(
|
|
295
|
+
self,
|
|
296
|
+
query: str,
|
|
297
|
+
memories: Iterable,
|
|
298
|
+
top_k: Optional[int] = 10,
|
|
299
|
+
mode: str = "rerank",
|
|
300
|
+
candidate_ids: Optional[Iterable[str]] = None,
|
|
301
|
+
protected_k: int = 10,
|
|
302
|
+
context_budget: Optional[int] = None,
|
|
303
|
+
expansion_policy: str = "balanced",
|
|
304
|
+
expert_rankers: Optional[Sequence["ConvMemory"]] = None,
|
|
305
|
+
window_mode=None,
|
|
306
|
+
editor=None,
|
|
307
|
+
ccge_top_n: Optional[int] = None,
|
|
308
|
+
):
|
|
309
|
+
"""Retrieve memories and return `list[RerankResult]`.
|
|
310
|
+
|
|
311
|
+
`mode="rerank"` returns the normal ConvMemory ranking.
|
|
312
|
+
`mode="expand"` protects the strongest reranked memories, then fills the
|
|
313
|
+
remaining context budget with complementary candidates. `editor` and
|
|
314
|
+
`ccge_top_n` are passed through to the scoring path. Raises `ValueError`
|
|
315
|
+
for unknown modes, policies, editors, or window modes.
|
|
316
|
+
"""
|
|
317
|
+
selected_mode = mode.lower().strip()
|
|
318
|
+
if selected_mode == "rerank":
|
|
319
|
+
return self.rerank(
|
|
320
|
+
query=query,
|
|
321
|
+
memories=memories,
|
|
322
|
+
top_k=top_k,
|
|
323
|
+
candidate_ids=candidate_ids,
|
|
324
|
+
window_mode=window_mode,
|
|
325
|
+
editor=editor,
|
|
326
|
+
ccge_top_n=ccge_top_n,
|
|
327
|
+
)
|
|
328
|
+
if selected_mode not in {"expand", "context", "expand_context"}:
|
|
329
|
+
raise ValueError("mode must be either 'rerank' or 'expand'")
|
|
330
|
+
|
|
331
|
+
if context_budget is None:
|
|
332
|
+
budget = top_k if top_k is not None else protected_k + 5
|
|
333
|
+
if budget <= protected_k:
|
|
334
|
+
budget = protected_k + 5
|
|
335
|
+
else:
|
|
336
|
+
budget = context_budget
|
|
337
|
+
return self.expand_context(
|
|
338
|
+
query=query,
|
|
339
|
+
memories=memories,
|
|
340
|
+
protected_k=protected_k,
|
|
341
|
+
context_budget=budget,
|
|
342
|
+
candidate_ids=candidate_ids,
|
|
343
|
+
expansion_policy=expansion_policy,
|
|
344
|
+
expert_rankers=expert_rankers,
|
|
345
|
+
window_mode=window_mode,
|
|
346
|
+
editor=editor,
|
|
347
|
+
ccge_top_n=ccge_top_n,
|
|
348
|
+
)
|
|
349
|
+
|
|
350
|
+
def expand_context(
|
|
351
|
+
self,
|
|
352
|
+
query: str,
|
|
353
|
+
memories: Iterable,
|
|
354
|
+
protected_k: int = 10,
|
|
355
|
+
context_budget: int = 15,
|
|
356
|
+
candidate_ids: Optional[Iterable[str]] = None,
|
|
357
|
+
expansion_policy: str = "balanced",
|
|
358
|
+
expert_rankers: Optional[Sequence["ConvMemory"]] = None,
|
|
359
|
+
window_mode=None,
|
|
360
|
+
editor=None,
|
|
361
|
+
ccge_top_n: Optional[int] = None,
|
|
362
|
+
):
|
|
363
|
+
"""Build a wider memory context and return `list[RerankResult]`.
|
|
364
|
+
|
|
365
|
+
The first `protected_k` memories come from the main ConvMemory ranking.
|
|
366
|
+
The remaining slots are filled from complementary rankings, which can
|
|
367
|
+
include raw dense retrieval, candidate-local window scoring, optional
|
|
368
|
+
expert rankers, and optional CCGE-LA editing via `editor`/`ccge_top_n`.
|
|
369
|
+
Raises `ValueError` for invalid expansion policies or editor settings.
|
|
370
|
+
"""
|
|
371
|
+
memory_ids, memory_texts = self._parse_memories(memories)
|
|
372
|
+
embeddings = self.encode([query, *memory_texts])
|
|
373
|
+
query_embedding = embeddings[0]
|
|
374
|
+
memory_embeddings = embeddings[1:]
|
|
375
|
+
candidate_indices = None
|
|
376
|
+
if candidate_ids is not None:
|
|
377
|
+
id_to_idx = {memory_id: i for i, memory_id in enumerate(memory_ids)}
|
|
378
|
+
candidate_indices = [
|
|
379
|
+
id_to_idx[str(memory_id)]
|
|
380
|
+
for memory_id in candidate_ids
|
|
381
|
+
if str(memory_id) in id_to_idx
|
|
382
|
+
]
|
|
383
|
+
return self.expand_context_embeddings(
|
|
384
|
+
query_embedding=query_embedding,
|
|
385
|
+
memory_embeddings=memory_embeddings,
|
|
386
|
+
memory_ids=memory_ids,
|
|
387
|
+
memory_texts=memory_texts,
|
|
388
|
+
query=query,
|
|
389
|
+
protected_k=protected_k,
|
|
390
|
+
context_budget=context_budget,
|
|
391
|
+
candidate_indices=candidate_indices,
|
|
392
|
+
expansion_policy=expansion_policy,
|
|
393
|
+
expert_rankers=expert_rankers,
|
|
394
|
+
window_mode=window_mode,
|
|
395
|
+
editor=editor,
|
|
396
|
+
ccge_top_n=ccge_top_n,
|
|
397
|
+
)
|
|
398
|
+
|
|
399
|
+
def rerank_embeddings(
|
|
400
|
+
self,
|
|
401
|
+
query_embedding,
|
|
402
|
+
memory_embeddings,
|
|
403
|
+
memory_ids,
|
|
404
|
+
memory_texts=None,
|
|
405
|
+
query="",
|
|
406
|
+
top_k: Optional[int] = None,
|
|
407
|
+
candidate_indices=None,
|
|
408
|
+
window_mode=None,
|
|
409
|
+
editor=None,
|
|
410
|
+
ccge_top_n: Optional[int] = None,
|
|
411
|
+
):
|
|
412
|
+
"""Rerank precomputed embeddings and return `list[RerankResult]`.
|
|
413
|
+
|
|
414
|
+
This is the no-encoder path for systems that already store embeddings.
|
|
415
|
+
`editor="ccge_la"` applies an attached CCGE-LA editor; `ccge_top_n`
|
|
416
|
+
limits the edited prefix. Raises `ValueError` for invalid editor or
|
|
417
|
+
window-mode settings.
|
|
418
|
+
"""
|
|
419
|
+
|
|
420
|
+
results = self.reranker.rerank_embeddings(
|
|
421
|
+
query_embedding=query_embedding,
|
|
422
|
+
memory_embeddings=memory_embeddings,
|
|
423
|
+
memory_ids=memory_ids,
|
|
424
|
+
memory_texts=memory_texts,
|
|
425
|
+
query=query,
|
|
426
|
+
candidate_indices=candidate_indices,
|
|
427
|
+
window_mode=window_mode,
|
|
428
|
+
)
|
|
429
|
+
results = self._maybe_apply_editor(
|
|
430
|
+
results=results,
|
|
431
|
+
query_embedding=query_embedding,
|
|
432
|
+
memory_embeddings=memory_embeddings,
|
|
433
|
+
memory_ids=memory_ids,
|
|
434
|
+
memory_texts=memory_texts,
|
|
435
|
+
query=query,
|
|
436
|
+
candidate_indices=candidate_indices,
|
|
437
|
+
editor=editor,
|
|
438
|
+
ccge_top_n=ccge_top_n,
|
|
439
|
+
)
|
|
440
|
+
return results[:top_k] if top_k is not None else results
|
|
441
|
+
|
|
442
|
+
def expand_context_embeddings(
|
|
443
|
+
self,
|
|
444
|
+
query_embedding,
|
|
445
|
+
memory_embeddings,
|
|
446
|
+
memory_ids,
|
|
447
|
+
memory_texts=None,
|
|
448
|
+
query="",
|
|
449
|
+
protected_k: int = 10,
|
|
450
|
+
context_budget: int = 15,
|
|
451
|
+
candidate_indices=None,
|
|
452
|
+
expansion_policy: str = "balanced",
|
|
453
|
+
expert_rankers: Optional[Sequence["ConvMemory"]] = None,
|
|
454
|
+
window_mode=None,
|
|
455
|
+
editor=None,
|
|
456
|
+
ccge_top_n: Optional[int] = None,
|
|
457
|
+
):
|
|
458
|
+
"""Expand context over precomputed embeddings.
|
|
459
|
+
|
|
460
|
+
Protects a ConvMemory prefix, fills the remaining budget from
|
|
461
|
+
complementary rankings, and optionally applies `editor="ccge_la"`.
|
|
462
|
+
Returns `list[RerankResult]`; raises `ValueError` for invalid policy,
|
|
463
|
+
editor, or window-mode arguments.
|
|
464
|
+
"""
|
|
465
|
+
|
|
466
|
+
if context_budget <= 0:
|
|
467
|
+
return []
|
|
468
|
+
protected_k = max(0, min(int(protected_k), int(context_budget)))
|
|
469
|
+
policy = expansion_policy.lower().strip()
|
|
470
|
+
if policy not in {"balanced", "model", "raw", "local"}:
|
|
471
|
+
raise ValueError(
|
|
472
|
+
"expansion_policy must be one of: 'balanced', 'model', 'raw', 'local'"
|
|
473
|
+
)
|
|
474
|
+
|
|
475
|
+
memory_ids = [str(memory_id) for memory_id in memory_ids]
|
|
476
|
+
base_results = self.rerank_embeddings(
|
|
477
|
+
query_embedding=query_embedding,
|
|
478
|
+
memory_embeddings=memory_embeddings,
|
|
479
|
+
memory_ids=memory_ids,
|
|
480
|
+
memory_texts=memory_texts,
|
|
481
|
+
query=query,
|
|
482
|
+
candidate_indices=candidate_indices,
|
|
483
|
+
window_mode=window_mode,
|
|
484
|
+
editor=editor,
|
|
485
|
+
ccge_top_n=ccge_top_n,
|
|
486
|
+
)
|
|
487
|
+
if context_budget <= protected_k:
|
|
488
|
+
return self._rerank_with_new_positions(base_results[:context_budget])
|
|
489
|
+
|
|
490
|
+
result_by_id = {result.memory_id: result for result in base_results}
|
|
491
|
+
selected = list(base_results[:protected_k])
|
|
492
|
+
selected_ids = {result.memory_id for result in selected}
|
|
493
|
+
|
|
494
|
+
rankings = []
|
|
495
|
+
if policy in {"balanced", "model"}:
|
|
496
|
+
rankings.append([result.memory_id for result in base_results])
|
|
497
|
+
if policy in {"balanced", "raw"}:
|
|
498
|
+
rankings.append(
|
|
499
|
+
self._raw_ranking_ids(
|
|
500
|
+
query_embedding=query_embedding,
|
|
501
|
+
memory_embeddings=memory_embeddings,
|
|
502
|
+
memory_ids=memory_ids,
|
|
503
|
+
candidate_indices=candidate_indices,
|
|
504
|
+
)
|
|
505
|
+
)
|
|
506
|
+
if policy in {"balanced", "local"}:
|
|
507
|
+
local_results = self.rerank_embeddings(
|
|
508
|
+
query_embedding=query_embedding,
|
|
509
|
+
memory_embeddings=memory_embeddings,
|
|
510
|
+
memory_ids=memory_ids,
|
|
511
|
+
memory_texts=memory_texts,
|
|
512
|
+
query=query,
|
|
513
|
+
candidate_indices=candidate_indices,
|
|
514
|
+
window_mode="candidate_local",
|
|
515
|
+
editor=editor,
|
|
516
|
+
ccge_top_n=ccge_top_n,
|
|
517
|
+
)
|
|
518
|
+
rankings.append([result.memory_id for result in local_results])
|
|
519
|
+
result_by_id.update({result.memory_id: result for result in local_results})
|
|
520
|
+
|
|
521
|
+
for expert in expert_rankers or []:
|
|
522
|
+
expert_results = expert.rerank_embeddings(
|
|
523
|
+
query_embedding=query_embedding,
|
|
524
|
+
memory_embeddings=memory_embeddings,
|
|
525
|
+
memory_ids=memory_ids,
|
|
526
|
+
memory_texts=memory_texts,
|
|
527
|
+
query=query,
|
|
528
|
+
candidate_indices=candidate_indices,
|
|
529
|
+
window_mode=window_mode,
|
|
530
|
+
)
|
|
531
|
+
rankings.append([result.memory_id for result in expert_results])
|
|
532
|
+
for result in expert_results:
|
|
533
|
+
result_by_id.setdefault(result.memory_id, result)
|
|
534
|
+
|
|
535
|
+
self._round_robin_fill(
|
|
536
|
+
selected=selected,
|
|
537
|
+
selected_ids=selected_ids,
|
|
538
|
+
rankings=rankings,
|
|
539
|
+
result_by_id=result_by_id,
|
|
540
|
+
context_budget=int(context_budget),
|
|
541
|
+
)
|
|
542
|
+
if len(selected) < context_budget:
|
|
543
|
+
self._round_robin_fill(
|
|
544
|
+
selected=selected,
|
|
545
|
+
selected_ids=selected_ids,
|
|
546
|
+
rankings=[[result.memory_id for result in base_results]],
|
|
547
|
+
result_by_id=result_by_id,
|
|
548
|
+
context_budget=int(context_budget),
|
|
549
|
+
)
|
|
550
|
+
return self._rerank_with_new_positions(selected)
|
|
551
|
+
|
|
552
|
+
def _resolve_editor(self, editor):
|
|
553
|
+
message = "editor must be None, 'ccge_la', or a CCGELowAmplitudeEditor instance"
|
|
554
|
+
if editor is None:
|
|
555
|
+
return None
|
|
556
|
+
if isinstance(editor, CCGELowAmplitudeEditor):
|
|
557
|
+
return editor.to(self.device).eval()
|
|
558
|
+
if isinstance(editor, str):
|
|
559
|
+
if editor != "ccge_la":
|
|
560
|
+
raise ValueError(message)
|
|
561
|
+
if self.ccge_editor is None:
|
|
562
|
+
raise ValueError(
|
|
563
|
+
"No CCGE-LA editor is attached. Call `load_ccge_editor(path)` "
|
|
564
|
+
"or `attach_ccge_editor(editor)` before using editor='ccge_la'."
|
|
565
|
+
)
|
|
566
|
+
return self.ccge_editor
|
|
567
|
+
raise ValueError(message)
|
|
568
|
+
|
|
569
|
+
def _maybe_apply_editor(
|
|
570
|
+
self,
|
|
571
|
+
*,
|
|
572
|
+
results,
|
|
573
|
+
query_embedding,
|
|
574
|
+
memory_embeddings,
|
|
575
|
+
memory_ids,
|
|
576
|
+
memory_texts,
|
|
577
|
+
query,
|
|
578
|
+
candidate_indices,
|
|
579
|
+
editor,
|
|
580
|
+
ccge_top_n: Optional[int],
|
|
581
|
+
):
|
|
582
|
+
editor_module = self._resolve_editor(editor)
|
|
583
|
+
if editor_module is None or not results:
|
|
584
|
+
return results
|
|
585
|
+
return self._apply_ccge_editor(
|
|
586
|
+
results=results,
|
|
587
|
+
editor=editor_module,
|
|
588
|
+
memory_embeddings=memory_embeddings,
|
|
589
|
+
memory_ids=memory_ids,
|
|
590
|
+
memory_texts=memory_texts,
|
|
591
|
+
query=query,
|
|
592
|
+
candidate_indices=candidate_indices,
|
|
593
|
+
ccge_top_n=ccge_top_n,
|
|
594
|
+
)
|
|
595
|
+
|
|
596
|
+
def _apply_ccge_editor(
|
|
597
|
+
self,
|
|
598
|
+
*,
|
|
599
|
+
results,
|
|
600
|
+
editor,
|
|
601
|
+
memory_embeddings,
|
|
602
|
+
memory_ids,
|
|
603
|
+
memory_texts,
|
|
604
|
+
query,
|
|
605
|
+
candidate_indices,
|
|
606
|
+
ccge_top_n: Optional[int],
|
|
607
|
+
):
|
|
608
|
+
memory_ids = [str(memory_id) for memory_id in memory_ids]
|
|
609
|
+
id_to_idx = {memory_id: i for i, memory_id in enumerate(memory_ids)}
|
|
610
|
+
|
|
611
|
+
if candidate_indices is None:
|
|
612
|
+
edit_count = min(int(self.config.candidate_top_n), len(results))
|
|
613
|
+
edit_candidates = list(results[:edit_count])
|
|
614
|
+
else:
|
|
615
|
+
candidate_ids = {
|
|
616
|
+
memory_ids[int(idx)]
|
|
617
|
+
for idx in np.asarray(candidate_indices, dtype=np.int64)
|
|
618
|
+
if 0 <= int(idx) < len(memory_ids)
|
|
619
|
+
}
|
|
620
|
+
edit_candidates = [result for result in results if result.memory_id in candidate_ids]
|
|
621
|
+
|
|
622
|
+
if ccge_top_n is not None:
|
|
623
|
+
edit_candidates = edit_candidates[: max(0, int(ccge_top_n))]
|
|
624
|
+
if not edit_candidates:
|
|
625
|
+
return results
|
|
626
|
+
|
|
627
|
+
edit_ids = [result.memory_id for result in edit_candidates]
|
|
628
|
+
edit_id_set = set(edit_ids)
|
|
629
|
+
edit_indices = [id_to_idx[memory_id] for memory_id in edit_ids]
|
|
630
|
+
matrix = np.asarray(memory_embeddings, dtype=np.float32)
|
|
631
|
+
if matrix.shape[0] != len(memory_ids):
|
|
632
|
+
raise ValueError("memory_embeddings must match memory_ids")
|
|
633
|
+
|
|
634
|
+
text_by_id = {result.memory_id: result.text for result in results}
|
|
635
|
+
if memory_texts is not None:
|
|
636
|
+
for memory_id, text in zip(memory_ids, memory_texts):
|
|
637
|
+
text_by_id.setdefault(memory_id, text)
|
|
638
|
+
candidate_texts = [text_by_id.get(memory_id) or "" for memory_id in edit_ids]
|
|
639
|
+
|
|
640
|
+
batch = build_ccge_features(
|
|
641
|
+
candidate_ids=edit_ids,
|
|
642
|
+
convmemory_scores=[result.score for result in edit_candidates],
|
|
643
|
+
dense_scores=[result.raw_score for result in edit_candidates],
|
|
644
|
+
positions=edit_indices,
|
|
645
|
+
candidate_embeddings=matrix[edit_indices],
|
|
646
|
+
query=query,
|
|
647
|
+
candidate_texts=candidate_texts,
|
|
648
|
+
)
|
|
649
|
+
edited_scores, _ = editor.edit_batch(batch, device=self.device)
|
|
650
|
+
score_by_id = {
|
|
651
|
+
memory_id: float(score)
|
|
652
|
+
for memory_id, score in zip(edit_ids, edited_scores)
|
|
653
|
+
}
|
|
654
|
+
original_by_id = {result.memory_id: result for result in results}
|
|
655
|
+
edited_results = [
|
|
656
|
+
RerankResult(
|
|
657
|
+
memory_id=memory_id,
|
|
658
|
+
score=score_by_id[memory_id],
|
|
659
|
+
raw_score=original_by_id[memory_id].raw_score,
|
|
660
|
+
rank=rank,
|
|
661
|
+
text=original_by_id[memory_id].text,
|
|
662
|
+
)
|
|
663
|
+
for rank, memory_id in enumerate(
|
|
664
|
+
sorted(edit_ids, key=lambda memory_id: score_by_id[memory_id], reverse=True),
|
|
665
|
+
start=1,
|
|
666
|
+
)
|
|
667
|
+
]
|
|
668
|
+
tail = [result for result in results if result.memory_id not in edit_id_set]
|
|
669
|
+
return self._rerank_with_new_positions([*edited_results, *tail])
|
|
670
|
+
|
|
671
|
+
@staticmethod
|
|
672
|
+
def _raw_ranking_ids(query_embedding, memory_embeddings, memory_ids, candidate_indices=None):
|
|
673
|
+
matrix = np.asarray(memory_embeddings, dtype=np.float32)
|
|
674
|
+
matrix = matrix / (np.linalg.norm(matrix, axis=1, keepdims=True) + 1e-8)
|
|
675
|
+
query_vec = np.asarray(query_embedding, dtype=np.float32)
|
|
676
|
+
query_vec = query_vec / (np.linalg.norm(query_vec) + 1e-8)
|
|
677
|
+
raw_scores = cosine_scores(query_vec, matrix)
|
|
678
|
+
raw_order = [int(i) for i in np.argsort(-raw_scores)]
|
|
679
|
+
if candidate_indices is None:
|
|
680
|
+
return [memory_ids[i] for i in raw_order]
|
|
681
|
+
|
|
682
|
+
candidate_set = {int(i) for i in candidate_indices}
|
|
683
|
+
candidate_order = [i for i in raw_order if i in candidate_set]
|
|
684
|
+
tail_order = [i for i in raw_order if i not in candidate_set]
|
|
685
|
+
return [memory_ids[i] for i in [*candidate_order, *tail_order]]
|
|
686
|
+
|
|
687
|
+
@staticmethod
|
|
688
|
+
def _round_robin_fill(selected, selected_ids, rankings, result_by_id, context_budget):
|
|
689
|
+
if not rankings:
|
|
690
|
+
return
|
|
691
|
+
cursors = [0 for _ in rankings]
|
|
692
|
+
while len(selected) < context_budget:
|
|
693
|
+
added = False
|
|
694
|
+
for ranking_idx, ranking in enumerate(rankings):
|
|
695
|
+
while cursors[ranking_idx] < len(ranking):
|
|
696
|
+
memory_id = ranking[cursors[ranking_idx]]
|
|
697
|
+
cursors[ranking_idx] += 1
|
|
698
|
+
if memory_id in selected_ids:
|
|
699
|
+
continue
|
|
700
|
+
selected_ids.add(memory_id)
|
|
701
|
+
selected.append(result_by_id[memory_id])
|
|
702
|
+
added = True
|
|
703
|
+
break
|
|
704
|
+
if len(selected) >= context_budget:
|
|
705
|
+
return
|
|
706
|
+
if not added:
|
|
707
|
+
return
|
|
708
|
+
|
|
709
|
+
@staticmethod
|
|
710
|
+
def _rerank_with_new_positions(results):
|
|
711
|
+
return [
|
|
712
|
+
RerankResult(
|
|
713
|
+
memory_id=result.memory_id,
|
|
714
|
+
score=result.score,
|
|
715
|
+
raw_score=result.raw_score,
|
|
716
|
+
rank=rank,
|
|
717
|
+
text=result.text,
|
|
718
|
+
)
|
|
719
|
+
for rank, result in enumerate(results, start=1)
|
|
720
|
+
]
|
|
721
|
+
|
|
722
|
+
@staticmethod
|
|
723
|
+
def _parse_memories(memories):
|
|
724
|
+
memory_ids = []
|
|
725
|
+
memory_texts = []
|
|
726
|
+
for i, memory in enumerate(memories):
|
|
727
|
+
if isinstance(memory, str):
|
|
728
|
+
memory_ids.append(str(i))
|
|
729
|
+
memory_texts.append(memory)
|
|
730
|
+
else:
|
|
731
|
+
memory_ids.append(str(memory.get("id", i)))
|
|
732
|
+
memory_texts.append(str(memory.get("text", "")))
|
|
733
|
+
return memory_ids, memory_texts
|