topic-stability 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,7 @@
1
+ """topic_stability — evaluate and visualize topic model stability."""
2
+
3
+ from .run import TopicRun
4
+ from .embeddings import DocumentEmbedder
5
+ from .analysis import StabilityAnalysis
6
+
7
+ __all__ = ["TopicRun", "DocumentEmbedder", "StabilityAnalysis"]
@@ -0,0 +1,14 @@
1
+ import numpy as np
2
+ from scipy.optimize import linear_sum_assignment
3
+
4
+
5
+ def _normalise_rows(m):
6
+ norms = np.linalg.norm(m, axis=1, keepdims=True)
7
+ return m / np.where(norms > 0, norms, 1.0)
8
+
9
+
10
+ def align_to_reference(ref_centroids, cur_centroids):
11
+ """Return perm where perm[k] = index in cur that best matches reference topic k."""
12
+ sim = _normalise_rows(ref_centroids) @ _normalise_rows(cur_centroids).T
13
+ _, perm = linear_sum_assignment(-sim)
14
+ return perm
@@ -0,0 +1,35 @@
1
+ import numpy as np
2
+
3
+
4
+ def _normalise(v):
5
+ s = v.sum()
6
+ return v / s if s > 0 else np.full_like(v, 1.0 / len(v))
7
+
8
+
9
+ def js_divergence(p, q):
10
+ """Jensen-Shannon divergence in [0, 1] (log base 2)."""
11
+ m = (p + q) / 2.0
12
+
13
+ def kl(a, b):
14
+ mask = a > 0
15
+ return np.sum(a[mask] * np.log2(a[mask] / b[mask]))
16
+
17
+ return (kl(p, m) + kl(q, m)) / 2.0
18
+
19
+
20
+ def pairwise_stability(profiles):
21
+ """Mean pairwise (1 - JS) for a list of normalised column vectors."""
22
+ scores = []
23
+ for i in range(len(profiles)):
24
+ for j in range(i + 1, len(profiles)):
25
+ scores.append(1.0 - js_divergence(profiles[i], profiles[j]))
26
+ return float(np.mean(scores)) if scores else 1.0
27
+
28
+
29
+ def doc_profiles(run, permutation):
30
+ """
31
+ Normalised document-profile columns, reordered so column k corresponds to
32
+ reference topic k. Returns ndarray of shape (n_topics, n_docs).
33
+ """
34
+ cols = run.doc_topic[:, permutation].T.copy() # (n_topics, n_docs)
35
+ return np.array([_normalise(col) for col in cols])
@@ -0,0 +1,174 @@
1
+ """StabilityAnalysis: multi-run topic alignment and stability scoring."""
2
+ from __future__ import annotations
3
+
4
+ import numpy as np
5
+
6
+ from ._align import align_to_reference
7
+ from ._metrics import doc_profiles, pairwise_stability
8
+
9
+
10
+ class StabilityAnalysis:
11
+ """Compare multiple topic model runs over the same document set.
12
+
13
+ Parameters
14
+ ----------
15
+ runs: list of TopicRun objects, all with the same n_docs and n_topics
16
+ embeddings: ndarray (n_docs, embedding_dim), or a DocumentEmbedder whose
17
+ cache to load. Required — used for both alignment and UMAP layout.
18
+ doc_ids: explicit ID order to reconcile runs whose rows differ in order.
19
+ If None, run rows are assumed to already be aligned.
20
+
21
+ Workflow
22
+ --------
23
+ analysis = StabilityAnalysis(runs, embeddings)
24
+ analysis.align() # match topics across runs
25
+ scores = analysis.topic_stability() # per-topic stability in [0, 1]
26
+ analysis.visualize("out.png") # small-multiples UMAP plot
27
+ """
28
+
29
+ def __init__(self, runs, embeddings, *, doc_ids=None):
30
+ self.runs = list(runs)
31
+ self.embeddings = _resolve_embeddings(embeddings)
32
+ self._permutations: list[np.ndarray] | None = None
33
+ self._umap_coords: np.ndarray | None = None
34
+
35
+ if doc_ids is not None:
36
+ self.runs = [_reorder_run(run, doc_ids) for run in self.runs]
37
+
38
+ # ── alignment ────────────────────────────────────────────────────────────
39
+
40
+ def _topic_centroids(self, run) -> np.ndarray:
41
+ """Weighted embedding centroid for each topic, L2-normalised.
42
+
43
+ centroid_k = sum_d(theta_{dk} * e_d) / |...|
44
+ Returns ndarray (n_topics, dim).
45
+ """
46
+ weights = run.doc_topic # (n_docs, n_topics)
47
+ centroids = weights.T @ self.embeddings # (n_topics, dim)
48
+ norms = np.linalg.norm(centroids, axis=1, keepdims=True)
49
+ return centroids / np.where(norms > 0, norms, 1.0)
50
+
51
+ def align(self, reference: int = 0) -> None:
52
+ """Align all runs to a reference run via embedding-centroid cosine similarity.
53
+
54
+ For each run, computes the weighted embedding centroid of each topic
55
+ (the same quantity used for UMAP grid layout in a single run) and finds
56
+ the best bijective mapping to reference topics via the Hungarian algorithm.
57
+
58
+ This is model-agnostic: no shared vocabulary is required.
59
+ """
60
+ K = self.runs[reference].n_topics
61
+ ref_centroids = self._topic_centroids(self.runs[reference])
62
+ perms = []
63
+ for i, run in enumerate(self.runs):
64
+ if i == reference:
65
+ perms.append(np.arange(K))
66
+ else:
67
+ cur_centroids = self._topic_centroids(run)
68
+ perms.append(align_to_reference(ref_centroids, cur_centroids))
69
+ self._permutations = perms
70
+
71
+ # ── stability metrics ─────────────────────────────────────────────────────
72
+
73
+ def topic_stability(self) -> np.ndarray:
74
+ """Per-topic stability as mean pairwise (1 - JS divergence) in [0, 1].
75
+
76
+ For each topic k (in reference labeling), collects the normalised
77
+ document-profile column from every run — treating theta[:,k] as a
78
+ distribution over documents — then averages pairwise (1 - JS) scores.
79
+
80
+ 1.0 = perfectly stable across runs; 0.0 = maximally unstable.
81
+ """
82
+ if self._permutations is None:
83
+ raise RuntimeError("Call align() before topic_stability()")
84
+
85
+ K = self.runs[0].n_topics
86
+ all_profiles = [doc_profiles(run, perm) for run, perm in
87
+ zip(self.runs, self._permutations)] # each: (K, n_docs)
88
+
89
+ return np.array([
90
+ pairwise_stability([profiles[k] for profiles in all_profiles])
91
+ for k in range(K)
92
+ ])
93
+
94
+ def overall_stability(self) -> float:
95
+ """Mean stability across all topics."""
96
+ return float(self.topic_stability().mean())
97
+
98
+ # ── UMAP and visualization ────────────────────────────────────────────────
99
+
100
+ def umap_projection(self, **umap_kwargs) -> np.ndarray:
101
+ """Project embeddings to 2D with UMAP. Result is cached on the object.
102
+
103
+ Requires umap-learn: pip install topic-stability[umap]
104
+ """
105
+ if self._umap_coords is not None:
106
+ return self._umap_coords
107
+ try:
108
+ import umap as umap_lib
109
+ except ImportError as e:
110
+ raise ImportError(
111
+ "umap-learn is required for projection. "
112
+ "Install with: pip install topic-stability[umap]"
113
+ ) from e
114
+ params = dict(n_neighbors=15, min_dist=0.1, metric="cosine",
115
+ n_components=2, random_state=42)
116
+ params.update(umap_kwargs)
117
+ self._umap_coords = umap_lib.UMAP(**params).fit_transform(self.embeddings)
118
+ return self._umap_coords
119
+
120
+ def visualize(
121
+ self,
122
+ output_path: str,
123
+ *,
124
+ reference_run: int = 0,
125
+ umap_coords: np.ndarray | None = None,
126
+ ) -> None:
127
+ """Produce a small-multiples UMAP PNG for the reference run.
128
+
129
+ Topics are laid out in a grid whose positions reflect where their
130
+ high-weight documents cluster in the UMAP projection. Stability
131
+ scores are shown in each panel title when align() has been called.
132
+
133
+ umap_coords: pass a precomputed (n_docs, 2) array to skip UMAP.
134
+ """
135
+ from .visualization import plot_topic_grid
136
+
137
+ if umap_coords is None:
138
+ umap_coords = self.umap_projection()
139
+
140
+ stability = self.topic_stability() if self._permutations is not None else None
141
+
142
+ plot_topic_grid(
143
+ umap_coords,
144
+ self.runs[reference_run],
145
+ output_path,
146
+ stability_scores=stability,
147
+ )
148
+
149
+
150
+ # ── helpers ───────────────────────────────────────────────────────────────────
151
+
152
+
153
+ def _resolve_embeddings(embeddings) -> np.ndarray:
154
+ if isinstance(embeddings, np.ndarray):
155
+ return embeddings
156
+ if hasattr(embeddings, "load"):
157
+ arr, _ = embeddings.load()
158
+ return arr
159
+ return np.asarray(embeddings, dtype=float)
160
+
161
+
162
+ def _reorder_run(run, target_ids):
163
+ from .run import TopicRun
164
+
165
+ if run.doc_ids is None:
166
+ raise ValueError("Run has no doc_ids; cannot reorder by target_ids")
167
+ index = {id_: i for i, id_ in enumerate(run.doc_ids)}
168
+ order = [index[id_] for id_ in target_ids]
169
+ return TopicRun(
170
+ doc_topic=run.doc_topic[order],
171
+ doc_ids=list(target_ids),
172
+ word_topic=run.word_topic,
173
+ vocab=run.vocab,
174
+ )
File without changes
@@ -0,0 +1,30 @@
1
+ """CLI: generate and cache document embeddings."""
2
+ import argparse
3
+ import sys
4
+
5
+
6
+ def main():
7
+ parser = argparse.ArgumentParser(
8
+ description="Generate sentence-transformer embeddings for a TSV corpus."
9
+ )
10
+ parser.add_argument("tsv_path", help="Three-column TSV: id, date, text")
11
+ parser.add_argument("output_npy", help="Output .npy path")
12
+ parser.add_argument("--model", default="all-MiniLM-L6-v2",
13
+ help="sentence-transformers model name")
14
+ args = parser.parse_args()
15
+
16
+ import numpy as np
17
+ from ..embeddings import DocumentEmbedder
18
+
19
+ ids, texts = [], []
20
+ with open(args.tsv_path, encoding="utf-8") as f:
21
+ for line in f:
22
+ parts = line.rstrip("\n").split("\t")
23
+ ids.append(parts[0])
24
+ texts.append(parts[2])
25
+
26
+ print(f"Read {len(texts)} documents from {args.tsv_path}")
27
+ embedder = DocumentEmbedder(model=args.model, cache_path=args.output_npy)
28
+ embeddings = embedder.embed(texts, ids=ids)
29
+ print(f"Saved {embeddings.shape} embeddings to {args.output_npy}")
30
+ print(f"Saved IDs to {args.output_npy}.ids")
@@ -0,0 +1,49 @@
1
+ """CLI: estimate distributions from Mallet state files."""
2
+ import argparse
3
+ import os
4
+
5
+
6
+ def main():
7
+ parser = argparse.ArgumentParser(
8
+ description="Estimate doc-topic and word-topic distributions from Mallet state files."
9
+ )
10
+ parser.add_argument("model_dir")
11
+ parser.add_argument("num_topics", type=int)
12
+ parser.add_argument("tsv_path")
13
+ args = parser.parse_args()
14
+
15
+ import csv
16
+ import numpy as np
17
+ from ..io import read_mallet_states
18
+
19
+ doc_topic, word_topic, vocab, doc_ids = read_mallet_states(
20
+ args.model_dir, tsv_path=args.tsv_path
21
+ )
22
+
23
+ if doc_ids and doc_topic.shape[0] != len(doc_ids):
24
+ raise ValueError(
25
+ f"State files have {doc_topic.shape[0]} docs "
26
+ f"but {args.tsv_path} has {len(doc_ids)}"
27
+ )
28
+
29
+ topic_cols = [f"topic_{k}" for k in range(args.num_topics)]
30
+
31
+ # doc_topic_avg.csv
32
+ dt_path = os.path.join(args.model_dir, "doc_topic_avg.csv")
33
+ with open(dt_path, "w", newline="", encoding="utf-8") as f:
34
+ writer = csv.writer(f)
35
+ writer.writerow(["id"] + topic_cols)
36
+ id_list = doc_ids if doc_ids else [str(i) for i in range(len(doc_topic))]
37
+ for id_, row in zip(id_list, doc_topic):
38
+ writer.writerow([id_] + [f"{v:.8f}" for v in row])
39
+
40
+ # word_topic_avg.csv
41
+ wt_path = os.path.join(args.model_dir, "word_topic_avg.csv")
42
+ with open(wt_path, "w", newline="", encoding="utf-8") as f:
43
+ writer = csv.writer(f)
44
+ writer.writerow(["word"] + topic_cols)
45
+ for word, row in zip(vocab, word_topic):
46
+ writer.writerow([word] + [f"{v:.8f}" for v in row])
47
+
48
+ print(f"Wrote {dt_path}")
49
+ print(f"Wrote {wt_path}")
@@ -0,0 +1,45 @@
1
+ """CLI: project embeddings to 2D with UMAP."""
2
+ import argparse
3
+
4
+
5
+ def main():
6
+ parser = argparse.ArgumentParser(
7
+ description="Project document embeddings to 2D using UMAP."
8
+ )
9
+ parser.add_argument("embeddings_npy")
10
+ parser.add_argument("output_csv")
11
+ parser.add_argument("--neighbors", type=int, default=15)
12
+ parser.add_argument("--min-dist", type=float, default=0.1)
13
+ parser.add_argument("--metric", default="cosine")
14
+ args = parser.parse_args()
15
+
16
+ try:
17
+ import umap as umap_lib
18
+ except ImportError as e:
19
+ raise ImportError(
20
+ "umap-learn is required for projection. "
21
+ "Install with: pip install topic-stability[umap]"
22
+ ) from e
23
+
24
+ from ..embeddings import DocumentEmbedder
25
+
26
+ embedder = DocumentEmbedder(cache_path=args.embeddings_npy)
27
+ embeddings, ids = embedder.load()
28
+ print(f"Loaded {embeddings.shape[0]} embeddings ({embeddings.shape[1]} dim)")
29
+
30
+ reducer = umap_lib.UMAP(
31
+ n_neighbors=args.neighbors,
32
+ min_dist=args.min_dist,
33
+ metric=args.metric,
34
+ n_components=2,
35
+ random_state=42,
36
+ )
37
+ coords = reducer.fit_transform(embeddings)
38
+
39
+ id_list = ids if ids else [str(i) for i in range(len(coords))]
40
+ with open(args.output_csv, "w", encoding="utf-8") as f:
41
+ f.write("id,x,y\n")
42
+ for id_, (x, y) in zip(id_list, coords):
43
+ f.write(f"{id_},{x:.6f},{y:.6f}\n")
44
+
45
+ print(f"Saved 2D projection to {args.output_csv}")
@@ -0,0 +1,41 @@
1
+ """CLI: produce small-multiples UMAP topic visualization."""
2
+ import argparse
3
+
4
+
5
+ def main():
6
+ parser = argparse.ArgumentParser(
7
+ description="Small-multiples UMAP visualization coloured by topic weight."
8
+ )
9
+ parser.add_argument("umap_csv", help="CSV with columns id, x, y")
10
+ parser.add_argument("doc_topic_csv", help="doc-topic CSV (id, topic_0, ...)")
11
+ parser.add_argument("word_topic_csv", help="word-topic CSV (word, topic_0, ...)")
12
+ parser.add_argument("output_png")
13
+ args = parser.parse_args()
14
+
15
+ import numpy as np
16
+ from ..run import TopicRun
17
+ from ..visualization import plot_topic_grid
18
+
19
+ umap_ids, rows = [], []
20
+ with open(args.umap_csv, encoding="utf-8") as f:
21
+ next(f) # header
22
+ for line in f:
23
+ parts = line.strip().split(",")
24
+ umap_ids.append(parts[0])
25
+ rows.append([float(parts[1]), float(parts[2])])
26
+ umap_coords = np.array(rows)
27
+
28
+ run = TopicRun.from_csv(args.doc_topic_csv, word_topic_path=args.word_topic_csv)
29
+
30
+ if run.doc_ids and run.doc_ids != umap_ids:
31
+ index = {id_: i for i, id_ in enumerate(run.doc_ids)}
32
+ order = [index[id_] for id_ in umap_ids]
33
+ run = TopicRun(
34
+ doc_topic=run.doc_topic[order],
35
+ doc_ids=umap_ids,
36
+ word_topic=run.word_topic,
37
+ vocab=run.vocab,
38
+ )
39
+
40
+ plot_topic_grid(umap_coords, run, args.output_png)
41
+ print(f"Saved to {args.output_png}")
@@ -0,0 +1,87 @@
1
+ """Sentence-embedding generation with optional disk caching."""
2
+ from __future__ import annotations
3
+
4
+ import os
5
+
6
+ import numpy as np
7
+
8
+ DEFAULT_MODEL = "all-MiniLM-L6-v2"
9
+
10
+
11
+ class DocumentEmbedder:
12
+ """Generate or load sentence embeddings, with optional .npy cache.
13
+
14
+ Parameters
15
+ ----------
16
+ model: sentence-transformers model name
17
+ cache_path: if given, load embeddings from <cache_path> on disk if they
18
+ exist; otherwise compute and save them there after encoding.
19
+ A parallel <cache_path>.ids file stores document IDs.
20
+
21
+ Usage
22
+ -----
23
+ # Generate (and cache) from raw text:
24
+ embedder = DocumentEmbedder(cache_path="embeddings.npy")
25
+ embeddings = embedder.embed(texts, ids=doc_ids)
26
+
27
+ # Load previously cached embeddings:
28
+ embedder = DocumentEmbedder(cache_path="embeddings.npy")
29
+ embeddings, ids = embedder.load()
30
+
31
+ # Pass the numpy array directly to StabilityAnalysis:
32
+ analysis = StabilityAnalysis(runs, embeddings=embeddings)
33
+ """
34
+
35
+ def __init__(self, model: str = DEFAULT_MODEL, *, cache_path: str | None = None):
36
+ self._model_name = model
37
+ self._model = None # lazy-loaded
38
+ self.cache_path = cache_path
39
+
40
+ def embed(self, texts: list[str], *, ids: list[str] | None = None) -> np.ndarray:
41
+ """Encode texts; return (n_docs, dim) array.
42
+
43
+ If cache_path is set and the cache file already exists, the cached
44
+ embeddings are returned without re-encoding. If the cache does not
45
+ exist, embeddings are computed and saved.
46
+ """
47
+ if self.cache_path and os.path.exists(self.cache_path):
48
+ embeddings, _ = self.load()
49
+ return embeddings
50
+
51
+ if self._model is None:
52
+ try:
53
+ from sentence_transformers import SentenceTransformer
54
+ except ImportError as e:
55
+ raise ImportError(
56
+ "sentence-transformers is required for embedding. "
57
+ "Install with: pip install topic-stability[embed]"
58
+ ) from e
59
+ self._model = SentenceTransformer(self._model_name)
60
+
61
+ embeddings = self._model.encode(texts, show_progress_bar=True, convert_to_numpy=True)
62
+
63
+ if self.cache_path:
64
+ np.save(self.cache_path, embeddings)
65
+ if ids is not None:
66
+ with open(self.cache_path + ".ids", "w", encoding="utf-8") as f:
67
+ f.write("\n".join(ids) + "\n")
68
+
69
+ return embeddings
70
+
71
+ def load(self) -> tuple[np.ndarray, list[str] | None]:
72
+ """Load embeddings and IDs from cache.
73
+
74
+ Returns (embeddings, ids). ids is None if no .ids file was saved.
75
+ """
76
+ if not self.cache_path:
77
+ raise ValueError("No cache_path set")
78
+ if not os.path.exists(self.cache_path):
79
+ raise FileNotFoundError(self.cache_path)
80
+
81
+ embeddings = np.load(self.cache_path)
82
+ ids_path = self.cache_path + ".ids"
83
+ ids = None
84
+ if os.path.exists(ids_path):
85
+ with open(ids_path, encoding="utf-8") as f:
86
+ ids = [line.rstrip("\n") for line in f if line.strip()]
87
+ return embeddings, ids
File without changes
@@ -0,0 +1,163 @@
1
+ """BERTopic integration for topic_stability.
2
+
3
+ Structural differences from probabilistic topic models (LDA, NMF)
4
+ ------------------------------------------------------------------
5
+ BERTopic and LDA/NMF occupy different positions in the topic-model landscape,
6
+ and the differences matter for how stability scores should be interpreted.
7
+
8
+ Representation
9
+ BERTopic defines topics by clustering in sentence-embedding space; a topic
10
+ is a centroid, not a word distribution. LDA/NMF define topics as
11
+ distributions over a fixed vocabulary learned from token co-occurrence.
12
+
13
+ Assignment
14
+ BERTopic default: hard — each document belongs to exactly one topic.
15
+ LDA/NMF: soft — each document has a probability distribution over all topics.
16
+ Soft BERTopic probabilities require calculate_probabilities=True at fit time.
17
+
18
+ Outliers
19
+ BERTopic assigns documents that don't fit any cluster to topic -1. These
20
+ documents are present in the corpus but absent from every topic. This
21
+ function represents them as an all-zero row in doc_topic.
22
+
23
+ Vocabulary
24
+ BERTopic uses c-TF-IDF to label topics post-hoc; the word-topic matrix is
25
+ not a generative distribution. Word-based alignment (cosine similarity of
26
+ word distributions) is therefore not equivalent across model types. Use
27
+ centroid-based alignment (the default in StabilityAnalysis) when mixing
28
+ BERTopic and LDA/NMF runs.
29
+
30
+ Stability interpretation
31
+ For LDA/NMF, a stable topic means the same semantic region of the word
32
+ distribution recurs across runs. For BERTopic, stability means the same
33
+ cluster of documents recurs — an inherently geometric notion. Comparing
34
+ stability scores between the two model types is not straightforward.
35
+
36
+ When to mix model types
37
+ Cross-model comparison can be informative as a sensitivity check: do runs
38
+ trained with completely different methods agree on which documents belong
39
+ together? But the JS divergence of document profiles measures agreement on
40
+ *which documents* are in a topic, not agreement on *what the topic means*.
41
+ """
42
+ from __future__ import annotations
43
+
44
+ import numpy as np
45
+
46
+
47
+ def from_bertopic(
48
+ model,
49
+ docs: list[str] | None = None,
50
+ *,
51
+ embeddings=None,
52
+ doc_ids: list[str] | None = None,
53
+ ) -> tuple:
54
+ """Extract a TopicRun and embeddings from a fitted BERTopic model.
55
+
56
+ Parameters
57
+ ----------
58
+ model: fitted BERTopic instance
59
+ docs: original document texts, needed only if embeddings are not
60
+ stored on the model and embeddings= is not passed
61
+ embeddings: precomputed ndarray (n_docs, dim) — use when you passed
62
+ embeddings directly to model.fit_transform() and BERTopic
63
+ did not cache them on model._embeddings
64
+ doc_ids: document identifiers, parallel to docs / embeddings rows
65
+
66
+ Returns
67
+ -------
68
+ (TopicRun, embeddings) where embeddings is an ndarray (n_docs, dim)
69
+
70
+ Notes
71
+ -----
72
+ doc_topic is taken from model.probabilities_ when available (soft
73
+ assignment, requires calculate_probabilities=True). Otherwise a one-hot
74
+ encoding of model.topics_ is used. Either way, documents assigned to
75
+ topic -1 (outliers) receive an all-zero row.
76
+
77
+ word_topic is populated from model.get_topics() as a c-TF-IDF matrix.
78
+ This is a relevance score, not a probability, and is not directly
79
+ comparable to LDA/NMF word-topic distributions.
80
+ """
81
+ from ..run import TopicRun
82
+
83
+ embeddings = _get_embeddings(model, docs, embeddings)
84
+ doc_topic, topic_labels = _get_doc_topic(model)
85
+ word_topic, vocab = _get_word_topic(model)
86
+
87
+ run = TopicRun(
88
+ doc_topic=doc_topic,
89
+ doc_ids=doc_ids,
90
+ word_topic=word_topic,
91
+ vocab=vocab,
92
+ )
93
+ return run, embeddings
94
+
95
+
96
+ # ── internals ─────────────────────────────────────────────────────────────────
97
+
98
+
99
+ def _get_embeddings(model, docs, precomputed):
100
+ if precomputed is not None:
101
+ return np.asarray(precomputed, dtype=float)
102
+ if getattr(model, "_embeddings", None) is not None:
103
+ return np.asarray(model._embeddings, dtype=float)
104
+ if docs is None:
105
+ raise ValueError(
106
+ "No embeddings available: pass embeddings= (precomputed ndarray), "
107
+ "docs= (texts to re-encode), or fit with a model that caches _embeddings."
108
+ )
109
+ return np.asarray(model.embedding_model.embed(docs), dtype=float)
110
+
111
+
112
+ def _get_doc_topic(model):
113
+ """Hard (binary) document-topic matrix from BERTopic's cluster assignments.
114
+
115
+ BERTopic assigns each document to exactly one cluster; documents that
116
+ don't fit any cluster are labelled -1 (outliers). We represent this as
117
+ a binary matrix: 1.0 where a document is assigned, 0.0 everywhere else.
118
+
119
+ We deliberately ignore model.probabilities_ here. Those are HDBSCAN soft
120
+ membership scores that measure how strongly a point sits inside its
121
+ cluster — a geometric property of the embedding space, not a probability
122
+ distribution over topics. Using them would produce apparent gradations
123
+ that are not directly comparable to LDA/NMF topic weights. Hard 0/1
124
+ assignment makes BERTopic's discrete nature explicit.
125
+ """
126
+ topics = np.array(model.topics_)
127
+ valid = sorted(set(topics) - {-1})
128
+ K = len(valid)
129
+ label_to_col = {t: i for i, t in enumerate(valid)}
130
+ n_docs = len(topics)
131
+
132
+ doc_topic = np.zeros((n_docs, K))
133
+ for d, t in enumerate(topics):
134
+ if t != -1:
135
+ doc_topic[d, label_to_col[t]] = 1.0
136
+
137
+ return doc_topic, valid
138
+
139
+
140
+ def _get_word_topic(model):
141
+ """Return (vocab, word_topic) from c-TF-IDF topic representations."""
142
+ topics = model.get_topics()
143
+ if not topics:
144
+ return None, None
145
+
146
+ # Collect all words across topics to build a shared vocabulary.
147
+ all_words = []
148
+ for words_scores in topics.values():
149
+ for word, _ in words_scores:
150
+ if word not in all_words:
151
+ all_words.append(word)
152
+ word_index = {w: i for i, w in enumerate(all_words)}
153
+
154
+ valid_ids = sorted(k for k in topics if k != -1)
155
+ K = len(valid_ids)
156
+ W = len(all_words)
157
+ word_topic = np.zeros((W, K))
158
+ for col, tid in enumerate(valid_ids):
159
+ for word, score in topics[tid]:
160
+ if word in word_index:
161
+ word_topic[word_index[word], col] = max(score, 0.0)
162
+
163
+ return word_topic, all_words
topic_stability/io.py ADDED
@@ -0,0 +1,153 @@
1
+ """File I/O helpers for topic distributions."""
2
+ from __future__ import annotations
3
+
4
+ import csv
5
+ import gzip
6
+ import os
7
+
8
+ import numpy as np
9
+
10
+ _MALLET_ITERATIONS = [1000, 1050, 1100, 1150, 1200]
11
+
12
+
13
+ # ── CSV (current pipeline format) ────────────────────────────────────────────
14
+
15
+
16
+ def read_doc_topic_csv(path) -> tuple[list[str], np.ndarray]:
17
+ """Read a doc-topic CSV (id, topic_0, topic_1, ...).
18
+
19
+ Returns (doc_ids, doc_topic) where doc_topic has shape (n_docs, n_topics).
20
+ """
21
+ with open(path, encoding="utf-8", newline="") as f:
22
+ reader = csv.reader(f)
23
+ header = next(reader)
24
+ ids, rows = [], []
25
+ for row in reader:
26
+ ids.append(row[0])
27
+ rows.append([float(v) for v in row[1:]])
28
+ return ids, np.array(rows)
29
+
30
+
31
+ def read_word_topic_csv(path) -> tuple[list[str], np.ndarray]:
32
+ """Read a word-topic CSV (word, topic_0, topic_1, ...).
33
+
34
+ Returns (vocab, word_topic) where word_topic has shape (n_words, n_topics).
35
+
36
+ Words may contain commas (e.g. "specifically,we"); n_topics is inferred
37
+ from the header so the last n_topics comma-separated fields are always
38
+ treated as the numeric values.
39
+ """
40
+ with open(path, encoding="utf-8") as f:
41
+ header = f.readline().rstrip("\n").split(",")
42
+ n_topics = len(header) - 1
43
+ vocab, rows = [], []
44
+ for line in f:
45
+ if not line.strip():
46
+ continue
47
+ parts = line.rstrip("\n").split(",")
48
+ word = ",".join(parts[:-n_topics])
49
+ rows.append([float(v) for v in parts[-n_topics:]])
50
+ vocab.append(word)
51
+ return vocab, np.array(rows)
52
+
53
+
54
+ # ── Mallet state files ────────────────────────────────────────────────────────
55
+
56
+
57
+ def _read_one_state(path):
58
+ """Parse a single Mallet .gz state file.
59
+
60
+ Returns (doc_topic_counts, topic_word_counts, vocab, alpha, beta).
61
+ """
62
+ alpha = beta = None
63
+ vocab: dict[str, int] = {}
64
+
65
+ with gzip.open(path, "rt", encoding="utf-8") as f:
66
+ for line in f:
67
+ if line.startswith("#alpha"):
68
+ alpha = [float(v) for v in line.strip().split()[2:]]
69
+ elif line.startswith("#beta"):
70
+ beta = float(line.strip().split()[2])
71
+ elif not line.startswith("#"):
72
+ word = line.split()[4]
73
+ if word not in vocab:
74
+ vocab[word] = len(vocab)
75
+
76
+ if alpha is None or beta is None:
77
+ raise ValueError(f"Could not parse alpha/beta from {path}")
78
+
79
+ num_topics = len(alpha)
80
+ W = len(vocab)
81
+ topic_word = np.zeros((num_topics, W), dtype=np.int32)
82
+ doc_topic_dict: dict[int, np.ndarray] = {}
83
+
84
+ with gzip.open(path, "rt", encoding="utf-8") as f:
85
+ for line in f:
86
+ if line.startswith("#"):
87
+ continue
88
+ parts = line.split()
89
+ doc_id = int(parts[0])
90
+ word = parts[4]
91
+ topic = int(parts[5])
92
+ topic_word[topic, vocab[word]] += 1
93
+ if doc_id not in doc_topic_dict:
94
+ doc_topic_dict[doc_id] = np.zeros(num_topics, dtype=np.int32)
95
+ doc_topic_dict[doc_id][topic] += 1
96
+
97
+ num_docs = max(doc_topic_dict) + 1
98
+ doc_topic = np.zeros((num_docs, num_topics), dtype=np.int32)
99
+ for did, counts in doc_topic_dict.items():
100
+ doc_topic[did] = counts
101
+
102
+ return doc_topic, topic_word, list(vocab.keys()), alpha, beta
103
+
104
+
105
+ def _smooth_doc_topic(counts, alpha):
106
+ a = np.array(alpha)
107
+ return (counts + a) / (counts.sum(axis=1, keepdims=True) + a.sum())
108
+
109
+
110
+ def _smooth_word_topic(topic_word, beta):
111
+ W = topic_word.shape[1]
112
+ phi = (topic_word + beta) / (topic_word.sum(axis=1, keepdims=True) + W * beta)
113
+ return phi.T # (n_words, n_topics)
114
+
115
+
116
+ def read_mallet_states(
117
+ model_dir,
118
+ *,
119
+ iterations: list[int] | None = None,
120
+ tsv_path: str | None = None,
121
+ ) -> tuple[np.ndarray, np.ndarray, list[str], list[str] | None]:
122
+ """Average doc-topic and word-topic distributions over multiple Mallet states.
123
+
124
+ Returns (doc_topic, word_topic, vocab, doc_ids).
125
+ doc_ids is read from tsv_path if provided, otherwise None.
126
+ """
127
+ if iterations is None:
128
+ iterations = _MALLET_ITERATIONS
129
+
130
+ doc_ids = None
131
+ if tsv_path is not None:
132
+ with open(tsv_path, encoding="utf-8") as f:
133
+ doc_ids = [line.split("\t")[0] for line in f]
134
+
135
+ all_dt, all_wt = [], []
136
+ shared_vocab = None
137
+
138
+ for it in iterations:
139
+ path = os.path.join(model_dir, f"state_{it}.gz")
140
+ if not os.path.exists(path):
141
+ continue
142
+ dt_counts, tw_counts, vocab, alpha, beta = _read_one_state(path)
143
+ if shared_vocab is None:
144
+ shared_vocab = vocab
145
+ all_dt.append(_smooth_doc_topic(dt_counts, alpha))
146
+ all_wt.append(_smooth_word_topic(tw_counts, beta))
147
+
148
+ if not all_dt:
149
+ raise FileNotFoundError(f"No state files found in {model_dir}")
150
+
151
+ doc_topic = np.mean(all_dt, axis=0)
152
+ word_topic = np.mean(all_wt, axis=0)
153
+ return doc_topic, word_topic, shared_vocab, doc_ids
topic_stability/run.py ADDED
@@ -0,0 +1,103 @@
1
+ from __future__ import annotations
2
+
3
+ import numpy as np
4
+ from dataclasses import dataclass, field
5
+
6
+
7
+ @dataclass
8
+ class TopicRun:
9
+ """One run's topic distributions.
10
+
11
+ doc_topic: ndarray (n_docs, n_topics) — required
12
+ doc_ids: list of document identifiers, parallel to doc_topic rows
13
+ word_topic: ndarray (n_words, n_topics) — optional; used for word-based inspection
14
+ vocab: list of words parallel to word_topic rows
15
+ """
16
+
17
+ doc_topic: np.ndarray
18
+ doc_ids: list[str] | None = field(default=None)
19
+ word_topic: np.ndarray | None = field(default=None)
20
+ vocab: list[str] | None = field(default=None)
21
+
22
+ @property
23
+ def n_docs(self) -> int:
24
+ return self.doc_topic.shape[0]
25
+
26
+ @property
27
+ def n_topics(self) -> int:
28
+ return self.doc_topic.shape[1]
29
+
30
+ # ── constructors ─────────────────────────────────────────────────────────
31
+
32
+ @classmethod
33
+ def from_matrix(
34
+ cls,
35
+ doc_topic,
36
+ *,
37
+ doc_ids=None,
38
+ word_topic=None,
39
+ vocab=None,
40
+ ) -> "TopicRun":
41
+ """Construct from numpy arrays.
42
+
43
+ doc_topic: array-like (n_docs, n_topics)
44
+ word_topic: array-like (n_words, n_topics), optional
45
+ """
46
+ return cls(
47
+ doc_topic=np.asarray(doc_topic, dtype=float),
48
+ doc_ids=list(doc_ids) if doc_ids is not None else None,
49
+ word_topic=np.asarray(word_topic, dtype=float) if word_topic is not None else None,
50
+ vocab=list(vocab) if vocab is not None else None,
51
+ )
52
+
53
+ @classmethod
54
+ def from_sklearn(cls, model, X, *, doc_ids=None, vocab=None) -> "TopicRun":
55
+ """Construct from a fitted sklearn-compatible topic model.
56
+
57
+ Calls model.transform(X) to get doc-topic probabilities.
58
+ Reads model.components_ (n_topics, n_features) as word-topic if present.
59
+
60
+ Compatible with LatentDirichletAllocation, NMF, and any model that
61
+ follows the sklearn transformer interface.
62
+ """
63
+ doc_topic = model.transform(X)
64
+ word_topic = None
65
+ if hasattr(model, "components_"):
66
+ word_topic = model.components_.T # → (n_features, n_topics)
67
+ return cls.from_matrix(doc_topic, doc_ids=doc_ids, word_topic=word_topic, vocab=vocab)
68
+
69
+ @classmethod
70
+ def from_csv(cls, doc_topic_path, *, word_topic_path=None) -> "TopicRun":
71
+ """Construct from the CSV files produced by the topic-stability pipeline.
72
+
73
+ doc_topic_path: CSV with columns id, topic_0, topic_1, ...
74
+ word_topic_path: CSV with columns word, topic_0, topic_1, ... (optional)
75
+ """
76
+ from .io import read_doc_topic_csv, read_word_topic_csv
77
+
78
+ doc_ids, doc_topic = read_doc_topic_csv(doc_topic_path)
79
+ vocab, word_topic = None, None
80
+ if word_topic_path is not None:
81
+ vocab, word_topic = read_word_topic_csv(word_topic_path)
82
+ return cls(doc_topic=doc_topic, doc_ids=doc_ids, word_topic=word_topic, vocab=vocab)
83
+
84
+ @classmethod
85
+ def from_mallet_states(
86
+ cls,
87
+ model_dir,
88
+ *,
89
+ iterations=None,
90
+ tsv_path=None,
91
+ ) -> "TopicRun":
92
+ """Construct by averaging over multiple Mallet Gibbs sampling states.
93
+
94
+ model_dir: directory containing state_<iter>.gz files
95
+ iterations: list of iteration numbers to include (default: 1000–1200)
96
+ tsv_path: original corpus TSV to recover document IDs (optional)
97
+ """
98
+ from .io import read_mallet_states
99
+
100
+ doc_topic, word_topic, vocab, doc_ids = read_mallet_states(
101
+ model_dir, iterations=iterations, tsv_path=tsv_path
102
+ )
103
+ return cls(doc_topic=doc_topic, doc_ids=doc_ids, word_topic=word_topic, vocab=vocab)
@@ -0,0 +1,128 @@
1
+ """Small-multiples UMAP topic visualization."""
2
+ from __future__ import annotations
3
+
4
+ import math
5
+
6
+ import numpy as np
7
+
8
+
9
+ def _topic_centroids_2d(xy, weights):
10
+ """Weighted-average 2D UMAP position per topic. Returns (n_topics, 2)."""
11
+ w_norm = weights / weights.sum(axis=0, keepdims=True)
12
+ return w_norm.T @ xy
13
+
14
+
15
+ def _assign_grid(centroids, nrows, ncols):
16
+ """Hungarian matching of topic centroids to grid cells.
17
+
18
+ Returns dict {topic_index: (row, col)}.
19
+ """
20
+ from scipy.optimize import linear_sum_assignment
21
+
22
+ x_min, x_max = centroids[:, 0].min(), centroids[:, 0].max()
23
+ y_min, y_max = centroids[:, 1].min(), centroids[:, 1].max()
24
+
25
+ col_centres = x_min + (np.arange(ncols) + 0.5) / ncols * (x_max - x_min)
26
+ row_centres = y_max - (np.arange(nrows) + 0.5) / nrows * (y_max - y_min)
27
+ gx, gy = np.meshgrid(col_centres, row_centres)
28
+ cell_xy = np.column_stack([gx.ravel(), gy.ravel()])
29
+
30
+ diff = centroids[:, np.newaxis, :] - cell_xy[np.newaxis, :, :]
31
+ cost = np.sqrt((diff ** 2).sum(axis=2))
32
+ topic_ind, cell_ind = linear_sum_assignment(cost)
33
+ return {k: (int(c) // ncols, int(c) % ncols) for k, c in zip(topic_ind, cell_ind)}
34
+
35
+
36
+ def plot_topic_grid(
37
+ umap_coords: np.ndarray,
38
+ run,
39
+ output_path: str,
40
+ *,
41
+ stability_scores: np.ndarray | None = None,
42
+ top_n_words: int = 8,
43
+ ) -> None:
44
+ """Render a small-multiples grid coloured by per-topic document weight.
45
+
46
+ Each panel shows all documents in UMAP space, coloured by their weight
47
+ for that topic. Grid positions reflect where each topic's documents
48
+ cluster in the projection. When stability_scores is provided, the
49
+ per-topic score is appended to each panel title.
50
+
51
+ Requires matplotlib: pip install topic-stability[viz]
52
+ """
53
+ try:
54
+ import matplotlib.pyplot as plt
55
+ import matplotlib.colors as mcolors
56
+ except ImportError as e:
57
+ raise ImportError(
58
+ "matplotlib is required for visualization. "
59
+ "Install with: pip install topic-stability[viz]"
60
+ ) from e
61
+
62
+ xy = umap_coords
63
+ weights = run.doc_topic # (n_docs, n_topics)
64
+ K = run.n_topics
65
+
66
+ ncols = math.ceil(math.sqrt(K))
67
+ nrows = math.ceil(K / ncols)
68
+
69
+ centroids_2d = _topic_centroids_2d(xy, weights)
70
+ grid_pos = _assign_grid(centroids_2d, nrows, ncols)
71
+
72
+ top_words = _top_words(run, top_n_words)
73
+
74
+ cmap = mcolors.LinearSegmentedColormap.from_list(
75
+ "topic_weight", ["#d8d8d8", "#f5a623", "#c0392b"]
76
+ )
77
+
78
+ fig, axes = plt.subplots(nrows, ncols,
79
+ figsize=(ncols * 2.6, nrows * 3.0),
80
+ facecolor="white")
81
+ axes = np.array(axes).reshape(nrows, ncols)
82
+
83
+ occupied = set()
84
+ x, y = xy[:, 0], xy[:, 1]
85
+
86
+ for k in range(K):
87
+ row, col = grid_pos[k]
88
+ occupied.add((row, col))
89
+ ax = axes[row, col]
90
+ w = weights[:, k]
91
+
92
+ vmax = np.percentile(w, 95)
93
+ if vmax == 0:
94
+ vmax = w.max() or 1.0
95
+ norm_w = np.clip(w / vmax, 0, 1)
96
+ order = np.argsort(norm_w)
97
+
98
+ ax.scatter(x[order], y[order], c=norm_w[order], cmap=cmap,
99
+ vmin=0, vmax=1, s=2, linewidths=0, rasterized=True)
100
+
101
+ words = top_words[k] if top_n_words > 0 and top_words else []
102
+ label = " ".join(words[:4]) + ("\n" + " ".join(words[4:]) if words[4:] else "")
103
+ if stability_scores is not None:
104
+ label += f"\nstability {stability_scores[k]:.2f}"
105
+ ax.set_title(label, fontsize=6.5, pad=3, linespacing=1.4)
106
+ ax.set_xticks([])
107
+ ax.set_yticks([])
108
+ for spine in ax.spines.values():
109
+ spine.set_visible(False)
110
+
111
+ for r in range(nrows):
112
+ for c in range(ncols):
113
+ if (r, c) not in occupied:
114
+ axes[r, c].set_visible(False)
115
+
116
+ fig.tight_layout(pad=0.4)
117
+ fig.savefig(output_path, dpi=150, bbox_inches="tight")
118
+ plt.close(fig)
119
+
120
+
121
+ def _top_words(run, n):
122
+ if run.word_topic is None or run.vocab is None or n == 0:
123
+ return [[] for _ in range(run.n_topics)]
124
+ wt = run.word_topic # (n_words, n_topics)
125
+ return [
126
+ [run.vocab[i] for i in np.argsort(wt[:, k])[::-1][:n]]
127
+ for k in range(run.n_topics)
128
+ ]
@@ -0,0 +1,215 @@
1
+ Metadata-Version: 2.4
2
+ Name: topic-stability
3
+ Version: 0.1.0
4
+ Summary: Measure and visualize topic model stability across multiple runs
5
+ Project-URL: Homepage, https://github.com/mimno/TopicStability
6
+ Project-URL: Repository, https://github.com/mimno/TopicStability
7
+ Project-URL: Issues, https://github.com/mimno/TopicStability/issues
8
+ Author-email: David Mimno <mimno@cornell.edu>
9
+ License: MIT License
10
+
11
+ Copyright (c) 2026 mimno
12
+
13
+ Permission is hereby granted, free of charge, to any person obtaining a copy
14
+ of this software and associated documentation files (the "Software"), to deal
15
+ in the Software without restriction, including without limitation the rights
16
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
17
+ copies of the Software, and to permit persons to whom the Software is
18
+ furnished to do so, subject to the following conditions:
19
+
20
+ The above copyright notice and this permission notice shall be included in all
21
+ copies or substantial portions of the Software.
22
+
23
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
24
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
25
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
26
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
27
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
28
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
29
+ SOFTWARE.
30
+ License-File: LICENSE
31
+ Keywords: BERTopic,LDA,NLP,stability,text analysis,topic modeling
32
+ Classifier: Development Status :: 3 - Alpha
33
+ Classifier: Intended Audience :: Science/Research
34
+ Classifier: License :: OSI Approved :: MIT License
35
+ Classifier: Programming Language :: Python :: 3
36
+ Classifier: Programming Language :: Python :: 3.10
37
+ Classifier: Programming Language :: Python :: 3.11
38
+ Classifier: Programming Language :: Python :: 3.12
39
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
40
+ Classifier: Topic :: Text Processing :: Linguistic
41
+ Requires-Python: >=3.10
42
+ Requires-Dist: numpy
43
+ Requires-Dist: scipy
44
+ Provides-Extra: all
45
+ Requires-Dist: matplotlib; extra == 'all'
46
+ Requires-Dist: sentence-transformers; extra == 'all'
47
+ Requires-Dist: umap-learn; extra == 'all'
48
+ Provides-Extra: embed
49
+ Requires-Dist: sentence-transformers; extra == 'embed'
50
+ Provides-Extra: umap
51
+ Requires-Dist: umap-learn; extra == 'umap'
52
+ Provides-Extra: viz
53
+ Requires-Dist: matplotlib; extra == 'viz'
54
+ Description-Content-Type: text/markdown
55
+
56
+ # topic-stability
57
+
58
+ Measure and visualize the stability of topic models across multiple runs.
59
+
60
+ Topic models are stochastic: two runs with the same settings produce differently-labelled topics in a different order. **topic-stability** aligns topics across runs using sentence-embedding centroids and scores each topic by how consistently the same documents are assigned to it (Jensen-Shannon divergence). The result is a per-topic stability score in [0, 1] and a small-multiples UMAP visualization with stability annotated on each panel.
61
+
62
+ Works with any topic model that produces a document-topic matrix — LDA, NMF, BERTopic, and more.
63
+
64
+ ## Install
65
+
66
+ ```bash
67
+ pip install topic-stability # core (numpy + scipy only)
68
+ pip install "topic-stability[embed]" # + sentence-transformers
69
+ pip install "topic-stability[umap,viz]" # + UMAP + matplotlib
70
+ pip install "topic-stability[all]" # everything
71
+ ```
72
+
73
+ ## Quick start
74
+
75
+ ### sklearn (LDA, NMF, …)
76
+
77
+ ```python
78
+ from sklearn.decomposition import LatentDirichletAllocation
79
+ from topic_stability import TopicRun, StabilityAnalysis, DocumentEmbedder
80
+
81
+ # Embed documents once and cache to disk
82
+ embedder = DocumentEmbedder(cache_path="embeddings.npy")
83
+ embeddings = embedder.embed(texts, ids=doc_ids)
84
+
85
+ # Train several runs
86
+ runs = [TopicRun.from_sklearn(
87
+ LatentDirichletAllocation(n_components=20).fit(X), X
88
+ ) for _ in range(5)]
89
+
90
+ analysis = StabilityAnalysis(runs, embeddings=embeddings)
91
+ analysis.align()
92
+
93
+ print(analysis.topic_stability()) # array of shape (n_topics,)
94
+ print(analysis.overall_stability()) # scalar
95
+
96
+ analysis.visualize("topics.png") # requires topic-stability[umap,viz]
97
+ ```
98
+
99
+ ### Pass precomputed embeddings (e.g. from BERTopic)
100
+
101
+ ```python
102
+ from topic_stability.integrations.bertopic import from_bertopic
103
+
104
+ run, embeddings = from_bertopic(model, embeddings=precomputed_embeddings)
105
+ ```
106
+
107
+ See [BERTopic notes](#bertopic) below for important differences.
108
+
109
+ ### From files (Mallet / CSV pipeline)
110
+
111
+ ```python
112
+ runs = [
113
+ TopicRun.from_csv(
114
+ f"model_42_run{i}/doc_topic_avg.csv",
115
+ word_topic_path=f"model_42_run{i}/word_topic_avg.csv",
116
+ )
117
+ for i in range(1, 6)
118
+ ]
119
+
120
+ embedder = DocumentEmbedder(cache_path="embeddings.npy")
121
+ embeddings, _ = embedder.load()
122
+
123
+ analysis = StabilityAnalysis(runs, embeddings=embeddings)
124
+ analysis.align()
125
+ analysis.visualize("topics.png", umap_coords=precomputed_umap)
126
+ ```
127
+
128
+ ## API
129
+
130
+ ### `TopicRun`
131
+
132
+ One run's topic distributions.
133
+
134
+ | Constructor | Use when |
135
+ |---|---|
136
+ | `TopicRun.from_matrix(doc_topic, *, doc_ids, word_topic, vocab)` | You have numpy arrays |
137
+ | `TopicRun.from_sklearn(model, X, *, doc_ids, vocab)` | sklearn `transform()` interface |
138
+ | `TopicRun.from_csv(doc_topic_path, *, word_topic_path)` | CSV files from the CLI pipeline |
139
+ | `TopicRun.from_mallet_states(model_dir, *, iterations, tsv_path)` | Mallet `.gz` state files |
140
+
141
+ ### `DocumentEmbedder`
142
+
143
+ ```python
144
+ embedder = DocumentEmbedder(model="all-MiniLM-L6-v2", cache_path="embeddings.npy")
145
+ embeddings = embedder.embed(texts, ids=doc_ids) # computes and caches
146
+ embeddings, ids = embedder.load() # load from cache
147
+ ```
148
+
149
+ Pass the returned array directly to `StabilityAnalysis(runs, embeddings=embeddings)`.
150
+
151
+ ### `StabilityAnalysis`
152
+
153
+ ```python
154
+ analysis = StabilityAnalysis(runs, embeddings, *, doc_ids=None)
155
+ analysis.align(reference=0) # must call before scoring
156
+ analysis.topic_stability() # ndarray (n_topics,) in [0, 1]
157
+ analysis.overall_stability() # float
158
+ analysis.umap_projection(**kwargs) # ndarray (n_docs, 2)
159
+ analysis.visualize(path, *, reference_run=0, umap_coords=None)
160
+ ```
161
+
162
+ **Alignment** uses cosine similarity of per-topic embedding centroids
163
+ (`centroid_k = Σ_d θ_dk · e_d`, normalised) matched with the Hungarian
164
+ algorithm. No shared vocabulary is required, so runs from different model
165
+ types can be compared.
166
+
167
+ **Stability score** for topic k: mean pairwise `1 − JS(p, q)` where p and
168
+ q are the normalised document-profile columns `θ[:,k]` (treated as a
169
+ distribution over documents) from each pair of aligned runs.
170
+
171
+ ## BERTopic
172
+
173
+ ```python
174
+ from topic_stability.integrations.bertopic import from_bertopic
175
+
176
+ run, embeddings = from_bertopic(model, docs=None, *, embeddings=None, doc_ids=None)
177
+ ```
178
+
179
+ Returns `(TopicRun, embeddings_array)`.
180
+
181
+ **Key differences from LDA/NMF:**
182
+
183
+ - BERTopic assigns each document to exactly one cluster (hard assignment). The
184
+ `doc_topic` matrix is binary: 1 for the assigned topic, 0 elsewhere.
185
+ Documents that HDBSCAN assigns to topic −1 (outliers) get an all-zero row.
186
+ - `model.probabilities_` contains HDBSCAN soft-membership scores, not
187
+ topic-weight distributions. We do not use them — they are a geometric
188
+ property of the embedding space, not comparable to LDA posterior weights.
189
+ - Word representations come from c-TF-IDF scores, not a generative word
190
+ distribution. Cross-model word-based comparison is not meaningful.
191
+ - Stability scores measure whether the *same documents* cluster together
192
+ across runs, not whether the same word distributions recur.
193
+
194
+ ## CLI pipeline (Mallet / RustMallet)
195
+
196
+ The package includes CLI wrappers for a full file-based workflow:
197
+
198
+ ```bash
199
+ # 1. Embed documents
200
+ topic-stability-embed corpus.tsv embeddings.npy
201
+
202
+ # 2. Project to 2D
203
+ topic-stability-project embeddings.npy umap_2d.csv
204
+
205
+ # 3. Estimate distributions from Mallet states
206
+ topic-stability-estimate model_42_run1/ 42 corpus.tsv
207
+
208
+ # 4. Visualize a single run
209
+ topic-stability-visualize umap_2d.csv model_42_run1/doc_topic_avg.csv \
210
+ model_42_run1/word_topic_avg.csv topics.png
211
+ ```
212
+
213
+ ## License
214
+
215
+ MIT
@@ -0,0 +1,20 @@
1
+ topic_stability/__init__.py,sha256=FQVi5oNJXBqvcTs4ozMDldzu9sUQPJMrXd7cTZEmLH4,245
2
+ topic_stability/_align.py,sha256=rRwUHeYe4G6vTGIeznak45Wf_t8l29ZrIpQOpd8ZqbA,472
3
+ topic_stability/_metrics.py,sha256=QWoiz2K2ByKCBM3-D1QIzKNRREfJvpGxptoy2XBGNtA,1027
4
+ topic_stability/analysis.py,sha256=tgTgJuCOa5TRJBJRsg0-ycEhmhgLJlr8VAm483sPrco,7052
5
+ topic_stability/embeddings.py,sha256=3ew_e5fV3FLe2BpCPr3Lr2g5SXs2MCAKXA0dufUvYVU,3171
6
+ topic_stability/io.py,sha256=FrM3BAR4i0p-JApUDKRdEQ8H_8opSLq_2KL0A1uHyAw,5221
7
+ topic_stability/run.py,sha256=5iknMtG6SOObIFOEYQi52JgCMHUgl_ixKHi5wJwZ8_E,3824
8
+ topic_stability/visualization.py,sha256=7a-qKxORrkyoMH2mZxPNNMmXF9sWNu1ChtVcxgWz-QE,4205
9
+ topic_stability/cli/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
10
+ topic_stability/cli/embed.py,sha256=C_PeQXxdIFg_ogmEWPyzOIAQXvF4SMYcXtIA03dLYXU,1115
11
+ topic_stability/cli/estimate.py,sha256=xZKHuZ_00Nk0ayuwmR0jgy19ZJl8uM8gWRCx4gVI12Y,1687
12
+ topic_stability/cli/project.py,sha256=p2m8wSZB_L4ILSTq9VESHZL2xeU4ARqZMTTqTMNuqD0,1478
13
+ topic_stability/cli/visualize.py,sha256=Cd_Q9Rblt1TyuXrZR-RezmglbGPesAM1u7SKeVs5j7Y,1453
14
+ topic_stability/integrations/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
15
+ topic_stability/integrations/bertopic.py,sha256=DL_1oDKdS67Y2nFm_LBNAIu1vFwChxVnXr1HoSORLsQ,6457
16
+ topic_stability-0.1.0.dist-info/METADATA,sha256=jH10rLVS_ORnZLeQUgwLuBAqv7ZQDoSo2oeYf09CFAU,8228
17
+ topic_stability-0.1.0.dist-info/WHEEL,sha256=mffPy8wBnZQn2VnJUU5jE99KsxaSfiyMHV9Yt0aLVxs,87
18
+ topic_stability-0.1.0.dist-info/entry_points.txt,sha256=wKp4yuV1LSF0Bq84yuOHKfZRrVPcQ5oEKVenwM-2EM4,256
19
+ topic_stability-0.1.0.dist-info/licenses/LICENSE,sha256=XKuNUDXkpLe-Px7bBHxR1AtgIwSYNw3t_hE6m0Dkgq4,1062
20
+ topic_stability-0.1.0.dist-info/RECORD,,
@@ -0,0 +1,4 @@
1
+ Wheel-Version: 1.0
2
+ Generator: hatchling 1.30.1
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
@@ -0,0 +1,5 @@
1
+ [console_scripts]
2
+ topic-stability-embed = topic_stability.cli.embed:main
3
+ topic-stability-estimate = topic_stability.cli.estimate:main
4
+ topic-stability-project = topic_stability.cli.project:main
5
+ topic-stability-visualize = topic_stability.cli.visualize:main
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2026 mimno
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.