tritopic 0.1.0__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tritopic might be problematic. Click here for more details.
- tritopic/__init__.py +22 -32
- tritopic/config.py +289 -0
- tritopic/core/__init__.py +0 -17
- tritopic/core/clustering.py +229 -243
- tritopic/core/embeddings.py +151 -157
- tritopic/core/graph.py +435 -0
- tritopic/core/keywords.py +213 -249
- tritopic/core/refinement.py +231 -0
- tritopic/core/representatives.py +560 -0
- tritopic/labeling.py +313 -0
- tritopic/model.py +718 -0
- tritopic/multilingual/__init__.py +38 -0
- tritopic/multilingual/detection.py +208 -0
- tritopic/multilingual/stopwords.py +467 -0
- tritopic/multilingual/tokenizers.py +275 -0
- tritopic/visualization.py +371 -0
- {tritopic-0.1.0.dist-info → tritopic-1.1.0.dist-info}/METADATA +91 -51
- tritopic-1.1.0.dist-info/RECORD +20 -0
- tritopic/core/graph_builder.py +0 -493
- tritopic/core/model.py +0 -810
- tritopic/labeling/__init__.py +0 -5
- tritopic/labeling/llm_labeler.py +0 -279
- tritopic/utils/__init__.py +0 -13
- tritopic/utils/metrics.py +0 -254
- tritopic/visualization/__init__.py +0 -5
- tritopic/visualization/plotter.py +0 -523
- tritopic-0.1.0.dist-info/RECORD +0 -18
- tritopic-0.1.0.dist-info/licenses/LICENSE +0 -21
- {tritopic-0.1.0.dist-info → tritopic-1.1.0.dist-info}/WHEEL +0 -0
- {tritopic-0.1.0.dist-info → tritopic-1.1.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,275 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Multilingual Tokenizers Module
|
|
3
|
+
|
|
4
|
+
Provides language-specific tokenization for various languages including CJK.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from typing import List, Callable, Optional
|
|
8
|
+
import re
|
|
9
|
+
import warnings
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class TokenizerFactory:
|
|
13
|
+
"""Factory for creating language-specific tokenizers."""
|
|
14
|
+
|
|
15
|
+
_tokenizers = {}
|
|
16
|
+
|
|
17
|
+
@classmethod
|
|
18
|
+
def get_tokenizer(cls, language: str, tokenizer_type: str = "auto") -> Callable[[str], List[str]]:
|
|
19
|
+
"""
|
|
20
|
+
Get the appropriate tokenizer for a language.
|
|
21
|
+
|
|
22
|
+
Parameters
|
|
23
|
+
----------
|
|
24
|
+
language : str
|
|
25
|
+
ISO 639-1 language code
|
|
26
|
+
tokenizer_type : str
|
|
27
|
+
Type of tokenizer: "auto", "whitespace", "spacy", "jieba", "fugashi", "konlpy", "pythainlp"
|
|
28
|
+
|
|
29
|
+
Returns
|
|
30
|
+
-------
|
|
31
|
+
Callable[[str], List[str]]
|
|
32
|
+
A tokenizer function that takes text and returns tokens
|
|
33
|
+
"""
|
|
34
|
+
if tokenizer_type != "auto":
|
|
35
|
+
return cls._get_specific_tokenizer(tokenizer_type, language)
|
|
36
|
+
|
|
37
|
+
# Auto-select based on language
|
|
38
|
+
if language in ['zh', 'zh-cn', 'zh-tw']:
|
|
39
|
+
return cls._get_jieba_tokenizer()
|
|
40
|
+
elif language == 'ja':
|
|
41
|
+
return cls._get_japanese_tokenizer()
|
|
42
|
+
elif language == 'ko':
|
|
43
|
+
return cls._get_korean_tokenizer()
|
|
44
|
+
elif language == 'th':
|
|
45
|
+
return cls._get_thai_tokenizer()
|
|
46
|
+
else:
|
|
47
|
+
return cls._get_whitespace_tokenizer(language)
|
|
48
|
+
|
|
49
|
+
@classmethod
|
|
50
|
+
def _get_specific_tokenizer(cls, tokenizer_type: str, language: str) -> Callable[[str], List[str]]:
|
|
51
|
+
"""Get a specific tokenizer by name."""
|
|
52
|
+
tokenizer_map = {
|
|
53
|
+
"whitespace": lambda: cls._get_whitespace_tokenizer(language),
|
|
54
|
+
"jieba": cls._get_jieba_tokenizer,
|
|
55
|
+
"fugashi": cls._get_japanese_tokenizer,
|
|
56
|
+
"mecab": cls._get_japanese_tokenizer,
|
|
57
|
+
"konlpy": cls._get_korean_tokenizer,
|
|
58
|
+
"pythainlp": cls._get_thai_tokenizer,
|
|
59
|
+
"spacy": lambda: cls._get_spacy_tokenizer(language),
|
|
60
|
+
}
|
|
61
|
+
|
|
62
|
+
if tokenizer_type in tokenizer_map:
|
|
63
|
+
return tokenizer_map[tokenizer_type]()
|
|
64
|
+
else:
|
|
65
|
+
warnings.warn(f"Unknown tokenizer '{tokenizer_type}', falling back to whitespace")
|
|
66
|
+
return cls._get_whitespace_tokenizer(language)
|
|
67
|
+
|
|
68
|
+
@classmethod
|
|
69
|
+
def _get_whitespace_tokenizer(cls, language: str) -> Callable[[str], List[str]]:
|
|
70
|
+
"""Get a simple whitespace-based tokenizer with language-aware preprocessing."""
|
|
71
|
+
|
|
72
|
+
def tokenize(text: str) -> List[str]:
|
|
73
|
+
# Lowercase
|
|
74
|
+
text = text.lower()
|
|
75
|
+
# Remove punctuation but keep apostrophes for contractions
|
|
76
|
+
text = re.sub(r"[^\w\s'-]", " ", text)
|
|
77
|
+
# Split on whitespace
|
|
78
|
+
tokens = text.split()
|
|
79
|
+
# Remove tokens that are just punctuation or numbers
|
|
80
|
+
tokens = [t for t in tokens if re.search(r'[a-zA-ZäöüßàâçéèêëîïôûùüÿñæœÄÖÜ]', t)]
|
|
81
|
+
return tokens
|
|
82
|
+
|
|
83
|
+
return tokenize
|
|
84
|
+
|
|
85
|
+
@classmethod
|
|
86
|
+
def _get_jieba_tokenizer(cls) -> Callable[[str], List[str]]:
|
|
87
|
+
"""Get Chinese tokenizer using jieba."""
|
|
88
|
+
try:
|
|
89
|
+
import jieba
|
|
90
|
+
jieba.setLogLevel(jieba.logging.INFO) # Reduce verbosity
|
|
91
|
+
|
|
92
|
+
def tokenize(text: str) -> List[str]:
|
|
93
|
+
# Use jieba's cut function for word segmentation
|
|
94
|
+
tokens = list(jieba.cut(text, cut_all=False))
|
|
95
|
+
# Filter out whitespace and punctuation
|
|
96
|
+
tokens = [t.strip() for t in tokens if t.strip() and not re.match(r'^[\s\W]+$', t)]
|
|
97
|
+
return tokens
|
|
98
|
+
|
|
99
|
+
return tokenize
|
|
100
|
+
|
|
101
|
+
except ImportError:
|
|
102
|
+
warnings.warn(
|
|
103
|
+
"jieba not installed. Install with 'pip install jieba' for Chinese tokenization. "
|
|
104
|
+
"Falling back to character-based tokenization."
|
|
105
|
+
)
|
|
106
|
+
return cls._get_character_tokenizer()
|
|
107
|
+
|
|
108
|
+
@classmethod
|
|
109
|
+
def _get_japanese_tokenizer(cls) -> Callable[[str], List[str]]:
|
|
110
|
+
"""Get Japanese tokenizer using fugashi (MeCab)."""
|
|
111
|
+
try:
|
|
112
|
+
import fugashi
|
|
113
|
+
tagger = fugashi.Tagger()
|
|
114
|
+
|
|
115
|
+
def tokenize(text: str) -> List[str]:
|
|
116
|
+
tokens = []
|
|
117
|
+
for word in tagger(text):
|
|
118
|
+
surface = word.surface
|
|
119
|
+
# Filter out punctuation and whitespace
|
|
120
|
+
if surface.strip() and not re.match(r'^[\s\W]+$', surface):
|
|
121
|
+
tokens.append(surface)
|
|
122
|
+
return tokens
|
|
123
|
+
|
|
124
|
+
return tokenize
|
|
125
|
+
|
|
126
|
+
except ImportError:
|
|
127
|
+
warnings.warn(
|
|
128
|
+
"fugashi not installed. Install with 'pip install fugashi unidic-lite' for Japanese tokenization. "
|
|
129
|
+
"Falling back to character-based tokenization."
|
|
130
|
+
)
|
|
131
|
+
return cls._get_character_tokenizer()
|
|
132
|
+
|
|
133
|
+
@classmethod
|
|
134
|
+
def _get_korean_tokenizer(cls) -> Callable[[str], List[str]]:
|
|
135
|
+
"""Get Korean tokenizer using KoNLPy."""
|
|
136
|
+
try:
|
|
137
|
+
from konlpy.tag import Okt
|
|
138
|
+
okt = Okt()
|
|
139
|
+
|
|
140
|
+
def tokenize(text: str) -> List[str]:
|
|
141
|
+
# Use morphological analysis
|
|
142
|
+
tokens = okt.morphs(text)
|
|
143
|
+
# Filter out punctuation
|
|
144
|
+
tokens = [t for t in tokens if t.strip() and not re.match(r'^[\s\W]+$', t)]
|
|
145
|
+
return tokens
|
|
146
|
+
|
|
147
|
+
return tokenize
|
|
148
|
+
|
|
149
|
+
except ImportError:
|
|
150
|
+
warnings.warn(
|
|
151
|
+
"konlpy not installed. Install with 'pip install konlpy' for Korean tokenization. "
|
|
152
|
+
"Note: KoNLPy may require Java. Falling back to character-based tokenization."
|
|
153
|
+
)
|
|
154
|
+
return cls._get_character_tokenizer()
|
|
155
|
+
|
|
156
|
+
@classmethod
|
|
157
|
+
def _get_thai_tokenizer(cls) -> Callable[[str], List[str]]:
|
|
158
|
+
"""Get Thai tokenizer using pythainlp."""
|
|
159
|
+
try:
|
|
160
|
+
from pythainlp.tokenize import word_tokenize
|
|
161
|
+
|
|
162
|
+
def tokenize(text: str) -> List[str]:
|
|
163
|
+
tokens = word_tokenize(text, engine='newmm')
|
|
164
|
+
# Filter out whitespace and punctuation
|
|
165
|
+
tokens = [t.strip() for t in tokens if t.strip() and not re.match(r'^[\s\W]+$', t)]
|
|
166
|
+
return tokens
|
|
167
|
+
|
|
168
|
+
return tokenize
|
|
169
|
+
|
|
170
|
+
except ImportError:
|
|
171
|
+
warnings.warn(
|
|
172
|
+
"pythainlp not installed. Install with 'pip install pythainlp' for Thai tokenization. "
|
|
173
|
+
"Falling back to whitespace tokenization."
|
|
174
|
+
)
|
|
175
|
+
return cls._get_whitespace_tokenizer('th')
|
|
176
|
+
|
|
177
|
+
@classmethod
|
|
178
|
+
def _get_spacy_tokenizer(cls, language: str) -> Callable[[str], List[str]]:
|
|
179
|
+
"""Get spaCy tokenizer for specified language."""
|
|
180
|
+
try:
|
|
181
|
+
import spacy
|
|
182
|
+
|
|
183
|
+
# Map language codes to spaCy model names
|
|
184
|
+
model_map = {
|
|
185
|
+
'en': 'en_core_web_sm',
|
|
186
|
+
'de': 'de_core_news_sm',
|
|
187
|
+
'fr': 'fr_core_news_sm',
|
|
188
|
+
'es': 'es_core_news_sm',
|
|
189
|
+
'it': 'it_core_news_sm',
|
|
190
|
+
'pt': 'pt_core_news_sm',
|
|
191
|
+
'nl': 'nl_core_news_sm',
|
|
192
|
+
'pl': 'pl_core_news_sm',
|
|
193
|
+
'ru': 'ru_core_news_sm',
|
|
194
|
+
'zh': 'zh_core_web_sm',
|
|
195
|
+
'ja': 'ja_core_news_sm',
|
|
196
|
+
}
|
|
197
|
+
|
|
198
|
+
model_name = model_map.get(language, 'en_core_web_sm')
|
|
199
|
+
|
|
200
|
+
try:
|
|
201
|
+
nlp = spacy.load(model_name)
|
|
202
|
+
except OSError:
|
|
203
|
+
warnings.warn(f"spaCy model '{model_name}' not found. Falling back to whitespace tokenization.")
|
|
204
|
+
return cls._get_whitespace_tokenizer(language)
|
|
205
|
+
|
|
206
|
+
def tokenize(text: str) -> List[str]:
|
|
207
|
+
doc = nlp(text)
|
|
208
|
+
tokens = [token.text.lower() for token in doc if not token.is_punct and not token.is_space]
|
|
209
|
+
return tokens
|
|
210
|
+
|
|
211
|
+
return tokenize
|
|
212
|
+
|
|
213
|
+
except ImportError:
|
|
214
|
+
warnings.warn("spaCy not installed. Falling back to whitespace tokenization.")
|
|
215
|
+
return cls._get_whitespace_tokenizer(language)
|
|
216
|
+
|
|
217
|
+
@classmethod
|
|
218
|
+
def _get_character_tokenizer(cls) -> Callable[[str], List[str]]:
|
|
219
|
+
"""Fallback character-based tokenizer for CJK without proper tokenizer."""
|
|
220
|
+
|
|
221
|
+
def tokenize(text: str) -> List[str]:
|
|
222
|
+
# For CJK, use n-gram characters
|
|
223
|
+
tokens = []
|
|
224
|
+
# Remove whitespace and punctuation
|
|
225
|
+
text = re.sub(r'[\s\W]+', '', text)
|
|
226
|
+
# Create character bigrams
|
|
227
|
+
for i in range(len(text) - 1):
|
|
228
|
+
tokens.append(text[i:i+2])
|
|
229
|
+
return tokens
|
|
230
|
+
|
|
231
|
+
return tokenize
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
def tokenize_documents(
|
|
235
|
+
documents: List[str],
|
|
236
|
+
language: str,
|
|
237
|
+
tokenizer_type: str = "auto",
|
|
238
|
+
min_length: int = 2,
|
|
239
|
+
max_length: int = 50
|
|
240
|
+
) -> List[List[str]]:
|
|
241
|
+
"""
|
|
242
|
+
Tokenize a list of documents.
|
|
243
|
+
|
|
244
|
+
Parameters
|
|
245
|
+
----------
|
|
246
|
+
documents : List[str]
|
|
247
|
+
List of documents to tokenize
|
|
248
|
+
language : str
|
|
249
|
+
ISO 639-1 language code
|
|
250
|
+
tokenizer_type : str
|
|
251
|
+
Type of tokenizer to use
|
|
252
|
+
min_length : int
|
|
253
|
+
Minimum token length to keep
|
|
254
|
+
max_length : int
|
|
255
|
+
Maximum token length to keep
|
|
256
|
+
|
|
257
|
+
Returns
|
|
258
|
+
-------
|
|
259
|
+
List[List[str]]
|
|
260
|
+
List of tokenized documents
|
|
261
|
+
"""
|
|
262
|
+
tokenizer = TokenizerFactory.get_tokenizer(language, tokenizer_type)
|
|
263
|
+
|
|
264
|
+
tokenized = []
|
|
265
|
+
for doc in documents:
|
|
266
|
+
if not isinstance(doc, str):
|
|
267
|
+
tokenized.append([])
|
|
268
|
+
continue
|
|
269
|
+
|
|
270
|
+
tokens = tokenizer(doc)
|
|
271
|
+
# Filter by length
|
|
272
|
+
tokens = [t for t in tokens if min_length <= len(t) <= max_length]
|
|
273
|
+
tokenized.append(tokens)
|
|
274
|
+
|
|
275
|
+
return tokenized
|
|
@@ -0,0 +1,371 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Visualization functions for TriTopic.
|
|
3
|
+
|
|
4
|
+
Provides interactive visualizations using Plotly.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from typing import Any, Dict, List, Optional, TYPE_CHECKING
|
|
10
|
+
|
|
11
|
+
import numpy as np
|
|
12
|
+
|
|
13
|
+
if TYPE_CHECKING:
|
|
14
|
+
from .model import Topic
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def create_topic_visualization(
|
|
18
|
+
embeddings: np.ndarray,
|
|
19
|
+
labels: np.ndarray,
|
|
20
|
+
topics: List["Topic"],
|
|
21
|
+
documents: List[str],
|
|
22
|
+
method: str = "umap",
|
|
23
|
+
width: int = 900,
|
|
24
|
+
height: int = 700,
|
|
25
|
+
point_size: int = 5,
|
|
26
|
+
**kwargs
|
|
27
|
+
) -> Any:
|
|
28
|
+
"""
|
|
29
|
+
Create a 2D visualization of documents colored by topic.
|
|
30
|
+
|
|
31
|
+
Parameters
|
|
32
|
+
----------
|
|
33
|
+
embeddings : np.ndarray
|
|
34
|
+
Document embeddings.
|
|
35
|
+
labels : np.ndarray
|
|
36
|
+
Topic assignments.
|
|
37
|
+
topics : List[Topic]
|
|
38
|
+
Topic objects with labels.
|
|
39
|
+
documents : List[str]
|
|
40
|
+
Original documents for hover text.
|
|
41
|
+
method : str
|
|
42
|
+
Reduction method: "umap", "tsne", or "pca".
|
|
43
|
+
width, height : int
|
|
44
|
+
Figure dimensions.
|
|
45
|
+
point_size : int
|
|
46
|
+
Size of scatter points.
|
|
47
|
+
|
|
48
|
+
Returns
|
|
49
|
+
-------
|
|
50
|
+
plotly.graph_objects.Figure
|
|
51
|
+
Interactive scatter plot.
|
|
52
|
+
"""
|
|
53
|
+
try:
|
|
54
|
+
import plotly.express as px
|
|
55
|
+
import plotly.graph_objects as go
|
|
56
|
+
except ImportError:
|
|
57
|
+
raise ImportError("Plotly required. Install with: pip install plotly")
|
|
58
|
+
|
|
59
|
+
# Reduce dimensions
|
|
60
|
+
coords = _reduce_dimensions(embeddings, method, **kwargs)
|
|
61
|
+
|
|
62
|
+
# Create topic labels for legend
|
|
63
|
+
topic_map = {t.topic_id: t.label or f"Topic {t.topic_id}" for t in topics}
|
|
64
|
+
topic_names = [topic_map.get(l, "Outlier") for l in labels]
|
|
65
|
+
|
|
66
|
+
# Create hover text
|
|
67
|
+
hover_texts = [doc[:200] + "..." if len(doc) > 200 else doc for doc in documents]
|
|
68
|
+
|
|
69
|
+
# Create figure
|
|
70
|
+
fig = go.Figure()
|
|
71
|
+
|
|
72
|
+
# Add scatter for each topic
|
|
73
|
+
unique_labels = sorted(set(labels))
|
|
74
|
+
colors = px.colors.qualitative.Set3 + px.colors.qualitative.Pastel
|
|
75
|
+
|
|
76
|
+
for i, topic_id in enumerate(unique_labels):
|
|
77
|
+
mask = labels == topic_id
|
|
78
|
+
topic_name = topic_map.get(topic_id, "Outlier")
|
|
79
|
+
color = colors[i % len(colors)] if topic_id >= 0 else "lightgray"
|
|
80
|
+
|
|
81
|
+
fig.add_trace(go.Scatter(
|
|
82
|
+
x=coords[mask, 0],
|
|
83
|
+
y=coords[mask, 1],
|
|
84
|
+
mode="markers",
|
|
85
|
+
name=topic_name,
|
|
86
|
+
marker=dict(
|
|
87
|
+
size=point_size,
|
|
88
|
+
color=color,
|
|
89
|
+
opacity=0.7 if topic_id >= 0 else 0.3
|
|
90
|
+
),
|
|
91
|
+
text=[hover_texts[j] for j in np.where(mask)[0]],
|
|
92
|
+
hovertemplate="<b>%{text}</b><extra></extra>"
|
|
93
|
+
))
|
|
94
|
+
|
|
95
|
+
fig.update_layout(
|
|
96
|
+
title="TriTopic Document Map",
|
|
97
|
+
xaxis_title=f"{method.upper()} 1",
|
|
98
|
+
yaxis_title=f"{method.upper()} 2",
|
|
99
|
+
width=width,
|
|
100
|
+
height=height,
|
|
101
|
+
template="plotly_white",
|
|
102
|
+
legend=dict(
|
|
103
|
+
yanchor="top",
|
|
104
|
+
y=0.99,
|
|
105
|
+
xanchor="left",
|
|
106
|
+
x=1.02
|
|
107
|
+
)
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
return fig
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def create_topic_barchart(
|
|
114
|
+
topics: List["Topic"],
|
|
115
|
+
n_keywords: int = 10,
|
|
116
|
+
width: int = 800,
|
|
117
|
+
height: int = None
|
|
118
|
+
) -> Any:
|
|
119
|
+
"""
|
|
120
|
+
Create horizontal bar charts showing top keywords per topic.
|
|
121
|
+
|
|
122
|
+
Parameters
|
|
123
|
+
----------
|
|
124
|
+
topics : List[Topic]
|
|
125
|
+
Topic objects with keywords.
|
|
126
|
+
n_keywords : int
|
|
127
|
+
Number of keywords to show.
|
|
128
|
+
width : int
|
|
129
|
+
Figure width.
|
|
130
|
+
height : int, optional
|
|
131
|
+
Figure height. Auto-calculated if None.
|
|
132
|
+
|
|
133
|
+
Returns
|
|
134
|
+
-------
|
|
135
|
+
plotly.graph_objects.Figure
|
|
136
|
+
Bar chart figure.
|
|
137
|
+
"""
|
|
138
|
+
try:
|
|
139
|
+
import plotly.graph_objects as go
|
|
140
|
+
from plotly.subplots import make_subplots
|
|
141
|
+
except ImportError:
|
|
142
|
+
raise ImportError("Plotly required.")
|
|
143
|
+
|
|
144
|
+
# Filter out outliers
|
|
145
|
+
valid_topics = [t for t in topics if t.topic_id >= 0]
|
|
146
|
+
n_topics = len(valid_topics)
|
|
147
|
+
|
|
148
|
+
if n_topics == 0:
|
|
149
|
+
raise ValueError("No valid topics to visualize")
|
|
150
|
+
|
|
151
|
+
# Calculate layout
|
|
152
|
+
n_cols = min(3, n_topics)
|
|
153
|
+
n_rows = (n_topics + n_cols - 1) // n_cols
|
|
154
|
+
|
|
155
|
+
if height is None:
|
|
156
|
+
height = n_rows * 250
|
|
157
|
+
|
|
158
|
+
# Create subplots
|
|
159
|
+
subplot_titles = [
|
|
160
|
+
t.label or f"Topic {t.topic_id}" for t in valid_topics
|
|
161
|
+
]
|
|
162
|
+
|
|
163
|
+
fig = make_subplots(
|
|
164
|
+
rows=n_rows,
|
|
165
|
+
cols=n_cols,
|
|
166
|
+
subplot_titles=subplot_titles,
|
|
167
|
+
horizontal_spacing=0.1,
|
|
168
|
+
vertical_spacing=0.15
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
# Add bar for each topic
|
|
172
|
+
for i, topic in enumerate(valid_topics):
|
|
173
|
+
row = i // n_cols + 1
|
|
174
|
+
col = i % n_cols + 1
|
|
175
|
+
|
|
176
|
+
keywords = topic.keywords[:n_keywords][::-1] # Reverse for horizontal bars
|
|
177
|
+
scores = topic.keyword_scores[:n_keywords][::-1]
|
|
178
|
+
|
|
179
|
+
fig.add_trace(
|
|
180
|
+
go.Bar(
|
|
181
|
+
x=scores,
|
|
182
|
+
y=keywords,
|
|
183
|
+
orientation='h',
|
|
184
|
+
marker_color='steelblue',
|
|
185
|
+
showlegend=False
|
|
186
|
+
),
|
|
187
|
+
row=row,
|
|
188
|
+
col=col
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
fig.update_layout(
|
|
192
|
+
title="Topic Keywords",
|
|
193
|
+
width=width,
|
|
194
|
+
height=height,
|
|
195
|
+
template="plotly_white"
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
return fig
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
def create_topic_hierarchy(
|
|
202
|
+
embeddings: np.ndarray,
|
|
203
|
+
labels: np.ndarray,
|
|
204
|
+
topics: List["Topic"],
|
|
205
|
+
method: str = "ward",
|
|
206
|
+
**kwargs
|
|
207
|
+
) -> Any:
|
|
208
|
+
"""
|
|
209
|
+
Create a hierarchical clustering dendrogram of topics.
|
|
210
|
+
|
|
211
|
+
Parameters
|
|
212
|
+
----------
|
|
213
|
+
embeddings : np.ndarray
|
|
214
|
+
Document embeddings.
|
|
215
|
+
labels : np.ndarray
|
|
216
|
+
Topic assignments.
|
|
217
|
+
topics : List[Topic]
|
|
218
|
+
Topic objects.
|
|
219
|
+
method : str
|
|
220
|
+
Linkage method for hierarchical clustering.
|
|
221
|
+
|
|
222
|
+
Returns
|
|
223
|
+
-------
|
|
224
|
+
plotly.graph_objects.Figure
|
|
225
|
+
Dendrogram figure.
|
|
226
|
+
"""
|
|
227
|
+
try:
|
|
228
|
+
import plotly.figure_factory as ff
|
|
229
|
+
except ImportError:
|
|
230
|
+
raise ImportError("Plotly required.")
|
|
231
|
+
|
|
232
|
+
from scipy.cluster.hierarchy import linkage
|
|
233
|
+
from scipy.spatial.distance import pdist
|
|
234
|
+
|
|
235
|
+
# Compute topic centroids
|
|
236
|
+
valid_topics = [t for t in topics if t.topic_id >= 0 and t.centroid is not None]
|
|
237
|
+
|
|
238
|
+
if len(valid_topics) < 2:
|
|
239
|
+
raise ValueError("Need at least 2 topics for hierarchy")
|
|
240
|
+
|
|
241
|
+
centroids = np.array([t.centroid for t in valid_topics])
|
|
242
|
+
topic_labels = [t.label or f"Topic {t.topic_id}" for t in valid_topics]
|
|
243
|
+
|
|
244
|
+
# Compute linkage
|
|
245
|
+
distances = pdist(centroids, metric='cosine')
|
|
246
|
+
Z = linkage(distances, method=method)
|
|
247
|
+
|
|
248
|
+
# Create dendrogram
|
|
249
|
+
fig = ff.create_dendrogram(
|
|
250
|
+
centroids,
|
|
251
|
+
orientation='left',
|
|
252
|
+
labels=topic_labels,
|
|
253
|
+
linkagefun=lambda x: Z
|
|
254
|
+
)
|
|
255
|
+
|
|
256
|
+
fig.update_layout(
|
|
257
|
+
title="Topic Hierarchy",
|
|
258
|
+
width=800,
|
|
259
|
+
height=max(400, len(valid_topics) * 30),
|
|
260
|
+
template="plotly_white"
|
|
261
|
+
)
|
|
262
|
+
|
|
263
|
+
return fig
|
|
264
|
+
|
|
265
|
+
|
|
266
|
+
def create_topic_heatmap(
|
|
267
|
+
topics: List["Topic"],
|
|
268
|
+
documents: List[str],
|
|
269
|
+
labels: np.ndarray,
|
|
270
|
+
n_top_keywords: int = 20
|
|
271
|
+
) -> Any:
|
|
272
|
+
"""
|
|
273
|
+
Create a heatmap showing keyword importance across topics.
|
|
274
|
+
|
|
275
|
+
Parameters
|
|
276
|
+
----------
|
|
277
|
+
topics : List[Topic]
|
|
278
|
+
Topic objects.
|
|
279
|
+
documents : List[str]
|
|
280
|
+
Original documents.
|
|
281
|
+
labels : np.ndarray
|
|
282
|
+
Topic assignments.
|
|
283
|
+
n_top_keywords : int
|
|
284
|
+
Number of keywords to include.
|
|
285
|
+
|
|
286
|
+
Returns
|
|
287
|
+
-------
|
|
288
|
+
plotly.graph_objects.Figure
|
|
289
|
+
Heatmap figure.
|
|
290
|
+
"""
|
|
291
|
+
try:
|
|
292
|
+
import plotly.graph_objects as go
|
|
293
|
+
except ImportError:
|
|
294
|
+
raise ImportError("Plotly required.")
|
|
295
|
+
|
|
296
|
+
# Collect all unique keywords
|
|
297
|
+
valid_topics = [t for t in topics if t.topic_id >= 0]
|
|
298
|
+
all_keywords = set()
|
|
299
|
+
for topic in valid_topics:
|
|
300
|
+
all_keywords.update(topic.keywords[:n_top_keywords])
|
|
301
|
+
|
|
302
|
+
all_keywords = sorted(all_keywords)
|
|
303
|
+
|
|
304
|
+
# Build matrix
|
|
305
|
+
matrix = np.zeros((len(valid_topics), len(all_keywords)))
|
|
306
|
+
|
|
307
|
+
for i, topic in enumerate(valid_topics):
|
|
308
|
+
for j, kw in enumerate(all_keywords):
|
|
309
|
+
if kw in topic.keywords:
|
|
310
|
+
idx = topic.keywords.index(kw)
|
|
311
|
+
if idx < len(topic.keyword_scores):
|
|
312
|
+
matrix[i, j] = topic.keyword_scores[idx]
|
|
313
|
+
|
|
314
|
+
# Create heatmap
|
|
315
|
+
topic_names = [t.label or f"Topic {t.topic_id}" for t in valid_topics]
|
|
316
|
+
|
|
317
|
+
fig = go.Figure(data=go.Heatmap(
|
|
318
|
+
z=matrix,
|
|
319
|
+
x=all_keywords,
|
|
320
|
+
y=topic_names,
|
|
321
|
+
colorscale='Blues'
|
|
322
|
+
))
|
|
323
|
+
|
|
324
|
+
fig.update_layout(
|
|
325
|
+
title="Keyword Importance by Topic",
|
|
326
|
+
xaxis_title="Keywords",
|
|
327
|
+
yaxis_title="Topics",
|
|
328
|
+
width=max(800, len(all_keywords) * 20),
|
|
329
|
+
height=max(400, len(valid_topics) * 30),
|
|
330
|
+
template="plotly_white"
|
|
331
|
+
)
|
|
332
|
+
|
|
333
|
+
return fig
|
|
334
|
+
|
|
335
|
+
|
|
336
|
+
def _reduce_dimensions(
|
|
337
|
+
embeddings: np.ndarray,
|
|
338
|
+
method: str = "umap",
|
|
339
|
+
n_components: int = 2,
|
|
340
|
+
**kwargs
|
|
341
|
+
) -> np.ndarray:
|
|
342
|
+
"""Reduce embedding dimensions for visualization."""
|
|
343
|
+
|
|
344
|
+
if method.lower() == "umap":
|
|
345
|
+
try:
|
|
346
|
+
from umap import UMAP
|
|
347
|
+
reducer = UMAP(
|
|
348
|
+
n_components=n_components,
|
|
349
|
+
n_neighbors=kwargs.get("n_neighbors", 15),
|
|
350
|
+
min_dist=kwargs.get("min_dist", 0.1),
|
|
351
|
+
metric=kwargs.get("metric", "cosine"),
|
|
352
|
+
random_state=kwargs.get("random_state", 42)
|
|
353
|
+
)
|
|
354
|
+
return reducer.fit_transform(embeddings)
|
|
355
|
+
except ImportError:
|
|
356
|
+
print("UMAP not available, falling back to PCA")
|
|
357
|
+
method = "pca"
|
|
358
|
+
|
|
359
|
+
if method.lower() == "tsne":
|
|
360
|
+
from sklearn.manifold import TSNE
|
|
361
|
+
reducer = TSNE(
|
|
362
|
+
n_components=n_components,
|
|
363
|
+
perplexity=kwargs.get("perplexity", 30),
|
|
364
|
+
random_state=kwargs.get("random_state", 42)
|
|
365
|
+
)
|
|
366
|
+
return reducer.fit_transform(embeddings)
|
|
367
|
+
|
|
368
|
+
# Default: PCA
|
|
369
|
+
from sklearn.decomposition import PCA
|
|
370
|
+
reducer = PCA(n_components=n_components)
|
|
371
|
+
return reducer.fit_transform(embeddings)
|