tritopic 0.1.0__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,275 @@
1
+ """
2
+ Multilingual Tokenizers Module
3
+
4
+ Provides language-specific tokenization for various languages including CJK.
5
+ """
6
+
7
+ from typing import List, Callable, Optional
8
+ import re
9
+ import warnings
10
+
11
+
12
+ class TokenizerFactory:
13
+ """Factory for creating language-specific tokenizers."""
14
+
15
+ _tokenizers = {}
16
+
17
+ @classmethod
18
+ def get_tokenizer(cls, language: str, tokenizer_type: str = "auto") -> Callable[[str], List[str]]:
19
+ """
20
+ Get the appropriate tokenizer for a language.
21
+
22
+ Parameters
23
+ ----------
24
+ language : str
25
+ ISO 639-1 language code
26
+ tokenizer_type : str
27
+ Type of tokenizer: "auto", "whitespace", "spacy", "jieba", "fugashi", "konlpy", "pythainlp"
28
+
29
+ Returns
30
+ -------
31
+ Callable[[str], List[str]]
32
+ A tokenizer function that takes text and returns tokens
33
+ """
34
+ if tokenizer_type != "auto":
35
+ return cls._get_specific_tokenizer(tokenizer_type, language)
36
+
37
+ # Auto-select based on language
38
+ if language in ['zh', 'zh-cn', 'zh-tw']:
39
+ return cls._get_jieba_tokenizer()
40
+ elif language == 'ja':
41
+ return cls._get_japanese_tokenizer()
42
+ elif language == 'ko':
43
+ return cls._get_korean_tokenizer()
44
+ elif language == 'th':
45
+ return cls._get_thai_tokenizer()
46
+ else:
47
+ return cls._get_whitespace_tokenizer(language)
48
+
49
+ @classmethod
50
+ def _get_specific_tokenizer(cls, tokenizer_type: str, language: str) -> Callable[[str], List[str]]:
51
+ """Get a specific tokenizer by name."""
52
+ tokenizer_map = {
53
+ "whitespace": lambda: cls._get_whitespace_tokenizer(language),
54
+ "jieba": cls._get_jieba_tokenizer,
55
+ "fugashi": cls._get_japanese_tokenizer,
56
+ "mecab": cls._get_japanese_tokenizer,
57
+ "konlpy": cls._get_korean_tokenizer,
58
+ "pythainlp": cls._get_thai_tokenizer,
59
+ "spacy": lambda: cls._get_spacy_tokenizer(language),
60
+ }
61
+
62
+ if tokenizer_type in tokenizer_map:
63
+ return tokenizer_map[tokenizer_type]()
64
+ else:
65
+ warnings.warn(f"Unknown tokenizer '{tokenizer_type}', falling back to whitespace")
66
+ return cls._get_whitespace_tokenizer(language)
67
+
68
+ @classmethod
69
+ def _get_whitespace_tokenizer(cls, language: str) -> Callable[[str], List[str]]:
70
+ """Get a simple whitespace-based tokenizer with language-aware preprocessing."""
71
+
72
+ def tokenize(text: str) -> List[str]:
73
+ # Lowercase
74
+ text = text.lower()
75
+ # Remove punctuation but keep apostrophes for contractions
76
+ text = re.sub(r"[^\w\s'-]", " ", text)
77
+ # Split on whitespace
78
+ tokens = text.split()
79
+ # Remove tokens that are just punctuation or numbers
80
+ tokens = [t for t in tokens if re.search(r'[a-zA-ZäöüßàâçéèêëîïôûùüÿñæœÄÖÜ]', t)]
81
+ return tokens
82
+
83
+ return tokenize
84
+
85
+ @classmethod
86
+ def _get_jieba_tokenizer(cls) -> Callable[[str], List[str]]:
87
+ """Get Chinese tokenizer using jieba."""
88
+ try:
89
+ import jieba
90
+ jieba.setLogLevel(jieba.logging.INFO) # Reduce verbosity
91
+
92
+ def tokenize(text: str) -> List[str]:
93
+ # Use jieba's cut function for word segmentation
94
+ tokens = list(jieba.cut(text, cut_all=False))
95
+ # Filter out whitespace and punctuation
96
+ tokens = [t.strip() for t in tokens if t.strip() and not re.match(r'^[\s\W]+$', t)]
97
+ return tokens
98
+
99
+ return tokenize
100
+
101
+ except ImportError:
102
+ warnings.warn(
103
+ "jieba not installed. Install with 'pip install jieba' for Chinese tokenization. "
104
+ "Falling back to character-based tokenization."
105
+ )
106
+ return cls._get_character_tokenizer()
107
+
108
+ @classmethod
109
+ def _get_japanese_tokenizer(cls) -> Callable[[str], List[str]]:
110
+ """Get Japanese tokenizer using fugashi (MeCab)."""
111
+ try:
112
+ import fugashi
113
+ tagger = fugashi.Tagger()
114
+
115
+ def tokenize(text: str) -> List[str]:
116
+ tokens = []
117
+ for word in tagger(text):
118
+ surface = word.surface
119
+ # Filter out punctuation and whitespace
120
+ if surface.strip() and not re.match(r'^[\s\W]+$', surface):
121
+ tokens.append(surface)
122
+ return tokens
123
+
124
+ return tokenize
125
+
126
+ except ImportError:
127
+ warnings.warn(
128
+ "fugashi not installed. Install with 'pip install fugashi unidic-lite' for Japanese tokenization. "
129
+ "Falling back to character-based tokenization."
130
+ )
131
+ return cls._get_character_tokenizer()
132
+
133
+ @classmethod
134
+ def _get_korean_tokenizer(cls) -> Callable[[str], List[str]]:
135
+ """Get Korean tokenizer using KoNLPy."""
136
+ try:
137
+ from konlpy.tag import Okt
138
+ okt = Okt()
139
+
140
+ def tokenize(text: str) -> List[str]:
141
+ # Use morphological analysis
142
+ tokens = okt.morphs(text)
143
+ # Filter out punctuation
144
+ tokens = [t for t in tokens if t.strip() and not re.match(r'^[\s\W]+$', t)]
145
+ return tokens
146
+
147
+ return tokenize
148
+
149
+ except ImportError:
150
+ warnings.warn(
151
+ "konlpy not installed. Install with 'pip install konlpy' for Korean tokenization. "
152
+ "Note: KoNLPy may require Java. Falling back to character-based tokenization."
153
+ )
154
+ return cls._get_character_tokenizer()
155
+
156
+ @classmethod
157
+ def _get_thai_tokenizer(cls) -> Callable[[str], List[str]]:
158
+ """Get Thai tokenizer using pythainlp."""
159
+ try:
160
+ from pythainlp.tokenize import word_tokenize
161
+
162
+ def tokenize(text: str) -> List[str]:
163
+ tokens = word_tokenize(text, engine='newmm')
164
+ # Filter out whitespace and punctuation
165
+ tokens = [t.strip() for t in tokens if t.strip() and not re.match(r'^[\s\W]+$', t)]
166
+ return tokens
167
+
168
+ return tokenize
169
+
170
+ except ImportError:
171
+ warnings.warn(
172
+ "pythainlp not installed. Install with 'pip install pythainlp' for Thai tokenization. "
173
+ "Falling back to whitespace tokenization."
174
+ )
175
+ return cls._get_whitespace_tokenizer('th')
176
+
177
+ @classmethod
178
+ def _get_spacy_tokenizer(cls, language: str) -> Callable[[str], List[str]]:
179
+ """Get spaCy tokenizer for specified language."""
180
+ try:
181
+ import spacy
182
+
183
+ # Map language codes to spaCy model names
184
+ model_map = {
185
+ 'en': 'en_core_web_sm',
186
+ 'de': 'de_core_news_sm',
187
+ 'fr': 'fr_core_news_sm',
188
+ 'es': 'es_core_news_sm',
189
+ 'it': 'it_core_news_sm',
190
+ 'pt': 'pt_core_news_sm',
191
+ 'nl': 'nl_core_news_sm',
192
+ 'pl': 'pl_core_news_sm',
193
+ 'ru': 'ru_core_news_sm',
194
+ 'zh': 'zh_core_web_sm',
195
+ 'ja': 'ja_core_news_sm',
196
+ }
197
+
198
+ model_name = model_map.get(language, 'en_core_web_sm')
199
+
200
+ try:
201
+ nlp = spacy.load(model_name)
202
+ except OSError:
203
+ warnings.warn(f"spaCy model '{model_name}' not found. Falling back to whitespace tokenization.")
204
+ return cls._get_whitespace_tokenizer(language)
205
+
206
+ def tokenize(text: str) -> List[str]:
207
+ doc = nlp(text)
208
+ tokens = [token.text.lower() for token in doc if not token.is_punct and not token.is_space]
209
+ return tokens
210
+
211
+ return tokenize
212
+
213
+ except ImportError:
214
+ warnings.warn("spaCy not installed. Falling back to whitespace tokenization.")
215
+ return cls._get_whitespace_tokenizer(language)
216
+
217
+ @classmethod
218
+ def _get_character_tokenizer(cls) -> Callable[[str], List[str]]:
219
+ """Fallback character-based tokenizer for CJK without proper tokenizer."""
220
+
221
+ def tokenize(text: str) -> List[str]:
222
+ # For CJK, use n-gram characters
223
+ tokens = []
224
+ # Remove whitespace and punctuation
225
+ text = re.sub(r'[\s\W]+', '', text)
226
+ # Create character bigrams
227
+ for i in range(len(text) - 1):
228
+ tokens.append(text[i:i+2])
229
+ return tokens
230
+
231
+ return tokenize
232
+
233
+
234
+ def tokenize_documents(
235
+ documents: List[str],
236
+ language: str,
237
+ tokenizer_type: str = "auto",
238
+ min_length: int = 2,
239
+ max_length: int = 50
240
+ ) -> List[List[str]]:
241
+ """
242
+ Tokenize a list of documents.
243
+
244
+ Parameters
245
+ ----------
246
+ documents : List[str]
247
+ List of documents to tokenize
248
+ language : str
249
+ ISO 639-1 language code
250
+ tokenizer_type : str
251
+ Type of tokenizer to use
252
+ min_length : int
253
+ Minimum token length to keep
254
+ max_length : int
255
+ Maximum token length to keep
256
+
257
+ Returns
258
+ -------
259
+ List[List[str]]
260
+ List of tokenized documents
261
+ """
262
+ tokenizer = TokenizerFactory.get_tokenizer(language, tokenizer_type)
263
+
264
+ tokenized = []
265
+ for doc in documents:
266
+ if not isinstance(doc, str):
267
+ tokenized.append([])
268
+ continue
269
+
270
+ tokens = tokenizer(doc)
271
+ # Filter by length
272
+ tokens = [t for t in tokens if min_length <= len(t) <= max_length]
273
+ tokenized.append(tokens)
274
+
275
+ return tokenized
@@ -0,0 +1,371 @@
1
+ """
2
+ Visualization functions for TriTopic.
3
+
4
+ Provides interactive visualizations using Plotly.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from typing import Any, Dict, List, Optional, TYPE_CHECKING
10
+
11
+ import numpy as np
12
+
13
+ if TYPE_CHECKING:
14
+ from .model import Topic
15
+
16
+
17
+ def create_topic_visualization(
18
+ embeddings: np.ndarray,
19
+ labels: np.ndarray,
20
+ topics: List["Topic"],
21
+ documents: List[str],
22
+ method: str = "umap",
23
+ width: int = 900,
24
+ height: int = 700,
25
+ point_size: int = 5,
26
+ **kwargs
27
+ ) -> Any:
28
+ """
29
+ Create a 2D visualization of documents colored by topic.
30
+
31
+ Parameters
32
+ ----------
33
+ embeddings : np.ndarray
34
+ Document embeddings.
35
+ labels : np.ndarray
36
+ Topic assignments.
37
+ topics : List[Topic]
38
+ Topic objects with labels.
39
+ documents : List[str]
40
+ Original documents for hover text.
41
+ method : str
42
+ Reduction method: "umap", "tsne", or "pca".
43
+ width, height : int
44
+ Figure dimensions.
45
+ point_size : int
46
+ Size of scatter points.
47
+
48
+ Returns
49
+ -------
50
+ plotly.graph_objects.Figure
51
+ Interactive scatter plot.
52
+ """
53
+ try:
54
+ import plotly.express as px
55
+ import plotly.graph_objects as go
56
+ except ImportError:
57
+ raise ImportError("Plotly required. Install with: pip install plotly")
58
+
59
+ # Reduce dimensions
60
+ coords = _reduce_dimensions(embeddings, method, **kwargs)
61
+
62
+ # Create topic labels for legend
63
+ topic_map = {t.topic_id: t.label or f"Topic {t.topic_id}" for t in topics}
64
+ topic_names = [topic_map.get(l, "Outlier") for l in labels]
65
+
66
+ # Create hover text
67
+ hover_texts = [doc[:200] + "..." if len(doc) > 200 else doc for doc in documents]
68
+
69
+ # Create figure
70
+ fig = go.Figure()
71
+
72
+ # Add scatter for each topic
73
+ unique_labels = sorted(set(labels))
74
+ colors = px.colors.qualitative.Set3 + px.colors.qualitative.Pastel
75
+
76
+ for i, topic_id in enumerate(unique_labels):
77
+ mask = labels == topic_id
78
+ topic_name = topic_map.get(topic_id, "Outlier")
79
+ color = colors[i % len(colors)] if topic_id >= 0 else "lightgray"
80
+
81
+ fig.add_trace(go.Scatter(
82
+ x=coords[mask, 0],
83
+ y=coords[mask, 1],
84
+ mode="markers",
85
+ name=topic_name,
86
+ marker=dict(
87
+ size=point_size,
88
+ color=color,
89
+ opacity=0.7 if topic_id >= 0 else 0.3
90
+ ),
91
+ text=[hover_texts[j] for j in np.where(mask)[0]],
92
+ hovertemplate="<b>%{text}</b><extra></extra>"
93
+ ))
94
+
95
+ fig.update_layout(
96
+ title="TriTopic Document Map",
97
+ xaxis_title=f"{method.upper()} 1",
98
+ yaxis_title=f"{method.upper()} 2",
99
+ width=width,
100
+ height=height,
101
+ template="plotly_white",
102
+ legend=dict(
103
+ yanchor="top",
104
+ y=0.99,
105
+ xanchor="left",
106
+ x=1.02
107
+ )
108
+ )
109
+
110
+ return fig
111
+
112
+
113
+ def create_topic_barchart(
114
+ topics: List["Topic"],
115
+ n_keywords: int = 10,
116
+ width: int = 800,
117
+ height: int = None
118
+ ) -> Any:
119
+ """
120
+ Create horizontal bar charts showing top keywords per topic.
121
+
122
+ Parameters
123
+ ----------
124
+ topics : List[Topic]
125
+ Topic objects with keywords.
126
+ n_keywords : int
127
+ Number of keywords to show.
128
+ width : int
129
+ Figure width.
130
+ height : int, optional
131
+ Figure height. Auto-calculated if None.
132
+
133
+ Returns
134
+ -------
135
+ plotly.graph_objects.Figure
136
+ Bar chart figure.
137
+ """
138
+ try:
139
+ import plotly.graph_objects as go
140
+ from plotly.subplots import make_subplots
141
+ except ImportError:
142
+ raise ImportError("Plotly required.")
143
+
144
+ # Filter out outliers
145
+ valid_topics = [t for t in topics if t.topic_id >= 0]
146
+ n_topics = len(valid_topics)
147
+
148
+ if n_topics == 0:
149
+ raise ValueError("No valid topics to visualize")
150
+
151
+ # Calculate layout
152
+ n_cols = min(3, n_topics)
153
+ n_rows = (n_topics + n_cols - 1) // n_cols
154
+
155
+ if height is None:
156
+ height = n_rows * 250
157
+
158
+ # Create subplots
159
+ subplot_titles = [
160
+ t.label or f"Topic {t.topic_id}" for t in valid_topics
161
+ ]
162
+
163
+ fig = make_subplots(
164
+ rows=n_rows,
165
+ cols=n_cols,
166
+ subplot_titles=subplot_titles,
167
+ horizontal_spacing=0.1,
168
+ vertical_spacing=0.15
169
+ )
170
+
171
+ # Add bar for each topic
172
+ for i, topic in enumerate(valid_topics):
173
+ row = i // n_cols + 1
174
+ col = i % n_cols + 1
175
+
176
+ keywords = topic.keywords[:n_keywords][::-1] # Reverse for horizontal bars
177
+ scores = topic.keyword_scores[:n_keywords][::-1]
178
+
179
+ fig.add_trace(
180
+ go.Bar(
181
+ x=scores,
182
+ y=keywords,
183
+ orientation='h',
184
+ marker_color='steelblue',
185
+ showlegend=False
186
+ ),
187
+ row=row,
188
+ col=col
189
+ )
190
+
191
+ fig.update_layout(
192
+ title="Topic Keywords",
193
+ width=width,
194
+ height=height,
195
+ template="plotly_white"
196
+ )
197
+
198
+ return fig
199
+
200
+
201
+ def create_topic_hierarchy(
202
+ embeddings: np.ndarray,
203
+ labels: np.ndarray,
204
+ topics: List["Topic"],
205
+ method: str = "ward",
206
+ **kwargs
207
+ ) -> Any:
208
+ """
209
+ Create a hierarchical clustering dendrogram of topics.
210
+
211
+ Parameters
212
+ ----------
213
+ embeddings : np.ndarray
214
+ Document embeddings.
215
+ labels : np.ndarray
216
+ Topic assignments.
217
+ topics : List[Topic]
218
+ Topic objects.
219
+ method : str
220
+ Linkage method for hierarchical clustering.
221
+
222
+ Returns
223
+ -------
224
+ plotly.graph_objects.Figure
225
+ Dendrogram figure.
226
+ """
227
+ try:
228
+ import plotly.figure_factory as ff
229
+ except ImportError:
230
+ raise ImportError("Plotly required.")
231
+
232
+ from scipy.cluster.hierarchy import linkage
233
+ from scipy.spatial.distance import pdist
234
+
235
+ # Compute topic centroids
236
+ valid_topics = [t for t in topics if t.topic_id >= 0 and t.centroid is not None]
237
+
238
+ if len(valid_topics) < 2:
239
+ raise ValueError("Need at least 2 topics for hierarchy")
240
+
241
+ centroids = np.array([t.centroid for t in valid_topics])
242
+ topic_labels = [t.label or f"Topic {t.topic_id}" for t in valid_topics]
243
+
244
+ # Compute linkage
245
+ distances = pdist(centroids, metric='cosine')
246
+ Z = linkage(distances, method=method)
247
+
248
+ # Create dendrogram
249
+ fig = ff.create_dendrogram(
250
+ centroids,
251
+ orientation='left',
252
+ labels=topic_labels,
253
+ linkagefun=lambda x: Z
254
+ )
255
+
256
+ fig.update_layout(
257
+ title="Topic Hierarchy",
258
+ width=800,
259
+ height=max(400, len(valid_topics) * 30),
260
+ template="plotly_white"
261
+ )
262
+
263
+ return fig
264
+
265
+
266
+ def create_topic_heatmap(
267
+ topics: List["Topic"],
268
+ documents: List[str],
269
+ labels: np.ndarray,
270
+ n_top_keywords: int = 20
271
+ ) -> Any:
272
+ """
273
+ Create a heatmap showing keyword importance across topics.
274
+
275
+ Parameters
276
+ ----------
277
+ topics : List[Topic]
278
+ Topic objects.
279
+ documents : List[str]
280
+ Original documents.
281
+ labels : np.ndarray
282
+ Topic assignments.
283
+ n_top_keywords : int
284
+ Number of keywords to include.
285
+
286
+ Returns
287
+ -------
288
+ plotly.graph_objects.Figure
289
+ Heatmap figure.
290
+ """
291
+ try:
292
+ import plotly.graph_objects as go
293
+ except ImportError:
294
+ raise ImportError("Plotly required.")
295
+
296
+ # Collect all unique keywords
297
+ valid_topics = [t for t in topics if t.topic_id >= 0]
298
+ all_keywords = set()
299
+ for topic in valid_topics:
300
+ all_keywords.update(topic.keywords[:n_top_keywords])
301
+
302
+ all_keywords = sorted(all_keywords)
303
+
304
+ # Build matrix
305
+ matrix = np.zeros((len(valid_topics), len(all_keywords)))
306
+
307
+ for i, topic in enumerate(valid_topics):
308
+ for j, kw in enumerate(all_keywords):
309
+ if kw in topic.keywords:
310
+ idx = topic.keywords.index(kw)
311
+ if idx < len(topic.keyword_scores):
312
+ matrix[i, j] = topic.keyword_scores[idx]
313
+
314
+ # Create heatmap
315
+ topic_names = [t.label or f"Topic {t.topic_id}" for t in valid_topics]
316
+
317
+ fig = go.Figure(data=go.Heatmap(
318
+ z=matrix,
319
+ x=all_keywords,
320
+ y=topic_names,
321
+ colorscale='Blues'
322
+ ))
323
+
324
+ fig.update_layout(
325
+ title="Keyword Importance by Topic",
326
+ xaxis_title="Keywords",
327
+ yaxis_title="Topics",
328
+ width=max(800, len(all_keywords) * 20),
329
+ height=max(400, len(valid_topics) * 30),
330
+ template="plotly_white"
331
+ )
332
+
333
+ return fig
334
+
335
+
336
+ def _reduce_dimensions(
337
+ embeddings: np.ndarray,
338
+ method: str = "umap",
339
+ n_components: int = 2,
340
+ **kwargs
341
+ ) -> np.ndarray:
342
+ """Reduce embedding dimensions for visualization."""
343
+
344
+ if method.lower() == "umap":
345
+ try:
346
+ from umap import UMAP
347
+ reducer = UMAP(
348
+ n_components=n_components,
349
+ n_neighbors=kwargs.get("n_neighbors", 15),
350
+ min_dist=kwargs.get("min_dist", 0.1),
351
+ metric=kwargs.get("metric", "cosine"),
352
+ random_state=kwargs.get("random_state", 42)
353
+ )
354
+ return reducer.fit_transform(embeddings)
355
+ except ImportError:
356
+ print("UMAP not available, falling back to PCA")
357
+ method = "pca"
358
+
359
+ if method.lower() == "tsne":
360
+ from sklearn.manifold import TSNE
361
+ reducer = TSNE(
362
+ n_components=n_components,
363
+ perplexity=kwargs.get("perplexity", 30),
364
+ random_state=kwargs.get("random_state", 42)
365
+ )
366
+ return reducer.fit_transform(embeddings)
367
+
368
+ # Default: PCA
369
+ from sklearn.decomposition import PCA
370
+ reducer = PCA(n_components=n_components)
371
+ return reducer.fit_transform(embeddings)