congrads 1.1.2__py3-none-any.whl → 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- congrads/__init__.py +0 -17
- congrads/callbacks/base.py +357 -0
- congrads/callbacks/registry.py +106 -0
- congrads/checkpoints.py +1 -1
- congrads/constraints/base.py +174 -0
- congrads/{constraints.py → constraints/registry.py} +120 -158
- congrads/core/batch_runner.py +200 -0
- congrads/core/congradscore.py +271 -0
- congrads/core/constraint_engine.py +170 -0
- congrads/core/epoch_runner.py +119 -0
- congrads/descriptor.py +1 -1
- congrads/metrics.py +1 -1
- congrads/transformations/base.py +37 -0
- congrads/{transformations.py → transformations/registry.py} +3 -33
- congrads/utils/preprocessors.py +439 -0
- congrads/utils/utility.py +506 -0
- congrads/utils/validation.py +194 -0
- {congrads-1.1.2.dist-info → congrads-1.2.0.dist-info}/METADATA +1 -1
- congrads-1.2.0.dist-info/RECORD +23 -0
- congrads/core.py +0 -773
- congrads/utils.py +0 -1078
- congrads-1.1.2.dist-info/RECORD +0 -14
- /congrads/{datasets.py → datasets/registry.py} +0 -0
- /congrads/{networks.py → networks/registry.py} +0 -0
- {congrads-1.1.2.dist-info → congrads-1.2.0.dist-info}/WHEEL +0 -0
congrads/core.py
DELETED
|
@@ -1,773 +0,0 @@
|
|
|
1
|
-
"""This module provides the core CongradsCore class for the main training functionality.
|
|
2
|
-
|
|
3
|
-
It is designed to integrate constraint-guided optimization into neural network training.
|
|
4
|
-
It extends traditional training processes by enforcing specific constraints
|
|
5
|
-
on the model's outputs, ensuring that the network satisfies domain-specific
|
|
6
|
-
requirements during both training and evaluation.
|
|
7
|
-
|
|
8
|
-
The `CongradsCore` class serves as the central engine for managing the
|
|
9
|
-
training, validation, and testing phases of a neural network model,
|
|
10
|
-
incorporating constraints that influence the loss function and model updates.
|
|
11
|
-
The model is trained with standard loss functions while also incorporating
|
|
12
|
-
constraint-based adjustments, which are tracked and logged
|
|
13
|
-
throughout the process.
|
|
14
|
-
|
|
15
|
-
Key features:
|
|
16
|
-
- Support for various constraints that can influence the training process.
|
|
17
|
-
- Integration with PyTorch's `DataLoader` for efficient batch processing.
|
|
18
|
-
- Metric management for tracking loss and constraint satisfaction.
|
|
19
|
-
- Checkpoint management for saving and evaluating model states.
|
|
20
|
-
|
|
21
|
-
The `CongradsCore` class allows for the use of additional callback functions
|
|
22
|
-
at different stages of the training process to customize behavior for
|
|
23
|
-
specific needs. These include callbacks for the start and end of epochs, as
|
|
24
|
-
well as the start and end of the entire training process.
|
|
25
|
-
|
|
26
|
-
"""
|
|
27
|
-
|
|
28
|
-
import warnings
|
|
29
|
-
from collections.abc import Callable
|
|
30
|
-
|
|
31
|
-
import torch
|
|
32
|
-
from torch import Tensor, float32, no_grad, sum, tensor
|
|
33
|
-
from torch.linalg import vector_norm
|
|
34
|
-
from torch.nn import Module
|
|
35
|
-
from torch.nn.modules.loss import _Loss
|
|
36
|
-
from torch.optim import Optimizer
|
|
37
|
-
from torch.utils.data import DataLoader
|
|
38
|
-
from tqdm import tqdm
|
|
39
|
-
|
|
40
|
-
from .checkpoints import CheckpointManager
|
|
41
|
-
from .constraints import Constraint
|
|
42
|
-
from .descriptor import Descriptor
|
|
43
|
-
from .metrics import MetricManager
|
|
44
|
-
from .utils import (
|
|
45
|
-
is_torch_loss,
|
|
46
|
-
torch_loss_wrapper,
|
|
47
|
-
validate_callable,
|
|
48
|
-
validate_callable_iterable,
|
|
49
|
-
validate_iterable,
|
|
50
|
-
validate_loaders,
|
|
51
|
-
validate_type,
|
|
52
|
-
)
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
class CongradsCore:
|
|
56
|
-
"""The CongradsCore class is the central training engine for constraint-guided optimization.
|
|
57
|
-
|
|
58
|
-
It integrates standard neural network training
|
|
59
|
-
with additional constraint-driven adjustments to the loss function, ensuring
|
|
60
|
-
that the network satisfies domain-specific constraints during training.
|
|
61
|
-
"""
|
|
62
|
-
|
|
63
|
-
def __init__(
|
|
64
|
-
self,
|
|
65
|
-
descriptor: Descriptor,
|
|
66
|
-
constraints: list[Constraint],
|
|
67
|
-
loaders: tuple[DataLoader, DataLoader, DataLoader],
|
|
68
|
-
network: Module,
|
|
69
|
-
criterion: _Loss,
|
|
70
|
-
optimizer: Optimizer,
|
|
71
|
-
metric_manager: MetricManager,
|
|
72
|
-
device: torch.device,
|
|
73
|
-
network_uses_grad: bool = False,
|
|
74
|
-
checkpoint_manager: CheckpointManager = None,
|
|
75
|
-
epsilon: float = 1e-6,
|
|
76
|
-
constraint_aggregator: Callable[..., Tensor] = sum,
|
|
77
|
-
disable_progress_bar_epoch: bool = False,
|
|
78
|
-
disable_progress_bar_batch: bool = False,
|
|
79
|
-
enforce_all: bool = True,
|
|
80
|
-
):
|
|
81
|
-
"""Initialize the CongradsCore object.
|
|
82
|
-
|
|
83
|
-
Args:
|
|
84
|
-
descriptor (Descriptor): Describes variable layers in the network.
|
|
85
|
-
constraints (list[Constraint]): List of constraints to guide training.
|
|
86
|
-
loaders (tuple[DataLoader, DataLoader, DataLoader]): DataLoaders for
|
|
87
|
-
training, validation, and testing.
|
|
88
|
-
network (Module): The neural network model to train.
|
|
89
|
-
criterion (callable): The loss function used for
|
|
90
|
-
training and validation.
|
|
91
|
-
optimizer (Optimizer): The optimizer used for updating model parameters.
|
|
92
|
-
metric_manager (MetricManager): Manages metric tracking and recording.
|
|
93
|
-
device (torch.device): The device (e.g., CPU or GPU) for computations.
|
|
94
|
-
network_uses_grad (bool, optional): A flag indicating if the network
|
|
95
|
-
contains gradient calculation computations. Default is False.
|
|
96
|
-
checkpoint_manager (CheckpointManager, optional): Manages
|
|
97
|
-
checkpointing. If not set, no checkpointing is done.
|
|
98
|
-
epsilon (float, optional): A small value to avoid division by zero
|
|
99
|
-
in gradient calculations. Default is 1e-10.
|
|
100
|
-
constraint_aggregator (Callable[..., Tensor], optional): A function
|
|
101
|
-
to aggregate the constraint rescale loss. Default is `sum`.
|
|
102
|
-
disable_progress_bar_epoch (bool, optional): If set to True, the epoch
|
|
103
|
-
progress bar will not show. Defaults to False.
|
|
104
|
-
disable_progress_bar_batch (bool, optional): If set to True, the batch
|
|
105
|
-
progress bar will not show. Defaults to False.
|
|
106
|
-
enforce_all (bool, optional): If set to False, constraints will only be monitored and
|
|
107
|
-
not influence the training process. Overrides constraint-specific `enforce` parameters.
|
|
108
|
-
Defaults to True.
|
|
109
|
-
|
|
110
|
-
Note:
|
|
111
|
-
A warning is logged if the descriptor has no variable layers,
|
|
112
|
-
as at least one variable layer is required for the constraint logic
|
|
113
|
-
to influence the training process.
|
|
114
|
-
"""
|
|
115
|
-
# Type checking
|
|
116
|
-
validate_type("descriptor", descriptor, Descriptor)
|
|
117
|
-
validate_iterable("constraints", constraints, Constraint, allow_empty=True)
|
|
118
|
-
validate_loaders("loaders", loaders)
|
|
119
|
-
validate_type("network", network, Module)
|
|
120
|
-
validate_type("criterion", criterion, _Loss)
|
|
121
|
-
validate_type("optimizer", optimizer, Optimizer)
|
|
122
|
-
validate_type("metric_manager", metric_manager, MetricManager)
|
|
123
|
-
validate_type("device", device, torch.device)
|
|
124
|
-
validate_type("network_uses_grad", network_uses_grad, bool)
|
|
125
|
-
validate_type(
|
|
126
|
-
"checkpoint_manager",
|
|
127
|
-
checkpoint_manager,
|
|
128
|
-
CheckpointManager,
|
|
129
|
-
allow_none=True,
|
|
130
|
-
)
|
|
131
|
-
validate_type("epsilon", epsilon, float)
|
|
132
|
-
validate_callable("constraint_aggregator", constraint_aggregator, allow_none=True)
|
|
133
|
-
validate_type("disable_progress_bar_epoch", disable_progress_bar_epoch, bool)
|
|
134
|
-
validate_type("disable_progress_bar_batch", disable_progress_bar_batch, bool)
|
|
135
|
-
validate_type("enforce_all", enforce_all, bool)
|
|
136
|
-
|
|
137
|
-
# Init object variables
|
|
138
|
-
self.descriptor = descriptor
|
|
139
|
-
self.constraints = constraints
|
|
140
|
-
self.train_loader = loaders[0]
|
|
141
|
-
self.valid_loader = loaders[1]
|
|
142
|
-
self.test_loader = loaders[2]
|
|
143
|
-
self.network = network
|
|
144
|
-
self.optimizer = optimizer
|
|
145
|
-
self.metric_manager = metric_manager
|
|
146
|
-
self.device = device
|
|
147
|
-
self.network_uses_grad = network_uses_grad
|
|
148
|
-
self.checkpoint_manager = checkpoint_manager
|
|
149
|
-
self.epsilon = epsilon
|
|
150
|
-
self.constraint_aggregator = constraint_aggregator
|
|
151
|
-
self.disable_progress_bar_epoch = disable_progress_bar_epoch
|
|
152
|
-
self.disable_progress_bar_batch = disable_progress_bar_batch
|
|
153
|
-
self.enforce_all = enforce_all
|
|
154
|
-
|
|
155
|
-
# Check if criterion is a torch loss function
|
|
156
|
-
if is_torch_loss(criterion):
|
|
157
|
-
# If so, wrap it in a custom loss function
|
|
158
|
-
self.criterion = torch_loss_wrapper(criterion)
|
|
159
|
-
else:
|
|
160
|
-
self.criterion = criterion
|
|
161
|
-
|
|
162
|
-
# Perform checks
|
|
163
|
-
if len(self.descriptor.variable_keys) == 0:
|
|
164
|
-
warnings.warn(
|
|
165
|
-
"The descriptor object has no variable layers. The constraint \
|
|
166
|
-
guided loss adjustment is therefore not used. \
|
|
167
|
-
Is this the intended behavior?",
|
|
168
|
-
stacklevel=2,
|
|
169
|
-
)
|
|
170
|
-
|
|
171
|
-
# Initialize constraint metrics
|
|
172
|
-
self._initialize_metrics()
|
|
173
|
-
|
|
174
|
-
def _initialize_metrics(self) -> None:
|
|
175
|
-
"""Register metrics for loss, constraint satisfaction ratio (CSR), and constraints.
|
|
176
|
-
|
|
177
|
-
This method registers the following metrics:
|
|
178
|
-
|
|
179
|
-
- Loss/train: Training loss.
|
|
180
|
-
- Loss/valid: Validation loss.
|
|
181
|
-
- Loss/test: Test loss after training.
|
|
182
|
-
- CSR/train: Constraint satisfaction ratio during training.
|
|
183
|
-
- CSR/valid: Constraint satisfaction ratio during validation.
|
|
184
|
-
- CSR/test: Constraint satisfaction ratio after training.
|
|
185
|
-
- One metric per constraint, for both training and validation.
|
|
186
|
-
|
|
187
|
-
"""
|
|
188
|
-
self.metric_manager.register("Loss/train", "during_training")
|
|
189
|
-
self.metric_manager.register("Loss/valid", "during_training")
|
|
190
|
-
self.metric_manager.register("Loss/test", "after_training")
|
|
191
|
-
|
|
192
|
-
if len(self.constraints) > 0:
|
|
193
|
-
self.metric_manager.register("CSR/train", "during_training")
|
|
194
|
-
self.metric_manager.register("CSR/valid", "during_training")
|
|
195
|
-
self.metric_manager.register("CSR/test", "after_training")
|
|
196
|
-
|
|
197
|
-
for constraint in self.constraints:
|
|
198
|
-
self.metric_manager.register(f"{constraint.name}/train", "during_training")
|
|
199
|
-
self.metric_manager.register(f"{constraint.name}/valid", "during_training")
|
|
200
|
-
self.metric_manager.register(f"{constraint.name}/test", "after_training")
|
|
201
|
-
|
|
202
|
-
def fit(
|
|
203
|
-
self,
|
|
204
|
-
start_epoch: int = 0,
|
|
205
|
-
max_epochs: int = 100,
|
|
206
|
-
test_model: bool = True,
|
|
207
|
-
on_batch_start: list[Callable[[dict[str, Tensor]], dict[str, Tensor]]] | None = None,
|
|
208
|
-
on_batch_end: list[Callable[[dict[str, Tensor]], dict[str, Tensor]]] | None = None,
|
|
209
|
-
on_train_batch_start: list[Callable[[dict[str, Tensor]], dict[str, Tensor]]] | None = None,
|
|
210
|
-
on_train_batch_end: list[Callable[[dict[str, Tensor]], dict[str, Tensor]]] | None = None,
|
|
211
|
-
on_valid_batch_start: list[Callable[[dict[str, Tensor]], dict[str, Tensor]]] | None = None,
|
|
212
|
-
on_valid_batch_end: list[Callable[[dict[str, Tensor]], dict[str, Tensor]]] | None = None,
|
|
213
|
-
on_test_batch_start: list[Callable[[dict[str, Tensor]], dict[str, Tensor]]] | None = None,
|
|
214
|
-
on_test_batch_end: list[Callable[[dict[str, Tensor]], dict[str, Tensor]]] | None = None,
|
|
215
|
-
on_epoch_start: list[Callable[[int], None]] | None = None,
|
|
216
|
-
on_epoch_end: list[Callable[[int], None]] | None = None,
|
|
217
|
-
on_train_start: list[Callable[[int], None]] | None = None,
|
|
218
|
-
on_train_end: list[Callable[[int], None]] | None = None,
|
|
219
|
-
on_train_completion_forward_pass: list[Callable[[dict[str, Tensor]], dict[str, Tensor]]]
|
|
220
|
-
| None = None,
|
|
221
|
-
on_val_completion_forward_pass: list[Callable[[dict[str, Tensor]], dict[str, Tensor]]]
|
|
222
|
-
| None = None,
|
|
223
|
-
on_test_completion_forward_pass: list[Callable[[dict[str, Tensor]], dict[str, Tensor]]]
|
|
224
|
-
| None = None,
|
|
225
|
-
on_test_start: list[Callable[[int], None]] | None = None,
|
|
226
|
-
on_test_end: list[Callable[[int], None]] | None = None,
|
|
227
|
-
) -> None:
|
|
228
|
-
"""Train the model over multiple epochs with optional validation and testing.
|
|
229
|
-
|
|
230
|
-
This method manages the full training loop, including:
|
|
231
|
-
|
|
232
|
-
- Executing epoch-level and batch-level callbacks.
|
|
233
|
-
- Training and validating the model each epoch.
|
|
234
|
-
- Adjusting losses according to constraints.
|
|
235
|
-
- Logging metrics via the metric manager.
|
|
236
|
-
- Optional evaluation on the test set.
|
|
237
|
-
- Checkpointing the model during and after training.
|
|
238
|
-
|
|
239
|
-
Args:
|
|
240
|
-
start_epoch (int, optional): Epoch number to start training from. Defaults to 0.
|
|
241
|
-
max_epochs (int, optional): Total number of epochs to train. Defaults to 100.
|
|
242
|
-
test_model (bool, optional): If True, evaluate the model on the test set after training. Defaults to True.
|
|
243
|
-
on_batch_start (list[Callable], optional): Callbacks executed at the start of every batch. Defaults to None.
|
|
244
|
-
on_batch_end (list[Callable], optional): Callbacks executed at the end of every batch. Defaults to None.
|
|
245
|
-
on_train_batch_start (list[Callable], optional): Callbacks executed at the start of each training batch. Defaults to `on_batch_start` if not provided.
|
|
246
|
-
on_train_batch_end (list[Callable], optional): Callbacks executed at the end of each training batch. Defaults to `on_batch_end` if not provided.
|
|
247
|
-
on_valid_batch_start (list[Callable], optional): Callbacks executed at the start of each validation batch. Defaults to `on_batch_start` if not provided.
|
|
248
|
-
on_valid_batch_end (list[Callable], optional): Callbacks executed at the end of each validation batch. Defaults to `on_batch_end` if not provided.
|
|
249
|
-
on_test_batch_start (list[Callable], optional): Callbacks executed at the start of each test batch. Defaults to `on_batch_start` if not provided.
|
|
250
|
-
on_test_batch_end (list[Callable], optional): Callbacks executed at the end of each test batch. Defaults to `on_batch_end` if not provided.
|
|
251
|
-
on_epoch_start (list[Callable], optional): Callbacks executed at the start of each epoch. Defaults to None.
|
|
252
|
-
on_epoch_end (list[Callable], optional): Callbacks executed at the end of each epoch. Defaults to None.
|
|
253
|
-
on_train_start (list[Callable], optional): Callbacks executed before training starts. Defaults to None.
|
|
254
|
-
on_train_end (list[Callable], optional): Callbacks executed after training ends. Defaults to None.
|
|
255
|
-
on_train_completion_forward_pass (list[Callable], optional): Callbacks executed after the forward pass during training. Defaults to None.
|
|
256
|
-
on_val_completion_forward_pass (list[Callable], optional): Callbacks executed after the forward pass during validation. Defaults to None.
|
|
257
|
-
on_test_completion_forward_pass (list[Callable], optional): Callbacks executed after the forward pass during testing. Defaults to None.
|
|
258
|
-
on_test_start (list[Callable], optional): Callbacks executed before testing starts. Defaults to None.
|
|
259
|
-
on_test_end (list[Callable], optional): Callbacks executed after testing ends. Defaults to None.
|
|
260
|
-
|
|
261
|
-
Notes:
|
|
262
|
-
- If phase-specific callbacks (train/valid/test) are not provided, the global `on_batch_start` and `on_batch_end` are used.
|
|
263
|
-
- Training metrics, loss adjustments, and constraint satisfaction ratios are automatically logged via the metric manager.
|
|
264
|
-
- The final model checkpoint is saved if a checkpoint manager is configured.
|
|
265
|
-
"""
|
|
266
|
-
# Type checking
|
|
267
|
-
validate_type("start_epoch", start_epoch, int)
|
|
268
|
-
validate_type("max_epochs", max_epochs, int)
|
|
269
|
-
validate_type("test_model", test_model, bool)
|
|
270
|
-
validate_callable_iterable("on_batch_start", on_batch_start, allow_none=True)
|
|
271
|
-
validate_callable_iterable("on_batch_end", on_batch_end, allow_none=True)
|
|
272
|
-
validate_callable_iterable("on_train_batch_start", on_train_batch_start, allow_none=True)
|
|
273
|
-
validate_callable_iterable("on_train_batch_end", on_train_batch_end, allow_none=True)
|
|
274
|
-
validate_callable_iterable("on_valid_batch_start", on_valid_batch_start, allow_none=True)
|
|
275
|
-
validate_callable_iterable("on_valid_batch_end", on_valid_batch_end, allow_none=True)
|
|
276
|
-
validate_callable_iterable("on_test_batch_start", on_test_batch_start, allow_none=True)
|
|
277
|
-
validate_callable_iterable("on_test_batch_end", on_test_batch_end, allow_none=True)
|
|
278
|
-
validate_callable_iterable("on_epoch_start", on_epoch_start, allow_none=True)
|
|
279
|
-
validate_callable_iterable("on_epoch_end", on_epoch_end, allow_none=True)
|
|
280
|
-
validate_callable_iterable("on_train_start", on_train_start, allow_none=True)
|
|
281
|
-
validate_callable_iterable("on_train_end", on_train_end, allow_none=True)
|
|
282
|
-
validate_callable_iterable(
|
|
283
|
-
"on_train_completion_forward_pass",
|
|
284
|
-
on_train_completion_forward_pass,
|
|
285
|
-
allow_none=True,
|
|
286
|
-
)
|
|
287
|
-
validate_callable_iterable(
|
|
288
|
-
"on_val_completion_forward_pass",
|
|
289
|
-
on_val_completion_forward_pass,
|
|
290
|
-
allow_none=True,
|
|
291
|
-
)
|
|
292
|
-
validate_callable_iterable(
|
|
293
|
-
"on_test_completion_forward_pass",
|
|
294
|
-
on_test_completion_forward_pass,
|
|
295
|
-
allow_none=True,
|
|
296
|
-
)
|
|
297
|
-
validate_callable_iterable("on_test_start", on_test_start, allow_none=True)
|
|
298
|
-
validate_callable_iterable("on_test_end", on_test_end, allow_none=True)
|
|
299
|
-
|
|
300
|
-
# Use global batch callback if phase-specific callback is unset
|
|
301
|
-
# Init callbacks as empty list if None
|
|
302
|
-
on_train_batch_start = on_train_batch_start or on_batch_start or []
|
|
303
|
-
on_train_batch_end = on_train_batch_end or on_batch_end or []
|
|
304
|
-
on_valid_batch_start = on_valid_batch_start or on_batch_start or []
|
|
305
|
-
on_valid_batch_end = on_valid_batch_end or on_batch_end or []
|
|
306
|
-
on_test_batch_start = on_test_batch_start or on_batch_start or []
|
|
307
|
-
on_test_batch_end = on_test_batch_end or on_batch_end or []
|
|
308
|
-
on_batch_start = on_batch_start or []
|
|
309
|
-
on_batch_end = on_batch_end or []
|
|
310
|
-
on_epoch_start = on_epoch_start or []
|
|
311
|
-
on_epoch_end = on_epoch_end or []
|
|
312
|
-
on_train_start = on_train_start or []
|
|
313
|
-
on_train_end = on_train_end or []
|
|
314
|
-
on_train_completion_forward_pass = on_train_completion_forward_pass or []
|
|
315
|
-
on_val_completion_forward_pass = on_val_completion_forward_pass or []
|
|
316
|
-
on_test_completion_forward_pass = on_test_completion_forward_pass or []
|
|
317
|
-
on_test_start = on_test_start or []
|
|
318
|
-
on_test_end = on_test_end or []
|
|
319
|
-
|
|
320
|
-
# Keep track of epoch
|
|
321
|
-
epoch = start_epoch
|
|
322
|
-
|
|
323
|
-
# Execute training start hook if set
|
|
324
|
-
for callback in on_train_start:
|
|
325
|
-
callback(epoch)
|
|
326
|
-
|
|
327
|
-
for i in tqdm(
|
|
328
|
-
range(epoch, max_epochs),
|
|
329
|
-
initial=epoch,
|
|
330
|
-
desc="Epoch",
|
|
331
|
-
disable=self.disable_progress_bar_epoch,
|
|
332
|
-
):
|
|
333
|
-
epoch = i
|
|
334
|
-
|
|
335
|
-
# Execute epoch start hook if set
|
|
336
|
-
for callback in on_epoch_start:
|
|
337
|
-
callback(epoch)
|
|
338
|
-
|
|
339
|
-
# Execute training and validation epoch
|
|
340
|
-
self._train_epoch(
|
|
341
|
-
on_train_batch_start,
|
|
342
|
-
on_train_batch_end,
|
|
343
|
-
on_train_completion_forward_pass,
|
|
344
|
-
)
|
|
345
|
-
self._validate_epoch(
|
|
346
|
-
on_valid_batch_start,
|
|
347
|
-
on_valid_batch_end,
|
|
348
|
-
on_val_completion_forward_pass,
|
|
349
|
-
)
|
|
350
|
-
|
|
351
|
-
# Checkpointing
|
|
352
|
-
if self.checkpoint_manager:
|
|
353
|
-
self.checkpoint_manager.evaluate_criteria(epoch)
|
|
354
|
-
|
|
355
|
-
# Execute epoch end hook if set
|
|
356
|
-
for callback in on_epoch_end:
|
|
357
|
-
callback(epoch)
|
|
358
|
-
|
|
359
|
-
# Execute training end hook if set
|
|
360
|
-
for callback in on_train_end:
|
|
361
|
-
callback(epoch)
|
|
362
|
-
|
|
363
|
-
# Evaluate model performance on unseen test set if required
|
|
364
|
-
if test_model:
|
|
365
|
-
# Execute test end hook if set
|
|
366
|
-
for callback in on_test_start:
|
|
367
|
-
callback(epoch)
|
|
368
|
-
|
|
369
|
-
self._test_model(
|
|
370
|
-
on_test_batch_start,
|
|
371
|
-
on_test_batch_end,
|
|
372
|
-
on_test_completion_forward_pass,
|
|
373
|
-
)
|
|
374
|
-
|
|
375
|
-
# Execute test end hook if set
|
|
376
|
-
for callback in on_test_end:
|
|
377
|
-
callback(epoch)
|
|
378
|
-
|
|
379
|
-
# Save final model
|
|
380
|
-
if self.checkpoint_manager:
|
|
381
|
-
self.checkpoint_manager.save(epoch, "checkpoint_final.pth")
|
|
382
|
-
|
|
383
|
-
def _train_epoch(
|
|
384
|
-
self,
|
|
385
|
-
on_train_batch_start: tuple[Callable[[dict[str, Tensor]], dict[str, Tensor]], ...],
|
|
386
|
-
on_train_batch_end: tuple[Callable[[dict[str, Tensor]], dict[str, Tensor]], ...],
|
|
387
|
-
on_train_completion_forward_pass: tuple[
|
|
388
|
-
Callable[[dict[str, Tensor]], dict[str, Tensor]], ...
|
|
389
|
-
],
|
|
390
|
-
) -> None:
|
|
391
|
-
"""Perform a single training epoch over all batches.
|
|
392
|
-
|
|
393
|
-
This method sets the network to training mode, iterates over the training
|
|
394
|
-
DataLoader, computes predictions, evaluates losses, applies constraint-based
|
|
395
|
-
adjustments, performs backpropagation, and updates model parameters. It also
|
|
396
|
-
supports executing optional callbacks at different stages of the batch
|
|
397
|
-
processing.
|
|
398
|
-
|
|
399
|
-
Args:
|
|
400
|
-
on_train_batch_start (tuple[Callable[[dict[str, Tensor]], dict[str, Tensor]], ...]):
|
|
401
|
-
Callbacks executed at the start of each batch. Each callback receives the
|
|
402
|
-
data dictionary and returns updated versions.
|
|
403
|
-
on_train_batch_end (tuple[Callable[[dict[str, Tensor]], dict[str, Tensor]], ...]):
|
|
404
|
-
Callbacks executed at the end of each batch. Each callback receives the
|
|
405
|
-
data dictionary and returns updated versions.
|
|
406
|
-
on_train_completion_forward_pass (tuple[Callable[[dict[str, Tensor]], dict[str, Tensor]], ...]):
|
|
407
|
-
Callbacks executed immediately after the forward pass of the batch.
|
|
408
|
-
Each callback receives the data dictionary and returns updated versions.
|
|
409
|
-
|
|
410
|
-
Returns:
|
|
411
|
-
None
|
|
412
|
-
"""
|
|
413
|
-
# Set model in training mode
|
|
414
|
-
self.network.train()
|
|
415
|
-
|
|
416
|
-
for data in tqdm(
|
|
417
|
-
self.train_loader,
|
|
418
|
-
desc="Training batches",
|
|
419
|
-
leave=False,
|
|
420
|
-
disable=self.disable_progress_bar_batch,
|
|
421
|
-
):
|
|
422
|
-
# Transfer batch data to GPU
|
|
423
|
-
data: dict[str, Tensor] = {key: value.to(self.device) for key, value in data.items()}
|
|
424
|
-
|
|
425
|
-
# Execute on batch start callbacks
|
|
426
|
-
for callback in on_train_batch_start:
|
|
427
|
-
data = callback(data)
|
|
428
|
-
|
|
429
|
-
# Model computations
|
|
430
|
-
data = self.network(data)
|
|
431
|
-
|
|
432
|
-
# Execute on completion forward pass callbacks
|
|
433
|
-
for callback in on_train_completion_forward_pass:
|
|
434
|
-
data = callback(data)
|
|
435
|
-
|
|
436
|
-
# Calculate loss
|
|
437
|
-
loss = self.criterion(
|
|
438
|
-
data["output"],
|
|
439
|
-
data["target"],
|
|
440
|
-
data=data,
|
|
441
|
-
)
|
|
442
|
-
self.metric_manager.accumulate("Loss/train", loss.unsqueeze(0))
|
|
443
|
-
|
|
444
|
-
# Adjust loss based on constraints
|
|
445
|
-
combined_loss = self.train_step(
|
|
446
|
-
data,
|
|
447
|
-
loss,
|
|
448
|
-
self.constraints,
|
|
449
|
-
self.descriptor,
|
|
450
|
-
self.metric_manager,
|
|
451
|
-
self.device,
|
|
452
|
-
constraint_aggregator=self.constraint_aggregator,
|
|
453
|
-
epsilon=self.epsilon,
|
|
454
|
-
enforce_all=self.enforce_all,
|
|
455
|
-
)
|
|
456
|
-
|
|
457
|
-
# Backprop
|
|
458
|
-
self.optimizer.zero_grad()
|
|
459
|
-
combined_loss.backward(retain_graph=False, inputs=list(self.network.parameters()))
|
|
460
|
-
self.optimizer.step()
|
|
461
|
-
|
|
462
|
-
# Execute on batch end callbacks
|
|
463
|
-
for callback in on_train_batch_end:
|
|
464
|
-
data = callback(data)
|
|
465
|
-
|
|
466
|
-
def _validate_epoch(
|
|
467
|
-
self,
|
|
468
|
-
on_valid_batch_start: tuple[Callable[[dict[str, Tensor]], dict[str, Tensor]]],
|
|
469
|
-
on_valid_batch_end: tuple[Callable[[dict[str, Tensor]], dict[str, Tensor]]],
|
|
470
|
-
on_valid_completion_forward_pass: tuple[Callable[[dict[str, Tensor]], dict[str, Tensor]]],
|
|
471
|
-
) -> None:
|
|
472
|
-
"""Perform a single validation epoch over all batches.
|
|
473
|
-
|
|
474
|
-
This method sets the network to evaluation mode, iterates over the validation
|
|
475
|
-
DataLoader, computes predictions, evaluates losses, and logs constraint
|
|
476
|
-
satisfaction. Optional callbacks can be executed at the start and end of each
|
|
477
|
-
batch, as well as after the forward pass.
|
|
478
|
-
|
|
479
|
-
Args:
|
|
480
|
-
on_valid_batch_start (tuple[Callable[[dict[str, Tensor]], dict[str, Tensor]]]):
|
|
481
|
-
Callbacks executed at the start of each validation batch. Each callback
|
|
482
|
-
receives the data dictionary and returns updated versions.
|
|
483
|
-
on_valid_batch_end (tuple[Callable[[dict[str, Tensor]], dict[str, Tensor]]]):
|
|
484
|
-
Callbacks executed at the end of each validation batch. Each callback
|
|
485
|
-
receives the data dictionary and returns updated versions.
|
|
486
|
-
on_valid_completion_forward_pass (tuple[Callable[[dict[str, Tensor]], dict[str, Tensor]]]):
|
|
487
|
-
Callbacks executed immediately after the forward pass of the validation batch.
|
|
488
|
-
Each callback receives the data dictionary and returns updated versions.
|
|
489
|
-
|
|
490
|
-
Returns:
|
|
491
|
-
None
|
|
492
|
-
"""
|
|
493
|
-
# Set model in evaluation mode
|
|
494
|
-
self.network.eval()
|
|
495
|
-
|
|
496
|
-
# Enable or disable gradient tracking for validation pass
|
|
497
|
-
with torch.set_grad_enabled(self.network_uses_grad):
|
|
498
|
-
# Loop over validation batches
|
|
499
|
-
for data in tqdm(
|
|
500
|
-
self.valid_loader,
|
|
501
|
-
desc="Validation batches",
|
|
502
|
-
leave=False,
|
|
503
|
-
disable=self.disable_progress_bar_batch,
|
|
504
|
-
):
|
|
505
|
-
# Transfer batch data to GPU
|
|
506
|
-
data: dict[str, Tensor] = {
|
|
507
|
-
key: value.to(self.device) for key, value in data.items()
|
|
508
|
-
}
|
|
509
|
-
|
|
510
|
-
# Execute on batch start callbacks
|
|
511
|
-
for callback in on_valid_batch_start:
|
|
512
|
-
data = callback(data)
|
|
513
|
-
|
|
514
|
-
# Model computations
|
|
515
|
-
data = self.network(data)
|
|
516
|
-
|
|
517
|
-
# Execute on completion forward pass callbacks
|
|
518
|
-
for callback in on_valid_completion_forward_pass:
|
|
519
|
-
data = callback(data)
|
|
520
|
-
|
|
521
|
-
# Calculate loss
|
|
522
|
-
loss = self.criterion(
|
|
523
|
-
data["output"],
|
|
524
|
-
data["target"],
|
|
525
|
-
data=data,
|
|
526
|
-
)
|
|
527
|
-
self.metric_manager.accumulate("Loss/valid", loss.unsqueeze(0))
|
|
528
|
-
|
|
529
|
-
# Validate constraints
|
|
530
|
-
self.valid_step(
|
|
531
|
-
data,
|
|
532
|
-
loss,
|
|
533
|
-
self.constraints,
|
|
534
|
-
self.metric_manager,
|
|
535
|
-
)
|
|
536
|
-
|
|
537
|
-
# Execute on batch end callbacks
|
|
538
|
-
for callback in on_valid_batch_end:
|
|
539
|
-
data = callback(data)
|
|
540
|
-
|
|
541
|
-
def _test_model(
|
|
542
|
-
self,
|
|
543
|
-
on_test_batch_start: tuple[Callable[[dict[str, Tensor]], dict[str, Tensor]]],
|
|
544
|
-
on_test_batch_end: tuple[Callable[[dict[str, Tensor]], dict[str, Tensor]]],
|
|
545
|
-
on_test_completion_forward_pass: tuple[Callable[[dict[str, Tensor]], dict[str, Tensor]]],
|
|
546
|
-
) -> None:
|
|
547
|
-
"""Evaluate the model on the test dataset.
|
|
548
|
-
|
|
549
|
-
This method sets the network to evaluation mode, iterates over the test
|
|
550
|
-
DataLoader, computes predictions, evaluates losses, and logs constraint
|
|
551
|
-
satisfaction. Optional callbacks can be executed at the start and end of
|
|
552
|
-
each batch, as well as after the forward pass.
|
|
553
|
-
|
|
554
|
-
Args:
|
|
555
|
-
on_test_batch_start (tuple[Callable[[dict[str, Tensor]], dict[str, Tensor]]]):
|
|
556
|
-
Callbacks executed at the start of each test batch. Each callback
|
|
557
|
-
receives the data dictionary and returns updated versions.
|
|
558
|
-
on_test_batch_end (tuple[Callable[[dict[str, Tensor]], dict[str, Tensor]]]):
|
|
559
|
-
Callbacks executed at the end of each test batch. Each callback
|
|
560
|
-
receives the data dictionary and returns updated versions.
|
|
561
|
-
on_test_completion_forward_pass (tuple[Callable[[dict[str, Tensor]], dict[str, Tensor]]]):
|
|
562
|
-
Callbacks executed immediately after the forward pass of the test batch.
|
|
563
|
-
Each callback receives the data dictionary and returns updated versions.
|
|
564
|
-
|
|
565
|
-
Returns:
|
|
566
|
-
None
|
|
567
|
-
"""
|
|
568
|
-
# Set model in evaluation mode
|
|
569
|
-
self.network.eval()
|
|
570
|
-
|
|
571
|
-
# Enable or disable gradient tracking for validation pass
|
|
572
|
-
with torch.set_grad_enabled(self.network_uses_grad):
|
|
573
|
-
# Loop over test batches
|
|
574
|
-
for data in tqdm(
|
|
575
|
-
self.test_loader,
|
|
576
|
-
desc="Test batches",
|
|
577
|
-
leave=False,
|
|
578
|
-
disable=self.disable_progress_bar_batch,
|
|
579
|
-
):
|
|
580
|
-
# Transfer batch data to GPU
|
|
581
|
-
data: dict[str, Tensor] = {
|
|
582
|
-
key: value.to(self.device) for key, value in data.items()
|
|
583
|
-
}
|
|
584
|
-
|
|
585
|
-
# Execute on batch start callbacks
|
|
586
|
-
for callback in on_test_batch_start:
|
|
587
|
-
data = callback(data)
|
|
588
|
-
|
|
589
|
-
# Model computations
|
|
590
|
-
data = self.network(data)
|
|
591
|
-
|
|
592
|
-
# Execute on completion forward pass callbacks
|
|
593
|
-
for callback in on_test_completion_forward_pass:
|
|
594
|
-
data = callback(data)
|
|
595
|
-
|
|
596
|
-
# Calculate loss
|
|
597
|
-
loss = self.criterion(
|
|
598
|
-
data["output"],
|
|
599
|
-
data["target"],
|
|
600
|
-
data=data,
|
|
601
|
-
)
|
|
602
|
-
self.metric_manager.accumulate("Loss/test", loss.unsqueeze(0))
|
|
603
|
-
|
|
604
|
-
# Validate constraints
|
|
605
|
-
self.test_step(
|
|
606
|
-
data,
|
|
607
|
-
loss,
|
|
608
|
-
self.constraints,
|
|
609
|
-
self.metric_manager,
|
|
610
|
-
)
|
|
611
|
-
|
|
612
|
-
# Execute on batch end callbacks
|
|
613
|
-
for callback in on_test_batch_end:
|
|
614
|
-
data = callback(data)
|
|
615
|
-
|
|
616
|
-
@staticmethod
|
|
617
|
-
def train_step(
|
|
618
|
-
data: dict[str, Tensor],
|
|
619
|
-
loss: Tensor,
|
|
620
|
-
constraints: list[Constraint],
|
|
621
|
-
descriptor: Descriptor,
|
|
622
|
-
metric_manager: MetricManager,
|
|
623
|
-
device: torch.device,
|
|
624
|
-
constraint_aggregator: Callable = torch.sum,
|
|
625
|
-
epsilon: float = 1e-6,
|
|
626
|
-
enforce_all: bool = True,
|
|
627
|
-
) -> Tensor:
|
|
628
|
-
"""Adjust the training loss based on constraints and compute the combined loss.
|
|
629
|
-
|
|
630
|
-
This method calculates the directions in which the network outputs should be
|
|
631
|
-
adjusted to satisfy constraints, scales these adjustments according to the
|
|
632
|
-
constraint's rescale factor and gradient norms, and adds the result to the
|
|
633
|
-
base loss. It also logs the constraint satisfaction ratio (CSR) for monitoring.
|
|
634
|
-
|
|
635
|
-
Args:
|
|
636
|
-
data (dict[str, Tensor]): Dictionary containing the batch data, predictions and additional data.
|
|
637
|
-
loss (Tensor): The base loss computed by the criterion.
|
|
638
|
-
constraints (list[Constraint]): List of constraints to enforce during training.
|
|
639
|
-
descriptor (Descriptor): Descriptor containing layer metadata and variable/loss layer info.
|
|
640
|
-
metric_manager (MetricManager): Metric manager for logging loss and CSR.
|
|
641
|
-
device (torch.device): Device on which computations are performed.
|
|
642
|
-
constraint_aggregator (Callable, optional): Function to aggregate per-layer rescaled losses. Defaults to `torch.mean`.
|
|
643
|
-
epsilon (float, optional): Small value to prevent division by zero in gradient normalization. Defaults to 1e-6.
|
|
644
|
-
enforce_all (bool, optional): If False, constraints are only monitored and do not influence the loss. Defaults to True.
|
|
645
|
-
|
|
646
|
-
Returns:
|
|
647
|
-
Tensor: The combined loss including the original loss and constraint-based adjustments.
|
|
648
|
-
"""
|
|
649
|
-
# Init scalar tensor for loss
|
|
650
|
-
total_rescale_loss = tensor(0, dtype=float32, device=device)
|
|
651
|
-
norm_loss_grad: dict[str, Tensor] = {}
|
|
652
|
-
|
|
653
|
-
# Precalculate loss gradients for each variable layer
|
|
654
|
-
for key in descriptor.variable_keys & descriptor.affects_loss_keys:
|
|
655
|
-
# Calculate gradients of loss w.r.t. predictions
|
|
656
|
-
grad = torch.autograd.grad(
|
|
657
|
-
outputs=loss, inputs=data[key], retain_graph=True, allow_unused=True
|
|
658
|
-
)[0]
|
|
659
|
-
|
|
660
|
-
# If gradients is None, report error
|
|
661
|
-
if grad is None:
|
|
662
|
-
raise RuntimeError(
|
|
663
|
-
f"Unable to compute loss gradients for layer '{key}'. "
|
|
664
|
-
"For layers not connected to the loss, set has_loss=False "
|
|
665
|
-
"when defining them in the Descriptor."
|
|
666
|
-
)
|
|
667
|
-
|
|
668
|
-
# Flatten batch and compute L2 norm along each item
|
|
669
|
-
grad_flat = grad.view(grad.shape[0], -1)
|
|
670
|
-
norm_loss_grad[key] = (
|
|
671
|
-
vector_norm(grad_flat, dim=1, ord=2, keepdim=True).clamp(min=epsilon).detach()
|
|
672
|
-
)
|
|
673
|
-
|
|
674
|
-
for constraint in constraints:
|
|
675
|
-
# Check if constraints are satisfied and calculate directions
|
|
676
|
-
checks, mask = constraint.check_constraint(data)
|
|
677
|
-
directions = constraint.calculate_direction(data)
|
|
678
|
-
|
|
679
|
-
# Log constraint satisfaction ratio
|
|
680
|
-
csr = (sum(checks * mask) / sum(mask)).unsqueeze(0)
|
|
681
|
-
metric_manager.accumulate(f"{constraint.name}/train", csr)
|
|
682
|
-
metric_manager.accumulate("CSR/train", csr)
|
|
683
|
-
|
|
684
|
-
# Only do adjusting calculation if constraint is not observant
|
|
685
|
-
if not enforce_all or not constraint.enforce:
|
|
686
|
-
continue
|
|
687
|
-
|
|
688
|
-
# Only do direction calculations for variable layers affecting constraint
|
|
689
|
-
for key in constraint.layers & descriptor.variable_keys:
|
|
690
|
-
with no_grad():
|
|
691
|
-
# Multiply direction modifiers with constraint result
|
|
692
|
-
constraint_result = (1 - checks) * directions[key]
|
|
693
|
-
|
|
694
|
-
# Multiply result with rescale factor of constraint
|
|
695
|
-
constraint_result *= constraint.rescale_factor
|
|
696
|
-
|
|
697
|
-
# Calculate rescale loss
|
|
698
|
-
total_rescale_loss += constraint_aggregator(
|
|
699
|
-
data[key] * constraint_result * norm_loss_grad[key],
|
|
700
|
-
)
|
|
701
|
-
|
|
702
|
-
# Return combined loss
|
|
703
|
-
return loss + total_rescale_loss
|
|
704
|
-
|
|
705
|
-
@staticmethod
|
|
706
|
-
def valid_step(
|
|
707
|
-
data: dict[str, Tensor],
|
|
708
|
-
loss: Tensor,
|
|
709
|
-
constraints: list[Constraint],
|
|
710
|
-
metric_manager: MetricManager,
|
|
711
|
-
) -> Tensor:
|
|
712
|
-
"""Evaluate constraints during validation and log constraint satisfaction metrics.
|
|
713
|
-
|
|
714
|
-
This method checks whether each constraint is satisfied for the given
|
|
715
|
-
data, computes the constraint satisfaction ratio (CSR),
|
|
716
|
-
and logs it using the metric manager. The base loss is not modified.
|
|
717
|
-
|
|
718
|
-
Args:
|
|
719
|
-
data (dict[str, Tensor]): Dictionary containing the batch data, predictions and additional data.
|
|
720
|
-
loss (Tensor): The base loss computed by the criterion.
|
|
721
|
-
constraints (list[Constraint]): List of constraints to evaluate.
|
|
722
|
-
metric_manager (MetricManager): Metric manager for logging CSR and per-constraint metrics.
|
|
723
|
-
|
|
724
|
-
Returns:
|
|
725
|
-
Tensor: The original, unchanged base loss.
|
|
726
|
-
"""
|
|
727
|
-
# For each constraint in this reference space, calculate directions
|
|
728
|
-
for constraint in constraints:
|
|
729
|
-
# Check if constraints are satisfied for
|
|
730
|
-
checks, mask = constraint.check_constraint(data)
|
|
731
|
-
|
|
732
|
-
# Log constraint satisfaction ratio
|
|
733
|
-
csr = (sum(checks * mask) / sum(mask)).unsqueeze(0)
|
|
734
|
-
metric_manager.accumulate(f"{constraint.name}/valid", csr)
|
|
735
|
-
metric_manager.accumulate("CSR/valid", csr)
|
|
736
|
-
|
|
737
|
-
# Return original loss
|
|
738
|
-
return loss
|
|
739
|
-
|
|
740
|
-
@staticmethod
|
|
741
|
-
def test_step(
|
|
742
|
-
data: dict[str, Tensor],
|
|
743
|
-
loss: Tensor,
|
|
744
|
-
constraints: list[Constraint],
|
|
745
|
-
metric_manager: MetricManager,
|
|
746
|
-
) -> Tensor:
|
|
747
|
-
"""Evaluate constraints during testing and log constraint satisfaction metrics.
|
|
748
|
-
|
|
749
|
-
This method checks whether each constraint is satisfied for the given
|
|
750
|
-
data, computes the constraint satisfaction ratio (CSR),
|
|
751
|
-
and logs it using the metric manager. The base loss is not modified.
|
|
752
|
-
|
|
753
|
-
Args:
|
|
754
|
-
data (dict[str, Tensor]): Dictionary containing the batch data, predictions and additional data.
|
|
755
|
-
loss (Tensor): The base loss computed by the criterion.
|
|
756
|
-
constraints (list[Constraint]): List of constraints to evaluate.
|
|
757
|
-
metric_manager (MetricManager): Metric manager for logging CSR and per-constraint metrics.
|
|
758
|
-
|
|
759
|
-
Returns:
|
|
760
|
-
Tensor: The original, unchanged base loss.
|
|
761
|
-
"""
|
|
762
|
-
# For each constraint in this reference space, calculate directions
|
|
763
|
-
for constraint in constraints:
|
|
764
|
-
# Check if constraints are satisfied for
|
|
765
|
-
checks, mask = constraint.check_constraint(data)
|
|
766
|
-
|
|
767
|
-
# Log constraint satisfaction ratio
|
|
768
|
-
csr = (sum(checks * mask) / sum(mask)).unsqueeze(0)
|
|
769
|
-
metric_manager.accumulate(f"{constraint.name}/test", csr)
|
|
770
|
-
metric_manager.accumulate("CSR/test", csr)
|
|
771
|
-
|
|
772
|
-
# Return original loss
|
|
773
|
-
return loss
|