tritopic 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,523 @@
1
+ """
2
+ Topic Visualization
3
+ ====================
4
+
5
+ Interactive visualizations for topic models using:
6
+ - UMAP / PaCMAP for dimensionality reduction
7
+ - Plotly for interactive plots
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ from typing import Any, Literal
13
+
14
+ import numpy as np
15
+ import plotly.express as px
16
+ import plotly.graph_objects as go
17
+ from plotly.subplots import make_subplots
18
+
19
+
20
+ class TopicVisualizer:
21
+ """
22
+ Visualize topics and documents.
23
+
24
+ Provides various visualization methods for exploring topic models.
25
+
26
+ Parameters
27
+ ----------
28
+ method : str
29
+ Dimensionality reduction method: "umap" or "pacmap"
30
+ random_state : int
31
+ Random seed for reproducibility.
32
+ """
33
+
34
+ def __init__(
35
+ self,
36
+ method: Literal["umap", "pacmap"] = "umap",
37
+ random_state: int = 42,
38
+ ):
39
+ self.method = method
40
+ self.random_state = random_state
41
+
42
+ self._reducer = None
43
+ self._reduced_embeddings = None
44
+
45
+ def _reduce_dimensions(
46
+ self,
47
+ embeddings: np.ndarray,
48
+ n_components: int = 2,
49
+ ) -> np.ndarray:
50
+ """Reduce embeddings to 2D for visualization."""
51
+ if self.method == "umap":
52
+ from umap import UMAP
53
+
54
+ reducer = UMAP(
55
+ n_components=n_components,
56
+ n_neighbors=15,
57
+ min_dist=0.1,
58
+ metric="cosine",
59
+ random_state=self.random_state,
60
+ )
61
+ else: # pacmap
62
+ try:
63
+ from pacmap import PaCMAP
64
+ reducer = PaCMAP(
65
+ n_components=n_components,
66
+ random_state=self.random_state,
67
+ )
68
+ except ImportError:
69
+ # Fallback to UMAP
70
+ from umap import UMAP
71
+ reducer = UMAP(
72
+ n_components=n_components,
73
+ random_state=self.random_state,
74
+ )
75
+
76
+ self._reducer = reducer
77
+ self._reduced_embeddings = reducer.fit_transform(embeddings)
78
+
79
+ return self._reduced_embeddings
80
+
81
+ def plot_documents(
82
+ self,
83
+ embeddings: np.ndarray,
84
+ labels: np.ndarray,
85
+ documents: list[str] | None = None,
86
+ topics: list | None = None,
87
+ show_outliers: bool = True,
88
+ interactive: bool = True,
89
+ title: str = "Topic Document Map",
90
+ width: int = 900,
91
+ height: int = 700,
92
+ **kwargs,
93
+ ) -> go.Figure:
94
+ """
95
+ Plot documents in 2D space colored by topic.
96
+
97
+ Parameters
98
+ ----------
99
+ embeddings : np.ndarray
100
+ Document embeddings.
101
+ labels : np.ndarray
102
+ Topic assignments.
103
+ documents : list[str], optional
104
+ Document texts for hover info.
105
+ topics : list[TopicInfo], optional
106
+ Topic info for labels.
107
+ show_outliers : bool
108
+ Whether to show outlier documents.
109
+ interactive : bool
110
+ Create interactive Plotly figure.
111
+ title : str
112
+ Plot title.
113
+ width, height : int
114
+ Figure dimensions.
115
+
116
+ Returns
117
+ -------
118
+ fig : go.Figure
119
+ Plotly figure.
120
+ """
121
+ # Reduce dimensions
122
+ coords = self._reduce_dimensions(embeddings, n_components=2)
123
+
124
+ # Prepare data
125
+ mask = np.ones(len(labels), dtype=bool)
126
+ if not show_outliers:
127
+ mask = labels != -1
128
+
129
+ x = coords[mask, 0]
130
+ y = coords[mask, 1]
131
+ topic_labels = labels[mask]
132
+
133
+ # Create hover text
134
+ if documents:
135
+ hover_texts = []
136
+ for i, idx in enumerate(np.where(mask)[0]):
137
+ doc = documents[idx]
138
+ # Truncate long documents
139
+ if len(doc) > 200:
140
+ doc = doc[:200] + "..."
141
+ topic_id = labels[idx]
142
+
143
+ # Get topic label if available
144
+ topic_name = f"Topic {topic_id}"
145
+ if topics:
146
+ for t in topics:
147
+ if t.topic_id == topic_id:
148
+ topic_name = t.label or f"Topic {topic_id}"
149
+ break
150
+
151
+ hover_texts.append(f"<b>{topic_name}</b><br>{doc}")
152
+ else:
153
+ hover_texts = [f"Topic {l}" for l in topic_labels]
154
+
155
+ # Create color mapping
156
+ unique_labels = sorted(np.unique(topic_labels))
157
+ n_topics = len([l for l in unique_labels if l != -1])
158
+
159
+ # Use a good colorscale
160
+ colors = px.colors.qualitative.Set2 + px.colors.qualitative.Set3
161
+ color_map = {}
162
+ color_idx = 0
163
+ for label in unique_labels:
164
+ if label == -1:
165
+ color_map[-1] = "lightgray"
166
+ else:
167
+ color_map[label] = colors[color_idx % len(colors)]
168
+ color_idx += 1
169
+
170
+ point_colors = [color_map[l] for l in topic_labels]
171
+
172
+ # Create figure
173
+ fig = go.Figure()
174
+
175
+ # Add scatter for each topic (for legend)
176
+ for label in unique_labels:
177
+ topic_mask = topic_labels == label
178
+
179
+ # Get topic name
180
+ topic_name = "Outliers" if label == -1 else f"Topic {label}"
181
+ if topics:
182
+ for t in topics:
183
+ if t.topic_id == label:
184
+ topic_name = t.label or f"Topic {label}"
185
+ break
186
+
187
+ fig.add_trace(go.Scatter(
188
+ x=x[topic_mask],
189
+ y=y[topic_mask],
190
+ mode="markers",
191
+ name=topic_name,
192
+ marker=dict(
193
+ color=color_map[label],
194
+ size=6 if label != -1 else 4,
195
+ opacity=0.7 if label != -1 else 0.3,
196
+ ),
197
+ text=[hover_texts[i] for i in np.where(topic_mask)[0]],
198
+ hovertemplate="%{text}<extra></extra>",
199
+ ))
200
+
201
+ # Update layout
202
+ fig.update_layout(
203
+ title=dict(text=title, font=dict(size=16)),
204
+ width=width,
205
+ height=height,
206
+ xaxis=dict(title=f"{self.method.upper()} 1", showgrid=False),
207
+ yaxis=dict(title=f"{self.method.upper()} 2", showgrid=False),
208
+ legend=dict(
209
+ orientation="v",
210
+ yanchor="top",
211
+ y=1,
212
+ xanchor="left",
213
+ x=1.02,
214
+ ),
215
+ template="plotly_white",
216
+ )
217
+
218
+ return fig
219
+
220
+ def plot_topics(
221
+ self,
222
+ topics: list,
223
+ n_keywords: int = 8,
224
+ title: str = "Topics Overview",
225
+ width: int = 900,
226
+ height: int = None,
227
+ ) -> go.Figure:
228
+ """
229
+ Plot topics as horizontal bar charts of keywords.
230
+
231
+ Parameters
232
+ ----------
233
+ topics : list[TopicInfo]
234
+ Topic information objects.
235
+ n_keywords : int
236
+ Number of keywords to show per topic.
237
+ title : str
238
+ Plot title.
239
+ width, height : int
240
+ Figure dimensions.
241
+
242
+ Returns
243
+ -------
244
+ fig : go.Figure
245
+ Plotly figure.
246
+ """
247
+ # Filter out outliers and sort by size
248
+ valid_topics = [t for t in topics if t.topic_id != -1]
249
+ valid_topics = sorted(valid_topics, key=lambda t: -t.size)
250
+
251
+ n_topics = len(valid_topics)
252
+ if height is None:
253
+ height = max(400, 80 * n_topics)
254
+
255
+ # Create subplots
256
+ fig = make_subplots(
257
+ rows=n_topics,
258
+ cols=1,
259
+ subplot_titles=[
260
+ f"{t.label or f'Topic {t.topic_id}'} (n={t.size})"
261
+ for t in valid_topics
262
+ ],
263
+ vertical_spacing=0.08,
264
+ )
265
+
266
+ colors = px.colors.qualitative.Set2
267
+
268
+ for i, topic in enumerate(valid_topics):
269
+ keywords = topic.keywords[:n_keywords]
270
+ scores = topic.keyword_scores[:n_keywords]
271
+
272
+ # Normalize scores
273
+ max_score = max(scores) if scores else 1
274
+ scores = [s / max_score for s in scores]
275
+
276
+ fig.add_trace(
277
+ go.Bar(
278
+ x=scores[::-1],
279
+ y=keywords[::-1],
280
+ orientation="h",
281
+ marker_color=colors[i % len(colors)],
282
+ showlegend=False,
283
+ ),
284
+ row=i + 1,
285
+ col=1,
286
+ )
287
+
288
+ fig.update_layout(
289
+ title=dict(text=title, font=dict(size=16)),
290
+ width=width,
291
+ height=height,
292
+ template="plotly_white",
293
+ )
294
+
295
+ return fig
296
+
297
+ def plot_hierarchy(
298
+ self,
299
+ topic_embeddings: np.ndarray,
300
+ topics: list,
301
+ title: str = "Topic Hierarchy",
302
+ width: int = 800,
303
+ height: int = 500,
304
+ ) -> go.Figure:
305
+ """
306
+ Plot topic hierarchy as a dendrogram.
307
+
308
+ Parameters
309
+ ----------
310
+ topic_embeddings : np.ndarray
311
+ Centroid embeddings for each topic.
312
+ topics : list[TopicInfo]
313
+ Topic information objects.
314
+ title : str
315
+ Plot title.
316
+ width, height : int
317
+ Figure dimensions.
318
+
319
+ Returns
320
+ -------
321
+ fig : go.Figure
322
+ Plotly figure.
323
+ """
324
+ from scipy.cluster.hierarchy import linkage, dendrogram
325
+ from scipy.spatial.distance import pdist
326
+
327
+ # Filter valid topics
328
+ valid_topics = [t for t in topics if t.topic_id != -1]
329
+
330
+ if len(valid_topics) < 2:
331
+ # Not enough topics for hierarchy
332
+ fig = go.Figure()
333
+ fig.add_annotation(
334
+ text="Need at least 2 topics for hierarchy",
335
+ xref="paper", yref="paper",
336
+ x=0.5, y=0.5,
337
+ showarrow=False,
338
+ )
339
+ return fig
340
+
341
+ # Compute linkage
342
+ distances = pdist(topic_embeddings, metric="cosine")
343
+ Z = linkage(distances, method="ward")
344
+
345
+ # Create dendrogram
346
+ labels = [t.label or f"Topic {t.topic_id}" for t in valid_topics]
347
+
348
+ # Use scipy's dendrogram to get coordinates
349
+ dendro = dendrogram(Z, labels=labels, no_plot=True)
350
+
351
+ # Create plotly figure
352
+ fig = go.Figure()
353
+
354
+ # Add lines for dendrogram
355
+ icoord = dendro["icoord"]
356
+ dcoord = dendro["dcoord"]
357
+
358
+ for xs, ys in zip(icoord, dcoord):
359
+ fig.add_trace(go.Scatter(
360
+ x=xs,
361
+ y=ys,
362
+ mode="lines",
363
+ line=dict(color="#636EFA", width=2),
364
+ showlegend=False,
365
+ ))
366
+
367
+ # Add labels
368
+ fig.update_layout(
369
+ title=dict(text=title, font=dict(size=16)),
370
+ width=width,
371
+ height=height,
372
+ xaxis=dict(
373
+ ticktext=dendro["ivl"],
374
+ tickvals=list(range(5, len(dendro["ivl"]) * 10, 10)),
375
+ tickangle=45,
376
+ ),
377
+ yaxis=dict(title="Distance"),
378
+ template="plotly_white",
379
+ )
380
+
381
+ return fig
382
+
383
+ def plot_topic_similarity(
384
+ self,
385
+ topic_embeddings: np.ndarray,
386
+ topics: list,
387
+ title: str = "Topic Similarity",
388
+ width: int = 600,
389
+ height: int = 600,
390
+ ) -> go.Figure:
391
+ """
392
+ Plot topic similarity as a heatmap.
393
+
394
+ Parameters
395
+ ----------
396
+ topic_embeddings : np.ndarray
397
+ Centroid embeddings for each topic.
398
+ topics : list[TopicInfo]
399
+ Topic information objects.
400
+ title : str
401
+ Plot title.
402
+ width, height : int
403
+ Figure dimensions.
404
+
405
+ Returns
406
+ -------
407
+ fig : go.Figure
408
+ Plotly figure.
409
+ """
410
+ from sklearn.metrics.pairwise import cosine_similarity
411
+
412
+ # Filter valid topics
413
+ valid_topics = [t for t in topics if t.topic_id != -1]
414
+
415
+ # Compute similarity matrix
416
+ sim_matrix = cosine_similarity(topic_embeddings)
417
+
418
+ # Labels
419
+ labels = [t.label or f"Topic {t.topic_id}" for t in valid_topics]
420
+
421
+ # Create heatmap
422
+ fig = go.Figure(data=go.Heatmap(
423
+ z=sim_matrix,
424
+ x=labels,
425
+ y=labels,
426
+ colorscale="RdBu",
427
+ zmid=0.5,
428
+ text=np.round(sim_matrix, 2),
429
+ texttemplate="%{text}",
430
+ textfont={"size": 10},
431
+ hovertemplate="Similarity: %{z:.3f}<extra></extra>",
432
+ ))
433
+
434
+ fig.update_layout(
435
+ title=dict(text=title, font=dict(size=16)),
436
+ width=width,
437
+ height=height,
438
+ xaxis=dict(tickangle=45),
439
+ template="plotly_white",
440
+ )
441
+
442
+ return fig
443
+
444
+ def plot_topic_over_time(
445
+ self,
446
+ labels: np.ndarray,
447
+ timestamps: list,
448
+ topics: list | None = None,
449
+ title: str = "Topics Over Time",
450
+ width: int = 900,
451
+ height: int = 500,
452
+ ) -> go.Figure:
453
+ """
454
+ Plot topic distribution over time.
455
+
456
+ Parameters
457
+ ----------
458
+ labels : np.ndarray
459
+ Topic assignments.
460
+ timestamps : list
461
+ Timestamps for each document.
462
+ topics : list[TopicInfo], optional
463
+ Topic information for labels.
464
+ title : str
465
+ Plot title.
466
+ width, height : int
467
+ Figure dimensions.
468
+
469
+ Returns
470
+ -------
471
+ fig : go.Figure
472
+ Plotly figure.
473
+ """
474
+ import pandas as pd
475
+
476
+ # Create dataframe
477
+ df = pd.DataFrame({
478
+ "topic": labels,
479
+ "timestamp": pd.to_datetime(timestamps),
480
+ })
481
+
482
+ # Filter outliers
483
+ df = df[df["topic"] != -1]
484
+
485
+ # Group by time and topic
486
+ df["period"] = df["timestamp"].dt.to_period("M").dt.to_timestamp()
487
+ counts = df.groupby(["period", "topic"]).size().unstack(fill_value=0)
488
+
489
+ # Normalize to percentages
490
+ counts = counts.div(counts.sum(axis=1), axis=0) * 100
491
+
492
+ # Create figure
493
+ fig = go.Figure()
494
+
495
+ colors = px.colors.qualitative.Set2
496
+
497
+ for i, topic_id in enumerate(counts.columns):
498
+ topic_name = f"Topic {topic_id}"
499
+ if topics:
500
+ for t in topics:
501
+ if t.topic_id == topic_id:
502
+ topic_name = t.label or f"Topic {topic_id}"
503
+ break
504
+
505
+ fig.add_trace(go.Scatter(
506
+ x=counts.index,
507
+ y=counts[topic_id],
508
+ name=topic_name,
509
+ mode="lines",
510
+ stackgroup="one",
511
+ line=dict(color=colors[i % len(colors)]),
512
+ ))
513
+
514
+ fig.update_layout(
515
+ title=dict(text=title, font=dict(size=16)),
516
+ width=width,
517
+ height=height,
518
+ xaxis=dict(title="Time"),
519
+ yaxis=dict(title="Topic Share (%)", range=[0, 100]),
520
+ template="plotly_white",
521
+ )
522
+
523
+ return fig