ebm4subjects 0.5.4__py3-none-any.whl → 0.5.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -39,17 +39,17 @@ class EbmLogger:
39
39
  else:
40
40
  self.logger.setLevel(logging.NOTSET)
41
41
 
42
- # Create a file handler to log messages to a file
43
- log_file_handler = logging.FileHandler(f"{log_path}/ebm.log")
44
- log_file_handler.setFormatter(
45
- logging.Formatter(
46
- "%(asctime)s %(levelname)s: %(message)s",
47
- "%Y-%m-%d %H:%M:%S",
42
+ # Create a file handler to log messages to a file
43
+ if not self.logger.handlers:
44
+ log_file_handler = logging.FileHandler(f"{log_path}/ebm.log")
45
+ log_file_handler.setFormatter(
46
+ logging.Formatter(
47
+ "%(asctime)s %(levelname)s: %(message)s",
48
+ "%Y-%m-%d %H:%M:%S",
49
+ )
48
50
  )
49
- )
50
51
 
51
- # Add the file handler to the logger
52
- self.logger.addHandler(log_file_handler)
52
+ self.logger.addHandler(log_file_handler)
53
53
 
54
54
  def get_logger(self) -> logging.Logger:
55
55
  """
ebm4subjects/ebm_model.py CHANGED
@@ -16,7 +16,7 @@ from ebm4subjects.ebm_logging import EbmLogger, NullLogger, XGBLogging
16
16
  from ebm4subjects.embedding_generator import (
17
17
  EmbeddingGeneratorHuggingFaceTEI,
18
18
  EmbeddingGeneratorMock,
19
- EmbeddingGeneratorOfflineInference,
19
+ EmbeddingGeneratorInProcess,
20
20
  EmbeddingGeneratorOpenAI,
21
21
  )
22
22
 
@@ -44,12 +44,13 @@ class EbmModel:
44
44
  use_altLabels: bool = True,
45
45
  hnsw_index_params: dict | str | None = None,
46
46
  embedding_model_name: str | None = None,
47
- embedding_model_deployment: str = "offline-inference",
47
+ embedding_model_deployment: str = "mock",
48
48
  embedding_model_args: dict | str | None = None,
49
49
  encode_args_vocab: dict | str | None = None,
50
50
  encode_args_documents: dict | str | None = None,
51
51
  log_path: str | None = None,
52
52
  logger: logging.Logger | None = None,
53
+ logging_level: str = "info",
53
54
  ) -> None:
54
55
  """
55
56
  A class representing an Embedding-Based-Matching (EBM) model
@@ -100,7 +101,7 @@ class EbmModel:
100
101
 
101
102
  # Parameters for embedding generator
102
103
  self.generator = None
103
- self.embedding_model_deployment = embedding_model_deployment
104
+ self.embedding_model_deployment = embedding_model_deployment.lower()
104
105
  self.embedding_model_name = embedding_model_name
105
106
  self.embedding_dimensions = int(embedding_dimensions)
106
107
  if isinstance(embedding_model_args, str) or not embedding_model_args:
@@ -139,7 +140,7 @@ class EbmModel:
139
140
  self.train_jobs = int(xgb_jobs)
140
141
 
141
142
  # Initiliaze logging
142
- self.init_logger(log_path, logger)
143
+ self.init_logger(log_path, logger, logging_level)
143
144
 
144
145
  # Initialize EBM model
145
146
  self.model = None
@@ -180,36 +181,41 @@ class EbmModel:
180
181
  None
181
182
  """
182
183
  if self.generator is None:
183
- if self.embedding_model_deployment == "offline-inference":
184
- self.logger.info("initializing offline-inference embedding generator")
185
- self.generator = EmbeddingGeneratorOfflineInference(
184
+ if self.embedding_model_deployment == "in-process":
185
+ self.logger.info("initializing in-process embedding generator")
186
+ self.generator = EmbeddingGeneratorInProcess(
186
187
  model_name=self.embedding_model_name,
187
188
  embedding_dimensions=self.embedding_dimensions,
189
+ logger=self.logger,
188
190
  **self.embedding_model_args,
189
191
  )
190
192
  elif self.embedding_model_deployment == "mock":
191
193
  self.logger.info("initializing mock embedding generator")
192
194
  self.generator = EmbeddingGeneratorMock(self.embedding_dimensions)
193
- elif self.embedding_model_deployment == "HuggingFaceTEI":
195
+ elif self.embedding_model_deployment == "huggingfacetei":
194
196
  self.logger.info("initializing API embedding generator")
195
197
  self.generator = EmbeddingGeneratorHuggingFaceTEI(
196
198
  model_name=self.embedding_model_name,
197
199
  embedding_dimensions=self.embedding_dimensions,
200
+ logger=self.logger,
198
201
  **self.embedding_model_args,
199
202
  )
200
- elif self.embedding_model_deployment == "OpenAI":
203
+ elif self.embedding_model_deployment == "openai":
201
204
  self.logger.info("initializing API embedding generator")
202
205
  self.generator = EmbeddingGeneratorOpenAI(
203
206
  model_name=self.embedding_model_name,
204
207
  embedding_dimensions=self.embedding_dimensions,
208
+ logger=self.logger,
205
209
  **self.embedding_model_args,
206
210
  )
207
211
  else:
208
- self.logger.error("unsupportet API for embedding generator")
209
- raise NotImplementedError
212
+ raise NotImplementedError("Unsupportet API for embedding generator")
210
213
 
211
214
  def init_logger(
212
- self, log_path: str | None = None, logger: logging.Logger | None = None
215
+ self,
216
+ log_path: str | None = None,
217
+ logger: logging.Logger | None = None,
218
+ logging_level: str = "info",
213
219
  ) -> None:
214
220
  """
215
221
  Initializes the logging for the EBM model.
@@ -218,7 +224,7 @@ class EbmModel:
218
224
  None
219
225
  """
220
226
  if log_path:
221
- self.logger = EbmLogger(log_path, "info").get_logger()
227
+ self.logger = EbmLogger(log_path, logging_level).get_logger()
222
228
  self.xgb_logger = XGBLogging(self.logger, epoch_log_interval=1)
223
229
  self.xgb_callbacks = [self.xgb_logger]
224
230
  elif logger:
@@ -663,7 +669,7 @@ class EbmModel:
663
669
  )
664
670
  self.logger.info("training successful finished")
665
671
  except xgb.core.XGBoostError:
666
- self.logger.critical(
672
+ self.logger.warn(
667
673
  "XGBoost can't train with candidates equal to gold standard "
668
674
  "or candidates with no match to gold standard at all - "
669
675
  "Check if your training data and gold standard are correct"
@@ -769,7 +775,4 @@ class EbmModel:
769
775
  Returns:
770
776
  EbmModel: The loaded EBM model instance.
771
777
  """
772
- ebm_model = joblib.load(input_path)
773
- ebm_model.init_logger()
774
-
775
- return ebm_model
778
+ return joblib.load(input_path)
@@ -1,8 +1,9 @@
1
+ import logging
1
2
  import os
2
3
 
3
4
  import numpy as np
4
5
  import requests
5
- from sentence_transformers import SentenceTransformer
6
+ from openai import BadRequestError, NotFoundError, OpenAI
6
7
  from tqdm import tqdm
7
8
 
8
9
 
@@ -32,25 +33,29 @@ class EmbeddingGenerator:
32
33
  pass
33
34
 
34
35
 
35
- class EmbeddingGeneratorAPI(EmbeddingGenerator):
36
+ class EmbeddingGeneratorHuggingFaceTEI(EmbeddingGenerator):
36
37
  """
37
- A base class for API embedding generators.
38
-
39
- Attributes:
40
- embedding_dimensions (int): The dimensionality of the generated embeddings.
38
+ A class for generating embeddings using the HuggingFaceTEI API.
41
39
  """
42
40
 
43
41
  def __init__(
44
42
  self,
45
43
  model_name: str,
46
44
  embedding_dimensions: int,
45
+ logger: logging.Logger,
47
46
  **kwargs,
48
47
  ) -> None:
49
48
  """
50
- Initializes the API EmbeddingGenerator.
49
+ Initializes the HuggingFaceTEI API EmbeddingGenerator.
51
50
 
52
51
  Sets the embedding dimensions, and initiliazes and
53
52
  prepares a session with the API.
53
+
54
+ Args:
55
+ model_name (str): The name of the SentenceTransformer model.
56
+ embedding_dimensions (int): The dimensionality of the generated embeddings.
57
+ logger (Logger): A logger for the embedding generator.
58
+ **kwargs: Additional keyword arguments to pass to the model.
54
59
  """
55
60
 
56
61
  self.embedding_dimensions = embedding_dimensions
@@ -59,11 +64,26 @@ class EmbeddingGeneratorAPI(EmbeddingGenerator):
59
64
  self.api_address = kwargs.get("api_address")
60
65
  self.headers = kwargs.get("headers", {"Content-Type": "application/json"})
61
66
 
67
+ self.logger = logger
68
+ self._test_api()
62
69
 
63
- class EmbeddingGeneratorHuggingFaceTEI(EmbeddingGeneratorAPI):
64
- """
65
- A class for generating embeddings using the HuggingFaceTEI API.
66
- """
70
+ def _test_api(self):
71
+ """
72
+ Tests if the API is working with the given parameters
73
+ """
74
+ response = self.session.post(
75
+ self.api_address,
76
+ headers=self.headers,
77
+ json={"inputs": "This is a test request!", "truncate": True},
78
+ )
79
+ if response.status_code == 200:
80
+ self.logger.debug(
81
+ "API call successful. Everything seems to be working fine."
82
+ )
83
+ else:
84
+ raise RuntimeError(
85
+ "Request to API not possible! Please check the corresponding parameters!"
86
+ )
67
87
 
68
88
  def generate_embeddings(self, texts: list[str], **kwargs) -> np.ndarray:
69
89
  """
@@ -72,8 +92,7 @@ class EmbeddingGeneratorHuggingFaceTEI(EmbeddingGeneratorAPI):
72
92
 
73
93
  Args:
74
94
  texts (list[str]): A list of input texts.
75
- **kwargs: Additional keyword arguments to pass to the
76
- SentenceTransformer model.
95
+ **kwargs: Additional keyword arguments to pass to the API.
77
96
 
78
97
  Returns:
79
98
  np.ndarray: A numpy array of shape (len(texts), embedding_dimensions)
@@ -102,19 +121,70 @@ class EmbeddingGeneratorHuggingFaceTEI(EmbeddingGeneratorAPI):
102
121
  if response.status_code == 200:
103
122
  embeddings.extend(response.json())
104
123
  else:
105
- # TODO: write warning to logger
124
+ self.logger.warn("Call to API NOT successful! Returning 0's.")
106
125
  for _ in batch_texts:
107
- # TODO: ensure same format as true case and truncate dim
108
- embeddings.append([0 for _ in range(self.embedding_dimensions)])
126
+ embeddings.append(
127
+ [
128
+ 0
129
+ for _ in range(
130
+ min(
131
+ self.embedding_dimensions,
132
+ kwargs.get("truncate_prompt_tokens", float("inf")),
133
+ ),
134
+ )
135
+ ]
136
+ )
109
137
 
110
138
  return np.array(embeddings)
111
139
 
112
140
 
113
- class EmbeddingGeneratorOpenAI(EmbeddingGeneratorAPI):
141
+ class EmbeddingGeneratorOpenAI(EmbeddingGenerator):
114
142
  """
115
- A class for generating embeddings using any OpenAI compatibleAPI.
143
+ A class for generating embeddings using any OpenAI compatible API.
116
144
  """
117
145
 
146
+ def __init__(
147
+ self,
148
+ model_name: str,
149
+ embedding_dimensions: int,
150
+ logger: logging.Logger,
151
+ **kwargs,
152
+ ) -> None:
153
+ """
154
+ Initializes the OpenAI API EmbeddingGenerator.
155
+
156
+ Sets the embedding dimensions, and initiliazes and
157
+ prepares a session with the API.
158
+
159
+ Args:
160
+ model_name (str): The name of the SentenceTransformer model.
161
+ embedding_dimensions (int): The dimensionality of the generated embeddings.
162
+ logger (Logger): A logger for the embedding generator.
163
+ **kwargs: Additional keyword arguments to pass to the model.
164
+ """
165
+
166
+ self.embedding_dimensions = embedding_dimensions
167
+ self.model_name = model_name
168
+
169
+ if not (api_key := os.environ.get("OPENAI_API_KEY")):
170
+ api_key = ""
171
+
172
+ self.client = OpenAI(api_key=api_key, base_url=kwargs.get("api_address"))
173
+
174
+ self.logger = logger
175
+ self._test_api()
176
+
177
+ def _test_api(self):
178
+ """
179
+ Tests if the API is working with the given parameters
180
+ """
181
+ _ = self.client.embeddings.create(
182
+ input="This is a test request!",
183
+ model=self.model_name,
184
+ encoding_format="float",
185
+ )
186
+ self.logger.debug("API call successful. Everything seems to be working fine.")
187
+
118
188
  def generate_embeddings(self, texts: list[str], **kwargs) -> np.ndarray:
119
189
  """
120
190
  Generates embeddings for a list of input texts using a model
@@ -122,8 +192,7 @@ class EmbeddingGeneratorOpenAI(EmbeddingGeneratorAPI):
122
192
 
123
193
  Args:
124
194
  texts (list[str]): A list of input texts.
125
- **kwargs: Additional keyword arguments to pass to the
126
- SentenceTransformer model.
195
+ **kwargs: Additional keyword arguments to pass to the API.
127
196
 
128
197
  Returns:
129
198
  np.ndarray: A numpy array of shape (len(texts), embedding_dimensions)
@@ -143,53 +212,54 @@ class EmbeddingGeneratorOpenAI(EmbeddingGeneratorAPI):
143
212
 
144
213
  for i in tqdm(range(0, len(texts), batch_size), desc="Processing batches"):
145
214
  batch_texts = texts[i : i + batch_size]
146
- data = {
147
- "input": batch_texts,
148
- "model": self.model_name,
149
- "encoding_format": "float",
150
- **kwargs,
151
- }
152
215
 
153
- response = self.session.post(
154
- self.api_address, headers=self.headers, json=data
155
- )
216
+ # Try to get embeddings for the batch from the API
217
+ try:
218
+ embedding_response = self.client.embeddings.create(
219
+ input=batch_texts,
220
+ model=self.model_name,
221
+ encoding_format="float",
222
+ extra_body={**kwargs},
223
+ )
156
224
 
157
- # Process all embeddings from the batch response
158
- if response.status_code == 200:
159
- response_data = response.json()
225
+ # Process all embeddings from the batch response
160
226
  for i, _ in enumerate(batch_texts):
161
- embedding = response_data["data"][i]["embedding"]
162
- embeddings.append(embedding)
163
- else:
164
- # TODO: write warning to logger
227
+ embeddings.append(embedding_response.data[i].embedding)
228
+ except (NotFoundError, BadRequestError):
229
+ self.logger.warn("Call to API NOT successful! Returning 0's.")
165
230
  for _ in batch_texts:
166
231
  embeddings.append([0 for _ in range(self.embedding_dimensions)])
167
232
 
168
233
  return np.array(embeddings)
169
234
 
170
235
 
171
- class EmbeddingGeneratorOfflineInference(EmbeddingGenerator):
236
+ class EmbeddingGeneratorInProcess(EmbeddingGenerator):
172
237
  """
173
238
  A class for generating embeddings using a given SentenceTransformer model
174
- loaded offline with SentenceTransformer.
239
+ loaded in-process with SentenceTransformer.
175
240
 
176
241
  Args:
177
242
  model_name (str): The name of the SentenceTransformer model.
178
243
  embedding_dimensions (int): The dimensionality of the generated embeddings.
244
+ logger (Logger): A logger for the embedding generator.
179
245
  **kwargs: Additional keyword arguments to pass to the model.
180
-
181
- Attributes:
182
- model_name (str): The name of the SentenceTransformer model.
183
- embedding_dimensions (int): The dimensionality of the generated embeddings.
184
246
  """
185
247
 
186
- def __init__(self, model_name: str, embedding_dimensions: int, **kwargs) -> None:
248
+ def __init__(
249
+ self,
250
+ model_name: str,
251
+ embedding_dimensions: int,
252
+ logger: logging.Logger,
253
+ **kwargs,
254
+ ) -> None:
187
255
  """
188
- Initializes the EmbeddingGenerator in offline inference mode.
256
+ Initializes the EmbeddingGenerator in 'in-process' mode.
189
257
 
190
258
  Sets the model name, embedding dimensions, and creates a
191
259
  SentenceTransformer model instance.
192
260
  """
261
+ from sentence_transformers import SentenceTransformer
262
+
193
263
  self.model_name = model_name
194
264
  self.embedding_dimensions = embedding_dimensions
195
265
 
@@ -198,6 +268,8 @@ class EmbeddingGeneratorOfflineInference(EmbeddingGenerator):
198
268
  self.model = SentenceTransformer(
199
269
  model_name, truncate_dim=embedding_dimensions, **kwargs
200
270
  )
271
+ self.logger = logger
272
+ self.logger.debug(f"SentenceTransfomer model running on {self.model.device}")
201
273
 
202
274
  # Disabel parallelism for tokenizer
203
275
  # Needed because process might be already parallelized
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ebm4subjects
3
- Version: 0.5.4
3
+ Version: 0.5.6
4
4
  Summary: Embedding Based Matching for Automated Subject Indexing
5
5
  Author: Deutsche Nationalbibliothek
6
6
  Maintainer-email: Clemens Rietdorf <c.rietdorf@dnb.de>, Maximilian Kähler <m.kaehler@dnb.de>
@@ -14,12 +14,14 @@ Classifier: Programming Language :: Python :: 3
14
14
  Requires-Python: >=3.10
15
15
  Requires-Dist: duckdb>=1.3.0
16
16
  Requires-Dist: nltk~=3.9.1
17
+ Requires-Dist: openai>=2.15.0
17
18
  Requires-Dist: polars>=1.30.0
18
19
  Requires-Dist: pyarrow>=21.0.0
19
20
  Requires-Dist: pyoxigraph>=0.4.11
20
21
  Requires-Dist: rdflib~=7.1.3
21
- Requires-Dist: sentence-transformers>=5.0.0
22
22
  Requires-Dist: xgboost>=3.0.2
23
+ Provides-Extra: in-process
24
+ Requires-Dist: sentence-transformers>=5.0.0; extra == 'in-process'
23
25
  Description-Content-Type: text/markdown
24
26
 
25
27
  # Embedding Based Matching for Automated Subject Indexing
@@ -0,0 +1,12 @@
1
+ ebm4subjects/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
+ ebm4subjects/analyzer.py,sha256=lqX7AF8WsvwIavgtnmoVQ0i3wzBJJSeH47EiEwoLKGg,1664
3
+ ebm4subjects/chunker.py,sha256=HcEFJtKWHFYZL8DmZcHGXLPGEkCqHZhh_0kSqyYVsdE,6764
4
+ ebm4subjects/duckdb_client.py,sha256=8lDIpj2o2VTEtjHC_vTYrI5-RNXZnWMft45bS6z9B_k,13031
5
+ ebm4subjects/ebm_logging.py,sha256=vGMa3xSm6T7ZQ94XeNGJVGCTl3zytt4sbunwXc6qF5U,5987
6
+ ebm4subjects/ebm_model.py,sha256=UTCIv_KCQ4HTJVbcVIAUv4S2j87oq8HXBeN5mfJmclQ,30879
7
+ ebm4subjects/embedding_generator.py,sha256=fk8rRhqBcRCknpCYoFolcXjoCwsx25Qd_UEOt-nUlv8,11774
8
+ ebm4subjects/prepare_data.py,sha256=vQ-BdXkIP3iZJdPXol0WDlY8cRFMHkjzzL7oC7EbouE,3084
9
+ ebm4subjects-0.5.6.dist-info/METADATA,sha256=Dujb7SghFPo3j42yRAgkbqv-VSmwpocJIHW4NgJFhn0,8354
10
+ ebm4subjects-0.5.6.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
11
+ ebm4subjects-0.5.6.dist-info/licenses/LICENSE,sha256=RpvAZSjULHvoTR_esTlucJ08-zdQydnoqQLbqOh9Ub8,13826
12
+ ebm4subjects-0.5.6.dist-info/RECORD,,
@@ -1,12 +0,0 @@
1
- ebm4subjects/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
- ebm4subjects/analyzer.py,sha256=lqX7AF8WsvwIavgtnmoVQ0i3wzBJJSeH47EiEwoLKGg,1664
3
- ebm4subjects/chunker.py,sha256=HcEFJtKWHFYZL8DmZcHGXLPGEkCqHZhh_0kSqyYVsdE,6764
4
- ebm4subjects/duckdb_client.py,sha256=8lDIpj2o2VTEtjHC_vTYrI5-RNXZnWMft45bS6z9B_k,13031
5
- ebm4subjects/ebm_logging.py,sha256=xkbqeVhSCNuhMwkx2yoIX8_D3z9DcsauZEmHhR1gaS0,5962
6
- ebm4subjects/ebm_model.py,sha256=lzGx_HLkKyTPVhtU4117DOEDz1rduNdzltvCYSbHQPg,30780
7
- ebm4subjects/embedding_generator.py,sha256=LKZ_YAe4Th8foI_8-v-3tYFj0KGJ90XJ3OPuMXaqgSQ,9274
8
- ebm4subjects/prepare_data.py,sha256=vQ-BdXkIP3iZJdPXol0WDlY8cRFMHkjzzL7oC7EbouE,3084
9
- ebm4subjects-0.5.4.dist-info/METADATA,sha256=OmMMh0pGAdv3YTkTork55wuj2gA0Ac8zV9ad3cDCIks,8274
10
- ebm4subjects-0.5.4.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
11
- ebm4subjects-0.5.4.dist-info/licenses/LICENSE,sha256=RpvAZSjULHvoTR_esTlucJ08-zdQydnoqQLbqOh9Ub8,13826
12
- ebm4subjects-0.5.4.dist-info/RECORD,,