collie-mlops 0.1.0b0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of collie-mlops might be problematic. Click here for more details.

Files changed (45) hide show
  1. collie/__init__.py +69 -0
  2. collie/_common/__init__.py +0 -0
  3. collie/_common/decorator.py +53 -0
  4. collie/_common/exceptions.py +104 -0
  5. collie/_common/mlflow_model_io/__init__.py +0 -0
  6. collie/_common/mlflow_model_io/base_flavor_handler.py +26 -0
  7. collie/_common/mlflow_model_io/flavor_registry.py +72 -0
  8. collie/_common/mlflow_model_io/model_flavors.py +259 -0
  9. collie/_common/mlflow_model_io/model_io.py +65 -0
  10. collie/_common/utils.py +13 -0
  11. collie/contracts/__init__.py +0 -0
  12. collie/contracts/event.py +79 -0
  13. collie/contracts/mlflow.py +444 -0
  14. collie/contracts/orchestrator.py +79 -0
  15. collie/core/__init__.py +41 -0
  16. collie/core/enums/__init__.py +0 -0
  17. collie/core/enums/components.py +26 -0
  18. collie/core/enums/ml_models.py +20 -0
  19. collie/core/evaluator/__init__.py +0 -0
  20. collie/core/evaluator/evaluator.py +147 -0
  21. collie/core/models.py +125 -0
  22. collie/core/orchestrator/__init__.py +0 -0
  23. collie/core/orchestrator/orchestrator.py +47 -0
  24. collie/core/pusher/__init__.py +0 -0
  25. collie/core/pusher/pusher.py +98 -0
  26. collie/core/trainer/__init__.py +0 -0
  27. collie/core/trainer/trainer.py +78 -0
  28. collie/core/transform/__init__.py +0 -0
  29. collie/core/transform/transform.py +87 -0
  30. collie/core/tuner/__init__.py +0 -0
  31. collie/core/tuner/tuner.py +84 -0
  32. collie/helper/__init__.py +0 -0
  33. collie/helper/pytorch/__init__.py +0 -0
  34. collie/helper/pytorch/callback/__init__.py +0 -0
  35. collie/helper/pytorch/callback/callback.py +155 -0
  36. collie/helper/pytorch/callback/earlystop.py +54 -0
  37. collie/helper/pytorch/callback/model_checkpoint.py +100 -0
  38. collie/helper/pytorch/model/__init__.py +0 -0
  39. collie/helper/pytorch/model/loader.py +55 -0
  40. collie/helper/pytorch/trainer.py +304 -0
  41. collie_mlops-0.1.0b0.dist-info/METADATA +217 -0
  42. collie_mlops-0.1.0b0.dist-info/RECORD +45 -0
  43. collie_mlops-0.1.0b0.dist-info/WHEEL +5 -0
  44. collie_mlops-0.1.0b0.dist-info/licenses/LICENSE +21 -0
  45. collie_mlops-0.1.0b0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,155 @@
