dragon-ml-toolbox 5.3.0__py3-none-any.whl → 6.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dragon-ml-toolbox might be problematic. Click here for more details.

@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: dragon-ml-toolbox
3
- Version: 5.3.0
3
+ Version: 6.0.0
4
4
  Summary: A collection of tools for data science and machine learning projects.
5
5
  Author-email: Karl Loza <luigiloza@gmail.com>
6
6
  License-Expression: MIT
@@ -141,19 +141,22 @@ pip install "dragon-ml-toolbox[pytorch]"
141
141
  ```bash
142
142
  custom_logger
143
143
  data_exploration
144
- ensemble_learning
144
+ ensemble_evaluation
145
145
  ensemble_inference
146
+ ensemble_learning
146
147
  ETL_engineering
147
- ML_datasetmaster
148
- ML_models
149
148
  ML_callbacks
149
+ ML_datasetmaster
150
150
  ML_evaluation
151
- ML_trainer
152
151
  ML_inference
152
+ ML_models
153
+ ML_optimization
154
+ ML_trainer
155
+ optimization_tools
153
156
  path_manager
154
157
  PSO_optimization
155
- SQL
156
158
  RNN_forecast
159
+ SQL
157
160
  utilities
158
161
  ```
159
162
 
@@ -1,16 +1,16 @@
1
- dragon_ml_toolbox-5.3.0.dist-info/licenses/LICENSE,sha256=2uUFNy7D0TLgHim1K5s3DIJ4q_KvxEXVilnU20cWliY,1066
2
- dragon_ml_toolbox-5.3.0.dist-info/licenses/LICENSE-THIRD-PARTY.md,sha256=lY4_rJPnLnMu7YBQaY-_iz1JRDcLdQzNCyeLAF1glJY,1837
1
+ dragon_ml_toolbox-6.0.0.dist-info/licenses/LICENSE,sha256=2uUFNy7D0TLgHim1K5s3DIJ4q_KvxEXVilnU20cWliY,1066
2
+ dragon_ml_toolbox-6.0.0.dist-info/licenses/LICENSE-THIRD-PARTY.md,sha256=lY4_rJPnLnMu7YBQaY-_iz1JRDcLdQzNCyeLAF1glJY,1837
3
3
  ml_tools/ETL_engineering.py,sha256=4wwZXi9_U7xfCY70jGBaKniOeZ0m75ppxWpQBd_DmLc,39369
4
4
  ml_tools/GUI_tools.py,sha256=n4ZZ5kEjwK5rkOCFJE41HeLFfjhpJVLUSzk9Kd9Kr_0,45410
5
5
  ml_tools/MICE_imputation.py,sha256=oFHg-OytOzPYTzBR_wIRHhP71cMn3aupDeT59ABsXlQ,11576
6
- ml_tools/ML_callbacks.py,sha256=eOCSc-1_e5vC2dQN1ydHGKDLeJ3DqB-eLRLuXp2DpFM,13257
6
+ ml_tools/ML_callbacks.py,sha256=FEJ80TSEtY0-hdnOsAWeVApQt1mdzTdOntqtoWmMAzE,13310
7
7
  ml_tools/ML_datasetmaster.py,sha256=bbKCNA_b_uDIfxP9YIYKZm-VSfUSD15LvegFxpE9DIQ,34315
8
- ml_tools/ML_evaluation.py,sha256=4dVqe6JF1Ukmk1sAcY8E5EG1oB1_oy2HXE5OT-pZwCs,10273
8
+ ml_tools/ML_evaluation.py,sha256=A7AlEjy4ZOcdQMh9M3TJIDvCOXqzAbhgLxyhli8S_WY,13593
9
9
  ml_tools/ML_inference.py,sha256=Fh-X2UQn3AznWBjf-7iPSxwE-EzkGQm1VEIRUAkURmE,5336
10
10
  ml_tools/ML_models.py,sha256=SJhKHGAN2VTBqzcHUOpFWuVZ2Y7U1M4P_axG_LNYWcI,6460
11
11
  ml_tools/ML_optimization.py,sha256=zGKpWW4SL1-3iiHglDP-dkuADL73T0kxs3Dc-Lyishs,9671
12
- ml_tools/ML_trainer.py,sha256=t58Ka6ryaYm0Fi5xje-e-fkmz9DwDLIeJLbh04n_gDg,15034
13
- ml_tools/PSO_optimization.py,sha256=stH2Ux1sftQgX5EwLc85kHcoT4Rmz6zv7sH2yzf4Zrw,22710
12
+ ml_tools/ML_trainer.py,sha256=1q_CDXuMfndRsPuNofUn2mg2TlhG6MYuGqjWxTDgN9c,15112
13
+ ml_tools/PSO_optimization.py,sha256=9Y074d-B5h4Wvp9YPiy6KAeXM-Yv6Il3gWalKvOLVgo,22705
14
14
  ml_tools/RNN_forecast.py,sha256=2CyjBLSYYc3xLHxwLXUmP5Qv8AmV1OB_EndETNX1IBk,1956
15
15
  ml_tools/SQL.py,sha256=9zzS6AFEJM9aj6nE31hDe8S9TqLonk-J1amwZoiHNbk,10468
