dragon-ml-toolbox 14.7.0__py3-none-any.whl → 16.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. {dragon_ml_toolbox-14.7.0.dist-info → dragon_ml_toolbox-16.2.1.dist-info}/METADATA +9 -5
  2. dragon_ml_toolbox-16.2.1.dist-info/RECORD +51 -0
  3. ml_tools/ETL_cleaning.py +20 -20
  4. ml_tools/ETL_engineering.py +23 -25
  5. ml_tools/GUI_tools.py +20 -20
  6. ml_tools/MICE_imputation.py +3 -3
  7. ml_tools/ML_callbacks.py +43 -26
  8. ml_tools/ML_configuration.py +726 -32
  9. ml_tools/ML_datasetmaster.py +235 -280
  10. ml_tools/ML_evaluation.py +160 -42
  11. ml_tools/ML_evaluation_multi.py +103 -35
  12. ml_tools/ML_inference.py +290 -208
  13. ml_tools/ML_models.py +13 -102
  14. ml_tools/ML_models_advanced.py +1 -1
  15. ml_tools/ML_optimization.py +12 -12
  16. ml_tools/ML_scaler.py +11 -11
  17. ml_tools/ML_sequence_datasetmaster.py +341 -0
  18. ml_tools/ML_sequence_evaluation.py +219 -0
  19. ml_tools/ML_sequence_inference.py +391 -0
  20. ml_tools/ML_sequence_models.py +139 -0
  21. ml_tools/ML_trainer.py +1342 -386
  22. ml_tools/ML_utilities.py +1 -1
  23. ml_tools/ML_vision_datasetmaster.py +120 -72
  24. ml_tools/ML_vision_evaluation.py +30 -6
  25. ml_tools/ML_vision_inference.py +129 -152
  26. ml_tools/ML_vision_models.py +1 -1
  27. ml_tools/ML_vision_transformers.py +121 -40
  28. ml_tools/PSO_optimization.py +6 -6
  29. ml_tools/SQL.py +4 -4
  30. ml_tools/{keys.py → _keys.py} +45 -0
  31. ml_tools/_schema.py +1 -1
  32. ml_tools/ensemble_evaluation.py +1 -1
  33. ml_tools/ensemble_inference.py +7 -33
  34. ml_tools/ensemble_learning.py +1 -1
  35. ml_tools/optimization_tools.py +2 -2
  36. ml_tools/path_manager.py +5 -5
  37. ml_tools/utilities.py +1 -2
  38. dragon_ml_toolbox-14.7.0.dist-info/RECORD +0 -49
  39. ml_tools/RNN_forecast.py +0 -56
  40. ml_tools/_ML_vision_recipe.py +0 -88
  41. {dragon_ml_toolbox-14.7.0.dist-info → dragon_ml_toolbox-16.2.1.dist-info}/WHEEL +0 -0
  42. {dragon_ml_toolbox-14.7.0.dist-info → dragon_ml_toolbox-16.2.1.dist-info}/licenses/LICENSE +0 -0
  43. {dragon_ml_toolbox-14.7.0.dist-info → dragon_ml_toolbox-16.2.1.dist-info}/licenses/LICENSE-THIRD-PARTY.md +0 -0
  44. {dragon_ml_toolbox-14.7.0.dist-info → dragon_ml_toolbox-16.2.1.dist-info}/top_level.txt +0 -0
ml_tools/ML_evaluation.py CHANGED
@@ -24,7 +24,14 @@ import warnings
24
24
  from .path_manager import make_fullpath, sanitize_filename
25
25
  from ._logger import _LOGGER
26
26
  from ._script_info import _script_info
27
- from .keys import SHAPKeys, PyTorchLogKeys
27
+ from ._keys import SHAPKeys, PyTorchLogKeys
28
+ from .ML_configuration import (RegressionMetricsFormat,
29
+ BinaryClassificationMetricsFormat,
30
+ MultiClassClassificationMetricsFormat,
31
+ BinaryImageClassificationMetricsFormat,
32
+ MultiClassImageClassificationMetricsFormat,
33
+ _BaseClassificationFormat,
34
+ _BaseRegressionFormat)
28
35
 
29
36
 
