dragon-ml-toolbox 19.13.0__py3-none-any.whl → 20.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/METADATA +29 -46
- dragon_ml_toolbox-20.0.0.dist-info/RECORD +178 -0
- ml_tools/{ETL_cleaning.py → ETL_cleaning/__init__.py} +13 -5
- ml_tools/ETL_cleaning/_basic_clean.py +351 -0
- ml_tools/ETL_cleaning/_clean_tools.py +128 -0
- ml_tools/ETL_cleaning/_dragon_cleaner.py +245 -0
- ml_tools/ETL_cleaning/_imprimir.py +13 -0
- ml_tools/{ETL_engineering.py → ETL_engineering/__init__.py} +8 -4
- ml_tools/ETL_engineering/_dragon_engineering.py +261 -0
- ml_tools/ETL_engineering/_imprimir.py +24 -0
- ml_tools/{_core/_ETL_engineering.py → ETL_engineering/_transforms.py} +14 -267
- ml_tools/{_core → GUI_tools}/_GUI_tools.py +37 -40
- ml_tools/{GUI_tools.py → GUI_tools/__init__.py} +7 -5
- ml_tools/GUI_tools/_imprimir.py +12 -0
- ml_tools/IO_tools/_IO_loggers.py +235 -0
- ml_tools/IO_tools/_IO_save_load.py +151 -0
- ml_tools/IO_tools/_IO_utils.py +140 -0
- ml_tools/{IO_tools.py → IO_tools/__init__.py} +13 -5
- ml_tools/IO_tools/_imprimir.py +14 -0
- ml_tools/MICE/_MICE_imputation.py +132 -0
- ml_tools/{MICE_imputation.py → MICE/__init__.py} +6 -7
- ml_tools/{_core/_MICE_imputation.py → MICE/_dragon_mice.py} +243 -322
- ml_tools/MICE/_imprimir.py +11 -0
- ml_tools/{ML_callbacks.py → ML_callbacks/__init__.py} +12 -4
- ml_tools/ML_callbacks/_base.py +101 -0
- ml_tools/ML_callbacks/_checkpoint.py +232 -0
- ml_tools/ML_callbacks/_early_stop.py +208 -0
- ml_tools/ML_callbacks/_imprimir.py +12 -0
- ml_tools/ML_callbacks/_scheduler.py +197 -0
- ml_tools/{ML_chaining_utilities.py → ML_chain/__init__.py} +8 -3
- ml_tools/{_core/_ML_chaining_utilities.py → ML_chain/_chaining_tools.py} +5 -129
- ml_tools/ML_chain/_dragon_chain.py +140 -0
- ml_tools/ML_chain/_imprimir.py +11 -0
- ml_tools/ML_configuration/__init__.py +90 -0
- ml_tools/ML_configuration/_base_model_config.py +69 -0
- ml_tools/ML_configuration/_finalize.py +366 -0
- ml_tools/ML_configuration/_imprimir.py +47 -0
- ml_tools/ML_configuration/_metrics.py +593 -0
- ml_tools/ML_configuration/_models.py +206 -0
- ml_tools/ML_configuration/_training.py +124 -0
- ml_tools/ML_datasetmaster/__init__.py +28 -0
- ml_tools/ML_datasetmaster/_base_datasetmaster.py +337 -0
- ml_tools/{_core/_ML_datasetmaster.py → ML_datasetmaster/_datasetmaster.py} +9 -329
- ml_tools/ML_datasetmaster/_imprimir.py +15 -0
- ml_tools/{_core/_ML_sequence_datasetmaster.py → ML_datasetmaster/_sequence_datasetmaster.py} +13 -15
- ml_tools/{_core/_ML_vision_datasetmaster.py → ML_datasetmaster/_vision_datasetmaster.py} +63 -65
- ml_tools/ML_evaluation/__init__.py +53 -0
- ml_tools/ML_evaluation/_classification.py +629 -0
- ml_tools/ML_evaluation/_feature_importance.py +409 -0
- ml_tools/ML_evaluation/_imprimir.py +25 -0
- ml_tools/ML_evaluation/_loss.py +92 -0
- ml_tools/ML_evaluation/_regression.py +273 -0
- ml_tools/{_core/_ML_sequence_evaluation.py → ML_evaluation/_sequence.py} +8 -11
- ml_tools/{_core/_ML_vision_evaluation.py → ML_evaluation/_vision.py} +12 -17
- ml_tools/{_core → ML_evaluation_captum}/_ML_evaluation_captum.py +11 -38
- ml_tools/{ML_evaluation_captum.py → ML_evaluation_captum/__init__.py} +6 -4
- ml_tools/ML_evaluation_captum/_imprimir.py +10 -0
- ml_tools/{_core → ML_finalize_handler}/_ML_finalize_handler.py +3 -7
- ml_tools/ML_finalize_handler/__init__.py +10 -0
- ml_tools/ML_finalize_handler/_imprimir.py +8 -0
- ml_tools/ML_inference/__init__.py +22 -0
- ml_tools/ML_inference/_base_inference.py +166 -0
- ml_tools/{_core/_ML_chaining_inference.py → ML_inference/_chain_inference.py} +14 -17
- ml_tools/ML_inference/_dragon_inference.py +332 -0
- ml_tools/ML_inference/_imprimir.py +11 -0
- ml_tools/ML_inference/_multi_inference.py +180 -0
- ml_tools/ML_inference_sequence/__init__.py +10 -0
- ml_tools/ML_inference_sequence/_imprimir.py +8 -0
- ml_tools/{_core/_ML_sequence_inference.py → ML_inference_sequence/_sequence_inference.py} +11 -15
- ml_tools/ML_inference_vision/__init__.py +10 -0
- ml_tools/ML_inference_vision/_imprimir.py +8 -0
- ml_tools/{_core/_ML_vision_inference.py → ML_inference_vision/_vision_inference.py} +15 -19
- ml_tools/ML_models/__init__.py +32 -0
- ml_tools/{_core/_ML_models_advanced.py → ML_models/_advanced_models.py} +22 -18
- ml_tools/ML_models/_base_mlp_attention.py +198 -0
- ml_tools/{_core/_models_advanced_base.py → ML_models/_base_save_load.py} +73 -49
- ml_tools/ML_models/_dragon_tabular.py +248 -0
- ml_tools/ML_models/_imprimir.py +18 -0
- ml_tools/ML_models/_mlp_attention.py +134 -0
- ml_tools/{_core → ML_models}/_models_advanced_helpers.py +13 -13
- ml_tools/ML_models_sequence/__init__.py +10 -0
- ml_tools/ML_models_sequence/_imprimir.py +8 -0
- ml_tools/{_core/_ML_sequence_models.py → ML_models_sequence/_sequence_models.py} +5 -8
- ml_tools/ML_models_vision/__init__.py +29 -0
- ml_tools/ML_models_vision/_base_wrapper.py +254 -0
- ml_tools/ML_models_vision/_image_classification.py +182 -0
- ml_tools/ML_models_vision/_image_segmentation.py +108 -0
- ml_tools/ML_models_vision/_imprimir.py +16 -0
- ml_tools/ML_models_vision/_object_detection.py +135 -0
- ml_tools/ML_optimization/__init__.py +21 -0
- ml_tools/ML_optimization/_imprimir.py +13 -0
- ml_tools/{_core/_ML_optimization_pareto.py → ML_optimization/_multi_dragon.py} +18 -24
- ml_tools/ML_optimization/_single_dragon.py +203 -0
- ml_tools/{_core/_ML_optimization.py → ML_optimization/_single_manual.py} +75 -213
- ml_tools/{_core → ML_scaler}/_ML_scaler.py +8 -11
- ml_tools/ML_scaler/__init__.py +10 -0
- ml_tools/ML_scaler/_imprimir.py +8 -0
- ml_tools/ML_trainer/__init__.py +20 -0
- ml_tools/ML_trainer/_base_trainer.py +297 -0
- ml_tools/ML_trainer/_dragon_detection_trainer.py +402 -0
- ml_tools/ML_trainer/_dragon_sequence_trainer.py +540 -0
- ml_tools/ML_trainer/_dragon_trainer.py +1160 -0
- ml_tools/ML_trainer/_imprimir.py +10 -0
- ml_tools/{ML_utilities.py → ML_utilities/__init__.py} +14 -6
- ml_tools/ML_utilities/_artifact_finder.py +382 -0
- ml_tools/ML_utilities/_imprimir.py +16 -0
- ml_tools/ML_utilities/_inspection.py +325 -0
- ml_tools/ML_utilities/_train_tools.py +205 -0
- ml_tools/{ML_vision_transformers.py → ML_vision_transformers/__init__.py} +9 -6
- ml_tools/{_core/_ML_vision_transformers.py → ML_vision_transformers/_core_transforms.py} +11 -155
- ml_tools/ML_vision_transformers/_imprimir.py +14 -0
- ml_tools/ML_vision_transformers/_offline_augmentation.py +159 -0
- ml_tools/{_core/_PSO_optimization.py → PSO_optimization/_PSO.py} +58 -15
- ml_tools/{PSO_optimization.py → PSO_optimization/__init__.py} +5 -3
- ml_tools/PSO_optimization/_imprimir.py +10 -0
- ml_tools/SQL/__init__.py +7 -0
- ml_tools/{_core/_SQL.py → SQL/_dragon_SQL.py} +7 -11
- ml_tools/SQL/_imprimir.py +8 -0
- ml_tools/{_core → VIF}/_VIF_factor.py +5 -8
- ml_tools/{VIF_factor.py → VIF/__init__.py} +4 -2
- ml_tools/VIF/_imprimir.py +10 -0
- ml_tools/_core/__init__.py +7 -1
- ml_tools/_core/_logger.py +8 -18
- ml_tools/_core/_schema_load_ops.py +43 -0
- ml_tools/_core/_script_info.py +2 -2
- ml_tools/{data_exploration.py → data_exploration/__init__.py} +32 -16
- ml_tools/data_exploration/_analysis.py +214 -0
- ml_tools/data_exploration/_cleaning.py +566 -0
- ml_tools/data_exploration/_features.py +583 -0
- ml_tools/data_exploration/_imprimir.py +32 -0
- ml_tools/data_exploration/_plotting.py +487 -0
- ml_tools/data_exploration/_schema_ops.py +176 -0
- ml_tools/{ensemble_evaluation.py → ensemble_evaluation/__init__.py} +6 -4
- ml_tools/{_core → ensemble_evaluation}/_ensemble_evaluation.py +3 -7
- ml_tools/ensemble_evaluation/_imprimir.py +14 -0
- ml_tools/{ensemble_inference.py → ensemble_inference/__init__.py} +5 -3
- ml_tools/{_core → ensemble_inference}/_ensemble_inference.py +15 -18
- ml_tools/ensemble_inference/_imprimir.py +9 -0
- ml_tools/{ensemble_learning.py → ensemble_learning/__init__.py} +4 -6
- ml_tools/{_core → ensemble_learning}/_ensemble_learning.py +7 -10
- ml_tools/ensemble_learning/_imprimir.py +10 -0
- ml_tools/{excel_handler.py → excel_handler/__init__.py} +5 -3
- ml_tools/{_core → excel_handler}/_excel_handler.py +6 -10
- ml_tools/excel_handler/_imprimir.py +13 -0
- ml_tools/{keys.py → keys/__init__.py} +4 -1
- ml_tools/keys/_imprimir.py +11 -0
- ml_tools/{_core → keys}/_keys.py +2 -0
- ml_tools/{math_utilities.py → math_utilities/__init__.py} +5 -2
- ml_tools/math_utilities/_imprimir.py +11 -0
- ml_tools/{_core → math_utilities}/_math_utilities.py +1 -5
- ml_tools/{optimization_tools.py → optimization_tools/__init__.py} +9 -4
- ml_tools/optimization_tools/_imprimir.py +13 -0
- ml_tools/optimization_tools/_optimization_bounds.py +236 -0
- ml_tools/optimization_tools/_optimization_plots.py +218 -0
- ml_tools/{path_manager.py → path_manager/__init__.py} +6 -3
- ml_tools/{_core/_path_manager.py → path_manager/_dragonmanager.py} +11 -347
- ml_tools/path_manager/_imprimir.py +15 -0
- ml_tools/path_manager/_path_tools.py +346 -0
- ml_tools/plot_fonts/__init__.py +8 -0
- ml_tools/plot_fonts/_imprimir.py +8 -0
- ml_tools/{_core → plot_fonts}/_plot_fonts.py +2 -5
- ml_tools/schema/__init__.py +15 -0
- ml_tools/schema/_feature_schema.py +223 -0
- ml_tools/schema/_gui_schema.py +191 -0
- ml_tools/schema/_imprimir.py +10 -0
- ml_tools/{serde.py → serde/__init__.py} +4 -2
- ml_tools/serde/_imprimir.py +10 -0
- ml_tools/{_core → serde}/_serde.py +3 -8
- ml_tools/{utilities.py → utilities/__init__.py} +11 -6
- ml_tools/utilities/_imprimir.py +18 -0
- ml_tools/{_core/_utilities.py → utilities/_utility_save_load.py} +13 -190
- ml_tools/utilities/_utility_tools.py +192 -0
- dragon_ml_toolbox-19.13.0.dist-info/RECORD +0 -111
- ml_tools/ML_chaining_inference.py +0 -8
- ml_tools/ML_configuration.py +0 -86
- ml_tools/ML_configuration_pytab.py +0 -14
- ml_tools/ML_datasetmaster.py +0 -10
- ml_tools/ML_evaluation.py +0 -16
- ml_tools/ML_evaluation_multi.py +0 -12
- ml_tools/ML_finalize_handler.py +0 -8
- ml_tools/ML_inference.py +0 -12
- ml_tools/ML_models.py +0 -14
- ml_tools/ML_models_advanced.py +0 -14
- ml_tools/ML_models_pytab.py +0 -14
- ml_tools/ML_optimization.py +0 -14
- ml_tools/ML_optimization_pareto.py +0 -8
- ml_tools/ML_scaler.py +0 -8
- ml_tools/ML_sequence_datasetmaster.py +0 -8
- ml_tools/ML_sequence_evaluation.py +0 -10
- ml_tools/ML_sequence_inference.py +0 -8
- ml_tools/ML_sequence_models.py +0 -8
- ml_tools/ML_trainer.py +0 -12
- ml_tools/ML_vision_datasetmaster.py +0 -12
- ml_tools/ML_vision_evaluation.py +0 -10
- ml_tools/ML_vision_inference.py +0 -8
- ml_tools/ML_vision_models.py +0 -18
- ml_tools/SQL.py +0 -8
- ml_tools/_core/_ETL_cleaning.py +0 -694
- ml_tools/_core/_IO_tools.py +0 -498
- ml_tools/_core/_ML_callbacks.py +0 -702
- ml_tools/_core/_ML_configuration.py +0 -1332
- ml_tools/_core/_ML_configuration_pytab.py +0 -102
- ml_tools/_core/_ML_evaluation.py +0 -867
- ml_tools/_core/_ML_evaluation_multi.py +0 -544
- ml_tools/_core/_ML_inference.py +0 -646
- ml_tools/_core/_ML_models.py +0 -668
- ml_tools/_core/_ML_models_pytab.py +0 -693
- ml_tools/_core/_ML_trainer.py +0 -2323
- ml_tools/_core/_ML_utilities.py +0 -886
- ml_tools/_core/_ML_vision_models.py +0 -644
- ml_tools/_core/_data_exploration.py +0 -1901
- ml_tools/_core/_optimization_tools.py +0 -493
- ml_tools/_core/_schema.py +0 -359
- ml_tools/plot_fonts.py +0 -8
- ml_tools/schema.py +0 -12
- {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/WHEEL +0 -0
- {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/licenses/LICENSE +0 -0
- {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +0 -0
- {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/top_level.txt +0 -0
ml_tools/_core/_ML_evaluation.py
DELETED
|
@@ -1,867 +0,0 @@
|
|
|
1
|
-
import numpy as np
|
|
2
|
-
import pandas as pd
|
|
3
|
-
import matplotlib.pyplot as plt
|
|
4
|
-
import seaborn as sns
|
|
5
|
-
from sklearn.calibration import CalibrationDisplay
|
|
6
|
-
from sklearn.metrics import (
|
|
7
|
-
classification_report,
|
|
8
|
-
ConfusionMatrixDisplay,
|
|
9
|
-
roc_curve,
|
|
10
|
-
roc_auc_score,
|
|
11
|
-
mean_squared_error,
|
|
12
|
-
mean_absolute_error,
|
|
13
|
-
r2_score,
|
|
14
|
-
median_absolute_error,
|
|
15
|
-
precision_recall_curve,
|
|
16
|
-
average_precision_score
|
|
17
|
-
)
|
|
18
|
-
import torch
|
|
19
|
-
import shap
|
|
20
|
-
from pathlib import Path
|
|
21
|
-
from typing import Union, Optional, List, Literal
|
|
22
|
-
import warnings
|
|
23
|
-
|
|
24
|
-
from ._path_manager import make_fullpath, sanitize_filename
|
|
25
|
-
from ._logger import get_logger
|
|
26
|
-
from ._script_info import _script_info
|
|
27
|
-
from ._keys import SHAPKeys, PyTorchLogKeys, _EvaluationConfig
|
|
28
|
-
from ._ML_configuration import (RegressionMetricsFormat,
|
|
29
|
-
BinaryClassificationMetricsFormat,
|
|
30
|
-
MultiClassClassificationMetricsFormat,
|
|
31
|
-
BinaryImageClassificationMetricsFormat,
|
|
32
|
-
MultiClassImageClassificationMetricsFormat,
|
|
33
|
-
_BaseClassificationFormat,
|
|
34
|
-
_BaseRegressionFormat)
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
_LOGGER = get_logger("Evaluation")
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
__all__ = [
|
|
41
|
-
"plot_losses",
|
|
42
|
-
"classification_metrics",
|
|
43
|
-
"regression_metrics",
|
|
44
|
-
"shap_summary_plot",
|
|
45
|
-
"plot_attention_importance"
|
|
46
|
-
]
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
DPI_value = _EvaluationConfig.DPI
|
|
50
|
-
REGRESSION_PLOT_SIZE = _EvaluationConfig.REGRESSION_PLOT_SIZE
|
|
51
|
-
CLASSIFICATION_PLOT_SIZE = _EvaluationConfig.CLASSIFICATION_PLOT_SIZE
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
def plot_losses(history: dict, save_dir: Union[str, Path]):
|
|
55
|
-
"""
|
|
56
|
-
Plots training & validation loss curves from a history object.
|
|
57
|
-
Also plots the learning rate if available in the history.
|
|
58
|
-
|
|
59
|
-
Args:
|
|
60
|
-
history (dict): A dictionary containing 'train_loss' and 'val_loss'.
|
|
61
|
-
save_dir (str | Path): Directory to save the plot image.
|
|
62
|
-
"""
|
|
63
|
-
train_loss = history.get(PyTorchLogKeys.TRAIN_LOSS, [])
|
|
64
|
-
val_loss = history.get(PyTorchLogKeys.VAL_LOSS, [])
|
|
65
|
-
lr_history = history.get(PyTorchLogKeys.LEARNING_RATE, [])
|
|
66
|
-
|
|
67
|
-
if not train_loss and not val_loss:
|
|
68
|
-
_LOGGER.warning("Loss history is empty or incomplete. Cannot plot.")
|
|
69
|
-
return
|
|
70
|
-
|
|
71
|
-
fig, ax = plt.subplots(figsize=_EvaluationConfig.LOSS_PLOT_SIZE, dpi=DPI_value)
|
|
72
|
-
|
|
73
|
-
# --- Plot Losses (Left Y-axis) ---
|
|
74
|
-
line_handles = [] # To store line objects for the legend
|
|
75
|
-
|
|
76
|
-
# Plot training loss only if data for it exists
|
|
77
|
-
if train_loss:
|
|
78
|
-
epochs = range(1, len(train_loss) + 1)
|
|
79
|
-
line1, = ax.plot(epochs, train_loss, 'o-', label='Training Loss', color='tab:blue')
|
|
80
|
-
line_handles.append(line1)
|
|
81
|
-
|
|
82
|
-
# Plot validation loss only if data for it exists
|
|
83
|
-
if val_loss:
|
|
84
|
-
epochs = range(1, len(val_loss) + 1)
|
|
85
|
-
line2, = ax.plot(epochs, val_loss, 'o-', label='Validation Loss', color='tab:orange')
|
|
86
|
-
line_handles.append(line2)
|
|
87
|
-
|
|
88
|
-
ax.set_title('Training and Validation Loss', fontsize=_EvaluationConfig.LOSS_PLOT_LABEL_SIZE + 2, pad=_EvaluationConfig.LABEL_PADDING)
|
|
89
|
-
ax.set_xlabel('Epochs', fontsize=_EvaluationConfig.LOSS_PLOT_LABEL_SIZE, labelpad=_EvaluationConfig.LABEL_PADDING)
|
|
90
|
-
ax.set_ylabel('Loss', color='tab:blue', fontsize=_EvaluationConfig.LOSS_PLOT_LABEL_SIZE, labelpad=_EvaluationConfig.LABEL_PADDING)
|
|
91
|
-
ax.tick_params(axis='y', labelcolor='tab:blue', labelsize=_EvaluationConfig.LOSS_PLOT_TICK_SIZE)
|
|
92
|
-
ax.tick_params(axis='x', labelsize=_EvaluationConfig.LOSS_PLOT_TICK_SIZE)
|
|
93
|
-
ax.grid(True, linestyle='--')
|
|
94
|
-
|
|
95
|
-
# --- Plot Learning Rate (Right Y-axis) ---
|
|
96
|
-
if lr_history:
|
|
97
|
-
ax2 = ax.twinx() # Create a second y-axis
|
|
98
|
-
epochs = range(1, len(lr_history) + 1)
|
|
99
|
-
line3, = ax2.plot(epochs, lr_history, 'g--', label='Learning Rate')
|
|
100
|
-
line_handles.append(line3)
|
|
101
|
-
|
|
102
|
-
ax2.set_ylabel('Learning Rate', color='g', fontsize=_EvaluationConfig.LOSS_PLOT_LABEL_SIZE, labelpad=_EvaluationConfig.LABEL_PADDING)
|
|
103
|
-
ax2.tick_params(axis='y', labelcolor='g', labelsize=_EvaluationConfig.LOSS_PLOT_TICK_SIZE)
|
|
104
|
-
# Use scientific notation if the LR is very small
|
|
105
|
-
ax2.ticklabel_format(style='sci', axis='y', scilimits=(0,0))
|
|
106
|
-
# increase the size of the scientific notation
|
|
107
|
-
ax2.yaxis.get_offset_text().set_fontsize(_EvaluationConfig.LOSS_PLOT_TICK_SIZE - 2)
|
|
108
|
-
# remove grid from second y-axis
|
|
109
|
-
ax2.grid(False)
|
|
110
|
-
|
|
111
|
-
# Combine legends from both axes
|
|
112
|
-
ax.legend(handles=line_handles, loc='best', fontsize=_EvaluationConfig.LOSS_PLOT_LEGEND_SIZE)
|
|
113
|
-
|
|
114
|
-
# ax.grid(True)
|
|
115
|
-
plt.tight_layout()
|
|
116
|
-
|
|
117
|
-
save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
118
|
-
save_path = save_dir_path / "loss_plot.svg"
|
|
119
|
-
plt.savefig(save_path)
|
|
120
|
-
_LOGGER.info(f"📉 Loss plot saved as '{save_path.name}'")
|
|
121
|
-
|
|
122
|
-
plt.close(fig)
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
def classification_metrics(save_dir: Union[str, Path],
|
|
126
|
-
y_true: np.ndarray,
|
|
127
|
-
y_pred: np.ndarray,
|
|
128
|
-
y_prob: Optional[np.ndarray] = None,
|
|
129
|
-
class_map: Optional[dict[str,int]] = None,
|
|
130
|
-
config: Optional[Union[BinaryClassificationMetricsFormat,
|
|
131
|
-
MultiClassClassificationMetricsFormat,
|
|
132
|
-
BinaryImageClassificationMetricsFormat,
|
|
133
|
-
MultiClassImageClassificationMetricsFormat]] = None):
|
|
134
|
-
"""
|
|
135
|
-
Saves classification metrics and plots.
|
|
136
|
-
|
|
137
|
-
Args:
|
|
138
|
-
y_true (np.ndarray): Ground truth labels.
|
|
139
|
-
y_pred (np.ndarray): Predicted labels.
|
|
140
|
-
y_prob (np.ndarray): Predicted probabilities for ROC curve.
|
|
141
|
-
config (object): Formatting configuration object.
|
|
142
|
-
save_dir (str | Path): Directory to save plots.
|
|
143
|
-
"""
|
|
144
|
-
# --- Parse Config or use defaults ---
|
|
145
|
-
if config is None:
|
|
146
|
-
# Create a default config if one wasn't provided
|
|
147
|
-
format_config = _BaseClassificationFormat()
|
|
148
|
-
else:
|
|
149
|
-
format_config = config
|
|
150
|
-
|
|
151
|
-
# original_rc_params = plt.rcParams.copy()
|
|
152
|
-
# plt.rcParams.update({'font.size': format_config.font_size})
|
|
153
|
-
|
|
154
|
-
# --- Set Font Sizes ---
|
|
155
|
-
xtick_size = format_config.xtick_size
|
|
156
|
-
ytick_size = format_config.ytick_size
|
|
157
|
-
legend_size = format_config.legend_size
|
|
158
|
-
|
|
159
|
-
# config font size for heatmap
|
|
160
|
-
cm_font_size = format_config.cm_font_size
|
|
161
|
-
cm_tick_size = cm_font_size - 4
|
|
162
|
-
|
|
163
|
-
# --- Parse class_map ---
|
|
164
|
-
map_labels = None
|
|
165
|
-
map_display_labels = None
|
|
166
|
-
if class_map:
|
|
167
|
-
# Sort the map by its values (the indices) to ensure correct order
|
|
168
|
-
try:
|
|
169
|
-
sorted_items = sorted(class_map.items(), key=lambda item: item[1])
|
|
170
|
-
map_labels = [item[1] for item in sorted_items]
|
|
171
|
-
map_display_labels = [item[0] for item in sorted_items]
|
|
172
|
-
except Exception as e:
|
|
173
|
-
_LOGGER.warning(f"Could not parse 'class_map': {e}")
|
|
174
|
-
map_labels = None
|
|
175
|
-
map_display_labels = None
|
|
176
|
-
|
|
177
|
-
# Generate report as both text and dictionary
|
|
178
|
-
report_text: str = classification_report(y_true, y_pred, labels=map_labels, target_names=map_display_labels) # type: ignore
|
|
179
|
-
report_dict: dict = classification_report(y_true, y_pred, output_dict=True, labels=map_labels, target_names=map_display_labels) # type: ignore
|
|
180
|
-
# print(report_text)
|
|
181
|
-
|
|
182
|
-
save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
183
|
-
# Save text report
|
|
184
|
-
report_path = save_dir_path / "classification_report.txt"
|
|
185
|
-
report_path.write_text(report_text, encoding="utf-8")
|
|
186
|
-
_LOGGER.info(f"📝 Classification report saved as '{report_path.name}'")
|
|
187
|
-
|
|
188
|
-
# --- Save Classification Report Heatmap ---
|
|
189
|
-
try:
|
|
190
|
-
# Create DataFrame from report
|
|
191
|
-
report_df = pd.DataFrame(report_dict)
|
|
192
|
-
|
|
193
|
-
# 1. Robust Cleanup: Drop by name, not position
|
|
194
|
-
# Remove 'accuracy' column if it exists (handles the scalar value issue)
|
|
195
|
-
report_df = report_df.drop(columns=['accuracy'], errors='ignore')
|
|
196
|
-
|
|
197
|
-
# Remove 'support' row explicitly (safer than iloc[:-1])
|
|
198
|
-
if 'support' in report_df.index:
|
|
199
|
-
report_df = report_df.drop(index='support')
|
|
200
|
-
|
|
201
|
-
# 2. Transpose: Rows = Classes, Cols = Metrics
|
|
202
|
-
plot_df = report_df.T
|
|
203
|
-
|
|
204
|
-
# 3. Dynamic Height Calculation
|
|
205
|
-
# (Base height of 4 + 0.5 inches per class row)
|
|
206
|
-
fig_height = max(5.0, len(plot_df.index) * 0.5 + 4.0)
|
|
207
|
-
fig_width = 8.0 # Set a fixed width
|
|
208
|
-
|
|
209
|
-
# --- Use calculated dimensions, not the config constant ---
|
|
210
|
-
fig_heat, ax_heat = plt.subplots(figsize=(fig_width, fig_height), dpi=_EvaluationConfig.DPI)
|
|
211
|
-
|
|
212
|
-
# sns.set_theme(font_scale=1.4)
|
|
213
|
-
sns.heatmap(plot_df,
|
|
214
|
-
annot=True,
|
|
215
|
-
cmap=format_config.cmap,
|
|
216
|
-
fmt='.2f',
|
|
217
|
-
vmin=0.0,
|
|
218
|
-
vmax=1.0,
|
|
219
|
-
cbar_kws={'shrink': 0.9}) # Shrink colorbar slightly to fit better
|
|
220
|
-
|
|
221
|
-
# sns.set_theme(font_scale=1.0)
|
|
222
|
-
|
|
223
|
-
ax_heat.set_title("Classification Report Heatmap", pad=_EvaluationConfig.LABEL_PADDING, fontsize=cm_font_size)
|
|
224
|
-
|
|
225
|
-
# manually increase the font size of the elements
|
|
226
|
-
for text in ax_heat.texts:
|
|
227
|
-
text.set_fontsize(cm_tick_size)
|
|
228
|
-
|
|
229
|
-
# manually increase the size of the colorbar ticks
|
|
230
|
-
cbar = ax_heat.collections[0].colorbar
|
|
231
|
-
cbar.ax.tick_params(labelsize=cm_tick_size - 4) # type: ignore
|
|
232
|
-
|
|
233
|
-
# Update Ticks
|
|
234
|
-
ax_heat.tick_params(axis='x', labelsize=cm_tick_size, pad=_EvaluationConfig.LABEL_PADDING)
|
|
235
|
-
ax_heat.tick_params(axis='y', labelsize=cm_tick_size, pad=_EvaluationConfig.LABEL_PADDING, rotation=0) # Ensure Y labels are horizontal
|
|
236
|
-
|
|
237
|
-
plt.tight_layout()
|
|
238
|
-
|
|
239
|
-
heatmap_path = save_dir_path / "classification_report_heatmap.svg"
|
|
240
|
-
plt.savefig(heatmap_path)
|
|
241
|
-
_LOGGER.info(f"📊 Report heatmap saved as '{heatmap_path.name}'")
|
|
242
|
-
plt.close(fig_heat)
|
|
243
|
-
|
|
244
|
-
except Exception as e:
|
|
245
|
-
_LOGGER.error(f"Could not generate classification report heatmap: {e}")
|
|
246
|
-
|
|
247
|
-
# --- labels for Confusion Matrix ---
|
|
248
|
-
plot_labels = map_labels
|
|
249
|
-
plot_display_labels = map_display_labels
|
|
250
|
-
|
|
251
|
-
# 1. DYNAMIC SIZE CALCULATION
|
|
252
|
-
# Calculate figure size based on number of classes.
|
|
253
|
-
n_classes = len(plot_labels) if plot_labels is not None else len(np.unique(y_true))
|
|
254
|
-
# Ensure a minimum size so very small matrices aren't tiny
|
|
255
|
-
fig_w = max(9, n_classes * 0.8 + 3)
|
|
256
|
-
fig_h = max(8, n_classes * 0.8 + 2)
|
|
257
|
-
|
|
258
|
-
# Use the calculated size instead of CLASSIFICATION_PLOT_SIZE
|
|
259
|
-
fig_cm, ax_cm = plt.subplots(figsize=(fig_w, fig_h), dpi=DPI_value)
|
|
260
|
-
disp_ = ConfusionMatrixDisplay.from_predictions(y_true,
|
|
261
|
-
y_pred,
|
|
262
|
-
cmap=format_config.cmap,
|
|
263
|
-
ax=ax_cm,
|
|
264
|
-
normalize='true',
|
|
265
|
-
labels=plot_labels,
|
|
266
|
-
display_labels=plot_display_labels,
|
|
267
|
-
colorbar=False)
|
|
268
|
-
|
|
269
|
-
disp_.im_.set_clim(vmin=0.0, vmax=1.0)
|
|
270
|
-
|
|
271
|
-
# Turn off gridlines
|
|
272
|
-
ax_cm.grid(False)
|
|
273
|
-
|
|
274
|
-
# 2. CHECK FOR FONT CLASH
|
|
275
|
-
# If matrix is huge, force text smaller. If small, allow user config.
|
|
276
|
-
final_font_size = cm_font_size + 2
|
|
277
|
-
if n_classes > 2:
|
|
278
|
-
final_font_size = cm_font_size - n_classes # Decrease font size for larger matrices
|
|
279
|
-
|
|
280
|
-
for text in ax_cm.texts:
|
|
281
|
-
text.set_fontsize(final_font_size)
|
|
282
|
-
|
|
283
|
-
# Update Ticks for Confusion Matrix
|
|
284
|
-
ax_cm.tick_params(axis='x', labelsize=cm_tick_size)
|
|
285
|
-
ax_cm.tick_params(axis='y', labelsize=cm_tick_size)
|
|
286
|
-
|
|
287
|
-
#if more than 3 classes, rotate x ticks
|
|
288
|
-
if n_classes > 3:
|
|
289
|
-
plt.setp(ax_cm.get_xticklabels(), rotation=45, ha='right', rotation_mode="anchor")
|
|
290
|
-
|
|
291
|
-
# Set titles and labels with padding
|
|
292
|
-
ax_cm.set_title("Confusion Matrix", pad=_EvaluationConfig.LABEL_PADDING, fontsize=cm_font_size + 2)
|
|
293
|
-
ax_cm.set_xlabel(ax_cm.get_xlabel(), labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=cm_font_size)
|
|
294
|
-
ax_cm.set_ylabel(ax_cm.get_ylabel(), labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=cm_font_size)
|
|
295
|
-
|
|
296
|
-
# --- ADJUST COLORBAR FONT & SIZE---
|
|
297
|
-
# Manually add the colorbar with the 'shrink' parameter
|
|
298
|
-
cbar = fig_cm.colorbar(disp_.im_, ax=ax_cm, shrink=0.8)
|
|
299
|
-
|
|
300
|
-
# Update the tick size on the new cbar object
|
|
301
|
-
cbar.ax.tick_params(labelsize=cm_tick_size)
|
|
302
|
-
|
|
303
|
-
# (Optional) add a label to the bar itself (e.g. "Probability")
|
|
304
|
-
# cbar.set_label('Probability', fontsize=12)
|
|
305
|
-
|
|
306
|
-
fig_cm.tight_layout()
|
|
307
|
-
|
|
308
|
-
cm_path = save_dir_path / "confusion_matrix.svg"
|
|
309
|
-
plt.savefig(cm_path)
|
|
310
|
-
_LOGGER.info(f"❇️ Confusion matrix saved as '{cm_path.name}'")
|
|
311
|
-
plt.close(fig_cm)
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
# Plotting logic for ROC, PR, and Calibration Curves
|
|
315
|
-
if y_prob is not None and y_prob.ndim == 2:
|
|
316
|
-
num_classes = y_prob.shape[1]
|
|
317
|
-
|
|
318
|
-
# --- Determine which classes to loop over ---
|
|
319
|
-
class_indices_to_plot = []
|
|
320
|
-
plot_titles = []
|
|
321
|
-
save_suffixes = []
|
|
322
|
-
|
|
323
|
-
if num_classes == 2:
|
|
324
|
-
# Binary case: Only plot for the positive class (index 1)
|
|
325
|
-
class_indices_to_plot = [1]
|
|
326
|
-
plot_titles = [""] # No extra title
|
|
327
|
-
save_suffixes = [""] # No extra suffix
|
|
328
|
-
_LOGGER.debug("Generating binary classification plots (ROC, PR, Calibration).")
|
|
329
|
-
|
|
330
|
-
elif num_classes > 2:
|
|
331
|
-
_LOGGER.debug(f"Generating One-vs-Rest plots for {num_classes} classes.")
|
|
332
|
-
# Multiclass case: Plot for every class (One-vs-Rest)
|
|
333
|
-
class_indices_to_plot = list(range(num_classes))
|
|
334
|
-
|
|
335
|
-
# --- Use class_map names if available ---
|
|
336
|
-
use_generic_names = True
|
|
337
|
-
if map_display_labels and len(map_display_labels) == num_classes:
|
|
338
|
-
try:
|
|
339
|
-
# Ensure labels are safe for filenames
|
|
340
|
-
safe_names = [sanitize_filename(name) for name in map_display_labels]
|
|
341
|
-
plot_titles = [f" ({name} vs. Rest)" for name in map_display_labels]
|
|
342
|
-
save_suffixes = [f"_{safe_names[i]}" for i in class_indices_to_plot]
|
|
343
|
-
use_generic_names = False
|
|
344
|
-
except Exception as e:
|
|
345
|
-
_LOGGER.warning(f"Failed to use 'class_map' for plot titles: {e}. Reverting to generic names.")
|
|
346
|
-
use_generic_names = True
|
|
347
|
-
|
|
348
|
-
if use_generic_names:
|
|
349
|
-
plot_titles = [f" (Class {i} vs. Rest)" for i in class_indices_to_plot]
|
|
350
|
-
save_suffixes = [f"_class_{i}" for i in class_indices_to_plot]
|
|
351
|
-
|
|
352
|
-
else:
|
|
353
|
-
# Should not happen, but good to check
|
|
354
|
-
_LOGGER.warning(f"Probability array has invalid shape {y_prob.shape}. Skipping ROC/PR/Calibration plots.")
|
|
355
|
-
|
|
356
|
-
# --- Loop and generate plots ---
|
|
357
|
-
for i, class_index in enumerate(class_indices_to_plot):
|
|
358
|
-
plot_title = plot_titles[i]
|
|
359
|
-
save_suffix = save_suffixes[i]
|
|
360
|
-
|
|
361
|
-
# Get scores for the current class
|
|
362
|
-
y_score = y_prob[:, class_index]
|
|
363
|
-
|
|
364
|
-
# Binarize y_true for the current class
|
|
365
|
-
y_true_binary = (y_true == class_index).astype(int)
|
|
366
|
-
|
|
367
|
-
# --- Save ROC Curve ---
|
|
368
|
-
fpr, tpr, thresholds = roc_curve(y_true_binary, y_score)
|
|
369
|
-
|
|
370
|
-
try:
|
|
371
|
-
# Calculate Youden's J statistic (tpr - fpr)
|
|
372
|
-
J = tpr - fpr
|
|
373
|
-
# Find the index of the best threshold
|
|
374
|
-
best_index = np.argmax(J)
|
|
375
|
-
optimal_threshold = thresholds[best_index]
|
|
376
|
-
|
|
377
|
-
# Define the filename
|
|
378
|
-
threshold_filename = f"best_threshold{save_suffix}.txt"
|
|
379
|
-
threshold_path = save_dir_path / threshold_filename
|
|
380
|
-
|
|
381
|
-
# Get the class name for the report
|
|
382
|
-
class_name = ""
|
|
383
|
-
# Check if we have display labels and the current index is valid
|
|
384
|
-
if map_display_labels and class_index < len(map_display_labels):
|
|
385
|
-
class_name = map_display_labels[class_index]
|
|
386
|
-
if num_classes > 2:
|
|
387
|
-
# Add 'vs. Rest' for multiclass one-vs-rest plots
|
|
388
|
-
class_name += " (vs. Rest)"
|
|
389
|
-
else:
|
|
390
|
-
# Fallback to the generic title or default binary name
|
|
391
|
-
class_name = plot_title.strip() or "Binary Positive Class"
|
|
392
|
-
|
|
393
|
-
# Create content for the file
|
|
394
|
-
file_content = (
|
|
395
|
-
f"Optimal Classification Threshold (Youden's J Statistic)\n"
|
|
396
|
-
f"Class: {class_name}\n"
|
|
397
|
-
f"--------------------------------------------------\n"
|
|
398
|
-
f"Threshold: {optimal_threshold:.6f}\n"
|
|
399
|
-
f"True Positive Rate (TPR): {tpr[best_index]:.6f}\n"
|
|
400
|
-
f"False Positive Rate (FPR): {fpr[best_index]:.6f}\n"
|
|
401
|
-
)
|
|
402
|
-
|
|
403
|
-
threshold_path.write_text(file_content, encoding="utf-8")
|
|
404
|
-
_LOGGER.info(f"💾 Optimal threshold saved as '{threshold_path.name}'")
|
|
405
|
-
|
|
406
|
-
except Exception as e:
|
|
407
|
-
_LOGGER.warning(f"Could not calculate or save optimal threshold: {e}")
|
|
408
|
-
|
|
409
|
-
# Calculate AUC.
|
|
410
|
-
auc = roc_auc_score(y_true_binary, y_score)
|
|
411
|
-
|
|
412
|
-
fig_roc, ax_roc = plt.subplots(figsize=CLASSIFICATION_PLOT_SIZE, dpi=DPI_value)
|
|
413
|
-
ax_roc.plot(fpr, tpr, label=f'AUC = {auc:.2f}', color=format_config.ROC_PR_line)
|
|
414
|
-
ax_roc.plot([0, 1], [0, 1], 'k--')
|
|
415
|
-
ax_roc.set_title(f'Receiver Operating Characteristic{plot_title}', pad=_EvaluationConfig.LABEL_PADDING, fontsize=format_config.font_size + 2)
|
|
416
|
-
ax_roc.set_xlabel('False Positive Rate', labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=format_config.font_size)
|
|
417
|
-
ax_roc.set_ylabel('True Positive Rate', labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=format_config.font_size)
|
|
418
|
-
|
|
419
|
-
# Apply Ticks and Legend sizing
|
|
420
|
-
ax_roc.tick_params(axis='x', labelsize=xtick_size)
|
|
421
|
-
ax_roc.tick_params(axis='y', labelsize=ytick_size)
|
|
422
|
-
ax_roc.legend(loc='lower right', fontsize=legend_size)
|
|
423
|
-
|
|
424
|
-
ax_roc.grid(True)
|
|
425
|
-
roc_path = save_dir_path / f"roc_curve{save_suffix}.svg"
|
|
426
|
-
|
|
427
|
-
plt.tight_layout()
|
|
428
|
-
|
|
429
|
-
plt.savefig(roc_path)
|
|
430
|
-
plt.close(fig_roc)
|
|
431
|
-
|
|
432
|
-
# --- Save Precision-Recall Curve ---
|
|
433
|
-
precision, recall, _ = precision_recall_curve(y_true_binary, y_score)
|
|
434
|
-
ap_score = average_precision_score(y_true_binary, y_score)
|
|
435
|
-
fig_pr, ax_pr = plt.subplots(figsize=CLASSIFICATION_PLOT_SIZE, dpi=DPI_value)
|
|
436
|
-
ax_pr.plot(recall, precision, label=f'Avg Precision = {ap_score:.2f}', color=format_config.ROC_PR_line)
|
|
437
|
-
ax_pr.set_title(f'Precision-Recall Curve{plot_title}', pad=_EvaluationConfig.LABEL_PADDING, fontsize=format_config.font_size + 2)
|
|
438
|
-
ax_pr.set_xlabel('Recall', labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=format_config.font_size)
|
|
439
|
-
ax_pr.set_ylabel('Precision', labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=format_config.font_size)
|
|
440
|
-
|
|
441
|
-
# Apply Ticks and Legend sizing
|
|
442
|
-
ax_pr.tick_params(axis='x', labelsize=xtick_size)
|
|
443
|
-
ax_pr.tick_params(axis='y', labelsize=ytick_size)
|
|
444
|
-
ax_pr.legend(loc='lower left', fontsize=legend_size)
|
|
445
|
-
|
|
446
|
-
ax_pr.grid(True)
|
|
447
|
-
pr_path = save_dir_path / f"pr_curve{save_suffix}.svg"
|
|
448
|
-
|
|
449
|
-
plt.tight_layout()
|
|
450
|
-
|
|
451
|
-
plt.savefig(pr_path)
|
|
452
|
-
plt.close(fig_pr)
|
|
453
|
-
|
|
454
|
-
# --- Save Calibration Plot ---
|
|
455
|
-
fig_cal, ax_cal = plt.subplots(figsize=CLASSIFICATION_PLOT_SIZE, dpi=DPI_value)
|
|
456
|
-
|
|
457
|
-
# --- Step 1: Get binned data *without* plotting ---
|
|
458
|
-
with plt.ioff(): # Suppress showing the temporary plot
|
|
459
|
-
fig_temp, ax_temp = plt.subplots()
|
|
460
|
-
cal_display_temp = CalibrationDisplay.from_predictions(
|
|
461
|
-
y_true_binary, # Use binarized labels
|
|
462
|
-
y_score,
|
|
463
|
-
n_bins=format_config.calibration_bins,
|
|
464
|
-
ax=ax_temp,
|
|
465
|
-
name="temp" # Add a name to suppress potential warnings
|
|
466
|
-
)
|
|
467
|
-
# Get the x, y coordinates of the binned data
|
|
468
|
-
line_x, line_y = cal_display_temp.line_.get_data() # type: ignore
|
|
469
|
-
plt.close(fig_temp) # Close the temporary plot
|
|
470
|
-
|
|
471
|
-
# --- Step 2: Build the plot from scratch ---
|
|
472
|
-
ax_cal.plot([0, 1], [0, 1], 'k--', label='Perfectly calibrated')
|
|
473
|
-
|
|
474
|
-
sns.regplot(
|
|
475
|
-
x=line_x,
|
|
476
|
-
y=line_y,
|
|
477
|
-
ax=ax_cal,
|
|
478
|
-
scatter=False,
|
|
479
|
-
label=f"Model calibration",
|
|
480
|
-
line_kws={
|
|
481
|
-
'color': format_config.ROC_PR_line,
|
|
482
|
-
'linestyle': '--',
|
|
483
|
-
'linewidth': 2,
|
|
484
|
-
}
|
|
485
|
-
)
|
|
486
|
-
|
|
487
|
-
ax_cal.set_title(f'Reliability Curve{plot_title}', pad=_EvaluationConfig.LABEL_PADDING, fontsize=format_config.font_size + 2)
|
|
488
|
-
ax_cal.set_xlabel('Mean Predicted Probability', labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=format_config.font_size)
|
|
489
|
-
ax_cal.set_ylabel('Fraction of Positives', labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=format_config.font_size)
|
|
490
|
-
|
|
491
|
-
# --- Step 3: Set final limits *after* plotting ---
|
|
492
|
-
ax_cal.set_ylim(0.0, 1.0)
|
|
493
|
-
ax_cal.set_xlim(0.0, 1.0)
|
|
494
|
-
|
|
495
|
-
# Apply Ticks and Legend sizing
|
|
496
|
-
ax_cal.tick_params(axis='x', labelsize=xtick_size)
|
|
497
|
-
ax_cal.tick_params(axis='y', labelsize=ytick_size)
|
|
498
|
-
ax_cal.legend(loc='lower right', fontsize=legend_size)
|
|
499
|
-
|
|
500
|
-
ax_cal.grid(True)
|
|
501
|
-
plt.tight_layout()
|
|
502
|
-
|
|
503
|
-
cal_path = save_dir_path / f"calibration_plot{save_suffix}.svg"
|
|
504
|
-
plt.savefig(cal_path)
|
|
505
|
-
plt.close(fig_cal)
|
|
506
|
-
|
|
507
|
-
_LOGGER.info(f"📈 Saved {len(class_indices_to_plot)} sets of ROC, Precision-Recall, and Calibration plots.")
|
|
508
|
-
|
|
509
|
-
# restore RC params
|
|
510
|
-
# plt.rcParams.update(original_rc_params)
|
|
511
|
-
|
|
512
|
-
|
|
513
|
-
def regression_metrics(
|
|
514
|
-
y_true: np.ndarray,
|
|
515
|
-
y_pred: np.ndarray,
|
|
516
|
-
save_dir: Union[str, Path],
|
|
517
|
-
config: Optional[RegressionMetricsFormat] = None
|
|
518
|
-
):
|
|
519
|
-
"""
|
|
520
|
-
Saves regression metrics and plots.
|
|
521
|
-
|
|
522
|
-
Args:
|
|
523
|
-
y_true (np.ndarray): Ground truth values.
|
|
524
|
-
y_pred (np.ndarray): Predicted values.
|
|
525
|
-
save_dir (str | Path): Directory to save plots and report.
|
|
526
|
-
config (RegressionMetricsFormat, optional): Formatting configuration object.
|
|
527
|
-
"""
|
|
528
|
-
|
|
529
|
-
# --- Parse Config or use defaults ---
|
|
530
|
-
if config is None:
|
|
531
|
-
# Create a default config if one wasn't provided
|
|
532
|
-
format_config = _BaseRegressionFormat()
|
|
533
|
-
else:
|
|
534
|
-
format_config = config
|
|
535
|
-
|
|
536
|
-
# --- Set Matplotlib font size ---
|
|
537
|
-
# original_rc_params = plt.rcParams.copy()
|
|
538
|
-
# plt.rcParams.update({'font.size': format_config.font_size})
|
|
539
|
-
|
|
540
|
-
# --- Resolve Font Sizes ---
|
|
541
|
-
xtick_size = format_config.xtick_size
|
|
542
|
-
ytick_size = format_config.ytick_size
|
|
543
|
-
base_font_size = format_config.font_size
|
|
544
|
-
|
|
545
|
-
# --- Calculate Metrics ---
|
|
546
|
-
rmse = np.sqrt(mean_squared_error(y_true, y_pred))
|
|
547
|
-
mae = mean_absolute_error(y_true, y_pred)
|
|
548
|
-
r2 = r2_score(y_true, y_pred)
|
|
549
|
-
medae = median_absolute_error(y_true, y_pred)
|
|
550
|
-
|
|
551
|
-
report_lines = [
|
|
552
|
-
"--- Regression Report ---",
|
|
553
|
-
f" Root Mean Squared Error (RMSE): {rmse:.4f}",
|
|
554
|
-
f" Mean Absolute Error (MAE): {mae:.4f}",
|
|
555
|
-
f" Median Absolute Error (MedAE): {medae:.4f}",
|
|
556
|
-
f" Coefficient of Determination (R²): {r2:.4f}"
|
|
557
|
-
]
|
|
558
|
-
report_string = "\n".join(report_lines)
|
|
559
|
-
# print(report_string)
|
|
560
|
-
|
|
561
|
-
save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
562
|
-
# Save text report
|
|
563
|
-
report_path = save_dir_path / "regression_report.txt"
|
|
564
|
-
report_path.write_text(report_string)
|
|
565
|
-
_LOGGER.info(f"📝 Regression report saved as '{report_path.name}'")
|
|
566
|
-
|
|
567
|
-
# --- Save residual plot ---
|
|
568
|
-
residuals = y_true - y_pred
|
|
569
|
-
fig_res, ax_res = plt.subplots(figsize=REGRESSION_PLOT_SIZE, dpi=DPI_value)
|
|
570
|
-
ax_res.scatter(y_pred, residuals,
|
|
571
|
-
alpha=format_config.scatter_alpha,
|
|
572
|
-
color=format_config.scatter_color)
|
|
573
|
-
ax_res.axhline(0, color=format_config.residual_line_color, linestyle='--')
|
|
574
|
-
ax_res.set_xlabel("Predicted Values", labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size)
|
|
575
|
-
ax_res.set_ylabel("Residuals", labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size)
|
|
576
|
-
ax_res.set_title("Residual Plot", pad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size + 2)
|
|
577
|
-
|
|
578
|
-
# Apply Ticks
|
|
579
|
-
ax_res.tick_params(axis='x', labelsize=xtick_size)
|
|
580
|
-
ax_res.tick_params(axis='y', labelsize=ytick_size)
|
|
581
|
-
|
|
582
|
-
ax_res.grid(True)
|
|
583
|
-
plt.tight_layout()
|
|
584
|
-
res_path = save_dir_path / "residual_plot.svg"
|
|
585
|
-
plt.savefig(res_path)
|
|
586
|
-
_LOGGER.info(f"📈 Residual plot saved as '{res_path.name}'")
|
|
587
|
-
plt.close(fig_res)
|
|
588
|
-
|
|
589
|
-
# --- Save true vs predicted plot ---
|
|
590
|
-
fig_tvp, ax_tvp = plt.subplots(figsize=REGRESSION_PLOT_SIZE, dpi=DPI_value)
|
|
591
|
-
ax_tvp.scatter(y_true, y_pred,
|
|
592
|
-
alpha=format_config.scatter_alpha,
|
|
593
|
-
color=format_config.scatter_color)
|
|
594
|
-
ax_tvp.plot([y_true.min(), y_true.max()], [y_true.min(), y_true.max()],
|
|
595
|
-
linestyle='--',
|
|
596
|
-
lw=2,
|
|
597
|
-
color=format_config.ideal_line_color)
|
|
598
|
-
ax_tvp.set_xlabel('True Values', labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size)
|
|
599
|
-
ax_tvp.set_ylabel('Predictions', labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size)
|
|
600
|
-
ax_tvp.set_title('True vs. Predicted Values', pad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size + 2)
|
|
601
|
-
|
|
602
|
-
# Apply Ticks
|
|
603
|
-
ax_tvp.tick_params(axis='x', labelsize=xtick_size)
|
|
604
|
-
ax_tvp.tick_params(axis='y', labelsize=ytick_size)
|
|
605
|
-
|
|
606
|
-
ax_tvp.grid(True)
|
|
607
|
-
plt.tight_layout()
|
|
608
|
-
tvp_path = save_dir_path / "true_vs_predicted_plot.svg"
|
|
609
|
-
plt.savefig(tvp_path)
|
|
610
|
-
_LOGGER.info(f"📉 True vs. Predicted plot saved as '{tvp_path.name}'")
|
|
611
|
-
plt.close(fig_tvp)
|
|
612
|
-
|
|
613
|
-
# --- Save Histogram of Residuals ---
|
|
614
|
-
fig_hist, ax_hist = plt.subplots(figsize=REGRESSION_PLOT_SIZE, dpi=DPI_value)
|
|
615
|
-
sns.histplot(residuals, kde=True, ax=ax_hist,
|
|
616
|
-
bins=format_config.hist_bins,
|
|
617
|
-
color=format_config.scatter_color)
|
|
618
|
-
ax_hist.set_xlabel("Residual Value", labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size)
|
|
619
|
-
ax_hist.set_ylabel("Frequency", labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size)
|
|
620
|
-
ax_hist.set_title("Distribution of Residuals", pad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size + 2)
|
|
621
|
-
|
|
622
|
-
# Apply Ticks
|
|
623
|
-
ax_hist.tick_params(axis='x', labelsize=xtick_size)
|
|
624
|
-
ax_hist.tick_params(axis='y', labelsize=ytick_size)
|
|
625
|
-
|
|
626
|
-
ax_hist.grid(True)
|
|
627
|
-
plt.tight_layout()
|
|
628
|
-
hist_path = save_dir_path / "residuals_histogram.svg"
|
|
629
|
-
plt.savefig(hist_path)
|
|
630
|
-
_LOGGER.info(f"📊 Residuals histogram saved as '{hist_path.name}'")
|
|
631
|
-
plt.close(fig_hist)
|
|
632
|
-
|
|
633
|
-
# --- Restore RC params ---
|
|
634
|
-
# plt.rcParams.update(original_rc_params)
|
|
635
|
-
|
|
636
|
-
|
|
637
|
-
def shap_summary_plot(model,
|
|
638
|
-
background_data: Union[torch.Tensor,np.ndarray],
|
|
639
|
-
instances_to_explain: Union[torch.Tensor,np.ndarray],
|
|
640
|
-
feature_names: Optional[list[str]],
|
|
641
|
-
save_dir: Union[str, Path],
|
|
642
|
-
device: torch.device = torch.device('cpu'),
|
|
643
|
-
explainer_type: Literal['deep', 'kernel'] = 'kernel'):
|
|
644
|
-
"""
|
|
645
|
-
Calculates SHAP values and saves summary plots and data.
|
|
646
|
-
|
|
647
|
-
Args:
|
|
648
|
-
model (nn.Module): The trained PyTorch model.
|
|
649
|
-
background_data (torch.Tensor): A sample of data for the explainer background.
|
|
650
|
-
instances_to_explain (torch.Tensor): The specific data instances to explain.
|
|
651
|
-
feature_names (list of str | None): Names of the features for plot labeling.
|
|
652
|
-
save_dir (str | Path): Directory to save SHAP artifacts.
|
|
653
|
-
device (torch.device): The torch device for SHAP calculations.
|
|
654
|
-
explainer_type (Literal['deep', 'kernel']): The explainer to use.
|
|
655
|
-
- 'deep': Uses shap.DeepExplainer. Fast and efficient for
|
|
656
|
-
PyTorch models.
|
|
657
|
-
- 'kernel': Uses shap.KernelExplainer. Model-agnostic but EXTREMELY
|
|
658
|
-
slow and memory-intensive.
|
|
659
|
-
"""
|
|
660
|
-
|
|
661
|
-
_LOGGER.info(f"📊 Running SHAP Value Explanation Using {explainer_type.upper()} Explainer")
|
|
662
|
-
|
|
663
|
-
model.eval()
|
|
664
|
-
# model.cpu() # Run explanations on CPU
|
|
665
|
-
|
|
666
|
-
shap_values = None
|
|
667
|
-
instances_to_explain_np = None
|
|
668
|
-
|
|
669
|
-
if explainer_type == 'deep':
|
|
670
|
-
# --- 1. Use DeepExplainer ---
|
|
671
|
-
|
|
672
|
-
# Ensure data is torch.Tensor
|
|
673
|
-
if isinstance(background_data, np.ndarray):
|
|
674
|
-
background_data = torch.from_numpy(background_data).float()
|
|
675
|
-
if isinstance(instances_to_explain, np.ndarray):
|
|
676
|
-
instances_to_explain = torch.from_numpy(instances_to_explain).float()
|
|
677
|
-
|
|
678
|
-
if torch.isnan(background_data).any() or torch.isnan(instances_to_explain).any():
|
|
679
|
-
_LOGGER.error("Input data for SHAP contains NaN values. Aborting explanation.")
|
|
680
|
-
return
|
|
681
|
-
|
|
682
|
-
background_data = background_data.to(device)
|
|
683
|
-
instances_to_explain = instances_to_explain.to(device)
|
|
684
|
-
|
|
685
|
-
with warnings.catch_warnings():
|
|
686
|
-
warnings.simplefilter("ignore", category=UserWarning)
|
|
687
|
-
explainer = shap.DeepExplainer(model, background_data)
|
|
688
|
-
|
|
689
|
-
# print("Calculating SHAP values with DeepExplainer...")
|
|
690
|
-
shap_values = explainer.shap_values(instances_to_explain)
|
|
691
|
-
instances_to_explain_np = instances_to_explain.cpu().numpy()
|
|
692
|
-
|
|
693
|
-
elif explainer_type == 'kernel':
|
|
694
|
-
# --- 2. Use KernelExplainer ---
|
|
695
|
-
_LOGGER.warning(
|
|
696
|
-
"KernelExplainer is memory-intensive and slow. Consider reducing the number of instances to explain if the process terminates unexpectedly."
|
|
697
|
-
)
|
|
698
|
-
|
|
699
|
-
# Ensure data is np.ndarray
|
|
700
|
-
if isinstance(background_data, torch.Tensor):
|
|
701
|
-
background_data_np = background_data.cpu().numpy()
|
|
702
|
-
else:
|
|
703
|
-
background_data_np = background_data
|
|
704
|
-
|
|
705
|
-
if isinstance(instances_to_explain, torch.Tensor):
|
|
706
|
-
instances_to_explain_np = instances_to_explain.cpu().numpy()
|
|
707
|
-
else:
|
|
708
|
-
instances_to_explain_np = instances_to_explain
|
|
709
|
-
|
|
710
|
-
if np.isnan(background_data_np).any() or np.isnan(instances_to_explain_np).any():
|
|
711
|
-
_LOGGER.error("Input data for SHAP contains NaN values. Aborting explanation.")
|
|
712
|
-
return
|
|
713
|
-
|
|
714
|
-
# Summarize background data
|
|
715
|
-
background_summary = shap.kmeans(background_data_np, 30)
|
|
716
|
-
|
|
717
|
-
def prediction_wrapper(x_np: np.ndarray) -> np.ndarray:
|
|
718
|
-
x_torch = torch.from_numpy(x_np).float().to(device)
|
|
719
|
-
with torch.no_grad():
|
|
720
|
-
output = model(x_torch)
|
|
721
|
-
# Return as numpy array
|
|
722
|
-
return output.cpu().numpy()
|
|
723
|
-
|
|
724
|
-
explainer = shap.KernelExplainer(prediction_wrapper, background_summary)
|
|
725
|
-
# print("Calculating SHAP values with KernelExplainer...")
|
|
726
|
-
shap_values = explainer.shap_values(instances_to_explain_np, l1_reg="aic")
|
|
727
|
-
# instances_to_explain_np is already set
|
|
728
|
-
|
|
729
|
-
else:
|
|
730
|
-
_LOGGER.error(f"Invalid explainer_type: '{explainer_type}'. Must be 'deep' or 'kernel'.")
|
|
731
|
-
raise ValueError()
|
|
732
|
-
|
|
733
|
-
if not isinstance(shap_values, list) and shap_values.ndim == 3 and shap_values.shape[2] == 1: # type: ignore
|
|
734
|
-
# _LOGGER.info("Squeezing SHAP values from (N, F, 1) to (N, F) for regression plot.")
|
|
735
|
-
shap_values = shap_values.squeeze(-1) # type: ignore
|
|
736
|
-
|
|
737
|
-
# --- 3. Plotting and Saving ---
|
|
738
|
-
save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
739
|
-
plt.ioff()
|
|
740
|
-
|
|
741
|
-
# Convert instances to a DataFrame. robust way to ensure SHAP correctly maps values to feature names.
|
|
742
|
-
if feature_names is None:
|
|
743
|
-
# Create generic names if none were provided
|
|
744
|
-
num_features = instances_to_explain_np.shape[1]
|
|
745
|
-
feature_names = [f'feature_{i}' for i in range(num_features)]
|
|
746
|
-
|
|
747
|
-
instances_df = pd.DataFrame(instances_to_explain_np, columns=feature_names)
|
|
748
|
-
|
|
749
|
-
# Save Bar Plot
|
|
750
|
-
bar_path = save_dir_path / "shap_bar_plot.svg"
|
|
751
|
-
shap.summary_plot(shap_values, instances_df, plot_type="bar", show=False)
|
|
752
|
-
ax = plt.gca()
|
|
753
|
-
ax.set_xlabel("SHAP Value Impact", labelpad=10)
|
|
754
|
-
plt.title("SHAP Feature Importance")
|
|
755
|
-
plt.tight_layout()
|
|
756
|
-
plt.savefig(bar_path)
|
|
757
|
-
_LOGGER.info(f"📊 SHAP bar plot saved as '{bar_path.name}'")
|
|
758
|
-
plt.close()
|
|
759
|
-
|
|
760
|
-
# Save Dot Plot
|
|
761
|
-
dot_path = save_dir_path / "shap_dot_plot.svg"
|
|
762
|
-
shap.summary_plot(shap_values, instances_df, plot_type="dot", show=False)
|
|
763
|
-
ax = plt.gca()
|
|
764
|
-
ax.set_xlabel("SHAP Value Impact", labelpad=10)
|
|
765
|
-
if plt.gcf().axes and len(plt.gcf().axes) > 1:
|
|
766
|
-
cb = plt.gcf().axes[-1]
|
|
767
|
-
cb.set_ylabel("", size=1)
|
|
768
|
-
plt.title("SHAP Feature Importance")
|
|
769
|
-
plt.tight_layout()
|
|
770
|
-
plt.savefig(dot_path)
|
|
771
|
-
_LOGGER.info(f"📊 SHAP dot plot saved as '{dot_path.name}'")
|
|
772
|
-
plt.close()
|
|
773
|
-
|
|
774
|
-
# Save Summary Data to CSV
|
|
775
|
-
shap_summary_filename = SHAPKeys.SAVENAME + ".csv"
|
|
776
|
-
summary_path = save_dir_path / shap_summary_filename
|
|
777
|
-
|
|
778
|
-
# Handle multi-class (list of arrays) vs. regression (single array)
|
|
779
|
-
if isinstance(shap_values, list):
|
|
780
|
-
mean_abs_shap = np.abs(np.stack(shap_values)).mean(axis=0).mean(axis=0)
|
|
781
|
-
else:
|
|
782
|
-
mean_abs_shap = np.abs(shap_values).mean(axis=0)
|
|
783
|
-
|
|
784
|
-
mean_abs_shap = mean_abs_shap.flatten()
|
|
785
|
-
|
|
786
|
-
summary_df = pd.DataFrame({
|
|
787
|
-
SHAPKeys.FEATURE_COLUMN: feature_names,
|
|
788
|
-
SHAPKeys.SHAP_VALUE_COLUMN: mean_abs_shap
|
|
789
|
-
}).sort_values(SHAPKeys.SHAP_VALUE_COLUMN, ascending=False)
|
|
790
|
-
|
|
791
|
-
summary_df.to_csv(summary_path, index=False)
|
|
792
|
-
|
|
793
|
-
_LOGGER.info(f"📝 SHAP summary data saved as '{summary_path.name}'")
|
|
794
|
-
plt.ion()
|
|
795
|
-
|
|
796
|
-
|
|
797
|
-
def plot_attention_importance(weights: List[torch.Tensor], feature_names: Optional[List[str]], save_dir: Union[str, Path], top_n: int = 10):
|
|
798
|
-
"""
|
|
799
|
-
Aggregates attention weights and plots global feature importance.
|
|
800
|
-
|
|
801
|
-
The plot shows the mean attention for each feature as a bar, with the
|
|
802
|
-
standard deviation represented by error bars.
|
|
803
|
-
|
|
804
|
-
Args:
|
|
805
|
-
weights (List[torch.Tensor]): A list of attention weight tensors from each batch.
|
|
806
|
-
feature_names (List[str] | None): Names of the features for plot labeling.
|
|
807
|
-
save_dir (str | Path): Directory to save the plot and summary CSV.
|
|
808
|
-
top_n (int): The number of top features to display in the plot.
|
|
809
|
-
"""
|
|
810
|
-
if not weights:
|
|
811
|
-
_LOGGER.error("Attention weights list is empty. Skipping importance plot.")
|
|
812
|
-
return
|
|
813
|
-
|
|
814
|
-
# --- Step 1: Aggregate data ---
|
|
815
|
-
# Concatenate the list of tensors into a single large tensor
|
|
816
|
-
full_weights_tensor = torch.cat(weights, dim=0)
|
|
817
|
-
|
|
818
|
-
# Calculate mean and std dev across the batch dimension (dim=0)
|
|
819
|
-
mean_weights = full_weights_tensor.mean(dim=0)
|
|
820
|
-
std_weights = full_weights_tensor.std(dim=0)
|
|
821
|
-
|
|
822
|
-
# --- Step 2: Create and save summary DataFrame ---
|
|
823
|
-
if feature_names is None:
|
|
824
|
-
feature_names = [f'feature_{i}' for i in range(len(mean_weights))]
|
|
825
|
-
|
|
826
|
-
summary_df = pd.DataFrame({
|
|
827
|
-
'feature': feature_names,
|
|
828
|
-
'mean_attention': mean_weights.numpy(),
|
|
829
|
-
'std_attention': std_weights.numpy()
|
|
830
|
-
}).sort_values('mean_attention', ascending=False)
|
|
831
|
-
|
|
832
|
-
save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
833
|
-
summary_path = save_dir_path / "attention_summary.csv"
|
|
834
|
-
summary_df.to_csv(summary_path, index=False)
|
|
835
|
-
_LOGGER.info(f"📝 Attention summary data saved as '{summary_path.name}'")
|
|
836
|
-
|
|
837
|
-
# --- Step 3: Create and save the plot for top N features ---
|
|
838
|
-
plot_df = summary_df.head(top_n).sort_values('mean_attention', ascending=True)
|
|
839
|
-
|
|
840
|
-
plt.figure(figsize=(10, 8), dpi=DPI_value)
|
|
841
|
-
|
|
842
|
-
# Create horizontal bar plot with error bars
|
|
843
|
-
plt.barh(
|
|
844
|
-
y=plot_df['feature'],
|
|
845
|
-
width=plot_df['mean_attention'],
|
|
846
|
-
xerr=plot_df['std_attention'],
|
|
847
|
-
align='center',
|
|
848
|
-
alpha=0.7,
|
|
849
|
-
ecolor='grey',
|
|
850
|
-
capsize=3,
|
|
851
|
-
color='cornflowerblue'
|
|
852
|
-
)
|
|
853
|
-
|
|
854
|
-
plt.title('Top Features by Attention')
|
|
855
|
-
plt.xlabel('Average Attention Weight')
|
|
856
|
-
plt.ylabel('Feature')
|
|
857
|
-
plt.grid(axis='x', linestyle='--', alpha=0.6)
|
|
858
|
-
plt.tight_layout()
|
|
859
|
-
|
|
860
|
-
plot_path = save_dir_path / "attention_importance.svg"
|
|
861
|
-
plt.savefig(plot_path)
|
|
862
|
-
_LOGGER.info(f"📊 Attention importance plot saved as '{plot_path.name}'")
|
|
863
|
-
plt.close()
|
|
864
|
-
|
|
865
|
-
|
|
866
|
-
def info():
|
|
867
|
-
_script_info(__all__)
|