congrads 0.1.0__py3-none-any.whl → 1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
congrads/core.py CHANGED
@@ -1,211 +1,597 @@
1
- import logging
2
- from typing import Dict
3
- from lightning import LightningModule
4
- from torch import Tensor, float32, no_grad, norm, tensor
5
- from torchmetrics import Metric
6
- from torch.nn import ModuleDict
7
-
1
+ """
2
+ This module provides the CongradsCore class, which is designed to integrate
3
+ constraint-guided optimization into neural network training.
4
+ It extends traditional training processes by enforcing specific constraints
5
+ on the model's outputs, ensuring that the network satisfies domain-specific
6
+ requirements during both training and evaluation.
7
+
8
+ The `CongradsCore` class serves as the central engine for managing the
9
+ training, validation, and testing phases of a neural network model,
10
+ incorporating constraints that influence the loss function and model updates.
11
+ The model is trained with standard loss functions while also incorporating
12
+ constraint-based adjustments, which are tracked and logged
13
+ throughout the process.
14
+
15
+ Key features:
16
+ - Support for various constraints that can influence the training process.
17
+ - Integration with PyTorch's `DataLoader` for efficient batch processing.
18
+ - Metric management for tracking loss and constraint satisfaction.
19
+ - Checkpoint management for saving and evaluating model states.
20
+
21
+ Modules in this package provide the following:
22
+
23
+ - `Descriptor`: Describes variable layers in the network that are
24
+ subject to constraints.
25
+ - `Constraint`: Defines various constraints, which are used to guide
26
+ the training process.
27
+ - `MetricManager`: Manages and tracks performance metrics such as loss
28
+ and constraint satisfaction.
29
+ - `CheckpointManager`: Manages saving and loading model checkpoints
30
+ during training.
31
+ - Utility functions to validate inputs and configurations.
32
+
33
+ Dependencies:
34
+ - PyTorch (`torch`)
35
+ - tqdm (for progress tracking)
36
+
37
+ The `CongradsCore` class allows for the use of additional callback functions
38
+ at different stages of the training process to customize behavior for
39
+ specific needs. These include callbacks for the start and end of epochs, as
40
+ well as the start and end of the entire training process.
41
+
42
+ """
43
+
44
+ import warnings
45
+ from numbers import Number
46
+ from typing import Callable
47
+
48
+ import torch
49
+
50
+ # pylint: disable-next=redefined-builtin
51
+ from torch import Tensor, float32, maximum, no_grad, norm, numel, sum, tensor
52
+ from torch.nn import Module
53
+ from torch.nn.modules.loss import _Loss
54
+ from torch.optim import Optimizer
55
+ from torch.utils.data import DataLoader
56
+ from tqdm import tqdm
57
+
58
+ from .checkpoints import CheckpointManager
8
59
  from .constraints import Constraint
9
- from .metrics import ConstraintSatisfactionRatio
10
60
  from .descriptor import Descriptor
61
+ from .metrics import MetricManager
62
+ from .utils import (
63
+ validate_callable,
64
+ validate_iterable,
65
+ validate_loaders,
66
+ validate_type,
67
+ )
11
68
 
12
69
 
