tritopic 1.1.0__py3-none-any.whl → 1.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tritopic might be problematic. Click here for more details.

tritopic/config.py CHANGED
@@ -1,289 +1,305 @@
1
1
  """
2
2
  TriTopic Configuration Module
3
3
 
4
- This module provides configuration classes and utilities for TriTopic.
4
+ Defines all configuration parameters for the TriTopic model.
5
5
  """
6
6
 
7
7
  from dataclasses import dataclass, field
8
- from typing import Optional, List, Literal
8
+ from typing import Optional, List, Literal, Union
9
9
 
10
10
 
11
11
  @dataclass
12
12
  class TriTopicConfig:
13
13
  """
14
- Configuration class for TriTopic model.
14
+ Configuration for TriTopic model.
15
+
16
+ Attributes
17
+ ----------
18
+ # Embedding & Language Settings
19
+ embedding_model : str
20
+ Sentence-Transformer model name or "auto" for automatic selection.
21
+ Auto-selection considers the language parameter.
22
+ embedding_batch_size : int
23
+ Batch size for embedding generation.
24
+ language : str
25
+ ISO 639-1 language code (e.g., "en", "de", "zh") or "auto" for detection.
26
+ multilingual : bool
27
+ If True, uses multilingual embedding models regardless of detected language.
28
+ language_detection_sample : int
29
+ Number of documents to sample for automatic language detection.
30
+ tokenizer : str
31
+ Tokenizer to use: "auto", "whitespace", "spacy", "jieba", "fugashi", "konlpy", "pythainlp".
32
+ custom_stopwords : List[str]
33
+ Additional stopwords to add to the language-specific list.
34
+ min_token_length : int
35
+ Minimum token length to keep.
36
+ max_token_length : int
37
+ Maximum token length to keep.
15
38
 
16
- All hyperparameters can be set here and passed to TriTopic.
39
+ # Graph Construction
40
+ n_neighbors : int
41
+ Number of neighbors for kNN graph construction.
42
+ metric : str
43
+ Distance metric for similarity calculation.
44
+ graph_type : str
45
+ Type of graph: "knn", "mutual_knn", "snn", "hybrid".
46
+ snn_weight : float
47
+ Weight of SNN component in hybrid graph (0-1).
17
48
 
18
- Example:
19
- config = TriTopicConfig(
20
- embedding_model="all-mpnet-base-v2",
21
- n_neighbors=20,
22
- use_iterative_refinement=True
23
- )
24
- model = TriTopic(config=config)
25
- """
49
+ # Multi-View Fusion
50
+ use_lexical_view : bool
51
+ Whether to include lexical (TF-IDF/BM25) similarity.
52
+ use_metadata_view : bool
53
+ Whether to include metadata-based similarity.
54
+ semantic_weight : float
55
+ Weight for semantic (embedding) view.
56
+ lexical_weight : float
57
+ Weight for lexical view.
58
+ metadata_weight : float
59
+ Weight for metadata view.
60
+ lexical_method : str
61
+ Method for lexical similarity: "tfidf", "bm25".
62
+ ngram_range : tuple
63
+ N-gram range for lexical features.
64
+
65
+ # Clustering
66
+ resolution : float
67
+ Resolution parameter for Leiden algorithm.
68
+ n_consensus_runs : int
69
+ Number of clustering runs for consensus.
70
+ min_cluster_size : int
71
+ Minimum number of documents per topic.
26
72
 
