dragon-ml-toolbox 19.13.0__py3-none-any.whl → 20.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/METADATA +29 -46
- dragon_ml_toolbox-20.0.0.dist-info/RECORD +178 -0
- ml_tools/{ETL_cleaning.py → ETL_cleaning/__init__.py} +13 -5
- ml_tools/ETL_cleaning/_basic_clean.py +351 -0
- ml_tools/ETL_cleaning/_clean_tools.py +128 -0
- ml_tools/ETL_cleaning/_dragon_cleaner.py +245 -0
- ml_tools/ETL_cleaning/_imprimir.py +13 -0
- ml_tools/{ETL_engineering.py → ETL_engineering/__init__.py} +8 -4
- ml_tools/ETL_engineering/_dragon_engineering.py +261 -0
- ml_tools/ETL_engineering/_imprimir.py +24 -0
- ml_tools/{_core/_ETL_engineering.py → ETL_engineering/_transforms.py} +14 -267
- ml_tools/{_core → GUI_tools}/_GUI_tools.py +37 -40
- ml_tools/{GUI_tools.py → GUI_tools/__init__.py} +7 -5
- ml_tools/GUI_tools/_imprimir.py +12 -0
- ml_tools/IO_tools/_IO_loggers.py +235 -0
- ml_tools/IO_tools/_IO_save_load.py +151 -0
- ml_tools/IO_tools/_IO_utils.py +140 -0
- ml_tools/{IO_tools.py → IO_tools/__init__.py} +13 -5
- ml_tools/IO_tools/_imprimir.py +14 -0
- ml_tools/MICE/_MICE_imputation.py +132 -0
- ml_tools/{MICE_imputation.py → MICE/__init__.py} +6 -7
- ml_tools/{_core/_MICE_imputation.py → MICE/_dragon_mice.py} +243 -322
- ml_tools/MICE/_imprimir.py +11 -0
- ml_tools/{ML_callbacks.py → ML_callbacks/__init__.py} +12 -4
- ml_tools/ML_callbacks/_base.py +101 -0
- ml_tools/ML_callbacks/_checkpoint.py +232 -0
- ml_tools/ML_callbacks/_early_stop.py +208 -0
- ml_tools/ML_callbacks/_imprimir.py +12 -0
- ml_tools/ML_callbacks/_scheduler.py +197 -0
- ml_tools/{ML_chaining_utilities.py → ML_chain/__init__.py} +8 -3
- ml_tools/{_core/_ML_chaining_utilities.py → ML_chain/_chaining_tools.py} +5 -129
- ml_tools/ML_chain/_dragon_chain.py +140 -0
- ml_tools/ML_chain/_imprimir.py +11 -0
- ml_tools/ML_configuration/__init__.py +90 -0
- ml_tools/ML_configuration/_base_model_config.py +69 -0
- ml_tools/ML_configuration/_finalize.py +366 -0
- ml_tools/ML_configuration/_imprimir.py +47 -0
- ml_tools/ML_configuration/_metrics.py +593 -0
- ml_tools/ML_configuration/_models.py +206 -0
- ml_tools/ML_configuration/_training.py +124 -0
- ml_tools/ML_datasetmaster/__init__.py +28 -0
- ml_tools/ML_datasetmaster/_base_datasetmaster.py +337 -0
- ml_tools/{_core/_ML_datasetmaster.py → ML_datasetmaster/_datasetmaster.py} +9 -329
- ml_tools/ML_datasetmaster/_imprimir.py +15 -0
- ml_tools/{_core/_ML_sequence_datasetmaster.py → ML_datasetmaster/_sequence_datasetmaster.py} +13 -15
- ml_tools/{_core/_ML_vision_datasetmaster.py → ML_datasetmaster/_vision_datasetmaster.py} +63 -65
- ml_tools/ML_evaluation/__init__.py +53 -0
- ml_tools/ML_evaluation/_classification.py +629 -0
- ml_tools/ML_evaluation/_feature_importance.py +409 -0
- ml_tools/ML_evaluation/_imprimir.py +25 -0
- ml_tools/ML_evaluation/_loss.py +92 -0
- ml_tools/ML_evaluation/_regression.py +273 -0
- ml_tools/{_core/_ML_sequence_evaluation.py → ML_evaluation/_sequence.py} +8 -11
- ml_tools/{_core/_ML_vision_evaluation.py → ML_evaluation/_vision.py} +12 -17
- ml_tools/{_core → ML_evaluation_captum}/_ML_evaluation_captum.py +11 -38
- ml_tools/{ML_evaluation_captum.py → ML_evaluation_captum/__init__.py} +6 -4
- ml_tools/ML_evaluation_captum/_imprimir.py +10 -0
- ml_tools/{_core → ML_finalize_handler}/_ML_finalize_handler.py +3 -7
- ml_tools/ML_finalize_handler/__init__.py +10 -0
- ml_tools/ML_finalize_handler/_imprimir.py +8 -0
- ml_tools/ML_inference/__init__.py +22 -0
- ml_tools/ML_inference/_base_inference.py +166 -0
- ml_tools/{_core/_ML_chaining_inference.py → ML_inference/_chain_inference.py} +14 -17
- ml_tools/ML_inference/_dragon_inference.py +332 -0
- ml_tools/ML_inference/_imprimir.py +11 -0
- ml_tools/ML_inference/_multi_inference.py +180 -0
- ml_tools/ML_inference_sequence/__init__.py +10 -0
- ml_tools/ML_inference_sequence/_imprimir.py +8 -0
- ml_tools/{_core/_ML_sequence_inference.py → ML_inference_sequence/_sequence_inference.py} +11 -15
- ml_tools/ML_inference_vision/__init__.py +10 -0
- ml_tools/ML_inference_vision/_imprimir.py +8 -0
- ml_tools/{_core/_ML_vision_inference.py → ML_inference_vision/_vision_inference.py} +15 -19
- ml_tools/ML_models/__init__.py +32 -0
- ml_tools/{_core/_ML_models_advanced.py → ML_models/_advanced_models.py} +22 -18
- ml_tools/ML_models/_base_mlp_attention.py +198 -0
- ml_tools/{_core/_models_advanced_base.py → ML_models/_base_save_load.py} +73 -49
- ml_tools/ML_models/_dragon_tabular.py +248 -0
- ml_tools/ML_models/_imprimir.py +18 -0
- ml_tools/ML_models/_mlp_attention.py +134 -0
- ml_tools/{_core → ML_models}/_models_advanced_helpers.py +13 -13
- ml_tools/ML_models_sequence/__init__.py +10 -0
- ml_tools/ML_models_sequence/_imprimir.py +8 -0
- ml_tools/{_core/_ML_sequence_models.py → ML_models_sequence/_sequence_models.py} +5 -8
- ml_tools/ML_models_vision/__init__.py +29 -0
- ml_tools/ML_models_vision/_base_wrapper.py +254 -0
- ml_tools/ML_models_vision/_image_classification.py +182 -0
- ml_tools/ML_models_vision/_image_segmentation.py +108 -0
- ml_tools/ML_models_vision/_imprimir.py +16 -0
- ml_tools/ML_models_vision/_object_detection.py +135 -0
- ml_tools/ML_optimization/__init__.py +21 -0
- ml_tools/ML_optimization/_imprimir.py +13 -0
- ml_tools/{_core/_ML_optimization_pareto.py → ML_optimization/_multi_dragon.py} +18 -24
- ml_tools/ML_optimization/_single_dragon.py +203 -0
- ml_tools/{_core/_ML_optimization.py → ML_optimization/_single_manual.py} +75 -213
- ml_tools/{_core → ML_scaler}/_ML_scaler.py +8 -11
- ml_tools/ML_scaler/__init__.py +10 -0
- ml_tools/ML_scaler/_imprimir.py +8 -0
- ml_tools/ML_trainer/__init__.py +20 -0
- ml_tools/ML_trainer/_base_trainer.py +297 -0
- ml_tools/ML_trainer/_dragon_detection_trainer.py +402 -0
- ml_tools/ML_trainer/_dragon_sequence_trainer.py +540 -0
- ml_tools/ML_trainer/_dragon_trainer.py +1160 -0
- ml_tools/ML_trainer/_imprimir.py +10 -0
- ml_tools/{ML_utilities.py → ML_utilities/__init__.py} +14 -6
- ml_tools/ML_utilities/_artifact_finder.py +382 -0
- ml_tools/ML_utilities/_imprimir.py +16 -0
- ml_tools/ML_utilities/_inspection.py +325 -0
- ml_tools/ML_utilities/_train_tools.py +205 -0
- ml_tools/{ML_vision_transformers.py → ML_vision_transformers/__init__.py} +9 -6
- ml_tools/{_core/_ML_vision_transformers.py → ML_vision_transformers/_core_transforms.py} +11 -155
- ml_tools/ML_vision_transformers/_imprimir.py +14 -0
- ml_tools/ML_vision_transformers/_offline_augmentation.py +159 -0
- ml_tools/{_core/_PSO_optimization.py → PSO_optimization/_PSO.py} +58 -15
- ml_tools/{PSO_optimization.py → PSO_optimization/__init__.py} +5 -3
- ml_tools/PSO_optimization/_imprimir.py +10 -0
- ml_tools/SQL/__init__.py +7 -0
- ml_tools/{_core/_SQL.py → SQL/_dragon_SQL.py} +7 -11
- ml_tools/SQL/_imprimir.py +8 -0
- ml_tools/{_core → VIF}/_VIF_factor.py +5 -8
- ml_tools/{VIF_factor.py → VIF/__init__.py} +4 -2
- ml_tools/VIF/_imprimir.py +10 -0
- ml_tools/_core/__init__.py +7 -1
- ml_tools/_core/_logger.py +8 -18
- ml_tools/_core/_schema_load_ops.py +43 -0
- ml_tools/_core/_script_info.py +2 -2
- ml_tools/{data_exploration.py → data_exploration/__init__.py} +32 -16
- ml_tools/data_exploration/_analysis.py +214 -0
- ml_tools/data_exploration/_cleaning.py +566 -0
- ml_tools/data_exploration/_features.py +583 -0
- ml_tools/data_exploration/_imprimir.py +32 -0
- ml_tools/data_exploration/_plotting.py +487 -0
- ml_tools/data_exploration/_schema_ops.py +176 -0
- ml_tools/{ensemble_evaluation.py → ensemble_evaluation/__init__.py} +6 -4
- ml_tools/{_core → ensemble_evaluation}/_ensemble_evaluation.py +3 -7
- ml_tools/ensemble_evaluation/_imprimir.py +14 -0
- ml_tools/{ensemble_inference.py → ensemble_inference/__init__.py} +5 -3
- ml_tools/{_core → ensemble_inference}/_ensemble_inference.py +15 -18
- ml_tools/ensemble_inference/_imprimir.py +9 -0
- ml_tools/{ensemble_learning.py → ensemble_learning/__init__.py} +4 -6
- ml_tools/{_core → ensemble_learning}/_ensemble_learning.py +7 -10
- ml_tools/ensemble_learning/_imprimir.py +10 -0
- ml_tools/{excel_handler.py → excel_handler/__init__.py} +5 -3
- ml_tools/{_core → excel_handler}/_excel_handler.py +6 -10
- ml_tools/excel_handler/_imprimir.py +13 -0
- ml_tools/{keys.py → keys/__init__.py} +4 -1
- ml_tools/keys/_imprimir.py +11 -0
- ml_tools/{_core → keys}/_keys.py +2 -0
- ml_tools/{math_utilities.py → math_utilities/__init__.py} +5 -2
- ml_tools/math_utilities/_imprimir.py +11 -0
- ml_tools/{_core → math_utilities}/_math_utilities.py +1 -5
- ml_tools/{optimization_tools.py → optimization_tools/__init__.py} +9 -4
- ml_tools/optimization_tools/_imprimir.py +13 -0
- ml_tools/optimization_tools/_optimization_bounds.py +236 -0
- ml_tools/optimization_tools/_optimization_plots.py +218 -0
- ml_tools/{path_manager.py → path_manager/__init__.py} +6 -3
- ml_tools/{_core/_path_manager.py → path_manager/_dragonmanager.py} +11 -347
- ml_tools/path_manager/_imprimir.py +15 -0
- ml_tools/path_manager/_path_tools.py +346 -0
- ml_tools/plot_fonts/__init__.py +8 -0
- ml_tools/plot_fonts/_imprimir.py +8 -0
- ml_tools/{_core → plot_fonts}/_plot_fonts.py +2 -5
- ml_tools/schema/__init__.py +15 -0
- ml_tools/schema/_feature_schema.py +223 -0
- ml_tools/schema/_gui_schema.py +191 -0
- ml_tools/schema/_imprimir.py +10 -0
- ml_tools/{serde.py → serde/__init__.py} +4 -2
- ml_tools/serde/_imprimir.py +10 -0
- ml_tools/{_core → serde}/_serde.py +3 -8
- ml_tools/{utilities.py → utilities/__init__.py} +11 -6
- ml_tools/utilities/_imprimir.py +18 -0
- ml_tools/{_core/_utilities.py → utilities/_utility_save_load.py} +13 -190
- ml_tools/utilities/_utility_tools.py +192 -0
- dragon_ml_toolbox-19.13.0.dist-info/RECORD +0 -111
- ml_tools/ML_chaining_inference.py +0 -8
- ml_tools/ML_configuration.py +0 -86
- ml_tools/ML_configuration_pytab.py +0 -14
- ml_tools/ML_datasetmaster.py +0 -10
- ml_tools/ML_evaluation.py +0 -16
- ml_tools/ML_evaluation_multi.py +0 -12
- ml_tools/ML_finalize_handler.py +0 -8
- ml_tools/ML_inference.py +0 -12
- ml_tools/ML_models.py +0 -14
- ml_tools/ML_models_advanced.py +0 -14
- ml_tools/ML_models_pytab.py +0 -14
- ml_tools/ML_optimization.py +0 -14
- ml_tools/ML_optimization_pareto.py +0 -8
- ml_tools/ML_scaler.py +0 -8
- ml_tools/ML_sequence_datasetmaster.py +0 -8
- ml_tools/ML_sequence_evaluation.py +0 -10
- ml_tools/ML_sequence_inference.py +0 -8
- ml_tools/ML_sequence_models.py +0 -8
- ml_tools/ML_trainer.py +0 -12
- ml_tools/ML_vision_datasetmaster.py +0 -12
- ml_tools/ML_vision_evaluation.py +0 -10
- ml_tools/ML_vision_inference.py +0 -8
- ml_tools/ML_vision_models.py +0 -18
- ml_tools/SQL.py +0 -8
- ml_tools/_core/_ETL_cleaning.py +0 -694
- ml_tools/_core/_IO_tools.py +0 -498
- ml_tools/_core/_ML_callbacks.py +0 -702
- ml_tools/_core/_ML_configuration.py +0 -1332
- ml_tools/_core/_ML_configuration_pytab.py +0 -102
- ml_tools/_core/_ML_evaluation.py +0 -867
- ml_tools/_core/_ML_evaluation_multi.py +0 -544
- ml_tools/_core/_ML_inference.py +0 -646
- ml_tools/_core/_ML_models.py +0 -668
- ml_tools/_core/_ML_models_pytab.py +0 -693
- ml_tools/_core/_ML_trainer.py +0 -2323
- ml_tools/_core/_ML_utilities.py +0 -886
- ml_tools/_core/_ML_vision_models.py +0 -644
- ml_tools/_core/_data_exploration.py +0 -1901
- ml_tools/_core/_optimization_tools.py +0 -493
- ml_tools/_core/_schema.py +0 -359
- ml_tools/plot_fonts.py +0 -8
- ml_tools/schema.py +0 -12
- {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/WHEEL +0 -0
- {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/licenses/LICENSE +0 -0
- {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +0 -0
- {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,197 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from typing import Literal
|
|
3
|
+
|
|
4
|
+
from ..keys._keys import PyTorchLogKeys
|
|
5
|
+
from .._core import get_logger
|
|
6
|
+
|
|
7
|
+
from ._base import _Callback
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
_LOGGER = get_logger("LR Scheduler")
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
__all__ = [
|
|
14
|
+
"DragonScheduler",
|
|
15
|
+
"DragonPlateauScheduler"
|
|
16
|
+
]
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class _DragonLRScheduler(_Callback):
|
|
20
|
+
"""
|
|
21
|
+
Base class for Dragon LR Schedulers.
|
|
22
|
+
Handles common logic like logging and attaching to the trainer.
|
|
23
|
+
"""
|
|
24
|
+
def __init__(self):
|
|
25
|
+
super().__init__()
|
|
26
|
+
self.scheduler = None
|
|
27
|
+
self.previous_lr = None
|
|
28
|
+
|
|
29
|
+
def set_trainer(self, trainer):
|
|
30
|
+
"""Associates the callback with the trainer."""
|
|
31
|
+
super().set_trainer(trainer)
|
|
32
|
+
# Note: Subclasses must ensure self.scheduler is set before or during this call
|
|
33
|
+
# if they want to register it immediately.
|
|
34
|
+
if self.scheduler:
|
|
35
|
+
self.trainer.scheduler = self.scheduler # type: ignore
|
|
36
|
+
|
|
37
|
+
def on_train_begin(self, logs=None):
|
|
38
|
+
"""Store the initial learning rate."""
|
|
39
|
+
if not self.trainer.optimizer: # type: ignore
|
|
40
|
+
_LOGGER.warning("No optimizer found in trainer. LRScheduler cannot track learning rate.")
|
|
41
|
+
return
|
|
42
|
+
self.previous_lr = self.trainer.optimizer.param_groups[0]['lr'] # type: ignore
|
|
43
|
+
|
|
44
|
+
def _check_and_log_lr(self, epoch, logs, verbose: bool):
|
|
45
|
+
"""Helper to log LR changes and update history."""
|
|
46
|
+
if not self.trainer.optimizer: # type: ignore
|
|
47
|
+
return
|
|
48
|
+
|
|
49
|
+
current_lr = self.trainer.optimizer.param_groups[0]['lr'] # type: ignore
|
|
50
|
+
|
|
51
|
+
# Log change
|
|
52
|
+
if self.previous_lr is not None and current_lr != self.previous_lr:
|
|
53
|
+
if verbose:
|
|
54
|
+
print(f" > Epoch {epoch}: Learning rate changed to {current_lr:.6f}")
|
|
55
|
+
self.previous_lr = current_lr
|
|
56
|
+
|
|
57
|
+
# Log to dictionary
|
|
58
|
+
logs[PyTorchLogKeys.LEARNING_RATE] = current_lr
|
|
59
|
+
|
|
60
|
+
# Log to history
|
|
61
|
+
if hasattr(self.trainer, 'history'):
|
|
62
|
+
self.trainer.history.setdefault(PyTorchLogKeys.LEARNING_RATE, []).append(current_lr) # type: ignore
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class DragonScheduler(_DragonLRScheduler):
|
|
66
|
+
"""
|
|
67
|
+
Callback for standard PyTorch Learning Rate Schedulers.
|
|
68
|
+
|
|
69
|
+
Compatible with: StepLR, MultiStepLR, ExponentialLR, CosineAnnealingLR, etc.
|
|
70
|
+
|
|
71
|
+
NOT Compatible with: ReduceLROnPlateau (Use `DragonReduceLROnPlateau` instead).
|
|
72
|
+
"""
|
|
73
|
+
def __init__(self, scheduler, verbose: bool=True):
|
|
74
|
+
"""
|
|
75
|
+
Args:
|
|
76
|
+
scheduler: An initialized PyTorch learning rate scheduler instance.
|
|
77
|
+
verbose (bool): If True, logs learning rate changes to console.
|
|
78
|
+
"""
|
|
79
|
+
super().__init__()
|
|
80
|
+
if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
|
|
81
|
+
raise ValueError(
|
|
82
|
+
"DragonLRScheduler does not support 'ReduceLROnPlateau'. "
|
|
83
|
+
"Please use the `DragonReduceLROnPlateau` callback instead."
|
|
84
|
+
)
|
|
85
|
+
self.scheduler = scheduler
|
|
86
|
+
self.verbose = verbose
|
|
87
|
+
|
|
88
|
+
def set_trainer(self, trainer):
|
|
89
|
+
super().set_trainer(trainer)
|
|
90
|
+
# Explicitly register the scheduler again to be safe
|
|
91
|
+
self.trainer.scheduler = self.scheduler # type: ignore
|
|
92
|
+
if self.verbose:
|
|
93
|
+
_LOGGER.info(f"Registered LR Scheduler: {self.scheduler.__class__.__name__}")
|
|
94
|
+
|
|
95
|
+
def on_epoch_end(self, epoch, logs=None):
|
|
96
|
+
logs = logs or {}
|
|
97
|
+
|
|
98
|
+
# Standard step (no metrics needed)
|
|
99
|
+
self.scheduler.step()
|
|
100
|
+
|
|
101
|
+
self._check_and_log_lr(epoch, logs, self.verbose)
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
class DragonPlateauScheduler(_DragonLRScheduler):
|
|
105
|
+
"""
|
|
106
|
+
Specific callback for `torch.optim.lr_scheduler.ReduceLROnPlateau`. Reduces learning rate when a monitored metric has stopped improving.
|
|
107
|
+
|
|
108
|
+
This wrapper initializes the scheduler internally using the Trainer's optimizer, simplifying the setup process.
|
|
109
|
+
"""
|
|
110
|
+
def __init__(self,
|
|
111
|
+
monitor: Literal["Training Loss", "Validation Loss"] = "Validation Loss",
|
|
112
|
+
mode: Literal['min', 'max'] = 'min',
|
|
113
|
+
factor: float = 0.1,
|
|
114
|
+
patience: int = 5,
|
|
115
|
+
threshold: float = 1e-4,
|
|
116
|
+
threshold_mode: Literal['rel', 'abs'] = 'rel',
|
|
117
|
+
cooldown: int = 0,
|
|
118
|
+
min_lr: float = 0,
|
|
119
|
+
eps: float = 1e-8,
|
|
120
|
+
verbose: bool = True):
|
|
121
|
+
"""
|
|
122
|
+
Args:
|
|
123
|
+
monitor ("Training Loss", "Validation Loss"): Metric to monitor.
|
|
124
|
+
mode ('min', 'max'): One of 'min', 'max'.
|
|
125
|
+
factor (float): Factor by which the learning rate will be reduced. new_lr = lr * factor.
|
|
126
|
+
patience (int): Number of epochs with no improvement after which learning rate will be reduced.
|
|
127
|
+
threshold (float): Threshold for measuring the new optimum.
|
|
128
|
+
threshold_mode ('rel', 'abs'): One of 'rel', 'abs'.
|
|
129
|
+
cooldown (int): Number of epochs to wait before resuming normal operation after lr has been reduced.
|
|
130
|
+
min_lr (float or list): A scalar or a list of scalars.
|
|
131
|
+
eps (float): Minimal decay applied to lr.
|
|
132
|
+
verbose (bool): If True, logs learning rate changes to console.
|
|
133
|
+
"""
|
|
134
|
+
super().__init__()
|
|
135
|
+
|
|
136
|
+
# Standardize monitor key
|
|
137
|
+
if monitor == "Training Loss":
|
|
138
|
+
std_monitor = PyTorchLogKeys.TRAIN_LOSS
|
|
139
|
+
elif monitor == "Validation Loss":
|
|
140
|
+
std_monitor = PyTorchLogKeys.VAL_LOSS
|
|
141
|
+
else:
|
|
142
|
+
_LOGGER.error(f"Unknown monitor key: {monitor}.")
|
|
143
|
+
raise ValueError()
|
|
144
|
+
|
|
145
|
+
self.monitor = std_monitor
|
|
146
|
+
self.verbose = verbose
|
|
147
|
+
|
|
148
|
+
# Config storage for delayed initialization
|
|
149
|
+
self.config = {
|
|
150
|
+
'mode': mode,
|
|
151
|
+
'factor': factor,
|
|
152
|
+
'patience': patience,
|
|
153
|
+
'threshold': threshold,
|
|
154
|
+
'threshold_mode': threshold_mode,
|
|
155
|
+
'cooldown': cooldown,
|
|
156
|
+
'min_lr': min_lr,
|
|
157
|
+
'eps': eps,
|
|
158
|
+
}
|
|
159
|
+
|
|
160
|
+
def set_trainer(self, trainer):
|
|
161
|
+
"""
|
|
162
|
+
Initializes the ReduceLROnPlateau scheduler using the trainer's optimizer and registers it.
|
|
163
|
+
"""
|
|
164
|
+
super().set_trainer(trainer)
|
|
165
|
+
|
|
166
|
+
if not hasattr(self.trainer, 'optimizer'):
|
|
167
|
+
_LOGGER.error("Trainer has no optimizer. Cannot initialize ReduceLROnPlateau.")
|
|
168
|
+
raise ValueError()
|
|
169
|
+
|
|
170
|
+
# Initialize the actual scheduler with the optimizer
|
|
171
|
+
if self.verbose:
|
|
172
|
+
_LOGGER.info(f"Initializing ReduceLROnPlateau monitoring '{self.monitor}'")
|
|
173
|
+
|
|
174
|
+
self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
|
175
|
+
optimizer=self.trainer.optimizer, # type: ignore
|
|
176
|
+
**self.config
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
# Register with trainer for checkpointing
|
|
180
|
+
self.trainer.scheduler = self.scheduler # type: ignore
|
|
181
|
+
|
|
182
|
+
def on_epoch_end(self, epoch, logs=None):
|
|
183
|
+
logs = logs or {}
|
|
184
|
+
|
|
185
|
+
metric_val = logs.get(self.monitor)
|
|
186
|
+
|
|
187
|
+
if metric_val is None:
|
|
188
|
+
_LOGGER.warning(f"DragonReduceLROnPlateau could not find metric '{self.monitor}' in logs. Scheduler step skipped.")
|
|
189
|
+
# Still log LR to keep history consistent
|
|
190
|
+
self._check_and_log_lr(epoch, logs, self.verbose)
|
|
191
|
+
return
|
|
192
|
+
|
|
193
|
+
# Step with metric
|
|
194
|
+
self.scheduler.step(metric_val)
|
|
195
|
+
|
|
196
|
+
self._check_and_log_lr(epoch, logs, self.verbose)
|
|
197
|
+
|
|
@@ -1,11 +1,16 @@
|
|
|
1
|
-
from .
|
|
2
|
-
DragonChainOrchestrator
|
|
1
|
+
from ._dragon_chain import (
|
|
2
|
+
DragonChainOrchestrator
|
|
3
|
+
)
|
|
4
|
+
|
|
5
|
+
from ._chaining_tools import (
|
|
3
6
|
augment_dataset_with_predictions,
|
|
4
7
|
augment_dataset_with_predictions_multi,
|
|
5
8
|
prepare_chaining_dataset,
|
|
6
|
-
info
|
|
7
9
|
)
|
|
8
10
|
|
|
11
|
+
from ._imprimir import info
|
|
12
|
+
|
|
13
|
+
|
|
9
14
|
__all__ = [
|
|
10
15
|
"DragonChainOrchestrator",
|
|
11
16
|
"augment_dataset_with_predictions",
|
|
@@ -3,17 +3,16 @@ import numpy as np
|
|
|
3
3
|
from math import ceil
|
|
4
4
|
from typing import Optional, Literal
|
|
5
5
|
|
|
6
|
-
from
|
|
7
|
-
from ._keys import MLTaskKeys, PyTorchInferenceKeys
|
|
8
|
-
from ._logger import get_logger
|
|
9
|
-
from ._script_info import _script_info
|
|
6
|
+
from ..ML_inference import DragonInferenceHandler
|
|
10
7
|
|
|
8
|
+
from ..keys._keys import MLTaskKeys, PyTorchInferenceKeys
|
|
9
|
+
from .._core import get_logger
|
|
11
10
|
|
|
12
|
-
|
|
11
|
+
|
|
12
|
+
_LOGGER = get_logger("ML Chain")
|
|
13
13
|
|
|
14
14
|
|
|
15
15
|
__all__ = [
|
|
16
|
-
"DragonChainOrchestrator",
|
|
17
16
|
"augment_dataset_with_predictions",
|
|
18
17
|
"augment_dataset_with_predictions_multi",
|
|
19
18
|
"prepare_chaining_dataset",
|
|
@@ -321,126 +320,3 @@ def prepare_chaining_dataset(
|
|
|
321
320
|
|
|
322
321
|
return df
|
|
323
322
|
|
|
324
|
-
|
|
325
|
-
class DragonChainOrchestrator:
|
|
326
|
-
"""
|
|
327
|
-
Manages the data flow for a sequential chain of ML models (Model 1 -> Model 2 -> ... -> Model N).
|
|
328
|
-
|
|
329
|
-
This orchestrator maintains a master copy of the dataset that grows as models are applied.
|
|
330
|
-
1. Use `get_training_data` to extract a clean, target-specific subset for training a model.
|
|
331
|
-
2. Train your model externally.
|
|
332
|
-
3. Use `update_with_inference` to run that model on the master dataset and append predictions
|
|
333
|
-
as features for subsequent steps.
|
|
334
|
-
"""
|
|
335
|
-
def __init__(self, initial_dataset: pd.DataFrame, all_targets: list[str]):
|
|
336
|
-
"""
|
|
337
|
-
Args:
|
|
338
|
-
initial_dataset (pd.DataFrame): The starting dataframe with original features and all ground truth targets.
|
|
339
|
-
all_targets (list[str]): A list of all ground truth target column names present in the dataset.
|
|
340
|
-
"""
|
|
341
|
-
# Validation: Ensure targets exist
|
|
342
|
-
missing = [t for t in all_targets if t not in initial_dataset.columns]
|
|
343
|
-
if missing:
|
|
344
|
-
_LOGGER.error(f"The following targets were not found in the initial dataset: {missing}")
|
|
345
|
-
raise ValueError()
|
|
346
|
-
|
|
347
|
-
self.current_dataset = initial_dataset.copy()
|
|
348
|
-
self.all_targets = all_targets
|
|
349
|
-
_LOGGER.info(f"Orchestrator initialized with {len(initial_dataset)} samples, {len(initial_dataset.columns) - len(all_targets)} features, and {len(all_targets)} targets.")
|
|
350
|
-
|
|
351
|
-
def get_training_data(
|
|
352
|
-
self,
|
|
353
|
-
target_subset: list[str],
|
|
354
|
-
dropna_how: Literal["any", "all"] = "all"
|
|
355
|
-
) -> pd.DataFrame:
|
|
356
|
-
"""
|
|
357
|
-
Generates a clean dataframe tailored for training a specific step in the chain.
|
|
358
|
-
|
|
359
|
-
This method does NOT modify the internal state. It returns a view with:
|
|
360
|
-
- Current features (including previous model predictions).
|
|
361
|
-
- Only the specified `target_subset`.
|
|
362
|
-
- Rows cleaned based on `dropna_how`.
|
|
363
|
-
|
|
364
|
-
Args:
|
|
365
|
-
target_subset (list[str]): The targets for the current model.
|
|
366
|
-
dropna_how (Literal["any", "all"]): "any" drops row if any target is missing; "all" drops if all are missing.
|
|
367
|
-
|
|
368
|
-
Returns:
|
|
369
|
-
pd.DataFrame: A prepared dataframe for training.
|
|
370
|
-
"""
|
|
371
|
-
_LOGGER.info(f"Extracting training data for targets {target_subset}...")
|
|
372
|
-
return prepare_chaining_dataset(
|
|
373
|
-
dataset=self.current_dataset,
|
|
374
|
-
all_targets=self.all_targets,
|
|
375
|
-
target_subset=target_subset,
|
|
376
|
-
dropna_how=dropna_how,
|
|
377
|
-
verbose=False
|
|
378
|
-
)
|
|
379
|
-
|
|
380
|
-
def update_with_inference(
|
|
381
|
-
self,
|
|
382
|
-
handler: DragonInferenceHandler,
|
|
383
|
-
prefix: str = "pred_",
|
|
384
|
-
batch_size: int = 4096
|
|
385
|
-
) -> None:
|
|
386
|
-
"""
|
|
387
|
-
Runs inference using the provided handler on the full internal dataset and appends the results as new features.
|
|
388
|
-
|
|
389
|
-
This updates the internal state of the Orchestrator. Subsequent calls to `get_training_data`
|
|
390
|
-
will include these new prediction columns as features.
|
|
391
|
-
|
|
392
|
-
Args:
|
|
393
|
-
handler (DragonInferenceHandler): The trained model handler.
|
|
394
|
-
prefix (str): Prefix for the new prediction columns (e.g., "m1_", "step2_").
|
|
395
|
-
batch_size (int): Batch size for inference.
|
|
396
|
-
"""
|
|
397
|
-
_LOGGER.info(f"Orchestrator: Updating internal state with predictions from handler (Targets: {handler.target_ids})...")
|
|
398
|
-
|
|
399
|
-
# We use the existing utility to handle the augmentation
|
|
400
|
-
# This keeps the logic consistent (drop GT -> predict -> concat GT)
|
|
401
|
-
self.current_dataset = augment_dataset_with_predictions(
|
|
402
|
-
handler=handler,
|
|
403
|
-
dataset=self.current_dataset,
|
|
404
|
-
ground_truth_targets=self.all_targets,
|
|
405
|
-
prediction_col_prefix=prefix,
|
|
406
|
-
batch_size=batch_size
|
|
407
|
-
)
|
|
408
|
-
|
|
409
|
-
_LOGGER.debug(f"Orchestrator State updated. Current feature count (approx): {self.current_dataset.shape[1] - len(self.all_targets)}")
|
|
410
|
-
|
|
411
|
-
def update_with_ensemble(
|
|
412
|
-
self,
|
|
413
|
-
handlers: list[DragonInferenceHandler],
|
|
414
|
-
prefixes: Optional[list[str]] = None,
|
|
415
|
-
batch_size: int = 4096
|
|
416
|
-
) -> None:
|
|
417
|
-
"""
|
|
418
|
-
Runs multiple independent inference handlers (e.g. for Stacking) on the full internal dataset
|
|
419
|
-
and appends all results as new features.
|
|
420
|
-
|
|
421
|
-
Args:
|
|
422
|
-
handlers (list[DragonInferenceHandler]): List of trained model handlers.
|
|
423
|
-
prefixes (list[str], optional): Prefixes for each model's columns.
|
|
424
|
-
batch_size (int): Batch size for inference.
|
|
425
|
-
"""
|
|
426
|
-
_LOGGER.info(f"Orchestrator: Updating internal state with ensemble of {len(handlers)} models...")
|
|
427
|
-
|
|
428
|
-
self.current_dataset = augment_dataset_with_predictions_multi(
|
|
429
|
-
handlers=handlers,
|
|
430
|
-
dataset=self.current_dataset,
|
|
431
|
-
ground_truth_targets=self.all_targets,
|
|
432
|
-
model_prefixes=prefixes,
|
|
433
|
-
batch_size=batch_size
|
|
434
|
-
)
|
|
435
|
-
|
|
436
|
-
new_feat_count = self.current_dataset.shape[1] - len(self.all_targets)
|
|
437
|
-
_LOGGER.debug(f"Orchestrator: State updated. Total current features: {new_feat_count}")
|
|
438
|
-
|
|
439
|
-
@property
|
|
440
|
-
def latest_dataset(self) -> pd.DataFrame:
|
|
441
|
-
"""Returns a copy of the current master dataset including all accumulated predictions."""
|
|
442
|
-
return self.current_dataset.copy()
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
def info():
|
|
446
|
-
_script_info(__all__)
|
|
@@ -0,0 +1,140 @@
|
|
|
1
|
+
import pandas as pd
|
|
2
|
+
from typing import Optional, Literal
|
|
3
|
+
|
|
4
|
+
from ..ML_inference import DragonInferenceHandler
|
|
5
|
+
|
|
6
|
+
from .._core import get_logger
|
|
7
|
+
|
|
8
|
+
from ._chaining_tools import (
|
|
9
|
+
augment_dataset_with_predictions,
|
|
10
|
+
augment_dataset_with_predictions_multi,
|
|
11
|
+
prepare_chaining_dataset,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
_LOGGER = get_logger("DragonChainOrchestrator")
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
__all__ = [
|
|
19
|
+
"DragonChainOrchestrator",
|
|
20
|
+
]
|
|
21
|
+
|
|
22
|
+
class DragonChainOrchestrator:
|
|
23
|
+
"""
|
|
24
|
+
Manages the data flow for a sequential chain of ML models (Model 1 -> Model 2 -> ... -> Model N).
|
|
25
|
+
|
|
26
|
+
This orchestrator maintains a master copy of the dataset that grows as models are applied.
|
|
27
|
+
1. Use `get_training_data` to extract a clean, target-specific subset for training a model.
|
|
28
|
+
2. Train your model externally.
|
|
29
|
+
3. Use `update_with_inference` to run that model on the master dataset and append predictions
|
|
30
|
+
as features for subsequent steps.
|
|
31
|
+
"""
|
|
32
|
+
def __init__(self, initial_dataset: pd.DataFrame, all_targets: list[str]):
|
|
33
|
+
"""
|
|
34
|
+
Args:
|
|
35
|
+
initial_dataset (pd.DataFrame): The starting dataframe with original features and all ground truth targets.
|
|
36
|
+
all_targets (list[str]): A list of all ground truth target column names present in the dataset.
|
|
37
|
+
"""
|
|
38
|
+
# Validation: Ensure targets exist
|
|
39
|
+
missing = [t for t in all_targets if t not in initial_dataset.columns]
|
|
40
|
+
if missing:
|
|
41
|
+
_LOGGER.error(f"The following targets were not found in the initial dataset: {missing}")
|
|
42
|
+
raise ValueError()
|
|
43
|
+
|
|
44
|
+
self.current_dataset = initial_dataset.copy()
|
|
45
|
+
self.all_targets = all_targets
|
|
46
|
+
_LOGGER.info(f"Orchestrator initialized with {len(initial_dataset)} samples, {len(initial_dataset.columns) - len(all_targets)} features, and {len(all_targets)} targets.")
|
|
47
|
+
|
|
48
|
+
def get_training_data(
|
|
49
|
+
self,
|
|
50
|
+
target_subset: list[str],
|
|
51
|
+
dropna_how: Literal["any", "all"] = "all"
|
|
52
|
+
) -> pd.DataFrame:
|
|
53
|
+
"""
|
|
54
|
+
Generates a clean dataframe tailored for training a specific step in the chain.
|
|
55
|
+
|
|
56
|
+
This method does NOT modify the internal state. It returns a view with:
|
|
57
|
+
- Current features (including previous model predictions).
|
|
58
|
+
- Only the specified `target_subset`.
|
|
59
|
+
- Rows cleaned based on `dropna_how`.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
target_subset (list[str]): The targets for the current model.
|
|
63
|
+
dropna_how (Literal["any", "all"]): "any" drops row if any target is missing; "all" drops if all are missing.
|
|
64
|
+
|
|
65
|
+
Returns:
|
|
66
|
+
pd.DataFrame: A prepared dataframe for training.
|
|
67
|
+
"""
|
|
68
|
+
_LOGGER.info(f"Extracting training data for targets {target_subset}...")
|
|
69
|
+
return prepare_chaining_dataset(
|
|
70
|
+
dataset=self.current_dataset,
|
|
71
|
+
all_targets=self.all_targets,
|
|
72
|
+
target_subset=target_subset,
|
|
73
|
+
dropna_how=dropna_how,
|
|
74
|
+
verbose=False
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
def update_with_inference(
|
|
78
|
+
self,
|
|
79
|
+
handler: DragonInferenceHandler,
|
|
80
|
+
prefix: str = "pred_",
|
|
81
|
+
batch_size: int = 4096
|
|
82
|
+
) -> None:
|
|
83
|
+
"""
|
|
84
|
+
Runs inference using the provided handler on the full internal dataset and appends the results as new features.
|
|
85
|
+
|
|
86
|
+
This updates the internal state of the Orchestrator. Subsequent calls to `get_training_data`
|
|
87
|
+
will include these new prediction columns as features.
|
|
88
|
+
|
|
89
|
+
Args:
|
|
90
|
+
handler (DragonInferenceHandler): The trained model handler.
|
|
91
|
+
prefix (str): Prefix for the new prediction columns (e.g., "m1_", "step2_").
|
|
92
|
+
batch_size (int): Batch size for inference.
|
|
93
|
+
"""
|
|
94
|
+
_LOGGER.info(f"Orchestrator: Updating internal state with predictions from handler (Targets: {handler.target_ids})...")
|
|
95
|
+
|
|
96
|
+
# We use the existing utility to handle the augmentation
|
|
97
|
+
# This keeps the logic consistent (drop GT -> predict -> concat GT)
|
|
98
|
+
self.current_dataset = augment_dataset_with_predictions(
|
|
99
|
+
handler=handler,
|
|
100
|
+
dataset=self.current_dataset,
|
|
101
|
+
ground_truth_targets=self.all_targets,
|
|
102
|
+
prediction_col_prefix=prefix,
|
|
103
|
+
batch_size=batch_size
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
_LOGGER.debug(f"Orchestrator State updated. Current feature count (approx): {self.current_dataset.shape[1] - len(self.all_targets)}")
|
|
107
|
+
|
|
108
|
+
def update_with_ensemble(
|
|
109
|
+
self,
|
|
110
|
+
handlers: list[DragonInferenceHandler],
|
|
111
|
+
prefixes: Optional[list[str]] = None,
|
|
112
|
+
batch_size: int = 4096
|
|
113
|
+
) -> None:
|
|
114
|
+
"""
|
|
115
|
+
Runs multiple independent inference handlers (e.g. for Stacking) on the full internal dataset
|
|
116
|
+
and appends all results as new features.
|
|
117
|
+
|
|
118
|
+
Args:
|
|
119
|
+
handlers (list[DragonInferenceHandler]): List of trained model handlers.
|
|
120
|
+
prefixes (list[str], optional): Prefixes for each model's columns.
|
|
121
|
+
batch_size (int): Batch size for inference.
|
|
122
|
+
"""
|
|
123
|
+
_LOGGER.info(f"Orchestrator: Updating internal state with ensemble of {len(handlers)} models...")
|
|
124
|
+
|
|
125
|
+
self.current_dataset = augment_dataset_with_predictions_multi(
|
|
126
|
+
handlers=handlers,
|
|
127
|
+
dataset=self.current_dataset,
|
|
128
|
+
ground_truth_targets=self.all_targets,
|
|
129
|
+
model_prefixes=prefixes,
|
|
130
|
+
batch_size=batch_size
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
new_feat_count = self.current_dataset.shape[1] - len(self.all_targets)
|
|
134
|
+
_LOGGER.debug(f"Orchestrator: State updated. Total current features: {new_feat_count}")
|
|
135
|
+
|
|
136
|
+
@property
|
|
137
|
+
def latest_dataset(self) -> pd.DataFrame:
|
|
138
|
+
"""Returns a copy of the current master dataset including all accumulated predictions."""
|
|
139
|
+
return self.current_dataset.copy()
|
|
140
|
+
|
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
from ._metrics import (
|
|
2
|
+
FormatRegressionMetrics,
|
|
3
|
+
FormatMultiTargetRegressionMetrics,
|
|
4
|
+
FormatBinaryClassificationMetrics,
|
|
5
|
+
FormatMultiClassClassificationMetrics,
|
|
6
|
+
FormatBinaryImageClassificationMetrics,
|
|
7
|
+
FormatMultiClassImageClassificationMetrics,
|
|
8
|
+
FormatMultiLabelBinaryClassificationMetrics,
|
|
9
|
+
FormatBinarySegmentationMetrics,
|
|
10
|
+
FormatMultiClassSegmentationMetrics,
|
|
11
|
+
FormatSequenceValueMetrics,
|
|
12
|
+
FormatSequenceSequenceMetrics,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
from ._finalize import (
|
|
16
|
+
FinalizeBinaryClassification,
|
|
17
|
+
FinalizeBinarySegmentation,
|
|
18
|
+
FinalizeBinaryImageClassification,
|
|
19
|
+
FinalizeMultiClassClassification,
|
|
20
|
+
FinalizeMultiClassImageClassification,
|
|
21
|
+
FinalizeMultiClassSegmentation,
|
|
22
|
+
FinalizeMultiLabelBinaryClassification,
|
|
23
|
+
FinalizeMultiTargetRegression,
|
|
24
|
+
FinalizeRegression,
|
|
25
|
+
FinalizeObjectDetection,
|
|
26
|
+
FinalizeSequenceSequencePrediction,
|
|
27
|
+
FinalizeSequenceValuePrediction,
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
from ._models import (
|
|
31
|
+
DragonMLPParams,
|
|
32
|
+
DragonAttentionMLPParams,
|
|
33
|
+
DragonMultiHeadAttentionNetParams,
|
|
34
|
+
DragonTabularTransformerParams,
|
|
35
|
+
DragonGateParams,
|
|
36
|
+
DragonNodeParams,
|
|
37
|
+
DragonTabNetParams,
|
|
38
|
+
DragonAutoIntParams,
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
from ._training import (
|
|
42
|
+
DragonTrainingConfig,
|
|
43
|
+
DragonParetoConfig,
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
from ._imprimir import info
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
__all__ = [
|
|
50
|
+
# --- Metrics Formats ---
|
|
51
|
+
"FormatRegressionMetrics",
|
|
52
|
+
"FormatMultiTargetRegressionMetrics",
|
|
53
|
+
"FormatBinaryClassificationMetrics",
|
|
54
|
+
"FormatMultiClassClassificationMetrics",
|
|
55
|
+
"FormatBinaryImageClassificationMetrics",
|
|
56
|
+
"FormatMultiClassImageClassificationMetrics",
|
|
57
|
+
"FormatMultiLabelBinaryClassificationMetrics",
|
|
58
|
+
"FormatBinarySegmentationMetrics",
|
|
59
|
+
"FormatMultiClassSegmentationMetrics",
|
|
60
|
+
"FormatSequenceValueMetrics",
|
|
61
|
+
"FormatSequenceSequenceMetrics",
|
|
62
|
+
|
|
63
|
+
# --- Finalize Configs ---
|
|
64
|
+
"FinalizeBinaryClassification",
|
|
65
|
+
"FinalizeBinarySegmentation",
|
|
66
|
+
"FinalizeBinaryImageClassification",
|
|
67
|
+
"FinalizeMultiClassClassification",
|
|
68
|
+
"FinalizeMultiClassImageClassification",
|
|
69
|
+
"FinalizeMultiClassSegmentation",
|
|
70
|
+
"FinalizeMultiLabelBinaryClassification",
|
|
71
|
+
"FinalizeMultiTargetRegression",
|
|
72
|
+
"FinalizeRegression",
|
|
73
|
+
"FinalizeObjectDetection",
|
|
74
|
+
"FinalizeSequenceSequencePrediction",
|
|
75
|
+
"FinalizeSequenceValuePrediction",
|
|
76
|
+
|
|
77
|
+
# --- Model Parameter Configs ---
|
|
78
|
+
"DragonMLPParams",
|
|
79
|
+
"DragonAttentionMLPParams",
|
|
80
|
+
"DragonMultiHeadAttentionNetParams",
|
|
81
|
+
"DragonTabularTransformerParams",
|
|
82
|
+
"DragonGateParams",
|
|
83
|
+
"DragonNodeParams",
|
|
84
|
+
"DragonTabNetParams",
|
|
85
|
+
"DragonAutoIntParams",
|
|
86
|
+
|
|
87
|
+
# --- Training Config ---
|
|
88
|
+
"DragonTrainingConfig",
|
|
89
|
+
"DragonParetoConfig",
|
|
90
|
+
]
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from collections.abc import Mapping
|
|
4
|
+
|
|
5
|
+
from ..schema import FeatureSchema
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
__all__ = [
|
|
9
|
+
"_BaseModelParams",
|
|
10
|
+
]
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class _BaseModelParams(Mapping):
|
|
14
|
+
"""
|
|
15
|
+
[PRIVATE] Base class for model parameter configs.
|
|
16
|
+
|
|
17
|
+
Inherits from Mapping to behave like a dictionary, enabling
|
|
18
|
+
`**params` unpacking directly into model constructors.
|
|
19
|
+
"""
|
|
20
|
+
def __getitem__(self, key: str) -> Any:
|
|
21
|
+
return self.__dict__[key]
|
|
22
|
+
|
|
23
|
+
def __iter__(self):
|
|
24
|
+
return iter(self.__dict__)
|
|
25
|
+
|
|
26
|
+
def __len__(self) -> int:
|
|
27
|
+
return len(self.__dict__)
|
|
28
|
+
|
|
29
|
+
def __or__(self, other) -> dict[str, Any]:
|
|
30
|
+
"""Allows merging with other Mappings using the | operator."""
|
|
31
|
+
if isinstance(other, Mapping):
|
|
32
|
+
return dict(self) | dict(other)
|
|
33
|
+
return NotImplemented
|
|
34
|
+
|
|
35
|
+
def __ror__(self, other) -> dict[str, Any]:
|
|
36
|
+
"""Allows merging with other Mappings using the | operator."""
|
|
37
|
+
if isinstance(other, Mapping):
|
|
38
|
+
return dict(other) | dict(self)
|
|
39
|
+
return NotImplemented
|
|
40
|
+
|
|
41
|
+
def __repr__(self) -> str:
|
|
42
|
+
"""Returns a formatted multi-line string representation."""
|
|
43
|
+
class_name = self.__class__.__name__
|
|
44
|
+
# Format parameters for clean logging
|
|
45
|
+
params = []
|
|
46
|
+
for k, v in self.__dict__.items():
|
|
47
|
+
# If value is huge (like FeatureSchema), use its own repr
|
|
48
|
+
val_str = repr(v)
|
|
49
|
+
params.append(f" {k}={val_str}")
|
|
50
|
+
|
|
51
|
+
params_str = ",\n".join(params)
|
|
52
|
+
return f"{class_name}(\n{params_str}\n)"
|
|
53
|
+
|
|
54
|
+
def to_log(self) -> dict[str, Any]:
|
|
55
|
+
"""
|
|
56
|
+
Safely converts complex types (like FeatureSchema) to their string
|
|
57
|
+
representation for cleaner JSON logging.
|
|
58
|
+
"""
|
|
59
|
+
clean_dict = {}
|
|
60
|
+
for k, v in self.__dict__.items():
|
|
61
|
+
if isinstance(v, FeatureSchema):
|
|
62
|
+
# Force the repr() string, otherwise json.dump treats it as a list
|
|
63
|
+
clean_dict[k] = repr(v)
|
|
64
|
+
elif isinstance(v, Path):
|
|
65
|
+
# JSON cannot serialize Path objects, convert to string
|
|
66
|
+
clean_dict[k] = str(v)
|
|
67
|
+
else:
|
|
68
|
+
clean_dict[k] = v
|
|
69
|
+
return clean_dict
|