dragon-ml-toolbox 19.13.0__py3-none-any.whl → 20.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (219) hide show
  1. {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/METADATA +29 -46
  2. dragon_ml_toolbox-20.0.0.dist-info/RECORD +178 -0
  3. ml_tools/{ETL_cleaning.py → ETL_cleaning/__init__.py} +13 -5
  4. ml_tools/ETL_cleaning/_basic_clean.py +351 -0
  5. ml_tools/ETL_cleaning/_clean_tools.py +128 -0
  6. ml_tools/ETL_cleaning/_dragon_cleaner.py +245 -0
  7. ml_tools/ETL_cleaning/_imprimir.py +13 -0
  8. ml_tools/{ETL_engineering.py → ETL_engineering/__init__.py} +8 -4
  9. ml_tools/ETL_engineering/_dragon_engineering.py +261 -0
  10. ml_tools/ETL_engineering/_imprimir.py +24 -0
  11. ml_tools/{_core/_ETL_engineering.py → ETL_engineering/_transforms.py} +14 -267
  12. ml_tools/{_core → GUI_tools}/_GUI_tools.py +37 -40
  13. ml_tools/{GUI_tools.py → GUI_tools/__init__.py} +7 -5
  14. ml_tools/GUI_tools/_imprimir.py +12 -0
  15. ml_tools/IO_tools/_IO_loggers.py +235 -0
  16. ml_tools/IO_tools/_IO_save_load.py +151 -0
  17. ml_tools/IO_tools/_IO_utils.py +140 -0
  18. ml_tools/{IO_tools.py → IO_tools/__init__.py} +13 -5
  19. ml_tools/IO_tools/_imprimir.py +14 -0
  20. ml_tools/MICE/_MICE_imputation.py +132 -0
  21. ml_tools/{MICE_imputation.py → MICE/__init__.py} +6 -7
  22. ml_tools/{_core/_MICE_imputation.py → MICE/_dragon_mice.py} +243 -322
  23. ml_tools/MICE/_imprimir.py +11 -0
  24. ml_tools/{ML_callbacks.py → ML_callbacks/__init__.py} +12 -4
  25. ml_tools/ML_callbacks/_base.py +101 -0
  26. ml_tools/ML_callbacks/_checkpoint.py +232 -0
  27. ml_tools/ML_callbacks/_early_stop.py +208 -0
  28. ml_tools/ML_callbacks/_imprimir.py +12 -0
  29. ml_tools/ML_callbacks/_scheduler.py +197 -0
  30. ml_tools/{ML_chaining_utilities.py → ML_chain/__init__.py} +8 -3
  31. ml_tools/{_core/_ML_chaining_utilities.py → ML_chain/_chaining_tools.py} +5 -129
  32. ml_tools/ML_chain/_dragon_chain.py +140 -0
  33. ml_tools/ML_chain/_imprimir.py +11 -0
  34. ml_tools/ML_configuration/__init__.py +90 -0
  35. ml_tools/ML_configuration/_base_model_config.py +69 -0
  36. ml_tools/ML_configuration/_finalize.py +366 -0
  37. ml_tools/ML_configuration/_imprimir.py +47 -0
  38. ml_tools/ML_configuration/_metrics.py +593 -0
  39. ml_tools/ML_configuration/_models.py +206 -0
  40. ml_tools/ML_configuration/_training.py +124 -0
  41. ml_tools/ML_datasetmaster/__init__.py +28 -0
  42. ml_tools/ML_datasetmaster/_base_datasetmaster.py +337 -0
  43. ml_tools/{_core/_ML_datasetmaster.py → ML_datasetmaster/_datasetmaster.py} +9 -329
  44. ml_tools/ML_datasetmaster/_imprimir.py +15 -0
  45. ml_tools/{_core/_ML_sequence_datasetmaster.py → ML_datasetmaster/_sequence_datasetmaster.py} +13 -15
  46. ml_tools/{_core/_ML_vision_datasetmaster.py → ML_datasetmaster/_vision_datasetmaster.py} +63 -65
  47. ml_tools/ML_evaluation/__init__.py +53 -0
  48. ml_tools/ML_evaluation/_classification.py +629 -0
  49. ml_tools/ML_evaluation/_feature_importance.py +409 -0
  50. ml_tools/ML_evaluation/_imprimir.py +25 -0
  51. ml_tools/ML_evaluation/_loss.py +92 -0
  52. ml_tools/ML_evaluation/_regression.py +273 -0
  53. ml_tools/{_core/_ML_sequence_evaluation.py → ML_evaluation/_sequence.py} +8 -11
  54. ml_tools/{_core/_ML_vision_evaluation.py → ML_evaluation/_vision.py} +12 -17
  55. ml_tools/{_core → ML_evaluation_captum}/_ML_evaluation_captum.py +11 -38
  56. ml_tools/{ML_evaluation_captum.py → ML_evaluation_captum/__init__.py} +6 -4
  57. ml_tools/ML_evaluation_captum/_imprimir.py +10 -0
  58. ml_tools/{_core → ML_finalize_handler}/_ML_finalize_handler.py +3 -7
  59. ml_tools/ML_finalize_handler/__init__.py +10 -0
  60. ml_tools/ML_finalize_handler/_imprimir.py +8 -0
  61. ml_tools/ML_inference/__init__.py +22 -0
  62. ml_tools/ML_inference/_base_inference.py +166 -0
  63. ml_tools/{_core/_ML_chaining_inference.py → ML_inference/_chain_inference.py} +14 -17
  64. ml_tools/ML_inference/_dragon_inference.py +332 -0
  65. ml_tools/ML_inference/_imprimir.py +11 -0
  66. ml_tools/ML_inference/_multi_inference.py +180 -0
  67. ml_tools/ML_inference_sequence/__init__.py +10 -0
  68. ml_tools/ML_inference_sequence/_imprimir.py +8 -0
  69. ml_tools/{_core/_ML_sequence_inference.py → ML_inference_sequence/_sequence_inference.py} +11 -15
  70. ml_tools/ML_inference_vision/__init__.py +10 -0
  71. ml_tools/ML_inference_vision/_imprimir.py +8 -0
  72. ml_tools/{_core/_ML_vision_inference.py → ML_inference_vision/_vision_inference.py} +15 -19
  73. ml_tools/ML_models/__init__.py +32 -0
  74. ml_tools/{_core/_ML_models_advanced.py → ML_models/_advanced_models.py} +22 -18
  75. ml_tools/ML_models/_base_mlp_attention.py +198 -0
  76. ml_tools/{_core/_models_advanced_base.py → ML_models/_base_save_load.py} +73 -49
  77. ml_tools/ML_models/_dragon_tabular.py +248 -0
  78. ml_tools/ML_models/_imprimir.py +18 -0
  79. ml_tools/ML_models/_mlp_attention.py +134 -0
  80. ml_tools/{_core → ML_models}/_models_advanced_helpers.py +13 -13
  81. ml_tools/ML_models_sequence/__init__.py +10 -0
  82. ml_tools/ML_models_sequence/_imprimir.py +8 -0
  83. ml_tools/{_core/_ML_sequence_models.py → ML_models_sequence/_sequence_models.py} +5 -8
  84. ml_tools/ML_models_vision/__init__.py +29 -0
  85. ml_tools/ML_models_vision/_base_wrapper.py +254 -0
  86. ml_tools/ML_models_vision/_image_classification.py +182 -0
  87. ml_tools/ML_models_vision/_image_segmentation.py +108 -0
  88. ml_tools/ML_models_vision/_imprimir.py +16 -0
  89. ml_tools/ML_models_vision/_object_detection.py +135 -0
  90. ml_tools/ML_optimization/__init__.py +21 -0
  91. ml_tools/ML_optimization/_imprimir.py +13 -0
  92. ml_tools/{_core/_ML_optimization_pareto.py → ML_optimization/_multi_dragon.py} +18 -24
  93. ml_tools/ML_optimization/_single_dragon.py +203 -0
  94. ml_tools/{_core/_ML_optimization.py → ML_optimization/_single_manual.py} +75 -213
  95. ml_tools/{_core → ML_scaler}/_ML_scaler.py +8 -11
  96. ml_tools/ML_scaler/__init__.py +10 -0
  97. ml_tools/ML_scaler/_imprimir.py +8 -0
  98. ml_tools/ML_trainer/__init__.py +20 -0
  99. ml_tools/ML_trainer/_base_trainer.py +297 -0
  100. ml_tools/ML_trainer/_dragon_detection_trainer.py +402 -0
  101. ml_tools/ML_trainer/_dragon_sequence_trainer.py +540 -0
  102. ml_tools/ML_trainer/_dragon_trainer.py +1160 -0
  103. ml_tools/ML_trainer/_imprimir.py +10 -0
  104. ml_tools/{ML_utilities.py → ML_utilities/__init__.py} +14 -6
  105. ml_tools/ML_utilities/_artifact_finder.py +382 -0
  106. ml_tools/ML_utilities/_imprimir.py +16 -0
  107. ml_tools/ML_utilities/_inspection.py +325 -0
  108. ml_tools/ML_utilities/_train_tools.py +205 -0
  109. ml_tools/{ML_vision_transformers.py → ML_vision_transformers/__init__.py} +9 -6
  110. ml_tools/{_core/_ML_vision_transformers.py → ML_vision_transformers/_core_transforms.py} +11 -155
  111. ml_tools/ML_vision_transformers/_imprimir.py +14 -0
  112. ml_tools/ML_vision_transformers/_offline_augmentation.py +159 -0
  113. ml_tools/{_core/_PSO_optimization.py → PSO_optimization/_PSO.py} +58 -15
  114. ml_tools/{PSO_optimization.py → PSO_optimization/__init__.py} +5 -3
  115. ml_tools/PSO_optimization/_imprimir.py +10 -0
  116. ml_tools/SQL/__init__.py +7 -0
  117. ml_tools/{_core/_SQL.py → SQL/_dragon_SQL.py} +7 -11
  118. ml_tools/SQL/_imprimir.py +8 -0
  119. ml_tools/{_core → VIF}/_VIF_factor.py +5 -8
  120. ml_tools/{VIF_factor.py → VIF/__init__.py} +4 -2
  121. ml_tools/VIF/_imprimir.py +10 -0
  122. ml_tools/_core/__init__.py +7 -1
  123. ml_tools/_core/_logger.py +8 -18
  124. ml_tools/_core/_schema_load_ops.py +43 -0
  125. ml_tools/_core/_script_info.py +2 -2
  126. ml_tools/{data_exploration.py → data_exploration/__init__.py} +32 -16
  127. ml_tools/data_exploration/_analysis.py +214 -0
  128. ml_tools/data_exploration/_cleaning.py +566 -0
  129. ml_tools/data_exploration/_features.py +583 -0
  130. ml_tools/data_exploration/_imprimir.py +32 -0
  131. ml_tools/data_exploration/_plotting.py +487 -0
  132. ml_tools/data_exploration/_schema_ops.py +176 -0
  133. ml_tools/{ensemble_evaluation.py → ensemble_evaluation/__init__.py} +6 -4
  134. ml_tools/{_core → ensemble_evaluation}/_ensemble_evaluation.py +3 -7
  135. ml_tools/ensemble_evaluation/_imprimir.py +14 -0
  136. ml_tools/{ensemble_inference.py → ensemble_inference/__init__.py} +5 -3
  137. ml_tools/{_core → ensemble_inference}/_ensemble_inference.py +15 -18
  138. ml_tools/ensemble_inference/_imprimir.py +9 -0
  139. ml_tools/{ensemble_learning.py → ensemble_learning/__init__.py} +4 -6
  140. ml_tools/{_core → ensemble_learning}/_ensemble_learning.py +7 -10
  141. ml_tools/ensemble_learning/_imprimir.py +10 -0
  142. ml_tools/{excel_handler.py → excel_handler/__init__.py} +5 -3
  143. ml_tools/{_core → excel_handler}/_excel_handler.py +6 -10
  144. ml_tools/excel_handler/_imprimir.py +13 -0
  145. ml_tools/{keys.py → keys/__init__.py} +4 -1
  146. ml_tools/keys/_imprimir.py +11 -0
  147. ml_tools/{_core → keys}/_keys.py +2 -0
  148. ml_tools/{math_utilities.py → math_utilities/__init__.py} +5 -2
  149. ml_tools/math_utilities/_imprimir.py +11 -0
  150. ml_tools/{_core → math_utilities}/_math_utilities.py +1 -5
  151. ml_tools/{optimization_tools.py → optimization_tools/__init__.py} +9 -4
  152. ml_tools/optimization_tools/_imprimir.py +13 -0
  153. ml_tools/optimization_tools/_optimization_bounds.py +236 -0
  154. ml_tools/optimization_tools/_optimization_plots.py +218 -0
  155. ml_tools/{path_manager.py → path_manager/__init__.py} +6 -3
  156. ml_tools/{_core/_path_manager.py → path_manager/_dragonmanager.py} +11 -347
  157. ml_tools/path_manager/_imprimir.py +15 -0
  158. ml_tools/path_manager/_path_tools.py +346 -0
  159. ml_tools/plot_fonts/__init__.py +8 -0
  160. ml_tools/plot_fonts/_imprimir.py +8 -0
  161. ml_tools/{_core → plot_fonts}/_plot_fonts.py +2 -5
  162. ml_tools/schema/__init__.py +15 -0
  163. ml_tools/schema/_feature_schema.py +223 -0
  164. ml_tools/schema/_gui_schema.py +191 -0
  165. ml_tools/schema/_imprimir.py +10 -0
  166. ml_tools/{serde.py → serde/__init__.py} +4 -2
  167. ml_tools/serde/_imprimir.py +10 -0
  168. ml_tools/{_core → serde}/_serde.py +3 -8
  169. ml_tools/{utilities.py → utilities/__init__.py} +11 -6
  170. ml_tools/utilities/_imprimir.py +18 -0
  171. ml_tools/{_core/_utilities.py → utilities/_utility_save_load.py} +13 -190
  172. ml_tools/utilities/_utility_tools.py +192 -0
  173. dragon_ml_toolbox-19.13.0.dist-info/RECORD +0 -111
  174. ml_tools/ML_chaining_inference.py +0 -8
  175. ml_tools/ML_configuration.py +0 -86
  176. ml_tools/ML_configuration_pytab.py +0 -14
  177. ml_tools/ML_datasetmaster.py +0 -10
  178. ml_tools/ML_evaluation.py +0 -16
  179. ml_tools/ML_evaluation_multi.py +0 -12
  180. ml_tools/ML_finalize_handler.py +0 -8
  181. ml_tools/ML_inference.py +0 -12
  182. ml_tools/ML_models.py +0 -14
  183. ml_tools/ML_models_advanced.py +0 -14
  184. ml_tools/ML_models_pytab.py +0 -14
  185. ml_tools/ML_optimization.py +0 -14
  186. ml_tools/ML_optimization_pareto.py +0 -8
  187. ml_tools/ML_scaler.py +0 -8
  188. ml_tools/ML_sequence_datasetmaster.py +0 -8
  189. ml_tools/ML_sequence_evaluation.py +0 -10
  190. ml_tools/ML_sequence_inference.py +0 -8
  191. ml_tools/ML_sequence_models.py +0 -8
  192. ml_tools/ML_trainer.py +0 -12
  193. ml_tools/ML_vision_datasetmaster.py +0 -12
  194. ml_tools/ML_vision_evaluation.py +0 -10
  195. ml_tools/ML_vision_inference.py +0 -8
  196. ml_tools/ML_vision_models.py +0 -18
  197. ml_tools/SQL.py +0 -8
  198. ml_tools/_core/_ETL_cleaning.py +0 -694
  199. ml_tools/_core/_IO_tools.py +0 -498
  200. ml_tools/_core/_ML_callbacks.py +0 -702
  201. ml_tools/_core/_ML_configuration.py +0 -1332
  202. ml_tools/_core/_ML_configuration_pytab.py +0 -102
  203. ml_tools/_core/_ML_evaluation.py +0 -867
  204. ml_tools/_core/_ML_evaluation_multi.py +0 -544
  205. ml_tools/_core/_ML_inference.py +0 -646
  206. ml_tools/_core/_ML_models.py +0 -668
  207. ml_tools/_core/_ML_models_pytab.py +0 -693
  208. ml_tools/_core/_ML_trainer.py +0 -2323
  209. ml_tools/_core/_ML_utilities.py +0 -886
  210. ml_tools/_core/_ML_vision_models.py +0 -644
  211. ml_tools/_core/_data_exploration.py +0 -1901
  212. ml_tools/_core/_optimization_tools.py +0 -493
  213. ml_tools/_core/_schema.py +0 -359
  214. ml_tools/plot_fonts.py +0 -8
  215. ml_tools/schema.py +0 -12
  216. {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/WHEEL +0 -0
  217. {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/licenses/LICENSE +0 -0
  218. {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +0 -0
  219. {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/top_level.txt +0 -0
@@ -1,1332 +0,0 @@
1
- from typing import Union, Optional, List, Any, Dict, Literal, Tuple
2
- from pathlib import Path
3
- from collections.abc import Mapping
4
- import numpy as np
5
-
6
- from ._schema import FeatureSchema
7
- from ._script_info import _script_info
8
- from ._logger import get_logger
9
- from ._path_manager import sanitize_filename, make_fullpath
10
- from ._keys import MLTaskKeys
11
-
12
-
13
- _LOGGER = get_logger("Configuration")
14
-
15
-
16
- __all__ = [
17
- # --- Metrics Formats ---
18
- "RegressionMetricsFormat",
19
- "MultiTargetRegressionMetricsFormat",
20
- "BinaryClassificationMetricsFormat",
21
- "MultiClassClassificationMetricsFormat",
22
- "BinaryImageClassificationMetricsFormat",
23
- "MultiClassImageClassificationMetricsFormat",
24
- "MultiLabelBinaryClassificationMetricsFormat",
25
- "BinarySegmentationMetricsFormat",
26
- "MultiClassSegmentationMetricsFormat",
27
- "SequenceValueMetricsFormat",
28
- "SequenceSequenceMetricsFormat",
29
-
30
- # --- Finalize Configs ---
31
- "FinalizeBinaryClassification",
32
- "FinalizeBinarySegmentation",
33
- "FinalizeBinaryImageClassification",
34
- "FinalizeMultiClassClassification",
35
- "FinalizeMultiClassImageClassification",
36
- "FinalizeMultiClassSegmentation",
37
- "FinalizeMultiLabelBinaryClassification",
38
- "FinalizeMultiTargetRegression",
39
- "FinalizeRegression",
40
- "FinalizeObjectDetection",
41
- "FinalizeSequenceSequencePrediction",
42
- "FinalizeSequenceValuePrediction",
43
-
44
- # --- Model Parameter Configs ---
45
- "DragonMLPParams",
46
- "DragonAttentionMLPParams",
47
- "DragonMultiHeadAttentionNetParams",
48
- "DragonTabularTransformerParams",
49
- "DragonGateParams",
50
- "DragonNodeParams",
51
- "DragonTabNetParams",
52
- "DragonAutoIntParams",
53
-
54
- # --- Training Config ---
55
- "DragonTrainingConfig",
56
- "DragonParetoConfig"
57
- ]
58
-
59
-
60
- # --- Private base classes ---
61
-
62
- class _BaseClassificationFormat:
63
- """
64
- [PRIVATE] Base configuration for single-label classification metrics.
65
- """
66
- def __init__(self,
67
- cmap: str="BuGn",
68
- ROC_PR_line: str='darkorange',
69
- calibration_bins: int=15,
70
- xtick_size: int=22,
71
- ytick_size: int=22,
72
- legend_size: int=26,
73
- font_size: int=26,
74
- cm_font_size: int=26) -> None:
75
- """
76
- Initializes the formatting configuration for single-label classification metrics.
77
-
78
- Args:
79
- cmap (str): The matplotlib colormap name for the confusion matrix
80
- and report heatmap.
81
- - Sequential options: 'Blues', 'Greens', 'Reds', 'Oranges', 'Purples'
82
- - Diverging options: 'coolwarm', 'viridis', 'plasma', 'inferno'
83
-
84
- ROC_PR_line (str): The color name or hex code for the line plotted
85
- on the ROC and Precision-Recall curves.
86
- - Common color names: 'darkorange', 'cornflowerblue', 'crimson', 'forestgreen'
87
- - Hex codes: '#FF6347', '#4682B4'
88
-
89
- calibration_bins (int): The number of bins to use when
90
- creating the calibration (reliability) plot.
91
-
92
- font_size (int): The base font size to apply to the plots.
93
-
94
- xtick_size (int): Font size for x-axis tick labels.
95
-
96
- ytick_size (int): Font size for y-axis tick labels.
97
-
98
- legend_size (int): Font size for plot legends.
99
-
100
- cm_font_size (int): Font size for the confusion matrix.
101
-
102
- <br>
103
-
104
- ### [Matplotlib Colormaps](https://matplotlib.org/stable/users/explain/colors/colormaps.html)
105
-
106
- <br>
107
-
108
- ### [Matplotlib Colors](https://matplotlib.org/stable/gallery/color/named_colors.html)
109
- """
110
- self.cmap = cmap
111
- self.ROC_PR_line = ROC_PR_line
112
- self.calibration_bins = calibration_bins
113
- self.font_size = font_size
114
- self.xtick_size = xtick_size
115
- self.ytick_size = ytick_size
116
- self.legend_size = legend_size
117
- self.cm_font_size = cm_font_size
118
-
119
- def __repr__(self) -> str:
120
- parts = [
121
- f"cmap='{self.cmap}'",
122
- f"ROC_PR_line='{self.ROC_PR_line}'",
123
- f"calibration_bins={self.calibration_bins}",
124
- f"font_size={self.font_size}",
125
- f"xtick_size={self.xtick_size}",
126
- f"ytick_size={self.ytick_size}",
127
- f"legend_size={self.legend_size}",
128
- f"cm_font_size={self.cm_font_size}"
129
- ]
130
- return f"{self.__class__.__name__}({', '.join(parts)})"
131
-
132
-
133
- class _BaseMultiLabelFormat:
134
- """
135
- [PRIVATE] Base configuration for multi-label binary classification metrics.
136
- """
137
- def __init__(self,
138
- cmap: str = "BuGn",
139
- ROC_PR_line: str='darkorange',
140
- font_size: int = 25,
141
- xtick_size: int=20,
142
- ytick_size: int=20,
143
- legend_size: int=23) -> None:
144
- """
145
- Initializes the formatting configuration for multi-label classification metrics.
146
-
147
- Args:
148
- cmap (str): The matplotlib colormap name for the per-label
149
- confusion matrices.
150
- - Sequential options: 'Blues', 'Greens', 'Reds', 'Oranges', 'Purples'
151
- - Diverging options: 'coolwarm', 'viridis', 'plasma', 'inferno'
152
-
153
- ROC_PR_line (str): The color name or hex code for the line plotted
154
- on the ROC and Precision-Recall curves (one for each label).
155
- - Common color names: 'darkorange', 'cornflowerblue', 'crimson', 'forestgreen'
156
- - Hex codes: '#FF6347', '#4682B4'
157
-
158
- font_size (int): The base font size to apply to the plots.
159
-
160
- xtick_size (int): Font size for x-axis tick labels.
161
-
162
- ytick_size (int): Font size for y-axis tick labels.
163
-
164
- legend_size (int): Font size for plot legends.
165
-
166
- <br>
167
-
168
- ### [Matplotlib Colormaps](https://matplotlib.org/stable/users/explain/colors/colormaps.html)
169
-
170
- <br>
171
-
172
- ### [Matplotlib Colors](https://matplotlib.org/stable/gallery/color/named_colors.html)
173
- """
174
- self.cmap = cmap
175
- self.ROC_PR_line = ROC_PR_line
176
- self.font_size = font_size
177
- self.xtick_size = xtick_size
178
- self.ytick_size = ytick_size
179
- self.legend_size = legend_size
180
-
181
- def __repr__(self) -> str:
182
- parts = [
183
- f"cmap='{self.cmap}'",
184
- f"ROC_PR_line='{self.ROC_PR_line}'",
185
- f"font_size={self.font_size}",
186
- f"xtick_size={self.xtick_size}",
187
- f"ytick_size={self.ytick_size}",
188
- f"legend_size={self.legend_size}"
189
- ]
190
- return f"{self.__class__.__name__}({', '.join(parts)})"
191
-
192
-
193
- class _BaseRegressionFormat:
194
- """
195
- [PRIVATE] Base configuration for regression metrics.
196
- """
197
- def __init__(self,
198
- font_size: int=26,
199
- scatter_color: str='tab:blue',
200
- scatter_alpha: float=0.6,
201
- ideal_line_color: str='k',
202
- residual_line_color: str='red',
203
- hist_bins: Union[int, str] = 'auto',
204
- xtick_size: int=22,
205
- ytick_size: int=22) -> None:
206
- """
207
- Initializes the formatting configuration for regression metrics.
208
-
209
- Args:
210
- font_size (int): The base font size to apply to the plots.
211
- scatter_color (str): Matplotlib color for the scatter plot points.
212
- - Common color names: 'tab:blue', 'crimson', 'forestgreen', '#4682B4'
213
- scatter_alpha (float): Alpha transparency for scatter plot points.
214
- ideal_line_color (str): Matplotlib color for the 'ideal' y=x line in the
215
- True vs. Predicted plot.
216
- - Common color names: 'k', 'red', 'darkgrey', '#FF6347'
217
- residual_line_color (str): Matplotlib color for the y=0 line in the
218
- Residual plot.
219
- - Common color names: 'red', 'blue', 'k', '#4682B4'
220
- hist_bins (int | str): The number of bins for the residuals histogram.
221
- Defaults to 'auto' to use seaborn's automatic bin selection.
222
- - Options: 'auto', 'sqrt', 10, 20
223
- xtick_size (int): Font size for x-axis tick labels.
224
- ytick_size (int): Font size for y-axis tick labels.
225
-
226
- <br>
227
-
228
- ### [Matplotlib Colors](https://matplotlib.org/stable/gallery/color/named_colors.html)
229
- """
230
- self.font_size = font_size
231
- self.scatter_color = scatter_color
232
- self.scatter_alpha = scatter_alpha
233
- self.ideal_line_color = ideal_line_color
234
- self.residual_line_color = residual_line_color
235
- self.hist_bins = hist_bins
236
- self.xtick_size = xtick_size
237
- self.ytick_size = ytick_size
238
-
239
- def __repr__(self) -> str:
240
- parts = [
241
- f"font_size={self.font_size}",
242
- f"scatter_color='{self.scatter_color}'",
243
- f"scatter_alpha={self.scatter_alpha}",
244
- f"ideal_line_color='{self.ideal_line_color}'",
245
- f"residual_line_color='{self.residual_line_color}'",
246
- f"hist_bins='{self.hist_bins}'",
247
- f"xtick_size={self.xtick_size}",
248
- f"ytick_size={self.ytick_size}"
249
- ]
250
- return f"{self.__class__.__name__}({', '.join(parts)})"
251
-
252
-
253
- class _BaseSegmentationFormat:
254
- """
255
- [PRIVATE] Base configuration for segmentation metrics.
256
- """
257
- def __init__(self,
258
- heatmap_cmap: str = "BuGn",
259
- cm_cmap: str = "Purples",
260
- font_size: int = 16) -> None:
261
- """
262
- Initializes the formatting configuration for segmentation metrics.
263
-
264
- Args:
265
- heatmap_cmap (str): The matplotlib colormap name for the per-class
266
- metrics heatmap.
267
- - Sequential options: 'viridis', 'plasma', 'inferno', 'cividis'
268
- - Diverging options: 'coolwarm', 'bwr', 'seismic'
269
- cm_cmap (str): The matplotlib colormap name for the pixel-level
270
- confusion matrix.
271
- - Sequential options: 'Blues', 'Greens', 'Reds', 'Oranges'
272
- font_size (int): The base font size to apply to the plots.
273
-
274
- <br>
275
-
276
- ### [Matplotlib Colormaps](https://matplotlib.org/stable/users/explain/colors/colormaps.html)
277
- """
278
- self.heatmap_cmap = heatmap_cmap
279
- self.cm_cmap = cm_cmap
280
- self.font_size = font_size
281
-
282
- def __repr__(self) -> str:
283
- parts = [
284
- f"heatmap_cmap='{self.heatmap_cmap}'",
285
- f"cm_cmap='{self.cm_cmap}'",
286
- f"font_size={self.font_size}"
287
- ]
288
- return f"{self.__class__.__name__}({', '.join(parts)})"
289
-
290
-
291
- class _BaseSequenceValueFormat:
292
- """
293
- [PRIVATE] Base configuration for sequence to value metrics.
294
- """
295
- def __init__(self,
296
- font_size: int=25,
297
- scatter_color: str='tab:blue',
298
- scatter_alpha: float=0.6,
299
- ideal_line_color: str='k',
300
- residual_line_color: str='red',
301
- hist_bins: Union[int, str] = 'auto') -> None:
302
- """
303
- Initializes the formatting configuration for sequence to value metrics.
304
-
305
- Args:
306
- font_size (int): The base font size to apply to the plots.
307
- scatter_color (str): Matplotlib color for the scatter plot points.
308
- - Common color names: 'tab:blue', 'crimson', 'forestgreen', '#4682B4'
309
- scatter_alpha (float): Alpha transparency for scatter plot points.
310
- ideal_line_color (str): Matplotlib color for the 'ideal' y=x line in the
311
- True vs. Predicted plot.
312
- - Common color names: 'k', 'red', 'darkgrey', '#FF6347'
313
- residual_line_color (str): Matplotlib color for the y=0 line in the
314
- Residual plot.
315
- - Common color names: 'red', 'blue', 'k', '#4682B4'
316
- hist_bins (int | str): The number of bins for the residuals histogram.
317
- Defaults to 'auto' to use seaborn's automatic bin selection.
318
- - Options: 'auto', 'sqrt', 10, 20
319
-
320
- <br>
321
-
322
- ### [Matplotlib Colors](https://matplotlib.org/stable/gallery/color/named_colors.html)
323
- """
324
- self.font_size = font_size
325
- self.scatter_color = scatter_color
326
- self.scatter_alpha = scatter_alpha
327
- self.ideal_line_color = ideal_line_color
328
- self.residual_line_color = residual_line_color
329
- self.hist_bins = hist_bins
330
-
331
- def __repr__(self) -> str:
332
- parts = [
333
- f"font_size={self.font_size}",
334
- f"scatter_color='{self.scatter_color}'",
335
- f"scatter_alpha={self.scatter_alpha}",
336
- f"ideal_line_color='{self.ideal_line_color}'",
337
- f"residual_line_color='{self.residual_line_color}'",
338
- f"hist_bins='{self.hist_bins}'"
339
- ]
340
- return f"{self.__class__.__name__}({', '.join(parts)})"
341
-
342
-
343
- class _BaseSequenceSequenceFormat:
344
- """
345
- [PRIVATE] Base configuration for sequence-to-sequence metrics.
346
- """
347
- def __init__(self,
348
- font_size: int = 25,
349
- grid_style: str = '--',
350
- rmse_color: str = 'tab:blue',
351
- rmse_marker: str = 'o-',
352
- mae_color: str = 'tab:orange',
353
- mae_marker: str = 's--'):
354
- """
355
- Initializes the formatting configuration for seq-to-seq metrics.
356
-
357
- Args:
358
- font_size (int): The base font size to apply to the plots.
359
- grid_style (str): Matplotlib linestyle for the plot grid.
360
- - Options: '--' (dashed), ':' (dotted), '-.' (dash-dot), '-' (solid)
361
- rmse_color (str): Matplotlib color for the RMSE line.
362
- - Common color names: 'tab:blue', 'crimson', 'forestgreen', '#4682B4'
363
- rmse_marker (str): Matplotlib marker style for the RMSE line.
364
- - Options: 'o-' (circle), 's--' (square), '^:' (triangle), 'x' (x marker)
365
- mae_color (str): Matplotlib color for the MAE line.
366
- - Common color names: 'tab:orange', 'purple', 'black', '#FF6347'
367
- mae_marker (str): Matplotlib marker style for the MAE line.
368
- - Options: 's--', 'o-', 'v:', '+' (plus marker)
369
-
370
- <br>
371
-
372
- ### [Matplotlib Colors](https://matplotlib.org/stable/gallery/color/named_colors.html)
373
-
374
- <br>
375
-
376
- ### [Matplotlib Linestyles](https://matplotlib.org/stable/gallery/lines_bars_and_markers/linestyles.html)
377
-
378
- <br>
379
-
380
- ### [Matplotlib Markers](https://matplotlib.org/stable/api/markers_api.html)
381
- """
382
- self.font_size = font_size
383
- self.grid_style = grid_style
384
- self.rmse_color = rmse_color
385
- self.rmse_marker = rmse_marker
386
- self.mae_color = mae_color
387
- self.mae_marker = mae_marker
388
-
389
- def __repr__(self) -> str:
390
- parts = [
391
- f"font_size={self.font_size}",
392
- f"grid_style='{self.grid_style}'",
393
- f"rmse_color='{self.rmse_color}'",
394
- f"mae_color='{self.mae_color}'"
395
- ]
396
- return f"{self.__class__.__name__}({', '.join(parts)})"
397
-
398
-
399
- class _BaseModelParams(Mapping):
400
- """
401
- [PRIVATE] Base class for model parameter configs.
402
-
403
- Inherits from Mapping to behave like a dictionary, enabling
404
- `**params` unpacking directly into model constructors.
405
- """
406
- def __getitem__(self, key: str) -> Any:
407
- return self.__dict__[key]
408
-
409
- def __iter__(self):
410
- return iter(self.__dict__)
411
-
412
- def __len__(self) -> int:
413
- return len(self.__dict__)
414
-
415
- def __or__(self, other) -> Dict[str, Any]:
416
- """Allows merging with other Mappings using the | operator."""
417
- if isinstance(other, Mapping):
418
- return dict(self) | dict(other)
419
- return NotImplemented
420
-
421
- def __ror__(self, other) -> Dict[str, Any]:
422
- """Allows merging with other Mappings using the | operator."""
423
- if isinstance(other, Mapping):
424
- return dict(other) | dict(self)
425
- return NotImplemented
426
-
427
- def __repr__(self) -> str:
428
- """Returns a formatted multi-line string representation."""
429
- class_name = self.__class__.__name__
430
- # Format parameters for clean logging
431
- params = []
432
- for k, v in self.__dict__.items():
433
- # If value is huge (like FeatureSchema), use its own repr
434
- val_str = repr(v)
435
- params.append(f" {k}={val_str}")
436
-
437
- params_str = ",\n".join(params)
438
- return f"{class_name}(\n{params_str}\n)"
439
-
440
- def to_log(self) -> Dict[str, Any]:
441
- """
442
- Safely converts complex types (like FeatureSchema) to their string
443
- representation for cleaner JSON logging.
444
- """
445
- clean_dict = {}
446
- for k, v in self.__dict__.items():
447
- if isinstance(v, FeatureSchema):
448
- # Force the repr() string, otherwise json.dump treats it as a list
449
- clean_dict[k] = repr(v)
450
- elif isinstance(v, Path):
451
- # JSON cannot serialize Path objects, convert to string
452
- clean_dict[k] = str(v)
453
- else:
454
- clean_dict[k] = v
455
- return clean_dict
456
-
457
-
458
- # --- Public API classes ---
459
-
460
- # ----------------------------
461
- # Model Parameters Configurations
462
- # ----------------------------
463
-
464
- # --- Standard Models ---
465
-
466
- class DragonMLPParams(_BaseModelParams):
467
- def __init__(self,
468
- in_features: int,
469
- out_targets: int,
470
- hidden_layers: List[int],
471
- drop_out: float = 0.2) -> None:
472
- self.in_features = in_features
473
- self.out_targets = out_targets
474
- self.hidden_layers = hidden_layers
475
- self.drop_out = drop_out
476
-
477
-
478
- class DragonAttentionMLPParams(_BaseModelParams):
479
- def __init__(self,
480
- in_features: int,
481
- out_targets: int,
482
- hidden_layers: List[int],
483
- drop_out: float = 0.2) -> None:
484
- self.in_features = in_features
485
- self.out_targets = out_targets
486
- self.hidden_layers = hidden_layers
487
- self.drop_out = drop_out
488
-
489
-
490
- class DragonMultiHeadAttentionNetParams(_BaseModelParams):
491
- def __init__(self,
492
- in_features: int,
493
- out_targets: int,
494
- hidden_layers: List[int],
495
- drop_out: float = 0.2,
496
- num_heads: int = 4,
497
- attention_dropout: float = 0.1) -> None:
498
- self.in_features = in_features
499
- self.out_targets = out_targets
500
- self.hidden_layers = hidden_layers
501
- self.drop_out = drop_out
502
- self.num_heads = num_heads
503
- self.attention_dropout = attention_dropout
504
-
505
-
506
- class DragonTabularTransformerParams(_BaseModelParams):
507
- def __init__(self, *,
508
- schema: FeatureSchema,
509
- out_targets: int,
510
- embedding_dim: int = 256,
511
- num_heads: int = 8,
512
- num_layers: int = 6,
513
- dropout: float = 0.2) -> None:
514
- self.schema = schema
515
- self.out_targets = out_targets
516
- self.embedding_dim = embedding_dim
517
- self.num_heads = num_heads
518
- self.num_layers = num_layers
519
- self.dropout = dropout
520
-
521
- # --- Advanced Models ---
522
-
523
- class DragonGateParams(_BaseModelParams):
524
- def __init__(self, *,
525
- schema: FeatureSchema,
526
- out_targets: int,
527
- embedding_dim: int = 16,
528
- gflu_stages: int = 6,
529
- gflu_dropout: float = 0.1,
530
- num_trees: int = 20,
531
- tree_depth: int = 4,
532
- tree_dropout: float = 0.1,
533
- chain_trees: bool = False,
534
- tree_wise_attention: bool = True,
535
- tree_wise_attention_dropout: float = 0.1,
536
- binning_activation: Literal['entmoid', 'sparsemoid', 'sigmoid'] = "entmoid",
537
- feature_mask_function: Literal['entmax', 'sparsemax', 'softmax', 't-softmax'] = "entmax",
538
- share_head_weights: bool = True,
539
- batch_norm_continuous: bool = True) -> None:
540
- self.schema = schema
541
- self.out_targets = out_targets
542
- self.embedding_dim = embedding_dim
543
- self.gflu_stages = gflu_stages
544
- self.gflu_dropout = gflu_dropout
545
- self.num_trees = num_trees
546
- self.tree_depth = tree_depth
547
- self.tree_dropout = tree_dropout
548
- self.chain_trees = chain_trees
549
- self.tree_wise_attention = tree_wise_attention
550
- self.tree_wise_attention_dropout = tree_wise_attention_dropout
551
- self.binning_activation = binning_activation
552
- self.feature_mask_function = feature_mask_function
553
- self.share_head_weights = share_head_weights
554
- self.batch_norm_continuous = batch_norm_continuous
555
-
556
-
557
- class DragonNodeParams(_BaseModelParams):
558
- def __init__(self, *,
559
- schema: FeatureSchema,
560
- out_targets: int,
561
- embedding_dim: int = 24,
562
- num_trees: int = 1024,
563
- num_layers: int = 2,
564
- tree_depth: int = 6,
565
- additional_tree_output_dim: int = 3,
566
- max_features: Optional[int] = None,
567
- input_dropout: float = 0.0,
568
- embedding_dropout: float = 0.0,
569
- choice_function: Literal['entmax', 'sparsemax', 'softmax'] = 'entmax',
570
- bin_function: Literal['entmoid', 'sparsemoid', 'sigmoid'] = 'entmoid',
571
- batch_norm_continuous: bool = False) -> None:
572
- self.schema = schema
573
- self.out_targets = out_targets
574
- self.embedding_dim = embedding_dim
575
- self.num_trees = num_trees
576
- self.num_layers = num_layers
577
- self.tree_depth = tree_depth
578
- self.additional_tree_output_dim = additional_tree_output_dim
579
- self.max_features = max_features
580
- self.input_dropout = input_dropout
581
- self.embedding_dropout = embedding_dropout
582
- self.choice_function = choice_function
583
- self.bin_function = bin_function
584
- self.batch_norm_continuous = batch_norm_continuous
585
-
586
-
587
- class DragonAutoIntParams(_BaseModelParams):
588
- def __init__(self, *,
589
- schema: FeatureSchema,
590
- out_targets: int,
591
- embedding_dim: int = 32,
592
- attn_embed_dim: int = 32,
593
- num_heads: int = 2,
594
- num_attn_blocks: int = 3,
595
- attn_dropout: float = 0.1,
596
- has_residuals: bool = True,
597
- attention_pooling: bool = True,
598
- deep_layers: bool = True,
599
- layers: str = "128-64-32",
600
- activation: str = "ReLU",
601
- embedding_dropout: float = 0.0,
602
- batch_norm_continuous: bool = False) -> None:
603
- self.schema = schema
604
- self.out_targets = out_targets
605
- self.embedding_dim = embedding_dim
606
- self.attn_embed_dim = attn_embed_dim
607
- self.num_heads = num_heads
608
- self.num_attn_blocks = num_attn_blocks
609
- self.attn_dropout = attn_dropout
610
- self.has_residuals = has_residuals
611
- self.attention_pooling = attention_pooling
612
- self.deep_layers = deep_layers
613
- self.layers = layers
614
- self.activation = activation
615
- self.embedding_dropout = embedding_dropout
616
- self.batch_norm_continuous = batch_norm_continuous
617
-
618
-
619
- class DragonTabNetParams(_BaseModelParams):
620
- def __init__(self, *,
621
- schema: FeatureSchema,
622
- out_targets: int,
623
- n_d: int = 8,
624
- n_a: int = 8,
625
- n_steps: int = 3,
626
- gamma: float = 1.3,
627
- n_independent: int = 2,
628
- n_shared: int = 2,
629
- virtual_batch_size: int = 128,
630
- momentum: float = 0.02,
631
- mask_type: Literal['sparsemax', 'entmax', 'softmax'] = 'sparsemax',
632
- batch_norm_continuous: bool = False) -> None:
633
- self.schema = schema
634
- self.out_targets = out_targets
635
- self.n_d = n_d
636
- self.n_a = n_a
637
- self.n_steps = n_steps
638
- self.gamma = gamma
639
- self.n_independent = n_independent
640
- self.n_shared = n_shared
641
- self.virtual_batch_size = virtual_batch_size
642
- self.momentum = momentum
643
- self.mask_type = mask_type
644
- self.batch_norm_continuous = batch_norm_continuous
645
-
646
-
647
- # --- Training Configuration ---
648
-
649
- class DragonTrainingConfig(_BaseModelParams):
650
- """
651
- Configuration object for the training process.
652
-
653
- Can be unpacked as a dictionary for logging or accessed as an object.
654
-
655
- Accepts arbitrary keyword arguments which are set as instance attributes.
656
- """
657
- def __init__(self,
658
- validation_size: float,
659
- test_size: float,
660
- initial_learning_rate: float,
661
- batch_size: int,
662
- random_state: int = 101,
663
- # early_stop_patience: Optional[int] = None,
664
- # scheduler_patience: Optional[int] = None,
665
- # scheduler_lr_factor: Optional[float] = None,
666
- **kwargs: Any) -> None:
667
- """
668
- Args:
669
- validation_size (float): Proportion of data for validation set.
670
- test_size (float): Proportion of data for test set.
671
- initial_learning_rate (float): Starting learning rate.
672
- batch_size (int): Number of samples per training batch.
673
- random_state (int): Seed for reproducibility.
674
- **kwargs: Additional training parameters as key-value pairs.
675
- """
676
- self.validation_size = validation_size
677
- self.test_size = test_size
678
- self.initial_learning_rate = initial_learning_rate
679
- self.batch_size = batch_size
680
- self.random_state = random_state
681
- # self.early_stop_patience = early_stop_patience
682
- # self.scheduler_patience = scheduler_patience
683
- # self.scheduler_lr_factor = scheduler_lr_factor
684
-
685
- # Process kwargs with validation
686
- for key, value in kwargs.items():
687
- # Python guarantees 'key' is a string for **kwargs
688
-
689
- # Allow None in value
690
- if value is None:
691
- setattr(self, key, value)
692
- continue
693
-
694
- if isinstance(value, dict):
695
- _LOGGER.error("Nested dictionaries are not supported, unpack them first.")
696
- raise TypeError()
697
-
698
- # Check if value is a number or a string or a JSON supported type, except dict
699
- if not isinstance(value, (str, int, float, bool, list, tuple)):
700
- _LOGGER.error(f"Invalid type for configuration '{key}': {type(value).__name__}")
701
- raise TypeError()
702
-
703
- setattr(self, key, value)
704
-
705
-
706
- class DragonParetoConfig(_BaseModelParams):
707
- """
708
- Configuration object for the Pareto Optimization process.
709
- """
710
- def __init__(self,
711
- save_directory: Union[str, Path],
712
- target_objectives: Dict[str, Literal["min", "max"]],
713
- continuous_bounds_map: Union[Dict[str, Tuple[float, float]], Dict[str, List[float]], str, Path],
714
- columns_to_round: Optional[List[str]] = None,
715
- population_size: int = 500,
716
- generations: int = 1000,
717
- solutions_filename: str = "NonDominatedSolutions",
718
- float_precision: int = 4,
719
- log_interval: int = 10,
720
- plot_size: Tuple[int, int] = (10, 7),
721
- plot_font_size: int = 16,
722
- discretize_start_at_zero: bool = True):
723
- """
724
- Configure the Pareto Optimizer.
725
-
726
- Args:
727
- save_directory (str | Path): Directory to save artifacts.
728
- target_objectives (Dict[str, "min"|"max"]): Dictionary mapping target names to optimization direction.
729
- Example: {"price": "max", "error": "min"}
730
- continuous_bounds_map (Dict): Bounds for continuous features {name: (min, max)}. Or a path/str to a directory containing the "optimization_bounds.json" file.
731
- columns_to_round (List[str] | None): List of continuous column names that should be rounded to the nearest integer.
732
- population_size (int): Size of the genetic population.
733
- generations (int): Number of generations to run.
734
- solutions_filename (str): Filename for saving Pareto solutions.
735
- float_precision (int): Number of decimal places to round standard float columns.
736
- log_interval (int): Interval for logging progress.
737
- plot_size (Tuple[int, int]): Size of the 2D plots.
738
- plot_font_size (int): Font size for plot text.
739
- discretize_start_at_zero (bool): Categorical encoding start index. True=0, False=1.
740
- """
741
- # Validate string or Path
742
- valid_save_dir = make_fullpath(save_directory, make=True, enforce="directory")
743
-
744
- if isinstance(continuous_bounds_map, (str, Path)):
745
- continuous_bounds_map = make_fullpath(continuous_bounds_map, make=False, enforce="directory")
746
-
747
- self.save_directory = valid_save_dir
748
- self.target_objectives = target_objectives
749
- self.continuous_bounds_map = continuous_bounds_map
750
- self.columns_to_round = columns_to_round
751
- self.population_size = population_size
752
- self.generations = generations
753
- self.solutions_filename = solutions_filename
754
- self.float_precision = float_precision
755
- self.log_interval = log_interval
756
- self.plot_size = plot_size
757
- self.plot_font_size = plot_font_size
758
- self.discretize_start_at_zero = discretize_start_at_zero
759
-
760
- # ----------------------------
761
- # Metrics Configurations
762
- # ----------------------------
763
-
764
- # Regression
765
- class RegressionMetricsFormat(_BaseRegressionFormat):
766
- """
767
- Configuration for single-target regression.
768
- """
769
- def __init__(self,
770
- font_size: int=26,
771
- scatter_color: str='tab:blue',
772
- scatter_alpha: float=0.6,
773
- ideal_line_color: str='k',
774
- residual_line_color: str='red',
775
- hist_bins: Union[int, str] = 'auto',
776
- xtick_size: int=22,
777
- ytick_size: int=22) -> None:
778
- super().__init__(font_size=font_size,
779
- scatter_color=scatter_color,
780
- scatter_alpha=scatter_alpha,
781
- ideal_line_color=ideal_line_color,
782
- residual_line_color=residual_line_color,
783
- hist_bins=hist_bins,
784
- xtick_size=xtick_size,
785
- ytick_size=ytick_size)
786
-
787
-
788
- # Multitarget regression
789
- class MultiTargetRegressionMetricsFormat(_BaseRegressionFormat):
790
- """
791
- Configuration for multi-target regression.
792
- """
793
- def __init__(self,
794
- font_size: int=26,
795
- scatter_color: str='tab:blue',
796
- scatter_alpha: float=0.6,
797
- ideal_line_color: str='k',
798
- residual_line_color: str='red',
799
- hist_bins: Union[int, str] = 'auto',
800
- xtick_size: int=22,
801
- ytick_size: int=22) -> None:
802
- super().__init__(font_size=font_size,
803
- scatter_color=scatter_color,
804
- scatter_alpha=scatter_alpha,
805
- ideal_line_color=ideal_line_color,
806
- residual_line_color=residual_line_color,
807
- hist_bins=hist_bins,
808
- xtick_size=xtick_size,
809
- ytick_size=ytick_size)
810
-
811
-
812
- # Classification
813
- class BinaryClassificationMetricsFormat(_BaseClassificationFormat):
814
- """
815
- Configuration for binary classification.
816
- """
817
- def __init__(self,
818
- cmap: str="BuGn",
819
- ROC_PR_line: str='darkorange',
820
- calibration_bins: int=15,
821
- font_size: int=26,
822
- xtick_size: int=22,
823
- ytick_size: int=22,
824
- legend_size: int=26,
825
- cm_font_size: int=26
826
- ) -> None:
827
- super().__init__(cmap=cmap,
828
- ROC_PR_line=ROC_PR_line,
829
- calibration_bins=calibration_bins,
830
- font_size=font_size,
831
- xtick_size=xtick_size,
832
- ytick_size=ytick_size,
833
- legend_size=legend_size,
834
- cm_font_size=cm_font_size)
835
-
836
-
837
- class MultiClassClassificationMetricsFormat(_BaseClassificationFormat):
838
- """
839
- Configuration for multi-class classification.
840
- """
841
- def __init__(self,
842
- cmap: str="BuGn",
843
- ROC_PR_line: str='darkorange',
844
- calibration_bins: int=15,
845
- font_size: int=26,
846
- xtick_size: int=22,
847
- ytick_size: int=22,
848
- legend_size: int=26,
849
- cm_font_size: int=26
850
- ) -> None:
851
- super().__init__(cmap=cmap,
852
- ROC_PR_line=ROC_PR_line,
853
- calibration_bins=calibration_bins,
854
- font_size=font_size,
855
- xtick_size=xtick_size,
856
- ytick_size=ytick_size,
857
- legend_size=legend_size,
858
- cm_font_size=cm_font_size)
859
-
860
- class BinaryImageClassificationMetricsFormat(_BaseClassificationFormat):
861
- """
862
- Configuration for binary image classification.
863
- """
864
- def __init__(self,
865
- cmap: str="BuGn",
866
- ROC_PR_line: str='darkorange',
867
- calibration_bins: int=15,
868
- font_size: int=26,
869
- xtick_size: int=22,
870
- ytick_size: int=22,
871
- legend_size: int=26,
872
- cm_font_size: int=26
873
- ) -> None:
874
- super().__init__(cmap=cmap,
875
- ROC_PR_line=ROC_PR_line,
876
- calibration_bins=calibration_bins,
877
- font_size=font_size,
878
- xtick_size=xtick_size,
879
- ytick_size=ytick_size,
880
- legend_size=legend_size,
881
- cm_font_size=cm_font_size)
882
-
883
- class MultiClassImageClassificationMetricsFormat(_BaseClassificationFormat):
884
- """
885
- Configuration for multi-class image classification.
886
- """
887
- def __init__(self,
888
- cmap: str="BuGn",
889
- ROC_PR_line: str='darkorange',
890
- calibration_bins: int=15,
891
- font_size: int=26,
892
- xtick_size: int=22,
893
- ytick_size: int=22,
894
- legend_size: int=26,
895
- cm_font_size: int=26
896
- ) -> None:
897
- super().__init__(cmap=cmap,
898
- ROC_PR_line=ROC_PR_line,
899
- calibration_bins=calibration_bins,
900
- font_size=font_size,
901
- xtick_size=xtick_size,
902
- ytick_size=ytick_size,
903
- legend_size=legend_size,
904
- cm_font_size=cm_font_size)
905
-
906
- # Multi-Label classification
907
- class MultiLabelBinaryClassificationMetricsFormat(_BaseMultiLabelFormat):
908
- """
909
- Configuration for multi-label binary classification.
910
- """
911
- def __init__(self,
912
- cmap: str = "BuGn",
913
- ROC_PR_line: str='darkorange',
914
- font_size: int = 25,
915
- xtick_size: int=20,
916
- ytick_size: int=20,
917
- legend_size: int=23
918
- ) -> None:
919
- super().__init__(cmap=cmap,
920
- ROC_PR_line=ROC_PR_line,
921
- font_size=font_size,
922
- xtick_size=xtick_size,
923
- ytick_size=ytick_size,
924
- legend_size=legend_size)
925
-
926
- # Segmentation
927
- class BinarySegmentationMetricsFormat(_BaseSegmentationFormat):
928
- """
929
- Configuration for binary segmentation.
930
- """
931
- def __init__(self,
932
- heatmap_cmap: str = "BuGn",
933
- cm_cmap: str = "Purples",
934
- font_size: int = 16) -> None:
935
- super().__init__(heatmap_cmap=heatmap_cmap,
936
- cm_cmap=cm_cmap,
937
- font_size=font_size)
938
-
939
-
940
- class MultiClassSegmentationMetricsFormat(_BaseSegmentationFormat):
941
- """
942
- Configuration for multi-class segmentation.
943
- """
944
- def __init__(self,
945
- heatmap_cmap: str = "BuGn",
946
- cm_cmap: str = "Purples",
947
- font_size: int = 16) -> None:
948
- super().__init__(heatmap_cmap=heatmap_cmap,
949
- cm_cmap=cm_cmap,
950
- font_size=font_size)
951
-
952
-
953
- # Sequence
954
- class SequenceValueMetricsFormat(_BaseSequenceValueFormat):
955
- """
956
- Configuration for sequence-to-value prediction.
957
- """
958
- def __init__(self,
959
- font_size: int=25,
960
- scatter_color: str='tab:blue',
961
- scatter_alpha: float=0.6,
962
- ideal_line_color: str='k',
963
- residual_line_color: str='red',
964
- hist_bins: Union[int, str] = 'auto') -> None:
965
- super().__init__(font_size=font_size,
966
- scatter_color=scatter_color,
967
- scatter_alpha=scatter_alpha,
968
- ideal_line_color=ideal_line_color,
969
- residual_line_color=residual_line_color,
970
- hist_bins=hist_bins)
971
-
972
-
973
- class SequenceSequenceMetricsFormat(_BaseSequenceSequenceFormat):
974
- """
975
- Configuration for sequence-to-sequence prediction.
976
- """
977
- def __init__(self,
978
- font_size: int = 25,
979
- grid_style: str = '--',
980
- rmse_color: str = 'tab:blue',
981
- rmse_marker: str = 'o-',
982
- mae_color: str = 'tab:orange',
983
- mae_marker: str = 's--'):
984
- super().__init__(font_size=font_size,
985
- grid_style=grid_style,
986
- rmse_color=rmse_color,
987
- rmse_marker=rmse_marker,
988
- mae_color=mae_color,
989
- mae_marker=mae_marker)
990
-
991
-
992
- # -------- Finalize classes --------
993
- class _FinalizeModelTraining:
994
- """
995
- Base class for finalizing model training.
996
-
997
- This class is not intended to be instantiated directly. Instead, use one of its specific subclasses.
998
- """
999
- def __init__(self,
1000
- filename: str,
1001
- ) -> None:
1002
- self.filename = _validate_string(string=filename, attribute_name="filename", extension=".pth")
1003
- self.target_name: Optional[str] = None
1004
- self.target_names: Optional[list[str]] = None
1005
- self.classification_threshold: Optional[float] = None
1006
- self.class_map: Optional[dict[str,int]] = None
1007
- self.initial_sequence: Optional[np.ndarray] = None
1008
- self.sequence_length: Optional[int] = None
1009
- self.task: str = 'UNKNOWN'
1010
-
1011
-
1012
- class FinalizeRegression(_FinalizeModelTraining):
1013
- """Parameters for finalizing a single-target regression model."""
1014
- def __init__(self,
1015
- filename: str,
1016
- target_name: str,
1017
- ) -> None:
1018
- """Initializes the finalization parameters.
1019
-
1020
- Args:
1021
- filename (str): The name of the file to be saved.
1022
- target_name (str): The name of the target variable.
1023
- """
1024
- super().__init__(filename=filename)
1025
- self.target_name = _validate_string(string=target_name, attribute_name="Target name")
1026
- self.task = MLTaskKeys.REGRESSION
1027
-
1028
-
1029
- class FinalizeMultiTargetRegression(_FinalizeModelTraining):
1030
- """Parameters for finalizing a multi-target regression model."""
1031
- def __init__(self,
1032
- filename: str,
1033
- target_names: list[str],
1034
- ) -> None:
1035
- """Initializes the finalization parameters.
1036
-
1037
- Args:
1038
- filename (str): The name of the file to be saved.
1039
- target_names (list[str]): A list of names for the target variables.
1040
- """
1041
- super().__init__(filename=filename)
1042
- safe_names = [_validate_string(string=target_name, attribute_name="All target names") for target_name in target_names]
1043
- self.target_names = safe_names
1044
- self.task = MLTaskKeys.MULTITARGET_REGRESSION
1045
-
1046
-
1047
- class FinalizeBinaryClassification(_FinalizeModelTraining):
1048
- """Parameters for finalizing a binary classification model."""
1049
- def __init__(self,
1050
- filename: str,
1051
- target_name: str,
1052
- classification_threshold: float,
1053
- class_map: dict[str,int]
1054
- ) -> None:
1055
- """Initializes the finalization parameters.
1056
-
1057
- Args:
1058
- filename (str): The name of the file to be saved.
1059
- target_name (str): The name of the target variable.
1060
- classification_threshold (float): The cutoff threshold for classifying as the positive class.
1061
- class_map (dict[str,int]): A dictionary mapping class names (str)
1062
- to their integer representations (e.g., {'cat': 0, 'dog': 1}).
1063
- """
1064
- super().__init__(filename=filename)
1065
- self.target_name = _validate_string(string=target_name, attribute_name="Target name")
1066
- self.classification_threshold = _validate_threshold(classification_threshold)
1067
- self.class_map = _validate_class_map(class_map)
1068
- self.task = MLTaskKeys.BINARY_CLASSIFICATION
1069
-
1070
-
1071
- class FinalizeMultiClassClassification(_FinalizeModelTraining):
1072
- """Parameters for finalizing a multi-class classification model."""
1073
- def __init__(self,
1074
- filename: str,
1075
- target_name: str,
1076
- class_map: dict[str,int]
1077
- ) -> None:
1078
- """Initializes the finalization parameters.
1079
-
1080
- Args:
1081
- filename (str): The name of the file to be saved.
1082
- target_name (str): The name of the target variable.
1083
- class_map (dict[str,int]): A dictionary mapping class names (str)
1084
- to their integer representations (e.g., {'cat': 0, 'dog': 1}).
1085
- """
1086
- super().__init__(filename=filename)
1087
- self.target_name = _validate_string(string=target_name, attribute_name="Target name")
1088
- self.class_map = _validate_class_map(class_map)
1089
- self.task = MLTaskKeys.MULTICLASS_CLASSIFICATION
1090
-
1091
-
1092
- class FinalizeBinaryImageClassification(_FinalizeModelTraining):
1093
- """Parameters for finalizing a binary image classification model."""
1094
- def __init__(self,
1095
- filename: str,
1096
- classification_threshold: float,
1097
- class_map: dict[str,int]
1098
- ) -> None:
1099
- """Initializes the finalization parameters.
1100
-
1101
- Args:
1102
- filename (str): The name of the file to be saved.
1103
- classification_threshold (float): The cutoff threshold for
1104
- classifying as the positive class.
1105
- class_map (dict[str,int]): A dictionary mapping class names (str)
1106
- to their integer representations (e.g., {'cat': 0, 'dog': 1}).
1107
- """
1108
- super().__init__(filename=filename)
1109
- self.classification_threshold = _validate_threshold(classification_threshold)
1110
- self.class_map = _validate_class_map(class_map)
1111
- self.task = MLTaskKeys.BINARY_IMAGE_CLASSIFICATION
1112
-
1113
-
1114
- class FinalizeMultiClassImageClassification(_FinalizeModelTraining):
1115
- """Parameters for finalizing a multi-class image classification model."""
1116
- def __init__(self,
1117
- filename: str,
1118
- class_map: dict[str,int]
1119
- ) -> None:
1120
- """Initializes the finalization parameters.
1121
-
1122
- Args:
1123
- filename (str): The name of the file to be saved.
1124
- class_map (dict[str,int]): A dictionary mapping class names (str)
1125
- to their integer representations (e.g., {'cat': 0, 'dog': 1}).
1126
- """
1127
- super().__init__(filename=filename)
1128
- self.class_map = _validate_class_map(class_map)
1129
- self.task = MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION
1130
-
1131
-
1132
- class FinalizeMultiLabelBinaryClassification(_FinalizeModelTraining):
1133
- """Parameters for finalizing a multi-label binary classification model."""
1134
- def __init__(self,
1135
- filename: str,
1136
- target_names: list[str],
1137
- classification_threshold: float,
1138
- ) -> None:
1139
- """Initializes the finalization parameters.
1140
-
1141
- Args:
1142
- filename (str): The name of the file to be saved.
1143
- target_names (list[str]): A list of names for the target variables.
1144
- classification_threshold (float): The cutoff threshold for classifying as the positive class.
1145
- """
1146
- super().__init__(filename=filename)
1147
- safe_names = [_validate_string(string=target_name, attribute_name="All target names") for target_name in target_names]
1148
- self.target_names = safe_names
1149
- self.classification_threshold = _validate_threshold(classification_threshold)
1150
- self.task = MLTaskKeys.MULTILABEL_BINARY_CLASSIFICATION
1151
-
1152
-
1153
- class FinalizeBinarySegmentation(_FinalizeModelTraining):
1154
- """Parameters for finalizing a binary segmentation model."""
1155
- def __init__(self,
1156
- filename: str,
1157
- class_map: dict[str,int],
1158
- classification_threshold: float,
1159
- ) -> None:
1160
- """Initializes the finalization parameters.
1161
-
1162
- Args:
1163
- filename (str): The name of the file to be saved.
1164
- classification_threshold (float): The cutoff threshold for classifying as the positive class (mask).
1165
- """
1166
- super().__init__(filename=filename)
1167
- self.classification_threshold = _validate_threshold(classification_threshold)
1168
- self.class_map = _validate_class_map(class_map)
1169
- self.task = MLTaskKeys.BINARY_SEGMENTATION
1170
-
1171
-
1172
- class FinalizeMultiClassSegmentation(_FinalizeModelTraining):
1173
- """Parameters for finalizing a multi-class segmentation model."""
1174
- def __init__(self,
1175
- filename: str,
1176
- class_map: dict[str,int]
1177
- ) -> None:
1178
- """Initializes the finalization parameters.
1179
-
1180
- Args:
1181
- filename (str): The name of the file to be saved.
1182
- """
1183
- super().__init__(filename=filename)
1184
- self.class_map = _validate_class_map(class_map)
1185
- self.task = MLTaskKeys.MULTICLASS_SEGMENTATION
1186
-
1187
-
1188
- class FinalizeObjectDetection(_FinalizeModelTraining):
1189
- """Parameters for finalizing an object detection model."""
1190
- def __init__(self,
1191
- filename: str,
1192
- class_map: dict[str,int]
1193
- ) -> None:
1194
- """Initializes the finalization parameters.
1195
-
1196
- Args:
1197
- filename (str): The name of the file to be saved.
1198
- """
1199
- super().__init__(filename=filename)
1200
- self.class_map = _validate_class_map(class_map)
1201
- self.task = MLTaskKeys.OBJECT_DETECTION
1202
-
1203
-
1204
- class FinalizeSequenceSequencePrediction(_FinalizeModelTraining):
1205
- """Parameters for finalizing a sequence-to-sequence prediction model."""
1206
- def __init__(self,
1207
- filename: str,
1208
- last_training_sequence: np.ndarray,
1209
- ) -> None:
1210
- """Initializes the finalization parameters.
1211
-
1212
- Args:
1213
- filename (str): The name of the file to be saved.
1214
- last_training_sequence (np.ndarray): The last sequence from the training data, needed to start predictions.
1215
- """
1216
- super().__init__(filename=filename)
1217
-
1218
- if not isinstance(last_training_sequence, np.ndarray):
1219
- _LOGGER.error(f"The last training sequence must be a 1D numpy array, got {type(last_training_sequence)}.")
1220
- raise TypeError()
1221
-
1222
- if last_training_sequence.ndim == 1:
1223
- # It's already 1D, (N,). This is valid.
1224
- self.initial_sequence = last_training_sequence
1225
- elif last_training_sequence.ndim == 2:
1226
- # Handle both (1, N) and (N, 1)
1227
- if last_training_sequence.shape[0] == 1:
1228
- self.initial_sequence = last_training_sequence.flatten()
1229
- elif last_training_sequence.shape[1] == 1:
1230
- self.initial_sequence = last_training_sequence.flatten()
1231
- else:
1232
- _LOGGER.error(f"The last training sequence must be a 1D numpy array, got shape {last_training_sequence.shape}.")
1233
- raise ValueError()
1234
- else:
1235
- # It's 3D or more, which is not supported
1236
- _LOGGER.error(f"The last training sequence must be a 1D numpy array, got shape {last_training_sequence.shape}.")
1237
- raise ValueError()
1238
-
1239
- # Save the length of the validated 1D sequence
1240
- self.sequence_length = len(self.initial_sequence) # type: ignore
1241
- self.task = MLTaskKeys.SEQUENCE_SEQUENCE
1242
-
1243
-
1244
- class FinalizeSequenceValuePrediction(_FinalizeModelTraining):
1245
- """Parameters for finalizing a sequence-to-value prediction model."""
1246
- def __init__(self,
1247
- filename: str,
1248
- last_training_sequence: np.ndarray,
1249
- ) -> None:
1250
- """Initializes the finalization parameters.
1251
-
1252
- Args:
1253
- filename (str): The name of the file to be saved.
1254
- last_training_sequence (np.ndarray): The last sequence from the training data, needed to start predictions.
1255
- """
1256
- super().__init__(filename=filename)
1257
-
1258
- if not isinstance(last_training_sequence, np.ndarray):
1259
- _LOGGER.error(f"The last training sequence must be a 1D numpy array, got {type(last_training_sequence)}.")
1260
- raise TypeError()
1261
-
1262
- if last_training_sequence.ndim == 1:
1263
- # It's already 1D, (N,). This is valid.
1264
- self.initial_sequence = last_training_sequence
1265
- elif last_training_sequence.ndim == 2:
1266
- # Handle both (1, N) and (N, 1)
1267
- if last_training_sequence.shape[0] == 1:
1268
- self.initial_sequence = last_training_sequence.flatten()
1269
- elif last_training_sequence.shape[1] == 1:
1270
- self.initial_sequence = last_training_sequence.flatten()
1271
- else:
1272
- _LOGGER.error(f"The last training sequence must be a 1D numpy array, got shape {last_training_sequence.shape}.")
1273
- raise ValueError()
1274
- else:
1275
- # It's 3D or more, which is not supported
1276
- _LOGGER.error(f"The last training sequence must be a 1D numpy array, got shape {last_training_sequence.shape}.")
1277
- raise ValueError()
1278
-
1279
- # Save the length of the validated 1D sequence
1280
- self.sequence_length = len(self.initial_sequence) # type: ignore
1281
- self.task = MLTaskKeys.SEQUENCE_VALUE
1282
-
1283
-
1284
- def _validate_string(string: str, attribute_name: str, extension: Optional[str]=None) -> str:
1285
- """Helper for finalize classes"""
1286
- if not isinstance(string, str):
1287
- _LOGGER.error(f"{attribute_name} must be a string.")
1288
- raise TypeError()
1289
-
1290
- if extension:
1291
- safe_name = sanitize_filename(string)
1292
-
1293
- if not safe_name.endswith(extension):
1294
- safe_name += extension
1295
- else:
1296
- safe_name = string
1297
-
1298
- return safe_name
1299
-
1300
- def _validate_threshold(threshold: float):
1301
- """Helper for finalize classes"""
1302
- if not isinstance(threshold, float):
1303
- _LOGGER.error(f"Classification threshold must be a float.")
1304
- raise TypeError()
1305
- elif threshold < 0.1 or threshold > 0.9:
1306
- _LOGGER.error(f"Classification threshold must be in the range [0.1, 0.9]")
1307
- raise ValueError()
1308
-
1309
- return threshold
1310
-
1311
- def _validate_class_map(map_dict: dict[str, int]):
1312
- """Helper for finalize classes"""
1313
- if not isinstance(map_dict, dict):
1314
- _LOGGER.error(f"Class map must be a dictionary, but got {type(map_dict)}.")
1315
- raise TypeError()
1316
-
1317
- if not map_dict:
1318
- _LOGGER.error("Class map dictionary cannot be empty.")
1319
- raise ValueError()
1320
-
1321
- for key, val in map_dict.items():
1322
- if not isinstance(key, str):
1323
- _LOGGER.error(f"All keys in the class map must be strings, but found key: {key} ({type(key)}).")
1324
- raise TypeError()
1325
- if not isinstance(val, int):
1326
- _LOGGER.error(f"All values in the class map must be integers, but for key '{key}' found value: {val} ({type(val)}).")
1327
- raise TypeError()
1328
-
1329
- return map_dict
1330
-
1331
- def info():
1332
- _script_info(__all__)