tritopic 1.0.0__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tritopic might be problematic. Click here for more details.

tritopic/config.py CHANGED
@@ -1,305 +1,289 @@
1
1
  """
2
2
  TriTopic Configuration Module
3
3
 
4
- Defines all configuration parameters for the TriTopic model.
4
+ This module provides configuration classes and utilities for TriTopic.
5
5
  """
6
6
 
7
7
  from dataclasses import dataclass, field
8
- from typing import Optional, List, Literal, Union
8
+ from typing import Optional, List, Literal
9
9
 
10
10
 
11
11
  @dataclass
12
12
  class TriTopicConfig:
13
13
  """
14
- Configuration for TriTopic model.
15
-
16
- Attributes
17
- ----------
18
- # Embedding & Language Settings
19
- embedding_model : str
20
- Sentence-Transformer model name or "auto" for automatic selection.
21
- Auto-selection considers the language parameter.
22
- embedding_batch_size : int
23
- Batch size for embedding generation.
24
- language : str
25
- ISO 639-1 language code (e.g., "en", "de", "zh") or "auto" for detection.
26
- multilingual : bool
27
- If True, uses multilingual embedding models regardless of detected language.
28
- language_detection_sample : int
29
- Number of documents to sample for automatic language detection.
30
- tokenizer : str
31
- Tokenizer to use: "auto", "whitespace", "spacy", "jieba", "fugashi", "konlpy", "pythainlp".
32
- custom_stopwords : List[str]
33
- Additional stopwords to add to the language-specific list.
34
- min_token_length : int
35
- Minimum token length to keep.
36
- max_token_length : int
37
- Maximum token length to keep.
14
+ Configuration class for TriTopic model.
38
15
 
39
- # Graph Construction
40
- n_neighbors : int
41
- Number of neighbors for kNN graph construction.
42
- metric : str
43
- Distance metric for similarity calculation.
44
- graph_type : str
45
- Type of graph: "knn", "mutual_knn", "snn", "hybrid".
46
- snn_weight : float
47
- Weight of SNN component in hybrid graph (0-1).
48
-
49
- # Multi-View Fusion
50
- use_lexical_view : bool
51
- Whether to include lexical (TF-IDF/BM25) similarity.
52
- use_metadata_view : bool
53
- Whether to include metadata-based similarity.
54
- semantic_weight : float
55
- Weight for semantic (embedding) view.
56
- lexical_weight : float
57
- Weight for lexical view.
58
- metadata_weight : float
59
- Weight for metadata view.
60
- lexical_method : str
61
- Method for lexical similarity: "tfidf", "bm25".
62
- ngram_range : tuple
63
- N-gram range for lexical features.
64
-
65
- # Clustering
66
- resolution : float
67
- Resolution parameter for Leiden algorithm.
68
- n_consensus_runs : int
69
- Number of clustering runs for consensus.
70
- min_cluster_size : int
71
- Minimum number of documents per topic.
72
-
73
- # Iterative Refinement
74
- use_iterative_refinement : bool
75
- Whether to use iterative embedding refinement.
76
- max_iterations : int
77
- Maximum refinement iterations.
78
- convergence_threshold : float
79
- ARI threshold for convergence detection.
80
- refinement_strength : float
81
- How strongly to pull embeddings toward centroids (0-1).
82
-
83
- # Keywords
84
- n_keywords : int
85
- Number of keywords per topic.
86
- keyword_method : str
87
- Method for keyword extraction: "ctfidf", "bm25", "keybert".
88
-
89
- # Representative Documents
90
- n_representative_docs : int
91
- Number of representative documents per topic.
92
- representative_method : str
93
- Method for selection: "centroid", "medoid", "archetype", "diverse", "hybrid".
94
- n_archetypes : int
95
- Number of archetypes per topic (for archetype/hybrid method).
96
- archetype_method : str
97
- Algorithm for archetype analysis: "pcha", "convex_hull", "furthest_sum".
16
+ All hyperparameters can be set here and passed to TriTopic.
98
17
 
99
- # Outlier Handling
100
- outlier_threshold : float
101
- Threshold for outlier detection (0-1).
102
- reassign_outliers : bool
103
- Whether to try reassigning outliers to nearest topic.
104
-
105
- # Misc
106
- random_state : int
107
- Random seed for reproducibility.
108
- verbose : bool
109
- Whether to print progress information.
110
- n_jobs : int
111
- Number of parallel jobs (-1 for all cores).
18
+ Example:
19
+ config = TriTopicConfig(
20
+ embedding_model="all-mpnet-base-v2",
21
+ n_neighbors=20,
22
+ use_iterative_refinement=True
23
+ )
24
+ model = TriTopic(config=config)
112
25
  """