27
- # ==========================================================================
28
- # Embedding Settings
29
- # ==========================================================================
30
- embedding_model: str = "all-MiniLM-L6-v2"
31
- """Sentence-Transformer model name or path.
32
- Options: "all-MiniLM-L6-v2", "all-mpnet-base-v2", "BAAI/bge-base-en-v1.5", etc."""
73
+ # Iterative Refinement
74
+ use_iterative_refinement : bool
75
+ Whether to use iterative embedding refinement.
76
+ max_iterations : int
77
+ Maximum refinement iterations.
78
+ convergence_threshold : float
79
+ ARI threshold for convergence detection.
80
+ refinement_strength : float
81
+ How strongly to pull embeddings toward centroids (0-1).
82
+
83
+ # Keywords
84
+ n_keywords : int
85
+ Number of keywords per topic.
86
+ keyword_method : str
87
+ Method for keyword extraction: "ctfidf", "bm25", "keybert".
88
+
89
+ # Representative Documents
90
+ n_representative_docs : int
91
+ Number of representative documents per topic.
92
+ representative_method : str
93
+ Method for selection: "centroid", "medoid", "archetype", "diverse", "hybrid".
94
+ n_archetypes : int
95
+ Number of archetypes per topic (for archetype/hybrid method).
96
+ archetype_method : str
97
+ Algorithm for archetype analysis: "pcha", "convex_hull", "furthest_sum".
33
98
 
34
- embedding_batch_size: int = 32
35
- """Batch size for embedding generation. Reduce if GPU OOM."""
99
+ # Outlier Handling
100
+ outlier_threshold : float
101
+ Threshold for outlier detection (0-1).
102
+ reassign_outliers : bool
103
+ Whether to try reassigning outliers to nearest topic.
36
104
 
