collie-mlops 0.1.1b0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- collie/__init__.py +69 -0
- collie/_common/__init__.py +0 -0
- collie/_common/decorator.py +53 -0
- collie/_common/exceptions.py +104 -0
- collie/_common/mlflow_model_io/__init__.py +0 -0
- collie/_common/mlflow_model_io/base_flavor_handler.py +26 -0
- collie/_common/mlflow_model_io/flavor_registry.py +72 -0
- collie/_common/mlflow_model_io/model_flavors.py +259 -0
- collie/_common/mlflow_model_io/model_io.py +65 -0
- collie/_common/utils.py +13 -0
- collie/contracts/__init__.py +0 -0
- collie/contracts/event.py +79 -0
- collie/contracts/mlflow.py +444 -0
- collie/contracts/orchestrator.py +79 -0
- collie/core/__init__.py +41 -0
- collie/core/enums/__init__.py +0 -0
- collie/core/enums/components.py +26 -0
- collie/core/enums/ml_models.py +20 -0
- collie/core/evaluator/__init__.py +0 -0
- collie/core/evaluator/evaluator.py +147 -0
- collie/core/models.py +125 -0
- collie/core/orchestrator/__init__.py +0 -0
- collie/core/orchestrator/orchestrator.py +47 -0
- collie/core/pusher/__init__.py +0 -0
- collie/core/pusher/pusher.py +98 -0
- collie/core/trainer/__init__.py +0 -0
- collie/core/trainer/trainer.py +78 -0
- collie/core/transform/__init__.py +0 -0
- collie/core/transform/transform.py +87 -0
- collie/core/tuner/__init__.py +0 -0
- collie/core/tuner/tuner.py +84 -0
- collie/helper/__init__.py +0 -0
- collie/helper/pytorch/__init__.py +0 -0
- collie/helper/pytorch/callback/__init__.py +0 -0
- collie/helper/pytorch/callback/callback.py +155 -0
- collie/helper/pytorch/callback/earlystop.py +54 -0
- collie/helper/pytorch/callback/model_checkpoint.py +100 -0
- collie/helper/pytorch/model/__init__.py +0 -0
- collie/helper/pytorch/model/loader.py +55 -0
- collie/helper/pytorch/trainer.py +304 -0
- collie_mlops-0.1.1b0.dist-info/LICENSE +21 -0
- collie_mlops-0.1.1b0.dist-info/METADATA +259 -0
- collie_mlops-0.1.1b0.dist-info/RECORD +45 -0
- collie_mlops-0.1.1b0.dist-info/WHEEL +5 -0
- collie_mlops-0.1.1b0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,155 @@
|
|
|
1
|
+
from typing import Optional, List, Literal
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from collie import trainer
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class Callback:
|
|
9
|
+
|
|
10
|
+
def on_train_start(
|
|
11
|
+
self,
|
|
12
|
+
trainer: "trainer.PytorchTrainer"
|
|
13
|
+
) -> None:
|
|
14
|
+
...
|
|
15
|
+
|
|
16
|
+
def on_train_end(
|
|
17
|
+
self,
|
|
18
|
+
trainer: "trainer.PytorchTrainer"
|
|
19
|
+
) -> None:
|
|
20
|
+
...
|
|
21
|
+
|
|
22
|
+
def on_epoch_start(
|
|
23
|
+
self,
|
|
24
|
+
epoch_step: int,
|
|
25
|
+
trainer: "trainer.PytorchTrainer"
|
|
26
|
+
) -> None:
|
|
27
|
+
...
|
|
28
|
+
|
|
29
|
+
def on_epoch_end(
|
|
30
|
+
self,
|
|
31
|
+
epoch_step: int,
|
|
32
|
+
epoch_train_loss: float,
|
|
33
|
+
epoch_val_loss: Optional[float],
|
|
34
|
+
trainer: "trainer.PytorchTrainer"
|
|
35
|
+
) -> None:
|
|
36
|
+
...
|
|
37
|
+
|
|
38
|
+
def on_batch_start(
|
|
39
|
+
self,
|
|
40
|
+
batch_step: int,
|
|
41
|
+
batch_data: torch.Tensor,
|
|
42
|
+
trainer: "trainer.PytorchTrainer"
|
|
43
|
+
) -> None:
|
|
44
|
+
...
|
|
45
|
+
|
|
46
|
+
def on_batch_end(
|
|
47
|
+
self,
|
|
48
|
+
batch_step: int,
|
|
49
|
+
batch_data: torch.Tensor,
|
|
50
|
+
batch_train_loss: float,
|
|
51
|
+
trainer: "trainer.PytorchTrainer"
|
|
52
|
+
) -> None:
|
|
53
|
+
...
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class _CallbackManager:
|
|
57
|
+
def __init__(self, callbacks: List[Callback]):
|
|
58
|
+
self.callbacks = callbacks
|
|
59
|
+
|
|
60
|
+
def on_train_start(
|
|
61
|
+
self,
|
|
62
|
+
trainer: "trainer.PytorchTrainer"
|
|
63
|
+
) -> None:
|
|
64
|
+
|
|
65
|
+
self._execute_callbacks(
|
|
66
|
+
"on_train_start",
|
|
67
|
+
trainer=trainer
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
def on_train_end(
|
|
71
|
+
self,
|
|
72
|
+
trainer: "trainer.PytorchTrainer"
|
|
73
|
+
) -> None:
|
|
74
|
+
|
|
75
|
+
self._execute_callbacks(
|
|
76
|
+
"on_train_end",
|
|
77
|
+
trainer=trainer
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
def on_epoch_start(
|
|
81
|
+
self,
|
|
82
|
+
trainer: "trainer.PytorchTrainer",
|
|
83
|
+
epoch_step: int
|
|
84
|
+
) -> None:
|
|
85
|
+
|
|
86
|
+
self._execute_callbacks(
|
|
87
|
+
"on_epoch_start",
|
|
88
|
+
trainer=trainer,
|
|
89
|
+
epoch_step=epoch_step
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
def on_epoch_end(
|
|
93
|
+
self,
|
|
94
|
+
trainer: "trainer.PytorchTrainer",
|
|
95
|
+
epoch_step: int,
|
|
96
|
+
epoch_train_loss: float,
|
|
97
|
+
epoch_val_loss: Optional[float],
|
|
98
|
+
) -> None:
|
|
99
|
+
|
|
100
|
+
self._execute_callbacks(
|
|
101
|
+
"on_epoch_end",
|
|
102
|
+
trainer=trainer,
|
|
103
|
+
epoch_step=epoch_step,
|
|
104
|
+
epoch_train_loss=epoch_train_loss,
|
|
105
|
+
epoch_val_loss=epoch_val_loss
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
def on_batch_start(
|
|
109
|
+
self,
|
|
110
|
+
batch_step: int,
|
|
111
|
+
batch_data: torch.Tensor,
|
|
112
|
+
trainer: "trainer.PytorchTrainer"
|
|
113
|
+
) -> None:
|
|
114
|
+
|
|
115
|
+
self._execute_callbacks(
|
|
116
|
+
"on_batch_start",
|
|
117
|
+
trainer=trainer,
|
|
118
|
+
batch_step=batch_step,
|
|
119
|
+
batch_data= batch_data
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
def on_batch_end(
|
|
123
|
+
self,
|
|
124
|
+
batch_step: int,
|
|
125
|
+
batch_data: torch.Tensor,
|
|
126
|
+
batch_train_loss: float,
|
|
127
|
+
trainer: "trainer.PytorchTrainer"
|
|
128
|
+
) -> None:
|
|
129
|
+
|
|
130
|
+
self._execute_callbacks(
|
|
131
|
+
"on_batch_end",
|
|
132
|
+
trainer=trainer,
|
|
133
|
+
batch_step=batch_step,
|
|
134
|
+
batch_data= batch_data,
|
|
135
|
+
batch_train_loss=batch_train_loss
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
def _execute_callbacks(
|
|
139
|
+
self,
|
|
140
|
+
method_name: Literal[
|
|
141
|
+
"on_train_start",
|
|
142
|
+
"on_train_end",
|
|
143
|
+
"on_epoch_start",
|
|
144
|
+
"on_epoch_end",
|
|
145
|
+
"on_batch_start",
|
|
146
|
+
"on_batch_end"],
|
|
147
|
+
*args, **kwargs
|
|
148
|
+
) -> None:
|
|
149
|
+
|
|
150
|
+
if self.callbacks is None:
|
|
151
|
+
return
|
|
152
|
+
|
|
153
|
+
for callback in self.callbacks:
|
|
154
|
+
# Call the method specified by method_name on the callback
|
|
155
|
+
getattr(callback, method_name)(*args, **kwargs)
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
from collie.helper.pytorch.callback.callback import Callback
|
|
4
|
+
from collie import trainer
|
|
5
|
+
from collie._common.utils import get_logger
|
|
6
|
+
|
|
7
|
+
logger = get_logger()
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class EarlyStopping(Callback):
|
|
11
|
+
|
|
12
|
+
def __init__(
|
|
13
|
+
self,
|
|
14
|
+
patience_on_epoch: int,
|
|
15
|
+
delta: float = 0.0,
|
|
16
|
+
monitor: str = "val_loss"
|
|
17
|
+
):
|
|
18
|
+
super().__init__()
|
|
19
|
+
|
|
20
|
+
self.patience_on_epoch = patience_on_epoch
|
|
21
|
+
self.delta = delta
|
|
22
|
+
self.monitor = monitor
|
|
23
|
+
self.best_score = float('inf')
|
|
24
|
+
self.wait = 0
|
|
25
|
+
self.stopped_epoch = 0
|
|
26
|
+
self.early_stop = False
|
|
27
|
+
|
|
28
|
+
def on_epoch_end(
|
|
29
|
+
self,
|
|
30
|
+
epoch_step: int,
|
|
31
|
+
epoch_train_loss: float,
|
|
32
|
+
epoch_val_loss: Optional[float],
|
|
33
|
+
trainer: "trainer.PytorchTrainer"
|
|
34
|
+
) -> None:
|
|
35
|
+
|
|
36
|
+
if self.monitor == "val_loss" and epoch_val_loss is not None:
|
|
37
|
+
score = epoch_val_loss
|
|
38
|
+
elif self.monitor == "train_loss":
|
|
39
|
+
score = epoch_train_loss
|
|
40
|
+
else:
|
|
41
|
+
return
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
if self.best_score - score > self.delta:
|
|
45
|
+
self.best_score = score
|
|
46
|
+
self.wait = 0
|
|
47
|
+
else:
|
|
48
|
+
self.wait += 1
|
|
49
|
+
|
|
50
|
+
if self.wait >= self.patience_on_epoch:
|
|
51
|
+
self.early_stop = True
|
|
52
|
+
self.stopped_epoch = epoch_step
|
|
53
|
+
trainer.should_stop = True
|
|
54
|
+
logger.info(f"Epoch {epoch_step}: Early stopping triggered (patience {self.patience_on_epoch} epochs).")
|
|
@@ -0,0 +1,100 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
from glob import glob
|
|
3
|
+
import os
|
|
4
|
+
import shutil
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
from collie.helper.pytorch.callback.callback import Callback
|
|
9
|
+
from collie._common.utils import get_logger
|
|
10
|
+
from collie import trainer
|
|
11
|
+
|
|
12
|
+
logger = get_logger()
|
|
13
|
+
|
|
14
|
+
class ModelCheckpoint(Callback):
|
|
15
|
+
|
|
16
|
+
def __init__(self, topk_checkpoints: int):
|
|
17
|
+
super().__init__()
|
|
18
|
+
self.topk_checkpoints = topk_checkpoints
|
|
19
|
+
self._best_checkpoints = [] # save as (loss, epoch_idx, checkpoint_path)
|
|
20
|
+
|
|
21
|
+
self.parent_dir = "./.checkpoint/"
|
|
22
|
+
if os.path.exists(self.parent_dir):
|
|
23
|
+
shutil.rmtree(self.parent_dir)
|
|
24
|
+
logger.info(f"Directory {self.parent_dir} has been removed.")
|
|
25
|
+
|
|
26
|
+
def on_epoch_end(
|
|
27
|
+
self,
|
|
28
|
+
trainer: "trainer.PytorchTrainer",
|
|
29
|
+
epoch_step: int,
|
|
30
|
+
epoch_train_loss: float,
|
|
31
|
+
epoch_val_loss: Optional[float]
|
|
32
|
+
) -> None:
|
|
33
|
+
|
|
34
|
+
current_loss = epoch_val_loss if epoch_val_loss else epoch_train_loss
|
|
35
|
+
if self._should_save_checkpoint(current_loss):
|
|
36
|
+
self._save_checkpoint(
|
|
37
|
+
trainer=trainer,
|
|
38
|
+
epoch_step=epoch_step,
|
|
39
|
+
epoch_loss=epoch_train_loss,
|
|
40
|
+
loss_for_ckpt=current_loss
|
|
41
|
+
)
|
|
42
|
+
def on_train_end(self, trainer):
|
|
43
|
+
for ckpt in glob(f"{self.parent_dir}/*.pt"):
|
|
44
|
+
trainer.log_artifact(ckpt, "checkpoints")
|
|
45
|
+
|
|
46
|
+
def _should_save_checkpoint(self, current_loss: float) -> bool:
|
|
47
|
+
"""
|
|
48
|
+
Returns a boolean indicating if the current checkpoint should be saved.
|
|
49
|
+
|
|
50
|
+
The current checkpoint should be saved if the topk_checkpoints list is not full.
|
|
51
|
+
Otherwise, the current checkpoint should be saved if its loss is better than the
|
|
52
|
+
loss of the worst checkpoint in the topk_checkpoints list.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
current_loss (float): The loss of the current checkpoint.
|
|
56
|
+
|
|
57
|
+
Returns:
|
|
58
|
+
bool: True if the current checkpoint should be saved, False otherwise.
|
|
59
|
+
"""
|
|
60
|
+
if len(self._best_checkpoints) < self.topk_checkpoints:
|
|
61
|
+
# If the topk_checkpoints list is not full, save the current checkpoint.
|
|
62
|
+
return True
|
|
63
|
+
|
|
64
|
+
worst_loss = max(self._best_checkpoints, key=lambda x: x[0])[0]
|
|
65
|
+
# If the current checkpoint has a better loss than the worst checkpoint in the
|
|
66
|
+
# topk_checkpoints list, save the current checkpoint.
|
|
67
|
+
return current_loss < worst_loss
|
|
68
|
+
|
|
69
|
+
def _save_checkpoint(
|
|
70
|
+
self,
|
|
71
|
+
trainer: "trainer.PytorchTrainer",
|
|
72
|
+
epoch_step: int,
|
|
73
|
+
epoch_loss: float,
|
|
74
|
+
loss_for_ckpt: float
|
|
75
|
+
) -> None:
|
|
76
|
+
|
|
77
|
+
checkpoint_path = self.parent_dir + f'model_epoch{epoch_step}.pt'
|
|
78
|
+
|
|
79
|
+
if not os.path.exists(self.parent_dir):
|
|
80
|
+
os.makedirs(self.parent_dir)
|
|
81
|
+
|
|
82
|
+
# https://zhuanlan.zhihu.com/p/136902153
|
|
83
|
+
checkpoint = {
|
|
84
|
+
"model_state_dict": trainer.model.state_dict(),
|
|
85
|
+
"optimizer_state_dict": trainer.optimizer.state_dict(),
|
|
86
|
+
'scheduler_state_dict': trainer.lr_scheduler.state_dict() if trainer.lr_scheduler else None,
|
|
87
|
+
"epoch": epoch_step,
|
|
88
|
+
"loss": epoch_loss
|
|
89
|
+
}
|
|
90
|
+
|
|
91
|
+
if len(self._best_checkpoints) < self.topk_checkpoints:
|
|
92
|
+
self._best_checkpoints.append((loss_for_ckpt, epoch_step, checkpoint_path))
|
|
93
|
+
else:
|
|
94
|
+
worst_idx, (worst_loss, _, worst_path) = max(enumerate(self._best_checkpoints), key=lambda x: x[1][0])
|
|
95
|
+
|
|
96
|
+
os.remove(worst_path)
|
|
97
|
+
self._best_checkpoints[worst_idx] = (loss_for_ckpt, epoch_step, checkpoint_path)
|
|
98
|
+
|
|
99
|
+
self._best_checkpoints.sort(key=lambda x: x[0])
|
|
100
|
+
torch.save(checkpoint, f'{checkpoint_path}')
|
|
File without changes
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
from typing import Optional, Dict, Union
|
|
2
|
+
from collections import OrderedDict
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def load_pytorch_model_checkpoint(
|
|
8
|
+
ckpt_path: str,
|
|
9
|
+
model: torch.nn.Module,
|
|
10
|
+
optimizer: Optional[torch.optim.Optimizer] = None,
|
|
11
|
+
lr_scheduler: Optional[torch.optim.lr_scheduler.LRScheduler] = None,
|
|
12
|
+
is_ddp_model: bool = False
|
|
13
|
+
) -> Dict[str, Union[torch.nn.Module, torch.optim.Optimizer, torch.optim.lr_scheduler._LRScheduler, int]]:
|
|
14
|
+
"""
|
|
15
|
+
Loads a PyTorch model checkpoint and optionally restores the optimizer and learning rate scheduler states.
|
|
16
|
+
|
|
17
|
+
Args:
|
|
18
|
+
ckpt_path (str): Path to the checkpoint file.
|
|
19
|
+
model (torch.nn.Module): The model to load the state dict into.
|
|
20
|
+
optimizer (Optional[torch.optim.Optimizer], optional): Optimizer to restore its state. Default is None.
|
|
21
|
+
lr_scheduler (Optional[torch.optim.lr_scheduler._LRScheduler], optional): Learning rate scheduler to restore its state. Default is None.
|
|
22
|
+
is_ddp_model (bool, optional): Whether the model is wrapped in DistributedDataParallel. Default is False.
|
|
23
|
+
|
|
24
|
+
Returns:
|
|
25
|
+
Dict[str, Union[torch.nn.Module, torch.optim.Optimizer, torch.optim.lr_scheduler._LRScheduler, int]]:
|
|
26
|
+
A dictionary containing the restored model, optimizer, learning rate scheduler (if any), and the epoch.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
res_dict = dict()
|
|
30
|
+
|
|
31
|
+
checkpoint = torch.load(ckpt_path)
|
|
32
|
+
state_dict = checkpoint['model_state_dict']
|
|
33
|
+
|
|
34
|
+
# Handle DDP (DistributedDataParallel) model
|
|
35
|
+
if is_ddp_model:
|
|
36
|
+
new_state_dict = OrderedDict()
|
|
37
|
+
for k, v in state_dict.items():
|
|
38
|
+
# remove 'module.' of DataParallel/DistributedDataParallel
|
|
39
|
+
name = k[7:]
|
|
40
|
+
new_state_dict[name] = v
|
|
41
|
+
state_dict = new_state_dict
|
|
42
|
+
|
|
43
|
+
model.load_state_dict(state_dict, strict=False)
|
|
44
|
+
res_dict["model"] = model.eval()
|
|
45
|
+
res_dict["epoch"] = checkpoint.get("epoch", "")
|
|
46
|
+
|
|
47
|
+
if optimizer :
|
|
48
|
+
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
|
49
|
+
res_dict["optimizer"] = optimizer
|
|
50
|
+
|
|
51
|
+
if 'lr_scheduler_state_dict' in checkpoint and lr_scheduler:
|
|
52
|
+
lr_scheduler.load_state_dict(checkpoint['lr_scheduler_state_dict'])
|
|
53
|
+
res_dict["lr_scheduler"] = lr_scheduler
|
|
54
|
+
|
|
55
|
+
return res_dict
|
|
@@ -0,0 +1,304 @@
|
|
|
1
|
+
from typing import (
|
|
2
|
+
Optional,
|
|
3
|
+
Tuple,
|
|
4
|
+
Dict,
|
|
5
|
+
Any,
|
|
6
|
+
List,
|
|
7
|
+
)
|
|
8
|
+
from abc import abstractmethod
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
from tqdm import tqdm
|
|
12
|
+
from torch.utils.data import DataLoader
|
|
13
|
+
from torch import nn
|
|
14
|
+
from torch.amp import GradScaler, autocast
|
|
15
|
+
import torch
|
|
16
|
+
|
|
17
|
+
from collie.trainer.trainer import Trainer
|
|
18
|
+
from collie._common.decorator import type_checker
|
|
19
|
+
from collie._common.utils import get_logger
|
|
20
|
+
from collie.helper.pytorch.callback.callback import Callback, _CallbackManager
|
|
21
|
+
from collie.helper.pytorch.callback.model_checkpoint import ModelCheckpoint
|
|
22
|
+
from collie.helper.pytorch.callback.earlystop import EarlyStopping
|
|
23
|
+
from collie.contracts.event import Event
|
|
24
|
+
from collie.core.types import TrainerPayload, TransformerPayload
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
#TODO: Features to develop:
|
|
28
|
+
# 1. Train with multiple GPUs
|
|
29
|
+
logger = get_logger()
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class _AbstractPytorchTrainer:
|
|
33
|
+
|
|
34
|
+
@abstractmethod
|
|
35
|
+
def configure_optimizers(
|
|
36
|
+
self
|
|
37
|
+
) -> Tuple[
|
|
38
|
+
torch.optim.Optimizer,
|
|
39
|
+
Optional[torch.optim.lr_scheduler._LRScheduler]
|
|
40
|
+
]:
|
|
41
|
+
"""
|
|
42
|
+
Configure the optimizers and learning rate schedulers.
|
|
43
|
+
|
|
44
|
+
Returns:
|
|
45
|
+
A tuple of two objects. The first object is the optimizers and the second is the learning rate schedulers if necessary.
|
|
46
|
+
|
|
47
|
+
Example:
|
|
48
|
+
code block :: python
|
|
49
|
+
|
|
50
|
+
def configure_optimizers(self) -> Tuple[List[torch.optim.Optimizer], List[torch.optim.lr_scheduler._LRScheduler]]:
|
|
51
|
+
optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.initial_lr)
|
|
52
|
+
scheduler = get_cosine_schedule_with_warmup(
|
|
53
|
+
optimizer=optimizer,
|
|
54
|
+
num_warmup_steps=self.num_warmup_steps,
|
|
55
|
+
num_training_steps=self.num_training_steps
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
return optimizer, scheduler
|
|
59
|
+
"""
|
|
60
|
+
raise NotImplementedError("Please implement the *configure_optimizers* method.")
|
|
61
|
+
|
|
62
|
+
@type_checker(
|
|
63
|
+
(DataLoader,),
|
|
64
|
+
"The train_data in TransformerPayload should be type of DataLoader in PytorchTrainer."
|
|
65
|
+
)
|
|
66
|
+
def get_train_dataloader(self, event: Event[TransformerPayload]) -> DataLoader:
|
|
67
|
+
transformer_payload = event.payload
|
|
68
|
+
return transformer_payload.train_data
|
|
69
|
+
|
|
70
|
+
@type_checker(
|
|
71
|
+
(DataLoader, None),
|
|
72
|
+
"The validation_data in TransformerPayload should be type of DataLoader in PytorchTrainer."
|
|
73
|
+
)
|
|
74
|
+
def get_val_dataloader(self, event: Event[TransformerPayload]) -> DataLoader:
|
|
75
|
+
transformer_payload = event.payload
|
|
76
|
+
return transformer_payload.validation_data
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
class PytorchTrainer(_AbstractPytorchTrainer):
|
|
80
|
+
def __init__(
|
|
81
|
+
self,
|
|
82
|
+
model: nn.Module,
|
|
83
|
+
epochs: int,
|
|
84
|
+
device: Optional[str],
|
|
85
|
+
use_amp: Optional[bool],
|
|
86
|
+
topk_checkpoints: Optional[int],
|
|
87
|
+
earlystop_patience_on_epoch: Optional[int] = None,
|
|
88
|
+
accumulate_grad_batches: int = 1,
|
|
89
|
+
callbacks: Optional[List[Callback]] = None,
|
|
90
|
+
) -> None:
|
|
91
|
+
|
|
92
|
+
super().__init__()
|
|
93
|
+
|
|
94
|
+
self.model = model
|
|
95
|
+
|
|
96
|
+
self.epochs = epochs
|
|
97
|
+
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
|
98
|
+
self.use_amp = use_amp
|
|
99
|
+
self.topk_checkpoints = 3 if topk_checkpoints is None else topk_checkpoints
|
|
100
|
+
self.earlystop_patience_on_epoch = earlystop_patience_on_epoch
|
|
101
|
+
|
|
102
|
+
DEFAULT_CALLBACK = [ModelCheckpoint(topk_checkpoints=self.topk_checkpoints)]
|
|
103
|
+
if self.earlystop_patience_on_epoch:
|
|
104
|
+
DEFAULT_CALLBACK.append(EarlyStopping(self.earlystop_patience_on_epoch, delta=0.0))
|
|
105
|
+
|
|
106
|
+
self.callbacks = [] if callbacks is None else callbacks
|
|
107
|
+
self.callbacks = self.callbacks + DEFAULT_CALLBACK
|
|
108
|
+
self.accumulate_grad_batches = accumulate_grad_batches
|
|
109
|
+
|
|
110
|
+
self.train_data_loader = None
|
|
111
|
+
self.val_data_loader = None
|
|
112
|
+
self.optimizer = None
|
|
113
|
+
self.lr_scheduler = None
|
|
114
|
+
self.grad_scaler = None
|
|
115
|
+
self.should_stop = False
|
|
116
|
+
self.cb_manager = _CallbackManager(callbacks=self.callbacks)
|
|
117
|
+
|
|
118
|
+
self.model = self.model.to(self.device)
|
|
119
|
+
|
|
120
|
+
def run(
|
|
121
|
+
self,
|
|
122
|
+
event: Event[TransformerPayload]
|
|
123
|
+
) -> Event[TrainerPayload]:
|
|
124
|
+
|
|
125
|
+
self.train_data_loader = self.get_train_dataloader(event=event)
|
|
126
|
+
self.val_data_loader = self.get_val_dataloader(event=event)
|
|
127
|
+
|
|
128
|
+
self.optimizer, self.lr_scheduler = self.configure_optimizers()
|
|
129
|
+
|
|
130
|
+
trainstep_per_epoch = len(self.train_data_loader)
|
|
131
|
+
total_train_step = trainstep_per_epoch * self.epochs
|
|
132
|
+
|
|
133
|
+
self.log_param("epoch", self.epochs)
|
|
134
|
+
self.log_param("batch size", self.train_data_loader.batch_size)
|
|
135
|
+
self.log_param("total training step", total_train_step)
|
|
136
|
+
|
|
137
|
+
self.cb_manager.on_train_start(self)
|
|
138
|
+
self.optimizer.zero_grad()
|
|
139
|
+
# TODO: the train loop is written by the user in handle method.
|
|
140
|
+
for epoch_idx in tqdm(range(1, self.epochs + 1)):
|
|
141
|
+
# Modified by Earlystop callback
|
|
142
|
+
if self.should_stop:
|
|
143
|
+
break
|
|
144
|
+
|
|
145
|
+
self.cb_manager.on_epoch_start(self, epoch_idx)
|
|
146
|
+
running_loss = 0
|
|
147
|
+
self.model.train()
|
|
148
|
+
for batch_idx, batch_data in enumerate(self.train_data_loader, start=1):
|
|
149
|
+
self.cb_manager.on_batch_start(
|
|
150
|
+
batch_data=batch_data,
|
|
151
|
+
batch_step=batch_idx,
|
|
152
|
+
trainer=self
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
loss = self._get_train_step_result(epoch_idx, batch_data) / self.accumulate_grad_batches
|
|
156
|
+
running_loss += loss.item() * self.accumulate_grad_batches
|
|
157
|
+
|
|
158
|
+
self._backward(loss=loss, batch_idx=batch_idx)
|
|
159
|
+
# The learning rate scheduler is batch level.
|
|
160
|
+
if self.lr_scheduler and batch_idx % self.accumulate_grad_batches == 0:
|
|
161
|
+
self.lr_scheduler.step()
|
|
162
|
+
|
|
163
|
+
self.cb_manager.on_batch_end(
|
|
164
|
+
batch_step=batch_idx,
|
|
165
|
+
batch_data=batch_data,
|
|
166
|
+
batch_train_loss=running_loss,
|
|
167
|
+
trainer=self
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
epoch_loss = running_loss / trainstep_per_epoch
|
|
171
|
+
|
|
172
|
+
self.log_metric("train loss", epoch_loss, step=epoch_idx)
|
|
173
|
+
self._log_lr(epoch_step=epoch_idx)
|
|
174
|
+
|
|
175
|
+
epoch_val_loss = self.validation_loop(epoch_step=epoch_idx)
|
|
176
|
+
|
|
177
|
+
self.cb_manager.on_epoch_end(
|
|
178
|
+
trainer=self,
|
|
179
|
+
epoch_step=epoch_idx,
|
|
180
|
+
epoch_train_loss=epoch_loss,
|
|
181
|
+
epoch_val_loss=epoch_val_loss
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
# TODO: Fix the issue of loading the model from the following result model.
|
|
185
|
+
self.log_model(self.model, model_type="pt")
|
|
186
|
+
|
|
187
|
+
self.cb_manager.on_train_end(self)
|
|
188
|
+
|
|
189
|
+
# return {"model":self.model, "loss": epoch_loss}
|
|
190
|
+
trainer_payload = TrainerPayload(
|
|
191
|
+
model=self.model,
|
|
192
|
+
train_loss=epoch_loss,
|
|
193
|
+
val_loss=epoch_val_loss
|
|
194
|
+
)
|
|
195
|
+
return Event(
|
|
196
|
+
payload=trainer_payload
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
def validation_loop(self, epoch_step: int) -> Optional[float]:
|
|
200
|
+
|
|
201
|
+
if self.val_data_loader is None:
|
|
202
|
+
return None
|
|
203
|
+
|
|
204
|
+
valstep_per_epoch = len(self.val_data_loader)
|
|
205
|
+
|
|
206
|
+
self.model.eval()
|
|
207
|
+
val_running_loss = 0
|
|
208
|
+
with torch.no_grad():
|
|
209
|
+
|
|
210
|
+
for val_batch_data in self.val_data_loader:
|
|
211
|
+
|
|
212
|
+
val_loss = self.validation_step(epoch_step, val_batch_data)
|
|
213
|
+
val_running_loss += val_loss.item()
|
|
214
|
+
|
|
215
|
+
epoch_val_loss = val_running_loss / valstep_per_epoch
|
|
216
|
+
|
|
217
|
+
self.log_metric("val_loss", epoch_val_loss, step=epoch_step)
|
|
218
|
+
return epoch_val_loss
|
|
219
|
+
|
|
220
|
+
def _backward(
|
|
221
|
+
self,
|
|
222
|
+
loss: torch.Tensor,
|
|
223
|
+
batch_idx: int
|
|
224
|
+
) -> None:
|
|
225
|
+
|
|
226
|
+
def should_step_optimizer(batch_idx: int) -> bool:
|
|
227
|
+
return batch_idx % self.accumulate_grad_batches == 0 or batch_idx == len(self.train_data_loader)
|
|
228
|
+
|
|
229
|
+
if self.use_amp:
|
|
230
|
+
if self.grad_scaler is None:
|
|
231
|
+
self.grad_scaler = GradScaler(device=self.device)
|
|
232
|
+
|
|
233
|
+
self.grad_scaler.scale(loss).backward()
|
|
234
|
+
if should_step_optimizer(batch_idx):
|
|
235
|
+
|
|
236
|
+
if self._has_invalid_gradients():
|
|
237
|
+
# gradient clipping
|
|
238
|
+
self.grad_scaler.unscale_(self.optimizer)
|
|
239
|
+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
|
|
240
|
+
|
|
241
|
+
self.grad_scaler.step(self.optimizer)
|
|
242
|
+
self.grad_scaler.update()
|
|
243
|
+
self.optimizer.zero_grad()
|
|
244
|
+
else:
|
|
245
|
+
loss.backward()
|
|
246
|
+
if should_step_optimizer():
|
|
247
|
+
if self._has_invalid_gradients(batch_idx):
|
|
248
|
+
# gradient clipping
|
|
249
|
+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
|
|
250
|
+
self.optimizer.step()
|
|
251
|
+
self.optimizer.zero_grad()
|
|
252
|
+
|
|
253
|
+
def _has_invalid_gradients(self) -> bool:
|
|
254
|
+
|
|
255
|
+
for param in self.model.parameters():
|
|
256
|
+
if param.grad is not None:
|
|
257
|
+
if torch.isnan(param.grad).any() or torch.isinf(param.grad).any():
|
|
258
|
+
logger.warning(f"Detected NaN or Inf gradients for parameter {param}. Zeroing out the gradients.")
|
|
259
|
+
# side effect
|
|
260
|
+
param.grad.zero_()
|
|
261
|
+
return True
|
|
262
|
+
return False
|
|
263
|
+
|
|
264
|
+
def _log_lr(self, epoch_step: int) -> None:
|
|
265
|
+
|
|
266
|
+
lr = self.optimizer.param_groups[0]['lr'] if not self.lr_scheduler else self.lr_scheduler.get_last_lr()[0]
|
|
267
|
+
self.log_metric("learning rate", lr, step=epoch_step)
|
|
268
|
+
|
|
269
|
+
@type_checker((torch.Tensor,),
|
|
270
|
+
"The return type of *train_step* method must be 'Tensor'."
|
|
271
|
+
)
|
|
272
|
+
def _get_train_step_result(
|
|
273
|
+
self,
|
|
274
|
+
epoch_step: int,
|
|
275
|
+
batch_data: torch.Tensor
|
|
276
|
+
):
|
|
277
|
+
|
|
278
|
+
if self.use_amp:
|
|
279
|
+
with autocast(device_type=self.device):
|
|
280
|
+
loss = self.handle(epoch_step, batch_data)
|
|
281
|
+
else:
|
|
282
|
+
loss = self.handle(epoch_step, batch_data)
|
|
283
|
+
|
|
284
|
+
return loss
|
|
285
|
+
|
|
286
|
+
def _get_optimizers(
|
|
287
|
+
self
|
|
288
|
+
) -> Tuple[torch.optim.Optimizer, Optional[torch.optim.lr_scheduler.LRScheduler]]:
|
|
289
|
+
"""
|
|
290
|
+
Get the optimizer and learning rate scheduler from the
|
|
291
|
+
*configure_optimizers* method.
|
|
292
|
+
"""
|
|
293
|
+
result = self.configure_optimizers()
|
|
294
|
+
if isinstance(result, tuple):
|
|
295
|
+
optimizers, lr_scheduler = result
|
|
296
|
+
if not isinstance(lr_scheduler, torch.optim.lr_scheduler._LRScheduler):
|
|
297
|
+
raise TypeError(
|
|
298
|
+
f"learning rate scheduler must be type of \
|
|
299
|
+
torch.optim.lr_scheduler._LRScheduler "
|
|
300
|
+
)
|
|
301
|
+
else:
|
|
302
|
+
optimizers, lr_scheduler = result, None
|
|
303
|
+
return optimizers, lr_scheduler
|
|
304
|
+
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2025 ChingHuanChiu
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|