epoch-engine 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- epoch_engine-0.1.0/LICENSE.txt +21 -0
- epoch_engine-0.1.0/PKG-INFO +68 -0
- epoch_engine-0.1.0/README.md +41 -0
- epoch_engine-0.1.0/epoch_engine/__init__.py +0 -0
- epoch_engine-0.1.0/epoch_engine/core/__init__.py +0 -0
- epoch_engine-0.1.0/epoch_engine/core/checkpoint_handler.py +40 -0
- epoch_engine-0.1.0/epoch_engine/core/metrics_tracker.py +100 -0
- epoch_engine-0.1.0/epoch_engine/core/trainer.py +294 -0
- epoch_engine-0.1.0/epoch_engine/models/__init__.py +0 -0
- epoch_engine-0.1.0/epoch_engine/models/architectures.py +313 -0
- epoch_engine-0.1.0/epoch_engine.egg-info/PKG-INFO +68 -0
- epoch_engine-0.1.0/epoch_engine.egg-info/SOURCES.txt +15 -0
- epoch_engine-0.1.0/epoch_engine.egg-info/dependency_links.txt +1 -0
- epoch_engine-0.1.0/epoch_engine.egg-info/requires.txt +13 -0
- epoch_engine-0.1.0/epoch_engine.egg-info/top_level.txt +1 -0
- epoch_engine-0.1.0/pyproject.toml +32 -0
- epoch_engine-0.1.0/setup.cfg +4 -0
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2025 Sergey Polivin
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: epoch-engine
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Trainer and evaluator for PyTorch models with a focus on simplicity and flexibility.
|
|
5
|
+
Author-email: Sergey Polivin <s.polivin@gmail.com>
|
|
6
|
+
License: MIT
|
|
7
|
+
Keywords: pytorch,deep-learning
|
|
8
|
+
Classifier: Programming Language :: Python :: 3
|
|
9
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
10
|
+
Classifier: Operating System :: OS Independent
|
|
11
|
+
Classifier: Development Status :: 3 - Alpha
|
|
12
|
+
Requires-Python: >=3.8
|
|
13
|
+
Description-Content-Type: text/markdown
|
|
14
|
+
License-File: LICENSE.txt
|
|
15
|
+
Requires-Dist: torch==2.5.0
|
|
16
|
+
Requires-Dist: torchvision==0.20.0
|
|
17
|
+
Requires-Dist: tqdm==4.67.0
|
|
18
|
+
Provides-Extra: build
|
|
19
|
+
Requires-Dist: setuptools; extra == "build"
|
|
20
|
+
Requires-Dist: wheel; extra == "build"
|
|
21
|
+
Requires-Dist: build; extra == "build"
|
|
22
|
+
Requires-Dist: twine; extra == "build"
|
|
23
|
+
Provides-Extra: linters
|
|
24
|
+
Requires-Dist: black; extra == "linters"
|
|
25
|
+
Requires-Dist: isort; extra == "linters"
|
|
26
|
+
Dynamic: license-file
|
|
27
|
+
|
|
28
|
+
# Epoch Engine - Python Library for training PyTorch models
|
|
29
|
+
|
|
30
|
+
This project represents my attempt to come up with a convenient way to train neural nets coded in Torch. While being aware of already existing libraries for training PyTorch models (e.g. PyTorch Lightning), my idea here is to make training of the models more visual and understandable as to what is going on during training.
|
|
31
|
+
|
|
32
|
+
The project is currently in its raw form, more changes expected.
|
|
33
|
+
|
|
34
|
+
## Features
|
|
35
|
+
|
|
36
|
+
* TQDM-Progress bar support for both training and validation loops
|
|
37
|
+
* Intemediate metrics computations after each forward pass (currently it is based on computing loss and accuracy only)
|
|
38
|
+
* Saving/loading checkpoints from/into Trainer directly without having to touch model, optimizer or scheduler separately
|
|
39
|
+
* Resuming training from the loaded checkpoint with epoch number being remembered automatically to avoid having to remember from which epoch the training originally started
|
|
40
|
+
* Ready-to-use neural net architectures coded from scratch (currently only 4-layer Encoder-Decoder and ResNet with 20 layers architectures are available)
|
|
41
|
+
|
|
42
|
+
## Installation
|
|
43
|
+
|
|
44
|
+
After cloning this repo, the package can be installed in the development mode as follows:
|
|
45
|
+
|
|
46
|
+
```bash
|
|
47
|
+
# Installing the main package
|
|
48
|
+
pip install -e .
|
|
49
|
+
|
|
50
|
+
# Installing additional optional dependencies
|
|
51
|
+
pip install -e .[build,linters]
|
|
52
|
+
```
|
|
53
|
+
|
|
54
|
+
## Python API
|
|
55
|
+
|
|
56
|
+
The basics of the developed API are presented in the [test script](./run_trainer.py) I built. It can be run for instance as follows:
|
|
57
|
+
|
|
58
|
+
```bash
|
|
59
|
+
python run_trainer.py --model=resnet --epochs=3 --batch-size=16
|
|
60
|
+
```
|
|
61
|
+
> The training will be launched on the device automatically derived based on the CUDA availability and the final training checkpoint will be saved in `checkpoints` directory.
|
|
62
|
+
|
|
63
|
+
One can also resume the training from the saved checkpoint:
|
|
64
|
+
|
|
65
|
+
```bash
|
|
66
|
+
python run_trainer.py --model=resnet --epochs=4 --resume-training=True --ckpt-path=checkpoints/ckpt_3.pt
|
|
67
|
+
```
|
|
68
|
+
> The training will be resumed from the loaded checkpoint with TQDM-progress bar showing the next training epoch.
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
# Epoch Engine - Python Library for training PyTorch models
|
|
2
|
+
|
|
3
|
+
This project represents my attempt to come up with a convenient way to train neural nets coded in Torch. While being aware of already existing libraries for training PyTorch models (e.g. PyTorch Lightning), my idea here is to make training of the models more visual and understandable as to what is going on during training.
|
|
4
|
+
|
|
5
|
+
The project is currently in its raw form, more changes expected.
|
|
6
|
+
|
|
7
|
+
## Features
|
|
8
|
+
|
|
9
|
+
* TQDM-Progress bar support for both training and validation loops
|
|
10
|
+
* Intemediate metrics computations after each forward pass (currently it is based on computing loss and accuracy only)
|
|
11
|
+
* Saving/loading checkpoints from/into Trainer directly without having to touch model, optimizer or scheduler separately
|
|
12
|
+
* Resuming training from the loaded checkpoint with epoch number being remembered automatically to avoid having to remember from which epoch the training originally started
|
|
13
|
+
* Ready-to-use neural net architectures coded from scratch (currently only 4-layer Encoder-Decoder and ResNet with 20 layers architectures are available)
|
|
14
|
+
|
|
15
|
+
## Installation
|
|
16
|
+
|
|
17
|
+
After cloning this repo, the package can be installed in the development mode as follows:
|
|
18
|
+
|
|
19
|
+
```bash
|
|
20
|
+
# Installing the main package
|
|
21
|
+
pip install -e .
|
|
22
|
+
|
|
23
|
+
# Installing additional optional dependencies
|
|
24
|
+
pip install -e .[build,linters]
|
|
25
|
+
```
|
|
26
|
+
|
|
27
|
+
## Python API
|
|
28
|
+
|
|
29
|
+
The basics of the developed API are presented in the [test script](./run_trainer.py) I built. It can be run for instance as follows:
|
|
30
|
+
|
|
31
|
+
```bash
|
|
32
|
+
python run_trainer.py --model=resnet --epochs=3 --batch-size=16
|
|
33
|
+
```
|
|
34
|
+
> The training will be launched on the device automatically derived based on the CUDA availability and the final training checkpoint will be saved in `checkpoints` directory.
|
|
35
|
+
|
|
36
|
+
One can also resume the training from the saved checkpoint:
|
|
37
|
+
|
|
38
|
+
```bash
|
|
39
|
+
python run_trainer.py --model=resnet --epochs=4 --resume-training=True --ckpt-path=checkpoints/ckpt_3.pt
|
|
40
|
+
```
|
|
41
|
+
> The training will be resumed from the loaded checkpoint with TQDM-progress bar showing the next training epoch.
|
|
File without changes
|
|
File without changes
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
"""Module for handling saving/loading of checkpoints."""
|
|
2
|
+
|
|
3
|
+
# Author: Sergey Polivin <s.polivin@gmail.com>
|
|
4
|
+
# License: MIT License
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class CheckpointHandler:
|
|
10
|
+
"""Class for saving and loading trainer checkpoints."""
|
|
11
|
+
|
|
12
|
+
def save_checkpoint(self, trainer, path: str) -> None:
|
|
13
|
+
"""Saves the checkpoint: last trained epoch, model state and optimizer state.
|
|
14
|
+
|
|
15
|
+
Args:
|
|
16
|
+
trainer: Trainer instance.
|
|
17
|
+
path (str): Path of the checkpoint file to which data are to be saved.
|
|
18
|
+
"""
|
|
19
|
+
torch.save(
|
|
20
|
+
{
|
|
21
|
+
"epoch": trainer.last_epoch,
|
|
22
|
+
"model_state_dict": trainer.model.state_dict(),
|
|
23
|
+
"optimizer_state_dict": trainer.optimizer.state_dict(),
|
|
24
|
+
},
|
|
25
|
+
path,
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
def load_checkpoint(self, trainer, path: str) -> None:
|
|
29
|
+
"""Loads the checkpoint: last trained epoch, model state and optimizer state.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
trainer: Trainer instance.
|
|
33
|
+
path (str): Path of the checkpoint file from which data are to be loaded.
|
|
34
|
+
"""
|
|
35
|
+
# Loading the checkpoint file
|
|
36
|
+
checkpoint = torch.load(path, weights_only=True)
|
|
37
|
+
# Loading the data into the Trainer instance attributes
|
|
38
|
+
trainer.last_epoch = checkpoint["epoch"]
|
|
39
|
+
trainer.model.load_state_dict(checkpoint["model_state_dict"])
|
|
40
|
+
trainer.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
|
|
@@ -0,0 +1,100 @@
|
|
|
1
|
+
"""Module for handling tracking and computation of metrics."""
|
|
2
|
+
|
|
3
|
+
# Author: Sergey Polivin <s.polivin@gmail.com>
|
|
4
|
+
# License: MIT License
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class MetricsTracker:
|
|
8
|
+
"""Class for keeping track of metrics."""
|
|
9
|
+
|
|
10
|
+
def __init__(self) -> None:
|
|
11
|
+
"""Initializes a class instance.
|
|
12
|
+
|
|
13
|
+
Attributes:
|
|
14
|
+
metrics (dict[str, list[float]]): Batch-level metrics in a dict format.
|
|
15
|
+
total_correct_train (int): Number of correct predictions in train set.
|
|
16
|
+
total_samples_train (int): Total number of examples in train set.
|
|
17
|
+
total_correct_valid (int): Number of correct predictions in valid set.
|
|
18
|
+
total_samples_valid (int): Total number of examples in valid set.
|
|
19
|
+
"""
|
|
20
|
+
self.metrics = {}
|
|
21
|
+
self.total_correct_train = 0
|
|
22
|
+
self.total_samples_train = 0
|
|
23
|
+
self.total_correct_valid = 0
|
|
24
|
+
self.total_samples_valid = 0
|
|
25
|
+
|
|
26
|
+
def reset(self) -> None:
|
|
27
|
+
"""Resets all metrics and counters."""
|
|
28
|
+
self.metrics = {}
|
|
29
|
+
self.total_correct_train = 0
|
|
30
|
+
self.total_samples_train = 0
|
|
31
|
+
self.total_correct_valid = 0
|
|
32
|
+
self.total_samples_valid = 0
|
|
33
|
+
|
|
34
|
+
def update(self, name: str, value: float) -> None:
|
|
35
|
+
"""Updates a specific metric.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
name (str): Metric name.
|
|
39
|
+
value (float): Value of a metric.
|
|
40
|
+
"""
|
|
41
|
+
if name not in self.metrics:
|
|
42
|
+
self.metrics[name] = []
|
|
43
|
+
self.metrics[name].append(value)
|
|
44
|
+
|
|
45
|
+
def update_accuracy(self, correct: int, total: int, split: str) -> None:
|
|
46
|
+
"""Updates accuracy counters.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
correct (int): Number of correct predictions.
|
|
50
|
+
total (int): Number of total examples.
|
|
51
|
+
split (str): Train or valid split indicator.
|
|
52
|
+
|
|
53
|
+
Raises:
|
|
54
|
+
ValueError: Error in case some other split name is passed.
|
|
55
|
+
"""
|
|
56
|
+
if split == "train":
|
|
57
|
+
self.total_correct_train += correct
|
|
58
|
+
self.total_samples_train += total
|
|
59
|
+
elif split == "valid":
|
|
60
|
+
self.total_correct_valid += correct
|
|
61
|
+
self.total_samples_valid += total
|
|
62
|
+
else:
|
|
63
|
+
raise ValueError
|
|
64
|
+
|
|
65
|
+
def compute_accuracy(self, split: str) -> float:
|
|
66
|
+
"""Computes accuracy across all batches.
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
split (str): Train or valid split indicator.
|
|
70
|
+
|
|
71
|
+
Raises:
|
|
72
|
+
ValueError: Error in case some other split name is passed.
|
|
73
|
+
|
|
74
|
+
Returns:
|
|
75
|
+
float: Accuracy score.
|
|
76
|
+
"""
|
|
77
|
+
if self.total_samples_train == 0 or self.total_samples_valid == 0:
|
|
78
|
+
return 0.0
|
|
79
|
+
if split == "train":
|
|
80
|
+
return self.total_correct_train / self.total_samples_train
|
|
81
|
+
elif split == "valid":
|
|
82
|
+
return self.total_correct_valid / self.total_samples_valid
|
|
83
|
+
else:
|
|
84
|
+
raise ValueError
|
|
85
|
+
|
|
86
|
+
def get_all_metrics(self) -> dict[str, float]:
|
|
87
|
+
"""Computes all metrics as averages.
|
|
88
|
+
|
|
89
|
+
Returns:
|
|
90
|
+
dict[str, float]: Epoch-level metrics.
|
|
91
|
+
"""
|
|
92
|
+
# Aggregating loss across batches
|
|
93
|
+
metrics = {
|
|
94
|
+
name: sum(values) / len(values) for name, values in self.metrics.items()
|
|
95
|
+
}
|
|
96
|
+
# Computing accuracy scores
|
|
97
|
+
metrics["accuracy/train"] = self.compute_accuracy(split="train")
|
|
98
|
+
metrics["accuracy/valid"] = self.compute_accuracy(split="valid")
|
|
99
|
+
|
|
100
|
+
return metrics
|
|
@@ -0,0 +1,294 @@
|
|
|
1
|
+
"""Module for Trainer functionality."""
|
|
2
|
+
|
|
3
|
+
# Author: Sergey Polivin <s.polivin@gmail.com>
|
|
4
|
+
# License: MIT License
|
|
5
|
+
|
|
6
|
+
from typing import Any, TypeAlias
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
from tqdm import tqdm
|
|
10
|
+
|
|
11
|
+
from .checkpoint_handler import CheckpointHandler
|
|
12
|
+
from .metrics_tracker import MetricsTracker
|
|
13
|
+
|
|
14
|
+
TorchDataloader: TypeAlias = torch.utils.data.dataloader.DataLoader
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class Trainer:
|
|
18
|
+
"""Trainer of PyTorch models."""
|
|
19
|
+
|
|
20
|
+
def __init__(
|
|
21
|
+
self,
|
|
22
|
+
model: Any,
|
|
23
|
+
criterion: Any,
|
|
24
|
+
optimizer: Any,
|
|
25
|
+
train_loader: TorchDataloader,
|
|
26
|
+
valid_loader: TorchDataloader,
|
|
27
|
+
train_on: str = "auto",
|
|
28
|
+
) -> None:
|
|
29
|
+
"""Initializes a class instance.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
model (Any): PyTorch model instance.
|
|
33
|
+
criterion (Any): PyTorch loss function instance.
|
|
34
|
+
optimizer (Any): PyTorch optimizer instance.
|
|
35
|
+
train_loader (TorchDataloader): Torch Dataloader for training set.
|
|
36
|
+
valid_loader (TorchDataloader): Torch Dataloader for validation set.
|
|
37
|
+
train_on (str, optional): Indicator of CPU- or GPU-based training. Defaults to "auto".
|
|
38
|
+
|
|
39
|
+
Attributes:
|
|
40
|
+
model (Any): PyTorch model instance.
|
|
41
|
+
criterion (Any): PyTorch loss function instance.
|
|
42
|
+
optimizer (Any): PyTorch optimizer instance.
|
|
43
|
+
train_loader (TorchDataloader): Torch Dataloader for training set.
|
|
44
|
+
valid_loader (TorchDataloader): Torch Dataloader for validation set.
|
|
45
|
+
device (torch.device): Device to be used to train the model.
|
|
46
|
+
scheduler (Any): Instance of PyTorch scheduler for learning rate. Defaults to None.
|
|
47
|
+
scheduler_level (bool): Level on which to apply scheduler. Defaults to None.
|
|
48
|
+
last_epoch (int): Last epoch when training has been successfully finished. Defaults to 0.
|
|
49
|
+
metrics_tracker (MetricsTracker): Class for handling computing metrics.
|
|
50
|
+
ckpt_handler (CheckpointHandler): Class for saving/loading checkpoints.
|
|
51
|
+
"""
|
|
52
|
+
# Using the CPU or GPU device or inferring one
|
|
53
|
+
self.device = (
|
|
54
|
+
torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
55
|
+
if train_on == "auto"
|
|
56
|
+
else torch.device(train_on)
|
|
57
|
+
)
|
|
58
|
+
# Moving model to the device specified
|
|
59
|
+
self.model = model.to(self.device)
|
|
60
|
+
# Reinitializing the optimizer with model parameters having been moved to device
|
|
61
|
+
self.optimizer = type(optimizer)(self.model.parameters(), **optimizer.defaults)
|
|
62
|
+
|
|
63
|
+
# Setting loss function
|
|
64
|
+
self.criterion = criterion
|
|
65
|
+
|
|
66
|
+
# Setting dataloaders for training/validation sets
|
|
67
|
+
self.train_loader = train_loader
|
|
68
|
+
self.valid_loader = valid_loader
|
|
69
|
+
|
|
70
|
+
# Setting scheduler-related attributes
|
|
71
|
+
self.scheduler = None
|
|
72
|
+
self.scheduler_level = None
|
|
73
|
+
|
|
74
|
+
# Setting epoch indicator
|
|
75
|
+
self.last_epoch = 0
|
|
76
|
+
|
|
77
|
+
# Setting tracker for metrics and checkpoints handler
|
|
78
|
+
self.metrics_tracker = MetricsTracker()
|
|
79
|
+
self.ckpt_handler = CheckpointHandler()
|
|
80
|
+
|
|
81
|
+
def set_scheduler(
|
|
82
|
+
self,
|
|
83
|
+
scheduler_class: Any,
|
|
84
|
+
scheduler_params: dict[str, Any],
|
|
85
|
+
level: str = "epoch",
|
|
86
|
+
) -> None:
|
|
87
|
+
"""Sets and initializes a scheduler.
|
|
88
|
+
|
|
89
|
+
Args:
|
|
90
|
+
scheduler_class (Any): Scheduler class from PyTorch
|
|
91
|
+
scheduler_params (dict[str, Any]): Scheduler parameters.
|
|
92
|
+
level (str, optional): Level on which to apply scheduler. Defaults to "epoch".
|
|
93
|
+
|
|
94
|
+
Raises:
|
|
95
|
+
ValueError: Error raised if the specified level name is unexpected.
|
|
96
|
+
"""
|
|
97
|
+
# Setting scheduler level on batch- or epoch-level
|
|
98
|
+
if level not in ("epoch", "batch"):
|
|
99
|
+
raise ValueError("Invalid value")
|
|
100
|
+
else:
|
|
101
|
+
self.scheduler_level = level
|
|
102
|
+
# Initializing a scheduler
|
|
103
|
+
self.scheduler = scheduler_class(self.optimizer, **scheduler_params)
|
|
104
|
+
|
|
105
|
+
def reset_scheduler(self) -> None:
|
|
106
|
+
"""Resetting and removing scheduler."""
|
|
107
|
+
self.scheduler = None
|
|
108
|
+
self.scheduler_level = None
|
|
109
|
+
|
|
110
|
+
def __call__(
|
|
111
|
+
self, x_batch: torch.Tensor, y_batch: torch.Tensor
|
|
112
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
113
|
+
"""Computes loss and model outputs.
|
|
114
|
+
|
|
115
|
+
Args:
|
|
116
|
+
x_batch (torch.Tensor): Batch of examples.
|
|
117
|
+
y_batch (torch.Tensor): Batch of labels.
|
|
118
|
+
|
|
119
|
+
Returns:
|
|
120
|
+
tuple[torch.Tensor, torch.Tensor]: Tuple containing value of loss function and output tensor.
|
|
121
|
+
"""
|
|
122
|
+
outputs = self.model(x_batch)
|
|
123
|
+
loss = self.criterion(outputs, y_batch.long())
|
|
124
|
+
|
|
125
|
+
return loss, outputs
|
|
126
|
+
|
|
127
|
+
def run(
|
|
128
|
+
self,
|
|
129
|
+
epochs: int,
|
|
130
|
+
seed: int = 42,
|
|
131
|
+
enable_tqdm: bool = True,
|
|
132
|
+
) -> None:
|
|
133
|
+
"""Launches the trainer.
|
|
134
|
+
|
|
135
|
+
Args:
|
|
136
|
+
epochs (int): Number of epochs to train.
|
|
137
|
+
seed (int, optional): Random number generator seed. Defaults to 42.
|
|
138
|
+
enable_tqdm (bool, optional): Indicator to turn on tqdm progress bar. Defaults to True.
|
|
139
|
+
"""
|
|
140
|
+
# Setting seed
|
|
141
|
+
if seed:
|
|
142
|
+
torch.manual_seed(seed)
|
|
143
|
+
# Computing the number of total epochs to be trained (useful in case of additional trainings)
|
|
144
|
+
total_epochs = epochs + self.last_epoch
|
|
145
|
+
# Running training/validation loops
|
|
146
|
+
for epoch in range(self.last_epoch, total_epochs):
|
|
147
|
+
# Processing all training batches
|
|
148
|
+
self._train_one_epoch(
|
|
149
|
+
epoch=epoch, epochs=total_epochs, enable_tqdm=enable_tqdm
|
|
150
|
+
)
|
|
151
|
+
# Processing all validation batches
|
|
152
|
+
self._validate_one_epoch(
|
|
153
|
+
epoch=epoch, epochs=total_epochs, enable_tqdm=enable_tqdm
|
|
154
|
+
)
|
|
155
|
+
# Incrementing last epoch after successful training/validation
|
|
156
|
+
self.last_epoch = epoch + 1
|
|
157
|
+
# Resetting metrics counters after successful training/validation
|
|
158
|
+
self.metrics_tracker.reset()
|
|
159
|
+
|
|
160
|
+
def _train_one_epoch(self, epoch: int, epochs: int, enable_tqdm: bool) -> None:
|
|
161
|
+
"""Processes all training batches for one epoch.
|
|
162
|
+
|
|
163
|
+
Args:
|
|
164
|
+
epoch (int): Current epoch.
|
|
165
|
+
epochs (int): Total number of epochs.
|
|
166
|
+
enable_tqdm (bool): Indicator to turn on tqdm progress bar.
|
|
167
|
+
"""
|
|
168
|
+
|
|
169
|
+
with tqdm(
|
|
170
|
+
total=len(self.train_loader),
|
|
171
|
+
desc=f"Epoch {epoch + 1}/{epochs} [Training]",
|
|
172
|
+
position=0,
|
|
173
|
+
leave=True,
|
|
174
|
+
unit="batches",
|
|
175
|
+
disable=not enable_tqdm,
|
|
176
|
+
) as pbar:
|
|
177
|
+
# Setting the learning rate in progress bar
|
|
178
|
+
pbar.set_postfix({"lr": self.optimizer.param_groups[0]["lr"]})
|
|
179
|
+
# Setting model in training mode
|
|
180
|
+
self.model.train()
|
|
181
|
+
# Processing all training batches
|
|
182
|
+
for batch in self.train_loader:
|
|
183
|
+
self._process_batch(batch=batch, split="train")
|
|
184
|
+
pbar.update(1)
|
|
185
|
+
# Adjusting learning rate if set
|
|
186
|
+
if self.scheduler and self.scheduler_level == "epoch":
|
|
187
|
+
self.scheduler.step()
|
|
188
|
+
|
|
189
|
+
def _validate_one_epoch(self, epoch: int, epochs: int, enable_tqdm: bool) -> None:
|
|
190
|
+
"""Processes all validation batches for one epoch.
|
|
191
|
+
|
|
192
|
+
Args:
|
|
193
|
+
epoch (int): Current epoch.
|
|
194
|
+
epochs (int): Total number of epochs.
|
|
195
|
+
enable_tqdm (bool): Indicator to turn on tqdm progress bar.
|
|
196
|
+
"""
|
|
197
|
+
with tqdm(
|
|
198
|
+
total=len(self.valid_loader),
|
|
199
|
+
desc=f"Epoch {epoch + 1}/{epochs} [Validation]",
|
|
200
|
+
position=0,
|
|
201
|
+
leave=True,
|
|
202
|
+
unit="batches",
|
|
203
|
+
disable=not enable_tqdm,
|
|
204
|
+
) as pbar:
|
|
205
|
+
# Setting model to evalutation model
|
|
206
|
+
self.model.eval()
|
|
207
|
+
# Processing all validation batches (grad compute turned off)
|
|
208
|
+
with torch.no_grad():
|
|
209
|
+
for batch in self.valid_loader:
|
|
210
|
+
self._process_batch(batch=batch, split="valid")
|
|
211
|
+
pbar.update(1)
|
|
212
|
+
|
|
213
|
+
# Aggregating batch-level metrics (loss, accuracy) to epoch-level
|
|
214
|
+
self.metrics = self.metrics_tracker.get_all_metrics()
|
|
215
|
+
# Setting the epoch-level stats to the progress bar
|
|
216
|
+
pbar.set_postfix(
|
|
217
|
+
{
|
|
218
|
+
"loss": (
|
|
219
|
+
round(self.metrics["loss/train"], 4),
|
|
220
|
+
round(self.metrics["loss/valid"], 4),
|
|
221
|
+
),
|
|
222
|
+
"acc": (
|
|
223
|
+
round(self.metrics["accuracy/train"], 4),
|
|
224
|
+
round(self.metrics["accuracy/valid"], 4),
|
|
225
|
+
),
|
|
226
|
+
}
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
def _process_batch(self, batch: list[torch.Tensor], split: str) -> None:
|
|
230
|
+
"""Backpropagates model during training and collects batch-level metrics.
|
|
231
|
+
|
|
232
|
+
Args:
|
|
233
|
+
batch (list[torch.Tensor]): Examples and labels.
|
|
234
|
+
split (str): Indicator of training or validation set.
|
|
235
|
+
"""
|
|
236
|
+
|
|
237
|
+
# Moving examples and labels batches to device chosen
|
|
238
|
+
x_batch, y_batch = batch
|
|
239
|
+
x_batch, y_batch = x_batch.to(self.device), y_batch.to(self.device)
|
|
240
|
+
# Computing loss and output tensor for the batch
|
|
241
|
+
loss, outputs = self(x_batch=x_batch, y_batch=y_batch)
|
|
242
|
+
|
|
243
|
+
# If training set, backpropagate
|
|
244
|
+
if split == "train":
|
|
245
|
+
loss.backward()
|
|
246
|
+
self.optimizer.step()
|
|
247
|
+
self.optimizer.zero_grad()
|
|
248
|
+
# Adjusting learning rate at batch-level
|
|
249
|
+
if self.scheduler and self.scheduler_level == "batch":
|
|
250
|
+
self.scheduler.step()
|
|
251
|
+
|
|
252
|
+
# Collecting batch-level metrics
|
|
253
|
+
self._collect_metrics(loss=loss, outputs=outputs, labels=y_batch, split=split)
|
|
254
|
+
|
|
255
|
+
def _collect_metrics(
|
|
256
|
+
self,
|
|
257
|
+
loss: torch.Tensor,
|
|
258
|
+
outputs: torch.Tensor,
|
|
259
|
+
labels: torch.Tensor,
|
|
260
|
+
split: str,
|
|
261
|
+
) -> None:
|
|
262
|
+
"""Records loss and accuracy.
|
|
263
|
+
|
|
264
|
+
Args:
|
|
265
|
+
loss (torch.Tensor): Value of loss function.
|
|
266
|
+
outputs (torch.Tensor): Output tensor.
|
|
267
|
+
labels (torch.Tensor): Batch of labels.
|
|
268
|
+
split (str): Indicator of training or validation set.
|
|
269
|
+
"""
|
|
270
|
+
# Adding batch-level loss
|
|
271
|
+
loss = loss.item()
|
|
272
|
+
self.metrics_tracker.update(f"loss/{split}", loss)
|
|
273
|
+
|
|
274
|
+
# Adding batch-level accuracy
|
|
275
|
+
predictions = torch.argmax(outputs, dim=1)
|
|
276
|
+
correct = (predictions == labels).sum().cpu().item()
|
|
277
|
+
total = labels.size(0)
|
|
278
|
+
self.metrics_tracker.update_accuracy(correct, total, split=split)
|
|
279
|
+
|
|
280
|
+
def save_checkpoint(self, path: str) -> None:
|
|
281
|
+
"""Saves trainer checkpoint.
|
|
282
|
+
|
|
283
|
+
Args:
|
|
284
|
+
path (str): Path to the checkpoint to be saved.
|
|
285
|
+
"""
|
|
286
|
+
self.ckpt_handler.save_checkpoint(path=path, trainer=self)
|
|
287
|
+
|
|
288
|
+
def load_checkpoint(self, path: str) -> None:
|
|
289
|
+
"""Loads trainer checkpoint.
|
|
290
|
+
|
|
291
|
+
Args:
|
|
292
|
+
path (str): Path to the checkpoint to be loaded.
|
|
293
|
+
"""
|
|
294
|
+
self.ckpt_handler.load_checkpoint(path=path, trainer=self)
|
|
File without changes
|
|
@@ -0,0 +1,313 @@
|
|
|
1
|
+
"""Module containing building blocks of models and functions to build neural networks."""
|
|
2
|
+
|
|
3
|
+
# Author: Sergey Polivin <s.polivin@gmail.com>
|
|
4
|
+
# License: MIT License
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
import torch.nn as nn
|
|
8
|
+
import torch.nn.functional as F
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class Encoder(nn.Module):
|
|
12
|
+
"""Builds a series of convolutional blocks."""
|
|
13
|
+
|
|
14
|
+
def __init__(self, encoder_channels: tuple[int]) -> None:
|
|
15
|
+
"""Initializes a class instance.
|
|
16
|
+
|
|
17
|
+
Args:
|
|
18
|
+
encoder_channels (tuple[int]): Tuple of convolution channels.
|
|
19
|
+
"""
|
|
20
|
+
super().__init__()
|
|
21
|
+
# Creating a series of encoder blocks in accordance with channels in `encoder_channels`
|
|
22
|
+
encoder_blocks = [
|
|
23
|
+
self._make_encoder_block(in_channel, out_channel)
|
|
24
|
+
for in_channel, out_channel in zip(encoder_channels, encoder_channels[1:])
|
|
25
|
+
]
|
|
26
|
+
# Sequentially connecting the generated convolution blocks
|
|
27
|
+
self.encoder_blocks = nn.Sequential(*encoder_blocks)
|
|
28
|
+
|
|
29
|
+
def _make_encoder_block(self, in_channels: int, out_channels: int) -> nn.Sequential:
|
|
30
|
+
"""Creates a convolution block with ReLU activation and MaxPooling.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
in_channels (int): Number of input channels for convolution.
|
|
34
|
+
out_channels (int): Number of output channels for convolution.
|
|
35
|
+
|
|
36
|
+
Returns:
|
|
37
|
+
nn.Sequential: Convolution block of sequentially connected layers.
|
|
38
|
+
"""
|
|
39
|
+
return nn.Sequential(
|
|
40
|
+
nn.Conv2d(
|
|
41
|
+
in_channels=in_channels,
|
|
42
|
+
out_channels=out_channels,
|
|
43
|
+
kernel_size=3,
|
|
44
|
+
padding=1,
|
|
45
|
+
),
|
|
46
|
+
nn.ReLU(),
|
|
47
|
+
nn.MaxPool2d(kernel_size=3),
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
51
|
+
"""Makes a forward pass.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
x (torch.Tensor): Input tensor for Encoder.
|
|
55
|
+
|
|
56
|
+
Returns:
|
|
57
|
+
torch.Tensor: Output tensor of Encoder.
|
|
58
|
+
"""
|
|
59
|
+
return self.encoder_blocks(x)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class Decoder(nn.Module):
|
|
63
|
+
"""Builds a series of feedforward blocks."""
|
|
64
|
+
|
|
65
|
+
def __init__(self, decoder_features: tuple[int], num_labels: int) -> None:
|
|
66
|
+
"""Initializes a class instance.
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
decoder_features (tuple[int]): Tuple of features in the Decoder.
|
|
70
|
+
num_labels (int): Number of output labels in the last layer.
|
|
71
|
+
"""
|
|
72
|
+
super().__init__()
|
|
73
|
+
# Creating a series of decoder blocks in accordance with features in `decoder_features`
|
|
74
|
+
decoder_blocks = [
|
|
75
|
+
self._make_decoder_block(in_feature, out_feature)
|
|
76
|
+
for in_feature, out_feature in zip(decoder_features, decoder_features[1:])
|
|
77
|
+
]
|
|
78
|
+
# Sequentially connecting the generated feedforward blocks
|
|
79
|
+
self.decoder_blocks = nn.Sequential(*decoder_blocks)
|
|
80
|
+
# Creating the last layer of the Decoder
|
|
81
|
+
self.last = nn.Linear(decoder_features[-1], num_labels)
|
|
82
|
+
|
|
83
|
+
def _make_decoder_block(self, in_features: int, out_features: int) -> nn.Sequential:
|
|
84
|
+
"""Creates a fully connected linear layer with Sigmoid activation.
|
|
85
|
+
|
|
86
|
+
Args:
|
|
87
|
+
in_features (int): Number of input features of the Decoder.
|
|
88
|
+
out_features (int): Number of output features of the Decoder.
|
|
89
|
+
|
|
90
|
+
Returns:
|
|
91
|
+
nn.Sequential: Decoder block of sequentially connected layers.
|
|
92
|
+
"""
|
|
93
|
+
return nn.Sequential(
|
|
94
|
+
nn.Linear(in_features, out_features),
|
|
95
|
+
nn.Sigmoid(),
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
99
|
+
"""Makes a forward pass.
|
|
100
|
+
|
|
101
|
+
Args:
|
|
102
|
+
x (torch.Tensor): Input tensor for Decoder.
|
|
103
|
+
|
|
104
|
+
Returns:
|
|
105
|
+
torch.Tensor: Output tensor of Decoder.
|
|
106
|
+
"""
|
|
107
|
+
x = self.decoder_blocks(x)
|
|
108
|
+
x = self.last(x)
|
|
109
|
+
|
|
110
|
+
return x
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
class EDNet(nn.Module):
|
|
114
|
+
"""Joins Encoder with Decoder to form a CNN."""
|
|
115
|
+
|
|
116
|
+
def __init__(
|
|
117
|
+
self,
|
|
118
|
+
in_channels: int,
|
|
119
|
+
encoder_channels: tuple[int],
|
|
120
|
+
decoder_features: tuple[int],
|
|
121
|
+
num_labels: int,
|
|
122
|
+
) -> None:
|
|
123
|
+
"""Initializes a class instance.
|
|
124
|
+
|
|
125
|
+
Args:
|
|
126
|
+
in_channels (int): Number of input channels for the image.
|
|
127
|
+
encoder_channels (tuple[int]): Tuple of channels for the Encoder.
|
|
128
|
+
decoder_features (tuple[int]): Tuple of channels for the Decoder.
|
|
129
|
+
num_labels (int): Number of ouput labels.
|
|
130
|
+
"""
|
|
131
|
+
super().__init__()
|
|
132
|
+
# Setting up Encoder block
|
|
133
|
+
self.encoder_channels = [in_channels, *encoder_channels]
|
|
134
|
+
self.encoder = Encoder(self.encoder_channels)
|
|
135
|
+
# Setting up Decoder block
|
|
136
|
+
self.decoder_features = decoder_features
|
|
137
|
+
self.decoder = Decoder(self.decoder_features, num_labels)
|
|
138
|
+
# Flatten layer
|
|
139
|
+
self.flatten = nn.Flatten()
|
|
140
|
+
|
|
141
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
142
|
+
"""Makes a forward pass.
|
|
143
|
+
|
|
144
|
+
Args:
|
|
145
|
+
x (torch.Tensor): Input tensor for a CNN.
|
|
146
|
+
|
|
147
|
+
Returns:
|
|
148
|
+
torch.Tensor: Output tensor of a CNN.
|
|
149
|
+
"""
|
|
150
|
+
x = self.encoder(x)
|
|
151
|
+
x = self.flatten(x)
|
|
152
|
+
x = self.decoder(x)
|
|
153
|
+
|
|
154
|
+
return x
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
class BasicBlock(nn.Module):
|
|
158
|
+
def __init__(self, in_channels: int, out_channels: int, stride: int = 1) -> None:
|
|
159
|
+
"""Builds a block of ResNet layer with 2 inner blocks (CONV + BN + SKIP CONNECTION).
|
|
160
|
+
|
|
161
|
+
Args:
|
|
162
|
+
in_channels (int): Number of input channels.
|
|
163
|
+
out_channels (int): Number of output channels.
|
|
164
|
+
stride (int, optional): Value of stride. Defaults to 1.
|
|
165
|
+
"""
|
|
166
|
+
super().__init__()
|
|
167
|
+
# BLOCK 1 (CONV + BN) #################################################
|
|
168
|
+
self.conv1 = nn.Conv2d(
|
|
169
|
+
in_channels,
|
|
170
|
+
out_channels,
|
|
171
|
+
kernel_size=3,
|
|
172
|
+
stride=stride,
|
|
173
|
+
padding=1,
|
|
174
|
+
bias=False,
|
|
175
|
+
)
|
|
176
|
+
self.bn1 = nn.BatchNorm2d(out_channels)
|
|
177
|
+
|
|
178
|
+
# BLOCK 2 (CONV + BN) #################################################
|
|
179
|
+
self.conv2 = nn.Conv2d(
|
|
180
|
+
out_channels,
|
|
181
|
+
out_channels,
|
|
182
|
+
kernel_size=3,
|
|
183
|
+
stride=1,
|
|
184
|
+
padding=1,
|
|
185
|
+
bias=False,
|
|
186
|
+
)
|
|
187
|
+
self.bn2 = nn.BatchNorm2d(out_channels)
|
|
188
|
+
|
|
189
|
+
# SKIP CONNECTION #################################################
|
|
190
|
+
self.shortcut = nn.Sequential()
|
|
191
|
+
# Specifying condition for applying residual connection
|
|
192
|
+
if stride != 1 or in_channels != out_channels:
|
|
193
|
+
self.shortcut = nn.Sequential(
|
|
194
|
+
nn.Conv2d(
|
|
195
|
+
in_channels,
|
|
196
|
+
out_channels,
|
|
197
|
+
kernel_size=1,
|
|
198
|
+
stride=stride,
|
|
199
|
+
bias=False,
|
|
200
|
+
),
|
|
201
|
+
nn.BatchNorm2d(out_channels),
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
205
|
+
"""Makes a forward pass.
|
|
206
|
+
|
|
207
|
+
Args:
|
|
208
|
+
x (torch.Tensor): Input tensor.
|
|
209
|
+
|
|
210
|
+
Returns:
|
|
211
|
+
torch.Tensor: Output tensor.
|
|
212
|
+
"""
|
|
213
|
+
# Output for BLOCK 1
|
|
214
|
+
out = F.relu(self.bn1(self.conv1(x)))
|
|
215
|
+
# Output for BLOCK 2
|
|
216
|
+
out = self.bn2(self.conv2(out))
|
|
217
|
+
# Adding up SKIP CONNECTION if condition met
|
|
218
|
+
out += self.shortcut(x)
|
|
219
|
+
# Output of RESNET BLOCK
|
|
220
|
+
return F.relu(out)
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
class ResNet(nn.Module):
|
|
224
|
+
def __init__(
|
|
225
|
+
self,
|
|
226
|
+
in_channels: int,
|
|
227
|
+
block: BasicBlock,
|
|
228
|
+
num_blocks: list[int],
|
|
229
|
+
num_classes: int = 10,
|
|
230
|
+
) -> None:
|
|
231
|
+
"""Builds a ResNet model with 3 layers.
|
|
232
|
+
|
|
233
|
+
Args:
|
|
234
|
+
in_channels (int): Number of input channels.
|
|
235
|
+
block (BasicBlock): Block of ResNet layer.
|
|
236
|
+
num_blocks (list[int]): Number of blocks inside Block of ResNet layer.
|
|
237
|
+
num_classes (int, optional): Number of output classes. Defaults to 10.
|
|
238
|
+
"""
|
|
239
|
+
super().__init__()
|
|
240
|
+
# Specifying a number of initial input channels for ResNet block
|
|
241
|
+
self.in_block_channels = 16
|
|
242
|
+
|
|
243
|
+
# LAYER 0 (CONV + BN) #################################################################################
|
|
244
|
+
self.conv1 = nn.Conv2d(
|
|
245
|
+
in_channels,
|
|
246
|
+
self.in_block_channels,
|
|
247
|
+
kernel_size=3,
|
|
248
|
+
stride=1,
|
|
249
|
+
padding=1,
|
|
250
|
+
bias=False,
|
|
251
|
+
)
|
|
252
|
+
self.bn1 = nn.BatchNorm2d(self.in_block_channels)
|
|
253
|
+
|
|
254
|
+
# LAYERS 1, 2, 3 (`num_blocks` blocks with 2 CONV/BN BLOCKS for 1 layer) ##############################
|
|
255
|
+
self.layer1 = self._make_resnet_layer(
|
|
256
|
+
block=block, out_channels=16, num_blocks=num_blocks[0], stride=1
|
|
257
|
+
)
|
|
258
|
+
self.layer2 = self._make_resnet_layer(
|
|
259
|
+
block=block, out_channels=32, num_blocks=num_blocks[1], stride=2
|
|
260
|
+
)
|
|
261
|
+
self.layer3 = self._make_resnet_layer(
|
|
262
|
+
block=block, out_channels=64, num_blocks=num_blocks[2], stride=2
|
|
263
|
+
)
|
|
264
|
+
|
|
265
|
+
# GLOBAL POOLING LAYER ################################################################################
|
|
266
|
+
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
|
|
267
|
+
|
|
268
|
+
# FC LAYER ############################################################################################
|
|
269
|
+
self.fc = nn.Linear(64, num_classes)
|
|
270
|
+
|
|
271
|
+
def _make_resnet_layer(
|
|
272
|
+
self,
|
|
273
|
+
block: BasicBlock,
|
|
274
|
+
out_channels: int,
|
|
275
|
+
num_blocks: int,
|
|
276
|
+
stride: int,
|
|
277
|
+
) -> nn.Sequential:
|
|
278
|
+
"""Creates a ResNet layer by stacking up BasicBlock-s.
|
|
279
|
+
|
|
280
|
+
Args:
|
|
281
|
+
block (BasicBlock): Block inside ResNet layer.
|
|
282
|
+
out_channels (int): Number of output channels.
|
|
283
|
+
num_blocks (int): Number of blocks inside ResNet layer.
|
|
284
|
+
stride (int): Value of stride.
|
|
285
|
+
|
|
286
|
+
Returns:
|
|
287
|
+
nn.Sequential: Layer of ResNet with blocks added inside.
|
|
288
|
+
"""
|
|
289
|
+
# Specifying strides to be used for each BasicBlock
|
|
290
|
+
strides = [stride] + [1] * (num_blocks - 1)
|
|
291
|
+
# Stacking up ResNet blocks
|
|
292
|
+
layers = []
|
|
293
|
+
for stride in strides:
|
|
294
|
+
layers.append(block(self.in_block_channels, out_channels, stride))
|
|
295
|
+
self.in_block_channels = out_channels
|
|
296
|
+
return nn.Sequential(*layers)
|
|
297
|
+
|
|
298
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
299
|
+
"""Makes a forward pass.
|
|
300
|
+
|
|
301
|
+
Args:
|
|
302
|
+
x (torch.Tensor): Input tensor.
|
|
303
|
+
|
|
304
|
+
Returns:
|
|
305
|
+
torch.Tensor: Output tensor.
|
|
306
|
+
"""
|
|
307
|
+
out = F.relu(self.bn1(self.conv1(x)))
|
|
308
|
+
out = self.layer1(out)
|
|
309
|
+
out = self.layer2(out)
|
|
310
|
+
out = self.layer3(out)
|
|
311
|
+
out = self.avg_pool(out)
|
|
312
|
+
out = torch.flatten(out, 1)
|
|
313
|
+
return self.fc(out)
|
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: epoch-engine
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Trainer and evaluator for PyTorch models with a focus on simplicity and flexibility.
|
|
5
|
+
Author-email: Sergey Polivin <s.polivin@gmail.com>
|
|
6
|
+
License: MIT
|
|
7
|
+
Keywords: pytorch,deep-learning
|
|
8
|
+
Classifier: Programming Language :: Python :: 3
|
|
9
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
10
|
+
Classifier: Operating System :: OS Independent
|
|
11
|
+
Classifier: Development Status :: 3 - Alpha
|
|
12
|
+
Requires-Python: >=3.8
|
|
13
|
+
Description-Content-Type: text/markdown
|
|
14
|
+
License-File: LICENSE.txt
|
|
15
|
+
Requires-Dist: torch==2.5.0
|
|
16
|
+
Requires-Dist: torchvision==0.20.0
|
|
17
|
+
Requires-Dist: tqdm==4.67.0
|
|
18
|
+
Provides-Extra: build
|
|
19
|
+
Requires-Dist: setuptools; extra == "build"
|
|
20
|
+
Requires-Dist: wheel; extra == "build"
|
|
21
|
+
Requires-Dist: build; extra == "build"
|
|
22
|
+
Requires-Dist: twine; extra == "build"
|
|
23
|
+
Provides-Extra: linters
|
|
24
|
+
Requires-Dist: black; extra == "linters"
|
|
25
|
+
Requires-Dist: isort; extra == "linters"
|
|
26
|
+
Dynamic: license-file
|
|
27
|
+
|
|
28
|
+
# Epoch Engine - Python Library for training PyTorch models
|
|
29
|
+
|
|
30
|
+
This project represents my attempt to come up with a convenient way to train neural nets coded in Torch. While being aware of already existing libraries for training PyTorch models (e.g. PyTorch Lightning), my idea here is to make training of the models more visual and understandable as to what is going on during training.
|
|
31
|
+
|
|
32
|
+
The project is currently in its raw form, more changes expected.
|
|
33
|
+
|
|
34
|
+
## Features
|
|
35
|
+
|
|
36
|
+
* TQDM-Progress bar support for both training and validation loops
|
|
37
|
+
* Intemediate metrics computations after each forward pass (currently it is based on computing loss and accuracy only)
|
|
38
|
+
* Saving/loading checkpoints from/into Trainer directly without having to touch model, optimizer or scheduler separately
|
|
39
|
+
* Resuming training from the loaded checkpoint with epoch number being remembered automatically to avoid having to remember from which epoch the training originally started
|
|
40
|
+
* Ready-to-use neural net architectures coded from scratch (currently only 4-layer Encoder-Decoder and ResNet with 20 layers architectures are available)
|
|
41
|
+
|
|
42
|
+
## Installation
|
|
43
|
+
|
|
44
|
+
After cloning this repo, the package can be installed in the development mode as follows:
|
|
45
|
+
|
|
46
|
+
```bash
|
|
47
|
+
# Installing the main package
|
|
48
|
+
pip install -e .
|
|
49
|
+
|
|
50
|
+
# Installing additional optional dependencies
|
|
51
|
+
pip install -e .[build,linters]
|
|
52
|
+
```
|
|
53
|
+
|
|
54
|
+
## Python API
|
|
55
|
+
|
|
56
|
+
The basics of the developed API are presented in the [test script](./run_trainer.py) I built. It can be run for instance as follows:
|
|
57
|
+
|
|
58
|
+
```bash
|
|
59
|
+
python run_trainer.py --model=resnet --epochs=3 --batch-size=16
|
|
60
|
+
```
|
|
61
|
+
> The training will be launched on the device automatically derived based on the CUDA availability and the final training checkpoint will be saved in `checkpoints` directory.
|
|
62
|
+
|
|
63
|
+
One can also resume the training from the saved checkpoint:
|
|
64
|
+
|
|
65
|
+
```bash
|
|
66
|
+
python run_trainer.py --model=resnet --epochs=4 --resume-training=True --ckpt-path=checkpoints/ckpt_3.pt
|
|
67
|
+
```
|
|
68
|
+
> The training will be resumed from the loaded checkpoint with TQDM-progress bar showing the next training epoch.
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
LICENSE.txt
|
|
2
|
+
README.md
|
|
3
|
+
pyproject.toml
|
|
4
|
+
epoch_engine/__init__.py
|
|
5
|
+
epoch_engine.egg-info/PKG-INFO
|
|
6
|
+
epoch_engine.egg-info/SOURCES.txt
|
|
7
|
+
epoch_engine.egg-info/dependency_links.txt
|
|
8
|
+
epoch_engine.egg-info/requires.txt
|
|
9
|
+
epoch_engine.egg-info/top_level.txt
|
|
10
|
+
epoch_engine/core/__init__.py
|
|
11
|
+
epoch_engine/core/checkpoint_handler.py
|
|
12
|
+
epoch_engine/core/metrics_tracker.py
|
|
13
|
+
epoch_engine/core/trainer.py
|
|
14
|
+
epoch_engine/models/__init__.py
|
|
15
|
+
epoch_engine/models/architectures.py
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
epoch_engine
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
[project]
|
|
2
|
+
name = "epoch-engine"
|
|
3
|
+
version = "0.1.0"
|
|
4
|
+
description = "Trainer and evaluator for PyTorch models with a focus on simplicity and flexibility."
|
|
5
|
+
authors = [{name = "Sergey Polivin", email = "s.polivin@gmail.com"}]
|
|
6
|
+
keywords = ["pytorch", "deep-learning"]
|
|
7
|
+
classifiers = [
|
|
8
|
+
"Programming Language :: Python :: 3",
|
|
9
|
+
"License :: OSI Approved :: MIT License",
|
|
10
|
+
"Operating System :: OS Independent",
|
|
11
|
+
"Development Status :: 3 - Alpha",
|
|
12
|
+
]
|
|
13
|
+
readme = "README.md"
|
|
14
|
+
requires-python = ">=3.8"
|
|
15
|
+
license = {text = "MIT"}
|
|
16
|
+
dependencies = [
|
|
17
|
+
"torch==2.5.0",
|
|
18
|
+
"torchvision==0.20.0",
|
|
19
|
+
"tqdm==4.67.0"
|
|
20
|
+
]
|
|
21
|
+
|
|
22
|
+
[project.optional-dependencies]
|
|
23
|
+
build = ["setuptools", "wheel", "build", "twine"]
|
|
24
|
+
linters = ["black", "isort"]
|
|
25
|
+
|
|
26
|
+
[build-system]
|
|
27
|
+
requires = ["setuptools>=61.0", "wheel"]
|
|
28
|
+
build-backend = "setuptools.build_meta"
|
|
29
|
+
|
|
30
|
+
[tool.setuptools.packages.find]
|
|
31
|
+
where = ["."]
|
|
32
|
+
include = ["epoch_engine*"]
|