dragon-ml-toolbox 19.13.0__py3-none-any.whl → 20.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (219) hide show
  1. {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/METADATA +29 -46
  2. dragon_ml_toolbox-20.0.0.dist-info/RECORD +178 -0
  3. ml_tools/{ETL_cleaning.py → ETL_cleaning/__init__.py} +13 -5
  4. ml_tools/ETL_cleaning/_basic_clean.py +351 -0
  5. ml_tools/ETL_cleaning/_clean_tools.py +128 -0
  6. ml_tools/ETL_cleaning/_dragon_cleaner.py +245 -0
  7. ml_tools/ETL_cleaning/_imprimir.py +13 -0
  8. ml_tools/{ETL_engineering.py → ETL_engineering/__init__.py} +8 -4
  9. ml_tools/ETL_engineering/_dragon_engineering.py +261 -0
  10. ml_tools/ETL_engineering/_imprimir.py +24 -0
  11. ml_tools/{_core/_ETL_engineering.py → ETL_engineering/_transforms.py} +14 -267
  12. ml_tools/{_core → GUI_tools}/_GUI_tools.py +37 -40
  13. ml_tools/{GUI_tools.py → GUI_tools/__init__.py} +7 -5
  14. ml_tools/GUI_tools/_imprimir.py +12 -0
  15. ml_tools/IO_tools/_IO_loggers.py +235 -0
  16. ml_tools/IO_tools/_IO_save_load.py +151 -0
  17. ml_tools/IO_tools/_IO_utils.py +140 -0
  18. ml_tools/{IO_tools.py → IO_tools/__init__.py} +13 -5
  19. ml_tools/IO_tools/_imprimir.py +14 -0
  20. ml_tools/MICE/_MICE_imputation.py +132 -0
  21. ml_tools/{MICE_imputation.py → MICE/__init__.py} +6 -7
  22. ml_tools/{_core/_MICE_imputation.py → MICE/_dragon_mice.py} +243 -322
  23. ml_tools/MICE/_imprimir.py +11 -0
  24. ml_tools/{ML_callbacks.py → ML_callbacks/__init__.py} +12 -4
  25. ml_tools/ML_callbacks/_base.py +101 -0
  26. ml_tools/ML_callbacks/_checkpoint.py +232 -0
  27. ml_tools/ML_callbacks/_early_stop.py +208 -0
  28. ml_tools/ML_callbacks/_imprimir.py +12 -0
  29. ml_tools/ML_callbacks/_scheduler.py +197 -0
  30. ml_tools/{ML_chaining_utilities.py → ML_chain/__init__.py} +8 -3
  31. ml_tools/{_core/_ML_chaining_utilities.py → ML_chain/_chaining_tools.py} +5 -129
  32. ml_tools/ML_chain/_dragon_chain.py +140 -0
  33. ml_tools/ML_chain/_imprimir.py +11 -0
  34. ml_tools/ML_configuration/__init__.py +90 -0
  35. ml_tools/ML_configuration/_base_model_config.py +69 -0
  36. ml_tools/ML_configuration/_finalize.py +366 -0
  37. ml_tools/ML_configuration/_imprimir.py +47 -0
  38. ml_tools/ML_configuration/_metrics.py +593 -0
  39. ml_tools/ML_configuration/_models.py +206 -0
  40. ml_tools/ML_configuration/_training.py +124 -0
  41. ml_tools/ML_datasetmaster/__init__.py +28 -0
  42. ml_tools/ML_datasetmaster/_base_datasetmaster.py +337 -0
  43. ml_tools/{_core/_ML_datasetmaster.py → ML_datasetmaster/_datasetmaster.py} +9 -329
  44. ml_tools/ML_datasetmaster/_imprimir.py +15 -0
  45. ml_tools/{_core/_ML_sequence_datasetmaster.py → ML_datasetmaster/_sequence_datasetmaster.py} +13 -15
  46. ml_tools/{_core/_ML_vision_datasetmaster.py → ML_datasetmaster/_vision_datasetmaster.py} +63 -65
  47. ml_tools/ML_evaluation/__init__.py +53 -0
  48. ml_tools/ML_evaluation/_classification.py +629 -0
  49. ml_tools/ML_evaluation/_feature_importance.py +409 -0
  50. ml_tools/ML_evaluation/_imprimir.py +25 -0
  51. ml_tools/ML_evaluation/_loss.py +92 -0
  52. ml_tools/ML_evaluation/_regression.py +273 -0
  53. ml_tools/{_core/_ML_sequence_evaluation.py → ML_evaluation/_sequence.py} +8 -11
  54. ml_tools/{_core/_ML_vision_evaluation.py → ML_evaluation/_vision.py} +12 -17
  55. ml_tools/{_core → ML_evaluation_captum}/_ML_evaluation_captum.py +11 -38
  56. ml_tools/{ML_evaluation_captum.py → ML_evaluation_captum/__init__.py} +6 -4
  57. ml_tools/ML_evaluation_captum/_imprimir.py +10 -0
  58. ml_tools/{_core → ML_finalize_handler}/_ML_finalize_handler.py +3 -7
  59. ml_tools/ML_finalize_handler/__init__.py +10 -0
  60. ml_tools/ML_finalize_handler/_imprimir.py +8 -0
  61. ml_tools/ML_inference/__init__.py +22 -0
  62. ml_tools/ML_inference/_base_inference.py +166 -0
  63. ml_tools/{_core/_ML_chaining_inference.py → ML_inference/_chain_inference.py} +14 -17
  64. ml_tools/ML_inference/_dragon_inference.py +332 -0
  65. ml_tools/ML_inference/_imprimir.py +11 -0
  66. ml_tools/ML_inference/_multi_inference.py +180 -0
  67. ml_tools/ML_inference_sequence/__init__.py +10 -0
  68. ml_tools/ML_inference_sequence/_imprimir.py +8 -0
  69. ml_tools/{_core/_ML_sequence_inference.py → ML_inference_sequence/_sequence_inference.py} +11 -15
  70. ml_tools/ML_inference_vision/__init__.py +10 -0
  71. ml_tools/ML_inference_vision/_imprimir.py +8 -0
  72. ml_tools/{_core/_ML_vision_inference.py → ML_inference_vision/_vision_inference.py} +15 -19
  73. ml_tools/ML_models/__init__.py +32 -0
  74. ml_tools/{_core/_ML_models_advanced.py → ML_models/_advanced_models.py} +22 -18
  75. ml_tools/ML_models/_base_mlp_attention.py +198 -0
  76. ml_tools/{_core/_models_advanced_base.py → ML_models/_base_save_load.py} +73 -49
  77. ml_tools/ML_models/_dragon_tabular.py +248 -0
  78. ml_tools/ML_models/_imprimir.py +18 -0
  79. ml_tools/ML_models/_mlp_attention.py +134 -0
  80. ml_tools/{_core → ML_models}/_models_advanced_helpers.py +13 -13
  81. ml_tools/ML_models_sequence/__init__.py +10 -0
  82. ml_tools/ML_models_sequence/_imprimir.py +8 -0
  83. ml_tools/{_core/_ML_sequence_models.py → ML_models_sequence/_sequence_models.py} +5 -8
  84. ml_tools/ML_models_vision/__init__.py +29 -0
  85. ml_tools/ML_models_vision/_base_wrapper.py +254 -0
  86. ml_tools/ML_models_vision/_image_classification.py +182 -0
  87. ml_tools/ML_models_vision/_image_segmentation.py +108 -0
  88. ml_tools/ML_models_vision/_imprimir.py +16 -0
  89. ml_tools/ML_models_vision/_object_detection.py +135 -0
  90. ml_tools/ML_optimization/__init__.py +21 -0
  91. ml_tools/ML_optimization/_imprimir.py +13 -0
  92. ml_tools/{_core/_ML_optimization_pareto.py → ML_optimization/_multi_dragon.py} +18 -24
  93. ml_tools/ML_optimization/_single_dragon.py +203 -0
  94. ml_tools/{_core/_ML_optimization.py → ML_optimization/_single_manual.py} +75 -213
  95. ml_tools/{_core → ML_scaler}/_ML_scaler.py +8 -11
  96. ml_tools/ML_scaler/__init__.py +10 -0
  97. ml_tools/ML_scaler/_imprimir.py +8 -0
  98. ml_tools/ML_trainer/__init__.py +20 -0
  99. ml_tools/ML_trainer/_base_trainer.py +297 -0
  100. ml_tools/ML_trainer/_dragon_detection_trainer.py +402 -0
  101. ml_tools/ML_trainer/_dragon_sequence_trainer.py +540 -0
  102. ml_tools/ML_trainer/_dragon_trainer.py +1160 -0
  103. ml_tools/ML_trainer/_imprimir.py +10 -0
  104. ml_tools/{ML_utilities.py → ML_utilities/__init__.py} +14 -6
  105. ml_tools/ML_utilities/_artifact_finder.py +382 -0
  106. ml_tools/ML_utilities/_imprimir.py +16 -0
  107. ml_tools/ML_utilities/_inspection.py +325 -0
  108. ml_tools/ML_utilities/_train_tools.py +205 -0
  109. ml_tools/{ML_vision_transformers.py → ML_vision_transformers/__init__.py} +9 -6
  110. ml_tools/{_core/_ML_vision_transformers.py → ML_vision_transformers/_core_transforms.py} +11 -155
  111. ml_tools/ML_vision_transformers/_imprimir.py +14 -0
  112. ml_tools/ML_vision_transformers/_offline_augmentation.py +159 -0
  113. ml_tools/{_core/_PSO_optimization.py → PSO_optimization/_PSO.py} +58 -15
  114. ml_tools/{PSO_optimization.py → PSO_optimization/__init__.py} +5 -3
  115. ml_tools/PSO_optimization/_imprimir.py +10 -0
  116. ml_tools/SQL/__init__.py +7 -0
  117. ml_tools/{_core/_SQL.py → SQL/_dragon_SQL.py} +7 -11
  118. ml_tools/SQL/_imprimir.py +8 -0
  119. ml_tools/{_core → VIF}/_VIF_factor.py +5 -8
  120. ml_tools/{VIF_factor.py → VIF/__init__.py} +4 -2
  121. ml_tools/VIF/_imprimir.py +10 -0
  122. ml_tools/_core/__init__.py +7 -1
  123. ml_tools/_core/_logger.py +8 -18
  124. ml_tools/_core/_schema_load_ops.py +43 -0
  125. ml_tools/_core/_script_info.py +2 -2
  126. ml_tools/{data_exploration.py → data_exploration/__init__.py} +32 -16
  127. ml_tools/data_exploration/_analysis.py +214 -0
  128. ml_tools/data_exploration/_cleaning.py +566 -0
  129. ml_tools/data_exploration/_features.py +583 -0
  130. ml_tools/data_exploration/_imprimir.py +32 -0
  131. ml_tools/data_exploration/_plotting.py +487 -0
  132. ml_tools/data_exploration/_schema_ops.py +176 -0
  133. ml_tools/{ensemble_evaluation.py → ensemble_evaluation/__init__.py} +6 -4
  134. ml_tools/{_core → ensemble_evaluation}/_ensemble_evaluation.py +3 -7
  135. ml_tools/ensemble_evaluation/_imprimir.py +14 -0
  136. ml_tools/{ensemble_inference.py → ensemble_inference/__init__.py} +5 -3
  137. ml_tools/{_core → ensemble_inference}/_ensemble_inference.py +15 -18
  138. ml_tools/ensemble_inference/_imprimir.py +9 -0
  139. ml_tools/{ensemble_learning.py → ensemble_learning/__init__.py} +4 -6
  140. ml_tools/{_core → ensemble_learning}/_ensemble_learning.py +7 -10
  141. ml_tools/ensemble_learning/_imprimir.py +10 -0
  142. ml_tools/{excel_handler.py → excel_handler/__init__.py} +5 -3
  143. ml_tools/{_core → excel_handler}/_excel_handler.py +6 -10
  144. ml_tools/excel_handler/_imprimir.py +13 -0
  145. ml_tools/{keys.py → keys/__init__.py} +4 -1
  146. ml_tools/keys/_imprimir.py +11 -0
  147. ml_tools/{_core → keys}/_keys.py +2 -0
  148. ml_tools/{math_utilities.py → math_utilities/__init__.py} +5 -2
  149. ml_tools/math_utilities/_imprimir.py +11 -0
  150. ml_tools/{_core → math_utilities}/_math_utilities.py +1 -5
  151. ml_tools/{optimization_tools.py → optimization_tools/__init__.py} +9 -4
  152. ml_tools/optimization_tools/_imprimir.py +13 -0
  153. ml_tools/optimization_tools/_optimization_bounds.py +236 -0
  154. ml_tools/optimization_tools/_optimization_plots.py +218 -0
  155. ml_tools/{path_manager.py → path_manager/__init__.py} +6 -3
  156. ml_tools/{_core/_path_manager.py → path_manager/_dragonmanager.py} +11 -347
  157. ml_tools/path_manager/_imprimir.py +15 -0
  158. ml_tools/path_manager/_path_tools.py +346 -0
  159. ml_tools/plot_fonts/__init__.py +8 -0
  160. ml_tools/plot_fonts/_imprimir.py +8 -0
  161. ml_tools/{_core → plot_fonts}/_plot_fonts.py +2 -5
  162. ml_tools/schema/__init__.py +15 -0
  163. ml_tools/schema/_feature_schema.py +223 -0
  164. ml_tools/schema/_gui_schema.py +191 -0
  165. ml_tools/schema/_imprimir.py +10 -0
  166. ml_tools/{serde.py → serde/__init__.py} +4 -2
  167. ml_tools/serde/_imprimir.py +10 -0
  168. ml_tools/{_core → serde}/_serde.py +3 -8
  169. ml_tools/{utilities.py → utilities/__init__.py} +11 -6
  170. ml_tools/utilities/_imprimir.py +18 -0
  171. ml_tools/{_core/_utilities.py → utilities/_utility_save_load.py} +13 -190
  172. ml_tools/utilities/_utility_tools.py +192 -0
  173. dragon_ml_toolbox-19.13.0.dist-info/RECORD +0 -111
  174. ml_tools/ML_chaining_inference.py +0 -8
  175. ml_tools/ML_configuration.py +0 -86
  176. ml_tools/ML_configuration_pytab.py +0 -14
  177. ml_tools/ML_datasetmaster.py +0 -10
  178. ml_tools/ML_evaluation.py +0 -16
  179. ml_tools/ML_evaluation_multi.py +0 -12
  180. ml_tools/ML_finalize_handler.py +0 -8
  181. ml_tools/ML_inference.py +0 -12
  182. ml_tools/ML_models.py +0 -14
  183. ml_tools/ML_models_advanced.py +0 -14
  184. ml_tools/ML_models_pytab.py +0 -14
  185. ml_tools/ML_optimization.py +0 -14
  186. ml_tools/ML_optimization_pareto.py +0 -8
  187. ml_tools/ML_scaler.py +0 -8
  188. ml_tools/ML_sequence_datasetmaster.py +0 -8
  189. ml_tools/ML_sequence_evaluation.py +0 -10
  190. ml_tools/ML_sequence_inference.py +0 -8
  191. ml_tools/ML_sequence_models.py +0 -8
  192. ml_tools/ML_trainer.py +0 -12
  193. ml_tools/ML_vision_datasetmaster.py +0 -12
  194. ml_tools/ML_vision_evaluation.py +0 -10
  195. ml_tools/ML_vision_inference.py +0 -8
  196. ml_tools/ML_vision_models.py +0 -18
  197. ml_tools/SQL.py +0 -8
  198. ml_tools/_core/_ETL_cleaning.py +0 -694
  199. ml_tools/_core/_IO_tools.py +0 -498
  200. ml_tools/_core/_ML_callbacks.py +0 -702
  201. ml_tools/_core/_ML_configuration.py +0 -1332
  202. ml_tools/_core/_ML_configuration_pytab.py +0 -102
  203. ml_tools/_core/_ML_evaluation.py +0 -867
  204. ml_tools/_core/_ML_evaluation_multi.py +0 -544
  205. ml_tools/_core/_ML_inference.py +0 -646
  206. ml_tools/_core/_ML_models.py +0 -668
  207. ml_tools/_core/_ML_models_pytab.py +0 -693
  208. ml_tools/_core/_ML_trainer.py +0 -2323
  209. ml_tools/_core/_ML_utilities.py +0 -886
  210. ml_tools/_core/_ML_vision_models.py +0 -644
  211. ml_tools/_core/_data_exploration.py +0 -1901
  212. ml_tools/_core/_optimization_tools.py +0 -493
  213. ml_tools/_core/_schema.py +0 -359
  214. ml_tools/plot_fonts.py +0 -8
  215. ml_tools/schema.py +0 -12
  216. {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/WHEEL +0 -0
  217. {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/licenses/LICENSE +0 -0
  218. {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +0 -0
  219. {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/top_level.txt +0 -0
@@ -1,544 +0,0 @@
1
- import numpy as np
2
- import pandas as pd
3
- import matplotlib.pyplot as plt
4
- import seaborn as sns
5
- import torch
6
- import shap
7
- from sklearn.metrics import (
8
- classification_report,
9
- ConfusionMatrixDisplay,
10
- roc_curve,
11
- roc_auc_score,
12
- precision_recall_curve,
13
- average_precision_score,
14
- mean_squared_error,
15
- mean_absolute_error,
16
- r2_score,
17
- median_absolute_error,
18
- hamming_loss,
19
- jaccard_score
20
- )
21
- from pathlib import Path
22
- from typing import Union, List, Literal, Optional
23
- import warnings
24
-
25
- from ._path_manager import make_fullpath, sanitize_filename
26
- from ._logger import get_logger
27
- from ._script_info import _script_info
28
- from ._keys import SHAPKeys, _EvaluationConfig
29
- from ._ML_configuration import (MultiTargetRegressionMetricsFormat,
30
- _BaseRegressionFormat,
31
- MultiLabelBinaryClassificationMetricsFormat,
32
- _BaseMultiLabelFormat)
33
-
34
-
35
- _LOGGER = get_logger("Evaluation Multi")
36
-
37
-
38
- __all__ = [
39
- "multi_target_regression_metrics",
40
- "multi_label_classification_metrics",
41
- "multi_target_shap_summary_plot",
42
- ]
43
-
44
-
45
- DPI_value = _EvaluationConfig.DPI
46
- REGRESSION_PLOT_SIZE = _EvaluationConfig.REGRESSION_PLOT_SIZE
47
- CLASSIFICATION_PLOT_SIZE = _EvaluationConfig.CLASSIFICATION_PLOT_SIZE
48
-
49
-
50
- def multi_target_regression_metrics(
51
- y_true: np.ndarray,
52
- y_pred: np.ndarray,
53
- target_names: List[str],
54
- save_dir: Union[str, Path],
55
- config: Optional[MultiTargetRegressionMetricsFormat] = None
56
- ):
57
- """
58
- Calculates and saves regression metrics for each target individually.
59
-
60
- For each target, this function saves a residual plot and a true vs. predicted plot.
61
- It also saves a single CSV file containing the key metrics (RMSE, MAE, R², MedAE)
62
- for all targets.
63
-
64
- Args:
65
- y_true (np.ndarray): Ground truth values, shape (n_samples, n_targets).
66
- y_pred (np.ndarray): Predicted values, shape (n_samples, n_targets).
67
- target_names (List[str]): A list of names for the target variables.
68
- save_dir (str | Path): Directory to save plots and the report.
69
- config (object): Formatting configuration object.
70
- """
71
- if y_true.ndim != 2 or y_pred.ndim != 2:
72
- _LOGGER.error("y_true and y_pred must be 2D arrays for multi-target regression.")
73
- raise ValueError()
74
- if y_true.shape != y_pred.shape:
75
- _LOGGER.error("Shapes of y_true and y_pred must match.")
76
- raise ValueError()
77
- if y_true.shape[1] != len(target_names):
78
- _LOGGER.error("Number of target names must match the number of columns in y_true.")
79
- raise ValueError()
80
-
81
- save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
82
- metrics_summary = []
83
-
84
- # --- Parse Config or use defaults ---
85
- if config is None:
86
- # Create a default config if one wasn't provided
87
- format_config = _BaseRegressionFormat()
88
- else:
89
- format_config = config
90
-
91
- # --- Set Matplotlib font size ---
92
- # original_rc_params = plt.rcParams.copy()
93
- # plt.rcParams.update({'font.size': format_config.font_size})
94
-
95
- # ticks font sizes
96
- xtick_size = format_config.xtick_size
97
- ytick_size = format_config.ytick_size
98
- base_font_size = format_config.font_size
99
-
100
- _LOGGER.debug("--- Multi-Target Regression Evaluation ---")
101
-
102
- for i, name in enumerate(target_names):
103
- # print(f" -> Evaluating target: '{name}'")
104
- true_i = y_true[:, i]
105
- pred_i = y_pred[:, i]
106
- sanitized_name = sanitize_filename(name)
107
-
108
- # --- Calculate Metrics ---
109
- rmse = np.sqrt(mean_squared_error(true_i, pred_i))
110
- mae = mean_absolute_error(true_i, pred_i)
111
- r2 = r2_score(true_i, pred_i)
112
- medae = median_absolute_error(true_i, pred_i)
113
- metrics_summary.append({
114
- 'Target': name,
115
- 'RMSE': rmse,
116
- 'MAE': mae,
117
- 'MedAE': medae,
118
- 'R2-score': r2,
119
- })
120
-
121
- # --- Save Residual Plot ---
122
- residuals = true_i - pred_i
123
- fig_res, ax_res = plt.subplots(figsize=REGRESSION_PLOT_SIZE, dpi=DPI_value)
124
- ax_res.scatter(pred_i, residuals,
125
- alpha=format_config.scatter_alpha,
126
- edgecolors='k',
127
- s=50,
128
- color=format_config.scatter_color) # Use config color
129
- ax_res.axhline(0, color=format_config.residual_line_color, linestyle='--') # Use config color
130
- ax_res.set_xlabel("Predicted Values", labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size)
131
- ax_res.set_ylabel("Residuals", labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size)
132
- ax_res.set_title(f"Residual Plot for '{name}'", pad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size + 2)
133
-
134
- # Apply Ticks
135
- ax_res.tick_params(axis='x', labelsize=xtick_size)
136
- ax_res.tick_params(axis='y', labelsize=ytick_size)
137
-
138
- ax_res.grid(True, linestyle='--', alpha=0.6)
139
- plt.tight_layout()
140
- res_path = save_dir_path / f"residual_plot_{sanitized_name}.svg"
141
- plt.savefig(res_path)
142
- plt.close(fig_res)
143
-
144
- # --- Save True vs. Predicted Plot ---
145
- fig_tvp, ax_tvp = plt.subplots(figsize=REGRESSION_PLOT_SIZE, dpi=DPI_value)
146
- ax_tvp.scatter(true_i, pred_i,
147
- alpha=format_config.scatter_alpha,
148
- edgecolors='k',
149
- s=50,
150
- color=format_config.scatter_color) # Use config color
151
- ax_tvp.plot([true_i.min(), true_i.max()], [true_i.min(), true_i.max()],
152
- linestyle='--',
153
- lw=2,
154
- color=format_config.ideal_line_color) # Use config color
155
- ax_tvp.set_xlabel('True Values', labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size)
156
- ax_tvp.set_ylabel('Predicted Values', labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size)
157
- ax_tvp.set_title(f"True vs. Predicted for '{name}'", pad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size + 2)
158
-
159
- # Apply Ticks
160
- ax_tvp.tick_params(axis='x', labelsize=xtick_size)
161
- ax_tvp.tick_params(axis='y', labelsize=ytick_size)
162
-
163
- ax_tvp.grid(True, linestyle='--', alpha=0.6)
164
- plt.tight_layout()
165
- tvp_path = save_dir_path / f"true_vs_predicted_plot_{sanitized_name}.svg"
166
- plt.savefig(tvp_path)
167
- plt.close(fig_tvp)
168
-
169
- # --- Save Summary Report ---
170
- summary_df = pd.DataFrame(metrics_summary)
171
- report_path = save_dir_path / "regression_report_multi.csv"
172
- summary_df.to_csv(report_path, index=False)
173
- _LOGGER.info(f"Full regression report saved to '{report_path.name}'")
174
-
175
- # --- Restore RC params ---
176
- # plt.rcParams.update(original_rc_params)
177
-
178
-
179
- def multi_label_classification_metrics(
180
- y_true: np.ndarray,
181
- y_pred: np.ndarray,
182
- y_prob: np.ndarray,
183
- target_names: List[str],
184
- save_dir: Union[str, Path],
185
- config: Optional[MultiLabelBinaryClassificationMetricsFormat] = None
186
- ):
187
- """
188
- Calculates and saves classification metrics for each label individually.
189
-
190
- This function first computes overall multi-label metrics (Hamming Loss, Jaccard Score)
191
- and then iterates through each label to generate and save individual reports,
192
- confusion matrices, ROC curves, and Precision-Recall curves.
193
-
194
- Args:
195
- y_true (np.ndarray): Ground truth binary labels, shape (n_samples, n_labels).
196
- y_pred (np.ndarray): Predicted binary labels, shape (n_samples, n_labels).
197
- y_prob (np.ndarray): Predicted probabilities, shape (n_samples, n_labels).
198
- target_names (List[str]): A list of names for the labels.
199
- save_dir (str | Path): Directory to save plots and reports.
200
- config (object): Formatting configuration object.
201
- """
202
- if y_true.ndim != 2 or y_prob.ndim != 2 or y_pred.ndim != 2:
203
- _LOGGER.error("y_true, y_pred, and y_prob must be 2D arrays for multi-label classification.")
204
- raise ValueError()
205
- if y_true.shape != y_prob.shape or y_true.shape != y_pred.shape:
206
- _LOGGER.error("Shapes of y_true, y_pred, and y_prob must match.")
207
- raise ValueError()
208
- if y_true.shape[1] != len(target_names):
209
- _LOGGER.error("Number of target names must match the number of columns in y_true.")
210
- raise ValueError()
211
-
212
- save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
213
-
214
- # --- Parse Config or use defaults ---
215
- if config is None:
216
- # Create a default config if one wasn't provided
217
- format_config = _BaseMultiLabelFormat()
218
- else:
219
- format_config = config
220
-
221
- # y_pred is now passed in directly, no threshold needed.
222
-
223
- # --- Save current RC params and update font size ---
224
- # original_rc_params = plt.rcParams.copy()
225
- # plt.rcParams.update({'font.size': format_config.font_size})
226
-
227
- # ticks and legend font sizes
228
- xtick_size = format_config.xtick_size
229
- ytick_size = format_config.ytick_size
230
- legend_size = format_config.legend_size
231
- base_font_size = format_config.font_size
232
-
233
- # --- Calculate and Save Overall Metrics (using y_pred) ---
234
- h_loss = hamming_loss(y_true, y_pred)
235
- j_score_micro = jaccard_score(y_true, y_pred, average='micro')
236
- j_score_macro = jaccard_score(y_true, y_pred, average='macro')
237
-
238
- overall_report = (
239
- f"Overall Multi-Label Metrics:\n" # No threshold to report here
240
- f"--------------------------------------------------\n"
241
- f"Hamming Loss: {h_loss:.4f}\n"
242
- f"Jaccard Score (micro): {j_score_micro:.4f}\n"
243
- f"Jaccard Score (macro): {j_score_macro:.4f}\n"
244
- f"--------------------------------------------------\n"
245
- )
246
- # print(overall_report)
247
- overall_report_path = save_dir_path / "classification_report.txt"
248
- overall_report_path.write_text(overall_report)
249
-
250
- # --- Per-Label Metrics and Plots ---
251
- for i, name in enumerate(target_names):
252
- print(f" -> Evaluating label: '{name}'")
253
- true_i = y_true[:, i]
254
- pred_i = y_pred[:, i] # Use passed-in y_pred
255
- prob_i = y_prob[:, i] # Use passed-in y_prob
256
- sanitized_name = sanitize_filename(name)
257
-
258
- # --- Save Classification Report for the label (uses y_pred) ---
259
- report_text = classification_report(true_i, pred_i)
260
- report_path = save_dir_path / f"classification_report_{sanitized_name}.txt"
261
- report_path.write_text(report_text) # type: ignore
262
-
263
- # --- Save Confusion Matrix (uses y_pred) ---
264
- fig_cm, ax_cm = plt.subplots(figsize=_EvaluationConfig.CM_SIZE, dpi=_EvaluationConfig.DPI)
265
- disp_ = ConfusionMatrixDisplay.from_predictions(true_i,
266
- pred_i,
267
- cmap=format_config.cmap, # Use config cmap
268
- ax=ax_cm,
269
- normalize='true',
270
- labels=[0, 1],
271
- display_labels=["Negative", "Positive"],
272
- colorbar=False)
273
-
274
- disp_.im_.set_clim(vmin=0.0, vmax=1.0)
275
-
276
- # Turn off gridlines
277
- ax_cm.grid(False)
278
-
279
- # Manually update font size of cell texts
280
- for text in ax_cm.texts:
281
- text.set_fontsize(base_font_size + 2) # Use config font_size
282
-
283
- # Apply ticks
284
- ax_cm.tick_params(axis='x', labelsize=xtick_size)
285
- ax_cm.tick_params(axis='y', labelsize=ytick_size)
286
-
287
- # Set titles and labels with padding
288
- ax_cm.set_title(f"Confusion Matrix for '{name}'", pad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size + 2)
289
- ax_cm.set_xlabel(ax_cm.get_xlabel(), labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size)
290
- ax_cm.set_ylabel(ax_cm.get_ylabel(), labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size)
291
-
292
- # --- ADJUST COLORBAR FONT & SIZE---
293
- # Manually add the colorbar with the 'shrink' parameter
294
- cbar = fig_cm.colorbar(disp_.im_, ax=ax_cm, shrink=0.8)
295
-
296
- # Update the tick size on the new cbar object
297
- cbar.ax.tick_params(labelsize=ytick_size) # type: ignore
298
-
299
- plt.tight_layout()
300
-
301
- cm_path = save_dir_path / f"confusion_matrix_{sanitized_name}.svg"
302
- plt.savefig(cm_path)
303
- plt.close(fig_cm)
304
-
305
- # --- Save ROC Curve (uses y_prob) ---
306
- fpr, tpr, thresholds = roc_curve(true_i, prob_i)
307
-
308
- try:
309
- # Calculate Youden's J statistic (tpr - fpr)
310
- J = tpr - fpr
311
- # Find the index of the best threshold
312
- best_index = np.argmax(J)
313
- optimal_threshold = thresholds[best_index]
314
- best_tpr = tpr[best_index]
315
- best_fpr = fpr[best_index]
316
-
317
- # Define the filename
318
- threshold_filename = f"best_threshold_{sanitized_name}.txt"
319
- threshold_path = save_dir_path / threshold_filename
320
-
321
- # The class name is the target_name for this label
322
- class_name = name
323
-
324
- # Create content for the file
325
- file_content = (
326
- f"Optimal Classification Threshold (Youden's J Statistic)\n"
327
- f"Class/Label: {class_name}\n"
328
- f"--------------------------------------------------\n"
329
- f"Threshold: {optimal_threshold:.6f}\n"
330
- f"True Positive Rate (TPR): {best_tpr:.6f}\n"
331
- f"False Positive Rate (FPR): {best_fpr:.6f}\n"
332
- )
333
-
334
- threshold_path.write_text(file_content, encoding="utf-8")
335
- _LOGGER.info(f"💾 Optimal threshold for '{name}' saved to '{threshold_path.name}'")
336
-
337
- except Exception as e:
338
- _LOGGER.warning(f"Could not calculate or save optimal threshold for '{name}': {e}")
339
-
340
- auc = roc_auc_score(true_i, prob_i)
341
- fig_roc, ax_roc = plt.subplots(figsize=CLASSIFICATION_PLOT_SIZE, dpi=DPI_value)
342
- ax_roc.plot(fpr, tpr, label=f'AUC = {auc:.2f}', color=format_config.ROC_PR_line) # Use config color
343
- ax_roc.plot([0, 1], [0, 1], 'k--')
344
-
345
- ax_roc.set_title(f'ROC Curve for "{name}"', pad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size + 2)
346
- ax_roc.set_xlabel('False Positive Rate', labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size)
347
- ax_roc.set_ylabel('True Positive Rate', labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size)
348
-
349
- # Apply ticks and legend font size
350
- ax_roc.tick_params(axis='x', labelsize=xtick_size)
351
- ax_roc.tick_params(axis='y', labelsize=ytick_size)
352
- ax_roc.legend(loc='lower right', fontsize=legend_size)
353
-
354
- ax_roc.grid(True, linestyle='--', alpha=0.6)
355
-
356
- plt.tight_layout()
357
-
358
- roc_path = save_dir_path / f"roc_curve_{sanitized_name}.svg"
359
- plt.savefig(roc_path)
360
- plt.close(fig_roc)
361
-
362
- # --- Save Precision-Recall Curve (uses y_prob) ---
363
- precision, recall, _ = precision_recall_curve(true_i, prob_i)
364
- ap_score = average_precision_score(true_i, prob_i)
365
- fig_pr, ax_pr = plt.subplots(figsize=CLASSIFICATION_PLOT_SIZE, dpi=DPI_value)
366
- ax_pr.plot(recall, precision, label=f'AP = {ap_score:.2f}', color=format_config.ROC_PR_line) # Use config color
367
- ax_pr.set_title(f'Precision-Recall Curve for "{name}"', pad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size + 2)
368
- ax_pr.set_xlabel('Recall', labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size)
369
- ax_pr.set_ylabel('Precision', labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size)
370
-
371
- # Apply ticks and legend font size
372
- ax_pr.tick_params(axis='x', labelsize=xtick_size)
373
- ax_pr.tick_params(axis='y', labelsize=ytick_size)
374
- ax_pr.legend(loc='lower left', fontsize=legend_size)
375
-
376
- ax_pr.grid(True, linestyle='--', alpha=0.6)
377
-
378
- fig_pr.tight_layout()
379
-
380
- pr_path = save_dir_path / f"pr_curve_{sanitized_name}.svg"
381
- plt.savefig(pr_path)
382
- plt.close(fig_pr)
383
-
384
- # restore RC params
385
- # plt.rcParams.update(original_rc_params)
386
-
387
- _LOGGER.info(f"All individual label reports and plots saved to '{save_dir_path.name}'")
388
-
389
-
390
- def multi_target_shap_summary_plot(
391
- model: torch.nn.Module,
392
- background_data: Union[torch.Tensor, np.ndarray],
393
- instances_to_explain: Union[torch.Tensor, np.ndarray],
394
- feature_names: List[str],
395
- target_names: List[str],
396
- save_dir: Union[str, Path],
397
- device: torch.device = torch.device('cpu'),
398
- explainer_type: Literal['deep', 'kernel'] = 'kernel'
399
- ):
400
- """
401
- DEPRECATED
402
-
403
- Calculates SHAP values for a multi-target model and saves summary plots and data for each target.
404
-
405
- Args:
406
- model (torch.nn.Module): The trained PyTorch model.
407
- background_data (torch.Tensor | np.ndarray): A sample of data for the explainer background.
408
- instances_to_explain (torch.Tensor | np.ndarray): The specific data instances to explain.
409
- feature_names (List[str]): Names of the features for plot labeling.
410
- target_names (List[str]): Names of the output targets.
411
- save_dir (str | Path): Directory to save SHAP artifacts.
412
- device (torch.device): The torch device for SHAP calculations.
413
- explainer_type (Literal['deep', 'kernel']): The explainer to use.
414
- - 'deep': Uses shap.DeepExplainer. Fast and efficient.
415
- - 'kernel': Uses shap.KernelExplainer. Model-agnostic but slow and memory-intensive.
416
- """
417
- _LOGGER.warning("This function is deprecated and may be removed in future versions. Use Captum module instead.")
418
-
419
- _LOGGER.info(f"--- Multi-Target SHAP Value Explanation (Using: {explainer_type.upper()}Explainer) ---")
420
- model.eval()
421
- # model.cpu()
422
-
423
- shap_values_list = None
424
- instances_to_explain_np = None
425
-
426
- if explainer_type == 'deep':
427
- # --- 1. Use DeepExplainer ---
428
-
429
- # Ensure data is torch.Tensor
430
- if isinstance(background_data, np.ndarray):
431
- background_data = torch.from_numpy(background_data).float()
432
- if isinstance(instances_to_explain, np.ndarray):
433
- instances_to_explain = torch.from_numpy(instances_to_explain).float()
434
-
435
- if torch.isnan(background_data).any() or torch.isnan(instances_to_explain).any():
436
- _LOGGER.error("Input data for SHAP contains NaN values. Aborting explanation.")
437
- return
438
-
439
- background_data = background_data.to(device)
440
- instances_to_explain = instances_to_explain.to(device)
441
-
442
- with warnings.catch_warnings():
443
- warnings.simplefilter("ignore", category=UserWarning)
444
- explainer = shap.DeepExplainer(model, background_data)
445
-
446
- # print("Calculating SHAP values with DeepExplainer...")
447
- # DeepExplainer returns a list of arrays for multi-output models
448
- shap_values_list = explainer.shap_values(instances_to_explain)
449
- instances_to_explain_np = instances_to_explain.cpu().numpy()
450
-
451
- elif explainer_type == 'kernel':
452
- # --- 2. Use KernelExplainer ---
453
- _LOGGER.warning(
454
- "KernelExplainer is memory-intensive and slow. Consider reducing the number of instances to explain if the process terminates unexpectedly."
455
- )
456
-
457
- # Convert all data to numpy
458
- background_data_np = background_data.numpy() if isinstance(background_data, torch.Tensor) else background_data
459
- instances_to_explain_np = instances_to_explain.numpy() if isinstance(instances_to_explain, torch.Tensor) else instances_to_explain
460
-
461
- if np.isnan(background_data_np).any() or np.isnan(instances_to_explain_np).any():
462
- _LOGGER.error("Input data for SHAP contains NaN values. Aborting explanation.")
463
- return
464
-
465
- background_summary = shap.kmeans(background_data_np, 30)
466
-
467
- def prediction_wrapper(x_np: np.ndarray) -> np.ndarray:
468
- x_torch = torch.from_numpy(x_np).float().to(device)
469
- with torch.no_grad():
470
- output = model(x_torch)
471
- return output.cpu().numpy() # Return full multi-output array
472
-
473
- explainer = shap.KernelExplainer(prediction_wrapper, background_summary)
474
- # print("Calculating SHAP values with KernelExplainer...")
475
- # KernelExplainer also returns a list of arrays for multi-output models
476
- shap_values_list = explainer.shap_values(instances_to_explain_np, l1_reg="aic")
477
- # instances_to_explain_np is already set
478
-
479
- else:
480
- _LOGGER.error(f"Invalid explainer_type: '{explainer_type}'. Must be 'deep' or 'kernel'.")
481
- raise ValueError("Invalid explainer_type")
482
-
483
- # --- 3. Plotting and Saving (Common Logic) ---
484
-
485
- if shap_values_list is None or instances_to_explain_np is None:
486
- _LOGGER.error("SHAP value calculation failed. Aborting plotting.")
487
- return
488
-
489
- # Ensure number of SHAP value arrays matches number of target names
490
- if len(shap_values_list) != len(target_names):
491
- _LOGGER.error(
492
- f"SHAP explanation mismatch: Model produced {len(shap_values_list)} "
493
- f"outputs, but {len(target_names)} target_names were provided."
494
- )
495
- return
496
-
497
- save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
498
- plt.ioff()
499
-
500
- # Iterate through each target's SHAP values and generate plots.
501
- for i, target_name in enumerate(target_names):
502
- print(f" -> Generating SHAP plots for target: '{target_name}'")
503
- shap_values_for_target = shap_values_list[i]
504
- sanitized_target_name = sanitize_filename(target_name)
505
-
506
- # Save Bar Plot for the target
507
- shap.summary_plot(shap_values_for_target, instances_to_explain_np, feature_names=feature_names, plot_type="bar", show=False)
508
- plt.title(f"SHAP Feature Importance for '{target_name}'")
509
- plt.tight_layout()
510
- bar_path = save_dir_path / f"shap_bar_plot_{sanitized_target_name}.svg"
511
- plt.savefig(bar_path)
512
- plt.close()
513
-
514
- # Save Dot Plot for the target
515
- shap.summary_plot(shap_values_for_target, instances_to_explain_np, feature_names=feature_names, plot_type="dot", show=False)
516
- plt.title(f"SHAP Feature Importance for '{target_name}'")
517
- if plt.gcf().axes and len(plt.gcf().axes) > 1:
518
- cb = plt.gcf().axes[-1]
519
- cb.set_ylabel("", size=1)
520
- plt.tight_layout()
521
- dot_path = save_dir_path / f"shap_dot_plot_{sanitized_target_name}.svg"
522
- plt.savefig(dot_path)
523
- plt.close()
524
-
525
- # --- Save Summary Data to CSV for this target ---
526
- shap_summary_filename = f"{SHAPKeys.SAVENAME}_{sanitized_target_name}.csv"
527
- summary_path = save_dir_path / shap_summary_filename
528
-
529
- # For a specific target, shap_values_for_target is just a 2D array
530
- mean_abs_shap = np.abs(shap_values_for_target).mean(axis=0).flatten()
531
-
532
- summary_df = pd.DataFrame({
533
- SHAPKeys.FEATURE_COLUMN: feature_names,
534
- SHAPKeys.SHAP_VALUE_COLUMN: mean_abs_shap
535
- }).sort_values(SHAPKeys.SHAP_VALUE_COLUMN, ascending=False)
536
-
537
- summary_df.to_csv(summary_path, index=False)
538
-
539
- plt.ion()
540
- _LOGGER.info(f"All SHAP plots saved to '{save_dir_path.name}'")
541
-
542
-
543
- def info():
544
- _script_info(__all__)