dragon-ml-toolbox 19.13.0__py3-none-any.whl → 20.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (219) hide show
  1. {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/METADATA +29 -46
  2. dragon_ml_toolbox-20.0.0.dist-info/RECORD +178 -0
  3. ml_tools/{ETL_cleaning.py → ETL_cleaning/__init__.py} +13 -5
  4. ml_tools/ETL_cleaning/_basic_clean.py +351 -0
  5. ml_tools/ETL_cleaning/_clean_tools.py +128 -0
  6. ml_tools/ETL_cleaning/_dragon_cleaner.py +245 -0
  7. ml_tools/ETL_cleaning/_imprimir.py +13 -0
  8. ml_tools/{ETL_engineering.py → ETL_engineering/__init__.py} +8 -4
  9. ml_tools/ETL_engineering/_dragon_engineering.py +261 -0
  10. ml_tools/ETL_engineering/_imprimir.py +24 -0
  11. ml_tools/{_core/_ETL_engineering.py → ETL_engineering/_transforms.py} +14 -267
  12. ml_tools/{_core → GUI_tools}/_GUI_tools.py +37 -40
  13. ml_tools/{GUI_tools.py → GUI_tools/__init__.py} +7 -5
  14. ml_tools/GUI_tools/_imprimir.py +12 -0
  15. ml_tools/IO_tools/_IO_loggers.py +235 -0
  16. ml_tools/IO_tools/_IO_save_load.py +151 -0
  17. ml_tools/IO_tools/_IO_utils.py +140 -0
  18. ml_tools/{IO_tools.py → IO_tools/__init__.py} +13 -5
  19. ml_tools/IO_tools/_imprimir.py +14 -0
  20. ml_tools/MICE/_MICE_imputation.py +132 -0
  21. ml_tools/{MICE_imputation.py → MICE/__init__.py} +6 -7
  22. ml_tools/{_core/_MICE_imputation.py → MICE/_dragon_mice.py} +243 -322
  23. ml_tools/MICE/_imprimir.py +11 -0
  24. ml_tools/{ML_callbacks.py → ML_callbacks/__init__.py} +12 -4
  25. ml_tools/ML_callbacks/_base.py +101 -0
  26. ml_tools/ML_callbacks/_checkpoint.py +232 -0
  27. ml_tools/ML_callbacks/_early_stop.py +208 -0
  28. ml_tools/ML_callbacks/_imprimir.py +12 -0
  29. ml_tools/ML_callbacks/_scheduler.py +197 -0
  30. ml_tools/{ML_chaining_utilities.py → ML_chain/__init__.py} +8 -3
  31. ml_tools/{_core/_ML_chaining_utilities.py → ML_chain/_chaining_tools.py} +5 -129
  32. ml_tools/ML_chain/_dragon_chain.py +140 -0
  33. ml_tools/ML_chain/_imprimir.py +11 -0
  34. ml_tools/ML_configuration/__init__.py +90 -0
  35. ml_tools/ML_configuration/_base_model_config.py +69 -0
  36. ml_tools/ML_configuration/_finalize.py +366 -0
  37. ml_tools/ML_configuration/_imprimir.py +47 -0
  38. ml_tools/ML_configuration/_metrics.py +593 -0
  39. ml_tools/ML_configuration/_models.py +206 -0
  40. ml_tools/ML_configuration/_training.py +124 -0
  41. ml_tools/ML_datasetmaster/__init__.py +28 -0
  42. ml_tools/ML_datasetmaster/_base_datasetmaster.py +337 -0
  43. ml_tools/{_core/_ML_datasetmaster.py → ML_datasetmaster/_datasetmaster.py} +9 -329
  44. ml_tools/ML_datasetmaster/_imprimir.py +15 -0
  45. ml_tools/{_core/_ML_sequence_datasetmaster.py → ML_datasetmaster/_sequence_datasetmaster.py} +13 -15
  46. ml_tools/{_core/_ML_vision_datasetmaster.py → ML_datasetmaster/_vision_datasetmaster.py} +63 -65
  47. ml_tools/ML_evaluation/__init__.py +53 -0
  48. ml_tools/ML_evaluation/_classification.py +629 -0
  49. ml_tools/ML_evaluation/_feature_importance.py +409 -0
  50. ml_tools/ML_evaluation/_imprimir.py +25 -0
  51. ml_tools/ML_evaluation/_loss.py +92 -0
  52. ml_tools/ML_evaluation/_regression.py +273 -0
  53. ml_tools/{_core/_ML_sequence_evaluation.py → ML_evaluation/_sequence.py} +8 -11
  54. ml_tools/{_core/_ML_vision_evaluation.py → ML_evaluation/_vision.py} +12 -17
  55. ml_tools/{_core → ML_evaluation_captum}/_ML_evaluation_captum.py +11 -38
  56. ml_tools/{ML_evaluation_captum.py → ML_evaluation_captum/__init__.py} +6 -4
  57. ml_tools/ML_evaluation_captum/_imprimir.py +10 -0
  58. ml_tools/{_core → ML_finalize_handler}/_ML_finalize_handler.py +3 -7
  59. ml_tools/ML_finalize_handler/__init__.py +10 -0
  60. ml_tools/ML_finalize_handler/_imprimir.py +8 -0
  61. ml_tools/ML_inference/__init__.py +22 -0
  62. ml_tools/ML_inference/_base_inference.py +166 -0
  63. ml_tools/{_core/_ML_chaining_inference.py → ML_inference/_chain_inference.py} +14 -17
  64. ml_tools/ML_inference/_dragon_inference.py +332 -0
  65. ml_tools/ML_inference/_imprimir.py +11 -0
  66. ml_tools/ML_inference/_multi_inference.py +180 -0
  67. ml_tools/ML_inference_sequence/__init__.py +10 -0
  68. ml_tools/ML_inference_sequence/_imprimir.py +8 -0
  69. ml_tools/{_core/_ML_sequence_inference.py → ML_inference_sequence/_sequence_inference.py} +11 -15
  70. ml_tools/ML_inference_vision/__init__.py +10 -0
  71. ml_tools/ML_inference_vision/_imprimir.py +8 -0
  72. ml_tools/{_core/_ML_vision_inference.py → ML_inference_vision/_vision_inference.py} +15 -19
  73. ml_tools/ML_models/__init__.py +32 -0
  74. ml_tools/{_core/_ML_models_advanced.py → ML_models/_advanced_models.py} +22 -18
  75. ml_tools/ML_models/_base_mlp_attention.py +198 -0
  76. ml_tools/{_core/_models_advanced_base.py → ML_models/_base_save_load.py} +73 -49
  77. ml_tools/ML_models/_dragon_tabular.py +248 -0
  78. ml_tools/ML_models/_imprimir.py +18 -0
  79. ml_tools/ML_models/_mlp_attention.py +134 -0
  80. ml_tools/{_core → ML_models}/_models_advanced_helpers.py +13 -13
  81. ml_tools/ML_models_sequence/__init__.py +10 -0
  82. ml_tools/ML_models_sequence/_imprimir.py +8 -0
  83. ml_tools/{_core/_ML_sequence_models.py → ML_models_sequence/_sequence_models.py} +5 -8
  84. ml_tools/ML_models_vision/__init__.py +29 -0
  85. ml_tools/ML_models_vision/_base_wrapper.py +254 -0
  86. ml_tools/ML_models_vision/_image_classification.py +182 -0
  87. ml_tools/ML_models_vision/_image_segmentation.py +108 -0
  88. ml_tools/ML_models_vision/_imprimir.py +16 -0
  89. ml_tools/ML_models_vision/_object_detection.py +135 -0
  90. ml_tools/ML_optimization/__init__.py +21 -0
  91. ml_tools/ML_optimization/_imprimir.py +13 -0
  92. ml_tools/{_core/_ML_optimization_pareto.py → ML_optimization/_multi_dragon.py} +18 -24
  93. ml_tools/ML_optimization/_single_dragon.py +203 -0
  94. ml_tools/{_core/_ML_optimization.py → ML_optimization/_single_manual.py} +75 -213
  95. ml_tools/{_core → ML_scaler}/_ML_scaler.py +8 -11
  96. ml_tools/ML_scaler/__init__.py +10 -0
  97. ml_tools/ML_scaler/_imprimir.py +8 -0
  98. ml_tools/ML_trainer/__init__.py +20 -0
  99. ml_tools/ML_trainer/_base_trainer.py +297 -0
  100. ml_tools/ML_trainer/_dragon_detection_trainer.py +402 -0
  101. ml_tools/ML_trainer/_dragon_sequence_trainer.py +540 -0
  102. ml_tools/ML_trainer/_dragon_trainer.py +1160 -0
  103. ml_tools/ML_trainer/_imprimir.py +10 -0
  104. ml_tools/{ML_utilities.py → ML_utilities/__init__.py} +14 -6
  105. ml_tools/ML_utilities/_artifact_finder.py +382 -0
  106. ml_tools/ML_utilities/_imprimir.py +16 -0
  107. ml_tools/ML_utilities/_inspection.py +325 -0
  108. ml_tools/ML_utilities/_train_tools.py +205 -0
  109. ml_tools/{ML_vision_transformers.py → ML_vision_transformers/__init__.py} +9 -6
  110. ml_tools/{_core/_ML_vision_transformers.py → ML_vision_transformers/_core_transforms.py} +11 -155
  111. ml_tools/ML_vision_transformers/_imprimir.py +14 -0
  112. ml_tools/ML_vision_transformers/_offline_augmentation.py +159 -0
  113. ml_tools/{_core/_PSO_optimization.py → PSO_optimization/_PSO.py} +58 -15
  114. ml_tools/{PSO_optimization.py → PSO_optimization/__init__.py} +5 -3
  115. ml_tools/PSO_optimization/_imprimir.py +10 -0
  116. ml_tools/SQL/__init__.py +7 -0
  117. ml_tools/{_core/_SQL.py → SQL/_dragon_SQL.py} +7 -11
  118. ml_tools/SQL/_imprimir.py +8 -0
  119. ml_tools/{_core → VIF}/_VIF_factor.py +5 -8
  120. ml_tools/{VIF_factor.py → VIF/__init__.py} +4 -2
  121. ml_tools/VIF/_imprimir.py +10 -0
  122. ml_tools/_core/__init__.py +7 -1
  123. ml_tools/_core/_logger.py +8 -18
  124. ml_tools/_core/_schema_load_ops.py +43 -0
  125. ml_tools/_core/_script_info.py +2 -2
  126. ml_tools/{data_exploration.py → data_exploration/__init__.py} +32 -16
  127. ml_tools/data_exploration/_analysis.py +214 -0
  128. ml_tools/data_exploration/_cleaning.py +566 -0
  129. ml_tools/data_exploration/_features.py +583 -0
  130. ml_tools/data_exploration/_imprimir.py +32 -0
  131. ml_tools/data_exploration/_plotting.py +487 -0
  132. ml_tools/data_exploration/_schema_ops.py +176 -0
  133. ml_tools/{ensemble_evaluation.py → ensemble_evaluation/__init__.py} +6 -4
  134. ml_tools/{_core → ensemble_evaluation}/_ensemble_evaluation.py +3 -7
  135. ml_tools/ensemble_evaluation/_imprimir.py +14 -0
  136. ml_tools/{ensemble_inference.py → ensemble_inference/__init__.py} +5 -3
  137. ml_tools/{_core → ensemble_inference}/_ensemble_inference.py +15 -18
  138. ml_tools/ensemble_inference/_imprimir.py +9 -0
  139. ml_tools/{ensemble_learning.py → ensemble_learning/__init__.py} +4 -6
  140. ml_tools/{_core → ensemble_learning}/_ensemble_learning.py +7 -10
  141. ml_tools/ensemble_learning/_imprimir.py +10 -0
  142. ml_tools/{excel_handler.py → excel_handler/__init__.py} +5 -3
  143. ml_tools/{_core → excel_handler}/_excel_handler.py +6 -10
  144. ml_tools/excel_handler/_imprimir.py +13 -0
  145. ml_tools/{keys.py → keys/__init__.py} +4 -1
  146. ml_tools/keys/_imprimir.py +11 -0
  147. ml_tools/{_core → keys}/_keys.py +2 -0
  148. ml_tools/{math_utilities.py → math_utilities/__init__.py} +5 -2
  149. ml_tools/math_utilities/_imprimir.py +11 -0
  150. ml_tools/{_core → math_utilities}/_math_utilities.py +1 -5
  151. ml_tools/{optimization_tools.py → optimization_tools/__init__.py} +9 -4
  152. ml_tools/optimization_tools/_imprimir.py +13 -0
  153. ml_tools/optimization_tools/_optimization_bounds.py +236 -0
  154. ml_tools/optimization_tools/_optimization_plots.py +218 -0
  155. ml_tools/{path_manager.py → path_manager/__init__.py} +6 -3
  156. ml_tools/{_core/_path_manager.py → path_manager/_dragonmanager.py} +11 -347
  157. ml_tools/path_manager/_imprimir.py +15 -0
  158. ml_tools/path_manager/_path_tools.py +346 -0
  159. ml_tools/plot_fonts/__init__.py +8 -0
  160. ml_tools/plot_fonts/_imprimir.py +8 -0
  161. ml_tools/{_core → plot_fonts}/_plot_fonts.py +2 -5
  162. ml_tools/schema/__init__.py +15 -0
  163. ml_tools/schema/_feature_schema.py +223 -0
  164. ml_tools/schema/_gui_schema.py +191 -0
  165. ml_tools/schema/_imprimir.py +10 -0
  166. ml_tools/{serde.py → serde/__init__.py} +4 -2
  167. ml_tools/serde/_imprimir.py +10 -0
  168. ml_tools/{_core → serde}/_serde.py +3 -8
  169. ml_tools/{utilities.py → utilities/__init__.py} +11 -6
  170. ml_tools/utilities/_imprimir.py +18 -0
  171. ml_tools/{_core/_utilities.py → utilities/_utility_save_load.py} +13 -190
  172. ml_tools/utilities/_utility_tools.py +192 -0
  173. dragon_ml_toolbox-19.13.0.dist-info/RECORD +0 -111
  174. ml_tools/ML_chaining_inference.py +0 -8
  175. ml_tools/ML_configuration.py +0 -86
  176. ml_tools/ML_configuration_pytab.py +0 -14
  177. ml_tools/ML_datasetmaster.py +0 -10
  178. ml_tools/ML_evaluation.py +0 -16
  179. ml_tools/ML_evaluation_multi.py +0 -12
  180. ml_tools/ML_finalize_handler.py +0 -8
  181. ml_tools/ML_inference.py +0 -12
  182. ml_tools/ML_models.py +0 -14
  183. ml_tools/ML_models_advanced.py +0 -14
  184. ml_tools/ML_models_pytab.py +0 -14
  185. ml_tools/ML_optimization.py +0 -14
  186. ml_tools/ML_optimization_pareto.py +0 -8
  187. ml_tools/ML_scaler.py +0 -8
  188. ml_tools/ML_sequence_datasetmaster.py +0 -8
  189. ml_tools/ML_sequence_evaluation.py +0 -10
  190. ml_tools/ML_sequence_inference.py +0 -8
  191. ml_tools/ML_sequence_models.py +0 -8
  192. ml_tools/ML_trainer.py +0 -12
  193. ml_tools/ML_vision_datasetmaster.py +0 -12
  194. ml_tools/ML_vision_evaluation.py +0 -10
  195. ml_tools/ML_vision_inference.py +0 -8
  196. ml_tools/ML_vision_models.py +0 -18
  197. ml_tools/SQL.py +0 -8
  198. ml_tools/_core/_ETL_cleaning.py +0 -694
  199. ml_tools/_core/_IO_tools.py +0 -498
  200. ml_tools/_core/_ML_callbacks.py +0 -702
  201. ml_tools/_core/_ML_configuration.py +0 -1332
  202. ml_tools/_core/_ML_configuration_pytab.py +0 -102
  203. ml_tools/_core/_ML_evaluation.py +0 -867
  204. ml_tools/_core/_ML_evaluation_multi.py +0 -544
  205. ml_tools/_core/_ML_inference.py +0 -646
  206. ml_tools/_core/_ML_models.py +0 -668
  207. ml_tools/_core/_ML_models_pytab.py +0 -693
  208. ml_tools/_core/_ML_trainer.py +0 -2323
  209. ml_tools/_core/_ML_utilities.py +0 -886
  210. ml_tools/_core/_ML_vision_models.py +0 -644
  211. ml_tools/_core/_data_exploration.py +0 -1901
  212. ml_tools/_core/_optimization_tools.py +0 -493
  213. ml_tools/_core/_schema.py +0 -359
  214. ml_tools/plot_fonts.py +0 -8
  215. ml_tools/schema.py +0 -12
  216. {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/WHEEL +0 -0
  217. {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/licenses/LICENSE +0 -0
  218. {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +0 -0
  219. {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1160 @@
1
+ from typing import Literal, Union, Optional
2
+ from pathlib import Path
3
+ from torch.utils.data import DataLoader, Dataset
4
+ import torch
5
+ from torch import nn
6
+ import numpy as np
7
+
8
+ from ..ML_callbacks._base import _Callback
9
+ from ..ML_callbacks._checkpoint import DragonModelCheckpoint
10
+ from ..ML_callbacks._early_stop import _DragonEarlyStopping
11
+ from ..ML_callbacks._scheduler import _DragonLRScheduler
12
+ from ..ML_evaluation import classification_metrics, regression_metrics, shap_summary_plot, plot_attention_importance
13
+ from ..ML_evaluation import multi_target_regression_metrics, multi_label_classification_metrics, multi_target_shap_summary_plot
14
+ from ..ML_evaluation import segmentation_metrics
15
+ from ..ML_evaluation_captum import captum_feature_importance, captum_segmentation_heatmap, captum_image_heatmap
16
+ from ..ML_configuration import (FormatRegressionMetrics,
17
+ FormatMultiTargetRegressionMetrics,
18
+ FormatBinaryClassificationMetrics,
19
+ FormatMultiClassClassificationMetrics,
20
+ FormatBinaryImageClassificationMetrics,
21
+ FormatMultiClassImageClassificationMetrics,
22
+ FormatMultiLabelBinaryClassificationMetrics,
23
+ FormatBinarySegmentationMetrics,
24
+ FormatMultiClassSegmentationMetrics,
25
+
26
+ FinalizeBinaryClassification,
27
+ FinalizeBinarySegmentation,
28
+ FinalizeBinaryImageClassification,
29
+ FinalizeMultiClassClassification,
30
+ FinalizeMultiClassImageClassification,
31
+ FinalizeMultiClassSegmentation,
32
+ FinalizeMultiLabelBinaryClassification,
33
+ FinalizeMultiTargetRegression,
34
+ FinalizeRegression)
35
+
36
+ from ..path_manager import make_fullpath
37
+ from ..keys._keys import PyTorchLogKeys, PyTorchCheckpointKeys, DatasetKeys, MLTaskKeys, MagicWords, DragonTrainerKeys, ScalerKeys
38
+ from .._core import get_logger
39
+
40
+ from ._base_trainer import _BaseDragonTrainer
41
+
42
+
43
+ _LOGGER = get_logger("DragonTrainer")
44
+
45
+
46
+ __all__ = [
47
+ "DragonTrainer",
48
+ ]
49
+
50
+
51
+ # --- DragonTrainer ----
52
+ class DragonTrainer(_BaseDragonTrainer):
53
+ def __init__(self,
54
+ model: nn.Module,
55
+ train_dataset: Dataset,
56
+ validation_dataset: Dataset,
57
+ kind: Literal["regression",
58
+ "binary classification",
59
+ "multiclass classification",
60
+ "multitarget regression",
61
+ "multilabel binary classification",
62
+ "binary segmentation",
63
+ "multiclass segmentation",
64
+ "binary image classification",
65
+ "multiclass image classification"],
66
+ optimizer: torch.optim.Optimizer,
67
+ device: Union[Literal['cuda', 'mps', 'cpu'],str],
68
+ checkpoint_callback: Optional[DragonModelCheckpoint],
69
+ early_stopping_callback: Optional[_DragonEarlyStopping],
70
+ lr_scheduler_callback: Optional[_DragonLRScheduler],
71
+ extra_callbacks: Optional[list[_Callback]] = None,
72
+ criterion: Union[nn.Module,Literal["auto"]] = "auto",
73
+ dataloader_workers: int = 2):
74
+ """
75
+ Automates the training process of a PyTorch Model.
76
+
77
+ Built-in Callbacks: `History`, `TqdmProgressBar`
78
+
79
+ Args:
80
+ model (nn.Module): The PyTorch model to train.
81
+ train_dataset (Dataset): The training dataset.
82
+ validation_dataset (Dataset): The validation dataset.
83
+ kind (str): Used to redirect to the correct process.
84
+ criterion (nn.Module | "auto"): The loss function to use. If "auto", it will be inferred from the selected task
85
+ optimizer (torch.optim.Optimizer): The optimizer.
86
+ device (str): The device to run training on ('cpu', 'cuda', 'mps').
87
+ dataloader_workers (int): Subprocesses for data loading.
88
+ extra_callbacks (List[Callback] | None): A list of extra callbacks to use during training.
89
+
90
+ Note:
91
+ - For **regression** and **multi_target_regression** tasks, suggested criterions include `nn.MSELoss` or `nn.L1Loss`. The model should output as many logits as existing targets.
92
+
93
+ - For **single-label, binary classification**, `nn.BCEWithLogitsLoss` is the standard choice. The model should output a single logit.
94
+
95
+ - For **single-label, multi-class classification** tasks, `nn.CrossEntropyLoss` is the standard choice. The model should output as many logits as existing classes.
96
+
97
+ - For **multi-label, binary classification** tasks (where each label is a 0 or 1), `nn.BCEWithLogitsLoss` is the correct choice as it treats each output as an independent binary problem. The model should output 1 logit per binary target.
98
+
99
+ - For **binary segmentation** tasks, `nn.BCEWithLogitsLoss` is common. The model should output a single logit.
100
+
101
+ - for **multiclass segmentation** tasks, `nn.CrossEntropyLoss` is the standard. The model should output as many logits as existing classes.
102
+ """
103
+ # Call the base class constructor with common parameters
104
+ super().__init__(
105
+ model=model,
106
+ optimizer=optimizer,
107
+ device=device,
108
+ dataloader_workers=dataloader_workers,
109
+ checkpoint_callback=checkpoint_callback,
110
+ early_stopping_callback=early_stopping_callback,
111
+ lr_scheduler_callback=lr_scheduler_callback,
112
+ extra_callbacks=extra_callbacks
113
+ )
114
+
115
+ if kind not in [MLTaskKeys.REGRESSION,
116
+ MLTaskKeys.BINARY_CLASSIFICATION,
117
+ MLTaskKeys.MULTICLASS_CLASSIFICATION,
118
+ MLTaskKeys.MULTILABEL_BINARY_CLASSIFICATION,
119
+ MLTaskKeys.MULTITARGET_REGRESSION,
120
+ MLTaskKeys.BINARY_SEGMENTATION,
121
+ MLTaskKeys.MULTICLASS_SEGMENTATION,
122
+ MLTaskKeys.BINARY_IMAGE_CLASSIFICATION,
123
+ MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION]:
124
+ raise ValueError(f"'{kind}' is not a valid task type.")
125
+
126
+ self.train_dataset = train_dataset
127
+ self.validation_dataset = validation_dataset
128
+ self.kind = kind
129
+ self._classification_threshold: float = 0.5
130
+
131
+ # loss function
132
+ if criterion == "auto":
133
+ if kind in [MLTaskKeys.REGRESSION, MLTaskKeys.MULTITARGET_REGRESSION]:
134
+ self.criterion = nn.MSELoss()
135
+ elif kind in [MLTaskKeys.BINARY_CLASSIFICATION, MLTaskKeys.BINARY_IMAGE_CLASSIFICATION, MLTaskKeys.BINARY_SEGMENTATION, MLTaskKeys.MULTILABEL_BINARY_CLASSIFICATION]:
136
+ self.criterion = nn.BCEWithLogitsLoss()
137
+ elif kind in [MLTaskKeys.MULTICLASS_CLASSIFICATION, MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION, MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION, MLTaskKeys.MULTICLASS_SEGMENTATION]:
138
+ self.criterion = nn.CrossEntropyLoss()
139
+ else:
140
+ self.criterion = criterion
141
+
142
+ def _create_dataloaders(self, batch_size: int, shuffle: bool):
143
+ """Initializes the DataLoaders."""
144
+ # Ensure stability on MPS devices by setting num_workers to 0
145
+ loader_workers = 0 if self.device.type == 'mps' else self.dataloader_workers
146
+
147
+ self.train_loader = DataLoader(
148
+ dataset=self.train_dataset,
149
+ batch_size=batch_size,
150
+ shuffle=shuffle,
151
+ num_workers=loader_workers,
152
+ pin_memory=("cuda" in self.device.type),
153
+ drop_last=True # Drops the last batch if incomplete, selecting a good batch size is key.
154
+ )
155
+
156
+ self.validation_loader = DataLoader(
157
+ dataset=self.validation_dataset,
158
+ batch_size=batch_size,
159
+ shuffle=False,
160
+ num_workers=loader_workers,
161
+ pin_memory=("cuda" in self.device.type)
162
+ )
163
+
164
+ def _train_step(self):
165
+ self.model.train()
166
+ running_loss = 0.0
167
+ total_samples = 0
168
+
169
+ for batch_idx, (features, target) in enumerate(self.train_loader): # type: ignore
170
+ # Create a log dictionary for the batch
171
+ batch_logs = {
172
+ PyTorchLogKeys.BATCH_INDEX: batch_idx,
173
+ PyTorchLogKeys.BATCH_SIZE: features.size(0)
174
+ }
175
+ self._callbacks_hook('on_batch_begin', batch_idx, logs=batch_logs)
176
+
177
+ features, target = features.to(self.device), target.to(self.device)
178
+ self.optimizer.zero_grad()
179
+
180
+ output = self.model(features)
181
+
182
+ # --- Label Type/Shape Correction ---
183
+ # Cast target to float for BCE-based losses
184
+ if self.kind in MLTaskKeys.ALL_BINARY_TASKS:
185
+ target = target.float()
186
+
187
+ # Reshape output to match target for single-logit tasks
188
+ if self.kind in [MLTaskKeys.REGRESSION, MLTaskKeys.BINARY_CLASSIFICATION, MLTaskKeys.BINARY_IMAGE_CLASSIFICATION]:
189
+ # If model outputs [N, 1] and target is [N], squeeze output
190
+ if output.ndim == 2 and output.shape[1] == 1 and target.ndim == 1:
191
+ output = output.squeeze(1)
192
+
193
+ if self.kind == MLTaskKeys.BINARY_SEGMENTATION:
194
+ # If model outputs [N, 1, H, W] and target is [N, H, W], squeeze output
195
+ if output.ndim == 4 and output.shape[1] == 1 and target.ndim == 3:
196
+ output = output.squeeze(1)
197
+
198
+ loss = self.criterion(output, target)
199
+
200
+ loss.backward()
201
+ self.optimizer.step()
202
+
203
+ # Calculate batch loss and update running loss for the epoch
204
+ batch_loss = loss.item()
205
+ batch_size = features.size(0)
206
+ running_loss += batch_loss * batch_size # Accumulate total loss
207
+ total_samples += batch_size # total samples
208
+
209
+ # Add the batch loss to the logs and call the end-of-batch hook
210
+ batch_logs[PyTorchLogKeys.BATCH_LOSS] = batch_loss
211
+ self._callbacks_hook('on_batch_end', batch_idx, logs=batch_logs)
212
+
213
+ if total_samples == 0:
214
+ _LOGGER.warning("No samples processed in a train_step. Returning 0 loss.")
215
+ return {PyTorchLogKeys.TRAIN_LOSS: 0.0}
216
+
217
+ return {PyTorchLogKeys.TRAIN_LOSS: running_loss / total_samples} # type: ignore
218
+
219
+ def _validation_step(self):
220
+ self.model.eval()
221
+ running_loss = 0.0
222
+
223
+ with torch.no_grad():
224
+ for features, target in self.validation_loader: # type: ignore
225
+ features, target = features.to(self.device), target.to(self.device)
226
+
227
+ output = self.model(features)
228
+
229
+ # --- Label Type/Shape Correction ---
230
+ # Cast target to float for BCE-based losses
231
+ if self.kind in MLTaskKeys.ALL_BINARY_TASKS:
232
+ target = target.float()
233
+
234
+ # Reshape output to match target for single-logit tasks
235
+ if self.kind in [MLTaskKeys.REGRESSION, MLTaskKeys.BINARY_CLASSIFICATION, MLTaskKeys.BINARY_IMAGE_CLASSIFICATION]:
236
+ # If model outputs [N, 1] and target is [N], squeeze output
237
+ if output.ndim == 2 and output.shape[1] == 1 and target.ndim == 1:
238
+ output = output.squeeze(1)
239
+
240
+ if self.kind == MLTaskKeys.BINARY_SEGMENTATION:
241
+ # If model outputs [N, 1, H, W] and target is [N, H, W], squeeze output
242
+ if output.ndim == 4 and output.shape[1] == 1 and target.ndim == 3:
243
+ output = output.squeeze(1)
244
+
245
+ loss = self.criterion(output, target)
246
+
247
+ running_loss += loss.item() * features.size(0)
248
+
249
+ if not self.validation_loader.dataset: # type: ignore
250
+ _LOGGER.warning("No samples processed in _validation_step. Returning 0 loss.")
251
+ return {PyTorchLogKeys.VAL_LOSS: 0.0}
252
+
253
+ logs = {PyTorchLogKeys.VAL_LOSS: running_loss / len(self.validation_loader.dataset)} # type: ignore
254
+ return logs
255
+
256
+ def _predict_for_eval(self, dataloader: DataLoader):
257
+ """
258
+ Private method to yield model predictions batch by batch for evaluation.
259
+
260
+ Automatically detects if `target_scaler` is present in the training dataset
261
+ and applies inverse transformation for Regression tasks.
262
+
263
+ Yields:
264
+ tuple: A tuple containing (y_pred_batch, y_prob_batch, y_true_batch).
265
+
266
+ - y_prob_batch is None for regression tasks.
267
+ """
268
+ self.model.eval()
269
+ self.model.to(self.device)
270
+
271
+ # --- Check for Target Scaler (for Regression Un-scaling) ---
272
+ target_scaler = None
273
+ if self.kind in [MLTaskKeys.REGRESSION, MLTaskKeys.MULTITARGET_REGRESSION]:
274
+ # Try to get the scaler from the dataset attached to the trainer
275
+ if hasattr(self.train_dataset, ScalerKeys.TARGET_SCALER):
276
+ target_scaler = getattr(self.train_dataset, ScalerKeys.TARGET_SCALER)
277
+ if target_scaler is not None:
278
+ _LOGGER.debug("Target scaler detected. Un-scaling predictions and targets for metric calculation.")
279
+
280
+ with torch.no_grad():
281
+ for features, target in dataloader:
282
+ features = features.to(self.device)
283
+ # Keep target on device initially for potential un-scaling
284
+ target = target.to(self.device)
285
+
286
+ output = self.model(features)
287
+
288
+ y_pred_batch = None
289
+ y_prob_batch = None
290
+ y_true_batch = None
291
+
292
+ if self.kind in [MLTaskKeys.REGRESSION, MLTaskKeys.MULTITARGET_REGRESSION]:
293
+
294
+ # --- Automatic Un-scaling Logic ---
295
+ if target_scaler:
296
+ # 1. Reshape output/target if flattened (common in single regression)
297
+ # Scaler expects [N, Features]
298
+ original_out_shape = output.shape
299
+ original_target_shape = target.shape
300
+
301
+ if output.ndim == 1: output = output.reshape(-1, 1)
302
+ if target.ndim == 1: target = target.reshape(-1, 1)
303
+
304
+ # 2. Apply Inverse Transform
305
+ output = target_scaler.inverse_transform(output)
306
+ target = target_scaler.inverse_transform(target)
307
+
308
+ # 3. Restore shapes (optional, but good for consistency)
309
+ if len(original_out_shape) == 1: output = output.flatten()
310
+ if len(original_target_shape) == 1: target = target.flatten()
311
+
312
+ y_pred_batch = output.cpu().numpy()
313
+ y_true_batch = target.cpu().numpy()
314
+
315
+ elif self.kind in [MLTaskKeys.BINARY_CLASSIFICATION, MLTaskKeys.BINARY_IMAGE_CLASSIFICATION]:
316
+ if output.ndim == 2 and output.shape[1] == 1:
317
+ output = output.squeeze(1)
318
+
319
+ probs_pos = torch.sigmoid(output)
320
+ preds = (probs_pos >= self._classification_threshold).int()
321
+ y_pred_batch = preds.cpu().numpy()
322
+
323
+ probs_neg = 1.0 - probs_pos
324
+ y_prob_batch = torch.stack([probs_neg, probs_pos], dim=1).cpu().numpy()
325
+ y_true_batch = target.cpu().numpy()
326
+
327
+ elif self.kind in [MLTaskKeys.MULTICLASS_CLASSIFICATION, MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION]:
328
+ probs = torch.softmax(output, dim=1)
329
+ preds = torch.argmax(probs, dim=1)
330
+ y_pred_batch = preds.cpu().numpy()
331
+ y_prob_batch = probs.cpu().numpy()
332
+ y_true_batch = target.cpu().numpy()
333
+
334
+ elif self.kind == MLTaskKeys.MULTILABEL_BINARY_CLASSIFICATION:
335
+ probs = torch.sigmoid(output)
336
+ preds = (probs >= self._classification_threshold).int()
337
+ y_pred_batch = preds.cpu().numpy()
338
+ y_prob_batch = probs.cpu().numpy()
339
+ y_true_batch = target.cpu().numpy()
340
+
341
+ elif self.kind == MLTaskKeys.BINARY_SEGMENTATION:
342
+ probs_pos = torch.sigmoid(output)
343
+ preds = (probs_pos >= self._classification_threshold).int()
344
+ y_pred_batch = preds.squeeze(1).cpu().numpy()
345
+
346
+ probs_neg = 1.0 - probs_pos
347
+ y_prob_batch = torch.cat([probs_neg, probs_pos], dim=1).cpu().numpy()
348
+
349
+ if target.ndim == 4 and target.shape[1] == 1:
350
+ target = target.squeeze(1)
351
+ y_true_batch = target.cpu().numpy()
352
+
353
+ elif self.kind == MLTaskKeys.MULTICLASS_SEGMENTATION:
354
+ probs = torch.softmax(output, dim=1)
355
+ preds = torch.argmax(probs, dim=1)
356
+ y_pred_batch = preds.cpu().numpy()
357
+ y_prob_batch = probs.cpu().numpy()
358
+
359
+ if target.ndim == 4 and target.shape[1] == 1:
360
+ target = target.squeeze(1)
361
+ y_true_batch = target.cpu().numpy()
362
+
363
+ yield y_pred_batch, y_prob_batch, y_true_batch
364
+
365
+ def evaluate(self,
366
+ save_dir: Union[str, Path],
367
+ model_checkpoint: Union[Path, Literal["best", "current"]],
368
+ classification_threshold: Optional[float] = None,
369
+ test_data: Optional[Union[DataLoader, Dataset]] = None,
370
+ val_format_configuration: Optional[Union[
371
+ FormatRegressionMetrics,
372
+ FormatMultiTargetRegressionMetrics,
373
+ FormatBinaryClassificationMetrics,
374
+ FormatMultiClassClassificationMetrics,
375
+ FormatBinaryImageClassificationMetrics,
376
+ FormatMultiClassImageClassificationMetrics,
377
+ FormatMultiLabelBinaryClassificationMetrics,
378
+ FormatBinarySegmentationMetrics,
379
+ FormatMultiClassSegmentationMetrics
380
+ ]]=None,
381
+ test_format_configuration: Optional[Union[
382
+ FormatRegressionMetrics,
383
+ FormatMultiTargetRegressionMetrics,
384
+ FormatBinaryClassificationMetrics,
385
+ FormatMultiClassClassificationMetrics,
386
+ FormatBinaryImageClassificationMetrics,
387
+ FormatMultiClassImageClassificationMetrics,
388
+ FormatMultiLabelBinaryClassificationMetrics,
389
+ FormatBinarySegmentationMetrics,
390
+ FormatMultiClassSegmentationMetrics,
391
+ ]]=None):
392
+ """
393
+ Evaluates the model, routing to the correct evaluation function based on task `kind`.
394
+
395
+ Args:
396
+ model_checkpoint (Path | "best" | "current"):
397
+ - Path to a valid checkpoint for the model. The state of the trained model will be overwritten in place.
398
+ - If 'best', the best checkpoint will be loaded if a DragonModelCheckpoint was provided. The state of the trained model will be overwritten in place.
399
+ - If 'current', use the current state of the trained model up the latest trained epoch.
400
+ save_dir (str | Path): Directory to save all reports and plots.
401
+ classification_threshold (float | None): Used for tasks using a binary approach (binary classification, binary segmentation, multilabel binary classification)
402
+ test_data (DataLoader | Dataset | None): Optional Test data to evaluate the model performance. Validation and Test metrics will be saved to subdirectories.
403
+ val_format_configuration (object): Optional configuration for metric format output for the validation set.
404
+ test_format_configuration (object): Optional configuration for metric format output for the test set.
405
+ """
406
+ # Validate model checkpoint
407
+ if isinstance(model_checkpoint, Path):
408
+ checkpoint_validated = make_fullpath(model_checkpoint, enforce="file")
409
+ elif model_checkpoint in [MagicWords.BEST, MagicWords.CURRENT]:
410
+ checkpoint_validated = model_checkpoint
411
+ else:
412
+ _LOGGER.error(f"'model_checkpoint' must be a Path object, or the string '{MagicWords.BEST}', or the string '{MagicWords.CURRENT}'.")
413
+ raise ValueError()
414
+
415
+ # Validate classification threshold
416
+ if self.kind not in MLTaskKeys.ALL_BINARY_TASKS:
417
+ # dummy value for tasks that do not need it
418
+ threshold_validated = 0.5
419
+ elif classification_threshold is None:
420
+ # it should have been provided for binary tasks
421
+ _LOGGER.error(f"The classification threshold must be provided for '{self.kind}'.")
422
+ raise ValueError()
423
+ elif classification_threshold <= 0.0 or classification_threshold >= 1.0:
424
+ # Invalid float
425
+ _LOGGER.error(f"A classification threshold of {classification_threshold} is invalid. Must be in the range (0.0 - 1.0).")
426
+ raise ValueError()
427
+ else:
428
+ threshold_validated = classification_threshold
429
+
430
+ # Validate val configuration
431
+ if val_format_configuration is not None:
432
+ if not isinstance(val_format_configuration, (FormatRegressionMetrics,
433
+ FormatMultiTargetRegressionMetrics,
434
+ FormatBinaryClassificationMetrics,
435
+ FormatMultiClassClassificationMetrics,
436
+ FormatBinaryImageClassificationMetrics,
437
+ FormatMultiClassImageClassificationMetrics,
438
+ FormatMultiLabelBinaryClassificationMetrics,
439
+ FormatBinarySegmentationMetrics,
440
+ FormatMultiClassSegmentationMetrics)):
441
+ _LOGGER.error(f"Invalid 'format_configuration': '{type(val_format_configuration)}'.")
442
+ raise ValueError()
443
+ else:
444
+ val_configuration_validated = val_format_configuration
445
+ else: # config is None
446
+ val_configuration_validated = None
447
+
448
+ # Validate directory
449
+ save_path = make_fullpath(save_dir, make=True, enforce="directory")
450
+
451
+ # Validate test data and dispatch
452
+ if test_data is not None:
453
+ if not isinstance(test_data, (DataLoader, Dataset)):
454
+ _LOGGER.error(f"Invalid type for 'test_data': '{type(test_data)}'.")
455
+ raise ValueError()
456
+ test_data_validated = test_data
457
+
458
+ validation_metrics_path = save_path / DragonTrainerKeys.VALIDATION_METRICS_DIR
459
+ test_metrics_path = save_path / DragonTrainerKeys.TEST_METRICS_DIR
460
+
461
+ # Dispatch validation set
462
+ _LOGGER.info(f"🔎 Evaluating on validation dataset. Metrics will be saved to '{DragonTrainerKeys.VALIDATION_METRICS_DIR}'")
463
+ self._evaluate(save_dir=validation_metrics_path,
464
+ model_checkpoint=checkpoint_validated,
465
+ classification_threshold=threshold_validated,
466
+ data=None,
467
+ format_configuration=val_configuration_validated)
468
+
469
+ # Validate test configuration
470
+ if test_format_configuration is not None:
471
+ if not isinstance(test_format_configuration, (FormatRegressionMetrics,
472
+ FormatMultiTargetRegressionMetrics,
473
+ FormatBinaryClassificationMetrics,
474
+ FormatMultiClassClassificationMetrics,
475
+ FormatBinaryImageClassificationMetrics,
476
+ FormatMultiClassImageClassificationMetrics,
477
+ FormatMultiLabelBinaryClassificationMetrics,
478
+ FormatBinarySegmentationMetrics,
479
+ FormatMultiClassSegmentationMetrics)):
480
+ warning_message_type = f"Invalid test_format_configuration': '{type(test_format_configuration)}'."
481
+ if val_configuration_validated is not None:
482
+ warning_message_type += " 'val_format_configuration' will be used for the test set metrics output."
483
+ test_configuration_validated = val_configuration_validated
484
+ else:
485
+ warning_message_type += " Using default format."
486
+ test_configuration_validated = None
487
+ _LOGGER.warning(warning_message_type)
488
+ else:
489
+ test_configuration_validated = test_format_configuration
490
+ else: #config is None
491
+ test_configuration_validated = None
492
+
493
+ # Dispatch test set
494
+ _LOGGER.info(f"🔎 Evaluating on test dataset. Metrics will be saved to '{DragonTrainerKeys.TEST_METRICS_DIR}'")
495
+ self._evaluate(save_dir=test_metrics_path,
496
+ model_checkpoint="current",
497
+ classification_threshold=threshold_validated,
498
+ data=test_data_validated,
499
+ format_configuration=test_configuration_validated)
500
+ else:
501
+ # Dispatch validation set
502
+ _LOGGER.info(f"Evaluating on validation dataset. Metrics will be saved to '{save_path.name}'")
503
+ self._evaluate(save_dir=save_path,
504
+ model_checkpoint=checkpoint_validated,
505
+ classification_threshold=threshold_validated,
506
+ data=None,
507
+ format_configuration=val_configuration_validated)
508
+
509
+ def _evaluate(self,
510
+ save_dir: Union[str, Path],
511
+ model_checkpoint: Union[Path, Literal["best", "current"]],
512
+ classification_threshold: float,
513
+ data: Optional[Union[DataLoader, Dataset]],
514
+ format_configuration: Optional[Union[
515
+ FormatRegressionMetrics,
516
+ FormatMultiTargetRegressionMetrics,
517
+ FormatBinaryClassificationMetrics,
518
+ FormatMultiClassClassificationMetrics,
519
+ FormatBinaryImageClassificationMetrics,
520
+ FormatMultiClassImageClassificationMetrics,
521
+ FormatMultiLabelBinaryClassificationMetrics,
522
+ FormatBinarySegmentationMetrics,
523
+ FormatMultiClassSegmentationMetrics
524
+ ]]=None):
525
+ """
526
+ Changed to a private helper function.
527
+ """
528
+ dataset_for_artifacts = None
529
+ eval_loader = None
530
+
531
+ # set threshold
532
+ self._classification_threshold = classification_threshold
533
+
534
+ # load model checkpoint
535
+ if isinstance(model_checkpoint, Path):
536
+ self._load_checkpoint(path=model_checkpoint)
537
+ elif model_checkpoint == MagicWords.BEST and self._checkpoint_callback:
538
+ path_to_latest = self._checkpoint_callback.best_checkpoint_path
539
+ self._load_checkpoint(path_to_latest)
540
+ elif model_checkpoint == MagicWords.BEST and self._checkpoint_callback is None:
541
+ _LOGGER.error(f"'model_checkpoint' set to '{MagicWords.BEST}' but no checkpoint callback was found.")
542
+ raise ValueError()
543
+
544
+ # Dataloader
545
+ if isinstance(data, DataLoader):
546
+ eval_loader = data
547
+ # Try to get the dataset from the loader for fetching target names
548
+ if hasattr(data, 'dataset'):
549
+ dataset_for_artifacts = data.dataset # type: ignore
550
+ elif isinstance(data, Dataset):
551
+ # Create a new loader from the provided dataset
552
+ eval_loader = DataLoader(data,
553
+ batch_size=self._batch_size,
554
+ shuffle=False,
555
+ num_workers=0 if self.device.type == 'mps' else self.dataloader_workers,
556
+ pin_memory=(self.device.type == "cuda"))
557
+ dataset_for_artifacts = data
558
+ else: # data is None, use the trainer's default test dataset
559
+ if self.validation_dataset is None:
560
+ _LOGGER.error("Cannot evaluate. No data provided and no validation dataset available in the trainer.")
561
+ raise ValueError()
562
+ # Create a fresh DataLoader from the test_dataset
563
+ eval_loader = DataLoader(self.validation_dataset,
564
+ batch_size=self._batch_size,
565
+ shuffle=False,
566
+ num_workers=0 if self.device.type == 'mps' else self.dataloader_workers,
567
+ pin_memory=(self.device.type == "cuda"))
568
+
569
+ dataset_for_artifacts = self.validation_dataset
570
+
571
+ if eval_loader is None:
572
+ _LOGGER.error("Cannot evaluate. No valid data was provided or found.")
573
+ raise ValueError()
574
+
575
+ # print("\n--- Model Evaluation ---")
576
+
577
+ all_preds, all_probs, all_true = [], [], []
578
+ for y_pred_b, y_prob_b, y_true_b in self._predict_for_eval(eval_loader):
579
+ if y_pred_b is not None: all_preds.append(y_pred_b)
580
+ if y_prob_b is not None: all_probs.append(y_prob_b)
581
+ if y_true_b is not None: all_true.append(y_true_b)
582
+
583
+ if not all_true:
584
+ _LOGGER.error("Evaluation failed: No data was processed.")
585
+ return
586
+
587
+ y_pred = np.concatenate(all_preds)
588
+ y_true = np.concatenate(all_true)
589
+ y_prob = np.concatenate(all_probs) if all_probs else None
590
+
591
+ # --- Routing Logic ---
592
+ # Single-target regression
593
+ if self.kind == MLTaskKeys.REGRESSION:
594
+ # Check configuration
595
+ config = None
596
+ if format_configuration and isinstance(format_configuration, FormatRegressionMetrics):
597
+ config = format_configuration
598
+ elif format_configuration:
599
+ _LOGGER.warning(f"Wrong configuration type: Received '{type(format_configuration).__name__}'.")
600
+
601
+ regression_metrics(y_true=y_true.flatten(),
602
+ y_pred=y_pred.flatten(),
603
+ save_dir=save_dir,
604
+ config=config)
605
+
606
+ # single target classification
607
+ elif self.kind in [MLTaskKeys.BINARY_CLASSIFICATION,
608
+ MLTaskKeys.BINARY_IMAGE_CLASSIFICATION,
609
+ MLTaskKeys.MULTICLASS_CLASSIFICATION,
610
+ MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION]:
611
+ # get the class map if it exists
612
+ try:
613
+ class_map = dataset_for_artifacts.class_map # type: ignore
614
+ except AttributeError:
615
+ _LOGGER.warning(f"Dataset has no 'class_map' attribute. Using generics.")
616
+ class_map = None
617
+ else:
618
+ if not isinstance(class_map, dict):
619
+ _LOGGER.warning(f"Dataset has a 'class_map' attribute, but it is not a dictionary: '{type(class_map)}'.")
620
+ class_map = None
621
+
622
+ # Check configuration
623
+ config = None
624
+ if format_configuration:
625
+ if self.kind == MLTaskKeys.BINARY_CLASSIFICATION and isinstance(format_configuration, FormatBinaryClassificationMetrics):
626
+ config = format_configuration
627
+ elif self.kind == MLTaskKeys.BINARY_IMAGE_CLASSIFICATION and isinstance(format_configuration, FormatBinaryImageClassificationMetrics):
628
+ config = format_configuration
629
+ elif self.kind == MLTaskKeys.MULTICLASS_CLASSIFICATION and isinstance(format_configuration, FormatMultiClassClassificationMetrics):
630
+ config = format_configuration
631
+ elif self.kind == MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION and isinstance(format_configuration, FormatMultiClassImageClassificationMetrics):
632
+ config = format_configuration
633
+ else:
634
+ _LOGGER.warning(f"Wrong configuration type: Received '{type(format_configuration).__name__}'.")
635
+
636
+ classification_metrics(save_dir=save_dir,
637
+ y_true=y_true,
638
+ y_pred=y_pred,
639
+ y_prob=y_prob,
640
+ class_map=class_map,
641
+ config=config)
642
+
643
+ # multitarget regression
644
+ elif self.kind == MLTaskKeys.MULTITARGET_REGRESSION:
645
+ try:
646
+ target_names = dataset_for_artifacts.target_names # type: ignore
647
+ except AttributeError:
648
+ num_targets = y_true.shape[1]
649
+ target_names = [f"target_{i}" for i in range(num_targets)]
650
+ _LOGGER.warning(f"Dataset has no 'target_names' attribute. Using generic names.")
651
+
652
+ # Check configuration
653
+ config = None
654
+ if format_configuration and isinstance(format_configuration, FormatMultiTargetRegressionMetrics):
655
+ config = format_configuration
656
+ elif format_configuration:
657
+ _LOGGER.warning(f"Wrong configuration type: Received '{type(format_configuration).__name__}'.")
658
+
659
+ multi_target_regression_metrics(y_true=y_true,
660
+ y_pred=y_pred,
661
+ target_names=target_names,
662
+ save_dir=save_dir,
663
+ config=config)
664
+
665
+ # multi-label binary classification
666
+ elif self.kind == MLTaskKeys.MULTILABEL_BINARY_CLASSIFICATION:
667
+ try:
668
+ target_names = dataset_for_artifacts.target_names # type: ignore
669
+ except AttributeError:
670
+ num_targets = y_true.shape[1]
671
+ target_names = [f"label_{i}" for i in range(num_targets)]
672
+ _LOGGER.warning(f"Dataset has no 'target_names' attribute. Using generic names.")
673
+
674
+ if y_prob is None:
675
+ _LOGGER.error("Evaluation for multi_label_classification requires probabilities (y_prob).")
676
+ return
677
+
678
+ # Check configuration
679
+ config = None
680
+ if format_configuration and isinstance(format_configuration, FormatMultiLabelBinaryClassificationMetrics):
681
+ config = format_configuration
682
+ elif format_configuration:
683
+ _LOGGER.warning(f"Wrong configuration type: Received '{type(format_configuration).__name__}'.")
684
+
685
+ multi_label_classification_metrics(y_true=y_true,
686
+ y_pred=y_pred,
687
+ y_prob=y_prob,
688
+ target_names=target_names,
689
+ save_dir=save_dir,
690
+ config=config)
691
+
692
+ # Segmentation tasks
693
+ elif self.kind in [MLTaskKeys.BINARY_SEGMENTATION, MLTaskKeys.MULTICLASS_SEGMENTATION]:
694
+ class_names = None
695
+ try:
696
+ # Try to get 'classes' from VisionDatasetMaker
697
+ if hasattr(dataset_for_artifacts, 'classes'):
698
+ class_names = dataset_for_artifacts.classes # type: ignore
699
+ # Fallback for Subset
700
+ elif hasattr(dataset_for_artifacts, 'dataset') and hasattr(dataset_for_artifacts.dataset, 'classes'): # type: ignore
701
+ class_names = dataset_for_artifacts.dataset.classes # type: ignore
702
+ except AttributeError:
703
+ pass # class_names is still None
704
+
705
+ if class_names is None:
706
+ try:
707
+ # Fallback to 'target_names'
708
+ class_names = dataset_for_artifacts.target_names # type: ignore
709
+ except AttributeError:
710
+ # Fallback to inferring from labels
711
+ labels = np.unique(y_true)
712
+ class_names = [f"Class {i}" for i in labels]
713
+ _LOGGER.warning(f"Dataset has no 'classes' or 'target_names' attribute. Using generic names.")
714
+
715
+ # Check configuration
716
+ config = None
717
+ if format_configuration and isinstance(format_configuration, (FormatBinarySegmentationMetrics, FormatMultiClassSegmentationMetrics)):
718
+ config = format_configuration
719
+ elif format_configuration:
720
+ _LOGGER.warning(f"Wrong configuration type: Received '{type(format_configuration).__name__}'.")
721
+
722
+ segmentation_metrics(y_true=y_true,
723
+ y_pred=y_pred,
724
+ save_dir=save_dir,
725
+ class_names=class_names,
726
+ config=config)
727
+
728
+ def explain_shap(self,
729
+ save_dir: Union[str,Path],
730
+ explain_dataset: Optional[Dataset] = None,
731
+ n_samples: int = 300,
732
+ feature_names: Optional[list[str]] = None,
733
+ target_names: Optional[list[str]] = None,
734
+ explainer_type: Literal['deep', 'kernel'] = 'kernel'):
735
+ """
736
+ Explains model predictions using SHAP and saves all artifacts.
737
+
738
+ NOTE: SHAP support is limited to single-target tasks (Regression, Binary/Multiclass Classification).
739
+ For complex tasks (Multi-target, Multi-label, Sequences, Images), please use `explain_captum()`.
740
+
741
+ The background data is automatically sampled from the trainer's training dataset.
742
+
743
+ This method automatically routes to the appropriate SHAP summary plot
744
+ function based on the task. If `feature_names` or `target_names` (multi-target) are not provided,
745
+ it will attempt to extract them from the dataset.
746
+
747
+ Args:
748
+ explain_dataset (Dataset | None): A specific dataset to explain.
749
+ If None, the trainer's test dataset is used.
750
+ n_samples (int): The number of samples to use for both background and explanation.
751
+ feature_names (list[str] | None): Feature names. If None, the names will be extracted from the Dataset and raise an error on failure.
752
+ target_names (list[str] | None): Target names for multi-target tasks.
753
+ save_dir (str | Path): Directory to save all SHAP artifacts.
754
+ explainer_type (Literal['deep', 'kernel']): The explainer to use.
755
+ - 'deep': Uses shap.DeepExplainer. Fast and efficient for PyTorch models.
756
+ - 'kernel': Uses shap.KernelExplainer. Model-agnostic but EXTREMELY slow and memory-intensive. Use with a very low 'n_samples'< 100.
757
+ """
758
+ # --- 1. Compatibility Guard ---
759
+ valid_shap_tasks = [
760
+ MLTaskKeys.REGRESSION,
761
+ MLTaskKeys.BINARY_CLASSIFICATION,
762
+ MLTaskKeys.MULTICLASS_CLASSIFICATION
763
+ ]
764
+
765
+ if self.kind not in valid_shap_tasks:
766
+ _LOGGER.warning(f"SHAP explanation is deprecated for task '{self.kind}' due to instability. Please use 'explain_captum()' instead.")
767
+ return
768
+
769
+ # memory efficient helper
770
+ def _get_random_sample(dataset: Dataset, num_samples: int):
771
+ """
772
+ Memory-efficiently samples data from a dataset.
773
+ """
774
+ if dataset is None:
775
+ return None
776
+
777
+ dataset_len = len(dataset) # type: ignore
778
+ if dataset_len == 0:
779
+ return None
780
+
781
+ # For MPS devices, num_workers must be 0 to ensure stability
782
+ loader_workers = 0 if self.device.type == 'mps' else self.dataloader_workers
783
+
784
+ # Ensure batch_size is not larger than the dataset itself
785
+ batch_size = min(num_samples, 64, dataset_len)
786
+
787
+ loader = DataLoader(
788
+ dataset,
789
+ batch_size=batch_size,
790
+ shuffle=True, # Shuffle to get random samples
791
+ num_workers=loader_workers
792
+ )
793
+
794
+ collected_features = []
795
+ num_collected = 0
796
+
797
+ for features, _ in loader:
798
+ collected_features.append(features)
799
+ num_collected += features.size(0)
800
+ if num_collected >= num_samples:
801
+ break # Stop once we have enough samples
802
+
803
+ if not collected_features:
804
+ return None
805
+
806
+ full_data = torch.cat(collected_features, dim=0)
807
+
808
+ # If we collected more than needed, trim it down
809
+ if full_data.size(0) > num_samples:
810
+ return full_data[:num_samples]
811
+
812
+ return full_data
813
+
814
+ # print(f"\n--- Preparing SHAP Data (sampling up to {n_samples} instances) ---")
815
+
816
+ # 1. Get background data from the trainer's train_dataset
817
+ background_data = _get_random_sample(self.train_dataset, n_samples)
818
+ if background_data is None:
819
+ _LOGGER.error("Trainer's train_dataset is empty or invalid. Skipping SHAP analysis.")
820
+ return
821
+
822
+ # 2. Determine target dataset and get explanation instances
823
+ target_dataset = explain_dataset if explain_dataset is not None else self.validation_dataset
824
+ instances_to_explain = _get_random_sample(target_dataset, n_samples)
825
+ if instances_to_explain is None:
826
+ _LOGGER.error("Explanation dataset is empty or invalid. Skipping SHAP analysis.")
827
+ return
828
+
829
+ # attempt to get feature names
830
+ if feature_names is None:
831
+ # _LOGGER.info("`feature_names` not provided. Attempting to extract from dataset...")
832
+ if hasattr(target_dataset, DatasetKeys.FEATURE_NAMES):
833
+ feature_names = target_dataset.feature_names # type: ignore
834
+ else:
835
+ _LOGGER.error(f"Could not extract `feature_names` from the dataset. It must be provided if the dataset object does not have a '{DatasetKeys.FEATURE_NAMES}' attribute.")
836
+ raise ValueError()
837
+
838
+ # move model to device
839
+ self.model.to(self.device)
840
+
841
+ # 3. Call the plotting function
842
+ if self.kind in [MLTaskKeys.REGRESSION, MLTaskKeys.BINARY_CLASSIFICATION, MLTaskKeys.MULTICLASS_CLASSIFICATION]:
843
+ shap_summary_plot(
844
+ model=self.model,
845
+ background_data=background_data,
846
+ instances_to_explain=instances_to_explain,
847
+ feature_names=feature_names,
848
+ save_dir=save_dir,
849
+ explainer_type=explainer_type,
850
+ device=self.device
851
+ )
852
+ # DEPRECATED: Multi-target SHAP support is unstable; recommend Captum instead.
853
+ elif self.kind in [MLTaskKeys.MULTITARGET_REGRESSION, MLTaskKeys.MULTILABEL_BINARY_CLASSIFICATION]:
854
+ # try to get target names
855
+ if target_names is None:
856
+ target_names = []
857
+ if hasattr(target_dataset, DatasetKeys.TARGET_NAMES):
858
+ target_names = target_dataset.target_names # type: ignore
859
+ else:
860
+ # Infer number of targets from the model's output layer
861
+ try:
862
+ num_targets = self.model.output_layer.out_features # type: ignore
863
+ target_names = [f"target_{i}" for i in range(num_targets)] # type: ignore
864
+ _LOGGER.warning("Dataset has no 'target_names' attribute. Using generic names.")
865
+ except AttributeError:
866
+ _LOGGER.error("Cannot determine target names for multi-target SHAP plot. Skipping.")
867
+ return
868
+
869
+ multi_target_shap_summary_plot(
870
+ model=self.model,
871
+ background_data=background_data,
872
+ instances_to_explain=instances_to_explain,
873
+ feature_names=feature_names, # type: ignore
874
+ target_names=target_names, # type: ignore
875
+ save_dir=save_dir,
876
+ explainer_type=explainer_type,
877
+ device=self.device
878
+ )
879
+
880
+ def explain_captum(self,
881
+ save_dir: Union[str, Path],
882
+ explain_dataset: Optional[Dataset] = None,
883
+ n_samples: int = 100,
884
+ feature_names: Optional[list[str]] = None,
885
+ target_names: Optional[list[str]] = None,
886
+ n_steps: int = 50):
887
+ """
888
+ Explains model predictions using Captum's Integrated Gradients.
889
+
890
+ - **Tabular/Classification:** Generates Feature Importance Bar Charts.
891
+ - **Segmentation:** Generates Spatial Heatmaps for each class.
892
+
893
+ Args:
894
+ save_dir (str | Path): Directory to save artifacts.
895
+ explain_dataset (Dataset | None): Dataset to sample from. Defaults to validation set.
896
+ n_samples (int): Number of samples to evaluate.
897
+ feature_names (list[str] | None): Feature names.
898
+ - Required for Tabular tasks.
899
+ - Ignored/Optional for Image tasks (defaults to Channel names).
900
+ target_names (list[str] | None): Names for the model outputs (or Class names).
901
+ - If None, attempts to extract from dataset attributes (`target_names`, `classes`, or `class_map`).
902
+ - If extraction fails, generates generic names (e.g. "Output_0").
903
+ n_steps (int): Number of interpolation steps.
904
+ """
905
+ # 2. Prepare Data
906
+ dataset_to_use = explain_dataset if explain_dataset is not None else self.validation_dataset
907
+ if dataset_to_use is None:
908
+ _LOGGER.error("No dataset available for explanation.")
909
+ return
910
+
911
+ # Efficient sampling helper
912
+ def _get_samples(ds, n):
913
+ # Use num_workers=0 for stability during ad-hoc sampling
914
+ loader = DataLoader(ds, batch_size=n, shuffle=True, num_workers=0)
915
+ data_iter = iter(loader)
916
+ features, targets = next(data_iter)
917
+ return features, targets
918
+
919
+ input_data, _ = _get_samples(dataset_to_use, n_samples)
920
+
921
+ # 3. Get Feature Names (Only if NOT segmentation AND NOT image classification)
922
+ # Image tasks generally don't have explicit feature names; Captum will default to "Channel_X"
923
+ is_segmentation = self.kind in [MLTaskKeys.BINARY_SEGMENTATION, MLTaskKeys.MULTICLASS_SEGMENTATION]
924
+ is_image_classification = self.kind in [MLTaskKeys.BINARY_IMAGE_CLASSIFICATION, MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION]
925
+
926
+ if feature_names is None and not is_segmentation and not is_image_classification:
927
+ if hasattr(dataset_to_use, DatasetKeys.FEATURE_NAMES):
928
+ feature_names = dataset_to_use.feature_names # type: ignore
929
+ else:
930
+ _LOGGER.error(f"Could not extract `feature_names`. It must be provided if the dataset does not have it.")
931
+ raise ValueError()
932
+
933
+ # 4. Handle Target Names (or Class Names)
934
+ if target_names is None:
935
+ # A. Try dataset attributes first
936
+ if hasattr(dataset_to_use, DatasetKeys.TARGET_NAMES):
937
+ target_names = dataset_to_use.target_names # type: ignore
938
+ elif hasattr(dataset_to_use, "classes"):
939
+ target_names = dataset_to_use.classes # type: ignore
940
+ elif hasattr(dataset_to_use, "class_map") and isinstance(dataset_to_use.class_map, dict): # type: ignore
941
+ # Sort by value (index) to ensure correct order: {name: index} -> [name_at_0, name_at_1...]
942
+ sorted_items = sorted(dataset_to_use.class_map.items(), key=lambda item: item[1]) # type: ignore
943
+ target_names = [k for k, v in sorted_items]
944
+
945
+ # B. Infer based on task
946
+ if target_names is None:
947
+ if self.kind in [MLTaskKeys.REGRESSION, MLTaskKeys.BINARY_CLASSIFICATION, MLTaskKeys.BINARY_IMAGE_CLASSIFICATION]:
948
+ target_names = ["Output"]
949
+ elif self.kind == MLTaskKeys.BINARY_SEGMENTATION:
950
+ target_names = ["Foreground"]
951
+
952
+ # For multiclass/multitarget without names, leave it None and let the evaluation function generate generics.
953
+
954
+ # 5. Dispatch based on Task
955
+ if is_segmentation:
956
+ # lower n_steps for segmentation to save memory
957
+ if n_steps > 30:
958
+ n_steps = 30
959
+ _LOGGER.warning(f"Segmentation task detected: Reducing Captum n_steps to {n_steps} to prevent OOM. If you encounter OOM errors, consider lowering this further.")
960
+
961
+ captum_segmentation_heatmap(
962
+ model=self.model,
963
+ input_data=input_data,
964
+ save_dir=save_dir,
965
+ target_names=target_names, # Can be None, helper handles it
966
+ n_steps=n_steps,
967
+ device=self.device
968
+ )
969
+
970
+ elif is_image_classification:
971
+ captum_image_heatmap(
972
+ model=self.model,
973
+ input_data=input_data,
974
+ save_dir=save_dir,
975
+ target_names=target_names,
976
+ n_steps=n_steps,
977
+ device=self.device
978
+ )
979
+
980
+ else:
981
+ # Standard Tabular/Image Classification
982
+ captum_feature_importance(
983
+ model=self.model,
984
+ input_data=input_data,
985
+ feature_names=feature_names,
986
+ save_dir=save_dir,
987
+ target_names=target_names,
988
+ n_steps=n_steps,
989
+ device=self.device
990
+ )
991
+
992
+ def _attention_helper(self, dataloader: DataLoader):
993
+ """
994
+ Private method to yield model attention weights batch by batch for evaluation.
995
+
996
+ Args:
997
+ dataloader (DataLoader): The dataloader to predict on.
998
+
999
+ Yields:
1000
+ (torch.Tensor): Attention weights
1001
+ """
1002
+ self.model.eval()
1003
+ self.model.to(self.device)
1004
+
1005
+ with torch.no_grad():
1006
+ for features, target in dataloader:
1007
+ features = features.to(self.device)
1008
+ attention_weights = None
1009
+
1010
+ # Get model output
1011
+ # Unpack logits and weights from the special forward method
1012
+ _output, attention_weights = self.model.forward_attention(features) # type: ignore
1013
+
1014
+ if attention_weights is not None:
1015
+ attention_weights = attention_weights.cpu()
1016
+
1017
+ yield attention_weights
1018
+
1019
+ def explain_attention(self, save_dir: Union[str, Path],
1020
+ feature_names: Optional[list[str]] = None,
1021
+ explain_dataset: Optional[Dataset] = None,
1022
+ plot_n_features: int = 10):
1023
+ """
1024
+ Generates and saves a feature importance plot based on attention weights.
1025
+
1026
+ This method only works for models with models with 'has_interpretable_attention'.
1027
+
1028
+ Args:
1029
+ save_dir (str | Path): Directory to save the plot and summary data.
1030
+ feature_names (List[str] | None): Names for the features for plot labeling. If None, the names will be extracted from the Dataset and raise an error on failure.
1031
+ explain_dataset (Dataset, optional): A specific dataset to explain. If None, the trainer's test dataset is used.
1032
+ plot_n_features (int): Number of top features to plot.
1033
+ """
1034
+
1035
+ # print("\n--- Attention Analysis ---")
1036
+
1037
+ # --- Step 1: Check if the model supports this explanation ---
1038
+ if not getattr(self.model, 'has_interpretable_attention', False):
1039
+ _LOGGER.warning("Model is not compatible with interpretable attention analysis. Skipping.")
1040
+ return
1041
+
1042
+ # --- Step 2: Set up the dataloader ---
1043
+ dataset_to_use = explain_dataset if explain_dataset is not None else self.validation_dataset
1044
+ if not isinstance(dataset_to_use, Dataset):
1045
+ _LOGGER.error("The explanation dataset is empty or invalid. Skipping attention analysis.")
1046
+ return
1047
+
1048
+ # Get feature names
1049
+ if feature_names is None:
1050
+ if hasattr(dataset_to_use, DatasetKeys.FEATURE_NAMES):
1051
+ feature_names = dataset_to_use.feature_names # type: ignore
1052
+ else:
1053
+ _LOGGER.error(f"Could not extract `feature_names` from the dataset for attention plot. It must be provided if the dataset object does not have a '{DatasetKeys.FEATURE_NAMES}' attribute.")
1054
+ raise ValueError()
1055
+
1056
+ explain_loader = DataLoader(
1057
+ dataset=dataset_to_use, batch_size=32, shuffle=False,
1058
+ num_workers=0 if self.device.type == 'mps' else self.dataloader_workers,
1059
+ pin_memory=("cuda" in self.device.type)
1060
+ )
1061
+
1062
+ # --- Step 3: Collect weights ---
1063
+ all_weights = []
1064
+ for att_weights_b in self._attention_helper(explain_loader):
1065
+ if att_weights_b is not None:
1066
+ all_weights.append(att_weights_b)
1067
+
1068
+ # --- Step 4: Call the plotting function ---
1069
+ if all_weights:
1070
+ plot_attention_importance(
1071
+ weights=all_weights,
1072
+ feature_names=feature_names,
1073
+ save_dir=save_dir,
1074
+ top_n=plot_n_features
1075
+ )
1076
+ else:
1077
+ _LOGGER.error("No attention weights were collected from the model.")
1078
+
1079
+ def finalize_model_training(self,
1080
+ model_checkpoint: Union[Path, Literal['best', 'current']],
1081
+ save_dir: Union[str, Path],
1082
+ finalize_config: Union[FinalizeRegression,
1083
+ FinalizeMultiTargetRegression,
1084
+ FinalizeBinaryClassification,
1085
+ FinalizeBinaryImageClassification,
1086
+ FinalizeMultiClassClassification,
1087
+ FinalizeMultiClassImageClassification,
1088
+ FinalizeBinarySegmentation,
1089
+ FinalizeMultiClassSegmentation,
1090
+ FinalizeMultiLabelBinaryClassification]):
1091
+ """
1092
+ Saves a finalized, "inference-ready" model state to a .pth file.
1093
+
1094
+ This method saves the model's `state_dict`, the final epoch number, and optional configuration for the task at hand.
1095
+
1096
+ Args:
1097
+ model_checkpoint (Path | "best" | "current"):
1098
+ - Path: Loads the model state from a specific checkpoint file.
1099
+ - "best": Loads the best model state saved by the `DragonModelCheckpoint` callback.
1100
+ - "current": Uses the model's state as it is.
1101
+ save_dir (str | Path): The directory to save the finalized model.
1102
+ finalize_config (object): A data class instance specific to the ML task containing task-specific metadata required for inference.
1103
+ """
1104
+ if self.kind == MLTaskKeys.REGRESSION and not isinstance(finalize_config, FinalizeRegression):
1105
+ _LOGGER.error(f"For task {self.kind}, expected finalize_config of type 'FinalizeRegression', but got {type(finalize_config).__name__}.")
1106
+ raise TypeError()
1107
+ elif self.kind == MLTaskKeys.MULTITARGET_REGRESSION and not isinstance(finalize_config, FinalizeMultiTargetRegression):
1108
+ _LOGGER.error(f"For task {self.kind}, expected finalize_config of type 'FinalizeMultiTargetRegression', but got {type(finalize_config).__name__}.")
1109
+ raise TypeError()
1110
+ elif self.kind == MLTaskKeys.BINARY_CLASSIFICATION and not isinstance(finalize_config, FinalizeBinaryClassification):
1111
+ _LOGGER.error(f"For task {self.kind}, expected finalize_config of type 'FinalizeBinaryClassification', but got {type(finalize_config).__name__}.")
1112
+ raise TypeError()
1113
+ elif self.kind == MLTaskKeys.BINARY_IMAGE_CLASSIFICATION and not isinstance(finalize_config, FinalizeBinaryImageClassification):
1114
+ _LOGGER.error(f"For task {self.kind}, expected finalize_config of type 'FinalizeBinaryImageClassification', but got {type(finalize_config).__name__}.")
1115
+ raise TypeError()
1116
+ elif self.kind == MLTaskKeys.MULTICLASS_CLASSIFICATION and not isinstance(finalize_config, FinalizeMultiClassClassification):
1117
+ _LOGGER.error(f"For task {self.kind}, expected finalize_config of type 'FinalizeMultiClassClassification', but got {type(finalize_config).__name__}.")
1118
+ raise TypeError()
1119
+ elif self.kind == MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION and not isinstance(finalize_config, FinalizeMultiClassImageClassification):
1120
+ _LOGGER.error(f"For task {self.kind}, expected finalize_config of type 'FinalizeMultiClassImageClassification', but got {type(finalize_config).__name__}.")
1121
+ raise TypeError()
1122
+ elif self.kind == MLTaskKeys.BINARY_SEGMENTATION and not isinstance(finalize_config, FinalizeBinarySegmentation):
1123
+ _LOGGER.error(f"For task {self.kind}, expected finalize_config of type 'FinalizeBinarySegmentation', but got {type(finalize_config).__name__}.")
1124
+ raise TypeError()
1125
+ elif self.kind == MLTaskKeys.MULTICLASS_SEGMENTATION and not isinstance(finalize_config, FinalizeMultiClassSegmentation):
1126
+ _LOGGER.error(f"For task {self.kind}, expected finalize_config of type 'FinalizeMultiClassSegmentation', but got {type(finalize_config).__name__}.")
1127
+ raise TypeError()
1128
+ elif self.kind == MLTaskKeys.MULTILABEL_BINARY_CLASSIFICATION and not isinstance(finalize_config, FinalizeMultiLabelBinaryClassification):
1129
+ _LOGGER.error(f"For task {self.kind}, expected finalize_config of type 'FinalizeMultiLabelBinaryClassification', but got {type(finalize_config).__name__}.")
1130
+ raise TypeError()
1131
+
1132
+ # handle save path
1133
+ dir_path = make_fullpath(save_dir, make=True, enforce="directory")
1134
+ full_path = dir_path / finalize_config.filename
1135
+
1136
+ # handle checkpoint
1137
+ self._load_model_state_for_finalizing(model_checkpoint)
1138
+
1139
+ # Create finalized data
1140
+ finalized_data = {
1141
+ PyTorchCheckpointKeys.EPOCH: self.epoch,
1142
+ PyTorchCheckpointKeys.MODEL_STATE: self.model.state_dict(),
1143
+ PyTorchCheckpointKeys.TASK: finalize_config.task
1144
+ }
1145
+
1146
+ # Parse config
1147
+ if finalize_config.target_name is not None:
1148
+ finalized_data[PyTorchCheckpointKeys.TARGET_NAME] = finalize_config.target_name
1149
+ if finalize_config.target_names is not None:
1150
+ finalized_data[PyTorchCheckpointKeys.TARGET_NAMES] = finalize_config.target_names
1151
+ if finalize_config.classification_threshold is not None:
1152
+ finalized_data[PyTorchCheckpointKeys.CLASSIFICATION_THRESHOLD] = finalize_config.classification_threshold
1153
+ if finalize_config.class_map is not None:
1154
+ finalized_data[PyTorchCheckpointKeys.CLASS_MAP] = finalize_config.class_map
1155
+
1156
+ # Save model file
1157
+ torch.save(finalized_data, full_path)
1158
+
1159
+ _LOGGER.info(f"Finalized model file saved to '{full_path}'")
1160
+