congrads 0.2.0__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,200 @@
1
+ """Defines the BatchRunner, which executes individual batches for training, validation, and testing.
2
+
3
+ Responsibilities:
4
+ - Move batch data to the appropriate device
5
+ - Run forward passes through the network
6
+ - Compute base and constraint-adjusted losses
7
+ - Perform backpropagation during training
8
+ - Accumulate metrics for loss and other monitored quantities
9
+ - Trigger callbacks at key points in the batch lifecycle
10
+ """
11
+
12
+ import torch
13
+ from torch import Tensor
14
+ from torch.nn import Module
15
+ from torch.optim import Optimizer
16
+
17
+ from ..callbacks.base import CallbackManager
18
+ from ..core.constraint_engine import ConstraintEngine
19
+ from ..metrics import MetricManager
20
+
21
+
22
+ class BatchRunner:
23
+ """Executes a single batch for training, validation, or testing.
24
+
25
+ The BatchRunner handles moving data to the correct device, running the network
26
+ forward, computing base and constraint-adjusted losses, performing backpropagation
27
+ during training, accumulating metrics, and dispatching callbacks at key points
28
+ in the batch lifecycle.
29
+ """
30
+
31
+ def __init__(
32
+ self,
33
+ network: Module,
34
+ criterion,
35
+ optimizer: Optimizer,
36
+ constraint_engine: ConstraintEngine,
37
+ metric_manager: MetricManager | None,
38
+ callback_manager: CallbackManager | None,
39
+ device: torch.device,
40
+ ):
41
+ """Initialize the BatchRunner.
42
+
43
+ Args:
44
+ network: The neural network module to execute.
45
+ criterion: Loss function callable accepting (output, target, data=batch).
46
+ optimizer: Optimizer for updating network parameters.
47
+ constraint_engine: ConstraintEngine instance for evaluating and enforcing constraints.
48
+ metric_manager: Optional MetricManager for logging batch metrics.
49
+ callback_manager: Optional CallbackManager for triggering hooks during batch processing.
50
+ device: Torch device on which to place data and network.
51
+ """
52
+ self.network = network
53
+ self.criterion = criterion
54
+ self.optimizer = optimizer
55
+ self.constraint_engine = constraint_engine
56
+ self.metric_manager = metric_manager
57
+ self.callback_manager = callback_manager
58
+ self.device = device
59
+
60
+ def _to_device(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
61
+ """Move all tensors in the batch to the BatchRunner's device.
62
+
63
+ Args:
64
+ batch: Dictionary of tensors for a single batch.
65
+
66
+ Returns:
67
+ Dictionary of tensors moved to the target device.
68
+ """
69
+ return {k: v.to(self.device) for k, v in batch.items()}
70
+
71
+ def _run_callbacks(self, hook: str, data: dict) -> dict:
72
+ """Run the specified callback hook on the batch data.
73
+
74
+ Args:
75
+ hook: Name of the callback hook to run.
76
+ data: Dictionary containing batch data.
77
+
78
+ Returns:
79
+ Potentially modified batch data after callback execution.
80
+ """
81
+ if self.callback_manager is None:
82
+ return data
83
+ return self.callback_manager.run(hook, data)
84
+
85
+ def train_batch(self, batch: dict[str, Tensor]) -> Tensor:
86
+ """Run a single training batch.
87
+
88
+ Steps performed:
89
+ 1. Move batch to device and run "on_train_batch_start" callback.
90
+ 2. Forward pass through the network.
91
+ 3. Compute base loss using the criterion and accumulate metric.
92
+ 4. Apply constraint-based adjustments to the loss.
93
+ 5. Perform backward pass and optimizer step.
94
+ 6. Run "on_train_batch_end" callback.
95
+
96
+ Args:
97
+ batch: Dictionary of input and target tensors for the batch.
98
+
99
+ Returns:
100
+ Tensor: The base loss computed before constraint adjustments.
101
+ """
102
+ batch = self._to_device(batch)
103
+ batch = self._run_callbacks("on_train_batch_start", batch)
104
+
105
+ # Forward
106
+ batch = self.network(batch)
107
+ batch = self._run_callbacks("after_train_forward", batch)
108
+
109
+ # Base loss
110
+ loss: Tensor = self.criterion(
111
+ batch["output"],
112
+ batch["target"],
113
+ data=batch,
114
+ )
115
+
116
+ if self.metric_manager is not None:
117
+ self.metric_manager.accumulate("Loss/train", loss.unsqueeze(0))
118
+
119
+ # Constraint-adjusted loss
120
+ combined_loss = self.constraint_engine.train(batch, loss)
121
+
122
+ # Backward
123
+ self.optimizer.zero_grad()
124
+ combined_loss.backward()
125
+ self.optimizer.step()
126
+
127
+ batch = self._run_callbacks("on_train_batch_end", batch)
128
+ return loss
129
+
130
+ def valid_batch(self, batch: dict[str, Tensor]) -> Tensor:
131
+ """Run a single validation batch.
132
+
133
+ Steps performed:
134
+ 1. Move batch to device and run "on_valid_batch_start" callback.
135
+ 2. Forward pass through the network.
136
+ 3. Compute base loss using the criterion and accumulate metric.
137
+ 4. Evaluate constraints via the ConstraintEngine (does not modify loss).
138
+ 5. Run "on_valid_batch_end" callback.
139
+
140
+ Args:
141
+ batch: Dictionary of input and target tensors for the batch.
142
+
143
+ Returns:
144
+ Tensor: The base loss computed for the batch.
145
+ """
146
+ batch = self._to_device(batch)
147
+ batch = self._run_callbacks("on_valid_batch_start", batch)
148
+
149
+ batch = self.network(batch)
150
+ batch = self._run_callbacks("after_valid_forward", batch)
151
+
152
+ loss: Tensor = self.criterion(
153
+ batch["output"],
154
+ batch["target"],
155
+ data=batch,
156
+ )
157
+
158
+ if self.metric_manager is not None:
159
+ self.metric_manager.accumulate("Loss/valid", loss.unsqueeze(0))
160
+
161
+ self.constraint_engine.validate(batch, loss)
162
+
163
+ batch = self._run_callbacks("on_valid_batch_end", batch)
164
+ return loss
165
+
166
+ def test_batch(self, batch: dict[str, Tensor]) -> Tensor:
167
+ """Run a single test batch.
168
+
169
+ Steps performed:
170
+ 1. Move batch to device and run "on_test_batch_start" callback.
171
+ 2. Forward pass through the network.
172
+ 3. Compute base loss using the criterion and accumulate metric.
173
+ 4. Evaluate constraints via the ConstraintEngine (does not modify loss).
174
+ 5. Run "on_test_batch_end" callback.
175
+
176
+ Args:
177
+ batch: Dictionary of input and target tensors for the batch.
178
+
179
+ Returns:
180
+ Tensor: The base loss computed for the batch.
181
+ """
182
+ batch = self._to_device(batch)
183
+ batch = self._run_callbacks("on_test_batch_start", batch)
184
+
185
+ batch = self.network(batch)
186
+ batch = self._run_callbacks("after_test_forward", batch)
187
+
188
+ loss: Tensor = self.criterion(
189
+ batch["output"],
190
+ batch["target"],
191
+ data=batch,
192
+ )
193
+
194
+ if self.metric_manager is not None:
195
+ self.metric_manager.accumulate("Loss/test", loss.unsqueeze(0))
196
+
197
+ self.constraint_engine.test(batch, loss)
198
+
199
+ batch = self._run_callbacks("on_test_batch_end", batch)
200
+ return loss
@@ -0,0 +1,271 @@
1
+ """This module provides the core CongradsCore class for the main training functionality.
2
+
3
+ It is designed to integrate constraint-guided optimization into neural network training.
4
+ It extends traditional training processes by enforcing specific constraints
5
+ on the model's outputs, ensuring that the network satisfies domain-specific
6
+ requirements during both training and evaluation.
7
+
8
+ The `CongradsCore` class serves as the central engine for managing the
9
+ training, validation, and testing phases of a neural network model,
10
+ incorporating constraints that influence the loss function and model updates.
11
+ The model is trained with standard loss functions while also incorporating
12
+ constraint-based adjustments, which are tracked and logged
13
+ throughout the process.
14
+
15
+ Key features:
16
+ - Support for various constraints that can influence the training process.
17
+ - Integration with PyTorch's `DataLoader` for efficient batch processing.
18
+ - Metric management for tracking loss and constraint satisfaction.
19
+ - Checkpoint management for saving and evaluating model states.
20
+
21
+ """
22
+
23
+ from collections.abc import Callable
24
+
25
+ import torch
26
+ from torch import Tensor, sum
27
+ from torch.nn import Module
28
+ from torch.nn.modules.loss import _Loss
29
+ from torch.optim import Optimizer
30
+ from torch.utils.data import DataLoader
31
+ from tqdm import tqdm
32
+
33
+ from congrads.utils.utility import LossWrapper
34
+
35
+ from ..callbacks.base import CallbackManager
36
+ from ..checkpoints import CheckpointManager
37
+ from ..constraints.base import Constraint
38
+ from ..core.batch_runner import BatchRunner
39
+ from ..core.constraint_engine import ConstraintEngine
40
+ from ..core.epoch_runner import EpochRunner
41
+ from ..descriptor import Descriptor
42
+ from ..metrics import MetricManager
43
+
44
+
45
+ class CongradsCore:
46
+ """The CongradsCore class is the central training engine for constraint-guided optimization.
47
+
48
+ It integrates standard neural network training
49
+ with additional constraint-driven adjustments to the loss function, ensuring
50
+ that the network satisfies domain-specific constraints during training.
51
+ """
52
+
53
+ def __init__(
54
+ self,
55
+ descriptor: Descriptor,
56
+ constraints: list[Constraint],
57
+ network: Module,
58
+ criterion: _Loss,
59
+ optimizer: Optimizer,
60
+ device: torch.device,
61
+ dataloader_train: DataLoader,
62
+ dataloader_valid: DataLoader | None = None,
63
+ dataloader_test: DataLoader | None = None,
64
+ metric_manager: MetricManager | None = None,
65
+ callback_manager: CallbackManager | None = None,
66
+ checkpoint_manager: CheckpointManager | None = None,
67
+ network_uses_grad: bool = False,
68
+ epsilon: float = 1e-6,
69
+ constraint_aggregator: Callable[..., Tensor] = sum,
70
+ enforce_all: bool = True,
71
+ disable_progress_bar_epoch: bool = False,
72
+ disable_progress_bar_batch: bool = False,
73
+ epoch_runner_cls: type["EpochRunner"] | None = None,
74
+ batch_runner_cls: type["BatchRunner"] | None = None,
75
+ constraint_engine_cls: type["ConstraintEngine"] | None = None,
76
+ ):
77
+ """Initialize the CongradsCore object.
78
+
79
+ Args:
80
+ descriptor (Descriptor): Describes variable layers in the network.
81
+ constraints (list[Constraint]): List of constraints to guide training.
82
+ network (Module): The neural network model to train.
83
+ criterion (callable): The loss function used for
84
+ training and validation.
85
+ optimizer (Optimizer): The optimizer used for updating model parameters.
86
+ device (torch.device): The device (e.g., CPU or GPU) for computations.
87
+ dataloader_train (DataLoader): DataLoader for training data.
88
+ dataloader_valid (DataLoader, optional): DataLoader for validation data.
89
+ If not provided, validation is skipped.
90
+ dataloader_test (DataLoader, optional): DataLoader for test data.
91
+ If not provided, testing is skipped.
92
+ metric_manager (MetricManager, optional): Manages metric tracking and recording.
93
+ callback_manager (CallbackManager, optional): Manages training callbacks.
94
+ checkpoint_manager (CheckpointManager, optional): Manages
95
+ checkpointing. If not set, no checkpointing is done.
96
+ network_uses_grad (bool, optional): A flag indicating if the network
97
+ contains gradient calculation computations. Default is False.
98
+ epsilon (float, optional): A small value to avoid division by zero
99
+ in gradient calculations. Default is 1e-10.
100
+ constraint_aggregator (Callable[..., Tensor], optional): A function
101
+ to aggregate the constraint rescale loss. Default is `sum`.
102
+ enforce_all (bool, optional): If set to False, constraints will only be monitored and
103
+ not influence the training process. Overrides constraint-specific `enforce` parameters.
104
+ Defaults to True.
105
+ disable_progress_bar_epoch (bool, optional): If set to True, the epoch
106
+ progress bar will not show. Defaults to False.
107
+ disable_progress_bar_batch (bool, optional): If set to True, the batch
108
+ progress bar will not show. Defaults to False.
109
+ epoch_runner_cls (type[EpochRunner], optional): Custom EpochRunner class.
110
+ If not provided, the default EpochRunner is used.
111
+ batch_runner_cls (type[BatchRunner], optional): Custom BatchRunner class.
112
+ If not provided, the default BatchRunner is used.
113
+ constraint_engine_cls (type[ConstraintEngine], optional): Custom ConstraintEngine class.
114
+ If not provided, the default ConstraintEngine is used.
115
+
116
+ Note:
117
+ A warning is logged if the descriptor has no variable layers,
118
+ as at least one variable layer is required for the constraint logic
119
+ to influence the training process.
120
+ """
121
+ # Init object variables
122
+ self.device = device
123
+ self.network = network.to(device)
124
+ self.criterion = LossWrapper(criterion)
125
+ self.optimizer = optimizer
126
+ self.descriptor = descriptor
127
+
128
+ self.constraints = constraints or []
129
+ self.epsilon = epsilon
130
+ self.aggregator = constraint_aggregator
131
+ self.enforce_all = enforce_all
132
+
133
+ self.metric_manager = metric_manager
134
+ self.callback_manager = callback_manager
135
+ self.checkpoint_manager = checkpoint_manager
136
+
137
+ self.disable_progress_bar_epoch = disable_progress_bar_epoch
138
+ self.disable_progress_bar_batch = disable_progress_bar_batch
139
+
140
+ # Initialize constraint engine
141
+ self.constraint_engine = (constraint_engine_cls or ConstraintEngine)(
142
+ constraints=self.constraints,
143
+ descriptor=self.descriptor,
144
+ device=self.device,
145
+ epsilon=self.epsilon,
146
+ aggregator=self.aggregator,
147
+ enforce_all=self.enforce_all,
148
+ metric_manager=self.metric_manager,
149
+ )
150
+
151
+ # Initialize runners
152
+ self.batch_runner = (batch_runner_cls or BatchRunner)(
153
+ network=self.network,
154
+ criterion=self.criterion,
155
+ optimizer=self.optimizer,
156
+ constraint_engine=self.constraint_engine,
157
+ metric_manager=self.metric_manager,
158
+ callback_manager=self.callback_manager,
159
+ device=self.device,
160
+ )
161
+
162
+ self.epoch_runner = (epoch_runner_cls or EpochRunner)(
163
+ batch_runner=self.batch_runner,
164
+ network=self.network,
165
+ train_loader=dataloader_train,
166
+ valid_loader=dataloader_valid,
167
+ test_loader=dataloader_test,
168
+ network_uses_grad=network_uses_grad,
169
+ disable_progress_bar=self.disable_progress_bar_batch,
170
+ )
171
+
172
+ # Initialize constraint metrics
173
+ if self.metric_manager is not None:
174
+ self._initialize_metrics()
175
+
176
+ def _initialize_metrics(self) -> None:
177
+ """Register metrics for loss, constraint satisfaction ratio (CSR), and constraints.
178
+
179
+ This method registers the following metrics:
180
+
181
+ - Loss/train: Training loss.
182
+ - Loss/valid: Validation loss.
183
+ - Loss/test: Test loss after training.
184
+ - CSR/train: Constraint satisfaction ratio during training.
185
+ - CSR/valid: Constraint satisfaction ratio during validation.
186
+ - CSR/test: Constraint satisfaction ratio after training.
187
+ - One metric per constraint, for both training and validation.
188
+
189
+ """
190
+ self.metric_manager.register("Loss/train", "during_training")
191
+ self.metric_manager.register("Loss/valid", "during_training")
192
+ self.metric_manager.register("Loss/test", "after_training")
193
+
194
+ if len(self.constraints) > 0:
195
+ self.metric_manager.register("CSR/train", "during_training")
196
+ self.metric_manager.register("CSR/valid", "during_training")
197
+ self.metric_manager.register("CSR/test", "after_training")
198
+
199
+ for constraint in self.constraints:
200
+ self.metric_manager.register(f"{constraint.name}/train", "during_training")
201
+ self.metric_manager.register(f"{constraint.name}/valid", "during_training")
202
+ self.metric_manager.register(f"{constraint.name}/test", "after_training")
203
+
204
+ def fit(
205
+ self,
206
+ *,
207
+ start_epoch: int = 0,
208
+ max_epochs: int = 100,
209
+ test_model: bool = True,
210
+ final_checkpoint_name: str = "checkpoint_final.pth",
211
+ ) -> None:
212
+ """Run the full training loop, including optional validation, testing, and checkpointing.
213
+
214
+ This method performs training over multiple epochs with the following steps:
215
+ 1. Trigger "on_train_start" callbacks if a callback manager is present.
216
+ 2. For each epoch:
217
+ - Trigger "on_epoch_start" callbacks.
218
+ - Run a training epoch via the EpochRunner.
219
+ - Run a validation epoch via the EpochRunner.
220
+ - Evaluate checkpoint criteria if a checkpoint manager is present.
221
+ - Trigger "on_epoch_end" callbacks.
222
+ 3. Trigger "on_train_end" callbacks after all epochs.
223
+ 4. Optionally run a test epoch via the EpochRunner if `test_model` is True,
224
+ with corresponding "on_test_start" and "on_test_end" callbacks.
225
+ 5. Save a final checkpoint using the checkpoint manager.
226
+
227
+ Args:
228
+ start_epoch: Index of the starting epoch (default 0). Useful for resuming training.
229
+ max_epochs: Maximum number of epochs to run (default 100).
230
+ test_model: Whether to run a test epoch after training (default True).
231
+ final_checkpoint_name: Filename for the final checkpoint saved at the end of training
232
+ (default "checkpoint_final.pth").
233
+
234
+ Returns:
235
+ None
236
+ """
237
+ if self.callback_manager:
238
+ self.callback_manager.run("on_train_start", {"epoch": start_epoch})
239
+
240
+ for epoch in tqdm(
241
+ range(start_epoch, max_epochs),
242
+ initial=start_epoch,
243
+ desc="Epoch",
244
+ disable=self.disable_progress_bar_epoch,
245
+ ):
246
+ if self.callback_manager:
247
+ self.callback_manager.run("on_epoch_start", {"epoch": epoch})
248
+
249
+ self.epoch_runner.train()
250
+ self.epoch_runner.validate()
251
+
252
+ if self.checkpoint_manager:
253
+ self.checkpoint_manager.evaluate_criteria(epoch)
254
+
255
+ if self.callback_manager:
256
+ self.callback_manager.run("on_epoch_end", {"epoch": epoch})
257
+
258
+ if self.callback_manager:
259
+ self.callback_manager.run("on_train_end", {"epoch": epoch})
260
+
261
+ if test_model:
262
+ if self.callback_manager:
263
+ self.callback_manager.run("on_test_start", {"epoch": epoch})
264
+
265
+ self.epoch_runner.test()
266
+
267
+ if self.callback_manager:
268
+ self.callback_manager.run("on_test_end", {"epoch": epoch})
269
+
270
+ if self.checkpoint_manager:
271
+ self.checkpoint_manager.save(epoch, final_checkpoint_name)
@@ -0,0 +1,209 @@
1
+ """Manages the evaluation and optional enforcement of constraints on neural network outputs.
2
+
3
+ Responsibilities:
4
+ - Compute and log Constraint Satisfaction Rate (CSR) for training, validation, and test batches.
5
+ - Optionally adjust loss during training based on constraint directions and rescale factors.
6
+ - Handle gradient computation and CGGD application.
7
+ """
8
+
9
+ import torch
10
+ from torch import Tensor, no_grad
11
+ from torch.linalg import vector_norm
12
+
13
+ from ..constraints.base import Constraint
14
+ from ..descriptor import Descriptor
15
+ from ..metrics import MetricManager
16
+
17
+
18
+ class ConstraintEngine:
19
+ """Manages constraint evaluation and enforcement for a neural network.
20
+
21
+ The ConstraintEngine coordinates constraints defined in Constraint objects,
22
+ computes gradients for layers that affect the loss, logs metrics, and optionally
23
+ modifies the loss during training according to the constraints. It supports
24
+ separate phases for training, validation, and testing.
25
+ """
26
+
27
+ def __init__(
28
+ self,
29
+ *,
30
+ constraints: list[Constraint],
31
+ descriptor: Descriptor,
32
+ metric_manager: MetricManager,
33
+ device: torch.device,
34
+ epsilon: float,
35
+ aggregator: callable,
36
+ enforce_all: bool,
37
+ ) -> None:
38
+ """Initialize the ConstraintEngine.
39
+
40
+ Args:
41
+ constraints: List of Constraint objects to evaluate and optionally enforce.
42
+ descriptor: Descriptor containing metadata about network layers and which
43
+ variables affect the loss.
44
+ metric_manager: MetricManager instance for logging CSR metrics.
45
+ device: Torch device where tensors will be allocated (CPU or GPU).
46
+ epsilon: Small positive value to avoid division by zero in gradient norms.
47
+ aggregator: Callable used to reduce per-layer constraint contributions
48
+ to a scalar loss adjustment.
49
+ enforce_all: Whether to enforce all constraints during training.
50
+ """
51
+ self.constraints = constraints
52
+ self.descriptor = descriptor
53
+ self.metric_manager = metric_manager
54
+ self.device = device
55
+ self.epsilon = epsilon
56
+ self.enforce_all = enforce_all
57
+ self.aggregator = aggregator
58
+
59
+ self.norm_loss_grad: dict[str, Tensor] = {}
60
+
61
+ def train(self, data: dict[str, Tensor], loss: Tensor) -> Tensor:
62
+ """Apply all active constraints during training.
63
+
64
+ Computes the original loss gradients for layers that affect the loss,
65
+ evaluates each constraint, logs the Constraint Satisfaction Rate (CSR),
66
+ and adjusts the loss according to constraint satisfaction.
67
+
68
+ Args:
69
+ data: Dictionary containing input and prediction tensors for the batch.
70
+ loss: The original loss tensor computed from the network output.
71
+
72
+ Returns:
73
+ Tensor: The loss tensor after applying constraint-based adjustments.
74
+ """
75
+ return self._apply_constraints(data, loss, phase="train", enforce=True)
76
+
77
+ def validate(self, data: dict[str, Tensor], loss: Tensor) -> Tensor:
78
+ """Evaluate constraints during validation without modifying the loss.
79
+
80
+ Computes and logs the Constraint Satisfaction Rate (CSR) for each constraint,
81
+ but does not apply rescale adjustments to the loss.
82
+
83
+ Args:
84
+ data: Dictionary containing input and prediction tensors for the batch.
85
+ loss: The original loss tensor computed from the network output.
86
+
87
+ Returns:
88
+ Tensor: The original loss tensor, unchanged.
89
+ """
90
+ return self._apply_constraints(data, loss, phase="valid", enforce=False)
91
+
92
+ def test(self, data: dict[str, Tensor], loss: Tensor) -> Tensor:
93
+ """Evaluate constraints during testing without modifying the loss.
94
+
95
+ Computes and logs the Constraint Satisfaction Rate (CSR) for each constraint,
96
+ but does not apply rescale adjustments to the loss.
97
+
98
+ Args:
99
+ data: Dictionary containing input and prediction tensors for the batch.
100
+ loss: The original loss tensor computed from the network output.
101
+
102
+ Returns:
103
+ Tensor: The original loss tensor, unchanged.
104
+ """
105
+ return self._apply_constraints(data, loss, phase="test", enforce=False)
106
+
107
+ def _apply_constraints(
108
+ self, data: dict[str, Tensor], loss: Tensor, phase: str, enforce: bool
109
+ ) -> Tensor:
110
+ """Evaluate constraints, log CSR, and optionally adjust the loss.
111
+
112
+ During training, computes loss gradients for variable layers that affect the loss.
113
+ Iterates over all constraints, logging the Constraint Satisfaction Rate (CSR)
114
+ and, if enforcement is enabled, adjusts the loss using constraint-specific
115
+ directions and rescale factors.
116
+
117
+ Args:
118
+ data: Dictionary containing input and prediction tensors for the batch.
119
+ loss: Original loss tensor computed from the network output.
120
+ phase: Current phase, one of "train", "valid", or "test".
121
+ enforce: If True, constraint-based adjustments are applied to the loss.
122
+
123
+ Returns:
124
+ Tensor: The combined loss after applying constraints (or the original loss
125
+ if enforce is False or not in training phase).
126
+ """
127
+ total_rescale_loss = torch.tensor(0.0, device=self.device, dtype=loss.dtype)
128
+
129
+ if phase == "train":
130
+ norm_loss_grad = self._calculate_loss_gradients(loss, data)
131
+ norm_loss_grad = self._override_loss_gradients(norm_loss_grad, loss, data)
132
+
133
+ # Iterate constraints
134
+ for constraint in self.constraints:
135
+ checks, mask = constraint.check_constraint(data)
136
+ directions = constraint.calculate_direction(data)
137
+
138
+ # Log CSR
139
+ csr = (torch.sum(checks * mask) / torch.sum(mask)).unsqueeze(0)
140
+ self.metric_manager.accumulate(f"{constraint.name}/{phase}", csr)
141
+ self.metric_manager.accumulate(f"CSR/{phase}", csr)
142
+
143
+ # Skip adjustment if not enforcing
144
+ if not enforce or not constraint.enforce or not self.enforce_all or phase != "train":
145
+ continue
146
+
147
+ # Compute constraint-based rescale loss
148
+ for key in constraint.layers & self.descriptor.variable_keys:
149
+ with no_grad():
150
+ rescale = (1 - checks) * directions[key] * constraint.rescale_factor
151
+ total_rescale_loss += self.aggregator(data[key] * rescale * norm_loss_grad[key])
152
+
153
+ return loss + total_rescale_loss
154
+
155
+ def _calculate_loss_gradients(self, loss: Tensor, data: dict[str, Tensor]) -> None:
156
+ """Compute and store normalized loss gradients for variable layers.
157
+
158
+ For each layer that affects the loss, computes the gradient of the loss
159
+ with respect to that layer's output. The gradients are normalized by their
160
+ vector norms plus a small epsilon to avoid division by zero.
161
+
162
+ Args:
163
+ loss: The original loss tensor computed from the network output.
164
+ data: Dictionary containing input and prediction tensors for the batch.
165
+ """
166
+ # Precompute gradients for variable layers affecting the loss
167
+ norm_loss_grad = {}
168
+
169
+ variable_keys = self.descriptor.variable_keys & self.descriptor.affects_loss_keys
170
+ for key in variable_keys:
171
+ if data[key].requires_grad is False:
172
+ raise RuntimeError(
173
+ f"Layer '{key}' does not require gradients. Is this an input? "
174
+ "Set constant=True in Descriptor if this layer is an input."
175
+ )
176
+
177
+ grad = torch.autograd.grad(
178
+ outputs=loss, inputs=data[key], retain_graph=True, allow_unused=True
179
+ )[0]
180
+
181
+ if grad is None:
182
+ raise RuntimeError(
183
+ f"Unable to compute loss gradients for layer '{key}'. "
184
+ "Set has_loss=False in Descriptor if this layer does not affect loss."
185
+ )
186
+
187
+ grad_flat = grad.view(grad.shape[0], -1)
188
+ norm_loss_grad[key] = (
189
+ vector_norm(grad_flat, dim=1, ord=2, keepdim=True).clamp(min=self.epsilon).detach()
190
+ )
191
+
192
+ return norm_loss_grad
193
+
194
+ def _override_loss_gradients(
195
+ self, norm_loss_grad: dict[str, Tensor], loss: Tensor, data: dict[str, Tensor]
196
+ ) -> dict[str, Tensor]:
197
+ """Override the standard normalized loss gradient computation for custom functionality.
198
+
199
+ Args:
200
+ norm_loss_grad: Dictionary mapping parameter or component names to normalized
201
+ gradient tensors.
202
+ loss: Scalar loss value for the current training step.
203
+ data: Dictionary containing the batch data used to compute the loss and any
204
+ constraint-related signals.
205
+
206
+ Returns:
207
+ A dictionary of modified normalized gradients.
208
+ """
209
+ return norm_loss_grad