congrads 0.2.0__py3-none-any.whl → 1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
congrads/core.py CHANGED
@@ -1,16 +1,100 @@
1
- import logging
2
- from torch import Tensor, float32, no_grad, norm, tensor
3
- from torch.optim import Optimizer
1
+ """
2
+ This module provides the CongradsCore class, which is designed to integrate
3
+ constraint-guided optimization into neural network training.
4
+ It extends traditional training processes by enforcing specific constraints
5
+ on the model's outputs, ensuring that the network satisfies domain-specific
6
+ requirements during both training and evaluation.
7
+
8
+ The `CongradsCore` class serves as the central engine for managing the
9
+ training, validation, and testing phases of a neural network model,
10
+ incorporating constraints that influence the loss function and model updates.
11
+ The model is trained with standard loss functions while also incorporating
12
+ constraint-based adjustments, which are tracked and logged
13
+ throughout the process.
14
+
15
+ Key features:
16
+ - Support for various constraints that can influence the training process.
17
+ - Integration with PyTorch's `DataLoader` for efficient batch processing.
18
+ - Metric management for tracking loss and constraint satisfaction.
19
+ - Checkpoint management for saving and evaluating model states.
20
+
21
+ Modules in this package provide the following:
22
+
23
+ - `Descriptor`: Describes variable layers in the network that are
24
+ subject to constraints.
25
+ - `Constraint`: Defines various constraints, which are used to guide
26
+ the training process.
27
+ - `MetricManager`: Manages and tracks performance metrics such as loss
28
+ and constraint satisfaction.
29
+ - `CheckpointManager`: Manages saving and loading model checkpoints
30
+ during training.
31
+ - Utility functions to validate inputs and configurations.
32
+
33
+ Dependencies:
34
+ - PyTorch (`torch`)
35
+ - tqdm (for progress tracking)
36
+
37
+ The `CongradsCore` class allows for the use of additional callback functions
38
+ at different stages of the training process to customize behavior for
39
+ specific needs. These include callbacks for the start and end of epochs, as
40
+ well as the start and end of the entire training process.
41
+
42
+ """
43
+
44
+ import warnings
45
+ from numbers import Number
46
+ from typing import Callable
47
+
48
+ import torch
49
+
50
+ # pylint: disable-next=redefined-builtin
51
+ from torch import Tensor, float32, maximum, no_grad, norm, numel, sum, tensor
4
52
  from torch.nn import Module
53
+ from torch.nn.modules.loss import _Loss
54
+ from torch.optim import Optimizer
5
55
  from torch.utils.data import DataLoader
6
- from time import time
56
+ from tqdm import tqdm
7
57
 
8
- from .metrics import MetricManager
58
+ from .checkpoints import CheckpointManager
9
59
  from .constraints import Constraint
10
60
  from .descriptor import Descriptor
61
+ from .metrics import MetricManager
62
+ from .utils import (
63
+ validate_callable,
64
+ validate_iterable,
65
+ validate_loaders,
66
+ validate_type,
67
+ )
11
68
 
12
69
 
13
70
  class CongradsCore:
71
+ """
72
+ The CongradsCore class is the central training engine for constraint-guided
73
+ neural network optimization. It integrates standard neural network training
74
+ with additional constraint-driven adjustments to the loss function, ensuring
75
+ that the network satisfies domain-specific constraints during training.
76
+
77
+ Args:
78
+ descriptor (Descriptor): Describes variable layers in the network.
79
+ constraints (list[Constraint]): List of constraints to guide training.
80
+ loaders (tuple[DataLoader, DataLoader, DataLoader]): DataLoaders for
81
+ training, validation, and testing.
82
+ network (Module): The neural network model to train.
83
+ criterion (callable): The loss function used for
84
+ training and validation.
85
+ optimizer (Optimizer): The optimizer used for updating model parameters.
86
+ metric_manager (MetricManager): Manages metric tracking and recording.
87
+ device (torch.device): The device (e.g., CPU or GPU) for computations.
88
+ checkpoint_manager (CheckpointManager, optional): Manages
89
+ checkpointing. If not set, no checkpointing is done.
90
+ epsilon (Number, optional): A small value to avoid division by zero
91
+ in gradient calculations. Default is 1e-10.
92
+
93
+ Note:
94
+ A warning is logged if the descriptor has no variable layers,
95
+ as at least one variable layer is required for the constraint logic
96
+ to influence the training process.
97
+ """
14
98
 
