ebm4subjects 0.4.1__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ebm4subjects/analyzer.py +3 -2
- ebm4subjects/chunker.py +25 -20
- ebm4subjects/duckdb_client.py +8 -8
- ebm4subjects/ebm_logging.py +11 -8
- ebm4subjects/ebm_model.py +105 -75
- ebm4subjects/embedding_generator.py +12 -5
- {ebm4subjects-0.4.1.dist-info → ebm4subjects-0.5.1.dist-info}/METADATA +2 -3
- ebm4subjects-0.5.1.dist-info/RECORD +12 -0
- ebm4subjects-0.4.1.dist-info/RECORD +0 -12
- {ebm4subjects-0.4.1.dist-info → ebm4subjects-0.5.1.dist-info}/WHEEL +0 -0
- {ebm4subjects-0.4.1.dist-info → ebm4subjects-0.5.1.dist-info}/licenses/LICENSE +0 -0
ebm4subjects/analyzer.py
CHANGED
|
@@ -32,8 +32,9 @@ class EbmAnalyzer:
|
|
|
32
32
|
nltk.data.find(tokenizer_name)
|
|
33
33
|
# If the tokenizer is not found, try to download it
|
|
34
34
|
except LookupError as error:
|
|
35
|
-
if
|
|
36
|
-
nltk.download(
|
|
35
|
+
if "punkt" in str(error):
|
|
36
|
+
nltk.download("punkt")
|
|
37
|
+
nltk.download("punkt_tab")
|
|
37
38
|
else:
|
|
38
39
|
raise
|
|
39
40
|
|
ebm4subjects/chunker.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from concurrent.futures import ProcessPoolExecutor
|
|
2
2
|
from math import ceil
|
|
3
|
-
from typing import Tuple
|
|
3
|
+
from typing import Any, Tuple
|
|
4
4
|
|
|
5
5
|
import polars as pl
|
|
6
6
|
|
|
@@ -17,9 +17,9 @@ class Chunker:
|
|
|
17
17
|
|
|
18
18
|
Attributes:
|
|
19
19
|
tokenizer (EbmAnalyzer): The tokenizer used for tokenizing sentences.
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
20
|
+
max_chunk_count (int): The maximum number of chunks to generate.
|
|
21
|
+
max_chunk_length (int): The maximum size of each chunk in characters.
|
|
22
|
+
max_sentence_count (int): The maximum number of sentences to consider.
|
|
23
23
|
|
|
24
24
|
Methods:
|
|
25
25
|
- chunk_text: Chunks a given text into smaller sections
|
|
@@ -28,25 +28,30 @@ class Chunker:
|
|
|
28
28
|
|
|
29
29
|
def __init__(
|
|
30
30
|
self,
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
31
|
+
tokenizer: Any,
|
|
32
|
+
max_chunk_count: int | None,
|
|
33
|
+
max_chunk_length: int | None,
|
|
34
|
+
max_sentence_count: int | None,
|
|
35
35
|
):
|
|
36
36
|
"""
|
|
37
37
|
Initializes the Chunker.
|
|
38
38
|
|
|
39
39
|
Args:
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
40
|
+
tokenizer (Any): The name of the tokenizer to use or the tokenizer itself.
|
|
41
|
+
max_chunk_count (int | None): The maximum number of chunks to generate.
|
|
42
|
+
max_chunk_length (int | None): The maximum size of each chunk in characters.
|
|
43
|
+
max_sentence_count (int | None): The maximum number of sentences to consider.
|
|
44
44
|
"""
|
|
45
|
-
self.
|
|
46
|
-
self.
|
|
47
|
-
self.
|
|
45
|
+
self.max_chunk_count = max_chunk_count if max_chunk_count else float("inf")
|
|
46
|
+
self.max_chunk_length = max_chunk_length if max_chunk_length else float("inf")
|
|
47
|
+
self.max_sentence_count = (
|
|
48
|
+
max_sentence_count if max_sentence_count else float("inf")
|
|
49
|
+
)
|
|
48
50
|
|
|
49
|
-
|
|
51
|
+
if type(tokenizer) is str:
|
|
52
|
+
self.tokenizer = EbmAnalyzer(tokenizer)
|
|
53
|
+
else:
|
|
54
|
+
self.tokenizer = tokenizer
|
|
50
55
|
|
|
51
56
|
def chunk_text(self, text: str) -> list[str]:
|
|
52
57
|
"""
|
|
@@ -63,7 +68,7 @@ class Chunker:
|
|
|
63
68
|
|
|
64
69
|
# Tokenize the text into sentences
|
|
65
70
|
sentences = self.tokenizer.tokenize_sentences(text)
|
|
66
|
-
sentences = sentences[: self.
|
|
71
|
+
sentences = sentences[: self.max_sentence_count]
|
|
67
72
|
|
|
68
73
|
# Initialize an empty list to store the current chunk
|
|
69
74
|
current_chunk = []
|
|
@@ -71,18 +76,18 @@ class Chunker:
|
|
|
71
76
|
# Iterate over the sentences
|
|
72
77
|
for sentence in sentences:
|
|
73
78
|
# If the current chunk is not full, add the sentence to it
|
|
74
|
-
if len(" ".join(current_chunk)) < self.
|
|
79
|
+
if len(" ".join(current_chunk)) < self.max_chunk_length:
|
|
75
80
|
current_chunk.append(sentence)
|
|
76
81
|
# Otherwise, add the current chunk to the list of chunks
|
|
77
82
|
# and start a new chunk
|
|
78
83
|
else:
|
|
79
84
|
chunks.append(" ".join(current_chunk))
|
|
80
85
|
current_chunk = [sentence]
|
|
81
|
-
if len(chunks) == self.
|
|
86
|
+
if len(chunks) == self.max_chunk_count:
|
|
82
87
|
break
|
|
83
88
|
|
|
84
89
|
# If the maximum number of chunks is reached, break the loop
|
|
85
|
-
if current_chunk and len(chunks) < self.
|
|
90
|
+
if current_chunk and len(chunks) < self.max_chunk_count:
|
|
86
91
|
chunks.append(" ".join(current_chunk))
|
|
87
92
|
|
|
88
93
|
# Return the chunked text
|
ebm4subjects/duckdb_client.py
CHANGED
|
@@ -37,8 +37,8 @@ class Duckdb_client:
|
|
|
37
37
|
(default: {"M": 32, "ef_construction": 256, "ef_search": 256}).
|
|
38
38
|
|
|
39
39
|
Notes:
|
|
40
|
-
'hnsw_enable_experimental_persistence' needs to be set to 'True' in order
|
|
41
|
-
to store and query the index later
|
|
40
|
+
'hnsw_enable_experimental_persistence' needs to be set to 'True' in order
|
|
41
|
+
to store and query the index later
|
|
42
42
|
"""
|
|
43
43
|
# Establish a connection to the DuckDB database
|
|
44
44
|
self.connection = duckdb.connect(
|
|
@@ -76,10 +76,10 @@ class Duckdb_client:
|
|
|
76
76
|
(default: "cosine")
|
|
77
77
|
force (bool, optional): Whether to replace the existing collection if it
|
|
78
78
|
already exists (default: False).
|
|
79
|
-
|
|
79
|
+
|
|
80
80
|
Notes:
|
|
81
|
-
If 'hnsw_metric' is changed in this function 'hnsw_metric_function' in
|
|
82
|
-
the vector_search function needs to be changed accordingly in order
|
|
81
|
+
If 'hnsw_metric' is changed in this function 'hnsw_metric_function' in
|
|
82
|
+
the vector_search function needs to be changed accordingly in order
|
|
83
83
|
for the index to work properly.
|
|
84
84
|
"""
|
|
85
85
|
# Determine whether to replace the existing collection
|
|
@@ -147,10 +147,10 @@ class Duckdb_client:
|
|
|
147
147
|
pl.DataFrame: The result of the vector search.
|
|
148
148
|
|
|
149
149
|
Notes:
|
|
150
|
-
If 'hnsw_metric_function' is changed in this function 'hnsw_metric' in
|
|
151
|
-
the create_collection function needs to be changed accordingly in order
|
|
150
|
+
If 'hnsw_metric_function' is changed in this function 'hnsw_metric' in
|
|
151
|
+
the create_collection function needs to be changed accordingly in order
|
|
152
152
|
for the index to work properly.
|
|
153
|
-
The argument 'chunk_size' is already set to the optimal value for the
|
|
153
|
+
The argument 'chunk_size' is already set to the optimal value for the
|
|
154
154
|
query processing with DuckDB. Only change it if necessary.
|
|
155
155
|
"""
|
|
156
156
|
# Create a temporary table to store the search results
|
ebm4subjects/ebm_logging.py
CHANGED
|
@@ -7,7 +7,7 @@ class EbmLogger:
|
|
|
7
7
|
"""
|
|
8
8
|
A custom logger class.
|
|
9
9
|
|
|
10
|
-
This class provides a way to log messages at different levels
|
|
10
|
+
This class provides a way to log messages at different levels
|
|
11
11
|
(error, warning, info, debug) to a file.
|
|
12
12
|
It also provides a way to get the logger instance.
|
|
13
13
|
|
|
@@ -16,6 +16,7 @@ class EbmLogger:
|
|
|
16
16
|
log_path (str): The path to the log file.
|
|
17
17
|
level (str): The log level (default: "info").
|
|
18
18
|
"""
|
|
19
|
+
|
|
19
20
|
def __init__(self, log_path: str, level: str = "info") -> None:
|
|
20
21
|
"""
|
|
21
22
|
Initializes the logger.
|
|
@@ -66,6 +67,7 @@ class NullLogger:
|
|
|
66
67
|
|
|
67
68
|
This class is used when no logging is needed.
|
|
68
69
|
"""
|
|
70
|
+
|
|
69
71
|
def __init__(self) -> None:
|
|
70
72
|
"""
|
|
71
73
|
Initializes the null logger.
|
|
@@ -136,16 +138,17 @@ class NullLogger:
|
|
|
136
138
|
class XGBLogging(xgboost.callback.TrainingCallback):
|
|
137
139
|
"""
|
|
138
140
|
Custom XGBoost training callback for logging model performance during training.
|
|
139
|
-
|
|
141
|
+
|
|
140
142
|
Args:
|
|
141
143
|
logger (logging.Logger): Logger instance to use for logging.
|
|
142
144
|
epoch_log_interval (int, optional): Interval at which to log model performance
|
|
143
145
|
(default: 100).
|
|
144
|
-
|
|
146
|
+
|
|
145
147
|
Attributes:
|
|
146
148
|
logger (logging.Logger): Logger instance used for logging.
|
|
147
149
|
epoch_log_interval (int): Interval at which to log model performance.
|
|
148
150
|
"""
|
|
151
|
+
|
|
149
152
|
def __init__(
|
|
150
153
|
self,
|
|
151
154
|
logger: logging.Logger,
|
|
@@ -153,10 +156,10 @@ class XGBLogging(xgboost.callback.TrainingCallback):
|
|
|
153
156
|
) -> None:
|
|
154
157
|
"""
|
|
155
158
|
Initializes the XGBLogger.
|
|
156
|
-
|
|
159
|
+
|
|
157
160
|
Args:
|
|
158
161
|
logger (logging.Logger): Logger instance to use for logging.
|
|
159
|
-
epoch_log_interval (int, optional): Interval at which to log model
|
|
162
|
+
epoch_log_interval (int, optional): Interval at which to log model
|
|
160
163
|
performance (default: to 100).
|
|
161
164
|
"""
|
|
162
165
|
# Logger instance used for logging
|
|
@@ -172,14 +175,14 @@ class XGBLogging(xgboost.callback.TrainingCallback):
|
|
|
172
175
|
) -> bool:
|
|
173
176
|
"""
|
|
174
177
|
Callback function called after each iteration of the XGBoost training process.
|
|
175
|
-
|
|
178
|
+
|
|
176
179
|
Logs model performance at the specified interval.
|
|
177
|
-
|
|
180
|
+
|
|
178
181
|
Args:
|
|
179
182
|
model (xgboost.Booster): XGBoost model instance.
|
|
180
183
|
epoch (int): Current epoch number.
|
|
181
184
|
evals_log (dict): Dictionary containing evaluation metrics.
|
|
182
|
-
|
|
185
|
+
|
|
183
186
|
Returns:
|
|
184
187
|
bool: Always returns False, as specified by the XGBoost callback API.
|
|
185
188
|
"""
|
ebm4subjects/ebm_model.py
CHANGED
|
@@ -1,7 +1,9 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import ast
|
|
4
|
+
import logging
|
|
4
5
|
from pathlib import Path
|
|
6
|
+
from typing import Any
|
|
5
7
|
|
|
6
8
|
import joblib
|
|
7
9
|
import polars as pl
|
|
@@ -17,30 +19,31 @@ from ebm4subjects.embedding_generator import EmbeddingGenerator
|
|
|
17
19
|
class EbmModel:
|
|
18
20
|
def __init__(
|
|
19
21
|
self,
|
|
20
|
-
|
|
21
|
-
collection_name: str,
|
|
22
|
-
use_altLabels: bool,
|
|
23
|
-
duckdb_threads: int | str,
|
|
24
|
-
embedding_model_name: str,
|
|
22
|
+
embedding_model_name: str | Any,
|
|
25
23
|
embedding_dimensions: int | str,
|
|
26
|
-
chunk_tokenizer: str,
|
|
27
|
-
|
|
28
|
-
|
|
24
|
+
chunk_tokenizer: str | Any,
|
|
25
|
+
max_chunk_count: int | str,
|
|
26
|
+
max_chunk_length: int | str,
|
|
29
27
|
chunking_jobs: int | str,
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
28
|
+
max_sentence_count: int | str,
|
|
29
|
+
candidates_per_chunk: int | str,
|
|
30
|
+
candidates_per_doc: int | str,
|
|
33
31
|
query_jobs: int | str,
|
|
34
32
|
xgb_shrinkage: float | str,
|
|
35
33
|
xgb_interaction_depth: int | str,
|
|
36
34
|
xgb_subsample: float | str,
|
|
37
35
|
xgb_rounds: int | str,
|
|
38
36
|
xgb_jobs: int | str,
|
|
37
|
+
duckdb_threads: int | str,
|
|
38
|
+
db_path: str,
|
|
39
|
+
collection_name: str = "my_collection",
|
|
40
|
+
use_altLabels: bool = True,
|
|
39
41
|
hnsw_index_params: dict | str | None = None,
|
|
40
42
|
model_args: dict | str | None = None,
|
|
41
43
|
encode_args_vocab: dict | str | None = None,
|
|
42
44
|
encode_args_documents: dict | str | None = None,
|
|
43
45
|
log_path: str | None = None,
|
|
46
|
+
logger: logging.Logger | None = None,
|
|
44
47
|
) -> None:
|
|
45
48
|
"""
|
|
46
49
|
A class representing an Embedding-Based-Matching (EBM) model
|
|
@@ -109,14 +112,14 @@ class EbmModel:
|
|
|
109
112
|
|
|
110
113
|
# Parameters for chunker
|
|
111
114
|
self.chunk_tokenizer = chunk_tokenizer
|
|
112
|
-
self.
|
|
113
|
-
self.
|
|
114
|
-
self.
|
|
115
|
+
self.max_chunk_count = int(max_chunk_count)
|
|
116
|
+
self.max_chunk_length = int(max_chunk_length)
|
|
117
|
+
self.max_sentence_count = int(max_sentence_count)
|
|
115
118
|
self.chunking_jobs = int(chunking_jobs)
|
|
116
119
|
|
|
117
120
|
# Parameters for vector search
|
|
118
|
-
self.
|
|
119
|
-
self.
|
|
121
|
+
self.candidates_per_chunk = int(candidates_per_chunk)
|
|
122
|
+
self.candidates_per_doc = int(candidates_per_doc)
|
|
120
123
|
self.query_jobs = int(query_jobs)
|
|
121
124
|
|
|
122
125
|
# Parameters for XGB boost ranker
|
|
@@ -126,17 +129,8 @@ class EbmModel:
|
|
|
126
129
|
self.train_rounds = int(xgb_rounds)
|
|
127
130
|
self.train_jobs = int(xgb_jobs)
|
|
128
131
|
|
|
129
|
-
#
|
|
130
|
-
|
|
131
|
-
self.logger = None
|
|
132
|
-
self.xgb_logger = None
|
|
133
|
-
self.xgb_callbacks = None
|
|
134
|
-
if log_path:
|
|
135
|
-
self.logger = EbmLogger(log_path, "info").get_logger()
|
|
136
|
-
self.xgb_logger = XGBLogging(self.logger, epoch_log_interval=1)
|
|
137
|
-
self.xgb_callbacks = [self.xgb_logger]
|
|
138
|
-
else:
|
|
139
|
-
self.logger = NullLogger()
|
|
132
|
+
# Initiliaze logging
|
|
133
|
+
self.init_logger(log_path, logger)
|
|
140
134
|
|
|
141
135
|
# Initialize EBM model
|
|
142
136
|
self.model = None
|
|
@@ -153,7 +147,9 @@ class EbmModel:
|
|
|
153
147
|
None
|
|
154
148
|
"""
|
|
155
149
|
if self.client is None:
|
|
156
|
-
self.logger.info(
|
|
150
|
+
self.logger.info(
|
|
151
|
+
f"initializing DuckDB client with duckdb_threads: {self.duckdb_threads}"
|
|
152
|
+
)
|
|
157
153
|
|
|
158
154
|
self.client = Duckdb_client(
|
|
159
155
|
db_path=self.db_path,
|
|
@@ -175,14 +171,35 @@ class EbmModel:
|
|
|
175
171
|
None
|
|
176
172
|
"""
|
|
177
173
|
if self.generator is None:
|
|
178
|
-
self.logger.info("
|
|
179
|
-
|
|
174
|
+
self.logger.info("initializing embedding generator")
|
|
180
175
|
self.generator = EmbeddingGenerator(
|
|
181
176
|
model_name=self.embedding_model_name,
|
|
182
177
|
embedding_dimensions=self.embedding_dimensions,
|
|
183
178
|
**self.model_args,
|
|
184
179
|
)
|
|
185
180
|
|
|
181
|
+
def init_logger(
|
|
182
|
+
self, log_path: str | None = None, logger: logging.Logger | None = None
|
|
183
|
+
) -> None:
|
|
184
|
+
"""
|
|
185
|
+
Initializes the logging for the EBM model.
|
|
186
|
+
|
|
187
|
+
Returns:
|
|
188
|
+
None
|
|
189
|
+
"""
|
|
190
|
+
if log_path:
|
|
191
|
+
self.logger = EbmLogger(log_path, "info").get_logger()
|
|
192
|
+
self.xgb_logger = XGBLogging(self.logger, epoch_log_interval=1)
|
|
193
|
+
self.xgb_callbacks = [self.xgb_logger]
|
|
194
|
+
elif logger:
|
|
195
|
+
self.logger = logger
|
|
196
|
+
self.xgb_logger = XGBLogging(self.logger, epoch_log_interval=1)
|
|
197
|
+
self.xgb_callbacks = [self.xgb_logger]
|
|
198
|
+
else:
|
|
199
|
+
self.logger = NullLogger()
|
|
200
|
+
self.xgb_logger = None
|
|
201
|
+
self.xgb_callbacks = None
|
|
202
|
+
|
|
186
203
|
def create_vector_db(
|
|
187
204
|
self,
|
|
188
205
|
vocab_in_path: str | None = None,
|
|
@@ -213,12 +230,12 @@ class EbmModel:
|
|
|
213
230
|
# Check if output path exists and load existing vocabulary if so
|
|
214
231
|
if vocab_out_path and Path(vocab_out_path).exists():
|
|
215
232
|
self.logger.info(
|
|
216
|
-
f"
|
|
233
|
+
f"loading vocabulary with embeddings from {vocab_out_path}"
|
|
217
234
|
)
|
|
218
235
|
collection_df = pl.read_ipc(vocab_out_path)
|
|
219
236
|
# Parse input vocabulary if provided
|
|
220
237
|
elif vocab_in_path:
|
|
221
|
-
self.logger.info("
|
|
238
|
+
self.logger.info("parsing vocabulary")
|
|
222
239
|
vocab = prepare_data.parse_vocab(
|
|
223
240
|
vocab_path=vocab_in_path,
|
|
224
241
|
use_altLabels=self.use_altLabels,
|
|
@@ -226,7 +243,7 @@ class EbmModel:
|
|
|
226
243
|
|
|
227
244
|
# Initialize generator and add embeddings to vocabulary
|
|
228
245
|
self._init_generator()
|
|
229
|
-
self.logger.info("
|
|
246
|
+
self.logger.info("adding embeddings to vocabulary")
|
|
230
247
|
collection_df = prepare_data.add_vocab_embeddings(
|
|
231
248
|
vocab=vocab,
|
|
232
249
|
generator=self.generator,
|
|
@@ -238,20 +255,20 @@ class EbmModel:
|
|
|
238
255
|
# Check if file already exists and warn if so
|
|
239
256
|
if Path(vocab_out_path).exists() and not force:
|
|
240
257
|
self.logger.warn(
|
|
241
|
-
f"""
|
|
258
|
+
f"""cant't save vocabulary to {vocab_out_path}.
|
|
242
259
|
File already exists"""
|
|
243
260
|
)
|
|
244
261
|
else:
|
|
245
|
-
self.logger.info(f"
|
|
262
|
+
self.logger.info(f"saving vocabulary to {vocab_out_path}")
|
|
246
263
|
collection_df.write_ipc(vocab_out_path)
|
|
247
264
|
else:
|
|
248
265
|
# If no existing vocabulary and no input vocabulary is provided,
|
|
249
266
|
# raise an error
|
|
250
|
-
raise ValueError("
|
|
267
|
+
raise ValueError("vocabulary path is required")
|
|
251
268
|
|
|
252
269
|
# Initialize DuckDB client and create collection
|
|
253
270
|
self._init_duckdb_client()
|
|
254
|
-
self.logger.info("
|
|
271
|
+
self.logger.info("creating collection")
|
|
255
272
|
self.client.create_collection(
|
|
256
273
|
collection_df=collection_df,
|
|
257
274
|
collection_name=self.collection_name,
|
|
@@ -286,8 +303,6 @@ class EbmModel:
|
|
|
286
303
|
Returns:
|
|
287
304
|
pl.DataFrame: The prepared training data.
|
|
288
305
|
"""
|
|
289
|
-
|
|
290
|
-
self.logger.info("Preparing training data")
|
|
291
306
|
# Check if pre-computed candidate training data is provided
|
|
292
307
|
if not train_candidates:
|
|
293
308
|
# If not, generate candidate training data in batches
|
|
@@ -306,7 +321,6 @@ class EbmModel:
|
|
|
306
321
|
)
|
|
307
322
|
|
|
308
323
|
# Create a gold standard data frame from the provided doc IDs and label IDs
|
|
309
|
-
self.logger.info("Preparing gold standard")
|
|
310
324
|
gold_standard = pl.DataFrame(
|
|
311
325
|
{
|
|
312
326
|
"doc_id": doc_ids,
|
|
@@ -318,7 +332,7 @@ class EbmModel:
|
|
|
318
332
|
|
|
319
333
|
# Compare the candidate training data to the gold standard
|
|
320
334
|
# and prepare data for the training of the XGB ranker model
|
|
321
|
-
self.logger.info("
|
|
335
|
+
self.logger.info("prepare training data and gold standard for training")
|
|
322
336
|
training_data = (
|
|
323
337
|
self._compare_to_gold_standard(train_candidates, gold_standard)
|
|
324
338
|
.with_columns(pl.when(pl.col("gold")).then(1).otherwise(0).alias("gold"))
|
|
@@ -406,19 +420,19 @@ class EbmModel:
|
|
|
406
420
|
n_jobs = self.query_jobs
|
|
407
421
|
|
|
408
422
|
# Create a Chunker instance with specified parameters
|
|
409
|
-
self.logger.info("
|
|
423
|
+
self.logger.info("chunking text")
|
|
410
424
|
chunker = Chunker(
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
425
|
+
tokenizer=self.chunk_tokenizer,
|
|
426
|
+
max_chunk_count=self.max_chunk_count,
|
|
427
|
+
max_chunk_length=self.max_chunk_length,
|
|
428
|
+
max_sentence_count=self.max_sentence_count,
|
|
415
429
|
)
|
|
416
430
|
# Chunk the input text
|
|
417
431
|
text_chunks = chunker.chunk_text(text)
|
|
418
432
|
|
|
419
433
|
# Initialize the generator
|
|
420
434
|
self._init_generator()
|
|
421
|
-
self.logger.info("
|
|
435
|
+
self.logger.info("creating embeddings for text chunks")
|
|
422
436
|
# Generate embeddings for the text chunks
|
|
423
437
|
embeddings = self.generator.generate_embeddings(
|
|
424
438
|
# Use the text chunks as input
|
|
@@ -432,7 +446,7 @@ class EbmModel:
|
|
|
432
446
|
)
|
|
433
447
|
|
|
434
448
|
# Create a query DataFrame
|
|
435
|
-
self.logger.info("
|
|
449
|
+
self.logger.info("creating query dataframe")
|
|
436
450
|
query_df = pl.DataFrame(
|
|
437
451
|
{
|
|
438
452
|
# Create a column for the query ID
|
|
@@ -450,7 +464,9 @@ class EbmModel:
|
|
|
450
464
|
|
|
451
465
|
# Initialize the DuckDB client
|
|
452
466
|
self._init_duckdb_client()
|
|
453
|
-
self.logger.info(
|
|
467
|
+
self.logger.info(
|
|
468
|
+
f"running vector search and creating candidates with query_jobs: {n_jobs}"
|
|
469
|
+
)
|
|
454
470
|
# Perform vector search using the query DataFrame
|
|
455
471
|
# Using the parameters specified for the EBM model
|
|
456
472
|
# and the optimal chunk size for the DuckDB
|
|
@@ -459,9 +475,9 @@ class EbmModel:
|
|
|
459
475
|
collection_name=self.collection_name,
|
|
460
476
|
embedding_dimensions=self.embedding_dimensions,
|
|
461
477
|
n_jobs=n_jobs,
|
|
462
|
-
n_hits=self.
|
|
478
|
+
n_hits=self.candidates_per_chunk,
|
|
463
479
|
chunk_size=1024,
|
|
464
|
-
top_k=self.
|
|
480
|
+
top_k=self.candidates_per_doc,
|
|
465
481
|
hnsw_metric_function="array_cosine_distance",
|
|
466
482
|
)
|
|
467
483
|
|
|
@@ -500,20 +516,20 @@ class EbmModel:
|
|
|
500
516
|
query_jobs = self.query_jobs
|
|
501
517
|
|
|
502
518
|
# Create a Chunker instance with specified parameters
|
|
503
|
-
self.logger.info("Chunking texts in batches")
|
|
504
519
|
chunker = Chunker(
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
|
|
520
|
+
tokenizer=self.chunk_tokenizer,
|
|
521
|
+
max_chunk_count=self.max_chunk_count,
|
|
522
|
+
max_chunk_length=self.max_chunk_length,
|
|
523
|
+
max_sentence_count=self.max_sentence_count,
|
|
509
524
|
)
|
|
510
525
|
# Chunk the input texts
|
|
526
|
+
self.logger.info(f"chunking texts with chunking_jobs: {chunking_jobs}")
|
|
511
527
|
text_chunks, chunk_index = chunker.chunk_batches(texts, doc_ids, chunking_jobs)
|
|
512
528
|
|
|
513
529
|
# Initialize the generator and chunk index
|
|
514
530
|
self._init_generator()
|
|
515
531
|
chunk_index = pl.concat(chunk_index).with_row_index("query_id")
|
|
516
|
-
self.logger.info("
|
|
532
|
+
self.logger.info("creating embeddings for text chunks and query dataframe")
|
|
517
533
|
embeddings = self.generator.generate_embeddings(
|
|
518
534
|
texts=text_chunks,
|
|
519
535
|
**(
|
|
@@ -531,15 +547,17 @@ class EbmModel:
|
|
|
531
547
|
# Perform vector search using the query DataFrame
|
|
532
548
|
# Using the parameters specified for the EBM model
|
|
533
549
|
# and the optimal chunk size for the DuckDB
|
|
534
|
-
self.logger.info(
|
|
550
|
+
self.logger.info(
|
|
551
|
+
f"running vector search and creating candidates with query_jobs: {query_jobs}"
|
|
552
|
+
)
|
|
535
553
|
candidates = self.client.vector_search(
|
|
536
554
|
query_df=query_df,
|
|
537
555
|
collection_name=self.collection_name,
|
|
538
556
|
embedding_dimensions=self.embedding_dimensions,
|
|
539
557
|
n_jobs=query_jobs,
|
|
540
|
-
n_hits=self.
|
|
558
|
+
n_hits=self.candidates_per_chunk,
|
|
541
559
|
chunk_size=1024,
|
|
542
|
-
top_k=self.
|
|
560
|
+
top_k=self.candidates_per_doc,
|
|
543
561
|
hnsw_metric_function="array_cosine_distance",
|
|
544
562
|
)
|
|
545
563
|
|
|
@@ -567,8 +585,8 @@ class EbmModel:
|
|
|
567
585
|
n_jobs = self.train_jobs
|
|
568
586
|
|
|
569
587
|
# Select the required columns from the train_data DataFrame,
|
|
570
|
-
# convert to a
|
|
571
|
-
self.logger.info("
|
|
588
|
+
# convert to a numpy array and afterwards to training matrix
|
|
589
|
+
self.logger.info("creating training matrix")
|
|
572
590
|
matrix = xgb.DMatrix(
|
|
573
591
|
train_data.select(
|
|
574
592
|
[
|
|
@@ -582,14 +600,16 @@ class EbmModel:
|
|
|
582
600
|
"is_prefLabel",
|
|
583
601
|
"n_chunks",
|
|
584
602
|
]
|
|
585
|
-
).
|
|
603
|
+
).to_numpy(),
|
|
586
604
|
# Use the gold standard as the target
|
|
587
|
-
train_data.
|
|
605
|
+
train_data.select("gold").to_numpy(),
|
|
588
606
|
)
|
|
589
607
|
|
|
590
608
|
try:
|
|
591
609
|
# Train the XGBoost model with the specified parameters
|
|
592
|
-
self.logger.info(
|
|
610
|
+
self.logger.info(
|
|
611
|
+
f"starting training of XGBoost Ranker with xgb_jobs: {n_jobs}"
|
|
612
|
+
)
|
|
593
613
|
model = xgb.train(
|
|
594
614
|
# Train the XGBoost model with the specified parameters
|
|
595
615
|
params={
|
|
@@ -611,7 +631,7 @@ class EbmModel:
|
|
|
611
631
|
# Use the specified callbacks
|
|
612
632
|
callbacks=self.xgb_callbacks,
|
|
613
633
|
)
|
|
614
|
-
self.logger.info("
|
|
634
|
+
self.logger.info("training successful finished")
|
|
615
635
|
except xgb.core.XGBoostError:
|
|
616
636
|
self.logger.critical(
|
|
617
637
|
"XGBoost can't train with candidates equal to gold standard "
|
|
@@ -641,7 +661,7 @@ class EbmModel:
|
|
|
641
661
|
"""
|
|
642
662
|
# Select relevant columns from the candidates DataFrame to create a matrix
|
|
643
663
|
# for the trained model to make predictions
|
|
644
|
-
self.logger.info("
|
|
664
|
+
self.logger.info("creating matrix of candidates to generate predictions")
|
|
645
665
|
matrix = xgb.DMatrix(
|
|
646
666
|
candidates.select(
|
|
647
667
|
[
|
|
@@ -659,7 +679,7 @@ class EbmModel:
|
|
|
659
679
|
)
|
|
660
680
|
|
|
661
681
|
# Use the trained model to make predictions on the created matrix
|
|
662
|
-
self.logger.info("
|
|
682
|
+
self.logger.info("making predictions for candidates")
|
|
663
683
|
predictions = self.model.predict(matrix)
|
|
664
684
|
|
|
665
685
|
# Transform the predictions into a list of DataFrames containing the
|
|
@@ -671,11 +691,12 @@ class EbmModel:
|
|
|
671
691
|
.select(["doc_id", "label_id", "score"])
|
|
672
692
|
# Sort the DataFrame by document ID and score in ascending and
|
|
673
693
|
# descending order, respectively
|
|
694
|
+
.with_columns(pl.col("doc_id").cast(pl.Int64))
|
|
674
695
|
.sort(["doc_id", "score"], descending=[False, True])
|
|
675
696
|
# Group the DataFrame by document ID and aggregate the top-k labels
|
|
676
697
|
# and scores for each group
|
|
677
698
|
.group_by("doc_id")
|
|
678
|
-
.agg(pl.all().head(self.
|
|
699
|
+
.agg(pl.all().head(self.candidates_per_doc))
|
|
679
700
|
# Explode the aggregated DataFrame to create separate rows for each
|
|
680
701
|
# label and score
|
|
681
702
|
.explode(["label_id", "score"])
|
|
@@ -683,7 +704,7 @@ class EbmModel:
|
|
|
683
704
|
.partition_by("doc_id")
|
|
684
705
|
)
|
|
685
706
|
|
|
686
|
-
def save(self, output_path: str) ->
|
|
707
|
+
def save(self, output_path: str) -> list[str]:
|
|
687
708
|
"""
|
|
688
709
|
Saves the current state of the EBM model to a file using joblib.
|
|
689
710
|
|
|
@@ -694,12 +715,18 @@ class EbmModel:
|
|
|
694
715
|
Args:
|
|
695
716
|
output_path: The file path where the serialized model will be written.
|
|
696
717
|
|
|
718
|
+
Returns:
|
|
719
|
+
list[str]: Output path of model file.
|
|
720
|
+
|
|
697
721
|
Notes:
|
|
698
|
-
The model's client and
|
|
722
|
+
The model's client, generator and loggers are reset to None.
|
|
699
723
|
"""
|
|
700
724
|
self.client = None
|
|
701
725
|
self.generator = None
|
|
702
|
-
|
|
726
|
+
|
|
727
|
+
self.init_logger()
|
|
728
|
+
|
|
729
|
+
return joblib.dump(self, output_path)
|
|
703
730
|
|
|
704
731
|
@staticmethod
|
|
705
732
|
def load(input_path: str) -> EbmModel:
|
|
@@ -707,9 +734,12 @@ class EbmModel:
|
|
|
707
734
|
Loads an EBM model from a joblib serialized file.
|
|
708
735
|
|
|
709
736
|
Args:
|
|
710
|
-
|
|
737
|
+
input_path (str): Path to the joblib serialized file containing the EBM model.
|
|
711
738
|
|
|
712
739
|
Returns:
|
|
713
|
-
|
|
740
|
+
EbmModel: The loaded EBM model instance.
|
|
714
741
|
"""
|
|
715
|
-
|
|
742
|
+
ebm_model = joblib.load(input_path)
|
|
743
|
+
ebm_model.init_logger()
|
|
744
|
+
|
|
745
|
+
return ebm_model
|
|
@@ -9,7 +9,8 @@ class EmbeddingGenerator:
|
|
|
9
9
|
A class for generating embeddings using a given SentenceTransformer model.
|
|
10
10
|
|
|
11
11
|
Args:
|
|
12
|
-
model_name (str): The name of the SentenceTransformer
|
|
12
|
+
model_name (str, SentenceTransformer): The name of the SentenceTransformer
|
|
13
|
+
model or an SentenceTransformer model to use.
|
|
13
14
|
embedding_dimensions (int): The dimensionality of the generated embeddings.
|
|
14
15
|
**kwargs: Additional keyword arguments to pass to the model.
|
|
15
16
|
|
|
@@ -19,7 +20,9 @@ class EmbeddingGenerator:
|
|
|
19
20
|
model (SentenceTransformer): The SentenceTransformer model instance.
|
|
20
21
|
"""
|
|
21
22
|
|
|
22
|
-
def __init__(
|
|
23
|
+
def __init__(
|
|
24
|
+
self, model_name: str | SentenceTransformer, embedding_dimensions: int, **kwargs
|
|
25
|
+
) -> None:
|
|
23
26
|
"""
|
|
24
27
|
Initializes the EmbeddingGenerator.
|
|
25
28
|
|
|
@@ -31,9 +34,13 @@ class EmbeddingGenerator:
|
|
|
31
34
|
|
|
32
35
|
# Create a SentenceTransformer model instance with the given
|
|
33
36
|
# model name and embedding dimensions
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
+
# or set model to the given SentenceTransformer
|
|
38
|
+
if type(model_name) is str:
|
|
39
|
+
self.model = SentenceTransformer(
|
|
40
|
+
model_name, truncate_dim=embedding_dimensions, **kwargs
|
|
41
|
+
)
|
|
42
|
+
else:
|
|
43
|
+
self.model = model_name
|
|
37
44
|
|
|
38
45
|
# Disabel parallelism for tokenizer
|
|
39
46
|
# Needed because process might be already parallelized
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: ebm4subjects
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.5.1
|
|
4
4
|
Summary: Embedding Based Matching for Automated Subject Indexing
|
|
5
5
|
Author: Deutsche Nationalbibliothek
|
|
6
6
|
Maintainer-email: Clemens Rietdorf <c.rietdorf@dnb.de>, Maximilian Kähler <m.kaehler@dnb.de>
|
|
@@ -13,9 +13,7 @@ Classifier: Operating System :: OS Independent
|
|
|
13
13
|
Classifier: Programming Language :: Python :: 3
|
|
14
14
|
Requires-Python: >=3.10
|
|
15
15
|
Requires-Dist: duckdb>=1.3.0
|
|
16
|
-
Requires-Dist: flash-attn>=2.8.2
|
|
17
16
|
Requires-Dist: nltk~=3.9.1
|
|
18
|
-
Requires-Dist: pandas>=2.3.0
|
|
19
17
|
Requires-Dist: polars>=1.30.0
|
|
20
18
|
Requires-Dist: pyarrow>=21.0.0
|
|
21
19
|
Requires-Dist: pyoxigraph>=0.4.11
|
|
@@ -56,6 +54,7 @@ This design borrows a lot of ideas from lexical matching like Maui [1], Kea [2]
|
|
|
56
54
|
|
|
57
55
|
[2] Frank, E., Paynter, G. W., Witten, I. H., Gutwin, C., & Nevill-Manning, C. G. (1999). Domain-Specific Keyphrase Extraction. Proceedings of the 16 Th International Joint Conference on Artifical Intelligence (IJCAI99), 668–673.
|
|
58
56
|
|
|
57
|
+

|
|
59
58
|
|
|
60
59
|
## Why embedding based matching
|
|
61
60
|
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
ebm4subjects/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
|
+
ebm4subjects/analyzer.py,sha256=lqX7AF8WsvwIavgtnmoVQ0i3wzBJJSeH47EiEwoLKGg,1664
|
|
3
|
+
ebm4subjects/chunker.py,sha256=HcEFJtKWHFYZL8DmZcHGXLPGEkCqHZhh_0kSqyYVsdE,6764
|
|
4
|
+
ebm4subjects/duckdb_client.py,sha256=8lDIpj2o2VTEtjHC_vTYrI5-RNXZnWMft45bS6z9B_k,13031
|
|
5
|
+
ebm4subjects/ebm_logging.py,sha256=xkbqeVhSCNuhMwkx2yoIX8_D3z9DcsauZEmHhR1gaS0,5962
|
|
6
|
+
ebm4subjects/ebm_model.py,sha256=PVFtljF3oZK8u0lA6df82lsTdAD8H1Y9CHvWq1jWF2M,29125
|
|
7
|
+
ebm4subjects/embedding_generator.py,sha256=DZhZxkjcsy_4NA62_2V-4UPbIUkg5qMPat_cIgsoIAA,2609
|
|
8
|
+
ebm4subjects/prepare_data.py,sha256=vQ-BdXkIP3iZJdPXol0WDlY8cRFMHkjzzL7oC7EbouE,3084
|
|
9
|
+
ebm4subjects-0.5.1.dist-info/METADATA,sha256=QkOBvOAI49_AUipc3yAH6RVG9OVUs_8jO64Bjfy561U,8274
|
|
10
|
+
ebm4subjects-0.5.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
11
|
+
ebm4subjects-0.5.1.dist-info/licenses/LICENSE,sha256=RpvAZSjULHvoTR_esTlucJ08-zdQydnoqQLbqOh9Ub8,13826
|
|
12
|
+
ebm4subjects-0.5.1.dist-info/RECORD,,
|
|
@@ -1,12 +0,0 @@
|
|
|
1
|
-
ebm4subjects/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
|
-
ebm4subjects/analyzer.py,sha256=kHsM2ZPzOIHp93UbdWtlgWARoH5ZbDueLsw9FJxpomM,1635
|
|
3
|
-
ebm4subjects/chunker.py,sha256=5LMOAHAxm_VlwSQnmVJjBxb4Vrdv7N-ioW8wcC-VvF0,6545
|
|
4
|
-
ebm4subjects/duckdb_client.py,sha256=JS6yyBe2p01cX_apFXjpYtT-w4Ow41HVhF3z9lKvvww,13046
|
|
5
|
-
ebm4subjects/ebm_logging.py,sha256=0tvodIHXdAGPzOXHwQF5lNBZYZTHD33mZrogr1btqV4,6001
|
|
6
|
-
ebm4subjects/ebm_model.py,sha256=sZI1QwKAH6wPPIxKbdLudD6rIJj7RNsDVJhV0fPBICw,28097
|
|
7
|
-
ebm4subjects/embedding_generator.py,sha256=jC4rz4W50tKndxYezD7Kaoqysl8zhN-TbWirxA_WIQc,2354
|
|
8
|
-
ebm4subjects/prepare_data.py,sha256=vQ-BdXkIP3iZJdPXol0WDlY8cRFMHkjzzL7oC7EbouE,3084
|
|
9
|
-
ebm4subjects-0.4.1.dist-info/METADATA,sha256=Oo_YR6zYDnhxWZa7Gp_HZuK7qIFIQWlA3dAbDsze_YE,8285
|
|
10
|
-
ebm4subjects-0.4.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
11
|
-
ebm4subjects-0.4.1.dist-info/licenses/LICENSE,sha256=RpvAZSjULHvoTR_esTlucJ08-zdQydnoqQLbqOh9Ub8,13826
|
|
12
|
-
ebm4subjects-0.4.1.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|