dragon-ml-toolbox 19.13.0__py3-none-any.whl → 20.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/METADATA +29 -46
- dragon_ml_toolbox-20.0.0.dist-info/RECORD +178 -0
- ml_tools/{ETL_cleaning.py → ETL_cleaning/__init__.py} +13 -5
- ml_tools/ETL_cleaning/_basic_clean.py +351 -0
- ml_tools/ETL_cleaning/_clean_tools.py +128 -0
- ml_tools/ETL_cleaning/_dragon_cleaner.py +245 -0
- ml_tools/ETL_cleaning/_imprimir.py +13 -0
- ml_tools/{ETL_engineering.py → ETL_engineering/__init__.py} +8 -4
- ml_tools/ETL_engineering/_dragon_engineering.py +261 -0
- ml_tools/ETL_engineering/_imprimir.py +24 -0
- ml_tools/{_core/_ETL_engineering.py → ETL_engineering/_transforms.py} +14 -267
- ml_tools/{_core → GUI_tools}/_GUI_tools.py +37 -40
- ml_tools/{GUI_tools.py → GUI_tools/__init__.py} +7 -5
- ml_tools/GUI_tools/_imprimir.py +12 -0
- ml_tools/IO_tools/_IO_loggers.py +235 -0
- ml_tools/IO_tools/_IO_save_load.py +151 -0
- ml_tools/IO_tools/_IO_utils.py +140 -0
- ml_tools/{IO_tools.py → IO_tools/__init__.py} +13 -5
- ml_tools/IO_tools/_imprimir.py +14 -0
- ml_tools/MICE/_MICE_imputation.py +132 -0
- ml_tools/{MICE_imputation.py → MICE/__init__.py} +6 -7
- ml_tools/{_core/_MICE_imputation.py → MICE/_dragon_mice.py} +243 -322
- ml_tools/MICE/_imprimir.py +11 -0
- ml_tools/{ML_callbacks.py → ML_callbacks/__init__.py} +12 -4
- ml_tools/ML_callbacks/_base.py +101 -0
- ml_tools/ML_callbacks/_checkpoint.py +232 -0
- ml_tools/ML_callbacks/_early_stop.py +208 -0
- ml_tools/ML_callbacks/_imprimir.py +12 -0
- ml_tools/ML_callbacks/_scheduler.py +197 -0
- ml_tools/{ML_chaining_utilities.py → ML_chain/__init__.py} +8 -3
- ml_tools/{_core/_ML_chaining_utilities.py → ML_chain/_chaining_tools.py} +5 -129
- ml_tools/ML_chain/_dragon_chain.py +140 -0
- ml_tools/ML_chain/_imprimir.py +11 -0
- ml_tools/ML_configuration/__init__.py +90 -0
- ml_tools/ML_configuration/_base_model_config.py +69 -0
- ml_tools/ML_configuration/_finalize.py +366 -0
- ml_tools/ML_configuration/_imprimir.py +47 -0
- ml_tools/ML_configuration/_metrics.py +593 -0
- ml_tools/ML_configuration/_models.py +206 -0
- ml_tools/ML_configuration/_training.py +124 -0
- ml_tools/ML_datasetmaster/__init__.py +28 -0
- ml_tools/ML_datasetmaster/_base_datasetmaster.py +337 -0
- ml_tools/{_core/_ML_datasetmaster.py → ML_datasetmaster/_datasetmaster.py} +9 -329
- ml_tools/ML_datasetmaster/_imprimir.py +15 -0
- ml_tools/{_core/_ML_sequence_datasetmaster.py → ML_datasetmaster/_sequence_datasetmaster.py} +13 -15
- ml_tools/{_core/_ML_vision_datasetmaster.py → ML_datasetmaster/_vision_datasetmaster.py} +63 -65
- ml_tools/ML_evaluation/__init__.py +53 -0
- ml_tools/ML_evaluation/_classification.py +629 -0
- ml_tools/ML_evaluation/_feature_importance.py +409 -0
- ml_tools/ML_evaluation/_imprimir.py +25 -0
- ml_tools/ML_evaluation/_loss.py +92 -0
- ml_tools/ML_evaluation/_regression.py +273 -0
- ml_tools/{_core/_ML_sequence_evaluation.py → ML_evaluation/_sequence.py} +8 -11
- ml_tools/{_core/_ML_vision_evaluation.py → ML_evaluation/_vision.py} +12 -17
- ml_tools/{_core → ML_evaluation_captum}/_ML_evaluation_captum.py +11 -38
- ml_tools/{ML_evaluation_captum.py → ML_evaluation_captum/__init__.py} +6 -4
- ml_tools/ML_evaluation_captum/_imprimir.py +10 -0
- ml_tools/{_core → ML_finalize_handler}/_ML_finalize_handler.py +3 -7
- ml_tools/ML_finalize_handler/__init__.py +10 -0
- ml_tools/ML_finalize_handler/_imprimir.py +8 -0
- ml_tools/ML_inference/__init__.py +22 -0
- ml_tools/ML_inference/_base_inference.py +166 -0
- ml_tools/{_core/_ML_chaining_inference.py → ML_inference/_chain_inference.py} +14 -17
- ml_tools/ML_inference/_dragon_inference.py +332 -0
- ml_tools/ML_inference/_imprimir.py +11 -0
- ml_tools/ML_inference/_multi_inference.py +180 -0
- ml_tools/ML_inference_sequence/__init__.py +10 -0
- ml_tools/ML_inference_sequence/_imprimir.py +8 -0
- ml_tools/{_core/_ML_sequence_inference.py → ML_inference_sequence/_sequence_inference.py} +11 -15
- ml_tools/ML_inference_vision/__init__.py +10 -0
- ml_tools/ML_inference_vision/_imprimir.py +8 -0
- ml_tools/{_core/_ML_vision_inference.py → ML_inference_vision/_vision_inference.py} +15 -19
- ml_tools/ML_models/__init__.py +32 -0
- ml_tools/{_core/_ML_models_advanced.py → ML_models/_advanced_models.py} +22 -18
- ml_tools/ML_models/_base_mlp_attention.py +198 -0
- ml_tools/{_core/_models_advanced_base.py → ML_models/_base_save_load.py} +73 -49
- ml_tools/ML_models/_dragon_tabular.py +248 -0
- ml_tools/ML_models/_imprimir.py +18 -0
- ml_tools/ML_models/_mlp_attention.py +134 -0
- ml_tools/{_core → ML_models}/_models_advanced_helpers.py +13 -13
- ml_tools/ML_models_sequence/__init__.py +10 -0
- ml_tools/ML_models_sequence/_imprimir.py +8 -0
- ml_tools/{_core/_ML_sequence_models.py → ML_models_sequence/_sequence_models.py} +5 -8
- ml_tools/ML_models_vision/__init__.py +29 -0
- ml_tools/ML_models_vision/_base_wrapper.py +254 -0
- ml_tools/ML_models_vision/_image_classification.py +182 -0
- ml_tools/ML_models_vision/_image_segmentation.py +108 -0
- ml_tools/ML_models_vision/_imprimir.py +16 -0
- ml_tools/ML_models_vision/_object_detection.py +135 -0
- ml_tools/ML_optimization/__init__.py +21 -0
- ml_tools/ML_optimization/_imprimir.py +13 -0
- ml_tools/{_core/_ML_optimization_pareto.py → ML_optimization/_multi_dragon.py} +18 -24
- ml_tools/ML_optimization/_single_dragon.py +203 -0
- ml_tools/{_core/_ML_optimization.py → ML_optimization/_single_manual.py} +75 -213
- ml_tools/{_core → ML_scaler}/_ML_scaler.py +8 -11
- ml_tools/ML_scaler/__init__.py +10 -0
- ml_tools/ML_scaler/_imprimir.py +8 -0
- ml_tools/ML_trainer/__init__.py +20 -0
- ml_tools/ML_trainer/_base_trainer.py +297 -0
- ml_tools/ML_trainer/_dragon_detection_trainer.py +402 -0
- ml_tools/ML_trainer/_dragon_sequence_trainer.py +540 -0
- ml_tools/ML_trainer/_dragon_trainer.py +1160 -0
- ml_tools/ML_trainer/_imprimir.py +10 -0
- ml_tools/{ML_utilities.py → ML_utilities/__init__.py} +14 -6
- ml_tools/ML_utilities/_artifact_finder.py +382 -0
- ml_tools/ML_utilities/_imprimir.py +16 -0
- ml_tools/ML_utilities/_inspection.py +325 -0
- ml_tools/ML_utilities/_train_tools.py +205 -0
- ml_tools/{ML_vision_transformers.py → ML_vision_transformers/__init__.py} +9 -6
- ml_tools/{_core/_ML_vision_transformers.py → ML_vision_transformers/_core_transforms.py} +11 -155
- ml_tools/ML_vision_transformers/_imprimir.py +14 -0
- ml_tools/ML_vision_transformers/_offline_augmentation.py +159 -0
- ml_tools/{_core/_PSO_optimization.py → PSO_optimization/_PSO.py} +58 -15
- ml_tools/{PSO_optimization.py → PSO_optimization/__init__.py} +5 -3
- ml_tools/PSO_optimization/_imprimir.py +10 -0
- ml_tools/SQL/__init__.py +7 -0
- ml_tools/{_core/_SQL.py → SQL/_dragon_SQL.py} +7 -11
- ml_tools/SQL/_imprimir.py +8 -0
- ml_tools/{_core → VIF}/_VIF_factor.py +5 -8
- ml_tools/{VIF_factor.py → VIF/__init__.py} +4 -2
- ml_tools/VIF/_imprimir.py +10 -0
- ml_tools/_core/__init__.py +7 -1
- ml_tools/_core/_logger.py +8 -18
- ml_tools/_core/_schema_load_ops.py +43 -0
- ml_tools/_core/_script_info.py +2 -2
- ml_tools/{data_exploration.py → data_exploration/__init__.py} +32 -16
- ml_tools/data_exploration/_analysis.py +214 -0
- ml_tools/data_exploration/_cleaning.py +566 -0
- ml_tools/data_exploration/_features.py +583 -0
- ml_tools/data_exploration/_imprimir.py +32 -0
- ml_tools/data_exploration/_plotting.py +487 -0
- ml_tools/data_exploration/_schema_ops.py +176 -0
- ml_tools/{ensemble_evaluation.py → ensemble_evaluation/__init__.py} +6 -4
- ml_tools/{_core → ensemble_evaluation}/_ensemble_evaluation.py +3 -7
- ml_tools/ensemble_evaluation/_imprimir.py +14 -0
- ml_tools/{ensemble_inference.py → ensemble_inference/__init__.py} +5 -3
- ml_tools/{_core → ensemble_inference}/_ensemble_inference.py +15 -18
- ml_tools/ensemble_inference/_imprimir.py +9 -0
- ml_tools/{ensemble_learning.py → ensemble_learning/__init__.py} +4 -6
- ml_tools/{_core → ensemble_learning}/_ensemble_learning.py +7 -10
- ml_tools/ensemble_learning/_imprimir.py +10 -0
- ml_tools/{excel_handler.py → excel_handler/__init__.py} +5 -3
- ml_tools/{_core → excel_handler}/_excel_handler.py +6 -10
- ml_tools/excel_handler/_imprimir.py +13 -0
- ml_tools/{keys.py → keys/__init__.py} +4 -1
- ml_tools/keys/_imprimir.py +11 -0
- ml_tools/{_core → keys}/_keys.py +2 -0
- ml_tools/{math_utilities.py → math_utilities/__init__.py} +5 -2
- ml_tools/math_utilities/_imprimir.py +11 -0
- ml_tools/{_core → math_utilities}/_math_utilities.py +1 -5
- ml_tools/{optimization_tools.py → optimization_tools/__init__.py} +9 -4
- ml_tools/optimization_tools/_imprimir.py +13 -0
- ml_tools/optimization_tools/_optimization_bounds.py +236 -0
- ml_tools/optimization_tools/_optimization_plots.py +218 -0
- ml_tools/{path_manager.py → path_manager/__init__.py} +6 -3
- ml_tools/{_core/_path_manager.py → path_manager/_dragonmanager.py} +11 -347
- ml_tools/path_manager/_imprimir.py +15 -0
- ml_tools/path_manager/_path_tools.py +346 -0
- ml_tools/plot_fonts/__init__.py +8 -0
- ml_tools/plot_fonts/_imprimir.py +8 -0
- ml_tools/{_core → plot_fonts}/_plot_fonts.py +2 -5
- ml_tools/schema/__init__.py +15 -0
- ml_tools/schema/_feature_schema.py +223 -0
- ml_tools/schema/_gui_schema.py +191 -0
- ml_tools/schema/_imprimir.py +10 -0
- ml_tools/{serde.py → serde/__init__.py} +4 -2
- ml_tools/serde/_imprimir.py +10 -0
- ml_tools/{_core → serde}/_serde.py +3 -8
- ml_tools/{utilities.py → utilities/__init__.py} +11 -6
- ml_tools/utilities/_imprimir.py +18 -0
- ml_tools/{_core/_utilities.py → utilities/_utility_save_load.py} +13 -190
- ml_tools/utilities/_utility_tools.py +192 -0
- dragon_ml_toolbox-19.13.0.dist-info/RECORD +0 -111
- ml_tools/ML_chaining_inference.py +0 -8
- ml_tools/ML_configuration.py +0 -86
- ml_tools/ML_configuration_pytab.py +0 -14
- ml_tools/ML_datasetmaster.py +0 -10
- ml_tools/ML_evaluation.py +0 -16
- ml_tools/ML_evaluation_multi.py +0 -12
- ml_tools/ML_finalize_handler.py +0 -8
- ml_tools/ML_inference.py +0 -12
- ml_tools/ML_models.py +0 -14
- ml_tools/ML_models_advanced.py +0 -14
- ml_tools/ML_models_pytab.py +0 -14
- ml_tools/ML_optimization.py +0 -14
- ml_tools/ML_optimization_pareto.py +0 -8
- ml_tools/ML_scaler.py +0 -8
- ml_tools/ML_sequence_datasetmaster.py +0 -8
- ml_tools/ML_sequence_evaluation.py +0 -10
- ml_tools/ML_sequence_inference.py +0 -8
- ml_tools/ML_sequence_models.py +0 -8
- ml_tools/ML_trainer.py +0 -12
- ml_tools/ML_vision_datasetmaster.py +0 -12
- ml_tools/ML_vision_evaluation.py +0 -10
- ml_tools/ML_vision_inference.py +0 -8
- ml_tools/ML_vision_models.py +0 -18
- ml_tools/SQL.py +0 -8
- ml_tools/_core/_ETL_cleaning.py +0 -694
- ml_tools/_core/_IO_tools.py +0 -498
- ml_tools/_core/_ML_callbacks.py +0 -702
- ml_tools/_core/_ML_configuration.py +0 -1332
- ml_tools/_core/_ML_configuration_pytab.py +0 -102
- ml_tools/_core/_ML_evaluation.py +0 -867
- ml_tools/_core/_ML_evaluation_multi.py +0 -544
- ml_tools/_core/_ML_inference.py +0 -646
- ml_tools/_core/_ML_models.py +0 -668
- ml_tools/_core/_ML_models_pytab.py +0 -693
- ml_tools/_core/_ML_trainer.py +0 -2323
- ml_tools/_core/_ML_utilities.py +0 -886
- ml_tools/_core/_ML_vision_models.py +0 -644
- ml_tools/_core/_data_exploration.py +0 -1901
- ml_tools/_core/_optimization_tools.py +0 -493
- ml_tools/_core/_schema.py +0 -359
- ml_tools/plot_fonts.py +0 -8
- ml_tools/schema.py +0 -12
- {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/WHEEL +0 -0
- {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/licenses/LICENSE +0 -0
- {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +0 -0
- {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,203 @@
|
|
|
1
|
+
from typing import Literal, Union, Optional
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
|
|
4
|
+
from ..optimization_tools import create_optimization_bounds
|
|
5
|
+
from ..ML_inference import DragonInferenceHandler
|
|
6
|
+
from ..schema import FeatureSchema
|
|
7
|
+
|
|
8
|
+
from .._core import get_logger
|
|
9
|
+
from ..keys._keys import MLTaskKeys
|
|
10
|
+
|
|
11
|
+
from ._single_manual import FitnessEvaluator, create_pytorch_problem, run_optimization
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
_LOGGER = get_logger("DragonOptimizer")
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
__all__ = [
|
|
18
|
+
"DragonOptimizer",
|
|
19
|
+
]
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class DragonOptimizer:
|
|
23
|
+
"""
|
|
24
|
+
A wrapper class for setting up and running EvoTorch optimization tasks for regression models.
|
|
25
|
+
|
|
26
|
+
This class combines the functionality of `FitnessEvaluator`, `create_pytorch_problem`, and
|
|
27
|
+
`run_optimization` into a single, streamlined workflow.
|
|
28
|
+
|
|
29
|
+
SNES and CEM algorithms do not accept bounds, the given bounds will be used as an initial starting point.
|
|
30
|
+
|
|
31
|
+
Example:
|
|
32
|
+
>>> # 1. Define bounds for continuous features
|
|
33
|
+
>>> cont_bounds = {'feature_A': (0, 100), 'feature_B': (-10, 10)}
|
|
34
|
+
>>>
|
|
35
|
+
>>> # 2. Initialize the optimizer
|
|
36
|
+
>>> optimizer = DragonOptimizer(
|
|
37
|
+
... inference_handler=my_handler,
|
|
38
|
+
... schema=schema,
|
|
39
|
+
... target_name="my_target",
|
|
40
|
+
... continuous_bounds_map=cont_bounds,
|
|
41
|
+
... task="max",
|
|
42
|
+
... algorithm="Genetic",
|
|
43
|
+
... )
|
|
44
|
+
>>> # 3. Run the optimization
|
|
45
|
+
>>> best_result = optimizer.run(
|
|
46
|
+
... num_generations=100,
|
|
47
|
+
... save_dir="/path/to/results",
|
|
48
|
+
... save_format="csv"
|
|
49
|
+
... )
|
|
50
|
+
"""
|
|
51
|
+
def __init__(self,
|
|
52
|
+
inference_handler: DragonInferenceHandler,
|
|
53
|
+
schema: FeatureSchema,
|
|
54
|
+
target_name: str,
|
|
55
|
+
continuous_bounds_map: dict[str, tuple[float, float]],
|
|
56
|
+
task: Literal["min", "max"],
|
|
57
|
+
algorithm: Literal["SNES", "CEM", "Genetic"] = "Genetic",
|
|
58
|
+
population_size: int = 200,
|
|
59
|
+
discretize_start_at_zero: bool = True,
|
|
60
|
+
**searcher_kwargs):
|
|
61
|
+
"""
|
|
62
|
+
Initializes the optimizer by creating the EvoTorch problem and searcher.
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
inference_handler (DragonInferenceHandler):
|
|
66
|
+
An initialized inference handler containing the model.
|
|
67
|
+
schema (FeatureSchema):
|
|
68
|
+
The definitive schema object from data_exploration.
|
|
69
|
+
target_name (str):
|
|
70
|
+
target name to optimize.
|
|
71
|
+
continuous_bounds_map (Dict[str, Tuple[float, float]]):
|
|
72
|
+
A dictionary mapping the *name* of each **continuous** feature
|
|
73
|
+
to its (min_bound, max_bound) tuple.
|
|
74
|
+
task (str): The optimization goal, either "min" or "max".
|
|
75
|
+
|
|
76
|
+
algorithm (str): The search algorithm to use ("SNES", "CEM", "Genetic").
|
|
77
|
+
population_size (int): Population size for CEM and GeneticAlgorithm.
|
|
78
|
+
discretize_start_at_zero (bool):
|
|
79
|
+
True if the discrete encoding starts at 0 (e.g., [0, 1, 2]).
|
|
80
|
+
False if it starts at 1 (e.g., [1, 2, 3]).
|
|
81
|
+
**searcher_kwargs: Additional keyword arguments for the selected
|
|
82
|
+
search algorithm's constructor.
|
|
83
|
+
"""
|
|
84
|
+
# --- Store schema ---
|
|
85
|
+
self.schema = schema
|
|
86
|
+
# --- Store inference handler ---
|
|
87
|
+
self.inference_handler = inference_handler
|
|
88
|
+
|
|
89
|
+
# Ensure only Regression tasks are used
|
|
90
|
+
allowed_tasks = [MLTaskKeys.REGRESSION, MLTaskKeys.MULTITARGET_REGRESSION]
|
|
91
|
+
if self.inference_handler.task not in allowed_tasks:
|
|
92
|
+
_LOGGER.error(f"DragonOptimizer only supports {allowed_tasks}. Got '{self.inference_handler.task}'.")
|
|
93
|
+
raise ValueError(f"Invalid Task: {self.inference_handler.task}")
|
|
94
|
+
|
|
95
|
+
# --- store target name ---
|
|
96
|
+
self.target_name = target_name
|
|
97
|
+
|
|
98
|
+
# --- flag to control single vs multi-target ---
|
|
99
|
+
self.is_multi_target = False
|
|
100
|
+
|
|
101
|
+
# --- 1. Create bounds from schema ---
|
|
102
|
+
# This is the robust way to get bounds
|
|
103
|
+
bounds = create_optimization_bounds(
|
|
104
|
+
schema=schema,
|
|
105
|
+
continuous_bounds_map=continuous_bounds_map,
|
|
106
|
+
start_at_zero=discretize_start_at_zero
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
# Resolve target index if multi-target
|
|
110
|
+
target_index = None
|
|
111
|
+
|
|
112
|
+
if self.inference_handler.target_ids is None:
|
|
113
|
+
# This should be caught by ML_inference logic
|
|
114
|
+
_LOGGER.error("The provided inference handler does not have 'target_ids' defined.")
|
|
115
|
+
raise ValueError()
|
|
116
|
+
|
|
117
|
+
if target_name not in self.inference_handler.target_ids:
|
|
118
|
+
_LOGGER.error(f"Target name '{target_name}' not found in the inference handler's 'target_ids': {self.inference_handler.target_ids}")
|
|
119
|
+
raise ValueError()
|
|
120
|
+
|
|
121
|
+
if len(self.inference_handler.target_ids) == 1:
|
|
122
|
+
# Single target regression
|
|
123
|
+
target_index = None
|
|
124
|
+
_LOGGER.info(f"Optimization locked to single-target model '{target_name}'.")
|
|
125
|
+
else:
|
|
126
|
+
# Multi-target regression (optimizing one specific column)
|
|
127
|
+
target_index = self.inference_handler.target_ids.index(target_name)
|
|
128
|
+
self.is_multi_target = True
|
|
129
|
+
_LOGGER.info(f"Optimization locked to target '{target_name}' (Index {target_index}) in a multi-target model.")
|
|
130
|
+
|
|
131
|
+
# --- 2. Make a fitness function ---
|
|
132
|
+
self.evaluator = FitnessEvaluator(
|
|
133
|
+
inference_handler=inference_handler,
|
|
134
|
+
# Get categorical info from the schema
|
|
135
|
+
categorical_index_map=schema.categorical_index_map,
|
|
136
|
+
discretize_start_at_zero=discretize_start_at_zero,
|
|
137
|
+
target_index=target_index
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
# --- 3. Create the problem and searcher factory ---
|
|
141
|
+
self.problem, self.searcher_factory = create_pytorch_problem(
|
|
142
|
+
evaluator=self.evaluator,
|
|
143
|
+
bounds=bounds,
|
|
144
|
+
task=task,
|
|
145
|
+
algorithm=algorithm,
|
|
146
|
+
population_size=population_size,
|
|
147
|
+
**searcher_kwargs
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
# --- 4. Store other info needed by run() ---
|
|
151
|
+
self.discretize_start_at_zero = discretize_start_at_zero
|
|
152
|
+
|
|
153
|
+
def run(self,
|
|
154
|
+
num_generations: int,
|
|
155
|
+
save_dir: Union[str, Path],
|
|
156
|
+
save_format: Literal['csv', 'sqlite', 'both'],
|
|
157
|
+
repetitions: int = 1,
|
|
158
|
+
verbose: bool = True) -> Optional[dict]:
|
|
159
|
+
"""
|
|
160
|
+
Runs the evolutionary optimization process using the pre-configured settings.
|
|
161
|
+
|
|
162
|
+
The `feature_names` are automatically pulled from the `FeatureSchema`
|
|
163
|
+
provided during initialization.
|
|
164
|
+
|
|
165
|
+
Args:
|
|
166
|
+
num_generations (int): The total number of generations for each repetition.
|
|
167
|
+
save_dir (str | Path): The directory where result files will be saved.
|
|
168
|
+
save_format (Literal['csv', 'sqlite', 'both']): The format for saving results.
|
|
169
|
+
repetitions (int): The number of independent times to run the optimization.
|
|
170
|
+
verbose (bool): If True, enables detailed logging.
|
|
171
|
+
|
|
172
|
+
Returns:
|
|
173
|
+
Optional[dict]: A dictionary with the best result if repetitions is 1,
|
|
174
|
+
otherwise None.
|
|
175
|
+
"""
|
|
176
|
+
# Pass inference handler and target names for multi-target only
|
|
177
|
+
if self.is_multi_target:
|
|
178
|
+
target_names_to_pass = self.inference_handler.target_ids
|
|
179
|
+
inference_handler_to_pass = self.inference_handler
|
|
180
|
+
else:
|
|
181
|
+
target_names_to_pass = None
|
|
182
|
+
inference_handler_to_pass = None
|
|
183
|
+
|
|
184
|
+
# Call the existing run function, passing info from the schema
|
|
185
|
+
return run_optimization(
|
|
186
|
+
problem=self.problem,
|
|
187
|
+
searcher_factory=self.searcher_factory,
|
|
188
|
+
num_generations=num_generations,
|
|
189
|
+
target_name=self.target_name,
|
|
190
|
+
save_dir=save_dir,
|
|
191
|
+
save_format=save_format,
|
|
192
|
+
# Get the definitive feature names (as a list) from the schema
|
|
193
|
+
feature_names=list(self.schema.feature_names),
|
|
194
|
+
# Get categorical info from the schema
|
|
195
|
+
categorical_map=self.schema.categorical_index_map,
|
|
196
|
+
categorical_mappings=self.schema.categorical_mappings,
|
|
197
|
+
repetitions=repetitions,
|
|
198
|
+
verbose=verbose,
|
|
199
|
+
discretize_start_at_zero=self.discretize_start_at_zero,
|
|
200
|
+
all_target_names=target_names_to_pass,
|
|
201
|
+
inference_handler=inference_handler_to_pass
|
|
202
|
+
)
|
|
203
|
+
|
|
@@ -1,222 +1,39 @@
|
|
|
1
|
-
import pandas
|
|
1
|
+
import pandas as pd
|
|
2
2
|
import torch
|
|
3
|
-
import numpy
|
|
3
|
+
import numpy
|
|
4
4
|
import evotorch
|
|
5
5
|
from evotorch.algorithms import SNES, CEM, GeneticAlgorithm
|
|
6
6
|
from evotorch.logging import PandasLogger
|
|
7
7
|
from evotorch.operators import SimulatedBinaryCrossOver, GaussianMutation
|
|
8
|
-
from typing import Literal, Union,
|
|
8
|
+
from typing import Literal, Union, Optional, Any, Callable
|
|
9
9
|
from pathlib import Path
|
|
10
10
|
from tqdm.auto import trange
|
|
11
11
|
from contextlib import nullcontext
|
|
12
12
|
from functools import partial
|
|
13
13
|
|
|
14
|
-
from
|
|
15
|
-
from
|
|
16
|
-
from
|
|
17
|
-
from ._math_utilities import discretize_categorical_values
|
|
18
|
-
from ._ML_inference import DragonInferenceHandler
|
|
19
|
-
from ._path_manager import make_fullpath, sanitize_filename
|
|
20
|
-
from ._logger import get_logger
|
|
21
|
-
from ._script_info import _script_info
|
|
22
|
-
from ._keys import PyTorchInferenceKeys, MLTaskKeys
|
|
23
|
-
from ._schema import FeatureSchema
|
|
14
|
+
from ..SQL import DragonSQL
|
|
15
|
+
from ..utilities import save_dataframe_filename
|
|
16
|
+
from ..ML_inference import DragonInferenceHandler
|
|
24
17
|
|
|
18
|
+
from ..math_utilities import discretize_categorical_values
|
|
19
|
+
from ..path_manager import make_fullpath, sanitize_filename
|
|
20
|
+
from .._core import get_logger
|
|
21
|
+
from ..keys._keys import PyTorchInferenceKeys
|
|
25
22
|
|
|
26
|
-
|
|
23
|
+
|
|
24
|
+
_LOGGER = get_logger("Optimization")
|
|
27
25
|
|
|
28
26
|
|
|
29
27
|
__all__ = [
|
|
30
|
-
"DragonOptimizer",
|
|
31
28
|
"FitnessEvaluator",
|
|
32
29
|
"create_pytorch_problem",
|
|
33
|
-
"run_optimization"
|
|
30
|
+
"run_optimization",
|
|
31
|
+
"_save_result",
|
|
32
|
+
"_handle_pandas_log",
|
|
33
|
+
"_run_single_optimization_rep"
|
|
34
34
|
]
|
|
35
35
|
|
|
36
36
|
|
|
37
|
-
class DragonOptimizer:
|
|
38
|
-
"""
|
|
39
|
-
A wrapper class for setting up and running EvoTorch optimization tasks for regression models.
|
|
40
|
-
|
|
41
|
-
This class combines the functionality of `FitnessEvaluator`, `create_pytorch_problem`, and
|
|
42
|
-
`run_optimization` into a single, streamlined workflow.
|
|
43
|
-
|
|
44
|
-
SNES and CEM algorithms do not accept bounds, the given bounds will be used as an initial starting point.
|
|
45
|
-
|
|
46
|
-
Example:
|
|
47
|
-
>>> # 1. Define bounds for continuous features
|
|
48
|
-
>>> cont_bounds = {'feature_A': (0, 100), 'feature_B': (-10, 10)}
|
|
49
|
-
>>>
|
|
50
|
-
>>> # 2. Initialize the optimizer
|
|
51
|
-
>>> optimizer = DragonOptimizer(
|
|
52
|
-
... inference_handler=my_handler,
|
|
53
|
-
... schema=schema,
|
|
54
|
-
... target_name="my_target",
|
|
55
|
-
... continuous_bounds_map=cont_bounds,
|
|
56
|
-
... task="max",
|
|
57
|
-
... algorithm="Genetic",
|
|
58
|
-
... )
|
|
59
|
-
>>> # 3. Run the optimization
|
|
60
|
-
>>> best_result = optimizer.run(
|
|
61
|
-
... num_generations=100,
|
|
62
|
-
... save_dir="/path/to/results",
|
|
63
|
-
... save_format="csv"
|
|
64
|
-
... )
|
|
65
|
-
"""
|
|
66
|
-
def __init__(self,
|
|
67
|
-
inference_handler: DragonInferenceHandler,
|
|
68
|
-
schema: FeatureSchema,
|
|
69
|
-
target_name: str,
|
|
70
|
-
continuous_bounds_map: Dict[str, Tuple[float, float]],
|
|
71
|
-
task: Literal["min", "max"],
|
|
72
|
-
algorithm: Literal["SNES", "CEM", "Genetic"] = "Genetic",
|
|
73
|
-
population_size: int = 200,
|
|
74
|
-
discretize_start_at_zero: bool = True,
|
|
75
|
-
**searcher_kwargs):
|
|
76
|
-
"""
|
|
77
|
-
Initializes the optimizer by creating the EvoTorch problem and searcher.
|
|
78
|
-
|
|
79
|
-
Args:
|
|
80
|
-
inference_handler (DragonInferenceHandler):
|
|
81
|
-
An initialized inference handler containing the model.
|
|
82
|
-
schema (FeatureSchema):
|
|
83
|
-
The definitive schema object from data_exploration.
|
|
84
|
-
target_name (str):
|
|
85
|
-
target name to optimize.
|
|
86
|
-
continuous_bounds_map (Dict[str, Tuple[float, float]]):
|
|
87
|
-
A dictionary mapping the *name* of each **continuous** feature
|
|
88
|
-
to its (min_bound, max_bound) tuple.
|
|
89
|
-
task (str): The optimization goal, either "min" or "max".
|
|
90
|
-
|
|
91
|
-
algorithm (str): The search algorithm to use ("SNES", "CEM", "Genetic").
|
|
92
|
-
population_size (int): Population size for CEM and GeneticAlgorithm.
|
|
93
|
-
discretize_start_at_zero (bool):
|
|
94
|
-
True if the discrete encoding starts at 0 (e.g., [0, 1, 2]).
|
|
95
|
-
False if it starts at 1 (e.g., [1, 2, 3]).
|
|
96
|
-
**searcher_kwargs: Additional keyword arguments for the selected
|
|
97
|
-
search algorithm's constructor.
|
|
98
|
-
"""
|
|
99
|
-
# --- Store schema ---
|
|
100
|
-
self.schema = schema
|
|
101
|
-
# --- Store inference handler ---
|
|
102
|
-
self.inference_handler = inference_handler
|
|
103
|
-
|
|
104
|
-
# Ensure only Regression tasks are used
|
|
105
|
-
allowed_tasks = [MLTaskKeys.REGRESSION, MLTaskKeys.MULTITARGET_REGRESSION]
|
|
106
|
-
if self.inference_handler.task not in allowed_tasks:
|
|
107
|
-
_LOGGER.error(f"DragonOptimizer only supports {allowed_tasks}. Got '{self.inference_handler.task}'.")
|
|
108
|
-
raise ValueError(f"Invalid Task: {self.inference_handler.task}")
|
|
109
|
-
|
|
110
|
-
# --- store target name ---
|
|
111
|
-
self.target_name = target_name
|
|
112
|
-
|
|
113
|
-
# --- flag to control single vs multi-target ---
|
|
114
|
-
self.is_multi_target = False
|
|
115
|
-
|
|
116
|
-
# --- 1. Create bounds from schema ---
|
|
117
|
-
# This is the robust way to get bounds
|
|
118
|
-
bounds = create_optimization_bounds(
|
|
119
|
-
schema=schema,
|
|
120
|
-
continuous_bounds_map=continuous_bounds_map,
|
|
121
|
-
start_at_zero=discretize_start_at_zero
|
|
122
|
-
)
|
|
123
|
-
|
|
124
|
-
# Resolve target index if multi-target
|
|
125
|
-
target_index = None
|
|
126
|
-
|
|
127
|
-
if self.inference_handler.target_ids is None:
|
|
128
|
-
# This should be caught by ML_inference logic
|
|
129
|
-
_LOGGER.error("The provided inference handler does not have 'target_ids' defined.")
|
|
130
|
-
raise ValueError()
|
|
131
|
-
|
|
132
|
-
if target_name not in self.inference_handler.target_ids:
|
|
133
|
-
_LOGGER.error(f"Target name '{target_name}' not found in the inference handler's 'target_ids': {self.inference_handler.target_ids}")
|
|
134
|
-
raise ValueError()
|
|
135
|
-
|
|
136
|
-
if len(self.inference_handler.target_ids) == 1:
|
|
137
|
-
# Single target regression
|
|
138
|
-
target_index = None
|
|
139
|
-
_LOGGER.info(f"Optimization locked to single-target model '{target_name}'.")
|
|
140
|
-
else:
|
|
141
|
-
# Multi-target regression (optimizing one specific column)
|
|
142
|
-
target_index = self.inference_handler.target_ids.index(target_name)
|
|
143
|
-
self.is_multi_target = True
|
|
144
|
-
_LOGGER.info(f"Optimization locked to target '{target_name}' (Index {target_index}) in a multi-target model.")
|
|
145
|
-
|
|
146
|
-
# --- 2. Make a fitness function ---
|
|
147
|
-
self.evaluator = FitnessEvaluator(
|
|
148
|
-
inference_handler=inference_handler,
|
|
149
|
-
# Get categorical info from the schema
|
|
150
|
-
categorical_index_map=schema.categorical_index_map,
|
|
151
|
-
discretize_start_at_zero=discretize_start_at_zero,
|
|
152
|
-
target_index=target_index
|
|
153
|
-
)
|
|
154
|
-
|
|
155
|
-
# --- 3. Create the problem and searcher factory ---
|
|
156
|
-
self.problem, self.searcher_factory = create_pytorch_problem(
|
|
157
|
-
evaluator=self.evaluator,
|
|
158
|
-
bounds=bounds,
|
|
159
|
-
task=task,
|
|
160
|
-
algorithm=algorithm,
|
|
161
|
-
population_size=population_size,
|
|
162
|
-
**searcher_kwargs
|
|
163
|
-
)
|
|
164
|
-
|
|
165
|
-
# --- 4. Store other info needed by run() ---
|
|
166
|
-
self.discretize_start_at_zero = discretize_start_at_zero
|
|
167
|
-
|
|
168
|
-
def run(self,
|
|
169
|
-
num_generations: int,
|
|
170
|
-
save_dir: Union[str, Path],
|
|
171
|
-
save_format: Literal['csv', 'sqlite', 'both'],
|
|
172
|
-
repetitions: int = 1,
|
|
173
|
-
verbose: bool = True) -> Optional[dict]:
|
|
174
|
-
"""
|
|
175
|
-
Runs the evolutionary optimization process using the pre-configured settings.
|
|
176
|
-
|
|
177
|
-
The `feature_names` are automatically pulled from the `FeatureSchema`
|
|
178
|
-
provided during initialization.
|
|
179
|
-
|
|
180
|
-
Args:
|
|
181
|
-
num_generations (int): The total number of generations for each repetition.
|
|
182
|
-
save_dir (str | Path): The directory where result files will be saved.
|
|
183
|
-
save_format (Literal['csv', 'sqlite', 'both']): The format for saving results.
|
|
184
|
-
repetitions (int): The number of independent times to run the optimization.
|
|
185
|
-
verbose (bool): If True, enables detailed logging.
|
|
186
|
-
|
|
187
|
-
Returns:
|
|
188
|
-
Optional[dict]: A dictionary with the best result if repetitions is 1,
|
|
189
|
-
otherwise None.
|
|
190
|
-
"""
|
|
191
|
-
# Pass inference handler and target names for multi-target only
|
|
192
|
-
if self.is_multi_target:
|
|
193
|
-
target_names_to_pass = self.inference_handler.target_ids
|
|
194
|
-
inference_handler_to_pass = self.inference_handler
|
|
195
|
-
else:
|
|
196
|
-
target_names_to_pass = None
|
|
197
|
-
inference_handler_to_pass = None
|
|
198
|
-
|
|
199
|
-
# Call the existing run function, passing info from the schema
|
|
200
|
-
return run_optimization(
|
|
201
|
-
problem=self.problem,
|
|
202
|
-
searcher_factory=self.searcher_factory,
|
|
203
|
-
num_generations=num_generations,
|
|
204
|
-
target_name=self.target_name,
|
|
205
|
-
save_dir=save_dir,
|
|
206
|
-
save_format=save_format,
|
|
207
|
-
# Get the definitive feature names (as a list) from the schema
|
|
208
|
-
feature_names=list(self.schema.feature_names),
|
|
209
|
-
# Get categorical info from the schema
|
|
210
|
-
categorical_map=self.schema.categorical_index_map,
|
|
211
|
-
categorical_mappings=self.schema.categorical_mappings,
|
|
212
|
-
repetitions=repetitions,
|
|
213
|
-
verbose=verbose,
|
|
214
|
-
discretize_start_at_zero=self.discretize_start_at_zero,
|
|
215
|
-
all_target_names=target_names_to_pass,
|
|
216
|
-
inference_handler=inference_handler_to_pass
|
|
217
|
-
)
|
|
218
|
-
|
|
219
|
-
|
|
220
37
|
class FitnessEvaluator:
|
|
221
38
|
"""
|
|
222
39
|
A callable class that wraps the PyTorch model inference handler and performs
|
|
@@ -227,7 +44,7 @@ class FitnessEvaluator:
|
|
|
227
44
|
"""
|
|
228
45
|
def __init__(self,
|
|
229
46
|
inference_handler: DragonInferenceHandler,
|
|
230
|
-
categorical_index_map: Optional[
|
|
47
|
+
categorical_index_map: Optional[dict[int, int]] = None,
|
|
231
48
|
target_index: Optional[int] = None,
|
|
232
49
|
discretize_start_at_zero: bool = True):
|
|
233
50
|
"""
|
|
@@ -287,12 +104,12 @@ class FitnessEvaluator:
|
|
|
287
104
|
|
|
288
105
|
def create_pytorch_problem(
|
|
289
106
|
evaluator: FitnessEvaluator,
|
|
290
|
-
bounds:
|
|
107
|
+
bounds: tuple[list[float], list[float]],
|
|
291
108
|
task: Literal["min", "max"],
|
|
292
109
|
algorithm: Literal["SNES", "CEM", "Genetic"] = "Genetic",
|
|
293
110
|
population_size: int = 200,
|
|
294
111
|
**searcher_kwargs
|
|
295
|
-
) ->
|
|
112
|
+
) -> tuple[evotorch.Problem, Callable[[], Any]]:
|
|
296
113
|
"""
|
|
297
114
|
Creates and configures an EvoTorch Problem and a Searcher factory class for a PyTorch model.
|
|
298
115
|
|
|
@@ -389,13 +206,13 @@ def run_optimization(
|
|
|
389
206
|
target_name: str,
|
|
390
207
|
save_dir: Union[str, Path],
|
|
391
208
|
save_format: Literal['csv', 'sqlite', 'both'],
|
|
392
|
-
feature_names: Optional[
|
|
209
|
+
feature_names: Optional[list[str]],
|
|
393
210
|
repetitions: int = 1,
|
|
394
211
|
verbose: bool = True,
|
|
395
|
-
categorical_map: Optional[
|
|
396
|
-
categorical_mappings: Optional[
|
|
212
|
+
categorical_map: Optional[dict[int, int]] = None,
|
|
213
|
+
categorical_mappings: Optional[dict[str, dict[str, int]]] = None,
|
|
397
214
|
discretize_start_at_zero: bool = True,
|
|
398
|
-
all_target_names: Optional[
|
|
215
|
+
all_target_names: Optional[list[str]] = None,
|
|
399
216
|
inference_handler: Optional[DragonInferenceHandler] = None
|
|
400
217
|
) -> Optional[dict]:
|
|
401
218
|
"""
|
|
@@ -563,14 +380,14 @@ def run_optimization(
|
|
|
563
380
|
def _run_single_optimization_rep(
|
|
564
381
|
searcher_factory: Callable[[],Any],
|
|
565
382
|
num_generations: int,
|
|
566
|
-
feature_names:
|
|
383
|
+
feature_names: list[str],
|
|
567
384
|
target_name: str,
|
|
568
|
-
categorical_map: Optional[
|
|
385
|
+
categorical_map: Optional[dict[int, int]],
|
|
569
386
|
discretize_start_at_zero: bool,
|
|
570
387
|
attach_logger: bool,
|
|
571
|
-
all_target_names:
|
|
388
|
+
all_target_names: list[str],
|
|
572
389
|
inference_handler: Optional[DragonInferenceHandler]
|
|
573
|
-
) ->
|
|
390
|
+
) -> tuple[dict, Optional[PandasLogger]]:
|
|
574
391
|
"""
|
|
575
392
|
Internal helper to run one full optimization repetition.
|
|
576
393
|
|
|
@@ -639,10 +456,55 @@ def _run_single_optimization_rep(
|
|
|
639
456
|
return result_dict, pandas_logger
|
|
640
457
|
|
|
641
458
|
|
|
459
|
+
def _save_result(
|
|
460
|
+
result_dict: dict,
|
|
461
|
+
save_format: Literal['csv', 'sqlite', 'both'],
|
|
462
|
+
csv_path: Path,
|
|
463
|
+
db_manager: Optional[DragonSQL] = None,
|
|
464
|
+
db_table_name: Optional[str] = None,
|
|
465
|
+
categorical_mappings: Optional[dict[str, dict[str, int]]] = None
|
|
466
|
+
):
|
|
467
|
+
"""
|
|
468
|
+
Private helper to handle saving a single result to CSV, SQLite, or both.
|
|
469
|
+
|
|
470
|
+
If `categorical_mappings` is provided, it will reverse-map integer values
|
|
471
|
+
to their string representations before saving.
|
|
472
|
+
"""
|
|
473
|
+
# --- Reverse Mapping Logic ---
|
|
474
|
+
# Create a copy to hold the values to be saved
|
|
475
|
+
save_dict = result_dict.copy()
|
|
476
|
+
|
|
477
|
+
if categorical_mappings:
|
|
478
|
+
for feature_name, mapping in categorical_mappings.items():
|
|
479
|
+
if feature_name in save_dict:
|
|
480
|
+
# Create a reverse map {0: 'Category_A', 1: 'Category_B'}
|
|
481
|
+
reverse_map = {idx: name for name, idx in mapping.items()}
|
|
482
|
+
|
|
483
|
+
# Get the integer value from the results (e.g., 0)
|
|
484
|
+
int_value = save_dict[feature_name]
|
|
485
|
+
|
|
486
|
+
# Find the corresponding string (e.g., 'Category_A')
|
|
487
|
+
# Use .get() for safety, defaulting to the original value if not found
|
|
488
|
+
string_value = reverse_map.get(int_value, int_value)
|
|
489
|
+
|
|
490
|
+
# Update the dictionary that will be saved
|
|
491
|
+
save_dict[feature_name] = string_value
|
|
492
|
+
|
|
493
|
+
# Save to CSV
|
|
494
|
+
if save_format in ['csv', 'both']:
|
|
495
|
+
df_row = pd.DataFrame([save_dict])
|
|
496
|
+
file_exists = csv_path.exists()
|
|
497
|
+
df_row.to_csv(csv_path, mode='a', index=False, header=not file_exists)
|
|
498
|
+
|
|
499
|
+
# Save to SQLite
|
|
500
|
+
if save_format in ['sqlite', 'both']:
|
|
501
|
+
if db_manager and db_table_name:
|
|
502
|
+
db_manager.insert_row(db_table_name, save_dict)
|
|
503
|
+
else:
|
|
504
|
+
_LOGGER.warning("SQLite saving requested but db_manager or table_name not provided.")
|
|
505
|
+
|
|
506
|
+
|
|
642
507
|
def _handle_pandas_log(logger: PandasLogger, save_path: Path, target_name: str):
|
|
643
508
|
log_dataframe = logger.to_dataframe()
|
|
644
509
|
save_dataframe_filename(df=log_dataframe, save_dir=save_path / "EvolutionLogs", filename=target_name)
|
|
645
510
|
|
|
646
|
-
|
|
647
|
-
def info():
|
|
648
|
-
_script_info(__all__)
|
|
@@ -1,12 +1,11 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
from torch.utils.data import Dataset, DataLoader
|
|
3
3
|
from pathlib import Path
|
|
4
|
-
from typing import Union,
|
|
4
|
+
from typing import Union, Optional
|
|
5
5
|
|
|
6
|
-
from
|
|
7
|
-
from
|
|
8
|
-
from .
|
|
9
|
-
from ._keys import ScalerKeys
|
|
6
|
+
from .._core import get_logger
|
|
7
|
+
from ..path_manager import make_fullpath
|
|
8
|
+
from ..keys._keys import ScalerKeys
|
|
10
9
|
|
|
11
10
|
|
|
12
11
|
_LOGGER = get_logger("DragonScaler")
|
|
@@ -25,7 +24,7 @@ class DragonScaler:
|
|
|
25
24
|
def __init__(self,
|
|
26
25
|
mean: Optional[torch.Tensor] = None,
|
|
27
26
|
std: Optional[torch.Tensor] = None,
|
|
28
|
-
continuous_feature_indices: Optional[
|
|
27
|
+
continuous_feature_indices: Optional[list[int]] = None):
|
|
29
28
|
"""
|
|
30
29
|
Initializes the scaler.
|
|
31
30
|
"""
|
|
@@ -34,7 +33,7 @@ class DragonScaler:
|
|
|
34
33
|
self.continuous_feature_indices = continuous_feature_indices
|
|
35
34
|
|
|
36
35
|
@classmethod
|
|
37
|
-
def fit(cls, dataset: Dataset, continuous_feature_indices:
|
|
36
|
+
def fit(cls, dataset: Dataset, continuous_feature_indices: list[int], batch_size: int = 64) -> 'DragonScaler':
|
|
38
37
|
"""
|
|
39
38
|
Fits the scaler using a PyTorch Dataset (Method A) using Batched Welford's Algorithm.
|
|
40
39
|
"""
|
|
@@ -72,7 +71,7 @@ class DragonScaler:
|
|
|
72
71
|
else:
|
|
73
72
|
# Batched Welford's Update
|
|
74
73
|
# Combine existing global stats (A) with new batch stats (B)
|
|
75
|
-
delta = mean_batch - mean_global
|
|
74
|
+
delta = mean_batch - mean_global # type: ignore
|
|
76
75
|
new_n_total = n_total + n_batch
|
|
77
76
|
|
|
78
77
|
# Update M2 (Sum of Squares)
|
|
@@ -93,7 +92,7 @@ class DragonScaler:
|
|
|
93
92
|
# Unbiased estimator (divide by n-1)
|
|
94
93
|
if n_total < 2:
|
|
95
94
|
_LOGGER.warning(f"Only one sample found. Standard deviation set to 1.")
|
|
96
|
-
std = torch.ones_like(mean_global)
|
|
95
|
+
std = torch.ones_like(mean_global) # type: ignore
|
|
97
96
|
else:
|
|
98
97
|
variance = m2_global / (n_total - 1)
|
|
99
98
|
std = torch.sqrt(torch.clamp(variance, min=1e-8))
|
|
@@ -218,5 +217,3 @@ class DragonScaler:
|
|
|
218
217
|
return f"DragonScaler(fitted for {num_features} columns)"
|
|
219
218
|
return "DragonScaler(not fitted)"
|
|
220
219
|
|
|
221
|
-
def info():
|
|
222
|
-
_script_info(__all__)
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
from ._dragon_trainer import (
|
|
2
|
+
DragonTrainer
|
|
3
|
+
)
|
|
4
|
+
|
|
5
|
+
from ._dragon_sequence_trainer import (
|
|
6
|
+
DragonSequenceTrainer
|
|
7
|
+
)
|
|
8
|
+
|
|
9
|
+
from ._dragon_detection_trainer import (
|
|
10
|
+
DragonDetectionTrainer
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
from ._imprimir import info
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
__all__ = [
|
|
17
|
+
"DragonTrainer",
|
|
18
|
+
"DragonSequenceTrainer",
|
|
19
|
+
"DragonDetectionTrainer",
|
|
20
|
+
]
|