congrads 0.1.0__py3-none-any.whl → 1.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- congrads/__init__.py +21 -13
- congrads/checkpoints.py +232 -0
- congrads/constraints.py +728 -316
- congrads/core.py +525 -139
- congrads/datasets.py +273 -516
- congrads/descriptor.py +95 -30
- congrads/metrics.py +185 -38
- congrads/networks.py +51 -28
- congrads/requirements.txt +6 -0
- congrads/transformations.py +139 -0
- congrads/utils.py +710 -0
- congrads-1.0.1.dist-info/LICENSE +26 -0
- congrads-1.0.1.dist-info/METADATA +208 -0
- congrads-1.0.1.dist-info/RECORD +16 -0
- {congrads-0.1.0.dist-info → congrads-1.0.1.dist-info}/WHEEL +1 -1
- congrads/learners.py +0 -233
- congrads-0.1.0.dist-info/LICENSE +0 -34
- congrads-0.1.0.dist-info/METADATA +0 -196
- congrads-0.1.0.dist-info/RECORD +0 -13
- {congrads-0.1.0.dist-info → congrads-1.0.1.dist-info}/top_level.txt +0 -0
congrads/core.py
CHANGED
|
@@ -1,211 +1,597 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
1
|
+
"""
|
|
2
|
+
This module provides the CongradsCore class, which is designed to integrate
|
|
3
|
+
constraint-guided optimization into neural network training.
|
|
4
|
+
It extends traditional training processes by enforcing specific constraints
|
|
5
|
+
on the model's outputs, ensuring that the network satisfies domain-specific
|
|
6
|
+
requirements during both training and evaluation.
|
|
7
|
+
|
|
8
|
+
The `CongradsCore` class serves as the central engine for managing the
|
|
9
|
+
training, validation, and testing phases of a neural network model,
|
|
10
|
+
incorporating constraints that influence the loss function and model updates.
|
|
11
|
+
The model is trained with standard loss functions while also incorporating
|
|
12
|
+
constraint-based adjustments, which are tracked and logged
|
|
13
|
+
throughout the process.
|
|
14
|
+
|
|
15
|
+
Key features:
|
|
16
|
+
- Support for various constraints that can influence the training process.
|
|
17
|
+
- Integration with PyTorch's `DataLoader` for efficient batch processing.
|
|
18
|
+
- Metric management for tracking loss and constraint satisfaction.
|
|
19
|
+
- Checkpoint management for saving and evaluating model states.
|
|
20
|
+
|
|
21
|
+
Modules in this package provide the following:
|
|
22
|
+
|
|
23
|
+
- `Descriptor`: Describes variable layers in the network that are
|
|
24
|
+
subject to constraints.
|
|
25
|
+
- `Constraint`: Defines various constraints, which are used to guide
|
|
26
|
+
the training process.
|
|
27
|
+
- `MetricManager`: Manages and tracks performance metrics such as loss
|
|
28
|
+
and constraint satisfaction.
|
|
29
|
+
- `CheckpointManager`: Manages saving and loading model checkpoints
|
|
30
|
+
during training.
|
|
31
|
+
- Utility functions to validate inputs and configurations.
|
|
32
|
+
|
|
33
|
+
Dependencies:
|
|
34
|
+
- PyTorch (`torch`)
|
|
35
|
+
- tqdm (for progress tracking)
|
|
36
|
+
|
|
37
|
+
The `CongradsCore` class allows for the use of additional callback functions
|
|
38
|
+
at different stages of the training process to customize behavior for
|
|
39
|
+
specific needs. These include callbacks for the start and end of epochs, as
|
|
40
|
+
well as the start and end of the entire training process.
|
|
41
|
+
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
import warnings
|
|
45
|
+
from numbers import Number
|
|
46
|
+
from typing import Callable
|
|
47
|
+
|
|
48
|
+
import torch
|
|
49
|
+
|
|
50
|
+
# pylint: disable-next=redefined-builtin
|
|
51
|
+
from torch import Tensor, float32, maximum, no_grad, norm, numel, sum, tensor
|
|
52
|
+
from torch.nn import Module
|
|
53
|
+
from torch.nn.modules.loss import _Loss
|
|
54
|
+
from torch.optim import Optimizer
|
|
55
|
+
from torch.utils.data import DataLoader
|
|
56
|
+
from tqdm import tqdm
|
|
57
|
+
|
|
58
|
+
from .checkpoints import CheckpointManager
|
|
8
59
|
from .constraints import Constraint
|
|
9
|
-
from .metrics import ConstraintSatisfactionRatio
|
|
10
60
|
from .descriptor import Descriptor
|
|
61
|
+
from .metrics import MetricManager
|
|
62
|
+
from .utils import (
|
|
63
|
+
validate_callable,
|
|
64
|
+
validate_iterable,
|
|
65
|
+
validate_loaders,
|
|
66
|
+
validate_type,
|
|
67
|
+
)
|
|
11
68
|
|
|
12
69
|
|
|
13
|
-
class
|
|
70
|
+
class CongradsCore:
|
|
14
71
|
"""
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
72
|
+
The CongradsCore class is the central training engine for constraint-guided
|
|
73
|
+
neural network optimization. It integrates standard neural network training
|
|
74
|
+
with additional constraint-driven adjustments to the loss function, ensuring
|
|
75
|
+
that the network satisfies domain-specific constraints during training.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
descriptor (Descriptor): Describes variable layers in the network.
|
|
79
|
+
constraints (list[Constraint]): List of constraints to guide training.
|
|
80
|
+
loaders (tuple[DataLoader, DataLoader, DataLoader]): DataLoaders for
|
|
81
|
+
training, validation, and testing.
|
|
82
|
+
network (Module): The neural network model to train.
|
|
83
|
+
criterion (callable): The loss function used for
|
|
84
|
+
training and validation.
|
|
85
|
+
optimizer (Optimizer): The optimizer used for updating model parameters.
|
|
86
|
+
metric_manager (MetricManager): Manages metric tracking and recording.
|
|
87
|
+
device (torch.device): The device (e.g., CPU or GPU) for computations.
|
|
88
|
+
checkpoint_manager (CheckpointManager, optional): Manages
|
|
89
|
+
checkpointing. If not set, no checkpointing is done.
|
|
90
|
+
epsilon (Number, optional): A small value to avoid division by zero
|
|
91
|
+
in gradient calculations. Default is 1e-10.
|
|
92
|
+
|
|
93
|
+
Note:
|
|
94
|
+
A warning is logged if the descriptor has no variable layers,
|
|
95
|
+
as at least one variable layer is required for the constraint logic
|
|
96
|
+
to influence the training process.
|
|
29
97
|
"""
|
|
30
98
|
|
|
31
|
-
def __init__(
|
|
99
|
+
def __init__(
|
|
100
|
+
self,
|
|
101
|
+
descriptor: Descriptor,
|
|
102
|
+
constraints: list[Constraint],
|
|
103
|
+
loaders: tuple[DataLoader, DataLoader, DataLoader],
|
|
104
|
+
network: Module,
|
|
105
|
+
criterion: _Loss,
|
|
106
|
+
optimizer: Optimizer,
|
|
107
|
+
metric_manager: MetricManager,
|
|
108
|
+
device: torch.device,
|
|
109
|
+
checkpoint_manager: CheckpointManager = None,
|
|
110
|
+
epsilon: Number = 1e-6,
|
|
111
|
+
):
|
|
32
112
|
"""
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
Args:
|
|
36
|
-
descriptor (Descriptor): The object that describes the network's layers and neurons, including their categorization.
|
|
37
|
-
constraints (list[Constraint]): A list of constraints that will guide the optimization process.
|
|
38
|
-
|
|
39
|
-
Raises:
|
|
40
|
-
Warning if there are no variable layers in the descriptor, as constraints will not be applied.
|
|
113
|
+
Initialize the CongradsCore object.
|
|
41
114
|
"""
|
|
42
115
|
|
|
43
|
-
#
|
|
44
|
-
|
|
116
|
+
# Type checking
|
|
117
|
+
validate_type("descriptor", descriptor, Descriptor)
|
|
118
|
+
validate_iterable("constraints", constraints, Constraint)
|
|
119
|
+
validate_loaders()
|
|
120
|
+
validate_type("network", network, Module)
|
|
121
|
+
validate_type("criterion", criterion, _Loss)
|
|
122
|
+
validate_type("optimizer", optimizer, Optimizer)
|
|
123
|
+
validate_type("metric_manager", metric_manager, MetricManager)
|
|
124
|
+
validate_type("device", device, torch.device)
|
|
125
|
+
validate_type(
|
|
126
|
+
"checkpoint_manager",
|
|
127
|
+
checkpoint_manager,
|
|
128
|
+
CheckpointManager,
|
|
129
|
+
allow_none=True,
|
|
130
|
+
)
|
|
131
|
+
validate_type("epsilon", epsilon, Number)
|
|
45
132
|
|
|
46
133
|
# Init object variables
|
|
47
134
|
self.descriptor = descriptor
|
|
48
135
|
self.constraints = constraints
|
|
136
|
+
self.train_loader = loaders[0]
|
|
137
|
+
self.valid_loader = loaders[1]
|
|
138
|
+
self.test_loader = loaders[2]
|
|
139
|
+
self.network = network
|
|
140
|
+
self.criterion = criterion
|
|
141
|
+
self.optimizer = optimizer
|
|
142
|
+
self.metric_manager = metric_manager
|
|
143
|
+
self.device = device
|
|
144
|
+
self.checkpoint_manager = checkpoint_manager
|
|
145
|
+
|
|
146
|
+
# Init epsilon tensor
|
|
147
|
+
self.epsilon = tensor(epsilon, device=self.device)
|
|
49
148
|
|
|
50
149
|
# Perform checks
|
|
51
150
|
if len(self.descriptor.variable_layers) == 0:
|
|
52
|
-
|
|
53
|
-
"The descriptor object has no variable layers. The constraint
|
|
151
|
+
warnings.warn(
|
|
152
|
+
"The descriptor object has no variable layers. The constraint \
|
|
153
|
+
guided loss adjustment is therefore not used. \
|
|
154
|
+
Is this the intended behavior?"
|
|
54
155
|
)
|
|
55
156
|
|
|
56
|
-
#
|
|
157
|
+
# Initialize constraint metrics
|
|
158
|
+
self._initialize_metrics()
|
|
159
|
+
|
|
160
|
+
def _initialize_metrics(self) -> None:
|
|
161
|
+
"""
|
|
162
|
+
Register metrics for loss, constraint satisfaction ratio (CSR),
|
|
163
|
+
and individual constraints.
|
|
164
|
+
|
|
165
|
+
This method registers the following metrics:
|
|
166
|
+
|
|
167
|
+
- Loss/train: Training loss.
|
|
168
|
+
- Loss/valid: Validation loss.
|
|
169
|
+
- Loss/test: Test loss after training.
|
|
170
|
+
- CSR/train: Constraint satisfaction ratio during training.
|
|
171
|
+
- CSR/valid: Constraint satisfaction ratio during validation.
|
|
172
|
+
- CSR/test: Constraint satisfaction ratio after training.
|
|
173
|
+
- One metric per constraint, for both training and validation.
|
|
174
|
+
|
|
175
|
+
"""
|
|
176
|
+
|
|
177
|
+
self.metric_manager.register("Loss/train", "during_training")
|
|
178
|
+
self.metric_manager.register("Loss/valid", "during_training")
|
|
179
|
+
self.metric_manager.register("Loss/test", "after_training")
|
|
180
|
+
|
|
181
|
+
if len(self.constraints) > 0:
|
|
182
|
+
self.metric_manager.register("CSR/train", "during_training")
|
|
183
|
+
self.metric_manager.register("CSR/valid", "during_training")
|
|
184
|
+
self.metric_manager.register("CSR/test", "after_training")
|
|
185
|
+
|
|
57
186
|
for constraint in self.constraints:
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
)
|
|
68
|
-
self.train_csr["global"] = ConstraintSatisfactionRatio()
|
|
69
|
-
self.valid_csr: Dict[str, Metric] = ModuleDict(
|
|
70
|
-
{
|
|
71
|
-
constraint.constraint_name: ConstraintSatisfactionRatio()
|
|
72
|
-
for constraint in self.constraints
|
|
73
|
-
}
|
|
74
|
-
)
|
|
75
|
-
self.valid_csr["global"] = ConstraintSatisfactionRatio()
|
|
187
|
+
self.metric_manager.register(
|
|
188
|
+
f"{constraint.name}/train", "during_training"
|
|
189
|
+
)
|
|
190
|
+
self.metric_manager.register(
|
|
191
|
+
f"{constraint.name}/valid", "during_training"
|
|
192
|
+
)
|
|
193
|
+
self.metric_manager.register(
|
|
194
|
+
f"{constraint.name}/test", "after_training"
|
|
195
|
+
)
|
|
76
196
|
|
|
77
|
-
def
|
|
197
|
+
def fit(
|
|
198
|
+
self,
|
|
199
|
+
start_epoch: int = 0,
|
|
200
|
+
max_epochs: int = 100,
|
|
201
|
+
on_epoch_start: Callable[[int], None] = None,
|
|
202
|
+
on_epoch_end: Callable[[int], None] = None,
|
|
203
|
+
on_train_start: Callable[[int], None] = None,
|
|
204
|
+
on_train_end: Callable[[int], None] = None,
|
|
205
|
+
) -> None:
|
|
206
|
+
"""
|
|
207
|
+
Train the model for a given number of epochs.
|
|
208
|
+
|
|
209
|
+
Args:
|
|
210
|
+
start_epoch (int, optional): The epoch number to start the training
|
|
211
|
+
with. Default is 0.
|
|
212
|
+
max_epochs (int, optional): The number of epochs to train the
|
|
213
|
+
model. Default is 100.
|
|
214
|
+
on_epoch_start (Callable[[int], None], optional): A callback
|
|
215
|
+
function that will be executed at the start of each epoch.
|
|
216
|
+
on_epoch_end (Callable[[int], None], optional): A callback
|
|
217
|
+
function that will be executed at the end of each epoch.
|
|
218
|
+
on_train_start (Callable[[int], None], optional): A callback
|
|
219
|
+
function that will be executed before the training starts.
|
|
220
|
+
on_train_end (Callable[[int], None], optional): A callback
|
|
221
|
+
function that will be executed after training ends.
|
|
222
|
+
"""
|
|
223
|
+
|
|
224
|
+
# Type checking
|
|
225
|
+
validate_type("start_epoch", start_epoch, int)
|
|
226
|
+
validate_callable("on_epoch_start", on_epoch_start, True)
|
|
227
|
+
validate_callable("on_epoch_end", on_epoch_end, True)
|
|
228
|
+
validate_callable("on_train_start", on_train_start, True)
|
|
229
|
+
validate_callable("on_train_end", on_train_end, True)
|
|
230
|
+
|
|
231
|
+
# Keep track of epoch
|
|
232
|
+
epoch = start_epoch
|
|
233
|
+
|
|
234
|
+
# Execute training start hook if set
|
|
235
|
+
if on_train_start:
|
|
236
|
+
on_train_start(epoch)
|
|
237
|
+
|
|
238
|
+
for i in tqdm(range(epoch, max_epochs), initial=epoch, desc="Epoch"):
|
|
239
|
+
epoch = i
|
|
240
|
+
|
|
241
|
+
# Execute epoch start hook if set
|
|
242
|
+
if on_epoch_start:
|
|
243
|
+
on_epoch_start(epoch)
|
|
244
|
+
|
|
245
|
+
# Execute training and validation epoch
|
|
246
|
+
self._train_epoch()
|
|
247
|
+
self._validate_epoch()
|
|
248
|
+
|
|
249
|
+
# Checkpointing
|
|
250
|
+
if self.checkpoint_manager:
|
|
251
|
+
self.checkpoint_manager.evaluate_criteria(epoch)
|
|
252
|
+
|
|
253
|
+
# Execute epoch end hook if set
|
|
254
|
+
if on_epoch_end:
|
|
255
|
+
on_epoch_end(epoch)
|
|
256
|
+
|
|
257
|
+
# Evaluate model performance on unseen test set
|
|
258
|
+
self._test_model()
|
|
259
|
+
|
|
260
|
+
# Save final model
|
|
261
|
+
if self.checkpoint_manager:
|
|
262
|
+
self.checkpoint_manager.save(epoch, "checkpoint_final.pth")
|
|
263
|
+
|
|
264
|
+
# Execute training end hook if set
|
|
265
|
+
if on_train_end:
|
|
266
|
+
on_train_end(epoch)
|
|
267
|
+
|
|
268
|
+
def _train_epoch(self) -> None:
|
|
269
|
+
"""
|
|
270
|
+
Perform training for a single epoch.
|
|
271
|
+
|
|
272
|
+
This method:
|
|
273
|
+
- Sets the model to training mode.
|
|
274
|
+
- Processes batches from the training DataLoader.
|
|
275
|
+
- Computes predictions and losses.
|
|
276
|
+
- Adjusts losses based on constraints.
|
|
277
|
+
- Updates model parameters using backpropagation.
|
|
278
|
+
|
|
279
|
+
Args:
|
|
280
|
+
epoch (int): The current epoch number.
|
|
281
|
+
"""
|
|
282
|
+
|
|
283
|
+
# Set model in training mode
|
|
284
|
+
self.network.train()
|
|
285
|
+
|
|
286
|
+
for batch in tqdm(
|
|
287
|
+
self.train_loader, desc="Training batches", leave=False
|
|
288
|
+
):
|
|
289
|
+
|
|
290
|
+
# Get input-output pairs from batch
|
|
291
|
+
inputs, outputs = batch
|
|
292
|
+
|
|
293
|
+
# Transfer to GPU
|
|
294
|
+
inputs, outputs = inputs.to(self.device), outputs.to(self.device)
|
|
295
|
+
|
|
296
|
+
# Model computations
|
|
297
|
+
prediction = self.network(inputs)
|
|
298
|
+
|
|
299
|
+
# Calculate loss
|
|
300
|
+
loss = self.criterion(prediction["output"], outputs)
|
|
301
|
+
self.metric_manager.accumulate("Loss/train", loss.unsqueeze(0))
|
|
302
|
+
|
|
303
|
+
# Adjust loss based on constraints
|
|
304
|
+
combined_loss = self.train_step(prediction, loss)
|
|
305
|
+
|
|
306
|
+
# Backprop
|
|
307
|
+
self.optimizer.zero_grad()
|
|
308
|
+
combined_loss.backward(
|
|
309
|
+
retain_graph=False, inputs=list(self.network.parameters())
|
|
310
|
+
)
|
|
311
|
+
self.optimizer.step()
|
|
312
|
+
|
|
313
|
+
def _validate_epoch(self) -> None:
|
|
314
|
+
"""
|
|
315
|
+
Perform validation for a single epoch.
|
|
316
|
+
|
|
317
|
+
This method:
|
|
318
|
+
- Sets the model to evaluation mode.
|
|
319
|
+
- Processes batches from the validation DataLoader.
|
|
320
|
+
- Computes predictions and losses.
|
|
321
|
+
- Logs constraint satisfaction ratios.
|
|
322
|
+
|
|
323
|
+
Args:
|
|
324
|
+
epoch (int): The current epoch number.
|
|
325
|
+
"""
|
|
326
|
+
|
|
327
|
+
# Set model in evaluation mode
|
|
328
|
+
self.network.eval()
|
|
329
|
+
|
|
330
|
+
with no_grad():
|
|
331
|
+
for batch in tqdm(
|
|
332
|
+
self.valid_loader, desc="Validation batches", leave=False
|
|
333
|
+
):
|
|
334
|
+
|
|
335
|
+
# Get input-output pairs from batch
|
|
336
|
+
inputs, outputs = batch
|
|
337
|
+
|
|
338
|
+
# Transfer to GPU
|
|
339
|
+
inputs, outputs = inputs.to(self.device), outputs.to(
|
|
340
|
+
self.device
|
|
341
|
+
)
|
|
342
|
+
|
|
343
|
+
# Model computations
|
|
344
|
+
prediction = self.network(inputs)
|
|
345
|
+
|
|
346
|
+
# Calculate loss
|
|
347
|
+
loss = self.criterion(prediction["output"], outputs)
|
|
348
|
+
self.metric_manager.accumulate("Loss/valid", loss.unsqueeze(0))
|
|
349
|
+
|
|
350
|
+
# Validate constraints
|
|
351
|
+
self.valid_step(prediction, loss)
|
|
352
|
+
|
|
353
|
+
def _test_model(self) -> None:
|
|
354
|
+
"""
|
|
355
|
+
Evaluate model performance on the test set.
|
|
356
|
+
|
|
357
|
+
This method:
|
|
358
|
+
- Sets the model to evaluation mode.
|
|
359
|
+
- Processes batches from the test DataLoader.
|
|
360
|
+
- Computes predictions and losses.
|
|
361
|
+
- Logs constraint satisfaction ratios.
|
|
362
|
+
|
|
363
|
+
"""
|
|
364
|
+
|
|
365
|
+
# Set model in evaluation mode
|
|
366
|
+
self.network.eval()
|
|
367
|
+
|
|
368
|
+
with no_grad():
|
|
369
|
+
for batch in tqdm(
|
|
370
|
+
self.test_loader, desc="Test batches", leave=False
|
|
371
|
+
):
|
|
372
|
+
|
|
373
|
+
# Get input-output pairs from batch
|
|
374
|
+
inputs, outputs = batch
|
|
375
|
+
|
|
376
|
+
# Transfer to GPU
|
|
377
|
+
inputs, outputs = inputs.to(self.device), outputs.to(
|
|
378
|
+
self.device
|
|
379
|
+
)
|
|
380
|
+
|
|
381
|
+
# Model computations
|
|
382
|
+
prediction = self.network(inputs)
|
|
383
|
+
|
|
384
|
+
# Calculate loss
|
|
385
|
+
loss = self.criterion(prediction["output"], outputs)
|
|
386
|
+
self.metric_manager.accumulate("Loss/test", loss.unsqueeze(0))
|
|
387
|
+
|
|
388
|
+
# Validate constraints
|
|
389
|
+
self.test_step(prediction, loss)
|
|
390
|
+
|
|
391
|
+
def train_step(
|
|
78
392
|
self,
|
|
79
393
|
prediction: dict[str, Tensor],
|
|
80
394
|
loss: Tensor,
|
|
81
|
-
):
|
|
395
|
+
) -> Tensor:
|
|
82
396
|
"""
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
For each constraint, the satisfaction ratio is checked, and the loss is adjusted by adding a rescale loss
|
|
86
|
-
based on the directions calculated by the constraint.
|
|
397
|
+
Adjust the training loss based on constraints
|
|
398
|
+
and compute the combined loss.
|
|
87
399
|
|
|
88
400
|
Args:
|
|
89
|
-
prediction (dict[str, Tensor]):
|
|
90
|
-
|
|
401
|
+
prediction (dict[str, Tensor]): Model predictions
|
|
402
|
+
for variable layers.
|
|
403
|
+
loss (Tensor): The base loss computed by the criterion.
|
|
91
404
|
|
|
92
405
|
Returns:
|
|
93
|
-
Tensor: The combined loss
|
|
406
|
+
Tensor: The combined loss (base loss + constraint adjustments).
|
|
94
407
|
"""
|
|
95
408
|
|
|
96
409
|
# Init scalar tensor for loss
|
|
97
410
|
total_rescale_loss = tensor(0, dtype=float32, device=self.device)
|
|
411
|
+
loss_grads = {}
|
|
98
412
|
|
|
99
|
-
#
|
|
413
|
+
# Precalculate loss gradients for each variable layer
|
|
100
414
|
with no_grad():
|
|
415
|
+
for layer in self.descriptor.variable_layers:
|
|
416
|
+
self.optimizer.zero_grad()
|
|
417
|
+
loss.backward(retain_graph=True, inputs=prediction[layer])
|
|
418
|
+
loss_grads[layer] = prediction[layer].grad
|
|
101
419
|
|
|
102
|
-
|
|
103
|
-
for constraint in self.constraints:
|
|
420
|
+
for constraint in self.constraints:
|
|
104
421
|
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
422
|
+
# Check if constraints are satisfied and calculate directions
|
|
423
|
+
with no_grad():
|
|
424
|
+
constraint_checks, relevant_constraint_count = (
|
|
425
|
+
constraint.check_constraint(prediction)
|
|
426
|
+
)
|
|
427
|
+
|
|
428
|
+
# Only do adjusting calculation if constraint is not observant
|
|
429
|
+
if not constraint.monitor_only:
|
|
430
|
+
with no_grad():
|
|
431
|
+
constraint_directions = constraint.calculate_direction(
|
|
432
|
+
prediction
|
|
433
|
+
)
|
|
108
434
|
|
|
109
|
-
# Only do direction calculations for variable
|
|
110
|
-
|
|
435
|
+
# Only do direction calculations for variable
|
|
436
|
+
# layers affecting constraint
|
|
437
|
+
for layer in (
|
|
438
|
+
constraint.layers & self.descriptor.variable_layers
|
|
439
|
+
):
|
|
111
440
|
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
441
|
+
with no_grad():
|
|
442
|
+
# Multiply direction modifiers with constraint result
|
|
443
|
+
constraint_result = (
|
|
444
|
+
1 - constraint_checks.unsqueeze(1)
|
|
445
|
+
) * constraint_directions[layer]
|
|
117
446
|
|
|
118
|
-
|
|
119
|
-
|
|
447
|
+
# Multiply result with rescale factor of constraint
|
|
448
|
+
constraint_result *= constraint.rescale_factor
|
|
120
449
|
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
450
|
+
# Calculate loss gradient norm
|
|
451
|
+
norm_loss_grad = norm(
|
|
452
|
+
loss_grads[layer], dim=1, p=2, keepdim=True
|
|
453
|
+
)
|
|
124
454
|
|
|
125
|
-
|
|
126
|
-
|
|
455
|
+
# Apply minimum epsilon
|
|
456
|
+
norm_loss_grad = maximum(norm_loss_grad, self.epsilon)
|
|
127
457
|
|
|
128
458
|
# Calculate rescale loss
|
|
129
459
|
rescale_loss = (
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
.
|
|
133
|
-
)
|
|
460
|
+
prediction[layer]
|
|
461
|
+
* constraint_result
|
|
462
|
+
* norm_loss_grad.detach().clone()
|
|
463
|
+
).mean()
|
|
134
464
|
|
|
135
465
|
# Store rescale loss for this reference space
|
|
136
466
|
total_rescale_loss += rescale_loss
|
|
137
467
|
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
on_epoch=True,
|
|
468
|
+
# Log constraint satisfaction ratio
|
|
469
|
+
self.metric_manager.accumulate(
|
|
470
|
+
f"{constraint.name}/train",
|
|
471
|
+
(
|
|
472
|
+
(
|
|
473
|
+
sum(constraint_checks)
|
|
474
|
+
- numel(constraint_checks)
|
|
475
|
+
+ relevant_constraint_count
|
|
147
476
|
)
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
477
|
+
/ relevant_constraint_count
|
|
478
|
+
).unsqueeze(0),
|
|
479
|
+
)
|
|
480
|
+
self.metric_manager.accumulate(
|
|
481
|
+
"CSR/train",
|
|
482
|
+
(
|
|
483
|
+
(
|
|
484
|
+
sum(constraint_checks)
|
|
485
|
+
- numel(constraint_checks)
|
|
486
|
+
+ relevant_constraint_count
|
|
487
|
+
)
|
|
488
|
+
/ relevant_constraint_count
|
|
489
|
+
).unsqueeze(0),
|
|
490
|
+
)
|
|
156
491
|
|
|
157
492
|
# Return combined loss
|
|
158
493
|
return loss + total_rescale_loss
|
|
159
494
|
|
|
160
|
-
def
|
|
495
|
+
def valid_step(
|
|
161
496
|
self,
|
|
162
497
|
prediction: dict[str, Tensor],
|
|
163
498
|
loss: Tensor,
|
|
164
|
-
):
|
|
499
|
+
) -> Tensor:
|
|
165
500
|
"""
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
Similar to the training step, but without updating the loss, this method tracks the constraint satisfaction
|
|
169
|
-
during validation.
|
|
501
|
+
Evaluate constraints during validation and log satisfaction metrics.
|
|
170
502
|
|
|
171
503
|
Args:
|
|
172
|
-
prediction (dict[str, Tensor]):
|
|
173
|
-
|
|
504
|
+
prediction (dict[str, Tensor]): Model predictions for
|
|
505
|
+
variable layers.
|
|
506
|
+
loss (Tensor): The base loss computed by the criterion.
|
|
174
507
|
|
|
175
508
|
Returns:
|
|
176
|
-
Tensor: The base loss
|
|
509
|
+
Tensor: The unchanged base loss.
|
|
177
510
|
"""
|
|
178
511
|
|
|
179
|
-
#
|
|
180
|
-
|
|
512
|
+
# For each constraint in this reference space, calculate directions
|
|
513
|
+
for constraint in self.constraints:
|
|
181
514
|
|
|
182
|
-
#
|
|
183
|
-
|
|
515
|
+
# Check if constraints are satisfied for
|
|
516
|
+
constraint_checks, relevant_constraint_count = (
|
|
517
|
+
constraint.check_constraint(prediction)
|
|
518
|
+
)
|
|
184
519
|
|
|
185
|
-
|
|
186
|
-
|
|
520
|
+
# Log constraint satisfaction ratio
|
|
521
|
+
self.metric_manager.accumulate(
|
|
522
|
+
f"{constraint.name}/valid",
|
|
523
|
+
(
|
|
524
|
+
(
|
|
525
|
+
sum(constraint_checks)
|
|
526
|
+
- numel(constraint_checks)
|
|
527
|
+
+ relevant_constraint_count
|
|
528
|
+
)
|
|
529
|
+
/ relevant_constraint_count
|
|
530
|
+
).unsqueeze(0),
|
|
531
|
+
)
|
|
532
|
+
self.metric_manager.accumulate(
|
|
533
|
+
"CSR/valid",
|
|
534
|
+
(
|
|
535
|
+
(
|
|
536
|
+
sum(constraint_checks)
|
|
537
|
+
- numel(constraint_checks)
|
|
538
|
+
+ relevant_constraint_count
|
|
539
|
+
)
|
|
540
|
+
/ relevant_constraint_count
|
|
541
|
+
).unsqueeze(0),
|
|
542
|
+
)
|
|
187
543
|
|
|
188
|
-
|
|
189
|
-
|
|
544
|
+
# Return loss
|
|
545
|
+
return loss
|
|
190
546
|
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
on_step=False,
|
|
199
|
-
on_epoch=True,
|
|
200
|
-
)
|
|
547
|
+
def test_step(
|
|
548
|
+
self,
|
|
549
|
+
prediction: dict[str, Tensor],
|
|
550
|
+
loss: Tensor,
|
|
551
|
+
) -> Tensor:
|
|
552
|
+
"""
|
|
553
|
+
Evaluate constraints during test and log satisfaction metrics.
|
|
201
554
|
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
555
|
+
Args:
|
|
556
|
+
prediction (dict[str, Tensor]): Model predictions
|
|
557
|
+
for variable layers.
|
|
558
|
+
loss (Tensor): The base loss computed by the criterion.
|
|
559
|
+
|
|
560
|
+
Returns:
|
|
561
|
+
Tensor: The unchanged base loss.
|
|
562
|
+
"""
|
|
563
|
+
|
|
564
|
+
# For each constraint in this reference space, calculate directions
|
|
565
|
+
for constraint in self.constraints:
|
|
566
|
+
|
|
567
|
+
# Check if constraints are satisfied for
|
|
568
|
+
constraint_checks, relevant_constraint_count = (
|
|
569
|
+
constraint.check_constraint(prediction)
|
|
570
|
+
)
|
|
571
|
+
|
|
572
|
+
# Log constraint satisfaction ratio
|
|
573
|
+
self.metric_manager.accumulate(
|
|
574
|
+
f"{constraint.name}/test",
|
|
575
|
+
(
|
|
576
|
+
(
|
|
577
|
+
sum(constraint_checks)
|
|
578
|
+
- numel(constraint_checks)
|
|
579
|
+
+ relevant_constraint_count
|
|
580
|
+
)
|
|
581
|
+
/ relevant_constraint_count
|
|
582
|
+
).unsqueeze(0),
|
|
583
|
+
)
|
|
584
|
+
self.metric_manager.accumulate(
|
|
585
|
+
"CSR/test",
|
|
586
|
+
(
|
|
587
|
+
(
|
|
588
|
+
sum(constraint_checks)
|
|
589
|
+
- numel(constraint_checks)
|
|
590
|
+
+ relevant_constraint_count
|
|
591
|
+
)
|
|
592
|
+
/ relevant_constraint_count
|
|
593
|
+
).unsqueeze(0),
|
|
594
|
+
)
|
|
209
595
|
|
|
210
596
|
# Return loss
|
|
211
597
|
return loss
|