dragon-ml-toolbox 19.14.0__py3-none-any.whl → 20.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (219) hide show
  1. {dragon_ml_toolbox-19.14.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/METADATA +29 -46
  2. dragon_ml_toolbox-20.0.0.dist-info/RECORD +178 -0
  3. ml_tools/{ETL_cleaning.py → ETL_cleaning/__init__.py} +13 -5
  4. ml_tools/ETL_cleaning/_basic_clean.py +351 -0
  5. ml_tools/ETL_cleaning/_clean_tools.py +128 -0
  6. ml_tools/ETL_cleaning/_dragon_cleaner.py +245 -0
  7. ml_tools/ETL_cleaning/_imprimir.py +13 -0
  8. ml_tools/{ETL_engineering.py → ETL_engineering/__init__.py} +8 -4
  9. ml_tools/ETL_engineering/_dragon_engineering.py +261 -0
  10. ml_tools/ETL_engineering/_imprimir.py +24 -0
  11. ml_tools/{_core/_ETL_engineering.py → ETL_engineering/_transforms.py} +14 -267
  12. ml_tools/{_core → GUI_tools}/_GUI_tools.py +37 -40
  13. ml_tools/{GUI_tools.py → GUI_tools/__init__.py} +7 -5
  14. ml_tools/GUI_tools/_imprimir.py +12 -0
  15. ml_tools/IO_tools/_IO_loggers.py +235 -0
  16. ml_tools/IO_tools/_IO_save_load.py +151 -0
  17. ml_tools/IO_tools/_IO_utils.py +140 -0
  18. ml_tools/{IO_tools.py → IO_tools/__init__.py} +13 -5
  19. ml_tools/IO_tools/_imprimir.py +14 -0
  20. ml_tools/MICE/_MICE_imputation.py +132 -0
  21. ml_tools/{MICE_imputation.py → MICE/__init__.py} +6 -7
  22. ml_tools/{_core/_MICE_imputation.py → MICE/_dragon_mice.py} +243 -322
  23. ml_tools/MICE/_imprimir.py +11 -0
  24. ml_tools/{ML_callbacks.py → ML_callbacks/__init__.py} +12 -4
  25. ml_tools/ML_callbacks/_base.py +101 -0
  26. ml_tools/ML_callbacks/_checkpoint.py +232 -0
  27. ml_tools/ML_callbacks/_early_stop.py +208 -0
  28. ml_tools/ML_callbacks/_imprimir.py +12 -0
  29. ml_tools/ML_callbacks/_scheduler.py +197 -0
  30. ml_tools/{ML_chaining_utilities.py → ML_chain/__init__.py} +8 -3
  31. ml_tools/{_core/_ML_chaining_utilities.py → ML_chain/_chaining_tools.py} +5 -129
  32. ml_tools/ML_chain/_dragon_chain.py +140 -0
  33. ml_tools/ML_chain/_imprimir.py +11 -0
  34. ml_tools/ML_configuration/__init__.py +90 -0
  35. ml_tools/ML_configuration/_base_model_config.py +69 -0
  36. ml_tools/ML_configuration/_finalize.py +366 -0
  37. ml_tools/ML_configuration/_imprimir.py +47 -0
  38. ml_tools/ML_configuration/_metrics.py +593 -0
  39. ml_tools/ML_configuration/_models.py +206 -0
  40. ml_tools/ML_configuration/_training.py +124 -0
  41. ml_tools/ML_datasetmaster/__init__.py +28 -0
  42. ml_tools/ML_datasetmaster/_base_datasetmaster.py +337 -0
  43. ml_tools/{_core/_ML_datasetmaster.py → ML_datasetmaster/_datasetmaster.py} +9 -329
  44. ml_tools/ML_datasetmaster/_imprimir.py +15 -0
  45. ml_tools/{_core/_ML_sequence_datasetmaster.py → ML_datasetmaster/_sequence_datasetmaster.py} +13 -15
  46. ml_tools/{_core/_ML_vision_datasetmaster.py → ML_datasetmaster/_vision_datasetmaster.py} +63 -65
  47. ml_tools/ML_evaluation/__init__.py +53 -0
  48. ml_tools/ML_evaluation/_classification.py +629 -0
  49. ml_tools/ML_evaluation/_feature_importance.py +409 -0
  50. ml_tools/ML_evaluation/_imprimir.py +25 -0
  51. ml_tools/ML_evaluation/_loss.py +92 -0
  52. ml_tools/ML_evaluation/_regression.py +273 -0
  53. ml_tools/{_core/_ML_sequence_evaluation.py → ML_evaluation/_sequence.py} +8 -11
  54. ml_tools/{_core/_ML_vision_evaluation.py → ML_evaluation/_vision.py} +12 -17
  55. ml_tools/{_core → ML_evaluation_captum}/_ML_evaluation_captum.py +11 -38
  56. ml_tools/{ML_evaluation_captum.py → ML_evaluation_captum/__init__.py} +6 -4
  57. ml_tools/ML_evaluation_captum/_imprimir.py +10 -0
  58. ml_tools/{_core → ML_finalize_handler}/_ML_finalize_handler.py +3 -7
  59. ml_tools/ML_finalize_handler/__init__.py +10 -0
  60. ml_tools/ML_finalize_handler/_imprimir.py +8 -0
  61. ml_tools/ML_inference/__init__.py +22 -0
  62. ml_tools/ML_inference/_base_inference.py +166 -0
  63. ml_tools/{_core/_ML_chaining_inference.py → ML_inference/_chain_inference.py} +14 -17
  64. ml_tools/ML_inference/_dragon_inference.py +332 -0
  65. ml_tools/ML_inference/_imprimir.py +11 -0
  66. ml_tools/ML_inference/_multi_inference.py +180 -0
  67. ml_tools/ML_inference_sequence/__init__.py +10 -0
  68. ml_tools/ML_inference_sequence/_imprimir.py +8 -0
  69. ml_tools/{_core/_ML_sequence_inference.py → ML_inference_sequence/_sequence_inference.py} +11 -15
  70. ml_tools/ML_inference_vision/__init__.py +10 -0
  71. ml_tools/ML_inference_vision/_imprimir.py +8 -0
  72. ml_tools/{_core/_ML_vision_inference.py → ML_inference_vision/_vision_inference.py} +15 -19
  73. ml_tools/ML_models/__init__.py +32 -0
  74. ml_tools/{_core/_ML_models_advanced.py → ML_models/_advanced_models.py} +22 -18
  75. ml_tools/ML_models/_base_mlp_attention.py +198 -0
  76. ml_tools/{_core/_models_advanced_base.py → ML_models/_base_save_load.py} +73 -49
  77. ml_tools/ML_models/_dragon_tabular.py +248 -0
  78. ml_tools/ML_models/_imprimir.py +18 -0
  79. ml_tools/ML_models/_mlp_attention.py +134 -0
  80. ml_tools/{_core → ML_models}/_models_advanced_helpers.py +13 -13
  81. ml_tools/ML_models_sequence/__init__.py +10 -0
  82. ml_tools/ML_models_sequence/_imprimir.py +8 -0
  83. ml_tools/{_core/_ML_sequence_models.py → ML_models_sequence/_sequence_models.py} +5 -8
  84. ml_tools/ML_models_vision/__init__.py +29 -0
  85. ml_tools/ML_models_vision/_base_wrapper.py +254 -0
  86. ml_tools/ML_models_vision/_image_classification.py +182 -0
  87. ml_tools/ML_models_vision/_image_segmentation.py +108 -0
  88. ml_tools/ML_models_vision/_imprimir.py +16 -0
  89. ml_tools/ML_models_vision/_object_detection.py +135 -0
  90. ml_tools/ML_optimization/__init__.py +21 -0
  91. ml_tools/ML_optimization/_imprimir.py +13 -0
  92. ml_tools/{_core/_ML_optimization_pareto.py → ML_optimization/_multi_dragon.py} +18 -24
  93. ml_tools/ML_optimization/_single_dragon.py +203 -0
  94. ml_tools/{_core/_ML_optimization.py → ML_optimization/_single_manual.py} +75 -213
  95. ml_tools/{_core → ML_scaler}/_ML_scaler.py +8 -11
  96. ml_tools/ML_scaler/__init__.py +10 -0
  97. ml_tools/ML_scaler/_imprimir.py +8 -0
  98. ml_tools/ML_trainer/__init__.py +20 -0
  99. ml_tools/ML_trainer/_base_trainer.py +297 -0
  100. ml_tools/ML_trainer/_dragon_detection_trainer.py +402 -0
  101. ml_tools/ML_trainer/_dragon_sequence_trainer.py +540 -0
  102. ml_tools/ML_trainer/_dragon_trainer.py +1160 -0
  103. ml_tools/ML_trainer/_imprimir.py +10 -0
  104. ml_tools/{ML_utilities.py → ML_utilities/__init__.py} +14 -6
  105. ml_tools/ML_utilities/_artifact_finder.py +382 -0
  106. ml_tools/ML_utilities/_imprimir.py +16 -0
  107. ml_tools/ML_utilities/_inspection.py +325 -0
  108. ml_tools/ML_utilities/_train_tools.py +205 -0
  109. ml_tools/{ML_vision_transformers.py → ML_vision_transformers/__init__.py} +9 -6
  110. ml_tools/{_core/_ML_vision_transformers.py → ML_vision_transformers/_core_transforms.py} +11 -155
  111. ml_tools/ML_vision_transformers/_imprimir.py +14 -0
  112. ml_tools/ML_vision_transformers/_offline_augmentation.py +159 -0
  113. ml_tools/{_core/_PSO_optimization.py → PSO_optimization/_PSO.py} +58 -15
  114. ml_tools/{PSO_optimization.py → PSO_optimization/__init__.py} +5 -3
  115. ml_tools/PSO_optimization/_imprimir.py +10 -0
  116. ml_tools/SQL/__init__.py +7 -0
  117. ml_tools/{_core/_SQL.py → SQL/_dragon_SQL.py} +7 -11
  118. ml_tools/SQL/_imprimir.py +8 -0
  119. ml_tools/{_core → VIF}/_VIF_factor.py +5 -8
  120. ml_tools/{VIF_factor.py → VIF/__init__.py} +4 -2
  121. ml_tools/VIF/_imprimir.py +10 -0
  122. ml_tools/_core/__init__.py +7 -1
  123. ml_tools/_core/_logger.py +8 -18
  124. ml_tools/_core/_schema_load_ops.py +43 -0
  125. ml_tools/_core/_script_info.py +2 -2
  126. ml_tools/{data_exploration.py → data_exploration/__init__.py} +32 -16
  127. ml_tools/data_exploration/_analysis.py +214 -0
  128. ml_tools/data_exploration/_cleaning.py +566 -0
  129. ml_tools/data_exploration/_features.py +583 -0
  130. ml_tools/data_exploration/_imprimir.py +32 -0
  131. ml_tools/data_exploration/_plotting.py +487 -0
  132. ml_tools/data_exploration/_schema_ops.py +176 -0
  133. ml_tools/{ensemble_evaluation.py → ensemble_evaluation/__init__.py} +6 -4
  134. ml_tools/{_core → ensemble_evaluation}/_ensemble_evaluation.py +3 -7
  135. ml_tools/ensemble_evaluation/_imprimir.py +14 -0
  136. ml_tools/{ensemble_inference.py → ensemble_inference/__init__.py} +5 -3
  137. ml_tools/{_core → ensemble_inference}/_ensemble_inference.py +15 -18
  138. ml_tools/ensemble_inference/_imprimir.py +9 -0
  139. ml_tools/{ensemble_learning.py → ensemble_learning/__init__.py} +4 -6
  140. ml_tools/{_core → ensemble_learning}/_ensemble_learning.py +7 -10
  141. ml_tools/ensemble_learning/_imprimir.py +10 -0
  142. ml_tools/{excel_handler.py → excel_handler/__init__.py} +5 -3
  143. ml_tools/{_core → excel_handler}/_excel_handler.py +6 -10
  144. ml_tools/excel_handler/_imprimir.py +13 -0
  145. ml_tools/{keys.py → keys/__init__.py} +4 -1
  146. ml_tools/keys/_imprimir.py +11 -0
  147. ml_tools/{_core → keys}/_keys.py +2 -0
  148. ml_tools/{math_utilities.py → math_utilities/__init__.py} +5 -2
  149. ml_tools/math_utilities/_imprimir.py +11 -0
  150. ml_tools/{_core → math_utilities}/_math_utilities.py +1 -5
  151. ml_tools/{optimization_tools.py → optimization_tools/__init__.py} +9 -4
  152. ml_tools/optimization_tools/_imprimir.py +13 -0
  153. ml_tools/optimization_tools/_optimization_bounds.py +236 -0
  154. ml_tools/optimization_tools/_optimization_plots.py +218 -0
  155. ml_tools/{path_manager.py → path_manager/__init__.py} +6 -3
  156. ml_tools/{_core/_path_manager.py → path_manager/_dragonmanager.py} +11 -347
  157. ml_tools/path_manager/_imprimir.py +15 -0
  158. ml_tools/path_manager/_path_tools.py +346 -0
  159. ml_tools/plot_fonts/__init__.py +8 -0
  160. ml_tools/plot_fonts/_imprimir.py +8 -0
  161. ml_tools/{_core → plot_fonts}/_plot_fonts.py +2 -5
  162. ml_tools/schema/__init__.py +15 -0
  163. ml_tools/schema/_feature_schema.py +223 -0
  164. ml_tools/schema/_gui_schema.py +191 -0
  165. ml_tools/schema/_imprimir.py +10 -0
  166. ml_tools/{serde.py → serde/__init__.py} +4 -2
  167. ml_tools/serde/_imprimir.py +10 -0
  168. ml_tools/{_core → serde}/_serde.py +3 -8
  169. ml_tools/{utilities.py → utilities/__init__.py} +11 -6
  170. ml_tools/utilities/_imprimir.py +18 -0
  171. ml_tools/{_core/_utilities.py → utilities/_utility_save_load.py} +13 -190
  172. ml_tools/utilities/_utility_tools.py +192 -0
  173. dragon_ml_toolbox-19.14.0.dist-info/RECORD +0 -111
  174. ml_tools/ML_chaining_inference.py +0 -8
  175. ml_tools/ML_configuration.py +0 -86
  176. ml_tools/ML_configuration_pytab.py +0 -14
  177. ml_tools/ML_datasetmaster.py +0 -10
  178. ml_tools/ML_evaluation.py +0 -16
  179. ml_tools/ML_evaluation_multi.py +0 -12
  180. ml_tools/ML_finalize_handler.py +0 -8
  181. ml_tools/ML_inference.py +0 -12
  182. ml_tools/ML_models.py +0 -14
  183. ml_tools/ML_models_advanced.py +0 -14
  184. ml_tools/ML_models_pytab.py +0 -14
  185. ml_tools/ML_optimization.py +0 -14
  186. ml_tools/ML_optimization_pareto.py +0 -8
  187. ml_tools/ML_scaler.py +0 -8
  188. ml_tools/ML_sequence_datasetmaster.py +0 -8
  189. ml_tools/ML_sequence_evaluation.py +0 -10
  190. ml_tools/ML_sequence_inference.py +0 -8
  191. ml_tools/ML_sequence_models.py +0 -8
  192. ml_tools/ML_trainer.py +0 -12
  193. ml_tools/ML_vision_datasetmaster.py +0 -12
  194. ml_tools/ML_vision_evaluation.py +0 -10
  195. ml_tools/ML_vision_inference.py +0 -8
  196. ml_tools/ML_vision_models.py +0 -18
  197. ml_tools/SQL.py +0 -8
  198. ml_tools/_core/_ETL_cleaning.py +0 -694
  199. ml_tools/_core/_IO_tools.py +0 -498
  200. ml_tools/_core/_ML_callbacks.py +0 -702
  201. ml_tools/_core/_ML_configuration.py +0 -1332
  202. ml_tools/_core/_ML_configuration_pytab.py +0 -102
  203. ml_tools/_core/_ML_evaluation.py +0 -867
  204. ml_tools/_core/_ML_evaluation_multi.py +0 -544
  205. ml_tools/_core/_ML_inference.py +0 -646
  206. ml_tools/_core/_ML_models.py +0 -668
  207. ml_tools/_core/_ML_models_pytab.py +0 -693
  208. ml_tools/_core/_ML_trainer.py +0 -2323
  209. ml_tools/_core/_ML_utilities.py +0 -886
  210. ml_tools/_core/_ML_vision_models.py +0 -644
  211. ml_tools/_core/_data_exploration.py +0 -1909
  212. ml_tools/_core/_optimization_tools.py +0 -493
  213. ml_tools/_core/_schema.py +0 -359
  214. ml_tools/plot_fonts.py +0 -8
  215. ml_tools/schema.py +0 -12
  216. {dragon_ml_toolbox-19.14.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/WHEEL +0 -0
  217. {dragon_ml_toolbox-19.14.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/licenses/LICENSE +0 -0
  218. {dragon_ml_toolbox-19.14.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +0 -0
  219. {dragon_ml_toolbox-19.14.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,402 @@
1
+ from typing import Literal, Union, Optional, Callable
2
+ from pathlib import Path
3
+ from torch.utils.data import DataLoader, Dataset
4
+ import torch
5
+ from torch import nn
6
+
7
+ from ..ML_callbacks._base import _Callback
8
+ from ..ML_callbacks._checkpoint import DragonModelCheckpoint
9
+ from ..ML_callbacks._early_stop import _DragonEarlyStopping
10
+ from ..ML_callbacks._scheduler import _DragonLRScheduler
11
+ from ..ML_evaluation import object_detection_metrics
12
+ from ..ML_configuration import FinalizeObjectDetection
13
+
14
+ from ..path_manager import make_fullpath
15
+ from ..keys._keys import PyTorchLogKeys, PyTorchCheckpointKeys, MLTaskKeys, MagicWords, DragonTrainerKeys
16
+ from .._core import get_logger
17
+
18
+ from ._base_trainer import _BaseDragonTrainer
19
+
20
+
21
+ _LOGGER = get_logger("DragonDetectionTrainer")
22
+
23
+
24
+ __all__ = [
25
+ "DragonDetectionTrainer",
26
+ ]
27
+
28
+
29
+ # Object Detection Trainer
30
+ class DragonDetectionTrainer(_BaseDragonTrainer):
31
+ def __init__(self, model: nn.Module,
32
+ train_dataset: Dataset,
33
+ validation_dataset: Dataset,
34
+ collate_fn: Callable,
35
+ optimizer: torch.optim.Optimizer,
36
+ device: Union[Literal['cuda', 'mps', 'cpu'],str],
37
+ checkpoint_callback: Optional[DragonModelCheckpoint],
38
+ early_stopping_callback: Optional[_DragonEarlyStopping],
39
+ lr_scheduler_callback: Optional[_DragonLRScheduler],
40
+ extra_callbacks: Optional[list[_Callback]] = None,
41
+ dataloader_workers: int = 2):
42
+ """
43
+ Automates the training process of an Object Detection Model (e.g., DragonFastRCNN).
44
+
45
+ Built-in Callbacks: `History`, `TqdmProgressBar`
46
+
47
+ Args:
48
+ model (nn.Module): The PyTorch object detection model to train.
49
+ train_dataset (Dataset): The training dataset.
50
+ validation_dataset (Dataset): The testing/validation dataset.
51
+ collate_fn (Callable): The collate function from `ObjectDetectionDatasetMaker.collate_fn`.
52
+ optimizer (torch.optim.Optimizer): The optimizer.
53
+ device (str): The device to run training on ('cpu', 'cuda', 'mps').
54
+ dataloader_workers (int): Subprocesses for data loading.
55
+ checkpoint_callback (DragonModelCheckpoint | None): Callback to save the model.
56
+ early_stopping_callback (DragonEarlyStopping | None): Callback to stop training early.
57
+ lr_scheduler_callback (DragonLRScheduler | None): Callback to manage the LR scheduler.
58
+ extra_callbacks (List[Callback] | None): A list of extra callbacks to use during training.
59
+
60
+ ## Note:
61
+ This trainer is specialized. It does not take a `criterion` because object detection models like Faster R-CNN return a dictionary of losses directly from their forward pass during training.
62
+ """
63
+ # Call the base class constructor with common parameters
64
+ super().__init__(
65
+ model=model,
66
+ optimizer=optimizer,
67
+ device=device,
68
+ dataloader_workers=dataloader_workers,
69
+ checkpoint_callback=checkpoint_callback,
70
+ early_stopping_callback=early_stopping_callback,
71
+ lr_scheduler_callback=lr_scheduler_callback,
72
+ extra_callbacks=extra_callbacks
73
+ )
74
+
75
+ self.train_dataset = train_dataset
76
+ self.validation_dataset = validation_dataset # <-- Renamed
77
+ self.kind = MLTaskKeys.OBJECT_DETECTION
78
+ self.collate_fn = collate_fn
79
+ self.criterion = None # Criterion is handled inside the model
80
+
81
+ def _create_dataloaders(self, batch_size: int, shuffle: bool):
82
+ """Initializes the DataLoaders with the object detection collate_fn."""
83
+ # Ensure stability on MPS devices by setting num_workers to 0
84
+ loader_workers = 0 if self.device.type == 'mps' else self.dataloader_workers
85
+
86
+ self.train_loader = DataLoader(
87
+ dataset=self.train_dataset,
88
+ batch_size=batch_size,
89
+ shuffle=shuffle,
90
+ num_workers=loader_workers,
91
+ pin_memory=("cuda" in self.device.type),
92
+ collate_fn=self.collate_fn, # Use the provided collate function
93
+ drop_last=True
94
+ )
95
+
96
+ self.validation_loader = DataLoader(
97
+ dataset=self.validation_dataset,
98
+ batch_size=batch_size,
99
+ shuffle=False,
100
+ num_workers=loader_workers,
101
+ pin_memory=("cuda" in self.device.type),
102
+ collate_fn=self.collate_fn # Use the provided collate function
103
+ )
104
+
105
+ def _train_step(self):
106
+ self.model.train()
107
+ running_loss = 0.0
108
+ total_samples = 0
109
+
110
+ for batch_idx, (images, targets) in enumerate(self.train_loader): # type: ignore
111
+ # images is a tuple of tensors, targets is a tuple of dicts
112
+ batch_size = len(images)
113
+
114
+ # Create a log dictionary for the batch
115
+ batch_logs = {
116
+ PyTorchLogKeys.BATCH_INDEX: batch_idx,
117
+ PyTorchLogKeys.BATCH_SIZE: batch_size
118
+ }
119
+ self._callbacks_hook('on_batch_begin', batch_idx, logs=batch_logs)
120
+
121
+ # Move data to device
122
+ images = list(img.to(self.device) for img in images)
123
+ targets = [{k: v.to(self.device) for k, v in t.items()} for t in targets]
124
+
125
+ self.optimizer.zero_grad()
126
+
127
+ # Model returns a loss dict when in train() mode and targets are passed
128
+ loss_dict = self.model(images, targets)
129
+
130
+ if not loss_dict:
131
+ # No losses returned, skip batch
132
+ _LOGGER.warning(f"Model returned no losses for batch {batch_idx}. Skipping.")
133
+ batch_logs[PyTorchLogKeys.BATCH_LOSS] = 0
134
+ self._callbacks_hook('on_batch_end', batch_idx, logs=batch_logs)
135
+ continue
136
+
137
+ # Sum all losses
138
+ loss: torch.Tensor = sum(l for l in loss_dict.values()) # type: ignore
139
+
140
+ loss.backward()
141
+ self.optimizer.step()
142
+
143
+ # Calculate batch loss and update running loss for the epoch
144
+ batch_loss = loss.item()
145
+ running_loss += batch_loss * batch_size
146
+ total_samples += batch_size # <-- Accumulate total samples
147
+
148
+ # Add the batch loss to the logs and call the end-of-batch hook
149
+ batch_logs[PyTorchLogKeys.BATCH_LOSS] = batch_loss # type: ignore
150
+ self._callbacks_hook('on_batch_end', batch_idx, logs=batch_logs)
151
+
152
+ # Calculate loss using the correct denominator
153
+ if total_samples == 0:
154
+ _LOGGER.warning("No samples processed in _train_step. Returning 0 loss.")
155
+ return {PyTorchLogKeys.TRAIN_LOSS: 0.0}
156
+
157
+ return {PyTorchLogKeys.TRAIN_LOSS: running_loss / total_samples}
158
+
159
+ def _validation_step(self):
160
+ self.model.train() # Set to train mode even for validation loss calculation
161
+ # as model internals (e.g., proposals) might differ, but we still need loss_dict.
162
+ # use torch.no_grad() to prevent gradient updates.
163
+ running_loss = 0.0
164
+ total_samples = 0
165
+
166
+ with torch.no_grad():
167
+ for images, targets in self.validation_loader: # type: ignore
168
+ batch_size = len(images)
169
+
170
+ # Move data to device
171
+ images = list(img.to(self.device) for img in images)
172
+ targets = [{k: v.to(self.device) for k, v in t.items()} for t in targets]
173
+
174
+ # Get loss dict
175
+ loss_dict = self.model(images, targets)
176
+
177
+ if not loss_dict:
178
+ _LOGGER.warning("Model returned no losses during validation step. Skipping batch.")
179
+ continue # Skip if no losses
180
+
181
+ # Sum all losses
182
+ loss: torch.Tensor = sum(l for l in loss_dict.values()) # type: ignore
183
+
184
+ running_loss += loss.item() * batch_size
185
+ total_samples += batch_size # <-- Accumulate total samples
186
+
187
+ # Calculate loss using the correct denominator
188
+ if total_samples == 0:
189
+ _LOGGER.warning("No samples processed in _validation_step. Returning 0 loss.")
190
+ return {PyTorchLogKeys.VAL_LOSS: 0.0}
191
+
192
+ logs = {PyTorchLogKeys.VAL_LOSS: running_loss / total_samples}
193
+ return logs
194
+
195
+ def evaluate(self,
196
+ save_dir: Union[str, Path],
197
+ model_checkpoint: Union[Path, Literal["best", "current"]],
198
+ test_data: Optional[Union[DataLoader, Dataset]] = None):
199
+ """
200
+ Evaluates the model using object detection mAP metrics.
201
+
202
+ Args:
203
+ save_dir (str | Path): Directory to save all reports and plots.
204
+ model_checkpoint (Path | "best" | "current"):
205
+ - Path to a valid checkpoint for the model. The state of the trained model will be overwritten in place.
206
+ - If 'best', the best checkpoint will be loaded if a DragonModelCheckpoint was provided. The state of the trained model will be overwritten in place.
207
+ - If 'current', use the current state of the trained model up the latest trained epoch.
208
+ test_data (DataLoader | Dataset | None): Optional Test data to evaluate the model performance. Validation and Test metrics will be saved to subdirectories.
209
+ """
210
+ # Validate model checkpoint
211
+ if isinstance(model_checkpoint, Path):
212
+ checkpoint_validated = make_fullpath(model_checkpoint, enforce="file")
213
+ elif model_checkpoint in [MagicWords.BEST, MagicWords.CURRENT]:
214
+ checkpoint_validated = model_checkpoint
215
+ else:
216
+ _LOGGER.error(f"'model_checkpoint' must be a Path object, or the string '{MagicWords.BEST}', or the string '{MagicWords.CURRENT}'.")
217
+ raise ValueError()
218
+
219
+ # Validate directory
220
+ save_path = make_fullpath(save_dir, make=True, enforce="directory")
221
+
222
+ # Validate test data and dispatch
223
+ if test_data is not None:
224
+ if not isinstance(test_data, (DataLoader, Dataset)):
225
+ _LOGGER.error(f"Invalid type for 'test_data': '{type(test_data)}'.")
226
+ raise ValueError()
227
+ test_data_validated = test_data
228
+
229
+ validation_metrics_path = save_path / DragonTrainerKeys.VALIDATION_METRICS_DIR
230
+ test_metrics_path = save_path / DragonTrainerKeys.TEST_METRICS_DIR
231
+
232
+ # Dispatch validation set
233
+ _LOGGER.info(f"Evaluating on validation dataset. Metrics will be saved to '{DragonTrainerKeys.VALIDATION_METRICS_DIR}'")
234
+ self._evaluate(save_dir=validation_metrics_path,
235
+ model_checkpoint=checkpoint_validated,
236
+ data=None) # 'None' triggers use of self.test_dataset
237
+
238
+ # Dispatch test set
239
+ _LOGGER.info(f"Evaluating on test dataset. Metrics will be saved to '{DragonTrainerKeys.TEST_METRICS_DIR}'")
240
+ self._evaluate(save_dir=test_metrics_path,
241
+ model_checkpoint="current", # Use 'current' state after loading checkpoint once
242
+ data=test_data_validated)
243
+ else:
244
+ # Dispatch validation set
245
+ _LOGGER.info(f"Evaluating on validation dataset. Metrics will be saved to '{save_path.name}'")
246
+ self._evaluate(save_dir=save_path,
247
+ model_checkpoint=checkpoint_validated,
248
+ data=None) # 'None' triggers use of self.test_dataset
249
+
250
+ def _evaluate(self,
251
+ save_dir: Union[str, Path],
252
+ model_checkpoint: Union[Path, Literal["best", "current"]],
253
+ data: Optional[Union[DataLoader, Dataset]]):
254
+ """
255
+ Changed to a private helper method
256
+ Evaluates the model using object detection mAP metrics.
257
+
258
+ Args:
259
+ save_dir (str | Path): Directory to save all reports and plots.
260
+ data (DataLoader | Dataset | None): The data to evaluate on. If None, defaults to the trainer's internal test_dataset.
261
+ model_checkpoint ('auto' | Path | None):
262
+ - Path to a valid checkpoint for the model. The state of the trained model will be overwritten in place.
263
+ - If 'best', the best checkpoint will be loaded if a DragonModelCheckpoint was provided. The state of the trained model will be overwritten in place.
264
+ - If 'current', use the current state of the trained model up the latest trained epoch.
265
+ """
266
+ dataset_for_artifacts = None
267
+ eval_loader = None
268
+
269
+ # load model checkpoint
270
+ if isinstance(model_checkpoint, Path):
271
+ self._load_checkpoint(path=model_checkpoint)
272
+ elif model_checkpoint == MagicWords.BEST and self._checkpoint_callback:
273
+ path_to_latest = self._checkpoint_callback.best_checkpoint_path
274
+ self._load_checkpoint(path_to_latest)
275
+ elif model_checkpoint == MagicWords.BEST and self._checkpoint_callback is None:
276
+ _LOGGER.error(f"'model_checkpoint' set to '{MagicWords.BEST}' but no checkpoint callback was found.")
277
+ raise ValueError()
278
+
279
+ # Dataloader
280
+ if isinstance(data, DataLoader):
281
+ eval_loader = data
282
+ if hasattr(data, 'dataset'):
283
+ dataset_for_artifacts = data.dataset # type: ignore
284
+ elif isinstance(data, Dataset):
285
+ # Create a new loader from the provided dataset
286
+ eval_loader = DataLoader(data,
287
+ batch_size=self._batch_size,
288
+ shuffle=False,
289
+ num_workers=0 if self.device.type == 'mps' else self.dataloader_workers,
290
+ pin_memory=(self.device.type == "cuda"),
291
+ collate_fn=self.collate_fn)
292
+ dataset_for_artifacts = data
293
+ else: # data is None, use the trainer's default test dataset
294
+ if self.validation_dataset is None:
295
+ _LOGGER.error("Cannot evaluate. No data provided and no test_dataset available in the trainer.")
296
+ raise ValueError()
297
+ # Create a fresh DataLoader from the test_dataset
298
+ eval_loader = DataLoader(
299
+ self.validation_dataset,
300
+ batch_size=self._batch_size,
301
+ shuffle=False,
302
+ num_workers=0 if self.device.type == 'mps' else self.dataloader_workers,
303
+ pin_memory=(self.device.type == "cuda"),
304
+ collate_fn=self.collate_fn
305
+ )
306
+ dataset_for_artifacts = self.validation_dataset
307
+
308
+ if eval_loader is None:
309
+ _LOGGER.error("Cannot evaluate. No valid data was provided or found.")
310
+ raise ValueError()
311
+
312
+ # print("\n--- Model Evaluation ---")
313
+
314
+ all_predictions = []
315
+ all_targets = []
316
+
317
+ self.model.eval() # Set model to evaluation mode
318
+ self.model.to(self.device)
319
+
320
+ with torch.no_grad():
321
+ for images, targets in eval_loader:
322
+ # Move images to device
323
+ images = list(img.to(self.device) for img in images)
324
+
325
+ # Model returns predictions when in eval() mode
326
+ predictions = self.model(images)
327
+
328
+ # Move predictions and targets to CPU for aggregation
329
+ cpu_preds = [{k: v.to('cpu') for k, v in p.items()} for p in predictions]
330
+ cpu_targets = [{k: v.to('cpu') for k, v in t.items()} for t in targets]
331
+
332
+ all_predictions.extend(cpu_preds)
333
+ all_targets.extend(cpu_targets)
334
+
335
+ if not all_targets:
336
+ _LOGGER.error("Evaluation failed: No data was processed.")
337
+ return
338
+
339
+ # Get class names from the dataset for the report
340
+ class_names = None
341
+ try:
342
+ # Try to get 'classes' from ObjectDetectionDatasetMaker
343
+ if hasattr(dataset_for_artifacts, 'classes'):
344
+ class_names = dataset_for_artifacts.classes # type: ignore
345
+ # Fallback for Subset
346
+ elif hasattr(dataset_for_artifacts, 'dataset') and hasattr(dataset_for_artifacts.dataset, 'classes'): # type: ignore
347
+ class_names = dataset_for_artifacts.dataset.classes # type: ignore
348
+ except AttributeError:
349
+ _LOGGER.warning("Could not find 'classes' attribute on dataset. Per-class metrics will not be named.")
350
+ pass # class_names is still None
351
+
352
+ # --- Routing Logic ---
353
+ object_detection_metrics(
354
+ preds=all_predictions,
355
+ targets=all_targets,
356
+ save_dir=save_dir,
357
+ class_names=class_names,
358
+ print_output=False
359
+ )
360
+
361
+ def finalize_model_training(self,
362
+ save_dir: Union[str, Path],
363
+ model_checkpoint: Union[Path, Literal['best', 'current']],
364
+ finalize_config: FinalizeObjectDetection
365
+ ):
366
+ """
367
+ Saves a finalized, "inference-ready" model state to a .pth file.
368
+
369
+ This method saves the model's `state_dict` and the final epoch number.
370
+
371
+ Args:
372
+ save_dir (Union[str, Path]): The directory to save the finalized model.
373
+ model_checkpoint (Union[Path, Literal["best", "current"]]):
374
+ - Path: Loads the model state from a specific checkpoint file.
375
+ - "best": Loads the best model state saved by the `DragonModelCheckpoint` callback.
376
+ - "current": Uses the model's state as it is.
377
+ finalize_config (FinalizeObjectDetection): A data class instance specific to the ML task containing task-specific metadata required for inference.
378
+ """
379
+ if not isinstance(finalize_config, FinalizeObjectDetection):
380
+ _LOGGER.error(f"For task {self.kind}, expected finalize_config of type 'FinalizeObjectDetection', but got {type(finalize_config).__name__}.")
381
+ raise TypeError()
382
+
383
+ # handle save path
384
+ dir_path = make_fullpath(save_dir, make=True, enforce="directory")
385
+ full_path = dir_path / finalize_config.filename
386
+
387
+ # handle checkpoint
388
+ self._load_model_state_for_finalizing(model_checkpoint)
389
+
390
+ # Create finalized data
391
+ finalized_data = {
392
+ PyTorchCheckpointKeys.EPOCH: self.epoch,
393
+ PyTorchCheckpointKeys.MODEL_STATE: self.model.state_dict(),
394
+ PyTorchCheckpointKeys.TASK: finalize_config.task
395
+ }
396
+
397
+ if finalize_config.class_map is not None:
398
+ finalized_data[PyTorchCheckpointKeys.CLASS_MAP] = finalize_config.class_map
399
+
400
+ torch.save(finalized_data, full_path)
401
+
402
+ _LOGGER.info(f"Finalized model file saved to '{full_path}'")