tritopic 0.1.0__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tritopic/__init__.py +22 -32
- tritopic/config.py +305 -0
- tritopic/core/__init__.py +0 -17
- tritopic/core/clustering.py +229 -243
- tritopic/core/embeddings.py +151 -157
- tritopic/core/graph.py +435 -0
- tritopic/core/keywords.py +213 -249
- tritopic/core/refinement.py +231 -0
- tritopic/core/representatives.py +560 -0
- tritopic/labeling.py +313 -0
- tritopic/model.py +718 -0
- tritopic/multilingual/__init__.py +38 -0
- tritopic/multilingual/detection.py +208 -0
- tritopic/multilingual/stopwords.py +467 -0
- tritopic/multilingual/tokenizers.py +275 -0
- tritopic/visualization.py +371 -0
- {tritopic-0.1.0.dist-info → tritopic-1.0.0.dist-info}/METADATA +92 -48
- tritopic-1.0.0.dist-info/RECORD +20 -0
- tritopic/core/graph_builder.py +0 -493
- tritopic/core/model.py +0 -810
- tritopic/labeling/__init__.py +0 -5
- tritopic/labeling/llm_labeler.py +0 -279
- tritopic/utils/__init__.py +0 -13
- tritopic/utils/metrics.py +0 -254
- tritopic/visualization/__init__.py +0 -5
- tritopic/visualization/plotter.py +0 -523
- tritopic-0.1.0.dist-info/RECORD +0 -18
- tritopic-0.1.0.dist-info/licenses/LICENSE +0 -21
- {tritopic-0.1.0.dist-info → tritopic-1.0.0.dist-info}/WHEEL +0 -0
- {tritopic-0.1.0.dist-info → tritopic-1.0.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,560 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Representative Documents Module
|
|
3
|
+
|
|
4
|
+
Implements various methods for selecting representative documents per topic:
|
|
5
|
+
- Centroid: Documents closest to topic centroid
|
|
6
|
+
- Medoid: Actual document that minimizes distance to all others
|
|
7
|
+
- Archetype: Documents on the convex hull (extremes/facets)
|
|
8
|
+
- Diverse: Diverse set maximizing coverage
|
|
9
|
+
- Hybrid: Combination of medoid + archetypes + keyword champion
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from typing import List, Tuple, Optional, Literal, Dict
|
|
13
|
+
import numpy as np
|
|
14
|
+
from dataclasses import dataclass
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@dataclass
|
|
18
|
+
class RepresentativeDocument:
|
|
19
|
+
"""A representative document with metadata."""
|
|
20
|
+
doc_id: int
|
|
21
|
+
text: str
|
|
22
|
+
score: float
|
|
23
|
+
method: str # How this document was selected
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@dataclass
|
|
27
|
+
class TopicRepresentatives:
|
|
28
|
+
"""Collection of representative documents for a topic."""
|
|
29
|
+
topic_id: int
|
|
30
|
+
medoid: Optional[RepresentativeDocument] = None
|
|
31
|
+
archetypes: List[RepresentativeDocument] = None
|
|
32
|
+
keyword_champion: Optional[RepresentativeDocument] = None
|
|
33
|
+
centroid_nearest: List[RepresentativeDocument] = None
|
|
34
|
+
|
|
35
|
+
def __post_init__(self):
|
|
36
|
+
if self.archetypes is None:
|
|
37
|
+
self.archetypes = []
|
|
38
|
+
if self.centroid_nearest is None:
|
|
39
|
+
self.centroid_nearest = []
|
|
40
|
+
|
|
41
|
+
def get_all(self) -> List[RepresentativeDocument]:
|
|
42
|
+
"""Get all representative documents."""
|
|
43
|
+
result = []
|
|
44
|
+
if self.medoid:
|
|
45
|
+
result.append(self.medoid)
|
|
46
|
+
result.extend(self.archetypes)
|
|
47
|
+
if self.keyword_champion and self.keyword_champion not in result:
|
|
48
|
+
result.append(self.keyword_champion)
|
|
49
|
+
for doc in self.centroid_nearest:
|
|
50
|
+
if doc not in result:
|
|
51
|
+
result.append(doc)
|
|
52
|
+
return result
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class RepresentativeSelector:
|
|
56
|
+
"""
|
|
57
|
+
Selects representative documents for topics using various methods.
|
|
58
|
+
"""
|
|
59
|
+
|
|
60
|
+
def __init__(
|
|
61
|
+
self,
|
|
62
|
+
method: Literal["centroid", "medoid", "archetype", "diverse", "hybrid"] = "hybrid",
|
|
63
|
+
n_representatives: int = 5,
|
|
64
|
+
n_archetypes: int = 4,
|
|
65
|
+
archetype_method: Literal["pcha", "convex_hull", "furthest_sum"] = "furthest_sum",
|
|
66
|
+
):
|
|
67
|
+
"""
|
|
68
|
+
Initialize the representative selector.
|
|
69
|
+
|
|
70
|
+
Parameters
|
|
71
|
+
----------
|
|
72
|
+
method : str
|
|
73
|
+
Selection method
|
|
74
|
+
n_representatives : int
|
|
75
|
+
Total number of representatives per topic
|
|
76
|
+
n_archetypes : int
|
|
77
|
+
Number of archetypes (for archetype/hybrid methods)
|
|
78
|
+
archetype_method : str
|
|
79
|
+
Algorithm for archetype analysis
|
|
80
|
+
"""
|
|
81
|
+
self.method = method
|
|
82
|
+
self.n_representatives = n_representatives
|
|
83
|
+
self.n_archetypes = n_archetypes
|
|
84
|
+
self.archetype_method = archetype_method
|
|
85
|
+
|
|
86
|
+
def select(
|
|
87
|
+
self,
|
|
88
|
+
embeddings: np.ndarray,
|
|
89
|
+
labels: np.ndarray,
|
|
90
|
+
documents: List[str],
|
|
91
|
+
topic_keywords: Dict[int, List[str]] = None,
|
|
92
|
+
) -> Dict[int, TopicRepresentatives]:
|
|
93
|
+
"""
|
|
94
|
+
Select representative documents for each topic.
|
|
95
|
+
|
|
96
|
+
Parameters
|
|
97
|
+
----------
|
|
98
|
+
embeddings : np.ndarray
|
|
99
|
+
Document embeddings
|
|
100
|
+
labels : np.ndarray
|
|
101
|
+
Topic labels
|
|
102
|
+
documents : List[str]
|
|
103
|
+
Original documents
|
|
104
|
+
topic_keywords : dict
|
|
105
|
+
Keywords per topic (for keyword champion)
|
|
106
|
+
|
|
107
|
+
Returns
|
|
108
|
+
-------
|
|
109
|
+
Dict[int, TopicRepresentatives]
|
|
110
|
+
Representative documents per topic
|
|
111
|
+
"""
|
|
112
|
+
representatives = {}
|
|
113
|
+
unique_topics = sorted([t for t in np.unique(labels) if t >= 0])
|
|
114
|
+
|
|
115
|
+
for topic_id in unique_topics:
|
|
116
|
+
mask = labels == topic_id
|
|
117
|
+
topic_indices = np.where(mask)[0]
|
|
118
|
+
topic_embeddings = embeddings[mask]
|
|
119
|
+
topic_docs = [documents[i] for i in topic_indices]
|
|
120
|
+
|
|
121
|
+
if len(topic_indices) < 2:
|
|
122
|
+
# Too few documents for meaningful selection
|
|
123
|
+
representatives[topic_id] = TopicRepresentatives(
|
|
124
|
+
topic_id=topic_id,
|
|
125
|
+
centroid_nearest=[RepresentativeDocument(
|
|
126
|
+
doc_id=topic_indices[0],
|
|
127
|
+
text=topic_docs[0],
|
|
128
|
+
score=1.0,
|
|
129
|
+
method="only_document"
|
|
130
|
+
)] if len(topic_indices) > 0 else []
|
|
131
|
+
)
|
|
132
|
+
continue
|
|
133
|
+
|
|
134
|
+
keywords = topic_keywords.get(topic_id, []) if topic_keywords else []
|
|
135
|
+
|
|
136
|
+
if self.method == "centroid":
|
|
137
|
+
reps = self._select_centroid(
|
|
138
|
+
topic_embeddings, topic_indices, topic_docs
|
|
139
|
+
)
|
|
140
|
+
elif self.method == "medoid":
|
|
141
|
+
reps = self._select_medoid(
|
|
142
|
+
topic_embeddings, topic_indices, topic_docs
|
|
143
|
+
)
|
|
144
|
+
elif self.method == "archetype":
|
|
145
|
+
reps = self._select_archetypes(
|
|
146
|
+
topic_embeddings, topic_indices, topic_docs
|
|
147
|
+
)
|
|
148
|
+
elif self.method == "diverse":
|
|
149
|
+
reps = self._select_diverse(
|
|
150
|
+
topic_embeddings, topic_indices, topic_docs
|
|
151
|
+
)
|
|
152
|
+
elif self.method == "hybrid":
|
|
153
|
+
reps = self._select_hybrid(
|
|
154
|
+
topic_embeddings, topic_indices, topic_docs, keywords
|
|
155
|
+
)
|
|
156
|
+
else:
|
|
157
|
+
raise ValueError(f"Unknown method: {self.method}")
|
|
158
|
+
|
|
159
|
+
reps.topic_id = topic_id
|
|
160
|
+
representatives[topic_id] = reps
|
|
161
|
+
|
|
162
|
+
return representatives
|
|
163
|
+
|
|
164
|
+
def _select_centroid(
|
|
165
|
+
self,
|
|
166
|
+
embeddings: np.ndarray,
|
|
167
|
+
indices: np.ndarray,
|
|
168
|
+
documents: List[str],
|
|
169
|
+
) -> TopicRepresentatives:
|
|
170
|
+
"""Select documents closest to centroid."""
|
|
171
|
+
centroid = embeddings.mean(axis=0)
|
|
172
|
+
|
|
173
|
+
# Compute distances to centroid
|
|
174
|
+
distances = np.linalg.norm(embeddings - centroid, axis=1)
|
|
175
|
+
|
|
176
|
+
# Sort by distance
|
|
177
|
+
sorted_idx = np.argsort(distances)
|
|
178
|
+
|
|
179
|
+
n = min(self.n_representatives, len(indices))
|
|
180
|
+
selected = []
|
|
181
|
+
|
|
182
|
+
for i in range(n):
|
|
183
|
+
local_idx = sorted_idx[i]
|
|
184
|
+
selected.append(RepresentativeDocument(
|
|
185
|
+
doc_id=indices[local_idx],
|
|
186
|
+
text=documents[local_idx],
|
|
187
|
+
score=1.0 / (1.0 + distances[local_idx]),
|
|
188
|
+
method="centroid"
|
|
189
|
+
))
|
|
190
|
+
|
|
191
|
+
return TopicRepresentatives(
|
|
192
|
+
topic_id=-1,
|
|
193
|
+
centroid_nearest=selected
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
def _select_medoid(
|
|
197
|
+
self,
|
|
198
|
+
embeddings: np.ndarray,
|
|
199
|
+
indices: np.ndarray,
|
|
200
|
+
documents: List[str],
|
|
201
|
+
) -> TopicRepresentatives:
|
|
202
|
+
"""Select medoid (document minimizing total distance to others)."""
|
|
203
|
+
n = len(embeddings)
|
|
204
|
+
|
|
205
|
+
# Compute pairwise distances
|
|
206
|
+
distances = np.zeros((n, n))
|
|
207
|
+
for i in range(n):
|
|
208
|
+
for j in range(i + 1, n):
|
|
209
|
+
d = np.linalg.norm(embeddings[i] - embeddings[j])
|
|
210
|
+
distances[i, j] = d
|
|
211
|
+
distances[j, i] = d
|
|
212
|
+
|
|
213
|
+
# Find medoid (minimum total distance)
|
|
214
|
+
total_distances = distances.sum(axis=1)
|
|
215
|
+
medoid_idx = np.argmin(total_distances)
|
|
216
|
+
|
|
217
|
+
medoid = RepresentativeDocument(
|
|
218
|
+
doc_id=indices[medoid_idx],
|
|
219
|
+
text=documents[medoid_idx],
|
|
220
|
+
score=1.0,
|
|
221
|
+
method="medoid"
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
# Fill remaining slots with centroid method
|
|
225
|
+
remaining = self.n_representatives - 1
|
|
226
|
+
if remaining > 0:
|
|
227
|
+
centroid_reps = self._select_centroid(embeddings, indices, documents)
|
|
228
|
+
# Filter out medoid
|
|
229
|
+
other_docs = [d for d in centroid_reps.centroid_nearest
|
|
230
|
+
if d.doc_id != medoid.doc_id][:remaining]
|
|
231
|
+
else:
|
|
232
|
+
other_docs = []
|
|
233
|
+
|
|
234
|
+
return TopicRepresentatives(
|
|
235
|
+
topic_id=-1,
|
|
236
|
+
medoid=medoid,
|
|
237
|
+
centroid_nearest=other_docs
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
def _select_archetypes(
|
|
241
|
+
self,
|
|
242
|
+
embeddings: np.ndarray,
|
|
243
|
+
indices: np.ndarray,
|
|
244
|
+
documents: List[str],
|
|
245
|
+
) -> TopicRepresentatives:
|
|
246
|
+
"""
|
|
247
|
+
Select archetypal documents using furthest-sum or convex hull.
|
|
248
|
+
|
|
249
|
+
Archetypes represent the "extreme" examples that define the
|
|
250
|
+
boundaries of the topic - the different facets or aspects.
|
|
251
|
+
"""
|
|
252
|
+
n_archetypes = min(self.n_archetypes, len(indices))
|
|
253
|
+
|
|
254
|
+
if self.archetype_method == "furthest_sum":
|
|
255
|
+
archetype_indices = self._furthest_sum_archetypes(embeddings, n_archetypes)
|
|
256
|
+
elif self.archetype_method == "convex_hull":
|
|
257
|
+
archetype_indices = self._convex_hull_archetypes(embeddings, n_archetypes)
|
|
258
|
+
elif self.archetype_method == "pcha":
|
|
259
|
+
archetype_indices = self._pcha_archetypes(embeddings, n_archetypes)
|
|
260
|
+
else:
|
|
261
|
+
archetype_indices = self._furthest_sum_archetypes(embeddings, n_archetypes)
|
|
262
|
+
|
|
263
|
+
archetypes = []
|
|
264
|
+
for local_idx in archetype_indices:
|
|
265
|
+
archetypes.append(RepresentativeDocument(
|
|
266
|
+
doc_id=indices[local_idx],
|
|
267
|
+
text=documents[local_idx],
|
|
268
|
+
score=1.0,
|
|
269
|
+
method="archetype"
|
|
270
|
+
))
|
|
271
|
+
|
|
272
|
+
return TopicRepresentatives(
|
|
273
|
+
topic_id=-1,
|
|
274
|
+
archetypes=archetypes
|
|
275
|
+
)
|
|
276
|
+
|
|
277
|
+
def _furthest_sum_archetypes(
|
|
278
|
+
self,
|
|
279
|
+
embeddings: np.ndarray,
|
|
280
|
+
n_archetypes: int,
|
|
281
|
+
) -> List[int]:
|
|
282
|
+
"""
|
|
283
|
+
Select archetypes using the Furthest Sum algorithm.
|
|
284
|
+
|
|
285
|
+
This iteratively selects points that maximize the sum of
|
|
286
|
+
distances to already selected points - finding extreme points.
|
|
287
|
+
"""
|
|
288
|
+
n = len(embeddings)
|
|
289
|
+
|
|
290
|
+
if n <= n_archetypes:
|
|
291
|
+
return list(range(n))
|
|
292
|
+
|
|
293
|
+
# Start with the point furthest from centroid
|
|
294
|
+
centroid = embeddings.mean(axis=0)
|
|
295
|
+
distances_to_centroid = np.linalg.norm(embeddings - centroid, axis=1)
|
|
296
|
+
selected = [np.argmax(distances_to_centroid)]
|
|
297
|
+
|
|
298
|
+
# Iteratively add points maximizing distance to selected set
|
|
299
|
+
for _ in range(n_archetypes - 1):
|
|
300
|
+
best_idx = -1
|
|
301
|
+
best_score = -1
|
|
302
|
+
|
|
303
|
+
for i in range(n):
|
|
304
|
+
if i in selected:
|
|
305
|
+
continue
|
|
306
|
+
|
|
307
|
+
# Sum of distances to all selected points
|
|
308
|
+
total_dist = sum(
|
|
309
|
+
np.linalg.norm(embeddings[i] - embeddings[j])
|
|
310
|
+
for j in selected
|
|
311
|
+
)
|
|
312
|
+
|
|
313
|
+
if total_dist > best_score:
|
|
314
|
+
best_score = total_dist
|
|
315
|
+
best_idx = i
|
|
316
|
+
|
|
317
|
+
if best_idx >= 0:
|
|
318
|
+
selected.append(best_idx)
|
|
319
|
+
|
|
320
|
+
return selected
|
|
321
|
+
|
|
322
|
+
def _convex_hull_archetypes(
|
|
323
|
+
self,
|
|
324
|
+
embeddings: np.ndarray,
|
|
325
|
+
n_archetypes: int,
|
|
326
|
+
) -> List[int]:
|
|
327
|
+
"""
|
|
328
|
+
Select archetypes as vertices of the convex hull.
|
|
329
|
+
|
|
330
|
+
Note: Convex hull is computed in a reduced dimensional space
|
|
331
|
+
(PCA to 2-3 dimensions) since convex hull in high dimensions
|
|
332
|
+
often includes most/all points.
|
|
333
|
+
"""
|
|
334
|
+
from sklearn.decomposition import PCA
|
|
335
|
+
|
|
336
|
+
n = len(embeddings)
|
|
337
|
+
|
|
338
|
+
if n <= n_archetypes:
|
|
339
|
+
return list(range(n))
|
|
340
|
+
|
|
341
|
+
# Reduce to 2D for convex hull
|
|
342
|
+
n_components = min(2, n, embeddings.shape[1])
|
|
343
|
+
pca = PCA(n_components=n_components)
|
|
344
|
+
reduced = pca.fit_transform(embeddings)
|
|
345
|
+
|
|
346
|
+
try:
|
|
347
|
+
from scipy.spatial import ConvexHull
|
|
348
|
+
|
|
349
|
+
if n_components == 2 and n > 3:
|
|
350
|
+
hull = ConvexHull(reduced)
|
|
351
|
+
hull_vertices = hull.vertices.tolist()
|
|
352
|
+
|
|
353
|
+
# If we have more hull vertices than needed, select diverse subset
|
|
354
|
+
if len(hull_vertices) > n_archetypes:
|
|
355
|
+
return self._furthest_sum_archetypes(
|
|
356
|
+
embeddings[hull_vertices], n_archetypes
|
|
357
|
+
)
|
|
358
|
+
else:
|
|
359
|
+
# Fill remaining with furthest sum
|
|
360
|
+
remaining = n_archetypes - len(hull_vertices)
|
|
361
|
+
if remaining > 0:
|
|
362
|
+
candidates = [i for i in range(n) if i not in hull_vertices]
|
|
363
|
+
if candidates:
|
|
364
|
+
extra = self._furthest_sum_archetypes(
|
|
365
|
+
embeddings[candidates], remaining
|
|
366
|
+
)
|
|
367
|
+
hull_vertices.extend([candidates[i] for i in extra])
|
|
368
|
+
return hull_vertices[:n_archetypes]
|
|
369
|
+
else:
|
|
370
|
+
# Fall back to furthest sum
|
|
371
|
+
return self._furthest_sum_archetypes(embeddings, n_archetypes)
|
|
372
|
+
|
|
373
|
+
except Exception:
|
|
374
|
+
# Fall back to furthest sum
|
|
375
|
+
return self._furthest_sum_archetypes(embeddings, n_archetypes)
|
|
376
|
+
|
|
377
|
+
def _pcha_archetypes(
|
|
378
|
+
self,
|
|
379
|
+
embeddings: np.ndarray,
|
|
380
|
+
n_archetypes: int,
|
|
381
|
+
) -> List[int]:
|
|
382
|
+
"""
|
|
383
|
+
Select archetypes using Principal Convex Hull Analysis (PCHA).
|
|
384
|
+
|
|
385
|
+
PCHA finds archetypes such that all data points can be
|
|
386
|
+
expressed as convex combinations of archetypes.
|
|
387
|
+
|
|
388
|
+
Simplified version: Uses iterative approach to find points
|
|
389
|
+
that best "span" the data.
|
|
390
|
+
"""
|
|
391
|
+
# For simplicity, we use furthest sum as a proxy for PCHA
|
|
392
|
+
# Full PCHA implementation would require optimization
|
|
393
|
+
return self._furthest_sum_archetypes(embeddings, n_archetypes)
|
|
394
|
+
|
|
395
|
+
def _select_diverse(
|
|
396
|
+
self,
|
|
397
|
+
embeddings: np.ndarray,
|
|
398
|
+
indices: np.ndarray,
|
|
399
|
+
documents: List[str],
|
|
400
|
+
) -> TopicRepresentatives:
|
|
401
|
+
"""Select diverse documents using maximal marginal relevance."""
|
|
402
|
+
n = min(self.n_representatives, len(indices))
|
|
403
|
+
|
|
404
|
+
# Start with centroid-nearest
|
|
405
|
+
centroid = embeddings.mean(axis=0)
|
|
406
|
+
distances = np.linalg.norm(embeddings - centroid, axis=1)
|
|
407
|
+
first_idx = np.argmin(distances)
|
|
408
|
+
|
|
409
|
+
selected = [first_idx]
|
|
410
|
+
|
|
411
|
+
# Add documents maximizing diversity
|
|
412
|
+
lambda_param = 0.5 # Balance relevance and diversity
|
|
413
|
+
|
|
414
|
+
for _ in range(n - 1):
|
|
415
|
+
best_idx = -1
|
|
416
|
+
best_score = float('-inf')
|
|
417
|
+
|
|
418
|
+
for i in range(len(embeddings)):
|
|
419
|
+
if i in selected:
|
|
420
|
+
continue
|
|
421
|
+
|
|
422
|
+
# Relevance: closeness to centroid
|
|
423
|
+
relevance = 1.0 / (1.0 + distances[i])
|
|
424
|
+
|
|
425
|
+
# Diversity: minimum distance to already selected
|
|
426
|
+
min_dist_to_selected = min(
|
|
427
|
+
np.linalg.norm(embeddings[i] - embeddings[j])
|
|
428
|
+
for j in selected
|
|
429
|
+
)
|
|
430
|
+
|
|
431
|
+
# MMR score
|
|
432
|
+
score = lambda_param * relevance + (1 - lambda_param) * min_dist_to_selected
|
|
433
|
+
|
|
434
|
+
if score > best_score:
|
|
435
|
+
best_score = score
|
|
436
|
+
best_idx = i
|
|
437
|
+
|
|
438
|
+
if best_idx >= 0:
|
|
439
|
+
selected.append(best_idx)
|
|
440
|
+
|
|
441
|
+
reps = []
|
|
442
|
+
for local_idx in selected:
|
|
443
|
+
reps.append(RepresentativeDocument(
|
|
444
|
+
doc_id=indices[local_idx],
|
|
445
|
+
text=documents[local_idx],
|
|
446
|
+
score=1.0 / (1.0 + distances[local_idx]),
|
|
447
|
+
method="diverse"
|
|
448
|
+
))
|
|
449
|
+
|
|
450
|
+
return TopicRepresentatives(
|
|
451
|
+
topic_id=-1,
|
|
452
|
+
centroid_nearest=reps
|
|
453
|
+
)
|
|
454
|
+
|
|
455
|
+
def _select_hybrid(
|
|
456
|
+
self,
|
|
457
|
+
embeddings: np.ndarray,
|
|
458
|
+
indices: np.ndarray,
|
|
459
|
+
documents: List[str],
|
|
460
|
+
keywords: List[str],
|
|
461
|
+
) -> TopicRepresentatives:
|
|
462
|
+
"""
|
|
463
|
+
Hybrid selection combining:
|
|
464
|
+
1. Medoid (most typical document)
|
|
465
|
+
2. Archetypes (facets/extremes)
|
|
466
|
+
3. Keyword Champion (highest keyword density)
|
|
467
|
+
"""
|
|
468
|
+
# 1. Get medoid
|
|
469
|
+
n = len(embeddings)
|
|
470
|
+
distances = np.zeros((n, n))
|
|
471
|
+
for i in range(n):
|
|
472
|
+
for j in range(i + 1, n):
|
|
473
|
+
d = np.linalg.norm(embeddings[i] - embeddings[j])
|
|
474
|
+
distances[i, j] = d
|
|
475
|
+
distances[j, i] = d
|
|
476
|
+
|
|
477
|
+
total_distances = distances.sum(axis=1)
|
|
478
|
+
medoid_idx = np.argmin(total_distances)
|
|
479
|
+
|
|
480
|
+
medoid = RepresentativeDocument(
|
|
481
|
+
doc_id=indices[medoid_idx],
|
|
482
|
+
text=documents[medoid_idx],
|
|
483
|
+
score=1.0,
|
|
484
|
+
method="medoid"
|
|
485
|
+
)
|
|
486
|
+
|
|
487
|
+
# 2. Get archetypes (excluding medoid)
|
|
488
|
+
n_arch = min(self.n_archetypes, n - 1)
|
|
489
|
+
if n_arch > 0:
|
|
490
|
+
# Filter out medoid
|
|
491
|
+
other_indices = [i for i in range(n) if i != medoid_idx]
|
|
492
|
+
other_embeddings = embeddings[other_indices]
|
|
493
|
+
|
|
494
|
+
arch_local = self._furthest_sum_archetypes(other_embeddings, n_arch)
|
|
495
|
+
arch_indices = [other_indices[i] for i in arch_local]
|
|
496
|
+
|
|
497
|
+
archetypes = [
|
|
498
|
+
RepresentativeDocument(
|
|
499
|
+
doc_id=indices[i],
|
|
500
|
+
text=documents[i],
|
|
501
|
+
score=1.0,
|
|
502
|
+
method="archetype"
|
|
503
|
+
)
|
|
504
|
+
for i in arch_indices
|
|
505
|
+
]
|
|
506
|
+
else:
|
|
507
|
+
archetypes = []
|
|
508
|
+
|
|
509
|
+
# 3. Get keyword champion
|
|
510
|
+
keyword_champion = None
|
|
511
|
+
if keywords:
|
|
512
|
+
keyword_set = set(w.lower() for w in keywords[:10])
|
|
513
|
+
best_score = 0
|
|
514
|
+
best_idx = -1
|
|
515
|
+
|
|
516
|
+
for i, doc in enumerate(documents):
|
|
517
|
+
doc_words = set(doc.lower().split())
|
|
518
|
+
score = len(keyword_set & doc_words) / len(keyword_set)
|
|
519
|
+
if score > best_score:
|
|
520
|
+
best_score = score
|
|
521
|
+
best_idx = i
|
|
522
|
+
|
|
523
|
+
if best_idx >= 0 and best_idx != medoid_idx and best_idx not in arch_indices:
|
|
524
|
+
keyword_champion = RepresentativeDocument(
|
|
525
|
+
doc_id=indices[best_idx],
|
|
526
|
+
text=documents[best_idx],
|
|
527
|
+
score=best_score,
|
|
528
|
+
method="keyword_champion"
|
|
529
|
+
)
|
|
530
|
+
|
|
531
|
+
return TopicRepresentatives(
|
|
532
|
+
topic_id=-1,
|
|
533
|
+
medoid=medoid,
|
|
534
|
+
archetypes=archetypes,
|
|
535
|
+
keyword_champion=keyword_champion
|
|
536
|
+
)
|
|
537
|
+
|
|
538
|
+
|
|
539
|
+
def get_representative_documents(
|
|
540
|
+
topic_representatives: Dict[int, TopicRepresentatives],
|
|
541
|
+
n_per_topic: int = 5,
|
|
542
|
+
) -> Dict[int, List[Tuple[int, str, str]]]:
|
|
543
|
+
"""
|
|
544
|
+
Get flattened list of representative documents per topic.
|
|
545
|
+
|
|
546
|
+
Returns
|
|
547
|
+
-------
|
|
548
|
+
Dict[int, List[Tuple[int, str, str]]]
|
|
549
|
+
topic_id -> [(doc_id, text, method), ...]
|
|
550
|
+
"""
|
|
551
|
+
result = {}
|
|
552
|
+
|
|
553
|
+
for topic_id, reps in topic_representatives.items():
|
|
554
|
+
all_reps = reps.get_all()[:n_per_topic]
|
|
555
|
+
result[topic_id] = [
|
|
556
|
+
(r.doc_id, r.text, r.method)
|
|
557
|
+
for r in all_reps
|
|
558
|
+
]
|
|
559
|
+
|
|
560
|
+
return result
|