ebm4subjects 0.4.1__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ebm4subjects/analyzer.py CHANGED
@@ -32,8 +32,9 @@ class EbmAnalyzer:
32
32
  nltk.data.find(tokenizer_name)
33
33
  # If the tokenizer is not found, try to download it
34
34
  except LookupError as error:
35
- if tokenizer_name in str(error):
36
- nltk.download(tokenizer_name)
35
+ if "punkt" in str(error):
36
+ nltk.download("punkt")
37
+ nltk.download("punkt_tab")
37
38
  else:
38
39
  raise
39
40
 
ebm4subjects/chunker.py CHANGED
@@ -1,6 +1,6 @@
1
1
  from concurrent.futures import ProcessPoolExecutor
2
2
  from math import ceil
3
- from typing import Tuple
3
+ from typing import Any, Tuple
4
4
 
5
5
  import polars as pl
6
6
 
@@ -17,9 +17,9 @@ class Chunker:
17
17
 
18
18
  Attributes:
19
19
  tokenizer (EbmAnalyzer): The tokenizer used for tokenizing sentences.
20
- max_chunks (int): The maximum number of chunks to generate.
21
- max_chunk_size (int): The maximum size of each chunk in characters.
22
- max_sentences (int): The maximum number of sentences to consider.
20
+ max_chunk_count (int): The maximum number of chunks to generate.
21
+ max_chunk_length (int): The maximum size of each chunk in characters.
22
+ max_sentence_count (int): The maximum number of sentences to consider.
23
23
 
24
24
  Methods:
25
25
  - chunk_text: Chunks a given text into smaller sections
@@ -28,25 +28,30 @@ class Chunker:
28
28
 
29
29
  def __init__(
30
30
  self,
31
- tokenizer_name: str,
32
- max_chunks: int | None,
33
- max_chunk_size: int | None,
34
- max_sentences: int | None,
31
+ tokenizer: Any,
32
+ max_chunk_count: int | None,
33
+ max_chunk_length: int | None,
34
+ max_sentence_count: int | None,
35
35
  ):
36
36
  """
37
37
  Initializes the Chunker.
38
38
 
39
39
  Args:
40
- tokenizer_name (str): The name of the tokenizer to use.
41
- max_chunks (int | None): The maximum number of chunks to generate.
42
- max_chunk_size (int | None): The maximum size of each chunk in characters.
43
- max_sentences (int | None): The maximum number of sentences to consider.
40
+ tokenizer (Any): The name of the tokenizer to use or the tokenizer itself.
41
+ max_chunk_count (int | None): The maximum number of chunks to generate.
42
+ max_chunk_length (int | None): The maximum size of each chunk in characters.
43
+ max_sentence_count (int | None): The maximum number of sentences to consider.
44
44
  """
45
- self.max_chunks = max_chunks if max_chunks else float("inf")
46
- self.max_chunk_size = max_chunk_size if max_chunk_size else float("inf")
47
- self.max_sentences = max_sentences if max_sentences else float("inf")
45
+ self.max_chunk_count = max_chunk_count if max_chunk_count else float("inf")
46
+ self.max_chunk_length = max_chunk_length if max_chunk_length else float("inf")
47
+ self.max_sentence_count = (
48
+ max_sentence_count if max_sentence_count else float("inf")
49
+ )
48
50
 
49
- self.tokenizer = EbmAnalyzer(tokenizer_name)
51
+ if type(tokenizer) is str:
52
+ self.tokenizer = EbmAnalyzer(tokenizer)
53
+ else:
54
+ self.tokenizer = tokenizer
50
55
 
51
56
  def chunk_text(self, text: str) -> list[str]:
