ebm4subjects 0.5.3__tar.gz → 0.5.5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (26) hide show
  1. {ebm4subjects-0.5.3 → ebm4subjects-0.5.5}/PKG-INFO +4 -2
  2. {ebm4subjects-0.5.3 → ebm4subjects-0.5.5}/pyproject.toml +4 -2
  3. {ebm4subjects-0.5.3 → ebm4subjects-0.5.5}/src/ebm4subjects/ebm_logging.py +9 -9
  4. {ebm4subjects-0.5.3 → ebm4subjects-0.5.5}/src/ebm4subjects/ebm_model.py +27 -14
  5. ebm4subjects-0.5.5/src/ebm4subjects/embedding_generator.py +367 -0
  6. {ebm4subjects-0.5.3 → ebm4subjects-0.5.5}/tests/test_prepare_data.py +3 -1
  7. ebm4subjects-0.5.3/src/ebm4subjects/embedding_generator.py +0 -202
  8. {ebm4subjects-0.5.3 → ebm4subjects-0.5.5}/.gitignore +0 -0
  9. {ebm4subjects-0.5.3 → ebm4subjects-0.5.5}/.python-version +0 -0
  10. {ebm4subjects-0.5.3 → ebm4subjects-0.5.5}/LICENSE +0 -0
  11. {ebm4subjects-0.5.3 → ebm4subjects-0.5.5}/README.md +0 -0
  12. {ebm4subjects-0.5.3 → ebm4subjects-0.5.5}/docs/Makefile +0 -0
  13. {ebm4subjects-0.5.3 → ebm4subjects-0.5.5}/docs/make.bat +0 -0
  14. {ebm4subjects-0.5.3 → ebm4subjects-0.5.5}/docs/source/README.md +0 -0
  15. {ebm4subjects-0.5.3 → ebm4subjects-0.5.5}/docs/source/conf.py +0 -0
  16. {ebm4subjects-0.5.3 → ebm4subjects-0.5.5}/docs/source/ebm4subjects.rst +0 -0
  17. {ebm4subjects-0.5.3 → ebm4subjects-0.5.5}/docs/source/index.rst +0 -0
  18. {ebm4subjects-0.5.3 → ebm4subjects-0.5.5}/ebm-sketch.svg +0 -0
  19. {ebm4subjects-0.5.3 → ebm4subjects-0.5.5}/src/ebm4subjects/__init__.py +0 -0
  20. {ebm4subjects-0.5.3 → ebm4subjects-0.5.5}/src/ebm4subjects/analyzer.py +0 -0
  21. {ebm4subjects-0.5.3 → ebm4subjects-0.5.5}/src/ebm4subjects/chunker.py +0 -0
  22. {ebm4subjects-0.5.3 → ebm4subjects-0.5.5}/src/ebm4subjects/duckdb_client.py +0 -0
  23. {ebm4subjects-0.5.3 → ebm4subjects-0.5.5}/src/ebm4subjects/prepare_data.py +0 -0
  24. {ebm4subjects-0.5.3 → ebm4subjects-0.5.5}/tests/__init__.py +0 -0
  25. {ebm4subjects-0.5.3 → ebm4subjects-0.5.5}/tests/data/vocab.ttl +0 -0
  26. {ebm4subjects-0.5.3 → ebm4subjects-0.5.5}/tests/test_hello.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ebm4subjects
3
- Version: 0.5.3
3
+ Version: 0.5.5
4
4
  Summary: Embedding Based Matching for Automated Subject Indexing
5
5
  Author: Deutsche Nationalbibliothek
6
6
  Maintainer-email: Clemens Rietdorf <c.rietdorf@dnb.de>, Maximilian Kähler <m.kaehler@dnb.de>
@@ -14,12 +14,14 @@ Classifier: Programming Language :: Python :: 3
14
14
  Requires-Python: >=3.10
15
15
  Requires-Dist: duckdb>=1.3.0
16
16
  Requires-Dist: nltk~=3.9.1
17
+ Requires-Dist: openai>=2.15.0
17
18
  Requires-Dist: polars>=1.30.0
18
19
  Requires-Dist: pyarrow>=21.0.0
