dragon-ml-toolbox 14.8.0__py3-none-any.whl → 16.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dragon-ml-toolbox might be problematic. Click here for more details.

Files changed (44) hide show
  1. {dragon_ml_toolbox-14.8.0.dist-info → dragon_ml_toolbox-16.0.0.dist-info}/METADATA +9 -5
  2. dragon_ml_toolbox-16.0.0.dist-info/RECORD +51 -0
  3. ml_tools/ETL_cleaning.py +20 -20
  4. ml_tools/ETL_engineering.py +23 -25
  5. ml_tools/GUI_tools.py +20 -20
  6. ml_tools/MICE_imputation.py +3 -3
  7. ml_tools/ML_callbacks.py +43 -26
  8. ml_tools/ML_configuration.py +204 -11
  9. ml_tools/ML_datasetmaster.py +198 -280
  10. ml_tools/ML_evaluation.py +132 -41
  11. ml_tools/ML_evaluation_multi.py +96 -35
  12. ml_tools/ML_inference.py +249 -207
  13. ml_tools/ML_models.py +13 -102
  14. ml_tools/ML_models_advanced.py +1 -1
  15. ml_tools/ML_optimization.py +12 -12
  16. ml_tools/ML_scaler.py +11 -11
  17. ml_tools/ML_sequence_datasetmaster.py +341 -0
  18. ml_tools/ML_sequence_evaluation.py +215 -0
  19. ml_tools/ML_sequence_inference.py +391 -0
  20. ml_tools/ML_sequence_models.py +139 -0
  21. ml_tools/ML_trainer.py +1237 -354
  22. ml_tools/ML_utilities.py +1 -1
  23. ml_tools/ML_vision_datasetmaster.py +73 -67
  24. ml_tools/ML_vision_evaluation.py +26 -6
  25. ml_tools/ML_vision_inference.py +117 -140
  26. ml_tools/ML_vision_models.py +1 -1
  27. ml_tools/ML_vision_transformers.py +121 -40
  28. ml_tools/PSO_optimization.py +6 -6
  29. ml_tools/SQL.py +4 -4
  30. ml_tools/{keys.py → _keys.py} +43 -0
  31. ml_tools/_schema.py +1 -1
  32. ml_tools/ensemble_evaluation.py +1 -1
  33. ml_tools/ensemble_inference.py +7 -33
  34. ml_tools/ensemble_learning.py +1 -1
  35. ml_tools/optimization_tools.py +2 -2
  36. ml_tools/path_manager.py +5 -5
  37. ml_tools/utilities.py +1 -2
  38. dragon_ml_toolbox-14.8.0.dist-info/RECORD +0 -49
  39. ml_tools/RNN_forecast.py +0 -56
  40. ml_tools/_ML_vision_recipe.py +0 -88
  41. {dragon_ml_toolbox-14.8.0.dist-info → dragon_ml_toolbox-16.0.0.dist-info}/WHEEL +0 -0
  42. {dragon_ml_toolbox-14.8.0.dist-info → dragon_ml_toolbox-16.0.0.dist-info}/licenses/LICENSE +0 -0
  43. {dragon_ml_toolbox-14.8.0.dist-info → dragon_ml_toolbox-16.0.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +0 -0
  44. {dragon_ml_toolbox-14.8.0.dist-info → dragon_ml_toolbox-16.0.0.dist-info}/top_level.txt +0 -0
ml_tools/ML_evaluation.py CHANGED
@@ -24,7 +24,8 @@ import warnings
24
24
  from .path_manager import make_fullpath, sanitize_filename
25
25
  from ._logger import _LOGGER
26
26
  from ._script_info import _script_info
27
- from .keys import SHAPKeys, PyTorchLogKeys
27
+ from ._keys import SHAPKeys, PyTorchLogKeys
28
+ from .ML_configuration import RegressionMetricsFormat, ClassificationMetricsFormat
28
29
 
29
30
 
30
31
  __all__ = [
@@ -41,6 +42,7 @@ DPI_value = 250
41
42
  def plot_losses(history: dict, save_dir: Union[str, Path]):
42
43
  """
43
44
  Plots training & validation loss curves from a history object.
45
+ Also plots the learning rate if available in the history.
44
46
 
45
47
  Args:
46
48
  history (dict): A dictionary containing 'train_loss' and 'val_loss'.
@@ -48,6 +50,7 @@ def plot_losses(history: dict, save_dir: Union[str, Path]):
48
50
  """
49
51
  train_loss = history.get(PyTorchLogKeys.TRAIN_LOSS, [])
50
52
  val_loss = history.get(PyTorchLogKeys.VAL_LOSS, [])
53
+ lr_history = history.get(PyTorchLogKeys.LEARNING_RATE, [])
51
54
 
52
55
  if not train_loss and not val_loss:
53
56
  _LOGGER.warning("Loss history is empty or incomplete. Cannot plot.")
@@ -55,22 +58,44 @@ def plot_losses(history: dict, save_dir: Union[str, Path]):
55
58
 
56
59
  fig, ax = plt.subplots(figsize=(10, 5), dpi=DPI_value)
57
60
 
61
+ # --- Plot Losses (Left Y-axis) ---
62
+ line_handles = [] # To store line objects for the legend
63
+
58
64
  # Plot training loss only if data for it exists
59
65
  if train_loss:
60
66
  epochs = range(1, len(train_loss) + 1)
61
- ax.plot(epochs, train_loss, 'o-', label='Training Loss')
67
+ line1, = ax.plot(epochs, train_loss, 'o-', label='Training Loss', color='tab:blue')
68
+ line_handles.append(line1)
62
69
 
63
70
  # Plot validation loss only if data for it exists
64
71
  if val_loss:
65
72
  epochs = range(1, len(val_loss) + 1)
66
- ax.plot(epochs, val_loss, 'o-', label='Validation Loss')
73
+ line2, = ax.plot(epochs, val_loss, 'o-', label='Validation Loss', color='tab:orange')
74
+ line_handles.append(line2)
67
75
 
68
76
  ax.set_title('Training and Validation Loss')
69
77
  ax.set_xlabel('Epochs')
70
- ax.set_ylabel('Loss')
71
- ax.legend()
72
- ax.grid(True)
73
- plt.tight_layout()
78
+ ax.set_ylabel('Loss', color='tab:blue')
79
+ ax.tick_params(axis='y', labelcolor='tab:blue')
80
+ ax.grid(True, linestyle='--')
81
+
82
+ # --- Plot Learning Rate (Right Y-axis) ---
83
+ if lr_history:
84
+ ax2 = ax.twinx() # Create a second y-axis
85
+ epochs = range(1, len(lr_history) + 1)
86
+ line3, = ax2.plot(epochs, lr_history, 'g--', label='Learning Rate')
87
+ line_handles.append(line3)
88
+
89
+ ax2.set_ylabel('Learning Rate', color='g')
90
+ ax2.tick_params(axis='y', labelcolor='g')
91
+ # Use scientific notation if the LR is very small
92
+ ax2.ticklabel_format(style='sci', axis='y', scilimits=(0,0))
93
+
94
+ # Combine legends from both axes
95
+ ax.legend(handles=line_handles, loc='best')
96
+
97
+ # ax.grid(True)
98
+ plt.tight_layout()
74
99
 
75
100
  save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
76
101
  save_path = save_dir_path / "loss_plot.svg"
@@ -84,34 +109,34 @@ def classification_metrics(save_dir: Union[str, Path],
84
109
  y_true: np.ndarray,
85
110
  y_pred: np.ndarray,
86
111
  y_prob: Optional[np.ndarray] = None,
87
- cmap: str = "Blues",
88
- class_map: Optional[dict[str,int]]=None,
89
- ROC_PR_line: str='darkorange',
90
- calibration_bins: int=15,
91
- font_size: int=16):
112
+ config: Optional[ClassificationMetricsFormat] = None):
92
113
  """
93
114
  Saves classification metrics and plots.
94
115
 
95
116
  Args:
96
117
  y_true (np.ndarray): Ground truth labels.
97
118
  y_pred (np.ndarray): Predicted labels.
98
- y_prob (np.ndarray, optional): Predicted probabilities for ROC curve.
99
- cmap (str): Colormap for the confusion matrix.
119
+ y_prob (np.ndarray): Predicted probabilities for ROC curve.
120
+ config (ClassificationMetricsFormat): Formatting configuration object.
100
121
  save_dir (str | Path): Directory to save plots.
101
- class_map (dict[str, int], None): A map of {class_name: index} used to order and label the confusion matrix.
102
122
  """
123
+ # --- Parse Config or use defaults ---
124
+ if config is None:
125
+ # Create a default config if one wasn't provided
126
+ config = ClassificationMetricsFormat()
127
+
103
128
  original_rc_params = plt.rcParams.copy()
104
- plt.rcParams.update({'font.size': font_size})
129
+ plt.rcParams.update({'font.size': config.font_size})
105
130
 
106
131
  # print("--- Classification Report ---")
107
132
 
108
133
  # --- Parse class_map ---
109
134
  map_labels = None
110
135
  map_display_labels = None
111
- if class_map:
136
+ if config.class_map:
112
137
  # Sort the map by its values (the indices) to ensure correct order
113
138
  try:
114
- sorted_items = sorted(class_map.items(), key=lambda item: item[1])
139
+ sorted_items = sorted(config.class_map.items(), key=lambda item: item[1])
115
140
  map_labels = [item[1] for item in sorted_items]
116
141
  map_display_labels = [item[0] for item in sorted_items]
117
142
  except Exception as e:
@@ -136,7 +161,7 @@ def classification_metrics(save_dir: Union[str, Path],
136
161
  sns.set_theme(font_scale=1.2) # Scale seaborn font
137
162
  sns.heatmap(pd.DataFrame(report_dict).iloc[:-1, :].T,
138
163
  annot=True,
139
- cmap=cmap,
164
+ cmap=config.cmap,
140
165
  fmt='.2f',
141
166
  vmin=0.0,
142
167
  vmax=1.0)
@@ -158,7 +183,7 @@ def classification_metrics(save_dir: Union[str, Path],
158
183
  fig_cm, ax_cm = plt.subplots(figsize=(6, 6), dpi=DPI_value)
159
184
  disp_ = ConfusionMatrixDisplay.from_predictions(y_true,
160
185
  y_pred,
161
- cmap=cmap,
186
+ cmap=config.cmap,
162
187
  ax=ax_cm,
163
188
  normalize='true',
164
189
  labels=plot_labels,
@@ -171,7 +196,7 @@ def classification_metrics(save_dir: Union[str, Path],
171
196
 
172
197
  # Manually update font size of cell texts
173
198
  for text in ax_cm.texts:
174
- text.set_fontsize(font_size)
199
+ text.set_fontsize(config.font_size)
175
200
 
176
201
  fig_cm.tight_layout()
177
202
 
@@ -196,10 +221,10 @@ def classification_metrics(save_dir: Union[str, Path],
196
221
  class_indices_to_plot = [1]
197
222
  plot_titles = [""] # No extra title
198
223
  save_suffixes = [""] # No extra suffix
199
- _LOGGER.info("Generating binary classification plots (ROC, PR, Calibration).")
224
+ _LOGGER.debug("Generating binary classification plots (ROC, PR, Calibration).")
200
225
 
201
226
  elif num_classes > 2:
202
- _LOGGER.info(f"Generating One-vs-Rest plots for {num_classes} classes.")
227
+ _LOGGER.debug(f"Generating One-vs-Rest plots for {num_classes} classes.")
203
228
  # Multiclass case: Plot for every class (One-vs-Rest)
204
229
  class_indices_to_plot = list(range(num_classes))
205
230
 
@@ -234,17 +259,54 @@ def classification_metrics(save_dir: Union[str, Path],
234
259
 
235
260
  # Binarize y_true for the current class
236
261
  y_true_binary = (y_true == class_index).astype(int)
237
-
262
+
238
263
  # --- Save ROC Curve ---
239
- fpr, tpr, _ = roc_curve(y_true_binary, y_score)
264
+ fpr, tpr, thresholds = roc_curve(y_true_binary, y_score)
265
+
266
+ try:
267
+ # Calculate Youden's J statistic (tpr - fpr)
268
+ J = tpr - fpr
269
+ # Find the index of the best threshold
270
+ best_index = np.argmax(J)
271
+ optimal_threshold = thresholds[best_index]
272
+
273
+ # Define the filename
274
+ threshold_filename = f"best_threshold{save_suffix}.txt"
275
+ threshold_path = save_dir_path / threshold_filename
276
+
277
+ # Get the class name for the report
278
+ class_name = ""
279
+ # Check if we have display labels and the current index is valid
280
+ if map_display_labels and class_index < len(map_display_labels):
281
+ class_name = map_display_labels[class_index]
282
+ if num_classes > 2:
283
+ # Add 'vs. Rest' for multiclass one-vs-rest plots
284
+ class_name += " (vs. Rest)"
285
+ else:
286
+ # Fallback to the generic title or default binary name
287
+ class_name = plot_title.strip() or "Binary Positive Class"
288
+
289
+ # Create content for the file
290
+ file_content = (
291
+ f"Optimal Classification Threshold (Youden's J Statistic)\n"
292
+ f"Class: {class_name}\n"
293
+ f"--------------------------------------------------\n"
294
+ f"Threshold: {optimal_threshold:.6f}\n"
295
+ f"True Positive Rate (TPR): {tpr[best_index]:.6f}\n"
296
+ f"False Positive Rate (FPR): {fpr[best_index]:.6f}\n"
297
+ )
298
+
299
+ threshold_path.write_text(file_content, encoding="utf-8")
300
+ _LOGGER.info(f"💾 Optimal threshold saved as '{threshold_path.name}'")
301
+
302
+ except Exception as e:
303
+ _LOGGER.warning(f"Could not calculate or save optimal threshold: {e}")
240
304
 
241
305
  # Calculate AUC.
242
- # Note: For multiclass, roc_auc_score(y_true, y_prob, multi_class='ovr') could average, but plotting individual curves is more informative.
243
- # Here we calculate the specific AUC for the binarized problem.
244
306
  auc = roc_auc_score(y_true_binary, y_score)
245
307
 
246
308
  fig_roc, ax_roc = plt.subplots(figsize=(6, 6), dpi=DPI_value)
247
- ax_roc.plot(fpr, tpr, label=f'AUC = {auc:.2f}', color=ROC_PR_line)
309
+ ax_roc.plot(fpr, tpr, label=f'AUC = {auc:.2f}', color=config.ROC_PR_line)
248
310
  ax_roc.plot([0, 1], [0, 1], 'k--')
249
311
  ax_roc.set_title(f'Receiver Operating Characteristic{plot_title}')
250
312
  ax_roc.set_xlabel('False Positive Rate')
@@ -259,7 +321,7 @@ def classification_metrics(save_dir: Union[str, Path],
259
321
  precision, recall, _ = precision_recall_curve(y_true_binary, y_score)
260
322
  ap_score = average_precision_score(y_true_binary, y_score)
261
323
  fig_pr, ax_pr = plt.subplots(figsize=(6, 6), dpi=DPI_value)
262
- ax_pr.plot(recall, precision, label=f'Avg Precision = {ap_score:.2f}', color=ROC_PR_line)
324
+ ax_pr.plot(recall, precision, label=f'Avg Precision = {ap_score:.2f}', color=config.ROC_PR_line)
263
325
  ax_pr.set_title(f'Precision-Recall Curve{plot_title}')
264
326
  ax_pr.set_xlabel('Recall')
265
327
  ax_pr.set_ylabel('Precision')
@@ -278,7 +340,7 @@ def classification_metrics(save_dir: Union[str, Path],
278
340
  cal_display_temp = CalibrationDisplay.from_predictions(
279
341
  y_true_binary, # Use binarized labels
280
342
  y_score,
281
- n_bins=calibration_bins,
343
+ n_bins=config.calibration_bins,
282
344
  ax=ax_temp,
283
345
  name="temp" # Add a name to suppress potential warnings
284
346
  )
@@ -294,9 +356,9 @@ def classification_metrics(save_dir: Union[str, Path],
294
356
  y=line_y,
295
357
  ax=ax_cal,
296
358
  scatter=False,
297
- label=f"Calibration Curve ({calibration_bins} bins)",
359
+ label=f"Calibration Curve ({config.calibration_bins} bins)",
298
360
  line_kws={
299
- 'color': ROC_PR_line,
361
+ 'color': config.ROC_PR_line,
300
362
  'linestyle': '--',
301
363
  'linewidth': 2,
302
364
  }
@@ -324,7 +386,12 @@ def classification_metrics(save_dir: Union[str, Path],
324
386
  plt.rcParams.update(original_rc_params)
325
387
 
326
388
 
327
- def regression_metrics(y_true: np.ndarray, y_pred: np.ndarray, save_dir: Union[str, Path]):
389
+ def regression_metrics(
390
+ y_true: np.ndarray,
391
+ y_pred: np.ndarray,
392
+ save_dir: Union[str, Path],
393
+ config: Optional[RegressionMetricsFormat] = None
394
+ ):
328
395
  """
329
396
  Saves regression metrics and plots.
330
397
 
@@ -332,7 +399,19 @@ def regression_metrics(y_true: np.ndarray, y_pred: np.ndarray, save_dir: Union[s
332
399
  y_true (np.ndarray): Ground truth values.
333
400
  y_pred (np.ndarray): Predicted values.
334
401
  save_dir (str | Path): Directory to save plots and report.
402
+ config (RegressionMetricsFormat, optional): Formatting configuration object.
335
403
  """
404
+
405
+ # --- Parse Config or use defaults ---
406
+ if config is None:
407
+ # Create a default config if one wasn't provided
408
+ config = RegressionMetricsFormat()
409
+
410
+ # --- Set Matplotlib font size ---
411
+ original_rc_params = plt.rcParams.copy()
412
+ plt.rcParams.update({'font.size': config.font_size})
413
+
414
+ # --- Calculate Metrics ---
336
415
  rmse = np.sqrt(mean_squared_error(y_true, y_pred))
337
416
  mae = mean_absolute_error(y_true, y_pred)
338
417
  r2 = r2_score(y_true, y_pred)
@@ -354,11 +433,13 @@ def regression_metrics(y_true: np.ndarray, y_pred: np.ndarray, save_dir: Union[s
354
433
  report_path.write_text(report_string)
355
434
  _LOGGER.info(f"📝 Regression report saved as '{report_path.name}'")
356
435
 
357
- # Save residual plot
436
+ # --- Save residual plot ---
358
437
  residuals = y_true - y_pred
359
438
  fig_res, ax_res = plt.subplots(figsize=(8, 6), dpi=DPI_value)
360
- ax_res.scatter(y_pred, residuals, alpha=0.6)
361
- ax_res.axhline(0, color='red', linestyle='--')
439
+ ax_res.scatter(y_pred, residuals,
440
+ alpha=config.scatter_alpha,
441
+ color=config.scatter_color)
442
+ ax_res.axhline(0, color=config.residual_line_color, linestyle='--')
362
443
  ax_res.set_xlabel("Predicted Values")
363
444
  ax_res.set_ylabel("Residuals")
364
445
  ax_res.set_title("Residual Plot")
@@ -369,10 +450,15 @@ def regression_metrics(y_true: np.ndarray, y_pred: np.ndarray, save_dir: Union[s
369
450
  _LOGGER.info(f"📈 Residual plot saved as '{res_path.name}'")
370
451
  plt.close(fig_res)
371
452
 
372
- # Save true vs predicted plot
453
+ # --- Save true vs predicted plot ---
373
454
  fig_tvp, ax_tvp = plt.subplots(figsize=(8, 6), dpi=DPI_value)
374
- ax_tvp.scatter(y_true, y_pred, alpha=0.6)
375
- ax_tvp.plot([y_true.min(), y_true.max()], [y_true.min(), y_true.max()], 'k--', lw=2)
455
+ ax_tvp.scatter(y_true, y_pred,
456
+ alpha=config.scatter_alpha,
457
+ color=config.scatter_color)
458
+ ax_tvp.plot([y_true.min(), y_true.max()], [y_true.min(), y_true.max()],
459
+ linestyle='--',
460
+ lw=2,
461
+ color=config.ideal_line_color)
376
462
  ax_tvp.set_xlabel('True Values')
377
463
  ax_tvp.set_ylabel('Predictions')
378
464
  ax_tvp.set_title('True vs. Predicted Values')
@@ -383,9 +469,11 @@ def regression_metrics(y_true: np.ndarray, y_pred: np.ndarray, save_dir: Union[s
383
469
  _LOGGER.info(f"📉 True vs. Predicted plot saved as '{tvp_path.name}'")
384
470
  plt.close(fig_tvp)
385
471
 
386
- # Save Histogram of Residuals
472
+ # --- Save Histogram of Residuals ---
387
473
  fig_hist, ax_hist = plt.subplots(figsize=(8, 6), dpi=DPI_value)
388
- sns.histplot(residuals, kde=True, ax=ax_hist)
474
+ sns.histplot(residuals, kde=True, ax=ax_hist,
475
+ bins=config.hist_bins,
476
+ color=config.scatter_color)
389
477
  ax_hist.set_xlabel("Residual Value")
390
478
  ax_hist.set_ylabel("Frequency")
391
479
  ax_hist.set_title("Distribution of Residuals")
@@ -396,6 +484,9 @@ def regression_metrics(y_true: np.ndarray, y_pred: np.ndarray, save_dir: Union[s
396
484
  _LOGGER.info(f"📊 Residuals histogram saved as '{hist_path.name}'")
397
485
  plt.close(fig_hist)
398
486
 
487
+ # --- Restore RC params ---
488
+ plt.rcParams.update(original_rc_params)
489
+
399
490
 
400
491
  def shap_summary_plot(model,
401
492
  background_data: Union[torch.Tensor,np.ndarray],
@@ -19,13 +19,14 @@ from sklearn.metrics import (
19
19
  jaccard_score
20
20
  )
21
21
  from pathlib import Path
22
- from typing import Union, List, Literal
22
+ from typing import Union, List, Literal, Optional
23
23
  import warnings
24
24
 
25
25
  from .path_manager import make_fullpath, sanitize_filename
26
26
  from ._logger import _LOGGER
27
27
  from ._script_info import _script_info
28
- from .keys import SHAPKeys
28
+ from ._keys import SHAPKeys
29
+ from .ML_configuration import RegressionMetricsFormat, MultiClassificationMetricsFormat
29
30
 
30
31
 
31
32
  __all__ = [
@@ -41,7 +42,8 @@ def multi_target_regression_metrics(
41
42
  y_true: np.ndarray,
42
43
  y_pred: np.ndarray,
43
44
  target_names: List[str],
44
- save_dir: Union[str, Path]
45
+ save_dir: Union[str, Path],
46
+ config: Optional[RegressionMetricsFormat] = None
45
47
  ):
46
48
  """
47
49
  Calculates and saves regression metrics for each target individually.
@@ -55,6 +57,7 @@ def multi_target_regression_metrics(
55
57
  y_pred (np.ndarray): Predicted values, shape (n_samples, n_targets).
56
58
  target_names (List[str]): A list of names for the target variables.
57
59
  save_dir (str | Path): Directory to save plots and the report.
60
+ config (RegressionMetricsFormat, optional): Formatting configuration object.
58
61
  """
59
62
  if y_true.ndim != 2 or y_pred.ndim != 2:
60
63
  _LOGGER.error("y_true and y_pred must be 2D arrays for multi-target regression.")
@@ -68,8 +71,17 @@ def multi_target_regression_metrics(
68
71
 
69
72
  save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
70
73
  metrics_summary = []
74
+
75
+ # --- Parse Config or use defaults ---
76
+ if config is None:
77
+ # Create a default config if one wasn't provided
78
+ config = RegressionMetricsFormat()
79
+
80
+ # --- Set Matplotlib font size ---
81
+ original_rc_params = plt.rcParams.copy()
82
+ plt.rcParams.update({'font.size': config.font_size})
71
83
 
72
- _LOGGER.info("--- Multi-Target Regression Evaluation ---")
84
+ _LOGGER.debug("--- Multi-Target Regression Evaluation ---")
73
85
 
74
86
  for i, name in enumerate(target_names):
75
87
  print(f" -> Evaluating target: '{name}'")
@@ -93,8 +105,12 @@ def multi_target_regression_metrics(
93
105
  # --- Save Residual Plot ---
94
106
  residuals = true_i - pred_i
95
107
  fig_res, ax_res = plt.subplots(figsize=(8, 6), dpi=DPI_value)
96
- ax_res.scatter(pred_i, residuals, alpha=0.6, edgecolors='k', s=50)
97
- ax_res.axhline(0, color='red', linestyle='--')
108
+ ax_res.scatter(pred_i, residuals,
109
+ alpha=config.scatter_alpha,
110
+ edgecolors='k',
111
+ s=50,
112
+ color=config.scatter_color) # Use config color
113
+ ax_res.axhline(0, color=config.residual_line_color, linestyle='--') # Use config color
98
114
  ax_res.set_xlabel("Predicted Values")
99
115
  ax_res.set_ylabel("Residuals (True - Predicted)")
100
116
  ax_res.set_title(f"Residual Plot for '{name}'")
@@ -106,8 +122,15 @@ def multi_target_regression_metrics(
106
122
 
107
123
  # --- Save True vs. Predicted Plot ---
108
124
  fig_tvp, ax_tvp = plt.subplots(figsize=(8, 6), dpi=DPI_value)
109
- ax_tvp.scatter(true_i, pred_i, alpha=0.6, edgecolors='k', s=50)
110
- ax_tvp.plot([true_i.min(), true_i.max()], [true_i.min(), true_i.max()], 'k--', lw=2)
125
+ ax_tvp.scatter(true_i, pred_i,
126
+ alpha=config.scatter_alpha,
127
+ edgecolors='k',
128
+ s=50,
129
+ color=config.scatter_color) # Use config color
130
+ ax_tvp.plot([true_i.min(), true_i.max()], [true_i.min(), true_i.max()],
131
+ linestyle='--',
132
+ lw=2,
133
+ color=config.ideal_line_color) # Use config color
111
134
  ax_tvp.set_xlabel('True Values')
112
135
  ax_tvp.set_ylabel('Predicted Values')
113
136
  ax_tvp.set_title(f'True vs. Predicted Values for "{name}"')
@@ -122,17 +145,18 @@ def multi_target_regression_metrics(
122
145
  report_path = save_dir_path / "regression_report_multi.csv"
123
146
  summary_df.to_csv(report_path, index=False)
124
147
  _LOGGER.info(f"Full regression report saved to '{report_path.name}'")
148
+
149
+ # --- Restore RC params ---
150
+ plt.rcParams.update(original_rc_params)
125
151
 
126
152
 
127
153
  def multi_label_classification_metrics(
128
154
  y_true: np.ndarray,
155
+ y_pred: np.ndarray,
129
156
  y_prob: np.ndarray,
130
157
  target_names: List[str],
131
158
  save_dir: Union[str, Path],
132
- threshold: float = 0.5,
133
- ROC_PR_line: str='darkorange',
134
- cmap: str = "Blues",
135
- font_size: int = 16
159
+ config: Optional[MultiClassificationMetricsFormat] = None # Add config object
136
160
  ):
137
161
  """
138
162
  Calculates and saves classification metrics for each label individually.
@@ -143,17 +167,17 @@ def multi_label_classification_metrics(
143
167
 
144
168
  Args:
145
169
  y_true (np.ndarray): Ground truth binary labels, shape (n_samples, n_labels).
170
+ y_pred (np.ndarray): Predicted binary labels, shape (n_samples, n_labels).
146
171
  y_prob (np.ndarray): Predicted probabilities, shape (n_samples, n_labels).
147
172
  target_names (List[str]): A list of names for the labels.
148
173
  save_dir (str | Path): Directory to save plots and reports.
149
- threshold (float): The probability threshold to convert probabilities into
150
- binary predictions for metrics like the confusion matrix.
174
+ config (MultiClassificationMetricsFormat, optional): Formatting configuration object.
151
175
  """
152
- if y_true.ndim != 2 or y_prob.ndim != 2:
153
- _LOGGER.error("y_true and y_prob must be 2D arrays for multi-label classification.")
176
+ if y_true.ndim != 2 or y_prob.ndim != 2 or y_pred.ndim != 2:
177
+ _LOGGER.error("y_true, y_pred, and y_prob must be 2D arrays for multi-label classification.")
154
178
  raise ValueError()
155
- if y_true.shape != y_prob.shape:
156
- _LOGGER.error("Shapes of y_true and y_prob must match.")
179
+ if y_true.shape != y_prob.shape or y_true.shape != y_pred.shape:
180
+ _LOGGER.error("Shapes of y_true, y_pred, and y_prob must match.")
157
181
  raise ValueError()
158
182
  if y_true.shape[1] != len(target_names):
159
183
  _LOGGER.error("Number of target names must match the number of columns in y_true.")
@@ -161,22 +185,26 @@ def multi_label_classification_metrics(
161
185
 
162
186
  save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
163
187
 
164
- # Generate binary predictions from probabilities
165
- y_pred = (y_prob >= threshold).astype(int)
188
+ # --- Parse Config or use defaults ---
189
+ if config is None:
190
+ # Create a default config if one wasn't provided
191
+ config = MultiClassificationMetricsFormat()
192
+
193
+ # y_pred is now passed in directly, no threshold needed.
166
194
 
167
195
  # --- Save current RC params and update font size ---
168
196
  original_rc_params = plt.rcParams.copy()
169
- plt.rcParams.update({'font.size': font_size})
197
+ plt.rcParams.update({'font.size': config.font_size})
170
198
 
171
- _LOGGER.info("--- Multi-Label Classification Evaluation ---")
199
+ # _LOGGER.info("--- Multi-Label Classification Evaluation ---")
172
200
 
173
- # --- Calculate and Save Overall Metrics ---
201
+ # --- Calculate and Save Overall Metrics (using y_pred) ---
174
202
  h_loss = hamming_loss(y_true, y_pred)
175
203
  j_score_micro = jaccard_score(y_true, y_pred, average='micro')
176
204
  j_score_macro = jaccard_score(y_true, y_pred, average='macro')
177
205
 
178
206
  overall_report = (
179
- f"Overall Multi-Label Metrics (Threshold = {threshold}):\n"
207
+ f"Overall Multi-Label Metrics:\n" # No threshold to report here
180
208
  f"--------------------------------------------------\n"
181
209
  f"Hamming Loss: {h_loss:.4f}\n"
182
210
  f"Jaccard Score (micro): {j_score_micro:.4f}\n"
@@ -191,20 +219,20 @@ def multi_label_classification_metrics(
191
219
  for i, name in enumerate(target_names):
192
220
  print(f" -> Evaluating label: '{name}'")
193
221
  true_i = y_true[:, i]
194
- pred_i = y_pred[:, i]
195
- prob_i = y_prob[:, i]
222
+ pred_i = y_pred[:, i] # Use passed-in y_pred
223
+ prob_i = y_prob[:, i] # Use passed-in y_prob
196
224
  sanitized_name = sanitize_filename(name)
197
225
 
198
- # --- Save Classification Report for the label ---
226
+ # --- Save Classification Report for the label (uses y_pred) ---
199
227
  report_text = classification_report(true_i, pred_i)
200
228
  report_path = save_dir_path / f"classification_report_{sanitized_name}.txt"
201
229
  report_path.write_text(report_text) # type: ignore
202
230
 
203
- # --- Save Confusion Matrix ---
231
+ # --- Save Confusion Matrix (uses y_pred) ---
204
232
  fig_cm, ax_cm = plt.subplots(figsize=(6, 6), dpi=DPI_value)
205
233
  disp_ = ConfusionMatrixDisplay.from_predictions(true_i,
206
234
  pred_i,
207
- cmap=cmap,
235
+ cmap=config.cmap, # Use config cmap
208
236
  ax=ax_cm,
209
237
  normalize='true',
210
238
  labels=[0, 1],
@@ -217,7 +245,7 @@ def multi_label_classification_metrics(
217
245
 
218
246
  # Manually update font size of cell texts
219
247
  for text in ax_cm.texts:
220
- text.set_fontsize(font_size)
248
+ text.set_fontsize(config.font_size) # Use config font_size
221
249
 
222
250
  fig_cm.tight_layout()
223
251
 
@@ -226,11 +254,44 @@ def multi_label_classification_metrics(
226
254
  plt.savefig(cm_path)
227
255
  plt.close(fig_cm)
228
256
 
229
- # --- Save ROC Curve ---
230
- fpr, tpr, _ = roc_curve(true_i, prob_i)
257
+ # --- Save ROC Curve (uses y_prob) ---
258
+ fpr, tpr, thresholds = roc_curve(true_i, prob_i)
259
+
260
+ try:
261
+ # Calculate Youden's J statistic (tpr - fpr)
262
+ J = tpr - fpr
263
+ # Find the index of the best threshold
264
+ best_index = np.argmax(J)
265
+ optimal_threshold = thresholds[best_index]
266
+ best_tpr = tpr[best_index]
267
+ best_fpr = fpr[best_index]
268
+
269
+ # Define the filename
270
+ threshold_filename = f"best_threshold_{sanitized_name}.txt"
271
+ threshold_path = save_dir_path / threshold_filename
272
+
273
+ # The class name is the target_name for this label
274
+ class_name = name
275
+
276
+ # Create content for the file
277
+ file_content = (
278
+ f"Optimal Classification Threshold (Youden's J Statistic)\n"
279
+ f"Class/Label: {class_name}\n"
280
+ f"--------------------------------------------------\n"
281
+ f"Threshold: {optimal_threshold:.6f}\n"
282
+ f"True Positive Rate (TPR): {best_tpr:.6f}\n"
283
+ f"False Positive Rate (FPR): {best_fpr:.6f}\n"
284
+ )
285
+
286
+ threshold_path.write_text(file_content, encoding="utf-8")
287
+ _LOGGER.info(f"💾 Optimal threshold for '{name}' saved to '{threshold_path.name}'")
288
+
289
+ except Exception as e:
290
+ _LOGGER.warning(f"Could not calculate or save optimal threshold for '{name}': {e}")
291
+
231
292
  auc = roc_auc_score(true_i, prob_i)
232
293
  fig_roc, ax_roc = plt.subplots(figsize=(6, 6), dpi=DPI_value)
233
- ax_roc.plot(fpr, tpr, label=f'AUC = {auc:.2f}', color=ROC_PR_line)
294
+ ax_roc.plot(fpr, tpr, label=f'AUC = {auc:.2f}', color=config.ROC_PR_line) # Use config color
234
295
  ax_roc.plot([0, 1], [0, 1], 'k--')
235
296
  ax_roc.set_title(f'ROC Curve for "{name}"')
236
297
  ax_roc.set_xlabel('False Positive Rate'); ax_roc.set_ylabel('True Positive Rate')
@@ -239,11 +300,11 @@ def multi_label_classification_metrics(
239
300
  plt.savefig(roc_path)
240
301
  plt.close(fig_roc)
241
302
 
242
- # --- Save Precision-Recall Curve ---
303
+ # --- Save Precision-Recall Curve (uses y_prob) ---
243
304
  precision, recall, _ = precision_recall_curve(true_i, prob_i)
244
305
  ap_score = average_precision_score(true_i, prob_i)
245
306
  fig_pr, ax_pr = plt.subplots(figsize=(6, 6), dpi=DPI_value)
246
- ax_pr.plot(recall, precision, label=f'AP = {ap_score:.2f}', color=ROC_PR_line)
307
+ ax_pr.plot(recall, precision, label=f'AP = {ap_score:.2f}', color=config.ROC_PR_line) # Use config color
247
308
  ax_pr.set_title(f'Precision-Recall Curve for "{name}"')
248
309
  ax_pr.set_xlabel('Recall'); ax_pr.set_ylabel('Precision')
249
310
  ax_pr.legend(loc='lower left'); ax_pr.grid(True, linestyle='--', alpha=0.6)