52
57
  """
@@ -63,7 +68,7 @@ class Chunker:
63
68
 
64
69
  # Tokenize the text into sentences
65
70
  sentences = self.tokenizer.tokenize_sentences(text)
66
- sentences = sentences[: self.max_sentences]
71
+ sentences = sentences[: self.max_sentence_count]
67
72
 
68
73
  # Initialize an empty list to store the current chunk
69
74
  current_chunk = []
@@ -71,18 +76,18 @@ class Chunker:
71
76
  # Iterate over the sentences
72
77
  for sentence in sentences:
73
78
  # If the current chunk is not full, add the sentence to it
74
- if len(" ".join(current_chunk)) < self.max_chunk_size:
79
+ if len(" ".join(current_chunk)) < self.max_chunk_length:
75
80
  current_chunk.append(sentence)
76
81
  # Otherwise, add the current chunk to the list of chunks
77
82
  # and start a new chunk
78
83
  else:
79
84
  chunks.append(" ".join(current_chunk))
80
85
  current_chunk = [sentence]
81
- if len(chunks) == self.max_chunks:
86
+ if len(chunks) == self.max_chunk_count:
82
87
  break
83
88
 
84
89
  # If the maximum number of chunks is reached, break the loop
85
- if current_chunk and len(chunks) < self.max_chunks:
90
+ if current_chunk and len(chunks) < self.max_chunk_count:
86
91
  chunks.append(" ".join(current_chunk))
87
92
 
88
93
  # Return the chunked text
@@ -37,8 +37,8 @@ class Duckdb_client:
37
37
  (default: {"M": 32, "ef_construction": 256, "ef_search": 256}).
38
38
 
39
39
  Notes:
40
- 'hnsw_enable_experimental_persistence' needs to be set to 'True' in order
41
- to store and query the index later
40
+ 'hnsw_enable_experimental_persistence' needs to be set to 'True' in order
41
+ to store and query the index later
42
42
  """
43
43
  # Establish a connection to the DuckDB database
