tritopic 0.1.0__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tritopic might be problematic. Click here for more details.

tritopic/__init__.py CHANGED
@@ -1,46 +1,36 @@
1
1
  """
2
2
  TriTopic: Tri-Modal Graph Topic Modeling with Iterative Refinement
3
- ===================================================================
4
3
 
5
- A state-of-the-art topic modeling library that combines:
6
- - Semantic embeddings (Sentence-BERT, Instructor, BGE)
7
- - Lexical similarity (BM25)
8
- - Metadata context (optional)
4
+ A state-of-the-art topic modeling library that consistently outperforms
5
+ BERTopic and traditional approaches.
9
6
 
10
- With advanced techniques:
11
- - Leiden clustering with consensus
12
- - Mutual kNN + SNN graph construction
13
- - Iterative refinement loop
14
- - LLM-powered topic labeling
7
+ Key Features:
8
+ - Multi-view representation (semantic, lexical, metadata)
9
+ - Hybrid graph construction (Mutual kNN + SNN)
10
+ - Consensus Leiden clustering for stability
11
+ - Iterative refinement for improved coherence
12
+ - Multilingual support (60+ languages)
13
+ - LLM-powered labeling
15
14
 
16
- Basic usage:
17
- -----------
18
- >>> from tritopic import TriTopic
19
- >>> model = TriTopic()
20
- >>> topics = model.fit_transform(documents)
21
- >>> model.visualize()
22
-
23
- Author: Roman Egger
24
- License: MIT
15
+ Example:
16
+ >>> from tritopic import TriTopic
17
+ >>> model = TriTopic(verbose=True)
18
+ >>> topics = model.fit_transform(documents)
19
+ >>> print(model.get_topic_info())
25
20
  """
26
21
 
27
- __version__ = "0.1.0"
22
+ __version__ = "1.0.0"
28
23
  __author__ = "Roman Egger"
29
24
 
30
- from tritopic.core.model import TriTopic
31
- from tritopic.core.graph_builder import GraphBuilder
32
- from tritopic.core.clustering import ConsensusLeiden
33
- from tritopic.core.embeddings import EmbeddingEngine
34
- from tritopic.core.keywords import KeywordExtractor
35
- from tritopic.labeling.llm_labeler import LLMLabeler
36
- from tritopic.visualization.plotter import TopicVisualizer
25
+ from .model import TriTopic, Topic
26
+ from .config import TriTopicConfig, get_config
27
+ from .labeling import LLMLabeler, KeywordLabeler
37
28
 
38
29
  __all__ = [
39
30
  "TriTopic",
40
- "GraphBuilder",
41
- "ConsensusLeiden",
42
- "EmbeddingEngine",
43
- "KeywordExtractor",
31
+ "Topic",
32
+ "TriTopicConfig",
33
+ "get_config",
44
34
  "LLMLabeler",
45
- "TopicVisualizer",
35
+ "KeywordLabeler",
46
36
  ]