15
99
  def __init__(
16
100
  self,
@@ -18,14 +102,33 @@ class CongradsCore:
18
102
  constraints: list[Constraint],
19
103
  loaders: tuple[DataLoader, DataLoader, DataLoader],
20
104
  network: Module,
21
- criterion: callable,
105
+ criterion: _Loss,
22
106
  optimizer: Optimizer,
23
107
  metric_manager: MetricManager,
24
- device,
108
+ device: torch.device,
109
+ checkpoint_manager: CheckpointManager = None,
110
+ epsilon: Number = 1e-6,
25
111
  ):
26
-
27
- # Init parent class
28
- super().__init__()
112
+ """
113
+ Initialize the CongradsCore object.
114
+ """
115
+
116
+ # Type checking
117
+ validate_type("descriptor", descriptor, Descriptor)
118
+ validate_iterable("constraints", constraints, Constraint)
119
+ validate_loaders()
120
+ validate_type("network", network, Module)
121
+ validate_type("criterion", criterion, _Loss)
122
+ validate_type("optimizer", optimizer, Optimizer)
123
+ validate_type("metric_manager", metric_manager, MetricManager)
124
+ validate_type("device", device, torch.device)
125
+ validate_type(
126
+ "checkpoint_manager",
127
+ checkpoint_manager,
128
+ CheckpointManager,
129
+ allow_none=True,
130
+ )
131
+ validate_type("epsilon", epsilon, Number)
29
132
 
30
133
  # Init object variables
31
134
  self.descriptor = descriptor
@@ -38,105 +141,270 @@ class CongradsCore:
38
141
  self.optimizer = optimizer
39
142
  self.metric_manager = metric_manager
40
143
  self.device = device
144
+ self.checkpoint_manager = checkpoint_manager
145
+
146
+ # Init epsilon tensor
147
+ self.epsilon = tensor(epsilon, device=self.device)
41
148
 
42
149
  # Perform checks
43
150
  if len(self.descriptor.variable_layers) == 0:
