dragon-ml-toolbox 19.13.0__py3-none-any.whl → 20.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/METADATA +29 -46
- dragon_ml_toolbox-20.0.0.dist-info/RECORD +178 -0
- ml_tools/{ETL_cleaning.py → ETL_cleaning/__init__.py} +13 -5
- ml_tools/ETL_cleaning/_basic_clean.py +351 -0
- ml_tools/ETL_cleaning/_clean_tools.py +128 -0
- ml_tools/ETL_cleaning/_dragon_cleaner.py +245 -0
- ml_tools/ETL_cleaning/_imprimir.py +13 -0
- ml_tools/{ETL_engineering.py → ETL_engineering/__init__.py} +8 -4
- ml_tools/ETL_engineering/_dragon_engineering.py +261 -0
- ml_tools/ETL_engineering/_imprimir.py +24 -0
- ml_tools/{_core/_ETL_engineering.py → ETL_engineering/_transforms.py} +14 -267
- ml_tools/{_core → GUI_tools}/_GUI_tools.py +37 -40
- ml_tools/{GUI_tools.py → GUI_tools/__init__.py} +7 -5
- ml_tools/GUI_tools/_imprimir.py +12 -0
- ml_tools/IO_tools/_IO_loggers.py +235 -0
- ml_tools/IO_tools/_IO_save_load.py +151 -0
- ml_tools/IO_tools/_IO_utils.py +140 -0
- ml_tools/{IO_tools.py → IO_tools/__init__.py} +13 -5
- ml_tools/IO_tools/_imprimir.py +14 -0
- ml_tools/MICE/_MICE_imputation.py +132 -0
- ml_tools/{MICE_imputation.py → MICE/__init__.py} +6 -7
- ml_tools/{_core/_MICE_imputation.py → MICE/_dragon_mice.py} +243 -322
- ml_tools/MICE/_imprimir.py +11 -0
- ml_tools/{ML_callbacks.py → ML_callbacks/__init__.py} +12 -4
- ml_tools/ML_callbacks/_base.py +101 -0
- ml_tools/ML_callbacks/_checkpoint.py +232 -0
- ml_tools/ML_callbacks/_early_stop.py +208 -0
- ml_tools/ML_callbacks/_imprimir.py +12 -0
- ml_tools/ML_callbacks/_scheduler.py +197 -0
- ml_tools/{ML_chaining_utilities.py → ML_chain/__init__.py} +8 -3
- ml_tools/{_core/_ML_chaining_utilities.py → ML_chain/_chaining_tools.py} +5 -129
- ml_tools/ML_chain/_dragon_chain.py +140 -0
- ml_tools/ML_chain/_imprimir.py +11 -0
- ml_tools/ML_configuration/__init__.py +90 -0
- ml_tools/ML_configuration/_base_model_config.py +69 -0
- ml_tools/ML_configuration/_finalize.py +366 -0
- ml_tools/ML_configuration/_imprimir.py +47 -0
- ml_tools/ML_configuration/_metrics.py +593 -0
- ml_tools/ML_configuration/_models.py +206 -0
- ml_tools/ML_configuration/_training.py +124 -0
- ml_tools/ML_datasetmaster/__init__.py +28 -0
- ml_tools/ML_datasetmaster/_base_datasetmaster.py +337 -0
- ml_tools/{_core/_ML_datasetmaster.py → ML_datasetmaster/_datasetmaster.py} +9 -329
- ml_tools/ML_datasetmaster/_imprimir.py +15 -0
- ml_tools/{_core/_ML_sequence_datasetmaster.py → ML_datasetmaster/_sequence_datasetmaster.py} +13 -15
- ml_tools/{_core/_ML_vision_datasetmaster.py → ML_datasetmaster/_vision_datasetmaster.py} +63 -65
- ml_tools/ML_evaluation/__init__.py +53 -0
- ml_tools/ML_evaluation/_classification.py +629 -0
- ml_tools/ML_evaluation/_feature_importance.py +409 -0
- ml_tools/ML_evaluation/_imprimir.py +25 -0
- ml_tools/ML_evaluation/_loss.py +92 -0
- ml_tools/ML_evaluation/_regression.py +273 -0
- ml_tools/{_core/_ML_sequence_evaluation.py → ML_evaluation/_sequence.py} +8 -11
- ml_tools/{_core/_ML_vision_evaluation.py → ML_evaluation/_vision.py} +12 -17
- ml_tools/{_core → ML_evaluation_captum}/_ML_evaluation_captum.py +11 -38
- ml_tools/{ML_evaluation_captum.py → ML_evaluation_captum/__init__.py} +6 -4
- ml_tools/ML_evaluation_captum/_imprimir.py +10 -0
- ml_tools/{_core → ML_finalize_handler}/_ML_finalize_handler.py +3 -7
- ml_tools/ML_finalize_handler/__init__.py +10 -0
- ml_tools/ML_finalize_handler/_imprimir.py +8 -0
- ml_tools/ML_inference/__init__.py +22 -0
- ml_tools/ML_inference/_base_inference.py +166 -0
- ml_tools/{_core/_ML_chaining_inference.py → ML_inference/_chain_inference.py} +14 -17
- ml_tools/ML_inference/_dragon_inference.py +332 -0
- ml_tools/ML_inference/_imprimir.py +11 -0
- ml_tools/ML_inference/_multi_inference.py +180 -0
- ml_tools/ML_inference_sequence/__init__.py +10 -0
- ml_tools/ML_inference_sequence/_imprimir.py +8 -0
- ml_tools/{_core/_ML_sequence_inference.py → ML_inference_sequence/_sequence_inference.py} +11 -15
- ml_tools/ML_inference_vision/__init__.py +10 -0
- ml_tools/ML_inference_vision/_imprimir.py +8 -0
- ml_tools/{_core/_ML_vision_inference.py → ML_inference_vision/_vision_inference.py} +15 -19
- ml_tools/ML_models/__init__.py +32 -0
- ml_tools/{_core/_ML_models_advanced.py → ML_models/_advanced_models.py} +22 -18
- ml_tools/ML_models/_base_mlp_attention.py +198 -0
- ml_tools/{_core/_models_advanced_base.py → ML_models/_base_save_load.py} +73 -49
- ml_tools/ML_models/_dragon_tabular.py +248 -0
- ml_tools/ML_models/_imprimir.py +18 -0
- ml_tools/ML_models/_mlp_attention.py +134 -0
- ml_tools/{_core → ML_models}/_models_advanced_helpers.py +13 -13
- ml_tools/ML_models_sequence/__init__.py +10 -0
- ml_tools/ML_models_sequence/_imprimir.py +8 -0
- ml_tools/{_core/_ML_sequence_models.py → ML_models_sequence/_sequence_models.py} +5 -8
- ml_tools/ML_models_vision/__init__.py +29 -0
- ml_tools/ML_models_vision/_base_wrapper.py +254 -0
- ml_tools/ML_models_vision/_image_classification.py +182 -0
- ml_tools/ML_models_vision/_image_segmentation.py +108 -0
- ml_tools/ML_models_vision/_imprimir.py +16 -0
- ml_tools/ML_models_vision/_object_detection.py +135 -0
- ml_tools/ML_optimization/__init__.py +21 -0
- ml_tools/ML_optimization/_imprimir.py +13 -0
- ml_tools/{_core/_ML_optimization_pareto.py → ML_optimization/_multi_dragon.py} +18 -24
- ml_tools/ML_optimization/_single_dragon.py +203 -0
- ml_tools/{_core/_ML_optimization.py → ML_optimization/_single_manual.py} +75 -213
- ml_tools/{_core → ML_scaler}/_ML_scaler.py +8 -11
- ml_tools/ML_scaler/__init__.py +10 -0
- ml_tools/ML_scaler/_imprimir.py +8 -0
- ml_tools/ML_trainer/__init__.py +20 -0
- ml_tools/ML_trainer/_base_trainer.py +297 -0
- ml_tools/ML_trainer/_dragon_detection_trainer.py +402 -0
- ml_tools/ML_trainer/_dragon_sequence_trainer.py +540 -0
- ml_tools/ML_trainer/_dragon_trainer.py +1160 -0
- ml_tools/ML_trainer/_imprimir.py +10 -0
- ml_tools/{ML_utilities.py → ML_utilities/__init__.py} +14 -6
- ml_tools/ML_utilities/_artifact_finder.py +382 -0
- ml_tools/ML_utilities/_imprimir.py +16 -0
- ml_tools/ML_utilities/_inspection.py +325 -0
- ml_tools/ML_utilities/_train_tools.py +205 -0
- ml_tools/{ML_vision_transformers.py → ML_vision_transformers/__init__.py} +9 -6
- ml_tools/{_core/_ML_vision_transformers.py → ML_vision_transformers/_core_transforms.py} +11 -155
- ml_tools/ML_vision_transformers/_imprimir.py +14 -0
- ml_tools/ML_vision_transformers/_offline_augmentation.py +159 -0
- ml_tools/{_core/_PSO_optimization.py → PSO_optimization/_PSO.py} +58 -15
- ml_tools/{PSO_optimization.py → PSO_optimization/__init__.py} +5 -3
- ml_tools/PSO_optimization/_imprimir.py +10 -0
- ml_tools/SQL/__init__.py +7 -0
- ml_tools/{_core/_SQL.py → SQL/_dragon_SQL.py} +7 -11
- ml_tools/SQL/_imprimir.py +8 -0
- ml_tools/{_core → VIF}/_VIF_factor.py +5 -8
- ml_tools/{VIF_factor.py → VIF/__init__.py} +4 -2
- ml_tools/VIF/_imprimir.py +10 -0
- ml_tools/_core/__init__.py +7 -1
- ml_tools/_core/_logger.py +8 -18
- ml_tools/_core/_schema_load_ops.py +43 -0
- ml_tools/_core/_script_info.py +2 -2
- ml_tools/{data_exploration.py → data_exploration/__init__.py} +32 -16
- ml_tools/data_exploration/_analysis.py +214 -0
- ml_tools/data_exploration/_cleaning.py +566 -0
- ml_tools/data_exploration/_features.py +583 -0
- ml_tools/data_exploration/_imprimir.py +32 -0
- ml_tools/data_exploration/_plotting.py +487 -0
- ml_tools/data_exploration/_schema_ops.py +176 -0
- ml_tools/{ensemble_evaluation.py → ensemble_evaluation/__init__.py} +6 -4
- ml_tools/{_core → ensemble_evaluation}/_ensemble_evaluation.py +3 -7
- ml_tools/ensemble_evaluation/_imprimir.py +14 -0
- ml_tools/{ensemble_inference.py → ensemble_inference/__init__.py} +5 -3
- ml_tools/{_core → ensemble_inference}/_ensemble_inference.py +15 -18
- ml_tools/ensemble_inference/_imprimir.py +9 -0
- ml_tools/{ensemble_learning.py → ensemble_learning/__init__.py} +4 -6
- ml_tools/{_core → ensemble_learning}/_ensemble_learning.py +7 -10
- ml_tools/ensemble_learning/_imprimir.py +10 -0
- ml_tools/{excel_handler.py → excel_handler/__init__.py} +5 -3
- ml_tools/{_core → excel_handler}/_excel_handler.py +6 -10
- ml_tools/excel_handler/_imprimir.py +13 -0
- ml_tools/{keys.py → keys/__init__.py} +4 -1
- ml_tools/keys/_imprimir.py +11 -0
- ml_tools/{_core → keys}/_keys.py +2 -0
- ml_tools/{math_utilities.py → math_utilities/__init__.py} +5 -2
- ml_tools/math_utilities/_imprimir.py +11 -0
- ml_tools/{_core → math_utilities}/_math_utilities.py +1 -5
- ml_tools/{optimization_tools.py → optimization_tools/__init__.py} +9 -4
- ml_tools/optimization_tools/_imprimir.py +13 -0
- ml_tools/optimization_tools/_optimization_bounds.py +236 -0
- ml_tools/optimization_tools/_optimization_plots.py +218 -0
- ml_tools/{path_manager.py → path_manager/__init__.py} +6 -3
- ml_tools/{_core/_path_manager.py → path_manager/_dragonmanager.py} +11 -347
- ml_tools/path_manager/_imprimir.py +15 -0
- ml_tools/path_manager/_path_tools.py +346 -0
- ml_tools/plot_fonts/__init__.py +8 -0
- ml_tools/plot_fonts/_imprimir.py +8 -0
- ml_tools/{_core → plot_fonts}/_plot_fonts.py +2 -5
- ml_tools/schema/__init__.py +15 -0
- ml_tools/schema/_feature_schema.py +223 -0
- ml_tools/schema/_gui_schema.py +191 -0
- ml_tools/schema/_imprimir.py +10 -0
- ml_tools/{serde.py → serde/__init__.py} +4 -2
- ml_tools/serde/_imprimir.py +10 -0
- ml_tools/{_core → serde}/_serde.py +3 -8
- ml_tools/{utilities.py → utilities/__init__.py} +11 -6
- ml_tools/utilities/_imprimir.py +18 -0
- ml_tools/{_core/_utilities.py → utilities/_utility_save_load.py} +13 -190
- ml_tools/utilities/_utility_tools.py +192 -0
- dragon_ml_toolbox-19.13.0.dist-info/RECORD +0 -111
- ml_tools/ML_chaining_inference.py +0 -8
- ml_tools/ML_configuration.py +0 -86
- ml_tools/ML_configuration_pytab.py +0 -14
- ml_tools/ML_datasetmaster.py +0 -10
- ml_tools/ML_evaluation.py +0 -16
- ml_tools/ML_evaluation_multi.py +0 -12
- ml_tools/ML_finalize_handler.py +0 -8
- ml_tools/ML_inference.py +0 -12
- ml_tools/ML_models.py +0 -14
- ml_tools/ML_models_advanced.py +0 -14
- ml_tools/ML_models_pytab.py +0 -14
- ml_tools/ML_optimization.py +0 -14
- ml_tools/ML_optimization_pareto.py +0 -8
- ml_tools/ML_scaler.py +0 -8
- ml_tools/ML_sequence_datasetmaster.py +0 -8
- ml_tools/ML_sequence_evaluation.py +0 -10
- ml_tools/ML_sequence_inference.py +0 -8
- ml_tools/ML_sequence_models.py +0 -8
- ml_tools/ML_trainer.py +0 -12
- ml_tools/ML_vision_datasetmaster.py +0 -12
- ml_tools/ML_vision_evaluation.py +0 -10
- ml_tools/ML_vision_inference.py +0 -8
- ml_tools/ML_vision_models.py +0 -18
- ml_tools/SQL.py +0 -8
- ml_tools/_core/_ETL_cleaning.py +0 -694
- ml_tools/_core/_IO_tools.py +0 -498
- ml_tools/_core/_ML_callbacks.py +0 -702
- ml_tools/_core/_ML_configuration.py +0 -1332
- ml_tools/_core/_ML_configuration_pytab.py +0 -102
- ml_tools/_core/_ML_evaluation.py +0 -867
- ml_tools/_core/_ML_evaluation_multi.py +0 -544
- ml_tools/_core/_ML_inference.py +0 -646
- ml_tools/_core/_ML_models.py +0 -668
- ml_tools/_core/_ML_models_pytab.py +0 -693
- ml_tools/_core/_ML_trainer.py +0 -2323
- ml_tools/_core/_ML_utilities.py +0 -886
- ml_tools/_core/_ML_vision_models.py +0 -644
- ml_tools/_core/_data_exploration.py +0 -1901
- ml_tools/_core/_optimization_tools.py +0 -493
- ml_tools/_core/_schema.py +0 -359
- ml_tools/plot_fonts.py +0 -8
- ml_tools/schema.py +0 -12
- {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/WHEEL +0 -0
- {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/licenses/LICENSE +0 -0
- {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +0 -0
- {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,629 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import pandas as pd
|
|
3
|
+
import matplotlib.pyplot as plt
|
|
4
|
+
import seaborn as sns
|
|
5
|
+
from sklearn.calibration import CalibrationDisplay
|
|
6
|
+
from sklearn.metrics import (
|
|
7
|
+
classification_report,
|
|
8
|
+
ConfusionMatrixDisplay,
|
|
9
|
+
roc_curve,
|
|
10
|
+
roc_auc_score,
|
|
11
|
+
precision_recall_curve,
|
|
12
|
+
average_precision_score,
|
|
13
|
+
hamming_loss,
|
|
14
|
+
jaccard_score
|
|
15
|
+
)
|
|
16
|
+
from pathlib import Path
|
|
17
|
+
from typing import Union, Optional
|
|
18
|
+
|
|
19
|
+
from ..ML_configuration._metrics import (_BaseMultiLabelFormat,
|
|
20
|
+
_BaseClassificationFormat,
|
|
21
|
+
FormatBinaryClassificationMetrics,
|
|
22
|
+
FormatMultiClassClassificationMetrics,
|
|
23
|
+
FormatBinaryImageClassificationMetrics,
|
|
24
|
+
FormatMultiClassImageClassificationMetrics,
|
|
25
|
+
FormatMultiLabelBinaryClassificationMetrics)
|
|
26
|
+
|
|
27
|
+
from ..path_manager import make_fullpath, sanitize_filename
|
|
28
|
+
from .._core import get_logger
|
|
29
|
+
from ..keys._keys import _EvaluationConfig
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
_LOGGER = get_logger("Classification Metrics")
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
__all__ = [
|
|
36
|
+
"classification_metrics",
|
|
37
|
+
"multi_label_classification_metrics",
|
|
38
|
+
]
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
DPI_value = _EvaluationConfig.DPI
|
|
42
|
+
CLASSIFICATION_PLOT_SIZE = _EvaluationConfig.CLASSIFICATION_PLOT_SIZE
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def classification_metrics(save_dir: Union[str, Path],
|
|
46
|
+
y_true: np.ndarray,
|
|
47
|
+
y_pred: np.ndarray,
|
|
48
|
+
y_prob: Optional[np.ndarray] = None,
|
|
49
|
+
class_map: Optional[dict[str,int]] = None,
|
|
50
|
+
config: Optional[Union[FormatBinaryClassificationMetrics,
|
|
51
|
+
FormatMultiClassClassificationMetrics,
|
|
52
|
+
FormatBinaryImageClassificationMetrics,
|
|
53
|
+
FormatMultiClassImageClassificationMetrics]] = None):
|
|
54
|
+
"""
|
|
55
|
+
Saves classification metrics and plots.
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
y_true (np.ndarray): Ground truth labels.
|
|
59
|
+
y_pred (np.ndarray): Predicted labels.
|
|
60
|
+
y_prob (np.ndarray): Predicted probabilities for ROC curve.
|
|
61
|
+
config (object): Formatting configuration object.
|
|
62
|
+
save_dir (str | Path): Directory to save plots.
|
|
63
|
+
"""
|
|
64
|
+
# --- Parse Config or use defaults ---
|
|
65
|
+
if config is None:
|
|
66
|
+
# Create a default config if one wasn't provided
|
|
67
|
+
format_config = _BaseClassificationFormat()
|
|
68
|
+
else:
|
|
69
|
+
format_config = config
|
|
70
|
+
|
|
71
|
+
# --- Set Font Sizes ---
|
|
72
|
+
xtick_size = format_config.xtick_size
|
|
73
|
+
ytick_size = format_config.ytick_size
|
|
74
|
+
legend_size = format_config.legend_size
|
|
75
|
+
|
|
76
|
+
# config font size for heatmap
|
|
77
|
+
cm_font_size = format_config.cm_font_size
|
|
78
|
+
cm_tick_size = cm_font_size - 4
|
|
79
|
+
|
|
80
|
+
# --- Parse class_map ---
|
|
81
|
+
map_labels = None
|
|
82
|
+
map_display_labels = None
|
|
83
|
+
if class_map:
|
|
84
|
+
# Sort the map by its values (the indices) to ensure correct order
|
|
85
|
+
try:
|
|
86
|
+
sorted_items = sorted(class_map.items(), key=lambda item: item[1])
|
|
87
|
+
map_labels = [item[1] for item in sorted_items]
|
|
88
|
+
map_display_labels = [item[0] for item in sorted_items]
|
|
89
|
+
except Exception as e:
|
|
90
|
+
_LOGGER.warning(f"Could not parse 'class_map': {e}")
|
|
91
|
+
map_labels = None
|
|
92
|
+
map_display_labels = None
|
|
93
|
+
|
|
94
|
+
# Generate report as both text and dictionary
|
|
95
|
+
report_text: str = classification_report(y_true, y_pred, labels=map_labels, target_names=map_display_labels) # type: ignore
|
|
96
|
+
report_dict: dict = classification_report(y_true, y_pred, output_dict=True, labels=map_labels, target_names=map_display_labels) # type: ignore
|
|
97
|
+
# print(report_text)
|
|
98
|
+
|
|
99
|
+
save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
100
|
+
# Save text report
|
|
101
|
+
report_path = save_dir_path / "classification_report.txt"
|
|
102
|
+
report_path.write_text(report_text, encoding="utf-8")
|
|
103
|
+
_LOGGER.info(f"📝 Classification report saved as '{report_path.name}'")
|
|
104
|
+
|
|
105
|
+
# --- Save Classification Report Heatmap ---
|
|
106
|
+
try:
|
|
107
|
+
# Create DataFrame from report
|
|
108
|
+
report_df = pd.DataFrame(report_dict)
|
|
109
|
+
|
|
110
|
+
# 1. Robust Cleanup: Drop by name, not position
|
|
111
|
+
# Remove 'accuracy' column if it exists (handles the scalar value issue)
|
|
112
|
+
report_df = report_df.drop(columns=['accuracy'], errors='ignore')
|
|
113
|
+
|
|
114
|
+
# Remove 'support' row explicitly (safer than iloc[:-1])
|
|
115
|
+
if 'support' in report_df.index:
|
|
116
|
+
report_df = report_df.drop(index='support')
|
|
117
|
+
|
|
118
|
+
# 2. Transpose: Rows = Classes, Cols = Metrics
|
|
119
|
+
plot_df = report_df.T
|
|
120
|
+
|
|
121
|
+
# 3. Dynamic Height Calculation
|
|
122
|
+
# (Base height of 4 + 0.5 inches per class row)
|
|
123
|
+
fig_height = max(5.0, len(plot_df.index) * 0.5 + 4.0)
|
|
124
|
+
fig_width = 8.0 # Set a fixed width
|
|
125
|
+
|
|
126
|
+
# --- Use calculated dimensions, not the config constant ---
|
|
127
|
+
fig_heat, ax_heat = plt.subplots(figsize=(fig_width, fig_height), dpi=_EvaluationConfig.DPI)
|
|
128
|
+
|
|
129
|
+
# sns.set_theme(font_scale=1.4)
|
|
130
|
+
sns.heatmap(plot_df,
|
|
131
|
+
annot=True,
|
|
132
|
+
cmap=format_config.cmap,
|
|
133
|
+
fmt='.2f',
|
|
134
|
+
vmin=0.0,
|
|
135
|
+
vmax=1.0,
|
|
136
|
+
cbar_kws={'shrink': 0.9}) # Shrink colorbar slightly to fit better
|
|
137
|
+
|
|
138
|
+
# sns.set_theme(font_scale=1.0)
|
|
139
|
+
|
|
140
|
+
ax_heat.set_title("Classification Report Heatmap", pad=_EvaluationConfig.LABEL_PADDING, fontsize=cm_font_size)
|
|
141
|
+
|
|
142
|
+
# manually increase the font size of the elements
|
|
143
|
+
for text in ax_heat.texts:
|
|
144
|
+
text.set_fontsize(cm_tick_size)
|
|
145
|
+
|
|
146
|
+
# manually increase the size of the colorbar ticks
|
|
147
|
+
cbar = ax_heat.collections[0].colorbar
|
|
148
|
+
cbar.ax.tick_params(labelsize=cm_tick_size - 4) # type: ignore
|
|
149
|
+
|
|
150
|
+
# Update Ticks
|
|
151
|
+
ax_heat.tick_params(axis='x', labelsize=cm_tick_size, pad=_EvaluationConfig.LABEL_PADDING)
|
|
152
|
+
ax_heat.tick_params(axis='y', labelsize=cm_tick_size, pad=_EvaluationConfig.LABEL_PADDING, rotation=0) # Ensure Y labels are horizontal
|
|
153
|
+
|
|
154
|
+
plt.tight_layout()
|
|
155
|
+
|
|
156
|
+
heatmap_path = save_dir_path / "classification_report_heatmap.svg"
|
|
157
|
+
plt.savefig(heatmap_path)
|
|
158
|
+
_LOGGER.info(f"📊 Report heatmap saved as '{heatmap_path.name}'")
|
|
159
|
+
plt.close(fig_heat)
|
|
160
|
+
|
|
161
|
+
except Exception as e:
|
|
162
|
+
_LOGGER.error(f"Could not generate classification report heatmap: {e}")
|
|
163
|
+
|
|
164
|
+
# --- labels for Confusion Matrix ---
|
|
165
|
+
plot_labels = map_labels
|
|
166
|
+
plot_display_labels = map_display_labels
|
|
167
|
+
|
|
168
|
+
# 1. DYNAMIC SIZE CALCULATION
|
|
169
|
+
# Calculate figure size based on number of classes.
|
|
170
|
+
n_classes = len(plot_labels) if plot_labels is not None else len(np.unique(y_true))
|
|
171
|
+
# Ensure a minimum size so very small matrices aren't tiny
|
|
172
|
+
fig_w = max(9, n_classes * 0.8 + 3)
|
|
173
|
+
fig_h = max(8, n_classes * 0.8 + 2)
|
|
174
|
+
|
|
175
|
+
# Use the calculated size instead of CLASSIFICATION_PLOT_SIZE
|
|
176
|
+
fig_cm, ax_cm = plt.subplots(figsize=(fig_w, fig_h), dpi=DPI_value)
|
|
177
|
+
disp_ = ConfusionMatrixDisplay.from_predictions(y_true,
|
|
178
|
+
y_pred,
|
|
179
|
+
cmap=format_config.cmap,
|
|
180
|
+
ax=ax_cm,
|
|
181
|
+
normalize='true',
|
|
182
|
+
labels=plot_labels,
|
|
183
|
+
display_labels=plot_display_labels,
|
|
184
|
+
colorbar=False)
|
|
185
|
+
|
|
186
|
+
disp_.im_.set_clim(vmin=0.0, vmax=1.0)
|
|
187
|
+
|
|
188
|
+
# Turn off gridlines
|
|
189
|
+
ax_cm.grid(False)
|
|
190
|
+
|
|
191
|
+
# 2. CHECK FOR FONT CLASH
|
|
192
|
+
# If matrix is huge, force text smaller. If small, allow user config.
|
|
193
|
+
final_font_size = cm_font_size + 2
|
|
194
|
+
if n_classes > 2:
|
|
195
|
+
final_font_size = cm_font_size - n_classes # Decrease font size for larger matrices
|
|
196
|
+
|
|
197
|
+
for text in ax_cm.texts:
|
|
198
|
+
text.set_fontsize(final_font_size)
|
|
199
|
+
|
|
200
|
+
# Update Ticks for Confusion Matrix
|
|
201
|
+
ax_cm.tick_params(axis='x', labelsize=cm_tick_size)
|
|
202
|
+
ax_cm.tick_params(axis='y', labelsize=cm_tick_size)
|
|
203
|
+
|
|
204
|
+
#if more than 3 classes, rotate x ticks
|
|
205
|
+
if n_classes > 3:
|
|
206
|
+
plt.setp(ax_cm.get_xticklabels(), rotation=45, ha='right', rotation_mode="anchor")
|
|
207
|
+
|
|
208
|
+
# Set titles and labels with padding
|
|
209
|
+
ax_cm.set_title("Confusion Matrix", pad=_EvaluationConfig.LABEL_PADDING, fontsize=cm_font_size + 2)
|
|
210
|
+
ax_cm.set_xlabel(ax_cm.get_xlabel(), labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=cm_font_size)
|
|
211
|
+
ax_cm.set_ylabel(ax_cm.get_ylabel(), labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=cm_font_size)
|
|
212
|
+
|
|
213
|
+
# --- ADJUST COLORBAR FONT & SIZE---
|
|
214
|
+
# Manually add the colorbar with the 'shrink' parameter
|
|
215
|
+
cbar = fig_cm.colorbar(disp_.im_, ax=ax_cm, shrink=0.8)
|
|
216
|
+
|
|
217
|
+
# Update the tick size on the new cbar object
|
|
218
|
+
cbar.ax.tick_params(labelsize=cm_tick_size)
|
|
219
|
+
|
|
220
|
+
# (Optional) add a label to the bar itself (e.g. "Probability")
|
|
221
|
+
# cbar.set_label('Probability', fontsize=12)
|
|
222
|
+
|
|
223
|
+
fig_cm.tight_layout()
|
|
224
|
+
|
|
225
|
+
cm_path = save_dir_path / "confusion_matrix.svg"
|
|
226
|
+
plt.savefig(cm_path)
|
|
227
|
+
_LOGGER.info(f"❇️ Confusion matrix saved as '{cm_path.name}'")
|
|
228
|
+
plt.close(fig_cm)
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
# Plotting logic for ROC, PR, and Calibration Curves
|
|
232
|
+
if y_prob is not None and y_prob.ndim == 2:
|
|
233
|
+
num_classes = y_prob.shape[1]
|
|
234
|
+
|
|
235
|
+
# --- Determine which classes to loop over ---
|
|
236
|
+
class_indices_to_plot = []
|
|
237
|
+
plot_titles = []
|
|
238
|
+
save_suffixes = []
|
|
239
|
+
|
|
240
|
+
if num_classes == 2:
|
|
241
|
+
# Binary case: Only plot for the positive class (index 1)
|
|
242
|
+
class_indices_to_plot = [1]
|
|
243
|
+
plot_titles = [""] # No extra title
|
|
244
|
+
save_suffixes = [""] # No extra suffix
|
|
245
|
+
_LOGGER.debug("Generating binary classification plots (ROC, PR, Calibration).")
|
|
246
|
+
|
|
247
|
+
elif num_classes > 2:
|
|
248
|
+
_LOGGER.debug(f"Generating One-vs-Rest plots for {num_classes} classes.")
|
|
249
|
+
# Multiclass case: Plot for every class (One-vs-Rest)
|
|
250
|
+
class_indices_to_plot = list(range(num_classes))
|
|
251
|
+
|
|
252
|
+
# --- Use class_map names if available ---
|
|
253
|
+
use_generic_names = True
|
|
254
|
+
if map_display_labels and len(map_display_labels) == num_classes:
|
|
255
|
+
try:
|
|
256
|
+
# Ensure labels are safe for filenames
|
|
257
|
+
safe_names = [sanitize_filename(name) for name in map_display_labels]
|
|
258
|
+
plot_titles = [f" ({name} vs. Rest)" for name in map_display_labels]
|
|
259
|
+
save_suffixes = [f"_{safe_names[i]}" for i in class_indices_to_plot]
|
|
260
|
+
use_generic_names = False
|
|
261
|
+
except Exception as e:
|
|
262
|
+
_LOGGER.warning(f"Failed to use 'class_map' for plot titles: {e}. Reverting to generic names.")
|
|
263
|
+
use_generic_names = True
|
|
264
|
+
|
|
265
|
+
if use_generic_names:
|
|
266
|
+
plot_titles = [f" (Class {i} vs. Rest)" for i in class_indices_to_plot]
|
|
267
|
+
save_suffixes = [f"_class_{i}" for i in class_indices_to_plot]
|
|
268
|
+
|
|
269
|
+
else:
|
|
270
|
+
# Should not happen, but good to check
|
|
271
|
+
_LOGGER.warning(f"Probability array has invalid shape {y_prob.shape}. Skipping ROC/PR/Calibration plots.")
|
|
272
|
+
|
|
273
|
+
# --- Loop and generate plots ---
|
|
274
|
+
for i, class_index in enumerate(class_indices_to_plot):
|
|
275
|
+
plot_title = plot_titles[i]
|
|
276
|
+
save_suffix = save_suffixes[i]
|
|
277
|
+
|
|
278
|
+
# Get scores for the current class
|
|
279
|
+
y_score = y_prob[:, class_index]
|
|
280
|
+
|
|
281
|
+
# Binarize y_true for the current class
|
|
282
|
+
y_true_binary = (y_true == class_index).astype(int)
|
|
283
|
+
|
|
284
|
+
# --- Save ROC Curve ---
|
|
285
|
+
fpr, tpr, thresholds = roc_curve(y_true_binary, y_score)
|
|
286
|
+
|
|
287
|
+
try:
|
|
288
|
+
# Calculate Youden's J statistic (tpr - fpr)
|
|
289
|
+
J = tpr - fpr
|
|
290
|
+
# Find the index of the best threshold
|
|
291
|
+
best_index = np.argmax(J)
|
|
292
|
+
optimal_threshold = thresholds[best_index]
|
|
293
|
+
|
|
294
|
+
# Define the filename
|
|
295
|
+
threshold_filename = f"best_threshold{save_suffix}.txt"
|
|
296
|
+
threshold_path = save_dir_path / threshold_filename
|
|
297
|
+
|
|
298
|
+
# Get the class name for the report
|
|
299
|
+
class_name = ""
|
|
300
|
+
# Check if we have display labels and the current index is valid
|
|
301
|
+
if map_display_labels and class_index < len(map_display_labels):
|
|
302
|
+
class_name = map_display_labels[class_index]
|
|
303
|
+
if num_classes > 2:
|
|
304
|
+
# Add 'vs. Rest' for multiclass one-vs-rest plots
|
|
305
|
+
class_name += " (vs. Rest)"
|
|
306
|
+
else:
|
|
307
|
+
# Fallback to the generic title or default binary name
|
|
308
|
+
class_name = plot_title.strip() or "Binary Positive Class"
|
|
309
|
+
|
|
310
|
+
# Create content for the file
|
|
311
|
+
file_content = (
|
|
312
|
+
f"Optimal Classification Threshold (Youden's J Statistic)\n"
|
|
313
|
+
f"Class: {class_name}\n"
|
|
314
|
+
f"--------------------------------------------------\n"
|
|
315
|
+
f"Threshold: {optimal_threshold:.6f}\n"
|
|
316
|
+
f"True Positive Rate (TPR): {tpr[best_index]:.6f}\n"
|
|
317
|
+
f"False Positive Rate (FPR): {fpr[best_index]:.6f}\n"
|
|
318
|
+
)
|
|
319
|
+
|
|
320
|
+
threshold_path.write_text(file_content, encoding="utf-8")
|
|
321
|
+
_LOGGER.info(f"💾 Optimal threshold saved as '{threshold_path.name}'")
|
|
322
|
+
|
|
323
|
+
except Exception as e:
|
|
324
|
+
_LOGGER.warning(f"Could not calculate or save optimal threshold: {e}")
|
|
325
|
+
|
|
326
|
+
# Calculate AUC.
|
|
327
|
+
auc = roc_auc_score(y_true_binary, y_score)
|
|
328
|
+
|
|
329
|
+
fig_roc, ax_roc = plt.subplots(figsize=CLASSIFICATION_PLOT_SIZE, dpi=DPI_value)
|
|
330
|
+
ax_roc.plot(fpr, tpr, label=f'AUC = {auc:.2f}', color=format_config.ROC_PR_line)
|
|
331
|
+
ax_roc.plot([0, 1], [0, 1], 'k--')
|
|
332
|
+
ax_roc.set_title(f'Receiver Operating Characteristic{plot_title}', pad=_EvaluationConfig.LABEL_PADDING, fontsize=format_config.font_size + 2)
|
|
333
|
+
ax_roc.set_xlabel('False Positive Rate', labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=format_config.font_size)
|
|
334
|
+
ax_roc.set_ylabel('True Positive Rate', labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=format_config.font_size)
|
|
335
|
+
|
|
336
|
+
# Apply Ticks and Legend sizing
|
|
337
|
+
ax_roc.tick_params(axis='x', labelsize=xtick_size)
|
|
338
|
+
ax_roc.tick_params(axis='y', labelsize=ytick_size)
|
|
339
|
+
ax_roc.legend(loc='lower right', fontsize=legend_size)
|
|
340
|
+
|
|
341
|
+
ax_roc.grid(True)
|
|
342
|
+
roc_path = save_dir_path / f"roc_curve{save_suffix}.svg"
|
|
343
|
+
|
|
344
|
+
plt.tight_layout()
|
|
345
|
+
|
|
346
|
+
plt.savefig(roc_path)
|
|
347
|
+
plt.close(fig_roc)
|
|
348
|
+
|
|
349
|
+
# --- Save Precision-Recall Curve ---
|
|
350
|
+
precision, recall, _ = precision_recall_curve(y_true_binary, y_score)
|
|
351
|
+
ap_score = average_precision_score(y_true_binary, y_score)
|
|
352
|
+
fig_pr, ax_pr = plt.subplots(figsize=CLASSIFICATION_PLOT_SIZE, dpi=DPI_value)
|
|
353
|
+
ax_pr.plot(recall, precision, label=f'Avg Precision = {ap_score:.2f}', color=format_config.ROC_PR_line)
|
|
354
|
+
ax_pr.set_title(f'Precision-Recall Curve{plot_title}', pad=_EvaluationConfig.LABEL_PADDING, fontsize=format_config.font_size + 2)
|
|
355
|
+
ax_pr.set_xlabel('Recall', labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=format_config.font_size)
|
|
356
|
+
ax_pr.set_ylabel('Precision', labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=format_config.font_size)
|
|
357
|
+
|
|
358
|
+
# Apply Ticks and Legend sizing
|
|
359
|
+
ax_pr.tick_params(axis='x', labelsize=xtick_size)
|
|
360
|
+
ax_pr.tick_params(axis='y', labelsize=ytick_size)
|
|
361
|
+
ax_pr.legend(loc='lower left', fontsize=legend_size)
|
|
362
|
+
|
|
363
|
+
ax_pr.grid(True)
|
|
364
|
+
pr_path = save_dir_path / f"pr_curve{save_suffix}.svg"
|
|
365
|
+
|
|
366
|
+
plt.tight_layout()
|
|
367
|
+
|
|
368
|
+
plt.savefig(pr_path)
|
|
369
|
+
plt.close(fig_pr)
|
|
370
|
+
|
|
371
|
+
# --- Save Calibration Plot ---
|
|
372
|
+
fig_cal, ax_cal = plt.subplots(figsize=CLASSIFICATION_PLOT_SIZE, dpi=DPI_value)
|
|
373
|
+
|
|
374
|
+
# --- Step 1: Get binned data *without* plotting ---
|
|
375
|
+
with plt.ioff(): # Suppress showing the temporary plot
|
|
376
|
+
fig_temp, ax_temp = plt.subplots()
|
|
377
|
+
cal_display_temp = CalibrationDisplay.from_predictions(
|
|
378
|
+
y_true_binary, # Use binarized labels
|
|
379
|
+
y_score,
|
|
380
|
+
n_bins=format_config.calibration_bins,
|
|
381
|
+
ax=ax_temp,
|
|
382
|
+
name="temp" # Add a name to suppress potential warnings
|
|
383
|
+
)
|
|
384
|
+
# Get the x, y coordinates of the binned data
|
|
385
|
+
line_x, line_y = cal_display_temp.line_.get_data() # type: ignore
|
|
386
|
+
plt.close(fig_temp) # Close the temporary plot
|
|
387
|
+
|
|
388
|
+
# --- Step 2: Build the plot from scratch ---
|
|
389
|
+
ax_cal.plot([0, 1], [0, 1], 'k--', label='Perfectly calibrated')
|
|
390
|
+
|
|
391
|
+
sns.regplot(
|
|
392
|
+
x=line_x,
|
|
393
|
+
y=line_y,
|
|
394
|
+
ax=ax_cal,
|
|
395
|
+
scatter=False,
|
|
396
|
+
label=f"Model calibration",
|
|
397
|
+
line_kws={
|
|
398
|
+
'color': format_config.ROC_PR_line,
|
|
399
|
+
'linestyle': '--',
|
|
400
|
+
'linewidth': 2,
|
|
401
|
+
}
|
|
402
|
+
)
|
|
403
|
+
|
|
404
|
+
ax_cal.set_title(f'Reliability Curve{plot_title}', pad=_EvaluationConfig.LABEL_PADDING, fontsize=format_config.font_size + 2)
|
|
405
|
+
ax_cal.set_xlabel('Mean Predicted Probability', labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=format_config.font_size)
|
|
406
|
+
ax_cal.set_ylabel('Fraction of Positives', labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=format_config.font_size)
|
|
407
|
+
|
|
408
|
+
# --- Step 3: Set final limits *after* plotting ---
|
|
409
|
+
ax_cal.set_ylim(0.0, 1.0)
|
|
410
|
+
ax_cal.set_xlim(0.0, 1.0)
|
|
411
|
+
|
|
412
|
+
# Apply Ticks and Legend sizing
|
|
413
|
+
ax_cal.tick_params(axis='x', labelsize=xtick_size)
|
|
414
|
+
ax_cal.tick_params(axis='y', labelsize=ytick_size)
|
|
415
|
+
ax_cal.legend(loc='lower right', fontsize=legend_size)
|
|
416
|
+
|
|
417
|
+
ax_cal.grid(True)
|
|
418
|
+
plt.tight_layout()
|
|
419
|
+
|
|
420
|
+
cal_path = save_dir_path / f"calibration_plot{save_suffix}.svg"
|
|
421
|
+
plt.savefig(cal_path)
|
|
422
|
+
plt.close(fig_cal)
|
|
423
|
+
|
|
424
|
+
_LOGGER.info(f"📈 Saved {len(class_indices_to_plot)} sets of ROC, Precision-Recall, and Calibration plots.")
|
|
425
|
+
|
|
426
|
+
|
|
427
|
+
def multi_label_classification_metrics(
|
|
428
|
+
y_true: np.ndarray,
|
|
429
|
+
y_pred: np.ndarray,
|
|
430
|
+
y_prob: np.ndarray,
|
|
431
|
+
target_names: list[str],
|
|
432
|
+
save_dir: Union[str, Path],
|
|
433
|
+
config: Optional[FormatMultiLabelBinaryClassificationMetrics] = None
|
|
434
|
+
):
|
|
435
|
+
"""
|
|
436
|
+
Calculates and saves classification metrics for each label individually.
|
|
437
|
+
|
|
438
|
+
This function first computes overall multi-label metrics (Hamming Loss, Jaccard Score)
|
|
439
|
+
and then iterates through each label to generate and save individual reports,
|
|
440
|
+
confusion matrices, ROC curves, and Precision-Recall curves.
|
|
441
|
+
|
|
442
|
+
Args:
|
|
443
|
+
y_true (np.ndarray): Ground truth binary labels, shape (n_samples, n_labels).
|
|
444
|
+
y_pred (np.ndarray): Predicted binary labels, shape (n_samples, n_labels).
|
|
445
|
+
y_prob (np.ndarray): Predicted probabilities, shape (n_samples, n_labels).
|
|
446
|
+
target_names (List[str]): A list of names for the labels.
|
|
447
|
+
save_dir (str | Path): Directory to save plots and reports.
|
|
448
|
+
config (object): Formatting configuration object.
|
|
449
|
+
"""
|
|
450
|
+
if y_true.ndim != 2 or y_prob.ndim != 2 or y_pred.ndim != 2:
|
|
451
|
+
_LOGGER.error("y_true, y_pred, and y_prob must be 2D arrays for multi-label classification.")
|
|
452
|
+
raise ValueError()
|
|
453
|
+
if y_true.shape != y_prob.shape or y_true.shape != y_pred.shape:
|
|
454
|
+
_LOGGER.error("Shapes of y_true, y_pred, and y_prob must match.")
|
|
455
|
+
raise ValueError()
|
|
456
|
+
if y_true.shape[1] != len(target_names):
|
|
457
|
+
_LOGGER.error("Number of target names must match the number of columns in y_true.")
|
|
458
|
+
raise ValueError()
|
|
459
|
+
|
|
460
|
+
save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
461
|
+
|
|
462
|
+
# --- Parse Config or use defaults ---
|
|
463
|
+
if config is None:
|
|
464
|
+
# Create a default config if one wasn't provided
|
|
465
|
+
format_config = _BaseMultiLabelFormat()
|
|
466
|
+
else:
|
|
467
|
+
format_config = config
|
|
468
|
+
|
|
469
|
+
# y_pred is now passed in directly, no threshold needed.
|
|
470
|
+
|
|
471
|
+
# ticks and legend font sizes
|
|
472
|
+
xtick_size = format_config.xtick_size
|
|
473
|
+
ytick_size = format_config.ytick_size
|
|
474
|
+
legend_size = format_config.legend_size
|
|
475
|
+
base_font_size = format_config.font_size
|
|
476
|
+
|
|
477
|
+
# --- Calculate and Save Overall Metrics (using y_pred) ---
|
|
478
|
+
h_loss = hamming_loss(y_true, y_pred)
|
|
479
|
+
j_score_micro = jaccard_score(y_true, y_pred, average='micro')
|
|
480
|
+
j_score_macro = jaccard_score(y_true, y_pred, average='macro')
|
|
481
|
+
|
|
482
|
+
overall_report = (
|
|
483
|
+
f"Overall Multi-Label Metrics:\n" # No threshold to report here
|
|
484
|
+
f"--------------------------------------------------\n"
|
|
485
|
+
f"Hamming Loss: {h_loss:.4f}\n"
|
|
486
|
+
f"Jaccard Score (micro): {j_score_micro:.4f}\n"
|
|
487
|
+
f"Jaccard Score (macro): {j_score_macro:.4f}\n"
|
|
488
|
+
f"--------------------------------------------------\n"
|
|
489
|
+
)
|
|
490
|
+
# print(overall_report)
|
|
491
|
+
overall_report_path = save_dir_path / "classification_report.txt"
|
|
492
|
+
overall_report_path.write_text(overall_report)
|
|
493
|
+
|
|
494
|
+
# --- Per-Label Metrics and Plots ---
|
|
495
|
+
for i, name in enumerate(target_names):
|
|
496
|
+
print(f" -> Evaluating label: '{name}'")
|
|
497
|
+
true_i = y_true[:, i]
|
|
498
|
+
pred_i = y_pred[:, i] # Use passed-in y_pred
|
|
499
|
+
prob_i = y_prob[:, i] # Use passed-in y_prob
|
|
500
|
+
sanitized_name = sanitize_filename(name)
|
|
501
|
+
|
|
502
|
+
# --- Save Classification Report for the label (uses y_pred) ---
|
|
503
|
+
report_text = classification_report(true_i, pred_i)
|
|
504
|
+
report_path = save_dir_path / f"classification_report_{sanitized_name}.txt"
|
|
505
|
+
report_path.write_text(report_text) # type: ignore
|
|
506
|
+
|
|
507
|
+
# --- Save Confusion Matrix (uses y_pred) ---
|
|
508
|
+
fig_cm, ax_cm = plt.subplots(figsize=_EvaluationConfig.CM_SIZE, dpi=_EvaluationConfig.DPI)
|
|
509
|
+
disp_ = ConfusionMatrixDisplay.from_predictions(true_i,
|
|
510
|
+
pred_i,
|
|
511
|
+
cmap=format_config.cmap, # Use config cmap
|
|
512
|
+
ax=ax_cm,
|
|
513
|
+
normalize='true',
|
|
514
|
+
labels=[0, 1],
|
|
515
|
+
display_labels=["Negative", "Positive"],
|
|
516
|
+
colorbar=False)
|
|
517
|
+
|
|
518
|
+
disp_.im_.set_clim(vmin=0.0, vmax=1.0)
|
|
519
|
+
|
|
520
|
+
# Turn off gridlines
|
|
521
|
+
ax_cm.grid(False)
|
|
522
|
+
|
|
523
|
+
# Manually update font size of cell texts
|
|
524
|
+
for text in ax_cm.texts:
|
|
525
|
+
text.set_fontsize(base_font_size + 2) # Use config font_size
|
|
526
|
+
|
|
527
|
+
# Apply ticks
|
|
528
|
+
ax_cm.tick_params(axis='x', labelsize=xtick_size)
|
|
529
|
+
ax_cm.tick_params(axis='y', labelsize=ytick_size)
|
|
530
|
+
|
|
531
|
+
# Set titles and labels with padding
|
|
532
|
+
ax_cm.set_title(f"Confusion Matrix for '{name}'", pad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size + 2)
|
|
533
|
+
ax_cm.set_xlabel(ax_cm.get_xlabel(), labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size)
|
|
534
|
+
ax_cm.set_ylabel(ax_cm.get_ylabel(), labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size)
|
|
535
|
+
|
|
536
|
+
# --- ADJUST COLORBAR FONT & SIZE---
|
|
537
|
+
# Manually add the colorbar with the 'shrink' parameter
|
|
538
|
+
cbar = fig_cm.colorbar(disp_.im_, ax=ax_cm, shrink=0.8)
|
|
539
|
+
|
|
540
|
+
# Update the tick size on the new cbar object
|
|
541
|
+
cbar.ax.tick_params(labelsize=ytick_size) # type: ignore
|
|
542
|
+
|
|
543
|
+
plt.tight_layout()
|
|
544
|
+
|
|
545
|
+
cm_path = save_dir_path / f"confusion_matrix_{sanitized_name}.svg"
|
|
546
|
+
plt.savefig(cm_path)
|
|
547
|
+
plt.close(fig_cm)
|
|
548
|
+
|
|
549
|
+
# --- Save ROC Curve (uses y_prob) ---
|
|
550
|
+
fpr, tpr, thresholds = roc_curve(true_i, prob_i)
|
|
551
|
+
|
|
552
|
+
try:
|
|
553
|
+
# Calculate Youden's J statistic (tpr - fpr)
|
|
554
|
+
J = tpr - fpr
|
|
555
|
+
# Find the index of the best threshold
|
|
556
|
+
best_index = np.argmax(J)
|
|
557
|
+
optimal_threshold = thresholds[best_index]
|
|
558
|
+
best_tpr = tpr[best_index]
|
|
559
|
+
best_fpr = fpr[best_index]
|
|
560
|
+
|
|
561
|
+
# Define the filename
|
|
562
|
+
threshold_filename = f"best_threshold_{sanitized_name}.txt"
|
|
563
|
+
threshold_path = save_dir_path / threshold_filename
|
|
564
|
+
|
|
565
|
+
# The class name is the target_name for this label
|
|
566
|
+
class_name = name
|
|
567
|
+
|
|
568
|
+
# Create content for the file
|
|
569
|
+
file_content = (
|
|
570
|
+
f"Optimal Classification Threshold (Youden's J Statistic)\n"
|
|
571
|
+
f"Class/Label: {class_name}\n"
|
|
572
|
+
f"--------------------------------------------------\n"
|
|
573
|
+
f"Threshold: {optimal_threshold:.6f}\n"
|
|
574
|
+
f"True Positive Rate (TPR): {best_tpr:.6f}\n"
|
|
575
|
+
f"False Positive Rate (FPR): {best_fpr:.6f}\n"
|
|
576
|
+
)
|
|
577
|
+
|
|
578
|
+
threshold_path.write_text(file_content, encoding="utf-8")
|
|
579
|
+
_LOGGER.info(f"💾 Optimal threshold for '{name}' saved to '{threshold_path.name}'")
|
|
580
|
+
|
|
581
|
+
except Exception as e:
|
|
582
|
+
_LOGGER.warning(f"Could not calculate or save optimal threshold for '{name}': {e}")
|
|
583
|
+
|
|
584
|
+
auc = roc_auc_score(true_i, prob_i)
|
|
585
|
+
fig_roc, ax_roc = plt.subplots(figsize=CLASSIFICATION_PLOT_SIZE, dpi=DPI_value)
|
|
586
|
+
ax_roc.plot(fpr, tpr, label=f'AUC = {auc:.2f}', color=format_config.ROC_PR_line) # Use config color
|
|
587
|
+
ax_roc.plot([0, 1], [0, 1], 'k--')
|
|
588
|
+
|
|
589
|
+
ax_roc.set_title(f'ROC Curve for "{name}"', pad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size + 2)
|
|
590
|
+
ax_roc.set_xlabel('False Positive Rate', labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size)
|
|
591
|
+
ax_roc.set_ylabel('True Positive Rate', labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size)
|
|
592
|
+
|
|
593
|
+
# Apply ticks and legend font size
|
|
594
|
+
ax_roc.tick_params(axis='x', labelsize=xtick_size)
|
|
595
|
+
ax_roc.tick_params(axis='y', labelsize=ytick_size)
|
|
596
|
+
ax_roc.legend(loc='lower right', fontsize=legend_size)
|
|
597
|
+
|
|
598
|
+
ax_roc.grid(True, linestyle='--', alpha=0.6)
|
|
599
|
+
|
|
600
|
+
plt.tight_layout()
|
|
601
|
+
|
|
602
|
+
roc_path = save_dir_path / f"roc_curve_{sanitized_name}.svg"
|
|
603
|
+
plt.savefig(roc_path)
|
|
604
|
+
plt.close(fig_roc)
|
|
605
|
+
|
|
606
|
+
# --- Save Precision-Recall Curve (uses y_prob) ---
|
|
607
|
+
precision, recall, _ = precision_recall_curve(true_i, prob_i)
|
|
608
|
+
ap_score = average_precision_score(true_i, prob_i)
|
|
609
|
+
fig_pr, ax_pr = plt.subplots(figsize=CLASSIFICATION_PLOT_SIZE, dpi=DPI_value)
|
|
610
|
+
ax_pr.plot(recall, precision, label=f'AP = {ap_score:.2f}', color=format_config.ROC_PR_line) # Use config color
|
|
611
|
+
ax_pr.set_title(f'Precision-Recall Curve for "{name}"', pad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size + 2)
|
|
612
|
+
ax_pr.set_xlabel('Recall', labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size)
|
|
613
|
+
ax_pr.set_ylabel('Precision', labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size)
|
|
614
|
+
|
|
615
|
+
# Apply ticks and legend font size
|
|
616
|
+
ax_pr.tick_params(axis='x', labelsize=xtick_size)
|
|
617
|
+
ax_pr.tick_params(axis='y', labelsize=ytick_size)
|
|
618
|
+
ax_pr.legend(loc='lower left', fontsize=legend_size)
|
|
619
|
+
|
|
620
|
+
ax_pr.grid(True, linestyle='--', alpha=0.6)
|
|
621
|
+
|
|
622
|
+
fig_pr.tight_layout()
|
|
623
|
+
|
|
624
|
+
pr_path = save_dir_path / f"pr_curve_{sanitized_name}.svg"
|
|
625
|
+
plt.savefig(pr_path)
|
|
626
|
+
plt.close(fig_pr)
|
|
627
|
+
|
|
628
|
+
_LOGGER.info(f"All individual label reports and plots saved to '{save_dir_path.name}'")
|
|
629
|
+
|