tritopic/config.py ADDED
@@ -0,0 +1,289 @@
1
+ """
2
+ TriTopic Configuration Module
3
+
4
+ This module provides configuration classes and utilities for TriTopic.
5
+ """
6
+
7
+ from dataclasses import dataclass, field
8
+ from typing import Optional, List, Literal
9
+
10
+
11
+ @dataclass
12
+ class TriTopicConfig:
13
+ """
14
+ Configuration class for TriTopic model.
15
+
16
+ All hyperparameters can be set here and passed to TriTopic.
17
+
18
+ Example:
19
+ config = TriTopicConfig(
20
+ embedding_model="all-mpnet-base-v2",
21
+ n_neighbors=20,
22
+ use_iterative_refinement=True
23
+ )
24
+ model = TriTopic(config=config)
25
+ """
26
+
27
+ # ==========================================================================
28
+ # Embedding Settings
29
+ # ==========================================================================
30
+ embedding_model: str = "all-MiniLM-L6-v2"
31
+ """Sentence-Transformer model name or path.
32
+ Options: "all-MiniLM-L6-v2", "all-mpnet-base-v2", "BAAI/bge-base-en-v1.5", etc."""
33
+
34
+ embedding_batch_size: int = 32
35
+ """Batch size for embedding generation. Reduce if GPU OOM."""
36
+
37
+ # ==========================================================================
38
+ # Graph Construction
39
+ # ==========================================================================
40
+ n_neighbors: int = 15
41
+ """Number of neighbors for kNN graph construction."""
42
+
43
+ metric: str = "cosine"
44
+ """Distance metric: "cosine", "euclidean", "manhattan"."""
45
+
46
+ graph_type: Literal["knn", "mutual_knn", "snn", "hybrid"] = "hybrid"
47
+ """Graph type:
48
+ - "knn": Standard k-nearest neighbors
49
+ - "mutual_knn": Only bidirectional connections
50
+ - "snn": Shared nearest neighbors
51
+ - "hybrid": Combination of mutual_knn + snn (recommended)
52
+ """
53
+
54
+ snn_weight: float = 0.5
55
+ """Weight of SNN component in hybrid graph (0.0 to 1.0)."""
56
+
57
+ # ==========================================================================
58
+ # Multi-View Fusion
59
+ # ==========================================================================
60
+ use_lexical_view: bool = True
61
+ """Include lexical (TF-IDF) similarity in graph construction."""
62
+
63
+ use_metadata_view: bool = False
64
+ """Include metadata similarity in graph construction."""
65
+
66
+ semantic_weight: float = 0.5
67
+ """Weight for semantic (embedding) similarity."""
68
+
69
+ lexical_weight: float = 0.3
70
+ """Weight for lexical (TF-IDF) similarity."""
71
+
72
+ metadata_weight: float = 0.2
73
+ """Weight for metadata similarity."""
74
+
75
+ # ==========================================================================
76
+ # Clustering (Leiden + Consensus)
77
+ # ==========================================================================
78
+ resolution: float = 1.0
79
+ """Leiden resolution parameter. Higher = more topics, lower = fewer topics."""
80
+
81
+ n_consensus_runs: int = 10
82
+ """Number of clustering runs for consensus clustering."""
83
+
84
+ min_cluster_size: int = 5
85
+ """Minimum documents per topic. Smaller clusters become outliers."""
86
+
87
+ # ==========================================================================
88
+ # Iterative Refinement
89
+ # ==========================================================================
90
+ use_iterative_refinement: bool = True
91
+ """Enable iterative embedding refinement based on discovered topics."""
92
+
93
+ max_iterations: int = 5
94
+ """Maximum number of refinement iterations."""
95
+
96
+ convergence_threshold: float = 0.95
97
+ """ARI threshold for convergence (0.0 to 1.0)."""
98
+
99
+ refinement_strength: float = 0.1
100
+ """How strongly to pull embeddings toward topic centroids (0.0 to 1.0)."""
101
+
102
+ # ==========================================================================
103
+ # Keyword Extraction
104
+ # ==========================================================================
105
+ n_keywords: int = 10
106
+ """Number of keywords to extract per topic."""
107
+
108
+ n_representative_docs: int = 5
109
+ """Number of representative documents per topic."""
110
+
111
+ keyword_method: Literal["ctfidf", "bm25", "keybert"] = "ctfidf"
112
+ """Keyword extraction method:
113
+ - "ctfidf": Class-based TF-IDF (fast, good quality)
114
+ - "bm25": BM25 scoring
115
+ - "keybert": KeyBERT extraction (slower, embedding-based)
116
+ """
117
+
118
+ # ==========================================================================
119
+ # Outlier Handling
120
+ # ==========================================================================
121
+ outlier_threshold: float = 0.1
122
+ """Threshold for outlier detection (documents below this similarity)."""
123
+
124
+ reduce_outliers: bool = False
125
+ """Whether to reassign outliers to nearest topics."""
126
+
127
+ # ==========================================================================
128
+ # Dimensionality Reduction (for visualization)
129
+ # ==========================================================================
130
+ umap_n_neighbors: int = 15
131
+ """UMAP n_neighbors for visualization."""
132
+
133
+ umap_n_components: int = 2
134
+ """UMAP dimensions for visualization."""
135
+
136
+ umap_min_dist: float = 0.1
137
+ """UMAP min_dist parameter."""
138
+
139
+ # ==========================================================================
140
+ # Misc
141
+ # ==========================================================================
142
+ random_state: Optional[int] = 42
143
+ """Random seed for reproducibility."""
144
+
145
+ verbose: bool = True
146
+ """Print progress information."""
147
+
148
+ n_jobs: int = -1
149
+ """Number of parallel jobs (-1 = all cores)."""
150
+
151
+ def __post_init__(self):
152
+ """Validate configuration after initialization."""
153
+ # Validate weights
154
+ if self.use_lexical_view and not self.use_metadata_view:
155
+ total = self.semantic_weight + self.lexical_weight
156
+ if abs(total - 1.0) > 0.01:
157
+ # Normalize weights
158
+ self.semantic_weight = self.semantic_weight / total
159
+ self.lexical_weight = self.lexical_weight / total
160
+ elif self.use_lexical_view and self.use_metadata_view:
161
+ total = self.semantic_weight + self.lexical_weight + self.metadata_weight
162
+ if abs(total - 1.0) > 0.01:
163
+ # Normalize weights
164
+ self.semantic_weight = self.semantic_weight / total
165
+ self.lexical_weight = self.lexical_weight / total
166
+ self.metadata_weight = self.metadata_weight / total
167
+
168
+ # Validate ranges
169
+ assert 0 < self.n_neighbors <= 100, "n_neighbors must be between 1 and 100"
170
+ assert 0 < self.n_consensus_runs <= 50, "n_consensus_runs must be between 1 and 50"
171
+ assert 0 <= self.snn_weight <= 1, "snn_weight must be between 0 and 1"
172
+ assert 0 < self.convergence_threshold <= 1, "convergence_threshold must be between 0 and 1"
173
+ assert self.min_cluster_size >= 2, "min_cluster_size must be at least 2"
174
+
175
+ def to_dict(self) -> dict:
176
+ """Convert config to dictionary."""
177
+ return {
178
+ 'embedding_model': self.embedding_model,
179
+ 'embedding_batch_size': self.embedding_batch_size,
180
+ 'n_neighbors': self.n_neighbors,
181
+ 'metric': self.metric,
182
+ 'graph_type': self.graph_type,
183
+ 'snn_weight': self.snn_weight,
184
+ 'use_lexical_view': self.use_lexical_view,
185
+ 'use_metadata_view': self.use_metadata_view,
186
+ 'semantic_weight': self.semantic_weight,
187
+ 'lexical_weight': self.lexical_weight,
188
+ 'metadata_weight': self.metadata_weight,
189
+ 'resolution': self.resolution,
190
+ 'n_consensus_runs': self.n_consensus_runs,
191
+ 'min_cluster_size': self.min_cluster_size,
192
+ 'use_iterative_refinement': self.use_iterative_refinement,
193
+ 'max_iterations': self.max_iterations,
194
+ 'convergence_threshold': self.convergence_threshold,
195
+ 'refinement_strength': self.refinement_strength,
196
+ 'n_keywords': self.n_keywords,
197
+ 'n_representative_docs': self.n_representative_docs,
198
+ 'keyword_method': self.keyword_method,
199
+ 'outlier_threshold': self.outlier_threshold,
200
+ 'reduce_outliers': self.reduce_outliers,
201
+ 'random_state': self.random_state,
202
+ 'verbose': self.verbose,
203
+ 'n_jobs': self.n_jobs,
204
+ }
205
+
206
+ @classmethod
207
+ def from_dict(cls, config_dict: dict) -> "TriTopicConfig":
208
+ """Create config from dictionary."""
209
+ return cls(**{k: v for k, v in config_dict.items() if hasattr(cls, k)})
210
+
211
+
212
+ def get_config(**kwargs) -> TriTopicConfig:
213
+ """
214
+ Helper function to create a TriTopicConfig with custom parameters.
215
+
216
+ All parameters are optional and default to TriTopicConfig defaults.
217
+
218
+ Example:
219
+ config = get_config(
220
+ embedding_model="all-mpnet-base-v2",
221
+ n_neighbors=20,
222
+ use_iterative_refinement=True
223
+ )
224
+
225
+ Args:
226
+ **kwargs: Any TriTopicConfig parameter
227
+
228
+ Returns:
229
+ TriTopicConfig instance
230
+ """
231
+ return TriTopicConfig(**kwargs)
232
+
233
+
234
+ # Preset configurations for common use cases
235
+ PRESETS = {
236
+ "fast": TriTopicConfig(
237
+ embedding_model="all-MiniLM-L6-v2",
238
+ n_neighbors=10,
239
+ n_consensus_runs=5,
240
+ use_iterative_refinement=False,
241
+ keyword_method="ctfidf",
242
+ ),
243
+ "balanced": TriTopicConfig(
244
+ embedding_model="all-MiniLM-L6-v2",
245
+ n_neighbors=15,
246
+ n_consensus_runs=10,
247
+ use_iterative_refinement=True,
248
+ max_iterations=3,
249
+ keyword_method="ctfidf",
250
+ ),
251
+ "quality": TriTopicConfig(
252
+ embedding_model="all-mpnet-base-v2",
253
+ n_neighbors=20,
254
+ n_consensus_runs=15,
255
+ use_iterative_refinement=True,
256
+ max_iterations=5,
257
+ keyword_method="ctfidf",
258
+ ),
259
+ "research": TriTopicConfig(
260
+ embedding_model="BAAI/bge-base-en-v1.5",
261
+ n_neighbors=20,
262
+ n_consensus_runs=20,
263
+ use_iterative_refinement=True,
264
+ max_iterations=10,
265
+ convergence_threshold=0.98,
266
+ keyword_method="ctfidf",
267
+ ),
268
+ }
269
+
270
+
271
+ def get_preset(name: str) -> TriTopicConfig:
272
+ """
273
+ Get a preset configuration.
274
+
275
+ Available presets:
276
+ - "fast": Quick results, lower quality
277
+ - "balanced": Good balance of speed and quality (default)
278
+ - "quality": Higher quality, slower
279
+ - "research": Maximum quality for research/publication
280
+
281
+ Args:
282
+ name: Preset name
283
+
284
+ Returns:
285
+ TriTopicConfig instance
286
+ """
287
+ if name not in PRESETS:
288
+ raise ValueError(f"Unknown preset '{name}'. Available: {list(PRESETS.keys())}")
289
+ return PRESETS[name]
tritopic/core/__init__.py CHANGED
@@ -1,17 +0,0 @@
1
- """Core components for TriTopic."""
2
-
3
- from tritopic.core.model import TriTopic, TriTopicConfig, TopicInfo
4
- from tritopic.core.graph_builder import GraphBuilder
5
- from tritopic.core.clustering import ConsensusLeiden
6
- from tritopic.core.embeddings import EmbeddingEngine
7
- from tritopic.core.keywords import KeywordExtractor
8
-
9
- __all__ = [
10
- "TriTopic",
11
- "TriTopicConfig",
12
- "TopicInfo",
13
- "GraphBuilder",
14
- "ConsensusLeiden",
15
- "EmbeddingEngine",
16
- "KeywordExtractor",
17
- ]