dragon-ml-toolbox 19.13.0__py3-none-any.whl → 20.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/METADATA +29 -46
- dragon_ml_toolbox-20.0.0.dist-info/RECORD +178 -0
- ml_tools/{ETL_cleaning.py → ETL_cleaning/__init__.py} +13 -5
- ml_tools/ETL_cleaning/_basic_clean.py +351 -0
- ml_tools/ETL_cleaning/_clean_tools.py +128 -0
- ml_tools/ETL_cleaning/_dragon_cleaner.py +245 -0
- ml_tools/ETL_cleaning/_imprimir.py +13 -0
- ml_tools/{ETL_engineering.py → ETL_engineering/__init__.py} +8 -4
- ml_tools/ETL_engineering/_dragon_engineering.py +261 -0
- ml_tools/ETL_engineering/_imprimir.py +24 -0
- ml_tools/{_core/_ETL_engineering.py → ETL_engineering/_transforms.py} +14 -267
- ml_tools/{_core → GUI_tools}/_GUI_tools.py +37 -40
- ml_tools/{GUI_tools.py → GUI_tools/__init__.py} +7 -5
- ml_tools/GUI_tools/_imprimir.py +12 -0
- ml_tools/IO_tools/_IO_loggers.py +235 -0
- ml_tools/IO_tools/_IO_save_load.py +151 -0
- ml_tools/IO_tools/_IO_utils.py +140 -0
- ml_tools/{IO_tools.py → IO_tools/__init__.py} +13 -5
- ml_tools/IO_tools/_imprimir.py +14 -0
- ml_tools/MICE/_MICE_imputation.py +132 -0
- ml_tools/{MICE_imputation.py → MICE/__init__.py} +6 -7
- ml_tools/{_core/_MICE_imputation.py → MICE/_dragon_mice.py} +243 -322
- ml_tools/MICE/_imprimir.py +11 -0
- ml_tools/{ML_callbacks.py → ML_callbacks/__init__.py} +12 -4
- ml_tools/ML_callbacks/_base.py +101 -0
- ml_tools/ML_callbacks/_checkpoint.py +232 -0
- ml_tools/ML_callbacks/_early_stop.py +208 -0
- ml_tools/ML_callbacks/_imprimir.py +12 -0
- ml_tools/ML_callbacks/_scheduler.py +197 -0
- ml_tools/{ML_chaining_utilities.py → ML_chain/__init__.py} +8 -3
- ml_tools/{_core/_ML_chaining_utilities.py → ML_chain/_chaining_tools.py} +5 -129
- ml_tools/ML_chain/_dragon_chain.py +140 -0
- ml_tools/ML_chain/_imprimir.py +11 -0
- ml_tools/ML_configuration/__init__.py +90 -0
- ml_tools/ML_configuration/_base_model_config.py +69 -0
- ml_tools/ML_configuration/_finalize.py +366 -0
- ml_tools/ML_configuration/_imprimir.py +47 -0
- ml_tools/ML_configuration/_metrics.py +593 -0
- ml_tools/ML_configuration/_models.py +206 -0
- ml_tools/ML_configuration/_training.py +124 -0
- ml_tools/ML_datasetmaster/__init__.py +28 -0
- ml_tools/ML_datasetmaster/_base_datasetmaster.py +337 -0
- ml_tools/{_core/_ML_datasetmaster.py → ML_datasetmaster/_datasetmaster.py} +9 -329
- ml_tools/ML_datasetmaster/_imprimir.py +15 -0
- ml_tools/{_core/_ML_sequence_datasetmaster.py → ML_datasetmaster/_sequence_datasetmaster.py} +13 -15
- ml_tools/{_core/_ML_vision_datasetmaster.py → ML_datasetmaster/_vision_datasetmaster.py} +63 -65
- ml_tools/ML_evaluation/__init__.py +53 -0
- ml_tools/ML_evaluation/_classification.py +629 -0
- ml_tools/ML_evaluation/_feature_importance.py +409 -0
- ml_tools/ML_evaluation/_imprimir.py +25 -0
- ml_tools/ML_evaluation/_loss.py +92 -0
- ml_tools/ML_evaluation/_regression.py +273 -0
- ml_tools/{_core/_ML_sequence_evaluation.py → ML_evaluation/_sequence.py} +8 -11
- ml_tools/{_core/_ML_vision_evaluation.py → ML_evaluation/_vision.py} +12 -17
- ml_tools/{_core → ML_evaluation_captum}/_ML_evaluation_captum.py +11 -38
- ml_tools/{ML_evaluation_captum.py → ML_evaluation_captum/__init__.py} +6 -4
- ml_tools/ML_evaluation_captum/_imprimir.py +10 -0
- ml_tools/{_core → ML_finalize_handler}/_ML_finalize_handler.py +3 -7
- ml_tools/ML_finalize_handler/__init__.py +10 -0
- ml_tools/ML_finalize_handler/_imprimir.py +8 -0
- ml_tools/ML_inference/__init__.py +22 -0
- ml_tools/ML_inference/_base_inference.py +166 -0
- ml_tools/{_core/_ML_chaining_inference.py → ML_inference/_chain_inference.py} +14 -17
- ml_tools/ML_inference/_dragon_inference.py +332 -0
- ml_tools/ML_inference/_imprimir.py +11 -0
- ml_tools/ML_inference/_multi_inference.py +180 -0
- ml_tools/ML_inference_sequence/__init__.py +10 -0
- ml_tools/ML_inference_sequence/_imprimir.py +8 -0
- ml_tools/{_core/_ML_sequence_inference.py → ML_inference_sequence/_sequence_inference.py} +11 -15
- ml_tools/ML_inference_vision/__init__.py +10 -0
- ml_tools/ML_inference_vision/_imprimir.py +8 -0
- ml_tools/{_core/_ML_vision_inference.py → ML_inference_vision/_vision_inference.py} +15 -19
- ml_tools/ML_models/__init__.py +32 -0
- ml_tools/{_core/_ML_models_advanced.py → ML_models/_advanced_models.py} +22 -18
- ml_tools/ML_models/_base_mlp_attention.py +198 -0
- ml_tools/{_core/_models_advanced_base.py → ML_models/_base_save_load.py} +73 -49
- ml_tools/ML_models/_dragon_tabular.py +248 -0
- ml_tools/ML_models/_imprimir.py +18 -0
- ml_tools/ML_models/_mlp_attention.py +134 -0
- ml_tools/{_core → ML_models}/_models_advanced_helpers.py +13 -13
- ml_tools/ML_models_sequence/__init__.py +10 -0
- ml_tools/ML_models_sequence/_imprimir.py +8 -0
- ml_tools/{_core/_ML_sequence_models.py → ML_models_sequence/_sequence_models.py} +5 -8
- ml_tools/ML_models_vision/__init__.py +29 -0
- ml_tools/ML_models_vision/_base_wrapper.py +254 -0
- ml_tools/ML_models_vision/_image_classification.py +182 -0
- ml_tools/ML_models_vision/_image_segmentation.py +108 -0
- ml_tools/ML_models_vision/_imprimir.py +16 -0
- ml_tools/ML_models_vision/_object_detection.py +135 -0
- ml_tools/ML_optimization/__init__.py +21 -0
- ml_tools/ML_optimization/_imprimir.py +13 -0
- ml_tools/{_core/_ML_optimization_pareto.py → ML_optimization/_multi_dragon.py} +18 -24
- ml_tools/ML_optimization/_single_dragon.py +203 -0
- ml_tools/{_core/_ML_optimization.py → ML_optimization/_single_manual.py} +75 -213
- ml_tools/{_core → ML_scaler}/_ML_scaler.py +8 -11
- ml_tools/ML_scaler/__init__.py +10 -0
- ml_tools/ML_scaler/_imprimir.py +8 -0
- ml_tools/ML_trainer/__init__.py +20 -0
- ml_tools/ML_trainer/_base_trainer.py +297 -0
- ml_tools/ML_trainer/_dragon_detection_trainer.py +402 -0
- ml_tools/ML_trainer/_dragon_sequence_trainer.py +540 -0
- ml_tools/ML_trainer/_dragon_trainer.py +1160 -0
- ml_tools/ML_trainer/_imprimir.py +10 -0
- ml_tools/{ML_utilities.py → ML_utilities/__init__.py} +14 -6
- ml_tools/ML_utilities/_artifact_finder.py +382 -0
- ml_tools/ML_utilities/_imprimir.py +16 -0
- ml_tools/ML_utilities/_inspection.py +325 -0
- ml_tools/ML_utilities/_train_tools.py +205 -0
- ml_tools/{ML_vision_transformers.py → ML_vision_transformers/__init__.py} +9 -6
- ml_tools/{_core/_ML_vision_transformers.py → ML_vision_transformers/_core_transforms.py} +11 -155
- ml_tools/ML_vision_transformers/_imprimir.py +14 -0
- ml_tools/ML_vision_transformers/_offline_augmentation.py +159 -0
- ml_tools/{_core/_PSO_optimization.py → PSO_optimization/_PSO.py} +58 -15
- ml_tools/{PSO_optimization.py → PSO_optimization/__init__.py} +5 -3
- ml_tools/PSO_optimization/_imprimir.py +10 -0
- ml_tools/SQL/__init__.py +7 -0
- ml_tools/{_core/_SQL.py → SQL/_dragon_SQL.py} +7 -11
- ml_tools/SQL/_imprimir.py +8 -0
- ml_tools/{_core → VIF}/_VIF_factor.py +5 -8
- ml_tools/{VIF_factor.py → VIF/__init__.py} +4 -2
- ml_tools/VIF/_imprimir.py +10 -0
- ml_tools/_core/__init__.py +7 -1
- ml_tools/_core/_logger.py +8 -18
- ml_tools/_core/_schema_load_ops.py +43 -0
- ml_tools/_core/_script_info.py +2 -2
- ml_tools/{data_exploration.py → data_exploration/__init__.py} +32 -16
- ml_tools/data_exploration/_analysis.py +214 -0
- ml_tools/data_exploration/_cleaning.py +566 -0
- ml_tools/data_exploration/_features.py +583 -0
- ml_tools/data_exploration/_imprimir.py +32 -0
- ml_tools/data_exploration/_plotting.py +487 -0
- ml_tools/data_exploration/_schema_ops.py +176 -0
- ml_tools/{ensemble_evaluation.py → ensemble_evaluation/__init__.py} +6 -4
- ml_tools/{_core → ensemble_evaluation}/_ensemble_evaluation.py +3 -7
- ml_tools/ensemble_evaluation/_imprimir.py +14 -0
- ml_tools/{ensemble_inference.py → ensemble_inference/__init__.py} +5 -3
- ml_tools/{_core → ensemble_inference}/_ensemble_inference.py +15 -18
- ml_tools/ensemble_inference/_imprimir.py +9 -0
- ml_tools/{ensemble_learning.py → ensemble_learning/__init__.py} +4 -6
- ml_tools/{_core → ensemble_learning}/_ensemble_learning.py +7 -10
- ml_tools/ensemble_learning/_imprimir.py +10 -0
- ml_tools/{excel_handler.py → excel_handler/__init__.py} +5 -3
- ml_tools/{_core → excel_handler}/_excel_handler.py +6 -10
- ml_tools/excel_handler/_imprimir.py +13 -0
- ml_tools/{keys.py → keys/__init__.py} +4 -1
- ml_tools/keys/_imprimir.py +11 -0
- ml_tools/{_core → keys}/_keys.py +2 -0
- ml_tools/{math_utilities.py → math_utilities/__init__.py} +5 -2
- ml_tools/math_utilities/_imprimir.py +11 -0
- ml_tools/{_core → math_utilities}/_math_utilities.py +1 -5
- ml_tools/{optimization_tools.py → optimization_tools/__init__.py} +9 -4
- ml_tools/optimization_tools/_imprimir.py +13 -0
- ml_tools/optimization_tools/_optimization_bounds.py +236 -0
- ml_tools/optimization_tools/_optimization_plots.py +218 -0
- ml_tools/{path_manager.py → path_manager/__init__.py} +6 -3
- ml_tools/{_core/_path_manager.py → path_manager/_dragonmanager.py} +11 -347
- ml_tools/path_manager/_imprimir.py +15 -0
- ml_tools/path_manager/_path_tools.py +346 -0
- ml_tools/plot_fonts/__init__.py +8 -0
- ml_tools/plot_fonts/_imprimir.py +8 -0
- ml_tools/{_core → plot_fonts}/_plot_fonts.py +2 -5
- ml_tools/schema/__init__.py +15 -0
- ml_tools/schema/_feature_schema.py +223 -0
- ml_tools/schema/_gui_schema.py +191 -0
- ml_tools/schema/_imprimir.py +10 -0
- ml_tools/{serde.py → serde/__init__.py} +4 -2
- ml_tools/serde/_imprimir.py +10 -0
- ml_tools/{_core → serde}/_serde.py +3 -8
- ml_tools/{utilities.py → utilities/__init__.py} +11 -6
- ml_tools/utilities/_imprimir.py +18 -0
- ml_tools/{_core/_utilities.py → utilities/_utility_save_load.py} +13 -190
- ml_tools/utilities/_utility_tools.py +192 -0
- dragon_ml_toolbox-19.13.0.dist-info/RECORD +0 -111
- ml_tools/ML_chaining_inference.py +0 -8
- ml_tools/ML_configuration.py +0 -86
- ml_tools/ML_configuration_pytab.py +0 -14
- ml_tools/ML_datasetmaster.py +0 -10
- ml_tools/ML_evaluation.py +0 -16
- ml_tools/ML_evaluation_multi.py +0 -12
- ml_tools/ML_finalize_handler.py +0 -8
- ml_tools/ML_inference.py +0 -12
- ml_tools/ML_models.py +0 -14
- ml_tools/ML_models_advanced.py +0 -14
- ml_tools/ML_models_pytab.py +0 -14
- ml_tools/ML_optimization.py +0 -14
- ml_tools/ML_optimization_pareto.py +0 -8
- ml_tools/ML_scaler.py +0 -8
- ml_tools/ML_sequence_datasetmaster.py +0 -8
- ml_tools/ML_sequence_evaluation.py +0 -10
- ml_tools/ML_sequence_inference.py +0 -8
- ml_tools/ML_sequence_models.py +0 -8
- ml_tools/ML_trainer.py +0 -12
- ml_tools/ML_vision_datasetmaster.py +0 -12
- ml_tools/ML_vision_evaluation.py +0 -10
- ml_tools/ML_vision_inference.py +0 -8
- ml_tools/ML_vision_models.py +0 -18
- ml_tools/SQL.py +0 -8
- ml_tools/_core/_ETL_cleaning.py +0 -694
- ml_tools/_core/_IO_tools.py +0 -498
- ml_tools/_core/_ML_callbacks.py +0 -702
- ml_tools/_core/_ML_configuration.py +0 -1332
- ml_tools/_core/_ML_configuration_pytab.py +0 -102
- ml_tools/_core/_ML_evaluation.py +0 -867
- ml_tools/_core/_ML_evaluation_multi.py +0 -544
- ml_tools/_core/_ML_inference.py +0 -646
- ml_tools/_core/_ML_models.py +0 -668
- ml_tools/_core/_ML_models_pytab.py +0 -693
- ml_tools/_core/_ML_trainer.py +0 -2323
- ml_tools/_core/_ML_utilities.py +0 -886
- ml_tools/_core/_ML_vision_models.py +0 -644
- ml_tools/_core/_data_exploration.py +0 -1901
- ml_tools/_core/_optimization_tools.py +0 -493
- ml_tools/_core/_schema.py +0 -359
- ml_tools/plot_fonts.py +0 -8
- ml_tools/schema.py +0 -12
- {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/WHEEL +0 -0
- {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/licenses/LICENSE +0 -0
- {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +0 -0
- {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,540 @@
|
|
|
1
|
+
from typing import Literal, Union, Optional
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from torch.utils.data import DataLoader, Dataset
|
|
4
|
+
import torch
|
|
5
|
+
from torch import nn
|
|
6
|
+
import numpy as np
|
|
7
|
+
|
|
8
|
+
from ..ML_callbacks._base import _Callback
|
|
9
|
+
from ..ML_callbacks._checkpoint import DragonModelCheckpoint
|
|
10
|
+
from ..ML_callbacks._early_stop import _DragonEarlyStopping
|
|
11
|
+
from ..ML_callbacks._scheduler import _DragonLRScheduler
|
|
12
|
+
from ..ML_evaluation import sequence_to_sequence_metrics, sequence_to_value_metrics
|
|
13
|
+
from ..ML_evaluation_captum import captum_feature_importance
|
|
14
|
+
from ..ML_configuration import (FormatSequenceValueMetrics,
|
|
15
|
+
FormatSequenceSequenceMetrics,
|
|
16
|
+
|
|
17
|
+
FinalizeSequenceSequencePrediction,
|
|
18
|
+
FinalizeSequenceValuePrediction)
|
|
19
|
+
|
|
20
|
+
from ..path_manager import make_fullpath
|
|
21
|
+
from ..keys._keys import PyTorchLogKeys, PyTorchCheckpointKeys, DatasetKeys, MLTaskKeys, MagicWords, DragonTrainerKeys, ScalerKeys
|
|
22
|
+
from .._core import get_logger
|
|
23
|
+
|
|
24
|
+
from ._base_trainer import _BaseDragonTrainer
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
_LOGGER = get_logger("DragonSequenceTrainer")
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
__all__ = [
|
|
31
|
+
"DragonSequenceTrainer"
|
|
32
|
+
]
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
# --- DragonSequenceTrainer ----
|
|
36
|
+
class DragonSequenceTrainer(_BaseDragonTrainer):
|
|
37
|
+
def __init__(self,
|
|
38
|
+
model: nn.Module,
|
|
39
|
+
train_dataset: Dataset,
|
|
40
|
+
validation_dataset: Dataset,
|
|
41
|
+
kind: Literal["sequence-to-sequence", "sequence-to-value"],
|
|
42
|
+
optimizer: torch.optim.Optimizer,
|
|
43
|
+
device: Union[Literal['cuda', 'mps', 'cpu'],str],
|
|
44
|
+
checkpoint_callback: Optional[DragonModelCheckpoint],
|
|
45
|
+
early_stopping_callback: Optional[_DragonEarlyStopping],
|
|
46
|
+
lr_scheduler_callback: Optional[_DragonLRScheduler],
|
|
47
|
+
extra_callbacks: Optional[list[_Callback]] = None,
|
|
48
|
+
criterion: Union[nn.Module,Literal["auto"]] = "auto",
|
|
49
|
+
dataloader_workers: int = 2):
|
|
50
|
+
"""
|
|
51
|
+
Automates the training process of a PyTorch Sequence Model.
|
|
52
|
+
|
|
53
|
+
Built-in Callbacks: `History`, `TqdmProgressBar`
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
model (nn.Module): The PyTorch model to train.
|
|
57
|
+
train_dataset (Dataset): The training dataset.
|
|
58
|
+
validation_dataset (Dataset): The validation dataset.
|
|
59
|
+
kind (str): Used to redirect to the correct process ('sequence-to-sequence' or 'sequence-to-value').
|
|
60
|
+
criterion (nn.Module | "auto"): The loss function to use. If "auto", it will be inferred from the selected task
|
|
61
|
+
optimizer (torch.optim.Optimizer): The optimizer.
|
|
62
|
+
device (str): The device to run training on ('cpu', 'cuda', 'mps').
|
|
63
|
+
dataloader_workers (int): Subprocesses for data loading.
|
|
64
|
+
extra_callbacks (List[Callback] | None): A list of extra callbacks to use during training.
|
|
65
|
+
"""
|
|
66
|
+
# Call the base class constructor with common parameters
|
|
67
|
+
super().__init__(
|
|
68
|
+
model=model,
|
|
69
|
+
optimizer=optimizer,
|
|
70
|
+
device=device,
|
|
71
|
+
dataloader_workers=dataloader_workers,
|
|
72
|
+
checkpoint_callback=checkpoint_callback,
|
|
73
|
+
early_stopping_callback=early_stopping_callback,
|
|
74
|
+
lr_scheduler_callback=lr_scheduler_callback,
|
|
75
|
+
extra_callbacks=extra_callbacks
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
if kind not in [MLTaskKeys.SEQUENCE_SEQUENCE, MLTaskKeys.SEQUENCE_VALUE]:
|
|
79
|
+
raise ValueError(f"'{kind}' is not a valid task type for DragonSequenceTrainer.")
|
|
80
|
+
|
|
81
|
+
self.train_dataset = train_dataset
|
|
82
|
+
self.validation_dataset = validation_dataset
|
|
83
|
+
self.kind = kind
|
|
84
|
+
|
|
85
|
+
# try to validate against Dragon Sequence model
|
|
86
|
+
if hasattr(self.model, "prediction_mode"):
|
|
87
|
+
key_to_check: str = self.model.prediction_mode # type: ignore
|
|
88
|
+
if not key_to_check == self.kind:
|
|
89
|
+
_LOGGER.error(f"Trainer was set for '{self.kind}', but model architecture '{self.model}' is built for '{key_to_check}'.")
|
|
90
|
+
raise RuntimeError()
|
|
91
|
+
|
|
92
|
+
# loss function
|
|
93
|
+
if criterion == "auto":
|
|
94
|
+
# Both sequence tasks are treated as regression problems
|
|
95
|
+
self.criterion = nn.MSELoss()
|
|
96
|
+
else:
|
|
97
|
+
self.criterion = criterion
|
|
98
|
+
|
|
99
|
+
def _create_dataloaders(self, batch_size: int, shuffle: bool):
|
|
100
|
+
"""Initializes the DataLoaders."""
|
|
101
|
+
# Ensure stability on MPS devices by setting num_workers to 0
|
|
102
|
+
loader_workers = 0 if self.device.type == 'mps' else self.dataloader_workers
|
|
103
|
+
|
|
104
|
+
self.train_loader = DataLoader(
|
|
105
|
+
dataset=self.train_dataset,
|
|
106
|
+
batch_size=batch_size,
|
|
107
|
+
shuffle=shuffle,
|
|
108
|
+
num_workers=loader_workers,
|
|
109
|
+
pin_memory=("cuda" in self.device.type),
|
|
110
|
+
drop_last=True # Drops the last batch if incomplete, selecting a good batch size is key.
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
self.validation_loader = DataLoader(
|
|
114
|
+
dataset=self.validation_dataset,
|
|
115
|
+
batch_size=batch_size,
|
|
116
|
+
shuffle=False,
|
|
117
|
+
num_workers=loader_workers,
|
|
118
|
+
pin_memory=("cuda" in self.device.type)
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
def _train_step(self):
|
|
122
|
+
self.model.train()
|
|
123
|
+
running_loss = 0.0
|
|
124
|
+
total_samples = 0
|
|
125
|
+
|
|
126
|
+
for batch_idx, (features, target) in enumerate(self.train_loader): # type: ignore
|
|
127
|
+
# Create a log dictionary for the batch
|
|
128
|
+
batch_logs = {
|
|
129
|
+
PyTorchLogKeys.BATCH_INDEX: batch_idx,
|
|
130
|
+
PyTorchLogKeys.BATCH_SIZE: features.size(0)
|
|
131
|
+
}
|
|
132
|
+
self._callbacks_hook('on_batch_begin', batch_idx, logs=batch_logs)
|
|
133
|
+
|
|
134
|
+
features, target = features.to(self.device), target.to(self.device)
|
|
135
|
+
self.optimizer.zero_grad()
|
|
136
|
+
|
|
137
|
+
output = self.model(features)
|
|
138
|
+
|
|
139
|
+
# --- Label Type/Shape Correction ---
|
|
140
|
+
# Ensure target is float for MSELoss
|
|
141
|
+
target = target.float()
|
|
142
|
+
|
|
143
|
+
# For seq-to-val, models might output [N, 1] but target is [N].
|
|
144
|
+
if self.kind == MLTaskKeys.SEQUENCE_VALUE:
|
|
145
|
+
if output.ndim == 2 and output.shape[1] == 1 and target.ndim == 1:
|
|
146
|
+
output = output.squeeze(1)
|
|
147
|
+
|
|
148
|
+
# For seq-to-seq, models might output [N, Seq, 1] but target is [N, Seq].
|
|
149
|
+
elif self.kind == MLTaskKeys.SEQUENCE_SEQUENCE:
|
|
150
|
+
if output.ndim == 3 and output.shape[2] == 1 and target.ndim == 2:
|
|
151
|
+
output = output.squeeze(-1)
|
|
152
|
+
|
|
153
|
+
loss = self.criterion(output, target)
|
|
154
|
+
|
|
155
|
+
loss.backward()
|
|
156
|
+
self.optimizer.step()
|
|
157
|
+
|
|
158
|
+
# Calculate batch loss and update running loss for the epoch
|
|
159
|
+
batch_loss = loss.item()
|
|
160
|
+
batch_size = features.size(0)
|
|
161
|
+
running_loss += batch_loss * batch_size # Accumulate total loss
|
|
162
|
+
total_samples += batch_size # total samples
|
|
163
|
+
|
|
164
|
+
# Add the batch loss to the logs and call the end-of-batch hook
|
|
165
|
+
batch_logs[PyTorchLogKeys.BATCH_LOSS] = batch_loss
|
|
166
|
+
self._callbacks_hook('on_batch_end', batch_idx, logs=batch_logs)
|
|
167
|
+
|
|
168
|
+
if total_samples == 0:
|
|
169
|
+
_LOGGER.warning("No samples processed in a train_step. Returning 0 loss.")
|
|
170
|
+
return {PyTorchLogKeys.TRAIN_LOSS: 0.0}
|
|
171
|
+
|
|
172
|
+
return {PyTorchLogKeys.TRAIN_LOSS: running_loss / total_samples} # type: ignore
|
|
173
|
+
|
|
174
|
+
def _validation_step(self):
|
|
175
|
+
self.model.eval()
|
|
176
|
+
running_loss = 0.0
|
|
177
|
+
|
|
178
|
+
with torch.no_grad():
|
|
179
|
+
for features, target in self.validation_loader: # type: ignore
|
|
180
|
+
features, target = features.to(self.device), target.to(self.device)
|
|
181
|
+
|
|
182
|
+
output = self.model(features)
|
|
183
|
+
|
|
184
|
+
# --- Label Type/Shape Correction ---
|
|
185
|
+
target = target.float()
|
|
186
|
+
|
|
187
|
+
# For seq-to-val, models might output [N, 1] but target is [N].
|
|
188
|
+
if self.kind == MLTaskKeys.SEQUENCE_VALUE:
|
|
189
|
+
if output.ndim == 2 and output.shape[1] == 1 and target.ndim == 1:
|
|
190
|
+
output = output.squeeze(1)
|
|
191
|
+
|
|
192
|
+
# For seq-to-seq, models might output [N, Seq, 1] but target is [N, Seq].
|
|
193
|
+
elif self.kind == MLTaskKeys.SEQUENCE_SEQUENCE:
|
|
194
|
+
if output.ndim == 3 and output.shape[2] == 1 and target.ndim == 2:
|
|
195
|
+
output = output.squeeze(-1)
|
|
196
|
+
|
|
197
|
+
loss = self.criterion(output, target)
|
|
198
|
+
|
|
199
|
+
running_loss += loss.item() * features.size(0)
|
|
200
|
+
|
|
201
|
+
if not self.validation_loader.dataset: # type: ignore
|
|
202
|
+
_LOGGER.warning("No samples processed in _validation_step. Returning 0 loss.")
|
|
203
|
+
return {PyTorchLogKeys.VAL_LOSS: 0.0}
|
|
204
|
+
|
|
205
|
+
logs = {PyTorchLogKeys.VAL_LOSS: running_loss / len(self.validation_loader.dataset)} # type: ignore
|
|
206
|
+
return logs
|
|
207
|
+
|
|
208
|
+
def _predict_for_eval(self, dataloader: DataLoader):
|
|
209
|
+
"""
|
|
210
|
+
Private method to yield model predictions batch by batch for evaluation.
|
|
211
|
+
|
|
212
|
+
Automatically checks for 'scaler'.
|
|
213
|
+
|
|
214
|
+
Yields:
|
|
215
|
+
tuple: A tuple containing (y_pred_batch, y_prob_batch, y_true_batch).
|
|
216
|
+
y_prob_batch is always None for sequence tasks.
|
|
217
|
+
"""
|
|
218
|
+
self.model.eval()
|
|
219
|
+
self.model.to(self.device)
|
|
220
|
+
|
|
221
|
+
# --- Check for Scaler ---
|
|
222
|
+
# DragonDatasetSequence stores it as 'scaler'
|
|
223
|
+
scaler = None
|
|
224
|
+
if hasattr(self.train_dataset, ScalerKeys.TARGET_SCALER):
|
|
225
|
+
scaler = getattr(self.train_dataset, ScalerKeys.TARGET_SCALER)
|
|
226
|
+
if scaler is not None:
|
|
227
|
+
_LOGGER.debug("Sequence scaler detected. Un-scaling predictions and targets.")
|
|
228
|
+
|
|
229
|
+
with torch.no_grad():
|
|
230
|
+
for features, target in dataloader:
|
|
231
|
+
features = features.to(self.device)
|
|
232
|
+
target = target.to(self.device)
|
|
233
|
+
|
|
234
|
+
output = self.model(features)
|
|
235
|
+
|
|
236
|
+
# --- Automatic Un-scaling Logic ---
|
|
237
|
+
if scaler:
|
|
238
|
+
# 1. Reshape for scaler (N, 1) or (N*Seq, 1)
|
|
239
|
+
original_out_shape = output.shape
|
|
240
|
+
original_target_shape = target.shape
|
|
241
|
+
|
|
242
|
+
# Flatten sequence dims
|
|
243
|
+
output_flat = output.reshape(-1, 1)
|
|
244
|
+
target_flat = target.reshape(-1, 1)
|
|
245
|
+
|
|
246
|
+
# 2. Inverse Transform
|
|
247
|
+
output_flat = scaler.inverse_transform(output_flat)
|
|
248
|
+
target_flat = scaler.inverse_transform(target_flat)
|
|
249
|
+
|
|
250
|
+
# 3. Restore
|
|
251
|
+
output = output_flat.reshape(original_out_shape)
|
|
252
|
+
target = target_flat.reshape(original_target_shape)
|
|
253
|
+
|
|
254
|
+
# Move to CPU
|
|
255
|
+
y_pred_batch = output.cpu().numpy()
|
|
256
|
+
y_true_batch = target.cpu().numpy()
|
|
257
|
+
y_prob_batch = None
|
|
258
|
+
|
|
259
|
+
yield y_pred_batch, y_prob_batch, y_true_batch
|
|
260
|
+
|
|
261
|
+
def evaluate(self,
|
|
262
|
+
save_dir: Union[str, Path],
|
|
263
|
+
model_checkpoint: Union[Path, Literal["best", "current"]],
|
|
264
|
+
test_data: Optional[Union[DataLoader, Dataset]] = None,
|
|
265
|
+
val_format_configuration: Optional[Union[FormatSequenceValueMetrics,
|
|
266
|
+
FormatSequenceSequenceMetrics]]=None,
|
|
267
|
+
test_format_configuration: Optional[Union[FormatSequenceValueMetrics,
|
|
268
|
+
FormatSequenceSequenceMetrics]]=None):
|
|
269
|
+
"""
|
|
270
|
+
Evaluates the model, routing to the correct evaluation function.
|
|
271
|
+
|
|
272
|
+
Args:
|
|
273
|
+
model_checkpoint (Path | "best" | "current"):
|
|
274
|
+
- Path to a valid checkpoint for the model.
|
|
275
|
+
- If 'best', the best checkpoint will be loaded.
|
|
276
|
+
- If 'current', use the current state of the trained model.
|
|
277
|
+
save_dir (str | Path): Directory to save all reports and plots.
|
|
278
|
+
test_data (DataLoader | Dataset | None): Optional Test data.
|
|
279
|
+
val_format_configuration: Optional configuration for validation metrics.
|
|
280
|
+
test_format_configuration: Optional configuration for test metrics.
|
|
281
|
+
"""
|
|
282
|
+
# Validate model checkpoint
|
|
283
|
+
if isinstance(model_checkpoint, Path):
|
|
284
|
+
checkpoint_validated = make_fullpath(model_checkpoint, enforce="file")
|
|
285
|
+
elif model_checkpoint in [MagicWords.BEST, MagicWords.CURRENT]:
|
|
286
|
+
checkpoint_validated = model_checkpoint
|
|
287
|
+
else:
|
|
288
|
+
_LOGGER.error(f"'model_checkpoint' must be a Path object, or '{MagicWords.BEST}', or '{MagicWords.CURRENT}'.")
|
|
289
|
+
raise ValueError()
|
|
290
|
+
|
|
291
|
+
# Validate val configuration
|
|
292
|
+
if val_format_configuration is not None:
|
|
293
|
+
if not isinstance(val_format_configuration, (FormatSequenceValueMetrics, FormatSequenceSequenceMetrics)):
|
|
294
|
+
_LOGGER.error(f"Invalid 'val_format_configuration': '{type(val_format_configuration)}'.")
|
|
295
|
+
raise ValueError()
|
|
296
|
+
|
|
297
|
+
# Validate directory
|
|
298
|
+
save_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
299
|
+
|
|
300
|
+
# Validate test data and dispatch
|
|
301
|
+
if test_data is not None:
|
|
302
|
+
if not isinstance(test_data, (DataLoader, Dataset)):
|
|
303
|
+
_LOGGER.error(f"Invalid type for 'test_data': '{type(test_data)}'.")
|
|
304
|
+
raise ValueError()
|
|
305
|
+
test_data_validated = test_data
|
|
306
|
+
|
|
307
|
+
validation_metrics_path = save_path / DragonTrainerKeys.VALIDATION_METRICS_DIR
|
|
308
|
+
test_metrics_path = save_path / DragonTrainerKeys.TEST_METRICS_DIR
|
|
309
|
+
|
|
310
|
+
# Dispatch validation set
|
|
311
|
+
_LOGGER.info(f"Evaluating on validation dataset. Metrics will be saved to '{DragonTrainerKeys.VALIDATION_METRICS_DIR}'")
|
|
312
|
+
self._evaluate(save_dir=validation_metrics_path,
|
|
313
|
+
model_checkpoint=checkpoint_validated,
|
|
314
|
+
data=None,
|
|
315
|
+
format_configuration=val_format_configuration)
|
|
316
|
+
|
|
317
|
+
# Validate test configuration
|
|
318
|
+
test_configuration_validated = None
|
|
319
|
+
if test_format_configuration is not None:
|
|
320
|
+
if not isinstance(test_format_configuration, (FormatSequenceValueMetrics, FormatSequenceSequenceMetrics)):
|
|
321
|
+
warning_message_type = f"Invalid test_format_configuration': '{type(test_format_configuration)}'."
|
|
322
|
+
if val_format_configuration is not None:
|
|
323
|
+
warning_message_type += " 'val_format_configuration' will be used."
|
|
324
|
+
test_configuration_validated = val_format_configuration
|
|
325
|
+
else:
|
|
326
|
+
warning_message_type += " Using default format."
|
|
327
|
+
_LOGGER.warning(warning_message_type)
|
|
328
|
+
else:
|
|
329
|
+
test_configuration_validated = test_format_configuration
|
|
330
|
+
|
|
331
|
+
# Dispatch test set
|
|
332
|
+
_LOGGER.info(f"Evaluating on test dataset. Metrics will be saved to '{DragonTrainerKeys.TEST_METRICS_DIR}'")
|
|
333
|
+
self._evaluate(save_dir=test_metrics_path,
|
|
334
|
+
model_checkpoint="current",
|
|
335
|
+
data=test_data_validated,
|
|
336
|
+
format_configuration=test_configuration_validated)
|
|
337
|
+
else:
|
|
338
|
+
# Dispatch validation set
|
|
339
|
+
_LOGGER.info(f"Evaluating on validation dataset. Metrics will be saved to '{save_path.name}'")
|
|
340
|
+
self._evaluate(save_dir=save_path,
|
|
341
|
+
model_checkpoint=checkpoint_validated,
|
|
342
|
+
data=None,
|
|
343
|
+
format_configuration=val_format_configuration)
|
|
344
|
+
|
|
345
|
+
def _evaluate(self,
|
|
346
|
+
save_dir: Union[str, Path],
|
|
347
|
+
model_checkpoint: Union[Path, Literal["best", "current"]],
|
|
348
|
+
data: Optional[Union[DataLoader, Dataset]],
|
|
349
|
+
format_configuration: object):
|
|
350
|
+
"""
|
|
351
|
+
Private evaluation helper.
|
|
352
|
+
"""
|
|
353
|
+
eval_loader = None
|
|
354
|
+
|
|
355
|
+
# load model checkpoint
|
|
356
|
+
if isinstance(model_checkpoint, Path):
|
|
357
|
+
self._load_checkpoint(path=model_checkpoint)
|
|
358
|
+
elif model_checkpoint == MagicWords.BEST and self._checkpoint_callback:
|
|
359
|
+
path_to_latest = self._checkpoint_callback.best_checkpoint_path
|
|
360
|
+
self._load_checkpoint(path_to_latest)
|
|
361
|
+
elif model_checkpoint == MagicWords.BEST and self._checkpoint_callback is None:
|
|
362
|
+
_LOGGER.error(f"'model_checkpoint' set to '{MagicWords.BEST}' but no checkpoint callback was found.")
|
|
363
|
+
raise ValueError()
|
|
364
|
+
|
|
365
|
+
# Dataloader
|
|
366
|
+
if isinstance(data, DataLoader):
|
|
367
|
+
eval_loader = data
|
|
368
|
+
elif isinstance(data, Dataset):
|
|
369
|
+
# Create a new loader from the provided dataset
|
|
370
|
+
eval_loader = DataLoader(data,
|
|
371
|
+
batch_size=self._batch_size,
|
|
372
|
+
shuffle=False,
|
|
373
|
+
num_workers=0 if self.device.type == 'mps' else self.dataloader_workers,
|
|
374
|
+
pin_memory=(self.device.type == "cuda"))
|
|
375
|
+
else: # data is None, use the trainer's default validation dataset
|
|
376
|
+
if self.validation_dataset is None:
|
|
377
|
+
_LOGGER.error("Cannot evaluate. No data provided and no validation_dataset available in the trainer.")
|
|
378
|
+
raise ValueError()
|
|
379
|
+
eval_loader = DataLoader(self.validation_dataset,
|
|
380
|
+
batch_size=self._batch_size,
|
|
381
|
+
shuffle=False,
|
|
382
|
+
num_workers=0 if self.device.type == 'mps' else self.dataloader_workers,
|
|
383
|
+
pin_memory=(self.device.type == "cuda"))
|
|
384
|
+
|
|
385
|
+
if eval_loader is None:
|
|
386
|
+
_LOGGER.error("Cannot evaluate. No valid data was provided or found.")
|
|
387
|
+
raise ValueError()
|
|
388
|
+
|
|
389
|
+
all_preds, _, all_true = [], [], []
|
|
390
|
+
for y_pred_b, y_prob_b, y_true_b in self._predict_for_eval(eval_loader):
|
|
391
|
+
if y_pred_b is not None: all_preds.append(y_pred_b)
|
|
392
|
+
if y_true_b is not None: all_true.append(y_true_b)
|
|
393
|
+
|
|
394
|
+
if not all_true:
|
|
395
|
+
_LOGGER.error("Evaluation failed: No data was processed.")
|
|
396
|
+
return
|
|
397
|
+
|
|
398
|
+
y_pred = np.concatenate(all_preds)
|
|
399
|
+
y_true = np.concatenate(all_true)
|
|
400
|
+
|
|
401
|
+
# --- Routing Logic ---
|
|
402
|
+
if self.kind == MLTaskKeys.SEQUENCE_VALUE:
|
|
403
|
+
config = None
|
|
404
|
+
if format_configuration and isinstance(format_configuration, FormatSequenceValueMetrics):
|
|
405
|
+
config = format_configuration
|
|
406
|
+
elif format_configuration:
|
|
407
|
+
_LOGGER.warning(f"Wrong config type: Received {type(format_configuration).__name__}, expected SequenceValueMetricsFormat.")
|
|
408
|
+
|
|
409
|
+
sequence_to_value_metrics(y_true=y_true,
|
|
410
|
+
y_pred=y_pred,
|
|
411
|
+
save_dir=save_dir,
|
|
412
|
+
config=config)
|
|
413
|
+
|
|
414
|
+
elif self.kind == MLTaskKeys.SEQUENCE_SEQUENCE:
|
|
415
|
+
config = None
|
|
416
|
+
if format_configuration and isinstance(format_configuration, FormatSequenceSequenceMetrics):
|
|
417
|
+
config = format_configuration
|
|
418
|
+
elif format_configuration:
|
|
419
|
+
_LOGGER.warning(f"Wrong config type: Received {type(format_configuration).__name__}, expected SequenceSequenceMetricsFormat.")
|
|
420
|
+
|
|
421
|
+
sequence_to_sequence_metrics(y_true=y_true,
|
|
422
|
+
y_pred=y_pred,
|
|
423
|
+
save_dir=save_dir,
|
|
424
|
+
config=config)
|
|
425
|
+
|
|
426
|
+
def explain_captum(self,
|
|
427
|
+
save_dir: Union[str, Path],
|
|
428
|
+
explain_dataset: Optional[Dataset] = None,
|
|
429
|
+
n_samples: int = 100,
|
|
430
|
+
feature_names: Optional[list[str]] = None,
|
|
431
|
+
target_names: Optional[list[str]] = None,
|
|
432
|
+
n_steps: int = 50):
|
|
433
|
+
"""
|
|
434
|
+
Explains sequence model predictions using Captum's Integrated Gradients.
|
|
435
|
+
|
|
436
|
+
This method calculates global feature importance by aggregating attributions across
|
|
437
|
+
the time dimension.
|
|
438
|
+
- For **multivariate** sequences, it highlights which variables (channels) are most influential.
|
|
439
|
+
- For **univariate** sequences, it attributes importance to the single signal feature.
|
|
440
|
+
|
|
441
|
+
Args:
|
|
442
|
+
save_dir (str | Path): Directory to save the importance plots and CSV reports.
|
|
443
|
+
explain_dataset (Dataset | None): A specific dataset to sample from. If None, the
|
|
444
|
+
trainer's validation dataset is used.
|
|
445
|
+
n_samples (int): The number of samples to use for the explanation (background + inputs).
|
|
446
|
+
feature_names (List[str] | None): Names of the features (signals). If None, attempts to extract them from the dataset attribute.
|
|
447
|
+
target_names (List[str] | None): Names of the model outputs (e.g., for Seq2Seq or Multivariate output). If None, attempts to extract them from the dataset attribute.
|
|
448
|
+
n_steps (int): Number of integral approximation steps.
|
|
449
|
+
|
|
450
|
+
Note:
|
|
451
|
+
For univariate data (Shape: N, Seq_Len), the 'feature' is the signal itself.
|
|
452
|
+
"""
|
|
453
|
+
dataset_to_use = explain_dataset if explain_dataset is not None else self.validation_dataset
|
|
454
|
+
if dataset_to_use is None:
|
|
455
|
+
_LOGGER.error("No dataset available for explanation.")
|
|
456
|
+
return
|
|
457
|
+
|
|
458
|
+
# Helper to sample data (same as DragonTrainer)
|
|
459
|
+
def _get_samples(ds, n):
|
|
460
|
+
loader = DataLoader(ds, batch_size=n, shuffle=True, num_workers=0)
|
|
461
|
+
data_iter = iter(loader)
|
|
462
|
+
features, targets = next(data_iter)
|
|
463
|
+
return features, targets
|
|
464
|
+
|
|
465
|
+
input_data, _ = _get_samples(dataset_to_use, n_samples)
|
|
466
|
+
|
|
467
|
+
if feature_names is None:
|
|
468
|
+
if hasattr(dataset_to_use, DatasetKeys.FEATURE_NAMES):
|
|
469
|
+
feature_names = dataset_to_use.feature_names # type: ignore
|
|
470
|
+
else:
|
|
471
|
+
# If retrieval fails, leave it as None.
|
|
472
|
+
_LOGGER.warning("'feature_names' not provided and not found in dataset. Generic names will be used.")
|
|
473
|
+
|
|
474
|
+
if target_names is None:
|
|
475
|
+
if hasattr(dataset_to_use, DatasetKeys.TARGET_NAMES):
|
|
476
|
+
target_names = dataset_to_use.target_names # type: ignore
|
|
477
|
+
else:
|
|
478
|
+
# If retrieval fails, leave it as None.
|
|
479
|
+
_LOGGER.warning("'target_names' not provided and not found in dataset. Generic names will be used.")
|
|
480
|
+
|
|
481
|
+
# Sequence models usually output [N, 1] (Value) or [N, Seq, 1] (Seq2Seq)
|
|
482
|
+
# captum_feature_importance handles the aggregation.
|
|
483
|
+
|
|
484
|
+
captum_feature_importance(
|
|
485
|
+
model=self.model,
|
|
486
|
+
input_data=input_data,
|
|
487
|
+
feature_names=feature_names,
|
|
488
|
+
save_dir=save_dir,
|
|
489
|
+
target_names=target_names,
|
|
490
|
+
n_steps=n_steps,
|
|
491
|
+
device=self.device
|
|
492
|
+
)
|
|
493
|
+
|
|
494
|
+
def finalize_model_training(self,
|
|
495
|
+
save_dir: Union[str, Path],
|
|
496
|
+
model_checkpoint: Union[Path, Literal['best', 'current']],
|
|
497
|
+
finalize_config: Union[FinalizeSequenceSequencePrediction, FinalizeSequenceValuePrediction]):
|
|
498
|
+
"""
|
|
499
|
+
Saves a finalized, "inference-ready" model state to a .pth file.
|
|
500
|
+
|
|
501
|
+
This method saves the model's `state_dict` and the final epoch number.
|
|
502
|
+
|
|
503
|
+
Args:
|
|
504
|
+
save_dir (Union[str, Path]): The directory to save the finalized model.
|
|
505
|
+
model_checkpoint (Union[Path, Literal["best", "current"]]):
|
|
506
|
+
- Path: Loads the model state from a specific checkpoint file.
|
|
507
|
+
- "best": Loads the best model state saved by the `DragonModelCheckpoint` callback.
|
|
508
|
+
- "current": Uses the model's state as it is.
|
|
509
|
+
finalize_config (FinalizeSequencePrediction): A data class instance specific to the ML task containing task-specific metadata required for inference.
|
|
510
|
+
"""
|
|
511
|
+
if self.kind == MLTaskKeys.SEQUENCE_SEQUENCE and not isinstance(finalize_config, FinalizeSequenceSequencePrediction):
|
|
512
|
+
_LOGGER.error(f"Received a wrong finalize configuration for task {self.kind}: {type(finalize_config).__name__}.")
|
|
513
|
+
raise TypeError()
|
|
514
|
+
elif self.kind == MLTaskKeys.SEQUENCE_VALUE and not isinstance(finalize_config, FinalizeSequenceValuePrediction):
|
|
515
|
+
_LOGGER.error(f"Received a wrong finalize configuration for task {self.kind}: {type(finalize_config).__name__}.")
|
|
516
|
+
raise TypeError()
|
|
517
|
+
|
|
518
|
+
# handle save path
|
|
519
|
+
dir_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
520
|
+
full_path = dir_path / finalize_config.filename
|
|
521
|
+
|
|
522
|
+
# handle checkpoint
|
|
523
|
+
self._load_model_state_for_finalizing(model_checkpoint)
|
|
524
|
+
|
|
525
|
+
# Create finalized data
|
|
526
|
+
finalized_data = {
|
|
527
|
+
PyTorchCheckpointKeys.EPOCH: self.epoch,
|
|
528
|
+
PyTorchCheckpointKeys.MODEL_STATE: self.model.state_dict(),
|
|
529
|
+
PyTorchCheckpointKeys.TASK: finalize_config.task
|
|
530
|
+
}
|
|
531
|
+
|
|
532
|
+
if finalize_config.sequence_length is not None:
|
|
533
|
+
finalized_data[PyTorchCheckpointKeys.SEQUENCE_LENGTH] = finalize_config.sequence_length
|
|
534
|
+
if finalize_config.initial_sequence is not None:
|
|
535
|
+
finalized_data[PyTorchCheckpointKeys.INITIAL_SEQUENCE] = finalize_config.initial_sequence
|
|
536
|
+
|
|
537
|
+
torch.save(finalized_data, full_path)
|
|
538
|
+
|
|
539
|
+
_LOGGER.info(f"Finalized model file saved to '{full_path}'")
|
|
540
|
+
|