37
- # ==========================================================================
38
- # Graph Construction
39
- # ==========================================================================
40
- n_neighbors: int = 15
41
- """Number of neighbors for kNN graph construction."""
105
+ # Misc
106
+ random_state : int
107
+ Random seed for reproducibility.
108
+ verbose : bool
109
+ Whether to print progress information.
110
+ n_jobs : int
111
+ Number of parallel jobs (-1 for all cores).
112
+ """
42
113
 
114
+ # === Embedding & Language Settings ===
115
+ embedding_model: str = "auto"
116
+ embedding_batch_size: int = 32
117
+ language: str = "auto"
118
+ multilingual: bool = False
119
+ language_detection_sample: int = 100
120
+ tokenizer: str = "auto"
121
+ custom_stopwords: Optional[List[str]] = None
122
+ min_token_length: int = 2
123
+ max_token_length: int = 50
124
+
125
+ # === Graph Construction ===
126
+ n_neighbors: int = 15
43
127
  metric: str = "cosine"
44
- """Distance metric: "cosine", "euclidean", "manhattan"."""
45
-
46
128
  graph_type: Literal["knn", "mutual_knn", "snn", "hybrid"] = "hybrid"
47
- """Graph type:
48
- - "knn": Standard k-nearest neighbors
49
- - "mutual_knn": Only bidirectional connections
50
- - "snn": Shared nearest neighbors
51
- - "hybrid": Combination of mutual_knn + snn (recommended)
52
- """
53
-
54
129
  snn_weight: float = 0.5
55
- """Weight of SNN component in hybrid graph (0.0 to 1.0)."""
56
130
 
57
- # ==========================================================================
58
- # Multi-View Fusion
59
- # ==========================================================================
131
+ # === Multi-View Fusion ===
60
132
  use_lexical_view: bool = True
61
- """Include lexical (TF-IDF) similarity in graph construction."""
62
-
63
133
  use_metadata_view: bool = False
64
- """Include metadata similarity in graph construction."""
65
-
66
134
  semantic_weight: float = 0.5
67
- """Weight for semantic (embedding) similarity."""
68
-
69
135
  lexical_weight: float = 0.3
70
- """Weight for lexical (TF-IDF) similarity."""
71
-
72
136
  metadata_weight: float = 0.2
73
- """Weight for metadata similarity."""
137
+ lexical_method: Literal["tfidf", "bm25"] = "tfidf"
138
+ ngram_range: tuple = (1, 2)
74
139
 
75
- # ==========================================================================
76
- # Clustering (Leiden + Consensus)
77
- # ==========================================================================
140
+ # === Clustering ===
78
141
  resolution: float = 1.0
79
- """Leiden resolution parameter. Higher = more topics, lower = fewer topics."""
80
-
81
142
  n_consensus_runs: int = 10
82
- """Number of clustering runs for consensus clustering."""
83
-
84
143
  min_cluster_size: int = 5
85
- """Minimum documents per topic. Smaller clusters become outliers."""
86
144
 
87
- # ==========================================================================
88
- # Iterative Refinement
89
- # ==========================================================================
145
+ # === Iterative Refinement ===
90
146
  use_iterative_refinement: bool = True
91
- """Enable iterative embedding refinement based on discovered topics."""
92
-
93
147
  max_iterations: int = 5
94
- """Maximum number of refinement iterations."""
95
-
96
148
  convergence_threshold: float = 0.95
97
- """ARI threshold for convergence (0.0 to 1.0)."""
98
-
99
- refinement_strength: float = 0.1
100
- """How strongly to pull embeddings toward topic centroids (0.0 to 1.0)."""
149
+ refinement_strength: float = 0.15
101
150
 
102
- # ==========================================================================
103
- # Keyword Extraction
104
- # ==========================================================================
151
+ # === Keywords ===
105
152
  n_keywords: int = 10
106
- """Number of keywords to extract per topic."""
153
+ keyword_method: Literal["ctfidf", "bm25", "keybert"] = "ctfidf"
107
154
 
155
+ # === Representative Documents ===
108
156
  n_representative_docs: int = 5
109
- """Number of representative documents per topic."""
157
+ representative_method: Literal["centroid", "medoid", "archetype", "diverse", "hybrid"] = "hybrid"
158
+ n_archetypes: int = 4
159
+ archetype_method: Literal["pcha", "convex_hull", "furthest_sum"] = "furthest_sum"
110
160
 
111
- keyword_method: Literal["ctfidf", "bm25", "keybert"] = "ctfidf"
112
- """Keyword extraction method:
113
- - "ctfidf": Class-based TF-IDF (fast, good quality)
114
- - "bm25": BM25 scoring
115
- - "keybert": KeyBERT extraction (slower, embedding-based)
116
- """
117
-
118
- # ==========================================================================
119
- # Outlier Handling
120
- # ==========================================================================
161
+ # === Outlier Handling ===
121
162
  outlier_threshold: float = 0.1
122
- """Threshold for outlier detection (documents below this similarity)."""
123
-
124
- reduce_outliers: bool = False
125
- """Whether to reassign outliers to nearest topics."""
126
-
127
- # ==========================================================================
128
- # Dimensionality Reduction (for visualization)
129
- # ==========================================================================
130
- umap_n_neighbors: int = 15
131
- """UMAP n_neighbors for visualization."""
132
-
133
- umap_n_components: int = 2
134
- """UMAP dimensions for visualization."""
163
+ reassign_outliers: bool = False
135
164
 
136
- umap_min_dist: float = 0.1
137
- """UMAP min_dist parameter."""
138
-
139
- # ==========================================================================
140
- # Misc
141
- # ==========================================================================
165
+ # === Misc ===
142
166
  random_state: Optional[int] = 42
143
- """Random seed for reproducibility."""
144
-
145
167
  verbose: bool = True
146
- """Print progress information."""
147
-
148
168
  n_jobs: int = -1
149
- """Number of parallel jobs (-1 = all cores)."""
150
169
 
151
170
  def __post_init__(self):
152
171
  """Validate configuration after initialization."""
153
- # Validate weights
154
- if self.use_lexical_view and not self.use_metadata_view:
155
- total = self.semantic_weight + self.lexical_weight
156
- if abs(total - 1.0) > 0.01:
157
- # Normalize weights
158
- self.semantic_weight = self.semantic_weight / total
159
- self.lexical_weight = self.lexical_weight / total
160
- elif self.use_lexical_view and self.use_metadata_view:
161
- total = self.semantic_weight + self.lexical_weight + self.metadata_weight
162
- if abs(total - 1.0) > 0.01:
163
- # Normalize weights
172
+ self._validate()
173
+
174
+ def _validate(self):
175
+ """Validate configuration parameters."""
176
+ # Weights should sum to ~1.0
177
+ total_weight = self.semantic_weight
178
+ if self.use_lexical_view:
179
+ total_weight += self.lexical_weight
180
+ if self.use_metadata_view:
181
+ total_weight += self.metadata_weight
182
+
183
+ if abs(total_weight - 1.0) > 0.01:
184
+ # Auto-normalize weights
185
+ if self.use_lexical_view and self.use_metadata_view:
186
+ self.semantic_weight = self.semantic_weight / total_weight
187
+ self.lexical_weight = self.lexical_weight / total_weight
188
+ self.metadata_weight = self.metadata_weight / total_weight
189
+ elif self.use_lexical_view:
190
+ total = self.semantic_weight + self.lexical_weight
164
191
  self.semantic_weight = self.semantic_weight / total
165
192
  self.lexical_weight = self.lexical_weight / total
166
- self.metadata_weight = self.metadata_weight / total
193
+ else:
194
+ self.semantic_weight = 1.0
167
195
 
168
196
  # Validate ranges
169
197
  assert 0 < self.n_neighbors <= 100, "n_neighbors must be between 1 and 100"
170
- assert 0 < self.n_consensus_runs <= 50, "n_consensus_runs must be between 1 and 50"
171
- assert 0 <= self.snn_weight <= 1, "snn_weight must be between 0 and 1"
198
+ assert 0 < self.snn_weight <= 1, "snn_weight must be between 0 and 1"
199
+ assert 0 < self.resolution <= 5, "resolution must be between 0 and 5"
172
200
  assert 0 < self.convergence_threshold <= 1, "convergence_threshold must be between 0 and 1"
173
- assert self.min_cluster_size >= 2, "min_cluster_size must be at least 2"
201
+ assert self.n_archetypes >= 2, "n_archetypes must be at least 2"
202
+
203
+ def get_embedding_model_for_language(self, detected_language: str = None) -> str:
204
+ """
205
+ Get the appropriate embedding model based on language settings.
206
+
207
+ Parameters
208
+ ----------
209
+ detected_language : str, optional
210
+ The detected language code if language="auto"
211
+
212
+ Returns
213
+ -------
214
+ str
215
+ The embedding model name to use
216
+ """
217
+ if self.embedding_model != "auto":
218
+ return self.embedding_model
219
+
220
+ lang = detected_language or self.language
221
+
222
+ # If multilingual mode is explicitly enabled
223
+ if self.multilingual:
224
+ return "paraphrase-multilingual-mpnet-base-v2"
225
+
226
+ # Language-specific model selection
227
+ model_map = {
228
+ "en": "all-MiniLM-L6-v2",
229
+ "zh": "BAAI/bge-base-zh-v1.5",
230
+ "ja": "paraphrase-multilingual-MiniLM-L12-v2",
231
+ "ko": "paraphrase-multilingual-MiniLM-L12-v2",
232
+ }
233
+
234
+ # Default to multilingual for non-English
235
+ if lang in model_map:
236
+ return model_map[lang]
237
+ elif lang != "en" and lang != "auto":
238
+ return "paraphrase-multilingual-MiniLM-L12-v2"
239
+ else:
240
+ return "all-MiniLM-L6-v2"
174
241
 
175
242
  def to_dict(self) -> dict:
176
243
  """Convert config to dictionary."""
177
244
  return {
178
- 'embedding_model': self.embedding_model,
179
- 'embedding_batch_size': self.embedding_batch_size,
180
- 'n_neighbors': self.n_neighbors,
181
- 'metric': self.metric,
182
- 'graph_type': self.graph_type,
183
- 'snn_weight': self.snn_weight,
184
- 'use_lexical_view': self.use_lexical_view,
185
- 'use_metadata_view': self.use_metadata_view,
186
- 'semantic_weight': self.semantic_weight,
187
- 'lexical_weight': self.lexical_weight,
188
- 'metadata_weight': self.metadata_weight,
189
- 'resolution': self.resolution,
190
- 'n_consensus_runs': self.n_consensus_runs,
191
- 'min_cluster_size': self.min_cluster_size,
192
- 'use_iterative_refinement': self.use_iterative_refinement,
193
- 'max_iterations': self.max_iterations,
194
- 'convergence_threshold': self.convergence_threshold,
195
- 'refinement_strength': self.refinement_strength,
196
- 'n_keywords': self.n_keywords,
197
- 'n_representative_docs': self.n_representative_docs,
198
- 'keyword_method': self.keyword_method,
199
- 'outlier_threshold': self.outlier_threshold,
200
- 'reduce_outliers': self.reduce_outliers,
201
- 'random_state': self.random_state,
202
- 'verbose': self.verbose,
203
- 'n_jobs': self.n_jobs,
245
+ k: v for k, v in self.__dict__.items()
246
+ if not k.startswith('_')
204
247
  }
205
248
 
206
249
  @classmethod
207
250
  def from_dict(cls, config_dict: dict) -> "TriTopicConfig":
208
251
  """Create config from dictionary."""
209
- return cls(**{k: v for k, v in config_dict.items() if hasattr(cls, k)})
252
+ return cls(**config_dict)
210
253
 
211
254
 
212
- def get_config(**kwargs) -> TriTopicConfig:
213
- """
214
- Helper function to create a TriTopicConfig with custom parameters.
215
-
216
- All parameters are optional and default to TriTopicConfig defaults.
217
-
218
- Example:
219
- config = get_config(
220
- embedding_model="all-mpnet-base-v2",
221
- n_neighbors=20,
222
- use_iterative_refinement=True
223
- )
255
+ # Predefined configurations for common use cases
256
+ CONFIGS = {
257
+ "default": TriTopicConfig(),
224
258
 
225
- Args:
226
- **kwargs: Any TriTopicConfig parameter
227
-
228
- Returns:
229
- TriTopicConfig instance
230
- """
231
- return TriTopicConfig(**kwargs)
232
-
233
-
234
- # Preset configurations for common use cases
235
- PRESETS = {
236
259
  "fast": TriTopicConfig(
237
260
  embedding_model="all-MiniLM-L6-v2",
238
261
  n_neighbors=10,
239
262
  n_consensus_runs=5,
240
263
  use_iterative_refinement=False,
241
- keyword_method="ctfidf",
242
- ),
243
- "balanced": TriTopicConfig(
244
- embedding_model="all-MiniLM-L6-v2",
245
- n_neighbors=15,
246
- n_consensus_runs=10,
247
- use_iterative_refinement=True,
248
- max_iterations=3,
249
- keyword_method="ctfidf",
264
+ representative_method="centroid",
250
265
  ),
266
+
251
267
  "quality": TriTopicConfig(
252
- embedding_model="all-mpnet-base-v2",
253
- n_neighbors=20,
254
- n_consensus_runs=15,
255
- use_iterative_refinement=True,
256
- max_iterations=5,
257
- keyword_method="ctfidf",
258
- ),
259
- "research": TriTopicConfig(
260
268
  embedding_model="BAAI/bge-base-en-v1.5",
261
269
  n_neighbors=20,
262
270
  n_consensus_runs=20,
263
- use_iterative_refinement=True,
264
271
  max_iterations=10,
265
- convergence_threshold=0.98,
266
- keyword_method="ctfidf",
272
+ representative_method="hybrid",
273
+ n_archetypes=5,
267
274
  ),
268
- }
269
-
270
-
271
- def get_preset(name: str) -> TriTopicConfig:
272
- """
273
- Get a preset configuration.
274
275
 
275
- Available presets:
276
- - "fast": Quick results, lower quality
277
- - "balanced": Good balance of speed and quality (default)
278
- - "quality": Higher quality, slower
279
- - "research": Maximum quality for research/publication
276
+ "multilingual": TriTopicConfig(
277
+ multilingual=True,
278
+ embedding_model="paraphrase-multilingual-mpnet-base-v2",
279
+ semantic_weight=0.6,
280
+ lexical_weight=0.2,
281
+ metadata_weight=0.2,
282
+ ),
280
283
 
281
- Args:
282
- name: Preset name
283
-
284
- Returns:
285
- TriTopicConfig instance
286
- """
287
- if name not in PRESETS:
288
- raise ValueError(f"Unknown preset '{name}'. Available: {list(PRESETS.keys())}")
289
- return PRESETS[name]
284
+ "multilingual_quality": TriTopicConfig(
285
+ multilingual=True,
286
+ embedding_model="BAAI/bge-m3",
287
+ n_neighbors=20,
288
+ n_consensus_runs=15,
289
+ semantic_weight=0.6,
290
+ lexical_weight=0.2,
291
+ representative_method="hybrid",
292
+ ),
293
+
294
+ "chinese": TriTopicConfig(
295
+ language="zh",
296
+ embedding_model="BAAI/bge-base-zh-v1.5",
297
+ tokenizer="jieba",
298
+ ngram_range=(1, 2),
299
+ ),
300
+
301
+ "german": TriTopicConfig(
302
+ language="de",
303
+ embedding_model="paraphrase-multilingual-MiniLM-L12-v2",
304
+ ),
305
+ }
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: tritopic
3
- Version: 1.1.0
3
+ Version: 1.1.2
4
4
  Summary: Tri-Modal Graph Topic Modeling with Iterative Refinement
5
5
  Author-email: Roman Egger <roman.egger@smartvisions.at>
6
6
  License: MIT
@@ -1,5 +1,5 @@
1
1
  tritopic/__init__.py,sha256=BaHbardg5BW9zykYOtYG1ZM1nGwvfVt7DV7NJ7tp4l8,936
2
- tritopic/config.py,sha256=vL47vU5KAYD1iCzH3cRMFUO1w1NSibmjIuAHNsBLu5c,10614
2
+ tritopic/config.py,sha256=bsornL0etlRxQyMa6-Yx7tgXqVR1b8OZPpXM62cibhI,10120
3
3
  tritopic/labeling.py,sha256=SJsvOXRl-q8f3qtk1S66FGozTJsW8bwNnAKGkAklmVQ,8883
4
4
  tritopic/model.py,sha256=mzptfvqG_Q81OcS6kiYd7u2uU2AKjxpDYKo9u1EfpH4,25015
5
5
  tritopic/visualization.py,sha256=MCiIgIoTzFoQ7GG9WjfSZlV2j1BBGzZwxRddmvmh1OY,9841
@@ -14,7 +14,7 @@ tritopic/multilingual/__init__.py,sha256=EagOqVqMDNKX7AfEAQfVgbR92f2vBy1KSM5O88A
14
14
  tritopic/multilingual/detection.py,sha256=xeZqNp4l-fRII5s2S4EMzBdJPf3Xgt6e1a3Od2hc2q4,5700
15
15
  tritopic/multilingual/stopwords.py,sha256=viMM1pb4VpDEmDpGpx_8sDfumXfrVXKfUULyOZXFFYU,29942
16
16
  tritopic/multilingual/tokenizers.py,sha256=seTCzRiUOqO0UbAqA3nn8V8EoVYQ1wiwqcH8lafRCxM,9954
17
- tritopic-1.1.0.dist-info/METADATA,sha256=nIWD3zUMOQR9efdUFo8zUjM0JVJGgrzgZVDyLbbjJ7I,13922
18
- tritopic-1.1.0.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
19
- tritopic-1.1.0.dist-info/top_level.txt,sha256=9PASbqQyi0-wa7E2Hl3Z0u1ae7MwLcfgFliFE1ioFBA,9
20
- tritopic-1.1.0.dist-info/RECORD,,
17
+ tritopic-1.1.2.dist-info/METADATA,sha256=730Y7lueQ4nGWQeu2187uEpS03aaLLXcLLlZTODi668,13922
18
+ tritopic-1.1.2.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
19
+ tritopic-1.1.2.dist-info/top_level.txt,sha256=9PASbqQyi0-wa7E2Hl3Z0u1ae7MwLcfgFliFE1ioFBA,9
20
+ tritopic-1.1.2.dist-info/RECORD,,