44
- logging.warning(
45
- "The descriptor object has no variable layers. The constraint guided loss adjustment is therefore not used. Is this the intended behaviour?"
151
+ warnings.warn(
152
+ "The descriptor object has no variable layers. The constraint \
153
+ guided loss adjustment is therefore not used. \
154
+ Is this the intended behavior?"
46
155
  )
47
156
 
48
157
  # Initialize constraint metrics
49
- metric_manager.register("Loss/train")
50
- metric_manager.register("Loss/valid")
51
- metric_manager.register("CSR/train")
52
- metric_manager.register("CSR/valid")
158
+ self._initialize_metrics()
159
+
160
+ def _initialize_metrics(self) -> None:
161
+ """
162
+ Register metrics for loss, constraint satisfaction ratio (CSR),
163
+ and individual constraints.
164
+
165
+ This method registers the following metrics:
166
+
167
+ - Loss/train: Training loss.
168
+ - Loss/valid: Validation loss.
169
+ - Loss/test: Test loss after training.
170
+ - CSR/train: Constraint satisfaction ratio during training.
171
+ - CSR/valid: Constraint satisfaction ratio during validation.
172
+ - CSR/test: Constraint satisfaction ratio after training.
173
+ - One metric per constraint, for both training and validation.
174
+
175
+ """
176
+
177
+ self.metric_manager.register("Loss/train", "during_training")
178
+ self.metric_manager.register("Loss/valid", "during_training")
179
+ self.metric_manager.register("Loss/test", "after_training")
180
+
181
+ if len(self.constraints) > 0:
182
+ self.metric_manager.register("CSR/train", "during_training")
183
+ self.metric_manager.register("CSR/valid", "during_training")
184
+ self.metric_manager.register("CSR/test", "after_training")
53
185
 
54
186
  for constraint in self.constraints:
55
- metric_manager.register(f"{constraint.name}/train")
56
- metric_manager.register(f"{constraint.name}/valid")
187
+ self.metric_manager.register(
188
+ f"{constraint.name}/train", "during_training"
189
+ )
190
+ self.metric_manager.register(
191
+ f"{constraint.name}/valid", "during_training"
192
+ )
193
+ self.metric_manager.register(
194
+ f"{constraint.name}/test", "after_training"
195
+ )
196
+
197
+ def fit(
198
+ self,
199
+ start_epoch: int = 0,
200
+ max_epochs: int = 100,
201
+ on_epoch_start: Callable[[int], None] = None,
202
+ on_epoch_end: Callable[[int], None] = None,
203
+ on_train_start: Callable[[int], None] = None,
204
+ on_train_end: Callable[[int], None] = None,
205
+ ) -> None:
206
+ """
207
+ Train the model for a given number of epochs.
208
+
209
+ Args:
210
+ start_epoch (int, optional): The epoch number to start the training
211
+ with. Default is 0.
212
+ max_epochs (int, optional): The number of epochs to train the
213
+ model. Default is 100.
214
+ on_epoch_start (Callable[[int], None], optional): A callback
215
+ function that will be executed at the start of each epoch.
216
+ on_epoch_end (Callable[[int], None], optional): A callback
217
+ function that will be executed at the end of each epoch.
218
+ on_train_start (Callable[[int], None], optional): A callback
219
+ function that will be executed before the training starts.
220
+ on_train_end (Callable[[int], None], optional): A callback
221
+ function that will be executed after training ends.
222
+ """
223
+
224
+ # Type checking
225
+ validate_type("start_epoch", start_epoch, int)
226
+ validate_callable("on_epoch_start", on_epoch_start, True)
227
+ validate_callable("on_epoch_end", on_epoch_end, True)
228
+ validate_callable("on_train_start", on_train_start, True)
229
+ validate_callable("on_train_end", on_train_end, True)
230
+
231
+ # Keep track of epoch
232
+ epoch = start_epoch
233
+
234
+ # Execute training start hook if set
235
+ if on_train_start:
236
+ on_train_start(epoch)
237
+
238
+ for i in tqdm(range(epoch, max_epochs), initial=epoch, desc="Epoch"):
239
+ epoch = i
240
+
241
+ # Execute epoch start hook if set
242
+ if on_epoch_start:
243
+ on_epoch_start(epoch)
244
+
245
+ # Execute training and validation epoch
246
+ self._train_epoch()
247
+ self._validate_epoch()
248
+
249
+ # Checkpointing
250
+ if self.checkpoint_manager:
251
+ self.checkpoint_manager.evaluate_criteria(epoch)
252
+
253
+ # Execute epoch end hook if set
254
+ if on_epoch_end:
255
+ on_epoch_end(epoch)
256
+
257
+ # Evaluate model performance on unseen test set
258
+ self._test_model()
259
+
260
+ # Save final model
261
+ if self.checkpoint_manager:
262
+ self.checkpoint_manager.save(epoch, "checkpoint_final.pth")
263
+
264
+ # Execute training end hook if set
265
+ if on_train_end:
266
+ on_train_end(epoch)
267
+
268
+ def _train_epoch(self) -> None:
269
+ """
270
+ Perform training for a single epoch.
271
+
272
+ This method:
273
+ - Sets the model to training mode.
274
+ - Processes batches from the training DataLoader.
275
+ - Computes predictions and losses.
276
+ - Adjusts losses based on constraints.
277
+ - Updates model parameters using backpropagation.
278
+
279
+ Args:
280
+ epoch (int): The current epoch number.
281
+ """
282
+
283
+ # Set model in training mode
284
+ self.network.train()
285
+
286
+ for batch in tqdm(
287
+ self.train_loader, desc="Training batches", leave=False
288
+ ):
289
+
290
+ # Get input-output pairs from batch
291
+ inputs, outputs = batch
292
+
293
+ # Transfer to GPU
294
+ inputs, outputs = inputs.to(self.device), outputs.to(self.device)
295
+
296
+ # Model computations
297
+ prediction = self.network(inputs)
298
+
299
+ # Calculate loss
300
+ loss = self.criterion(prediction["output"], outputs)
301
+ self.metric_manager.accumulate("Loss/train", loss.unsqueeze(0))
302
+
303
+ # Adjust loss based on constraints
304
+ combined_loss = self.train_step(prediction, loss)
305
+
306
+ # Backprop
307
+ self.optimizer.zero_grad()
308
+ combined_loss.backward(
309
+ retain_graph=False, inputs=list(self.network.parameters())
310
+ )
311
+ self.optimizer.step()
57
312
 
58
- def fit(self, max_epochs: int = 100):
59
- # Loop over epochs
60
- for epoch in range(max_epochs):
313
+ def _validate_epoch(self) -> None:
314
+ """
315
+ Perform validation for a single epoch.
61
316
 
62
- # Log start time
63
- start_time = time()
317
+ This method:
318
+ - Sets the model to evaluation mode.
319
+ - Processes batches from the validation DataLoader.
320
+ - Computes predictions and losses.
321
+ - Logs constraint satisfaction ratios.
64
322
 
65
- # Training
66
- for batch in self.train_loader:
323
+ Args:
324
+ epoch (int): The current epoch number.
325
+ """
67
326
 
68
- # Set model in training mode
69
- self.network.train()
327
+ # Set model in evaluation mode
328
+ self.network.eval()
329
+
330
+ with no_grad():
331
+ for batch in tqdm(
332
+ self.valid_loader, desc="Validation batches", leave=False
333
+ ):
70
334
 
71
335
  # Get input-output pairs from batch
72
336
  inputs, outputs = batch
73
337
 
74
338
  # Transfer to GPU
75
- inputs, outputs = inputs.to(self.device), outputs.to(self.device)
76
-
77
- # Log preparation time
78
- prepare_time = start_time - time()
339
+ inputs, outputs = inputs.to(self.device), outputs.to(
340
+ self.device
341
+ )
79
342
 
80
343
  # Model computations
81
344
  prediction = self.network(inputs)
82
345
 
83
346
  # Calculate loss
84
347
  loss = self.criterion(prediction["output"], outputs)
85
- self.metric_manager.accumulate("Loss/train", loss.unsqueeze(0))
348
+ self.metric_manager.accumulate("Loss/valid", loss.unsqueeze(0))
86
349
 
87
- # Adjust loss based on constraints
88
- combined_loss = self.train_step(prediction, loss)
350
+ # Validate constraints
351
+ self.valid_step(prediction, loss)
89
352
 
90
- # Backpropx
91
- self.optimizer.zero_grad()
92
- combined_loss.backward(
93
- retain_graph=False, inputs=list(self.network.parameters())
94
- )
95
- self.optimizer.step()
96
-
97
- # Validation
98
- with no_grad():
99
- for batch in self.valid_loader:
353
+ def _test_model(self) -> None:
354
+ """
355
+ Evaluate model performance on the test set.
100
356
 
101
- # Set model in evaluation mode
102
- self.network.eval()
357
+ This method:
358
+ - Sets the model to evaluation mode.
359
+ - Processes batches from the test DataLoader.
360
+ - Computes predictions and losses.
361
+ - Logs constraint satisfaction ratios.
103
362
 
104
- # Get input-output pairs from batch
105
- inputs, outputs = batch
363
+ """
106
364
 
107
- # Transfer to GPU
108
- inputs, outputs = inputs.to(self.device), outputs.to(self.device)
365
+ # Set model in evaluation mode
366
+ self.network.eval()
109
367
 
110
- # Model computations
111
- prediction = self.network(inputs)
368
+ with no_grad():
369
+ for batch in tqdm(
370
+ self.test_loader, desc="Test batches", leave=False
371
+ ):
112
372
 
113
- # Calculate loss
114
- loss = self.criterion(prediction["output"], outputs)
115
- self.metric_manager.accumulate("Loss/valid", loss.unsqueeze(0))
373
+ # Get input-output pairs from batch
374
+ inputs, outputs = batch
116
375
 
117
- # Validate constraints
118
- self.valid_step(prediction, loss)
376
+ # Transfer to GPU
377
+ inputs, outputs = inputs.to(self.device), outputs.to(
378
+ self.device
379
+ )
119
380
 
120
- # TODO with valid loader, checkpoint model with best performance
381
+ # Model computations
382
+ prediction = self.network(inputs)
121
383
 
122
- # Save metrics
123
- self.metric_manager.record(epoch)
124
- self.metric_manager.reset()
384
+ # Calculate loss
385
+ loss = self.criterion(prediction["output"], outputs)
386
+ self.metric_manager.accumulate("Loss/test", loss.unsqueeze(0))
125
387
 
126
- # Log compute and preparation time
127
- process_time = start_time - time() - prepare_time
128
- print(
129
- "Compute efficiency: {:.2f}, epoch: {}/{}:".format(
130
- process_time / (process_time + prepare_time), epoch, max_epochs
131
- )
132
- )
133
- start_time = time()
388
+ # Validate constraints
389
+ self.test_step(prediction, loss)
134
390
 
135
391
  def train_step(
136
392
  self,
137
393
  prediction: dict[str, Tensor],
138
394
  loss: Tensor,
139
- ):
395
+ ) -> Tensor:
396
+ """
397
+ Adjust the training loss based on constraints
398
+ and compute the combined loss.
399
+
400
+ Args:
401
+ prediction (dict[str, Tensor]): Model predictions
402
+ for variable layers.
403
+ loss (Tensor): The base loss computed by the criterion.
404
+
405
+ Returns:
406
+ Tensor: The combined loss (base loss + constraint adjustments).
407
+ """
140
408
 
