ebm4subjects 0.5.3__py3-none-any.whl → 0.5.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ebm4subjects/ebm_model.py +15 -6
- ebm4subjects/embedding_generator.py +72 -6
- {ebm4subjects-0.5.3.dist-info → ebm4subjects-0.5.4.dist-info}/METADATA +1 -1
- {ebm4subjects-0.5.3.dist-info → ebm4subjects-0.5.4.dist-info}/RECORD +6 -6
- {ebm4subjects-0.5.3.dist-info → ebm4subjects-0.5.4.dist-info}/WHEEL +0 -0
- {ebm4subjects-0.5.3.dist-info → ebm4subjects-0.5.4.dist-info}/licenses/LICENSE +0 -0
ebm4subjects/ebm_model.py
CHANGED
|
@@ -15,8 +15,9 @@ from ebm4subjects.duckdb_client import Duckdb_client
|
|
|
15
15
|
from ebm4subjects.ebm_logging import EbmLogger, NullLogger, XGBLogging
|
|
16
16
|
from ebm4subjects.embedding_generator import (
|
|
17
17
|
EmbeddingGeneratorHuggingFaceTEI,
|
|
18
|
-
EmbeddingGeneratorOfflineInference,
|
|
19
18
|
EmbeddingGeneratorMock,
|
|
19
|
+
EmbeddingGeneratorOfflineInference,
|
|
20
|
+
EmbeddingGeneratorOpenAI,
|
|
20
21
|
)
|
|
21
22
|
|
|
22
23
|
|
|
@@ -43,7 +44,7 @@ class EbmModel:
|
|
|
43
44
|
use_altLabels: bool = True,
|
|
44
45
|
hnsw_index_params: dict | str | None = None,
|
|
45
46
|
embedding_model_name: str | None = None,
|
|
46
|
-
|
|
47
|
+
embedding_model_deployment: str = "offline-inference",
|
|
47
48
|
embedding_model_args: dict | str | None = None,
|
|
48
49
|
encode_args_vocab: dict | str | None = None,
|
|
49
50
|
encode_args_documents: dict | str | None = None,
|
|
@@ -99,7 +100,7 @@ class EbmModel:
|
|
|
99
100
|
|
|
100
101
|
# Parameters for embedding generator
|
|
101
102
|
self.generator = None
|
|
102
|
-
self.
|
|
103
|
+
self.embedding_model_deployment = embedding_model_deployment
|
|
103
104
|
self.embedding_model_name = embedding_model_name
|
|
104
105
|
self.embedding_dimensions = int(embedding_dimensions)
|
|
105
106
|
if isinstance(embedding_model_args, str) or not embedding_model_args:
|
|
@@ -179,19 +180,27 @@ class EbmModel:
|
|
|
179
180
|
None
|
|
180
181
|
"""
|
|
181
182
|
if self.generator is None:
|
|
182
|
-
if self.
|
|
183
|
+
if self.embedding_model_deployment == "offline-inference":
|
|
183
184
|
self.logger.info("initializing offline-inference embedding generator")
|
|
184
185
|
self.generator = EmbeddingGeneratorOfflineInference(
|
|
185
186
|
model_name=self.embedding_model_name,
|
|
186
187
|
embedding_dimensions=self.embedding_dimensions,
|
|
187
188
|
**self.embedding_model_args,
|
|
188
189
|
)
|
|
189
|
-
elif self.
|
|
190
|
+
elif self.embedding_model_deployment == "mock":
|
|
190
191
|
self.logger.info("initializing mock embedding generator")
|
|
191
192
|
self.generator = EmbeddingGeneratorMock(self.embedding_dimensions)
|
|
192
|
-
elif self.
|
|
193
|
+
elif self.embedding_model_deployment == "HuggingFaceTEI":
|
|
193
194
|
self.logger.info("initializing API embedding generator")
|
|
194
195
|
self.generator = EmbeddingGeneratorHuggingFaceTEI(
|
|
196
|
+
model_name=self.embedding_model_name,
|
|
197
|
+
embedding_dimensions=self.embedding_dimensions,
|
|
198
|
+
**self.embedding_model_args,
|
|
199
|
+
)
|
|
200
|
+
elif self.embedding_model_deployment == "OpenAI":
|
|
201
|
+
self.logger.info("initializing API embedding generator")
|
|
202
|
+
self.generator = EmbeddingGeneratorOpenAI(
|
|
203
|
+
model_name=self.embedding_model_name,
|
|
195
204
|
embedding_dimensions=self.embedding_dimensions,
|
|
196
205
|
**self.embedding_model_args,
|
|
197
206
|
)
|
|
@@ -3,6 +3,7 @@ import os
|
|
|
3
3
|
import numpy as np
|
|
4
4
|
import requests
|
|
5
5
|
from sentence_transformers import SentenceTransformer
|
|
6
|
+
from tqdm import tqdm
|
|
6
7
|
|
|
7
8
|
|
|
8
9
|
class EmbeddingGenerator:
|
|
@@ -41,6 +42,7 @@ class EmbeddingGeneratorAPI(EmbeddingGenerator):
|
|
|
41
42
|
|
|
42
43
|
def __init__(
|
|
43
44
|
self,
|
|
45
|
+
model_name: str,
|
|
44
46
|
embedding_dimensions: int,
|
|
45
47
|
**kwargs,
|
|
46
48
|
) -> None:
|
|
@@ -52,7 +54,7 @@ class EmbeddingGeneratorAPI(EmbeddingGenerator):
|
|
|
52
54
|
"""
|
|
53
55
|
|
|
54
56
|
self.embedding_dimensions = embedding_dimensions
|
|
55
|
-
|
|
57
|
+
self.model_name = model_name
|
|
56
58
|
self.session = requests.Session()
|
|
57
59
|
self.api_address = kwargs.get("api_address")
|
|
58
60
|
self.headers = kwargs.get("headers", {"Content-Type": "application/json"})
|
|
@@ -85,19 +87,83 @@ class EmbeddingGeneratorHuggingFaceTEI(EmbeddingGeneratorAPI):
|
|
|
85
87
|
# If empty, return an empty numpy array with the correct shape
|
|
86
88
|
return np.empty((0, self.embedding_dimensions))
|
|
87
89
|
|
|
88
|
-
#
|
|
89
|
-
|
|
90
|
+
# Process in smaller batches to avoid memory overload
|
|
91
|
+
batch_size = min(32, len(texts)) # HuggingFaceTEI has a limit of 32 as default
|
|
92
|
+
|
|
93
|
+
for i in tqdm(range(0, len(texts), batch_size), desc="Processing batches"):
|
|
94
|
+
batch_texts = texts[i : i + batch_size]
|
|
90
95
|
# send a request to the HuggingFaceTEI API
|
|
91
|
-
data = {"inputs":
|
|
96
|
+
data = {"inputs": batch_texts, "truncate": True}
|
|
92
97
|
response = self.session.post(
|
|
93
98
|
self.api_address, headers=self.headers, json=data
|
|
94
99
|
)
|
|
95
100
|
|
|
96
101
|
# add generated embeddings to return list if request was successfull
|
|
97
102
|
if response.status_code == 200:
|
|
98
|
-
embeddings.
|
|
103
|
+
embeddings.extend(response.json())
|
|
104
|
+
else:
|
|
105
|
+
# TODO: write warning to logger
|
|
106
|
+
for _ in batch_texts:
|
|
107
|
+
# TODO: ensure same format as true case and truncate dim
|
|
108
|
+
embeddings.append([0 for _ in range(self.embedding_dimensions)])
|
|
109
|
+
|
|
110
|
+
return np.array(embeddings)
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
class EmbeddingGeneratorOpenAI(EmbeddingGeneratorAPI):
|
|
114
|
+
"""
|
|
115
|
+
A class for generating embeddings using any OpenAI compatibleAPI.
|
|
116
|
+
"""
|
|
117
|
+
|
|
118
|
+
def generate_embeddings(self, texts: list[str], **kwargs) -> np.ndarray:
|
|
119
|
+
"""
|
|
120
|
+
Generates embeddings for a list of input texts using a model
|
|
121
|
+
via an OpenAI compatible API.
|
|
122
|
+
|
|
123
|
+
Args:
|
|
124
|
+
texts (list[str]): A list of input texts.
|
|
125
|
+
**kwargs: Additional keyword arguments to pass to the
|
|
126
|
+
SentenceTransformer model.
|
|
127
|
+
|
|
128
|
+
Returns:
|
|
129
|
+
np.ndarray: A numpy array of shape (len(texts), embedding_dimensions)
|
|
130
|
+
containing the generated embeddings.
|
|
131
|
+
"""
|
|
132
|
+
# prepare list for return
|
|
133
|
+
embeddings = []
|
|
134
|
+
|
|
135
|
+
# Check if the input list is empty
|
|
136
|
+
if not texts:
|
|
137
|
+
# If empty, return an empty numpy array with the correct shape
|
|
138
|
+
return np.empty((0, self.embedding_dimensions))
|
|
139
|
+
|
|
140
|
+
# Process in smaller batches to avoid memory overload
|
|
141
|
+
batch_size = min(200, len(texts))
|
|
142
|
+
embeddings = []
|
|
143
|
+
|
|
144
|
+
for i in tqdm(range(0, len(texts), batch_size), desc="Processing batches"):
|
|
145
|
+
batch_texts = texts[i : i + batch_size]
|
|
146
|
+
data = {
|
|
147
|
+
"input": batch_texts,
|
|
148
|
+
"model": self.model_name,
|
|
149
|
+
"encoding_format": "float",
|
|
150
|
+
**kwargs,
|
|
151
|
+
}
|
|
152
|
+
|
|
153
|
+
response = self.session.post(
|
|
154
|
+
self.api_address, headers=self.headers, json=data
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
# Process all embeddings from the batch response
|
|
158
|
+
if response.status_code == 200:
|
|
159
|
+
response_data = response.json()
|
|
160
|
+
for i, _ in enumerate(batch_texts):
|
|
161
|
+
embedding = response_data["data"][i]["embedding"]
|
|
162
|
+
embeddings.append(embedding)
|
|
99
163
|
else:
|
|
100
|
-
|
|
164
|
+
# TODO: write warning to logger
|
|
165
|
+
for _ in batch_texts:
|
|
166
|
+
embeddings.append([0 for _ in range(self.embedding_dimensions)])
|
|
101
167
|
|
|
102
168
|
return np.array(embeddings)
|
|
103
169
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: ebm4subjects
|
|
3
|
-
Version: 0.5.
|
|
3
|
+
Version: 0.5.4
|
|
4
4
|
Summary: Embedding Based Matching for Automated Subject Indexing
|
|
5
5
|
Author: Deutsche Nationalbibliothek
|
|
6
6
|
Maintainer-email: Clemens Rietdorf <c.rietdorf@dnb.de>, Maximilian Kähler <m.kaehler@dnb.de>
|
|
@@ -3,10 +3,10 @@ ebm4subjects/analyzer.py,sha256=lqX7AF8WsvwIavgtnmoVQ0i3wzBJJSeH47EiEwoLKGg,1664
|
|
|
3
3
|
ebm4subjects/chunker.py,sha256=HcEFJtKWHFYZL8DmZcHGXLPGEkCqHZhh_0kSqyYVsdE,6764
|
|
4
4
|
ebm4subjects/duckdb_client.py,sha256=8lDIpj2o2VTEtjHC_vTYrI5-RNXZnWMft45bS6z9B_k,13031
|
|
5
5
|
ebm4subjects/ebm_logging.py,sha256=xkbqeVhSCNuhMwkx2yoIX8_D3z9DcsauZEmHhR1gaS0,5962
|
|
6
|
-
ebm4subjects/ebm_model.py,sha256=
|
|
7
|
-
ebm4subjects/embedding_generator.py,sha256=
|
|
6
|
+
ebm4subjects/ebm_model.py,sha256=lzGx_HLkKyTPVhtU4117DOEDz1rduNdzltvCYSbHQPg,30780
|
|
7
|
+
ebm4subjects/embedding_generator.py,sha256=LKZ_YAe4Th8foI_8-v-3tYFj0KGJ90XJ3OPuMXaqgSQ,9274
|
|
8
8
|
ebm4subjects/prepare_data.py,sha256=vQ-BdXkIP3iZJdPXol0WDlY8cRFMHkjzzL7oC7EbouE,3084
|
|
9
|
-
ebm4subjects-0.5.
|
|
10
|
-
ebm4subjects-0.5.
|
|
11
|
-
ebm4subjects-0.5.
|
|
12
|
-
ebm4subjects-0.5.
|
|
9
|
+
ebm4subjects-0.5.4.dist-info/METADATA,sha256=OmMMh0pGAdv3YTkTork55wuj2gA0Ac8zV9ad3cDCIks,8274
|
|
10
|
+
ebm4subjects-0.5.4.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
11
|
+
ebm4subjects-0.5.4.dist-info/licenses/LICENSE,sha256=RpvAZSjULHvoTR_esTlucJ08-zdQydnoqQLbqOh9Ub8,13826
|
|
12
|
+
ebm4subjects-0.5.4.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|