dragon-ml-toolbox 12.13.0__py3-none-any.whl → 13.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: dragon-ml-toolbox
3
- Version: 12.13.0
3
+ Version: 13.0.0
4
4
  Summary: A collection of tools for data science and machine learning projects.
5
5
  Author-email: "Karl L. Loza Vidaurre" <luigiloza@gmail.com>
6
6
  License-Expression: MIT
@@ -1,19 +1,19 @@
1
- dragon_ml_toolbox-12.13.0.dist-info/licenses/LICENSE,sha256=L35WDmmLZNTlJvxF6Vy7Uy4SYNi6rCfWUqlTHpoRMoU,1081
2
- dragon_ml_toolbox-12.13.0.dist-info/licenses/LICENSE-THIRD-PARTY.md,sha256=iy2r_R7wjzsCbz_Q_jMsp_jfZ6oP8XW9QhwzRBH0mGY,1904
1
+ dragon_ml_toolbox-13.0.0.dist-info/licenses/LICENSE,sha256=L35WDmmLZNTlJvxF6Vy7Uy4SYNi6rCfWUqlTHpoRMoU,1081
2
+ dragon_ml_toolbox-13.0.0.dist-info/licenses/LICENSE-THIRD-PARTY.md,sha256=iy2r_R7wjzsCbz_Q_jMsp_jfZ6oP8XW9QhwzRBH0mGY,1904
3
3
  ml_tools/ETL_cleaning.py,sha256=2VBRllV8F-ZiPylPp8Az2gwn5ztgazN0BH5OKnRUhV0,20402
4
4
  ml_tools/ETL_engineering.py,sha256=KfYqgsxupAx6e_TxwO1LZXeu5mFkIhVXJrNjP3CzIZc,54927
5
5
  ml_tools/GUI_tools.py,sha256=Va6ig-dHULPVRwQYYtH3fvY5XPIoqRcJpRW8oXC55Hw,45413
6
6
  ml_tools/MICE_imputation.py,sha256=X273Qlgoqqg7KTmoKd75YDyAPB0UIbTzGP3xsCmRh3E,11717
7
- ml_tools/ML_callbacks.py,sha256=2ZazJjlbClP-ALc8q0ru2oalkugbhO3TFwPg4RFZpck,14056
7
+ ml_tools/ML_callbacks.py,sha256=elD2Yr030sv_6gX_m9GVd6HTyrbmt34nFS8lrgS4HtM,15808
8
8
  ml_tools/ML_datasetmaster.py,sha256=kedCGneR3S2zui0_JFZN6TBL5e69XWkdpkE_QohyqSM,31433
9
- ml_tools/ML_evaluation.py,sha256=h7fAtk0lS4gTqQ46fiVjucTvFlX4rsufKnEtate6Nu0,18381
10
- ml_tools/ML_evaluation_multi.py,sha256=Kn9n5lfxo7A0TvgIDMx8UHZCvzTqv1ViezzwJBF-ypM,15970
11
- ml_tools/ML_inference.py,sha256=ymFvncFsU10PExq87xnEj541DKV5ck0nMuK8ToJHzVQ,23067
9
+ ml_tools/ML_evaluation.py,sha256=3u5dOhS77gn3kAshKr2GwSa5xZBF0YM77ZkFevqNPvA,18528
10
+ ml_tools/ML_evaluation_multi.py,sha256=L6Ub_uObXsI7ToVCF6DtmAFekHRcga5wWMOnRYRR-BY,16121
11
+ ml_tools/ML_inference.py,sha256=yq2gdN6s_OUYC5ZLQrIJC5BA5H33q8UKODXwb-_0M2c,23549
12
12
  ml_tools/ML_models.py,sha256=G64NPhYZfYvHTIUwkIrMrNLgfDTKJwqdc8jwesPqB9E,28090
13
13
  ml_tools/ML_optimization.py,sha256=es3TlQbY7RYgJMZnznkjYGbUxFnAqzZxE_g3_qLK9Q8,22960
14
14
  ml_tools/ML_scaler.py,sha256=tw6onj9o8_kk3FQYb930HUzvv1zsFZe2YZJdF3LtHkU,7538
15
15
  ml_tools/ML_simple_optimization.py,sha256=W2mce1XFCuiOHTOjOsCNbETISHn5MwYlYsTIXH5hMMo,18177
