congrads 1.0.6__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
congrads/core.py CHANGED
@@ -1,15 +1,15 @@
1
- """
2
- This module provides the CongradsCore class, which is designed to integrate
3
- constraint-guided optimization into neural network training.
4
- It extends traditional training processes by enforcing specific constraints
5
- on the model's outputs, ensuring that the network satisfies domain-specific
1
+ """This module provides the core CongradsCore class for the main training functionality.
2
+
3
+ It is designed to integrate constraint-guided optimization into neural network training.
4
+ It extends traditional training processes by enforcing specific constraints
5
+ on the model's outputs, ensuring that the network satisfies domain-specific
6
6
  requirements during both training and evaluation.
7
7
 
8
- The `CongradsCore` class serves as the central engine for managing the
9
- training, validation, and testing phases of a neural network model,
10
- incorporating constraints that influence the loss function and model updates.
11
- The model is trained with standard loss functions while also incorporating
12
- constraint-based adjustments, which are tracked and logged
8
+ The `CongradsCore` class serves as the central engine for managing the
9
+ training, validation, and testing phases of a neural network model,
10
+ incorporating constraints that influence the loss function and model updates.
11
+ The model is trained with standard loss functions while also incorporating
12
+ constraint-based adjustments, which are tracked and logged
13
13
  throughout the process.
14
14
 
15
15
  Key features:
@@ -18,37 +18,19 @@ Key features:
18
18
  - Metric management for tracking loss and constraint satisfaction.
19
19
  - Checkpoint management for saving and evaluating model states.
20
20
 
21
- Modules in this package provide the following:
22
-
23
- - `Descriptor`: Describes variable layers in the network that are
24
- subject to constraints.
25
- - `Constraint`: Defines various constraints, which are used to guide
26
- the training process.
27
- - `MetricManager`: Manages and tracks performance metrics such as loss
28
- and constraint satisfaction.
29
- - `CheckpointManager`: Manages saving and loading model checkpoints
30
- during training.
31
- - Utility functions to validate inputs and configurations.
32
-
33
- Dependencies:
34
- - PyTorch (`torch`)
35
- - tqdm (for progress tracking)
36
-
37
- The `CongradsCore` class allows for the use of additional callback functions
38
- at different stages of the training process to customize behavior for
39
- specific needs. These include callbacks for the start and end of epochs, as
21
+ The `CongradsCore` class allows for the use of additional callback functions
22
+ at different stages of the training process to customize behavior for
23
+ specific needs. These include callbacks for the start and end of epochs, as
40
24
  well as the start and end of the entire training process.
41
25
 
42
26
  """
43
27
 
44
28
  import warnings
45
- from numbers import Number
46
- from typing import Callable
29
+ from collections.abc import Callable
47
30
 
48
31
  import torch
49
-
50
- # pylint: disable-next=redefined-builtin
51
- from torch import Tensor, float32, maximum, no_grad, norm, numel, sum, tensor
32
+ from torch import Tensor, float32, no_grad, sum, tensor
33
+ from torch.linalg import vector_norm
52
34
  from torch.nn import Module
53
35
  from torch.nn.modules.loss import _Loss
54
36
  from torch.optim import Optimizer
@@ -60,7 +42,10 @@ from .constraints import Constraint
60
42
  from .descriptor import Descriptor
61
43
  from .metrics import MetricManager
62
44
  from .utils import (
45
+ is_torch_loss,
46
+ torch_loss_wrapper,
63
47
  validate_callable,
48
+ validate_callable_iterable,
64
49
  validate_iterable,
65
50
  validate_loaders,
66
51
  validate_type,
@@ -68,32 +53,11 @@ from .utils import (
68
53
 
69
54
 
70
55
  class CongradsCore:
71
- """
72
- The CongradsCore class is the central training engine for constraint-guided
73
- neural network optimization. It integrates standard neural network training
56
+ """The CongradsCore class is the central training engine for constraint-guided optimization.
57
+
58
+ It integrates standard neural network training
74
59
  with additional constraint-driven adjustments to the loss function, ensuring
75
60
  that the network satisfies domain-specific constraints during training.
76
-
77
- Args:
78
- descriptor (Descriptor): Describes variable layers in the network.
79
- constraints (list[Constraint]): List of constraints to guide training.
80
- loaders (tuple[DataLoader, DataLoader, DataLoader]): DataLoaders for
81
- training, validation, and testing.
82
- network (Module): The neural network model to train.
83
- criterion (callable): The loss function used for
84
- training and validation.
85
- optimizer (Optimizer): The optimizer used for updating model parameters.
86
- metric_manager (MetricManager): Manages metric tracking and recording.
87
- device (torch.device): The device (e.g., CPU or GPU) for computations.
88
- checkpoint_manager (CheckpointManager, optional): Manages
89
- checkpointing. If not set, no checkpointing is done.
90
- epsilon (Number, optional): A small value to avoid division by zero
91
- in gradient calculations. Default is 1e-10.
92
-
93
- Note:
94
- A warning is logged if the descriptor has no variable layers,
95
- as at least one variable layer is required for the constraint logic
96
- to influence the training process.
97
61
  """
98
62
 
99
63
  def __init__(
@@ -106,29 +70,69 @@ class CongradsCore:
106
70
  optimizer: Optimizer,
107
71
  metric_manager: MetricManager,
108
72
  device: torch.device,
73
+ network_uses_grad: bool = False,
109
74
  checkpoint_manager: CheckpointManager = None,
110
- epsilon: Number = 1e-6,
75
+ epsilon: float = 1e-6,
76
+ constraint_aggregator: Callable[..., Tensor] = sum,
77
+ disable_progress_bar_epoch: bool = False,
78
+ disable_progress_bar_batch: bool = False,
79
+ enforce_all: bool = True,
111
80
  ):
