dragon-ml-toolbox 3.12.5__py3-none-any.whl → 4.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dragon-ml-toolbox might be problematic. Click here for more details.
- dragon_ml_toolbox-4.0.0.dist-info/METADATA +230 -0
- dragon_ml_toolbox-4.0.0.dist-info/RECORD +29 -0
- ml_tools/ETL_engineering.py +2 -2
- ml_tools/GUI_tools.py +28 -12
- ml_tools/MICE_imputation.py +4 -3
- ml_tools/ML_callbacks.py +8 -4
- ml_tools/ML_evaluation.py +11 -6
- ml_tools/ML_inference.py +131 -0
- ml_tools/ML_trainer.py +17 -8
- ml_tools/PSO_optimization.py +7 -12
- ml_tools/RNN_forecast.py +5 -0
- ml_tools/VIF_factor.py +4 -3
- ml_tools/_logger.py +36 -0
- ml_tools/_pytorch_models.py +1 -1
- ml_tools/_script_info.py +8 -0
- ml_tools/{logger.py → custom_logger.py} +4 -66
- ml_tools/data_exploration.py +2 -66
- ml_tools/datasetmaster.py +3 -2
- ml_tools/ensemble_inference.py +249 -0
- ml_tools/ensemble_learning.py +40 -294
- ml_tools/handle_excel.py +3 -2
- ml_tools/keys.py +13 -2
- ml_tools/path_manager.py +194 -31
- ml_tools/utilities.py +2 -180
- dragon_ml_toolbox-3.12.5.dist-info/METADATA +0 -137
- dragon_ml_toolbox-3.12.5.dist-info/RECORD +0 -26
- ml_tools/ML_tutorial.py +0 -300
- {dragon_ml_toolbox-3.12.5.dist-info → dragon_ml_toolbox-4.0.0.dist-info}/WHEEL +0 -0
- {dragon_ml_toolbox-3.12.5.dist-info → dragon_ml_toolbox-4.0.0.dist-info}/licenses/LICENSE +0 -0
- {dragon_ml_toolbox-3.12.5.dist-info → dragon_ml_toolbox-4.0.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +0 -0
- {dragon_ml_toolbox-3.12.5.dist-info → dragon_ml_toolbox-4.0.0.dist-info}/top_level.txt +0 -0
ml_tools/ML_trainer.py
CHANGED
|
@@ -7,9 +7,9 @@ import numpy as np
|
|
|
7
7
|
|
|
8
8
|
from .ML_callbacks import Callback, History, TqdmProgressBar
|
|
9
9
|
from .ML_evaluation import classification_metrics, regression_metrics, plot_losses, shap_summary_plot
|
|
10
|
-
from .
|
|
10
|
+
from ._script_info import _script_info
|
|
11
11
|
from .keys import LogKeys
|
|
12
|
-
from .
|
|
12
|
+
from ._logger import _LOGGER
|
|
13
13
|
|
|
14
14
|
|
|
15
15
|
__all__ = [
|
|
@@ -105,7 +105,7 @@ class MyTrainer:
|
|
|
105
105
|
pin_memory=(self.device.type == "cuda")
|
|
106
106
|
)
|
|
107
107
|
|
|
108
|
-
def fit(self, epochs: int = 10, batch_size: int =
|
|
108
|
+
def fit(self, epochs: int = 10, batch_size: int = 10, shuffle: bool = True):
|
|
109
109
|
"""
|
|
110
110
|
Starts the training-validation process of the model.
|
|
111
111
|
|
|
@@ -113,6 +113,13 @@ class MyTrainer:
|
|
|
113
113
|
epochs (int): The total number of epochs to train for.
|
|
114
114
|
batch_size (int): The number of samples per batch.
|
|
115
115
|
shuffle (bool): Whether to shuffle the training data at each epoch.
|
|
116
|
+
|
|
117
|
+
Note:
|
|
118
|
+
For regression tasks using `nn.MSELoss` or `nn.L1Loss`, the trainer
|
|
119
|
+
automatically aligns the model's output tensor with the target tensor's
|
|
120
|
+
shape using `output.view_as(target)`. This handles the common case
|
|
121
|
+
where a model outputs a shape of `[batch_size, 1]` and the target has a
|
|
122
|
+
shape of `[batch_size]`.
|
|
116
123
|
"""
|
|
117
124
|
self.epochs = epochs
|
|
118
125
|
self._create_dataloaders(batch_size, shuffle)
|
|
@@ -189,9 +196,10 @@ class MyTrainer:
|
|
|
189
196
|
logs = {LogKeys.VAL_LOSS: running_loss / len(self.test_loader.dataset)} # type: ignore
|
|
190
197
|
return logs
|
|
191
198
|
|
|
192
|
-
def
|
|
199
|
+
def _predict_for_eval(self, dataloader: DataLoader):
|
|
193
200
|
"""
|
|
194
|
-
|
|
201
|
+
Private method to yield model predictions batch by batch for evaluation.
|
|
202
|
+
This is used internally by the `evaluate` method.
|
|
195
203
|
|
|
196
204
|
Args:
|
|
197
205
|
dataloader (DataLoader): The dataloader to predict on.
|
|
@@ -213,13 +221,14 @@ class MyTrainer:
|
|
|
213
221
|
preds = torch.argmax(probs, dim=1)
|
|
214
222
|
y_pred_batch = preds.numpy()
|
|
215
223
|
y_prob_batch = probs.numpy()
|
|
224
|
+
# regression
|
|
216
225
|
else:
|
|
217
226
|
y_pred_batch = output.numpy()
|
|
218
227
|
y_prob_batch = None
|
|
219
228
|
|
|
220
229
|
yield y_pred_batch, y_prob_batch, y_true_batch
|
|
221
230
|
|
|
222
|
-
def evaluate(self,
|
|
231
|
+
def evaluate(self, save_dir: Optional[Union[str,Path]], data: Optional[Union[DataLoader, Dataset]] = None):
|
|
223
232
|
"""
|
|
224
233
|
Evaluates the model on the given data.
|
|
225
234
|
|
|
@@ -251,7 +260,7 @@ class MyTrainer:
|
|
|
251
260
|
|
|
252
261
|
# Collect results from the predict generator
|
|
253
262
|
all_preds, all_probs, all_true = [], [], []
|
|
254
|
-
for y_pred_b, y_prob_b, y_true_b in self.
|
|
263
|
+
for y_pred_b, y_prob_b, y_true_b in self._predict_for_eval(eval_loader):
|
|
255
264
|
all_preds.append(y_pred_b)
|
|
256
265
|
if y_prob_b is not None:
|
|
257
266
|
all_probs.append(y_prob_b)
|
|
@@ -270,7 +279,7 @@ class MyTrainer:
|
|
|
270
279
|
plot_losses(self.history, save_dir=save_dir)
|
|
271
280
|
|
|
272
281
|
def explain(self, explain_dataset: Optional[Dataset] = None, n_samples: int = 100,
|
|
273
|
-
feature_names: Optional[List[str]] = None, save_dir: Optional[str] = None):
|
|
282
|
+
feature_names: Optional[List[str]] = None, save_dir: Optional[Union[str,Path]] = None):
|
|
274
283
|
"""
|
|
275
284
|
Explains model predictions using SHAP and saves all artifacts.
|
|
276
285
|
|
ml_tools/PSO_optimization.py
CHANGED
|
@@ -2,28 +2,23 @@ import numpy as np
|
|
|
2
2
|
from pathlib import Path
|
|
3
3
|
import xgboost as xgb
|
|
4
4
|
import lightgbm as lgb
|
|
5
|
-
from sklearn.ensemble import HistGradientBoostingRegressor
|
|
6
|
-
from sklearn.base import ClassifierMixin
|
|
7
5
|
from typing import Literal, Union, Tuple, Dict, Optional
|
|
8
6
|
import pandas as pd
|
|
9
7
|
from copy import deepcopy
|
|
10
8
|
from .utilities import (
|
|
11
|
-
_script_info,
|
|
12
|
-
list_csv_paths,
|
|
13
9
|
threshold_binary_values,
|
|
14
10
|
threshold_binary_values_batch,
|
|
15
|
-
deserialize_object,
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
yield_dataframes_from_dir,
|
|
20
|
-
sanitize_filename)
|
|
11
|
+
deserialize_object,
|
|
12
|
+
save_dataframe,
|
|
13
|
+
yield_dataframes_from_dir)
|
|
14
|
+
from .path_manager import sanitize_filename, make_fullpath, list_files_by_extension, list_csv_paths
|
|
21
15
|
import torch
|
|
22
16
|
from tqdm import trange
|
|
23
17
|
import matplotlib.pyplot as plt
|
|
24
18
|
import seaborn as sns
|
|
25
|
-
from .
|
|
19
|
+
from ._logger import _LOGGER
|
|
26
20
|
from .keys import ModelSaveKeys
|
|
21
|
+
from ._script_info import _script_info
|
|
27
22
|
|
|
28
23
|
|
|
29
24
|
__all__ = [
|
|
@@ -125,7 +120,7 @@ class ObjectiveFunction():
|
|
|
125
120
|
return features_array * noise
|
|
126
121
|
|
|
127
122
|
def check_model(self):
|
|
128
|
-
if isinstance(self.model,
|
|
123
|
+
if isinstance(self.model, xgb.XGBClassifier) or isinstance(self.model, lgb.LGBMClassifier):
|
|
129
124
|
raise ValueError(f"[Model Check Failed] ❌\nThe loaded model ({type(self.model).__name__}) is a Classifier.\nOptimization is not suitable for standard classification tasks.")
|
|
130
125
|
if self.model is None:
|
|
131
126
|
raise ValueError("Loaded model is None")
|
ml_tools/RNN_forecast.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
from torch import nn
|
|
3
3
|
import numpy as np
|
|
4
|
+
from ._script_info import _script_info
|
|
4
5
|
|
|
5
6
|
__all__ = [
|
|
6
7
|
"rnn_forecast"
|
|
@@ -47,3 +48,7 @@ def rnn_forecast(model: nn.Module, start_sequence: torch.Tensor, steps: int, dev
|
|
|
47
48
|
|
|
48
49
|
# Concatenate all predictions and flatten the array for easy use
|
|
49
50
|
return np.concatenate(predictions).flatten()
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def info():
|
|
54
|
+
_script_info
|
ml_tools/VIF_factor.py
CHANGED
|
@@ -7,9 +7,10 @@ from statsmodels.stats.outliers_influence import variance_inflation_factor
|
|
|
7
7
|
from statsmodels.tools.tools import add_constant
|
|
8
8
|
import warnings
|
|
9
9
|
from pathlib import Path
|
|
10
|
-
from .utilities import
|
|
11
|
-
from .
|
|
12
|
-
|
|
10
|
+
from .utilities import yield_dataframes_from_dir, save_dataframe
|
|
11
|
+
from .path_manager import sanitize_filename, make_fullpath
|
|
12
|
+
from ._logger import _LOGGER
|
|
13
|
+
from ._script_info import _script_info
|
|
13
14
|
|
|
14
15
|
__all__ = [
|
|
15
16
|
"compute_vif",
|
ml_tools/_logger.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import sys
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def _get_logger(name: str = "ml_tools", level: int = logging.INFO):
|
|
6
|
+
"""
|
|
7
|
+
Initializes and returns a configured logger instance.
|
|
8
|
+
|
|
9
|
+
- `logger.info()`
|
|
10
|
+
- `logger.warning()`
|
|
11
|
+
- `logger.error()` the program can potentially recover.
|
|
12
|
+
- `logger.critical()` the program is going to crash.
|
|
13
|
+
"""
|
|
14
|
+
logger = logging.getLogger(name)
|
|
15
|
+
logger.setLevel(level)
|
|
16
|
+
|
|
17
|
+
# Prevents adding handlers multiple times if the function is called again
|
|
18
|
+
if not logger.handlers:
|
|
19
|
+
handler = logging.StreamHandler(sys.stdout)
|
|
20
|
+
|
|
21
|
+
# Define the format string and the date format separately
|
|
22
|
+
log_format = '\n🐉%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
|
23
|
+
date_format = '%Y-%m-%d %H:%M' # Format: Year-Month-Day Hour:Minute
|
|
24
|
+
|
|
25
|
+
# Pass both the format and the date format to the Formatter
|
|
26
|
+
formatter = logging.Formatter(log_format, datefmt=date_format)
|
|
27
|
+
|
|
28
|
+
handler.setFormatter(formatter)
|
|
29
|
+
logger.addHandler(handler)
|
|
30
|
+
|
|
31
|
+
logger.propagate = False
|
|
32
|
+
|
|
33
|
+
return logger
|
|
34
|
+
|
|
35
|
+
# Create a single logger instance to be imported by other modules
|
|
36
|
+
_LOGGER = _get_logger()
|
ml_tools/_pytorch_models.py
CHANGED
ml_tools/_script_info.py
ADDED
|
@@ -2,12 +2,11 @@ from pathlib import Path
|
|
|
2
2
|
from datetime import datetime
|
|
3
3
|
from typing import Union, List, Dict, Any
|
|
4
4
|
import pandas as pd
|
|
5
|
-
from openpyxl.styles import Font, PatternFill
|
|
6
5
|
import traceback
|
|
7
6
|
import json
|
|
8
|
-
from .
|
|
9
|
-
import
|
|
10
|
-
import
|
|
7
|
+
from .path_manager import sanitize_filename, make_fullpath
|
|
8
|
+
from ._script_info import _script_info
|
|
9
|
+
from ._logger import _LOGGER
|
|
11
10
|
|
|
12
11
|
|
|
13
12
|
__all__ = [
|
|
@@ -38,9 +37,6 @@ def custom_logger(
|
|
|
38
37
|
- dict[str, scalar] → .json
|
|
39
38
|
Dictionary is treated as structured data and serialized as JSON.
|
|
40
39
|
|
|
41
|
-
- pandas.DataFrame → .xlsx
|
|
42
|
-
Written to an Excel file with styled headers.
|
|
43
|
-
|
|
44
40
|
- str → .log
|
|
45
41
|
Plain text string is written to a .log file.
|
|
46
42
|
|
|
@@ -72,9 +68,6 @@ def custom_logger(
|
|
|
72
68
|
else:
|
|
73
69
|
_log_dict_to_json(data, base_path.with_suffix(".json"))
|
|
74
70
|
|
|
75
|
-
elif isinstance(data, pd.DataFrame):
|
|
76
|
-
_log_dataframe_to_xlsx(data, base_path.with_suffix(".xlsx"))
|
|
77
|
-
|
|
78
71
|
elif isinstance(data, str):
|
|
79
72
|
_log_string_to_log(data, base_path.with_suffix(".log"))
|
|
80
73
|
|
|
@@ -117,27 +110,6 @@ def _log_dict_to_csv(data: Dict[Any, List[Any]], path: Path) -> None:
|
|
|
117
110
|
df.to_csv(path, index=False)
|
|
118
111
|
|
|
119
112
|
|
|
120
|
-
def _log_dataframe_to_xlsx(data: pd.DataFrame, path: Path) -> None:
|
|
121
|
-
writer = pd.ExcelWriter(path, engine='openpyxl')
|
|
122
|
-
data.to_excel(writer, index=True, sheet_name='Data')
|
|
123
|
-
|
|
124
|
-
workbook = writer.book
|
|
125
|
-
worksheet = writer.sheets['Data']
|
|
126
|
-
|
|
127
|
-
header_font = Font(bold=True)
|
|
128
|
-
header_fill = PatternFill(
|
|
129
|
-
start_color="ADD8E6", # Light blue
|
|
130
|
-
end_color="ADD8E6",
|
|
131
|
-
fill_type="solid"
|
|
132
|
-
)
|
|
133
|
-
|
|
134
|
-
for cell in worksheet[1]:
|
|
135
|
-
cell.font = header_font
|
|
136
|
-
cell.fill = header_fill
|
|
137
|
-
|
|
138
|
-
writer.close()
|
|
139
|
-
|
|
140
|
-
|
|
141
113
|
def _log_string_to_log(data: str, path: Path) -> None:
|
|
142
114
|
with open(path, 'w', encoding='utf-8') as f:
|
|
143
115
|
f.write(data.strip() + '\n')
|
|
@@ -155,38 +127,4 @@ def _log_dict_to_json(data: Dict[Any, Any], path: Path) -> None:
|
|
|
155
127
|
|
|
156
128
|
|
|
157
129
|
def info():
|
|
158
|
-
_script_info(__all__)
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
def _get_logger(name: str = "ml_tools", level: int = logging.INFO):
|
|
162
|
-
"""
|
|
163
|
-
Initializes and returns a configured logger instance.
|
|
164
|
-
|
|
165
|
-
- `logger.info()`
|
|
166
|
-
- `logger.warning()`
|
|
167
|
-
- `logger.error()` the program can potentially recover.
|
|
168
|
-
- `logger.critical()` the program is going to crash.
|
|
169
|
-
"""
|
|
170
|
-
logger = logging.getLogger(name)
|
|
171
|
-
logger.setLevel(level)
|
|
172
|
-
|
|
173
|
-
# Prevents adding handlers multiple times if the function is called again
|
|
174
|
-
if not logger.handlers:
|
|
175
|
-
handler = logging.StreamHandler(sys.stdout)
|
|
176
|
-
|
|
177
|
-
# Define the format string and the date format separately
|
|
178
|
-
log_format = '\n🐉%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
|
179
|
-
date_format = '%Y-%m-%d %H:%M' # Format: Year-Month-Day Hour:Minute
|
|
180
|
-
|
|
181
|
-
# Pass both the format and the date format to the Formatter
|
|
182
|
-
formatter = logging.Formatter(log_format, datefmt=date_format)
|
|
183
|
-
|
|
184
|
-
handler.setFormatter(formatter)
|
|
185
|
-
logger.addHandler(handler)
|
|
186
|
-
|
|
187
|
-
logger.propagate = False
|
|
188
|
-
|
|
189
|
-
return logger
|
|
190
|
-
|
|
191
|
-
# Create a single logger instance to be imported by other modules
|
|
192
|
-
_LOGGER = _get_logger()
|
|
130
|
+
_script_info(__all__)
|
ml_tools/data_exploration.py
CHANGED
|
@@ -3,12 +3,10 @@ from pandas.api.types import is_numeric_dtype
|
|
|
3
3
|
import numpy as np
|
|
4
4
|
import matplotlib.pyplot as plt
|
|
5
5
|
import seaborn as sns
|
|
6
|
-
from IPython import get_ipython
|
|
7
|
-
from IPython.display import clear_output
|
|
8
|
-
import time
|
|
9
6
|
from typing import Union, Literal, Dict, Tuple, List, Optional
|
|
10
7
|
from pathlib import Path
|
|
11
|
-
from .
|
|
8
|
+
from .path_manager import sanitize_filename, make_fullpath
|
|
9
|
+
from ._script_info import _script_info
|
|
12
10
|
import re
|
|
13
11
|
|
|
14
12
|
|
|
@@ -22,7 +20,6 @@ __all__ = [
|
|
|
22
20
|
"drop_columns_with_missing_data",
|
|
23
21
|
"split_continuous_binary",
|
|
24
22
|
"plot_correlation_heatmap",
|
|
25
|
-
"check_value_distributions",
|
|
26
23
|
"plot_value_distributions",
|
|
27
24
|
"clip_outliers_single",
|
|
28
25
|
"clip_outliers_multi",
|
|
@@ -343,63 +340,6 @@ def plot_correlation_heatmap(df: pd.DataFrame,
|
|
|
343
340
|
plt.close()
|
|
344
341
|
|
|
345
342
|
|
|
346
|
-
def check_value_distributions(df: pd.DataFrame, view_frequencies: bool=True, bin_threshold: int=10, skip_cols_with_key: Union[str, None]=None):
|
|
347
|
-
"""
|
|
348
|
-
Analyzes value counts for each column in a DataFrame, optionally plots distributions,
|
|
349
|
-
and saves them as .png files in the specified directory.
|
|
350
|
-
|
|
351
|
-
Args:
|
|
352
|
-
df (pd.DataFrame): The dataset to analyze.
|
|
353
|
-
view_frequencies (bool): Print relative frequencies instead of value counts.
|
|
354
|
-
bin_threshold (int): Threshold of unique values to start using bins.
|
|
355
|
-
skip_cols_with_key (str | None): Skip column names containing the key. If None, don't skip any column.
|
|
356
|
-
|
|
357
|
-
Notes:
|
|
358
|
-
- Binning is adaptive: if quantile binning results in ≤ 2 unique bins, raw values are used instead.
|
|
359
|
-
"""
|
|
360
|
-
# cherry-pick columns
|
|
361
|
-
if skip_cols_with_key is not None:
|
|
362
|
-
columns = [col for col in df.columns if skip_cols_with_key not in col]
|
|
363
|
-
else:
|
|
364
|
-
columns = df.columns.to_list()
|
|
365
|
-
|
|
366
|
-
for col in columns:
|
|
367
|
-
if _is_notebook():
|
|
368
|
-
clear_output(wait=False)
|
|
369
|
-
if pd.api.types.is_numeric_dtype(df[col]) and df[col].nunique() > bin_threshold:
|
|
370
|
-
bins_number = 10
|
|
371
|
-
binned = pd.qcut(df[col], q=bins_number, duplicates='drop')
|
|
372
|
-
while binned.nunique() <= 2:
|
|
373
|
-
bins_number -= 1
|
|
374
|
-
binned = pd.qcut(df[col], q=bins_number, duplicates='drop')
|
|
375
|
-
if bins_number <= 2:
|
|
376
|
-
break
|
|
377
|
-
|
|
378
|
-
if binned.nunique() <= 2:
|
|
379
|
-
view_std = df[col].value_counts(ascending=False)
|
|
380
|
-
else:
|
|
381
|
-
view_std = binned.value_counts(sort=False)
|
|
382
|
-
|
|
383
|
-
else:
|
|
384
|
-
view_std = df[col].value_counts(ascending=False)
|
|
385
|
-
|
|
386
|
-
view_std.name = col
|
|
387
|
-
|
|
388
|
-
# unlikely scenario where the series is empty
|
|
389
|
-
if view_std.sum() == 0:
|
|
390
|
-
view_freq = view_std
|
|
391
|
-
else:
|
|
392
|
-
view_freq = view_std / view_std.sum()
|
|
393
|
-
# view_freq = df[col].value_counts(normalize=True, bins=10) # relative percentages
|
|
394
|
-
view_freq.name = col
|
|
395
|
-
|
|
396
|
-
# Print value counts
|
|
397
|
-
print(view_freq if view_frequencies else view_std)
|
|
398
|
-
|
|
399
|
-
time.sleep(1)
|
|
400
|
-
user_input_ = input("Press enter to continue")
|
|
401
|
-
|
|
402
|
-
|
|
403
343
|
def plot_value_distributions(df: pd.DataFrame, save_dir: Union[str, Path], bin_threshold: int=10, skip_cols_with_key: Union[str, None]=None):
|
|
404
344
|
"""
|
|
405
345
|
Plots and saves the value distributions for all (or selected) columns in a DataFrame,
|
|
@@ -691,9 +631,5 @@ def standardize_percentages(
|
|
|
691
631
|
return df_copy
|
|
692
632
|
|
|
693
633
|
|
|
694
|
-
def _is_notebook():
|
|
695
|
-
return get_ipython() is not None
|
|
696
|
-
|
|
697
|
-
|
|
698
634
|
def info():
|
|
699
635
|
_script_info(__all__)
|
ml_tools/datasetmaster.py
CHANGED
|
@@ -13,8 +13,9 @@ from torchvision.datasets import ImageFolder
|
|
|
13
13
|
from torchvision import transforms
|
|
14
14
|
import matplotlib.pyplot as plt
|
|
15
15
|
from pathlib import Path
|
|
16
|
-
from .
|
|
17
|
-
from .
|
|
16
|
+
from .path_manager import make_fullpath
|
|
17
|
+
from ._logger import _LOGGER
|
|
18
|
+
from ._script_info import _script_info
|
|
18
19
|
|
|
19
20
|
|
|
20
21
|
# --- public-facing API ---
|
|
@@ -0,0 +1,249 @@
|
|
|
1
|
+
from ._script_info import _script_info
|
|
2
|
+
from ._logger import _LOGGER
|
|
3
|
+
from .path_manager import make_fullpath, list_files_by_extension
|
|
4
|
+
from .keys import ModelSaveKeys
|
|
5
|
+
|
|
6
|
+
from typing import Union, Literal, Dict, Any, Optional, List
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
import json
|
|
9
|
+
|
|
10
|
+
import joblib
|
|
11
|
+
import numpy as np
|
|
12
|
+
# Inference models
|
|
13
|
+
import xgboost
|
|
14
|
+
import lightgbm
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
__all__ = [
|
|
18
|
+
"InferenceHandler",
|
|
19
|
+
"model_report"
|
|
20
|
+
]
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class InferenceHandler:
|
|
24
|
+
"""
|
|
25
|
+
Handles loading ensemble models and performing inference for either regression or classification tasks.
|
|
26
|
+
"""
|
|
27
|
+
def __init__(self,
|
|
28
|
+
models_dir: Union[str,Path],
|
|
29
|
+
task: Literal["classification", "regression"],
|
|
30
|
+
verbose: bool=True) -> None:
|
|
31
|
+
"""
|
|
32
|
+
Initializes the handler by loading all models from a directory.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
models_dir (Path): The directory containing the saved .joblib model files.
|
|
36
|
+
task ("regression" | "classification"): The type of task the models perform.
|
|
37
|
+
"""
|
|
38
|
+
self.models: Dict[str, Any] = dict()
|
|
39
|
+
self.task: str = task
|
|
40
|
+
self.verbose = verbose
|
|
41
|
+
self._feature_names: Optional[List[str]] = None
|
|
42
|
+
|
|
43
|
+
model_files = list_files_by_extension(directory=models_dir, extension="joblib")
|
|
44
|
+
|
|
45
|
+
for fname, fpath in model_files.items():
|
|
46
|
+
try:
|
|
47
|
+
full_object: dict
|
|
48
|
+
full_object = _deserialize_object(filepath=fpath,
|
|
49
|
+
verbose=self.verbose,
|
|
50
|
+
raise_on_error=True) # type: ignore
|
|
51
|
+
|
|
52
|
+
model: Any = full_object[ModelSaveKeys.MODEL]
|
|
53
|
+
target_name: str = full_object[ModelSaveKeys.TARGET]
|
|
54
|
+
feature_names_list: List[str] = full_object[ModelSaveKeys.FEATURES]
|
|
55
|
+
|
|
56
|
+
# Check that feature names match
|
|
57
|
+
if self._feature_names is None:
|
|
58
|
+
# Store the feature names from the first model loaded.
|
|
59
|
+
self._feature_names = feature_names_list
|
|
60
|
+
elif self._feature_names != feature_names_list:
|
|
61
|
+
# Add a warning if subsequent models have different feature names.
|
|
62
|
+
_LOGGER.warning(f"⚠️ Mismatched feature names in {fname}. Using feature order from the first model loaded.")
|
|
63
|
+
|
|
64
|
+
self.models[target_name] = model
|
|
65
|
+
if self.verbose:
|
|
66
|
+
_LOGGER.info(f"✅ Loaded model for target: {target_name}")
|
|
67
|
+
|
|
68
|
+
except Exception as e:
|
|
69
|
+
_LOGGER.warning(f"⚠️ Failed to load or parse {fname}: {e}")
|
|
70
|
+
|
|
71
|
+
@property
|
|
72
|
+
def feature_names(self) -> List[str]:
|
|
73
|
+
"""
|
|
74
|
+
Getter for the list of feature names the models expect.
|
|
75
|
+
Returns an empty list if no models were loaded.
|
|
76
|
+
"""
|
|
77
|
+
return self._feature_names if self._feature_names is not None else []
|
|
78
|
+
|
|
79
|
+
def predict(self, features: np.ndarray) -> Dict[str, Any]:
|
|
80
|
+
"""
|
|
81
|
+
Predicts on a single feature vector.
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
features (np.ndarray): A 1D or 2D NumPy array for a single sample.
|
|
85
|
+
|
|
86
|
+
Returns:
|
|
87
|
+
Dict[str, Any]: A dictionary where keys are target names.
|
|
88
|
+
- For regression: The value is the single predicted float.
|
|
89
|
+
- For classification: The value is another dictionary {'label': ..., 'probabilities': ...}.
|
|
90
|
+
"""
|
|
91
|
+
if features.ndim == 1:
|
|
92
|
+
features = features.reshape(1, -1)
|
|
93
|
+
|
|
94
|
+
if features.shape[0] != 1:
|
|
95
|
+
raise ValueError("The predict() method is for a single sample. Use predict_batch() for multiple samples.")
|
|
96
|
+
|
|
97
|
+
results: Dict[str, Any] = dict()
|
|
98
|
+
for target_name, model in self.models.items():
|
|
99
|
+
if self.task == "regression":
|
|
100
|
+
prediction = model.predict(features)
|
|
101
|
+
results[target_name] = prediction.item()
|
|
102
|
+
else: # Classification
|
|
103
|
+
label = model.predict(features)[0]
|
|
104
|
+
probabilities = model.predict_proba(features)[0]
|
|
105
|
+
results[target_name] = {ModelSaveKeys.CLASSIFICATION_LABEL: label,
|
|
106
|
+
ModelSaveKeys.CLASSIFICATION_PROBABILITIES: probabilities}
|
|
107
|
+
|
|
108
|
+
if self.verbose:
|
|
109
|
+
_LOGGER.info("✅ Inference process complete.")
|
|
110
|
+
return results
|
|
111
|
+
|
|
112
|
+
def predict_batch(self, features: np.ndarray) -> Dict[str, Any]:
|
|
113
|
+
"""
|
|
114
|
+
Predicts on a batch of feature vectors.
|
|
115
|
+
|
|
116
|
+
Args:
|
|
117
|
+
features (np.ndarray): A 2D NumPy array where each row is a sample.
|
|
118
|
+
|
|
119
|
+
Returns:
|
|
120
|
+
Dict[str, Any]: A dictionary where keys are target names.
|
|
121
|
+
- For regression: The value is a NumPy array of predictions.
|
|
122
|
+
- For classification: The value is another dictionary {'labels': ..., 'probabilities': ...}.
|
|
123
|
+
"""
|
|
124
|
+
if features.ndim != 2:
|
|
125
|
+
raise ValueError("Input for batch prediction must be a 2D array.")
|
|
126
|
+
|
|
127
|
+
results: Dict[str, Any] = dict()
|
|
128
|
+
for target_name, model in self.models.items():
|
|
129
|
+
if self.task == "regression":
|
|
130
|
+
results[target_name] = model.predict(features)
|
|
131
|
+
else: # Classification
|
|
132
|
+
labels = model.predict(features)
|
|
133
|
+
probabilities = model.predict_proba(features)
|
|
134
|
+
results[target_name] = {"labels": labels, "probabilities": probabilities}
|
|
135
|
+
|
|
136
|
+
if self.verbose:
|
|
137
|
+
_LOGGER.info("✅ Inference process complete.")
|
|
138
|
+
|
|
139
|
+
return results
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def model_report(
|
|
143
|
+
model_path: Union[str,Path],
|
|
144
|
+
output_dir: Optional[Union[str,Path]] = None,
|
|
145
|
+
verbose: bool = True
|
|
146
|
+
) -> Dict[str, Any]:
|
|
147
|
+
"""
|
|
148
|
+
Deserializes a model and generates a summary report.
|
|
149
|
+
|
|
150
|
+
This function loads a serialized model object (joblib), prints a summary to the
|
|
151
|
+
console (if verbose), and saves a detailed JSON report.
|
|
152
|
+
|
|
153
|
+
Args:
|
|
154
|
+
model_path (str): The path to the serialized model file.
|
|
155
|
+
output_dir (str, optional): Directory to save the JSON report.
|
|
156
|
+
If None, it defaults to the same directory as the model file.
|
|
157
|
+
verbose (bool, optional): If True, prints summary information
|
|
158
|
+
to the console. Defaults to True.
|
|
159
|
+
|
|
160
|
+
Returns:
|
|
161
|
+
(Dict[str, Any]): A dictionary containing the model metadata.
|
|
162
|
+
|
|
163
|
+
Raises:
|
|
164
|
+
FileNotFoundError: If the model_path does not exist.
|
|
165
|
+
KeyError: If the deserialized object is missing required keys from `ModelSaveKeys`.
|
|
166
|
+
"""
|
|
167
|
+
# 1. Convert to Path object
|
|
168
|
+
model_p = make_fullpath(model_path)
|
|
169
|
+
|
|
170
|
+
# --- 2. Deserialize and Extract Info ---
|
|
171
|
+
try:
|
|
172
|
+
full_object: dict = _deserialize_object(model_p) # type: ignore
|
|
173
|
+
model = full_object[ModelSaveKeys.MODEL]
|
|
174
|
+
target = full_object[ModelSaveKeys.TARGET]
|
|
175
|
+
features = full_object[ModelSaveKeys.FEATURES]
|
|
176
|
+
except FileNotFoundError:
|
|
177
|
+
_LOGGER.error(f"❌ Model file not found at '{model_p}'")
|
|
178
|
+
raise
|
|
179
|
+
except (KeyError, TypeError) as e:
|
|
180
|
+
_LOGGER.error(
|
|
181
|
+
f"❌ The serialized object is missing required keys '{ModelSaveKeys.MODEL}', '{ModelSaveKeys.TARGET}', '{ModelSaveKeys.FEATURES}'"
|
|
182
|
+
)
|
|
183
|
+
raise e
|
|
184
|
+
|
|
185
|
+
# --- 3. Print Summary to Console (if verbose) ---
|
|
186
|
+
if verbose:
|
|
187
|
+
print("\n--- 📝 Model Summary ---")
|
|
188
|
+
print(f"Source File: {model_p.name}")
|
|
189
|
+
print(f"Model Type: {type(model).__name__}")
|
|
190
|
+
print(f"Target: {target}")
|
|
191
|
+
print(f"Feature Count: {len(features)}")
|
|
192
|
+
print("-----------------------")
|
|
193
|
+
|
|
194
|
+
# --- 4. Generate JSON Report ---
|
|
195
|
+
report_data = {
|
|
196
|
+
"source_file": model_p.name,
|
|
197
|
+
"model_type": str(type(model)),
|
|
198
|
+
"target_name": target,
|
|
199
|
+
"feature_count": len(features),
|
|
200
|
+
"feature_names": features
|
|
201
|
+
}
|
|
202
|
+
|
|
203
|
+
# Determine output path
|
|
204
|
+
output_p = make_fullpath(output_dir, make=True) if output_dir else model_p.parent
|
|
205
|
+
json_filename = model_p.stem + "_info.json"
|
|
206
|
+
json_filepath = output_p / json_filename
|
|
207
|
+
|
|
208
|
+
try:
|
|
209
|
+
with open(json_filepath, 'w') as f:
|
|
210
|
+
json.dump(report_data, f, indent=4)
|
|
211
|
+
if verbose:
|
|
212
|
+
_LOGGER.info(f"✅ JSON report saved to: '{json_filepath}'")
|
|
213
|
+
except PermissionError:
|
|
214
|
+
_LOGGER.error(f"❌ Permission denied to write JSON report at '{json_filepath}'")
|
|
215
|
+
|
|
216
|
+
# --- 5. Return the extracted data ---
|
|
217
|
+
return report_data
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
# Local implementation to avoid calling utilities' dependencies
|
|
221
|
+
def _deserialize_object(filepath: Union[str,Path], verbose: bool=True, raise_on_error: bool=True) -> Optional[Any]:
|
|
222
|
+
"""
|
|
223
|
+
Loads a serialized object from a .joblib file.
|
|
224
|
+
|
|
225
|
+
Parameters:
|
|
226
|
+
filepath (str | Path): Full path to the serialized .joblib file.
|
|
227
|
+
|
|
228
|
+
Returns:
|
|
229
|
+
(Any | None): The deserialized Python object, or None if loading fails.
|
|
230
|
+
"""
|
|
231
|
+
true_filepath = make_fullpath(filepath)
|
|
232
|
+
|
|
233
|
+
try:
|
|
234
|
+
obj = joblib.load(true_filepath)
|
|
235
|
+
except (IOError, OSError, EOFError, TypeError, ValueError) as e:
|
|
236
|
+
message = f"❌ Failed to deserialize object from '{true_filepath}': {e}"
|
|
237
|
+
if raise_on_error:
|
|
238
|
+
raise Exception(message)
|
|
239
|
+
else:
|
|
240
|
+
print(message)
|
|
241
|
+
return None
|
|
242
|
+
else:
|
|
243
|
+
if verbose:
|
|
244
|
+
print(f"\n✅ Loaded object of type '{type(obj)}'")
|
|
245
|
+
return obj
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
def info():
|
|
249
|
+
_script_info(__all__)
|