16
- ml_tools/ML_trainer.py,sha256=UmCuKr_GzQGYqhEZ-kaRv9Buj44DsNyuOzmOM7Fw8N0,24569
16
+ ml_tools/ML_trainer.py,sha256=9BP6JFClqGfe7GL-FGG3n5e-no9ssjEOLol7P6baGrI,29019
17
17
  ml_tools/ML_utilities.py,sha256=EnKpPTnJ2qjZmz7kvows4Uu5CfSA7ByRmI1v2-KarKw,9337
18
18
  ml_tools/PSO_optimization.py,sha256=fVHeemqilBS0zrGV25E5yKwDlGdd2ZKa18d8CZ6Q6Fk,22961
19
19
  ml_tools/RNN_forecast.py,sha256=Qa2KoZfdAvSjZ4yE78N4BFXtr3tTr0Gx7tQJZPotsh0,1967
@@ -29,13 +29,13 @@ ml_tools/ensemble_evaluation.py,sha256=FGHSe8LBI8_w8LjNeJWOcYQ1UK_mc6fVah8gmSvNV
29
29
  ml_tools/ensemble_inference.py,sha256=0yLmLNj45RVVoSCLH1ZYJG9IoAhTkWUqEZmLOQTFGTY,9348
30
30
  ml_tools/ensemble_learning.py,sha256=vsIED7nlheYI4w2SBzP6SC1AnNeMfn-2A1Gqw5EfxsM,21964
31
31
  ml_tools/handle_excel.py,sha256=pfdAPb9ywegFkM9T54bRssDOsX-K7rSeV0RaMz7lEAo,14006
32
- ml_tools/keys.py,sha256=FDpbS3Jb0pjrVvvp2_8nZi919mbob_-xwuy5OOtKM_A,1848
32
+ ml_tools/keys.py,sha256=eJ4St5fl8uHstEGO1XVdP8G-ddwjOxV9zqG0D6W8pCI,2124
33
33
  ml_tools/math_utilities.py,sha256=PxoOrnuj6Ntp7_TJqyDWi0JX03WpAO5iaFNK2Oeq5I4,8800
34
34
  ml_tools/optimization_tools.py,sha256=P074YCuZzkqkONnAsM-Zb9DTX_i8cRkkJLpwAWz6CRw,13521
35
35
  ml_tools/path_manager.py,sha256=CyDU16pOKmC82jPubqJPT6EBt-u-3rGVbxyPIZCvDDY,18432
36
36
  ml_tools/serde.py,sha256=ll2mVC0sO2jIEdG3K6xMcgEN13N4YSb8VjviGvw_ers,4949
37
37
  ml_tools/utilities.py,sha256=OcAyV1tEcYAfOWlGjRgopsjDLxU3DcI5EynzvWV4q3A,15754
38
- dragon_ml_toolbox-12.13.0.dist-info/METADATA,sha256=p3-oOSqq1hhJj13KjIXeFnwBu3UTfBJu5mTDL9MCpdU,6167
39
- dragon_ml_toolbox-12.13.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
40
- dragon_ml_toolbox-12.13.0.dist-info/top_level.txt,sha256=wm-oxax3ciyez6VoO4zsFd-gSok2VipYXnbg3TH9PtU,9
41
- dragon_ml_toolbox-12.13.0.dist-info/RECORD,,
38
+ dragon_ml_toolbox-13.0.0.dist-info/METADATA,sha256=trY1fFyTTXLS6TZdrJXxq4_YMPjEZhKCilzCg6qFxzw,6166
39
+ dragon_ml_toolbox-13.0.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
40
+ dragon_ml_toolbox-13.0.0.dist-info/top_level.txt,sha256=wm-oxax3ciyez6VoO4zsFd-gSok2VipYXnbg3TH9PtU,9
41
+ dragon_ml_toolbox-13.0.0.dist-info/RECORD,,
ml_tools/ML_callbacks.py CHANGED
@@ -5,7 +5,7 @@ from typing import Union, Literal, Optional
5
5
  from pathlib import Path
6
6
 
7
7
  from .path_manager import make_fullpath, sanitize_filename
8
- from .keys import PyTorchLogKeys
8
+ from .keys import PyTorchLogKeys, PyTorchCheckpointKeys
9
9
  from ._logger import _LOGGER
10
10
  from ._script_info import _script_info
11
11
 
