dragon-ml-toolbox 12.13.0__py3-none-any.whl → 13.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dragon-ml-toolbox might be problematic. Click here for more details.

ml_tools/ML_models.py CHANGED
@@ -8,6 +8,7 @@ from ._logger import _LOGGER
8
8
  from .path_manager import make_fullpath
9
9
  from ._script_info import _script_info
10
10
  from .keys import PytorchModelArchitectureKeys
11
+ from ._schema import FeatureSchema
11
12
 
12
13
 
13
14
  __all__ = [
@@ -298,76 +299,59 @@ class TabularTransformer(nn.Module, _ArchitectureHandlerMixin):
298
299
  """
299
300
  A Transformer-based model for tabular data tasks.
300
301
 
301
- This model uses a Feature Tokenizer to convert all input features into a sequence of embeddings, prepends a [CLS] token, and processes the
302
+ This model uses a Feature Tokenizer to convert all input features into a
303
+ sequence of embeddings, prepends a [CLS] token, and processes the
302
304
  sequence with a standard Transformer Encoder.
303
305
  """
304
306
  def __init__(self, *,
305
- in_features: int,
307
+ schema: FeatureSchema,
306
308
  out_targets: int,
307
- categorical_index_map: Dict[int, int],
308
309
  embedding_dim: int = 32,
309
310
  num_heads: int = 8,
310
311
  num_layers: int = 6,
311
312
  dropout: float = 0.1):
312
313
  """
313
314
  Args:
314
- in_features (int): The total number of columns in the input data (features).
315
- out_targets (int): Number of output targets (1 for regression).
316
- categorical_index_map (Dict[int, int]): Maps categorical column index to its cardinality (number of unique categories).
317
- embedding_dim (int): The dimension for all feature embeddings. Must be divisible by num_heads.
318
- num_heads (int): The number of heads in the multi-head attention mechanism.
319
- num_layers (int): The number of sub-encoder-layers in the transformer encoder.
320
- dropout (float): The dropout value.
321
-
322
- Note:
323
- - All arguments are keyword-only to promote clarity.
324
- - Column indices start at 0.
325
-
326
- ### Data Preparation
327
- The model requires a specific input format. All columns in the input DataFrame must be numerical, but they are treated differently based on the
328
- provided index lists.
329
-
330
- **Nominal Categorical Features** (e.g., 'City', 'Color'): Should **NOT** be one-hot encoded.
331
- Instead, convert them to integer codes (label encoding). You must then provide a dictionary mapping their column indices to
332
- their cardinality (the number of unique categories) via the `categorical_map` parameter.
333
-
334
- **Ordinal & Binary Features** (e.g., 'Low/Medium/High', 'True/False'): Should be treated as **numerical**. Map them to numbers that
335
- represent their state (e.g., `{'Low': 0, 'Medium': 1}` or `{False: 0, True: 1}`). Their column indices should **NOT** be included in the
336
- `categorical_map` parameter.
337
-
338
- **Standard Numerical and Continuous Features** (e.g., 'Age', 'Price'): It is highly recommended to scale them before training.
315
+ schema (FeatureSchema):
316
+ The definitive schema object created by `data_exploration.finalize_feature_schema()`.
317
+ out_targets (int):
318
+ Number of output targets (1 for regression).
319
+ embedding_dim (int):
320
+ The dimension for all feature embeddings. Must be divisible
321
+ by num_heads.
322
+ num_heads (int):
323
+ The number of heads in the multi-head attention mechanism.
324
+ num_layers (int):
325
+ The number of sub-encoder-layers in the transformer encoder.
326
+ dropout (float):
327
+ The dropout value.
339
328
  """
340
329
  super().__init__()
341
330
 
331
+ # --- Get info from schema ---
332
+ in_features = len(schema.feature_names)
333
+ categorical_index_map = schema.categorical_index_map
334
+
342
335
  # --- Validation ---
343
- if categorical_index_map and max(categorical_index_map.keys()) >= in_features:
336
+ if categorical_index_map and (max(categorical_index_map.keys()) >= in_features):
344
337
  _LOGGER.error(f"A categorical index ({max(categorical_index_map.keys())}) is out of bounds for the provided input features ({in_features}).")
345
338
  raise ValueError()
346
339
 
347
- # --- Derive numerical indices ---
348
- all_indices = set(range(in_features))
349
- categorical_indices_set = set(categorical_index_map.keys())
350
- numerical_indices = sorted(list(all_indices - categorical_indices_set))
351
-
352
340
  # --- Save configuration ---
353
- self.in_features = in_features
341
+ self.schema = schema # <-- Save the whole schema
354
342
  self.out_targets = out_targets
355
- self.numerical_indices = numerical_indices
356
- self.categorical_map = categorical_index_map
357
343
  self.embedding_dim = embedding_dim
358
344
  self.num_heads = num_heads
359
345
  self.num_layers = num_layers
360
346
  self.dropout = dropout
361
347
 
362
- # --- 1. Feature Tokenizer ---
348
+ # --- 1. Feature Tokenizer (now takes the schema) ---
363
349
  self.tokenizer = _FeatureTokenizer(
364
- numerical_indices=numerical_indices,
365
- categorical_map=categorical_index_map,
350
+ schema=schema,
366
351
  embedding_dim=embedding_dim
367
352
  )
368
353
 
369
354
  # --- 2. CLS Token ---
370
- # A learnable token that will be prepended to the sequence.
371
355
  self.cls_token = nn.Parameter(torch.randn(1, 1, embedding_dim))
372
356
 
373
357
  # --- 3. Transformer Encoder ---
@@ -416,21 +400,87 @@ class TabularTransformer(nn.Module, _ArchitectureHandlerMixin):
416
400
 
417
401
  def get_architecture_config(self) -> Dict[str, Any]:
418
402
  """Returns the full configuration of the model."""
403
+ # Deconstruct schema into a JSON-friendly dict
404
+ # Tuples are saved as lists
405
+ schema_dict = {
406
+ 'feature_names': self.schema.feature_names,
407
+ 'continuous_feature_names': self.schema.continuous_feature_names,
408
+ 'categorical_feature_names': self.schema.categorical_feature_names,
409
+ 'categorical_index_map': self.schema.categorical_index_map,
410
+ 'categorical_mappings': self.schema.categorical_mappings
411
+ }
412
+
419
413
  return {
420
- 'in_features': self.in_features,
414
+ 'schema_dict': schema_dict,
421
415
  'out_targets': self.out_targets,
422
- 'categorical_map': self.categorical_map,
423
416
  'embedding_dim': self.embedding_dim,
424
417
  'num_heads': self.num_heads,
425
418
  'num_layers': self.num_layers,
426
419
  'dropout': self.dropout
427
420
  }
421
+
422
+ @classmethod
423
+ def load(cls: type, file_or_dir: Union[str, Path], verbose: bool = True) -> nn.Module:
424
+ """Loads a model architecture from a JSON file."""
425
+ user_path = make_fullpath(file_or_dir)
426
+
427
+ if user_path.is_dir():
428
+ json_filename = PytorchModelArchitectureKeys.SAVENAME + ".json"
429
+ target_path = make_fullpath(user_path / json_filename, enforce="file")
430
+ elif user_path.is_file():
431
+ target_path = user_path
432
+ else:
433
+ _LOGGER.error(f"Invalid path: '{file_or_dir}'")
434
+ raise IOError()
435
+
436
+ with open(target_path, 'r') as f:
437
+ saved_data = json.load(f)
438
+
439
+ saved_class_name = saved_data[PytorchModelArchitectureKeys.MODEL]
440
+ config = saved_data[PytorchModelArchitectureKeys.CONFIG]
441
+
442
+ if saved_class_name != cls.__name__:
443
+ _LOGGER.error(f"Model class mismatch. File specifies '{saved_class_name}', but '{cls.__name__}' was expected.")
444
+ raise ValueError()
445
+
446
+ # --- RECONSTRUCTION LOGIC ---
447
+ if 'schema_dict' not in config:
448
+ _LOGGER.error("Invalid architecture file: missing 'schema_dict'. This file may be from an older version.")
449
+ raise ValueError("Missing 'schema_dict' in config.")
450
+
451
+ schema_data = config.pop('schema_dict')
452
+
453
+ # Re-hydrate the categorical_index_map
454
+ # JSON saves all dict keys as strings, so we must convert them back to int.
455
+ raw_index_map = schema_data['categorical_index_map']
456
+ if raw_index_map is not None:
457
+ rehydrated_index_map = {int(k): v for k, v in raw_index_map.items()}
458
+ else:
459
+ rehydrated_index_map = None
460
+
461
+ # Re-hydrate the FeatureSchema object
462
+ # JSON deserializes tuples as lists, so we must convert them back.
463
+ schema = FeatureSchema(
464
+ feature_names=tuple(schema_data['feature_names']),
465
+ continuous_feature_names=tuple(schema_data['continuous_feature_names']),
466
+ categorical_feature_names=tuple(schema_data['categorical_feature_names']),
467
+ categorical_index_map=rehydrated_index_map,
468
+ categorical_mappings=schema_data['categorical_mappings']
469
+ )
470
+
471
+ config['schema'] = schema
472
+ # --- End Reconstruction ---
473
+
474
+ model = cls(**config)
475
+ if verbose:
476
+ _LOGGER.info(f"Successfully loaded architecture for '{saved_class_name}'")
477
+ return model
428
478
 
429
479
  def __repr__(self) -> str:
430
480
  """Returns the developer-friendly string representation of the model."""
431
481
  # Build the architecture string part-by-part
432
482
  parts = [
433
- f"Tokenizer(features={self.in_features}, dim={self.embedding_dim})",
483
+ f"Tokenizer(features={len(self.schema.feature_names)}, dim={self.embedding_dim})",
434
484
  "[CLS]",
435
485
  f"TransformerEncoder(layers={self.num_layers}, heads={self.num_heads})",
436
486
  f"PredictionHead(outputs={self.out_targets})"
@@ -443,29 +493,41 @@ class TabularTransformer(nn.Module, _ArchitectureHandlerMixin):
443
493
 
444
494
  class _FeatureTokenizer(nn.Module):
445
495
  """
446
- Transforms raw numerical and categorical features from any column order into a sequence of embeddings.
496
+ Transforms raw numerical and categorical features from any column order
497
+ into a sequence of embeddings.
447
498
  """
448
499
  def __init__(self,
449
- numerical_indices: List[int],
450
- categorical_map: Dict[int, int],
500
+ schema: FeatureSchema,
451
501
  embedding_dim: int):
452
502
  """
453
503
  Args:
454
- numerical_indices (List[int]): A list of column indices for the numerical features.
455
- categorical_map (Dict[int, int]): A dictionary mapping each categorical column index to its cardinality (number of unique categories).
456
- embedding_dim (int): The dimension for all feature embeddings.
504
+ schema (FeatureSchema):
505
+ The definitive schema object from data_exploration.
506
+ embedding_dim (int):
507
+ The dimension for all feature embeddings.
457
508
  """
458
509
  super().__init__()
459
510
 
460
- # Unpack the dictionary into separate lists for indices and cardinalities
461
- self.categorical_indices = list(categorical_map.keys())
462
- cardinalities = list(categorical_map.values())
511
+ # --- Get info from schema ---
512
+ categorical_map = schema.categorical_index_map
513
+
514
+ if categorical_map:
515
+ # Unpack the dictionary into separate lists
516
+ self.categorical_indices = list(categorical_map.keys())
517
+ cardinalities = list(categorical_map.values())
518
+ else:
519
+ self.categorical_indices = []
520
+ cardinalities = []
521
+
522
+ # Derive numerical indices by finding what's not categorical
523
+ all_indices = set(range(len(schema.feature_names)))
524
+ categorical_indices_set = set(self.categorical_indices)
525
+ self.numerical_indices = sorted(list(all_indices - categorical_indices_set))
463
526
 
464
- self.numerical_indices = numerical_indices
465
527
  self.embedding_dim = embedding_dim
466
528
 
467
529
  # A learnable embedding for each numerical feature
468
- self.numerical_embeddings = nn.Parameter(torch.randn(len(numerical_indices), embedding_dim))
530
+ self.numerical_embeddings = nn.Parameter(torch.randn(len(self.numerical_indices), embedding_dim))
469
531
 
470
532
  # A standard embedding layer for each categorical feature
471
533
  self.categorical_embeddings = nn.ModuleList(
@@ -487,6 +549,8 @@ class _FeatureTokenizer(nn.Module):
487
549
  # Process categorical features
488
550
  categorical_tokens = []
489
551
  for i, embed_layer in enumerate(self.categorical_embeddings):
552
+ # x_categorical[:, i] selects the i-th categorical column
553
+ # (e.g., all values for the 'color' feature)
490
554
  token = embed_layer(x_categorical[:, i]).unsqueeze(1)
491
555
  categorical_tokens.append(token)
492
556
 
@@ -17,9 +17,10 @@ from ._script_info import _script_info
17
17
  from .ML_inference import PyTorchInferenceHandler
18
18
  from .keys import PyTorchInferenceKeys
19
19
  from .SQL import DatabaseManager
20
- from .optimization_tools import _save_result
20
+ from .optimization_tools import _save_result, create_optimization_bounds
21
21
  from .utilities import save_dataframe_filename
22
22
  from .math_utilities import discretize_categorical_values
23
+ from ._schema import FeatureSchema
23
24
 
24
25
 
25
26
  __all__ = [
@@ -40,66 +41,76 @@ class MLOptimizer:
40
41
  SNES and CEM algorithms do not accept bounds, the given bounds will be used as an initial starting point.
41
42
 
42
43
  Example:
43
- >>> # 1. Get categorical info from preprocessing steps
44
- >>> # e.g., from data_exploration.encode_categorical_features
45
- >>> cat_mappings = {'feature_C': {'A': 0, 'B': 1}, 'feature_D': {'X': 0, 'Y': 1}}
46
- >>> # e.g., from data_exploration.create_transformer_categorical_map
47
- >>> # Assumes feature_C is at index 2 (cardinality 2) and feature_D is at index 3 (cardinality 2)
48
- >>> cat_index_map = {2: 2, 3: 2}
44
+ >>> # 1. Get the final schema from data exploration
45
+ >>> schema = data_exploration.finalize_feature_schema(...)
46
+ >>> # 2. Define bounds for continuous features
47
+ >>> cont_bounds = {'feature_A': (0, 100), 'feature_B': (-10, 10)}
49
48
  >>>
50
- >>> # 2. Initialize the optimizer
49
+ >>> # 3. Initialize the optimizer
51
50
  >>> optimizer = MLOptimizer(
52
51
  ... inference_handler=my_handler,
53
- ... bounds=(lower_bounds, upper_bounds), # Bounds for ALL features
52
+ ... schema=schema,
53
+ ... continuous_bounds_map=cont_bounds,
54
54
  ... task="max",
55
55
  ... algorithm="Genetic",
56
- ... categorical_index_map=cat_index_map,
57
- ... categorical_mappings=cat_mappings,
58
56
  ... )
59
- >>> # 3. Run the optimization
57
+ >>> # 4. Run the optimization
60
58
  >>> best_result = optimizer.run(
61
59
  ... num_generations=100,
62
60
  ... target_name="my_target",
63
- ... feature_names=my_feature_names,
64
61
  ... save_dir="/path/to/results",
65
62
  ... save_format="csv"
66
63
  ... )
67
64
  """
68
65
  def __init__(self,
69
66
  inference_handler: PyTorchInferenceHandler,
70
- bounds: Tuple[List[float], List[float]],
67
+ schema: FeatureSchema,
68
+ continuous_bounds_map: Dict[str, Tuple[float, float]],
71
69
  task: Literal["min", "max"],
72
70
  algorithm: Literal["SNES", "CEM", "Genetic"] = "Genetic",
73
71
  population_size: int = 200,
74
- categorical_index_map: Optional[Dict[int, int]] = None,
75
- categorical_mappings: Optional[Dict[str, Dict[str, int]]] = None,
76
72
  discretize_start_at_zero: bool = True,
77
73
  **searcher_kwargs):
78
74
  """
79
75
  Initializes the optimizer by creating the EvoTorch problem and searcher.
80
76
 
81
77
  Args:
82
- inference_handler (PyTorchInferenceHandler): An initialized inference handler containing the model and weights.
83
- bounds (tuple[list[float], list[float]]): A tuple containing the lower and upper bounds for ALL solution features.
84
- Use the `optimization_tools.create_optimization_bounds()` helper to easily generate this and ensure unbiased categorical bounds.
78
+ inference_handler (PyTorchInferenceHandler):
79
+ An initialized inference handler containing the model.
80
+ schema (FeatureSchema):
81
+ The definitive schema object from data_exploration.
82
+ continuous_bounds_map (Dict[str, Tuple[float, float]]):
83
+ A dictionary mapping the *name* of each **continuous** feature
84
+ to its (min_bound, max_bound) tuple.
85
85
  task (str): The optimization goal, either "min" or "max".
86
86
  algorithm (str): The search algorithm to use ("SNES", "CEM", "Genetic").
87
87
  population_size (int): Population size for CEM and GeneticAlgorithm.
88
- categorical_index_map (Dict[int, int] | None): Used to discretize values after optimization. Maps {column_index: cardinality}.
89
- categorical_mappings (Dict[str, Dict[str, int]] | None): Used to map discrete integer values back to strings (e.g., {0: 'Category_A'}) before saving.
90
88
  discretize_start_at_zero (bool):
91
89
  True if the discrete encoding starts at 0 (e.g., [0, 1, 2]).
92
90
  False if it starts at 1 (e.g., [1, 2, 3]).
93
- **searcher_kwargs: Additional keyword arguments for the selected search algorithm's constructor.
91
+ **searcher_kwargs: Additional keyword arguments for the selected
92
+ search algorithm's constructor.
94
93
  """
95
- # Make a fitness function
94
+ # --- Store schema ---
95
+ self.schema = schema
96
+
97
+ # --- 1. Create bounds from schema ---
98
+ # This is the new, robust way to get bounds
99
+ bounds = create_optimization_bounds(
100
+ schema=schema,
101
+ continuous_bounds_map=continuous_bounds_map,
102
+ start_at_zero=discretize_start_at_zero
103
+ )
104
+
105
+ # --- 2. Make a fitness function ---
96
106
  self.evaluator = FitnessEvaluator(
97
107
  inference_handler=inference_handler,
98
- categorical_index_map=categorical_index_map,
108
+ # Get categorical info from the schema
109
+ categorical_index_map=schema.categorical_index_map,
99
110
  discretize_start_at_zero=discretize_start_at_zero
100
111
  )
101
112
 
102
- # Call the existing factory function to get the problem and searcher factory
113
+ # --- 3. Create the problem and searcher factory ---
103
114
  self.problem, self.searcher_factory = create_pytorch_problem(
104
115
  evaluator=self.evaluator,
105
116
  bounds=bounds,
@@ -108,36 +119,36 @@ class MLOptimizer:
108
119
  population_size=population_size,
109
120
  **searcher_kwargs
110
121
  )
111
- # Store categorical info to pass to the run function
112
- self.categorical_map = categorical_index_map
113
- self.categorical_mappings = categorical_mappings
122
+
123
+ # --- 4. Store other info needed by run() ---
114
124
  self.discretize_start_at_zero = discretize_start_at_zero
115
125
 
116
126
  def run(self,
117
127
  num_generations: int,
118
128
  target_name: str,
119
129
  save_dir: Union[str, Path],
120
- feature_names: Optional[List[str]],
121
130
  save_format: Literal['csv', 'sqlite', 'both'],
122
131
  repetitions: int = 1,
123
132
  verbose: bool = True) -> Optional[dict]:
124
133
  """
125
134
  Runs the evolutionary optimization process using the pre-configured settings.
126
135
 
136
+ The `feature_names` are automatically pulled from the `FeatureSchema`
137
+ provided during initialization.
138
+
127
139
  Args:
128
140
  num_generations (int): The total number of generations for each repetition.
129
141
  target_name (str): Target name used for the CSV filename and/or SQL table.
130
142
  save_dir (str | Path): The directory where result files will be saved.
131
- feature_names (List[str] | None): Names of the solution features for labeling output.
132
- If None, generic names like 'feature_0', 'feature_1', ... , will be created.
133
143
  save_format (Literal['csv', 'sqlite', 'both']): The format for saving results.
134
144
  repetitions (int): The number of independent times to run the optimization.
135
145
  verbose (bool): If True, enables detailed logging.
136
146
 
137
147
  Returns:
138
- Optional[dict]: A dictionary with the best result if repetitions is 1, otherwise None.
148
+ Optional[dict]: A dictionary with the best result if repetitions is 1,
149
+ otherwise None.
139
150
  """
140
- # Call the existing run function with the stored problem, searcher, and categorical info
151
+ # Call the existing run function, passing info from the schema
141
152
  return run_optimization(
142
153
  problem=self.problem,
143
154
  searcher_factory=self.searcher_factory,
@@ -145,11 +156,13 @@ class MLOptimizer:
145
156
  target_name=target_name,
146
157
  save_dir=save_dir,
147
158
  save_format=save_format,
148
- feature_names=feature_names,
159
+ # Get the definitive feature names (as a list) from the schema
160
+ feature_names=list(self.schema.feature_names),
161
+ # Get categorical info from the schema
162
+ categorical_map=self.schema.categorical_index_map,
163
+ categorical_mappings=self.schema.categorical_mappings,
149
164
  repetitions=repetitions,
150
165
  verbose=verbose,
151
- categorical_map=self.categorical_map,
152
- categorical_mappings=self.categorical_mappings,
153
166
  discretize_start_at_zero=self.discretize_start_at_zero
154
167
  )
155
168
 
ml_tools/ML_trainer.py CHANGED
@@ -5,12 +5,13 @@ import torch
5
5
  from torch import nn
6
6
  import numpy as np
7
7
 
8
- from .ML_callbacks import Callback, History, TqdmProgressBar
8
+ from .ML_callbacks import Callback, History, TqdmProgressBar, ModelCheckpoint
9
9
  from .ML_evaluation import classification_metrics, regression_metrics, plot_losses, shap_summary_plot, plot_attention_importance
10
10
  from .ML_evaluation_multi import multi_target_regression_metrics, multi_label_classification_metrics, multi_target_shap_summary_plot
11
11
  from ._script_info import _script_info
12
- from .keys import PyTorchLogKeys
12
+ from .keys import PyTorchLogKeys, PyTorchCheckpointKeys
13
13
  from ._logger import _LOGGER
14
+ from .path_manager import make_fullpath
14
15
 
15
16
 
16
17
  __all__ = [
@@ -55,6 +56,7 @@ class MLTrainer:
55
56
  self.kind = kind
56
57
  self.criterion = criterion
57
58
  self.optimizer = optimizer
59
+ self.scheduler = None
58
60
  self.device = self._validate_device(device)
59
61
  self.dataloader_workers = dataloader_workers
60
62
 
@@ -70,6 +72,7 @@ class MLTrainer:
70
72
  self.history = {}
71
73
  self.epoch = 0
72
74
  self.epochs = 0 # Total epochs for the fit run
75
+ self.start_epoch = 1
73
76
  self.stop_training = False
74
77
 
75
78
  def _validate_device(self, device: str) -> torch.device:
@@ -109,8 +112,66 @@ class MLTrainer:
109
112
  num_workers=loader_workers,
110
113
  pin_memory=("cuda" in self.device.type)
111
114
  )
115
+
116
+ def _load_checkpoint(self, path: Union[str, Path]):
117
+ """Loads a training checkpoint to resume training."""
118
+ p = make_fullpath(path, enforce="file")
119
+ _LOGGER.info(f"Loading checkpoint from '{p.name}' to resume training...")
120
+
121
+ try:
122
+ checkpoint = torch.load(p, map_location=self.device)
123
+
124
+ if PyTorchCheckpointKeys.MODEL_STATE not in checkpoint or PyTorchCheckpointKeys.OPTIMIZER_STATE not in checkpoint:
125
+ _LOGGER.error(f"Checkpoint file '{p.name}' is invalid. Missing 'model_state_dict' or 'optimizer_state_dict'.")
126
+ raise KeyError()
112
127
 
113
- def fit(self, epochs: int = 10, batch_size: int = 10, shuffle: bool = True):
128
+ self.model.load_state_dict(checkpoint[PyTorchCheckpointKeys.MODEL_STATE])
129
+ self.optimizer.load_state_dict(checkpoint[PyTorchCheckpointKeys.OPTIMIZER_STATE])
130
+ self.start_epoch = checkpoint.get(PyTorchCheckpointKeys.EPOCH, 0) + 1 # Resume on the *next* epoch
131
+
132
+ # --- Scheduler State Loading Logic ---
133
+ scheduler_state_exists = PyTorchCheckpointKeys.SCHEDULER_STATE in checkpoint
134
+ scheduler_object_exists = self.scheduler is not None
135
+
136
+ if scheduler_object_exists and scheduler_state_exists:
137
+ # Case 1: Both exist. Attempt to load.
138
+ try:
139
+ self.scheduler.load_state_dict(checkpoint[PyTorchCheckpointKeys.SCHEDULER_STATE]) # type: ignore
140
+ scheduler_name = self.scheduler.__class__.__name__
141
+ _LOGGER.info(f"Restored LR scheduler state for: {scheduler_name}")
142
+ except Exception as e:
143
+ # Loading failed, likely a mismatch
144
+ scheduler_name = self.scheduler.__class__.__name__
145
+ _LOGGER.error(f"Failed to load scheduler state for '{scheduler_name}'. A different scheduler type might have been used.")
146
+ raise e
147
+
148
+ elif scheduler_object_exists and not scheduler_state_exists:
149
+ # Case 2: Scheduler provided, but no state in checkpoint.
150
+ scheduler_name = self.scheduler.__class__.__name__
151
+ _LOGGER.warning(f"'{scheduler_name}' was provided, but no scheduler state was found in the checkpoint. The scheduler will start from its initial state.")
152
+
153
+ elif not scheduler_object_exists and scheduler_state_exists:
154
+ # Case 3: State in checkpoint, but no scheduler provided.
155
+ _LOGGER.error("Checkpoint contains an LR scheduler state, but no LRScheduler callback was provided.")
156
+ raise ValueError()
157
+
158
+ # Restore callback states
159
+ for cb in self.callbacks:
160
+ if isinstance(cb, ModelCheckpoint) and PyTorchCheckpointKeys.BEST_SCORE in checkpoint:
161
+ cb.best = checkpoint[PyTorchCheckpointKeys.BEST_SCORE]
162
+ _LOGGER.info(f"Restored {cb.__class__.__name__} 'best' score to: {cb.best:.4f}")
163
+
164
+ _LOGGER.info(f"Checkpoint loaded. Resuming training from epoch {self.start_epoch}.")
165
+
166
+ except Exception as e:
167
+ _LOGGER.error(f"Failed to load checkpoint from '{p}': {e}")
168
+ raise
169
+
170
+ def fit(self,
171
+ epochs: int = 10,
172
+ batch_size: int = 10,
173
+ shuffle: bool = True,
174
+ resume_from_checkpoint: Optional[Union[str, Path]] = None):
114
175
  """
115
176
  Starts the training-validation process of the model.
116
177
 
@@ -120,6 +181,7 @@ class MLTrainer:
120
181
  epochs (int): The total number of epochs to train for.
121
182
  batch_size (int): The number of samples per batch.
122
183
  shuffle (bool): Whether to shuffle the training data at each epoch.
184
+ resume_from_checkpoint (str | Path | None): Optional path to a checkpoint to resume training.
123
185
 
124
186
  Note:
125
187
  For regression tasks using `nn.MSELoss` or `nn.L1Loss`, the trainer
@@ -132,15 +194,18 @@ class MLTrainer:
132
194
  self._create_dataloaders(batch_size, shuffle)
133
195
  self.model.to(self.device)
134
196
 
197
+ if resume_from_checkpoint:
198
+ self._load_checkpoint(resume_from_checkpoint)
199
+
135
200
  # Reset stop_training flag on the trainer
136
201
  self.stop_training = False
137
202
 
138
- self.callbacks_hook('on_train_begin')
203
+ self._callbacks_hook('on_train_begin')
139
204
 
140
- for epoch in range(1, self.epochs + 1):
205
+ for epoch in range(self.start_epoch, self.epochs + 1):
141
206
  self.epoch = epoch
142
207
  epoch_logs = {}
143
- self.callbacks_hook('on_epoch_begin', epoch, logs=epoch_logs)
208
+ self._callbacks_hook('on_epoch_begin', epoch, logs=epoch_logs)
144
209
 
145
210
  train_logs = self._train_step()
146
211
  epoch_logs.update(train_logs)
@@ -148,13 +213,13 @@ class MLTrainer:
148
213
  val_logs = self._validation_step()
149
214
  epoch_logs.update(val_logs)
150
215
 
151
- self.callbacks_hook('on_epoch_end', epoch, logs=epoch_logs)
216
+ self._callbacks_hook('on_epoch_end', epoch, logs=epoch_logs)
152
217
 
153
218
  # Check the early stopping flag
154
219
  if self.stop_training:
155
220
  break
156
221
 
157
- self.callbacks_hook('on_train_end')
222
+ self._callbacks_hook('on_train_end')
158
223
  return self.history
159
224
 
160
225
  def _train_step(self):
@@ -166,7 +231,7 @@ class MLTrainer:
166
231
  PyTorchLogKeys.BATCH_INDEX: batch_idx,
167
232
  PyTorchLogKeys.BATCH_SIZE: features.size(0)
168
233
  }
169
- self.callbacks_hook('on_batch_begin', batch_idx, logs=batch_logs)
234
+ self._callbacks_hook('on_batch_begin', batch_idx, logs=batch_logs)
170
235
 
171
236
  features, target = features.to(self.device), target.to(self.device)
172
237
  self.optimizer.zero_grad()
@@ -188,7 +253,7 @@ class MLTrainer:
188
253
 
189
254
  # Add the batch loss to the logs and call the end-of-batch hook
190
255
  batch_logs[PyTorchLogKeys.BATCH_LOSS] = batch_loss
191
- self.callbacks_hook('on_batch_end', batch_idx, logs=batch_logs)
256
+ self._callbacks_hook('on_batch_end', batch_idx, logs=batch_logs)
192
257
 
193
258
  return {PyTorchLogKeys.TRAIN_LOSS: running_loss / len(self.train_loader.dataset)} # type: ignore
194
259
 
@@ -538,11 +603,33 @@ class MLTrainer:
538
603
  else:
539
604
  _LOGGER.error("No attention weights were collected from the model.")
540
605
 
541
- def callbacks_hook(self, method_name: str, *args, **kwargs):
606
+ def _callbacks_hook(self, method_name: str, *args, **kwargs):
542
607
  """Calls the specified method on all callbacks."""
543
608
  for callback in self.callbacks:
544
609
  method = getattr(callback, method_name)
545
610
  method(*args, **kwargs)
611
+
612
+ def to_cpu(self):
613
+ """
614
+ Moves the model to the CPU and updates the trainer's device setting.
615
+
616
+ This is useful for running operations that require the CPU.
617
+ """
618
+ self.device = torch.device('cpu')
619
+ self.model.to(self.device)
620
+ _LOGGER.info("Trainer and model moved to CPU.")
621
+
622
+ def to_device(self, device: str):
623
+ """
624
+ Moves the model to the specified device and updates the trainer's device setting.
625
+
626
+ Args:
627
+ device (str): The target device (e.g., 'cuda', 'mps', 'cpu').
628
+ """
629
+ self.device = self._validate_device(device)
630
+ self.model.to(self.device)
631
+ _LOGGER.info(f"Trainer and model moved to {self.device}.")
632
+
546
633
 
547
634
  def info():
548
635
  _script_info(__all__)
@@ -17,6 +17,10 @@ from ._script_info import _script_info
17
17
  from .SQL import DatabaseManager
18
18
  from .optimization_tools import _save_result
19
19
 
20
+ """
21
+ DEPRECATED
22
+ """
23
+
20
24
 
21
25
  __all__ = [
22
26
  "ObjectiveFunction",
@@ -46,7 +50,7 @@ class ObjectiveFunction():
46
50
  self.binary_features = binary_features
47
51
  self.is_hybrid = False if binary_features <= 0 else True
48
52
  self.use_noise = add_noise
49
- self._artifact = deserialize_object(trained_model_path, verbose=False, raise_on_error=True)
53
+ self._artifact = deserialize_object(trained_model_path, verbose=False)
50
54
  self.model = self._get_from_artifact(EnsembleKeys.MODEL)
51
55
  self.feature_names: Optional[list[str]] = self._get_from_artifact(EnsembleKeys.FEATURES) # type: ignore
52
56
  self.target_name: Optional[str] = self._get_from_artifact(EnsembleKeys.TARGET) # type: ignore