112
- """
113
- Initialize the CongradsCore object.
114
- """
81
+ """Initialize the CongradsCore object.
115
82
 
83
+ Args:
84
+ descriptor (Descriptor): Describes variable layers in the network.
85
+ constraints (list[Constraint]): List of constraints to guide training.
86
+ loaders (tuple[DataLoader, DataLoader, DataLoader]): DataLoaders for
87
+ training, validation, and testing.
88
+ network (Module): The neural network model to train.
89
+ criterion (callable): The loss function used for
90
+ training and validation.
91
+ optimizer (Optimizer): The optimizer used for updating model parameters.
92
+ metric_manager (MetricManager): Manages metric tracking and recording.
93
+ device (torch.device): The device (e.g., CPU or GPU) for computations.
94
+ network_uses_grad (bool, optional): A flag indicating if the network
95
+ contains gradient calculation computations. Default is False.
96
+ checkpoint_manager (CheckpointManager, optional): Manages
97
+ checkpointing. If not set, no checkpointing is done.
98
+ epsilon (float, optional): A small value to avoid division by zero
99
+ in gradient calculations. Default is 1e-10.
100
+ constraint_aggregator (Callable[..., Tensor], optional): A function
101
+ to aggregate the constraint rescale loss. Default is `sum`.
102
+ disable_progress_bar_epoch (bool, optional): If set to True, the epoch
103
+ progress bar will not show. Defaults to False.
104
+ disable_progress_bar_batch (bool, optional): If set to True, the batch
105
+ progress bar will not show. Defaults to False.
106
+ enforce_all (bool, optional): If set to False, constraints will only be monitored and
107
+ not influence the training process. Overrides constraint-specific `enforce` parameters.
108
+ Defaults to True.
109
+
110
+ Note:
111
+ A warning is logged if the descriptor has no variable layers,
112
+ as at least one variable layer is required for the constraint logic
113
+ to influence the training process.
114
+ """
116
115
  # Type checking
117
116
  validate_type("descriptor", descriptor, Descriptor)
118
- validate_iterable("constraints", constraints, Constraint)
119
- validate_loaders()
117
+ validate_iterable("constraints", constraints, Constraint, allow_empty=True)
118
+ validate_loaders("loaders", loaders)
120
119
  validate_type("network", network, Module)
121
120
  validate_type("criterion", criterion, _Loss)
122
121
  validate_type("optimizer", optimizer, Optimizer)
123
122
  validate_type("metric_manager", metric_manager, MetricManager)
124
123
  validate_type("device", device, torch.device)
124
+ validate_type("network_uses_grad", network_uses_grad, bool)
125
125
  validate_type(
126
126
  "checkpoint_manager",
127
127
  checkpoint_manager,
128
128
  CheckpointManager,
129
129
  allow_none=True,
130
130
  )
131
- validate_type("epsilon", epsilon, Number)
131
+ validate_type("epsilon", epsilon, float)
132
+ validate_callable("constraint_aggregator", constraint_aggregator, allow_none=True)
133
+ validate_type("disable_progress_bar_epoch", disable_progress_bar_epoch, bool)
134
+ validate_type("disable_progress_bar_batch", disable_progress_bar_batch, bool)
135
+ validate_type("enforce_all", enforce_all, bool)
132
136
 
133
137
  # Init object variables
134
138
  self.descriptor = descriptor
@@ -137,30 +141,38 @@ class CongradsCore:
137
141
  self.valid_loader = loaders[1]
138
142
  self.test_loader = loaders[2]
139
143
  self.network = network
140
- self.criterion = criterion
141
144
  self.optimizer = optimizer
142
145
  self.metric_manager = metric_manager
143
146
  self.device = device
147
+ self.network_uses_grad = network_uses_grad
144
148
  self.checkpoint_manager = checkpoint_manager
145
-
146
- # Init epsilon tensor
147
- self.epsilon = tensor(epsilon, device=self.device)
149
+ self.epsilon = epsilon
150
+ self.constraint_aggregator = constraint_aggregator
151
+ self.disable_progress_bar_epoch = disable_progress_bar_epoch
152
+ self.disable_progress_bar_batch = disable_progress_bar_batch
153
+ self.enforce_all = enforce_all
154
+
155
+ # Check if criterion is a torch loss function
156
+ if is_torch_loss(criterion):
157
+ # If so, wrap it in a custom loss function
158
+ self.criterion = torch_loss_wrapper(criterion)
159
+ else:
160
+ self.criterion = criterion
148
161
 
149
162
  # Perform checks
150
- if len(self.descriptor.variable_layers) == 0:
163
+ if len(self.descriptor.variable_keys) == 0:
151
164
  warnings.warn(
152
165
  "The descriptor object has no variable layers. The constraint \
153
166
  guided loss adjustment is therefore not used. \
154
- Is this the intended behavior?"
167
+ Is this the intended behavior?",
168
+ stacklevel=2,
155
169
  )
156
170
 
157
171
  # Initialize constraint metrics
158
172
  self._initialize_metrics()
159
173
 
160
174
  def _initialize_metrics(self) -> None:
161
- """
162
- Register metrics for loss, constraint satisfaction ratio (CSR),
163
- and individual constraints.
175
+ """Register metrics for loss, constraint satisfaction ratio (CSR), and constraints.
164
176
 
165
177
  This method registers the following metrics:
166
178
 
@@ -173,7 +185,6 @@ class CongradsCore:
173
185
  - One metric per constraint, for both training and validation.
174
186
 
