tritopic 0.1.0__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,560 @@
1
+ """
2
+ Representative Documents Module
3
+
4
+ Implements various methods for selecting representative documents per topic:
5
+ - Centroid: Documents closest to topic centroid
6
+ - Medoid: Actual document that minimizes distance to all others
7
+ - Archetype: Documents on the convex hull (extremes/facets)
8
+ - Diverse: Diverse set maximizing coverage
9
+ - Hybrid: Combination of medoid + archetypes + keyword champion
10
+ """
11
+
12
+ from typing import List, Tuple, Optional, Literal, Dict
13
+ import numpy as np
14
+ from dataclasses import dataclass
15
+
16
+
17
+ @dataclass
18
+ class RepresentativeDocument:
19
+ """A representative document with metadata."""
20
+ doc_id: int
21
+ text: str
22
+ score: float
23
+ method: str # How this document was selected
24
+
25
+
26
+ @dataclass
27
+ class TopicRepresentatives:
28
+ """Collection of representative documents for a topic."""
29
+ topic_id: int
30
+ medoid: Optional[RepresentativeDocument] = None
31
+ archetypes: List[RepresentativeDocument] = None
32
+ keyword_champion: Optional[RepresentativeDocument] = None
33
+ centroid_nearest: List[RepresentativeDocument] = None
34
+
35
+ def __post_init__(self):
36
+ if self.archetypes is None:
37
+ self.archetypes = []
38
+ if self.centroid_nearest is None:
39
+ self.centroid_nearest = []
40
+
41
+ def get_all(self) -> List[RepresentativeDocument]:
42
+ """Get all representative documents."""
43
+ result = []
44
+ if self.medoid:
45
+ result.append(self.medoid)
46
+ result.extend(self.archetypes)
47
+ if self.keyword_champion and self.keyword_champion not in result:
48
+ result.append(self.keyword_champion)
49
+ for doc in self.centroid_nearest:
50
+ if doc not in result:
51
+ result.append(doc)
52
+ return result
53
+
54
+
55
+ class RepresentativeSelector:
56
+ """
57
+ Selects representative documents for topics using various methods.
58
+ """
59
+
60
+ def __init__(
61
+ self,
62
+ method: Literal["centroid", "medoid", "archetype", "diverse", "hybrid"] = "hybrid",
63
+ n_representatives: int = 5,
64
+ n_archetypes: int = 4,
65
+ archetype_method: Literal["pcha", "convex_hull", "furthest_sum"] = "furthest_sum",
66
+ ):
67
+ """
68
+ Initialize the representative selector.
69
+
70
+ Parameters
71
+ ----------
72
+ method : str
73
+ Selection method
74
+ n_representatives : int
75
+ Total number of representatives per topic
76
+ n_archetypes : int
77
+ Number of archetypes (for archetype/hybrid methods)
78
+ archetype_method : str
79
+ Algorithm for archetype analysis
80
+ """
81
+ self.method = method
82
+ self.n_representatives = n_representatives
83
+ self.n_archetypes = n_archetypes
84
+ self.archetype_method = archetype_method
85
+
86
+ def select(
87
+ self,
88
+ embeddings: np.ndarray,
89
+ labels: np.ndarray,
90
+ documents: List[str],
91
+ topic_keywords: Dict[int, List[str]] = None,
92
+ ) -> Dict[int, TopicRepresentatives]:
93
+ """
94
+ Select representative documents for each topic.
95
+
96
+ Parameters
97
+ ----------
98
+ embeddings : np.ndarray
99
+ Document embeddings
100
+ labels : np.ndarray
101
+ Topic labels
102
+ documents : List[str]
103
+ Original documents
104
+ topic_keywords : dict
105
+ Keywords per topic (for keyword champion)
106
+
107
+ Returns
108
+ -------
109
+ Dict[int, TopicRepresentatives]
110
+ Representative documents per topic
111
+ """
112
+ representatives = {}
113
+ unique_topics = sorted([t for t in np.unique(labels) if t >= 0])
114
+
115
+ for topic_id in unique_topics:
116
+ mask = labels == topic_id
117
+ topic_indices = np.where(mask)[0]
118
+ topic_embeddings = embeddings[mask]
119
+ topic_docs = [documents[i] for i in topic_indices]
120
+
121
+ if len(topic_indices) < 2:
122
+ # Too few documents for meaningful selection
123
+ representatives[topic_id] = TopicRepresentatives(
124
+ topic_id=topic_id,
125
+ centroid_nearest=[RepresentativeDocument(
126
+ doc_id=topic_indices[0],
127
+ text=topic_docs[0],
128
+ score=1.0,
129
+ method="only_document"
130
+ )] if len(topic_indices) > 0 else []
131
+ )
132
+ continue
133
+
134
+ keywords = topic_keywords.get(topic_id, []) if topic_keywords else []
135
+
136
+ if self.method == "centroid":
137
+ reps = self._select_centroid(
138
+ topic_embeddings, topic_indices, topic_docs
139
+ )
140
+ elif self.method == "medoid":
141
+ reps = self._select_medoid(
142
+ topic_embeddings, topic_indices, topic_docs
143
+ )
144
+ elif self.method == "archetype":
145
+ reps = self._select_archetypes(
146
+ topic_embeddings, topic_indices, topic_docs
147
+ )
148
+ elif self.method == "diverse":
149
+ reps = self._select_diverse(
150
+ topic_embeddings, topic_indices, topic_docs
151
+ )
152
+ elif self.method == "hybrid":
153
+ reps = self._select_hybrid(
154
+ topic_embeddings, topic_indices, topic_docs, keywords
155
+ )
156
+ else:
157
+ raise ValueError(f"Unknown method: {self.method}")
158
+
159
+ reps.topic_id = topic_id
160
+ representatives[topic_id] = reps
161
+
162
+ return representatives
163
+
164
+ def _select_centroid(
165
+ self,
166
+ embeddings: np.ndarray,
167
+ indices: np.ndarray,
168
+ documents: List[str],
169
+ ) -> TopicRepresentatives:
170
+ """Select documents closest to centroid."""
171
+ centroid = embeddings.mean(axis=0)
172
+
173
+ # Compute distances to centroid
174
+ distances = np.linalg.norm(embeddings - centroid, axis=1)
175
+
176
+ # Sort by distance
177
+ sorted_idx = np.argsort(distances)
178
+
179
+ n = min(self.n_representatives, len(indices))
180
+ selected = []
181
+
182
+ for i in range(n):
183
+ local_idx = sorted_idx[i]
184
+ selected.append(RepresentativeDocument(
185
+ doc_id=indices[local_idx],
186
+ text=documents[local_idx],
187
+ score=1.0 / (1.0 + distances[local_idx]),
188
+ method="centroid"
189
+ ))
190
+
191
+ return TopicRepresentatives(
192
+ topic_id=-1,
193
+ centroid_nearest=selected
194
+ )
195
+
196
+ def _select_medoid(
197
+ self,
198
+ embeddings: np.ndarray,
199
+ indices: np.ndarray,
200
+ documents: List[str],
201
+ ) -> TopicRepresentatives:
202
+ """Select medoid (document minimizing total distance to others)."""
203
+ n = len(embeddings)
204
+
205
+ # Compute pairwise distances
206
+ distances = np.zeros((n, n))
207
+ for i in range(n):
208
+ for j in range(i + 1, n):
209
+ d = np.linalg.norm(embeddings[i] - embeddings[j])
210
+ distances[i, j] = d
211
+ distances[j, i] = d
212
+
213
+ # Find medoid (minimum total distance)
214
+ total_distances = distances.sum(axis=1)
215
+ medoid_idx = np.argmin(total_distances)
216
+
217
+ medoid = RepresentativeDocument(
218
+ doc_id=indices[medoid_idx],
219
+ text=documents[medoid_idx],
220
+ score=1.0,
221
+ method="medoid"
222
+ )
223
+
224
+ # Fill remaining slots with centroid method
225
+ remaining = self.n_representatives - 1
226
+ if remaining > 0:
227
+ centroid_reps = self._select_centroid(embeddings, indices, documents)
228
+ # Filter out medoid
229
+ other_docs = [d for d in centroid_reps.centroid_nearest
230
+ if d.doc_id != medoid.doc_id][:remaining]
231
+ else:
232
+ other_docs = []
233
+
234
+ return TopicRepresentatives(
235
+ topic_id=-1,
236
+ medoid=medoid,
237
+ centroid_nearest=other_docs
238
+ )
239
+
240
+ def _select_archetypes(
241
+ self,
242
+ embeddings: np.ndarray,
243
+ indices: np.ndarray,
244
+ documents: List[str],
245
+ ) -> TopicRepresentatives:
246
+ """
247
+ Select archetypal documents using furthest-sum or convex hull.
248
+
249
+ Archetypes represent the "extreme" examples that define the
250
+ boundaries of the topic - the different facets or aspects.
251
+ """
252
+ n_archetypes = min(self.n_archetypes, len(indices))
253
+
254
+ if self.archetype_method == "furthest_sum":
255
+ archetype_indices = self._furthest_sum_archetypes(embeddings, n_archetypes)
256
+ elif self.archetype_method == "convex_hull":
257
+ archetype_indices = self._convex_hull_archetypes(embeddings, n_archetypes)
258
+ elif self.archetype_method == "pcha":
259
+ archetype_indices = self._pcha_archetypes(embeddings, n_archetypes)
260
+ else:
261
+ archetype_indices = self._furthest_sum_archetypes(embeddings, n_archetypes)
262
+
263
+ archetypes = []
264
+ for local_idx in archetype_indices:
265
+ archetypes.append(RepresentativeDocument(
266
+ doc_id=indices[local_idx],
267
+ text=documents[local_idx],
268
+ score=1.0,
269
+ method="archetype"
270
+ ))
271
+
272
+ return TopicRepresentatives(
273
+ topic_id=-1,
274
+ archetypes=archetypes
275
+ )
276
+
277
+ def _furthest_sum_archetypes(
278
+ self,
279
+ embeddings: np.ndarray,
280
+ n_archetypes: int,
281
+ ) -> List[int]:
282
+ """
283
+ Select archetypes using the Furthest Sum algorithm.
284
+
285
+ This iteratively selects points that maximize the sum of
286
+ distances to already selected points - finding extreme points.
287
+ """
288
+ n = len(embeddings)
289
+
290
+ if n <= n_archetypes:
291
+ return list(range(n))
292
+
293
+ # Start with the point furthest from centroid
294
+ centroid = embeddings.mean(axis=0)
295
+ distances_to_centroid = np.linalg.norm(embeddings - centroid, axis=1)
296
+ selected = [np.argmax(distances_to_centroid)]
297
+
298
+ # Iteratively add points maximizing distance to selected set
299
+ for _ in range(n_archetypes - 1):
300
+ best_idx = -1
301
+ best_score = -1
302
+
303
+ for i in range(n):
304
+ if i in selected:
305
+ continue
306
+
307
+ # Sum of distances to all selected points
308
+ total_dist = sum(
309
+ np.linalg.norm(embeddings[i] - embeddings[j])
310
+ for j in selected
311
+ )
312
+
313
+ if total_dist > best_score:
314
+ best_score = total_dist
315
+ best_idx = i
316
+
317
+ if best_idx >= 0:
318
+ selected.append(best_idx)
319
+
320
+ return selected
321
+
322
+ def _convex_hull_archetypes(
323
+ self,
324
+ embeddings: np.ndarray,
325
+ n_archetypes: int,
326
+ ) -> List[int]:
327
+ """
328
+ Select archetypes as vertices of the convex hull.
329
+
330
+ Note: Convex hull is computed in a reduced dimensional space
331
+ (PCA to 2-3 dimensions) since convex hull in high dimensions
332
+ often includes most/all points.
333
+ """
334
+ from sklearn.decomposition import PCA
335
+
336
+ n = len(embeddings)
337
+
338
+ if n <= n_archetypes:
339
+ return list(range(n))
340
+
341
+ # Reduce to 2D for convex hull
342
+ n_components = min(2, n, embeddings.shape[1])
343
+ pca = PCA(n_components=n_components)
344
+ reduced = pca.fit_transform(embeddings)
345
+
346
+ try:
347
+ from scipy.spatial import ConvexHull
348
+
349
+ if n_components == 2 and n > 3:
350
+ hull = ConvexHull(reduced)
351
+ hull_vertices = hull.vertices.tolist()
352
+
353
+ # If we have more hull vertices than needed, select diverse subset
354
+ if len(hull_vertices) > n_archetypes:
355
+ return self._furthest_sum_archetypes(
356
+ embeddings[hull_vertices], n_archetypes
357
+ )
358
+ else:
359
+ # Fill remaining with furthest sum
360
+ remaining = n_archetypes - len(hull_vertices)
361
+ if remaining > 0:
362
+ candidates = [i for i in range(n) if i not in hull_vertices]
363
+ if candidates:
364
+ extra = self._furthest_sum_archetypes(
365
+ embeddings[candidates], remaining
366
+ )
367
+ hull_vertices.extend([candidates[i] for i in extra])
368
+ return hull_vertices[:n_archetypes]
369
+ else:
370
+ # Fall back to furthest sum
371
+ return self._furthest_sum_archetypes(embeddings, n_archetypes)
372
+
373
+ except Exception:
374
+ # Fall back to furthest sum
375
+ return self._furthest_sum_archetypes(embeddings, n_archetypes)
376
+
377
+ def _pcha_archetypes(
378
+ self,
379
+ embeddings: np.ndarray,
380
+ n_archetypes: int,
381
+ ) -> List[int]:
382
+ """
383
+ Select archetypes using Principal Convex Hull Analysis (PCHA).
384
+
385
+ PCHA finds archetypes such that all data points can be
386
+ expressed as convex combinations of archetypes.
387
+
388
+ Simplified version: Uses iterative approach to find points
389
+ that best "span" the data.
390
+ """
391
+ # For simplicity, we use furthest sum as a proxy for PCHA
392
+ # Full PCHA implementation would require optimization
393
+ return self._furthest_sum_archetypes(embeddings, n_archetypes)
394
+
395
+ def _select_diverse(
396
+ self,
397
+ embeddings: np.ndarray,
398
+ indices: np.ndarray,
399
+ documents: List[str],
400
+ ) -> TopicRepresentatives:
401
+ """Select diverse documents using maximal marginal relevance."""
402
+ n = min(self.n_representatives, len(indices))
403
+
404
+ # Start with centroid-nearest
405
+ centroid = embeddings.mean(axis=0)
406
+ distances = np.linalg.norm(embeddings - centroid, axis=1)
407
+ first_idx = np.argmin(distances)
408
+
409
+ selected = [first_idx]
410
+
411
+ # Add documents maximizing diversity
412
+ lambda_param = 0.5 # Balance relevance and diversity
413
+
414
+ for _ in range(n - 1):
415
+ best_idx = -1
416
+ best_score = float('-inf')
417
+
418
+ for i in range(len(embeddings)):
419
+ if i in selected:
420
+ continue
421
+
422
+ # Relevance: closeness to centroid
423
+ relevance = 1.0 / (1.0 + distances[i])
424
+
425
+ # Diversity: minimum distance to already selected
426
+ min_dist_to_selected = min(
427
+ np.linalg.norm(embeddings[i] - embeddings[j])
428
+ for j in selected
429
+ )
430
+
431
+ # MMR score
432
+ score = lambda_param * relevance + (1 - lambda_param) * min_dist_to_selected
433
+
434
+ if score > best_score:
435
+ best_score = score
436
+ best_idx = i
437
+
438
+ if best_idx >= 0:
439
+ selected.append(best_idx)
440
+
441
+ reps = []
442
+ for local_idx in selected:
443
+ reps.append(RepresentativeDocument(
444
+ doc_id=indices[local_idx],
445
+ text=documents[local_idx],
446
+ score=1.0 / (1.0 + distances[local_idx]),
447
+ method="diverse"
448
+ ))
449
+
450
+ return TopicRepresentatives(
451
+ topic_id=-1,
452
+ centroid_nearest=reps
453
+ )
454
+
455
+ def _select_hybrid(
456
+ self,
457
+ embeddings: np.ndarray,
458
+ indices: np.ndarray,
459
+ documents: List[str],
460
+ keywords: List[str],
461
+ ) -> TopicRepresentatives:
462
+ """
463
+ Hybrid selection combining:
464
+ 1. Medoid (most typical document)
465
+ 2. Archetypes (facets/extremes)
466
+ 3. Keyword Champion (highest keyword density)
467
+ """
468
+ # 1. Get medoid
469
+ n = len(embeddings)
470
+ distances = np.zeros((n, n))
471
+ for i in range(n):
472
+ for j in range(i + 1, n):
473
+ d = np.linalg.norm(embeddings[i] - embeddings[j])
474
+ distances[i, j] = d
475
+ distances[j, i] = d
476
+
477
+ total_distances = distances.sum(axis=1)
478
+ medoid_idx = np.argmin(total_distances)
479
+
480
+ medoid = RepresentativeDocument(
481
+ doc_id=indices[medoid_idx],
482
+ text=documents[medoid_idx],
483
+ score=1.0,
484
+ method="medoid"
485
+ )
486
+
487
+ # 2. Get archetypes (excluding medoid)
488
+ n_arch = min(self.n_archetypes, n - 1)
489
+ if n_arch > 0:
490
+ # Filter out medoid
491
+ other_indices = [i for i in range(n) if i != medoid_idx]
492
+ other_embeddings = embeddings[other_indices]
493
+
494
+ arch_local = self._furthest_sum_archetypes(other_embeddings, n_arch)
495
+ arch_indices = [other_indices[i] for i in arch_local]
496
+
497
+ archetypes = [
498
+ RepresentativeDocument(
499
+ doc_id=indices[i],
500
+ text=documents[i],
501
+ score=1.0,
502
+ method="archetype"
503
+ )
504
+ for i in arch_indices
505
+ ]
506
+ else:
507
+ archetypes = []
508
+
509
+ # 3. Get keyword champion
510
+ keyword_champion = None
511
+ if keywords:
512
+ keyword_set = set(w.lower() for w in keywords[:10])
513
+ best_score = 0
514
+ best_idx = -1
515
+
516
+ for i, doc in enumerate(documents):
517
+ doc_words = set(doc.lower().split())
518
+ score = len(keyword_set & doc_words) / len(keyword_set)
519
+ if score > best_score:
520
+ best_score = score
521
+ best_idx = i
522
+
523
+ if best_idx >= 0 and best_idx != medoid_idx and best_idx not in arch_indices:
524
+ keyword_champion = RepresentativeDocument(
525
+ doc_id=indices[best_idx],
526
+ text=documents[best_idx],
527
+ score=best_score,
528
+ method="keyword_champion"
529
+ )
530
+
531
+ return TopicRepresentatives(
532
+ topic_id=-1,
533
+ medoid=medoid,
534
+ archetypes=archetypes,
535
+ keyword_champion=keyword_champion
536
+ )
537
+
538
+
539
+ def get_representative_documents(
540
+ topic_representatives: Dict[int, TopicRepresentatives],
541
+ n_per_topic: int = 5,
542
+ ) -> Dict[int, List[Tuple[int, str, str]]]:
543
+ """
544
+ Get flattened list of representative documents per topic.
545
+
546
+ Returns
547
+ -------
548
+ Dict[int, List[Tuple[int, str, str]]]
549
+ topic_id -> [(doc_id, text, method), ...]
550
+ """
551
+ result = {}
552
+
553
+ for topic_id, reps in topic_representatives.items():
554
+ all_reps = reps.get_all()[:n_per_topic]
555
+ result[topic_id] = [
556
+ (r.doc_id, r.text, r.method)
557
+ for r in all_reps
558
+ ]
559
+
560
+ return result