@@ -189,7 +189,7 @@ class EarlyStopping(Callback):
189
189
 
190
190
  class ModelCheckpoint(Callback):
191
191
  """
192
- Saves the model weights to a directory with automated filename generation and rotation.
192
+ Saves the model weights, optimizer state, LR scheduler state (if any), and epoch number to a directory with automated filename generation and rotation.
193
193
  """
194
194
  def __init__(self, save_dir: Union[str,Path], checkpoint_name: Optional[str]=None, monitor: str = PyTorchLogKeys.VAL_LOSS,
195
195
  save_best_only: bool = True, mode: Literal['auto', 'min', 'max']= 'auto', verbose: int = 0):
@@ -200,7 +200,7 @@ class ModelCheckpoint(Callback):
200
200
  Args:
201
201
  save_dir (str): Directory where checkpoint files will be saved.
202
202
  checkpoint_name (str| None): If None, the filename will include the epoch and score.
203
- monitor (str): Metric to monitor for `save_best_only=True`.
203
+ monitor (str): Metric to monitor.
204
204
  save_best_only (bool): If true, save only the best model.
205
205
  mode (str): One of {'auto', 'min', 'max'}.
206
206
  verbose (int): Verbosity mode.
@@ -270,15 +270,29 @@ class ModelCheckpoint(Callback):
270
270
  if self.verbose > 0:
271
271
  _LOGGER.info(f"Epoch {epoch}: {self.monitor} improved from {old_best_str} to {current:.4f}, saving model to {new_filepath}")
272
272
 
273
+ # Update best score *before* saving
274
+ self.best = current
275
+
276
+ # Create a comprehensive checkpoint dictionary
277
+ checkpoint_data = {
278
+ PyTorchCheckpointKeys.EPOCH: epoch,
279
+ PyTorchCheckpointKeys.MODEL_STATE: self.trainer.model.state_dict(), # type: ignore
280
+ PyTorchCheckpointKeys.OPTIMIZER_STATE: self.trainer.optimizer.state_dict(), # type: ignore
281
+ PyTorchCheckpointKeys.BEST_SCORE: self.best,
282
+ }
283
+
284
+ # Check for scheduler
285
+ if hasattr(self.trainer, 'scheduler') and self.trainer.scheduler is not None: # type: ignore
286
+ checkpoint_data[PyTorchCheckpointKeys.SCHEDULER_STATE] = self.trainer.scheduler.state_dict() # type: ignore
287
+
273
288
  # Save the new best model
274
- torch.save(self.trainer.model.state_dict(), new_filepath) # type: ignore
289
+ torch.save(checkpoint_data, new_filepath)
275
290
 
276
291
  # Delete the old best model file
277
292
  if self.last_best_filepath and self.last_best_filepath.exists():
278
293
  self.last_best_filepath.unlink()
279
294
 
280
295
  # Update state
281
- self.best = current
282
296
  self.last_best_filepath = new_filepath
283
297
 
284
298
  def _save_rolling_checkpoints(self, epoch, logs):
@@ -292,7 +306,19 @@ class ModelCheckpoint(Callback):
292
306
 
293
307
  if self.verbose > 0:
294
308
  _LOGGER.info(f'Epoch {epoch}: saving model to {filepath}')
295
- torch.save(self.trainer.model.state_dict(), filepath) # type: ignore
309
+
310
+ # Create a comprehensive checkpoint dictionary
311
+ checkpoint_data = {
312
+ PyTorchCheckpointKeys.EPOCH: epoch,
313
+ PyTorchCheckpointKeys.MODEL_STATE: self.trainer.model.state_dict(), # type: ignore
314
+ PyTorchCheckpointKeys.OPTIMIZER_STATE: self.trainer.optimizer.state_dict(), # type: ignore
315
+ PyTorchCheckpointKeys.BEST_SCORE: self.best, # Save the current best score
316
+ }
317
+
318
+ if hasattr(self.trainer, 'scheduler') and self.trainer.scheduler is not None: # type: ignore
319
+ checkpoint_data[PyTorchCheckpointKeys.SCHEDULER_STATE] = self.trainer.scheduler.state_dict() # type: ignore
320
+
321
+ torch.save(checkpoint_data, filepath)
296
322
 
297
323
  self.saved_checkpoints.append(filepath)
298
324
 
