dragon-ml-toolbox 19.14.0__py3-none-any.whl → 20.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (219) hide show
  1. {dragon_ml_toolbox-19.14.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/METADATA +29 -46
  2. dragon_ml_toolbox-20.0.0.dist-info/RECORD +178 -0
  3. ml_tools/{ETL_cleaning.py → ETL_cleaning/__init__.py} +13 -5
  4. ml_tools/ETL_cleaning/_basic_clean.py +351 -0
  5. ml_tools/ETL_cleaning/_clean_tools.py +128 -0
  6. ml_tools/ETL_cleaning/_dragon_cleaner.py +245 -0
  7. ml_tools/ETL_cleaning/_imprimir.py +13 -0
  8. ml_tools/{ETL_engineering.py → ETL_engineering/__init__.py} +8 -4
  9. ml_tools/ETL_engineering/_dragon_engineering.py +261 -0
  10. ml_tools/ETL_engineering/_imprimir.py +24 -0
  11. ml_tools/{_core/_ETL_engineering.py → ETL_engineering/_transforms.py} +14 -267
  12. ml_tools/{_core → GUI_tools}/_GUI_tools.py +37 -40
  13. ml_tools/{GUI_tools.py → GUI_tools/__init__.py} +7 -5
  14. ml_tools/GUI_tools/_imprimir.py +12 -0
  15. ml_tools/IO_tools/_IO_loggers.py +235 -0
  16. ml_tools/IO_tools/_IO_save_load.py +151 -0
  17. ml_tools/IO_tools/_IO_utils.py +140 -0
  18. ml_tools/{IO_tools.py → IO_tools/__init__.py} +13 -5
  19. ml_tools/IO_tools/_imprimir.py +14 -0
  20. ml_tools/MICE/_MICE_imputation.py +132 -0
  21. ml_tools/{MICE_imputation.py → MICE/__init__.py} +6 -7
  22. ml_tools/{_core/_MICE_imputation.py → MICE/_dragon_mice.py} +243 -322
  23. ml_tools/MICE/_imprimir.py +11 -0
  24. ml_tools/{ML_callbacks.py → ML_callbacks/__init__.py} +12 -4
  25. ml_tools/ML_callbacks/_base.py +101 -0
  26. ml_tools/ML_callbacks/_checkpoint.py +232 -0
  27. ml_tools/ML_callbacks/_early_stop.py +208 -0
  28. ml_tools/ML_callbacks/_imprimir.py +12 -0
  29. ml_tools/ML_callbacks/_scheduler.py +197 -0
  30. ml_tools/{ML_chaining_utilities.py → ML_chain/__init__.py} +8 -3
  31. ml_tools/{_core/_ML_chaining_utilities.py → ML_chain/_chaining_tools.py} +5 -129
  32. ml_tools/ML_chain/_dragon_chain.py +140 -0
  33. ml_tools/ML_chain/_imprimir.py +11 -0
  34. ml_tools/ML_configuration/__init__.py +90 -0
  35. ml_tools/ML_configuration/_base_model_config.py +69 -0
  36. ml_tools/ML_configuration/_finalize.py +366 -0
  37. ml_tools/ML_configuration/_imprimir.py +47 -0
  38. ml_tools/ML_configuration/_metrics.py +593 -0
  39. ml_tools/ML_configuration/_models.py +206 -0
  40. ml_tools/ML_configuration/_training.py +124 -0
  41. ml_tools/ML_datasetmaster/__init__.py +28 -0
  42. ml_tools/ML_datasetmaster/_base_datasetmaster.py +337 -0
  43. ml_tools/{_core/_ML_datasetmaster.py → ML_datasetmaster/_datasetmaster.py} +9 -329
  44. ml_tools/ML_datasetmaster/_imprimir.py +15 -0
  45. ml_tools/{_core/_ML_sequence_datasetmaster.py → ML_datasetmaster/_sequence_datasetmaster.py} +13 -15
  46. ml_tools/{_core/_ML_vision_datasetmaster.py → ML_datasetmaster/_vision_datasetmaster.py} +63 -65
  47. ml_tools/ML_evaluation/__init__.py +53 -0
  48. ml_tools/ML_evaluation/_classification.py +629 -0
  49. ml_tools/ML_evaluation/_feature_importance.py +409 -0
  50. ml_tools/ML_evaluation/_imprimir.py +25 -0
  51. ml_tools/ML_evaluation/_loss.py +92 -0
  52. ml_tools/ML_evaluation/_regression.py +273 -0
  53. ml_tools/{_core/_ML_sequence_evaluation.py → ML_evaluation/_sequence.py} +8 -11
  54. ml_tools/{_core/_ML_vision_evaluation.py → ML_evaluation/_vision.py} +12 -17
  55. ml_tools/{_core → ML_evaluation_captum}/_ML_evaluation_captum.py +11 -38
  56. ml_tools/{ML_evaluation_captum.py → ML_evaluation_captum/__init__.py} +6 -4
  57. ml_tools/ML_evaluation_captum/_imprimir.py +10 -0
  58. ml_tools/{_core → ML_finalize_handler}/_ML_finalize_handler.py +3 -7
  59. ml_tools/ML_finalize_handler/__init__.py +10 -0
  60. ml_tools/ML_finalize_handler/_imprimir.py +8 -0
  61. ml_tools/ML_inference/__init__.py +22 -0
  62. ml_tools/ML_inference/_base_inference.py +166 -0
  63. ml_tools/{_core/_ML_chaining_inference.py → ML_inference/_chain_inference.py} +14 -17
  64. ml_tools/ML_inference/_dragon_inference.py +332 -0
  65. ml_tools/ML_inference/_imprimir.py +11 -0
  66. ml_tools/ML_inference/_multi_inference.py +180 -0
  67. ml_tools/ML_inference_sequence/__init__.py +10 -0
  68. ml_tools/ML_inference_sequence/_imprimir.py +8 -0
  69. ml_tools/{_core/_ML_sequence_inference.py → ML_inference_sequence/_sequence_inference.py} +11 -15
  70. ml_tools/ML_inference_vision/__init__.py +10 -0
  71. ml_tools/ML_inference_vision/_imprimir.py +8 -0
  72. ml_tools/{_core/_ML_vision_inference.py → ML_inference_vision/_vision_inference.py} +15 -19
  73. ml_tools/ML_models/__init__.py +32 -0
  74. ml_tools/{_core/_ML_models_advanced.py → ML_models/_advanced_models.py} +22 -18
  75. ml_tools/ML_models/_base_mlp_attention.py +198 -0
  76. ml_tools/{_core/_models_advanced_base.py → ML_models/_base_save_load.py} +73 -49
  77. ml_tools/ML_models/_dragon_tabular.py +248 -0
  78. ml_tools/ML_models/_imprimir.py +18 -0
  79. ml_tools/ML_models/_mlp_attention.py +134 -0
  80. ml_tools/{_core → ML_models}/_models_advanced_helpers.py +13 -13
  81. ml_tools/ML_models_sequence/__init__.py +10 -0
  82. ml_tools/ML_models_sequence/_imprimir.py +8 -0
  83. ml_tools/{_core/_ML_sequence_models.py → ML_models_sequence/_sequence_models.py} +5 -8
  84. ml_tools/ML_models_vision/__init__.py +29 -0
  85. ml_tools/ML_models_vision/_base_wrapper.py +254 -0
  86. ml_tools/ML_models_vision/_image_classification.py +182 -0
  87. ml_tools/ML_models_vision/_image_segmentation.py +108 -0
  88. ml_tools/ML_models_vision/_imprimir.py +16 -0
  89. ml_tools/ML_models_vision/_object_detection.py +135 -0
  90. ml_tools/ML_optimization/__init__.py +21 -0
  91. ml_tools/ML_optimization/_imprimir.py +13 -0
  92. ml_tools/{_core/_ML_optimization_pareto.py → ML_optimization/_multi_dragon.py} +18 -24
  93. ml_tools/ML_optimization/_single_dragon.py +203 -0
  94. ml_tools/{_core/_ML_optimization.py → ML_optimization/_single_manual.py} +75 -213
  95. ml_tools/{_core → ML_scaler}/_ML_scaler.py +8 -11
  96. ml_tools/ML_scaler/__init__.py +10 -0
  97. ml_tools/ML_scaler/_imprimir.py +8 -0
  98. ml_tools/ML_trainer/__init__.py +20 -0
  99. ml_tools/ML_trainer/_base_trainer.py +297 -0
  100. ml_tools/ML_trainer/_dragon_detection_trainer.py +402 -0
  101. ml_tools/ML_trainer/_dragon_sequence_trainer.py +540 -0
  102. ml_tools/ML_trainer/_dragon_trainer.py +1160 -0
  103. ml_tools/ML_trainer/_imprimir.py +10 -0
  104. ml_tools/{ML_utilities.py → ML_utilities/__init__.py} +14 -6
  105. ml_tools/ML_utilities/_artifact_finder.py +382 -0
  106. ml_tools/ML_utilities/_imprimir.py +16 -0
  107. ml_tools/ML_utilities/_inspection.py +325 -0
  108. ml_tools/ML_utilities/_train_tools.py +205 -0
  109. ml_tools/{ML_vision_transformers.py → ML_vision_transformers/__init__.py} +9 -6
  110. ml_tools/{_core/_ML_vision_transformers.py → ML_vision_transformers/_core_transforms.py} +11 -155
  111. ml_tools/ML_vision_transformers/_imprimir.py +14 -0
  112. ml_tools/ML_vision_transformers/_offline_augmentation.py +159 -0
  113. ml_tools/{_core/_PSO_optimization.py → PSO_optimization/_PSO.py} +58 -15
  114. ml_tools/{PSO_optimization.py → PSO_optimization/__init__.py} +5 -3
  115. ml_tools/PSO_optimization/_imprimir.py +10 -0
  116. ml_tools/SQL/__init__.py +7 -0
  117. ml_tools/{_core/_SQL.py → SQL/_dragon_SQL.py} +7 -11
  118. ml_tools/SQL/_imprimir.py +8 -0
  119. ml_tools/{_core → VIF}/_VIF_factor.py +5 -8
  120. ml_tools/{VIF_factor.py → VIF/__init__.py} +4 -2
  121. ml_tools/VIF/_imprimir.py +10 -0
  122. ml_tools/_core/__init__.py +7 -1
  123. ml_tools/_core/_logger.py +8 -18
  124. ml_tools/_core/_schema_load_ops.py +43 -0
  125. ml_tools/_core/_script_info.py +2 -2
  126. ml_tools/{data_exploration.py → data_exploration/__init__.py} +32 -16
  127. ml_tools/data_exploration/_analysis.py +214 -0
  128. ml_tools/data_exploration/_cleaning.py +566 -0
  129. ml_tools/data_exploration/_features.py +583 -0
  130. ml_tools/data_exploration/_imprimir.py +32 -0
  131. ml_tools/data_exploration/_plotting.py +487 -0
  132. ml_tools/data_exploration/_schema_ops.py +176 -0
  133. ml_tools/{ensemble_evaluation.py → ensemble_evaluation/__init__.py} +6 -4
  134. ml_tools/{_core → ensemble_evaluation}/_ensemble_evaluation.py +3 -7
  135. ml_tools/ensemble_evaluation/_imprimir.py +14 -0
  136. ml_tools/{ensemble_inference.py → ensemble_inference/__init__.py} +5 -3
  137. ml_tools/{_core → ensemble_inference}/_ensemble_inference.py +15 -18
  138. ml_tools/ensemble_inference/_imprimir.py +9 -0
  139. ml_tools/{ensemble_learning.py → ensemble_learning/__init__.py} +4 -6
  140. ml_tools/{_core → ensemble_learning}/_ensemble_learning.py +7 -10
  141. ml_tools/ensemble_learning/_imprimir.py +10 -0
  142. ml_tools/{excel_handler.py → excel_handler/__init__.py} +5 -3
  143. ml_tools/{_core → excel_handler}/_excel_handler.py +6 -10
  144. ml_tools/excel_handler/_imprimir.py +13 -0
  145. ml_tools/{keys.py → keys/__init__.py} +4 -1
  146. ml_tools/keys/_imprimir.py +11 -0
  147. ml_tools/{_core → keys}/_keys.py +2 -0
  148. ml_tools/{math_utilities.py → math_utilities/__init__.py} +5 -2
  149. ml_tools/math_utilities/_imprimir.py +11 -0
  150. ml_tools/{_core → math_utilities}/_math_utilities.py +1 -5
  151. ml_tools/{optimization_tools.py → optimization_tools/__init__.py} +9 -4
  152. ml_tools/optimization_tools/_imprimir.py +13 -0
  153. ml_tools/optimization_tools/_optimization_bounds.py +236 -0
  154. ml_tools/optimization_tools/_optimization_plots.py +218 -0
  155. ml_tools/{path_manager.py → path_manager/__init__.py} +6 -3
  156. ml_tools/{_core/_path_manager.py → path_manager/_dragonmanager.py} +11 -347
  157. ml_tools/path_manager/_imprimir.py +15 -0
  158. ml_tools/path_manager/_path_tools.py +346 -0
  159. ml_tools/plot_fonts/__init__.py +8 -0
  160. ml_tools/plot_fonts/_imprimir.py +8 -0
  161. ml_tools/{_core → plot_fonts}/_plot_fonts.py +2 -5
  162. ml_tools/schema/__init__.py +15 -0
  163. ml_tools/schema/_feature_schema.py +223 -0
  164. ml_tools/schema/_gui_schema.py +191 -0
  165. ml_tools/schema/_imprimir.py +10 -0
  166. ml_tools/{serde.py → serde/__init__.py} +4 -2
  167. ml_tools/serde/_imprimir.py +10 -0
  168. ml_tools/{_core → serde}/_serde.py +3 -8
  169. ml_tools/{utilities.py → utilities/__init__.py} +11 -6
  170. ml_tools/utilities/_imprimir.py +18 -0
  171. ml_tools/{_core/_utilities.py → utilities/_utility_save_load.py} +13 -190
  172. ml_tools/utilities/_utility_tools.py +192 -0
  173. dragon_ml_toolbox-19.14.0.dist-info/RECORD +0 -111
  174. ml_tools/ML_chaining_inference.py +0 -8
  175. ml_tools/ML_configuration.py +0 -86
  176. ml_tools/ML_configuration_pytab.py +0 -14
  177. ml_tools/ML_datasetmaster.py +0 -10
  178. ml_tools/ML_evaluation.py +0 -16
  179. ml_tools/ML_evaluation_multi.py +0 -12
  180. ml_tools/ML_finalize_handler.py +0 -8
  181. ml_tools/ML_inference.py +0 -12
  182. ml_tools/ML_models.py +0 -14
  183. ml_tools/ML_models_advanced.py +0 -14
  184. ml_tools/ML_models_pytab.py +0 -14
  185. ml_tools/ML_optimization.py +0 -14
  186. ml_tools/ML_optimization_pareto.py +0 -8
  187. ml_tools/ML_scaler.py +0 -8
  188. ml_tools/ML_sequence_datasetmaster.py +0 -8
  189. ml_tools/ML_sequence_evaluation.py +0 -10
  190. ml_tools/ML_sequence_inference.py +0 -8
  191. ml_tools/ML_sequence_models.py +0 -8
  192. ml_tools/ML_trainer.py +0 -12
  193. ml_tools/ML_vision_datasetmaster.py +0 -12
  194. ml_tools/ML_vision_evaluation.py +0 -10
  195. ml_tools/ML_vision_inference.py +0 -8
  196. ml_tools/ML_vision_models.py +0 -18
  197. ml_tools/SQL.py +0 -8
  198. ml_tools/_core/_ETL_cleaning.py +0 -694
  199. ml_tools/_core/_IO_tools.py +0 -498
  200. ml_tools/_core/_ML_callbacks.py +0 -702
  201. ml_tools/_core/_ML_configuration.py +0 -1332
  202. ml_tools/_core/_ML_configuration_pytab.py +0 -102
  203. ml_tools/_core/_ML_evaluation.py +0 -867
  204. ml_tools/_core/_ML_evaluation_multi.py +0 -544
  205. ml_tools/_core/_ML_inference.py +0 -646
  206. ml_tools/_core/_ML_models.py +0 -668
  207. ml_tools/_core/_ML_models_pytab.py +0 -693
  208. ml_tools/_core/_ML_trainer.py +0 -2323
  209. ml_tools/_core/_ML_utilities.py +0 -886
  210. ml_tools/_core/_ML_vision_models.py +0 -644
  211. ml_tools/_core/_data_exploration.py +0 -1909
  212. ml_tools/_core/_optimization_tools.py +0 -493
  213. ml_tools/_core/_schema.py +0 -359
  214. ml_tools/plot_fonts.py +0 -8
  215. ml_tools/schema.py +0 -12
  216. {dragon_ml_toolbox-19.14.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/WHEEL +0 -0
  217. {dragon_ml_toolbox-19.14.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/licenses/LICENSE +0 -0
  218. {dragon_ml_toolbox-19.14.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +0 -0
  219. {dragon_ml_toolbox-19.14.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,540 @@
1
+ from typing import Literal, Union, Optional
2
+ from pathlib import Path
3
+ from torch.utils.data import DataLoader, Dataset
4
+ import torch
5
+ from torch import nn
6
+ import numpy as np
7
+
8
+ from ..ML_callbacks._base import _Callback
9
+ from ..ML_callbacks._checkpoint import DragonModelCheckpoint
10
+ from ..ML_callbacks._early_stop import _DragonEarlyStopping
11
+ from ..ML_callbacks._scheduler import _DragonLRScheduler
12
+ from ..ML_evaluation import sequence_to_sequence_metrics, sequence_to_value_metrics
13
+ from ..ML_evaluation_captum import captum_feature_importance
14
+ from ..ML_configuration import (FormatSequenceValueMetrics,
15
+ FormatSequenceSequenceMetrics,
16
+
17
+ FinalizeSequenceSequencePrediction,
18
+ FinalizeSequenceValuePrediction)
19
+
20
+ from ..path_manager import make_fullpath
21
+ from ..keys._keys import PyTorchLogKeys, PyTorchCheckpointKeys, DatasetKeys, MLTaskKeys, MagicWords, DragonTrainerKeys, ScalerKeys
22
+ from .._core import get_logger
23
+
24
+ from ._base_trainer import _BaseDragonTrainer
25
+
26
+
27
+ _LOGGER = get_logger("DragonSequenceTrainer")
28
+
29
+
30
+ __all__ = [
31
+ "DragonSequenceTrainer"
32
+ ]
33
+
34
+
35
+ # --- DragonSequenceTrainer ----
36
+ class DragonSequenceTrainer(_BaseDragonTrainer):
37
+ def __init__(self,
38
+ model: nn.Module,
39
+ train_dataset: Dataset,
40
+ validation_dataset: Dataset,
41
+ kind: Literal["sequence-to-sequence", "sequence-to-value"],
42
+ optimizer: torch.optim.Optimizer,
43
+ device: Union[Literal['cuda', 'mps', 'cpu'],str],
44
+ checkpoint_callback: Optional[DragonModelCheckpoint],
45
+ early_stopping_callback: Optional[_DragonEarlyStopping],
46
+ lr_scheduler_callback: Optional[_DragonLRScheduler],
47
+ extra_callbacks: Optional[list[_Callback]] = None,
48
+ criterion: Union[nn.Module,Literal["auto"]] = "auto",
49
+ dataloader_workers: int = 2):
50
+ """
51
+ Automates the training process of a PyTorch Sequence Model.
52
+
53
+ Built-in Callbacks: `History`, `TqdmProgressBar`
54
+
55
+ Args:
56
+ model (nn.Module): The PyTorch model to train.
57
+ train_dataset (Dataset): The training dataset.
58
+ validation_dataset (Dataset): The validation dataset.
59
+ kind (str): Used to redirect to the correct process ('sequence-to-sequence' or 'sequence-to-value').
60
+ criterion (nn.Module | "auto"): The loss function to use. If "auto", it will be inferred from the selected task
61
+ optimizer (torch.optim.Optimizer): The optimizer.
62
+ device (str): The device to run training on ('cpu', 'cuda', 'mps').
63
+ dataloader_workers (int): Subprocesses for data loading.
64
+ extra_callbacks (List[Callback] | None): A list of extra callbacks to use during training.
65
+ """
66
+ # Call the base class constructor with common parameters
67
+ super().__init__(
68
+ model=model,
69
+ optimizer=optimizer,
70
+ device=device,
71
+ dataloader_workers=dataloader_workers,
72
+ checkpoint_callback=checkpoint_callback,
73
+ early_stopping_callback=early_stopping_callback,
74
+ lr_scheduler_callback=lr_scheduler_callback,
75
+ extra_callbacks=extra_callbacks
76
+ )
77
+
78
+ if kind not in [MLTaskKeys.SEQUENCE_SEQUENCE, MLTaskKeys.SEQUENCE_VALUE]:
79
+ raise ValueError(f"'{kind}' is not a valid task type for DragonSequenceTrainer.")
80
+
81
+ self.train_dataset = train_dataset
82
+ self.validation_dataset = validation_dataset
83
+ self.kind = kind
84
+
85
+ # try to validate against Dragon Sequence model
86
+ if hasattr(self.model, "prediction_mode"):
87
+ key_to_check: str = self.model.prediction_mode # type: ignore
88
+ if not key_to_check == self.kind:
89
+ _LOGGER.error(f"Trainer was set for '{self.kind}', but model architecture '{self.model}' is built for '{key_to_check}'.")
90
+ raise RuntimeError()
91
+
92
+ # loss function
93
+ if criterion == "auto":
94
+ # Both sequence tasks are treated as regression problems
95
+ self.criterion = nn.MSELoss()
96
+ else:
97
+ self.criterion = criterion
98
+
99
+ def _create_dataloaders(self, batch_size: int, shuffle: bool):
100
+ """Initializes the DataLoaders."""
101
+ # Ensure stability on MPS devices by setting num_workers to 0
102
+ loader_workers = 0 if self.device.type == 'mps' else self.dataloader_workers
103
+
104
+ self.train_loader = DataLoader(
105
+ dataset=self.train_dataset,
106
+ batch_size=batch_size,
107
+ shuffle=shuffle,
108
+ num_workers=loader_workers,
109
+ pin_memory=("cuda" in self.device.type),
110
+ drop_last=True # Drops the last batch if incomplete, selecting a good batch size is key.
111
+ )
112
+
113
+ self.validation_loader = DataLoader(
114
+ dataset=self.validation_dataset,
115
+ batch_size=batch_size,
116
+ shuffle=False,
117
+ num_workers=loader_workers,
118
+ pin_memory=("cuda" in self.device.type)
119
+ )
120
+
121
+ def _train_step(self):
122
+ self.model.train()
123
+ running_loss = 0.0
124
+ total_samples = 0
125
+
126
+ for batch_idx, (features, target) in enumerate(self.train_loader): # type: ignore
127
+ # Create a log dictionary for the batch
128
+ batch_logs = {
129
+ PyTorchLogKeys.BATCH_INDEX: batch_idx,
130
+ PyTorchLogKeys.BATCH_SIZE: features.size(0)
131
+ }
132
+ self._callbacks_hook('on_batch_begin', batch_idx, logs=batch_logs)
133
+
134
+ features, target = features.to(self.device), target.to(self.device)
135
+ self.optimizer.zero_grad()
136
+
137
+ output = self.model(features)
138
+
139
+ # --- Label Type/Shape Correction ---
140
+ # Ensure target is float for MSELoss
141
+ target = target.float()
142
+
143
+ # For seq-to-val, models might output [N, 1] but target is [N].
144
+ if self.kind == MLTaskKeys.SEQUENCE_VALUE:
145
+ if output.ndim == 2 and output.shape[1] == 1 and target.ndim == 1:
146
+ output = output.squeeze(1)
147
+
148
+ # For seq-to-seq, models might output [N, Seq, 1] but target is [N, Seq].
149
+ elif self.kind == MLTaskKeys.SEQUENCE_SEQUENCE:
150
+ if output.ndim == 3 and output.shape[2] == 1 and target.ndim == 2:
151
+ output = output.squeeze(-1)
152
+
153
+ loss = self.criterion(output, target)
154
+
155
+ loss.backward()
156
+ self.optimizer.step()
157
+
158
+ # Calculate batch loss and update running loss for the epoch
159
+ batch_loss = loss.item()
160
+ batch_size = features.size(0)
161
+ running_loss += batch_loss * batch_size # Accumulate total loss
162
+ total_samples += batch_size # total samples
163
+
164
+ # Add the batch loss to the logs and call the end-of-batch hook
165
+ batch_logs[PyTorchLogKeys.BATCH_LOSS] = batch_loss
166
+ self._callbacks_hook('on_batch_end', batch_idx, logs=batch_logs)
167
+
168
+ if total_samples == 0:
169
+ _LOGGER.warning("No samples processed in a train_step. Returning 0 loss.")
170
+ return {PyTorchLogKeys.TRAIN_LOSS: 0.0}
171
+
172
+ return {PyTorchLogKeys.TRAIN_LOSS: running_loss / total_samples} # type: ignore
173
+
174
+ def _validation_step(self):
175
+ self.model.eval()
176
+ running_loss = 0.0
177
+
178
+ with torch.no_grad():
179
+ for features, target in self.validation_loader: # type: ignore
180
+ features, target = features.to(self.device), target.to(self.device)
181
+
182
+ output = self.model(features)
183
+
184
+ # --- Label Type/Shape Correction ---
185
+ target = target.float()
186
+
187
+ # For seq-to-val, models might output [N, 1] but target is [N].
188
+ if self.kind == MLTaskKeys.SEQUENCE_VALUE:
189
+ if output.ndim == 2 and output.shape[1] == 1 and target.ndim == 1:
190
+ output = output.squeeze(1)
191
+
192
+ # For seq-to-seq, models might output [N, Seq, 1] but target is [N, Seq].
193
+ elif self.kind == MLTaskKeys.SEQUENCE_SEQUENCE:
194
+ if output.ndim == 3 and output.shape[2] == 1 and target.ndim == 2:
195
+ output = output.squeeze(-1)
196
+
197
+ loss = self.criterion(output, target)
198
+
199
+ running_loss += loss.item() * features.size(0)
200
+
201
+ if not self.validation_loader.dataset: # type: ignore
202
+ _LOGGER.warning("No samples processed in _validation_step. Returning 0 loss.")
203
+ return {PyTorchLogKeys.VAL_LOSS: 0.0}
204
+
205
+ logs = {PyTorchLogKeys.VAL_LOSS: running_loss / len(self.validation_loader.dataset)} # type: ignore
206
+ return logs
207
+
208
+ def _predict_for_eval(self, dataloader: DataLoader):
209
+ """
210
+ Private method to yield model predictions batch by batch for evaluation.
211
+
212
+ Automatically checks for 'scaler'.
213
+
214
+ Yields:
215
+ tuple: A tuple containing (y_pred_batch, y_prob_batch, y_true_batch).
216
+ y_prob_batch is always None for sequence tasks.
217
+ """
218
+ self.model.eval()
219
+ self.model.to(self.device)
220
+
221
+ # --- Check for Scaler ---
222
+ # DragonDatasetSequence stores it as 'scaler'
223
+ scaler = None
224
+ if hasattr(self.train_dataset, ScalerKeys.TARGET_SCALER):
225
+ scaler = getattr(self.train_dataset, ScalerKeys.TARGET_SCALER)
226
+ if scaler is not None:
227
+ _LOGGER.debug("Sequence scaler detected. Un-scaling predictions and targets.")
228
+
229
+ with torch.no_grad():
230
+ for features, target in dataloader:
231
+ features = features.to(self.device)
232
+ target = target.to(self.device)
233
+
234
+ output = self.model(features)
235
+
236
+ # --- Automatic Un-scaling Logic ---
237
+ if scaler:
238
+ # 1. Reshape for scaler (N, 1) or (N*Seq, 1)
239
+ original_out_shape = output.shape
240
+ original_target_shape = target.shape
241
+
242
+ # Flatten sequence dims
243
+ output_flat = output.reshape(-1, 1)
244
+ target_flat = target.reshape(-1, 1)
245
+
246
+ # 2. Inverse Transform
247
+ output_flat = scaler.inverse_transform(output_flat)
248
+ target_flat = scaler.inverse_transform(target_flat)
249
+
250
+ # 3. Restore
251
+ output = output_flat.reshape(original_out_shape)
252
+ target = target_flat.reshape(original_target_shape)
253
+
254
+ # Move to CPU
255
+ y_pred_batch = output.cpu().numpy()
256
+ y_true_batch = target.cpu().numpy()
257
+ y_prob_batch = None
258
+
259
+ yield y_pred_batch, y_prob_batch, y_true_batch
260
+
261
+ def evaluate(self,
262
+ save_dir: Union[str, Path],
263
+ model_checkpoint: Union[Path, Literal["best", "current"]],
264
+ test_data: Optional[Union[DataLoader, Dataset]] = None,
265
+ val_format_configuration: Optional[Union[FormatSequenceValueMetrics,
266
+ FormatSequenceSequenceMetrics]]=None,
267
+ test_format_configuration: Optional[Union[FormatSequenceValueMetrics,
268
+ FormatSequenceSequenceMetrics]]=None):
269
+ """
270
+ Evaluates the model, routing to the correct evaluation function.
271
+
272
+ Args:
273
+ model_checkpoint (Path | "best" | "current"):
274
+ - Path to a valid checkpoint for the model.
275
+ - If 'best', the best checkpoint will be loaded.
276
+ - If 'current', use the current state of the trained model.
277
+ save_dir (str | Path): Directory to save all reports and plots.
278
+ test_data (DataLoader | Dataset | None): Optional Test data.
279
+ val_format_configuration: Optional configuration for validation metrics.
280
+ test_format_configuration: Optional configuration for test metrics.
281
+ """
282
+ # Validate model checkpoint
283
+ if isinstance(model_checkpoint, Path):
284
+ checkpoint_validated = make_fullpath(model_checkpoint, enforce="file")
285
+ elif model_checkpoint in [MagicWords.BEST, MagicWords.CURRENT]:
286
+ checkpoint_validated = model_checkpoint
287
+ else:
288
+ _LOGGER.error(f"'model_checkpoint' must be a Path object, or '{MagicWords.BEST}', or '{MagicWords.CURRENT}'.")
289
+ raise ValueError()
290
+
291
+ # Validate val configuration
292
+ if val_format_configuration is not None:
293
+ if not isinstance(val_format_configuration, (FormatSequenceValueMetrics, FormatSequenceSequenceMetrics)):
294
+ _LOGGER.error(f"Invalid 'val_format_configuration': '{type(val_format_configuration)}'.")
295
+ raise ValueError()
296
+
297
+ # Validate directory
298
+ save_path = make_fullpath(save_dir, make=True, enforce="directory")
299
+
300
+ # Validate test data and dispatch
301
+ if test_data is not None:
302
+ if not isinstance(test_data, (DataLoader, Dataset)):
303
+ _LOGGER.error(f"Invalid type for 'test_data': '{type(test_data)}'.")
304
+ raise ValueError()
305
+ test_data_validated = test_data
306
+
307
+ validation_metrics_path = save_path / DragonTrainerKeys.VALIDATION_METRICS_DIR
308
+ test_metrics_path = save_path / DragonTrainerKeys.TEST_METRICS_DIR
309
+
310
+ # Dispatch validation set
311
+ _LOGGER.info(f"Evaluating on validation dataset. Metrics will be saved to '{DragonTrainerKeys.VALIDATION_METRICS_DIR}'")
312
+ self._evaluate(save_dir=validation_metrics_path,
313
+ model_checkpoint=checkpoint_validated,
314
+ data=None,
315
+ format_configuration=val_format_configuration)
316
+
317
+ # Validate test configuration
318
+ test_configuration_validated = None
319
+ if test_format_configuration is not None:
320
+ if not isinstance(test_format_configuration, (FormatSequenceValueMetrics, FormatSequenceSequenceMetrics)):
321
+ warning_message_type = f"Invalid test_format_configuration': '{type(test_format_configuration)}'."
322
+ if val_format_configuration is not None:
323
+ warning_message_type += " 'val_format_configuration' will be used."
324
+ test_configuration_validated = val_format_configuration
325
+ else:
326
+ warning_message_type += " Using default format."
327
+ _LOGGER.warning(warning_message_type)
328
+ else:
329
+ test_configuration_validated = test_format_configuration
330
+
331
+ # Dispatch test set
332
+ _LOGGER.info(f"Evaluating on test dataset. Metrics will be saved to '{DragonTrainerKeys.TEST_METRICS_DIR}'")
333
+ self._evaluate(save_dir=test_metrics_path,
334
+ model_checkpoint="current",
335
+ data=test_data_validated,
336
+ format_configuration=test_configuration_validated)
337
+ else:
338
+ # Dispatch validation set
339
+ _LOGGER.info(f"Evaluating on validation dataset. Metrics will be saved to '{save_path.name}'")
340
+ self._evaluate(save_dir=save_path,
341
+ model_checkpoint=checkpoint_validated,
342
+ data=None,
343
+ format_configuration=val_format_configuration)
344
+
345
+ def _evaluate(self,
346
+ save_dir: Union[str, Path],
347
+ model_checkpoint: Union[Path, Literal["best", "current"]],
348
+ data: Optional[Union[DataLoader, Dataset]],
349
+ format_configuration: object):
350
+ """
351
+ Private evaluation helper.
352
+ """
353
+ eval_loader = None
354
+
355
+ # load model checkpoint
356
+ if isinstance(model_checkpoint, Path):
357
+ self._load_checkpoint(path=model_checkpoint)
358
+ elif model_checkpoint == MagicWords.BEST and self._checkpoint_callback:
359
+ path_to_latest = self._checkpoint_callback.best_checkpoint_path
360
+ self._load_checkpoint(path_to_latest)
361
+ elif model_checkpoint == MagicWords.BEST and self._checkpoint_callback is None:
362
+ _LOGGER.error(f"'model_checkpoint' set to '{MagicWords.BEST}' but no checkpoint callback was found.")
363
+ raise ValueError()
364
+
365
+ # Dataloader
366
+ if isinstance(data, DataLoader):
367
+ eval_loader = data
368
+ elif isinstance(data, Dataset):
369
+ # Create a new loader from the provided dataset
370
+ eval_loader = DataLoader(data,
371
+ batch_size=self._batch_size,
372
+ shuffle=False,
373
+ num_workers=0 if self.device.type == 'mps' else self.dataloader_workers,
374
+ pin_memory=(self.device.type == "cuda"))
375
+ else: # data is None, use the trainer's default validation dataset
376
+ if self.validation_dataset is None:
377
+ _LOGGER.error("Cannot evaluate. No data provided and no validation_dataset available in the trainer.")
378
+ raise ValueError()
379
+ eval_loader = DataLoader(self.validation_dataset,
380
+ batch_size=self._batch_size,
381
+ shuffle=False,
382
+ num_workers=0 if self.device.type == 'mps' else self.dataloader_workers,
383
+ pin_memory=(self.device.type == "cuda"))
384
+
385
+ if eval_loader is None:
386
+ _LOGGER.error("Cannot evaluate. No valid data was provided or found.")
387
+ raise ValueError()
388
+
389
+ all_preds, _, all_true = [], [], []
390
+ for y_pred_b, y_prob_b, y_true_b in self._predict_for_eval(eval_loader):
391
+ if y_pred_b is not None: all_preds.append(y_pred_b)
392
+ if y_true_b is not None: all_true.append(y_true_b)
393
+
394
+ if not all_true:
395
+ _LOGGER.error("Evaluation failed: No data was processed.")
396
+ return
397
+
398
+ y_pred = np.concatenate(all_preds)
399
+ y_true = np.concatenate(all_true)
400
+
401
+ # --- Routing Logic ---
402
+ if self.kind == MLTaskKeys.SEQUENCE_VALUE:
403
+ config = None
404
+ if format_configuration and isinstance(format_configuration, FormatSequenceValueMetrics):
405
+ config = format_configuration
406
+ elif format_configuration:
407
+ _LOGGER.warning(f"Wrong config type: Received {type(format_configuration).__name__}, expected SequenceValueMetricsFormat.")
408
+
409
+ sequence_to_value_metrics(y_true=y_true,
410
+ y_pred=y_pred,
411
+ save_dir=save_dir,
412
+ config=config)
413
+
414
+ elif self.kind == MLTaskKeys.SEQUENCE_SEQUENCE:
415
+ config = None
416
+ if format_configuration and isinstance(format_configuration, FormatSequenceSequenceMetrics):
417
+ config = format_configuration
418
+ elif format_configuration:
419
+ _LOGGER.warning(f"Wrong config type: Received {type(format_configuration).__name__}, expected SequenceSequenceMetricsFormat.")
420
+
421
+ sequence_to_sequence_metrics(y_true=y_true,
422
+ y_pred=y_pred,
423
+ save_dir=save_dir,
424
+ config=config)
425
+
426
+ def explain_captum(self,
427
+ save_dir: Union[str, Path],
428
+ explain_dataset: Optional[Dataset] = None,
429
+ n_samples: int = 100,
430
+ feature_names: Optional[list[str]] = None,
431
+ target_names: Optional[list[str]] = None,
432
+ n_steps: int = 50):
433
+ """
434
+ Explains sequence model predictions using Captum's Integrated Gradients.
435
+
436
+ This method calculates global feature importance by aggregating attributions across
437
+ the time dimension.
438
+ - For **multivariate** sequences, it highlights which variables (channels) are most influential.
439
+ - For **univariate** sequences, it attributes importance to the single signal feature.
440
+
441
+ Args:
442
+ save_dir (str | Path): Directory to save the importance plots and CSV reports.
443
+ explain_dataset (Dataset | None): A specific dataset to sample from. If None, the
444
+ trainer's validation dataset is used.
445
+ n_samples (int): The number of samples to use for the explanation (background + inputs).
446
+ feature_names (List[str] | None): Names of the features (signals). If None, attempts to extract them from the dataset attribute.
447
+ target_names (List[str] | None): Names of the model outputs (e.g., for Seq2Seq or Multivariate output). If None, attempts to extract them from the dataset attribute.
448
+ n_steps (int): Number of integral approximation steps.
449
+
450
+ Note:
451
+ For univariate data (Shape: N, Seq_Len), the 'feature' is the signal itself.
452
+ """
453
+ dataset_to_use = explain_dataset if explain_dataset is not None else self.validation_dataset
454
+ if dataset_to_use is None:
455
+ _LOGGER.error("No dataset available for explanation.")
456
+ return
457
+
458
+ # Helper to sample data (same as DragonTrainer)
459
+ def _get_samples(ds, n):
460
+ loader = DataLoader(ds, batch_size=n, shuffle=True, num_workers=0)
461
+ data_iter = iter(loader)
462
+ features, targets = next(data_iter)
463
+ return features, targets
464
+
465
+ input_data, _ = _get_samples(dataset_to_use, n_samples)
466
+
467
+ if feature_names is None:
468
+ if hasattr(dataset_to_use, DatasetKeys.FEATURE_NAMES):
469
+ feature_names = dataset_to_use.feature_names # type: ignore
470
+ else:
471
+ # If retrieval fails, leave it as None.
472
+ _LOGGER.warning("'feature_names' not provided and not found in dataset. Generic names will be used.")
473
+
474
+ if target_names is None:
475
+ if hasattr(dataset_to_use, DatasetKeys.TARGET_NAMES):
476
+ target_names = dataset_to_use.target_names # type: ignore
477
+ else:
478
+ # If retrieval fails, leave it as None.
479
+ _LOGGER.warning("'target_names' not provided and not found in dataset. Generic names will be used.")
480
+
481
+ # Sequence models usually output [N, 1] (Value) or [N, Seq, 1] (Seq2Seq)
482
+ # captum_feature_importance handles the aggregation.
483
+
484
+ captum_feature_importance(
485
+ model=self.model,
486
+ input_data=input_data,
487
+ feature_names=feature_names,
488
+ save_dir=save_dir,
489
+ target_names=target_names,
490
+ n_steps=n_steps,
491
+ device=self.device
492
+ )
493
+
494
+ def finalize_model_training(self,
495
+ save_dir: Union[str, Path],
496
+ model_checkpoint: Union[Path, Literal['best', 'current']],
497
+ finalize_config: Union[FinalizeSequenceSequencePrediction, FinalizeSequenceValuePrediction]):
498
+ """
499
+ Saves a finalized, "inference-ready" model state to a .pth file.
500
+
501
+ This method saves the model's `state_dict` and the final epoch number.
502
+
503
+ Args:
504
+ save_dir (Union[str, Path]): The directory to save the finalized model.
505
+ model_checkpoint (Union[Path, Literal["best", "current"]]):
506
+ - Path: Loads the model state from a specific checkpoint file.
507
+ - "best": Loads the best model state saved by the `DragonModelCheckpoint` callback.
508
+ - "current": Uses the model's state as it is.
509
+ finalize_config (FinalizeSequencePrediction): A data class instance specific to the ML task containing task-specific metadata required for inference.
510
+ """
511
+ if self.kind == MLTaskKeys.SEQUENCE_SEQUENCE and not isinstance(finalize_config, FinalizeSequenceSequencePrediction):
512
+ _LOGGER.error(f"Received a wrong finalize configuration for task {self.kind}: {type(finalize_config).__name__}.")
513
+ raise TypeError()
514
+ elif self.kind == MLTaskKeys.SEQUENCE_VALUE and not isinstance(finalize_config, FinalizeSequenceValuePrediction):
515
+ _LOGGER.error(f"Received a wrong finalize configuration for task {self.kind}: {type(finalize_config).__name__}.")
516
+ raise TypeError()
517
+
518
+ # handle save path
519
+ dir_path = make_fullpath(save_dir, make=True, enforce="directory")
520
+ full_path = dir_path / finalize_config.filename
521
+
522
+ # handle checkpoint
523
+ self._load_model_state_for_finalizing(model_checkpoint)
524
+
525
+ # Create finalized data
526
+ finalized_data = {
527
+ PyTorchCheckpointKeys.EPOCH: self.epoch,
528
+ PyTorchCheckpointKeys.MODEL_STATE: self.model.state_dict(),
529
+ PyTorchCheckpointKeys.TASK: finalize_config.task
530
+ }
531
+
532
+ if finalize_config.sequence_length is not None:
533
+ finalized_data[PyTorchCheckpointKeys.SEQUENCE_LENGTH] = finalize_config.sequence_length
534
+ if finalize_config.initial_sequence is not None:
535
+ finalized_data[PyTorchCheckpointKeys.INITIAL_SEQUENCE] = finalize_config.initial_sequence
536
+
537
+ torch.save(finalized_data, full_path)
538
+
539
+ _LOGGER.info(f"Finalized model file saved to '{full_path}'")
540
+