dragon-ml-toolbox 19.14.0__py3-none-any.whl → 20.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (219) hide show
  1. {dragon_ml_toolbox-19.14.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/METADATA +29 -46
  2. dragon_ml_toolbox-20.0.0.dist-info/RECORD +178 -0
  3. ml_tools/{ETL_cleaning.py → ETL_cleaning/__init__.py} +13 -5
  4. ml_tools/ETL_cleaning/_basic_clean.py +351 -0
  5. ml_tools/ETL_cleaning/_clean_tools.py +128 -0
  6. ml_tools/ETL_cleaning/_dragon_cleaner.py +245 -0
  7. ml_tools/ETL_cleaning/_imprimir.py +13 -0
  8. ml_tools/{ETL_engineering.py → ETL_engineering/__init__.py} +8 -4
  9. ml_tools/ETL_engineering/_dragon_engineering.py +261 -0
  10. ml_tools/ETL_engineering/_imprimir.py +24 -0
  11. ml_tools/{_core/_ETL_engineering.py → ETL_engineering/_transforms.py} +14 -267
  12. ml_tools/{_core → GUI_tools}/_GUI_tools.py +37 -40
  13. ml_tools/{GUI_tools.py → GUI_tools/__init__.py} +7 -5
  14. ml_tools/GUI_tools/_imprimir.py +12 -0
  15. ml_tools/IO_tools/_IO_loggers.py +235 -0
  16. ml_tools/IO_tools/_IO_save_load.py +151 -0
  17. ml_tools/IO_tools/_IO_utils.py +140 -0
  18. ml_tools/{IO_tools.py → IO_tools/__init__.py} +13 -5
  19. ml_tools/IO_tools/_imprimir.py +14 -0
  20. ml_tools/MICE/_MICE_imputation.py +132 -0
  21. ml_tools/{MICE_imputation.py → MICE/__init__.py} +6 -7
  22. ml_tools/{_core/_MICE_imputation.py → MICE/_dragon_mice.py} +243 -322
  23. ml_tools/MICE/_imprimir.py +11 -0
  24. ml_tools/{ML_callbacks.py → ML_callbacks/__init__.py} +12 -4
  25. ml_tools/ML_callbacks/_base.py +101 -0
  26. ml_tools/ML_callbacks/_checkpoint.py +232 -0
  27. ml_tools/ML_callbacks/_early_stop.py +208 -0
  28. ml_tools/ML_callbacks/_imprimir.py +12 -0
  29. ml_tools/ML_callbacks/_scheduler.py +197 -0
  30. ml_tools/{ML_chaining_utilities.py → ML_chain/__init__.py} +8 -3
  31. ml_tools/{_core/_ML_chaining_utilities.py → ML_chain/_chaining_tools.py} +5 -129
  32. ml_tools/ML_chain/_dragon_chain.py +140 -0
  33. ml_tools/ML_chain/_imprimir.py +11 -0
  34. ml_tools/ML_configuration/__init__.py +90 -0
  35. ml_tools/ML_configuration/_base_model_config.py +69 -0
  36. ml_tools/ML_configuration/_finalize.py +366 -0
  37. ml_tools/ML_configuration/_imprimir.py +47 -0
  38. ml_tools/ML_configuration/_metrics.py +593 -0
  39. ml_tools/ML_configuration/_models.py +206 -0
  40. ml_tools/ML_configuration/_training.py +124 -0
  41. ml_tools/ML_datasetmaster/__init__.py +28 -0
  42. ml_tools/ML_datasetmaster/_base_datasetmaster.py +337 -0
  43. ml_tools/{_core/_ML_datasetmaster.py → ML_datasetmaster/_datasetmaster.py} +9 -329
  44. ml_tools/ML_datasetmaster/_imprimir.py +15 -0
  45. ml_tools/{_core/_ML_sequence_datasetmaster.py → ML_datasetmaster/_sequence_datasetmaster.py} +13 -15
  46. ml_tools/{_core/_ML_vision_datasetmaster.py → ML_datasetmaster/_vision_datasetmaster.py} +63 -65
  47. ml_tools/ML_evaluation/__init__.py +53 -0
  48. ml_tools/ML_evaluation/_classification.py +629 -0
  49. ml_tools/ML_evaluation/_feature_importance.py +409 -0
  50. ml_tools/ML_evaluation/_imprimir.py +25 -0
  51. ml_tools/ML_evaluation/_loss.py +92 -0
  52. ml_tools/ML_evaluation/_regression.py +273 -0
  53. ml_tools/{_core/_ML_sequence_evaluation.py → ML_evaluation/_sequence.py} +8 -11
  54. ml_tools/{_core/_ML_vision_evaluation.py → ML_evaluation/_vision.py} +12 -17
  55. ml_tools/{_core → ML_evaluation_captum}/_ML_evaluation_captum.py +11 -38
  56. ml_tools/{ML_evaluation_captum.py → ML_evaluation_captum/__init__.py} +6 -4
  57. ml_tools/ML_evaluation_captum/_imprimir.py +10 -0
  58. ml_tools/{_core → ML_finalize_handler}/_ML_finalize_handler.py +3 -7
  59. ml_tools/ML_finalize_handler/__init__.py +10 -0
  60. ml_tools/ML_finalize_handler/_imprimir.py +8 -0
  61. ml_tools/ML_inference/__init__.py +22 -0
  62. ml_tools/ML_inference/_base_inference.py +166 -0
  63. ml_tools/{_core/_ML_chaining_inference.py → ML_inference/_chain_inference.py} +14 -17
  64. ml_tools/ML_inference/_dragon_inference.py +332 -0
  65. ml_tools/ML_inference/_imprimir.py +11 -0
  66. ml_tools/ML_inference/_multi_inference.py +180 -0
  67. ml_tools/ML_inference_sequence/__init__.py +10 -0
  68. ml_tools/ML_inference_sequence/_imprimir.py +8 -0
  69. ml_tools/{_core/_ML_sequence_inference.py → ML_inference_sequence/_sequence_inference.py} +11 -15
  70. ml_tools/ML_inference_vision/__init__.py +10 -0
  71. ml_tools/ML_inference_vision/_imprimir.py +8 -0
  72. ml_tools/{_core/_ML_vision_inference.py → ML_inference_vision/_vision_inference.py} +15 -19
  73. ml_tools/ML_models/__init__.py +32 -0
  74. ml_tools/{_core/_ML_models_advanced.py → ML_models/_advanced_models.py} +22 -18
  75. ml_tools/ML_models/_base_mlp_attention.py +198 -0
  76. ml_tools/{_core/_models_advanced_base.py → ML_models/_base_save_load.py} +73 -49
  77. ml_tools/ML_models/_dragon_tabular.py +248 -0
  78. ml_tools/ML_models/_imprimir.py +18 -0
  79. ml_tools/ML_models/_mlp_attention.py +134 -0
  80. ml_tools/{_core → ML_models}/_models_advanced_helpers.py +13 -13
  81. ml_tools/ML_models_sequence/__init__.py +10 -0
  82. ml_tools/ML_models_sequence/_imprimir.py +8 -0
  83. ml_tools/{_core/_ML_sequence_models.py → ML_models_sequence/_sequence_models.py} +5 -8
  84. ml_tools/ML_models_vision/__init__.py +29 -0
  85. ml_tools/ML_models_vision/_base_wrapper.py +254 -0
  86. ml_tools/ML_models_vision/_image_classification.py +182 -0
  87. ml_tools/ML_models_vision/_image_segmentation.py +108 -0
  88. ml_tools/ML_models_vision/_imprimir.py +16 -0
  89. ml_tools/ML_models_vision/_object_detection.py +135 -0
  90. ml_tools/ML_optimization/__init__.py +21 -0
  91. ml_tools/ML_optimization/_imprimir.py +13 -0
  92. ml_tools/{_core/_ML_optimization_pareto.py → ML_optimization/_multi_dragon.py} +18 -24
  93. ml_tools/ML_optimization/_single_dragon.py +203 -0
  94. ml_tools/{_core/_ML_optimization.py → ML_optimization/_single_manual.py} +75 -213
  95. ml_tools/{_core → ML_scaler}/_ML_scaler.py +8 -11
  96. ml_tools/ML_scaler/__init__.py +10 -0
  97. ml_tools/ML_scaler/_imprimir.py +8 -0
  98. ml_tools/ML_trainer/__init__.py +20 -0
  99. ml_tools/ML_trainer/_base_trainer.py +297 -0
  100. ml_tools/ML_trainer/_dragon_detection_trainer.py +402 -0
  101. ml_tools/ML_trainer/_dragon_sequence_trainer.py +540 -0
  102. ml_tools/ML_trainer/_dragon_trainer.py +1160 -0
  103. ml_tools/ML_trainer/_imprimir.py +10 -0
  104. ml_tools/{ML_utilities.py → ML_utilities/__init__.py} +14 -6
  105. ml_tools/ML_utilities/_artifact_finder.py +382 -0
  106. ml_tools/ML_utilities/_imprimir.py +16 -0
  107. ml_tools/ML_utilities/_inspection.py +325 -0
  108. ml_tools/ML_utilities/_train_tools.py +205 -0
  109. ml_tools/{ML_vision_transformers.py → ML_vision_transformers/__init__.py} +9 -6
  110. ml_tools/{_core/_ML_vision_transformers.py → ML_vision_transformers/_core_transforms.py} +11 -155
  111. ml_tools/ML_vision_transformers/_imprimir.py +14 -0
  112. ml_tools/ML_vision_transformers/_offline_augmentation.py +159 -0
  113. ml_tools/{_core/_PSO_optimization.py → PSO_optimization/_PSO.py} +58 -15
  114. ml_tools/{PSO_optimization.py → PSO_optimization/__init__.py} +5 -3
  115. ml_tools/PSO_optimization/_imprimir.py +10 -0
  116. ml_tools/SQL/__init__.py +7 -0
  117. ml_tools/{_core/_SQL.py → SQL/_dragon_SQL.py} +7 -11
  118. ml_tools/SQL/_imprimir.py +8 -0
  119. ml_tools/{_core → VIF}/_VIF_factor.py +5 -8
  120. ml_tools/{VIF_factor.py → VIF/__init__.py} +4 -2
  121. ml_tools/VIF/_imprimir.py +10 -0
  122. ml_tools/_core/__init__.py +7 -1
  123. ml_tools/_core/_logger.py +8 -18
  124. ml_tools/_core/_schema_load_ops.py +43 -0
  125. ml_tools/_core/_script_info.py +2 -2
  126. ml_tools/{data_exploration.py → data_exploration/__init__.py} +32 -16
  127. ml_tools/data_exploration/_analysis.py +214 -0
  128. ml_tools/data_exploration/_cleaning.py +566 -0
  129. ml_tools/data_exploration/_features.py +583 -0
  130. ml_tools/data_exploration/_imprimir.py +32 -0
  131. ml_tools/data_exploration/_plotting.py +487 -0
  132. ml_tools/data_exploration/_schema_ops.py +176 -0
  133. ml_tools/{ensemble_evaluation.py → ensemble_evaluation/__init__.py} +6 -4
  134. ml_tools/{_core → ensemble_evaluation}/_ensemble_evaluation.py +3 -7
  135. ml_tools/ensemble_evaluation/_imprimir.py +14 -0
  136. ml_tools/{ensemble_inference.py → ensemble_inference/__init__.py} +5 -3
  137. ml_tools/{_core → ensemble_inference}/_ensemble_inference.py +15 -18
  138. ml_tools/ensemble_inference/_imprimir.py +9 -0
  139. ml_tools/{ensemble_learning.py → ensemble_learning/__init__.py} +4 -6
  140. ml_tools/{_core → ensemble_learning}/_ensemble_learning.py +7 -10
  141. ml_tools/ensemble_learning/_imprimir.py +10 -0
  142. ml_tools/{excel_handler.py → excel_handler/__init__.py} +5 -3
  143. ml_tools/{_core → excel_handler}/_excel_handler.py +6 -10
  144. ml_tools/excel_handler/_imprimir.py +13 -0
  145. ml_tools/{keys.py → keys/__init__.py} +4 -1
  146. ml_tools/keys/_imprimir.py +11 -0
  147. ml_tools/{_core → keys}/_keys.py +2 -0
  148. ml_tools/{math_utilities.py → math_utilities/__init__.py} +5 -2
  149. ml_tools/math_utilities/_imprimir.py +11 -0
  150. ml_tools/{_core → math_utilities}/_math_utilities.py +1 -5
  151. ml_tools/{optimization_tools.py → optimization_tools/__init__.py} +9 -4
  152. ml_tools/optimization_tools/_imprimir.py +13 -0
  153. ml_tools/optimization_tools/_optimization_bounds.py +236 -0
  154. ml_tools/optimization_tools/_optimization_plots.py +218 -0
  155. ml_tools/{path_manager.py → path_manager/__init__.py} +6 -3
  156. ml_tools/{_core/_path_manager.py → path_manager/_dragonmanager.py} +11 -347
  157. ml_tools/path_manager/_imprimir.py +15 -0
  158. ml_tools/path_manager/_path_tools.py +346 -0
  159. ml_tools/plot_fonts/__init__.py +8 -0
  160. ml_tools/plot_fonts/_imprimir.py +8 -0
  161. ml_tools/{_core → plot_fonts}/_plot_fonts.py +2 -5
  162. ml_tools/schema/__init__.py +15 -0
  163. ml_tools/schema/_feature_schema.py +223 -0
  164. ml_tools/schema/_gui_schema.py +191 -0
  165. ml_tools/schema/_imprimir.py +10 -0
  166. ml_tools/{serde.py → serde/__init__.py} +4 -2
  167. ml_tools/serde/_imprimir.py +10 -0
  168. ml_tools/{_core → serde}/_serde.py +3 -8
  169. ml_tools/{utilities.py → utilities/__init__.py} +11 -6
  170. ml_tools/utilities/_imprimir.py +18 -0
  171. ml_tools/{_core/_utilities.py → utilities/_utility_save_load.py} +13 -190
  172. ml_tools/utilities/_utility_tools.py +192 -0
  173. dragon_ml_toolbox-19.14.0.dist-info/RECORD +0 -111
  174. ml_tools/ML_chaining_inference.py +0 -8
  175. ml_tools/ML_configuration.py +0 -86
  176. ml_tools/ML_configuration_pytab.py +0 -14
  177. ml_tools/ML_datasetmaster.py +0 -10
  178. ml_tools/ML_evaluation.py +0 -16
  179. ml_tools/ML_evaluation_multi.py +0 -12
  180. ml_tools/ML_finalize_handler.py +0 -8
  181. ml_tools/ML_inference.py +0 -12
  182. ml_tools/ML_models.py +0 -14
  183. ml_tools/ML_models_advanced.py +0 -14
  184. ml_tools/ML_models_pytab.py +0 -14
  185. ml_tools/ML_optimization.py +0 -14
  186. ml_tools/ML_optimization_pareto.py +0 -8
  187. ml_tools/ML_scaler.py +0 -8
  188. ml_tools/ML_sequence_datasetmaster.py +0 -8
  189. ml_tools/ML_sequence_evaluation.py +0 -10
  190. ml_tools/ML_sequence_inference.py +0 -8
  191. ml_tools/ML_sequence_models.py +0 -8
  192. ml_tools/ML_trainer.py +0 -12
  193. ml_tools/ML_vision_datasetmaster.py +0 -12
  194. ml_tools/ML_vision_evaluation.py +0 -10
  195. ml_tools/ML_vision_inference.py +0 -8
  196. ml_tools/ML_vision_models.py +0 -18
  197. ml_tools/SQL.py +0 -8
  198. ml_tools/_core/_ETL_cleaning.py +0 -694
  199. ml_tools/_core/_IO_tools.py +0 -498
  200. ml_tools/_core/_ML_callbacks.py +0 -702
  201. ml_tools/_core/_ML_configuration.py +0 -1332
  202. ml_tools/_core/_ML_configuration_pytab.py +0 -102
  203. ml_tools/_core/_ML_evaluation.py +0 -867
  204. ml_tools/_core/_ML_evaluation_multi.py +0 -544
  205. ml_tools/_core/_ML_inference.py +0 -646
  206. ml_tools/_core/_ML_models.py +0 -668
  207. ml_tools/_core/_ML_models_pytab.py +0 -693
  208. ml_tools/_core/_ML_trainer.py +0 -2323
  209. ml_tools/_core/_ML_utilities.py +0 -886
  210. ml_tools/_core/_ML_vision_models.py +0 -644
  211. ml_tools/_core/_data_exploration.py +0 -1909
  212. ml_tools/_core/_optimization_tools.py +0 -493
  213. ml_tools/_core/_schema.py +0 -359
  214. ml_tools/plot_fonts.py +0 -8
  215. ml_tools/schema.py +0 -12
  216. {dragon_ml_toolbox-19.14.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/WHEEL +0 -0
  217. {dragon_ml_toolbox-19.14.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/licenses/LICENSE +0 -0
  218. {dragon_ml_toolbox-19.14.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +0 -0
  219. {dragon_ml_toolbox-19.14.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,593 @@
1
+ from typing import Union
2
+
3
+
4
+ __all__ = [
5
+ # --- Metrics Formats ---
6
+ "FormatRegressionMetrics",
7
+ "FormatMultiTargetRegressionMetrics",
8
+ "FormatBinaryClassificationMetrics",
9
+ "FormatMultiClassClassificationMetrics",
10
+ "FormatBinaryImageClassificationMetrics",
11
+ "FormatMultiClassImageClassificationMetrics",
12
+ "FormatMultiLabelBinaryClassificationMetrics",
13
+ "FormatBinarySegmentationMetrics",
14
+ "FormatMultiClassSegmentationMetrics",
15
+ "FormatSequenceValueMetrics",
16
+ "FormatSequenceSequenceMetrics",
17
+ ]
18
+
19
+
20
+ # --- Private base classes ---
21
+
22
+ class _BaseClassificationFormat:
23
+ """
24
+ [PRIVATE] Base configuration for single-label classification metrics.
25
+ """
26
+ def __init__(self,
27
+ cmap: str="BuGn",
28
+ ROC_PR_line: str='darkorange',
29
+ calibration_bins: int=15,
30
+ xtick_size: int=22,
31
+ ytick_size: int=22,
32
+ legend_size: int=26,
33
+ font_size: int=26,
34
+ cm_font_size: int=26) -> None:
35
+ """
36
+ Initializes the formatting configuration for single-label classification metrics.
37
+
38
+ Args:
39
+ cmap (str): The matplotlib colormap name for the confusion matrix
40
+ and report heatmap.
41
+ - Sequential options: 'Blues', 'Greens', 'Reds', 'Oranges', 'Purples'
42
+ - Diverging options: 'coolwarm', 'viridis', 'plasma', 'inferno'
43
+
44
+ ROC_PR_line (str): The color name or hex code for the line plotted
45
+ on the ROC and Precision-Recall curves.
46
+ - Common color names: 'darkorange', 'cornflowerblue', 'crimson', 'forestgreen'
47
+ - Hex codes: '#FF6347', '#4682B4'
48
+
49
+ calibration_bins (int): The number of bins to use when
50
+ creating the calibration (reliability) plot.
51
+
52
+ font_size (int): The base font size to apply to the plots.
53
+
54
+ xtick_size (int): Font size for x-axis tick labels.
55
+
56
+ ytick_size (int): Font size for y-axis tick labels.
57
+
58
+ legend_size (int): Font size for plot legends.
59
+
60
+ cm_font_size (int): Font size for the confusion matrix.
61
+
62
+ <br>
63
+
64
+ ### [Matplotlib Colormaps](https://matplotlib.org/stable/users/explain/colors/colormaps.html)
65
+
66
+ <br>
67
+
68
+ ### [Matplotlib Colors](https://matplotlib.org/stable/gallery/color/named_colors.html)
69
+ """
70
+ self.cmap = cmap
71
+ self.ROC_PR_line = ROC_PR_line
72
+ self.calibration_bins = calibration_bins
73
+ self.font_size = font_size
74
+ self.xtick_size = xtick_size
75
+ self.ytick_size = ytick_size
76
+ self.legend_size = legend_size
77
+ self.cm_font_size = cm_font_size
78
+
79
+ def __repr__(self) -> str:
80
+ parts = [
81
+ f"cmap='{self.cmap}'",
82
+ f"ROC_PR_line='{self.ROC_PR_line}'",
83
+ f"calibration_bins={self.calibration_bins}",
84
+ f"font_size={self.font_size}",
85
+ f"xtick_size={self.xtick_size}",
86
+ f"ytick_size={self.ytick_size}",
87
+ f"legend_size={self.legend_size}",
88
+ f"cm_font_size={self.cm_font_size}"
89
+ ]
90
+ return f"{self.__class__.__name__}({', '.join(parts)})"
91
+
92
+
93
+ class _BaseMultiLabelFormat:
94
+ """
95
+ [PRIVATE] Base configuration for multi-label binary classification metrics.
96
+ """
97
+ def __init__(self,
98
+ cmap: str = "BuGn",
99
+ ROC_PR_line: str='darkorange',
100
+ font_size: int = 25,
101
+ xtick_size: int=20,
102
+ ytick_size: int=20,
103
+ legend_size: int=23) -> None:
104
+ """
105
+ Initializes the formatting configuration for multi-label classification metrics.
106
+
107
+ Args:
108
+ cmap (str): The matplotlib colormap name for the per-label
109
+ confusion matrices.
110
+ - Sequential options: 'Blues', 'Greens', 'Reds', 'Oranges', 'Purples'
111
+ - Diverging options: 'coolwarm', 'viridis', 'plasma', 'inferno'
112
+
113
+ ROC_PR_line (str): The color name or hex code for the line plotted
114
+ on the ROC and Precision-Recall curves (one for each label).
115
+ - Common color names: 'darkorange', 'cornflowerblue', 'crimson', 'forestgreen'
116
+ - Hex codes: '#FF6347', '#4682B4'
117
+
118
+ font_size (int): The base font size to apply to the plots.
119
+
120
+ xtick_size (int): Font size for x-axis tick labels.
121
+
122
+ ytick_size (int): Font size for y-axis tick labels.
123
+
124
+ legend_size (int): Font size for plot legends.
125
+
126
+ <br>
127
+
128
+ ### [Matplotlib Colormaps](https://matplotlib.org/stable/users/explain/colors/colormaps.html)
129
+
130
+ <br>
131
+
132
+ ### [Matplotlib Colors](https://matplotlib.org/stable/gallery/color/named_colors.html)
133
+ """
134
+ self.cmap = cmap
135
+ self.ROC_PR_line = ROC_PR_line
136
+ self.font_size = font_size
137
+ self.xtick_size = xtick_size
138
+ self.ytick_size = ytick_size
139
+ self.legend_size = legend_size
140
+
141
+ def __repr__(self) -> str:
142
+ parts = [
143
+ f"cmap='{self.cmap}'",
144
+ f"ROC_PR_line='{self.ROC_PR_line}'",
145
+ f"font_size={self.font_size}",
146
+ f"xtick_size={self.xtick_size}",
147
+ f"ytick_size={self.ytick_size}",
148
+ f"legend_size={self.legend_size}"
149
+ ]
150
+ return f"{self.__class__.__name__}({', '.join(parts)})"
151
+
152
+
153
+ class _BaseRegressionFormat:
154
+ """
155
+ [PRIVATE] Base configuration for regression metrics.
156
+ """
157
+ def __init__(self,
158
+ font_size: int=26,
159
+ scatter_color: str='tab:blue',
160
+ scatter_alpha: float=0.6,
161
+ ideal_line_color: str='k',
162
+ residual_line_color: str='red',
163
+ hist_bins: Union[int, str] = 'auto',
164
+ xtick_size: int=22,
165
+ ytick_size: int=22) -> None:
166
+ """
167
+ Initializes the formatting configuration for regression metrics.
168
+
169
+ Args:
170
+ font_size (int): The base font size to apply to the plots.
171
+ scatter_color (str): Matplotlib color for the scatter plot points.
172
+ - Common color names: 'tab:blue', 'crimson', 'forestgreen', '#4682B4'
173
+ scatter_alpha (float): Alpha transparency for scatter plot points.
174
+ ideal_line_color (str): Matplotlib color for the 'ideal' y=x line in the
175
+ True vs. Predicted plot.
176
+ - Common color names: 'k', 'red', 'darkgrey', '#FF6347'
177
+ residual_line_color (str): Matplotlib color for the y=0 line in the
178
+ Residual plot.
179
+ - Common color names: 'red', 'blue', 'k', '#4682B4'
180
+ hist_bins (int | str): The number of bins for the residuals histogram.
181
+ Defaults to 'auto' to use seaborn's automatic bin selection.
182
+ - Options: 'auto', 'sqrt', 10, 20
183
+ xtick_size (int): Font size for x-axis tick labels.
184
+ ytick_size (int): Font size for y-axis tick labels.
185
+
186
+ <br>
187
+
188
+ ### [Matplotlib Colors](https://matplotlib.org/stable/gallery/color/named_colors.html)
189
+ """
190
+ self.font_size = font_size
191
+ self.scatter_color = scatter_color
192
+ self.scatter_alpha = scatter_alpha
193
+ self.ideal_line_color = ideal_line_color
194
+ self.residual_line_color = residual_line_color
195
+ self.hist_bins = hist_bins
196
+ self.xtick_size = xtick_size
197
+ self.ytick_size = ytick_size
198
+
199
+ def __repr__(self) -> str:
200
+ parts = [
201
+ f"font_size={self.font_size}",
202
+ f"scatter_color='{self.scatter_color}'",
203
+ f"scatter_alpha={self.scatter_alpha}",
204
+ f"ideal_line_color='{self.ideal_line_color}'",
205
+ f"residual_line_color='{self.residual_line_color}'",
206
+ f"hist_bins='{self.hist_bins}'",
207
+ f"xtick_size={self.xtick_size}",
208
+ f"ytick_size={self.ytick_size}"
209
+ ]
210
+ return f"{self.__class__.__name__}({', '.join(parts)})"
211
+
212
+
213
+ class _BaseSegmentationFormat:
214
+ """
215
+ [PRIVATE] Base configuration for segmentation metrics.
216
+ """
217
+ def __init__(self,
218
+ heatmap_cmap: str = "BuGn",
219
+ cm_cmap: str = "Purples",
220
+ font_size: int = 16) -> None:
221
+ """
222
+ Initializes the formatting configuration for segmentation metrics.
223
+
224
+ Args:
225
+ heatmap_cmap (str): The matplotlib colormap name for the per-class
226
+ metrics heatmap.
227
+ - Sequential options: 'viridis', 'plasma', 'inferno', 'cividis'
228
+ - Diverging options: 'coolwarm', 'bwr', 'seismic'
229
+ cm_cmap (str): The matplotlib colormap name for the pixel-level
230
+ confusion matrix.
231
+ - Sequential options: 'Blues', 'Greens', 'Reds', 'Oranges'
232
+ font_size (int): The base font size to apply to the plots.
233
+
234
+ <br>
235
+
236
+ ### [Matplotlib Colormaps](https://matplotlib.org/stable/users/explain/colors/colormaps.html)
237
+ """
238
+ self.heatmap_cmap = heatmap_cmap
239
+ self.cm_cmap = cm_cmap
240
+ self.font_size = font_size
241
+
242
+ def __repr__(self) -> str:
243
+ parts = [
244
+ f"heatmap_cmap='{self.heatmap_cmap}'",
245
+ f"cm_cmap='{self.cm_cmap}'",
246
+ f"font_size={self.font_size}"
247
+ ]
248
+ return f"{self.__class__.__name__}({', '.join(parts)})"
249
+
250
+
251
+ class _BaseSequenceValueFormat:
252
+ """
253
+ [PRIVATE] Base configuration for sequence to value metrics.
254
+ """
255
+ def __init__(self,
256
+ font_size: int=25,
257
+ scatter_color: str='tab:blue',
258
+ scatter_alpha: float=0.6,
259
+ ideal_line_color: str='k',
260
+ residual_line_color: str='red',
261
+ hist_bins: Union[int, str] = 'auto') -> None:
262
+ """
263
+ Initializes the formatting configuration for sequence to value metrics.
264
+
265
+ Args:
266
+ font_size (int): The base font size to apply to the plots.
267
+ scatter_color (str): Matplotlib color for the scatter plot points.
268
+ - Common color names: 'tab:blue', 'crimson', 'forestgreen', '#4682B4'
269
+ scatter_alpha (float): Alpha transparency for scatter plot points.
270
+ ideal_line_color (str): Matplotlib color for the 'ideal' y=x line in the
271
+ True vs. Predicted plot.
272
+ - Common color names: 'k', 'red', 'darkgrey', '#FF6347'
273
+ residual_line_color (str): Matplotlib color for the y=0 line in the
274
+ Residual plot.
275
+ - Common color names: 'red', 'blue', 'k', '#4682B4'
276
+ hist_bins (int | str): The number of bins for the residuals histogram.
277
+ Defaults to 'auto' to use seaborn's automatic bin selection.
278
+ - Options: 'auto', 'sqrt', 10, 20
279
+
280
+ <br>
281
+
282
+ ### [Matplotlib Colors](https://matplotlib.org/stable/gallery/color/named_colors.html)
283
+ """
284
+ self.font_size = font_size
285
+ self.scatter_color = scatter_color
286
+ self.scatter_alpha = scatter_alpha
287
+ self.ideal_line_color = ideal_line_color
288
+ self.residual_line_color = residual_line_color
289
+ self.hist_bins = hist_bins
290
+
291
+ def __repr__(self) -> str:
292
+ parts = [
293
+ f"font_size={self.font_size}",
294
+ f"scatter_color='{self.scatter_color}'",
295
+ f"scatter_alpha={self.scatter_alpha}",
296
+ f"ideal_line_color='{self.ideal_line_color}'",
297
+ f"residual_line_color='{self.residual_line_color}'",
298
+ f"hist_bins='{self.hist_bins}'"
299
+ ]
300
+ return f"{self.__class__.__name__}({', '.join(parts)})"
301
+
302
+
303
+ class _BaseSequenceSequenceFormat:
304
+ """
305
+ [PRIVATE] Base configuration for sequence-to-sequence metrics.
306
+ """
307
+ def __init__(self,
308
+ font_size: int = 25,
309
+ grid_style: str = '--',
310
+ rmse_color: str = 'tab:blue',
311
+ rmse_marker: str = 'o-',
312
+ mae_color: str = 'tab:orange',
313
+ mae_marker: str = 's--'):
314
+ """
315
+ Initializes the formatting configuration for seq-to-seq metrics.
316
+
317
+ Args:
318
+ font_size (int): The base font size to apply to the plots.
319
+ grid_style (str): Matplotlib linestyle for the plot grid.
320
+ - Options: '--' (dashed), ':' (dotted), '-.' (dash-dot), '-' (solid)
321
+ rmse_color (str): Matplotlib color for the RMSE line.
322
+ - Common color names: 'tab:blue', 'crimson', 'forestgreen', '#4682B4'
323
+ rmse_marker (str): Matplotlib marker style for the RMSE line.
324
+ - Options: 'o-' (circle), 's--' (square), '^:' (triangle), 'x' (x marker)
325
+ mae_color (str): Matplotlib color for the MAE line.
326
+ - Common color names: 'tab:orange', 'purple', 'black', '#FF6347'
327
+ mae_marker (str): Matplotlib marker style for the MAE line.
328
+ - Options: 's--', 'o-', 'v:', '+' (plus marker)
329
+
330
+ <br>
331
+
332
+ ### [Matplotlib Colors](https://matplotlib.org/stable/gallery/color/named_colors.html)
333
+
334
+ <br>
335
+
336
+ ### [Matplotlib Linestyles](https://matplotlib.org/stable/gallery/lines_bars_and_markers/linestyles.html)
337
+
338
+ <br>
339
+
340
+ ### [Matplotlib Markers](https://matplotlib.org/stable/api/markers_api.html)
341
+ """
342
+ self.font_size = font_size
343
+ self.grid_style = grid_style
344
+ self.rmse_color = rmse_color
345
+ self.rmse_marker = rmse_marker
346
+ self.mae_color = mae_color
347
+ self.mae_marker = mae_marker
348
+
349
+ def __repr__(self) -> str:
350
+ parts = [
351
+ f"font_size={self.font_size}",
352
+ f"grid_style='{self.grid_style}'",
353
+ f"rmse_color='{self.rmse_color}'",
354
+ f"mae_color='{self.mae_color}'"
355
+ ]
356
+ return f"{self.__class__.__name__}({', '.join(parts)})"
357
+
358
+
359
+ # ----------------------------
360
+ # Metrics Configurations
361
+ # ----------------------------
362
+
363
+ # Regression
364
+ class FormatRegressionMetrics(_BaseRegressionFormat):
365
+ """
366
+ Configuration for single-target regression.
367
+ """
368
+ def __init__(self,
369
+ font_size: int=26,
370
+ scatter_color: str='tab:blue',
371
+ scatter_alpha: float=0.6,
372
+ ideal_line_color: str='k',
373
+ residual_line_color: str='red',
374
+ hist_bins: Union[int, str] = 'auto',
375
+ xtick_size: int=22,
376
+ ytick_size: int=22) -> None:
377
+ super().__init__(font_size=font_size,
378
+ scatter_color=scatter_color,
379
+ scatter_alpha=scatter_alpha,
380
+ ideal_line_color=ideal_line_color,
381
+ residual_line_color=residual_line_color,
382
+ hist_bins=hist_bins,
383
+ xtick_size=xtick_size,
384
+ ytick_size=ytick_size)
385
+
386
+
387
+ # Multitarget regression
388
+ class FormatMultiTargetRegressionMetrics(_BaseRegressionFormat):
389
+ """
390
+ Configuration for multi-target regression.
391
+ """
392
+ def __init__(self,
393
+ font_size: int=26,
394
+ scatter_color: str='tab:blue',
395
+ scatter_alpha: float=0.6,
396
+ ideal_line_color: str='k',
397
+ residual_line_color: str='red',
398
+ hist_bins: Union[int, str] = 'auto',
399
+ xtick_size: int=22,
400
+ ytick_size: int=22) -> None:
401
+ super().__init__(font_size=font_size,
402
+ scatter_color=scatter_color,
403
+ scatter_alpha=scatter_alpha,
404
+ ideal_line_color=ideal_line_color,
405
+ residual_line_color=residual_line_color,
406
+ hist_bins=hist_bins,
407
+ xtick_size=xtick_size,
408
+ ytick_size=ytick_size)
409
+
410
+
411
+ # Classification
412
+ class FormatBinaryClassificationMetrics(_BaseClassificationFormat):
413
+ """
414
+ Configuration for binary classification.
415
+ """
416
+ def __init__(self,
417
+ cmap: str="BuGn",
418
+ ROC_PR_line: str='darkorange',
419
+ calibration_bins: int=15,
420
+ font_size: int=26,
421
+ xtick_size: int=22,
422
+ ytick_size: int=22,
423
+ legend_size: int=26,
424
+ cm_font_size: int=26
425
+ ) -> None:
426
+ super().__init__(cmap=cmap,
427
+ ROC_PR_line=ROC_PR_line,
428
+ calibration_bins=calibration_bins,
429
+ font_size=font_size,
430
+ xtick_size=xtick_size,
431
+ ytick_size=ytick_size,
432
+ legend_size=legend_size,
433
+ cm_font_size=cm_font_size)
434
+
435
+
436
+ class FormatMultiClassClassificationMetrics(_BaseClassificationFormat):
437
+ """
438
+ Configuration for multi-class classification.
439
+ """
440
+ def __init__(self,
441
+ cmap: str="BuGn",
442
+ ROC_PR_line: str='darkorange',
443
+ calibration_bins: int=15,
444
+ font_size: int=26,
445
+ xtick_size: int=22,
446
+ ytick_size: int=22,
447
+ legend_size: int=26,
448
+ cm_font_size: int=26
449
+ ) -> None:
450
+ super().__init__(cmap=cmap,
451
+ ROC_PR_line=ROC_PR_line,
452
+ calibration_bins=calibration_bins,
453
+ font_size=font_size,
454
+ xtick_size=xtick_size,
455
+ ytick_size=ytick_size,
456
+ legend_size=legend_size,
457
+ cm_font_size=cm_font_size)
458
+
459
+
460
+ class FormatBinaryImageClassificationMetrics(_BaseClassificationFormat):
461
+ """
462
+ Configuration for binary image classification.
463
+ """
464
+ def __init__(self,
465
+ cmap: str="BuGn",
466
+ ROC_PR_line: str='darkorange',
467
+ calibration_bins: int=15,
468
+ font_size: int=26,
469
+ xtick_size: int=22,
470
+ ytick_size: int=22,
471
+ legend_size: int=26,
472
+ cm_font_size: int=26
473
+ ) -> None:
474
+ super().__init__(cmap=cmap,
475
+ ROC_PR_line=ROC_PR_line,
476
+ calibration_bins=calibration_bins,
477
+ font_size=font_size,
478
+ xtick_size=xtick_size,
479
+ ytick_size=ytick_size,
480
+ legend_size=legend_size,
481
+ cm_font_size=cm_font_size)
482
+
483
+
484
+ class FormatMultiClassImageClassificationMetrics(_BaseClassificationFormat):
485
+ """
486
+ Configuration for multi-class image classification.
487
+ """
488
+ def __init__(self,
489
+ cmap: str="BuGn",
490
+ ROC_PR_line: str='darkorange',
491
+ calibration_bins: int=15,
492
+ font_size: int=26,
493
+ xtick_size: int=22,
494
+ ytick_size: int=22,
495
+ legend_size: int=26,
496
+ cm_font_size: int=26
497
+ ) -> None:
498
+ super().__init__(cmap=cmap,
499
+ ROC_PR_line=ROC_PR_line,
500
+ calibration_bins=calibration_bins,
501
+ font_size=font_size,
502
+ xtick_size=xtick_size,
503
+ ytick_size=ytick_size,
504
+ legend_size=legend_size,
505
+ cm_font_size=cm_font_size)
506
+
507
+
508
+ # Multi-Label classification
509
+ class FormatMultiLabelBinaryClassificationMetrics(_BaseMultiLabelFormat):
510
+ """
511
+ Configuration for multi-label binary classification.
512
+ """
513
+ def __init__(self,
514
+ cmap: str = "BuGn",
515
+ ROC_PR_line: str='darkorange',
516
+ font_size: int = 25,
517
+ xtick_size: int=20,
518
+ ytick_size: int=20,
519
+ legend_size: int=23
520
+ ) -> None:
521
+ super().__init__(cmap=cmap,
522
+ ROC_PR_line=ROC_PR_line,
523
+ font_size=font_size,
524
+ xtick_size=xtick_size,
525
+ ytick_size=ytick_size,
526
+ legend_size=legend_size)
527
+
528
+
529
+ # Segmentation
530
+ class FormatBinarySegmentationMetrics(_BaseSegmentationFormat):
531
+ """
532
+ Configuration for binary segmentation.
533
+ """
534
+ def __init__(self,
535
+ heatmap_cmap: str = "BuGn",
536
+ cm_cmap: str = "Purples",
537
+ font_size: int = 16) -> None:
538
+ super().__init__(heatmap_cmap=heatmap_cmap,
539
+ cm_cmap=cm_cmap,
540
+ font_size=font_size)
541
+
542
+
543
+ class FormatMultiClassSegmentationMetrics(_BaseSegmentationFormat):
544
+ """
545
+ Configuration for multi-class segmentation.
546
+ """
547
+ def __init__(self,
548
+ heatmap_cmap: str = "BuGn",
549
+ cm_cmap: str = "Purples",
550
+ font_size: int = 16) -> None:
551
+ super().__init__(heatmap_cmap=heatmap_cmap,
552
+ cm_cmap=cm_cmap,
553
+ font_size=font_size)
554
+
555
+
556
+ # Sequence
557
+ class FormatSequenceValueMetrics(_BaseSequenceValueFormat):
558
+ """
559
+ Configuration for sequence-to-value prediction.
560
+ """
561
+ def __init__(self,
562
+ font_size: int=25,
563
+ scatter_color: str='tab:blue',
564
+ scatter_alpha: float=0.6,
565
+ ideal_line_color: str='k',
566
+ residual_line_color: str='red',
567
+ hist_bins: Union[int, str] = 'auto') -> None:
568
+ super().__init__(font_size=font_size,
569
+ scatter_color=scatter_color,
570
+ scatter_alpha=scatter_alpha,
571
+ ideal_line_color=ideal_line_color,
572
+ residual_line_color=residual_line_color,
573
+ hist_bins=hist_bins)
574
+
575
+
576
+ class FormatSequenceSequenceMetrics(_BaseSequenceSequenceFormat):
577
+ """
578
+ Configuration for sequence-to-sequence prediction.
579
+ """
580
+ def __init__(self,
581
+ font_size: int = 25,
582
+ grid_style: str = '--',
583
+ rmse_color: str = 'tab:blue',
584
+ rmse_marker: str = 'o-',
585
+ mae_color: str = 'tab:orange',
586
+ mae_marker: str = 's--'):
587
+ super().__init__(font_size=font_size,
588
+ grid_style=grid_style,
589
+ rmse_color=rmse_color,
590
+ rmse_marker=rmse_marker,
591
+ mae_color=mae_color,
592
+ mae_marker=mae_marker)
593
+