ebm4subjects 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ebm4subjects/__init__.py +0 -0
- ebm4subjects/analyzer.py +57 -0
- ebm4subjects/chunker.py +173 -0
- ebm4subjects/duckdb_client.py +329 -0
- ebm4subjects/ebm_logging.py +203 -0
- ebm4subjects/ebm_model.py +715 -0
- ebm4subjects/embedding_generator.py +63 -0
- ebm4subjects/prepare_data.py +82 -0
- ebm4subjects-0.4.1.dist-info/METADATA +134 -0
- ebm4subjects-0.4.1.dist-info/RECORD +12 -0
- ebm4subjects-0.4.1.dist-info/WHEEL +4 -0
- ebm4subjects-0.4.1.dist-info/licenses/LICENSE +287 -0
|
@@ -0,0 +1,715 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import ast
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
|
|
6
|
+
import joblib
|
|
7
|
+
import polars as pl
|
|
8
|
+
import xgboost as xgb
|
|
9
|
+
|
|
10
|
+
from ebm4subjects import prepare_data
|
|
11
|
+
from ebm4subjects.chunker import Chunker
|
|
12
|
+
from ebm4subjects.duckdb_client import Duckdb_client
|
|
13
|
+
from ebm4subjects.ebm_logging import EbmLogger, NullLogger, XGBLogging
|
|
14
|
+
from ebm4subjects.embedding_generator import EmbeddingGenerator
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class EbmModel:
|
|
18
|
+
def __init__(
|
|
19
|
+
self,
|
|
20
|
+
db_path: str,
|
|
21
|
+
collection_name: str,
|
|
22
|
+
use_altLabels: bool,
|
|
23
|
+
duckdb_threads: int | str,
|
|
24
|
+
embedding_model_name: str,
|
|
25
|
+
embedding_dimensions: int | str,
|
|
26
|
+
chunk_tokenizer: str,
|
|
27
|
+
max_chunks: int | str,
|
|
28
|
+
max_chunk_size: int | str,
|
|
29
|
+
chunking_jobs: int | str,
|
|
30
|
+
max_sentences: int | str,
|
|
31
|
+
max_query_hits: int | str,
|
|
32
|
+
query_top_k: int | str,
|
|
33
|
+
query_jobs: int | str,
|
|
34
|
+
xgb_shrinkage: float | str,
|
|
35
|
+
xgb_interaction_depth: int | str,
|
|
36
|
+
xgb_subsample: float | str,
|
|
37
|
+
xgb_rounds: int | str,
|
|
38
|
+
xgb_jobs: int | str,
|
|
39
|
+
hnsw_index_params: dict | str | None = None,
|
|
40
|
+
model_args: dict | str | None = None,
|
|
41
|
+
encode_args_vocab: dict | str | None = None,
|
|
42
|
+
encode_args_documents: dict | str | None = None,
|
|
43
|
+
log_path: str | None = None,
|
|
44
|
+
) -> None:
|
|
45
|
+
"""
|
|
46
|
+
A class representing an Embedding-Based-Matching (EBM) model
|
|
47
|
+
for automated subject indexing for texts.
|
|
48
|
+
|
|
49
|
+
The EBM model integrates multiple components, including:
|
|
50
|
+
- A DuckDB client for database operations
|
|
51
|
+
- An EmbeddingGenerator for generating embeddings from text data
|
|
52
|
+
- A Chunker for chunking text into smaller pieces
|
|
53
|
+
- An XGBoost Ranker model for ranking candidate labels
|
|
54
|
+
|
|
55
|
+
The EBM model provides methods for creating a vector database,
|
|
56
|
+
preparing training data, training the XGBoost Ranker model,
|
|
57
|
+
and making predictions on generated candidate labels.
|
|
58
|
+
|
|
59
|
+
Attributes:
|
|
60
|
+
client (DuckDB client): The DuckDB client instance
|
|
61
|
+
generator (EmbeddingGenerator): The EmbeddingGenerator instance
|
|
62
|
+
chunker (Chunker): The Chunker instance
|
|
63
|
+
model (XGBoost Ranker model): The trained XGBoost Ranker model
|
|
64
|
+
|
|
65
|
+
Methods:
|
|
66
|
+
create_vector_db: Creates a vector database by loading an existing
|
|
67
|
+
vocabulary with embeddings or generating a new
|
|
68
|
+
vocabulary with embeddings
|
|
69
|
+
prepare_train: Prepares the training data for the EBM model
|
|
70
|
+
train: Trains the XGBoost Ranker model using the provided training data
|
|
71
|
+
predict: Generates predictions for given candidates using the trained model
|
|
72
|
+
save: Saves the current state of the EBM model to a file
|
|
73
|
+
load: Loads an EBM model from a file
|
|
74
|
+
|
|
75
|
+
Notes:
|
|
76
|
+
All parameters with type hints like 'TYPE | str' are expecting a parameter
|
|
77
|
+
of type TYPE, but can also accept the parameter as string. The parameter is
|
|
78
|
+
then cast to the needed type.
|
|
79
|
+
"""
|
|
80
|
+
# Parameters for duckdb
|
|
81
|
+
self.client = None
|
|
82
|
+
self.db_path = db_path
|
|
83
|
+
self.collection_name = collection_name
|
|
84
|
+
self.use_altLabels = use_altLabels
|
|
85
|
+
self.duckdb_threads = int(duckdb_threads)
|
|
86
|
+
if isinstance(hnsw_index_params, str) or not hnsw_index_params:
|
|
87
|
+
hnsw_index_params = (
|
|
88
|
+
ast.literal_eval(hnsw_index_params) if hnsw_index_params else {}
|
|
89
|
+
)
|
|
90
|
+
self.hnsw_index_params = hnsw_index_params
|
|
91
|
+
|
|
92
|
+
# Parameters for embedding generator
|
|
93
|
+
self.generator = None
|
|
94
|
+
self.embedding_model_name = embedding_model_name
|
|
95
|
+
self.embedding_dimensions = int(embedding_dimensions)
|
|
96
|
+
if isinstance(model_args, str) or not model_args:
|
|
97
|
+
model_args = ast.literal_eval(model_args) if model_args else {}
|
|
98
|
+
self.model_args = model_args
|
|
99
|
+
if isinstance(encode_args_vocab, str) or not encode_args_vocab:
|
|
100
|
+
encode_args_vocab = (
|
|
101
|
+
ast.literal_eval(encode_args_vocab) if encode_args_vocab else {}
|
|
102
|
+
)
|
|
103
|
+
self.encode_args_vocab = encode_args_vocab
|
|
104
|
+
if isinstance(encode_args_documents, str) or not encode_args_documents:
|
|
105
|
+
encode_args_documents = (
|
|
106
|
+
ast.literal_eval(encode_args_documents) if encode_args_documents else {}
|
|
107
|
+
)
|
|
108
|
+
self.encode_args_documents = encode_args_documents
|
|
109
|
+
|
|
110
|
+
# Parameters for chunker
|
|
111
|
+
self.chunk_tokenizer = chunk_tokenizer
|
|
112
|
+
self.max_chunks = int(max_chunks)
|
|
113
|
+
self.max_chunk_size = int(max_chunk_size)
|
|
114
|
+
self.max_sentences = int(max_sentences)
|
|
115
|
+
self.chunking_jobs = int(chunking_jobs)
|
|
116
|
+
|
|
117
|
+
# Parameters for vector search
|
|
118
|
+
self.max_query_hits = int(max_query_hits)
|
|
119
|
+
self.query_top_k = int(query_top_k)
|
|
120
|
+
self.query_jobs = int(query_jobs)
|
|
121
|
+
|
|
122
|
+
# Parameters for XGB boost ranker
|
|
123
|
+
self.train_shrinkage = float(xgb_shrinkage)
|
|
124
|
+
self.train_interaction_depth = int(xgb_interaction_depth)
|
|
125
|
+
self.train_subsample = float(xgb_subsample)
|
|
126
|
+
self.train_rounds = int(xgb_rounds)
|
|
127
|
+
self.train_jobs = int(xgb_jobs)
|
|
128
|
+
|
|
129
|
+
# Parameters for logger
|
|
130
|
+
# Only create logger if path to log file is set
|
|
131
|
+
self.logger = None
|
|
132
|
+
self.xgb_logger = None
|
|
133
|
+
self.xgb_callbacks = None
|
|
134
|
+
if log_path:
|
|
135
|
+
self.logger = EbmLogger(log_path, "info").get_logger()
|
|
136
|
+
self.xgb_logger = XGBLogging(self.logger, epoch_log_interval=1)
|
|
137
|
+
self.xgb_callbacks = [self.xgb_logger]
|
|
138
|
+
else:
|
|
139
|
+
self.logger = NullLogger()
|
|
140
|
+
|
|
141
|
+
# Initialize EBM model
|
|
142
|
+
self.model = None
|
|
143
|
+
|
|
144
|
+
def _init_duckdb_client(self) -> None:
|
|
145
|
+
"""
|
|
146
|
+
Initializes the DuckDB client if it does not already exist.
|
|
147
|
+
|
|
148
|
+
This method creates a new DuckDB client if it is not already
|
|
149
|
+
initiliazed and configures it with the provided database path,
|
|
150
|
+
thread settings, and HNSW index parameters.
|
|
151
|
+
|
|
152
|
+
Returns:
|
|
153
|
+
None
|
|
154
|
+
"""
|
|
155
|
+
if self.client is None:
|
|
156
|
+
self.logger.info("Initializing DuckDB client")
|
|
157
|
+
|
|
158
|
+
self.client = Duckdb_client(
|
|
159
|
+
db_path=self.db_path,
|
|
160
|
+
config={
|
|
161
|
+
"hnsw_enable_experimental_persistence": True,
|
|
162
|
+
"threads": self.duckdb_threads,
|
|
163
|
+
},
|
|
164
|
+
hnsw_index_params=self.hnsw_index_params,
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
def _init_generator(self) -> None:
|
|
168
|
+
"""
|
|
169
|
+
Initializes the embedding generator if it does not already exist.
|
|
170
|
+
|
|
171
|
+
If the generator is not initialized, it creates a new EmbeddingGenerator
|
|
172
|
+
with the specified model name, embedding dimensions, and model arguments.
|
|
173
|
+
|
|
174
|
+
Returns:
|
|
175
|
+
None
|
|
176
|
+
"""
|
|
177
|
+
if self.generator is None:
|
|
178
|
+
self.logger.info("Initializing embedding generator")
|
|
179
|
+
|
|
180
|
+
self.generator = EmbeddingGenerator(
|
|
181
|
+
model_name=self.embedding_model_name,
|
|
182
|
+
embedding_dimensions=self.embedding_dimensions,
|
|
183
|
+
**self.model_args,
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
def create_vector_db(
|
|
187
|
+
self,
|
|
188
|
+
vocab_in_path: str | None = None,
|
|
189
|
+
vocab_out_path: str | None = None,
|
|
190
|
+
force: bool = False,
|
|
191
|
+
) -> None:
|
|
192
|
+
"""
|
|
193
|
+
Creates a vector database by either loading an existing vocabulary
|
|
194
|
+
with embeddings or generating a new vocabulary with embeddings from scratch.
|
|
195
|
+
|
|
196
|
+
If a vocabulary with embeddings already exists at the specified output path,
|
|
197
|
+
it will be loaded. Otherwise, a new vocabulary will be generated from the input
|
|
198
|
+
vocabulary path, and the resulting vocabulary with embeddings will be saved to
|
|
199
|
+
the output path if specified.
|
|
200
|
+
|
|
201
|
+
Args:
|
|
202
|
+
vocab_in_path (optional): The path to the input vocabulary file.
|
|
203
|
+
vocab_out_path (optional): The path to the output vocabulary file
|
|
204
|
+
with embeddings.
|
|
205
|
+
force: Whether to overwrite an existing output file (default: False).
|
|
206
|
+
|
|
207
|
+
Returns:
|
|
208
|
+
None
|
|
209
|
+
|
|
210
|
+
Raises:
|
|
211
|
+
ValueError: If no vocabulary is provided.
|
|
212
|
+
"""
|
|
213
|
+
# Check if output path exists and load existing vocabulary if so
|
|
214
|
+
if vocab_out_path and Path(vocab_out_path).exists():
|
|
215
|
+
self.logger.info(
|
|
216
|
+
f"Loading vocabulary with embeddings from {vocab_out_path}"
|
|
217
|
+
)
|
|
218
|
+
collection_df = pl.read_ipc(vocab_out_path)
|
|
219
|
+
# Parse input vocabulary if provided
|
|
220
|
+
elif vocab_in_path:
|
|
221
|
+
self.logger.info("Parsing vocabulary")
|
|
222
|
+
vocab = prepare_data.parse_vocab(
|
|
223
|
+
vocab_path=vocab_in_path,
|
|
224
|
+
use_altLabels=self.use_altLabels,
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
# Initialize generator and add embeddings to vocabulary
|
|
228
|
+
self._init_generator()
|
|
229
|
+
self.logger.info("Adding embeddings to vocabulary")
|
|
230
|
+
collection_df = prepare_data.add_vocab_embeddings(
|
|
231
|
+
vocab=vocab,
|
|
232
|
+
generator=self.generator,
|
|
233
|
+
encode_args=self.encode_args_vocab,
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
# Save vocabulary to output path if specified
|
|
237
|
+
if vocab_out_path:
|
|
238
|
+
# Check if file already exists and warn if so
|
|
239
|
+
if Path(vocab_out_path).exists() and not force:
|
|
240
|
+
self.logger.warn(
|
|
241
|
+
f"""Cant't save vocabulary to {vocab_out_path}.
|
|
242
|
+
File already exists"""
|
|
243
|
+
)
|
|
244
|
+
else:
|
|
245
|
+
self.logger.info(f"Saving vocabulary to {vocab_out_path}")
|
|
246
|
+
collection_df.write_ipc(vocab_out_path)
|
|
247
|
+
else:
|
|
248
|
+
# If no existing vocabulary and no input vocabulary is provided,
|
|
249
|
+
# raise an error
|
|
250
|
+
raise ValueError("Vocabulary path is required")
|
|
251
|
+
|
|
252
|
+
# Initialize DuckDB client and create collection
|
|
253
|
+
self._init_duckdb_client()
|
|
254
|
+
self.logger.info("Creating collection")
|
|
255
|
+
self.client.create_collection(
|
|
256
|
+
collection_df=collection_df,
|
|
257
|
+
collection_name=self.collection_name,
|
|
258
|
+
embedding_dimensions=self.embedding_dimensions,
|
|
259
|
+
hnsw_index_name="hnsw_index",
|
|
260
|
+
hnsw_metric="cosine",
|
|
261
|
+
force=force,
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
def prepare_train(
|
|
265
|
+
self,
|
|
266
|
+
doc_ids: list[str],
|
|
267
|
+
label_ids: list[str],
|
|
268
|
+
texts: list[str],
|
|
269
|
+
train_candidates: pl.DataFrame = None,
|
|
270
|
+
n_jobs: int = 0,
|
|
271
|
+
) -> pl.DataFrame:
|
|
272
|
+
"""
|
|
273
|
+
Prepares the training data for the EBM model.
|
|
274
|
+
|
|
275
|
+
This function generates candidate training data and a gold standard
|
|
276
|
+
data frame. It then compares the candidates to the gold standard,
|
|
277
|
+
computes the necessary features, and returns the resulting training data.
|
|
278
|
+
|
|
279
|
+
Args:
|
|
280
|
+
doc_ids (list[str]): A list of document IDs.
|
|
281
|
+
label_ids (list[str]): A list of label IDs.
|
|
282
|
+
texts (list[str]): A list of text data.
|
|
283
|
+
train_candidates (pl.DataFrame, optional): Pre-computed candidate training data (default: None).
|
|
284
|
+
n_jobs (int, optional): The number of jobs to use for parallel processing (default: 0).
|
|
285
|
+
|
|
286
|
+
Returns:
|
|
287
|
+
pl.DataFrame: The prepared training data.
|
|
288
|
+
"""
|
|
289
|
+
|
|
290
|
+
self.logger.info("Preparing training data")
|
|
291
|
+
# Check if pre-computed candidate training data is provided
|
|
292
|
+
if not train_candidates:
|
|
293
|
+
# If not, generate candidate training data in batches
|
|
294
|
+
# If n_jobs is 0, use parameter of EBM model; otherwise, use given number of jobs
|
|
295
|
+
if not n_jobs:
|
|
296
|
+
train_candidates = self.generate_candidates_batch(
|
|
297
|
+
texts=texts,
|
|
298
|
+
doc_ids=doc_ids,
|
|
299
|
+
)
|
|
300
|
+
else:
|
|
301
|
+
train_candidates = self.generate_candidates_batch(
|
|
302
|
+
texts=texts,
|
|
303
|
+
doc_ids=doc_ids,
|
|
304
|
+
chunking_jobs=n_jobs,
|
|
305
|
+
query_jobs=n_jobs,
|
|
306
|
+
)
|
|
307
|
+
|
|
308
|
+
# Create a gold standard data frame from the provided doc IDs and label IDs
|
|
309
|
+
self.logger.info("Preparing gold standard")
|
|
310
|
+
gold_standard = pl.DataFrame(
|
|
311
|
+
{
|
|
312
|
+
"doc_id": doc_ids,
|
|
313
|
+
"label_id": label_ids,
|
|
314
|
+
}
|
|
315
|
+
).with_columns(
|
|
316
|
+
pl.col("doc_id").cast(pl.String), pl.col("label_id").cast(pl.String)
|
|
317
|
+
)
|
|
318
|
+
|
|
319
|
+
# Compare the candidate training data to the gold standard
|
|
320
|
+
# and prepare data for the training of the XGB ranker model
|
|
321
|
+
self.logger.info("Prepare training data and gold standard for training")
|
|
322
|
+
training_data = (
|
|
323
|
+
self._compare_to_gold_standard(train_candidates, gold_standard)
|
|
324
|
+
.with_columns(pl.when(pl.col("gold")).then(1).otherwise(0).alias("gold"))
|
|
325
|
+
.filter(pl.col("doc_id").is_not_null())
|
|
326
|
+
.select(
|
|
327
|
+
[
|
|
328
|
+
"score",
|
|
329
|
+
"occurrences",
|
|
330
|
+
"min_cosine_similarity",
|
|
331
|
+
"max_cosine_similarity",
|
|
332
|
+
"first_occurence",
|
|
333
|
+
"last_occurence",
|
|
334
|
+
"spread",
|
|
335
|
+
"is_prefLabel",
|
|
336
|
+
"n_chunks",
|
|
337
|
+
"gold",
|
|
338
|
+
]
|
|
339
|
+
)
|
|
340
|
+
)
|
|
341
|
+
|
|
342
|
+
# Return the prepared training data
|
|
343
|
+
return training_data
|
|
344
|
+
|
|
345
|
+
def _compare_to_gold_standard(
|
|
346
|
+
self,
|
|
347
|
+
candidates: pl.DataFrame,
|
|
348
|
+
gold_standard: pl.DataFrame,
|
|
349
|
+
) -> pl.DataFrame:
|
|
350
|
+
"""
|
|
351
|
+
Compare the model's suggested labels to the gold standard labels.
|
|
352
|
+
|
|
353
|
+
This method joins the model's suggested labels with the gold standard labels
|
|
354
|
+
on the 'doc_id' and 'label_id' columns, filling any missing values with False.
|
|
355
|
+
It then filters the resulting DataFrame to only include suggested labels.
|
|
356
|
+
|
|
357
|
+
Args:
|
|
358
|
+
candidates (pl.DataFrame): The model's suggested labels.
|
|
359
|
+
gold_standard (pl.DataFrame): The gold standard labels.
|
|
360
|
+
|
|
361
|
+
Returns:
|
|
362
|
+
pl.DataFrame: A DataFrame containing the model's suggested labels that match
|
|
363
|
+
the gold standard labels.
|
|
364
|
+
"""
|
|
365
|
+
return (
|
|
366
|
+
# Mark suggested candidates and gold standard labels
|
|
367
|
+
# Join candidates and gold standard
|
|
368
|
+
candidates.with_columns(pl.lit(True).alias("suggested"))
|
|
369
|
+
.join(
|
|
370
|
+
other=gold_standard.with_columns(pl.lit(True).alias("gold")),
|
|
371
|
+
on=["doc_id", "label_id"],
|
|
372
|
+
how="outer",
|
|
373
|
+
)
|
|
374
|
+
# Fill dataframe so that all not suggested labels which are not part of
|
|
375
|
+
# the gold standard and all gold standard labels which where not
|
|
376
|
+
# suggested are marked
|
|
377
|
+
.with_columns(
|
|
378
|
+
pl.col("suggested").fill_null(False),
|
|
379
|
+
pl.col("gold").fill_null(False),
|
|
380
|
+
)
|
|
381
|
+
# Keep only suggested labels
|
|
382
|
+
.filter(pl.col("suggested"))
|
|
383
|
+
)
|
|
384
|
+
|
|
385
|
+
def generate_candidates(
|
|
386
|
+
self, text: str, doc_id: int, n_jobs: int = 0
|
|
387
|
+
) -> pl.DataFrame:
|
|
388
|
+
"""
|
|
389
|
+
Generates candidate labels for a given text and document ID.
|
|
390
|
+
|
|
391
|
+
This method chunks the input text, generates embeddings for each chunk,
|
|
392
|
+
and then uses vector search to find similar documents in the database.
|
|
393
|
+
|
|
394
|
+
Args:
|
|
395
|
+
text (str): The input text.
|
|
396
|
+
doc_id (int): The document ID.
|
|
397
|
+
n_jobs (int, optional): The number of jobs to use for parallel
|
|
398
|
+
rocessing (default: 0).
|
|
399
|
+
|
|
400
|
+
Returns:
|
|
401
|
+
pl.DataFrame: A DataFrame containing the generated candidate labels.
|
|
402
|
+
"""
|
|
403
|
+
# Check if n_jobs is provided, if not use number of jobs
|
|
404
|
+
# specified in model parameters
|
|
405
|
+
if not n_jobs:
|
|
406
|
+
n_jobs = self.query_jobs
|
|
407
|
+
|
|
408
|
+
# Create a Chunker instance with specified parameters
|
|
409
|
+
self.logger.info("Chunking text")
|
|
410
|
+
chunker = Chunker(
|
|
411
|
+
tokenizer_name=self.chunk_tokenizer,
|
|
412
|
+
max_chunks=self.max_chunks,
|
|
413
|
+
max_chunk_size=self.max_chunk_size,
|
|
414
|
+
max_sentences=self.max_sentences,
|
|
415
|
+
)
|
|
416
|
+
# Chunk the input text
|
|
417
|
+
text_chunks = chunker.chunk_text(text)
|
|
418
|
+
|
|
419
|
+
# Initialize the generator
|
|
420
|
+
self._init_generator()
|
|
421
|
+
self.logger.info("Creating embeddings for text chunks")
|
|
422
|
+
# Generate embeddings for the text chunks
|
|
423
|
+
embeddings = self.generator.generate_embeddings(
|
|
424
|
+
# Use the text chunks as input
|
|
425
|
+
texts=text_chunks,
|
|
426
|
+
# Use the encode arguments for documents if provided
|
|
427
|
+
**(
|
|
428
|
+
self.encode_args_documents
|
|
429
|
+
if self.encode_args_documents is not None
|
|
430
|
+
else {}
|
|
431
|
+
),
|
|
432
|
+
)
|
|
433
|
+
|
|
434
|
+
# Create a query DataFrame
|
|
435
|
+
self.logger.info("Creating query dataframe")
|
|
436
|
+
query_df = pl.DataFrame(
|
|
437
|
+
{
|
|
438
|
+
# Create a column for the query ID
|
|
439
|
+
"query_id": [i + 1 for i in range(len(text_chunks))],
|
|
440
|
+
# Create a column for the query document ID
|
|
441
|
+
"query_doc_id": [doc_id for _ in range(len(text_chunks))],
|
|
442
|
+
# Create a column for the chunk position
|
|
443
|
+
"chunk_position": [i + 1 for i in range(len(text_chunks))],
|
|
444
|
+
# Create a column for the number of chunks
|
|
445
|
+
"n_chunks": [len(text_chunks) for _ in range(len(text_chunks))],
|
|
446
|
+
# Create a column for the embeddings
|
|
447
|
+
"embeddings": embeddings,
|
|
448
|
+
}
|
|
449
|
+
)
|
|
450
|
+
|
|
451
|
+
# Initialize the DuckDB client
|
|
452
|
+
self._init_duckdb_client()
|
|
453
|
+
self.logger.info("Running vector search and creating candidates")
|
|
454
|
+
# Perform vector search using the query DataFrame
|
|
455
|
+
# Using the parameters specified for the EBM model
|
|
456
|
+
# and the optimal chunk size for the DuckDB
|
|
457
|
+
candidates = self.client.vector_search(
|
|
458
|
+
query_df=query_df,
|
|
459
|
+
collection_name=self.collection_name,
|
|
460
|
+
embedding_dimensions=self.embedding_dimensions,
|
|
461
|
+
n_jobs=n_jobs,
|
|
462
|
+
n_hits=self.max_query_hits,
|
|
463
|
+
chunk_size=1024,
|
|
464
|
+
top_k=self.query_top_k,
|
|
465
|
+
hnsw_metric_function="array_cosine_distance",
|
|
466
|
+
)
|
|
467
|
+
|
|
468
|
+
# Return generated candidates
|
|
469
|
+
return candidates
|
|
470
|
+
|
|
471
|
+
def generate_candidates_batch(
|
|
472
|
+
self,
|
|
473
|
+
texts: list[str],
|
|
474
|
+
doc_ids: list[int],
|
|
475
|
+
chunking_jobs: int = 0,
|
|
476
|
+
query_jobs: int = 0,
|
|
477
|
+
) -> pl.DataFrame:
|
|
478
|
+
"""
|
|
479
|
+
Generates candidate labels for a batch of given texts and document IDs.
|
|
480
|
+
|
|
481
|
+
This method chunks the input texts, generates embeddings for each chunk,
|
|
482
|
+
and then uses vector search to find similar documents in the database.
|
|
483
|
+
|
|
484
|
+
Args:
|
|
485
|
+
text (str): The input text.
|
|
486
|
+
doc_id (int): The document ID.
|
|
487
|
+
chunking_jobs (int, optional): The number of jobs to use for parallel
|
|
488
|
+
chunking (default: 0).
|
|
489
|
+
query_jobs (int, optional): The number of jobs to use for parallel
|
|
490
|
+
querying (default: 0).
|
|
491
|
+
|
|
492
|
+
Returns:
|
|
493
|
+
pl.DataFrame: A DataFrame containing the generated candidate labels.
|
|
494
|
+
"""
|
|
495
|
+
# Check if number of jobs are provided, if not use number of jobs
|
|
496
|
+
# specified in model parameters
|
|
497
|
+
if not chunking_jobs:
|
|
498
|
+
chunking_jobs = self.chunking_jobs
|
|
499
|
+
if not query_jobs:
|
|
500
|
+
query_jobs = self.query_jobs
|
|
501
|
+
|
|
502
|
+
# Create a Chunker instance with specified parameters
|
|
503
|
+
self.logger.info("Chunking texts in batches")
|
|
504
|
+
chunker = Chunker(
|
|
505
|
+
tokenizer_name=self.chunk_tokenizer,
|
|
506
|
+
max_chunks=self.max_chunks,
|
|
507
|
+
max_chunk_size=self.max_chunk_size,
|
|
508
|
+
max_sentences=self.max_sentences,
|
|
509
|
+
)
|
|
510
|
+
# Chunk the input texts
|
|
511
|
+
text_chunks, chunk_index = chunker.chunk_batches(texts, doc_ids, chunking_jobs)
|
|
512
|
+
|
|
513
|
+
# Initialize the generator and chunk index
|
|
514
|
+
self._init_generator()
|
|
515
|
+
chunk_index = pl.concat(chunk_index).with_row_index("query_id")
|
|
516
|
+
self.logger.info("Creating embeddings for text chunks and query dataframe")
|
|
517
|
+
embeddings = self.generator.generate_embeddings(
|
|
518
|
+
texts=text_chunks,
|
|
519
|
+
**(
|
|
520
|
+
self.encode_args_documents
|
|
521
|
+
if self.encode_args_documents is not None
|
|
522
|
+
else {}
|
|
523
|
+
),
|
|
524
|
+
)
|
|
525
|
+
|
|
526
|
+
# Initialize the DuckDB client
|
|
527
|
+
self._init_duckdb_client()
|
|
528
|
+
# Extend chunk_index by a list column containing the embeddings
|
|
529
|
+
query_df = chunk_index.with_columns(pl.Series("embeddings", embeddings))
|
|
530
|
+
|
|
531
|
+
# Perform vector search using the query DataFrame
|
|
532
|
+
# Using the parameters specified for the EBM model
|
|
533
|
+
# and the optimal chunk size for the DuckDB
|
|
534
|
+
self.logger.info("Running vector search and creating candidates")
|
|
535
|
+
candidates = self.client.vector_search(
|
|
536
|
+
query_df=query_df,
|
|
537
|
+
collection_name=self.collection_name,
|
|
538
|
+
embedding_dimensions=self.embedding_dimensions,
|
|
539
|
+
n_jobs=query_jobs,
|
|
540
|
+
n_hits=self.max_query_hits,
|
|
541
|
+
chunk_size=1024,
|
|
542
|
+
top_k=self.query_top_k,
|
|
543
|
+
hnsw_metric_function="array_cosine_distance",
|
|
544
|
+
)
|
|
545
|
+
|
|
546
|
+
# Return generated candidates
|
|
547
|
+
return candidates
|
|
548
|
+
|
|
549
|
+
def train(self, train_data: pl.DataFrame, n_jobs: int = 0) -> None:
|
|
550
|
+
"""
|
|
551
|
+
Trains the XGBoost Ranker model using the provided training data.
|
|
552
|
+
|
|
553
|
+
Args:
|
|
554
|
+
train_data: The data to be used for training.
|
|
555
|
+
n_jobs (int, optional): The number of jobs to use for parallel
|
|
556
|
+
processing (default: 0).
|
|
557
|
+
|
|
558
|
+
Returns:
|
|
559
|
+
None
|
|
560
|
+
|
|
561
|
+
Raises:
|
|
562
|
+
XGBoostError: If XGBoost is unable to train with candidates.
|
|
563
|
+
"""
|
|
564
|
+
# Check if n_jobs is provided, if not use number of jobs
|
|
565
|
+
# specified in model parameters
|
|
566
|
+
if not n_jobs:
|
|
567
|
+
n_jobs = self.train_jobs
|
|
568
|
+
|
|
569
|
+
# Select the required columns from the train_data DataFrame,
|
|
570
|
+
# convert to a Pandas DataFrame and afterwards to training matrix
|
|
571
|
+
self.logger.info("Creating training matrix")
|
|
572
|
+
matrix = xgb.DMatrix(
|
|
573
|
+
train_data.select(
|
|
574
|
+
[
|
|
575
|
+
"score",
|
|
576
|
+
"occurrences",
|
|
577
|
+
"min_cosine_similarity",
|
|
578
|
+
"max_cosine_similarity",
|
|
579
|
+
"first_occurence",
|
|
580
|
+
"last_occurence",
|
|
581
|
+
"spread",
|
|
582
|
+
"is_prefLabel",
|
|
583
|
+
"n_chunks",
|
|
584
|
+
]
|
|
585
|
+
).to_pandas(),
|
|
586
|
+
# Use the gold standard as the target
|
|
587
|
+
train_data.to_pandas()["gold"],
|
|
588
|
+
)
|
|
589
|
+
|
|
590
|
+
try:
|
|
591
|
+
# Train the XGBoost model with the specified parameters
|
|
592
|
+
self.logger.info("Starting training of XGBoost Ranker")
|
|
593
|
+
model = xgb.train(
|
|
594
|
+
# Train the XGBoost model with the specified parameters
|
|
595
|
+
params={
|
|
596
|
+
"objective": "binary:logistic", # Objective function to minimize
|
|
597
|
+
"eval_metric": "logloss", # Evaluation metric
|
|
598
|
+
"eta": self.train_shrinkage, # Learning rate
|
|
599
|
+
"max_depth": self.train_interaction_depth, # Maximum tree depth
|
|
600
|
+
"subsample": self.train_subsample, # Sampling ratio
|
|
601
|
+
"nthread": n_jobs, # Number of threads to use
|
|
602
|
+
},
|
|
603
|
+
# Use the training matrix as the input data
|
|
604
|
+
dtrain=matrix,
|
|
605
|
+
# Disable verbose evaluation
|
|
606
|
+
verbose_eval=False,
|
|
607
|
+
# Evaluate the model on the training data
|
|
608
|
+
evals=[(matrix, "train")],
|
|
609
|
+
# Specify the number of boosting rounds
|
|
610
|
+
num_boost_round=self.train_rounds,
|
|
611
|
+
# Use the specified callbacks
|
|
612
|
+
callbacks=self.xgb_callbacks,
|
|
613
|
+
)
|
|
614
|
+
self.logger.info("Training successful finished")
|
|
615
|
+
except xgb.core.XGBoostError:
|
|
616
|
+
self.logger.critical(
|
|
617
|
+
"XGBoost can't train with candidates equal to gold standard "
|
|
618
|
+
"or candidates with no match to gold standard at all - "
|
|
619
|
+
"Check if your training data and gold standard are correct"
|
|
620
|
+
)
|
|
621
|
+
raise
|
|
622
|
+
else:
|
|
623
|
+
# Store the trained model
|
|
624
|
+
self.model = model
|
|
625
|
+
|
|
626
|
+
def predict(self, candidates: pl.DataFrame) -> list[pl.DataFrame]:
|
|
627
|
+
"""
|
|
628
|
+
Generates predictions for the given candidates using the trained model.
|
|
629
|
+
|
|
630
|
+
This method creates a matrix from the candidates DataFrame, makes predictions
|
|
631
|
+
using the trained model, and returns a list of DataFrames containing the
|
|
632
|
+
predicted scores and top-k labels for each document.
|
|
633
|
+
|
|
634
|
+
Args:
|
|
635
|
+
candidates (pl.DataFrame): A DataFrame containing the candidates to
|
|
636
|
+
generate predictions for.
|
|
637
|
+
|
|
638
|
+
Returns:
|
|
639
|
+
list[pl.DataFrame]: A list of DataFrames, where each DataFrame contains
|
|
640
|
+
the predicted scores and top-k labels for a document.
|
|
641
|
+
"""
|
|
642
|
+
# Select relevant columns from the candidates DataFrame to create a matrix
|
|
643
|
+
# for the trained model to make predictions
|
|
644
|
+
self.logger.info("Creating matrix of candidates to generate predictions")
|
|
645
|
+
matrix = xgb.DMatrix(
|
|
646
|
+
candidates.select(
|
|
647
|
+
[
|
|
648
|
+
"score",
|
|
649
|
+
"occurrences",
|
|
650
|
+
"min_cosine_similarity",
|
|
651
|
+
"max_cosine_similarity",
|
|
652
|
+
"first_occurence",
|
|
653
|
+
"last_occurence",
|
|
654
|
+
"spread",
|
|
655
|
+
"is_prefLabel",
|
|
656
|
+
"n_chunks",
|
|
657
|
+
]
|
|
658
|
+
)
|
|
659
|
+
)
|
|
660
|
+
|
|
661
|
+
# Use the trained model to make predictions on the created matrix
|
|
662
|
+
self.logger.info("Making predictions for candidates")
|
|
663
|
+
predictions = self.model.predict(matrix)
|
|
664
|
+
|
|
665
|
+
# Transform the predictions into a list of DataFrames containing the
|
|
666
|
+
# predicted scores and top-k labels for each document
|
|
667
|
+
return (
|
|
668
|
+
# Add a new column with the predicted scores to the candidates DataFrame
|
|
669
|
+
candidates.with_columns(pl.Series(predictions).alias("score"))
|
|
670
|
+
# Select the relevant columns from the updated DataFrame
|
|
671
|
+
.select(["doc_id", "label_id", "score"])
|
|
672
|
+
# Sort the DataFrame by document ID and score in ascending and
|
|
673
|
+
# descending order, respectively
|
|
674
|
+
.sort(["doc_id", "score"], descending=[False, True])
|
|
675
|
+
# Group the DataFrame by document ID and aggregate the top-k labels
|
|
676
|
+
# and scores for each group
|
|
677
|
+
.group_by("doc_id")
|
|
678
|
+
.agg(pl.all().head(self.query_top_k))
|
|
679
|
+
# Explode the aggregated DataFrame to create separate rows for each
|
|
680
|
+
# label and score
|
|
681
|
+
.explode(["label_id", "score"])
|
|
682
|
+
# Partition the DataFrame by document ID
|
|
683
|
+
.partition_by("doc_id")
|
|
684
|
+
)
|
|
685
|
+
|
|
686
|
+
def save(self, output_path: str) -> None:
|
|
687
|
+
"""
|
|
688
|
+
Saves the current state of the EBM model to a file using joblib.
|
|
689
|
+
|
|
690
|
+
This method serializes the model instance and writes it to the
|
|
691
|
+
specified output path, allowing for later deserialization and
|
|
692
|
+
restoration of the model's state.
|
|
693
|
+
|
|
694
|
+
Args:
|
|
695
|
+
output_path: The file path where the serialized model will be written.
|
|
696
|
+
|
|
697
|
+
Notes:
|
|
698
|
+
The model's client and generator attributes are reset to None.
|
|
699
|
+
"""
|
|
700
|
+
self.client = None
|
|
701
|
+
self.generator = None
|
|
702
|
+
joblib.dump(self, output_path)
|
|
703
|
+
|
|
704
|
+
@staticmethod
|
|
705
|
+
def load(input_path: str) -> EbmModel:
|
|
706
|
+
"""
|
|
707
|
+
Loads an EBM model from a joblib serialized file.
|
|
708
|
+
|
|
709
|
+
Args:
|
|
710
|
+
input_path (str): Path to the joblib serialized file containing the EBM model.
|
|
711
|
+
|
|
712
|
+
Returns:
|
|
713
|
+
EbmModel: The loaded EBM model instance.
|
|
714
|
+
"""
|
|
715
|
+
return joblib.load(input_path)
|