19
20
  Requires-Dist: pyoxigraph>=0.4.11
20
21
  Requires-Dist: rdflib~=7.1.3
21
- Requires-Dist: sentence-transformers>=5.0.0
22
22
  Requires-Dist: xgboost>=3.0.2
23
+ Provides-Extra: in-process
24
+ Requires-Dist: sentence-transformers>=5.0.0; extra == 'in-process'
23
25
  Description-Content-Type: text/markdown
24
26
 
25
27
  # Embedding Based Matching for Automated Subject Indexing
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "ebm4subjects"
3
- version = "0.5.3"
3
+ version = "0.5.5"
4
4
  description = "Embedding Based Matching for Automated Subject Indexing"
5
5
  authors = [
6
6
  {name = "Deutsche Nationalbibliothek"},
@@ -29,13 +29,15 @@ requires-python = ">=3.10"
29
29
  dependencies = [
30
30
  "duckdb>=1.3.0",
31
31
  "nltk~=3.9.1",
32
+ "openai>=2.15.0",
32
33
  "polars>=1.30.0",
33
34
  "pyarrow>=21.0.0",
34
35
  "pyoxigraph>=0.4.11",
35
36
  "rdflib~=7.1.3",
36
- "sentence-transformers>=5.0.0",
37
37
  "xgboost>=3.0.2",
38
38
  ]
39
+ [project.optional-dependencies]
40
+ in-process=["sentence-transformers>=5.0.0"]
39
41
 
40
42
  [build-system]
41
43
  requires = ["hatchling"]
@@ -39,17 +39,17 @@ class EbmLogger:
39
39
  else:
40
40
  self.logger.setLevel(logging.NOTSET)
41
41
 
42
- # Create a file handler to log messages to a file
43
- log_file_handler = logging.FileHandler(f"{log_path}/ebm.log")
44
- log_file_handler.setFormatter(
45
- logging.Formatter(
46
- "%(asctime)s %(levelname)s: %(message)s",
47
- "%Y-%m-%d %H:%M:%S",
42
+ # Create a file handler to log messages to a file
43
+ if not self.logger.handlers:
44
+ log_file_handler = logging.FileHandler(f"{log_path}/ebm.log")
45
+ log_file_handler.setFormatter(
46
+ logging.Formatter(
47
+ "%(asctime)s %(levelname)s: %(message)s",
48
+ "%Y-%m-%d %H:%M:%S",
49
+ )
48
50
  )
49
- )
50
51
 
51
- # Add the file handler to the logger
52
- self.logger.addHandler(log_file_handler)
52
+ self.logger.addHandler(log_file_handler)
53
53
 
54
54
  def get_logger(self) -> logging.Logger:
55
55
  """
@@ -15,8 +15,9 @@ from ebm4subjects.duckdb_client import Duckdb_client
15
15
  from ebm4subjects.ebm_logging import EbmLogger, NullLogger, XGBLogging
16
16
  from ebm4subjects.embedding_generator import (
17
17
  EmbeddingGeneratorHuggingFaceTEI,
18
- EmbeddingGeneratorOfflineInference,
19
18
  EmbeddingGeneratorMock,
19
+ EmbeddingGeneratorInProcess,
20
+ EmbeddingGeneratorOpenAI,
20
21
  )
21
22
 
22
23
 
@@ -43,12 +44,13 @@ class EbmModel:
43
44
  use_altLabels: bool = True,
44
45
  hnsw_index_params: dict | str | None = None,
45
46
  embedding_model_name: str | None = None,
46
- embedding_model_type: str = "offline-inference",
47
+ embedding_model_deployment: str = "offline-inference",
47
48
  embedding_model_args: dict | str | None = None,
48
49
  encode_args_vocab: dict | str | None = None,
49
50
  encode_args_documents: dict | str | None = None,
50
51
  log_path: str | None = None,
51
52
  logger: logging.Logger | None = None,
53
+ logging_level: str = "info",
52
54
  ) -> None:
53
55
  """
54
56
  A class representing an Embedding-Based-Matching (EBM) model
@@ -99,7 +101,7 @@ class EbmModel:
99
101
 
100
102
  # Parameters for embedding generator
101
103
  self.generator = None
102
- self.embedding_model_type = embedding_model_type
104
+ self.embedding_model_deployment = embedding_model_deployment
103
105
  self.embedding_model_name = embedding_model_name
104
106
  self.embedding_dimensions = int(embedding_dimensions)
105
107
  if isinstance(embedding_model_args, str) or not embedding_model_args:
@@ -138,7 +140,7 @@ class EbmModel:
138
140
  self.train_jobs = int(xgb_jobs)
139
141
 
140
142
  # Initiliaze logging
141
- self.init_logger(log_path, logger)
143
+ self.init_logger(log_path, logger, logging_level)
142
144
 
143
145
  # Initialize EBM model
144
146
  self.model = None
@@ -179,20 +181,31 @@ class EbmModel:
179
181
  None
180
182
  """
181
183
  if self.generator is None:
182
- if self.embedding_model_type == "offline-inference":
184
+ if self.embedding_model_deployment == "in-process":
183
185
  self.logger.info("initializing offline-inference embedding generator")
184
- self.generator = EmbeddingGeneratorOfflineInference(
186
+ self.generator = EmbeddingGeneratorInProcess(
185
187
  model_name=self.embedding_model_name,
186
188
  embedding_dimensions=self.embedding_dimensions,
189
+ logger=self.logger,
187
190
  **self.embedding_model_args,
188
191
  )
189
- elif self.embedding_model_type == "mock":
192
+ elif self.embedding_model_deployment == "mock":
190
193
  self.logger.info("initializing mock embedding generator")
191
194
  self.generator = EmbeddingGeneratorMock(self.embedding_dimensions)
192
- elif self.embedding_model_type == "HuggingFaceTEI":
195
+ elif self.embedding_model_deployment == "HuggingFaceTEI":
193
196
  self.logger.info("initializing API embedding generator")
194
197
  self.generator = EmbeddingGeneratorHuggingFaceTEI(
198
+ model_name=self.embedding_model_name,
199
+ embedding_dimensions=self.embedding_dimensions,
200
+ logger=self.logger,
201
+ **self.embedding_model_args,
202
+ )
203
+ elif self.embedding_model_deployment == "OpenAI":
204
+ self.logger.info("initializing API embedding generator")
205
+ self.generator = EmbeddingGeneratorOpenAI(
206
+ model_name=self.embedding_model_name,
195
207
  embedding_dimensions=self.embedding_dimensions,
208
+ logger=self.logger,
196
209
  **self.embedding_model_args,
197
210
  )
198
211
  else:
@@ -200,7 +213,10 @@ class EbmModel:
200
213
  raise NotImplementedError
201
214
 
202
215
  def init_logger(
203
- self, log_path: str | None = None, logger: logging.Logger | None = None
216
+ self,
217
+ log_path: str | None = None,
218
+ logger: logging.Logger | None = None,
219
+ logging_level: str = "info",
204
220
  ) -> None:
205
221
  """
206
222
  Initializes the logging for the EBM model.
@@ -209,7 +225,7 @@ class EbmModel:
209
225
  None
210
226
  """
211
227
  if log_path:
212
- self.logger = EbmLogger(log_path, "info").get_logger()
228
+ self.logger = EbmLogger(log_path, logging_level).get_logger()
213
229
  self.xgb_logger = XGBLogging(self.logger, epoch_log_interval=1)
214
230
  self.xgb_callbacks = [self.xgb_logger]
215
231
  elif logger:
@@ -760,7 +776,4 @@ class EbmModel:
760
776
  Returns:
761
777
  EbmModel: The loaded EBM model instance.
762
778
  """
763
- ebm_model = joblib.load(input_path)
764
- ebm_model.init_logger()
765
-
766
- return ebm_model
779
+ return joblib.load(input_path)
@@ -0,0 +1,367 @@
1
+ import logging
2
+ import os
3
+
4
+ import numpy as np
5
+ import requests
6
+ from openai import BadRequestError, NotFoundError, OpenAI
7
+ from tqdm import tqdm
8
+
9
+
10
+ class EmbeddingGenerator:
11
+ """
12
+ A base class for embedding generators.
13
+ """
14
+
15
+ def __init__(self) -> None:
16
+ """
17
+ Base method fot the initialization of an EmbeddingGenerator.
18
+ """
19
+ pass
20
+
21
+ def generate_embeddings(self, texts: list[str], **kwargs) -> np.ndarray:
22
+ """
23
+ Base method fot the creating embeddings with an EmbeddingGenerator.
24
+
25
+ Args:
26
+ texts (list[str]): A list of input texts.
27
+ **kwargs: Additional keyword arguments.
28
+
29
+ Returns:
30
+ np.ndarray: A numpy array of shape (len(texts), embedding_dimensions)
31
+ containing the generated embeddings.
32
+ """
33
+ pass
34
+
35
+
36
+ class EmbeddingGeneratorHuggingFaceTEI(EmbeddingGenerator):
37
+ """
38
+ A class for generating embeddings using the HuggingFaceTEI API.
39
+ """
40
+
41
+ def __init__(
42
+ self,
43
+ model_name: str,
44
+ embedding_dimensions: int,
45
+ logger: logging.Logger,
46
+ **kwargs,
47
+ ) -> None:
48
+ """
49
+ Initializes the HuggingFaceTEI API EmbeddingGenerator.
50
+
51
+ Sets the embedding dimensions, and initiliazes and
52
+ prepares a session with the API.
53
+
54
+ Args:
55
+ model_name (str): The name of the SentenceTransformer model.
56
+ embedding_dimensions (int): The dimensionality of the generated embeddings.
57
+ logger (Logger): A logger for the embedding generator.
58
+ **kwargs: Additional keyword arguments to pass to the model.
59
+ """
60
+
61
+ self.embedding_dimensions = embedding_dimensions
62
+ self.model_name = model_name
63
+ self.session = requests.Session()
64
+ self.api_address = kwargs.get("api_address")
65
+ self.headers = kwargs.get("headers", {"Content-Type": "application/json"})
66
+
67
+ self.logger = logger
68
+ self._test_api()
69
+
70
+ def _test_api(self):
71
+ """
72
+ Tests if the API is working with the given parameters
73
+ """
74
+ response = self.session.post(
75
+ self.api_address,
76
+ headers=self.headers,
77
+ json={"inputs": "This is a test request!", "truncate": True},
78
+ )
79
+ if response.status_code == 200:
80
+ self.logger.debug(
81
+ "API call successful. Everything seems to be working fine."
82
+ )
83
+ elif response.status_code == 404:
84
+ self.logger.error(
85
+ "API not found under given adress! Please check the corresponding parameter!"
86
+ )
87
+ raise RuntimeError(
88
+ "API not found under given adress! Please check the corresponding parameter!"
89
+ )
90
+ else:
91
+ self.logger.error(
92
+ "Request to API not possible! Please check the corresponding parameters!"
93
+ )
94
+ raise RuntimeError(
95
+ "Request to API not possible! Please check the corresponding parameters!"
96
+ )
97
+
98
+ def generate_embeddings(self, texts: list[str], **kwargs) -> np.ndarray:
99
+ """
100
+ Generates embeddings for a list of input texts using a model
101
+ via the HuggingFaceTEI API.
102
+
103
+ Args:
104
+ texts (list[str]): A list of input texts.
105
+ **kwargs: Additional keyword arguments to pass to the API.
106
+
107
+ Returns:
108
+ np.ndarray: A numpy array of shape (len(texts), embedding_dimensions)
109
+ containing the generated embeddings.
110
+ """
111
+ # prepare list for return
112
+ embeddings = []
113
+
114
+ # Check if the input list is empty
115
+ if not texts:
116
+ # If empty, return an empty numpy array with the correct shape
117
+ return np.empty((0, self.embedding_dimensions))
118
+
119
+ # Process in smaller batches to avoid memory overload
120
+ batch_size = min(32, len(texts)) # HuggingFaceTEI has a limit of 32 as default
121
+
122
+ for i in tqdm(range(0, len(texts), batch_size), desc="Processing batches"):
123
+ batch_texts = texts[i : i + batch_size]
124
+ # send a request to the HuggingFaceTEI API
125
+ data = {"inputs": batch_texts, "truncate": True}
126
+ response = self.session.post(
127
+ self.api_address, headers=self.headers, json=data
128
+ )
129
+
130
+ # add generated embeddings to return list if request was successfull
131
+ if response.status_code == 200:
132
+ embeddings.extend(response.json())
133
+ else:
134
+ self.logger.warn("Call to API NOT successful! Returning 0's.")
135
+ for _ in batch_texts:
136
+ embeddings.append(
137
+ [
138
+ 0
139
+ for _ in range(
140
+ min(
141
+ self.embedding_dimensions,
142
+ kwargs.get("truncate_prompt_tokens", float("inf")),
143
+ ),
144
+ )
145
+ ]
146
+ )
147
+
148
+ return np.array(embeddings)
149
+
150
+
151
+ class EmbeddingGeneratorOpenAI(EmbeddingGenerator):
152
+ """
153
+ A class for generating embeddings using any OpenAI compatible API.
154
+ """
155
+
156
+ def __init__(
157
+ self,
158
+ model_name: str,
159
+ embedding_dimensions: int,
160
+ logger: logging.Logger,
161
+ **kwargs,
162
+ ) -> None:
163
+ """
164
+ Initializes the OpenAI API EmbeddingGenerator.
165
+
166
+ Sets the embedding dimensions, and initiliazes and
167
+ prepares a session with the API.
168
+
169
+ Args:
170
+ model_name (str): The name of the SentenceTransformer model.
171
+ embedding_dimensions (int): The dimensionality of the generated embeddings.
172
+ logger (Logger): A logger for the embedding generator.
173
+ **kwargs: Additional keyword arguments to pass to the model.
174
+ """
175
+
176
+ self.embedding_dimensions = embedding_dimensions
177
+ self.model_name = model_name
178
+
179
+ if not (api_key := os.environ.get("OPENAI_API_KEY")):
180
+ api_key = ""
181
+
182
+ self.client = OpenAI(api_key=api_key, base_url=kwargs.get("api_address"))
183
+
184
+ self.logger = logger
185
+ self._test_api()
186
+
187
+ def _test_api(self):
188
+ """
189
+ Tests if the API is working with the given parameters
190
+ """
191
+ try:
192
+ _ = self.client.embeddings.create(
193
+ input="This is a test request!",
194
+ model=self.model_name,
195
+ encoding_format="float",
196
+ )
197
+ self.logger.debug(
198
+ "API call successful. Everything seems to be working fine."
199
+ )
200
+ except NotFoundError:
201
+ self.logger.error(
202
+ "API not found under given adress! Please check the corresponding parameter!"
203
+ )
204
+ raise RuntimeError(
205
+ "API not found under given adress! Please check the corresponding parameter!"
206
+ )
207
+ except BadRequestError:
208
+ self.logger.error(
209
+ "Request to API not possible! Please check the corresponding parameters!"
210
+ )
211
+ raise RuntimeError(
212
+ "Request to API not possible! Please check the corresponding parameters!"
213
+ )
214
+
215
+ def generate_embeddings(self, texts: list[str], **kwargs) -> np.ndarray:
216
+ """
217
+ Generates embeddings for a list of input texts using a model
218
+ via an OpenAI compatible API.
219
+
220
+ Args:
221
+ texts (list[str]): A list of input texts.
222
+ **kwargs: Additional keyword arguments to pass to the API.
223
+
224
+ Returns:
225
+ np.ndarray: A numpy array of shape (len(texts), embedding_dimensions)
226
+ containing the generated embeddings.
227
+ """
228
+ # prepare list for return
229
+ embeddings = []
230
+
231
+ # Check if the input list is empty
232
+ if not texts:
233
+ # If empty, return an empty numpy array with the correct shape
234
+ return np.empty((0, self.embedding_dimensions))
235
+
236
+ # Process in smaller batches to avoid memory overload
237
+ batch_size = min(200, len(texts))
238
+ embeddings = []
239
+
240
+ for i in tqdm(range(0, len(texts), batch_size), desc="Processing batches"):
241
+ batch_texts = texts[i : i + batch_size]
242
+
243
+ # Try to get embeddings for the batch from the API
244
+ try:
245
+ embedding_response = self.client.embeddings.create(
246
+ input=batch_texts,
247
+ model=self.model_name,
248
+ encoding_format="float",
249
+ extra_body={**kwargs},
250
+ )
251
+
252
+ # Process all embeddings from the batch response
253
+ for i, _ in enumerate(batch_texts):
254
+ embeddings.append(embedding_response.data[i].embedding)
255
+ except (NotFoundError, BadRequestError):
256
+ self.logger.warn("Call to API NOT successful! Returning 0's.")
257
+ for _ in batch_texts:
258
+ embeddings.append([0 for _ in range(self.embedding_dimensions)])
259
+
260
+ return np.array(embeddings)
261
+
262
+
263
+ class EmbeddingGeneratorInProcess(EmbeddingGenerator):
264
+ """
265
+ A class for generating embeddings using a given SentenceTransformer model
266
+ loaded in-process with SentenceTransformer.
267
+
268
+ Args:
269
+ model_name (str): The name of the SentenceTransformer model.
270
+ embedding_dimensions (int): The dimensionality of the generated embeddings.
271
+ logger (Logger): A logger for the embedding generator.
272
+ **kwargs: Additional keyword arguments to pass to the model.
273
+ """
274
+
275
+ def __init__(
276
+ self,
277
+ model_name: str,
278
+ embedding_dimensions: int,
279
+ logger: logging.Logger,
280
+ **kwargs,
281
+ ) -> None:
282
+ """
283
+ Initializes the EmbeddingGenerator in 'in-process' mode.
284
+
285
+ Sets the model name, embedding dimensions, and creates a
286
+ SentenceTransformer model instance.
287
+ """
288
+ from sentence_transformers import SentenceTransformer
289
+
290
+ self.model_name = model_name
291
+ self.embedding_dimensions = embedding_dimensions
292
+
293
+ # Create a SentenceTransformer model instance with the given
294
+ # model name and embedding dimensions
295
+ self.model = SentenceTransformer(
296
+ model_name, truncate_dim=embedding_dimensions, **kwargs
297
+ )
298
+ self.logger = logger
299
+ self.logger.debug(f"SentenceTransfomer model running on {self.model.device}")
300
+
301
+ # Disabel parallelism for tokenizer
302
+ # Needed because process might be already parallelized
303
+ # before embedding creation
304
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
305
+
306
+ def generate_embeddings(self, texts: list[str], **kwargs) -> np.ndarray:
307
+ """
308
+ Generates embeddings for a list of input texts using the
309
+ SentenceTransformer model.
310
+
311
+ Args:
312
+ texts (list[str]): A list of input texts.
313
+ **kwargs: Additional keyword arguments to pass to the
314
+ SentenceTransformer model.
315
+
316
+ Returns:
317
+ np.ndarray: A numpy array of shape (len(texts), embedding_dimensions)
318
+ containing the generated embeddings.
319
+ """
320
+ # Check if the input list is empty
321
+ if not texts:
322
+ # If empty, return an empty numpy array with the correct shape
323
+ return np.empty((0, self.embedding_dimensions))
324
+
325
+ # Generate embeddings using the SentenceTransformer model and return them
326
+ return self.model.encode(texts, **kwargs)
327
+
328
+
329
+ class EmbeddingGeneratorMock(EmbeddingGenerator):
330
+ """
331
+ A mock class for generating fake embeddings. Used for testing.
332
+
333
+ Args:
334
+ embedding_dimensions (int): The dimensionality of the generated embeddings.
335
+ **kwargs: Additional keyword arguments to pass to the model.
336
+
337
+ Attributes:
338
+ embedding_dimensions (int): The dimensionality of the generated embeddings.
339
+ """
340
+
341
+ def __init__(self, embedding_dimensions: int, **kwargs) -> None:
342
+ """
343
+ Initializes the mock EmbeddingGenerator.
344
+
345
+ Sets the embedding dimensions.
346
+ """
347
+ self.embedding_dimensions = embedding_dimensions
348
+
349
+ def generate_embeddings(self, texts: list[str], **kwargs) -> np.ndarray:
350
+ """
351
+ Generates embeddings for a list of input texts.
352
+
353
+ Args:
354
+ texts (list[str]): A list of input texts.
355
+ **kwargs: Additional keyword arguments.
356
+
357
+ Returns:
358
+ np.ndarray: A numpy array of shape (len(texts), embedding_dimensions)
359
+ containing the generated embeddings.
360
+ """
361
+ # Check if the input list is empty
362
+ if not texts:
363
+ # If empty, return an empty numpy array with the correct shape
364
+ return np.empty((0, self.embedding_dimensions))
365
+
366
+ # Generate mock embeddings return them
367
+ return np.ones((len(texts), 1024))
@@ -1,8 +1,10 @@
1
- import polars as pl
2
1
  from pathlib import Path
3
2
 
3
+ import polars as pl
4
+
4
5
  from ebm4subjects.prepare_data import parse_vocab
5
6
 
7
+
6
8
  def test_parse_vocab_reads_ttl_and_returns_dataframe(tmp_path):
7
9
  # Copy the sample vocab.ttl to a temp location
8
10
  vocab_src = Path(__file__).parent / "data/vocab.ttl"
@@ -1,202 +0,0 @@
1
- import os
2
-
3
- import numpy as np
4
- import requests
5
- from sentence_transformers import SentenceTransformer
6
-
7
-
8
- class EmbeddingGenerator:
9
- """
10
- A base class for embedding generators.
11
- """
12
-
13
- def __init__(self) -> None:
14
- """
15
- Base method fot the initialization of an EmbeddingGenerator.
16
- """
17
- pass
18
-
19
- def generate_embeddings(self, texts: list[str], **kwargs) -> np.ndarray:
20
- """
21
- Base method fot the creating embeddings with an EmbeddingGenerator.
22
-
23
- Args:
24
- texts (list[str]): A list of input texts.
25
- **kwargs: Additional keyword arguments.
26
-
27
- Returns:
28
- np.ndarray: A numpy array of shape (len(texts), embedding_dimensions)
29
- containing the generated embeddings.
30
- """
31
- pass
32
-
33
-
34
- class EmbeddingGeneratorAPI(EmbeddingGenerator):
35
- """
36
- A base class for API embedding generators.
37
-
38
- Attributes:
39
- embedding_dimensions (int): The dimensionality of the generated embeddings.
40
- """
41
-
42
- def __init__(
43
- self,
44
- embedding_dimensions: int,
45
- **kwargs,
46
- ) -> None:
47
- """
48
- Initializes the API EmbeddingGenerator.
49
-
50
- Sets the embedding dimensions, and initiliazes and
51
- prepares a session with the API.
52
- """
53
-
54
- self.embedding_dimensions = embedding_dimensions
55
-
56
- self.session = requests.Session()
57
- self.api_address = kwargs.get("api_address")
58
- self.headers = kwargs.get("headers", {"Content-Type": "application/json"})
59
-
60
-
61
- class EmbeddingGeneratorHuggingFaceTEI(EmbeddingGeneratorAPI):
62
- """
63
- A class for generating embeddings using the HuggingFaceTEI API.
64
- """
65
-
66
- def generate_embeddings(self, texts: list[str], **kwargs) -> np.ndarray:
67
- """
68
- Generates embeddings for a list of input texts using a model
69
- via the HuggingFaceTEI API.
70
-
71
- Args:
72
- texts (list[str]): A list of input texts.
73
- **kwargs: Additional keyword arguments to pass to the
74
- SentenceTransformer model.
75
-
76
- Returns:
77
- np.ndarray: A numpy array of shape (len(texts), embedding_dimensions)
78
- containing the generated embeddings.
79
- """
80
- # prepare list for return
81
- embeddings = []
82
-
83
- # Check if the input list is empty
84
- if not texts:
85
- # If empty, return an empty numpy array with the correct shape
86
- return np.empty((0, self.embedding_dimensions))
87
-
88
- # process each text
89
- for text in texts:
90
- # send a request to the HuggingFaceTEI API
91
- data = {"inputs": text}
92
- response = self.session.post(
93
- self.api_address, headers=self.headers, json=data
94
- )
95
-
96
- # add generated embeddings to return list if request was successfull
97
- if response.status_code == 200:
98
- embeddings.append(response.json()[0])
99
- else:
100
- embeddings.append([0 for _ in range(self.embedding_dimensions)])
101
-
102
- return np.array(embeddings)
103
-
104
-
105
- class EmbeddingGeneratorOfflineInference(EmbeddingGenerator):
106
- """
107
- A class for generating embeddings using a given SentenceTransformer model
108
- loaded offline with SentenceTransformer.
109
-
110
- Args:
111
- model_name (str): The name of the SentenceTransformer model.
112
- embedding_dimensions (int): The dimensionality of the generated embeddings.
113
- **kwargs: Additional keyword arguments to pass to the model.
114
-
115
- Attributes:
116
- model_name (str): The name of the SentenceTransformer model.
117
- embedding_dimensions (int): The dimensionality of the generated embeddings.
118
- """
119
-
120
- def __init__(self, model_name: str, embedding_dimensions: int, **kwargs) -> None:
121
- """
122
- Initializes the EmbeddingGenerator in offline inference mode.
123
-
124
- Sets the model name, embedding dimensions, and creates a
125
- SentenceTransformer model instance.
126
- """
127
- self.model_name = model_name
128
- self.embedding_dimensions = embedding_dimensions
129
-
130
- # Create a SentenceTransformer model instance with the given
131
- # model name and embedding dimensions
132
- self.model = SentenceTransformer(
133
- model_name, truncate_dim=embedding_dimensions, **kwargs
134
- )
135
-
136
- # Disabel parallelism for tokenizer
137
- # Needed because process might be already parallelized
138
- # before embedding creation
139
- os.environ["TOKENIZERS_PARALLELISM"] = "false"
140
-
141
- def generate_embeddings(self, texts: list[str], **kwargs) -> np.ndarray:
142
- """
143
- Generates embeddings for a list of input texts using the
144
- SentenceTransformer model.
145
-
146
- Args:
147
- texts (list[str]): A list of input texts.
148
- **kwargs: Additional keyword arguments to pass to the
149
- SentenceTransformer model.
150
-
151
- Returns:
152
- np.ndarray: A numpy array of shape (len(texts), embedding_dimensions)
153
- containing the generated embeddings.
154
- """
155
- # Check if the input list is empty
156
- if not texts:
157
- # If empty, return an empty numpy array with the correct shape
158
- return np.empty((0, self.embedding_dimensions))
159
-
160
- # Generate embeddings using the SentenceTransformer model and return them
161
- return self.model.encode(texts, **kwargs)
162
-
163
-
164
- class EmbeddingGeneratorMock(EmbeddingGenerator):
165
- """
166
- A mock class for generating fake embeddings. Used for testing.
167
-
168
- Args:
169
- embedding_dimensions (int): The dimensionality of the generated embeddings.
170
- **kwargs: Additional keyword arguments to pass to the model.
171
-
172
- Attributes:
173
- embedding_dimensions (int): The dimensionality of the generated embeddings.
174
- """
175
-
176
- def __init__(self, embedding_dimensions: int, **kwargs) -> None:
177
- """
178
- Initializes the mock EmbeddingGenerator.
179
-
180
- Sets the embedding dimensions.
181
- """
182
- self.embedding_dimensions = embedding_dimensions
183
-
184
- def generate_embeddings(self, texts: list[str], **kwargs) -> np.ndarray:
185
- """
186
- Generates embeddings for a list of input texts.
187
-
188
- Args:
189
- texts (list[str]): A list of input texts.
190
- **kwargs: Additional keyword arguments.
191
-
192
- Returns:
193
- np.ndarray: A numpy array of shape (len(texts), embedding_dimensions)
194
- containing the generated embeddings.
195
- """
196
- # Check if the input list is empty
197
- if not texts:
198
- # If empty, return an empty numpy array with the correct shape
199
- return np.empty((0, self.embedding_dimensions))
200
-
201
- # Generate mock embeddings return them
202
- return np.ones((len(texts), 1024))
File without changes
File without changes
File without changes
File without changes
File without changes