convmemory 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- convmemory/__init__.py +35 -0
- convmemory/api.py +733 -0
- convmemory/ccge.py +391 -0
- convmemory/encoder.py +150 -0
- convmemory/hub.py +45 -0
- convmemory/metrics.py +14 -0
- convmemory/models.py +31 -0
- convmemory/reranker.py +253 -0
- convmemory/routing.py +208 -0
- convmemory/scoring.py +314 -0
- convmemory-0.4.0.dist-info/LICENSE +21 -0
- convmemory-0.4.0.dist-info/METADATA +517 -0
- convmemory-0.4.0.dist-info/RECORD +15 -0
- convmemory-0.4.0.dist-info/WHEEL +5 -0
- convmemory-0.4.0.dist-info/top_level.txt +1 -0
convmemory/reranker.py
ADDED
|
@@ -0,0 +1,253 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from typing import Iterable, Optional
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
from .scoring import (
|
|
8
|
+
build_memory_to_windows,
|
|
9
|
+
cosine_scores,
|
|
10
|
+
rerank_candidates,
|
|
11
|
+
score_ce_lite,
|
|
12
|
+
window_scores,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@dataclass
|
|
17
|
+
class RerankConfig:
|
|
18
|
+
window_size: int = 5
|
|
19
|
+
stride: int = 1
|
|
20
|
+
candidate_top_n: int = 500
|
|
21
|
+
raw_weight: float = 0.0
|
|
22
|
+
dca_router_block_size: int = 32
|
|
23
|
+
lexical_features: bool = True
|
|
24
|
+
window_mode: str = "full"
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@dataclass
|
|
28
|
+
class RerankResult:
|
|
29
|
+
memory_id: str
|
|
30
|
+
score: float
|
|
31
|
+
raw_score: float
|
|
32
|
+
rank: int
|
|
33
|
+
text: Optional[str] = None
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def sliding_windows(num_items: int, window_size: int, stride: int):
|
|
37
|
+
if num_items <= 0:
|
|
38
|
+
return []
|
|
39
|
+
if num_items <= window_size:
|
|
40
|
+
return [list(range(num_items))]
|
|
41
|
+
|
|
42
|
+
windows = []
|
|
43
|
+
for start in range(0, num_items - window_size + 1, stride):
|
|
44
|
+
windows.append(list(range(start, start + window_size)))
|
|
45
|
+
last = list(range(num_items - window_size, num_items))
|
|
46
|
+
if windows[-1] != last:
|
|
47
|
+
windows.append(last)
|
|
48
|
+
return windows
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def normalize_rows(matrix):
|
|
52
|
+
matrix = np.asarray(matrix, dtype=np.float32)
|
|
53
|
+
norms = np.linalg.norm(matrix, axis=1, keepdims=True) + 1e-8
|
|
54
|
+
return matrix / norms
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def window_tensor(memory_embeddings, windows):
|
|
58
|
+
if not windows:
|
|
59
|
+
return torch.zeros((0, 0, memory_embeddings.shape[1]), dtype=torch.float32)
|
|
60
|
+
if all(len(window) == len(windows[0]) for window in windows):
|
|
61
|
+
indices = np.asarray(windows, dtype=np.int64)
|
|
62
|
+
return torch.as_tensor(memory_embeddings[indices], dtype=torch.float32)
|
|
63
|
+
max_len = max(len(window) for window in windows)
|
|
64
|
+
batch = np.zeros((len(windows), max_len, memory_embeddings.shape[1]), dtype=np.float32)
|
|
65
|
+
for i, window in enumerate(windows):
|
|
66
|
+
batch[i, : len(window)] = memory_embeddings[window]
|
|
67
|
+
return torch.tensor(batch, dtype=torch.float32)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def candidate_local_windows(num_items, candidate_indices, window_size):
|
|
71
|
+
if num_items <= 0:
|
|
72
|
+
return []
|
|
73
|
+
if num_items <= window_size:
|
|
74
|
+
return [list(range(num_items))]
|
|
75
|
+
|
|
76
|
+
windows = []
|
|
77
|
+
seen = set()
|
|
78
|
+
half = window_size // 2
|
|
79
|
+
for idx in candidate_indices:
|
|
80
|
+
idx = int(idx)
|
|
81
|
+
start = idx - half
|
|
82
|
+
end = start + window_size
|
|
83
|
+
if start < 0:
|
|
84
|
+
start = 0
|
|
85
|
+
end = window_size
|
|
86
|
+
if end > num_items:
|
|
87
|
+
end = num_items
|
|
88
|
+
start = max(0, end - window_size)
|
|
89
|
+
window = tuple(range(start, end))
|
|
90
|
+
if window not in seen:
|
|
91
|
+
seen.add(window)
|
|
92
|
+
windows.append(list(window))
|
|
93
|
+
return windows
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
class ConvMemoryReranker:
|
|
97
|
+
"""Plug-in reranker over precomputed memory embeddings.
|
|
98
|
+
|
|
99
|
+
The class deliberately does not own text embedding. In production, users can
|
|
100
|
+
bring any retriever or embedding model, pass the raw top-k candidates here,
|
|
101
|
+
and get a ConvMemory-enhanced ordering back.
|
|
102
|
+
"""
|
|
103
|
+
|
|
104
|
+
def __init__(self, conv_model, scorer, config=None, device="cpu"):
|
|
105
|
+
self.conv_model = conv_model
|
|
106
|
+
self.scorer = scorer
|
|
107
|
+
self.config = config or RerankConfig()
|
|
108
|
+
self.device = device
|
|
109
|
+
|
|
110
|
+
def make_item(
|
|
111
|
+
self,
|
|
112
|
+
query_embedding,
|
|
113
|
+
memory_embeddings,
|
|
114
|
+
memory_ids,
|
|
115
|
+
memory_texts: Optional[Iterable[str]] = None,
|
|
116
|
+
query: str = "",
|
|
117
|
+
):
|
|
118
|
+
memory_ids = [str(x) for x in memory_ids]
|
|
119
|
+
memory_embeddings = normalize_rows(memory_embeddings)
|
|
120
|
+
query_embedding = np.asarray(query_embedding, dtype=np.float32)
|
|
121
|
+
query_embedding = query_embedding / (np.linalg.norm(query_embedding) + 1e-8)
|
|
122
|
+
if memory_texts is None:
|
|
123
|
+
memory_texts = ["" for _ in memory_ids]
|
|
124
|
+
memories = [
|
|
125
|
+
{"id": memory_id, "text": text}
|
|
126
|
+
for memory_id, text in zip(memory_ids, memory_texts)
|
|
127
|
+
]
|
|
128
|
+
windows = sliding_windows(
|
|
129
|
+
len(memory_ids),
|
|
130
|
+
self.config.window_size,
|
|
131
|
+
self.config.stride,
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
return {
|
|
135
|
+
"question_id": "query",
|
|
136
|
+
"question_type": "unknown",
|
|
137
|
+
"query": query,
|
|
138
|
+
"query_embedding": query_embedding.astype(np.float32),
|
|
139
|
+
"memory_embeddings": memory_embeddings.astype(np.float32),
|
|
140
|
+
"memory_ids": memory_ids,
|
|
141
|
+
"memories": memories,
|
|
142
|
+
"windows": windows,
|
|
143
|
+
"window_tensor": None,
|
|
144
|
+
"gold_memory_ids": [],
|
|
145
|
+
}
|
|
146
|
+
|
|
147
|
+
def rerank_embeddings(
|
|
148
|
+
self,
|
|
149
|
+
query_embedding,
|
|
150
|
+
memory_embeddings,
|
|
151
|
+
memory_ids,
|
|
152
|
+
memory_texts: Optional[Iterable[str]] = None,
|
|
153
|
+
query: str = "",
|
|
154
|
+
candidate_indices=None,
|
|
155
|
+
candidate_top_n: Optional[int] = None,
|
|
156
|
+
raw_weight: Optional[float] = None,
|
|
157
|
+
window_mode: Optional[str] = None,
|
|
158
|
+
):
|
|
159
|
+
item = self.make_item(
|
|
160
|
+
query_embedding=query_embedding,
|
|
161
|
+
memory_embeddings=memory_embeddings,
|
|
162
|
+
memory_ids=memory_ids,
|
|
163
|
+
memory_texts=memory_texts,
|
|
164
|
+
query=query,
|
|
165
|
+
)
|
|
166
|
+
return self.rerank_item(
|
|
167
|
+
item,
|
|
168
|
+
candidate_indices=candidate_indices,
|
|
169
|
+
candidate_top_n=candidate_top_n,
|
|
170
|
+
raw_weight=raw_weight,
|
|
171
|
+
window_mode=window_mode,
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
def rerank_item(
|
|
175
|
+
self,
|
|
176
|
+
item,
|
|
177
|
+
candidate_indices=None,
|
|
178
|
+
candidate_top_n: Optional[int] = None,
|
|
179
|
+
raw_weight: Optional[float] = None,
|
|
180
|
+
window_mode: Optional[str] = None,
|
|
181
|
+
):
|
|
182
|
+
raw_scores = cosine_scores(item["query_embedding"], item["memory_embeddings"])
|
|
183
|
+
if candidate_indices is None:
|
|
184
|
+
top_n = candidate_top_n or self.config.candidate_top_n
|
|
185
|
+
candidate_indices = np.argsort(-raw_scores)[: min(top_n, len(raw_scores))]
|
|
186
|
+
else:
|
|
187
|
+
candidate_indices = np.asarray(candidate_indices, dtype=np.int64)
|
|
188
|
+
|
|
189
|
+
scoring_item = item
|
|
190
|
+
selected_window_mode = window_mode or self.config.window_mode
|
|
191
|
+
if selected_window_mode == "candidate_local":
|
|
192
|
+
local_windows = candidate_local_windows(
|
|
193
|
+
len(item["memory_ids"]),
|
|
194
|
+
candidate_indices,
|
|
195
|
+
self.config.window_size,
|
|
196
|
+
)
|
|
197
|
+
scoring_item = {
|
|
198
|
+
**item,
|
|
199
|
+
"windows": local_windows,
|
|
200
|
+
"window_tensor": window_tensor(item["memory_embeddings"], local_windows),
|
|
201
|
+
}
|
|
202
|
+
elif selected_window_mode != "full":
|
|
203
|
+
raise ValueError(f"Unknown window_mode: {selected_window_mode}")
|
|
204
|
+
elif scoring_item.get("window_tensor") is None:
|
|
205
|
+
scoring_item = {
|
|
206
|
+
**item,
|
|
207
|
+
"window_tensor": window_tensor(item["memory_embeddings"], item["windows"]),
|
|
208
|
+
}
|
|
209
|
+
|
|
210
|
+
with torch.no_grad():
|
|
211
|
+
conv_tensor = window_scores(self.conv_model, scoring_item, self.device)
|
|
212
|
+
memory_to_windows = build_memory_to_windows(scoring_item["windows"])
|
|
213
|
+
_, _, ce_lite_scores = score_ce_lite(
|
|
214
|
+
self.conv_model,
|
|
215
|
+
self.scorer,
|
|
216
|
+
scoring_item,
|
|
217
|
+
candidate_indices,
|
|
218
|
+
self.device,
|
|
219
|
+
raw_scores_all=raw_scores,
|
|
220
|
+
window_logits=conv_tensor,
|
|
221
|
+
memory_to_windows=memory_to_windows,
|
|
222
|
+
dca_router_block_size=self.config.dca_router_block_size,
|
|
223
|
+
lexical_features=self.config.lexical_features,
|
|
224
|
+
)
|
|
225
|
+
ranked_ids = rerank_candidates(
|
|
226
|
+
raw_scores,
|
|
227
|
+
candidate_indices,
|
|
228
|
+
ce_lite_scores,
|
|
229
|
+
item["memory_ids"],
|
|
230
|
+
raw_weight=self.config.raw_weight if raw_weight is None else raw_weight,
|
|
231
|
+
)
|
|
232
|
+
score_by_id = {
|
|
233
|
+
item["memory_ids"][int(idx)]: float(score)
|
|
234
|
+
for idx, score in zip(candidate_indices, ce_lite_scores)
|
|
235
|
+
}
|
|
236
|
+
raw_by_id = {
|
|
237
|
+
item["memory_ids"][idx]: float(score)
|
|
238
|
+
for idx, score in enumerate(raw_scores)
|
|
239
|
+
}
|
|
240
|
+
text_by_id = {
|
|
241
|
+
str(memory.get("id", idx)): memory.get("text")
|
|
242
|
+
for idx, memory in enumerate(item.get("memories", []))
|
|
243
|
+
}
|
|
244
|
+
return [
|
|
245
|
+
RerankResult(
|
|
246
|
+
memory_id=memory_id,
|
|
247
|
+
score=score_by_id.get(memory_id, raw_by_id[memory_id]),
|
|
248
|
+
raw_score=raw_by_id[memory_id],
|
|
249
|
+
rank=rank,
|
|
250
|
+
text=text_by_id.get(memory_id),
|
|
251
|
+
)
|
|
252
|
+
for rank, memory_id in enumerate(ranked_ids, start=1)
|
|
253
|
+
]
|
convmemory/routing.py
ADDED
|
@@ -0,0 +1,208 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from typing import Iterable, List, Mapping, Optional, Sequence
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
@dataclass(frozen=True)
|
|
8
|
+
class CompressionRouteConfig:
|
|
9
|
+
"""Configuration for note-to-memory candidate routing.
|
|
10
|
+
|
|
11
|
+
The defaults are intentionally conservative and reflect the stable
|
|
12
|
+
LoCoMo v0.31 setting: use compressed notes to select a smaller raw-memory
|
|
13
|
+
pool, then let ConvMemory rerank that pool.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
note_depth: int = 240
|
|
17
|
+
max_sources_per_note: int = 5
|
|
18
|
+
max_candidates: int = 450
|
|
19
|
+
raw_anchor: int = 80
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@dataclass(frozen=True)
|
|
23
|
+
class CompressionRouteResult:
|
|
24
|
+
candidate_indices: List[int]
|
|
25
|
+
candidate_ids: List[str]
|
|
26
|
+
note_indices: List[int]
|
|
27
|
+
raw_anchor_count: int
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@dataclass(frozen=True)
|
|
31
|
+
class CompressedNoteConfig:
|
|
32
|
+
"""Configuration for lightweight raw-memory note construction."""
|
|
33
|
+
|
|
34
|
+
mode: str = "session"
|
|
35
|
+
block_size: int = 32
|
|
36
|
+
representatives: int = 3
|
|
37
|
+
strategy: str = "central"
|
|
38
|
+
session_key: str = "session_id"
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class CompressionRouter:
|
|
42
|
+
"""Route raw memories through compressed note blocks.
|
|
43
|
+
|
|
44
|
+
Compressed memories are dictionaries with at least:
|
|
45
|
+
- `text`: note text used for embedding
|
|
46
|
+
- `source_ids`: raw memory ids covered by this note
|
|
47
|
+
|
|
48
|
+
The router does not call ConvMemory directly. It only returns candidate ids
|
|
49
|
+
so it can be plugged into any retrieval or agent-memory pipeline.
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
def __init__(self, config: Optional[CompressionRouteConfig] = None):
|
|
53
|
+
self.config = config or CompressionRouteConfig()
|
|
54
|
+
|
|
55
|
+
def route(
|
|
56
|
+
self,
|
|
57
|
+
query_embedding,
|
|
58
|
+
memory_embeddings,
|
|
59
|
+
memory_ids: Sequence[str],
|
|
60
|
+
compressed_embeddings,
|
|
61
|
+
compressed_memories: Iterable[Mapping],
|
|
62
|
+
) -> CompressionRouteResult:
|
|
63
|
+
memory_ids = [str(memory_id) for memory_id in memory_ids]
|
|
64
|
+
compressed_memories = list(compressed_memories)
|
|
65
|
+
if len(compressed_memories) != len(compressed_embeddings):
|
|
66
|
+
raise ValueError("compressed_memories and compressed_embeddings must have the same length")
|
|
67
|
+
|
|
68
|
+
query = _normalize_vector(query_embedding)
|
|
69
|
+
memories = _normalize_matrix(memory_embeddings)
|
|
70
|
+
notes = _normalize_matrix(compressed_embeddings)
|
|
71
|
+
|
|
72
|
+
raw_scores = memories @ query
|
|
73
|
+
raw_order = np.argsort(-raw_scores)
|
|
74
|
+
note_scores = notes @ query if len(notes) else np.asarray([], dtype=np.float32)
|
|
75
|
+
note_order = np.argsort(-note_scores)[: max(0, int(self.config.note_depth))]
|
|
76
|
+
|
|
77
|
+
id_to_index = {memory_id: idx for idx, memory_id in enumerate(memory_ids)}
|
|
78
|
+
selected: List[int] = []
|
|
79
|
+
seen = set()
|
|
80
|
+
|
|
81
|
+
for idx in raw_order[: max(0, int(self.config.raw_anchor))]:
|
|
82
|
+
self._add_candidate(int(idx), selected, seen)
|
|
83
|
+
if len(selected) >= self.config.max_candidates:
|
|
84
|
+
return self._result(selected, memory_ids, note_order)
|
|
85
|
+
|
|
86
|
+
for note_idx in note_order:
|
|
87
|
+
source_indices = []
|
|
88
|
+
for source_id in compressed_memories[int(note_idx)].get("source_ids", []):
|
|
89
|
+
source_key = str(source_id)
|
|
90
|
+
if source_key in id_to_index:
|
|
91
|
+
source_indices.append(id_to_index[source_key])
|
|
92
|
+
source_indices.sort(key=lambda idx: -float(raw_scores[idx]))
|
|
93
|
+
limit = int(self.config.max_sources_per_note)
|
|
94
|
+
if limit > 0:
|
|
95
|
+
source_indices = source_indices[:limit]
|
|
96
|
+
|
|
97
|
+
for idx in source_indices:
|
|
98
|
+
self._add_candidate(int(idx), selected, seen)
|
|
99
|
+
if len(selected) >= self.config.max_candidates:
|
|
100
|
+
return self._result(selected, memory_ids, note_order)
|
|
101
|
+
|
|
102
|
+
return self._result(selected, memory_ids, note_order)
|
|
103
|
+
|
|
104
|
+
@staticmethod
|
|
105
|
+
def _add_candidate(idx: int, selected: List[int], seen) -> None:
|
|
106
|
+
if idx in seen:
|
|
107
|
+
return
|
|
108
|
+
selected.append(idx)
|
|
109
|
+
seen.add(idx)
|
|
110
|
+
|
|
111
|
+
def _result(self, selected: List[int], memory_ids: Sequence[str], note_order) -> CompressionRouteResult:
|
|
112
|
+
return CompressionRouteResult(
|
|
113
|
+
candidate_indices=list(selected),
|
|
114
|
+
candidate_ids=[memory_ids[idx] for idx in selected],
|
|
115
|
+
note_indices=[int(idx) for idx in note_order],
|
|
116
|
+
raw_anchor_count=min(len(selected), max(0, int(self.config.raw_anchor))),
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def _normalize_vector(x):
|
|
121
|
+
arr = np.asarray(x, dtype=np.float32)
|
|
122
|
+
return arr / (np.linalg.norm(arr) + 1e-8)
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def _normalize_matrix(x):
|
|
126
|
+
arr = np.asarray(x, dtype=np.float32)
|
|
127
|
+
if arr.ndim == 1:
|
|
128
|
+
arr = arr.reshape(1, -1)
|
|
129
|
+
return arr / (np.linalg.norm(arr, axis=1, keepdims=True) + 1e-8)
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def build_compressed_notes(
|
|
133
|
+
memories: Iterable[Mapping],
|
|
134
|
+
memory_embeddings,
|
|
135
|
+
config: Optional[CompressedNoteConfig] = None,
|
|
136
|
+
):
|
|
137
|
+
"""Build compressed notes from raw memories.
|
|
138
|
+
|
|
139
|
+
This helper is intentionally simple: it groups an ordered memory stream by
|
|
140
|
+
session or fixed-size blocks, chooses representative turns, and keeps
|
|
141
|
+
`source_ids` so the note can be expanded back to raw memories.
|
|
142
|
+
"""
|
|
143
|
+
|
|
144
|
+
cfg = config or CompressedNoteConfig()
|
|
145
|
+
memories = list(memories)
|
|
146
|
+
embeddings = _normalize_matrix(memory_embeddings)
|
|
147
|
+
if len(memories) != len(embeddings):
|
|
148
|
+
raise ValueError("memories and memory_embeddings must have the same length")
|
|
149
|
+
if cfg.mode not in {"session", "block"}:
|
|
150
|
+
raise ValueError("CompressedNoteConfig.mode must be 'session' or 'block'")
|
|
151
|
+
if cfg.strategy not in {"central", "first"}:
|
|
152
|
+
raise ValueError("CompressedNoteConfig.strategy must be 'central' or 'first'")
|
|
153
|
+
|
|
154
|
+
groups = _session_groups(memories, cfg.session_key)
|
|
155
|
+
if cfg.mode == "block":
|
|
156
|
+
groups = [
|
|
157
|
+
group[start : start + cfg.block_size]
|
|
158
|
+
for group in groups
|
|
159
|
+
for start in range(0, len(group), cfg.block_size)
|
|
160
|
+
]
|
|
161
|
+
|
|
162
|
+
notes = []
|
|
163
|
+
for note_idx, group in enumerate(groups):
|
|
164
|
+
if not group:
|
|
165
|
+
continue
|
|
166
|
+
reps = _representative_indices(group, embeddings, cfg.strategy, cfg.representatives)
|
|
167
|
+
text = " ".join(str(memories[idx].get("text", "")) for idx in reps)
|
|
168
|
+
source_ids = [str(memories[idx].get("id", idx)) for idx in group]
|
|
169
|
+
session_id = str(memories[group[0]].get(cfg.session_key, ""))
|
|
170
|
+
notes.append(
|
|
171
|
+
{
|
|
172
|
+
"id": f"{cfg.mode}:{note_idx}",
|
|
173
|
+
"text": text,
|
|
174
|
+
"source_ids": source_ids,
|
|
175
|
+
cfg.session_key: session_id,
|
|
176
|
+
"granularity": cfg.mode,
|
|
177
|
+
}
|
|
178
|
+
)
|
|
179
|
+
return notes
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
def _session_groups(memories, session_key):
|
|
183
|
+
groups = []
|
|
184
|
+
current_value = None
|
|
185
|
+
current = []
|
|
186
|
+
for idx, item in enumerate(memories):
|
|
187
|
+
value = item.get(session_key, "")
|
|
188
|
+
if current and value != current_value:
|
|
189
|
+
groups.append(current)
|
|
190
|
+
current = []
|
|
191
|
+
current_value = value
|
|
192
|
+
current.append(idx)
|
|
193
|
+
if current:
|
|
194
|
+
groups.append(current)
|
|
195
|
+
return groups
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
def _representative_indices(group, embeddings, strategy, representatives):
|
|
199
|
+
count = max(1, min(int(representatives), len(group)))
|
|
200
|
+
if strategy == "first":
|
|
201
|
+
return group[:count]
|
|
202
|
+
|
|
203
|
+
local_embeddings = embeddings[group]
|
|
204
|
+
centroid = local_embeddings.mean(axis=0)
|
|
205
|
+
centroid = centroid / (np.linalg.norm(centroid) + 1e-8)
|
|
206
|
+
scores = local_embeddings @ centroid
|
|
207
|
+
picked = np.argsort(-scores)[:count]
|
|
208
|
+
return [group[int(i)] for i in sorted(picked)]
|