dragon-ml-toolbox 19.14.0__py3-none-any.whl → 20.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dragon_ml_toolbox-19.14.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/METADATA +29 -46
- dragon_ml_toolbox-20.0.0.dist-info/RECORD +178 -0
- ml_tools/{ETL_cleaning.py → ETL_cleaning/__init__.py} +13 -5
- ml_tools/ETL_cleaning/_basic_clean.py +351 -0
- ml_tools/ETL_cleaning/_clean_tools.py +128 -0
- ml_tools/ETL_cleaning/_dragon_cleaner.py +245 -0
- ml_tools/ETL_cleaning/_imprimir.py +13 -0
- ml_tools/{ETL_engineering.py → ETL_engineering/__init__.py} +8 -4
- ml_tools/ETL_engineering/_dragon_engineering.py +261 -0
- ml_tools/ETL_engineering/_imprimir.py +24 -0
- ml_tools/{_core/_ETL_engineering.py → ETL_engineering/_transforms.py} +14 -267
- ml_tools/{_core → GUI_tools}/_GUI_tools.py +37 -40
- ml_tools/{GUI_tools.py → GUI_tools/__init__.py} +7 -5
- ml_tools/GUI_tools/_imprimir.py +12 -0
- ml_tools/IO_tools/_IO_loggers.py +235 -0
- ml_tools/IO_tools/_IO_save_load.py +151 -0
- ml_tools/IO_tools/_IO_utils.py +140 -0
- ml_tools/{IO_tools.py → IO_tools/__init__.py} +13 -5
- ml_tools/IO_tools/_imprimir.py +14 -0
- ml_tools/MICE/_MICE_imputation.py +132 -0
- ml_tools/{MICE_imputation.py → MICE/__init__.py} +6 -7
- ml_tools/{_core/_MICE_imputation.py → MICE/_dragon_mice.py} +243 -322
- ml_tools/MICE/_imprimir.py +11 -0
- ml_tools/{ML_callbacks.py → ML_callbacks/__init__.py} +12 -4
- ml_tools/ML_callbacks/_base.py +101 -0
- ml_tools/ML_callbacks/_checkpoint.py +232 -0
- ml_tools/ML_callbacks/_early_stop.py +208 -0
- ml_tools/ML_callbacks/_imprimir.py +12 -0
- ml_tools/ML_callbacks/_scheduler.py +197 -0
- ml_tools/{ML_chaining_utilities.py → ML_chain/__init__.py} +8 -3
- ml_tools/{_core/_ML_chaining_utilities.py → ML_chain/_chaining_tools.py} +5 -129
- ml_tools/ML_chain/_dragon_chain.py +140 -0
- ml_tools/ML_chain/_imprimir.py +11 -0
- ml_tools/ML_configuration/__init__.py +90 -0
- ml_tools/ML_configuration/_base_model_config.py +69 -0
- ml_tools/ML_configuration/_finalize.py +366 -0
- ml_tools/ML_configuration/_imprimir.py +47 -0
- ml_tools/ML_configuration/_metrics.py +593 -0
- ml_tools/ML_configuration/_models.py +206 -0
- ml_tools/ML_configuration/_training.py +124 -0
- ml_tools/ML_datasetmaster/__init__.py +28 -0
- ml_tools/ML_datasetmaster/_base_datasetmaster.py +337 -0
- ml_tools/{_core/_ML_datasetmaster.py → ML_datasetmaster/_datasetmaster.py} +9 -329
- ml_tools/ML_datasetmaster/_imprimir.py +15 -0
- ml_tools/{_core/_ML_sequence_datasetmaster.py → ML_datasetmaster/_sequence_datasetmaster.py} +13 -15
- ml_tools/{_core/_ML_vision_datasetmaster.py → ML_datasetmaster/_vision_datasetmaster.py} +63 -65
- ml_tools/ML_evaluation/__init__.py +53 -0
- ml_tools/ML_evaluation/_classification.py +629 -0
- ml_tools/ML_evaluation/_feature_importance.py +409 -0
- ml_tools/ML_evaluation/_imprimir.py +25 -0
- ml_tools/ML_evaluation/_loss.py +92 -0
- ml_tools/ML_evaluation/_regression.py +273 -0
- ml_tools/{_core/_ML_sequence_evaluation.py → ML_evaluation/_sequence.py} +8 -11
- ml_tools/{_core/_ML_vision_evaluation.py → ML_evaluation/_vision.py} +12 -17
- ml_tools/{_core → ML_evaluation_captum}/_ML_evaluation_captum.py +11 -38
- ml_tools/{ML_evaluation_captum.py → ML_evaluation_captum/__init__.py} +6 -4
- ml_tools/ML_evaluation_captum/_imprimir.py +10 -0
- ml_tools/{_core → ML_finalize_handler}/_ML_finalize_handler.py +3 -7
- ml_tools/ML_finalize_handler/__init__.py +10 -0
- ml_tools/ML_finalize_handler/_imprimir.py +8 -0
- ml_tools/ML_inference/__init__.py +22 -0
- ml_tools/ML_inference/_base_inference.py +166 -0
- ml_tools/{_core/_ML_chaining_inference.py → ML_inference/_chain_inference.py} +14 -17
- ml_tools/ML_inference/_dragon_inference.py +332 -0
- ml_tools/ML_inference/_imprimir.py +11 -0
- ml_tools/ML_inference/_multi_inference.py +180 -0
- ml_tools/ML_inference_sequence/__init__.py +10 -0
- ml_tools/ML_inference_sequence/_imprimir.py +8 -0
- ml_tools/{_core/_ML_sequence_inference.py → ML_inference_sequence/_sequence_inference.py} +11 -15
- ml_tools/ML_inference_vision/__init__.py +10 -0
- ml_tools/ML_inference_vision/_imprimir.py +8 -0
- ml_tools/{_core/_ML_vision_inference.py → ML_inference_vision/_vision_inference.py} +15 -19
- ml_tools/ML_models/__init__.py +32 -0
- ml_tools/{_core/_ML_models_advanced.py → ML_models/_advanced_models.py} +22 -18
- ml_tools/ML_models/_base_mlp_attention.py +198 -0
- ml_tools/{_core/_models_advanced_base.py → ML_models/_base_save_load.py} +73 -49
- ml_tools/ML_models/_dragon_tabular.py +248 -0
- ml_tools/ML_models/_imprimir.py +18 -0
- ml_tools/ML_models/_mlp_attention.py +134 -0
- ml_tools/{_core → ML_models}/_models_advanced_helpers.py +13 -13
- ml_tools/ML_models_sequence/__init__.py +10 -0
- ml_tools/ML_models_sequence/_imprimir.py +8 -0
- ml_tools/{_core/_ML_sequence_models.py → ML_models_sequence/_sequence_models.py} +5 -8
- ml_tools/ML_models_vision/__init__.py +29 -0
- ml_tools/ML_models_vision/_base_wrapper.py +254 -0
- ml_tools/ML_models_vision/_image_classification.py +182 -0
- ml_tools/ML_models_vision/_image_segmentation.py +108 -0
- ml_tools/ML_models_vision/_imprimir.py +16 -0
- ml_tools/ML_models_vision/_object_detection.py +135 -0
- ml_tools/ML_optimization/__init__.py +21 -0
- ml_tools/ML_optimization/_imprimir.py +13 -0
- ml_tools/{_core/_ML_optimization_pareto.py → ML_optimization/_multi_dragon.py} +18 -24
- ml_tools/ML_optimization/_single_dragon.py +203 -0
- ml_tools/{_core/_ML_optimization.py → ML_optimization/_single_manual.py} +75 -213
- ml_tools/{_core → ML_scaler}/_ML_scaler.py +8 -11
- ml_tools/ML_scaler/__init__.py +10 -0
- ml_tools/ML_scaler/_imprimir.py +8 -0
- ml_tools/ML_trainer/__init__.py +20 -0
- ml_tools/ML_trainer/_base_trainer.py +297 -0
- ml_tools/ML_trainer/_dragon_detection_trainer.py +402 -0
- ml_tools/ML_trainer/_dragon_sequence_trainer.py +540 -0
- ml_tools/ML_trainer/_dragon_trainer.py +1160 -0
- ml_tools/ML_trainer/_imprimir.py +10 -0
- ml_tools/{ML_utilities.py → ML_utilities/__init__.py} +14 -6
- ml_tools/ML_utilities/_artifact_finder.py +382 -0
- ml_tools/ML_utilities/_imprimir.py +16 -0
- ml_tools/ML_utilities/_inspection.py +325 -0
- ml_tools/ML_utilities/_train_tools.py +205 -0
- ml_tools/{ML_vision_transformers.py → ML_vision_transformers/__init__.py} +9 -6
- ml_tools/{_core/_ML_vision_transformers.py → ML_vision_transformers/_core_transforms.py} +11 -155
- ml_tools/ML_vision_transformers/_imprimir.py +14 -0
- ml_tools/ML_vision_transformers/_offline_augmentation.py +159 -0
- ml_tools/{_core/_PSO_optimization.py → PSO_optimization/_PSO.py} +58 -15
- ml_tools/{PSO_optimization.py → PSO_optimization/__init__.py} +5 -3
- ml_tools/PSO_optimization/_imprimir.py +10 -0
- ml_tools/SQL/__init__.py +7 -0
- ml_tools/{_core/_SQL.py → SQL/_dragon_SQL.py} +7 -11
- ml_tools/SQL/_imprimir.py +8 -0
- ml_tools/{_core → VIF}/_VIF_factor.py +5 -8
- ml_tools/{VIF_factor.py → VIF/__init__.py} +4 -2
- ml_tools/VIF/_imprimir.py +10 -0
- ml_tools/_core/__init__.py +7 -1
- ml_tools/_core/_logger.py +8 -18
- ml_tools/_core/_schema_load_ops.py +43 -0
- ml_tools/_core/_script_info.py +2 -2
- ml_tools/{data_exploration.py → data_exploration/__init__.py} +32 -16
- ml_tools/data_exploration/_analysis.py +214 -0
- ml_tools/data_exploration/_cleaning.py +566 -0
- ml_tools/data_exploration/_features.py +583 -0
- ml_tools/data_exploration/_imprimir.py +32 -0
- ml_tools/data_exploration/_plotting.py +487 -0
- ml_tools/data_exploration/_schema_ops.py +176 -0
- ml_tools/{ensemble_evaluation.py → ensemble_evaluation/__init__.py} +6 -4
- ml_tools/{_core → ensemble_evaluation}/_ensemble_evaluation.py +3 -7
- ml_tools/ensemble_evaluation/_imprimir.py +14 -0
- ml_tools/{ensemble_inference.py → ensemble_inference/__init__.py} +5 -3
- ml_tools/{_core → ensemble_inference}/_ensemble_inference.py +15 -18
- ml_tools/ensemble_inference/_imprimir.py +9 -0
- ml_tools/{ensemble_learning.py → ensemble_learning/__init__.py} +4 -6
- ml_tools/{_core → ensemble_learning}/_ensemble_learning.py +7 -10
- ml_tools/ensemble_learning/_imprimir.py +10 -0
- ml_tools/{excel_handler.py → excel_handler/__init__.py} +5 -3
- ml_tools/{_core → excel_handler}/_excel_handler.py +6 -10
- ml_tools/excel_handler/_imprimir.py +13 -0
- ml_tools/{keys.py → keys/__init__.py} +4 -1
- ml_tools/keys/_imprimir.py +11 -0
- ml_tools/{_core → keys}/_keys.py +2 -0
- ml_tools/{math_utilities.py → math_utilities/__init__.py} +5 -2
- ml_tools/math_utilities/_imprimir.py +11 -0
- ml_tools/{_core → math_utilities}/_math_utilities.py +1 -5
- ml_tools/{optimization_tools.py → optimization_tools/__init__.py} +9 -4
- ml_tools/optimization_tools/_imprimir.py +13 -0
- ml_tools/optimization_tools/_optimization_bounds.py +236 -0
- ml_tools/optimization_tools/_optimization_plots.py +218 -0
- ml_tools/{path_manager.py → path_manager/__init__.py} +6 -3
- ml_tools/{_core/_path_manager.py → path_manager/_dragonmanager.py} +11 -347
- ml_tools/path_manager/_imprimir.py +15 -0
- ml_tools/path_manager/_path_tools.py +346 -0
- ml_tools/plot_fonts/__init__.py +8 -0
- ml_tools/plot_fonts/_imprimir.py +8 -0
- ml_tools/{_core → plot_fonts}/_plot_fonts.py +2 -5
- ml_tools/schema/__init__.py +15 -0
- ml_tools/schema/_feature_schema.py +223 -0
- ml_tools/schema/_gui_schema.py +191 -0
- ml_tools/schema/_imprimir.py +10 -0
- ml_tools/{serde.py → serde/__init__.py} +4 -2
- ml_tools/serde/_imprimir.py +10 -0
- ml_tools/{_core → serde}/_serde.py +3 -8
- ml_tools/{utilities.py → utilities/__init__.py} +11 -6
- ml_tools/utilities/_imprimir.py +18 -0
- ml_tools/{_core/_utilities.py → utilities/_utility_save_load.py} +13 -190
- ml_tools/utilities/_utility_tools.py +192 -0
- dragon_ml_toolbox-19.14.0.dist-info/RECORD +0 -111
- ml_tools/ML_chaining_inference.py +0 -8
- ml_tools/ML_configuration.py +0 -86
- ml_tools/ML_configuration_pytab.py +0 -14
- ml_tools/ML_datasetmaster.py +0 -10
- ml_tools/ML_evaluation.py +0 -16
- ml_tools/ML_evaluation_multi.py +0 -12
- ml_tools/ML_finalize_handler.py +0 -8
- ml_tools/ML_inference.py +0 -12
- ml_tools/ML_models.py +0 -14
- ml_tools/ML_models_advanced.py +0 -14
- ml_tools/ML_models_pytab.py +0 -14
- ml_tools/ML_optimization.py +0 -14
- ml_tools/ML_optimization_pareto.py +0 -8
- ml_tools/ML_scaler.py +0 -8
- ml_tools/ML_sequence_datasetmaster.py +0 -8
- ml_tools/ML_sequence_evaluation.py +0 -10
- ml_tools/ML_sequence_inference.py +0 -8
- ml_tools/ML_sequence_models.py +0 -8
- ml_tools/ML_trainer.py +0 -12
- ml_tools/ML_vision_datasetmaster.py +0 -12
- ml_tools/ML_vision_evaluation.py +0 -10
- ml_tools/ML_vision_inference.py +0 -8
- ml_tools/ML_vision_models.py +0 -18
- ml_tools/SQL.py +0 -8
- ml_tools/_core/_ETL_cleaning.py +0 -694
- ml_tools/_core/_IO_tools.py +0 -498
- ml_tools/_core/_ML_callbacks.py +0 -702
- ml_tools/_core/_ML_configuration.py +0 -1332
- ml_tools/_core/_ML_configuration_pytab.py +0 -102
- ml_tools/_core/_ML_evaluation.py +0 -867
- ml_tools/_core/_ML_evaluation_multi.py +0 -544
- ml_tools/_core/_ML_inference.py +0 -646
- ml_tools/_core/_ML_models.py +0 -668
- ml_tools/_core/_ML_models_pytab.py +0 -693
- ml_tools/_core/_ML_trainer.py +0 -2323
- ml_tools/_core/_ML_utilities.py +0 -886
- ml_tools/_core/_ML_vision_models.py +0 -644
- ml_tools/_core/_data_exploration.py +0 -1909
- ml_tools/_core/_optimization_tools.py +0 -493
- ml_tools/_core/_schema.py +0 -359
- ml_tools/plot_fonts.py +0 -8
- ml_tools/schema.py +0 -12
- {dragon_ml_toolbox-19.14.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/WHEEL +0 -0
- {dragon_ml_toolbox-19.14.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/licenses/LICENSE +0 -0
- {dragon_ml_toolbox-19.14.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +0 -0
- {dragon_ml_toolbox-19.14.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/top_level.txt +0 -0
ml_tools/_core/_ML_trainer.py
DELETED
|
@@ -1,2323 +0,0 @@
|
|
|
1
|
-
from typing import List, Literal, Union, Optional, Callable, Dict, Any
|
|
2
|
-
from pathlib import Path
|
|
3
|
-
from torch.utils.data import DataLoader, Dataset
|
|
4
|
-
import torch
|
|
5
|
-
from torch import nn
|
|
6
|
-
import numpy as np
|
|
7
|
-
from abc import ABC, abstractmethod
|
|
8
|
-
|
|
9
|
-
from ._path_manager import make_fullpath
|
|
10
|
-
from ._ML_callbacks import _Callback, History, TqdmProgressBar, DragonModelCheckpoint, _DragonEarlyStopping, _DragonLRScheduler
|
|
11
|
-
from ._ML_evaluation import classification_metrics, regression_metrics, plot_losses, shap_summary_plot, plot_attention_importance
|
|
12
|
-
from ._ML_evaluation_multi import multi_target_regression_metrics, multi_label_classification_metrics, multi_target_shap_summary_plot
|
|
13
|
-
from ._ML_vision_evaluation import segmentation_metrics, object_detection_metrics
|
|
14
|
-
from ._ML_sequence_evaluation import sequence_to_sequence_metrics, sequence_to_value_metrics
|
|
15
|
-
from ._ML_evaluation_captum import captum_feature_importance, _is_captum_available, captum_segmentation_heatmap, captum_image_heatmap
|
|
16
|
-
from ._ML_configuration import (RegressionMetricsFormat,
|
|
17
|
-
MultiTargetRegressionMetricsFormat,
|
|
18
|
-
BinaryClassificationMetricsFormat,
|
|
19
|
-
MultiClassClassificationMetricsFormat,
|
|
20
|
-
BinaryImageClassificationMetricsFormat,
|
|
21
|
-
MultiClassImageClassificationMetricsFormat,
|
|
22
|
-
MultiLabelBinaryClassificationMetricsFormat,
|
|
23
|
-
BinarySegmentationMetricsFormat,
|
|
24
|
-
MultiClassSegmentationMetricsFormat,
|
|
25
|
-
SequenceValueMetricsFormat,
|
|
26
|
-
SequenceSequenceMetricsFormat,
|
|
27
|
-
|
|
28
|
-
FinalizeBinaryClassification,
|
|
29
|
-
FinalizeBinarySegmentation,
|
|
30
|
-
FinalizeBinaryImageClassification,
|
|
31
|
-
FinalizeMultiClassClassification,
|
|
32
|
-
FinalizeMultiClassImageClassification,
|
|
33
|
-
FinalizeMultiClassSegmentation,
|
|
34
|
-
FinalizeMultiLabelBinaryClassification,
|
|
35
|
-
FinalizeMultiTargetRegression,
|
|
36
|
-
FinalizeRegression,
|
|
37
|
-
FinalizeObjectDetection,
|
|
38
|
-
FinalizeSequenceSequencePrediction,
|
|
39
|
-
FinalizeSequenceValuePrediction)
|
|
40
|
-
|
|
41
|
-
from ._script_info import _script_info
|
|
42
|
-
from ._keys import PyTorchLogKeys, PyTorchCheckpointKeys, DatasetKeys, MLTaskKeys, MagicWords, DragonTrainerKeys, SequenceDatasetKeys, ScalerKeys
|
|
43
|
-
from ._logger import get_logger
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
_LOGGER = get_logger("DragonTrainer")
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
__all__ = [
|
|
50
|
-
"DragonTrainer",
|
|
51
|
-
"DragonDetectionTrainer",
|
|
52
|
-
"DragonSequenceTrainer"
|
|
53
|
-
]
|
|
54
|
-
|
|
55
|
-
class _BaseDragonTrainer(ABC):
|
|
56
|
-
"""
|
|
57
|
-
Abstract base class for Dragon Trainers.
|
|
58
|
-
|
|
59
|
-
Handles the common training loop orchestration, checkpointing, callback
|
|
60
|
-
management, and device handling. Subclasses must implement the
|
|
61
|
-
task-specific logic (dataloaders, train/val steps, evaluation).
|
|
62
|
-
"""
|
|
63
|
-
def __init__(self,
|
|
64
|
-
model: nn.Module,
|
|
65
|
-
optimizer: torch.optim.Optimizer,
|
|
66
|
-
device: Union[Literal['cuda', 'mps', 'cpu'],str],
|
|
67
|
-
dataloader_workers: int = 2,
|
|
68
|
-
checkpoint_callback: Optional[DragonModelCheckpoint] = None,
|
|
69
|
-
early_stopping_callback: Optional[_DragonEarlyStopping] = None,
|
|
70
|
-
lr_scheduler_callback: Optional[_DragonLRScheduler] = None,
|
|
71
|
-
extra_callbacks: Optional[List[_Callback]] = None):
|
|
72
|
-
|
|
73
|
-
self.model = model
|
|
74
|
-
self.optimizer = optimizer
|
|
75
|
-
self.scheduler = None
|
|
76
|
-
self.device = self._validate_device(device)
|
|
77
|
-
self.dataloader_workers = dataloader_workers
|
|
78
|
-
|
|
79
|
-
# Callback handler
|
|
80
|
-
default_callbacks = [History(), TqdmProgressBar()]
|
|
81
|
-
|
|
82
|
-
self._checkpoint_callback = None
|
|
83
|
-
if checkpoint_callback:
|
|
84
|
-
default_callbacks.append(checkpoint_callback)
|
|
85
|
-
self._checkpoint_callback = checkpoint_callback
|
|
86
|
-
if early_stopping_callback:
|
|
87
|
-
default_callbacks.append(early_stopping_callback)
|
|
88
|
-
if lr_scheduler_callback:
|
|
89
|
-
default_callbacks.append(lr_scheduler_callback)
|
|
90
|
-
|
|
91
|
-
user_callbacks = extra_callbacks if extra_callbacks is not None else []
|
|
92
|
-
self.callbacks = default_callbacks + user_callbacks
|
|
93
|
-
self._set_trainer_on_callbacks()
|
|
94
|
-
|
|
95
|
-
# Internal state
|
|
96
|
-
self.train_loader: Optional[DataLoader] = None
|
|
97
|
-
self.validation_loader: Optional[DataLoader] = None
|
|
98
|
-
self.history: Dict[str, List[Any]] = {}
|
|
99
|
-
self.epoch = 0
|
|
100
|
-
self.epochs = 0 # Total epochs for the fit run
|
|
101
|
-
self.start_epoch = 1
|
|
102
|
-
self.stop_training = False
|
|
103
|
-
self._batch_size = 10
|
|
104
|
-
|
|
105
|
-
def _validate_device(self, device: str) -> torch.device:
|
|
106
|
-
"""Validates the selected device and returns a torch.device object."""
|
|
107
|
-
device_lower = device.lower()
|
|
108
|
-
if "cuda" in device_lower and not torch.cuda.is_available():
|
|
109
|
-
_LOGGER.warning("CUDA not available, switching to CPU.")
|
|
110
|
-
device = "cpu"
|
|
111
|
-
elif device_lower == "mps" and not torch.backends.mps.is_available():
|
|
112
|
-
_LOGGER.warning("Apple Metal Performance Shaders (MPS) not available, switching to CPU.")
|
|
113
|
-
device = "cpu"
|
|
114
|
-
return torch.device(device)
|
|
115
|
-
|
|
116
|
-
def _set_trainer_on_callbacks(self):
|
|
117
|
-
"""Gives each callback a reference to this trainer instance."""
|
|
118
|
-
for callback in self.callbacks:
|
|
119
|
-
callback.set_trainer(self)
|
|
120
|
-
|
|
121
|
-
def _load_checkpoint(self, path: Union[str, Path]):
|
|
122
|
-
"""Loads a training checkpoint to resume training."""
|
|
123
|
-
p = make_fullpath(path, enforce="file")
|
|
124
|
-
_LOGGER.info(f"Loading checkpoint from '{p.name}'...")
|
|
125
|
-
|
|
126
|
-
try:
|
|
127
|
-
checkpoint = torch.load(p, map_location=self.device)
|
|
128
|
-
|
|
129
|
-
if PyTorchCheckpointKeys.MODEL_STATE not in checkpoint or PyTorchCheckpointKeys.OPTIMIZER_STATE not in checkpoint:
|
|
130
|
-
_LOGGER.error(f"Checkpoint file '{p.name}' is invalid. Missing 'model_state_dict' or 'optimizer_state_dict'.")
|
|
131
|
-
raise KeyError()
|
|
132
|
-
|
|
133
|
-
self.model.load_state_dict(checkpoint[PyTorchCheckpointKeys.MODEL_STATE])
|
|
134
|
-
self.optimizer.load_state_dict(checkpoint[PyTorchCheckpointKeys.OPTIMIZER_STATE])
|
|
135
|
-
self.epoch = checkpoint.get(PyTorchCheckpointKeys.EPOCH, 0)
|
|
136
|
-
self.start_epoch = self.epoch + 1 # Resume on the *next* epoch
|
|
137
|
-
|
|
138
|
-
# --- Load History ---
|
|
139
|
-
if PyTorchCheckpointKeys.HISTORY in checkpoint:
|
|
140
|
-
self.history = checkpoint[PyTorchCheckpointKeys.HISTORY]
|
|
141
|
-
_LOGGER.info(f"Restored training history up to epoch {self.epoch}.")
|
|
142
|
-
else:
|
|
143
|
-
_LOGGER.warning("No 'history' found in checkpoint. A new history will be started.")
|
|
144
|
-
self.history = {} # Ensure it's at least an empty dict
|
|
145
|
-
|
|
146
|
-
# --- Scheduler State Loading Logic ---
|
|
147
|
-
scheduler_state_exists = PyTorchCheckpointKeys.SCHEDULER_STATE in checkpoint
|
|
148
|
-
scheduler_object_exists = self.scheduler is not None
|
|
149
|
-
|
|
150
|
-
if scheduler_object_exists and scheduler_state_exists:
|
|
151
|
-
# Case 1: Both exist. Attempt to load.
|
|
152
|
-
try:
|
|
153
|
-
self.scheduler.load_state_dict(checkpoint[PyTorchCheckpointKeys.SCHEDULER_STATE]) # type: ignore
|
|
154
|
-
scheduler_name = self.scheduler.__class__.__name__
|
|
155
|
-
_LOGGER.info(f"Restored LR scheduler state for: {scheduler_name}")
|
|
156
|
-
except Exception as e:
|
|
157
|
-
# Loading failed, likely a mismatch
|
|
158
|
-
scheduler_name = self.scheduler.__class__.__name__
|
|
159
|
-
_LOGGER.error(f"Failed to load scheduler state for '{scheduler_name}'. A different scheduler type might have been used.")
|
|
160
|
-
raise e
|
|
161
|
-
|
|
162
|
-
elif scheduler_object_exists and not scheduler_state_exists:
|
|
163
|
-
# Case 2: Scheduler provided, but no state in checkpoint.
|
|
164
|
-
scheduler_name = self.scheduler.__class__.__name__
|
|
165
|
-
_LOGGER.warning(f"'{scheduler_name}' was provided, but no scheduler state was found in the checkpoint. The scheduler will start from its initial state.")
|
|
166
|
-
|
|
167
|
-
elif not scheduler_object_exists and scheduler_state_exists:
|
|
168
|
-
# Case 3: State in checkpoint, but no scheduler provided.
|
|
169
|
-
_LOGGER.error("Checkpoint contains an LR scheduler state, but no LRScheduler callback was provided.")
|
|
170
|
-
raise ValueError()
|
|
171
|
-
|
|
172
|
-
# Restore callback states
|
|
173
|
-
for cb in self.callbacks:
|
|
174
|
-
if isinstance(cb, DragonModelCheckpoint) and PyTorchCheckpointKeys.BEST_SCORE in checkpoint:
|
|
175
|
-
cb.best = checkpoint[PyTorchCheckpointKeys.BEST_SCORE]
|
|
176
|
-
_LOGGER.info(f"Restored {cb.__class__.__name__} 'best' score to: {cb.best:.4f}")
|
|
177
|
-
|
|
178
|
-
_LOGGER.info(f"Checkpoint loaded. Resuming training from epoch {self.start_epoch}.")
|
|
179
|
-
|
|
180
|
-
except Exception as e:
|
|
181
|
-
_LOGGER.error(f"Failed to load checkpoint from '{p}': {e}")
|
|
182
|
-
raise
|
|
183
|
-
|
|
184
|
-
def fit(self,
|
|
185
|
-
save_dir: Union[str,Path],
|
|
186
|
-
epochs: int = 100,
|
|
187
|
-
batch_size: int = 10,
|
|
188
|
-
shuffle: bool = True,
|
|
189
|
-
resume_from_checkpoint: Optional[Union[str, Path]] = None):
|
|
190
|
-
"""
|
|
191
|
-
Starts the training-validation process of the model.
|
|
192
|
-
|
|
193
|
-
Returns the "History" callback dictionary.
|
|
194
|
-
|
|
195
|
-
Args:
|
|
196
|
-
save_dir (str | Path): Directory to save the loss plot.
|
|
197
|
-
epochs (int): The total number of epochs to train for.
|
|
198
|
-
batch_size (int): The number of samples per batch.
|
|
199
|
-
shuffle (bool): Whether to shuffle the training data at each epoch.
|
|
200
|
-
resume_from_checkpoint (str | Path | None): Optional path to a checkpoint to resume training.
|
|
201
|
-
"""
|
|
202
|
-
self.epochs = epochs
|
|
203
|
-
self._batch_size = batch_size
|
|
204
|
-
self._create_dataloaders(self._batch_size, shuffle) # type: ignore
|
|
205
|
-
self.model.to(self.device)
|
|
206
|
-
|
|
207
|
-
if resume_from_checkpoint:
|
|
208
|
-
self._load_checkpoint(resume_from_checkpoint)
|
|
209
|
-
|
|
210
|
-
# Reset stop_training flag on the trainer
|
|
211
|
-
self.stop_training = False
|
|
212
|
-
|
|
213
|
-
self._callbacks_hook('on_train_begin')
|
|
214
|
-
|
|
215
|
-
if not self.train_loader:
|
|
216
|
-
_LOGGER.error("Train loader is not initialized.")
|
|
217
|
-
raise ValueError()
|
|
218
|
-
|
|
219
|
-
if not self.validation_loader:
|
|
220
|
-
_LOGGER.error("Validation loader is not initialized.")
|
|
221
|
-
raise ValueError()
|
|
222
|
-
|
|
223
|
-
for epoch in range(self.start_epoch, self.epochs + 1):
|
|
224
|
-
self.epoch = epoch
|
|
225
|
-
epoch_logs: Dict[str, Any] = {}
|
|
226
|
-
self._callbacks_hook('on_epoch_begin', epoch, logs=epoch_logs)
|
|
227
|
-
|
|
228
|
-
train_logs = self._train_step()
|
|
229
|
-
epoch_logs.update(train_logs)
|
|
230
|
-
|
|
231
|
-
val_logs = self._validation_step()
|
|
232
|
-
epoch_logs.update(val_logs)
|
|
233
|
-
|
|
234
|
-
self._callbacks_hook('on_epoch_end', epoch, logs=epoch_logs)
|
|
235
|
-
|
|
236
|
-
# Check the early stopping flag
|
|
237
|
-
if self.stop_training:
|
|
238
|
-
break
|
|
239
|
-
|
|
240
|
-
self._callbacks_hook('on_train_end')
|
|
241
|
-
|
|
242
|
-
# Training History
|
|
243
|
-
plot_losses(self.history, save_dir=save_dir)
|
|
244
|
-
|
|
245
|
-
return self.history
|
|
246
|
-
|
|
247
|
-
def _callbacks_hook(self, method_name: str, *args, **kwargs):
|
|
248
|
-
"""Calls the specified method on all callbacks."""
|
|
249
|
-
for callback in self.callbacks:
|
|
250
|
-
method = getattr(callback, method_name)
|
|
251
|
-
method(*args, **kwargs)
|
|
252
|
-
|
|
253
|
-
def to_cpu(self):
|
|
254
|
-
"""
|
|
255
|
-
Moves the model to the CPU and updates the trainer's device setting.
|
|
256
|
-
|
|
257
|
-
This is useful for running operations that require the CPU.
|
|
258
|
-
"""
|
|
259
|
-
self.device = torch.device('cpu')
|
|
260
|
-
self.model.to(self.device)
|
|
261
|
-
_LOGGER.info("Trainer and model moved to CPU.")
|
|
262
|
-
|
|
263
|
-
def to_device(self, device: str):
|
|
264
|
-
"""
|
|
265
|
-
Moves the model to the specified device and updates the trainer's device setting.
|
|
266
|
-
|
|
267
|
-
Args:
|
|
268
|
-
device (str): The target device (e.g., 'cuda', 'mps', 'cpu').
|
|
269
|
-
"""
|
|
270
|
-
self.device = self._validate_device(device)
|
|
271
|
-
self.model.to(self.device)
|
|
272
|
-
_LOGGER.info(f"Trainer and model moved to {self.device}.")
|
|
273
|
-
|
|
274
|
-
def _load_model_state_for_finalizing(self, model_checkpoint: Union[Path, Literal['best', 'current']]):
|
|
275
|
-
"""
|
|
276
|
-
Private helper to load the correct model state_dict based on user's choice.
|
|
277
|
-
This is called by finalize_model_training() in subclasses.
|
|
278
|
-
"""
|
|
279
|
-
if isinstance(model_checkpoint, Path):
|
|
280
|
-
self._load_checkpoint(path=model_checkpoint)
|
|
281
|
-
elif model_checkpoint == MagicWords.BEST and self._checkpoint_callback:
|
|
282
|
-
path_to_latest = self._checkpoint_callback.best_checkpoint_path
|
|
283
|
-
self._load_checkpoint(path_to_latest)
|
|
284
|
-
elif model_checkpoint == MagicWords.BEST and self._checkpoint_callback is None:
|
|
285
|
-
_LOGGER.error(f"'model_checkpoint' set to '{MagicWords.BEST}' but no checkpoint callback was found.")
|
|
286
|
-
raise ValueError()
|
|
287
|
-
elif model_checkpoint == MagicWords.CURRENT:
|
|
288
|
-
pass
|
|
289
|
-
else:
|
|
290
|
-
_LOGGER.error(f"Unknown 'model_checkpoint' received '{model_checkpoint}'.")
|
|
291
|
-
raise ValueError()
|
|
292
|
-
|
|
293
|
-
# --- Abstract Methods ---
|
|
294
|
-
# These must be implemented by subclasses
|
|
295
|
-
|
|
296
|
-
@abstractmethod
|
|
297
|
-
def _create_dataloaders(self, batch_size: int, shuffle: bool):
|
|
298
|
-
"""Initializes the DataLoaders."""
|
|
299
|
-
raise NotImplementedError
|
|
300
|
-
|
|
301
|
-
@abstractmethod
|
|
302
|
-
def _train_step(self) -> Dict[str, float]:
|
|
303
|
-
"""Runs a single training epoch."""
|
|
304
|
-
raise NotImplementedError
|
|
305
|
-
|
|
306
|
-
@abstractmethod
|
|
307
|
-
def _validation_step(self) -> Dict[str, float]:
|
|
308
|
-
"""Runs a single validation epoch."""
|
|
309
|
-
raise NotImplementedError
|
|
310
|
-
|
|
311
|
-
@abstractmethod
|
|
312
|
-
def evaluate(self, *args, **kwargs):
|
|
313
|
-
"""Runs the full model evaluation."""
|
|
314
|
-
raise NotImplementedError
|
|
315
|
-
|
|
316
|
-
@abstractmethod
|
|
317
|
-
def _evaluate(self, *args, **kwargs):
|
|
318
|
-
"""Internal evaluation helper."""
|
|
319
|
-
raise NotImplementedError
|
|
320
|
-
|
|
321
|
-
@abstractmethod
|
|
322
|
-
def finalize_model_training(self, *args, **kwargs):
|
|
323
|
-
"""Saves the finalized model for inference."""
|
|
324
|
-
raise NotImplementedError
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
# --- DragonTrainer ----
|
|
328
|
-
class DragonTrainer(_BaseDragonTrainer):
|
|
329
|
-
def __init__(self,
|
|
330
|
-
model: nn.Module,
|
|
331
|
-
train_dataset: Dataset,
|
|
332
|
-
validation_dataset: Dataset,
|
|
333
|
-
kind: Literal["regression", "binary classification", "multiclass classification",
|
|
334
|
-
"multitarget regression", "multilabel binary classification",
|
|
335
|
-
"binary segmentation", "multiclass segmentation", "binary image classification", "multiclass image classification"],
|
|
336
|
-
optimizer: torch.optim.Optimizer,
|
|
337
|
-
device: Union[Literal['cuda', 'mps', 'cpu'],str],
|
|
338
|
-
checkpoint_callback: Optional[DragonModelCheckpoint],
|
|
339
|
-
early_stopping_callback: Optional[_DragonEarlyStopping],
|
|
340
|
-
lr_scheduler_callback: Optional[_DragonLRScheduler],
|
|
341
|
-
extra_callbacks: Optional[List[_Callback]] = None,
|
|
342
|
-
criterion: Union[nn.Module,Literal["auto"]] = "auto",
|
|
343
|
-
dataloader_workers: int = 2):
|
|
344
|
-
"""
|
|
345
|
-
Automates the training process of a PyTorch Model.
|
|
346
|
-
|
|
347
|
-
Built-in Callbacks: `History`, `TqdmProgressBar`
|
|
348
|
-
|
|
349
|
-
Args:
|
|
350
|
-
model (nn.Module): The PyTorch model to train.
|
|
351
|
-
train_dataset (Dataset): The training dataset.
|
|
352
|
-
validation_dataset (Dataset): The validation dataset.
|
|
353
|
-
kind (str): Used to redirect to the correct process.
|
|
354
|
-
criterion (nn.Module | "auto"): The loss function to use. If "auto", it will be inferred from the selected task
|
|
355
|
-
optimizer (torch.optim.Optimizer): The optimizer.
|
|
356
|
-
device (str): The device to run training on ('cpu', 'cuda', 'mps').
|
|
357
|
-
dataloader_workers (int): Subprocesses for data loading.
|
|
358
|
-
extra_callbacks (List[Callback] | None): A list of extra callbacks to use during training.
|
|
359
|
-
|
|
360
|
-
Note:
|
|
361
|
-
- For **regression** and **multi_target_regression** tasks, suggested criterions include `nn.MSELoss` or `nn.L1Loss`. The model should output as many logits as existing targets.
|
|
362
|
-
|
|
363
|
-
- For **single-label, binary classification**, `nn.BCEWithLogitsLoss` is the standard choice. The model should output a single logit.
|
|
364
|
-
|
|
365
|
-
- For **single-label, multi-class classification** tasks, `nn.CrossEntropyLoss` is the standard choice. The model should output as many logits as existing classes.
|
|
366
|
-
|
|
367
|
-
- For **multi-label, binary classification** tasks (where each label is a 0 or 1), `nn.BCEWithLogitsLoss` is the correct choice as it treats each output as an independent binary problem. The model should output 1 logit per binary target.
|
|
368
|
-
|
|
369
|
-
- For **binary segmentation** tasks, `nn.BCEWithLogitsLoss` is common. The model should output a single logit.
|
|
370
|
-
|
|
371
|
-
- for **multiclass segmentation** tasks, `nn.CrossEntropyLoss` is the standard. The model should output as many logits as existing classes.
|
|
372
|
-
"""
|
|
373
|
-
# Call the base class constructor with common parameters
|
|
374
|
-
super().__init__(
|
|
375
|
-
model=model,
|
|
376
|
-
optimizer=optimizer,
|
|
377
|
-
device=device,
|
|
378
|
-
dataloader_workers=dataloader_workers,
|
|
379
|
-
checkpoint_callback=checkpoint_callback,
|
|
380
|
-
early_stopping_callback=early_stopping_callback,
|
|
381
|
-
lr_scheduler_callback=lr_scheduler_callback,
|
|
382
|
-
extra_callbacks=extra_callbacks
|
|
383
|
-
)
|
|
384
|
-
|
|
385
|
-
if kind not in [MLTaskKeys.REGRESSION,
|
|
386
|
-
MLTaskKeys.BINARY_CLASSIFICATION,
|
|
387
|
-
MLTaskKeys.MULTICLASS_CLASSIFICATION,
|
|
388
|
-
MLTaskKeys.MULTILABEL_BINARY_CLASSIFICATION,
|
|
389
|
-
MLTaskKeys.MULTITARGET_REGRESSION,
|
|
390
|
-
MLTaskKeys.BINARY_SEGMENTATION,
|
|
391
|
-
MLTaskKeys.MULTICLASS_SEGMENTATION,
|
|
392
|
-
MLTaskKeys.BINARY_IMAGE_CLASSIFICATION,
|
|
393
|
-
MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION]:
|
|
394
|
-
raise ValueError(f"'{kind}' is not a valid task type.")
|
|
395
|
-
|
|
396
|
-
self.train_dataset = train_dataset
|
|
397
|
-
self.validation_dataset = validation_dataset
|
|
398
|
-
self.kind = kind
|
|
399
|
-
self._classification_threshold: float = 0.5
|
|
400
|
-
|
|
401
|
-
# loss function
|
|
402
|
-
if criterion == "auto":
|
|
403
|
-
if kind in [MLTaskKeys.REGRESSION, MLTaskKeys.MULTITARGET_REGRESSION]:
|
|
404
|
-
self.criterion = nn.MSELoss()
|
|
405
|
-
elif kind in [MLTaskKeys.BINARY_CLASSIFICATION, MLTaskKeys.BINARY_IMAGE_CLASSIFICATION, MLTaskKeys.BINARY_SEGMENTATION, MLTaskKeys.MULTILABEL_BINARY_CLASSIFICATION]:
|
|
406
|
-
self.criterion = nn.BCEWithLogitsLoss()
|
|
407
|
-
elif kind in [MLTaskKeys.MULTICLASS_CLASSIFICATION, MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION, MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION, MLTaskKeys.MULTICLASS_SEGMENTATION]:
|
|
408
|
-
self.criterion = nn.CrossEntropyLoss()
|
|
409
|
-
else:
|
|
410
|
-
self.criterion = criterion
|
|
411
|
-
|
|
412
|
-
def _create_dataloaders(self, batch_size: int, shuffle: bool):
|
|
413
|
-
"""Initializes the DataLoaders."""
|
|
414
|
-
# Ensure stability on MPS devices by setting num_workers to 0
|
|
415
|
-
loader_workers = 0 if self.device.type == 'mps' else self.dataloader_workers
|
|
416
|
-
|
|
417
|
-
self.train_loader = DataLoader(
|
|
418
|
-
dataset=self.train_dataset,
|
|
419
|
-
batch_size=batch_size,
|
|
420
|
-
shuffle=shuffle,
|
|
421
|
-
num_workers=loader_workers,
|
|
422
|
-
pin_memory=("cuda" in self.device.type),
|
|
423
|
-
drop_last=True # Drops the last batch if incomplete, selecting a good batch size is key.
|
|
424
|
-
)
|
|
425
|
-
|
|
426
|
-
self.validation_loader = DataLoader(
|
|
427
|
-
dataset=self.validation_dataset,
|
|
428
|
-
batch_size=batch_size,
|
|
429
|
-
shuffle=False,
|
|
430
|
-
num_workers=loader_workers,
|
|
431
|
-
pin_memory=("cuda" in self.device.type)
|
|
432
|
-
)
|
|
433
|
-
|
|
434
|
-
def _train_step(self):
|
|
435
|
-
self.model.train()
|
|
436
|
-
running_loss = 0.0
|
|
437
|
-
total_samples = 0
|
|
438
|
-
|
|
439
|
-
for batch_idx, (features, target) in enumerate(self.train_loader): # type: ignore
|
|
440
|
-
# Create a log dictionary for the batch
|
|
441
|
-
batch_logs = {
|
|
442
|
-
PyTorchLogKeys.BATCH_INDEX: batch_idx,
|
|
443
|
-
PyTorchLogKeys.BATCH_SIZE: features.size(0)
|
|
444
|
-
}
|
|
445
|
-
self._callbacks_hook('on_batch_begin', batch_idx, logs=batch_logs)
|
|
446
|
-
|
|
447
|
-
features, target = features.to(self.device), target.to(self.device)
|
|
448
|
-
self.optimizer.zero_grad()
|
|
449
|
-
|
|
450
|
-
output = self.model(features)
|
|
451
|
-
|
|
452
|
-
# --- Label Type/Shape Correction ---
|
|
453
|
-
# Cast target to float for BCE-based losses
|
|
454
|
-
if self.kind in MLTaskKeys.ALL_BINARY_TASKS:
|
|
455
|
-
target = target.float()
|
|
456
|
-
|
|
457
|
-
# Reshape output to match target for single-logit tasks
|
|
458
|
-
if self.kind in [MLTaskKeys.REGRESSION, MLTaskKeys.BINARY_CLASSIFICATION, MLTaskKeys.BINARY_IMAGE_CLASSIFICATION]:
|
|
459
|
-
# If model outputs [N, 1] and target is [N], squeeze output
|
|
460
|
-
if output.ndim == 2 and output.shape[1] == 1 and target.ndim == 1:
|
|
461
|
-
output = output.squeeze(1)
|
|
462
|
-
|
|
463
|
-
if self.kind == MLTaskKeys.BINARY_SEGMENTATION:
|
|
464
|
-
# If model outputs [N, 1, H, W] and target is [N, H, W], squeeze output
|
|
465
|
-
if output.ndim == 4 and output.shape[1] == 1 and target.ndim == 3:
|
|
466
|
-
output = output.squeeze(1)
|
|
467
|
-
|
|
468
|
-
loss = self.criterion(output, target)
|
|
469
|
-
|
|
470
|
-
loss.backward()
|
|
471
|
-
self.optimizer.step()
|
|
472
|
-
|
|
473
|
-
# Calculate batch loss and update running loss for the epoch
|
|
474
|
-
batch_loss = loss.item()
|
|
475
|
-
batch_size = features.size(0)
|
|
476
|
-
running_loss += batch_loss * batch_size # Accumulate total loss
|
|
477
|
-
total_samples += batch_size # total samples
|
|
478
|
-
|
|
479
|
-
# Add the batch loss to the logs and call the end-of-batch hook
|
|
480
|
-
batch_logs[PyTorchLogKeys.BATCH_LOSS] = batch_loss
|
|
481
|
-
self._callbacks_hook('on_batch_end', batch_idx, logs=batch_logs)
|
|
482
|
-
|
|
483
|
-
if total_samples == 0:
|
|
484
|
-
_LOGGER.warning("No samples processed in a train_step. Returning 0 loss.")
|
|
485
|
-
return {PyTorchLogKeys.TRAIN_LOSS: 0.0}
|
|
486
|
-
|
|
487
|
-
return {PyTorchLogKeys.TRAIN_LOSS: running_loss / total_samples} # type: ignore
|
|
488
|
-
|
|
489
|
-
def _validation_step(self):
|
|
490
|
-
self.model.eval()
|
|
491
|
-
running_loss = 0.0
|
|
492
|
-
|
|
493
|
-
with torch.no_grad():
|
|
494
|
-
for features, target in self.validation_loader: # type: ignore
|
|
495
|
-
features, target = features.to(self.device), target.to(self.device)
|
|
496
|
-
|
|
497
|
-
output = self.model(features)
|
|
498
|
-
|
|
499
|
-
# --- Label Type/Shape Correction ---
|
|
500
|
-
# Cast target to float for BCE-based losses
|
|
501
|
-
if self.kind in MLTaskKeys.ALL_BINARY_TASKS:
|
|
502
|
-
target = target.float()
|
|
503
|
-
|
|
504
|
-
# Reshape output to match target for single-logit tasks
|
|
505
|
-
if self.kind in [MLTaskKeys.REGRESSION, MLTaskKeys.BINARY_CLASSIFICATION, MLTaskKeys.BINARY_IMAGE_CLASSIFICATION]:
|
|
506
|
-
# If model outputs [N, 1] and target is [N], squeeze output
|
|
507
|
-
if output.ndim == 2 and output.shape[1] == 1 and target.ndim == 1:
|
|
508
|
-
output = output.squeeze(1)
|
|
509
|
-
|
|
510
|
-
if self.kind == MLTaskKeys.BINARY_SEGMENTATION:
|
|
511
|
-
# If model outputs [N, 1, H, W] and target is [N, H, W], squeeze output
|
|
512
|
-
if output.ndim == 4 and output.shape[1] == 1 and target.ndim == 3:
|
|
513
|
-
output = output.squeeze(1)
|
|
514
|
-
|
|
515
|
-
loss = self.criterion(output, target)
|
|
516
|
-
|
|
517
|
-
running_loss += loss.item() * features.size(0)
|
|
518
|
-
|
|
519
|
-
if not self.validation_loader.dataset: # type: ignore
|
|
520
|
-
_LOGGER.warning("No samples processed in _validation_step. Returning 0 loss.")
|
|
521
|
-
return {PyTorchLogKeys.VAL_LOSS: 0.0}
|
|
522
|
-
|
|
523
|
-
logs = {PyTorchLogKeys.VAL_LOSS: running_loss / len(self.validation_loader.dataset)} # type: ignore
|
|
524
|
-
return logs
|
|
525
|
-
|
|
526
|
-
def _predict_for_eval(self, dataloader: DataLoader):
|
|
527
|
-
"""
|
|
528
|
-
Private method to yield model predictions batch by batch for evaluation.
|
|
529
|
-
|
|
530
|
-
Automatically detects if `target_scaler` is present in the training dataset
|
|
531
|
-
and applies inverse transformation for Regression tasks.
|
|
532
|
-
|
|
533
|
-
Yields:
|
|
534
|
-
tuple: A tuple containing (y_pred_batch, y_prob_batch, y_true_batch).
|
|
535
|
-
|
|
536
|
-
- y_prob_batch is None for regression tasks.
|
|
537
|
-
"""
|
|
538
|
-
self.model.eval()
|
|
539
|
-
self.model.to(self.device)
|
|
540
|
-
|
|
541
|
-
# --- Check for Target Scaler (for Regression Un-scaling) ---
|
|
542
|
-
target_scaler = None
|
|
543
|
-
if self.kind in [MLTaskKeys.REGRESSION, MLTaskKeys.MULTITARGET_REGRESSION]:
|
|
544
|
-
# Try to get the scaler from the dataset attached to the trainer
|
|
545
|
-
if hasattr(self.train_dataset, ScalerKeys.TARGET_SCALER):
|
|
546
|
-
target_scaler = getattr(self.train_dataset, ScalerKeys.TARGET_SCALER)
|
|
547
|
-
if target_scaler is not None:
|
|
548
|
-
_LOGGER.debug("Target scaler detected. Un-scaling predictions and targets for metric calculation.")
|
|
549
|
-
|
|
550
|
-
with torch.no_grad():
|
|
551
|
-
for features, target in dataloader:
|
|
552
|
-
features = features.to(self.device)
|
|
553
|
-
# Keep target on device initially for potential un-scaling
|
|
554
|
-
target = target.to(self.device)
|
|
555
|
-
|
|
556
|
-
output = self.model(features)
|
|
557
|
-
|
|
558
|
-
y_pred_batch = None
|
|
559
|
-
y_prob_batch = None
|
|
560
|
-
y_true_batch = None
|
|
561
|
-
|
|
562
|
-
if self.kind in [MLTaskKeys.REGRESSION, MLTaskKeys.MULTITARGET_REGRESSION]:
|
|
563
|
-
|
|
564
|
-
# --- Automatic Un-scaling Logic ---
|
|
565
|
-
if target_scaler:
|
|
566
|
-
# 1. Reshape output/target if flattened (common in single regression)
|
|
567
|
-
# Scaler expects [N, Features]
|
|
568
|
-
original_out_shape = output.shape
|
|
569
|
-
original_target_shape = target.shape
|
|
570
|
-
|
|
571
|
-
if output.ndim == 1: output = output.reshape(-1, 1)
|
|
572
|
-
if target.ndim == 1: target = target.reshape(-1, 1)
|
|
573
|
-
|
|
574
|
-
# 2. Apply Inverse Transform
|
|
575
|
-
output = target_scaler.inverse_transform(output)
|
|
576
|
-
target = target_scaler.inverse_transform(target)
|
|
577
|
-
|
|
578
|
-
# 3. Restore shapes (optional, but good for consistency)
|
|
579
|
-
if len(original_out_shape) == 1: output = output.flatten()
|
|
580
|
-
if len(original_target_shape) == 1: target = target.flatten()
|
|
581
|
-
|
|
582
|
-
y_pred_batch = output.cpu().numpy()
|
|
583
|
-
y_true_batch = target.cpu().numpy()
|
|
584
|
-
|
|
585
|
-
elif self.kind in [MLTaskKeys.BINARY_CLASSIFICATION, MLTaskKeys.BINARY_IMAGE_CLASSIFICATION]:
|
|
586
|
-
if output.ndim == 2 and output.shape[1] == 1:
|
|
587
|
-
output = output.squeeze(1)
|
|
588
|
-
|
|
589
|
-
probs_pos = torch.sigmoid(output)
|
|
590
|
-
preds = (probs_pos >= self._classification_threshold).int()
|
|
591
|
-
y_pred_batch = preds.cpu().numpy()
|
|
592
|
-
|
|
593
|
-
probs_neg = 1.0 - probs_pos
|
|
594
|
-
y_prob_batch = torch.stack([probs_neg, probs_pos], dim=1).cpu().numpy()
|
|
595
|
-
y_true_batch = target.cpu().numpy()
|
|
596
|
-
|
|
597
|
-
elif self.kind in [MLTaskKeys.MULTICLASS_CLASSIFICATION, MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION]:
|
|
598
|
-
probs = torch.softmax(output, dim=1)
|
|
599
|
-
preds = torch.argmax(probs, dim=1)
|
|
600
|
-
y_pred_batch = preds.cpu().numpy()
|
|
601
|
-
y_prob_batch = probs.cpu().numpy()
|
|
602
|
-
y_true_batch = target.cpu().numpy()
|
|
603
|
-
|
|
604
|
-
elif self.kind == MLTaskKeys.MULTILABEL_BINARY_CLASSIFICATION:
|
|
605
|
-
probs = torch.sigmoid(output)
|
|
606
|
-
preds = (probs >= self._classification_threshold).int()
|
|
607
|
-
y_pred_batch = preds.cpu().numpy()
|
|
608
|
-
y_prob_batch = probs.cpu().numpy()
|
|
609
|
-
y_true_batch = target.cpu().numpy()
|
|
610
|
-
|
|
611
|
-
elif self.kind == MLTaskKeys.BINARY_SEGMENTATION:
|
|
612
|
-
probs_pos = torch.sigmoid(output)
|
|
613
|
-
preds = (probs_pos >= self._classification_threshold).int()
|
|
614
|
-
y_pred_batch = preds.squeeze(1).cpu().numpy()
|
|
615
|
-
|
|
616
|
-
probs_neg = 1.0 - probs_pos
|
|
617
|
-
y_prob_batch = torch.cat([probs_neg, probs_pos], dim=1).cpu().numpy()
|
|
618
|
-
|
|
619
|
-
if target.ndim == 4 and target.shape[1] == 1:
|
|
620
|
-
target = target.squeeze(1)
|
|
621
|
-
y_true_batch = target.cpu().numpy()
|
|
622
|
-
|
|
623
|
-
elif self.kind == MLTaskKeys.MULTICLASS_SEGMENTATION:
|
|
624
|
-
probs = torch.softmax(output, dim=1)
|
|
625
|
-
preds = torch.argmax(probs, dim=1)
|
|
626
|
-
y_pred_batch = preds.cpu().numpy()
|
|
627
|
-
y_prob_batch = probs.cpu().numpy()
|
|
628
|
-
|
|
629
|
-
if target.ndim == 4 and target.shape[1] == 1:
|
|
630
|
-
target = target.squeeze(1)
|
|
631
|
-
y_true_batch = target.cpu().numpy()
|
|
632
|
-
|
|
633
|
-
yield y_pred_batch, y_prob_batch, y_true_batch
|
|
634
|
-
|
|
635
|
-
def evaluate(self,
|
|
636
|
-
save_dir: Union[str, Path],
|
|
637
|
-
model_checkpoint: Union[Path, Literal["best", "current"]],
|
|
638
|
-
classification_threshold: Optional[float] = None,
|
|
639
|
-
test_data: Optional[Union[DataLoader, Dataset]] = None,
|
|
640
|
-
val_format_configuration: Optional[Union[
|
|
641
|
-
RegressionMetricsFormat,
|
|
642
|
-
MultiTargetRegressionMetricsFormat,
|
|
643
|
-
BinaryClassificationMetricsFormat,
|
|
644
|
-
MultiClassClassificationMetricsFormat,
|
|
645
|
-
BinaryImageClassificationMetricsFormat,
|
|
646
|
-
MultiClassImageClassificationMetricsFormat,
|
|
647
|
-
MultiLabelBinaryClassificationMetricsFormat,
|
|
648
|
-
BinarySegmentationMetricsFormat,
|
|
649
|
-
MultiClassSegmentationMetricsFormat
|
|
650
|
-
]]=None,
|
|
651
|
-
test_format_configuration: Optional[Union[
|
|
652
|
-
RegressionMetricsFormat,
|
|
653
|
-
MultiTargetRegressionMetricsFormat,
|
|
654
|
-
BinaryClassificationMetricsFormat,
|
|
655
|
-
MultiClassClassificationMetricsFormat,
|
|
656
|
-
BinaryImageClassificationMetricsFormat,
|
|
657
|
-
MultiClassImageClassificationMetricsFormat,
|
|
658
|
-
MultiLabelBinaryClassificationMetricsFormat,
|
|
659
|
-
BinarySegmentationMetricsFormat,
|
|
660
|
-
MultiClassSegmentationMetricsFormat,
|
|
661
|
-
]]=None):
|
|
662
|
-
"""
|
|
663
|
-
Evaluates the model, routing to the correct evaluation function based on task `kind`.
|
|
664
|
-
|
|
665
|
-
Args:
|
|
666
|
-
model_checkpoint (Path | "best" | "current"):
|
|
667
|
-
- Path to a valid checkpoint for the model. The state of the trained model will be overwritten in place.
|
|
668
|
-
- If 'best', the best checkpoint will be loaded if a DragonModelCheckpoint was provided. The state of the trained model will be overwritten in place.
|
|
669
|
-
- If 'current', use the current state of the trained model up the latest trained epoch.
|
|
670
|
-
save_dir (str | Path): Directory to save all reports and plots.
|
|
671
|
-
classification_threshold (float | None): Used for tasks using a binary approach (binary classification, binary segmentation, multilabel binary classification)
|
|
672
|
-
test_data (DataLoader | Dataset | None): Optional Test data to evaluate the model performance. Validation and Test metrics will be saved to subdirectories.
|
|
673
|
-
val_format_configuration (object): Optional configuration for metric format output for the validation set.
|
|
674
|
-
test_format_configuration (object): Optional configuration for metric format output for the test set.
|
|
675
|
-
"""
|
|
676
|
-
# Validate model checkpoint
|
|
677
|
-
if isinstance(model_checkpoint, Path):
|
|
678
|
-
checkpoint_validated = make_fullpath(model_checkpoint, enforce="file")
|
|
679
|
-
elif model_checkpoint in [MagicWords.BEST, MagicWords.CURRENT]:
|
|
680
|
-
checkpoint_validated = model_checkpoint
|
|
681
|
-
else:
|
|
682
|
-
_LOGGER.error(f"'model_checkpoint' must be a Path object, or the string '{MagicWords.BEST}', or the string '{MagicWords.CURRENT}'.")
|
|
683
|
-
raise ValueError()
|
|
684
|
-
|
|
685
|
-
# Validate classification threshold
|
|
686
|
-
if self.kind not in MLTaskKeys.ALL_BINARY_TASKS:
|
|
687
|
-
# dummy value for tasks that do not need it
|
|
688
|
-
threshold_validated = 0.5
|
|
689
|
-
elif classification_threshold is None:
|
|
690
|
-
# it should have been provided for binary tasks
|
|
691
|
-
_LOGGER.error(f"The classification threshold must be provided for '{self.kind}'.")
|
|
692
|
-
raise ValueError()
|
|
693
|
-
elif classification_threshold <= 0.0 or classification_threshold >= 1.0:
|
|
694
|
-
# Invalid float
|
|
695
|
-
_LOGGER.error(f"A classification threshold of {classification_threshold} is invalid. Must be in the range (0.0 - 1.0).")
|
|
696
|
-
raise ValueError()
|
|
697
|
-
else:
|
|
698
|
-
threshold_validated = classification_threshold
|
|
699
|
-
|
|
700
|
-
# Validate val configuration
|
|
701
|
-
if val_format_configuration is not None:
|
|
702
|
-
if not isinstance(val_format_configuration, (RegressionMetricsFormat,
|
|
703
|
-
MultiTargetRegressionMetricsFormat,
|
|
704
|
-
BinaryClassificationMetricsFormat,
|
|
705
|
-
MultiClassClassificationMetricsFormat,
|
|
706
|
-
BinaryImageClassificationMetricsFormat,
|
|
707
|
-
MultiClassImageClassificationMetricsFormat,
|
|
708
|
-
MultiLabelBinaryClassificationMetricsFormat,
|
|
709
|
-
BinarySegmentationMetricsFormat,
|
|
710
|
-
MultiClassSegmentationMetricsFormat)):
|
|
711
|
-
_LOGGER.error(f"Invalid 'format_configuration': '{type(val_format_configuration)}'.")
|
|
712
|
-
raise ValueError()
|
|
713
|
-
else:
|
|
714
|
-
val_configuration_validated = val_format_configuration
|
|
715
|
-
else: # config is None
|
|
716
|
-
val_configuration_validated = None
|
|
717
|
-
|
|
718
|
-
# Validate directory
|
|
719
|
-
save_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
720
|
-
|
|
721
|
-
# Validate test data and dispatch
|
|
722
|
-
if test_data is not None:
|
|
723
|
-
if not isinstance(test_data, (DataLoader, Dataset)):
|
|
724
|
-
_LOGGER.error(f"Invalid type for 'test_data': '{type(test_data)}'.")
|
|
725
|
-
raise ValueError()
|
|
726
|
-
test_data_validated = test_data
|
|
727
|
-
|
|
728
|
-
validation_metrics_path = save_path / DragonTrainerKeys.VALIDATION_METRICS_DIR
|
|
729
|
-
test_metrics_path = save_path / DragonTrainerKeys.TEST_METRICS_DIR
|
|
730
|
-
|
|
731
|
-
# Dispatch validation set
|
|
732
|
-
_LOGGER.info(f"🔎 Evaluating on validation dataset. Metrics will be saved to '{DragonTrainerKeys.VALIDATION_METRICS_DIR}'")
|
|
733
|
-
self._evaluate(save_dir=validation_metrics_path,
|
|
734
|
-
model_checkpoint=checkpoint_validated,
|
|
735
|
-
classification_threshold=threshold_validated,
|
|
736
|
-
data=None,
|
|
737
|
-
format_configuration=val_configuration_validated)
|
|
738
|
-
|
|
739
|
-
# Validate test configuration
|
|
740
|
-
if test_format_configuration is not None:
|
|
741
|
-
if not isinstance(test_format_configuration, (RegressionMetricsFormat,
|
|
742
|
-
MultiTargetRegressionMetricsFormat,
|
|
743
|
-
BinaryClassificationMetricsFormat,
|
|
744
|
-
MultiClassClassificationMetricsFormat,
|
|
745
|
-
BinaryImageClassificationMetricsFormat,
|
|
746
|
-
MultiClassImageClassificationMetricsFormat,
|
|
747
|
-
MultiLabelBinaryClassificationMetricsFormat,
|
|
748
|
-
BinarySegmentationMetricsFormat,
|
|
749
|
-
MultiClassSegmentationMetricsFormat)):
|
|
750
|
-
warning_message_type = f"Invalid test_format_configuration': '{type(test_format_configuration)}'."
|
|
751
|
-
if val_configuration_validated is not None:
|
|
752
|
-
warning_message_type += " 'val_format_configuration' will be used for the test set metrics output."
|
|
753
|
-
test_configuration_validated = val_configuration_validated
|
|
754
|
-
else:
|
|
755
|
-
warning_message_type += " Using default format."
|
|
756
|
-
test_configuration_validated = None
|
|
757
|
-
_LOGGER.warning(warning_message_type)
|
|
758
|
-
else:
|
|
759
|
-
test_configuration_validated = test_format_configuration
|
|
760
|
-
else: #config is None
|
|
761
|
-
test_configuration_validated = None
|
|
762
|
-
|
|
763
|
-
# Dispatch test set
|
|
764
|
-
_LOGGER.info(f"🔎 Evaluating on test dataset. Metrics will be saved to '{DragonTrainerKeys.TEST_METRICS_DIR}'")
|
|
765
|
-
self._evaluate(save_dir=test_metrics_path,
|
|
766
|
-
model_checkpoint="current",
|
|
767
|
-
classification_threshold=threshold_validated,
|
|
768
|
-
data=test_data_validated,
|
|
769
|
-
format_configuration=test_configuration_validated)
|
|
770
|
-
else:
|
|
771
|
-
# Dispatch validation set
|
|
772
|
-
_LOGGER.info(f"Evaluating on validation dataset. Metrics will be saved to '{save_path.name}'")
|
|
773
|
-
self._evaluate(save_dir=save_path,
|
|
774
|
-
model_checkpoint=checkpoint_validated,
|
|
775
|
-
classification_threshold=threshold_validated,
|
|
776
|
-
data=None,
|
|
777
|
-
format_configuration=val_configuration_validated)
|
|
778
|
-
|
|
779
|
-
def _evaluate(self,
|
|
780
|
-
save_dir: Union[str, Path],
|
|
781
|
-
model_checkpoint: Union[Path, Literal["best", "current"]],
|
|
782
|
-
classification_threshold: float,
|
|
783
|
-
data: Optional[Union[DataLoader, Dataset]],
|
|
784
|
-
format_configuration: Optional[Union[
|
|
785
|
-
RegressionMetricsFormat,
|
|
786
|
-
MultiTargetRegressionMetricsFormat,
|
|
787
|
-
BinaryClassificationMetricsFormat,
|
|
788
|
-
MultiClassClassificationMetricsFormat,
|
|
789
|
-
BinaryImageClassificationMetricsFormat,
|
|
790
|
-
MultiClassImageClassificationMetricsFormat,
|
|
791
|
-
MultiLabelBinaryClassificationMetricsFormat,
|
|
792
|
-
BinarySegmentationMetricsFormat,
|
|
793
|
-
MultiClassSegmentationMetricsFormat
|
|
794
|
-
]]=None):
|
|
795
|
-
"""
|
|
796
|
-
Changed to a private helper function.
|
|
797
|
-
"""
|
|
798
|
-
dataset_for_artifacts = None
|
|
799
|
-
eval_loader = None
|
|
800
|
-
|
|
801
|
-
# set threshold
|
|
802
|
-
self._classification_threshold = classification_threshold
|
|
803
|
-
|
|
804
|
-
# load model checkpoint
|
|
805
|
-
if isinstance(model_checkpoint, Path):
|
|
806
|
-
self._load_checkpoint(path=model_checkpoint)
|
|
807
|
-
elif model_checkpoint == MagicWords.BEST and self._checkpoint_callback:
|
|
808
|
-
path_to_latest = self._checkpoint_callback.best_checkpoint_path
|
|
809
|
-
self._load_checkpoint(path_to_latest)
|
|
810
|
-
elif model_checkpoint == MagicWords.BEST and self._checkpoint_callback is None:
|
|
811
|
-
_LOGGER.error(f"'model_checkpoint' set to '{MagicWords.BEST}' but no checkpoint callback was found.")
|
|
812
|
-
raise ValueError()
|
|
813
|
-
|
|
814
|
-
# Dataloader
|
|
815
|
-
if isinstance(data, DataLoader):
|
|
816
|
-
eval_loader = data
|
|
817
|
-
# Try to get the dataset from the loader for fetching target names
|
|
818
|
-
if hasattr(data, 'dataset'):
|
|
819
|
-
dataset_for_artifacts = data.dataset # type: ignore
|
|
820
|
-
elif isinstance(data, Dataset):
|
|
821
|
-
# Create a new loader from the provided dataset
|
|
822
|
-
eval_loader = DataLoader(data,
|
|
823
|
-
batch_size=self._batch_size,
|
|
824
|
-
shuffle=False,
|
|
825
|
-
num_workers=0 if self.device.type == 'mps' else self.dataloader_workers,
|
|
826
|
-
pin_memory=(self.device.type == "cuda"))
|
|
827
|
-
dataset_for_artifacts = data
|
|
828
|
-
else: # data is None, use the trainer's default test dataset
|
|
829
|
-
if self.validation_dataset is None:
|
|
830
|
-
_LOGGER.error("Cannot evaluate. No data provided and no validation dataset available in the trainer.")
|
|
831
|
-
raise ValueError()
|
|
832
|
-
# Create a fresh DataLoader from the test_dataset
|
|
833
|
-
eval_loader = DataLoader(self.validation_dataset,
|
|
834
|
-
batch_size=self._batch_size,
|
|
835
|
-
shuffle=False,
|
|
836
|
-
num_workers=0 if self.device.type == 'mps' else self.dataloader_workers,
|
|
837
|
-
pin_memory=(self.device.type == "cuda"))
|
|
838
|
-
|
|
839
|
-
dataset_for_artifacts = self.validation_dataset
|
|
840
|
-
|
|
841
|
-
if eval_loader is None:
|
|
842
|
-
_LOGGER.error("Cannot evaluate. No valid data was provided or found.")
|
|
843
|
-
raise ValueError()
|
|
844
|
-
|
|
845
|
-
# print("\n--- Model Evaluation ---")
|
|
846
|
-
|
|
847
|
-
all_preds, all_probs, all_true = [], [], []
|
|
848
|
-
for y_pred_b, y_prob_b, y_true_b in self._predict_for_eval(eval_loader):
|
|
849
|
-
if y_pred_b is not None: all_preds.append(y_pred_b)
|
|
850
|
-
if y_prob_b is not None: all_probs.append(y_prob_b)
|
|
851
|
-
if y_true_b is not None: all_true.append(y_true_b)
|
|
852
|
-
|
|
853
|
-
if not all_true:
|
|
854
|
-
_LOGGER.error("Evaluation failed: No data was processed.")
|
|
855
|
-
return
|
|
856
|
-
|
|
857
|
-
y_pred = np.concatenate(all_preds)
|
|
858
|
-
y_true = np.concatenate(all_true)
|
|
859
|
-
y_prob = np.concatenate(all_probs) if all_probs else None
|
|
860
|
-
|
|
861
|
-
# --- Routing Logic ---
|
|
862
|
-
# Single-target regression
|
|
863
|
-
if self.kind == MLTaskKeys.REGRESSION:
|
|
864
|
-
# Check configuration
|
|
865
|
-
config = None
|
|
866
|
-
if format_configuration and isinstance(format_configuration, RegressionMetricsFormat):
|
|
867
|
-
config = format_configuration
|
|
868
|
-
elif format_configuration:
|
|
869
|
-
_LOGGER.warning(f"Wrong configuration type: Received '{type(format_configuration).__name__}'.")
|
|
870
|
-
|
|
871
|
-
regression_metrics(y_true=y_true.flatten(),
|
|
872
|
-
y_pred=y_pred.flatten(),
|
|
873
|
-
save_dir=save_dir,
|
|
874
|
-
config=config)
|
|
875
|
-
|
|
876
|
-
# single target classification
|
|
877
|
-
elif self.kind in [MLTaskKeys.BINARY_CLASSIFICATION,
|
|
878
|
-
MLTaskKeys.BINARY_IMAGE_CLASSIFICATION,
|
|
879
|
-
MLTaskKeys.MULTICLASS_CLASSIFICATION,
|
|
880
|
-
MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION]:
|
|
881
|
-
# get the class map if it exists
|
|
882
|
-
try:
|
|
883
|
-
class_map = dataset_for_artifacts.class_map # type: ignore
|
|
884
|
-
except AttributeError:
|
|
885
|
-
_LOGGER.warning(f"Dataset has no 'class_map' attribute. Using generics.")
|
|
886
|
-
class_map = None
|
|
887
|
-
else:
|
|
888
|
-
if not isinstance(class_map, dict):
|
|
889
|
-
_LOGGER.warning(f"Dataset has a 'class_map' attribute, but it is not a dictionary: '{type(class_map)}'.")
|
|
890
|
-
class_map = None
|
|
891
|
-
|
|
892
|
-
# Check configuration
|
|
893
|
-
config = None
|
|
894
|
-
if format_configuration:
|
|
895
|
-
if self.kind == MLTaskKeys.BINARY_CLASSIFICATION and isinstance(format_configuration, BinaryClassificationMetricsFormat):
|
|
896
|
-
config = format_configuration
|
|
897
|
-
elif self.kind == MLTaskKeys.BINARY_IMAGE_CLASSIFICATION and isinstance(format_configuration, BinaryImageClassificationMetricsFormat):
|
|
898
|
-
config = format_configuration
|
|
899
|
-
elif self.kind == MLTaskKeys.MULTICLASS_CLASSIFICATION and isinstance(format_configuration, MultiClassClassificationMetricsFormat):
|
|
900
|
-
config = format_configuration
|
|
901
|
-
elif self.kind == MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION and isinstance(format_configuration, MultiClassImageClassificationMetricsFormat):
|
|
902
|
-
config = format_configuration
|
|
903
|
-
else:
|
|
904
|
-
_LOGGER.warning(f"Wrong configuration type: Received '{type(format_configuration).__name__}'.")
|
|
905
|
-
|
|
906
|
-
classification_metrics(save_dir=save_dir,
|
|
907
|
-
y_true=y_true,
|
|
908
|
-
y_pred=y_pred,
|
|
909
|
-
y_prob=y_prob,
|
|
910
|
-
class_map=class_map,
|
|
911
|
-
config=config)
|
|
912
|
-
|
|
913
|
-
# multitarget regression
|
|
914
|
-
elif self.kind == MLTaskKeys.MULTITARGET_REGRESSION:
|
|
915
|
-
try:
|
|
916
|
-
target_names = dataset_for_artifacts.target_names # type: ignore
|
|
917
|
-
except AttributeError:
|
|
918
|
-
num_targets = y_true.shape[1]
|
|
919
|
-
target_names = [f"target_{i}" for i in range(num_targets)]
|
|
920
|
-
_LOGGER.warning(f"Dataset has no 'target_names' attribute. Using generic names.")
|
|
921
|
-
|
|
922
|
-
# Check configuration
|
|
923
|
-
config = None
|
|
924
|
-
if format_configuration and isinstance(format_configuration, MultiTargetRegressionMetricsFormat):
|
|
925
|
-
config = format_configuration
|
|
926
|
-
elif format_configuration:
|
|
927
|
-
_LOGGER.warning(f"Wrong configuration type: Received '{type(format_configuration).__name__}'.")
|
|
928
|
-
|
|
929
|
-
multi_target_regression_metrics(y_true=y_true,
|
|
930
|
-
y_pred=y_pred,
|
|
931
|
-
target_names=target_names,
|
|
932
|
-
save_dir=save_dir,
|
|
933
|
-
config=config)
|
|
934
|
-
|
|
935
|
-
# multi-label binary classification
|
|
936
|
-
elif self.kind == MLTaskKeys.MULTILABEL_BINARY_CLASSIFICATION:
|
|
937
|
-
try:
|
|
938
|
-
target_names = dataset_for_artifacts.target_names # type: ignore
|
|
939
|
-
except AttributeError:
|
|
940
|
-
num_targets = y_true.shape[1]
|
|
941
|
-
target_names = [f"label_{i}" for i in range(num_targets)]
|
|
942
|
-
_LOGGER.warning(f"Dataset has no 'target_names' attribute. Using generic names.")
|
|
943
|
-
|
|
944
|
-
if y_prob is None:
|
|
945
|
-
_LOGGER.error("Evaluation for multi_label_classification requires probabilities (y_prob).")
|
|
946
|
-
return
|
|
947
|
-
|
|
948
|
-
# Check configuration
|
|
949
|
-
config = None
|
|
950
|
-
if format_configuration and isinstance(format_configuration, MultiLabelBinaryClassificationMetricsFormat):
|
|
951
|
-
config = format_configuration
|
|
952
|
-
elif format_configuration:
|
|
953
|
-
_LOGGER.warning(f"Wrong configuration type: Received '{type(format_configuration).__name__}'.")
|
|
954
|
-
|
|
955
|
-
multi_label_classification_metrics(y_true=y_true,
|
|
956
|
-
y_pred=y_pred,
|
|
957
|
-
y_prob=y_prob,
|
|
958
|
-
target_names=target_names,
|
|
959
|
-
save_dir=save_dir,
|
|
960
|
-
config=config)
|
|
961
|
-
|
|
962
|
-
# Segmentation tasks
|
|
963
|
-
elif self.kind in [MLTaskKeys.BINARY_SEGMENTATION, MLTaskKeys.MULTICLASS_SEGMENTATION]:
|
|
964
|
-
class_names = None
|
|
965
|
-
try:
|
|
966
|
-
# Try to get 'classes' from VisionDatasetMaker
|
|
967
|
-
if hasattr(dataset_for_artifacts, 'classes'):
|
|
968
|
-
class_names = dataset_for_artifacts.classes # type: ignore
|
|
969
|
-
# Fallback for Subset
|
|
970
|
-
elif hasattr(dataset_for_artifacts, 'dataset') and hasattr(dataset_for_artifacts.dataset, 'classes'): # type: ignore
|
|
971
|
-
class_names = dataset_for_artifacts.dataset.classes # type: ignore
|
|
972
|
-
except AttributeError:
|
|
973
|
-
pass # class_names is still None
|
|
974
|
-
|
|
975
|
-
if class_names is None:
|
|
976
|
-
try:
|
|
977
|
-
# Fallback to 'target_names'
|
|
978
|
-
class_names = dataset_for_artifacts.target_names # type: ignore
|
|
979
|
-
except AttributeError:
|
|
980
|
-
# Fallback to inferring from labels
|
|
981
|
-
labels = np.unique(y_true)
|
|
982
|
-
class_names = [f"Class {i}" for i in labels]
|
|
983
|
-
_LOGGER.warning(f"Dataset has no 'classes' or 'target_names' attribute. Using generic names.")
|
|
984
|
-
|
|
985
|
-
# Check configuration
|
|
986
|
-
config = None
|
|
987
|
-
if format_configuration and isinstance(format_configuration, (BinarySegmentationMetricsFormat, MultiClassSegmentationMetricsFormat)):
|
|
988
|
-
config = format_configuration
|
|
989
|
-
elif format_configuration:
|
|
990
|
-
_LOGGER.warning(f"Wrong configuration type: Received '{type(format_configuration).__name__}'.")
|
|
991
|
-
|
|
992
|
-
segmentation_metrics(y_true=y_true,
|
|
993
|
-
y_pred=y_pred,
|
|
994
|
-
save_dir=save_dir,
|
|
995
|
-
class_names=class_names,
|
|
996
|
-
config=config)
|
|
997
|
-
|
|
998
|
-
def explain_shap(self,
|
|
999
|
-
save_dir: Union[str,Path],
|
|
1000
|
-
explain_dataset: Optional[Dataset] = None,
|
|
1001
|
-
n_samples: int = 300,
|
|
1002
|
-
feature_names: Optional[List[str]] = None,
|
|
1003
|
-
target_names: Optional[List[str]] = None,
|
|
1004
|
-
explainer_type: Literal['deep', 'kernel'] = 'kernel'):
|
|
1005
|
-
"""
|
|
1006
|
-
Explains model predictions using SHAP and saves all artifacts.
|
|
1007
|
-
|
|
1008
|
-
NOTE: SHAP support is limited to single-target tasks (Regression, Binary/Multiclass Classification).
|
|
1009
|
-
For complex tasks (Multi-target, Multi-label, Sequences, Images), please use `explain_captum()`.
|
|
1010
|
-
|
|
1011
|
-
The background data is automatically sampled from the trainer's training dataset.
|
|
1012
|
-
|
|
1013
|
-
This method automatically routes to the appropriate SHAP summary plot
|
|
1014
|
-
function based on the task. If `feature_names` or `target_names` (multi-target) are not provided,
|
|
1015
|
-
it will attempt to extract them from the dataset.
|
|
1016
|
-
|
|
1017
|
-
Args:
|
|
1018
|
-
explain_dataset (Dataset | None): A specific dataset to explain.
|
|
1019
|
-
If None, the trainer's test dataset is used.
|
|
1020
|
-
n_samples (int): The number of samples to use for both background and explanation.
|
|
1021
|
-
feature_names (list[str] | None): Feature names. If None, the names will be extracted from the Dataset and raise an error on failure.
|
|
1022
|
-
target_names (list[str] | None): Target names for multi-target tasks.
|
|
1023
|
-
save_dir (str | Path): Directory to save all SHAP artifacts.
|
|
1024
|
-
explainer_type (Literal['deep', 'kernel']): The explainer to use.
|
|
1025
|
-
- 'deep': Uses shap.DeepExplainer. Fast and efficient for PyTorch models.
|
|
1026
|
-
- 'kernel': Uses shap.KernelExplainer. Model-agnostic but EXTREMELY slow and memory-intensive. Use with a very low 'n_samples'< 100.
|
|
1027
|
-
"""
|
|
1028
|
-
# --- 1. Compatibility Guard ---
|
|
1029
|
-
valid_shap_tasks = [
|
|
1030
|
-
MLTaskKeys.REGRESSION,
|
|
1031
|
-
MLTaskKeys.BINARY_CLASSIFICATION,
|
|
1032
|
-
MLTaskKeys.MULTICLASS_CLASSIFICATION
|
|
1033
|
-
]
|
|
1034
|
-
|
|
1035
|
-
if self.kind not in valid_shap_tasks:
|
|
1036
|
-
_LOGGER.warning(f"SHAP explanation is deprecated for task '{self.kind}' due to instability. Please use 'explain_captum()' instead.")
|
|
1037
|
-
return
|
|
1038
|
-
|
|
1039
|
-
# memory efficient helper
|
|
1040
|
-
def _get_random_sample(dataset: Dataset, num_samples: int):
|
|
1041
|
-
"""
|
|
1042
|
-
Memory-efficiently samples data from a dataset.
|
|
1043
|
-
"""
|
|
1044
|
-
if dataset is None:
|
|
1045
|
-
return None
|
|
1046
|
-
|
|
1047
|
-
dataset_len = len(dataset) # type: ignore
|
|
1048
|
-
if dataset_len == 0:
|
|
1049
|
-
return None
|
|
1050
|
-
|
|
1051
|
-
# For MPS devices, num_workers must be 0 to ensure stability
|
|
1052
|
-
loader_workers = 0 if self.device.type == 'mps' else self.dataloader_workers
|
|
1053
|
-
|
|
1054
|
-
# Ensure batch_size is not larger than the dataset itself
|
|
1055
|
-
batch_size = min(num_samples, 64, dataset_len)
|
|
1056
|
-
|
|
1057
|
-
loader = DataLoader(
|
|
1058
|
-
dataset,
|
|
1059
|
-
batch_size=batch_size,
|
|
1060
|
-
shuffle=True, # Shuffle to get random samples
|
|
1061
|
-
num_workers=loader_workers
|
|
1062
|
-
)
|
|
1063
|
-
|
|
1064
|
-
collected_features = []
|
|
1065
|
-
num_collected = 0
|
|
1066
|
-
|
|
1067
|
-
for features, _ in loader:
|
|
1068
|
-
collected_features.append(features)
|
|
1069
|
-
num_collected += features.size(0)
|
|
1070
|
-
if num_collected >= num_samples:
|
|
1071
|
-
break # Stop once we have enough samples
|
|
1072
|
-
|
|
1073
|
-
if not collected_features:
|
|
1074
|
-
return None
|
|
1075
|
-
|
|
1076
|
-
full_data = torch.cat(collected_features, dim=0)
|
|
1077
|
-
|
|
1078
|
-
# If we collected more than needed, trim it down
|
|
1079
|
-
if full_data.size(0) > num_samples:
|
|
1080
|
-
return full_data[:num_samples]
|
|
1081
|
-
|
|
1082
|
-
return full_data
|
|
1083
|
-
|
|
1084
|
-
# print(f"\n--- Preparing SHAP Data (sampling up to {n_samples} instances) ---")
|
|
1085
|
-
|
|
1086
|
-
# 1. Get background data from the trainer's train_dataset
|
|
1087
|
-
background_data = _get_random_sample(self.train_dataset, n_samples)
|
|
1088
|
-
if background_data is None:
|
|
1089
|
-
_LOGGER.error("Trainer's train_dataset is empty or invalid. Skipping SHAP analysis.")
|
|
1090
|
-
return
|
|
1091
|
-
|
|
1092
|
-
# 2. Determine target dataset and get explanation instances
|
|
1093
|
-
target_dataset = explain_dataset if explain_dataset is not None else self.validation_dataset
|
|
1094
|
-
instances_to_explain = _get_random_sample(target_dataset, n_samples)
|
|
1095
|
-
if instances_to_explain is None:
|
|
1096
|
-
_LOGGER.error("Explanation dataset is empty or invalid. Skipping SHAP analysis.")
|
|
1097
|
-
return
|
|
1098
|
-
|
|
1099
|
-
# attempt to get feature names
|
|
1100
|
-
if feature_names is None:
|
|
1101
|
-
# _LOGGER.info("`feature_names` not provided. Attempting to extract from dataset...")
|
|
1102
|
-
if hasattr(target_dataset, DatasetKeys.FEATURE_NAMES):
|
|
1103
|
-
feature_names = target_dataset.feature_names # type: ignore
|
|
1104
|
-
else:
|
|
1105
|
-
_LOGGER.error(f"Could not extract `feature_names` from the dataset. It must be provided if the dataset object does not have a '{DatasetKeys.FEATURE_NAMES}' attribute.")
|
|
1106
|
-
raise ValueError()
|
|
1107
|
-
|
|
1108
|
-
# move model to device
|
|
1109
|
-
self.model.to(self.device)
|
|
1110
|
-
|
|
1111
|
-
# 3. Call the plotting function
|
|
1112
|
-
if self.kind in [MLTaskKeys.REGRESSION, MLTaskKeys.BINARY_CLASSIFICATION, MLTaskKeys.MULTICLASS_CLASSIFICATION]:
|
|
1113
|
-
shap_summary_plot(
|
|
1114
|
-
model=self.model,
|
|
1115
|
-
background_data=background_data,
|
|
1116
|
-
instances_to_explain=instances_to_explain,
|
|
1117
|
-
feature_names=feature_names,
|
|
1118
|
-
save_dir=save_dir,
|
|
1119
|
-
explainer_type=explainer_type,
|
|
1120
|
-
device=self.device
|
|
1121
|
-
)
|
|
1122
|
-
# DEPRECATED: Multi-target SHAP support is unstable; recommend Captum instead.
|
|
1123
|
-
elif self.kind in [MLTaskKeys.MULTITARGET_REGRESSION, MLTaskKeys.MULTILABEL_BINARY_CLASSIFICATION]:
|
|
1124
|
-
# try to get target names
|
|
1125
|
-
if target_names is None:
|
|
1126
|
-
target_names = []
|
|
1127
|
-
if hasattr(target_dataset, DatasetKeys.TARGET_NAMES):
|
|
1128
|
-
target_names = target_dataset.target_names # type: ignore
|
|
1129
|
-
else:
|
|
1130
|
-
# Infer number of targets from the model's output layer
|
|
1131
|
-
try:
|
|
1132
|
-
num_targets = self.model.output_layer.out_features # type: ignore
|
|
1133
|
-
target_names = [f"target_{i}" for i in range(num_targets)] # type: ignore
|
|
1134
|
-
_LOGGER.warning("Dataset has no 'target_names' attribute. Using generic names.")
|
|
1135
|
-
except AttributeError:
|
|
1136
|
-
_LOGGER.error("Cannot determine target names for multi-target SHAP plot. Skipping.")
|
|
1137
|
-
return
|
|
1138
|
-
|
|
1139
|
-
multi_target_shap_summary_plot(
|
|
1140
|
-
model=self.model,
|
|
1141
|
-
background_data=background_data,
|
|
1142
|
-
instances_to_explain=instances_to_explain,
|
|
1143
|
-
feature_names=feature_names, # type: ignore
|
|
1144
|
-
target_names=target_names, # type: ignore
|
|
1145
|
-
save_dir=save_dir,
|
|
1146
|
-
explainer_type=explainer_type,
|
|
1147
|
-
device=self.device
|
|
1148
|
-
)
|
|
1149
|
-
|
|
1150
|
-
def explain_captum(self,
|
|
1151
|
-
save_dir: Union[str, Path],
|
|
1152
|
-
explain_dataset: Optional[Dataset] = None,
|
|
1153
|
-
n_samples: int = 100,
|
|
1154
|
-
feature_names: Optional[List[str]] = None,
|
|
1155
|
-
target_names: Optional[List[str]] = None,
|
|
1156
|
-
n_steps: int = 50):
|
|
1157
|
-
"""
|
|
1158
|
-
Explains model predictions using Captum's Integrated Gradients.
|
|
1159
|
-
|
|
1160
|
-
- **Tabular/Classification:** Generates Feature Importance Bar Charts.
|
|
1161
|
-
- **Segmentation:** Generates Spatial Heatmaps for each class.
|
|
1162
|
-
|
|
1163
|
-
Args:
|
|
1164
|
-
save_dir (str | Path): Directory to save artifacts.
|
|
1165
|
-
explain_dataset (Dataset | None): Dataset to sample from. Defaults to validation set.
|
|
1166
|
-
n_samples (int): Number of samples to evaluate.
|
|
1167
|
-
feature_names (list[str] | None): Feature names.
|
|
1168
|
-
- Required for Tabular tasks.
|
|
1169
|
-
- Ignored/Optional for Image tasks (defaults to Channel names).
|
|
1170
|
-
target_names (list[str] | None): Names for the model outputs (or Class names).
|
|
1171
|
-
- If None, attempts to extract from dataset attributes (`target_names`, `classes`, or `class_map`).
|
|
1172
|
-
- If extraction fails, generates generic names (e.g. "Output_0").
|
|
1173
|
-
n_steps (int): Number of interpolation steps.
|
|
1174
|
-
"""
|
|
1175
|
-
# 1. Check availability
|
|
1176
|
-
if not _is_captum_available():
|
|
1177
|
-
_LOGGER.error("Captum is not installed or could not be imported.")
|
|
1178
|
-
return
|
|
1179
|
-
|
|
1180
|
-
# 2. Prepare Data
|
|
1181
|
-
dataset_to_use = explain_dataset if explain_dataset is not None else self.validation_dataset
|
|
1182
|
-
if dataset_to_use is None:
|
|
1183
|
-
_LOGGER.error("No dataset available for explanation.")
|
|
1184
|
-
return
|
|
1185
|
-
|
|
1186
|
-
# Efficient sampling helper
|
|
1187
|
-
def _get_samples(ds, n):
|
|
1188
|
-
# Use num_workers=0 for stability during ad-hoc sampling
|
|
1189
|
-
loader = DataLoader(ds, batch_size=n, shuffle=True, num_workers=0)
|
|
1190
|
-
data_iter = iter(loader)
|
|
1191
|
-
features, targets = next(data_iter)
|
|
1192
|
-
return features, targets
|
|
1193
|
-
|
|
1194
|
-
input_data, _ = _get_samples(dataset_to_use, n_samples)
|
|
1195
|
-
|
|
1196
|
-
# 3. Get Feature Names (Only if NOT segmentation AND NOT image classification)
|
|
1197
|
-
# Image tasks generally don't have explicit feature names; Captum will default to "Channel_X"
|
|
1198
|
-
is_segmentation = self.kind in [MLTaskKeys.BINARY_SEGMENTATION, MLTaskKeys.MULTICLASS_SEGMENTATION]
|
|
1199
|
-
is_image_classification = self.kind in [MLTaskKeys.BINARY_IMAGE_CLASSIFICATION, MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION]
|
|
1200
|
-
|
|
1201
|
-
if feature_names is None and not is_segmentation and not is_image_classification:
|
|
1202
|
-
if hasattr(dataset_to_use, DatasetKeys.FEATURE_NAMES):
|
|
1203
|
-
feature_names = dataset_to_use.feature_names # type: ignore
|
|
1204
|
-
else:
|
|
1205
|
-
_LOGGER.error(f"Could not extract `feature_names`. It must be provided if the dataset does not have it.")
|
|
1206
|
-
raise ValueError()
|
|
1207
|
-
|
|
1208
|
-
# 4. Handle Target Names (or Class Names)
|
|
1209
|
-
if target_names is None:
|
|
1210
|
-
# A. Try dataset attributes first
|
|
1211
|
-
if hasattr(dataset_to_use, DatasetKeys.TARGET_NAMES):
|
|
1212
|
-
target_names = dataset_to_use.target_names # type: ignore
|
|
1213
|
-
elif hasattr(dataset_to_use, "classes"):
|
|
1214
|
-
target_names = dataset_to_use.classes # type: ignore
|
|
1215
|
-
elif hasattr(dataset_to_use, "class_map") and isinstance(dataset_to_use.class_map, dict): # type: ignore
|
|
1216
|
-
# Sort by value (index) to ensure correct order: {name: index} -> [name_at_0, name_at_1...]
|
|
1217
|
-
sorted_items = sorted(dataset_to_use.class_map.items(), key=lambda item: item[1]) # type: ignore
|
|
1218
|
-
target_names = [k for k, v in sorted_items]
|
|
1219
|
-
|
|
1220
|
-
# B. Infer based on task
|
|
1221
|
-
if target_names is None:
|
|
1222
|
-
if self.kind in [MLTaskKeys.REGRESSION, MLTaskKeys.BINARY_CLASSIFICATION, MLTaskKeys.BINARY_IMAGE_CLASSIFICATION]:
|
|
1223
|
-
target_names = ["Output"]
|
|
1224
|
-
elif self.kind == MLTaskKeys.BINARY_SEGMENTATION:
|
|
1225
|
-
target_names = ["Foreground"]
|
|
1226
|
-
|
|
1227
|
-
# For multiclass/multitarget without names, leave it None and let the evaluation function generate generics.
|
|
1228
|
-
|
|
1229
|
-
# 5. Dispatch based on Task
|
|
1230
|
-
if is_segmentation:
|
|
1231
|
-
# lower n_steps for segmentation to save memory
|
|
1232
|
-
if n_steps > 30:
|
|
1233
|
-
n_steps = 30
|
|
1234
|
-
_LOGGER.warning(f"Segmentation task detected: Reducing Captum n_steps to {n_steps} to prevent OOM. If you encounter OOM errors, consider lowering this further.")
|
|
1235
|
-
|
|
1236
|
-
captum_segmentation_heatmap(
|
|
1237
|
-
model=self.model,
|
|
1238
|
-
input_data=input_data,
|
|
1239
|
-
save_dir=save_dir,
|
|
1240
|
-
target_names=target_names, # Can be None, helper handles it
|
|
1241
|
-
n_steps=n_steps,
|
|
1242
|
-
device=self.device
|
|
1243
|
-
)
|
|
1244
|
-
|
|
1245
|
-
elif is_image_classification:
|
|
1246
|
-
captum_image_heatmap(
|
|
1247
|
-
model=self.model,
|
|
1248
|
-
input_data=input_data,
|
|
1249
|
-
save_dir=save_dir,
|
|
1250
|
-
target_names=target_names,
|
|
1251
|
-
n_steps=n_steps,
|
|
1252
|
-
device=self.device
|
|
1253
|
-
)
|
|
1254
|
-
|
|
1255
|
-
else:
|
|
1256
|
-
# Standard Tabular/Image Classification
|
|
1257
|
-
captum_feature_importance(
|
|
1258
|
-
model=self.model,
|
|
1259
|
-
input_data=input_data,
|
|
1260
|
-
feature_names=feature_names,
|
|
1261
|
-
save_dir=save_dir,
|
|
1262
|
-
target_names=target_names,
|
|
1263
|
-
n_steps=n_steps,
|
|
1264
|
-
device=self.device
|
|
1265
|
-
)
|
|
1266
|
-
|
|
1267
|
-
def _attention_helper(self, dataloader: DataLoader):
|
|
1268
|
-
"""
|
|
1269
|
-
Private method to yield model attention weights batch by batch for evaluation.
|
|
1270
|
-
|
|
1271
|
-
Args:
|
|
1272
|
-
dataloader (DataLoader): The dataloader to predict on.
|
|
1273
|
-
|
|
1274
|
-
Yields:
|
|
1275
|
-
(torch.Tensor): Attention weights
|
|
1276
|
-
"""
|
|
1277
|
-
self.model.eval()
|
|
1278
|
-
self.model.to(self.device)
|
|
1279
|
-
|
|
1280
|
-
with torch.no_grad():
|
|
1281
|
-
for features, target in dataloader:
|
|
1282
|
-
features = features.to(self.device)
|
|
1283
|
-
attention_weights = None
|
|
1284
|
-
|
|
1285
|
-
# Get model output
|
|
1286
|
-
# Unpack logits and weights from the special forward method
|
|
1287
|
-
_output, attention_weights = self.model.forward_attention(features) # type: ignore
|
|
1288
|
-
|
|
1289
|
-
if attention_weights is not None:
|
|
1290
|
-
attention_weights = attention_weights.cpu()
|
|
1291
|
-
|
|
1292
|
-
yield attention_weights
|
|
1293
|
-
|
|
1294
|
-
def explain_attention(self, save_dir: Union[str, Path],
|
|
1295
|
-
feature_names: Optional[List[str]] = None,
|
|
1296
|
-
explain_dataset: Optional[Dataset] = None,
|
|
1297
|
-
plot_n_features: int = 10):
|
|
1298
|
-
"""
|
|
1299
|
-
Generates and saves a feature importance plot based on attention weights.
|
|
1300
|
-
|
|
1301
|
-
This method only works for models with models with 'has_interpretable_attention'.
|
|
1302
|
-
|
|
1303
|
-
Args:
|
|
1304
|
-
save_dir (str | Path): Directory to save the plot and summary data.
|
|
1305
|
-
feature_names (List[str] | None): Names for the features for plot labeling. If None, the names will be extracted from the Dataset and raise an error on failure.
|
|
1306
|
-
explain_dataset (Dataset, optional): A specific dataset to explain. If None, the trainer's test dataset is used.
|
|
1307
|
-
plot_n_features (int): Number of top features to plot.
|
|
1308
|
-
"""
|
|
1309
|
-
|
|
1310
|
-
# print("\n--- Attention Analysis ---")
|
|
1311
|
-
|
|
1312
|
-
# --- Step 1: Check if the model supports this explanation ---
|
|
1313
|
-
if not getattr(self.model, 'has_interpretable_attention', False):
|
|
1314
|
-
_LOGGER.warning("Model is not compatible with interpretable attention analysis. Skipping.")
|
|
1315
|
-
return
|
|
1316
|
-
|
|
1317
|
-
# --- Step 2: Set up the dataloader ---
|
|
1318
|
-
dataset_to_use = explain_dataset if explain_dataset is not None else self.validation_dataset
|
|
1319
|
-
if not isinstance(dataset_to_use, Dataset):
|
|
1320
|
-
_LOGGER.error("The explanation dataset is empty or invalid. Skipping attention analysis.")
|
|
1321
|
-
return
|
|
1322
|
-
|
|
1323
|
-
# Get feature names
|
|
1324
|
-
if feature_names is None:
|
|
1325
|
-
if hasattr(dataset_to_use, DatasetKeys.FEATURE_NAMES):
|
|
1326
|
-
feature_names = dataset_to_use.feature_names # type: ignore
|
|
1327
|
-
else:
|
|
1328
|
-
_LOGGER.error(f"Could not extract `feature_names` from the dataset for attention plot. It must be provided if the dataset object does not have a '{DatasetKeys.FEATURE_NAMES}' attribute.")
|
|
1329
|
-
raise ValueError()
|
|
1330
|
-
|
|
1331
|
-
explain_loader = DataLoader(
|
|
1332
|
-
dataset=dataset_to_use, batch_size=32, shuffle=False,
|
|
1333
|
-
num_workers=0 if self.device.type == 'mps' else self.dataloader_workers,
|
|
1334
|
-
pin_memory=("cuda" in self.device.type)
|
|
1335
|
-
)
|
|
1336
|
-
|
|
1337
|
-
# --- Step 3: Collect weights ---
|
|
1338
|
-
all_weights = []
|
|
1339
|
-
for att_weights_b in self._attention_helper(explain_loader):
|
|
1340
|
-
if att_weights_b is not None:
|
|
1341
|
-
all_weights.append(att_weights_b)
|
|
1342
|
-
|
|
1343
|
-
# --- Step 4: Call the plotting function ---
|
|
1344
|
-
if all_weights:
|
|
1345
|
-
plot_attention_importance(
|
|
1346
|
-
weights=all_weights,
|
|
1347
|
-
feature_names=feature_names,
|
|
1348
|
-
save_dir=save_dir,
|
|
1349
|
-
top_n=plot_n_features
|
|
1350
|
-
)
|
|
1351
|
-
else:
|
|
1352
|
-
_LOGGER.error("No attention weights were collected from the model.")
|
|
1353
|
-
|
|
1354
|
-
def finalize_model_training(self,
|
|
1355
|
-
model_checkpoint: Union[Path, Literal['best', 'current']],
|
|
1356
|
-
save_dir: Union[str, Path],
|
|
1357
|
-
finalize_config: Union[FinalizeRegression,
|
|
1358
|
-
FinalizeMultiTargetRegression,
|
|
1359
|
-
FinalizeBinaryClassification,
|
|
1360
|
-
FinalizeBinaryImageClassification,
|
|
1361
|
-
FinalizeMultiClassClassification,
|
|
1362
|
-
FinalizeMultiClassImageClassification,
|
|
1363
|
-
FinalizeBinarySegmentation,
|
|
1364
|
-
FinalizeMultiClassSegmentation,
|
|
1365
|
-
FinalizeMultiLabelBinaryClassification]):
|
|
1366
|
-
"""
|
|
1367
|
-
Saves a finalized, "inference-ready" model state to a .pth file.
|
|
1368
|
-
|
|
1369
|
-
This method saves the model's `state_dict`, the final epoch number, and optional configuration for the task at hand.
|
|
1370
|
-
|
|
1371
|
-
Args:
|
|
1372
|
-
model_checkpoint (Path | "best" | "current"):
|
|
1373
|
-
- Path: Loads the model state from a specific checkpoint file.
|
|
1374
|
-
- "best": Loads the best model state saved by the `DragonModelCheckpoint` callback.
|
|
1375
|
-
- "current": Uses the model's state as it is.
|
|
1376
|
-
save_dir (str | Path): The directory to save the finalized model.
|
|
1377
|
-
finalize_config (object): A data class instance specific to the ML task containing task-specific metadata required for inference.
|
|
1378
|
-
"""
|
|
1379
|
-
if self.kind == MLTaskKeys.REGRESSION and not isinstance(finalize_config, FinalizeRegression):
|
|
1380
|
-
_LOGGER.error(f"For task {self.kind}, expected finalize_config of type 'FinalizeRegression', but got {type(finalize_config).__name__}.")
|
|
1381
|
-
raise TypeError()
|
|
1382
|
-
elif self.kind == MLTaskKeys.MULTITARGET_REGRESSION and not isinstance(finalize_config, FinalizeMultiTargetRegression):
|
|
1383
|
-
_LOGGER.error(f"For task {self.kind}, expected finalize_config of type 'FinalizeMultiTargetRegression', but got {type(finalize_config).__name__}.")
|
|
1384
|
-
raise TypeError()
|
|
1385
|
-
elif self.kind == MLTaskKeys.BINARY_CLASSIFICATION and not isinstance(finalize_config, FinalizeBinaryClassification):
|
|
1386
|
-
_LOGGER.error(f"For task {self.kind}, expected finalize_config of type 'FinalizeBinaryClassification', but got {type(finalize_config).__name__}.")
|
|
1387
|
-
raise TypeError()
|
|
1388
|
-
elif self.kind == MLTaskKeys.BINARY_IMAGE_CLASSIFICATION and not isinstance(finalize_config, FinalizeBinaryImageClassification):
|
|
1389
|
-
_LOGGER.error(f"For task {self.kind}, expected finalize_config of type 'FinalizeBinaryImageClassification', but got {type(finalize_config).__name__}.")
|
|
1390
|
-
raise TypeError()
|
|
1391
|
-
elif self.kind == MLTaskKeys.MULTICLASS_CLASSIFICATION and not isinstance(finalize_config, FinalizeMultiClassClassification):
|
|
1392
|
-
_LOGGER.error(f"For task {self.kind}, expected finalize_config of type 'FinalizeMultiClassClassification', but got {type(finalize_config).__name__}.")
|
|
1393
|
-
raise TypeError()
|
|
1394
|
-
elif self.kind == MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION and not isinstance(finalize_config, FinalizeMultiClassImageClassification):
|
|
1395
|
-
_LOGGER.error(f"For task {self.kind}, expected finalize_config of type 'FinalizeMultiClassImageClassification', but got {type(finalize_config).__name__}.")
|
|
1396
|
-
raise TypeError()
|
|
1397
|
-
elif self.kind == MLTaskKeys.BINARY_SEGMENTATION and not isinstance(finalize_config, FinalizeBinarySegmentation):
|
|
1398
|
-
_LOGGER.error(f"For task {self.kind}, expected finalize_config of type 'FinalizeBinarySegmentation', but got {type(finalize_config).__name__}.")
|
|
1399
|
-
raise TypeError()
|
|
1400
|
-
elif self.kind == MLTaskKeys.MULTICLASS_SEGMENTATION and not isinstance(finalize_config, FinalizeMultiClassSegmentation):
|
|
1401
|
-
_LOGGER.error(f"For task {self.kind}, expected finalize_config of type 'FinalizeMultiClassSegmentation', but got {type(finalize_config).__name__}.")
|
|
1402
|
-
raise TypeError()
|
|
1403
|
-
elif self.kind == MLTaskKeys.MULTILABEL_BINARY_CLASSIFICATION and not isinstance(finalize_config, FinalizeMultiLabelBinaryClassification):
|
|
1404
|
-
_LOGGER.error(f"For task {self.kind}, expected finalize_config of type 'FinalizeMultiLabelBinaryClassification', but got {type(finalize_config).__name__}.")
|
|
1405
|
-
raise TypeError()
|
|
1406
|
-
|
|
1407
|
-
# handle save path
|
|
1408
|
-
dir_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
1409
|
-
full_path = dir_path / finalize_config.filename
|
|
1410
|
-
|
|
1411
|
-
# handle checkpoint
|
|
1412
|
-
self._load_model_state_for_finalizing(model_checkpoint)
|
|
1413
|
-
|
|
1414
|
-
# Create finalized data
|
|
1415
|
-
finalized_data = {
|
|
1416
|
-
PyTorchCheckpointKeys.EPOCH: self.epoch,
|
|
1417
|
-
PyTorchCheckpointKeys.MODEL_STATE: self.model.state_dict(),
|
|
1418
|
-
PyTorchCheckpointKeys.TASK: finalize_config.task
|
|
1419
|
-
}
|
|
1420
|
-
|
|
1421
|
-
# Parse config
|
|
1422
|
-
if finalize_config.target_name is not None:
|
|
1423
|
-
finalized_data[PyTorchCheckpointKeys.TARGET_NAME] = finalize_config.target_name
|
|
1424
|
-
if finalize_config.target_names is not None:
|
|
1425
|
-
finalized_data[PyTorchCheckpointKeys.TARGET_NAMES] = finalize_config.target_names
|
|
1426
|
-
if finalize_config.classification_threshold is not None:
|
|
1427
|
-
finalized_data[PyTorchCheckpointKeys.CLASSIFICATION_THRESHOLD] = finalize_config.classification_threshold
|
|
1428
|
-
if finalize_config.class_map is not None:
|
|
1429
|
-
finalized_data[PyTorchCheckpointKeys.CLASS_MAP] = finalize_config.class_map
|
|
1430
|
-
|
|
1431
|
-
# Save model file
|
|
1432
|
-
torch.save(finalized_data, full_path)
|
|
1433
|
-
|
|
1434
|
-
_LOGGER.info(f"Finalized model file saved to '{full_path}'")
|
|
1435
|
-
|
|
1436
|
-
|
|
1437
|
-
# Object Detection Trainer
|
|
1438
|
-
class DragonDetectionTrainer(_BaseDragonTrainer):
|
|
1439
|
-
def __init__(self, model: nn.Module,
|
|
1440
|
-
train_dataset: Dataset,
|
|
1441
|
-
validation_dataset: Dataset,
|
|
1442
|
-
collate_fn: Callable, optimizer: torch.optim.Optimizer,
|
|
1443
|
-
device: Union[Literal['cuda', 'mps', 'cpu'],str],
|
|
1444
|
-
checkpoint_callback: Optional[DragonModelCheckpoint],
|
|
1445
|
-
early_stopping_callback: Optional[_DragonEarlyStopping],
|
|
1446
|
-
lr_scheduler_callback: Optional[_DragonLRScheduler],
|
|
1447
|
-
extra_callbacks: Optional[List[_Callback]] = None,
|
|
1448
|
-
dataloader_workers: int = 2):
|
|
1449
|
-
"""
|
|
1450
|
-
Automates the training process of an Object Detection Model (e.g., DragonFastRCNN).
|
|
1451
|
-
|
|
1452
|
-
Built-in Callbacks: `History`, `TqdmProgressBar`
|
|
1453
|
-
|
|
1454
|
-
Args:
|
|
1455
|
-
model (nn.Module): The PyTorch object detection model to train.
|
|
1456
|
-
train_dataset (Dataset): The training dataset.
|
|
1457
|
-
validation_dataset (Dataset): The testing/validation dataset.
|
|
1458
|
-
collate_fn (Callable): The collate function from `ObjectDetectionDatasetMaker.collate_fn`.
|
|
1459
|
-
optimizer (torch.optim.Optimizer): The optimizer.
|
|
1460
|
-
device (str): The device to run training on ('cpu', 'cuda', 'mps').
|
|
1461
|
-
dataloader_workers (int): Subprocesses for data loading.
|
|
1462
|
-
checkpoint_callback (DragonModelCheckpoint | None): Callback to save the model.
|
|
1463
|
-
early_stopping_callback (DragonEarlyStopping | None): Callback to stop training early.
|
|
1464
|
-
lr_scheduler_callback (DragonLRScheduler | None): Callback to manage the LR scheduler.
|
|
1465
|
-
extra_callbacks (List[Callback] | None): A list of extra callbacks to use during training.
|
|
1466
|
-
|
|
1467
|
-
## Note:
|
|
1468
|
-
This trainer is specialized. It does not take a `criterion` because object detection models like Faster R-CNN return a dictionary of losses directly from their forward pass during training.
|
|
1469
|
-
"""
|
|
1470
|
-
# Call the base class constructor with common parameters
|
|
1471
|
-
super().__init__(
|
|
1472
|
-
model=model,
|
|
1473
|
-
optimizer=optimizer,
|
|
1474
|
-
device=device,
|
|
1475
|
-
dataloader_workers=dataloader_workers,
|
|
1476
|
-
checkpoint_callback=checkpoint_callback,
|
|
1477
|
-
early_stopping_callback=early_stopping_callback,
|
|
1478
|
-
lr_scheduler_callback=lr_scheduler_callback,
|
|
1479
|
-
extra_callbacks=extra_callbacks
|
|
1480
|
-
)
|
|
1481
|
-
|
|
1482
|
-
self.train_dataset = train_dataset
|
|
1483
|
-
self.validation_dataset = validation_dataset # <-- Renamed
|
|
1484
|
-
self.kind = MLTaskKeys.OBJECT_DETECTION
|
|
1485
|
-
self.collate_fn = collate_fn
|
|
1486
|
-
self.criterion = None # Criterion is handled inside the model
|
|
1487
|
-
|
|
1488
|
-
def _create_dataloaders(self, batch_size: int, shuffle: bool):
|
|
1489
|
-
"""Initializes the DataLoaders with the object detection collate_fn."""
|
|
1490
|
-
# Ensure stability on MPS devices by setting num_workers to 0
|
|
1491
|
-
loader_workers = 0 if self.device.type == 'mps' else self.dataloader_workers
|
|
1492
|
-
|
|
1493
|
-
self.train_loader = DataLoader(
|
|
1494
|
-
dataset=self.train_dataset,
|
|
1495
|
-
batch_size=batch_size,
|
|
1496
|
-
shuffle=shuffle,
|
|
1497
|
-
num_workers=loader_workers,
|
|
1498
|
-
pin_memory=("cuda" in self.device.type),
|
|
1499
|
-
collate_fn=self.collate_fn, # Use the provided collate function
|
|
1500
|
-
drop_last=True
|
|
1501
|
-
)
|
|
1502
|
-
|
|
1503
|
-
self.validation_loader = DataLoader(
|
|
1504
|
-
dataset=self.validation_dataset,
|
|
1505
|
-
batch_size=batch_size,
|
|
1506
|
-
shuffle=False,
|
|
1507
|
-
num_workers=loader_workers,
|
|
1508
|
-
pin_memory=("cuda" in self.device.type),
|
|
1509
|
-
collate_fn=self.collate_fn # Use the provided collate function
|
|
1510
|
-
)
|
|
1511
|
-
|
|
1512
|
-
def _train_step(self):
|
|
1513
|
-
self.model.train()
|
|
1514
|
-
running_loss = 0.0
|
|
1515
|
-
total_samples = 0
|
|
1516
|
-
|
|
1517
|
-
for batch_idx, (images, targets) in enumerate(self.train_loader): # type: ignore
|
|
1518
|
-
# images is a tuple of tensors, targets is a tuple of dicts
|
|
1519
|
-
batch_size = len(images)
|
|
1520
|
-
|
|
1521
|
-
# Create a log dictionary for the batch
|
|
1522
|
-
batch_logs = {
|
|
1523
|
-
PyTorchLogKeys.BATCH_INDEX: batch_idx,
|
|
1524
|
-
PyTorchLogKeys.BATCH_SIZE: batch_size
|
|
1525
|
-
}
|
|
1526
|
-
self._callbacks_hook('on_batch_begin', batch_idx, logs=batch_logs)
|
|
1527
|
-
|
|
1528
|
-
# Move data to device
|
|
1529
|
-
images = list(img.to(self.device) for img in images)
|
|
1530
|
-
targets = [{k: v.to(self.device) for k, v in t.items()} for t in targets]
|
|
1531
|
-
|
|
1532
|
-
self.optimizer.zero_grad()
|
|
1533
|
-
|
|
1534
|
-
# Model returns a loss dict when in train() mode and targets are passed
|
|
1535
|
-
loss_dict = self.model(images, targets)
|
|
1536
|
-
|
|
1537
|
-
if not loss_dict:
|
|
1538
|
-
# No losses returned, skip batch
|
|
1539
|
-
_LOGGER.warning(f"Model returned no losses for batch {batch_idx}. Skipping.")
|
|
1540
|
-
batch_logs[PyTorchLogKeys.BATCH_LOSS] = 0
|
|
1541
|
-
self._callbacks_hook('on_batch_end', batch_idx, logs=batch_logs)
|
|
1542
|
-
continue
|
|
1543
|
-
|
|
1544
|
-
# Sum all losses
|
|
1545
|
-
loss: torch.Tensor = sum(l for l in loss_dict.values()) # type: ignore
|
|
1546
|
-
|
|
1547
|
-
loss.backward()
|
|
1548
|
-
self.optimizer.step()
|
|
1549
|
-
|
|
1550
|
-
# Calculate batch loss and update running loss for the epoch
|
|
1551
|
-
batch_loss = loss.item()
|
|
1552
|
-
running_loss += batch_loss * batch_size
|
|
1553
|
-
total_samples += batch_size # <-- Accumulate total samples
|
|
1554
|
-
|
|
1555
|
-
# Add the batch loss to the logs and call the end-of-batch hook
|
|
1556
|
-
batch_logs[PyTorchLogKeys.BATCH_LOSS] = batch_loss # type: ignore
|
|
1557
|
-
self._callbacks_hook('on_batch_end', batch_idx, logs=batch_logs)
|
|
1558
|
-
|
|
1559
|
-
# Calculate loss using the correct denominator
|
|
1560
|
-
if total_samples == 0:
|
|
1561
|
-
_LOGGER.warning("No samples processed in _train_step. Returning 0 loss.")
|
|
1562
|
-
return {PyTorchLogKeys.TRAIN_LOSS: 0.0}
|
|
1563
|
-
|
|
1564
|
-
return {PyTorchLogKeys.TRAIN_LOSS: running_loss / total_samples}
|
|
1565
|
-
|
|
1566
|
-
def _validation_step(self):
|
|
1567
|
-
self.model.train() # Set to train mode even for validation loss calculation
|
|
1568
|
-
# as model internals (e.g., proposals) might differ, but we still need loss_dict.
|
|
1569
|
-
# use torch.no_grad() to prevent gradient updates.
|
|
1570
|
-
running_loss = 0.0
|
|
1571
|
-
total_samples = 0
|
|
1572
|
-
|
|
1573
|
-
with torch.no_grad():
|
|
1574
|
-
for images, targets in self.validation_loader: # type: ignore
|
|
1575
|
-
batch_size = len(images)
|
|
1576
|
-
|
|
1577
|
-
# Move data to device
|
|
1578
|
-
images = list(img.to(self.device) for img in images)
|
|
1579
|
-
targets = [{k: v.to(self.device) for k, v in t.items()} for t in targets]
|
|
1580
|
-
|
|
1581
|
-
# Get loss dict
|
|
1582
|
-
loss_dict = self.model(images, targets)
|
|
1583
|
-
|
|
1584
|
-
if not loss_dict:
|
|
1585
|
-
_LOGGER.warning("Model returned no losses during validation step. Skipping batch.")
|
|
1586
|
-
continue # Skip if no losses
|
|
1587
|
-
|
|
1588
|
-
# Sum all losses
|
|
1589
|
-
loss: torch.Tensor = sum(l for l in loss_dict.values()) # type: ignore
|
|
1590
|
-
|
|
1591
|
-
running_loss += loss.item() * batch_size
|
|
1592
|
-
total_samples += batch_size # <-- Accumulate total samples
|
|
1593
|
-
|
|
1594
|
-
# Calculate loss using the correct denominator
|
|
1595
|
-
if total_samples == 0:
|
|
1596
|
-
_LOGGER.warning("No samples processed in _validation_step. Returning 0 loss.")
|
|
1597
|
-
return {PyTorchLogKeys.VAL_LOSS: 0.0}
|
|
1598
|
-
|
|
1599
|
-
logs = {PyTorchLogKeys.VAL_LOSS: running_loss / total_samples}
|
|
1600
|
-
return logs
|
|
1601
|
-
|
|
1602
|
-
def evaluate(self,
|
|
1603
|
-
save_dir: Union[str, Path],
|
|
1604
|
-
model_checkpoint: Union[Path, Literal["best", "current"]],
|
|
1605
|
-
test_data: Optional[Union[DataLoader, Dataset]] = None):
|
|
1606
|
-
"""
|
|
1607
|
-
Evaluates the model using object detection mAP metrics.
|
|
1608
|
-
|
|
1609
|
-
Args:
|
|
1610
|
-
save_dir (str | Path): Directory to save all reports and plots.
|
|
1611
|
-
model_checkpoint (Path | "best" | "current"):
|
|
1612
|
-
- Path to a valid checkpoint for the model. The state of the trained model will be overwritten in place.
|
|
1613
|
-
- If 'best', the best checkpoint will be loaded if a DragonModelCheckpoint was provided. The state of the trained model will be overwritten in place.
|
|
1614
|
-
- If 'current', use the current state of the trained model up the latest trained epoch.
|
|
1615
|
-
test_data (DataLoader | Dataset | None): Optional Test data to evaluate the model performance. Validation and Test metrics will be saved to subdirectories.
|
|
1616
|
-
"""
|
|
1617
|
-
# Validate model checkpoint
|
|
1618
|
-
if isinstance(model_checkpoint, Path):
|
|
1619
|
-
checkpoint_validated = make_fullpath(model_checkpoint, enforce="file")
|
|
1620
|
-
elif model_checkpoint in [MagicWords.BEST, MagicWords.CURRENT]:
|
|
1621
|
-
checkpoint_validated = model_checkpoint
|
|
1622
|
-
else:
|
|
1623
|
-
_LOGGER.error(f"'model_checkpoint' must be a Path object, or the string '{MagicWords.BEST}', or the string '{MagicWords.CURRENT}'.")
|
|
1624
|
-
raise ValueError()
|
|
1625
|
-
|
|
1626
|
-
# Validate directory
|
|
1627
|
-
save_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
1628
|
-
|
|
1629
|
-
# Validate test data and dispatch
|
|
1630
|
-
if test_data is not None:
|
|
1631
|
-
if not isinstance(test_data, (DataLoader, Dataset)):
|
|
1632
|
-
_LOGGER.error(f"Invalid type for 'test_data': '{type(test_data)}'.")
|
|
1633
|
-
raise ValueError()
|
|
1634
|
-
test_data_validated = test_data
|
|
1635
|
-
|
|
1636
|
-
validation_metrics_path = save_path / DragonTrainerKeys.VALIDATION_METRICS_DIR
|
|
1637
|
-
test_metrics_path = save_path / DragonTrainerKeys.TEST_METRICS_DIR
|
|
1638
|
-
|
|
1639
|
-
# Dispatch validation set
|
|
1640
|
-
_LOGGER.info(f"Evaluating on validation dataset. Metrics will be saved to '{DragonTrainerKeys.VALIDATION_METRICS_DIR}'")
|
|
1641
|
-
self._evaluate(save_dir=validation_metrics_path,
|
|
1642
|
-
model_checkpoint=checkpoint_validated,
|
|
1643
|
-
data=None) # 'None' triggers use of self.test_dataset
|
|
1644
|
-
|
|
1645
|
-
# Dispatch test set
|
|
1646
|
-
_LOGGER.info(f"Evaluating on test dataset. Metrics will be saved to '{DragonTrainerKeys.TEST_METRICS_DIR}'")
|
|
1647
|
-
self._evaluate(save_dir=test_metrics_path,
|
|
1648
|
-
model_checkpoint="current", # Use 'current' state after loading checkpoint once
|
|
1649
|
-
data=test_data_validated)
|
|
1650
|
-
else:
|
|
1651
|
-
# Dispatch validation set
|
|
1652
|
-
_LOGGER.info(f"Evaluating on validation dataset. Metrics will be saved to '{save_path.name}'")
|
|
1653
|
-
self._evaluate(save_dir=save_path,
|
|
1654
|
-
model_checkpoint=checkpoint_validated,
|
|
1655
|
-
data=None) # 'None' triggers use of self.test_dataset
|
|
1656
|
-
|
|
1657
|
-
def _evaluate(self,
|
|
1658
|
-
save_dir: Union[str, Path],
|
|
1659
|
-
model_checkpoint: Union[Path, Literal["best", "current"]],
|
|
1660
|
-
data: Optional[Union[DataLoader, Dataset]]):
|
|
1661
|
-
"""
|
|
1662
|
-
Changed to a private helper method
|
|
1663
|
-
Evaluates the model using object detection mAP metrics.
|
|
1664
|
-
|
|
1665
|
-
Args:
|
|
1666
|
-
save_dir (str | Path): Directory to save all reports and plots.
|
|
1667
|
-
data (DataLoader | Dataset | None): The data to evaluate on. If None, defaults to the trainer's internal test_dataset.
|
|
1668
|
-
model_checkpoint ('auto' | Path | None):
|
|
1669
|
-
- Path to a valid checkpoint for the model. The state of the trained model will be overwritten in place.
|
|
1670
|
-
- If 'best', the best checkpoint will be loaded if a DragonModelCheckpoint was provided. The state of the trained model will be overwritten in place.
|
|
1671
|
-
- If 'current', use the current state of the trained model up the latest trained epoch.
|
|
1672
|
-
"""
|
|
1673
|
-
dataset_for_artifacts = None
|
|
1674
|
-
eval_loader = None
|
|
1675
|
-
|
|
1676
|
-
# load model checkpoint
|
|
1677
|
-
if isinstance(model_checkpoint, Path):
|
|
1678
|
-
self._load_checkpoint(path=model_checkpoint)
|
|
1679
|
-
elif model_checkpoint == MagicWords.BEST and self._checkpoint_callback:
|
|
1680
|
-
path_to_latest = self._checkpoint_callback.best_checkpoint_path
|
|
1681
|
-
self._load_checkpoint(path_to_latest)
|
|
1682
|
-
elif model_checkpoint == MagicWords.BEST and self._checkpoint_callback is None:
|
|
1683
|
-
_LOGGER.error(f"'model_checkpoint' set to '{MagicWords.BEST}' but no checkpoint callback was found.")
|
|
1684
|
-
raise ValueError()
|
|
1685
|
-
|
|
1686
|
-
# Dataloader
|
|
1687
|
-
if isinstance(data, DataLoader):
|
|
1688
|
-
eval_loader = data
|
|
1689
|
-
if hasattr(data, 'dataset'):
|
|
1690
|
-
dataset_for_artifacts = data.dataset # type: ignore
|
|
1691
|
-
elif isinstance(data, Dataset):
|
|
1692
|
-
# Create a new loader from the provided dataset
|
|
1693
|
-
eval_loader = DataLoader(data,
|
|
1694
|
-
batch_size=self._batch_size,
|
|
1695
|
-
shuffle=False,
|
|
1696
|
-
num_workers=0 if self.device.type == 'mps' else self.dataloader_workers,
|
|
1697
|
-
pin_memory=(self.device.type == "cuda"),
|
|
1698
|
-
collate_fn=self.collate_fn)
|
|
1699
|
-
dataset_for_artifacts = data
|
|
1700
|
-
else: # data is None, use the trainer's default test dataset
|
|
1701
|
-
if self.validation_dataset is None:
|
|
1702
|
-
_LOGGER.error("Cannot evaluate. No data provided and no test_dataset available in the trainer.")
|
|
1703
|
-
raise ValueError()
|
|
1704
|
-
# Create a fresh DataLoader from the test_dataset
|
|
1705
|
-
eval_loader = DataLoader(
|
|
1706
|
-
self.validation_dataset,
|
|
1707
|
-
batch_size=self._batch_size,
|
|
1708
|
-
shuffle=False,
|
|
1709
|
-
num_workers=0 if self.device.type == 'mps' else self.dataloader_workers,
|
|
1710
|
-
pin_memory=(self.device.type == "cuda"),
|
|
1711
|
-
collate_fn=self.collate_fn
|
|
1712
|
-
)
|
|
1713
|
-
dataset_for_artifacts = self.validation_dataset
|
|
1714
|
-
|
|
1715
|
-
if eval_loader is None:
|
|
1716
|
-
_LOGGER.error("Cannot evaluate. No valid data was provided or found.")
|
|
1717
|
-
raise ValueError()
|
|
1718
|
-
|
|
1719
|
-
# print("\n--- Model Evaluation ---")
|
|
1720
|
-
|
|
1721
|
-
all_predictions = []
|
|
1722
|
-
all_targets = []
|
|
1723
|
-
|
|
1724
|
-
self.model.eval() # Set model to evaluation mode
|
|
1725
|
-
self.model.to(self.device)
|
|
1726
|
-
|
|
1727
|
-
with torch.no_grad():
|
|
1728
|
-
for images, targets in eval_loader:
|
|
1729
|
-
# Move images to device
|
|
1730
|
-
images = list(img.to(self.device) for img in images)
|
|
1731
|
-
|
|
1732
|
-
# Model returns predictions when in eval() mode
|
|
1733
|
-
predictions = self.model(images)
|
|
1734
|
-
|
|
1735
|
-
# Move predictions and targets to CPU for aggregation
|
|
1736
|
-
cpu_preds = [{k: v.to('cpu') for k, v in p.items()} for p in predictions]
|
|
1737
|
-
cpu_targets = [{k: v.to('cpu') for k, v in t.items()} for t in targets]
|
|
1738
|
-
|
|
1739
|
-
all_predictions.extend(cpu_preds)
|
|
1740
|
-
all_targets.extend(cpu_targets)
|
|
1741
|
-
|
|
1742
|
-
if not all_targets:
|
|
1743
|
-
_LOGGER.error("Evaluation failed: No data was processed.")
|
|
1744
|
-
return
|
|
1745
|
-
|
|
1746
|
-
# Get class names from the dataset for the report
|
|
1747
|
-
class_names = None
|
|
1748
|
-
try:
|
|
1749
|
-
# Try to get 'classes' from ObjectDetectionDatasetMaker
|
|
1750
|
-
if hasattr(dataset_for_artifacts, 'classes'):
|
|
1751
|
-
class_names = dataset_for_artifacts.classes # type: ignore
|
|
1752
|
-
# Fallback for Subset
|
|
1753
|
-
elif hasattr(dataset_for_artifacts, 'dataset') and hasattr(dataset_for_artifacts.dataset, 'classes'): # type: ignore
|
|
1754
|
-
class_names = dataset_for_artifacts.dataset.classes # type: ignore
|
|
1755
|
-
except AttributeError:
|
|
1756
|
-
_LOGGER.warning("Could not find 'classes' attribute on dataset. Per-class metrics will not be named.")
|
|
1757
|
-
pass # class_names is still None
|
|
1758
|
-
|
|
1759
|
-
# --- Routing Logic ---
|
|
1760
|
-
object_detection_metrics(
|
|
1761
|
-
preds=all_predictions,
|
|
1762
|
-
targets=all_targets,
|
|
1763
|
-
save_dir=save_dir,
|
|
1764
|
-
class_names=class_names,
|
|
1765
|
-
print_output=False
|
|
1766
|
-
)
|
|
1767
|
-
|
|
1768
|
-
def finalize_model_training(self,
|
|
1769
|
-
save_dir: Union[str, Path],
|
|
1770
|
-
model_checkpoint: Union[Path, Literal['best', 'current']],
|
|
1771
|
-
finalize_config: FinalizeObjectDetection
|
|
1772
|
-
):
|
|
1773
|
-
"""
|
|
1774
|
-
Saves a finalized, "inference-ready" model state to a .pth file.
|
|
1775
|
-
|
|
1776
|
-
This method saves the model's `state_dict` and the final epoch number.
|
|
1777
|
-
|
|
1778
|
-
Args:
|
|
1779
|
-
save_dir (Union[str, Path]): The directory to save the finalized model.
|
|
1780
|
-
model_checkpoint (Union[Path, Literal["best", "current"]]):
|
|
1781
|
-
- Path: Loads the model state from a specific checkpoint file.
|
|
1782
|
-
- "best": Loads the best model state saved by the `DragonModelCheckpoint` callback.
|
|
1783
|
-
- "current": Uses the model's state as it is.
|
|
1784
|
-
finalize_config (FinalizeObjectDetection): A data class instance specific to the ML task containing task-specific metadata required for inference.
|
|
1785
|
-
"""
|
|
1786
|
-
if not isinstance(finalize_config, FinalizeObjectDetection):
|
|
1787
|
-
_LOGGER.error(f"For task {self.kind}, expected finalize_config of type 'FinalizeObjectDetection', but got {type(finalize_config).__name__}.")
|
|
1788
|
-
raise TypeError()
|
|
1789
|
-
|
|
1790
|
-
# handle save path
|
|
1791
|
-
dir_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
1792
|
-
full_path = dir_path / finalize_config.filename
|
|
1793
|
-
|
|
1794
|
-
# handle checkpoint
|
|
1795
|
-
self._load_model_state_for_finalizing(model_checkpoint)
|
|
1796
|
-
|
|
1797
|
-
# Create finalized data
|
|
1798
|
-
finalized_data = {
|
|
1799
|
-
PyTorchCheckpointKeys.EPOCH: self.epoch,
|
|
1800
|
-
PyTorchCheckpointKeys.MODEL_STATE: self.model.state_dict(),
|
|
1801
|
-
PyTorchCheckpointKeys.TASK: finalize_config.task
|
|
1802
|
-
}
|
|
1803
|
-
|
|
1804
|
-
if finalize_config.class_map is not None:
|
|
1805
|
-
finalized_data[PyTorchCheckpointKeys.CLASS_MAP] = finalize_config.class_map
|
|
1806
|
-
|
|
1807
|
-
torch.save(finalized_data, full_path)
|
|
1808
|
-
|
|
1809
|
-
_LOGGER.info(f"Finalized model file saved to '{full_path}'")
|
|
1810
|
-
|
|
1811
|
-
# --- DragonSequenceTrainer ----
|
|
1812
|
-
class DragonSequenceTrainer(_BaseDragonTrainer):
|
|
1813
|
-
def __init__(self,
|
|
1814
|
-
model: nn.Module,
|
|
1815
|
-
train_dataset: Dataset,
|
|
1816
|
-
validation_dataset: Dataset,
|
|
1817
|
-
kind: Literal["sequence-to-sequence", "sequence-to-value"],
|
|
1818
|
-
optimizer: torch.optim.Optimizer,
|
|
1819
|
-
device: Union[Literal['cuda', 'mps', 'cpu'],str],
|
|
1820
|
-
checkpoint_callback: Optional[DragonModelCheckpoint],
|
|
1821
|
-
early_stopping_callback: Optional[_DragonEarlyStopping],
|
|
1822
|
-
lr_scheduler_callback: Optional[_DragonLRScheduler],
|
|
1823
|
-
extra_callbacks: Optional[List[_Callback]] = None,
|
|
1824
|
-
criterion: Union[nn.Module,Literal["auto"]] = "auto",
|
|
1825
|
-
dataloader_workers: int = 2):
|
|
1826
|
-
"""
|
|
1827
|
-
Automates the training process of a PyTorch Sequence Model.
|
|
1828
|
-
|
|
1829
|
-
Built-in Callbacks: `History`, `TqdmProgressBar`
|
|
1830
|
-
|
|
1831
|
-
Args:
|
|
1832
|
-
model (nn.Module): The PyTorch model to train.
|
|
1833
|
-
train_dataset (Dataset): The training dataset.
|
|
1834
|
-
validation_dataset (Dataset): The validation dataset.
|
|
1835
|
-
kind (str): Used to redirect to the correct process ('sequence-to-sequence' or 'sequence-to-value').
|
|
1836
|
-
criterion (nn.Module | "auto"): The loss function to use. If "auto", it will be inferred from the selected task
|
|
1837
|
-
optimizer (torch.optim.Optimizer): The optimizer.
|
|
1838
|
-
device (str): The device to run training on ('cpu', 'cuda', 'mps').
|
|
1839
|
-
dataloader_workers (int): Subprocesses for data loading.
|
|
1840
|
-
extra_callbacks (List[Callback] | None): A list of extra callbacks to use during training.
|
|
1841
|
-
"""
|
|
1842
|
-
# Call the base class constructor with common parameters
|
|
1843
|
-
super().__init__(
|
|
1844
|
-
model=model,
|
|
1845
|
-
optimizer=optimizer,
|
|
1846
|
-
device=device,
|
|
1847
|
-
dataloader_workers=dataloader_workers,
|
|
1848
|
-
checkpoint_callback=checkpoint_callback,
|
|
1849
|
-
early_stopping_callback=early_stopping_callback,
|
|
1850
|
-
lr_scheduler_callback=lr_scheduler_callback,
|
|
1851
|
-
extra_callbacks=extra_callbacks
|
|
1852
|
-
)
|
|
1853
|
-
|
|
1854
|
-
if kind not in [MLTaskKeys.SEQUENCE_SEQUENCE, MLTaskKeys.SEQUENCE_VALUE]:
|
|
1855
|
-
raise ValueError(f"'{kind}' is not a valid task type for DragonSequenceTrainer.")
|
|
1856
|
-
|
|
1857
|
-
self.train_dataset = train_dataset
|
|
1858
|
-
self.validation_dataset = validation_dataset
|
|
1859
|
-
self.kind = kind
|
|
1860
|
-
|
|
1861
|
-
# try to validate against Dragon Sequence model
|
|
1862
|
-
if hasattr(self.model, "prediction_mode"):
|
|
1863
|
-
key_to_check: str = self.model.prediction_mode # type: ignore
|
|
1864
|
-
if not key_to_check == self.kind:
|
|
1865
|
-
_LOGGER.error(f"Trainer was set for '{self.kind}', but model architecture '{self.model}' is built for '{key_to_check}'.")
|
|
1866
|
-
raise RuntimeError()
|
|
1867
|
-
|
|
1868
|
-
# loss function
|
|
1869
|
-
if criterion == "auto":
|
|
1870
|
-
# Both sequence tasks are treated as regression problems
|
|
1871
|
-
self.criterion = nn.MSELoss()
|
|
1872
|
-
else:
|
|
1873
|
-
self.criterion = criterion
|
|
1874
|
-
|
|
1875
|
-
def _create_dataloaders(self, batch_size: int, shuffle: bool):
|
|
1876
|
-
"""Initializes the DataLoaders."""
|
|
1877
|
-
# Ensure stability on MPS devices by setting num_workers to 0
|
|
1878
|
-
loader_workers = 0 if self.device.type == 'mps' else self.dataloader_workers
|
|
1879
|
-
|
|
1880
|
-
self.train_loader = DataLoader(
|
|
1881
|
-
dataset=self.train_dataset,
|
|
1882
|
-
batch_size=batch_size,
|
|
1883
|
-
shuffle=shuffle,
|
|
1884
|
-
num_workers=loader_workers,
|
|
1885
|
-
pin_memory=("cuda" in self.device.type),
|
|
1886
|
-
drop_last=True # Drops the last batch if incomplete, selecting a good batch size is key.
|
|
1887
|
-
)
|
|
1888
|
-
|
|
1889
|
-
self.validation_loader = DataLoader(
|
|
1890
|
-
dataset=self.validation_dataset,
|
|
1891
|
-
batch_size=batch_size,
|
|
1892
|
-
shuffle=False,
|
|
1893
|
-
num_workers=loader_workers,
|
|
1894
|
-
pin_memory=("cuda" in self.device.type)
|
|
1895
|
-
)
|
|
1896
|
-
|
|
1897
|
-
def _train_step(self):
|
|
1898
|
-
self.model.train()
|
|
1899
|
-
running_loss = 0.0
|
|
1900
|
-
total_samples = 0
|
|
1901
|
-
|
|
1902
|
-
for batch_idx, (features, target) in enumerate(self.train_loader): # type: ignore
|
|
1903
|
-
# Create a log dictionary for the batch
|
|
1904
|
-
batch_logs = {
|
|
1905
|
-
PyTorchLogKeys.BATCH_INDEX: batch_idx,
|
|
1906
|
-
PyTorchLogKeys.BATCH_SIZE: features.size(0)
|
|
1907
|
-
}
|
|
1908
|
-
self._callbacks_hook('on_batch_begin', batch_idx, logs=batch_logs)
|
|
1909
|
-
|
|
1910
|
-
features, target = features.to(self.device), target.to(self.device)
|
|
1911
|
-
self.optimizer.zero_grad()
|
|
1912
|
-
|
|
1913
|
-
output = self.model(features)
|
|
1914
|
-
|
|
1915
|
-
# --- Label Type/Shape Correction ---
|
|
1916
|
-
# Ensure target is float for MSELoss
|
|
1917
|
-
target = target.float()
|
|
1918
|
-
|
|
1919
|
-
# For seq-to-val, models might output [N, 1] but target is [N].
|
|
1920
|
-
if self.kind == MLTaskKeys.SEQUENCE_VALUE:
|
|
1921
|
-
if output.ndim == 2 and output.shape[1] == 1 and target.ndim == 1:
|
|
1922
|
-
output = output.squeeze(1)
|
|
1923
|
-
|
|
1924
|
-
# For seq-to-seq, models might output [N, Seq, 1] but target is [N, Seq].
|
|
1925
|
-
elif self.kind == MLTaskKeys.SEQUENCE_SEQUENCE:
|
|
1926
|
-
if output.ndim == 3 and output.shape[2] == 1 and target.ndim == 2:
|
|
1927
|
-
output = output.squeeze(-1)
|
|
1928
|
-
|
|
1929
|
-
loss = self.criterion(output, target)
|
|
1930
|
-
|
|
1931
|
-
loss.backward()
|
|
1932
|
-
self.optimizer.step()
|
|
1933
|
-
|
|
1934
|
-
# Calculate batch loss and update running loss for the epoch
|
|
1935
|
-
batch_loss = loss.item()
|
|
1936
|
-
batch_size = features.size(0)
|
|
1937
|
-
running_loss += batch_loss * batch_size # Accumulate total loss
|
|
1938
|
-
total_samples += batch_size # total samples
|
|
1939
|
-
|
|
1940
|
-
# Add the batch loss to the logs and call the end-of-batch hook
|
|
1941
|
-
batch_logs[PyTorchLogKeys.BATCH_LOSS] = batch_loss
|
|
1942
|
-
self._callbacks_hook('on_batch_end', batch_idx, logs=batch_logs)
|
|
1943
|
-
|
|
1944
|
-
if total_samples == 0:
|
|
1945
|
-
_LOGGER.warning("No samples processed in a train_step. Returning 0 loss.")
|
|
1946
|
-
return {PyTorchLogKeys.TRAIN_LOSS: 0.0}
|
|
1947
|
-
|
|
1948
|
-
return {PyTorchLogKeys.TRAIN_LOSS: running_loss / total_samples} # type: ignore
|
|
1949
|
-
|
|
1950
|
-
def _validation_step(self):
|
|
1951
|
-
self.model.eval()
|
|
1952
|
-
running_loss = 0.0
|
|
1953
|
-
|
|
1954
|
-
with torch.no_grad():
|
|
1955
|
-
for features, target in self.validation_loader: # type: ignore
|
|
1956
|
-
features, target = features.to(self.device), target.to(self.device)
|
|
1957
|
-
|
|
1958
|
-
output = self.model(features)
|
|
1959
|
-
|
|
1960
|
-
# --- Label Type/Shape Correction ---
|
|
1961
|
-
target = target.float()
|
|
1962
|
-
|
|
1963
|
-
# For seq-to-val, models might output [N, 1] but target is [N].
|
|
1964
|
-
if self.kind == MLTaskKeys.SEQUENCE_VALUE:
|
|
1965
|
-
if output.ndim == 2 and output.shape[1] == 1 and target.ndim == 1:
|
|
1966
|
-
output = output.squeeze(1)
|
|
1967
|
-
|
|
1968
|
-
# For seq-to-seq, models might output [N, Seq, 1] but target is [N, Seq].
|
|
1969
|
-
elif self.kind == MLTaskKeys.SEQUENCE_SEQUENCE:
|
|
1970
|
-
if output.ndim == 3 and output.shape[2] == 1 and target.ndim == 2:
|
|
1971
|
-
output = output.squeeze(-1)
|
|
1972
|
-
|
|
1973
|
-
loss = self.criterion(output, target)
|
|
1974
|
-
|
|
1975
|
-
running_loss += loss.item() * features.size(0)
|
|
1976
|
-
|
|
1977
|
-
if not self.validation_loader.dataset: # type: ignore
|
|
1978
|
-
_LOGGER.warning("No samples processed in _validation_step. Returning 0 loss.")
|
|
1979
|
-
return {PyTorchLogKeys.VAL_LOSS: 0.0}
|
|
1980
|
-
|
|
1981
|
-
logs = {PyTorchLogKeys.VAL_LOSS: running_loss / len(self.validation_loader.dataset)} # type: ignore
|
|
1982
|
-
return logs
|
|
1983
|
-
|
|
1984
|
-
def _predict_for_eval(self, dataloader: DataLoader):
|
|
1985
|
-
"""
|
|
1986
|
-
Private method to yield model predictions batch by batch for evaluation.
|
|
1987
|
-
|
|
1988
|
-
Automatically checks for 'scaler'.
|
|
1989
|
-
|
|
1990
|
-
Yields:
|
|
1991
|
-
tuple: A tuple containing (y_pred_batch, y_prob_batch, y_true_batch).
|
|
1992
|
-
y_prob_batch is always None for sequence tasks.
|
|
1993
|
-
"""
|
|
1994
|
-
self.model.eval()
|
|
1995
|
-
self.model.to(self.device)
|
|
1996
|
-
|
|
1997
|
-
# --- Check for Scaler ---
|
|
1998
|
-
# DragonDatasetSequence stores it as 'scaler'
|
|
1999
|
-
scaler = None
|
|
2000
|
-
if hasattr(self.train_dataset, ScalerKeys.TARGET_SCALER):
|
|
2001
|
-
scaler = getattr(self.train_dataset, ScalerKeys.TARGET_SCALER)
|
|
2002
|
-
if scaler is not None:
|
|
2003
|
-
_LOGGER.debug("Sequence scaler detected. Un-scaling predictions and targets.")
|
|
2004
|
-
|
|
2005
|
-
with torch.no_grad():
|
|
2006
|
-
for features, target in dataloader:
|
|
2007
|
-
features = features.to(self.device)
|
|
2008
|
-
target = target.to(self.device)
|
|
2009
|
-
|
|
2010
|
-
output = self.model(features)
|
|
2011
|
-
|
|
2012
|
-
# --- Automatic Un-scaling Logic ---
|
|
2013
|
-
if scaler:
|
|
2014
|
-
# 1. Reshape for scaler (N, 1) or (N*Seq, 1)
|
|
2015
|
-
original_out_shape = output.shape
|
|
2016
|
-
original_target_shape = target.shape
|
|
2017
|
-
|
|
2018
|
-
# Flatten sequence dims
|
|
2019
|
-
output_flat = output.reshape(-1, 1)
|
|
2020
|
-
target_flat = target.reshape(-1, 1)
|
|
2021
|
-
|
|
2022
|
-
# 2. Inverse Transform
|
|
2023
|
-
output_flat = scaler.inverse_transform(output_flat)
|
|
2024
|
-
target_flat = scaler.inverse_transform(target_flat)
|
|
2025
|
-
|
|
2026
|
-
# 3. Restore
|
|
2027
|
-
output = output_flat.reshape(original_out_shape)
|
|
2028
|
-
target = target_flat.reshape(original_target_shape)
|
|
2029
|
-
|
|
2030
|
-
# Move to CPU
|
|
2031
|
-
y_pred_batch = output.cpu().numpy()
|
|
2032
|
-
y_true_batch = target.cpu().numpy()
|
|
2033
|
-
y_prob_batch = None
|
|
2034
|
-
|
|
2035
|
-
yield y_pred_batch, y_prob_batch, y_true_batch
|
|
2036
|
-
|
|
2037
|
-
def evaluate(self,
|
|
2038
|
-
save_dir: Union[str, Path],
|
|
2039
|
-
model_checkpoint: Union[Path, Literal["best", "current"]],
|
|
2040
|
-
test_data: Optional[Union[DataLoader, Dataset]] = None,
|
|
2041
|
-
val_format_configuration: Optional[Union[SequenceValueMetricsFormat,
|
|
2042
|
-
SequenceSequenceMetricsFormat]]=None,
|
|
2043
|
-
test_format_configuration: Optional[Union[SequenceValueMetricsFormat,
|
|
2044
|
-
SequenceSequenceMetricsFormat]]=None):
|
|
2045
|
-
"""
|
|
2046
|
-
Evaluates the model, routing to the correct evaluation function.
|
|
2047
|
-
|
|
2048
|
-
Args:
|
|
2049
|
-
model_checkpoint (Path | "best" | "current"):
|
|
2050
|
-
- Path to a valid checkpoint for the model.
|
|
2051
|
-
- If 'best', the best checkpoint will be loaded.
|
|
2052
|
-
- If 'current', use the current state of the trained model.
|
|
2053
|
-
save_dir (str | Path): Directory to save all reports and plots.
|
|
2054
|
-
test_data (DataLoader | Dataset | None): Optional Test data.
|
|
2055
|
-
val_format_configuration: Optional configuration for validation metrics.
|
|
2056
|
-
test_format_configuration: Optional configuration for test metrics.
|
|
2057
|
-
"""
|
|
2058
|
-
# Validate model checkpoint
|
|
2059
|
-
if isinstance(model_checkpoint, Path):
|
|
2060
|
-
checkpoint_validated = make_fullpath(model_checkpoint, enforce="file")
|
|
2061
|
-
elif model_checkpoint in [MagicWords.BEST, MagicWords.CURRENT]:
|
|
2062
|
-
checkpoint_validated = model_checkpoint
|
|
2063
|
-
else:
|
|
2064
|
-
_LOGGER.error(f"'model_checkpoint' must be a Path object, or '{MagicWords.BEST}', or '{MagicWords.CURRENT}'.")
|
|
2065
|
-
raise ValueError()
|
|
2066
|
-
|
|
2067
|
-
# Validate val configuration
|
|
2068
|
-
if val_format_configuration is not None:
|
|
2069
|
-
if not isinstance(val_format_configuration, (SequenceValueMetricsFormat, SequenceSequenceMetricsFormat)):
|
|
2070
|
-
_LOGGER.error(f"Invalid 'val_format_configuration': '{type(val_format_configuration)}'.")
|
|
2071
|
-
raise ValueError()
|
|
2072
|
-
|
|
2073
|
-
# Validate directory
|
|
2074
|
-
save_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
2075
|
-
|
|
2076
|
-
# Validate test data and dispatch
|
|
2077
|
-
if test_data is not None:
|
|
2078
|
-
if not isinstance(test_data, (DataLoader, Dataset)):
|
|
2079
|
-
_LOGGER.error(f"Invalid type for 'test_data': '{type(test_data)}'.")
|
|
2080
|
-
raise ValueError()
|
|
2081
|
-
test_data_validated = test_data
|
|
2082
|
-
|
|
2083
|
-
validation_metrics_path = save_path / DragonTrainerKeys.VALIDATION_METRICS_DIR
|
|
2084
|
-
test_metrics_path = save_path / DragonTrainerKeys.TEST_METRICS_DIR
|
|
2085
|
-
|
|
2086
|
-
# Dispatch validation set
|
|
2087
|
-
_LOGGER.info(f"Evaluating on validation dataset. Metrics will be saved to '{DragonTrainerKeys.VALIDATION_METRICS_DIR}'")
|
|
2088
|
-
self._evaluate(save_dir=validation_metrics_path,
|
|
2089
|
-
model_checkpoint=checkpoint_validated,
|
|
2090
|
-
data=None,
|
|
2091
|
-
format_configuration=val_format_configuration)
|
|
2092
|
-
|
|
2093
|
-
# Validate test configuration
|
|
2094
|
-
test_configuration_validated = None
|
|
2095
|
-
if test_format_configuration is not None:
|
|
2096
|
-
if not isinstance(test_format_configuration, (SequenceValueMetricsFormat, SequenceSequenceMetricsFormat)):
|
|
2097
|
-
warning_message_type = f"Invalid test_format_configuration': '{type(test_format_configuration)}'."
|
|
2098
|
-
if val_format_configuration is not None:
|
|
2099
|
-
warning_message_type += " 'val_format_configuration' will be used."
|
|
2100
|
-
test_configuration_validated = val_format_configuration
|
|
2101
|
-
else:
|
|
2102
|
-
warning_message_type += " Using default format."
|
|
2103
|
-
_LOGGER.warning(warning_message_type)
|
|
2104
|
-
else:
|
|
2105
|
-
test_configuration_validated = test_format_configuration
|
|
2106
|
-
|
|
2107
|
-
# Dispatch test set
|
|
2108
|
-
_LOGGER.info(f"Evaluating on test dataset. Metrics will be saved to '{DragonTrainerKeys.TEST_METRICS_DIR}'")
|
|
2109
|
-
self._evaluate(save_dir=test_metrics_path,
|
|
2110
|
-
model_checkpoint="current",
|
|
2111
|
-
data=test_data_validated,
|
|
2112
|
-
format_configuration=test_configuration_validated)
|
|
2113
|
-
else:
|
|
2114
|
-
# Dispatch validation set
|
|
2115
|
-
_LOGGER.info(f"Evaluating on validation dataset. Metrics will be saved to '{save_path.name}'")
|
|
2116
|
-
self._evaluate(save_dir=save_path,
|
|
2117
|
-
model_checkpoint=checkpoint_validated,
|
|
2118
|
-
data=None,
|
|
2119
|
-
format_configuration=val_format_configuration)
|
|
2120
|
-
|
|
2121
|
-
def _evaluate(self,
|
|
2122
|
-
save_dir: Union[str, Path],
|
|
2123
|
-
model_checkpoint: Union[Path, Literal["best", "current"]],
|
|
2124
|
-
data: Optional[Union[DataLoader, Dataset]],
|
|
2125
|
-
format_configuration: object):
|
|
2126
|
-
"""
|
|
2127
|
-
Private evaluation helper.
|
|
2128
|
-
"""
|
|
2129
|
-
eval_loader = None
|
|
2130
|
-
|
|
2131
|
-
# load model checkpoint
|
|
2132
|
-
if isinstance(model_checkpoint, Path):
|
|
2133
|
-
self._load_checkpoint(path=model_checkpoint)
|
|
2134
|
-
elif model_checkpoint == MagicWords.BEST and self._checkpoint_callback:
|
|
2135
|
-
path_to_latest = self._checkpoint_callback.best_checkpoint_path
|
|
2136
|
-
self._load_checkpoint(path_to_latest)
|
|
2137
|
-
elif model_checkpoint == MagicWords.BEST and self._checkpoint_callback is None:
|
|
2138
|
-
_LOGGER.error(f"'model_checkpoint' set to '{MagicWords.BEST}' but no checkpoint callback was found.")
|
|
2139
|
-
raise ValueError()
|
|
2140
|
-
|
|
2141
|
-
# Dataloader
|
|
2142
|
-
if isinstance(data, DataLoader):
|
|
2143
|
-
eval_loader = data
|
|
2144
|
-
elif isinstance(data, Dataset):
|
|
2145
|
-
# Create a new loader from the provided dataset
|
|
2146
|
-
eval_loader = DataLoader(data,
|
|
2147
|
-
batch_size=self._batch_size,
|
|
2148
|
-
shuffle=False,
|
|
2149
|
-
num_workers=0 if self.device.type == 'mps' else self.dataloader_workers,
|
|
2150
|
-
pin_memory=(self.device.type == "cuda"))
|
|
2151
|
-
else: # data is None, use the trainer's default validation dataset
|
|
2152
|
-
if self.validation_dataset is None:
|
|
2153
|
-
_LOGGER.error("Cannot evaluate. No data provided and no validation_dataset available in the trainer.")
|
|
2154
|
-
raise ValueError()
|
|
2155
|
-
eval_loader = DataLoader(self.validation_dataset,
|
|
2156
|
-
batch_size=self._batch_size,
|
|
2157
|
-
shuffle=False,
|
|
2158
|
-
num_workers=0 if self.device.type == 'mps' else self.dataloader_workers,
|
|
2159
|
-
pin_memory=(self.device.type == "cuda"))
|
|
2160
|
-
|
|
2161
|
-
if eval_loader is None:
|
|
2162
|
-
_LOGGER.error("Cannot evaluate. No valid data was provided or found.")
|
|
2163
|
-
raise ValueError()
|
|
2164
|
-
|
|
2165
|
-
all_preds, _, all_true = [], [], []
|
|
2166
|
-
for y_pred_b, y_prob_b, y_true_b in self._predict_for_eval(eval_loader):
|
|
2167
|
-
if y_pred_b is not None: all_preds.append(y_pred_b)
|
|
2168
|
-
if y_true_b is not None: all_true.append(y_true_b)
|
|
2169
|
-
|
|
2170
|
-
if not all_true:
|
|
2171
|
-
_LOGGER.error("Evaluation failed: No data was processed.")
|
|
2172
|
-
return
|
|
2173
|
-
|
|
2174
|
-
y_pred = np.concatenate(all_preds)
|
|
2175
|
-
y_true = np.concatenate(all_true)
|
|
2176
|
-
|
|
2177
|
-
# --- Routing Logic ---
|
|
2178
|
-
if self.kind == MLTaskKeys.SEQUENCE_VALUE:
|
|
2179
|
-
config = None
|
|
2180
|
-
if format_configuration and isinstance(format_configuration, SequenceValueMetricsFormat):
|
|
2181
|
-
config = format_configuration
|
|
2182
|
-
elif format_configuration:
|
|
2183
|
-
_LOGGER.warning(f"Wrong config type: Received {type(format_configuration).__name__}, expected SequenceValueMetricsFormat.")
|
|
2184
|
-
|
|
2185
|
-
sequence_to_value_metrics(y_true=y_true,
|
|
2186
|
-
y_pred=y_pred,
|
|
2187
|
-
save_dir=save_dir,
|
|
2188
|
-
config=config)
|
|
2189
|
-
|
|
2190
|
-
elif self.kind == MLTaskKeys.SEQUENCE_SEQUENCE:
|
|
2191
|
-
config = None
|
|
2192
|
-
if format_configuration and isinstance(format_configuration, SequenceSequenceMetricsFormat):
|
|
2193
|
-
config = format_configuration
|
|
2194
|
-
elif format_configuration:
|
|
2195
|
-
_LOGGER.warning(f"Wrong config type: Received {type(format_configuration).__name__}, expected SequenceSequenceMetricsFormat.")
|
|
2196
|
-
|
|
2197
|
-
sequence_to_sequence_metrics(y_true=y_true,
|
|
2198
|
-
y_pred=y_pred,
|
|
2199
|
-
save_dir=save_dir,
|
|
2200
|
-
config=config)
|
|
2201
|
-
|
|
2202
|
-
def explain_captum(self,
|
|
2203
|
-
save_dir: Union[str, Path],
|
|
2204
|
-
explain_dataset: Optional[Dataset] = None,
|
|
2205
|
-
n_samples: int = 100,
|
|
2206
|
-
feature_names: Optional[List[str]] = None,
|
|
2207
|
-
target_names: Optional[List[str]] = None,
|
|
2208
|
-
n_steps: int = 50):
|
|
2209
|
-
"""
|
|
2210
|
-
Explains sequence model predictions using Captum's Integrated Gradients.
|
|
2211
|
-
|
|
2212
|
-
This method calculates global feature importance by aggregating attributions across
|
|
2213
|
-
the time dimension.
|
|
2214
|
-
- For **multivariate** sequences, it highlights which variables (channels) are most influential.
|
|
2215
|
-
- For **univariate** sequences, it attributes importance to the single signal feature.
|
|
2216
|
-
|
|
2217
|
-
Args:
|
|
2218
|
-
save_dir (str | Path): Directory to save the importance plots and CSV reports.
|
|
2219
|
-
explain_dataset (Dataset | None): A specific dataset to sample from. If None, the
|
|
2220
|
-
trainer's validation dataset is used.
|
|
2221
|
-
n_samples (int): The number of samples to use for the explanation (background + inputs).
|
|
2222
|
-
feature_names (List[str] | None): Names of the features (signals). If None, attempts to extract them from the dataset attribute.
|
|
2223
|
-
target_names (List[str] | None): Names of the model outputs (e.g., for Seq2Seq or Multivariate output). If None, attempts to extract them from the dataset attribute.
|
|
2224
|
-
n_steps (int): Number of integral approximation steps.
|
|
2225
|
-
|
|
2226
|
-
Note:
|
|
2227
|
-
For univariate data (Shape: N, Seq_Len), the 'feature' is the signal itself.
|
|
2228
|
-
"""
|
|
2229
|
-
if not _is_captum_available():
|
|
2230
|
-
_LOGGER.error("Captum is not installed.")
|
|
2231
|
-
return
|
|
2232
|
-
|
|
2233
|
-
dataset_to_use = explain_dataset if explain_dataset is not None else self.validation_dataset
|
|
2234
|
-
if dataset_to_use is None:
|
|
2235
|
-
_LOGGER.error("No dataset available for explanation.")
|
|
2236
|
-
return
|
|
2237
|
-
|
|
2238
|
-
# Helper to sample data (same as DragonTrainer)
|
|
2239
|
-
def _get_samples(ds, n):
|
|
2240
|
-
loader = DataLoader(ds, batch_size=n, shuffle=True, num_workers=0)
|
|
2241
|
-
data_iter = iter(loader)
|
|
2242
|
-
features, targets = next(data_iter)
|
|
2243
|
-
return features, targets
|
|
2244
|
-
|
|
2245
|
-
input_data, _ = _get_samples(dataset_to_use, n_samples)
|
|
2246
|
-
|
|
2247
|
-
if feature_names is None:
|
|
2248
|
-
if hasattr(dataset_to_use, DatasetKeys.FEATURE_NAMES):
|
|
2249
|
-
feature_names = dataset_to_use.feature_names # type: ignore
|
|
2250
|
-
else:
|
|
2251
|
-
# If retrieval fails, leave it as None.
|
|
2252
|
-
_LOGGER.warning("'feature_names' not provided and not found in dataset. Generic names will be used.")
|
|
2253
|
-
|
|
2254
|
-
if target_names is None:
|
|
2255
|
-
if hasattr(dataset_to_use, DatasetKeys.TARGET_NAMES):
|
|
2256
|
-
target_names = dataset_to_use.target_names # type: ignore
|
|
2257
|
-
else:
|
|
2258
|
-
# If retrieval fails, leave it as None.
|
|
2259
|
-
_LOGGER.warning("'target_names' not provided and not found in dataset. Generic names will be used.")
|
|
2260
|
-
|
|
2261
|
-
# Sequence models usually output [N, 1] (Value) or [N, Seq, 1] (Seq2Seq)
|
|
2262
|
-
# captum_feature_importance handles the aggregation.
|
|
2263
|
-
|
|
2264
|
-
captum_feature_importance(
|
|
2265
|
-
model=self.model,
|
|
2266
|
-
input_data=input_data,
|
|
2267
|
-
feature_names=feature_names,
|
|
2268
|
-
save_dir=save_dir,
|
|
2269
|
-
target_names=target_names,
|
|
2270
|
-
n_steps=n_steps,
|
|
2271
|
-
device=self.device
|
|
2272
|
-
)
|
|
2273
|
-
|
|
2274
|
-
def finalize_model_training(self,
|
|
2275
|
-
save_dir: Union[str, Path],
|
|
2276
|
-
model_checkpoint: Union[Path, Literal['best', 'current']],
|
|
2277
|
-
finalize_config: Union[FinalizeSequenceSequencePrediction, FinalizeSequenceValuePrediction]):
|
|
2278
|
-
"""
|
|
2279
|
-
Saves a finalized, "inference-ready" model state to a .pth file.
|
|
2280
|
-
|
|
2281
|
-
This method saves the model's `state_dict` and the final epoch number.
|
|
2282
|
-
|
|
2283
|
-
Args:
|
|
2284
|
-
save_dir (Union[str, Path]): The directory to save the finalized model.
|
|
2285
|
-
model_checkpoint (Union[Path, Literal["best", "current"]]):
|
|
2286
|
-
- Path: Loads the model state from a specific checkpoint file.
|
|
2287
|
-
- "best": Loads the best model state saved by the `DragonModelCheckpoint` callback.
|
|
2288
|
-
- "current": Uses the model's state as it is.
|
|
2289
|
-
finalize_config (FinalizeSequencePrediction): A data class instance specific to the ML task containing task-specific metadata required for inference.
|
|
2290
|
-
"""
|
|
2291
|
-
if self.kind == MLTaskKeys.SEQUENCE_SEQUENCE and not isinstance(finalize_config, FinalizeSequenceSequencePrediction):
|
|
2292
|
-
_LOGGER.error(f"Received a wrong finalize configuration for task {self.kind}: {type(finalize_config).__name__}.")
|
|
2293
|
-
raise TypeError()
|
|
2294
|
-
elif self.kind == MLTaskKeys.SEQUENCE_VALUE and not isinstance(finalize_config, FinalizeSequenceValuePrediction):
|
|
2295
|
-
_LOGGER.error(f"Received a wrong finalize configuration for task {self.kind}: {type(finalize_config).__name__}.")
|
|
2296
|
-
raise TypeError()
|
|
2297
|
-
|
|
2298
|
-
# handle save path
|
|
2299
|
-
dir_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
2300
|
-
full_path = dir_path / finalize_config.filename
|
|
2301
|
-
|
|
2302
|
-
# handle checkpoint
|
|
2303
|
-
self._load_model_state_for_finalizing(model_checkpoint)
|
|
2304
|
-
|
|
2305
|
-
# Create finalized data
|
|
2306
|
-
finalized_data = {
|
|
2307
|
-
PyTorchCheckpointKeys.EPOCH: self.epoch,
|
|
2308
|
-
PyTorchCheckpointKeys.MODEL_STATE: self.model.state_dict(),
|
|
2309
|
-
PyTorchCheckpointKeys.TASK: finalize_config.task
|
|
2310
|
-
}
|
|
2311
|
-
|
|
2312
|
-
if finalize_config.sequence_length is not None:
|
|
2313
|
-
finalized_data[PyTorchCheckpointKeys.SEQUENCE_LENGTH] = finalize_config.sequence_length
|
|
2314
|
-
if finalize_config.initial_sequence is not None:
|
|
2315
|
-
finalized_data[PyTorchCheckpointKeys.INITIAL_SEQUENCE] = finalize_config.initial_sequence
|
|
2316
|
-
|
|
2317
|
-
torch.save(finalized_data, full_path)
|
|
2318
|
-
|
|
2319
|
-
_LOGGER.info(f"Finalized model file saved to '{full_path}'")
|
|
2320
|
-
|
|
2321
|
-
|
|
2322
|
-
def info():
|
|
2323
|
-
_script_info(__all__)
|