ebm4subjects 0.5.4__tar.gz → 0.5.5__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {ebm4subjects-0.5.4 → ebm4subjects-0.5.5}/PKG-INFO +4 -2
- {ebm4subjects-0.5.4 → ebm4subjects-0.5.5}/pyproject.toml +4 -2
- {ebm4subjects-0.5.4 → ebm4subjects-0.5.5}/src/ebm4subjects/ebm_logging.py +9 -9
- {ebm4subjects-0.5.4 → ebm4subjects-0.5.5}/src/ebm4subjects/ebm_model.py +14 -10
- {ebm4subjects-0.5.4 → ebm4subjects-0.5.5}/src/ebm4subjects/embedding_generator.py +143 -44
- {ebm4subjects-0.5.4 → ebm4subjects-0.5.5}/.gitignore +0 -0
- {ebm4subjects-0.5.4 → ebm4subjects-0.5.5}/.python-version +0 -0
- {ebm4subjects-0.5.4 → ebm4subjects-0.5.5}/LICENSE +0 -0
- {ebm4subjects-0.5.4 → ebm4subjects-0.5.5}/README.md +0 -0
- {ebm4subjects-0.5.4 → ebm4subjects-0.5.5}/docs/Makefile +0 -0
- {ebm4subjects-0.5.4 → ebm4subjects-0.5.5}/docs/make.bat +0 -0
- {ebm4subjects-0.5.4 → ebm4subjects-0.5.5}/docs/source/README.md +0 -0
- {ebm4subjects-0.5.4 → ebm4subjects-0.5.5}/docs/source/conf.py +0 -0
- {ebm4subjects-0.5.4 → ebm4subjects-0.5.5}/docs/source/ebm4subjects.rst +0 -0
- {ebm4subjects-0.5.4 → ebm4subjects-0.5.5}/docs/source/index.rst +0 -0
- {ebm4subjects-0.5.4 → ebm4subjects-0.5.5}/ebm-sketch.svg +0 -0
- {ebm4subjects-0.5.4 → ebm4subjects-0.5.5}/src/ebm4subjects/__init__.py +0 -0
- {ebm4subjects-0.5.4 → ebm4subjects-0.5.5}/src/ebm4subjects/analyzer.py +0 -0
- {ebm4subjects-0.5.4 → ebm4subjects-0.5.5}/src/ebm4subjects/chunker.py +0 -0
- {ebm4subjects-0.5.4 → ebm4subjects-0.5.5}/src/ebm4subjects/duckdb_client.py +0 -0
- {ebm4subjects-0.5.4 → ebm4subjects-0.5.5}/src/ebm4subjects/prepare_data.py +0 -0
- {ebm4subjects-0.5.4 → ebm4subjects-0.5.5}/tests/__init__.py +0 -0
- {ebm4subjects-0.5.4 → ebm4subjects-0.5.5}/tests/data/vocab.ttl +0 -0
- {ebm4subjects-0.5.4 → ebm4subjects-0.5.5}/tests/test_hello.py +0 -0
- {ebm4subjects-0.5.4 → ebm4subjects-0.5.5}/tests/test_prepare_data.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: ebm4subjects
|
|
3
|
-
Version: 0.5.
|
|
3
|
+
Version: 0.5.5
|
|
4
4
|
Summary: Embedding Based Matching for Automated Subject Indexing
|
|
5
5
|
Author: Deutsche Nationalbibliothek
|
|
6
6
|
Maintainer-email: Clemens Rietdorf <c.rietdorf@dnb.de>, Maximilian Kähler <m.kaehler@dnb.de>
|
|
@@ -14,12 +14,14 @@ Classifier: Programming Language :: Python :: 3
|
|
|
14
14
|
Requires-Python: >=3.10
|
|
15
15
|
Requires-Dist: duckdb>=1.3.0
|
|
16
16
|
Requires-Dist: nltk~=3.9.1
|
|
17
|
+
Requires-Dist: openai>=2.15.0
|
|
17
18
|
Requires-Dist: polars>=1.30.0
|
|
18
19
|
Requires-Dist: pyarrow>=21.0.0
|
|
19
20
|
Requires-Dist: pyoxigraph>=0.4.11
|
|
20
21
|
Requires-Dist: rdflib~=7.1.3
|
|
21
|
-
Requires-Dist: sentence-transformers>=5.0.0
|
|
22
22
|
Requires-Dist: xgboost>=3.0.2
|
|
23
|
+
Provides-Extra: in-process
|
|
24
|
+
Requires-Dist: sentence-transformers>=5.0.0; extra == 'in-process'
|
|
23
25
|
Description-Content-Type: text/markdown
|
|
24
26
|
|
|
25
27
|
# Embedding Based Matching for Automated Subject Indexing
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[project]
|
|
2
2
|
name = "ebm4subjects"
|
|
3
|
-
version = "0.5.
|
|
3
|
+
version = "0.5.5"
|
|
4
4
|
description = "Embedding Based Matching for Automated Subject Indexing"
|
|
5
5
|
authors = [
|
|
6
6
|
{name = "Deutsche Nationalbibliothek"},
|
|
@@ -29,13 +29,15 @@ requires-python = ">=3.10"
|
|
|
29
29
|
dependencies = [
|
|
30
30
|
"duckdb>=1.3.0",
|
|
31
31
|
"nltk~=3.9.1",
|
|
32
|
+
"openai>=2.15.0",
|
|
32
33
|
"polars>=1.30.0",
|
|
33
34
|
"pyarrow>=21.0.0",
|
|
34
35
|
"pyoxigraph>=0.4.11",
|
|
35
36
|
"rdflib~=7.1.3",
|
|
36
|
-
"sentence-transformers>=5.0.0",
|
|
37
37
|
"xgboost>=3.0.2",
|
|
38
38
|
]
|
|
39
|
+
[project.optional-dependencies]
|
|
40
|
+
in-process=["sentence-transformers>=5.0.0"]
|
|
39
41
|
|
|
40
42
|
[build-system]
|
|
41
43
|
requires = ["hatchling"]
|
|
@@ -39,17 +39,17 @@ class EbmLogger:
|
|
|
39
39
|
else:
|
|
40
40
|
self.logger.setLevel(logging.NOTSET)
|
|
41
41
|
|
|
42
|
-
# Create a file handler to log messages to a file
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
42
|
+
# Create a file handler to log messages to a file
|
|
43
|
+
if not self.logger.handlers:
|
|
44
|
+
log_file_handler = logging.FileHandler(f"{log_path}/ebm.log")
|
|
45
|
+
log_file_handler.setFormatter(
|
|
46
|
+
logging.Formatter(
|
|
47
|
+
"%(asctime)s %(levelname)s: %(message)s",
|
|
48
|
+
"%Y-%m-%d %H:%M:%S",
|
|
49
|
+
)
|
|
48
50
|
)
|
|
49
|
-
)
|
|
50
51
|
|
|
51
|
-
|
|
52
|
-
self.logger.addHandler(log_file_handler)
|
|
52
|
+
self.logger.addHandler(log_file_handler)
|
|
53
53
|
|
|
54
54
|
def get_logger(self) -> logging.Logger:
|
|
55
55
|
"""
|
|
@@ -16,7 +16,7 @@ from ebm4subjects.ebm_logging import EbmLogger, NullLogger, XGBLogging
|
|
|
16
16
|
from ebm4subjects.embedding_generator import (
|
|
17
17
|
EmbeddingGeneratorHuggingFaceTEI,
|
|
18
18
|
EmbeddingGeneratorMock,
|
|
19
|
-
|
|
19
|
+
EmbeddingGeneratorInProcess,
|
|
20
20
|
EmbeddingGeneratorOpenAI,
|
|
21
21
|
)
|
|
22
22
|
|
|
@@ -50,6 +50,7 @@ class EbmModel:
|
|
|
50
50
|
encode_args_documents: dict | str | None = None,
|
|
51
51
|
log_path: str | None = None,
|
|
52
52
|
logger: logging.Logger | None = None,
|
|
53
|
+
logging_level: str = "info",
|
|
53
54
|
) -> None:
|
|
54
55
|
"""
|
|
55
56
|
A class representing an Embedding-Based-Matching (EBM) model
|
|
@@ -139,7 +140,7 @@ class EbmModel:
|
|
|
139
140
|
self.train_jobs = int(xgb_jobs)
|
|
140
141
|
|
|
141
142
|
# Initiliaze logging
|
|
142
|
-
self.init_logger(log_path, logger)
|
|
143
|
+
self.init_logger(log_path, logger, logging_level)
|
|
143
144
|
|
|
144
145
|
# Initialize EBM model
|
|
145
146
|
self.model = None
|
|
@@ -180,11 +181,12 @@ class EbmModel:
|
|
|
180
181
|
None
|
|
181
182
|
"""
|
|
182
183
|
if self.generator is None:
|
|
183
|
-
if self.embedding_model_deployment == "
|
|
184
|
+
if self.embedding_model_deployment == "in-process":
|
|
184
185
|
self.logger.info("initializing offline-inference embedding generator")
|
|
185
|
-
self.generator =
|
|
186
|
+
self.generator = EmbeddingGeneratorInProcess(
|
|
186
187
|
model_name=self.embedding_model_name,
|
|
187
188
|
embedding_dimensions=self.embedding_dimensions,
|
|
189
|
+
logger=self.logger,
|
|
188
190
|
**self.embedding_model_args,
|
|
189
191
|
)
|
|
190
192
|
elif self.embedding_model_deployment == "mock":
|
|
@@ -195,6 +197,7 @@ class EbmModel:
|
|
|
195
197
|
self.generator = EmbeddingGeneratorHuggingFaceTEI(
|
|
196
198
|
model_name=self.embedding_model_name,
|
|
197
199
|
embedding_dimensions=self.embedding_dimensions,
|
|
200
|
+
logger=self.logger,
|
|
198
201
|
**self.embedding_model_args,
|
|
199
202
|
)
|
|
200
203
|
elif self.embedding_model_deployment == "OpenAI":
|
|
@@ -202,6 +205,7 @@ class EbmModel:
|
|
|
202
205
|
self.generator = EmbeddingGeneratorOpenAI(
|
|
203
206
|
model_name=self.embedding_model_name,
|
|
204
207
|
embedding_dimensions=self.embedding_dimensions,
|
|
208
|
+
logger=self.logger,
|
|
205
209
|
**self.embedding_model_args,
|
|
206
210
|
)
|
|
207
211
|
else:
|
|
@@ -209,7 +213,10 @@ class EbmModel:
|
|
|
209
213
|
raise NotImplementedError
|
|
210
214
|
|
|
211
215
|
def init_logger(
|
|
212
|
-
self,
|
|
216
|
+
self,
|
|
217
|
+
log_path: str | None = None,
|
|
218
|
+
logger: logging.Logger | None = None,
|
|
219
|
+
logging_level: str = "info",
|
|
213
220
|
) -> None:
|
|
214
221
|
"""
|
|
215
222
|
Initializes the logging for the EBM model.
|
|
@@ -218,7 +225,7 @@ class EbmModel:
|
|
|
218
225
|
None
|
|
219
226
|
"""
|
|
220
227
|
if log_path:
|
|
221
|
-
self.logger = EbmLogger(log_path,
|
|
228
|
+
self.logger = EbmLogger(log_path, logging_level).get_logger()
|
|
222
229
|
self.xgb_logger = XGBLogging(self.logger, epoch_log_interval=1)
|
|
223
230
|
self.xgb_callbacks = [self.xgb_logger]
|
|
224
231
|
elif logger:
|
|
@@ -769,7 +776,4 @@ class EbmModel:
|
|
|
769
776
|
Returns:
|
|
770
777
|
EbmModel: The loaded EBM model instance.
|
|
771
778
|
"""
|
|
772
|
-
|
|
773
|
-
ebm_model.init_logger()
|
|
774
|
-
|
|
775
|
-
return ebm_model
|
|
779
|
+
return joblib.load(input_path)
|
|
@@ -1,8 +1,9 @@
|
|
|
1
|
+
import logging
|
|
1
2
|
import os
|
|
2
3
|
|
|
3
4
|
import numpy as np
|
|
4
5
|
import requests
|
|
5
|
-
from
|
|
6
|
+
from openai import BadRequestError, NotFoundError, OpenAI
|
|
6
7
|
from tqdm import tqdm
|
|
7
8
|
|
|
8
9
|
|
|
@@ -32,25 +33,29 @@ class EmbeddingGenerator:
|
|
|
32
33
|
pass
|
|
33
34
|
|
|
34
35
|
|
|
35
|
-
class
|
|
36
|
+
class EmbeddingGeneratorHuggingFaceTEI(EmbeddingGenerator):
|
|
36
37
|
"""
|
|
37
|
-
A
|
|
38
|
-
|
|
39
|
-
Attributes:
|
|
40
|
-
embedding_dimensions (int): The dimensionality of the generated embeddings.
|
|
38
|
+
A class for generating embeddings using the HuggingFaceTEI API.
|
|
41
39
|
"""
|
|
42
40
|
|
|
43
41
|
def __init__(
|
|
44
42
|
self,
|
|
45
43
|
model_name: str,
|
|
46
44
|
embedding_dimensions: int,
|
|
45
|
+
logger: logging.Logger,
|
|
47
46
|
**kwargs,
|
|
48
47
|
) -> None:
|
|
49
48
|
"""
|
|
50
|
-
Initializes the API EmbeddingGenerator.
|
|
49
|
+
Initializes the HuggingFaceTEI API EmbeddingGenerator.
|
|
51
50
|
|
|
52
51
|
Sets the embedding dimensions, and initiliazes and
|
|
53
52
|
prepares a session with the API.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
model_name (str): The name of the SentenceTransformer model.
|
|
56
|
+
embedding_dimensions (int): The dimensionality of the generated embeddings.
|
|
57
|
+
logger (Logger): A logger for the embedding generator.
|
|
58
|
+
**kwargs: Additional keyword arguments to pass to the model.
|
|
54
59
|
"""
|
|
55
60
|
|
|
56
61
|
self.embedding_dimensions = embedding_dimensions
|
|
@@ -59,11 +64,36 @@ class EmbeddingGeneratorAPI(EmbeddingGenerator):
|
|
|
59
64
|
self.api_address = kwargs.get("api_address")
|
|
60
65
|
self.headers = kwargs.get("headers", {"Content-Type": "application/json"})
|
|
61
66
|
|
|
67
|
+
self.logger = logger
|
|
68
|
+
self._test_api()
|
|
62
69
|
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
70
|
+
def _test_api(self):
|
|
71
|
+
"""
|
|
72
|
+
Tests if the API is working with the given parameters
|
|
73
|
+
"""
|
|
74
|
+
response = self.session.post(
|
|
75
|
+
self.api_address,
|
|
76
|
+
headers=self.headers,
|
|
77
|
+
json={"inputs": "This is a test request!", "truncate": True},
|
|
78
|
+
)
|
|
79
|
+
if response.status_code == 200:
|
|
80
|
+
self.logger.debug(
|
|
81
|
+
"API call successful. Everything seems to be working fine."
|
|
82
|
+
)
|
|
83
|
+
elif response.status_code == 404:
|
|
84
|
+
self.logger.error(
|
|
85
|
+
"API not found under given adress! Please check the corresponding parameter!"
|
|
86
|
+
)
|
|
87
|
+
raise RuntimeError(
|
|
88
|
+
"API not found under given adress! Please check the corresponding parameter!"
|
|
89
|
+
)
|
|
90
|
+
else:
|
|
91
|
+
self.logger.error(
|
|
92
|
+
"Request to API not possible! Please check the corresponding parameters!"
|
|
93
|
+
)
|
|
94
|
+
raise RuntimeError(
|
|
95
|
+
"Request to API not possible! Please check the corresponding parameters!"
|
|
96
|
+
)
|
|
67
97
|
|
|
68
98
|
def generate_embeddings(self, texts: list[str], **kwargs) -> np.ndarray:
|
|
69
99
|
"""
|
|
@@ -72,8 +102,7 @@ class EmbeddingGeneratorHuggingFaceTEI(EmbeddingGeneratorAPI):
|
|
|
72
102
|
|
|
73
103
|
Args:
|
|
74
104
|
texts (list[str]): A list of input texts.
|
|
75
|
-
**kwargs: Additional keyword arguments to pass to the
|
|
76
|
-
SentenceTransformer model.
|
|
105
|
+
**kwargs: Additional keyword arguments to pass to the API.
|
|
77
106
|
|
|
78
107
|
Returns:
|
|
79
108
|
np.ndarray: A numpy array of shape (len(texts), embedding_dimensions)
|
|
@@ -102,19 +131,87 @@ class EmbeddingGeneratorHuggingFaceTEI(EmbeddingGeneratorAPI):
|
|
|
102
131
|
if response.status_code == 200:
|
|
103
132
|
embeddings.extend(response.json())
|
|
104
133
|
else:
|
|
105
|
-
|
|
134
|
+
self.logger.warn("Call to API NOT successful! Returning 0's.")
|
|
106
135
|
for _ in batch_texts:
|
|
107
|
-
|
|
108
|
-
|
|
136
|
+
embeddings.append(
|
|
137
|
+
[
|
|
138
|
+
0
|
|
139
|
+
for _ in range(
|
|
140
|
+
min(
|
|
141
|
+
self.embedding_dimensions,
|
|
142
|
+
kwargs.get("truncate_prompt_tokens", float("inf")),
|
|
143
|
+
),
|
|
144
|
+
)
|
|
145
|
+
]
|
|
146
|
+
)
|
|
109
147
|
|
|
110
148
|
return np.array(embeddings)
|
|
111
149
|
|
|
112
150
|
|
|
113
|
-
class EmbeddingGeneratorOpenAI(
|
|
151
|
+
class EmbeddingGeneratorOpenAI(EmbeddingGenerator):
|
|
114
152
|
"""
|
|
115
|
-
A class for generating embeddings using any OpenAI
|
|
153
|
+
A class for generating embeddings using any OpenAI compatible API.
|
|
116
154
|
"""
|
|
117
155
|
|
|
156
|
+
def __init__(
|
|
157
|
+
self,
|
|
158
|
+
model_name: str,
|
|
159
|
+
embedding_dimensions: int,
|
|
160
|
+
logger: logging.Logger,
|
|
161
|
+
**kwargs,
|
|
162
|
+
) -> None:
|
|
163
|
+
"""
|
|
164
|
+
Initializes the OpenAI API EmbeddingGenerator.
|
|
165
|
+
|
|
166
|
+
Sets the embedding dimensions, and initiliazes and
|
|
167
|
+
prepares a session with the API.
|
|
168
|
+
|
|
169
|
+
Args:
|
|
170
|
+
model_name (str): The name of the SentenceTransformer model.
|
|
171
|
+
embedding_dimensions (int): The dimensionality of the generated embeddings.
|
|
172
|
+
logger (Logger): A logger for the embedding generator.
|
|
173
|
+
**kwargs: Additional keyword arguments to pass to the model.
|
|
174
|
+
"""
|
|
175
|
+
|
|
176
|
+
self.embedding_dimensions = embedding_dimensions
|
|
177
|
+
self.model_name = model_name
|
|
178
|
+
|
|
179
|
+
if not (api_key := os.environ.get("OPENAI_API_KEY")):
|
|
180
|
+
api_key = ""
|
|
181
|
+
|
|
182
|
+
self.client = OpenAI(api_key=api_key, base_url=kwargs.get("api_address"))
|
|
183
|
+
|
|
184
|
+
self.logger = logger
|
|
185
|
+
self._test_api()
|
|
186
|
+
|
|
187
|
+
def _test_api(self):
|
|
188
|
+
"""
|
|
189
|
+
Tests if the API is working with the given parameters
|
|
190
|
+
"""
|
|
191
|
+
try:
|
|
192
|
+
_ = self.client.embeddings.create(
|
|
193
|
+
input="This is a test request!",
|
|
194
|
+
model=self.model_name,
|
|
195
|
+
encoding_format="float",
|
|
196
|
+
)
|
|
197
|
+
self.logger.debug(
|
|
198
|
+
"API call successful. Everything seems to be working fine."
|
|
199
|
+
)
|
|
200
|
+
except NotFoundError:
|
|
201
|
+
self.logger.error(
|
|
202
|
+
"API not found under given adress! Please check the corresponding parameter!"
|
|
203
|
+
)
|
|
204
|
+
raise RuntimeError(
|
|
205
|
+
"API not found under given adress! Please check the corresponding parameter!"
|
|
206
|
+
)
|
|
207
|
+
except BadRequestError:
|
|
208
|
+
self.logger.error(
|
|
209
|
+
"Request to API not possible! Please check the corresponding parameters!"
|
|
210
|
+
)
|
|
211
|
+
raise RuntimeError(
|
|
212
|
+
"Request to API not possible! Please check the corresponding parameters!"
|
|
213
|
+
)
|
|
214
|
+
|
|
118
215
|
def generate_embeddings(self, texts: list[str], **kwargs) -> np.ndarray:
|
|
119
216
|
"""
|
|
120
217
|
Generates embeddings for a list of input texts using a model
|
|
@@ -122,8 +219,7 @@ class EmbeddingGeneratorOpenAI(EmbeddingGeneratorAPI):
|
|
|
122
219
|
|
|
123
220
|
Args:
|
|
124
221
|
texts (list[str]): A list of input texts.
|
|
125
|
-
**kwargs: Additional keyword arguments to pass to the
|
|
126
|
-
SentenceTransformer model.
|
|
222
|
+
**kwargs: Additional keyword arguments to pass to the API.
|
|
127
223
|
|
|
128
224
|
Returns:
|
|
129
225
|
np.ndarray: A numpy array of shape (len(texts), embedding_dimensions)
|
|
@@ -143,53 +239,54 @@ class EmbeddingGeneratorOpenAI(EmbeddingGeneratorAPI):
|
|
|
143
239
|
|
|
144
240
|
for i in tqdm(range(0, len(texts), batch_size), desc="Processing batches"):
|
|
145
241
|
batch_texts = texts[i : i + batch_size]
|
|
146
|
-
data = {
|
|
147
|
-
"input": batch_texts,
|
|
148
|
-
"model": self.model_name,
|
|
149
|
-
"encoding_format": "float",
|
|
150
|
-
**kwargs,
|
|
151
|
-
}
|
|
152
242
|
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
243
|
+
# Try to get embeddings for the batch from the API
|
|
244
|
+
try:
|
|
245
|
+
embedding_response = self.client.embeddings.create(
|
|
246
|
+
input=batch_texts,
|
|
247
|
+
model=self.model_name,
|
|
248
|
+
encoding_format="float",
|
|
249
|
+
extra_body={**kwargs},
|
|
250
|
+
)
|
|
156
251
|
|
|
157
|
-
|
|
158
|
-
if response.status_code == 200:
|
|
159
|
-
response_data = response.json()
|
|
252
|
+
# Process all embeddings from the batch response
|
|
160
253
|
for i, _ in enumerate(batch_texts):
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
# TODO: write warning to logger
|
|
254
|
+
embeddings.append(embedding_response.data[i].embedding)
|
|
255
|
+
except (NotFoundError, BadRequestError):
|
|
256
|
+
self.logger.warn("Call to API NOT successful! Returning 0's.")
|
|
165
257
|
for _ in batch_texts:
|
|
166
258
|
embeddings.append([0 for _ in range(self.embedding_dimensions)])
|
|
167
259
|
|
|
168
260
|
return np.array(embeddings)
|
|
169
261
|
|
|
170
262
|
|
|
171
|
-
class
|
|
263
|
+
class EmbeddingGeneratorInProcess(EmbeddingGenerator):
|
|
172
264
|
"""
|
|
173
265
|
A class for generating embeddings using a given SentenceTransformer model
|
|
174
|
-
loaded
|
|
266
|
+
loaded in-process with SentenceTransformer.
|
|
175
267
|
|
|
176
268
|
Args:
|
|
177
269
|
model_name (str): The name of the SentenceTransformer model.
|
|
178
270
|
embedding_dimensions (int): The dimensionality of the generated embeddings.
|
|
271
|
+
logger (Logger): A logger for the embedding generator.
|
|
179
272
|
**kwargs: Additional keyword arguments to pass to the model.
|
|
180
|
-
|
|
181
|
-
Attributes:
|
|
182
|
-
model_name (str): The name of the SentenceTransformer model.
|
|
183
|
-
embedding_dimensions (int): The dimensionality of the generated embeddings.
|
|
184
273
|
"""
|
|
185
274
|
|
|
186
|
-
def __init__(
|
|
275
|
+
def __init__(
|
|
276
|
+
self,
|
|
277
|
+
model_name: str,
|
|
278
|
+
embedding_dimensions: int,
|
|
279
|
+
logger: logging.Logger,
|
|
280
|
+
**kwargs,
|
|
281
|
+
) -> None:
|
|
187
282
|
"""
|
|
188
|
-
Initializes the EmbeddingGenerator in
|
|
283
|
+
Initializes the EmbeddingGenerator in 'in-process' mode.
|
|
189
284
|
|
|
190
285
|
Sets the model name, embedding dimensions, and creates a
|
|
191
286
|
SentenceTransformer model instance.
|
|
192
287
|
"""
|
|
288
|
+
from sentence_transformers import SentenceTransformer
|
|
289
|
+
|
|
193
290
|
self.model_name = model_name
|
|
194
291
|
self.embedding_dimensions = embedding_dimensions
|
|
195
292
|
|
|
@@ -198,6 +295,8 @@ class EmbeddingGeneratorOfflineInference(EmbeddingGenerator):
|
|
|
198
295
|
self.model = SentenceTransformer(
|
|
199
296
|
model_name, truncate_dim=embedding_dimensions, **kwargs
|
|
200
297
|
)
|
|
298
|
+
self.logger = logger
|
|
299
|
+
self.logger.debug(f"SentenceTransfomer model running on {self.model.device}")
|
|
201
300
|
|
|
202
301
|
# Disabel parallelism for tokenizer
|
|
203
302
|
# Needed because process might be already parallelized
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|