ebm4subjects 0.5.4__tar.gz → 0.5.5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (25) hide show
  1. {ebm4subjects-0.5.4 → ebm4subjects-0.5.5}/PKG-INFO +4 -2
  2. {ebm4subjects-0.5.4 → ebm4subjects-0.5.5}/pyproject.toml +4 -2
  3. {ebm4subjects-0.5.4 → ebm4subjects-0.5.5}/src/ebm4subjects/ebm_logging.py +9 -9
  4. {ebm4subjects-0.5.4 → ebm4subjects-0.5.5}/src/ebm4subjects/ebm_model.py +14 -10
  5. {ebm4subjects-0.5.4 → ebm4subjects-0.5.5}/src/ebm4subjects/embedding_generator.py +143 -44
  6. {ebm4subjects-0.5.4 → ebm4subjects-0.5.5}/.gitignore +0 -0
  7. {ebm4subjects-0.5.4 → ebm4subjects-0.5.5}/.python-version +0 -0
  8. {ebm4subjects-0.5.4 → ebm4subjects-0.5.5}/LICENSE +0 -0
  9. {ebm4subjects-0.5.4 → ebm4subjects-0.5.5}/README.md +0 -0
  10. {ebm4subjects-0.5.4 → ebm4subjects-0.5.5}/docs/Makefile +0 -0
  11. {ebm4subjects-0.5.4 → ebm4subjects-0.5.5}/docs/make.bat +0 -0
  12. {ebm4subjects-0.5.4 → ebm4subjects-0.5.5}/docs/source/README.md +0 -0
  13. {ebm4subjects-0.5.4 → ebm4subjects-0.5.5}/docs/source/conf.py +0 -0
  14. {ebm4subjects-0.5.4 → ebm4subjects-0.5.5}/docs/source/ebm4subjects.rst +0 -0
  15. {ebm4subjects-0.5.4 → ebm4subjects-0.5.5}/docs/source/index.rst +0 -0
  16. {ebm4subjects-0.5.4 → ebm4subjects-0.5.5}/ebm-sketch.svg +0 -0
  17. {ebm4subjects-0.5.4 → ebm4subjects-0.5.5}/src/ebm4subjects/__init__.py +0 -0
  18. {ebm4subjects-0.5.4 → ebm4subjects-0.5.5}/src/ebm4subjects/analyzer.py +0 -0
  19. {ebm4subjects-0.5.4 → ebm4subjects-0.5.5}/src/ebm4subjects/chunker.py +0 -0
  20. {ebm4subjects-0.5.4 → ebm4subjects-0.5.5}/src/ebm4subjects/duckdb_client.py +0 -0
  21. {ebm4subjects-0.5.4 → ebm4subjects-0.5.5}/src/ebm4subjects/prepare_data.py +0 -0
  22. {ebm4subjects-0.5.4 → ebm4subjects-0.5.5}/tests/__init__.py +0 -0
  23. {ebm4subjects-0.5.4 → ebm4subjects-0.5.5}/tests/data/vocab.ttl +0 -0
  24. {ebm4subjects-0.5.4 → ebm4subjects-0.5.5}/tests/test_hello.py +0 -0
  25. {ebm4subjects-0.5.4 → ebm4subjects-0.5.5}/tests/test_prepare_data.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ebm4subjects
3
- Version: 0.5.4
3
+ Version: 0.5.5
4
4
  Summary: Embedding Based Matching for Automated Subject Indexing
5
5
  Author: Deutsche Nationalbibliothek
6
6
  Maintainer-email: Clemens Rietdorf <c.rietdorf@dnb.de>, Maximilian Kähler <m.kaehler@dnb.de>
@@ -14,12 +14,14 @@ Classifier: Programming Language :: Python :: 3
14
14
  Requires-Python: >=3.10
15
15
  Requires-Dist: duckdb>=1.3.0
16
16
  Requires-Dist: nltk~=3.9.1
17
+ Requires-Dist: openai>=2.15.0
17
18
  Requires-Dist: polars>=1.30.0
18
19
  Requires-Dist: pyarrow>=21.0.0
19
20
  Requires-Dist: pyoxigraph>=0.4.11
20
21
  Requires-Dist: rdflib~=7.1.3
21
- Requires-Dist: sentence-transformers>=5.0.0
22
22
  Requires-Dist: xgboost>=3.0.2
23
+ Provides-Extra: in-process
24
+ Requires-Dist: sentence-transformers>=5.0.0; extra == 'in-process'
23
25
  Description-Content-Type: text/markdown