113
26
 
114
- # === Embedding & Language Settings ===
115
- embedding_model: str = "auto"
27
+ # ==========================================================================
28
+ # Embedding Settings
29
+ # ==========================================================================
30
+ embedding_model: str = "all-MiniLM-L6-v2"
31
+ """Sentence-Transformer model name or path.
32
+ Options: "all-MiniLM-L6-v2", "all-mpnet-base-v2", "BAAI/bge-base-en-v1.5", etc."""
33
+
116
34
  embedding_batch_size: int = 32
117
- language: str = "auto"
118
- multilingual: bool = False
119
- language_detection_sample: int = 100
120
- tokenizer: str = "auto"
121
- custom_stopwords: Optional[List[str]] = None
122
- min_token_length: int = 2
123
- max_token_length: int = 50
124
-
125
- # === Graph Construction ===
35
+ """Batch size for embedding generation. Reduce if GPU OOM."""
36
+
37
+ # ==========================================================================
38
+ # Graph Construction
39
+ # ==========================================================================
126
40
  n_neighbors: int = 15
41
+ """Number of neighbors for kNN graph construction."""
42
+
127
43
  metric: str = "cosine"
44
+ """Distance metric: "cosine", "euclidean", "manhattan"."""
45
+
128
46
  graph_type: Literal["knn", "mutual_knn", "snn", "hybrid"] = "hybrid"
47
+ """Graph type:
48
+ - "knn": Standard k-nearest neighbors
49
+ - "mutual_knn": Only bidirectional connections
50
+ - "snn": Shared nearest neighbors
51
+ - "hybrid": Combination of mutual_knn + snn (recommended)
52
+ """
53
+
129
54
  snn_weight: float = 0.5
55
+ """Weight of SNN component in hybrid graph (0.0 to 1.0)."""
130
56
 
131
- # === Multi-View Fusion ===
57
+ # ==========================================================================
58
+ # Multi-View Fusion
59
+ # ==========================================================================
132
60
  use_lexical_view: bool = True
61
+ """Include lexical (TF-IDF) similarity in graph construction."""
62
+
133
63
  use_metadata_view: bool = False
64
+ """Include metadata similarity in graph construction."""
65
+
134
66
  semantic_weight: float = 0.5
67
+ """Weight for semantic (embedding) similarity."""
68
+
135
69
  lexical_weight: float = 0.3
70
+ """Weight for lexical (TF-IDF) similarity."""
71
+
136
72
  metadata_weight: float = 0.2
137
- lexical_method: Literal["tfidf", "bm25"] = "tfidf"
138
- ngram_range: tuple = (1, 2)
73
+ """Weight for metadata similarity."""
139
74
 
140
- # === Clustering ===
75
+ # ==========================================================================
76
+ # Clustering (Leiden + Consensus)
77
+ # ==========================================================================
141
78
  resolution: float = 1.0
79
+ """Leiden resolution parameter. Higher = more topics, lower = fewer topics."""
80
+
142
81
  n_consensus_runs: int = 10
82
+ """Number of clustering runs for consensus clustering."""
83
+
143
84
  min_cluster_size: int = 5
85
+ """Minimum documents per topic. Smaller clusters become outliers."""
144
86
 
145
- # === Iterative Refinement ===
87
+ # ==========================================================================
88
+ # Iterative Refinement
89
+ # ==========================================================================
146
90
  use_iterative_refinement: bool = True
91
+ """Enable iterative embedding refinement based on discovered topics."""
92
+
147
93
  max_iterations: int = 5
