collie-mlops 0.1.0b0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of collie-mlops might be problematic. Click here for more details.
- collie/__init__.py +69 -0
- collie/_common/__init__.py +0 -0
- collie/_common/decorator.py +53 -0
- collie/_common/exceptions.py +104 -0
- collie/_common/mlflow_model_io/__init__.py +0 -0
- collie/_common/mlflow_model_io/base_flavor_handler.py +26 -0
- collie/_common/mlflow_model_io/flavor_registry.py +72 -0
- collie/_common/mlflow_model_io/model_flavors.py +259 -0
- collie/_common/mlflow_model_io/model_io.py +65 -0
- collie/_common/utils.py +13 -0
- collie/contracts/__init__.py +0 -0
- collie/contracts/event.py +79 -0
- collie/contracts/mlflow.py +444 -0
- collie/contracts/orchestrator.py +79 -0
- collie/core/__init__.py +41 -0
- collie/core/enums/__init__.py +0 -0
- collie/core/enums/components.py +26 -0
- collie/core/enums/ml_models.py +20 -0
- collie/core/evaluator/__init__.py +0 -0
- collie/core/evaluator/evaluator.py +147 -0
- collie/core/models.py +125 -0
- collie/core/orchestrator/__init__.py +0 -0
- collie/core/orchestrator/orchestrator.py +47 -0
- collie/core/pusher/__init__.py +0 -0
- collie/core/pusher/pusher.py +98 -0
- collie/core/trainer/__init__.py +0 -0
- collie/core/trainer/trainer.py +78 -0
- collie/core/transform/__init__.py +0 -0
- collie/core/transform/transform.py +87 -0
- collie/core/tuner/__init__.py +0 -0
- collie/core/tuner/tuner.py +84 -0
- collie/helper/__init__.py +0 -0
- collie/helper/pytorch/__init__.py +0 -0
- collie/helper/pytorch/callback/__init__.py +0 -0
- collie/helper/pytorch/callback/callback.py +155 -0
- collie/helper/pytorch/callback/earlystop.py +54 -0
- collie/helper/pytorch/callback/model_checkpoint.py +100 -0
- collie/helper/pytorch/model/__init__.py +0 -0
- collie/helper/pytorch/model/loader.py +55 -0
- collie/helper/pytorch/trainer.py +304 -0
- collie_mlops-0.1.0b0.dist-info/METADATA +217 -0
- collie_mlops-0.1.0b0.dist-info/RECORD +45 -0
- collie_mlops-0.1.0b0.dist-info/WHEEL +5 -0
- collie_mlops-0.1.0b0.dist-info/licenses/LICENSE +21 -0
- collie_mlops-0.1.0b0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,155 @@
|
|
|
1
|
+
from typing import Optional, List, Literal
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from collie import trainer
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class Callback:
|
|
9
|
+
|
|
10
|
+
def on_train_start(
|
|
11
|
+
self,
|
|
12
|
+
trainer: "trainer.PytorchTrainer"
|
|
13
|
+
) -> None:
|
|
14
|
+
...
|
|
15
|
+
|
|
16
|
+
def on_train_end(
|
|
17
|
+
self,
|
|
18
|
+
trainer: "trainer.PytorchTrainer"
|
|
19
|
+
) -> None:
|
|
20
|
+
...
|
|
21
|
+
|
|
22
|
+
def on_epoch_start(
|
|
23
|
+
self,
|
|
24
|
+
epoch_step: int,
|
|
25
|
+
trainer: "trainer.PytorchTrainer"
|
|
26
|
+
) -> None:
|
|
27
|
+
...
|
|
28
|
+
|
|
29
|
+
def on_epoch_end(
|
|
30
|
+
self,
|
|
31
|
+
epoch_step: int,
|
|
32
|
+
epoch_train_loss: float,
|
|
33
|
+
epoch_val_loss: Optional[float],
|
|
34
|
+
trainer: "trainer.PytorchTrainer"
|
|
35
|
+
) -> None:
|
|
36
|
+
...
|
|
37
|
+
|
|
38
|
+
def on_batch_start(
|
|
39
|
+
self,
|
|
40
|
+
batch_step: int,
|
|
41
|
+
batch_data: torch.Tensor,
|
|
42
|
+
trainer: "trainer.PytorchTrainer"
|
|
43
|
+
) -> None:
|
|
44
|
+
...
|
|
45
|
+
|
|
46
|
+
def on_batch_end(
|
|
47
|
+
self,
|
|
48
|
+
batch_step: int,
|
|
49
|
+
batch_data: torch.Tensor,
|
|
50
|
+
batch_train_loss: float,
|
|
51
|
+
trainer: "trainer.PytorchTrainer"
|
|
52
|
+
) -> None:
|
|
53
|
+
...
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class _CallbackManager:
|
|
57
|
+
def __init__(self, callbacks: List[Callback]):
|
|
58
|
+
self.callbacks = callbacks
|
|
59
|
+
|
|
60
|
+
def on_train_start(
|
|
61
|
+
self,
|
|
62
|
+
trainer: "trainer.PytorchTrainer"
|
|
63
|
+
) -> None:
|
|
64
|
+
|
|
65
|
+
self._execute_callbacks(
|
|
66
|
+
"on_train_start",
|
|
67
|
+
trainer=trainer
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
def on_train_end(
|
|
71
|
+
self,
|
|
72
|
+
trainer: "trainer.PytorchTrainer"
|
|
73
|
+
) -> None:
|
|
74
|
+
|
|
75
|
+
self._execute_callbacks(
|
|
76
|
+
"on_train_end",
|
|
77
|
+
trainer=trainer
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
def on_epoch_start(
|
|
81
|
+
self,
|
|
82
|
+
trainer: "trainer.PytorchTrainer",
|
|
83
|
+
epoch_step: int
|
|
84
|
+
) -> None:
|
|
85
|
+
|
|
86
|
+
self._execute_callbacks(
|
|
87
|
+
"on_epoch_start",
|
|
88
|
+
trainer=trainer,
|
|
89
|
+
epoch_step=epoch_step
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
def on_epoch_end(
|
|
93
|
+
self,
|
|
94
|
+
trainer: "trainer.PytorchTrainer",
|
|
95
|
+
epoch_step: int,
|
|
96
|
+
epoch_train_loss: float,
|
|
97
|
+
epoch_val_loss: Optional[float],
|
|
98
|
+
) -> None:
|
|
99
|
+
|
|
100
|
+
self._execute_callbacks(
|
|
101
|
+
"on_epoch_end",
|
|
102
|
+
trainer=trainer,
|
|
103
|
+
epoch_step=epoch_step,
|
|
104
|
+
epoch_train_loss=epoch_train_loss,
|
|
105
|
+
epoch_val_loss=epoch_val_loss
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
def on_batch_start(
|
|
109
|
+
self,
|
|
110
|
+
batch_step: int,
|
|
111
|
+
batch_data: torch.Tensor,
|
|
112
|
+
trainer: "trainer.PytorchTrainer"
|
|
113
|
+
) -> None:
|
|
114
|
+
|
|
115
|
+
self._execute_callbacks(
|
|
116
|
+
"on_batch_start",
|
|
117
|
+
trainer=trainer,
|
|
118
|
+
batch_step=batch_step,
|
|
119
|
+
batch_data= batch_data
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
def on_batch_end(
|
|
123
|
+
self,
|
|
124
|
+
batch_step: int,
|
|
125
|
+
batch_data: torch.Tensor,
|
|
126
|
+
batch_train_loss: float,
|
|
127
|
+
trainer: "trainer.PytorchTrainer"
|
|
128
|
+
) -> None:
|
|
129
|
+
|
|
130
|
+
self._execute_callbacks(
|
|
131
|
+
"on_batch_end",
|
|
132
|
+
trainer=trainer,
|
|
133
|
+
batch_step=batch_step,
|
|
134
|
+
batch_data= batch_data,
|
|
135
|
+
batch_train_loss=batch_train_loss
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
def _execute_callbacks(
|
|
139
|
+
self,
|
|
140
|
+
method_name: Literal[
|
|
141
|
+
"on_train_start",
|
|
142
|
+
"on_train_end",
|
|
143
|
+
"on_epoch_start",
|
|
144
|
+
"on_epoch_end",
|
|
145
|
+
"on_batch_start",
|
|
146
|
+
"on_batch_end"],
|
|
147
|
+
*args, **kwargs
|
|
148
|
+
) -> None:
|
|
149
|
+
|
|
150
|
+
if self.callbacks is None:
|
|
151
|
+
return
|
|
152
|
+
|
|
153
|
+
for callback in self.callbacks:
|
|
154
|
+
# Call the method specified by method_name on the callback
|
|
155
|
+
getattr(callback, method_name)(*args, **kwargs)
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
from collie.helper.pytorch.callback.callback import Callback
|
|
4
|
+
from collie import trainer
|
|
5
|
+
from collie._common.utils import get_logger
|
|
6
|
+
|
|
7
|
+
logger = get_logger()
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class EarlyStopping(Callback):
|
|
11
|
+
|
|
12
|
+
def __init__(
|
|
13
|
+
self,
|
|
14
|
+
patience_on_epoch: int,
|
|
15
|
+
delta: float = 0.0,
|
|
16
|
+
monitor: str = "val_loss"
|
|
17
|
+
):
|
|
18
|
+
super().__init__()
|
|
19
|
+
|
|
20
|
+
self.patience_on_epoch = patience_on_epoch
|
|
21
|
+
self.delta = delta
|
|
22
|
+
self.monitor = monitor
|
|
23
|
+
self.best_score = float('inf')
|
|
24
|
+
self.wait = 0
|
|
25
|
+
self.stopped_epoch = 0
|
|
26
|
+
self.early_stop = False
|
|
27
|
+
|
|
28
|
+
def on_epoch_end(
|
|
29
|
+
self,
|
|
30
|
+
epoch_step: int,
|
|
31
|
+
epoch_train_loss: float,
|
|
32
|
+
epoch_val_loss: Optional[float],
|
|
33
|
+
trainer: "trainer.PytorchTrainer"
|
|
34
|
+
) -> None:
|
|
35
|
+
|
|
36
|
+
if self.monitor == "val_loss" and epoch_val_loss is not None:
|
|
37
|
+
score = epoch_val_loss
|
|
38
|
+
elif self.monitor == "train_loss":
|
|
39
|
+
score = epoch_train_loss
|
|
40
|
+
else:
|
|
41
|
+
return
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
if self.best_score - score > self.delta:
|
|
45
|
+
self.best_score = score
|
|
46
|
+
self.wait = 0
|
|
47
|
+
else:
|
|
48
|
+
self.wait += 1
|
|
49
|
+
|
|
50
|
+
if self.wait >= self.patience_on_epoch:
|
|
51
|
+
self.early_stop = True
|
|
52
|
+
self.stopped_epoch = epoch_step
|
|
53
|
+
trainer.should_stop = True
|
|
54
|
+
logger.info(f"Epoch {epoch_step}: Early stopping triggered (patience {self.patience_on_epoch} epochs).")
|
|
@@ -0,0 +1,100 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
from glob import glob
|
|
3
|
+
import os
|
|
4
|
+
import shutil
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
from collie.helper.pytorch.callback.callback import Callback
|
|
9
|
+
from collie._common.utils import get_logger
|
|
10
|
+
from collie import trainer
|
|
11
|
+
|
|
12
|
+
logger = get_logger()
|
|
13
|
+
|
|
14
|
+
class ModelCheckpoint(Callback):
|
|
15
|
+
|
|
16
|
+
def __init__(self, topk_checkpoints: int):
|
|
17
|
+
super().__init__()
|
|
18
|
+
self.topk_checkpoints = topk_checkpoints
|
|
19
|
+
self._best_checkpoints = [] # save as (loss, epoch_idx, checkpoint_path)
|
|
20
|
+
|
|
21
|
+
self.parent_dir = "./.checkpoint/"
|
|
22
|
+
if os.path.exists(self.parent_dir):
|
|
23
|
+
shutil.rmtree(self.parent_dir)
|
|
24
|
+
logger.info(f"Directory {self.parent_dir} has been removed.")
|
|
25
|
+
|
|
26
|
+
def on_epoch_end(
|
|
27
|
+
self,
|
|
28
|
+
trainer: "trainer.PytorchTrainer",
|
|
29
|
+
epoch_step: int,
|
|
30
|
+
epoch_train_loss: float,
|
|
31
|
+
epoch_val_loss: Optional[float]
|
|
32
|
+
) -> None:
|
|
33
|
+
|
|
34
|
+
current_loss = epoch_val_loss if epoch_val_loss else epoch_train_loss
|
|
35
|
+
if self._should_save_checkpoint(current_loss):
|
|
36
|
+
self._save_checkpoint(
|
|
37
|
+
trainer=trainer,
|
|
38
|
+
epoch_step=epoch_step,
|
|
39
|
+
epoch_loss=epoch_train_loss,
|
|
40
|
+
loss_for_ckpt=current_loss
|
|
41
|
+
)
|
|
42
|
+
def on_train_end(self, trainer):
|
|
43
|
+
for ckpt in glob(f"{self.parent_dir}/*.pt"):
|
|
44
|
+
trainer.log_artifact(ckpt, "checkpoints")
|
|
45
|
+
|
|
46
|
+
def _should_save_checkpoint(self, current_loss: float) -> bool:
|
|
47
|
+
"""
|
|
48
|
+
Returns a boolean indicating if the current checkpoint should be saved.
|
|
49
|
+
|
|
50
|
+
The current checkpoint should be saved if the topk_checkpoints list is not full.
|
|
51
|
+
Otherwise, the current checkpoint should be saved if its loss is better than the
|
|
52
|
+
loss of the worst checkpoint in the topk_checkpoints list.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
current_loss (float): The loss of the current checkpoint.
|
|
56
|
+
|
|
57
|
+
Returns:
|
|
58
|
+
bool: True if the current checkpoint should be saved, False otherwise.
|
|
59
|
+
"""
|
|
60
|
+
if len(self._best_checkpoints) < self.topk_checkpoints:
|
|
61
|
+
# If the topk_checkpoints list is not full, save the current checkpoint.
|
|
62
|
+
return True
|
|
63
|
+
|
|
64
|
+
worst_loss = max(self._best_checkpoints, key=lambda x: x[0])[0]
|
|
65
|
+
# If the current checkpoint has a better loss than the worst checkpoint in the
|
|
66
|
+
# topk_checkpoints list, save the current checkpoint.
|
|
67
|
+
return current_loss < worst_loss
|
|
68
|
+
|
|
69
|
+
def _save_checkpoint(
|
|
70
|
+
self,
|
|
71
|
+
trainer: "trainer.PytorchTrainer",
|
|
72
|
+
epoch_step: int,
|
|
73
|
+
epoch_loss: float,
|
|
74
|
+
loss_for_ckpt: float
|
|
75
|
+
) -> None:
|
|
76
|
+
|
|
77
|
+
checkpoint_path = self.parent_dir + f'model_epoch{epoch_step}.pt'
|
|
78
|
+
|
|
79
|
+
if not os.path.exists(self.parent_dir):
|
|
80
|
+
os.makedirs(self.parent_dir)
|
|
81
|
+
|
|
82
|
+
# https://zhuanlan.zhihu.com/p/136902153
|
|
83
|
+
checkpoint = {
|
|
84
|
+
"model_state_dict": trainer.model.state_dict(),
|
|
85
|
+
"optimizer_state_dict": trainer.optimizer.state_dict(),
|
|
86
|
+
'scheduler_state_dict': trainer.lr_scheduler.state_dict() if trainer.lr_scheduler else None,
|
|
87
|
+
"epoch": epoch_step,
|
|
88
|
+
"loss": epoch_loss
|
|
89
|
+
}
|
|
90
|
+
|
|
91
|
+
if len(self._best_checkpoints) < self.topk_checkpoints:
|
|
92
|
+
self._best_checkpoints.append((loss_for_ckpt, epoch_step, checkpoint_path))
|
|
93
|
+
else:
|
|
94
|
+
worst_idx, (worst_loss, _, worst_path) = max(enumerate(self._best_checkpoints), key=lambda x: x[1][0])
|
|
95
|
+
|
|
96
|
+
os.remove(worst_path)
|
|
97
|
+
self._best_checkpoints[worst_idx] = (loss_for_ckpt, epoch_step, checkpoint_path)
|
|
98
|
+
|
|
99
|
+
self._best_checkpoints.sort(key=lambda x: x[0])
|
|
100
|
+
torch.save(checkpoint, f'{checkpoint_path}')
|
|
File without changes
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
from typing import Optional, Dict, Union
|
|
2
|
+
from collections import OrderedDict
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def load_pytorch_model_checkpoint(
|
|
8
|
+
ckpt_path: str,
|
|
9
|
+
model: torch.nn.Module,
|
|
10
|
+
optimizer: Optional[torch.optim.Optimizer] = None,
|
|
11
|
+
lr_scheduler: Optional[torch.optim.lr_scheduler.LRScheduler] = None,
|
|
12
|
+
is_ddp_model: bool = False
|
|
13
|
+
) -> Dict[str, Union[torch.nn.Module, torch.optim.Optimizer, torch.optim.lr_scheduler._LRScheduler, int]]:
|
|
14
|
+
"""
|
|
15
|
+
Loads a PyTorch model checkpoint and optionally restores the optimizer and learning rate scheduler states.
|
|
16
|
+
|
|
17
|
+
Args:
|
|
18
|
+
ckpt_path (str): Path to the checkpoint file.
|
|
19
|
+
model (torch.nn.Module): The model to load the state dict into.
|
|
20
|
+
optimizer (Optional[torch.optim.Optimizer], optional): Optimizer to restore its state. Default is None.
|
|
21
|
+
lr_scheduler (Optional[torch.optim.lr_scheduler._LRScheduler], optional): Learning rate scheduler to restore its state. Default is None.
|
|
22
|
+
is_ddp_model (bool, optional): Whether the model is wrapped in DistributedDataParallel. Default is False.
|
|
23
|
+
|
|
24
|
+
Returns:
|
|
25
|
+
Dict[str, Union[torch.nn.Module, torch.optim.Optimizer, torch.optim.lr_scheduler._LRScheduler, int]]:
|
|
26
|
+
A dictionary containing the restored model, optimizer, learning rate scheduler (if any), and the epoch.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
res_dict = dict()
|
|
30
|
+
|
|
31
|
+
checkpoint = torch.load(ckpt_path)
|
|
32
|
+
state_dict = checkpoint['model_state_dict']
|
|
33
|
+
|
|
34
|
+
# Handle DDP (DistributedDataParallel) model
|
|
35
|
+
if is_ddp_model:
|
|
36
|
+
new_state_dict = OrderedDict()
|
|
37
|
+
for k, v in state_dict.items():
|
|
38
|
+
# remove 'module.' of DataParallel/DistributedDataParallel
|
|
39
|
+
name = k[7:]
|
|
40
|
+
new_state_dict[name] = v
|
|
41
|
+
state_dict = new_state_dict
|
|
42
|
+
|
|
43
|
+
model.load_state_dict(state_dict, strict=False)
|
|
44
|
+
res_dict["model"] = model.eval()
|
|
45
|
+
res_dict["epoch"] = checkpoint.get("epoch", "")
|
|
46
|
+
|
|
47
|
+
if optimizer :
|
|
48
|
+
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
|
49
|
+
res_dict["optimizer"] = optimizer
|
|
50
|
+
|
|
51
|
+
if 'lr_scheduler_state_dict' in checkpoint and lr_scheduler:
|
|
52
|
+
lr_scheduler.load_state_dict(checkpoint['lr_scheduler_state_dict'])
|
|
53
|
+
res_dict["lr_scheduler"] = lr_scheduler
|
|
54
|
+
|
|
55
|
+
return res_dict
|
|
@@ -0,0 +1,304 @@
|
|
|
1
|
+
from typing import (
|
|
2
|
+
Optional,
|
|
3
|
+
Tuple,
|
|
4
|
+
Dict,
|
|
5
|
+
Any,
|
|
6
|
+
List,
|
|
7
|
+
)
|
|
8
|
+
from abc import abstractmethod
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
from tqdm import tqdm
|
|
12
|
+
from torch.utils.data import DataLoader
|
|
13
|
+
from torch import nn
|
|
14
|
+
from torch.amp import GradScaler, autocast
|
|
15
|
+
import torch
|
|
16
|
+
|
|
17
|
+
from collie.trainer.trainer import Trainer
|
|
18
|
+
from collie._common.decorator import type_checker
|
|
19
|
+
from collie._common.utils import get_logger
|
|
20
|
+
from collie.helper.pytorch.callback.callback import Callback, _CallbackManager
|
|
21
|
+
from collie.helper.pytorch.callback.model_checkpoint import ModelCheckpoint
|
|
22
|
+
from collie.helper.pytorch.callback.earlystop import EarlyStopping
|
|
23
|
+
from collie.contracts.event import Event
|
|
24
|
+
from collie.core.types import TrainerPayload, TransformerPayload
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
#TODO: Features to develop:
|
|
28
|
+
# 1. Train with multiple GPUs
|
|
29
|
+
logger = get_logger()
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class _AbstractPytorchTrainer:
|
|
33
|
+
|
|
34
|
+
@abstractmethod
|
|
35
|
+
def configure_optimizers(
|
|
36
|
+
self
|
|
37
|
+
) -> Tuple[
|
|
38
|
+
torch.optim.Optimizer,
|
|
39
|
+
Optional[torch.optim.lr_scheduler._LRScheduler]
|
|
40
|
+
]:
|
|
41
|
+
"""
|
|
42
|
+
Configure the optimizers and learning rate schedulers.
|
|
43
|
+
|
|
44
|
+
Returns:
|
|
45
|
+
A tuple of two objects. The first object is the optimizers and the second is the learning rate schedulers if necessary.
|
|
46
|
+
|
|
47
|
+
Example:
|
|
48
|
+
code block :: python
|
|
49
|
+
|
|
50
|
+
def configure_optimizers(self) -> Tuple[List[torch.optim.Optimizer], List[torch.optim.lr_scheduler._LRScheduler]]:
|
|
51
|
+
optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.initial_lr)
|
|
52
|
+
scheduler = get_cosine_schedule_with_warmup(
|
|
53
|
+
optimizer=optimizer,
|
|
54
|
+
num_warmup_steps=self.num_warmup_steps,
|
|
55
|
+
num_training_steps=self.num_training_steps
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
return optimizer, scheduler
|
|
59
|
+
"""
|
|
60
|
+
raise NotImplementedError("Please implement the *configure_optimizers* method.")
|
|
61
|
+
|
|
62
|
+
@type_checker(
|
|
63
|
+
(DataLoader,),
|
|
64
|
+
"The train_data in TransformerPayload should be type of DataLoader in PytorchTrainer."
|
|
65
|
+
)
|
|
66
|
+
def get_train_dataloader(self, event: Event[TransformerPayload]) -> DataLoader:
|
|
67
|
+
transformer_payload = event.payload
|
|
68
|
+
return transformer_payload.train_data
|
|
69
|
+
|
|
70
|
+
@type_checker(
|
|
71
|
+
(DataLoader, None),
|
|
72
|
+
"The validation_data in TransformerPayload should be type of DataLoader in PytorchTrainer."
|
|
73
|
+
)
|
|
74
|
+
def get_val_dataloader(self, event: Event[TransformerPayload]) -> DataLoader:
|
|
75
|
+
transformer_payload = event.payload
|
|
76
|
+
return transformer_payload.validation_data
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
class PytorchTrainer(_AbstractPytorchTrainer):
|
|
80
|
+
def __init__(
|
|
81
|
+
self,
|
|
82
|
+
model: nn.Module,
|
|
83
|
+
epochs: int,
|
|
84
|
+
device: Optional[str],
|
|
85
|
+
use_amp: Optional[bool],
|
|
86
|
+
topk_checkpoints: Optional[int],
|
|
87
|
+
earlystop_patience_on_epoch: Optional[int] = None,
|
|
88
|
+
accumulate_grad_batches: int = 1,
|
|
89
|
+
callbacks: Optional[List[Callback]] = None,
|
|
90
|
+
) -> None:
|
|
91
|
+
|
|
92
|
+
super().__init__()
|
|
93
|
+
|
|
94
|
+
self.model = model
|
|
95
|
+
|
|
96
|
+
self.epochs = epochs
|
|
97
|
+
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
|
98
|
+
self.use_amp = use_amp
|
|
99
|
+
self.topk_checkpoints = 3 if topk_checkpoints is None else topk_checkpoints
|
|
100
|
+
self.earlystop_patience_on_epoch = earlystop_patience_on_epoch
|
|
101
|
+
|
|
102
|
+
DEFAULT_CALLBACK = [ModelCheckpoint(topk_checkpoints=self.topk_checkpoints)]
|
|
103
|
+
if self.earlystop_patience_on_epoch:
|
|
104
|
+
DEFAULT_CALLBACK.append(EarlyStopping(self.earlystop_patience_on_epoch, delta=0.0))
|
|
105
|
+
|
|
106
|
+
self.callbacks = [] if callbacks is None else callbacks
|
|
107
|
+
self.callbacks = self.callbacks + DEFAULT_CALLBACK
|
|
108
|
+
self.accumulate_grad_batches = accumulate_grad_batches
|
|
109
|
+
|
|
110
|
+
self.train_data_loader = None
|
|
111
|
+
self.val_data_loader = None
|
|
112
|
+
self.optimizer = None
|
|
113
|
+
self.lr_scheduler = None
|
|
114
|
+
self.grad_scaler = None
|
|
115
|
+
self.should_stop = False
|
|
116
|
+
self.cb_manager = _CallbackManager(callbacks=self.callbacks)
|
|
117
|
+
|
|
118
|
+
self.model = self.model.to(self.device)
|
|
119
|
+
|
|
120
|
+
def run(
|
|
121
|
+
self,
|
|
122
|
+
event: Event[TransformerPayload]
|
|
123
|
+
) -> Event[TrainerPayload]:
|
|
124
|
+
|
|
125
|
+
self.train_data_loader = self.get_train_dataloader(event=event)
|
|
126
|
+
self.val_data_loader = self.get_val_dataloader(event=event)
|
|
127
|
+
|
|
128
|
+
self.optimizer, self.lr_scheduler = self.configure_optimizers()
|
|
129
|
+
|
|
130
|
+
trainstep_per_epoch = len(self.train_data_loader)
|
|
131
|
+
total_train_step = trainstep_per_epoch * self.epochs
|
|
132
|
+
|
|
133
|
+
self.log_param("epoch", self.epochs)
|
|
134
|
+
self.log_param("batch size", self.train_data_loader.batch_size)
|
|
135
|
+
self.log_param("total training step", total_train_step)
|
|
136
|
+
|
|
137
|
+
self.cb_manager.on_train_start(self)
|
|
138
|
+
self.optimizer.zero_grad()
|
|
139
|
+
# TODO: the train loop is written by the user in handle method.
|
|
140
|
+
for epoch_idx in tqdm(range(1, self.epochs + 1)):
|
|
141
|
+
# Modified by Earlystop callback
|
|
142
|
+
if self.should_stop:
|
|
143
|
+
break
|
|
144
|
+
|
|
145
|
+
self.cb_manager.on_epoch_start(self, epoch_idx)
|
|
146
|
+
running_loss = 0
|
|
147
|
+
self.model.train()
|
|
148
|
+
for batch_idx, batch_data in enumerate(self.train_data_loader, start=1):
|
|
149
|
+
self.cb_manager.on_batch_start(
|
|
150
|
+
batch_data=batch_data,
|
|
151
|
+
batch_step=batch_idx,
|
|
152
|
+
trainer=self
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
loss = self._get_train_step_result(epoch_idx, batch_data) / self.accumulate_grad_batches
|
|
156
|
+
running_loss += loss.item() * self.accumulate_grad_batches
|
|
157
|
+
|
|
158
|
+
self._backward(loss=loss, batch_idx=batch_idx)
|
|
159
|
+
# The learning rate scheduler is batch level.
|
|
160
|
+
if self.lr_scheduler and batch_idx % self.accumulate_grad_batches == 0:
|
|
161
|
+
self.lr_scheduler.step()
|
|
162
|
+
|
|
163
|
+
self.cb_manager.on_batch_end(
|
|
164
|
+
batch_step=batch_idx,
|
|
165
|
+
batch_data=batch_data,
|
|
166
|
+
batch_train_loss=running_loss,
|
|
167
|
+
trainer=self
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
epoch_loss = running_loss / trainstep_per_epoch
|
|
171
|
+
|
|
172
|
+
self.log_metric("train loss", epoch_loss, step=epoch_idx)
|
|
173
|
+
self._log_lr(epoch_step=epoch_idx)
|
|
174
|
+
|
|
175
|
+
epoch_val_loss = self.validation_loop(epoch_step=epoch_idx)
|
|
176
|
+
|
|
177
|
+
self.cb_manager.on_epoch_end(
|
|
178
|
+
trainer=self,
|
|
179
|
+
epoch_step=epoch_idx,
|
|
180
|
+
epoch_train_loss=epoch_loss,
|
|
181
|
+
epoch_val_loss=epoch_val_loss
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
# TODO: Fix the issue of loading the model from the following result model.
|
|
185
|
+
self.log_model(self.model, model_type="pt")
|
|
186
|
+
|
|
187
|
+
self.cb_manager.on_train_end(self)
|
|
188
|
+
|
|
189
|
+
# return {"model":self.model, "loss": epoch_loss}
|
|
190
|
+
trainer_payload = TrainerPayload(
|
|
191
|
+
model=self.model,
|
|
192
|
+
train_loss=epoch_loss,
|
|
193
|
+
val_loss=epoch_val_loss
|
|
194
|
+
)
|
|
195
|
+
return Event(
|
|
196
|
+
payload=trainer_payload
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
def validation_loop(self, epoch_step: int) -> Optional[float]:
|
|
200
|
+
|
|
201
|
+
if self.val_data_loader is None:
|
|
202
|
+
return None
|
|
203
|
+
|
|
204
|
+
valstep_per_epoch = len(self.val_data_loader)
|
|
205
|
+
|
|
206
|
+
self.model.eval()
|
|
207
|
+
val_running_loss = 0
|
|
208
|
+
with torch.no_grad():
|
|
209
|
+
|
|
210
|
+
for val_batch_data in self.val_data_loader:
|
|
211
|
+
|
|
212
|
+
val_loss = self.validation_step(epoch_step, val_batch_data)
|
|
213
|
+
val_running_loss += val_loss.item()
|
|
214
|
+
|
|
215
|
+
epoch_val_loss = val_running_loss / valstep_per_epoch
|
|
216
|
+
|
|
217
|
+
self.log_metric("val_loss", epoch_val_loss, step=epoch_step)
|
|
218
|
+
return epoch_val_loss
|
|
219
|
+
|
|
220
|
+
def _backward(
|
|
221
|
+
self,
|
|
222
|
+
loss: torch.Tensor,
|
|
223
|
+
batch_idx: int
|
|
224
|
+
) -> None:
|
|
225
|
+
|
|
226
|
+
def should_step_optimizer(batch_idx: int) -> bool:
|
|
227
|
+
return batch_idx % self.accumulate_grad_batches == 0 or batch_idx == len(self.train_data_loader)
|
|
228
|
+
|
|
229
|
+
if self.use_amp:
|
|
230
|
+
if self.grad_scaler is None:
|
|
231
|
+
self.grad_scaler = GradScaler(device=self.device)
|
|
232
|
+
|
|
233
|
+
self.grad_scaler.scale(loss).backward()
|
|
234
|
+
if should_step_optimizer(batch_idx):
|
|
235
|
+
|
|
236
|
+
if self._has_invalid_gradients():
|
|
237
|
+
# gradient clipping
|
|
238
|
+
self.grad_scaler.unscale_(self.optimizer)
|
|
239
|
+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
|
|
240
|
+
|
|
241
|
+
self.grad_scaler.step(self.optimizer)
|
|
242
|
+
self.grad_scaler.update()
|
|
243
|
+
self.optimizer.zero_grad()
|
|
244
|
+
else:
|
|
245
|
+
loss.backward()
|
|
246
|
+
if should_step_optimizer():
|
|
247
|
+
if self._has_invalid_gradients(batch_idx):
|
|
248
|
+
# gradient clipping
|
|
249
|
+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
|
|
250
|
+
self.optimizer.step()
|
|
251
|
+
self.optimizer.zero_grad()
|
|
252
|
+
|
|
253
|
+
def _has_invalid_gradients(self) -> bool:
|
|
254
|
+
|
|
255
|
+
for param in self.model.parameters():
|
|
256
|
+
if param.grad is not None:
|
|
257
|
+
if torch.isnan(param.grad).any() or torch.isinf(param.grad).any():
|
|
258
|
+
logger.warning(f"Detected NaN or Inf gradients for parameter {param}. Zeroing out the gradients.")
|
|
259
|
+
# side effect
|
|
260
|
+
param.grad.zero_()
|
|
261
|
+
return True
|
|
262
|
+
return False
|
|
263
|
+
|
|
264
|
+
def _log_lr(self, epoch_step: int) -> None:
|
|
265
|
+
|
|
266
|
+
lr = self.optimizer.param_groups[0]['lr'] if not self.lr_scheduler else self.lr_scheduler.get_last_lr()[0]
|
|
267
|
+
self.log_metric("learning rate", lr, step=epoch_step)
|
|
268
|
+
|
|
269
|
+
@type_checker((torch.Tensor,),
|
|
270
|
+
"The return type of *train_step* method must be 'Tensor'."
|
|
271
|
+
)
|
|
272
|
+
def _get_train_step_result(
|
|
273
|
+
self,
|
|
274
|
+
epoch_step: int,
|
|
275
|
+
batch_data: torch.Tensor
|
|
276
|
+
):
|
|
277
|
+
|
|
278
|
+
if self.use_amp:
|
|
279
|
+
with autocast(device_type=self.device):
|
|
280
|
+
loss = self.handle(epoch_step, batch_data)
|
|
281
|
+
else:
|
|
282
|
+
loss = self.handle(epoch_step, batch_data)
|
|
283
|
+
|
|
284
|
+
return loss
|
|
285
|
+
|
|
286
|
+
def _get_optimizers(
|
|
287
|
+
self
|
|
288
|
+
) -> Tuple[torch.optim.Optimizer, Optional[torch.optim.lr_scheduler.LRScheduler]]:
|
|
289
|
+
"""
|
|
290
|
+
Get the optimizer and learning rate scheduler from the
|
|
291
|
+
*configure_optimizers* method.
|
|
292
|
+
"""
|
|
293
|
+
result = self.configure_optimizers()
|
|
294
|
+
if isinstance(result, tuple):
|
|
295
|
+
optimizers, lr_scheduler = result
|
|
296
|
+
if not isinstance(lr_scheduler, torch.optim.lr_scheduler._LRScheduler):
|
|
297
|
+
raise TypeError(
|
|
298
|
+
f"learning rate scheduler must be type of \
|
|
299
|
+
torch.optim.lr_scheduler._LRScheduler "
|
|
300
|
+
)
|
|
301
|
+
else:
|
|
302
|
+
optimizers, lr_scheduler = result, None
|
|
303
|
+
return optimizers, lr_scheduler
|
|
304
|
+
|