tritopic 0.1.0__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tritopic/core/model.py DELETED
@@ -1,810 +0,0 @@
1
- """
2
- TriTopic: Main Model Class
3
- ===========================
4
-
5
- The core class that orchestrates all components of the topic modeling pipeline.
6
- """
7
-
8
- from __future__ import annotations
9
-
10
- import warnings
11
- from dataclasses import dataclass, field
12
- from typing import Any, Callable, Literal
13
-
14
- import numpy as np
15
- import pandas as pd
16
- from tqdm import tqdm
17
-
18
- from tritopic.core.embeddings import EmbeddingEngine
19
- from tritopic.core.graph_builder import GraphBuilder
20
- from tritopic.core.clustering import ConsensusLeiden
21
- from tritopic.core.keywords import KeywordExtractor
22
- from tritopic.utils.metrics import compute_coherence, compute_diversity, compute_stability
23
-
24
-
25
- @dataclass
26
- class TopicInfo:
27
- """Container for topic information."""
28
-
29
- topic_id: int
30
- size: int
31
- keywords: list[str]
32
- keyword_scores: list[float]
33
- representative_docs: list[int]
34
- label: str | None = None
35
- description: str | None = None
36
- centroid: np.ndarray | None = None
37
- coherence: float | None = None
38
-
39
-
40
- @dataclass
41
- class TriTopicConfig:
42
- """Configuration for TriTopic model."""
43
-
44
- # Embedding settings
45
- embedding_model: str = "all-MiniLM-L6-v2"
46
- embedding_batch_size: int = 32
47
-
48
- # Graph settings
49
- n_neighbors: int = 15
50
- metric: str = "cosine"
51
- graph_type: Literal["mutual_knn", "snn", "hybrid"] = "hybrid"
52
- snn_weight: float = 0.5
53
-
54
- # Multi-view settings
55
- use_lexical_view: bool = True
56
- use_metadata_view: bool = False
57
- lexical_weight: float = 0.3
58
- metadata_weight: float = 0.2
59
- semantic_weight: float = 0.5
60
-
61
- # Clustering settings
62
- resolution: float = 1.0
63
- resolution_range: tuple[float, float] | None = None
64
- n_consensus_runs: int = 10
65
- min_cluster_size: int = 5
66
-
67
- # Iterative refinement
68
- use_iterative_refinement: bool = True
69
- max_iterations: int = 5
70
- convergence_threshold: float = 0.95
71
-
72
- # Keyword extraction
73
- n_keywords: int = 10
74
- n_representative_docs: int = 5
75
- keyword_method: Literal["ctfidf", "bm25", "keybert"] = "ctfidf"
76
-
77
- # Outlier handling
78
- outlier_threshold: float = 0.1
79
-
80
- # Misc
81
- random_state: int = 42
82
- verbose: bool = True
83
-
84
-
85
- class TriTopic:
86
- """
87
- Tri-Modal Graph Topic Modeling with Iterative Refinement.
88
-
89
- A state-of-the-art topic modeling approach that combines semantic embeddings,
90
- lexical similarity, and optional metadata to create robust, interpretable topics.
91
-
92
- Key innovations:
93
- - Multi-view graph fusion (semantic + lexical + metadata)
94
- - Leiden clustering with consensus for stability
95
- - Iterative refinement loop for optimal topic separation
96
- - Advanced keyword extraction with representative documents
97
- - Optional LLM-powered topic labeling
98
-
99
- Parameters
100
- ----------
101
- config : TriTopicConfig, optional
102
- Configuration object. If None, uses defaults.
103
- embedding_model : str, optional
104
- Name of sentence-transformers model. Default: "all-MiniLM-L6-v2"
105
- n_neighbors : int, optional
106
- Number of neighbors for graph construction. Default: 15
107
- n_topics : int or "auto", optional
108
- Number of topics. "auto" uses Leiden's natural resolution. Default: "auto"
109
- use_iterative_refinement : bool, optional
110
- Whether to use the iterative refinement loop. Default: True
111
- verbose : bool, optional
112
- Print progress information. Default: True
113
-
114
- Attributes
115
- ----------
116
- topics_ : list[TopicInfo]
117
- Information about each discovered topic.
118
- labels_ : np.ndarray
119
- Topic assignment for each document.
120
- embeddings_ : np.ndarray
121
- Document embeddings.
122
- graph_ : igraph.Graph
123
- The constructed similarity graph.
124
- topic_embeddings_ : np.ndarray
125
- Centroid embeddings for each topic.
126
-
127
- Examples
128
- --------
129
- Basic usage:
130
-
131
- >>> from tritopic import TriTopic
132
- >>> model = TriTopic(n_neighbors=15, verbose=True)
133
- >>> topics = model.fit_transform(documents)
134
- >>> print(model.get_topic_info())
135
-
136
- With metadata:
137
-
138
- >>> model = TriTopic()
139
- >>> model.config.use_metadata_view = True
140
- >>> topics = model.fit_transform(documents, metadata=df[['source', 'date']])
141
-
142
- With LLM labeling:
143
-
144
- >>> from tritopic import TriTopic, LLMLabeler
145
- >>> model = TriTopic()
146
- >>> model.fit_transform(documents)
147
- >>> labeler = LLMLabeler(provider="anthropic", api_key="...")
148
- >>> model.generate_labels(labeler)
149
- """
150
-
151
- def __init__(
152
- self,
153
- config: TriTopicConfig | None = None,
154
- embedding_model: str | None = None,
155
- n_neighbors: int | None = None,
156
- n_topics: int | Literal["auto"] = "auto",
157
- use_iterative_refinement: bool | None = None,
158
- verbose: bool | None = None,
159
- random_state: int | None = None,
160
- ):
161
- # Initialize config
162
- self.config = config or TriTopicConfig()
163
-
164
- # Override config with explicit parameters
165
- if embedding_model is not None:
166
- self.config.embedding_model = embedding_model
167
- if n_neighbors is not None:
168
- self.config.n_neighbors = n_neighbors
169
- if use_iterative_refinement is not None:
170
- self.config.use_iterative_refinement = use_iterative_refinement
171
- if verbose is not None:
172
- self.config.verbose = verbose
173
- if random_state is not None:
174
- self.config.random_state = random_state
175
-
176
- self.n_topics = n_topics
177
-
178
- # Initialize components
179
- self._embedding_engine = EmbeddingEngine(
180
- model_name=self.config.embedding_model,
181
- batch_size=self.config.embedding_batch_size,
182
- )
183
- self._graph_builder = GraphBuilder(
184
- n_neighbors=self.config.n_neighbors,
185
- metric=self.config.metric,
186
- graph_type=self.config.graph_type,
187
- snn_weight=self.config.snn_weight,
188
- )
189
- self._clusterer = ConsensusLeiden(
190
- resolution=self.config.resolution,
191
- n_runs=self.config.n_consensus_runs,
192
- random_state=self.config.random_state,
193
- )
194
- self._keyword_extractor = KeywordExtractor(
195
- method=self.config.keyword_method,
196
- n_keywords=self.config.n_keywords,
197
- )
198
-
199
- # State
200
- self.topics_: list[TopicInfo] = []
201
- self.labels_: np.ndarray | None = None
202
- self.embeddings_: np.ndarray | None = None
203
- self.lexical_matrix_: Any | None = None
204
- self.graph_: Any | None = None
205
- self.topic_embeddings_: np.ndarray | None = None
206
- self.documents_: list[str] | None = None
207
- self._is_fitted: bool = False
208
- self._iteration_history: list[dict] = []
209
-
210
- def fit(
211
- self,
212
- documents: list[str],
213
- embeddings: np.ndarray | None = None,
214
- metadata: pd.DataFrame | None = None,
215
- ) -> "TriTopic":
216
- """
217
- Fit the topic model to documents.
218
-
219
- Parameters
220
- ----------
221
- documents : list[str]
222
- List of document texts.
223
- embeddings : np.ndarray, optional
224
- Pre-computed embeddings. If None, computed automatically.
225
- metadata : pd.DataFrame, optional
226
- Document metadata for the metadata view.
227
-
228
- Returns
229
- -------
230
- self : TriTopic
231
- Fitted model.
232
- """
233
- self.documents_ = documents
234
- n_docs = len(documents)
235
-
236
- if self.config.verbose:
237
- print(f"🚀 TriTopic: Fitting model on {n_docs} documents")
238
- print(f" Config: {self.config.graph_type} graph, "
239
- f"{'iterative' if self.config.use_iterative_refinement else 'single-pass'} mode")
240
-
241
- # Step 1: Generate embeddings
242
- if embeddings is not None:
243
- self.embeddings_ = embeddings
244
- if self.config.verbose:
245
- print(" ✓ Using provided embeddings")
246
- else:
247
- if self.config.verbose:
248
- print(f" → Generating embeddings ({self.config.embedding_model})...")
249
- self.embeddings_ = self._embedding_engine.encode(documents)
250
-
251
- # Step 2: Build lexical representation
252
- if self.config.use_lexical_view:
253
- if self.config.verbose:
254
- print(" → Building lexical similarity matrix...")
255
- self.lexical_matrix_ = self._graph_builder.build_lexical_matrix(documents)
256
-
257
- # Step 3: Build metadata graph (if provided)
258
- metadata_graph = None
259
- if self.config.use_metadata_view and metadata is not None:
260
- if self.config.verbose:
261
- print(" → Building metadata similarity graph...")
262
- metadata_graph = self._graph_builder.build_metadata_graph(metadata)
263
-
264
- # Step 4: Main fitting loop
265
- if self.config.use_iterative_refinement:
266
- self._fit_iterative(documents, metadata_graph)
267
- else:
268
- self._fit_single_pass(documents, metadata_graph)
269
-
270
- # Step 5: Extract keywords and representative docs
271
- if self.config.verbose:
272
- print(" → Extracting keywords and representative documents...")
273
- self._extract_topic_info(documents)
274
-
275
- # Step 6: Compute topic centroids
276
- self._compute_topic_centroids()
277
-
278
- self._is_fitted = True
279
-
280
- if self.config.verbose:
281
- n_topics = len([t for t in self.topics_ if t.topic_id != -1])
282
- n_outliers = np.sum(self.labels_ == -1) if self.labels_ is not None else 0
283
- print(f"\n✅ Fitting complete!")
284
- print(f" Found {n_topics} topics")
285
- print(f" {n_outliers} outlier documents ({100*n_outliers/n_docs:.1f}%)")
286
-
287
- return self
288
-
289
- def _fit_single_pass(
290
- self,
291
- documents: list[str],
292
- metadata_graph: Any | None = None,
293
- ) -> None:
294
- """Single-pass fitting without iterative refinement."""
295
- # Build graph
296
- if self.config.verbose:
297
- print(" → Building multi-view graph...")
298
-
299
- self.graph_ = self._graph_builder.build_multiview_graph(
300
- semantic_embeddings=self.embeddings_,
301
- lexical_matrix=self.lexical_matrix_ if self.config.use_lexical_view else None,
302
- metadata_graph=metadata_graph,
303
- weights={
304
- "semantic": self.config.semantic_weight,
305
- "lexical": self.config.lexical_weight,
306
- "metadata": self.config.metadata_weight,
307
- }
308
- )
309
-
310
- # Cluster
311
- if self.config.verbose:
312
- print(f" → Running Leiden consensus clustering ({self.config.n_consensus_runs} runs)...")
313
-
314
- self.labels_ = self._clusterer.fit_predict(
315
- self.graph_,
316
- min_cluster_size=self.config.min_cluster_size,
317
- )
318
-
319
- def _fit_iterative(
320
- self,
321
- documents: list[str],
322
- metadata_graph: Any | None = None,
323
- ) -> None:
324
- """Iterative refinement fitting loop."""
325
- if self.config.verbose:
326
- print(f" → Starting iterative refinement (max {self.config.max_iterations} iterations)...")
327
-
328
- current_embeddings = self.embeddings_.copy()
329
- previous_labels = None
330
-
331
- for iteration in range(self.config.max_iterations):
332
- if self.config.verbose:
333
- print(f" Iteration {iteration + 1}...")
334
-
335
- # Build graph with current embeddings
336
- self.graph_ = self._graph_builder.build_multiview_graph(
337
- semantic_embeddings=current_embeddings,
338
- lexical_matrix=self.lexical_matrix_ if self.config.use_lexical_view else None,
339
- metadata_graph=metadata_graph,
340
- weights={
341
- "semantic": self.config.semantic_weight,
342
- "lexical": self.config.lexical_weight,
343
- "metadata": self.config.metadata_weight,
344
- }
345
- )
346
-
347
- # Cluster
348
- self.labels_ = self._clusterer.fit_predict(
349
- self.graph_,
350
- min_cluster_size=self.config.min_cluster_size,
351
- )
352
-
353
- # Check convergence
354
- if previous_labels is not None:
355
- from sklearn.metrics import adjusted_rand_score
356
- ari = adjusted_rand_score(previous_labels, self.labels_)
357
- self._iteration_history.append({
358
- "iteration": iteration + 1,
359
- "ari": ari,
360
- "n_topics": len(np.unique(self.labels_[self.labels_ != -1])),
361
- })
362
-
363
- if self.config.verbose:
364
- print(f" ARI vs previous: {ari:.4f}")
365
-
366
- if ari >= self.config.convergence_threshold:
367
- if self.config.verbose:
368
- print(f" ✓ Converged at iteration {iteration + 1}")
369
- break
370
-
371
- previous_labels = self.labels_.copy()
372
-
373
- # Refine embeddings based on topic structure
374
- current_embeddings = self._refine_embeddings(
375
- documents, self.embeddings_, self.labels_
376
- )
377
-
378
- # Store final refined embeddings
379
- self.embeddings_ = current_embeddings
380
-
381
- def _refine_embeddings(
382
- self,
383
- documents: list[str],
384
- original_embeddings: np.ndarray,
385
- labels: np.ndarray,
386
- ) -> np.ndarray:
387
- """
388
- Refine embeddings by incorporating topic context.
389
-
390
- This is the key innovation: we modify embeddings to be more
391
- topic-aware by pulling documents toward their topic centroid.
392
- """
393
- refined = original_embeddings.copy()
394
- unique_labels = np.unique(labels[labels != -1])
395
-
396
- # Compute topic centroids
397
- centroids = {}
398
- for label in unique_labels:
399
- mask = labels == label
400
- centroids[label] = original_embeddings[mask].mean(axis=0)
401
-
402
- # Soft refinement: blend original embedding with topic centroid
403
- blend_factor = 0.2 # How much to pull toward centroid
404
-
405
- for i, label in enumerate(labels):
406
- if label != -1: # Skip outliers
407
- centroid = centroids[label]
408
- refined[i] = (1 - blend_factor) * refined[i] + blend_factor * centroid
409
- # Re-normalize
410
- refined[i] = refined[i] / np.linalg.norm(refined[i])
411
-
412
- return refined
413
-
414
- def _extract_topic_info(self, documents: list[str]) -> None:
415
- """Extract keywords and representative documents for each topic."""
416
- self.topics_ = []
417
- unique_labels = np.unique(self.labels_)
418
-
419
- for label in unique_labels:
420
- mask = self.labels_ == label
421
- topic_docs = [documents[i] for i in np.where(mask)[0]]
422
- topic_indices = np.where(mask)[0]
423
-
424
- # Extract keywords
425
- keywords, scores = self._keyword_extractor.extract(
426
- topic_docs,
427
- all_docs=documents,
428
- n_keywords=self.config.n_keywords,
429
- )
430
-
431
- # Find representative documents (closest to centroid)
432
- if self.embeddings_ is not None and label != -1:
433
- topic_embeddings = self.embeddings_[mask]
434
- centroid = topic_embeddings.mean(axis=0)
435
- distances = np.linalg.norm(topic_embeddings - centroid, axis=1)
436
- top_indices = np.argsort(distances)[:self.config.n_representative_docs]
437
- representative_docs = [int(topic_indices[i]) for i in top_indices]
438
- else:
439
- representative_docs = list(topic_indices[:self.config.n_representative_docs])
440
-
441
- topic_info = TopicInfo(
442
- topic_id=int(label),
443
- size=int(mask.sum()),
444
- keywords=keywords,
445
- keyword_scores=scores,
446
- representative_docs=representative_docs,
447
- label=None,
448
- description=None,
449
- )
450
- self.topics_.append(topic_info)
451
-
452
- # Sort by size (excluding outliers)
453
- self.topics_ = sorted(
454
- self.topics_,
455
- key=lambda t: (t.topic_id == -1, -t.size)
456
- )
457
-
458
- def _compute_topic_centroids(self) -> None:
459
- """Compute centroid embeddings for each topic."""
460
- if self.embeddings_ is None:
461
- return
462
-
463
- unique_labels = [t.topic_id for t in self.topics_ if t.topic_id != -1]
464
- self.topic_embeddings_ = np.zeros((len(unique_labels), self.embeddings_.shape[1]))
465
-
466
- for i, label in enumerate(unique_labels):
467
- mask = self.labels_ == label
468
- self.topic_embeddings_[i] = self.embeddings_[mask].mean(axis=0)
469
-
470
- # Store in topic info
471
- for topic in self.topics_:
472
- if topic.topic_id == label:
473
- topic.centroid = self.topic_embeddings_[i]
474
- break
475
-
476
- def fit_transform(
477
- self,
478
- documents: list[str],
479
- embeddings: np.ndarray | None = None,
480
- metadata: pd.DataFrame | None = None,
481
- ) -> np.ndarray:
482
- """
483
- Fit the model and return topic assignments.
484
-
485
- Parameters
486
- ----------
487
- documents : list[str]
488
- List of document texts.
489
- embeddings : np.ndarray, optional
490
- Pre-computed embeddings.
491
- metadata : pd.DataFrame, optional
492
- Document metadata.
493
-
494
- Returns
495
- -------
496
- labels : np.ndarray
497
- Topic assignment for each document. -1 indicates outlier.
498
- """
499
- self.fit(documents, embeddings, metadata)
500
- return self.labels_
501
-
502
- def transform(self, documents: list[str]) -> np.ndarray:
503
- """
504
- Assign topics to new documents.
505
-
506
- Parameters
507
- ----------
508
- documents : list[str]
509
- New documents to classify.
510
-
511
- Returns
512
- -------
513
- labels : np.ndarray
514
- Topic assignments.
515
- """
516
- if not self._is_fitted:
517
- raise ValueError("Model not fitted. Call fit() first.")
518
-
519
- # Encode new documents
520
- new_embeddings = self._embedding_engine.encode(documents)
521
-
522
- # Find nearest topic centroid
523
- labels = np.zeros(len(documents), dtype=int)
524
-
525
- for i, emb in enumerate(new_embeddings):
526
- distances = np.linalg.norm(self.topic_embeddings_ - emb, axis=1)
527
- nearest_topic_idx = np.argmin(distances)
528
-
529
- # Check if it's an outlier (too far from any centroid)
530
- if distances[nearest_topic_idx] > self.config.outlier_threshold * 2:
531
- labels[i] = -1
532
- else:
533
- # Map index back to topic_id
534
- non_outlier_topics = [t for t in self.topics_ if t.topic_id != -1]
535
- labels[i] = non_outlier_topics[nearest_topic_idx].topic_id
536
-
537
- return labels
538
-
539
- def get_topic_info(self) -> pd.DataFrame:
540
- """
541
- Get a DataFrame with topic information.
542
-
543
- Returns
544
- -------
545
- df : pd.DataFrame
546
- DataFrame with columns: Topic, Size, Keywords, Label, Coherence
547
- """
548
- if not self._is_fitted:
549
- raise ValueError("Model not fitted. Call fit() first.")
550
-
551
- data = []
552
- for topic in self.topics_:
553
- data.append({
554
- "Topic": topic.topic_id,
555
- "Size": topic.size,
556
- "Keywords": ", ".join(topic.keywords[:5]),
557
- "All_Keywords": topic.keywords,
558
- "Keyword_Scores": topic.keyword_scores,
559
- "Label": topic.label or f"Topic {topic.topic_id}",
560
- "Description": topic.description,
561
- "Representative_Docs": topic.representative_docs,
562
- "Coherence": topic.coherence,
563
- })
564
-
565
- return pd.DataFrame(data)
566
-
567
- def get_topic(self, topic_id: int) -> TopicInfo | None:
568
- """Get information about a specific topic."""
569
- for topic in self.topics_:
570
- if topic.topic_id == topic_id:
571
- return topic
572
- return None
573
-
574
- def get_representative_docs(
575
- self,
576
- topic_id: int,
577
- n_docs: int = 5,
578
- ) -> list[tuple[int, str]]:
579
- """
580
- Get representative documents for a topic.
581
-
582
- Parameters
583
- ----------
584
- topic_id : int
585
- Topic ID.
586
- n_docs : int
587
- Number of documents to return.
588
-
589
- Returns
590
- -------
591
- docs : list[tuple[int, str]]
592
- List of (index, document_text) tuples.
593
- """
594
- if not self._is_fitted or self.documents_ is None:
595
- raise ValueError("Model not fitted. Call fit() first.")
596
-
597
- topic = self.get_topic(topic_id)
598
- if topic is None:
599
- raise ValueError(f"Topic {topic_id} not found.")
600
-
601
- indices = topic.representative_docs[:n_docs]
602
- return [(idx, self.documents_[idx]) for idx in indices]
603
-
604
- def generate_labels(
605
- self,
606
- labeler: "LLMLabeler",
607
- topics: list[int] | None = None,
608
- ) -> None:
609
- """
610
- Generate labels for topics using an LLM.
611
-
612
- Parameters
613
- ----------
614
- labeler : LLMLabeler
615
- Configured LLM labeler instance.
616
- topics : list[int], optional
617
- Specific topics to label. If None, labels all.
618
- """
619
- if not self._is_fitted:
620
- raise ValueError("Model not fitted. Call fit() first.")
621
-
622
- target_topics = topics or [t.topic_id for t in self.topics_ if t.topic_id != -1]
623
-
624
- for topic_id in tqdm(target_topics, desc="Generating labels", disable=not self.config.verbose):
625
- topic = self.get_topic(topic_id)
626
- if topic is None:
627
- continue
628
-
629
- # Get representative docs
630
- rep_docs = self.get_representative_docs(topic_id, n_docs=5)
631
- doc_texts = [doc for _, doc in rep_docs]
632
-
633
- # Generate label
634
- label, description = labeler.generate_label(
635
- keywords=topic.keywords,
636
- representative_docs=doc_texts,
637
- )
638
-
639
- topic.label = label
640
- topic.description = description
641
-
642
- def visualize(
643
- self,
644
- method: Literal["umap", "pacmap"] = "umap",
645
- color_by: Literal["topic", "custom"] = "topic",
646
- custom_labels: list[str] | None = None,
647
- show_outliers: bool = True,
648
- interactive: bool = True,
649
- **kwargs,
650
- ):
651
- """
652
- Visualize topics in 2D.
653
-
654
- Parameters
655
- ----------
656
- method : str
657
- Dimensionality reduction method. "umap" or "pacmap".
658
- color_by : str
659
- How to color points. "topic" uses topic assignments.
660
- custom_labels : list[str], optional
661
- Custom labels for hover text.
662
- show_outliers : bool
663
- Whether to show outlier documents.
664
- interactive : bool
665
- If True, returns interactive Plotly figure.
666
- **kwargs
667
- Additional arguments passed to the visualizer.
668
-
669
- Returns
670
- -------
671
- fig : plotly.graph_objects.Figure
672
- Interactive visualization.
673
- """
674
- from tritopic.visualization.plotter import TopicVisualizer
675
-
676
- if not self._is_fitted:
677
- raise ValueError("Model not fitted. Call fit() first.")
678
-
679
- visualizer = TopicVisualizer(method=method)
680
-
681
- return visualizer.plot_documents(
682
- embeddings=self.embeddings_,
683
- labels=self.labels_,
684
- documents=self.documents_,
685
- topics=self.topics_,
686
- show_outliers=show_outliers,
687
- interactive=interactive,
688
- **kwargs,
689
- )
690
-
691
- def visualize_hierarchy(self, **kwargs):
692
- """Visualize topic hierarchy as a dendrogram."""
693
- from tritopic.visualization.plotter import TopicVisualizer
694
-
695
- if not self._is_fitted:
696
- raise ValueError("Model not fitted. Call fit() first.")
697
-
698
- visualizer = TopicVisualizer()
699
- return visualizer.plot_hierarchy(
700
- topic_embeddings=self.topic_embeddings_,
701
- topics=self.topics_,
702
- **kwargs,
703
- )
704
-
705
- def visualize_topics(self, **kwargs):
706
- """Visualize topics as a heatmap or bar chart."""
707
- from tritopic.visualization.plotter import TopicVisualizer
708
-
709
- if not self._is_fitted:
710
- raise ValueError("Model not fitted. Call fit() first.")
711
-
712
- visualizer = TopicVisualizer()
713
- return visualizer.plot_topics(
714
- topics=self.topics_,
715
- **kwargs,
716
- )
717
-
718
- def evaluate(self) -> dict[str, float]:
719
- """
720
- Evaluate topic model quality.
721
-
722
- Returns
723
- -------
724
- metrics : dict
725
- Dictionary with coherence, diversity, and stability scores.
726
- """
727
- if not self._is_fitted:
728
- raise ValueError("Model not fitted. Call fit() first.")
729
-
730
- # Compute coherence for each topic
731
- coherences = []
732
- for topic in self.topics_:
733
- if topic.topic_id != -1:
734
- coh = compute_coherence(
735
- topic.keywords,
736
- [self.documents_[i] for i in np.where(self.labels_ == topic.topic_id)[0]]
737
- )
738
- topic.coherence = coh
739
- coherences.append(coh)
740
-
741
- # Compute diversity
742
- all_keywords = [kw for t in self.topics_ if t.topic_id != -1 for kw in t.keywords]
743
- diversity = compute_diversity(all_keywords, n_topics=len(coherences))
744
-
745
- # Get stability from consensus clustering
746
- stability = self._clusterer.stability_score_ if hasattr(self._clusterer, 'stability_score_') else None
747
-
748
- metrics = {
749
- "coherence_mean": float(np.mean(coherences)) if coherences else 0.0,
750
- "coherence_std": float(np.std(coherences)) if coherences else 0.0,
751
- "diversity": diversity,
752
- "stability": stability,
753
- "n_topics": len([t for t in self.topics_ if t.topic_id != -1]),
754
- "outlier_ratio": float(np.mean(self.labels_ == -1)) if self.labels_ is not None else 0.0,
755
- }
756
-
757
- if self.config.verbose:
758
- print("\n📊 Evaluation Metrics:")
759
- print(f" Coherence (mean): {metrics['coherence_mean']:.4f}")
760
- print(f" Diversity: {metrics['diversity']:.4f}")
761
- if stability:
762
- print(f" Stability: {stability:.4f}")
763
- print(f" Outlier ratio: {metrics['outlier_ratio']:.2%}")
764
-
765
- return metrics
766
-
767
- def save(self, path: str) -> None:
768
- """Save model to disk."""
769
- import pickle
770
-
771
- state = {
772
- "config": self.config,
773
- "topics_": self.topics_,
774
- "labels_": self.labels_,
775
- "embeddings_": self.embeddings_,
776
- "topic_embeddings_": self.topic_embeddings_,
777
- "documents_": self.documents_,
778
- "_is_fitted": self._is_fitted,
779
- "_iteration_history": self._iteration_history,
780
- }
781
-
782
- with open(path, "wb") as f:
783
- pickle.dump(state, f)
784
-
785
- if self.config.verbose:
786
- print(f"💾 Model saved to {path}")
787
-
788
- @classmethod
789
- def load(cls, path: str) -> "TriTopic":
790
- """Load model from disk."""
791
- import pickle
792
-
793
- with open(path, "rb") as f:
794
- state = pickle.load(f)
795
-
796
- model = cls(config=state["config"])
797
- model.topics_ = state["topics_"]
798
- model.labels_ = state["labels_"]
799
- model.embeddings_ = state["embeddings_"]
800
- model.topic_embeddings_ = state["topic_embeddings_"]
801
- model.documents_ = state["documents_"]
802
- model._is_fitted = state["_is_fitted"]
803
- model._iteration_history = state["_iteration_history"]
804
-
805
- return model
806
-
807
- def __repr__(self) -> str:
808
- status = "fitted" if self._is_fitted else "not fitted"
809
- n_topics = len([t for t in self.topics_ if t.topic_id != -1]) if self._is_fitted else "?"
810
- return f"TriTopic(n_topics={n_topics}, status={status})"