94
+ """Maximum number of refinement iterations."""
95
+
148
96
  convergence_threshold: float = 0.95
149
- refinement_strength: float = 0.15
97
+ """ARI threshold for convergence (0.0 to 1.0)."""
98
+
99
+ refinement_strength: float = 0.1
100
+ """How strongly to pull embeddings toward topic centroids (0.0 to 1.0)."""
150
101
 
151
- # === Keywords ===
102
+ # ==========================================================================
103
+ # Keyword Extraction
104
+ # ==========================================================================
152
105
  n_keywords: int = 10
153
- keyword_method: Literal["ctfidf", "bm25", "keybert"] = "ctfidf"
106
+ """Number of keywords to extract per topic."""
154
107
 
155
- # === Representative Documents ===
156
108
  n_representative_docs: int = 5
157
- representative_method: Literal["centroid", "medoid", "archetype", "diverse", "hybrid"] = "hybrid"
158
- n_archetypes: int = 4
159
- archetype_method: Literal["pcha", "convex_hull", "furthest_sum"] = "furthest_sum"
109
+ """Number of representative documents per topic."""
160
110
 
161
- # === Outlier Handling ===
111
+ keyword_method: Literal["ctfidf", "bm25", "keybert"] = "ctfidf"
112
+ """Keyword extraction method:
113
+ - "ctfidf": Class-based TF-IDF (fast, good quality)
114
+ - "bm25": BM25 scoring
115
+ - "keybert": KeyBERT extraction (slower, embedding-based)
116
+ """
117
+
118
+ # ==========================================================================
119
+ # Outlier Handling
120
+ # ==========================================================================
162
121
  outlier_threshold: float = 0.1
163
- reassign_outliers: bool = False
122
+ """Threshold for outlier detection (documents below this similarity)."""
123
+
124
+ reduce_outliers: bool = False
125
+ """Whether to reassign outliers to nearest topics."""
126
+
127
+ # ==========================================================================
128
+ # Dimensionality Reduction (for visualization)
129
+ # ==========================================================================
130
+ umap_n_neighbors: int = 15
131
+ """UMAP n_neighbors for visualization."""
132
+
133
+ umap_n_components: int = 2
134
+ """UMAP dimensions for visualization."""
164
135
 
165
- # === Misc ===
136
+ umap_min_dist: float = 0.1
137
+ """UMAP min_dist parameter."""
138
+
139
+ # ==========================================================================
140
+ # Misc
141
+ # ==========================================================================
166
142
  random_state: Optional[int] = 42
143
+ """Random seed for reproducibility."""
144
+
167
145
  verbose: bool = True
146
+ """Print progress information."""
147
+
168
148
  n_jobs: int = -1
149
+ """Number of parallel jobs (-1 = all cores)."""
169
150
 
170
151
  def __post_init__(self):
171
152
  """Validate configuration after initialization."""
172
- self._validate()
173
-
174
- def _validate(self):
175
- """Validate configuration parameters."""
176
- # Weights should sum to ~1.0
177
- total_weight = self.semantic_weight
178
- if self.use_lexical_view:
179
- total_weight += self.lexical_weight
180
- if self.use_metadata_view:
181
- total_weight += self.metadata_weight
182
-
183
- if abs(total_weight - 1.0) > 0.01:
184
- # Auto-normalize weights
185
- if self.use_lexical_view and self.use_metadata_view:
186
- self.semantic_weight = self.semantic_weight / total_weight
187
- self.lexical_weight = self.lexical_weight / total_weight
188
- self.metadata_weight = self.metadata_weight / total_weight
189
- elif self.use_lexical_view:
190
- total = self.semantic_weight + self.lexical_weight
153
+ # Validate weights
154
+ if self.use_lexical_view and not self.use_metadata_view:
155
+ total = self.semantic_weight + self.lexical_weight
156
+ if abs(total - 1.0) > 0.01:
157
+ # Normalize weights
158
+ self.semantic_weight = self.semantic_weight / total
159
+ self.lexical_weight = self.lexical_weight / total
160
+ elif self.use_lexical_view and self.use_metadata_view:
161
+ total = self.semantic_weight + self.lexical_weight + self.metadata_weight
162
+ if abs(total - 1.0) > 0.01:
163
+ # Normalize weights
191
164
  self.semantic_weight = self.semantic_weight / total
