dragon-ml-toolbox 14.7.0__py3-none-any.whl → 16.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. {dragon_ml_toolbox-14.7.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/METADATA +9 -5
  2. dragon_ml_toolbox-16.2.0.dist-info/RECORD +51 -0
  3. ml_tools/ETL_cleaning.py +20 -20
  4. ml_tools/ETL_engineering.py +23 -25
  5. ml_tools/GUI_tools.py +20 -20
  6. ml_tools/MICE_imputation.py +3 -3
  7. ml_tools/ML_callbacks.py +43 -26
  8. ml_tools/ML_configuration.py +704 -24
  9. ml_tools/ML_datasetmaster.py +235 -280
  10. ml_tools/ML_evaluation.py +144 -39
  11. ml_tools/ML_evaluation_multi.py +103 -35
  12. ml_tools/ML_inference.py +290 -208
  13. ml_tools/ML_models.py +13 -102
  14. ml_tools/ML_models_advanced.py +1 -1
  15. ml_tools/ML_optimization.py +12 -12
  16. ml_tools/ML_scaler.py +11 -11
  17. ml_tools/ML_sequence_datasetmaster.py +341 -0
  18. ml_tools/ML_sequence_evaluation.py +219 -0
  19. ml_tools/ML_sequence_inference.py +391 -0
  20. ml_tools/ML_sequence_models.py +139 -0
  21. ml_tools/ML_trainer.py +1342 -386
  22. ml_tools/ML_utilities.py +1 -1
  23. ml_tools/ML_vision_datasetmaster.py +120 -72
  24. ml_tools/ML_vision_evaluation.py +30 -6
  25. ml_tools/ML_vision_inference.py +129 -152
  26. ml_tools/ML_vision_models.py +1 -1
  27. ml_tools/ML_vision_transformers.py +121 -40
  28. ml_tools/PSO_optimization.py +6 -6
  29. ml_tools/SQL.py +4 -4
  30. ml_tools/{keys.py → _keys.py} +45 -0
  31. ml_tools/_schema.py +1 -1
  32. ml_tools/ensemble_evaluation.py +1 -1
  33. ml_tools/ensemble_inference.py +7 -33
  34. ml_tools/ensemble_learning.py +1 -1
  35. ml_tools/optimization_tools.py +2 -2
  36. ml_tools/path_manager.py +5 -5
  37. ml_tools/utilities.py +1 -2
  38. dragon_ml_toolbox-14.7.0.dist-info/RECORD +0 -49
  39. ml_tools/RNN_forecast.py +0 -56
  40. ml_tools/_ML_vision_recipe.py +0 -88
  41. {dragon_ml_toolbox-14.7.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/WHEEL +0 -0
  42. {dragon_ml_toolbox-14.7.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/licenses/LICENSE +0 -0
  43. {dragon_ml_toolbox-14.7.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +0 -0
  44. {dragon_ml_toolbox-14.7.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/top_level.txt +0 -0
ml_tools/ML_evaluation.py CHANGED
@@ -24,7 +24,14 @@ import warnings
24
24
  from .path_manager import make_fullpath, sanitize_filename
25
25
  from ._logger import _LOGGER
26
26
  from ._script_info import _script_info
27
- from .keys import SHAPKeys, PyTorchLogKeys
27
+ from ._keys import SHAPKeys, PyTorchLogKeys
28
+ from .ML_configuration import (RegressionMetricsFormat,
29
+ BinaryClassificationMetricsFormat,
30
+ MultiClassClassificationMetricsFormat,
31
+ BinaryImageClassificationMetricsFormat,
32
+ MultiClassImageClassificationMetricsFormat,
33
+ _BaseClassificationFormat,
34
+ _BaseRegressionFormat)
28
35
 
29
36
 
30
37
  __all__ = [
@@ -41,6 +48,7 @@ DPI_value = 250
41
48
  def plot_losses(history: dict, save_dir: Union[str, Path]):
42
49
  """
43
50
  Plots training & validation loss curves from a history object.
51
+ Also plots the learning rate if available in the history.
44
52
 
45
53
  Args:
46
54
  history (dict): A dictionary containing 'train_loss' and 'val_loss'.
@@ -48,6 +56,7 @@ def plot_losses(history: dict, save_dir: Union[str, Path]):
48
56
  """
49
57
  train_loss = history.get(PyTorchLogKeys.TRAIN_LOSS, [])
50
58
  val_loss = history.get(PyTorchLogKeys.VAL_LOSS, [])
59
+ lr_history = history.get(PyTorchLogKeys.LEARNING_RATE, [])
51
60
 
52
61
  if not train_loss and not val_loss:
53
62
  _LOGGER.warning("Loss history is empty or incomplete. Cannot plot.")
@@ -55,22 +64,44 @@ def plot_losses(history: dict, save_dir: Union[str, Path]):
55
64
 
56
65
  fig, ax = plt.subplots(figsize=(10, 5), dpi=DPI_value)
57
66
 
67
+ # --- Plot Losses (Left Y-axis) ---
68
+ line_handles = [] # To store line objects for the legend
69
+
58
70
  # Plot training loss only if data for it exists
59
71
  if train_loss:
60
72
  epochs = range(1, len(train_loss) + 1)
61
- ax.plot(epochs, train_loss, 'o-', label='Training Loss')
73
+ line1, = ax.plot(epochs, train_loss, 'o-', label='Training Loss', color='tab:blue')
74
+ line_handles.append(line1)
62
75
 
63
76
  # Plot validation loss only if data for it exists
64
77
  if val_loss:
65
78
  epochs = range(1, len(val_loss) + 1)
66
- ax.plot(epochs, val_loss, 'o-', label='Validation Loss')
79
+ line2, = ax.plot(epochs, val_loss, 'o-', label='Validation Loss', color='tab:orange')
80
+ line_handles.append(line2)
67
81
 
68
82
  ax.set_title('Training and Validation Loss')
69
83
  ax.set_xlabel('Epochs')
70
- ax.set_ylabel('Loss')
71
- ax.legend()
72
- ax.grid(True)
73
- plt.tight_layout()
84
+ ax.set_ylabel('Loss', color='tab:blue')
85
+ ax.tick_params(axis='y', labelcolor='tab:blue')
86
+ ax.grid(True, linestyle='--')
87
+
88
+ # --- Plot Learning Rate (Right Y-axis) ---
89
+ if lr_history:
90
+ ax2 = ax.twinx() # Create a second y-axis
91
+ epochs = range(1, len(lr_history) + 1)
92
+ line3, = ax2.plot(epochs, lr_history, 'g--', label='Learning Rate')
93
+ line_handles.append(line3)
94
+
95
+ ax2.set_ylabel('Learning Rate', color='g')
96
+ ax2.tick_params(axis='y', labelcolor='g')
97
+ # Use scientific notation if the LR is very small
98
+ ax2.ticklabel_format(style='sci', axis='y', scilimits=(0,0))
99
+
100
+ # Combine legends from both axes
101
+ ax.legend(handles=line_handles, loc='best')
102
+
103
+ # ax.grid(True)
104
+ plt.tight_layout()
74
105
 
75
106
  save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
76
107
  save_path = save_dir_path / "loss_plot.svg"
@@ -84,24 +115,30 @@ def classification_metrics(save_dir: Union[str, Path],
84
115
  y_true: np.ndarray,
85
116
  y_pred: np.ndarray,
86
117
  y_prob: Optional[np.ndarray] = None,
87
- cmap: str = "Blues",
88
- class_map: Optional[dict[str,int]]=None,
89
- ROC_PR_line: str='darkorange',
90
- calibration_bins: int=15,
91
- font_size: int=16):
118
+ class_map: Optional[dict[str,int]] = None,
119
+ config: Optional[Union[BinaryClassificationMetricsFormat,
120
+ MultiClassClassificationMetricsFormat,
121
+ BinaryImageClassificationMetricsFormat,
122
+ MultiClassImageClassificationMetricsFormat]] = None):
92
123
  """
93
124
  Saves classification metrics and plots.
94
125
 
95
126
  Args:
96
127
  y_true (np.ndarray): Ground truth labels.
97
128
  y_pred (np.ndarray): Predicted labels.
98
- y_prob (np.ndarray, optional): Predicted probabilities for ROC curve.
99
- cmap (str): Colormap for the confusion matrix.
129
+ y_prob (np.ndarray): Predicted probabilities for ROC curve.
130
+ config (object): Formatting configuration object.
100
131
  save_dir (str | Path): Directory to save plots.
101
- class_map (dict[str, int], None): A map of {class_name: index} used to order and label the confusion matrix.
102
132
  """
133
+ # --- Parse Config or use defaults ---
134
+ if config is None:
135
+ # Create a default config if one wasn't provided
136
+ format_config = _BaseClassificationFormat()
137
+ else:
138
+ format_config = config
139
+
103
140
  original_rc_params = plt.rcParams.copy()
104
- plt.rcParams.update({'font.size': font_size})
141
+ plt.rcParams.update({'font.size': format_config.font_size})
105
142
 
106
143
  # print("--- Classification Report ---")
107
144
 
@@ -136,7 +173,7 @@ def classification_metrics(save_dir: Union[str, Path],
136
173
  sns.set_theme(font_scale=1.2) # Scale seaborn font
137
174
  sns.heatmap(pd.DataFrame(report_dict).iloc[:-1, :].T,
138
175
  annot=True,
139
- cmap=cmap,
176
+ cmap=format_config.cmap,
140
177
  fmt='.2f',
141
178
  vmin=0.0,
142
179
  vmax=1.0)
@@ -158,7 +195,7 @@ def classification_metrics(save_dir: Union[str, Path],
158
195
  fig_cm, ax_cm = plt.subplots(figsize=(6, 6), dpi=DPI_value)
159
196
  disp_ = ConfusionMatrixDisplay.from_predictions(y_true,
160
197
  y_pred,
161
- cmap=cmap,
198
+ cmap=format_config.cmap,
162
199
  ax=ax_cm,
163
200
  normalize='true',
164
201
  labels=plot_labels,
@@ -171,7 +208,7 @@ def classification_metrics(save_dir: Union[str, Path],
171
208
 
172
209
  # Manually update font size of cell texts
173
210
  for text in ax_cm.texts:
174
- text.set_fontsize(font_size)
211
+ text.set_fontsize(format_config.font_size)
175
212
 
176
213
  fig_cm.tight_layout()
177
214
 
@@ -196,10 +233,10 @@ def classification_metrics(save_dir: Union[str, Path],
196
233
  class_indices_to_plot = [1]
197
234
  plot_titles = [""] # No extra title
198
235
  save_suffixes = [""] # No extra suffix
199
- _LOGGER.info("Generating binary classification plots (ROC, PR, Calibration).")
236
+ _LOGGER.debug("Generating binary classification plots (ROC, PR, Calibration).")
200
237
 
201
238
  elif num_classes > 2:
202
- _LOGGER.info(f"Generating One-vs-Rest plots for {num_classes} classes.")
239
+ _LOGGER.debug(f"Generating One-vs-Rest plots for {num_classes} classes.")
203
240
  # Multiclass case: Plot for every class (One-vs-Rest)
204
241
  class_indices_to_plot = list(range(num_classes))
205
242
 
@@ -234,17 +271,54 @@ def classification_metrics(save_dir: Union[str, Path],
234
271
 
235
272
  # Binarize y_true for the current class
236
273
  y_true_binary = (y_true == class_index).astype(int)
237
-
274
+
238
275
  # --- Save ROC Curve ---
239
- fpr, tpr, _ = roc_curve(y_true_binary, y_score)
276
+ fpr, tpr, thresholds = roc_curve(y_true_binary, y_score)
277
+
278
+ try:
279
+ # Calculate Youden's J statistic (tpr - fpr)
280
+ J = tpr - fpr
281
+ # Find the index of the best threshold
282
+ best_index = np.argmax(J)
283
+ optimal_threshold = thresholds[best_index]
284
+
285
+ # Define the filename
286
+ threshold_filename = f"best_threshold{save_suffix}.txt"
287
+ threshold_path = save_dir_path / threshold_filename
288
+
289
+ # Get the class name for the report
290
+ class_name = ""
291
+ # Check if we have display labels and the current index is valid
292
+ if map_display_labels and class_index < len(map_display_labels):
293
+ class_name = map_display_labels[class_index]
294
+ if num_classes > 2:
295
+ # Add 'vs. Rest' for multiclass one-vs-rest plots
296
+ class_name += " (vs. Rest)"
297
+ else:
298
+ # Fallback to the generic title or default binary name
299
+ class_name = plot_title.strip() or "Binary Positive Class"
300
+
301
+ # Create content for the file
302
+ file_content = (
303
+ f"Optimal Classification Threshold (Youden's J Statistic)\n"
304
+ f"Class: {class_name}\n"
305
+ f"--------------------------------------------------\n"
306
+ f"Threshold: {optimal_threshold:.6f}\n"
307
+ f"True Positive Rate (TPR): {tpr[best_index]:.6f}\n"
308
+ f"False Positive Rate (FPR): {fpr[best_index]:.6f}\n"
309
+ )
310
+
311
+ threshold_path.write_text(file_content, encoding="utf-8")
312
+ _LOGGER.info(f"💾 Optimal threshold saved as '{threshold_path.name}'")
313
+
314
+ except Exception as e:
315
+ _LOGGER.warning(f"Could not calculate or save optimal threshold: {e}")
240
316
 
241
317
  # Calculate AUC.
242
- # Note: For multiclass, roc_auc_score(y_true, y_prob, multi_class='ovr') could average, but plotting individual curves is more informative.
243
- # Here we calculate the specific AUC for the binarized problem.
244
318
  auc = roc_auc_score(y_true_binary, y_score)
245
319
 
246
320
  fig_roc, ax_roc = plt.subplots(figsize=(6, 6), dpi=DPI_value)
247
- ax_roc.plot(fpr, tpr, label=f'AUC = {auc:.2f}', color=ROC_PR_line)
321
+ ax_roc.plot(fpr, tpr, label=f'AUC = {auc:.2f}', color=format_config.ROC_PR_line)
248
322
  ax_roc.plot([0, 1], [0, 1], 'k--')
249
323
  ax_roc.set_title(f'Receiver Operating Characteristic{plot_title}')
250
324
  ax_roc.set_xlabel('False Positive Rate')
@@ -259,7 +333,7 @@ def classification_metrics(save_dir: Union[str, Path],
259
333
  precision, recall, _ = precision_recall_curve(y_true_binary, y_score)
260
334
  ap_score = average_precision_score(y_true_binary, y_score)
261
335
  fig_pr, ax_pr = plt.subplots(figsize=(6, 6), dpi=DPI_value)
262
- ax_pr.plot(recall, precision, label=f'Avg Precision = {ap_score:.2f}', color=ROC_PR_line)
336
+ ax_pr.plot(recall, precision, label=f'Avg Precision = {ap_score:.2f}', color=format_config.ROC_PR_line)
263
337
  ax_pr.set_title(f'Precision-Recall Curve{plot_title}')
264
338
  ax_pr.set_xlabel('Recall')
265
339
  ax_pr.set_ylabel('Precision')
@@ -278,7 +352,7 @@ def classification_metrics(save_dir: Union[str, Path],
278
352
  cal_display_temp = CalibrationDisplay.from_predictions(
279
353
  y_true_binary, # Use binarized labels
280
354
  y_score,
281
- n_bins=calibration_bins,
355
+ n_bins=format_config.calibration_bins,
282
356
  ax=ax_temp,
283
357
  name="temp" # Add a name to suppress potential warnings
284
358
  )
@@ -294,9 +368,9 @@ def classification_metrics(save_dir: Union[str, Path],
294
368
  y=line_y,
295
369
  ax=ax_cal,
296
370
  scatter=False,
297
- label=f"Calibration Curve ({calibration_bins} bins)",
371
+ label=f"Calibration Curve ({format_config.calibration_bins} bins)",
298
372
  line_kws={
299
- 'color': ROC_PR_line,
373
+ 'color': format_config.ROC_PR_line,
300
374
  'linestyle': '--',
301
375
  'linewidth': 2,
302
376
  }
@@ -324,7 +398,12 @@ def classification_metrics(save_dir: Union[str, Path],
324
398
  plt.rcParams.update(original_rc_params)
325
399
 
326
400
 
327
- def regression_metrics(y_true: np.ndarray, y_pred: np.ndarray, save_dir: Union[str, Path]):
401
+ def regression_metrics(
402
+ y_true: np.ndarray,
403
+ y_pred: np.ndarray,
404
+ save_dir: Union[str, Path],
405
+ config: Optional[RegressionMetricsFormat] = None
406
+ ):
328
407
  """
329
408
  Saves regression metrics and plots.
330
409
 
@@ -332,7 +411,21 @@ def regression_metrics(y_true: np.ndarray, y_pred: np.ndarray, save_dir: Union[s
332
411
  y_true (np.ndarray): Ground truth values.
333
412
  y_pred (np.ndarray): Predicted values.
334
413
  save_dir (str | Path): Directory to save plots and report.
414
+ config (RegressionMetricsFormat, optional): Formatting configuration object.
335
415
  """
416
+
417
+ # --- Parse Config or use defaults ---
418
+ if config is None:
419
+ # Create a default config if one wasn't provided
420
+ format_config = _BaseRegressionFormat()
421
+ else:
422
+ format_config = config
423
+
424
+ # --- Set Matplotlib font size ---
425
+ original_rc_params = plt.rcParams.copy()
426
+ plt.rcParams.update({'font.size': format_config.font_size})
427
+
428
+ # --- Calculate Metrics ---
336
429
  rmse = np.sqrt(mean_squared_error(y_true, y_pred))
337
430
  mae = mean_absolute_error(y_true, y_pred)
338
431
  r2 = r2_score(y_true, y_pred)
@@ -354,11 +447,13 @@ def regression_metrics(y_true: np.ndarray, y_pred: np.ndarray, save_dir: Union[s
354
447
  report_path.write_text(report_string)
355
448
  _LOGGER.info(f"📝 Regression report saved as '{report_path.name}'")
356
449
 
357
- # Save residual plot
450
+ # --- Save residual plot ---
358
451
  residuals = y_true - y_pred
359
452
  fig_res, ax_res = plt.subplots(figsize=(8, 6), dpi=DPI_value)
360
- ax_res.scatter(y_pred, residuals, alpha=0.6)
361
- ax_res.axhline(0, color='red', linestyle='--')
453
+ ax_res.scatter(y_pred, residuals,
454
+ alpha=format_config.scatter_alpha,
455
+ color=format_config.scatter_color)
456
+ ax_res.axhline(0, color=format_config.residual_line_color, linestyle='--')
362
457
  ax_res.set_xlabel("Predicted Values")
363
458
  ax_res.set_ylabel("Residuals")
364
459
  ax_res.set_title("Residual Plot")
@@ -369,10 +464,15 @@ def regression_metrics(y_true: np.ndarray, y_pred: np.ndarray, save_dir: Union[s
369
464
  _LOGGER.info(f"📈 Residual plot saved as '{res_path.name}'")
370
465
  plt.close(fig_res)
371
466
 
372
- # Save true vs predicted plot
467
+ # --- Save true vs predicted plot ---
373
468
  fig_tvp, ax_tvp = plt.subplots(figsize=(8, 6), dpi=DPI_value)
374
- ax_tvp.scatter(y_true, y_pred, alpha=0.6)
375
- ax_tvp.plot([y_true.min(), y_true.max()], [y_true.min(), y_true.max()], 'k--', lw=2)
469
+ ax_tvp.scatter(y_true, y_pred,
470
+ alpha=format_config.scatter_alpha,
471
+ color=format_config.scatter_color)
472
+ ax_tvp.plot([y_true.min(), y_true.max()], [y_true.min(), y_true.max()],
473
+ linestyle='--',
474
+ lw=2,
475
+ color=format_config.ideal_line_color)
376
476
  ax_tvp.set_xlabel('True Values')
377
477
  ax_tvp.set_ylabel('Predictions')
378
478
  ax_tvp.set_title('True vs. Predicted Values')
@@ -383,9 +483,11 @@ def regression_metrics(y_true: np.ndarray, y_pred: np.ndarray, save_dir: Union[s
383
483
  _LOGGER.info(f"📉 True vs. Predicted plot saved as '{tvp_path.name}'")
384
484
  plt.close(fig_tvp)
385
485
 
386
- # Save Histogram of Residuals
486
+ # --- Save Histogram of Residuals ---
387
487
  fig_hist, ax_hist = plt.subplots(figsize=(8, 6), dpi=DPI_value)
388
- sns.histplot(residuals, kde=True, ax=ax_hist)
488
+ sns.histplot(residuals, kde=True, ax=ax_hist,
489
+ bins=format_config.hist_bins,
490
+ color=format_config.scatter_color)
389
491
  ax_hist.set_xlabel("Residual Value")
390
492
  ax_hist.set_ylabel("Frequency")
391
493
  ax_hist.set_title("Distribution of Residuals")
@@ -396,6 +498,9 @@ def regression_metrics(y_true: np.ndarray, y_pred: np.ndarray, save_dir: Union[s
396
498
  _LOGGER.info(f"📊 Residuals histogram saved as '{hist_path.name}'")
397
499
  plt.close(fig_hist)
398
500
 
501
+ # --- Restore RC params ---
502
+ plt.rcParams.update(original_rc_params)
503
+
399
504
 
400
505
  def shap_summary_plot(model,
401
506
  background_data: Union[torch.Tensor,np.ndarray],
@@ -19,13 +19,17 @@ from sklearn.metrics import (
19
19
  jaccard_score
20
20
  )
21
21
  from pathlib import Path
22
- from typing import Union, List, Literal
22
+ from typing import Union, List, Literal, Optional
23
23
  import warnings
24
24
 
25
25
  from .path_manager import make_fullpath, sanitize_filename
26
26
  from ._logger import _LOGGER
27
27
  from ._script_info import _script_info
28
- from .keys import SHAPKeys
28
+ from ._keys import SHAPKeys
29
+ from .ML_configuration import (MultiTargetRegressionMetricsFormat,
30
+ _BaseRegressionFormat,
31
+ MultiLabelBinaryClassificationMetricsFormat,
32
+ _BaseMultiLabelFormat)
29
33
 
30
34
 
31
35
  __all__ = [
@@ -41,7 +45,8 @@ def multi_target_regression_metrics(
41
45
  y_true: np.ndarray,
42
46
  y_pred: np.ndarray,
43
47
  target_names: List[str],
44
- save_dir: Union[str, Path]
48
+ save_dir: Union[str, Path],
49
+ config: Optional[MultiTargetRegressionMetricsFormat] = None
45
50
  ):
46
51
  """
47
52
  Calculates and saves regression metrics for each target individually.
@@ -55,6 +60,7 @@ def multi_target_regression_metrics(
55
60
  y_pred (np.ndarray): Predicted values, shape (n_samples, n_targets).
56
61
  target_names (List[str]): A list of names for the target variables.
57
62
  save_dir (str | Path): Directory to save plots and the report.
63
+ config (object): Formatting configuration object.
58
64
  """
59
65
  if y_true.ndim != 2 or y_pred.ndim != 2:
60
66
  _LOGGER.error("y_true and y_pred must be 2D arrays for multi-target regression.")
@@ -68,8 +74,19 @@ def multi_target_regression_metrics(
68
74
 
69
75
  save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
70
76
  metrics_summary = []
77
+
78
+ # --- Parse Config or use defaults ---
79
+ if config is None:
80
+ # Create a default config if one wasn't provided
81
+ format_config = _BaseRegressionFormat()
82
+ else:
83
+ format_config = config
84
+
85
+ # --- Set Matplotlib font size ---
86
+ original_rc_params = plt.rcParams.copy()
87
+ plt.rcParams.update({'font.size': format_config.font_size})
71
88
 
72
- _LOGGER.info("--- Multi-Target Regression Evaluation ---")
89
+ _LOGGER.debug("--- Multi-Target Regression Evaluation ---")
73
90
 
74
91
  for i, name in enumerate(target_names):
75
92
  print(f" -> Evaluating target: '{name}'")
@@ -93,8 +110,12 @@ def multi_target_regression_metrics(
93
110
  # --- Save Residual Plot ---
94
111
  residuals = true_i - pred_i
95
112
  fig_res, ax_res = plt.subplots(figsize=(8, 6), dpi=DPI_value)
96
- ax_res.scatter(pred_i, residuals, alpha=0.6, edgecolors='k', s=50)
97
- ax_res.axhline(0, color='red', linestyle='--')
113
+ ax_res.scatter(pred_i, residuals,
114
+ alpha=format_config.scatter_alpha,
115
+ edgecolors='k',
116
+ s=50,
117
+ color=format_config.scatter_color) # Use config color
118
+ ax_res.axhline(0, color=format_config.residual_line_color, linestyle='--') # Use config color
98
119
  ax_res.set_xlabel("Predicted Values")
99
120
  ax_res.set_ylabel("Residuals (True - Predicted)")
100
121
  ax_res.set_title(f"Residual Plot for '{name}'")
@@ -106,8 +127,15 @@ def multi_target_regression_metrics(
106
127
 
107
128
  # --- Save True vs. Predicted Plot ---
108
129
  fig_tvp, ax_tvp = plt.subplots(figsize=(8, 6), dpi=DPI_value)
109
- ax_tvp.scatter(true_i, pred_i, alpha=0.6, edgecolors='k', s=50)
110
- ax_tvp.plot([true_i.min(), true_i.max()], [true_i.min(), true_i.max()], 'k--', lw=2)
130
+ ax_tvp.scatter(true_i, pred_i,
131
+ alpha=format_config.scatter_alpha,
132
+ edgecolors='k',
133
+ s=50,
134
+ color=format_config.scatter_color) # Use config color
135
+ ax_tvp.plot([true_i.min(), true_i.max()], [true_i.min(), true_i.max()],
136
+ linestyle='--',
137
+ lw=2,
138
+ color=format_config.ideal_line_color) # Use config color
111
139
  ax_tvp.set_xlabel('True Values')
112
140
  ax_tvp.set_ylabel('Predicted Values')
113
141
  ax_tvp.set_title(f'True vs. Predicted Values for "{name}"')
@@ -122,17 +150,18 @@ def multi_target_regression_metrics(
122
150
  report_path = save_dir_path / "regression_report_multi.csv"
123
151
  summary_df.to_csv(report_path, index=False)
124
152
  _LOGGER.info(f"Full regression report saved to '{report_path.name}'")
153
+
154
+ # --- Restore RC params ---
155
+ plt.rcParams.update(original_rc_params)
125
156
 
126
157
 
127
158
  def multi_label_classification_metrics(
128
159
  y_true: np.ndarray,
160
+ y_pred: np.ndarray,
129
161
  y_prob: np.ndarray,
130
162
  target_names: List[str],
131
163
  save_dir: Union[str, Path],
132
- threshold: float = 0.5,
133
- ROC_PR_line: str='darkorange',
134
- cmap: str = "Blues",
135
- font_size: int = 16
164
+ config: Optional[MultiLabelBinaryClassificationMetricsFormat] = None
136
165
  ):
137
166
  """
138
167
  Calculates and saves classification metrics for each label individually.
@@ -143,17 +172,17 @@ def multi_label_classification_metrics(
143
172
 
144
173
  Args:
145
174
  y_true (np.ndarray): Ground truth binary labels, shape (n_samples, n_labels).
175
+ y_pred (np.ndarray): Predicted binary labels, shape (n_samples, n_labels).
146
176
  y_prob (np.ndarray): Predicted probabilities, shape (n_samples, n_labels).
147
177
  target_names (List[str]): A list of names for the labels.
148
178
  save_dir (str | Path): Directory to save plots and reports.
149
- threshold (float): The probability threshold to convert probabilities into
150
- binary predictions for metrics like the confusion matrix.
179
+ config (object): Formatting configuration object.
151
180
  """
152
- if y_true.ndim != 2 or y_prob.ndim != 2:
153
- _LOGGER.error("y_true and y_prob must be 2D arrays for multi-label classification.")
181
+ if y_true.ndim != 2 or y_prob.ndim != 2 or y_pred.ndim != 2:
182
+ _LOGGER.error("y_true, y_pred, and y_prob must be 2D arrays for multi-label classification.")
154
183
  raise ValueError()
155
- if y_true.shape != y_prob.shape:
156
- _LOGGER.error("Shapes of y_true and y_prob must match.")
184
+ if y_true.shape != y_prob.shape or y_true.shape != y_pred.shape:
185
+ _LOGGER.error("Shapes of y_true, y_pred, and y_prob must match.")
157
186
  raise ValueError()
158
187
  if y_true.shape[1] != len(target_names):
159
188
  _LOGGER.error("Number of target names must match the number of columns in y_true.")
@@ -161,22 +190,28 @@ def multi_label_classification_metrics(
161
190
 
162
191
  save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
163
192
 
164
- # Generate binary predictions from probabilities
165
- y_pred = (y_prob >= threshold).astype(int)
193
+ # --- Parse Config or use defaults ---
194
+ if config is None:
195
+ # Create a default config if one wasn't provided
196
+ format_config = _BaseMultiLabelFormat()
197
+ else:
198
+ format_config = config
199
+
200
+ # y_pred is now passed in directly, no threshold needed.
166
201
 
167
202
  # --- Save current RC params and update font size ---
168
203
  original_rc_params = plt.rcParams.copy()
169
- plt.rcParams.update({'font.size': font_size})
204
+ plt.rcParams.update({'font.size': format_config.font_size})
170
205
 
171
- _LOGGER.info("--- Multi-Label Classification Evaluation ---")
206
+ # _LOGGER.info("--- Multi-Label Classification Evaluation ---")
172
207
 
173
- # --- Calculate and Save Overall Metrics ---
208
+ # --- Calculate and Save Overall Metrics (using y_pred) ---
174
209
  h_loss = hamming_loss(y_true, y_pred)
175
210
  j_score_micro = jaccard_score(y_true, y_pred, average='micro')
176
211
  j_score_macro = jaccard_score(y_true, y_pred, average='macro')
177
212
 
178
213
  overall_report = (
179
- f"Overall Multi-Label Metrics (Threshold = {threshold}):\n"
214
+ f"Overall Multi-Label Metrics:\n" # No threshold to report here
180
215
  f"--------------------------------------------------\n"
181
216
  f"Hamming Loss: {h_loss:.4f}\n"
182
217
  f"Jaccard Score (micro): {j_score_micro:.4f}\n"
@@ -191,20 +226,20 @@ def multi_label_classification_metrics(
191
226
  for i, name in enumerate(target_names):
192
227
  print(f" -> Evaluating label: '{name}'")
193
228
  true_i = y_true[:, i]
194
- pred_i = y_pred[:, i]
195
- prob_i = y_prob[:, i]
229
+ pred_i = y_pred[:, i] # Use passed-in y_pred
230
+ prob_i = y_prob[:, i] # Use passed-in y_prob
196
231
  sanitized_name = sanitize_filename(name)
197
232
 
198
- # --- Save Classification Report for the label ---
233
+ # --- Save Classification Report for the label (uses y_pred) ---
199
234
  report_text = classification_report(true_i, pred_i)
200
235
  report_path = save_dir_path / f"classification_report_{sanitized_name}.txt"
201
236
  report_path.write_text(report_text) # type: ignore
202
237
 
203
- # --- Save Confusion Matrix ---
238
+ # --- Save Confusion Matrix (uses y_pred) ---
204
239
  fig_cm, ax_cm = plt.subplots(figsize=(6, 6), dpi=DPI_value)
205
240
  disp_ = ConfusionMatrixDisplay.from_predictions(true_i,
206
241
  pred_i,
207
- cmap=cmap,
242
+ cmap=format_config.cmap, # Use config cmap
208
243
  ax=ax_cm,
209
244
  normalize='true',
210
245
  labels=[0, 1],
@@ -217,7 +252,7 @@ def multi_label_classification_metrics(
217
252
 
218
253
  # Manually update font size of cell texts
219
254
  for text in ax_cm.texts:
220
- text.set_fontsize(font_size)
255
+ text.set_fontsize(format_config.font_size) # Use config font_size
221
256
 
222
257
  fig_cm.tight_layout()
223
258
 
@@ -226,11 +261,44 @@ def multi_label_classification_metrics(
226
261
  plt.savefig(cm_path)
227
262
  plt.close(fig_cm)
228
263
 
229
- # --- Save ROC Curve ---
230
- fpr, tpr, _ = roc_curve(true_i, prob_i)
264
+ # --- Save ROC Curve (uses y_prob) ---
265
+ fpr, tpr, thresholds = roc_curve(true_i, prob_i)
266
+
267
+ try:
268
+ # Calculate Youden's J statistic (tpr - fpr)
269
+ J = tpr - fpr
270
+ # Find the index of the best threshold
271
+ best_index = np.argmax(J)
272
+ optimal_threshold = thresholds[best_index]
273
+ best_tpr = tpr[best_index]
274
+ best_fpr = fpr[best_index]
275
+
276
+ # Define the filename
277
+ threshold_filename = f"best_threshold_{sanitized_name}.txt"
278
+ threshold_path = save_dir_path / threshold_filename
279
+
280
+ # The class name is the target_name for this label
281
+ class_name = name
282
+
283
+ # Create content for the file
284
+ file_content = (
285
+ f"Optimal Classification Threshold (Youden's J Statistic)\n"
286
+ f"Class/Label: {class_name}\n"
287
+ f"--------------------------------------------------\n"
288
+ f"Threshold: {optimal_threshold:.6f}\n"
289
+ f"True Positive Rate (TPR): {best_tpr:.6f}\n"
290
+ f"False Positive Rate (FPR): {best_fpr:.6f}\n"
291
+ )
292
+
293
+ threshold_path.write_text(file_content, encoding="utf-8")
294
+ _LOGGER.info(f"💾 Optimal threshold for '{name}' saved to '{threshold_path.name}'")
295
+
296
+ except Exception as e:
297
+ _LOGGER.warning(f"Could not calculate or save optimal threshold for '{name}': {e}")
298
+
231
299
  auc = roc_auc_score(true_i, prob_i)
232
300
  fig_roc, ax_roc = plt.subplots(figsize=(6, 6), dpi=DPI_value)
233
- ax_roc.plot(fpr, tpr, label=f'AUC = {auc:.2f}', color=ROC_PR_line)
301
+ ax_roc.plot(fpr, tpr, label=f'AUC = {auc:.2f}', color=format_config.ROC_PR_line) # Use config color
234
302
  ax_roc.plot([0, 1], [0, 1], 'k--')
235
303
  ax_roc.set_title(f'ROC Curve for "{name}"')
236
304
  ax_roc.set_xlabel('False Positive Rate'); ax_roc.set_ylabel('True Positive Rate')
@@ -239,11 +307,11 @@ def multi_label_classification_metrics(
239
307
  plt.savefig(roc_path)
240
308
  plt.close(fig_roc)
241
309
 
242
- # --- Save Precision-Recall Curve ---
310
+ # --- Save Precision-Recall Curve (uses y_prob) ---
243
311
  precision, recall, _ = precision_recall_curve(true_i, prob_i)
244
312
  ap_score = average_precision_score(true_i, prob_i)
245
313
  fig_pr, ax_pr = plt.subplots(figsize=(6, 6), dpi=DPI_value)
246
- ax_pr.plot(recall, precision, label=f'AP = {ap_score:.2f}', color=ROC_PR_line)
314
+ ax_pr.plot(recall, precision, label=f'AP = {ap_score:.2f}', color=format_config.ROC_PR_line) # Use config color
247
315
  ax_pr.set_title(f'Precision-Recall Curve for "{name}"')
248
316
  ax_pr.set_xlabel('Recall'); ax_pr.set_ylabel('Precision')
249
317
  ax_pr.legend(loc='lower left'); ax_pr.grid(True, linestyle='--', alpha=0.6)