44
44
  self.connection = duckdb.connect(
@@ -76,10 +76,10 @@ class Duckdb_client:
76
76
  (default: "cosine")
77
77
  force (bool, optional): Whether to replace the existing collection if it
78
78
  already exists (default: False).
79
-
79
+
80
80
  Notes:
81
- If 'hnsw_metric' is changed in this function 'hnsw_metric_function' in
82
- the vector_search function needs to be changed accordingly in order
81
+ If 'hnsw_metric' is changed in this function 'hnsw_metric_function' in
82
+ the vector_search function needs to be changed accordingly in order
83
83
  for the index to work properly.
84
84
  """
85
85
  # Determine whether to replace the existing collection
@@ -147,10 +147,10 @@ class Duckdb_client:
147
147
  pl.DataFrame: The result of the vector search.
148
148
 
149
149
  Notes:
150
- If 'hnsw_metric_function' is changed in this function 'hnsw_metric' in
151
- the create_collection function needs to be changed accordingly in order
150
+ If 'hnsw_metric_function' is changed in this function 'hnsw_metric' in
151
+ the create_collection function needs to be changed accordingly in order
152
152
  for the index to work properly.
153
- The argument 'chunk_size' is already set to the optimal value for the
153
+ The argument 'chunk_size' is already set to the optimal value for the
154
154
  query processing with DuckDB. Only change it if necessary.
155
155
  """
156
156
  # Create a temporary table to store the search results
@@ -7,7 +7,7 @@ class EbmLogger:
7
7
  """
8
8
  A custom logger class.
9
9
 
10
- This class provides a way to log messages at different levels
10
+ This class provides a way to log messages at different levels
11
11
  (error, warning, info, debug) to a file.
12
12
  It also provides a way to get the logger instance.
13
13
 
@@ -16,6 +16,7 @@ class EbmLogger:
16
16
  log_path (str): The path to the log file.
17
17
  level (str): The log level (default: "info").
18
18
  """
19
+
19
20
  def __init__(self, log_path: str, level: str = "info") -> None:
20
21
  """
21
22
  Initializes the logger.
@@ -66,6 +67,7 @@ class NullLogger:
66
67
 
67
68
  This class is used when no logging is needed.
68
69
  """
70
+
69
71
  def __init__(self) -> None:
70
72
  """
71
73
  Initializes the null logger.
@@ -136,16 +138,17 @@ class NullLogger:
136
138
  class XGBLogging(xgboost.callback.TrainingCallback):
137
139
  """
138
140
  Custom XGBoost training callback for logging model performance during training.
139
-
141
+
140
142
  Args:
141
143
  logger (logging.Logger): Logger instance to use for logging.
142
144
  epoch_log_interval (int, optional): Interval at which to log model performance
143
145
  (default: 100).
144
-
146
+
145
147
  Attributes:
146
148
  logger (logging.Logger): Logger instance used for logging.
147
149
  epoch_log_interval (int): Interval at which to log model performance.
148
150
  """
151
+
149
152
  def __init__(
150
153
  self,
151
154
  logger: logging.Logger,
@@ -153,10 +156,10 @@ class XGBLogging(xgboost.callback.TrainingCallback):
153
156
  ) -> None:
154
157
  """
155
158
  Initializes the XGBLogger.
156
-
159
+
157
160
  Args:
158
161
  logger (logging.Logger): Logger instance to use for logging.
159
- epoch_log_interval (int, optional): Interval at which to log model
162
+ epoch_log_interval (int, optional): Interval at which to log model
160
163
  performance (default: to 100).
161
164
  """
162
165
  # Logger instance used for logging
@@ -172,14 +175,14 @@ class XGBLogging(xgboost.callback.TrainingCallback):
172
175
  ) -> bool:
173
176
  """
174
177
  Callback function called after each iteration of the XGBoost training process.
175
-
178
+
176
179
  Logs model performance at the specified interval.
177
-
180
+
178
181
  Args:
179
182
  model (xgboost.Booster): XGBoost model instance.
180
183
  epoch (int): Current epoch number.
181
184
  evals_log (dict): Dictionary containing evaluation metrics.
182
-
185
+
183
186
  Returns:
184
187
  bool: Always returns False, as specified by the XGBoost callback API.
185
188
  """
ebm4subjects/ebm_model.py CHANGED
@@ -1,7 +1,9 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import ast
4
+ import logging
4
5
  from pathlib import Path
6
+ from typing import Any
5
7
 
6
8
  import joblib
7
9
  import polars as pl
@@ -17,30 +19,31 @@ from ebm4subjects.embedding_generator import EmbeddingGenerator
17
19
  class EbmModel:
18
20
  def __init__(
19
21
  self,
20
- db_path: str,
21
- collection_name: str,
22
- use_altLabels: bool,
23
- duckdb_threads: int | str,
24
- embedding_model_name: str,
22
+ embedding_model_name: str | Any,
25
23
  embedding_dimensions: int | str,
26
- chunk_tokenizer: str,
27
- max_chunks: int | str,
28
- max_chunk_size: int | str,
24
+ chunk_tokenizer: str | Any,
25
+ max_chunk_count: int | str,
26
+ max_chunk_length: int | str,
29
27
  chunking_jobs: int | str,
30
- max_sentences: int | str,
31
- max_query_hits: int | str,
32
- query_top_k: int | str,
28
+ max_sentence_count: int | str,
29
+ candidates_per_chunk: int | str,
30
+ candidates_per_doc: int | str,
33
31
  query_jobs: int | str,
34
32
  xgb_shrinkage: float | str,
35
33
  xgb_interaction_depth: int | str,
36
34
  xgb_subsample: float | str,
37
35
  xgb_rounds: int | str,
38
36
  xgb_jobs: int | str,
37
+ duckdb_threads: int | str,
38
+ db_path: str,
39
+ collection_name: str = "my_collection",
40
+ use_altLabels: bool = True,
39
41
  hnsw_index_params: dict | str | None = None,
40
42
  model_args: dict | str | None = None,
41
43
  encode_args_vocab: dict | str | None = None,
42
44
  encode_args_documents: dict | str | None = None,
43
45
  log_path: str | None = None,
46
+ logger: logging.Logger | None = None,
44
47
  ) -> None:
45
48
  """
46
49
  A class representing an Embedding-Based-Matching (EBM) model
@@ -109,14 +112,14 @@ class EbmModel:
109
112
 
110
113
  # Parameters for chunker
111
114
  self.chunk_tokenizer = chunk_tokenizer
112
- self.max_chunks = int(max_chunks)
113
- self.max_chunk_size = int(max_chunk_size)
114
- self.max_sentences = int(max_sentences)
115
+ self.max_chunk_count = int(max_chunk_count)
116
+ self.max_chunk_length = int(max_chunk_length)
117
+ self.max_sentence_count = int(max_sentence_count)
115
118
  self.chunking_jobs = int(chunking_jobs)
116
119
 
117
120
  # Parameters for vector search
118
- self.max_query_hits = int(max_query_hits)
119
- self.query_top_k = int(query_top_k)
121
+ self.candidates_per_chunk = int(candidates_per_chunk)
122
+ self.candidates_per_doc = int(candidates_per_doc)
120
123
  self.query_jobs = int(query_jobs)
121
124
 
122
125
  # Parameters for XGB boost ranker
@@ -126,17 +129,8 @@ class EbmModel:
126
129
  self.train_rounds = int(xgb_rounds)
127
130
  self.train_jobs = int(xgb_jobs)
128
131
 
129
- # Parameters for logger
130
- # Only create logger if path to log file is set
131
- self.logger = None
132
- self.xgb_logger = None
133
- self.xgb_callbacks = None
134
- if log_path:
135
- self.logger = EbmLogger(log_path, "info").get_logger()
136
- self.xgb_logger = XGBLogging(self.logger, epoch_log_interval=1)
137
- self.xgb_callbacks = [self.xgb_logger]
138
- else:
139
- self.logger = NullLogger()
132
+ # Initiliaze logging
133
+ self.init_logger(log_path, logger)
140
134
 
141
135
  # Initialize EBM model
142
136
  self.model = None
@@ -153,7 +147,9 @@ class EbmModel:
153
147
  None
154
148
  """
155
149
  if self.client is None:
156
- self.logger.info("Initializing DuckDB client")
150
+ self.logger.info(
151
+ f"initializing DuckDB client with duckdb_threads: {self.duckdb_threads}"
152
+ )
157
153
 
158
154
  self.client = Duckdb_client(
159
155
  db_path=self.db_path,
@@ -175,14 +171,35 @@ class EbmModel:
175
171
  None
176
172
  """
177
173
  if self.generator is None:
178
- self.logger.info("Initializing embedding generator")
179
-
174
+ self.logger.info("initializing embedding generator")
180
175
  self.generator = EmbeddingGenerator(
181
176
  model_name=self.embedding_model_name,
182
177
  embedding_dimensions=self.embedding_dimensions,
183
178
  **self.model_args,
184
179
  )
185
180
 
181
+ def init_logger(
182
+ self, log_path: str | None = None, logger: logging.Logger | None = None
183
+ ) -> None:
184
+ """
185
+ Initializes the logging for the EBM model.
186
+
187
+ Returns:
188
+ None
189
+ """
190
+ if log_path:
191
+ self.logger = EbmLogger(log_path, "info").get_logger()
192
+ self.xgb_logger = XGBLogging(self.logger, epoch_log_interval=1)
193
+ self.xgb_callbacks = [self.xgb_logger]
194
+ elif logger:
195
+ self.logger = logger
196
+ self.xgb_logger = XGBLogging(self.logger, epoch_log_interval=1)
197
+ self.xgb_callbacks = [self.xgb_logger]
198
+ else:
199
+ self.logger = NullLogger()
200
+ self.xgb_logger = None
201
+ self.xgb_callbacks = None
202
+
186
203
  def create_vector_db(
187
204
  self,
188
205
  vocab_in_path: str | None = None,
@@ -213,12 +230,12 @@ class EbmModel:
213
230
  # Check if output path exists and load existing vocabulary if so
214
231
  if vocab_out_path and Path(vocab_out_path).exists():
215
232
  self.logger.info(
216
- f"Loading vocabulary with embeddings from {vocab_out_path}"
233
+ f"loading vocabulary with embeddings from {vocab_out_path}"
217
234
  )
218
235
  collection_df = pl.read_ipc(vocab_out_path)
219
236
  # Parse input vocabulary if provided
220
237
  elif vocab_in_path:
221
- self.logger.info("Parsing vocabulary")
238
+ self.logger.info("parsing vocabulary")
222
239
  vocab = prepare_data.parse_vocab(
223
240
  vocab_path=vocab_in_path,
224
241
  use_altLabels=self.use_altLabels,
@@ -226,7 +243,7 @@ class EbmModel:
226
243
 
227
244
  # Initialize generator and add embeddings to vocabulary
228
245
  self._init_generator()
229
- self.logger.info("Adding embeddings to vocabulary")
246
+ self.logger.info("adding embeddings to vocabulary")
230
247
  collection_df = prepare_data.add_vocab_embeddings(
231
248
  vocab=vocab,
232
249
  generator=self.generator,
@@ -238,20 +255,20 @@ class EbmModel:
238
255
  # Check if file already exists and warn if so
239
256
  if Path(vocab_out_path).exists() and not force:
240
257
  self.logger.warn(
241
- f"""Cant't save vocabulary to {vocab_out_path}.
258
+ f"""cant't save vocabulary to {vocab_out_path}.
242
259
  File already exists"""
243
260
  )
244
261
  else:
245
- self.logger.info(f"Saving vocabulary to {vocab_out_path}")
262
+ self.logger.info(f"saving vocabulary to {vocab_out_path}")
246
263
  collection_df.write_ipc(vocab_out_path)
247
264
  else:
248
265
  # If no existing vocabulary and no input vocabulary is provided,
249
266
  # raise an error
250
- raise ValueError("Vocabulary path is required")
267
+ raise ValueError("vocabulary path is required")
251
268
 
252
269
  # Initialize DuckDB client and create collection
253
270
  self._init_duckdb_client()
254
- self.logger.info("Creating collection")
271
+ self.logger.info("creating collection")
255
272
  self.client.create_collection(
256
273
  collection_df=collection_df,
257
274
  collection_name=self.collection_name,
@@ -286,8 +303,6 @@ class EbmModel:
286
303
  Returns:
287
304
  pl.DataFrame: The prepared training data.
288
305
  """
289
-
290
- self.logger.info("Preparing training data")
291
306
  # Check if pre-computed candidate training data is provided
292
307
  if not train_candidates:
293
308
  # If not, generate candidate training data in batches
@@ -306,7 +321,6 @@ class EbmModel:
306
321
  )
307
322
 
308
323
  # Create a gold standard data frame from the provided doc IDs and label IDs
309
- self.logger.info("Preparing gold standard")
310
324
  gold_standard = pl.DataFrame(
311
325
  {
312
326
  "doc_id": doc_ids,
@@ -318,7 +332,7 @@ class EbmModel:
318
332
 
319
333
  # Compare the candidate training data to the gold standard
320
334
  # and prepare data for the training of the XGB ranker model
321
- self.logger.info("Prepare training data and gold standard for training")
335
+ self.logger.info("prepare training data and gold standard for training")
322
336
  training_data = (
323
337
  self._compare_to_gold_standard(train_candidates, gold_standard)
324
338
  .with_columns(pl.when(pl.col("gold")).then(1).otherwise(0).alias("gold"))
@@ -406,19 +420,19 @@ class EbmModel:
406
420
  n_jobs = self.query_jobs
407
421
 
408
422
  # Create a Chunker instance with specified parameters
409
- self.logger.info("Chunking text")
423
+ self.logger.info("chunking text")
410
424
  chunker = Chunker(
411
- tokenizer_name=self.chunk_tokenizer,
412
- max_chunks=self.max_chunks,
413
- max_chunk_size=self.max_chunk_size,
414
- max_sentences=self.max_sentences,
425
+ tokenizer=self.chunk_tokenizer,
426
+ max_chunk_count=self.max_chunk_count,
427
+ max_chunk_length=self.max_chunk_length,
428
+ max_sentence_count=self.max_sentence_count,
415
429
  )
416
430
  # Chunk the input text
417
431
  text_chunks = chunker.chunk_text(text)
418
432
 
419
433
  # Initialize the generator
420
434
  self._init_generator()
421
- self.logger.info("Creating embeddings for text chunks")
435
+ self.logger.info("creating embeddings for text chunks")
422
436
  # Generate embeddings for the text chunks
423
437
  embeddings = self.generator.generate_embeddings(
424
438
  # Use the text chunks as input
@@ -432,7 +446,7 @@ class EbmModel:
432
446
  )
433
447
 
434
448
  # Create a query DataFrame
435
- self.logger.info("Creating query dataframe")
449
+ self.logger.info("creating query dataframe")
436
450
  query_df = pl.DataFrame(
437
451
  {
438
452
  # Create a column for the query ID
@@ -450,7 +464,9 @@ class EbmModel:
450
464
 
451
465
  # Initialize the DuckDB client
452
466
  self._init_duckdb_client()
453
- self.logger.info("Running vector search and creating candidates")
467
+ self.logger.info(
468
+ f"running vector search and creating candidates with query_jobs: {n_jobs}"
469
+ )
454
470
  # Perform vector search using the query DataFrame
455
471
  # Using the parameters specified for the EBM model
456
472
  # and the optimal chunk size for the DuckDB
@@ -459,9 +475,9 @@ class EbmModel:
459
475
  collection_name=self.collection_name,
460
476
  embedding_dimensions=self.embedding_dimensions,
461
477
  n_jobs=n_jobs,
462
- n_hits=self.max_query_hits,
478
+ n_hits=self.candidates_per_chunk,
463
479
  chunk_size=1024,
464
- top_k=self.query_top_k,
480
+ top_k=self.candidates_per_doc,
465
481
  hnsw_metric_function="array_cosine_distance",
466
482
  )
467
483
 
@@ -500,20 +516,20 @@ class EbmModel:
500
516
  query_jobs = self.query_jobs
501
517
 
502
518
  # Create a Chunker instance with specified parameters
503
- self.logger.info("Chunking texts in batches")
504
519
  chunker = Chunker(
505
- tokenizer_name=self.chunk_tokenizer,
506
- max_chunks=self.max_chunks,
507
- max_chunk_size=self.max_chunk_size,
508
- max_sentences=self.max_sentences,
520
+ tokenizer=self.chunk_tokenizer,
521
+ max_chunk_count=self.max_chunk_count,
522
+ max_chunk_length=self.max_chunk_length,
523
+ max_sentence_count=self.max_sentence_count,
509
524
  )
510
525
  # Chunk the input texts
526
+ self.logger.info(f"chunking texts with chunking_jobs: {chunking_jobs}")
511
527
  text_chunks, chunk_index = chunker.chunk_batches(texts, doc_ids, chunking_jobs)
512
528
 
513
529
  # Initialize the generator and chunk index
514
530
  self._init_generator()
515
531
  chunk_index = pl.concat(chunk_index).with_row_index("query_id")
516
- self.logger.info("Creating embeddings for text chunks and query dataframe")
532
+ self.logger.info("creating embeddings for text chunks and query dataframe")
517
533
  embeddings = self.generator.generate_embeddings(
518
534
  texts=text_chunks,
519
535
  **(
@@ -531,15 +547,17 @@ class EbmModel:
531
547
  # Perform vector search using the query DataFrame
532
548
  # Using the parameters specified for the EBM model
533
549
  # and the optimal chunk size for the DuckDB
534
- self.logger.info("Running vector search and creating candidates")
550
+ self.logger.info(
551
+ f"running vector search and creating candidates with query_jobs: {query_jobs}"
552
+ )
535
553
  candidates = self.client.vector_search(
536
554
  query_df=query_df,
537
555
  collection_name=self.collection_name,
538
556
  embedding_dimensions=self.embedding_dimensions,
539
557
  n_jobs=query_jobs,
540
- n_hits=self.max_query_hits,
558
+ n_hits=self.candidates_per_chunk,
541
559
  chunk_size=1024,
542
- top_k=self.query_top_k,
560
+ top_k=self.candidates_per_doc,
543
561
  hnsw_metric_function="array_cosine_distance",
544
562
  )
545
563
 
@@ -567,8 +585,8 @@ class EbmModel:
567
585
  n_jobs = self.train_jobs
568
586
 
569
587
  # Select the required columns from the train_data DataFrame,
570
- # convert to a Pandas DataFrame and afterwards to training matrix
571
- self.logger.info("Creating training matrix")
588
+ # convert to a numpy array and afterwards to training matrix
589
+ self.logger.info("creating training matrix")
572
590
  matrix = xgb.DMatrix(
573
591
  train_data.select(
574
592
  [
@@ -582,14 +600,16 @@ class EbmModel:
582
600
  "is_prefLabel",
583
601
  "n_chunks",
584
602
  ]
585
- ).to_pandas(),
603
+ ).to_numpy(),
586
604
  # Use the gold standard as the target
587
- train_data.to_pandas()["gold"],
605
+ train_data.select("gold").to_numpy(),
588
606
  )
589
607
 
590
608
  try:
591
609
  # Train the XGBoost model with the specified parameters
592
- self.logger.info("Starting training of XGBoost Ranker")
610
+ self.logger.info(
611
+ f"starting training of XGBoost Ranker with xgb_jobs: {n_jobs}"
612
+ )
593
613
  model = xgb.train(
594
614
  # Train the XGBoost model with the specified parameters
595
615
  params={
@@ -611,7 +631,7 @@ class EbmModel:
611
631
  # Use the specified callbacks
612
632
  callbacks=self.xgb_callbacks,
613
633
  )
614
- self.logger.info("Training successful finished")
634
+ self.logger.info("training successful finished")
615
635
  except xgb.core.XGBoostError:
616
636
  self.logger.critical(
617
637
  "XGBoost can't train with candidates equal to gold standard "
@@ -641,7 +661,7 @@ class EbmModel:
641
661
  """
642
662
  # Select relevant columns from the candidates DataFrame to create a matrix
643
663
  # for the trained model to make predictions
644
- self.logger.info("Creating matrix of candidates to generate predictions")
664
+ self.logger.info("creating matrix of candidates to generate predictions")
645
665
  matrix = xgb.DMatrix(
646
666
  candidates.select(
647
667
  [
@@ -659,7 +679,7 @@ class EbmModel:
659
679
  )
660
680
 
661
681
  # Use the trained model to make predictions on the created matrix
662
- self.logger.info("Making predictions for candidates")
682
+ self.logger.info("making predictions for candidates")
663
683
  predictions = self.model.predict(matrix)
664
684
 
665
685
  # Transform the predictions into a list of DataFrames containing the
@@ -671,11 +691,12 @@ class EbmModel:
671
691
  .select(["doc_id", "label_id", "score"])
672
692
  # Sort the DataFrame by document ID and score in ascending and
673
693
  # descending order, respectively
694
+ .with_columns(pl.col("doc_id").cast(pl.Int64))
674
695
  .sort(["doc_id", "score"], descending=[False, True])
675
696
  # Group the DataFrame by document ID and aggregate the top-k labels
676
697
  # and scores for each group
677
698
  .group_by("doc_id")
678
- .agg(pl.all().head(self.query_top_k))
699
+ .agg(pl.all().head(self.candidates_per_doc))
679
700
  # Explode the aggregated DataFrame to create separate rows for each
680
701
  # label and score
681
702
  .explode(["label_id", "score"])
@@ -683,7 +704,7 @@ class EbmModel:
683
704
  .partition_by("doc_id")
684
705
  )
685
706
 
686
- def save(self, output_path: str) -> None:
707
+ def save(self, output_path: str) -> list[str]:
687
708
  """
688
709
  Saves the current state of the EBM model to a file using joblib.
689
710
 
@@ -694,12 +715,18 @@ class EbmModel:
694
715
  Args:
695
716
  output_path: The file path where the serialized model will be written.
696
717
 
718
+ Returns:
719
+ list[str]: Output path of model file.
720
+
697
721
  Notes:
698
- The model's client and generator attributes are reset to None.
722
+ The model's client, generator and loggers are reset to None.
699
723
  """
700
724
  self.client = None
701
725
  self.generator = None
702
- joblib.dump(self, output_path)
726
+
727
+ self.init_logger()
728
+
729
+ return joblib.dump(self, output_path)
703
730
 
704
731
  @staticmethod
705
732
  def load(input_path: str) -> EbmModel:
@@ -707,9 +734,12 @@ class EbmModel:
707
734
  Loads an EBM model from a joblib serialized file.
708
735
 
709
736
  Args:
710
- input_path (str): Path to the joblib serialized file containing the EBM model.
737
+ input_path (str): Path to the joblib serialized file containing the EBM model.
711
738
 
712
739
  Returns:
713
- EbmModel: The loaded EBM model instance.
740
+ EbmModel: The loaded EBM model instance.
714
741
  """
715
- return joblib.load(input_path)
742
+ ebm_model = joblib.load(input_path)
743
+ ebm_model.init_logger()
744
+
745
+ return ebm_model
@@ -9,7 +9,8 @@ class EmbeddingGenerator:
9
9
  A class for generating embeddings using a given SentenceTransformer model.
10
10
 
11
11
  Args:
12
- model_name (str): The name of the SentenceTransformer model to use.
12
+ model_name (str, SentenceTransformer): The name of the SentenceTransformer
13
+ model or an SentenceTransformer model to use.
13
14
  embedding_dimensions (int): The dimensionality of the generated embeddings.
14
15
  **kwargs: Additional keyword arguments to pass to the model.
15
16
 
@@ -19,7 +20,9 @@ class EmbeddingGenerator:
19
20
  model (SentenceTransformer): The SentenceTransformer model instance.
20
21
  """
21
22
 
22
- def __init__(self, model_name: str, embedding_dimensions: int, **kwargs) -> None:
23
+ def __init__(
24
+ self, model_name: str | SentenceTransformer, embedding_dimensions: int, **kwargs
25
+ ) -> None:
23
26
  """
24
27
  Initializes the EmbeddingGenerator.
25
28
 
@@ -31,9 +34,13 @@ class EmbeddingGenerator:
31
34
 
32
35
  # Create a SentenceTransformer model instance with the given
33
36
  # model name and embedding dimensions
34
- self.model = SentenceTransformer(
35
- model_name, truncate_dim=embedding_dimensions, **kwargs
36
- )
37
+ # or set model to the given SentenceTransformer
38
+ if type(model_name) is str:
39
+ self.model = SentenceTransformer(
40
+ model_name, truncate_dim=embedding_dimensions, **kwargs
41
+ )
42
+ else:
43
+ self.model = model_name
37
44
 
38
45
  # Disabel parallelism for tokenizer
39
46
  # Needed because process might be already parallelized
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ebm4subjects
3
- Version: 0.4.1
3
+ Version: 0.5.1
4
4
  Summary: Embedding Based Matching for Automated Subject Indexing
5
5
  Author: Deutsche Nationalbibliothek
6
6
  Maintainer-email: Clemens Rietdorf <c.rietdorf@dnb.de>, Maximilian Kähler <m.kaehler@dnb.de>
@@ -13,9 +13,7 @@ Classifier: Operating System :: OS Independent
13
13
  Classifier: Programming Language :: Python :: 3
14
14
  Requires-Python: >=3.10
15
15
  Requires-Dist: duckdb>=1.3.0
16
- Requires-Dist: flash-attn>=2.8.2
17
16
  Requires-Dist: nltk~=3.9.1
18
- Requires-Dist: pandas>=2.3.0
19
17
  Requires-Dist: polars>=1.30.0
20
18
  Requires-Dist: pyarrow>=21.0.0
21
19
  Requires-Dist: pyoxigraph>=0.4.11
@@ -56,6 +54,7 @@ This design borrows a lot of ideas from lexical matching like Maui [1], Kea [2]
56
54
 
57
55
  [2] Frank, E., Paynter, G. W., Witten, I. H., Gutwin, C., & Nevill-Manning, C. G. (1999). Domain-Specific Keyphrase Extraction. Proceedings of the 16 Th International Joint Conference on Artifical Intelligence (IJCAI99), 668–673.
58
56
 
57
+ ![Embedding Based Matching Sketch](ebm-sketch.svg)
59
58
 
60
59
  ## Why embedding based matching
61
60
 
@@ -0,0 +1,12 @@
1
+ ebm4subjects/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
+ ebm4subjects/analyzer.py,sha256=lqX7AF8WsvwIavgtnmoVQ0i3wzBJJSeH47EiEwoLKGg,1664
3
+ ebm4subjects/chunker.py,sha256=HcEFJtKWHFYZL8DmZcHGXLPGEkCqHZhh_0kSqyYVsdE,6764
4
+ ebm4subjects/duckdb_client.py,sha256=8lDIpj2o2VTEtjHC_vTYrI5-RNXZnWMft45bS6z9B_k,13031
5
+ ebm4subjects/ebm_logging.py,sha256=xkbqeVhSCNuhMwkx2yoIX8_D3z9DcsauZEmHhR1gaS0,5962
6
+ ebm4subjects/ebm_model.py,sha256=PVFtljF3oZK8u0lA6df82lsTdAD8H1Y9CHvWq1jWF2M,29125
7
+ ebm4subjects/embedding_generator.py,sha256=DZhZxkjcsy_4NA62_2V-4UPbIUkg5qMPat_cIgsoIAA,2609
8
+ ebm4subjects/prepare_data.py,sha256=vQ-BdXkIP3iZJdPXol0WDlY8cRFMHkjzzL7oC7EbouE,3084
9
+ ebm4subjects-0.5.1.dist-info/METADATA,sha256=QkOBvOAI49_AUipc3yAH6RVG9OVUs_8jO64Bjfy561U,8274
10
+ ebm4subjects-0.5.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
11
+ ebm4subjects-0.5.1.dist-info/licenses/LICENSE,sha256=RpvAZSjULHvoTR_esTlucJ08-zdQydnoqQLbqOh9Ub8,13826
12
+ ebm4subjects-0.5.1.dist-info/RECORD,,
@@ -1,12 +0,0 @@
1
- ebm4subjects/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
- ebm4subjects/analyzer.py,sha256=kHsM2ZPzOIHp93UbdWtlgWARoH5ZbDueLsw9FJxpomM,1635
3
- ebm4subjects/chunker.py,sha256=5LMOAHAxm_VlwSQnmVJjBxb4Vrdv7N-ioW8wcC-VvF0,6545
4
- ebm4subjects/duckdb_client.py,sha256=JS6yyBe2p01cX_apFXjpYtT-w4Ow41HVhF3z9lKvvww,13046
5
- ebm4subjects/ebm_logging.py,sha256=0tvodIHXdAGPzOXHwQF5lNBZYZTHD33mZrogr1btqV4,6001
6
- ebm4subjects/ebm_model.py,sha256=sZI1QwKAH6wPPIxKbdLudD6rIJj7RNsDVJhV0fPBICw,28097
7
- ebm4subjects/embedding_generator.py,sha256=jC4rz4W50tKndxYezD7Kaoqysl8zhN-TbWirxA_WIQc,2354
8
- ebm4subjects/prepare_data.py,sha256=vQ-BdXkIP3iZJdPXol0WDlY8cRFMHkjzzL7oC7EbouE,3084
9
- ebm4subjects-0.4.1.dist-info/METADATA,sha256=Oo_YR6zYDnhxWZa7Gp_HZuK7qIFIQWlA3dAbDsze_YE,8285
10
- ebm4subjects-0.4.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
11
- ebm4subjects-0.4.1.dist-info/licenses/LICENSE,sha256=RpvAZSjULHvoTR_esTlucJ08-zdQydnoqQLbqOh9Ub8,13826
12
- ebm4subjects-0.4.1.dist-info/RECORD,,