ebm4subjects 0.5.3__py3-none-any.whl → 0.5.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -39,17 +39,17 @@ class EbmLogger:
39
39
  else:
40
40
  self.logger.setLevel(logging.NOTSET)
41
41
 
42
- # Create a file handler to log messages to a file
43
- log_file_handler = logging.FileHandler(f"{log_path}/ebm.log")
44
- log_file_handler.setFormatter(
45
- logging.Formatter(
46
- "%(asctime)s %(levelname)s: %(message)s",
47
- "%Y-%m-%d %H:%M:%S",
42
+ # Create a file handler to log messages to a file
43
+ if not self.logger.handlers:
44
+ log_file_handler = logging.FileHandler(f"{log_path}/ebm.log")
45
+ log_file_handler.setFormatter(
46
+ logging.Formatter(
47
+ "%(asctime)s %(levelname)s: %(message)s",
48
+ "%Y-%m-%d %H:%M:%S",
49
+ )
48
50
  )
49
- )
50
51
 
51
- # Add the file handler to the logger
52
- self.logger.addHandler(log_file_handler)
52
+ self.logger.addHandler(log_file_handler)
53
53
 
54
54
  def get_logger(self) -> logging.Logger:
55
55
  """
ebm4subjects/ebm_model.py CHANGED
@@ -15,8 +15,9 @@ from ebm4subjects.duckdb_client import Duckdb_client
15
15
  from ebm4subjects.ebm_logging import EbmLogger, NullLogger, XGBLogging
16
16
  from ebm4subjects.embedding_generator import (
17
17
  EmbeddingGeneratorHuggingFaceTEI,
18
- EmbeddingGeneratorOfflineInference,
19
18
  EmbeddingGeneratorMock,
19
+ EmbeddingGeneratorInProcess,
20
+ EmbeddingGeneratorOpenAI,
20
21
  )
21
22
 
22
23
 
@@ -43,12 +44,13 @@ class EbmModel:
43
44
  use_altLabels: bool = True,
44
45
  hnsw_index_params: dict | str | None = None,
45
46
  embedding_model_name: str | None = None,
46
- embedding_model_type: str = "offline-inference",
47
+ embedding_model_deployment: str = "offline-inference",
47
48
  embedding_model_args: dict | str | None = None,
48
49
  encode_args_vocab: dict | str | None = None,
49
50
  encode_args_documents: dict | str | None = None,
50
51
  log_path: str | None = None,
51
52
  logger: logging.Logger | None = None,
53
+ logging_level: str = "info",
52
54
  ) -> None:
53
55
  """
54
56
  A class representing an Embedding-Based-Matching (EBM) model
@@ -99,7 +101,7 @@ class EbmModel:
99
101
 
100
102
  # Parameters for embedding generator
101
103
  self.generator = None
102
- self.embedding_model_type = embedding_model_type
104
+ self.embedding_model_deployment = embedding_model_deployment
103
105
  self.embedding_model_name = embedding_model_name
104
106
  self.embedding_dimensions = int(embedding_dimensions)
105
107
  if isinstance(embedding_model_args, str) or not embedding_model_args:
@@ -138,7 +140,7 @@ class EbmModel:
138
140
  self.train_jobs = int(xgb_jobs)
139
141
 
140
142
  # Initiliaze logging
141
- self.init_logger(log_path, logger)
143
+ self.init_logger(log_path, logger, logging_level)
142
144
 
143
145
  # Initialize EBM model
144
146
  self.model = None
@@ -179,20 +181,31 @@ class EbmModel:
179
181
  None
180
182
  """
181
183
  if self.generator is None:
182
- if self.embedding_model_type == "offline-inference":
184
+ if self.embedding_model_deployment == "in-process":
183
185
  self.logger.info("initializing offline-inference embedding generator")
184
- self.generator = EmbeddingGeneratorOfflineInference(
186
+ self.generator = EmbeddingGeneratorInProcess(
185
187
  model_name=self.embedding_model_name,
186
188
  embedding_dimensions=self.embedding_dimensions,
189
+ logger=self.logger,
187
190
  **self.embedding_model_args,
188
191
  )
189
- elif self.embedding_model_type == "mock":
192
+ elif self.embedding_model_deployment == "mock":
190
193
  self.logger.info("initializing mock embedding generator")
191
194
  self.generator = EmbeddingGeneratorMock(self.embedding_dimensions)
192
- elif self.embedding_model_type == "HuggingFaceTEI":
195
+ elif self.embedding_model_deployment == "HuggingFaceTEI":
193
196
  self.logger.info("initializing API embedding generator")
194
197
  self.generator = EmbeddingGeneratorHuggingFaceTEI(
198
+ model_name=self.embedding_model_name,
199
+ embedding_dimensions=self.embedding_dimensions,
200
+ logger=self.logger,
201
+ **self.embedding_model_args,
202
+ )
203
+ elif self.embedding_model_deployment == "OpenAI":
204
+ self.logger.info("initializing API embedding generator")
205
+ self.generator = EmbeddingGeneratorOpenAI(
206
+ model_name=self.embedding_model_name,
195
207
  embedding_dimensions=self.embedding_dimensions,
208
+ logger=self.logger,
196
209
  **self.embedding_model_args,
197
210
  )
198
211
  else:
@@ -200,7 +213,10 @@ class EbmModel:
200
213
  raise NotImplementedError
201
214
 
202
215
  def init_logger(
203
- self, log_path: str | None = None, logger: logging.Logger | None = None
216
+ self,
217
+ log_path: str | None = None,
218
+ logger: logging.Logger | None = None,
219
+ logging_level: str = "info",
204
220
  ) -> None:
205
221
  """
206
222
  Initializes the logging for the EBM model.
@@ -209,7 +225,7 @@ class EbmModel:
209
225
  None
210
226
  """
211
227
  if log_path:
212
- self.logger = EbmLogger(log_path, "info").get_logger()
228
+ self.logger = EbmLogger(log_path, logging_level).get_logger()
213
229
  self.xgb_logger = XGBLogging(self.logger, epoch_log_interval=1)
214
230
  self.xgb_callbacks = [self.xgb_logger]
215
231
  elif logger:
@@ -760,7 +776,4 @@ class EbmModel:
760
776
  Returns:
761
777
  EbmModel: The loaded EBM model instance.
762
778
  """
763
- ebm_model = joblib.load(input_path)
764
- ebm_model.init_logger()
765
-
766
- return ebm_model
779
+ return joblib.load(input_path)
@@ -1,8 +1,10 @@
1
+ import logging
1
2
  import os
2
3
 
3
4
  import numpy as np
4
5
  import requests
5
- from sentence_transformers import SentenceTransformer
6
+ from openai import BadRequestError, NotFoundError, OpenAI
7
+ from tqdm import tqdm
6
8
 
7
9
 
8
10
  class EmbeddingGenerator:
@@ -31,37 +33,67 @@ class EmbeddingGenerator:
31
33
  pass
32
34
 
33
35
 
34
- class EmbeddingGeneratorAPI(EmbeddingGenerator):
36
+ class EmbeddingGeneratorHuggingFaceTEI(EmbeddingGenerator):
35
37
  """
36
- A base class for API embedding generators.
37
-
38
- Attributes:
39
- embedding_dimensions (int): The dimensionality of the generated embeddings.
38
+ A class for generating embeddings using the HuggingFaceTEI API.
40
39
  """
41
40
 
42
41
  def __init__(
43
42
  self,
43
+ model_name: str,
44
44
  embedding_dimensions: int,
45
+ logger: logging.Logger,
45
46
  **kwargs,
46
47
  ) -> None:
47
48
  """
48
- Initializes the API EmbeddingGenerator.
49
+ Initializes the HuggingFaceTEI API EmbeddingGenerator.
49
50
 
50
51
  Sets the embedding dimensions, and initiliazes and
51
52
  prepares a session with the API.
53
+
54
+ Args:
55
+ model_name (str): The name of the SentenceTransformer model.
56
+ embedding_dimensions (int): The dimensionality of the generated embeddings.
57
+ logger (Logger): A logger for the embedding generator.
58
+ **kwargs: Additional keyword arguments to pass to the model.
52
59
  """
53
60
 
54
61
  self.embedding_dimensions = embedding_dimensions
55
-
62
+ self.model_name = model_name
56
63
  self.session = requests.Session()
57
64
  self.api_address = kwargs.get("api_address")
58
65
  self.headers = kwargs.get("headers", {"Content-Type": "application/json"})
59
66
 
67
+ self.logger = logger
68
+ self._test_api()
60
69
 
61
- class EmbeddingGeneratorHuggingFaceTEI(EmbeddingGeneratorAPI):
62
- """
63
- A class for generating embeddings using the HuggingFaceTEI API.
64
- """
70
+ def _test_api(self):
71
+ """
72
+ Tests if the API is working with the given parameters
73
+ """
74
+ response = self.session.post(
75
+ self.api_address,
76
+ headers=self.headers,
77
+ json={"inputs": "This is a test request!", "truncate": True},
78
+ )
79
+ if response.status_code == 200:
80
+ self.logger.debug(
81
+ "API call successful. Everything seems to be working fine."
82
+ )
83
+ elif response.status_code == 404:
84
+ self.logger.error(
85
+ "API not found under given adress! Please check the corresponding parameter!"
86
+ )
87
+ raise RuntimeError(
88
+ "API not found under given adress! Please check the corresponding parameter!"
89
+ )
90
+ else:
91
+ self.logger.error(
92
+ "Request to API not possible! Please check the corresponding parameters!"
93
+ )
94
+ raise RuntimeError(
95
+ "Request to API not possible! Please check the corresponding parameters!"
96
+ )
65
97
 
66
98
  def generate_embeddings(self, texts: list[str], **kwargs) -> np.ndarray:
67
99
  """
@@ -70,8 +102,7 @@ class EmbeddingGeneratorHuggingFaceTEI(EmbeddingGeneratorAPI):
70
102
 
71
103
  Args:
72
104
  texts (list[str]): A list of input texts.
73
- **kwargs: Additional keyword arguments to pass to the
74
- SentenceTransformer model.
105
+ **kwargs: Additional keyword arguments to pass to the API.
75
106
 
76
107
  Returns:
77
108
  np.ndarray: A numpy array of shape (len(texts), embedding_dimensions)
@@ -85,45 +116,177 @@ class EmbeddingGeneratorHuggingFaceTEI(EmbeddingGeneratorAPI):
85
116
  # If empty, return an empty numpy array with the correct shape
86
117
  return np.empty((0, self.embedding_dimensions))
87
118
 
88
- # process each text
89
- for text in texts:
119
+ # Process in smaller batches to avoid memory overload
120
+ batch_size = min(32, len(texts)) # HuggingFaceTEI has a limit of 32 as default
121
+
122
+ for i in tqdm(range(0, len(texts), batch_size), desc="Processing batches"):
123
+ batch_texts = texts[i : i + batch_size]
90
124
  # send a request to the HuggingFaceTEI API
91
- data = {"inputs": text}
125
+ data = {"inputs": batch_texts, "truncate": True}
92
126
  response = self.session.post(
93
127
  self.api_address, headers=self.headers, json=data
94
128
  )
95
129
 
96
130
  # add generated embeddings to return list if request was successfull
97
131
  if response.status_code == 200:
98
- embeddings.append(response.json()[0])
132
+ embeddings.extend(response.json())
99
133
  else:
100
- embeddings.append([0 for _ in range(self.embedding_dimensions)])
134
+ self.logger.warn("Call to API NOT successful! Returning 0's.")
135
+ for _ in batch_texts:
136
+ embeddings.append(
137
+ [
138
+ 0
139
+ for _ in range(
140
+ min(
141
+ self.embedding_dimensions,
142
+ kwargs.get("truncate_prompt_tokens", float("inf")),
143
+ ),
144
+ )
145
+ ]
146
+ )
147
+
148
+ return np.array(embeddings)
149
+
150
+
151
+ class EmbeddingGeneratorOpenAI(EmbeddingGenerator):
152
+ """
153
+ A class for generating embeddings using any OpenAI compatible API.
154
+ """
155
+
156
+ def __init__(
157
+ self,
158
+ model_name: str,
159
+ embedding_dimensions: int,
160
+ logger: logging.Logger,
161
+ **kwargs,
162
+ ) -> None:
163
+ """
164
+ Initializes the OpenAI API EmbeddingGenerator.
165
+
166
+ Sets the embedding dimensions, and initiliazes and
167
+ prepares a session with the API.
168
+
169
+ Args:
170
+ model_name (str): The name of the SentenceTransformer model.
171
+ embedding_dimensions (int): The dimensionality of the generated embeddings.
172
+ logger (Logger): A logger for the embedding generator.
173
+ **kwargs: Additional keyword arguments to pass to the model.
174
+ """
175
+
176
+ self.embedding_dimensions = embedding_dimensions
177
+ self.model_name = model_name
178
+
179
+ if not (api_key := os.environ.get("OPENAI_API_KEY")):
180
+ api_key = ""
181
+
182
+ self.client = OpenAI(api_key=api_key, base_url=kwargs.get("api_address"))
183
+
184
+ self.logger = logger
185
+ self._test_api()
186
+
187
+ def _test_api(self):
188
+ """
189
+ Tests if the API is working with the given parameters
190
+ """
191
+ try:
192
+ _ = self.client.embeddings.create(
193
+ input="This is a test request!",
194
+ model=self.model_name,
195
+ encoding_format="float",
196
+ )
197
+ self.logger.debug(
198
+ "API call successful. Everything seems to be working fine."
199
+ )
200
+ except NotFoundError:
201
+ self.logger.error(
202
+ "API not found under given adress! Please check the corresponding parameter!"
203
+ )
204
+ raise RuntimeError(
205
+ "API not found under given adress! Please check the corresponding parameter!"
206
+ )
207
+ except BadRequestError:
208
+ self.logger.error(
209
+ "Request to API not possible! Please check the corresponding parameters!"
210
+ )
211
+ raise RuntimeError(
212
+ "Request to API not possible! Please check the corresponding parameters!"
213
+ )
214
+
215
+ def generate_embeddings(self, texts: list[str], **kwargs) -> np.ndarray:
216
+ """
217
+ Generates embeddings for a list of input texts using a model
218
+ via an OpenAI compatible API.
219
+
220
+ Args:
221
+ texts (list[str]): A list of input texts.
222
+ **kwargs: Additional keyword arguments to pass to the API.
223
+
224
+ Returns:
225
+ np.ndarray: A numpy array of shape (len(texts), embedding_dimensions)
226
+ containing the generated embeddings.
227
+ """
228
+ # prepare list for return
229
+ embeddings = []
230
+
231
+ # Check if the input list is empty
232
+ if not texts:
233
+ # If empty, return an empty numpy array with the correct shape
234
+ return np.empty((0, self.embedding_dimensions))
235
+
236
+ # Process in smaller batches to avoid memory overload
237
+ batch_size = min(200, len(texts))
238
+ embeddings = []
239
+
240
+ for i in tqdm(range(0, len(texts), batch_size), desc="Processing batches"):
241
+ batch_texts = texts[i : i + batch_size]
242
+
243
+ # Try to get embeddings for the batch from the API
244
+ try:
245
+ embedding_response = self.client.embeddings.create(
246
+ input=batch_texts,
247
+ model=self.model_name,
248
+ encoding_format="float",
249
+ extra_body={**kwargs},
250
+ )
251
+
252
+ # Process all embeddings from the batch response
253
+ for i, _ in enumerate(batch_texts):
254
+ embeddings.append(embedding_response.data[i].embedding)
255
+ except (NotFoundError, BadRequestError):
256
+ self.logger.warn("Call to API NOT successful! Returning 0's.")
257
+ for _ in batch_texts:
258
+ embeddings.append([0 for _ in range(self.embedding_dimensions)])
101
259
 
102
260
  return np.array(embeddings)
103
261
 
104
262
 
105
- class EmbeddingGeneratorOfflineInference(EmbeddingGenerator):
263
+ class EmbeddingGeneratorInProcess(EmbeddingGenerator):
106
264
  """
107
265
  A class for generating embeddings using a given SentenceTransformer model
108
- loaded offline with SentenceTransformer.
266
+ loaded in-process with SentenceTransformer.
109
267
 
110
268
  Args:
111
269
  model_name (str): The name of the SentenceTransformer model.
112
270
  embedding_dimensions (int): The dimensionality of the generated embeddings.
271
+ logger (Logger): A logger for the embedding generator.
113
272
  **kwargs: Additional keyword arguments to pass to the model.
114
-
115
- Attributes:
116
- model_name (str): The name of the SentenceTransformer model.
117
- embedding_dimensions (int): The dimensionality of the generated embeddings.
118
273
  """
119
274
 
120
- def __init__(self, model_name: str, embedding_dimensions: int, **kwargs) -> None:
275
+ def __init__(
276
+ self,
277
+ model_name: str,
278
+ embedding_dimensions: int,
279
+ logger: logging.Logger,
280
+ **kwargs,
281
+ ) -> None:
121
282
  """
122
- Initializes the EmbeddingGenerator in offline inference mode.
283
+ Initializes the EmbeddingGenerator in 'in-process' mode.
123
284
 
124
285
  Sets the model name, embedding dimensions, and creates a
125
286
  SentenceTransformer model instance.
126
287
  """
288
+ from sentence_transformers import SentenceTransformer
289
+
127
290
  self.model_name = model_name
128
291
  self.embedding_dimensions = embedding_dimensions
129
292
 
@@ -132,6 +295,8 @@ class EmbeddingGeneratorOfflineInference(EmbeddingGenerator):
132
295
  self.model = SentenceTransformer(
133
296
  model_name, truncate_dim=embedding_dimensions, **kwargs
134
297
  )
298
+ self.logger = logger
299
+ self.logger.debug(f"SentenceTransfomer model running on {self.model.device}")
135
300
 
136
301
  # Disabel parallelism for tokenizer
137
302
  # Needed because process might be already parallelized
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ebm4subjects
3
- Version: 0.5.3
3
+ Version: 0.5.5
4
4
  Summary: Embedding Based Matching for Automated Subject Indexing
5
5
  Author: Deutsche Nationalbibliothek
6
6
  Maintainer-email: Clemens Rietdorf <c.rietdorf@dnb.de>, Maximilian Kähler <m.kaehler@dnb.de>
@@ -14,12 +14,14 @@ Classifier: Programming Language :: Python :: 3
14
14
  Requires-Python: >=3.10
15
15
  Requires-Dist: duckdb>=1.3.0
16
16
  Requires-Dist: nltk~=3.9.1
17
+ Requires-Dist: openai>=2.15.0
17
18
  Requires-Dist: polars>=1.30.0
18
19
  Requires-Dist: pyarrow>=21.0.0
19
20
  Requires-Dist: pyoxigraph>=0.4.11
20
21
  Requires-Dist: rdflib~=7.1.3
21
- Requires-Dist: sentence-transformers>=5.0.0
22
22
  Requires-Dist: xgboost>=3.0.2
23
+ Provides-Extra: in-process
24
+ Requires-Dist: sentence-transformers>=5.0.0; extra == 'in-process'
23
25
  Description-Content-Type: text/markdown
24
26
 
25
27
  # Embedding Based Matching for Automated Subject Indexing
@@ -0,0 +1,12 @@
1
+ ebm4subjects/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
+ ebm4subjects/analyzer.py,sha256=lqX7AF8WsvwIavgtnmoVQ0i3wzBJJSeH47EiEwoLKGg,1664
3
+ ebm4subjects/chunker.py,sha256=HcEFJtKWHFYZL8DmZcHGXLPGEkCqHZhh_0kSqyYVsdE,6764
4
+ ebm4subjects/duckdb_client.py,sha256=8lDIpj2o2VTEtjHC_vTYrI5-RNXZnWMft45bS6z9B_k,13031
5
+ ebm4subjects/ebm_logging.py,sha256=vGMa3xSm6T7ZQ94XeNGJVGCTl3zytt4sbunwXc6qF5U,5987
6
+ ebm4subjects/ebm_model.py,sha256=mnnRqdO3vlRF0MTNx7JvHgTXPCNg-8YWJm1kOtHinak,30929
7
+ ebm4subjects/embedding_generator.py,sha256=q5HP36q11EMkH_yomduXa176ays7mtRvBvL0f78NFIE,12909
8
+ ebm4subjects/prepare_data.py,sha256=vQ-BdXkIP3iZJdPXol0WDlY8cRFMHkjzzL7oC7EbouE,3084
9
+ ebm4subjects-0.5.5.dist-info/METADATA,sha256=oekB-uWB3p53odPkbtx-CqzxL_AHf6Az3RJcNhw1xhY,8354
10
+ ebm4subjects-0.5.5.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
11
+ ebm4subjects-0.5.5.dist-info/licenses/LICENSE,sha256=RpvAZSjULHvoTR_esTlucJ08-zdQydnoqQLbqOh9Ub8,13826
12
+ ebm4subjects-0.5.5.dist-info/RECORD,,
@@ -1,12 +0,0 @@
1
- ebm4subjects/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
- ebm4subjects/analyzer.py,sha256=lqX7AF8WsvwIavgtnmoVQ0i3wzBJJSeH47EiEwoLKGg,1664
3
- ebm4subjects/chunker.py,sha256=HcEFJtKWHFYZL8DmZcHGXLPGEkCqHZhh_0kSqyYVsdE,6764
4
- ebm4subjects/duckdb_client.py,sha256=8lDIpj2o2VTEtjHC_vTYrI5-RNXZnWMft45bS6z9B_k,13031
5
- ebm4subjects/ebm_logging.py,sha256=xkbqeVhSCNuhMwkx2yoIX8_D3z9DcsauZEmHhR1gaS0,5962
6
- ebm4subjects/ebm_model.py,sha256=oVLNQv7IVb7KhhExb8o38z1xS3na_DzL-uoIK2A7IW0,30269
7
- ebm4subjects/embedding_generator.py,sha256=VXnZ2mqu2emmyIUkW-pw-7I_Zikc2LqsyiGcg2sxMuc,6703
8
- ebm4subjects/prepare_data.py,sha256=vQ-BdXkIP3iZJdPXol0WDlY8cRFMHkjzzL7oC7EbouE,3084
9
- ebm4subjects-0.5.3.dist-info/METADATA,sha256=uIuPMpcd4GH4sCCn5mbTPUGkjodQBBoTD0cmBt64_9Q,8274
10
- ebm4subjects-0.5.3.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
11
- ebm4subjects-0.5.3.dist-info/licenses/LICENSE,sha256=RpvAZSjULHvoTR_esTlucJ08-zdQydnoqQLbqOh9Ub8,13826
12
- ebm4subjects-0.5.3.dist-info/RECORD,,