16
16
  ml_tools/VIF_factor.py,sha256=2nUMupfUoogf8o6ghoFZk_OwWhFXU0R3C9Gj0HOlI14,10415
@@ -19,14 +19,15 @@ ml_tools/_logger.py,sha256=TpgYguxO-CWYqqgLW0tqFjtwZ58PE_W2OCfWNGZr0n0,1175
19
19
  ml_tools/_script_info.py,sha256=21r83LV3RubsNZ_RTEUON6RbDf7Mh4_udweNcvdF_Fk,212
20
20
  ml_tools/custom_logger.py,sha256=njM_0XPbQ1S-x5LeSQAaTo2if-XVOR_pQSGg4EDeiTU,4603
21
21
  ml_tools/data_exploration.py,sha256=P4f8OpRa7Q4i-11nkppxXw5Lx2lwlpn20GwWBbN_xbM,23901
22
- ml_tools/ensemble_inference.py,sha256=0SNX3YAz5bpvtwYmqEwqyWeIJP2Pb-v-bemENRSO7qg,9426
23
- ml_tools/ensemble_learning.py,sha256=Zi1oy6G2FWnTI5hBwjlexwF3JKALFS2FN6F8HAlVi_s,35391
22
+ ml_tools/ensemble_evaluation.py,sha256=ywpBCvmVImocZAcKv52mSdQKKHdLswozknoev39l4Yo,24682
23
+ ml_tools/ensemble_inference.py,sha256=rtU7eUaQne615n2g7IHZCJI-OvrBCcjxbTkEIvtCGFQ,9414
24
+ ml_tools/ensemble_learning.py,sha256=dAyFgSTyvxJWjc_enJ_8EUoWwiekBeoNyJNxVY-kcUU,21868
24
25
  ml_tools/handle_excel.py,sha256=J9iwIqMZemoxK49J5osSwp9Ge0h9YTKyYGbOm53hcno,13007
25
- ml_tools/keys.py,sha256=kK9UF-hek2VcPGFILCKl5geoN6flmMOu7IzhdEA6z5Y,1068
26
+ ml_tools/keys.py,sha256=HtPG8-MWh89C32A7eIlfuuA-DLwkxGkoDfwR2TGN9CQ,1074
26
27
  ml_tools/optimization_tools.py,sha256=MuT4OG7_r1QqLUti-yYix7QeCpglezD0oe9BDCq0QXk,5086
27
28
  ml_tools/path_manager.py,sha256=Z8e7w3MPqQaN8xmTnKuXZS6CIW59BFwwqGhGc00sdp4,13692
28
- ml_tools/utilities.py,sha256=T5xbxzBr14odUj7KncSeg-tJzqjmSDLOOmxEaGYLLi4,18447
29
- dragon_ml_toolbox-5.3.0.dist-info/METADATA,sha256=Lu_JBMfkCPssLk-a2v4b-oZu86cFK1OIB4HtHspVRIk,6643
30
- dragon_ml_toolbox-5.3.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
31
- dragon_ml_toolbox-5.3.0.dist-info/top_level.txt,sha256=wm-oxax3ciyez6VoO4zsFd-gSok2VipYXnbg3TH9PtU,9
32
- dragon_ml_toolbox-5.3.0.dist-info/RECORD,,
29
+ ml_tools/utilities.py,sha256=LqXXTovaHbA5AOKRk6Ru6DgAPAM0wPfYU70kUjYBryo,19231
30
+ dragon_ml_toolbox-6.0.0.dist-info/METADATA,sha256=v7JMG994i_tGqZJmN87pWxswxJEGQTsH2m2fQ_qz0C0,6698
31
+ dragon_ml_toolbox-6.0.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
32
+ dragon_ml_toolbox-6.0.0.dist-info/top_level.txt,sha256=wm-oxax3ciyez6VoO4zsFd-gSok2VipYXnbg3TH9PtU,9
33
+ dragon_ml_toolbox-6.0.0.dist-info/RECORD,,
ml_tools/ML_callbacks.py CHANGED
@@ -2,7 +2,7 @@ import numpy as np
2
2
  import torch
3
3
  from tqdm.auto import tqdm
4
4
  from .path_manager import make_fullpath
5
- from .keys import LogKeys
5
+ from .keys import PyTorchLogKeys
6
6
  from ._logger import _LOGGER
7
7
  from typing import Optional
8
8
  from ._script_info import _script_info
@@ -96,14 +96,14 @@ class TqdmProgressBar(Callback):
96
96
  def on_batch_end(self, batch, logs=None):
97
97
  self.batch_bar.update(1) # type: ignore
98
98
  if logs:
99
- self.batch_bar.set_postfix(loss=f"{logs.get(LogKeys.BATCH_LOSS, 0):.4f}") # type: ignore
99
+ self.batch_bar.set_postfix(loss=f"{logs.get(PyTorchLogKeys.BATCH_LOSS, 0):.4f}") # type: ignore
100
100
 
101
101
  def on_epoch_end(self, epoch, logs=None):
102
102
  self.batch_bar.close() # type: ignore
103
103
  self.epoch_bar.update(1) # type: ignore
