dragon-ml-toolbox 20.2.0__py3-none-any.whl → 20.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (111) hide show
  1. {dragon_ml_toolbox-20.2.0.dist-info → dragon_ml_toolbox-20.4.0.dist-info}/METADATA +1 -1
  2. dragon_ml_toolbox-20.4.0.dist-info/RECORD +143 -0
  3. ml_tools/ETL_cleaning/__init__.py +5 -1
  4. ml_tools/ETL_cleaning/_basic_clean.py +1 -1
  5. ml_tools/ETL_engineering/__init__.py +5 -1
  6. ml_tools/GUI_tools/__init__.py +5 -1
  7. ml_tools/IO_tools/_IO_loggers.py +33 -21
  8. ml_tools/IO_tools/__init__.py +5 -1
  9. ml_tools/MICE/__init__.py +8 -2
  10. ml_tools/MICE/_dragon_mice.py +1 -1
  11. ml_tools/ML_callbacks/__init__.py +5 -1
  12. ml_tools/ML_chain/__init__.py +5 -1
  13. ml_tools/ML_configuration/__init__.py +7 -1
  14. ml_tools/ML_configuration/_training.py +65 -1
  15. ml_tools/ML_datasetmaster/__init__.py +5 -1
  16. ml_tools/ML_datasetmaster/_base_datasetmaster.py +31 -20
  17. ml_tools/ML_datasetmaster/_datasetmaster.py +26 -9
  18. ml_tools/ML_datasetmaster/_sequence_datasetmaster.py +38 -23
  19. ml_tools/ML_evaluation/__init__.py +5 -1
  20. ml_tools/ML_evaluation/_classification.py +10 -2
  21. ml_tools/ML_evaluation_captum/__init__.py +5 -1
  22. ml_tools/ML_finalize_handler/__init__.py +5 -1
  23. ml_tools/ML_inference/__init__.py +5 -1
  24. ml_tools/ML_inference_sequence/__init__.py +5 -1
  25. ml_tools/ML_inference_vision/__init__.py +5 -1
  26. ml_tools/ML_models/__init__.py +21 -6
  27. ml_tools/ML_models/_dragon_autoint.py +302 -0
  28. ml_tools/ML_models/_dragon_gate.py +358 -0
  29. ml_tools/ML_models/_dragon_node.py +268 -0
  30. ml_tools/ML_models/_dragon_tabnet.py +255 -0
  31. ml_tools/ML_models_sequence/__init__.py +5 -1
  32. ml_tools/ML_models_vision/__init__.py +5 -1
  33. ml_tools/ML_optimization/__init__.py +11 -3
  34. ml_tools/ML_optimization/_multi_dragon.py +24 -8
  35. ml_tools/ML_optimization/_single_dragon.py +47 -67
  36. ml_tools/ML_optimization/_single_manual.py +1 -1
  37. ml_tools/ML_scaler/_ML_scaler.py +12 -7
  38. ml_tools/ML_scaler/__init__.py +5 -1
  39. ml_tools/ML_trainer/__init__.py +5 -1
  40. ml_tools/ML_trainer/_base_trainer.py +136 -13
  41. ml_tools/ML_trainer/_dragon_detection_trainer.py +31 -91
  42. ml_tools/ML_trainer/_dragon_sequence_trainer.py +24 -74
  43. ml_tools/ML_trainer/_dragon_trainer.py +24 -85
  44. ml_tools/ML_utilities/__init__.py +5 -1
  45. ml_tools/ML_utilities/_inspection.py +44 -30
  46. ml_tools/ML_vision_transformers/__init__.py +8 -2
  47. ml_tools/PSO_optimization/__init__.py +5 -1
  48. ml_tools/SQL/__init__.py +8 -2
  49. ml_tools/VIF/__init__.py +5 -1
  50. ml_tools/data_exploration/__init__.py +4 -1
  51. ml_tools/data_exploration/_cleaning.py +4 -2
  52. ml_tools/ensemble_evaluation/__init__.py +5 -1
  53. ml_tools/ensemble_inference/__init__.py +5 -1
  54. ml_tools/ensemble_learning/__init__.py +5 -1
  55. ml_tools/excel_handler/__init__.py +5 -1
  56. ml_tools/keys/__init__.py +5 -1
  57. ml_tools/keys/_keys.py +1 -1
  58. ml_tools/math_utilities/__init__.py +5 -1
  59. ml_tools/optimization_tools/__init__.py +5 -1
  60. ml_tools/path_manager/__init__.py +8 -2
  61. ml_tools/plot_fonts/__init__.py +8 -2
  62. ml_tools/schema/__init__.py +8 -2
  63. ml_tools/schema/_feature_schema.py +3 -3
  64. ml_tools/serde/__init__.py +5 -1
  65. ml_tools/utilities/__init__.py +5 -1
  66. ml_tools/utilities/_utility_save_load.py +38 -20
  67. dragon_ml_toolbox-20.2.0.dist-info/RECORD +0 -179
  68. ml_tools/ETL_cleaning/_imprimir.py +0 -13
  69. ml_tools/ETL_engineering/_imprimir.py +0 -24
  70. ml_tools/GUI_tools/_imprimir.py +0 -12
  71. ml_tools/IO_tools/_imprimir.py +0 -14
  72. ml_tools/MICE/_imprimir.py +0 -11
  73. ml_tools/ML_callbacks/_imprimir.py +0 -12
  74. ml_tools/ML_chain/_imprimir.py +0 -12
  75. ml_tools/ML_configuration/_imprimir.py +0 -47
  76. ml_tools/ML_datasetmaster/_imprimir.py +0 -15
  77. ml_tools/ML_evaluation/_imprimir.py +0 -25
  78. ml_tools/ML_evaluation_captum/_imprimir.py +0 -10
  79. ml_tools/ML_finalize_handler/_imprimir.py +0 -8
  80. ml_tools/ML_inference/_imprimir.py +0 -11
  81. ml_tools/ML_inference_sequence/_imprimir.py +0 -8
  82. ml_tools/ML_inference_vision/_imprimir.py +0 -8
  83. ml_tools/ML_models/_advanced_models.py +0 -1086
  84. ml_tools/ML_models/_imprimir.py +0 -18
  85. ml_tools/ML_models_sequence/_imprimir.py +0 -8
  86. ml_tools/ML_models_vision/_imprimir.py +0 -16
  87. ml_tools/ML_optimization/_imprimir.py +0 -13
  88. ml_tools/ML_scaler/_imprimir.py +0 -8
  89. ml_tools/ML_trainer/_imprimir.py +0 -10
  90. ml_tools/ML_utilities/_imprimir.py +0 -16
  91. ml_tools/ML_vision_transformers/_imprimir.py +0 -14
  92. ml_tools/PSO_optimization/_imprimir.py +0 -10
  93. ml_tools/SQL/_imprimir.py +0 -8
  94. ml_tools/VIF/_imprimir.py +0 -10
  95. ml_tools/data_exploration/_imprimir.py +0 -32
  96. ml_tools/ensemble_evaluation/_imprimir.py +0 -14
  97. ml_tools/ensemble_inference/_imprimir.py +0 -9
  98. ml_tools/ensemble_learning/_imprimir.py +0 -10
  99. ml_tools/excel_handler/_imprimir.py +0 -13
  100. ml_tools/keys/_imprimir.py +0 -11
  101. ml_tools/math_utilities/_imprimir.py +0 -11
  102. ml_tools/optimization_tools/_imprimir.py +0 -13
  103. ml_tools/path_manager/_imprimir.py +0 -15
  104. ml_tools/plot_fonts/_imprimir.py +0 -8
  105. ml_tools/schema/_imprimir.py +0 -10
  106. ml_tools/serde/_imprimir.py +0 -10
  107. ml_tools/utilities/_imprimir.py +0 -18
  108. {dragon_ml_toolbox-20.2.0.dist-info → dragon_ml_toolbox-20.4.0.dist-info}/WHEEL +0 -0
  109. {dragon_ml_toolbox-20.2.0.dist-info → dragon_ml_toolbox-20.4.0.dist-info}/licenses/LICENSE +0 -0
  110. {dragon_ml_toolbox-20.2.0.dist-info → dragon_ml_toolbox-20.4.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +0 -0
  111. {dragon_ml_toolbox-20.2.0.dist-info → dragon_ml_toolbox-20.4.0.dist-info}/top_level.txt +0 -0
@@ -1,9 +1,10 @@
1
1
  from typing import Literal, Union, Optional
2
2
  from pathlib import Path
3
3
 
4
- from ..optimization_tools import create_optimization_bounds
4
+ from ..optimization_tools import create_optimization_bounds, load_continuous_bounds_template
5
5
  from ..ML_inference import DragonInferenceHandler
6
6
  from ..schema import FeatureSchema
7
+ from ..ML_configuration import DragonOptimizerConfig
7
8
 
8
9
  from .._core import get_logger
9
10
  from ..keys._keys import MLTaskKeys
@@ -29,35 +30,28 @@ class DragonOptimizer:
29
30
  SNES and CEM algorithms do not accept bounds, the given bounds will be used as an initial starting point.
30
31
 
31
32
  Example:
32
- >>> # 1. Define bounds for continuous features
33
- >>> cont_bounds = {'feature_A': (0, 100), 'feature_B': (-10, 10)}
33
+ >>> # 1. Define configuration
34
+ >>> config = DragonOptimizerConfig(
35
+ ... target_name="my_target",
36
+ ... task="max",
37
+ ... continuous_bounds_map="path/to/bounds",
38
+ ... save_directory="/path/to/results",
39
+ ... algorithm="Genetic"
40
+ ... )
34
41
  >>>
35
42
  >>> # 2. Initialize the optimizer
36
43
  >>> optimizer = DragonOptimizer(
37
44
  ... inference_handler=my_handler,
38
45
  ... schema=schema,
39
- ... target_name="my_target",
40
- ... continuous_bounds_map=cont_bounds,
41
- ... task="max",
42
- ... algorithm="Genetic",
46
+ ... config=config
43
47
  ... )
44
48
  >>> # 3. Run the optimization
45
- >>> best_result = optimizer.run(
46
- ... num_generations=100,
47
- ... save_dir="/path/to/results",
48
- ... save_format="csv"
49
- ... )
49
+ >>> best_result = optimizer.run()
50
50
  """
51
51
  def __init__(self,
52
52
  inference_handler: DragonInferenceHandler,
53
53
  schema: FeatureSchema,
54
- target_name: str,
55
- continuous_bounds_map: dict[str, tuple[float, float]],
56
- task: Literal["min", "max"],
57
- algorithm: Literal["SNES", "CEM", "Genetic"] = "Genetic",
58
- population_size: int = 200,
59
- discretize_start_at_zero: bool = True,
60
- **searcher_kwargs):
54
+ config: DragonOptimizerConfig):
61
55
  """
62
56
  Initializes the optimizer by creating the EvoTorch problem and searcher.
63
57
 
@@ -65,45 +59,43 @@ class DragonOptimizer:
65
59
  inference_handler (DragonInferenceHandler):
66
60
  An initialized inference handler containing the model.
67
61
  schema (FeatureSchema):
68
- The definitive schema object from data_exploration.
69
- target_name (str):
70
- target name to optimize.
71
- continuous_bounds_map (Dict[str, Tuple[float, float]]):
72
- A dictionary mapping the *name* of each **continuous** feature
73
- to its (min_bound, max_bound) tuple.
74
- task (str): The optimization goal, either "min" or "max".
75
-
76
- algorithm (str): The search algorithm to use ("SNES", "CEM", "Genetic").
77
- population_size (int): Population size for CEM and GeneticAlgorithm.
78
- discretize_start_at_zero (bool):
79
- True if the discrete encoding starts at 0 (e.g., [0, 1, 2]).
80
- False if it starts at 1 (e.g., [1, 2, 3]).
81
- **searcher_kwargs: Additional keyword arguments for the selected
82
- search algorithm's constructor.
62
+ The definitive schema object.
63
+ config (DragonOptimizerConfig):
64
+ Configuration object containing optimization parameters.
83
65
  """
84
66
  # --- Store schema ---
85
67
  self.schema = schema
86
68
  # --- Store inference handler ---
87
69
  self.inference_handler = inference_handler
88
70
 
71
+ # --- Store config ---
72
+ self.config = config
73
+
89
74
  # Ensure only Regression tasks are used
90
75
  allowed_tasks = [MLTaskKeys.REGRESSION, MLTaskKeys.MULTITARGET_REGRESSION]
91
76
  if self.inference_handler.task not in allowed_tasks:
92
77
  _LOGGER.error(f"DragonOptimizer only supports {allowed_tasks}. Got '{self.inference_handler.task}'.")
93
- raise ValueError(f"Invalid Task: {self.inference_handler.task}")
78
+ raise ValueError()
94
79
 
95
80
  # --- store target name ---
96
- self.target_name = target_name
81
+ self.target_name = config.target_name
97
82
 
98
83
  # --- flag to control single vs multi-target ---
99
84
  self.is_multi_target = False
100
85
 
101
86
  # --- 1. Create bounds from schema ---
102
- # This is the robust way to get bounds
87
+ # Handle bounds loading if it's a path
88
+ raw_bounds_map = config.continuous_bounds_map
89
+ if isinstance(raw_bounds_map, (str, Path)):
90
+ continuous_bounds = load_continuous_bounds_template(raw_bounds_map)
91
+ else:
92
+ continuous_bounds = raw_bounds_map
93
+
94
+ # Robust way to get bounds
103
95
  bounds = create_optimization_bounds(
104
96
  schema=schema,
105
- continuous_bounds_map=continuous_bounds_map,
106
- start_at_zero=discretize_start_at_zero
97
+ continuous_bounds_map=continuous_bounds,
98
+ start_at_zero=config.discretize_start_at_zero
107
99
  )
108
100
 
109
101
  # Resolve target index if multi-target
@@ -114,26 +106,26 @@ class DragonOptimizer:
114
106
  _LOGGER.error("The provided inference handler does not have 'target_ids' defined.")
115
107
  raise ValueError()
116
108
 
117
- if target_name not in self.inference_handler.target_ids:
118
- _LOGGER.error(f"Target name '{target_name}' not found in the inference handler's 'target_ids': {self.inference_handler.target_ids}")
109
+ if self.target_name not in self.inference_handler.target_ids:
110
+ _LOGGER.error(f"Target name '{self.target_name}' not found in the inference handler's 'target_ids': {self.inference_handler.target_ids}")
119
111
  raise ValueError()
120
112
 
121
113
  if len(self.inference_handler.target_ids) == 1:
122
114
  # Single target regression
123
115
  target_index = None
124
- _LOGGER.info(f"Optimization locked to single-target model '{target_name}'.")
116
+ _LOGGER.info(f"Optimization locked to single-target model '{self.target_name}'.")
125
117
  else:
126
118
  # Multi-target regression (optimizing one specific column)
127
- target_index = self.inference_handler.target_ids.index(target_name)
119
+ target_index = self.inference_handler.target_ids.index(self.target_name)
128
120
  self.is_multi_target = True
129
- _LOGGER.info(f"Optimization locked to target '{target_name}' (Index {target_index}) in a multi-target model.")
121
+ _LOGGER.info(f"Optimization locked to target '{self.target_name}' (Index {target_index}) in a multi-target model.")
130
122
 
131
123
  # --- 2. Make a fitness function ---
132
124
  self.evaluator = FitnessEvaluator(
133
125
  inference_handler=inference_handler,
134
126
  # Get categorical info from the schema
135
127
  categorical_index_map=schema.categorical_index_map,
136
- discretize_start_at_zero=discretize_start_at_zero,
128
+ discretize_start_at_zero=config.discretize_start_at_zero,
137
129
  target_index=target_index
138
130
  )
139
131
 
@@ -141,20 +133,13 @@ class DragonOptimizer:
141
133
  self.problem, self.searcher_factory = create_pytorch_problem(
142
134
  evaluator=self.evaluator,
143
135
  bounds=bounds,
144
- task=task,
145
- algorithm=algorithm,
146
- population_size=population_size,
147
- **searcher_kwargs
136
+ task=config.task, # type: ignore
137
+ algorithm=config.algorithm, # type: ignore
138
+ population_size=config.population_size,
139
+ **config.searcher_kwargs
148
140
  )
149
-
150
- # --- 4. Store other info needed by run() ---
151
- self.discretize_start_at_zero = discretize_start_at_zero
152
141
 
153
142
  def run(self,
154
- num_generations: int,
155
- save_dir: Union[str, Path],
156
- save_format: Literal['csv', 'sqlite', 'both'],
157
- repetitions: int = 1,
158
143
  verbose: bool = True) -> Optional[dict]:
159
144
  """
160
145
  Runs the evolutionary optimization process using the pre-configured settings.
@@ -163,15 +148,10 @@ class DragonOptimizer:
163
148
  provided during initialization.
164
149
 
165
150
  Args:
166
- num_generations (int): The total number of generations for each repetition.
167
- save_dir (str | Path): The directory where result files will be saved.
168
- save_format (Literal['csv', 'sqlite', 'both']): The format for saving results.
169
- repetitions (int): The number of independent times to run the optimization.
170
151
  verbose (bool): If True, enables detailed logging.
171
152
 
172
153
  Returns:
173
- Optional[dict]: A dictionary with the best result if repetitions is 1,
174
- otherwise None.
154
+ Optional[dict]: A dictionary with the best result if repetitions is 1, otherwise None.
175
155
  """
176
156
  # Pass inference handler and target names for multi-target only
177
157
  if self.is_multi_target:
@@ -185,18 +165,18 @@ class DragonOptimizer:
185
165
  return run_optimization(
186
166
  problem=self.problem,
187
167
  searcher_factory=self.searcher_factory,
188
- num_generations=num_generations,
168
+ num_generations=self.config.generations,
189
169
  target_name=self.target_name,
190
- save_dir=save_dir,
191
- save_format=save_format,
170
+ save_dir=self.config.save_directory,
171
+ save_format=self.config.save_format, # type: ignore
192
172
  # Get the definitive feature names (as a list) from the schema
193
173
  feature_names=list(self.schema.feature_names),
194
174
  # Get categorical info from the schema
195
175
  categorical_map=self.schema.categorical_index_map,
196
176
  categorical_mappings=self.schema.categorical_mappings,
197
- repetitions=repetitions,
177
+ repetitions=self.config.repetitions,
198
178
  verbose=verbose,
199
- discretize_start_at_zero=self.discretize_start_at_zero,
179
+ discretize_start_at_zero=self.config.discretize_start_at_zero,
200
180
  all_target_names=target_names_to_pass,
201
181
  inference_handler=inference_handler_to_pass
202
182
  )
@@ -506,5 +506,5 @@ def _save_result(
506
506
 
507
507
  def _handle_pandas_log(logger: PandasLogger, save_path: Path, target_name: str):
508
508
  log_dataframe = logger.to_dataframe()
509
- save_dataframe_filename(df=log_dataframe, save_dir=save_path / "EvolutionLogs", filename=target_name)
509
+ save_dataframe_filename(df=log_dataframe, save_dir=save_path / "EvolutionLogs", filename=target_name, verbose=2)
510
510
 
@@ -33,7 +33,7 @@ class DragonScaler:
33
33
  self.continuous_feature_indices = continuous_feature_indices
34
34
 
35
35
  @classmethod
36
- def fit(cls, dataset: Dataset, continuous_feature_indices: list[int], batch_size: int = 64) -> 'DragonScaler':
36
+ def fit(cls, dataset: Dataset, continuous_feature_indices: list[int], batch_size: int = 64, verbose: int = 3) -> 'DragonScaler':
37
37
  """
38
38
  Fits the scaler using a PyTorch Dataset (Method A) using Batched Welford's Algorithm.
39
39
  """
@@ -85,23 +85,25 @@ class DragonScaler:
85
85
  n_total = new_n_total
86
86
 
87
87
  if n_total == 0:
88
- _LOGGER.error("Dataset is empty. Scaler cannot be fitted.")
89
- return cls(continuous_feature_indices=continuous_feature_indices)
88
+ _LOGGER.error("Dataset is empty. Scaler cannot be fitted.")
89
+ return cls(continuous_feature_indices=continuous_feature_indices)
90
90
 
91
91
  # Finalize Standard Deviation
92
92
  # Unbiased estimator (divide by n-1)
93
93
  if n_total < 2:
94
- _LOGGER.warning(f"Only one sample found. Standard deviation set to 1.")
94
+ if verbose >= 1:
95
+ _LOGGER.warning(f"Only one sample found. Standard deviation set to 1.")
95
96
  std = torch.ones_like(mean_global) # type: ignore
96
97
  else:
97
98
  variance = m2_global / (n_total - 1)
98
99
  std = torch.sqrt(torch.clamp(variance, min=1e-8))
99
-
100
- _LOGGER.info(f"Scaler fitted on {n_total} samples for {num_continuous_features} features (Welford's).")
100
+
101
+ if verbose >= 2:
102
+ _LOGGER.info(f"Scaler fitted on {n_total} samples for {num_continuous_features} features (Welford's).")
101
103
  return cls(mean=mean_global, std=std, continuous_feature_indices=continuous_feature_indices)
102
104
 
103
105
  @classmethod
104
- def fit_tensor(cls, data: torch.Tensor) -> 'DragonScaler':
106
+ def fit_tensor(cls, data: torch.Tensor, verbose: int = 3) -> 'DragonScaler':
105
107
  """
106
108
  Fits the scaler directly on a Tensor (Method B).
107
109
  Useful for targets or small datasets already in memory.
@@ -118,6 +120,9 @@ class DragonScaler:
118
120
  # Handle constant values (std=0) to prevent division by zero
119
121
  std = torch.where(std == 0, torch.tensor(1.0, device=data.device), std)
120
122
 
123
+ if verbose >= 2:
124
+ _LOGGER.info(f"Scaler fitted on tensor with {data.shape[0]} samples for {num_features} features.")
125
+
121
126
  return cls(mean=mean, std=std, continuous_feature_indices=indices)
122
127
 
123
128
  def transform(self, data: torch.Tensor) -> torch.Tensor:
@@ -2,9 +2,13 @@ from ._ML_scaler import (
2
2
  DragonScaler
3
3
  )
4
4
 
5
- from ._imprimir import info
5
+ from .._core import _imprimir_disponibles
6
6
 
7
7
 
8
8
  __all__ = [
9
9
  "DragonScaler"
10
10
  ]
11
+
12
+
13
+ def info():
14
+ _imprimir_disponibles(__all__)
@@ -10,7 +10,7 @@ from ._dragon_detection_trainer import (
10
10
  DragonDetectionTrainer
11
11
  )
12
12
 
13
- from ._imprimir import info
13
+ from .._core import _imprimir_disponibles
14
14
 
15
15
 
16
16
  __all__ = [
@@ -18,3 +18,7 @@ __all__ = [
18
18
  "DragonSequenceTrainer",
19
19
  "DragonDetectionTrainer",
20
20
  ]
21
+
22
+
23
+ def info():
24
+ _imprimir_disponibles(__all__)
@@ -1,6 +1,6 @@
1
1
  from typing import Literal, Union, Optional, Any
2
2
  from pathlib import Path
3
- from torch.utils.data import DataLoader
3
+ from torch.utils.data import DataLoader, Dataset
4
4
  import torch
5
5
  from torch import nn
6
6
  from abc import ABC, abstractmethod
@@ -10,6 +10,7 @@ from ..ML_callbacks._checkpoint import DragonModelCheckpoint
10
10
  from ..ML_callbacks._early_stop import _DragonEarlyStopping
11
11
  from ..ML_callbacks._scheduler import _DragonLRScheduler
12
12
  from ..ML_evaluation import plot_losses
13
+ from ..ML_utilities import inspect_pth_file
13
14
 
14
15
  from ..path_manager import make_fullpath
15
16
  from ..keys._keys import PyTorchCheckpointKeys, MagicWords
@@ -89,11 +90,128 @@ class _BaseDragonTrainer(ABC):
89
90
  """Gives each callback a reference to this trainer instance."""
90
91
  for callback in self.callbacks:
91
92
  callback.set_trainer(self)
93
+
94
+ def _make_dataloaders(self,
95
+ train_dataset: Any,
96
+ validation_dataset: Any,
97
+ batch_size: int,
98
+ shuffle: bool,
99
+ collate_fn: Optional[Any] = None):
100
+ """
101
+ Shared logic to initialize standard DataLoaders.
102
+ Subclasses can call this inside their _create_dataloaders implementation.
103
+ """
104
+ # Ensure stability on MPS devices by setting num_workers to 0
105
+ loader_workers = 0 if self.device.type == 'mps' else self.dataloader_workers
106
+ pin_memory = ("cuda" in self.device.type)
107
+
108
+ self.train_loader = DataLoader(
109
+ dataset=train_dataset,
110
+ batch_size=batch_size,
111
+ shuffle=shuffle,
112
+ num_workers=loader_workers,
113
+ pin_memory=pin_memory,
114
+ drop_last=True,
115
+ collate_fn=collate_fn
116
+ )
117
+
118
+ self.validation_loader = DataLoader(
119
+ dataset=validation_dataset,
120
+ batch_size=batch_size,
121
+ shuffle=False,
122
+ num_workers=loader_workers,
123
+ pin_memory=pin_memory,
124
+ collate_fn=collate_fn
125
+ )
126
+
127
+ def _validate_checkpoint_arg(self, model_checkpoint: Union[Path, str]) -> Union[Path, str]:
128
+ """Validates the model_checkpoint argument."""
129
+ if isinstance(model_checkpoint, Path):
130
+ return make_fullpath(model_checkpoint, enforce="file")
131
+ elif model_checkpoint in [MagicWords.BEST, MagicWords.CURRENT]:
132
+ return model_checkpoint
133
+ else:
134
+ _LOGGER.error(f"'model_checkpoint' must be a Path object, or the string '{MagicWords.BEST}', or the string '{MagicWords.CURRENT}'.")
135
+ raise ValueError()
136
+
137
+ def _validate_save_dir(self, save_dir: Union[str, Path]) -> Path:
138
+ """Validates and creates the save directory."""
139
+ return make_fullpath(save_dir, make=True, enforce="directory")
140
+
141
+ def _prepare_eval_data(self,
142
+ data: Optional[Union[DataLoader, Dataset]],
143
+ default_dataset: Optional[Dataset],
144
+ collate_fn: Optional[Any] = None) -> tuple[DataLoader, Any]:
145
+ """
146
+ Prepares the DataLoader and dataset artifact source for evaluation.
147
+
148
+ Returns:
149
+ (eval_loader, dataset_for_artifacts)
150
+ """
151
+ eval_loader = None
152
+ dataset_for_artifacts = None
153
+
154
+ # Loader workers config
155
+ loader_workers = 0 if self.device.type == 'mps' else self.dataloader_workers
156
+ pin_memory = (self.device.type == "cuda")
92
157
 
93
- def _load_checkpoint(self, path: Union[str, Path]):
158
+ if isinstance(data, DataLoader):
159
+ eval_loader = data
160
+ if hasattr(data, 'dataset'):
161
+ dataset_for_artifacts = data.dataset
162
+ elif isinstance(data, Dataset):
163
+ eval_loader = DataLoader(data,
164
+ batch_size=self._batch_size,
165
+ shuffle=False,
166
+ num_workers=loader_workers,
167
+ pin_memory=pin_memory,
168
+ collate_fn=collate_fn)
169
+ dataset_for_artifacts = data
170
+ else: # data is None
171
+ if default_dataset is None:
172
+ _LOGGER.error("Cannot evaluate. No data provided and no validation dataset available in the trainer.")
173
+ raise ValueError()
174
+
175
+ eval_loader = DataLoader(default_dataset,
176
+ batch_size=self._batch_size,
177
+ shuffle=False,
178
+ num_workers=loader_workers,
179
+ pin_memory=pin_memory,
180
+ collate_fn=collate_fn)
181
+ dataset_for_artifacts = default_dataset
182
+
183
+ if eval_loader is None:
184
+ _LOGGER.error("Cannot evaluate. No valid data was provided or found.")
185
+ raise ValueError()
186
+
187
+ return eval_loader, dataset_for_artifacts
188
+
189
+ def _save_finalized_artifact(self,
190
+ finalized_data: dict,
191
+ save_dir: Union[str, Path],
192
+ filename: str):
193
+ """
194
+ Handles the common logic for saving the finalized model dictionary to disk.
195
+ """
196
+ # handle save path
197
+ dir_path = self._validate_save_dir(save_dir)
198
+ full_path = dir_path / filename
199
+
200
+ # checkpoint loading happens before dict creation.
201
+
202
+ torch.save(finalized_data, full_path)
203
+
204
+ _LOGGER.info(f"Finalized model file saved to '{full_path}'")
205
+
206
+ if full_path.is_file():
207
+ inspect_pth_file(pth_path=full_path, save_dir=dir_path, verbose=2)
208
+
209
+ def _load_checkpoint(self, path: Union[str, Path], verbose: int = 3):
94
210
  """Loads a training checkpoint to resume training."""
95
211
  p = make_fullpath(path, enforce="file")
96
- _LOGGER.info(f"Loading checkpoint from '{p.name}'...")
212
+
213
+ if verbose >= 2:
214
+ _LOGGER.info(f"Loading checkpoint from '{p.name}'...")
97
215
 
98
216
  try:
99
217
  checkpoint = torch.load(p, map_location=self.device)
@@ -110,9 +228,11 @@ class _BaseDragonTrainer(ABC):
110
228
  # --- Load History ---
111
229
  if PyTorchCheckpointKeys.HISTORY in checkpoint:
112
230
  self.history = checkpoint[PyTorchCheckpointKeys.HISTORY]
113
- _LOGGER.info(f"Restored training history up to epoch {self.epoch}.")
231
+ if verbose >= 3:
232
+ _LOGGER.info(f"Restored training history up to epoch {self.epoch}.")
114
233
  else:
115
- _LOGGER.warning("No 'history' found in checkpoint. A new history will be started.")
234
+ if verbose >= 1:
235
+ _LOGGER.warning("No 'history' found in checkpoint. A new history will be started.")
116
236
  self.history = {} # Ensure it's at least an empty dict
117
237
 
118
238
  # --- Scheduler State Loading Logic ---
@@ -124,7 +244,8 @@ class _BaseDragonTrainer(ABC):
124
244
  try:
125
245
  self.scheduler.load_state_dict(checkpoint[PyTorchCheckpointKeys.SCHEDULER_STATE]) # type: ignore
126
246
  scheduler_name = self.scheduler.__class__.__name__
127
- _LOGGER.info(f"Restored LR scheduler state for: {scheduler_name}")
247
+ if verbose >= 3:
248
+ _LOGGER.info(f"Restored LR scheduler state for: {scheduler_name}")
128
249
  except Exception as e:
129
250
  # Loading failed, likely a mismatch
130
251
  scheduler_name = self.scheduler.__class__.__name__
@@ -134,7 +255,8 @@ class _BaseDragonTrainer(ABC):
134
255
  elif scheduler_object_exists and not scheduler_state_exists:
135
256
  # Case 2: Scheduler provided, but no state in checkpoint.
136
257
  scheduler_name = self.scheduler.__class__.__name__
137
- _LOGGER.warning(f"'{scheduler_name}' was provided, but no scheduler state was found in the checkpoint. The scheduler will start from its initial state.")
258
+ if verbose >= 1:
259
+ _LOGGER.warning(f"'{scheduler_name}' was provided, but no scheduler state was found in the checkpoint. The scheduler will start from its initial state.")
138
260
 
139
261
  elif not scheduler_object_exists and scheduler_state_exists:
140
262
  # Case 3: State in checkpoint, but no scheduler provided.
@@ -145,9 +267,11 @@ class _BaseDragonTrainer(ABC):
145
267
  for cb in self.callbacks:
146
268
  if isinstance(cb, DragonModelCheckpoint) and PyTorchCheckpointKeys.BEST_SCORE in checkpoint:
147
269
  cb.best = checkpoint[PyTorchCheckpointKeys.BEST_SCORE]
148
- _LOGGER.info(f"Restored {cb.__class__.__name__} 'best' score to: {cb.best:.4f}")
270
+ if verbose >= 3:
271
+ _LOGGER.info(f"Restored {cb.__class__.__name__} 'best' score to: {cb.best:.4f}")
149
272
 
150
- _LOGGER.info(f"Checkpoint loaded. Resuming training from epoch {self.start_epoch}.")
273
+ if verbose >= 2:
274
+ _LOGGER.info(f"Model restored to epoch {self.epoch}.")
151
275
 
152
276
  except Exception as e:
153
277
  _LOGGER.error(f"Failed to load checkpoint from '{p}': {e}")
@@ -243,16 +367,15 @@ class _BaseDragonTrainer(ABC):
243
367
  self.model.to(self.device)
244
368
  _LOGGER.info(f"Trainer and model moved to {self.device}.")
245
369
 
246
- def _load_model_state_for_finalizing(self, model_checkpoint: Union[Path, Literal['best', 'current']]):
370
+ def _load_model_state_wrapper(self, model_checkpoint: Union[Path, Literal['best', 'current']], verbose: int = 2):
247
371
  """
248
372
  Private helper to load the correct model state_dict based on user's choice.
249
- This is called by finalize_model_training() in subclasses.
250
373
  """
251
374
  if isinstance(model_checkpoint, Path):
252
- self._load_checkpoint(path=model_checkpoint)
375
+ self._load_checkpoint(path=model_checkpoint, verbose=verbose)
253
376
  elif model_checkpoint == MagicWords.BEST and self._checkpoint_callback:
254
377
  path_to_latest = self._checkpoint_callback.best_checkpoint_path
255
- self._load_checkpoint(path_to_latest)
378
+ self._load_checkpoint(path_to_latest, verbose=verbose)
256
379
  elif model_checkpoint == MagicWords.BEST and self._checkpoint_callback is None:
257
380
  _LOGGER.error(f"'model_checkpoint' set to '{MagicWords.BEST}' but no checkpoint callback was found.")
258
381
  raise ValueError()