topic-stability 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. topic_stability-0.1.0/.github/workflows/ci.yml +23 -0
  2. topic_stability-0.1.0/.github/workflows/publish.yml +32 -0
  3. topic_stability-0.1.0/.gitignore +29 -0
  4. topic_stability-0.1.0/LICENSE +21 -0
  5. topic_stability-0.1.0/PKG-INFO +215 -0
  6. topic_stability-0.1.0/README.md +160 -0
  7. topic_stability-0.1.0/embed_documents.py +65 -0
  8. topic_stability-0.1.0/estimate_from_states.py +188 -0
  9. topic_stability-0.1.0/example.tsv +3 -0
  10. topic_stability-0.1.0/import_data.sh +20 -0
  11. topic_stability-0.1.0/project_umap.py +69 -0
  12. topic_stability-0.1.0/pyproject.toml +47 -0
  13. topic_stability-0.1.0/stoplist.txt +172 -0
  14. topic_stability-0.1.0/tests/test_stability.py +123 -0
  15. topic_stability-0.1.0/topic_stability/__init__.py +7 -0
  16. topic_stability-0.1.0/topic_stability/_align.py +14 -0
  17. topic_stability-0.1.0/topic_stability/_metrics.py +35 -0
  18. topic_stability-0.1.0/topic_stability/analysis.py +174 -0
  19. topic_stability-0.1.0/topic_stability/cli/__init__.py +0 -0
  20. topic_stability-0.1.0/topic_stability/cli/embed.py +30 -0
  21. topic_stability-0.1.0/topic_stability/cli/estimate.py +49 -0
  22. topic_stability-0.1.0/topic_stability/cli/project.py +45 -0
  23. topic_stability-0.1.0/topic_stability/cli/visualize.py +41 -0
  24. topic_stability-0.1.0/topic_stability/embeddings.py +87 -0
  25. topic_stability-0.1.0/topic_stability/integrations/__init__.py +0 -0
  26. topic_stability-0.1.0/topic_stability/integrations/bertopic.py +163 -0
  27. topic_stability-0.1.0/topic_stability/io.py +153 -0
  28. topic_stability-0.1.0/topic_stability/run.py +103 -0
  29. topic_stability-0.1.0/topic_stability/visualization.py +128 -0
  30. topic_stability-0.1.0/train_model.sh +37 -0
  31. topic_stability-0.1.0/uv.lock +2182 -0
  32. topic_stability-0.1.0/visualize_topics.py +276 -0
