congrads 0.2.0__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- congrads/__init__.py +10 -21
- congrads/callbacks/base.py +357 -0
- congrads/callbacks/registry.py +106 -0
- congrads/checkpoints.py +178 -0
- congrads/constraints/base.py +242 -0
- congrads/constraints/registry.py +1255 -0
- congrads/core/batch_runner.py +200 -0
- congrads/core/congradscore.py +271 -0
- congrads/core/constraint_engine.py +209 -0
- congrads/core/epoch_runner.py +119 -0
- congrads/datasets/registry.py +799 -0
- congrads/descriptor.py +148 -29
- congrads/metrics.py +109 -19
- congrads/networks/registry.py +68 -0
- congrads/py.typed +0 -0
- congrads/transformations/base.py +37 -0
- congrads/transformations/registry.py +86 -0
- congrads/{utils.py → utils/preprocessors.py} +201 -72
- congrads/utils/utility.py +506 -0
- congrads/utils/validation.py +182 -0
- congrads-0.3.0.dist-info/METADATA +234 -0
- congrads-0.3.0.dist-info/RECORD +23 -0
- congrads-0.3.0.dist-info/WHEEL +4 -0
- congrads/constraints.py +0 -389
- congrads/core.py +0 -225
- congrads/datasets.py +0 -195
- congrads/networks.py +0 -90
- congrads-0.2.0.dist-info/LICENSE +0 -26
- congrads-0.2.0.dist-info/METADATA +0 -222
- congrads-0.2.0.dist-info/RECORD +0 -13
- congrads-0.2.0.dist-info/WHEEL +0 -5
- congrads-0.2.0.dist-info/top_level.txt +0 -1
|
@@ -0,0 +1,119 @@
|
|
|
1
|
+
"""Defines the EpochRunner class for running full training, validation, and test epochs.
|
|
2
|
+
|
|
3
|
+
This module handles:
|
|
4
|
+
- Switching the network between training and evaluation modes
|
|
5
|
+
- Iterating over DataLoaders with optional progress bars
|
|
6
|
+
- Delegating per-batch processing to a BatchRunner instance
|
|
7
|
+
- Optional gradient tracking control for evaluation phases
|
|
8
|
+
- Warnings when validation or test loaders are not provided
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
import warnings
|
|
12
|
+
|
|
13
|
+
from torch import set_grad_enabled
|
|
14
|
+
from torch.nn import Module
|
|
15
|
+
from torch.utils.data import DataLoader
|
|
16
|
+
from tqdm import tqdm
|
|
17
|
+
|
|
18
|
+
from ..core.batch_runner import BatchRunner
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class EpochRunner:
|
|
22
|
+
"""Runs full epochs over DataLoaders.
|
|
23
|
+
|
|
24
|
+
Responsibilities:
|
|
25
|
+
- Model mode switching
|
|
26
|
+
- Iteration over DataLoader
|
|
27
|
+
- Delegation to BatchRunner
|
|
28
|
+
- Progress bars
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
def __init__(
|
|
32
|
+
self,
|
|
33
|
+
network: Module,
|
|
34
|
+
batch_runner: BatchRunner,
|
|
35
|
+
train_loader: DataLoader,
|
|
36
|
+
valid_loader: DataLoader | None = None,
|
|
37
|
+
test_loader: DataLoader | None = None,
|
|
38
|
+
*,
|
|
39
|
+
network_uses_grad: bool = False,
|
|
40
|
+
disable_progress_bar: bool = False,
|
|
41
|
+
):
|
|
42
|
+
"""Initialize the EpochRunner.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
network: The neural network module to train/validate/test.
|
|
46
|
+
batch_runner: The BatchRunner instance for processing batches.
|
|
47
|
+
train_loader: DataLoader for training data.
|
|
48
|
+
valid_loader: DataLoader for validation data. Defaults to None.
|
|
49
|
+
test_loader: DataLoader for test data. Defaults to None.
|
|
50
|
+
network_uses_grad: Whether the network uses gradient computation. Defaults to False.
|
|
51
|
+
disable_progress_bar: Whether to disable progress bar display. Defaults to False.
|
|
52
|
+
"""
|
|
53
|
+
self.network = network
|
|
54
|
+
self.batch_runner = batch_runner
|
|
55
|
+
self.train_loader = train_loader
|
|
56
|
+
self.valid_loader = valid_loader
|
|
57
|
+
self.test_loader = test_loader
|
|
58
|
+
self.network_uses_grad = network_uses_grad
|
|
59
|
+
self.disable_progress_bar = disable_progress_bar
|
|
60
|
+
|
|
61
|
+
def train(self) -> None:
|
|
62
|
+
"""Run a training epoch over the training DataLoader.
|
|
63
|
+
|
|
64
|
+
Sets the network to training mode and iterates over batches,
|
|
65
|
+
delegating each batch to the BatchRunner for processing.
|
|
66
|
+
"""
|
|
67
|
+
self.network.train()
|
|
68
|
+
|
|
69
|
+
for batch in tqdm(
|
|
70
|
+
self.train_loader,
|
|
71
|
+
desc="Training batches",
|
|
72
|
+
leave=False,
|
|
73
|
+
disable=self.disable_progress_bar,
|
|
74
|
+
):
|
|
75
|
+
self.batch_runner.train_batch(batch)
|
|
76
|
+
|
|
77
|
+
def validate(self) -> None:
|
|
78
|
+
"""Run a validation epoch over the validation DataLoader.
|
|
79
|
+
|
|
80
|
+
Sets the network to evaluation mode and iterates over batches,
|
|
81
|
+
delegating each batch to the BatchRunner for processing.
|
|
82
|
+
Skips validation if no valid_loader is provided.
|
|
83
|
+
"""
|
|
84
|
+
if self.valid_loader is None:
|
|
85
|
+
warnings.warn("Validation skipped: no valid_loader provided.", stacklevel=2)
|
|
86
|
+
return
|
|
87
|
+
|
|
88
|
+
with set_grad_enabled(self.network_uses_grad):
|
|
89
|
+
self.network.eval()
|
|
90
|
+
|
|
91
|
+
for batch in tqdm(
|
|
92
|
+
self.valid_loader,
|
|
93
|
+
desc="Validation batches",
|
|
94
|
+
leave=False,
|
|
95
|
+
disable=self.disable_progress_bar,
|
|
96
|
+
):
|
|
97
|
+
self.batch_runner.valid_batch(batch)
|
|
98
|
+
|
|
99
|
+
def test(self) -> None:
|
|
100
|
+
"""Run a test epoch over the test DataLoader.
|
|
101
|
+
|
|
102
|
+
Sets the network to evaluation mode and iterates over batches,
|
|
103
|
+
delegating each batch to the BatchRunner for processing.
|
|
104
|
+
Skips testing if no test_loader is provided.
|
|
105
|
+
"""
|
|
106
|
+
if self.test_loader is None:
|
|
107
|
+
warnings.warn("Testing skipped: no test_loader provided.", stacklevel=2)
|
|
108
|
+
return
|
|
109
|
+
|
|
110
|
+
with set_grad_enabled(self.network_uses_grad):
|
|
111
|
+
self.network.eval()
|
|
112
|
+
|
|
113
|
+
for batch in tqdm(
|
|
114
|
+
self.test_loader,
|
|
115
|
+
desc="Test batches",
|
|
116
|
+
leave=False,
|
|
117
|
+
disable=self.disable_progress_bar,
|
|
118
|
+
):
|
|
119
|
+
self.batch_runner.test_batch(batch)
|