congrads 1.1.2__py3-none-any.whl → 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- congrads/__init__.py +0 -17
- congrads/callbacks/base.py +357 -0
- congrads/callbacks/registry.py +106 -0
- congrads/checkpoints.py +1 -1
- congrads/constraints/base.py +174 -0
- congrads/{constraints.py → constraints/registry.py} +120 -158
- congrads/core/batch_runner.py +200 -0
- congrads/core/congradscore.py +271 -0
- congrads/core/constraint_engine.py +170 -0
- congrads/core/epoch_runner.py +119 -0
- congrads/descriptor.py +1 -1
- congrads/metrics.py +1 -1
- congrads/transformations/base.py +37 -0
- congrads/{transformations.py → transformations/registry.py} +3 -33
- congrads/utils/preprocessors.py +439 -0
- congrads/utils/utility.py +506 -0
- congrads/utils/validation.py +194 -0
- {congrads-1.1.2.dist-info → congrads-1.2.0.dist-info}/METADATA +1 -1
- congrads-1.2.0.dist-info/RECORD +23 -0
- congrads/core.py +0 -773
- congrads/utils.py +0 -1078
- congrads-1.1.2.dist-info/RECORD +0 -14
- /congrads/{datasets.py → datasets/registry.py} +0 -0
- /congrads/{networks.py → networks/registry.py} +0 -0
- {congrads-1.1.2.dist-info → congrads-1.2.0.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,271 @@
|
|
|
1
|
+
"""This module provides the core CongradsCore class for the main training functionality.
|
|
2
|
+
|
|
3
|
+
It is designed to integrate constraint-guided optimization into neural network training.
|
|
4
|
+
It extends traditional training processes by enforcing specific constraints
|
|
5
|
+
on the model's outputs, ensuring that the network satisfies domain-specific
|
|
6
|
+
requirements during both training and evaluation.
|
|
7
|
+
|
|
8
|
+
The `CongradsCore` class serves as the central engine for managing the
|
|
9
|
+
training, validation, and testing phases of a neural network model,
|
|
10
|
+
incorporating constraints that influence the loss function and model updates.
|
|
11
|
+
The model is trained with standard loss functions while also incorporating
|
|
12
|
+
constraint-based adjustments, which are tracked and logged
|
|
13
|
+
throughout the process.
|
|
14
|
+
|
|
15
|
+
Key features:
|
|
16
|
+
- Support for various constraints that can influence the training process.
|
|
17
|
+
- Integration with PyTorch's `DataLoader` for efficient batch processing.
|
|
18
|
+
- Metric management for tracking loss and constraint satisfaction.
|
|
19
|
+
- Checkpoint management for saving and evaluating model states.
|
|
20
|
+
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
from collections.abc import Callable
|
|
24
|
+
|
|
25
|
+
import torch
|
|
26
|
+
from torch import Tensor, sum
|
|
27
|
+
from torch.nn import Module
|
|
28
|
+
from torch.nn.modules.loss import _Loss
|
|
29
|
+
from torch.optim import Optimizer
|
|
30
|
+
from torch.utils.data import DataLoader
|
|
31
|
+
from tqdm import tqdm
|
|
32
|
+
|
|
33
|
+
from congrads.utils.utility import LossWrapper
|
|
34
|
+
|
|
35
|
+
from ..callbacks.base import CallbackManager
|
|
36
|
+
from ..checkpoints import CheckpointManager
|
|
37
|
+
from ..constraints.base import Constraint
|
|
38
|
+
from ..core.batch_runner import BatchRunner
|
|
39
|
+
from ..core.constraint_engine import ConstraintEngine
|
|
40
|
+
from ..core.epoch_runner import EpochRunner
|
|
41
|
+
from ..descriptor import Descriptor
|
|
42
|
+
from ..metrics import MetricManager
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class CongradsCore:
|
|
46
|
+
"""The CongradsCore class is the central training engine for constraint-guided optimization.
|
|
47
|
+
|
|
48
|
+
It integrates standard neural network training
|
|
49
|
+
with additional constraint-driven adjustments to the loss function, ensuring
|
|
50
|
+
that the network satisfies domain-specific constraints during training.
|
|
51
|
+
"""
|
|
52
|
+
|
|
53
|
+
def __init__(
|
|
54
|
+
self,
|
|
55
|
+
descriptor: Descriptor,
|
|
56
|
+
constraints: list[Constraint],
|
|
57
|
+
network: Module,
|
|
58
|
+
criterion: _Loss,
|
|
59
|
+
optimizer: Optimizer,
|
|
60
|
+
device: torch.device,
|
|
61
|
+
dataloader_train: DataLoader,
|
|
62
|
+
dataloader_valid: DataLoader | None = None,
|
|
63
|
+
dataloader_test: DataLoader | None = None,
|
|
64
|
+
metric_manager: MetricManager | None = None,
|
|
65
|
+
callback_manager: CallbackManager | None = None,
|
|
66
|
+
checkpoint_manager: CheckpointManager | None = None,
|
|
67
|
+
network_uses_grad: bool = False,
|
|
68
|
+
epsilon: float = 1e-6,
|
|
69
|
+
constraint_aggregator: Callable[..., Tensor] = sum,
|
|
70
|
+
enforce_all: bool = True,
|
|
71
|
+
disable_progress_bar_epoch: bool = False,
|
|
72
|
+
disable_progress_bar_batch: bool = False,
|
|
73
|
+
epoch_runner_cls: type["EpochRunner"] | None = None,
|
|
74
|
+
batch_runner_cls: type["BatchRunner"] | None = None,
|
|
75
|
+
constraint_engine_cls: type["ConstraintEngine"] | None = None,
|
|
76
|
+
):
|
|
77
|
+
"""Initialize the CongradsCore object.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
descriptor (Descriptor): Describes variable layers in the network.
|
|
81
|
+
constraints (list[Constraint]): List of constraints to guide training.
|
|
82
|
+
network (Module): The neural network model to train.
|
|
83
|
+
criterion (callable): The loss function used for
|
|
84
|
+
training and validation.
|
|
85
|
+
optimizer (Optimizer): The optimizer used for updating model parameters.
|
|
86
|
+
device (torch.device): The device (e.g., CPU or GPU) for computations.
|
|
87
|
+
dataloader_train (DataLoader): DataLoader for training data.
|
|
88
|
+
dataloader_valid (DataLoader, optional): DataLoader for validation data.
|
|
89
|
+
If not provided, validation is skipped.
|
|
90
|
+
dataloader_test (DataLoader, optional): DataLoader for test data.
|
|
91
|
+
If not provided, testing is skipped.
|
|
92
|
+
metric_manager (MetricManager, optional): Manages metric tracking and recording.
|
|
93
|
+
callback_manager (CallbackManager, optional): Manages training callbacks.
|
|
94
|
+
checkpoint_manager (CheckpointManager, optional): Manages
|
|
95
|
+
checkpointing. If not set, no checkpointing is done.
|
|
96
|
+
network_uses_grad (bool, optional): A flag indicating if the network
|
|
97
|
+
contains gradient calculation computations. Default is False.
|
|
98
|
+
epsilon (float, optional): A small value to avoid division by zero
|
|
99
|
+
in gradient calculations. Default is 1e-10.
|
|
100
|
+
constraint_aggregator (Callable[..., Tensor], optional): A function
|
|
101
|
+
to aggregate the constraint rescale loss. Default is `sum`.
|
|
102
|
+
enforce_all (bool, optional): If set to False, constraints will only be monitored and
|
|
103
|
+
not influence the training process. Overrides constraint-specific `enforce` parameters.
|
|
104
|
+
Defaults to True.
|
|
105
|
+
disable_progress_bar_epoch (bool, optional): If set to True, the epoch
|
|
106
|
+
progress bar will not show. Defaults to False.
|
|
107
|
+
disable_progress_bar_batch (bool, optional): If set to True, the batch
|
|
108
|
+
progress bar will not show. Defaults to False.
|
|
109
|
+
epoch_runner_cls (type[EpochRunner], optional): Custom EpochRunner class.
|
|
110
|
+
If not provided, the default EpochRunner is used.
|
|
111
|
+
batch_runner_cls (type[BatchRunner], optional): Custom BatchRunner class.
|
|
112
|
+
If not provided, the default BatchRunner is used.
|
|
113
|
+
constraint_engine_cls (type[ConstraintEngine], optional): Custom ConstraintEngine class.
|
|
114
|
+
If not provided, the default ConstraintEngine is used.
|
|
115
|
+
|
|
116
|
+
Note:
|
|
117
|
+
A warning is logged if the descriptor has no variable layers,
|
|
118
|
+
as at least one variable layer is required for the constraint logic
|
|
119
|
+
to influence the training process.
|
|
120
|
+
"""
|
|
121
|
+
# Init object variables
|
|
122
|
+
self.device = device
|
|
123
|
+
self.network = network.to(device)
|
|
124
|
+
self.criterion = LossWrapper(criterion)
|
|
125
|
+
self.optimizer = optimizer
|
|
126
|
+
self.descriptor = descriptor
|
|
127
|
+
|
|
128
|
+
self.constraints = constraints or []
|
|
129
|
+
self.epsilon = epsilon
|
|
130
|
+
self.aggregator = constraint_aggregator
|
|
131
|
+
self.enforce_all = enforce_all
|
|
132
|
+
|
|
133
|
+
self.metric_manager = metric_manager
|
|
134
|
+
self.callback_manager = callback_manager
|
|
135
|
+
self.checkpoint_manager = checkpoint_manager
|
|
136
|
+
|
|
137
|
+
self.disable_progress_bar_epoch = disable_progress_bar_epoch
|
|
138
|
+
self.disable_progress_bar_batch = disable_progress_bar_batch
|
|
139
|
+
|
|
140
|
+
# Initialize constraint engine
|
|
141
|
+
self.constraint_engine = (constraint_engine_cls or ConstraintEngine)(
|
|
142
|
+
constraints=self.constraints,
|
|
143
|
+
descriptor=self.descriptor,
|
|
144
|
+
device=self.device,
|
|
145
|
+
epsilon=self.epsilon,
|
|
146
|
+
aggregator=self.aggregator,
|
|
147
|
+
enforce_all=self.enforce_all,
|
|
148
|
+
metric_manager=self.metric_manager,
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
# Initialize runners
|
|
152
|
+
self.batch_runner = (batch_runner_cls or BatchRunner)(
|
|
153
|
+
network=self.network,
|
|
154
|
+
criterion=self.criterion,
|
|
155
|
+
optimizer=self.optimizer,
|
|
156
|
+
constraint_engine=self.constraint_engine,
|
|
157
|
+
metric_manager=self.metric_manager,
|
|
158
|
+
callback_manager=self.callback_manager,
|
|
159
|
+
device=self.device,
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
self.epoch_runner = (epoch_runner_cls or EpochRunner)(
|
|
163
|
+
batch_runner=self.batch_runner,
|
|
164
|
+
network=self.network,
|
|
165
|
+
train_loader=dataloader_train,
|
|
166
|
+
valid_loader=dataloader_valid,
|
|
167
|
+
test_loader=dataloader_test,
|
|
168
|
+
network_uses_grad=network_uses_grad,
|
|
169
|
+
disable_progress_bar=self.disable_progress_bar_batch,
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
# Initialize constraint metrics
|
|
173
|
+
if self.metric_manager is not None:
|
|
174
|
+
self._initialize_metrics()
|
|
175
|
+
|
|
176
|
+
def _initialize_metrics(self) -> None:
|
|
177
|
+
"""Register metrics for loss, constraint satisfaction ratio (CSR), and constraints.
|
|
178
|
+
|
|
179
|
+
This method registers the following metrics:
|
|
180
|
+
|
|
181
|
+
- Loss/train: Training loss.
|
|
182
|
+
- Loss/valid: Validation loss.
|
|
183
|
+
- Loss/test: Test loss after training.
|
|
184
|
+
- CSR/train: Constraint satisfaction ratio during training.
|
|
185
|
+
- CSR/valid: Constraint satisfaction ratio during validation.
|
|
186
|
+
- CSR/test: Constraint satisfaction ratio after training.
|
|
187
|
+
- One metric per constraint, for both training and validation.
|
|
188
|
+
|
|
189
|
+
"""
|
|
190
|
+
self.metric_manager.register("Loss/train", "during_training")
|
|
191
|
+
self.metric_manager.register("Loss/valid", "during_training")
|
|
192
|
+
self.metric_manager.register("Loss/test", "after_training")
|
|
193
|
+
|
|
194
|
+
if len(self.constraints) > 0:
|
|
195
|
+
self.metric_manager.register("CSR/train", "during_training")
|
|
196
|
+
self.metric_manager.register("CSR/valid", "during_training")
|
|
197
|
+
self.metric_manager.register("CSR/test", "after_training")
|
|
198
|
+
|
|
199
|
+
for constraint in self.constraints:
|
|
200
|
+
self.metric_manager.register(f"{constraint.name}/train", "during_training")
|
|
201
|
+
self.metric_manager.register(f"{constraint.name}/valid", "during_training")
|
|
202
|
+
self.metric_manager.register(f"{constraint.name}/test", "after_training")
|
|
203
|
+
|
|
204
|
+
def fit(
|
|
205
|
+
self,
|
|
206
|
+
*,
|
|
207
|
+
start_epoch: int = 0,
|
|
208
|
+
max_epochs: int = 100,
|
|
209
|
+
test_model: bool = True,
|
|
210
|
+
final_checkpoint_name: str = "checkpoint_final.pth",
|
|
211
|
+
) -> None:
|
|
212
|
+
"""Run the full training loop, including optional validation, testing, and checkpointing.
|
|
213
|
+
|
|
214
|
+
This method performs training over multiple epochs with the following steps:
|
|
215
|
+
1. Trigger "on_train_start" callbacks if a callback manager is present.
|
|
216
|
+
2. For each epoch:
|
|
217
|
+
- Trigger "on_epoch_start" callbacks.
|
|
218
|
+
- Run a training epoch via the EpochRunner.
|
|
219
|
+
- Run a validation epoch via the EpochRunner.
|
|
220
|
+
- Evaluate checkpoint criteria if a checkpoint manager is present.
|
|
221
|
+
- Trigger "on_epoch_end" callbacks.
|
|
222
|
+
3. Trigger "on_train_end" callbacks after all epochs.
|
|
223
|
+
4. Optionally run a test epoch via the EpochRunner if `test_model` is True,
|
|
224
|
+
with corresponding "on_test_start" and "on_test_end" callbacks.
|
|
225
|
+
5. Save a final checkpoint using the checkpoint manager.
|
|
226
|
+
|
|
227
|
+
Args:
|
|
228
|
+
start_epoch: Index of the starting epoch (default 0). Useful for resuming training.
|
|
229
|
+
max_epochs: Maximum number of epochs to run (default 100).
|
|
230
|
+
test_model: Whether to run a test epoch after training (default True).
|
|
231
|
+
final_checkpoint_name: Filename for the final checkpoint saved at the end of training
|
|
232
|
+
(default "checkpoint_final.pth").
|
|
233
|
+
|
|
234
|
+
Returns:
|
|
235
|
+
None
|
|
236
|
+
"""
|
|
237
|
+
if self.callback_manager:
|
|
238
|
+
self.callback_manager.run("on_train_start", {"epoch": start_epoch})
|
|
239
|
+
|
|
240
|
+
for epoch in tqdm(
|
|
241
|
+
range(start_epoch, max_epochs),
|
|
242
|
+
initial=start_epoch,
|
|
243
|
+
desc="Epoch",
|
|
244
|
+
disable=self.disable_progress_bar_epoch,
|
|
245
|
+
):
|
|
246
|
+
if self.callback_manager:
|
|
247
|
+
self.callback_manager.run("on_epoch_start", {"epoch": epoch})
|
|
248
|
+
|
|
249
|
+
self.epoch_runner.train()
|
|
250
|
+
self.epoch_runner.validate()
|
|
251
|
+
|
|
252
|
+
if self.checkpoint_manager:
|
|
253
|
+
self.checkpoint_manager.evaluate_criteria(epoch)
|
|
254
|
+
|
|
255
|
+
if self.callback_manager:
|
|
256
|
+
self.callback_manager.run("on_epoch_end", {"epoch": epoch})
|
|
257
|
+
|
|
258
|
+
if self.callback_manager:
|
|
259
|
+
self.callback_manager.run("on_train_end", {"epoch": epoch})
|
|
260
|
+
|
|
261
|
+
if test_model:
|
|
262
|
+
if self.callback_manager:
|
|
263
|
+
self.callback_manager.run("on_test_start", {"epoch": epoch})
|
|
264
|
+
|
|
265
|
+
self.epoch_runner.test()
|
|
266
|
+
|
|
267
|
+
if self.callback_manager:
|
|
268
|
+
self.callback_manager.run("on_test_end", {"epoch": epoch})
|
|
269
|
+
|
|
270
|
+
if self.checkpoint_manager:
|
|
271
|
+
self.checkpoint_manager.save(epoch, final_checkpoint_name)
|
|
@@ -0,0 +1,170 @@
|
|
|
1
|
+
"""Manages the evaluation and optional enforcement of constraints on neural network outputs.
|
|
2
|
+
|
|
3
|
+
Responsibilities:
|
|
4
|
+
- Compute and log Constraint Satisfaction Rate (CSR) for training, validation, and test batches.
|
|
5
|
+
- Optionally adjust loss during training based on constraint directions and rescale factors.
|
|
6
|
+
- Handle gradient computation and CGGD application.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
from torch import Tensor, no_grad
|
|
11
|
+
from torch.linalg import vector_norm
|
|
12
|
+
|
|
13
|
+
from ..constraints.base import Constraint
|
|
14
|
+
from ..descriptor import Descriptor
|
|
15
|
+
from ..metrics import MetricManager
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class ConstraintEngine:
|
|
19
|
+
"""Manages constraint evaluation and enforcement for a neural network.
|
|
20
|
+
|
|
21
|
+
The ConstraintEngine coordinates constraints defined in Constraint objects,
|
|
22
|
+
computes gradients for layers that affect the loss, logs metrics, and optionally
|
|
23
|
+
modifies the loss during training according to the constraints. It supports
|
|
24
|
+
separate phases for training, validation, and testing.
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
def __init__(
|
|
28
|
+
self,
|
|
29
|
+
*,
|
|
30
|
+
constraints: list[Constraint],
|
|
31
|
+
descriptor: Descriptor,
|
|
32
|
+
metric_manager: MetricManager,
|
|
33
|
+
device: torch.device,
|
|
34
|
+
epsilon: float,
|
|
35
|
+
aggregator: callable,
|
|
36
|
+
enforce_all: bool,
|
|
37
|
+
) -> None:
|
|
38
|
+
"""Initialize the ConstraintEngine.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
constraints: List of Constraint objects to evaluate and optionally enforce.
|
|
42
|
+
descriptor: Descriptor containing metadata about network layers and which
|
|
43
|
+
variables affect the loss.
|
|
44
|
+
metric_manager: MetricManager instance for logging CSR metrics.
|
|
45
|
+
device: Torch device where tensors will be allocated (CPU or GPU).
|
|
46
|
+
epsilon: Small positive value to avoid division by zero in gradient norms.
|
|
47
|
+
aggregator: Callable used to reduce per-layer constraint contributions
|
|
48
|
+
to a scalar loss adjustment.
|
|
49
|
+
enforce_all: Whether to enforce all constraints during training.
|
|
50
|
+
"""
|
|
51
|
+
self.constraints = constraints
|
|
52
|
+
self.descriptor = descriptor
|
|
53
|
+
self.metric_manager = metric_manager
|
|
54
|
+
self.device = device
|
|
55
|
+
self.epsilon = epsilon
|
|
56
|
+
self.enforce_all = enforce_all
|
|
57
|
+
self.aggregator = aggregator
|
|
58
|
+
|
|
59
|
+
def train(self, data: dict[str, Tensor], loss: Tensor) -> Tensor:
|
|
60
|
+
"""Apply all active constraints during training.
|
|
61
|
+
|
|
62
|
+
Computes the original loss gradients for layers that affect the loss,
|
|
63
|
+
evaluates each constraint, logs the Constraint Satisfaction Rate (CSR),
|
|
64
|
+
and adjusts the loss according to constraint satisfaction.
|
|
65
|
+
|
|
66
|
+
Args:
|
|
67
|
+
data: Dictionary containing input and prediction tensors for the batch.
|
|
68
|
+
loss: The original loss tensor computed from the network output.
|
|
69
|
+
|
|
70
|
+
Returns:
|
|
71
|
+
Tensor: The loss tensor after applying constraint-based adjustments.
|
|
72
|
+
"""
|
|
73
|
+
return self._apply_constraints(data, loss, phase="train", enforce=True)
|
|
74
|
+
|
|
75
|
+
def validate(self, data: dict[str, Tensor], loss: Tensor) -> Tensor:
|
|
76
|
+
"""Evaluate constraints during validation without modifying the loss.
|
|
77
|
+
|
|
78
|
+
Computes and logs the Constraint Satisfaction Rate (CSR) for each constraint,
|
|
79
|
+
but does not apply rescale adjustments to the loss.
|
|
80
|
+
|
|
81
|
+
Args:
|
|
82
|
+
data: Dictionary containing input and prediction tensors for the batch.
|
|
83
|
+
loss: The original loss tensor computed from the network output.
|
|
84
|
+
|
|
85
|
+
Returns:
|
|
86
|
+
Tensor: The original loss tensor, unchanged.
|
|
87
|
+
"""
|
|
88
|
+
return self._apply_constraints(data, loss, phase="valid", enforce=False)
|
|
89
|
+
|
|
90
|
+
def test(self, data: dict[str, Tensor], loss: Tensor) -> Tensor:
|
|
91
|
+
"""Evaluate constraints during testing without modifying the loss.
|
|
92
|
+
|
|
93
|
+
Computes and logs the Constraint Satisfaction Rate (CSR) for each constraint,
|
|
94
|
+
but does not apply rescale adjustments to the loss.
|
|
95
|
+
|
|
96
|
+
Args:
|
|
97
|
+
data: Dictionary containing input and prediction tensors for the batch.
|
|
98
|
+
loss: The original loss tensor computed from the network output.
|
|
99
|
+
|
|
100
|
+
Returns:
|
|
101
|
+
Tensor: The original loss tensor, unchanged.
|
|
102
|
+
"""
|
|
103
|
+
return self._apply_constraints(data, loss, phase="test", enforce=False)
|
|
104
|
+
|
|
105
|
+
def _apply_constraints(
|
|
106
|
+
self, data: dict[str, Tensor], loss: Tensor, phase: str, enforce: bool
|
|
107
|
+
) -> Tensor:
|
|
108
|
+
"""Evaluate constraints, log CSR, and optionally adjust the loss.
|
|
109
|
+
|
|
110
|
+
During training, computes loss gradients for variable layers that affect the loss.
|
|
111
|
+
Iterates over all constraints, logging the Constraint Satisfaction Rate (CSR)
|
|
112
|
+
and, if enforcement is enabled, adjusts the loss using constraint-specific
|
|
113
|
+
directions and rescale factors.
|
|
114
|
+
|
|
115
|
+
Args:
|
|
116
|
+
data: Dictionary containing input and prediction tensors for the batch.
|
|
117
|
+
loss: Original loss tensor computed from the network output.
|
|
118
|
+
phase: Current phase, one of "train", "valid", or "test".
|
|
119
|
+
enforce: If True, constraint-based adjustments are applied to the loss.
|
|
120
|
+
|
|
121
|
+
Returns:
|
|
122
|
+
Tensor: The combined loss after applying constraints (or the original loss
|
|
123
|
+
if enforce is False or not in training phase).
|
|
124
|
+
"""
|
|
125
|
+
total_rescale_loss = torch.tensor(0.0, device=self.device, dtype=loss.dtype)
|
|
126
|
+
|
|
127
|
+
# Precompute gradients for variable layers affecting the loss
|
|
128
|
+
if phase == "train":
|
|
129
|
+
norm_loss_grad = {}
|
|
130
|
+
variable_keys = self.descriptor.variable_keys & self.descriptor.affects_loss_keys
|
|
131
|
+
|
|
132
|
+
for key in variable_keys:
|
|
133
|
+
grad = torch.autograd.grad(
|
|
134
|
+
outputs=loss, inputs=data[key], retain_graph=True, allow_unused=True
|
|
135
|
+
)[0]
|
|
136
|
+
|
|
137
|
+
if grad is None:
|
|
138
|
+
raise RuntimeError(
|
|
139
|
+
f"Unable to compute loss gradients for layer '{key}'. "
|
|
140
|
+
"Set has_loss=False in Descriptor if this layer does not affect loss."
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
grad_flat = grad.view(grad.shape[0], -1)
|
|
144
|
+
norm_loss_grad[key] = (
|
|
145
|
+
vector_norm(grad_flat, dim=1, ord=2, keepdim=True)
|
|
146
|
+
.clamp(min=self.epsilon)
|
|
147
|
+
.detach()
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
# Iterate constraints
|
|
151
|
+
for constraint in self.constraints:
|
|
152
|
+
checks, mask = constraint.check_constraint(data)
|
|
153
|
+
directions = constraint.calculate_direction(data)
|
|
154
|
+
|
|
155
|
+
# Log CSR
|
|
156
|
+
csr = (torch.sum(checks * mask) / torch.sum(mask)).unsqueeze(0)
|
|
157
|
+
self.metric_manager.accumulate(f"{constraint.name}/{phase}", csr)
|
|
158
|
+
self.metric_manager.accumulate(f"CSR/{phase}", csr)
|
|
159
|
+
|
|
160
|
+
# Skip adjustment if not enforcing
|
|
161
|
+
if not enforce or not constraint.enforce or not self.enforce_all or phase != "train":
|
|
162
|
+
continue
|
|
163
|
+
|
|
164
|
+
# Compute constraint-based rescale loss
|
|
165
|
+
for key in constraint.layers & self.descriptor.variable_keys:
|
|
166
|
+
with no_grad():
|
|
167
|
+
rescale = (1 - checks) * directions[key] * constraint.rescale_factor
|
|
168
|
+
total_rescale_loss += self.aggregator(data[key] * rescale * norm_loss_grad[key])
|
|
169
|
+
|
|
170
|
+
return loss + total_rescale_loss
|
|
@@ -0,0 +1,119 @@
|
|
|
1
|
+
"""Defines the EpochRunner class for running full training, validation, and test epochs.
|
|
2
|
+
|
|
3
|
+
This module handles:
|
|
4
|
+
- Switching the network between training and evaluation modes
|
|
5
|
+
- Iterating over DataLoaders with optional progress bars
|
|
6
|
+
- Delegating per-batch processing to a BatchRunner instance
|
|
7
|
+
- Optional gradient tracking control for evaluation phases
|
|
8
|
+
- Warnings when validation or test loaders are not provided
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
import warnings
|
|
12
|
+
|
|
13
|
+
from torch import set_grad_enabled
|
|
14
|
+
from torch.nn import Module
|
|
15
|
+
from torch.utils.data import DataLoader
|
|
16
|
+
from tqdm import tqdm
|
|
17
|
+
|
|
18
|
+
from ..core.batch_runner import BatchRunner
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class EpochRunner:
|
|
22
|
+
"""Runs full epochs over DataLoaders.
|
|
23
|
+
|
|
24
|
+
Responsibilities:
|
|
25
|
+
- Model mode switching
|
|
26
|
+
- Iteration over DataLoader
|
|
27
|
+
- Delegation to BatchRunner
|
|
28
|
+
- Progress bars
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
def __init__(
|
|
32
|
+
self,
|
|
33
|
+
network: Module,
|
|
34
|
+
batch_runner: BatchRunner,
|
|
35
|
+
train_loader: DataLoader,
|
|
36
|
+
valid_loader: DataLoader | None = None,
|
|
37
|
+
test_loader: DataLoader | None = None,
|
|
38
|
+
*,
|
|
39
|
+
network_uses_grad: bool = False,
|
|
40
|
+
disable_progress_bar: bool = False,
|
|
41
|
+
):
|
|
42
|
+
"""Initialize the EpochRunner.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
network: The neural network module to train/validate/test.
|
|
46
|
+
batch_runner: The BatchRunner instance for processing batches.
|
|
47
|
+
train_loader: DataLoader for training data.
|
|
48
|
+
valid_loader: DataLoader for validation data. Defaults to None.
|
|
49
|
+
test_loader: DataLoader for test data. Defaults to None.
|
|
50
|
+
network_uses_grad: Whether the network uses gradient computation. Defaults to False.
|
|
51
|
+
disable_progress_bar: Whether to disable progress bar display. Defaults to False.
|
|
52
|
+
"""
|
|
53
|
+
self.network = network
|
|
54
|
+
self.batch_runner = batch_runner
|
|
55
|
+
self.train_loader = train_loader
|
|
56
|
+
self.valid_loader = valid_loader
|
|
57
|
+
self.test_loader = test_loader
|
|
58
|
+
self.network_uses_grad = network_uses_grad
|
|
59
|
+
self.disable_progress_bar = disable_progress_bar
|
|
60
|
+
|
|
61
|
+
def train(self) -> None:
|
|
62
|
+
"""Run a training epoch over the training DataLoader.
|
|
63
|
+
|
|
64
|
+
Sets the network to training mode and iterates over batches,
|
|
65
|
+
delegating each batch to the BatchRunner for processing.
|
|
66
|
+
"""
|
|
67
|
+
self.network.train()
|
|
68
|
+
|
|
69
|
+
for batch in tqdm(
|
|
70
|
+
self.train_loader,
|
|
71
|
+
desc="Training batches",
|
|
72
|
+
leave=False,
|
|
73
|
+
disable=self.disable_progress_bar,
|
|
74
|
+
):
|
|
75
|
+
self.batch_runner.train_batch(batch)
|
|
76
|
+
|
|
77
|
+
def validate(self) -> None:
|
|
78
|
+
"""Run a validation epoch over the validation DataLoader.
|
|
79
|
+
|
|
80
|
+
Sets the network to evaluation mode and iterates over batches,
|
|
81
|
+
delegating each batch to the BatchRunner for processing.
|
|
82
|
+
Skips validation if no valid_loader is provided.
|
|
83
|
+
"""
|
|
84
|
+
if self.valid_loader is None:
|
|
85
|
+
warnings.warn("Validation skipped: no valid_loader provided.", stacklevel=2)
|
|
86
|
+
return
|
|
87
|
+
|
|
88
|
+
with set_grad_enabled(self.network_uses_grad):
|
|
89
|
+
self.network.eval()
|
|
90
|
+
|
|
91
|
+
for batch in tqdm(
|
|
92
|
+
self.valid_loader,
|
|
93
|
+
desc="Validation batches",
|
|
94
|
+
leave=False,
|
|
95
|
+
disable=self.disable_progress_bar,
|
|
96
|
+
):
|
|
97
|
+
self.batch_runner.valid_batch(batch)
|
|
98
|
+
|
|
99
|
+
def test(self) -> None:
|
|
100
|
+
"""Run a test epoch over the test DataLoader.
|
|
101
|
+
|
|
102
|
+
Sets the network to evaluation mode and iterates over batches,
|
|
103
|
+
delegating each batch to the BatchRunner for processing.
|
|
104
|
+
Skips testing if no test_loader is provided.
|
|
105
|
+
"""
|
|
106
|
+
if self.test_loader is None:
|
|
107
|
+
warnings.warn("Testing skipped: no test_loader provided.", stacklevel=2)
|
|
108
|
+
return
|
|
109
|
+
|
|
110
|
+
with set_grad_enabled(self.network_uses_grad):
|
|
111
|
+
self.network.eval()
|
|
112
|
+
|
|
113
|
+
for batch in tqdm(
|
|
114
|
+
self.test_loader,
|
|
115
|
+
desc="Test batches",
|
|
116
|
+
leave=False,
|
|
117
|
+
disable=self.disable_progress_bar,
|
|
118
|
+
):
|
|
119
|
+
self.batch_runner.test_batch(batch)
|
congrads/descriptor.py
CHANGED
congrads/metrics.py
CHANGED
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
"""Module defining transformations and components."""
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
|
|
5
|
+
from torch import Tensor
|
|
6
|
+
|
|
7
|
+
from ..utils.validation import validate_type
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class Transformation(ABC):
|
|
11
|
+
"""Abstract base class for tag data transformations."""
|
|
12
|
+
|
|
13
|
+
def __init__(self, tag: str):
|
|
14
|
+
"""Initialize a Transformation.
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
tag (str): Tag this transformation applies to.
|
|
18
|
+
"""
|
|
19
|
+
validate_type("tag", tag, str)
|
|
20
|
+
|
|
21
|
+
super().__init__()
|
|
22
|
+
self.tag = tag
|
|
23
|
+
|
|
24
|
+
@abstractmethod
|
|
25
|
+
def __call__(self, data: Tensor) -> Tensor:
|
|
26
|
+
"""Apply the transformation to the input tensor.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
data (Tensor): Input tensor representing network data.
|
|
30
|
+
|
|
31
|
+
Returns:
|
|
32
|
+
Tensor: Transformed tensor.
|
|
33
|
+
|
|
34
|
+
Raises:
|
|
35
|
+
NotImplementedError: Must be implemented by subclasses.
|
|
36
|
+
"""
|
|
37
|
+
raise NotImplementedError
|
|
@@ -1,41 +1,11 @@
|
|
|
1
|
-
"""Module
|
|
1
|
+
"""Module holding specific transformation implementations."""
|
|
2
2
|
|
|
3
|
-
from abc import ABC, abstractmethod
|
|
4
3
|
from numbers import Number
|
|
5
4
|
|
|
6
5
|
from torch import Tensor
|
|
7
6
|
|
|
8
|
-
from .
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
class Transformation(ABC):
|
|
12
|
-
"""Abstract base class for tag data transformations."""
|
|
13
|
-
|
|
14
|
-
def __init__(self, tag: str):
|
|
15
|
-
"""Initialize a Transformation.
|
|
16
|
-
|
|
17
|
-
Args:
|
|
18
|
-
tag (str): Tag this transformation applies to.
|
|
19
|
-
"""
|
|
20
|
-
validate_type("tag", tag, str)
|
|
21
|
-
|
|
22
|
-
super().__init__()
|
|
23
|
-
self.tag = tag
|
|
24
|
-
|
|
25
|
-
@abstractmethod
|
|
26
|
-
def __call__(self, data: Tensor) -> Tensor:
|
|
27
|
-
"""Apply the transformation to the input tensor.
|
|
28
|
-
|
|
29
|
-
Args:
|
|
30
|
-
data (Tensor): Input tensor representing network data.
|
|
31
|
-
|
|
32
|
-
Returns:
|
|
33
|
-
Tensor: Transformed tensor.
|
|
34
|
-
|
|
35
|
-
Raises:
|
|
36
|
-
NotImplementedError: Must be implemented by subclasses.
|
|
37
|
-
"""
|
|
38
|
-
raise NotImplementedError
|
|
7
|
+
from ..utils.validation import validate_callable, validate_type
|
|
8
|
+
from .base import Transformation
|
|
39
9
|
|
|
40
10
|
|
|
41
11
|
class IdentityTransformation(Transformation):
|