dragon-ml-toolbox 19.14.0__py3-none-any.whl → 20.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dragon_ml_toolbox-19.14.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/METADATA +29 -46
- dragon_ml_toolbox-20.0.0.dist-info/RECORD +178 -0
- ml_tools/{ETL_cleaning.py → ETL_cleaning/__init__.py} +13 -5
- ml_tools/ETL_cleaning/_basic_clean.py +351 -0
- ml_tools/ETL_cleaning/_clean_tools.py +128 -0
- ml_tools/ETL_cleaning/_dragon_cleaner.py +245 -0
- ml_tools/ETL_cleaning/_imprimir.py +13 -0
- ml_tools/{ETL_engineering.py → ETL_engineering/__init__.py} +8 -4
- ml_tools/ETL_engineering/_dragon_engineering.py +261 -0
- ml_tools/ETL_engineering/_imprimir.py +24 -0
- ml_tools/{_core/_ETL_engineering.py → ETL_engineering/_transforms.py} +14 -267
- ml_tools/{_core → GUI_tools}/_GUI_tools.py +37 -40
- ml_tools/{GUI_tools.py → GUI_tools/__init__.py} +7 -5
- ml_tools/GUI_tools/_imprimir.py +12 -0
- ml_tools/IO_tools/_IO_loggers.py +235 -0
- ml_tools/IO_tools/_IO_save_load.py +151 -0
- ml_tools/IO_tools/_IO_utils.py +140 -0
- ml_tools/{IO_tools.py → IO_tools/__init__.py} +13 -5
- ml_tools/IO_tools/_imprimir.py +14 -0
- ml_tools/MICE/_MICE_imputation.py +132 -0
- ml_tools/{MICE_imputation.py → MICE/__init__.py} +6 -7
- ml_tools/{_core/_MICE_imputation.py → MICE/_dragon_mice.py} +243 -322
- ml_tools/MICE/_imprimir.py +11 -0
- ml_tools/{ML_callbacks.py → ML_callbacks/__init__.py} +12 -4
- ml_tools/ML_callbacks/_base.py +101 -0
- ml_tools/ML_callbacks/_checkpoint.py +232 -0
- ml_tools/ML_callbacks/_early_stop.py +208 -0
- ml_tools/ML_callbacks/_imprimir.py +12 -0
- ml_tools/ML_callbacks/_scheduler.py +197 -0
- ml_tools/{ML_chaining_utilities.py → ML_chain/__init__.py} +8 -3
- ml_tools/{_core/_ML_chaining_utilities.py → ML_chain/_chaining_tools.py} +5 -129
- ml_tools/ML_chain/_dragon_chain.py +140 -0
- ml_tools/ML_chain/_imprimir.py +11 -0
- ml_tools/ML_configuration/__init__.py +90 -0
- ml_tools/ML_configuration/_base_model_config.py +69 -0
- ml_tools/ML_configuration/_finalize.py +366 -0
- ml_tools/ML_configuration/_imprimir.py +47 -0
- ml_tools/ML_configuration/_metrics.py +593 -0
- ml_tools/ML_configuration/_models.py +206 -0
- ml_tools/ML_configuration/_training.py +124 -0
- ml_tools/ML_datasetmaster/__init__.py +28 -0
- ml_tools/ML_datasetmaster/_base_datasetmaster.py +337 -0
- ml_tools/{_core/_ML_datasetmaster.py → ML_datasetmaster/_datasetmaster.py} +9 -329
- ml_tools/ML_datasetmaster/_imprimir.py +15 -0
- ml_tools/{_core/_ML_sequence_datasetmaster.py → ML_datasetmaster/_sequence_datasetmaster.py} +13 -15
- ml_tools/{_core/_ML_vision_datasetmaster.py → ML_datasetmaster/_vision_datasetmaster.py} +63 -65
- ml_tools/ML_evaluation/__init__.py +53 -0
- ml_tools/ML_evaluation/_classification.py +629 -0
- ml_tools/ML_evaluation/_feature_importance.py +409 -0
- ml_tools/ML_evaluation/_imprimir.py +25 -0
- ml_tools/ML_evaluation/_loss.py +92 -0
- ml_tools/ML_evaluation/_regression.py +273 -0
- ml_tools/{_core/_ML_sequence_evaluation.py → ML_evaluation/_sequence.py} +8 -11
- ml_tools/{_core/_ML_vision_evaluation.py → ML_evaluation/_vision.py} +12 -17
- ml_tools/{_core → ML_evaluation_captum}/_ML_evaluation_captum.py +11 -38
- ml_tools/{ML_evaluation_captum.py → ML_evaluation_captum/__init__.py} +6 -4
- ml_tools/ML_evaluation_captum/_imprimir.py +10 -0
- ml_tools/{_core → ML_finalize_handler}/_ML_finalize_handler.py +3 -7
- ml_tools/ML_finalize_handler/__init__.py +10 -0
- ml_tools/ML_finalize_handler/_imprimir.py +8 -0
- ml_tools/ML_inference/__init__.py +22 -0
- ml_tools/ML_inference/_base_inference.py +166 -0
- ml_tools/{_core/_ML_chaining_inference.py → ML_inference/_chain_inference.py} +14 -17
- ml_tools/ML_inference/_dragon_inference.py +332 -0
- ml_tools/ML_inference/_imprimir.py +11 -0
- ml_tools/ML_inference/_multi_inference.py +180 -0
- ml_tools/ML_inference_sequence/__init__.py +10 -0
- ml_tools/ML_inference_sequence/_imprimir.py +8 -0
- ml_tools/{_core/_ML_sequence_inference.py → ML_inference_sequence/_sequence_inference.py} +11 -15
- ml_tools/ML_inference_vision/__init__.py +10 -0
- ml_tools/ML_inference_vision/_imprimir.py +8 -0
- ml_tools/{_core/_ML_vision_inference.py → ML_inference_vision/_vision_inference.py} +15 -19
- ml_tools/ML_models/__init__.py +32 -0
- ml_tools/{_core/_ML_models_advanced.py → ML_models/_advanced_models.py} +22 -18
- ml_tools/ML_models/_base_mlp_attention.py +198 -0
- ml_tools/{_core/_models_advanced_base.py → ML_models/_base_save_load.py} +73 -49
- ml_tools/ML_models/_dragon_tabular.py +248 -0
- ml_tools/ML_models/_imprimir.py +18 -0
- ml_tools/ML_models/_mlp_attention.py +134 -0
- ml_tools/{_core → ML_models}/_models_advanced_helpers.py +13 -13
- ml_tools/ML_models_sequence/__init__.py +10 -0
- ml_tools/ML_models_sequence/_imprimir.py +8 -0
- ml_tools/{_core/_ML_sequence_models.py → ML_models_sequence/_sequence_models.py} +5 -8
- ml_tools/ML_models_vision/__init__.py +29 -0
- ml_tools/ML_models_vision/_base_wrapper.py +254 -0
- ml_tools/ML_models_vision/_image_classification.py +182 -0
- ml_tools/ML_models_vision/_image_segmentation.py +108 -0
- ml_tools/ML_models_vision/_imprimir.py +16 -0
- ml_tools/ML_models_vision/_object_detection.py +135 -0
- ml_tools/ML_optimization/__init__.py +21 -0
- ml_tools/ML_optimization/_imprimir.py +13 -0
- ml_tools/{_core/_ML_optimization_pareto.py → ML_optimization/_multi_dragon.py} +18 -24
- ml_tools/ML_optimization/_single_dragon.py +203 -0
- ml_tools/{_core/_ML_optimization.py → ML_optimization/_single_manual.py} +75 -213
- ml_tools/{_core → ML_scaler}/_ML_scaler.py +8 -11
- ml_tools/ML_scaler/__init__.py +10 -0
- ml_tools/ML_scaler/_imprimir.py +8 -0
- ml_tools/ML_trainer/__init__.py +20 -0
- ml_tools/ML_trainer/_base_trainer.py +297 -0
- ml_tools/ML_trainer/_dragon_detection_trainer.py +402 -0
- ml_tools/ML_trainer/_dragon_sequence_trainer.py +540 -0
- ml_tools/ML_trainer/_dragon_trainer.py +1160 -0
- ml_tools/ML_trainer/_imprimir.py +10 -0
- ml_tools/{ML_utilities.py → ML_utilities/__init__.py} +14 -6
- ml_tools/ML_utilities/_artifact_finder.py +382 -0
- ml_tools/ML_utilities/_imprimir.py +16 -0
- ml_tools/ML_utilities/_inspection.py +325 -0
- ml_tools/ML_utilities/_train_tools.py +205 -0
- ml_tools/{ML_vision_transformers.py → ML_vision_transformers/__init__.py} +9 -6
- ml_tools/{_core/_ML_vision_transformers.py → ML_vision_transformers/_core_transforms.py} +11 -155
- ml_tools/ML_vision_transformers/_imprimir.py +14 -0
- ml_tools/ML_vision_transformers/_offline_augmentation.py +159 -0
- ml_tools/{_core/_PSO_optimization.py → PSO_optimization/_PSO.py} +58 -15
- ml_tools/{PSO_optimization.py → PSO_optimization/__init__.py} +5 -3
- ml_tools/PSO_optimization/_imprimir.py +10 -0
- ml_tools/SQL/__init__.py +7 -0
- ml_tools/{_core/_SQL.py → SQL/_dragon_SQL.py} +7 -11
- ml_tools/SQL/_imprimir.py +8 -0
- ml_tools/{_core → VIF}/_VIF_factor.py +5 -8
- ml_tools/{VIF_factor.py → VIF/__init__.py} +4 -2
- ml_tools/VIF/_imprimir.py +10 -0
- ml_tools/_core/__init__.py +7 -1
- ml_tools/_core/_logger.py +8 -18
- ml_tools/_core/_schema_load_ops.py +43 -0
- ml_tools/_core/_script_info.py +2 -2
- ml_tools/{data_exploration.py → data_exploration/__init__.py} +32 -16
- ml_tools/data_exploration/_analysis.py +214 -0
- ml_tools/data_exploration/_cleaning.py +566 -0
- ml_tools/data_exploration/_features.py +583 -0
- ml_tools/data_exploration/_imprimir.py +32 -0
- ml_tools/data_exploration/_plotting.py +487 -0
- ml_tools/data_exploration/_schema_ops.py +176 -0
- ml_tools/{ensemble_evaluation.py → ensemble_evaluation/__init__.py} +6 -4
- ml_tools/{_core → ensemble_evaluation}/_ensemble_evaluation.py +3 -7
- ml_tools/ensemble_evaluation/_imprimir.py +14 -0
- ml_tools/{ensemble_inference.py → ensemble_inference/__init__.py} +5 -3
- ml_tools/{_core → ensemble_inference}/_ensemble_inference.py +15 -18
- ml_tools/ensemble_inference/_imprimir.py +9 -0
- ml_tools/{ensemble_learning.py → ensemble_learning/__init__.py} +4 -6
- ml_tools/{_core → ensemble_learning}/_ensemble_learning.py +7 -10
- ml_tools/ensemble_learning/_imprimir.py +10 -0
- ml_tools/{excel_handler.py → excel_handler/__init__.py} +5 -3
- ml_tools/{_core → excel_handler}/_excel_handler.py +6 -10
- ml_tools/excel_handler/_imprimir.py +13 -0
- ml_tools/{keys.py → keys/__init__.py} +4 -1
- ml_tools/keys/_imprimir.py +11 -0
- ml_tools/{_core → keys}/_keys.py +2 -0
- ml_tools/{math_utilities.py → math_utilities/__init__.py} +5 -2
- ml_tools/math_utilities/_imprimir.py +11 -0
- ml_tools/{_core → math_utilities}/_math_utilities.py +1 -5
- ml_tools/{optimization_tools.py → optimization_tools/__init__.py} +9 -4
- ml_tools/optimization_tools/_imprimir.py +13 -0
- ml_tools/optimization_tools/_optimization_bounds.py +236 -0
- ml_tools/optimization_tools/_optimization_plots.py +218 -0
- ml_tools/{path_manager.py → path_manager/__init__.py} +6 -3
- ml_tools/{_core/_path_manager.py → path_manager/_dragonmanager.py} +11 -347
- ml_tools/path_manager/_imprimir.py +15 -0
- ml_tools/path_manager/_path_tools.py +346 -0
- ml_tools/plot_fonts/__init__.py +8 -0
- ml_tools/plot_fonts/_imprimir.py +8 -0
- ml_tools/{_core → plot_fonts}/_plot_fonts.py +2 -5
- ml_tools/schema/__init__.py +15 -0
- ml_tools/schema/_feature_schema.py +223 -0
- ml_tools/schema/_gui_schema.py +191 -0
- ml_tools/schema/_imprimir.py +10 -0
- ml_tools/{serde.py → serde/__init__.py} +4 -2
- ml_tools/serde/_imprimir.py +10 -0
- ml_tools/{_core → serde}/_serde.py +3 -8
- ml_tools/{utilities.py → utilities/__init__.py} +11 -6
- ml_tools/utilities/_imprimir.py +18 -0
- ml_tools/{_core/_utilities.py → utilities/_utility_save_load.py} +13 -190
- ml_tools/utilities/_utility_tools.py +192 -0
- dragon_ml_toolbox-19.14.0.dist-info/RECORD +0 -111
- ml_tools/ML_chaining_inference.py +0 -8
- ml_tools/ML_configuration.py +0 -86
- ml_tools/ML_configuration_pytab.py +0 -14
- ml_tools/ML_datasetmaster.py +0 -10
- ml_tools/ML_evaluation.py +0 -16
- ml_tools/ML_evaluation_multi.py +0 -12
- ml_tools/ML_finalize_handler.py +0 -8
- ml_tools/ML_inference.py +0 -12
- ml_tools/ML_models.py +0 -14
- ml_tools/ML_models_advanced.py +0 -14
- ml_tools/ML_models_pytab.py +0 -14
- ml_tools/ML_optimization.py +0 -14
- ml_tools/ML_optimization_pareto.py +0 -8
- ml_tools/ML_scaler.py +0 -8
- ml_tools/ML_sequence_datasetmaster.py +0 -8
- ml_tools/ML_sequence_evaluation.py +0 -10
- ml_tools/ML_sequence_inference.py +0 -8
- ml_tools/ML_sequence_models.py +0 -8
- ml_tools/ML_trainer.py +0 -12
- ml_tools/ML_vision_datasetmaster.py +0 -12
- ml_tools/ML_vision_evaluation.py +0 -10
- ml_tools/ML_vision_inference.py +0 -8
- ml_tools/ML_vision_models.py +0 -18
- ml_tools/SQL.py +0 -8
- ml_tools/_core/_ETL_cleaning.py +0 -694
- ml_tools/_core/_IO_tools.py +0 -498
- ml_tools/_core/_ML_callbacks.py +0 -702
- ml_tools/_core/_ML_configuration.py +0 -1332
- ml_tools/_core/_ML_configuration_pytab.py +0 -102
- ml_tools/_core/_ML_evaluation.py +0 -867
- ml_tools/_core/_ML_evaluation_multi.py +0 -544
- ml_tools/_core/_ML_inference.py +0 -646
- ml_tools/_core/_ML_models.py +0 -668
- ml_tools/_core/_ML_models_pytab.py +0 -693
- ml_tools/_core/_ML_trainer.py +0 -2323
- ml_tools/_core/_ML_utilities.py +0 -886
- ml_tools/_core/_ML_vision_models.py +0 -644
- ml_tools/_core/_data_exploration.py +0 -1909
- ml_tools/_core/_optimization_tools.py +0 -493
- ml_tools/_core/_schema.py +0 -359
- ml_tools/plot_fonts.py +0 -8
- ml_tools/schema.py +0 -12
- {dragon_ml_toolbox-19.14.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/WHEEL +0 -0
- {dragon_ml_toolbox-19.14.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/licenses/LICENSE +0 -0
- {dragon_ml_toolbox-19.14.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +0 -0
- {dragon_ml_toolbox-19.14.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,402 @@
|
|
|
1
|
+
from typing import Literal, Union, Optional, Callable
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from torch.utils.data import DataLoader, Dataset
|
|
4
|
+
import torch
|
|
5
|
+
from torch import nn
|
|
6
|
+
|
|
7
|
+
from ..ML_callbacks._base import _Callback
|
|
8
|
+
from ..ML_callbacks._checkpoint import DragonModelCheckpoint
|
|
9
|
+
from ..ML_callbacks._early_stop import _DragonEarlyStopping
|
|
10
|
+
from ..ML_callbacks._scheduler import _DragonLRScheduler
|
|
11
|
+
from ..ML_evaluation import object_detection_metrics
|
|
12
|
+
from ..ML_configuration import FinalizeObjectDetection
|
|
13
|
+
|
|
14
|
+
from ..path_manager import make_fullpath
|
|
15
|
+
from ..keys._keys import PyTorchLogKeys, PyTorchCheckpointKeys, MLTaskKeys, MagicWords, DragonTrainerKeys
|
|
16
|
+
from .._core import get_logger
|
|
17
|
+
|
|
18
|
+
from ._base_trainer import _BaseDragonTrainer
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
_LOGGER = get_logger("DragonDetectionTrainer")
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
__all__ = [
|
|
25
|
+
"DragonDetectionTrainer",
|
|
26
|
+
]
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
# Object Detection Trainer
|
|
30
|
+
class DragonDetectionTrainer(_BaseDragonTrainer):
|
|
31
|
+
def __init__(self, model: nn.Module,
|
|
32
|
+
train_dataset: Dataset,
|
|
33
|
+
validation_dataset: Dataset,
|
|
34
|
+
collate_fn: Callable,
|
|
35
|
+
optimizer: torch.optim.Optimizer,
|
|
36
|
+
device: Union[Literal['cuda', 'mps', 'cpu'],str],
|
|
37
|
+
checkpoint_callback: Optional[DragonModelCheckpoint],
|
|
38
|
+
early_stopping_callback: Optional[_DragonEarlyStopping],
|
|
39
|
+
lr_scheduler_callback: Optional[_DragonLRScheduler],
|
|
40
|
+
extra_callbacks: Optional[list[_Callback]] = None,
|
|
41
|
+
dataloader_workers: int = 2):
|
|
42
|
+
"""
|
|
43
|
+
Automates the training process of an Object Detection Model (e.g., DragonFastRCNN).
|
|
44
|
+
|
|
45
|
+
Built-in Callbacks: `History`, `TqdmProgressBar`
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
model (nn.Module): The PyTorch object detection model to train.
|
|
49
|
+
train_dataset (Dataset): The training dataset.
|
|
50
|
+
validation_dataset (Dataset): The testing/validation dataset.
|
|
51
|
+
collate_fn (Callable): The collate function from `ObjectDetectionDatasetMaker.collate_fn`.
|
|
52
|
+
optimizer (torch.optim.Optimizer): The optimizer.
|
|
53
|
+
device (str): The device to run training on ('cpu', 'cuda', 'mps').
|
|
54
|
+
dataloader_workers (int): Subprocesses for data loading.
|
|
55
|
+
checkpoint_callback (DragonModelCheckpoint | None): Callback to save the model.
|
|
56
|
+
early_stopping_callback (DragonEarlyStopping | None): Callback to stop training early.
|
|
57
|
+
lr_scheduler_callback (DragonLRScheduler | None): Callback to manage the LR scheduler.
|
|
58
|
+
extra_callbacks (List[Callback] | None): A list of extra callbacks to use during training.
|
|
59
|
+
|
|
60
|
+
## Note:
|
|
61
|
+
This trainer is specialized. It does not take a `criterion` because object detection models like Faster R-CNN return a dictionary of losses directly from their forward pass during training.
|
|
62
|
+
"""
|
|
63
|
+
# Call the base class constructor with common parameters
|
|
64
|
+
super().__init__(
|
|
65
|
+
model=model,
|
|
66
|
+
optimizer=optimizer,
|
|
67
|
+
device=device,
|
|
68
|
+
dataloader_workers=dataloader_workers,
|
|
69
|
+
checkpoint_callback=checkpoint_callback,
|
|
70
|
+
early_stopping_callback=early_stopping_callback,
|
|
71
|
+
lr_scheduler_callback=lr_scheduler_callback,
|
|
72
|
+
extra_callbacks=extra_callbacks
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
self.train_dataset = train_dataset
|
|
76
|
+
self.validation_dataset = validation_dataset # <-- Renamed
|
|
77
|
+
self.kind = MLTaskKeys.OBJECT_DETECTION
|
|
78
|
+
self.collate_fn = collate_fn
|
|
79
|
+
self.criterion = None # Criterion is handled inside the model
|
|
80
|
+
|
|
81
|
+
def _create_dataloaders(self, batch_size: int, shuffle: bool):
|
|
82
|
+
"""Initializes the DataLoaders with the object detection collate_fn."""
|
|
83
|
+
# Ensure stability on MPS devices by setting num_workers to 0
|
|
84
|
+
loader_workers = 0 if self.device.type == 'mps' else self.dataloader_workers
|
|
85
|
+
|
|
86
|
+
self.train_loader = DataLoader(
|
|
87
|
+
dataset=self.train_dataset,
|
|
88
|
+
batch_size=batch_size,
|
|
89
|
+
shuffle=shuffle,
|
|
90
|
+
num_workers=loader_workers,
|
|
91
|
+
pin_memory=("cuda" in self.device.type),
|
|
92
|
+
collate_fn=self.collate_fn, # Use the provided collate function
|
|
93
|
+
drop_last=True
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
self.validation_loader = DataLoader(
|
|
97
|
+
dataset=self.validation_dataset,
|
|
98
|
+
batch_size=batch_size,
|
|
99
|
+
shuffle=False,
|
|
100
|
+
num_workers=loader_workers,
|
|
101
|
+
pin_memory=("cuda" in self.device.type),
|
|
102
|
+
collate_fn=self.collate_fn # Use the provided collate function
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
def _train_step(self):
|
|
106
|
+
self.model.train()
|
|
107
|
+
running_loss = 0.0
|
|
108
|
+
total_samples = 0
|
|
109
|
+
|
|
110
|
+
for batch_idx, (images, targets) in enumerate(self.train_loader): # type: ignore
|
|
111
|
+
# images is a tuple of tensors, targets is a tuple of dicts
|
|
112
|
+
batch_size = len(images)
|
|
113
|
+
|
|
114
|
+
# Create a log dictionary for the batch
|
|
115
|
+
batch_logs = {
|
|
116
|
+
PyTorchLogKeys.BATCH_INDEX: batch_idx,
|
|
117
|
+
PyTorchLogKeys.BATCH_SIZE: batch_size
|
|
118
|
+
}
|
|
119
|
+
self._callbacks_hook('on_batch_begin', batch_idx, logs=batch_logs)
|
|
120
|
+
|
|
121
|
+
# Move data to device
|
|
122
|
+
images = list(img.to(self.device) for img in images)
|
|
123
|
+
targets = [{k: v.to(self.device) for k, v in t.items()} for t in targets]
|
|
124
|
+
|
|
125
|
+
self.optimizer.zero_grad()
|
|
126
|
+
|
|
127
|
+
# Model returns a loss dict when in train() mode and targets are passed
|
|
128
|
+
loss_dict = self.model(images, targets)
|
|
129
|
+
|
|
130
|
+
if not loss_dict:
|
|
131
|
+
# No losses returned, skip batch
|
|
132
|
+
_LOGGER.warning(f"Model returned no losses for batch {batch_idx}. Skipping.")
|
|
133
|
+
batch_logs[PyTorchLogKeys.BATCH_LOSS] = 0
|
|
134
|
+
self._callbacks_hook('on_batch_end', batch_idx, logs=batch_logs)
|
|
135
|
+
continue
|
|
136
|
+
|
|
137
|
+
# Sum all losses
|
|
138
|
+
loss: torch.Tensor = sum(l for l in loss_dict.values()) # type: ignore
|
|
139
|
+
|
|
140
|
+
loss.backward()
|
|
141
|
+
self.optimizer.step()
|
|
142
|
+
|
|
143
|
+
# Calculate batch loss and update running loss for the epoch
|
|
144
|
+
batch_loss = loss.item()
|
|
145
|
+
running_loss += batch_loss * batch_size
|
|
146
|
+
total_samples += batch_size # <-- Accumulate total samples
|
|
147
|
+
|
|
148
|
+
# Add the batch loss to the logs and call the end-of-batch hook
|
|
149
|
+
batch_logs[PyTorchLogKeys.BATCH_LOSS] = batch_loss # type: ignore
|
|
150
|
+
self._callbacks_hook('on_batch_end', batch_idx, logs=batch_logs)
|
|
151
|
+
|
|
152
|
+
# Calculate loss using the correct denominator
|
|
153
|
+
if total_samples == 0:
|
|
154
|
+
_LOGGER.warning("No samples processed in _train_step. Returning 0 loss.")
|
|
155
|
+
return {PyTorchLogKeys.TRAIN_LOSS: 0.0}
|
|
156
|
+
|
|
157
|
+
return {PyTorchLogKeys.TRAIN_LOSS: running_loss / total_samples}
|
|
158
|
+
|
|
159
|
+
def _validation_step(self):
|
|
160
|
+
self.model.train() # Set to train mode even for validation loss calculation
|
|
161
|
+
# as model internals (e.g., proposals) might differ, but we still need loss_dict.
|
|
162
|
+
# use torch.no_grad() to prevent gradient updates.
|
|
163
|
+
running_loss = 0.0
|
|
164
|
+
total_samples = 0
|
|
165
|
+
|
|
166
|
+
with torch.no_grad():
|
|
167
|
+
for images, targets in self.validation_loader: # type: ignore
|
|
168
|
+
batch_size = len(images)
|
|
169
|
+
|
|
170
|
+
# Move data to device
|
|
171
|
+
images = list(img.to(self.device) for img in images)
|
|
172
|
+
targets = [{k: v.to(self.device) for k, v in t.items()} for t in targets]
|
|
173
|
+
|
|
174
|
+
# Get loss dict
|
|
175
|
+
loss_dict = self.model(images, targets)
|
|
176
|
+
|
|
177
|
+
if not loss_dict:
|
|
178
|
+
_LOGGER.warning("Model returned no losses during validation step. Skipping batch.")
|
|
179
|
+
continue # Skip if no losses
|
|
180
|
+
|
|
181
|
+
# Sum all losses
|
|
182
|
+
loss: torch.Tensor = sum(l for l in loss_dict.values()) # type: ignore
|
|
183
|
+
|
|
184
|
+
running_loss += loss.item() * batch_size
|
|
185
|
+
total_samples += batch_size # <-- Accumulate total samples
|
|
186
|
+
|
|
187
|
+
# Calculate loss using the correct denominator
|
|
188
|
+
if total_samples == 0:
|
|
189
|
+
_LOGGER.warning("No samples processed in _validation_step. Returning 0 loss.")
|
|
190
|
+
return {PyTorchLogKeys.VAL_LOSS: 0.0}
|
|
191
|
+
|
|
192
|
+
logs = {PyTorchLogKeys.VAL_LOSS: running_loss / total_samples}
|
|
193
|
+
return logs
|
|
194
|
+
|
|
195
|
+
def evaluate(self,
|
|
196
|
+
save_dir: Union[str, Path],
|
|
197
|
+
model_checkpoint: Union[Path, Literal["best", "current"]],
|
|
198
|
+
test_data: Optional[Union[DataLoader, Dataset]] = None):
|
|
199
|
+
"""
|
|
200
|
+
Evaluates the model using object detection mAP metrics.
|
|
201
|
+
|
|
202
|
+
Args:
|
|
203
|
+
save_dir (str | Path): Directory to save all reports and plots.
|
|
204
|
+
model_checkpoint (Path | "best" | "current"):
|
|
205
|
+
- Path to a valid checkpoint for the model. The state of the trained model will be overwritten in place.
|
|
206
|
+
- If 'best', the best checkpoint will be loaded if a DragonModelCheckpoint was provided. The state of the trained model will be overwritten in place.
|
|
207
|
+
- If 'current', use the current state of the trained model up the latest trained epoch.
|
|
208
|
+
test_data (DataLoader | Dataset | None): Optional Test data to evaluate the model performance. Validation and Test metrics will be saved to subdirectories.
|
|
209
|
+
"""
|
|
210
|
+
# Validate model checkpoint
|
|
211
|
+
if isinstance(model_checkpoint, Path):
|
|
212
|
+
checkpoint_validated = make_fullpath(model_checkpoint, enforce="file")
|
|
213
|
+
elif model_checkpoint in [MagicWords.BEST, MagicWords.CURRENT]:
|
|
214
|
+
checkpoint_validated = model_checkpoint
|
|
215
|
+
else:
|
|
216
|
+
_LOGGER.error(f"'model_checkpoint' must be a Path object, or the string '{MagicWords.BEST}', or the string '{MagicWords.CURRENT}'.")
|
|
217
|
+
raise ValueError()
|
|
218
|
+
|
|
219
|
+
# Validate directory
|
|
220
|
+
save_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
221
|
+
|
|
222
|
+
# Validate test data and dispatch
|
|
223
|
+
if test_data is not None:
|
|
224
|
+
if not isinstance(test_data, (DataLoader, Dataset)):
|
|
225
|
+
_LOGGER.error(f"Invalid type for 'test_data': '{type(test_data)}'.")
|
|
226
|
+
raise ValueError()
|
|
227
|
+
test_data_validated = test_data
|
|
228
|
+
|
|
229
|
+
validation_metrics_path = save_path / DragonTrainerKeys.VALIDATION_METRICS_DIR
|
|
230
|
+
test_metrics_path = save_path / DragonTrainerKeys.TEST_METRICS_DIR
|
|
231
|
+
|
|
232
|
+
# Dispatch validation set
|
|
233
|
+
_LOGGER.info(f"Evaluating on validation dataset. Metrics will be saved to '{DragonTrainerKeys.VALIDATION_METRICS_DIR}'")
|
|
234
|
+
self._evaluate(save_dir=validation_metrics_path,
|
|
235
|
+
model_checkpoint=checkpoint_validated,
|
|
236
|
+
data=None) # 'None' triggers use of self.test_dataset
|
|
237
|
+
|
|
238
|
+
# Dispatch test set
|
|
239
|
+
_LOGGER.info(f"Evaluating on test dataset. Metrics will be saved to '{DragonTrainerKeys.TEST_METRICS_DIR}'")
|
|
240
|
+
self._evaluate(save_dir=test_metrics_path,
|
|
241
|
+
model_checkpoint="current", # Use 'current' state after loading checkpoint once
|
|
242
|
+
data=test_data_validated)
|
|
243
|
+
else:
|
|
244
|
+
# Dispatch validation set
|
|
245
|
+
_LOGGER.info(f"Evaluating on validation dataset. Metrics will be saved to '{save_path.name}'")
|
|
246
|
+
self._evaluate(save_dir=save_path,
|
|
247
|
+
model_checkpoint=checkpoint_validated,
|
|
248
|
+
data=None) # 'None' triggers use of self.test_dataset
|
|
249
|
+
|
|
250
|
+
def _evaluate(self,
|
|
251
|
+
save_dir: Union[str, Path],
|
|
252
|
+
model_checkpoint: Union[Path, Literal["best", "current"]],
|
|
253
|
+
data: Optional[Union[DataLoader, Dataset]]):
|
|
254
|
+
"""
|
|
255
|
+
Changed to a private helper method
|
|
256
|
+
Evaluates the model using object detection mAP metrics.
|
|
257
|
+
|
|
258
|
+
Args:
|
|
259
|
+
save_dir (str | Path): Directory to save all reports and plots.
|
|
260
|
+
data (DataLoader | Dataset | None): The data to evaluate on. If None, defaults to the trainer's internal test_dataset.
|
|
261
|
+
model_checkpoint ('auto' | Path | None):
|
|
262
|
+
- Path to a valid checkpoint for the model. The state of the trained model will be overwritten in place.
|
|
263
|
+
- If 'best', the best checkpoint will be loaded if a DragonModelCheckpoint was provided. The state of the trained model will be overwritten in place.
|
|
264
|
+
- If 'current', use the current state of the trained model up the latest trained epoch.
|
|
265
|
+
"""
|
|
266
|
+
dataset_for_artifacts = None
|
|
267
|
+
eval_loader = None
|
|
268
|
+
|
|
269
|
+
# load model checkpoint
|
|
270
|
+
if isinstance(model_checkpoint, Path):
|
|
271
|
+
self._load_checkpoint(path=model_checkpoint)
|
|
272
|
+
elif model_checkpoint == MagicWords.BEST and self._checkpoint_callback:
|
|
273
|
+
path_to_latest = self._checkpoint_callback.best_checkpoint_path
|
|
274
|
+
self._load_checkpoint(path_to_latest)
|
|
275
|
+
elif model_checkpoint == MagicWords.BEST and self._checkpoint_callback is None:
|
|
276
|
+
_LOGGER.error(f"'model_checkpoint' set to '{MagicWords.BEST}' but no checkpoint callback was found.")
|
|
277
|
+
raise ValueError()
|
|
278
|
+
|
|
279
|
+
# Dataloader
|
|
280
|
+
if isinstance(data, DataLoader):
|
|
281
|
+
eval_loader = data
|
|
282
|
+
if hasattr(data, 'dataset'):
|
|
283
|
+
dataset_for_artifacts = data.dataset # type: ignore
|
|
284
|
+
elif isinstance(data, Dataset):
|
|
285
|
+
# Create a new loader from the provided dataset
|
|
286
|
+
eval_loader = DataLoader(data,
|
|
287
|
+
batch_size=self._batch_size,
|
|
288
|
+
shuffle=False,
|
|
289
|
+
num_workers=0 if self.device.type == 'mps' else self.dataloader_workers,
|
|
290
|
+
pin_memory=(self.device.type == "cuda"),
|
|
291
|
+
collate_fn=self.collate_fn)
|
|
292
|
+
dataset_for_artifacts = data
|
|
293
|
+
else: # data is None, use the trainer's default test dataset
|
|
294
|
+
if self.validation_dataset is None:
|
|
295
|
+
_LOGGER.error("Cannot evaluate. No data provided and no test_dataset available in the trainer.")
|
|
296
|
+
raise ValueError()
|
|
297
|
+
# Create a fresh DataLoader from the test_dataset
|
|
298
|
+
eval_loader = DataLoader(
|
|
299
|
+
self.validation_dataset,
|
|
300
|
+
batch_size=self._batch_size,
|
|
301
|
+
shuffle=False,
|
|
302
|
+
num_workers=0 if self.device.type == 'mps' else self.dataloader_workers,
|
|
303
|
+
pin_memory=(self.device.type == "cuda"),
|
|
304
|
+
collate_fn=self.collate_fn
|
|
305
|
+
)
|
|
306
|
+
dataset_for_artifacts = self.validation_dataset
|
|
307
|
+
|
|
308
|
+
if eval_loader is None:
|
|
309
|
+
_LOGGER.error("Cannot evaluate. No valid data was provided or found.")
|
|
310
|
+
raise ValueError()
|
|
311
|
+
|
|
312
|
+
# print("\n--- Model Evaluation ---")
|
|
313
|
+
|
|
314
|
+
all_predictions = []
|
|
315
|
+
all_targets = []
|
|
316
|
+
|
|
317
|
+
self.model.eval() # Set model to evaluation mode
|
|
318
|
+
self.model.to(self.device)
|
|
319
|
+
|
|
320
|
+
with torch.no_grad():
|
|
321
|
+
for images, targets in eval_loader:
|
|
322
|
+
# Move images to device
|
|
323
|
+
images = list(img.to(self.device) for img in images)
|
|
324
|
+
|
|
325
|
+
# Model returns predictions when in eval() mode
|
|
326
|
+
predictions = self.model(images)
|
|
327
|
+
|
|
328
|
+
# Move predictions and targets to CPU for aggregation
|
|
329
|
+
cpu_preds = [{k: v.to('cpu') for k, v in p.items()} for p in predictions]
|
|
330
|
+
cpu_targets = [{k: v.to('cpu') for k, v in t.items()} for t in targets]
|
|
331
|
+
|
|
332
|
+
all_predictions.extend(cpu_preds)
|
|
333
|
+
all_targets.extend(cpu_targets)
|
|
334
|
+
|
|
335
|
+
if not all_targets:
|
|
336
|
+
_LOGGER.error("Evaluation failed: No data was processed.")
|
|
337
|
+
return
|
|
338
|
+
|
|
339
|
+
# Get class names from the dataset for the report
|
|
340
|
+
class_names = None
|
|
341
|
+
try:
|
|
342
|
+
# Try to get 'classes' from ObjectDetectionDatasetMaker
|
|
343
|
+
if hasattr(dataset_for_artifacts, 'classes'):
|
|
344
|
+
class_names = dataset_for_artifacts.classes # type: ignore
|
|
345
|
+
# Fallback for Subset
|
|
346
|
+
elif hasattr(dataset_for_artifacts, 'dataset') and hasattr(dataset_for_artifacts.dataset, 'classes'): # type: ignore
|
|
347
|
+
class_names = dataset_for_artifacts.dataset.classes # type: ignore
|
|
348
|
+
except AttributeError:
|
|
349
|
+
_LOGGER.warning("Could not find 'classes' attribute on dataset. Per-class metrics will not be named.")
|
|
350
|
+
pass # class_names is still None
|
|
351
|
+
|
|
352
|
+
# --- Routing Logic ---
|
|
353
|
+
object_detection_metrics(
|
|
354
|
+
preds=all_predictions,
|
|
355
|
+
targets=all_targets,
|
|
356
|
+
save_dir=save_dir,
|
|
357
|
+
class_names=class_names,
|
|
358
|
+
print_output=False
|
|
359
|
+
)
|
|
360
|
+
|
|
361
|
+
def finalize_model_training(self,
|
|
362
|
+
save_dir: Union[str, Path],
|
|
363
|
+
model_checkpoint: Union[Path, Literal['best', 'current']],
|
|
364
|
+
finalize_config: FinalizeObjectDetection
|
|
365
|
+
):
|
|
366
|
+
"""
|
|
367
|
+
Saves a finalized, "inference-ready" model state to a .pth file.
|
|
368
|
+
|
|
369
|
+
This method saves the model's `state_dict` and the final epoch number.
|
|
370
|
+
|
|
371
|
+
Args:
|
|
372
|
+
save_dir (Union[str, Path]): The directory to save the finalized model.
|
|
373
|
+
model_checkpoint (Union[Path, Literal["best", "current"]]):
|
|
374
|
+
- Path: Loads the model state from a specific checkpoint file.
|
|
375
|
+
- "best": Loads the best model state saved by the `DragonModelCheckpoint` callback.
|
|
376
|
+
- "current": Uses the model's state as it is.
|
|
377
|
+
finalize_config (FinalizeObjectDetection): A data class instance specific to the ML task containing task-specific metadata required for inference.
|
|
378
|
+
"""
|
|
379
|
+
if not isinstance(finalize_config, FinalizeObjectDetection):
|
|
380
|
+
_LOGGER.error(f"For task {self.kind}, expected finalize_config of type 'FinalizeObjectDetection', but got {type(finalize_config).__name__}.")
|
|
381
|
+
raise TypeError()
|
|
382
|
+
|
|
383
|
+
# handle save path
|
|
384
|
+
dir_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
385
|
+
full_path = dir_path / finalize_config.filename
|
|
386
|
+
|
|
387
|
+
# handle checkpoint
|
|
388
|
+
self._load_model_state_for_finalizing(model_checkpoint)
|
|
389
|
+
|
|
390
|
+
# Create finalized data
|
|
391
|
+
finalized_data = {
|
|
392
|
+
PyTorchCheckpointKeys.EPOCH: self.epoch,
|
|
393
|
+
PyTorchCheckpointKeys.MODEL_STATE: self.model.state_dict(),
|
|
394
|
+
PyTorchCheckpointKeys.TASK: finalize_config.task
|
|
395
|
+
}
|
|
396
|
+
|
|
397
|
+
if finalize_config.class_map is not None:
|
|
398
|
+
finalized_data[PyTorchCheckpointKeys.CLASS_MAP] = finalize_config.class_map
|
|
399
|
+
|
|
400
|
+
torch.save(finalized_data, full_path)
|
|
401
|
+
|
|
402
|
+
_LOGGER.info(f"Finalized model file saved to '{full_path}'")
|