dragon-ml-toolbox 14.7.0__py3-none-any.whl → 16.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dragon_ml_toolbox-14.7.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/METADATA +9 -5
- dragon_ml_toolbox-16.2.0.dist-info/RECORD +51 -0
- ml_tools/ETL_cleaning.py +20 -20
- ml_tools/ETL_engineering.py +23 -25
- ml_tools/GUI_tools.py +20 -20
- ml_tools/MICE_imputation.py +3 -3
- ml_tools/ML_callbacks.py +43 -26
- ml_tools/ML_configuration.py +704 -24
- ml_tools/ML_datasetmaster.py +235 -280
- ml_tools/ML_evaluation.py +144 -39
- ml_tools/ML_evaluation_multi.py +103 -35
- ml_tools/ML_inference.py +290 -208
- ml_tools/ML_models.py +13 -102
- ml_tools/ML_models_advanced.py +1 -1
- ml_tools/ML_optimization.py +12 -12
- ml_tools/ML_scaler.py +11 -11
- ml_tools/ML_sequence_datasetmaster.py +341 -0
- ml_tools/ML_sequence_evaluation.py +219 -0
- ml_tools/ML_sequence_inference.py +391 -0
- ml_tools/ML_sequence_models.py +139 -0
- ml_tools/ML_trainer.py +1342 -386
- ml_tools/ML_utilities.py +1 -1
- ml_tools/ML_vision_datasetmaster.py +120 -72
- ml_tools/ML_vision_evaluation.py +30 -6
- ml_tools/ML_vision_inference.py +129 -152
- ml_tools/ML_vision_models.py +1 -1
- ml_tools/ML_vision_transformers.py +121 -40
- ml_tools/PSO_optimization.py +6 -6
- ml_tools/SQL.py +4 -4
- ml_tools/{keys.py → _keys.py} +45 -0
- ml_tools/_schema.py +1 -1
- ml_tools/ensemble_evaluation.py +1 -1
- ml_tools/ensemble_inference.py +7 -33
- ml_tools/ensemble_learning.py +1 -1
- ml_tools/optimization_tools.py +2 -2
- ml_tools/path_manager.py +5 -5
- ml_tools/utilities.py +1 -2
- dragon_ml_toolbox-14.7.0.dist-info/RECORD +0 -49
- ml_tools/RNN_forecast.py +0 -56
- ml_tools/_ML_vision_recipe.py +0 -88
- {dragon_ml_toolbox-14.7.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/WHEEL +0 -0
- {dragon_ml_toolbox-14.7.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/licenses/LICENSE +0 -0
- {dragon_ml_toolbox-14.7.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +0 -0
- {dragon_ml_toolbox-14.7.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/top_level.txt +0 -0
ml_tools/ML_callbacks.py
CHANGED
|
@@ -4,23 +4,22 @@ from tqdm.auto import tqdm
|
|
|
4
4
|
from typing import Union, Literal, Optional
|
|
5
5
|
from pathlib import Path
|
|
6
6
|
|
|
7
|
-
from .path_manager import make_fullpath
|
|
8
|
-
from .
|
|
7
|
+
from .path_manager import make_fullpath
|
|
8
|
+
from ._keys import PyTorchLogKeys, PyTorchCheckpointKeys
|
|
9
9
|
from ._logger import _LOGGER
|
|
10
10
|
from ._script_info import _script_info
|
|
11
11
|
|
|
12
12
|
|
|
13
13
|
__all__ = [
|
|
14
|
-
"Callback",
|
|
15
14
|
"History",
|
|
16
15
|
"TqdmProgressBar",
|
|
17
|
-
"
|
|
18
|
-
"
|
|
19
|
-
"
|
|
16
|
+
"DragonEarlyStopping",
|
|
17
|
+
"DragonModelCheckpoint",
|
|
18
|
+
"DragonLRScheduler"
|
|
20
19
|
]
|
|
21
20
|
|
|
22
21
|
|
|
23
|
-
class
|
|
22
|
+
class _Callback:
|
|
24
23
|
"""
|
|
25
24
|
Abstract base class used to build new callbacks.
|
|
26
25
|
|
|
@@ -60,7 +59,7 @@ class Callback:
|
|
|
60
59
|
pass
|
|
61
60
|
|
|
62
61
|
|
|
63
|
-
class History(
|
|
62
|
+
class History(_Callback):
|
|
64
63
|
"""
|
|
65
64
|
Callback that records events into a `history` dictionary.
|
|
66
65
|
|
|
@@ -79,7 +78,7 @@ class History(Callback):
|
|
|
79
78
|
self.trainer.history.setdefault(k, []).append(v) # type: ignore
|
|
80
79
|
|
|
81
80
|
|
|
82
|
-
class TqdmProgressBar(
|
|
81
|
+
class TqdmProgressBar(_Callback):
|
|
83
82
|
"""Callback that provides a tqdm progress bar for training."""
|
|
84
83
|
def __init__(self):
|
|
85
84
|
self.epoch_bar = None
|
|
@@ -110,7 +109,7 @@ class TqdmProgressBar(Callback):
|
|
|
110
109
|
self.epoch_bar.close() # type: ignore
|
|
111
110
|
|
|
112
111
|
|
|
113
|
-
class
|
|
112
|
+
class DragonEarlyStopping(_Callback):
|
|
114
113
|
"""
|
|
115
114
|
Stop training when a monitored metric has stopped improving.
|
|
116
115
|
"""
|
|
@@ -187,11 +186,11 @@ class EarlyStopping(Callback):
|
|
|
187
186
|
_LOGGER.info(f"Epoch {epoch+1}: early stopping after {self.wait} epochs with no improvement.")
|
|
188
187
|
|
|
189
188
|
|
|
190
|
-
class
|
|
189
|
+
class DragonModelCheckpoint(_Callback):
|
|
191
190
|
"""
|
|
192
191
|
Saves the model weights, optimizer state, LR scheduler state (if any), and epoch number to a directory with automated filename generation and rotation.
|
|
193
192
|
"""
|
|
194
|
-
def __init__(self, save_dir: Union[str,Path],
|
|
193
|
+
def __init__(self, save_dir: Union[str,Path], monitor: str = PyTorchLogKeys.VAL_LOSS,
|
|
195
194
|
save_best_only: bool = True, mode: Literal['auto', 'min', 'max']= 'auto', verbose: int = 0):
|
|
196
195
|
"""
|
|
197
196
|
- If `save_best_only` is True, it saves the single best model, deleting the previous best.
|
|
@@ -199,7 +198,6 @@ class ModelCheckpoint(Callback):
|
|
|
199
198
|
|
|
200
199
|
Args:
|
|
201
200
|
save_dir (str): Directory where checkpoint files will be saved.
|
|
202
|
-
checkpoint_name (str| None): If None, the filename will include the epoch and score.
|
|
203
201
|
monitor (str): Metric to monitor.
|
|
204
202
|
save_best_only (bool): If true, save only the best model.
|
|
205
203
|
mode (str): One of {'auto', 'min', 'max'}.
|
|
@@ -215,9 +213,8 @@ class ModelCheckpoint(Callback):
|
|
|
215
213
|
self.monitor = monitor
|
|
216
214
|
self.save_best_only = save_best_only
|
|
217
215
|
self.verbose = verbose
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
self.checkpoint_name = checkpoint_name
|
|
216
|
+
self._latest_checkpoint_path = None
|
|
217
|
+
self._checkpoint_name = PyTorchCheckpointKeys.CHECKPOINT_NAME
|
|
221
218
|
|
|
222
219
|
# State variables to be managed during training
|
|
223
220
|
self.saved_checkpoints = []
|
|
@@ -261,10 +258,7 @@ class ModelCheckpoint(Callback):
|
|
|
261
258
|
old_best_str = f"{self.best:.4f}" if self.best not in [np.inf, -np.inf] else "inf"
|
|
262
259
|
|
|
263
260
|
# Create a descriptive filename
|
|
264
|
-
|
|
265
|
-
filename = f"epoch_{epoch}-{self.monitor}_{current:.4f}.pth"
|
|
266
|
-
else:
|
|
267
|
-
filename = f"epoch{epoch}_{self.checkpoint_name}.pth"
|
|
261
|
+
filename = f"epoch{epoch}_{self._checkpoint_name}_{current:.4f}.pth"
|
|
268
262
|
new_filepath = self.save_dir / filename
|
|
269
263
|
|
|
270
264
|
if self.verbose > 0:
|
|
@@ -279,6 +273,7 @@ class ModelCheckpoint(Callback):
|
|
|
279
273
|
PyTorchCheckpointKeys.MODEL_STATE: self.trainer.model.state_dict(), # type: ignore
|
|
280
274
|
PyTorchCheckpointKeys.OPTIMIZER_STATE: self.trainer.optimizer.state_dict(), # type: ignore
|
|
281
275
|
PyTorchCheckpointKeys.BEST_SCORE: self.best,
|
|
276
|
+
PyTorchCheckpointKeys.HISTORY: self.trainer.history, # type: ignore
|
|
282
277
|
}
|
|
283
278
|
|
|
284
279
|
# Check for scheduler
|
|
@@ -287,6 +282,7 @@ class ModelCheckpoint(Callback):
|
|
|
287
282
|
|
|
288
283
|
# Save the new best model
|
|
289
284
|
torch.save(checkpoint_data, new_filepath)
|
|
285
|
+
self._latest_checkpoint_path = new_filepath
|
|
290
286
|
|
|
291
287
|
# Delete the old best model file
|
|
292
288
|
if self.last_best_filepath and self.last_best_filepath.exists():
|
|
@@ -298,10 +294,8 @@ class ModelCheckpoint(Callback):
|
|
|
298
294
|
def _save_rolling_checkpoints(self, epoch, logs):
|
|
299
295
|
"""Saves the latest model and keeps only the most recent ones."""
|
|
300
296
|
current = logs.get(self.monitor)
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
else:
|
|
304
|
-
filename = f"epoch{epoch}_{self.checkpoint_name}.pth"
|
|
297
|
+
|
|
298
|
+
filename = f"epoch{epoch}_{self._checkpoint_name}_{current:.4f}.pth"
|
|
305
299
|
filepath = self.save_dir / filename
|
|
306
300
|
|
|
307
301
|
if self.verbose > 0:
|
|
@@ -313,12 +307,15 @@ class ModelCheckpoint(Callback):
|
|
|
313
307
|
PyTorchCheckpointKeys.MODEL_STATE: self.trainer.model.state_dict(), # type: ignore
|
|
314
308
|
PyTorchCheckpointKeys.OPTIMIZER_STATE: self.trainer.optimizer.state_dict(), # type: ignore
|
|
315
309
|
PyTorchCheckpointKeys.BEST_SCORE: self.best, # Save the current best score
|
|
310
|
+
PyTorchCheckpointKeys.HISTORY: self.trainer.history, # type: ignore
|
|
316
311
|
}
|
|
317
312
|
|
|
318
313
|
if hasattr(self.trainer, 'scheduler') and self.trainer.scheduler is not None: # type: ignore
|
|
319
314
|
checkpoint_data[PyTorchCheckpointKeys.SCHEDULER_STATE] = self.trainer.scheduler.state_dict() # type: ignore
|
|
320
315
|
|
|
321
316
|
torch.save(checkpoint_data, filepath)
|
|
317
|
+
|
|
318
|
+
self._latest_checkpoint_path = filepath
|
|
322
319
|
|
|
323
320
|
self.saved_checkpoints.append(filepath)
|
|
324
321
|
|
|
@@ -330,8 +327,16 @@ class ModelCheckpoint(Callback):
|
|
|
330
327
|
_LOGGER.info(f" -> Deleting old checkpoint: {file_to_delete.name}")
|
|
331
328
|
file_to_delete.unlink()
|
|
332
329
|
|
|
330
|
+
@property
|
|
331
|
+
def best_checkpoint_path(self):
|
|
332
|
+
if self._latest_checkpoint_path:
|
|
333
|
+
return self._latest_checkpoint_path
|
|
334
|
+
else:
|
|
335
|
+
_LOGGER.error("No checkpoint paths saved.")
|
|
336
|
+
raise ValueError()
|
|
337
|
+
|
|
333
338
|
|
|
334
|
-
class
|
|
339
|
+
class DragonLRScheduler(_Callback):
|
|
335
340
|
"""
|
|
336
341
|
Callback to manage a PyTorch learning rate scheduler.
|
|
337
342
|
"""
|
|
@@ -361,6 +366,8 @@ class LRScheduler(Callback):
|
|
|
361
366
|
|
|
362
367
|
def on_epoch_end(self, epoch, logs=None):
|
|
363
368
|
"""Step the scheduler and log any change in learning rate."""
|
|
369
|
+
logs = logs or {}
|
|
370
|
+
|
|
364
371
|
# For schedulers that need a metric (e.g., val_loss)
|
|
365
372
|
if isinstance(self.scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
|
|
366
373
|
if self.monitor is None:
|
|
@@ -376,12 +383,22 @@ class LRScheduler(Callback):
|
|
|
376
383
|
# For all other schedulers
|
|
377
384
|
else:
|
|
378
385
|
self.scheduler.step()
|
|
386
|
+
|
|
387
|
+
# Get the current learning rate
|
|
388
|
+
current_lr = self.trainer.optimizer.param_groups[0]['lr'] # type: ignore
|
|
379
389
|
|
|
380
390
|
# Log the change if the LR was updated
|
|
381
|
-
current_lr = self.trainer.optimizer.param_groups[0]['lr'] # type: ignore
|
|
382
391
|
if current_lr != self.previous_lr:
|
|
383
392
|
_LOGGER.info(f"Epoch {epoch}: Learning rate changed to {current_lr:.6f}")
|
|
384
393
|
self.previous_lr = current_lr
|
|
394
|
+
|
|
395
|
+
# --- Add LR to logs and history ---
|
|
396
|
+
# Add to the logs dict for any subsequent callbacks
|
|
397
|
+
logs[PyTorchLogKeys.LEARNING_RATE] = current_lr
|
|
398
|
+
|
|
399
|
+
# Also add directly to the trainer's history dict
|
|
400
|
+
if hasattr(self.trainer, 'history'):
|
|
401
|
+
self.trainer.history.setdefault(PyTorchLogKeys.LEARNING_RATE, []).append(current_lr) # type: ignore
|
|
385
402
|
|
|
386
403
|
|
|
387
404
|
def info():
|