192
165
  self.lexical_weight = self.lexical_weight / total
193
- else:
194
- self.semantic_weight = 1.0
166
+ self.metadata_weight = self.metadata_weight / total
195
167
 
196
168
  # Validate ranges
197
169
  assert 0 < self.n_neighbors <= 100, "n_neighbors must be between 1 and 100"
198
- assert 0 < self.snn_weight <= 1, "snn_weight must be between 0 and 1"
199
- assert 0 < self.resolution <= 5, "resolution must be between 0 and 5"
170
+ assert 0 < self.n_consensus_runs <= 50, "n_consensus_runs must be between 1 and 50"
171
+ assert 0 <= self.snn_weight <= 1, "snn_weight must be between 0 and 1"
200
172
  assert 0 < self.convergence_threshold <= 1, "convergence_threshold must be between 0 and 1"
201
- assert self.n_archetypes >= 2, "n_archetypes must be at least 2"
202
-
203
- def get_embedding_model_for_language(self, detected_language: str = None) -> str:
204
- """
205
- Get the appropriate embedding model based on language settings.
206
-
207
- Parameters
208
- ----------
209
- detected_language : str, optional
210
- The detected language code if language="auto"
211
-
212
- Returns
213
- -------
214
- str
215
- The embedding model name to use
216
- """
217
- if self.embedding_model != "auto":
218
- return self.embedding_model
219
-
220
- lang = detected_language or self.language
221
-
222
- # If multilingual mode is explicitly enabled
223
- if self.multilingual:
224
- return "paraphrase-multilingual-mpnet-base-v2"
225
-
226
- # Language-specific model selection
227
- model_map = {
228
- "en": "all-MiniLM-L6-v2",
229
- "zh": "BAAI/bge-base-zh-v1.5",
230
- "ja": "paraphrase-multilingual-MiniLM-L12-v2",
231
- "ko": "paraphrase-multilingual-MiniLM-L12-v2",
232
- }
233
-
234
- # Default to multilingual for non-English
235
- if lang in model_map:
236
- return model_map[lang]
237
- elif lang != "en" and lang != "auto":
238
- return "paraphrase-multilingual-MiniLM-L12-v2"
239
- else:
240
- return "all-MiniLM-L6-v2"
173
+ assert self.min_cluster_size >= 2, "min_cluster_size must be at least 2"
241
174
 
242
175
  def to_dict(self) -> dict:
243
176
  """Convert config to dictionary."""
244
177
  return {
245
- k: v for k, v in self.__dict__.items()
246
- if not k.startswith('_')
178
+ 'embedding_model': self.embedding_model,
179
+ 'embedding_batch_size': self.embedding_batch_size,
180
+ 'n_neighbors': self.n_neighbors,
181
+ 'metric': self.metric,
182
+ 'graph_type': self.graph_type,
183
+ 'snn_weight': self.snn_weight,
184
+ 'use_lexical_view': self.use_lexical_view,
185
+ 'use_metadata_view': self.use_metadata_view,
186
+ 'semantic_weight': self.semantic_weight,
187
+ 'lexical_weight': self.lexical_weight,
188
+ 'metadata_weight': self.metadata_weight,
189
+ 'resolution': self.resolution,
190
+ 'n_consensus_runs': self.n_consensus_runs,
191
+ 'min_cluster_size': self.min_cluster_size,
192
+ 'use_iterative_refinement': self.use_iterative_refinement,
193
+ 'max_iterations': self.max_iterations,
194
+ 'convergence_threshold': self.convergence_threshold,
195
+ 'refinement_strength': self.refinement_strength,
196
+ 'n_keywords': self.n_keywords,
197
+ 'n_representative_docs': self.n_representative_docs,
198
+ 'keyword_method': self.keyword_method,
199
+ 'outlier_threshold': self.outlier_threshold,
200
+ 'reduce_outliers': self.reduce_outliers,
201
+ 'random_state': self.random_state,
202
+ 'verbose': self.verbose,
203
+ 'n_jobs': self.n_jobs,
247
204
  }
