dragon-ml-toolbox 19.13.0__py3-none-any.whl → 20.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (219) hide show
  1. {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/METADATA +29 -46
  2. dragon_ml_toolbox-20.0.0.dist-info/RECORD +178 -0
  3. ml_tools/{ETL_cleaning.py → ETL_cleaning/__init__.py} +13 -5
  4. ml_tools/ETL_cleaning/_basic_clean.py +351 -0
  5. ml_tools/ETL_cleaning/_clean_tools.py +128 -0
  6. ml_tools/ETL_cleaning/_dragon_cleaner.py +245 -0
  7. ml_tools/ETL_cleaning/_imprimir.py +13 -0
  8. ml_tools/{ETL_engineering.py → ETL_engineering/__init__.py} +8 -4
  9. ml_tools/ETL_engineering/_dragon_engineering.py +261 -0
  10. ml_tools/ETL_engineering/_imprimir.py +24 -0
  11. ml_tools/{_core/_ETL_engineering.py → ETL_engineering/_transforms.py} +14 -267
  12. ml_tools/{_core → GUI_tools}/_GUI_tools.py +37 -40
  13. ml_tools/{GUI_tools.py → GUI_tools/__init__.py} +7 -5
  14. ml_tools/GUI_tools/_imprimir.py +12 -0
  15. ml_tools/IO_tools/_IO_loggers.py +235 -0
  16. ml_tools/IO_tools/_IO_save_load.py +151 -0
  17. ml_tools/IO_tools/_IO_utils.py +140 -0
  18. ml_tools/{IO_tools.py → IO_tools/__init__.py} +13 -5
  19. ml_tools/IO_tools/_imprimir.py +14 -0
  20. ml_tools/MICE/_MICE_imputation.py +132 -0
  21. ml_tools/{MICE_imputation.py → MICE/__init__.py} +6 -7
  22. ml_tools/{_core/_MICE_imputation.py → MICE/_dragon_mice.py} +243 -322
  23. ml_tools/MICE/_imprimir.py +11 -0
  24. ml_tools/{ML_callbacks.py → ML_callbacks/__init__.py} +12 -4
  25. ml_tools/ML_callbacks/_base.py +101 -0
  26. ml_tools/ML_callbacks/_checkpoint.py +232 -0
  27. ml_tools/ML_callbacks/_early_stop.py +208 -0
  28. ml_tools/ML_callbacks/_imprimir.py +12 -0
  29. ml_tools/ML_callbacks/_scheduler.py +197 -0
  30. ml_tools/{ML_chaining_utilities.py → ML_chain/__init__.py} +8 -3
  31. ml_tools/{_core/_ML_chaining_utilities.py → ML_chain/_chaining_tools.py} +5 -129
  32. ml_tools/ML_chain/_dragon_chain.py +140 -0
  33. ml_tools/ML_chain/_imprimir.py +11 -0
  34. ml_tools/ML_configuration/__init__.py +90 -0
  35. ml_tools/ML_configuration/_base_model_config.py +69 -0
  36. ml_tools/ML_configuration/_finalize.py +366 -0
  37. ml_tools/ML_configuration/_imprimir.py +47 -0
  38. ml_tools/ML_configuration/_metrics.py +593 -0
  39. ml_tools/ML_configuration/_models.py +206 -0
  40. ml_tools/ML_configuration/_training.py +124 -0
  41. ml_tools/ML_datasetmaster/__init__.py +28 -0
  42. ml_tools/ML_datasetmaster/_base_datasetmaster.py +337 -0
  43. ml_tools/{_core/_ML_datasetmaster.py → ML_datasetmaster/_datasetmaster.py} +9 -329
  44. ml_tools/ML_datasetmaster/_imprimir.py +15 -0
  45. ml_tools/{_core/_ML_sequence_datasetmaster.py → ML_datasetmaster/_sequence_datasetmaster.py} +13 -15
  46. ml_tools/{_core/_ML_vision_datasetmaster.py → ML_datasetmaster/_vision_datasetmaster.py} +63 -65
  47. ml_tools/ML_evaluation/__init__.py +53 -0
  48. ml_tools/ML_evaluation/_classification.py +629 -0
  49. ml_tools/ML_evaluation/_feature_importance.py +409 -0
  50. ml_tools/ML_evaluation/_imprimir.py +25 -0
  51. ml_tools/ML_evaluation/_loss.py +92 -0
  52. ml_tools/ML_evaluation/_regression.py +273 -0
  53. ml_tools/{_core/_ML_sequence_evaluation.py → ML_evaluation/_sequence.py} +8 -11
  54. ml_tools/{_core/_ML_vision_evaluation.py → ML_evaluation/_vision.py} +12 -17
  55. ml_tools/{_core → ML_evaluation_captum}/_ML_evaluation_captum.py +11 -38
  56. ml_tools/{ML_evaluation_captum.py → ML_evaluation_captum/__init__.py} +6 -4
  57. ml_tools/ML_evaluation_captum/_imprimir.py +10 -0
  58. ml_tools/{_core → ML_finalize_handler}/_ML_finalize_handler.py +3 -7
  59. ml_tools/ML_finalize_handler/__init__.py +10 -0
  60. ml_tools/ML_finalize_handler/_imprimir.py +8 -0
  61. ml_tools/ML_inference/__init__.py +22 -0
  62. ml_tools/ML_inference/_base_inference.py +166 -0
  63. ml_tools/{_core/_ML_chaining_inference.py → ML_inference/_chain_inference.py} +14 -17
  64. ml_tools/ML_inference/_dragon_inference.py +332 -0
  65. ml_tools/ML_inference/_imprimir.py +11 -0
  66. ml_tools/ML_inference/_multi_inference.py +180 -0
  67. ml_tools/ML_inference_sequence/__init__.py +10 -0
  68. ml_tools/ML_inference_sequence/_imprimir.py +8 -0
  69. ml_tools/{_core/_ML_sequence_inference.py → ML_inference_sequence/_sequence_inference.py} +11 -15
  70. ml_tools/ML_inference_vision/__init__.py +10 -0
  71. ml_tools/ML_inference_vision/_imprimir.py +8 -0
  72. ml_tools/{_core/_ML_vision_inference.py → ML_inference_vision/_vision_inference.py} +15 -19
  73. ml_tools/ML_models/__init__.py +32 -0
  74. ml_tools/{_core/_ML_models_advanced.py → ML_models/_advanced_models.py} +22 -18
  75. ml_tools/ML_models/_base_mlp_attention.py +198 -0
  76. ml_tools/{_core/_models_advanced_base.py → ML_models/_base_save_load.py} +73 -49
  77. ml_tools/ML_models/_dragon_tabular.py +248 -0
  78. ml_tools/ML_models/_imprimir.py +18 -0
  79. ml_tools/ML_models/_mlp_attention.py +134 -0
  80. ml_tools/{_core → ML_models}/_models_advanced_helpers.py +13 -13
  81. ml_tools/ML_models_sequence/__init__.py +10 -0
  82. ml_tools/ML_models_sequence/_imprimir.py +8 -0
  83. ml_tools/{_core/_ML_sequence_models.py → ML_models_sequence/_sequence_models.py} +5 -8
  84. ml_tools/ML_models_vision/__init__.py +29 -0
  85. ml_tools/ML_models_vision/_base_wrapper.py +254 -0
  86. ml_tools/ML_models_vision/_image_classification.py +182 -0
  87. ml_tools/ML_models_vision/_image_segmentation.py +108 -0
  88. ml_tools/ML_models_vision/_imprimir.py +16 -0
  89. ml_tools/ML_models_vision/_object_detection.py +135 -0
  90. ml_tools/ML_optimization/__init__.py +21 -0
  91. ml_tools/ML_optimization/_imprimir.py +13 -0
  92. ml_tools/{_core/_ML_optimization_pareto.py → ML_optimization/_multi_dragon.py} +18 -24
  93. ml_tools/ML_optimization/_single_dragon.py +203 -0
  94. ml_tools/{_core/_ML_optimization.py → ML_optimization/_single_manual.py} +75 -213
  95. ml_tools/{_core → ML_scaler}/_ML_scaler.py +8 -11
  96. ml_tools/ML_scaler/__init__.py +10 -0
  97. ml_tools/ML_scaler/_imprimir.py +8 -0
  98. ml_tools/ML_trainer/__init__.py +20 -0
  99. ml_tools/ML_trainer/_base_trainer.py +297 -0
  100. ml_tools/ML_trainer/_dragon_detection_trainer.py +402 -0
  101. ml_tools/ML_trainer/_dragon_sequence_trainer.py +540 -0
  102. ml_tools/ML_trainer/_dragon_trainer.py +1160 -0
  103. ml_tools/ML_trainer/_imprimir.py +10 -0
  104. ml_tools/{ML_utilities.py → ML_utilities/__init__.py} +14 -6
  105. ml_tools/ML_utilities/_artifact_finder.py +382 -0
  106. ml_tools/ML_utilities/_imprimir.py +16 -0
  107. ml_tools/ML_utilities/_inspection.py +325 -0
  108. ml_tools/ML_utilities/_train_tools.py +205 -0
  109. ml_tools/{ML_vision_transformers.py → ML_vision_transformers/__init__.py} +9 -6
  110. ml_tools/{_core/_ML_vision_transformers.py → ML_vision_transformers/_core_transforms.py} +11 -155
  111. ml_tools/ML_vision_transformers/_imprimir.py +14 -0
  112. ml_tools/ML_vision_transformers/_offline_augmentation.py +159 -0
  113. ml_tools/{_core/_PSO_optimization.py → PSO_optimization/_PSO.py} +58 -15
  114. ml_tools/{PSO_optimization.py → PSO_optimization/__init__.py} +5 -3
  115. ml_tools/PSO_optimization/_imprimir.py +10 -0
  116. ml_tools/SQL/__init__.py +7 -0
  117. ml_tools/{_core/_SQL.py → SQL/_dragon_SQL.py} +7 -11
  118. ml_tools/SQL/_imprimir.py +8 -0
  119. ml_tools/{_core → VIF}/_VIF_factor.py +5 -8
  120. ml_tools/{VIF_factor.py → VIF/__init__.py} +4 -2
  121. ml_tools/VIF/_imprimir.py +10 -0
  122. ml_tools/_core/__init__.py +7 -1
  123. ml_tools/_core/_logger.py +8 -18
  124. ml_tools/_core/_schema_load_ops.py +43 -0
  125. ml_tools/_core/_script_info.py +2 -2
  126. ml_tools/{data_exploration.py → data_exploration/__init__.py} +32 -16
  127. ml_tools/data_exploration/_analysis.py +214 -0
  128. ml_tools/data_exploration/_cleaning.py +566 -0
  129. ml_tools/data_exploration/_features.py +583 -0
  130. ml_tools/data_exploration/_imprimir.py +32 -0
  131. ml_tools/data_exploration/_plotting.py +487 -0
  132. ml_tools/data_exploration/_schema_ops.py +176 -0
  133. ml_tools/{ensemble_evaluation.py → ensemble_evaluation/__init__.py} +6 -4
  134. ml_tools/{_core → ensemble_evaluation}/_ensemble_evaluation.py +3 -7
  135. ml_tools/ensemble_evaluation/_imprimir.py +14 -0
  136. ml_tools/{ensemble_inference.py → ensemble_inference/__init__.py} +5 -3
  137. ml_tools/{_core → ensemble_inference}/_ensemble_inference.py +15 -18
  138. ml_tools/ensemble_inference/_imprimir.py +9 -0
  139. ml_tools/{ensemble_learning.py → ensemble_learning/__init__.py} +4 -6
  140. ml_tools/{_core → ensemble_learning}/_ensemble_learning.py +7 -10
  141. ml_tools/ensemble_learning/_imprimir.py +10 -0
  142. ml_tools/{excel_handler.py → excel_handler/__init__.py} +5 -3
  143. ml_tools/{_core → excel_handler}/_excel_handler.py +6 -10
  144. ml_tools/excel_handler/_imprimir.py +13 -0
  145. ml_tools/{keys.py → keys/__init__.py} +4 -1
  146. ml_tools/keys/_imprimir.py +11 -0
  147. ml_tools/{_core → keys}/_keys.py +2 -0
  148. ml_tools/{math_utilities.py → math_utilities/__init__.py} +5 -2
  149. ml_tools/math_utilities/_imprimir.py +11 -0
  150. ml_tools/{_core → math_utilities}/_math_utilities.py +1 -5
  151. ml_tools/{optimization_tools.py → optimization_tools/__init__.py} +9 -4
  152. ml_tools/optimization_tools/_imprimir.py +13 -0
  153. ml_tools/optimization_tools/_optimization_bounds.py +236 -0
  154. ml_tools/optimization_tools/_optimization_plots.py +218 -0
  155. ml_tools/{path_manager.py → path_manager/__init__.py} +6 -3
  156. ml_tools/{_core/_path_manager.py → path_manager/_dragonmanager.py} +11 -347
  157. ml_tools/path_manager/_imprimir.py +15 -0
  158. ml_tools/path_manager/_path_tools.py +346 -0
  159. ml_tools/plot_fonts/__init__.py +8 -0
  160. ml_tools/plot_fonts/_imprimir.py +8 -0
  161. ml_tools/{_core → plot_fonts}/_plot_fonts.py +2 -5
  162. ml_tools/schema/__init__.py +15 -0
  163. ml_tools/schema/_feature_schema.py +223 -0
  164. ml_tools/schema/_gui_schema.py +191 -0
  165. ml_tools/schema/_imprimir.py +10 -0
  166. ml_tools/{serde.py → serde/__init__.py} +4 -2
  167. ml_tools/serde/_imprimir.py +10 -0
  168. ml_tools/{_core → serde}/_serde.py +3 -8
  169. ml_tools/{utilities.py → utilities/__init__.py} +11 -6
  170. ml_tools/utilities/_imprimir.py +18 -0
  171. ml_tools/{_core/_utilities.py → utilities/_utility_save_load.py} +13 -190
  172. ml_tools/utilities/_utility_tools.py +192 -0
  173. dragon_ml_toolbox-19.13.0.dist-info/RECORD +0 -111
  174. ml_tools/ML_chaining_inference.py +0 -8
  175. ml_tools/ML_configuration.py +0 -86
  176. ml_tools/ML_configuration_pytab.py +0 -14
  177. ml_tools/ML_datasetmaster.py +0 -10
  178. ml_tools/ML_evaluation.py +0 -16
  179. ml_tools/ML_evaluation_multi.py +0 -12
  180. ml_tools/ML_finalize_handler.py +0 -8
  181. ml_tools/ML_inference.py +0 -12
  182. ml_tools/ML_models.py +0 -14
  183. ml_tools/ML_models_advanced.py +0 -14
  184. ml_tools/ML_models_pytab.py +0 -14
  185. ml_tools/ML_optimization.py +0 -14
  186. ml_tools/ML_optimization_pareto.py +0 -8
  187. ml_tools/ML_scaler.py +0 -8
  188. ml_tools/ML_sequence_datasetmaster.py +0 -8
  189. ml_tools/ML_sequence_evaluation.py +0 -10
  190. ml_tools/ML_sequence_inference.py +0 -8
  191. ml_tools/ML_sequence_models.py +0 -8
  192. ml_tools/ML_trainer.py +0 -12
  193. ml_tools/ML_vision_datasetmaster.py +0 -12
  194. ml_tools/ML_vision_evaluation.py +0 -10
  195. ml_tools/ML_vision_inference.py +0 -8
  196. ml_tools/ML_vision_models.py +0 -18
  197. ml_tools/SQL.py +0 -8
  198. ml_tools/_core/_ETL_cleaning.py +0 -694
  199. ml_tools/_core/_IO_tools.py +0 -498
  200. ml_tools/_core/_ML_callbacks.py +0 -702
  201. ml_tools/_core/_ML_configuration.py +0 -1332
  202. ml_tools/_core/_ML_configuration_pytab.py +0 -102
  203. ml_tools/_core/_ML_evaluation.py +0 -867
  204. ml_tools/_core/_ML_evaluation_multi.py +0 -544
  205. ml_tools/_core/_ML_inference.py +0 -646
  206. ml_tools/_core/_ML_models.py +0 -668
  207. ml_tools/_core/_ML_models_pytab.py +0 -693
  208. ml_tools/_core/_ML_trainer.py +0 -2323
  209. ml_tools/_core/_ML_utilities.py +0 -886
  210. ml_tools/_core/_ML_vision_models.py +0 -644
  211. ml_tools/_core/_data_exploration.py +0 -1901
  212. ml_tools/_core/_optimization_tools.py +0 -493
  213. ml_tools/_core/_schema.py +0 -359
  214. ml_tools/plot_fonts.py +0 -8
  215. ml_tools/schema.py +0 -12
  216. {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/WHEEL +0 -0
  217. {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/licenses/LICENSE +0 -0
  218. {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +0 -0
  219. {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,629 @@
1
+ import numpy as np
2
+ import pandas as pd
3
+ import matplotlib.pyplot as plt
4
+ import seaborn as sns
5
+ from sklearn.calibration import CalibrationDisplay
6
+ from sklearn.metrics import (
7
+ classification_report,
8
+ ConfusionMatrixDisplay,
9
+ roc_curve,
10
+ roc_auc_score,
11
+ precision_recall_curve,
12
+ average_precision_score,
13
+ hamming_loss,
14
+ jaccard_score
15
+ )
16
+ from pathlib import Path
17
+ from typing import Union, Optional
18
+
19
+ from ..ML_configuration._metrics import (_BaseMultiLabelFormat,
20
+ _BaseClassificationFormat,
21
+ FormatBinaryClassificationMetrics,
22
+ FormatMultiClassClassificationMetrics,
23
+ FormatBinaryImageClassificationMetrics,
24
+ FormatMultiClassImageClassificationMetrics,
25
+ FormatMultiLabelBinaryClassificationMetrics)
26
+
27
+ from ..path_manager import make_fullpath, sanitize_filename
28
+ from .._core import get_logger
29
+ from ..keys._keys import _EvaluationConfig
30
+
31
+
32
+ _LOGGER = get_logger("Classification Metrics")
33
+
34
+
35
+ __all__ = [
36
+ "classification_metrics",
37
+ "multi_label_classification_metrics",
38
+ ]
39
+
40
+
41
+ DPI_value = _EvaluationConfig.DPI
42
+ CLASSIFICATION_PLOT_SIZE = _EvaluationConfig.CLASSIFICATION_PLOT_SIZE
43
+
44
+
45
+ def classification_metrics(save_dir: Union[str, Path],
46
+ y_true: np.ndarray,
47
+ y_pred: np.ndarray,
48
+ y_prob: Optional[np.ndarray] = None,
49
+ class_map: Optional[dict[str,int]] = None,
50
+ config: Optional[Union[FormatBinaryClassificationMetrics,
51
+ FormatMultiClassClassificationMetrics,
52
+ FormatBinaryImageClassificationMetrics,
53
+ FormatMultiClassImageClassificationMetrics]] = None):
54
+ """
55
+ Saves classification metrics and plots.
56
+
57
+ Args:
58
+ y_true (np.ndarray): Ground truth labels.
59
+ y_pred (np.ndarray): Predicted labels.
60
+ y_prob (np.ndarray): Predicted probabilities for ROC curve.
61
+ config (object): Formatting configuration object.
62
+ save_dir (str | Path): Directory to save plots.
63
+ """
64
+ # --- Parse Config or use defaults ---
65
+ if config is None:
66
+ # Create a default config if one wasn't provided
67
+ format_config = _BaseClassificationFormat()
68
+ else:
69
+ format_config = config
70
+
71
+ # --- Set Font Sizes ---
72
+ xtick_size = format_config.xtick_size
73
+ ytick_size = format_config.ytick_size
74
+ legend_size = format_config.legend_size
75
+
76
+ # config font size for heatmap
77
+ cm_font_size = format_config.cm_font_size
78
+ cm_tick_size = cm_font_size - 4
79
+
80
+ # --- Parse class_map ---
81
+ map_labels = None
82
+ map_display_labels = None
83
+ if class_map:
84
+ # Sort the map by its values (the indices) to ensure correct order
85
+ try:
86
+ sorted_items = sorted(class_map.items(), key=lambda item: item[1])
87
+ map_labels = [item[1] for item in sorted_items]
88
+ map_display_labels = [item[0] for item in sorted_items]
89
+ except Exception as e:
90
+ _LOGGER.warning(f"Could not parse 'class_map': {e}")
91
+ map_labels = None
92
+ map_display_labels = None
93
+
94
+ # Generate report as both text and dictionary
95
+ report_text: str = classification_report(y_true, y_pred, labels=map_labels, target_names=map_display_labels) # type: ignore
96
+ report_dict: dict = classification_report(y_true, y_pred, output_dict=True, labels=map_labels, target_names=map_display_labels) # type: ignore
97
+ # print(report_text)
98
+
99
+ save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
100
+ # Save text report
101
+ report_path = save_dir_path / "classification_report.txt"
102
+ report_path.write_text(report_text, encoding="utf-8")
103
+ _LOGGER.info(f"📝 Classification report saved as '{report_path.name}'")
104
+
105
+ # --- Save Classification Report Heatmap ---
106
+ try:
107
+ # Create DataFrame from report
108
+ report_df = pd.DataFrame(report_dict)
109
+
110
+ # 1. Robust Cleanup: Drop by name, not position
111
+ # Remove 'accuracy' column if it exists (handles the scalar value issue)
112
+ report_df = report_df.drop(columns=['accuracy'], errors='ignore')
113
+
114
+ # Remove 'support' row explicitly (safer than iloc[:-1])
115
+ if 'support' in report_df.index:
116
+ report_df = report_df.drop(index='support')
117
+
118
+ # 2. Transpose: Rows = Classes, Cols = Metrics
119
+ plot_df = report_df.T
120
+
121
+ # 3. Dynamic Height Calculation
122
+ # (Base height of 4 + 0.5 inches per class row)
123
+ fig_height = max(5.0, len(plot_df.index) * 0.5 + 4.0)
124
+ fig_width = 8.0 # Set a fixed width
125
+
126
+ # --- Use calculated dimensions, not the config constant ---
127
+ fig_heat, ax_heat = plt.subplots(figsize=(fig_width, fig_height), dpi=_EvaluationConfig.DPI)
128
+
129
+ # sns.set_theme(font_scale=1.4)
130
+ sns.heatmap(plot_df,
131
+ annot=True,
132
+ cmap=format_config.cmap,
133
+ fmt='.2f',
134
+ vmin=0.0,
135
+ vmax=1.0,
136
+ cbar_kws={'shrink': 0.9}) # Shrink colorbar slightly to fit better
137
+
138
+ # sns.set_theme(font_scale=1.0)
139
+
140
+ ax_heat.set_title("Classification Report Heatmap", pad=_EvaluationConfig.LABEL_PADDING, fontsize=cm_font_size)
141
+
142
+ # manually increase the font size of the elements
143
+ for text in ax_heat.texts:
144
+ text.set_fontsize(cm_tick_size)
145
+
146
+ # manually increase the size of the colorbar ticks
147
+ cbar = ax_heat.collections[0].colorbar
148
+ cbar.ax.tick_params(labelsize=cm_tick_size - 4) # type: ignore
149
+
150
+ # Update Ticks
151
+ ax_heat.tick_params(axis='x', labelsize=cm_tick_size, pad=_EvaluationConfig.LABEL_PADDING)
152
+ ax_heat.tick_params(axis='y', labelsize=cm_tick_size, pad=_EvaluationConfig.LABEL_PADDING, rotation=0) # Ensure Y labels are horizontal
153
+
154
+ plt.tight_layout()
155
+
156
+ heatmap_path = save_dir_path / "classification_report_heatmap.svg"
157
+ plt.savefig(heatmap_path)
158
+ _LOGGER.info(f"📊 Report heatmap saved as '{heatmap_path.name}'")
159
+ plt.close(fig_heat)
160
+
161
+ except Exception as e:
162
+ _LOGGER.error(f"Could not generate classification report heatmap: {e}")
163
+
164
+ # --- labels for Confusion Matrix ---
165
+ plot_labels = map_labels
166
+ plot_display_labels = map_display_labels
167
+
168
+ # 1. DYNAMIC SIZE CALCULATION
169
+ # Calculate figure size based on number of classes.
170
+ n_classes = len(plot_labels) if plot_labels is not None else len(np.unique(y_true))
171
+ # Ensure a minimum size so very small matrices aren't tiny
172
+ fig_w = max(9, n_classes * 0.8 + 3)
173
+ fig_h = max(8, n_classes * 0.8 + 2)
174
+
175
+ # Use the calculated size instead of CLASSIFICATION_PLOT_SIZE
176
+ fig_cm, ax_cm = plt.subplots(figsize=(fig_w, fig_h), dpi=DPI_value)
177
+ disp_ = ConfusionMatrixDisplay.from_predictions(y_true,
178
+ y_pred,
179
+ cmap=format_config.cmap,
180
+ ax=ax_cm,
181
+ normalize='true',
182
+ labels=plot_labels,
183
+ display_labels=plot_display_labels,
184
+ colorbar=False)
185
+
186
+ disp_.im_.set_clim(vmin=0.0, vmax=1.0)
187
+
188
+ # Turn off gridlines
189
+ ax_cm.grid(False)
190
+
191
+ # 2. CHECK FOR FONT CLASH
192
+ # If matrix is huge, force text smaller. If small, allow user config.
193
+ final_font_size = cm_font_size + 2
194
+ if n_classes > 2:
195
+ final_font_size = cm_font_size - n_classes # Decrease font size for larger matrices
196
+
197
+ for text in ax_cm.texts:
198
+ text.set_fontsize(final_font_size)
199
+
200
+ # Update Ticks for Confusion Matrix
201
+ ax_cm.tick_params(axis='x', labelsize=cm_tick_size)
202
+ ax_cm.tick_params(axis='y', labelsize=cm_tick_size)
203
+
204
+ #if more than 3 classes, rotate x ticks
205
+ if n_classes > 3:
206
+ plt.setp(ax_cm.get_xticklabels(), rotation=45, ha='right', rotation_mode="anchor")
207
+
208
+ # Set titles and labels with padding
209
+ ax_cm.set_title("Confusion Matrix", pad=_EvaluationConfig.LABEL_PADDING, fontsize=cm_font_size + 2)
210
+ ax_cm.set_xlabel(ax_cm.get_xlabel(), labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=cm_font_size)
211
+ ax_cm.set_ylabel(ax_cm.get_ylabel(), labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=cm_font_size)
212
+
213
+ # --- ADJUST COLORBAR FONT & SIZE---
214
+ # Manually add the colorbar with the 'shrink' parameter
215
+ cbar = fig_cm.colorbar(disp_.im_, ax=ax_cm, shrink=0.8)
216
+
217
+ # Update the tick size on the new cbar object
218
+ cbar.ax.tick_params(labelsize=cm_tick_size)
219
+
220
+ # (Optional) add a label to the bar itself (e.g. "Probability")
221
+ # cbar.set_label('Probability', fontsize=12)
222
+
223
+ fig_cm.tight_layout()
224
+
225
+ cm_path = save_dir_path / "confusion_matrix.svg"
226
+ plt.savefig(cm_path)
227
+ _LOGGER.info(f"❇️ Confusion matrix saved as '{cm_path.name}'")
228
+ plt.close(fig_cm)
229
+
230
+
231
+ # Plotting logic for ROC, PR, and Calibration Curves
232
+ if y_prob is not None and y_prob.ndim == 2:
233
+ num_classes = y_prob.shape[1]
234
+
235
+ # --- Determine which classes to loop over ---
236
+ class_indices_to_plot = []
237
+ plot_titles = []
238
+ save_suffixes = []
239
+
240
+ if num_classes == 2:
241
+ # Binary case: Only plot for the positive class (index 1)
242
+ class_indices_to_plot = [1]
243
+ plot_titles = [""] # No extra title
244
+ save_suffixes = [""] # No extra suffix
245
+ _LOGGER.debug("Generating binary classification plots (ROC, PR, Calibration).")
246
+
247
+ elif num_classes > 2:
248
+ _LOGGER.debug(f"Generating One-vs-Rest plots for {num_classes} classes.")
249
+ # Multiclass case: Plot for every class (One-vs-Rest)
250
+ class_indices_to_plot = list(range(num_classes))
251
+
252
+ # --- Use class_map names if available ---
253
+ use_generic_names = True
254
+ if map_display_labels and len(map_display_labels) == num_classes:
255
+ try:
256
+ # Ensure labels are safe for filenames
257
+ safe_names = [sanitize_filename(name) for name in map_display_labels]
258
+ plot_titles = [f" ({name} vs. Rest)" for name in map_display_labels]
259
+ save_suffixes = [f"_{safe_names[i]}" for i in class_indices_to_plot]
260
+ use_generic_names = False
261
+ except Exception as e:
262
+ _LOGGER.warning(f"Failed to use 'class_map' for plot titles: {e}. Reverting to generic names.")
263
+ use_generic_names = True
264
+
265
+ if use_generic_names:
266
+ plot_titles = [f" (Class {i} vs. Rest)" for i in class_indices_to_plot]
267
+ save_suffixes = [f"_class_{i}" for i in class_indices_to_plot]
268
+
269
+ else:
270
+ # Should not happen, but good to check
271
+ _LOGGER.warning(f"Probability array has invalid shape {y_prob.shape}. Skipping ROC/PR/Calibration plots.")
272
+
273
+ # --- Loop and generate plots ---
274
+ for i, class_index in enumerate(class_indices_to_plot):
275
+ plot_title = plot_titles[i]
276
+ save_suffix = save_suffixes[i]
277
+
278
+ # Get scores for the current class
279
+ y_score = y_prob[:, class_index]
280
+
281
+ # Binarize y_true for the current class
282
+ y_true_binary = (y_true == class_index).astype(int)
283
+
284
+ # --- Save ROC Curve ---
285
+ fpr, tpr, thresholds = roc_curve(y_true_binary, y_score)
286
+
287
+ try:
288
+ # Calculate Youden's J statistic (tpr - fpr)
289
+ J = tpr - fpr
290
+ # Find the index of the best threshold
291
+ best_index = np.argmax(J)
292
+ optimal_threshold = thresholds[best_index]
293
+
294
+ # Define the filename
295
+ threshold_filename = f"best_threshold{save_suffix}.txt"
296
+ threshold_path = save_dir_path / threshold_filename
297
+
298
+ # Get the class name for the report
299
+ class_name = ""
300
+ # Check if we have display labels and the current index is valid
301
+ if map_display_labels and class_index < len(map_display_labels):
302
+ class_name = map_display_labels[class_index]
303
+ if num_classes > 2:
304
+ # Add 'vs. Rest' for multiclass one-vs-rest plots
305
+ class_name += " (vs. Rest)"
306
+ else:
307
+ # Fallback to the generic title or default binary name
308
+ class_name = plot_title.strip() or "Binary Positive Class"
309
+
310
+ # Create content for the file
311
+ file_content = (
312
+ f"Optimal Classification Threshold (Youden's J Statistic)\n"
313
+ f"Class: {class_name}\n"
314
+ f"--------------------------------------------------\n"
315
+ f"Threshold: {optimal_threshold:.6f}\n"
316
+ f"True Positive Rate (TPR): {tpr[best_index]:.6f}\n"
317
+ f"False Positive Rate (FPR): {fpr[best_index]:.6f}\n"
318
+ )
319
+
320
+ threshold_path.write_text(file_content, encoding="utf-8")
321
+ _LOGGER.info(f"💾 Optimal threshold saved as '{threshold_path.name}'")
322
+
323
+ except Exception as e:
324
+ _LOGGER.warning(f"Could not calculate or save optimal threshold: {e}")
325
+
326
+ # Calculate AUC.
327
+ auc = roc_auc_score(y_true_binary, y_score)
328
+
329
+ fig_roc, ax_roc = plt.subplots(figsize=CLASSIFICATION_PLOT_SIZE, dpi=DPI_value)
330
+ ax_roc.plot(fpr, tpr, label=f'AUC = {auc:.2f}', color=format_config.ROC_PR_line)
331
+ ax_roc.plot([0, 1], [0, 1], 'k--')
332
+ ax_roc.set_title(f'Receiver Operating Characteristic{plot_title}', pad=_EvaluationConfig.LABEL_PADDING, fontsize=format_config.font_size + 2)
333
+ ax_roc.set_xlabel('False Positive Rate', labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=format_config.font_size)
334
+ ax_roc.set_ylabel('True Positive Rate', labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=format_config.font_size)
335
+
336
+ # Apply Ticks and Legend sizing
337
+ ax_roc.tick_params(axis='x', labelsize=xtick_size)
338
+ ax_roc.tick_params(axis='y', labelsize=ytick_size)
339
+ ax_roc.legend(loc='lower right', fontsize=legend_size)
340
+
341
+ ax_roc.grid(True)
342
+ roc_path = save_dir_path / f"roc_curve{save_suffix}.svg"
343
+
344
+ plt.tight_layout()
345
+
346
+ plt.savefig(roc_path)
347
+ plt.close(fig_roc)
348
+
349
+ # --- Save Precision-Recall Curve ---
350
+ precision, recall, _ = precision_recall_curve(y_true_binary, y_score)
351
+ ap_score = average_precision_score(y_true_binary, y_score)
352
+ fig_pr, ax_pr = plt.subplots(figsize=CLASSIFICATION_PLOT_SIZE, dpi=DPI_value)
353
+ ax_pr.plot(recall, precision, label=f'Avg Precision = {ap_score:.2f}', color=format_config.ROC_PR_line)
354
+ ax_pr.set_title(f'Precision-Recall Curve{plot_title}', pad=_EvaluationConfig.LABEL_PADDING, fontsize=format_config.font_size + 2)
355
+ ax_pr.set_xlabel('Recall', labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=format_config.font_size)
356
+ ax_pr.set_ylabel('Precision', labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=format_config.font_size)
357
+
358
+ # Apply Ticks and Legend sizing
359
+ ax_pr.tick_params(axis='x', labelsize=xtick_size)
360
+ ax_pr.tick_params(axis='y', labelsize=ytick_size)
361
+ ax_pr.legend(loc='lower left', fontsize=legend_size)
362
+
363
+ ax_pr.grid(True)
364
+ pr_path = save_dir_path / f"pr_curve{save_suffix}.svg"
365
+
366
+ plt.tight_layout()
367
+
368
+ plt.savefig(pr_path)
369
+ plt.close(fig_pr)
370
+
371
+ # --- Save Calibration Plot ---
372
+ fig_cal, ax_cal = plt.subplots(figsize=CLASSIFICATION_PLOT_SIZE, dpi=DPI_value)
373
+
374
+ # --- Step 1: Get binned data *without* plotting ---
375
+ with plt.ioff(): # Suppress showing the temporary plot
376
+ fig_temp, ax_temp = plt.subplots()
377
+ cal_display_temp = CalibrationDisplay.from_predictions(
378
+ y_true_binary, # Use binarized labels
379
+ y_score,
380
+ n_bins=format_config.calibration_bins,
381
+ ax=ax_temp,
382
+ name="temp" # Add a name to suppress potential warnings
383
+ )
384
+ # Get the x, y coordinates of the binned data
385
+ line_x, line_y = cal_display_temp.line_.get_data() # type: ignore
386
+ plt.close(fig_temp) # Close the temporary plot
387
+
388
+ # --- Step 2: Build the plot from scratch ---
389
+ ax_cal.plot([0, 1], [0, 1], 'k--', label='Perfectly calibrated')
390
+
391
+ sns.regplot(
392
+ x=line_x,
393
+ y=line_y,
394
+ ax=ax_cal,
395
+ scatter=False,
396
+ label=f"Model calibration",
397
+ line_kws={
398
+ 'color': format_config.ROC_PR_line,
399
+ 'linestyle': '--',
400
+ 'linewidth': 2,
401
+ }
402
+ )
403
+
404
+ ax_cal.set_title(f'Reliability Curve{plot_title}', pad=_EvaluationConfig.LABEL_PADDING, fontsize=format_config.font_size + 2)
405
+ ax_cal.set_xlabel('Mean Predicted Probability', labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=format_config.font_size)
406
+ ax_cal.set_ylabel('Fraction of Positives', labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=format_config.font_size)
407
+
408
+ # --- Step 3: Set final limits *after* plotting ---
409
+ ax_cal.set_ylim(0.0, 1.0)
410
+ ax_cal.set_xlim(0.0, 1.0)
411
+
412
+ # Apply Ticks and Legend sizing
413
+ ax_cal.tick_params(axis='x', labelsize=xtick_size)
414
+ ax_cal.tick_params(axis='y', labelsize=ytick_size)
415
+ ax_cal.legend(loc='lower right', fontsize=legend_size)
416
+
417
+ ax_cal.grid(True)
418
+ plt.tight_layout()
419
+
420
+ cal_path = save_dir_path / f"calibration_plot{save_suffix}.svg"
421
+ plt.savefig(cal_path)
422
+ plt.close(fig_cal)
423
+
424
+ _LOGGER.info(f"📈 Saved {len(class_indices_to_plot)} sets of ROC, Precision-Recall, and Calibration plots.")
425
+
426
+
427
+ def multi_label_classification_metrics(
428
+ y_true: np.ndarray,
429
+ y_pred: np.ndarray,
430
+ y_prob: np.ndarray,
431
+ target_names: list[str],
432
+ save_dir: Union[str, Path],
433
+ config: Optional[FormatMultiLabelBinaryClassificationMetrics] = None
434
+ ):
435
+ """
436
+ Calculates and saves classification metrics for each label individually.
437
+
438
+ This function first computes overall multi-label metrics (Hamming Loss, Jaccard Score)
439
+ and then iterates through each label to generate and save individual reports,
440
+ confusion matrices, ROC curves, and Precision-Recall curves.
441
+
442
+ Args:
443
+ y_true (np.ndarray): Ground truth binary labels, shape (n_samples, n_labels).
444
+ y_pred (np.ndarray): Predicted binary labels, shape (n_samples, n_labels).
445
+ y_prob (np.ndarray): Predicted probabilities, shape (n_samples, n_labels).
446
+ target_names (List[str]): A list of names for the labels.
447
+ save_dir (str | Path): Directory to save plots and reports.
448
+ config (object): Formatting configuration object.
449
+ """
450
+ if y_true.ndim != 2 or y_prob.ndim != 2 or y_pred.ndim != 2:
451
+ _LOGGER.error("y_true, y_pred, and y_prob must be 2D arrays for multi-label classification.")
452
+ raise ValueError()
453
+ if y_true.shape != y_prob.shape or y_true.shape != y_pred.shape:
454
+ _LOGGER.error("Shapes of y_true, y_pred, and y_prob must match.")
455
+ raise ValueError()
456
+ if y_true.shape[1] != len(target_names):
457
+ _LOGGER.error("Number of target names must match the number of columns in y_true.")
458
+ raise ValueError()
459
+
460
+ save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
461
+
462
+ # --- Parse Config or use defaults ---
463
+ if config is None:
464
+ # Create a default config if one wasn't provided
465
+ format_config = _BaseMultiLabelFormat()
466
+ else:
467
+ format_config = config
468
+
469
+ # y_pred is now passed in directly, no threshold needed.
470
+
471
+ # ticks and legend font sizes
472
+ xtick_size = format_config.xtick_size
473
+ ytick_size = format_config.ytick_size
474
+ legend_size = format_config.legend_size
475
+ base_font_size = format_config.font_size
476
+
477
+ # --- Calculate and Save Overall Metrics (using y_pred) ---
478
+ h_loss = hamming_loss(y_true, y_pred)
479
+ j_score_micro = jaccard_score(y_true, y_pred, average='micro')
480
+ j_score_macro = jaccard_score(y_true, y_pred, average='macro')
481
+
482
+ overall_report = (
483
+ f"Overall Multi-Label Metrics:\n" # No threshold to report here
484
+ f"--------------------------------------------------\n"
485
+ f"Hamming Loss: {h_loss:.4f}\n"
486
+ f"Jaccard Score (micro): {j_score_micro:.4f}\n"
487
+ f"Jaccard Score (macro): {j_score_macro:.4f}\n"
488
+ f"--------------------------------------------------\n"
489
+ )
490
+ # print(overall_report)
491
+ overall_report_path = save_dir_path / "classification_report.txt"
492
+ overall_report_path.write_text(overall_report)
493
+
494
+ # --- Per-Label Metrics and Plots ---
495
+ for i, name in enumerate(target_names):
496
+ print(f" -> Evaluating label: '{name}'")
497
+ true_i = y_true[:, i]
498
+ pred_i = y_pred[:, i] # Use passed-in y_pred
499
+ prob_i = y_prob[:, i] # Use passed-in y_prob
500
+ sanitized_name = sanitize_filename(name)
501
+
502
+ # --- Save Classification Report for the label (uses y_pred) ---
503
+ report_text = classification_report(true_i, pred_i)
504
+ report_path = save_dir_path / f"classification_report_{sanitized_name}.txt"
505
+ report_path.write_text(report_text) # type: ignore
506
+
507
+ # --- Save Confusion Matrix (uses y_pred) ---
508
+ fig_cm, ax_cm = plt.subplots(figsize=_EvaluationConfig.CM_SIZE, dpi=_EvaluationConfig.DPI)
509
+ disp_ = ConfusionMatrixDisplay.from_predictions(true_i,
510
+ pred_i,
511
+ cmap=format_config.cmap, # Use config cmap
512
+ ax=ax_cm,
513
+ normalize='true',
514
+ labels=[0, 1],
515
+ display_labels=["Negative", "Positive"],
516
+ colorbar=False)
517
+
518
+ disp_.im_.set_clim(vmin=0.0, vmax=1.0)
519
+
520
+ # Turn off gridlines
521
+ ax_cm.grid(False)
522
+
523
+ # Manually update font size of cell texts
524
+ for text in ax_cm.texts:
525
+ text.set_fontsize(base_font_size + 2) # Use config font_size
526
+
527
+ # Apply ticks
528
+ ax_cm.tick_params(axis='x', labelsize=xtick_size)
529
+ ax_cm.tick_params(axis='y', labelsize=ytick_size)
530
+
531
+ # Set titles and labels with padding
532
+ ax_cm.set_title(f"Confusion Matrix for '{name}'", pad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size + 2)
533
+ ax_cm.set_xlabel(ax_cm.get_xlabel(), labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size)
534
+ ax_cm.set_ylabel(ax_cm.get_ylabel(), labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size)
535
+
536
+ # --- ADJUST COLORBAR FONT & SIZE---
537
+ # Manually add the colorbar with the 'shrink' parameter
538
+ cbar = fig_cm.colorbar(disp_.im_, ax=ax_cm, shrink=0.8)
539
+
540
+ # Update the tick size on the new cbar object
541
+ cbar.ax.tick_params(labelsize=ytick_size) # type: ignore
542
+
543
+ plt.tight_layout()
544
+
545
+ cm_path = save_dir_path / f"confusion_matrix_{sanitized_name}.svg"
546
+ plt.savefig(cm_path)
547
+ plt.close(fig_cm)
548
+
549
+ # --- Save ROC Curve (uses y_prob) ---
550
+ fpr, tpr, thresholds = roc_curve(true_i, prob_i)
551
+
552
+ try:
553
+ # Calculate Youden's J statistic (tpr - fpr)
554
+ J = tpr - fpr
555
+ # Find the index of the best threshold
556
+ best_index = np.argmax(J)
557
+ optimal_threshold = thresholds[best_index]
558
+ best_tpr = tpr[best_index]
559
+ best_fpr = fpr[best_index]
560
+
561
+ # Define the filename
562
+ threshold_filename = f"best_threshold_{sanitized_name}.txt"
563
+ threshold_path = save_dir_path / threshold_filename
564
+
565
+ # The class name is the target_name for this label
566
+ class_name = name
567
+
568
+ # Create content for the file
569
+ file_content = (
570
+ f"Optimal Classification Threshold (Youden's J Statistic)\n"
571
+ f"Class/Label: {class_name}\n"
572
+ f"--------------------------------------------------\n"
573
+ f"Threshold: {optimal_threshold:.6f}\n"
574
+ f"True Positive Rate (TPR): {best_tpr:.6f}\n"
575
+ f"False Positive Rate (FPR): {best_fpr:.6f}\n"
576
+ )
577
+
578
+ threshold_path.write_text(file_content, encoding="utf-8")
579
+ _LOGGER.info(f"💾 Optimal threshold for '{name}' saved to '{threshold_path.name}'")
580
+
581
+ except Exception as e:
582
+ _LOGGER.warning(f"Could not calculate or save optimal threshold for '{name}': {e}")
583
+
584
+ auc = roc_auc_score(true_i, prob_i)
585
+ fig_roc, ax_roc = plt.subplots(figsize=CLASSIFICATION_PLOT_SIZE, dpi=DPI_value)
586
+ ax_roc.plot(fpr, tpr, label=f'AUC = {auc:.2f}', color=format_config.ROC_PR_line) # Use config color
587
+ ax_roc.plot([0, 1], [0, 1], 'k--')
588
+
589
+ ax_roc.set_title(f'ROC Curve for "{name}"', pad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size + 2)
590
+ ax_roc.set_xlabel('False Positive Rate', labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size)
591
+ ax_roc.set_ylabel('True Positive Rate', labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size)
592
+
593
+ # Apply ticks and legend font size
594
+ ax_roc.tick_params(axis='x', labelsize=xtick_size)
595
+ ax_roc.tick_params(axis='y', labelsize=ytick_size)
596
+ ax_roc.legend(loc='lower right', fontsize=legend_size)
597
+
598
+ ax_roc.grid(True, linestyle='--', alpha=0.6)
599
+
600
+ plt.tight_layout()
601
+
602
+ roc_path = save_dir_path / f"roc_curve_{sanitized_name}.svg"
603
+ plt.savefig(roc_path)
604
+ plt.close(fig_roc)
605
+
606
+ # --- Save Precision-Recall Curve (uses y_prob) ---
607
+ precision, recall, _ = precision_recall_curve(true_i, prob_i)
608
+ ap_score = average_precision_score(true_i, prob_i)
609
+ fig_pr, ax_pr = plt.subplots(figsize=CLASSIFICATION_PLOT_SIZE, dpi=DPI_value)
610
+ ax_pr.plot(recall, precision, label=f'AP = {ap_score:.2f}', color=format_config.ROC_PR_line) # Use config color
611
+ ax_pr.set_title(f'Precision-Recall Curve for "{name}"', pad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size + 2)
612
+ ax_pr.set_xlabel('Recall', labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size)
613
+ ax_pr.set_ylabel('Precision', labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size)
614
+
615
+ # Apply ticks and legend font size
616
+ ax_pr.tick_params(axis='x', labelsize=xtick_size)
617
+ ax_pr.tick_params(axis='y', labelsize=ytick_size)
618
+ ax_pr.legend(loc='lower left', fontsize=legend_size)
619
+
620
+ ax_pr.grid(True, linestyle='--', alpha=0.6)
621
+
622
+ fig_pr.tight_layout()
623
+
624
+ pr_path = save_dir_path / f"pr_curve_{sanitized_name}.svg"
625
+ plt.savefig(pr_path)
626
+ plt.close(fig_pr)
627
+
628
+ _LOGGER.info(f"All individual label reports and plots saved to '{save_dir_path.name}'")
629
+