104
104
  if logs:
105
- train_loss_str = f"{logs.get(LogKeys.TRAIN_LOSS, 0):.4f}"
106
- val_loss_str = f"{logs.get(LogKeys.VAL_LOSS, 0):.4f}"
105
+ train_loss_str = f"{logs.get(PyTorchLogKeys.TRAIN_LOSS, 0):.4f}"
106
+ val_loss_str = f"{logs.get(PyTorchLogKeys.VAL_LOSS, 0):.4f}"
107
107
  self.epoch_bar.set_postfix_str(f"Train Loss: {train_loss_str}, Val Loss: {val_loss_str}") # type: ignore
108
108
 
109
109
  def on_train_end(self, logs=None):
@@ -124,7 +124,7 @@ class EarlyStopping(Callback):
124
124
  inferred from the name of the monitored quantity.
125
125
  verbose (int): Verbosity mode.
126
126
  """
127
- def __init__(self, monitor: str=LogKeys.VAL_LOSS, min_delta=0.0, patience=3, mode: Literal['auto', 'min', 'max']='auto', verbose: int=1):
127
+ def __init__(self, monitor: str=PyTorchLogKeys.VAL_LOSS, min_delta: float=0.0, patience: int=5, mode: Literal['auto', 'min', 'max']='auto', verbose: int=1):
128
128
  super().__init__()
129
129
  self.monitor = monitor
130
130
  self.patience = patience
@@ -201,8 +201,8 @@ class ModelCheckpoint(Callback):
201
201
  mode (str): One of {'auto', 'min', 'max'}.
202
202
  verbose (int): Verbosity mode.
203
203
  """
204
- def __init__(self, save_dir: Union[str,Path], monitor: str = LogKeys.VAL_LOSS,
205
- save_best_only: bool = False, mode: Literal['auto', 'min', 'max']= 'auto', verbose: int = 1):
204
+ def __init__(self, save_dir: Union[str,Path], monitor: str = PyTorchLogKeys.VAL_LOSS,
205
+ save_best_only: bool = True, mode: Literal['auto', 'min', 'max']= 'auto', verbose: int = 0):
206
206
  super().__init__()
207
207
  self.save_dir = make_fullpath(save_dir, make=True, enforce="directory")
208
208
  if not self.save_dir.is_dir():
ml_tools/ML_evaluation.py CHANGED
@@ -1,6 +1,8 @@
1
1
  import numpy as np
2
2
  import pandas as pd
3
3
  import matplotlib.pyplot as plt