1
+ from typing import Optional, List, Literal
2
+
3
+ import torch
4
+
5
+ from collie import trainer
6
+
7
+
8
+ class Callback:
9
+
10
+ def on_train_start(
11
+ self,
12
+ trainer: "trainer.PytorchTrainer"
13
+ ) -> None:
14
+ ...
15
+
16
+ def on_train_end(
17
+ self,
18
+ trainer: "trainer.PytorchTrainer"
19
+ ) -> None:
20
+ ...
21
+
22
+ def on_epoch_start(
23
+ self,
24
+ epoch_step: int,
25
+ trainer: "trainer.PytorchTrainer"
26
+ ) -> None:
27
+ ...
28
+
29
+ def on_epoch_end(
30
+ self,
31
+ epoch_step: int,
32
+ epoch_train_loss: float,
33
+ epoch_val_loss: Optional[float],
34
+ trainer: "trainer.PytorchTrainer"
35
+ ) -> None:
36
+ ...
37
+
38
+ def on_batch_start(
39
+ self,
40
+ batch_step: int,
41
+ batch_data: torch.Tensor,
42
+ trainer: "trainer.PytorchTrainer"
43
+ ) -> None:
44
+ ...
45
+
46
+ def on_batch_end(
47
+ self,
48
+ batch_step: int,
49
+ batch_data: torch.Tensor,
50
+ batch_train_loss: float,
51
+ trainer: "trainer.PytorchTrainer"
52
+ ) -> None:
53
+ ...
54
+
55
+
56
+ class _CallbackManager:
57
+ def __init__(self, callbacks: List[Callback]):
58
+ self.callbacks = callbacks
59
+
60
+ def on_train_start(
61
+ self,
62
+ trainer: "trainer.PytorchTrainer"
63
+ ) -> None:
64
+
65
+ self._execute_callbacks(
66
+ "on_train_start",
67
+ trainer=trainer
68
+ )
69
+
70
+ def on_train_end(
71
+ self,
72
+ trainer: "trainer.PytorchTrainer"
73
+ ) -> None:
74
+
75
+ self._execute_callbacks(
76
+ "on_train_end",
77
+ trainer=trainer
78
+ )
79
+
80
+ def on_epoch_start(
81
+ self,
82
+ trainer: "trainer.PytorchTrainer",
83
+ epoch_step: int
84
+ ) -> None:
85
+
86
+ self._execute_callbacks(
87
+ "on_epoch_start",
88
+ trainer=trainer,
89
+ epoch_step=epoch_step
90
+ )
91
+
92
+ def on_epoch_end(
93
+ self,
94
+ trainer: "trainer.PytorchTrainer",
95
+ epoch_step: int,
96
+ epoch_train_loss: float,
97
+ epoch_val_loss: Optional[float],
98
+ ) -> None:
99
+
100
+ self._execute_callbacks(
101
+ "on_epoch_end",
102
+ trainer=trainer,
103
+ epoch_step=epoch_step,
104
+ epoch_train_loss=epoch_train_loss,
105
+ epoch_val_loss=epoch_val_loss
106
+ )
107
+
108
+ def on_batch_start(
109
+ self,
110
+ batch_step: int,
111
+ batch_data: torch.Tensor,
112
+ trainer: "trainer.PytorchTrainer"
113
+ ) -> None:
114
+
115
+ self._execute_callbacks(
116
+ "on_batch_start",
117
+ trainer=trainer,
118
+ batch_step=batch_step,
119
+ batch_data= batch_data
120
+ )
121
+
122
+ def on_batch_end(
123
+ self,
124
+ batch_step: int,
125
+ batch_data: torch.Tensor,
126
+ batch_train_loss: float,
127
+ trainer: "trainer.PytorchTrainer"
128
+ ) -> None:
129
+
130
+ self._execute_callbacks(
131
+ "on_batch_end",
132
+ trainer=trainer,
133
+ batch_step=batch_step,
134
+ batch_data= batch_data,
135
+ batch_train_loss=batch_train_loss
136
+ )
137
+
138
+ def _execute_callbacks(
139
+ self,
140
+ method_name: Literal[
141
+ "on_train_start",
142
+ "on_train_end",
143
+ "on_epoch_start",
144
+ "on_epoch_end",
145
+ "on_batch_start",
146
+ "on_batch_end"],
147
+ *args, **kwargs
148
+ ) -> None:
149
+
150
+ if self.callbacks is None:
151
+ return
152
+
153
+ for callback in self.callbacks:
154
+ # Call the method specified by method_name on the callback
155
+ getattr(callback, method_name)(*args, **kwargs)
@@ -0,0 +1,54 @@
1
+ from typing import Optional
2
+
3
+ from collie.helper.pytorch.callback.callback import Callback
4
+ from collie import trainer
5
+ from collie._common.utils import get_logger
6
+
7
+ logger = get_logger()
8
+
9
+
10
+ class EarlyStopping(Callback):
11
+
12
+ def __init__(
13
+ self,
14
+ patience_on_epoch: int,
15
+ delta: float = 0.0,
16
+ monitor: str = "val_loss"
17
+ ):
18
+ super().__init__()
19
+
20
+ self.patience_on_epoch = patience_on_epoch
21
+ self.delta = delta
22
+ self.monitor = monitor
23
+ self.best_score = float('inf')
24
+ self.wait = 0
25
+ self.stopped_epoch = 0
26
+ self.early_stop = False
27
+
28
+ def on_epoch_end(
29
+ self,
30
+ epoch_step: int,
31
+ epoch_train_loss: float,
32
+ epoch_val_loss: Optional[float],
33
+ trainer: "trainer.PytorchTrainer"
34
+ ) -> None:
35
+
36
+ if self.monitor == "val_loss" and epoch_val_loss is not None:
37
+ score = epoch_val_loss
38
+ elif self.monitor == "train_loss":
39
+ score = epoch_train_loss
40
+ else:
41
+ return
42
+
43
+
44
+ if self.best_score - score > self.delta:
45
+ self.best_score = score
46
+ self.wait = 0
47
+ else:
48
+ self.wait += 1
49
+
50
+ if self.wait >= self.patience_on_epoch:
51
+ self.early_stop = True
52
+ self.stopped_epoch = epoch_step
53
+ trainer.should_stop = True
54
+ logger.info(f"Epoch {epoch_step}: Early stopping triggered (patience {self.patience_on_epoch} epochs).")
@@ -0,0 +1,100 @@
1
+ from typing import Optional
2
+ from glob import glob
3
+ import os
4
+ import shutil
5
+
6
+ import torch
7
+
8
+ from collie.helper.pytorch.callback.callback import Callback
9
+ from collie._common.utils import get_logger
10
+ from collie import trainer
11
+
12
+ logger = get_logger()
13
+
14
+ class ModelCheckpoint(Callback):
15
+
16
+ def __init__(self, topk_checkpoints: int):
17
+ super().__init__()
18
+ self.topk_checkpoints = topk_checkpoints
19
+ self._best_checkpoints = [] # save as (loss, epoch_idx, checkpoint_path)
20
+
21
+ self.parent_dir = "./.checkpoint/"
22
+ if os.path.exists(self.parent_dir):
23
+ shutil.rmtree(self.parent_dir)
24
+ logger.info(f"Directory {self.parent_dir} has been removed.")
25
+
26
+ def on_epoch_end(
27
+ self,
28
+ trainer: "trainer.PytorchTrainer",
29
+ epoch_step: int,
30
+ epoch_train_loss: float,
31
+ epoch_val_loss: Optional[float]
32
+ ) -> None:
33
+
34
+ current_loss = epoch_val_loss if epoch_val_loss else epoch_train_loss
35
+ if self._should_save_checkpoint(current_loss):
36
+ self._save_checkpoint(
37
+ trainer=trainer,
38
+ epoch_step=epoch_step,
39
+ epoch_loss=epoch_train_loss,
40
+ loss_for_ckpt=current_loss
41
+ )
42
+ def on_train_end(self, trainer):
43
+ for ckpt in glob(f"{self.parent_dir}/*.pt"):
44
+ trainer.log_artifact(ckpt, "checkpoints")
45
+
46
+ def _should_save_checkpoint(self, current_loss: float) -> bool:
47
+ """
48
+ Returns a boolean indicating if the current checkpoint should be saved.
49
+
50
+ The current checkpoint should be saved if the topk_checkpoints list is not full.
51
+ Otherwise, the current checkpoint should be saved if its loss is better than the
52
+ loss of the worst checkpoint in the topk_checkpoints list.
53
+
54
+ Args:
55
+ current_loss (float): The loss of the current checkpoint.
56
+
57
+ Returns:
58
+ bool: True if the current checkpoint should be saved, False otherwise.
59
+ """
60
+ if len(self._best_checkpoints) < self.topk_checkpoints:
61
+ # If the topk_checkpoints list is not full, save the current checkpoint.
62
+ return True
63
+
64
+ worst_loss = max(self._best_checkpoints, key=lambda x: x[0])[0]
65
+ # If the current checkpoint has a better loss than the worst checkpoint in the
66
+ # topk_checkpoints list, save the current checkpoint.
67
+ return current_loss < worst_loss
68
+
69
+ def _save_checkpoint(
70
+ self,
71
+ trainer: "trainer.PytorchTrainer",
72
+ epoch_step: int,
73
+ epoch_loss: float,
74
+ loss_for_ckpt: float
75
+ ) -> None:
76
+
77
+ checkpoint_path = self.parent_dir + f'model_epoch{epoch_step}.pt'
78
+
79
+ if not os.path.exists(self.parent_dir):
80
+ os.makedirs(self.parent_dir)
81
+
82
+ # https://zhuanlan.zhihu.com/p/136902153
83
+ checkpoint = {
84
+ "model_state_dict": trainer.model.state_dict(),
85
+ "optimizer_state_dict": trainer.optimizer.state_dict(),
86
+ 'scheduler_state_dict': trainer.lr_scheduler.state_dict() if trainer.lr_scheduler else None,
87
+ "epoch": epoch_step,
88
+ "loss": epoch_loss
89
+ }
90
+
91
+ if len(self._best_checkpoints) < self.topk_checkpoints:
92
+ self._best_checkpoints.append((loss_for_ckpt, epoch_step, checkpoint_path))
93
+ else:
94
+ worst_idx, (worst_loss, _, worst_path) = max(enumerate(self._best_checkpoints), key=lambda x: x[1][0])
95
+
96
+ os.remove(worst_path)
97
+ self._best_checkpoints[worst_idx] = (loss_for_ckpt, epoch_step, checkpoint_path)
98
+
99
+ self._best_checkpoints.sort(key=lambda x: x[0])
100
+ torch.save(checkpoint, f'{checkpoint_path}')
File without changes
@@ -0,0 +1,55 @@
1
+ from typing import Optional, Dict, Union
2
+ from collections import OrderedDict
3
+
4
+ import torch
5
+
6
+
7
+ def load_pytorch_model_checkpoint(
8
+ ckpt_path: str,
9
+ model: torch.nn.Module,
10
+ optimizer: Optional[torch.optim.Optimizer] = None,
11
+ lr_scheduler: Optional[torch.optim.lr_scheduler.LRScheduler] = None,
12
+ is_ddp_model: bool = False
13
+ ) -> Dict[str, Union[torch.nn.Module, torch.optim.Optimizer, torch.optim.lr_scheduler._LRScheduler, int]]:
14
+ """
15
+ Loads a PyTorch model checkpoint and optionally restores the optimizer and learning rate scheduler states.
16
+
17
+ Args:
18
+ ckpt_path (str): Path to the checkpoint file.
19
+ model (torch.nn.Module): The model to load the state dict into.
20
+ optimizer (Optional[torch.optim.Optimizer], optional): Optimizer to restore its state. Default is None.
21
+ lr_scheduler (Optional[torch.optim.lr_scheduler._LRScheduler], optional): Learning rate scheduler to restore its state. Default is None.
22
+ is_ddp_model (bool, optional): Whether the model is wrapped in DistributedDataParallel. Default is False.
23
+
24
+ Returns:
25
+ Dict[str, Union[torch.nn.Module, torch.optim.Optimizer, torch.optim.lr_scheduler._LRScheduler, int]]:
26
+ A dictionary containing the restored model, optimizer, learning rate scheduler (if any), and the epoch.
27
+ """
28
+
29
+ res_dict = dict()
30
+
31
+ checkpoint = torch.load(ckpt_path)
32
+ state_dict = checkpoint['model_state_dict']
33
+
34
+ # Handle DDP (DistributedDataParallel) model
35
+ if is_ddp_model:
36
+ new_state_dict = OrderedDict()
37
+ for k, v in state_dict.items():
38
+ # remove 'module.' of DataParallel/DistributedDataParallel
39
+ name = k[7:]
40
+ new_state_dict[name] = v
41
+ state_dict = new_state_dict
42
+
43
+ model.load_state_dict(state_dict, strict=False)
44
+ res_dict["model"] = model.eval()
45
+ res_dict["epoch"] = checkpoint.get("epoch", "")
46
+
47
+ if optimizer :
48
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
49
+ res_dict["optimizer"] = optimizer
50
+
51
+ if 'lr_scheduler_state_dict' in checkpoint and lr_scheduler:
52
+ lr_scheduler.load_state_dict(checkpoint['lr_scheduler_state_dict'])
53
+ res_dict["lr_scheduler"] = lr_scheduler
54
+
55
+ return res_dict
@@ -0,0 +1,304 @@
1
+ from typing import (
2
+ Optional,
3
+ Tuple,
4
+ Dict,
5
+ Any,
6
+ List,
7
+ )
8
+ from abc import abstractmethod
9
+
10
+
11
+ from tqdm import tqdm
12
+ from torch.utils.data import DataLoader
13
+ from torch import nn
14
+ from torch.amp import GradScaler, autocast
15
+ import torch
16
+
17
+ from collie.trainer.trainer import Trainer
18
+ from collie._common.decorator import type_checker
19
+ from collie._common.utils import get_logger
20
+ from collie.helper.pytorch.callback.callback import Callback, _CallbackManager
21
+ from collie.helper.pytorch.callback.model_checkpoint import ModelCheckpoint
22
+ from collie.helper.pytorch.callback.earlystop import EarlyStopping
23
+ from collie.contracts.event import Event
24
+ from collie.core.types import TrainerPayload, TransformerPayload
25
+
26
+
27
+ #TODO: Features to develop:
28
+ # 1. Train with multiple GPUs
29
+ logger = get_logger()
30
+
31
+
32
+ class _AbstractPytorchTrainer:
33
+
34
+ @abstractmethod
35
+ def configure_optimizers(
36
+ self
37
+ ) -> Tuple[
38
+ torch.optim.Optimizer,
39
+ Optional[torch.optim.lr_scheduler._LRScheduler]
40
+ ]:
41
+ """
42
+ Configure the optimizers and learning rate schedulers.
43
+
44
+ Returns:
45
+ A tuple of two objects. The first object is the optimizers and the second is the learning rate schedulers if necessary.
46
+
47
+ Example:
48
+ code block :: python
49
+
50
+ def configure_optimizers(self) -> Tuple[List[torch.optim.Optimizer], List[torch.optim.lr_scheduler._LRScheduler]]:
51
+ optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.initial_lr)
52
+ scheduler = get_cosine_schedule_with_warmup(
53
+ optimizer=optimizer,
54
+ num_warmup_steps=self.num_warmup_steps,
55
+ num_training_steps=self.num_training_steps
56
+ )
57
+
58
+ return optimizer, scheduler
59
+ """
60
+ raise NotImplementedError("Please implement the *configure_optimizers* method.")
61
+
62
+ @type_checker(
63
+ (DataLoader,),
64
+ "The train_data in TransformerPayload should be type of DataLoader in PytorchTrainer."
65
+ )
66
+ def get_train_dataloader(self, event: Event[TransformerPayload]) -> DataLoader:
67
+ transformer_payload = event.payload
68
+ return transformer_payload.train_data
69
+
70
+ @type_checker(
71
+ (DataLoader, None),
72
+ "The validation_data in TransformerPayload should be type of DataLoader in PytorchTrainer."
73
+ )
74
+ def get_val_dataloader(self, event: Event[TransformerPayload]) -> DataLoader:
75
+ transformer_payload = event.payload
76
+ return transformer_payload.validation_data
77
+
78
+
79
+ class PytorchTrainer(_AbstractPytorchTrainer):
80
+ def __init__(
81
+ self,
82
+ model: nn.Module,
83
+ epochs: int,
84
+ device: Optional[str],
85
+ use_amp: Optional[bool],
86
+ topk_checkpoints: Optional[int],
87
+ earlystop_patience_on_epoch: Optional[int] = None,
88
+ accumulate_grad_batches: int = 1,
89
+ callbacks: Optional[List[Callback]] = None,
90
+ ) -> None:
91
+
92
+ super().__init__()
93
+
94
+ self.model = model
95
+
96
+ self.epochs = epochs
97
+ self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
98
+ self.use_amp = use_amp
99
+ self.topk_checkpoints = 3 if topk_checkpoints is None else topk_checkpoints
100
+ self.earlystop_patience_on_epoch = earlystop_patience_on_epoch
101
+
102
+ DEFAULT_CALLBACK = [ModelCheckpoint(topk_checkpoints=self.topk_checkpoints)]
103
+ if self.earlystop_patience_on_epoch:
104
+ DEFAULT_CALLBACK.append(EarlyStopping(self.earlystop_patience_on_epoch, delta=0.0))
105
+
106
+ self.callbacks = [] if callbacks is None else callbacks
107
+ self.callbacks = self.callbacks + DEFAULT_CALLBACK
108
+ self.accumulate_grad_batches = accumulate_grad_batches
109
+
110
+ self.train_data_loader = None
111
+ self.val_data_loader = None
112
+ self.optimizer = None
113
+ self.lr_scheduler = None
114
+ self.grad_scaler = None
115
+ self.should_stop = False
116
+ self.cb_manager = _CallbackManager(callbacks=self.callbacks)
117
+
118
+ self.model = self.model.to(self.device)
119
+
120
+ def run(
121
+ self,
122
+ event: Event[TransformerPayload]
123
+ ) -> Event[TrainerPayload]:
124
+
125
+ self.train_data_loader = self.get_train_dataloader(event=event)
126
+ self.val_data_loader = self.get_val_dataloader(event=event)
127
+
128
+ self.optimizer, self.lr_scheduler = self.configure_optimizers()
129
+
130
+ trainstep_per_epoch = len(self.train_data_loader)
131
+ total_train_step = trainstep_per_epoch * self.epochs
132
+
133
+ self.log_param("epoch", self.epochs)
134
+ self.log_param("batch size", self.train_data_loader.batch_size)
135
+ self.log_param("total training step", total_train_step)
136
+
137
+ self.cb_manager.on_train_start(self)
138
+ self.optimizer.zero_grad()
139
+ # TODO: the train loop is written by the user in handle method.
140
+ for epoch_idx in tqdm(range(1, self.epochs + 1)):
141
+ # Modified by Earlystop callback
142
+ if self.should_stop:
143
+ break
144
+
145
+ self.cb_manager.on_epoch_start(self, epoch_idx)
146
+ running_loss = 0
147
+ self.model.train()
148
+ for batch_idx, batch_data in enumerate(self.train_data_loader, start=1):
149
+ self.cb_manager.on_batch_start(
150
+ batch_data=batch_data,
151
+ batch_step=batch_idx,
152
+ trainer=self
153
+ )
154
+
155
+ loss = self._get_train_step_result(epoch_idx, batch_data) / self.accumulate_grad_batches
156
+ running_loss += loss.item() * self.accumulate_grad_batches
157
+
158
+ self._backward(loss=loss, batch_idx=batch_idx)
159
+ # The learning rate scheduler is batch level.
160
+ if self.lr_scheduler and batch_idx % self.accumulate_grad_batches == 0:
161
+ self.lr_scheduler.step()
162
+
163
+ self.cb_manager.on_batch_end(
164
+ batch_step=batch_idx,
165
+ batch_data=batch_data,
166
+ batch_train_loss=running_loss,
167
+ trainer=self
168
+ )
169
+
170
+ epoch_loss = running_loss / trainstep_per_epoch
171
+
172
+ self.log_metric("train loss", epoch_loss, step=epoch_idx)
173
+ self._log_lr(epoch_step=epoch_idx)
174
+
175
+ epoch_val_loss = self.validation_loop(epoch_step=epoch_idx)
176
+
177
+ self.cb_manager.on_epoch_end(
178
+ trainer=self,
179
+ epoch_step=epoch_idx,
180
+ epoch_train_loss=epoch_loss,
181
+ epoch_val_loss=epoch_val_loss
182
+ )
183
+
184
+ # TODO: Fix the issue of loading the model from the following result model.
185
+ self.log_model(self.model, model_type="pt")
186
+
187
+ self.cb_manager.on_train_end(self)
188
+
189
+ # return {"model":self.model, "loss": epoch_loss}
190
+ trainer_payload = TrainerPayload(
191
+ model=self.model,
192
+ train_loss=epoch_loss,
193
+ val_loss=epoch_val_loss
194
+ )
195
+ return Event(
196
+ payload=trainer_payload
197
+ )
198
+
199
+ def validation_loop(self, epoch_step: int) -> Optional[float]:
200
+
201
+ if self.val_data_loader is None:
202
+ return None
203
+
204
+ valstep_per_epoch = len(self.val_data_loader)
205
+
206
+ self.model.eval()
207
+ val_running_loss = 0
208
+ with torch.no_grad():
209
+
210
+ for val_batch_data in self.val_data_loader:
211
+
212
+ val_loss = self.validation_step(epoch_step, val_batch_data)
213
+ val_running_loss += val_loss.item()
214
+
215
+ epoch_val_loss = val_running_loss / valstep_per_epoch
216
+
217
+ self.log_metric("val_loss", epoch_val_loss, step=epoch_step)
218
+ return epoch_val_loss
219
+
220
+ def _backward(
221
+ self,
222
+ loss: torch.Tensor,
223
+ batch_idx: int
224
+ ) -> None:
225
+
226
+ def should_step_optimizer(batch_idx: int) -> bool:
227
+ return batch_idx % self.accumulate_grad_batches == 0 or batch_idx == len(self.train_data_loader)
228
+
229
+ if self.use_amp:
230
+ if self.grad_scaler is None:
231
+ self.grad_scaler = GradScaler(device=self.device)
232
+
233
+ self.grad_scaler.scale(loss).backward()
234
+ if should_step_optimizer(batch_idx):
235
+
236
+ if self._has_invalid_gradients():
237
+ # gradient clipping
238
+ self.grad_scaler.unscale_(self.optimizer)
239
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
240
+
241
+ self.grad_scaler.step(self.optimizer)
242
+ self.grad_scaler.update()
243
+ self.optimizer.zero_grad()
244
+ else:
245
+ loss.backward()
246
+ if should_step_optimizer():
247
+ if self._has_invalid_gradients(batch_idx):
248
+ # gradient clipping
249
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
250
+ self.optimizer.step()
251
+ self.optimizer.zero_grad()
252
+
253
+ def _has_invalid_gradients(self) -> bool:
254
+
255
+ for param in self.model.parameters():
256
+ if param.grad is not None:
257
+ if torch.isnan(param.grad).any() or torch.isinf(param.grad).any():
258
+ logger.warning(f"Detected NaN or Inf gradients for parameter {param}. Zeroing out the gradients.")
259
+ # side effect
260
+ param.grad.zero_()
261
+ return True
262
+ return False
263
+
264
+ def _log_lr(self, epoch_step: int) -> None:
265
+
266
+ lr = self.optimizer.param_groups[0]['lr'] if not self.lr_scheduler else self.lr_scheduler.get_last_lr()[0]
267
+ self.log_metric("learning rate", lr, step=epoch_step)
268
+
269
+ @type_checker((torch.Tensor,),
270
+ "The return type of *train_step* method must be 'Tensor'."
271
+ )
272
+ def _get_train_step_result(
273
+ self,
274
+ epoch_step: int,
275
+ batch_data: torch.Tensor
276
+ ):
277
+
278
+ if self.use_amp:
279
+ with autocast(device_type=self.device):
280
+ loss = self.handle(epoch_step, batch_data)
281
+ else:
282
+ loss = self.handle(epoch_step, batch_data)
283
+
284
+ return loss
285
+
286
+ def _get_optimizers(
287
+ self
288
+ ) -> Tuple[torch.optim.Optimizer, Optional[torch.optim.lr_scheduler.LRScheduler]]:
289
+ """
290
+ Get the optimizer and learning rate scheduler from the
291
+ *configure_optimizers* method.
292
+ """
293
+ result = self.configure_optimizers()
294
+ if isinstance(result, tuple):
295
+ optimizers, lr_scheduler = result
296
+ if not isinstance(lr_scheduler, torch.optim.lr_scheduler._LRScheduler):
297
+ raise TypeError(
298
+ f"learning rate scheduler must be type of \
299
+ torch.optim.lr_scheduler._LRScheduler "
300
+ )
301
+ else:
302
+ optimizers, lr_scheduler = result, None
303
+ return optimizers, lr_scheduler
304
+