dragon-ml-toolbox 13.3.0__py3-none-any.whl → 16.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. {dragon_ml_toolbox-13.3.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/METADATA +20 -6
  2. dragon_ml_toolbox-16.2.0.dist-info/RECORD +51 -0
  3. {dragon_ml_toolbox-13.3.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +10 -0
  4. ml_tools/ETL_cleaning.py +20 -20
  5. ml_tools/ETL_engineering.py +23 -25
  6. ml_tools/GUI_tools.py +20 -20
  7. ml_tools/MICE_imputation.py +207 -5
  8. ml_tools/ML_callbacks.py +43 -26
  9. ml_tools/ML_configuration.py +788 -0
  10. ml_tools/ML_datasetmaster.py +303 -448
  11. ml_tools/ML_evaluation.py +351 -93
  12. ml_tools/ML_evaluation_multi.py +139 -42
  13. ml_tools/ML_inference.py +290 -209
  14. ml_tools/ML_models.py +33 -106
  15. ml_tools/ML_models_advanced.py +323 -0
  16. ml_tools/ML_optimization.py +12 -12
  17. ml_tools/ML_scaler.py +11 -11
  18. ml_tools/ML_sequence_datasetmaster.py +341 -0
  19. ml_tools/ML_sequence_evaluation.py +219 -0
  20. ml_tools/ML_sequence_inference.py +391 -0
  21. ml_tools/ML_sequence_models.py +139 -0
  22. ml_tools/ML_trainer.py +1604 -179
  23. ml_tools/ML_utilities.py +351 -4
  24. ml_tools/ML_vision_datasetmaster.py +1540 -0
  25. ml_tools/ML_vision_evaluation.py +284 -0
  26. ml_tools/ML_vision_inference.py +405 -0
  27. ml_tools/ML_vision_models.py +641 -0
  28. ml_tools/ML_vision_transformers.py +284 -0
  29. ml_tools/PSO_optimization.py +6 -6
  30. ml_tools/SQL.py +4 -4
  31. ml_tools/_keys.py +171 -0
  32. ml_tools/_schema.py +1 -1
  33. ml_tools/custom_logger.py +37 -14
  34. ml_tools/data_exploration.py +502 -93
  35. ml_tools/ensemble_evaluation.py +54 -11
  36. ml_tools/ensemble_inference.py +7 -33
  37. ml_tools/ensemble_learning.py +1 -1
  38. ml_tools/math_utilities.py +1 -1
  39. ml_tools/optimization_tools.py +2 -2
  40. ml_tools/path_manager.py +5 -5
  41. ml_tools/serde.py +2 -2
  42. ml_tools/utilities.py +192 -4
  43. dragon_ml_toolbox-13.3.0.dist-info/RECORD +0 -41
  44. ml_tools/RNN_forecast.py +0 -56
  45. ml_tools/keys.py +0 -87
  46. {dragon_ml_toolbox-13.3.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/WHEEL +0 -0
  47. {dragon_ml_toolbox-13.3.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/licenses/LICENSE +0 -0
  48. {dragon_ml_toolbox-13.3.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/top_level.txt +0 -0
@@ -19,13 +19,17 @@ from sklearn.metrics import (
19
19
  jaccard_score
20
20
  )
21
21
  from pathlib import Path
22
- from typing import Union, List, Literal
22
+ from typing import Union, List, Literal, Optional
23
23
  import warnings
24
24
 
25
25
  from .path_manager import make_fullpath, sanitize_filename
26
26
  from ._logger import _LOGGER
27
27
  from ._script_info import _script_info
28
- from .keys import SHAPKeys
28
+ from ._keys import SHAPKeys
29
+ from .ML_configuration import (MultiTargetRegressionMetricsFormat,
30
+ _BaseRegressionFormat,
31
+ MultiLabelBinaryClassificationMetricsFormat,
32
+ _BaseMultiLabelFormat)
29
33
 
30
34
 
31
35
  __all__ = [
@@ -34,12 +38,15 @@ __all__ = [
34
38
  "multi_target_shap_summary_plot",
35
39
  ]
36
40
 
41
+ DPI_value = 250
42
+
37
43
 
38
44
  def multi_target_regression_metrics(
39
45
  y_true: np.ndarray,
40
46
  y_pred: np.ndarray,
41
47
  target_names: List[str],
42
- save_dir: Union[str, Path]
48
+ save_dir: Union[str, Path],
49
+ config: Optional[MultiTargetRegressionMetricsFormat] = None
43
50
  ):
44
51
  """
45
52
  Calculates and saves regression metrics for each target individually.
@@ -53,6 +60,7 @@ def multi_target_regression_metrics(
53
60
  y_pred (np.ndarray): Predicted values, shape (n_samples, n_targets).
54
61
  target_names (List[str]): A list of names for the target variables.
55
62
  save_dir (str | Path): Directory to save plots and the report.
63
+ config (object): Formatting configuration object.
56
64
  """
57
65
  if y_true.ndim != 2 or y_pred.ndim != 2:
58
66
  _LOGGER.error("y_true and y_pred must be 2D arrays for multi-target regression.")
@@ -66,8 +74,19 @@ def multi_target_regression_metrics(
66
74
 
67
75
  save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
68
76
  metrics_summary = []
77
+
78
+ # --- Parse Config or use defaults ---
79
+ if config is None:
80
+ # Create a default config if one wasn't provided
81
+ format_config = _BaseRegressionFormat()
82
+ else:
83
+ format_config = config
84
+
85
+ # --- Set Matplotlib font size ---
86
+ original_rc_params = plt.rcParams.copy()
87
+ plt.rcParams.update({'font.size': format_config.font_size})
69
88
 
70
- _LOGGER.info("--- Multi-Target Regression Evaluation ---")
89
+ _LOGGER.debug("--- Multi-Target Regression Evaluation ---")
71
90
 
72
91
  for i, name in enumerate(target_names):
73
92
  print(f" -> Evaluating target: '{name}'")
@@ -90,9 +109,13 @@ def multi_target_regression_metrics(
90
109
 
91
110
  # --- Save Residual Plot ---
92
111
  residuals = true_i - pred_i
93
- fig_res, ax_res = plt.subplots(figsize=(8, 6), dpi=100)
94
- ax_res.scatter(pred_i, residuals, alpha=0.6, edgecolors='k', s=50)
95
- ax_res.axhline(0, color='red', linestyle='--')
112
+ fig_res, ax_res = plt.subplots(figsize=(8, 6), dpi=DPI_value)
113
+ ax_res.scatter(pred_i, residuals,
114
+ alpha=format_config.scatter_alpha,
115
+ edgecolors='k',
116
+ s=50,
117
+ color=format_config.scatter_color) # Use config color
118
+ ax_res.axhline(0, color=format_config.residual_line_color, linestyle='--') # Use config color
96
119
  ax_res.set_xlabel("Predicted Values")
97
120
  ax_res.set_ylabel("Residuals (True - Predicted)")
98
121
  ax_res.set_title(f"Residual Plot for '{name}'")
@@ -103,9 +126,16 @@ def multi_target_regression_metrics(
103
126
  plt.close(fig_res)
104
127
 
105
128
  # --- Save True vs. Predicted Plot ---
106
- fig_tvp, ax_tvp = plt.subplots(figsize=(8, 6), dpi=100)
107
- ax_tvp.scatter(true_i, pred_i, alpha=0.6, edgecolors='k', s=50)
108
- ax_tvp.plot([true_i.min(), true_i.max()], [true_i.min(), true_i.max()], 'k--', lw=2)
129
+ fig_tvp, ax_tvp = plt.subplots(figsize=(8, 6), dpi=DPI_value)
130
+ ax_tvp.scatter(true_i, pred_i,
131
+ alpha=format_config.scatter_alpha,
132
+ edgecolors='k',
133
+ s=50,
134
+ color=format_config.scatter_color) # Use config color
135
+ ax_tvp.plot([true_i.min(), true_i.max()], [true_i.min(), true_i.max()],
136
+ linestyle='--',
137
+ lw=2,
138
+ color=format_config.ideal_line_color) # Use config color
109
139
  ax_tvp.set_xlabel('True Values')
110
140
  ax_tvp.set_ylabel('Predicted Values')
111
141
  ax_tvp.set_title(f'True vs. Predicted Values for "{name}"')
@@ -120,14 +150,18 @@ def multi_target_regression_metrics(
120
150
  report_path = save_dir_path / "regression_report_multi.csv"
121
151
  summary_df.to_csv(report_path, index=False)
122
152
  _LOGGER.info(f"Full regression report saved to '{report_path.name}'")
153
+
154
+ # --- Restore RC params ---
155
+ plt.rcParams.update(original_rc_params)
123
156
 
124
157
 
125
158
  def multi_label_classification_metrics(
126
159
  y_true: np.ndarray,
160
+ y_pred: np.ndarray,
127
161
  y_prob: np.ndarray,
128
162
  target_names: List[str],
129
163
  save_dir: Union[str, Path],
130
- threshold: float = 0.5
164
+ config: Optional[MultiLabelBinaryClassificationMetricsFormat] = None
131
165
  ):
132
166
  """
133
167
  Calculates and saves classification metrics for each label individually.
@@ -138,17 +172,17 @@ def multi_label_classification_metrics(
138
172
 
139
173
  Args:
140
174
  y_true (np.ndarray): Ground truth binary labels, shape (n_samples, n_labels).
175
+ y_pred (np.ndarray): Predicted binary labels, shape (n_samples, n_labels).
141
176
  y_prob (np.ndarray): Predicted probabilities, shape (n_samples, n_labels).
142
177
  target_names (List[str]): A list of names for the labels.
143
178
  save_dir (str | Path): Directory to save plots and reports.
144
- threshold (float): The probability threshold to convert probabilities into
145
- binary predictions for metrics like the confusion matrix.
179
+ config (object): Formatting configuration object.
146
180
  """
147
- if y_true.ndim != 2 or y_prob.ndim != 2:
148
- _LOGGER.error("y_true and y_prob must be 2D arrays for multi-label classification.")
181
+ if y_true.ndim != 2 or y_prob.ndim != 2 or y_pred.ndim != 2:
182
+ _LOGGER.error("y_true, y_pred, and y_prob must be 2D arrays for multi-label classification.")
149
183
  raise ValueError()
150
- if y_true.shape != y_prob.shape:
151
- _LOGGER.error("Shapes of y_true and y_prob must match.")
184
+ if y_true.shape != y_prob.shape or y_true.shape != y_pred.shape:
185
+ _LOGGER.error("Shapes of y_true, y_pred, and y_prob must match.")
152
186
  raise ValueError()
153
187
  if y_true.shape[1] != len(target_names):
154
188
  _LOGGER.error("Number of target names must match the number of columns in y_true.")
@@ -156,25 +190,35 @@ def multi_label_classification_metrics(
156
190
 
157
191
  save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
158
192
 
159
- # Generate binary predictions from probabilities
160
- y_pred = (y_prob >= threshold).astype(int)
193
+ # --- Parse Config or use defaults ---
194
+ if config is None:
195
+ # Create a default config if one wasn't provided
196
+ format_config = _BaseMultiLabelFormat()
197
+ else:
198
+ format_config = config
199
+
200
+ # y_pred is now passed in directly, no threshold needed.
201
+
202
+ # --- Save current RC params and update font size ---
203
+ original_rc_params = plt.rcParams.copy()
204
+ plt.rcParams.update({'font.size': format_config.font_size})
161
205
 
162
- _LOGGER.info("--- Multi-Label Classification Evaluation ---")
206
+ # _LOGGER.info("--- Multi-Label Classification Evaluation ---")
163
207
 
164
- # --- Calculate and Save Overall Metrics ---
208
+ # --- Calculate and Save Overall Metrics (using y_pred) ---
165
209
  h_loss = hamming_loss(y_true, y_pred)
166
210
  j_score_micro = jaccard_score(y_true, y_pred, average='micro')
167
211
  j_score_macro = jaccard_score(y_true, y_pred, average='macro')
168
212
 
169
213
  overall_report = (
170
- f"Overall Multi-Label Metrics (Threshold = {threshold}):\n"
214
+ f"Overall Multi-Label Metrics:\n" # No threshold to report here
171
215
  f"--------------------------------------------------\n"
172
216
  f"Hamming Loss: {h_loss:.4f}\n"
173
217
  f"Jaccard Score (micro): {j_score_micro:.4f}\n"
174
218
  f"Jaccard Score (macro): {j_score_macro:.4f}\n"
175
219
  f"--------------------------------------------------\n"
176
220
  )
177
- print(overall_report)
221
+ # print(overall_report)
178
222
  overall_report_path = save_dir_path / "classification_report_overall.txt"
179
223
  overall_report_path.write_text(overall_report)
180
224
 
@@ -182,28 +226,79 @@ def multi_label_classification_metrics(
182
226
  for i, name in enumerate(target_names):
183
227
  print(f" -> Evaluating label: '{name}'")
184
228
  true_i = y_true[:, i]
185
- pred_i = y_pred[:, i]
186
- prob_i = y_prob[:, i]
229
+ pred_i = y_pred[:, i] # Use passed-in y_pred
230
+ prob_i = y_prob[:, i] # Use passed-in y_prob
187
231
  sanitized_name = sanitize_filename(name)
188
232
 
189
- # --- Save Classification Report for the label ---
233
+ # --- Save Classification Report for the label (uses y_pred) ---
190
234
  report_text = classification_report(true_i, pred_i)
191
235
  report_path = save_dir_path / f"classification_report_{sanitized_name}.txt"
192
236
  report_path.write_text(report_text) # type: ignore
193
237
 
194
- # --- Save Confusion Matrix ---
195
- fig_cm, ax_cm = plt.subplots(figsize=(6, 6), dpi=100)
196
- ConfusionMatrixDisplay.from_predictions(true_i, pred_i, cmap="Blues", ax=ax_cm)
238
+ # --- Save Confusion Matrix (uses y_pred) ---
239
+ fig_cm, ax_cm = plt.subplots(figsize=(6, 6), dpi=DPI_value)
240
+ disp_ = ConfusionMatrixDisplay.from_predictions(true_i,
241
+ pred_i,
242
+ cmap=format_config.cmap, # Use config cmap
243
+ ax=ax_cm,
244
+ normalize='true',
245
+ labels=[0, 1],
246
+ display_labels=["Negative", "Positive"])
247
+
248
+ disp_.im_.set_clim(vmin=0.0, vmax=1.0)
249
+
250
+ # Turn off gridlines
251
+ ax_cm.grid(False)
252
+
253
+ # Manually update font size of cell texts
254
+ for text in ax_cm.texts:
255
+ text.set_fontsize(format_config.font_size) # Use config font_size
256
+
257
+ fig_cm.tight_layout()
258
+
197
259
  ax_cm.set_title(f"Confusion Matrix for '{name}'")
198
260
  cm_path = save_dir_path / f"confusion_matrix_{sanitized_name}.svg"
199
261
  plt.savefig(cm_path)
200
262
  plt.close(fig_cm)
201
263
 
202
- # --- Save ROC Curve ---
203
- fpr, tpr, _ = roc_curve(true_i, prob_i)
264
+ # --- Save ROC Curve (uses y_prob) ---
265
+ fpr, tpr, thresholds = roc_curve(true_i, prob_i)
266
+
267
+ try:
268
+ # Calculate Youden's J statistic (tpr - fpr)
269
+ J = tpr - fpr
270
+ # Find the index of the best threshold
271
+ best_index = np.argmax(J)
272
+ optimal_threshold = thresholds[best_index]
273
+ best_tpr = tpr[best_index]
274
+ best_fpr = fpr[best_index]
275
+
276
+ # Define the filename
277
+ threshold_filename = f"best_threshold_{sanitized_name}.txt"
278
+ threshold_path = save_dir_path / threshold_filename
279
+
280
+ # The class name is the target_name for this label
281
+ class_name = name
282
+
283
+ # Create content for the file
284
+ file_content = (
285
+ f"Optimal Classification Threshold (Youden's J Statistic)\n"
286
+ f"Class/Label: {class_name}\n"
287
+ f"--------------------------------------------------\n"
288
+ f"Threshold: {optimal_threshold:.6f}\n"
289
+ f"True Positive Rate (TPR): {best_tpr:.6f}\n"
290
+ f"False Positive Rate (FPR): {best_fpr:.6f}\n"
291
+ )
292
+
293
+ threshold_path.write_text(file_content, encoding="utf-8")
294
+ _LOGGER.info(f"💾 Optimal threshold for '{name}' saved to '{threshold_path.name}'")
295
+
296
+ except Exception as e:
297
+ _LOGGER.warning(f"Could not calculate or save optimal threshold for '{name}': {e}")
298
+
204
299
  auc = roc_auc_score(true_i, prob_i)
205
- fig_roc, ax_roc = plt.subplots(figsize=(6, 6), dpi=100)
206
- ax_roc.plot(fpr, tpr, label=f'AUC = {auc:.2f}')
300
+ fig_roc, ax_roc = plt.subplots(figsize=(6, 6), dpi=DPI_value)
301
+ ax_roc.plot(fpr, tpr, label=f'AUC = {auc:.2f}', color=format_config.ROC_PR_line) # Use config color
207
302
  ax_roc.plot([0, 1], [0, 1], 'k--')
208
303
  ax_roc.set_title(f'ROC Curve for "{name}"')
209
304
  ax_roc.set_xlabel('False Positive Rate'); ax_roc.set_ylabel('True Positive Rate')
@@ -212,17 +307,20 @@ def multi_label_classification_metrics(
212
307
  plt.savefig(roc_path)
213
308
  plt.close(fig_roc)
214
309
 
215
- # --- Save Precision-Recall Curve ---
310
+ # --- Save Precision-Recall Curve (uses y_prob) ---
216
311
  precision, recall, _ = precision_recall_curve(true_i, prob_i)
217
312
  ap_score = average_precision_score(true_i, prob_i)
218
- fig_pr, ax_pr = plt.subplots(figsize=(6, 6), dpi=100)
219
- ax_pr.plot(recall, precision, label=f'AP = {ap_score:.2f}')
313
+ fig_pr, ax_pr = plt.subplots(figsize=(6, 6), dpi=DPI_value)
314
+ ax_pr.plot(recall, precision, label=f'AP = {ap_score:.2f}', color=format_config.ROC_PR_line) # Use config color
220
315
  ax_pr.set_title(f'Precision-Recall Curve for "{name}"')
221
316
  ax_pr.set_xlabel('Recall'); ax_pr.set_ylabel('Precision')
222
317
  ax_pr.legend(loc='lower left'); ax_pr.grid(True, linestyle='--', alpha=0.6)
223
318
  pr_path = save_dir_path / f"pr_curve_{sanitized_name}.svg"
224
319
  plt.savefig(pr_path)
225
320
  plt.close(fig_pr)
321
+
322
+ # restore RC params
323
+ plt.rcParams.update(original_rc_params)
226
324
 
227
325
  _LOGGER.info(f"All individual label reports and plots saved to '{save_dir_path.name}'")
228
326
 
@@ -235,7 +333,7 @@ def multi_target_shap_summary_plot(
235
333
  target_names: List[str],
236
334
  save_dir: Union[str, Path],
237
335
  device: torch.device = torch.device('cpu'),
238
- explainer_type: Literal['deep', 'kernel'] = 'deep'
336
+ explainer_type: Literal['deep', 'kernel'] = 'kernel'
239
337
  ):
240
338
  """
241
339
  Calculates SHAP values for a multi-target model and saves summary plots and data for each target.
@@ -249,7 +347,7 @@ def multi_target_shap_summary_plot(
249
347
  save_dir (str | Path): Directory to save SHAP artifacts.
250
348
  device (torch.device): The torch device for SHAP calculations.
251
349
  explainer_type (Literal['deep', 'kernel']): The explainer to use.
252
- - 'deep': (Default) Uses shap.DeepExplainer. Fast and efficient.
350
+ - 'deep': Uses shap.DeepExplainer. Fast and efficient.
253
351
  - 'kernel': Uses shap.KernelExplainer. Model-agnostic but slow and memory-intensive.
254
352
  """
255
353
  _LOGGER.info(f"--- Multi-Target SHAP Value Explanation (Using: {explainer_type.upper()}Explainer) ---")
@@ -260,7 +358,7 @@ def multi_target_shap_summary_plot(
260
358
  instances_to_explain_np = None
261
359
 
262
360
  if explainer_type == 'deep':
263
- # --- 1. Use DeepExplainer (Preferred) ---
361
+ # --- 1. Use DeepExplainer ---
264
362
 
265
363
  # Ensure data is torch.Tensor
266
364
  if isinstance(background_data, np.ndarray):
@@ -285,10 +383,9 @@ def multi_target_shap_summary_plot(
285
383
  instances_to_explain_np = instances_to_explain.cpu().numpy()
286
384
 
287
385
  elif explainer_type == 'kernel':
288
- # --- 2. Use KernelExplainer (Slow Fallback) ---
386
+ # --- 2. Use KernelExplainer ---
289
387
  _LOGGER.warning(
290
- "Using KernelExplainer. This is memory-intensive and slow. "
291
- "Consider reducing 'n_samples' if the process terminates."
388
+ "KernelExplainer is memory-intensive and slow. Consider reducing the number of instances to explain if the process terminates unexpectedly."
292
389
  )
293
390
 
294
391
  # Convert all data to numpy