dragon-ml-toolbox 3.12.6__py3-none-any.whl → 4.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dragon-ml-toolbox might be problematic. Click here for more details.
- dragon_ml_toolbox-4.0.0.dist-info/METADATA +230 -0
- dragon_ml_toolbox-4.0.0.dist-info/RECORD +29 -0
- ml_tools/ETL_engineering.py +2 -2
- ml_tools/GUI_tools.py +2 -2
- ml_tools/MICE_imputation.py +4 -3
- ml_tools/ML_callbacks.py +8 -4
- ml_tools/ML_evaluation.py +11 -6
- ml_tools/ML_inference.py +131 -0
- ml_tools/ML_trainer.py +17 -8
- ml_tools/PSO_optimization.py +7 -12
- ml_tools/RNN_forecast.py +5 -0
- ml_tools/VIF_factor.py +4 -3
- ml_tools/_logger.py +36 -0
- ml_tools/_pytorch_models.py +1 -1
- ml_tools/_script_info.py +8 -0
- ml_tools/{logger.py → custom_logger.py} +4 -66
- ml_tools/data_exploration.py +2 -66
- ml_tools/datasetmaster.py +3 -2
- ml_tools/ensemble_inference.py +249 -0
- ml_tools/ensemble_learning.py +40 -294
- ml_tools/handle_excel.py +3 -2
- ml_tools/keys.py +13 -2
- ml_tools/path_manager.py +194 -31
- ml_tools/utilities.py +2 -180
- dragon_ml_toolbox-3.12.6.dist-info/METADATA +0 -137
- dragon_ml_toolbox-3.12.6.dist-info/RECORD +0 -26
- ml_tools/ML_tutorial.py +0 -300
- {dragon_ml_toolbox-3.12.6.dist-info → dragon_ml_toolbox-4.0.0.dist-info}/WHEEL +0 -0
- {dragon_ml_toolbox-3.12.6.dist-info → dragon_ml_toolbox-4.0.0.dist-info}/licenses/LICENSE +0 -0
- {dragon_ml_toolbox-3.12.6.dist-info → dragon_ml_toolbox-4.0.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +0 -0
- {dragon_ml_toolbox-3.12.6.dist-info → dragon_ml_toolbox-4.0.0.dist-info}/top_level.txt +0 -0
ml_tools/ensemble_learning.py
CHANGED
|
@@ -1,18 +1,16 @@
|
|
|
1
1
|
import pandas as pd
|
|
2
2
|
import numpy as np
|
|
3
|
-
import
|
|
4
|
-
import seaborn as sns
|
|
3
|
+
import seaborn # Use plot styling
|
|
5
4
|
import matplotlib.pyplot as plt
|
|
6
5
|
from matplotlib.colors import Colormap
|
|
7
6
|
from matplotlib import rcdefaults
|
|
8
7
|
|
|
9
8
|
from pathlib import Path
|
|
10
|
-
from typing import Literal, Union, Optional, Iterator, Tuple
|
|
9
|
+
from typing import Literal, Union, Optional, Iterator, Tuple
|
|
11
10
|
|
|
12
11
|
from imblearn.over_sampling import ADASYN, SMOTE, RandomOverSampler
|
|
13
12
|
from imblearn.under_sampling import RandomUnderSampler
|
|
14
13
|
|
|
15
|
-
from sklearn.ensemble import HistGradientBoostingClassifier, HistGradientBoostingRegressor
|
|
16
14
|
import xgboost as xgb
|
|
17
15
|
import lightgbm as lgb
|
|
18
16
|
|
|
@@ -20,9 +18,11 @@ from sklearn.model_selection import train_test_split
|
|
|
20
18
|
from sklearn.metrics import accuracy_score, classification_report, ConfusionMatrixDisplay, mean_absolute_error, mean_squared_error, r2_score, roc_curve, roc_auc_score
|
|
21
19
|
import shap
|
|
22
20
|
|
|
23
|
-
from .utilities import yield_dataframes_from_dir,
|
|
21
|
+
from .utilities import yield_dataframes_from_dir, serialize_object
|
|
22
|
+
from .path_manager import sanitize_filename, make_fullpath
|
|
23
|
+
from ._script_info import _script_info
|
|
24
24
|
from .keys import ModelSaveKeys
|
|
25
|
-
from .
|
|
25
|
+
from ._logger import _LOGGER
|
|
26
26
|
|
|
27
27
|
import warnings # Ignore warnings
|
|
28
28
|
warnings.filterwarnings('ignore', category=DeprecationWarning)
|
|
@@ -41,8 +41,6 @@ __all__ = [
|
|
|
41
41
|
"get_shap_values",
|
|
42
42
|
"train_test_pipeline",
|
|
43
43
|
"run_ensemble_pipeline",
|
|
44
|
-
"InferenceHandler",
|
|
45
|
-
"model_report"
|
|
46
44
|
]
|
|
47
45
|
|
|
48
46
|
## Type aliases
|
|
@@ -81,7 +79,7 @@ def dataset_yielder(
|
|
|
81
79
|
class RegressionTreeModels:
|
|
82
80
|
"""
|
|
83
81
|
A factory class for creating and configuring multiple gradient boosting regression models
|
|
84
|
-
with unified hyperparameters. This includes XGBoost
|
|
82
|
+
with unified hyperparameters. This includes XGBoost and LightGBM.
|
|
85
83
|
|
|
86
84
|
Use the `__call__`, `()` method.
|
|
87
85
|
|
|
@@ -111,12 +109,6 @@ class RegressionTreeModels:
|
|
|
111
109
|
colsample_bytree : float [0.3 - 1.0]
|
|
112
110
|
Fraction of features per tree; useful for regularization (used by XGBoost and LightGBM).
|
|
113
111
|
|
|
114
|
-
min_samples_leaf : int [10 - 100]
|
|
115
|
-
Minimum samples per leaf; higher = less overfitting (used in HistGB).
|
|
116
|
-
|
|
117
|
-
max_iter : int [100 - 2000]
|
|
118
|
-
Maximum number of iterations (used in HistGB).
|
|
119
|
-
|
|
120
112
|
min_child_weight : float [0.1 - 10.0]
|
|
121
113
|
Minimum sum of instance weight (hessian) needed in a child; larger values make the algorithm more conservative (used in XGBoost).
|
|
122
114
|
|
|
@@ -130,20 +122,19 @@ class RegressionTreeModels:
|
|
|
130
122
|
Minimum number of data points in a leaf; increasing may prevent overfitting (used in LightGBM).
|
|
131
123
|
"""
|
|
132
124
|
def __init__(self,
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
min_data_in_leaf: int = 40):
|
|
125
|
+
random_state: int = 101,
|
|
126
|
+
learning_rate: float = 0.005,
|
|
127
|
+
L1_regularization: float = 1.0,
|
|
128
|
+
L2_regularization: float = 1.0,
|
|
129
|
+
n_estimators: int = 1000,
|
|
130
|
+
max_depth: int = 8,
|
|
131
|
+
subsample: float = 0.8,
|
|
132
|
+
colsample_bytree: float = 0.8,
|
|
133
|
+
min_child_weight: float = 3.0,
|
|
134
|
+
gamma: float = 1.0,
|
|
135
|
+
num_leaves: int = 31,
|
|
136
|
+
min_data_in_leaf: int = 40):
|
|
137
|
+
|
|
147
138
|
# General config
|
|
148
139
|
self.random_state = random_state
|
|
149
140
|
self.lr = learning_rate
|
|
@@ -165,16 +156,11 @@ class RegressionTreeModels:
|
|
|
165
156
|
self.num_leaves = num_leaves
|
|
166
157
|
self.min_data_in_leaf = min_data_in_leaf
|
|
167
158
|
|
|
168
|
-
# HistGB specific
|
|
169
|
-
self.max_iter = max_iter
|
|
170
|
-
self.min_samples_leaf = min_samples_leaf
|
|
171
|
-
|
|
172
159
|
def __call__(self) -> dict[str, object]:
|
|
173
160
|
"""
|
|
174
161
|
Returns a dictionary with new instances of:
|
|
175
162
|
- "XGBoost": XGBRegressor
|
|
176
163
|
- "LightGBM": LGBMRegressor
|
|
177
|
-
- "HistGB": HistGradientBoostingRegressor
|
|
178
164
|
"""
|
|
179
165
|
# XGBoost Regressor
|
|
180
166
|
xgb_model = xgb.XGBRegressor(
|
|
@@ -209,23 +195,9 @@ class RegressionTreeModels:
|
|
|
209
195
|
min_data_in_leaf=self.min_data_in_leaf
|
|
210
196
|
)
|
|
211
197
|
|
|
212
|
-
# HistGradientBoosting Regressor
|
|
213
|
-
hist_model = HistGradientBoostingRegressor(
|
|
214
|
-
max_iter=self.max_iter,
|
|
215
|
-
learning_rate=self.lr,
|
|
216
|
-
max_depth=self.max_depth,
|
|
217
|
-
min_samples_leaf=self.min_samples_leaf,
|
|
218
|
-
random_state=self.random_state,
|
|
219
|
-
l2_regularization=self.L2,
|
|
220
|
-
scoring='neg_mean_squared_error',
|
|
221
|
-
early_stopping=True,
|
|
222
|
-
validation_fraction=0.1
|
|
223
|
-
)
|
|
224
|
-
|
|
225
198
|
return {
|
|
226
199
|
"XGBoost": xgb_model,
|
|
227
|
-
"LightGBM": lgb_model
|
|
228
|
-
"HistGB": hist_model
|
|
200
|
+
"LightGBM": lgb_model
|
|
229
201
|
}
|
|
230
202
|
|
|
231
203
|
def __str__(self):
|
|
@@ -235,7 +207,7 @@ class RegressionTreeModels:
|
|
|
235
207
|
class ClassificationTreeModels:
|
|
236
208
|
"""
|
|
237
209
|
A factory class for creating and configuring multiple gradient boosting classification models
|
|
238
|
-
with unified hyperparameters. This includes
|
|
210
|
+
with unified hyperparameters. This includes XGBoost and LightGBM.
|
|
239
211
|
|
|
240
212
|
Use the `__call__`, `()` method.
|
|
241
213
|
|
|
@@ -265,12 +237,6 @@ class ClassificationTreeModels:
|
|
|
265
237
|
colsample_bytree : float [0.3 - 1.0]
|
|
266
238
|
Fraction of features per tree; useful for regularization (used by XGBoost and LightGBM).
|
|
267
239
|
|
|
268
|
-
min_samples_leaf : int [10 - 100]
|
|
269
|
-
Minimum number of samples required to be at a leaf node; higher = less overfitting (used in HistGB).
|
|
270
|
-
|
|
271
|
-
max_iter : int [100 - 2000]
|
|
272
|
-
Maximum number of boosting iteration (used in HistGB).
|
|
273
|
-
|
|
274
240
|
min_child_weight : float [0.1 - 10.0]
|
|
275
241
|
Minimum sum of instance weight (Hessian) in a child node; larger values make the algorithm more conservative (used in XGBoost).
|
|
276
242
|
|
|
@@ -289,20 +255,19 @@ class ClassificationTreeModels:
|
|
|
289
255
|
Indicates whether to apply class balancing strategies internally. Can be overridden at runtime via the `__call__` method.
|
|
290
256
|
"""
|
|
291
257
|
def __init__(self,
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
min_data_in_leaf: int = 40):
|
|
258
|
+
random_state: int = 101,
|
|
259
|
+
learning_rate: float = 0.005,
|
|
260
|
+
L1_regularization: float = 1.0,
|
|
261
|
+
L2_regularization: float = 1.0,
|
|
262
|
+
n_estimators: int = 1000,
|
|
263
|
+
max_depth: int = 8,
|
|
264
|
+
subsample: float = 0.8,
|
|
265
|
+
colsample_bytree: float = 0.8,
|
|
266
|
+
min_child_weight: float = 3.0,
|
|
267
|
+
gamma: float = 1.0,
|
|
268
|
+
num_leaves: int = 31,
|
|
269
|
+
min_data_in_leaf: int = 40):
|
|
270
|
+
|
|
306
271
|
# General config
|
|
307
272
|
self.random_state = random_state
|
|
308
273
|
self.lr = learning_rate
|
|
@@ -327,16 +292,11 @@ class ClassificationTreeModels:
|
|
|
327
292
|
self.num_leaves = num_leaves
|
|
328
293
|
self.min_data_in_leaf = min_data_in_leaf
|
|
329
294
|
|
|
330
|
-
# HistGB specific
|
|
331
|
-
self.max_iter = max_iter
|
|
332
|
-
self.min_samples_leaf = min_samples_leaf
|
|
333
|
-
|
|
334
295
|
def __call__(self, use_model_balance: Optional[bool]=None) -> dict[str, object]:
|
|
335
296
|
"""
|
|
336
297
|
Returns a dictionary with new instances of:
|
|
337
298
|
- "XGBoost": XGBClassifier
|
|
338
299
|
- "LightGBM": LGBMClassifier
|
|
339
|
-
- "HistGB": HistGradientBoostingClassifier
|
|
340
300
|
"""
|
|
341
301
|
if use_model_balance is not None:
|
|
342
302
|
self.use_model_balance = use_model_balance
|
|
@@ -376,24 +336,9 @@ class ClassificationTreeModels:
|
|
|
376
336
|
class_weight='balanced' if self.use_model_balance else None
|
|
377
337
|
)
|
|
378
338
|
|
|
379
|
-
# HistGradientBoosting Classifier
|
|
380
|
-
hist_model = HistGradientBoostingClassifier(
|
|
381
|
-
max_iter=self.max_iter,
|
|
382
|
-
learning_rate=self.lr,
|
|
383
|
-
max_depth=self.max_depth,
|
|
384
|
-
min_samples_leaf=self.min_samples_leaf,
|
|
385
|
-
random_state=self.random_state,
|
|
386
|
-
l2_regularization=self.L2,
|
|
387
|
-
early_stopping=True,
|
|
388
|
-
validation_fraction=0.1,
|
|
389
|
-
class_weight='balanced' if self.use_model_balance else None,
|
|
390
|
-
scoring='balanced_accuracy' if self.use_model_balance else 'loss'
|
|
391
|
-
)
|
|
392
|
-
|
|
393
339
|
return {
|
|
394
340
|
"XGBoost": xgb_model,
|
|
395
|
-
"LightGBM": lgb_model
|
|
396
|
-
"HistGB": hist_model
|
|
341
|
+
"LightGBM": lgb_model
|
|
397
342
|
}
|
|
398
343
|
|
|
399
344
|
def __str__(self):
|
|
@@ -577,7 +522,7 @@ def evaluate_model_classification(
|
|
|
577
522
|
|
|
578
523
|
fig.tight_layout()
|
|
579
524
|
fig_path = save_path / f"Confusion_Matrix_{sanitized_target_name}.svg"
|
|
580
|
-
fig.savefig(fig_path, format="svg", bbox_inches="tight")
|
|
525
|
+
fig.savefig(fig_path, format="svg", bbox_inches="tight") # type: ignore
|
|
581
526
|
plt.close(fig)
|
|
582
527
|
|
|
583
528
|
return y_pred
|
|
@@ -621,8 +566,8 @@ def plot_roc_curve(
|
|
|
621
566
|
# Determine predicted probabilities
|
|
622
567
|
if isinstance(probabilities_or_model, np.ndarray):
|
|
623
568
|
# Input is already probabilities
|
|
624
|
-
if probabilities_or_model.ndim == 2:
|
|
625
|
-
y_score = probabilities_or_model[:, 1]
|
|
569
|
+
if probabilities_or_model.ndim == 2: # type: ignore
|
|
570
|
+
y_score = probabilities_or_model[:, 1] # type: ignore
|
|
626
571
|
else:
|
|
627
572
|
y_score = probabilities_or_model
|
|
628
573
|
|
|
@@ -661,7 +606,7 @@ def plot_roc_curve(
|
|
|
661
606
|
save_path = make_fullpath(save_directory, make=True)
|
|
662
607
|
sanitized_target_name = sanitize_filename(target_name)
|
|
663
608
|
full_save_path = save_path / f"ROC_{sanitized_target_name}.svg"
|
|
664
|
-
fig.savefig(full_save_path, bbox_inches="tight", format="svg")
|
|
609
|
+
fig.savefig(full_save_path, bbox_inches="tight", format="svg") # type: ignore
|
|
665
610
|
|
|
666
611
|
return fig
|
|
667
612
|
|
|
@@ -943,204 +888,5 @@ def run_ensemble_pipeline(datasets_dir: Union[str,Path], save_dir: Union[str,Pat
|
|
|
943
888
|
_LOGGER.info("✅ Training and evaluation complete.")
|
|
944
889
|
|
|
945
890
|
|
|
946
|
-
###### 6. Inference ######
|
|
947
|
-
class InferenceHandler:
|
|
948
|
-
"""
|
|
949
|
-
Handles loading ensemble models and performing inference for either regression or classification tasks.
|
|
950
|
-
"""
|
|
951
|
-
def __init__(self,
|
|
952
|
-
models_dir: Union[str,Path],
|
|
953
|
-
task: TaskType,
|
|
954
|
-
verbose: bool=True) -> None:
|
|
955
|
-
"""
|
|
956
|
-
Initializes the handler by loading all models from a directory.
|
|
957
|
-
|
|
958
|
-
Args:
|
|
959
|
-
models_dir (Path): The directory containing the saved .joblib model files.
|
|
960
|
-
task ("regression" | "classification"): The type of task the models perform.
|
|
961
|
-
"""
|
|
962
|
-
self.models: Dict[str, Any] = dict()
|
|
963
|
-
self.task: str = task
|
|
964
|
-
self.verbose = verbose
|
|
965
|
-
self._feature_names: Optional[List[str]] = None
|
|
966
|
-
|
|
967
|
-
model_files = list_files_by_extension(directory=models_dir, extension="joblib")
|
|
968
|
-
|
|
969
|
-
for fname, fpath in model_files.items():
|
|
970
|
-
try:
|
|
971
|
-
full_object: dict
|
|
972
|
-
full_object = deserialize_object(filepath=fpath,
|
|
973
|
-
verbose=self.verbose,
|
|
974
|
-
raise_on_error=True) # type: ignore
|
|
975
|
-
|
|
976
|
-
model: Any = full_object[ModelSaveKeys.MODEL]
|
|
977
|
-
target_name: str = full_object[ModelSaveKeys.TARGET]
|
|
978
|
-
feature_names_list: List[str] = full_object[ModelSaveKeys.FEATURES]
|
|
979
|
-
|
|
980
|
-
# Check that feature names match
|
|
981
|
-
if self._feature_names is None:
|
|
982
|
-
# Store the feature names from the first model loaded.
|
|
983
|
-
self._feature_names = feature_names_list
|
|
984
|
-
elif self._feature_names != feature_names_list:
|
|
985
|
-
# Add a warning if subsequent models have different feature names.
|
|
986
|
-
_LOGGER.warning(f"⚠️ Mismatched feature names in {fname}. Using feature order from the first model loaded.")
|
|
987
|
-
|
|
988
|
-
self.models[target_name] = model
|
|
989
|
-
if self.verbose:
|
|
990
|
-
_LOGGER.info(f"✅ Loaded model for target: {target_name}")
|
|
991
|
-
|
|
992
|
-
except Exception as e:
|
|
993
|
-
_LOGGER.warning(f"⚠️ Failed to load or parse {fname}: {e}")
|
|
994
|
-
|
|
995
|
-
@property
|
|
996
|
-
def feature_names(self) -> List[str]:
|
|
997
|
-
"""
|
|
998
|
-
Getter for the list of feature names the models expect.
|
|
999
|
-
Returns an empty list if no models were loaded.
|
|
1000
|
-
"""
|
|
1001
|
-
return self._feature_names if self._feature_names is not None else []
|
|
1002
|
-
|
|
1003
|
-
def predict(self, features: np.ndarray) -> Dict[str, Any]:
|
|
1004
|
-
"""
|
|
1005
|
-
Predicts on a single feature vector.
|
|
1006
|
-
|
|
1007
|
-
Args:
|
|
1008
|
-
features (np.ndarray): A 1D or 2D NumPy array for a single sample.
|
|
1009
|
-
|
|
1010
|
-
Returns:
|
|
1011
|
-
Dict[str, Any]: A dictionary where keys are target names.
|
|
1012
|
-
- For regression: The value is the single predicted float.
|
|
1013
|
-
- For classification: The value is another dictionary {'label': ..., 'probabilities': ...}.
|
|
1014
|
-
"""
|
|
1015
|
-
if features.ndim == 1:
|
|
1016
|
-
features = features.reshape(1, -1)
|
|
1017
|
-
|
|
1018
|
-
if features.shape[0] != 1:
|
|
1019
|
-
raise ValueError("The predict() method is for a single sample. Use predict_batch() for multiple samples.")
|
|
1020
|
-
|
|
1021
|
-
results: Dict[str, Any] = dict()
|
|
1022
|
-
for target_name, model in self.models.items():
|
|
1023
|
-
if self.task == "regression":
|
|
1024
|
-
prediction = model.predict(features)
|
|
1025
|
-
results[target_name] = prediction.item()
|
|
1026
|
-
else: # Classification
|
|
1027
|
-
label = model.predict(features)[0]
|
|
1028
|
-
probabilities = model.predict_proba(features)[0]
|
|
1029
|
-
results[target_name] = {ModelSaveKeys.CLASSIFICATION_LABEL: label,
|
|
1030
|
-
ModelSaveKeys.CLASSIFICATION_PROBABILITIES: probabilities}
|
|
1031
|
-
|
|
1032
|
-
if self.verbose:
|
|
1033
|
-
_LOGGER.info("✅ Inference process complete.")
|
|
1034
|
-
return results
|
|
1035
|
-
|
|
1036
|
-
def predict_batch(self, features: np.ndarray) -> Dict[str, Any]:
|
|
1037
|
-
"""
|
|
1038
|
-
Predicts on a batch of feature vectors.
|
|
1039
|
-
|
|
1040
|
-
Args:
|
|
1041
|
-
features (np.ndarray): A 2D NumPy array where each row is a sample.
|
|
1042
|
-
|
|
1043
|
-
Returns:
|
|
1044
|
-
Dict[str, Any]: A dictionary where keys are target names.
|
|
1045
|
-
- For regression: The value is a NumPy array of predictions.
|
|
1046
|
-
- For classification: The value is another dictionary {'labels': ..., 'probabilities': ...}.
|
|
1047
|
-
"""
|
|
1048
|
-
if features.ndim != 2:
|
|
1049
|
-
raise ValueError("Input for batch prediction must be a 2D array.")
|
|
1050
|
-
|
|
1051
|
-
results: Dict[str, Any] = dict()
|
|
1052
|
-
for target_name, model in self.models.items():
|
|
1053
|
-
if self.task == "regression":
|
|
1054
|
-
results[target_name] = model.predict(features)
|
|
1055
|
-
else: # Classification
|
|
1056
|
-
labels = model.predict(features)
|
|
1057
|
-
probabilities = model.predict_proba(features)
|
|
1058
|
-
results[target_name] = {"labels": labels, "probabilities": probabilities}
|
|
1059
|
-
|
|
1060
|
-
if self.verbose:
|
|
1061
|
-
_LOGGER.info("✅ Inference process complete.")
|
|
1062
|
-
|
|
1063
|
-
return results
|
|
1064
|
-
|
|
1065
|
-
|
|
1066
|
-
###### 7. Save Model info report ######
|
|
1067
|
-
def model_report(
|
|
1068
|
-
model_path: Union[str,Path],
|
|
1069
|
-
output_dir: Optional[Union[str,Path]] = None,
|
|
1070
|
-
verbose: bool = True
|
|
1071
|
-
) -> Dict[str, Any]:
|
|
1072
|
-
"""
|
|
1073
|
-
Deserializes a model and generates a summary report.
|
|
1074
|
-
|
|
1075
|
-
This function loads a serialized model object (joblib), prints a summary to the
|
|
1076
|
-
console (if verbose), and saves a detailed JSON report.
|
|
1077
|
-
|
|
1078
|
-
Args:
|
|
1079
|
-
model_path (str): The path to the serialized model file.
|
|
1080
|
-
output_dir (str, optional): Directory to save the JSON report.
|
|
1081
|
-
If None, it defaults to the same directory as the model file.
|
|
1082
|
-
verbose (bool, optional): If True, prints summary information
|
|
1083
|
-
to the console. Defaults to True.
|
|
1084
|
-
|
|
1085
|
-
Returns:
|
|
1086
|
-
(Dict[str, Any]): A dictionary containing the model metadata.
|
|
1087
|
-
|
|
1088
|
-
Raises:
|
|
1089
|
-
FileNotFoundError: If the model_path does not exist.
|
|
1090
|
-
KeyError: If the deserialized object is missing required keys from `ModelSaveKeys`.
|
|
1091
|
-
"""
|
|
1092
|
-
# 1. Convert to Path object
|
|
1093
|
-
model_p = make_fullpath(model_path)
|
|
1094
|
-
|
|
1095
|
-
# --- 2. Deserialize and Extract Info ---
|
|
1096
|
-
try:
|
|
1097
|
-
full_object: dict = deserialize_object(model_p) # type: ignore
|
|
1098
|
-
model = full_object[ModelSaveKeys.MODEL]
|
|
1099
|
-
target = full_object[ModelSaveKeys.TARGET]
|
|
1100
|
-
features = full_object[ModelSaveKeys.FEATURES]
|
|
1101
|
-
except FileNotFoundError:
|
|
1102
|
-
_LOGGER.error(f"❌ Model file not found at '{model_p}'")
|
|
1103
|
-
raise
|
|
1104
|
-
except (KeyError, TypeError) as e:
|
|
1105
|
-
_LOGGER.error(
|
|
1106
|
-
f"❌ The serialized object is missing required keys '{ModelSaveKeys.MODEL}', '{ModelSaveKeys.TARGET}', '{ModelSaveKeys.FEATURES}'"
|
|
1107
|
-
)
|
|
1108
|
-
raise e
|
|
1109
|
-
|
|
1110
|
-
# --- 3. Print Summary to Console (if verbose) ---
|
|
1111
|
-
if verbose:
|
|
1112
|
-
print("\n--- 📝 Model Summary ---")
|
|
1113
|
-
print(f"Source File: {model_p.name}")
|
|
1114
|
-
print(f"Model Type: {type(model).__name__}")
|
|
1115
|
-
print(f"Target: {target}")
|
|
1116
|
-
print(f"Feature Count: {len(features)}")
|
|
1117
|
-
print("-----------------------")
|
|
1118
|
-
|
|
1119
|
-
# --- 4. Generate JSON Report ---
|
|
1120
|
-
report_data = {
|
|
1121
|
-
"source_file": model_p.name,
|
|
1122
|
-
"model_type": str(type(model)),
|
|
1123
|
-
"target_name": target,
|
|
1124
|
-
"feature_count": len(features),
|
|
1125
|
-
"feature_names": features
|
|
1126
|
-
}
|
|
1127
|
-
|
|
1128
|
-
# Determine output path
|
|
1129
|
-
output_p = make_fullpath(output_dir, make=True) if output_dir else model_p.parent
|
|
1130
|
-
json_filename = model_p.stem + "_info.json"
|
|
1131
|
-
json_filepath = output_p / json_filename
|
|
1132
|
-
|
|
1133
|
-
try:
|
|
1134
|
-
with open(json_filepath, 'w') as f:
|
|
1135
|
-
json.dump(report_data, f, indent=4)
|
|
1136
|
-
if verbose:
|
|
1137
|
-
_LOGGER.info(f"✅ JSON report saved to: '{json_filepath}'")
|
|
1138
|
-
except PermissionError:
|
|
1139
|
-
_LOGGER.error(f"❌ Permission denied to write JSON report at '{json_filepath}'")
|
|
1140
|
-
|
|
1141
|
-
# --- 5. Return the extracted data ---
|
|
1142
|
-
return report_data
|
|
1143
|
-
|
|
1144
|
-
|
|
1145
891
|
def info():
|
|
1146
892
|
_script_info(__all__)
|
ml_tools/handle_excel.py
CHANGED
|
@@ -2,8 +2,9 @@ from pathlib import Path
|
|
|
2
2
|
from openpyxl import load_workbook, Workbook
|
|
3
3
|
import pandas as pd
|
|
4
4
|
from typing import List, Optional, Union
|
|
5
|
-
from .
|
|
6
|
-
from .
|
|
5
|
+
from .path_manager import sanitize_filename, make_fullpath
|
|
6
|
+
from ._script_info import _script_info
|
|
7
|
+
from ._logger import _LOGGER
|
|
7
8
|
|
|
8
9
|
|
|
9
10
|
__all__ = [
|
ml_tools/keys.py
CHANGED
|
@@ -16,7 +16,7 @@ class LogKeys:
|
|
|
16
16
|
|
|
17
17
|
class ModelSaveKeys:
|
|
18
18
|
"""
|
|
19
|
-
Used internally
|
|
19
|
+
Used internally by ensemble_learning.
|
|
20
20
|
"""
|
|
21
21
|
# Serializing a trained model metadata.
|
|
22
22
|
MODEL = "model"
|
|
@@ -24,11 +24,22 @@ class ModelSaveKeys:
|
|
|
24
24
|
TARGET = "target_name"
|
|
25
25
|
|
|
26
26
|
# Classification keys
|
|
27
|
-
CLASSIFICATION_LABEL = "
|
|
27
|
+
CLASSIFICATION_LABEL = "labels"
|
|
28
28
|
CLASSIFICATION_PROBABILITIES = "probabilities"
|
|
29
29
|
|
|
30
30
|
|
|
31
|
+
class PyTorchInferenceKeys:
|
|
32
|
+
"""Keys for the output dictionaries of PyTorchInferenceHandler."""
|
|
33
|
+
# For regression tasks
|
|
34
|
+
PREDICTIONS = "predictions"
|
|
35
|
+
|
|
36
|
+
# For classification tasks
|
|
37
|
+
LABELS = "labels"
|
|
38
|
+
PROBABILITIES = "probabilities"
|
|
39
|
+
|
|
40
|
+
|
|
31
41
|
class _OneHotOtherPlaceholder:
|
|
42
|
+
"""Used internally by GUI_tools."""
|
|
32
43
|
OTHER_GUI = "OTHER"
|
|
33
44
|
OTHER_MODEL = "one hot OTHER placeholder"
|
|
34
45
|
OTHER_DICT = {OTHER_GUI: OTHER_MODEL}
|