13
- class CGGDModule(LightningModule):
70
+ class CongradsCore:
14
71
  """
15
- A PyTorch Lightning module that integrates constraint-guided optimization into the training and validation steps.
16
-
17
- This module extends the `LightningModule` and incorporates constraints on the neural network's predictions
18
- by adjusting the loss using a rescale factor. The constraints are checked, and the loss is modified to guide
19
- the optimization process based on these constraints.
20
-
21
- Attributes:
22
- descriptor (Descriptor): The object that describes the layers and neurons of the network, including
23
- the categorization of variable layers.
24
- constraints (list[Constraint]): A list of constraints that define the conditions to guide the optimization.
25
- train_csr (Dict[str, Metric]): A dictionary of `ConstraintSatisfactionRatio` metrics to track constraint satisfaction
26
- during training, indexed by constraint name.
27
- valid_csr (Dict[str, Metric]): A dictionary of `ConstraintSatisfactionRatio` metrics to track constraint satisfaction
28
- during validation, indexed by constraint name.
72
+ The CongradsCore class is the central training engine for constraint-guided
73
+ neural network optimization. It integrates standard neural network training
74
+ with additional constraint-driven adjustments to the loss function, ensuring
75
+ that the network satisfies domain-specific constraints during training.
76
+
77
+ Args:
78
+ descriptor (Descriptor): Describes variable layers in the network.
79
+ constraints (list[Constraint]): List of constraints to guide training.
80
+ loaders (tuple[DataLoader, DataLoader, DataLoader]): DataLoaders for
81
+ training, validation, and testing.
82
+ network (Module): The neural network model to train.
83
+ criterion (callable): The loss function used for
84
+ training and validation.
85
+ optimizer (Optimizer): The optimizer used for updating model parameters.
86
+ metric_manager (MetricManager): Manages metric tracking and recording.
87
+ device (torch.device): The device (e.g., CPU or GPU) for computations.
88
+ checkpoint_manager (CheckpointManager, optional): Manages
89
+ checkpointing. If not set, no checkpointing is done.
90
+ epsilon (Number, optional): A small value to avoid division by zero
91
+ in gradient calculations. Default is 1e-10.
92
+
93
+ Note:
94
+ A warning is logged if the descriptor has no variable layers,
95
+ as at least one variable layer is required for the constraint logic
96
+ to influence the training process.
29
97
  """
30
98
 
31
- def __init__(self, descriptor: Descriptor, constraints: list[Constraint]):
99
+ def __init__(
100
+ self,
101
+ descriptor: Descriptor,
102
+ constraints: list[Constraint],
103
+ loaders: tuple[DataLoader, DataLoader, DataLoader],
104
+ network: Module,
105
+ criterion: _Loss,
106
+ optimizer: Optimizer,
107
+ metric_manager: MetricManager,
108
+ device: torch.device,
109
+ checkpoint_manager: CheckpointManager = None,
110
+ epsilon: Number = 1e-6,
111
+ ):
32
112
  """
33
- Initializes the CGGDModule with a descriptor and a list of constraints.
34
-
35
- Args:
36
- descriptor (Descriptor): The object that describes the network's layers and neurons, including their categorization.
37
- constraints (list[Constraint]): A list of constraints that will guide the optimization process.
38
-
39
- Raises:
40
- Warning if there are no variable layers in the descriptor, as constraints will not be applied.
113
+ Initialize the CongradsCore object.
41
114
  """
42
115
 
43
- # Init parent class
44
- super().__init__()
116
+ # Type checking
117
+ validate_type("descriptor", descriptor, Descriptor)
118
+ validate_iterable("constraints", constraints, Constraint)
119
+ validate_loaders()
120
+ validate_type("network", network, Module)
121
+ validate_type("criterion", criterion, _Loss)
122
+ validate_type("optimizer", optimizer, Optimizer)
123
+ validate_type("metric_manager", metric_manager, MetricManager)
124
+ validate_type("device", device, torch.device)
125
+ validate_type(
126
+ "checkpoint_manager",
127
+ checkpoint_manager,
128
+ CheckpointManager,
129
+ allow_none=True,
130
+ )
131
+ validate_type("epsilon", epsilon, Number)
45
132
 
46
133
  # Init object variables
47
134
  self.descriptor = descriptor
48
135
  self.constraints = constraints
136
+ self.train_loader = loaders[0]
137
+ self.valid_loader = loaders[1]
138
+ self.test_loader = loaders[2]
139
+ self.network = network
140
+ self.criterion = criterion
141
+ self.optimizer = optimizer
142
+ self.metric_manager = metric_manager
143
+ self.device = device
144
+ self.checkpoint_manager = checkpoint_manager
145
+
146
+ # Init epsilon tensor
147
+ self.epsilon = tensor(epsilon, device=self.device)
49
148
 
50
149
  # Perform checks
51
150
  if len(self.descriptor.variable_layers) == 0:
