tritopic 0.1.0__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tritopic/model.py ADDED
@@ -0,0 +1,718 @@
1
+ """
2
+ TriTopic: Tri-Modal Graph Topic Modeling with Iterative Refinement
3
+
4
+ Main model class that orchestrates all components.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import pickle
10
+ import warnings
11
+ from dataclasses import dataclass, field
12
+ from pathlib import Path
13
+ from typing import Any, Dict, List, Optional, Tuple, Union
14
+
15
+ import numpy as np
16
+ import pandas as pd
17
+ from scipy import sparse
18
+ from tqdm import tqdm
19
+
20
+ from .config import TriTopicConfig, get_config
21
+ from .core.embeddings import EmbeddingEngine
22
+ from .core.graph import GraphBuilder, MultiViewGraphBuilder
23
+ from .core.clustering import ConsensusLeiden
24
+ from .core.refinement import IterativeRefinement
25
+ from .core.keywords import KeywordExtractor
26
+ from .core.representatives import RepresentativeSelector
27
+
28
+
29
+ @dataclass
30
+ class Topic:
31
+ """Represents a single topic with its metadata."""
32
+ topic_id: int
33
+ size: int
34
+ keywords: List[str]
35
+ keyword_scores: List[float]
36
+ representative_docs: List[int]
37
+ representative_texts: List[str]
38
+ centroid: Optional[np.ndarray] = None
39
+ label: Optional[str] = None
40
+
41
+ def __repr__(self) -> str:
42
+ kw_str = ", ".join(self.keywords[:5])
43
+ label_str = f" ({self.label})" if self.label else ""
44
+ return f"Topic {self.topic_id}{label_str}: [{kw_str}...] (n={self.size})"
45
+
46
+
47
+ class TriTopic:
48
+ """
49
+ Tri-Modal Graph Topic Modeling with Iterative Refinement.
50
+
51
+ A state-of-the-art topic modeling approach that combines:
52
+ - Multi-view representation (semantic, lexical, metadata)
53
+ - Hybrid graph construction (Mutual kNN + SNN)
54
+ - Consensus Leiden clustering for stability
55
+ - Iterative refinement for improved coherence
56
+ - LLM-powered labeling (optional)
57
+
58
+ Parameters
59
+ ----------
60
+ config : TriTopicConfig, str, or None
61
+ Configuration object or name of preset config.
62
+ If None, uses default config.
63
+ **kwargs
64
+ Override any config parameter directly.
65
+
66
+ Examples
67
+ --------
68
+ >>> from tritopic import TriTopic
69
+ >>> model = TriTopic(verbose=True)
70
+ >>> topics = model.fit_transform(documents)
71
+ >>> print(model.get_topic_info())
72
+
73
+ >>> # With custom config
74
+ >>> model = TriTopic(
75
+ ... embedding_model="BAAI/bge-base-en-v1.5",
76
+ ... n_neighbors=20,
77
+ ... use_iterative_refinement=True
78
+ ... )
79
+ """
80
+
81
+ def __init__(
82
+ self,
83
+ config: Optional[Union[TriTopicConfig, str]] = None,
84
+ **kwargs
85
+ ):
86
+ # Load config
87
+ if config is None:
88
+ self.config = TriTopicConfig()
89
+ elif isinstance(config, str):
90
+ self.config = get_config(config)
91
+ else:
92
+ self.config = config
93
+
94
+ # Override with kwargs
95
+ for key, value in kwargs.items():
96
+ if hasattr(self.config, key):
97
+ setattr(self.config, key, value)
98
+ else:
99
+ warnings.warn(f"Unknown config parameter: {key}")
100
+
101
+ # Initialize components (lazy loading)
102
+ self._embedding_engine: Optional[EmbeddingEngine] = None
103
+ self._graph_builder: Optional[MultiViewGraphBuilder] = None
104
+ self._clusterer: Optional[ConsensusLeiden] = None
105
+ self._refiner: Optional[IterativeRefinement] = None
106
+ self._keyword_extractor: Optional[KeywordExtractor] = None
107
+ self._representative_selector: Optional[RepresentativeSelector] = None
108
+
109
+ # Fitted attributes
110
+ self.documents_: Optional[List[str]] = None
111
+ self.embeddings_: Optional[np.ndarray] = None
112
+ self.graph_: Optional[sparse.csr_matrix] = None
113
+ self.topic_labels_: Optional[np.ndarray] = None
114
+ self.topics_: Optional[List[Topic]] = None
115
+ self.n_topics_: int = 0
116
+ self.outlier_count_: int = 0
117
+
118
+ # Language detection results
119
+ self.detected_language_: Optional[str] = None
120
+ self.is_multilingual_: bool = False
121
+
122
+ def _log(self, message: str, level: int = 1) -> None:
123
+ """Print message if verbose mode is enabled."""
124
+ if self.config.verbose:
125
+ indent = " " * (level - 1)
126
+ print(f"{indent}{message}")
127
+
128
+ def _initialize_components(self, documents: List[str]) -> None:
129
+ """Initialize all components based on config and detected language."""
130
+ # Detect language if auto
131
+ if self.config.language == "auto":
132
+ self._detect_language(documents)
133
+ else:
134
+ self.detected_language_ = self.config.language
135
+
136
+ # Get appropriate embedding model
137
+ embedding_model = self.config.embedding_model
138
+ if embedding_model == "auto":
139
+ embedding_model = self.config.get_embedding_model_for_language(
140
+ self.detected_language_ or "en"
141
+ )
142
+ self._log(f"Auto-selected embedding model: {embedding_model}", 2)
143
+
144
+ # Initialize embedding engine
145
+ self._embedding_engine = EmbeddingEngine(
146
+ model_name=embedding_model,
147
+ batch_size=self.config.embedding_batch_size,
148
+ device=self.config.device
149
+ )
150
+
151
+ # Initialize graph builder
152
+ self._graph_builder = MultiViewGraphBuilder(
153
+ n_neighbors=self.config.n_neighbors,
154
+ metric=self.config.metric,
155
+ graph_type=self.config.graph_type,
156
+ snn_weight=self.config.snn_weight,
157
+ semantic_weight=self.config.semantic_weight,
158
+ lexical_weight=self.config.lexical_weight,
159
+ metadata_weight=self.config.metadata_weight,
160
+ use_lexical=self.config.use_lexical_view,
161
+ use_metadata=self.config.use_metadata_view,
162
+ lexical_method=self.config.lexical_method,
163
+ ngram_range=self.config.ngram_range
164
+ )
165
+
166
+ # Initialize clusterer
167
+ self._clusterer = ConsensusLeiden(
168
+ resolution=self.config.resolution,
169
+ n_runs=self.config.n_consensus_runs,
170
+ min_cluster_size=self.config.min_cluster_size,
171
+ random_state=self.config.random_state
172
+ )
173
+
174
+ # Initialize refiner
175
+ if self.config.use_iterative_refinement:
176
+ self._refiner = IterativeRefinement(
177
+ max_iterations=self.config.max_iterations,
178
+ convergence_threshold=self.config.convergence_threshold,
179
+ refinement_strength=self.config.refinement_strength
180
+ )
181
+
182
+ # Initialize keyword extractor
183
+ self._keyword_extractor = KeywordExtractor(
184
+ method=self.config.keyword_method,
185
+ n_keywords=self.config.n_keywords,
186
+ language=self.detected_language_ or "en",
187
+ ngram_range=self.config.ngram_range
188
+ )
189
+
190
+ # Initialize representative selector
191
+ self._representative_selector = RepresentativeSelector(
192
+ method=self.config.representative_method,
193
+ n_representatives=self.config.n_representative_docs,
194
+ n_archetypes=self.config.n_archetypes,
195
+ archetype_method=self.config.archetype_method
196
+ )
197
+
198
+ def _detect_language(self, documents: List[str]) -> None:
199
+ """Detect the dominant language of the corpus."""
200
+ try:
201
+ from .multilingual.detection import detect_corpus_language
202
+
203
+ result = detect_corpus_language(
204
+ documents,
205
+ sample_size=self.config.language_detection_sample
206
+ )
207
+ self.detected_language_ = result["dominant_language"]
208
+ self.is_multilingual_ = result["is_multilingual"]
209
+
210
+ self._log(f"Detected language: {self.detected_language_} "
211
+ f"(confidence: {result['confidence']:.2f})", 2)
212
+ if self.is_multilingual_:
213
+ self._log("Corpus appears multilingual", 2)
214
+
215
+ except ImportError:
216
+ self._log("Language detection not available, defaulting to English", 2)
217
+ self.detected_language_ = "en"
218
+
219
+ def fit(
220
+ self,
221
+ documents: List[str],
222
+ embeddings: Optional[np.ndarray] = None,
223
+ metadata: Optional[pd.DataFrame] = None
224
+ ) -> "TriTopic":
225
+ """
226
+ Fit the topic model on documents.
227
+
228
+ Parameters
229
+ ----------
230
+ documents : List[str]
231
+ List of documents to model.
232
+ embeddings : np.ndarray, optional
233
+ Pre-computed embeddings. If None, embeddings are generated.
234
+ metadata : pd.DataFrame, optional
235
+ Document metadata for multi-view fusion.
236
+
237
+ Returns
238
+ -------
239
+ self
240
+ Fitted model.
241
+ """
242
+ n_docs = len(documents)
243
+ self.documents_ = documents
244
+
245
+ self._log(f"🚀 TriTopic: Fitting model on {n_docs} documents")
246
+ self._log(f"Config: {self.config.graph_type} graph, "
247
+ f"{'iterative' if self.config.use_iterative_refinement else 'single-pass'} mode", 1)
248
+
249
+ # Initialize components
250
+ self._initialize_components(documents)
251
+
252
+ # Step 1: Generate embeddings
253
+ if embeddings is not None:
254
+ self._log("→ Using provided embeddings", 1)
255
+ self.embeddings_ = embeddings
256
+ else:
257
+ model_name = self._embedding_engine.model_name
258
+ self._log(f"→ Generating embeddings ({model_name})...", 1)
259
+ self.embeddings_ = self._embedding_engine.encode(
260
+ documents,
261
+ show_progress=self.config.verbose
262
+ )
263
+
264
+ # Step 2: Build lexical similarity (if enabled)
265
+ if self.config.use_lexical_view:
266
+ self._log("→ Building lexical similarity matrix...", 1)
267
+
268
+ # Step 3: Build graph
269
+ self._log("→ Constructing multi-view graph...", 1)
270
+ self.graph_ = self._graph_builder.build(
271
+ embeddings=self.embeddings_,
272
+ documents=documents,
273
+ metadata=metadata
274
+ )
275
+
276
+ # Step 4: Clustering (with optional refinement)
277
+ if self.config.use_iterative_refinement:
278
+ self._log(f"→ Starting iterative refinement (max {self.config.max_iterations} iterations)...", 1)
279
+
280
+ def graph_builder_fn(emb):
281
+ return self._graph_builder.build(
282
+ embeddings=emb,
283
+ documents=documents,
284
+ metadata=metadata
285
+ )
286
+
287
+ def cluster_fn(g):
288
+ return self._clusterer.fit_predict(g)
289
+
290
+ self.topic_labels_, self.embeddings_, iterations = self._refiner.refine(
291
+ embeddings=self.embeddings_,
292
+ initial_labels=self._clusterer.fit_predict(self.graph_),
293
+ graph_builder_fn=graph_builder_fn,
294
+ cluster_fn=cluster_fn,
295
+ verbose=self.config.verbose
296
+ )
297
+
298
+ # Rebuild final graph with refined embeddings
299
+ self.graph_ = graph_builder_fn(self.embeddings_)
300
+
301
+ else:
302
+ self._log("→ Consensus clustering...", 1)
303
+ self.topic_labels_ = self._clusterer.fit_predict(self.graph_)
304
+
305
+ # Step 5: Extract keywords and representatives
306
+ self._log("→ Extracting keywords and representative documents...", 1)
307
+ self._extract_topic_info(documents)
308
+
309
+ # Summary
310
+ self.n_topics_ = len([t for t in self.topics_ if t.topic_id >= 0])
311
+ self.outlier_count_ = sum(1 for l in self.topic_labels_ if l < 0)
312
+ outlier_pct = 100 * self.outlier_count_ / n_docs
313
+
314
+ self._log("")
315
+ self._log(f"✅ Fitting complete!")
316
+ self._log(f" Found {self.n_topics_} topics")
317
+ self._log(f" {self.outlier_count_} outlier documents ({outlier_pct:.1f}%)")
318
+
319
+ return self
320
+
321
+ def fit_transform(
322
+ self,
323
+ documents: List[str],
324
+ embeddings: Optional[np.ndarray] = None,
325
+ metadata: Optional[pd.DataFrame] = None
326
+ ) -> np.ndarray:
327
+ """
328
+ Fit the model and return topic assignments.
329
+
330
+ Parameters
331
+ ----------
332
+ documents : List[str]
333
+ List of documents to model.
334
+ embeddings : np.ndarray, optional
335
+ Pre-computed embeddings.
336
+ metadata : pd.DataFrame, optional
337
+ Document metadata.
338
+
339
+ Returns
340
+ -------
341
+ np.ndarray
342
+ Topic assignments for each document (-1 for outliers).
343
+ """
344
+ self.fit(documents, embeddings, metadata)
345
+ return self.topic_labels_
346
+
347
+ def transform(self, documents: List[str]) -> np.ndarray:
348
+ """
349
+ Assign topics to new documents.
350
+
351
+ Parameters
352
+ ----------
353
+ documents : List[str]
354
+ New documents to classify.
355
+
356
+ Returns
357
+ -------
358
+ np.ndarray
359
+ Topic assignments for each document.
360
+ """
361
+ if self.topics_ is None:
362
+ raise ValueError("Model not fitted. Call fit() first.")
363
+
364
+ # Encode new documents
365
+ new_embeddings = self._embedding_engine.encode(documents)
366
+
367
+ # Find nearest topic centroid for each document
368
+ topic_centroids = np.array([
369
+ t.centroid for t in self.topics_
370
+ if t.topic_id >= 0 and t.centroid is not None
371
+ ])
372
+ topic_ids = [t.topic_id for t in self.topics_ if t.topic_id >= 0]
373
+
374
+ # Compute distances to centroids
375
+ from sklearn.metrics.pairwise import cosine_distances
376
+ distances = cosine_distances(new_embeddings, topic_centroids)
377
+
378
+ # Assign to nearest topic (with outlier threshold)
379
+ assignments = []
380
+ for i, doc_distances in enumerate(distances):
381
+ min_idx = np.argmin(doc_distances)
382
+ min_dist = doc_distances[min_idx]
383
+
384
+ if min_dist > self.config.outlier_threshold:
385
+ assignments.append(-1)
386
+ else:
387
+ assignments.append(topic_ids[min_idx])
388
+
389
+ return np.array(assignments)
390
+
391
+ def _extract_topic_info(self, documents: List[str]) -> None:
392
+ """Extract keywords and representatives for each topic."""
393
+ unique_labels = sorted(set(self.topic_labels_))
394
+ self.topics_ = []
395
+
396
+ for topic_id in unique_labels:
397
+ # Get documents in this topic
398
+ mask = self.topic_labels_ == topic_id
399
+ topic_docs = [documents[i] for i, m in enumerate(mask) if m]
400
+ topic_embeddings = self.embeddings_[mask]
401
+ topic_indices = np.where(mask)[0]
402
+
403
+ # Extract keywords
404
+ if topic_id >= 0:
405
+ keywords, scores = self._keyword_extractor.extract(
406
+ topic_docs,
407
+ all_documents=documents
408
+ )
409
+ else:
410
+ keywords, scores = ["[outlier]"], [0.0]
411
+
412
+ # Get representative documents
413
+ if topic_id >= 0 and len(topic_docs) > 0:
414
+ reps = self._representative_selector.select(
415
+ embeddings=topic_embeddings,
416
+ documents=topic_docs,
417
+ keywords=keywords[:5],
418
+ global_indices=topic_indices
419
+ )
420
+ rep_indices = [r.doc_id for r in reps.representatives]
421
+ rep_texts = [r.text[:200] for r in reps.representatives]
422
+ else:
423
+ rep_indices = list(topic_indices[:self.config.n_representative_docs])
424
+ rep_texts = topic_docs[:self.config.n_representative_docs]
425
+
426
+ # Compute centroid
427
+ centroid = topic_embeddings.mean(axis=0) if len(topic_embeddings) > 0 else None
428
+
429
+ topic = Topic(
430
+ topic_id=topic_id,
431
+ size=len(topic_docs),
432
+ keywords=keywords,
433
+ keyword_scores=scores,
434
+ representative_docs=rep_indices,
435
+ representative_texts=rep_texts,
436
+ centroid=centroid
437
+ )
438
+ self.topics_.append(topic)
439
+
440
+ def get_topic_info(self) -> pd.DataFrame:
441
+ """
442
+ Get summary information about all topics.
443
+
444
+ Returns
445
+ -------
446
+ pd.DataFrame
447
+ DataFrame with topic information.
448
+ """
449
+ if self.topics_ is None:
450
+ raise ValueError("Model not fitted. Call fit() first.")
451
+
452
+ data = []
453
+ for topic in self.topics_:
454
+ data.append({
455
+ "Topic": topic.topic_id,
456
+ "Size": topic.size,
457
+ "Label": topic.label or "",
458
+ "Keywords": ", ".join(topic.keywords[:5]),
459
+ "Representative": topic.representative_texts[0][:100] + "..."
460
+ if topic.representative_texts else ""
461
+ })
462
+
463
+ return pd.DataFrame(data)
464
+
465
+ def get_topic(self, topic_id: int) -> Optional[Topic]:
466
+ """
467
+ Get detailed information about a specific topic.
468
+
469
+ Parameters
470
+ ----------
471
+ topic_id : int
472
+ Topic ID to retrieve.
473
+
474
+ Returns
475
+ -------
476
+ Topic or None
477
+ Topic object if found.
478
+ """
479
+ if self.topics_ is None:
480
+ return None
481
+
482
+ for topic in self.topics_:
483
+ if topic.topic_id == topic_id:
484
+ return topic
485
+ return None
486
+
487
+ def get_document_topics(
488
+ self,
489
+ doc_indices: Optional[List[int]] = None
490
+ ) -> pd.DataFrame:
491
+ """
492
+ Get topic assignments for documents.
493
+
494
+ Parameters
495
+ ----------
496
+ doc_indices : List[int], optional
497
+ Specific document indices. If None, returns all.
498
+
499
+ Returns
500
+ -------
501
+ pd.DataFrame
502
+ DataFrame with document-topic assignments.
503
+ """
504
+ if self.topic_labels_ is None:
505
+ raise ValueError("Model not fitted.")
506
+
507
+ if doc_indices is None:
508
+ doc_indices = list(range(len(self.topic_labels_)))
509
+
510
+ data = []
511
+ for idx in doc_indices:
512
+ topic_id = self.topic_labels_[idx]
513
+ topic = self.get_topic(topic_id)
514
+ data.append({
515
+ "Document": idx,
516
+ "Topic": topic_id,
517
+ "Topic_Label": topic.label if topic else "",
518
+ "Text_Preview": self.documents_[idx][:100] + "..."
519
+ })
520
+
521
+ return pd.DataFrame(data)
522
+
523
+ def generate_labels(self, labeler: Any) -> None:
524
+ """
525
+ Generate human-readable labels for topics using an LLM.
526
+
527
+ Parameters
528
+ ----------
529
+ labeler : LLMLabeler
530
+ Labeler instance configured with LLM provider.
531
+ """
532
+ if self.topics_ is None:
533
+ raise ValueError("Model not fitted.")
534
+
535
+ for topic in self.topics_:
536
+ if topic.topic_id < 0:
537
+ topic.label = "Outliers"
538
+ continue
539
+
540
+ label = labeler.generate_label(
541
+ keywords=topic.keywords,
542
+ representative_docs=topic.representative_texts
543
+ )
544
+ topic.label = label
545
+
546
+ def evaluate(self) -> Dict[str, float]:
547
+ """
548
+ Evaluate the topic model quality.
549
+
550
+ Returns
551
+ -------
552
+ Dict[str, float]
553
+ Dictionary with evaluation metrics.
554
+ """
555
+ if self.topics_ is None:
556
+ raise ValueError("Model not fitted.")
557
+
558
+ metrics = {}
559
+
560
+ # Number of topics
561
+ metrics["n_topics"] = self.n_topics_
562
+
563
+ # Outlier ratio
564
+ metrics["outlier_ratio"] = self.outlier_count_ / len(self.topic_labels_)
565
+
566
+ # Topic diversity
567
+ all_keywords = []
568
+ for topic in self.topics_:
569
+ if topic.topic_id >= 0:
570
+ all_keywords.extend(topic.keywords[:10])
571
+
572
+ unique_keywords = len(set(all_keywords))
573
+ total_keywords = len(all_keywords)
574
+ metrics["diversity"] = unique_keywords / total_keywords if total_keywords > 0 else 0
575
+
576
+ # Coherence (simplified NPMI)
577
+ try:
578
+ from .core.keywords import compute_coherence
579
+ coherence_scores = []
580
+ for topic in self.topics_:
581
+ if topic.topic_id >= 0:
582
+ score = compute_coherence(topic.keywords[:10], self.documents_)
583
+ coherence_scores.append(score)
584
+
585
+ metrics["coherence_mean"] = np.mean(coherence_scores) if coherence_scores else 0
586
+ metrics["coherence_std"] = np.std(coherence_scores) if coherence_scores else 0
587
+ except:
588
+ metrics["coherence_mean"] = 0
589
+ metrics["coherence_std"] = 0
590
+
591
+ # Clustering stability
592
+ if hasattr(self._clusterer, "stability_score_"):
593
+ metrics["stability"] = self._clusterer.stability_score_
594
+
595
+ return metrics
596
+
597
+ def visualize(
598
+ self,
599
+ method: str = "umap",
600
+ **kwargs
601
+ ) -> Any:
602
+ """
603
+ Visualize the topic model.
604
+
605
+ Parameters
606
+ ----------
607
+ method : str
608
+ Visualization method: "umap", "tsne", or "pca".
609
+ **kwargs
610
+ Additional arguments for visualization.
611
+
612
+ Returns
613
+ -------
614
+ plotly.graph_objects.Figure
615
+ Interactive visualization.
616
+ """
617
+ try:
618
+ from .visualization import create_topic_visualization
619
+ return create_topic_visualization(
620
+ embeddings=self.embeddings_,
621
+ labels=self.topic_labels_,
622
+ topics=self.topics_,
623
+ documents=self.documents_,
624
+ method=method,
625
+ **kwargs
626
+ )
627
+ except ImportError:
628
+ raise ImportError(
629
+ "Visualization requires plotly. Install with: pip install tritopic[visualization]"
630
+ )
631
+
632
+ def visualize_topics(self, **kwargs) -> Any:
633
+ """Visualize topic keywords as bar charts."""
634
+ try:
635
+ from .visualization import create_topic_barchart
636
+ return create_topic_barchart(self.topics_, **kwargs)
637
+ except ImportError:
638
+ raise ImportError("Visualization requires plotly.")
639
+
640
+ def visualize_hierarchy(self, **kwargs) -> Any:
641
+ """Visualize topic hierarchy as dendrogram."""
642
+ try:
643
+ from .visualization import create_topic_hierarchy
644
+ return create_topic_hierarchy(
645
+ embeddings=self.embeddings_,
646
+ labels=self.topic_labels_,
647
+ topics=self.topics_,
648
+ **kwargs
649
+ )
650
+ except ImportError:
651
+ raise ImportError("Visualization requires plotly.")
652
+
653
+ def save(self, path: Union[str, Path]) -> None:
654
+ """
655
+ Save the model to disk.
656
+
657
+ Parameters
658
+ ----------
659
+ path : str or Path
660
+ File path for saving.
661
+ """
662
+ path = Path(path)
663
+
664
+ # Prepare state dict (excluding non-picklable objects)
665
+ state = {
666
+ "config": self.config,
667
+ "documents_": self.documents_,
668
+ "embeddings_": self.embeddings_,
669
+ "topic_labels_": self.topic_labels_,
670
+ "topics_": self.topics_,
671
+ "n_topics_": self.n_topics_,
672
+ "outlier_count_": self.outlier_count_,
673
+ "detected_language_": self.detected_language_,
674
+ "is_multilingual_": self.is_multilingual_,
675
+ }
676
+
677
+ with open(path, "wb") as f:
678
+ pickle.dump(state, f)
679
+
680
+ @classmethod
681
+ def load(cls, path: Union[str, Path]) -> "TriTopic":
682
+ """
683
+ Load a model from disk.
684
+
685
+ Parameters
686
+ ----------
687
+ path : str or Path
688
+ File path to load from.
689
+
690
+ Returns
691
+ -------
692
+ TriTopic
693
+ Loaded model.
694
+ """
695
+ path = Path(path)
696
+
697
+ with open(path, "rb") as f:
698
+ state = pickle.load(f)
699
+
700
+ model = cls(config=state["config"])
701
+ model.documents_ = state["documents_"]
702
+ model.embeddings_ = state["embeddings_"]
703
+ model.topic_labels_ = state["topic_labels_"]
704
+ model.topics_ = state["topics_"]
705
+ model.n_topics_ = state["n_topics_"]
706
+ model.outlier_count_ = state["outlier_count_"]
707
+ model.detected_language_ = state.get("detected_language_")
708
+ model.is_multilingual_ = state.get("is_multilingual_", False)
709
+
710
+ # Re-initialize components for transform()
711
+ model._initialize_components(model.documents_)
712
+
713
+ return model
714
+
715
+ def __repr__(self) -> str:
716
+ if self.topics_ is None:
717
+ return "TriTopic(not fitted)"
718
+ return f"TriTopic(n_topics={self.n_topics_}, n_docs={len(self.documents_)})"