dragon-ml-toolbox 19.13.0__py3-none-any.whl → 20.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/METADATA +29 -46
- dragon_ml_toolbox-20.0.0.dist-info/RECORD +178 -0
- ml_tools/{ETL_cleaning.py → ETL_cleaning/__init__.py} +13 -5
- ml_tools/ETL_cleaning/_basic_clean.py +351 -0
- ml_tools/ETL_cleaning/_clean_tools.py +128 -0
- ml_tools/ETL_cleaning/_dragon_cleaner.py +245 -0
- ml_tools/ETL_cleaning/_imprimir.py +13 -0
- ml_tools/{ETL_engineering.py → ETL_engineering/__init__.py} +8 -4
- ml_tools/ETL_engineering/_dragon_engineering.py +261 -0
- ml_tools/ETL_engineering/_imprimir.py +24 -0
- ml_tools/{_core/_ETL_engineering.py → ETL_engineering/_transforms.py} +14 -267
- ml_tools/{_core → GUI_tools}/_GUI_tools.py +37 -40
- ml_tools/{GUI_tools.py → GUI_tools/__init__.py} +7 -5
- ml_tools/GUI_tools/_imprimir.py +12 -0
- ml_tools/IO_tools/_IO_loggers.py +235 -0
- ml_tools/IO_tools/_IO_save_load.py +151 -0
- ml_tools/IO_tools/_IO_utils.py +140 -0
- ml_tools/{IO_tools.py → IO_tools/__init__.py} +13 -5
- ml_tools/IO_tools/_imprimir.py +14 -0
- ml_tools/MICE/_MICE_imputation.py +132 -0
- ml_tools/{MICE_imputation.py → MICE/__init__.py} +6 -7
- ml_tools/{_core/_MICE_imputation.py → MICE/_dragon_mice.py} +243 -322
- ml_tools/MICE/_imprimir.py +11 -0
- ml_tools/{ML_callbacks.py → ML_callbacks/__init__.py} +12 -4
- ml_tools/ML_callbacks/_base.py +101 -0
- ml_tools/ML_callbacks/_checkpoint.py +232 -0
- ml_tools/ML_callbacks/_early_stop.py +208 -0
- ml_tools/ML_callbacks/_imprimir.py +12 -0
- ml_tools/ML_callbacks/_scheduler.py +197 -0
- ml_tools/{ML_chaining_utilities.py → ML_chain/__init__.py} +8 -3
- ml_tools/{_core/_ML_chaining_utilities.py → ML_chain/_chaining_tools.py} +5 -129
- ml_tools/ML_chain/_dragon_chain.py +140 -0
- ml_tools/ML_chain/_imprimir.py +11 -0
- ml_tools/ML_configuration/__init__.py +90 -0
- ml_tools/ML_configuration/_base_model_config.py +69 -0
- ml_tools/ML_configuration/_finalize.py +366 -0
- ml_tools/ML_configuration/_imprimir.py +47 -0
- ml_tools/ML_configuration/_metrics.py +593 -0
- ml_tools/ML_configuration/_models.py +206 -0
- ml_tools/ML_configuration/_training.py +124 -0
- ml_tools/ML_datasetmaster/__init__.py +28 -0
- ml_tools/ML_datasetmaster/_base_datasetmaster.py +337 -0
- ml_tools/{_core/_ML_datasetmaster.py → ML_datasetmaster/_datasetmaster.py} +9 -329
- ml_tools/ML_datasetmaster/_imprimir.py +15 -0
- ml_tools/{_core/_ML_sequence_datasetmaster.py → ML_datasetmaster/_sequence_datasetmaster.py} +13 -15
- ml_tools/{_core/_ML_vision_datasetmaster.py → ML_datasetmaster/_vision_datasetmaster.py} +63 -65
- ml_tools/ML_evaluation/__init__.py +53 -0
- ml_tools/ML_evaluation/_classification.py +629 -0
- ml_tools/ML_evaluation/_feature_importance.py +409 -0
- ml_tools/ML_evaluation/_imprimir.py +25 -0
- ml_tools/ML_evaluation/_loss.py +92 -0
- ml_tools/ML_evaluation/_regression.py +273 -0
- ml_tools/{_core/_ML_sequence_evaluation.py → ML_evaluation/_sequence.py} +8 -11
- ml_tools/{_core/_ML_vision_evaluation.py → ML_evaluation/_vision.py} +12 -17
- ml_tools/{_core → ML_evaluation_captum}/_ML_evaluation_captum.py +11 -38
- ml_tools/{ML_evaluation_captum.py → ML_evaluation_captum/__init__.py} +6 -4
- ml_tools/ML_evaluation_captum/_imprimir.py +10 -0
- ml_tools/{_core → ML_finalize_handler}/_ML_finalize_handler.py +3 -7
- ml_tools/ML_finalize_handler/__init__.py +10 -0
- ml_tools/ML_finalize_handler/_imprimir.py +8 -0
- ml_tools/ML_inference/__init__.py +22 -0
- ml_tools/ML_inference/_base_inference.py +166 -0
- ml_tools/{_core/_ML_chaining_inference.py → ML_inference/_chain_inference.py} +14 -17
- ml_tools/ML_inference/_dragon_inference.py +332 -0
- ml_tools/ML_inference/_imprimir.py +11 -0
- ml_tools/ML_inference/_multi_inference.py +180 -0
- ml_tools/ML_inference_sequence/__init__.py +10 -0
- ml_tools/ML_inference_sequence/_imprimir.py +8 -0
- ml_tools/{_core/_ML_sequence_inference.py → ML_inference_sequence/_sequence_inference.py} +11 -15
- ml_tools/ML_inference_vision/__init__.py +10 -0
- ml_tools/ML_inference_vision/_imprimir.py +8 -0
- ml_tools/{_core/_ML_vision_inference.py → ML_inference_vision/_vision_inference.py} +15 -19
- ml_tools/ML_models/__init__.py +32 -0
- ml_tools/{_core/_ML_models_advanced.py → ML_models/_advanced_models.py} +22 -18
- ml_tools/ML_models/_base_mlp_attention.py +198 -0
- ml_tools/{_core/_models_advanced_base.py → ML_models/_base_save_load.py} +73 -49
- ml_tools/ML_models/_dragon_tabular.py +248 -0
- ml_tools/ML_models/_imprimir.py +18 -0
- ml_tools/ML_models/_mlp_attention.py +134 -0
- ml_tools/{_core → ML_models}/_models_advanced_helpers.py +13 -13
- ml_tools/ML_models_sequence/__init__.py +10 -0
- ml_tools/ML_models_sequence/_imprimir.py +8 -0
- ml_tools/{_core/_ML_sequence_models.py → ML_models_sequence/_sequence_models.py} +5 -8
- ml_tools/ML_models_vision/__init__.py +29 -0
- ml_tools/ML_models_vision/_base_wrapper.py +254 -0
- ml_tools/ML_models_vision/_image_classification.py +182 -0
- ml_tools/ML_models_vision/_image_segmentation.py +108 -0
- ml_tools/ML_models_vision/_imprimir.py +16 -0
- ml_tools/ML_models_vision/_object_detection.py +135 -0
- ml_tools/ML_optimization/__init__.py +21 -0
- ml_tools/ML_optimization/_imprimir.py +13 -0
- ml_tools/{_core/_ML_optimization_pareto.py → ML_optimization/_multi_dragon.py} +18 -24
- ml_tools/ML_optimization/_single_dragon.py +203 -0
- ml_tools/{_core/_ML_optimization.py → ML_optimization/_single_manual.py} +75 -213
- ml_tools/{_core → ML_scaler}/_ML_scaler.py +8 -11
- ml_tools/ML_scaler/__init__.py +10 -0
- ml_tools/ML_scaler/_imprimir.py +8 -0
- ml_tools/ML_trainer/__init__.py +20 -0
- ml_tools/ML_trainer/_base_trainer.py +297 -0
- ml_tools/ML_trainer/_dragon_detection_trainer.py +402 -0
- ml_tools/ML_trainer/_dragon_sequence_trainer.py +540 -0
- ml_tools/ML_trainer/_dragon_trainer.py +1160 -0
- ml_tools/ML_trainer/_imprimir.py +10 -0
- ml_tools/{ML_utilities.py → ML_utilities/__init__.py} +14 -6
- ml_tools/ML_utilities/_artifact_finder.py +382 -0
- ml_tools/ML_utilities/_imprimir.py +16 -0
- ml_tools/ML_utilities/_inspection.py +325 -0
- ml_tools/ML_utilities/_train_tools.py +205 -0
- ml_tools/{ML_vision_transformers.py → ML_vision_transformers/__init__.py} +9 -6
- ml_tools/{_core/_ML_vision_transformers.py → ML_vision_transformers/_core_transforms.py} +11 -155
- ml_tools/ML_vision_transformers/_imprimir.py +14 -0
- ml_tools/ML_vision_transformers/_offline_augmentation.py +159 -0
- ml_tools/{_core/_PSO_optimization.py → PSO_optimization/_PSO.py} +58 -15
- ml_tools/{PSO_optimization.py → PSO_optimization/__init__.py} +5 -3
- ml_tools/PSO_optimization/_imprimir.py +10 -0
- ml_tools/SQL/__init__.py +7 -0
- ml_tools/{_core/_SQL.py → SQL/_dragon_SQL.py} +7 -11
- ml_tools/SQL/_imprimir.py +8 -0
- ml_tools/{_core → VIF}/_VIF_factor.py +5 -8
- ml_tools/{VIF_factor.py → VIF/__init__.py} +4 -2
- ml_tools/VIF/_imprimir.py +10 -0
- ml_tools/_core/__init__.py +7 -1
- ml_tools/_core/_logger.py +8 -18
- ml_tools/_core/_schema_load_ops.py +43 -0
- ml_tools/_core/_script_info.py +2 -2
- ml_tools/{data_exploration.py → data_exploration/__init__.py} +32 -16
- ml_tools/data_exploration/_analysis.py +214 -0
- ml_tools/data_exploration/_cleaning.py +566 -0
- ml_tools/data_exploration/_features.py +583 -0
- ml_tools/data_exploration/_imprimir.py +32 -0
- ml_tools/data_exploration/_plotting.py +487 -0
- ml_tools/data_exploration/_schema_ops.py +176 -0
- ml_tools/{ensemble_evaluation.py → ensemble_evaluation/__init__.py} +6 -4
- ml_tools/{_core → ensemble_evaluation}/_ensemble_evaluation.py +3 -7
- ml_tools/ensemble_evaluation/_imprimir.py +14 -0
- ml_tools/{ensemble_inference.py → ensemble_inference/__init__.py} +5 -3
- ml_tools/{_core → ensemble_inference}/_ensemble_inference.py +15 -18
- ml_tools/ensemble_inference/_imprimir.py +9 -0
- ml_tools/{ensemble_learning.py → ensemble_learning/__init__.py} +4 -6
- ml_tools/{_core → ensemble_learning}/_ensemble_learning.py +7 -10
- ml_tools/ensemble_learning/_imprimir.py +10 -0
- ml_tools/{excel_handler.py → excel_handler/__init__.py} +5 -3
- ml_tools/{_core → excel_handler}/_excel_handler.py +6 -10
- ml_tools/excel_handler/_imprimir.py +13 -0
- ml_tools/{keys.py → keys/__init__.py} +4 -1
- ml_tools/keys/_imprimir.py +11 -0
- ml_tools/{_core → keys}/_keys.py +2 -0
- ml_tools/{math_utilities.py → math_utilities/__init__.py} +5 -2
- ml_tools/math_utilities/_imprimir.py +11 -0
- ml_tools/{_core → math_utilities}/_math_utilities.py +1 -5
- ml_tools/{optimization_tools.py → optimization_tools/__init__.py} +9 -4
- ml_tools/optimization_tools/_imprimir.py +13 -0
- ml_tools/optimization_tools/_optimization_bounds.py +236 -0
- ml_tools/optimization_tools/_optimization_plots.py +218 -0
- ml_tools/{path_manager.py → path_manager/__init__.py} +6 -3
- ml_tools/{_core/_path_manager.py → path_manager/_dragonmanager.py} +11 -347
- ml_tools/path_manager/_imprimir.py +15 -0
- ml_tools/path_manager/_path_tools.py +346 -0
- ml_tools/plot_fonts/__init__.py +8 -0
- ml_tools/plot_fonts/_imprimir.py +8 -0
- ml_tools/{_core → plot_fonts}/_plot_fonts.py +2 -5
- ml_tools/schema/__init__.py +15 -0
- ml_tools/schema/_feature_schema.py +223 -0
- ml_tools/schema/_gui_schema.py +191 -0
- ml_tools/schema/_imprimir.py +10 -0
- ml_tools/{serde.py → serde/__init__.py} +4 -2
- ml_tools/serde/_imprimir.py +10 -0
- ml_tools/{_core → serde}/_serde.py +3 -8
- ml_tools/{utilities.py → utilities/__init__.py} +11 -6
- ml_tools/utilities/_imprimir.py +18 -0
- ml_tools/{_core/_utilities.py → utilities/_utility_save_load.py} +13 -190
- ml_tools/utilities/_utility_tools.py +192 -0
- dragon_ml_toolbox-19.13.0.dist-info/RECORD +0 -111
- ml_tools/ML_chaining_inference.py +0 -8
- ml_tools/ML_configuration.py +0 -86
- ml_tools/ML_configuration_pytab.py +0 -14
- ml_tools/ML_datasetmaster.py +0 -10
- ml_tools/ML_evaluation.py +0 -16
- ml_tools/ML_evaluation_multi.py +0 -12
- ml_tools/ML_finalize_handler.py +0 -8
- ml_tools/ML_inference.py +0 -12
- ml_tools/ML_models.py +0 -14
- ml_tools/ML_models_advanced.py +0 -14
- ml_tools/ML_models_pytab.py +0 -14
- ml_tools/ML_optimization.py +0 -14
- ml_tools/ML_optimization_pareto.py +0 -8
- ml_tools/ML_scaler.py +0 -8
- ml_tools/ML_sequence_datasetmaster.py +0 -8
- ml_tools/ML_sequence_evaluation.py +0 -10
- ml_tools/ML_sequence_inference.py +0 -8
- ml_tools/ML_sequence_models.py +0 -8
- ml_tools/ML_trainer.py +0 -12
- ml_tools/ML_vision_datasetmaster.py +0 -12
- ml_tools/ML_vision_evaluation.py +0 -10
- ml_tools/ML_vision_inference.py +0 -8
- ml_tools/ML_vision_models.py +0 -18
- ml_tools/SQL.py +0 -8
- ml_tools/_core/_ETL_cleaning.py +0 -694
- ml_tools/_core/_IO_tools.py +0 -498
- ml_tools/_core/_ML_callbacks.py +0 -702
- ml_tools/_core/_ML_configuration.py +0 -1332
- ml_tools/_core/_ML_configuration_pytab.py +0 -102
- ml_tools/_core/_ML_evaluation.py +0 -867
- ml_tools/_core/_ML_evaluation_multi.py +0 -544
- ml_tools/_core/_ML_inference.py +0 -646
- ml_tools/_core/_ML_models.py +0 -668
- ml_tools/_core/_ML_models_pytab.py +0 -693
- ml_tools/_core/_ML_trainer.py +0 -2323
- ml_tools/_core/_ML_utilities.py +0 -886
- ml_tools/_core/_ML_vision_models.py +0 -644
- ml_tools/_core/_data_exploration.py +0 -1901
- ml_tools/_core/_optimization_tools.py +0 -493
- ml_tools/_core/_schema.py +0 -359
- ml_tools/plot_fonts.py +0 -8
- ml_tools/schema.py +0 -12
- {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/WHEEL +0 -0
- {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/licenses/LICENSE +0 -0
- {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +0 -0
- {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/top_level.txt +0 -0
|
@@ -1,1332 +0,0 @@
|
|
|
1
|
-
from typing import Union, Optional, List, Any, Dict, Literal, Tuple
|
|
2
|
-
from pathlib import Path
|
|
3
|
-
from collections.abc import Mapping
|
|
4
|
-
import numpy as np
|
|
5
|
-
|
|
6
|
-
from ._schema import FeatureSchema
|
|
7
|
-
from ._script_info import _script_info
|
|
8
|
-
from ._logger import get_logger
|
|
9
|
-
from ._path_manager import sanitize_filename, make_fullpath
|
|
10
|
-
from ._keys import MLTaskKeys
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
_LOGGER = get_logger("Configuration")
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
__all__ = [
|
|
17
|
-
# --- Metrics Formats ---
|
|
18
|
-
"RegressionMetricsFormat",
|
|
19
|
-
"MultiTargetRegressionMetricsFormat",
|
|
20
|
-
"BinaryClassificationMetricsFormat",
|
|
21
|
-
"MultiClassClassificationMetricsFormat",
|
|
22
|
-
"BinaryImageClassificationMetricsFormat",
|
|
23
|
-
"MultiClassImageClassificationMetricsFormat",
|
|
24
|
-
"MultiLabelBinaryClassificationMetricsFormat",
|
|
25
|
-
"BinarySegmentationMetricsFormat",
|
|
26
|
-
"MultiClassSegmentationMetricsFormat",
|
|
27
|
-
"SequenceValueMetricsFormat",
|
|
28
|
-
"SequenceSequenceMetricsFormat",
|
|
29
|
-
|
|
30
|
-
# --- Finalize Configs ---
|
|
31
|
-
"FinalizeBinaryClassification",
|
|
32
|
-
"FinalizeBinarySegmentation",
|
|
33
|
-
"FinalizeBinaryImageClassification",
|
|
34
|
-
"FinalizeMultiClassClassification",
|
|
35
|
-
"FinalizeMultiClassImageClassification",
|
|
36
|
-
"FinalizeMultiClassSegmentation",
|
|
37
|
-
"FinalizeMultiLabelBinaryClassification",
|
|
38
|
-
"FinalizeMultiTargetRegression",
|
|
39
|
-
"FinalizeRegression",
|
|
40
|
-
"FinalizeObjectDetection",
|
|
41
|
-
"FinalizeSequenceSequencePrediction",
|
|
42
|
-
"FinalizeSequenceValuePrediction",
|
|
43
|
-
|
|
44
|
-
# --- Model Parameter Configs ---
|
|
45
|
-
"DragonMLPParams",
|
|
46
|
-
"DragonAttentionMLPParams",
|
|
47
|
-
"DragonMultiHeadAttentionNetParams",
|
|
48
|
-
"DragonTabularTransformerParams",
|
|
49
|
-
"DragonGateParams",
|
|
50
|
-
"DragonNodeParams",
|
|
51
|
-
"DragonTabNetParams",
|
|
52
|
-
"DragonAutoIntParams",
|
|
53
|
-
|
|
54
|
-
# --- Training Config ---
|
|
55
|
-
"DragonTrainingConfig",
|
|
56
|
-
"DragonParetoConfig"
|
|
57
|
-
]
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
# --- Private base classes ---
|
|
61
|
-
|
|
62
|
-
class _BaseClassificationFormat:
|
|
63
|
-
"""
|
|
64
|
-
[PRIVATE] Base configuration for single-label classification metrics.
|
|
65
|
-
"""
|
|
66
|
-
def __init__(self,
|
|
67
|
-
cmap: str="BuGn",
|
|
68
|
-
ROC_PR_line: str='darkorange',
|
|
69
|
-
calibration_bins: int=15,
|
|
70
|
-
xtick_size: int=22,
|
|
71
|
-
ytick_size: int=22,
|
|
72
|
-
legend_size: int=26,
|
|
73
|
-
font_size: int=26,
|
|
74
|
-
cm_font_size: int=26) -> None:
|
|
75
|
-
"""
|
|
76
|
-
Initializes the formatting configuration for single-label classification metrics.
|
|
77
|
-
|
|
78
|
-
Args:
|
|
79
|
-
cmap (str): The matplotlib colormap name for the confusion matrix
|
|
80
|
-
and report heatmap.
|
|
81
|
-
- Sequential options: 'Blues', 'Greens', 'Reds', 'Oranges', 'Purples'
|
|
82
|
-
- Diverging options: 'coolwarm', 'viridis', 'plasma', 'inferno'
|
|
83
|
-
|
|
84
|
-
ROC_PR_line (str): The color name or hex code for the line plotted
|
|
85
|
-
on the ROC and Precision-Recall curves.
|
|
86
|
-
- Common color names: 'darkorange', 'cornflowerblue', 'crimson', 'forestgreen'
|
|
87
|
-
- Hex codes: '#FF6347', '#4682B4'
|
|
88
|
-
|
|
89
|
-
calibration_bins (int): The number of bins to use when
|
|
90
|
-
creating the calibration (reliability) plot.
|
|
91
|
-
|
|
92
|
-
font_size (int): The base font size to apply to the plots.
|
|
93
|
-
|
|
94
|
-
xtick_size (int): Font size for x-axis tick labels.
|
|
95
|
-
|
|
96
|
-
ytick_size (int): Font size for y-axis tick labels.
|
|
97
|
-
|
|
98
|
-
legend_size (int): Font size for plot legends.
|
|
99
|
-
|
|
100
|
-
cm_font_size (int): Font size for the confusion matrix.
|
|
101
|
-
|
|
102
|
-
<br>
|
|
103
|
-
|
|
104
|
-
### [Matplotlib Colormaps](https://matplotlib.org/stable/users/explain/colors/colormaps.html)
|
|
105
|
-
|
|
106
|
-
<br>
|
|
107
|
-
|
|
108
|
-
### [Matplotlib Colors](https://matplotlib.org/stable/gallery/color/named_colors.html)
|
|
109
|
-
"""
|
|
110
|
-
self.cmap = cmap
|
|
111
|
-
self.ROC_PR_line = ROC_PR_line
|
|
112
|
-
self.calibration_bins = calibration_bins
|
|
113
|
-
self.font_size = font_size
|
|
114
|
-
self.xtick_size = xtick_size
|
|
115
|
-
self.ytick_size = ytick_size
|
|
116
|
-
self.legend_size = legend_size
|
|
117
|
-
self.cm_font_size = cm_font_size
|
|
118
|
-
|
|
119
|
-
def __repr__(self) -> str:
|
|
120
|
-
parts = [
|
|
121
|
-
f"cmap='{self.cmap}'",
|
|
122
|
-
f"ROC_PR_line='{self.ROC_PR_line}'",
|
|
123
|
-
f"calibration_bins={self.calibration_bins}",
|
|
124
|
-
f"font_size={self.font_size}",
|
|
125
|
-
f"xtick_size={self.xtick_size}",
|
|
126
|
-
f"ytick_size={self.ytick_size}",
|
|
127
|
-
f"legend_size={self.legend_size}",
|
|
128
|
-
f"cm_font_size={self.cm_font_size}"
|
|
129
|
-
]
|
|
130
|
-
return f"{self.__class__.__name__}({', '.join(parts)})"
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
class _BaseMultiLabelFormat:
|
|
134
|
-
"""
|
|
135
|
-
[PRIVATE] Base configuration for multi-label binary classification metrics.
|
|
136
|
-
"""
|
|
137
|
-
def __init__(self,
|
|
138
|
-
cmap: str = "BuGn",
|
|
139
|
-
ROC_PR_line: str='darkorange',
|
|
140
|
-
font_size: int = 25,
|
|
141
|
-
xtick_size: int=20,
|
|
142
|
-
ytick_size: int=20,
|
|
143
|
-
legend_size: int=23) -> None:
|
|
144
|
-
"""
|
|
145
|
-
Initializes the formatting configuration for multi-label classification metrics.
|
|
146
|
-
|
|
147
|
-
Args:
|
|
148
|
-
cmap (str): The matplotlib colormap name for the per-label
|
|
149
|
-
confusion matrices.
|
|
150
|
-
- Sequential options: 'Blues', 'Greens', 'Reds', 'Oranges', 'Purples'
|
|
151
|
-
- Diverging options: 'coolwarm', 'viridis', 'plasma', 'inferno'
|
|
152
|
-
|
|
153
|
-
ROC_PR_line (str): The color name or hex code for the line plotted
|
|
154
|
-
on the ROC and Precision-Recall curves (one for each label).
|
|
155
|
-
- Common color names: 'darkorange', 'cornflowerblue', 'crimson', 'forestgreen'
|
|
156
|
-
- Hex codes: '#FF6347', '#4682B4'
|
|
157
|
-
|
|
158
|
-
font_size (int): The base font size to apply to the plots.
|
|
159
|
-
|
|
160
|
-
xtick_size (int): Font size for x-axis tick labels.
|
|
161
|
-
|
|
162
|
-
ytick_size (int): Font size for y-axis tick labels.
|
|
163
|
-
|
|
164
|
-
legend_size (int): Font size for plot legends.
|
|
165
|
-
|
|
166
|
-
<br>
|
|
167
|
-
|
|
168
|
-
### [Matplotlib Colormaps](https://matplotlib.org/stable/users/explain/colors/colormaps.html)
|
|
169
|
-
|
|
170
|
-
<br>
|
|
171
|
-
|
|
172
|
-
### [Matplotlib Colors](https://matplotlib.org/stable/gallery/color/named_colors.html)
|
|
173
|
-
"""
|
|
174
|
-
self.cmap = cmap
|
|
175
|
-
self.ROC_PR_line = ROC_PR_line
|
|
176
|
-
self.font_size = font_size
|
|
177
|
-
self.xtick_size = xtick_size
|
|
178
|
-
self.ytick_size = ytick_size
|
|
179
|
-
self.legend_size = legend_size
|
|
180
|
-
|
|
181
|
-
def __repr__(self) -> str:
|
|
182
|
-
parts = [
|
|
183
|
-
f"cmap='{self.cmap}'",
|
|
184
|
-
f"ROC_PR_line='{self.ROC_PR_line}'",
|
|
185
|
-
f"font_size={self.font_size}",
|
|
186
|
-
f"xtick_size={self.xtick_size}",
|
|
187
|
-
f"ytick_size={self.ytick_size}",
|
|
188
|
-
f"legend_size={self.legend_size}"
|
|
189
|
-
]
|
|
190
|
-
return f"{self.__class__.__name__}({', '.join(parts)})"
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
class _BaseRegressionFormat:
|
|
194
|
-
"""
|
|
195
|
-
[PRIVATE] Base configuration for regression metrics.
|
|
196
|
-
"""
|
|
197
|
-
def __init__(self,
|
|
198
|
-
font_size: int=26,
|
|
199
|
-
scatter_color: str='tab:blue',
|
|
200
|
-
scatter_alpha: float=0.6,
|
|
201
|
-
ideal_line_color: str='k',
|
|
202
|
-
residual_line_color: str='red',
|
|
203
|
-
hist_bins: Union[int, str] = 'auto',
|
|
204
|
-
xtick_size: int=22,
|
|
205
|
-
ytick_size: int=22) -> None:
|
|
206
|
-
"""
|
|
207
|
-
Initializes the formatting configuration for regression metrics.
|
|
208
|
-
|
|
209
|
-
Args:
|
|
210
|
-
font_size (int): The base font size to apply to the plots.
|
|
211
|
-
scatter_color (str): Matplotlib color for the scatter plot points.
|
|
212
|
-
- Common color names: 'tab:blue', 'crimson', 'forestgreen', '#4682B4'
|
|
213
|
-
scatter_alpha (float): Alpha transparency for scatter plot points.
|
|
214
|
-
ideal_line_color (str): Matplotlib color for the 'ideal' y=x line in the
|
|
215
|
-
True vs. Predicted plot.
|
|
216
|
-
- Common color names: 'k', 'red', 'darkgrey', '#FF6347'
|
|
217
|
-
residual_line_color (str): Matplotlib color for the y=0 line in the
|
|
218
|
-
Residual plot.
|
|
219
|
-
- Common color names: 'red', 'blue', 'k', '#4682B4'
|
|
220
|
-
hist_bins (int | str): The number of bins for the residuals histogram.
|
|
221
|
-
Defaults to 'auto' to use seaborn's automatic bin selection.
|
|
222
|
-
- Options: 'auto', 'sqrt', 10, 20
|
|
223
|
-
xtick_size (int): Font size for x-axis tick labels.
|
|
224
|
-
ytick_size (int): Font size for y-axis tick labels.
|
|
225
|
-
|
|
226
|
-
<br>
|
|
227
|
-
|
|
228
|
-
### [Matplotlib Colors](https://matplotlib.org/stable/gallery/color/named_colors.html)
|
|
229
|
-
"""
|
|
230
|
-
self.font_size = font_size
|
|
231
|
-
self.scatter_color = scatter_color
|
|
232
|
-
self.scatter_alpha = scatter_alpha
|
|
233
|
-
self.ideal_line_color = ideal_line_color
|
|
234
|
-
self.residual_line_color = residual_line_color
|
|
235
|
-
self.hist_bins = hist_bins
|
|
236
|
-
self.xtick_size = xtick_size
|
|
237
|
-
self.ytick_size = ytick_size
|
|
238
|
-
|
|
239
|
-
def __repr__(self) -> str:
|
|
240
|
-
parts = [
|
|
241
|
-
f"font_size={self.font_size}",
|
|
242
|
-
f"scatter_color='{self.scatter_color}'",
|
|
243
|
-
f"scatter_alpha={self.scatter_alpha}",
|
|
244
|
-
f"ideal_line_color='{self.ideal_line_color}'",
|
|
245
|
-
f"residual_line_color='{self.residual_line_color}'",
|
|
246
|
-
f"hist_bins='{self.hist_bins}'",
|
|
247
|
-
f"xtick_size={self.xtick_size}",
|
|
248
|
-
f"ytick_size={self.ytick_size}"
|
|
249
|
-
]
|
|
250
|
-
return f"{self.__class__.__name__}({', '.join(parts)})"
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
class _BaseSegmentationFormat:
|
|
254
|
-
"""
|
|
255
|
-
[PRIVATE] Base configuration for segmentation metrics.
|
|
256
|
-
"""
|
|
257
|
-
def __init__(self,
|
|
258
|
-
heatmap_cmap: str = "BuGn",
|
|
259
|
-
cm_cmap: str = "Purples",
|
|
260
|
-
font_size: int = 16) -> None:
|
|
261
|
-
"""
|
|
262
|
-
Initializes the formatting configuration for segmentation metrics.
|
|
263
|
-
|
|
264
|
-
Args:
|
|
265
|
-
heatmap_cmap (str): The matplotlib colormap name for the per-class
|
|
266
|
-
metrics heatmap.
|
|
267
|
-
- Sequential options: 'viridis', 'plasma', 'inferno', 'cividis'
|
|
268
|
-
- Diverging options: 'coolwarm', 'bwr', 'seismic'
|
|
269
|
-
cm_cmap (str): The matplotlib colormap name for the pixel-level
|
|
270
|
-
confusion matrix.
|
|
271
|
-
- Sequential options: 'Blues', 'Greens', 'Reds', 'Oranges'
|
|
272
|
-
font_size (int): The base font size to apply to the plots.
|
|
273
|
-
|
|
274
|
-
<br>
|
|
275
|
-
|
|
276
|
-
### [Matplotlib Colormaps](https://matplotlib.org/stable/users/explain/colors/colormaps.html)
|
|
277
|
-
"""
|
|
278
|
-
self.heatmap_cmap = heatmap_cmap
|
|
279
|
-
self.cm_cmap = cm_cmap
|
|
280
|
-
self.font_size = font_size
|
|
281
|
-
|
|
282
|
-
def __repr__(self) -> str:
|
|
283
|
-
parts = [
|
|
284
|
-
f"heatmap_cmap='{self.heatmap_cmap}'",
|
|
285
|
-
f"cm_cmap='{self.cm_cmap}'",
|
|
286
|
-
f"font_size={self.font_size}"
|
|
287
|
-
]
|
|
288
|
-
return f"{self.__class__.__name__}({', '.join(parts)})"
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
class _BaseSequenceValueFormat:
|
|
292
|
-
"""
|
|
293
|
-
[PRIVATE] Base configuration for sequence to value metrics.
|
|
294
|
-
"""
|
|
295
|
-
def __init__(self,
|
|
296
|
-
font_size: int=25,
|
|
297
|
-
scatter_color: str='tab:blue',
|
|
298
|
-
scatter_alpha: float=0.6,
|
|
299
|
-
ideal_line_color: str='k',
|
|
300
|
-
residual_line_color: str='red',
|
|
301
|
-
hist_bins: Union[int, str] = 'auto') -> None:
|
|
302
|
-
"""
|
|
303
|
-
Initializes the formatting configuration for sequence to value metrics.
|
|
304
|
-
|
|
305
|
-
Args:
|
|
306
|
-
font_size (int): The base font size to apply to the plots.
|
|
307
|
-
scatter_color (str): Matplotlib color for the scatter plot points.
|
|
308
|
-
- Common color names: 'tab:blue', 'crimson', 'forestgreen', '#4682B4'
|
|
309
|
-
scatter_alpha (float): Alpha transparency for scatter plot points.
|
|
310
|
-
ideal_line_color (str): Matplotlib color for the 'ideal' y=x line in the
|
|
311
|
-
True vs. Predicted plot.
|
|
312
|
-
- Common color names: 'k', 'red', 'darkgrey', '#FF6347'
|
|
313
|
-
residual_line_color (str): Matplotlib color for the y=0 line in the
|
|
314
|
-
Residual plot.
|
|
315
|
-
- Common color names: 'red', 'blue', 'k', '#4682B4'
|
|
316
|
-
hist_bins (int | str): The number of bins for the residuals histogram.
|
|
317
|
-
Defaults to 'auto' to use seaborn's automatic bin selection.
|
|
318
|
-
- Options: 'auto', 'sqrt', 10, 20
|
|
319
|
-
|
|
320
|
-
<br>
|
|
321
|
-
|
|
322
|
-
### [Matplotlib Colors](https://matplotlib.org/stable/gallery/color/named_colors.html)
|
|
323
|
-
"""
|
|
324
|
-
self.font_size = font_size
|
|
325
|
-
self.scatter_color = scatter_color
|
|
326
|
-
self.scatter_alpha = scatter_alpha
|
|
327
|
-
self.ideal_line_color = ideal_line_color
|
|
328
|
-
self.residual_line_color = residual_line_color
|
|
329
|
-
self.hist_bins = hist_bins
|
|
330
|
-
|
|
331
|
-
def __repr__(self) -> str:
|
|
332
|
-
parts = [
|
|
333
|
-
f"font_size={self.font_size}",
|
|
334
|
-
f"scatter_color='{self.scatter_color}'",
|
|
335
|
-
f"scatter_alpha={self.scatter_alpha}",
|
|
336
|
-
f"ideal_line_color='{self.ideal_line_color}'",
|
|
337
|
-
f"residual_line_color='{self.residual_line_color}'",
|
|
338
|
-
f"hist_bins='{self.hist_bins}'"
|
|
339
|
-
]
|
|
340
|
-
return f"{self.__class__.__name__}({', '.join(parts)})"
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
class _BaseSequenceSequenceFormat:
|
|
344
|
-
"""
|
|
345
|
-
[PRIVATE] Base configuration for sequence-to-sequence metrics.
|
|
346
|
-
"""
|
|
347
|
-
def __init__(self,
|
|
348
|
-
font_size: int = 25,
|
|
349
|
-
grid_style: str = '--',
|
|
350
|
-
rmse_color: str = 'tab:blue',
|
|
351
|
-
rmse_marker: str = 'o-',
|
|
352
|
-
mae_color: str = 'tab:orange',
|
|
353
|
-
mae_marker: str = 's--'):
|
|
354
|
-
"""
|
|
355
|
-
Initializes the formatting configuration for seq-to-seq metrics.
|
|
356
|
-
|
|
357
|
-
Args:
|
|
358
|
-
font_size (int): The base font size to apply to the plots.
|
|
359
|
-
grid_style (str): Matplotlib linestyle for the plot grid.
|
|
360
|
-
- Options: '--' (dashed), ':' (dotted), '-.' (dash-dot), '-' (solid)
|
|
361
|
-
rmse_color (str): Matplotlib color for the RMSE line.
|
|
362
|
-
- Common color names: 'tab:blue', 'crimson', 'forestgreen', '#4682B4'
|
|
363
|
-
rmse_marker (str): Matplotlib marker style for the RMSE line.
|
|
364
|
-
- Options: 'o-' (circle), 's--' (square), '^:' (triangle), 'x' (x marker)
|
|
365
|
-
mae_color (str): Matplotlib color for the MAE line.
|
|
366
|
-
- Common color names: 'tab:orange', 'purple', 'black', '#FF6347'
|
|
367
|
-
mae_marker (str): Matplotlib marker style for the MAE line.
|
|
368
|
-
- Options: 's--', 'o-', 'v:', '+' (plus marker)
|
|
369
|
-
|
|
370
|
-
<br>
|
|
371
|
-
|
|
372
|
-
### [Matplotlib Colors](https://matplotlib.org/stable/gallery/color/named_colors.html)
|
|
373
|
-
|
|
374
|
-
<br>
|
|
375
|
-
|
|
376
|
-
### [Matplotlib Linestyles](https://matplotlib.org/stable/gallery/lines_bars_and_markers/linestyles.html)
|
|
377
|
-
|
|
378
|
-
<br>
|
|
379
|
-
|
|
380
|
-
### [Matplotlib Markers](https://matplotlib.org/stable/api/markers_api.html)
|
|
381
|
-
"""
|
|
382
|
-
self.font_size = font_size
|
|
383
|
-
self.grid_style = grid_style
|
|
384
|
-
self.rmse_color = rmse_color
|
|
385
|
-
self.rmse_marker = rmse_marker
|
|
386
|
-
self.mae_color = mae_color
|
|
387
|
-
self.mae_marker = mae_marker
|
|
388
|
-
|
|
389
|
-
def __repr__(self) -> str:
|
|
390
|
-
parts = [
|
|
391
|
-
f"font_size={self.font_size}",
|
|
392
|
-
f"grid_style='{self.grid_style}'",
|
|
393
|
-
f"rmse_color='{self.rmse_color}'",
|
|
394
|
-
f"mae_color='{self.mae_color}'"
|
|
395
|
-
]
|
|
396
|
-
return f"{self.__class__.__name__}({', '.join(parts)})"
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
class _BaseModelParams(Mapping):
|
|
400
|
-
"""
|
|
401
|
-
[PRIVATE] Base class for model parameter configs.
|
|
402
|
-
|
|
403
|
-
Inherits from Mapping to behave like a dictionary, enabling
|
|
404
|
-
`**params` unpacking directly into model constructors.
|
|
405
|
-
"""
|
|
406
|
-
def __getitem__(self, key: str) -> Any:
|
|
407
|
-
return self.__dict__[key]
|
|
408
|
-
|
|
409
|
-
def __iter__(self):
|
|
410
|
-
return iter(self.__dict__)
|
|
411
|
-
|
|
412
|
-
def __len__(self) -> int:
|
|
413
|
-
return len(self.__dict__)
|
|
414
|
-
|
|
415
|
-
def __or__(self, other) -> Dict[str, Any]:
|
|
416
|
-
"""Allows merging with other Mappings using the | operator."""
|
|
417
|
-
if isinstance(other, Mapping):
|
|
418
|
-
return dict(self) | dict(other)
|
|
419
|
-
return NotImplemented
|
|
420
|
-
|
|
421
|
-
def __ror__(self, other) -> Dict[str, Any]:
|
|
422
|
-
"""Allows merging with other Mappings using the | operator."""
|
|
423
|
-
if isinstance(other, Mapping):
|
|
424
|
-
return dict(other) | dict(self)
|
|
425
|
-
return NotImplemented
|
|
426
|
-
|
|
427
|
-
def __repr__(self) -> str:
|
|
428
|
-
"""Returns a formatted multi-line string representation."""
|
|
429
|
-
class_name = self.__class__.__name__
|
|
430
|
-
# Format parameters for clean logging
|
|
431
|
-
params = []
|
|
432
|
-
for k, v in self.__dict__.items():
|
|
433
|
-
# If value is huge (like FeatureSchema), use its own repr
|
|
434
|
-
val_str = repr(v)
|
|
435
|
-
params.append(f" {k}={val_str}")
|
|
436
|
-
|
|
437
|
-
params_str = ",\n".join(params)
|
|
438
|
-
return f"{class_name}(\n{params_str}\n)"
|
|
439
|
-
|
|
440
|
-
def to_log(self) -> Dict[str, Any]:
|
|
441
|
-
"""
|
|
442
|
-
Safely converts complex types (like FeatureSchema) to their string
|
|
443
|
-
representation for cleaner JSON logging.
|
|
444
|
-
"""
|
|
445
|
-
clean_dict = {}
|
|
446
|
-
for k, v in self.__dict__.items():
|
|
447
|
-
if isinstance(v, FeatureSchema):
|
|
448
|
-
# Force the repr() string, otherwise json.dump treats it as a list
|
|
449
|
-
clean_dict[k] = repr(v)
|
|
450
|
-
elif isinstance(v, Path):
|
|
451
|
-
# JSON cannot serialize Path objects, convert to string
|
|
452
|
-
clean_dict[k] = str(v)
|
|
453
|
-
else:
|
|
454
|
-
clean_dict[k] = v
|
|
455
|
-
return clean_dict
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
# --- Public API classes ---
|
|
459
|
-
|
|
460
|
-
# ----------------------------
|
|
461
|
-
# Model Parameters Configurations
|
|
462
|
-
# ----------------------------
|
|
463
|
-
|
|
464
|
-
# --- Standard Models ---
|
|
465
|
-
|
|
466
|
-
class DragonMLPParams(_BaseModelParams):
|
|
467
|
-
def __init__(self,
|
|
468
|
-
in_features: int,
|
|
469
|
-
out_targets: int,
|
|
470
|
-
hidden_layers: List[int],
|
|
471
|
-
drop_out: float = 0.2) -> None:
|
|
472
|
-
self.in_features = in_features
|
|
473
|
-
self.out_targets = out_targets
|
|
474
|
-
self.hidden_layers = hidden_layers
|
|
475
|
-
self.drop_out = drop_out
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
class DragonAttentionMLPParams(_BaseModelParams):
|
|
479
|
-
def __init__(self,
|
|
480
|
-
in_features: int,
|
|
481
|
-
out_targets: int,
|
|
482
|
-
hidden_layers: List[int],
|
|
483
|
-
drop_out: float = 0.2) -> None:
|
|
484
|
-
self.in_features = in_features
|
|
485
|
-
self.out_targets = out_targets
|
|
486
|
-
self.hidden_layers = hidden_layers
|
|
487
|
-
self.drop_out = drop_out
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
class DragonMultiHeadAttentionNetParams(_BaseModelParams):
|
|
491
|
-
def __init__(self,
|
|
492
|
-
in_features: int,
|
|
493
|
-
out_targets: int,
|
|
494
|
-
hidden_layers: List[int],
|
|
495
|
-
drop_out: float = 0.2,
|
|
496
|
-
num_heads: int = 4,
|
|
497
|
-
attention_dropout: float = 0.1) -> None:
|
|
498
|
-
self.in_features = in_features
|
|
499
|
-
self.out_targets = out_targets
|
|
500
|
-
self.hidden_layers = hidden_layers
|
|
501
|
-
self.drop_out = drop_out
|
|
502
|
-
self.num_heads = num_heads
|
|
503
|
-
self.attention_dropout = attention_dropout
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
class DragonTabularTransformerParams(_BaseModelParams):
|
|
507
|
-
def __init__(self, *,
|
|
508
|
-
schema: FeatureSchema,
|
|
509
|
-
out_targets: int,
|
|
510
|
-
embedding_dim: int = 256,
|
|
511
|
-
num_heads: int = 8,
|
|
512
|
-
num_layers: int = 6,
|
|
513
|
-
dropout: float = 0.2) -> None:
|
|
514
|
-
self.schema = schema
|
|
515
|
-
self.out_targets = out_targets
|
|
516
|
-
self.embedding_dim = embedding_dim
|
|
517
|
-
self.num_heads = num_heads
|
|
518
|
-
self.num_layers = num_layers
|
|
519
|
-
self.dropout = dropout
|
|
520
|
-
|
|
521
|
-
# --- Advanced Models ---
|
|
522
|
-
|
|
523
|
-
class DragonGateParams(_BaseModelParams):
|
|
524
|
-
def __init__(self, *,
|
|
525
|
-
schema: FeatureSchema,
|
|
526
|
-
out_targets: int,
|
|
527
|
-
embedding_dim: int = 16,
|
|
528
|
-
gflu_stages: int = 6,
|
|
529
|
-
gflu_dropout: float = 0.1,
|
|
530
|
-
num_trees: int = 20,
|
|
531
|
-
tree_depth: int = 4,
|
|
532
|
-
tree_dropout: float = 0.1,
|
|
533
|
-
chain_trees: bool = False,
|
|
534
|
-
tree_wise_attention: bool = True,
|
|
535
|
-
tree_wise_attention_dropout: float = 0.1,
|
|
536
|
-
binning_activation: Literal['entmoid', 'sparsemoid', 'sigmoid'] = "entmoid",
|
|
537
|
-
feature_mask_function: Literal['entmax', 'sparsemax', 'softmax', 't-softmax'] = "entmax",
|
|
538
|
-
share_head_weights: bool = True,
|
|
539
|
-
batch_norm_continuous: bool = True) -> None:
|
|
540
|
-
self.schema = schema
|
|
541
|
-
self.out_targets = out_targets
|
|
542
|
-
self.embedding_dim = embedding_dim
|
|
543
|
-
self.gflu_stages = gflu_stages
|
|
544
|
-
self.gflu_dropout = gflu_dropout
|
|
545
|
-
self.num_trees = num_trees
|
|
546
|
-
self.tree_depth = tree_depth
|
|
547
|
-
self.tree_dropout = tree_dropout
|
|
548
|
-
self.chain_trees = chain_trees
|
|
549
|
-
self.tree_wise_attention = tree_wise_attention
|
|
550
|
-
self.tree_wise_attention_dropout = tree_wise_attention_dropout
|
|
551
|
-
self.binning_activation = binning_activation
|
|
552
|
-
self.feature_mask_function = feature_mask_function
|
|
553
|
-
self.share_head_weights = share_head_weights
|
|
554
|
-
self.batch_norm_continuous = batch_norm_continuous
|
|
555
|
-
|
|
556
|
-
|
|
557
|
-
class DragonNodeParams(_BaseModelParams):
|
|
558
|
-
def __init__(self, *,
|
|
559
|
-
schema: FeatureSchema,
|
|
560
|
-
out_targets: int,
|
|
561
|
-
embedding_dim: int = 24,
|
|
562
|
-
num_trees: int = 1024,
|
|
563
|
-
num_layers: int = 2,
|
|
564
|
-
tree_depth: int = 6,
|
|
565
|
-
additional_tree_output_dim: int = 3,
|
|
566
|
-
max_features: Optional[int] = None,
|
|
567
|
-
input_dropout: float = 0.0,
|
|
568
|
-
embedding_dropout: float = 0.0,
|
|
569
|
-
choice_function: Literal['entmax', 'sparsemax', 'softmax'] = 'entmax',
|
|
570
|
-
bin_function: Literal['entmoid', 'sparsemoid', 'sigmoid'] = 'entmoid',
|
|
571
|
-
batch_norm_continuous: bool = False) -> None:
|
|
572
|
-
self.schema = schema
|
|
573
|
-
self.out_targets = out_targets
|
|
574
|
-
self.embedding_dim = embedding_dim
|
|
575
|
-
self.num_trees = num_trees
|
|
576
|
-
self.num_layers = num_layers
|
|
577
|
-
self.tree_depth = tree_depth
|
|
578
|
-
self.additional_tree_output_dim = additional_tree_output_dim
|
|
579
|
-
self.max_features = max_features
|
|
580
|
-
self.input_dropout = input_dropout
|
|
581
|
-
self.embedding_dropout = embedding_dropout
|
|
582
|
-
self.choice_function = choice_function
|
|
583
|
-
self.bin_function = bin_function
|
|
584
|
-
self.batch_norm_continuous = batch_norm_continuous
|
|
585
|
-
|
|
586
|
-
|
|
587
|
-
class DragonAutoIntParams(_BaseModelParams):
|
|
588
|
-
def __init__(self, *,
|
|
589
|
-
schema: FeatureSchema,
|
|
590
|
-
out_targets: int,
|
|
591
|
-
embedding_dim: int = 32,
|
|
592
|
-
attn_embed_dim: int = 32,
|
|
593
|
-
num_heads: int = 2,
|
|
594
|
-
num_attn_blocks: int = 3,
|
|
595
|
-
attn_dropout: float = 0.1,
|
|
596
|
-
has_residuals: bool = True,
|
|
597
|
-
attention_pooling: bool = True,
|
|
598
|
-
deep_layers: bool = True,
|
|
599
|
-
layers: str = "128-64-32",
|
|
600
|
-
activation: str = "ReLU",
|
|
601
|
-
embedding_dropout: float = 0.0,
|
|
602
|
-
batch_norm_continuous: bool = False) -> None:
|
|
603
|
-
self.schema = schema
|
|
604
|
-
self.out_targets = out_targets
|
|
605
|
-
self.embedding_dim = embedding_dim
|
|
606
|
-
self.attn_embed_dim = attn_embed_dim
|
|
607
|
-
self.num_heads = num_heads
|
|
608
|
-
self.num_attn_blocks = num_attn_blocks
|
|
609
|
-
self.attn_dropout = attn_dropout
|
|
610
|
-
self.has_residuals = has_residuals
|
|
611
|
-
self.attention_pooling = attention_pooling
|
|
612
|
-
self.deep_layers = deep_layers
|
|
613
|
-
self.layers = layers
|
|
614
|
-
self.activation = activation
|
|
615
|
-
self.embedding_dropout = embedding_dropout
|
|
616
|
-
self.batch_norm_continuous = batch_norm_continuous
|
|
617
|
-
|
|
618
|
-
|
|
619
|
-
class DragonTabNetParams(_BaseModelParams):
|
|
620
|
-
def __init__(self, *,
|
|
621
|
-
schema: FeatureSchema,
|
|
622
|
-
out_targets: int,
|
|
623
|
-
n_d: int = 8,
|
|
624
|
-
n_a: int = 8,
|
|
625
|
-
n_steps: int = 3,
|
|
626
|
-
gamma: float = 1.3,
|
|
627
|
-
n_independent: int = 2,
|
|
628
|
-
n_shared: int = 2,
|
|
629
|
-
virtual_batch_size: int = 128,
|
|
630
|
-
momentum: float = 0.02,
|
|
631
|
-
mask_type: Literal['sparsemax', 'entmax', 'softmax'] = 'sparsemax',
|
|
632
|
-
batch_norm_continuous: bool = False) -> None:
|
|
633
|
-
self.schema = schema
|
|
634
|
-
self.out_targets = out_targets
|
|
635
|
-
self.n_d = n_d
|
|
636
|
-
self.n_a = n_a
|
|
637
|
-
self.n_steps = n_steps
|
|
638
|
-
self.gamma = gamma
|
|
639
|
-
self.n_independent = n_independent
|
|
640
|
-
self.n_shared = n_shared
|
|
641
|
-
self.virtual_batch_size = virtual_batch_size
|
|
642
|
-
self.momentum = momentum
|
|
643
|
-
self.mask_type = mask_type
|
|
644
|
-
self.batch_norm_continuous = batch_norm_continuous
|
|
645
|
-
|
|
646
|
-
|
|
647
|
-
# --- Training Configuration ---
|
|
648
|
-
|
|
649
|
-
class DragonTrainingConfig(_BaseModelParams):
|
|
650
|
-
"""
|
|
651
|
-
Configuration object for the training process.
|
|
652
|
-
|
|
653
|
-
Can be unpacked as a dictionary for logging or accessed as an object.
|
|
654
|
-
|
|
655
|
-
Accepts arbitrary keyword arguments which are set as instance attributes.
|
|
656
|
-
"""
|
|
657
|
-
def __init__(self,
|
|
658
|
-
validation_size: float,
|
|
659
|
-
test_size: float,
|
|
660
|
-
initial_learning_rate: float,
|
|
661
|
-
batch_size: int,
|
|
662
|
-
random_state: int = 101,
|
|
663
|
-
# early_stop_patience: Optional[int] = None,
|
|
664
|
-
# scheduler_patience: Optional[int] = None,
|
|
665
|
-
# scheduler_lr_factor: Optional[float] = None,
|
|
666
|
-
**kwargs: Any) -> None:
|
|
667
|
-
"""
|
|
668
|
-
Args:
|
|
669
|
-
validation_size (float): Proportion of data for validation set.
|
|
670
|
-
test_size (float): Proportion of data for test set.
|
|
671
|
-
initial_learning_rate (float): Starting learning rate.
|
|
672
|
-
batch_size (int): Number of samples per training batch.
|
|
673
|
-
random_state (int): Seed for reproducibility.
|
|
674
|
-
**kwargs: Additional training parameters as key-value pairs.
|
|
675
|
-
"""
|
|
676
|
-
self.validation_size = validation_size
|
|
677
|
-
self.test_size = test_size
|
|
678
|
-
self.initial_learning_rate = initial_learning_rate
|
|
679
|
-
self.batch_size = batch_size
|
|
680
|
-
self.random_state = random_state
|
|
681
|
-
# self.early_stop_patience = early_stop_patience
|
|
682
|
-
# self.scheduler_patience = scheduler_patience
|
|
683
|
-
# self.scheduler_lr_factor = scheduler_lr_factor
|
|
684
|
-
|
|
685
|
-
# Process kwargs with validation
|
|
686
|
-
for key, value in kwargs.items():
|
|
687
|
-
# Python guarantees 'key' is a string for **kwargs
|
|
688
|
-
|
|
689
|
-
# Allow None in value
|
|
690
|
-
if value is None:
|
|
691
|
-
setattr(self, key, value)
|
|
692
|
-
continue
|
|
693
|
-
|
|
694
|
-
if isinstance(value, dict):
|
|
695
|
-
_LOGGER.error("Nested dictionaries are not supported, unpack them first.")
|
|
696
|
-
raise TypeError()
|
|
697
|
-
|
|
698
|
-
# Check if value is a number or a string or a JSON supported type, except dict
|
|
699
|
-
if not isinstance(value, (str, int, float, bool, list, tuple)):
|
|
700
|
-
_LOGGER.error(f"Invalid type for configuration '{key}': {type(value).__name__}")
|
|
701
|
-
raise TypeError()
|
|
702
|
-
|
|
703
|
-
setattr(self, key, value)
|
|
704
|
-
|
|
705
|
-
|
|
706
|
-
class DragonParetoConfig(_BaseModelParams):
|
|
707
|
-
"""
|
|
708
|
-
Configuration object for the Pareto Optimization process.
|
|
709
|
-
"""
|
|
710
|
-
def __init__(self,
|
|
711
|
-
save_directory: Union[str, Path],
|
|
712
|
-
target_objectives: Dict[str, Literal["min", "max"]],
|
|
713
|
-
continuous_bounds_map: Union[Dict[str, Tuple[float, float]], Dict[str, List[float]], str, Path],
|
|
714
|
-
columns_to_round: Optional[List[str]] = None,
|
|
715
|
-
population_size: int = 500,
|
|
716
|
-
generations: int = 1000,
|
|
717
|
-
solutions_filename: str = "NonDominatedSolutions",
|
|
718
|
-
float_precision: int = 4,
|
|
719
|
-
log_interval: int = 10,
|
|
720
|
-
plot_size: Tuple[int, int] = (10, 7),
|
|
721
|
-
plot_font_size: int = 16,
|
|
722
|
-
discretize_start_at_zero: bool = True):
|
|
723
|
-
"""
|
|
724
|
-
Configure the Pareto Optimizer.
|
|
725
|
-
|
|
726
|
-
Args:
|
|
727
|
-
save_directory (str | Path): Directory to save artifacts.
|
|
728
|
-
target_objectives (Dict[str, "min"|"max"]): Dictionary mapping target names to optimization direction.
|
|
729
|
-
Example: {"price": "max", "error": "min"}
|
|
730
|
-
continuous_bounds_map (Dict): Bounds for continuous features {name: (min, max)}. Or a path/str to a directory containing the "optimization_bounds.json" file.
|
|
731
|
-
columns_to_round (List[str] | None): List of continuous column names that should be rounded to the nearest integer.
|
|
732
|
-
population_size (int): Size of the genetic population.
|
|
733
|
-
generations (int): Number of generations to run.
|
|
734
|
-
solutions_filename (str): Filename for saving Pareto solutions.
|
|
735
|
-
float_precision (int): Number of decimal places to round standard float columns.
|
|
736
|
-
log_interval (int): Interval for logging progress.
|
|
737
|
-
plot_size (Tuple[int, int]): Size of the 2D plots.
|
|
738
|
-
plot_font_size (int): Font size for plot text.
|
|
739
|
-
discretize_start_at_zero (bool): Categorical encoding start index. True=0, False=1.
|
|
740
|
-
"""
|
|
741
|
-
# Validate string or Path
|
|
742
|
-
valid_save_dir = make_fullpath(save_directory, make=True, enforce="directory")
|
|
743
|
-
|
|
744
|
-
if isinstance(continuous_bounds_map, (str, Path)):
|
|
745
|
-
continuous_bounds_map = make_fullpath(continuous_bounds_map, make=False, enforce="directory")
|
|
746
|
-
|
|
747
|
-
self.save_directory = valid_save_dir
|
|
748
|
-
self.target_objectives = target_objectives
|
|
749
|
-
self.continuous_bounds_map = continuous_bounds_map
|
|
750
|
-
self.columns_to_round = columns_to_round
|
|
751
|
-
self.population_size = population_size
|
|
752
|
-
self.generations = generations
|
|
753
|
-
self.solutions_filename = solutions_filename
|
|
754
|
-
self.float_precision = float_precision
|
|
755
|
-
self.log_interval = log_interval
|
|
756
|
-
self.plot_size = plot_size
|
|
757
|
-
self.plot_font_size = plot_font_size
|
|
758
|
-
self.discretize_start_at_zero = discretize_start_at_zero
|
|
759
|
-
|
|
760
|
-
# ----------------------------
|
|
761
|
-
# Metrics Configurations
|
|
762
|
-
# ----------------------------
|
|
763
|
-
|
|
764
|
-
# Regression
|
|
765
|
-
class RegressionMetricsFormat(_BaseRegressionFormat):
|
|
766
|
-
"""
|
|
767
|
-
Configuration for single-target regression.
|
|
768
|
-
"""
|
|
769
|
-
def __init__(self,
|
|
770
|
-
font_size: int=26,
|
|
771
|
-
scatter_color: str='tab:blue',
|
|
772
|
-
scatter_alpha: float=0.6,
|
|
773
|
-
ideal_line_color: str='k',
|
|
774
|
-
residual_line_color: str='red',
|
|
775
|
-
hist_bins: Union[int, str] = 'auto',
|
|
776
|
-
xtick_size: int=22,
|
|
777
|
-
ytick_size: int=22) -> None:
|
|
778
|
-
super().__init__(font_size=font_size,
|
|
779
|
-
scatter_color=scatter_color,
|
|
780
|
-
scatter_alpha=scatter_alpha,
|
|
781
|
-
ideal_line_color=ideal_line_color,
|
|
782
|
-
residual_line_color=residual_line_color,
|
|
783
|
-
hist_bins=hist_bins,
|
|
784
|
-
xtick_size=xtick_size,
|
|
785
|
-
ytick_size=ytick_size)
|
|
786
|
-
|
|
787
|
-
|
|
788
|
-
# Multitarget regression
|
|
789
|
-
class MultiTargetRegressionMetricsFormat(_BaseRegressionFormat):
|
|
790
|
-
"""
|
|
791
|
-
Configuration for multi-target regression.
|
|
792
|
-
"""
|
|
793
|
-
def __init__(self,
|
|
794
|
-
font_size: int=26,
|
|
795
|
-
scatter_color: str='tab:blue',
|
|
796
|
-
scatter_alpha: float=0.6,
|
|
797
|
-
ideal_line_color: str='k',
|
|
798
|
-
residual_line_color: str='red',
|
|
799
|
-
hist_bins: Union[int, str] = 'auto',
|
|
800
|
-
xtick_size: int=22,
|
|
801
|
-
ytick_size: int=22) -> None:
|
|
802
|
-
super().__init__(font_size=font_size,
|
|
803
|
-
scatter_color=scatter_color,
|
|
804
|
-
scatter_alpha=scatter_alpha,
|
|
805
|
-
ideal_line_color=ideal_line_color,
|
|
806
|
-
residual_line_color=residual_line_color,
|
|
807
|
-
hist_bins=hist_bins,
|
|
808
|
-
xtick_size=xtick_size,
|
|
809
|
-
ytick_size=ytick_size)
|
|
810
|
-
|
|
811
|
-
|
|
812
|
-
# Classification
|
|
813
|
-
class BinaryClassificationMetricsFormat(_BaseClassificationFormat):
|
|
814
|
-
"""
|
|
815
|
-
Configuration for binary classification.
|
|
816
|
-
"""
|
|
817
|
-
def __init__(self,
|
|
818
|
-
cmap: str="BuGn",
|
|
819
|
-
ROC_PR_line: str='darkorange',
|
|
820
|
-
calibration_bins: int=15,
|
|
821
|
-
font_size: int=26,
|
|
822
|
-
xtick_size: int=22,
|
|
823
|
-
ytick_size: int=22,
|
|
824
|
-
legend_size: int=26,
|
|
825
|
-
cm_font_size: int=26
|
|
826
|
-
) -> None:
|
|
827
|
-
super().__init__(cmap=cmap,
|
|
828
|
-
ROC_PR_line=ROC_PR_line,
|
|
829
|
-
calibration_bins=calibration_bins,
|
|
830
|
-
font_size=font_size,
|
|
831
|
-
xtick_size=xtick_size,
|
|
832
|
-
ytick_size=ytick_size,
|
|
833
|
-
legend_size=legend_size,
|
|
834
|
-
cm_font_size=cm_font_size)
|
|
835
|
-
|
|
836
|
-
|
|
837
|
-
class MultiClassClassificationMetricsFormat(_BaseClassificationFormat):
|
|
838
|
-
"""
|
|
839
|
-
Configuration for multi-class classification.
|
|
840
|
-
"""
|
|
841
|
-
def __init__(self,
|
|
842
|
-
cmap: str="BuGn",
|
|
843
|
-
ROC_PR_line: str='darkorange',
|
|
844
|
-
calibration_bins: int=15,
|
|
845
|
-
font_size: int=26,
|
|
846
|
-
xtick_size: int=22,
|
|
847
|
-
ytick_size: int=22,
|
|
848
|
-
legend_size: int=26,
|
|
849
|
-
cm_font_size: int=26
|
|
850
|
-
) -> None:
|
|
851
|
-
super().__init__(cmap=cmap,
|
|
852
|
-
ROC_PR_line=ROC_PR_line,
|
|
853
|
-
calibration_bins=calibration_bins,
|
|
854
|
-
font_size=font_size,
|
|
855
|
-
xtick_size=xtick_size,
|
|
856
|
-
ytick_size=ytick_size,
|
|
857
|
-
legend_size=legend_size,
|
|
858
|
-
cm_font_size=cm_font_size)
|
|
859
|
-
|
|
860
|
-
class BinaryImageClassificationMetricsFormat(_BaseClassificationFormat):
|
|
861
|
-
"""
|
|
862
|
-
Configuration for binary image classification.
|
|
863
|
-
"""
|
|
864
|
-
def __init__(self,
|
|
865
|
-
cmap: str="BuGn",
|
|
866
|
-
ROC_PR_line: str='darkorange',
|
|
867
|
-
calibration_bins: int=15,
|
|
868
|
-
font_size: int=26,
|
|
869
|
-
xtick_size: int=22,
|
|
870
|
-
ytick_size: int=22,
|
|
871
|
-
legend_size: int=26,
|
|
872
|
-
cm_font_size: int=26
|
|
873
|
-
) -> None:
|
|
874
|
-
super().__init__(cmap=cmap,
|
|
875
|
-
ROC_PR_line=ROC_PR_line,
|
|
876
|
-
calibration_bins=calibration_bins,
|
|
877
|
-
font_size=font_size,
|
|
878
|
-
xtick_size=xtick_size,
|
|
879
|
-
ytick_size=ytick_size,
|
|
880
|
-
legend_size=legend_size,
|
|
881
|
-
cm_font_size=cm_font_size)
|
|
882
|
-
|
|
883
|
-
class MultiClassImageClassificationMetricsFormat(_BaseClassificationFormat):
|
|
884
|
-
"""
|
|
885
|
-
Configuration for multi-class image classification.
|
|
886
|
-
"""
|
|
887
|
-
def __init__(self,
|
|
888
|
-
cmap: str="BuGn",
|
|
889
|
-
ROC_PR_line: str='darkorange',
|
|
890
|
-
calibration_bins: int=15,
|
|
891
|
-
font_size: int=26,
|
|
892
|
-
xtick_size: int=22,
|
|
893
|
-
ytick_size: int=22,
|
|
894
|
-
legend_size: int=26,
|
|
895
|
-
cm_font_size: int=26
|
|
896
|
-
) -> None:
|
|
897
|
-
super().__init__(cmap=cmap,
|
|
898
|
-
ROC_PR_line=ROC_PR_line,
|
|
899
|
-
calibration_bins=calibration_bins,
|
|
900
|
-
font_size=font_size,
|
|
901
|
-
xtick_size=xtick_size,
|
|
902
|
-
ytick_size=ytick_size,
|
|
903
|
-
legend_size=legend_size,
|
|
904
|
-
cm_font_size=cm_font_size)
|
|
905
|
-
|
|
906
|
-
# Multi-Label classification
|
|
907
|
-
class MultiLabelBinaryClassificationMetricsFormat(_BaseMultiLabelFormat):
|
|
908
|
-
"""
|
|
909
|
-
Configuration for multi-label binary classification.
|
|
910
|
-
"""
|
|
911
|
-
def __init__(self,
|
|
912
|
-
cmap: str = "BuGn",
|
|
913
|
-
ROC_PR_line: str='darkorange',
|
|
914
|
-
font_size: int = 25,
|
|
915
|
-
xtick_size: int=20,
|
|
916
|
-
ytick_size: int=20,
|
|
917
|
-
legend_size: int=23
|
|
918
|
-
) -> None:
|
|
919
|
-
super().__init__(cmap=cmap,
|
|
920
|
-
ROC_PR_line=ROC_PR_line,
|
|
921
|
-
font_size=font_size,
|
|
922
|
-
xtick_size=xtick_size,
|
|
923
|
-
ytick_size=ytick_size,
|
|
924
|
-
legend_size=legend_size)
|
|
925
|
-
|
|
926
|
-
# Segmentation
|
|
927
|
-
class BinarySegmentationMetricsFormat(_BaseSegmentationFormat):
|
|
928
|
-
"""
|
|
929
|
-
Configuration for binary segmentation.
|
|
930
|
-
"""
|
|
931
|
-
def __init__(self,
|
|
932
|
-
heatmap_cmap: str = "BuGn",
|
|
933
|
-
cm_cmap: str = "Purples",
|
|
934
|
-
font_size: int = 16) -> None:
|
|
935
|
-
super().__init__(heatmap_cmap=heatmap_cmap,
|
|
936
|
-
cm_cmap=cm_cmap,
|
|
937
|
-
font_size=font_size)
|
|
938
|
-
|
|
939
|
-
|
|
940
|
-
class MultiClassSegmentationMetricsFormat(_BaseSegmentationFormat):
|
|
941
|
-
"""
|
|
942
|
-
Configuration for multi-class segmentation.
|
|
943
|
-
"""
|
|
944
|
-
def __init__(self,
|
|
945
|
-
heatmap_cmap: str = "BuGn",
|
|
946
|
-
cm_cmap: str = "Purples",
|
|
947
|
-
font_size: int = 16) -> None:
|
|
948
|
-
super().__init__(heatmap_cmap=heatmap_cmap,
|
|
949
|
-
cm_cmap=cm_cmap,
|
|
950
|
-
font_size=font_size)
|
|
951
|
-
|
|
952
|
-
|
|
953
|
-
# Sequence
|
|
954
|
-
class SequenceValueMetricsFormat(_BaseSequenceValueFormat):
|
|
955
|
-
"""
|
|
956
|
-
Configuration for sequence-to-value prediction.
|
|
957
|
-
"""
|
|
958
|
-
def __init__(self,
|
|
959
|
-
font_size: int=25,
|
|
960
|
-
scatter_color: str='tab:blue',
|
|
961
|
-
scatter_alpha: float=0.6,
|
|
962
|
-
ideal_line_color: str='k',
|
|
963
|
-
residual_line_color: str='red',
|
|
964
|
-
hist_bins: Union[int, str] = 'auto') -> None:
|
|
965
|
-
super().__init__(font_size=font_size,
|
|
966
|
-
scatter_color=scatter_color,
|
|
967
|
-
scatter_alpha=scatter_alpha,
|
|
968
|
-
ideal_line_color=ideal_line_color,
|
|
969
|
-
residual_line_color=residual_line_color,
|
|
970
|
-
hist_bins=hist_bins)
|
|
971
|
-
|
|
972
|
-
|
|
973
|
-
class SequenceSequenceMetricsFormat(_BaseSequenceSequenceFormat):
|
|
974
|
-
"""
|
|
975
|
-
Configuration for sequence-to-sequence prediction.
|
|
976
|
-
"""
|
|
977
|
-
def __init__(self,
|
|
978
|
-
font_size: int = 25,
|
|
979
|
-
grid_style: str = '--',
|
|
980
|
-
rmse_color: str = 'tab:blue',
|
|
981
|
-
rmse_marker: str = 'o-',
|
|
982
|
-
mae_color: str = 'tab:orange',
|
|
983
|
-
mae_marker: str = 's--'):
|
|
984
|
-
super().__init__(font_size=font_size,
|
|
985
|
-
grid_style=grid_style,
|
|
986
|
-
rmse_color=rmse_color,
|
|
987
|
-
rmse_marker=rmse_marker,
|
|
988
|
-
mae_color=mae_color,
|
|
989
|
-
mae_marker=mae_marker)
|
|
990
|
-
|
|
991
|
-
|
|
992
|
-
# -------- Finalize classes --------
|
|
993
|
-
class _FinalizeModelTraining:
|
|
994
|
-
"""
|
|
995
|
-
Base class for finalizing model training.
|
|
996
|
-
|
|
997
|
-
This class is not intended to be instantiated directly. Instead, use one of its specific subclasses.
|
|
998
|
-
"""
|
|
999
|
-
def __init__(self,
|
|
1000
|
-
filename: str,
|
|
1001
|
-
) -> None:
|
|
1002
|
-
self.filename = _validate_string(string=filename, attribute_name="filename", extension=".pth")
|
|
1003
|
-
self.target_name: Optional[str] = None
|
|
1004
|
-
self.target_names: Optional[list[str]] = None
|
|
1005
|
-
self.classification_threshold: Optional[float] = None
|
|
1006
|
-
self.class_map: Optional[dict[str,int]] = None
|
|
1007
|
-
self.initial_sequence: Optional[np.ndarray] = None
|
|
1008
|
-
self.sequence_length: Optional[int] = None
|
|
1009
|
-
self.task: str = 'UNKNOWN'
|
|
1010
|
-
|
|
1011
|
-
|
|
1012
|
-
class FinalizeRegression(_FinalizeModelTraining):
|
|
1013
|
-
"""Parameters for finalizing a single-target regression model."""
|
|
1014
|
-
def __init__(self,
|
|
1015
|
-
filename: str,
|
|
1016
|
-
target_name: str,
|
|
1017
|
-
) -> None:
|
|
1018
|
-
"""Initializes the finalization parameters.
|
|
1019
|
-
|
|
1020
|
-
Args:
|
|
1021
|
-
filename (str): The name of the file to be saved.
|
|
1022
|
-
target_name (str): The name of the target variable.
|
|
1023
|
-
"""
|
|
1024
|
-
super().__init__(filename=filename)
|
|
1025
|
-
self.target_name = _validate_string(string=target_name, attribute_name="Target name")
|
|
1026
|
-
self.task = MLTaskKeys.REGRESSION
|
|
1027
|
-
|
|
1028
|
-
|
|
1029
|
-
class FinalizeMultiTargetRegression(_FinalizeModelTraining):
|
|
1030
|
-
"""Parameters for finalizing a multi-target regression model."""
|
|
1031
|
-
def __init__(self,
|
|
1032
|
-
filename: str,
|
|
1033
|
-
target_names: list[str],
|
|
1034
|
-
) -> None:
|
|
1035
|
-
"""Initializes the finalization parameters.
|
|
1036
|
-
|
|
1037
|
-
Args:
|
|
1038
|
-
filename (str): The name of the file to be saved.
|
|
1039
|
-
target_names (list[str]): A list of names for the target variables.
|
|
1040
|
-
"""
|
|
1041
|
-
super().__init__(filename=filename)
|
|
1042
|
-
safe_names = [_validate_string(string=target_name, attribute_name="All target names") for target_name in target_names]
|
|
1043
|
-
self.target_names = safe_names
|
|
1044
|
-
self.task = MLTaskKeys.MULTITARGET_REGRESSION
|
|
1045
|
-
|
|
1046
|
-
|
|
1047
|
-
class FinalizeBinaryClassification(_FinalizeModelTraining):
|
|
1048
|
-
"""Parameters for finalizing a binary classification model."""
|
|
1049
|
-
def __init__(self,
|
|
1050
|
-
filename: str,
|
|
1051
|
-
target_name: str,
|
|
1052
|
-
classification_threshold: float,
|
|
1053
|
-
class_map: dict[str,int]
|
|
1054
|
-
) -> None:
|
|
1055
|
-
"""Initializes the finalization parameters.
|
|
1056
|
-
|
|
1057
|
-
Args:
|
|
1058
|
-
filename (str): The name of the file to be saved.
|
|
1059
|
-
target_name (str): The name of the target variable.
|
|
1060
|
-
classification_threshold (float): The cutoff threshold for classifying as the positive class.
|
|
1061
|
-
class_map (dict[str,int]): A dictionary mapping class names (str)
|
|
1062
|
-
to their integer representations (e.g., {'cat': 0, 'dog': 1}).
|
|
1063
|
-
"""
|
|
1064
|
-
super().__init__(filename=filename)
|
|
1065
|
-
self.target_name = _validate_string(string=target_name, attribute_name="Target name")
|
|
1066
|
-
self.classification_threshold = _validate_threshold(classification_threshold)
|
|
1067
|
-
self.class_map = _validate_class_map(class_map)
|
|
1068
|
-
self.task = MLTaskKeys.BINARY_CLASSIFICATION
|
|
1069
|
-
|
|
1070
|
-
|
|
1071
|
-
class FinalizeMultiClassClassification(_FinalizeModelTraining):
|
|
1072
|
-
"""Parameters for finalizing a multi-class classification model."""
|
|
1073
|
-
def __init__(self,
|
|
1074
|
-
filename: str,
|
|
1075
|
-
target_name: str,
|
|
1076
|
-
class_map: dict[str,int]
|
|
1077
|
-
) -> None:
|
|
1078
|
-
"""Initializes the finalization parameters.
|
|
1079
|
-
|
|
1080
|
-
Args:
|
|
1081
|
-
filename (str): The name of the file to be saved.
|
|
1082
|
-
target_name (str): The name of the target variable.
|
|
1083
|
-
class_map (dict[str,int]): A dictionary mapping class names (str)
|
|
1084
|
-
to their integer representations (e.g., {'cat': 0, 'dog': 1}).
|
|
1085
|
-
"""
|
|
1086
|
-
super().__init__(filename=filename)
|
|
1087
|
-
self.target_name = _validate_string(string=target_name, attribute_name="Target name")
|
|
1088
|
-
self.class_map = _validate_class_map(class_map)
|
|
1089
|
-
self.task = MLTaskKeys.MULTICLASS_CLASSIFICATION
|
|
1090
|
-
|
|
1091
|
-
|
|
1092
|
-
class FinalizeBinaryImageClassification(_FinalizeModelTraining):
|
|
1093
|
-
"""Parameters for finalizing a binary image classification model."""
|
|
1094
|
-
def __init__(self,
|
|
1095
|
-
filename: str,
|
|
1096
|
-
classification_threshold: float,
|
|
1097
|
-
class_map: dict[str,int]
|
|
1098
|
-
) -> None:
|
|
1099
|
-
"""Initializes the finalization parameters.
|
|
1100
|
-
|
|
1101
|
-
Args:
|
|
1102
|
-
filename (str): The name of the file to be saved.
|
|
1103
|
-
classification_threshold (float): The cutoff threshold for
|
|
1104
|
-
classifying as the positive class.
|
|
1105
|
-
class_map (dict[str,int]): A dictionary mapping class names (str)
|
|
1106
|
-
to their integer representations (e.g., {'cat': 0, 'dog': 1}).
|
|
1107
|
-
"""
|
|
1108
|
-
super().__init__(filename=filename)
|
|
1109
|
-
self.classification_threshold = _validate_threshold(classification_threshold)
|
|
1110
|
-
self.class_map = _validate_class_map(class_map)
|
|
1111
|
-
self.task = MLTaskKeys.BINARY_IMAGE_CLASSIFICATION
|
|
1112
|
-
|
|
1113
|
-
|
|
1114
|
-
class FinalizeMultiClassImageClassification(_FinalizeModelTraining):
|
|
1115
|
-
"""Parameters for finalizing a multi-class image classification model."""
|
|
1116
|
-
def __init__(self,
|
|
1117
|
-
filename: str,
|
|
1118
|
-
class_map: dict[str,int]
|
|
1119
|
-
) -> None:
|
|
1120
|
-
"""Initializes the finalization parameters.
|
|
1121
|
-
|
|
1122
|
-
Args:
|
|
1123
|
-
filename (str): The name of the file to be saved.
|
|
1124
|
-
class_map (dict[str,int]): A dictionary mapping class names (str)
|
|
1125
|
-
to their integer representations (e.g., {'cat': 0, 'dog': 1}).
|
|
1126
|
-
"""
|
|
1127
|
-
super().__init__(filename=filename)
|
|
1128
|
-
self.class_map = _validate_class_map(class_map)
|
|
1129
|
-
self.task = MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION
|
|
1130
|
-
|
|
1131
|
-
|
|
1132
|
-
class FinalizeMultiLabelBinaryClassification(_FinalizeModelTraining):
|
|
1133
|
-
"""Parameters for finalizing a multi-label binary classification model."""
|
|
1134
|
-
def __init__(self,
|
|
1135
|
-
filename: str,
|
|
1136
|
-
target_names: list[str],
|
|
1137
|
-
classification_threshold: float,
|
|
1138
|
-
) -> None:
|
|
1139
|
-
"""Initializes the finalization parameters.
|
|
1140
|
-
|
|
1141
|
-
Args:
|
|
1142
|
-
filename (str): The name of the file to be saved.
|
|
1143
|
-
target_names (list[str]): A list of names for the target variables.
|
|
1144
|
-
classification_threshold (float): The cutoff threshold for classifying as the positive class.
|
|
1145
|
-
"""
|
|
1146
|
-
super().__init__(filename=filename)
|
|
1147
|
-
safe_names = [_validate_string(string=target_name, attribute_name="All target names") for target_name in target_names]
|
|
1148
|
-
self.target_names = safe_names
|
|
1149
|
-
self.classification_threshold = _validate_threshold(classification_threshold)
|
|
1150
|
-
self.task = MLTaskKeys.MULTILABEL_BINARY_CLASSIFICATION
|
|
1151
|
-
|
|
1152
|
-
|
|
1153
|
-
class FinalizeBinarySegmentation(_FinalizeModelTraining):
|
|
1154
|
-
"""Parameters for finalizing a binary segmentation model."""
|
|
1155
|
-
def __init__(self,
|
|
1156
|
-
filename: str,
|
|
1157
|
-
class_map: dict[str,int],
|
|
1158
|
-
classification_threshold: float,
|
|
1159
|
-
) -> None:
|
|
1160
|
-
"""Initializes the finalization parameters.
|
|
1161
|
-
|
|
1162
|
-
Args:
|
|
1163
|
-
filename (str): The name of the file to be saved.
|
|
1164
|
-
classification_threshold (float): The cutoff threshold for classifying as the positive class (mask).
|
|
1165
|
-
"""
|
|
1166
|
-
super().__init__(filename=filename)
|
|
1167
|
-
self.classification_threshold = _validate_threshold(classification_threshold)
|
|
1168
|
-
self.class_map = _validate_class_map(class_map)
|
|
1169
|
-
self.task = MLTaskKeys.BINARY_SEGMENTATION
|
|
1170
|
-
|
|
1171
|
-
|
|
1172
|
-
class FinalizeMultiClassSegmentation(_FinalizeModelTraining):
|
|
1173
|
-
"""Parameters for finalizing a multi-class segmentation model."""
|
|
1174
|
-
def __init__(self,
|
|
1175
|
-
filename: str,
|
|
1176
|
-
class_map: dict[str,int]
|
|
1177
|
-
) -> None:
|
|
1178
|
-
"""Initializes the finalization parameters.
|
|
1179
|
-
|
|
1180
|
-
Args:
|
|
1181
|
-
filename (str): The name of the file to be saved.
|
|
1182
|
-
"""
|
|
1183
|
-
super().__init__(filename=filename)
|
|
1184
|
-
self.class_map = _validate_class_map(class_map)
|
|
1185
|
-
self.task = MLTaskKeys.MULTICLASS_SEGMENTATION
|
|
1186
|
-
|
|
1187
|
-
|
|
1188
|
-
class FinalizeObjectDetection(_FinalizeModelTraining):
|
|
1189
|
-
"""Parameters for finalizing an object detection model."""
|
|
1190
|
-
def __init__(self,
|
|
1191
|
-
filename: str,
|
|
1192
|
-
class_map: dict[str,int]
|
|
1193
|
-
) -> None:
|
|
1194
|
-
"""Initializes the finalization parameters.
|
|
1195
|
-
|
|
1196
|
-
Args:
|
|
1197
|
-
filename (str): The name of the file to be saved.
|
|
1198
|
-
"""
|
|
1199
|
-
super().__init__(filename=filename)
|
|
1200
|
-
self.class_map = _validate_class_map(class_map)
|
|
1201
|
-
self.task = MLTaskKeys.OBJECT_DETECTION
|
|
1202
|
-
|
|
1203
|
-
|
|
1204
|
-
class FinalizeSequenceSequencePrediction(_FinalizeModelTraining):
|
|
1205
|
-
"""Parameters for finalizing a sequence-to-sequence prediction model."""
|
|
1206
|
-
def __init__(self,
|
|
1207
|
-
filename: str,
|
|
1208
|
-
last_training_sequence: np.ndarray,
|
|
1209
|
-
) -> None:
|
|
1210
|
-
"""Initializes the finalization parameters.
|
|
1211
|
-
|
|
1212
|
-
Args:
|
|
1213
|
-
filename (str): The name of the file to be saved.
|
|
1214
|
-
last_training_sequence (np.ndarray): The last sequence from the training data, needed to start predictions.
|
|
1215
|
-
"""
|
|
1216
|
-
super().__init__(filename=filename)
|
|
1217
|
-
|
|
1218
|
-
if not isinstance(last_training_sequence, np.ndarray):
|
|
1219
|
-
_LOGGER.error(f"The last training sequence must be a 1D numpy array, got {type(last_training_sequence)}.")
|
|
1220
|
-
raise TypeError()
|
|
1221
|
-
|
|
1222
|
-
if last_training_sequence.ndim == 1:
|
|
1223
|
-
# It's already 1D, (N,). This is valid.
|
|
1224
|
-
self.initial_sequence = last_training_sequence
|
|
1225
|
-
elif last_training_sequence.ndim == 2:
|
|
1226
|
-
# Handle both (1, N) and (N, 1)
|
|
1227
|
-
if last_training_sequence.shape[0] == 1:
|
|
1228
|
-
self.initial_sequence = last_training_sequence.flatten()
|
|
1229
|
-
elif last_training_sequence.shape[1] == 1:
|
|
1230
|
-
self.initial_sequence = last_training_sequence.flatten()
|
|
1231
|
-
else:
|
|
1232
|
-
_LOGGER.error(f"The last training sequence must be a 1D numpy array, got shape {last_training_sequence.shape}.")
|
|
1233
|
-
raise ValueError()
|
|
1234
|
-
else:
|
|
1235
|
-
# It's 3D or more, which is not supported
|
|
1236
|
-
_LOGGER.error(f"The last training sequence must be a 1D numpy array, got shape {last_training_sequence.shape}.")
|
|
1237
|
-
raise ValueError()
|
|
1238
|
-
|
|
1239
|
-
# Save the length of the validated 1D sequence
|
|
1240
|
-
self.sequence_length = len(self.initial_sequence) # type: ignore
|
|
1241
|
-
self.task = MLTaskKeys.SEQUENCE_SEQUENCE
|
|
1242
|
-
|
|
1243
|
-
|
|
1244
|
-
class FinalizeSequenceValuePrediction(_FinalizeModelTraining):
|
|
1245
|
-
"""Parameters for finalizing a sequence-to-value prediction model."""
|
|
1246
|
-
def __init__(self,
|
|
1247
|
-
filename: str,
|
|
1248
|
-
last_training_sequence: np.ndarray,
|
|
1249
|
-
) -> None:
|
|
1250
|
-
"""Initializes the finalization parameters.
|
|
1251
|
-
|
|
1252
|
-
Args:
|
|
1253
|
-
filename (str): The name of the file to be saved.
|
|
1254
|
-
last_training_sequence (np.ndarray): The last sequence from the training data, needed to start predictions.
|
|
1255
|
-
"""
|
|
1256
|
-
super().__init__(filename=filename)
|
|
1257
|
-
|
|
1258
|
-
if not isinstance(last_training_sequence, np.ndarray):
|
|
1259
|
-
_LOGGER.error(f"The last training sequence must be a 1D numpy array, got {type(last_training_sequence)}.")
|
|
1260
|
-
raise TypeError()
|
|
1261
|
-
|
|
1262
|
-
if last_training_sequence.ndim == 1:
|
|
1263
|
-
# It's already 1D, (N,). This is valid.
|
|
1264
|
-
self.initial_sequence = last_training_sequence
|
|
1265
|
-
elif last_training_sequence.ndim == 2:
|
|
1266
|
-
# Handle both (1, N) and (N, 1)
|
|
1267
|
-
if last_training_sequence.shape[0] == 1:
|
|
1268
|
-
self.initial_sequence = last_training_sequence.flatten()
|
|
1269
|
-
elif last_training_sequence.shape[1] == 1:
|
|
1270
|
-
self.initial_sequence = last_training_sequence.flatten()
|
|
1271
|
-
else:
|
|
1272
|
-
_LOGGER.error(f"The last training sequence must be a 1D numpy array, got shape {last_training_sequence.shape}.")
|
|
1273
|
-
raise ValueError()
|
|
1274
|
-
else:
|
|
1275
|
-
# It's 3D or more, which is not supported
|
|
1276
|
-
_LOGGER.error(f"The last training sequence must be a 1D numpy array, got shape {last_training_sequence.shape}.")
|
|
1277
|
-
raise ValueError()
|
|
1278
|
-
|
|
1279
|
-
# Save the length of the validated 1D sequence
|
|
1280
|
-
self.sequence_length = len(self.initial_sequence) # type: ignore
|
|
1281
|
-
self.task = MLTaskKeys.SEQUENCE_VALUE
|
|
1282
|
-
|
|
1283
|
-
|
|
1284
|
-
def _validate_string(string: str, attribute_name: str, extension: Optional[str]=None) -> str:
|
|
1285
|
-
"""Helper for finalize classes"""
|
|
1286
|
-
if not isinstance(string, str):
|
|
1287
|
-
_LOGGER.error(f"{attribute_name} must be a string.")
|
|
1288
|
-
raise TypeError()
|
|
1289
|
-
|
|
1290
|
-
if extension:
|
|
1291
|
-
safe_name = sanitize_filename(string)
|
|
1292
|
-
|
|
1293
|
-
if not safe_name.endswith(extension):
|
|
1294
|
-
safe_name += extension
|
|
1295
|
-
else:
|
|
1296
|
-
safe_name = string
|
|
1297
|
-
|
|
1298
|
-
return safe_name
|
|
1299
|
-
|
|
1300
|
-
def _validate_threshold(threshold: float):
|
|
1301
|
-
"""Helper for finalize classes"""
|
|
1302
|
-
if not isinstance(threshold, float):
|
|
1303
|
-
_LOGGER.error(f"Classification threshold must be a float.")
|
|
1304
|
-
raise TypeError()
|
|
1305
|
-
elif threshold < 0.1 or threshold > 0.9:
|
|
1306
|
-
_LOGGER.error(f"Classification threshold must be in the range [0.1, 0.9]")
|
|
1307
|
-
raise ValueError()
|
|
1308
|
-
|
|
1309
|
-
return threshold
|
|
1310
|
-
|
|
1311
|
-
def _validate_class_map(map_dict: dict[str, int]):
|
|
1312
|
-
"""Helper for finalize classes"""
|
|
1313
|
-
if not isinstance(map_dict, dict):
|
|
1314
|
-
_LOGGER.error(f"Class map must be a dictionary, but got {type(map_dict)}.")
|
|
1315
|
-
raise TypeError()
|
|
1316
|
-
|
|
1317
|
-
if not map_dict:
|
|
1318
|
-
_LOGGER.error("Class map dictionary cannot be empty.")
|
|
1319
|
-
raise ValueError()
|
|
1320
|
-
|
|
1321
|
-
for key, val in map_dict.items():
|
|
1322
|
-
if not isinstance(key, str):
|
|
1323
|
-
_LOGGER.error(f"All keys in the class map must be strings, but found key: {key} ({type(key)}).")
|
|
1324
|
-
raise TypeError()
|
|
1325
|
-
if not isinstance(val, int):
|
|
1326
|
-
_LOGGER.error(f"All values in the class map must be integers, but for key '{key}' found value: {val} ({type(val)}).")
|
|
1327
|
-
raise TypeError()
|
|
1328
|
-
|
|
1329
|
-
return map_dict
|
|
1330
|
-
|
|
1331
|
-
def info():
|
|
1332
|
-
_script_info(__all__)
|