30
37
  __all__ = [
@@ -41,6 +48,7 @@ DPI_value = 250
41
48
  def plot_losses(history: dict, save_dir: Union[str, Path]):
42
49
  """
43
50
  Plots training & validation loss curves from a history object.
51
+ Also plots the learning rate if available in the history.
44
52
 
45
53
  Args:
46
54
  history (dict): A dictionary containing 'train_loss' and 'val_loss'.
@@ -48,6 +56,7 @@ def plot_losses(history: dict, save_dir: Union[str, Path]):
48
56
  """
49
57
  train_loss = history.get(PyTorchLogKeys.TRAIN_LOSS, [])
50
58
  val_loss = history.get(PyTorchLogKeys.VAL_LOSS, [])
59
+ lr_history = history.get(PyTorchLogKeys.LEARNING_RATE, [])
51
60
 
52
61
  if not train_loss and not val_loss:
53
62
  _LOGGER.warning("Loss history is empty or incomplete. Cannot plot.")
@@ -55,22 +64,44 @@ def plot_losses(history: dict, save_dir: Union[str, Path]):
55
64
 
56
65
  fig, ax = plt.subplots(figsize=(10, 5), dpi=DPI_value)
57
66
 
67
+ # --- Plot Losses (Left Y-axis) ---
68
+ line_handles = [] # To store line objects for the legend
69
+
58
70
  # Plot training loss only if data for it exists
59
71
  if train_loss:
60
72
  epochs = range(1, len(train_loss) + 1)
61
- ax.plot(epochs, train_loss, 'o-', label='Training Loss')
73
+ line1, = ax.plot(epochs, train_loss, 'o-', label='Training Loss', color='tab:blue')
74
+ line_handles.append(line1)
62
75
 
63
76
  # Plot validation loss only if data for it exists
64
77
  if val_loss:
65
78
  epochs = range(1, len(val_loss) + 1)
66
- ax.plot(epochs, val_loss, 'o-', label='Validation Loss')
79
+ line2, = ax.plot(epochs, val_loss, 'o-', label='Validation Loss', color='tab:orange')
80
+ line_handles.append(line2)
67
81
 
68
82
  ax.set_title('Training and Validation Loss')
69
83
  ax.set_xlabel('Epochs')
70
- ax.set_ylabel('Loss')
71
- ax.legend()
72
- ax.grid(True)
73
- plt.tight_layout()
84
+ ax.set_ylabel('Loss', color='tab:blue')
85
+ ax.tick_params(axis='y', labelcolor='tab:blue')
86
+ ax.grid(True, linestyle='--')
87
+
88
+ # --- Plot Learning Rate (Right Y-axis) ---
89
+ if lr_history:
90
+ ax2 = ax.twinx() # Create a second y-axis
91
+ epochs = range(1, len(lr_history) + 1)
92
+ line3, = ax2.plot(epochs, lr_history, 'g--', label='Learning Rate')
93
+ line_handles.append(line3)
94
+
95
+ ax2.set_ylabel('Learning Rate', color='g')
96
+ ax2.tick_params(axis='y', labelcolor='g')
97
+ # Use scientific notation if the LR is very small
98
+ ax2.ticklabel_format(style='sci', axis='y', scilimits=(0,0))
99
+
100
+ # Combine legends from both axes
101
+ ax.legend(handles=line_handles, loc='best')
102
+
103
+ # ax.grid(True)
104
+ plt.tight_layout()
74
105
 
75
106
  save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
76
107
  save_path = save_dir_path / "loss_plot.svg"
@@ -84,24 +115,30 @@ def classification_metrics(save_dir: Union[str, Path],
84
115
  y_true: np.ndarray,
85
116
  y_pred: np.ndarray,
86
117
  y_prob: Optional[np.ndarray] = None,
87
- cmap: str = "Blues",
88
- class_map: Optional[dict[str,int]]=None,
89
- ROC_PR_line: str='darkorange',
90
- calibration_bins: int=15,
91
- font_size: int=16):
118
+ class_map: Optional[dict[str,int]] = None,
119
+ config: Optional[Union[BinaryClassificationMetricsFormat,
120
+ MultiClassClassificationMetricsFormat,
121
+ BinaryImageClassificationMetricsFormat,
122
+ MultiClassImageClassificationMetricsFormat]] = None):
92
123
  """
93
124
  Saves classification metrics and plots.
94
125
 
95
126
  Args:
96
127
  y_true (np.ndarray): Ground truth labels.
97
128
  y_pred (np.ndarray): Predicted labels.
98
- y_prob (np.ndarray, optional): Predicted probabilities for ROC curve.
99
- cmap (str): Colormap for the confusion matrix.
129
+ y_prob (np.ndarray): Predicted probabilities for ROC curve.
130
+ config (object): Formatting configuration object.
100
131
  save_dir (str | Path): Directory to save plots.
101
- class_map (dict[str, int], None): A map of {class_name: index} used to order and label the confusion matrix.
102
132
  """
133
+ # --- Parse Config or use defaults ---
134
+ if config is None:
135
+ # Create a default config if one wasn't provided
136
+ format_config = _BaseClassificationFormat()
137
+ else:
138
+ format_config = config
139
+
103
140
  original_rc_params = plt.rcParams.copy()
104
- plt.rcParams.update({'font.size': font_size})
141
+ plt.rcParams.update({'font.size': format_config.font_size})
105
142
 
106
143
  # print("--- Classification Report ---")
107
144
 
