tritopic 0.1.0__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tritopic might be problematic. Click here for more details.
- tritopic/__init__.py +22 -32
- tritopic/config.py +289 -0
- tritopic/core/__init__.py +0 -17
- tritopic/core/clustering.py +229 -243
- tritopic/core/embeddings.py +151 -157
- tritopic/core/graph.py +435 -0
- tritopic/core/keywords.py +213 -249
- tritopic/core/refinement.py +231 -0
- tritopic/core/representatives.py +560 -0
- tritopic/labeling.py +313 -0
- tritopic/model.py +718 -0
- tritopic/multilingual/__init__.py +38 -0
- tritopic/multilingual/detection.py +208 -0
- tritopic/multilingual/stopwords.py +467 -0
- tritopic/multilingual/tokenizers.py +275 -0
- tritopic/visualization.py +371 -0
- {tritopic-0.1.0.dist-info → tritopic-1.1.0.dist-info}/METADATA +91 -51
- tritopic-1.1.0.dist-info/RECORD +20 -0
- tritopic/core/graph_builder.py +0 -493
- tritopic/core/model.py +0 -810
- tritopic/labeling/__init__.py +0 -5
- tritopic/labeling/llm_labeler.py +0 -279
- tritopic/utils/__init__.py +0 -13
- tritopic/utils/metrics.py +0 -254
- tritopic/visualization/__init__.py +0 -5
- tritopic/visualization/plotter.py +0 -523
- tritopic-0.1.0.dist-info/RECORD +0 -18
- tritopic-0.1.0.dist-info/licenses/LICENSE +0 -21
- {tritopic-0.1.0.dist-info → tritopic-1.1.0.dist-info}/WHEEL +0 -0
- {tritopic-0.1.0.dist-info → tritopic-1.1.0.dist-info}/top_level.txt +0 -0
tritopic/model.py
ADDED
|
@@ -0,0 +1,718 @@
|
|
|
1
|
+
"""
|
|
2
|
+
TriTopic: Tri-Modal Graph Topic Modeling with Iterative Refinement
|
|
3
|
+
|
|
4
|
+
Main model class that orchestrates all components.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import pickle
|
|
10
|
+
import warnings
|
|
11
|
+
from dataclasses import dataclass, field
|
|
12
|
+
from pathlib import Path
|
|
13
|
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
14
|
+
|
|
15
|
+
import numpy as np
|
|
16
|
+
import pandas as pd
|
|
17
|
+
from scipy import sparse
|
|
18
|
+
from tqdm import tqdm
|
|
19
|
+
|
|
20
|
+
from .config import TriTopicConfig, get_config
|
|
21
|
+
from .core.embeddings import EmbeddingEngine
|
|
22
|
+
from .core.graph import GraphBuilder, MultiViewGraphBuilder
|
|
23
|
+
from .core.clustering import ConsensusLeiden
|
|
24
|
+
from .core.refinement import IterativeRefinement
|
|
25
|
+
from .core.keywords import KeywordExtractor
|
|
26
|
+
from .core.representatives import RepresentativeSelector
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@dataclass
|
|
30
|
+
class Topic:
|
|
31
|
+
"""Represents a single topic with its metadata."""
|
|
32
|
+
topic_id: int
|
|
33
|
+
size: int
|
|
34
|
+
keywords: List[str]
|
|
35
|
+
keyword_scores: List[float]
|
|
36
|
+
representative_docs: List[int]
|
|
37
|
+
representative_texts: List[str]
|
|
38
|
+
centroid: Optional[np.ndarray] = None
|
|
39
|
+
label: Optional[str] = None
|
|
40
|
+
|
|
41
|
+
def __repr__(self) -> str:
|
|
42
|
+
kw_str = ", ".join(self.keywords[:5])
|
|
43
|
+
label_str = f" ({self.label})" if self.label else ""
|
|
44
|
+
return f"Topic {self.topic_id}{label_str}: [{kw_str}...] (n={self.size})"
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class TriTopic:
|
|
48
|
+
"""
|
|
49
|
+
Tri-Modal Graph Topic Modeling with Iterative Refinement.
|
|
50
|
+
|
|
51
|
+
A state-of-the-art topic modeling approach that combines:
|
|
52
|
+
- Multi-view representation (semantic, lexical, metadata)
|
|
53
|
+
- Hybrid graph construction (Mutual kNN + SNN)
|
|
54
|
+
- Consensus Leiden clustering for stability
|
|
55
|
+
- Iterative refinement for improved coherence
|
|
56
|
+
- LLM-powered labeling (optional)
|
|
57
|
+
|
|
58
|
+
Parameters
|
|
59
|
+
----------
|
|
60
|
+
config : TriTopicConfig, str, or None
|
|
61
|
+
Configuration object or name of preset config.
|
|
62
|
+
If None, uses default config.
|
|
63
|
+
**kwargs
|
|
64
|
+
Override any config parameter directly.
|
|
65
|
+
|
|
66
|
+
Examples
|
|
67
|
+
--------
|
|
68
|
+
>>> from tritopic import TriTopic
|
|
69
|
+
>>> model = TriTopic(verbose=True)
|
|
70
|
+
>>> topics = model.fit_transform(documents)
|
|
71
|
+
>>> print(model.get_topic_info())
|
|
72
|
+
|
|
73
|
+
>>> # With custom config
|
|
74
|
+
>>> model = TriTopic(
|
|
75
|
+
... embedding_model="BAAI/bge-base-en-v1.5",
|
|
76
|
+
... n_neighbors=20,
|
|
77
|
+
... use_iterative_refinement=True
|
|
78
|
+
... )
|
|
79
|
+
"""
|
|
80
|
+
|
|
81
|
+
def __init__(
|
|
82
|
+
self,
|
|
83
|
+
config: Optional[Union[TriTopicConfig, str]] = None,
|
|
84
|
+
**kwargs
|
|
85
|
+
):
|
|
86
|
+
# Load config
|
|
87
|
+
if config is None:
|
|
88
|
+
self.config = TriTopicConfig()
|
|
89
|
+
elif isinstance(config, str):
|
|
90
|
+
self.config = get_config(config)
|
|
91
|
+
else:
|
|
92
|
+
self.config = config
|
|
93
|
+
|
|
94
|
+
# Override with kwargs
|
|
95
|
+
for key, value in kwargs.items():
|
|
96
|
+
if hasattr(self.config, key):
|
|
97
|
+
setattr(self.config, key, value)
|
|
98
|
+
else:
|
|
99
|
+
warnings.warn(f"Unknown config parameter: {key}")
|
|
100
|
+
|
|
101
|
+
# Initialize components (lazy loading)
|
|
102
|
+
self._embedding_engine: Optional[EmbeddingEngine] = None
|
|
103
|
+
self._graph_builder: Optional[MultiViewGraphBuilder] = None
|
|
104
|
+
self._clusterer: Optional[ConsensusLeiden] = None
|
|
105
|
+
self._refiner: Optional[IterativeRefinement] = None
|
|
106
|
+
self._keyword_extractor: Optional[KeywordExtractor] = None
|
|
107
|
+
self._representative_selector: Optional[RepresentativeSelector] = None
|
|
108
|
+
|
|
109
|
+
# Fitted attributes
|
|
110
|
+
self.documents_: Optional[List[str]] = None
|
|
111
|
+
self.embeddings_: Optional[np.ndarray] = None
|
|
112
|
+
self.graph_: Optional[sparse.csr_matrix] = None
|
|
113
|
+
self.topic_labels_: Optional[np.ndarray] = None
|
|
114
|
+
self.topics_: Optional[List[Topic]] = None
|
|
115
|
+
self.n_topics_: int = 0
|
|
116
|
+
self.outlier_count_: int = 0
|
|
117
|
+
|
|
118
|
+
# Language detection results
|
|
119
|
+
self.detected_language_: Optional[str] = None
|
|
120
|
+
self.is_multilingual_: bool = False
|
|
121
|
+
|
|
122
|
+
def _log(self, message: str, level: int = 1) -> None:
|
|
123
|
+
"""Print message if verbose mode is enabled."""
|
|
124
|
+
if self.config.verbose:
|
|
125
|
+
indent = " " * (level - 1)
|
|
126
|
+
print(f"{indent}{message}")
|
|
127
|
+
|
|
128
|
+
def _initialize_components(self, documents: List[str]) -> None:
|
|
129
|
+
"""Initialize all components based on config and detected language."""
|
|
130
|
+
# Detect language if auto
|
|
131
|
+
if self.config.language == "auto":
|
|
132
|
+
self._detect_language(documents)
|
|
133
|
+
else:
|
|
134
|
+
self.detected_language_ = self.config.language
|
|
135
|
+
|
|
136
|
+
# Get appropriate embedding model
|
|
137
|
+
embedding_model = self.config.embedding_model
|
|
138
|
+
if embedding_model == "auto":
|
|
139
|
+
embedding_model = self.config.get_embedding_model_for_language(
|
|
140
|
+
self.detected_language_ or "en"
|
|
141
|
+
)
|
|
142
|
+
self._log(f"Auto-selected embedding model: {embedding_model}", 2)
|
|
143
|
+
|
|
144
|
+
# Initialize embedding engine
|
|
145
|
+
self._embedding_engine = EmbeddingEngine(
|
|
146
|
+
model_name=embedding_model,
|
|
147
|
+
batch_size=self.config.embedding_batch_size,
|
|
148
|
+
device=self.config.device
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
# Initialize graph builder
|
|
152
|
+
self._graph_builder = MultiViewGraphBuilder(
|
|
153
|
+
n_neighbors=self.config.n_neighbors,
|
|
154
|
+
metric=self.config.metric,
|
|
155
|
+
graph_type=self.config.graph_type,
|
|
156
|
+
snn_weight=self.config.snn_weight,
|
|
157
|
+
semantic_weight=self.config.semantic_weight,
|
|
158
|
+
lexical_weight=self.config.lexical_weight,
|
|
159
|
+
metadata_weight=self.config.metadata_weight,
|
|
160
|
+
use_lexical=self.config.use_lexical_view,
|
|
161
|
+
use_metadata=self.config.use_metadata_view,
|
|
162
|
+
lexical_method=self.config.lexical_method,
|
|
163
|
+
ngram_range=self.config.ngram_range
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
# Initialize clusterer
|
|
167
|
+
self._clusterer = ConsensusLeiden(
|
|
168
|
+
resolution=self.config.resolution,
|
|
169
|
+
n_runs=self.config.n_consensus_runs,
|
|
170
|
+
min_cluster_size=self.config.min_cluster_size,
|
|
171
|
+
random_state=self.config.random_state
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
# Initialize refiner
|
|
175
|
+
if self.config.use_iterative_refinement:
|
|
176
|
+
self._refiner = IterativeRefinement(
|
|
177
|
+
max_iterations=self.config.max_iterations,
|
|
178
|
+
convergence_threshold=self.config.convergence_threshold,
|
|
179
|
+
refinement_strength=self.config.refinement_strength
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
# Initialize keyword extractor
|
|
183
|
+
self._keyword_extractor = KeywordExtractor(
|
|
184
|
+
method=self.config.keyword_method,
|
|
185
|
+
n_keywords=self.config.n_keywords,
|
|
186
|
+
language=self.detected_language_ or "en",
|
|
187
|
+
ngram_range=self.config.ngram_range
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
# Initialize representative selector
|
|
191
|
+
self._representative_selector = RepresentativeSelector(
|
|
192
|
+
method=self.config.representative_method,
|
|
193
|
+
n_representatives=self.config.n_representative_docs,
|
|
194
|
+
n_archetypes=self.config.n_archetypes,
|
|
195
|
+
archetype_method=self.config.archetype_method
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
def _detect_language(self, documents: List[str]) -> None:
|
|
199
|
+
"""Detect the dominant language of the corpus."""
|
|
200
|
+
try:
|
|
201
|
+
from .multilingual.detection import detect_corpus_language
|
|
202
|
+
|
|
203
|
+
result = detect_corpus_language(
|
|
204
|
+
documents,
|
|
205
|
+
sample_size=self.config.language_detection_sample
|
|
206
|
+
)
|
|
207
|
+
self.detected_language_ = result["dominant_language"]
|
|
208
|
+
self.is_multilingual_ = result["is_multilingual"]
|
|
209
|
+
|
|
210
|
+
self._log(f"Detected language: {self.detected_language_} "
|
|
211
|
+
f"(confidence: {result['confidence']:.2f})", 2)
|
|
212
|
+
if self.is_multilingual_:
|
|
213
|
+
self._log("Corpus appears multilingual", 2)
|
|
214
|
+
|
|
215
|
+
except ImportError:
|
|
216
|
+
self._log("Language detection not available, defaulting to English", 2)
|
|
217
|
+
self.detected_language_ = "en"
|
|
218
|
+
|
|
219
|
+
def fit(
|
|
220
|
+
self,
|
|
221
|
+
documents: List[str],
|
|
222
|
+
embeddings: Optional[np.ndarray] = None,
|
|
223
|
+
metadata: Optional[pd.DataFrame] = None
|
|
224
|
+
) -> "TriTopic":
|
|
225
|
+
"""
|
|
226
|
+
Fit the topic model on documents.
|
|
227
|
+
|
|
228
|
+
Parameters
|
|
229
|
+
----------
|
|
230
|
+
documents : List[str]
|
|
231
|
+
List of documents to model.
|
|
232
|
+
embeddings : np.ndarray, optional
|
|
233
|
+
Pre-computed embeddings. If None, embeddings are generated.
|
|
234
|
+
metadata : pd.DataFrame, optional
|
|
235
|
+
Document metadata for multi-view fusion.
|
|
236
|
+
|
|
237
|
+
Returns
|
|
238
|
+
-------
|
|
239
|
+
self
|
|
240
|
+
Fitted model.
|
|
241
|
+
"""
|
|
242
|
+
n_docs = len(documents)
|
|
243
|
+
self.documents_ = documents
|
|
244
|
+
|
|
245
|
+
self._log(f"🚀 TriTopic: Fitting model on {n_docs} documents")
|
|
246
|
+
self._log(f"Config: {self.config.graph_type} graph, "
|
|
247
|
+
f"{'iterative' if self.config.use_iterative_refinement else 'single-pass'} mode", 1)
|
|
248
|
+
|
|
249
|
+
# Initialize components
|
|
250
|
+
self._initialize_components(documents)
|
|
251
|
+
|
|
252
|
+
# Step 1: Generate embeddings
|
|
253
|
+
if embeddings is not None:
|
|
254
|
+
self._log("→ Using provided embeddings", 1)
|
|
255
|
+
self.embeddings_ = embeddings
|
|
256
|
+
else:
|
|
257
|
+
model_name = self._embedding_engine.model_name
|
|
258
|
+
self._log(f"→ Generating embeddings ({model_name})...", 1)
|
|
259
|
+
self.embeddings_ = self._embedding_engine.encode(
|
|
260
|
+
documents,
|
|
261
|
+
show_progress=self.config.verbose
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
# Step 2: Build lexical similarity (if enabled)
|
|
265
|
+
if self.config.use_lexical_view:
|
|
266
|
+
self._log("→ Building lexical similarity matrix...", 1)
|
|
267
|
+
|
|
268
|
+
# Step 3: Build graph
|
|
269
|
+
self._log("→ Constructing multi-view graph...", 1)
|
|
270
|
+
self.graph_ = self._graph_builder.build(
|
|
271
|
+
embeddings=self.embeddings_,
|
|
272
|
+
documents=documents,
|
|
273
|
+
metadata=metadata
|
|
274
|
+
)
|
|
275
|
+
|
|
276
|
+
# Step 4: Clustering (with optional refinement)
|
|
277
|
+
if self.config.use_iterative_refinement:
|
|
278
|
+
self._log(f"→ Starting iterative refinement (max {self.config.max_iterations} iterations)...", 1)
|
|
279
|
+
|
|
280
|
+
def graph_builder_fn(emb):
|
|
281
|
+
return self._graph_builder.build(
|
|
282
|
+
embeddings=emb,
|
|
283
|
+
documents=documents,
|
|
284
|
+
metadata=metadata
|
|
285
|
+
)
|
|
286
|
+
|
|
287
|
+
def cluster_fn(g):
|
|
288
|
+
return self._clusterer.fit_predict(g)
|
|
289
|
+
|
|
290
|
+
self.topic_labels_, self.embeddings_, iterations = self._refiner.refine(
|
|
291
|
+
embeddings=self.embeddings_,
|
|
292
|
+
initial_labels=self._clusterer.fit_predict(self.graph_),
|
|
293
|
+
graph_builder_fn=graph_builder_fn,
|
|
294
|
+
cluster_fn=cluster_fn,
|
|
295
|
+
verbose=self.config.verbose
|
|
296
|
+
)
|
|
297
|
+
|
|
298
|
+
# Rebuild final graph with refined embeddings
|
|
299
|
+
self.graph_ = graph_builder_fn(self.embeddings_)
|
|
300
|
+
|
|
301
|
+
else:
|
|
302
|
+
self._log("→ Consensus clustering...", 1)
|
|
303
|
+
self.topic_labels_ = self._clusterer.fit_predict(self.graph_)
|
|
304
|
+
|
|
305
|
+
# Step 5: Extract keywords and representatives
|
|
306
|
+
self._log("→ Extracting keywords and representative documents...", 1)
|
|
307
|
+
self._extract_topic_info(documents)
|
|
308
|
+
|
|
309
|
+
# Summary
|
|
310
|
+
self.n_topics_ = len([t for t in self.topics_ if t.topic_id >= 0])
|
|
311
|
+
self.outlier_count_ = sum(1 for l in self.topic_labels_ if l < 0)
|
|
312
|
+
outlier_pct = 100 * self.outlier_count_ / n_docs
|
|
313
|
+
|
|
314
|
+
self._log("")
|
|
315
|
+
self._log(f"✅ Fitting complete!")
|
|
316
|
+
self._log(f" Found {self.n_topics_} topics")
|
|
317
|
+
self._log(f" {self.outlier_count_} outlier documents ({outlier_pct:.1f}%)")
|
|
318
|
+
|
|
319
|
+
return self
|
|
320
|
+
|
|
321
|
+
def fit_transform(
|
|
322
|
+
self,
|
|
323
|
+
documents: List[str],
|
|
324
|
+
embeddings: Optional[np.ndarray] = None,
|
|
325
|
+
metadata: Optional[pd.DataFrame] = None
|
|
326
|
+
) -> np.ndarray:
|
|
327
|
+
"""
|
|
328
|
+
Fit the model and return topic assignments.
|
|
329
|
+
|
|
330
|
+
Parameters
|
|
331
|
+
----------
|
|
332
|
+
documents : List[str]
|
|
333
|
+
List of documents to model.
|
|
334
|
+
embeddings : np.ndarray, optional
|
|
335
|
+
Pre-computed embeddings.
|
|
336
|
+
metadata : pd.DataFrame, optional
|
|
337
|
+
Document metadata.
|
|
338
|
+
|
|
339
|
+
Returns
|
|
340
|
+
-------
|
|
341
|
+
np.ndarray
|
|
342
|
+
Topic assignments for each document (-1 for outliers).
|
|
343
|
+
"""
|
|
344
|
+
self.fit(documents, embeddings, metadata)
|
|
345
|
+
return self.topic_labels_
|
|
346
|
+
|
|
347
|
+
def transform(self, documents: List[str]) -> np.ndarray:
|
|
348
|
+
"""
|
|
349
|
+
Assign topics to new documents.
|
|
350
|
+
|
|
351
|
+
Parameters
|
|
352
|
+
----------
|
|
353
|
+
documents : List[str]
|
|
354
|
+
New documents to classify.
|
|
355
|
+
|
|
356
|
+
Returns
|
|
357
|
+
-------
|
|
358
|
+
np.ndarray
|
|
359
|
+
Topic assignments for each document.
|
|
360
|
+
"""
|
|
361
|
+
if self.topics_ is None:
|
|
362
|
+
raise ValueError("Model not fitted. Call fit() first.")
|
|
363
|
+
|
|
364
|
+
# Encode new documents
|
|
365
|
+
new_embeddings = self._embedding_engine.encode(documents)
|
|
366
|
+
|
|
367
|
+
# Find nearest topic centroid for each document
|
|
368
|
+
topic_centroids = np.array([
|
|
369
|
+
t.centroid for t in self.topics_
|
|
370
|
+
if t.topic_id >= 0 and t.centroid is not None
|
|
371
|
+
])
|
|
372
|
+
topic_ids = [t.topic_id for t in self.topics_ if t.topic_id >= 0]
|
|
373
|
+
|
|
374
|
+
# Compute distances to centroids
|
|
375
|
+
from sklearn.metrics.pairwise import cosine_distances
|
|
376
|
+
distances = cosine_distances(new_embeddings, topic_centroids)
|
|
377
|
+
|
|
378
|
+
# Assign to nearest topic (with outlier threshold)
|
|
379
|
+
assignments = []
|
|
380
|
+
for i, doc_distances in enumerate(distances):
|
|
381
|
+
min_idx = np.argmin(doc_distances)
|
|
382
|
+
min_dist = doc_distances[min_idx]
|
|
383
|
+
|
|
384
|
+
if min_dist > self.config.outlier_threshold:
|
|
385
|
+
assignments.append(-1)
|
|
386
|
+
else:
|
|
387
|
+
assignments.append(topic_ids[min_idx])
|
|
388
|
+
|
|
389
|
+
return np.array(assignments)
|
|
390
|
+
|
|
391
|
+
def _extract_topic_info(self, documents: List[str]) -> None:
|
|
392
|
+
"""Extract keywords and representatives for each topic."""
|
|
393
|
+
unique_labels = sorted(set(self.topic_labels_))
|
|
394
|
+
self.topics_ = []
|
|
395
|
+
|
|
396
|
+
for topic_id in unique_labels:
|
|
397
|
+
# Get documents in this topic
|
|
398
|
+
mask = self.topic_labels_ == topic_id
|
|
399
|
+
topic_docs = [documents[i] for i, m in enumerate(mask) if m]
|
|
400
|
+
topic_embeddings = self.embeddings_[mask]
|
|
401
|
+
topic_indices = np.where(mask)[0]
|
|
402
|
+
|
|
403
|
+
# Extract keywords
|
|
404
|
+
if topic_id >= 0:
|
|
405
|
+
keywords, scores = self._keyword_extractor.extract(
|
|
406
|
+
topic_docs,
|
|
407
|
+
all_documents=documents
|
|
408
|
+
)
|
|
409
|
+
else:
|
|
410
|
+
keywords, scores = ["[outlier]"], [0.0]
|
|
411
|
+
|
|
412
|
+
# Get representative documents
|
|
413
|
+
if topic_id >= 0 and len(topic_docs) > 0:
|
|
414
|
+
reps = self._representative_selector.select(
|
|
415
|
+
embeddings=topic_embeddings,
|
|
416
|
+
documents=topic_docs,
|
|
417
|
+
keywords=keywords[:5],
|
|
418
|
+
global_indices=topic_indices
|
|
419
|
+
)
|
|
420
|
+
rep_indices = [r.doc_id for r in reps.representatives]
|
|
421
|
+
rep_texts = [r.text[:200] for r in reps.representatives]
|
|
422
|
+
else:
|
|
423
|
+
rep_indices = list(topic_indices[:self.config.n_representative_docs])
|
|
424
|
+
rep_texts = topic_docs[:self.config.n_representative_docs]
|
|
425
|
+
|
|
426
|
+
# Compute centroid
|
|
427
|
+
centroid = topic_embeddings.mean(axis=0) if len(topic_embeddings) > 0 else None
|
|
428
|
+
|
|
429
|
+
topic = Topic(
|
|
430
|
+
topic_id=topic_id,
|
|
431
|
+
size=len(topic_docs),
|
|
432
|
+
keywords=keywords,
|
|
433
|
+
keyword_scores=scores,
|
|
434
|
+
representative_docs=rep_indices,
|
|
435
|
+
representative_texts=rep_texts,
|
|
436
|
+
centroid=centroid
|
|
437
|
+
)
|
|
438
|
+
self.topics_.append(topic)
|
|
439
|
+
|
|
440
|
+
def get_topic_info(self) -> pd.DataFrame:
|
|
441
|
+
"""
|
|
442
|
+
Get summary information about all topics.
|
|
443
|
+
|
|
444
|
+
Returns
|
|
445
|
+
-------
|
|
446
|
+
pd.DataFrame
|
|
447
|
+
DataFrame with topic information.
|
|
448
|
+
"""
|
|
449
|
+
if self.topics_ is None:
|
|
450
|
+
raise ValueError("Model not fitted. Call fit() first.")
|
|
451
|
+
|
|
452
|
+
data = []
|
|
453
|
+
for topic in self.topics_:
|
|
454
|
+
data.append({
|
|
455
|
+
"Topic": topic.topic_id,
|
|
456
|
+
"Size": topic.size,
|
|
457
|
+
"Label": topic.label or "",
|
|
458
|
+
"Keywords": ", ".join(topic.keywords[:5]),
|
|
459
|
+
"Representative": topic.representative_texts[0][:100] + "..."
|
|
460
|
+
if topic.representative_texts else ""
|
|
461
|
+
})
|
|
462
|
+
|
|
463
|
+
return pd.DataFrame(data)
|
|
464
|
+
|
|
465
|
+
def get_topic(self, topic_id: int) -> Optional[Topic]:
|
|
466
|
+
"""
|
|
467
|
+
Get detailed information about a specific topic.
|
|
468
|
+
|
|
469
|
+
Parameters
|
|
470
|
+
----------
|
|
471
|
+
topic_id : int
|
|
472
|
+
Topic ID to retrieve.
|
|
473
|
+
|
|
474
|
+
Returns
|
|
475
|
+
-------
|
|
476
|
+
Topic or None
|
|
477
|
+
Topic object if found.
|
|
478
|
+
"""
|
|
479
|
+
if self.topics_ is None:
|
|
480
|
+
return None
|
|
481
|
+
|
|
482
|
+
for topic in self.topics_:
|
|
483
|
+
if topic.topic_id == topic_id:
|
|
484
|
+
return topic
|
|
485
|
+
return None
|
|
486
|
+
|
|
487
|
+
def get_document_topics(
|
|
488
|
+
self,
|
|
489
|
+
doc_indices: Optional[List[int]] = None
|
|
490
|
+
) -> pd.DataFrame:
|
|
491
|
+
"""
|
|
492
|
+
Get topic assignments for documents.
|
|
493
|
+
|
|
494
|
+
Parameters
|
|
495
|
+
----------
|
|
496
|
+
doc_indices : List[int], optional
|
|
497
|
+
Specific document indices. If None, returns all.
|
|
498
|
+
|
|
499
|
+
Returns
|
|
500
|
+
-------
|
|
501
|
+
pd.DataFrame
|
|
502
|
+
DataFrame with document-topic assignments.
|
|
503
|
+
"""
|
|
504
|
+
if self.topic_labels_ is None:
|
|
505
|
+
raise ValueError("Model not fitted.")
|
|
506
|
+
|
|
507
|
+
if doc_indices is None:
|
|
508
|
+
doc_indices = list(range(len(self.topic_labels_)))
|
|
509
|
+
|
|
510
|
+
data = []
|
|
511
|
+
for idx in doc_indices:
|
|
512
|
+
topic_id = self.topic_labels_[idx]
|
|
513
|
+
topic = self.get_topic(topic_id)
|
|
514
|
+
data.append({
|
|
515
|
+
"Document": idx,
|
|
516
|
+
"Topic": topic_id,
|
|
517
|
+
"Topic_Label": topic.label if topic else "",
|
|
518
|
+
"Text_Preview": self.documents_[idx][:100] + "..."
|
|
519
|
+
})
|
|
520
|
+
|
|
521
|
+
return pd.DataFrame(data)
|
|
522
|
+
|
|
523
|
+
def generate_labels(self, labeler: Any) -> None:
|
|
524
|
+
"""
|
|
525
|
+
Generate human-readable labels for topics using an LLM.
|
|
526
|
+
|
|
527
|
+
Parameters
|
|
528
|
+
----------
|
|
529
|
+
labeler : LLMLabeler
|
|
530
|
+
Labeler instance configured with LLM provider.
|
|
531
|
+
"""
|
|
532
|
+
if self.topics_ is None:
|
|
533
|
+
raise ValueError("Model not fitted.")
|
|
534
|
+
|
|
535
|
+
for topic in self.topics_:
|
|
536
|
+
if topic.topic_id < 0:
|
|
537
|
+
topic.label = "Outliers"
|
|
538
|
+
continue
|
|
539
|
+
|
|
540
|
+
label = labeler.generate_label(
|
|
541
|
+
keywords=topic.keywords,
|
|
542
|
+
representative_docs=topic.representative_texts
|
|
543
|
+
)
|
|
544
|
+
topic.label = label
|
|
545
|
+
|
|
546
|
+
def evaluate(self) -> Dict[str, float]:
|
|
547
|
+
"""
|
|
548
|
+
Evaluate the topic model quality.
|
|
549
|
+
|
|
550
|
+
Returns
|
|
551
|
+
-------
|
|
552
|
+
Dict[str, float]
|
|
553
|
+
Dictionary with evaluation metrics.
|
|
554
|
+
"""
|
|
555
|
+
if self.topics_ is None:
|
|
556
|
+
raise ValueError("Model not fitted.")
|
|
557
|
+
|
|
558
|
+
metrics = {}
|
|
559
|
+
|
|
560
|
+
# Number of topics
|
|
561
|
+
metrics["n_topics"] = self.n_topics_
|
|
562
|
+
|
|
563
|
+
# Outlier ratio
|
|
564
|
+
metrics["outlier_ratio"] = self.outlier_count_ / len(self.topic_labels_)
|
|
565
|
+
|
|
566
|
+
# Topic diversity
|
|
567
|
+
all_keywords = []
|
|
568
|
+
for topic in self.topics_:
|
|
569
|
+
if topic.topic_id >= 0:
|
|
570
|
+
all_keywords.extend(topic.keywords[:10])
|
|
571
|
+
|
|
572
|
+
unique_keywords = len(set(all_keywords))
|
|
573
|
+
total_keywords = len(all_keywords)
|
|
574
|
+
metrics["diversity"] = unique_keywords / total_keywords if total_keywords > 0 else 0
|
|
575
|
+
|
|
576
|
+
# Coherence (simplified NPMI)
|
|
577
|
+
try:
|
|
578
|
+
from .core.keywords import compute_coherence
|
|
579
|
+
coherence_scores = []
|
|
580
|
+
for topic in self.topics_:
|
|
581
|
+
if topic.topic_id >= 0:
|
|
582
|
+
score = compute_coherence(topic.keywords[:10], self.documents_)
|
|
583
|
+
coherence_scores.append(score)
|
|
584
|
+
|
|
585
|
+
metrics["coherence_mean"] = np.mean(coherence_scores) if coherence_scores else 0
|
|
586
|
+
metrics["coherence_std"] = np.std(coherence_scores) if coherence_scores else 0
|
|
587
|
+
except:
|
|
588
|
+
metrics["coherence_mean"] = 0
|
|
589
|
+
metrics["coherence_std"] = 0
|
|
590
|
+
|
|
591
|
+
# Clustering stability
|
|
592
|
+
if hasattr(self._clusterer, "stability_score_"):
|
|
593
|
+
metrics["stability"] = self._clusterer.stability_score_
|
|
594
|
+
|
|
595
|
+
return metrics
|
|
596
|
+
|
|
597
|
+
def visualize(
|
|
598
|
+
self,
|
|
599
|
+
method: str = "umap",
|
|
600
|
+
**kwargs
|
|
601
|
+
) -> Any:
|
|
602
|
+
"""
|
|
603
|
+
Visualize the topic model.
|
|
604
|
+
|
|
605
|
+
Parameters
|
|
606
|
+
----------
|
|
607
|
+
method : str
|
|
608
|
+
Visualization method: "umap", "tsne", or "pca".
|
|
609
|
+
**kwargs
|
|
610
|
+
Additional arguments for visualization.
|
|
611
|
+
|
|
612
|
+
Returns
|
|
613
|
+
-------
|
|
614
|
+
plotly.graph_objects.Figure
|
|
615
|
+
Interactive visualization.
|
|
616
|
+
"""
|
|
617
|
+
try:
|
|
618
|
+
from .visualization import create_topic_visualization
|
|
619
|
+
return create_topic_visualization(
|
|
620
|
+
embeddings=self.embeddings_,
|
|
621
|
+
labels=self.topic_labels_,
|
|
622
|
+
topics=self.topics_,
|
|
623
|
+
documents=self.documents_,
|
|
624
|
+
method=method,
|
|
625
|
+
**kwargs
|
|
626
|
+
)
|
|
627
|
+
except ImportError:
|
|
628
|
+
raise ImportError(
|
|
629
|
+
"Visualization requires plotly. Install with: pip install tritopic[visualization]"
|
|
630
|
+
)
|
|
631
|
+
|
|
632
|
+
def visualize_topics(self, **kwargs) -> Any:
|
|
633
|
+
"""Visualize topic keywords as bar charts."""
|
|
634
|
+
try:
|
|
635
|
+
from .visualization import create_topic_barchart
|
|
636
|
+
return create_topic_barchart(self.topics_, **kwargs)
|
|
637
|
+
except ImportError:
|
|
638
|
+
raise ImportError("Visualization requires plotly.")
|
|
639
|
+
|
|
640
|
+
def visualize_hierarchy(self, **kwargs) -> Any:
|
|
641
|
+
"""Visualize topic hierarchy as dendrogram."""
|
|
642
|
+
try:
|
|
643
|
+
from .visualization import create_topic_hierarchy
|
|
644
|
+
return create_topic_hierarchy(
|
|
645
|
+
embeddings=self.embeddings_,
|
|
646
|
+
labels=self.topic_labels_,
|
|
647
|
+
topics=self.topics_,
|
|
648
|
+
**kwargs
|
|
649
|
+
)
|
|
650
|
+
except ImportError:
|
|
651
|
+
raise ImportError("Visualization requires plotly.")
|
|
652
|
+
|
|
653
|
+
def save(self, path: Union[str, Path]) -> None:
|
|
654
|
+
"""
|
|
655
|
+
Save the model to disk.
|
|
656
|
+
|
|
657
|
+
Parameters
|
|
658
|
+
----------
|
|
659
|
+
path : str or Path
|
|
660
|
+
File path for saving.
|
|
661
|
+
"""
|
|
662
|
+
path = Path(path)
|
|
663
|
+
|
|
664
|
+
# Prepare state dict (excluding non-picklable objects)
|
|
665
|
+
state = {
|
|
666
|
+
"config": self.config,
|
|
667
|
+
"documents_": self.documents_,
|
|
668
|
+
"embeddings_": self.embeddings_,
|
|
669
|
+
"topic_labels_": self.topic_labels_,
|
|
670
|
+
"topics_": self.topics_,
|
|
671
|
+
"n_topics_": self.n_topics_,
|
|
672
|
+
"outlier_count_": self.outlier_count_,
|
|
673
|
+
"detected_language_": self.detected_language_,
|
|
674
|
+
"is_multilingual_": self.is_multilingual_,
|
|
675
|
+
}
|
|
676
|
+
|
|
677
|
+
with open(path, "wb") as f:
|
|
678
|
+
pickle.dump(state, f)
|
|
679
|
+
|
|
680
|
+
@classmethod
|
|
681
|
+
def load(cls, path: Union[str, Path]) -> "TriTopic":
|
|
682
|
+
"""
|
|
683
|
+
Load a model from disk.
|
|
684
|
+
|
|
685
|
+
Parameters
|
|
686
|
+
----------
|
|
687
|
+
path : str or Path
|
|
688
|
+
File path to load from.
|
|
689
|
+
|
|
690
|
+
Returns
|
|
691
|
+
-------
|
|
692
|
+
TriTopic
|
|
693
|
+
Loaded model.
|
|
694
|
+
"""
|
|
695
|
+
path = Path(path)
|
|
696
|
+
|
|
697
|
+
with open(path, "rb") as f:
|
|
698
|
+
state = pickle.load(f)
|
|
699
|
+
|
|
700
|
+
model = cls(config=state["config"])
|
|
701
|
+
model.documents_ = state["documents_"]
|
|
702
|
+
model.embeddings_ = state["embeddings_"]
|
|
703
|
+
model.topic_labels_ = state["topic_labels_"]
|
|
704
|
+
model.topics_ = state["topics_"]
|
|
705
|
+
model.n_topics_ = state["n_topics_"]
|
|
706
|
+
model.outlier_count_ = state["outlier_count_"]
|
|
707
|
+
model.detected_language_ = state.get("detected_language_")
|
|
708
|
+
model.is_multilingual_ = state.get("is_multilingual_", False)
|
|
709
|
+
|
|
710
|
+
# Re-initialize components for transform()
|
|
711
|
+
model._initialize_components(model.documents_)
|
|
712
|
+
|
|
713
|
+
return model
|
|
714
|
+
|
|
715
|
+
def __repr__(self) -> str:
|
|
716
|
+
if self.topics_ is None:
|
|
717
|
+
return "TriTopic(not fitted)"
|
|
718
|
+
return f"TriTopic(n_topics={self.n_topics_}, n_docs={len(self.documents_)})"
|