141
409
  # Init scalar tensor for loss
142
410
  total_rescale_loss = tensor(0, dtype=float32, device=self.device)
@@ -149,48 +417,76 @@ class CongradsCore:
149
417
  loss.backward(retain_graph=True, inputs=prediction[layer])
150
418
  loss_grads[layer] = prediction[layer].grad
151
419
 
152
- # For each constraint, TODO split into real and validation only constraints
153
420
  for constraint in self.constraints:
154
421
 
155
422
  # Check if constraints are satisfied and calculate directions
156
423
  with no_grad():
157
- constraint_checks = constraint.check_constraint(prediction)
158
- constraint_directions = constraint.calculate_direction(prediction)
159
-
160
- # Only do direction calculations for variable layers affecting constraint
161
- for layer in constraint.layers & self.descriptor.variable_layers:
424
+ constraint_checks, relevant_constraint_count = (
425
+ constraint.check_constraint(prediction)
426
+ )
162
427
 
428
+ # Only do adjusting calculation if constraint is not observant
429
+ if not constraint.monitor_only:
163
430
  with no_grad():
164
- # Multiply direction modifiers with constraint result
165
- constraint_result = (
166
- constraint_checks.unsqueeze(1).type(float32)
167
- * constraint_directions[layer]
431
+ constraint_directions = constraint.calculate_direction(
432
+ prediction
168
433
  )
169
434
 
170
- # Multiply result with rescale factor of constraint
171
- constraint_result *= constraint.rescale_factor
435
+ # Only do direction calculations for variable
436
+ # layers affecting constraint
437
+ for layer in (
438
+ constraint.layers & self.descriptor.variable_layers
439
+ ):
440
+
441
+ with no_grad():
442
+ # Multiply direction modifiers with constraint result
443
+ constraint_result = (
444
+ 1 - constraint_checks.unsqueeze(1)
445
+ ) * constraint_directions[layer]
446
+
447
+ # Multiply result with rescale factor of constraint
448
+ constraint_result *= constraint.rescale_factor
172
449
 
173
- # Calculate loss gradient norm
174
- norm_loss_grad = norm(loss_grads[layer], dim=1, p=2, keepdim=True)
450
+ # Calculate loss gradient norm
451
+ norm_loss_grad = norm(
452
+ loss_grads[layer], dim=1, p=2, keepdim=True
453
+ )
175
454
 
176
- # Calculate rescale loss
177
- rescale_loss = (
178
- prediction[layer]
179
- * constraint_result
180
- * norm_loss_grad.detach().clone()
181
- ).mean()
455
+ # Apply minimum epsilon
456
+ norm_loss_grad = maximum(norm_loss_grad, self.epsilon)
182
457
 
183
- # Store rescale loss for this reference space
184
- total_rescale_loss += rescale_loss
458
+ # Calculate rescale loss
459
+ rescale_loss = (
460
+ prediction[layer]
461
+ * constraint_result
462
+ * norm_loss_grad.detach().clone()
463
+ ).mean()
464
+
465
+ # Store rescale loss for this reference space
466
+ total_rescale_loss += rescale_loss
185
467
 
186
468
  # Log constraint satisfaction ratio
187
469
  self.metric_manager.accumulate(
188
470
  f"{constraint.name}/train",
189
- (~constraint_checks).type(float32),
471
+ (
472
+ (
473
+ sum(constraint_checks)
474
+ - numel(constraint_checks)
475
+ + relevant_constraint_count
476
+ )
477
+ / relevant_constraint_count
478
+ ).unsqueeze(0),
190
479
  )
191
480
  self.metric_manager.accumulate(
192
481
  "CSR/train",
193
- (~constraint_checks).type(float32),
482
+ (
483
+ (
484
+ sum(constraint_checks)
485
+ - numel(constraint_checks)
486
+ + relevant_constraint_count
487
+ )
488
+ / relevant_constraint_count
489
+ ).unsqueeze(0),
194
490
  )
195
491
 
196
492
  # Return combined loss
@@ -200,26 +496,102 @@ class CongradsCore:
200
496
  self,
201
497
  prediction: dict[str, Tensor],
202
498
  loss: Tensor,
203
- ):
499
+ ) -> Tensor:
500
+ """
501
+ Evaluate constraints during validation and log satisfaction metrics.
204
502
 
205
- # Compute rescale loss without tracking gradients
206
- with no_grad():
503
+ Args:
504
+ prediction (dict[str, Tensor]): Model predictions for
505
+ variable layers.
506
+ loss (Tensor): The base loss computed by the criterion.
207
507
 
208
- # For each constraint in this reference space, calculate directions
209
- for constraint in self.constraints:
508
+ Returns:
509
+ Tensor: The unchanged base loss.
510
+ """
210
511
 