52
- logging.warning(
53
- "The descriptor object has no variable layers. The constraint guided loss adjustment is therefore not used. Is this the intended behaviour?"
151
+ warnings.warn(
152
+ "The descriptor object has no variable layers. The constraint \
153
+ guided loss adjustment is therefore not used. \
154
+ Is this the intended behavior?"
54
155
  )
55
156
 
56
- # Assign descriptor to constraints
157
+ # Initialize constraint metrics
158
+ self._initialize_metrics()
159
+
160
+ def _initialize_metrics(self) -> None:
161
+ """
162
+ Register metrics for loss, constraint satisfaction ratio (CSR),
163
+ and individual constraints.
164
+
165
+ This method registers the following metrics:
166
+
167
+ - Loss/train: Training loss.
168
+ - Loss/valid: Validation loss.
169
+ - Loss/test: Test loss after training.
170
+ - CSR/train: Constraint satisfaction ratio during training.
171
+ - CSR/valid: Constraint satisfaction ratio during validation.
172
+ - CSR/test: Constraint satisfaction ratio after training.
173
+ - One metric per constraint, for both training and validation.
174
+
175
+ """
176
+
177
+ self.metric_manager.register("Loss/train", "during_training")
178
+ self.metric_manager.register("Loss/valid", "during_training")
179
+ self.metric_manager.register("Loss/test", "after_training")
180
+
181
+ if len(self.constraints) > 0:
182
+ self.metric_manager.register("CSR/train", "during_training")
183
+ self.metric_manager.register("CSR/valid", "during_training")
184
+ self.metric_manager.register("CSR/test", "after_training")
185
+
57
186
  for constraint in self.constraints:
58
- constraint.descriptor = descriptor
59
- constraint.run_init_descriptor()
60
-
61
- # Init constraint metric logging
62
- self.train_csr: Dict[str, Metric] = ModuleDict(
63
- {
64
- constraint.constraint_name: ConstraintSatisfactionRatio()
65
- for constraint in self.constraints
66
- }
67
- )
68
- self.train_csr["global"] = ConstraintSatisfactionRatio()
69
- self.valid_csr: Dict[str, Metric] = ModuleDict(
70
- {
71
- constraint.constraint_name: ConstraintSatisfactionRatio()
72
- for constraint in self.constraints
73
- }
74
- )
75
- self.valid_csr["global"] = ConstraintSatisfactionRatio()
187
+ self.metric_manager.register(
188
+ f"{constraint.name}/train", "during_training"
189
+ )
190
+ self.metric_manager.register(
191
+ f"{constraint.name}/valid", "during_training"
192
+ )
193
+ self.metric_manager.register(
194
+ f"{constraint.name}/test", "after_training"
195
+ )
76
196
 
77
- def training_step(
197
+ def fit(
198
+ self,
199
+ start_epoch: int = 0,
200
+ max_epochs: int = 100,
201
+ on_epoch_start: Callable[[int], None] = None,
202
+ on_epoch_end: Callable[[int], None] = None,
203
+ on_train_start: Callable[[int], None] = None,
204
+ on_train_end: Callable[[int], None] = None,
205
+ ) -> None:
206
+ """
207
+ Train the model for a given number of epochs.
208
+
209
+ Args:
210
+ start_epoch (int, optional): The epoch number to start the training
211
+ with. Default is 0.
212
+ max_epochs (int, optional): The number of epochs to train the
213
+ model. Default is 100.
214
+ on_epoch_start (Callable[[int], None], optional): A callback
215
+ function that will be executed at the start of each epoch.
216
+ on_epoch_end (Callable[[int], None], optional): A callback
217
+ function that will be executed at the end of each epoch.
218
+ on_train_start (Callable[[int], None], optional): A callback
219
+ function that will be executed before the training starts.
220
+ on_train_end (Callable[[int], None], optional): A callback
221
+ function that will be executed after training ends.
222
+ """
223
+
224
+ # Type checking
225
+ validate_type("start_epoch", start_epoch, int)
226
+ validate_callable("on_epoch_start", on_epoch_start, True)
227
+ validate_callable("on_epoch_end", on_epoch_end, True)
228
+ validate_callable("on_train_start", on_train_start, True)
229
+ validate_callable("on_train_end", on_train_end, True)
230
+
231
+ # Keep track of epoch
232
+ epoch = start_epoch
233
+
234
+ # Execute training start hook if set
235
+ if on_train_start:
236
+ on_train_start(epoch)
237
+
238
+ for i in tqdm(range(epoch, max_epochs), initial=epoch, desc="Epoch"):
239
+ epoch = i
240
+
241
+ # Execute epoch start hook if set
242
+ if on_epoch_start:
243
+ on_epoch_start(epoch)
244
+
245
+ # Execute training and validation epoch
246
+ self._train_epoch()
247
+ self._validate_epoch()
248
+
249
+ # Checkpointing
250
+ if self.checkpoint_manager:
251
+ self.checkpoint_manager.evaluate_criteria(epoch)
252
+
253
+ # Execute epoch end hook if set
254
+ if on_epoch_end:
255
+ on_epoch_end(epoch)
256
+
257
+ # Evaluate model performance on unseen test set
258
+ self._test_model()
259
+
260
+ # Save final model
261
+ if self.checkpoint_manager:
262
+ self.checkpoint_manager.save(epoch, "checkpoint_final.pth")
263
+
264
+ # Execute training end hook if set
265
+ if on_train_end:
266
+ on_train_end(epoch)
267
+
268
+ def _train_epoch(self) -> None:
269
+ """
270
+ Perform training for a single epoch.
271
+
272
+ This method:
273
+ - Sets the model to training mode.
274
+ - Processes batches from the training DataLoader.
275
+ - Computes predictions and losses.
276
+ - Adjusts losses based on constraints.
277
+ - Updates model parameters using backpropagation.
278
+
279
+ Args:
280
+ epoch (int): The current epoch number.
281
+ """
282
+
283
+ # Set model in training mode
284
+ self.network.train()
285
+
286
+ for batch in tqdm(
287
+ self.train_loader, desc="Training batches", leave=False
288
+ ):
289
+
290
+ # Get input-output pairs from batch
291
+ inputs, outputs = batch
292
+
293
+ # Transfer to GPU
294
+ inputs, outputs = inputs.to(self.device), outputs.to(self.device)
295
+
296
+ # Model computations
297
+ prediction = self.network(inputs)
298
+
299
+ # Calculate loss
300
+ loss = self.criterion(prediction["output"], outputs)
301
+ self.metric_manager.accumulate("Loss/train", loss.unsqueeze(0))
302
+
303
+ # Adjust loss based on constraints
304
+ combined_loss = self.train_step(prediction, loss)
305
+
306
+ # Backprop
307
+ self.optimizer.zero_grad()
308
+ combined_loss.backward(
309
+ retain_graph=False, inputs=list(self.network.parameters())
310
+ )
311
+ self.optimizer.step()
312
+
313
+ def _validate_epoch(self) -> None:
314
+ """
315
+ Perform validation for a single epoch.
316
+
317
+ This method:
318
+ - Sets the model to evaluation mode.
319
+ - Processes batches from the validation DataLoader.
320
+ - Computes predictions and losses.
321
+ - Logs constraint satisfaction ratios.
322
+
323
+ Args:
324
+ epoch (int): The current epoch number.
325
+ """
326
+
327
+ # Set model in evaluation mode
328
+ self.network.eval()
329
+
330
+ with no_grad():
331
+ for batch in tqdm(
332
+ self.valid_loader, desc="Validation batches", leave=False
333
+ ):
334
+
335
+ # Get input-output pairs from batch
336
+ inputs, outputs = batch
337
+
338
+ # Transfer to GPU
339
+ inputs, outputs = inputs.to(self.device), outputs.to(
340
+ self.device
341
+ )
342
+
343
+ # Model computations
344
+ prediction = self.network(inputs)
345
+
346
+ # Calculate loss
347
+ loss = self.criterion(prediction["output"], outputs)
348
+ self.metric_manager.accumulate("Loss/valid", loss.unsqueeze(0))
349
+
350
+ # Validate constraints
351
+ self.valid_step(prediction, loss)
352
+
353
+ def _test_model(self) -> None:
354
+ """
355
+ Evaluate model performance on the test set.
356
+
357
+ This method:
358
+ - Sets the model to evaluation mode.
359
+ - Processes batches from the test DataLoader.
360
+ - Computes predictions and losses.
361
+ - Logs constraint satisfaction ratios.
362
+
363
+ """
364
+
365
+ # Set model in evaluation mode
366
+ self.network.eval()
367
+
368
+ with no_grad():
369
+ for batch in tqdm(
370
+ self.test_loader, desc="Test batches", leave=False
371
+ ):
372
+
373
+ # Get input-output pairs from batch
374
+ inputs, outputs = batch
375
+
376
+ # Transfer to GPU
377
+ inputs, outputs = inputs.to(self.device), outputs.to(
378
+ self.device
379
+ )
380
+
381
+ # Model computations
382
+ prediction = self.network(inputs)
383
+
384
+ # Calculate loss
385
+ loss = self.criterion(prediction["output"], outputs)
386
+ self.metric_manager.accumulate("Loss/test", loss.unsqueeze(0))
387
+
388
+ # Validate constraints
389
+ self.test_step(prediction, loss)
390
+
391
+ def train_step(
78
392
  self,
79
393
  prediction: dict[str, Tensor],
80
394
  loss: Tensor,
81
- ):
395
+ ) -> Tensor:
82
396
  """
83
- The training step where the standard loss is combined with rescale loss based on the constraints.
84
-
85
- For each constraint, the satisfaction ratio is checked, and the loss is adjusted by adding a rescale loss
86
- based on the directions calculated by the constraint.
397
+ Adjust the training loss based on constraints
398
+ and compute the combined loss.
87
399
 
88
400
  Args:
89
- prediction (dict[str, Tensor]): The model's predictions for each layer.
90
- loss (Tensor): The base loss from the model's forward pass.
401
+ prediction (dict[str, Tensor]): Model predictions
402
+ for variable layers.
403
+ loss (Tensor): The base loss computed by the criterion.
91
404
 
92
405
  Returns:
93
- Tensor: The combined loss, including both the original loss and the rescale loss from the constraints.
406
+ Tensor: The combined loss (base loss + constraint adjustments).
94
407
  """
95
408
 
96
409
  # Init scalar tensor for loss
97
410
  total_rescale_loss = tensor(0, dtype=float32, device=self.device)
411
+ loss_grads = {}
98
412
 
99
- # Compute rescale loss without tracking gradients
413
+ # Precalculate loss gradients for each variable layer
100
414
  with no_grad():
415
+ for layer in self.descriptor.variable_layers:
416
+ self.optimizer.zero_grad()
417
+ loss.backward(retain_graph=True, inputs=prediction[layer])
418
+ loss_grads[layer] = prediction[layer].grad
101
419
 
102
- # For each constraint, TODO split into real and validation only constraints
103
- for constraint in self.constraints:
420
+ for constraint in self.constraints:
104
421
 
105
- # Check if constraints are satisfied and calculate directions
106
- constraint_checks = constraint.check_constraint(prediction)
107
- constraint_directions = constraint.calculate_direction(prediction)
422
+ # Check if constraints are satisfied and calculate directions
423
+ with no_grad():
424
+ constraint_checks, relevant_constraint_count = (
425
+ constraint.check_constraint(prediction)
426
+ )
427
+
428
+ # Only do adjusting calculation if constraint is not observant
429
+ if not constraint.monitor_only:
430
+ with no_grad():
431
+ constraint_directions = constraint.calculate_direction(
432
+ prediction
433
+ )
108
434
 
109
- # Only do direction calculations for variable layers affecting constraint
110
- for layer in constraint.layers & self.descriptor.variable_layers:
435
+ # Only do direction calculations for variable
436
+ # layers affecting constraint
437
+ for layer in (
438
+ constraint.layers & self.descriptor.variable_layers
439
+ ):
111
440
 
112
- # Multiply direction modifiers with constraint result
113
- constraint_result = (
114
- constraint_checks[layer].unsqueeze(1).type(float32)
115
- * constraint_directions[layer]
116
- )
441
+ with no_grad():
442
+ # Multiply direction modifiers with constraint result
443
+ constraint_result = (
444
+ 1 - constraint_checks.unsqueeze(1)
445
+ ) * constraint_directions[layer]
117
446
 
118
- # Multiply result with rescale factor o constraint
119
- constraint_result *= constraint.rescale_factor
447
+ # Multiply result with rescale factor of constraint
448
+ constraint_result *= constraint.rescale_factor
120
449
 
121
- # Calculate gradients of general loss for each sample
122
- loss.backward(retain_graph=True, inputs=prediction[layer])
123
- loss_grad = prediction[layer].grad
450
+ # Calculate loss gradient norm
451
+ norm_loss_grad = norm(
452
+ loss_grads[layer], dim=1, p=2, keepdim=True
453
+ )
124
454
 
125
- # Calculate loss gradient norm
126
- norm_loss_grad = norm(loss_grad, dim=0, p=2, keepdim=True)
455
+ # Apply minimum epsilon
456
+ norm_loss_grad = maximum(norm_loss_grad, self.epsilon)
127
457
 
128
458
  # Calculate rescale loss
129
459
  rescale_loss = (
130
- (prediction[layer] * constraint_result * norm_loss_grad)
131
- .sum()
132
- .abs()
133
- )
460
+ prediction[layer]
461
+ * constraint_result
462
+ * norm_loss_grad.detach().clone()
463
+ ).mean()
134
464
 