@@ -309,19 +335,25 @@ class LRScheduler(Callback):
309
335
  """
310
336
  Callback to manage a PyTorch learning rate scheduler.
311
337
  """
312
- def __init__(self, scheduler, monitor: Optional[str] = None):
338
+ def __init__(self, scheduler, monitor: Optional[str] = PyTorchLogKeys.VAL_LOSS):
313
339
  """
314
340
  This callback automatically calls the scheduler's `step()` method at the
315
341
  end of each epoch. It also logs a message when the learning rate changes.
316
342
 
317
343
  Args:
318
344
  scheduler: An initialized PyTorch learning rate scheduler.
319
- monitor (str, optional): The metric to monitor for schedulers that require it, like `ReduceLROnPlateau`. Should match a key in the logs (e.g., 'val_loss').
345
+ monitor (str): The metric to monitor for schedulers that require it, like `ReduceLROnPlateau`. Should match a key in the logs (e.g., 'val_loss').
320
346
  """
321
347
  super().__init__()
322
348
  self.scheduler = scheduler
323
349
  self.monitor = monitor
324
350
  self.previous_lr = None
351
+
352
+ def set_trainer(self, trainer):
353
+ """This is called by the Trainer to associate itself with the callback."""
354
+ super().set_trainer(trainer)
355
+ # Register the scheduler with the trainer so it can be added to the checkpoint
356
+ self.trainer.scheduler = self.scheduler # type: ignore
325
357
 
326
358
  def on_train_begin(self, logs=None):
327
359
  """Store the initial learning rate."""
ml_tools/ML_evaluation.py CHANGED
@@ -19,6 +19,7 @@ import torch
19
19
  import shap
20
20
  from pathlib import Path
21
21
  from typing import Union, Optional, List, Literal
22
+ import warnings
22
23
 
23
24
  from .path_manager import make_fullpath
24
25
  from ._logger import _LOGGER