211
- # Check if constraints are satisfied for
212
- constraint_checks = constraint.check_constraint(prediction)
512
+ # For each constraint in this reference space, calculate directions
513
+ for constraint in self.constraints:
213
514
 
214
- # Log constraint satisfaction ratio
215
- self.metric_manager.accumulate(
216
- f"{constraint.name}/valid",
217
- (~constraint_checks).type(float32),
218
- )
219
- self.metric_manager.accumulate(
220
- "CSR/valid",
221
- (~constraint_checks).type(float32),
222
- )
515
+ # Check if constraints are satisfied for
516
+ constraint_checks, relevant_constraint_count = (
517
+ constraint.check_constraint(prediction)
518
+ )
519
+
520
+ # Log constraint satisfaction ratio
521
+ self.metric_manager.accumulate(
522
+ f"{constraint.name}/valid",
523
+ (
524
+ (
525
+ sum(constraint_checks)
526
+ - numel(constraint_checks)
527
+ + relevant_constraint_count
528
+ )
529
+ / relevant_constraint_count
530
+ ).unsqueeze(0),
531
+ )
532
+ self.metric_manager.accumulate(
533
+ "CSR/valid",
534
+ (
535
+ (
536
+ sum(constraint_checks)
537
+ - numel(constraint_checks)
538
+ + relevant_constraint_count
539
+ )
540
+ / relevant_constraint_count
541
+ ).unsqueeze(0),
542
+ )
543
+
544
+ # Return loss
545
+ return loss
546
+
547
+ def test_step(
548
+ self,
549
+ prediction: dict[str, Tensor],
550
+ loss: Tensor,
551
+ ) -> Tensor:
552
+ """
553
+ Evaluate constraints during test and log satisfaction metrics.
554
+
555
+ Args:
556
+ prediction (dict[str, Tensor]): Model predictions
557
+ for variable layers.
558
+ loss (Tensor): The base loss computed by the criterion.
559
+
560
+ Returns:
561
+ Tensor: The unchanged base loss.
562
+ """
563
+
564
+ # For each constraint in this reference space, calculate directions
565
+ for constraint in self.constraints:
566
+
567
+ # Check if constraints are satisfied for
568
+ constraint_checks, relevant_constraint_count = (
569
+ constraint.check_constraint(prediction)
570
+ )
571
+
572
+ # Log constraint satisfaction ratio
573
+ self.metric_manager.accumulate(
574
+ f"{constraint.name}/test",
575
+ (
576
+ (
577
+ sum(constraint_checks)
578
+ - numel(constraint_checks)
579
+ + relevant_constraint_count
580
+ )
581
+ / relevant_constraint_count
582
+ ).unsqueeze(0),
583
+ )
584
+ self.metric_manager.accumulate(
585
+ "CSR/test",
586
+ (
587
+ (
588
+ sum(constraint_checks)
589
+ - numel(constraint_checks)
590
+ + relevant_constraint_count
591
+ )
592
+ / relevant_constraint_count
593
+ ).unsqueeze(0),
594
+ )
223
595
 
224
596
  # Return loss
225
597
  return loss