248
205
 
249
206
  @classmethod
250
207
  def from_dict(cls, config_dict: dict) -> "TriTopicConfig":
251
208
  """Create config from dictionary."""
252
- return cls(**config_dict)
209
+ return cls(**{k: v for k, v in config_dict.items() if hasattr(cls, k)})
253
210
 
254
211
 
255
- # Predefined configurations for common use cases
256
- CONFIGS = {
257
- "default": TriTopicConfig(),
212
+ def get_config(**kwargs) -> TriTopicConfig:
213
+ """
214
+ Helper function to create a TriTopicConfig with custom parameters.
215
+
216
+ All parameters are optional and default to TriTopicConfig defaults.
217
+
218
+ Example:
219
+ config = get_config(
220
+ embedding_model="all-mpnet-base-v2",
221
+ n_neighbors=20,
222
+ use_iterative_refinement=True
223
+ )
258
224
 
225
+ Args:
226
+ **kwargs: Any TriTopicConfig parameter
227
+
228
+ Returns:
229
+ TriTopicConfig instance
230
+ """
231
+ return TriTopicConfig(**kwargs)
232
+
233
+
234
+ # Preset configurations for common use cases
235
+ PRESETS = {
259
236
  "fast": TriTopicConfig(
260
237
  embedding_model="all-MiniLM-L6-v2",
261
238
  n_neighbors=10,
262
239
  n_consensus_runs=5,
263
240
  use_iterative_refinement=False,
264
- representative_method="centroid",
241
+ keyword_method="ctfidf",
242
+ ),
243
+ "balanced": TriTopicConfig(
244
+ embedding_model="all-MiniLM-L6-v2",
245
+ n_neighbors=15,
246
+ n_consensus_runs=10,
247
+ use_iterative_refinement=True,
248
+ max_iterations=3,
249
+ keyword_method="ctfidf",
265
250
  ),
266
-
267
251
  "quality": TriTopicConfig(
252
+ embedding_model="all-mpnet-base-v2",
253
+ n_neighbors=20,
254
+ n_consensus_runs=15,
255
+ use_iterative_refinement=True,
256
+ max_iterations=5,
257
+ keyword_method="ctfidf",
258
+ ),
259
+ "research": TriTopicConfig(
268
260
  embedding_model="BAAI/bge-base-en-v1.5",
269
261
  n_neighbors=20,
270
262
  n_consensus_runs=20,
263
+ use_iterative_refinement=True,
271
264
  max_iterations=10,
272
- representative_method="hybrid",
273
- n_archetypes=5,
274
- ),
275
-
276
- "multilingual": TriTopicConfig(
277
- multilingual=True,
278
- embedding_model="paraphrase-multilingual-mpnet-base-v2",
279
- semantic_weight=0.6,
280
- lexical_weight=0.2,
281
- metadata_weight=0.2,
282
- ),
283
-
284
- "multilingual_quality": TriTopicConfig(
285
- multilingual=True,
286
- embedding_model="BAAI/bge-m3",
287
- n_neighbors=20,
288
- n_consensus_runs=15,
289
- semantic_weight=0.6,
290
- lexical_weight=0.2,
291
- representative_method="hybrid",
265
+ convergence_threshold=0.98,
266
+ keyword_method="ctfidf",
292
267
  ),
268
+ }
269
+
270
+
271
+ def get_preset(name: str) -> TriTopicConfig:
272
+ """
273
+ Get a preset configuration.
293
274
 
294
- "chinese": TriTopicConfig(
295
- language="zh",
296
- embedding_model="BAAI/bge-base-zh-v1.5",
297
- tokenizer="jieba",
298
- ngram_range=(1, 2),
299
- ),
275
+ Available presets:
276
+ - "fast": Quick results, lower quality
277
+ - "balanced": Good balance of speed and quality (default)
278
+ - "quality": Higher quality, slower
279
+ - "research": Maximum quality for research/publication
300
280
 
301
- "german": TriTopicConfig(
302
- language="de",
303
- embedding_model="paraphrase-multilingual-MiniLM-L12-v2",
304
- ),
305
- }
281
+ Args:
282
+ name: Preset name
283
+
284
+ Returns:
285
+ TriTopicConfig instance
286
+ """
287
+ if name not in PRESETS:
288
+ raise ValueError(f"Unknown preset '{name}'. Available: {list(PRESETS.keys())}")
289
+ return PRESETS[name]
@@ -1,13 +1,9 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: tritopic
3
- Version: 1.0.0
3
+ Version: 1.1.0
4
4
  Summary: Tri-Modal Graph Topic Modeling with Iterative Refinement
5
- Author-email: Roman Egger <roman.egger@example.com>
5
+ Author-email: Roman Egger <roman.egger@smartvisions.at>
6
6
  License: MIT
7
- Project-URL: Homepage, https://github.com/roman-egger/tritopic
8
- Project-URL: Documentation, https://tritopic.readthedocs.io
9
- Project-URL: Repository, https://github.com/roman-egger/tritopic
10
- Project-URL: Issues, https://github.com/roman-egger/tritopic/issues
11
7
  Keywords: topic-modeling,nlp,machine-learning,bertopic,clustering,text-analysis,multilingual
12
8
  Classifier: Development Status :: 4 - Beta
13
9
  Classifier: Intended Audience :: Developers
@@ -1,5 +1,5 @@
1
1
  tritopic/__init__.py,sha256=BaHbardg5BW9zykYOtYG1ZM1nGwvfVt7DV7NJ7tp4l8,936
2
- tritopic/config.py,sha256=bsornL0etlRxQyMa6-Yx7tgXqVR1b8OZPpXM62cibhI,10120
2
+ tritopic/config.py,sha256=vL47vU5KAYD1iCzH3cRMFUO1w1NSibmjIuAHNsBLu5c,10614
3
3
  tritopic/labeling.py,sha256=SJsvOXRl-q8f3qtk1S66FGozTJsW8bwNnAKGkAklmVQ,8883
4
4
  tritopic/model.py,sha256=mzptfvqG_Q81OcS6kiYd7u2uU2AKjxpDYKo9u1EfpH4,25015
5
5
  tritopic/visualization.py,sha256=MCiIgIoTzFoQ7GG9WjfSZlV2j1BBGzZwxRddmvmh1OY,9841
@@ -14,7 +14,7 @@ tritopic/multilingual/__init__.py,sha256=EagOqVqMDNKX7AfEAQfVgbR92f2vBy1KSM5O88A
14
14
  tritopic/multilingual/detection.py,sha256=xeZqNp4l-fRII5s2S4EMzBdJPf3Xgt6e1a3Od2hc2q4,5700
15
15
  tritopic/multilingual/stopwords.py,sha256=viMM1pb4VpDEmDpGpx_8sDfumXfrVXKfUULyOZXFFYU,29942
16
16
  tritopic/multilingual/tokenizers.py,sha256=seTCzRiUOqO0UbAqA3nn8V8EoVYQ1wiwqcH8lafRCxM,9954
17
- tritopic-1.0.0.dist-info/METADATA,sha256=kwoHBkE7i3m59h5i5QA10IsfBpJ5_rqI1u_SKXHFjQU,14178
18
- tritopic-1.0.0.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
19
- tritopic-1.0.0.dist-info/top_level.txt,sha256=9PASbqQyi0-wa7E2Hl3Z0u1ae7MwLcfgFliFE1ioFBA,9
20
- tritopic-1.0.0.dist-info/RECORD,,
17
+ tritopic-1.1.0.dist-info/METADATA,sha256=nIWD3zUMOQR9efdUFo8zUjM0JVJGgrzgZVDyLbbjJ7I,13922
18
+ tritopic-1.1.0.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
19
+ tritopic-1.1.0.dist-info/top_level.txt,sha256=9PASbqQyi0-wa7E2Hl3Z0u1ae7MwLcfgFliFE1ioFBA,9
20
+ tritopic-1.1.0.dist-info/RECORD,,