congrads 1.1.2__py3-none-any.whl → 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
congrads/core.py DELETED
@@ -1,773 +0,0 @@
1
- """This module provides the core CongradsCore class for the main training functionality.
2
-
3
- It is designed to integrate constraint-guided optimization into neural network training.
4
- It extends traditional training processes by enforcing specific constraints
5
- on the model's outputs, ensuring that the network satisfies domain-specific
6
- requirements during both training and evaluation.
7
-
8
- The `CongradsCore` class serves as the central engine for managing the
9
- training, validation, and testing phases of a neural network model,
10
- incorporating constraints that influence the loss function and model updates.
11
- The model is trained with standard loss functions while also incorporating
12
- constraint-based adjustments, which are tracked and logged
13
- throughout the process.
14
-
15
- Key features:
16
- - Support for various constraints that can influence the training process.
17
- - Integration with PyTorch's `DataLoader` for efficient batch processing.
18
- - Metric management for tracking loss and constraint satisfaction.
19
- - Checkpoint management for saving and evaluating model states.
20
-
21
- The `CongradsCore` class allows for the use of additional callback functions
22
- at different stages of the training process to customize behavior for
23
- specific needs. These include callbacks for the start and end of epochs, as
24
- well as the start and end of the entire training process.
25
-
26
- """
27
-
28
- import warnings
29
- from collections.abc import Callable
30
-
31
- import torch
32
- from torch import Tensor, float32, no_grad, sum, tensor
33
- from torch.linalg import vector_norm
34
- from torch.nn import Module
35
- from torch.nn.modules.loss import _Loss
36
- from torch.optim import Optimizer
37
- from torch.utils.data import DataLoader
38
- from tqdm import tqdm
39
-
40
- from .checkpoints import CheckpointManager
41
- from .constraints import Constraint
42
- from .descriptor import Descriptor
43
- from .metrics import MetricManager
44
- from .utils import (
45
- is_torch_loss,
46
- torch_loss_wrapper,
47
- validate_callable,
48
- validate_callable_iterable,
49
- validate_iterable,
50
- validate_loaders,
51
- validate_type,
52
- )
53
-
54
-
55
- class CongradsCore:
56
- """The CongradsCore class is the central training engine for constraint-guided optimization.
57
-
58
- It integrates standard neural network training
59
- with additional constraint-driven adjustments to the loss function, ensuring
60
- that the network satisfies domain-specific constraints during training.
61
- """
62
-
63
- def __init__(
64
- self,
65
- descriptor: Descriptor,
66
- constraints: list[Constraint],
67
- loaders: tuple[DataLoader, DataLoader, DataLoader],
68
- network: Module,
69
- criterion: _Loss,
70
- optimizer: Optimizer,
71
- metric_manager: MetricManager,
72
- device: torch.device,
73
- network_uses_grad: bool = False,
74
- checkpoint_manager: CheckpointManager = None,
75
- epsilon: float = 1e-6,
76
- constraint_aggregator: Callable[..., Tensor] = sum,
77
- disable_progress_bar_epoch: bool = False,
78
- disable_progress_bar_batch: bool = False,
79
- enforce_all: bool = True,
80
- ):
81
- """Initialize the CongradsCore object.
82
-
83
- Args:
84
- descriptor (Descriptor): Describes variable layers in the network.
85
- constraints (list[Constraint]): List of constraints to guide training.
86
- loaders (tuple[DataLoader, DataLoader, DataLoader]): DataLoaders for
87
- training, validation, and testing.
88
- network (Module): The neural network model to train.
89
- criterion (callable): The loss function used for
90
- training and validation.
91
- optimizer (Optimizer): The optimizer used for updating model parameters.
92
- metric_manager (MetricManager): Manages metric tracking and recording.
93
- device (torch.device): The device (e.g., CPU or GPU) for computations.
94
- network_uses_grad (bool, optional): A flag indicating if the network
95
- contains gradient calculation computations. Default is False.
96
- checkpoint_manager (CheckpointManager, optional): Manages
97
- checkpointing. If not set, no checkpointing is done.
98
- epsilon (float, optional): A small value to avoid division by zero
99
- in gradient calculations. Default is 1e-10.
100
- constraint_aggregator (Callable[..., Tensor], optional): A function
101
- to aggregate the constraint rescale loss. Default is `sum`.
102
- disable_progress_bar_epoch (bool, optional): If set to True, the epoch
103
- progress bar will not show. Defaults to False.
104
- disable_progress_bar_batch (bool, optional): If set to True, the batch
105
- progress bar will not show. Defaults to False.
106
- enforce_all (bool, optional): If set to False, constraints will only be monitored and
107
- not influence the training process. Overrides constraint-specific `enforce` parameters.
108
- Defaults to True.
109
-
110
- Note:
111
- A warning is logged if the descriptor has no variable layers,
112
- as at least one variable layer is required for the constraint logic
113
- to influence the training process.
114
- """
115
- # Type checking
116
- validate_type("descriptor", descriptor, Descriptor)
117
- validate_iterable("constraints", constraints, Constraint, allow_empty=True)
118
- validate_loaders("loaders", loaders)
119
- validate_type("network", network, Module)
120
- validate_type("criterion", criterion, _Loss)
121
- validate_type("optimizer", optimizer, Optimizer)
122
- validate_type("metric_manager", metric_manager, MetricManager)
123
- validate_type("device", device, torch.device)
124
- validate_type("network_uses_grad", network_uses_grad, bool)
125
- validate_type(
126
- "checkpoint_manager",
127
- checkpoint_manager,
128
- CheckpointManager,
129
- allow_none=True,
130
- )
131
- validate_type("epsilon", epsilon, float)
132
- validate_callable("constraint_aggregator", constraint_aggregator, allow_none=True)
133
- validate_type("disable_progress_bar_epoch", disable_progress_bar_epoch, bool)
134
- validate_type("disable_progress_bar_batch", disable_progress_bar_batch, bool)
135
- validate_type("enforce_all", enforce_all, bool)
136
-
137
- # Init object variables
138
- self.descriptor = descriptor
139
- self.constraints = constraints
140
- self.train_loader = loaders[0]
141
- self.valid_loader = loaders[1]
142
- self.test_loader = loaders[2]
143
- self.network = network
144
- self.optimizer = optimizer
145
- self.metric_manager = metric_manager
146
- self.device = device
147
- self.network_uses_grad = network_uses_grad
148
- self.checkpoint_manager = checkpoint_manager
149
- self.epsilon = epsilon
150
- self.constraint_aggregator = constraint_aggregator
151
- self.disable_progress_bar_epoch = disable_progress_bar_epoch
152
- self.disable_progress_bar_batch = disable_progress_bar_batch
153
- self.enforce_all = enforce_all
154
-
155
- # Check if criterion is a torch loss function
156
- if is_torch_loss(criterion):
157
- # If so, wrap it in a custom loss function
158
- self.criterion = torch_loss_wrapper(criterion)
159
- else:
160
- self.criterion = criterion
161
-
162
- # Perform checks
163
- if len(self.descriptor.variable_keys) == 0:
164
- warnings.warn(
165
- "The descriptor object has no variable layers. The constraint \
166
- guided loss adjustment is therefore not used. \
167
- Is this the intended behavior?",
168
- stacklevel=2,
169
- )
170
-
171
- # Initialize constraint metrics
172
- self._initialize_metrics()
173
-
174
- def _initialize_metrics(self) -> None:
175
- """Register metrics for loss, constraint satisfaction ratio (CSR), and constraints.
176
-
177
- This method registers the following metrics:
178
-
179
- - Loss/train: Training loss.
180
- - Loss/valid: Validation loss.
181
- - Loss/test: Test loss after training.
182
- - CSR/train: Constraint satisfaction ratio during training.
183
- - CSR/valid: Constraint satisfaction ratio during validation.
184
- - CSR/test: Constraint satisfaction ratio after training.
185
- - One metric per constraint, for both training and validation.
186
-
187
- """
188
- self.metric_manager.register("Loss/train", "during_training")
189
- self.metric_manager.register("Loss/valid", "during_training")
190
- self.metric_manager.register("Loss/test", "after_training")
191
-
192
- if len(self.constraints) > 0:
193
- self.metric_manager.register("CSR/train", "during_training")
194
- self.metric_manager.register("CSR/valid", "during_training")
195
- self.metric_manager.register("CSR/test", "after_training")
196
-
197
- for constraint in self.constraints:
198
- self.metric_manager.register(f"{constraint.name}/train", "during_training")
199
- self.metric_manager.register(f"{constraint.name}/valid", "during_training")
200
- self.metric_manager.register(f"{constraint.name}/test", "after_training")
201
-
202
- def fit(
203
- self,
204
- start_epoch: int = 0,
205
- max_epochs: int = 100,
206
- test_model: bool = True,
207
- on_batch_start: list[Callable[[dict[str, Tensor]], dict[str, Tensor]]] | None = None,
208
- on_batch_end: list[Callable[[dict[str, Tensor]], dict[str, Tensor]]] | None = None,
209
- on_train_batch_start: list[Callable[[dict[str, Tensor]], dict[str, Tensor]]] | None = None,
210
- on_train_batch_end: list[Callable[[dict[str, Tensor]], dict[str, Tensor]]] | None = None,
211
- on_valid_batch_start: list[Callable[[dict[str, Tensor]], dict[str, Tensor]]] | None = None,
212
- on_valid_batch_end: list[Callable[[dict[str, Tensor]], dict[str, Tensor]]] | None = None,
213
- on_test_batch_start: list[Callable[[dict[str, Tensor]], dict[str, Tensor]]] | None = None,
214
- on_test_batch_end: list[Callable[[dict[str, Tensor]], dict[str, Tensor]]] | None = None,
215
- on_epoch_start: list[Callable[[int], None]] | None = None,
216
- on_epoch_end: list[Callable[[int], None]] | None = None,
217
- on_train_start: list[Callable[[int], None]] | None = None,
218
- on_train_end: list[Callable[[int], None]] | None = None,
219
- on_train_completion_forward_pass: list[Callable[[dict[str, Tensor]], dict[str, Tensor]]]
220
- | None = None,
221
- on_val_completion_forward_pass: list[Callable[[dict[str, Tensor]], dict[str, Tensor]]]
222
- | None = None,
223
- on_test_completion_forward_pass: list[Callable[[dict[str, Tensor]], dict[str, Tensor]]]
224
- | None = None,
225
- on_test_start: list[Callable[[int], None]] | None = None,
226
- on_test_end: list[Callable[[int], None]] | None = None,
227
- ) -> None:
228
- """Train the model over multiple epochs with optional validation and testing.
229
-
230
- This method manages the full training loop, including:
231
-
232
- - Executing epoch-level and batch-level callbacks.
233
- - Training and validating the model each epoch.
234
- - Adjusting losses according to constraints.
235
- - Logging metrics via the metric manager.
236
- - Optional evaluation on the test set.
237
- - Checkpointing the model during and after training.
238
-
239
- Args:
240
- start_epoch (int, optional): Epoch number to start training from. Defaults to 0.
241
- max_epochs (int, optional): Total number of epochs to train. Defaults to 100.
242
- test_model (bool, optional): If True, evaluate the model on the test set after training. Defaults to True.
243
- on_batch_start (list[Callable], optional): Callbacks executed at the start of every batch. Defaults to None.
244
- on_batch_end (list[Callable], optional): Callbacks executed at the end of every batch. Defaults to None.
245
- on_train_batch_start (list[Callable], optional): Callbacks executed at the start of each training batch. Defaults to `on_batch_start` if not provided.
246
- on_train_batch_end (list[Callable], optional): Callbacks executed at the end of each training batch. Defaults to `on_batch_end` if not provided.
247
- on_valid_batch_start (list[Callable], optional): Callbacks executed at the start of each validation batch. Defaults to `on_batch_start` if not provided.
248
- on_valid_batch_end (list[Callable], optional): Callbacks executed at the end of each validation batch. Defaults to `on_batch_end` if not provided.
249
- on_test_batch_start (list[Callable], optional): Callbacks executed at the start of each test batch. Defaults to `on_batch_start` if not provided.
250
- on_test_batch_end (list[Callable], optional): Callbacks executed at the end of each test batch. Defaults to `on_batch_end` if not provided.
251
- on_epoch_start (list[Callable], optional): Callbacks executed at the start of each epoch. Defaults to None.
252
- on_epoch_end (list[Callable], optional): Callbacks executed at the end of each epoch. Defaults to None.
253
- on_train_start (list[Callable], optional): Callbacks executed before training starts. Defaults to None.
254
- on_train_end (list[Callable], optional): Callbacks executed after training ends. Defaults to None.
255
- on_train_completion_forward_pass (list[Callable], optional): Callbacks executed after the forward pass during training. Defaults to None.
256
- on_val_completion_forward_pass (list[Callable], optional): Callbacks executed after the forward pass during validation. Defaults to None.
257
- on_test_completion_forward_pass (list[Callable], optional): Callbacks executed after the forward pass during testing. Defaults to None.
258
- on_test_start (list[Callable], optional): Callbacks executed before testing starts. Defaults to None.
259
- on_test_end (list[Callable], optional): Callbacks executed after testing ends. Defaults to None.
260
-
261
- Notes:
262
- - If phase-specific callbacks (train/valid/test) are not provided, the global `on_batch_start` and `on_batch_end` are used.
263
- - Training metrics, loss adjustments, and constraint satisfaction ratios are automatically logged via the metric manager.
264
- - The final model checkpoint is saved if a checkpoint manager is configured.
265
- """
266
- # Type checking
267
- validate_type("start_epoch", start_epoch, int)
268
- validate_type("max_epochs", max_epochs, int)
269
- validate_type("test_model", test_model, bool)
270
- validate_callable_iterable("on_batch_start", on_batch_start, allow_none=True)
271
- validate_callable_iterable("on_batch_end", on_batch_end, allow_none=True)
272
- validate_callable_iterable("on_train_batch_start", on_train_batch_start, allow_none=True)
273
- validate_callable_iterable("on_train_batch_end", on_train_batch_end, allow_none=True)
274
- validate_callable_iterable("on_valid_batch_start", on_valid_batch_start, allow_none=True)
275
- validate_callable_iterable("on_valid_batch_end", on_valid_batch_end, allow_none=True)
276
- validate_callable_iterable("on_test_batch_start", on_test_batch_start, allow_none=True)
277
- validate_callable_iterable("on_test_batch_end", on_test_batch_end, allow_none=True)
278
- validate_callable_iterable("on_epoch_start", on_epoch_start, allow_none=True)
279
- validate_callable_iterable("on_epoch_end", on_epoch_end, allow_none=True)
280
- validate_callable_iterable("on_train_start", on_train_start, allow_none=True)
281
- validate_callable_iterable("on_train_end", on_train_end, allow_none=True)
282
- validate_callable_iterable(
283
- "on_train_completion_forward_pass",
284
- on_train_completion_forward_pass,
285
- allow_none=True,
286
- )
287
- validate_callable_iterable(
288
- "on_val_completion_forward_pass",
289
- on_val_completion_forward_pass,
290
- allow_none=True,
291
- )
292
- validate_callable_iterable(
293
- "on_test_completion_forward_pass",
294
- on_test_completion_forward_pass,
295
- allow_none=True,
296
- )
297
- validate_callable_iterable("on_test_start", on_test_start, allow_none=True)
298
- validate_callable_iterable("on_test_end", on_test_end, allow_none=True)
299
-
300
- # Use global batch callback if phase-specific callback is unset
301
- # Init callbacks as empty list if None
302
- on_train_batch_start = on_train_batch_start or on_batch_start or []
303
- on_train_batch_end = on_train_batch_end or on_batch_end or []
304
- on_valid_batch_start = on_valid_batch_start or on_batch_start or []
305
- on_valid_batch_end = on_valid_batch_end or on_batch_end or []
306
- on_test_batch_start = on_test_batch_start or on_batch_start or []
307
- on_test_batch_end = on_test_batch_end or on_batch_end or []
308
- on_batch_start = on_batch_start or []
309
- on_batch_end = on_batch_end or []
310
- on_epoch_start = on_epoch_start or []
311
- on_epoch_end = on_epoch_end or []
312
- on_train_start = on_train_start or []
313
- on_train_end = on_train_end or []
314
- on_train_completion_forward_pass = on_train_completion_forward_pass or []
315
- on_val_completion_forward_pass = on_val_completion_forward_pass or []
316
- on_test_completion_forward_pass = on_test_completion_forward_pass or []
317
- on_test_start = on_test_start or []
318
- on_test_end = on_test_end or []
319
-
320
- # Keep track of epoch
321
- epoch = start_epoch
322
-
323
- # Execute training start hook if set
324
- for callback in on_train_start:
325
- callback(epoch)
326
-
327
- for i in tqdm(
328
- range(epoch, max_epochs),
329
- initial=epoch,
330
- desc="Epoch",
331
- disable=self.disable_progress_bar_epoch,
332
- ):
333
- epoch = i
334
-
335
- # Execute epoch start hook if set
336
- for callback in on_epoch_start:
337
- callback(epoch)
338
-
339
- # Execute training and validation epoch
340
- self._train_epoch(
341
- on_train_batch_start,
342
- on_train_batch_end,
343
- on_train_completion_forward_pass,
344
- )
345
- self._validate_epoch(
346
- on_valid_batch_start,
347
- on_valid_batch_end,
348
- on_val_completion_forward_pass,
349
- )
350
-
351
- # Checkpointing
352
- if self.checkpoint_manager:
353
- self.checkpoint_manager.evaluate_criteria(epoch)
354
-
355
- # Execute epoch end hook if set
356
- for callback in on_epoch_end:
357
- callback(epoch)
358
-
359
- # Execute training end hook if set
360
- for callback in on_train_end:
361
- callback(epoch)
362
-
363
- # Evaluate model performance on unseen test set if required
364
- if test_model:
365
- # Execute test end hook if set
366
- for callback in on_test_start:
367
- callback(epoch)
368
-
369
- self._test_model(
370
- on_test_batch_start,
371
- on_test_batch_end,
372
- on_test_completion_forward_pass,
373
- )
374
-
375
- # Execute test end hook if set
376
- for callback in on_test_end:
377
- callback(epoch)
378
-
379
- # Save final model
380
- if self.checkpoint_manager:
381
- self.checkpoint_manager.save(epoch, "checkpoint_final.pth")
382
-
383
- def _train_epoch(
384
- self,
385
- on_train_batch_start: tuple[Callable[[dict[str, Tensor]], dict[str, Tensor]], ...],
386
- on_train_batch_end: tuple[Callable[[dict[str, Tensor]], dict[str, Tensor]], ...],
387
- on_train_completion_forward_pass: tuple[
388
- Callable[[dict[str, Tensor]], dict[str, Tensor]], ...
389
- ],
390
- ) -> None:
391
- """Perform a single training epoch over all batches.
392
-
393
- This method sets the network to training mode, iterates over the training
394
- DataLoader, computes predictions, evaluates losses, applies constraint-based
395
- adjustments, performs backpropagation, and updates model parameters. It also
396
- supports executing optional callbacks at different stages of the batch
397
- processing.
398
-
399
- Args:
400
- on_train_batch_start (tuple[Callable[[dict[str, Tensor]], dict[str, Tensor]], ...]):
401
- Callbacks executed at the start of each batch. Each callback receives the
402
- data dictionary and returns updated versions.
403
- on_train_batch_end (tuple[Callable[[dict[str, Tensor]], dict[str, Tensor]], ...]):
404
- Callbacks executed at the end of each batch. Each callback receives the
405
- data dictionary and returns updated versions.
406
- on_train_completion_forward_pass (tuple[Callable[[dict[str, Tensor]], dict[str, Tensor]], ...]):
407
- Callbacks executed immediately after the forward pass of the batch.
408
- Each callback receives the data dictionary and returns updated versions.
409
-
410
- Returns:
411
- None
412
- """
413
- # Set model in training mode
414
- self.network.train()
415
-
416
- for data in tqdm(
417
- self.train_loader,
418
- desc="Training batches",
419
- leave=False,
420
- disable=self.disable_progress_bar_batch,
421
- ):
422
- # Transfer batch data to GPU
423
- data: dict[str, Tensor] = {key: value.to(self.device) for key, value in data.items()}
424
-
425
- # Execute on batch start callbacks
426
- for callback in on_train_batch_start:
427
- data = callback(data)
428
-
429
- # Model computations
430
- data = self.network(data)
431
-
432
- # Execute on completion forward pass callbacks
433
- for callback in on_train_completion_forward_pass:
434
- data = callback(data)
435
-
436
- # Calculate loss
437
- loss = self.criterion(
438
- data["output"],
439
- data["target"],
440
- data=data,
441
- )
442
- self.metric_manager.accumulate("Loss/train", loss.unsqueeze(0))
443
-
444
- # Adjust loss based on constraints
445
- combined_loss = self.train_step(
446
- data,
447
- loss,
448
- self.constraints,
449
- self.descriptor,
450
- self.metric_manager,
451
- self.device,
452
- constraint_aggregator=self.constraint_aggregator,
453
- epsilon=self.epsilon,
454
- enforce_all=self.enforce_all,
455
- )
456
-
457
- # Backprop
458
- self.optimizer.zero_grad()
459
- combined_loss.backward(retain_graph=False, inputs=list(self.network.parameters()))
460
- self.optimizer.step()
461
-
462
- # Execute on batch end callbacks
463
- for callback in on_train_batch_end:
464
- data = callback(data)
465
-
466
- def _validate_epoch(
467
- self,
468
- on_valid_batch_start: tuple[Callable[[dict[str, Tensor]], dict[str, Tensor]]],
469
- on_valid_batch_end: tuple[Callable[[dict[str, Tensor]], dict[str, Tensor]]],
470
- on_valid_completion_forward_pass: tuple[Callable[[dict[str, Tensor]], dict[str, Tensor]]],
471
- ) -> None:
472
- """Perform a single validation epoch over all batches.
473
-
474
- This method sets the network to evaluation mode, iterates over the validation
475
- DataLoader, computes predictions, evaluates losses, and logs constraint
476
- satisfaction. Optional callbacks can be executed at the start and end of each
477
- batch, as well as after the forward pass.
478
-
479
- Args:
480
- on_valid_batch_start (tuple[Callable[[dict[str, Tensor]], dict[str, Tensor]]]):
481
- Callbacks executed at the start of each validation batch. Each callback
482
- receives the data dictionary and returns updated versions.
483
- on_valid_batch_end (tuple[Callable[[dict[str, Tensor]], dict[str, Tensor]]]):
484
- Callbacks executed at the end of each validation batch. Each callback
485
- receives the data dictionary and returns updated versions.
486
- on_valid_completion_forward_pass (tuple[Callable[[dict[str, Tensor]], dict[str, Tensor]]]):
487
- Callbacks executed immediately after the forward pass of the validation batch.
488
- Each callback receives the data dictionary and returns updated versions.
489
-
490
- Returns:
491
- None
492
- """
493
- # Set model in evaluation mode
494
- self.network.eval()
495
-
496
- # Enable or disable gradient tracking for validation pass
497
- with torch.set_grad_enabled(self.network_uses_grad):
498
- # Loop over validation batches
499
- for data in tqdm(
500
- self.valid_loader,
501
- desc="Validation batches",
502
- leave=False,
503
- disable=self.disable_progress_bar_batch,
504
- ):
505
- # Transfer batch data to GPU
506
- data: dict[str, Tensor] = {
507
- key: value.to(self.device) for key, value in data.items()
508
- }
509
-
510
- # Execute on batch start callbacks
511
- for callback in on_valid_batch_start:
512
- data = callback(data)
513
-
514
- # Model computations
515
- data = self.network(data)
516
-
517
- # Execute on completion forward pass callbacks
518
- for callback in on_valid_completion_forward_pass:
519
- data = callback(data)
520
-
521
- # Calculate loss
522
- loss = self.criterion(
523
- data["output"],
524
- data["target"],
525
- data=data,
526
- )
527
- self.metric_manager.accumulate("Loss/valid", loss.unsqueeze(0))
528
-
529
- # Validate constraints
530
- self.valid_step(
531
- data,
532
- loss,
533
- self.constraints,
534
- self.metric_manager,
535
- )
536
-
537
- # Execute on batch end callbacks
538
- for callback in on_valid_batch_end:
539
- data = callback(data)
540
-
541
- def _test_model(
542
- self,
543
- on_test_batch_start: tuple[Callable[[dict[str, Tensor]], dict[str, Tensor]]],
544
- on_test_batch_end: tuple[Callable[[dict[str, Tensor]], dict[str, Tensor]]],
545
- on_test_completion_forward_pass: tuple[Callable[[dict[str, Tensor]], dict[str, Tensor]]],
546
- ) -> None:
547
- """Evaluate the model on the test dataset.
548
-
549
- This method sets the network to evaluation mode, iterates over the test
550
- DataLoader, computes predictions, evaluates losses, and logs constraint
551
- satisfaction. Optional callbacks can be executed at the start and end of
552
- each batch, as well as after the forward pass.
553
-
554
- Args:
555
- on_test_batch_start (tuple[Callable[[dict[str, Tensor]], dict[str, Tensor]]]):
556
- Callbacks executed at the start of each test batch. Each callback
557
- receives the data dictionary and returns updated versions.
558
- on_test_batch_end (tuple[Callable[[dict[str, Tensor]], dict[str, Tensor]]]):
559
- Callbacks executed at the end of each test batch. Each callback
560
- receives the data dictionary and returns updated versions.
561
- on_test_completion_forward_pass (tuple[Callable[[dict[str, Tensor]], dict[str, Tensor]]]):
562
- Callbacks executed immediately after the forward pass of the test batch.
563
- Each callback receives the data dictionary and returns updated versions.
564
-
565
- Returns:
566
- None
567
- """
568
- # Set model in evaluation mode
569
- self.network.eval()
570
-
571
- # Enable or disable gradient tracking for validation pass
572
- with torch.set_grad_enabled(self.network_uses_grad):
573
- # Loop over test batches
574
- for data in tqdm(
575
- self.test_loader,
576
- desc="Test batches",
577
- leave=False,
578
- disable=self.disable_progress_bar_batch,
579
- ):
580
- # Transfer batch data to GPU
581
- data: dict[str, Tensor] = {
582
- key: value.to(self.device) for key, value in data.items()
583
- }
584
-
585
- # Execute on batch start callbacks
586
- for callback in on_test_batch_start:
587
- data = callback(data)
588
-
589
- # Model computations
590
- data = self.network(data)
591
-
592
- # Execute on completion forward pass callbacks
593
- for callback in on_test_completion_forward_pass:
594
- data = callback(data)
595
-
596
- # Calculate loss
597
- loss = self.criterion(
598
- data["output"],
599
- data["target"],
600
- data=data,
601
- )
602
- self.metric_manager.accumulate("Loss/test", loss.unsqueeze(0))
603
-
604
- # Validate constraints
605
- self.test_step(
606
- data,
607
- loss,
608
- self.constraints,
609
- self.metric_manager,
610
- )
611
-
612
- # Execute on batch end callbacks
613
- for callback in on_test_batch_end:
614
- data = callback(data)
615
-
616
- @staticmethod
617
- def train_step(
618
- data: dict[str, Tensor],
619
- loss: Tensor,
620
- constraints: list[Constraint],
621
- descriptor: Descriptor,
622
- metric_manager: MetricManager,
623
- device: torch.device,
624
- constraint_aggregator: Callable = torch.sum,
625
- epsilon: float = 1e-6,
626
- enforce_all: bool = True,
627
- ) -> Tensor:
628
- """Adjust the training loss based on constraints and compute the combined loss.
629
-
630
- This method calculates the directions in which the network outputs should be
631
- adjusted to satisfy constraints, scales these adjustments according to the
632
- constraint's rescale factor and gradient norms, and adds the result to the
633
- base loss. It also logs the constraint satisfaction ratio (CSR) for monitoring.
634
-
635
- Args:
636
- data (dict[str, Tensor]): Dictionary containing the batch data, predictions and additional data.
637
- loss (Tensor): The base loss computed by the criterion.
638
- constraints (list[Constraint]): List of constraints to enforce during training.
639
- descriptor (Descriptor): Descriptor containing layer metadata and variable/loss layer info.
640
- metric_manager (MetricManager): Metric manager for logging loss and CSR.
641
- device (torch.device): Device on which computations are performed.
642
- constraint_aggregator (Callable, optional): Function to aggregate per-layer rescaled losses. Defaults to `torch.mean`.
643
- epsilon (float, optional): Small value to prevent division by zero in gradient normalization. Defaults to 1e-6.
644
- enforce_all (bool, optional): If False, constraints are only monitored and do not influence the loss. Defaults to True.
645
-
646
- Returns:
647
- Tensor: The combined loss including the original loss and constraint-based adjustments.
648
- """
649
- # Init scalar tensor for loss
650
- total_rescale_loss = tensor(0, dtype=float32, device=device)
651
- norm_loss_grad: dict[str, Tensor] = {}
652
-
653
- # Precalculate loss gradients for each variable layer
654
- for key in descriptor.variable_keys & descriptor.affects_loss_keys:
655
- # Calculate gradients of loss w.r.t. predictions
656
- grad = torch.autograd.grad(
657
- outputs=loss, inputs=data[key], retain_graph=True, allow_unused=True
658
- )[0]
659
-
660
- # If gradients is None, report error
661
- if grad is None:
662
- raise RuntimeError(
663
- f"Unable to compute loss gradients for layer '{key}'. "
664
- "For layers not connected to the loss, set has_loss=False "
665
- "when defining them in the Descriptor."
666
- )
667
-
668
- # Flatten batch and compute L2 norm along each item
669
- grad_flat = grad.view(grad.shape[0], -1)
670
- norm_loss_grad[key] = (
671
- vector_norm(grad_flat, dim=1, ord=2, keepdim=True).clamp(min=epsilon).detach()
672
- )
673
-
674
- for constraint in constraints:
675
- # Check if constraints are satisfied and calculate directions
676
- checks, mask = constraint.check_constraint(data)
677
- directions = constraint.calculate_direction(data)
678
-
679
- # Log constraint satisfaction ratio
680
- csr = (sum(checks * mask) / sum(mask)).unsqueeze(0)
681
- metric_manager.accumulate(f"{constraint.name}/train", csr)
682
- metric_manager.accumulate("CSR/train", csr)
683
-
684
- # Only do adjusting calculation if constraint is not observant
685
- if not enforce_all or not constraint.enforce:
686
- continue
687
-
688
- # Only do direction calculations for variable layers affecting constraint
689
- for key in constraint.layers & descriptor.variable_keys:
690
- with no_grad():
691
- # Multiply direction modifiers with constraint result
692
- constraint_result = (1 - checks) * directions[key]
693
-
694
- # Multiply result with rescale factor of constraint
695
- constraint_result *= constraint.rescale_factor
696
-
697
- # Calculate rescale loss
698
- total_rescale_loss += constraint_aggregator(
699
- data[key] * constraint_result * norm_loss_grad[key],
700
- )
701
-
702
- # Return combined loss
703
- return loss + total_rescale_loss
704
-
705
- @staticmethod
706
- def valid_step(
707
- data: dict[str, Tensor],
708
- loss: Tensor,
709
- constraints: list[Constraint],
710
- metric_manager: MetricManager,
711
- ) -> Tensor:
712
- """Evaluate constraints during validation and log constraint satisfaction metrics.
713
-
714
- This method checks whether each constraint is satisfied for the given
715
- data, computes the constraint satisfaction ratio (CSR),
716
- and logs it using the metric manager. The base loss is not modified.
717
-
718
- Args:
719
- data (dict[str, Tensor]): Dictionary containing the batch data, predictions and additional data.
720
- loss (Tensor): The base loss computed by the criterion.
721
- constraints (list[Constraint]): List of constraints to evaluate.
722
- metric_manager (MetricManager): Metric manager for logging CSR and per-constraint metrics.
723
-
724
- Returns:
725
- Tensor: The original, unchanged base loss.
726
- """
727
- # For each constraint in this reference space, calculate directions
728
- for constraint in constraints:
729
- # Check if constraints are satisfied for
730
- checks, mask = constraint.check_constraint(data)
731
-
732
- # Log constraint satisfaction ratio
733
- csr = (sum(checks * mask) / sum(mask)).unsqueeze(0)
734
- metric_manager.accumulate(f"{constraint.name}/valid", csr)
735
- metric_manager.accumulate("CSR/valid", csr)
736
-
737
- # Return original loss
738
- return loss
739
-
740
- @staticmethod
741
- def test_step(
742
- data: dict[str, Tensor],
743
- loss: Tensor,
744
- constraints: list[Constraint],
745
- metric_manager: MetricManager,
746
- ) -> Tensor:
747
- """Evaluate constraints during testing and log constraint satisfaction metrics.
748
-
749
- This method checks whether each constraint is satisfied for the given
750
- data, computes the constraint satisfaction ratio (CSR),
751
- and logs it using the metric manager. The base loss is not modified.
752
-
753
- Args:
754
- data (dict[str, Tensor]): Dictionary containing the batch data, predictions and additional data.
755
- loss (Tensor): The base loss computed by the criterion.
756
- constraints (list[Constraint]): List of constraints to evaluate.
757
- metric_manager (MetricManager): Metric manager for logging CSR and per-constraint metrics.
758
-
759
- Returns:
760
- Tensor: The original, unchanged base loss.
761
- """
762
- # For each constraint in this reference space, calculate directions
763
- for constraint in constraints:
764
- # Check if constraints are satisfied for
765
- checks, mask = constraint.check_constraint(data)
766
-
767
- # Log constraint satisfaction ratio
768
- csr = (sum(checks * mask) / sum(mask)).unsqueeze(0)
769
- metric_manager.accumulate(f"{constraint.name}/test", csr)
770
- metric_manager.accumulate("CSR/test", csr)
771
-
772
- # Return original loss
773
- return loss