tritopic 0.1.0__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tritopic/__init__.py +22 -32
- tritopic/config.py +305 -0
- tritopic/core/__init__.py +0 -17
- tritopic/core/clustering.py +229 -243
- tritopic/core/embeddings.py +151 -157
- tritopic/core/graph.py +435 -0
- tritopic/core/keywords.py +213 -249
- tritopic/core/refinement.py +231 -0
- tritopic/core/representatives.py +560 -0
- tritopic/labeling.py +313 -0
- tritopic/model.py +718 -0
- tritopic/multilingual/__init__.py +38 -0
- tritopic/multilingual/detection.py +208 -0
- tritopic/multilingual/stopwords.py +467 -0
- tritopic/multilingual/tokenizers.py +275 -0
- tritopic/visualization.py +371 -0
- {tritopic-0.1.0.dist-info → tritopic-1.0.0.dist-info}/METADATA +92 -48
- tritopic-1.0.0.dist-info/RECORD +20 -0
- tritopic/core/graph_builder.py +0 -493
- tritopic/core/model.py +0 -810
- tritopic/labeling/__init__.py +0 -5
- tritopic/labeling/llm_labeler.py +0 -279
- tritopic/utils/__init__.py +0 -13
- tritopic/utils/metrics.py +0 -254
- tritopic/visualization/__init__.py +0 -5
- tritopic/visualization/plotter.py +0 -523
- tritopic-0.1.0.dist-info/RECORD +0 -18
- tritopic-0.1.0.dist-info/licenses/LICENSE +0 -21
- {tritopic-0.1.0.dist-info → tritopic-1.0.0.dist-info}/WHEEL +0 -0
- {tritopic-0.1.0.dist-info → tritopic-1.0.0.dist-info}/top_level.txt +0 -0
tritopic/core/model.py
DELETED
|
@@ -1,810 +0,0 @@
|
|
|
1
|
-
"""
|
|
2
|
-
TriTopic: Main Model Class
|
|
3
|
-
===========================
|
|
4
|
-
|
|
5
|
-
The core class that orchestrates all components of the topic modeling pipeline.
|
|
6
|
-
"""
|
|
7
|
-
|
|
8
|
-
from __future__ import annotations
|
|
9
|
-
|
|
10
|
-
import warnings
|
|
11
|
-
from dataclasses import dataclass, field
|
|
12
|
-
from typing import Any, Callable, Literal
|
|
13
|
-
|
|
14
|
-
import numpy as np
|
|
15
|
-
import pandas as pd
|
|
16
|
-
from tqdm import tqdm
|
|
17
|
-
|
|
18
|
-
from tritopic.core.embeddings import EmbeddingEngine
|
|
19
|
-
from tritopic.core.graph_builder import GraphBuilder
|
|
20
|
-
from tritopic.core.clustering import ConsensusLeiden
|
|
21
|
-
from tritopic.core.keywords import KeywordExtractor
|
|
22
|
-
from tritopic.utils.metrics import compute_coherence, compute_diversity, compute_stability
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
@dataclass
|
|
26
|
-
class TopicInfo:
|
|
27
|
-
"""Container for topic information."""
|
|
28
|
-
|
|
29
|
-
topic_id: int
|
|
30
|
-
size: int
|
|
31
|
-
keywords: list[str]
|
|
32
|
-
keyword_scores: list[float]
|
|
33
|
-
representative_docs: list[int]
|
|
34
|
-
label: str | None = None
|
|
35
|
-
description: str | None = None
|
|
36
|
-
centroid: np.ndarray | None = None
|
|
37
|
-
coherence: float | None = None
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
@dataclass
|
|
41
|
-
class TriTopicConfig:
|
|
42
|
-
"""Configuration for TriTopic model."""
|
|
43
|
-
|
|
44
|
-
# Embedding settings
|
|
45
|
-
embedding_model: str = "all-MiniLM-L6-v2"
|
|
46
|
-
embedding_batch_size: int = 32
|
|
47
|
-
|
|
48
|
-
# Graph settings
|
|
49
|
-
n_neighbors: int = 15
|
|
50
|
-
metric: str = "cosine"
|
|
51
|
-
graph_type: Literal["mutual_knn", "snn", "hybrid"] = "hybrid"
|
|
52
|
-
snn_weight: float = 0.5
|
|
53
|
-
|
|
54
|
-
# Multi-view settings
|
|
55
|
-
use_lexical_view: bool = True
|
|
56
|
-
use_metadata_view: bool = False
|
|
57
|
-
lexical_weight: float = 0.3
|
|
58
|
-
metadata_weight: float = 0.2
|
|
59
|
-
semantic_weight: float = 0.5
|
|
60
|
-
|
|
61
|
-
# Clustering settings
|
|
62
|
-
resolution: float = 1.0
|
|
63
|
-
resolution_range: tuple[float, float] | None = None
|
|
64
|
-
n_consensus_runs: int = 10
|
|
65
|
-
min_cluster_size: int = 5
|
|
66
|
-
|
|
67
|
-
# Iterative refinement
|
|
68
|
-
use_iterative_refinement: bool = True
|
|
69
|
-
max_iterations: int = 5
|
|
70
|
-
convergence_threshold: float = 0.95
|
|
71
|
-
|
|
72
|
-
# Keyword extraction
|
|
73
|
-
n_keywords: int = 10
|
|
74
|
-
n_representative_docs: int = 5
|
|
75
|
-
keyword_method: Literal["ctfidf", "bm25", "keybert"] = "ctfidf"
|
|
76
|
-
|
|
77
|
-
# Outlier handling
|
|
78
|
-
outlier_threshold: float = 0.1
|
|
79
|
-
|
|
80
|
-
# Misc
|
|
81
|
-
random_state: int = 42
|
|
82
|
-
verbose: bool = True
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
class TriTopic:
|
|
86
|
-
"""
|
|
87
|
-
Tri-Modal Graph Topic Modeling with Iterative Refinement.
|
|
88
|
-
|
|
89
|
-
A state-of-the-art topic modeling approach that combines semantic embeddings,
|
|
90
|
-
lexical similarity, and optional metadata to create robust, interpretable topics.
|
|
91
|
-
|
|
92
|
-
Key innovations:
|
|
93
|
-
- Multi-view graph fusion (semantic + lexical + metadata)
|
|
94
|
-
- Leiden clustering with consensus for stability
|
|
95
|
-
- Iterative refinement loop for optimal topic separation
|
|
96
|
-
- Advanced keyword extraction with representative documents
|
|
97
|
-
- Optional LLM-powered topic labeling
|
|
98
|
-
|
|
99
|
-
Parameters
|
|
100
|
-
----------
|
|
101
|
-
config : TriTopicConfig, optional
|
|
102
|
-
Configuration object. If None, uses defaults.
|
|
103
|
-
embedding_model : str, optional
|
|
104
|
-
Name of sentence-transformers model. Default: "all-MiniLM-L6-v2"
|
|
105
|
-
n_neighbors : int, optional
|
|
106
|
-
Number of neighbors for graph construction. Default: 15
|
|
107
|
-
n_topics : int or "auto", optional
|
|
108
|
-
Number of topics. "auto" uses Leiden's natural resolution. Default: "auto"
|
|
109
|
-
use_iterative_refinement : bool, optional
|
|
110
|
-
Whether to use the iterative refinement loop. Default: True
|
|
111
|
-
verbose : bool, optional
|
|
112
|
-
Print progress information. Default: True
|
|
113
|
-
|
|
114
|
-
Attributes
|
|
115
|
-
----------
|
|
116
|
-
topics_ : list[TopicInfo]
|
|
117
|
-
Information about each discovered topic.
|
|
118
|
-
labels_ : np.ndarray
|
|
119
|
-
Topic assignment for each document.
|
|
120
|
-
embeddings_ : np.ndarray
|
|
121
|
-
Document embeddings.
|
|
122
|
-
graph_ : igraph.Graph
|
|
123
|
-
The constructed similarity graph.
|
|
124
|
-
topic_embeddings_ : np.ndarray
|
|
125
|
-
Centroid embeddings for each topic.
|
|
126
|
-
|
|
127
|
-
Examples
|
|
128
|
-
--------
|
|
129
|
-
Basic usage:
|
|
130
|
-
|
|
131
|
-
>>> from tritopic import TriTopic
|
|
132
|
-
>>> model = TriTopic(n_neighbors=15, verbose=True)
|
|
133
|
-
>>> topics = model.fit_transform(documents)
|
|
134
|
-
>>> print(model.get_topic_info())
|
|
135
|
-
|
|
136
|
-
With metadata:
|
|
137
|
-
|
|
138
|
-
>>> model = TriTopic()
|
|
139
|
-
>>> model.config.use_metadata_view = True
|
|
140
|
-
>>> topics = model.fit_transform(documents, metadata=df[['source', 'date']])
|
|
141
|
-
|
|
142
|
-
With LLM labeling:
|
|
143
|
-
|
|
144
|
-
>>> from tritopic import TriTopic, LLMLabeler
|
|
145
|
-
>>> model = TriTopic()
|
|
146
|
-
>>> model.fit_transform(documents)
|
|
147
|
-
>>> labeler = LLMLabeler(provider="anthropic", api_key="...")
|
|
148
|
-
>>> model.generate_labels(labeler)
|
|
149
|
-
"""
|
|
150
|
-
|
|
151
|
-
def __init__(
|
|
152
|
-
self,
|
|
153
|
-
config: TriTopicConfig | None = None,
|
|
154
|
-
embedding_model: str | None = None,
|
|
155
|
-
n_neighbors: int | None = None,
|
|
156
|
-
n_topics: int | Literal["auto"] = "auto",
|
|
157
|
-
use_iterative_refinement: bool | None = None,
|
|
158
|
-
verbose: bool | None = None,
|
|
159
|
-
random_state: int | None = None,
|
|
160
|
-
):
|
|
161
|
-
# Initialize config
|
|
162
|
-
self.config = config or TriTopicConfig()
|
|
163
|
-
|
|
164
|
-
# Override config with explicit parameters
|
|
165
|
-
if embedding_model is not None:
|
|
166
|
-
self.config.embedding_model = embedding_model
|
|
167
|
-
if n_neighbors is not None:
|
|
168
|
-
self.config.n_neighbors = n_neighbors
|
|
169
|
-
if use_iterative_refinement is not None:
|
|
170
|
-
self.config.use_iterative_refinement = use_iterative_refinement
|
|
171
|
-
if verbose is not None:
|
|
172
|
-
self.config.verbose = verbose
|
|
173
|
-
if random_state is not None:
|
|
174
|
-
self.config.random_state = random_state
|
|
175
|
-
|
|
176
|
-
self.n_topics = n_topics
|
|
177
|
-
|
|
178
|
-
# Initialize components
|
|
179
|
-
self._embedding_engine = EmbeddingEngine(
|
|
180
|
-
model_name=self.config.embedding_model,
|
|
181
|
-
batch_size=self.config.embedding_batch_size,
|
|
182
|
-
)
|
|
183
|
-
self._graph_builder = GraphBuilder(
|
|
184
|
-
n_neighbors=self.config.n_neighbors,
|
|
185
|
-
metric=self.config.metric,
|
|
186
|
-
graph_type=self.config.graph_type,
|
|
187
|
-
snn_weight=self.config.snn_weight,
|
|
188
|
-
)
|
|
189
|
-
self._clusterer = ConsensusLeiden(
|
|
190
|
-
resolution=self.config.resolution,
|
|
191
|
-
n_runs=self.config.n_consensus_runs,
|
|
192
|
-
random_state=self.config.random_state,
|
|
193
|
-
)
|
|
194
|
-
self._keyword_extractor = KeywordExtractor(
|
|
195
|
-
method=self.config.keyword_method,
|
|
196
|
-
n_keywords=self.config.n_keywords,
|
|
197
|
-
)
|
|
198
|
-
|
|
199
|
-
# State
|
|
200
|
-
self.topics_: list[TopicInfo] = []
|
|
201
|
-
self.labels_: np.ndarray | None = None
|
|
202
|
-
self.embeddings_: np.ndarray | None = None
|
|
203
|
-
self.lexical_matrix_: Any | None = None
|
|
204
|
-
self.graph_: Any | None = None
|
|
205
|
-
self.topic_embeddings_: np.ndarray | None = None
|
|
206
|
-
self.documents_: list[str] | None = None
|
|
207
|
-
self._is_fitted: bool = False
|
|
208
|
-
self._iteration_history: list[dict] = []
|
|
209
|
-
|
|
210
|
-
def fit(
|
|
211
|
-
self,
|
|
212
|
-
documents: list[str],
|
|
213
|
-
embeddings: np.ndarray | None = None,
|
|
214
|
-
metadata: pd.DataFrame | None = None,
|
|
215
|
-
) -> "TriTopic":
|
|
216
|
-
"""
|
|
217
|
-
Fit the topic model to documents.
|
|
218
|
-
|
|
219
|
-
Parameters
|
|
220
|
-
----------
|
|
221
|
-
documents : list[str]
|
|
222
|
-
List of document texts.
|
|
223
|
-
embeddings : np.ndarray, optional
|
|
224
|
-
Pre-computed embeddings. If None, computed automatically.
|
|
225
|
-
metadata : pd.DataFrame, optional
|
|
226
|
-
Document metadata for the metadata view.
|
|
227
|
-
|
|
228
|
-
Returns
|
|
229
|
-
-------
|
|
230
|
-
self : TriTopic
|
|
231
|
-
Fitted model.
|
|
232
|
-
"""
|
|
233
|
-
self.documents_ = documents
|
|
234
|
-
n_docs = len(documents)
|
|
235
|
-
|
|
236
|
-
if self.config.verbose:
|
|
237
|
-
print(f"🚀 TriTopic: Fitting model on {n_docs} documents")
|
|
238
|
-
print(f" Config: {self.config.graph_type} graph, "
|
|
239
|
-
f"{'iterative' if self.config.use_iterative_refinement else 'single-pass'} mode")
|
|
240
|
-
|
|
241
|
-
# Step 1: Generate embeddings
|
|
242
|
-
if embeddings is not None:
|
|
243
|
-
self.embeddings_ = embeddings
|
|
244
|
-
if self.config.verbose:
|
|
245
|
-
print(" ✓ Using provided embeddings")
|
|
246
|
-
else:
|
|
247
|
-
if self.config.verbose:
|
|
248
|
-
print(f" → Generating embeddings ({self.config.embedding_model})...")
|
|
249
|
-
self.embeddings_ = self._embedding_engine.encode(documents)
|
|
250
|
-
|
|
251
|
-
# Step 2: Build lexical representation
|
|
252
|
-
if self.config.use_lexical_view:
|
|
253
|
-
if self.config.verbose:
|
|
254
|
-
print(" → Building lexical similarity matrix...")
|
|
255
|
-
self.lexical_matrix_ = self._graph_builder.build_lexical_matrix(documents)
|
|
256
|
-
|
|
257
|
-
# Step 3: Build metadata graph (if provided)
|
|
258
|
-
metadata_graph = None
|
|
259
|
-
if self.config.use_metadata_view and metadata is not None:
|
|
260
|
-
if self.config.verbose:
|
|
261
|
-
print(" → Building metadata similarity graph...")
|
|
262
|
-
metadata_graph = self._graph_builder.build_metadata_graph(metadata)
|
|
263
|
-
|
|
264
|
-
# Step 4: Main fitting loop
|
|
265
|
-
if self.config.use_iterative_refinement:
|
|
266
|
-
self._fit_iterative(documents, metadata_graph)
|
|
267
|
-
else:
|
|
268
|
-
self._fit_single_pass(documents, metadata_graph)
|
|
269
|
-
|
|
270
|
-
# Step 5: Extract keywords and representative docs
|
|
271
|
-
if self.config.verbose:
|
|
272
|
-
print(" → Extracting keywords and representative documents...")
|
|
273
|
-
self._extract_topic_info(documents)
|
|
274
|
-
|
|
275
|
-
# Step 6: Compute topic centroids
|
|
276
|
-
self._compute_topic_centroids()
|
|
277
|
-
|
|
278
|
-
self._is_fitted = True
|
|
279
|
-
|
|
280
|
-
if self.config.verbose:
|
|
281
|
-
n_topics = len([t for t in self.topics_ if t.topic_id != -1])
|
|
282
|
-
n_outliers = np.sum(self.labels_ == -1) if self.labels_ is not None else 0
|
|
283
|
-
print(f"\n✅ Fitting complete!")
|
|
284
|
-
print(f" Found {n_topics} topics")
|
|
285
|
-
print(f" {n_outliers} outlier documents ({100*n_outliers/n_docs:.1f}%)")
|
|
286
|
-
|
|
287
|
-
return self
|
|
288
|
-
|
|
289
|
-
def _fit_single_pass(
|
|
290
|
-
self,
|
|
291
|
-
documents: list[str],
|
|
292
|
-
metadata_graph: Any | None = None,
|
|
293
|
-
) -> None:
|
|
294
|
-
"""Single-pass fitting without iterative refinement."""
|
|
295
|
-
# Build graph
|
|
296
|
-
if self.config.verbose:
|
|
297
|
-
print(" → Building multi-view graph...")
|
|
298
|
-
|
|
299
|
-
self.graph_ = self._graph_builder.build_multiview_graph(
|
|
300
|
-
semantic_embeddings=self.embeddings_,
|
|
301
|
-
lexical_matrix=self.lexical_matrix_ if self.config.use_lexical_view else None,
|
|
302
|
-
metadata_graph=metadata_graph,
|
|
303
|
-
weights={
|
|
304
|
-
"semantic": self.config.semantic_weight,
|
|
305
|
-
"lexical": self.config.lexical_weight,
|
|
306
|
-
"metadata": self.config.metadata_weight,
|
|
307
|
-
}
|
|
308
|
-
)
|
|
309
|
-
|
|
310
|
-
# Cluster
|
|
311
|
-
if self.config.verbose:
|
|
312
|
-
print(f" → Running Leiden consensus clustering ({self.config.n_consensus_runs} runs)...")
|
|
313
|
-
|
|
314
|
-
self.labels_ = self._clusterer.fit_predict(
|
|
315
|
-
self.graph_,
|
|
316
|
-
min_cluster_size=self.config.min_cluster_size,
|
|
317
|
-
)
|
|
318
|
-
|
|
319
|
-
def _fit_iterative(
|
|
320
|
-
self,
|
|
321
|
-
documents: list[str],
|
|
322
|
-
metadata_graph: Any | None = None,
|
|
323
|
-
) -> None:
|
|
324
|
-
"""Iterative refinement fitting loop."""
|
|
325
|
-
if self.config.verbose:
|
|
326
|
-
print(f" → Starting iterative refinement (max {self.config.max_iterations} iterations)...")
|
|
327
|
-
|
|
328
|
-
current_embeddings = self.embeddings_.copy()
|
|
329
|
-
previous_labels = None
|
|
330
|
-
|
|
331
|
-
for iteration in range(self.config.max_iterations):
|
|
332
|
-
if self.config.verbose:
|
|
333
|
-
print(f" Iteration {iteration + 1}...")
|
|
334
|
-
|
|
335
|
-
# Build graph with current embeddings
|
|
336
|
-
self.graph_ = self._graph_builder.build_multiview_graph(
|
|
337
|
-
semantic_embeddings=current_embeddings,
|
|
338
|
-
lexical_matrix=self.lexical_matrix_ if self.config.use_lexical_view else None,
|
|
339
|
-
metadata_graph=metadata_graph,
|
|
340
|
-
weights={
|
|
341
|
-
"semantic": self.config.semantic_weight,
|
|
342
|
-
"lexical": self.config.lexical_weight,
|
|
343
|
-
"metadata": self.config.metadata_weight,
|
|
344
|
-
}
|
|
345
|
-
)
|
|
346
|
-
|
|
347
|
-
# Cluster
|
|
348
|
-
self.labels_ = self._clusterer.fit_predict(
|
|
349
|
-
self.graph_,
|
|
350
|
-
min_cluster_size=self.config.min_cluster_size,
|
|
351
|
-
)
|
|
352
|
-
|
|
353
|
-
# Check convergence
|
|
354
|
-
if previous_labels is not None:
|
|
355
|
-
from sklearn.metrics import adjusted_rand_score
|
|
356
|
-
ari = adjusted_rand_score(previous_labels, self.labels_)
|
|
357
|
-
self._iteration_history.append({
|
|
358
|
-
"iteration": iteration + 1,
|
|
359
|
-
"ari": ari,
|
|
360
|
-
"n_topics": len(np.unique(self.labels_[self.labels_ != -1])),
|
|
361
|
-
})
|
|
362
|
-
|
|
363
|
-
if self.config.verbose:
|
|
364
|
-
print(f" ARI vs previous: {ari:.4f}")
|
|
365
|
-
|
|
366
|
-
if ari >= self.config.convergence_threshold:
|
|
367
|
-
if self.config.verbose:
|
|
368
|
-
print(f" ✓ Converged at iteration {iteration + 1}")
|
|
369
|
-
break
|
|
370
|
-
|
|
371
|
-
previous_labels = self.labels_.copy()
|
|
372
|
-
|
|
373
|
-
# Refine embeddings based on topic structure
|
|
374
|
-
current_embeddings = self._refine_embeddings(
|
|
375
|
-
documents, self.embeddings_, self.labels_
|
|
376
|
-
)
|
|
377
|
-
|
|
378
|
-
# Store final refined embeddings
|
|
379
|
-
self.embeddings_ = current_embeddings
|
|
380
|
-
|
|
381
|
-
def _refine_embeddings(
|
|
382
|
-
self,
|
|
383
|
-
documents: list[str],
|
|
384
|
-
original_embeddings: np.ndarray,
|
|
385
|
-
labels: np.ndarray,
|
|
386
|
-
) -> np.ndarray:
|
|
387
|
-
"""
|
|
388
|
-
Refine embeddings by incorporating topic context.
|
|
389
|
-
|
|
390
|
-
This is the key innovation: we modify embeddings to be more
|
|
391
|
-
topic-aware by pulling documents toward their topic centroid.
|
|
392
|
-
"""
|
|
393
|
-
refined = original_embeddings.copy()
|
|
394
|
-
unique_labels = np.unique(labels[labels != -1])
|
|
395
|
-
|
|
396
|
-
# Compute topic centroids
|
|
397
|
-
centroids = {}
|
|
398
|
-
for label in unique_labels:
|
|
399
|
-
mask = labels == label
|
|
400
|
-
centroids[label] = original_embeddings[mask].mean(axis=0)
|
|
401
|
-
|
|
402
|
-
# Soft refinement: blend original embedding with topic centroid
|
|
403
|
-
blend_factor = 0.2 # How much to pull toward centroid
|
|
404
|
-
|
|
405
|
-
for i, label in enumerate(labels):
|
|
406
|
-
if label != -1: # Skip outliers
|
|
407
|
-
centroid = centroids[label]
|
|
408
|
-
refined[i] = (1 - blend_factor) * refined[i] + blend_factor * centroid
|
|
409
|
-
# Re-normalize
|
|
410
|
-
refined[i] = refined[i] / np.linalg.norm(refined[i])
|
|
411
|
-
|
|
412
|
-
return refined
|
|
413
|
-
|
|
414
|
-
def _extract_topic_info(self, documents: list[str]) -> None:
|
|
415
|
-
"""Extract keywords and representative documents for each topic."""
|
|
416
|
-
self.topics_ = []
|
|
417
|
-
unique_labels = np.unique(self.labels_)
|
|
418
|
-
|
|
419
|
-
for label in unique_labels:
|
|
420
|
-
mask = self.labels_ == label
|
|
421
|
-
topic_docs = [documents[i] for i in np.where(mask)[0]]
|
|
422
|
-
topic_indices = np.where(mask)[0]
|
|
423
|
-
|
|
424
|
-
# Extract keywords
|
|
425
|
-
keywords, scores = self._keyword_extractor.extract(
|
|
426
|
-
topic_docs,
|
|
427
|
-
all_docs=documents,
|
|
428
|
-
n_keywords=self.config.n_keywords,
|
|
429
|
-
)
|
|
430
|
-
|
|
431
|
-
# Find representative documents (closest to centroid)
|
|
432
|
-
if self.embeddings_ is not None and label != -1:
|
|
433
|
-
topic_embeddings = self.embeddings_[mask]
|
|
434
|
-
centroid = topic_embeddings.mean(axis=0)
|
|
435
|
-
distances = np.linalg.norm(topic_embeddings - centroid, axis=1)
|
|
436
|
-
top_indices = np.argsort(distances)[:self.config.n_representative_docs]
|
|
437
|
-
representative_docs = [int(topic_indices[i]) for i in top_indices]
|
|
438
|
-
else:
|
|
439
|
-
representative_docs = list(topic_indices[:self.config.n_representative_docs])
|
|
440
|
-
|
|
441
|
-
topic_info = TopicInfo(
|
|
442
|
-
topic_id=int(label),
|
|
443
|
-
size=int(mask.sum()),
|
|
444
|
-
keywords=keywords,
|
|
445
|
-
keyword_scores=scores,
|
|
446
|
-
representative_docs=representative_docs,
|
|
447
|
-
label=None,
|
|
448
|
-
description=None,
|
|
449
|
-
)
|
|
450
|
-
self.topics_.append(topic_info)
|
|
451
|
-
|
|
452
|
-
# Sort by size (excluding outliers)
|
|
453
|
-
self.topics_ = sorted(
|
|
454
|
-
self.topics_,
|
|
455
|
-
key=lambda t: (t.topic_id == -1, -t.size)
|
|
456
|
-
)
|
|
457
|
-
|
|
458
|
-
def _compute_topic_centroids(self) -> None:
|
|
459
|
-
"""Compute centroid embeddings for each topic."""
|
|
460
|
-
if self.embeddings_ is None:
|
|
461
|
-
return
|
|
462
|
-
|
|
463
|
-
unique_labels = [t.topic_id for t in self.topics_ if t.topic_id != -1]
|
|
464
|
-
self.topic_embeddings_ = np.zeros((len(unique_labels), self.embeddings_.shape[1]))
|
|
465
|
-
|
|
466
|
-
for i, label in enumerate(unique_labels):
|
|
467
|
-
mask = self.labels_ == label
|
|
468
|
-
self.topic_embeddings_[i] = self.embeddings_[mask].mean(axis=0)
|
|
469
|
-
|
|
470
|
-
# Store in topic info
|
|
471
|
-
for topic in self.topics_:
|
|
472
|
-
if topic.topic_id == label:
|
|
473
|
-
topic.centroid = self.topic_embeddings_[i]
|
|
474
|
-
break
|
|
475
|
-
|
|
476
|
-
def fit_transform(
|
|
477
|
-
self,
|
|
478
|
-
documents: list[str],
|
|
479
|
-
embeddings: np.ndarray | None = None,
|
|
480
|
-
metadata: pd.DataFrame | None = None,
|
|
481
|
-
) -> np.ndarray:
|
|
482
|
-
"""
|
|
483
|
-
Fit the model and return topic assignments.
|
|
484
|
-
|
|
485
|
-
Parameters
|
|
486
|
-
----------
|
|
487
|
-
documents : list[str]
|
|
488
|
-
List of document texts.
|
|
489
|
-
embeddings : np.ndarray, optional
|
|
490
|
-
Pre-computed embeddings.
|
|
491
|
-
metadata : pd.DataFrame, optional
|
|
492
|
-
Document metadata.
|
|
493
|
-
|
|
494
|
-
Returns
|
|
495
|
-
-------
|
|
496
|
-
labels : np.ndarray
|
|
497
|
-
Topic assignment for each document. -1 indicates outlier.
|
|
498
|
-
"""
|
|
499
|
-
self.fit(documents, embeddings, metadata)
|
|
500
|
-
return self.labels_
|
|
501
|
-
|
|
502
|
-
def transform(self, documents: list[str]) -> np.ndarray:
|
|
503
|
-
"""
|
|
504
|
-
Assign topics to new documents.
|
|
505
|
-
|
|
506
|
-
Parameters
|
|
507
|
-
----------
|
|
508
|
-
documents : list[str]
|
|
509
|
-
New documents to classify.
|
|
510
|
-
|
|
511
|
-
Returns
|
|
512
|
-
-------
|
|
513
|
-
labels : np.ndarray
|
|
514
|
-
Topic assignments.
|
|
515
|
-
"""
|
|
516
|
-
if not self._is_fitted:
|
|
517
|
-
raise ValueError("Model not fitted. Call fit() first.")
|
|
518
|
-
|
|
519
|
-
# Encode new documents
|
|
520
|
-
new_embeddings = self._embedding_engine.encode(documents)
|
|
521
|
-
|
|
522
|
-
# Find nearest topic centroid
|
|
523
|
-
labels = np.zeros(len(documents), dtype=int)
|
|
524
|
-
|
|
525
|
-
for i, emb in enumerate(new_embeddings):
|
|
526
|
-
distances = np.linalg.norm(self.topic_embeddings_ - emb, axis=1)
|
|
527
|
-
nearest_topic_idx = np.argmin(distances)
|
|
528
|
-
|
|
529
|
-
# Check if it's an outlier (too far from any centroid)
|
|
530
|
-
if distances[nearest_topic_idx] > self.config.outlier_threshold * 2:
|
|
531
|
-
labels[i] = -1
|
|
532
|
-
else:
|
|
533
|
-
# Map index back to topic_id
|
|
534
|
-
non_outlier_topics = [t for t in self.topics_ if t.topic_id != -1]
|
|
535
|
-
labels[i] = non_outlier_topics[nearest_topic_idx].topic_id
|
|
536
|
-
|
|
537
|
-
return labels
|
|
538
|
-
|
|
539
|
-
def get_topic_info(self) -> pd.DataFrame:
|
|
540
|
-
"""
|
|
541
|
-
Get a DataFrame with topic information.
|
|
542
|
-
|
|
543
|
-
Returns
|
|
544
|
-
-------
|
|
545
|
-
df : pd.DataFrame
|
|
546
|
-
DataFrame with columns: Topic, Size, Keywords, Label, Coherence
|
|
547
|
-
"""
|
|
548
|
-
if not self._is_fitted:
|
|
549
|
-
raise ValueError("Model not fitted. Call fit() first.")
|
|
550
|
-
|
|
551
|
-
data = []
|
|
552
|
-
for topic in self.topics_:
|
|
553
|
-
data.append({
|
|
554
|
-
"Topic": topic.topic_id,
|
|
555
|
-
"Size": topic.size,
|
|
556
|
-
"Keywords": ", ".join(topic.keywords[:5]),
|
|
557
|
-
"All_Keywords": topic.keywords,
|
|
558
|
-
"Keyword_Scores": topic.keyword_scores,
|
|
559
|
-
"Label": topic.label or f"Topic {topic.topic_id}",
|
|
560
|
-
"Description": topic.description,
|
|
561
|
-
"Representative_Docs": topic.representative_docs,
|
|
562
|
-
"Coherence": topic.coherence,
|
|
563
|
-
})
|
|
564
|
-
|
|
565
|
-
return pd.DataFrame(data)
|
|
566
|
-
|
|
567
|
-
def get_topic(self, topic_id: int) -> TopicInfo | None:
|
|
568
|
-
"""Get information about a specific topic."""
|
|
569
|
-
for topic in self.topics_:
|
|
570
|
-
if topic.topic_id == topic_id:
|
|
571
|
-
return topic
|
|
572
|
-
return None
|
|
573
|
-
|
|
574
|
-
def get_representative_docs(
|
|
575
|
-
self,
|
|
576
|
-
topic_id: int,
|
|
577
|
-
n_docs: int = 5,
|
|
578
|
-
) -> list[tuple[int, str]]:
|
|
579
|
-
"""
|
|
580
|
-
Get representative documents for a topic.
|
|
581
|
-
|
|
582
|
-
Parameters
|
|
583
|
-
----------
|
|
584
|
-
topic_id : int
|
|
585
|
-
Topic ID.
|
|
586
|
-
n_docs : int
|
|
587
|
-
Number of documents to return.
|
|
588
|
-
|
|
589
|
-
Returns
|
|
590
|
-
-------
|
|
591
|
-
docs : list[tuple[int, str]]
|
|
592
|
-
List of (index, document_text) tuples.
|
|
593
|
-
"""
|
|
594
|
-
if not self._is_fitted or self.documents_ is None:
|
|
595
|
-
raise ValueError("Model not fitted. Call fit() first.")
|
|
596
|
-
|
|
597
|
-
topic = self.get_topic(topic_id)
|
|
598
|
-
if topic is None:
|
|
599
|
-
raise ValueError(f"Topic {topic_id} not found.")
|
|
600
|
-
|
|
601
|
-
indices = topic.representative_docs[:n_docs]
|
|
602
|
-
return [(idx, self.documents_[idx]) for idx in indices]
|
|
603
|
-
|
|
604
|
-
def generate_labels(
|
|
605
|
-
self,
|
|
606
|
-
labeler: "LLMLabeler",
|
|
607
|
-
topics: list[int] | None = None,
|
|
608
|
-
) -> None:
|
|
609
|
-
"""
|
|
610
|
-
Generate labels for topics using an LLM.
|
|
611
|
-
|
|
612
|
-
Parameters
|
|
613
|
-
----------
|
|
614
|
-
labeler : LLMLabeler
|
|
615
|
-
Configured LLM labeler instance.
|
|
616
|
-
topics : list[int], optional
|
|
617
|
-
Specific topics to label. If None, labels all.
|
|
618
|
-
"""
|
|
619
|
-
if not self._is_fitted:
|
|
620
|
-
raise ValueError("Model not fitted. Call fit() first.")
|
|
621
|
-
|
|
622
|
-
target_topics = topics or [t.topic_id for t in self.topics_ if t.topic_id != -1]
|
|
623
|
-
|
|
624
|
-
for topic_id in tqdm(target_topics, desc="Generating labels", disable=not self.config.verbose):
|
|
625
|
-
topic = self.get_topic(topic_id)
|
|
626
|
-
if topic is None:
|
|
627
|
-
continue
|
|
628
|
-
|
|
629
|
-
# Get representative docs
|
|
630
|
-
rep_docs = self.get_representative_docs(topic_id, n_docs=5)
|
|
631
|
-
doc_texts = [doc for _, doc in rep_docs]
|
|
632
|
-
|
|
633
|
-
# Generate label
|
|
634
|
-
label, description = labeler.generate_label(
|
|
635
|
-
keywords=topic.keywords,
|
|
636
|
-
representative_docs=doc_texts,
|
|
637
|
-
)
|
|
638
|
-
|
|
639
|
-
topic.label = label
|
|
640
|
-
topic.description = description
|
|
641
|
-
|
|
642
|
-
def visualize(
|
|
643
|
-
self,
|
|
644
|
-
method: Literal["umap", "pacmap"] = "umap",
|
|
645
|
-
color_by: Literal["topic", "custom"] = "topic",
|
|
646
|
-
custom_labels: list[str] | None = None,
|
|
647
|
-
show_outliers: bool = True,
|
|
648
|
-
interactive: bool = True,
|
|
649
|
-
**kwargs,
|
|
650
|
-
):
|
|
651
|
-
"""
|
|
652
|
-
Visualize topics in 2D.
|
|
653
|
-
|
|
654
|
-
Parameters
|
|
655
|
-
----------
|
|
656
|
-
method : str
|
|
657
|
-
Dimensionality reduction method. "umap" or "pacmap".
|
|
658
|
-
color_by : str
|
|
659
|
-
How to color points. "topic" uses topic assignments.
|
|
660
|
-
custom_labels : list[str], optional
|
|
661
|
-
Custom labels for hover text.
|
|
662
|
-
show_outliers : bool
|
|
663
|
-
Whether to show outlier documents.
|
|
664
|
-
interactive : bool
|
|
665
|
-
If True, returns interactive Plotly figure.
|
|
666
|
-
**kwargs
|
|
667
|
-
Additional arguments passed to the visualizer.
|
|
668
|
-
|
|
669
|
-
Returns
|
|
670
|
-
-------
|
|
671
|
-
fig : plotly.graph_objects.Figure
|
|
672
|
-
Interactive visualization.
|
|
673
|
-
"""
|
|
674
|
-
from tritopic.visualization.plotter import TopicVisualizer
|
|
675
|
-
|
|
676
|
-
if not self._is_fitted:
|
|
677
|
-
raise ValueError("Model not fitted. Call fit() first.")
|
|
678
|
-
|
|
679
|
-
visualizer = TopicVisualizer(method=method)
|
|
680
|
-
|
|
681
|
-
return visualizer.plot_documents(
|
|
682
|
-
embeddings=self.embeddings_,
|
|
683
|
-
labels=self.labels_,
|
|
684
|
-
documents=self.documents_,
|
|
685
|
-
topics=self.topics_,
|
|
686
|
-
show_outliers=show_outliers,
|
|
687
|
-
interactive=interactive,
|
|
688
|
-
**kwargs,
|
|
689
|
-
)
|
|
690
|
-
|
|
691
|
-
def visualize_hierarchy(self, **kwargs):
|
|
692
|
-
"""Visualize topic hierarchy as a dendrogram."""
|
|
693
|
-
from tritopic.visualization.plotter import TopicVisualizer
|
|
694
|
-
|
|
695
|
-
if not self._is_fitted:
|
|
696
|
-
raise ValueError("Model not fitted. Call fit() first.")
|
|
697
|
-
|
|
698
|
-
visualizer = TopicVisualizer()
|
|
699
|
-
return visualizer.plot_hierarchy(
|
|
700
|
-
topic_embeddings=self.topic_embeddings_,
|
|
701
|
-
topics=self.topics_,
|
|
702
|
-
**kwargs,
|
|
703
|
-
)
|
|
704
|
-
|
|
705
|
-
def visualize_topics(self, **kwargs):
|
|
706
|
-
"""Visualize topics as a heatmap or bar chart."""
|
|
707
|
-
from tritopic.visualization.plotter import TopicVisualizer
|
|
708
|
-
|
|
709
|
-
if not self._is_fitted:
|
|
710
|
-
raise ValueError("Model not fitted. Call fit() first.")
|
|
711
|
-
|
|
712
|
-
visualizer = TopicVisualizer()
|
|
713
|
-
return visualizer.plot_topics(
|
|
714
|
-
topics=self.topics_,
|
|
715
|
-
**kwargs,
|
|
716
|
-
)
|
|
717
|
-
|
|
718
|
-
def evaluate(self) -> dict[str, float]:
|
|
719
|
-
"""
|
|
720
|
-
Evaluate topic model quality.
|
|
721
|
-
|
|
722
|
-
Returns
|
|
723
|
-
-------
|
|
724
|
-
metrics : dict
|
|
725
|
-
Dictionary with coherence, diversity, and stability scores.
|
|
726
|
-
"""
|
|
727
|
-
if not self._is_fitted:
|
|
728
|
-
raise ValueError("Model not fitted. Call fit() first.")
|
|
729
|
-
|
|
730
|
-
# Compute coherence for each topic
|
|
731
|
-
coherences = []
|
|
732
|
-
for topic in self.topics_:
|
|
733
|
-
if topic.topic_id != -1:
|
|
734
|
-
coh = compute_coherence(
|
|
735
|
-
topic.keywords,
|
|
736
|
-
[self.documents_[i] for i in np.where(self.labels_ == topic.topic_id)[0]]
|
|
737
|
-
)
|
|
738
|
-
topic.coherence = coh
|
|
739
|
-
coherences.append(coh)
|
|
740
|
-
|
|
741
|
-
# Compute diversity
|
|
742
|
-
all_keywords = [kw for t in self.topics_ if t.topic_id != -1 for kw in t.keywords]
|
|
743
|
-
diversity = compute_diversity(all_keywords, n_topics=len(coherences))
|
|
744
|
-
|
|
745
|
-
# Get stability from consensus clustering
|
|
746
|
-
stability = self._clusterer.stability_score_ if hasattr(self._clusterer, 'stability_score_') else None
|
|
747
|
-
|
|
748
|
-
metrics = {
|
|
749
|
-
"coherence_mean": float(np.mean(coherences)) if coherences else 0.0,
|
|
750
|
-
"coherence_std": float(np.std(coherences)) if coherences else 0.0,
|
|
751
|
-
"diversity": diversity,
|
|
752
|
-
"stability": stability,
|
|
753
|
-
"n_topics": len([t for t in self.topics_ if t.topic_id != -1]),
|
|
754
|
-
"outlier_ratio": float(np.mean(self.labels_ == -1)) if self.labels_ is not None else 0.0,
|
|
755
|
-
}
|
|
756
|
-
|
|
757
|
-
if self.config.verbose:
|
|
758
|
-
print("\n📊 Evaluation Metrics:")
|
|
759
|
-
print(f" Coherence (mean): {metrics['coherence_mean']:.4f}")
|
|
760
|
-
print(f" Diversity: {metrics['diversity']:.4f}")
|
|
761
|
-
if stability:
|
|
762
|
-
print(f" Stability: {stability:.4f}")
|
|
763
|
-
print(f" Outlier ratio: {metrics['outlier_ratio']:.2%}")
|
|
764
|
-
|
|
765
|
-
return metrics
|
|
766
|
-
|
|
767
|
-
def save(self, path: str) -> None:
|
|
768
|
-
"""Save model to disk."""
|
|
769
|
-
import pickle
|
|
770
|
-
|
|
771
|
-
state = {
|
|
772
|
-
"config": self.config,
|
|
773
|
-
"topics_": self.topics_,
|
|
774
|
-
"labels_": self.labels_,
|
|
775
|
-
"embeddings_": self.embeddings_,
|
|
776
|
-
"topic_embeddings_": self.topic_embeddings_,
|
|
777
|
-
"documents_": self.documents_,
|
|
778
|
-
"_is_fitted": self._is_fitted,
|
|
779
|
-
"_iteration_history": self._iteration_history,
|
|
780
|
-
}
|
|
781
|
-
|
|
782
|
-
with open(path, "wb") as f:
|
|
783
|
-
pickle.dump(state, f)
|
|
784
|
-
|
|
785
|
-
if self.config.verbose:
|
|
786
|
-
print(f"💾 Model saved to {path}")
|
|
787
|
-
|
|
788
|
-
@classmethod
|
|
789
|
-
def load(cls, path: str) -> "TriTopic":
|
|
790
|
-
"""Load model from disk."""
|
|
791
|
-
import pickle
|
|
792
|
-
|
|
793
|
-
with open(path, "rb") as f:
|
|
794
|
-
state = pickle.load(f)
|
|
795
|
-
|
|
796
|
-
model = cls(config=state["config"])
|
|
797
|
-
model.topics_ = state["topics_"]
|
|
798
|
-
model.labels_ = state["labels_"]
|
|
799
|
-
model.embeddings_ = state["embeddings_"]
|
|
800
|
-
model.topic_embeddings_ = state["topic_embeddings_"]
|
|
801
|
-
model.documents_ = state["documents_"]
|
|
802
|
-
model._is_fitted = state["_is_fitted"]
|
|
803
|
-
model._iteration_history = state["_iteration_history"]
|
|
804
|
-
|
|
805
|
-
return model
|
|
806
|
-
|
|
807
|
-
def __repr__(self) -> str:
|
|
808
|
-
status = "fitted" if self._is_fitted else "not fitted"
|
|
809
|
-
n_topics = len([t for t in self.topics_ if t.topic_id != -1]) if self._is_fitted else "?"
|
|
810
|
-
return f"TriTopic(n_topics={n_topics}, status={status})"
|