@@ -298,8 +299,11 @@ def shap_summary_plot(model,
298
299
 
299
300
  background_data = background_data.to(device)
300
301
  instances_to_explain = instances_to_explain.to(device)
301
-
302
- explainer = shap.DeepExplainer(model, background_data)
302
+
303
+ with warnings.catch_warnings():
304
+ warnings.simplefilter("ignore", category=UserWarning)
305
+ explainer = shap.DeepExplainer(model, background_data)
306
+
303
307
  # print("Calculating SHAP values with DeepExplainer...")
304
308
  shap_values = explainer.shap_values(instances_to_explain)
305
309
  instances_to_explain_np = instances_to_explain.cpu().numpy()
@@ -20,6 +20,7 @@ from sklearn.metrics import (
20
20
  )
21
21
  from pathlib import Path
22
22
  from typing import Union, List, Literal
23
+ import warnings
23
24
 
24
25
  from .path_manager import make_fullpath, sanitize_filename
25
26
  from ._logger import _LOGGER
@@ -273,9 +274,12 @@ def multi_target_shap_summary_plot(
273
274
 
274
275
  background_data = background_data.to(device)
275
276
  instances_to_explain = instances_to_explain.to(device)
276
-
277
- explainer = shap.DeepExplainer(model, background_data)
278
- print("Calculating SHAP values with DeepExplainer...")
277
+
278
+ with warnings.catch_warnings():
279
+ warnings.simplefilter("ignore", category=UserWarning)
280
+ explainer = shap.DeepExplainer(model, background_data)
281
+
282
+ # print("Calculating SHAP values with DeepExplainer...")
279
283
  # DeepExplainer returns a list of arrays for multi-output models
280
284
  shap_values_list = explainer.shap_values(instances_to_explain)
281
285
  instances_to_explain_np = instances_to_explain.cpu().numpy()
@@ -304,7 +308,7 @@ def multi_target_shap_summary_plot(
304
308
  return output.cpu().numpy() # Return full multi-output array
305
309
 
306
310
  explainer = shap.KernelExplainer(prediction_wrapper, background_summary)
307
- print("Calculating SHAP values with KernelExplainer...")
311
+ # print("Calculating SHAP values with KernelExplainer...")
308
312
  # KernelExplainer also returns a list of arrays for multi-output models
309
313
  shap_values_list = explainer.shap_values(instances_to_explain_np, l1_reg="aic")
310
314
  # instances_to_explain_np is already set
ml_tools/ML_inference.py CHANGED
@@ -9,7 +9,7 @@ from .ML_scaler import PytorchScaler
9
9
  from ._script_info import _script_info
10
10
  from ._logger import _LOGGER
11
11
  from .path_manager import make_fullpath
12
- from .keys import PyTorchInferenceKeys
12
+ from .keys import PyTorchInferenceKeys, PyTorchCheckpointKeys
13
13
 
14
14
 
15
15
  __all__ = [
@@ -56,11 +56,21 @@ class _BaseInferenceHandler(ABC):
56
56
  model_p = make_fullpath(state_dict, enforce="file")
57
57
 
58
58
  try:
59
- # Load the state dictionary and apply it to the model structure
60
- self.model.load_state_dict(torch.load(model_p, map_location=self.device))
59
+ # Load whatever is in the file
60
+ loaded_data = torch.load(model_p, map_location=self.device)
61
+
62
+ # Check if it's the new checkpoint dictionary or an old weights-only file
63
+ if isinstance(loaded_data, dict) and PyTorchCheckpointKeys.MODEL_STATE in loaded_data:
64
+ # It's a new training checkpoint, extract the weights
65
+ self.model.load_state_dict(loaded_data[PyTorchCheckpointKeys.MODEL_STATE])
66
+ else:
67
+ # It's an old-style file (or just a state_dict), load it directly
68
+ self.model.load_state_dict(loaded_data)
69
+
70
+ _LOGGER.info(f"Model state loaded from '{model_p.name}'.")
71
+
61
72
  self.model.to(self.device)
62
73
  self.model.eval() # Set the model to evaluation mode
63
- _LOGGER.info(f"Model state loaded from '{model_p.name}' and set to evaluation mode.")
64
74
  except Exception as e:
65
75
  _LOGGER.error(f"Failed to load model state from '{model_p}': {e}")
66
76
  raise
ml_tools/ML_trainer.py CHANGED
@@ -5,12 +5,13 @@ import torch
5
5
  from torch import nn
6
6
  import numpy as np
7
7
 
8
- from .ML_callbacks import Callback, History, TqdmProgressBar
8
+ from .ML_callbacks import Callback, History, TqdmProgressBar, ModelCheckpoint
9
9
  from .ML_evaluation import classification_metrics, regression_metrics, plot_losses, shap_summary_plot, plot_attention_importance
10
10
  from .ML_evaluation_multi import multi_target_regression_metrics, multi_label_classification_metrics, multi_target_shap_summary_plot
11
11
  from ._script_info import _script_info
12
- from .keys import PyTorchLogKeys
12
+ from .keys import PyTorchLogKeys, PyTorchCheckpointKeys
13
13
  from ._logger import _LOGGER
14
+ from .path_manager import make_fullpath
14
15
 
15
16
 
16
17
  __all__ = [
@@ -55,6 +56,7 @@ class MLTrainer:
55
56
  self.kind = kind
56
57
  self.criterion = criterion
57
58
  self.optimizer = optimizer
59
+ self.scheduler = None
58
60
  self.device = self._validate_device(device)
59
61
  self.dataloader_workers = dataloader_workers
60
62
 
@@ -70,6 +72,7 @@ class MLTrainer:
70
72
  self.history = {}
71
73
  self.epoch = 0
72
74
  self.epochs = 0 # Total epochs for the fit run
75
+ self.start_epoch = 1
73
76
  self.stop_training = False
74
77
 
75
78
  def _validate_device(self, device: str) -> torch.device:
@@ -109,8 +112,66 @@ class MLTrainer:
109
112
  num_workers=loader_workers,
110
113
  pin_memory=("cuda" in self.device.type)
111
114
  )
115
+
116
+ def _load_checkpoint(self, path: Union[str, Path]):
117
+ """Loads a training checkpoint to resume training."""
118
+ p = make_fullpath(path, enforce="file")
119
+ _LOGGER.info(f"Loading checkpoint from '{p.name}' to resume training...")
120
+
121
+ try:
122
+ checkpoint = torch.load(p, map_location=self.device)
123
+
124
+ if PyTorchCheckpointKeys.MODEL_STATE not in checkpoint or PyTorchCheckpointKeys.OPTIMIZER_STATE not in checkpoint:
125
+ _LOGGER.error(f"Checkpoint file '{p.name}' is invalid. Missing 'model_state_dict' or 'optimizer_state_dict'.")
126
+ raise KeyError()
112
127
 
113
- def fit(self, epochs: int = 10, batch_size: int = 10, shuffle: bool = True):
128
+ self.model.load_state_dict(checkpoint[PyTorchCheckpointKeys.MODEL_STATE])
129
+ self.optimizer.load_state_dict(checkpoint[PyTorchCheckpointKeys.OPTIMIZER_STATE])
130
+ self.start_epoch = checkpoint.get(PyTorchCheckpointKeys.EPOCH, 0) + 1 # Resume on the *next* epoch
131
+
132
+ # --- Scheduler State Loading Logic ---
133
+ scheduler_state_exists = PyTorchCheckpointKeys.SCHEDULER_STATE in checkpoint
134
+ scheduler_object_exists = self.scheduler is not None
135
+
136
+ if scheduler_object_exists and scheduler_state_exists:
137
+ # Case 1: Both exist. Attempt to load.
138
+ try:
139
+ self.scheduler.load_state_dict(checkpoint[PyTorchCheckpointKeys.SCHEDULER_STATE]) # type: ignore
140
+ scheduler_name = self.scheduler.__class__.__name__
141
+ _LOGGER.info(f"Restored LR scheduler state for: {scheduler_name}")
142
+ except Exception as e:
143
+ # Loading failed, likely a mismatch
144
+ scheduler_name = self.scheduler.__class__.__name__
145
+ _LOGGER.error(f"Failed to load scheduler state for '{scheduler_name}'. A different scheduler type might have been used.")
146
+ raise e
147
+
148
+ elif scheduler_object_exists and not scheduler_state_exists:
149
+ # Case 2: Scheduler provided, but no state in checkpoint.
150
+ scheduler_name = self.scheduler.__class__.__name__
151
+ _LOGGER.warning(f"'{scheduler_name}' was provided, but no scheduler state was found in the checkpoint. The scheduler will start from its initial state.")
152
+
153
+ elif not scheduler_object_exists and scheduler_state_exists:
154
+ # Case 3: State in checkpoint, but no scheduler provided.
155
+ _LOGGER.error("Checkpoint contains an LR scheduler state, but no LRScheduler callback was provided.")
156
+ raise ValueError()
157
+
158
+ # Restore callback states
159
+ for cb in self.callbacks:
160
+ if isinstance(cb, ModelCheckpoint) and PyTorchCheckpointKeys.BEST_SCORE in checkpoint:
161
+ cb.best = checkpoint[PyTorchCheckpointKeys.BEST_SCORE]
162
+ _LOGGER.info(f"Restored {cb.__class__.__name__} 'best' score to: {cb.best:.4f}")
163
+
164
+ _LOGGER.info(f"Checkpoint loaded. Resuming training from epoch {self.start_epoch}.")
165
+
166
+ except Exception as e:
167
+ _LOGGER.error(f"Failed to load checkpoint from '{p}': {e}")
168
+ raise
169
+
170
+ def fit(self,
171
+ epochs: int = 10,
172
+ batch_size: int = 10,
173
+ shuffle: bool = True,
174
+ resume_from_checkpoint: Optional[Union[str, Path]] = None):
114
175
  """
115
176
  Starts the training-validation process of the model.
116
177
 
@@ -120,6 +181,7 @@ class MLTrainer:
120
181
  epochs (int): The total number of epochs to train for.
121
182
  batch_size (int): The number of samples per batch.
122
183
  shuffle (bool): Whether to shuffle the training data at each epoch.
184
+ resume_from_checkpoint (str | Path | None): Optional path to a checkpoint to resume training.
123
185
 
124
186
  Note:
125
187
  For regression tasks using `nn.MSELoss` or `nn.L1Loss`, the trainer
@@ -132,15 +194,18 @@ class MLTrainer:
132
194
  self._create_dataloaders(batch_size, shuffle)
133
195
  self.model.to(self.device)
134
196
 
197
+ if resume_from_checkpoint:
198
+ self._load_checkpoint(resume_from_checkpoint)
199
+
135
200
  # Reset stop_training flag on the trainer
136
201
  self.stop_training = False
137
202
 
138
- self.callbacks_hook('on_train_begin')
203
+ self._callbacks_hook('on_train_begin')
139
204
 
140
- for epoch in range(1, self.epochs + 1):
205
+ for epoch in range(self.start_epoch, self.epochs + 1):
141
206
  self.epoch = epoch
142
207
  epoch_logs = {}
143
- self.callbacks_hook('on_epoch_begin', epoch, logs=epoch_logs)
208
+ self._callbacks_hook('on_epoch_begin', epoch, logs=epoch_logs)
144
209
 
145
210
  train_logs = self._train_step()
146
211
  epoch_logs.update(train_logs)
@@ -148,13 +213,13 @@ class MLTrainer:
148
213
  val_logs = self._validation_step()
149
214
  epoch_logs.update(val_logs)
150
215
 
151
- self.callbacks_hook('on_epoch_end', epoch, logs=epoch_logs)
216
+ self._callbacks_hook('on_epoch_end', epoch, logs=epoch_logs)
152
217
 
153
218
  # Check the early stopping flag
154
219
  if self.stop_training:
155
220
  break
156
221
 
157
- self.callbacks_hook('on_train_end')
222
+ self._callbacks_hook('on_train_end')
158
223
  return self.history
159
224
 
160
225
  def _train_step(self):
@@ -166,7 +231,7 @@ class MLTrainer:
166
231
  PyTorchLogKeys.BATCH_INDEX: batch_idx,
167
232
  PyTorchLogKeys.BATCH_SIZE: features.size(0)
168
233
  }
169
- self.callbacks_hook('on_batch_begin', batch_idx, logs=batch_logs)
234
+ self._callbacks_hook('on_batch_begin', batch_idx, logs=batch_logs)
170
235
 
171
236
  features, target = features.to(self.device), target.to(self.device)
172
237
  self.optimizer.zero_grad()
@@ -188,7 +253,7 @@ class MLTrainer:
188
253
 
189
254
  # Add the batch loss to the logs and call the end-of-batch hook
190
255
  batch_logs[PyTorchLogKeys.BATCH_LOSS] = batch_loss
191
- self.callbacks_hook('on_batch_end', batch_idx, logs=batch_logs)
256
+ self._callbacks_hook('on_batch_end', batch_idx, logs=batch_logs)
192
257
 
193
258
  return {PyTorchLogKeys.TRAIN_LOSS: running_loss / len(self.train_loader.dataset)} # type: ignore
194
259
 
@@ -538,11 +603,33 @@ class MLTrainer:
538
603
  else:
539
604
  _LOGGER.error("No attention weights were collected from the model.")
540
605
 
541
- def callbacks_hook(self, method_name: str, *args, **kwargs):
606
+ def _callbacks_hook(self, method_name: str, *args, **kwargs):
542
607
  """Calls the specified method on all callbacks."""
543
608
  for callback in self.callbacks:
544
609
  method = getattr(callback, method_name)
545
610
  method(*args, **kwargs)
611
+
612
+ def to_cpu(self):
613
+ """
614
+ Moves the model to the CPU and updates the trainer's device setting.
615
+
616
+ This is useful for running operations that require the CPU.
617
+ """
618
+ self.device = torch.device('cpu')
619
+ self.model.to(self.device)
620
+ _LOGGER.info("Trainer and model moved to CPU.")
621
+
622
+ def to_device(self, device: str):
623
+ """
624
+ Moves the model to the specified device and updates the trainer's device setting.
625
+
626
+ Args:
627
+ device (str): The target device (e.g., 'cuda', 'mps', 'cpu').
628
+ """
629
+ self.device = self._validate_device(device)
630
+ self.model.to(self.device)
631
+ _LOGGER.info(f"Trainer and model moved to {self.device}.")
632
+
546
633
 
547
634
  def info():
548
635
  _script_info(__all__)
ml_tools/keys.py CHANGED
@@ -68,6 +68,15 @@ class SHAPKeys:
68
68
  SAVENAME = "shap_summary"
69
69
 
70
70
 
71
+ class PyTorchCheckpointKeys:
72
+ """Keys for saving/loading a training checkpoint dictionary."""
73
+ MODEL_STATE = "model_state_dict"
74
+ OPTIMIZER_STATE = "optimizer_state_dict"
75
+ SCHEDULER_STATE = "scheduler_state_dict"
76
+ EPOCH = "epoch"
77
+ BEST_SCORE = "best_score"
78
+
79
+
71
80
  class _OneHotOtherPlaceholder:
72
81
  """Used internally by GUI_tools."""
73
82
  OTHER_GUI = "OTHER"