congrads 0.2.0__py3-none-any.whl → 1.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- congrads/__init__.py +17 -10
- congrads/checkpoints.py +232 -0
- congrads/constraints.py +664 -134
- congrads/core.py +482 -110
- congrads/datasets.py +315 -11
- congrads/descriptor.py +100 -20
- congrads/metrics.py +178 -16
- congrads/networks.py +47 -23
- congrads/requirements.txt +6 -0
- congrads/transformations.py +139 -0
- congrads/utils.py +439 -39
- congrads-1.0.1.dist-info/METADATA +208 -0
- congrads-1.0.1.dist-info/RECORD +16 -0
- {congrads-0.2.0.dist-info → congrads-1.0.1.dist-info}/WHEEL +1 -1
- congrads-0.2.0.dist-info/METADATA +0 -222
- congrads-0.2.0.dist-info/RECORD +0 -13
- {congrads-0.2.0.dist-info → congrads-1.0.1.dist-info}/LICENSE +0 -0
- {congrads-0.2.0.dist-info → congrads-1.0.1.dist-info}/top_level.txt +0 -0
congrads/core.py
CHANGED
|
@@ -1,16 +1,100 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
3
|
-
|
|
1
|
+
"""
|
|
2
|
+
This module provides the CongradsCore class, which is designed to integrate
|
|
3
|
+
constraint-guided optimization into neural network training.
|
|
4
|
+
It extends traditional training processes by enforcing specific constraints
|
|
5
|
+
on the model's outputs, ensuring that the network satisfies domain-specific
|
|
6
|
+
requirements during both training and evaluation.
|
|
7
|
+
|
|
8
|
+
The `CongradsCore` class serves as the central engine for managing the
|
|
9
|
+
training, validation, and testing phases of a neural network model,
|
|
10
|
+
incorporating constraints that influence the loss function and model updates.
|
|
11
|
+
The model is trained with standard loss functions while also incorporating
|
|
12
|
+
constraint-based adjustments, which are tracked and logged
|
|
13
|
+
throughout the process.
|
|
14
|
+
|
|
15
|
+
Key features:
|
|
16
|
+
- Support for various constraints that can influence the training process.
|
|
17
|
+
- Integration with PyTorch's `DataLoader` for efficient batch processing.
|
|
18
|
+
- Metric management for tracking loss and constraint satisfaction.
|
|
19
|
+
- Checkpoint management for saving and evaluating model states.
|
|
20
|
+
|
|
21
|
+
Modules in this package provide the following:
|
|
22
|
+
|
|
23
|
+
- `Descriptor`: Describes variable layers in the network that are
|
|
24
|
+
subject to constraints.
|
|
25
|
+
- `Constraint`: Defines various constraints, which are used to guide
|
|
26
|
+
the training process.
|
|
27
|
+
- `MetricManager`: Manages and tracks performance metrics such as loss
|
|
28
|
+
and constraint satisfaction.
|
|
29
|
+
- `CheckpointManager`: Manages saving and loading model checkpoints
|
|
30
|
+
during training.
|
|
31
|
+
- Utility functions to validate inputs and configurations.
|
|
32
|
+
|
|
33
|
+
Dependencies:
|
|
34
|
+
- PyTorch (`torch`)
|
|
35
|
+
- tqdm (for progress tracking)
|
|
36
|
+
|
|
37
|
+
The `CongradsCore` class allows for the use of additional callback functions
|
|
38
|
+
at different stages of the training process to customize behavior for
|
|
39
|
+
specific needs. These include callbacks for the start and end of epochs, as
|
|
40
|
+
well as the start and end of the entire training process.
|
|
41
|
+
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
import warnings
|
|
45
|
+
from numbers import Number
|
|
46
|
+
from typing import Callable
|
|
47
|
+
|
|
48
|
+
import torch
|
|
49
|
+
|
|
50
|
+
# pylint: disable-next=redefined-builtin
|
|
51
|
+
from torch import Tensor, float32, maximum, no_grad, norm, numel, sum, tensor
|
|
4
52
|
from torch.nn import Module
|
|
53
|
+
from torch.nn.modules.loss import _Loss
|
|
54
|
+
from torch.optim import Optimizer
|
|
5
55
|
from torch.utils.data import DataLoader
|
|
6
|
-
from
|
|
56
|
+
from tqdm import tqdm
|
|
7
57
|
|
|
8
|
-
from .
|
|
58
|
+
from .checkpoints import CheckpointManager
|
|
9
59
|
from .constraints import Constraint
|
|
10
60
|
from .descriptor import Descriptor
|
|
61
|
+
from .metrics import MetricManager
|
|
62
|
+
from .utils import (
|
|
63
|
+
validate_callable,
|
|
64
|
+
validate_iterable,
|
|
65
|
+
validate_loaders,
|
|
66
|
+
validate_type,
|
|
67
|
+
)
|
|
11
68
|
|
|
12
69
|
|
|
13
70
|
class CongradsCore:
|
|
71
|
+
"""
|
|
72
|
+
The CongradsCore class is the central training engine for constraint-guided
|
|
73
|
+
neural network optimization. It integrates standard neural network training
|
|
74
|
+
with additional constraint-driven adjustments to the loss function, ensuring
|
|
75
|
+
that the network satisfies domain-specific constraints during training.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
descriptor (Descriptor): Describes variable layers in the network.
|
|
79
|
+
constraints (list[Constraint]): List of constraints to guide training.
|
|
80
|
+
loaders (tuple[DataLoader, DataLoader, DataLoader]): DataLoaders for
|
|
81
|
+
training, validation, and testing.
|
|
82
|
+
network (Module): The neural network model to train.
|
|
83
|
+
criterion (callable): The loss function used for
|
|
84
|
+
training and validation.
|
|
85
|
+
optimizer (Optimizer): The optimizer used for updating model parameters.
|
|
86
|
+
metric_manager (MetricManager): Manages metric tracking and recording.
|
|
87
|
+
device (torch.device): The device (e.g., CPU or GPU) for computations.
|
|
88
|
+
checkpoint_manager (CheckpointManager, optional): Manages
|
|
89
|
+
checkpointing. If not set, no checkpointing is done.
|
|
90
|
+
epsilon (Number, optional): A small value to avoid division by zero
|
|
91
|
+
in gradient calculations. Default is 1e-10.
|
|
92
|
+
|
|
93
|
+
Note:
|
|
94
|
+
A warning is logged if the descriptor has no variable layers,
|
|
95
|
+
as at least one variable layer is required for the constraint logic
|
|
96
|
+
to influence the training process.
|
|
97
|
+
"""
|
|
14
98
|
|
|
15
99
|
def __init__(
|
|
16
100
|
self,
|
|
@@ -18,14 +102,33 @@ class CongradsCore:
|
|
|
18
102
|
constraints: list[Constraint],
|
|
19
103
|
loaders: tuple[DataLoader, DataLoader, DataLoader],
|
|
20
104
|
network: Module,
|
|
21
|
-
criterion:
|
|
105
|
+
criterion: _Loss,
|
|
22
106
|
optimizer: Optimizer,
|
|
23
107
|
metric_manager: MetricManager,
|
|
24
|
-
device,
|
|
108
|
+
device: torch.device,
|
|
109
|
+
checkpoint_manager: CheckpointManager = None,
|
|
110
|
+
epsilon: Number = 1e-6,
|
|
25
111
|
):
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
112
|
+
"""
|
|
113
|
+
Initialize the CongradsCore object.
|
|
114
|
+
"""
|
|
115
|
+
|
|
116
|
+
# Type checking
|
|
117
|
+
validate_type("descriptor", descriptor, Descriptor)
|
|
118
|
+
validate_iterable("constraints", constraints, Constraint)
|
|
119
|
+
validate_loaders()
|
|
120
|
+
validate_type("network", network, Module)
|
|
121
|
+
validate_type("criterion", criterion, _Loss)
|
|
122
|
+
validate_type("optimizer", optimizer, Optimizer)
|
|
123
|
+
validate_type("metric_manager", metric_manager, MetricManager)
|
|
124
|
+
validate_type("device", device, torch.device)
|
|
125
|
+
validate_type(
|
|
126
|
+
"checkpoint_manager",
|
|
127
|
+
checkpoint_manager,
|
|
128
|
+
CheckpointManager,
|
|
129
|
+
allow_none=True,
|
|
130
|
+
)
|
|
131
|
+
validate_type("epsilon", epsilon, Number)
|
|
29
132
|
|
|
30
133
|
# Init object variables
|
|
31
134
|
self.descriptor = descriptor
|
|
@@ -38,105 +141,270 @@ class CongradsCore:
|
|
|
38
141
|
self.optimizer = optimizer
|
|
39
142
|
self.metric_manager = metric_manager
|
|
40
143
|
self.device = device
|
|
144
|
+
self.checkpoint_manager = checkpoint_manager
|
|
145
|
+
|
|
146
|
+
# Init epsilon tensor
|
|
147
|
+
self.epsilon = tensor(epsilon, device=self.device)
|
|
41
148
|
|
|
42
149
|
# Perform checks
|
|
43
150
|
if len(self.descriptor.variable_layers) == 0:
|
|
44
|
-
|
|
45
|
-
"The descriptor object has no variable layers. The constraint
|
|
151
|
+
warnings.warn(
|
|
152
|
+
"The descriptor object has no variable layers. The constraint \
|
|
153
|
+
guided loss adjustment is therefore not used. \
|
|
154
|
+
Is this the intended behavior?"
|
|
46
155
|
)
|
|
47
156
|
|
|
48
157
|
# Initialize constraint metrics
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
158
|
+
self._initialize_metrics()
|
|
159
|
+
|
|
160
|
+
def _initialize_metrics(self) -> None:
|
|
161
|
+
"""
|
|
162
|
+
Register metrics for loss, constraint satisfaction ratio (CSR),
|
|
163
|
+
and individual constraints.
|
|
164
|
+
|
|
165
|
+
This method registers the following metrics:
|
|
166
|
+
|
|
167
|
+
- Loss/train: Training loss.
|
|
168
|
+
- Loss/valid: Validation loss.
|
|
169
|
+
- Loss/test: Test loss after training.
|
|
170
|
+
- CSR/train: Constraint satisfaction ratio during training.
|
|
171
|
+
- CSR/valid: Constraint satisfaction ratio during validation.
|
|
172
|
+
- CSR/test: Constraint satisfaction ratio after training.
|
|
173
|
+
- One metric per constraint, for both training and validation.
|
|
174
|
+
|
|
175
|
+
"""
|
|
176
|
+
|
|
177
|
+
self.metric_manager.register("Loss/train", "during_training")
|
|
178
|
+
self.metric_manager.register("Loss/valid", "during_training")
|
|
179
|
+
self.metric_manager.register("Loss/test", "after_training")
|
|
180
|
+
|
|
181
|
+
if len(self.constraints) > 0:
|
|
182
|
+
self.metric_manager.register("CSR/train", "during_training")
|
|
183
|
+
self.metric_manager.register("CSR/valid", "during_training")
|
|
184
|
+
self.metric_manager.register("CSR/test", "after_training")
|
|
53
185
|
|
|
54
186
|
for constraint in self.constraints:
|
|
55
|
-
metric_manager.register(
|
|
56
|
-
|
|
187
|
+
self.metric_manager.register(
|
|
188
|
+
f"{constraint.name}/train", "during_training"
|
|
189
|
+
)
|
|
190
|
+
self.metric_manager.register(
|
|
191
|
+
f"{constraint.name}/valid", "during_training"
|
|
192
|
+
)
|
|
193
|
+
self.metric_manager.register(
|
|
194
|
+
f"{constraint.name}/test", "after_training"
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
def fit(
|
|
198
|
+
self,
|
|
199
|
+
start_epoch: int = 0,
|
|
200
|
+
max_epochs: int = 100,
|
|
201
|
+
on_epoch_start: Callable[[int], None] = None,
|
|
202
|
+
on_epoch_end: Callable[[int], None] = None,
|
|
203
|
+
on_train_start: Callable[[int], None] = None,
|
|
204
|
+
on_train_end: Callable[[int], None] = None,
|
|
205
|
+
) -> None:
|
|
206
|
+
"""
|
|
207
|
+
Train the model for a given number of epochs.
|
|
208
|
+
|
|
209
|
+
Args:
|
|
210
|
+
start_epoch (int, optional): The epoch number to start the training
|
|
211
|
+
with. Default is 0.
|
|
212
|
+
max_epochs (int, optional): The number of epochs to train the
|
|
213
|
+
model. Default is 100.
|
|
214
|
+
on_epoch_start (Callable[[int], None], optional): A callback
|
|
215
|
+
function that will be executed at the start of each epoch.
|
|
216
|
+
on_epoch_end (Callable[[int], None], optional): A callback
|
|
217
|
+
function that will be executed at the end of each epoch.
|
|
218
|
+
on_train_start (Callable[[int], None], optional): A callback
|
|
219
|
+
function that will be executed before the training starts.
|
|
220
|
+
on_train_end (Callable[[int], None], optional): A callback
|
|
221
|
+
function that will be executed after training ends.
|
|
222
|
+
"""
|
|
223
|
+
|
|
224
|
+
# Type checking
|
|
225
|
+
validate_type("start_epoch", start_epoch, int)
|
|
226
|
+
validate_callable("on_epoch_start", on_epoch_start, True)
|
|
227
|
+
validate_callable("on_epoch_end", on_epoch_end, True)
|
|
228
|
+
validate_callable("on_train_start", on_train_start, True)
|
|
229
|
+
validate_callable("on_train_end", on_train_end, True)
|
|
230
|
+
|
|
231
|
+
# Keep track of epoch
|
|
232
|
+
epoch = start_epoch
|
|
233
|
+
|
|
234
|
+
# Execute training start hook if set
|
|
235
|
+
if on_train_start:
|
|
236
|
+
on_train_start(epoch)
|
|
237
|
+
|
|
238
|
+
for i in tqdm(range(epoch, max_epochs), initial=epoch, desc="Epoch"):
|
|
239
|
+
epoch = i
|
|
240
|
+
|
|
241
|
+
# Execute epoch start hook if set
|
|
242
|
+
if on_epoch_start:
|
|
243
|
+
on_epoch_start(epoch)
|
|
244
|
+
|
|
245
|
+
# Execute training and validation epoch
|
|
246
|
+
self._train_epoch()
|
|
247
|
+
self._validate_epoch()
|
|
248
|
+
|
|
249
|
+
# Checkpointing
|
|
250
|
+
if self.checkpoint_manager:
|
|
251
|
+
self.checkpoint_manager.evaluate_criteria(epoch)
|
|
252
|
+
|
|
253
|
+
# Execute epoch end hook if set
|
|
254
|
+
if on_epoch_end:
|
|
255
|
+
on_epoch_end(epoch)
|
|
256
|
+
|
|
257
|
+
# Evaluate model performance on unseen test set
|
|
258
|
+
self._test_model()
|
|
259
|
+
|
|
260
|
+
# Save final model
|
|
261
|
+
if self.checkpoint_manager:
|
|
262
|
+
self.checkpoint_manager.save(epoch, "checkpoint_final.pth")
|
|
263
|
+
|
|
264
|
+
# Execute training end hook if set
|
|
265
|
+
if on_train_end:
|
|
266
|
+
on_train_end(epoch)
|
|
267
|
+
|
|
268
|
+
def _train_epoch(self) -> None:
|
|
269
|
+
"""
|
|
270
|
+
Perform training for a single epoch.
|
|
271
|
+
|
|
272
|
+
This method:
|
|
273
|
+
- Sets the model to training mode.
|
|
274
|
+
- Processes batches from the training DataLoader.
|
|
275
|
+
- Computes predictions and losses.
|
|
276
|
+
- Adjusts losses based on constraints.
|
|
277
|
+
- Updates model parameters using backpropagation.
|
|
278
|
+
|
|
279
|
+
Args:
|
|
280
|
+
epoch (int): The current epoch number.
|
|
281
|
+
"""
|
|
282
|
+
|
|
283
|
+
# Set model in training mode
|
|
284
|
+
self.network.train()
|
|
285
|
+
|
|
286
|
+
for batch in tqdm(
|
|
287
|
+
self.train_loader, desc="Training batches", leave=False
|
|
288
|
+
):
|
|
289
|
+
|
|
290
|
+
# Get input-output pairs from batch
|
|
291
|
+
inputs, outputs = batch
|
|
292
|
+
|
|
293
|
+
# Transfer to GPU
|
|
294
|
+
inputs, outputs = inputs.to(self.device), outputs.to(self.device)
|
|
295
|
+
|
|
296
|
+
# Model computations
|
|
297
|
+
prediction = self.network(inputs)
|
|
298
|
+
|
|
299
|
+
# Calculate loss
|
|
300
|
+
loss = self.criterion(prediction["output"], outputs)
|
|
301
|
+
self.metric_manager.accumulate("Loss/train", loss.unsqueeze(0))
|
|
302
|
+
|
|
303
|
+
# Adjust loss based on constraints
|
|
304
|
+
combined_loss = self.train_step(prediction, loss)
|
|
305
|
+
|
|
306
|
+
# Backprop
|
|
307
|
+
self.optimizer.zero_grad()
|
|
308
|
+
combined_loss.backward(
|
|
309
|
+
retain_graph=False, inputs=list(self.network.parameters())
|
|
310
|
+
)
|
|
311
|
+
self.optimizer.step()
|
|
57
312
|
|
|
58
|
-
def
|
|
59
|
-
|
|
60
|
-
for
|
|
313
|
+
def _validate_epoch(self) -> None:
|
|
314
|
+
"""
|
|
315
|
+
Perform validation for a single epoch.
|
|
61
316
|
|
|
62
|
-
|
|
63
|
-
|
|
317
|
+
This method:
|
|
318
|
+
- Sets the model to evaluation mode.
|
|
319
|
+
- Processes batches from the validation DataLoader.
|
|
320
|
+
- Computes predictions and losses.
|
|
321
|
+
- Logs constraint satisfaction ratios.
|
|
64
322
|
|
|
65
|
-
|
|
66
|
-
|
|
323
|
+
Args:
|
|
324
|
+
epoch (int): The current epoch number.
|
|
325
|
+
"""
|
|
67
326
|
|
|
68
|
-
|
|
69
|
-
|
|
327
|
+
# Set model in evaluation mode
|
|
328
|
+
self.network.eval()
|
|
329
|
+
|
|
330
|
+
with no_grad():
|
|
331
|
+
for batch in tqdm(
|
|
332
|
+
self.valid_loader, desc="Validation batches", leave=False
|
|
333
|
+
):
|
|
70
334
|
|
|
71
335
|
# Get input-output pairs from batch
|
|
72
336
|
inputs, outputs = batch
|
|
73
337
|
|
|
74
338
|
# Transfer to GPU
|
|
75
|
-
inputs, outputs = inputs.to(self.device), outputs.to(
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
prepare_time = start_time - time()
|
|
339
|
+
inputs, outputs = inputs.to(self.device), outputs.to(
|
|
340
|
+
self.device
|
|
341
|
+
)
|
|
79
342
|
|
|
80
343
|
# Model computations
|
|
81
344
|
prediction = self.network(inputs)
|
|
82
345
|
|
|
83
346
|
# Calculate loss
|
|
84
347
|
loss = self.criterion(prediction["output"], outputs)
|
|
85
|
-
self.metric_manager.accumulate("Loss/
|
|
348
|
+
self.metric_manager.accumulate("Loss/valid", loss.unsqueeze(0))
|
|
86
349
|
|
|
87
|
-
#
|
|
88
|
-
|
|
350
|
+
# Validate constraints
|
|
351
|
+
self.valid_step(prediction, loss)
|
|
89
352
|
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
retain_graph=False, inputs=list(self.network.parameters())
|
|
94
|
-
)
|
|
95
|
-
self.optimizer.step()
|
|
96
|
-
|
|
97
|
-
# Validation
|
|
98
|
-
with no_grad():
|
|
99
|
-
for batch in self.valid_loader:
|
|
353
|
+
def _test_model(self) -> None:
|
|
354
|
+
"""
|
|
355
|
+
Evaluate model performance on the test set.
|
|
100
356
|
|
|
101
|
-
|
|
102
|
-
|
|
357
|
+
This method:
|
|
358
|
+
- Sets the model to evaluation mode.
|
|
359
|
+
- Processes batches from the test DataLoader.
|
|
360
|
+
- Computes predictions and losses.
|
|
361
|
+
- Logs constraint satisfaction ratios.
|
|
103
362
|
|
|
104
|
-
|
|
105
|
-
inputs, outputs = batch
|
|
363
|
+
"""
|
|
106
364
|
|
|
107
|
-
|
|
108
|
-
|
|
365
|
+
# Set model in evaluation mode
|
|
366
|
+
self.network.eval()
|
|
109
367
|
|
|
110
|
-
|
|
111
|
-
|
|
368
|
+
with no_grad():
|
|
369
|
+
for batch in tqdm(
|
|
370
|
+
self.test_loader, desc="Test batches", leave=False
|
|
371
|
+
):
|
|
112
372
|
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
self.metric_manager.accumulate("Loss/valid", loss.unsqueeze(0))
|
|
373
|
+
# Get input-output pairs from batch
|
|
374
|
+
inputs, outputs = batch
|
|
116
375
|
|
|
117
|
-
|
|
118
|
-
|
|
376
|
+
# Transfer to GPU
|
|
377
|
+
inputs, outputs = inputs.to(self.device), outputs.to(
|
|
378
|
+
self.device
|
|
379
|
+
)
|
|
119
380
|
|
|
120
|
-
|
|
381
|
+
# Model computations
|
|
382
|
+
prediction = self.network(inputs)
|
|
121
383
|
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
384
|
+
# Calculate loss
|
|
385
|
+
loss = self.criterion(prediction["output"], outputs)
|
|
386
|
+
self.metric_manager.accumulate("Loss/test", loss.unsqueeze(0))
|
|
125
387
|
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
print(
|
|
129
|
-
"Compute efficiency: {:.2f}, epoch: {}/{}:".format(
|
|
130
|
-
process_time / (process_time + prepare_time), epoch, max_epochs
|
|
131
|
-
)
|
|
132
|
-
)
|
|
133
|
-
start_time = time()
|
|
388
|
+
# Validate constraints
|
|
389
|
+
self.test_step(prediction, loss)
|
|
134
390
|
|
|
135
391
|
def train_step(
|
|
136
392
|
self,
|
|
137
393
|
prediction: dict[str, Tensor],
|
|
138
394
|
loss: Tensor,
|
|
139
|
-
):
|
|
395
|
+
) -> Tensor:
|
|
396
|
+
"""
|
|
397
|
+
Adjust the training loss based on constraints
|
|
398
|
+
and compute the combined loss.
|
|
399
|
+
|
|
400
|
+
Args:
|
|
401
|
+
prediction (dict[str, Tensor]): Model predictions
|
|
402
|
+
for variable layers.
|
|
403
|
+
loss (Tensor): The base loss computed by the criterion.
|
|
404
|
+
|
|
405
|
+
Returns:
|
|
406
|
+
Tensor: The combined loss (base loss + constraint adjustments).
|
|
407
|
+
"""
|
|
140
408
|
|
|
141
409
|
# Init scalar tensor for loss
|
|
142
410
|
total_rescale_loss = tensor(0, dtype=float32, device=self.device)
|
|
@@ -149,48 +417,76 @@ class CongradsCore:
|
|
|
149
417
|
loss.backward(retain_graph=True, inputs=prediction[layer])
|
|
150
418
|
loss_grads[layer] = prediction[layer].grad
|
|
151
419
|
|
|
152
|
-
# For each constraint, TODO split into real and validation only constraints
|
|
153
420
|
for constraint in self.constraints:
|
|
154
421
|
|
|
155
422
|
# Check if constraints are satisfied and calculate directions
|
|
156
423
|
with no_grad():
|
|
157
|
-
constraint_checks =
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
# Only do direction calculations for variable layers affecting constraint
|
|
161
|
-
for layer in constraint.layers & self.descriptor.variable_layers:
|
|
424
|
+
constraint_checks, relevant_constraint_count = (
|
|
425
|
+
constraint.check_constraint(prediction)
|
|
426
|
+
)
|
|
162
427
|
|
|
428
|
+
# Only do adjusting calculation if constraint is not observant
|
|
429
|
+
if not constraint.monitor_only:
|
|
163
430
|
with no_grad():
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
constraint_checks.unsqueeze(1).type(float32)
|
|
167
|
-
* constraint_directions[layer]
|
|
431
|
+
constraint_directions = constraint.calculate_direction(
|
|
432
|
+
prediction
|
|
168
433
|
)
|
|
169
434
|
|
|
170
|
-
|
|
171
|
-
|
|
435
|
+
# Only do direction calculations for variable
|
|
436
|
+
# layers affecting constraint
|
|
437
|
+
for layer in (
|
|
438
|
+
constraint.layers & self.descriptor.variable_layers
|
|
439
|
+
):
|
|
440
|
+
|
|
441
|
+
with no_grad():
|
|
442
|
+
# Multiply direction modifiers with constraint result
|
|
443
|
+
constraint_result = (
|
|
444
|
+
1 - constraint_checks.unsqueeze(1)
|
|
445
|
+
) * constraint_directions[layer]
|
|
446
|
+
|
|
447
|
+
# Multiply result with rescale factor of constraint
|
|
448
|
+
constraint_result *= constraint.rescale_factor
|
|
172
449
|
|
|
173
|
-
|
|
174
|
-
|
|
450
|
+
# Calculate loss gradient norm
|
|
451
|
+
norm_loss_grad = norm(
|
|
452
|
+
loss_grads[layer], dim=1, p=2, keepdim=True
|
|
453
|
+
)
|
|
175
454
|
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
prediction[layer]
|
|
179
|
-
* constraint_result
|
|
180
|
-
* norm_loss_grad.detach().clone()
|
|
181
|
-
).mean()
|
|
455
|
+
# Apply minimum epsilon
|
|
456
|
+
norm_loss_grad = maximum(norm_loss_grad, self.epsilon)
|
|
182
457
|
|
|
183
|
-
|
|
184
|
-
|
|
458
|
+
# Calculate rescale loss
|
|
459
|
+
rescale_loss = (
|
|
460
|
+
prediction[layer]
|
|
461
|
+
* constraint_result
|
|
462
|
+
* norm_loss_grad.detach().clone()
|
|
463
|
+
).mean()
|
|
464
|
+
|
|
465
|
+
# Store rescale loss for this reference space
|
|
466
|
+
total_rescale_loss += rescale_loss
|
|
185
467
|
|
|
186
468
|
# Log constraint satisfaction ratio
|
|
187
469
|
self.metric_manager.accumulate(
|
|
188
470
|
f"{constraint.name}/train",
|
|
189
|
-
(
|
|
471
|
+
(
|
|
472
|
+
(
|
|
473
|
+
sum(constraint_checks)
|
|
474
|
+
- numel(constraint_checks)
|
|
475
|
+
+ relevant_constraint_count
|
|
476
|
+
)
|
|
477
|
+
/ relevant_constraint_count
|
|
478
|
+
).unsqueeze(0),
|
|
190
479
|
)
|
|
191
480
|
self.metric_manager.accumulate(
|
|
192
481
|
"CSR/train",
|
|
193
|
-
(
|
|
482
|
+
(
|
|
483
|
+
(
|
|
484
|
+
sum(constraint_checks)
|
|
485
|
+
- numel(constraint_checks)
|
|
486
|
+
+ relevant_constraint_count
|
|
487
|
+
)
|
|
488
|
+
/ relevant_constraint_count
|
|
489
|
+
).unsqueeze(0),
|
|
194
490
|
)
|
|
195
491
|
|
|
196
492
|
# Return combined loss
|
|
@@ -200,26 +496,102 @@ class CongradsCore:
|
|
|
200
496
|
self,
|
|
201
497
|
prediction: dict[str, Tensor],
|
|
202
498
|
loss: Tensor,
|
|
203
|
-
):
|
|
499
|
+
) -> Tensor:
|
|
500
|
+
"""
|
|
501
|
+
Evaluate constraints during validation and log satisfaction metrics.
|
|
204
502
|
|
|
205
|
-
|
|
206
|
-
|
|
503
|
+
Args:
|
|
504
|
+
prediction (dict[str, Tensor]): Model predictions for
|
|
505
|
+
variable layers.
|
|
506
|
+
loss (Tensor): The base loss computed by the criterion.
|
|
207
507
|
|
|
208
|
-
|
|
209
|
-
|
|
508
|
+
Returns:
|
|
509
|
+
Tensor: The unchanged base loss.
|
|
510
|
+
"""
|
|
210
511
|
|
|
211
|
-
|
|
212
|
-
|
|
512
|
+
# For each constraint in this reference space, calculate directions
|
|
513
|
+
for constraint in self.constraints:
|
|
213
514
|
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
515
|
+
# Check if constraints are satisfied for
|
|
516
|
+
constraint_checks, relevant_constraint_count = (
|
|
517
|
+
constraint.check_constraint(prediction)
|
|
518
|
+
)
|
|
519
|
+
|
|
520
|
+
# Log constraint satisfaction ratio
|
|
521
|
+
self.metric_manager.accumulate(
|
|
522
|
+
f"{constraint.name}/valid",
|
|
523
|
+
(
|
|
524
|
+
(
|
|
525
|
+
sum(constraint_checks)
|
|
526
|
+
- numel(constraint_checks)
|
|
527
|
+
+ relevant_constraint_count
|
|
528
|
+
)
|
|
529
|
+
/ relevant_constraint_count
|
|
530
|
+
).unsqueeze(0),
|
|
531
|
+
)
|
|
532
|
+
self.metric_manager.accumulate(
|
|
533
|
+
"CSR/valid",
|
|
534
|
+
(
|
|
535
|
+
(
|
|
536
|
+
sum(constraint_checks)
|
|
537
|
+
- numel(constraint_checks)
|
|
538
|
+
+ relevant_constraint_count
|
|
539
|
+
)
|
|
540
|
+
/ relevant_constraint_count
|
|
541
|
+
).unsqueeze(0),
|
|
542
|
+
)
|
|
543
|
+
|
|
544
|
+
# Return loss
|
|
545
|
+
return loss
|
|
546
|
+
|
|
547
|
+
def test_step(
|
|
548
|
+
self,
|
|
549
|
+
prediction: dict[str, Tensor],
|
|
550
|
+
loss: Tensor,
|
|
551
|
+
) -> Tensor:
|
|
552
|
+
"""
|
|
553
|
+
Evaluate constraints during test and log satisfaction metrics.
|
|
554
|
+
|
|
555
|
+
Args:
|
|
556
|
+
prediction (dict[str, Tensor]): Model predictions
|
|
557
|
+
for variable layers.
|
|
558
|
+
loss (Tensor): The base loss computed by the criterion.
|
|
559
|
+
|
|
560
|
+
Returns:
|
|
561
|
+
Tensor: The unchanged base loss.
|
|
562
|
+
"""
|
|
563
|
+
|
|
564
|
+
# For each constraint in this reference space, calculate directions
|
|
565
|
+
for constraint in self.constraints:
|
|
566
|
+
|
|
567
|
+
# Check if constraints are satisfied for
|
|
568
|
+
constraint_checks, relevant_constraint_count = (
|
|
569
|
+
constraint.check_constraint(prediction)
|
|
570
|
+
)
|
|
571
|
+
|
|
572
|
+
# Log constraint satisfaction ratio
|
|
573
|
+
self.metric_manager.accumulate(
|
|
574
|
+
f"{constraint.name}/test",
|
|
575
|
+
(
|
|
576
|
+
(
|
|
577
|
+
sum(constraint_checks)
|
|
578
|
+
- numel(constraint_checks)
|
|
579
|
+
+ relevant_constraint_count
|
|
580
|
+
)
|
|
581
|
+
/ relevant_constraint_count
|
|
582
|
+
).unsqueeze(0),
|
|
583
|
+
)
|
|
584
|
+
self.metric_manager.accumulate(
|
|
585
|
+
"CSR/test",
|
|
586
|
+
(
|
|
587
|
+
(
|
|
588
|
+
sum(constraint_checks)
|
|
589
|
+
- numel(constraint_checks)
|
|
590
|
+
+ relevant_constraint_count
|
|
591
|
+
)
|
|
592
|
+
/ relevant_constraint_count
|
|
593
|
+
).unsqueeze(0),
|
|
594
|
+
)
|
|
223
595
|
|
|
224
596
|
# Return loss
|
|
225
597
|
return loss
|