135
465
  # Store rescale loss for this reference space
136
466
  total_rescale_loss += rescale_loss
137
467
 
138
- # Log constraint satisfaction ratio
139
- # NOTE does this take into account spaces with different dimensions?
140
- self.train_csr[constraint.constraint_name](constraint_checks[layer])
141
- self.train_csr["global"](constraint_checks[layer])
142
- self.log(
143
- f"train_csr_{constraint.constraint_name}_{layer}",
144
- self.train_csr[constraint.constraint_name],
145
- on_step=False,
146
- on_epoch=True,
468
+ # Log constraint satisfaction ratio
469
+ self.metric_manager.accumulate(
470
+ f"{constraint.name}/train",
471
+ (
472
+ (
473
+ sum(constraint_checks)
474
+ - numel(constraint_checks)
475
+ + relevant_constraint_count
147
476
  )
148
-
149
- # Log global constraint satisfaction ratio
150
- self.log(
151
- "train_csr_global",
152
- self.train_csr["global"],
153
- on_step=False,
154
- on_epoch=True,
155
- )
477
+ / relevant_constraint_count
478
+ ).unsqueeze(0),
479
+ )
480
+ self.metric_manager.accumulate(
481
+ "CSR/train",
482
+ (
483
+ (
484
+ sum(constraint_checks)
485
+ - numel(constraint_checks)
486
+ + relevant_constraint_count
487
+ )
488
+ / relevant_constraint_count
489
+ ).unsqueeze(0),
490
+ )
156
491
 
157
492
  # Return combined loss
158
493
  return loss + total_rescale_loss
159
494
 
160
- def validation_step(
495
+ def valid_step(
161
496
  self,
162
497
  prediction: dict[str, Tensor],
163
498
  loss: Tensor,
164
- ):
499
+ ) -> Tensor:
165
500
  """
166
- The validation step where the satisfaction of constraints is checked without applying the rescale loss.
167
-
168
- Similar to the training step, but without updating the loss, this method tracks the constraint satisfaction
169
- during validation.
501
+ Evaluate constraints during validation and log satisfaction metrics.
170
502
 
171
503
  Args:
172
- prediction (dict[str, Tensor]): The model's predictions for each layer.
173
- loss (Tensor): The base loss from the model's forward pass.
504
+ prediction (dict[str, Tensor]): Model predictions for
505
+ variable layers.
506
+ loss (Tensor): The base loss computed by the criterion.
174
507
 
175
508
  Returns:
176
- Tensor: The base loss value for validation.
509
+ Tensor: The unchanged base loss.
177
510
  """
178
511
 
179
- # Compute rescale loss without tracking gradients
180
- with no_grad():
512
+ # For each constraint in this reference space, calculate directions
513
+ for constraint in self.constraints:
181
514
 
182
- # For each constraint in this reference space, calculate directions
183
- for constraint in self.constraints:
515
+ # Check if constraints are satisfied for
516
+ constraint_checks, relevant_constraint_count = (
517
+ constraint.check_constraint(prediction)
518
+ )
184
519
 
185
- # Check if constraints are satisfied for
186
- constraint_checks = constraint.check_constraint(prediction)
520
+ # Log constraint satisfaction ratio
521
+ self.metric_manager.accumulate(
522
+ f"{constraint.name}/valid",
523
+ (
524
+ (
525
+ sum(constraint_checks)
526
+ - numel(constraint_checks)
527
+ + relevant_constraint_count
528
+ )
529
+ / relevant_constraint_count
530
+ ).unsqueeze(0),
531
+ )
532
+ self.metric_manager.accumulate(
533
+ "CSR/valid",
534
+ (
535
+ (
536
+ sum(constraint_checks)
537
+ - numel(constraint_checks)
538
+ + relevant_constraint_count
539
+ )
540
+ / relevant_constraint_count
541
+ ).unsqueeze(0),
542
+ )
187
543
 
188
- # Only do direction calculations for variable layers affecting constraint
189
- for layer in constraint.layers & self.descriptor.variable_layers:
544
+ # Return loss
545
+ return loss
190
546
 
191
- # Log constraint satisfaction ratio
192
- # NOTE does this take into account spaces with different dimensions?
193
- self.valid_csr[constraint.constraint_name](constraint_checks[layer])
194
- self.valid_csr["global"](constraint_checks[layer])
195
- self.log(
196
- f"valid_csr_{constraint.constraint_name}",
197
- self.valid_csr[constraint.constraint_name],
198
- on_step=False,
199
- on_epoch=True,
200
- )
547
+ def test_step(
548
+ self,
549
+ prediction: dict[str, Tensor],
550
+ loss: Tensor,
551
+ ) -> Tensor:
552
+ """
553
+ Evaluate constraints during test and log satisfaction metrics.
201
554
 
202
- # Log global constraint satisfaction ratio
203
- self.log(
204
- "valid_csr_global",
205
- self.valid_csr["global"],
206
- on_step=False,
207
- on_epoch=True,
208
- )
555
+ Args:
556
+ prediction (dict[str, Tensor]): Model predictions
557
+ for variable layers.
558
+ loss (Tensor): The base loss computed by the criterion.
559
+
560
+ Returns:
561
+ Tensor: The unchanged base loss.
562
+ """
563
+
564
+ # For each constraint in this reference space, calculate directions
565
+ for constraint in self.constraints:
566
+
567
+ # Check if constraints are satisfied for
568
+ constraint_checks, relevant_constraint_count = (
569
+ constraint.check_constraint(prediction)
570
+ )
571
+
572
+ # Log constraint satisfaction ratio
573
+ self.metric_manager.accumulate(
574
+ f"{constraint.name}/test",
575
+ (
576
+ (
577
+ sum(constraint_checks)
578
+ - numel(constraint_checks)
579
+ + relevant_constraint_count
580
+ )
581
+ / relevant_constraint_count
582
+ ).unsqueeze(0),
583
+ )
584
+ self.metric_manager.accumulate(
585
+ "CSR/test",
586
+ (
587
+ (
588
+ sum(constraint_checks)
589
+ - numel(constraint_checks)
590
+ + relevant_constraint_count
591
+ )
592
+ / relevant_constraint_count
593
+ ).unsqueeze(0),
594
+ )
209
595
 
210
596
  # Return loss
211
597
  return loss