congrads 1.0.6__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- congrads/__init__.py +2 -3
- congrads/checkpoints.py +73 -127
- congrads/constraints.py +813 -476
- congrads/core.py +521 -345
- congrads/datasets.py +491 -191
- congrads/descriptor.py +118 -82
- congrads/metrics.py +55 -127
- congrads/networks.py +35 -81
- congrads/py.typed +0 -0
- congrads/transformations.py +65 -88
- congrads/utils.py +499 -131
- {congrads-1.0.6.dist-info → congrads-1.1.0.dist-info}/METADATA +48 -41
- congrads-1.1.0.dist-info/RECORD +14 -0
- congrads-1.1.0.dist-info/WHEEL +4 -0
- congrads-1.0.6.dist-info/LICENSE +0 -26
- congrads-1.0.6.dist-info/RECORD +0 -15
- congrads-1.0.6.dist-info/WHEEL +0 -5
- congrads-1.0.6.dist-info/top_level.txt +0 -1
congrads/core.py
CHANGED
|
@@ -1,15 +1,15 @@
|
|
|
1
|
-
"""
|
|
2
|
-
|
|
3
|
-
constraint-guided optimization into neural network training.
|
|
4
|
-
It extends traditional training processes by enforcing specific constraints
|
|
5
|
-
on the model's outputs, ensuring that the network satisfies domain-specific
|
|
1
|
+
"""This module provides the core CongradsCore class for the main training functionality.
|
|
2
|
+
|
|
3
|
+
It is designed to integrate constraint-guided optimization into neural network training.
|
|
4
|
+
It extends traditional training processes by enforcing specific constraints
|
|
5
|
+
on the model's outputs, ensuring that the network satisfies domain-specific
|
|
6
6
|
requirements during both training and evaluation.
|
|
7
7
|
|
|
8
|
-
The `CongradsCore` class serves as the central engine for managing the
|
|
9
|
-
training, validation, and testing phases of a neural network model,
|
|
10
|
-
incorporating constraints that influence the loss function and model updates.
|
|
11
|
-
The model is trained with standard loss functions while also incorporating
|
|
12
|
-
constraint-based adjustments, which are tracked and logged
|
|
8
|
+
The `CongradsCore` class serves as the central engine for managing the
|
|
9
|
+
training, validation, and testing phases of a neural network model,
|
|
10
|
+
incorporating constraints that influence the loss function and model updates.
|
|
11
|
+
The model is trained with standard loss functions while also incorporating
|
|
12
|
+
constraint-based adjustments, which are tracked and logged
|
|
13
13
|
throughout the process.
|
|
14
14
|
|
|
15
15
|
Key features:
|
|
@@ -18,37 +18,19 @@ Key features:
|
|
|
18
18
|
- Metric management for tracking loss and constraint satisfaction.
|
|
19
19
|
- Checkpoint management for saving and evaluating model states.
|
|
20
20
|
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
subject to constraints.
|
|
25
|
-
- `Constraint`: Defines various constraints, which are used to guide
|
|
26
|
-
the training process.
|
|
27
|
-
- `MetricManager`: Manages and tracks performance metrics such as loss
|
|
28
|
-
and constraint satisfaction.
|
|
29
|
-
- `CheckpointManager`: Manages saving and loading model checkpoints
|
|
30
|
-
during training.
|
|
31
|
-
- Utility functions to validate inputs and configurations.
|
|
32
|
-
|
|
33
|
-
Dependencies:
|
|
34
|
-
- PyTorch (`torch`)
|
|
35
|
-
- tqdm (for progress tracking)
|
|
36
|
-
|
|
37
|
-
The `CongradsCore` class allows for the use of additional callback functions
|
|
38
|
-
at different stages of the training process to customize behavior for
|
|
39
|
-
specific needs. These include callbacks for the start and end of epochs, as
|
|
21
|
+
The `CongradsCore` class allows for the use of additional callback functions
|
|
22
|
+
at different stages of the training process to customize behavior for
|
|
23
|
+
specific needs. These include callbacks for the start and end of epochs, as
|
|
40
24
|
well as the start and end of the entire training process.
|
|
41
25
|
|
|
42
26
|
"""
|
|
43
27
|
|
|
44
28
|
import warnings
|
|
45
|
-
from
|
|
46
|
-
from typing import Callable
|
|
29
|
+
from collections.abc import Callable
|
|
47
30
|
|
|
48
31
|
import torch
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
from torch import Tensor, float32, maximum, no_grad, norm, numel, sum, tensor
|
|
32
|
+
from torch import Tensor, float32, no_grad, sum, tensor
|
|
33
|
+
from torch.linalg import vector_norm
|
|
52
34
|
from torch.nn import Module
|
|
53
35
|
from torch.nn.modules.loss import _Loss
|
|
54
36
|
from torch.optim import Optimizer
|
|
@@ -60,7 +42,10 @@ from .constraints import Constraint
|
|
|
60
42
|
from .descriptor import Descriptor
|
|
61
43
|
from .metrics import MetricManager
|
|
62
44
|
from .utils import (
|
|
45
|
+
is_torch_loss,
|
|
46
|
+
torch_loss_wrapper,
|
|
63
47
|
validate_callable,
|
|
48
|
+
validate_callable_iterable,
|
|
64
49
|
validate_iterable,
|
|
65
50
|
validate_loaders,
|
|
66
51
|
validate_type,
|
|
@@ -68,32 +53,11 @@ from .utils import (
|
|
|
68
53
|
|
|
69
54
|
|
|
70
55
|
class CongradsCore:
|
|
71
|
-
"""
|
|
72
|
-
|
|
73
|
-
|
|
56
|
+
"""The CongradsCore class is the central training engine for constraint-guided optimization.
|
|
57
|
+
|
|
58
|
+
It integrates standard neural network training
|
|
74
59
|
with additional constraint-driven adjustments to the loss function, ensuring
|
|
75
60
|
that the network satisfies domain-specific constraints during training.
|
|
76
|
-
|
|
77
|
-
Args:
|
|
78
|
-
descriptor (Descriptor): Describes variable layers in the network.
|
|
79
|
-
constraints (list[Constraint]): List of constraints to guide training.
|
|
80
|
-
loaders (tuple[DataLoader, DataLoader, DataLoader]): DataLoaders for
|
|
81
|
-
training, validation, and testing.
|
|
82
|
-
network (Module): The neural network model to train.
|
|
83
|
-
criterion (callable): The loss function used for
|
|
84
|
-
training and validation.
|
|
85
|
-
optimizer (Optimizer): The optimizer used for updating model parameters.
|
|
86
|
-
metric_manager (MetricManager): Manages metric tracking and recording.
|
|
87
|
-
device (torch.device): The device (e.g., CPU or GPU) for computations.
|
|
88
|
-
checkpoint_manager (CheckpointManager, optional): Manages
|
|
89
|
-
checkpointing. If not set, no checkpointing is done.
|
|
90
|
-
epsilon (Number, optional): A small value to avoid division by zero
|
|
91
|
-
in gradient calculations. Default is 1e-10.
|
|
92
|
-
|
|
93
|
-
Note:
|
|
94
|
-
A warning is logged if the descriptor has no variable layers,
|
|
95
|
-
as at least one variable layer is required for the constraint logic
|
|
96
|
-
to influence the training process.
|
|
97
61
|
"""
|
|
98
62
|
|
|
99
63
|
def __init__(
|
|
@@ -106,29 +70,69 @@ class CongradsCore:
|
|
|
106
70
|
optimizer: Optimizer,
|
|
107
71
|
metric_manager: MetricManager,
|
|
108
72
|
device: torch.device,
|
|
73
|
+
network_uses_grad: bool = False,
|
|
109
74
|
checkpoint_manager: CheckpointManager = None,
|
|
110
|
-
epsilon:
|
|
75
|
+
epsilon: float = 1e-6,
|
|
76
|
+
constraint_aggregator: Callable[..., Tensor] = sum,
|
|
77
|
+
disable_progress_bar_epoch: bool = False,
|
|
78
|
+
disable_progress_bar_batch: bool = False,
|
|
79
|
+
enforce_all: bool = True,
|
|
111
80
|
):
|
|
112
|
-
"""
|
|
113
|
-
Initialize the CongradsCore object.
|
|
114
|
-
"""
|
|
81
|
+
"""Initialize the CongradsCore object.
|
|
115
82
|
|
|
83
|
+
Args:
|
|
84
|
+
descriptor (Descriptor): Describes variable layers in the network.
|
|
85
|
+
constraints (list[Constraint]): List of constraints to guide training.
|
|
86
|
+
loaders (tuple[DataLoader, DataLoader, DataLoader]): DataLoaders for
|
|
87
|
+
training, validation, and testing.
|
|
88
|
+
network (Module): The neural network model to train.
|
|
89
|
+
criterion (callable): The loss function used for
|
|
90
|
+
training and validation.
|
|
91
|
+
optimizer (Optimizer): The optimizer used for updating model parameters.
|
|
92
|
+
metric_manager (MetricManager): Manages metric tracking and recording.
|
|
93
|
+
device (torch.device): The device (e.g., CPU or GPU) for computations.
|
|
94
|
+
network_uses_grad (bool, optional): A flag indicating if the network
|
|
95
|
+
contains gradient calculation computations. Default is False.
|
|
96
|
+
checkpoint_manager (CheckpointManager, optional): Manages
|
|
97
|
+
checkpointing. If not set, no checkpointing is done.
|
|
98
|
+
epsilon (float, optional): A small value to avoid division by zero
|
|
99
|
+
in gradient calculations. Default is 1e-10.
|
|
100
|
+
constraint_aggregator (Callable[..., Tensor], optional): A function
|
|
101
|
+
to aggregate the constraint rescale loss. Default is `sum`.
|
|
102
|
+
disable_progress_bar_epoch (bool, optional): If set to True, the epoch
|
|
103
|
+
progress bar will not show. Defaults to False.
|
|
104
|
+
disable_progress_bar_batch (bool, optional): If set to True, the batch
|
|
105
|
+
progress bar will not show. Defaults to False.
|
|
106
|
+
enforce_all (bool, optional): If set to False, constraints will only be monitored and
|
|
107
|
+
not influence the training process. Overrides constraint-specific `enforce` parameters.
|
|
108
|
+
Defaults to True.
|
|
109
|
+
|
|
110
|
+
Note:
|
|
111
|
+
A warning is logged if the descriptor has no variable layers,
|
|
112
|
+
as at least one variable layer is required for the constraint logic
|
|
113
|
+
to influence the training process.
|
|
114
|
+
"""
|
|
116
115
|
# Type checking
|
|
117
116
|
validate_type("descriptor", descriptor, Descriptor)
|
|
118
|
-
validate_iterable("constraints", constraints, Constraint)
|
|
119
|
-
validate_loaders()
|
|
117
|
+
validate_iterable("constraints", constraints, Constraint, allow_empty=True)
|
|
118
|
+
validate_loaders("loaders", loaders)
|
|
120
119
|
validate_type("network", network, Module)
|
|
121
120
|
validate_type("criterion", criterion, _Loss)
|
|
122
121
|
validate_type("optimizer", optimizer, Optimizer)
|
|
123
122
|
validate_type("metric_manager", metric_manager, MetricManager)
|
|
124
123
|
validate_type("device", device, torch.device)
|
|
124
|
+
validate_type("network_uses_grad", network_uses_grad, bool)
|
|
125
125
|
validate_type(
|
|
126
126
|
"checkpoint_manager",
|
|
127
127
|
checkpoint_manager,
|
|
128
128
|
CheckpointManager,
|
|
129
129
|
allow_none=True,
|
|
130
130
|
)
|
|
131
|
-
validate_type("epsilon", epsilon,
|
|
131
|
+
validate_type("epsilon", epsilon, float)
|
|
132
|
+
validate_callable("constraint_aggregator", constraint_aggregator, allow_none=True)
|
|
133
|
+
validate_type("disable_progress_bar_epoch", disable_progress_bar_epoch, bool)
|
|
134
|
+
validate_type("disable_progress_bar_batch", disable_progress_bar_batch, bool)
|
|
135
|
+
validate_type("enforce_all", enforce_all, bool)
|
|
132
136
|
|
|
133
137
|
# Init object variables
|
|
134
138
|
self.descriptor = descriptor
|
|
@@ -137,30 +141,38 @@ class CongradsCore:
|
|
|
137
141
|
self.valid_loader = loaders[1]
|
|
138
142
|
self.test_loader = loaders[2]
|
|
139
143
|
self.network = network
|
|
140
|
-
self.criterion = criterion
|
|
141
144
|
self.optimizer = optimizer
|
|
142
145
|
self.metric_manager = metric_manager
|
|
143
146
|
self.device = device
|
|
147
|
+
self.network_uses_grad = network_uses_grad
|
|
144
148
|
self.checkpoint_manager = checkpoint_manager
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
self.
|
|
149
|
+
self.epsilon = epsilon
|
|
150
|
+
self.constraint_aggregator = constraint_aggregator
|
|
151
|
+
self.disable_progress_bar_epoch = disable_progress_bar_epoch
|
|
152
|
+
self.disable_progress_bar_batch = disable_progress_bar_batch
|
|
153
|
+
self.enforce_all = enforce_all
|
|
154
|
+
|
|
155
|
+
# Check if criterion is a torch loss function
|
|
156
|
+
if is_torch_loss(criterion):
|
|
157
|
+
# If so, wrap it in a custom loss function
|
|
158
|
+
self.criterion = torch_loss_wrapper(criterion)
|
|
159
|
+
else:
|
|
160
|
+
self.criterion = criterion
|
|
148
161
|
|
|
149
162
|
# Perform checks
|
|
150
|
-
if len(self.descriptor.
|
|
163
|
+
if len(self.descriptor.variable_keys) == 0:
|
|
151
164
|
warnings.warn(
|
|
152
165
|
"The descriptor object has no variable layers. The constraint \
|
|
153
166
|
guided loss adjustment is therefore not used. \
|
|
154
|
-
Is this the intended behavior?"
|
|
167
|
+
Is this the intended behavior?",
|
|
168
|
+
stacklevel=2,
|
|
155
169
|
)
|
|
156
170
|
|
|
157
171
|
# Initialize constraint metrics
|
|
158
172
|
self._initialize_metrics()
|
|
159
173
|
|
|
160
174
|
def _initialize_metrics(self) -> None:
|
|
161
|
-
"""
|
|
162
|
-
Register metrics for loss, constraint satisfaction ratio (CSR),
|
|
163
|
-
and individual constraints.
|
|
175
|
+
"""Register metrics for loss, constraint satisfaction ratio (CSR), and constraints.
|
|
164
176
|
|
|
165
177
|
This method registers the following metrics:
|
|
166
178
|
|
|
@@ -173,7 +185,6 @@ class CongradsCore:
|
|
|
173
185
|
- One metric per constraint, for both training and validation.
|
|
174
186
|
|
|
175
187
|
"""
|
|
176
|
-
|
|
177
188
|
self.metric_manager.register("Loss/train", "during_training")
|
|
178
189
|
self.metric_manager.register("Loss/valid", "during_training")
|
|
179
190
|
self.metric_manager.register("Loss/test", "after_training")
|
|
@@ -184,414 +195,579 @@ class CongradsCore:
|
|
|
184
195
|
self.metric_manager.register("CSR/test", "after_training")
|
|
185
196
|
|
|
186
197
|
for constraint in self.constraints:
|
|
187
|
-
self.metric_manager.register(
|
|
188
|
-
|
|
189
|
-
)
|
|
190
|
-
self.metric_manager.register(
|
|
191
|
-
f"{constraint.name}/valid", "during_training"
|
|
192
|
-
)
|
|
193
|
-
self.metric_manager.register(
|
|
194
|
-
f"{constraint.name}/test", "after_training"
|
|
195
|
-
)
|
|
198
|
+
self.metric_manager.register(f"{constraint.name}/train", "during_training")
|
|
199
|
+
self.metric_manager.register(f"{constraint.name}/valid", "during_training")
|
|
200
|
+
self.metric_manager.register(f"{constraint.name}/test", "after_training")
|
|
196
201
|
|
|
197
202
|
def fit(
|
|
198
203
|
self,
|
|
199
204
|
start_epoch: int = 0,
|
|
200
205
|
max_epochs: int = 100,
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
206
|
+
test_model: bool = True,
|
|
207
|
+
on_batch_start: list[Callable[[dict[str, Tensor]], dict[str, Tensor]]] | None = None,
|
|
208
|
+
on_batch_end: list[Callable[[dict[str, Tensor]], dict[str, Tensor]]] | None = None,
|
|
209
|
+
on_train_batch_start: list[Callable[[dict[str, Tensor]], dict[str, Tensor]]] | None = None,
|
|
210
|
+
on_train_batch_end: list[Callable[[dict[str, Tensor]], dict[str, Tensor]]] | None = None,
|
|
211
|
+
on_valid_batch_start: list[Callable[[dict[str, Tensor]], dict[str, Tensor]]] | None = None,
|
|
212
|
+
on_valid_batch_end: list[Callable[[dict[str, Tensor]], dict[str, Tensor]]] | None = None,
|
|
213
|
+
on_test_batch_start: list[Callable[[dict[str, Tensor]], dict[str, Tensor]]] | None = None,
|
|
214
|
+
on_test_batch_end: list[Callable[[dict[str, Tensor]], dict[str, Tensor]]] | None = None,
|
|
215
|
+
on_epoch_start: list[Callable[[int], None]] | None = None,
|
|
216
|
+
on_epoch_end: list[Callable[[int], None]] | None = None,
|
|
217
|
+
on_train_start: list[Callable[[int], None]] | None = None,
|
|
218
|
+
on_train_end: list[Callable[[int], None]] | None = None,
|
|
219
|
+
on_train_completion_forward_pass: list[Callable[[dict[str, Tensor]], dict[str, Tensor]]]
|
|
220
|
+
| None = None,
|
|
221
|
+
on_val_completion_forward_pass: list[Callable[[dict[str, Tensor]], dict[str, Tensor]]]
|
|
222
|
+
| None = None,
|
|
223
|
+
on_test_completion_forward_pass: list[Callable[[dict[str, Tensor]], dict[str, Tensor]]]
|
|
224
|
+
| None = None,
|
|
225
|
+
on_test_start: list[Callable[[int], None]] | None = None,
|
|
226
|
+
on_test_end: list[Callable[[int], None]] | None = None,
|
|
205
227
|
) -> None:
|
|
206
|
-
"""
|
|
207
|
-
|
|
228
|
+
"""Train the model over multiple epochs with optional validation and testing.
|
|
229
|
+
|
|
230
|
+
This method manages the full training loop, including:
|
|
231
|
+
|
|
232
|
+
- Executing epoch-level and batch-level callbacks.
|
|
233
|
+
- Training and validating the model each epoch.
|
|
234
|
+
- Adjusting losses according to constraints.
|
|
235
|
+
- Logging metrics via the metric manager.
|
|
236
|
+
- Optional evaluation on the test set.
|
|
237
|
+
- Checkpointing the model during and after training.
|
|
208
238
|
|
|
209
239
|
Args:
|
|
210
|
-
start_epoch (int, optional):
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
240
|
+
start_epoch (int, optional): Epoch number to start training from. Defaults to 0.
|
|
241
|
+
max_epochs (int, optional): Total number of epochs to train. Defaults to 100.
|
|
242
|
+
test_model (bool, optional): If True, evaluate the model on the test set after training. Defaults to True.
|
|
243
|
+
on_batch_start (list[Callable], optional): Callbacks executed at the start of every batch. Defaults to None.
|
|
244
|
+
on_batch_end (list[Callable], optional): Callbacks executed at the end of every batch. Defaults to None.
|
|
245
|
+
on_train_batch_start (list[Callable], optional): Callbacks executed at the start of each training batch. Defaults to `on_batch_start` if not provided.
|
|
246
|
+
on_train_batch_end (list[Callable], optional): Callbacks executed at the end of each training batch. Defaults to `on_batch_end` if not provided.
|
|
247
|
+
on_valid_batch_start (list[Callable], optional): Callbacks executed at the start of each validation batch. Defaults to `on_batch_start` if not provided.
|
|
248
|
+
on_valid_batch_end (list[Callable], optional): Callbacks executed at the end of each validation batch. Defaults to `on_batch_end` if not provided.
|
|
249
|
+
on_test_batch_start (list[Callable], optional): Callbacks executed at the start of each test batch. Defaults to `on_batch_start` if not provided.
|
|
250
|
+
on_test_batch_end (list[Callable], optional): Callbacks executed at the end of each test batch. Defaults to `on_batch_end` if not provided.
|
|
251
|
+
on_epoch_start (list[Callable], optional): Callbacks executed at the start of each epoch. Defaults to None.
|
|
252
|
+
on_epoch_end (list[Callable], optional): Callbacks executed at the end of each epoch. Defaults to None.
|
|
253
|
+
on_train_start (list[Callable], optional): Callbacks executed before training starts. Defaults to None.
|
|
254
|
+
on_train_end (list[Callable], optional): Callbacks executed after training ends. Defaults to None.
|
|
255
|
+
on_train_completion_forward_pass (list[Callable], optional): Callbacks executed after the forward pass during training. Defaults to None.
|
|
256
|
+
on_val_completion_forward_pass (list[Callable], optional): Callbacks executed after the forward pass during validation. Defaults to None.
|
|
257
|
+
on_test_completion_forward_pass (list[Callable], optional): Callbacks executed after the forward pass during testing. Defaults to None.
|
|
258
|
+
on_test_start (list[Callable], optional): Callbacks executed before testing starts. Defaults to None.
|
|
259
|
+
on_test_end (list[Callable], optional): Callbacks executed after testing ends. Defaults to None.
|
|
260
|
+
|
|
261
|
+
Notes:
|
|
262
|
+
- If phase-specific callbacks (train/valid/test) are not provided, the global `on_batch_start` and `on_batch_end` are used.
|
|
263
|
+
- Training metrics, loss adjustments, and constraint satisfaction ratios are automatically logged via the metric manager.
|
|
264
|
+
- The final model checkpoint is saved if a checkpoint manager is configured.
|
|
222
265
|
"""
|
|
223
|
-
|
|
224
266
|
# Type checking
|
|
225
267
|
validate_type("start_epoch", start_epoch, int)
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
268
|
+
validate_type("max_epochs", max_epochs, int)
|
|
269
|
+
validate_type("test_model", test_model, bool)
|
|
270
|
+
validate_callable_iterable("on_batch_start", on_batch_start, allow_none=True)
|
|
271
|
+
validate_callable_iterable("on_batch_end", on_batch_end, allow_none=True)
|
|
272
|
+
validate_callable_iterable("on_train_batch_start", on_train_batch_start, allow_none=True)
|
|
273
|
+
validate_callable_iterable("on_train_batch_end", on_train_batch_end, allow_none=True)
|
|
274
|
+
validate_callable_iterable("on_valid_batch_start", on_valid_batch_start, allow_none=True)
|
|
275
|
+
validate_callable_iterable("on_valid_batch_end", on_valid_batch_end, allow_none=True)
|
|
276
|
+
validate_callable_iterable("on_test_batch_start", on_test_batch_start, allow_none=True)
|
|
277
|
+
validate_callable_iterable("on_test_batch_end", on_test_batch_end, allow_none=True)
|
|
278
|
+
validate_callable_iterable("on_epoch_start", on_epoch_start, allow_none=True)
|
|
279
|
+
validate_callable_iterable("on_epoch_end", on_epoch_end, allow_none=True)
|
|
280
|
+
validate_callable_iterable("on_train_start", on_train_start, allow_none=True)
|
|
281
|
+
validate_callable_iterable("on_train_end", on_train_end, allow_none=True)
|
|
282
|
+
validate_callable_iterable(
|
|
283
|
+
"on_train_completion_forward_pass",
|
|
284
|
+
on_train_completion_forward_pass,
|
|
285
|
+
allow_none=True,
|
|
286
|
+
)
|
|
287
|
+
validate_callable_iterable(
|
|
288
|
+
"on_val_completion_forward_pass",
|
|
289
|
+
on_val_completion_forward_pass,
|
|
290
|
+
allow_none=True,
|
|
291
|
+
)
|
|
292
|
+
validate_callable_iterable(
|
|
293
|
+
"on_test_completion_forward_pass",
|
|
294
|
+
on_test_completion_forward_pass,
|
|
295
|
+
allow_none=True,
|
|
296
|
+
)
|
|
297
|
+
validate_callable_iterable("on_test_start", on_test_start, allow_none=True)
|
|
298
|
+
validate_callable_iterable("on_test_end", on_test_end, allow_none=True)
|
|
299
|
+
|
|
300
|
+
# Use global batch callback if phase-specific callback is unset
|
|
301
|
+
# Init callbacks as empty list if None
|
|
302
|
+
on_train_batch_start = on_train_batch_start or on_batch_start or []
|
|
303
|
+
on_train_batch_end = on_train_batch_end or on_batch_end or []
|
|
304
|
+
on_valid_batch_start = on_valid_batch_start or on_batch_start or []
|
|
305
|
+
on_valid_batch_end = on_valid_batch_end or on_batch_end or []
|
|
306
|
+
on_test_batch_start = on_test_batch_start or on_batch_start or []
|
|
307
|
+
on_test_batch_end = on_test_batch_end or on_batch_end or []
|
|
308
|
+
on_batch_start = on_batch_start or []
|
|
309
|
+
on_batch_end = on_batch_end or []
|
|
310
|
+
on_epoch_start = on_epoch_start or []
|
|
311
|
+
on_epoch_end = on_epoch_end or []
|
|
312
|
+
on_train_start = on_train_start or []
|
|
313
|
+
on_train_end = on_train_end or []
|
|
314
|
+
on_train_completion_forward_pass = on_train_completion_forward_pass or []
|
|
315
|
+
on_val_completion_forward_pass = on_val_completion_forward_pass or []
|
|
316
|
+
on_test_completion_forward_pass = on_test_completion_forward_pass or []
|
|
317
|
+
on_test_start = on_test_start or []
|
|
318
|
+
on_test_end = on_test_end or []
|
|
230
319
|
|
|
231
320
|
# Keep track of epoch
|
|
232
321
|
epoch = start_epoch
|
|
233
322
|
|
|
234
323
|
# Execute training start hook if set
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
for i in tqdm(
|
|
324
|
+
for callback in on_train_start:
|
|
325
|
+
callback(epoch)
|
|
326
|
+
|
|
327
|
+
for i in tqdm(
|
|
328
|
+
range(epoch, max_epochs),
|
|
329
|
+
initial=epoch,
|
|
330
|
+
desc="Epoch",
|
|
331
|
+
disable=self.disable_progress_bar_epoch,
|
|
332
|
+
):
|
|
239
333
|
epoch = i
|
|
240
334
|
|
|
241
335
|
# Execute epoch start hook if set
|
|
242
|
-
|
|
243
|
-
|
|
336
|
+
for callback in on_epoch_start:
|
|
337
|
+
callback(epoch)
|
|
244
338
|
|
|
245
339
|
# Execute training and validation epoch
|
|
246
|
-
self._train_epoch(
|
|
247
|
-
|
|
340
|
+
self._train_epoch(
|
|
341
|
+
on_train_batch_start,
|
|
342
|
+
on_train_batch_end,
|
|
343
|
+
on_train_completion_forward_pass,
|
|
344
|
+
)
|
|
345
|
+
self._validate_epoch(
|
|
346
|
+
on_valid_batch_start,
|
|
347
|
+
on_valid_batch_end,
|
|
348
|
+
on_val_completion_forward_pass,
|
|
349
|
+
)
|
|
248
350
|
|
|
249
351
|
# Checkpointing
|
|
250
352
|
if self.checkpoint_manager:
|
|
251
353
|
self.checkpoint_manager.evaluate_criteria(epoch)
|
|
252
354
|
|
|
253
355
|
# Execute epoch end hook if set
|
|
254
|
-
|
|
255
|
-
|
|
356
|
+
for callback in on_epoch_end:
|
|
357
|
+
callback(epoch)
|
|
358
|
+
|
|
359
|
+
# Execute training end hook if set
|
|
360
|
+
for callback in on_train_end:
|
|
361
|
+
callback(epoch)
|
|
362
|
+
|
|
363
|
+
# Evaluate model performance on unseen test set if required
|
|
364
|
+
if test_model:
|
|
365
|
+
# Execute test end hook if set
|
|
366
|
+
for callback in on_test_start:
|
|
367
|
+
callback(epoch)
|
|
368
|
+
|
|
369
|
+
self._test_model(
|
|
370
|
+
on_test_batch_start,
|
|
371
|
+
on_test_batch_end,
|
|
372
|
+
on_test_completion_forward_pass,
|
|
373
|
+
)
|
|
256
374
|
|
|
257
|
-
|
|
258
|
-
|
|
375
|
+
# Execute test end hook if set
|
|
376
|
+
for callback in on_test_end:
|
|
377
|
+
callback(epoch)
|
|
259
378
|
|
|
260
379
|
# Save final model
|
|
261
380
|
if self.checkpoint_manager:
|
|
262
381
|
self.checkpoint_manager.save(epoch, "checkpoint_final.pth")
|
|
263
382
|
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
383
|
+
def _train_epoch(
|
|
384
|
+
self,
|
|
385
|
+
on_train_batch_start: tuple[Callable[[dict[str, Tensor]], dict[str, Tensor]], ...],
|
|
386
|
+
on_train_batch_end: tuple[Callable[[dict[str, Tensor]], dict[str, Tensor]], ...],
|
|
387
|
+
on_train_completion_forward_pass: tuple[
|
|
388
|
+
Callable[[dict[str, Tensor]], dict[str, Tensor]], ...
|
|
389
|
+
],
|
|
390
|
+
) -> None:
|
|
391
|
+
"""Perform a single training epoch over all batches.
|
|
271
392
|
|
|
272
|
-
This method
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
- Updates model parameters using backpropagation.
|
|
393
|
+
This method sets the network to training mode, iterates over the training
|
|
394
|
+
DataLoader, computes predictions, evaluates losses, applies constraint-based
|
|
395
|
+
adjustments, performs backpropagation, and updates model parameters. It also
|
|
396
|
+
supports executing optional callbacks at different stages of the batch
|
|
397
|
+
processing.
|
|
278
398
|
|
|
279
399
|
Args:
|
|
280
|
-
|
|
281
|
-
|
|
400
|
+
on_train_batch_start (tuple[Callable[[dict[str, Tensor]], dict[str, Tensor]], ...]):
|
|
401
|
+
Callbacks executed at the start of each batch. Each callback receives the
|
|
402
|
+
data dictionary and returns updated versions.
|
|
403
|
+
on_train_batch_end (tuple[Callable[[dict[str, Tensor]], dict[str, Tensor]], ...]):
|
|
404
|
+
Callbacks executed at the end of each batch. Each callback receives the
|
|
405
|
+
data dictionary and returns updated versions.
|
|
406
|
+
on_train_completion_forward_pass (tuple[Callable[[dict[str, Tensor]], dict[str, Tensor]], ...]):
|
|
407
|
+
Callbacks executed immediately after the forward pass of the batch.
|
|
408
|
+
Each callback receives the data dictionary and returns updated versions.
|
|
282
409
|
|
|
410
|
+
Returns:
|
|
411
|
+
None
|
|
412
|
+
"""
|
|
283
413
|
# Set model in training mode
|
|
284
414
|
self.network.train()
|
|
285
415
|
|
|
286
|
-
for
|
|
287
|
-
self.train_loader,
|
|
416
|
+
for data in tqdm(
|
|
417
|
+
self.train_loader,
|
|
418
|
+
desc="Training batches",
|
|
419
|
+
leave=False,
|
|
420
|
+
disable=self.disable_progress_bar_batch,
|
|
288
421
|
):
|
|
422
|
+
# Transfer batch data to GPU
|
|
423
|
+
data: dict[str, Tensor] = {key: value.to(self.device) for key, value in data.items()}
|
|
289
424
|
|
|
290
|
-
#
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
# Transfer to GPU
|
|
294
|
-
inputs, outputs = inputs.to(self.device), outputs.to(self.device)
|
|
425
|
+
# Execute on batch start callbacks
|
|
426
|
+
for callback in on_train_batch_start:
|
|
427
|
+
data = callback(data)
|
|
295
428
|
|
|
296
429
|
# Model computations
|
|
297
|
-
|
|
430
|
+
data = self.network(data)
|
|
431
|
+
|
|
432
|
+
# Execute on completion forward pass callbacks
|
|
433
|
+
for callback in on_train_completion_forward_pass:
|
|
434
|
+
data = callback(data)
|
|
298
435
|
|
|
299
436
|
# Calculate loss
|
|
300
|
-
loss = self.criterion(
|
|
437
|
+
loss = self.criterion(
|
|
438
|
+
data["output"],
|
|
439
|
+
data["target"],
|
|
440
|
+
data=data,
|
|
441
|
+
)
|
|
301
442
|
self.metric_manager.accumulate("Loss/train", loss.unsqueeze(0))
|
|
302
443
|
|
|
303
444
|
# Adjust loss based on constraints
|
|
304
|
-
combined_loss = self.train_step(
|
|
445
|
+
combined_loss = self.train_step(
|
|
446
|
+
data,
|
|
447
|
+
loss,
|
|
448
|
+
self.constraints,
|
|
449
|
+
self.descriptor,
|
|
450
|
+
self.metric_manager,
|
|
451
|
+
self.device,
|
|
452
|
+
constraint_aggregator=self.constraint_aggregator,
|
|
453
|
+
epsilon=self.epsilon,
|
|
454
|
+
enforce_all=self.enforce_all,
|
|
455
|
+
)
|
|
305
456
|
|
|
306
457
|
# Backprop
|
|
307
458
|
self.optimizer.zero_grad()
|
|
308
|
-
combined_loss.backward(
|
|
309
|
-
retain_graph=False, inputs=list(self.network.parameters())
|
|
310
|
-
)
|
|
459
|
+
combined_loss.backward(retain_graph=False, inputs=list(self.network.parameters()))
|
|
311
460
|
self.optimizer.step()
|
|
312
461
|
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
462
|
+
# Execute on batch end callbacks
|
|
463
|
+
for callback in on_train_batch_end:
|
|
464
|
+
data = callback(data)
|
|
465
|
+
|
|
466
|
+
def _validate_epoch(
|
|
467
|
+
self,
|
|
468
|
+
on_valid_batch_start: tuple[Callable[[dict[str, Tensor]], dict[str, Tensor]]],
|
|
469
|
+
on_valid_batch_end: tuple[Callable[[dict[str, Tensor]], dict[str, Tensor]]],
|
|
470
|
+
on_valid_completion_forward_pass: tuple[Callable[[dict[str, Tensor]], dict[str, Tensor]]],
|
|
471
|
+
) -> None:
|
|
472
|
+
"""Perform a single validation epoch over all batches.
|
|
316
473
|
|
|
317
|
-
This method
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
- Logs constraint satisfaction ratios.
|
|
474
|
+
This method sets the network to evaluation mode, iterates over the validation
|
|
475
|
+
DataLoader, computes predictions, evaluates losses, and logs constraint
|
|
476
|
+
satisfaction. Optional callbacks can be executed at the start and end of each
|
|
477
|
+
batch, as well as after the forward pass.
|
|
322
478
|
|
|
323
479
|
Args:
|
|
324
|
-
|
|
325
|
-
|
|
480
|
+
on_valid_batch_start (tuple[Callable[[dict[str, Tensor]], dict[str, Tensor]]]):
|
|
481
|
+
Callbacks executed at the start of each validation batch. Each callback
|
|
482
|
+
receives the data dictionary and returns updated versions.
|
|
483
|
+
on_valid_batch_end (tuple[Callable[[dict[str, Tensor]], dict[str, Tensor]]]):
|
|
484
|
+
Callbacks executed at the end of each validation batch. Each callback
|
|
485
|
+
receives the data dictionary and returns updated versions.
|
|
486
|
+
on_valid_completion_forward_pass (tuple[Callable[[dict[str, Tensor]], dict[str, Tensor]]]):
|
|
487
|
+
Callbacks executed immediately after the forward pass of the validation batch.
|
|
488
|
+
Each callback receives the data dictionary and returns updated versions.
|
|
326
489
|
|
|
490
|
+
Returns:
|
|
491
|
+
None
|
|
492
|
+
"""
|
|
327
493
|
# Set model in evaluation mode
|
|
328
494
|
self.network.eval()
|
|
329
495
|
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
496
|
+
# Enable or disable gradient tracking for validation pass
|
|
497
|
+
with torch.set_grad_enabled(self.network_uses_grad):
|
|
498
|
+
# Loop over validation batches
|
|
499
|
+
for data in tqdm(
|
|
500
|
+
self.valid_loader,
|
|
501
|
+
desc="Validation batches",
|
|
502
|
+
leave=False,
|
|
503
|
+
disable=self.disable_progress_bar_batch,
|
|
333
504
|
):
|
|
505
|
+
# Transfer batch data to GPU
|
|
506
|
+
data: dict[str, Tensor] = {
|
|
507
|
+
key: value.to(self.device) for key, value in data.items()
|
|
508
|
+
}
|
|
334
509
|
|
|
335
|
-
#
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
# Transfer to GPU
|
|
339
|
-
inputs, outputs = inputs.to(self.device), outputs.to(
|
|
340
|
-
self.device
|
|
341
|
-
)
|
|
510
|
+
# Execute on batch start callbacks
|
|
511
|
+
for callback in on_valid_batch_start:
|
|
512
|
+
data = callback(data)
|
|
342
513
|
|
|
343
514
|
# Model computations
|
|
344
|
-
|
|
515
|
+
data = self.network(data)
|
|
516
|
+
|
|
517
|
+
# Execute on completion forward pass callbacks
|
|
518
|
+
for callback in on_valid_completion_forward_pass:
|
|
519
|
+
data = callback(data)
|
|
345
520
|
|
|
346
521
|
# Calculate loss
|
|
347
|
-
loss = self.criterion(
|
|
522
|
+
loss = self.criterion(
|
|
523
|
+
data["output"],
|
|
524
|
+
data["target"],
|
|
525
|
+
data=data,
|
|
526
|
+
)
|
|
348
527
|
self.metric_manager.accumulate("Loss/valid", loss.unsqueeze(0))
|
|
349
528
|
|
|
350
529
|
# Validate constraints
|
|
351
|
-
self.valid_step(
|
|
530
|
+
self.valid_step(
|
|
531
|
+
data,
|
|
532
|
+
loss,
|
|
533
|
+
self.constraints,
|
|
534
|
+
self.metric_manager,
|
|
535
|
+
)
|
|
352
536
|
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
537
|
+
# Execute on batch end callbacks
|
|
538
|
+
for callback in on_valid_batch_end:
|
|
539
|
+
data = callback(data)
|
|
540
|
+
|
|
541
|
+
def _test_model(
|
|
542
|
+
self,
|
|
543
|
+
on_test_batch_start: tuple[Callable[[dict[str, Tensor]], dict[str, Tensor]]],
|
|
544
|
+
on_test_batch_end: tuple[Callable[[dict[str, Tensor]], dict[str, Tensor]]],
|
|
545
|
+
on_test_completion_forward_pass: tuple[Callable[[dict[str, Tensor]], dict[str, Tensor]]],
|
|
546
|
+
) -> None:
|
|
547
|
+
"""Evaluate the model on the test dataset.
|
|
356
548
|
|
|
357
|
-
This method
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
- Logs constraint satisfaction ratios.
|
|
549
|
+
This method sets the network to evaluation mode, iterates over the test
|
|
550
|
+
DataLoader, computes predictions, evaluates losses, and logs constraint
|
|
551
|
+
satisfaction. Optional callbacks can be executed at the start and end of
|
|
552
|
+
each batch, as well as after the forward pass.
|
|
362
553
|
|
|
363
|
-
|
|
554
|
+
Args:
|
|
555
|
+
on_test_batch_start (tuple[Callable[[dict[str, Tensor]], dict[str, Tensor]]]):
|
|
556
|
+
Callbacks executed at the start of each test batch. Each callback
|
|
557
|
+
receives the data dictionary and returns updated versions.
|
|
558
|
+
on_test_batch_end (tuple[Callable[[dict[str, Tensor]], dict[str, Tensor]]]):
|
|
559
|
+
Callbacks executed at the end of each test batch. Each callback
|
|
560
|
+
receives the data dictionary and returns updated versions.
|
|
561
|
+
on_test_completion_forward_pass (tuple[Callable[[dict[str, Tensor]], dict[str, Tensor]]]):
|
|
562
|
+
Callbacks executed immediately after the forward pass of the test batch.
|
|
563
|
+
Each callback receives the data dictionary and returns updated versions.
|
|
364
564
|
|
|
565
|
+
Returns:
|
|
566
|
+
None
|
|
567
|
+
"""
|
|
365
568
|
# Set model in evaluation mode
|
|
366
569
|
self.network.eval()
|
|
367
570
|
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
571
|
+
# Enable or disable gradient tracking for validation pass
|
|
572
|
+
with torch.set_grad_enabled(self.network_uses_grad):
|
|
573
|
+
# Loop over test batches
|
|
574
|
+
for data in tqdm(
|
|
575
|
+
self.test_loader,
|
|
576
|
+
desc="Test batches",
|
|
577
|
+
leave=False,
|
|
578
|
+
disable=self.disable_progress_bar_batch,
|
|
371
579
|
):
|
|
580
|
+
# Transfer batch data to GPU
|
|
581
|
+
data: dict[str, Tensor] = {
|
|
582
|
+
key: value.to(self.device) for key, value in data.items()
|
|
583
|
+
}
|
|
372
584
|
|
|
373
|
-
#
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
# Transfer to GPU
|
|
377
|
-
inputs, outputs = inputs.to(self.device), outputs.to(
|
|
378
|
-
self.device
|
|
379
|
-
)
|
|
585
|
+
# Execute on batch start callbacks
|
|
586
|
+
for callback in on_test_batch_start:
|
|
587
|
+
data = callback(data)
|
|
380
588
|
|
|
381
589
|
# Model computations
|
|
382
|
-
|
|
590
|
+
data = self.network(data)
|
|
591
|
+
|
|
592
|
+
# Execute on completion forward pass callbacks
|
|
593
|
+
for callback in on_test_completion_forward_pass:
|
|
594
|
+
data = callback(data)
|
|
383
595
|
|
|
384
596
|
# Calculate loss
|
|
385
|
-
loss = self.criterion(
|
|
597
|
+
loss = self.criterion(
|
|
598
|
+
data["output"],
|
|
599
|
+
data["target"],
|
|
600
|
+
data=data,
|
|
601
|
+
)
|
|
386
602
|
self.metric_manager.accumulate("Loss/test", loss.unsqueeze(0))
|
|
387
603
|
|
|
388
604
|
# Validate constraints
|
|
389
|
-
self.test_step(
|
|
605
|
+
self.test_step(
|
|
606
|
+
data,
|
|
607
|
+
loss,
|
|
608
|
+
self.constraints,
|
|
609
|
+
self.metric_manager,
|
|
610
|
+
)
|
|
611
|
+
|
|
612
|
+
# Execute on batch end callbacks
|
|
613
|
+
for callback in on_test_batch_end:
|
|
614
|
+
data = callback(data)
|
|
390
615
|
|
|
616
|
+
@staticmethod
|
|
391
617
|
def train_step(
|
|
392
|
-
|
|
393
|
-
prediction: dict[str, Tensor],
|
|
618
|
+
data: dict[str, Tensor],
|
|
394
619
|
loss: Tensor,
|
|
620
|
+
constraints: list[Constraint],
|
|
621
|
+
descriptor: Descriptor,
|
|
622
|
+
metric_manager: MetricManager,
|
|
623
|
+
device: torch.device,
|
|
624
|
+
constraint_aggregator: Callable = torch.sum,
|
|
625
|
+
epsilon: float = 1e-6,
|
|
626
|
+
enforce_all: bool = True,
|
|
395
627
|
) -> Tensor:
|
|
396
|
-
"""
|
|
397
|
-
|
|
398
|
-
|
|
628
|
+
"""Adjust the training loss based on constraints and compute the combined loss.
|
|
629
|
+
|
|
630
|
+
This method calculates the directions in which the network outputs should be
|
|
631
|
+
adjusted to satisfy constraints, scales these adjustments according to the
|
|
632
|
+
constraint's rescale factor and gradient norms, and adds the result to the
|
|
633
|
+
base loss. It also logs the constraint satisfaction ratio (CSR) for monitoring.
|
|
399
634
|
|
|
400
635
|
Args:
|
|
401
|
-
|
|
402
|
-
for variable layers.
|
|
636
|
+
data (dict[str, Tensor]): Dictionary containing the batch data, predictions and additional data.
|
|
403
637
|
loss (Tensor): The base loss computed by the criterion.
|
|
638
|
+
constraints (list[Constraint]): List of constraints to enforce during training.
|
|
639
|
+
descriptor (Descriptor): Descriptor containing layer metadata and variable/loss layer info.
|
|
640
|
+
metric_manager (MetricManager): Metric manager for logging loss and CSR.
|
|
641
|
+
device (torch.device): Device on which computations are performed.
|
|
642
|
+
constraint_aggregator (Callable, optional): Function to aggregate per-layer rescaled losses. Defaults to `torch.mean`.
|
|
643
|
+
epsilon (float, optional): Small value to prevent division by zero in gradient normalization. Defaults to 1e-6.
|
|
644
|
+
enforce_all (bool, optional): If False, constraints are only monitored and do not influence the loss. Defaults to True.
|
|
404
645
|
|
|
405
646
|
Returns:
|
|
406
|
-
Tensor: The combined loss
|
|
647
|
+
Tensor: The combined loss including the original loss and constraint-based adjustments.
|
|
407
648
|
"""
|
|
408
|
-
|
|
409
649
|
# Init scalar tensor for loss
|
|
410
|
-
total_rescale_loss = tensor(0, dtype=float32, device=
|
|
411
|
-
|
|
650
|
+
total_rescale_loss = tensor(0, dtype=float32, device=device)
|
|
651
|
+
norm_loss_grad: dict[str, Tensor] = {}
|
|
412
652
|
|
|
413
653
|
# Precalculate loss gradients for each variable layer
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
loss
|
|
418
|
-
|
|
654
|
+
for key in descriptor.variable_keys & descriptor.affects_loss_keys:
|
|
655
|
+
# Calculate gradients of loss w.r.t. predictions
|
|
656
|
+
grad = torch.autograd.grad(
|
|
657
|
+
outputs=loss, inputs=data[key], retain_graph=True, allow_unused=True
|
|
658
|
+
)[0]
|
|
659
|
+
|
|
660
|
+
# If gradients is None, report error
|
|
661
|
+
if grad is None:
|
|
662
|
+
raise RuntimeError(
|
|
663
|
+
f"Unable to compute loss gradients for layer '{key}'. "
|
|
664
|
+
"For layers not connected to the loss, set has_loss=False "
|
|
665
|
+
"when defining them in the Descriptor."
|
|
666
|
+
)
|
|
419
667
|
|
|
420
|
-
|
|
668
|
+
# Flatten batch and compute L2 norm along each item
|
|
669
|
+
grad_flat = grad.view(grad.shape[0], -1)
|
|
670
|
+
norm_loss_grad[key] = (
|
|
671
|
+
vector_norm(grad_flat, dim=1, ord=2, keepdim=True).clamp(min=epsilon).detach()
|
|
672
|
+
)
|
|
421
673
|
|
|
674
|
+
for constraint in constraints:
|
|
422
675
|
# Check if constraints are satisfied and calculate directions
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
676
|
+
checks, mask = constraint.check_constraint(data)
|
|
677
|
+
directions = constraint.calculate_direction(data)
|
|
678
|
+
|
|
679
|
+
# Log constraint satisfaction ratio
|
|
680
|
+
csr = (sum(checks * mask) / sum(mask)).unsqueeze(0)
|
|
681
|
+
metric_manager.accumulate(f"{constraint.name}/train", csr)
|
|
682
|
+
metric_manager.accumulate("CSR/train", csr)
|
|
427
683
|
|
|
428
684
|
# Only do adjusting calculation if constraint is not observant
|
|
429
|
-
if not constraint.
|
|
685
|
+
if not enforce_all or not constraint.enforce:
|
|
686
|
+
continue
|
|
687
|
+
|
|
688
|
+
# Only do direction calculations for variable layers affecting constraint
|
|
689
|
+
for key in constraint.layers & descriptor.variable_keys:
|
|
430
690
|
with no_grad():
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
)
|
|
434
|
-
|
|
435
|
-
# Only do direction calculations for variable
|
|
436
|
-
# layers affecting constraint
|
|
437
|
-
for layer in (
|
|
438
|
-
constraint.layers & self.descriptor.variable_layers
|
|
439
|
-
):
|
|
440
|
-
|
|
441
|
-
with no_grad():
|
|
442
|
-
# Multiply direction modifiers with constraint result
|
|
443
|
-
constraint_result = (
|
|
444
|
-
1 - constraint_checks.unsqueeze(1)
|
|
445
|
-
) * constraint_directions[layer]
|
|
446
|
-
|
|
447
|
-
# Multiply result with rescale factor of constraint
|
|
448
|
-
constraint_result *= constraint.rescale_factor
|
|
449
|
-
|
|
450
|
-
# Calculate loss gradient norm
|
|
451
|
-
norm_loss_grad = norm(
|
|
452
|
-
loss_grads[layer], dim=1, p=2, keepdim=True
|
|
453
|
-
)
|
|
454
|
-
|
|
455
|
-
# Apply minimum epsilon
|
|
456
|
-
norm_loss_grad = maximum(norm_loss_grad, self.epsilon)
|
|
457
|
-
|
|
458
|
-
# Calculate rescale loss
|
|
459
|
-
rescale_loss = (
|
|
460
|
-
prediction[layer]
|
|
461
|
-
* constraint_result
|
|
462
|
-
* norm_loss_grad.detach().clone()
|
|
463
|
-
).mean()
|
|
464
|
-
|
|
465
|
-
# Store rescale loss for this reference space
|
|
466
|
-
total_rescale_loss += rescale_loss
|
|
691
|
+
# Multiply direction modifiers with constraint result
|
|
692
|
+
constraint_result = (1 - checks) * directions[key]
|
|
467
693
|
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
+ relevant_constraint_count
|
|
476
|
-
)
|
|
477
|
-
/ relevant_constraint_count
|
|
478
|
-
).unsqueeze(0),
|
|
479
|
-
)
|
|
480
|
-
self.metric_manager.accumulate(
|
|
481
|
-
"CSR/train",
|
|
482
|
-
(
|
|
483
|
-
(
|
|
484
|
-
sum(constraint_checks)
|
|
485
|
-
- numel(constraint_checks)
|
|
486
|
-
+ relevant_constraint_count
|
|
487
|
-
)
|
|
488
|
-
/ relevant_constraint_count
|
|
489
|
-
).unsqueeze(0),
|
|
490
|
-
)
|
|
694
|
+
# Multiply result with rescale factor of constraint
|
|
695
|
+
constraint_result *= constraint.rescale_factor
|
|
696
|
+
|
|
697
|
+
# Calculate rescale loss
|
|
698
|
+
total_rescale_loss += constraint_aggregator(
|
|
699
|
+
data[key] * constraint_result * norm_loss_grad[key],
|
|
700
|
+
)
|
|
491
701
|
|
|
492
702
|
# Return combined loss
|
|
493
703
|
return loss + total_rescale_loss
|
|
494
704
|
|
|
705
|
+
@staticmethod
|
|
495
706
|
def valid_step(
|
|
496
|
-
|
|
497
|
-
prediction: dict[str, Tensor],
|
|
707
|
+
data: dict[str, Tensor],
|
|
498
708
|
loss: Tensor,
|
|
709
|
+
constraints: list[Constraint],
|
|
710
|
+
metric_manager: MetricManager,
|
|
499
711
|
) -> Tensor:
|
|
500
|
-
"""
|
|
501
|
-
|
|
712
|
+
"""Evaluate constraints during validation and log constraint satisfaction metrics.
|
|
713
|
+
|
|
714
|
+
This method checks whether each constraint is satisfied for the given
|
|
715
|
+
data, computes the constraint satisfaction ratio (CSR),
|
|
716
|
+
and logs it using the metric manager. The base loss is not modified.
|
|
502
717
|
|
|
503
718
|
Args:
|
|
504
|
-
|
|
505
|
-
variable layers.
|
|
719
|
+
data (dict[str, Tensor]): Dictionary containing the batch data, predictions and additional data.
|
|
506
720
|
loss (Tensor): The base loss computed by the criterion.
|
|
721
|
+
constraints (list[Constraint]): List of constraints to evaluate.
|
|
722
|
+
metric_manager (MetricManager): Metric manager for logging CSR and per-constraint metrics.
|
|
507
723
|
|
|
508
724
|
Returns:
|
|
509
|
-
Tensor: The unchanged base loss.
|
|
725
|
+
Tensor: The original, unchanged base loss.
|
|
510
726
|
"""
|
|
511
|
-
|
|
512
727
|
# For each constraint in this reference space, calculate directions
|
|
513
|
-
for constraint in
|
|
514
|
-
|
|
728
|
+
for constraint in constraints:
|
|
515
729
|
# Check if constraints are satisfied for
|
|
516
|
-
|
|
517
|
-
constraint.check_constraint(prediction)
|
|
518
|
-
)
|
|
730
|
+
checks, mask = constraint.check_constraint(data)
|
|
519
731
|
|
|
520
732
|
# Log constraint satisfaction ratio
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
(
|
|
525
|
-
sum(constraint_checks)
|
|
526
|
-
- numel(constraint_checks)
|
|
527
|
-
+ relevant_constraint_count
|
|
528
|
-
)
|
|
529
|
-
/ relevant_constraint_count
|
|
530
|
-
).unsqueeze(0),
|
|
531
|
-
)
|
|
532
|
-
self.metric_manager.accumulate(
|
|
533
|
-
"CSR/valid",
|
|
534
|
-
(
|
|
535
|
-
(
|
|
536
|
-
sum(constraint_checks)
|
|
537
|
-
- numel(constraint_checks)
|
|
538
|
-
+ relevant_constraint_count
|
|
539
|
-
)
|
|
540
|
-
/ relevant_constraint_count
|
|
541
|
-
).unsqueeze(0),
|
|
542
|
-
)
|
|
733
|
+
csr = (sum(checks * mask) / sum(mask)).unsqueeze(0)
|
|
734
|
+
metric_manager.accumulate(f"{constraint.name}/valid", csr)
|
|
735
|
+
metric_manager.accumulate("CSR/valid", csr)
|
|
543
736
|
|
|
544
|
-
# Return loss
|
|
737
|
+
# Return original loss
|
|
545
738
|
return loss
|
|
546
739
|
|
|
740
|
+
@staticmethod
|
|
547
741
|
def test_step(
|
|
548
|
-
|
|
549
|
-
prediction: dict[str, Tensor],
|
|
742
|
+
data: dict[str, Tensor],
|
|
550
743
|
loss: Tensor,
|
|
744
|
+
constraints: list[Constraint],
|
|
745
|
+
metric_manager: MetricManager,
|
|
551
746
|
) -> Tensor:
|
|
552
|
-
"""
|
|
553
|
-
|
|
747
|
+
"""Evaluate constraints during testing and log constraint satisfaction metrics.
|
|
748
|
+
|
|
749
|
+
This method checks whether each constraint is satisfied for the given
|
|
750
|
+
data, computes the constraint satisfaction ratio (CSR),
|
|
751
|
+
and logs it using the metric manager. The base loss is not modified.
|
|
554
752
|
|
|
555
753
|
Args:
|
|
556
|
-
|
|
557
|
-
for variable layers.
|
|
754
|
+
data (dict[str, Tensor]): Dictionary containing the batch data, predictions and additional data.
|
|
558
755
|
loss (Tensor): The base loss computed by the criterion.
|
|
756
|
+
constraints (list[Constraint]): List of constraints to evaluate.
|
|
757
|
+
metric_manager (MetricManager): Metric manager for logging CSR and per-constraint metrics.
|
|
559
758
|
|
|
560
759
|
Returns:
|
|
561
|
-
Tensor: The unchanged base loss.
|
|
760
|
+
Tensor: The original, unchanged base loss.
|
|
562
761
|
"""
|
|
563
|
-
|
|
564
762
|
# For each constraint in this reference space, calculate directions
|
|
565
|
-
for constraint in
|
|
566
|
-
|
|
763
|
+
for constraint in constraints:
|
|
567
764
|
# Check if constraints are satisfied for
|
|
568
|
-
|
|
569
|
-
constraint.check_constraint(prediction)
|
|
570
|
-
)
|
|
765
|
+
checks, mask = constraint.check_constraint(data)
|
|
571
766
|
|
|
572
767
|
# Log constraint satisfaction ratio
|
|
573
|
-
|
|
574
|
-
|
|
575
|
-
|
|
576
|
-
(
|
|
577
|
-
sum(constraint_checks)
|
|
578
|
-
- numel(constraint_checks)
|
|
579
|
-
+ relevant_constraint_count
|
|
580
|
-
)
|
|
581
|
-
/ relevant_constraint_count
|
|
582
|
-
).unsqueeze(0),
|
|
583
|
-
)
|
|
584
|
-
self.metric_manager.accumulate(
|
|
585
|
-
"CSR/test",
|
|
586
|
-
(
|
|
587
|
-
(
|
|
588
|
-
sum(constraint_checks)
|
|
589
|
-
- numel(constraint_checks)
|
|
590
|
-
+ relevant_constraint_count
|
|
591
|
-
)
|
|
592
|
-
/ relevant_constraint_count
|
|
593
|
-
).unsqueeze(0),
|
|
594
|
-
)
|
|
768
|
+
csr = (sum(checks * mask) / sum(mask)).unsqueeze(0)
|
|
769
|
+
metric_manager.accumulate(f"{constraint.name}/test", csr)
|
|
770
|
+
metric_manager.accumulate("CSR/test", csr)
|
|
595
771
|
|
|
596
|
-
# Return loss
|
|
772
|
+
# Return original loss
|
|
597
773
|
return loss
|