24
26
 
25
27
  # Embedding Based Matching for Automated Subject Indexing
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "ebm4subjects"
3
- version = "0.5.4"
3
+ version = "0.5.5"
4
4
  description = "Embedding Based Matching for Automated Subject Indexing"
5
5
  authors = [
6
6
  {name = "Deutsche Nationalbibliothek"},
@@ -29,13 +29,15 @@ requires-python = ">=3.10"
29
29
  dependencies = [
30
30
  "duckdb>=1.3.0",
31
31
  "nltk~=3.9.1",
32
+ "openai>=2.15.0",
32
33
  "polars>=1.30.0",
33
34
  "pyarrow>=21.0.0",
34
35
  "pyoxigraph>=0.4.11",
35
36
  "rdflib~=7.1.3",
36
- "sentence-transformers>=5.0.0",
37
37
  "xgboost>=3.0.2",
38
38
  ]
39
+ [project.optional-dependencies]
40
+ in-process=["sentence-transformers>=5.0.0"]
39
41
 
40
42
  [build-system]
41
43
  requires = ["hatchling"]
@@ -39,17 +39,17 @@ class EbmLogger:
39
39
  else:
40
40
  self.logger.setLevel(logging.NOTSET)
41
41
 
42
- # Create a file handler to log messages to a file
43
- log_file_handler = logging.FileHandler(f"{log_path}/ebm.log")
44
- log_file_handler.setFormatter(
45
- logging.Formatter(
46
- "%(asctime)s %(levelname)s: %(message)s",
47
- "%Y-%m-%d %H:%M:%S",
42
+ # Create a file handler to log messages to a file
43
+ if not self.logger.handlers:
44
+ log_file_handler = logging.FileHandler(f"{log_path}/ebm.log")
45
+ log_file_handler.setFormatter(
46
+ logging.Formatter(
47
+ "%(asctime)s %(levelname)s: %(message)s",
48
+ "%Y-%m-%d %H:%M:%S",
49
+ )
48
50
  )
49
- )
50
51
 
51
- # Add the file handler to the logger
52
- self.logger.addHandler(log_file_handler)
52
+ self.logger.addHandler(log_file_handler)
53
53
 
54
54
  def get_logger(self) -> logging.Logger:
55
55
  """
@@ -16,7 +16,7 @@ from ebm4subjects.ebm_logging import EbmLogger, NullLogger, XGBLogging
16
16
  from ebm4subjects.embedding_generator import (
17
17
  EmbeddingGeneratorHuggingFaceTEI,
18
18
  EmbeddingGeneratorMock,
19
- EmbeddingGeneratorOfflineInference,
19
+ EmbeddingGeneratorInProcess,
20
20
  EmbeddingGeneratorOpenAI,
21
21
  )
22
22
 
@@ -50,6 +50,7 @@ class EbmModel:
50
50
  encode_args_documents: dict | str | None = None,
51
51
  log_path: str | None = None,
52
52
  logger: logging.Logger | None = None,
53
+ logging_level: str = "info",
53
54
  ) -> None:
54
55
  """
55
56
  A class representing an Embedding-Based-Matching (EBM) model
@@ -139,7 +140,7 @@ class EbmModel:
139
140
  self.train_jobs = int(xgb_jobs)
140
141
 
141
142
  # Initiliaze logging
142
- self.init_logger(log_path, logger)
143
+ self.init_logger(log_path, logger, logging_level)
143
144
 
144
145
  # Initialize EBM model
145
146
  self.model = None
@@ -180,11 +181,12 @@ class EbmModel:
180
181
  None
181
182
  """
182
183
  if self.generator is None:
183
- if self.embedding_model_deployment == "offline-inference":
184
+ if self.embedding_model_deployment == "in-process":
184
185
  self.logger.info("initializing offline-inference embedding generator")
185
- self.generator = EmbeddingGeneratorOfflineInference(
186
+ self.generator = EmbeddingGeneratorInProcess(
186
187
  model_name=self.embedding_model_name,
187
188
  embedding_dimensions=self.embedding_dimensions,
189
+ logger=self.logger,
188
190
  **self.embedding_model_args,
189
191
  )
190
192
  elif self.embedding_model_deployment == "mock":
@@ -195,6 +197,7 @@ class EbmModel:
195
197
  self.generator = EmbeddingGeneratorHuggingFaceTEI(
196
198
  model_name=self.embedding_model_name,
197
199
  embedding_dimensions=self.embedding_dimensions,
200
+ logger=self.logger,
198
201
  **self.embedding_model_args,
199
202
  )
200
203
  elif self.embedding_model_deployment == "OpenAI":
@@ -202,6 +205,7 @@ class EbmModel:
202
205
  self.generator = EmbeddingGeneratorOpenAI(
203
206
  model_name=self.embedding_model_name,
204
207
  embedding_dimensions=self.embedding_dimensions,
208
+ logger=self.logger,
205
209
  **self.embedding_model_args,
206
210
  )
207
211
  else:
@@ -209,7 +213,10 @@ class EbmModel:
209
213
  raise NotImplementedError
210
214
 
211
215
  def init_logger(
212
- self, log_path: str | None = None, logger: logging.Logger | None = None
216
+ self,
217
+ log_path: str | None = None,
218
+ logger: logging.Logger | None = None,
219
+ logging_level: str = "info",
213
220
  ) -> None:
214
221
  """
215
222
  Initializes the logging for the EBM model.
@@ -218,7 +225,7 @@ class EbmModel:
218
225
  None
219
226
  """
220
227
  if log_path:
221
- self.logger = EbmLogger(log_path, "info").get_logger()
228
+ self.logger = EbmLogger(log_path, logging_level).get_logger()
222
229
  self.xgb_logger = XGBLogging(self.logger, epoch_log_interval=1)
223
230
  self.xgb_callbacks = [self.xgb_logger]
224
231
  elif logger:
@@ -769,7 +776,4 @@ class EbmModel:
769
776
  Returns:
770
777
  EbmModel: The loaded EBM model instance.
771
778
  """
772
- ebm_model = joblib.load(input_path)
773
- ebm_model.init_logger()
774
-
775
- return ebm_model
779
+ return joblib.load(input_path)
@@ -1,8 +1,9 @@
1
+ import logging
1
2
  import os
2
3
 
3
4
  import numpy as np
4
5
  import requests
5
- from sentence_transformers import SentenceTransformer
6
+ from openai import BadRequestError, NotFoundError, OpenAI
6
7
  from tqdm import tqdm
7
8
 
8
9
 
@@ -32,25 +33,29 @@ class EmbeddingGenerator:
32
33
  pass
33
34
 
34
35
 
35
- class EmbeddingGeneratorAPI(EmbeddingGenerator):
36
+ class EmbeddingGeneratorHuggingFaceTEI(EmbeddingGenerator):
36
37
  """
37
- A base class for API embedding generators.
38
-
39
- Attributes:
40
- embedding_dimensions (int): The dimensionality of the generated embeddings.
38
+ A class for generating embeddings using the HuggingFaceTEI API.
41
39
  """
42
40
 
43
41
  def __init__(
44
42
  self,
45
43
  model_name: str,
46
44
  embedding_dimensions: int,
45
+ logger: logging.Logger,
47
46
  **kwargs,
48
47
  ) -> None:
49
48
  """
50
- Initializes the API EmbeddingGenerator.
49
+ Initializes the HuggingFaceTEI API EmbeddingGenerator.
51
50
 
52
51
  Sets the embedding dimensions, and initiliazes and
53
52
  prepares a session with the API.
53
+
54
+ Args:
55
+ model_name (str): The name of the SentenceTransformer model.
56
+ embedding_dimensions (int): The dimensionality of the generated embeddings.
57
+ logger (Logger): A logger for the embedding generator.
58
+ **kwargs: Additional keyword arguments to pass to the model.
54
59
  """
55
60
 
56
61
  self.embedding_dimensions = embedding_dimensions
@@ -59,11 +64,36 @@ class EmbeddingGeneratorAPI(EmbeddingGenerator):
59
64
  self.api_address = kwargs.get("api_address")
60
65
  self.headers = kwargs.get("headers", {"Content-Type": "application/json"})
61
66
 
67
+ self.logger = logger
68
+ self._test_api()
62
69
 
63
- class EmbeddingGeneratorHuggingFaceTEI(EmbeddingGeneratorAPI):
64
- """
65
- A class for generating embeddings using the HuggingFaceTEI API.
66
- """
70
+ def _test_api(self):
71
+ """
72
+ Tests if the API is working with the given parameters
73
+ """
74
+ response = self.session.post(
75
+ self.api_address,
76
+ headers=self.headers,
77
+ json={"inputs": "This is a test request!", "truncate": True},
78
+ )
79
+ if response.status_code == 200:
80
+ self.logger.debug(
81
+ "API call successful. Everything seems to be working fine."
82
+ )
83
+ elif response.status_code == 404:
84
+ self.logger.error(
85
+ "API not found under given adress! Please check the corresponding parameter!"
86
+ )
87
+ raise RuntimeError(
88
+ "API not found under given adress! Please check the corresponding parameter!"
89
+ )
90
+ else:
91
+ self.logger.error(
92
+ "Request to API not possible! Please check the corresponding parameters!"
93
+ )
94
+ raise RuntimeError(
95
+ "Request to API not possible! Please check the corresponding parameters!"
96
+ )
67
97
 
68
98
  def generate_embeddings(self, texts: list[str], **kwargs) -> np.ndarray:
69
99
  """
@@ -72,8 +102,7 @@ class EmbeddingGeneratorHuggingFaceTEI(EmbeddingGeneratorAPI):
72
102
 
73
103
  Args:
74
104
  texts (list[str]): A list of input texts.
75
- **kwargs: Additional keyword arguments to pass to the
76
- SentenceTransformer model.
105
+ **kwargs: Additional keyword arguments to pass to the API.
77
106
 
78
107
  Returns:
79
108
  np.ndarray: A numpy array of shape (len(texts), embedding_dimensions)
@@ -102,19 +131,87 @@ class EmbeddingGeneratorHuggingFaceTEI(EmbeddingGeneratorAPI):
102
131
  if response.status_code == 200:
103
132
  embeddings.extend(response.json())
104
133
  else:
105
- # TODO: write warning to logger
134
+ self.logger.warn("Call to API NOT successful! Returning 0's.")
106
135
  for _ in batch_texts:
107
- # TODO: ensure same format as true case and truncate dim
108
- embeddings.append([0 for _ in range(self.embedding_dimensions)])
136
+ embeddings.append(
137
+ [
138
+ 0
139
+ for _ in range(
140
+ min(
141
+ self.embedding_dimensions,
142
+ kwargs.get("truncate_prompt_tokens", float("inf")),
143
+ ),
144
+ )
145
+ ]
146
+ )
109
147
 
110
148
  return np.array(embeddings)
111
149
 
112
150
 
113
- class EmbeddingGeneratorOpenAI(EmbeddingGeneratorAPI):
151
+ class EmbeddingGeneratorOpenAI(EmbeddingGenerator):
114
152
  """
115
- A class for generating embeddings using any OpenAI compatibleAPI.
153
+ A class for generating embeddings using any OpenAI compatible API.
116
154
  """
117
155
 
156
+ def __init__(
157
+ self,
158
+ model_name: str,
159
+ embedding_dimensions: int,
160
+ logger: logging.Logger,
161
+ **kwargs,
162
+ ) -> None:
163
+ """
164
+ Initializes the OpenAI API EmbeddingGenerator.
165
+
166
+ Sets the embedding dimensions, and initiliazes and
167
+ prepares a session with the API.
168
+
169
+ Args:
170
+ model_name (str): The name of the SentenceTransformer model.
171
+ embedding_dimensions (int): The dimensionality of the generated embeddings.
172
+ logger (Logger): A logger for the embedding generator.
173
+ **kwargs: Additional keyword arguments to pass to the model.
174
+ """
175
+
176
+ self.embedding_dimensions = embedding_dimensions
177
+ self.model_name = model_name
178
+
179
+ if not (api_key := os.environ.get("OPENAI_API_KEY")):
180
+ api_key = ""
181
+
182
+ self.client = OpenAI(api_key=api_key, base_url=kwargs.get("api_address"))
183
+
184
+ self.logger = logger
185
+ self._test_api()
186
+
187
+ def _test_api(self):
188
+ """
189
+ Tests if the API is working with the given parameters
190
+ """
191
+ try:
192
+ _ = self.client.embeddings.create(
193
+ input="This is a test request!",
194
+ model=self.model_name,
195
+ encoding_format="float",
196
+ )
197
+ self.logger.debug(
198
+ "API call successful. Everything seems to be working fine."
199
+ )
200
+ except NotFoundError:
201
+ self.logger.error(
202
+ "API not found under given adress! Please check the corresponding parameter!"
203
+ )
204
+ raise RuntimeError(
205
+ "API not found under given adress! Please check the corresponding parameter!"
206
+ )
207
+ except BadRequestError:
208
+ self.logger.error(
209
+ "Request to API not possible! Please check the corresponding parameters!"
210
+ )
211
+ raise RuntimeError(
212
+ "Request to API not possible! Please check the corresponding parameters!"
213
+ )
214
+
118
215
  def generate_embeddings(self, texts: list[str], **kwargs) -> np.ndarray:
119
216
  """
120
217
  Generates embeddings for a list of input texts using a model
@@ -122,8 +219,7 @@ class EmbeddingGeneratorOpenAI(EmbeddingGeneratorAPI):
122
219
 
123
220
  Args:
124
221
  texts (list[str]): A list of input texts.
125
- **kwargs: Additional keyword arguments to pass to the
126
- SentenceTransformer model.
222
+ **kwargs: Additional keyword arguments to pass to the API.
127
223
 
128
224
  Returns:
129
225
  np.ndarray: A numpy array of shape (len(texts), embedding_dimensions)
@@ -143,53 +239,54 @@ class EmbeddingGeneratorOpenAI(EmbeddingGeneratorAPI):
143
239
 
144
240
  for i in tqdm(range(0, len(texts), batch_size), desc="Processing batches"):
145
241
  batch_texts = texts[i : i + batch_size]
146
- data = {
147
- "input": batch_texts,
148
- "model": self.model_name,
149
- "encoding_format": "float",
150
- **kwargs,
151
- }
152
242
 
153
- response = self.session.post(
154
- self.api_address, headers=self.headers, json=data
155
- )
243
+ # Try to get embeddings for the batch from the API
244
+ try:
245
+ embedding_response = self.client.embeddings.create(
246
+ input=batch_texts,
247
+ model=self.model_name,
248
+ encoding_format="float",
249
+ extra_body={**kwargs},
250
+ )
156
251
 
157
- # Process all embeddings from the batch response
158
- if response.status_code == 200:
159
- response_data = response.json()
252
+ # Process all embeddings from the batch response
160
253
  for i, _ in enumerate(batch_texts):
161
- embedding = response_data["data"][i]["embedding"]
162
- embeddings.append(embedding)
163
- else:
164
- # TODO: write warning to logger
254
+ embeddings.append(embedding_response.data[i].embedding)
255
+ except (NotFoundError, BadRequestError):
256
+ self.logger.warn("Call to API NOT successful! Returning 0's.")
165
257
  for _ in batch_texts:
166
258
  embeddings.append([0 for _ in range(self.embedding_dimensions)])
167
259
 
168
260
  return np.array(embeddings)
169
261
 
170
262
 
171
- class EmbeddingGeneratorOfflineInference(EmbeddingGenerator):
263
+ class EmbeddingGeneratorInProcess(EmbeddingGenerator):
172
264
  """
173
265
  A class for generating embeddings using a given SentenceTransformer model
174
- loaded offline with SentenceTransformer.
266
+ loaded in-process with SentenceTransformer.
175
267
 
176
268
  Args:
177
269
  model_name (str): The name of the SentenceTransformer model.
178
270
  embedding_dimensions (int): The dimensionality of the generated embeddings.
271
+ logger (Logger): A logger for the embedding generator.
179
272
  **kwargs: Additional keyword arguments to pass to the model.
180
-
181
- Attributes:
182
- model_name (str): The name of the SentenceTransformer model.
183
- embedding_dimensions (int): The dimensionality of the generated embeddings.
184
273
  """
185
274
 
186
- def __init__(self, model_name: str, embedding_dimensions: int, **kwargs) -> None:
275
+ def __init__(
276
+ self,
277
+ model_name: str,
278
+ embedding_dimensions: int,
279
+ logger: logging.Logger,
280
+ **kwargs,
281
+ ) -> None:
187
282
  """
188
- Initializes the EmbeddingGenerator in offline inference mode.
283
+ Initializes the EmbeddingGenerator in 'in-process' mode.
189
284
 
190
285
  Sets the model name, embedding dimensions, and creates a
191
286
  SentenceTransformer model instance.
192
287
  """
288
+ from sentence_transformers import SentenceTransformer
289
+
193
290
  self.model_name = model_name
194
291
  self.embedding_dimensions = embedding_dimensions
195
292
 
@@ -198,6 +295,8 @@ class EmbeddingGeneratorOfflineInference(EmbeddingGenerator):
198
295
  self.model = SentenceTransformer(
199
296
  model_name, truncate_dim=embedding_dimensions, **kwargs
200
297
  )
298
+ self.logger = logger
299
+ self.logger.debug(f"SentenceTransfomer model running on {self.model.device}")
201
300
 
202
301
  # Disabel parallelism for tokenizer
203
302
  # Needed because process might be already parallelized
File without changes
File without changes
File without changes
File without changes
File without changes