dragon-ml-toolbox 12.11.0__py3-none-any.whl → 12.13.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dragon-ml-toolbox might be problematic. Click here for more details.
- {dragon_ml_toolbox-12.11.0.dist-info → dragon_ml_toolbox-12.13.0.dist-info}/METADATA +1 -1
- {dragon_ml_toolbox-12.11.0.dist-info → dragon_ml_toolbox-12.13.0.dist-info}/RECORD +12 -12
- ml_tools/ML_callbacks.py +33 -32
- ml_tools/ML_datasetmaster.py +21 -5
- ml_tools/ML_evaluation.py +90 -44
- ml_tools/ML_evaluation_multi.py +103 -32
- ml_tools/ML_models.py +7 -7
- ml_tools/ML_trainer.py +15 -4
- {dragon_ml_toolbox-12.11.0.dist-info → dragon_ml_toolbox-12.13.0.dist-info}/WHEEL +0 -0
- {dragon_ml_toolbox-12.11.0.dist-info → dragon_ml_toolbox-12.13.0.dist-info}/licenses/LICENSE +0 -0
- {dragon_ml_toolbox-12.11.0.dist-info → dragon_ml_toolbox-12.13.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +0 -0
- {dragon_ml_toolbox-12.11.0.dist-info → dragon_ml_toolbox-12.13.0.dist-info}/top_level.txt +0 -0
|
@@ -1,19 +1,19 @@
|
|
|
1
|
-
dragon_ml_toolbox-12.
|
|
2
|
-
dragon_ml_toolbox-12.
|
|
1
|
+
dragon_ml_toolbox-12.13.0.dist-info/licenses/LICENSE,sha256=L35WDmmLZNTlJvxF6Vy7Uy4SYNi6rCfWUqlTHpoRMoU,1081
|
|
2
|
+
dragon_ml_toolbox-12.13.0.dist-info/licenses/LICENSE-THIRD-PARTY.md,sha256=iy2r_R7wjzsCbz_Q_jMsp_jfZ6oP8XW9QhwzRBH0mGY,1904
|
|
3
3
|
ml_tools/ETL_cleaning.py,sha256=2VBRllV8F-ZiPylPp8Az2gwn5ztgazN0BH5OKnRUhV0,20402
|
|
4
4
|
ml_tools/ETL_engineering.py,sha256=KfYqgsxupAx6e_TxwO1LZXeu5mFkIhVXJrNjP3CzIZc,54927
|
|
5
5
|
ml_tools/GUI_tools.py,sha256=Va6ig-dHULPVRwQYYtH3fvY5XPIoqRcJpRW8oXC55Hw,45413
|
|
6
6
|
ml_tools/MICE_imputation.py,sha256=X273Qlgoqqg7KTmoKd75YDyAPB0UIbTzGP3xsCmRh3E,11717
|
|
7
|
-
ml_tools/ML_callbacks.py,sha256
|
|
8
|
-
ml_tools/ML_datasetmaster.py,sha256=
|
|
9
|
-
ml_tools/ML_evaluation.py,sha256=
|
|
10
|
-
ml_tools/ML_evaluation_multi.py,sha256=
|
|
7
|
+
ml_tools/ML_callbacks.py,sha256=2ZazJjlbClP-ALc8q0ru2oalkugbhO3TFwPg4RFZpck,14056
|
|
8
|
+
ml_tools/ML_datasetmaster.py,sha256=kedCGneR3S2zui0_JFZN6TBL5e69XWkdpkE_QohyqSM,31433
|
|
9
|
+
ml_tools/ML_evaluation.py,sha256=h7fAtk0lS4gTqQ46fiVjucTvFlX4rsufKnEtate6Nu0,18381
|
|
10
|
+
ml_tools/ML_evaluation_multi.py,sha256=Kn9n5lfxo7A0TvgIDMx8UHZCvzTqv1ViezzwJBF-ypM,15970
|
|
11
11
|
ml_tools/ML_inference.py,sha256=ymFvncFsU10PExq87xnEj541DKV5ck0nMuK8ToJHzVQ,23067
|
|
12
|
-
ml_tools/ML_models.py,sha256=
|
|
12
|
+
ml_tools/ML_models.py,sha256=G64NPhYZfYvHTIUwkIrMrNLgfDTKJwqdc8jwesPqB9E,28090
|
|
13
13
|
ml_tools/ML_optimization.py,sha256=es3TlQbY7RYgJMZnznkjYGbUxFnAqzZxE_g3_qLK9Q8,22960
|
|
14
14
|
ml_tools/ML_scaler.py,sha256=tw6onj9o8_kk3FQYb930HUzvv1zsFZe2YZJdF3LtHkU,7538
|
|
15
15
|
ml_tools/ML_simple_optimization.py,sha256=W2mce1XFCuiOHTOjOsCNbETISHn5MwYlYsTIXH5hMMo,18177
|
|
16
|
-
ml_tools/ML_trainer.py,sha256=
|
|
16
|
+
ml_tools/ML_trainer.py,sha256=UmCuKr_GzQGYqhEZ-kaRv9Buj44DsNyuOzmOM7Fw8N0,24569
|
|
17
17
|
ml_tools/ML_utilities.py,sha256=EnKpPTnJ2qjZmz7kvows4Uu5CfSA7ByRmI1v2-KarKw,9337
|
|
18
18
|
ml_tools/PSO_optimization.py,sha256=fVHeemqilBS0zrGV25E5yKwDlGdd2ZKa18d8CZ6Q6Fk,22961
|
|
19
19
|
ml_tools/RNN_forecast.py,sha256=Qa2KoZfdAvSjZ4yE78N4BFXtr3tTr0Gx7tQJZPotsh0,1967
|
|
@@ -35,7 +35,7 @@ ml_tools/optimization_tools.py,sha256=P074YCuZzkqkONnAsM-Zb9DTX_i8cRkkJLpwAWz6CR
|
|
|
35
35
|
ml_tools/path_manager.py,sha256=CyDU16pOKmC82jPubqJPT6EBt-u-3rGVbxyPIZCvDDY,18432
|
|
36
36
|
ml_tools/serde.py,sha256=ll2mVC0sO2jIEdG3K6xMcgEN13N4YSb8VjviGvw_ers,4949
|
|
37
37
|
ml_tools/utilities.py,sha256=OcAyV1tEcYAfOWlGjRgopsjDLxU3DcI5EynzvWV4q3A,15754
|
|
38
|
-
dragon_ml_toolbox-12.
|
|
39
|
-
dragon_ml_toolbox-12.
|
|
40
|
-
dragon_ml_toolbox-12.
|
|
41
|
-
dragon_ml_toolbox-12.
|
|
38
|
+
dragon_ml_toolbox-12.13.0.dist-info/METADATA,sha256=p3-oOSqq1hhJj13KjIXeFnwBu3UTfBJu5mTDL9MCpdU,6167
|
|
39
|
+
dragon_ml_toolbox-12.13.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
40
|
+
dragon_ml_toolbox-12.13.0.dist-info/top_level.txt,sha256=wm-oxax3ciyez6VoO4zsFd-gSok2VipYXnbg3TH9PtU,9
|
|
41
|
+
dragon_ml_toolbox-12.13.0.dist-info/RECORD,,
|
ml_tools/ML_callbacks.py
CHANGED
|
@@ -113,18 +113,19 @@ class TqdmProgressBar(Callback):
|
|
|
113
113
|
class EarlyStopping(Callback):
|
|
114
114
|
"""
|
|
115
115
|
Stop training when a monitored metric has stopped improving.
|
|
116
|
-
|
|
117
|
-
Args:
|
|
118
|
-
monitor (str): Quantity to be monitored. Defaults to 'val_loss'.
|
|
119
|
-
min_delta (float): Minimum change in the monitored quantity to qualify as an improvement.
|
|
120
|
-
patience (int): Number of epochs with no improvement after which training will be stopped.
|
|
121
|
-
mode (str): One of {'auto', 'min', 'max'}. In 'min' mode, training will stop when the quantity
|
|
122
|
-
monitored has stopped decreasing; in 'max' mode it will stop when the quantity
|
|
123
|
-
monitored has stopped increasing; in 'auto' mode, the direction is automatically
|
|
124
|
-
inferred from the name of the monitored quantity.
|
|
125
|
-
verbose (int): Verbosity mode.
|
|
126
116
|
"""
|
|
127
117
|
def __init__(self, monitor: str=PyTorchLogKeys.VAL_LOSS, min_delta: float=0.0, patience: int=5, mode: Literal['auto', 'min', 'max']='auto', verbose: int=1):
|
|
118
|
+
"""
|
|
119
|
+
Args:
|
|
120
|
+
monitor (str): Quantity to be monitored. Defaults to 'val_loss'.
|
|
121
|
+
min_delta (float): Minimum change in the monitored quantity to qualify as an improvement.
|
|
122
|
+
patience (int): Number of epochs with no improvement after which training will be stopped.
|
|
123
|
+
mode (str): One of {'auto', 'min', 'max'}. In 'min' mode, training will stop when the quantity
|
|
124
|
+
monitored has stopped decreasing; in 'max' mode it will stop when the quantity
|
|
125
|
+
monitored has stopped increasing; in 'auto' mode, the direction is automatically
|
|
126
|
+
inferred from the name of the monitored quantity.
|
|
127
|
+
verbose (int): Verbosity mode.
|
|
128
|
+
"""
|
|
128
129
|
super().__init__()
|
|
129
130
|
self.monitor = monitor
|
|
130
131
|
self.patience = patience
|
|
@@ -188,22 +189,23 @@ class EarlyStopping(Callback):
|
|
|
188
189
|
|
|
189
190
|
class ModelCheckpoint(Callback):
|
|
190
191
|
"""
|
|
191
|
-
Saves the model to a directory with automated filename generation and rotation.
|
|
192
|
-
|
|
193
|
-
- If `save_best_only` is True, it saves the single best model, deleting the
|
|
194
|
-
previous best.
|
|
195
|
-
- If `save_best_only` is False, it keeps the 3 most recent checkpoints,
|
|
196
|
-
deleting the oldest ones automatically.
|
|
197
|
-
|
|
198
|
-
Args:
|
|
199
|
-
save_dir (str): Directory where checkpoint files will be saved.
|
|
200
|
-
monitor (str): Metric to monitor for `save_best_only=True`.
|
|
201
|
-
save_best_only (bool): If true, save only the best model.
|
|
202
|
-
mode (str): One of {'auto', 'min', 'max'}.
|
|
203
|
-
verbose (int): Verbosity mode.
|
|
192
|
+
Saves the model weights to a directory with automated filename generation and rotation.
|
|
204
193
|
"""
|
|
205
194
|
def __init__(self, save_dir: Union[str,Path], checkpoint_name: Optional[str]=None, monitor: str = PyTorchLogKeys.VAL_LOSS,
|
|
206
195
|
save_best_only: bool = True, mode: Literal['auto', 'min', 'max']= 'auto', verbose: int = 0):
|
|
196
|
+
"""
|
|
197
|
+
- If `save_best_only` is True, it saves the single best model, deleting the previous best.
|
|
198
|
+
- If `save_best_only` is False, it keeps the 3 most recent checkpoints, deleting the oldest ones automatically.
|
|
199
|
+
|
|
200
|
+
Args:
|
|
201
|
+
save_dir (str): Directory where checkpoint files will be saved.
|
|
202
|
+
checkpoint_name (str| None): If None, the filename will include the epoch and score.
|
|
203
|
+
monitor (str): Metric to monitor for `save_best_only=True`.
|
|
204
|
+
save_best_only (bool): If true, save only the best model.
|
|
205
|
+
mode (str): One of {'auto', 'min', 'max'}.
|
|
206
|
+
verbose (int): Verbosity mode.
|
|
207
|
+
"""
|
|
208
|
+
|
|
207
209
|
super().__init__()
|
|
208
210
|
self.save_dir = make_fullpath(save_dir, make=True, enforce="directory")
|
|
209
211
|
if not self.save_dir.is_dir():
|
|
@@ -306,17 +308,16 @@ class ModelCheckpoint(Callback):
|
|
|
306
308
|
class LRScheduler(Callback):
|
|
307
309
|
"""
|
|
308
310
|
Callback to manage a PyTorch learning rate scheduler.
|
|
309
|
-
|
|
310
|
-
This callback automatically calls the scheduler's `step()` method at the
|
|
311
|
-
end of each epoch. It also logs a message when the learning rate changes.
|
|
312
|
-
|
|
313
|
-
Args:
|
|
314
|
-
scheduler: An initialized PyTorch learning rate scheduler.
|
|
315
|
-
monitor (str, optional): The metric to monitor for schedulers that
|
|
316
|
-
require it, like `ReduceLROnPlateau`.
|
|
317
|
-
Should match a key in the logs (e.g., 'val_loss').
|
|
318
311
|
"""
|
|
319
312
|
def __init__(self, scheduler, monitor: Optional[str] = None):
|
|
313
|
+
"""
|
|
314
|
+
This callback automatically calls the scheduler's `step()` method at the
|
|
315
|
+
end of each epoch. It also logs a message when the learning rate changes.
|
|
316
|
+
|
|
317
|
+
Args:
|
|
318
|
+
scheduler: An initialized PyTorch learning rate scheduler.
|
|
319
|
+
monitor (str, optional): The metric to monitor for schedulers that require it, like `ReduceLROnPlateau`. Should match a key in the logs (e.g., 'val_loss').
|
|
320
|
+
"""
|
|
320
321
|
super().__init__()
|
|
321
322
|
self.scheduler = scheduler
|
|
322
323
|
self.monitor = monitor
|
ml_tools/ML_datasetmaster.py
CHANGED
|
@@ -81,8 +81,7 @@ class _PytorchDataset(Dataset):
|
|
|
81
81
|
_LOGGER.error(f"Dataset {self.__class__} has not been initialized with any target names.")
|
|
82
82
|
|
|
83
83
|
|
|
84
|
-
# --- Abstract Base Class
|
|
85
|
-
# --- Abstract Base Class (Corrected) ---
|
|
84
|
+
# --- Abstract Base Class ---
|
|
86
85
|
class _BaseDatasetMaker(ABC):
|
|
87
86
|
"""
|
|
88
87
|
Abstract base class for dataset makers. Contains shared logic for
|
|
@@ -150,6 +149,14 @@ class _BaseDatasetMaker(ABC):
|
|
|
150
149
|
@property
|
|
151
150
|
def target_names(self) -> list[str]:
|
|
152
151
|
return self._target_names
|
|
152
|
+
|
|
153
|
+
@property
|
|
154
|
+
def number_of_features(self) -> int:
|
|
155
|
+
return len(self._feature_names)
|
|
156
|
+
|
|
157
|
+
@property
|
|
158
|
+
def number_of_targets(self) -> int:
|
|
159
|
+
return len(self._target_names)
|
|
153
160
|
|
|
154
161
|
@property
|
|
155
162
|
def id(self) -> Optional[str]:
|
|
@@ -180,14 +187,14 @@ class _BaseDatasetMaker(ABC):
|
|
|
180
187
|
filename=DatasetKeys.TARGET_NAMES,
|
|
181
188
|
verbose=verbose)
|
|
182
189
|
|
|
183
|
-
def save_scaler(self,
|
|
190
|
+
def save_scaler(self, directory: Union[str, Path], verbose: bool=True) -> None:
|
|
184
191
|
"""
|
|
185
192
|
Saves the fitted PytorchScaler's state to a .pth file.
|
|
186
193
|
|
|
187
194
|
The filename is automatically generated based on the dataset id.
|
|
188
195
|
|
|
189
196
|
Args:
|
|
190
|
-
|
|
197
|
+
directory (str | Path): The directory where the scaler will be saved.
|
|
191
198
|
"""
|
|
192
199
|
if not self.scaler:
|
|
193
200
|
_LOGGER.error("No scaler was fitted or provided.")
|
|
@@ -195,7 +202,7 @@ class _BaseDatasetMaker(ABC):
|
|
|
195
202
|
if not self.id:
|
|
196
203
|
_LOGGER.error("Must set the dataset `id` before saving scaler.")
|
|
197
204
|
raise ValueError()
|
|
198
|
-
save_path = make_fullpath(
|
|
205
|
+
save_path = make_fullpath(directory, make=True, enforce="directory")
|
|
199
206
|
sanitized_id = sanitize_filename(self.id)
|
|
200
207
|
filename = f"{DatasetKeys.SCALER_PREFIX}{sanitized_id}.pth"
|
|
201
208
|
filepath = save_path / filename
|
|
@@ -203,6 +210,15 @@ class _BaseDatasetMaker(ABC):
|
|
|
203
210
|
if verbose:
|
|
204
211
|
_LOGGER.info(f"Scaler for dataset '{self.id}' saved as '{filepath.name}'.")
|
|
205
212
|
|
|
213
|
+
def save_artifacts(self, directory: Union[str, Path], verbose: bool=True) -> None:
|
|
214
|
+
"""
|
|
215
|
+
Convenience method to save feature names, target names, and the scaler (if a scaler was fitted)
|
|
216
|
+
"""
|
|
217
|
+
self.save_feature_names(directory=directory, verbose=verbose)
|
|
218
|
+
self.save_target_names(directory=directory, verbose=verbose)
|
|
219
|
+
if self.scaler is not None:
|
|
220
|
+
self.save_scaler(directory=directory, verbose=verbose)
|
|
221
|
+
|
|
206
222
|
|
|
207
223
|
# Single target dataset
|
|
208
224
|
class DatasetMaker(_BaseDatasetMaker):
|
ml_tools/ML_evaluation.py
CHANGED
|
@@ -18,7 +18,7 @@ from sklearn.metrics import (
|
|
|
18
18
|
import torch
|
|
19
19
|
import shap
|
|
20
20
|
from pathlib import Path
|
|
21
|
-
from typing import Union, Optional, List
|
|
21
|
+
from typing import Union, Optional, List, Literal
|
|
22
22
|
|
|
23
23
|
from .path_manager import make_fullpath
|
|
24
24
|
from ._logger import _LOGGER
|
|
@@ -249,13 +249,15 @@ def regression_metrics(y_true: np.ndarray, y_pred: np.ndarray, save_dir: Union[s
|
|
|
249
249
|
plt.savefig(hist_path)
|
|
250
250
|
_LOGGER.info(f"📊 Residuals histogram saved as '{hist_path.name}'")
|
|
251
251
|
plt.close(fig_hist)
|
|
252
|
-
|
|
252
|
+
|
|
253
253
|
|
|
254
254
|
def shap_summary_plot(model,
|
|
255
255
|
background_data: Union[torch.Tensor,np.ndarray],
|
|
256
256
|
instances_to_explain: Union[torch.Tensor,np.ndarray],
|
|
257
257
|
feature_names: Optional[list[str]],
|
|
258
|
-
save_dir: Union[str, Path]
|
|
258
|
+
save_dir: Union[str, Path],
|
|
259
|
+
device: torch.device = torch.device('cpu'),
|
|
260
|
+
explainer_type: Literal['deep', 'kernel'] = 'deep'):
|
|
259
261
|
"""
|
|
260
262
|
Calculates SHAP values and saves summary plots and data.
|
|
261
263
|
|
|
@@ -265,48 +267,85 @@ def shap_summary_plot(model,
|
|
|
265
267
|
instances_to_explain (torch.Tensor): The specific data instances to explain.
|
|
266
268
|
feature_names (list of str | None): Names of the features for plot labeling.
|
|
267
269
|
save_dir (str | Path): Directory to save SHAP artifacts.
|
|
270
|
+
device (torch.device): The torch device for SHAP calculations.
|
|
271
|
+
explainer_type (Literal['deep', 'kernel']): The explainer to use.
|
|
272
|
+
- 'deep': (Default) Uses shap.DeepExplainer. Fast and efficient for
|
|
273
|
+
PyTorch models.
|
|
274
|
+
- 'kernel': Uses shap.KernelExplainer. Model-agnostic but EXTREMELY
|
|
275
|
+
slow and memory-intensive.
|
|
268
276
|
"""
|
|
269
|
-
# everything to numpy
|
|
270
|
-
if isinstance(background_data, np.ndarray):
|
|
271
|
-
background_data_np = background_data
|
|
272
|
-
else:
|
|
273
|
-
background_data_np = background_data.numpy()
|
|
274
|
-
|
|
275
|
-
if isinstance(instances_to_explain, np.ndarray):
|
|
276
|
-
instances_to_explain_np = instances_to_explain
|
|
277
|
-
else:
|
|
278
|
-
instances_to_explain_np = instances_to_explain.numpy()
|
|
279
277
|
|
|
280
|
-
|
|
281
|
-
if np.isnan(background_data_np).any() or np.isnan(instances_to_explain_np).any():
|
|
282
|
-
_LOGGER.error("Input data for SHAP contains NaN values. Aborting explanation.")
|
|
283
|
-
return
|
|
284
|
-
|
|
285
|
-
print("\n--- SHAP Value Explanation ---")
|
|
278
|
+
print(f"\n--- SHAP Value Explanation Using {explainer_type.upper()} Explainer ---")
|
|
286
279
|
|
|
287
280
|
model.eval()
|
|
288
|
-
model.cpu()
|
|
289
|
-
|
|
290
|
-
# 1. Summarize the background data.
|
|
291
|
-
# Summarize the background data using k-means. 10-50 clusters is a good starting point.
|
|
292
|
-
background_summary = shap.kmeans(background_data_np, 30)
|
|
293
|
-
|
|
294
|
-
# 2. Define a prediction function wrapper that SHAP can use. It must take a numpy array and return a numpy array.
|
|
295
|
-
def prediction_wrapper(x_np: np.ndarray) -> np.ndarray:
|
|
296
|
-
# Convert numpy data to torch tensor
|
|
297
|
-
x_torch = torch.from_numpy(x_np).float()
|
|
298
|
-
with torch.no_grad():
|
|
299
|
-
# Get model output
|
|
300
|
-
output = model(x_torch)
|
|
301
|
-
# Return as numpy array
|
|
302
|
-
return output.cpu().numpy().flatten()
|
|
303
|
-
|
|
304
|
-
# 3. Create the KernelExplainer
|
|
305
|
-
explainer = shap.KernelExplainer(prediction_wrapper, background_summary)
|
|
281
|
+
# model.cpu() # Run explanations on CPU
|
|
306
282
|
|
|
307
|
-
|
|
308
|
-
|
|
283
|
+
shap_values = None
|
|
284
|
+
instances_to_explain_np = None
|
|
285
|
+
|
|
286
|
+
if explainer_type == 'deep':
|
|
287
|
+
# --- 1. Use DeepExplainer (Preferred) ---
|
|
288
|
+
|
|
289
|
+
# Ensure data is torch.Tensor
|
|
290
|
+
if isinstance(background_data, np.ndarray):
|
|
291
|
+
background_data = torch.from_numpy(background_data).float()
|
|
292
|
+
if isinstance(instances_to_explain, np.ndarray):
|
|
293
|
+
instances_to_explain = torch.from_numpy(instances_to_explain).float()
|
|
294
|
+
|
|
295
|
+
if torch.isnan(background_data).any() or torch.isnan(instances_to_explain).any():
|
|
296
|
+
_LOGGER.error("Input data for SHAP contains NaN values. Aborting explanation.")
|
|
297
|
+
return
|
|
298
|
+
|
|
299
|
+
background_data = background_data.to(device)
|
|
300
|
+
instances_to_explain = instances_to_explain.to(device)
|
|
301
|
+
|
|
302
|
+
explainer = shap.DeepExplainer(model, background_data)
|
|
303
|
+
# print("Calculating SHAP values with DeepExplainer...")
|
|
304
|
+
shap_values = explainer.shap_values(instances_to_explain)
|
|
305
|
+
instances_to_explain_np = instances_to_explain.cpu().numpy()
|
|
306
|
+
|
|
307
|
+
elif explainer_type == 'kernel':
|
|
308
|
+
# --- 2. Use KernelExplainer (Slow Fallback) ---
|
|
309
|
+
_LOGGER.warning(
|
|
310
|
+
"Using KernelExplainer. This is memory-intensive and slow. "
|
|
311
|
+
"Consider reducing 'n_samples' if the process terminates unexpectedly."
|
|
312
|
+
)
|
|
313
|
+
|
|
314
|
+
# Ensure data is np.ndarray
|
|
315
|
+
if isinstance(background_data, torch.Tensor):
|
|
316
|
+
background_data_np = background_data.cpu().numpy()
|
|
317
|
+
else:
|
|
318
|
+
background_data_np = background_data
|
|
319
|
+
|
|
320
|
+
if isinstance(instances_to_explain, torch.Tensor):
|
|
321
|
+
instances_to_explain_np = instances_to_explain.cpu().numpy()
|
|
322
|
+
else:
|
|
323
|
+
instances_to_explain_np = instances_to_explain
|
|
324
|
+
|
|
325
|
+
if np.isnan(background_data_np).any() or np.isnan(instances_to_explain_np).any():
|
|
326
|
+
_LOGGER.error("Input data for SHAP contains NaN values. Aborting explanation.")
|
|
327
|
+
return
|
|
328
|
+
|
|
329
|
+
# Summarize background data
|
|
330
|
+
background_summary = shap.kmeans(background_data_np, 30)
|
|
331
|
+
|
|
332
|
+
def prediction_wrapper(x_np: np.ndarray) -> np.ndarray:
|
|
333
|
+
x_torch = torch.from_numpy(x_np).float().to(device)
|
|
334
|
+
with torch.no_grad():
|
|
335
|
+
output = model(x_torch)
|
|
336
|
+
# Return as numpy array
|
|
337
|
+
return output.cpu().numpy()
|
|
338
|
+
|
|
339
|
+
explainer = shap.KernelExplainer(prediction_wrapper, background_summary)
|
|
340
|
+
# print("Calculating SHAP values with KernelExplainer...")
|
|
341
|
+
shap_values = explainer.shap_values(instances_to_explain_np, l1_reg="aic")
|
|
342
|
+
# instances_to_explain_np is already set
|
|
309
343
|
|
|
344
|
+
else:
|
|
345
|
+
_LOGGER.error(f"Invalid explainer_type: '{explainer_type}'. Must be 'deep' or 'kernel'.")
|
|
346
|
+
raise ValueError()
|
|
347
|
+
|
|
348
|
+
# --- 3. Plotting and Saving ---
|
|
310
349
|
save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
311
350
|
plt.ioff()
|
|
312
351
|
|
|
@@ -326,8 +365,9 @@ def shap_summary_plot(model,
|
|
|
326
365
|
shap.summary_plot(shap_values, instances_to_explain_np, feature_names=feature_names, plot_type="dot", show=False)
|
|
327
366
|
ax = plt.gca()
|
|
328
367
|
ax.set_xlabel("SHAP Value Impact", labelpad=10)
|
|
329
|
-
|
|
330
|
-
|
|
368
|
+
if plt.gcf().axes and len(plt.gcf().axes) > 1:
|
|
369
|
+
cb = plt.gcf().axes[-1]
|
|
370
|
+
cb.set_ylabel("", size=1)
|
|
331
371
|
plt.title("SHAP Feature Importance")
|
|
332
372
|
plt.tight_layout()
|
|
333
373
|
plt.savefig(dot_path)
|
|
@@ -337,8 +377,14 @@ def shap_summary_plot(model,
|
|
|
337
377
|
# Save Summary Data to CSV
|
|
338
378
|
shap_summary_filename = SHAPKeys.SAVENAME + ".csv"
|
|
339
379
|
summary_path = save_dir_path / shap_summary_filename
|
|
340
|
-
|
|
341
|
-
|
|
380
|
+
|
|
381
|
+
# Handle multi-class (list of arrays) vs. regression (single array)
|
|
382
|
+
if isinstance(shap_values, list):
|
|
383
|
+
mean_abs_shap = np.abs(np.stack(shap_values)).mean(axis=0).mean(axis=0)
|
|
384
|
+
else:
|
|
385
|
+
mean_abs_shap = np.abs(shap_values).mean(axis=0)
|
|
386
|
+
|
|
387
|
+
mean_abs_shap = mean_abs_shap.flatten()
|
|
342
388
|
|
|
343
389
|
if feature_names is None:
|
|
344
390
|
feature_names = [f'feature_{i}' for i in range(len(mean_abs_shap))]
|
|
@@ -351,7 +397,7 @@ def shap_summary_plot(model,
|
|
|
351
397
|
summary_df.to_csv(summary_path, index=False)
|
|
352
398
|
|
|
353
399
|
_LOGGER.info(f"📝 SHAP summary data saved as '{summary_path.name}'")
|
|
354
|
-
plt.ion()
|
|
400
|
+
plt.ion()
|
|
355
401
|
|
|
356
402
|
|
|
357
403
|
def plot_attention_importance(weights: List[torch.Tensor], feature_names: Optional[List[str]], save_dir: Union[str, Path], top_n: int = 10):
|
ml_tools/ML_evaluation_multi.py
CHANGED
|
@@ -19,11 +19,12 @@ from sklearn.metrics import (
|
|
|
19
19
|
jaccard_score
|
|
20
20
|
)
|
|
21
21
|
from pathlib import Path
|
|
22
|
-
from typing import Union, List
|
|
22
|
+
from typing import Union, List, Literal
|
|
23
23
|
|
|
24
24
|
from .path_manager import make_fullpath, sanitize_filename
|
|
25
25
|
from ._logger import _LOGGER
|
|
26
26
|
from ._script_info import _script_info
|
|
27
|
+
from .keys import SHAPKeys
|
|
27
28
|
|
|
28
29
|
|
|
29
30
|
__all__ = [
|
|
@@ -231,10 +232,12 @@ def multi_target_shap_summary_plot(
|
|
|
231
232
|
instances_to_explain: Union[torch.Tensor, np.ndarray],
|
|
232
233
|
feature_names: List[str],
|
|
233
234
|
target_names: List[str],
|
|
234
|
-
save_dir: Union[str, Path]
|
|
235
|
+
save_dir: Union[str, Path],
|
|
236
|
+
device: torch.device = torch.device('cpu'),
|
|
237
|
+
explainer_type: Literal['deep', 'kernel'] = 'deep'
|
|
235
238
|
):
|
|
236
239
|
"""
|
|
237
|
-
Calculates SHAP values for a multi-target model and saves summary plots for each target.
|
|
240
|
+
Calculates SHAP values for a multi-target model and saves summary plots and data for each target.
|
|
238
241
|
|
|
239
242
|
Args:
|
|
240
243
|
model (torch.nn.Module): The trained PyTorch model.
|
|
@@ -243,40 +246,91 @@ def multi_target_shap_summary_plot(
|
|
|
243
246
|
feature_names (List[str]): Names of the features for plot labeling.
|
|
244
247
|
target_names (List[str]): Names of the output targets.
|
|
245
248
|
save_dir (str | Path): Directory to save SHAP artifacts.
|
|
249
|
+
device (torch.device): The torch device for SHAP calculations.
|
|
250
|
+
explainer_type (Literal['deep', 'kernel']): The explainer to use.
|
|
251
|
+
- 'deep': (Default) Uses shap.DeepExplainer. Fast and efficient.
|
|
252
|
+
- 'kernel': Uses shap.KernelExplainer. Model-agnostic but slow and memory-intensive.
|
|
246
253
|
"""
|
|
247
|
-
|
|
248
|
-
background_data_np = background_data.numpy() if isinstance(background_data, torch.Tensor) else background_data
|
|
249
|
-
instances_to_explain_np = instances_to_explain.numpy() if isinstance(instances_to_explain, torch.Tensor) else instances_to_explain
|
|
250
|
-
|
|
251
|
-
if np.isnan(background_data_np).any() or np.isnan(instances_to_explain_np).any():
|
|
252
|
-
_LOGGER.error("Input data for SHAP contains NaN values. Aborting explanation.")
|
|
253
|
-
return
|
|
254
|
-
|
|
255
|
-
_LOGGER.info("--- Multi-Target SHAP Value Explanation ---")
|
|
254
|
+
_LOGGER.info(f"--- Multi-Target SHAP Value Explanation (Using: {explainer_type.upper()}Explainer) ---")
|
|
256
255
|
model.eval()
|
|
257
|
-
model.cpu()
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
256
|
+
# model.cpu()
|
|
257
|
+
|
|
258
|
+
shap_values_list = None
|
|
259
|
+
instances_to_explain_np = None
|
|
260
|
+
|
|
261
|
+
if explainer_type == 'deep':
|
|
262
|
+
# --- 1. Use DeepExplainer (Preferred) ---
|
|
263
|
+
|
|
264
|
+
# Ensure data is torch.Tensor
|
|
265
|
+
if isinstance(background_data, np.ndarray):
|
|
266
|
+
background_data = torch.from_numpy(background_data).float()
|
|
267
|
+
if isinstance(instances_to_explain, np.ndarray):
|
|
268
|
+
instances_to_explain = torch.from_numpy(instances_to_explain).float()
|
|
269
|
+
|
|
270
|
+
if torch.isnan(background_data).any() or torch.isnan(instances_to_explain).any():
|
|
271
|
+
_LOGGER.error("Input data for SHAP contains NaN values. Aborting explanation.")
|
|
272
|
+
return
|
|
273
|
+
|
|
274
|
+
background_data = background_data.to(device)
|
|
275
|
+
instances_to_explain = instances_to_explain.to(device)
|
|
276
|
+
|
|
277
|
+
explainer = shap.DeepExplainer(model, background_data)
|
|
278
|
+
print("Calculating SHAP values with DeepExplainer...")
|
|
279
|
+
# DeepExplainer returns a list of arrays for multi-output models
|
|
280
|
+
shap_values_list = explainer.shap_values(instances_to_explain)
|
|
281
|
+
instances_to_explain_np = instances_to_explain.cpu().numpy()
|
|
282
|
+
|
|
283
|
+
elif explainer_type == 'kernel':
|
|
284
|
+
# --- 2. Use KernelExplainer (Slow Fallback) ---
|
|
285
|
+
_LOGGER.warning(
|
|
286
|
+
"Using KernelExplainer. This is memory-intensive and slow. "
|
|
287
|
+
"Consider reducing 'n_samples' if the process terminates."
|
|
288
|
+
)
|
|
289
|
+
|
|
290
|
+
# Convert all data to numpy
|
|
291
|
+
background_data_np = background_data.numpy() if isinstance(background_data, torch.Tensor) else background_data
|
|
292
|
+
instances_to_explain_np = instances_to_explain.numpy() if isinstance(instances_to_explain, torch.Tensor) else instances_to_explain
|
|
293
|
+
|
|
294
|
+
if np.isnan(background_data_np).any() or np.isnan(instances_to_explain_np).any():
|
|
295
|
+
_LOGGER.error("Input data for SHAP contains NaN values. Aborting explanation.")
|
|
296
|
+
return
|
|
297
|
+
|
|
298
|
+
background_summary = shap.kmeans(background_data_np, 30)
|
|
299
|
+
|
|
300
|
+
def prediction_wrapper(x_np: np.ndarray) -> np.ndarray:
|
|
301
|
+
x_torch = torch.from_numpy(x_np).float().to(device)
|
|
302
|
+
with torch.no_grad():
|
|
303
|
+
output = model(x_torch)
|
|
304
|
+
return output.cpu().numpy() # Return full multi-output array
|
|
305
|
+
|
|
306
|
+
explainer = shap.KernelExplainer(prediction_wrapper, background_summary)
|
|
307
|
+
print("Calculating SHAP values with KernelExplainer...")
|
|
308
|
+
# KernelExplainer also returns a list of arrays for multi-output models
|
|
309
|
+
shap_values_list = explainer.shap_values(instances_to_explain_np, l1_reg="aic")
|
|
310
|
+
# instances_to_explain_np is already set
|
|
311
|
+
|
|
312
|
+
else:
|
|
313
|
+
_LOGGER.error(f"Invalid explainer_type: '{explainer_type}'. Must be 'deep' or 'kernel'.")
|
|
314
|
+
raise ValueError("Invalid explainer_type")
|
|
315
|
+
|
|
316
|
+
# --- 3. Plotting and Saving (Common Logic) ---
|
|
317
|
+
|
|
318
|
+
if shap_values_list is None or instances_to_explain_np is None:
|
|
319
|
+
_LOGGER.error("SHAP value calculation failed. Aborting plotting.")
|
|
320
|
+
return
|
|
321
|
+
|
|
322
|
+
# Ensure number of SHAP value arrays matches number of target names
|
|
323
|
+
if len(shap_values_list) != len(target_names):
|
|
324
|
+
_LOGGER.error(
|
|
325
|
+
f"SHAP explanation mismatch: Model produced {len(shap_values_list)} "
|
|
326
|
+
f"outputs, but {len(target_names)} target_names were provided."
|
|
327
|
+
)
|
|
328
|
+
return
|
|
275
329
|
|
|
276
330
|
save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
277
331
|
plt.ioff()
|
|
278
332
|
|
|
279
|
-
#
|
|
333
|
+
# Iterate through each target's SHAP values and generate plots.
|
|
280
334
|
for i, target_name in enumerate(target_names):
|
|
281
335
|
print(f" -> Generating SHAP plots for target: '{target_name}'")
|
|
282
336
|
shap_values_for_target = shap_values_list[i]
|
|
@@ -293,11 +347,28 @@ def multi_target_shap_summary_plot(
|
|
|
293
347
|
# Save Dot Plot for the target
|
|
294
348
|
shap.summary_plot(shap_values_for_target, instances_to_explain_np, feature_names=feature_names, plot_type="dot", show=False)
|
|
295
349
|
plt.title(f"SHAP Feature Importance for '{target_name}'")
|
|
350
|
+
if plt.gcf().axes and len(plt.gcf().axes) > 1:
|
|
351
|
+
cb = plt.gcf().axes[-1]
|
|
352
|
+
cb.set_ylabel("", size=1)
|
|
296
353
|
plt.tight_layout()
|
|
297
354
|
dot_path = save_dir_path / f"shap_dot_plot_{sanitized_target_name}.svg"
|
|
298
355
|
plt.savefig(dot_path)
|
|
299
356
|
plt.close()
|
|
300
|
-
|
|
357
|
+
|
|
358
|
+
# --- Save Summary Data to CSV for this target ---
|
|
359
|
+
shap_summary_filename = f"{SHAPKeys.SAVENAME}_{sanitized_target_name}.csv"
|
|
360
|
+
summary_path = save_dir_path / shap_summary_filename
|
|
361
|
+
|
|
362
|
+
# For a specific target, shap_values_for_target is just a 2D array
|
|
363
|
+
mean_abs_shap = np.abs(shap_values_for_target).mean(axis=0).flatten()
|
|
364
|
+
|
|
365
|
+
summary_df = pd.DataFrame({
|
|
366
|
+
SHAPKeys.FEATURE_COLUMN: feature_names,
|
|
367
|
+
SHAPKeys.SHAP_VALUE_COLUMN: mean_abs_shap
|
|
368
|
+
}).sort_values(SHAPKeys.SHAP_VALUE_COLUMN, ascending=False)
|
|
369
|
+
|
|
370
|
+
summary_df.to_csv(summary_path, index=False)
|
|
371
|
+
|
|
301
372
|
plt.ion()
|
|
302
373
|
_LOGGER.info(f"All SHAP plots saved to '{save_dir_path.name}'")
|
|
303
374
|
|
ml_tools/ML_models.py
CHANGED
|
@@ -304,7 +304,7 @@ class TabularTransformer(nn.Module, _ArchitectureHandlerMixin):
|
|
|
304
304
|
def __init__(self, *,
|
|
305
305
|
in_features: int,
|
|
306
306
|
out_targets: int,
|
|
307
|
-
|
|
307
|
+
categorical_index_map: Dict[int, int],
|
|
308
308
|
embedding_dim: int = 32,
|
|
309
309
|
num_heads: int = 8,
|
|
310
310
|
num_layers: int = 6,
|
|
@@ -313,7 +313,7 @@ class TabularTransformer(nn.Module, _ArchitectureHandlerMixin):
|
|
|
313
313
|
Args:
|
|
314
314
|
in_features (int): The total number of columns in the input data (features).
|
|
315
315
|
out_targets (int): Number of output targets (1 for regression).
|
|
316
|
-
|
|
316
|
+
categorical_index_map (Dict[int, int]): Maps categorical column index to its cardinality (number of unique categories).
|
|
317
317
|
embedding_dim (int): The dimension for all feature embeddings. Must be divisible by num_heads.
|
|
318
318
|
num_heads (int): The number of heads in the multi-head attention mechanism.
|
|
319
319
|
num_layers (int): The number of sub-encoder-layers in the transformer encoder.
|
|
@@ -340,20 +340,20 @@ class TabularTransformer(nn.Module, _ArchitectureHandlerMixin):
|
|
|
340
340
|
super().__init__()
|
|
341
341
|
|
|
342
342
|
# --- Validation ---
|
|
343
|
-
if
|
|
344
|
-
_LOGGER.error(f"A categorical index ({max(
|
|
343
|
+
if categorical_index_map and max(categorical_index_map.keys()) >= in_features:
|
|
344
|
+
_LOGGER.error(f"A categorical index ({max(categorical_index_map.keys())}) is out of bounds for the provided input features ({in_features}).")
|
|
345
345
|
raise ValueError()
|
|
346
346
|
|
|
347
347
|
# --- Derive numerical indices ---
|
|
348
348
|
all_indices = set(range(in_features))
|
|
349
|
-
categorical_indices_set = set(
|
|
349
|
+
categorical_indices_set = set(categorical_index_map.keys())
|
|
350
350
|
numerical_indices = sorted(list(all_indices - categorical_indices_set))
|
|
351
351
|
|
|
352
352
|
# --- Save configuration ---
|
|
353
353
|
self.in_features = in_features
|
|
354
354
|
self.out_targets = out_targets
|
|
355
355
|
self.numerical_indices = numerical_indices
|
|
356
|
-
self.categorical_map =
|
|
356
|
+
self.categorical_map = categorical_index_map
|
|
357
357
|
self.embedding_dim = embedding_dim
|
|
358
358
|
self.num_heads = num_heads
|
|
359
359
|
self.num_layers = num_layers
|
|
@@ -362,7 +362,7 @@ class TabularTransformer(nn.Module, _ArchitectureHandlerMixin):
|
|
|
362
362
|
# --- 1. Feature Tokenizer ---
|
|
363
363
|
self.tokenizer = _FeatureTokenizer(
|
|
364
364
|
numerical_indices=numerical_indices,
|
|
365
|
-
categorical_map=
|
|
365
|
+
categorical_map=categorical_index_map,
|
|
366
366
|
embedding_dim=embedding_dim
|
|
367
367
|
)
|
|
368
368
|
|
ml_tools/ML_trainer.py
CHANGED
|
@@ -340,9 +340,10 @@ class MLTrainer:
|
|
|
340
340
|
def explain(self,
|
|
341
341
|
save_dir: Union[str,Path],
|
|
342
342
|
explain_dataset: Optional[Dataset] = None,
|
|
343
|
-
n_samples: int =
|
|
343
|
+
n_samples: int = 300,
|
|
344
344
|
feature_names: Optional[List[str]] = None,
|
|
345
|
-
target_names: Optional[List[str]] = None
|
|
345
|
+
target_names: Optional[List[str]] = None,
|
|
346
|
+
explainer_type: Literal['deep', 'kernel'] = 'deep'):
|
|
346
347
|
"""
|
|
347
348
|
Explains model predictions using SHAP and saves all artifacts.
|
|
348
349
|
|
|
@@ -359,6 +360,9 @@ class MLTrainer:
|
|
|
359
360
|
feature_names (list[str] | None): Feature names.
|
|
360
361
|
target_names (list[str] | None): Target names for multi-target tasks.
|
|
361
362
|
save_dir (str | Path): Directory to save all SHAP artifacts.
|
|
363
|
+
explainer_type (Literal['deep', 'kernel']): The explainer to use.
|
|
364
|
+
- 'deep': (Default) Uses shap.DeepExplainer. Fast and efficient for PyTorch models.
|
|
365
|
+
- 'kernel': Uses shap.KernelExplainer. Model-agnostic but EXTREMELY slow and memory-intensive. Use with a very low 'n_samples'< 100.
|
|
362
366
|
"""
|
|
363
367
|
# Internal helper to create a dataloader and get a random sample
|
|
364
368
|
def _get_random_sample(dataset: Dataset, num_samples: int):
|
|
@@ -410,6 +414,9 @@ class MLTrainer:
|
|
|
410
414
|
else:
|
|
411
415
|
_LOGGER.error("Could not extract `feature_names` from the dataset. It must be provided if the dataset object does not have a `feature_names` attribute.")
|
|
412
416
|
raise ValueError()
|
|
417
|
+
|
|
418
|
+
# move model to device
|
|
419
|
+
self.model.to(self.device)
|
|
413
420
|
|
|
414
421
|
# 3. Call the plotting function
|
|
415
422
|
if self.kind in ["regression", "classification"]:
|
|
@@ -418,7 +425,9 @@ class MLTrainer:
|
|
|
418
425
|
background_data=background_data,
|
|
419
426
|
instances_to_explain=instances_to_explain,
|
|
420
427
|
feature_names=feature_names,
|
|
421
|
-
save_dir=save_dir
|
|
428
|
+
save_dir=save_dir,
|
|
429
|
+
explainer_type=explainer_type,
|
|
430
|
+
device=self.device
|
|
422
431
|
)
|
|
423
432
|
elif self.kind in ["multi_target_regression", "multi_label_classification"]:
|
|
424
433
|
# try to get target names
|
|
@@ -442,7 +451,9 @@ class MLTrainer:
|
|
|
442
451
|
instances_to_explain=instances_to_explain,
|
|
443
452
|
feature_names=feature_names, # type: ignore
|
|
444
453
|
target_names=target_names, # type: ignore
|
|
445
|
-
save_dir=save_dir
|
|
454
|
+
save_dir=save_dir,
|
|
455
|
+
explainer_type=explainer_type,
|
|
456
|
+
device=self.device
|
|
446
457
|
)
|
|
447
458
|
|
|
448
459
|
def _attention_helper(self, dataloader: DataLoader):
|
|
File without changes
|
{dragon_ml_toolbox-12.11.0.dist-info → dragon_ml_toolbox-12.13.0.dist-info}/licenses/LICENSE
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|