4
+ import seaborn as sns
5
+ from sklearn.calibration import CalibrationDisplay
4
6
  from sklearn.metrics import (
5
7
  classification_report,
6
8
  ConfusionMatrixDisplay,
@@ -9,7 +11,9 @@ from sklearn.metrics import (
9
11
  mean_squared_error,
10
12
  mean_absolute_error,
11
13
  r2_score,
12
- median_absolute_error
14
+ median_absolute_error,
15
+ precision_recall_curve,
16
+ average_precision_score
13
17
  )
14
18
  import torch
15
19
  import shap
@@ -28,13 +32,13 @@ __all__ = [
28
32
  ]
29
33
 
30
34
 
31
- def plot_losses(history: dict, save_dir: Optional[Union[str, Path]] = None):
35
+ def plot_losses(history: dict, save_dir: Union[str, Path]):
32
36
  """
33
37
  Plots training & validation loss curves from a history object.
34
38
 
35
39
  Args:
36
40
  history (dict): A dictionary containing 'train_loss' and 'val_loss'.
37
- save_dir (str | Path | None): Directory to save the plot image.
41
+ save_dir (str | Path): Directory to save the plot image.
38
42
  """
39
43
  train_loss = history.get('train_loss', [])
40
44
  val_loss = history.get('val_loss', [])
@@ -62,86 +66,123 @@ def plot_losses(history: dict, save_dir: Optional[Union[str, Path]] = None):
62
66
  ax.grid(True)
63
67
  plt.tight_layout()
64
68
 
65
- if save_dir:
66
- save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
67
- save_path = save_dir_path / "loss_plot.svg"
68
- plt.savefig(save_path)
69
- _LOGGER.info(f"📉 Loss plot saved as '{save_path.name}'")
70
- else:
71
- plt.show()
69
+ save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
70
+ save_path = save_dir_path / "loss_plot.svg"
71
+ plt.savefig(save_path)
72
+ _LOGGER.info(f"📉 Loss plot saved as '{save_path.name}'")
73
+
72
74
  plt.close(fig)
73
75
 
74
76
 
75
- def classification_metrics(y_true: np.ndarray, y_pred: np.ndarray, y_prob: Optional[np.ndarray] = None,
76
- cmap: str = "Blues", save_dir: Optional[Union[str, Path]] = None):
77
+ def classification_metrics(save_dir: Union[str, Path], y_true: np.ndarray, y_pred: np.ndarray, y_prob: Optional[np.ndarray] = None,
78
+ cmap: str = "Blues"):
77
79
  """
78
- Displays and optionally saves classification metrics and plots.
80
+ Saves classification metrics and plots.
79
81
 
80
82
  Args:
81
83
  y_true (np.ndarray): Ground truth labels.
82
84
  y_pred (np.ndarray): Predicted labels.
83
85
  y_prob (np.ndarray, optional): Predicted probabilities for ROC curve.
84
86
  cmap (str): Colormap for the confusion matrix.
85
- save_dir (str | Path | None): Directory to save plots. If None, plots are shown not saved.
87
+ save_dir (str | Path): Directory to save plots.
86
88
  """
87
89
  print("--- Classification Report ---")
88
- report: str = classification_report(y_true, y_pred) # type: ignore
89
- print(report)
90
+ # Generate report as both text and dictionary
91
+ report_text: str = classification_report(y_true, y_pred) # type: ignore
92
+ report_dict: dict = classification_report(y_true, y_pred, output_dict=True) # type: ignore
93
+ print(report_text)
90
94
 
91
- if save_dir:
92
- save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
93
- # Save text report
94
- report_path = save_dir_path / "classification_report.txt"
95
- report_path.write_text(report, encoding="utf-8")
96
- _LOGGER.info(f"📝 Classification report saved as '{report_path.name}'")
95
+ save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
96
+ # Save text report
97
+ report_path = save_dir_path / "classification_report.txt"
98
+ report_path.write_text(report_text, encoding="utf-8")
99
+ _LOGGER.info(f"📝 Classification report saved as '{report_path.name}'")
100
+
101
+ # --- Save Classification Report Heatmap ---
102
+ try:
103
+ plt.figure(figsize=(8, 6), dpi=100)
104
+ sns.heatmap(pd.DataFrame(report_dict).iloc[:-1, :].T, annot=True, cmap='viridis', fmt='.2f')
105
+ plt.title("Classification Report")
106
+ plt.tight_layout()
107
+ heatmap_path = save_dir_path / "classification_report_heatmap.svg"
108
+ plt.savefig(heatmap_path)
109
+ _LOGGER.info(f"📊 Report heatmap saved as '{heatmap_path.name}'")
110
+ plt.close()
111
+ except Exception as e:
112
+ _LOGGER.error(f"❌ Could not generate classification report heatmap: {e}")
97
113
 
98
- # Save Confusion Matrix
99
- fig_cm, ax_cm = plt.subplots(figsize=(6, 6), dpi=100)
100
- ConfusionMatrixDisplay.from_predictions(y_true, y_pred, cmap=cmap, ax=ax_cm)
101
- ax_cm.set_title("Confusion Matrix")
102
- cm_path = save_dir_path / "confusion_matrix.svg"
103
- plt.savefig(cm_path)
104
- _LOGGER.info(f"❇️ Confusion matrix saved as '{cm_path.name}'")
105
- plt.close(fig_cm)
114
+ # Save Confusion Matrix
115
+ fig_cm, ax_cm = plt.subplots(figsize=(6, 6), dpi=100)
116
+ ConfusionMatrixDisplay.from_predictions(y_true, y_pred, cmap=cmap, ax=ax_cm)
117
+ ax_cm.set_title("Confusion Matrix")
118
+ cm_path = save_dir_path / "confusion_matrix.svg"
119
+ plt.savefig(cm_path)
120
+ _LOGGER.info(f"❇️ Confusion matrix saved as '{cm_path.name}'")
121
+ plt.close(fig_cm)
106
122
 
107
- # Save ROC Curve
108
- if y_prob is not None and y_prob.ndim > 1 and y_prob.shape[1] >= 2:
109
- fpr, tpr, _ = roc_curve(y_true, y_prob[:, 1])
110
- auc = roc_auc_score(y_true, y_prob[:, 1])
111
- fig_roc, ax_roc = plt.subplots(figsize=(6, 6), dpi=100)
112
- ax_roc.plot(fpr, tpr, label=f'AUC = {auc:.2f}')
113
- ax_roc.plot([0, 1], [0, 1], 'k--')
114
- ax_roc.set_title('Receiver Operating Characteristic (ROC) Curve')
115
- ax_roc.set_xlabel('False Positive Rate')
116
- ax_roc.set_ylabel('True Positive Rate')
117
- ax_roc.legend(loc='lower right')
118
- ax_roc.grid(True)
119
- roc_path = save_dir_path / "roc_curve.svg"
120
- plt.savefig(roc_path)
121
- _LOGGER.info(f"📈 ROC curve saved as '{roc_path.name}'")
122
- plt.close(fig_roc)
123
- else:
124
- # Show plots if not saving
125
- ConfusionMatrixDisplay.from_predictions(y_true, y_pred, cmap=cmap)
126
- plt.show()
127
- if y_prob is not None and y_prob.ndim > 1 and y_prob.shape[1] >= 2:
128
- fpr, tpr, _ = roc_curve(y_true, y_prob[:, 1])
129
- auc = roc_auc_score(y_true, y_prob[:, 1])
130
- plt.figure()
131
- plt.plot(fpr, tpr, label=f'AUC = {auc:.2f}')
132
- plt.plot([0, 1], [0, 1], 'k--')
133
- plt.title('ROC Curve')
134
- plt.show()
123
+ # Plotting logic for ROC and PR Curves
124
+ if y_prob is not None and y_prob.ndim > 1 and y_prob.shape[1] >= 2:
125
+ # Use probabilities of the positive class
126
+ y_score = y_prob[:, 1]
127
+
128
+ # --- Save ROC Curve ---
129
+ fpr, tpr, _ = roc_curve(y_true, y_score)
130
+ auc = roc_auc_score(y_true, y_score)
131
+ fig_roc, ax_roc = plt.subplots(figsize=(6, 6), dpi=100)
132
+ ax_roc.plot(fpr, tpr, label=f'AUC = {auc:.2f}')
133
+ ax_roc.plot([0, 1], [0, 1], 'k--')
134
+ ax_roc.set_title('Receiver Operating Characteristic (ROC) Curve')
135
+ ax_roc.set_xlabel('False Positive Rate')
136
+ ax_roc.set_ylabel('True Positive Rate')
137
+ ax_roc.legend(loc='lower right')
138
+ ax_roc.grid(True)
139
+ roc_path = save_dir_path / "roc_curve.svg"
140
+ plt.savefig(roc_path)
141
+ _LOGGER.info(f"📈 ROC curve saved as '{roc_path.name}'")
142
+ plt.close(fig_roc)
143
+
144
+ # --- Save Precision-Recall Curve ---
145
+ precision, recall, _ = precision_recall_curve(y_true, y_score)
146
+ ap_score = average_precision_score(y_true, y_score)
147
+ fig_pr, ax_pr = plt.subplots(figsize=(6, 6), dpi=100)
148
+ ax_pr.plot(recall, precision, label=f'AP = {ap_score:.2f}')
149
+ ax_pr.set_title('Precision-Recall Curve')
150
+ ax_pr.set_xlabel('Recall')
151
+ ax_pr.set_ylabel('Precision')
152
+ ax_pr.legend(loc='lower left')
153
+ ax_pr.grid(True)
154
+ pr_path = save_dir_path / "pr_curve.svg"
155
+ plt.savefig(pr_path)
156
+ _LOGGER.info(f"📈 PR curve saved as '{pr_path.name}'")
157
+ plt.close(fig_pr)
158
+
159
+ # --- Save Calibration Plot ---
160
+ if y_prob.ndim > 1 and y_prob.shape[1] >= 2:
161
+ y_score = y_prob[:, 1] # Use probabilities of the positive class
162
+
163
+ fig_cal, ax_cal = plt.subplots(figsize=(8, 8), dpi=100)
164
+ CalibrationDisplay.from_predictions(y_true, y_score, n_bins=15, ax=ax_cal)
165
+
166
+ ax_cal.set_title('Calibration Plot (Reliability Curve)')
167
+ ax_cal.set_xlabel('Mean Predicted Probability')
168
+ ax_cal.set_ylabel('Fraction of Positives')
169
+ ax_cal.grid(True)
170
+ plt.tight_layout()
171
+
172
+ cal_path = save_dir_path / "calibration_plot.svg"
173
+ plt.savefig(cal_path)
174
+ _LOGGER.info(f"✅ Calibration plot saved as '{cal_path.name}'")
175
+ plt.close(fig_cal)
135
176
 
136
177
 
137
- def regression_metrics(y_true: np.ndarray, y_pred: np.ndarray, save_dir: Optional[Union[str, Path]] = None):
178
+ def regression_metrics(y_true: np.ndarray, y_pred: np.ndarray, save_dir: Union[str, Path]):
138
179
  """
139
- Displays regression metrics and optionally saves plots and report.
180
+ Saves regression metrics and plots.
140
181
 
141
182
  Args:
142
183
  y_true (np.ndarray): Ground truth values.
143
184
  y_pred (np.ndarray): Predicted values.
144
- save_dir (str | None): Directory to save plots and report.
185
+ save_dir (str | Path): Directory to save plots and report.
145
186
  """
146
187
  rmse = np.sqrt(mean_squared_error(y_true, y_pred))
147
188
  mae = mean_absolute_error(y_true, y_pred)
@@ -158,44 +199,56 @@ def regression_metrics(y_true: np.ndarray, y_pred: np.ndarray, save_dir: Optiona
158
199
  report_string = "\n".join(report_lines)
159
200
  print(report_string)
160
201
 
161
- if save_dir:
162
- save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
163
- # Save text report
164
- report_path = save_dir_path / "regression_report.txt"
165
- report_path.write_text(report_string)
166
- _LOGGER.info(f"📝 Regression report saved as '{report_path.name}'")
202
+ save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
203
+ # Save text report
204
+ report_path = save_dir_path / "regression_report.txt"
205
+ report_path.write_text(report_string)
206
+ _LOGGER.info(f"📝 Regression report saved as '{report_path.name}'")
167
207
 
168
- # Save residual plot
169
- residuals = y_true - y_pred
170
- fig_res, ax_res = plt.subplots(figsize=(8, 6), dpi=100)
171
- ax_res.scatter(y_pred, residuals, alpha=0.6)
172
- ax_res.axhline(0, color='red', linestyle='--')
173
- ax_res.set_xlabel("Predicted Values")
174
- ax_res.set_ylabel("Residuals")
175
- ax_res.set_title("Residual Plot")
176
- ax_res.grid(True)
177
- plt.tight_layout()
178
- res_path = save_dir_path / "residual_plot.svg"
179
- plt.savefig(res_path)
180
- _LOGGER.info(f"📈 Residual plot saved as '{res_path.name}'")
181
- plt.close(fig_res)
208
+ # Save residual plot
209
+ residuals = y_true - y_pred
210
+ fig_res, ax_res = plt.subplots(figsize=(8, 6), dpi=100)
211
+ ax_res.scatter(y_pred, residuals, alpha=0.6)
212
+ ax_res.axhline(0, color='red', linestyle='--')
213
+ ax_res.set_xlabel("Predicted Values")
214
+ ax_res.set_ylabel("Residuals")
215
+ ax_res.set_title("Residual Plot")
216
+ ax_res.grid(True)
217
+ plt.tight_layout()
218
+ res_path = save_dir_path / "residual_plot.svg"
219
+ plt.savefig(res_path)
220
+ _LOGGER.info(f"📈 Residual plot saved as '{res_path.name}'")
221
+ plt.close(fig_res)
182
222
 
183
- # Save true vs predicted plot
184
- fig_tvp, ax_tvp = plt.subplots(figsize=(8, 6), dpi=100)
185
- ax_tvp.scatter(y_true, y_pred, alpha=0.6)
186
- ax_tvp.plot([y_true.min(), y_true.max()], [y_true.min(), y_true.max()], 'k--', lw=2)
187
- ax_tvp.set_xlabel('True Values')
188
- ax_tvp.set_ylabel('Predictions')
189
- ax_tvp.set_title('True vs. Predicted Values')
190
- ax_tvp.grid(True)
191
- plt.tight_layout()
192
- tvp_path = save_dir_path / "true_vs_predicted_plot.svg"
193
- plt.savefig(tvp_path)
194
- _LOGGER.info(f"📉 True vs. Predicted plot saved as '{tvp_path.name}'")
195
- plt.close(fig_tvp)
223
+ # Save true vs predicted plot
224
+ fig_tvp, ax_tvp = plt.subplots(figsize=(8, 6), dpi=100)
225
+ ax_tvp.scatter(y_true, y_pred, alpha=0.6)
226
+ ax_tvp.plot([y_true.min(), y_true.max()], [y_true.min(), y_true.max()], 'k--', lw=2)
227
+ ax_tvp.set_xlabel('True Values')
228
+ ax_tvp.set_ylabel('Predictions')
229
+ ax_tvp.set_title('True vs. Predicted Values')
230
+ ax_tvp.grid(True)
231
+ plt.tight_layout()
232
+ tvp_path = save_dir_path / "true_vs_predicted_plot.svg"
233
+ plt.savefig(tvp_path)
234
+ _LOGGER.info(f"📉 True vs. Predicted plot saved as '{tvp_path.name}'")
235
+ plt.close(fig_tvp)
236
+
237
+ # Save Histogram of Residuals
238
+ fig_hist, ax_hist = plt.subplots(figsize=(8, 6), dpi=100)
239
+ sns.histplot(residuals, kde=True, ax=ax_hist)
240
+ ax_hist.set_xlabel("Residual Value")
241
+ ax_hist.set_ylabel("Frequency")
242
+ ax_hist.set_title("Distribution of Residuals")
243
+ ax_hist.grid(True)
244
+ plt.tight_layout()
245
+ hist_path = save_dir_path / "residuals_histogram.svg"
246
+ plt.savefig(hist_path)
247
+ _LOGGER.info(f"📊 Residuals histogram saved as '{hist_path.name}'")
248
+ plt.close(fig_hist)
196
249
 
197
250
 
198
- def shap_summary_plot(model, background_data: torch.Tensor, instances_to_explain: torch.Tensor,
251
+ def shap_summary_plot(model, background_data: Union[torch.Tensor,np.ndarray], instances_to_explain: Union[torch.Tensor,np.ndarray],
199
252
  feature_names: Optional[list[str]]=None, save_dir: Optional[Union[str, Path]] = None):
200
253
  """
201
254
  Calculates SHAP values and saves summary plots and data.
@@ -207,24 +260,54 @@ def shap_summary_plot(model, background_data: torch.Tensor, instances_to_explain
207
260
  feature_names (list of str | None): Names of the features for plot labeling.
208
261
  save_dir (str | Path | None): Directory to save SHAP artifacts. If None, dot plot is shown.
209
262
  """
263
+ # everything to numpy
264
+ if isinstance(background_data, np.ndarray):
265
+ background_data_np = background_data
266
+ else:
267
+ background_data_np = background_data.numpy()
268
+
269
+ if isinstance(instances_to_explain, np.ndarray):
270
+ instances_to_explain_np = instances_to_explain
271
+ else:
272
+ instances_to_explain_np = instances_to_explain.numpy()
273
+
274
+ # --- Data Validation Step ---
275
+ if np.isnan(background_data_np).any() or np.isnan(instances_to_explain_np).any():
276
+ _LOGGER.error("❌ Input data for SHAP contains NaN values. Aborting explanation.")
277
+ return
278
+
210
279
  print("\n--- SHAP Value Explanation ---")
211
- print("Calculating SHAP values... ")
212
280
 
213
281
  model.eval()
214
282
  model.cpu()
215
283
 
216
- explainer = shap.DeepExplainer(model, background_data)
217
- shap_values = explainer.shap_values(instances_to_explain)
218
-
219
- shap_values_for_plot = shap_values[1] if isinstance(shap_values, list) else shap_values
220
- if isinstance(shap_values, list):
221
- _LOGGER.info("Using SHAP values for the positive class (class 1) for plots.")
284
+ # 1. Summarize the background data.
285
+ # Summarize the background data using k-means. 10-50 clusters is a good starting point.
286
+ background_summary = shap.kmeans(background_data_np, 30)
287
+
288
+ # 2. Define a prediction function wrapper that SHAP can use. It must take a numpy array and return a numpy array.
289
+ def prediction_wrapper(x_np: np.ndarray) -> np.ndarray:
290
+ # Convert numpy data to torch tensor
291
+ x_torch = torch.from_numpy(x_np).float()
292
+ with torch.no_grad():
293
+ # Get model output
294
+ output = model(x_torch)
295
+ # Return as numpy array
296
+ return output.cpu().numpy().flatten()
222
297
 
298
+ # 3. Create the KernelExplainer
299
+ explainer = shap.KernelExplainer(prediction_wrapper, background_summary)
300
+
301
+ print("Calculating SHAP values with KernelExplainer...")
302
+ shap_values = explainer.shap_values(instances_to_explain_np, l1_reg="aic")
303
+
223
304
  if save_dir:
224
305
  save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
306
+ plt.ioff()
307
+
225
308
  # Save Bar Plot
226
309
  bar_path = save_dir_path / "shap_bar_plot.svg"
227
- shap.summary_plot(shap_values_for_plot, instances_to_explain, feature_names=feature_names, plot_type="bar", show=False)
310
+ shap.summary_plot(shap_values, instances_to_explain_np, feature_names=feature_names, plot_type="bar", show=False)
228
311
  plt.title("SHAP Feature Importance")
229
312
  plt.tight_layout()
230
313
  plt.savefig(bar_path)
@@ -233,7 +316,7 @@ def shap_summary_plot(model, background_data: torch.Tensor, instances_to_explain
233
316
 
234
317
  # Save Dot Plot
235
318
  dot_path = save_dir_path / "shap_dot_plot.svg"
236
- shap.summary_plot(shap_values_for_plot, instances_to_explain, feature_names=feature_names, plot_type="dot", show=False)
319
+ shap.summary_plot(shap_values, instances_to_explain_np, feature_names=feature_names, plot_type="dot", show=False)
237
320
  plt.title("SHAP Feature Importance")
238
321
  plt.tight_layout()
239
322
  plt.savefig(dot_path)
@@ -242,18 +325,25 @@ def shap_summary_plot(model, background_data: torch.Tensor, instances_to_explain
242
325
 
243
326
  # Save Summary Data to CSV
244
327
  summary_path = save_dir_path / "shap_summary.csv"
245
- mean_abs_shap = np.abs(shap_values_for_plot).mean(axis=0)
328
+ # Ensure the array is 1D before creating the DataFrame
329
+ mean_abs_shap = np.abs(shap_values).mean(axis=0).flatten()
330
+
246
331
  if feature_names is None:
247
332
  feature_names = [f'feature_{i}' for i in range(len(mean_abs_shap))]
333
+
248
334
  summary_df = pd.DataFrame({
249
335
  'feature': feature_names,
250
336
  'mean_abs_shap_value': mean_abs_shap
251
337
  }).sort_values('mean_abs_shap_value', ascending=False)
338
+
252
339
  summary_df.to_csv(summary_path, index=False)
340
+
253
341
  _LOGGER.info(f"📝 SHAP summary data saved as '{summary_path.name}'")
342
+ plt.ion()
343
+
254
344
  else:
255
345
  _LOGGER.info("No save directory provided. Displaying SHAP dot plot.")
256
- shap.summary_plot(shap_values_for_plot, instances_to_explain, feature_names=feature_names, plot_type="dot")
346
+ shap.summary_plot(shap_values, instances_to_explain_np, feature_names=feature_names, plot_type="dot")
257
347
 
258
348
 
259
349
  def info():
ml_tools/ML_trainer.py CHANGED
@@ -8,16 +8,16 @@ import numpy as np
8
8
  from .ML_callbacks import Callback, History, TqdmProgressBar
9
9
  from .ML_evaluation import classification_metrics, regression_metrics, plot_losses, shap_summary_plot
10
10
  from ._script_info import _script_info
11
- from .keys import LogKeys
11
+ from .keys import PyTorchLogKeys
12
12
  from ._logger import _LOGGER
13
13
 
14
14
 
15
15
  __all__ = [
16
- "MyTrainer"
16
+ "MLTrainer"
17
17
  ]
18
18
 
19
19
 
20
- class MyTrainer:
20
+ class MLTrainer:
21
21
  def __init__(self, model: nn.Module, train_dataset: Dataset, test_dataset: Dataset,
22
22
  kind: Literal["regression", "classification"],
23
23
  criterion: nn.Module, optimizer: torch.optim.Optimizer,
@@ -95,14 +95,16 @@ class MyTrainer:
95
95
  batch_size=batch_size,
96
96
  shuffle=shuffle,
97
97
  num_workers=loader_workers,
98
- pin_memory=(self.device.type == "cuda")
98
+ pin_memory=("cuda" in self.device.type),
99
+ drop_last=True # Drops the last batch if incomplete, selecting a good batch size is key.
99
100
  )
101
+
100
102
  self.test_loader = DataLoader(
101
103
  dataset=self.test_dataset,
102
104
  batch_size=batch_size,
103
105
  shuffle=False,
104
106
  num_workers=loader_workers,
105
- pin_memory=(self.device.type == "cuda")
107
+ pin_memory=("cuda" in self.device.type)
106
108
  )
107
109
 
108
110
  def fit(self, epochs: int = 10, batch_size: int = 10, shuffle: bool = True):
@@ -159,8 +161,8 @@ class MyTrainer:
159
161
  for batch_idx, (features, target) in enumerate(self.train_loader): # type: ignore
160
162
  # Create a log dictionary for the batch
161
163
  batch_logs = {
162
- LogKeys.BATCH_INDEX: batch_idx,
163
- LogKeys.BATCH_SIZE: features.size(0)
164
+ PyTorchLogKeys.BATCH_INDEX: batch_idx,
165
+ PyTorchLogKeys.BATCH_SIZE: features.size(0)
164
166
  }
165
167
  self.callbacks_hook('on_batch_begin', batch_idx, logs=batch_logs)
166
168
 
@@ -178,11 +180,11 @@ class MyTrainer:
178
180
  running_loss += batch_loss * features.size(0)
179
181
 
180
182
  # Add the batch loss to the logs and call the end-of-batch hook
181
- batch_logs[LogKeys.BATCH_LOSS] = batch_loss
183
+ batch_logs[PyTorchLogKeys.BATCH_LOSS] = batch_loss
182
184
  self.callbacks_hook('on_batch_end', batch_idx, logs=batch_logs)
183
185
 
184
186
  # Return the average loss for the entire epoch
185
- return {LogKeys.TRAIN_LOSS: running_loss / len(self.train_loader.dataset)} # type: ignore
187
+ return {PyTorchLogKeys.TRAIN_LOSS: running_loss / len(self.train_loader.dataset)} # type: ignore
186
188
 
187
189
  def _validation_step(self):
188
190
  self.model.eval()
@@ -195,7 +197,7 @@ class MyTrainer:
195
197
  output = output.view_as(target)
196
198
  loss = self.criterion(output, target)
197
199
  running_loss += loss.item() * features.size(0)
198
- logs = {LogKeys.VAL_LOSS: running_loss / len(self.test_loader.dataset)} # type: ignore
200
+ logs = {PyTorchLogKeys.VAL_LOSS: running_loss / len(self.test_loader.dataset)} # type: ignore
199
201
  return logs
200
202
 
201
203
  def _predict_for_eval(self, dataloader: DataLoader):
@@ -230,14 +232,14 @@ class MyTrainer:
230
232
 
231
233
  yield y_pred_batch, y_prob_batch, y_true_batch
232
234
 
233
- def evaluate(self, save_dir: Optional[Union[str,Path]], data: Optional[Union[DataLoader, Dataset]] = None):
235
+ def evaluate(self, save_dir: Union[str,Path], data: Optional[Union[DataLoader, Dataset]] = None):
234
236
  """
235
237
  Evaluates the model on the given data.
236
238
 
237
239
  Args:
238
240
  data (DataLoader | Dataset | None ): The data to evaluate on.
239
241
  Can be a DataLoader or a Dataset. If None, defaults to the trainer's internal test_dataset.
240
- save_dir (str | Path | None): Directory to save all reports and plots. If None, metrics are shown but not saved.
242
+ save_dir (str | Path): Directory to save all reports and plots.
241
243
  """
242
244
  eval_loader = None
243
245
  if isinstance(data, DataLoader):
@@ -273,14 +275,14 @@ class MyTrainer:
273
275
  y_prob = np.concatenate(all_probs) if self.kind == "classification" else None
274
276
 
275
277
  if self.kind == "classification":
276
- classification_metrics(y_true, y_pred, y_prob, save_dir=save_dir)
278
+ classification_metrics(save_dir, y_true, y_pred, y_prob)
277
279
  else:
278
- regression_metrics(y_true.flatten(), y_pred.flatten(), save_dir=save_dir)
280
+ regression_metrics(y_true.flatten(), y_pred.flatten(), save_dir)
279
281
 
280
282
  print("\n--- Training History ---")
281
283
  plot_losses(self.history, save_dir=save_dir)
282
284
 
283
- def explain(self, explain_dataset: Optional[Dataset] = None, n_samples: int = 100,
285
+ def explain(self, explain_dataset: Optional[Dataset] = None, n_samples: int = 1000,
284
286
  feature_names: Optional[List[str]] = None, save_dir: Optional[Union[str,Path]] = None):
285
287
  """
286
288
  Explains model predictions using SHAP and saves all artifacts.