tritopic 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tritopic/__init__.py +46 -0
- tritopic/core/__init__.py +17 -0
- tritopic/core/clustering.py +331 -0
- tritopic/core/embeddings.py +222 -0
- tritopic/core/graph_builder.py +493 -0
- tritopic/core/keywords.py +337 -0
- tritopic/core/model.py +810 -0
- tritopic/labeling/__init__.py +5 -0
- tritopic/labeling/llm_labeler.py +279 -0
- tritopic/utils/__init__.py +13 -0
- tritopic/utils/metrics.py +254 -0
- tritopic/visualization/__init__.py +5 -0
- tritopic/visualization/plotter.py +523 -0
- tritopic-0.1.0.dist-info/METADATA +400 -0
- tritopic-0.1.0.dist-info/RECORD +18 -0
- tritopic-0.1.0.dist-info/WHEEL +5 -0
- tritopic-0.1.0.dist-info/licenses/LICENSE +21 -0
- tritopic-0.1.0.dist-info/top_level.txt +1 -0
tritopic/__init__.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
"""
|
|
2
|
+
TriTopic: Tri-Modal Graph Topic Modeling with Iterative Refinement
|
|
3
|
+
===================================================================
|
|
4
|
+
|
|
5
|
+
A state-of-the-art topic modeling library that combines:
|
|
6
|
+
- Semantic embeddings (Sentence-BERT, Instructor, BGE)
|
|
7
|
+
- Lexical similarity (BM25)
|
|
8
|
+
- Metadata context (optional)
|
|
9
|
+
|
|
10
|
+
With advanced techniques:
|
|
11
|
+
- Leiden clustering with consensus
|
|
12
|
+
- Mutual kNN + SNN graph construction
|
|
13
|
+
- Iterative refinement loop
|
|
14
|
+
- LLM-powered topic labeling
|
|
15
|
+
|
|
16
|
+
Basic usage:
|
|
17
|
+
-----------
|
|
18
|
+
>>> from tritopic import TriTopic
|
|
19
|
+
>>> model = TriTopic()
|
|
20
|
+
>>> topics = model.fit_transform(documents)
|
|
21
|
+
>>> model.visualize()
|
|
22
|
+
|
|
23
|
+
Author: Roman Egger
|
|
24
|
+
License: MIT
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
__version__ = "0.1.0"
|
|
28
|
+
__author__ = "Roman Egger"
|
|
29
|
+
|
|
30
|
+
from tritopic.core.model import TriTopic
|
|
31
|
+
from tritopic.core.graph_builder import GraphBuilder
|
|
32
|
+
from tritopic.core.clustering import ConsensusLeiden
|
|
33
|
+
from tritopic.core.embeddings import EmbeddingEngine
|
|
34
|
+
from tritopic.core.keywords import KeywordExtractor
|
|
35
|
+
from tritopic.labeling.llm_labeler import LLMLabeler
|
|
36
|
+
from tritopic.visualization.plotter import TopicVisualizer
|
|
37
|
+
|
|
38
|
+
__all__ = [
|
|
39
|
+
"TriTopic",
|
|
40
|
+
"GraphBuilder",
|
|
41
|
+
"ConsensusLeiden",
|
|
42
|
+
"EmbeddingEngine",
|
|
43
|
+
"KeywordExtractor",
|
|
44
|
+
"LLMLabeler",
|
|
45
|
+
"TopicVisualizer",
|
|
46
|
+
]
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
"""Core components for TriTopic."""
|
|
2
|
+
|
|
3
|
+
from tritopic.core.model import TriTopic, TriTopicConfig, TopicInfo
|
|
4
|
+
from tritopic.core.graph_builder import GraphBuilder
|
|
5
|
+
from tritopic.core.clustering import ConsensusLeiden
|
|
6
|
+
from tritopic.core.embeddings import EmbeddingEngine
|
|
7
|
+
from tritopic.core.keywords import KeywordExtractor
|
|
8
|
+
|
|
9
|
+
__all__ = [
|
|
10
|
+
"TriTopic",
|
|
11
|
+
"TriTopicConfig",
|
|
12
|
+
"TopicInfo",
|
|
13
|
+
"GraphBuilder",
|
|
14
|
+
"ConsensusLeiden",
|
|
15
|
+
"EmbeddingEngine",
|
|
16
|
+
"KeywordExtractor",
|
|
17
|
+
]
|
|
@@ -0,0 +1,331 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Consensus Leiden Clustering
|
|
3
|
+
============================
|
|
4
|
+
|
|
5
|
+
Robust community detection with:
|
|
6
|
+
- Leiden algorithm (better than Louvain)
|
|
7
|
+
- Consensus clustering for stability
|
|
8
|
+
- Resolution parameter tuning
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from __future__ import annotations
|
|
12
|
+
|
|
13
|
+
from typing import Any
|
|
14
|
+
|
|
15
|
+
import numpy as np
|
|
16
|
+
from scipy.cluster.hierarchy import linkage, fcluster
|
|
17
|
+
from sklearn.metrics import adjusted_rand_score
|
|
18
|
+
from collections import Counter
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class ConsensusLeiden:
|
|
22
|
+
"""
|
|
23
|
+
Leiden clustering with consensus for stability.
|
|
24
|
+
|
|
25
|
+
Runs multiple Leiden clusterings with different seeds and combines
|
|
26
|
+
results using consensus clustering. This dramatically improves
|
|
27
|
+
reproducibility and reduces sensitivity to random initialization.
|
|
28
|
+
|
|
29
|
+
Parameters
|
|
30
|
+
----------
|
|
31
|
+
resolution : float
|
|
32
|
+
Resolution parameter for Leiden. Higher = more clusters. Default: 1.0
|
|
33
|
+
n_runs : int
|
|
34
|
+
Number of consensus runs. Default: 10
|
|
35
|
+
random_state : int
|
|
36
|
+
Random seed for reproducibility. Default: 42
|
|
37
|
+
consensus_threshold : float
|
|
38
|
+
Minimum agreement ratio for consensus. Default: 0.5
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
def __init__(
|
|
42
|
+
self,
|
|
43
|
+
resolution: float = 1.0,
|
|
44
|
+
n_runs: int = 10,
|
|
45
|
+
random_state: int = 42,
|
|
46
|
+
consensus_threshold: float = 0.5,
|
|
47
|
+
):
|
|
48
|
+
self.resolution = resolution
|
|
49
|
+
self.n_runs = n_runs
|
|
50
|
+
self.random_state = random_state
|
|
51
|
+
self.consensus_threshold = consensus_threshold
|
|
52
|
+
|
|
53
|
+
self.labels_: np.ndarray | None = None
|
|
54
|
+
self.stability_score_: float | None = None
|
|
55
|
+
self._all_partitions: list[np.ndarray] = []
|
|
56
|
+
|
|
57
|
+
def fit_predict(
|
|
58
|
+
self,
|
|
59
|
+
graph: "igraph.Graph",
|
|
60
|
+
min_cluster_size: int = 5,
|
|
61
|
+
resolution: float | None = None,
|
|
62
|
+
) -> np.ndarray:
|
|
63
|
+
"""
|
|
64
|
+
Fit Leiden clustering with consensus.
|
|
65
|
+
|
|
66
|
+
Parameters
|
|
67
|
+
----------
|
|
68
|
+
graph : igraph.Graph
|
|
69
|
+
Input graph with edge weights.
|
|
70
|
+
min_cluster_size : int
|
|
71
|
+
Minimum cluster size. Smaller clusters become outliers.
|
|
72
|
+
resolution : float, optional
|
|
73
|
+
Override default resolution.
|
|
74
|
+
|
|
75
|
+
Returns
|
|
76
|
+
-------
|
|
77
|
+
labels : np.ndarray
|
|
78
|
+
Cluster assignments. -1 for outliers.
|
|
79
|
+
"""
|
|
80
|
+
import leidenalg as la
|
|
81
|
+
|
|
82
|
+
res = resolution or self.resolution
|
|
83
|
+
n_nodes = graph.vcount()
|
|
84
|
+
|
|
85
|
+
# Run multiple Leiden clusterings
|
|
86
|
+
self._all_partitions = []
|
|
87
|
+
|
|
88
|
+
for run in range(self.n_runs):
|
|
89
|
+
seed = self.random_state + run
|
|
90
|
+
|
|
91
|
+
# Run Leiden
|
|
92
|
+
partition = la.find_partition(
|
|
93
|
+
graph,
|
|
94
|
+
la.RBConfigurationVertexPartition,
|
|
95
|
+
weights="weight",
|
|
96
|
+
resolution_parameter=res,
|
|
97
|
+
seed=seed,
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
# Convert to labels
|
|
101
|
+
labels = np.array(partition.membership)
|
|
102
|
+
self._all_partitions.append(labels)
|
|
103
|
+
|
|
104
|
+
# Compute consensus
|
|
105
|
+
self.labels_ = self._compute_consensus(self._all_partitions)
|
|
106
|
+
|
|
107
|
+
# Handle small clusters as outliers
|
|
108
|
+
self.labels_ = self._handle_small_clusters(self.labels_, min_cluster_size)
|
|
109
|
+
|
|
110
|
+
# Compute stability score
|
|
111
|
+
self.stability_score_ = self._compute_stability()
|
|
112
|
+
|
|
113
|
+
return self.labels_
|
|
114
|
+
|
|
115
|
+
def _compute_consensus(self, partitions: list[np.ndarray]) -> np.ndarray:
|
|
116
|
+
"""
|
|
117
|
+
Compute consensus partition from multiple runs.
|
|
118
|
+
|
|
119
|
+
Uses co-occurrence matrix and hierarchical clustering.
|
|
120
|
+
"""
|
|
121
|
+
n_nodes = len(partitions[0])
|
|
122
|
+
n_runs = len(partitions)
|
|
123
|
+
|
|
124
|
+
# Build co-occurrence matrix
|
|
125
|
+
# co_occur[i,j] = fraction of runs where i and j are in same cluster
|
|
126
|
+
co_occur = np.zeros((n_nodes, n_nodes))
|
|
127
|
+
|
|
128
|
+
for partition in partitions:
|
|
129
|
+
for cluster_id in np.unique(partition):
|
|
130
|
+
members = np.where(partition == cluster_id)[0]
|
|
131
|
+
for i in members:
|
|
132
|
+
for j in members:
|
|
133
|
+
co_occur[i, j] += 1
|
|
134
|
+
|
|
135
|
+
co_occur /= n_runs
|
|
136
|
+
|
|
137
|
+
# Convert co-occurrence to distance
|
|
138
|
+
distance = 1 - co_occur
|
|
139
|
+
|
|
140
|
+
# Hierarchical clustering on distance matrix
|
|
141
|
+
# Use condensed form for linkage
|
|
142
|
+
condensed = []
|
|
143
|
+
for i in range(n_nodes):
|
|
144
|
+
for j in range(i + 1, n_nodes):
|
|
145
|
+
condensed.append(distance[i, j])
|
|
146
|
+
condensed = np.array(condensed)
|
|
147
|
+
|
|
148
|
+
# Average linkage tends to work well for consensus
|
|
149
|
+
Z = linkage(condensed, method="average")
|
|
150
|
+
|
|
151
|
+
# Cut at threshold that matches approximate number of clusters
|
|
152
|
+
# from the most frequent partition
|
|
153
|
+
n_clusters_list = [len(np.unique(p)) for p in partitions]
|
|
154
|
+
median_n_clusters = int(np.median(n_clusters_list))
|
|
155
|
+
|
|
156
|
+
# Find optimal cut
|
|
157
|
+
best_labels = None
|
|
158
|
+
best_score = -1
|
|
159
|
+
|
|
160
|
+
for n_clusters in range(max(2, median_n_clusters - 2), median_n_clusters + 3):
|
|
161
|
+
try:
|
|
162
|
+
labels = fcluster(Z, n_clusters, criterion="maxclust")
|
|
163
|
+
labels = labels - 1 # 0-indexed
|
|
164
|
+
|
|
165
|
+
# Score by average ARI with original partitions
|
|
166
|
+
ari_scores = [adjusted_rand_score(labels, p) for p in partitions]
|
|
167
|
+
avg_ari = np.mean(ari_scores)
|
|
168
|
+
|
|
169
|
+
if avg_ari > best_score:
|
|
170
|
+
best_score = avg_ari
|
|
171
|
+
best_labels = labels
|
|
172
|
+
except Exception:
|
|
173
|
+
continue
|
|
174
|
+
|
|
175
|
+
if best_labels is None:
|
|
176
|
+
# Fallback to most common partition
|
|
177
|
+
best_labels = partitions[0]
|
|
178
|
+
|
|
179
|
+
return best_labels
|
|
180
|
+
|
|
181
|
+
def _handle_small_clusters(
|
|
182
|
+
self,
|
|
183
|
+
labels: np.ndarray,
|
|
184
|
+
min_size: int,
|
|
185
|
+
) -> np.ndarray:
|
|
186
|
+
"""Mark small clusters as outliers (-1)."""
|
|
187
|
+
result = labels.copy()
|
|
188
|
+
|
|
189
|
+
for cluster_id in np.unique(labels):
|
|
190
|
+
if cluster_id == -1:
|
|
191
|
+
continue
|
|
192
|
+
|
|
193
|
+
size = np.sum(labels == cluster_id)
|
|
194
|
+
if size < min_size:
|
|
195
|
+
result[labels == cluster_id] = -1
|
|
196
|
+
|
|
197
|
+
# Relabel to consecutive integers
|
|
198
|
+
unique_labels = sorted([l for l in np.unique(result) if l != -1])
|
|
199
|
+
label_map = {old: new for new, old in enumerate(unique_labels)}
|
|
200
|
+
label_map[-1] = -1
|
|
201
|
+
|
|
202
|
+
result = np.array([label_map[l] for l in result])
|
|
203
|
+
|
|
204
|
+
return result
|
|
205
|
+
|
|
206
|
+
def _compute_stability(self) -> float:
|
|
207
|
+
"""Compute stability score as average pairwise ARI."""
|
|
208
|
+
if len(self._all_partitions) < 2:
|
|
209
|
+
return 1.0
|
|
210
|
+
|
|
211
|
+
ari_scores = []
|
|
212
|
+
for i in range(len(self._all_partitions)):
|
|
213
|
+
for j in range(i + 1, len(self._all_partitions)):
|
|
214
|
+
ari = adjusted_rand_score(
|
|
215
|
+
self._all_partitions[i],
|
|
216
|
+
self._all_partitions[j]
|
|
217
|
+
)
|
|
218
|
+
ari_scores.append(ari)
|
|
219
|
+
|
|
220
|
+
return float(np.mean(ari_scores))
|
|
221
|
+
|
|
222
|
+
def find_optimal_resolution(
|
|
223
|
+
self,
|
|
224
|
+
graph: "igraph.Graph",
|
|
225
|
+
resolution_range: tuple[float, float] = (0.1, 2.0),
|
|
226
|
+
n_steps: int = 10,
|
|
227
|
+
target_n_topics: int | None = None,
|
|
228
|
+
) -> float:
|
|
229
|
+
"""
|
|
230
|
+
Find optimal resolution parameter.
|
|
231
|
+
|
|
232
|
+
Parameters
|
|
233
|
+
----------
|
|
234
|
+
graph : igraph.Graph
|
|
235
|
+
Input graph.
|
|
236
|
+
resolution_range : tuple
|
|
237
|
+
Range of resolutions to search.
|
|
238
|
+
n_steps : int
|
|
239
|
+
Number of resolutions to try.
|
|
240
|
+
target_n_topics : int, optional
|
|
241
|
+
If provided, find resolution closest to this number of topics.
|
|
242
|
+
|
|
243
|
+
Returns
|
|
244
|
+
-------
|
|
245
|
+
optimal_resolution : float
|
|
246
|
+
Best resolution parameter.
|
|
247
|
+
"""
|
|
248
|
+
import leidenalg as la
|
|
249
|
+
|
|
250
|
+
resolutions = np.linspace(resolution_range[0], resolution_range[1], n_steps)
|
|
251
|
+
results = []
|
|
252
|
+
|
|
253
|
+
for res in resolutions:
|
|
254
|
+
partition = la.find_partition(
|
|
255
|
+
graph,
|
|
256
|
+
la.RBConfigurationVertexPartition,
|
|
257
|
+
weights="weight",
|
|
258
|
+
resolution_parameter=res,
|
|
259
|
+
seed=self.random_state,
|
|
260
|
+
)
|
|
261
|
+
|
|
262
|
+
n_clusters = len(set(partition.membership))
|
|
263
|
+
modularity = partition.modularity
|
|
264
|
+
|
|
265
|
+
results.append({
|
|
266
|
+
"resolution": res,
|
|
267
|
+
"n_clusters": n_clusters,
|
|
268
|
+
"modularity": modularity,
|
|
269
|
+
})
|
|
270
|
+
|
|
271
|
+
if target_n_topics is not None:
|
|
272
|
+
# Find closest to target
|
|
273
|
+
best = min(results, key=lambda x: abs(x["n_clusters"] - target_n_topics))
|
|
274
|
+
else:
|
|
275
|
+
# Find highest modularity
|
|
276
|
+
best = max(results, key=lambda x: x["modularity"])
|
|
277
|
+
|
|
278
|
+
return best["resolution"]
|
|
279
|
+
|
|
280
|
+
|
|
281
|
+
class HDBSCANClusterer:
|
|
282
|
+
"""
|
|
283
|
+
Alternative clustering using HDBSCAN.
|
|
284
|
+
|
|
285
|
+
Useful for datasets with varying density or many outliers.
|
|
286
|
+
"""
|
|
287
|
+
|
|
288
|
+
def __init__(
|
|
289
|
+
self,
|
|
290
|
+
min_cluster_size: int = 10,
|
|
291
|
+
min_samples: int = 5,
|
|
292
|
+
metric: str = "euclidean",
|
|
293
|
+
):
|
|
294
|
+
self.min_cluster_size = min_cluster_size
|
|
295
|
+
self.min_samples = min_samples
|
|
296
|
+
self.metric = metric
|
|
297
|
+
|
|
298
|
+
self.labels_: np.ndarray | None = None
|
|
299
|
+
self.probabilities_: np.ndarray | None = None
|
|
300
|
+
|
|
301
|
+
def fit_predict(
|
|
302
|
+
self,
|
|
303
|
+
embeddings: np.ndarray,
|
|
304
|
+
**kwargs,
|
|
305
|
+
) -> np.ndarray:
|
|
306
|
+
"""
|
|
307
|
+
Fit HDBSCAN clustering.
|
|
308
|
+
|
|
309
|
+
Parameters
|
|
310
|
+
----------
|
|
311
|
+
embeddings : np.ndarray
|
|
312
|
+
Document embeddings (optionally reduced with UMAP first).
|
|
313
|
+
|
|
314
|
+
Returns
|
|
315
|
+
-------
|
|
316
|
+
labels : np.ndarray
|
|
317
|
+
Cluster assignments. -1 for outliers.
|
|
318
|
+
"""
|
|
319
|
+
import hdbscan
|
|
320
|
+
|
|
321
|
+
clusterer = hdbscan.HDBSCAN(
|
|
322
|
+
min_cluster_size=self.min_cluster_size,
|
|
323
|
+
min_samples=self.min_samples,
|
|
324
|
+
metric=self.metric,
|
|
325
|
+
**kwargs,
|
|
326
|
+
)
|
|
327
|
+
|
|
328
|
+
self.labels_ = clusterer.fit_predict(embeddings)
|
|
329
|
+
self.probabilities_ = clusterer.probabilities_
|
|
330
|
+
|
|
331
|
+
return self.labels_
|
|
@@ -0,0 +1,222 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Embedding Engine for TriTopic
|
|
3
|
+
==============================
|
|
4
|
+
|
|
5
|
+
Handles document embedding with support for multiple models:
|
|
6
|
+
- Sentence-BERT models (default)
|
|
7
|
+
- Instructor models (task-specific)
|
|
8
|
+
- BGE models (multilingual)
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from __future__ import annotations
|
|
12
|
+
|
|
13
|
+
from typing import Any, Literal
|
|
14
|
+
|
|
15
|
+
import numpy as np
|
|
16
|
+
from tqdm import tqdm
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class EmbeddingEngine:
|
|
20
|
+
"""
|
|
21
|
+
Generate document embeddings using transformer models.
|
|
22
|
+
|
|
23
|
+
Supports various embedding models optimized for different use cases.
|
|
24
|
+
|
|
25
|
+
Parameters
|
|
26
|
+
----------
|
|
27
|
+
model_name : str
|
|
28
|
+
Name of the sentence-transformers model. Popular choices:
|
|
29
|
+
- "all-MiniLM-L6-v2": Fast, good quality (default)
|
|
30
|
+
- "all-mpnet-base-v2": Higher quality, slower
|
|
31
|
+
- "BAAI/bge-base-en-v1.5": State-of-the-art
|
|
32
|
+
- "BAAI/bge-m3": Multilingual
|
|
33
|
+
- "hkunlp/instructor-large": Task-specific (use with instruction)
|
|
34
|
+
batch_size : int
|
|
35
|
+
Batch size for encoding. Default: 32
|
|
36
|
+
device : str or None
|
|
37
|
+
Device to use ("cuda", "cpu", or None for auto).
|
|
38
|
+
show_progress : bool
|
|
39
|
+
Show progress bar. Default: True
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
def __init__(
|
|
43
|
+
self,
|
|
44
|
+
model_name: str = "all-MiniLM-L6-v2",
|
|
45
|
+
batch_size: int = 32,
|
|
46
|
+
device: str | None = None,
|
|
47
|
+
show_progress: bool = True,
|
|
48
|
+
):
|
|
49
|
+
self.model_name = model_name
|
|
50
|
+
self.batch_size = batch_size
|
|
51
|
+
self.device = device
|
|
52
|
+
self.show_progress = show_progress
|
|
53
|
+
|
|
54
|
+
self._model = None
|
|
55
|
+
self._is_instructor = "instructor" in model_name.lower()
|
|
56
|
+
|
|
57
|
+
def _load_model(self):
|
|
58
|
+
"""Lazy load the embedding model."""
|
|
59
|
+
if self._model is not None:
|
|
60
|
+
return
|
|
61
|
+
|
|
62
|
+
from sentence_transformers import SentenceTransformer
|
|
63
|
+
|
|
64
|
+
self._model = SentenceTransformer(
|
|
65
|
+
self.model_name,
|
|
66
|
+
device=self.device,
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
def encode(
|
|
70
|
+
self,
|
|
71
|
+
documents: list[str],
|
|
72
|
+
instruction: str | None = None,
|
|
73
|
+
normalize: bool = True,
|
|
74
|
+
) -> np.ndarray:
|
|
75
|
+
"""
|
|
76
|
+
Encode documents to embeddings.
|
|
77
|
+
|
|
78
|
+
Parameters
|
|
79
|
+
----------
|
|
80
|
+
documents : list[str]
|
|
81
|
+
List of document texts.
|
|
82
|
+
instruction : str, optional
|
|
83
|
+
Instruction for Instructor models (e.g., "Represent the topic of this document:").
|
|
84
|
+
normalize : bool
|
|
85
|
+
Whether to L2-normalize embeddings. Default: True
|
|
86
|
+
|
|
87
|
+
Returns
|
|
88
|
+
-------
|
|
89
|
+
embeddings : np.ndarray
|
|
90
|
+
Document embeddings of shape (n_docs, embedding_dim).
|
|
91
|
+
"""
|
|
92
|
+
self._load_model()
|
|
93
|
+
|
|
94
|
+
# Handle instructor models
|
|
95
|
+
if self._is_instructor and instruction:
|
|
96
|
+
documents = [[instruction, doc] for doc in documents]
|
|
97
|
+
|
|
98
|
+
# Encode in batches
|
|
99
|
+
embeddings = self._model.encode(
|
|
100
|
+
documents,
|
|
101
|
+
batch_size=self.batch_size,
|
|
102
|
+
show_progress_bar=self.show_progress,
|
|
103
|
+
normalize_embeddings=normalize,
|
|
104
|
+
convert_to_numpy=True,
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
return embeddings
|
|
108
|
+
|
|
109
|
+
def encode_with_pooling(
|
|
110
|
+
self,
|
|
111
|
+
documents: list[str],
|
|
112
|
+
pooling: Literal["mean", "max", "cls"] = "mean",
|
|
113
|
+
) -> np.ndarray:
|
|
114
|
+
"""
|
|
115
|
+
Encode with custom pooling strategy.
|
|
116
|
+
|
|
117
|
+
Parameters
|
|
118
|
+
----------
|
|
119
|
+
documents : list[str]
|
|
120
|
+
Document texts.
|
|
121
|
+
pooling : str
|
|
122
|
+
Pooling strategy: "mean", "max", or "cls".
|
|
123
|
+
|
|
124
|
+
Returns
|
|
125
|
+
-------
|
|
126
|
+
embeddings : np.ndarray
|
|
127
|
+
Pooled embeddings.
|
|
128
|
+
"""
|
|
129
|
+
# For now, use default pooling from model
|
|
130
|
+
# Custom pooling would require access to token-level embeddings
|
|
131
|
+
return self.encode(documents)
|
|
132
|
+
|
|
133
|
+
@property
|
|
134
|
+
def embedding_dim(self) -> int:
|
|
135
|
+
"""Get embedding dimension."""
|
|
136
|
+
self._load_model()
|
|
137
|
+
return self._model.get_sentence_embedding_dimension()
|
|
138
|
+
|
|
139
|
+
def similarity(
|
|
140
|
+
self,
|
|
141
|
+
embeddings1: np.ndarray,
|
|
142
|
+
embeddings2: np.ndarray | None = None,
|
|
143
|
+
) -> np.ndarray:
|
|
144
|
+
"""
|
|
145
|
+
Compute cosine similarity between embeddings.
|
|
146
|
+
|
|
147
|
+
Parameters
|
|
148
|
+
----------
|
|
149
|
+
embeddings1 : np.ndarray
|
|
150
|
+
First set of embeddings.
|
|
151
|
+
embeddings2 : np.ndarray, optional
|
|
152
|
+
Second set. If None, compute pairwise similarity of embeddings1.
|
|
153
|
+
|
|
154
|
+
Returns
|
|
155
|
+
-------
|
|
156
|
+
similarity : np.ndarray
|
|
157
|
+
Similarity matrix.
|
|
158
|
+
"""
|
|
159
|
+
from sklearn.metrics.pairwise import cosine_similarity
|
|
160
|
+
|
|
161
|
+
if embeddings2 is None:
|
|
162
|
+
return cosine_similarity(embeddings1)
|
|
163
|
+
return cosine_similarity(embeddings1, embeddings2)
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
class MultiModelEmbedding:
|
|
167
|
+
"""
|
|
168
|
+
Combine embeddings from multiple models.
|
|
169
|
+
|
|
170
|
+
Useful for ensemble approaches where different models capture
|
|
171
|
+
different aspects of document semantics.
|
|
172
|
+
"""
|
|
173
|
+
|
|
174
|
+
def __init__(
|
|
175
|
+
self,
|
|
176
|
+
model_names: list[str],
|
|
177
|
+
weights: list[float] | None = None,
|
|
178
|
+
batch_size: int = 32,
|
|
179
|
+
):
|
|
180
|
+
self.model_names = model_names
|
|
181
|
+
self.weights = weights or [1.0 / len(model_names)] * len(model_names)
|
|
182
|
+
self.batch_size = batch_size
|
|
183
|
+
|
|
184
|
+
self._engines = [
|
|
185
|
+
EmbeddingEngine(name, batch_size=batch_size)
|
|
186
|
+
for name in model_names
|
|
187
|
+
]
|
|
188
|
+
|
|
189
|
+
def encode(
|
|
190
|
+
self,
|
|
191
|
+
documents: list[str],
|
|
192
|
+
normalize: bool = True,
|
|
193
|
+
) -> np.ndarray:
|
|
194
|
+
"""
|
|
195
|
+
Encode using all models and combine.
|
|
196
|
+
|
|
197
|
+
Parameters
|
|
198
|
+
----------
|
|
199
|
+
documents : list[str]
|
|
200
|
+
Document texts.
|
|
201
|
+
normalize : bool
|
|
202
|
+
Normalize final embeddings.
|
|
203
|
+
|
|
204
|
+
Returns
|
|
205
|
+
-------
|
|
206
|
+
embeddings : np.ndarray
|
|
207
|
+
Combined embeddings (concatenated).
|
|
208
|
+
"""
|
|
209
|
+
all_embeddings = []
|
|
210
|
+
|
|
211
|
+
for engine, weight in zip(self._engines, self.weights):
|
|
212
|
+
emb = engine.encode(documents, normalize=True)
|
|
213
|
+
all_embeddings.append(emb * weight)
|
|
214
|
+
|
|
215
|
+
# Concatenate
|
|
216
|
+
combined = np.hstack(all_embeddings)
|
|
217
|
+
|
|
218
|
+
if normalize:
|
|
219
|
+
norms = np.linalg.norm(combined, axis=1, keepdims=True)
|
|
220
|
+
combined = combined / (norms + 1e-10)
|
|
221
|
+
|
|
222
|
+
return combined
|