ebm4subjects 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,715 @@
1
+ from __future__ import annotations
2
+
3
+ import ast
4
+ from pathlib import Path
5
+
6
+ import joblib
7
+ import polars as pl
8
+ import xgboost as xgb
9
+
10
+ from ebm4subjects import prepare_data
11
+ from ebm4subjects.chunker import Chunker
12
+ from ebm4subjects.duckdb_client import Duckdb_client
13
+ from ebm4subjects.ebm_logging import EbmLogger, NullLogger, XGBLogging
14
+ from ebm4subjects.embedding_generator import EmbeddingGenerator
15
+
16
+
17
+ class EbmModel:
18
+ def __init__(
19
+ self,
20
+ db_path: str,
21
+ collection_name: str,
22
+ use_altLabels: bool,
23
+ duckdb_threads: int | str,
24
+ embedding_model_name: str,
25
+ embedding_dimensions: int | str,
26
+ chunk_tokenizer: str,
27
+ max_chunks: int | str,
28
+ max_chunk_size: int | str,
29
+ chunking_jobs: int | str,
30
+ max_sentences: int | str,
31
+ max_query_hits: int | str,
32
+ query_top_k: int | str,
33
+ query_jobs: int | str,
34
+ xgb_shrinkage: float | str,
35
+ xgb_interaction_depth: int | str,
36
+ xgb_subsample: float | str,
37
+ xgb_rounds: int | str,
38
+ xgb_jobs: int | str,
39
+ hnsw_index_params: dict | str | None = None,
40
+ model_args: dict | str | None = None,
41
+ encode_args_vocab: dict | str | None = None,
42
+ encode_args_documents: dict | str | None = None,
43
+ log_path: str | None = None,
44
+ ) -> None:
45
+ """
46
+ A class representing an Embedding-Based-Matching (EBM) model
47
+ for automated subject indexing for texts.
48
+
49
+ The EBM model integrates multiple components, including:
50
+ - A DuckDB client for database operations
51
+ - An EmbeddingGenerator for generating embeddings from text data
52
+ - A Chunker for chunking text into smaller pieces
53
+ - An XGBoost Ranker model for ranking candidate labels
54
+
55
+ The EBM model provides methods for creating a vector database,
56
+ preparing training data, training the XGBoost Ranker model,
57
+ and making predictions on generated candidate labels.
58
+
59
+ Attributes:
60
+ client (DuckDB client): The DuckDB client instance
61
+ generator (EmbeddingGenerator): The EmbeddingGenerator instance
62
+ chunker (Chunker): The Chunker instance
63
+ model (XGBoost Ranker model): The trained XGBoost Ranker model
64
+
65
+ Methods:
66
+ create_vector_db: Creates a vector database by loading an existing
67
+ vocabulary with embeddings or generating a new
68
+ vocabulary with embeddings
69
+ prepare_train: Prepares the training data for the EBM model
70
+ train: Trains the XGBoost Ranker model using the provided training data
71
+ predict: Generates predictions for given candidates using the trained model
72
+ save: Saves the current state of the EBM model to a file
73
+ load: Loads an EBM model from a file
74
+
75
+ Notes:
76
+ All parameters with type hints like 'TYPE | str' are expecting a parameter
77
+ of type TYPE, but can also accept the parameter as string. The parameter is
78
+ then cast to the needed type.
79
+ """
80
+ # Parameters for duckdb
81
+ self.client = None
82
+ self.db_path = db_path
83
+ self.collection_name = collection_name
84
+ self.use_altLabels = use_altLabels
85
+ self.duckdb_threads = int(duckdb_threads)
86
+ if isinstance(hnsw_index_params, str) or not hnsw_index_params:
87
+ hnsw_index_params = (
88
+ ast.literal_eval(hnsw_index_params) if hnsw_index_params else {}
89
+ )
90
+ self.hnsw_index_params = hnsw_index_params
91
+
92
+ # Parameters for embedding generator
93
+ self.generator = None
94
+ self.embedding_model_name = embedding_model_name
95
+ self.embedding_dimensions = int(embedding_dimensions)
96
+ if isinstance(model_args, str) or not model_args:
97
+ model_args = ast.literal_eval(model_args) if model_args else {}
98
+ self.model_args = model_args
99
+ if isinstance(encode_args_vocab, str) or not encode_args_vocab:
100
+ encode_args_vocab = (
101
+ ast.literal_eval(encode_args_vocab) if encode_args_vocab else {}
102
+ )
103
+ self.encode_args_vocab = encode_args_vocab
104
+ if isinstance(encode_args_documents, str) or not encode_args_documents:
105
+ encode_args_documents = (
106
+ ast.literal_eval(encode_args_documents) if encode_args_documents else {}
107
+ )
108
+ self.encode_args_documents = encode_args_documents
109
+
110
+ # Parameters for chunker
111
+ self.chunk_tokenizer = chunk_tokenizer
112
+ self.max_chunks = int(max_chunks)
113
+ self.max_chunk_size = int(max_chunk_size)
114
+ self.max_sentences = int(max_sentences)
115
+ self.chunking_jobs = int(chunking_jobs)
116
+
117
+ # Parameters for vector search
118
+ self.max_query_hits = int(max_query_hits)
119
+ self.query_top_k = int(query_top_k)
120
+ self.query_jobs = int(query_jobs)
121
+
122
+ # Parameters for XGB boost ranker
123
+ self.train_shrinkage = float(xgb_shrinkage)
124
+ self.train_interaction_depth = int(xgb_interaction_depth)
125
+ self.train_subsample = float(xgb_subsample)
126
+ self.train_rounds = int(xgb_rounds)
127
+ self.train_jobs = int(xgb_jobs)
128
+
129
+ # Parameters for logger
130
+ # Only create logger if path to log file is set
131
+ self.logger = None
132
+ self.xgb_logger = None
133
+ self.xgb_callbacks = None
134
+ if log_path:
135
+ self.logger = EbmLogger(log_path, "info").get_logger()
136
+ self.xgb_logger = XGBLogging(self.logger, epoch_log_interval=1)
137
+ self.xgb_callbacks = [self.xgb_logger]
138
+ else:
139
+ self.logger = NullLogger()
140
+
141
+ # Initialize EBM model
142
+ self.model = None
143
+
144
+ def _init_duckdb_client(self) -> None:
145
+ """
146
+ Initializes the DuckDB client if it does not already exist.
147
+
148
+ This method creates a new DuckDB client if it is not already
149
+ initiliazed and configures it with the provided database path,
150
+ thread settings, and HNSW index parameters.
151
+
152
+ Returns:
153
+ None
154
+ """
155
+ if self.client is None:
156
+ self.logger.info("Initializing DuckDB client")
157
+
158
+ self.client = Duckdb_client(
159
+ db_path=self.db_path,
160
+ config={
161
+ "hnsw_enable_experimental_persistence": True,
162
+ "threads": self.duckdb_threads,
163
+ },
164
+ hnsw_index_params=self.hnsw_index_params,
165
+ )
166
+
167
+ def _init_generator(self) -> None:
168
+ """
169
+ Initializes the embedding generator if it does not already exist.
170
+
171
+ If the generator is not initialized, it creates a new EmbeddingGenerator
172
+ with the specified model name, embedding dimensions, and model arguments.
173
+
174
+ Returns:
175
+ None
176
+ """
177
+ if self.generator is None:
178
+ self.logger.info("Initializing embedding generator")
179
+
180
+ self.generator = EmbeddingGenerator(
181
+ model_name=self.embedding_model_name,
182
+ embedding_dimensions=self.embedding_dimensions,
183
+ **self.model_args,
184
+ )
185
+
186
+ def create_vector_db(
187
+ self,
188
+ vocab_in_path: str | None = None,
189
+ vocab_out_path: str | None = None,
190
+ force: bool = False,
191
+ ) -> None:
192
+ """
193
+ Creates a vector database by either loading an existing vocabulary
194
+ with embeddings or generating a new vocabulary with embeddings from scratch.
195
+
196
+ If a vocabulary with embeddings already exists at the specified output path,
197
+ it will be loaded. Otherwise, a new vocabulary will be generated from the input
198
+ vocabulary path, and the resulting vocabulary with embeddings will be saved to
199
+ the output path if specified.
200
+
201
+ Args:
202
+ vocab_in_path (optional): The path to the input vocabulary file.
203
+ vocab_out_path (optional): The path to the output vocabulary file
204
+ with embeddings.
205
+ force: Whether to overwrite an existing output file (default: False).
206
+
207
+ Returns:
208
+ None
209
+
210
+ Raises:
211
+ ValueError: If no vocabulary is provided.
212
+ """
213
+ # Check if output path exists and load existing vocabulary if so
214
+ if vocab_out_path and Path(vocab_out_path).exists():
215
+ self.logger.info(
216
+ f"Loading vocabulary with embeddings from {vocab_out_path}"
217
+ )
218
+ collection_df = pl.read_ipc(vocab_out_path)
219
+ # Parse input vocabulary if provided
220
+ elif vocab_in_path:
221
+ self.logger.info("Parsing vocabulary")
222
+ vocab = prepare_data.parse_vocab(
223
+ vocab_path=vocab_in_path,
224
+ use_altLabels=self.use_altLabels,
225
+ )
226
+
227
+ # Initialize generator and add embeddings to vocabulary
228
+ self._init_generator()
229
+ self.logger.info("Adding embeddings to vocabulary")
230
+ collection_df = prepare_data.add_vocab_embeddings(
231
+ vocab=vocab,
232
+ generator=self.generator,
233
+ encode_args=self.encode_args_vocab,
234
+ )
235
+
236
+ # Save vocabulary to output path if specified
237
+ if vocab_out_path:
238
+ # Check if file already exists and warn if so
239
+ if Path(vocab_out_path).exists() and not force:
240
+ self.logger.warn(
241
+ f"""Cant't save vocabulary to {vocab_out_path}.
242
+ File already exists"""
243
+ )
244
+ else:
245
+ self.logger.info(f"Saving vocabulary to {vocab_out_path}")
246
+ collection_df.write_ipc(vocab_out_path)
247
+ else:
248
+ # If no existing vocabulary and no input vocabulary is provided,
249
+ # raise an error
250
+ raise ValueError("Vocabulary path is required")
251
+
252
+ # Initialize DuckDB client and create collection
253
+ self._init_duckdb_client()
254
+ self.logger.info("Creating collection")
255
+ self.client.create_collection(
256
+ collection_df=collection_df,
257
+ collection_name=self.collection_name,
258
+ embedding_dimensions=self.embedding_dimensions,
259
+ hnsw_index_name="hnsw_index",
260
+ hnsw_metric="cosine",
261
+ force=force,
262
+ )
263
+
264
+ def prepare_train(
265
+ self,
266
+ doc_ids: list[str],
267
+ label_ids: list[str],
268
+ texts: list[str],
269
+ train_candidates: pl.DataFrame = None,
270
+ n_jobs: int = 0,
271
+ ) -> pl.DataFrame:
272
+ """
273
+ Prepares the training data for the EBM model.
274
+
275
+ This function generates candidate training data and a gold standard
276
+ data frame. It then compares the candidates to the gold standard,
277
+ computes the necessary features, and returns the resulting training data.
278
+
279
+ Args:
280
+ doc_ids (list[str]): A list of document IDs.
281
+ label_ids (list[str]): A list of label IDs.
282
+ texts (list[str]): A list of text data.
283
+ train_candidates (pl.DataFrame, optional): Pre-computed candidate training data (default: None).
284
+ n_jobs (int, optional): The number of jobs to use for parallel processing (default: 0).
285
+
286
+ Returns:
287
+ pl.DataFrame: The prepared training data.
288
+ """
289
+
290
+ self.logger.info("Preparing training data")
291
+ # Check if pre-computed candidate training data is provided
292
+ if not train_candidates:
293
+ # If not, generate candidate training data in batches
294
+ # If n_jobs is 0, use parameter of EBM model; otherwise, use given number of jobs
295
+ if not n_jobs:
296
+ train_candidates = self.generate_candidates_batch(
297
+ texts=texts,
298
+ doc_ids=doc_ids,
299
+ )
300
+ else:
301
+ train_candidates = self.generate_candidates_batch(
302
+ texts=texts,
303
+ doc_ids=doc_ids,
304
+ chunking_jobs=n_jobs,
305
+ query_jobs=n_jobs,
306
+ )
307
+
308
+ # Create a gold standard data frame from the provided doc IDs and label IDs
309
+ self.logger.info("Preparing gold standard")
310
+ gold_standard = pl.DataFrame(
311
+ {
312
+ "doc_id": doc_ids,
313
+ "label_id": label_ids,
314
+ }
315
+ ).with_columns(
316
+ pl.col("doc_id").cast(pl.String), pl.col("label_id").cast(pl.String)
317
+ )
318
+
319
+ # Compare the candidate training data to the gold standard
320
+ # and prepare data for the training of the XGB ranker model
321
+ self.logger.info("Prepare training data and gold standard for training")
322
+ training_data = (
323
+ self._compare_to_gold_standard(train_candidates, gold_standard)
324
+ .with_columns(pl.when(pl.col("gold")).then(1).otherwise(0).alias("gold"))
325
+ .filter(pl.col("doc_id").is_not_null())
326
+ .select(
327
+ [
328
+ "score",
329
+ "occurrences",
330
+ "min_cosine_similarity",
331
+ "max_cosine_similarity",
332
+ "first_occurence",
333
+ "last_occurence",
334
+ "spread",
335
+ "is_prefLabel",
336
+ "n_chunks",
337
+ "gold",
338
+ ]
339
+ )
340
+ )
341
+
342
+ # Return the prepared training data
343
+ return training_data
344
+
345
+ def _compare_to_gold_standard(
346
+ self,
347
+ candidates: pl.DataFrame,
348
+ gold_standard: pl.DataFrame,
349
+ ) -> pl.DataFrame:
350
+ """
351
+ Compare the model's suggested labels to the gold standard labels.
352
+
353
+ This method joins the model's suggested labels with the gold standard labels
354
+ on the 'doc_id' and 'label_id' columns, filling any missing values with False.
355
+ It then filters the resulting DataFrame to only include suggested labels.
356
+
357
+ Args:
358
+ candidates (pl.DataFrame): The model's suggested labels.
359
+ gold_standard (pl.DataFrame): The gold standard labels.
360
+
361
+ Returns:
362
+ pl.DataFrame: A DataFrame containing the model's suggested labels that match
363
+ the gold standard labels.
364
+ """
365
+ return (
366
+ # Mark suggested candidates and gold standard labels
367
+ # Join candidates and gold standard
368
+ candidates.with_columns(pl.lit(True).alias("suggested"))
369
+ .join(
370
+ other=gold_standard.with_columns(pl.lit(True).alias("gold")),
371
+ on=["doc_id", "label_id"],
372
+ how="outer",
373
+ )
374
+ # Fill dataframe so that all not suggested labels which are not part of
375
+ # the gold standard and all gold standard labels which where not
376
+ # suggested are marked
377
+ .with_columns(
378
+ pl.col("suggested").fill_null(False),
379
+ pl.col("gold").fill_null(False),
380
+ )
381
+ # Keep only suggested labels
382
+ .filter(pl.col("suggested"))
383
+ )
384
+
385
+ def generate_candidates(
386
+ self, text: str, doc_id: int, n_jobs: int = 0
387
+ ) -> pl.DataFrame:
388
+ """
389
+ Generates candidate labels for a given text and document ID.
390
+
391
+ This method chunks the input text, generates embeddings for each chunk,
392
+ and then uses vector search to find similar documents in the database.
393
+
394
+ Args:
395
+ text (str): The input text.
396
+ doc_id (int): The document ID.
397
+ n_jobs (int, optional): The number of jobs to use for parallel
398
+ rocessing (default: 0).
399
+
400
+ Returns:
401
+ pl.DataFrame: A DataFrame containing the generated candidate labels.
402
+ """
403
+ # Check if n_jobs is provided, if not use number of jobs
404
+ # specified in model parameters
405
+ if not n_jobs:
406
+ n_jobs = self.query_jobs
407
+
408
+ # Create a Chunker instance with specified parameters
409
+ self.logger.info("Chunking text")
410
+ chunker = Chunker(
411
+ tokenizer_name=self.chunk_tokenizer,
412
+ max_chunks=self.max_chunks,
413
+ max_chunk_size=self.max_chunk_size,
414
+ max_sentences=self.max_sentences,
415
+ )
416
+ # Chunk the input text
417
+ text_chunks = chunker.chunk_text(text)
418
+
419
+ # Initialize the generator
420
+ self._init_generator()
421
+ self.logger.info("Creating embeddings for text chunks")
422
+ # Generate embeddings for the text chunks
423
+ embeddings = self.generator.generate_embeddings(
424
+ # Use the text chunks as input
425
+ texts=text_chunks,
426
+ # Use the encode arguments for documents if provided
427
+ **(
428
+ self.encode_args_documents
429
+ if self.encode_args_documents is not None
430
+ else {}
431
+ ),
432
+ )
433
+
434
+ # Create a query DataFrame
435
+ self.logger.info("Creating query dataframe")
436
+ query_df = pl.DataFrame(
437
+ {
438
+ # Create a column for the query ID
439
+ "query_id": [i + 1 for i in range(len(text_chunks))],
440
+ # Create a column for the query document ID
441
+ "query_doc_id": [doc_id for _ in range(len(text_chunks))],
442
+ # Create a column for the chunk position
443
+ "chunk_position": [i + 1 for i in range(len(text_chunks))],
444
+ # Create a column for the number of chunks
445
+ "n_chunks": [len(text_chunks) for _ in range(len(text_chunks))],
446
+ # Create a column for the embeddings
447
+ "embeddings": embeddings,
448
+ }
449
+ )
450
+
451
+ # Initialize the DuckDB client
452
+ self._init_duckdb_client()
453
+ self.logger.info("Running vector search and creating candidates")
454
+ # Perform vector search using the query DataFrame
455
+ # Using the parameters specified for the EBM model
456
+ # and the optimal chunk size for the DuckDB
457
+ candidates = self.client.vector_search(
458
+ query_df=query_df,
459
+ collection_name=self.collection_name,
460
+ embedding_dimensions=self.embedding_dimensions,
461
+ n_jobs=n_jobs,
462
+ n_hits=self.max_query_hits,
463
+ chunk_size=1024,
464
+ top_k=self.query_top_k,
465
+ hnsw_metric_function="array_cosine_distance",
466
+ )
467
+
468
+ # Return generated candidates
469
+ return candidates
470
+
471
+ def generate_candidates_batch(
472
+ self,
473
+ texts: list[str],
474
+ doc_ids: list[int],
475
+ chunking_jobs: int = 0,
476
+ query_jobs: int = 0,
477
+ ) -> pl.DataFrame:
478
+ """
479
+ Generates candidate labels for a batch of given texts and document IDs.
480
+
481
+ This method chunks the input texts, generates embeddings for each chunk,
482
+ and then uses vector search to find similar documents in the database.
483
+
484
+ Args:
485
+ text (str): The input text.
486
+ doc_id (int): The document ID.
487
+ chunking_jobs (int, optional): The number of jobs to use for parallel
488
+ chunking (default: 0).
489
+ query_jobs (int, optional): The number of jobs to use for parallel
490
+ querying (default: 0).
491
+
492
+ Returns:
493
+ pl.DataFrame: A DataFrame containing the generated candidate labels.
494
+ """
495
+ # Check if number of jobs are provided, if not use number of jobs
496
+ # specified in model parameters
497
+ if not chunking_jobs:
498
+ chunking_jobs = self.chunking_jobs
499
+ if not query_jobs:
500
+ query_jobs = self.query_jobs
501
+
502
+ # Create a Chunker instance with specified parameters
503
+ self.logger.info("Chunking texts in batches")
504
+ chunker = Chunker(
505
+ tokenizer_name=self.chunk_tokenizer,
506
+ max_chunks=self.max_chunks,
507
+ max_chunk_size=self.max_chunk_size,
508
+ max_sentences=self.max_sentences,
509
+ )
510
+ # Chunk the input texts
511
+ text_chunks, chunk_index = chunker.chunk_batches(texts, doc_ids, chunking_jobs)
512
+
513
+ # Initialize the generator and chunk index
514
+ self._init_generator()
515
+ chunk_index = pl.concat(chunk_index).with_row_index("query_id")
516
+ self.logger.info("Creating embeddings for text chunks and query dataframe")
517
+ embeddings = self.generator.generate_embeddings(
518
+ texts=text_chunks,
519
+ **(
520
+ self.encode_args_documents
521
+ if self.encode_args_documents is not None
522
+ else {}
523
+ ),
524
+ )
525
+
526
+ # Initialize the DuckDB client
527
+ self._init_duckdb_client()
528
+ # Extend chunk_index by a list column containing the embeddings
529
+ query_df = chunk_index.with_columns(pl.Series("embeddings", embeddings))
530
+
531
+ # Perform vector search using the query DataFrame
532
+ # Using the parameters specified for the EBM model
533
+ # and the optimal chunk size for the DuckDB
534
+ self.logger.info("Running vector search and creating candidates")
535
+ candidates = self.client.vector_search(
536
+ query_df=query_df,
537
+ collection_name=self.collection_name,
538
+ embedding_dimensions=self.embedding_dimensions,
539
+ n_jobs=query_jobs,
540
+ n_hits=self.max_query_hits,
541
+ chunk_size=1024,
542
+ top_k=self.query_top_k,
543
+ hnsw_metric_function="array_cosine_distance",
544
+ )
545
+
546
+ # Return generated candidates
547
+ return candidates
548
+
549
+ def train(self, train_data: pl.DataFrame, n_jobs: int = 0) -> None:
550
+ """
551
+ Trains the XGBoost Ranker model using the provided training data.
552
+
553
+ Args:
554
+ train_data: The data to be used for training.
555
+ n_jobs (int, optional): The number of jobs to use for parallel
556
+ processing (default: 0).
557
+
558
+ Returns:
559
+ None
560
+
561
+ Raises:
562
+ XGBoostError: If XGBoost is unable to train with candidates.
563
+ """
564
+ # Check if n_jobs is provided, if not use number of jobs
565
+ # specified in model parameters
566
+ if not n_jobs:
567
+ n_jobs = self.train_jobs
568
+
569
+ # Select the required columns from the train_data DataFrame,
570
+ # convert to a Pandas DataFrame and afterwards to training matrix
571
+ self.logger.info("Creating training matrix")
572
+ matrix = xgb.DMatrix(
573
+ train_data.select(
574
+ [
575
+ "score",
576
+ "occurrences",
577
+ "min_cosine_similarity",
578
+ "max_cosine_similarity",
579
+ "first_occurence",
580
+ "last_occurence",
581
+ "spread",
582
+ "is_prefLabel",
583
+ "n_chunks",
584
+ ]
585
+ ).to_pandas(),
586
+ # Use the gold standard as the target
587
+ train_data.to_pandas()["gold"],
588
+ )
589
+
590
+ try:
591
+ # Train the XGBoost model with the specified parameters
592
+ self.logger.info("Starting training of XGBoost Ranker")
593
+ model = xgb.train(
594
+ # Train the XGBoost model with the specified parameters
595
+ params={
596
+ "objective": "binary:logistic", # Objective function to minimize
597
+ "eval_metric": "logloss", # Evaluation metric
598
+ "eta": self.train_shrinkage, # Learning rate
599
+ "max_depth": self.train_interaction_depth, # Maximum tree depth
600
+ "subsample": self.train_subsample, # Sampling ratio
601
+ "nthread": n_jobs, # Number of threads to use
602
+ },
603
+ # Use the training matrix as the input data
604
+ dtrain=matrix,
605
+ # Disable verbose evaluation
606
+ verbose_eval=False,
607
+ # Evaluate the model on the training data
608
+ evals=[(matrix, "train")],
609
+ # Specify the number of boosting rounds
610
+ num_boost_round=self.train_rounds,
611
+ # Use the specified callbacks
612
+ callbacks=self.xgb_callbacks,
613
+ )
614
+ self.logger.info("Training successful finished")
615
+ except xgb.core.XGBoostError:
616
+ self.logger.critical(
617
+ "XGBoost can't train with candidates equal to gold standard "
618
+ "or candidates with no match to gold standard at all - "
619
+ "Check if your training data and gold standard are correct"
620
+ )
621
+ raise
622
+ else:
623
+ # Store the trained model
624
+ self.model = model
625
+
626
+ def predict(self, candidates: pl.DataFrame) -> list[pl.DataFrame]:
627
+ """
628
+ Generates predictions for the given candidates using the trained model.
629
+
630
+ This method creates a matrix from the candidates DataFrame, makes predictions
631
+ using the trained model, and returns a list of DataFrames containing the
632
+ predicted scores and top-k labels for each document.
633
+
634
+ Args:
635
+ candidates (pl.DataFrame): A DataFrame containing the candidates to
636
+ generate predictions for.
637
+
638
+ Returns:
639
+ list[pl.DataFrame]: A list of DataFrames, where each DataFrame contains
640
+ the predicted scores and top-k labels for a document.
641
+ """
642
+ # Select relevant columns from the candidates DataFrame to create a matrix
643
+ # for the trained model to make predictions
644
+ self.logger.info("Creating matrix of candidates to generate predictions")
645
+ matrix = xgb.DMatrix(
646
+ candidates.select(
647
+ [
648
+ "score",
649
+ "occurrences",
650
+ "min_cosine_similarity",
651
+ "max_cosine_similarity",
652
+ "first_occurence",
653
+ "last_occurence",
654
+ "spread",
655
+ "is_prefLabel",
656
+ "n_chunks",
657
+ ]
658
+ )
659
+ )
660
+
661
+ # Use the trained model to make predictions on the created matrix
662
+ self.logger.info("Making predictions for candidates")
663
+ predictions = self.model.predict(matrix)
664
+
665
+ # Transform the predictions into a list of DataFrames containing the
666
+ # predicted scores and top-k labels for each document
667
+ return (
668
+ # Add a new column with the predicted scores to the candidates DataFrame
669
+ candidates.with_columns(pl.Series(predictions).alias("score"))
670
+ # Select the relevant columns from the updated DataFrame
671
+ .select(["doc_id", "label_id", "score"])
672
+ # Sort the DataFrame by document ID and score in ascending and
673
+ # descending order, respectively
674
+ .sort(["doc_id", "score"], descending=[False, True])
675
+ # Group the DataFrame by document ID and aggregate the top-k labels
676
+ # and scores for each group
677
+ .group_by("doc_id")
678
+ .agg(pl.all().head(self.query_top_k))
679
+ # Explode the aggregated DataFrame to create separate rows for each
680
+ # label and score
681
+ .explode(["label_id", "score"])
682
+ # Partition the DataFrame by document ID
683
+ .partition_by("doc_id")
684
+ )
685
+
686
+ def save(self, output_path: str) -> None:
687
+ """
688
+ Saves the current state of the EBM model to a file using joblib.
689
+
690
+ This method serializes the model instance and writes it to the
691
+ specified output path, allowing for later deserialization and
692
+ restoration of the model's state.
693
+
694
+ Args:
695
+ output_path: The file path where the serialized model will be written.
696
+
697
+ Notes:
698
+ The model's client and generator attributes are reset to None.
699
+ """
700
+ self.client = None
701
+ self.generator = None
702
+ joblib.dump(self, output_path)
703
+
704
+ @staticmethod
705
+ def load(input_path: str) -> EbmModel:
706
+ """
707
+ Loads an EBM model from a joblib serialized file.
708
+
709
+ Args:
710
+ input_path (str): Path to the joblib serialized file containing the EBM model.
711
+
712
+ Returns:
713
+ EbmModel: The loaded EBM model instance.
714
+ """
715
+ return joblib.load(input_path)