175
187
  """
176
-
177
188
  self.metric_manager.register("Loss/train", "during_training")
178
189
  self.metric_manager.register("Loss/valid", "during_training")
179
190
  self.metric_manager.register("Loss/test", "after_training")
@@ -184,414 +195,579 @@ class CongradsCore:
184
195
  self.metric_manager.register("CSR/test", "after_training")
185
196
 
186
197
  for constraint in self.constraints:
187
- self.metric_manager.register(
188
- f"{constraint.name}/train", "during_training"
189
- )
190
- self.metric_manager.register(
191
- f"{constraint.name}/valid", "during_training"
192
- )
193
- self.metric_manager.register(
194
- f"{constraint.name}/test", "after_training"
195
- )
198
+ self.metric_manager.register(f"{constraint.name}/train", "during_training")
199
+ self.metric_manager.register(f"{constraint.name}/valid", "during_training")
200
+ self.metric_manager.register(f"{constraint.name}/test", "after_training")
196
201
 
197
202
  def fit(
198
203
  self,
199
204
  start_epoch: int = 0,
200
205
  max_epochs: int = 100,
201
- on_epoch_start: Callable[[int], None] = None,
202
- on_epoch_end: Callable[[int], None] = None,
203
- on_train_start: Callable[[int], None] = None,
204
- on_train_end: Callable[[int], None] = None,
206
+ test_model: bool = True,
207
+ on_batch_start: list[Callable[[dict[str, Tensor]], dict[str, Tensor]]] | None = None,
208
+ on_batch_end: list[Callable[[dict[str, Tensor]], dict[str, Tensor]]] | None = None,
209
+ on_train_batch_start: list[Callable[[dict[str, Tensor]], dict[str, Tensor]]] | None = None,
210
+ on_train_batch_end: list[Callable[[dict[str, Tensor]], dict[str, Tensor]]] | None = None,
211
+ on_valid_batch_start: list[Callable[[dict[str, Tensor]], dict[str, Tensor]]] | None = None,
212
+ on_valid_batch_end: list[Callable[[dict[str, Tensor]], dict[str, Tensor]]] | None = None,
213
+ on_test_batch_start: list[Callable[[dict[str, Tensor]], dict[str, Tensor]]] | None = None,
214
+ on_test_batch_end: list[Callable[[dict[str, Tensor]], dict[str, Tensor]]] | None = None,
215
+ on_epoch_start: list[Callable[[int], None]] | None = None,
216
+ on_epoch_end: list[Callable[[int], None]] | None = None,
217
+ on_train_start: list[Callable[[int], None]] | None = None,
218
+ on_train_end: list[Callable[[int], None]] | None = None,
219
+ on_train_completion_forward_pass: list[Callable[[dict[str, Tensor]], dict[str, Tensor]]]
220
+ | None = None,
221
+ on_val_completion_forward_pass: list[Callable[[dict[str, Tensor]], dict[str, Tensor]]]
222
+ | None = None,
223
+ on_test_completion_forward_pass: list[Callable[[dict[str, Tensor]], dict[str, Tensor]]]
224
+ | None = None,
225
+ on_test_start: list[Callable[[int], None]] | None = None,
226
+ on_test_end: list[Callable[[int], None]] | None = None,
205
227
  ) -> None:
206
- """
207
- Train the model for a given number of epochs.
228
+ """Train the model over multiple epochs with optional validation and testing.
229
+
230
+ This method manages the full training loop, including:
231
+
232
+ - Executing epoch-level and batch-level callbacks.
233
+ - Training and validating the model each epoch.
234
+ - Adjusting losses according to constraints.
235
+ - Logging metrics via the metric manager.
236
+ - Optional evaluation on the test set.
237
+ - Checkpointing the model during and after training.
208
238
 
209
239
  Args:
210
- start_epoch (int, optional): The epoch number to start the training
211
- with. Default is 0.
212
- max_epochs (int, optional): The number of epochs to train the
213
- model. Default is 100.
214
- on_epoch_start (Callable[[int], None], optional): A callback
215
- function that will be executed at the start of each epoch.
216
- on_epoch_end (Callable[[int], None], optional): A callback
217
- function that will be executed at the end of each epoch.
218
- on_train_start (Callable[[int], None], optional): A callback
219
- function that will be executed before the training starts.
220
- on_train_end (Callable[[int], None], optional): A callback
221
- function that will be executed after training ends.
240
+ start_epoch (int, optional): Epoch number to start training from. Defaults to 0.
241
+ max_epochs (int, optional): Total number of epochs to train. Defaults to 100.
242
+ test_model (bool, optional): If True, evaluate the model on the test set after training. Defaults to True.
243
+ on_batch_start (list[Callable], optional): Callbacks executed at the start of every batch. Defaults to None.
244
+ on_batch_end (list[Callable], optional): Callbacks executed at the end of every batch. Defaults to None.
245
+ on_train_batch_start (list[Callable], optional): Callbacks executed at the start of each training batch. Defaults to `on_batch_start` if not provided.
246
+ on_train_batch_end (list[Callable], optional): Callbacks executed at the end of each training batch. Defaults to `on_batch_end` if not provided.
247
+ on_valid_batch_start (list[Callable], optional): Callbacks executed at the start of each validation batch. Defaults to `on_batch_start` if not provided.
248
+ on_valid_batch_end (list[Callable], optional): Callbacks executed at the end of each validation batch. Defaults to `on_batch_end` if not provided.
249
+ on_test_batch_start (list[Callable], optional): Callbacks executed at the start of each test batch. Defaults to `on_batch_start` if not provided.
250
+ on_test_batch_end (list[Callable], optional): Callbacks executed at the end of each test batch. Defaults to `on_batch_end` if not provided.
251
+ on_epoch_start (list[Callable], optional): Callbacks executed at the start of each epoch. Defaults to None.
252
+ on_epoch_end (list[Callable], optional): Callbacks executed at the end of each epoch. Defaults to None.
253
+ on_train_start (list[Callable], optional): Callbacks executed before training starts. Defaults to None.
254
+ on_train_end (list[Callable], optional): Callbacks executed after training ends. Defaults to None.
255
+ on_train_completion_forward_pass (list[Callable], optional): Callbacks executed after the forward pass during training. Defaults to None.
256
+ on_val_completion_forward_pass (list[Callable], optional): Callbacks executed after the forward pass during validation. Defaults to None.
257
+ on_test_completion_forward_pass (list[Callable], optional): Callbacks executed after the forward pass during testing. Defaults to None.
258
+ on_test_start (list[Callable], optional): Callbacks executed before testing starts. Defaults to None.
259
+ on_test_end (list[Callable], optional): Callbacks executed after testing ends. Defaults to None.
260
+
261
+ Notes:
262
+ - If phase-specific callbacks (train/valid/test) are not provided, the global `on_batch_start` and `on_batch_end` are used.
263
+ - Training metrics, loss adjustments, and constraint satisfaction ratios are automatically logged via the metric manager.
264
+ - The final model checkpoint is saved if a checkpoint manager is configured.
222
265
  """
223
-
224
266
  # Type checking
225
267
  validate_type("start_epoch", start_epoch, int)