@@ -0,0 +1,23 @@
1
+ name: CI
2
+
3
+ on:
4
+ push:
5
+ branches: [main]
6
+ pull_request:
7
+ branches: [main]
8
+
9
+ jobs:
10
+ test:
11
+ runs-on: ubuntu-latest
12
+ strategy:
13
+ matrix:
14
+ python-version: ["3.10", "3.11", "3.12"]
15
+ steps:
16
+ - uses: actions/checkout@v4
17
+
18
+ - uses: astral-sh/setup-uv@v5
19
+ with:
20
+ python-version: ${{ matrix.python-version }}
21
+
22
+ - name: Run tests
23
+ run: uv run --with pytest --with scikit-learn pytest tests/ -v
@@ -0,0 +1,32 @@
1
+ name: Publish to PyPI
2
+
3
+ on:
4
+ push:
5
+ tags:
6
+ - "v*"
7
+
8
+ jobs:
9
+ build:
10
+ runs-on: ubuntu-latest
11
+ steps:
12
+ - uses: actions/checkout@v4
13
+ - uses: astral-sh/setup-uv@v5
14
+ - run: uv run --with pytest --with scikit-learn pytest tests/ -v
15
+ - run: uv build
16
+ - uses: actions/upload-artifact@v4
17
+ with:
18
+ name: dist
19
+ path: dist/
20
+
21
+ publish:
22
+ needs: build
23
+ runs-on: ubuntu-latest
24
+ environment: pypi
25
+ permissions:
26
+ id-token: write # required for PyPI trusted publishing
27
+ steps:
28
+ - uses: actions/download-artifact@v4
29
+ with:
30
+ name: dist
31
+ path: dist/
32
+ - uses: pypa/gh-action-pypi-publish@release/v1
@@ -0,0 +1,29 @@
1
+ # macOS
2
+ .DS_Store
3
+
4
+ # Python
5
+ __pycache__/
6
+ *.py[cod]
7
+ *.egg-info/
8
+ dist/
9
+ .venv/
10
+
11
+ # Input data (not example.tsv)
12
+ cs_cl.tsv
13
+
14
+ # Generated corpus binaries
15
+ *.mallet
16
+ *.corp
17
+
18
+ # Generated embeddings
19
+ embeddings.npy
20
+ embeddings.npy.ids
21
+
22
+ # Generated UMAP projection
23
+ umap_2d.csv
24
+
25
+ # All model output directories
26
+ model_*/
27
+
28
+ # Training logs
29
+ *.log
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2026 mimno
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,215 @@
1
+ Metadata-Version: 2.4
2
+ Name: topic-stability
3
+ Version: 0.1.0
4
+ Summary: Measure and visualize topic model stability across multiple runs
5
+ Project-URL: Homepage, https://github.com/mimno/TopicStability
6
+ Project-URL: Repository, https://github.com/mimno/TopicStability
7
+ Project-URL: Issues, https://github.com/mimno/TopicStability/issues
8
+ Author-email: David Mimno <mimno@cornell.edu>
9
+ License: MIT License
10
+
11
+ Copyright (c) 2026 mimno
12
+
13
+ Permission is hereby granted, free of charge, to any person obtaining a copy
14
+ of this software and associated documentation files (the "Software"), to deal
15
+ in the Software without restriction, including without limitation the rights
16
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
17
+ copies of the Software, and to permit persons to whom the Software is
18
+ furnished to do so, subject to the following conditions:
19
+
20
+ The above copyright notice and this permission notice shall be included in all
21
+ copies or substantial portions of the Software.
22
+
23
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
24
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
25
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
26
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
27
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
28
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
29
+ SOFTWARE.
30
+ License-File: LICENSE
31
+ Keywords: BERTopic,LDA,NLP,stability,text analysis,topic modeling
32
+ Classifier: Development Status :: 3 - Alpha
33
+ Classifier: Intended Audience :: Science/Research
34
+ Classifier: License :: OSI Approved :: MIT License
35
+ Classifier: Programming Language :: Python :: 3
36
+ Classifier: Programming Language :: Python :: 3.10
37
+ Classifier: Programming Language :: Python :: 3.11
38
+ Classifier: Programming Language :: Python :: 3.12
39
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
40
+ Classifier: Topic :: Text Processing :: Linguistic
41
+ Requires-Python: >=3.10
42
+ Requires-Dist: numpy
43
+ Requires-Dist: scipy
44
+ Provides-Extra: all
45
+ Requires-Dist: matplotlib; extra == 'all'
46
+ Requires-Dist: sentence-transformers; extra == 'all'
47
+ Requires-Dist: umap-learn; extra == 'all'
48
+ Provides-Extra: embed
49
+ Requires-Dist: sentence-transformers; extra == 'embed'
50
+ Provides-Extra: umap
51
+ Requires-Dist: umap-learn; extra == 'umap'
52
+ Provides-Extra: viz
53
+ Requires-Dist: matplotlib; extra == 'viz'
54
+ Description-Content-Type: text/markdown
55
+
56
+ # topic-stability
57
+
58
+ Measure and visualize the stability of topic models across multiple runs.
59
+
60
+ Topic models are stochastic: two runs with the same settings produce differently-labelled topics in a different order. **topic-stability** aligns topics across runs using sentence-embedding centroids and scores each topic by how consistently the same documents are assigned to it (Jensen-Shannon divergence). The result is a per-topic stability score in [0, 1] and a small-multiples UMAP visualization with stability annotated on each panel.
61
+
62
+ Works with any topic model that produces a document-topic matrix — LDA, NMF, BERTopic, and more.
63
+
64
+ ## Install
65
+
66
+ ```bash
67
+ pip install topic-stability # core (numpy + scipy only)
68
+ pip install "topic-stability[embed]" # + sentence-transformers
69
+ pip install "topic-stability[umap,viz]" # + UMAP + matplotlib
70
+ pip install "topic-stability[all]" # everything
71
+ ```
72
+
73
+ ## Quick start
74
+
75
+ ### sklearn (LDA, NMF, …)
76
+
77
+ ```python
78
+ from sklearn.decomposition import LatentDirichletAllocation
79
+ from topic_stability import TopicRun, StabilityAnalysis, DocumentEmbedder
80
+
81
+ # Embed documents once and cache to disk
82
+ embedder = DocumentEmbedder(cache_path="embeddings.npy")
83
+ embeddings = embedder.embed(texts, ids=doc_ids)
84
+
85
+ # Train several runs
86
+ runs = [TopicRun.from_sklearn(
87
+ LatentDirichletAllocation(n_components=20).fit(X), X
88
+ ) for _ in range(5)]
89
+
90
+ analysis = StabilityAnalysis(runs, embeddings=embeddings)
91
+ analysis.align()
92
+
93
+ print(analysis.topic_stability()) # array of shape (n_topics,)
94
+ print(analysis.overall_stability()) # scalar
95
+
96
+ analysis.visualize("topics.png") # requires topic-stability[umap,viz]
97
+ ```
98
+
99
+ ### Pass precomputed embeddings (e.g. from BERTopic)
100
+
101
+ ```python
102
+ from topic_stability.integrations.bertopic import from_bertopic
103
+
104
+ run, embeddings = from_bertopic(model, embeddings=precomputed_embeddings)
105
+ ```
106
+
107
+ See [BERTopic notes](#bertopic) below for important differences.
108
+
109
+ ### From files (Mallet / CSV pipeline)
110
+
111
+ ```python
112
+ runs = [
113
+ TopicRun.from_csv(
114
+ f"model_42_run{i}/doc_topic_avg.csv",
115
+ word_topic_path=f"model_42_run{i}/word_topic_avg.csv",
116
+ )
117
+ for i in range(1, 6)
118
+ ]
119
+
120
+ embedder = DocumentEmbedder(cache_path="embeddings.npy")
121
+ embeddings, _ = embedder.load()
122
+
123
+ analysis = StabilityAnalysis(runs, embeddings=embeddings)
124
+ analysis.align()
125
+ analysis.visualize("topics.png", umap_coords=precomputed_umap)
126
+ ```
127
+
128
+ ## API
129
+
130
+ ### `TopicRun`
131
+
132
+ One run's topic distributions.
133
+
134
+ | Constructor | Use when |
135
+ |---|---|
136
+ | `TopicRun.from_matrix(doc_topic, *, doc_ids, word_topic, vocab)` | You have numpy arrays |
137
+ | `TopicRun.from_sklearn(model, X, *, doc_ids, vocab)` | sklearn `transform()` interface |
138
+ | `TopicRun.from_csv(doc_topic_path, *, word_topic_path)` | CSV files from the CLI pipeline |
139
+ | `TopicRun.from_mallet_states(model_dir, *, iterations, tsv_path)` | Mallet `.gz` state files |
140
+
141
+ ### `DocumentEmbedder`
142
+
143
+ ```python
144
+ embedder = DocumentEmbedder(model="all-MiniLM-L6-v2", cache_path="embeddings.npy")
145
+ embeddings = embedder.embed(texts, ids=doc_ids) # computes and caches
146
+ embeddings, ids = embedder.load() # load from cache
147
+ ```
148
+
149
+ Pass the returned array directly to `StabilityAnalysis(runs, embeddings=embeddings)`.
150
+
151
+ ### `StabilityAnalysis`
152
+
153
+ ```python
154
+ analysis = StabilityAnalysis(runs, embeddings, *, doc_ids=None)
155
+ analysis.align(reference=0) # must call before scoring
156
+ analysis.topic_stability() # ndarray (n_topics,) in [0, 1]
157
+ analysis.overall_stability() # float
158
+ analysis.umap_projection(**kwargs) # ndarray (n_docs, 2)
159
+ analysis.visualize(path, *, reference_run=0, umap_coords=None)
160
+ ```
161
+
162
+ **Alignment** uses cosine similarity of per-topic embedding centroids
163
+ (`centroid_k = Σ_d θ_dk · e_d`, normalised) matched with the Hungarian
164
+ algorithm. No shared vocabulary is required, so runs from different model
165
+ types can be compared.
166
+
167
+ **Stability score** for topic k: mean pairwise `1 − JS(p, q)` where p and
168
+ q are the normalised document-profile columns `θ[:,k]` (treated as a
169
+ distribution over documents) from each pair of aligned runs.
170
+
171
+ ## BERTopic
172
+
173
+ ```python
174
+ from topic_stability.integrations.bertopic import from_bertopic
175
+
176
+ run, embeddings = from_bertopic(model, docs=None, *, embeddings=None, doc_ids=None)
177
+ ```
178
+
179
+ Returns `(TopicRun, embeddings_array)`.
180
+
181
+ **Key differences from LDA/NMF:**
182
+
183
+ - BERTopic assigns each document to exactly one cluster (hard assignment). The
184
+ `doc_topic` matrix is binary: 1 for the assigned topic, 0 elsewhere.
185
+ Documents that HDBSCAN assigns to topic −1 (outliers) get an all-zero row.
186
+ - `model.probabilities_` contains HDBSCAN soft-membership scores, not
187
+ topic-weight distributions. We do not use them — they are a geometric
188
+ property of the embedding space, not comparable to LDA posterior weights.
189
+ - Word representations come from c-TF-IDF scores, not a generative word
190
+ distribution. Cross-model word-based comparison is not meaningful.
191
+ - Stability scores measure whether the *same documents* cluster together
192
+ across runs, not whether the same word distributions recur.
193
+
194
+ ## CLI pipeline (Mallet / RustMallet)
195
+
196
+ The package includes CLI wrappers for a full file-based workflow:
197
+
198
+ ```bash
199
+ # 1. Embed documents
200
+ topic-stability-embed corpus.tsv embeddings.npy
201
+
202
+ # 2. Project to 2D
203
+ topic-stability-project embeddings.npy umap_2d.csv
204
+
205
+ # 3. Estimate distributions from Mallet states
206
+ topic-stability-estimate model_42_run1/ 42 corpus.tsv
207
+
208
+ # 4. Visualize a single run
209
+ topic-stability-visualize umap_2d.csv model_42_run1/doc_topic_avg.csv \
210
+ model_42_run1/word_topic_avg.csv topics.png
211
+ ```
212
+
213
+ ## License
214
+
215
+ MIT
@@ -0,0 +1,160 @@
1
+ # topic-stability
2
+
3
+ Measure and visualize the stability of topic models across multiple runs.
4
+
5
+ Topic models are stochastic: two runs with the same settings produce differently-labelled topics in a different order. **topic-stability** aligns topics across runs using sentence-embedding centroids and scores each topic by how consistently the same documents are assigned to it (Jensen-Shannon divergence). The result is a per-topic stability score in [0, 1] and a small-multiples UMAP visualization with stability annotated on each panel.
6
+
7
+ Works with any topic model that produces a document-topic matrix — LDA, NMF, BERTopic, and more.
8
+
9
+ ## Install
10
+
11
+ ```bash
12
+ pip install topic-stability # core (numpy + scipy only)
13
+ pip install "topic-stability[embed]" # + sentence-transformers
14
+ pip install "topic-stability[umap,viz]" # + UMAP + matplotlib
15
+ pip install "topic-stability[all]" # everything
16
+ ```
17
+
18
+ ## Quick start
19
+
20
+ ### sklearn (LDA, NMF, …)
21
+
22
+ ```python
23
+ from sklearn.decomposition import LatentDirichletAllocation
24
+ from topic_stability import TopicRun, StabilityAnalysis, DocumentEmbedder
25
+
26
+ # Embed documents once and cache to disk
27
+ embedder = DocumentEmbedder(cache_path="embeddings.npy")
28
+ embeddings = embedder.embed(texts, ids=doc_ids)
29
+
30
+ # Train several runs
31
+ runs = [TopicRun.from_sklearn(
32
+ LatentDirichletAllocation(n_components=20).fit(X), X
33
+ ) for _ in range(5)]
34
+
35
+ analysis = StabilityAnalysis(runs, embeddings=embeddings)
36
+ analysis.align()
37
+
38
+ print(analysis.topic_stability()) # array of shape (n_topics,)
39
+ print(analysis.overall_stability()) # scalar
40
+
41
+ analysis.visualize("topics.png") # requires topic-stability[umap,viz]
42
+ ```
43
+
44
+ ### Pass precomputed embeddings (e.g. from BERTopic)
45
+
46
+ ```python
47
+ from topic_stability.integrations.bertopic import from_bertopic
48
+
49
+ run, embeddings = from_bertopic(model, embeddings=precomputed_embeddings)
50
+ ```
51
+
52
+ See [BERTopic notes](#bertopic) below for important differences.
53
+
54
+ ### From files (Mallet / CSV pipeline)
55
+
56
+ ```python
57
+ runs = [
58
+ TopicRun.from_csv(
59
+ f"model_42_run{i}/doc_topic_avg.csv",
60
+ word_topic_path=f"model_42_run{i}/word_topic_avg.csv",
61
+ )
62
+ for i in range(1, 6)
63
+ ]
64
+
65
+ embedder = DocumentEmbedder(cache_path="embeddings.npy")
66
+ embeddings, _ = embedder.load()
67
+
68
+ analysis = StabilityAnalysis(runs, embeddings=embeddings)
69
+ analysis.align()
70
+ analysis.visualize("topics.png", umap_coords=precomputed_umap)
71
+ ```
72
+
73
+ ## API
74
+
75
+ ### `TopicRun`
76
+
77
+ One run's topic distributions.
78
+
79
+ | Constructor | Use when |
80
+ |---|---|
81
+ | `TopicRun.from_matrix(doc_topic, *, doc_ids, word_topic, vocab)` | You have numpy arrays |
82
+ | `TopicRun.from_sklearn(model, X, *, doc_ids, vocab)` | sklearn `transform()` interface |
83
+ | `TopicRun.from_csv(doc_topic_path, *, word_topic_path)` | CSV files from the CLI pipeline |
84
+ | `TopicRun.from_mallet_states(model_dir, *, iterations, tsv_path)` | Mallet `.gz` state files |
85
+
86
+ ### `DocumentEmbedder`
87
+
88
+ ```python
89
+ embedder = DocumentEmbedder(model="all-MiniLM-L6-v2", cache_path="embeddings.npy")
90
+ embeddings = embedder.embed(texts, ids=doc_ids) # computes and caches
91
+ embeddings, ids = embedder.load() # load from cache
92
+ ```
93
+
94
+ Pass the returned array directly to `StabilityAnalysis(runs, embeddings=embeddings)`.
95
+
96
+ ### `StabilityAnalysis`
97
+
98
+ ```python
99
+ analysis = StabilityAnalysis(runs, embeddings, *, doc_ids=None)
100
+ analysis.align(reference=0) # must call before scoring
101
+ analysis.topic_stability() # ndarray (n_topics,) in [0, 1]
102
+ analysis.overall_stability() # float
103
+ analysis.umap_projection(**kwargs) # ndarray (n_docs, 2)
104
+ analysis.visualize(path, *, reference_run=0, umap_coords=None)
105
+ ```
106
+
107
+ **Alignment** uses cosine similarity of per-topic embedding centroids
108
+ (`centroid_k = Σ_d θ_dk · e_d`, normalised) matched with the Hungarian
109
+ algorithm. No shared vocabulary is required, so runs from different model
110
+ types can be compared.
111
+
112
+ **Stability score** for topic k: mean pairwise `1 − JS(p, q)` where p and
113
+ q are the normalised document-profile columns `θ[:,k]` (treated as a
114
+ distribution over documents) from each pair of aligned runs.
115
+
116
+ ## BERTopic
117
+
118
+ ```python
119
+ from topic_stability.integrations.bertopic import from_bertopic
120
+
121
+ run, embeddings = from_bertopic(model, docs=None, *, embeddings=None, doc_ids=None)
122
+ ```
123
+
124
+ Returns `(TopicRun, embeddings_array)`.
125
+
126
+ **Key differences from LDA/NMF:**
127
+
128
+ - BERTopic assigns each document to exactly one cluster (hard assignment). The
129
+ `doc_topic` matrix is binary: 1 for the assigned topic, 0 elsewhere.
130
+ Documents that HDBSCAN assigns to topic −1 (outliers) get an all-zero row.
131
+ - `model.probabilities_` contains HDBSCAN soft-membership scores, not
132
+ topic-weight distributions. We do not use them — they are a geometric
133
+ property of the embedding space, not comparable to LDA posterior weights.
134
+ - Word representations come from c-TF-IDF scores, not a generative word
135
+ distribution. Cross-model word-based comparison is not meaningful.
136
+ - Stability scores measure whether the *same documents* cluster together
137
+ across runs, not whether the same word distributions recur.
138
+
139
+ ## CLI pipeline (Mallet / RustMallet)
140
+
141
+ The package includes CLI wrappers for a full file-based workflow:
142
+
143
+ ```bash
144
+ # 1. Embed documents
145
+ topic-stability-embed corpus.tsv embeddings.npy
146
+
147
+ # 2. Project to 2D
148
+ topic-stability-project embeddings.npy umap_2d.csv
149
+
150
+ # 3. Estimate distributions from Mallet states
151
+ topic-stability-estimate model_42_run1/ 42 corpus.tsv
152
+
153
+ # 4. Visualize a single run
154
+ topic-stability-visualize umap_2d.csv model_42_run1/doc_topic_avg.csv \
155
+ model_42_run1/word_topic_avg.csv topics.png
156
+ ```
157
+
158
+ ## License
159
+
160
+ MIT
@@ -0,0 +1,65 @@
1
+ #!/usr/bin/env python3
2
+ # /// script
3
+ # dependencies = ["sentence-transformers", "numpy"]
4
+ # ///
5
+ """
6
+ Generate sentence-transformer embeddings for documents in cs_cl.tsv.
7
+
8
+ Usage: uv run embed_documents.py <tsv_path> <output_npy> [model_name]
9
+
10
+ Reads document IDs and text from a three-column TSV (id, date, text),
11
+ generates one embedding per document, and saves the matrix to a .npy file.
12
+ A companion <output_npy>.ids file records the document ID for each row.
13
+
14
+ Default model: all-MiniLM-L6-v2 (fast, 384-dim, good semantic similarity).
15
+ """
16
+
17
+ import sys
18
+ import os
19
+ import numpy as np
20
+ from sentence_transformers import SentenceTransformer
21
+
22
+ DEFAULT_MODEL = "all-MiniLM-L6-v2"
23
+
24
+
25
+ def read_tsv(path):
26
+ ids, texts = [], []
27
+ with open(path, encoding="utf-8") as f:
28
+ for line in f:
29
+ parts = line.rstrip("\n").split("\t")
30
+ ids.append(parts[0])
31
+ texts.append(parts[2])
32
+ return ids, texts
33
+
34
+
35
+ def main():
36
+ if len(sys.argv) < 3:
37
+ print("Usage: embed_documents.py <tsv_path> <output_npy> [model_name]")
38
+ sys.exit(1)
39
+
40
+ tsv_path = sys.argv[1]
41
+ output_npy = sys.argv[2]
42
+ model_name = sys.argv[3] if len(sys.argv) > 3 else DEFAULT_MODEL
43
+
44
+ print(f"Reading {tsv_path} ...")
45
+ ids, texts = read_tsv(tsv_path)
46
+ print(f" {len(texts)} documents")
47
+
48
+ print(f"Loading model {model_name} ...")
49
+ model = SentenceTransformer(model_name)
50
+
51
+ print("Encoding ...")
52
+ embeddings = model.encode(texts, show_progress_bar=True, convert_to_numpy=True)
53
+ print(f" Embedding shape: {embeddings.shape}")
54
+
55
+ np.save(output_npy, embeddings)
56
+ ids_path = output_npy + ".ids"
57
+ with open(ids_path, "w") as f:
58
+ f.write("\n".join(ids) + "\n")
59
+
60
+ print(f"Saved embeddings to {output_npy}")
61
+ print(f"Saved document IDs to {ids_path}")
62
+
63
+
64
+ if __name__ == "__main__":
65
+ main()