dragon-ml-toolbox 19.13.0__py3-none-any.whl → 20.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/METADATA +29 -46
- dragon_ml_toolbox-20.0.0.dist-info/RECORD +178 -0
- ml_tools/{ETL_cleaning.py → ETL_cleaning/__init__.py} +13 -5
- ml_tools/ETL_cleaning/_basic_clean.py +351 -0
- ml_tools/ETL_cleaning/_clean_tools.py +128 -0
- ml_tools/ETL_cleaning/_dragon_cleaner.py +245 -0
- ml_tools/ETL_cleaning/_imprimir.py +13 -0
- ml_tools/{ETL_engineering.py → ETL_engineering/__init__.py} +8 -4
- ml_tools/ETL_engineering/_dragon_engineering.py +261 -0
- ml_tools/ETL_engineering/_imprimir.py +24 -0
- ml_tools/{_core/_ETL_engineering.py → ETL_engineering/_transforms.py} +14 -267
- ml_tools/{_core → GUI_tools}/_GUI_tools.py +37 -40
- ml_tools/{GUI_tools.py → GUI_tools/__init__.py} +7 -5
- ml_tools/GUI_tools/_imprimir.py +12 -0
- ml_tools/IO_tools/_IO_loggers.py +235 -0
- ml_tools/IO_tools/_IO_save_load.py +151 -0
- ml_tools/IO_tools/_IO_utils.py +140 -0
- ml_tools/{IO_tools.py → IO_tools/__init__.py} +13 -5
- ml_tools/IO_tools/_imprimir.py +14 -0
- ml_tools/MICE/_MICE_imputation.py +132 -0
- ml_tools/{MICE_imputation.py → MICE/__init__.py} +6 -7
- ml_tools/{_core/_MICE_imputation.py → MICE/_dragon_mice.py} +243 -322
- ml_tools/MICE/_imprimir.py +11 -0
- ml_tools/{ML_callbacks.py → ML_callbacks/__init__.py} +12 -4
- ml_tools/ML_callbacks/_base.py +101 -0
- ml_tools/ML_callbacks/_checkpoint.py +232 -0
- ml_tools/ML_callbacks/_early_stop.py +208 -0
- ml_tools/ML_callbacks/_imprimir.py +12 -0
- ml_tools/ML_callbacks/_scheduler.py +197 -0
- ml_tools/{ML_chaining_utilities.py → ML_chain/__init__.py} +8 -3
- ml_tools/{_core/_ML_chaining_utilities.py → ML_chain/_chaining_tools.py} +5 -129
- ml_tools/ML_chain/_dragon_chain.py +140 -0
- ml_tools/ML_chain/_imprimir.py +11 -0
- ml_tools/ML_configuration/__init__.py +90 -0
- ml_tools/ML_configuration/_base_model_config.py +69 -0
- ml_tools/ML_configuration/_finalize.py +366 -0
- ml_tools/ML_configuration/_imprimir.py +47 -0
- ml_tools/ML_configuration/_metrics.py +593 -0
- ml_tools/ML_configuration/_models.py +206 -0
- ml_tools/ML_configuration/_training.py +124 -0
- ml_tools/ML_datasetmaster/__init__.py +28 -0
- ml_tools/ML_datasetmaster/_base_datasetmaster.py +337 -0
- ml_tools/{_core/_ML_datasetmaster.py → ML_datasetmaster/_datasetmaster.py} +9 -329
- ml_tools/ML_datasetmaster/_imprimir.py +15 -0
- ml_tools/{_core/_ML_sequence_datasetmaster.py → ML_datasetmaster/_sequence_datasetmaster.py} +13 -15
- ml_tools/{_core/_ML_vision_datasetmaster.py → ML_datasetmaster/_vision_datasetmaster.py} +63 -65
- ml_tools/ML_evaluation/__init__.py +53 -0
- ml_tools/ML_evaluation/_classification.py +629 -0
- ml_tools/ML_evaluation/_feature_importance.py +409 -0
- ml_tools/ML_evaluation/_imprimir.py +25 -0
- ml_tools/ML_evaluation/_loss.py +92 -0
- ml_tools/ML_evaluation/_regression.py +273 -0
- ml_tools/{_core/_ML_sequence_evaluation.py → ML_evaluation/_sequence.py} +8 -11
- ml_tools/{_core/_ML_vision_evaluation.py → ML_evaluation/_vision.py} +12 -17
- ml_tools/{_core → ML_evaluation_captum}/_ML_evaluation_captum.py +11 -38
- ml_tools/{ML_evaluation_captum.py → ML_evaluation_captum/__init__.py} +6 -4
- ml_tools/ML_evaluation_captum/_imprimir.py +10 -0
- ml_tools/{_core → ML_finalize_handler}/_ML_finalize_handler.py +3 -7
- ml_tools/ML_finalize_handler/__init__.py +10 -0
- ml_tools/ML_finalize_handler/_imprimir.py +8 -0
- ml_tools/ML_inference/__init__.py +22 -0
- ml_tools/ML_inference/_base_inference.py +166 -0
- ml_tools/{_core/_ML_chaining_inference.py → ML_inference/_chain_inference.py} +14 -17
- ml_tools/ML_inference/_dragon_inference.py +332 -0
- ml_tools/ML_inference/_imprimir.py +11 -0
- ml_tools/ML_inference/_multi_inference.py +180 -0
- ml_tools/ML_inference_sequence/__init__.py +10 -0
- ml_tools/ML_inference_sequence/_imprimir.py +8 -0
- ml_tools/{_core/_ML_sequence_inference.py → ML_inference_sequence/_sequence_inference.py} +11 -15
- ml_tools/ML_inference_vision/__init__.py +10 -0
- ml_tools/ML_inference_vision/_imprimir.py +8 -0
- ml_tools/{_core/_ML_vision_inference.py → ML_inference_vision/_vision_inference.py} +15 -19
- ml_tools/ML_models/__init__.py +32 -0
- ml_tools/{_core/_ML_models_advanced.py → ML_models/_advanced_models.py} +22 -18
- ml_tools/ML_models/_base_mlp_attention.py +198 -0
- ml_tools/{_core/_models_advanced_base.py → ML_models/_base_save_load.py} +73 -49
- ml_tools/ML_models/_dragon_tabular.py +248 -0
- ml_tools/ML_models/_imprimir.py +18 -0
- ml_tools/ML_models/_mlp_attention.py +134 -0
- ml_tools/{_core → ML_models}/_models_advanced_helpers.py +13 -13
- ml_tools/ML_models_sequence/__init__.py +10 -0
- ml_tools/ML_models_sequence/_imprimir.py +8 -0
- ml_tools/{_core/_ML_sequence_models.py → ML_models_sequence/_sequence_models.py} +5 -8
- ml_tools/ML_models_vision/__init__.py +29 -0
- ml_tools/ML_models_vision/_base_wrapper.py +254 -0
- ml_tools/ML_models_vision/_image_classification.py +182 -0
- ml_tools/ML_models_vision/_image_segmentation.py +108 -0
- ml_tools/ML_models_vision/_imprimir.py +16 -0
- ml_tools/ML_models_vision/_object_detection.py +135 -0
- ml_tools/ML_optimization/__init__.py +21 -0
- ml_tools/ML_optimization/_imprimir.py +13 -0
- ml_tools/{_core/_ML_optimization_pareto.py → ML_optimization/_multi_dragon.py} +18 -24
- ml_tools/ML_optimization/_single_dragon.py +203 -0
- ml_tools/{_core/_ML_optimization.py → ML_optimization/_single_manual.py} +75 -213
- ml_tools/{_core → ML_scaler}/_ML_scaler.py +8 -11
- ml_tools/ML_scaler/__init__.py +10 -0
- ml_tools/ML_scaler/_imprimir.py +8 -0
- ml_tools/ML_trainer/__init__.py +20 -0
- ml_tools/ML_trainer/_base_trainer.py +297 -0
- ml_tools/ML_trainer/_dragon_detection_trainer.py +402 -0
- ml_tools/ML_trainer/_dragon_sequence_trainer.py +540 -0
- ml_tools/ML_trainer/_dragon_trainer.py +1160 -0
- ml_tools/ML_trainer/_imprimir.py +10 -0
- ml_tools/{ML_utilities.py → ML_utilities/__init__.py} +14 -6
- ml_tools/ML_utilities/_artifact_finder.py +382 -0
- ml_tools/ML_utilities/_imprimir.py +16 -0
- ml_tools/ML_utilities/_inspection.py +325 -0
- ml_tools/ML_utilities/_train_tools.py +205 -0
- ml_tools/{ML_vision_transformers.py → ML_vision_transformers/__init__.py} +9 -6
- ml_tools/{_core/_ML_vision_transformers.py → ML_vision_transformers/_core_transforms.py} +11 -155
- ml_tools/ML_vision_transformers/_imprimir.py +14 -0
- ml_tools/ML_vision_transformers/_offline_augmentation.py +159 -0
- ml_tools/{_core/_PSO_optimization.py → PSO_optimization/_PSO.py} +58 -15
- ml_tools/{PSO_optimization.py → PSO_optimization/__init__.py} +5 -3
- ml_tools/PSO_optimization/_imprimir.py +10 -0
- ml_tools/SQL/__init__.py +7 -0
- ml_tools/{_core/_SQL.py → SQL/_dragon_SQL.py} +7 -11
- ml_tools/SQL/_imprimir.py +8 -0
- ml_tools/{_core → VIF}/_VIF_factor.py +5 -8
- ml_tools/{VIF_factor.py → VIF/__init__.py} +4 -2
- ml_tools/VIF/_imprimir.py +10 -0
- ml_tools/_core/__init__.py +7 -1
- ml_tools/_core/_logger.py +8 -18
- ml_tools/_core/_schema_load_ops.py +43 -0
- ml_tools/_core/_script_info.py +2 -2
- ml_tools/{data_exploration.py → data_exploration/__init__.py} +32 -16
- ml_tools/data_exploration/_analysis.py +214 -0
- ml_tools/data_exploration/_cleaning.py +566 -0
- ml_tools/data_exploration/_features.py +583 -0
- ml_tools/data_exploration/_imprimir.py +32 -0
- ml_tools/data_exploration/_plotting.py +487 -0
- ml_tools/data_exploration/_schema_ops.py +176 -0
- ml_tools/{ensemble_evaluation.py → ensemble_evaluation/__init__.py} +6 -4
- ml_tools/{_core → ensemble_evaluation}/_ensemble_evaluation.py +3 -7
- ml_tools/ensemble_evaluation/_imprimir.py +14 -0
- ml_tools/{ensemble_inference.py → ensemble_inference/__init__.py} +5 -3
- ml_tools/{_core → ensemble_inference}/_ensemble_inference.py +15 -18
- ml_tools/ensemble_inference/_imprimir.py +9 -0
- ml_tools/{ensemble_learning.py → ensemble_learning/__init__.py} +4 -6
- ml_tools/{_core → ensemble_learning}/_ensemble_learning.py +7 -10
- ml_tools/ensemble_learning/_imprimir.py +10 -0
- ml_tools/{excel_handler.py → excel_handler/__init__.py} +5 -3
- ml_tools/{_core → excel_handler}/_excel_handler.py +6 -10
- ml_tools/excel_handler/_imprimir.py +13 -0
- ml_tools/{keys.py → keys/__init__.py} +4 -1
- ml_tools/keys/_imprimir.py +11 -0
- ml_tools/{_core → keys}/_keys.py +2 -0
- ml_tools/{math_utilities.py → math_utilities/__init__.py} +5 -2
- ml_tools/math_utilities/_imprimir.py +11 -0
- ml_tools/{_core → math_utilities}/_math_utilities.py +1 -5
- ml_tools/{optimization_tools.py → optimization_tools/__init__.py} +9 -4
- ml_tools/optimization_tools/_imprimir.py +13 -0
- ml_tools/optimization_tools/_optimization_bounds.py +236 -0
- ml_tools/optimization_tools/_optimization_plots.py +218 -0
- ml_tools/{path_manager.py → path_manager/__init__.py} +6 -3
- ml_tools/{_core/_path_manager.py → path_manager/_dragonmanager.py} +11 -347
- ml_tools/path_manager/_imprimir.py +15 -0
- ml_tools/path_manager/_path_tools.py +346 -0
- ml_tools/plot_fonts/__init__.py +8 -0
- ml_tools/plot_fonts/_imprimir.py +8 -0
- ml_tools/{_core → plot_fonts}/_plot_fonts.py +2 -5
- ml_tools/schema/__init__.py +15 -0
- ml_tools/schema/_feature_schema.py +223 -0
- ml_tools/schema/_gui_schema.py +191 -0
- ml_tools/schema/_imprimir.py +10 -0
- ml_tools/{serde.py → serde/__init__.py} +4 -2
- ml_tools/serde/_imprimir.py +10 -0
- ml_tools/{_core → serde}/_serde.py +3 -8
- ml_tools/{utilities.py → utilities/__init__.py} +11 -6
- ml_tools/utilities/_imprimir.py +18 -0
- ml_tools/{_core/_utilities.py → utilities/_utility_save_load.py} +13 -190
- ml_tools/utilities/_utility_tools.py +192 -0
- dragon_ml_toolbox-19.13.0.dist-info/RECORD +0 -111
- ml_tools/ML_chaining_inference.py +0 -8
- ml_tools/ML_configuration.py +0 -86
- ml_tools/ML_configuration_pytab.py +0 -14
- ml_tools/ML_datasetmaster.py +0 -10
- ml_tools/ML_evaluation.py +0 -16
- ml_tools/ML_evaluation_multi.py +0 -12
- ml_tools/ML_finalize_handler.py +0 -8
- ml_tools/ML_inference.py +0 -12
- ml_tools/ML_models.py +0 -14
- ml_tools/ML_models_advanced.py +0 -14
- ml_tools/ML_models_pytab.py +0 -14
- ml_tools/ML_optimization.py +0 -14
- ml_tools/ML_optimization_pareto.py +0 -8
- ml_tools/ML_scaler.py +0 -8
- ml_tools/ML_sequence_datasetmaster.py +0 -8
- ml_tools/ML_sequence_evaluation.py +0 -10
- ml_tools/ML_sequence_inference.py +0 -8
- ml_tools/ML_sequence_models.py +0 -8
- ml_tools/ML_trainer.py +0 -12
- ml_tools/ML_vision_datasetmaster.py +0 -12
- ml_tools/ML_vision_evaluation.py +0 -10
- ml_tools/ML_vision_inference.py +0 -8
- ml_tools/ML_vision_models.py +0 -18
- ml_tools/SQL.py +0 -8
- ml_tools/_core/_ETL_cleaning.py +0 -694
- ml_tools/_core/_IO_tools.py +0 -498
- ml_tools/_core/_ML_callbacks.py +0 -702
- ml_tools/_core/_ML_configuration.py +0 -1332
- ml_tools/_core/_ML_configuration_pytab.py +0 -102
- ml_tools/_core/_ML_evaluation.py +0 -867
- ml_tools/_core/_ML_evaluation_multi.py +0 -544
- ml_tools/_core/_ML_inference.py +0 -646
- ml_tools/_core/_ML_models.py +0 -668
- ml_tools/_core/_ML_models_pytab.py +0 -693
- ml_tools/_core/_ML_trainer.py +0 -2323
- ml_tools/_core/_ML_utilities.py +0 -886
- ml_tools/_core/_ML_vision_models.py +0 -644
- ml_tools/_core/_data_exploration.py +0 -1901
- ml_tools/_core/_optimization_tools.py +0 -493
- ml_tools/_core/_schema.py +0 -359
- ml_tools/plot_fonts.py +0 -8
- ml_tools/schema.py +0 -12
- {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/WHEEL +0 -0
- {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/licenses/LICENSE +0 -0
- {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +0 -0
- {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,593 @@
|
|
|
1
|
+
from typing import Union
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
__all__ = [
|
|
5
|
+
# --- Metrics Formats ---
|
|
6
|
+
"FormatRegressionMetrics",
|
|
7
|
+
"FormatMultiTargetRegressionMetrics",
|
|
8
|
+
"FormatBinaryClassificationMetrics",
|
|
9
|
+
"FormatMultiClassClassificationMetrics",
|
|
10
|
+
"FormatBinaryImageClassificationMetrics",
|
|
11
|
+
"FormatMultiClassImageClassificationMetrics",
|
|
12
|
+
"FormatMultiLabelBinaryClassificationMetrics",
|
|
13
|
+
"FormatBinarySegmentationMetrics",
|
|
14
|
+
"FormatMultiClassSegmentationMetrics",
|
|
15
|
+
"FormatSequenceValueMetrics",
|
|
16
|
+
"FormatSequenceSequenceMetrics",
|
|
17
|
+
]
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
# --- Private base classes ---
|
|
21
|
+
|
|
22
|
+
class _BaseClassificationFormat:
|
|
23
|
+
"""
|
|
24
|
+
[PRIVATE] Base configuration for single-label classification metrics.
|
|
25
|
+
"""
|
|
26
|
+
def __init__(self,
|
|
27
|
+
cmap: str="BuGn",
|
|
28
|
+
ROC_PR_line: str='darkorange',
|
|
29
|
+
calibration_bins: int=15,
|
|
30
|
+
xtick_size: int=22,
|
|
31
|
+
ytick_size: int=22,
|
|
32
|
+
legend_size: int=26,
|
|
33
|
+
font_size: int=26,
|
|
34
|
+
cm_font_size: int=26) -> None:
|
|
35
|
+
"""
|
|
36
|
+
Initializes the formatting configuration for single-label classification metrics.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
cmap (str): The matplotlib colormap name for the confusion matrix
|
|
40
|
+
and report heatmap.
|
|
41
|
+
- Sequential options: 'Blues', 'Greens', 'Reds', 'Oranges', 'Purples'
|
|
42
|
+
- Diverging options: 'coolwarm', 'viridis', 'plasma', 'inferno'
|
|
43
|
+
|
|
44
|
+
ROC_PR_line (str): The color name or hex code for the line plotted
|
|
45
|
+
on the ROC and Precision-Recall curves.
|
|
46
|
+
- Common color names: 'darkorange', 'cornflowerblue', 'crimson', 'forestgreen'
|
|
47
|
+
- Hex codes: '#FF6347', '#4682B4'
|
|
48
|
+
|
|
49
|
+
calibration_bins (int): The number of bins to use when
|
|
50
|
+
creating the calibration (reliability) plot.
|
|
51
|
+
|
|
52
|
+
font_size (int): The base font size to apply to the plots.
|
|
53
|
+
|
|
54
|
+
xtick_size (int): Font size for x-axis tick labels.
|
|
55
|
+
|
|
56
|
+
ytick_size (int): Font size for y-axis tick labels.
|
|
57
|
+
|
|
58
|
+
legend_size (int): Font size for plot legends.
|
|
59
|
+
|
|
60
|
+
cm_font_size (int): Font size for the confusion matrix.
|
|
61
|
+
|
|
62
|
+
<br>
|
|
63
|
+
|
|
64
|
+
### [Matplotlib Colormaps](https://matplotlib.org/stable/users/explain/colors/colormaps.html)
|
|
65
|
+
|
|
66
|
+
<br>
|
|
67
|
+
|
|
68
|
+
### [Matplotlib Colors](https://matplotlib.org/stable/gallery/color/named_colors.html)
|
|
69
|
+
"""
|
|
70
|
+
self.cmap = cmap
|
|
71
|
+
self.ROC_PR_line = ROC_PR_line
|
|
72
|
+
self.calibration_bins = calibration_bins
|
|
73
|
+
self.font_size = font_size
|
|
74
|
+
self.xtick_size = xtick_size
|
|
75
|
+
self.ytick_size = ytick_size
|
|
76
|
+
self.legend_size = legend_size
|
|
77
|
+
self.cm_font_size = cm_font_size
|
|
78
|
+
|
|
79
|
+
def __repr__(self) -> str:
|
|
80
|
+
parts = [
|
|
81
|
+
f"cmap='{self.cmap}'",
|
|
82
|
+
f"ROC_PR_line='{self.ROC_PR_line}'",
|
|
83
|
+
f"calibration_bins={self.calibration_bins}",
|
|
84
|
+
f"font_size={self.font_size}",
|
|
85
|
+
f"xtick_size={self.xtick_size}",
|
|
86
|
+
f"ytick_size={self.ytick_size}",
|
|
87
|
+
f"legend_size={self.legend_size}",
|
|
88
|
+
f"cm_font_size={self.cm_font_size}"
|
|
89
|
+
]
|
|
90
|
+
return f"{self.__class__.__name__}({', '.join(parts)})"
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
class _BaseMultiLabelFormat:
|
|
94
|
+
"""
|
|
95
|
+
[PRIVATE] Base configuration for multi-label binary classification metrics.
|
|
96
|
+
"""
|
|
97
|
+
def __init__(self,
|
|
98
|
+
cmap: str = "BuGn",
|
|
99
|
+
ROC_PR_line: str='darkorange',
|
|
100
|
+
font_size: int = 25,
|
|
101
|
+
xtick_size: int=20,
|
|
102
|
+
ytick_size: int=20,
|
|
103
|
+
legend_size: int=23) -> None:
|
|
104
|
+
"""
|
|
105
|
+
Initializes the formatting configuration for multi-label classification metrics.
|
|
106
|
+
|
|
107
|
+
Args:
|
|
108
|
+
cmap (str): The matplotlib colormap name for the per-label
|
|
109
|
+
confusion matrices.
|
|
110
|
+
- Sequential options: 'Blues', 'Greens', 'Reds', 'Oranges', 'Purples'
|
|
111
|
+
- Diverging options: 'coolwarm', 'viridis', 'plasma', 'inferno'
|
|
112
|
+
|
|
113
|
+
ROC_PR_line (str): The color name or hex code for the line plotted
|
|
114
|
+
on the ROC and Precision-Recall curves (one for each label).
|
|
115
|
+
- Common color names: 'darkorange', 'cornflowerblue', 'crimson', 'forestgreen'
|
|
116
|
+
- Hex codes: '#FF6347', '#4682B4'
|
|
117
|
+
|
|
118
|
+
font_size (int): The base font size to apply to the plots.
|
|
119
|
+
|
|
120
|
+
xtick_size (int): Font size for x-axis tick labels.
|
|
121
|
+
|
|
122
|
+
ytick_size (int): Font size for y-axis tick labels.
|
|
123
|
+
|
|
124
|
+
legend_size (int): Font size for plot legends.
|
|
125
|
+
|
|
126
|
+
<br>
|
|
127
|
+
|
|
128
|
+
### [Matplotlib Colormaps](https://matplotlib.org/stable/users/explain/colors/colormaps.html)
|
|
129
|
+
|
|
130
|
+
<br>
|
|
131
|
+
|
|
132
|
+
### [Matplotlib Colors](https://matplotlib.org/stable/gallery/color/named_colors.html)
|
|
133
|
+
"""
|
|
134
|
+
self.cmap = cmap
|
|
135
|
+
self.ROC_PR_line = ROC_PR_line
|
|
136
|
+
self.font_size = font_size
|
|
137
|
+
self.xtick_size = xtick_size
|
|
138
|
+
self.ytick_size = ytick_size
|
|
139
|
+
self.legend_size = legend_size
|
|
140
|
+
|
|
141
|
+
def __repr__(self) -> str:
|
|
142
|
+
parts = [
|
|
143
|
+
f"cmap='{self.cmap}'",
|
|
144
|
+
f"ROC_PR_line='{self.ROC_PR_line}'",
|
|
145
|
+
f"font_size={self.font_size}",
|
|
146
|
+
f"xtick_size={self.xtick_size}",
|
|
147
|
+
f"ytick_size={self.ytick_size}",
|
|
148
|
+
f"legend_size={self.legend_size}"
|
|
149
|
+
]
|
|
150
|
+
return f"{self.__class__.__name__}({', '.join(parts)})"
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
class _BaseRegressionFormat:
|
|
154
|
+
"""
|
|
155
|
+
[PRIVATE] Base configuration for regression metrics.
|
|
156
|
+
"""
|
|
157
|
+
def __init__(self,
|
|
158
|
+
font_size: int=26,
|
|
159
|
+
scatter_color: str='tab:blue',
|
|
160
|
+
scatter_alpha: float=0.6,
|
|
161
|
+
ideal_line_color: str='k',
|
|
162
|
+
residual_line_color: str='red',
|
|
163
|
+
hist_bins: Union[int, str] = 'auto',
|
|
164
|
+
xtick_size: int=22,
|
|
165
|
+
ytick_size: int=22) -> None:
|
|
166
|
+
"""
|
|
167
|
+
Initializes the formatting configuration for regression metrics.
|
|
168
|
+
|
|
169
|
+
Args:
|
|
170
|
+
font_size (int): The base font size to apply to the plots.
|
|
171
|
+
scatter_color (str): Matplotlib color for the scatter plot points.
|
|
172
|
+
- Common color names: 'tab:blue', 'crimson', 'forestgreen', '#4682B4'
|
|
173
|
+
scatter_alpha (float): Alpha transparency for scatter plot points.
|
|
174
|
+
ideal_line_color (str): Matplotlib color for the 'ideal' y=x line in the
|
|
175
|
+
True vs. Predicted plot.
|
|
176
|
+
- Common color names: 'k', 'red', 'darkgrey', '#FF6347'
|
|
177
|
+
residual_line_color (str): Matplotlib color for the y=0 line in the
|
|
178
|
+
Residual plot.
|
|
179
|
+
- Common color names: 'red', 'blue', 'k', '#4682B4'
|
|
180
|
+
hist_bins (int | str): The number of bins for the residuals histogram.
|
|
181
|
+
Defaults to 'auto' to use seaborn's automatic bin selection.
|
|
182
|
+
- Options: 'auto', 'sqrt', 10, 20
|
|
183
|
+
xtick_size (int): Font size for x-axis tick labels.
|
|
184
|
+
ytick_size (int): Font size for y-axis tick labels.
|
|
185
|
+
|
|
186
|
+
<br>
|
|
187
|
+
|
|
188
|
+
### [Matplotlib Colors](https://matplotlib.org/stable/gallery/color/named_colors.html)
|
|
189
|
+
"""
|
|
190
|
+
self.font_size = font_size
|
|
191
|
+
self.scatter_color = scatter_color
|
|
192
|
+
self.scatter_alpha = scatter_alpha
|
|
193
|
+
self.ideal_line_color = ideal_line_color
|
|
194
|
+
self.residual_line_color = residual_line_color
|
|
195
|
+
self.hist_bins = hist_bins
|
|
196
|
+
self.xtick_size = xtick_size
|
|
197
|
+
self.ytick_size = ytick_size
|
|
198
|
+
|
|
199
|
+
def __repr__(self) -> str:
|
|
200
|
+
parts = [
|
|
201
|
+
f"font_size={self.font_size}",
|
|
202
|
+
f"scatter_color='{self.scatter_color}'",
|
|
203
|
+
f"scatter_alpha={self.scatter_alpha}",
|
|
204
|
+
f"ideal_line_color='{self.ideal_line_color}'",
|
|
205
|
+
f"residual_line_color='{self.residual_line_color}'",
|
|
206
|
+
f"hist_bins='{self.hist_bins}'",
|
|
207
|
+
f"xtick_size={self.xtick_size}",
|
|
208
|
+
f"ytick_size={self.ytick_size}"
|
|
209
|
+
]
|
|
210
|
+
return f"{self.__class__.__name__}({', '.join(parts)})"
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
class _BaseSegmentationFormat:
|
|
214
|
+
"""
|
|
215
|
+
[PRIVATE] Base configuration for segmentation metrics.
|
|
216
|
+
"""
|
|
217
|
+
def __init__(self,
|
|
218
|
+
heatmap_cmap: str = "BuGn",
|
|
219
|
+
cm_cmap: str = "Purples",
|
|
220
|
+
font_size: int = 16) -> None:
|
|
221
|
+
"""
|
|
222
|
+
Initializes the formatting configuration for segmentation metrics.
|
|
223
|
+
|
|
224
|
+
Args:
|
|
225
|
+
heatmap_cmap (str): The matplotlib colormap name for the per-class
|
|
226
|
+
metrics heatmap.
|
|
227
|
+
- Sequential options: 'viridis', 'plasma', 'inferno', 'cividis'
|
|
228
|
+
- Diverging options: 'coolwarm', 'bwr', 'seismic'
|
|
229
|
+
cm_cmap (str): The matplotlib colormap name for the pixel-level
|
|
230
|
+
confusion matrix.
|
|
231
|
+
- Sequential options: 'Blues', 'Greens', 'Reds', 'Oranges'
|
|
232
|
+
font_size (int): The base font size to apply to the plots.
|
|
233
|
+
|
|
234
|
+
<br>
|
|
235
|
+
|
|
236
|
+
### [Matplotlib Colormaps](https://matplotlib.org/stable/users/explain/colors/colormaps.html)
|
|
237
|
+
"""
|
|
238
|
+
self.heatmap_cmap = heatmap_cmap
|
|
239
|
+
self.cm_cmap = cm_cmap
|
|
240
|
+
self.font_size = font_size
|
|
241
|
+
|
|
242
|
+
def __repr__(self) -> str:
|
|
243
|
+
parts = [
|
|
244
|
+
f"heatmap_cmap='{self.heatmap_cmap}'",
|
|
245
|
+
f"cm_cmap='{self.cm_cmap}'",
|
|
246
|
+
f"font_size={self.font_size}"
|
|
247
|
+
]
|
|
248
|
+
return f"{self.__class__.__name__}({', '.join(parts)})"
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
class _BaseSequenceValueFormat:
|
|
252
|
+
"""
|
|
253
|
+
[PRIVATE] Base configuration for sequence to value metrics.
|
|
254
|
+
"""
|
|
255
|
+
def __init__(self,
|
|
256
|
+
font_size: int=25,
|
|
257
|
+
scatter_color: str='tab:blue',
|
|
258
|
+
scatter_alpha: float=0.6,
|
|
259
|
+
ideal_line_color: str='k',
|
|
260
|
+
residual_line_color: str='red',
|
|
261
|
+
hist_bins: Union[int, str] = 'auto') -> None:
|
|
262
|
+
"""
|
|
263
|
+
Initializes the formatting configuration for sequence to value metrics.
|
|
264
|
+
|
|
265
|
+
Args:
|
|
266
|
+
font_size (int): The base font size to apply to the plots.
|
|
267
|
+
scatter_color (str): Matplotlib color for the scatter plot points.
|
|
268
|
+
- Common color names: 'tab:blue', 'crimson', 'forestgreen', '#4682B4'
|
|
269
|
+
scatter_alpha (float): Alpha transparency for scatter plot points.
|
|
270
|
+
ideal_line_color (str): Matplotlib color for the 'ideal' y=x line in the
|
|
271
|
+
True vs. Predicted plot.
|
|
272
|
+
- Common color names: 'k', 'red', 'darkgrey', '#FF6347'
|
|
273
|
+
residual_line_color (str): Matplotlib color for the y=0 line in the
|
|
274
|
+
Residual plot.
|
|
275
|
+
- Common color names: 'red', 'blue', 'k', '#4682B4'
|
|
276
|
+
hist_bins (int | str): The number of bins for the residuals histogram.
|
|
277
|
+
Defaults to 'auto' to use seaborn's automatic bin selection.
|
|
278
|
+
- Options: 'auto', 'sqrt', 10, 20
|
|
279
|
+
|
|
280
|
+
<br>
|
|
281
|
+
|
|
282
|
+
### [Matplotlib Colors](https://matplotlib.org/stable/gallery/color/named_colors.html)
|
|
283
|
+
"""
|
|
284
|
+
self.font_size = font_size
|
|
285
|
+
self.scatter_color = scatter_color
|
|
286
|
+
self.scatter_alpha = scatter_alpha
|
|
287
|
+
self.ideal_line_color = ideal_line_color
|
|
288
|
+
self.residual_line_color = residual_line_color
|
|
289
|
+
self.hist_bins = hist_bins
|
|
290
|
+
|
|
291
|
+
def __repr__(self) -> str:
|
|
292
|
+
parts = [
|
|
293
|
+
f"font_size={self.font_size}",
|
|
294
|
+
f"scatter_color='{self.scatter_color}'",
|
|
295
|
+
f"scatter_alpha={self.scatter_alpha}",
|
|
296
|
+
f"ideal_line_color='{self.ideal_line_color}'",
|
|
297
|
+
f"residual_line_color='{self.residual_line_color}'",
|
|
298
|
+
f"hist_bins='{self.hist_bins}'"
|
|
299
|
+
]
|
|
300
|
+
return f"{self.__class__.__name__}({', '.join(parts)})"
|
|
301
|
+
|
|
302
|
+
|
|
303
|
+
class _BaseSequenceSequenceFormat:
|
|
304
|
+
"""
|
|
305
|
+
[PRIVATE] Base configuration for sequence-to-sequence metrics.
|
|
306
|
+
"""
|
|
307
|
+
def __init__(self,
|
|
308
|
+
font_size: int = 25,
|
|
309
|
+
grid_style: str = '--',
|
|
310
|
+
rmse_color: str = 'tab:blue',
|
|
311
|
+
rmse_marker: str = 'o-',
|
|
312
|
+
mae_color: str = 'tab:orange',
|
|
313
|
+
mae_marker: str = 's--'):
|
|
314
|
+
"""
|
|
315
|
+
Initializes the formatting configuration for seq-to-seq metrics.
|
|
316
|
+
|
|
317
|
+
Args:
|
|
318
|
+
font_size (int): The base font size to apply to the plots.
|
|
319
|
+
grid_style (str): Matplotlib linestyle for the plot grid.
|
|
320
|
+
- Options: '--' (dashed), ':' (dotted), '-.' (dash-dot), '-' (solid)
|
|
321
|
+
rmse_color (str): Matplotlib color for the RMSE line.
|
|
322
|
+
- Common color names: 'tab:blue', 'crimson', 'forestgreen', '#4682B4'
|
|
323
|
+
rmse_marker (str): Matplotlib marker style for the RMSE line.
|
|
324
|
+
- Options: 'o-' (circle), 's--' (square), '^:' (triangle), 'x' (x marker)
|
|
325
|
+
mae_color (str): Matplotlib color for the MAE line.
|
|
326
|
+
- Common color names: 'tab:orange', 'purple', 'black', '#FF6347'
|
|
327
|
+
mae_marker (str): Matplotlib marker style for the MAE line.
|
|
328
|
+
- Options: 's--', 'o-', 'v:', '+' (plus marker)
|
|
329
|
+
|
|
330
|
+
<br>
|
|
331
|
+
|
|
332
|
+
### [Matplotlib Colors](https://matplotlib.org/stable/gallery/color/named_colors.html)
|
|
333
|
+
|
|
334
|
+
<br>
|
|
335
|
+
|
|
336
|
+
### [Matplotlib Linestyles](https://matplotlib.org/stable/gallery/lines_bars_and_markers/linestyles.html)
|
|
337
|
+
|
|
338
|
+
<br>
|
|
339
|
+
|
|
340
|
+
### [Matplotlib Markers](https://matplotlib.org/stable/api/markers_api.html)
|
|
341
|
+
"""
|
|
342
|
+
self.font_size = font_size
|
|
343
|
+
self.grid_style = grid_style
|
|
344
|
+
self.rmse_color = rmse_color
|
|
345
|
+
self.rmse_marker = rmse_marker
|
|
346
|
+
self.mae_color = mae_color
|
|
347
|
+
self.mae_marker = mae_marker
|
|
348
|
+
|
|
349
|
+
def __repr__(self) -> str:
|
|
350
|
+
parts = [
|
|
351
|
+
f"font_size={self.font_size}",
|
|
352
|
+
f"grid_style='{self.grid_style}'",
|
|
353
|
+
f"rmse_color='{self.rmse_color}'",
|
|
354
|
+
f"mae_color='{self.mae_color}'"
|
|
355
|
+
]
|
|
356
|
+
return f"{self.__class__.__name__}({', '.join(parts)})"
|
|
357
|
+
|
|
358
|
+
|
|
359
|
+
# ----------------------------
|
|
360
|
+
# Metrics Configurations
|
|
361
|
+
# ----------------------------
|
|
362
|
+
|
|
363
|
+
# Regression
|
|
364
|
+
class FormatRegressionMetrics(_BaseRegressionFormat):
|
|
365
|
+
"""
|
|
366
|
+
Configuration for single-target regression.
|
|
367
|
+
"""
|
|
368
|
+
def __init__(self,
|
|
369
|
+
font_size: int=26,
|
|
370
|
+
scatter_color: str='tab:blue',
|
|
371
|
+
scatter_alpha: float=0.6,
|
|
372
|
+
ideal_line_color: str='k',
|
|
373
|
+
residual_line_color: str='red',
|
|
374
|
+
hist_bins: Union[int, str] = 'auto',
|
|
375
|
+
xtick_size: int=22,
|
|
376
|
+
ytick_size: int=22) -> None:
|
|
377
|
+
super().__init__(font_size=font_size,
|
|
378
|
+
scatter_color=scatter_color,
|
|
379
|
+
scatter_alpha=scatter_alpha,
|
|
380
|
+
ideal_line_color=ideal_line_color,
|
|
381
|
+
residual_line_color=residual_line_color,
|
|
382
|
+
hist_bins=hist_bins,
|
|
383
|
+
xtick_size=xtick_size,
|
|
384
|
+
ytick_size=ytick_size)
|
|
385
|
+
|
|
386
|
+
|
|
387
|
+
# Multitarget regression
|
|
388
|
+
class FormatMultiTargetRegressionMetrics(_BaseRegressionFormat):
|
|
389
|
+
"""
|
|
390
|
+
Configuration for multi-target regression.
|
|
391
|
+
"""
|
|
392
|
+
def __init__(self,
|
|
393
|
+
font_size: int=26,
|
|
394
|
+
scatter_color: str='tab:blue',
|
|
395
|
+
scatter_alpha: float=0.6,
|
|
396
|
+
ideal_line_color: str='k',
|
|
397
|
+
residual_line_color: str='red',
|
|
398
|
+
hist_bins: Union[int, str] = 'auto',
|
|
399
|
+
xtick_size: int=22,
|
|
400
|
+
ytick_size: int=22) -> None:
|
|
401
|
+
super().__init__(font_size=font_size,
|
|
402
|
+
scatter_color=scatter_color,
|
|
403
|
+
scatter_alpha=scatter_alpha,
|
|
404
|
+
ideal_line_color=ideal_line_color,
|
|
405
|
+
residual_line_color=residual_line_color,
|
|
406
|
+
hist_bins=hist_bins,
|
|
407
|
+
xtick_size=xtick_size,
|
|
408
|
+
ytick_size=ytick_size)
|
|
409
|
+
|
|
410
|
+
|
|
411
|
+
# Classification
|
|
412
|
+
class FormatBinaryClassificationMetrics(_BaseClassificationFormat):
|
|
413
|
+
"""
|
|
414
|
+
Configuration for binary classification.
|
|
415
|
+
"""
|
|
416
|
+
def __init__(self,
|
|
417
|
+
cmap: str="BuGn",
|
|
418
|
+
ROC_PR_line: str='darkorange',
|
|
419
|
+
calibration_bins: int=15,
|
|
420
|
+
font_size: int=26,
|
|
421
|
+
xtick_size: int=22,
|
|
422
|
+
ytick_size: int=22,
|
|
423
|
+
legend_size: int=26,
|
|
424
|
+
cm_font_size: int=26
|
|
425
|
+
) -> None:
|
|
426
|
+
super().__init__(cmap=cmap,
|
|
427
|
+
ROC_PR_line=ROC_PR_line,
|
|
428
|
+
calibration_bins=calibration_bins,
|
|
429
|
+
font_size=font_size,
|
|
430
|
+
xtick_size=xtick_size,
|
|
431
|
+
ytick_size=ytick_size,
|
|
432
|
+
legend_size=legend_size,
|
|
433
|
+
cm_font_size=cm_font_size)
|
|
434
|
+
|
|
435
|
+
|
|
436
|
+
class FormatMultiClassClassificationMetrics(_BaseClassificationFormat):
|
|
437
|
+
"""
|
|
438
|
+
Configuration for multi-class classification.
|
|
439
|
+
"""
|
|
440
|
+
def __init__(self,
|
|
441
|
+
cmap: str="BuGn",
|
|
442
|
+
ROC_PR_line: str='darkorange',
|
|
443
|
+
calibration_bins: int=15,
|
|
444
|
+
font_size: int=26,
|
|
445
|
+
xtick_size: int=22,
|
|
446
|
+
ytick_size: int=22,
|
|
447
|
+
legend_size: int=26,
|
|
448
|
+
cm_font_size: int=26
|
|
449
|
+
) -> None:
|
|
450
|
+
super().__init__(cmap=cmap,
|
|
451
|
+
ROC_PR_line=ROC_PR_line,
|
|
452
|
+
calibration_bins=calibration_bins,
|
|
453
|
+
font_size=font_size,
|
|
454
|
+
xtick_size=xtick_size,
|
|
455
|
+
ytick_size=ytick_size,
|
|
456
|
+
legend_size=legend_size,
|
|
457
|
+
cm_font_size=cm_font_size)
|
|
458
|
+
|
|
459
|
+
|
|
460
|
+
class FormatBinaryImageClassificationMetrics(_BaseClassificationFormat):
|
|
461
|
+
"""
|
|
462
|
+
Configuration for binary image classification.
|
|
463
|
+
"""
|
|
464
|
+
def __init__(self,
|
|
465
|
+
cmap: str="BuGn",
|
|
466
|
+
ROC_PR_line: str='darkorange',
|
|
467
|
+
calibration_bins: int=15,
|
|
468
|
+
font_size: int=26,
|
|
469
|
+
xtick_size: int=22,
|
|
470
|
+
ytick_size: int=22,
|
|
471
|
+
legend_size: int=26,
|
|
472
|
+
cm_font_size: int=26
|
|
473
|
+
) -> None:
|
|
474
|
+
super().__init__(cmap=cmap,
|
|
475
|
+
ROC_PR_line=ROC_PR_line,
|
|
476
|
+
calibration_bins=calibration_bins,
|
|
477
|
+
font_size=font_size,
|
|
478
|
+
xtick_size=xtick_size,
|
|
479
|
+
ytick_size=ytick_size,
|
|
480
|
+
legend_size=legend_size,
|
|
481
|
+
cm_font_size=cm_font_size)
|
|
482
|
+
|
|
483
|
+
|
|
484
|
+
class FormatMultiClassImageClassificationMetrics(_BaseClassificationFormat):
|
|
485
|
+
"""
|
|
486
|
+
Configuration for multi-class image classification.
|
|
487
|
+
"""
|
|
488
|
+
def __init__(self,
|
|
489
|
+
cmap: str="BuGn",
|
|
490
|
+
ROC_PR_line: str='darkorange',
|
|
491
|
+
calibration_bins: int=15,
|
|
492
|
+
font_size: int=26,
|
|
493
|
+
xtick_size: int=22,
|
|
494
|
+
ytick_size: int=22,
|
|
495
|
+
legend_size: int=26,
|
|
496
|
+
cm_font_size: int=26
|
|
497
|
+
) -> None:
|
|
498
|
+
super().__init__(cmap=cmap,
|
|
499
|
+
ROC_PR_line=ROC_PR_line,
|
|
500
|
+
calibration_bins=calibration_bins,
|
|
501
|
+
font_size=font_size,
|
|
502
|
+
xtick_size=xtick_size,
|
|
503
|
+
ytick_size=ytick_size,
|
|
504
|
+
legend_size=legend_size,
|
|
505
|
+
cm_font_size=cm_font_size)
|
|
506
|
+
|
|
507
|
+
|
|
508
|
+
# Multi-Label classification
|
|
509
|
+
class FormatMultiLabelBinaryClassificationMetrics(_BaseMultiLabelFormat):
|
|
510
|
+
"""
|
|
511
|
+
Configuration for multi-label binary classification.
|
|
512
|
+
"""
|
|
513
|
+
def __init__(self,
|
|
514
|
+
cmap: str = "BuGn",
|
|
515
|
+
ROC_PR_line: str='darkorange',
|
|
516
|
+
font_size: int = 25,
|
|
517
|
+
xtick_size: int=20,
|
|
518
|
+
ytick_size: int=20,
|
|
519
|
+
legend_size: int=23
|
|
520
|
+
) -> None:
|
|
521
|
+
super().__init__(cmap=cmap,
|
|
522
|
+
ROC_PR_line=ROC_PR_line,
|
|
523
|
+
font_size=font_size,
|
|
524
|
+
xtick_size=xtick_size,
|
|
525
|
+
ytick_size=ytick_size,
|
|
526
|
+
legend_size=legend_size)
|
|
527
|
+
|
|
528
|
+
|
|
529
|
+
# Segmentation
|
|
530
|
+
class FormatBinarySegmentationMetrics(_BaseSegmentationFormat):
|
|
531
|
+
"""
|
|
532
|
+
Configuration for binary segmentation.
|
|
533
|
+
"""
|
|
534
|
+
def __init__(self,
|
|
535
|
+
heatmap_cmap: str = "BuGn",
|
|
536
|
+
cm_cmap: str = "Purples",
|
|
537
|
+
font_size: int = 16) -> None:
|
|
538
|
+
super().__init__(heatmap_cmap=heatmap_cmap,
|
|
539
|
+
cm_cmap=cm_cmap,
|
|
540
|
+
font_size=font_size)
|
|
541
|
+
|
|
542
|
+
|
|
543
|
+
class FormatMultiClassSegmentationMetrics(_BaseSegmentationFormat):
|
|
544
|
+
"""
|
|
545
|
+
Configuration for multi-class segmentation.
|
|
546
|
+
"""
|
|
547
|
+
def __init__(self,
|
|
548
|
+
heatmap_cmap: str = "BuGn",
|
|
549
|
+
cm_cmap: str = "Purples",
|
|
550
|
+
font_size: int = 16) -> None:
|
|
551
|
+
super().__init__(heatmap_cmap=heatmap_cmap,
|
|
552
|
+
cm_cmap=cm_cmap,
|
|
553
|
+
font_size=font_size)
|
|
554
|
+
|
|
555
|
+
|
|
556
|
+
# Sequence
|
|
557
|
+
class FormatSequenceValueMetrics(_BaseSequenceValueFormat):
|
|
558
|
+
"""
|
|
559
|
+
Configuration for sequence-to-value prediction.
|
|
560
|
+
"""
|
|
561
|
+
def __init__(self,
|
|
562
|
+
font_size: int=25,
|
|
563
|
+
scatter_color: str='tab:blue',
|
|
564
|
+
scatter_alpha: float=0.6,
|
|
565
|
+
ideal_line_color: str='k',
|
|
566
|
+
residual_line_color: str='red',
|
|
567
|
+
hist_bins: Union[int, str] = 'auto') -> None:
|
|
568
|
+
super().__init__(font_size=font_size,
|
|
569
|
+
scatter_color=scatter_color,
|
|
570
|
+
scatter_alpha=scatter_alpha,
|
|
571
|
+
ideal_line_color=ideal_line_color,
|
|
572
|
+
residual_line_color=residual_line_color,
|
|
573
|
+
hist_bins=hist_bins)
|
|
574
|
+
|
|
575
|
+
|
|
576
|
+
class FormatSequenceSequenceMetrics(_BaseSequenceSequenceFormat):
|
|
577
|
+
"""
|
|
578
|
+
Configuration for sequence-to-sequence prediction.
|
|
579
|
+
"""
|
|
580
|
+
def __init__(self,
|
|
581
|
+
font_size: int = 25,
|
|
582
|
+
grid_style: str = '--',
|
|
583
|
+
rmse_color: str = 'tab:blue',
|
|
584
|
+
rmse_marker: str = 'o-',
|
|
585
|
+
mae_color: str = 'tab:orange',
|
|
586
|
+
mae_marker: str = 's--'):
|
|
587
|
+
super().__init__(font_size=font_size,
|
|
588
|
+
grid_style=grid_style,
|
|
589
|
+
rmse_color=rmse_color,
|
|
590
|
+
rmse_marker=rmse_marker,
|
|
591
|
+
mae_color=mae_color,
|
|
592
|
+
mae_marker=mae_marker)
|
|
593
|
+
|