226
- validate_callable("on_epoch_start", on_epoch_start, True)
227
- validate_callable("on_epoch_end", on_epoch_end, True)
228
- validate_callable("on_train_start", on_train_start, True)
229
- validate_callable("on_train_end", on_train_end, True)
268
+ validate_type("max_epochs", max_epochs, int)
269
+ validate_type("test_model", test_model, bool)
270
+ validate_callable_iterable("on_batch_start", on_batch_start, allow_none=True)
271
+ validate_callable_iterable("on_batch_end", on_batch_end, allow_none=True)
272
+ validate_callable_iterable("on_train_batch_start", on_train_batch_start, allow_none=True)
273
+ validate_callable_iterable("on_train_batch_end", on_train_batch_end, allow_none=True)
274
+ validate_callable_iterable("on_valid_batch_start", on_valid_batch_start, allow_none=True)
275
+ validate_callable_iterable("on_valid_batch_end", on_valid_batch_end, allow_none=True)
276
+ validate_callable_iterable("on_test_batch_start", on_test_batch_start, allow_none=True)
277
+ validate_callable_iterable("on_test_batch_end", on_test_batch_end, allow_none=True)
278
+ validate_callable_iterable("on_epoch_start", on_epoch_start, allow_none=True)
279
+ validate_callable_iterable("on_epoch_end", on_epoch_end, allow_none=True)
280
+ validate_callable_iterable("on_train_start", on_train_start, allow_none=True)
281
+ validate_callable_iterable("on_train_end", on_train_end, allow_none=True)
282
+ validate_callable_iterable(
283
+ "on_train_completion_forward_pass",
284
+ on_train_completion_forward_pass,
285
+ allow_none=True,
286
+ )
287
+ validate_callable_iterable(
288
+ "on_val_completion_forward_pass",
289
+ on_val_completion_forward_pass,
290
+ allow_none=True,
291
+ )
292
+ validate_callable_iterable(
293
+ "on_test_completion_forward_pass",
294
+ on_test_completion_forward_pass,
295
+ allow_none=True,
296
+ )
297
+ validate_callable_iterable("on_test_start", on_test_start, allow_none=True)
298
+ validate_callable_iterable("on_test_end", on_test_end, allow_none=True)
299
+
300
+ # Use global batch callback if phase-specific callback is unset
301
+ # Init callbacks as empty list if None
302
+ on_train_batch_start = on_train_batch_start or on_batch_start or []
303
+ on_train_batch_end = on_train_batch_end or on_batch_end or []
304
+ on_valid_batch_start = on_valid_batch_start or on_batch_start or []
305
+ on_valid_batch_end = on_valid_batch_end or on_batch_end or []
306
+ on_test_batch_start = on_test_batch_start or on_batch_start or []
307
+ on_test_batch_end = on_test_batch_end or on_batch_end or []
308
+ on_batch_start = on_batch_start or []
309
+ on_batch_end = on_batch_end or []
310
+ on_epoch_start = on_epoch_start or []
311
+ on_epoch_end = on_epoch_end or []
312
+ on_train_start = on_train_start or []
313
+ on_train_end = on_train_end or []
314
+ on_train_completion_forward_pass = on_train_completion_forward_pass or []
315
+ on_val_completion_forward_pass = on_val_completion_forward_pass or []
316
+ on_test_completion_forward_pass = on_test_completion_forward_pass or []
317
+ on_test_start = on_test_start or []
318
+ on_test_end = on_test_end or []
230
319
 
231
320
  # Keep track of epoch
232
321
  epoch = start_epoch
233
322
 
234
323
  # Execute training start hook if set
235
- if on_train_start:
236
- on_train_start(epoch)
237
-
238
- for i in tqdm(range(epoch, max_epochs), initial=epoch, desc="Epoch"):
324
+ for callback in on_train_start:
325
+ callback(epoch)
326
+
327
+ for i in tqdm(
328
+ range(epoch, max_epochs),
329
+ initial=epoch,
330
+ desc="Epoch",
331
+ disable=self.disable_progress_bar_epoch,
332
+ ):
239
333
  epoch = i
240
334
 
241
335
  # Execute epoch start hook if set
242
- if on_epoch_start:
243
- on_epoch_start(epoch)
336
+ for callback in on_epoch_start:
337
+ callback(epoch)
244
338
 
245
339
  # Execute training and validation epoch
246
- self._train_epoch()
247
- self._validate_epoch()
340
+ self._train_epoch(
341
+ on_train_batch_start,
342
+ on_train_batch_end,
343
+ on_train_completion_forward_pass,
344
+ )
345
+ self._validate_epoch(
346
+ on_valid_batch_start,
347
+ on_valid_batch_end,
348
+ on_val_completion_forward_pass,
349
+ )
248
350
 
249
351
  # Checkpointing
250
352
  if self.checkpoint_manager:
251
353
  self.checkpoint_manager.evaluate_criteria(epoch)
252
354
 
253
355
  # Execute epoch end hook if set
254
- if on_epoch_end:
255
- on_epoch_end(epoch)
356
+ for callback in on_epoch_end:
357
+ callback(epoch)
358
+
359
+ # Execute training end hook if set
360
+ for callback in on_train_end:
361
+ callback(epoch)
362
+
363
+ # Evaluate model performance on unseen test set if required
364
+ if test_model:
365
+ # Execute test end hook if set
366
+ for callback in on_test_start:
367
+ callback(epoch)
368
+
369
+ self._test_model(
370
+ on_test_batch_start,
371
+ on_test_batch_end,
372
+ on_test_completion_forward_pass,
373
+ )
256
374
 
257
- # Evaluate model performance on unseen test set
258
- self._test_model()
375
+ # Execute test end hook if set
376
+ for callback in on_test_end:
377
+ callback(epoch)
259
378
 
260
379
  # Save final model
261
380
  if self.checkpoint_manager:
262
381
  self.checkpoint_manager.save(epoch, "checkpoint_final.pth")
263
382
 
