congrads 0.2.0__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,119 @@
1
+ """Defines the EpochRunner class for running full training, validation, and test epochs.
2
+
3
+ This module handles:
4
+ - Switching the network between training and evaluation modes
5
+ - Iterating over DataLoaders with optional progress bars
6
+ - Delegating per-batch processing to a BatchRunner instance
7
+ - Optional gradient tracking control for evaluation phases
8
+ - Warnings when validation or test loaders are not provided
9
+ """
10
+
11
+ import warnings
12
+
13
+ from torch import set_grad_enabled
14
+ from torch.nn import Module
15
+ from torch.utils.data import DataLoader
16
+ from tqdm import tqdm
17
+
18
+ from ..core.batch_runner import BatchRunner
19
+
20
+
21
+ class EpochRunner:
22
+ """Runs full epochs over DataLoaders.
23
+
24
+ Responsibilities:
25
+ - Model mode switching
26
+ - Iteration over DataLoader
27
+ - Delegation to BatchRunner
28
+ - Progress bars
29
+ """
30
+
31
+ def __init__(
32
+ self,
33
+ network: Module,
34
+ batch_runner: BatchRunner,
35
+ train_loader: DataLoader,
36
+ valid_loader: DataLoader | None = None,
37
+ test_loader: DataLoader | None = None,
38
+ *,
39
+ network_uses_grad: bool = False,
40
+ disable_progress_bar: bool = False,
41
+ ):
42
+ """Initialize the EpochRunner.
43
+
44
+ Args:
45
+ network: The neural network module to train/validate/test.
46
+ batch_runner: The BatchRunner instance for processing batches.
47
+ train_loader: DataLoader for training data.
48
+ valid_loader: DataLoader for validation data. Defaults to None.
49
+ test_loader: DataLoader for test data. Defaults to None.
50
+ network_uses_grad: Whether the network uses gradient computation. Defaults to False.
51
+ disable_progress_bar: Whether to disable progress bar display. Defaults to False.
52
+ """
53
+ self.network = network
54
+ self.batch_runner = batch_runner
55
+ self.train_loader = train_loader
56
+ self.valid_loader = valid_loader
57
+ self.test_loader = test_loader
58
+ self.network_uses_grad = network_uses_grad
59
+ self.disable_progress_bar = disable_progress_bar
60
+
61
+ def train(self) -> None:
62
+ """Run a training epoch over the training DataLoader.
63
+
64
+ Sets the network to training mode and iterates over batches,
65
+ delegating each batch to the BatchRunner for processing.
66
+ """
67
+ self.network.train()
68
+
69
+ for batch in tqdm(
70
+ self.train_loader,
71
+ desc="Training batches",
72
+ leave=False,
73
+ disable=self.disable_progress_bar,
74
+ ):
75
+ self.batch_runner.train_batch(batch)
76
+
77
+ def validate(self) -> None:
78
+ """Run a validation epoch over the validation DataLoader.
79
+
80
+ Sets the network to evaluation mode and iterates over batches,
81
+ delegating each batch to the BatchRunner for processing.
82
+ Skips validation if no valid_loader is provided.
83
+ """
84
+ if self.valid_loader is None:
85
+ warnings.warn("Validation skipped: no valid_loader provided.", stacklevel=2)
86
+ return
87
+
88
+ with set_grad_enabled(self.network_uses_grad):
89
+ self.network.eval()
90
+
91
+ for batch in tqdm(
92
+ self.valid_loader,
93
+ desc="Validation batches",
94
+ leave=False,
95
+ disable=self.disable_progress_bar,
96
+ ):
97
+ self.batch_runner.valid_batch(batch)
98
+
99
+ def test(self) -> None:
100
+ """Run a test epoch over the test DataLoader.
101
+
102
+ Sets the network to evaluation mode and iterates over batches,
103
+ delegating each batch to the BatchRunner for processing.
104
+ Skips testing if no test_loader is provided.
105
+ """
106
+ if self.test_loader is None:
107
+ warnings.warn("Testing skipped: no test_loader provided.", stacklevel=2)
108
+ return
109
+
110
+ with set_grad_enabled(self.network_uses_grad):
111
+ self.network.eval()
112
+
113
+ for batch in tqdm(
114
+ self.test_loader,
115
+ desc="Test batches",
116
+ leave=False,
117
+ disable=self.disable_progress_bar,
118
+ ):
119
+ self.batch_runner.test_batch(batch)