tritopic 0.1.0__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,523 +0,0 @@
1
- """
2
- Topic Visualization
3
- ====================
4
-
5
- Interactive visualizations for topic models using:
6
- - UMAP / PaCMAP for dimensionality reduction
7
- - Plotly for interactive plots
8
- """
9
-
10
- from __future__ import annotations
11
-
12
- from typing import Any, Literal
13
-
14
- import numpy as np
15
- import plotly.express as px
16
- import plotly.graph_objects as go
17
- from plotly.subplots import make_subplots
18
-
19
-
20
- class TopicVisualizer:
21
- """
22
- Visualize topics and documents.
23
-
24
- Provides various visualization methods for exploring topic models.
25
-
26
- Parameters
27
- ----------
28
- method : str
29
- Dimensionality reduction method: "umap" or "pacmap"
30
- random_state : int
31
- Random seed for reproducibility.
32
- """
33
-
34
- def __init__(
35
- self,
36
- method: Literal["umap", "pacmap"] = "umap",
37
- random_state: int = 42,
38
- ):
39
- self.method = method
40
- self.random_state = random_state
41
-
42
- self._reducer = None
43
- self._reduced_embeddings = None
44
-
45
- def _reduce_dimensions(
46
- self,
47
- embeddings: np.ndarray,
48
- n_components: int = 2,
49
- ) -> np.ndarray:
50
- """Reduce embeddings to 2D for visualization."""
51
- if self.method == "umap":
52
- from umap import UMAP
53
-
54
- reducer = UMAP(
55
- n_components=n_components,
56
- n_neighbors=15,
57
- min_dist=0.1,
58
- metric="cosine",
59
- random_state=self.random_state,
60
- )
61
- else: # pacmap
62
- try:
63
- from pacmap import PaCMAP
64
- reducer = PaCMAP(
65
- n_components=n_components,
66
- random_state=self.random_state,
67
- )
68
- except ImportError:
69
- # Fallback to UMAP
70
- from umap import UMAP
71
- reducer = UMAP(
72
- n_components=n_components,
73
- random_state=self.random_state,
74
- )
75
-
76
- self._reducer = reducer
77
- self._reduced_embeddings = reducer.fit_transform(embeddings)
78
-
79
- return self._reduced_embeddings
80
-
81
- def plot_documents(
82
- self,
83
- embeddings: np.ndarray,
84
- labels: np.ndarray,
85
- documents: list[str] | None = None,
86
- topics: list | None = None,
87
- show_outliers: bool = True,
88
- interactive: bool = True,
89
- title: str = "Topic Document Map",
90
- width: int = 900,
91
- height: int = 700,
92
- **kwargs,
93
- ) -> go.Figure:
94
- """
95
- Plot documents in 2D space colored by topic.
96
-
97
- Parameters
98
- ----------
99
- embeddings : np.ndarray
100
- Document embeddings.
101
- labels : np.ndarray
102
- Topic assignments.
103
- documents : list[str], optional
104
- Document texts for hover info.
105
- topics : list[TopicInfo], optional
106
- Topic info for labels.
107
- show_outliers : bool
108
- Whether to show outlier documents.
109
- interactive : bool
110
- Create interactive Plotly figure.
111
- title : str
112
- Plot title.
113
- width, height : int
114
- Figure dimensions.
115
-
116
- Returns
117
- -------
118
- fig : go.Figure
119
- Plotly figure.
120
- """
121
- # Reduce dimensions
122
- coords = self._reduce_dimensions(embeddings, n_components=2)
123
-
124
- # Prepare data
125
- mask = np.ones(len(labels), dtype=bool)
126
- if not show_outliers:
127
- mask = labels != -1
128
-
129
- x = coords[mask, 0]
130
- y = coords[mask, 1]
131
- topic_labels = labels[mask]
132
-
133
- # Create hover text
134
- if documents:
135
- hover_texts = []
136
- for i, idx in enumerate(np.where(mask)[0]):
137
- doc = documents[idx]
138
- # Truncate long documents
139
- if len(doc) > 200:
140
- doc = doc[:200] + "..."
141
- topic_id = labels[idx]
142
-
143
- # Get topic label if available
144
- topic_name = f"Topic {topic_id}"
145
- if topics:
146
- for t in topics:
147
- if t.topic_id == topic_id:
148
- topic_name = t.label or f"Topic {topic_id}"
149
- break
150
-
151
- hover_texts.append(f"<b>{topic_name}</b><br>{doc}")
152
- else:
153
- hover_texts = [f"Topic {l}" for l in topic_labels]
154
-
155
- # Create color mapping
156
- unique_labels = sorted(np.unique(topic_labels))
157
- n_topics = len([l for l in unique_labels if l != -1])
158
-
159
- # Use a good colorscale
160
- colors = px.colors.qualitative.Set2 + px.colors.qualitative.Set3
161
- color_map = {}
162
- color_idx = 0
163
- for label in unique_labels:
164
- if label == -1:
165
- color_map[-1] = "lightgray"
166
- else:
167
- color_map[label] = colors[color_idx % len(colors)]
168
- color_idx += 1
169
-
170
- point_colors = [color_map[l] for l in topic_labels]
171
-
172
- # Create figure
173
- fig = go.Figure()
174
-
175
- # Add scatter for each topic (for legend)
176
- for label in unique_labels:
177
- topic_mask = topic_labels == label
178
-
179
- # Get topic name
180
- topic_name = "Outliers" if label == -1 else f"Topic {label}"
181
- if topics:
182
- for t in topics:
183
- if t.topic_id == label:
184
- topic_name = t.label or f"Topic {label}"
185
- break
186
-
187
- fig.add_trace(go.Scatter(
188
- x=x[topic_mask],
189
- y=y[topic_mask],
190
- mode="markers",
191
- name=topic_name,
192
- marker=dict(
193
- color=color_map[label],
194
- size=6 if label != -1 else 4,
195
- opacity=0.7 if label != -1 else 0.3,
196
- ),
197
- text=[hover_texts[i] for i in np.where(topic_mask)[0]],
198
- hovertemplate="%{text}<extra></extra>",
199
- ))
200
-
201
- # Update layout
202
- fig.update_layout(
203
- title=dict(text=title, font=dict(size=16)),
204
- width=width,
205
- height=height,
206
- xaxis=dict(title=f"{self.method.upper()} 1", showgrid=False),
207
- yaxis=dict(title=f"{self.method.upper()} 2", showgrid=False),
208
- legend=dict(
209
- orientation="v",
210
- yanchor="top",
211
- y=1,
212
- xanchor="left",
213
- x=1.02,
214
- ),
215
- template="plotly_white",
216
- )
217
-
218
- return fig
219
-
220
- def plot_topics(
221
- self,
222
- topics: list,
223
- n_keywords: int = 8,
224
- title: str = "Topics Overview",
225
- width: int = 900,
226
- height: int = None,
227
- ) -> go.Figure:
228
- """
229
- Plot topics as horizontal bar charts of keywords.
230
-
231
- Parameters
232
- ----------
233
- topics : list[TopicInfo]
234
- Topic information objects.
235
- n_keywords : int
236
- Number of keywords to show per topic.
237
- title : str
238
- Plot title.
239
- width, height : int
240
- Figure dimensions.
241
-
242
- Returns
243
- -------
244
- fig : go.Figure
245
- Plotly figure.
246
- """
247
- # Filter out outliers and sort by size
248
- valid_topics = [t for t in topics if t.topic_id != -1]
249
- valid_topics = sorted(valid_topics, key=lambda t: -t.size)
250
-
251
- n_topics = len(valid_topics)
252
- if height is None:
253
- height = max(400, 80 * n_topics)
254
-
255
- # Create subplots
256
- fig = make_subplots(
257
- rows=n_topics,
258
- cols=1,
259
- subplot_titles=[
260
- f"{t.label or f'Topic {t.topic_id}'} (n={t.size})"
261
- for t in valid_topics
262
- ],
263
- vertical_spacing=0.08,
264
- )
265
-
266
- colors = px.colors.qualitative.Set2
267
-
268
- for i, topic in enumerate(valid_topics):
269
- keywords = topic.keywords[:n_keywords]
270
- scores = topic.keyword_scores[:n_keywords]
271
-
272
- # Normalize scores
273
- max_score = max(scores) if scores else 1
274
- scores = [s / max_score for s in scores]
275
-
276
- fig.add_trace(
277
- go.Bar(
278
- x=scores[::-1],
279
- y=keywords[::-1],
280
- orientation="h",
281
- marker_color=colors[i % len(colors)],
282
- showlegend=False,
283
- ),
284
- row=i + 1,
285
- col=1,
286
- )
287
-
288
- fig.update_layout(
289
- title=dict(text=title, font=dict(size=16)),
290
- width=width,
291
- height=height,
292
- template="plotly_white",
293
- )
294
-
295
- return fig
296
-
297
- def plot_hierarchy(
298
- self,
299
- topic_embeddings: np.ndarray,
300
- topics: list,
301
- title: str = "Topic Hierarchy",
302
- width: int = 800,
303
- height: int = 500,
304
- ) -> go.Figure:
305
- """
306
- Plot topic hierarchy as a dendrogram.
307
-
308
- Parameters
309
- ----------
310
- topic_embeddings : np.ndarray
311
- Centroid embeddings for each topic.
312
- topics : list[TopicInfo]
313
- Topic information objects.
314
- title : str
315
- Plot title.
316
- width, height : int
317
- Figure dimensions.
318
-
319
- Returns
320
- -------
321
- fig : go.Figure
322
- Plotly figure.
323
- """
324
- from scipy.cluster.hierarchy import linkage, dendrogram
325
- from scipy.spatial.distance import pdist
326
-
327
- # Filter valid topics
328
- valid_topics = [t for t in topics if t.topic_id != -1]
329
-
330
- if len(valid_topics) < 2:
331
- # Not enough topics for hierarchy
332
- fig = go.Figure()
333
- fig.add_annotation(
334
- text="Need at least 2 topics for hierarchy",
335
- xref="paper", yref="paper",
336
- x=0.5, y=0.5,
337
- showarrow=False,
338
- )
339
- return fig
340
-
341
- # Compute linkage
342
- distances = pdist(topic_embeddings, metric="cosine")
343
- Z = linkage(distances, method="ward")
344
-
345
- # Create dendrogram
346
- labels = [t.label or f"Topic {t.topic_id}" for t in valid_topics]
347
-
348
- # Use scipy's dendrogram to get coordinates
349
- dendro = dendrogram(Z, labels=labels, no_plot=True)
350
-
351
- # Create plotly figure
352
- fig = go.Figure()
353
-
354
- # Add lines for dendrogram
355
- icoord = dendro["icoord"]
356
- dcoord = dendro["dcoord"]
357
-
358
- for xs, ys in zip(icoord, dcoord):
359
- fig.add_trace(go.Scatter(
360
- x=xs,
361
- y=ys,
362
- mode="lines",
363
- line=dict(color="#636EFA", width=2),
364
- showlegend=False,
365
- ))
366
-
367
- # Add labels
368
- fig.update_layout(
369
- title=dict(text=title, font=dict(size=16)),
370
- width=width,
371
- height=height,
372
- xaxis=dict(
373
- ticktext=dendro["ivl"],
374
- tickvals=list(range(5, len(dendro["ivl"]) * 10, 10)),
375
- tickangle=45,
376
- ),
377
- yaxis=dict(title="Distance"),
378
- template="plotly_white",
379
- )
380
-
381
- return fig
382
-
383
- def plot_topic_similarity(
384
- self,
385
- topic_embeddings: np.ndarray,
386
- topics: list,
387
- title: str = "Topic Similarity",
388
- width: int = 600,
389
- height: int = 600,
390
- ) -> go.Figure:
391
- """
392
- Plot topic similarity as a heatmap.
393
-
394
- Parameters
395
- ----------
396
- topic_embeddings : np.ndarray
397
- Centroid embeddings for each topic.
398
- topics : list[TopicInfo]
399
- Topic information objects.
400
- title : str
401
- Plot title.
402
- width, height : int
403
- Figure dimensions.
404
-
405
- Returns
406
- -------
407
- fig : go.Figure
408
- Plotly figure.
409
- """
410
- from sklearn.metrics.pairwise import cosine_similarity
411
-
412
- # Filter valid topics
413
- valid_topics = [t for t in topics if t.topic_id != -1]
414
-
415
- # Compute similarity matrix
416
- sim_matrix = cosine_similarity(topic_embeddings)
417
-
418
- # Labels
419
- labels = [t.label or f"Topic {t.topic_id}" for t in valid_topics]
420
-
421
- # Create heatmap
422
- fig = go.Figure(data=go.Heatmap(
423
- z=sim_matrix,
424
- x=labels,
425
- y=labels,
426
- colorscale="RdBu",
427
- zmid=0.5,
428
- text=np.round(sim_matrix, 2),
429
- texttemplate="%{text}",
430
- textfont={"size": 10},
431
- hovertemplate="Similarity: %{z:.3f}<extra></extra>",
432
- ))
433
-
434
- fig.update_layout(
435
- title=dict(text=title, font=dict(size=16)),
436
- width=width,
437
- height=height,
438
- xaxis=dict(tickangle=45),
439
- template="plotly_white",
440
- )
441
-
442
- return fig
443
-
444
- def plot_topic_over_time(
445
- self,
446
- labels: np.ndarray,
447
- timestamps: list,
448
- topics: list | None = None,
449
- title: str = "Topics Over Time",
450
- width: int = 900,
451
- height: int = 500,
452
- ) -> go.Figure:
453
- """
454
- Plot topic distribution over time.
455
-
456
- Parameters
457
- ----------
458
- labels : np.ndarray
459
- Topic assignments.
460
- timestamps : list
461
- Timestamps for each document.
462
- topics : list[TopicInfo], optional
463
- Topic information for labels.
464
- title : str
465
- Plot title.
466
- width, height : int
467
- Figure dimensions.
468
-
469
- Returns
470
- -------
471
- fig : go.Figure
472
- Plotly figure.
473
- """
474
- import pandas as pd
475
-
476
- # Create dataframe
477
- df = pd.DataFrame({
478
- "topic": labels,
479
- "timestamp": pd.to_datetime(timestamps),
480
- })
481
-
482
- # Filter outliers
483
- df = df[df["topic"] != -1]
484
-
485
- # Group by time and topic
486
- df["period"] = df["timestamp"].dt.to_period("M").dt.to_timestamp()
487
- counts = df.groupby(["period", "topic"]).size().unstack(fill_value=0)
488
-
489
- # Normalize to percentages
490
- counts = counts.div(counts.sum(axis=1), axis=0) * 100
491
-
492
- # Create figure
493
- fig = go.Figure()
494
-
495
- colors = px.colors.qualitative.Set2
496
-
497
- for i, topic_id in enumerate(counts.columns):
498
- topic_name = f"Topic {topic_id}"
499
- if topics:
500
- for t in topics:
501
- if t.topic_id == topic_id:
502
- topic_name = t.label or f"Topic {topic_id}"
503
- break
504
-
505
- fig.add_trace(go.Scatter(
506
- x=counts.index,
507
- y=counts[topic_id],
508
- name=topic_name,
509
- mode="lines",
510
- stackgroup="one",
511
- line=dict(color=colors[i % len(colors)]),
512
- ))
513
-
514
- fig.update_layout(
515
- title=dict(text=title, font=dict(size=16)),
516
- width=width,
517
- height=height,
518
- xaxis=dict(title="Time"),
519
- yaxis=dict(title="Topic Share (%)", range=[0, 100]),
520
- template="plotly_white",
521
- )
522
-
523
- return fig
@@ -1,18 +0,0 @@
1
- tritopic/__init__.py,sha256=KNtwfPUJANQtRLf-PUkoglz4u8IkHQYC8IQYPnEBf7I,1232
2
- tritopic/core/__init__.py,sha256=vCIaW9iG-to_9Z7J4EpMFXQJnlyBuRUsDImo7rZGprk,476
3
- tritopic/core/clustering.py,sha256=MFaBb_-6qgBdfX3iz8d0etpaSNgkVcsbSksfvqzN84I,10281
4
- tritopic/core/embeddings.py,sha256=F0ceeD0IfpIQUVByglFqR1IahTm9EKBS2VSpRoOMv4s,6320
5
- tritopic/core/graph_builder.py,sha256=PCRC-W_RYuiMOFfzKojGTFkU8ZTyieTXp6fy_LdF5zQ,16568
6
- tritopic/core/keywords.py,sha256=yHMa5QF0tzD2tgj6GBXvRy9yyN3lgO-kiWNn8uQ0HG4,10861
7
- tritopic/core/model.py,sha256=c9Fh72kNh1-fnQzxMKI6inc4VrWUw_66nIMovHrXMtg,28645
8
- tritopic/labeling/__init__.py,sha256=cKLYRklMA4yl_7RS6KiHLrAFqyXaqyMPCVH_Wck1mmc,125
9
- tritopic/labeling/llm_labeler.py,sha256=ZQkA0v-BWEChEGe5jkTdnC4pqjHt1UOCq9bY84zqsg4,8588
10
- tritopic/utils/__init__.py,sha256=R4PPNkUxEBtwzsu52kRKfqHUUayhdcObL9mvIRBLhg8,238
11
- tritopic/utils/metrics.py,sha256=Wr_L7_1TS1Eow485t-so2cLZ5ef6xrVAfVWJXZOcOiA,6938
12
- tritopic/visualization/__init__.py,sha256=bgNdgO5c_4fXv78mPH2X-trx5hWMNiXwVGSnrMzZyUk,136
13
- tritopic/visualization/plotter.py,sha256=cqfg8JbwUHnHDxW0FBuEVhDtJ1OIZ12bLLPVoN-aZHk,15491
14
- tritopic-0.1.0.dist-info/licenses/LICENSE,sha256=jX__n4_wnFJ18weIv0wXDsXnDzsTvMUp94gDDuZTFKE,1068
15
- tritopic-0.1.0.dist-info/METADATA,sha256=V8oOMWVIXoKWqV2DMWWwttTPBU6oDbGM3HFuYrgcMEo,12118
16
- tritopic-0.1.0.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
17
- tritopic-0.1.0.dist-info/top_level.txt,sha256=9PASbqQyi0-wa7E2Hl3Z0u1ae7MwLcfgFliFE1ioFBA,9
18
- tritopic-0.1.0.dist-info/RECORD,,
@@ -1,21 +0,0 @@
1
- MIT License
2
-
3
- Copyright (c) 2025 Roman Egger
4
-
5
- Permission is hereby granted, free of charge, to any person obtaining a copy
6
- of this software and associated documentation files (the "Software"), to deal
7
- in the Software without restriction, including without limitation the rights
8
- to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
- copies of the Software, and to permit persons to whom the Software is
10
- furnished to do so, subject to the following conditions:
11
-
12
- The above copyright notice and this permission notice shall be included in all
13
- copies or substantial portions of the Software.
14
-
15
- THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
- IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
- FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
- AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
- LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
- OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
- SOFTWARE.