264
- # Execute training end hook if set
265
- if on_train_end:
266
- on_train_end(epoch)
267
-
268
- def _train_epoch(self) -> None:
269
- """
270
- Perform training for a single epoch.
383
+ def _train_epoch(
384
+ self,
385
+ on_train_batch_start: tuple[Callable[[dict[str, Tensor]], dict[str, Tensor]], ...],
386
+ on_train_batch_end: tuple[Callable[[dict[str, Tensor]], dict[str, Tensor]], ...],
387
+ on_train_completion_forward_pass: tuple[
388
+ Callable[[dict[str, Tensor]], dict[str, Tensor]], ...
389
+ ],
390
+ ) -> None:
391
+ """Perform a single training epoch over all batches.
271
392
 
272
- This method:
273
- - Sets the model to training mode.
274
- - Processes batches from the training DataLoader.
275
- - Computes predictions and losses.
276
- - Adjusts losses based on constraints.
277
- - Updates model parameters using backpropagation.
393
+ This method sets the network to training mode, iterates over the training
394
+ DataLoader, computes predictions, evaluates losses, applies constraint-based
395
+ adjustments, performs backpropagation, and updates model parameters. It also
396
+ supports executing optional callbacks at different stages of the batch
397
+ processing.
278
398
 
279
399
  Args:
280
- epoch (int): The current epoch number.
281
- """
400
+ on_train_batch_start (tuple[Callable[[dict[str, Tensor]], dict[str, Tensor]], ...]):
401
+ Callbacks executed at the start of each batch. Each callback receives the
402
+ data dictionary and returns updated versions.
403
+ on_train_batch_end (tuple[Callable[[dict[str, Tensor]], dict[str, Tensor]], ...]):
404
+ Callbacks executed at the end of each batch. Each callback receives the
405
+ data dictionary and returns updated versions.
406
+ on_train_completion_forward_pass (tuple[Callable[[dict[str, Tensor]], dict[str, Tensor]], ...]):
407
+ Callbacks executed immediately after the forward pass of the batch.
408
+ Each callback receives the data dictionary and returns updated versions.
282
409
 
410
+ Returns:
411
+ None
412
+ """
283
413
  # Set model in training mode
284
414
  self.network.train()
285
415
 
