dragon-ml-toolbox 19.13.0__py3-none-any.whl → 20.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/METADATA +29 -46
- dragon_ml_toolbox-20.0.0.dist-info/RECORD +178 -0
- ml_tools/{ETL_cleaning.py → ETL_cleaning/__init__.py} +13 -5
- ml_tools/ETL_cleaning/_basic_clean.py +351 -0
- ml_tools/ETL_cleaning/_clean_tools.py +128 -0
- ml_tools/ETL_cleaning/_dragon_cleaner.py +245 -0
- ml_tools/ETL_cleaning/_imprimir.py +13 -0
- ml_tools/{ETL_engineering.py → ETL_engineering/__init__.py} +8 -4
- ml_tools/ETL_engineering/_dragon_engineering.py +261 -0
- ml_tools/ETL_engineering/_imprimir.py +24 -0
- ml_tools/{_core/_ETL_engineering.py → ETL_engineering/_transforms.py} +14 -267
- ml_tools/{_core → GUI_tools}/_GUI_tools.py +37 -40
- ml_tools/{GUI_tools.py → GUI_tools/__init__.py} +7 -5
- ml_tools/GUI_tools/_imprimir.py +12 -0
- ml_tools/IO_tools/_IO_loggers.py +235 -0
- ml_tools/IO_tools/_IO_save_load.py +151 -0
- ml_tools/IO_tools/_IO_utils.py +140 -0
- ml_tools/{IO_tools.py → IO_tools/__init__.py} +13 -5
- ml_tools/IO_tools/_imprimir.py +14 -0
- ml_tools/MICE/_MICE_imputation.py +132 -0
- ml_tools/{MICE_imputation.py → MICE/__init__.py} +6 -7
- ml_tools/{_core/_MICE_imputation.py → MICE/_dragon_mice.py} +243 -322
- ml_tools/MICE/_imprimir.py +11 -0
- ml_tools/{ML_callbacks.py → ML_callbacks/__init__.py} +12 -4
- ml_tools/ML_callbacks/_base.py +101 -0
- ml_tools/ML_callbacks/_checkpoint.py +232 -0
- ml_tools/ML_callbacks/_early_stop.py +208 -0
- ml_tools/ML_callbacks/_imprimir.py +12 -0
- ml_tools/ML_callbacks/_scheduler.py +197 -0
- ml_tools/{ML_chaining_utilities.py → ML_chain/__init__.py} +8 -3
- ml_tools/{_core/_ML_chaining_utilities.py → ML_chain/_chaining_tools.py} +5 -129
- ml_tools/ML_chain/_dragon_chain.py +140 -0
- ml_tools/ML_chain/_imprimir.py +11 -0
- ml_tools/ML_configuration/__init__.py +90 -0
- ml_tools/ML_configuration/_base_model_config.py +69 -0
- ml_tools/ML_configuration/_finalize.py +366 -0
- ml_tools/ML_configuration/_imprimir.py +47 -0
- ml_tools/ML_configuration/_metrics.py +593 -0
- ml_tools/ML_configuration/_models.py +206 -0
- ml_tools/ML_configuration/_training.py +124 -0
- ml_tools/ML_datasetmaster/__init__.py +28 -0
- ml_tools/ML_datasetmaster/_base_datasetmaster.py +337 -0
- ml_tools/{_core/_ML_datasetmaster.py → ML_datasetmaster/_datasetmaster.py} +9 -329
- ml_tools/ML_datasetmaster/_imprimir.py +15 -0
- ml_tools/{_core/_ML_sequence_datasetmaster.py → ML_datasetmaster/_sequence_datasetmaster.py} +13 -15
- ml_tools/{_core/_ML_vision_datasetmaster.py → ML_datasetmaster/_vision_datasetmaster.py} +63 -65
- ml_tools/ML_evaluation/__init__.py +53 -0
- ml_tools/ML_evaluation/_classification.py +629 -0
- ml_tools/ML_evaluation/_feature_importance.py +409 -0
- ml_tools/ML_evaluation/_imprimir.py +25 -0
- ml_tools/ML_evaluation/_loss.py +92 -0
- ml_tools/ML_evaluation/_regression.py +273 -0
- ml_tools/{_core/_ML_sequence_evaluation.py → ML_evaluation/_sequence.py} +8 -11
- ml_tools/{_core/_ML_vision_evaluation.py → ML_evaluation/_vision.py} +12 -17
- ml_tools/{_core → ML_evaluation_captum}/_ML_evaluation_captum.py +11 -38
- ml_tools/{ML_evaluation_captum.py → ML_evaluation_captum/__init__.py} +6 -4
- ml_tools/ML_evaluation_captum/_imprimir.py +10 -0
- ml_tools/{_core → ML_finalize_handler}/_ML_finalize_handler.py +3 -7
- ml_tools/ML_finalize_handler/__init__.py +10 -0
- ml_tools/ML_finalize_handler/_imprimir.py +8 -0
- ml_tools/ML_inference/__init__.py +22 -0
- ml_tools/ML_inference/_base_inference.py +166 -0
- ml_tools/{_core/_ML_chaining_inference.py → ML_inference/_chain_inference.py} +14 -17
- ml_tools/ML_inference/_dragon_inference.py +332 -0
- ml_tools/ML_inference/_imprimir.py +11 -0
- ml_tools/ML_inference/_multi_inference.py +180 -0
- ml_tools/ML_inference_sequence/__init__.py +10 -0
- ml_tools/ML_inference_sequence/_imprimir.py +8 -0
- ml_tools/{_core/_ML_sequence_inference.py → ML_inference_sequence/_sequence_inference.py} +11 -15
- ml_tools/ML_inference_vision/__init__.py +10 -0
- ml_tools/ML_inference_vision/_imprimir.py +8 -0
- ml_tools/{_core/_ML_vision_inference.py → ML_inference_vision/_vision_inference.py} +15 -19
- ml_tools/ML_models/__init__.py +32 -0
- ml_tools/{_core/_ML_models_advanced.py → ML_models/_advanced_models.py} +22 -18
- ml_tools/ML_models/_base_mlp_attention.py +198 -0
- ml_tools/{_core/_models_advanced_base.py → ML_models/_base_save_load.py} +73 -49
- ml_tools/ML_models/_dragon_tabular.py +248 -0
- ml_tools/ML_models/_imprimir.py +18 -0
- ml_tools/ML_models/_mlp_attention.py +134 -0
- ml_tools/{_core → ML_models}/_models_advanced_helpers.py +13 -13
- ml_tools/ML_models_sequence/__init__.py +10 -0
- ml_tools/ML_models_sequence/_imprimir.py +8 -0
- ml_tools/{_core/_ML_sequence_models.py → ML_models_sequence/_sequence_models.py} +5 -8
- ml_tools/ML_models_vision/__init__.py +29 -0
- ml_tools/ML_models_vision/_base_wrapper.py +254 -0
- ml_tools/ML_models_vision/_image_classification.py +182 -0
- ml_tools/ML_models_vision/_image_segmentation.py +108 -0
- ml_tools/ML_models_vision/_imprimir.py +16 -0
- ml_tools/ML_models_vision/_object_detection.py +135 -0
- ml_tools/ML_optimization/__init__.py +21 -0
- ml_tools/ML_optimization/_imprimir.py +13 -0
- ml_tools/{_core/_ML_optimization_pareto.py → ML_optimization/_multi_dragon.py} +18 -24
- ml_tools/ML_optimization/_single_dragon.py +203 -0
- ml_tools/{_core/_ML_optimization.py → ML_optimization/_single_manual.py} +75 -213
- ml_tools/{_core → ML_scaler}/_ML_scaler.py +8 -11
- ml_tools/ML_scaler/__init__.py +10 -0
- ml_tools/ML_scaler/_imprimir.py +8 -0
- ml_tools/ML_trainer/__init__.py +20 -0
- ml_tools/ML_trainer/_base_trainer.py +297 -0
- ml_tools/ML_trainer/_dragon_detection_trainer.py +402 -0
- ml_tools/ML_trainer/_dragon_sequence_trainer.py +540 -0
- ml_tools/ML_trainer/_dragon_trainer.py +1160 -0
- ml_tools/ML_trainer/_imprimir.py +10 -0
- ml_tools/{ML_utilities.py → ML_utilities/__init__.py} +14 -6
- ml_tools/ML_utilities/_artifact_finder.py +382 -0
- ml_tools/ML_utilities/_imprimir.py +16 -0
- ml_tools/ML_utilities/_inspection.py +325 -0
- ml_tools/ML_utilities/_train_tools.py +205 -0
- ml_tools/{ML_vision_transformers.py → ML_vision_transformers/__init__.py} +9 -6
- ml_tools/{_core/_ML_vision_transformers.py → ML_vision_transformers/_core_transforms.py} +11 -155
- ml_tools/ML_vision_transformers/_imprimir.py +14 -0
- ml_tools/ML_vision_transformers/_offline_augmentation.py +159 -0
- ml_tools/{_core/_PSO_optimization.py → PSO_optimization/_PSO.py} +58 -15
- ml_tools/{PSO_optimization.py → PSO_optimization/__init__.py} +5 -3
- ml_tools/PSO_optimization/_imprimir.py +10 -0
- ml_tools/SQL/__init__.py +7 -0
- ml_tools/{_core/_SQL.py → SQL/_dragon_SQL.py} +7 -11
- ml_tools/SQL/_imprimir.py +8 -0
- ml_tools/{_core → VIF}/_VIF_factor.py +5 -8
- ml_tools/{VIF_factor.py → VIF/__init__.py} +4 -2
- ml_tools/VIF/_imprimir.py +10 -0
- ml_tools/_core/__init__.py +7 -1
- ml_tools/_core/_logger.py +8 -18
- ml_tools/_core/_schema_load_ops.py +43 -0
- ml_tools/_core/_script_info.py +2 -2
- ml_tools/{data_exploration.py → data_exploration/__init__.py} +32 -16
- ml_tools/data_exploration/_analysis.py +214 -0
- ml_tools/data_exploration/_cleaning.py +566 -0
- ml_tools/data_exploration/_features.py +583 -0
- ml_tools/data_exploration/_imprimir.py +32 -0
- ml_tools/data_exploration/_plotting.py +487 -0
- ml_tools/data_exploration/_schema_ops.py +176 -0
- ml_tools/{ensemble_evaluation.py → ensemble_evaluation/__init__.py} +6 -4
- ml_tools/{_core → ensemble_evaluation}/_ensemble_evaluation.py +3 -7
- ml_tools/ensemble_evaluation/_imprimir.py +14 -0
- ml_tools/{ensemble_inference.py → ensemble_inference/__init__.py} +5 -3
- ml_tools/{_core → ensemble_inference}/_ensemble_inference.py +15 -18
- ml_tools/ensemble_inference/_imprimir.py +9 -0
- ml_tools/{ensemble_learning.py → ensemble_learning/__init__.py} +4 -6
- ml_tools/{_core → ensemble_learning}/_ensemble_learning.py +7 -10
- ml_tools/ensemble_learning/_imprimir.py +10 -0
- ml_tools/{excel_handler.py → excel_handler/__init__.py} +5 -3
- ml_tools/{_core → excel_handler}/_excel_handler.py +6 -10
- ml_tools/excel_handler/_imprimir.py +13 -0
- ml_tools/{keys.py → keys/__init__.py} +4 -1
- ml_tools/keys/_imprimir.py +11 -0
- ml_tools/{_core → keys}/_keys.py +2 -0
- ml_tools/{math_utilities.py → math_utilities/__init__.py} +5 -2
- ml_tools/math_utilities/_imprimir.py +11 -0
- ml_tools/{_core → math_utilities}/_math_utilities.py +1 -5
- ml_tools/{optimization_tools.py → optimization_tools/__init__.py} +9 -4
- ml_tools/optimization_tools/_imprimir.py +13 -0
- ml_tools/optimization_tools/_optimization_bounds.py +236 -0
- ml_tools/optimization_tools/_optimization_plots.py +218 -0
- ml_tools/{path_manager.py → path_manager/__init__.py} +6 -3
- ml_tools/{_core/_path_manager.py → path_manager/_dragonmanager.py} +11 -347
- ml_tools/path_manager/_imprimir.py +15 -0
- ml_tools/path_manager/_path_tools.py +346 -0
- ml_tools/plot_fonts/__init__.py +8 -0
- ml_tools/plot_fonts/_imprimir.py +8 -0
- ml_tools/{_core → plot_fonts}/_plot_fonts.py +2 -5
- ml_tools/schema/__init__.py +15 -0
- ml_tools/schema/_feature_schema.py +223 -0
- ml_tools/schema/_gui_schema.py +191 -0
- ml_tools/schema/_imprimir.py +10 -0
- ml_tools/{serde.py → serde/__init__.py} +4 -2
- ml_tools/serde/_imprimir.py +10 -0
- ml_tools/{_core → serde}/_serde.py +3 -8
- ml_tools/{utilities.py → utilities/__init__.py} +11 -6
- ml_tools/utilities/_imprimir.py +18 -0
- ml_tools/{_core/_utilities.py → utilities/_utility_save_load.py} +13 -190
- ml_tools/utilities/_utility_tools.py +192 -0
- dragon_ml_toolbox-19.13.0.dist-info/RECORD +0 -111
- ml_tools/ML_chaining_inference.py +0 -8
- ml_tools/ML_configuration.py +0 -86
- ml_tools/ML_configuration_pytab.py +0 -14
- ml_tools/ML_datasetmaster.py +0 -10
- ml_tools/ML_evaluation.py +0 -16
- ml_tools/ML_evaluation_multi.py +0 -12
- ml_tools/ML_finalize_handler.py +0 -8
- ml_tools/ML_inference.py +0 -12
- ml_tools/ML_models.py +0 -14
- ml_tools/ML_models_advanced.py +0 -14
- ml_tools/ML_models_pytab.py +0 -14
- ml_tools/ML_optimization.py +0 -14
- ml_tools/ML_optimization_pareto.py +0 -8
- ml_tools/ML_scaler.py +0 -8
- ml_tools/ML_sequence_datasetmaster.py +0 -8
- ml_tools/ML_sequence_evaluation.py +0 -10
- ml_tools/ML_sequence_inference.py +0 -8
- ml_tools/ML_sequence_models.py +0 -8
- ml_tools/ML_trainer.py +0 -12
- ml_tools/ML_vision_datasetmaster.py +0 -12
- ml_tools/ML_vision_evaluation.py +0 -10
- ml_tools/ML_vision_inference.py +0 -8
- ml_tools/ML_vision_models.py +0 -18
- ml_tools/SQL.py +0 -8
- ml_tools/_core/_ETL_cleaning.py +0 -694
- ml_tools/_core/_IO_tools.py +0 -498
- ml_tools/_core/_ML_callbacks.py +0 -702
- ml_tools/_core/_ML_configuration.py +0 -1332
- ml_tools/_core/_ML_configuration_pytab.py +0 -102
- ml_tools/_core/_ML_evaluation.py +0 -867
- ml_tools/_core/_ML_evaluation_multi.py +0 -544
- ml_tools/_core/_ML_inference.py +0 -646
- ml_tools/_core/_ML_models.py +0 -668
- ml_tools/_core/_ML_models_pytab.py +0 -693
- ml_tools/_core/_ML_trainer.py +0 -2323
- ml_tools/_core/_ML_utilities.py +0 -886
- ml_tools/_core/_ML_vision_models.py +0 -644
- ml_tools/_core/_data_exploration.py +0 -1901
- ml_tools/_core/_optimization_tools.py +0 -493
- ml_tools/_core/_schema.py +0 -359
- ml_tools/plot_fonts.py +0 -8
- ml_tools/schema.py +0 -12
- {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/WHEEL +0 -0
- {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/licenses/LICENSE +0 -0
- {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +0 -0
- {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,1160 @@
|
|
|
1
|
+
from typing import Literal, Union, Optional
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from torch.utils.data import DataLoader, Dataset
|
|
4
|
+
import torch
|
|
5
|
+
from torch import nn
|
|
6
|
+
import numpy as np
|
|
7
|
+
|
|
8
|
+
from ..ML_callbacks._base import _Callback
|
|
9
|
+
from ..ML_callbacks._checkpoint import DragonModelCheckpoint
|
|
10
|
+
from ..ML_callbacks._early_stop import _DragonEarlyStopping
|
|
11
|
+
from ..ML_callbacks._scheduler import _DragonLRScheduler
|
|
12
|
+
from ..ML_evaluation import classification_metrics, regression_metrics, shap_summary_plot, plot_attention_importance
|
|
13
|
+
from ..ML_evaluation import multi_target_regression_metrics, multi_label_classification_metrics, multi_target_shap_summary_plot
|
|
14
|
+
from ..ML_evaluation import segmentation_metrics
|
|
15
|
+
from ..ML_evaluation_captum import captum_feature_importance, captum_segmentation_heatmap, captum_image_heatmap
|
|
16
|
+
from ..ML_configuration import (FormatRegressionMetrics,
|
|
17
|
+
FormatMultiTargetRegressionMetrics,
|
|
18
|
+
FormatBinaryClassificationMetrics,
|
|
19
|
+
FormatMultiClassClassificationMetrics,
|
|
20
|
+
FormatBinaryImageClassificationMetrics,
|
|
21
|
+
FormatMultiClassImageClassificationMetrics,
|
|
22
|
+
FormatMultiLabelBinaryClassificationMetrics,
|
|
23
|
+
FormatBinarySegmentationMetrics,
|
|
24
|
+
FormatMultiClassSegmentationMetrics,
|
|
25
|
+
|
|
26
|
+
FinalizeBinaryClassification,
|
|
27
|
+
FinalizeBinarySegmentation,
|
|
28
|
+
FinalizeBinaryImageClassification,
|
|
29
|
+
FinalizeMultiClassClassification,
|
|
30
|
+
FinalizeMultiClassImageClassification,
|
|
31
|
+
FinalizeMultiClassSegmentation,
|
|
32
|
+
FinalizeMultiLabelBinaryClassification,
|
|
33
|
+
FinalizeMultiTargetRegression,
|
|
34
|
+
FinalizeRegression)
|
|
35
|
+
|
|
36
|
+
from ..path_manager import make_fullpath
|
|
37
|
+
from ..keys._keys import PyTorchLogKeys, PyTorchCheckpointKeys, DatasetKeys, MLTaskKeys, MagicWords, DragonTrainerKeys, ScalerKeys
|
|
38
|
+
from .._core import get_logger
|
|
39
|
+
|
|
40
|
+
from ._base_trainer import _BaseDragonTrainer
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
_LOGGER = get_logger("DragonTrainer")
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
__all__ = [
|
|
47
|
+
"DragonTrainer",
|
|
48
|
+
]
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
# --- DragonTrainer ----
|
|
52
|
+
class DragonTrainer(_BaseDragonTrainer):
|
|
53
|
+
def __init__(self,
|
|
54
|
+
model: nn.Module,
|
|
55
|
+
train_dataset: Dataset,
|
|
56
|
+
validation_dataset: Dataset,
|
|
57
|
+
kind: Literal["regression",
|
|
58
|
+
"binary classification",
|
|
59
|
+
"multiclass classification",
|
|
60
|
+
"multitarget regression",
|
|
61
|
+
"multilabel binary classification",
|
|
62
|
+
"binary segmentation",
|
|
63
|
+
"multiclass segmentation",
|
|
64
|
+
"binary image classification",
|
|
65
|
+
"multiclass image classification"],
|
|
66
|
+
optimizer: torch.optim.Optimizer,
|
|
67
|
+
device: Union[Literal['cuda', 'mps', 'cpu'],str],
|
|
68
|
+
checkpoint_callback: Optional[DragonModelCheckpoint],
|
|
69
|
+
early_stopping_callback: Optional[_DragonEarlyStopping],
|
|
70
|
+
lr_scheduler_callback: Optional[_DragonLRScheduler],
|
|
71
|
+
extra_callbacks: Optional[list[_Callback]] = None,
|
|
72
|
+
criterion: Union[nn.Module,Literal["auto"]] = "auto",
|
|
73
|
+
dataloader_workers: int = 2):
|
|
74
|
+
"""
|
|
75
|
+
Automates the training process of a PyTorch Model.
|
|
76
|
+
|
|
77
|
+
Built-in Callbacks: `History`, `TqdmProgressBar`
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
model (nn.Module): The PyTorch model to train.
|
|
81
|
+
train_dataset (Dataset): The training dataset.
|
|
82
|
+
validation_dataset (Dataset): The validation dataset.
|
|
83
|
+
kind (str): Used to redirect to the correct process.
|
|
84
|
+
criterion (nn.Module | "auto"): The loss function to use. If "auto", it will be inferred from the selected task
|
|
85
|
+
optimizer (torch.optim.Optimizer): The optimizer.
|
|
86
|
+
device (str): The device to run training on ('cpu', 'cuda', 'mps').
|
|
87
|
+
dataloader_workers (int): Subprocesses for data loading.
|
|
88
|
+
extra_callbacks (List[Callback] | None): A list of extra callbacks to use during training.
|
|
89
|
+
|
|
90
|
+
Note:
|
|
91
|
+
- For **regression** and **multi_target_regression** tasks, suggested criterions include `nn.MSELoss` or `nn.L1Loss`. The model should output as many logits as existing targets.
|
|
92
|
+
|
|
93
|
+
- For **single-label, binary classification**, `nn.BCEWithLogitsLoss` is the standard choice. The model should output a single logit.
|
|
94
|
+
|
|
95
|
+
- For **single-label, multi-class classification** tasks, `nn.CrossEntropyLoss` is the standard choice. The model should output as many logits as existing classes.
|
|
96
|
+
|
|
97
|
+
- For **multi-label, binary classification** tasks (where each label is a 0 or 1), `nn.BCEWithLogitsLoss` is the correct choice as it treats each output as an independent binary problem. The model should output 1 logit per binary target.
|
|
98
|
+
|
|
99
|
+
- For **binary segmentation** tasks, `nn.BCEWithLogitsLoss` is common. The model should output a single logit.
|
|
100
|
+
|
|
101
|
+
- for **multiclass segmentation** tasks, `nn.CrossEntropyLoss` is the standard. The model should output as many logits as existing classes.
|
|
102
|
+
"""
|
|
103
|
+
# Call the base class constructor with common parameters
|
|
104
|
+
super().__init__(
|
|
105
|
+
model=model,
|
|
106
|
+
optimizer=optimizer,
|
|
107
|
+
device=device,
|
|
108
|
+
dataloader_workers=dataloader_workers,
|
|
109
|
+
checkpoint_callback=checkpoint_callback,
|
|
110
|
+
early_stopping_callback=early_stopping_callback,
|
|
111
|
+
lr_scheduler_callback=lr_scheduler_callback,
|
|
112
|
+
extra_callbacks=extra_callbacks
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
if kind not in [MLTaskKeys.REGRESSION,
|
|
116
|
+
MLTaskKeys.BINARY_CLASSIFICATION,
|
|
117
|
+
MLTaskKeys.MULTICLASS_CLASSIFICATION,
|
|
118
|
+
MLTaskKeys.MULTILABEL_BINARY_CLASSIFICATION,
|
|
119
|
+
MLTaskKeys.MULTITARGET_REGRESSION,
|
|
120
|
+
MLTaskKeys.BINARY_SEGMENTATION,
|
|
121
|
+
MLTaskKeys.MULTICLASS_SEGMENTATION,
|
|
122
|
+
MLTaskKeys.BINARY_IMAGE_CLASSIFICATION,
|
|
123
|
+
MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION]:
|
|
124
|
+
raise ValueError(f"'{kind}' is not a valid task type.")
|
|
125
|
+
|
|
126
|
+
self.train_dataset = train_dataset
|
|
127
|
+
self.validation_dataset = validation_dataset
|
|
128
|
+
self.kind = kind
|
|
129
|
+
self._classification_threshold: float = 0.5
|
|
130
|
+
|
|
131
|
+
# loss function
|
|
132
|
+
if criterion == "auto":
|
|
133
|
+
if kind in [MLTaskKeys.REGRESSION, MLTaskKeys.MULTITARGET_REGRESSION]:
|
|
134
|
+
self.criterion = nn.MSELoss()
|
|
135
|
+
elif kind in [MLTaskKeys.BINARY_CLASSIFICATION, MLTaskKeys.BINARY_IMAGE_CLASSIFICATION, MLTaskKeys.BINARY_SEGMENTATION, MLTaskKeys.MULTILABEL_BINARY_CLASSIFICATION]:
|
|
136
|
+
self.criterion = nn.BCEWithLogitsLoss()
|
|
137
|
+
elif kind in [MLTaskKeys.MULTICLASS_CLASSIFICATION, MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION, MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION, MLTaskKeys.MULTICLASS_SEGMENTATION]:
|
|
138
|
+
self.criterion = nn.CrossEntropyLoss()
|
|
139
|
+
else:
|
|
140
|
+
self.criterion = criterion
|
|
141
|
+
|
|
142
|
+
def _create_dataloaders(self, batch_size: int, shuffle: bool):
|
|
143
|
+
"""Initializes the DataLoaders."""
|
|
144
|
+
# Ensure stability on MPS devices by setting num_workers to 0
|
|
145
|
+
loader_workers = 0 if self.device.type == 'mps' else self.dataloader_workers
|
|
146
|
+
|
|
147
|
+
self.train_loader = DataLoader(
|
|
148
|
+
dataset=self.train_dataset,
|
|
149
|
+
batch_size=batch_size,
|
|
150
|
+
shuffle=shuffle,
|
|
151
|
+
num_workers=loader_workers,
|
|
152
|
+
pin_memory=("cuda" in self.device.type),
|
|
153
|
+
drop_last=True # Drops the last batch if incomplete, selecting a good batch size is key.
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
self.validation_loader = DataLoader(
|
|
157
|
+
dataset=self.validation_dataset,
|
|
158
|
+
batch_size=batch_size,
|
|
159
|
+
shuffle=False,
|
|
160
|
+
num_workers=loader_workers,
|
|
161
|
+
pin_memory=("cuda" in self.device.type)
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
def _train_step(self):
|
|
165
|
+
self.model.train()
|
|
166
|
+
running_loss = 0.0
|
|
167
|
+
total_samples = 0
|
|
168
|
+
|
|
169
|
+
for batch_idx, (features, target) in enumerate(self.train_loader): # type: ignore
|
|
170
|
+
# Create a log dictionary for the batch
|
|
171
|
+
batch_logs = {
|
|
172
|
+
PyTorchLogKeys.BATCH_INDEX: batch_idx,
|
|
173
|
+
PyTorchLogKeys.BATCH_SIZE: features.size(0)
|
|
174
|
+
}
|
|
175
|
+
self._callbacks_hook('on_batch_begin', batch_idx, logs=batch_logs)
|
|
176
|
+
|
|
177
|
+
features, target = features.to(self.device), target.to(self.device)
|
|
178
|
+
self.optimizer.zero_grad()
|
|
179
|
+
|
|
180
|
+
output = self.model(features)
|
|
181
|
+
|
|
182
|
+
# --- Label Type/Shape Correction ---
|
|
183
|
+
# Cast target to float for BCE-based losses
|
|
184
|
+
if self.kind in MLTaskKeys.ALL_BINARY_TASKS:
|
|
185
|
+
target = target.float()
|
|
186
|
+
|
|
187
|
+
# Reshape output to match target for single-logit tasks
|
|
188
|
+
if self.kind in [MLTaskKeys.REGRESSION, MLTaskKeys.BINARY_CLASSIFICATION, MLTaskKeys.BINARY_IMAGE_CLASSIFICATION]:
|
|
189
|
+
# If model outputs [N, 1] and target is [N], squeeze output
|
|
190
|
+
if output.ndim == 2 and output.shape[1] == 1 and target.ndim == 1:
|
|
191
|
+
output = output.squeeze(1)
|
|
192
|
+
|
|
193
|
+
if self.kind == MLTaskKeys.BINARY_SEGMENTATION:
|
|
194
|
+
# If model outputs [N, 1, H, W] and target is [N, H, W], squeeze output
|
|
195
|
+
if output.ndim == 4 and output.shape[1] == 1 and target.ndim == 3:
|
|
196
|
+
output = output.squeeze(1)
|
|
197
|
+
|
|
198
|
+
loss = self.criterion(output, target)
|
|
199
|
+
|
|
200
|
+
loss.backward()
|
|
201
|
+
self.optimizer.step()
|
|
202
|
+
|
|
203
|
+
# Calculate batch loss and update running loss for the epoch
|
|
204
|
+
batch_loss = loss.item()
|
|
205
|
+
batch_size = features.size(0)
|
|
206
|
+
running_loss += batch_loss * batch_size # Accumulate total loss
|
|
207
|
+
total_samples += batch_size # total samples
|
|
208
|
+
|
|
209
|
+
# Add the batch loss to the logs and call the end-of-batch hook
|
|
210
|
+
batch_logs[PyTorchLogKeys.BATCH_LOSS] = batch_loss
|
|
211
|
+
self._callbacks_hook('on_batch_end', batch_idx, logs=batch_logs)
|
|
212
|
+
|
|
213
|
+
if total_samples == 0:
|
|
214
|
+
_LOGGER.warning("No samples processed in a train_step. Returning 0 loss.")
|
|
215
|
+
return {PyTorchLogKeys.TRAIN_LOSS: 0.0}
|
|
216
|
+
|
|
217
|
+
return {PyTorchLogKeys.TRAIN_LOSS: running_loss / total_samples} # type: ignore
|
|
218
|
+
|
|
219
|
+
def _validation_step(self):
|
|
220
|
+
self.model.eval()
|
|
221
|
+
running_loss = 0.0
|
|
222
|
+
|
|
223
|
+
with torch.no_grad():
|
|
224
|
+
for features, target in self.validation_loader: # type: ignore
|
|
225
|
+
features, target = features.to(self.device), target.to(self.device)
|
|
226
|
+
|
|
227
|
+
output = self.model(features)
|
|
228
|
+
|
|
229
|
+
# --- Label Type/Shape Correction ---
|
|
230
|
+
# Cast target to float for BCE-based losses
|
|
231
|
+
if self.kind in MLTaskKeys.ALL_BINARY_TASKS:
|
|
232
|
+
target = target.float()
|
|
233
|
+
|
|
234
|
+
# Reshape output to match target for single-logit tasks
|
|
235
|
+
if self.kind in [MLTaskKeys.REGRESSION, MLTaskKeys.BINARY_CLASSIFICATION, MLTaskKeys.BINARY_IMAGE_CLASSIFICATION]:
|
|
236
|
+
# If model outputs [N, 1] and target is [N], squeeze output
|
|
237
|
+
if output.ndim == 2 and output.shape[1] == 1 and target.ndim == 1:
|
|
238
|
+
output = output.squeeze(1)
|
|
239
|
+
|
|
240
|
+
if self.kind == MLTaskKeys.BINARY_SEGMENTATION:
|
|
241
|
+
# If model outputs [N, 1, H, W] and target is [N, H, W], squeeze output
|
|
242
|
+
if output.ndim == 4 and output.shape[1] == 1 and target.ndim == 3:
|
|
243
|
+
output = output.squeeze(1)
|
|
244
|
+
|
|
245
|
+
loss = self.criterion(output, target)
|
|
246
|
+
|
|
247
|
+
running_loss += loss.item() * features.size(0)
|
|
248
|
+
|
|
249
|
+
if not self.validation_loader.dataset: # type: ignore
|
|
250
|
+
_LOGGER.warning("No samples processed in _validation_step. Returning 0 loss.")
|
|
251
|
+
return {PyTorchLogKeys.VAL_LOSS: 0.0}
|
|
252
|
+
|
|
253
|
+
logs = {PyTorchLogKeys.VAL_LOSS: running_loss / len(self.validation_loader.dataset)} # type: ignore
|
|
254
|
+
return logs
|
|
255
|
+
|
|
256
|
+
def _predict_for_eval(self, dataloader: DataLoader):
|
|
257
|
+
"""
|
|
258
|
+
Private method to yield model predictions batch by batch for evaluation.
|
|
259
|
+
|
|
260
|
+
Automatically detects if `target_scaler` is present in the training dataset
|
|
261
|
+
and applies inverse transformation for Regression tasks.
|
|
262
|
+
|
|
263
|
+
Yields:
|
|
264
|
+
tuple: A tuple containing (y_pred_batch, y_prob_batch, y_true_batch).
|
|
265
|
+
|
|
266
|
+
- y_prob_batch is None for regression tasks.
|
|
267
|
+
"""
|
|
268
|
+
self.model.eval()
|
|
269
|
+
self.model.to(self.device)
|
|
270
|
+
|
|
271
|
+
# --- Check for Target Scaler (for Regression Un-scaling) ---
|
|
272
|
+
target_scaler = None
|
|
273
|
+
if self.kind in [MLTaskKeys.REGRESSION, MLTaskKeys.MULTITARGET_REGRESSION]:
|
|
274
|
+
# Try to get the scaler from the dataset attached to the trainer
|
|
275
|
+
if hasattr(self.train_dataset, ScalerKeys.TARGET_SCALER):
|
|
276
|
+
target_scaler = getattr(self.train_dataset, ScalerKeys.TARGET_SCALER)
|
|
277
|
+
if target_scaler is not None:
|
|
278
|
+
_LOGGER.debug("Target scaler detected. Un-scaling predictions and targets for metric calculation.")
|
|
279
|
+
|
|
280
|
+
with torch.no_grad():
|
|
281
|
+
for features, target in dataloader:
|
|
282
|
+
features = features.to(self.device)
|
|
283
|
+
# Keep target on device initially for potential un-scaling
|
|
284
|
+
target = target.to(self.device)
|
|
285
|
+
|
|
286
|
+
output = self.model(features)
|
|
287
|
+
|
|
288
|
+
y_pred_batch = None
|
|
289
|
+
y_prob_batch = None
|
|
290
|
+
y_true_batch = None
|
|
291
|
+
|
|
292
|
+
if self.kind in [MLTaskKeys.REGRESSION, MLTaskKeys.MULTITARGET_REGRESSION]:
|
|
293
|
+
|
|
294
|
+
# --- Automatic Un-scaling Logic ---
|
|
295
|
+
if target_scaler:
|
|
296
|
+
# 1. Reshape output/target if flattened (common in single regression)
|
|
297
|
+
# Scaler expects [N, Features]
|
|
298
|
+
original_out_shape = output.shape
|
|
299
|
+
original_target_shape = target.shape
|
|
300
|
+
|
|
301
|
+
if output.ndim == 1: output = output.reshape(-1, 1)
|
|
302
|
+
if target.ndim == 1: target = target.reshape(-1, 1)
|
|
303
|
+
|
|
304
|
+
# 2. Apply Inverse Transform
|
|
305
|
+
output = target_scaler.inverse_transform(output)
|
|
306
|
+
target = target_scaler.inverse_transform(target)
|
|
307
|
+
|
|
308
|
+
# 3. Restore shapes (optional, but good for consistency)
|
|
309
|
+
if len(original_out_shape) == 1: output = output.flatten()
|
|
310
|
+
if len(original_target_shape) == 1: target = target.flatten()
|
|
311
|
+
|
|
312
|
+
y_pred_batch = output.cpu().numpy()
|
|
313
|
+
y_true_batch = target.cpu().numpy()
|
|
314
|
+
|
|
315
|
+
elif self.kind in [MLTaskKeys.BINARY_CLASSIFICATION, MLTaskKeys.BINARY_IMAGE_CLASSIFICATION]:
|
|
316
|
+
if output.ndim == 2 and output.shape[1] == 1:
|
|
317
|
+
output = output.squeeze(1)
|
|
318
|
+
|
|
319
|
+
probs_pos = torch.sigmoid(output)
|
|
320
|
+
preds = (probs_pos >= self._classification_threshold).int()
|
|
321
|
+
y_pred_batch = preds.cpu().numpy()
|
|
322
|
+
|
|
323
|
+
probs_neg = 1.0 - probs_pos
|
|
324
|
+
y_prob_batch = torch.stack([probs_neg, probs_pos], dim=1).cpu().numpy()
|
|
325
|
+
y_true_batch = target.cpu().numpy()
|
|
326
|
+
|
|
327
|
+
elif self.kind in [MLTaskKeys.MULTICLASS_CLASSIFICATION, MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION]:
|
|
328
|
+
probs = torch.softmax(output, dim=1)
|
|
329
|
+
preds = torch.argmax(probs, dim=1)
|
|
330
|
+
y_pred_batch = preds.cpu().numpy()
|
|
331
|
+
y_prob_batch = probs.cpu().numpy()
|
|
332
|
+
y_true_batch = target.cpu().numpy()
|
|
333
|
+
|
|
334
|
+
elif self.kind == MLTaskKeys.MULTILABEL_BINARY_CLASSIFICATION:
|
|
335
|
+
probs = torch.sigmoid(output)
|
|
336
|
+
preds = (probs >= self._classification_threshold).int()
|
|
337
|
+
y_pred_batch = preds.cpu().numpy()
|
|
338
|
+
y_prob_batch = probs.cpu().numpy()
|
|
339
|
+
y_true_batch = target.cpu().numpy()
|
|
340
|
+
|
|
341
|
+
elif self.kind == MLTaskKeys.BINARY_SEGMENTATION:
|
|
342
|
+
probs_pos = torch.sigmoid(output)
|
|
343
|
+
preds = (probs_pos >= self._classification_threshold).int()
|
|
344
|
+
y_pred_batch = preds.squeeze(1).cpu().numpy()
|
|
345
|
+
|
|
346
|
+
probs_neg = 1.0 - probs_pos
|
|
347
|
+
y_prob_batch = torch.cat([probs_neg, probs_pos], dim=1).cpu().numpy()
|
|
348
|
+
|
|
349
|
+
if target.ndim == 4 and target.shape[1] == 1:
|
|
350
|
+
target = target.squeeze(1)
|
|
351
|
+
y_true_batch = target.cpu().numpy()
|
|
352
|
+
|
|
353
|
+
elif self.kind == MLTaskKeys.MULTICLASS_SEGMENTATION:
|
|
354
|
+
probs = torch.softmax(output, dim=1)
|
|
355
|
+
preds = torch.argmax(probs, dim=1)
|
|
356
|
+
y_pred_batch = preds.cpu().numpy()
|
|
357
|
+
y_prob_batch = probs.cpu().numpy()
|
|
358
|
+
|
|
359
|
+
if target.ndim == 4 and target.shape[1] == 1:
|
|
360
|
+
target = target.squeeze(1)
|
|
361
|
+
y_true_batch = target.cpu().numpy()
|
|
362
|
+
|
|
363
|
+
yield y_pred_batch, y_prob_batch, y_true_batch
|
|
364
|
+
|
|
365
|
+
def evaluate(self,
|
|
366
|
+
save_dir: Union[str, Path],
|
|
367
|
+
model_checkpoint: Union[Path, Literal["best", "current"]],
|
|
368
|
+
classification_threshold: Optional[float] = None,
|
|
369
|
+
test_data: Optional[Union[DataLoader, Dataset]] = None,
|
|
370
|
+
val_format_configuration: Optional[Union[
|
|
371
|
+
FormatRegressionMetrics,
|
|
372
|
+
FormatMultiTargetRegressionMetrics,
|
|
373
|
+
FormatBinaryClassificationMetrics,
|
|
374
|
+
FormatMultiClassClassificationMetrics,
|
|
375
|
+
FormatBinaryImageClassificationMetrics,
|
|
376
|
+
FormatMultiClassImageClassificationMetrics,
|
|
377
|
+
FormatMultiLabelBinaryClassificationMetrics,
|
|
378
|
+
FormatBinarySegmentationMetrics,
|
|
379
|
+
FormatMultiClassSegmentationMetrics
|
|
380
|
+
]]=None,
|
|
381
|
+
test_format_configuration: Optional[Union[
|
|
382
|
+
FormatRegressionMetrics,
|
|
383
|
+
FormatMultiTargetRegressionMetrics,
|
|
384
|
+
FormatBinaryClassificationMetrics,
|
|
385
|
+
FormatMultiClassClassificationMetrics,
|
|
386
|
+
FormatBinaryImageClassificationMetrics,
|
|
387
|
+
FormatMultiClassImageClassificationMetrics,
|
|
388
|
+
FormatMultiLabelBinaryClassificationMetrics,
|
|
389
|
+
FormatBinarySegmentationMetrics,
|
|
390
|
+
FormatMultiClassSegmentationMetrics,
|
|
391
|
+
]]=None):
|
|
392
|
+
"""
|
|
393
|
+
Evaluates the model, routing to the correct evaluation function based on task `kind`.
|
|
394
|
+
|
|
395
|
+
Args:
|
|
396
|
+
model_checkpoint (Path | "best" | "current"):
|
|
397
|
+
- Path to a valid checkpoint for the model. The state of the trained model will be overwritten in place.
|
|
398
|
+
- If 'best', the best checkpoint will be loaded if a DragonModelCheckpoint was provided. The state of the trained model will be overwritten in place.
|
|
399
|
+
- If 'current', use the current state of the trained model up the latest trained epoch.
|
|
400
|
+
save_dir (str | Path): Directory to save all reports and plots.
|
|
401
|
+
classification_threshold (float | None): Used for tasks using a binary approach (binary classification, binary segmentation, multilabel binary classification)
|
|
402
|
+
test_data (DataLoader | Dataset | None): Optional Test data to evaluate the model performance. Validation and Test metrics will be saved to subdirectories.
|
|
403
|
+
val_format_configuration (object): Optional configuration for metric format output for the validation set.
|
|
404
|
+
test_format_configuration (object): Optional configuration for metric format output for the test set.
|
|
405
|
+
"""
|
|
406
|
+
# Validate model checkpoint
|
|
407
|
+
if isinstance(model_checkpoint, Path):
|
|
408
|
+
checkpoint_validated = make_fullpath(model_checkpoint, enforce="file")
|
|
409
|
+
elif model_checkpoint in [MagicWords.BEST, MagicWords.CURRENT]:
|
|
410
|
+
checkpoint_validated = model_checkpoint
|
|
411
|
+
else:
|
|
412
|
+
_LOGGER.error(f"'model_checkpoint' must be a Path object, or the string '{MagicWords.BEST}', or the string '{MagicWords.CURRENT}'.")
|
|
413
|
+
raise ValueError()
|
|
414
|
+
|
|
415
|
+
# Validate classification threshold
|
|
416
|
+
if self.kind not in MLTaskKeys.ALL_BINARY_TASKS:
|
|
417
|
+
# dummy value for tasks that do not need it
|
|
418
|
+
threshold_validated = 0.5
|
|
419
|
+
elif classification_threshold is None:
|
|
420
|
+
# it should have been provided for binary tasks
|
|
421
|
+
_LOGGER.error(f"The classification threshold must be provided for '{self.kind}'.")
|
|
422
|
+
raise ValueError()
|
|
423
|
+
elif classification_threshold <= 0.0 or classification_threshold >= 1.0:
|
|
424
|
+
# Invalid float
|
|
425
|
+
_LOGGER.error(f"A classification threshold of {classification_threshold} is invalid. Must be in the range (0.0 - 1.0).")
|
|
426
|
+
raise ValueError()
|
|
427
|
+
else:
|
|
428
|
+
threshold_validated = classification_threshold
|
|
429
|
+
|
|
430
|
+
# Validate val configuration
|
|
431
|
+
if val_format_configuration is not None:
|
|
432
|
+
if not isinstance(val_format_configuration, (FormatRegressionMetrics,
|
|
433
|
+
FormatMultiTargetRegressionMetrics,
|
|
434
|
+
FormatBinaryClassificationMetrics,
|
|
435
|
+
FormatMultiClassClassificationMetrics,
|
|
436
|
+
FormatBinaryImageClassificationMetrics,
|
|
437
|
+
FormatMultiClassImageClassificationMetrics,
|
|
438
|
+
FormatMultiLabelBinaryClassificationMetrics,
|
|
439
|
+
FormatBinarySegmentationMetrics,
|
|
440
|
+
FormatMultiClassSegmentationMetrics)):
|
|
441
|
+
_LOGGER.error(f"Invalid 'format_configuration': '{type(val_format_configuration)}'.")
|
|
442
|
+
raise ValueError()
|
|
443
|
+
else:
|
|
444
|
+
val_configuration_validated = val_format_configuration
|
|
445
|
+
else: # config is None
|
|
446
|
+
val_configuration_validated = None
|
|
447
|
+
|
|
448
|
+
# Validate directory
|
|
449
|
+
save_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
450
|
+
|
|
451
|
+
# Validate test data and dispatch
|
|
452
|
+
if test_data is not None:
|
|
453
|
+
if not isinstance(test_data, (DataLoader, Dataset)):
|
|
454
|
+
_LOGGER.error(f"Invalid type for 'test_data': '{type(test_data)}'.")
|
|
455
|
+
raise ValueError()
|
|
456
|
+
test_data_validated = test_data
|
|
457
|
+
|
|
458
|
+
validation_metrics_path = save_path / DragonTrainerKeys.VALIDATION_METRICS_DIR
|
|
459
|
+
test_metrics_path = save_path / DragonTrainerKeys.TEST_METRICS_DIR
|
|
460
|
+
|
|
461
|
+
# Dispatch validation set
|
|
462
|
+
_LOGGER.info(f"🔎 Evaluating on validation dataset. Metrics will be saved to '{DragonTrainerKeys.VALIDATION_METRICS_DIR}'")
|
|
463
|
+
self._evaluate(save_dir=validation_metrics_path,
|
|
464
|
+
model_checkpoint=checkpoint_validated,
|
|
465
|
+
classification_threshold=threshold_validated,
|
|
466
|
+
data=None,
|
|
467
|
+
format_configuration=val_configuration_validated)
|
|
468
|
+
|
|
469
|
+
# Validate test configuration
|
|
470
|
+
if test_format_configuration is not None:
|
|
471
|
+
if not isinstance(test_format_configuration, (FormatRegressionMetrics,
|
|
472
|
+
FormatMultiTargetRegressionMetrics,
|
|
473
|
+
FormatBinaryClassificationMetrics,
|
|
474
|
+
FormatMultiClassClassificationMetrics,
|
|
475
|
+
FormatBinaryImageClassificationMetrics,
|
|
476
|
+
FormatMultiClassImageClassificationMetrics,
|
|
477
|
+
FormatMultiLabelBinaryClassificationMetrics,
|
|
478
|
+
FormatBinarySegmentationMetrics,
|
|
479
|
+
FormatMultiClassSegmentationMetrics)):
|
|
480
|
+
warning_message_type = f"Invalid test_format_configuration': '{type(test_format_configuration)}'."
|
|
481
|
+
if val_configuration_validated is not None:
|
|
482
|
+
warning_message_type += " 'val_format_configuration' will be used for the test set metrics output."
|
|
483
|
+
test_configuration_validated = val_configuration_validated
|
|
484
|
+
else:
|
|
485
|
+
warning_message_type += " Using default format."
|
|
486
|
+
test_configuration_validated = None
|
|
487
|
+
_LOGGER.warning(warning_message_type)
|
|
488
|
+
else:
|
|
489
|
+
test_configuration_validated = test_format_configuration
|
|
490
|
+
else: #config is None
|
|
491
|
+
test_configuration_validated = None
|
|
492
|
+
|
|
493
|
+
# Dispatch test set
|
|
494
|
+
_LOGGER.info(f"🔎 Evaluating on test dataset. Metrics will be saved to '{DragonTrainerKeys.TEST_METRICS_DIR}'")
|
|
495
|
+
self._evaluate(save_dir=test_metrics_path,
|
|
496
|
+
model_checkpoint="current",
|
|
497
|
+
classification_threshold=threshold_validated,
|
|
498
|
+
data=test_data_validated,
|
|
499
|
+
format_configuration=test_configuration_validated)
|
|
500
|
+
else:
|
|
501
|
+
# Dispatch validation set
|
|
502
|
+
_LOGGER.info(f"Evaluating on validation dataset. Metrics will be saved to '{save_path.name}'")
|
|
503
|
+
self._evaluate(save_dir=save_path,
|
|
504
|
+
model_checkpoint=checkpoint_validated,
|
|
505
|
+
classification_threshold=threshold_validated,
|
|
506
|
+
data=None,
|
|
507
|
+
format_configuration=val_configuration_validated)
|
|
508
|
+
|
|
509
|
+
def _evaluate(self,
|
|
510
|
+
save_dir: Union[str, Path],
|
|
511
|
+
model_checkpoint: Union[Path, Literal["best", "current"]],
|
|
512
|
+
classification_threshold: float,
|
|
513
|
+
data: Optional[Union[DataLoader, Dataset]],
|
|
514
|
+
format_configuration: Optional[Union[
|
|
515
|
+
FormatRegressionMetrics,
|
|
516
|
+
FormatMultiTargetRegressionMetrics,
|
|
517
|
+
FormatBinaryClassificationMetrics,
|
|
518
|
+
FormatMultiClassClassificationMetrics,
|
|
519
|
+
FormatBinaryImageClassificationMetrics,
|
|
520
|
+
FormatMultiClassImageClassificationMetrics,
|
|
521
|
+
FormatMultiLabelBinaryClassificationMetrics,
|
|
522
|
+
FormatBinarySegmentationMetrics,
|
|
523
|
+
FormatMultiClassSegmentationMetrics
|
|
524
|
+
]]=None):
|
|
525
|
+
"""
|
|
526
|
+
Changed to a private helper function.
|
|
527
|
+
"""
|
|
528
|
+
dataset_for_artifacts = None
|
|
529
|
+
eval_loader = None
|
|
530
|
+
|
|
531
|
+
# set threshold
|
|
532
|
+
self._classification_threshold = classification_threshold
|
|
533
|
+
|
|
534
|
+
# load model checkpoint
|
|
535
|
+
if isinstance(model_checkpoint, Path):
|
|
536
|
+
self._load_checkpoint(path=model_checkpoint)
|
|
537
|
+
elif model_checkpoint == MagicWords.BEST and self._checkpoint_callback:
|
|
538
|
+
path_to_latest = self._checkpoint_callback.best_checkpoint_path
|
|
539
|
+
self._load_checkpoint(path_to_latest)
|
|
540
|
+
elif model_checkpoint == MagicWords.BEST and self._checkpoint_callback is None:
|
|
541
|
+
_LOGGER.error(f"'model_checkpoint' set to '{MagicWords.BEST}' but no checkpoint callback was found.")
|
|
542
|
+
raise ValueError()
|
|
543
|
+
|
|
544
|
+
# Dataloader
|
|
545
|
+
if isinstance(data, DataLoader):
|
|
546
|
+
eval_loader = data
|
|
547
|
+
# Try to get the dataset from the loader for fetching target names
|
|
548
|
+
if hasattr(data, 'dataset'):
|
|
549
|
+
dataset_for_artifacts = data.dataset # type: ignore
|
|
550
|
+
elif isinstance(data, Dataset):
|
|
551
|
+
# Create a new loader from the provided dataset
|
|
552
|
+
eval_loader = DataLoader(data,
|
|
553
|
+
batch_size=self._batch_size,
|
|
554
|
+
shuffle=False,
|
|
555
|
+
num_workers=0 if self.device.type == 'mps' else self.dataloader_workers,
|
|
556
|
+
pin_memory=(self.device.type == "cuda"))
|
|
557
|
+
dataset_for_artifacts = data
|
|
558
|
+
else: # data is None, use the trainer's default test dataset
|
|
559
|
+
if self.validation_dataset is None:
|
|
560
|
+
_LOGGER.error("Cannot evaluate. No data provided and no validation dataset available in the trainer.")
|
|
561
|
+
raise ValueError()
|
|
562
|
+
# Create a fresh DataLoader from the test_dataset
|
|
563
|
+
eval_loader = DataLoader(self.validation_dataset,
|
|
564
|
+
batch_size=self._batch_size,
|
|
565
|
+
shuffle=False,
|
|
566
|
+
num_workers=0 if self.device.type == 'mps' else self.dataloader_workers,
|
|
567
|
+
pin_memory=(self.device.type == "cuda"))
|
|
568
|
+
|
|
569
|
+
dataset_for_artifacts = self.validation_dataset
|
|
570
|
+
|
|
571
|
+
if eval_loader is None:
|
|
572
|
+
_LOGGER.error("Cannot evaluate. No valid data was provided or found.")
|
|
573
|
+
raise ValueError()
|
|
574
|
+
|
|
575
|
+
# print("\n--- Model Evaluation ---")
|
|
576
|
+
|
|
577
|
+
all_preds, all_probs, all_true = [], [], []
|
|
578
|
+
for y_pred_b, y_prob_b, y_true_b in self._predict_for_eval(eval_loader):
|
|
579
|
+
if y_pred_b is not None: all_preds.append(y_pred_b)
|
|
580
|
+
if y_prob_b is not None: all_probs.append(y_prob_b)
|
|
581
|
+
if y_true_b is not None: all_true.append(y_true_b)
|
|
582
|
+
|
|
583
|
+
if not all_true:
|
|
584
|
+
_LOGGER.error("Evaluation failed: No data was processed.")
|
|
585
|
+
return
|
|
586
|
+
|
|
587
|
+
y_pred = np.concatenate(all_preds)
|
|
588
|
+
y_true = np.concatenate(all_true)
|
|
589
|
+
y_prob = np.concatenate(all_probs) if all_probs else None
|
|
590
|
+
|
|
591
|
+
# --- Routing Logic ---
|
|
592
|
+
# Single-target regression
|
|
593
|
+
if self.kind == MLTaskKeys.REGRESSION:
|
|
594
|
+
# Check configuration
|
|
595
|
+
config = None
|
|
596
|
+
if format_configuration and isinstance(format_configuration, FormatRegressionMetrics):
|
|
597
|
+
config = format_configuration
|
|
598
|
+
elif format_configuration:
|
|
599
|
+
_LOGGER.warning(f"Wrong configuration type: Received '{type(format_configuration).__name__}'.")
|
|
600
|
+
|
|
601
|
+
regression_metrics(y_true=y_true.flatten(),
|
|
602
|
+
y_pred=y_pred.flatten(),
|
|
603
|
+
save_dir=save_dir,
|
|
604
|
+
config=config)
|
|
605
|
+
|
|
606
|
+
# single target classification
|
|
607
|
+
elif self.kind in [MLTaskKeys.BINARY_CLASSIFICATION,
|
|
608
|
+
MLTaskKeys.BINARY_IMAGE_CLASSIFICATION,
|
|
609
|
+
MLTaskKeys.MULTICLASS_CLASSIFICATION,
|
|
610
|
+
MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION]:
|
|
611
|
+
# get the class map if it exists
|
|
612
|
+
try:
|
|
613
|
+
class_map = dataset_for_artifacts.class_map # type: ignore
|
|
614
|
+
except AttributeError:
|
|
615
|
+
_LOGGER.warning(f"Dataset has no 'class_map' attribute. Using generics.")
|
|
616
|
+
class_map = None
|
|
617
|
+
else:
|
|
618
|
+
if not isinstance(class_map, dict):
|
|
619
|
+
_LOGGER.warning(f"Dataset has a 'class_map' attribute, but it is not a dictionary: '{type(class_map)}'.")
|
|
620
|
+
class_map = None
|
|
621
|
+
|
|
622
|
+
# Check configuration
|
|
623
|
+
config = None
|
|
624
|
+
if format_configuration:
|
|
625
|
+
if self.kind == MLTaskKeys.BINARY_CLASSIFICATION and isinstance(format_configuration, FormatBinaryClassificationMetrics):
|
|
626
|
+
config = format_configuration
|
|
627
|
+
elif self.kind == MLTaskKeys.BINARY_IMAGE_CLASSIFICATION and isinstance(format_configuration, FormatBinaryImageClassificationMetrics):
|
|
628
|
+
config = format_configuration
|
|
629
|
+
elif self.kind == MLTaskKeys.MULTICLASS_CLASSIFICATION and isinstance(format_configuration, FormatMultiClassClassificationMetrics):
|
|
630
|
+
config = format_configuration
|
|
631
|
+
elif self.kind == MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION and isinstance(format_configuration, FormatMultiClassImageClassificationMetrics):
|
|
632
|
+
config = format_configuration
|
|
633
|
+
else:
|
|
634
|
+
_LOGGER.warning(f"Wrong configuration type: Received '{type(format_configuration).__name__}'.")
|
|
635
|
+
|
|
636
|
+
classification_metrics(save_dir=save_dir,
|
|
637
|
+
y_true=y_true,
|
|
638
|
+
y_pred=y_pred,
|
|
639
|
+
y_prob=y_prob,
|
|
640
|
+
class_map=class_map,
|
|
641
|
+
config=config)
|
|
642
|
+
|
|
643
|
+
# multitarget regression
|
|
644
|
+
elif self.kind == MLTaskKeys.MULTITARGET_REGRESSION:
|
|
645
|
+
try:
|
|
646
|
+
target_names = dataset_for_artifacts.target_names # type: ignore
|
|
647
|
+
except AttributeError:
|
|
648
|
+
num_targets = y_true.shape[1]
|
|
649
|
+
target_names = [f"target_{i}" for i in range(num_targets)]
|
|
650
|
+
_LOGGER.warning(f"Dataset has no 'target_names' attribute. Using generic names.")
|
|
651
|
+
|
|
652
|
+
# Check configuration
|
|
653
|
+
config = None
|
|
654
|
+
if format_configuration and isinstance(format_configuration, FormatMultiTargetRegressionMetrics):
|
|
655
|
+
config = format_configuration
|
|
656
|
+
elif format_configuration:
|
|
657
|
+
_LOGGER.warning(f"Wrong configuration type: Received '{type(format_configuration).__name__}'.")
|
|
658
|
+
|
|
659
|
+
multi_target_regression_metrics(y_true=y_true,
|
|
660
|
+
y_pred=y_pred,
|
|
661
|
+
target_names=target_names,
|
|
662
|
+
save_dir=save_dir,
|
|
663
|
+
config=config)
|
|
664
|
+
|
|
665
|
+
# multi-label binary classification
|
|
666
|
+
elif self.kind == MLTaskKeys.MULTILABEL_BINARY_CLASSIFICATION:
|
|
667
|
+
try:
|
|
668
|
+
target_names = dataset_for_artifacts.target_names # type: ignore
|
|
669
|
+
except AttributeError:
|
|
670
|
+
num_targets = y_true.shape[1]
|
|
671
|
+
target_names = [f"label_{i}" for i in range(num_targets)]
|
|
672
|
+
_LOGGER.warning(f"Dataset has no 'target_names' attribute. Using generic names.")
|
|
673
|
+
|
|
674
|
+
if y_prob is None:
|
|
675
|
+
_LOGGER.error("Evaluation for multi_label_classification requires probabilities (y_prob).")
|
|
676
|
+
return
|
|
677
|
+
|
|
678
|
+
# Check configuration
|
|
679
|
+
config = None
|
|
680
|
+
if format_configuration and isinstance(format_configuration, FormatMultiLabelBinaryClassificationMetrics):
|
|
681
|
+
config = format_configuration
|
|
682
|
+
elif format_configuration:
|
|
683
|
+
_LOGGER.warning(f"Wrong configuration type: Received '{type(format_configuration).__name__}'.")
|
|
684
|
+
|
|
685
|
+
multi_label_classification_metrics(y_true=y_true,
|
|
686
|
+
y_pred=y_pred,
|
|
687
|
+
y_prob=y_prob,
|
|
688
|
+
target_names=target_names,
|
|
689
|
+
save_dir=save_dir,
|
|
690
|
+
config=config)
|
|
691
|
+
|
|
692
|
+
# Segmentation tasks
|
|
693
|
+
elif self.kind in [MLTaskKeys.BINARY_SEGMENTATION, MLTaskKeys.MULTICLASS_SEGMENTATION]:
|
|
694
|
+
class_names = None
|
|
695
|
+
try:
|
|
696
|
+
# Try to get 'classes' from VisionDatasetMaker
|
|
697
|
+
if hasattr(dataset_for_artifacts, 'classes'):
|
|
698
|
+
class_names = dataset_for_artifacts.classes # type: ignore
|
|
699
|
+
# Fallback for Subset
|
|
700
|
+
elif hasattr(dataset_for_artifacts, 'dataset') and hasattr(dataset_for_artifacts.dataset, 'classes'): # type: ignore
|
|
701
|
+
class_names = dataset_for_artifacts.dataset.classes # type: ignore
|
|
702
|
+
except AttributeError:
|
|
703
|
+
pass # class_names is still None
|
|
704
|
+
|
|
705
|
+
if class_names is None:
|
|
706
|
+
try:
|
|
707
|
+
# Fallback to 'target_names'
|
|
708
|
+
class_names = dataset_for_artifacts.target_names # type: ignore
|
|
709
|
+
except AttributeError:
|
|
710
|
+
# Fallback to inferring from labels
|
|
711
|
+
labels = np.unique(y_true)
|
|
712
|
+
class_names = [f"Class {i}" for i in labels]
|
|
713
|
+
_LOGGER.warning(f"Dataset has no 'classes' or 'target_names' attribute. Using generic names.")
|
|
714
|
+
|
|
715
|
+
# Check configuration
|
|
716
|
+
config = None
|
|
717
|
+
if format_configuration and isinstance(format_configuration, (FormatBinarySegmentationMetrics, FormatMultiClassSegmentationMetrics)):
|
|
718
|
+
config = format_configuration
|
|
719
|
+
elif format_configuration:
|
|
720
|
+
_LOGGER.warning(f"Wrong configuration type: Received '{type(format_configuration).__name__}'.")
|
|
721
|
+
|
|
722
|
+
segmentation_metrics(y_true=y_true,
|
|
723
|
+
y_pred=y_pred,
|
|
724
|
+
save_dir=save_dir,
|
|
725
|
+
class_names=class_names,
|
|
726
|
+
config=config)
|
|
727
|
+
|
|
728
|
+
def explain_shap(self,
|
|
729
|
+
save_dir: Union[str,Path],
|
|
730
|
+
explain_dataset: Optional[Dataset] = None,
|
|
731
|
+
n_samples: int = 300,
|
|
732
|
+
feature_names: Optional[list[str]] = None,
|
|
733
|
+
target_names: Optional[list[str]] = None,
|
|
734
|
+
explainer_type: Literal['deep', 'kernel'] = 'kernel'):
|
|
735
|
+
"""
|
|
736
|
+
Explains model predictions using SHAP and saves all artifacts.
|
|
737
|
+
|
|
738
|
+
NOTE: SHAP support is limited to single-target tasks (Regression, Binary/Multiclass Classification).
|
|
739
|
+
For complex tasks (Multi-target, Multi-label, Sequences, Images), please use `explain_captum()`.
|
|
740
|
+
|
|
741
|
+
The background data is automatically sampled from the trainer's training dataset.
|
|
742
|
+
|
|
743
|
+
This method automatically routes to the appropriate SHAP summary plot
|
|
744
|
+
function based on the task. If `feature_names` or `target_names` (multi-target) are not provided,
|
|
745
|
+
it will attempt to extract them from the dataset.
|
|
746
|
+
|
|
747
|
+
Args:
|
|
748
|
+
explain_dataset (Dataset | None): A specific dataset to explain.
|
|
749
|
+
If None, the trainer's test dataset is used.
|
|
750
|
+
n_samples (int): The number of samples to use for both background and explanation.
|
|
751
|
+
feature_names (list[str] | None): Feature names. If None, the names will be extracted from the Dataset and raise an error on failure.
|
|
752
|
+
target_names (list[str] | None): Target names for multi-target tasks.
|
|
753
|
+
save_dir (str | Path): Directory to save all SHAP artifacts.
|
|
754
|
+
explainer_type (Literal['deep', 'kernel']): The explainer to use.
|
|
755
|
+
- 'deep': Uses shap.DeepExplainer. Fast and efficient for PyTorch models.
|
|
756
|
+
- 'kernel': Uses shap.KernelExplainer. Model-agnostic but EXTREMELY slow and memory-intensive. Use with a very low 'n_samples'< 100.
|
|
757
|
+
"""
|
|
758
|
+
# --- 1. Compatibility Guard ---
|
|
759
|
+
valid_shap_tasks = [
|
|
760
|
+
MLTaskKeys.REGRESSION,
|
|
761
|
+
MLTaskKeys.BINARY_CLASSIFICATION,
|
|
762
|
+
MLTaskKeys.MULTICLASS_CLASSIFICATION
|
|
763
|
+
]
|
|
764
|
+
|
|
765
|
+
if self.kind not in valid_shap_tasks:
|
|
766
|
+
_LOGGER.warning(f"SHAP explanation is deprecated for task '{self.kind}' due to instability. Please use 'explain_captum()' instead.")
|
|
767
|
+
return
|
|
768
|
+
|
|
769
|
+
# memory efficient helper
|
|
770
|
+
def _get_random_sample(dataset: Dataset, num_samples: int):
|
|
771
|
+
"""
|
|
772
|
+
Memory-efficiently samples data from a dataset.
|
|
773
|
+
"""
|
|
774
|
+
if dataset is None:
|
|
775
|
+
return None
|
|
776
|
+
|
|
777
|
+
dataset_len = len(dataset) # type: ignore
|
|
778
|
+
if dataset_len == 0:
|
|
779
|
+
return None
|
|
780
|
+
|
|
781
|
+
# For MPS devices, num_workers must be 0 to ensure stability
|
|
782
|
+
loader_workers = 0 if self.device.type == 'mps' else self.dataloader_workers
|
|
783
|
+
|
|
784
|
+
# Ensure batch_size is not larger than the dataset itself
|
|
785
|
+
batch_size = min(num_samples, 64, dataset_len)
|
|
786
|
+
|
|
787
|
+
loader = DataLoader(
|
|
788
|
+
dataset,
|
|
789
|
+
batch_size=batch_size,
|
|
790
|
+
shuffle=True, # Shuffle to get random samples
|
|
791
|
+
num_workers=loader_workers
|
|
792
|
+
)
|
|
793
|
+
|
|
794
|
+
collected_features = []
|
|
795
|
+
num_collected = 0
|
|
796
|
+
|
|
797
|
+
for features, _ in loader:
|
|
798
|
+
collected_features.append(features)
|
|
799
|
+
num_collected += features.size(0)
|
|
800
|
+
if num_collected >= num_samples:
|
|
801
|
+
break # Stop once we have enough samples
|
|
802
|
+
|
|
803
|
+
if not collected_features:
|
|
804
|
+
return None
|
|
805
|
+
|
|
806
|
+
full_data = torch.cat(collected_features, dim=0)
|
|
807
|
+
|
|
808
|
+
# If we collected more than needed, trim it down
|
|
809
|
+
if full_data.size(0) > num_samples:
|
|
810
|
+
return full_data[:num_samples]
|
|
811
|
+
|
|
812
|
+
return full_data
|
|
813
|
+
|
|
814
|
+
# print(f"\n--- Preparing SHAP Data (sampling up to {n_samples} instances) ---")
|
|
815
|
+
|
|
816
|
+
# 1. Get background data from the trainer's train_dataset
|
|
817
|
+
background_data = _get_random_sample(self.train_dataset, n_samples)
|
|
818
|
+
if background_data is None:
|
|
819
|
+
_LOGGER.error("Trainer's train_dataset is empty or invalid. Skipping SHAP analysis.")
|
|
820
|
+
return
|
|
821
|
+
|
|
822
|
+
# 2. Determine target dataset and get explanation instances
|
|
823
|
+
target_dataset = explain_dataset if explain_dataset is not None else self.validation_dataset
|
|
824
|
+
instances_to_explain = _get_random_sample(target_dataset, n_samples)
|
|
825
|
+
if instances_to_explain is None:
|
|
826
|
+
_LOGGER.error("Explanation dataset is empty or invalid. Skipping SHAP analysis.")
|
|
827
|
+
return
|
|
828
|
+
|
|
829
|
+
# attempt to get feature names
|
|
830
|
+
if feature_names is None:
|
|
831
|
+
# _LOGGER.info("`feature_names` not provided. Attempting to extract from dataset...")
|
|
832
|
+
if hasattr(target_dataset, DatasetKeys.FEATURE_NAMES):
|
|
833
|
+
feature_names = target_dataset.feature_names # type: ignore
|
|
834
|
+
else:
|
|
835
|
+
_LOGGER.error(f"Could not extract `feature_names` from the dataset. It must be provided if the dataset object does not have a '{DatasetKeys.FEATURE_NAMES}' attribute.")
|
|
836
|
+
raise ValueError()
|
|
837
|
+
|
|
838
|
+
# move model to device
|
|
839
|
+
self.model.to(self.device)
|
|
840
|
+
|
|
841
|
+
# 3. Call the plotting function
|
|
842
|
+
if self.kind in [MLTaskKeys.REGRESSION, MLTaskKeys.BINARY_CLASSIFICATION, MLTaskKeys.MULTICLASS_CLASSIFICATION]:
|
|
843
|
+
shap_summary_plot(
|
|
844
|
+
model=self.model,
|
|
845
|
+
background_data=background_data,
|
|
846
|
+
instances_to_explain=instances_to_explain,
|
|
847
|
+
feature_names=feature_names,
|
|
848
|
+
save_dir=save_dir,
|
|
849
|
+
explainer_type=explainer_type,
|
|
850
|
+
device=self.device
|
|
851
|
+
)
|
|
852
|
+
# DEPRECATED: Multi-target SHAP support is unstable; recommend Captum instead.
|
|
853
|
+
elif self.kind in [MLTaskKeys.MULTITARGET_REGRESSION, MLTaskKeys.MULTILABEL_BINARY_CLASSIFICATION]:
|
|
854
|
+
# try to get target names
|
|
855
|
+
if target_names is None:
|
|
856
|
+
target_names = []
|
|
857
|
+
if hasattr(target_dataset, DatasetKeys.TARGET_NAMES):
|
|
858
|
+
target_names = target_dataset.target_names # type: ignore
|
|
859
|
+
else:
|
|
860
|
+
# Infer number of targets from the model's output layer
|
|
861
|
+
try:
|
|
862
|
+
num_targets = self.model.output_layer.out_features # type: ignore
|
|
863
|
+
target_names = [f"target_{i}" for i in range(num_targets)] # type: ignore
|
|
864
|
+
_LOGGER.warning("Dataset has no 'target_names' attribute. Using generic names.")
|
|
865
|
+
except AttributeError:
|
|
866
|
+
_LOGGER.error("Cannot determine target names for multi-target SHAP plot. Skipping.")
|
|
867
|
+
return
|
|
868
|
+
|
|
869
|
+
multi_target_shap_summary_plot(
|
|
870
|
+
model=self.model,
|
|
871
|
+
background_data=background_data,
|
|
872
|
+
instances_to_explain=instances_to_explain,
|
|
873
|
+
feature_names=feature_names, # type: ignore
|
|
874
|
+
target_names=target_names, # type: ignore
|
|
875
|
+
save_dir=save_dir,
|
|
876
|
+
explainer_type=explainer_type,
|
|
877
|
+
device=self.device
|
|
878
|
+
)
|
|
879
|
+
|
|
880
|
+
def explain_captum(self,
|
|
881
|
+
save_dir: Union[str, Path],
|
|
882
|
+
explain_dataset: Optional[Dataset] = None,
|
|
883
|
+
n_samples: int = 100,
|
|
884
|
+
feature_names: Optional[list[str]] = None,
|
|
885
|
+
target_names: Optional[list[str]] = None,
|
|
886
|
+
n_steps: int = 50):
|
|
887
|
+
"""
|
|
888
|
+
Explains model predictions using Captum's Integrated Gradients.
|
|
889
|
+
|
|
890
|
+
- **Tabular/Classification:** Generates Feature Importance Bar Charts.
|
|
891
|
+
- **Segmentation:** Generates Spatial Heatmaps for each class.
|
|
892
|
+
|
|
893
|
+
Args:
|
|
894
|
+
save_dir (str | Path): Directory to save artifacts.
|
|
895
|
+
explain_dataset (Dataset | None): Dataset to sample from. Defaults to validation set.
|
|
896
|
+
n_samples (int): Number of samples to evaluate.
|
|
897
|
+
feature_names (list[str] | None): Feature names.
|
|
898
|
+
- Required for Tabular tasks.
|
|
899
|
+
- Ignored/Optional for Image tasks (defaults to Channel names).
|
|
900
|
+
target_names (list[str] | None): Names for the model outputs (or Class names).
|
|
901
|
+
- If None, attempts to extract from dataset attributes (`target_names`, `classes`, or `class_map`).
|
|
902
|
+
- If extraction fails, generates generic names (e.g. "Output_0").
|
|
903
|
+
n_steps (int): Number of interpolation steps.
|
|
904
|
+
"""
|
|
905
|
+
# 2. Prepare Data
|
|
906
|
+
dataset_to_use = explain_dataset if explain_dataset is not None else self.validation_dataset
|
|
907
|
+
if dataset_to_use is None:
|
|
908
|
+
_LOGGER.error("No dataset available for explanation.")
|
|
909
|
+
return
|
|
910
|
+
|
|
911
|
+
# Efficient sampling helper
|
|
912
|
+
def _get_samples(ds, n):
|
|
913
|
+
# Use num_workers=0 for stability during ad-hoc sampling
|
|
914
|
+
loader = DataLoader(ds, batch_size=n, shuffle=True, num_workers=0)
|
|
915
|
+
data_iter = iter(loader)
|
|
916
|
+
features, targets = next(data_iter)
|
|
917
|
+
return features, targets
|
|
918
|
+
|
|
919
|
+
input_data, _ = _get_samples(dataset_to_use, n_samples)
|
|
920
|
+
|
|
921
|
+
# 3. Get Feature Names (Only if NOT segmentation AND NOT image classification)
|
|
922
|
+
# Image tasks generally don't have explicit feature names; Captum will default to "Channel_X"
|
|
923
|
+
is_segmentation = self.kind in [MLTaskKeys.BINARY_SEGMENTATION, MLTaskKeys.MULTICLASS_SEGMENTATION]
|
|
924
|
+
is_image_classification = self.kind in [MLTaskKeys.BINARY_IMAGE_CLASSIFICATION, MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION]
|
|
925
|
+
|
|
926
|
+
if feature_names is None and not is_segmentation and not is_image_classification:
|
|
927
|
+
if hasattr(dataset_to_use, DatasetKeys.FEATURE_NAMES):
|
|
928
|
+
feature_names = dataset_to_use.feature_names # type: ignore
|
|
929
|
+
else:
|
|
930
|
+
_LOGGER.error(f"Could not extract `feature_names`. It must be provided if the dataset does not have it.")
|
|
931
|
+
raise ValueError()
|
|
932
|
+
|
|
933
|
+
# 4. Handle Target Names (or Class Names)
|
|
934
|
+
if target_names is None:
|
|
935
|
+
# A. Try dataset attributes first
|
|
936
|
+
if hasattr(dataset_to_use, DatasetKeys.TARGET_NAMES):
|
|
937
|
+
target_names = dataset_to_use.target_names # type: ignore
|
|
938
|
+
elif hasattr(dataset_to_use, "classes"):
|
|
939
|
+
target_names = dataset_to_use.classes # type: ignore
|
|
940
|
+
elif hasattr(dataset_to_use, "class_map") and isinstance(dataset_to_use.class_map, dict): # type: ignore
|
|
941
|
+
# Sort by value (index) to ensure correct order: {name: index} -> [name_at_0, name_at_1...]
|
|
942
|
+
sorted_items = sorted(dataset_to_use.class_map.items(), key=lambda item: item[1]) # type: ignore
|
|
943
|
+
target_names = [k for k, v in sorted_items]
|
|
944
|
+
|
|
945
|
+
# B. Infer based on task
|
|
946
|
+
if target_names is None:
|
|
947
|
+
if self.kind in [MLTaskKeys.REGRESSION, MLTaskKeys.BINARY_CLASSIFICATION, MLTaskKeys.BINARY_IMAGE_CLASSIFICATION]:
|
|
948
|
+
target_names = ["Output"]
|
|
949
|
+
elif self.kind == MLTaskKeys.BINARY_SEGMENTATION:
|
|
950
|
+
target_names = ["Foreground"]
|
|
951
|
+
|
|
952
|
+
# For multiclass/multitarget without names, leave it None and let the evaluation function generate generics.
|
|
953
|
+
|
|
954
|
+
# 5. Dispatch based on Task
|
|
955
|
+
if is_segmentation:
|
|
956
|
+
# lower n_steps for segmentation to save memory
|
|
957
|
+
if n_steps > 30:
|
|
958
|
+
n_steps = 30
|
|
959
|
+
_LOGGER.warning(f"Segmentation task detected: Reducing Captum n_steps to {n_steps} to prevent OOM. If you encounter OOM errors, consider lowering this further.")
|
|
960
|
+
|
|
961
|
+
captum_segmentation_heatmap(
|
|
962
|
+
model=self.model,
|
|
963
|
+
input_data=input_data,
|
|
964
|
+
save_dir=save_dir,
|
|
965
|
+
target_names=target_names, # Can be None, helper handles it
|
|
966
|
+
n_steps=n_steps,
|
|
967
|
+
device=self.device
|
|
968
|
+
)
|
|
969
|
+
|
|
970
|
+
elif is_image_classification:
|
|
971
|
+
captum_image_heatmap(
|
|
972
|
+
model=self.model,
|
|
973
|
+
input_data=input_data,
|
|
974
|
+
save_dir=save_dir,
|
|
975
|
+
target_names=target_names,
|
|
976
|
+
n_steps=n_steps,
|
|
977
|
+
device=self.device
|
|
978
|
+
)
|
|
979
|
+
|
|
980
|
+
else:
|
|
981
|
+
# Standard Tabular/Image Classification
|
|
982
|
+
captum_feature_importance(
|
|
983
|
+
model=self.model,
|
|
984
|
+
input_data=input_data,
|
|
985
|
+
feature_names=feature_names,
|
|
986
|
+
save_dir=save_dir,
|
|
987
|
+
target_names=target_names,
|
|
988
|
+
n_steps=n_steps,
|
|
989
|
+
device=self.device
|
|
990
|
+
)
|
|
991
|
+
|
|
992
|
+
def _attention_helper(self, dataloader: DataLoader):
|
|
993
|
+
"""
|
|
994
|
+
Private method to yield model attention weights batch by batch for evaluation.
|
|
995
|
+
|
|
996
|
+
Args:
|
|
997
|
+
dataloader (DataLoader): The dataloader to predict on.
|
|
998
|
+
|
|
999
|
+
Yields:
|
|
1000
|
+
(torch.Tensor): Attention weights
|
|
1001
|
+
"""
|
|
1002
|
+
self.model.eval()
|
|
1003
|
+
self.model.to(self.device)
|
|
1004
|
+
|
|
1005
|
+
with torch.no_grad():
|
|
1006
|
+
for features, target in dataloader:
|
|
1007
|
+
features = features.to(self.device)
|
|
1008
|
+
attention_weights = None
|
|
1009
|
+
|
|
1010
|
+
# Get model output
|
|
1011
|
+
# Unpack logits and weights from the special forward method
|
|
1012
|
+
_output, attention_weights = self.model.forward_attention(features) # type: ignore
|
|
1013
|
+
|
|
1014
|
+
if attention_weights is not None:
|
|
1015
|
+
attention_weights = attention_weights.cpu()
|
|
1016
|
+
|
|
1017
|
+
yield attention_weights
|
|
1018
|
+
|
|
1019
|
+
def explain_attention(self, save_dir: Union[str, Path],
|
|
1020
|
+
feature_names: Optional[list[str]] = None,
|
|
1021
|
+
explain_dataset: Optional[Dataset] = None,
|
|
1022
|
+
plot_n_features: int = 10):
|
|
1023
|
+
"""
|
|
1024
|
+
Generates and saves a feature importance plot based on attention weights.
|
|
1025
|
+
|
|
1026
|
+
This method only works for models with models with 'has_interpretable_attention'.
|
|
1027
|
+
|
|
1028
|
+
Args:
|
|
1029
|
+
save_dir (str | Path): Directory to save the plot and summary data.
|
|
1030
|
+
feature_names (List[str] | None): Names for the features for plot labeling. If None, the names will be extracted from the Dataset and raise an error on failure.
|
|
1031
|
+
explain_dataset (Dataset, optional): A specific dataset to explain. If None, the trainer's test dataset is used.
|
|
1032
|
+
plot_n_features (int): Number of top features to plot.
|
|
1033
|
+
"""
|
|
1034
|
+
|
|
1035
|
+
# print("\n--- Attention Analysis ---")
|
|
1036
|
+
|
|
1037
|
+
# --- Step 1: Check if the model supports this explanation ---
|
|
1038
|
+
if not getattr(self.model, 'has_interpretable_attention', False):
|
|
1039
|
+
_LOGGER.warning("Model is not compatible with interpretable attention analysis. Skipping.")
|
|
1040
|
+
return
|
|
1041
|
+
|
|
1042
|
+
# --- Step 2: Set up the dataloader ---
|
|
1043
|
+
dataset_to_use = explain_dataset if explain_dataset is not None else self.validation_dataset
|
|
1044
|
+
if not isinstance(dataset_to_use, Dataset):
|
|
1045
|
+
_LOGGER.error("The explanation dataset is empty or invalid. Skipping attention analysis.")
|
|
1046
|
+
return
|
|
1047
|
+
|
|
1048
|
+
# Get feature names
|
|
1049
|
+
if feature_names is None:
|
|
1050
|
+
if hasattr(dataset_to_use, DatasetKeys.FEATURE_NAMES):
|
|
1051
|
+
feature_names = dataset_to_use.feature_names # type: ignore
|
|
1052
|
+
else:
|
|
1053
|
+
_LOGGER.error(f"Could not extract `feature_names` from the dataset for attention plot. It must be provided if the dataset object does not have a '{DatasetKeys.FEATURE_NAMES}' attribute.")
|
|
1054
|
+
raise ValueError()
|
|
1055
|
+
|
|
1056
|
+
explain_loader = DataLoader(
|
|
1057
|
+
dataset=dataset_to_use, batch_size=32, shuffle=False,
|
|
1058
|
+
num_workers=0 if self.device.type == 'mps' else self.dataloader_workers,
|
|
1059
|
+
pin_memory=("cuda" in self.device.type)
|
|
1060
|
+
)
|
|
1061
|
+
|
|
1062
|
+
# --- Step 3: Collect weights ---
|
|
1063
|
+
all_weights = []
|
|
1064
|
+
for att_weights_b in self._attention_helper(explain_loader):
|
|
1065
|
+
if att_weights_b is not None:
|
|
1066
|
+
all_weights.append(att_weights_b)
|
|
1067
|
+
|
|
1068
|
+
# --- Step 4: Call the plotting function ---
|
|
1069
|
+
if all_weights:
|
|
1070
|
+
plot_attention_importance(
|
|
1071
|
+
weights=all_weights,
|
|
1072
|
+
feature_names=feature_names,
|
|
1073
|
+
save_dir=save_dir,
|
|
1074
|
+
top_n=plot_n_features
|
|
1075
|
+
)
|
|
1076
|
+
else:
|
|
1077
|
+
_LOGGER.error("No attention weights were collected from the model.")
|
|
1078
|
+
|
|
1079
|
+
def finalize_model_training(self,
|
|
1080
|
+
model_checkpoint: Union[Path, Literal['best', 'current']],
|
|
1081
|
+
save_dir: Union[str, Path],
|
|
1082
|
+
finalize_config: Union[FinalizeRegression,
|
|
1083
|
+
FinalizeMultiTargetRegression,
|
|
1084
|
+
FinalizeBinaryClassification,
|
|
1085
|
+
FinalizeBinaryImageClassification,
|
|
1086
|
+
FinalizeMultiClassClassification,
|
|
1087
|
+
FinalizeMultiClassImageClassification,
|
|
1088
|
+
FinalizeBinarySegmentation,
|
|
1089
|
+
FinalizeMultiClassSegmentation,
|
|
1090
|
+
FinalizeMultiLabelBinaryClassification]):
|
|
1091
|
+
"""
|
|
1092
|
+
Saves a finalized, "inference-ready" model state to a .pth file.
|
|
1093
|
+
|
|
1094
|
+
This method saves the model's `state_dict`, the final epoch number, and optional configuration for the task at hand.
|
|
1095
|
+
|
|
1096
|
+
Args:
|
|
1097
|
+
model_checkpoint (Path | "best" | "current"):
|
|
1098
|
+
- Path: Loads the model state from a specific checkpoint file.
|
|
1099
|
+
- "best": Loads the best model state saved by the `DragonModelCheckpoint` callback.
|
|
1100
|
+
- "current": Uses the model's state as it is.
|
|
1101
|
+
save_dir (str | Path): The directory to save the finalized model.
|
|
1102
|
+
finalize_config (object): A data class instance specific to the ML task containing task-specific metadata required for inference.
|
|
1103
|
+
"""
|
|
1104
|
+
if self.kind == MLTaskKeys.REGRESSION and not isinstance(finalize_config, FinalizeRegression):
|
|
1105
|
+
_LOGGER.error(f"For task {self.kind}, expected finalize_config of type 'FinalizeRegression', but got {type(finalize_config).__name__}.")
|
|
1106
|
+
raise TypeError()
|
|
1107
|
+
elif self.kind == MLTaskKeys.MULTITARGET_REGRESSION and not isinstance(finalize_config, FinalizeMultiTargetRegression):
|
|
1108
|
+
_LOGGER.error(f"For task {self.kind}, expected finalize_config of type 'FinalizeMultiTargetRegression', but got {type(finalize_config).__name__}.")
|
|
1109
|
+
raise TypeError()
|
|
1110
|
+
elif self.kind == MLTaskKeys.BINARY_CLASSIFICATION and not isinstance(finalize_config, FinalizeBinaryClassification):
|
|
1111
|
+
_LOGGER.error(f"For task {self.kind}, expected finalize_config of type 'FinalizeBinaryClassification', but got {type(finalize_config).__name__}.")
|
|
1112
|
+
raise TypeError()
|
|
1113
|
+
elif self.kind == MLTaskKeys.BINARY_IMAGE_CLASSIFICATION and not isinstance(finalize_config, FinalizeBinaryImageClassification):
|
|
1114
|
+
_LOGGER.error(f"For task {self.kind}, expected finalize_config of type 'FinalizeBinaryImageClassification', but got {type(finalize_config).__name__}.")
|
|
1115
|
+
raise TypeError()
|
|
1116
|
+
elif self.kind == MLTaskKeys.MULTICLASS_CLASSIFICATION and not isinstance(finalize_config, FinalizeMultiClassClassification):
|
|
1117
|
+
_LOGGER.error(f"For task {self.kind}, expected finalize_config of type 'FinalizeMultiClassClassification', but got {type(finalize_config).__name__}.")
|
|
1118
|
+
raise TypeError()
|
|
1119
|
+
elif self.kind == MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION and not isinstance(finalize_config, FinalizeMultiClassImageClassification):
|
|
1120
|
+
_LOGGER.error(f"For task {self.kind}, expected finalize_config of type 'FinalizeMultiClassImageClassification', but got {type(finalize_config).__name__}.")
|
|
1121
|
+
raise TypeError()
|
|
1122
|
+
elif self.kind == MLTaskKeys.BINARY_SEGMENTATION and not isinstance(finalize_config, FinalizeBinarySegmentation):
|
|
1123
|
+
_LOGGER.error(f"For task {self.kind}, expected finalize_config of type 'FinalizeBinarySegmentation', but got {type(finalize_config).__name__}.")
|
|
1124
|
+
raise TypeError()
|
|
1125
|
+
elif self.kind == MLTaskKeys.MULTICLASS_SEGMENTATION and not isinstance(finalize_config, FinalizeMultiClassSegmentation):
|
|
1126
|
+
_LOGGER.error(f"For task {self.kind}, expected finalize_config of type 'FinalizeMultiClassSegmentation', but got {type(finalize_config).__name__}.")
|
|
1127
|
+
raise TypeError()
|
|
1128
|
+
elif self.kind == MLTaskKeys.MULTILABEL_BINARY_CLASSIFICATION and not isinstance(finalize_config, FinalizeMultiLabelBinaryClassification):
|
|
1129
|
+
_LOGGER.error(f"For task {self.kind}, expected finalize_config of type 'FinalizeMultiLabelBinaryClassification', but got {type(finalize_config).__name__}.")
|
|
1130
|
+
raise TypeError()
|
|
1131
|
+
|
|
1132
|
+
# handle save path
|
|
1133
|
+
dir_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
1134
|
+
full_path = dir_path / finalize_config.filename
|
|
1135
|
+
|
|
1136
|
+
# handle checkpoint
|
|
1137
|
+
self._load_model_state_for_finalizing(model_checkpoint)
|
|
1138
|
+
|
|
1139
|
+
# Create finalized data
|
|
1140
|
+
finalized_data = {
|
|
1141
|
+
PyTorchCheckpointKeys.EPOCH: self.epoch,
|
|
1142
|
+
PyTorchCheckpointKeys.MODEL_STATE: self.model.state_dict(),
|
|
1143
|
+
PyTorchCheckpointKeys.TASK: finalize_config.task
|
|
1144
|
+
}
|
|
1145
|
+
|
|
1146
|
+
# Parse config
|
|
1147
|
+
if finalize_config.target_name is not None:
|
|
1148
|
+
finalized_data[PyTorchCheckpointKeys.TARGET_NAME] = finalize_config.target_name
|
|
1149
|
+
if finalize_config.target_names is not None:
|
|
1150
|
+
finalized_data[PyTorchCheckpointKeys.TARGET_NAMES] = finalize_config.target_names
|
|
1151
|
+
if finalize_config.classification_threshold is not None:
|
|
1152
|
+
finalized_data[PyTorchCheckpointKeys.CLASSIFICATION_THRESHOLD] = finalize_config.classification_threshold
|
|
1153
|
+
if finalize_config.class_map is not None:
|
|
1154
|
+
finalized_data[PyTorchCheckpointKeys.CLASS_MAP] = finalize_config.class_map
|
|
1155
|
+
|
|
1156
|
+
# Save model file
|
|
1157
|
+
torch.save(finalized_data, full_path)
|
|
1158
|
+
|
|
1159
|
+
_LOGGER.info(f"Finalized model file saved to '{full_path}'")
|
|
1160
|
+
|