@@ -132,16 +169,29 @@ def classification_metrics(save_dir: Union[str, Path],
132
169
 
133
170
  # --- Save Classification Report Heatmap ---
134
171
  try:
135
- plt.figure(figsize=(8, 6), dpi=DPI_value)
172
+ # Create DataFrame from report
173
+ report_df = pd.DataFrame(report_dict)
174
+
175
+ # 1. Drop the 'accuracy' column (single float)
176
+ if 'accuracy' in report_df.columns:
177
+ report_df = report_df.drop(columns=['accuracy'])
178
+
179
+ # 2. Select all metric rows *except* the last one ('support')
180
+ # 3. Transpose the DataFrame
181
+ plot_df = report_df.iloc[:-1, :].T
182
+
183
+ fig_height = max(5.0, len(plot_df.index) * 0.5 + 2.0)
184
+ plt.figure(figsize=(7, fig_height), dpi=DPI_value)
185
+
136
186
  sns.set_theme(font_scale=1.2) # Scale seaborn font
137
- sns.heatmap(pd.DataFrame(report_dict).iloc[:-1, :].T,
187
+ sns.heatmap(plot_df,
138
188
  annot=True,
139
- cmap=cmap,
189
+ cmap=format_config.cmap,
140
190
  fmt='.2f',
141
191
  vmin=0.0,
142
192
  vmax=1.0)
143
193
  sns.set_theme(font_scale=1.0) # Reset seaborn scale
144
- plt.title("Classification Report")
194
+ plt.title("Classification Report Heatmap")
145
195
  plt.tight_layout()
146
196
  heatmap_path = save_dir_path / "classification_report_heatmap.svg"
147
197
  plt.savefig(heatmap_path)
@@ -158,7 +208,7 @@ def classification_metrics(save_dir: Union[str, Path],
158
208
  fig_cm, ax_cm = plt.subplots(figsize=(6, 6), dpi=DPI_value)
159
209
  disp_ = ConfusionMatrixDisplay.from_predictions(y_true,
160
210
  y_pred,
161
- cmap=cmap,
211
+ cmap=format_config.cmap,
162
212
  ax=ax_cm,
163
213
  normalize='true',
164
214
  labels=plot_labels,
@@ -171,7 +221,7 @@ def classification_metrics(save_dir: Union[str, Path],
171
221
 
172
222
  # Manually update font size of cell texts
173
223
  for text in ax_cm.texts:
174
- text.set_fontsize(font_size)
224
+ text.set_fontsize(format_config.font_size)
175
225
 
176
226
  fig_cm.tight_layout()
177
227
 
@@ -196,10 +246,10 @@ def classification_metrics(save_dir: Union[str, Path],
196
246
  class_indices_to_plot = [1]
197
247
  plot_titles = [""] # No extra title
198
248
  save_suffixes = [""] # No extra suffix
199
- _LOGGER.info("Generating binary classification plots (ROC, PR, Calibration).")
249
+ _LOGGER.debug("Generating binary classification plots (ROC, PR, Calibration).")
200
250
 
201
251
  elif num_classes > 2:
202
- _LOGGER.info(f"Generating One-vs-Rest plots for {num_classes} classes.")
252
+ _LOGGER.debug(f"Generating One-vs-Rest plots for {num_classes} classes.")
203
253
  # Multiclass case: Plot for every class (One-vs-Rest)
204
254
  class_indices_to_plot = list(range(num_classes))
205
255
 
@@ -234,17 +284,54 @@ def classification_metrics(save_dir: Union[str, Path],
234
284
 
235
285
  # Binarize y_true for the current class
236
286
  y_true_binary = (y_true == class_index).astype(int)
237
-
287
+
238
288
  # --- Save ROC Curve ---
239
- fpr, tpr, _ = roc_curve(y_true_binary, y_score)
289
+ fpr, tpr, thresholds = roc_curve(y_true_binary, y_score)
290
+
291
+ try:
292
+ # Calculate Youden's J statistic (tpr - fpr)
293
+ J = tpr - fpr
294
+ # Find the index of the best threshold
295
+ best_index = np.argmax(J)
296
+ optimal_threshold = thresholds[best_index]
297
+
298
+ # Define the filename
299
+ threshold_filename = f"best_threshold{save_suffix}.txt"
300
+ threshold_path = save_dir_path / threshold_filename
301
+
302
+ # Get the class name for the report
303
+ class_name = ""
304
+ # Check if we have display labels and the current index is valid
305
+ if map_display_labels and class_index < len(map_display_labels):
306
+ class_name = map_display_labels[class_index]
307
+ if num_classes > 2:
308
+ # Add 'vs. Rest' for multiclass one-vs-rest plots
309
+ class_name += " (vs. Rest)"
310
+ else:
311
+ # Fallback to the generic title or default binary name
312
+ class_name = plot_title.strip() or "Binary Positive Class"
313
+
314
+ # Create content for the file
315
+ file_content = (
316
+ f"Optimal Classification Threshold (Youden's J Statistic)\n"
317
+ f"Class: {class_name}\n"
318
+ f"--------------------------------------------------\n"
319
+ f"Threshold: {optimal_threshold:.6f}\n"
320
+ f"True Positive Rate (TPR): {tpr[best_index]:.6f}\n"
321
+ f"False Positive Rate (FPR): {fpr[best_index]:.6f}\n"
322
+ )
323
+
324
+ threshold_path.write_text(file_content, encoding="utf-8")
325
+ _LOGGER.info(f"💾 Optimal threshold saved as '{threshold_path.name}'")
326
+
327
+ except Exception as e:
328
+ _LOGGER.warning(f"Could not calculate or save optimal threshold: {e}")
240
329
 
241
330
  # Calculate AUC.
242
- # Note: For multiclass, roc_auc_score(y_true, y_prob, multi_class='ovr') could average, but plotting individual curves is more informative.
243
- # Here we calculate the specific AUC for the binarized problem.
244
331
  auc = roc_auc_score(y_true_binary, y_score)
245
332
 
246
333
  fig_roc, ax_roc = plt.subplots(figsize=(6, 6), dpi=DPI_value)
247
- ax_roc.plot(fpr, tpr, label=f'AUC = {auc:.2f}', color=ROC_PR_line)
334
+ ax_roc.plot(fpr, tpr, label=f'AUC = {auc:.2f}', color=format_config.ROC_PR_line)
248
335
  ax_roc.plot([0, 1], [0, 1], 'k--')
249
336
  ax_roc.set_title(f'Receiver Operating Characteristic{plot_title}')
250
337
  ax_roc.set_xlabel('False Positive Rate')
@@ -259,7 +346,7 @@ def classification_metrics(save_dir: Union[str, Path],
259
346
  precision, recall, _ = precision_recall_curve(y_true_binary, y_score)
260
347
  ap_score = average_precision_score(y_true_binary, y_score)
261
348
  fig_pr, ax_pr = plt.subplots(figsize=(6, 6), dpi=DPI_value)
262
- ax_pr.plot(recall, precision, label=f'Avg Precision = {ap_score:.2f}', color=ROC_PR_line)
349
+ ax_pr.plot(recall, precision, label=f'Avg Precision = {ap_score:.2f}', color=format_config.ROC_PR_line)
263
350
  ax_pr.set_title(f'Precision-Recall Curve{plot_title}')
264
351
  ax_pr.set_xlabel('Recall')
265
352
  ax_pr.set_ylabel('Precision')
@@ -278,7 +365,7 @@ def classification_metrics(save_dir: Union[str, Path],
278
365
  cal_display_temp = CalibrationDisplay.from_predictions(
279
366
  y_true_binary, # Use binarized labels
280
367
  y_score,
281
- n_bins=calibration_bins,
368
+ n_bins=format_config.calibration_bins,
282
369
  ax=ax_temp,
283
370
  name="temp" # Add a name to suppress potential warnings
284
371
  )
@@ -294,9 +381,9 @@ def classification_metrics(save_dir: Union[str, Path],
294
381
  y=line_y,
295
382
  ax=ax_cal,
296
383
  scatter=False,
297
- label=f"Calibration Curve ({calibration_bins} bins)",
384
+ label=f"Calibration Curve ({format_config.calibration_bins} bins)",
298
385
  line_kws={
299
- 'color': ROC_PR_line,
386
+ 'color': format_config.ROC_PR_line,
300
387
  'linestyle': '--',
301
388
  'linewidth': 2,
302
389
  }
@@ -324,7 +411,12 @@ def classification_metrics(save_dir: Union[str, Path],
324
411
  plt.rcParams.update(original_rc_params)
325
412
 
326
413
 
327
- def regression_metrics(y_true: np.ndarray, y_pred: np.ndarray, save_dir: Union[str, Path]):
414
+ def regression_metrics(
415
+ y_true: np.ndarray,
416
+ y_pred: np.ndarray,
417
+ save_dir: Union[str, Path],
418
+ config: Optional[RegressionMetricsFormat] = None
419
+ ):
328
420
  """
329
421
  Saves regression metrics and plots.
330
422
 
@@ -332,7 +424,21 @@ def regression_metrics(y_true: np.ndarray, y_pred: np.ndarray, save_dir: Union[s
332
424
  y_true (np.ndarray): Ground truth values.
333
425
  y_pred (np.ndarray): Predicted values.
334
426
  save_dir (str | Path): Directory to save plots and report.
427
+ config (RegressionMetricsFormat, optional): Formatting configuration object.
335
428
  """
429
+
430
+ # --- Parse Config or use defaults ---
431
+ if config is None:
432
+ # Create a default config if one wasn't provided
433
+ format_config = _BaseRegressionFormat()
434
+ else:
435
+ format_config = config
436
+
437
+ # --- Set Matplotlib font size ---
438
+ original_rc_params = plt.rcParams.copy()
439
+ plt.rcParams.update({'font.size': format_config.font_size})
440
+
441
+ # --- Calculate Metrics ---
336
442
  rmse = np.sqrt(mean_squared_error(y_true, y_pred))
337
443
  mae = mean_absolute_error(y_true, y_pred)
338
444
  r2 = r2_score(y_true, y_pred)
@@ -354,11 +460,13 @@ def regression_metrics(y_true: np.ndarray, y_pred: np.ndarray, save_dir: Union[s
354
460
  report_path.write_text(report_string)
355
461
  _LOGGER.info(f"📝 Regression report saved as '{report_path.name}'")
356
462
 
357
- # Save residual plot
463
+ # --- Save residual plot ---
358
464
  residuals = y_true - y_pred
359
465
  fig_res, ax_res = plt.subplots(figsize=(8, 6), dpi=DPI_value)
360
- ax_res.scatter(y_pred, residuals, alpha=0.6)
361
- ax_res.axhline(0, color='red', linestyle='--')
466
+ ax_res.scatter(y_pred, residuals,
467
+ alpha=format_config.scatter_alpha,
468
+ color=format_config.scatter_color)
469
+ ax_res.axhline(0, color=format_config.residual_line_color, linestyle='--')
362
470
  ax_res.set_xlabel("Predicted Values")
363
471
  ax_res.set_ylabel("Residuals")
364
472
  ax_res.set_title("Residual Plot")
@@ -369,10 +477,15 @@ def regression_metrics(y_true: np.ndarray, y_pred: np.ndarray, save_dir: Union[s
369
477
  _LOGGER.info(f"📈 Residual plot saved as '{res_path.name}'")
370
478
  plt.close(fig_res)
371
479
 
372
- # Save true vs predicted plot
480
+ # --- Save true vs predicted plot ---
373
481
  fig_tvp, ax_tvp = plt.subplots(figsize=(8, 6), dpi=DPI_value)
374
- ax_tvp.scatter(y_true, y_pred, alpha=0.6)
375
- ax_tvp.plot([y_true.min(), y_true.max()], [y_true.min(), y_true.max()], 'k--', lw=2)
482
+ ax_tvp.scatter(y_true, y_pred,
483
+ alpha=format_config.scatter_alpha,
484
+ color=format_config.scatter_color)
485
+ ax_tvp.plot([y_true.min(), y_true.max()], [y_true.min(), y_true.max()],
486
+ linestyle='--',
487
+ lw=2,
488
+ color=format_config.ideal_line_color)
376
489
  ax_tvp.set_xlabel('True Values')
377
490
  ax_tvp.set_ylabel('Predictions')
378
491
  ax_tvp.set_title('True vs. Predicted Values')
@@ -383,9 +496,11 @@ def regression_metrics(y_true: np.ndarray, y_pred: np.ndarray, save_dir: Union[s
383
496
  _LOGGER.info(f"📉 True vs. Predicted plot saved as '{tvp_path.name}'")
384
497
  plt.close(fig_tvp)
385
498
 
386
- # Save Histogram of Residuals
499
+ # --- Save Histogram of Residuals ---
387
500
  fig_hist, ax_hist = plt.subplots(figsize=(8, 6), dpi=DPI_value)
388
- sns.histplot(residuals, kde=True, ax=ax_hist)
501
+ sns.histplot(residuals, kde=True, ax=ax_hist,
502
+ bins=format_config.hist_bins,
503
+ color=format_config.scatter_color)
389
504
  ax_hist.set_xlabel("Residual Value")
390
505
  ax_hist.set_ylabel("Frequency")
391
506
  ax_hist.set_title("Distribution of Residuals")
@@ -396,6 +511,9 @@ def regression_metrics(y_true: np.ndarray, y_pred: np.ndarray, save_dir: Union[s
396
511
  _LOGGER.info(f"📊 Residuals histogram saved as '{hist_path.name}'")
397
512
  plt.close(fig_hist)
398
513
 
514
+ # --- Restore RC params ---
515
+ plt.rcParams.update(original_rc_params)
516
+
399
517
 
400
518
  def shap_summary_plot(model,
401
519
  background_data: Union[torch.Tensor,np.ndarray],
@@ -19,13 +19,17 @@ from sklearn.metrics import (
19
19
  jaccard_score
20
20
  )
21
21
  from pathlib import Path
22
- from typing import Union, List, Literal
22
+ from typing import Union, List, Literal, Optional
23
23
  import warnings
24
24
 
25
25
  from .path_manager import make_fullpath, sanitize_filename
26
26
  from ._logger import _LOGGER
27
27
  from ._script_info import _script_info
28
- from .keys import SHAPKeys
28
+ from ._keys import SHAPKeys
29
+ from .ML_configuration import (MultiTargetRegressionMetricsFormat,
30
+ _BaseRegressionFormat,
31
+ MultiLabelBinaryClassificationMetricsFormat,
32
+ _BaseMultiLabelFormat)
29
33
 
30
34
 
31
35
  __all__ = [
@@ -41,7 +45,8 @@ def multi_target_regression_metrics(
41
45
  y_true: np.ndarray,
42
46
  y_pred: np.ndarray,
43
47
  target_names: List[str],
44
- save_dir: Union[str, Path]
48
+ save_dir: Union[str, Path],
49
+ config: Optional[MultiTargetRegressionMetricsFormat] = None
45
50
  ):
46
51
  """
47
52
  Calculates and saves regression metrics for each target individually.
@@ -55,6 +60,7 @@ def multi_target_regression_metrics(
55
60
  y_pred (np.ndarray): Predicted values, shape (n_samples, n_targets).
56
61
  target_names (List[str]): A list of names for the target variables.
57
62
  save_dir (str | Path): Directory to save plots and the report.
63
+ config (object): Formatting configuration object.
58
64
  """
59
65
  if y_true.ndim != 2 or y_pred.ndim != 2:
60
66
  _LOGGER.error("y_true and y_pred must be 2D arrays for multi-target regression.")
@@ -68,8 +74,19 @@ def multi_target_regression_metrics(
68
74
 
69
75
  save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
70
76
  metrics_summary = []
77
+
78
+ # --- Parse Config or use defaults ---
79
+ if config is None:
80
+ # Create a default config if one wasn't provided
81
+ format_config = _BaseRegressionFormat()
82
+ else:
83
+ format_config = config
84
+
85
+ # --- Set Matplotlib font size ---
86
+ original_rc_params = plt.rcParams.copy()
87
+ plt.rcParams.update({'font.size': format_config.font_size})
71
88
 
72
- _LOGGER.info("--- Multi-Target Regression Evaluation ---")
89
+ _LOGGER.debug("--- Multi-Target Regression Evaluation ---")
73
90
 
74
91
  for i, name in enumerate(target_names):
75
92
  print(f" -> Evaluating target: '{name}'")
@@ -93,8 +110,12 @@ def multi_target_regression_metrics(
93
110
  # --- Save Residual Plot ---
94
111
  residuals = true_i - pred_i
95
112
  fig_res, ax_res = plt.subplots(figsize=(8, 6), dpi=DPI_value)
96
- ax_res.scatter(pred_i, residuals, alpha=0.6, edgecolors='k', s=50)
97
- ax_res.axhline(0, color='red', linestyle='--')
113
+ ax_res.scatter(pred_i, residuals,
114
+ alpha=format_config.scatter_alpha,
115
+ edgecolors='k',
116
+ s=50,
117
+ color=format_config.scatter_color) # Use config color
118
+ ax_res.axhline(0, color=format_config.residual_line_color, linestyle='--') # Use config color
98
119
  ax_res.set_xlabel("Predicted Values")
99
120
  ax_res.set_ylabel("Residuals (True - Predicted)")
100
121
  ax_res.set_title(f"Residual Plot for '{name}'")
@@ -106,8 +127,15 @@ def multi_target_regression_metrics(
106
127
 
107
128
  # --- Save True vs. Predicted Plot ---
108
129
  fig_tvp, ax_tvp = plt.subplots(figsize=(8, 6), dpi=DPI_value)
109
- ax_tvp.scatter(true_i, pred_i, alpha=0.6, edgecolors='k', s=50)
110
- ax_tvp.plot([true_i.min(), true_i.max()], [true_i.min(), true_i.max()], 'k--', lw=2)
130
+ ax_tvp.scatter(true_i, pred_i,
131
+ alpha=format_config.scatter_alpha,
132
+ edgecolors='k',
133
+ s=50,
134
+ color=format_config.scatter_color) # Use config color
135
+ ax_tvp.plot([true_i.min(), true_i.max()], [true_i.min(), true_i.max()],
136
+ linestyle='--',
137
+ lw=2,
138
+ color=format_config.ideal_line_color) # Use config color
111
139
  ax_tvp.set_xlabel('True Values')
112
140
  ax_tvp.set_ylabel('Predicted Values')
113
141
  ax_tvp.set_title(f'True vs. Predicted Values for "{name}"')
@@ -122,17 +150,18 @@ def multi_target_regression_metrics(
122
150
  report_path = save_dir_path / "regression_report_multi.csv"
123
151
  summary_df.to_csv(report_path, index=False)
124
152
  _LOGGER.info(f"Full regression report saved to '{report_path.name}'")
153
+
154
+ # --- Restore RC params ---
155
+ plt.rcParams.update(original_rc_params)
125
156
 
126
157
 
127
158
  def multi_label_classification_metrics(
128
159
  y_true: np.ndarray,
160
+ y_pred: np.ndarray,
129
161
  y_prob: np.ndarray,
130
162
  target_names: List[str],
131
163
  save_dir: Union[str, Path],
132
- threshold: float = 0.5,
133
- ROC_PR_line: str='darkorange',
134
- cmap: str = "Blues",
135
- font_size: int = 16
164
+ config: Optional[MultiLabelBinaryClassificationMetricsFormat] = None
136
165
  ):
137
166
  """
138
167
  Calculates and saves classification metrics for each label individually.
@@ -143,17 +172,17 @@ def multi_label_classification_metrics(
143
172
 
144
173
  Args:
145
174
  y_true (np.ndarray): Ground truth binary labels, shape (n_samples, n_labels).
175
+ y_pred (np.ndarray): Predicted binary labels, shape (n_samples, n_labels).
146
176
  y_prob (np.ndarray): Predicted probabilities, shape (n_samples, n_labels).
147
177
  target_names (List[str]): A list of names for the labels.
148
178
  save_dir (str | Path): Directory to save plots and reports.
149
- threshold (float): The probability threshold to convert probabilities into
150
- binary predictions for metrics like the confusion matrix.
179
+ config (object): Formatting configuration object.
151
180
  """
152
- if y_true.ndim != 2 or y_prob.ndim != 2:
153
- _LOGGER.error("y_true and y_prob must be 2D arrays for multi-label classification.")
181
+ if y_true.ndim != 2 or y_prob.ndim != 2 or y_pred.ndim != 2:
182
+ _LOGGER.error("y_true, y_pred, and y_prob must be 2D arrays for multi-label classification.")
154
183
  raise ValueError()
155
- if y_true.shape != y_prob.shape:
156
- _LOGGER.error("Shapes of y_true and y_prob must match.")
184
+ if y_true.shape != y_prob.shape or y_true.shape != y_pred.shape:
185
+ _LOGGER.error("Shapes of y_true, y_pred, and y_prob must match.")
157
186
  raise ValueError()
158
187
  if y_true.shape[1] != len(target_names):
159
188
  _LOGGER.error("Number of target names must match the number of columns in y_true.")
@@ -161,22 +190,28 @@ def multi_label_classification_metrics(
161
190
 
162
191
  save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
163
192
 
164
- # Generate binary predictions from probabilities
165
- y_pred = (y_prob >= threshold).astype(int)
193
+ # --- Parse Config or use defaults ---
194
+ if config is None:
195
+ # Create a default config if one wasn't provided
196
+ format_config = _BaseMultiLabelFormat()
197
+ else:
198
+ format_config = config
199
+
200
+ # y_pred is now passed in directly, no threshold needed.
166
201
 
167
202
  # --- Save current RC params and update font size ---
168
203
  original_rc_params = plt.rcParams.copy()
169
- plt.rcParams.update({'font.size': font_size})
204
+ plt.rcParams.update({'font.size': format_config.font_size})
170
205
 
171
- _LOGGER.info("--- Multi-Label Classification Evaluation ---")
206
+ # _LOGGER.info("--- Multi-Label Classification Evaluation ---")
172
207
 
173
- # --- Calculate and Save Overall Metrics ---
208
+ # --- Calculate and Save Overall Metrics (using y_pred) ---
174
209
  h_loss = hamming_loss(y_true, y_pred)
175
210
  j_score_micro = jaccard_score(y_true, y_pred, average='micro')
176
211
  j_score_macro = jaccard_score(y_true, y_pred, average='macro')
177
212
 
178
213
  overall_report = (
179
- f"Overall Multi-Label Metrics (Threshold = {threshold}):\n"
214
+ f"Overall Multi-Label Metrics:\n" # No threshold to report here
180
215
  f"--------------------------------------------------\n"
181
216
  f"Hamming Loss: {h_loss:.4f}\n"
182
217
  f"Jaccard Score (micro): {j_score_micro:.4f}\n"
@@ -191,20 +226,20 @@ def multi_label_classification_metrics(
191
226
  for i, name in enumerate(target_names):
192
227
  print(f" -> Evaluating label: '{name}'")
193
228
  true_i = y_true[:, i]
194
- pred_i = y_pred[:, i]
195
- prob_i = y_prob[:, i]
229
+ pred_i = y_pred[:, i] # Use passed-in y_pred
230
+ prob_i = y_prob[:, i] # Use passed-in y_prob
196
231
  sanitized_name = sanitize_filename(name)
197
232
 
198
- # --- Save Classification Report for the label ---
233
+ # --- Save Classification Report for the label (uses y_pred) ---
199
234
  report_text = classification_report(true_i, pred_i)
200
235
  report_path = save_dir_path / f"classification_report_{sanitized_name}.txt"
201
236
  report_path.write_text(report_text) # type: ignore
202
237
 
203
- # --- Save Confusion Matrix ---
238
+ # --- Save Confusion Matrix (uses y_pred) ---
204
239
  fig_cm, ax_cm = plt.subplots(figsize=(6, 6), dpi=DPI_value)
205
240
  disp_ = ConfusionMatrixDisplay.from_predictions(true_i,
206
241
  pred_i,
207
- cmap=cmap,
242
+ cmap=format_config.cmap, # Use config cmap
208
243
  ax=ax_cm,
209
244
  normalize='true',
210
245
  labels=[0, 1],
@@ -217,7 +252,7 @@ def multi_label_classification_metrics(
217
252
 
218
253
  # Manually update font size of cell texts
219
254
  for text in ax_cm.texts:
220
- text.set_fontsize(font_size)
255
+ text.set_fontsize(format_config.font_size) # Use config font_size
221
256
 
222
257
  fig_cm.tight_layout()
223
258
 
@@ -226,11 +261,44 @@ def multi_label_classification_metrics(
226
261
  plt.savefig(cm_path)
227
262
  plt.close(fig_cm)
228
263
 
229
- # --- Save ROC Curve ---
230
- fpr, tpr, _ = roc_curve(true_i, prob_i)
264
+ # --- Save ROC Curve (uses y_prob) ---
265
+ fpr, tpr, thresholds = roc_curve(true_i, prob_i)
266
+
267
+ try:
268
+ # Calculate Youden's J statistic (tpr - fpr)
269
+ J = tpr - fpr
270
+ # Find the index of the best threshold
271
+ best_index = np.argmax(J)
272
+ optimal_threshold = thresholds[best_index]
273
+ best_tpr = tpr[best_index]
274
+ best_fpr = fpr[best_index]
275
+
276
+ # Define the filename
277
+ threshold_filename = f"best_threshold_{sanitized_name}.txt"
278
+ threshold_path = save_dir_path / threshold_filename
279
+
280
+ # The class name is the target_name for this label
281
+ class_name = name
282
+
283
+ # Create content for the file
284
+ file_content = (
285
+ f"Optimal Classification Threshold (Youden's J Statistic)\n"
286
+ f"Class/Label: {class_name}\n"
287
+ f"--------------------------------------------------\n"
288
+ f"Threshold: {optimal_threshold:.6f}\n"
289
+ f"True Positive Rate (TPR): {best_tpr:.6f}\n"
290
+ f"False Positive Rate (FPR): {best_fpr:.6f}\n"
291
+ )
292
+
293
+ threshold_path.write_text(file_content, encoding="utf-8")
294
+ _LOGGER.info(f"💾 Optimal threshold for '{name}' saved to '{threshold_path.name}'")
295
+
296
+ except Exception as e:
297
+ _LOGGER.warning(f"Could not calculate or save optimal threshold for '{name}': {e}")
298
+
231
299
  auc = roc_auc_score(true_i, prob_i)
232
300
  fig_roc, ax_roc = plt.subplots(figsize=(6, 6), dpi=DPI_value)
233
- ax_roc.plot(fpr, tpr, label=f'AUC = {auc:.2f}', color=ROC_PR_line)
301
+ ax_roc.plot(fpr, tpr, label=f'AUC = {auc:.2f}', color=format_config.ROC_PR_line) # Use config color
234
302
  ax_roc.plot([0, 1], [0, 1], 'k--')
235
303
  ax_roc.set_title(f'ROC Curve for "{name}"')
236
304
  ax_roc.set_xlabel('False Positive Rate'); ax_roc.set_ylabel('True Positive Rate')
@@ -239,11 +307,11 @@ def multi_label_classification_metrics(
239
307
  plt.savefig(roc_path)
240
308
  plt.close(fig_roc)
241
309
 
242
- # --- Save Precision-Recall Curve ---
310
+ # --- Save Precision-Recall Curve (uses y_prob) ---
243
311
  precision, recall, _ = precision_recall_curve(true_i, prob_i)
244
312
  ap_score = average_precision_score(true_i, prob_i)
245
313
  fig_pr, ax_pr = plt.subplots(figsize=(6, 6), dpi=DPI_value)
246
- ax_pr.plot(recall, precision, label=f'AP = {ap_score:.2f}', color=ROC_PR_line)
314
+ ax_pr.plot(recall, precision, label=f'AP = {ap_score:.2f}', color=format_config.ROC_PR_line) # Use config color
247
315
  ax_pr.set_title(f'Precision-Recall Curve for "{name}"')
248
316
  ax_pr.set_xlabel('Recall'); ax_pr.set_ylabel('Precision')
249
317
  ax_pr.legend(loc='lower left'); ax_pr.grid(True, linestyle='--', alpha=0.6)