286
- for batch in tqdm(
287
- self.train_loader, desc="Training batches", leave=False
416
+ for data in tqdm(
417
+ self.train_loader,
418
+ desc="Training batches",
419
+ leave=False,
420
+ disable=self.disable_progress_bar_batch,
288
421
  ):
422
+ # Transfer batch data to GPU
423
+ data: dict[str, Tensor] = {key: value.to(self.device) for key, value in data.items()}
289
424
 
290
- # Get input-output pairs from batch
291
- inputs, outputs = batch
292
-
293
- # Transfer to GPU
294
- inputs, outputs = inputs.to(self.device), outputs.to(self.device)
425
+ # Execute on batch start callbacks
426
+ for callback in on_train_batch_start:
427
+ data = callback(data)
295
428
 
296
429
  # Model computations
297
- prediction = self.network(inputs)
430
+ data = self.network(data)
431
+
432
+ # Execute on completion forward pass callbacks
433
+ for callback in on_train_completion_forward_pass:
434
+ data = callback(data)
298
435
 
299
436
  # Calculate loss
300
- loss = self.criterion(prediction["output"], outputs)
437
+ loss = self.criterion(
438
+ data["output"],
439
+ data["target"],
440
+ data=data,
441
+ )
301
442
  self.metric_manager.accumulate("Loss/train", loss.unsqueeze(0))
302
443
 
303
444
  # Adjust loss based on constraints
304
- combined_loss = self.train_step(prediction, loss)
445
+ combined_loss = self.train_step(
446
+ data,
447
+ loss,
448
+ self.constraints,
449
+ self.descriptor,
450
+ self.metric_manager,
451
+ self.device,
452
+ constraint_aggregator=self.constraint_aggregator,
453
+ epsilon=self.epsilon,
454
+ enforce_all=self.enforce_all,
455
+ )
305
456
 
306
457
  # Backprop
307
458
  self.optimizer.zero_grad()
308
- combined_loss.backward(
309
- retain_graph=False, inputs=list(self.network.parameters())
310
- )
459
+ combined_loss.backward(retain_graph=False, inputs=list(self.network.parameters()))
311
460
  self.optimizer.step()
312
461
 
313
- def _validate_epoch(self) -> None:
314
- """
315
- Perform validation for a single epoch.
462
+ # Execute on batch end callbacks
463
+ for callback in on_train_batch_end:
464
+ data = callback(data)
465
+
466
+ def _validate_epoch(
467
+ self,
468
+ on_valid_batch_start: tuple[Callable[[dict[str, Tensor]], dict[str, Tensor]]],
469
+ on_valid_batch_end: tuple[Callable[[dict[str, Tensor]], dict[str, Tensor]]],
470
+ on_valid_completion_forward_pass: tuple[Callable[[dict[str, Tensor]], dict[str, Tensor]]],
471
+ ) -> None:
472
+ """Perform a single validation epoch over all batches.
316
473
 
317
- This method:
318
- - Sets the model to evaluation mode.
319
- - Processes batches from the validation DataLoader.
320
- - Computes predictions and losses.
321
- - Logs constraint satisfaction ratios.
474
+ This method sets the network to evaluation mode, iterates over the validation
475
+ DataLoader, computes predictions, evaluates losses, and logs constraint
476
+ satisfaction. Optional callbacks can be executed at the start and end of each
477
+ batch, as well as after the forward pass.
322
478
 
323
479
  Args:
324
- epoch (int): The current epoch number.
325
- """
480
+ on_valid_batch_start (tuple[Callable[[dict[str, Tensor]], dict[str, Tensor]]]):
481
+ Callbacks executed at the start of each validation batch. Each callback
482
+ receives the data dictionary and returns updated versions.
483
+ on_valid_batch_end (tuple[Callable[[dict[str, Tensor]], dict[str, Tensor]]]):
484
+ Callbacks executed at the end of each validation batch. Each callback
485
+ receives the data dictionary and returns updated versions.
486
+ on_valid_completion_forward_pass (tuple[Callable[[dict[str, Tensor]], dict[str, Tensor]]]):
487
+ Callbacks executed immediately after the forward pass of the validation batch.
488
+ Each callback receives the data dictionary and returns updated versions.
326
489
 
490
+ Returns:
491
+ None
492
+ """
327
493
  # Set model in evaluation mode
328
494
  self.network.eval()
329
495
 
330
- with no_grad():
331
- for batch in tqdm(
332
- self.valid_loader, desc="Validation batches", leave=False
496
+ # Enable or disable gradient tracking for validation pass
497
+ with torch.set_grad_enabled(self.network_uses_grad):
498
+ # Loop over validation batches
499
+ for data in tqdm(
500
+ self.valid_loader,
501
+ desc="Validation batches",
502
+ leave=False,
503
+ disable=self.disable_progress_bar_batch,
333
504
  ):
505
+ # Transfer batch data to GPU
506
+ data: dict[str, Tensor] = {
507
+ key: value.to(self.device) for key, value in data.items()
508
+ }
334
509
 
335
- # Get input-output pairs from batch
336
- inputs, outputs = batch
337
-
338
- # Transfer to GPU
339
- inputs, outputs = inputs.to(self.device), outputs.to(
340
- self.device
341
- )
510
+ # Execute on batch start callbacks
511
+ for callback in on_valid_batch_start:
512
+ data = callback(data)
342
513
 
343
514
  # Model computations
344
- prediction = self.network(inputs)
515
+ data = self.network(data)
516
+
517
+ # Execute on completion forward pass callbacks
518
+ for callback in on_valid_completion_forward_pass:
519
+ data = callback(data)
345
520
 
346
521
  # Calculate loss
347
- loss = self.criterion(prediction["output"], outputs)
522
+ loss = self.criterion(
523
+ data["output"],
524
+ data["target"],
525
+ data=data,
526
+ )
348
527
  self.metric_manager.accumulate("Loss/valid", loss.unsqueeze(0))
349
528
 
350
529
  # Validate constraints
351
- self.valid_step(prediction, loss)
530
+ self.valid_step(
531
+ data,
532
+ loss,
533
+ self.constraints,
534
+ self.metric_manager,
535
+ )
352
536
 
353
- def _test_model(self) -> None:
354
- """
355
- Evaluate model performance on the test set.
537
+ # Execute on batch end callbacks
538
+ for callback in on_valid_batch_end:
539
+ data = callback(data)
540
+
541
+ def _test_model(
542
+ self,
543
+ on_test_batch_start: tuple[Callable[[dict[str, Tensor]], dict[str, Tensor]]],
544
+ on_test_batch_end: tuple[Callable[[dict[str, Tensor]], dict[str, Tensor]]],
545
+ on_test_completion_forward_pass: tuple[Callable[[dict[str, Tensor]], dict[str, Tensor]]],
546
+ ) -> None:
547
+ """Evaluate the model on the test dataset.
356
548
 
357
- This method:
358
- - Sets the model to evaluation mode.
359
- - Processes batches from the test DataLoader.
360
- - Computes predictions and losses.
361
- - Logs constraint satisfaction ratios.
549
+ This method sets the network to evaluation mode, iterates over the test
550
+ DataLoader, computes predictions, evaluates losses, and logs constraint
551
+ satisfaction. Optional callbacks can be executed at the start and end of
552
+ each batch, as well as after the forward pass.
362
553
 
363
- """
554
+ Args:
555
+ on_test_batch_start (tuple[Callable[[dict[str, Tensor]], dict[str, Tensor]]]):
556
+ Callbacks executed at the start of each test batch. Each callback
557
+ receives the data dictionary and returns updated versions.
558
+ on_test_batch_end (tuple[Callable[[dict[str, Tensor]], dict[str, Tensor]]]):
559
+ Callbacks executed at the end of each test batch. Each callback
560
+ receives the data dictionary and returns updated versions.
561
+ on_test_completion_forward_pass (tuple[Callable[[dict[str, Tensor]], dict[str, Tensor]]]):
562
+ Callbacks executed immediately after the forward pass of the test batch.
563
+ Each callback receives the data dictionary and returns updated versions.
364
564
 
565
+ Returns:
566
+ None
567
+ """
365
568
  # Set model in evaluation mode
366
569
  self.network.eval()
367
570
 
368
- with no_grad():
369
- for batch in tqdm(
370
- self.test_loader, desc="Test batches", leave=False
571
+ # Enable or disable gradient tracking for validation pass
572
+ with torch.set_grad_enabled(self.network_uses_grad):
573
+ # Loop over test batches
574
+ for data in tqdm(
575
+ self.test_loader,
576
+ desc="Test batches",
577
+ leave=False,
578
+ disable=self.disable_progress_bar_batch,
371
579
  ):
580
+ # Transfer batch data to GPU
581
+ data: dict[str, Tensor] = {
582
+ key: value.to(self.device) for key, value in data.items()
583
+ }
372
584
 
373
- # Get input-output pairs from batch
374
- inputs, outputs = batch
375
-
376
- # Transfer to GPU
377
- inputs, outputs = inputs.to(self.device), outputs.to(
378
- self.device
379
- )
585
+ # Execute on batch start callbacks
586
+ for callback in on_test_batch_start:
587
+ data = callback(data)
380
588
 
381
589
  # Model computations
382
- prediction = self.network(inputs)
590
+ data = self.network(data)
591
+
592
+ # Execute on completion forward pass callbacks
593
+ for callback in on_test_completion_forward_pass:
594
+ data = callback(data)
383
595
 
384
596
  # Calculate loss
385
- loss = self.criterion(prediction["output"], outputs)
597
+ loss = self.criterion(
598
+ data["output"],
599
+ data["target"],
600
+ data=data,
601
+ )
386
602
  self.metric_manager.accumulate("Loss/test", loss.unsqueeze(0))
387
603
 
388
604
  # Validate constraints
389
- self.test_step(prediction, loss)
605
+ self.test_step(
606
+ data,
607
+ loss,
608
+ self.constraints,
609
+ self.metric_manager,
610
+ )
611
+
612
+ # Execute on batch end callbacks
613
+ for callback in on_test_batch_end:
614
+ data = callback(data)
390
615
 
616
+ @staticmethod
391
617
  def train_step(
392
- self,
393
- prediction: dict[str, Tensor],
618
+ data: dict[str, Tensor],
394
619
  loss: Tensor,
620
+ constraints: list[Constraint],
621
+ descriptor: Descriptor,
622
+ metric_manager: MetricManager,
623
+ device: torch.device,
624
+ constraint_aggregator: Callable = torch.sum,
625
+ epsilon: float = 1e-6,
626
+ enforce_all: bool = True,
395
627
  ) -> Tensor:
396
- """
397
- Adjust the training loss based on constraints
398
- and compute the combined loss.
628
+ """Adjust the training loss based on constraints and compute the combined loss.
629
+
630
+ This method calculates the directions in which the network outputs should be
631
+ adjusted to satisfy constraints, scales these adjustments according to the
632
+ constraint's rescale factor and gradient norms, and adds the result to the
633
+ base loss. It also logs the constraint satisfaction ratio (CSR) for monitoring.
399
634
 
400
635
  Args:
401
- prediction (dict[str, Tensor]): Model predictions
402
- for variable layers.
636
+ data (dict[str, Tensor]): Dictionary containing the batch data, predictions and additional data.
403
637
  loss (Tensor): The base loss computed by the criterion.
638
+ constraints (list[Constraint]): List of constraints to enforce during training.
639
+ descriptor (Descriptor): Descriptor containing layer metadata and variable/loss layer info.
640
+ metric_manager (MetricManager): Metric manager for logging loss and CSR.
641
+ device (torch.device): Device on which computations are performed.
642
+ constraint_aggregator (Callable, optional): Function to aggregate per-layer rescaled losses. Defaults to `torch.mean`.
643
+ epsilon (float, optional): Small value to prevent division by zero in gradient normalization. Defaults to 1e-6.
644
+ enforce_all (bool, optional): If False, constraints are only monitored and do not influence the loss. Defaults to True.
404
645
 
405
646
  Returns:
406
- Tensor: The combined loss (base loss + constraint adjustments).
647
+ Tensor: The combined loss including the original loss and constraint-based adjustments.
407
648
  """
408
-
409
649
  # Init scalar tensor for loss
410
- total_rescale_loss = tensor(0, dtype=float32, device=self.device)
411
- loss_grads = {}
650
+ total_rescale_loss = tensor(0, dtype=float32, device=device)
651
+ norm_loss_grad: dict[str, Tensor] = {}
412
652
 
413
653
  # Precalculate loss gradients for each variable layer
414
- with no_grad():
415
- for layer in self.descriptor.variable_layers:
416
- self.optimizer.zero_grad()
417
- loss.backward(retain_graph=True, inputs=prediction[layer])
418
- loss_grads[layer] = prediction[layer].grad
654
+ for key in descriptor.variable_keys & descriptor.affects_loss_keys:
655
+ # Calculate gradients of loss w.r.t. predictions
656
+ grad = torch.autograd.grad(
657
+ outputs=loss, inputs=data[key], retain_graph=True, allow_unused=True
658
+ )[0]
659
+
660
+ # If gradients is None, report error
661
+ if grad is None:
662
+ raise RuntimeError(
663
+ f"Unable to compute loss gradients for layer '{key}'. "
664
+ "For layers not connected to the loss, set has_loss=False "
665
+ "when defining them in the Descriptor."
666
+ )
419
667
 
420
- for constraint in self.constraints:
668
+ # Flatten batch and compute L2 norm along each item
669
+ grad_flat = grad.view(grad.shape[0], -1)
670
+ norm_loss_grad[key] = (
671
+ vector_norm(grad_flat, dim=1, ord=2, keepdim=True).clamp(min=epsilon).detach()
672
+ )
421
673
 
674
+ for constraint in constraints:
422
675
  # Check if constraints are satisfied and calculate directions
423
- with no_grad():
424
- constraint_checks, relevant_constraint_count = (
425
- constraint.check_constraint(prediction)
426
- )
676
+ checks, mask = constraint.check_constraint(data)
677
+ directions = constraint.calculate_direction(data)
678
+
679
+ # Log constraint satisfaction ratio
680
+ csr = (sum(checks * mask) / sum(mask)).unsqueeze(0)
681
+ metric_manager.accumulate(f"{constraint.name}/train", csr)
682
+ metric_manager.accumulate("CSR/train", csr)
427
683
 
428
684
  # Only do adjusting calculation if constraint is not observant
429
- if not constraint.monitor_only:
685
+ if not enforce_all or not constraint.enforce:
686
+ continue
687
+
688
+ # Only do direction calculations for variable layers affecting constraint
689
+ for key in constraint.layers & descriptor.variable_keys:
430
690
  with no_grad():
431
- constraint_directions = constraint.calculate_direction(
432
- prediction
433
- )
434
-
435
- # Only do direction calculations for variable
436
- # layers affecting constraint
437
- for layer in (
438
- constraint.layers & self.descriptor.variable_layers
439
- ):
440
-
441
- with no_grad():
442
- # Multiply direction modifiers with constraint result
443
- constraint_result = (
444
- 1 - constraint_checks.unsqueeze(1)
445
- ) * constraint_directions[layer]
446
-
447
- # Multiply result with rescale factor of constraint
448
- constraint_result *= constraint.rescale_factor
449
-
450
- # Calculate loss gradient norm
451
- norm_loss_grad = norm(
452
- loss_grads[layer], dim=1, p=2, keepdim=True
453
- )
454
-
455
- # Apply minimum epsilon
456
- norm_loss_grad = maximum(norm_loss_grad, self.epsilon)
457
-
458
- # Calculate rescale loss
459
- rescale_loss = (
460
- prediction[layer]
461
- * constraint_result
462
- * norm_loss_grad.detach().clone()
463
- ).mean()
464
-
465
- # Store rescale loss for this reference space
466
- total_rescale_loss += rescale_loss
691
+ # Multiply direction modifiers with constraint result
692
+ constraint_result = (1 - checks) * directions[key]
467
693
 
468
- # Log constraint satisfaction ratio
469
- self.metric_manager.accumulate(
470
- f"{constraint.name}/train",
471
- (
472
- (
473
- sum(constraint_checks)
474
- - numel(constraint_checks)
475
- + relevant_constraint_count
476
- )
477
- / relevant_constraint_count
478
- ).unsqueeze(0),
479
- )
480
- self.metric_manager.accumulate(
481
- "CSR/train",
482
- (
483
- (
484
- sum(constraint_checks)
485
- - numel(constraint_checks)
486
- + relevant_constraint_count
487
- )
488
- / relevant_constraint_count
489
- ).unsqueeze(0),
490
- )
694
+ # Multiply result with rescale factor of constraint
695
+ constraint_result *= constraint.rescale_factor
696
+
697
+ # Calculate rescale loss
698
+ total_rescale_loss += constraint_aggregator(
699
+ data[key] * constraint_result * norm_loss_grad[key],
700
+ )
491
701
 
492
702
  # Return combined loss
493
703
  return loss + total_rescale_loss
494
704
 
705
+ @staticmethod
495
706
  def valid_step(
496
- self,
497
- prediction: dict[str, Tensor],
707
+ data: dict[str, Tensor],
498
708
  loss: Tensor,
709
+ constraints: list[Constraint],
710
+ metric_manager: MetricManager,
499
711
  ) -> Tensor:
500
- """
501
- Evaluate constraints during validation and log satisfaction metrics.
712
+ """Evaluate constraints during validation and log constraint satisfaction metrics.
713
+
714
+ This method checks whether each constraint is satisfied for the given
715
+ data, computes the constraint satisfaction ratio (CSR),
716
+ and logs it using the metric manager. The base loss is not modified.
502
717
 
503
718
  Args:
504
- prediction (dict[str, Tensor]): Model predictions for
505
- variable layers.
719
+ data (dict[str, Tensor]): Dictionary containing the batch data, predictions and additional data.
506
720
  loss (Tensor): The base loss computed by the criterion.
721
+ constraints (list[Constraint]): List of constraints to evaluate.
722
+ metric_manager (MetricManager): Metric manager for logging CSR and per-constraint metrics.
507
723
 
508
724
  Returns:
509
- Tensor: The unchanged base loss.
725
+ Tensor: The original, unchanged base loss.
510
726
  """
511
-
512
727
  # For each constraint in this reference space, calculate directions
513
- for constraint in self.constraints:
514
-
728
+ for constraint in constraints:
515
729
  # Check if constraints are satisfied for
516
- constraint_checks, relevant_constraint_count = (
517
- constraint.check_constraint(prediction)
518
- )
730
+ checks, mask = constraint.check_constraint(data)
519
731
 
520
732
  # Log constraint satisfaction ratio
521
- self.metric_manager.accumulate(
522
- f"{constraint.name}/valid",
523
- (
524
- (
525
- sum(constraint_checks)
526
- - numel(constraint_checks)
527
- + relevant_constraint_count
528
- )
529
- / relevant_constraint_count
530
- ).unsqueeze(0),
531
- )
532
- self.metric_manager.accumulate(
533
- "CSR/valid",
534
- (
535
- (
536
- sum(constraint_checks)
537
- - numel(constraint_checks)
538
- + relevant_constraint_count
539
- )
540
- / relevant_constraint_count
541
- ).unsqueeze(0),
542
- )
733
+ csr = (sum(checks * mask) / sum(mask)).unsqueeze(0)
734
+ metric_manager.accumulate(f"{constraint.name}/valid", csr)
735
+ metric_manager.accumulate("CSR/valid", csr)
543
736
 
544
- # Return loss
737
+ # Return original loss
545
738
  return loss
546
739
 
740
+ @staticmethod
547
741
  def test_step(
548
- self,
549
- prediction: dict[str, Tensor],
742
+ data: dict[str, Tensor],
550
743
  loss: Tensor,
744
+ constraints: list[Constraint],
745
+ metric_manager: MetricManager,
551
746
  ) -> Tensor:
552
- """
553
- Evaluate constraints during test and log satisfaction metrics.
747
+ """Evaluate constraints during testing and log constraint satisfaction metrics.
748
+
749
+ This method checks whether each constraint is satisfied for the given
750
+ data, computes the constraint satisfaction ratio (CSR),
751
+ and logs it using the metric manager. The base loss is not modified.
554
752
 
555
753
  Args:
556
- prediction (dict[str, Tensor]): Model predictions
557
- for variable layers.
754
+ data (dict[str, Tensor]): Dictionary containing the batch data, predictions and additional data.
558
755
  loss (Tensor): The base loss computed by the criterion.
756
+ constraints (list[Constraint]): List of constraints to evaluate.
757
+ metric_manager (MetricManager): Metric manager for logging CSR and per-constraint metrics.
559
758
 
560
759
  Returns:
561
- Tensor: The unchanged base loss.
760
+ Tensor: The original, unchanged base loss.
562
761
  """
563
-
564
762
  # For each constraint in this reference space, calculate directions
565
- for constraint in self.constraints:
566
-
763
+ for constraint in constraints:
567
764
  # Check if constraints are satisfied for
568
- constraint_checks, relevant_constraint_count = (
569
- constraint.check_constraint(prediction)
570
- )
765
+ checks, mask = constraint.check_constraint(data)
571
766
 
572
767
  # Log constraint satisfaction ratio
573
- self.metric_manager.accumulate(
574
- f"{constraint.name}/test",
575
- (
576
- (
577
- sum(constraint_checks)
578
- - numel(constraint_checks)
579
- + relevant_constraint_count
580
- )
581
- / relevant_constraint_count
582
- ).unsqueeze(0),
583
- )
584
- self.metric_manager.accumulate(
585
- "CSR/test",
586
- (
587
- (
588
- sum(constraint_checks)
589
- - numel(constraint_checks)
590
- + relevant_constraint_count
591
- )
592
- / relevant_constraint_count
593
- ).unsqueeze(0),
594
- )
768
+ csr = (sum(checks * mask) / sum(mask)).unsqueeze(0)
769
+ metric_manager.accumulate(f"{constraint.name}/test", csr)
770
+ metric_manager.accumulate("CSR/test", csr)
595
771
 
596
- # Return loss
772
+ # Return original loss
597
773
  return loss