tritopic 0.1.0__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tritopic might be problematic. Click here for more details.
- tritopic/__init__.py +22 -32
- tritopic/config.py +289 -0
- tritopic/core/__init__.py +0 -17
- tritopic/core/clustering.py +229 -243
- tritopic/core/embeddings.py +151 -157
- tritopic/core/graph.py +435 -0
- tritopic/core/keywords.py +213 -249
- tritopic/core/refinement.py +231 -0
- tritopic/core/representatives.py +560 -0
- tritopic/labeling.py +313 -0
- tritopic/model.py +718 -0
- tritopic/multilingual/__init__.py +38 -0
- tritopic/multilingual/detection.py +208 -0
- tritopic/multilingual/stopwords.py +467 -0
- tritopic/multilingual/tokenizers.py +275 -0
- tritopic/visualization.py +371 -0
- {tritopic-0.1.0.dist-info → tritopic-1.1.0.dist-info}/METADATA +91 -51
- tritopic-1.1.0.dist-info/RECORD +20 -0
- tritopic/core/graph_builder.py +0 -493
- tritopic/core/model.py +0 -810
- tritopic/labeling/__init__.py +0 -5
- tritopic/labeling/llm_labeler.py +0 -279
- tritopic/utils/__init__.py +0 -13
- tritopic/utils/metrics.py +0 -254
- tritopic/visualization/__init__.py +0 -5
- tritopic/visualization/plotter.py +0 -523
- tritopic-0.1.0.dist-info/RECORD +0 -18
- tritopic-0.1.0.dist-info/licenses/LICENSE +0 -21
- {tritopic-0.1.0.dist-info → tritopic-1.1.0.dist-info}/WHEEL +0 -0
- {tritopic-0.1.0.dist-info → tritopic-1.1.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,231 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Iterative Refinement Module
|
|
3
|
+
|
|
4
|
+
Implements the iterative embedding refinement process that
|
|
5
|
+
improves topic quality by pulling embeddings toward topic centroids.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from typing import Tuple, Optional, Callable
|
|
9
|
+
import numpy as np
|
|
10
|
+
from scipy import sparse
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class IterativeRefinement:
|
|
14
|
+
"""
|
|
15
|
+
Iteratively refines document embeddings based on topic assignments.
|
|
16
|
+
|
|
17
|
+
The key insight is that after initial clustering, we can use the
|
|
18
|
+
topic structure to improve the embeddings, which in turn improves
|
|
19
|
+
the topic structure.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
max_iterations: int = 5,
|
|
25
|
+
convergence_threshold: float = 0.95,
|
|
26
|
+
refinement_strength: float = 0.15,
|
|
27
|
+
verbose: bool = True,
|
|
28
|
+
):
|
|
29
|
+
"""
|
|
30
|
+
Initialize the refinement process.
|
|
31
|
+
|
|
32
|
+
Parameters
|
|
33
|
+
----------
|
|
34
|
+
max_iterations : int
|
|
35
|
+
Maximum number of refinement iterations
|
|
36
|
+
convergence_threshold : float
|
|
37
|
+
ARI threshold for convergence detection
|
|
38
|
+
refinement_strength : float
|
|
39
|
+
How strongly to pull embeddings toward centroids (0-1)
|
|
40
|
+
verbose : bool
|
|
41
|
+
Print progress information
|
|
42
|
+
"""
|
|
43
|
+
self.max_iterations = max_iterations
|
|
44
|
+
self.convergence_threshold = convergence_threshold
|
|
45
|
+
self.refinement_strength = refinement_strength
|
|
46
|
+
self.verbose = verbose
|
|
47
|
+
|
|
48
|
+
self.convergence_history_ = []
|
|
49
|
+
|
|
50
|
+
def refine(
|
|
51
|
+
self,
|
|
52
|
+
embeddings: np.ndarray,
|
|
53
|
+
initial_labels: np.ndarray,
|
|
54
|
+
graph_builder_fn: Callable[[np.ndarray], sparse.csr_matrix],
|
|
55
|
+
cluster_fn: Callable[[sparse.csr_matrix], np.ndarray],
|
|
56
|
+
) -> Tuple[np.ndarray, np.ndarray]:
|
|
57
|
+
"""
|
|
58
|
+
Iteratively refine embeddings and labels.
|
|
59
|
+
|
|
60
|
+
Parameters
|
|
61
|
+
----------
|
|
62
|
+
embeddings : np.ndarray
|
|
63
|
+
Initial document embeddings
|
|
64
|
+
initial_labels : np.ndarray
|
|
65
|
+
Initial topic labels
|
|
66
|
+
graph_builder_fn : callable
|
|
67
|
+
Function to build graph from embeddings
|
|
68
|
+
cluster_fn : callable
|
|
69
|
+
Function to cluster a graph
|
|
70
|
+
|
|
71
|
+
Returns
|
|
72
|
+
-------
|
|
73
|
+
Tuple[np.ndarray, np.ndarray]
|
|
74
|
+
Refined embeddings and final labels
|
|
75
|
+
"""
|
|
76
|
+
from .clustering import compute_clustering_stability
|
|
77
|
+
|
|
78
|
+
current_embeddings = embeddings.copy()
|
|
79
|
+
current_labels = initial_labels.copy()
|
|
80
|
+
|
|
81
|
+
self.convergence_history_ = []
|
|
82
|
+
|
|
83
|
+
for iteration in range(self.max_iterations):
|
|
84
|
+
if self.verbose:
|
|
85
|
+
print(f" Iteration {iteration + 1}...")
|
|
86
|
+
|
|
87
|
+
# Refine embeddings based on current labels
|
|
88
|
+
refined_embeddings = self._refine_embeddings(
|
|
89
|
+
current_embeddings, current_labels
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
# Build new graph with refined embeddings
|
|
93
|
+
new_graph = graph_builder_fn(refined_embeddings)
|
|
94
|
+
|
|
95
|
+
# Cluster the new graph
|
|
96
|
+
new_labels = cluster_fn(new_graph)
|
|
97
|
+
|
|
98
|
+
# Check convergence
|
|
99
|
+
if iteration > 0:
|
|
100
|
+
stability = compute_clustering_stability(current_labels, new_labels)
|
|
101
|
+
self.convergence_history_.append(stability)
|
|
102
|
+
|
|
103
|
+
if self.verbose:
|
|
104
|
+
print(f" ARI vs previous: {stability:.4f}")
|
|
105
|
+
|
|
106
|
+
if stability >= self.convergence_threshold:
|
|
107
|
+
if self.verbose:
|
|
108
|
+
print(f" ✓ Converged at iteration {iteration + 1}")
|
|
109
|
+
current_embeddings = refined_embeddings
|
|
110
|
+
current_labels = new_labels
|
|
111
|
+
break
|
|
112
|
+
|
|
113
|
+
current_embeddings = refined_embeddings
|
|
114
|
+
current_labels = new_labels
|
|
115
|
+
|
|
116
|
+
return current_embeddings, current_labels
|
|
117
|
+
|
|
118
|
+
def _refine_embeddings(
|
|
119
|
+
self,
|
|
120
|
+
embeddings: np.ndarray,
|
|
121
|
+
labels: np.ndarray,
|
|
122
|
+
) -> np.ndarray:
|
|
123
|
+
"""
|
|
124
|
+
Refine embeddings by pulling them toward topic centroids.
|
|
125
|
+
|
|
126
|
+
For each document in topic t:
|
|
127
|
+
new_embedding = (1 - β) * old_embedding + β * centroid_t
|
|
128
|
+
|
|
129
|
+
where β is the refinement_strength.
|
|
130
|
+
"""
|
|
131
|
+
refined = embeddings.copy()
|
|
132
|
+
unique_labels = np.unique(labels)
|
|
133
|
+
|
|
134
|
+
for label in unique_labels:
|
|
135
|
+
if label < 0: # Skip outliers
|
|
136
|
+
continue
|
|
137
|
+
|
|
138
|
+
mask = labels == label
|
|
139
|
+
if mask.sum() == 0:
|
|
140
|
+
continue
|
|
141
|
+
|
|
142
|
+
# Compute centroid for this topic
|
|
143
|
+
centroid = embeddings[mask].mean(axis=0)
|
|
144
|
+
|
|
145
|
+
# Pull documents toward centroid
|
|
146
|
+
refined[mask] = (
|
|
147
|
+
(1 - self.refinement_strength) * embeddings[mask] +
|
|
148
|
+
self.refinement_strength * centroid
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
# Re-normalize embeddings
|
|
152
|
+
norms = np.linalg.norm(refined, axis=1, keepdims=True)
|
|
153
|
+
norms[norms == 0] = 1
|
|
154
|
+
refined = refined / norms
|
|
155
|
+
|
|
156
|
+
return refined
|
|
157
|
+
|
|
158
|
+
@property
|
|
159
|
+
def n_iterations(self) -> int:
|
|
160
|
+
"""Number of iterations performed."""
|
|
161
|
+
return len(self.convergence_history_) + 1
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
class AdaptiveRefinement(IterativeRefinement):
|
|
165
|
+
"""
|
|
166
|
+
Adaptive refinement that adjusts strength based on topic coherence.
|
|
167
|
+
|
|
168
|
+
Topics with low internal coherence get stronger refinement,
|
|
169
|
+
while already coherent topics get lighter refinement.
|
|
170
|
+
"""
|
|
171
|
+
|
|
172
|
+
def __init__(
|
|
173
|
+
self,
|
|
174
|
+
max_iterations: int = 5,
|
|
175
|
+
convergence_threshold: float = 0.95,
|
|
176
|
+
base_strength: float = 0.15,
|
|
177
|
+
min_strength: float = 0.05,
|
|
178
|
+
max_strength: float = 0.30,
|
|
179
|
+
verbose: bool = True,
|
|
180
|
+
):
|
|
181
|
+
super().__init__(
|
|
182
|
+
max_iterations=max_iterations,
|
|
183
|
+
convergence_threshold=convergence_threshold,
|
|
184
|
+
refinement_strength=base_strength,
|
|
185
|
+
verbose=verbose,
|
|
186
|
+
)
|
|
187
|
+
self.base_strength = base_strength
|
|
188
|
+
self.min_strength = min_strength
|
|
189
|
+
self.max_strength = max_strength
|
|
190
|
+
|
|
191
|
+
def _refine_embeddings(
|
|
192
|
+
self,
|
|
193
|
+
embeddings: np.ndarray,
|
|
194
|
+
labels: np.ndarray,
|
|
195
|
+
) -> np.ndarray:
|
|
196
|
+
"""
|
|
197
|
+
Refine embeddings with adaptive strength per topic.
|
|
198
|
+
"""
|
|
199
|
+
refined = embeddings.copy()
|
|
200
|
+
unique_labels = np.unique(labels)
|
|
201
|
+
|
|
202
|
+
for label in unique_labels:
|
|
203
|
+
if label < 0:
|
|
204
|
+
continue
|
|
205
|
+
|
|
206
|
+
mask = labels == label
|
|
207
|
+
if mask.sum() < 2:
|
|
208
|
+
continue
|
|
209
|
+
|
|
210
|
+
topic_embeddings = embeddings[mask]
|
|
211
|
+
centroid = topic_embeddings.mean(axis=0)
|
|
212
|
+
|
|
213
|
+
# Compute internal coherence (average cosine similarity to centroid)
|
|
214
|
+
centroid_norm = centroid / (np.linalg.norm(centroid) + 1e-8)
|
|
215
|
+
emb_norms = topic_embeddings / (np.linalg.norm(topic_embeddings, axis=1, keepdims=True) + 1e-8)
|
|
216
|
+
coherence = np.mean(np.dot(emb_norms, centroid_norm))
|
|
217
|
+
|
|
218
|
+
# Adaptive strength: less coherent topics get stronger refinement
|
|
219
|
+
# coherence is in [0, 1], invert it for strength
|
|
220
|
+
strength = self.base_strength + (1 - coherence) * (self.max_strength - self.base_strength)
|
|
221
|
+
strength = np.clip(strength, self.min_strength, self.max_strength)
|
|
222
|
+
|
|
223
|
+
# Apply refinement
|
|
224
|
+
refined[mask] = (1 - strength) * embeddings[mask] + strength * centroid
|
|
225
|
+
|
|
226
|
+
# Re-normalize
|
|
227
|
+
norms = np.linalg.norm(refined, axis=1, keepdims=True)
|
|
228
|
+
norms[norms == 0] = 1
|
|
229
|
+
refined = refined / norms
|
|
230
|
+
|
|
231
|
+
return refined
|