dragon-ml-toolbox 14.3.1__py3-none-any.whl → 16.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dragon-ml-toolbox might be problematic. Click here for more details.

Files changed (44) hide show
  1. {dragon_ml_toolbox-14.3.1.dist-info → dragon_ml_toolbox-16.0.0.dist-info}/METADATA +10 -5
  2. dragon_ml_toolbox-16.0.0.dist-info/RECORD +51 -0
  3. ml_tools/ETL_cleaning.py +20 -20
  4. ml_tools/ETL_engineering.py +23 -25
  5. ml_tools/GUI_tools.py +20 -20
  6. ml_tools/MICE_imputation.py +3 -3
  7. ml_tools/ML_callbacks.py +43 -26
  8. ml_tools/ML_configuration.py +309 -0
  9. ml_tools/ML_datasetmaster.py +220 -260
  10. ml_tools/ML_evaluation.py +317 -81
  11. ml_tools/ML_evaluation_multi.py +127 -36
  12. ml_tools/ML_inference.py +249 -207
  13. ml_tools/ML_models.py +13 -102
  14. ml_tools/ML_models_advanced.py +1 -1
  15. ml_tools/ML_optimization.py +12 -12
  16. ml_tools/ML_scaler.py +11 -11
  17. ml_tools/ML_sequence_datasetmaster.py +341 -0
  18. ml_tools/ML_sequence_evaluation.py +215 -0
  19. ml_tools/ML_sequence_inference.py +391 -0
  20. ml_tools/ML_sequence_models.py +139 -0
  21. ml_tools/ML_trainer.py +1247 -338
  22. ml_tools/ML_utilities.py +51 -2
  23. ml_tools/ML_vision_datasetmaster.py +262 -118
  24. ml_tools/ML_vision_evaluation.py +26 -6
  25. ml_tools/ML_vision_inference.py +117 -140
  26. ml_tools/ML_vision_models.py +15 -1
  27. ml_tools/ML_vision_transformers.py +233 -7
  28. ml_tools/PSO_optimization.py +6 -6
  29. ml_tools/SQL.py +4 -4
  30. ml_tools/{keys.py → _keys.py} +45 -1
  31. ml_tools/_schema.py +1 -1
  32. ml_tools/ensemble_evaluation.py +54 -11
  33. ml_tools/ensemble_inference.py +7 -33
  34. ml_tools/ensemble_learning.py +1 -1
  35. ml_tools/optimization_tools.py +2 -2
  36. ml_tools/path_manager.py +5 -5
  37. ml_tools/utilities.py +1 -2
  38. dragon_ml_toolbox-14.3.1.dist-info/RECORD +0 -48
  39. ml_tools/RNN_forecast.py +0 -56
  40. ml_tools/_ML_vision_recipe.py +0 -88
  41. {dragon_ml_toolbox-14.3.1.dist-info → dragon_ml_toolbox-16.0.0.dist-info}/WHEEL +0 -0
  42. {dragon_ml_toolbox-14.3.1.dist-info → dragon_ml_toolbox-16.0.0.dist-info}/licenses/LICENSE +0 -0
  43. {dragon_ml_toolbox-14.3.1.dist-info → dragon_ml_toolbox-16.0.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +0 -0
  44. {dragon_ml_toolbox-14.3.1.dist-info → dragon_ml_toolbox-16.0.0.dist-info}/top_level.txt +0 -0
@@ -19,13 +19,14 @@ from sklearn.metrics import (
19
19
  jaccard_score
20
20
  )
21
21
  from pathlib import Path
22
- from typing import Union, List, Literal
22
+ from typing import Union, List, Literal, Optional
23
23
  import warnings
24
24
 
25
25
  from .path_manager import make_fullpath, sanitize_filename
26
26
  from ._logger import _LOGGER
27
27
  from ._script_info import _script_info
28
- from .keys import SHAPKeys
28
+ from ._keys import SHAPKeys
29
+ from .ML_configuration import RegressionMetricsFormat, MultiClassificationMetricsFormat
29
30
 
30
31
 
31
32
  __all__ = [
@@ -34,12 +35,15 @@ __all__ = [
34
35
  "multi_target_shap_summary_plot",
35
36
  ]
36
37
 
38
+ DPI_value = 250
39
+
37
40
 
38
41
  def multi_target_regression_metrics(
39
42
  y_true: np.ndarray,
40
43
  y_pred: np.ndarray,
41
44
  target_names: List[str],
42
- save_dir: Union[str, Path]
45
+ save_dir: Union[str, Path],
46
+ config: Optional[RegressionMetricsFormat] = None
43
47
  ):
44
48
  """
45
49
  Calculates and saves regression metrics for each target individually.
@@ -53,6 +57,7 @@ def multi_target_regression_metrics(
53
57
  y_pred (np.ndarray): Predicted values, shape (n_samples, n_targets).
54
58
  target_names (List[str]): A list of names for the target variables.
55
59
  save_dir (str | Path): Directory to save plots and the report.
60
+ config (RegressionMetricsFormat, optional): Formatting configuration object.
56
61
  """
57
62
  if y_true.ndim != 2 or y_pred.ndim != 2:
58
63
  _LOGGER.error("y_true and y_pred must be 2D arrays for multi-target regression.")
@@ -66,8 +71,17 @@ def multi_target_regression_metrics(
66
71
 
67
72
  save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
68
73
  metrics_summary = []
74
+
75
+ # --- Parse Config or use defaults ---
76
+ if config is None:
77
+ # Create a default config if one wasn't provided
78
+ config = RegressionMetricsFormat()
79
+
80
+ # --- Set Matplotlib font size ---
81
+ original_rc_params = plt.rcParams.copy()
82
+ plt.rcParams.update({'font.size': config.font_size})
69
83
 
70
- _LOGGER.info("--- Multi-Target Regression Evaluation ---")
84
+ _LOGGER.debug("--- Multi-Target Regression Evaluation ---")
71
85
 
72
86
  for i, name in enumerate(target_names):
73
87
  print(f" -> Evaluating target: '{name}'")
@@ -90,9 +104,13 @@ def multi_target_regression_metrics(
90
104
 
91
105
  # --- Save Residual Plot ---
92
106
  residuals = true_i - pred_i
93
- fig_res, ax_res = plt.subplots(figsize=(8, 6), dpi=100)
94
- ax_res.scatter(pred_i, residuals, alpha=0.6, edgecolors='k', s=50)
95
- ax_res.axhline(0, color='red', linestyle='--')
107
+ fig_res, ax_res = plt.subplots(figsize=(8, 6), dpi=DPI_value)
108
+ ax_res.scatter(pred_i, residuals,
109
+ alpha=config.scatter_alpha,
110
+ edgecolors='k',
111
+ s=50,
112
+ color=config.scatter_color) # Use config color
113
+ ax_res.axhline(0, color=config.residual_line_color, linestyle='--') # Use config color
96
114
  ax_res.set_xlabel("Predicted Values")
97
115
  ax_res.set_ylabel("Residuals (True - Predicted)")
98
116
  ax_res.set_title(f"Residual Plot for '{name}'")
@@ -103,9 +121,16 @@ def multi_target_regression_metrics(
103
121
  plt.close(fig_res)
104
122
 
105
123
  # --- Save True vs. Predicted Plot ---
106
- fig_tvp, ax_tvp = plt.subplots(figsize=(8, 6), dpi=100)
107
- ax_tvp.scatter(true_i, pred_i, alpha=0.6, edgecolors='k', s=50)
108
- ax_tvp.plot([true_i.min(), true_i.max()], [true_i.min(), true_i.max()], 'k--', lw=2)
124
+ fig_tvp, ax_tvp = plt.subplots(figsize=(8, 6), dpi=DPI_value)
125
+ ax_tvp.scatter(true_i, pred_i,
126
+ alpha=config.scatter_alpha,
127
+ edgecolors='k',
128
+ s=50,
129
+ color=config.scatter_color) # Use config color
130
+ ax_tvp.plot([true_i.min(), true_i.max()], [true_i.min(), true_i.max()],
131
+ linestyle='--',
132
+ lw=2,
133
+ color=config.ideal_line_color) # Use config color
109
134
  ax_tvp.set_xlabel('True Values')
110
135
  ax_tvp.set_ylabel('Predicted Values')
111
136
  ax_tvp.set_title(f'True vs. Predicted Values for "{name}"')
@@ -120,14 +145,18 @@ def multi_target_regression_metrics(
120
145
  report_path = save_dir_path / "regression_report_multi.csv"
121
146
  summary_df.to_csv(report_path, index=False)
122
147
  _LOGGER.info(f"Full regression report saved to '{report_path.name}'")
148
+
149
+ # --- Restore RC params ---
150
+ plt.rcParams.update(original_rc_params)
123
151
 
124
152
 
125
153
  def multi_label_classification_metrics(
126
154
  y_true: np.ndarray,
155
+ y_pred: np.ndarray,
127
156
  y_prob: np.ndarray,
128
157
  target_names: List[str],
129
158
  save_dir: Union[str, Path],
130
- threshold: float = 0.5
159
+ config: Optional[MultiClassificationMetricsFormat] = None # Add config object
131
160
  ):
132
161
  """
133
162
  Calculates and saves classification metrics for each label individually.
@@ -138,17 +167,17 @@ def multi_label_classification_metrics(
138
167
 
139
168
  Args:
140
169
  y_true (np.ndarray): Ground truth binary labels, shape (n_samples, n_labels).
170
+ y_pred (np.ndarray): Predicted binary labels, shape (n_samples, n_labels).
141
171
  y_prob (np.ndarray): Predicted probabilities, shape (n_samples, n_labels).
142
172
  target_names (List[str]): A list of names for the labels.
143
173
  save_dir (str | Path): Directory to save plots and reports.
144
- threshold (float): The probability threshold to convert probabilities into
145
- binary predictions for metrics like the confusion matrix.
174
+ config (MultiClassificationMetricsFormat, optional): Formatting configuration object.
146
175
  """
147
- if y_true.ndim != 2 or y_prob.ndim != 2:
148
- _LOGGER.error("y_true and y_prob must be 2D arrays for multi-label classification.")
176
+ if y_true.ndim != 2 or y_prob.ndim != 2 or y_pred.ndim != 2:
177
+ _LOGGER.error("y_true, y_pred, and y_prob must be 2D arrays for multi-label classification.")
149
178
  raise ValueError()
150
- if y_true.shape != y_prob.shape:
151
- _LOGGER.error("Shapes of y_true and y_prob must match.")
179
+ if y_true.shape != y_prob.shape or y_true.shape != y_pred.shape:
180
+ _LOGGER.error("Shapes of y_true, y_pred, and y_prob must match.")
152
181
  raise ValueError()
153
182
  if y_true.shape[1] != len(target_names):
154
183
  _LOGGER.error("Number of target names must match the number of columns in y_true.")
@@ -156,25 +185,33 @@ def multi_label_classification_metrics(
156
185
 
157
186
  save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
158
187
 
159
- # Generate binary predictions from probabilities
160
- y_pred = (y_prob >= threshold).astype(int)
188
+ # --- Parse Config or use defaults ---
189
+ if config is None:
190
+ # Create a default config if one wasn't provided
191
+ config = MultiClassificationMetricsFormat()
192
+
193
+ # y_pred is now passed in directly, no threshold needed.
194
+
195
+ # --- Save current RC params and update font size ---
196
+ original_rc_params = plt.rcParams.copy()
197
+ plt.rcParams.update({'font.size': config.font_size})
161
198
 
162
- _LOGGER.info("--- Multi-Label Classification Evaluation ---")
199
+ # _LOGGER.info("--- Multi-Label Classification Evaluation ---")
163
200
 
164
- # --- Calculate and Save Overall Metrics ---
201
+ # --- Calculate and Save Overall Metrics (using y_pred) ---
165
202
  h_loss = hamming_loss(y_true, y_pred)
166
203
  j_score_micro = jaccard_score(y_true, y_pred, average='micro')
167
204
  j_score_macro = jaccard_score(y_true, y_pred, average='macro')
168
205
 
169
206
  overall_report = (
170
- f"Overall Multi-Label Metrics (Threshold = {threshold}):\n"
207
+ f"Overall Multi-Label Metrics:\n" # No threshold to report here
171
208
  f"--------------------------------------------------\n"
172
209
  f"Hamming Loss: {h_loss:.4f}\n"
173
210
  f"Jaccard Score (micro): {j_score_micro:.4f}\n"
174
211
  f"Jaccard Score (macro): {j_score_macro:.4f}\n"
175
212
  f"--------------------------------------------------\n"
176
213
  )
177
- print(overall_report)
214
+ # print(overall_report)
178
215
  overall_report_path = save_dir_path / "classification_report_overall.txt"
179
216
  overall_report_path.write_text(overall_report)
180
217
 
@@ -182,28 +219,79 @@ def multi_label_classification_metrics(
182
219
  for i, name in enumerate(target_names):
183
220
  print(f" -> Evaluating label: '{name}'")
184
221
  true_i = y_true[:, i]
185
- pred_i = y_pred[:, i]
186
- prob_i = y_prob[:, i]
222
+ pred_i = y_pred[:, i] # Use passed-in y_pred
223
+ prob_i = y_prob[:, i] # Use passed-in y_prob
187
224
  sanitized_name = sanitize_filename(name)
188
225
 
189
- # --- Save Classification Report for the label ---
226
+ # --- Save Classification Report for the label (uses y_pred) ---
190
227
  report_text = classification_report(true_i, pred_i)
191
228
  report_path = save_dir_path / f"classification_report_{sanitized_name}.txt"
192
229
  report_path.write_text(report_text) # type: ignore
193
230
 
194
- # --- Save Confusion Matrix ---
195
- fig_cm, ax_cm = plt.subplots(figsize=(6, 6), dpi=100)
196
- ConfusionMatrixDisplay.from_predictions(true_i, pred_i, cmap="Blues", ax=ax_cm)
231
+ # --- Save Confusion Matrix (uses y_pred) ---
232
+ fig_cm, ax_cm = plt.subplots(figsize=(6, 6), dpi=DPI_value)
233
+ disp_ = ConfusionMatrixDisplay.from_predictions(true_i,
234
+ pred_i,
235
+ cmap=config.cmap, # Use config cmap
236
+ ax=ax_cm,
237
+ normalize='true',
238
+ labels=[0, 1],
239
+ display_labels=["Negative", "Positive"])
240
+
241
+ disp_.im_.set_clim(vmin=0.0, vmax=1.0)
242
+
243
+ # Turn off gridlines
244
+ ax_cm.grid(False)
245
+
246
+ # Manually update font size of cell texts
247
+ for text in ax_cm.texts:
248
+ text.set_fontsize(config.font_size) # Use config font_size
249
+
250
+ fig_cm.tight_layout()
251
+
197
252
  ax_cm.set_title(f"Confusion Matrix for '{name}'")
198
253
  cm_path = save_dir_path / f"confusion_matrix_{sanitized_name}.svg"
199
254
  plt.savefig(cm_path)
200
255
  plt.close(fig_cm)
201
256
 
202
- # --- Save ROC Curve ---
203
- fpr, tpr, _ = roc_curve(true_i, prob_i)
257
+ # --- Save ROC Curve (uses y_prob) ---
258
+ fpr, tpr, thresholds = roc_curve(true_i, prob_i)
259
+
260
+ try:
261
+ # Calculate Youden's J statistic (tpr - fpr)
262
+ J = tpr - fpr
263
+ # Find the index of the best threshold
264
+ best_index = np.argmax(J)
265
+ optimal_threshold = thresholds[best_index]
266
+ best_tpr = tpr[best_index]
267
+ best_fpr = fpr[best_index]
268
+
269
+ # Define the filename
270
+ threshold_filename = f"best_threshold_{sanitized_name}.txt"
271
+ threshold_path = save_dir_path / threshold_filename
272
+
273
+ # The class name is the target_name for this label
274
+ class_name = name
275
+
276
+ # Create content for the file
277
+ file_content = (
278
+ f"Optimal Classification Threshold (Youden's J Statistic)\n"
279
+ f"Class/Label: {class_name}\n"
280
+ f"--------------------------------------------------\n"
281
+ f"Threshold: {optimal_threshold:.6f}\n"
282
+ f"True Positive Rate (TPR): {best_tpr:.6f}\n"
283
+ f"False Positive Rate (FPR): {best_fpr:.6f}\n"
284
+ )
285
+
286
+ threshold_path.write_text(file_content, encoding="utf-8")
287
+ _LOGGER.info(f"💾 Optimal threshold for '{name}' saved to '{threshold_path.name}'")
288
+
289
+ except Exception as e:
290
+ _LOGGER.warning(f"Could not calculate or save optimal threshold for '{name}': {e}")
291
+
204
292
  auc = roc_auc_score(true_i, prob_i)
205
- fig_roc, ax_roc = plt.subplots(figsize=(6, 6), dpi=100)
206
- ax_roc.plot(fpr, tpr, label=f'AUC = {auc:.2f}')
293
+ fig_roc, ax_roc = plt.subplots(figsize=(6, 6), dpi=DPI_value)
294
+ ax_roc.plot(fpr, tpr, label=f'AUC = {auc:.2f}', color=config.ROC_PR_line) # Use config color
207
295
  ax_roc.plot([0, 1], [0, 1], 'k--')
208
296
  ax_roc.set_title(f'ROC Curve for "{name}"')
209
297
  ax_roc.set_xlabel('False Positive Rate'); ax_roc.set_ylabel('True Positive Rate')
@@ -212,17 +300,20 @@ def multi_label_classification_metrics(
212
300
  plt.savefig(roc_path)
213
301
  plt.close(fig_roc)
214
302
 
215
- # --- Save Precision-Recall Curve ---
303
+ # --- Save Precision-Recall Curve (uses y_prob) ---
216
304
  precision, recall, _ = precision_recall_curve(true_i, prob_i)
217
305
  ap_score = average_precision_score(true_i, prob_i)
218
- fig_pr, ax_pr = plt.subplots(figsize=(6, 6), dpi=100)
219
- ax_pr.plot(recall, precision, label=f'AP = {ap_score:.2f}')
306
+ fig_pr, ax_pr = plt.subplots(figsize=(6, 6), dpi=DPI_value)
307
+ ax_pr.plot(recall, precision, label=f'AP = {ap_score:.2f}', color=config.ROC_PR_line) # Use config color
220
308
  ax_pr.set_title(f'Precision-Recall Curve for "{name}"')
221
309
  ax_pr.set_xlabel('Recall'); ax_pr.set_ylabel('Precision')
222
310
  ax_pr.legend(loc='lower left'); ax_pr.grid(True, linestyle='--', alpha=0.6)
223
311
  pr_path = save_dir_path / f"pr_curve_{sanitized_name}.svg"
224
312
  plt.savefig(pr_path)
225
313
  plt.close(fig_pr)
314
+
315
+ # restore RC params
316
+ plt.rcParams.update(original_rc_params)
226
317
 
227
318
  _LOGGER.info(f"All individual label reports and plots saved to '{save_dir_path.name}'")
228
319