epoch-engine 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Sergey Polivin
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,68 @@
1
+ Metadata-Version: 2.4
2
+ Name: epoch-engine
3
+ Version: 0.1.0
4
+ Summary: Trainer and evaluator for PyTorch models with a focus on simplicity and flexibility.
5
+ Author-email: Sergey Polivin <s.polivin@gmail.com>
6
+ License: MIT
7
+ Keywords: pytorch,deep-learning
8
+ Classifier: Programming Language :: Python :: 3
9
+ Classifier: License :: OSI Approved :: MIT License
10
+ Classifier: Operating System :: OS Independent
11
+ Classifier: Development Status :: 3 - Alpha
12
+ Requires-Python: >=3.8
13
+ Description-Content-Type: text/markdown
14
+ License-File: LICENSE.txt
15
+ Requires-Dist: torch==2.5.0
16
+ Requires-Dist: torchvision==0.20.0
17
+ Requires-Dist: tqdm==4.67.0
18
+ Provides-Extra: build
19
+ Requires-Dist: setuptools; extra == "build"
20
+ Requires-Dist: wheel; extra == "build"
21
+ Requires-Dist: build; extra == "build"
22
+ Requires-Dist: twine; extra == "build"
23
+ Provides-Extra: linters
24
+ Requires-Dist: black; extra == "linters"
25
+ Requires-Dist: isort; extra == "linters"
26
+ Dynamic: license-file
27
+
28
+ # Epoch Engine - Python Library for training PyTorch models
29
+
30
+ This project represents my attempt to come up with a convenient way to train neural nets coded in Torch. While being aware of already existing libraries for training PyTorch models (e.g. PyTorch Lightning), my idea here is to make training of the models more visual and understandable as to what is going on during training.
31
+
32
+ The project is currently in its raw form, more changes expected.
33
+
34
+ ## Features
35
+
36
+ * TQDM-Progress bar support for both training and validation loops
37
+ * Intemediate metrics computations after each forward pass (currently it is based on computing loss and accuracy only)
38
+ * Saving/loading checkpoints from/into Trainer directly without having to touch model, optimizer or scheduler separately
39
+ * Resuming training from the loaded checkpoint with epoch number being remembered automatically to avoid having to remember from which epoch the training originally started
40
+ * Ready-to-use neural net architectures coded from scratch (currently only 4-layer Encoder-Decoder and ResNet with 20 layers architectures are available)
41
+
42
+ ## Installation
43
+
44
+ After cloning this repo, the package can be installed in the development mode as follows:
45
+
46
+ ```bash
47
+ # Installing the main package
48
+ pip install -e .
49
+
50
+ # Installing additional optional dependencies
51
+ pip install -e .[build,linters]
52
+ ```
53
+
54
+ ## Python API
55
+
56
+ The basics of the developed API are presented in the [test script](./run_trainer.py) I built. It can be run for instance as follows:
57
+
58
+ ```bash
59
+ python run_trainer.py --model=resnet --epochs=3 --batch-size=16
60
+ ```
61
+ > The training will be launched on the device automatically derived based on the CUDA availability and the final training checkpoint will be saved in `checkpoints` directory.
62
+
63
+ One can also resume the training from the saved checkpoint:
64
+
65
+ ```bash
66
+ python run_trainer.py --model=resnet --epochs=4 --resume-training=True --ckpt-path=checkpoints/ckpt_3.pt
67
+ ```
68
+ > The training will be resumed from the loaded checkpoint with TQDM-progress bar showing the next training epoch.
@@ -0,0 +1,41 @@
1
+ # Epoch Engine - Python Library for training PyTorch models
2
+
3
+ This project represents my attempt to come up with a convenient way to train neural nets coded in Torch. While being aware of already existing libraries for training PyTorch models (e.g. PyTorch Lightning), my idea here is to make training of the models more visual and understandable as to what is going on during training.
4
+
5
+ The project is currently in its raw form, more changes expected.
6
+
7
+ ## Features
8
+
9
+ * TQDM-Progress bar support for both training and validation loops
10
+ * Intemediate metrics computations after each forward pass (currently it is based on computing loss and accuracy only)
11
+ * Saving/loading checkpoints from/into Trainer directly without having to touch model, optimizer or scheduler separately
12
+ * Resuming training from the loaded checkpoint with epoch number being remembered automatically to avoid having to remember from which epoch the training originally started
13
+ * Ready-to-use neural net architectures coded from scratch (currently only 4-layer Encoder-Decoder and ResNet with 20 layers architectures are available)
14
+
15
+ ## Installation
16
+
17
+ After cloning this repo, the package can be installed in the development mode as follows:
18
+
19
+ ```bash
20
+ # Installing the main package
21
+ pip install -e .
22
+
23
+ # Installing additional optional dependencies
24
+ pip install -e .[build,linters]
25
+ ```
26
+
27
+ ## Python API
28
+
29
+ The basics of the developed API are presented in the [test script](./run_trainer.py) I built. It can be run for instance as follows:
30
+
31
+ ```bash
32
+ python run_trainer.py --model=resnet --epochs=3 --batch-size=16
33
+ ```
34
+ > The training will be launched on the device automatically derived based on the CUDA availability and the final training checkpoint will be saved in `checkpoints` directory.
35
+
36
+ One can also resume the training from the saved checkpoint:
37
+
38
+ ```bash
39
+ python run_trainer.py --model=resnet --epochs=4 --resume-training=True --ckpt-path=checkpoints/ckpt_3.pt
40
+ ```
41
+ > The training will be resumed from the loaded checkpoint with TQDM-progress bar showing the next training epoch.
File without changes
File without changes
@@ -0,0 +1,40 @@
1
+ """Module for handling saving/loading of checkpoints."""
2
+
3
+ # Author: Sergey Polivin <s.polivin@gmail.com>
4
+ # License: MIT License
5
+
6
+ import torch
7
+
8
+
9
+ class CheckpointHandler:
10
+ """Class for saving and loading trainer checkpoints."""
11
+
12
+ def save_checkpoint(self, trainer, path: str) -> None:
13
+ """Saves the checkpoint: last trained epoch, model state and optimizer state.
14
+
15
+ Args:
16
+ trainer: Trainer instance.
17
+ path (str): Path of the checkpoint file to which data are to be saved.
18
+ """
19
+ torch.save(
20
+ {
21
+ "epoch": trainer.last_epoch,
22
+ "model_state_dict": trainer.model.state_dict(),
23
+ "optimizer_state_dict": trainer.optimizer.state_dict(),
24
+ },
25
+ path,
26
+ )
27
+
28
+ def load_checkpoint(self, trainer, path: str) -> None:
29
+ """Loads the checkpoint: last trained epoch, model state and optimizer state.
30
+
31
+ Args:
32
+ trainer: Trainer instance.
33
+ path (str): Path of the checkpoint file from which data are to be loaded.
34
+ """
35
+ # Loading the checkpoint file
36
+ checkpoint = torch.load(path, weights_only=True)
37
+ # Loading the data into the Trainer instance attributes
38
+ trainer.last_epoch = checkpoint["epoch"]
39
+ trainer.model.load_state_dict(checkpoint["model_state_dict"])
40
+ trainer.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
@@ -0,0 +1,100 @@
1
+ """Module for handling tracking and computation of metrics."""
2
+
3
+ # Author: Sergey Polivin <s.polivin@gmail.com>
4
+ # License: MIT License
5
+
6
+
7
+ class MetricsTracker:
8
+ """Class for keeping track of metrics."""
9
+
10
+ def __init__(self) -> None:
11
+ """Initializes a class instance.
12
+
13
+ Attributes:
14
+ metrics (dict[str, list[float]]): Batch-level metrics in a dict format.
15
+ total_correct_train (int): Number of correct predictions in train set.
16
+ total_samples_train (int): Total number of examples in train set.
17
+ total_correct_valid (int): Number of correct predictions in valid set.
18
+ total_samples_valid (int): Total number of examples in valid set.
19
+ """
20
+ self.metrics = {}
21
+ self.total_correct_train = 0
22
+ self.total_samples_train = 0
23
+ self.total_correct_valid = 0
24
+ self.total_samples_valid = 0
25
+
26
+ def reset(self) -> None:
27
+ """Resets all metrics and counters."""
28
+ self.metrics = {}
29
+ self.total_correct_train = 0
30
+ self.total_samples_train = 0
31
+ self.total_correct_valid = 0
32
+ self.total_samples_valid = 0
33
+
34
+ def update(self, name: str, value: float) -> None:
35
+ """Updates a specific metric.
36
+
37
+ Args:
38
+ name (str): Metric name.
39
+ value (float): Value of a metric.
40
+ """
41
+ if name not in self.metrics:
42
+ self.metrics[name] = []
43
+ self.metrics[name].append(value)
44
+
45
+ def update_accuracy(self, correct: int, total: int, split: str) -> None:
46
+ """Updates accuracy counters.
47
+
48
+ Args:
49
+ correct (int): Number of correct predictions.
50
+ total (int): Number of total examples.
51
+ split (str): Train or valid split indicator.
52
+
53
+ Raises:
54
+ ValueError: Error in case some other split name is passed.
55
+ """
56
+ if split == "train":
57
+ self.total_correct_train += correct
58
+ self.total_samples_train += total
59
+ elif split == "valid":
60
+ self.total_correct_valid += correct
61
+ self.total_samples_valid += total
62
+ else:
63
+ raise ValueError
64
+
65
+ def compute_accuracy(self, split: str) -> float:
66
+ """Computes accuracy across all batches.
67
+
68
+ Args:
69
+ split (str): Train or valid split indicator.
70
+
71
+ Raises:
72
+ ValueError: Error in case some other split name is passed.
73
+
74
+ Returns:
75
+ float: Accuracy score.
76
+ """
77
+ if self.total_samples_train == 0 or self.total_samples_valid == 0:
78
+ return 0.0
79
+ if split == "train":
80
+ return self.total_correct_train / self.total_samples_train
81
+ elif split == "valid":
82
+ return self.total_correct_valid / self.total_samples_valid
83
+ else:
84
+ raise ValueError
85
+
86
+ def get_all_metrics(self) -> dict[str, float]:
87
+ """Computes all metrics as averages.
88
+
89
+ Returns:
90
+ dict[str, float]: Epoch-level metrics.
91
+ """
92
+ # Aggregating loss across batches
93
+ metrics = {
94
+ name: sum(values) / len(values) for name, values in self.metrics.items()
95
+ }
96
+ # Computing accuracy scores
97
+ metrics["accuracy/train"] = self.compute_accuracy(split="train")
98
+ metrics["accuracy/valid"] = self.compute_accuracy(split="valid")
99
+
100
+ return metrics
@@ -0,0 +1,294 @@
1
+ """Module for Trainer functionality."""
2
+
3
+ # Author: Sergey Polivin <s.polivin@gmail.com>
4
+ # License: MIT License
5
+
6
+ from typing import Any, TypeAlias
7
+
8
+ import torch
9
+ from tqdm import tqdm
10
+
11
+ from .checkpoint_handler import CheckpointHandler
12
+ from .metrics_tracker import MetricsTracker
13
+
14
+ TorchDataloader: TypeAlias = torch.utils.data.dataloader.DataLoader
15
+
16
+
17
+ class Trainer:
18
+ """Trainer of PyTorch models."""
19
+
20
+ def __init__(
21
+ self,
22
+ model: Any,
23
+ criterion: Any,
24
+ optimizer: Any,
25
+ train_loader: TorchDataloader,
26
+ valid_loader: TorchDataloader,
27
+ train_on: str = "auto",
28
+ ) -> None:
29
+ """Initializes a class instance.
30
+
31
+ Args:
32
+ model (Any): PyTorch model instance.
33
+ criterion (Any): PyTorch loss function instance.
34
+ optimizer (Any): PyTorch optimizer instance.
35
+ train_loader (TorchDataloader): Torch Dataloader for training set.
36
+ valid_loader (TorchDataloader): Torch Dataloader for validation set.
37
+ train_on (str, optional): Indicator of CPU- or GPU-based training. Defaults to "auto".
38
+
39
+ Attributes:
40
+ model (Any): PyTorch model instance.
41
+ criterion (Any): PyTorch loss function instance.
42
+ optimizer (Any): PyTorch optimizer instance.
43
+ train_loader (TorchDataloader): Torch Dataloader for training set.
44
+ valid_loader (TorchDataloader): Torch Dataloader for validation set.
45
+ device (torch.device): Device to be used to train the model.
46
+ scheduler (Any): Instance of PyTorch scheduler for learning rate. Defaults to None.
47
+ scheduler_level (bool): Level on which to apply scheduler. Defaults to None.
48
+ last_epoch (int): Last epoch when training has been successfully finished. Defaults to 0.
49
+ metrics_tracker (MetricsTracker): Class for handling computing metrics.
50
+ ckpt_handler (CheckpointHandler): Class for saving/loading checkpoints.
51
+ """
52
+ # Using the CPU or GPU device or inferring one
53
+ self.device = (
54
+ torch.device("cuda" if torch.cuda.is_available() else "cpu")
55
+ if train_on == "auto"
56
+ else torch.device(train_on)
57
+ )
58
+ # Moving model to the device specified
59
+ self.model = model.to(self.device)
60
+ # Reinitializing the optimizer with model parameters having been moved to device
61
+ self.optimizer = type(optimizer)(self.model.parameters(), **optimizer.defaults)
62
+
63
+ # Setting loss function
64
+ self.criterion = criterion
65
+
66
+ # Setting dataloaders for training/validation sets
67
+ self.train_loader = train_loader
68
+ self.valid_loader = valid_loader
69
+
70
+ # Setting scheduler-related attributes
71
+ self.scheduler = None
72
+ self.scheduler_level = None
73
+
74
+ # Setting epoch indicator
75
+ self.last_epoch = 0
76
+
77
+ # Setting tracker for metrics and checkpoints handler
78
+ self.metrics_tracker = MetricsTracker()
79
+ self.ckpt_handler = CheckpointHandler()
80
+
81
+ def set_scheduler(
82
+ self,
83
+ scheduler_class: Any,
84
+ scheduler_params: dict[str, Any],
85
+ level: str = "epoch",
86
+ ) -> None:
87
+ """Sets and initializes a scheduler.
88
+
89
+ Args:
90
+ scheduler_class (Any): Scheduler class from PyTorch
91
+ scheduler_params (dict[str, Any]): Scheduler parameters.
92
+ level (str, optional): Level on which to apply scheduler. Defaults to "epoch".
93
+
94
+ Raises:
95
+ ValueError: Error raised if the specified level name is unexpected.
96
+ """
97
+ # Setting scheduler level on batch- or epoch-level
98
+ if level not in ("epoch", "batch"):
99
+ raise ValueError("Invalid value")
100
+ else:
101
+ self.scheduler_level = level
102
+ # Initializing a scheduler
103
+ self.scheduler = scheduler_class(self.optimizer, **scheduler_params)
104
+
105
+ def reset_scheduler(self) -> None:
106
+ """Resetting and removing scheduler."""
107
+ self.scheduler = None
108
+ self.scheduler_level = None
109
+
110
+ def __call__(
111
+ self, x_batch: torch.Tensor, y_batch: torch.Tensor
112
+ ) -> tuple[torch.Tensor, torch.Tensor]:
113
+ """Computes loss and model outputs.
114
+
115
+ Args:
116
+ x_batch (torch.Tensor): Batch of examples.
117
+ y_batch (torch.Tensor): Batch of labels.
118
+
119
+ Returns:
120
+ tuple[torch.Tensor, torch.Tensor]: Tuple containing value of loss function and output tensor.
121
+ """
122
+ outputs = self.model(x_batch)
123
+ loss = self.criterion(outputs, y_batch.long())
124
+
125
+ return loss, outputs
126
+
127
+ def run(
128
+ self,
129
+ epochs: int,
130
+ seed: int = 42,
131
+ enable_tqdm: bool = True,
132
+ ) -> None:
133
+ """Launches the trainer.
134
+
135
+ Args:
136
+ epochs (int): Number of epochs to train.
137
+ seed (int, optional): Random number generator seed. Defaults to 42.
138
+ enable_tqdm (bool, optional): Indicator to turn on tqdm progress bar. Defaults to True.
139
+ """
140
+ # Setting seed
141
+ if seed:
142
+ torch.manual_seed(seed)
143
+ # Computing the number of total epochs to be trained (useful in case of additional trainings)
144
+ total_epochs = epochs + self.last_epoch
145
+ # Running training/validation loops
146
+ for epoch in range(self.last_epoch, total_epochs):
147
+ # Processing all training batches
148
+ self._train_one_epoch(
149
+ epoch=epoch, epochs=total_epochs, enable_tqdm=enable_tqdm
150
+ )
151
+ # Processing all validation batches
152
+ self._validate_one_epoch(
153
+ epoch=epoch, epochs=total_epochs, enable_tqdm=enable_tqdm
154
+ )
155
+ # Incrementing last epoch after successful training/validation
156
+ self.last_epoch = epoch + 1
157
+ # Resetting metrics counters after successful training/validation
158
+ self.metrics_tracker.reset()
159
+
160
+ def _train_one_epoch(self, epoch: int, epochs: int, enable_tqdm: bool) -> None:
161
+ """Processes all training batches for one epoch.
162
+
163
+ Args:
164
+ epoch (int): Current epoch.
165
+ epochs (int): Total number of epochs.
166
+ enable_tqdm (bool): Indicator to turn on tqdm progress bar.
167
+ """
168
+
169
+ with tqdm(
170
+ total=len(self.train_loader),
171
+ desc=f"Epoch {epoch + 1}/{epochs} [Training]",
172
+ position=0,
173
+ leave=True,
174
+ unit="batches",
175
+ disable=not enable_tqdm,
176
+ ) as pbar:
177
+ # Setting the learning rate in progress bar
178
+ pbar.set_postfix({"lr": self.optimizer.param_groups[0]["lr"]})
179
+ # Setting model in training mode
180
+ self.model.train()
181
+ # Processing all training batches
182
+ for batch in self.train_loader:
183
+ self._process_batch(batch=batch, split="train")
184
+ pbar.update(1)
185
+ # Adjusting learning rate if set
186
+ if self.scheduler and self.scheduler_level == "epoch":
187
+ self.scheduler.step()
188
+
189
+ def _validate_one_epoch(self, epoch: int, epochs: int, enable_tqdm: bool) -> None:
190
+ """Processes all validation batches for one epoch.
191
+
192
+ Args:
193
+ epoch (int): Current epoch.
194
+ epochs (int): Total number of epochs.
195
+ enable_tqdm (bool): Indicator to turn on tqdm progress bar.
196
+ """
197
+ with tqdm(
198
+ total=len(self.valid_loader),
199
+ desc=f"Epoch {epoch + 1}/{epochs} [Validation]",
200
+ position=0,
201
+ leave=True,
202
+ unit="batches",
203
+ disable=not enable_tqdm,
204
+ ) as pbar:
205
+ # Setting model to evalutation model
206
+ self.model.eval()
207
+ # Processing all validation batches (grad compute turned off)
208
+ with torch.no_grad():
209
+ for batch in self.valid_loader:
210
+ self._process_batch(batch=batch, split="valid")
211
+ pbar.update(1)
212
+
213
+ # Aggregating batch-level metrics (loss, accuracy) to epoch-level
214
+ self.metrics = self.metrics_tracker.get_all_metrics()
215
+ # Setting the epoch-level stats to the progress bar
216
+ pbar.set_postfix(
217
+ {
218
+ "loss": (
219
+ round(self.metrics["loss/train"], 4),
220
+ round(self.metrics["loss/valid"], 4),
221
+ ),
222
+ "acc": (
223
+ round(self.metrics["accuracy/train"], 4),
224
+ round(self.metrics["accuracy/valid"], 4),
225
+ ),
226
+ }
227
+ )
228
+
229
+ def _process_batch(self, batch: list[torch.Tensor], split: str) -> None:
230
+ """Backpropagates model during training and collects batch-level metrics.
231
+
232
+ Args:
233
+ batch (list[torch.Tensor]): Examples and labels.
234
+ split (str): Indicator of training or validation set.
235
+ """
236
+
237
+ # Moving examples and labels batches to device chosen
238
+ x_batch, y_batch = batch
239
+ x_batch, y_batch = x_batch.to(self.device), y_batch.to(self.device)
240
+ # Computing loss and output tensor for the batch
241
+ loss, outputs = self(x_batch=x_batch, y_batch=y_batch)
242
+
243
+ # If training set, backpropagate
244
+ if split == "train":
245
+ loss.backward()
246
+ self.optimizer.step()
247
+ self.optimizer.zero_grad()
248
+ # Adjusting learning rate at batch-level
249
+ if self.scheduler and self.scheduler_level == "batch":
250
+ self.scheduler.step()
251
+
252
+ # Collecting batch-level metrics
253
+ self._collect_metrics(loss=loss, outputs=outputs, labels=y_batch, split=split)
254
+
255
+ def _collect_metrics(
256
+ self,
257
+ loss: torch.Tensor,
258
+ outputs: torch.Tensor,
259
+ labels: torch.Tensor,
260
+ split: str,
261
+ ) -> None:
262
+ """Records loss and accuracy.
263
+
264
+ Args:
265
+ loss (torch.Tensor): Value of loss function.
266
+ outputs (torch.Tensor): Output tensor.
267
+ labels (torch.Tensor): Batch of labels.
268
+ split (str): Indicator of training or validation set.
269
+ """
270
+ # Adding batch-level loss
271
+ loss = loss.item()
272
+ self.metrics_tracker.update(f"loss/{split}", loss)
273
+
274
+ # Adding batch-level accuracy
275
+ predictions = torch.argmax(outputs, dim=1)
276
+ correct = (predictions == labels).sum().cpu().item()
277
+ total = labels.size(0)
278
+ self.metrics_tracker.update_accuracy(correct, total, split=split)
279
+
280
+ def save_checkpoint(self, path: str) -> None:
281
+ """Saves trainer checkpoint.
282
+
283
+ Args:
284
+ path (str): Path to the checkpoint to be saved.
285
+ """
286
+ self.ckpt_handler.save_checkpoint(path=path, trainer=self)
287
+
288
+ def load_checkpoint(self, path: str) -> None:
289
+ """Loads trainer checkpoint.
290
+
291
+ Args:
292
+ path (str): Path to the checkpoint to be loaded.
293
+ """
294
+ self.ckpt_handler.load_checkpoint(path=path, trainer=self)
File without changes
@@ -0,0 +1,313 @@
1
+ """Module containing building blocks of models and functions to build neural networks."""
2
+
3
+ # Author: Sergey Polivin <s.polivin@gmail.com>
4
+ # License: MIT License
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+
11
+ class Encoder(nn.Module):
12
+ """Builds a series of convolutional blocks."""
13
+
14
+ def __init__(self, encoder_channels: tuple[int]) -> None:
15
+ """Initializes a class instance.
16
+
17
+ Args:
18
+ encoder_channels (tuple[int]): Tuple of convolution channels.
19
+ """
20
+ super().__init__()
21
+ # Creating a series of encoder blocks in accordance with channels in `encoder_channels`
22
+ encoder_blocks = [
23
+ self._make_encoder_block(in_channel, out_channel)
24
+ for in_channel, out_channel in zip(encoder_channels, encoder_channels[1:])
25
+ ]
26
+ # Sequentially connecting the generated convolution blocks
27
+ self.encoder_blocks = nn.Sequential(*encoder_blocks)
28
+
29
+ def _make_encoder_block(self, in_channels: int, out_channels: int) -> nn.Sequential:
30
+ """Creates a convolution block with ReLU activation and MaxPooling.
31
+
32
+ Args:
33
+ in_channels (int): Number of input channels for convolution.
34
+ out_channels (int): Number of output channels for convolution.
35
+
36
+ Returns:
37
+ nn.Sequential: Convolution block of sequentially connected layers.
38
+ """
39
+ return nn.Sequential(
40
+ nn.Conv2d(
41
+ in_channels=in_channels,
42
+ out_channels=out_channels,
43
+ kernel_size=3,
44
+ padding=1,
45
+ ),
46
+ nn.ReLU(),
47
+ nn.MaxPool2d(kernel_size=3),
48
+ )
49
+
50
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
51
+ """Makes a forward pass.
52
+
53
+ Args:
54
+ x (torch.Tensor): Input tensor for Encoder.
55
+
56
+ Returns:
57
+ torch.Tensor: Output tensor of Encoder.
58
+ """
59
+ return self.encoder_blocks(x)
60
+
61
+
62
+ class Decoder(nn.Module):
63
+ """Builds a series of feedforward blocks."""
64
+
65
+ def __init__(self, decoder_features: tuple[int], num_labels: int) -> None:
66
+ """Initializes a class instance.
67
+
68
+ Args:
69
+ decoder_features (tuple[int]): Tuple of features in the Decoder.
70
+ num_labels (int): Number of output labels in the last layer.
71
+ """
72
+ super().__init__()
73
+ # Creating a series of decoder blocks in accordance with features in `decoder_features`
74
+ decoder_blocks = [
75
+ self._make_decoder_block(in_feature, out_feature)
76
+ for in_feature, out_feature in zip(decoder_features, decoder_features[1:])
77
+ ]
78
+ # Sequentially connecting the generated feedforward blocks
79
+ self.decoder_blocks = nn.Sequential(*decoder_blocks)
80
+ # Creating the last layer of the Decoder
81
+ self.last = nn.Linear(decoder_features[-1], num_labels)
82
+
83
+ def _make_decoder_block(self, in_features: int, out_features: int) -> nn.Sequential:
84
+ """Creates a fully connected linear layer with Sigmoid activation.
85
+
86
+ Args:
87
+ in_features (int): Number of input features of the Decoder.
88
+ out_features (int): Number of output features of the Decoder.
89
+
90
+ Returns:
91
+ nn.Sequential: Decoder block of sequentially connected layers.
92
+ """
93
+ return nn.Sequential(
94
+ nn.Linear(in_features, out_features),
95
+ nn.Sigmoid(),
96
+ )
97
+
98
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
99
+ """Makes a forward pass.
100
+
101
+ Args:
102
+ x (torch.Tensor): Input tensor for Decoder.
103
+
104
+ Returns:
105
+ torch.Tensor: Output tensor of Decoder.
106
+ """
107
+ x = self.decoder_blocks(x)
108
+ x = self.last(x)
109
+
110
+ return x
111
+
112
+
113
+ class EDNet(nn.Module):
114
+ """Joins Encoder with Decoder to form a CNN."""
115
+
116
+ def __init__(
117
+ self,
118
+ in_channels: int,
119
+ encoder_channels: tuple[int],
120
+ decoder_features: tuple[int],
121
+ num_labels: int,
122
+ ) -> None:
123
+ """Initializes a class instance.
124
+
125
+ Args:
126
+ in_channels (int): Number of input channels for the image.
127
+ encoder_channels (tuple[int]): Tuple of channels for the Encoder.
128
+ decoder_features (tuple[int]): Tuple of channels for the Decoder.
129
+ num_labels (int): Number of ouput labels.
130
+ """
131
+ super().__init__()
132
+ # Setting up Encoder block
133
+ self.encoder_channels = [in_channels, *encoder_channels]
134
+ self.encoder = Encoder(self.encoder_channels)
135
+ # Setting up Decoder block
136
+ self.decoder_features = decoder_features
137
+ self.decoder = Decoder(self.decoder_features, num_labels)
138
+ # Flatten layer
139
+ self.flatten = nn.Flatten()
140
+
141
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
142
+ """Makes a forward pass.
143
+
144
+ Args:
145
+ x (torch.Tensor): Input tensor for a CNN.
146
+
147
+ Returns:
148
+ torch.Tensor: Output tensor of a CNN.
149
+ """
150
+ x = self.encoder(x)
151
+ x = self.flatten(x)
152
+ x = self.decoder(x)
153
+
154
+ return x
155
+
156
+
157
+ class BasicBlock(nn.Module):
158
+ def __init__(self, in_channels: int, out_channels: int, stride: int = 1) -> None:
159
+ """Builds a block of ResNet layer with 2 inner blocks (CONV + BN + SKIP CONNECTION).
160
+
161
+ Args:
162
+ in_channels (int): Number of input channels.
163
+ out_channels (int): Number of output channels.
164
+ stride (int, optional): Value of stride. Defaults to 1.
165
+ """
166
+ super().__init__()
167
+ # BLOCK 1 (CONV + BN) #################################################
168
+ self.conv1 = nn.Conv2d(
169
+ in_channels,
170
+ out_channels,
171
+ kernel_size=3,
172
+ stride=stride,
173
+ padding=1,
174
+ bias=False,
175
+ )
176
+ self.bn1 = nn.BatchNorm2d(out_channels)
177
+
178
+ # BLOCK 2 (CONV + BN) #################################################
179
+ self.conv2 = nn.Conv2d(
180
+ out_channels,
181
+ out_channels,
182
+ kernel_size=3,
183
+ stride=1,
184
+ padding=1,
185
+ bias=False,
186
+ )
187
+ self.bn2 = nn.BatchNorm2d(out_channels)
188
+
189
+ # SKIP CONNECTION #################################################
190
+ self.shortcut = nn.Sequential()
191
+ # Specifying condition for applying residual connection
192
+ if stride != 1 or in_channels != out_channels:
193
+ self.shortcut = nn.Sequential(
194
+ nn.Conv2d(
195
+ in_channels,
196
+ out_channels,
197
+ kernel_size=1,
198
+ stride=stride,
199
+ bias=False,
200
+ ),
201
+ nn.BatchNorm2d(out_channels),
202
+ )
203
+
204
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
205
+ """Makes a forward pass.
206
+
207
+ Args:
208
+ x (torch.Tensor): Input tensor.
209
+
210
+ Returns:
211
+ torch.Tensor: Output tensor.
212
+ """
213
+ # Output for BLOCK 1
214
+ out = F.relu(self.bn1(self.conv1(x)))
215
+ # Output for BLOCK 2
216
+ out = self.bn2(self.conv2(out))
217
+ # Adding up SKIP CONNECTION if condition met
218
+ out += self.shortcut(x)
219
+ # Output of RESNET BLOCK
220
+ return F.relu(out)
221
+
222
+
223
+ class ResNet(nn.Module):
224
+ def __init__(
225
+ self,
226
+ in_channels: int,
227
+ block: BasicBlock,
228
+ num_blocks: list[int],
229
+ num_classes: int = 10,
230
+ ) -> None:
231
+ """Builds a ResNet model with 3 layers.
232
+
233
+ Args:
234
+ in_channels (int): Number of input channels.
235
+ block (BasicBlock): Block of ResNet layer.
236
+ num_blocks (list[int]): Number of blocks inside Block of ResNet layer.
237
+ num_classes (int, optional): Number of output classes. Defaults to 10.
238
+ """
239
+ super().__init__()
240
+ # Specifying a number of initial input channels for ResNet block
241
+ self.in_block_channels = 16
242
+
243
+ # LAYER 0 (CONV + BN) #################################################################################
244
+ self.conv1 = nn.Conv2d(
245
+ in_channels,
246
+ self.in_block_channels,
247
+ kernel_size=3,
248
+ stride=1,
249
+ padding=1,
250
+ bias=False,
251
+ )
252
+ self.bn1 = nn.BatchNorm2d(self.in_block_channels)
253
+
254
+ # LAYERS 1, 2, 3 (`num_blocks` blocks with 2 CONV/BN BLOCKS for 1 layer) ##############################
255
+ self.layer1 = self._make_resnet_layer(
256
+ block=block, out_channels=16, num_blocks=num_blocks[0], stride=1
257
+ )
258
+ self.layer2 = self._make_resnet_layer(
259
+ block=block, out_channels=32, num_blocks=num_blocks[1], stride=2
260
+ )
261
+ self.layer3 = self._make_resnet_layer(
262
+ block=block, out_channels=64, num_blocks=num_blocks[2], stride=2
263
+ )
264
+
265
+ # GLOBAL POOLING LAYER ################################################################################
266
+ self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
267
+
268
+ # FC LAYER ############################################################################################
269
+ self.fc = nn.Linear(64, num_classes)
270
+
271
+ def _make_resnet_layer(
272
+ self,
273
+ block: BasicBlock,
274
+ out_channels: int,
275
+ num_blocks: int,
276
+ stride: int,
277
+ ) -> nn.Sequential:
278
+ """Creates a ResNet layer by stacking up BasicBlock-s.
279
+
280
+ Args:
281
+ block (BasicBlock): Block inside ResNet layer.
282
+ out_channels (int): Number of output channels.
283
+ num_blocks (int): Number of blocks inside ResNet layer.
284
+ stride (int): Value of stride.
285
+
286
+ Returns:
287
+ nn.Sequential: Layer of ResNet with blocks added inside.
288
+ """
289
+ # Specifying strides to be used for each BasicBlock
290
+ strides = [stride] + [1] * (num_blocks - 1)
291
+ # Stacking up ResNet blocks
292
+ layers = []
293
+ for stride in strides:
294
+ layers.append(block(self.in_block_channels, out_channels, stride))
295
+ self.in_block_channels = out_channels
296
+ return nn.Sequential(*layers)
297
+
298
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
299
+ """Makes a forward pass.
300
+
301
+ Args:
302
+ x (torch.Tensor): Input tensor.
303
+
304
+ Returns:
305
+ torch.Tensor: Output tensor.
306
+ """
307
+ out = F.relu(self.bn1(self.conv1(x)))
308
+ out = self.layer1(out)
309
+ out = self.layer2(out)
310
+ out = self.layer3(out)
311
+ out = self.avg_pool(out)
312
+ out = torch.flatten(out, 1)
313
+ return self.fc(out)
@@ -0,0 +1,68 @@
1
+ Metadata-Version: 2.4
2
+ Name: epoch-engine
3
+ Version: 0.1.0
4
+ Summary: Trainer and evaluator for PyTorch models with a focus on simplicity and flexibility.
5
+ Author-email: Sergey Polivin <s.polivin@gmail.com>
6
+ License: MIT
7
+ Keywords: pytorch,deep-learning
8
+ Classifier: Programming Language :: Python :: 3
9
+ Classifier: License :: OSI Approved :: MIT License
10
+ Classifier: Operating System :: OS Independent
11
+ Classifier: Development Status :: 3 - Alpha
12
+ Requires-Python: >=3.8
13
+ Description-Content-Type: text/markdown
14
+ License-File: LICENSE.txt
15
+ Requires-Dist: torch==2.5.0
16
+ Requires-Dist: torchvision==0.20.0
17
+ Requires-Dist: tqdm==4.67.0
18
+ Provides-Extra: build
19
+ Requires-Dist: setuptools; extra == "build"
20
+ Requires-Dist: wheel; extra == "build"
21
+ Requires-Dist: build; extra == "build"
22
+ Requires-Dist: twine; extra == "build"
23
+ Provides-Extra: linters
24
+ Requires-Dist: black; extra == "linters"
25
+ Requires-Dist: isort; extra == "linters"
26
+ Dynamic: license-file
27
+
28
+ # Epoch Engine - Python Library for training PyTorch models
29
+
30
+ This project represents my attempt to come up with a convenient way to train neural nets coded in Torch. While being aware of already existing libraries for training PyTorch models (e.g. PyTorch Lightning), my idea here is to make training of the models more visual and understandable as to what is going on during training.
31
+
32
+ The project is currently in its raw form, more changes expected.
33
+
34
+ ## Features
35
+
36
+ * TQDM-Progress bar support for both training and validation loops
37
+ * Intemediate metrics computations after each forward pass (currently it is based on computing loss and accuracy only)
38
+ * Saving/loading checkpoints from/into Trainer directly without having to touch model, optimizer or scheduler separately
39
+ * Resuming training from the loaded checkpoint with epoch number being remembered automatically to avoid having to remember from which epoch the training originally started
40
+ * Ready-to-use neural net architectures coded from scratch (currently only 4-layer Encoder-Decoder and ResNet with 20 layers architectures are available)
41
+
42
+ ## Installation
43
+
44
+ After cloning this repo, the package can be installed in the development mode as follows:
45
+
46
+ ```bash
47
+ # Installing the main package
48
+ pip install -e .
49
+
50
+ # Installing additional optional dependencies
51
+ pip install -e .[build,linters]
52
+ ```
53
+
54
+ ## Python API
55
+
56
+ The basics of the developed API are presented in the [test script](./run_trainer.py) I built. It can be run for instance as follows:
57
+
58
+ ```bash
59
+ python run_trainer.py --model=resnet --epochs=3 --batch-size=16
60
+ ```
61
+ > The training will be launched on the device automatically derived based on the CUDA availability and the final training checkpoint will be saved in `checkpoints` directory.
62
+
63
+ One can also resume the training from the saved checkpoint:
64
+
65
+ ```bash
66
+ python run_trainer.py --model=resnet --epochs=4 --resume-training=True --ckpt-path=checkpoints/ckpt_3.pt
67
+ ```
68
+ > The training will be resumed from the loaded checkpoint with TQDM-progress bar showing the next training epoch.
@@ -0,0 +1,15 @@
1
+ LICENSE.txt
2
+ README.md
3
+ pyproject.toml
4
+ epoch_engine/__init__.py
5
+ epoch_engine.egg-info/PKG-INFO
6
+ epoch_engine.egg-info/SOURCES.txt
7
+ epoch_engine.egg-info/dependency_links.txt
8
+ epoch_engine.egg-info/requires.txt
9
+ epoch_engine.egg-info/top_level.txt
10
+ epoch_engine/core/__init__.py
11
+ epoch_engine/core/checkpoint_handler.py
12
+ epoch_engine/core/metrics_tracker.py
13
+ epoch_engine/core/trainer.py
14
+ epoch_engine/models/__init__.py
15
+ epoch_engine/models/architectures.py
@@ -0,0 +1,13 @@
1
+ torch==2.5.0
2
+ torchvision==0.20.0
3
+ tqdm==4.67.0
4
+
5
+ [build]
6
+ setuptools
7
+ wheel
8
+ build
9
+ twine
10
+
11
+ [linters]
12
+ black
13
+ isort
@@ -0,0 +1 @@
1
+ epoch_engine
@@ -0,0 +1,32 @@
1
+ [project]
2
+ name = "epoch-engine"
3
+ version = "0.1.0"
4
+ description = "Trainer and evaluator for PyTorch models with a focus on simplicity and flexibility."
5
+ authors = [{name = "Sergey Polivin", email = "s.polivin@gmail.com"}]
6
+ keywords = ["pytorch", "deep-learning"]
7
+ classifiers = [
8
+ "Programming Language :: Python :: 3",
9
+ "License :: OSI Approved :: MIT License",
10
+ "Operating System :: OS Independent",
11
+ "Development Status :: 3 - Alpha",
12
+ ]
13
+ readme = "README.md"
14
+ requires-python = ">=3.8"
15
+ license = {text = "MIT"}
16
+ dependencies = [
17
+ "torch==2.5.0",
18
+ "torchvision==0.20.0",
19
+ "tqdm==4.67.0"
20
+ ]
21
+
22
+ [project.optional-dependencies]
23
+ build = ["setuptools", "wheel", "build", "twine"]
24
+ linters = ["black", "isort"]
25
+
26
+ [build-system]
27
+ requires = ["setuptools>=61.0", "wheel"]
28
+ build-backend = "setuptools.build_meta"
29
+
30
+ [tool.setuptools.packages.find]
31
+ where = ["."]
32
+ include = ["epoch_engine*"]
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+