dragon-ml-toolbox 19.14.0__py3-none-any.whl → 20.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (219) hide show
  1. {dragon_ml_toolbox-19.14.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/METADATA +29 -46
  2. dragon_ml_toolbox-20.0.0.dist-info/RECORD +178 -0
  3. ml_tools/{ETL_cleaning.py → ETL_cleaning/__init__.py} +13 -5
  4. ml_tools/ETL_cleaning/_basic_clean.py +351 -0
  5. ml_tools/ETL_cleaning/_clean_tools.py +128 -0
  6. ml_tools/ETL_cleaning/_dragon_cleaner.py +245 -0
  7. ml_tools/ETL_cleaning/_imprimir.py +13 -0
  8. ml_tools/{ETL_engineering.py → ETL_engineering/__init__.py} +8 -4
  9. ml_tools/ETL_engineering/_dragon_engineering.py +261 -0
  10. ml_tools/ETL_engineering/_imprimir.py +24 -0
  11. ml_tools/{_core/_ETL_engineering.py → ETL_engineering/_transforms.py} +14 -267
  12. ml_tools/{_core → GUI_tools}/_GUI_tools.py +37 -40
  13. ml_tools/{GUI_tools.py → GUI_tools/__init__.py} +7 -5
  14. ml_tools/GUI_tools/_imprimir.py +12 -0
  15. ml_tools/IO_tools/_IO_loggers.py +235 -0
  16. ml_tools/IO_tools/_IO_save_load.py +151 -0
  17. ml_tools/IO_tools/_IO_utils.py +140 -0
  18. ml_tools/{IO_tools.py → IO_tools/__init__.py} +13 -5
  19. ml_tools/IO_tools/_imprimir.py +14 -0
  20. ml_tools/MICE/_MICE_imputation.py +132 -0
  21. ml_tools/{MICE_imputation.py → MICE/__init__.py} +6 -7
  22. ml_tools/{_core/_MICE_imputation.py → MICE/_dragon_mice.py} +243 -322
  23. ml_tools/MICE/_imprimir.py +11 -0
  24. ml_tools/{ML_callbacks.py → ML_callbacks/__init__.py} +12 -4
  25. ml_tools/ML_callbacks/_base.py +101 -0
  26. ml_tools/ML_callbacks/_checkpoint.py +232 -0
  27. ml_tools/ML_callbacks/_early_stop.py +208 -0
  28. ml_tools/ML_callbacks/_imprimir.py +12 -0
  29. ml_tools/ML_callbacks/_scheduler.py +197 -0
  30. ml_tools/{ML_chaining_utilities.py → ML_chain/__init__.py} +8 -3
  31. ml_tools/{_core/_ML_chaining_utilities.py → ML_chain/_chaining_tools.py} +5 -129
  32. ml_tools/ML_chain/_dragon_chain.py +140 -0
  33. ml_tools/ML_chain/_imprimir.py +11 -0
  34. ml_tools/ML_configuration/__init__.py +90 -0
  35. ml_tools/ML_configuration/_base_model_config.py +69 -0
  36. ml_tools/ML_configuration/_finalize.py +366 -0
  37. ml_tools/ML_configuration/_imprimir.py +47 -0
  38. ml_tools/ML_configuration/_metrics.py +593 -0
  39. ml_tools/ML_configuration/_models.py +206 -0
  40. ml_tools/ML_configuration/_training.py +124 -0
  41. ml_tools/ML_datasetmaster/__init__.py +28 -0
  42. ml_tools/ML_datasetmaster/_base_datasetmaster.py +337 -0
  43. ml_tools/{_core/_ML_datasetmaster.py → ML_datasetmaster/_datasetmaster.py} +9 -329
  44. ml_tools/ML_datasetmaster/_imprimir.py +15 -0
  45. ml_tools/{_core/_ML_sequence_datasetmaster.py → ML_datasetmaster/_sequence_datasetmaster.py} +13 -15
  46. ml_tools/{_core/_ML_vision_datasetmaster.py → ML_datasetmaster/_vision_datasetmaster.py} +63 -65
  47. ml_tools/ML_evaluation/__init__.py +53 -0
  48. ml_tools/ML_evaluation/_classification.py +629 -0
  49. ml_tools/ML_evaluation/_feature_importance.py +409 -0
  50. ml_tools/ML_evaluation/_imprimir.py +25 -0
  51. ml_tools/ML_evaluation/_loss.py +92 -0
  52. ml_tools/ML_evaluation/_regression.py +273 -0
  53. ml_tools/{_core/_ML_sequence_evaluation.py → ML_evaluation/_sequence.py} +8 -11
  54. ml_tools/{_core/_ML_vision_evaluation.py → ML_evaluation/_vision.py} +12 -17
  55. ml_tools/{_core → ML_evaluation_captum}/_ML_evaluation_captum.py +11 -38
  56. ml_tools/{ML_evaluation_captum.py → ML_evaluation_captum/__init__.py} +6 -4
  57. ml_tools/ML_evaluation_captum/_imprimir.py +10 -0
  58. ml_tools/{_core → ML_finalize_handler}/_ML_finalize_handler.py +3 -7
  59. ml_tools/ML_finalize_handler/__init__.py +10 -0
  60. ml_tools/ML_finalize_handler/_imprimir.py +8 -0
  61. ml_tools/ML_inference/__init__.py +22 -0
  62. ml_tools/ML_inference/_base_inference.py +166 -0
  63. ml_tools/{_core/_ML_chaining_inference.py → ML_inference/_chain_inference.py} +14 -17
  64. ml_tools/ML_inference/_dragon_inference.py +332 -0
  65. ml_tools/ML_inference/_imprimir.py +11 -0
  66. ml_tools/ML_inference/_multi_inference.py +180 -0
  67. ml_tools/ML_inference_sequence/__init__.py +10 -0
  68. ml_tools/ML_inference_sequence/_imprimir.py +8 -0
  69. ml_tools/{_core/_ML_sequence_inference.py → ML_inference_sequence/_sequence_inference.py} +11 -15
  70. ml_tools/ML_inference_vision/__init__.py +10 -0
  71. ml_tools/ML_inference_vision/_imprimir.py +8 -0
  72. ml_tools/{_core/_ML_vision_inference.py → ML_inference_vision/_vision_inference.py} +15 -19
  73. ml_tools/ML_models/__init__.py +32 -0
  74. ml_tools/{_core/_ML_models_advanced.py → ML_models/_advanced_models.py} +22 -18
  75. ml_tools/ML_models/_base_mlp_attention.py +198 -0
  76. ml_tools/{_core/_models_advanced_base.py → ML_models/_base_save_load.py} +73 -49
  77. ml_tools/ML_models/_dragon_tabular.py +248 -0
  78. ml_tools/ML_models/_imprimir.py +18 -0
  79. ml_tools/ML_models/_mlp_attention.py +134 -0
  80. ml_tools/{_core → ML_models}/_models_advanced_helpers.py +13 -13
  81. ml_tools/ML_models_sequence/__init__.py +10 -0
  82. ml_tools/ML_models_sequence/_imprimir.py +8 -0
  83. ml_tools/{_core/_ML_sequence_models.py → ML_models_sequence/_sequence_models.py} +5 -8
  84. ml_tools/ML_models_vision/__init__.py +29 -0
  85. ml_tools/ML_models_vision/_base_wrapper.py +254 -0
  86. ml_tools/ML_models_vision/_image_classification.py +182 -0
  87. ml_tools/ML_models_vision/_image_segmentation.py +108 -0
  88. ml_tools/ML_models_vision/_imprimir.py +16 -0
  89. ml_tools/ML_models_vision/_object_detection.py +135 -0
  90. ml_tools/ML_optimization/__init__.py +21 -0
  91. ml_tools/ML_optimization/_imprimir.py +13 -0
  92. ml_tools/{_core/_ML_optimization_pareto.py → ML_optimization/_multi_dragon.py} +18 -24
  93. ml_tools/ML_optimization/_single_dragon.py +203 -0
  94. ml_tools/{_core/_ML_optimization.py → ML_optimization/_single_manual.py} +75 -213
  95. ml_tools/{_core → ML_scaler}/_ML_scaler.py +8 -11
  96. ml_tools/ML_scaler/__init__.py +10 -0
  97. ml_tools/ML_scaler/_imprimir.py +8 -0
  98. ml_tools/ML_trainer/__init__.py +20 -0
  99. ml_tools/ML_trainer/_base_trainer.py +297 -0
  100. ml_tools/ML_trainer/_dragon_detection_trainer.py +402 -0
  101. ml_tools/ML_trainer/_dragon_sequence_trainer.py +540 -0
  102. ml_tools/ML_trainer/_dragon_trainer.py +1160 -0
  103. ml_tools/ML_trainer/_imprimir.py +10 -0
  104. ml_tools/{ML_utilities.py → ML_utilities/__init__.py} +14 -6
  105. ml_tools/ML_utilities/_artifact_finder.py +382 -0
  106. ml_tools/ML_utilities/_imprimir.py +16 -0
  107. ml_tools/ML_utilities/_inspection.py +325 -0
  108. ml_tools/ML_utilities/_train_tools.py +205 -0
  109. ml_tools/{ML_vision_transformers.py → ML_vision_transformers/__init__.py} +9 -6
  110. ml_tools/{_core/_ML_vision_transformers.py → ML_vision_transformers/_core_transforms.py} +11 -155
  111. ml_tools/ML_vision_transformers/_imprimir.py +14 -0
  112. ml_tools/ML_vision_transformers/_offline_augmentation.py +159 -0
  113. ml_tools/{_core/_PSO_optimization.py → PSO_optimization/_PSO.py} +58 -15
  114. ml_tools/{PSO_optimization.py → PSO_optimization/__init__.py} +5 -3
  115. ml_tools/PSO_optimization/_imprimir.py +10 -0
  116. ml_tools/SQL/__init__.py +7 -0
  117. ml_tools/{_core/_SQL.py → SQL/_dragon_SQL.py} +7 -11
  118. ml_tools/SQL/_imprimir.py +8 -0
  119. ml_tools/{_core → VIF}/_VIF_factor.py +5 -8
  120. ml_tools/{VIF_factor.py → VIF/__init__.py} +4 -2
  121. ml_tools/VIF/_imprimir.py +10 -0
  122. ml_tools/_core/__init__.py +7 -1
  123. ml_tools/_core/_logger.py +8 -18
  124. ml_tools/_core/_schema_load_ops.py +43 -0
  125. ml_tools/_core/_script_info.py +2 -2
  126. ml_tools/{data_exploration.py → data_exploration/__init__.py} +32 -16
  127. ml_tools/data_exploration/_analysis.py +214 -0
  128. ml_tools/data_exploration/_cleaning.py +566 -0
  129. ml_tools/data_exploration/_features.py +583 -0
  130. ml_tools/data_exploration/_imprimir.py +32 -0
  131. ml_tools/data_exploration/_plotting.py +487 -0
  132. ml_tools/data_exploration/_schema_ops.py +176 -0
  133. ml_tools/{ensemble_evaluation.py → ensemble_evaluation/__init__.py} +6 -4
  134. ml_tools/{_core → ensemble_evaluation}/_ensemble_evaluation.py +3 -7
  135. ml_tools/ensemble_evaluation/_imprimir.py +14 -0
  136. ml_tools/{ensemble_inference.py → ensemble_inference/__init__.py} +5 -3
  137. ml_tools/{_core → ensemble_inference}/_ensemble_inference.py +15 -18
  138. ml_tools/ensemble_inference/_imprimir.py +9 -0
  139. ml_tools/{ensemble_learning.py → ensemble_learning/__init__.py} +4 -6
  140. ml_tools/{_core → ensemble_learning}/_ensemble_learning.py +7 -10
  141. ml_tools/ensemble_learning/_imprimir.py +10 -0
  142. ml_tools/{excel_handler.py → excel_handler/__init__.py} +5 -3
  143. ml_tools/{_core → excel_handler}/_excel_handler.py +6 -10
  144. ml_tools/excel_handler/_imprimir.py +13 -0
  145. ml_tools/{keys.py → keys/__init__.py} +4 -1
  146. ml_tools/keys/_imprimir.py +11 -0
  147. ml_tools/{_core → keys}/_keys.py +2 -0
  148. ml_tools/{math_utilities.py → math_utilities/__init__.py} +5 -2
  149. ml_tools/math_utilities/_imprimir.py +11 -0
  150. ml_tools/{_core → math_utilities}/_math_utilities.py +1 -5
  151. ml_tools/{optimization_tools.py → optimization_tools/__init__.py} +9 -4
  152. ml_tools/optimization_tools/_imprimir.py +13 -0
  153. ml_tools/optimization_tools/_optimization_bounds.py +236 -0
  154. ml_tools/optimization_tools/_optimization_plots.py +218 -0
  155. ml_tools/{path_manager.py → path_manager/__init__.py} +6 -3
  156. ml_tools/{_core/_path_manager.py → path_manager/_dragonmanager.py} +11 -347
  157. ml_tools/path_manager/_imprimir.py +15 -0
  158. ml_tools/path_manager/_path_tools.py +346 -0
  159. ml_tools/plot_fonts/__init__.py +8 -0
  160. ml_tools/plot_fonts/_imprimir.py +8 -0
  161. ml_tools/{_core → plot_fonts}/_plot_fonts.py +2 -5
  162. ml_tools/schema/__init__.py +15 -0
  163. ml_tools/schema/_feature_schema.py +223 -0
  164. ml_tools/schema/_gui_schema.py +191 -0
  165. ml_tools/schema/_imprimir.py +10 -0
  166. ml_tools/{serde.py → serde/__init__.py} +4 -2
  167. ml_tools/serde/_imprimir.py +10 -0
  168. ml_tools/{_core → serde}/_serde.py +3 -8
  169. ml_tools/{utilities.py → utilities/__init__.py} +11 -6
  170. ml_tools/utilities/_imprimir.py +18 -0
  171. ml_tools/{_core/_utilities.py → utilities/_utility_save_load.py} +13 -190
  172. ml_tools/utilities/_utility_tools.py +192 -0
  173. dragon_ml_toolbox-19.14.0.dist-info/RECORD +0 -111
  174. ml_tools/ML_chaining_inference.py +0 -8
  175. ml_tools/ML_configuration.py +0 -86
  176. ml_tools/ML_configuration_pytab.py +0 -14
  177. ml_tools/ML_datasetmaster.py +0 -10
  178. ml_tools/ML_evaluation.py +0 -16
  179. ml_tools/ML_evaluation_multi.py +0 -12
  180. ml_tools/ML_finalize_handler.py +0 -8
  181. ml_tools/ML_inference.py +0 -12
  182. ml_tools/ML_models.py +0 -14
  183. ml_tools/ML_models_advanced.py +0 -14
  184. ml_tools/ML_models_pytab.py +0 -14
  185. ml_tools/ML_optimization.py +0 -14
  186. ml_tools/ML_optimization_pareto.py +0 -8
  187. ml_tools/ML_scaler.py +0 -8
  188. ml_tools/ML_sequence_datasetmaster.py +0 -8
  189. ml_tools/ML_sequence_evaluation.py +0 -10
  190. ml_tools/ML_sequence_inference.py +0 -8
  191. ml_tools/ML_sequence_models.py +0 -8
  192. ml_tools/ML_trainer.py +0 -12
  193. ml_tools/ML_vision_datasetmaster.py +0 -12
  194. ml_tools/ML_vision_evaluation.py +0 -10
  195. ml_tools/ML_vision_inference.py +0 -8
  196. ml_tools/ML_vision_models.py +0 -18
  197. ml_tools/SQL.py +0 -8
  198. ml_tools/_core/_ETL_cleaning.py +0 -694
  199. ml_tools/_core/_IO_tools.py +0 -498
  200. ml_tools/_core/_ML_callbacks.py +0 -702
  201. ml_tools/_core/_ML_configuration.py +0 -1332
  202. ml_tools/_core/_ML_configuration_pytab.py +0 -102
  203. ml_tools/_core/_ML_evaluation.py +0 -867
  204. ml_tools/_core/_ML_evaluation_multi.py +0 -544
  205. ml_tools/_core/_ML_inference.py +0 -646
  206. ml_tools/_core/_ML_models.py +0 -668
  207. ml_tools/_core/_ML_models_pytab.py +0 -693
  208. ml_tools/_core/_ML_trainer.py +0 -2323
  209. ml_tools/_core/_ML_utilities.py +0 -886
  210. ml_tools/_core/_ML_vision_models.py +0 -644
  211. ml_tools/_core/_data_exploration.py +0 -1909
  212. ml_tools/_core/_optimization_tools.py +0 -493
  213. ml_tools/_core/_schema.py +0 -359
  214. ml_tools/plot_fonts.py +0 -8
  215. ml_tools/schema.py +0 -12
  216. {dragon_ml_toolbox-19.14.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/WHEEL +0 -0
  217. {dragon_ml_toolbox-19.14.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/licenses/LICENSE +0 -0
  218. {dragon_ml_toolbox-19.14.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +0 -0
  219. {dragon_ml_toolbox-19.14.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/top_level.txt +0 -0
@@ -1,2323 +0,0 @@
1
- from typing import List, Literal, Union, Optional, Callable, Dict, Any
2
- from pathlib import Path
3
- from torch.utils.data import DataLoader, Dataset
4
- import torch
5
- from torch import nn
6
- import numpy as np
7
- from abc import ABC, abstractmethod
8
-
9
- from ._path_manager import make_fullpath
10
- from ._ML_callbacks import _Callback, History, TqdmProgressBar, DragonModelCheckpoint, _DragonEarlyStopping, _DragonLRScheduler
11
- from ._ML_evaluation import classification_metrics, regression_metrics, plot_losses, shap_summary_plot, plot_attention_importance
12
- from ._ML_evaluation_multi import multi_target_regression_metrics, multi_label_classification_metrics, multi_target_shap_summary_plot
13
- from ._ML_vision_evaluation import segmentation_metrics, object_detection_metrics
14
- from ._ML_sequence_evaluation import sequence_to_sequence_metrics, sequence_to_value_metrics
15
- from ._ML_evaluation_captum import captum_feature_importance, _is_captum_available, captum_segmentation_heatmap, captum_image_heatmap
16
- from ._ML_configuration import (RegressionMetricsFormat,
17
- MultiTargetRegressionMetricsFormat,
18
- BinaryClassificationMetricsFormat,
19
- MultiClassClassificationMetricsFormat,
20
- BinaryImageClassificationMetricsFormat,
21
- MultiClassImageClassificationMetricsFormat,
22
- MultiLabelBinaryClassificationMetricsFormat,
23
- BinarySegmentationMetricsFormat,
24
- MultiClassSegmentationMetricsFormat,
25
- SequenceValueMetricsFormat,
26
- SequenceSequenceMetricsFormat,
27
-
28
- FinalizeBinaryClassification,
29
- FinalizeBinarySegmentation,
30
- FinalizeBinaryImageClassification,
31
- FinalizeMultiClassClassification,
32
- FinalizeMultiClassImageClassification,
33
- FinalizeMultiClassSegmentation,
34
- FinalizeMultiLabelBinaryClassification,
35
- FinalizeMultiTargetRegression,
36
- FinalizeRegression,
37
- FinalizeObjectDetection,
38
- FinalizeSequenceSequencePrediction,
39
- FinalizeSequenceValuePrediction)
40
-
41
- from ._script_info import _script_info
42
- from ._keys import PyTorchLogKeys, PyTorchCheckpointKeys, DatasetKeys, MLTaskKeys, MagicWords, DragonTrainerKeys, SequenceDatasetKeys, ScalerKeys
43
- from ._logger import get_logger
44
-
45
-
46
- _LOGGER = get_logger("DragonTrainer")
47
-
48
-
49
- __all__ = [
50
- "DragonTrainer",
51
- "DragonDetectionTrainer",
52
- "DragonSequenceTrainer"
53
- ]
54
-
55
- class _BaseDragonTrainer(ABC):
56
- """
57
- Abstract base class for Dragon Trainers.
58
-
59
- Handles the common training loop orchestration, checkpointing, callback
60
- management, and device handling. Subclasses must implement the
61
- task-specific logic (dataloaders, train/val steps, evaluation).
62
- """
63
- def __init__(self,
64
- model: nn.Module,
65
- optimizer: torch.optim.Optimizer,
66
- device: Union[Literal['cuda', 'mps', 'cpu'],str],
67
- dataloader_workers: int = 2,
68
- checkpoint_callback: Optional[DragonModelCheckpoint] = None,
69
- early_stopping_callback: Optional[_DragonEarlyStopping] = None,
70
- lr_scheduler_callback: Optional[_DragonLRScheduler] = None,
71
- extra_callbacks: Optional[List[_Callback]] = None):
72
-
73
- self.model = model
74
- self.optimizer = optimizer
75
- self.scheduler = None
76
- self.device = self._validate_device(device)
77
- self.dataloader_workers = dataloader_workers
78
-
79
- # Callback handler
80
- default_callbacks = [History(), TqdmProgressBar()]
81
-
82
- self._checkpoint_callback = None
83
- if checkpoint_callback:
84
- default_callbacks.append(checkpoint_callback)
85
- self._checkpoint_callback = checkpoint_callback
86
- if early_stopping_callback:
87
- default_callbacks.append(early_stopping_callback)
88
- if lr_scheduler_callback:
89
- default_callbacks.append(lr_scheduler_callback)
90
-
91
- user_callbacks = extra_callbacks if extra_callbacks is not None else []
92
- self.callbacks = default_callbacks + user_callbacks
93
- self._set_trainer_on_callbacks()
94
-
95
- # Internal state
96
- self.train_loader: Optional[DataLoader] = None
97
- self.validation_loader: Optional[DataLoader] = None
98
- self.history: Dict[str, List[Any]] = {}
99
- self.epoch = 0
100
- self.epochs = 0 # Total epochs for the fit run
101
- self.start_epoch = 1
102
- self.stop_training = False
103
- self._batch_size = 10
104
-
105
- def _validate_device(self, device: str) -> torch.device:
106
- """Validates the selected device and returns a torch.device object."""
107
- device_lower = device.lower()
108
- if "cuda" in device_lower and not torch.cuda.is_available():
109
- _LOGGER.warning("CUDA not available, switching to CPU.")
110
- device = "cpu"
111
- elif device_lower == "mps" and not torch.backends.mps.is_available():
112
- _LOGGER.warning("Apple Metal Performance Shaders (MPS) not available, switching to CPU.")
113
- device = "cpu"
114
- return torch.device(device)
115
-
116
- def _set_trainer_on_callbacks(self):
117
- """Gives each callback a reference to this trainer instance."""
118
- for callback in self.callbacks:
119
- callback.set_trainer(self)
120
-
121
- def _load_checkpoint(self, path: Union[str, Path]):
122
- """Loads a training checkpoint to resume training."""
123
- p = make_fullpath(path, enforce="file")
124
- _LOGGER.info(f"Loading checkpoint from '{p.name}'...")
125
-
126
- try:
127
- checkpoint = torch.load(p, map_location=self.device)
128
-
129
- if PyTorchCheckpointKeys.MODEL_STATE not in checkpoint or PyTorchCheckpointKeys.OPTIMIZER_STATE not in checkpoint:
130
- _LOGGER.error(f"Checkpoint file '{p.name}' is invalid. Missing 'model_state_dict' or 'optimizer_state_dict'.")
131
- raise KeyError()
132
-
133
- self.model.load_state_dict(checkpoint[PyTorchCheckpointKeys.MODEL_STATE])
134
- self.optimizer.load_state_dict(checkpoint[PyTorchCheckpointKeys.OPTIMIZER_STATE])
135
- self.epoch = checkpoint.get(PyTorchCheckpointKeys.EPOCH, 0)
136
- self.start_epoch = self.epoch + 1 # Resume on the *next* epoch
137
-
138
- # --- Load History ---
139
- if PyTorchCheckpointKeys.HISTORY in checkpoint:
140
- self.history = checkpoint[PyTorchCheckpointKeys.HISTORY]
141
- _LOGGER.info(f"Restored training history up to epoch {self.epoch}.")
142
- else:
143
- _LOGGER.warning("No 'history' found in checkpoint. A new history will be started.")
144
- self.history = {} # Ensure it's at least an empty dict
145
-
146
- # --- Scheduler State Loading Logic ---
147
- scheduler_state_exists = PyTorchCheckpointKeys.SCHEDULER_STATE in checkpoint
148
- scheduler_object_exists = self.scheduler is not None
149
-
150
- if scheduler_object_exists and scheduler_state_exists:
151
- # Case 1: Both exist. Attempt to load.
152
- try:
153
- self.scheduler.load_state_dict(checkpoint[PyTorchCheckpointKeys.SCHEDULER_STATE]) # type: ignore
154
- scheduler_name = self.scheduler.__class__.__name__
155
- _LOGGER.info(f"Restored LR scheduler state for: {scheduler_name}")
156
- except Exception as e:
157
- # Loading failed, likely a mismatch
158
- scheduler_name = self.scheduler.__class__.__name__
159
- _LOGGER.error(f"Failed to load scheduler state for '{scheduler_name}'. A different scheduler type might have been used.")
160
- raise e
161
-
162
- elif scheduler_object_exists and not scheduler_state_exists:
163
- # Case 2: Scheduler provided, but no state in checkpoint.
164
- scheduler_name = self.scheduler.__class__.__name__
165
- _LOGGER.warning(f"'{scheduler_name}' was provided, but no scheduler state was found in the checkpoint. The scheduler will start from its initial state.")
166
-
167
- elif not scheduler_object_exists and scheduler_state_exists:
168
- # Case 3: State in checkpoint, but no scheduler provided.
169
- _LOGGER.error("Checkpoint contains an LR scheduler state, but no LRScheduler callback was provided.")
170
- raise ValueError()
171
-
172
- # Restore callback states
173
- for cb in self.callbacks:
174
- if isinstance(cb, DragonModelCheckpoint) and PyTorchCheckpointKeys.BEST_SCORE in checkpoint:
175
- cb.best = checkpoint[PyTorchCheckpointKeys.BEST_SCORE]
176
- _LOGGER.info(f"Restored {cb.__class__.__name__} 'best' score to: {cb.best:.4f}")
177
-
178
- _LOGGER.info(f"Checkpoint loaded. Resuming training from epoch {self.start_epoch}.")
179
-
180
- except Exception as e:
181
- _LOGGER.error(f"Failed to load checkpoint from '{p}': {e}")
182
- raise
183
-
184
- def fit(self,
185
- save_dir: Union[str,Path],
186
- epochs: int = 100,
187
- batch_size: int = 10,
188
- shuffle: bool = True,
189
- resume_from_checkpoint: Optional[Union[str, Path]] = None):
190
- """
191
- Starts the training-validation process of the model.
192
-
193
- Returns the "History" callback dictionary.
194
-
195
- Args:
196
- save_dir (str | Path): Directory to save the loss plot.
197
- epochs (int): The total number of epochs to train for.
198
- batch_size (int): The number of samples per batch.
199
- shuffle (bool): Whether to shuffle the training data at each epoch.
200
- resume_from_checkpoint (str | Path | None): Optional path to a checkpoint to resume training.
201
- """
202
- self.epochs = epochs
203
- self._batch_size = batch_size
204
- self._create_dataloaders(self._batch_size, shuffle) # type: ignore
205
- self.model.to(self.device)
206
-
207
- if resume_from_checkpoint:
208
- self._load_checkpoint(resume_from_checkpoint)
209
-
210
- # Reset stop_training flag on the trainer
211
- self.stop_training = False
212
-
213
- self._callbacks_hook('on_train_begin')
214
-
215
- if not self.train_loader:
216
- _LOGGER.error("Train loader is not initialized.")
217
- raise ValueError()
218
-
219
- if not self.validation_loader:
220
- _LOGGER.error("Validation loader is not initialized.")
221
- raise ValueError()
222
-
223
- for epoch in range(self.start_epoch, self.epochs + 1):
224
- self.epoch = epoch
225
- epoch_logs: Dict[str, Any] = {}
226
- self._callbacks_hook('on_epoch_begin', epoch, logs=epoch_logs)
227
-
228
- train_logs = self._train_step()
229
- epoch_logs.update(train_logs)
230
-
231
- val_logs = self._validation_step()
232
- epoch_logs.update(val_logs)
233
-
234
- self._callbacks_hook('on_epoch_end', epoch, logs=epoch_logs)
235
-
236
- # Check the early stopping flag
237
- if self.stop_training:
238
- break
239
-
240
- self._callbacks_hook('on_train_end')
241
-
242
- # Training History
243
- plot_losses(self.history, save_dir=save_dir)
244
-
245
- return self.history
246
-
247
- def _callbacks_hook(self, method_name: str, *args, **kwargs):
248
- """Calls the specified method on all callbacks."""
249
- for callback in self.callbacks:
250
- method = getattr(callback, method_name)
251
- method(*args, **kwargs)
252
-
253
- def to_cpu(self):
254
- """
255
- Moves the model to the CPU and updates the trainer's device setting.
256
-
257
- This is useful for running operations that require the CPU.
258
- """
259
- self.device = torch.device('cpu')
260
- self.model.to(self.device)
261
- _LOGGER.info("Trainer and model moved to CPU.")
262
-
263
- def to_device(self, device: str):
264
- """
265
- Moves the model to the specified device and updates the trainer's device setting.
266
-
267
- Args:
268
- device (str): The target device (e.g., 'cuda', 'mps', 'cpu').
269
- """
270
- self.device = self._validate_device(device)
271
- self.model.to(self.device)
272
- _LOGGER.info(f"Trainer and model moved to {self.device}.")
273
-
274
- def _load_model_state_for_finalizing(self, model_checkpoint: Union[Path, Literal['best', 'current']]):
275
- """
276
- Private helper to load the correct model state_dict based on user's choice.
277
- This is called by finalize_model_training() in subclasses.
278
- """
279
- if isinstance(model_checkpoint, Path):
280
- self._load_checkpoint(path=model_checkpoint)
281
- elif model_checkpoint == MagicWords.BEST and self._checkpoint_callback:
282
- path_to_latest = self._checkpoint_callback.best_checkpoint_path
283
- self._load_checkpoint(path_to_latest)
284
- elif model_checkpoint == MagicWords.BEST and self._checkpoint_callback is None:
285
- _LOGGER.error(f"'model_checkpoint' set to '{MagicWords.BEST}' but no checkpoint callback was found.")
286
- raise ValueError()
287
- elif model_checkpoint == MagicWords.CURRENT:
288
- pass
289
- else:
290
- _LOGGER.error(f"Unknown 'model_checkpoint' received '{model_checkpoint}'.")
291
- raise ValueError()
292
-
293
- # --- Abstract Methods ---
294
- # These must be implemented by subclasses
295
-
296
- @abstractmethod
297
- def _create_dataloaders(self, batch_size: int, shuffle: bool):
298
- """Initializes the DataLoaders."""
299
- raise NotImplementedError
300
-
301
- @abstractmethod
302
- def _train_step(self) -> Dict[str, float]:
303
- """Runs a single training epoch."""
304
- raise NotImplementedError
305
-
306
- @abstractmethod
307
- def _validation_step(self) -> Dict[str, float]:
308
- """Runs a single validation epoch."""
309
- raise NotImplementedError
310
-
311
- @abstractmethod
312
- def evaluate(self, *args, **kwargs):
313
- """Runs the full model evaluation."""
314
- raise NotImplementedError
315
-
316
- @abstractmethod
317
- def _evaluate(self, *args, **kwargs):
318
- """Internal evaluation helper."""
319
- raise NotImplementedError
320
-
321
- @abstractmethod
322
- def finalize_model_training(self, *args, **kwargs):
323
- """Saves the finalized model for inference."""
324
- raise NotImplementedError
325
-
326
-
327
- # --- DragonTrainer ----
328
- class DragonTrainer(_BaseDragonTrainer):
329
- def __init__(self,
330
- model: nn.Module,
331
- train_dataset: Dataset,
332
- validation_dataset: Dataset,
333
- kind: Literal["regression", "binary classification", "multiclass classification",
334
- "multitarget regression", "multilabel binary classification",
335
- "binary segmentation", "multiclass segmentation", "binary image classification", "multiclass image classification"],
336
- optimizer: torch.optim.Optimizer,
337
- device: Union[Literal['cuda', 'mps', 'cpu'],str],
338
- checkpoint_callback: Optional[DragonModelCheckpoint],
339
- early_stopping_callback: Optional[_DragonEarlyStopping],
340
- lr_scheduler_callback: Optional[_DragonLRScheduler],
341
- extra_callbacks: Optional[List[_Callback]] = None,
342
- criterion: Union[nn.Module,Literal["auto"]] = "auto",
343
- dataloader_workers: int = 2):
344
- """
345
- Automates the training process of a PyTorch Model.
346
-
347
- Built-in Callbacks: `History`, `TqdmProgressBar`
348
-
349
- Args:
350
- model (nn.Module): The PyTorch model to train.
351
- train_dataset (Dataset): The training dataset.
352
- validation_dataset (Dataset): The validation dataset.
353
- kind (str): Used to redirect to the correct process.
354
- criterion (nn.Module | "auto"): The loss function to use. If "auto", it will be inferred from the selected task
355
- optimizer (torch.optim.Optimizer): The optimizer.
356
- device (str): The device to run training on ('cpu', 'cuda', 'mps').
357
- dataloader_workers (int): Subprocesses for data loading.
358
- extra_callbacks (List[Callback] | None): A list of extra callbacks to use during training.
359
-
360
- Note:
361
- - For **regression** and **multi_target_regression** tasks, suggested criterions include `nn.MSELoss` or `nn.L1Loss`. The model should output as many logits as existing targets.
362
-
363
- - For **single-label, binary classification**, `nn.BCEWithLogitsLoss` is the standard choice. The model should output a single logit.
364
-
365
- - For **single-label, multi-class classification** tasks, `nn.CrossEntropyLoss` is the standard choice. The model should output as many logits as existing classes.
366
-
367
- - For **multi-label, binary classification** tasks (where each label is a 0 or 1), `nn.BCEWithLogitsLoss` is the correct choice as it treats each output as an independent binary problem. The model should output 1 logit per binary target.
368
-
369
- - For **binary segmentation** tasks, `nn.BCEWithLogitsLoss` is common. The model should output a single logit.
370
-
371
- - for **multiclass segmentation** tasks, `nn.CrossEntropyLoss` is the standard. The model should output as many logits as existing classes.
372
- """
373
- # Call the base class constructor with common parameters
374
- super().__init__(
375
- model=model,
376
- optimizer=optimizer,
377
- device=device,
378
- dataloader_workers=dataloader_workers,
379
- checkpoint_callback=checkpoint_callback,
380
- early_stopping_callback=early_stopping_callback,
381
- lr_scheduler_callback=lr_scheduler_callback,
382
- extra_callbacks=extra_callbacks
383
- )
384
-
385
- if kind not in [MLTaskKeys.REGRESSION,
386
- MLTaskKeys.BINARY_CLASSIFICATION,
387
- MLTaskKeys.MULTICLASS_CLASSIFICATION,
388
- MLTaskKeys.MULTILABEL_BINARY_CLASSIFICATION,
389
- MLTaskKeys.MULTITARGET_REGRESSION,
390
- MLTaskKeys.BINARY_SEGMENTATION,
391
- MLTaskKeys.MULTICLASS_SEGMENTATION,
392
- MLTaskKeys.BINARY_IMAGE_CLASSIFICATION,
393
- MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION]:
394
- raise ValueError(f"'{kind}' is not a valid task type.")
395
-
396
- self.train_dataset = train_dataset
397
- self.validation_dataset = validation_dataset
398
- self.kind = kind
399
- self._classification_threshold: float = 0.5
400
-
401
- # loss function
402
- if criterion == "auto":
403
- if kind in [MLTaskKeys.REGRESSION, MLTaskKeys.MULTITARGET_REGRESSION]:
404
- self.criterion = nn.MSELoss()
405
- elif kind in [MLTaskKeys.BINARY_CLASSIFICATION, MLTaskKeys.BINARY_IMAGE_CLASSIFICATION, MLTaskKeys.BINARY_SEGMENTATION, MLTaskKeys.MULTILABEL_BINARY_CLASSIFICATION]:
406
- self.criterion = nn.BCEWithLogitsLoss()
407
- elif kind in [MLTaskKeys.MULTICLASS_CLASSIFICATION, MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION, MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION, MLTaskKeys.MULTICLASS_SEGMENTATION]:
408
- self.criterion = nn.CrossEntropyLoss()
409
- else:
410
- self.criterion = criterion
411
-
412
- def _create_dataloaders(self, batch_size: int, shuffle: bool):
413
- """Initializes the DataLoaders."""
414
- # Ensure stability on MPS devices by setting num_workers to 0
415
- loader_workers = 0 if self.device.type == 'mps' else self.dataloader_workers
416
-
417
- self.train_loader = DataLoader(
418
- dataset=self.train_dataset,
419
- batch_size=batch_size,
420
- shuffle=shuffle,
421
- num_workers=loader_workers,
422
- pin_memory=("cuda" in self.device.type),
423
- drop_last=True # Drops the last batch if incomplete, selecting a good batch size is key.
424
- )
425
-
426
- self.validation_loader = DataLoader(
427
- dataset=self.validation_dataset,
428
- batch_size=batch_size,
429
- shuffle=False,
430
- num_workers=loader_workers,
431
- pin_memory=("cuda" in self.device.type)
432
- )
433
-
434
- def _train_step(self):
435
- self.model.train()
436
- running_loss = 0.0
437
- total_samples = 0
438
-
439
- for batch_idx, (features, target) in enumerate(self.train_loader): # type: ignore
440
- # Create a log dictionary for the batch
441
- batch_logs = {
442
- PyTorchLogKeys.BATCH_INDEX: batch_idx,
443
- PyTorchLogKeys.BATCH_SIZE: features.size(0)
444
- }
445
- self._callbacks_hook('on_batch_begin', batch_idx, logs=batch_logs)
446
-
447
- features, target = features.to(self.device), target.to(self.device)
448
- self.optimizer.zero_grad()
449
-
450
- output = self.model(features)
451
-
452
- # --- Label Type/Shape Correction ---
453
- # Cast target to float for BCE-based losses
454
- if self.kind in MLTaskKeys.ALL_BINARY_TASKS:
455
- target = target.float()
456
-
457
- # Reshape output to match target for single-logit tasks
458
- if self.kind in [MLTaskKeys.REGRESSION, MLTaskKeys.BINARY_CLASSIFICATION, MLTaskKeys.BINARY_IMAGE_CLASSIFICATION]:
459
- # If model outputs [N, 1] and target is [N], squeeze output
460
- if output.ndim == 2 and output.shape[1] == 1 and target.ndim == 1:
461
- output = output.squeeze(1)
462
-
463
- if self.kind == MLTaskKeys.BINARY_SEGMENTATION:
464
- # If model outputs [N, 1, H, W] and target is [N, H, W], squeeze output
465
- if output.ndim == 4 and output.shape[1] == 1 and target.ndim == 3:
466
- output = output.squeeze(1)
467
-
468
- loss = self.criterion(output, target)
469
-
470
- loss.backward()
471
- self.optimizer.step()
472
-
473
- # Calculate batch loss and update running loss for the epoch
474
- batch_loss = loss.item()
475
- batch_size = features.size(0)
476
- running_loss += batch_loss * batch_size # Accumulate total loss
477
- total_samples += batch_size # total samples
478
-
479
- # Add the batch loss to the logs and call the end-of-batch hook
480
- batch_logs[PyTorchLogKeys.BATCH_LOSS] = batch_loss
481
- self._callbacks_hook('on_batch_end', batch_idx, logs=batch_logs)
482
-
483
- if total_samples == 0:
484
- _LOGGER.warning("No samples processed in a train_step. Returning 0 loss.")
485
- return {PyTorchLogKeys.TRAIN_LOSS: 0.0}
486
-
487
- return {PyTorchLogKeys.TRAIN_LOSS: running_loss / total_samples} # type: ignore
488
-
489
- def _validation_step(self):
490
- self.model.eval()
491
- running_loss = 0.0
492
-
493
- with torch.no_grad():
494
- for features, target in self.validation_loader: # type: ignore
495
- features, target = features.to(self.device), target.to(self.device)
496
-
497
- output = self.model(features)
498
-
499
- # --- Label Type/Shape Correction ---
500
- # Cast target to float for BCE-based losses
501
- if self.kind in MLTaskKeys.ALL_BINARY_TASKS:
502
- target = target.float()
503
-
504
- # Reshape output to match target for single-logit tasks
505
- if self.kind in [MLTaskKeys.REGRESSION, MLTaskKeys.BINARY_CLASSIFICATION, MLTaskKeys.BINARY_IMAGE_CLASSIFICATION]:
506
- # If model outputs [N, 1] and target is [N], squeeze output
507
- if output.ndim == 2 and output.shape[1] == 1 and target.ndim == 1:
508
- output = output.squeeze(1)
509
-
510
- if self.kind == MLTaskKeys.BINARY_SEGMENTATION:
511
- # If model outputs [N, 1, H, W] and target is [N, H, W], squeeze output
512
- if output.ndim == 4 and output.shape[1] == 1 and target.ndim == 3:
513
- output = output.squeeze(1)
514
-
515
- loss = self.criterion(output, target)
516
-
517
- running_loss += loss.item() * features.size(0)
518
-
519
- if not self.validation_loader.dataset: # type: ignore
520
- _LOGGER.warning("No samples processed in _validation_step. Returning 0 loss.")
521
- return {PyTorchLogKeys.VAL_LOSS: 0.0}
522
-
523
- logs = {PyTorchLogKeys.VAL_LOSS: running_loss / len(self.validation_loader.dataset)} # type: ignore
524
- return logs
525
-
526
- def _predict_for_eval(self, dataloader: DataLoader):
527
- """
528
- Private method to yield model predictions batch by batch for evaluation.
529
-
530
- Automatically detects if `target_scaler` is present in the training dataset
531
- and applies inverse transformation for Regression tasks.
532
-
533
- Yields:
534
- tuple: A tuple containing (y_pred_batch, y_prob_batch, y_true_batch).
535
-
536
- - y_prob_batch is None for regression tasks.
537
- """
538
- self.model.eval()
539
- self.model.to(self.device)
540
-
541
- # --- Check for Target Scaler (for Regression Un-scaling) ---
542
- target_scaler = None
543
- if self.kind in [MLTaskKeys.REGRESSION, MLTaskKeys.MULTITARGET_REGRESSION]:
544
- # Try to get the scaler from the dataset attached to the trainer
545
- if hasattr(self.train_dataset, ScalerKeys.TARGET_SCALER):
546
- target_scaler = getattr(self.train_dataset, ScalerKeys.TARGET_SCALER)
547
- if target_scaler is not None:
548
- _LOGGER.debug("Target scaler detected. Un-scaling predictions and targets for metric calculation.")
549
-
550
- with torch.no_grad():
551
- for features, target in dataloader:
552
- features = features.to(self.device)
553
- # Keep target on device initially for potential un-scaling
554
- target = target.to(self.device)
555
-
556
- output = self.model(features)
557
-
558
- y_pred_batch = None
559
- y_prob_batch = None
560
- y_true_batch = None
561
-
562
- if self.kind in [MLTaskKeys.REGRESSION, MLTaskKeys.MULTITARGET_REGRESSION]:
563
-
564
- # --- Automatic Un-scaling Logic ---
565
- if target_scaler:
566
- # 1. Reshape output/target if flattened (common in single regression)
567
- # Scaler expects [N, Features]
568
- original_out_shape = output.shape
569
- original_target_shape = target.shape
570
-
571
- if output.ndim == 1: output = output.reshape(-1, 1)
572
- if target.ndim == 1: target = target.reshape(-1, 1)
573
-
574
- # 2. Apply Inverse Transform
575
- output = target_scaler.inverse_transform(output)
576
- target = target_scaler.inverse_transform(target)
577
-
578
- # 3. Restore shapes (optional, but good for consistency)
579
- if len(original_out_shape) == 1: output = output.flatten()
580
- if len(original_target_shape) == 1: target = target.flatten()
581
-
582
- y_pred_batch = output.cpu().numpy()
583
- y_true_batch = target.cpu().numpy()
584
-
585
- elif self.kind in [MLTaskKeys.BINARY_CLASSIFICATION, MLTaskKeys.BINARY_IMAGE_CLASSIFICATION]:
586
- if output.ndim == 2 and output.shape[1] == 1:
587
- output = output.squeeze(1)
588
-
589
- probs_pos = torch.sigmoid(output)
590
- preds = (probs_pos >= self._classification_threshold).int()
591
- y_pred_batch = preds.cpu().numpy()
592
-
593
- probs_neg = 1.0 - probs_pos
594
- y_prob_batch = torch.stack([probs_neg, probs_pos], dim=1).cpu().numpy()
595
- y_true_batch = target.cpu().numpy()
596
-
597
- elif self.kind in [MLTaskKeys.MULTICLASS_CLASSIFICATION, MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION]:
598
- probs = torch.softmax(output, dim=1)
599
- preds = torch.argmax(probs, dim=1)
600
- y_pred_batch = preds.cpu().numpy()
601
- y_prob_batch = probs.cpu().numpy()
602
- y_true_batch = target.cpu().numpy()
603
-
604
- elif self.kind == MLTaskKeys.MULTILABEL_BINARY_CLASSIFICATION:
605
- probs = torch.sigmoid(output)
606
- preds = (probs >= self._classification_threshold).int()
607
- y_pred_batch = preds.cpu().numpy()
608
- y_prob_batch = probs.cpu().numpy()
609
- y_true_batch = target.cpu().numpy()
610
-
611
- elif self.kind == MLTaskKeys.BINARY_SEGMENTATION:
612
- probs_pos = torch.sigmoid(output)
613
- preds = (probs_pos >= self._classification_threshold).int()
614
- y_pred_batch = preds.squeeze(1).cpu().numpy()
615
-
616
- probs_neg = 1.0 - probs_pos
617
- y_prob_batch = torch.cat([probs_neg, probs_pos], dim=1).cpu().numpy()
618
-
619
- if target.ndim == 4 and target.shape[1] == 1:
620
- target = target.squeeze(1)
621
- y_true_batch = target.cpu().numpy()
622
-
623
- elif self.kind == MLTaskKeys.MULTICLASS_SEGMENTATION:
624
- probs = torch.softmax(output, dim=1)
625
- preds = torch.argmax(probs, dim=1)
626
- y_pred_batch = preds.cpu().numpy()
627
- y_prob_batch = probs.cpu().numpy()
628
-
629
- if target.ndim == 4 and target.shape[1] == 1:
630
- target = target.squeeze(1)
631
- y_true_batch = target.cpu().numpy()
632
-
633
- yield y_pred_batch, y_prob_batch, y_true_batch
634
-
635
- def evaluate(self,
636
- save_dir: Union[str, Path],
637
- model_checkpoint: Union[Path, Literal["best", "current"]],
638
- classification_threshold: Optional[float] = None,
639
- test_data: Optional[Union[DataLoader, Dataset]] = None,
640
- val_format_configuration: Optional[Union[
641
- RegressionMetricsFormat,
642
- MultiTargetRegressionMetricsFormat,
643
- BinaryClassificationMetricsFormat,
644
- MultiClassClassificationMetricsFormat,
645
- BinaryImageClassificationMetricsFormat,
646
- MultiClassImageClassificationMetricsFormat,
647
- MultiLabelBinaryClassificationMetricsFormat,
648
- BinarySegmentationMetricsFormat,
649
- MultiClassSegmentationMetricsFormat
650
- ]]=None,
651
- test_format_configuration: Optional[Union[
652
- RegressionMetricsFormat,
653
- MultiTargetRegressionMetricsFormat,
654
- BinaryClassificationMetricsFormat,
655
- MultiClassClassificationMetricsFormat,
656
- BinaryImageClassificationMetricsFormat,
657
- MultiClassImageClassificationMetricsFormat,
658
- MultiLabelBinaryClassificationMetricsFormat,
659
- BinarySegmentationMetricsFormat,
660
- MultiClassSegmentationMetricsFormat,
661
- ]]=None):
662
- """
663
- Evaluates the model, routing to the correct evaluation function based on task `kind`.
664
-
665
- Args:
666
- model_checkpoint (Path | "best" | "current"):
667
- - Path to a valid checkpoint for the model. The state of the trained model will be overwritten in place.
668
- - If 'best', the best checkpoint will be loaded if a DragonModelCheckpoint was provided. The state of the trained model will be overwritten in place.
669
- - If 'current', use the current state of the trained model up the latest trained epoch.
670
- save_dir (str | Path): Directory to save all reports and plots.
671
- classification_threshold (float | None): Used for tasks using a binary approach (binary classification, binary segmentation, multilabel binary classification)
672
- test_data (DataLoader | Dataset | None): Optional Test data to evaluate the model performance. Validation and Test metrics will be saved to subdirectories.
673
- val_format_configuration (object): Optional configuration for metric format output for the validation set.
674
- test_format_configuration (object): Optional configuration for metric format output for the test set.
675
- """
676
- # Validate model checkpoint
677
- if isinstance(model_checkpoint, Path):
678
- checkpoint_validated = make_fullpath(model_checkpoint, enforce="file")
679
- elif model_checkpoint in [MagicWords.BEST, MagicWords.CURRENT]:
680
- checkpoint_validated = model_checkpoint
681
- else:
682
- _LOGGER.error(f"'model_checkpoint' must be a Path object, or the string '{MagicWords.BEST}', or the string '{MagicWords.CURRENT}'.")
683
- raise ValueError()
684
-
685
- # Validate classification threshold
686
- if self.kind not in MLTaskKeys.ALL_BINARY_TASKS:
687
- # dummy value for tasks that do not need it
688
- threshold_validated = 0.5
689
- elif classification_threshold is None:
690
- # it should have been provided for binary tasks
691
- _LOGGER.error(f"The classification threshold must be provided for '{self.kind}'.")
692
- raise ValueError()
693
- elif classification_threshold <= 0.0 or classification_threshold >= 1.0:
694
- # Invalid float
695
- _LOGGER.error(f"A classification threshold of {classification_threshold} is invalid. Must be in the range (0.0 - 1.0).")
696
- raise ValueError()
697
- else:
698
- threshold_validated = classification_threshold
699
-
700
- # Validate val configuration
701
- if val_format_configuration is not None:
702
- if not isinstance(val_format_configuration, (RegressionMetricsFormat,
703
- MultiTargetRegressionMetricsFormat,
704
- BinaryClassificationMetricsFormat,
705
- MultiClassClassificationMetricsFormat,
706
- BinaryImageClassificationMetricsFormat,
707
- MultiClassImageClassificationMetricsFormat,
708
- MultiLabelBinaryClassificationMetricsFormat,
709
- BinarySegmentationMetricsFormat,
710
- MultiClassSegmentationMetricsFormat)):
711
- _LOGGER.error(f"Invalid 'format_configuration': '{type(val_format_configuration)}'.")
712
- raise ValueError()
713
- else:
714
- val_configuration_validated = val_format_configuration
715
- else: # config is None
716
- val_configuration_validated = None
717
-
718
- # Validate directory
719
- save_path = make_fullpath(save_dir, make=True, enforce="directory")
720
-
721
- # Validate test data and dispatch
722
- if test_data is not None:
723
- if not isinstance(test_data, (DataLoader, Dataset)):
724
- _LOGGER.error(f"Invalid type for 'test_data': '{type(test_data)}'.")
725
- raise ValueError()
726
- test_data_validated = test_data
727
-
728
- validation_metrics_path = save_path / DragonTrainerKeys.VALIDATION_METRICS_DIR
729
- test_metrics_path = save_path / DragonTrainerKeys.TEST_METRICS_DIR
730
-
731
- # Dispatch validation set
732
- _LOGGER.info(f"🔎 Evaluating on validation dataset. Metrics will be saved to '{DragonTrainerKeys.VALIDATION_METRICS_DIR}'")
733
- self._evaluate(save_dir=validation_metrics_path,
734
- model_checkpoint=checkpoint_validated,
735
- classification_threshold=threshold_validated,
736
- data=None,
737
- format_configuration=val_configuration_validated)
738
-
739
- # Validate test configuration
740
- if test_format_configuration is not None:
741
- if not isinstance(test_format_configuration, (RegressionMetricsFormat,
742
- MultiTargetRegressionMetricsFormat,
743
- BinaryClassificationMetricsFormat,
744
- MultiClassClassificationMetricsFormat,
745
- BinaryImageClassificationMetricsFormat,
746
- MultiClassImageClassificationMetricsFormat,
747
- MultiLabelBinaryClassificationMetricsFormat,
748
- BinarySegmentationMetricsFormat,
749
- MultiClassSegmentationMetricsFormat)):
750
- warning_message_type = f"Invalid test_format_configuration': '{type(test_format_configuration)}'."
751
- if val_configuration_validated is not None:
752
- warning_message_type += " 'val_format_configuration' will be used for the test set metrics output."
753
- test_configuration_validated = val_configuration_validated
754
- else:
755
- warning_message_type += " Using default format."
756
- test_configuration_validated = None
757
- _LOGGER.warning(warning_message_type)
758
- else:
759
- test_configuration_validated = test_format_configuration
760
- else: #config is None
761
- test_configuration_validated = None
762
-
763
- # Dispatch test set
764
- _LOGGER.info(f"🔎 Evaluating on test dataset. Metrics will be saved to '{DragonTrainerKeys.TEST_METRICS_DIR}'")
765
- self._evaluate(save_dir=test_metrics_path,
766
- model_checkpoint="current",
767
- classification_threshold=threshold_validated,
768
- data=test_data_validated,
769
- format_configuration=test_configuration_validated)
770
- else:
771
- # Dispatch validation set
772
- _LOGGER.info(f"Evaluating on validation dataset. Metrics will be saved to '{save_path.name}'")
773
- self._evaluate(save_dir=save_path,
774
- model_checkpoint=checkpoint_validated,
775
- classification_threshold=threshold_validated,
776
- data=None,
777
- format_configuration=val_configuration_validated)
778
-
779
- def _evaluate(self,
780
- save_dir: Union[str, Path],
781
- model_checkpoint: Union[Path, Literal["best", "current"]],
782
- classification_threshold: float,
783
- data: Optional[Union[DataLoader, Dataset]],
784
- format_configuration: Optional[Union[
785
- RegressionMetricsFormat,
786
- MultiTargetRegressionMetricsFormat,
787
- BinaryClassificationMetricsFormat,
788
- MultiClassClassificationMetricsFormat,
789
- BinaryImageClassificationMetricsFormat,
790
- MultiClassImageClassificationMetricsFormat,
791
- MultiLabelBinaryClassificationMetricsFormat,
792
- BinarySegmentationMetricsFormat,
793
- MultiClassSegmentationMetricsFormat
794
- ]]=None):
795
- """
796
- Changed to a private helper function.
797
- """
798
- dataset_for_artifacts = None
799
- eval_loader = None
800
-
801
- # set threshold
802
- self._classification_threshold = classification_threshold
803
-
804
- # load model checkpoint
805
- if isinstance(model_checkpoint, Path):
806
- self._load_checkpoint(path=model_checkpoint)
807
- elif model_checkpoint == MagicWords.BEST and self._checkpoint_callback:
808
- path_to_latest = self._checkpoint_callback.best_checkpoint_path
809
- self._load_checkpoint(path_to_latest)
810
- elif model_checkpoint == MagicWords.BEST and self._checkpoint_callback is None:
811
- _LOGGER.error(f"'model_checkpoint' set to '{MagicWords.BEST}' but no checkpoint callback was found.")
812
- raise ValueError()
813
-
814
- # Dataloader
815
- if isinstance(data, DataLoader):
816
- eval_loader = data
817
- # Try to get the dataset from the loader for fetching target names
818
- if hasattr(data, 'dataset'):
819
- dataset_for_artifacts = data.dataset # type: ignore
820
- elif isinstance(data, Dataset):
821
- # Create a new loader from the provided dataset
822
- eval_loader = DataLoader(data,
823
- batch_size=self._batch_size,
824
- shuffle=False,
825
- num_workers=0 if self.device.type == 'mps' else self.dataloader_workers,
826
- pin_memory=(self.device.type == "cuda"))
827
- dataset_for_artifacts = data
828
- else: # data is None, use the trainer's default test dataset
829
- if self.validation_dataset is None:
830
- _LOGGER.error("Cannot evaluate. No data provided and no validation dataset available in the trainer.")
831
- raise ValueError()
832
- # Create a fresh DataLoader from the test_dataset
833
- eval_loader = DataLoader(self.validation_dataset,
834
- batch_size=self._batch_size,
835
- shuffle=False,
836
- num_workers=0 if self.device.type == 'mps' else self.dataloader_workers,
837
- pin_memory=(self.device.type == "cuda"))
838
-
839
- dataset_for_artifacts = self.validation_dataset
840
-
841
- if eval_loader is None:
842
- _LOGGER.error("Cannot evaluate. No valid data was provided or found.")
843
- raise ValueError()
844
-
845
- # print("\n--- Model Evaluation ---")
846
-
847
- all_preds, all_probs, all_true = [], [], []
848
- for y_pred_b, y_prob_b, y_true_b in self._predict_for_eval(eval_loader):
849
- if y_pred_b is not None: all_preds.append(y_pred_b)
850
- if y_prob_b is not None: all_probs.append(y_prob_b)
851
- if y_true_b is not None: all_true.append(y_true_b)
852
-
853
- if not all_true:
854
- _LOGGER.error("Evaluation failed: No data was processed.")
855
- return
856
-
857
- y_pred = np.concatenate(all_preds)
858
- y_true = np.concatenate(all_true)
859
- y_prob = np.concatenate(all_probs) if all_probs else None
860
-
861
- # --- Routing Logic ---
862
- # Single-target regression
863
- if self.kind == MLTaskKeys.REGRESSION:
864
- # Check configuration
865
- config = None
866
- if format_configuration and isinstance(format_configuration, RegressionMetricsFormat):
867
- config = format_configuration
868
- elif format_configuration:
869
- _LOGGER.warning(f"Wrong configuration type: Received '{type(format_configuration).__name__}'.")
870
-
871
- regression_metrics(y_true=y_true.flatten(),
872
- y_pred=y_pred.flatten(),
873
- save_dir=save_dir,
874
- config=config)
875
-
876
- # single target classification
877
- elif self.kind in [MLTaskKeys.BINARY_CLASSIFICATION,
878
- MLTaskKeys.BINARY_IMAGE_CLASSIFICATION,
879
- MLTaskKeys.MULTICLASS_CLASSIFICATION,
880
- MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION]:
881
- # get the class map if it exists
882
- try:
883
- class_map = dataset_for_artifacts.class_map # type: ignore
884
- except AttributeError:
885
- _LOGGER.warning(f"Dataset has no 'class_map' attribute. Using generics.")
886
- class_map = None
887
- else:
888
- if not isinstance(class_map, dict):
889
- _LOGGER.warning(f"Dataset has a 'class_map' attribute, but it is not a dictionary: '{type(class_map)}'.")
890
- class_map = None
891
-
892
- # Check configuration
893
- config = None
894
- if format_configuration:
895
- if self.kind == MLTaskKeys.BINARY_CLASSIFICATION and isinstance(format_configuration, BinaryClassificationMetricsFormat):
896
- config = format_configuration
897
- elif self.kind == MLTaskKeys.BINARY_IMAGE_CLASSIFICATION and isinstance(format_configuration, BinaryImageClassificationMetricsFormat):
898
- config = format_configuration
899
- elif self.kind == MLTaskKeys.MULTICLASS_CLASSIFICATION and isinstance(format_configuration, MultiClassClassificationMetricsFormat):
900
- config = format_configuration
901
- elif self.kind == MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION and isinstance(format_configuration, MultiClassImageClassificationMetricsFormat):
902
- config = format_configuration
903
- else:
904
- _LOGGER.warning(f"Wrong configuration type: Received '{type(format_configuration).__name__}'.")
905
-
906
- classification_metrics(save_dir=save_dir,
907
- y_true=y_true,
908
- y_pred=y_pred,
909
- y_prob=y_prob,
910
- class_map=class_map,
911
- config=config)
912
-
913
- # multitarget regression
914
- elif self.kind == MLTaskKeys.MULTITARGET_REGRESSION:
915
- try:
916
- target_names = dataset_for_artifacts.target_names # type: ignore
917
- except AttributeError:
918
- num_targets = y_true.shape[1]
919
- target_names = [f"target_{i}" for i in range(num_targets)]
920
- _LOGGER.warning(f"Dataset has no 'target_names' attribute. Using generic names.")
921
-
922
- # Check configuration
923
- config = None
924
- if format_configuration and isinstance(format_configuration, MultiTargetRegressionMetricsFormat):
925
- config = format_configuration
926
- elif format_configuration:
927
- _LOGGER.warning(f"Wrong configuration type: Received '{type(format_configuration).__name__}'.")
928
-
929
- multi_target_regression_metrics(y_true=y_true,
930
- y_pred=y_pred,
931
- target_names=target_names,
932
- save_dir=save_dir,
933
- config=config)
934
-
935
- # multi-label binary classification
936
- elif self.kind == MLTaskKeys.MULTILABEL_BINARY_CLASSIFICATION:
937
- try:
938
- target_names = dataset_for_artifacts.target_names # type: ignore
939
- except AttributeError:
940
- num_targets = y_true.shape[1]
941
- target_names = [f"label_{i}" for i in range(num_targets)]
942
- _LOGGER.warning(f"Dataset has no 'target_names' attribute. Using generic names.")
943
-
944
- if y_prob is None:
945
- _LOGGER.error("Evaluation for multi_label_classification requires probabilities (y_prob).")
946
- return
947
-
948
- # Check configuration
949
- config = None
950
- if format_configuration and isinstance(format_configuration, MultiLabelBinaryClassificationMetricsFormat):
951
- config = format_configuration
952
- elif format_configuration:
953
- _LOGGER.warning(f"Wrong configuration type: Received '{type(format_configuration).__name__}'.")
954
-
955
- multi_label_classification_metrics(y_true=y_true,
956
- y_pred=y_pred,
957
- y_prob=y_prob,
958
- target_names=target_names,
959
- save_dir=save_dir,
960
- config=config)
961
-
962
- # Segmentation tasks
963
- elif self.kind in [MLTaskKeys.BINARY_SEGMENTATION, MLTaskKeys.MULTICLASS_SEGMENTATION]:
964
- class_names = None
965
- try:
966
- # Try to get 'classes' from VisionDatasetMaker
967
- if hasattr(dataset_for_artifacts, 'classes'):
968
- class_names = dataset_for_artifacts.classes # type: ignore
969
- # Fallback for Subset
970
- elif hasattr(dataset_for_artifacts, 'dataset') and hasattr(dataset_for_artifacts.dataset, 'classes'): # type: ignore
971
- class_names = dataset_for_artifacts.dataset.classes # type: ignore
972
- except AttributeError:
973
- pass # class_names is still None
974
-
975
- if class_names is None:
976
- try:
977
- # Fallback to 'target_names'
978
- class_names = dataset_for_artifacts.target_names # type: ignore
979
- except AttributeError:
980
- # Fallback to inferring from labels
981
- labels = np.unique(y_true)
982
- class_names = [f"Class {i}" for i in labels]
983
- _LOGGER.warning(f"Dataset has no 'classes' or 'target_names' attribute. Using generic names.")
984
-
985
- # Check configuration
986
- config = None
987
- if format_configuration and isinstance(format_configuration, (BinarySegmentationMetricsFormat, MultiClassSegmentationMetricsFormat)):
988
- config = format_configuration
989
- elif format_configuration:
990
- _LOGGER.warning(f"Wrong configuration type: Received '{type(format_configuration).__name__}'.")
991
-
992
- segmentation_metrics(y_true=y_true,
993
- y_pred=y_pred,
994
- save_dir=save_dir,
995
- class_names=class_names,
996
- config=config)
997
-
998
- def explain_shap(self,
999
- save_dir: Union[str,Path],
1000
- explain_dataset: Optional[Dataset] = None,
1001
- n_samples: int = 300,
1002
- feature_names: Optional[List[str]] = None,
1003
- target_names: Optional[List[str]] = None,
1004
- explainer_type: Literal['deep', 'kernel'] = 'kernel'):
1005
- """
1006
- Explains model predictions using SHAP and saves all artifacts.
1007
-
1008
- NOTE: SHAP support is limited to single-target tasks (Regression, Binary/Multiclass Classification).
1009
- For complex tasks (Multi-target, Multi-label, Sequences, Images), please use `explain_captum()`.
1010
-
1011
- The background data is automatically sampled from the trainer's training dataset.
1012
-
1013
- This method automatically routes to the appropriate SHAP summary plot
1014
- function based on the task. If `feature_names` or `target_names` (multi-target) are not provided,
1015
- it will attempt to extract them from the dataset.
1016
-
1017
- Args:
1018
- explain_dataset (Dataset | None): A specific dataset to explain.
1019
- If None, the trainer's test dataset is used.
1020
- n_samples (int): The number of samples to use for both background and explanation.
1021
- feature_names (list[str] | None): Feature names. If None, the names will be extracted from the Dataset and raise an error on failure.
1022
- target_names (list[str] | None): Target names for multi-target tasks.
1023
- save_dir (str | Path): Directory to save all SHAP artifacts.
1024
- explainer_type (Literal['deep', 'kernel']): The explainer to use.
1025
- - 'deep': Uses shap.DeepExplainer. Fast and efficient for PyTorch models.
1026
- - 'kernel': Uses shap.KernelExplainer. Model-agnostic but EXTREMELY slow and memory-intensive. Use with a very low 'n_samples'< 100.
1027
- """
1028
- # --- 1. Compatibility Guard ---
1029
- valid_shap_tasks = [
1030
- MLTaskKeys.REGRESSION,
1031
- MLTaskKeys.BINARY_CLASSIFICATION,
1032
- MLTaskKeys.MULTICLASS_CLASSIFICATION
1033
- ]
1034
-
1035
- if self.kind not in valid_shap_tasks:
1036
- _LOGGER.warning(f"SHAP explanation is deprecated for task '{self.kind}' due to instability. Please use 'explain_captum()' instead.")
1037
- return
1038
-
1039
- # memory efficient helper
1040
- def _get_random_sample(dataset: Dataset, num_samples: int):
1041
- """
1042
- Memory-efficiently samples data from a dataset.
1043
- """
1044
- if dataset is None:
1045
- return None
1046
-
1047
- dataset_len = len(dataset) # type: ignore
1048
- if dataset_len == 0:
1049
- return None
1050
-
1051
- # For MPS devices, num_workers must be 0 to ensure stability
1052
- loader_workers = 0 if self.device.type == 'mps' else self.dataloader_workers
1053
-
1054
- # Ensure batch_size is not larger than the dataset itself
1055
- batch_size = min(num_samples, 64, dataset_len)
1056
-
1057
- loader = DataLoader(
1058
- dataset,
1059
- batch_size=batch_size,
1060
- shuffle=True, # Shuffle to get random samples
1061
- num_workers=loader_workers
1062
- )
1063
-
1064
- collected_features = []
1065
- num_collected = 0
1066
-
1067
- for features, _ in loader:
1068
- collected_features.append(features)
1069
- num_collected += features.size(0)
1070
- if num_collected >= num_samples:
1071
- break # Stop once we have enough samples
1072
-
1073
- if not collected_features:
1074
- return None
1075
-
1076
- full_data = torch.cat(collected_features, dim=0)
1077
-
1078
- # If we collected more than needed, trim it down
1079
- if full_data.size(0) > num_samples:
1080
- return full_data[:num_samples]
1081
-
1082
- return full_data
1083
-
1084
- # print(f"\n--- Preparing SHAP Data (sampling up to {n_samples} instances) ---")
1085
-
1086
- # 1. Get background data from the trainer's train_dataset
1087
- background_data = _get_random_sample(self.train_dataset, n_samples)
1088
- if background_data is None:
1089
- _LOGGER.error("Trainer's train_dataset is empty or invalid. Skipping SHAP analysis.")
1090
- return
1091
-
1092
- # 2. Determine target dataset and get explanation instances
1093
- target_dataset = explain_dataset if explain_dataset is not None else self.validation_dataset
1094
- instances_to_explain = _get_random_sample(target_dataset, n_samples)
1095
- if instances_to_explain is None:
1096
- _LOGGER.error("Explanation dataset is empty or invalid. Skipping SHAP analysis.")
1097
- return
1098
-
1099
- # attempt to get feature names
1100
- if feature_names is None:
1101
- # _LOGGER.info("`feature_names` not provided. Attempting to extract from dataset...")
1102
- if hasattr(target_dataset, DatasetKeys.FEATURE_NAMES):
1103
- feature_names = target_dataset.feature_names # type: ignore
1104
- else:
1105
- _LOGGER.error(f"Could not extract `feature_names` from the dataset. It must be provided if the dataset object does not have a '{DatasetKeys.FEATURE_NAMES}' attribute.")
1106
- raise ValueError()
1107
-
1108
- # move model to device
1109
- self.model.to(self.device)
1110
-
1111
- # 3. Call the plotting function
1112
- if self.kind in [MLTaskKeys.REGRESSION, MLTaskKeys.BINARY_CLASSIFICATION, MLTaskKeys.MULTICLASS_CLASSIFICATION]:
1113
- shap_summary_plot(
1114
- model=self.model,
1115
- background_data=background_data,
1116
- instances_to_explain=instances_to_explain,
1117
- feature_names=feature_names,
1118
- save_dir=save_dir,
1119
- explainer_type=explainer_type,
1120
- device=self.device
1121
- )
1122
- # DEPRECATED: Multi-target SHAP support is unstable; recommend Captum instead.
1123
- elif self.kind in [MLTaskKeys.MULTITARGET_REGRESSION, MLTaskKeys.MULTILABEL_BINARY_CLASSIFICATION]:
1124
- # try to get target names
1125
- if target_names is None:
1126
- target_names = []
1127
- if hasattr(target_dataset, DatasetKeys.TARGET_NAMES):
1128
- target_names = target_dataset.target_names # type: ignore
1129
- else:
1130
- # Infer number of targets from the model's output layer
1131
- try:
1132
- num_targets = self.model.output_layer.out_features # type: ignore
1133
- target_names = [f"target_{i}" for i in range(num_targets)] # type: ignore
1134
- _LOGGER.warning("Dataset has no 'target_names' attribute. Using generic names.")
1135
- except AttributeError:
1136
- _LOGGER.error("Cannot determine target names for multi-target SHAP plot. Skipping.")
1137
- return
1138
-
1139
- multi_target_shap_summary_plot(
1140
- model=self.model,
1141
- background_data=background_data,
1142
- instances_to_explain=instances_to_explain,
1143
- feature_names=feature_names, # type: ignore
1144
- target_names=target_names, # type: ignore
1145
- save_dir=save_dir,
1146
- explainer_type=explainer_type,
1147
- device=self.device
1148
- )
1149
-
1150
- def explain_captum(self,
1151
- save_dir: Union[str, Path],
1152
- explain_dataset: Optional[Dataset] = None,
1153
- n_samples: int = 100,
1154
- feature_names: Optional[List[str]] = None,
1155
- target_names: Optional[List[str]] = None,
1156
- n_steps: int = 50):
1157
- """
1158
- Explains model predictions using Captum's Integrated Gradients.
1159
-
1160
- - **Tabular/Classification:** Generates Feature Importance Bar Charts.
1161
- - **Segmentation:** Generates Spatial Heatmaps for each class.
1162
-
1163
- Args:
1164
- save_dir (str | Path): Directory to save artifacts.
1165
- explain_dataset (Dataset | None): Dataset to sample from. Defaults to validation set.
1166
- n_samples (int): Number of samples to evaluate.
1167
- feature_names (list[str] | None): Feature names.
1168
- - Required for Tabular tasks.
1169
- - Ignored/Optional for Image tasks (defaults to Channel names).
1170
- target_names (list[str] | None): Names for the model outputs (or Class names).
1171
- - If None, attempts to extract from dataset attributes (`target_names`, `classes`, or `class_map`).
1172
- - If extraction fails, generates generic names (e.g. "Output_0").
1173
- n_steps (int): Number of interpolation steps.
1174
- """
1175
- # 1. Check availability
1176
- if not _is_captum_available():
1177
- _LOGGER.error("Captum is not installed or could not be imported.")
1178
- return
1179
-
1180
- # 2. Prepare Data
1181
- dataset_to_use = explain_dataset if explain_dataset is not None else self.validation_dataset
1182
- if dataset_to_use is None:
1183
- _LOGGER.error("No dataset available for explanation.")
1184
- return
1185
-
1186
- # Efficient sampling helper
1187
- def _get_samples(ds, n):
1188
- # Use num_workers=0 for stability during ad-hoc sampling
1189
- loader = DataLoader(ds, batch_size=n, shuffle=True, num_workers=0)
1190
- data_iter = iter(loader)
1191
- features, targets = next(data_iter)
1192
- return features, targets
1193
-
1194
- input_data, _ = _get_samples(dataset_to_use, n_samples)
1195
-
1196
- # 3. Get Feature Names (Only if NOT segmentation AND NOT image classification)
1197
- # Image tasks generally don't have explicit feature names; Captum will default to "Channel_X"
1198
- is_segmentation = self.kind in [MLTaskKeys.BINARY_SEGMENTATION, MLTaskKeys.MULTICLASS_SEGMENTATION]
1199
- is_image_classification = self.kind in [MLTaskKeys.BINARY_IMAGE_CLASSIFICATION, MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION]
1200
-
1201
- if feature_names is None and not is_segmentation and not is_image_classification:
1202
- if hasattr(dataset_to_use, DatasetKeys.FEATURE_NAMES):
1203
- feature_names = dataset_to_use.feature_names # type: ignore
1204
- else:
1205
- _LOGGER.error(f"Could not extract `feature_names`. It must be provided if the dataset does not have it.")
1206
- raise ValueError()
1207
-
1208
- # 4. Handle Target Names (or Class Names)
1209
- if target_names is None:
1210
- # A. Try dataset attributes first
1211
- if hasattr(dataset_to_use, DatasetKeys.TARGET_NAMES):
1212
- target_names = dataset_to_use.target_names # type: ignore
1213
- elif hasattr(dataset_to_use, "classes"):
1214
- target_names = dataset_to_use.classes # type: ignore
1215
- elif hasattr(dataset_to_use, "class_map") and isinstance(dataset_to_use.class_map, dict): # type: ignore
1216
- # Sort by value (index) to ensure correct order: {name: index} -> [name_at_0, name_at_1...]
1217
- sorted_items = sorted(dataset_to_use.class_map.items(), key=lambda item: item[1]) # type: ignore
1218
- target_names = [k for k, v in sorted_items]
1219
-
1220
- # B. Infer based on task
1221
- if target_names is None:
1222
- if self.kind in [MLTaskKeys.REGRESSION, MLTaskKeys.BINARY_CLASSIFICATION, MLTaskKeys.BINARY_IMAGE_CLASSIFICATION]:
1223
- target_names = ["Output"]
1224
- elif self.kind == MLTaskKeys.BINARY_SEGMENTATION:
1225
- target_names = ["Foreground"]
1226
-
1227
- # For multiclass/multitarget without names, leave it None and let the evaluation function generate generics.
1228
-
1229
- # 5. Dispatch based on Task
1230
- if is_segmentation:
1231
- # lower n_steps for segmentation to save memory
1232
- if n_steps > 30:
1233
- n_steps = 30
1234
- _LOGGER.warning(f"Segmentation task detected: Reducing Captum n_steps to {n_steps} to prevent OOM. If you encounter OOM errors, consider lowering this further.")
1235
-
1236
- captum_segmentation_heatmap(
1237
- model=self.model,
1238
- input_data=input_data,
1239
- save_dir=save_dir,
1240
- target_names=target_names, # Can be None, helper handles it
1241
- n_steps=n_steps,
1242
- device=self.device
1243
- )
1244
-
1245
- elif is_image_classification:
1246
- captum_image_heatmap(
1247
- model=self.model,
1248
- input_data=input_data,
1249
- save_dir=save_dir,
1250
- target_names=target_names,
1251
- n_steps=n_steps,
1252
- device=self.device
1253
- )
1254
-
1255
- else:
1256
- # Standard Tabular/Image Classification
1257
- captum_feature_importance(
1258
- model=self.model,
1259
- input_data=input_data,
1260
- feature_names=feature_names,
1261
- save_dir=save_dir,
1262
- target_names=target_names,
1263
- n_steps=n_steps,
1264
- device=self.device
1265
- )
1266
-
1267
- def _attention_helper(self, dataloader: DataLoader):
1268
- """
1269
- Private method to yield model attention weights batch by batch for evaluation.
1270
-
1271
- Args:
1272
- dataloader (DataLoader): The dataloader to predict on.
1273
-
1274
- Yields:
1275
- (torch.Tensor): Attention weights
1276
- """
1277
- self.model.eval()
1278
- self.model.to(self.device)
1279
-
1280
- with torch.no_grad():
1281
- for features, target in dataloader:
1282
- features = features.to(self.device)
1283
- attention_weights = None
1284
-
1285
- # Get model output
1286
- # Unpack logits and weights from the special forward method
1287
- _output, attention_weights = self.model.forward_attention(features) # type: ignore
1288
-
1289
- if attention_weights is not None:
1290
- attention_weights = attention_weights.cpu()
1291
-
1292
- yield attention_weights
1293
-
1294
- def explain_attention(self, save_dir: Union[str, Path],
1295
- feature_names: Optional[List[str]] = None,
1296
- explain_dataset: Optional[Dataset] = None,
1297
- plot_n_features: int = 10):
1298
- """
1299
- Generates and saves a feature importance plot based on attention weights.
1300
-
1301
- This method only works for models with models with 'has_interpretable_attention'.
1302
-
1303
- Args:
1304
- save_dir (str | Path): Directory to save the plot and summary data.
1305
- feature_names (List[str] | None): Names for the features for plot labeling. If None, the names will be extracted from the Dataset and raise an error on failure.
1306
- explain_dataset (Dataset, optional): A specific dataset to explain. If None, the trainer's test dataset is used.
1307
- plot_n_features (int): Number of top features to plot.
1308
- """
1309
-
1310
- # print("\n--- Attention Analysis ---")
1311
-
1312
- # --- Step 1: Check if the model supports this explanation ---
1313
- if not getattr(self.model, 'has_interpretable_attention', False):
1314
- _LOGGER.warning("Model is not compatible with interpretable attention analysis. Skipping.")
1315
- return
1316
-
1317
- # --- Step 2: Set up the dataloader ---
1318
- dataset_to_use = explain_dataset if explain_dataset is not None else self.validation_dataset
1319
- if not isinstance(dataset_to_use, Dataset):
1320
- _LOGGER.error("The explanation dataset is empty or invalid. Skipping attention analysis.")
1321
- return
1322
-
1323
- # Get feature names
1324
- if feature_names is None:
1325
- if hasattr(dataset_to_use, DatasetKeys.FEATURE_NAMES):
1326
- feature_names = dataset_to_use.feature_names # type: ignore
1327
- else:
1328
- _LOGGER.error(f"Could not extract `feature_names` from the dataset for attention plot. It must be provided if the dataset object does not have a '{DatasetKeys.FEATURE_NAMES}' attribute.")
1329
- raise ValueError()
1330
-
1331
- explain_loader = DataLoader(
1332
- dataset=dataset_to_use, batch_size=32, shuffle=False,
1333
- num_workers=0 if self.device.type == 'mps' else self.dataloader_workers,
1334
- pin_memory=("cuda" in self.device.type)
1335
- )
1336
-
1337
- # --- Step 3: Collect weights ---
1338
- all_weights = []
1339
- for att_weights_b in self._attention_helper(explain_loader):
1340
- if att_weights_b is not None:
1341
- all_weights.append(att_weights_b)
1342
-
1343
- # --- Step 4: Call the plotting function ---
1344
- if all_weights:
1345
- plot_attention_importance(
1346
- weights=all_weights,
1347
- feature_names=feature_names,
1348
- save_dir=save_dir,
1349
- top_n=plot_n_features
1350
- )
1351
- else:
1352
- _LOGGER.error("No attention weights were collected from the model.")
1353
-
1354
- def finalize_model_training(self,
1355
- model_checkpoint: Union[Path, Literal['best', 'current']],
1356
- save_dir: Union[str, Path],
1357
- finalize_config: Union[FinalizeRegression,
1358
- FinalizeMultiTargetRegression,
1359
- FinalizeBinaryClassification,
1360
- FinalizeBinaryImageClassification,
1361
- FinalizeMultiClassClassification,
1362
- FinalizeMultiClassImageClassification,
1363
- FinalizeBinarySegmentation,
1364
- FinalizeMultiClassSegmentation,
1365
- FinalizeMultiLabelBinaryClassification]):
1366
- """
1367
- Saves a finalized, "inference-ready" model state to a .pth file.
1368
-
1369
- This method saves the model's `state_dict`, the final epoch number, and optional configuration for the task at hand.
1370
-
1371
- Args:
1372
- model_checkpoint (Path | "best" | "current"):
1373
- - Path: Loads the model state from a specific checkpoint file.
1374
- - "best": Loads the best model state saved by the `DragonModelCheckpoint` callback.
1375
- - "current": Uses the model's state as it is.
1376
- save_dir (str | Path): The directory to save the finalized model.
1377
- finalize_config (object): A data class instance specific to the ML task containing task-specific metadata required for inference.
1378
- """
1379
- if self.kind == MLTaskKeys.REGRESSION and not isinstance(finalize_config, FinalizeRegression):
1380
- _LOGGER.error(f"For task {self.kind}, expected finalize_config of type 'FinalizeRegression', but got {type(finalize_config).__name__}.")
1381
- raise TypeError()
1382
- elif self.kind == MLTaskKeys.MULTITARGET_REGRESSION and not isinstance(finalize_config, FinalizeMultiTargetRegression):
1383
- _LOGGER.error(f"For task {self.kind}, expected finalize_config of type 'FinalizeMultiTargetRegression', but got {type(finalize_config).__name__}.")
1384
- raise TypeError()
1385
- elif self.kind == MLTaskKeys.BINARY_CLASSIFICATION and not isinstance(finalize_config, FinalizeBinaryClassification):
1386
- _LOGGER.error(f"For task {self.kind}, expected finalize_config of type 'FinalizeBinaryClassification', but got {type(finalize_config).__name__}.")
1387
- raise TypeError()
1388
- elif self.kind == MLTaskKeys.BINARY_IMAGE_CLASSIFICATION and not isinstance(finalize_config, FinalizeBinaryImageClassification):
1389
- _LOGGER.error(f"For task {self.kind}, expected finalize_config of type 'FinalizeBinaryImageClassification', but got {type(finalize_config).__name__}.")
1390
- raise TypeError()
1391
- elif self.kind == MLTaskKeys.MULTICLASS_CLASSIFICATION and not isinstance(finalize_config, FinalizeMultiClassClassification):
1392
- _LOGGER.error(f"For task {self.kind}, expected finalize_config of type 'FinalizeMultiClassClassification', but got {type(finalize_config).__name__}.")
1393
- raise TypeError()
1394
- elif self.kind == MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION and not isinstance(finalize_config, FinalizeMultiClassImageClassification):
1395
- _LOGGER.error(f"For task {self.kind}, expected finalize_config of type 'FinalizeMultiClassImageClassification', but got {type(finalize_config).__name__}.")
1396
- raise TypeError()
1397
- elif self.kind == MLTaskKeys.BINARY_SEGMENTATION and not isinstance(finalize_config, FinalizeBinarySegmentation):
1398
- _LOGGER.error(f"For task {self.kind}, expected finalize_config of type 'FinalizeBinarySegmentation', but got {type(finalize_config).__name__}.")
1399
- raise TypeError()
1400
- elif self.kind == MLTaskKeys.MULTICLASS_SEGMENTATION and not isinstance(finalize_config, FinalizeMultiClassSegmentation):
1401
- _LOGGER.error(f"For task {self.kind}, expected finalize_config of type 'FinalizeMultiClassSegmentation', but got {type(finalize_config).__name__}.")
1402
- raise TypeError()
1403
- elif self.kind == MLTaskKeys.MULTILABEL_BINARY_CLASSIFICATION and not isinstance(finalize_config, FinalizeMultiLabelBinaryClassification):
1404
- _LOGGER.error(f"For task {self.kind}, expected finalize_config of type 'FinalizeMultiLabelBinaryClassification', but got {type(finalize_config).__name__}.")
1405
- raise TypeError()
1406
-
1407
- # handle save path
1408
- dir_path = make_fullpath(save_dir, make=True, enforce="directory")
1409
- full_path = dir_path / finalize_config.filename
1410
-
1411
- # handle checkpoint
1412
- self._load_model_state_for_finalizing(model_checkpoint)
1413
-
1414
- # Create finalized data
1415
- finalized_data = {
1416
- PyTorchCheckpointKeys.EPOCH: self.epoch,
1417
- PyTorchCheckpointKeys.MODEL_STATE: self.model.state_dict(),
1418
- PyTorchCheckpointKeys.TASK: finalize_config.task
1419
- }
1420
-
1421
- # Parse config
1422
- if finalize_config.target_name is not None:
1423
- finalized_data[PyTorchCheckpointKeys.TARGET_NAME] = finalize_config.target_name
1424
- if finalize_config.target_names is not None:
1425
- finalized_data[PyTorchCheckpointKeys.TARGET_NAMES] = finalize_config.target_names
1426
- if finalize_config.classification_threshold is not None:
1427
- finalized_data[PyTorchCheckpointKeys.CLASSIFICATION_THRESHOLD] = finalize_config.classification_threshold
1428
- if finalize_config.class_map is not None:
1429
- finalized_data[PyTorchCheckpointKeys.CLASS_MAP] = finalize_config.class_map
1430
-
1431
- # Save model file
1432
- torch.save(finalized_data, full_path)
1433
-
1434
- _LOGGER.info(f"Finalized model file saved to '{full_path}'")
1435
-
1436
-
1437
- # Object Detection Trainer
1438
- class DragonDetectionTrainer(_BaseDragonTrainer):
1439
- def __init__(self, model: nn.Module,
1440
- train_dataset: Dataset,
1441
- validation_dataset: Dataset,
1442
- collate_fn: Callable, optimizer: torch.optim.Optimizer,
1443
- device: Union[Literal['cuda', 'mps', 'cpu'],str],
1444
- checkpoint_callback: Optional[DragonModelCheckpoint],
1445
- early_stopping_callback: Optional[_DragonEarlyStopping],
1446
- lr_scheduler_callback: Optional[_DragonLRScheduler],
1447
- extra_callbacks: Optional[List[_Callback]] = None,
1448
- dataloader_workers: int = 2):
1449
- """
1450
- Automates the training process of an Object Detection Model (e.g., DragonFastRCNN).
1451
-
1452
- Built-in Callbacks: `History`, `TqdmProgressBar`
1453
-
1454
- Args:
1455
- model (nn.Module): The PyTorch object detection model to train.
1456
- train_dataset (Dataset): The training dataset.
1457
- validation_dataset (Dataset): The testing/validation dataset.
1458
- collate_fn (Callable): The collate function from `ObjectDetectionDatasetMaker.collate_fn`.
1459
- optimizer (torch.optim.Optimizer): The optimizer.
1460
- device (str): The device to run training on ('cpu', 'cuda', 'mps').
1461
- dataloader_workers (int): Subprocesses for data loading.
1462
- checkpoint_callback (DragonModelCheckpoint | None): Callback to save the model.
1463
- early_stopping_callback (DragonEarlyStopping | None): Callback to stop training early.
1464
- lr_scheduler_callback (DragonLRScheduler | None): Callback to manage the LR scheduler.
1465
- extra_callbacks (List[Callback] | None): A list of extra callbacks to use during training.
1466
-
1467
- ## Note:
1468
- This trainer is specialized. It does not take a `criterion` because object detection models like Faster R-CNN return a dictionary of losses directly from their forward pass during training.
1469
- """
1470
- # Call the base class constructor with common parameters
1471
- super().__init__(
1472
- model=model,
1473
- optimizer=optimizer,
1474
- device=device,
1475
- dataloader_workers=dataloader_workers,
1476
- checkpoint_callback=checkpoint_callback,
1477
- early_stopping_callback=early_stopping_callback,
1478
- lr_scheduler_callback=lr_scheduler_callback,
1479
- extra_callbacks=extra_callbacks
1480
- )
1481
-
1482
- self.train_dataset = train_dataset
1483
- self.validation_dataset = validation_dataset # <-- Renamed
1484
- self.kind = MLTaskKeys.OBJECT_DETECTION
1485
- self.collate_fn = collate_fn
1486
- self.criterion = None # Criterion is handled inside the model
1487
-
1488
- def _create_dataloaders(self, batch_size: int, shuffle: bool):
1489
- """Initializes the DataLoaders with the object detection collate_fn."""
1490
- # Ensure stability on MPS devices by setting num_workers to 0
1491
- loader_workers = 0 if self.device.type == 'mps' else self.dataloader_workers
1492
-
1493
- self.train_loader = DataLoader(
1494
- dataset=self.train_dataset,
1495
- batch_size=batch_size,
1496
- shuffle=shuffle,
1497
- num_workers=loader_workers,
1498
- pin_memory=("cuda" in self.device.type),
1499
- collate_fn=self.collate_fn, # Use the provided collate function
1500
- drop_last=True
1501
- )
1502
-
1503
- self.validation_loader = DataLoader(
1504
- dataset=self.validation_dataset,
1505
- batch_size=batch_size,
1506
- shuffle=False,
1507
- num_workers=loader_workers,
1508
- pin_memory=("cuda" in self.device.type),
1509
- collate_fn=self.collate_fn # Use the provided collate function
1510
- )
1511
-
1512
- def _train_step(self):
1513
- self.model.train()
1514
- running_loss = 0.0
1515
- total_samples = 0
1516
-
1517
- for batch_idx, (images, targets) in enumerate(self.train_loader): # type: ignore
1518
- # images is a tuple of tensors, targets is a tuple of dicts
1519
- batch_size = len(images)
1520
-
1521
- # Create a log dictionary for the batch
1522
- batch_logs = {
1523
- PyTorchLogKeys.BATCH_INDEX: batch_idx,
1524
- PyTorchLogKeys.BATCH_SIZE: batch_size
1525
- }
1526
- self._callbacks_hook('on_batch_begin', batch_idx, logs=batch_logs)
1527
-
1528
- # Move data to device
1529
- images = list(img.to(self.device) for img in images)
1530
- targets = [{k: v.to(self.device) for k, v in t.items()} for t in targets]
1531
-
1532
- self.optimizer.zero_grad()
1533
-
1534
- # Model returns a loss dict when in train() mode and targets are passed
1535
- loss_dict = self.model(images, targets)
1536
-
1537
- if not loss_dict:
1538
- # No losses returned, skip batch
1539
- _LOGGER.warning(f"Model returned no losses for batch {batch_idx}. Skipping.")
1540
- batch_logs[PyTorchLogKeys.BATCH_LOSS] = 0
1541
- self._callbacks_hook('on_batch_end', batch_idx, logs=batch_logs)
1542
- continue
1543
-
1544
- # Sum all losses
1545
- loss: torch.Tensor = sum(l for l in loss_dict.values()) # type: ignore
1546
-
1547
- loss.backward()
1548
- self.optimizer.step()
1549
-
1550
- # Calculate batch loss and update running loss for the epoch
1551
- batch_loss = loss.item()
1552
- running_loss += batch_loss * batch_size
1553
- total_samples += batch_size # <-- Accumulate total samples
1554
-
1555
- # Add the batch loss to the logs and call the end-of-batch hook
1556
- batch_logs[PyTorchLogKeys.BATCH_LOSS] = batch_loss # type: ignore
1557
- self._callbacks_hook('on_batch_end', batch_idx, logs=batch_logs)
1558
-
1559
- # Calculate loss using the correct denominator
1560
- if total_samples == 0:
1561
- _LOGGER.warning("No samples processed in _train_step. Returning 0 loss.")
1562
- return {PyTorchLogKeys.TRAIN_LOSS: 0.0}
1563
-
1564
- return {PyTorchLogKeys.TRAIN_LOSS: running_loss / total_samples}
1565
-
1566
- def _validation_step(self):
1567
- self.model.train() # Set to train mode even for validation loss calculation
1568
- # as model internals (e.g., proposals) might differ, but we still need loss_dict.
1569
- # use torch.no_grad() to prevent gradient updates.
1570
- running_loss = 0.0
1571
- total_samples = 0
1572
-
1573
- with torch.no_grad():
1574
- for images, targets in self.validation_loader: # type: ignore
1575
- batch_size = len(images)
1576
-
1577
- # Move data to device
1578
- images = list(img.to(self.device) for img in images)
1579
- targets = [{k: v.to(self.device) for k, v in t.items()} for t in targets]
1580
-
1581
- # Get loss dict
1582
- loss_dict = self.model(images, targets)
1583
-
1584
- if not loss_dict:
1585
- _LOGGER.warning("Model returned no losses during validation step. Skipping batch.")
1586
- continue # Skip if no losses
1587
-
1588
- # Sum all losses
1589
- loss: torch.Tensor = sum(l for l in loss_dict.values()) # type: ignore
1590
-
1591
- running_loss += loss.item() * batch_size
1592
- total_samples += batch_size # <-- Accumulate total samples
1593
-
1594
- # Calculate loss using the correct denominator
1595
- if total_samples == 0:
1596
- _LOGGER.warning("No samples processed in _validation_step. Returning 0 loss.")
1597
- return {PyTorchLogKeys.VAL_LOSS: 0.0}
1598
-
1599
- logs = {PyTorchLogKeys.VAL_LOSS: running_loss / total_samples}
1600
- return logs
1601
-
1602
- def evaluate(self,
1603
- save_dir: Union[str, Path],
1604
- model_checkpoint: Union[Path, Literal["best", "current"]],
1605
- test_data: Optional[Union[DataLoader, Dataset]] = None):
1606
- """
1607
- Evaluates the model using object detection mAP metrics.
1608
-
1609
- Args:
1610
- save_dir (str | Path): Directory to save all reports and plots.
1611
- model_checkpoint (Path | "best" | "current"):
1612
- - Path to a valid checkpoint for the model. The state of the trained model will be overwritten in place.
1613
- - If 'best', the best checkpoint will be loaded if a DragonModelCheckpoint was provided. The state of the trained model will be overwritten in place.
1614
- - If 'current', use the current state of the trained model up the latest trained epoch.
1615
- test_data (DataLoader | Dataset | None): Optional Test data to evaluate the model performance. Validation and Test metrics will be saved to subdirectories.
1616
- """
1617
- # Validate model checkpoint
1618
- if isinstance(model_checkpoint, Path):
1619
- checkpoint_validated = make_fullpath(model_checkpoint, enforce="file")
1620
- elif model_checkpoint in [MagicWords.BEST, MagicWords.CURRENT]:
1621
- checkpoint_validated = model_checkpoint
1622
- else:
1623
- _LOGGER.error(f"'model_checkpoint' must be a Path object, or the string '{MagicWords.BEST}', or the string '{MagicWords.CURRENT}'.")
1624
- raise ValueError()
1625
-
1626
- # Validate directory
1627
- save_path = make_fullpath(save_dir, make=True, enforce="directory")
1628
-
1629
- # Validate test data and dispatch
1630
- if test_data is not None:
1631
- if not isinstance(test_data, (DataLoader, Dataset)):
1632
- _LOGGER.error(f"Invalid type for 'test_data': '{type(test_data)}'.")
1633
- raise ValueError()
1634
- test_data_validated = test_data
1635
-
1636
- validation_metrics_path = save_path / DragonTrainerKeys.VALIDATION_METRICS_DIR
1637
- test_metrics_path = save_path / DragonTrainerKeys.TEST_METRICS_DIR
1638
-
1639
- # Dispatch validation set
1640
- _LOGGER.info(f"Evaluating on validation dataset. Metrics will be saved to '{DragonTrainerKeys.VALIDATION_METRICS_DIR}'")
1641
- self._evaluate(save_dir=validation_metrics_path,
1642
- model_checkpoint=checkpoint_validated,
1643
- data=None) # 'None' triggers use of self.test_dataset
1644
-
1645
- # Dispatch test set
1646
- _LOGGER.info(f"Evaluating on test dataset. Metrics will be saved to '{DragonTrainerKeys.TEST_METRICS_DIR}'")
1647
- self._evaluate(save_dir=test_metrics_path,
1648
- model_checkpoint="current", # Use 'current' state after loading checkpoint once
1649
- data=test_data_validated)
1650
- else:
1651
- # Dispatch validation set
1652
- _LOGGER.info(f"Evaluating on validation dataset. Metrics will be saved to '{save_path.name}'")
1653
- self._evaluate(save_dir=save_path,
1654
- model_checkpoint=checkpoint_validated,
1655
- data=None) # 'None' triggers use of self.test_dataset
1656
-
1657
- def _evaluate(self,
1658
- save_dir: Union[str, Path],
1659
- model_checkpoint: Union[Path, Literal["best", "current"]],
1660
- data: Optional[Union[DataLoader, Dataset]]):
1661
- """
1662
- Changed to a private helper method
1663
- Evaluates the model using object detection mAP metrics.
1664
-
1665
- Args:
1666
- save_dir (str | Path): Directory to save all reports and plots.
1667
- data (DataLoader | Dataset | None): The data to evaluate on. If None, defaults to the trainer's internal test_dataset.
1668
- model_checkpoint ('auto' | Path | None):
1669
- - Path to a valid checkpoint for the model. The state of the trained model will be overwritten in place.
1670
- - If 'best', the best checkpoint will be loaded if a DragonModelCheckpoint was provided. The state of the trained model will be overwritten in place.
1671
- - If 'current', use the current state of the trained model up the latest trained epoch.
1672
- """
1673
- dataset_for_artifacts = None
1674
- eval_loader = None
1675
-
1676
- # load model checkpoint
1677
- if isinstance(model_checkpoint, Path):
1678
- self._load_checkpoint(path=model_checkpoint)
1679
- elif model_checkpoint == MagicWords.BEST and self._checkpoint_callback:
1680
- path_to_latest = self._checkpoint_callback.best_checkpoint_path
1681
- self._load_checkpoint(path_to_latest)
1682
- elif model_checkpoint == MagicWords.BEST and self._checkpoint_callback is None:
1683
- _LOGGER.error(f"'model_checkpoint' set to '{MagicWords.BEST}' but no checkpoint callback was found.")
1684
- raise ValueError()
1685
-
1686
- # Dataloader
1687
- if isinstance(data, DataLoader):
1688
- eval_loader = data
1689
- if hasattr(data, 'dataset'):
1690
- dataset_for_artifacts = data.dataset # type: ignore
1691
- elif isinstance(data, Dataset):
1692
- # Create a new loader from the provided dataset
1693
- eval_loader = DataLoader(data,
1694
- batch_size=self._batch_size,
1695
- shuffle=False,
1696
- num_workers=0 if self.device.type == 'mps' else self.dataloader_workers,
1697
- pin_memory=(self.device.type == "cuda"),
1698
- collate_fn=self.collate_fn)
1699
- dataset_for_artifacts = data
1700
- else: # data is None, use the trainer's default test dataset
1701
- if self.validation_dataset is None:
1702
- _LOGGER.error("Cannot evaluate. No data provided and no test_dataset available in the trainer.")
1703
- raise ValueError()
1704
- # Create a fresh DataLoader from the test_dataset
1705
- eval_loader = DataLoader(
1706
- self.validation_dataset,
1707
- batch_size=self._batch_size,
1708
- shuffle=False,
1709
- num_workers=0 if self.device.type == 'mps' else self.dataloader_workers,
1710
- pin_memory=(self.device.type == "cuda"),
1711
- collate_fn=self.collate_fn
1712
- )
1713
- dataset_for_artifacts = self.validation_dataset
1714
-
1715
- if eval_loader is None:
1716
- _LOGGER.error("Cannot evaluate. No valid data was provided or found.")
1717
- raise ValueError()
1718
-
1719
- # print("\n--- Model Evaluation ---")
1720
-
1721
- all_predictions = []
1722
- all_targets = []
1723
-
1724
- self.model.eval() # Set model to evaluation mode
1725
- self.model.to(self.device)
1726
-
1727
- with torch.no_grad():
1728
- for images, targets in eval_loader:
1729
- # Move images to device
1730
- images = list(img.to(self.device) for img in images)
1731
-
1732
- # Model returns predictions when in eval() mode
1733
- predictions = self.model(images)
1734
-
1735
- # Move predictions and targets to CPU for aggregation
1736
- cpu_preds = [{k: v.to('cpu') for k, v in p.items()} for p in predictions]
1737
- cpu_targets = [{k: v.to('cpu') for k, v in t.items()} for t in targets]
1738
-
1739
- all_predictions.extend(cpu_preds)
1740
- all_targets.extend(cpu_targets)
1741
-
1742
- if not all_targets:
1743
- _LOGGER.error("Evaluation failed: No data was processed.")
1744
- return
1745
-
1746
- # Get class names from the dataset for the report
1747
- class_names = None
1748
- try:
1749
- # Try to get 'classes' from ObjectDetectionDatasetMaker
1750
- if hasattr(dataset_for_artifacts, 'classes'):
1751
- class_names = dataset_for_artifacts.classes # type: ignore
1752
- # Fallback for Subset
1753
- elif hasattr(dataset_for_artifacts, 'dataset') and hasattr(dataset_for_artifacts.dataset, 'classes'): # type: ignore
1754
- class_names = dataset_for_artifacts.dataset.classes # type: ignore
1755
- except AttributeError:
1756
- _LOGGER.warning("Could not find 'classes' attribute on dataset. Per-class metrics will not be named.")
1757
- pass # class_names is still None
1758
-
1759
- # --- Routing Logic ---
1760
- object_detection_metrics(
1761
- preds=all_predictions,
1762
- targets=all_targets,
1763
- save_dir=save_dir,
1764
- class_names=class_names,
1765
- print_output=False
1766
- )
1767
-
1768
- def finalize_model_training(self,
1769
- save_dir: Union[str, Path],
1770
- model_checkpoint: Union[Path, Literal['best', 'current']],
1771
- finalize_config: FinalizeObjectDetection
1772
- ):
1773
- """
1774
- Saves a finalized, "inference-ready" model state to a .pth file.
1775
-
1776
- This method saves the model's `state_dict` and the final epoch number.
1777
-
1778
- Args:
1779
- save_dir (Union[str, Path]): The directory to save the finalized model.
1780
- model_checkpoint (Union[Path, Literal["best", "current"]]):
1781
- - Path: Loads the model state from a specific checkpoint file.
1782
- - "best": Loads the best model state saved by the `DragonModelCheckpoint` callback.
1783
- - "current": Uses the model's state as it is.
1784
- finalize_config (FinalizeObjectDetection): A data class instance specific to the ML task containing task-specific metadata required for inference.
1785
- """
1786
- if not isinstance(finalize_config, FinalizeObjectDetection):
1787
- _LOGGER.error(f"For task {self.kind}, expected finalize_config of type 'FinalizeObjectDetection', but got {type(finalize_config).__name__}.")
1788
- raise TypeError()
1789
-
1790
- # handle save path
1791
- dir_path = make_fullpath(save_dir, make=True, enforce="directory")
1792
- full_path = dir_path / finalize_config.filename
1793
-
1794
- # handle checkpoint
1795
- self._load_model_state_for_finalizing(model_checkpoint)
1796
-
1797
- # Create finalized data
1798
- finalized_data = {
1799
- PyTorchCheckpointKeys.EPOCH: self.epoch,
1800
- PyTorchCheckpointKeys.MODEL_STATE: self.model.state_dict(),
1801
- PyTorchCheckpointKeys.TASK: finalize_config.task
1802
- }
1803
-
1804
- if finalize_config.class_map is not None:
1805
- finalized_data[PyTorchCheckpointKeys.CLASS_MAP] = finalize_config.class_map
1806
-
1807
- torch.save(finalized_data, full_path)
1808
-
1809
- _LOGGER.info(f"Finalized model file saved to '{full_path}'")
1810
-
1811
- # --- DragonSequenceTrainer ----
1812
- class DragonSequenceTrainer(_BaseDragonTrainer):
1813
- def __init__(self,
1814
- model: nn.Module,
1815
- train_dataset: Dataset,
1816
- validation_dataset: Dataset,
1817
- kind: Literal["sequence-to-sequence", "sequence-to-value"],
1818
- optimizer: torch.optim.Optimizer,
1819
- device: Union[Literal['cuda', 'mps', 'cpu'],str],
1820
- checkpoint_callback: Optional[DragonModelCheckpoint],
1821
- early_stopping_callback: Optional[_DragonEarlyStopping],
1822
- lr_scheduler_callback: Optional[_DragonLRScheduler],
1823
- extra_callbacks: Optional[List[_Callback]] = None,
1824
- criterion: Union[nn.Module,Literal["auto"]] = "auto",
1825
- dataloader_workers: int = 2):
1826
- """
1827
- Automates the training process of a PyTorch Sequence Model.
1828
-
1829
- Built-in Callbacks: `History`, `TqdmProgressBar`
1830
-
1831
- Args:
1832
- model (nn.Module): The PyTorch model to train.
1833
- train_dataset (Dataset): The training dataset.
1834
- validation_dataset (Dataset): The validation dataset.
1835
- kind (str): Used to redirect to the correct process ('sequence-to-sequence' or 'sequence-to-value').
1836
- criterion (nn.Module | "auto"): The loss function to use. If "auto", it will be inferred from the selected task
1837
- optimizer (torch.optim.Optimizer): The optimizer.
1838
- device (str): The device to run training on ('cpu', 'cuda', 'mps').
1839
- dataloader_workers (int): Subprocesses for data loading.
1840
- extra_callbacks (List[Callback] | None): A list of extra callbacks to use during training.
1841
- """
1842
- # Call the base class constructor with common parameters
1843
- super().__init__(
1844
- model=model,
1845
- optimizer=optimizer,
1846
- device=device,
1847
- dataloader_workers=dataloader_workers,
1848
- checkpoint_callback=checkpoint_callback,
1849
- early_stopping_callback=early_stopping_callback,
1850
- lr_scheduler_callback=lr_scheduler_callback,
1851
- extra_callbacks=extra_callbacks
1852
- )
1853
-
1854
- if kind not in [MLTaskKeys.SEQUENCE_SEQUENCE, MLTaskKeys.SEQUENCE_VALUE]:
1855
- raise ValueError(f"'{kind}' is not a valid task type for DragonSequenceTrainer.")
1856
-
1857
- self.train_dataset = train_dataset
1858
- self.validation_dataset = validation_dataset
1859
- self.kind = kind
1860
-
1861
- # try to validate against Dragon Sequence model
1862
- if hasattr(self.model, "prediction_mode"):
1863
- key_to_check: str = self.model.prediction_mode # type: ignore
1864
- if not key_to_check == self.kind:
1865
- _LOGGER.error(f"Trainer was set for '{self.kind}', but model architecture '{self.model}' is built for '{key_to_check}'.")
1866
- raise RuntimeError()
1867
-
1868
- # loss function
1869
- if criterion == "auto":
1870
- # Both sequence tasks are treated as regression problems
1871
- self.criterion = nn.MSELoss()
1872
- else:
1873
- self.criterion = criterion
1874
-
1875
- def _create_dataloaders(self, batch_size: int, shuffle: bool):
1876
- """Initializes the DataLoaders."""
1877
- # Ensure stability on MPS devices by setting num_workers to 0
1878
- loader_workers = 0 if self.device.type == 'mps' else self.dataloader_workers
1879
-
1880
- self.train_loader = DataLoader(
1881
- dataset=self.train_dataset,
1882
- batch_size=batch_size,
1883
- shuffle=shuffle,
1884
- num_workers=loader_workers,
1885
- pin_memory=("cuda" in self.device.type),
1886
- drop_last=True # Drops the last batch if incomplete, selecting a good batch size is key.
1887
- )
1888
-
1889
- self.validation_loader = DataLoader(
1890
- dataset=self.validation_dataset,
1891
- batch_size=batch_size,
1892
- shuffle=False,
1893
- num_workers=loader_workers,
1894
- pin_memory=("cuda" in self.device.type)
1895
- )
1896
-
1897
- def _train_step(self):
1898
- self.model.train()
1899
- running_loss = 0.0
1900
- total_samples = 0
1901
-
1902
- for batch_idx, (features, target) in enumerate(self.train_loader): # type: ignore
1903
- # Create a log dictionary for the batch
1904
- batch_logs = {
1905
- PyTorchLogKeys.BATCH_INDEX: batch_idx,
1906
- PyTorchLogKeys.BATCH_SIZE: features.size(0)
1907
- }
1908
- self._callbacks_hook('on_batch_begin', batch_idx, logs=batch_logs)
1909
-
1910
- features, target = features.to(self.device), target.to(self.device)
1911
- self.optimizer.zero_grad()
1912
-
1913
- output = self.model(features)
1914
-
1915
- # --- Label Type/Shape Correction ---
1916
- # Ensure target is float for MSELoss
1917
- target = target.float()
1918
-
1919
- # For seq-to-val, models might output [N, 1] but target is [N].
1920
- if self.kind == MLTaskKeys.SEQUENCE_VALUE:
1921
- if output.ndim == 2 and output.shape[1] == 1 and target.ndim == 1:
1922
- output = output.squeeze(1)
1923
-
1924
- # For seq-to-seq, models might output [N, Seq, 1] but target is [N, Seq].
1925
- elif self.kind == MLTaskKeys.SEQUENCE_SEQUENCE:
1926
- if output.ndim == 3 and output.shape[2] == 1 and target.ndim == 2:
1927
- output = output.squeeze(-1)
1928
-
1929
- loss = self.criterion(output, target)
1930
-
1931
- loss.backward()
1932
- self.optimizer.step()
1933
-
1934
- # Calculate batch loss and update running loss for the epoch
1935
- batch_loss = loss.item()
1936
- batch_size = features.size(0)
1937
- running_loss += batch_loss * batch_size # Accumulate total loss
1938
- total_samples += batch_size # total samples
1939
-
1940
- # Add the batch loss to the logs and call the end-of-batch hook
1941
- batch_logs[PyTorchLogKeys.BATCH_LOSS] = batch_loss
1942
- self._callbacks_hook('on_batch_end', batch_idx, logs=batch_logs)
1943
-
1944
- if total_samples == 0:
1945
- _LOGGER.warning("No samples processed in a train_step. Returning 0 loss.")
1946
- return {PyTorchLogKeys.TRAIN_LOSS: 0.0}
1947
-
1948
- return {PyTorchLogKeys.TRAIN_LOSS: running_loss / total_samples} # type: ignore
1949
-
1950
- def _validation_step(self):
1951
- self.model.eval()
1952
- running_loss = 0.0
1953
-
1954
- with torch.no_grad():
1955
- for features, target in self.validation_loader: # type: ignore
1956
- features, target = features.to(self.device), target.to(self.device)
1957
-
1958
- output = self.model(features)
1959
-
1960
- # --- Label Type/Shape Correction ---
1961
- target = target.float()
1962
-
1963
- # For seq-to-val, models might output [N, 1] but target is [N].
1964
- if self.kind == MLTaskKeys.SEQUENCE_VALUE:
1965
- if output.ndim == 2 and output.shape[1] == 1 and target.ndim == 1:
1966
- output = output.squeeze(1)
1967
-
1968
- # For seq-to-seq, models might output [N, Seq, 1] but target is [N, Seq].
1969
- elif self.kind == MLTaskKeys.SEQUENCE_SEQUENCE:
1970
- if output.ndim == 3 and output.shape[2] == 1 and target.ndim == 2:
1971
- output = output.squeeze(-1)
1972
-
1973
- loss = self.criterion(output, target)
1974
-
1975
- running_loss += loss.item() * features.size(0)
1976
-
1977
- if not self.validation_loader.dataset: # type: ignore
1978
- _LOGGER.warning("No samples processed in _validation_step. Returning 0 loss.")
1979
- return {PyTorchLogKeys.VAL_LOSS: 0.0}
1980
-
1981
- logs = {PyTorchLogKeys.VAL_LOSS: running_loss / len(self.validation_loader.dataset)} # type: ignore
1982
- return logs
1983
-
1984
- def _predict_for_eval(self, dataloader: DataLoader):
1985
- """
1986
- Private method to yield model predictions batch by batch for evaluation.
1987
-
1988
- Automatically checks for 'scaler'.
1989
-
1990
- Yields:
1991
- tuple: A tuple containing (y_pred_batch, y_prob_batch, y_true_batch).
1992
- y_prob_batch is always None for sequence tasks.
1993
- """
1994
- self.model.eval()
1995
- self.model.to(self.device)
1996
-
1997
- # --- Check for Scaler ---
1998
- # DragonDatasetSequence stores it as 'scaler'
1999
- scaler = None
2000
- if hasattr(self.train_dataset, ScalerKeys.TARGET_SCALER):
2001
- scaler = getattr(self.train_dataset, ScalerKeys.TARGET_SCALER)
2002
- if scaler is not None:
2003
- _LOGGER.debug("Sequence scaler detected. Un-scaling predictions and targets.")
2004
-
2005
- with torch.no_grad():
2006
- for features, target in dataloader:
2007
- features = features.to(self.device)
2008
- target = target.to(self.device)
2009
-
2010
- output = self.model(features)
2011
-
2012
- # --- Automatic Un-scaling Logic ---
2013
- if scaler:
2014
- # 1. Reshape for scaler (N, 1) or (N*Seq, 1)
2015
- original_out_shape = output.shape
2016
- original_target_shape = target.shape
2017
-
2018
- # Flatten sequence dims
2019
- output_flat = output.reshape(-1, 1)
2020
- target_flat = target.reshape(-1, 1)
2021
-
2022
- # 2. Inverse Transform
2023
- output_flat = scaler.inverse_transform(output_flat)
2024
- target_flat = scaler.inverse_transform(target_flat)
2025
-
2026
- # 3. Restore
2027
- output = output_flat.reshape(original_out_shape)
2028
- target = target_flat.reshape(original_target_shape)
2029
-
2030
- # Move to CPU
2031
- y_pred_batch = output.cpu().numpy()
2032
- y_true_batch = target.cpu().numpy()
2033
- y_prob_batch = None
2034
-
2035
- yield y_pred_batch, y_prob_batch, y_true_batch
2036
-
2037
- def evaluate(self,
2038
- save_dir: Union[str, Path],
2039
- model_checkpoint: Union[Path, Literal["best", "current"]],
2040
- test_data: Optional[Union[DataLoader, Dataset]] = None,
2041
- val_format_configuration: Optional[Union[SequenceValueMetricsFormat,
2042
- SequenceSequenceMetricsFormat]]=None,
2043
- test_format_configuration: Optional[Union[SequenceValueMetricsFormat,
2044
- SequenceSequenceMetricsFormat]]=None):
2045
- """
2046
- Evaluates the model, routing to the correct evaluation function.
2047
-
2048
- Args:
2049
- model_checkpoint (Path | "best" | "current"):
2050
- - Path to a valid checkpoint for the model.
2051
- - If 'best', the best checkpoint will be loaded.
2052
- - If 'current', use the current state of the trained model.
2053
- save_dir (str | Path): Directory to save all reports and plots.
2054
- test_data (DataLoader | Dataset | None): Optional Test data.
2055
- val_format_configuration: Optional configuration for validation metrics.
2056
- test_format_configuration: Optional configuration for test metrics.
2057
- """
2058
- # Validate model checkpoint
2059
- if isinstance(model_checkpoint, Path):
2060
- checkpoint_validated = make_fullpath(model_checkpoint, enforce="file")
2061
- elif model_checkpoint in [MagicWords.BEST, MagicWords.CURRENT]:
2062
- checkpoint_validated = model_checkpoint
2063
- else:
2064
- _LOGGER.error(f"'model_checkpoint' must be a Path object, or '{MagicWords.BEST}', or '{MagicWords.CURRENT}'.")
2065
- raise ValueError()
2066
-
2067
- # Validate val configuration
2068
- if val_format_configuration is not None:
2069
- if not isinstance(val_format_configuration, (SequenceValueMetricsFormat, SequenceSequenceMetricsFormat)):
2070
- _LOGGER.error(f"Invalid 'val_format_configuration': '{type(val_format_configuration)}'.")
2071
- raise ValueError()
2072
-
2073
- # Validate directory
2074
- save_path = make_fullpath(save_dir, make=True, enforce="directory")
2075
-
2076
- # Validate test data and dispatch
2077
- if test_data is not None:
2078
- if not isinstance(test_data, (DataLoader, Dataset)):
2079
- _LOGGER.error(f"Invalid type for 'test_data': '{type(test_data)}'.")
2080
- raise ValueError()
2081
- test_data_validated = test_data
2082
-
2083
- validation_metrics_path = save_path / DragonTrainerKeys.VALIDATION_METRICS_DIR
2084
- test_metrics_path = save_path / DragonTrainerKeys.TEST_METRICS_DIR
2085
-
2086
- # Dispatch validation set
2087
- _LOGGER.info(f"Evaluating on validation dataset. Metrics will be saved to '{DragonTrainerKeys.VALIDATION_METRICS_DIR}'")
2088
- self._evaluate(save_dir=validation_metrics_path,
2089
- model_checkpoint=checkpoint_validated,
2090
- data=None,
2091
- format_configuration=val_format_configuration)
2092
-
2093
- # Validate test configuration
2094
- test_configuration_validated = None
2095
- if test_format_configuration is not None:
2096
- if not isinstance(test_format_configuration, (SequenceValueMetricsFormat, SequenceSequenceMetricsFormat)):
2097
- warning_message_type = f"Invalid test_format_configuration': '{type(test_format_configuration)}'."
2098
- if val_format_configuration is not None:
2099
- warning_message_type += " 'val_format_configuration' will be used."
2100
- test_configuration_validated = val_format_configuration
2101
- else:
2102
- warning_message_type += " Using default format."
2103
- _LOGGER.warning(warning_message_type)
2104
- else:
2105
- test_configuration_validated = test_format_configuration
2106
-
2107
- # Dispatch test set
2108
- _LOGGER.info(f"Evaluating on test dataset. Metrics will be saved to '{DragonTrainerKeys.TEST_METRICS_DIR}'")
2109
- self._evaluate(save_dir=test_metrics_path,
2110
- model_checkpoint="current",
2111
- data=test_data_validated,
2112
- format_configuration=test_configuration_validated)
2113
- else:
2114
- # Dispatch validation set
2115
- _LOGGER.info(f"Evaluating on validation dataset. Metrics will be saved to '{save_path.name}'")
2116
- self._evaluate(save_dir=save_path,
2117
- model_checkpoint=checkpoint_validated,
2118
- data=None,
2119
- format_configuration=val_format_configuration)
2120
-
2121
- def _evaluate(self,
2122
- save_dir: Union[str, Path],
2123
- model_checkpoint: Union[Path, Literal["best", "current"]],
2124
- data: Optional[Union[DataLoader, Dataset]],
2125
- format_configuration: object):
2126
- """
2127
- Private evaluation helper.
2128
- """
2129
- eval_loader = None
2130
-
2131
- # load model checkpoint
2132
- if isinstance(model_checkpoint, Path):
2133
- self._load_checkpoint(path=model_checkpoint)
2134
- elif model_checkpoint == MagicWords.BEST and self._checkpoint_callback:
2135
- path_to_latest = self._checkpoint_callback.best_checkpoint_path
2136
- self._load_checkpoint(path_to_latest)
2137
- elif model_checkpoint == MagicWords.BEST and self._checkpoint_callback is None:
2138
- _LOGGER.error(f"'model_checkpoint' set to '{MagicWords.BEST}' but no checkpoint callback was found.")
2139
- raise ValueError()
2140
-
2141
- # Dataloader
2142
- if isinstance(data, DataLoader):
2143
- eval_loader = data
2144
- elif isinstance(data, Dataset):
2145
- # Create a new loader from the provided dataset
2146
- eval_loader = DataLoader(data,
2147
- batch_size=self._batch_size,
2148
- shuffle=False,
2149
- num_workers=0 if self.device.type == 'mps' else self.dataloader_workers,
2150
- pin_memory=(self.device.type == "cuda"))
2151
- else: # data is None, use the trainer's default validation dataset
2152
- if self.validation_dataset is None:
2153
- _LOGGER.error("Cannot evaluate. No data provided and no validation_dataset available in the trainer.")
2154
- raise ValueError()
2155
- eval_loader = DataLoader(self.validation_dataset,
2156
- batch_size=self._batch_size,
2157
- shuffle=False,
2158
- num_workers=0 if self.device.type == 'mps' else self.dataloader_workers,
2159
- pin_memory=(self.device.type == "cuda"))
2160
-
2161
- if eval_loader is None:
2162
- _LOGGER.error("Cannot evaluate. No valid data was provided or found.")
2163
- raise ValueError()
2164
-
2165
- all_preds, _, all_true = [], [], []
2166
- for y_pred_b, y_prob_b, y_true_b in self._predict_for_eval(eval_loader):
2167
- if y_pred_b is not None: all_preds.append(y_pred_b)
2168
- if y_true_b is not None: all_true.append(y_true_b)
2169
-
2170
- if not all_true:
2171
- _LOGGER.error("Evaluation failed: No data was processed.")
2172
- return
2173
-
2174
- y_pred = np.concatenate(all_preds)
2175
- y_true = np.concatenate(all_true)
2176
-
2177
- # --- Routing Logic ---
2178
- if self.kind == MLTaskKeys.SEQUENCE_VALUE:
2179
- config = None
2180
- if format_configuration and isinstance(format_configuration, SequenceValueMetricsFormat):
2181
- config = format_configuration
2182
- elif format_configuration:
2183
- _LOGGER.warning(f"Wrong config type: Received {type(format_configuration).__name__}, expected SequenceValueMetricsFormat.")
2184
-
2185
- sequence_to_value_metrics(y_true=y_true,
2186
- y_pred=y_pred,
2187
- save_dir=save_dir,
2188
- config=config)
2189
-
2190
- elif self.kind == MLTaskKeys.SEQUENCE_SEQUENCE:
2191
- config = None
2192
- if format_configuration and isinstance(format_configuration, SequenceSequenceMetricsFormat):
2193
- config = format_configuration
2194
- elif format_configuration:
2195
- _LOGGER.warning(f"Wrong config type: Received {type(format_configuration).__name__}, expected SequenceSequenceMetricsFormat.")
2196
-
2197
- sequence_to_sequence_metrics(y_true=y_true,
2198
- y_pred=y_pred,
2199
- save_dir=save_dir,
2200
- config=config)
2201
-
2202
- def explain_captum(self,
2203
- save_dir: Union[str, Path],
2204
- explain_dataset: Optional[Dataset] = None,
2205
- n_samples: int = 100,
2206
- feature_names: Optional[List[str]] = None,
2207
- target_names: Optional[List[str]] = None,
2208
- n_steps: int = 50):
2209
- """
2210
- Explains sequence model predictions using Captum's Integrated Gradients.
2211
-
2212
- This method calculates global feature importance by aggregating attributions across
2213
- the time dimension.
2214
- - For **multivariate** sequences, it highlights which variables (channels) are most influential.
2215
- - For **univariate** sequences, it attributes importance to the single signal feature.
2216
-
2217
- Args:
2218
- save_dir (str | Path): Directory to save the importance plots and CSV reports.
2219
- explain_dataset (Dataset | None): A specific dataset to sample from. If None, the
2220
- trainer's validation dataset is used.
2221
- n_samples (int): The number of samples to use for the explanation (background + inputs).
2222
- feature_names (List[str] | None): Names of the features (signals). If None, attempts to extract them from the dataset attribute.
2223
- target_names (List[str] | None): Names of the model outputs (e.g., for Seq2Seq or Multivariate output). If None, attempts to extract them from the dataset attribute.
2224
- n_steps (int): Number of integral approximation steps.
2225
-
2226
- Note:
2227
- For univariate data (Shape: N, Seq_Len), the 'feature' is the signal itself.
2228
- """
2229
- if not _is_captum_available():
2230
- _LOGGER.error("Captum is not installed.")
2231
- return
2232
-
2233
- dataset_to_use = explain_dataset if explain_dataset is not None else self.validation_dataset
2234
- if dataset_to_use is None:
2235
- _LOGGER.error("No dataset available for explanation.")
2236
- return
2237
-
2238
- # Helper to sample data (same as DragonTrainer)
2239
- def _get_samples(ds, n):
2240
- loader = DataLoader(ds, batch_size=n, shuffle=True, num_workers=0)
2241
- data_iter = iter(loader)
2242
- features, targets = next(data_iter)
2243
- return features, targets
2244
-
2245
- input_data, _ = _get_samples(dataset_to_use, n_samples)
2246
-
2247
- if feature_names is None:
2248
- if hasattr(dataset_to_use, DatasetKeys.FEATURE_NAMES):
2249
- feature_names = dataset_to_use.feature_names # type: ignore
2250
- else:
2251
- # If retrieval fails, leave it as None.
2252
- _LOGGER.warning("'feature_names' not provided and not found in dataset. Generic names will be used.")
2253
-
2254
- if target_names is None:
2255
- if hasattr(dataset_to_use, DatasetKeys.TARGET_NAMES):
2256
- target_names = dataset_to_use.target_names # type: ignore
2257
- else:
2258
- # If retrieval fails, leave it as None.
2259
- _LOGGER.warning("'target_names' not provided and not found in dataset. Generic names will be used.")
2260
-
2261
- # Sequence models usually output [N, 1] (Value) or [N, Seq, 1] (Seq2Seq)
2262
- # captum_feature_importance handles the aggregation.
2263
-
2264
- captum_feature_importance(
2265
- model=self.model,
2266
- input_data=input_data,
2267
- feature_names=feature_names,
2268
- save_dir=save_dir,
2269
- target_names=target_names,
2270
- n_steps=n_steps,
2271
- device=self.device
2272
- )
2273
-
2274
- def finalize_model_training(self,
2275
- save_dir: Union[str, Path],
2276
- model_checkpoint: Union[Path, Literal['best', 'current']],
2277
- finalize_config: Union[FinalizeSequenceSequencePrediction, FinalizeSequenceValuePrediction]):
2278
- """
2279
- Saves a finalized, "inference-ready" model state to a .pth file.
2280
-
2281
- This method saves the model's `state_dict` and the final epoch number.
2282
-
2283
- Args:
2284
- save_dir (Union[str, Path]): The directory to save the finalized model.
2285
- model_checkpoint (Union[Path, Literal["best", "current"]]):
2286
- - Path: Loads the model state from a specific checkpoint file.
2287
- - "best": Loads the best model state saved by the `DragonModelCheckpoint` callback.
2288
- - "current": Uses the model's state as it is.
2289
- finalize_config (FinalizeSequencePrediction): A data class instance specific to the ML task containing task-specific metadata required for inference.
2290
- """
2291
- if self.kind == MLTaskKeys.SEQUENCE_SEQUENCE and not isinstance(finalize_config, FinalizeSequenceSequencePrediction):
2292
- _LOGGER.error(f"Received a wrong finalize configuration for task {self.kind}: {type(finalize_config).__name__}.")
2293
- raise TypeError()
2294
- elif self.kind == MLTaskKeys.SEQUENCE_VALUE and not isinstance(finalize_config, FinalizeSequenceValuePrediction):
2295
- _LOGGER.error(f"Received a wrong finalize configuration for task {self.kind}: {type(finalize_config).__name__}.")
2296
- raise TypeError()
2297
-
2298
- # handle save path
2299
- dir_path = make_fullpath(save_dir, make=True, enforce="directory")
2300
- full_path = dir_path / finalize_config.filename
2301
-
2302
- # handle checkpoint
2303
- self._load_model_state_for_finalizing(model_checkpoint)
2304
-
2305
- # Create finalized data
2306
- finalized_data = {
2307
- PyTorchCheckpointKeys.EPOCH: self.epoch,
2308
- PyTorchCheckpointKeys.MODEL_STATE: self.model.state_dict(),
2309
- PyTorchCheckpointKeys.TASK: finalize_config.task
2310
- }
2311
-
2312
- if finalize_config.sequence_length is not None:
2313
- finalized_data[PyTorchCheckpointKeys.SEQUENCE_LENGTH] = finalize_config.sequence_length
2314
- if finalize_config.initial_sequence is not None:
2315
- finalized_data[PyTorchCheckpointKeys.INITIAL_SEQUENCE] = finalize_config.initial_sequence
2316
-
2317
- torch.save(finalized_data, full_path)
2318
-
2319
- _LOGGER.info(f"Finalized model file saved to '{full_path}'")
2320
-
2321
-
2322
- def info():
2323
- _script_info(__all__)