dragon-ml-toolbox 12.13.0__tar.gz → 13.0.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dragon-ml-toolbox might be problematic. Click here for more details.
- {dragon_ml_toolbox-12.13.0/dragon_ml_toolbox.egg-info → dragon_ml_toolbox-13.0.0}/PKG-INFO +1 -1
- {dragon_ml_toolbox-12.13.0 → dragon_ml_toolbox-13.0.0/dragon_ml_toolbox.egg-info}/PKG-INFO +1 -1
- {dragon_ml_toolbox-12.13.0 → dragon_ml_toolbox-13.0.0}/ml_tools/ML_callbacks.py +40 -8
- {dragon_ml_toolbox-12.13.0 → dragon_ml_toolbox-13.0.0}/ml_tools/ML_evaluation.py +6 -2
- {dragon_ml_toolbox-12.13.0 → dragon_ml_toolbox-13.0.0}/ml_tools/ML_evaluation_multi.py +8 -4
- {dragon_ml_toolbox-12.13.0 → dragon_ml_toolbox-13.0.0}/ml_tools/ML_inference.py +14 -4
- {dragon_ml_toolbox-12.13.0 → dragon_ml_toolbox-13.0.0}/ml_tools/ML_trainer.py +98 -11
- {dragon_ml_toolbox-12.13.0 → dragon_ml_toolbox-13.0.0}/ml_tools/keys.py +9 -0
- {dragon_ml_toolbox-12.13.0 → dragon_ml_toolbox-13.0.0}/pyproject.toml +1 -1
- {dragon_ml_toolbox-12.13.0 → dragon_ml_toolbox-13.0.0}/LICENSE +0 -0
- {dragon_ml_toolbox-12.13.0 → dragon_ml_toolbox-13.0.0}/LICENSE-THIRD-PARTY.md +0 -0
- {dragon_ml_toolbox-12.13.0 → dragon_ml_toolbox-13.0.0}/README.md +0 -0
- {dragon_ml_toolbox-12.13.0 → dragon_ml_toolbox-13.0.0}/dragon_ml_toolbox.egg-info/SOURCES.txt +0 -0
- {dragon_ml_toolbox-12.13.0 → dragon_ml_toolbox-13.0.0}/dragon_ml_toolbox.egg-info/dependency_links.txt +0 -0
- {dragon_ml_toolbox-12.13.0 → dragon_ml_toolbox-13.0.0}/dragon_ml_toolbox.egg-info/requires.txt +0 -0
- {dragon_ml_toolbox-12.13.0 → dragon_ml_toolbox-13.0.0}/dragon_ml_toolbox.egg-info/top_level.txt +0 -0
- {dragon_ml_toolbox-12.13.0 → dragon_ml_toolbox-13.0.0}/ml_tools/ETL_cleaning.py +0 -0
- {dragon_ml_toolbox-12.13.0 → dragon_ml_toolbox-13.0.0}/ml_tools/ETL_engineering.py +0 -0
- {dragon_ml_toolbox-12.13.0 → dragon_ml_toolbox-13.0.0}/ml_tools/GUI_tools.py +0 -0
- {dragon_ml_toolbox-12.13.0 → dragon_ml_toolbox-13.0.0}/ml_tools/MICE_imputation.py +0 -0
- {dragon_ml_toolbox-12.13.0 → dragon_ml_toolbox-13.0.0}/ml_tools/ML_datasetmaster.py +0 -0
- {dragon_ml_toolbox-12.13.0 → dragon_ml_toolbox-13.0.0}/ml_tools/ML_models.py +0 -0
- {dragon_ml_toolbox-12.13.0 → dragon_ml_toolbox-13.0.0}/ml_tools/ML_optimization.py +0 -0
- {dragon_ml_toolbox-12.13.0 → dragon_ml_toolbox-13.0.0}/ml_tools/ML_scaler.py +0 -0
- {dragon_ml_toolbox-12.13.0 → dragon_ml_toolbox-13.0.0}/ml_tools/ML_simple_optimization.py +0 -0
- {dragon_ml_toolbox-12.13.0 → dragon_ml_toolbox-13.0.0}/ml_tools/ML_utilities.py +0 -0
- {dragon_ml_toolbox-12.13.0 → dragon_ml_toolbox-13.0.0}/ml_tools/PSO_optimization.py +0 -0
- {dragon_ml_toolbox-12.13.0 → dragon_ml_toolbox-13.0.0}/ml_tools/RNN_forecast.py +0 -0
- {dragon_ml_toolbox-12.13.0 → dragon_ml_toolbox-13.0.0}/ml_tools/SQL.py +0 -0
- {dragon_ml_toolbox-12.13.0 → dragon_ml_toolbox-13.0.0}/ml_tools/VIF_factor.py +0 -0
- {dragon_ml_toolbox-12.13.0 → dragon_ml_toolbox-13.0.0}/ml_tools/__init__.py +0 -0
- {dragon_ml_toolbox-12.13.0 → dragon_ml_toolbox-13.0.0}/ml_tools/_logger.py +0 -0
- {dragon_ml_toolbox-12.13.0 → dragon_ml_toolbox-13.0.0}/ml_tools/_script_info.py +0 -0
- {dragon_ml_toolbox-12.13.0 → dragon_ml_toolbox-13.0.0}/ml_tools/constants.py +0 -0
- {dragon_ml_toolbox-12.13.0 → dragon_ml_toolbox-13.0.0}/ml_tools/custom_logger.py +0 -0
- {dragon_ml_toolbox-12.13.0 → dragon_ml_toolbox-13.0.0}/ml_tools/data_exploration.py +0 -0
- {dragon_ml_toolbox-12.13.0 → dragon_ml_toolbox-13.0.0}/ml_tools/ensemble_evaluation.py +0 -0
- {dragon_ml_toolbox-12.13.0 → dragon_ml_toolbox-13.0.0}/ml_tools/ensemble_inference.py +0 -0
- {dragon_ml_toolbox-12.13.0 → dragon_ml_toolbox-13.0.0}/ml_tools/ensemble_learning.py +0 -0
- {dragon_ml_toolbox-12.13.0 → dragon_ml_toolbox-13.0.0}/ml_tools/handle_excel.py +0 -0
- {dragon_ml_toolbox-12.13.0 → dragon_ml_toolbox-13.0.0}/ml_tools/math_utilities.py +0 -0
- {dragon_ml_toolbox-12.13.0 → dragon_ml_toolbox-13.0.0}/ml_tools/optimization_tools.py +0 -0
- {dragon_ml_toolbox-12.13.0 → dragon_ml_toolbox-13.0.0}/ml_tools/path_manager.py +0 -0
- {dragon_ml_toolbox-12.13.0 → dragon_ml_toolbox-13.0.0}/ml_tools/serde.py +0 -0
- {dragon_ml_toolbox-12.13.0 → dragon_ml_toolbox-13.0.0}/ml_tools/utilities.py +0 -0
- {dragon_ml_toolbox-12.13.0 → dragon_ml_toolbox-13.0.0}/setup.cfg +0 -0
|
@@ -5,7 +5,7 @@ from typing import Union, Literal, Optional
|
|
|
5
5
|
from pathlib import Path
|
|
6
6
|
|
|
7
7
|
from .path_manager import make_fullpath, sanitize_filename
|
|
8
|
-
from .keys import PyTorchLogKeys
|
|
8
|
+
from .keys import PyTorchLogKeys, PyTorchCheckpointKeys
|
|
9
9
|
from ._logger import _LOGGER
|
|
10
10
|
from ._script_info import _script_info
|
|
11
11
|
|
|
@@ -189,7 +189,7 @@ class EarlyStopping(Callback):
|
|
|
189
189
|
|
|
190
190
|
class ModelCheckpoint(Callback):
|
|
191
191
|
"""
|
|
192
|
-
Saves the model weights to a directory with automated filename generation and rotation.
|
|
192
|
+
Saves the model weights, optimizer state, LR scheduler state (if any), and epoch number to a directory with automated filename generation and rotation.
|
|
193
193
|
"""
|
|
194
194
|
def __init__(self, save_dir: Union[str,Path], checkpoint_name: Optional[str]=None, monitor: str = PyTorchLogKeys.VAL_LOSS,
|
|
195
195
|
save_best_only: bool = True, mode: Literal['auto', 'min', 'max']= 'auto', verbose: int = 0):
|
|
@@ -200,7 +200,7 @@ class ModelCheckpoint(Callback):
|
|
|
200
200
|
Args:
|
|
201
201
|
save_dir (str): Directory where checkpoint files will be saved.
|
|
202
202
|
checkpoint_name (str| None): If None, the filename will include the epoch and score.
|
|
203
|
-
monitor (str): Metric to monitor
|
|
203
|
+
monitor (str): Metric to monitor.
|
|
204
204
|
save_best_only (bool): If true, save only the best model.
|
|
205
205
|
mode (str): One of {'auto', 'min', 'max'}.
|
|
206
206
|
verbose (int): Verbosity mode.
|
|
@@ -270,15 +270,29 @@ class ModelCheckpoint(Callback):
|
|
|
270
270
|
if self.verbose > 0:
|
|
271
271
|
_LOGGER.info(f"Epoch {epoch}: {self.monitor} improved from {old_best_str} to {current:.4f}, saving model to {new_filepath}")
|
|
272
272
|
|
|
273
|
+
# Update best score *before* saving
|
|
274
|
+
self.best = current
|
|
275
|
+
|
|
276
|
+
# Create a comprehensive checkpoint dictionary
|
|
277
|
+
checkpoint_data = {
|
|
278
|
+
PyTorchCheckpointKeys.EPOCH: epoch,
|
|
279
|
+
PyTorchCheckpointKeys.MODEL_STATE: self.trainer.model.state_dict(), # type: ignore
|
|
280
|
+
PyTorchCheckpointKeys.OPTIMIZER_STATE: self.trainer.optimizer.state_dict(), # type: ignore
|
|
281
|
+
PyTorchCheckpointKeys.BEST_SCORE: self.best,
|
|
282
|
+
}
|
|
283
|
+
|
|
284
|
+
# Check for scheduler
|
|
285
|
+
if hasattr(self.trainer, 'scheduler') and self.trainer.scheduler is not None: # type: ignore
|
|
286
|
+
checkpoint_data[PyTorchCheckpointKeys.SCHEDULER_STATE] = self.trainer.scheduler.state_dict() # type: ignore
|
|
287
|
+
|
|
273
288
|
# Save the new best model
|
|
274
|
-
torch.save(
|
|
289
|
+
torch.save(checkpoint_data, new_filepath)
|
|
275
290
|
|
|
276
291
|
# Delete the old best model file
|
|
277
292
|
if self.last_best_filepath and self.last_best_filepath.exists():
|
|
278
293
|
self.last_best_filepath.unlink()
|
|
279
294
|
|
|
280
295
|
# Update state
|
|
281
|
-
self.best = current
|
|
282
296
|
self.last_best_filepath = new_filepath
|
|
283
297
|
|
|
284
298
|
def _save_rolling_checkpoints(self, epoch, logs):
|
|
@@ -292,7 +306,19 @@ class ModelCheckpoint(Callback):
|
|
|
292
306
|
|
|
293
307
|
if self.verbose > 0:
|
|
294
308
|
_LOGGER.info(f'Epoch {epoch}: saving model to {filepath}')
|
|
295
|
-
|
|
309
|
+
|
|
310
|
+
# Create a comprehensive checkpoint dictionary
|
|
311
|
+
checkpoint_data = {
|
|
312
|
+
PyTorchCheckpointKeys.EPOCH: epoch,
|
|
313
|
+
PyTorchCheckpointKeys.MODEL_STATE: self.trainer.model.state_dict(), # type: ignore
|
|
314
|
+
PyTorchCheckpointKeys.OPTIMIZER_STATE: self.trainer.optimizer.state_dict(), # type: ignore
|
|
315
|
+
PyTorchCheckpointKeys.BEST_SCORE: self.best, # Save the current best score
|
|
316
|
+
}
|
|
317
|
+
|
|
318
|
+
if hasattr(self.trainer, 'scheduler') and self.trainer.scheduler is not None: # type: ignore
|
|
319
|
+
checkpoint_data[PyTorchCheckpointKeys.SCHEDULER_STATE] = self.trainer.scheduler.state_dict() # type: ignore
|
|
320
|
+
|
|
321
|
+
torch.save(checkpoint_data, filepath)
|
|
296
322
|
|
|
297
323
|
self.saved_checkpoints.append(filepath)
|
|
298
324
|
|
|
@@ -309,19 +335,25 @@ class LRScheduler(Callback):
|
|
|
309
335
|
"""
|
|
310
336
|
Callback to manage a PyTorch learning rate scheduler.
|
|
311
337
|
"""
|
|
312
|
-
def __init__(self, scheduler, monitor: Optional[str] =
|
|
338
|
+
def __init__(self, scheduler, monitor: Optional[str] = PyTorchLogKeys.VAL_LOSS):
|
|
313
339
|
"""
|
|
314
340
|
This callback automatically calls the scheduler's `step()` method at the
|
|
315
341
|
end of each epoch. It also logs a message when the learning rate changes.
|
|
316
342
|
|
|
317
343
|
Args:
|
|
318
344
|
scheduler: An initialized PyTorch learning rate scheduler.
|
|
319
|
-
monitor (str
|
|
345
|
+
monitor (str): The metric to monitor for schedulers that require it, like `ReduceLROnPlateau`. Should match a key in the logs (e.g., 'val_loss').
|
|
320
346
|
"""
|
|
321
347
|
super().__init__()
|
|
322
348
|
self.scheduler = scheduler
|
|
323
349
|
self.monitor = monitor
|
|
324
350
|
self.previous_lr = None
|
|
351
|
+
|
|
352
|
+
def set_trainer(self, trainer):
|
|
353
|
+
"""This is called by the Trainer to associate itself with the callback."""
|
|
354
|
+
super().set_trainer(trainer)
|
|
355
|
+
# Register the scheduler with the trainer so it can be added to the checkpoint
|
|
356
|
+
self.trainer.scheduler = self.scheduler # type: ignore
|
|
325
357
|
|
|
326
358
|
def on_train_begin(self, logs=None):
|
|
327
359
|
"""Store the initial learning rate."""
|
|
@@ -19,6 +19,7 @@ import torch
|
|
|
19
19
|
import shap
|
|
20
20
|
from pathlib import Path
|
|
21
21
|
from typing import Union, Optional, List, Literal
|
|
22
|
+
import warnings
|
|
22
23
|
|
|
23
24
|
from .path_manager import make_fullpath
|
|
24
25
|
from ._logger import _LOGGER
|
|
@@ -298,8 +299,11 @@ def shap_summary_plot(model,
|
|
|
298
299
|
|
|
299
300
|
background_data = background_data.to(device)
|
|
300
301
|
instances_to_explain = instances_to_explain.to(device)
|
|
301
|
-
|
|
302
|
-
|
|
302
|
+
|
|
303
|
+
with warnings.catch_warnings():
|
|
304
|
+
warnings.simplefilter("ignore", category=UserWarning)
|
|
305
|
+
explainer = shap.DeepExplainer(model, background_data)
|
|
306
|
+
|
|
303
307
|
# print("Calculating SHAP values with DeepExplainer...")
|
|
304
308
|
shap_values = explainer.shap_values(instances_to_explain)
|
|
305
309
|
instances_to_explain_np = instances_to_explain.cpu().numpy()
|
|
@@ -20,6 +20,7 @@ from sklearn.metrics import (
|
|
|
20
20
|
)
|
|
21
21
|
from pathlib import Path
|
|
22
22
|
from typing import Union, List, Literal
|
|
23
|
+
import warnings
|
|
23
24
|
|
|
24
25
|
from .path_manager import make_fullpath, sanitize_filename
|
|
25
26
|
from ._logger import _LOGGER
|
|
@@ -273,9 +274,12 @@ def multi_target_shap_summary_plot(
|
|
|
273
274
|
|
|
274
275
|
background_data = background_data.to(device)
|
|
275
276
|
instances_to_explain = instances_to_explain.to(device)
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
277
|
+
|
|
278
|
+
with warnings.catch_warnings():
|
|
279
|
+
warnings.simplefilter("ignore", category=UserWarning)
|
|
280
|
+
explainer = shap.DeepExplainer(model, background_data)
|
|
281
|
+
|
|
282
|
+
# print("Calculating SHAP values with DeepExplainer...")
|
|
279
283
|
# DeepExplainer returns a list of arrays for multi-output models
|
|
280
284
|
shap_values_list = explainer.shap_values(instances_to_explain)
|
|
281
285
|
instances_to_explain_np = instances_to_explain.cpu().numpy()
|
|
@@ -304,7 +308,7 @@ def multi_target_shap_summary_plot(
|
|
|
304
308
|
return output.cpu().numpy() # Return full multi-output array
|
|
305
309
|
|
|
306
310
|
explainer = shap.KernelExplainer(prediction_wrapper, background_summary)
|
|
307
|
-
print("Calculating SHAP values with KernelExplainer...")
|
|
311
|
+
# print("Calculating SHAP values with KernelExplainer...")
|
|
308
312
|
# KernelExplainer also returns a list of arrays for multi-output models
|
|
309
313
|
shap_values_list = explainer.shap_values(instances_to_explain_np, l1_reg="aic")
|
|
310
314
|
# instances_to_explain_np is already set
|
|
@@ -9,7 +9,7 @@ from .ML_scaler import PytorchScaler
|
|
|
9
9
|
from ._script_info import _script_info
|
|
10
10
|
from ._logger import _LOGGER
|
|
11
11
|
from .path_manager import make_fullpath
|
|
12
|
-
from .keys import PyTorchInferenceKeys
|
|
12
|
+
from .keys import PyTorchInferenceKeys, PyTorchCheckpointKeys
|
|
13
13
|
|
|
14
14
|
|
|
15
15
|
__all__ = [
|
|
@@ -56,11 +56,21 @@ class _BaseInferenceHandler(ABC):
|
|
|
56
56
|
model_p = make_fullpath(state_dict, enforce="file")
|
|
57
57
|
|
|
58
58
|
try:
|
|
59
|
-
# Load
|
|
60
|
-
|
|
59
|
+
# Load whatever is in the file
|
|
60
|
+
loaded_data = torch.load(model_p, map_location=self.device)
|
|
61
|
+
|
|
62
|
+
# Check if it's the new checkpoint dictionary or an old weights-only file
|
|
63
|
+
if isinstance(loaded_data, dict) and PyTorchCheckpointKeys.MODEL_STATE in loaded_data:
|
|
64
|
+
# It's a new training checkpoint, extract the weights
|
|
65
|
+
self.model.load_state_dict(loaded_data[PyTorchCheckpointKeys.MODEL_STATE])
|
|
66
|
+
else:
|
|
67
|
+
# It's an old-style file (or just a state_dict), load it directly
|
|
68
|
+
self.model.load_state_dict(loaded_data)
|
|
69
|
+
|
|
70
|
+
_LOGGER.info(f"Model state loaded from '{model_p.name}'.")
|
|
71
|
+
|
|
61
72
|
self.model.to(self.device)
|
|
62
73
|
self.model.eval() # Set the model to evaluation mode
|
|
63
|
-
_LOGGER.info(f"Model state loaded from '{model_p.name}' and set to evaluation mode.")
|
|
64
74
|
except Exception as e:
|
|
65
75
|
_LOGGER.error(f"Failed to load model state from '{model_p}': {e}")
|
|
66
76
|
raise
|
|
@@ -5,12 +5,13 @@ import torch
|
|
|
5
5
|
from torch import nn
|
|
6
6
|
import numpy as np
|
|
7
7
|
|
|
8
|
-
from .ML_callbacks import Callback, History, TqdmProgressBar
|
|
8
|
+
from .ML_callbacks import Callback, History, TqdmProgressBar, ModelCheckpoint
|
|
9
9
|
from .ML_evaluation import classification_metrics, regression_metrics, plot_losses, shap_summary_plot, plot_attention_importance
|
|
10
10
|
from .ML_evaluation_multi import multi_target_regression_metrics, multi_label_classification_metrics, multi_target_shap_summary_plot
|
|
11
11
|
from ._script_info import _script_info
|
|
12
|
-
from .keys import PyTorchLogKeys
|
|
12
|
+
from .keys import PyTorchLogKeys, PyTorchCheckpointKeys
|
|
13
13
|
from ._logger import _LOGGER
|
|
14
|
+
from .path_manager import make_fullpath
|
|
14
15
|
|
|
15
16
|
|
|
16
17
|
__all__ = [
|
|
@@ -55,6 +56,7 @@ class MLTrainer:
|
|
|
55
56
|
self.kind = kind
|
|
56
57
|
self.criterion = criterion
|
|
57
58
|
self.optimizer = optimizer
|
|
59
|
+
self.scheduler = None
|
|
58
60
|
self.device = self._validate_device(device)
|
|
59
61
|
self.dataloader_workers = dataloader_workers
|
|
60
62
|
|
|
@@ -70,6 +72,7 @@ class MLTrainer:
|
|
|
70
72
|
self.history = {}
|
|
71
73
|
self.epoch = 0
|
|
72
74
|
self.epochs = 0 # Total epochs for the fit run
|
|
75
|
+
self.start_epoch = 1
|
|
73
76
|
self.stop_training = False
|
|
74
77
|
|
|
75
78
|
def _validate_device(self, device: str) -> torch.device:
|
|
@@ -109,8 +112,66 @@ class MLTrainer:
|
|
|
109
112
|
num_workers=loader_workers,
|
|
110
113
|
pin_memory=("cuda" in self.device.type)
|
|
111
114
|
)
|
|
115
|
+
|
|
116
|
+
def _load_checkpoint(self, path: Union[str, Path]):
|
|
117
|
+
"""Loads a training checkpoint to resume training."""
|
|
118
|
+
p = make_fullpath(path, enforce="file")
|
|
119
|
+
_LOGGER.info(f"Loading checkpoint from '{p.name}' to resume training...")
|
|
120
|
+
|
|
121
|
+
try:
|
|
122
|
+
checkpoint = torch.load(p, map_location=self.device)
|
|
123
|
+
|
|
124
|
+
if PyTorchCheckpointKeys.MODEL_STATE not in checkpoint or PyTorchCheckpointKeys.OPTIMIZER_STATE not in checkpoint:
|
|
125
|
+
_LOGGER.error(f"Checkpoint file '{p.name}' is invalid. Missing 'model_state_dict' or 'optimizer_state_dict'.")
|
|
126
|
+
raise KeyError()
|
|
112
127
|
|
|
113
|
-
|
|
128
|
+
self.model.load_state_dict(checkpoint[PyTorchCheckpointKeys.MODEL_STATE])
|
|
129
|
+
self.optimizer.load_state_dict(checkpoint[PyTorchCheckpointKeys.OPTIMIZER_STATE])
|
|
130
|
+
self.start_epoch = checkpoint.get(PyTorchCheckpointKeys.EPOCH, 0) + 1 # Resume on the *next* epoch
|
|
131
|
+
|
|
132
|
+
# --- Scheduler State Loading Logic ---
|
|
133
|
+
scheduler_state_exists = PyTorchCheckpointKeys.SCHEDULER_STATE in checkpoint
|
|
134
|
+
scheduler_object_exists = self.scheduler is not None
|
|
135
|
+
|
|
136
|
+
if scheduler_object_exists and scheduler_state_exists:
|
|
137
|
+
# Case 1: Both exist. Attempt to load.
|
|
138
|
+
try:
|
|
139
|
+
self.scheduler.load_state_dict(checkpoint[PyTorchCheckpointKeys.SCHEDULER_STATE]) # type: ignore
|
|
140
|
+
scheduler_name = self.scheduler.__class__.__name__
|
|
141
|
+
_LOGGER.info(f"Restored LR scheduler state for: {scheduler_name}")
|
|
142
|
+
except Exception as e:
|
|
143
|
+
# Loading failed, likely a mismatch
|
|
144
|
+
scheduler_name = self.scheduler.__class__.__name__
|
|
145
|
+
_LOGGER.error(f"Failed to load scheduler state for '{scheduler_name}'. A different scheduler type might have been used.")
|
|
146
|
+
raise e
|
|
147
|
+
|
|
148
|
+
elif scheduler_object_exists and not scheduler_state_exists:
|
|
149
|
+
# Case 2: Scheduler provided, but no state in checkpoint.
|
|
150
|
+
scheduler_name = self.scheduler.__class__.__name__
|
|
151
|
+
_LOGGER.warning(f"'{scheduler_name}' was provided, but no scheduler state was found in the checkpoint. The scheduler will start from its initial state.")
|
|
152
|
+
|
|
153
|
+
elif not scheduler_object_exists and scheduler_state_exists:
|
|
154
|
+
# Case 3: State in checkpoint, but no scheduler provided.
|
|
155
|
+
_LOGGER.error("Checkpoint contains an LR scheduler state, but no LRScheduler callback was provided.")
|
|
156
|
+
raise ValueError()
|
|
157
|
+
|
|
158
|
+
# Restore callback states
|
|
159
|
+
for cb in self.callbacks:
|
|
160
|
+
if isinstance(cb, ModelCheckpoint) and PyTorchCheckpointKeys.BEST_SCORE in checkpoint:
|
|
161
|
+
cb.best = checkpoint[PyTorchCheckpointKeys.BEST_SCORE]
|
|
162
|
+
_LOGGER.info(f"Restored {cb.__class__.__name__} 'best' score to: {cb.best:.4f}")
|
|
163
|
+
|
|
164
|
+
_LOGGER.info(f"Checkpoint loaded. Resuming training from epoch {self.start_epoch}.")
|
|
165
|
+
|
|
166
|
+
except Exception as e:
|
|
167
|
+
_LOGGER.error(f"Failed to load checkpoint from '{p}': {e}")
|
|
168
|
+
raise
|
|
169
|
+
|
|
170
|
+
def fit(self,
|
|
171
|
+
epochs: int = 10,
|
|
172
|
+
batch_size: int = 10,
|
|
173
|
+
shuffle: bool = True,
|
|
174
|
+
resume_from_checkpoint: Optional[Union[str, Path]] = None):
|
|
114
175
|
"""
|
|
115
176
|
Starts the training-validation process of the model.
|
|
116
177
|
|
|
@@ -120,6 +181,7 @@ class MLTrainer:
|
|
|
120
181
|
epochs (int): The total number of epochs to train for.
|
|
121
182
|
batch_size (int): The number of samples per batch.
|
|
122
183
|
shuffle (bool): Whether to shuffle the training data at each epoch.
|
|
184
|
+
resume_from_checkpoint (str | Path | None): Optional path to a checkpoint to resume training.
|
|
123
185
|
|
|
124
186
|
Note:
|
|
125
187
|
For regression tasks using `nn.MSELoss` or `nn.L1Loss`, the trainer
|
|
@@ -132,15 +194,18 @@ class MLTrainer:
|
|
|
132
194
|
self._create_dataloaders(batch_size, shuffle)
|
|
133
195
|
self.model.to(self.device)
|
|
134
196
|
|
|
197
|
+
if resume_from_checkpoint:
|
|
198
|
+
self._load_checkpoint(resume_from_checkpoint)
|
|
199
|
+
|
|
135
200
|
# Reset stop_training flag on the trainer
|
|
136
201
|
self.stop_training = False
|
|
137
202
|
|
|
138
|
-
self.
|
|
203
|
+
self._callbacks_hook('on_train_begin')
|
|
139
204
|
|
|
140
|
-
for epoch in range(
|
|
205
|
+
for epoch in range(self.start_epoch, self.epochs + 1):
|
|
141
206
|
self.epoch = epoch
|
|
142
207
|
epoch_logs = {}
|
|
143
|
-
self.
|
|
208
|
+
self._callbacks_hook('on_epoch_begin', epoch, logs=epoch_logs)
|
|
144
209
|
|
|
145
210
|
train_logs = self._train_step()
|
|
146
211
|
epoch_logs.update(train_logs)
|
|
@@ -148,13 +213,13 @@ class MLTrainer:
|
|
|
148
213
|
val_logs = self._validation_step()
|
|
149
214
|
epoch_logs.update(val_logs)
|
|
150
215
|
|
|
151
|
-
self.
|
|
216
|
+
self._callbacks_hook('on_epoch_end', epoch, logs=epoch_logs)
|
|
152
217
|
|
|
153
218
|
# Check the early stopping flag
|
|
154
219
|
if self.stop_training:
|
|
155
220
|
break
|
|
156
221
|
|
|
157
|
-
self.
|
|
222
|
+
self._callbacks_hook('on_train_end')
|
|
158
223
|
return self.history
|
|
159
224
|
|
|
160
225
|
def _train_step(self):
|
|
@@ -166,7 +231,7 @@ class MLTrainer:
|
|
|
166
231
|
PyTorchLogKeys.BATCH_INDEX: batch_idx,
|
|
167
232
|
PyTorchLogKeys.BATCH_SIZE: features.size(0)
|
|
168
233
|
}
|
|
169
|
-
self.
|
|
234
|
+
self._callbacks_hook('on_batch_begin', batch_idx, logs=batch_logs)
|
|
170
235
|
|
|
171
236
|
features, target = features.to(self.device), target.to(self.device)
|
|
172
237
|
self.optimizer.zero_grad()
|
|
@@ -188,7 +253,7 @@ class MLTrainer:
|
|
|
188
253
|
|
|
189
254
|
# Add the batch loss to the logs and call the end-of-batch hook
|
|
190
255
|
batch_logs[PyTorchLogKeys.BATCH_LOSS] = batch_loss
|
|
191
|
-
self.
|
|
256
|
+
self._callbacks_hook('on_batch_end', batch_idx, logs=batch_logs)
|
|
192
257
|
|
|
193
258
|
return {PyTorchLogKeys.TRAIN_LOSS: running_loss / len(self.train_loader.dataset)} # type: ignore
|
|
194
259
|
|
|
@@ -538,11 +603,33 @@ class MLTrainer:
|
|
|
538
603
|
else:
|
|
539
604
|
_LOGGER.error("No attention weights were collected from the model.")
|
|
540
605
|
|
|
541
|
-
def
|
|
606
|
+
def _callbacks_hook(self, method_name: str, *args, **kwargs):
|
|
542
607
|
"""Calls the specified method on all callbacks."""
|
|
543
608
|
for callback in self.callbacks:
|
|
544
609
|
method = getattr(callback, method_name)
|
|
545
610
|
method(*args, **kwargs)
|
|
611
|
+
|
|
612
|
+
def to_cpu(self):
|
|
613
|
+
"""
|
|
614
|
+
Moves the model to the CPU and updates the trainer's device setting.
|
|
615
|
+
|
|
616
|
+
This is useful for running operations that require the CPU.
|
|
617
|
+
"""
|
|
618
|
+
self.device = torch.device('cpu')
|
|
619
|
+
self.model.to(self.device)
|
|
620
|
+
_LOGGER.info("Trainer and model moved to CPU.")
|
|
621
|
+
|
|
622
|
+
def to_device(self, device: str):
|
|
623
|
+
"""
|
|
624
|
+
Moves the model to the specified device and updates the trainer's device setting.
|
|
625
|
+
|
|
626
|
+
Args:
|
|
627
|
+
device (str): The target device (e.g., 'cuda', 'mps', 'cpu').
|
|
628
|
+
"""
|
|
629
|
+
self.device = self._validate_device(device)
|
|
630
|
+
self.model.to(self.device)
|
|
631
|
+
_LOGGER.info(f"Trainer and model moved to {self.device}.")
|
|
632
|
+
|
|
546
633
|
|
|
547
634
|
def info():
|
|
548
635
|
_script_info(__all__)
|
|
@@ -68,6 +68,15 @@ class SHAPKeys:
|
|
|
68
68
|
SAVENAME = "shap_summary"
|
|
69
69
|
|
|
70
70
|
|
|
71
|
+
class PyTorchCheckpointKeys:
|
|
72
|
+
"""Keys for saving/loading a training checkpoint dictionary."""
|
|
73
|
+
MODEL_STATE = "model_state_dict"
|
|
74
|
+
OPTIMIZER_STATE = "optimizer_state_dict"
|
|
75
|
+
SCHEDULER_STATE = "scheduler_state_dict"
|
|
76
|
+
EPOCH = "epoch"
|
|
77
|
+
BEST_SCORE = "best_score"
|
|
78
|
+
|
|
79
|
+
|
|
71
80
|
class _OneHotOtherPlaceholder:
|
|
72
81
|
"""Used internally by GUI_tools."""
|
|
73
82
|
OTHER_GUI = "OTHER"
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{dragon_ml_toolbox-12.13.0 → dragon_ml_toolbox-13.0.0}/dragon_ml_toolbox.egg-info/SOURCES.txt
RENAMED
|
File without changes
|
|
File without changes
|
{dragon_ml_toolbox-12.13.0 → dragon_ml_toolbox-13.0.0}/dragon_ml_toolbox.egg-info/requires.txt
RENAMED
|
File without changes
|
{dragon_ml_toolbox-12.13.0 → dragon_ml_toolbox-13.0.0}/dragon_ml_toolbox.egg-info/top_level.txt
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|