tritopic 0.1.0__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tritopic might be problematic. Click here for more details.

@@ -0,0 +1,231 @@
1
+ """
2
+ Iterative Refinement Module
3
+
4
+ Implements the iterative embedding refinement process that
5
+ improves topic quality by pulling embeddings toward topic centroids.
6
+ """
7
+
8
+ from typing import Tuple, Optional, Callable
9
+ import numpy as np
10
+ from scipy import sparse
11
+
12
+
13
+ class IterativeRefinement:
14
+ """
15
+ Iteratively refines document embeddings based on topic assignments.
16
+
17
+ The key insight is that after initial clustering, we can use the
18
+ topic structure to improve the embeddings, which in turn improves
19
+ the topic structure.
20
+ """
21
+
22
+ def __init__(
23
+ self,
24
+ max_iterations: int = 5,
25
+ convergence_threshold: float = 0.95,
26
+ refinement_strength: float = 0.15,
27
+ verbose: bool = True,
28
+ ):
29
+ """
30
+ Initialize the refinement process.
31
+
32
+ Parameters
33
+ ----------
34
+ max_iterations : int
35
+ Maximum number of refinement iterations
36
+ convergence_threshold : float
37
+ ARI threshold for convergence detection
38
+ refinement_strength : float
39
+ How strongly to pull embeddings toward centroids (0-1)
40
+ verbose : bool
41
+ Print progress information
42
+ """
43
+ self.max_iterations = max_iterations
44
+ self.convergence_threshold = convergence_threshold
45
+ self.refinement_strength = refinement_strength
46
+ self.verbose = verbose
47
+
48
+ self.convergence_history_ = []
49
+
50
+ def refine(
51
+ self,
52
+ embeddings: np.ndarray,
53
+ initial_labels: np.ndarray,
54
+ graph_builder_fn: Callable[[np.ndarray], sparse.csr_matrix],
55
+ cluster_fn: Callable[[sparse.csr_matrix], np.ndarray],
56
+ ) -> Tuple[np.ndarray, np.ndarray]:
57
+ """
58
+ Iteratively refine embeddings and labels.
59
+
60
+ Parameters
61
+ ----------
62
+ embeddings : np.ndarray
63
+ Initial document embeddings
64
+ initial_labels : np.ndarray
65
+ Initial topic labels
66
+ graph_builder_fn : callable
67
+ Function to build graph from embeddings
68
+ cluster_fn : callable
69
+ Function to cluster a graph
70
+
71
+ Returns
72
+ -------
73
+ Tuple[np.ndarray, np.ndarray]
74
+ Refined embeddings and final labels
75
+ """
76
+ from .clustering import compute_clustering_stability
77
+
78
+ current_embeddings = embeddings.copy()
79
+ current_labels = initial_labels.copy()
80
+
81
+ self.convergence_history_ = []
82
+
83
+ for iteration in range(self.max_iterations):
84
+ if self.verbose:
85
+ print(f" Iteration {iteration + 1}...")
86
+
87
+ # Refine embeddings based on current labels
88
+ refined_embeddings = self._refine_embeddings(
89
+ current_embeddings, current_labels
90
+ )
91
+
92
+ # Build new graph with refined embeddings
93
+ new_graph = graph_builder_fn(refined_embeddings)
94
+
95
+ # Cluster the new graph
96
+ new_labels = cluster_fn(new_graph)
97
+
98
+ # Check convergence
99
+ if iteration > 0:
100
+ stability = compute_clustering_stability(current_labels, new_labels)
101
+ self.convergence_history_.append(stability)
102
+
103
+ if self.verbose:
104
+ print(f" ARI vs previous: {stability:.4f}")
105
+
106
+ if stability >= self.convergence_threshold:
107
+ if self.verbose:
108
+ print(f" ✓ Converged at iteration {iteration + 1}")
109
+ current_embeddings = refined_embeddings
110
+ current_labels = new_labels
111
+ break
112
+
113
+ current_embeddings = refined_embeddings
114
+ current_labels = new_labels
115
+
116
+ return current_embeddings, current_labels
117
+
118
+ def _refine_embeddings(
119
+ self,
120
+ embeddings: np.ndarray,
121
+ labels: np.ndarray,
122
+ ) -> np.ndarray:
123
+ """
124
+ Refine embeddings by pulling them toward topic centroids.
125
+
126
+ For each document in topic t:
127
+ new_embedding = (1 - β) * old_embedding + β * centroid_t
128
+
129
+ where β is the refinement_strength.
130
+ """
131
+ refined = embeddings.copy()
132
+ unique_labels = np.unique(labels)
133
+
134
+ for label in unique_labels:
135
+ if label < 0: # Skip outliers
136
+ continue
137
+
138
+ mask = labels == label
139
+ if mask.sum() == 0:
140
+ continue
141
+
142
+ # Compute centroid for this topic
143
+ centroid = embeddings[mask].mean(axis=0)
144
+
145
+ # Pull documents toward centroid
146
+ refined[mask] = (
147
+ (1 - self.refinement_strength) * embeddings[mask] +
148
+ self.refinement_strength * centroid
149
+ )
150
+
151
+ # Re-normalize embeddings
152
+ norms = np.linalg.norm(refined, axis=1, keepdims=True)
153
+ norms[norms == 0] = 1
154
+ refined = refined / norms
155
+
156
+ return refined
157
+
158
+ @property
159
+ def n_iterations(self) -> int:
160
+ """Number of iterations performed."""
161
+ return len(self.convergence_history_) + 1
162
+
163
+
164
+ class AdaptiveRefinement(IterativeRefinement):
165
+ """
166
+ Adaptive refinement that adjusts strength based on topic coherence.
167
+
168
+ Topics with low internal coherence get stronger refinement,
169
+ while already coherent topics get lighter refinement.
170
+ """
171
+
172
+ def __init__(
173
+ self,
174
+ max_iterations: int = 5,
175
+ convergence_threshold: float = 0.95,
176
+ base_strength: float = 0.15,
177
+ min_strength: float = 0.05,
178
+ max_strength: float = 0.30,
179
+ verbose: bool = True,
180
+ ):
181
+ super().__init__(
182
+ max_iterations=max_iterations,
183
+ convergence_threshold=convergence_threshold,
184
+ refinement_strength=base_strength,
185
+ verbose=verbose,
186
+ )
187
+ self.base_strength = base_strength
188
+ self.min_strength = min_strength
189
+ self.max_strength = max_strength
190
+
191
+ def _refine_embeddings(
192
+ self,
193
+ embeddings: np.ndarray,
194
+ labels: np.ndarray,
195
+ ) -> np.ndarray:
196
+ """
197
+ Refine embeddings with adaptive strength per topic.
198
+ """
199
+ refined = embeddings.copy()
200
+ unique_labels = np.unique(labels)
201
+
202
+ for label in unique_labels:
203
+ if label < 0:
204
+ continue
205
+
206
+ mask = labels == label
207
+ if mask.sum() < 2:
208
+ continue
209
+
210
+ topic_embeddings = embeddings[mask]
211
+ centroid = topic_embeddings.mean(axis=0)
212
+
213
+ # Compute internal coherence (average cosine similarity to centroid)
214
+ centroid_norm = centroid / (np.linalg.norm(centroid) + 1e-8)
215
+ emb_norms = topic_embeddings / (np.linalg.norm(topic_embeddings, axis=1, keepdims=True) + 1e-8)
216
+ coherence = np.mean(np.dot(emb_norms, centroid_norm))
217
+
218
+ # Adaptive strength: less coherent topics get stronger refinement
219
+ # coherence is in [0, 1], invert it for strength
220
+ strength = self.base_strength + (1 - coherence) * (self.max_strength - self.base_strength)
221
+ strength = np.clip(strength, self.min_strength, self.max_strength)
222
+
223
+ # Apply refinement
224
+ refined[mask] = (1 - strength) * embeddings[mask] + strength * centroid
225
+
226
+ # Re-normalize
227
+ norms = np.linalg.norm(refined, axis=1, keepdims=True)
228
+ norms[norms == 0] = 1
229
+ refined = refined / norms
230
+
231
+ return refined