dragon-ml-toolbox 20.2.0__py3-none-any.whl → 20.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dragon_ml_toolbox-20.2.0.dist-info → dragon_ml_toolbox-20.4.0.dist-info}/METADATA +1 -1
- dragon_ml_toolbox-20.4.0.dist-info/RECORD +143 -0
- ml_tools/ETL_cleaning/__init__.py +5 -1
- ml_tools/ETL_cleaning/_basic_clean.py +1 -1
- ml_tools/ETL_engineering/__init__.py +5 -1
- ml_tools/GUI_tools/__init__.py +5 -1
- ml_tools/IO_tools/_IO_loggers.py +33 -21
- ml_tools/IO_tools/__init__.py +5 -1
- ml_tools/MICE/__init__.py +8 -2
- ml_tools/MICE/_dragon_mice.py +1 -1
- ml_tools/ML_callbacks/__init__.py +5 -1
- ml_tools/ML_chain/__init__.py +5 -1
- ml_tools/ML_configuration/__init__.py +7 -1
- ml_tools/ML_configuration/_training.py +65 -1
- ml_tools/ML_datasetmaster/__init__.py +5 -1
- ml_tools/ML_datasetmaster/_base_datasetmaster.py +31 -20
- ml_tools/ML_datasetmaster/_datasetmaster.py +26 -9
- ml_tools/ML_datasetmaster/_sequence_datasetmaster.py +38 -23
- ml_tools/ML_evaluation/__init__.py +5 -1
- ml_tools/ML_evaluation/_classification.py +10 -2
- ml_tools/ML_evaluation_captum/__init__.py +5 -1
- ml_tools/ML_finalize_handler/__init__.py +5 -1
- ml_tools/ML_inference/__init__.py +5 -1
- ml_tools/ML_inference_sequence/__init__.py +5 -1
- ml_tools/ML_inference_vision/__init__.py +5 -1
- ml_tools/ML_models/__init__.py +21 -6
- ml_tools/ML_models/_dragon_autoint.py +302 -0
- ml_tools/ML_models/_dragon_gate.py +358 -0
- ml_tools/ML_models/_dragon_node.py +268 -0
- ml_tools/ML_models/_dragon_tabnet.py +255 -0
- ml_tools/ML_models_sequence/__init__.py +5 -1
- ml_tools/ML_models_vision/__init__.py +5 -1
- ml_tools/ML_optimization/__init__.py +11 -3
- ml_tools/ML_optimization/_multi_dragon.py +24 -8
- ml_tools/ML_optimization/_single_dragon.py +47 -67
- ml_tools/ML_optimization/_single_manual.py +1 -1
- ml_tools/ML_scaler/_ML_scaler.py +12 -7
- ml_tools/ML_scaler/__init__.py +5 -1
- ml_tools/ML_trainer/__init__.py +5 -1
- ml_tools/ML_trainer/_base_trainer.py +136 -13
- ml_tools/ML_trainer/_dragon_detection_trainer.py +31 -91
- ml_tools/ML_trainer/_dragon_sequence_trainer.py +24 -74
- ml_tools/ML_trainer/_dragon_trainer.py +24 -85
- ml_tools/ML_utilities/__init__.py +5 -1
- ml_tools/ML_utilities/_inspection.py +44 -30
- ml_tools/ML_vision_transformers/__init__.py +8 -2
- ml_tools/PSO_optimization/__init__.py +5 -1
- ml_tools/SQL/__init__.py +8 -2
- ml_tools/VIF/__init__.py +5 -1
- ml_tools/data_exploration/__init__.py +4 -1
- ml_tools/data_exploration/_cleaning.py +4 -2
- ml_tools/ensemble_evaluation/__init__.py +5 -1
- ml_tools/ensemble_inference/__init__.py +5 -1
- ml_tools/ensemble_learning/__init__.py +5 -1
- ml_tools/excel_handler/__init__.py +5 -1
- ml_tools/keys/__init__.py +5 -1
- ml_tools/keys/_keys.py +1 -1
- ml_tools/math_utilities/__init__.py +5 -1
- ml_tools/optimization_tools/__init__.py +5 -1
- ml_tools/path_manager/__init__.py +8 -2
- ml_tools/plot_fonts/__init__.py +8 -2
- ml_tools/schema/__init__.py +8 -2
- ml_tools/schema/_feature_schema.py +3 -3
- ml_tools/serde/__init__.py +5 -1
- ml_tools/utilities/__init__.py +5 -1
- ml_tools/utilities/_utility_save_load.py +38 -20
- dragon_ml_toolbox-20.2.0.dist-info/RECORD +0 -179
- ml_tools/ETL_cleaning/_imprimir.py +0 -13
- ml_tools/ETL_engineering/_imprimir.py +0 -24
- ml_tools/GUI_tools/_imprimir.py +0 -12
- ml_tools/IO_tools/_imprimir.py +0 -14
- ml_tools/MICE/_imprimir.py +0 -11
- ml_tools/ML_callbacks/_imprimir.py +0 -12
- ml_tools/ML_chain/_imprimir.py +0 -12
- ml_tools/ML_configuration/_imprimir.py +0 -47
- ml_tools/ML_datasetmaster/_imprimir.py +0 -15
- ml_tools/ML_evaluation/_imprimir.py +0 -25
- ml_tools/ML_evaluation_captum/_imprimir.py +0 -10
- ml_tools/ML_finalize_handler/_imprimir.py +0 -8
- ml_tools/ML_inference/_imprimir.py +0 -11
- ml_tools/ML_inference_sequence/_imprimir.py +0 -8
- ml_tools/ML_inference_vision/_imprimir.py +0 -8
- ml_tools/ML_models/_advanced_models.py +0 -1086
- ml_tools/ML_models/_imprimir.py +0 -18
- ml_tools/ML_models_sequence/_imprimir.py +0 -8
- ml_tools/ML_models_vision/_imprimir.py +0 -16
- ml_tools/ML_optimization/_imprimir.py +0 -13
- ml_tools/ML_scaler/_imprimir.py +0 -8
- ml_tools/ML_trainer/_imprimir.py +0 -10
- ml_tools/ML_utilities/_imprimir.py +0 -16
- ml_tools/ML_vision_transformers/_imprimir.py +0 -14
- ml_tools/PSO_optimization/_imprimir.py +0 -10
- ml_tools/SQL/_imprimir.py +0 -8
- ml_tools/VIF/_imprimir.py +0 -10
- ml_tools/data_exploration/_imprimir.py +0 -32
- ml_tools/ensemble_evaluation/_imprimir.py +0 -14
- ml_tools/ensemble_inference/_imprimir.py +0 -9
- ml_tools/ensemble_learning/_imprimir.py +0 -10
- ml_tools/excel_handler/_imprimir.py +0 -13
- ml_tools/keys/_imprimir.py +0 -11
- ml_tools/math_utilities/_imprimir.py +0 -11
- ml_tools/optimization_tools/_imprimir.py +0 -13
- ml_tools/path_manager/_imprimir.py +0 -15
- ml_tools/plot_fonts/_imprimir.py +0 -8
- ml_tools/schema/_imprimir.py +0 -10
- ml_tools/serde/_imprimir.py +0 -10
- ml_tools/utilities/_imprimir.py +0 -18
- {dragon_ml_toolbox-20.2.0.dist-info → dragon_ml_toolbox-20.4.0.dist-info}/WHEEL +0 -0
- {dragon_ml_toolbox-20.2.0.dist-info → dragon_ml_toolbox-20.4.0.dist-info}/licenses/LICENSE +0 -0
- {dragon_ml_toolbox-20.2.0.dist-info → dragon_ml_toolbox-20.4.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +0 -0
- {dragon_ml_toolbox-20.2.0.dist-info → dragon_ml_toolbox-20.4.0.dist-info}/top_level.txt +0 -0
|
@@ -1,9 +1,10 @@
|
|
|
1
1
|
from typing import Literal, Union, Optional
|
|
2
2
|
from pathlib import Path
|
|
3
3
|
|
|
4
|
-
from ..optimization_tools import create_optimization_bounds
|
|
4
|
+
from ..optimization_tools import create_optimization_bounds, load_continuous_bounds_template
|
|
5
5
|
from ..ML_inference import DragonInferenceHandler
|
|
6
6
|
from ..schema import FeatureSchema
|
|
7
|
+
from ..ML_configuration import DragonOptimizerConfig
|
|
7
8
|
|
|
8
9
|
from .._core import get_logger
|
|
9
10
|
from ..keys._keys import MLTaskKeys
|
|
@@ -29,35 +30,28 @@ class DragonOptimizer:
|
|
|
29
30
|
SNES and CEM algorithms do not accept bounds, the given bounds will be used as an initial starting point.
|
|
30
31
|
|
|
31
32
|
Example:
|
|
32
|
-
>>> # 1. Define
|
|
33
|
-
>>>
|
|
33
|
+
>>> # 1. Define configuration
|
|
34
|
+
>>> config = DragonOptimizerConfig(
|
|
35
|
+
... target_name="my_target",
|
|
36
|
+
... task="max",
|
|
37
|
+
... continuous_bounds_map="path/to/bounds",
|
|
38
|
+
... save_directory="/path/to/results",
|
|
39
|
+
... algorithm="Genetic"
|
|
40
|
+
... )
|
|
34
41
|
>>>
|
|
35
42
|
>>> # 2. Initialize the optimizer
|
|
36
43
|
>>> optimizer = DragonOptimizer(
|
|
37
44
|
... inference_handler=my_handler,
|
|
38
45
|
... schema=schema,
|
|
39
|
-
...
|
|
40
|
-
... continuous_bounds_map=cont_bounds,
|
|
41
|
-
... task="max",
|
|
42
|
-
... algorithm="Genetic",
|
|
46
|
+
... config=config
|
|
43
47
|
... )
|
|
44
48
|
>>> # 3. Run the optimization
|
|
45
|
-
>>> best_result = optimizer.run(
|
|
46
|
-
... num_generations=100,
|
|
47
|
-
... save_dir="/path/to/results",
|
|
48
|
-
... save_format="csv"
|
|
49
|
-
... )
|
|
49
|
+
>>> best_result = optimizer.run()
|
|
50
50
|
"""
|
|
51
51
|
def __init__(self,
|
|
52
52
|
inference_handler: DragonInferenceHandler,
|
|
53
53
|
schema: FeatureSchema,
|
|
54
|
-
|
|
55
|
-
continuous_bounds_map: dict[str, tuple[float, float]],
|
|
56
|
-
task: Literal["min", "max"],
|
|
57
|
-
algorithm: Literal["SNES", "CEM", "Genetic"] = "Genetic",
|
|
58
|
-
population_size: int = 200,
|
|
59
|
-
discretize_start_at_zero: bool = True,
|
|
60
|
-
**searcher_kwargs):
|
|
54
|
+
config: DragonOptimizerConfig):
|
|
61
55
|
"""
|
|
62
56
|
Initializes the optimizer by creating the EvoTorch problem and searcher.
|
|
63
57
|
|
|
@@ -65,45 +59,43 @@ class DragonOptimizer:
|
|
|
65
59
|
inference_handler (DragonInferenceHandler):
|
|
66
60
|
An initialized inference handler containing the model.
|
|
67
61
|
schema (FeatureSchema):
|
|
68
|
-
The definitive schema object
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
continuous_bounds_map (Dict[str, Tuple[float, float]]):
|
|
72
|
-
A dictionary mapping the *name* of each **continuous** feature
|
|
73
|
-
to its (min_bound, max_bound) tuple.
|
|
74
|
-
task (str): The optimization goal, either "min" or "max".
|
|
75
|
-
|
|
76
|
-
algorithm (str): The search algorithm to use ("SNES", "CEM", "Genetic").
|
|
77
|
-
population_size (int): Population size for CEM and GeneticAlgorithm.
|
|
78
|
-
discretize_start_at_zero (bool):
|
|
79
|
-
True if the discrete encoding starts at 0 (e.g., [0, 1, 2]).
|
|
80
|
-
False if it starts at 1 (e.g., [1, 2, 3]).
|
|
81
|
-
**searcher_kwargs: Additional keyword arguments for the selected
|
|
82
|
-
search algorithm's constructor.
|
|
62
|
+
The definitive schema object.
|
|
63
|
+
config (DragonOptimizerConfig):
|
|
64
|
+
Configuration object containing optimization parameters.
|
|
83
65
|
"""
|
|
84
66
|
# --- Store schema ---
|
|
85
67
|
self.schema = schema
|
|
86
68
|
# --- Store inference handler ---
|
|
87
69
|
self.inference_handler = inference_handler
|
|
88
70
|
|
|
71
|
+
# --- Store config ---
|
|
72
|
+
self.config = config
|
|
73
|
+
|
|
89
74
|
# Ensure only Regression tasks are used
|
|
90
75
|
allowed_tasks = [MLTaskKeys.REGRESSION, MLTaskKeys.MULTITARGET_REGRESSION]
|
|
91
76
|
if self.inference_handler.task not in allowed_tasks:
|
|
92
77
|
_LOGGER.error(f"DragonOptimizer only supports {allowed_tasks}. Got '{self.inference_handler.task}'.")
|
|
93
|
-
raise ValueError(
|
|
78
|
+
raise ValueError()
|
|
94
79
|
|
|
95
80
|
# --- store target name ---
|
|
96
|
-
self.target_name = target_name
|
|
81
|
+
self.target_name = config.target_name
|
|
97
82
|
|
|
98
83
|
# --- flag to control single vs multi-target ---
|
|
99
84
|
self.is_multi_target = False
|
|
100
85
|
|
|
101
86
|
# --- 1. Create bounds from schema ---
|
|
102
|
-
#
|
|
87
|
+
# Handle bounds loading if it's a path
|
|
88
|
+
raw_bounds_map = config.continuous_bounds_map
|
|
89
|
+
if isinstance(raw_bounds_map, (str, Path)):
|
|
90
|
+
continuous_bounds = load_continuous_bounds_template(raw_bounds_map)
|
|
91
|
+
else:
|
|
92
|
+
continuous_bounds = raw_bounds_map
|
|
93
|
+
|
|
94
|
+
# Robust way to get bounds
|
|
103
95
|
bounds = create_optimization_bounds(
|
|
104
96
|
schema=schema,
|
|
105
|
-
continuous_bounds_map=
|
|
106
|
-
start_at_zero=discretize_start_at_zero
|
|
97
|
+
continuous_bounds_map=continuous_bounds,
|
|
98
|
+
start_at_zero=config.discretize_start_at_zero
|
|
107
99
|
)
|
|
108
100
|
|
|
109
101
|
# Resolve target index if multi-target
|
|
@@ -114,26 +106,26 @@ class DragonOptimizer:
|
|
|
114
106
|
_LOGGER.error("The provided inference handler does not have 'target_ids' defined.")
|
|
115
107
|
raise ValueError()
|
|
116
108
|
|
|
117
|
-
if target_name not in self.inference_handler.target_ids:
|
|
118
|
-
_LOGGER.error(f"Target name '{target_name}' not found in the inference handler's 'target_ids': {self.inference_handler.target_ids}")
|
|
109
|
+
if self.target_name not in self.inference_handler.target_ids:
|
|
110
|
+
_LOGGER.error(f"Target name '{self.target_name}' not found in the inference handler's 'target_ids': {self.inference_handler.target_ids}")
|
|
119
111
|
raise ValueError()
|
|
120
112
|
|
|
121
113
|
if len(self.inference_handler.target_ids) == 1:
|
|
122
114
|
# Single target regression
|
|
123
115
|
target_index = None
|
|
124
|
-
_LOGGER.info(f"Optimization locked to single-target model '{target_name}'.")
|
|
116
|
+
_LOGGER.info(f"Optimization locked to single-target model '{self.target_name}'.")
|
|
125
117
|
else:
|
|
126
118
|
# Multi-target regression (optimizing one specific column)
|
|
127
|
-
target_index = self.inference_handler.target_ids.index(target_name)
|
|
119
|
+
target_index = self.inference_handler.target_ids.index(self.target_name)
|
|
128
120
|
self.is_multi_target = True
|
|
129
|
-
_LOGGER.info(f"Optimization locked to target '{target_name}' (Index {target_index}) in a multi-target model.")
|
|
121
|
+
_LOGGER.info(f"Optimization locked to target '{self.target_name}' (Index {target_index}) in a multi-target model.")
|
|
130
122
|
|
|
131
123
|
# --- 2. Make a fitness function ---
|
|
132
124
|
self.evaluator = FitnessEvaluator(
|
|
133
125
|
inference_handler=inference_handler,
|
|
134
126
|
# Get categorical info from the schema
|
|
135
127
|
categorical_index_map=schema.categorical_index_map,
|
|
136
|
-
discretize_start_at_zero=discretize_start_at_zero,
|
|
128
|
+
discretize_start_at_zero=config.discretize_start_at_zero,
|
|
137
129
|
target_index=target_index
|
|
138
130
|
)
|
|
139
131
|
|
|
@@ -141,20 +133,13 @@ class DragonOptimizer:
|
|
|
141
133
|
self.problem, self.searcher_factory = create_pytorch_problem(
|
|
142
134
|
evaluator=self.evaluator,
|
|
143
135
|
bounds=bounds,
|
|
144
|
-
task=task,
|
|
145
|
-
algorithm=algorithm,
|
|
146
|
-
population_size=population_size,
|
|
147
|
-
**searcher_kwargs
|
|
136
|
+
task=config.task, # type: ignore
|
|
137
|
+
algorithm=config.algorithm, # type: ignore
|
|
138
|
+
population_size=config.population_size,
|
|
139
|
+
**config.searcher_kwargs
|
|
148
140
|
)
|
|
149
|
-
|
|
150
|
-
# --- 4. Store other info needed by run() ---
|
|
151
|
-
self.discretize_start_at_zero = discretize_start_at_zero
|
|
152
141
|
|
|
153
142
|
def run(self,
|
|
154
|
-
num_generations: int,
|
|
155
|
-
save_dir: Union[str, Path],
|
|
156
|
-
save_format: Literal['csv', 'sqlite', 'both'],
|
|
157
|
-
repetitions: int = 1,
|
|
158
143
|
verbose: bool = True) -> Optional[dict]:
|
|
159
144
|
"""
|
|
160
145
|
Runs the evolutionary optimization process using the pre-configured settings.
|
|
@@ -163,15 +148,10 @@ class DragonOptimizer:
|
|
|
163
148
|
provided during initialization.
|
|
164
149
|
|
|
165
150
|
Args:
|
|
166
|
-
num_generations (int): The total number of generations for each repetition.
|
|
167
|
-
save_dir (str | Path): The directory where result files will be saved.
|
|
168
|
-
save_format (Literal['csv', 'sqlite', 'both']): The format for saving results.
|
|
169
|
-
repetitions (int): The number of independent times to run the optimization.
|
|
170
151
|
verbose (bool): If True, enables detailed logging.
|
|
171
152
|
|
|
172
153
|
Returns:
|
|
173
|
-
Optional[dict]: A dictionary with the best result if repetitions is 1,
|
|
174
|
-
otherwise None.
|
|
154
|
+
Optional[dict]: A dictionary with the best result if repetitions is 1, otherwise None.
|
|
175
155
|
"""
|
|
176
156
|
# Pass inference handler and target names for multi-target only
|
|
177
157
|
if self.is_multi_target:
|
|
@@ -185,18 +165,18 @@ class DragonOptimizer:
|
|
|
185
165
|
return run_optimization(
|
|
186
166
|
problem=self.problem,
|
|
187
167
|
searcher_factory=self.searcher_factory,
|
|
188
|
-
num_generations=
|
|
168
|
+
num_generations=self.config.generations,
|
|
189
169
|
target_name=self.target_name,
|
|
190
|
-
save_dir=
|
|
191
|
-
save_format=save_format,
|
|
170
|
+
save_dir=self.config.save_directory,
|
|
171
|
+
save_format=self.config.save_format, # type: ignore
|
|
192
172
|
# Get the definitive feature names (as a list) from the schema
|
|
193
173
|
feature_names=list(self.schema.feature_names),
|
|
194
174
|
# Get categorical info from the schema
|
|
195
175
|
categorical_map=self.schema.categorical_index_map,
|
|
196
176
|
categorical_mappings=self.schema.categorical_mappings,
|
|
197
|
-
repetitions=repetitions,
|
|
177
|
+
repetitions=self.config.repetitions,
|
|
198
178
|
verbose=verbose,
|
|
199
|
-
discretize_start_at_zero=self.discretize_start_at_zero,
|
|
179
|
+
discretize_start_at_zero=self.config.discretize_start_at_zero,
|
|
200
180
|
all_target_names=target_names_to_pass,
|
|
201
181
|
inference_handler=inference_handler_to_pass
|
|
202
182
|
)
|
|
@@ -506,5 +506,5 @@ def _save_result(
|
|
|
506
506
|
|
|
507
507
|
def _handle_pandas_log(logger: PandasLogger, save_path: Path, target_name: str):
|
|
508
508
|
log_dataframe = logger.to_dataframe()
|
|
509
|
-
save_dataframe_filename(df=log_dataframe, save_dir=save_path / "EvolutionLogs", filename=target_name)
|
|
509
|
+
save_dataframe_filename(df=log_dataframe, save_dir=save_path / "EvolutionLogs", filename=target_name, verbose=2)
|
|
510
510
|
|
ml_tools/ML_scaler/_ML_scaler.py
CHANGED
|
@@ -33,7 +33,7 @@ class DragonScaler:
|
|
|
33
33
|
self.continuous_feature_indices = continuous_feature_indices
|
|
34
34
|
|
|
35
35
|
@classmethod
|
|
36
|
-
def fit(cls, dataset: Dataset, continuous_feature_indices: list[int], batch_size: int = 64) -> 'DragonScaler':
|
|
36
|
+
def fit(cls, dataset: Dataset, continuous_feature_indices: list[int], batch_size: int = 64, verbose: int = 3) -> 'DragonScaler':
|
|
37
37
|
"""
|
|
38
38
|
Fits the scaler using a PyTorch Dataset (Method A) using Batched Welford's Algorithm.
|
|
39
39
|
"""
|
|
@@ -85,23 +85,25 @@ class DragonScaler:
|
|
|
85
85
|
n_total = new_n_total
|
|
86
86
|
|
|
87
87
|
if n_total == 0:
|
|
88
|
-
|
|
89
|
-
|
|
88
|
+
_LOGGER.error("Dataset is empty. Scaler cannot be fitted.")
|
|
89
|
+
return cls(continuous_feature_indices=continuous_feature_indices)
|
|
90
90
|
|
|
91
91
|
# Finalize Standard Deviation
|
|
92
92
|
# Unbiased estimator (divide by n-1)
|
|
93
93
|
if n_total < 2:
|
|
94
|
-
|
|
94
|
+
if verbose >= 1:
|
|
95
|
+
_LOGGER.warning(f"Only one sample found. Standard deviation set to 1.")
|
|
95
96
|
std = torch.ones_like(mean_global) # type: ignore
|
|
96
97
|
else:
|
|
97
98
|
variance = m2_global / (n_total - 1)
|
|
98
99
|
std = torch.sqrt(torch.clamp(variance, min=1e-8))
|
|
99
|
-
|
|
100
|
-
|
|
100
|
+
|
|
101
|
+
if verbose >= 2:
|
|
102
|
+
_LOGGER.info(f"Scaler fitted on {n_total} samples for {num_continuous_features} features (Welford's).")
|
|
101
103
|
return cls(mean=mean_global, std=std, continuous_feature_indices=continuous_feature_indices)
|
|
102
104
|
|
|
103
105
|
@classmethod
|
|
104
|
-
def fit_tensor(cls, data: torch.Tensor) -> 'DragonScaler':
|
|
106
|
+
def fit_tensor(cls, data: torch.Tensor, verbose: int = 3) -> 'DragonScaler':
|
|
105
107
|
"""
|
|
106
108
|
Fits the scaler directly on a Tensor (Method B).
|
|
107
109
|
Useful for targets or small datasets already in memory.
|
|
@@ -118,6 +120,9 @@ class DragonScaler:
|
|
|
118
120
|
# Handle constant values (std=0) to prevent division by zero
|
|
119
121
|
std = torch.where(std == 0, torch.tensor(1.0, device=data.device), std)
|
|
120
122
|
|
|
123
|
+
if verbose >= 2:
|
|
124
|
+
_LOGGER.info(f"Scaler fitted on tensor with {data.shape[0]} samples for {num_features} features.")
|
|
125
|
+
|
|
121
126
|
return cls(mean=mean, std=std, continuous_feature_indices=indices)
|
|
122
127
|
|
|
123
128
|
def transform(self, data: torch.Tensor) -> torch.Tensor:
|
ml_tools/ML_scaler/__init__.py
CHANGED
ml_tools/ML_trainer/__init__.py
CHANGED
|
@@ -10,7 +10,7 @@ from ._dragon_detection_trainer import (
|
|
|
10
10
|
DragonDetectionTrainer
|
|
11
11
|
)
|
|
12
12
|
|
|
13
|
-
from
|
|
13
|
+
from .._core import _imprimir_disponibles
|
|
14
14
|
|
|
15
15
|
|
|
16
16
|
__all__ = [
|
|
@@ -18,3 +18,7 @@ __all__ = [
|
|
|
18
18
|
"DragonSequenceTrainer",
|
|
19
19
|
"DragonDetectionTrainer",
|
|
20
20
|
]
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def info():
|
|
24
|
+
_imprimir_disponibles(__all__)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from typing import Literal, Union, Optional, Any
|
|
2
2
|
from pathlib import Path
|
|
3
|
-
from torch.utils.data import DataLoader
|
|
3
|
+
from torch.utils.data import DataLoader, Dataset
|
|
4
4
|
import torch
|
|
5
5
|
from torch import nn
|
|
6
6
|
from abc import ABC, abstractmethod
|
|
@@ -10,6 +10,7 @@ from ..ML_callbacks._checkpoint import DragonModelCheckpoint
|
|
|
10
10
|
from ..ML_callbacks._early_stop import _DragonEarlyStopping
|
|
11
11
|
from ..ML_callbacks._scheduler import _DragonLRScheduler
|
|
12
12
|
from ..ML_evaluation import plot_losses
|
|
13
|
+
from ..ML_utilities import inspect_pth_file
|
|
13
14
|
|
|
14
15
|
from ..path_manager import make_fullpath
|
|
15
16
|
from ..keys._keys import PyTorchCheckpointKeys, MagicWords
|
|
@@ -89,11 +90,128 @@ class _BaseDragonTrainer(ABC):
|
|
|
89
90
|
"""Gives each callback a reference to this trainer instance."""
|
|
90
91
|
for callback in self.callbacks:
|
|
91
92
|
callback.set_trainer(self)
|
|
93
|
+
|
|
94
|
+
def _make_dataloaders(self,
|
|
95
|
+
train_dataset: Any,
|
|
96
|
+
validation_dataset: Any,
|
|
97
|
+
batch_size: int,
|
|
98
|
+
shuffle: bool,
|
|
99
|
+
collate_fn: Optional[Any] = None):
|
|
100
|
+
"""
|
|
101
|
+
Shared logic to initialize standard DataLoaders.
|
|
102
|
+
Subclasses can call this inside their _create_dataloaders implementation.
|
|
103
|
+
"""
|
|
104
|
+
# Ensure stability on MPS devices by setting num_workers to 0
|
|
105
|
+
loader_workers = 0 if self.device.type == 'mps' else self.dataloader_workers
|
|
106
|
+
pin_memory = ("cuda" in self.device.type)
|
|
107
|
+
|
|
108
|
+
self.train_loader = DataLoader(
|
|
109
|
+
dataset=train_dataset,
|
|
110
|
+
batch_size=batch_size,
|
|
111
|
+
shuffle=shuffle,
|
|
112
|
+
num_workers=loader_workers,
|
|
113
|
+
pin_memory=pin_memory,
|
|
114
|
+
drop_last=True,
|
|
115
|
+
collate_fn=collate_fn
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
self.validation_loader = DataLoader(
|
|
119
|
+
dataset=validation_dataset,
|
|
120
|
+
batch_size=batch_size,
|
|
121
|
+
shuffle=False,
|
|
122
|
+
num_workers=loader_workers,
|
|
123
|
+
pin_memory=pin_memory,
|
|
124
|
+
collate_fn=collate_fn
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
def _validate_checkpoint_arg(self, model_checkpoint: Union[Path, str]) -> Union[Path, str]:
|
|
128
|
+
"""Validates the model_checkpoint argument."""
|
|
129
|
+
if isinstance(model_checkpoint, Path):
|
|
130
|
+
return make_fullpath(model_checkpoint, enforce="file")
|
|
131
|
+
elif model_checkpoint in [MagicWords.BEST, MagicWords.CURRENT]:
|
|
132
|
+
return model_checkpoint
|
|
133
|
+
else:
|
|
134
|
+
_LOGGER.error(f"'model_checkpoint' must be a Path object, or the string '{MagicWords.BEST}', or the string '{MagicWords.CURRENT}'.")
|
|
135
|
+
raise ValueError()
|
|
136
|
+
|
|
137
|
+
def _validate_save_dir(self, save_dir: Union[str, Path]) -> Path:
|
|
138
|
+
"""Validates and creates the save directory."""
|
|
139
|
+
return make_fullpath(save_dir, make=True, enforce="directory")
|
|
140
|
+
|
|
141
|
+
def _prepare_eval_data(self,
|
|
142
|
+
data: Optional[Union[DataLoader, Dataset]],
|
|
143
|
+
default_dataset: Optional[Dataset],
|
|
144
|
+
collate_fn: Optional[Any] = None) -> tuple[DataLoader, Any]:
|
|
145
|
+
"""
|
|
146
|
+
Prepares the DataLoader and dataset artifact source for evaluation.
|
|
147
|
+
|
|
148
|
+
Returns:
|
|
149
|
+
(eval_loader, dataset_for_artifacts)
|
|
150
|
+
"""
|
|
151
|
+
eval_loader = None
|
|
152
|
+
dataset_for_artifacts = None
|
|
153
|
+
|
|
154
|
+
# Loader workers config
|
|
155
|
+
loader_workers = 0 if self.device.type == 'mps' else self.dataloader_workers
|
|
156
|
+
pin_memory = (self.device.type == "cuda")
|
|
92
157
|
|
|
93
|
-
|
|
158
|
+
if isinstance(data, DataLoader):
|
|
159
|
+
eval_loader = data
|
|
160
|
+
if hasattr(data, 'dataset'):
|
|
161
|
+
dataset_for_artifacts = data.dataset
|
|
162
|
+
elif isinstance(data, Dataset):
|
|
163
|
+
eval_loader = DataLoader(data,
|
|
164
|
+
batch_size=self._batch_size,
|
|
165
|
+
shuffle=False,
|
|
166
|
+
num_workers=loader_workers,
|
|
167
|
+
pin_memory=pin_memory,
|
|
168
|
+
collate_fn=collate_fn)
|
|
169
|
+
dataset_for_artifacts = data
|
|
170
|
+
else: # data is None
|
|
171
|
+
if default_dataset is None:
|
|
172
|
+
_LOGGER.error("Cannot evaluate. No data provided and no validation dataset available in the trainer.")
|
|
173
|
+
raise ValueError()
|
|
174
|
+
|
|
175
|
+
eval_loader = DataLoader(default_dataset,
|
|
176
|
+
batch_size=self._batch_size,
|
|
177
|
+
shuffle=False,
|
|
178
|
+
num_workers=loader_workers,
|
|
179
|
+
pin_memory=pin_memory,
|
|
180
|
+
collate_fn=collate_fn)
|
|
181
|
+
dataset_for_artifacts = default_dataset
|
|
182
|
+
|
|
183
|
+
if eval_loader is None:
|
|
184
|
+
_LOGGER.error("Cannot evaluate. No valid data was provided or found.")
|
|
185
|
+
raise ValueError()
|
|
186
|
+
|
|
187
|
+
return eval_loader, dataset_for_artifacts
|
|
188
|
+
|
|
189
|
+
def _save_finalized_artifact(self,
|
|
190
|
+
finalized_data: dict,
|
|
191
|
+
save_dir: Union[str, Path],
|
|
192
|
+
filename: str):
|
|
193
|
+
"""
|
|
194
|
+
Handles the common logic for saving the finalized model dictionary to disk.
|
|
195
|
+
"""
|
|
196
|
+
# handle save path
|
|
197
|
+
dir_path = self._validate_save_dir(save_dir)
|
|
198
|
+
full_path = dir_path / filename
|
|
199
|
+
|
|
200
|
+
# checkpoint loading happens before dict creation.
|
|
201
|
+
|
|
202
|
+
torch.save(finalized_data, full_path)
|
|
203
|
+
|
|
204
|
+
_LOGGER.info(f"Finalized model file saved to '{full_path}'")
|
|
205
|
+
|
|
206
|
+
if full_path.is_file():
|
|
207
|
+
inspect_pth_file(pth_path=full_path, save_dir=dir_path, verbose=2)
|
|
208
|
+
|
|
209
|
+
def _load_checkpoint(self, path: Union[str, Path], verbose: int = 3):
|
|
94
210
|
"""Loads a training checkpoint to resume training."""
|
|
95
211
|
p = make_fullpath(path, enforce="file")
|
|
96
|
-
|
|
212
|
+
|
|
213
|
+
if verbose >= 2:
|
|
214
|
+
_LOGGER.info(f"Loading checkpoint from '{p.name}'...")
|
|
97
215
|
|
|
98
216
|
try:
|
|
99
217
|
checkpoint = torch.load(p, map_location=self.device)
|
|
@@ -110,9 +228,11 @@ class _BaseDragonTrainer(ABC):
|
|
|
110
228
|
# --- Load History ---
|
|
111
229
|
if PyTorchCheckpointKeys.HISTORY in checkpoint:
|
|
112
230
|
self.history = checkpoint[PyTorchCheckpointKeys.HISTORY]
|
|
113
|
-
|
|
231
|
+
if verbose >= 3:
|
|
232
|
+
_LOGGER.info(f"Restored training history up to epoch {self.epoch}.")
|
|
114
233
|
else:
|
|
115
|
-
|
|
234
|
+
if verbose >= 1:
|
|
235
|
+
_LOGGER.warning("No 'history' found in checkpoint. A new history will be started.")
|
|
116
236
|
self.history = {} # Ensure it's at least an empty dict
|
|
117
237
|
|
|
118
238
|
# --- Scheduler State Loading Logic ---
|
|
@@ -124,7 +244,8 @@ class _BaseDragonTrainer(ABC):
|
|
|
124
244
|
try:
|
|
125
245
|
self.scheduler.load_state_dict(checkpoint[PyTorchCheckpointKeys.SCHEDULER_STATE]) # type: ignore
|
|
126
246
|
scheduler_name = self.scheduler.__class__.__name__
|
|
127
|
-
|
|
247
|
+
if verbose >= 3:
|
|
248
|
+
_LOGGER.info(f"Restored LR scheduler state for: {scheduler_name}")
|
|
128
249
|
except Exception as e:
|
|
129
250
|
# Loading failed, likely a mismatch
|
|
130
251
|
scheduler_name = self.scheduler.__class__.__name__
|
|
@@ -134,7 +255,8 @@ class _BaseDragonTrainer(ABC):
|
|
|
134
255
|
elif scheduler_object_exists and not scheduler_state_exists:
|
|
135
256
|
# Case 2: Scheduler provided, but no state in checkpoint.
|
|
136
257
|
scheduler_name = self.scheduler.__class__.__name__
|
|
137
|
-
|
|
258
|
+
if verbose >= 1:
|
|
259
|
+
_LOGGER.warning(f"'{scheduler_name}' was provided, but no scheduler state was found in the checkpoint. The scheduler will start from its initial state.")
|
|
138
260
|
|
|
139
261
|
elif not scheduler_object_exists and scheduler_state_exists:
|
|
140
262
|
# Case 3: State in checkpoint, but no scheduler provided.
|
|
@@ -145,9 +267,11 @@ class _BaseDragonTrainer(ABC):
|
|
|
145
267
|
for cb in self.callbacks:
|
|
146
268
|
if isinstance(cb, DragonModelCheckpoint) and PyTorchCheckpointKeys.BEST_SCORE in checkpoint:
|
|
147
269
|
cb.best = checkpoint[PyTorchCheckpointKeys.BEST_SCORE]
|
|
148
|
-
|
|
270
|
+
if verbose >= 3:
|
|
271
|
+
_LOGGER.info(f"Restored {cb.__class__.__name__} 'best' score to: {cb.best:.4f}")
|
|
149
272
|
|
|
150
|
-
|
|
273
|
+
if verbose >= 2:
|
|
274
|
+
_LOGGER.info(f"Model restored to epoch {self.epoch}.")
|
|
151
275
|
|
|
152
276
|
except Exception as e:
|
|
153
277
|
_LOGGER.error(f"Failed to load checkpoint from '{p}': {e}")
|
|
@@ -243,16 +367,15 @@ class _BaseDragonTrainer(ABC):
|
|
|
243
367
|
self.model.to(self.device)
|
|
244
368
|
_LOGGER.info(f"Trainer and model moved to {self.device}.")
|
|
245
369
|
|
|
246
|
-
def
|
|
370
|
+
def _load_model_state_wrapper(self, model_checkpoint: Union[Path, Literal['best', 'current']], verbose: int = 2):
|
|
247
371
|
"""
|
|
248
372
|
Private helper to load the correct model state_dict based on user's choice.
|
|
249
|
-
This is called by finalize_model_training() in subclasses.
|
|
250
373
|
"""
|
|
251
374
|
if isinstance(model_checkpoint, Path):
|
|
252
|
-
self._load_checkpoint(path=model_checkpoint)
|
|
375
|
+
self._load_checkpoint(path=model_checkpoint, verbose=verbose)
|
|
253
376
|
elif model_checkpoint == MagicWords.BEST and self._checkpoint_callback:
|
|
254
377
|
path_to_latest = self._checkpoint_callback.best_checkpoint_path
|
|
255
|
-
self._load_checkpoint(path_to_latest)
|
|
378
|
+
self._load_checkpoint(path_to_latest, verbose=verbose)
|
|
256
379
|
elif model_checkpoint == MagicWords.BEST and self._checkpoint_callback is None:
|
|
257
380
|
_LOGGER.error(f"'model_checkpoint' set to '{MagicWords.BEST}' but no checkpoint callback was found.")
|
|
258
381
|
raise ValueError()
|