congrads 0.1.0__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- congrads/__init__.py +10 -20
- congrads/callbacks/base.py +357 -0
- congrads/callbacks/registry.py +106 -0
- congrads/checkpoints.py +178 -0
- congrads/constraints/base.py +242 -0
- congrads/constraints/registry.py +1255 -0
- congrads/core/batch_runner.py +200 -0
- congrads/core/congradscore.py +271 -0
- congrads/core/constraint_engine.py +209 -0
- congrads/core/epoch_runner.py +119 -0
- congrads/datasets/registry.py +799 -0
- congrads/descriptor.py +147 -43
- congrads/metrics.py +116 -41
- congrads/networks/registry.py +68 -0
- congrads/py.typed +0 -0
- congrads/transformations/base.py +37 -0
- congrads/transformations/registry.py +86 -0
- congrads/utils/preprocessors.py +439 -0
- congrads/utils/utility.py +506 -0
- congrads/utils/validation.py +182 -0
- congrads-0.3.0.dist-info/METADATA +234 -0
- congrads-0.3.0.dist-info/RECORD +23 -0
- congrads-0.3.0.dist-info/WHEEL +4 -0
- congrads/constraints.py +0 -507
- congrads/core.py +0 -211
- congrads/datasets.py +0 -742
- congrads/learners.py +0 -233
- congrads/networks.py +0 -91
- congrads-0.1.0.dist-info/LICENSE +0 -34
- congrads-0.1.0.dist-info/METADATA +0 -196
- congrads-0.1.0.dist-info/RECORD +0 -13
- congrads-0.1.0.dist-info/WHEEL +0 -5
- congrads-0.1.0.dist-info/top_level.txt +0 -1
|
@@ -0,0 +1,200 @@
|
|
|
1
|
+
"""Defines the BatchRunner, which executes individual batches for training, validation, and testing.
|
|
2
|
+
|
|
3
|
+
Responsibilities:
|
|
4
|
+
- Move batch data to the appropriate device
|
|
5
|
+
- Run forward passes through the network
|
|
6
|
+
- Compute base and constraint-adjusted losses
|
|
7
|
+
- Perform backpropagation during training
|
|
8
|
+
- Accumulate metrics for loss and other monitored quantities
|
|
9
|
+
- Trigger callbacks at key points in the batch lifecycle
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
import torch
|
|
13
|
+
from torch import Tensor
|
|
14
|
+
from torch.nn import Module
|
|
15
|
+
from torch.optim import Optimizer
|
|
16
|
+
|
|
17
|
+
from ..callbacks.base import CallbackManager
|
|
18
|
+
from ..core.constraint_engine import ConstraintEngine
|
|
19
|
+
from ..metrics import MetricManager
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class BatchRunner:
|
|
23
|
+
"""Executes a single batch for training, validation, or testing.
|
|
24
|
+
|
|
25
|
+
The BatchRunner handles moving data to the correct device, running the network
|
|
26
|
+
forward, computing base and constraint-adjusted losses, performing backpropagation
|
|
27
|
+
during training, accumulating metrics, and dispatching callbacks at key points
|
|
28
|
+
in the batch lifecycle.
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
def __init__(
|
|
32
|
+
self,
|
|
33
|
+
network: Module,
|
|
34
|
+
criterion,
|
|
35
|
+
optimizer: Optimizer,
|
|
36
|
+
constraint_engine: ConstraintEngine,
|
|
37
|
+
metric_manager: MetricManager | None,
|
|
38
|
+
callback_manager: CallbackManager | None,
|
|
39
|
+
device: torch.device,
|
|
40
|
+
):
|
|
41
|
+
"""Initialize the BatchRunner.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
network: The neural network module to execute.
|
|
45
|
+
criterion: Loss function callable accepting (output, target, data=batch).
|
|
46
|
+
optimizer: Optimizer for updating network parameters.
|
|
47
|
+
constraint_engine: ConstraintEngine instance for evaluating and enforcing constraints.
|
|
48
|
+
metric_manager: Optional MetricManager for logging batch metrics.
|
|
49
|
+
callback_manager: Optional CallbackManager for triggering hooks during batch processing.
|
|
50
|
+
device: Torch device on which to place data and network.
|
|
51
|
+
"""
|
|
52
|
+
self.network = network
|
|
53
|
+
self.criterion = criterion
|
|
54
|
+
self.optimizer = optimizer
|
|
55
|
+
self.constraint_engine = constraint_engine
|
|
56
|
+
self.metric_manager = metric_manager
|
|
57
|
+
self.callback_manager = callback_manager
|
|
58
|
+
self.device = device
|
|
59
|
+
|
|
60
|
+
def _to_device(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
|
61
|
+
"""Move all tensors in the batch to the BatchRunner's device.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
batch: Dictionary of tensors for a single batch.
|
|
65
|
+
|
|
66
|
+
Returns:
|
|
67
|
+
Dictionary of tensors moved to the target device.
|
|
68
|
+
"""
|
|
69
|
+
return {k: v.to(self.device) for k, v in batch.items()}
|
|
70
|
+
|
|
71
|
+
def _run_callbacks(self, hook: str, data: dict) -> dict:
|
|
72
|
+
"""Run the specified callback hook on the batch data.
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
hook: Name of the callback hook to run.
|
|
76
|
+
data: Dictionary containing batch data.
|
|
77
|
+
|
|
78
|
+
Returns:
|
|
79
|
+
Potentially modified batch data after callback execution.
|
|
80
|
+
"""
|
|
81
|
+
if self.callback_manager is None:
|
|
82
|
+
return data
|
|
83
|
+
return self.callback_manager.run(hook, data)
|
|
84
|
+
|
|
85
|
+
def train_batch(self, batch: dict[str, Tensor]) -> Tensor:
|
|
86
|
+
"""Run a single training batch.
|
|
87
|
+
|
|
88
|
+
Steps performed:
|
|
89
|
+
1. Move batch to device and run "on_train_batch_start" callback.
|
|
90
|
+
2. Forward pass through the network.
|
|
91
|
+
3. Compute base loss using the criterion and accumulate metric.
|
|
92
|
+
4. Apply constraint-based adjustments to the loss.
|
|
93
|
+
5. Perform backward pass and optimizer step.
|
|
94
|
+
6. Run "on_train_batch_end" callback.
|
|
95
|
+
|
|
96
|
+
Args:
|
|
97
|
+
batch: Dictionary of input and target tensors for the batch.
|
|
98
|
+
|
|
99
|
+
Returns:
|
|
100
|
+
Tensor: The base loss computed before constraint adjustments.
|
|
101
|
+
"""
|
|
102
|
+
batch = self._to_device(batch)
|
|
103
|
+
batch = self._run_callbacks("on_train_batch_start", batch)
|
|
104
|
+
|
|
105
|
+
# Forward
|
|
106
|
+
batch = self.network(batch)
|
|
107
|
+
batch = self._run_callbacks("after_train_forward", batch)
|
|
108
|
+
|
|
109
|
+
# Base loss
|
|
110
|
+
loss: Tensor = self.criterion(
|
|
111
|
+
batch["output"],
|
|
112
|
+
batch["target"],
|
|
113
|
+
data=batch,
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
if self.metric_manager is not None:
|
|
117
|
+
self.metric_manager.accumulate("Loss/train", loss.unsqueeze(0))
|
|
118
|
+
|
|
119
|
+
# Constraint-adjusted loss
|
|
120
|
+
combined_loss = self.constraint_engine.train(batch, loss)
|
|
121
|
+
|
|
122
|
+
# Backward
|
|
123
|
+
self.optimizer.zero_grad()
|
|
124
|
+
combined_loss.backward()
|
|
125
|
+
self.optimizer.step()
|
|
126
|
+
|
|
127
|
+
batch = self._run_callbacks("on_train_batch_end", batch)
|
|
128
|
+
return loss
|
|
129
|
+
|
|
130
|
+
def valid_batch(self, batch: dict[str, Tensor]) -> Tensor:
|
|
131
|
+
"""Run a single validation batch.
|
|
132
|
+
|
|
133
|
+
Steps performed:
|
|
134
|
+
1. Move batch to device and run "on_valid_batch_start" callback.
|
|
135
|
+
2. Forward pass through the network.
|
|
136
|
+
3. Compute base loss using the criterion and accumulate metric.
|
|
137
|
+
4. Evaluate constraints via the ConstraintEngine (does not modify loss).
|
|
138
|
+
5. Run "on_valid_batch_end" callback.
|
|
139
|
+
|
|
140
|
+
Args:
|
|
141
|
+
batch: Dictionary of input and target tensors for the batch.
|
|
142
|
+
|
|
143
|
+
Returns:
|
|
144
|
+
Tensor: The base loss computed for the batch.
|
|
145
|
+
"""
|
|
146
|
+
batch = self._to_device(batch)
|
|
147
|
+
batch = self._run_callbacks("on_valid_batch_start", batch)
|
|
148
|
+
|
|
149
|
+
batch = self.network(batch)
|
|
150
|
+
batch = self._run_callbacks("after_valid_forward", batch)
|
|
151
|
+
|
|
152
|
+
loss: Tensor = self.criterion(
|
|
153
|
+
batch["output"],
|
|
154
|
+
batch["target"],
|
|
155
|
+
data=batch,
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
if self.metric_manager is not None:
|
|
159
|
+
self.metric_manager.accumulate("Loss/valid", loss.unsqueeze(0))
|
|
160
|
+
|
|
161
|
+
self.constraint_engine.validate(batch, loss)
|
|
162
|
+
|
|
163
|
+
batch = self._run_callbacks("on_valid_batch_end", batch)
|
|
164
|
+
return loss
|
|
165
|
+
|
|
166
|
+
def test_batch(self, batch: dict[str, Tensor]) -> Tensor:
|
|
167
|
+
"""Run a single test batch.
|
|
168
|
+
|
|
169
|
+
Steps performed:
|
|
170
|
+
1. Move batch to device and run "on_test_batch_start" callback.
|
|
171
|
+
2. Forward pass through the network.
|
|
172
|
+
3. Compute base loss using the criterion and accumulate metric.
|
|
173
|
+
4. Evaluate constraints via the ConstraintEngine (does not modify loss).
|
|
174
|
+
5. Run "on_test_batch_end" callback.
|
|
175
|
+
|
|
176
|
+
Args:
|
|
177
|
+
batch: Dictionary of input and target tensors for the batch.
|
|
178
|
+
|
|
179
|
+
Returns:
|
|
180
|
+
Tensor: The base loss computed for the batch.
|
|
181
|
+
"""
|
|
182
|
+
batch = self._to_device(batch)
|
|
183
|
+
batch = self._run_callbacks("on_test_batch_start", batch)
|
|
184
|
+
|
|
185
|
+
batch = self.network(batch)
|
|
186
|
+
batch = self._run_callbacks("after_test_forward", batch)
|
|
187
|
+
|
|
188
|
+
loss: Tensor = self.criterion(
|
|
189
|
+
batch["output"],
|
|
190
|
+
batch["target"],
|
|
191
|
+
data=batch,
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
if self.metric_manager is not None:
|
|
195
|
+
self.metric_manager.accumulate("Loss/test", loss.unsqueeze(0))
|
|
196
|
+
|
|
197
|
+
self.constraint_engine.test(batch, loss)
|
|
198
|
+
|
|
199
|
+
batch = self._run_callbacks("on_test_batch_end", batch)
|
|
200
|
+
return loss
|
|
@@ -0,0 +1,271 @@
|
|
|
1
|
+
"""This module provides the core CongradsCore class for the main training functionality.
|
|
2
|
+
|
|
3
|
+
It is designed to integrate constraint-guided optimization into neural network training.
|
|
4
|
+
It extends traditional training processes by enforcing specific constraints
|
|
5
|
+
on the model's outputs, ensuring that the network satisfies domain-specific
|
|
6
|
+
requirements during both training and evaluation.
|
|
7
|
+
|
|
8
|
+
The `CongradsCore` class serves as the central engine for managing the
|
|
9
|
+
training, validation, and testing phases of a neural network model,
|
|
10
|
+
incorporating constraints that influence the loss function and model updates.
|
|
11
|
+
The model is trained with standard loss functions while also incorporating
|
|
12
|
+
constraint-based adjustments, which are tracked and logged
|
|
13
|
+
throughout the process.
|
|
14
|
+
|
|
15
|
+
Key features:
|
|
16
|
+
- Support for various constraints that can influence the training process.
|
|
17
|
+
- Integration with PyTorch's `DataLoader` for efficient batch processing.
|
|
18
|
+
- Metric management for tracking loss and constraint satisfaction.
|
|
19
|
+
- Checkpoint management for saving and evaluating model states.
|
|
20
|
+
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
from collections.abc import Callable
|
|
24
|
+
|
|
25
|
+
import torch
|
|
26
|
+
from torch import Tensor, sum
|
|
27
|
+
from torch.nn import Module
|
|
28
|
+
from torch.nn.modules.loss import _Loss
|
|
29
|
+
from torch.optim import Optimizer
|
|
30
|
+
from torch.utils.data import DataLoader
|
|
31
|
+
from tqdm import tqdm
|
|
32
|
+
|
|
33
|
+
from congrads.utils.utility import LossWrapper
|
|
34
|
+
|
|
35
|
+
from ..callbacks.base import CallbackManager
|
|
36
|
+
from ..checkpoints import CheckpointManager
|
|
37
|
+
from ..constraints.base import Constraint
|
|
38
|
+
from ..core.batch_runner import BatchRunner
|
|
39
|
+
from ..core.constraint_engine import ConstraintEngine
|
|
40
|
+
from ..core.epoch_runner import EpochRunner
|
|
41
|
+
from ..descriptor import Descriptor
|
|
42
|
+
from ..metrics import MetricManager
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class CongradsCore:
|
|
46
|
+
"""The CongradsCore class is the central training engine for constraint-guided optimization.
|
|
47
|
+
|
|
48
|
+
It integrates standard neural network training
|
|
49
|
+
with additional constraint-driven adjustments to the loss function, ensuring
|
|
50
|
+
that the network satisfies domain-specific constraints during training.
|
|
51
|
+
"""
|
|
52
|
+
|
|
53
|
+
def __init__(
|
|
54
|
+
self,
|
|
55
|
+
descriptor: Descriptor,
|
|
56
|
+
constraints: list[Constraint],
|
|
57
|
+
network: Module,
|
|
58
|
+
criterion: _Loss,
|
|
59
|
+
optimizer: Optimizer,
|
|
60
|
+
device: torch.device,
|
|
61
|
+
dataloader_train: DataLoader,
|
|
62
|
+
dataloader_valid: DataLoader | None = None,
|
|
63
|
+
dataloader_test: DataLoader | None = None,
|
|
64
|
+
metric_manager: MetricManager | None = None,
|
|
65
|
+
callback_manager: CallbackManager | None = None,
|
|
66
|
+
checkpoint_manager: CheckpointManager | None = None,
|
|
67
|
+
network_uses_grad: bool = False,
|
|
68
|
+
epsilon: float = 1e-6,
|
|
69
|
+
constraint_aggregator: Callable[..., Tensor] = sum,
|
|
70
|
+
enforce_all: bool = True,
|
|
71
|
+
disable_progress_bar_epoch: bool = False,
|
|
72
|
+
disable_progress_bar_batch: bool = False,
|
|
73
|
+
epoch_runner_cls: type["EpochRunner"] | None = None,
|
|
74
|
+
batch_runner_cls: type["BatchRunner"] | None = None,
|
|
75
|
+
constraint_engine_cls: type["ConstraintEngine"] | None = None,
|
|
76
|
+
):
|
|
77
|
+
"""Initialize the CongradsCore object.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
descriptor (Descriptor): Describes variable layers in the network.
|
|
81
|
+
constraints (list[Constraint]): List of constraints to guide training.
|
|
82
|
+
network (Module): The neural network model to train.
|
|
83
|
+
criterion (callable): The loss function used for
|
|
84
|
+
training and validation.
|
|
85
|
+
optimizer (Optimizer): The optimizer used for updating model parameters.
|
|
86
|
+
device (torch.device): The device (e.g., CPU or GPU) for computations.
|
|
87
|
+
dataloader_train (DataLoader): DataLoader for training data.
|
|
88
|
+
dataloader_valid (DataLoader, optional): DataLoader for validation data.
|
|
89
|
+
If not provided, validation is skipped.
|
|
90
|
+
dataloader_test (DataLoader, optional): DataLoader for test data.
|
|
91
|
+
If not provided, testing is skipped.
|
|
92
|
+
metric_manager (MetricManager, optional): Manages metric tracking and recording.
|
|
93
|
+
callback_manager (CallbackManager, optional): Manages training callbacks.
|
|
94
|
+
checkpoint_manager (CheckpointManager, optional): Manages
|
|
95
|
+
checkpointing. If not set, no checkpointing is done.
|
|
96
|
+
network_uses_grad (bool, optional): A flag indicating if the network
|
|
97
|
+
contains gradient calculation computations. Default is False.
|
|
98
|
+
epsilon (float, optional): A small value to avoid division by zero
|
|
99
|
+
in gradient calculations. Default is 1e-10.
|
|
100
|
+
constraint_aggregator (Callable[..., Tensor], optional): A function
|
|
101
|
+
to aggregate the constraint rescale loss. Default is `sum`.
|
|
102
|
+
enforce_all (bool, optional): If set to False, constraints will only be monitored and
|
|
103
|
+
not influence the training process. Overrides constraint-specific `enforce` parameters.
|
|
104
|
+
Defaults to True.
|
|
105
|
+
disable_progress_bar_epoch (bool, optional): If set to True, the epoch
|
|
106
|
+
progress bar will not show. Defaults to False.
|
|
107
|
+
disable_progress_bar_batch (bool, optional): If set to True, the batch
|
|
108
|
+
progress bar will not show. Defaults to False.
|
|
109
|
+
epoch_runner_cls (type[EpochRunner], optional): Custom EpochRunner class.
|
|
110
|
+
If not provided, the default EpochRunner is used.
|
|
111
|
+
batch_runner_cls (type[BatchRunner], optional): Custom BatchRunner class.
|
|
112
|
+
If not provided, the default BatchRunner is used.
|
|
113
|
+
constraint_engine_cls (type[ConstraintEngine], optional): Custom ConstraintEngine class.
|
|
114
|
+
If not provided, the default ConstraintEngine is used.
|
|
115
|
+
|
|
116
|
+
Note:
|
|
117
|
+
A warning is logged if the descriptor has no variable layers,
|
|
118
|
+
as at least one variable layer is required for the constraint logic
|
|
119
|
+
to influence the training process.
|
|
120
|
+
"""
|
|
121
|
+
# Init object variables
|
|
122
|
+
self.device = device
|
|
123
|
+
self.network = network.to(device)
|
|
124
|
+
self.criterion = LossWrapper(criterion)
|
|
125
|
+
self.optimizer = optimizer
|
|
126
|
+
self.descriptor = descriptor
|
|
127
|
+
|
|
128
|
+
self.constraints = constraints or []
|
|
129
|
+
self.epsilon = epsilon
|
|
130
|
+
self.aggregator = constraint_aggregator
|
|
131
|
+
self.enforce_all = enforce_all
|
|
132
|
+
|
|
133
|
+
self.metric_manager = metric_manager
|
|
134
|
+
self.callback_manager = callback_manager
|
|
135
|
+
self.checkpoint_manager = checkpoint_manager
|
|
136
|
+
|
|
137
|
+
self.disable_progress_bar_epoch = disable_progress_bar_epoch
|
|
138
|
+
self.disable_progress_bar_batch = disable_progress_bar_batch
|
|
139
|
+
|
|
140
|
+
# Initialize constraint engine
|
|
141
|
+
self.constraint_engine = (constraint_engine_cls or ConstraintEngine)(
|
|
142
|
+
constraints=self.constraints,
|
|
143
|
+
descriptor=self.descriptor,
|
|
144
|
+
device=self.device,
|
|
145
|
+
epsilon=self.epsilon,
|
|
146
|
+
aggregator=self.aggregator,
|
|
147
|
+
enforce_all=self.enforce_all,
|
|
148
|
+
metric_manager=self.metric_manager,
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
# Initialize runners
|
|
152
|
+
self.batch_runner = (batch_runner_cls or BatchRunner)(
|
|
153
|
+
network=self.network,
|
|
154
|
+
criterion=self.criterion,
|
|
155
|
+
optimizer=self.optimizer,
|
|
156
|
+
constraint_engine=self.constraint_engine,
|
|
157
|
+
metric_manager=self.metric_manager,
|
|
158
|
+
callback_manager=self.callback_manager,
|
|
159
|
+
device=self.device,
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
self.epoch_runner = (epoch_runner_cls or EpochRunner)(
|
|
163
|
+
batch_runner=self.batch_runner,
|
|
164
|
+
network=self.network,
|
|
165
|
+
train_loader=dataloader_train,
|
|
166
|
+
valid_loader=dataloader_valid,
|
|
167
|
+
test_loader=dataloader_test,
|
|
168
|
+
network_uses_grad=network_uses_grad,
|
|
169
|
+
disable_progress_bar=self.disable_progress_bar_batch,
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
# Initialize constraint metrics
|
|
173
|
+
if self.metric_manager is not None:
|
|
174
|
+
self._initialize_metrics()
|
|
175
|
+
|
|
176
|
+
def _initialize_metrics(self) -> None:
|
|
177
|
+
"""Register metrics for loss, constraint satisfaction ratio (CSR), and constraints.
|
|
178
|
+
|
|
179
|
+
This method registers the following metrics:
|
|
180
|
+
|
|
181
|
+
- Loss/train: Training loss.
|
|
182
|
+
- Loss/valid: Validation loss.
|
|
183
|
+
- Loss/test: Test loss after training.
|
|
184
|
+
- CSR/train: Constraint satisfaction ratio during training.
|
|
185
|
+
- CSR/valid: Constraint satisfaction ratio during validation.
|
|
186
|
+
- CSR/test: Constraint satisfaction ratio after training.
|
|
187
|
+
- One metric per constraint, for both training and validation.
|
|
188
|
+
|
|
189
|
+
"""
|
|
190
|
+
self.metric_manager.register("Loss/train", "during_training")
|
|
191
|
+
self.metric_manager.register("Loss/valid", "during_training")
|
|
192
|
+
self.metric_manager.register("Loss/test", "after_training")
|
|
193
|
+
|
|
194
|
+
if len(self.constraints) > 0:
|
|
195
|
+
self.metric_manager.register("CSR/train", "during_training")
|
|
196
|
+
self.metric_manager.register("CSR/valid", "during_training")
|
|
197
|
+
self.metric_manager.register("CSR/test", "after_training")
|
|
198
|
+
|
|
199
|
+
for constraint in self.constraints:
|
|
200
|
+
self.metric_manager.register(f"{constraint.name}/train", "during_training")
|
|
201
|
+
self.metric_manager.register(f"{constraint.name}/valid", "during_training")
|
|
202
|
+
self.metric_manager.register(f"{constraint.name}/test", "after_training")
|
|
203
|
+
|
|
204
|
+
def fit(
|
|
205
|
+
self,
|
|
206
|
+
*,
|
|
207
|
+
start_epoch: int = 0,
|
|
208
|
+
max_epochs: int = 100,
|
|
209
|
+
test_model: bool = True,
|
|
210
|
+
final_checkpoint_name: str = "checkpoint_final.pth",
|
|
211
|
+
) -> None:
|
|
212
|
+
"""Run the full training loop, including optional validation, testing, and checkpointing.
|
|
213
|
+
|
|
214
|
+
This method performs training over multiple epochs with the following steps:
|
|
215
|
+
1. Trigger "on_train_start" callbacks if a callback manager is present.
|
|
216
|
+
2. For each epoch:
|
|
217
|
+
- Trigger "on_epoch_start" callbacks.
|
|
218
|
+
- Run a training epoch via the EpochRunner.
|
|
219
|
+
- Run a validation epoch via the EpochRunner.
|
|
220
|
+
- Evaluate checkpoint criteria if a checkpoint manager is present.
|
|
221
|
+
- Trigger "on_epoch_end" callbacks.
|
|
222
|
+
3. Trigger "on_train_end" callbacks after all epochs.
|
|
223
|
+
4. Optionally run a test epoch via the EpochRunner if `test_model` is True,
|
|
224
|
+
with corresponding "on_test_start" and "on_test_end" callbacks.
|
|
225
|
+
5. Save a final checkpoint using the checkpoint manager.
|
|
226
|
+
|
|
227
|
+
Args:
|
|
228
|
+
start_epoch: Index of the starting epoch (default 0). Useful for resuming training.
|
|
229
|
+
max_epochs: Maximum number of epochs to run (default 100).
|
|
230
|
+
test_model: Whether to run a test epoch after training (default True).
|
|
231
|
+
final_checkpoint_name: Filename for the final checkpoint saved at the end of training
|
|
232
|
+
(default "checkpoint_final.pth").
|
|
233
|
+
|
|
234
|
+
Returns:
|
|
235
|
+
None
|
|
236
|
+
"""
|
|
237
|
+
if self.callback_manager:
|
|
238
|
+
self.callback_manager.run("on_train_start", {"epoch": start_epoch})
|
|
239
|
+
|
|
240
|
+
for epoch in tqdm(
|
|
241
|
+
range(start_epoch, max_epochs),
|
|
242
|
+
initial=start_epoch,
|
|
243
|
+
desc="Epoch",
|
|
244
|
+
disable=self.disable_progress_bar_epoch,
|
|
245
|
+
):
|
|
246
|
+
if self.callback_manager:
|
|
247
|
+
self.callback_manager.run("on_epoch_start", {"epoch": epoch})
|
|
248
|
+
|
|
249
|
+
self.epoch_runner.train()
|
|
250
|
+
self.epoch_runner.validate()
|
|
251
|
+
|
|
252
|
+
if self.checkpoint_manager:
|
|
253
|
+
self.checkpoint_manager.evaluate_criteria(epoch)
|
|
254
|
+
|
|
255
|
+
if self.callback_manager:
|
|
256
|
+
self.callback_manager.run("on_epoch_end", {"epoch": epoch})
|
|
257
|
+
|
|
258
|
+
if self.callback_manager:
|
|
259
|
+
self.callback_manager.run("on_train_end", {"epoch": epoch})
|
|
260
|
+
|
|
261
|
+
if test_model:
|
|
262
|
+
if self.callback_manager:
|
|
263
|
+
self.callback_manager.run("on_test_start", {"epoch": epoch})
|
|
264
|
+
|
|
265
|
+
self.epoch_runner.test()
|
|
266
|
+
|
|
267
|
+
if self.callback_manager:
|
|
268
|
+
self.callback_manager.run("on_test_end", {"epoch": epoch})
|
|
269
|
+
|
|
270
|
+
if self.checkpoint_manager:
|
|
271
|
+
self.checkpoint_manager.save(epoch, final_checkpoint_name)
|
|
@@ -0,0 +1,209 @@
|
|
|
1
|
+
"""Manages the evaluation and optional enforcement of constraints on neural network outputs.
|
|
2
|
+
|
|
3
|
+
Responsibilities:
|
|
4
|
+
- Compute and log Constraint Satisfaction Rate (CSR) for training, validation, and test batches.
|
|
5
|
+
- Optionally adjust loss during training based on constraint directions and rescale factors.
|
|
6
|
+
- Handle gradient computation and CGGD application.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
from torch import Tensor, no_grad
|
|
11
|
+
from torch.linalg import vector_norm
|
|
12
|
+
|
|
13
|
+
from ..constraints.base import Constraint
|
|
14
|
+
from ..descriptor import Descriptor
|
|
15
|
+
from ..metrics import MetricManager
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class ConstraintEngine:
|
|
19
|
+
"""Manages constraint evaluation and enforcement for a neural network.
|
|
20
|
+
|
|
21
|
+
The ConstraintEngine coordinates constraints defined in Constraint objects,
|
|
22
|
+
computes gradients for layers that affect the loss, logs metrics, and optionally
|
|
23
|
+
modifies the loss during training according to the constraints. It supports
|
|
24
|
+
separate phases for training, validation, and testing.
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
def __init__(
|
|
28
|
+
self,
|
|
29
|
+
*,
|
|
30
|
+
constraints: list[Constraint],
|
|
31
|
+
descriptor: Descriptor,
|
|
32
|
+
metric_manager: MetricManager,
|
|
33
|
+
device: torch.device,
|
|
34
|
+
epsilon: float,
|
|
35
|
+
aggregator: callable,
|
|
36
|
+
enforce_all: bool,
|
|
37
|
+
) -> None:
|
|
38
|
+
"""Initialize the ConstraintEngine.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
constraints: List of Constraint objects to evaluate and optionally enforce.
|
|
42
|
+
descriptor: Descriptor containing metadata about network layers and which
|
|
43
|
+
variables affect the loss.
|
|
44
|
+
metric_manager: MetricManager instance for logging CSR metrics.
|
|
45
|
+
device: Torch device where tensors will be allocated (CPU or GPU).
|
|
46
|
+
epsilon: Small positive value to avoid division by zero in gradient norms.
|
|
47
|
+
aggregator: Callable used to reduce per-layer constraint contributions
|
|
48
|
+
to a scalar loss adjustment.
|
|
49
|
+
enforce_all: Whether to enforce all constraints during training.
|
|
50
|
+
"""
|
|
51
|
+
self.constraints = constraints
|
|
52
|
+
self.descriptor = descriptor
|
|
53
|
+
self.metric_manager = metric_manager
|
|
54
|
+
self.device = device
|
|
55
|
+
self.epsilon = epsilon
|
|
56
|
+
self.enforce_all = enforce_all
|
|
57
|
+
self.aggregator = aggregator
|
|
58
|
+
|
|
59
|
+
self.norm_loss_grad: dict[str, Tensor] = {}
|
|
60
|
+
|
|
61
|
+
def train(self, data: dict[str, Tensor], loss: Tensor) -> Tensor:
|
|
62
|
+
"""Apply all active constraints during training.
|
|
63
|
+
|
|
64
|
+
Computes the original loss gradients for layers that affect the loss,
|
|
65
|
+
evaluates each constraint, logs the Constraint Satisfaction Rate (CSR),
|
|
66
|
+
and adjusts the loss according to constraint satisfaction.
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
data: Dictionary containing input and prediction tensors for the batch.
|
|
70
|
+
loss: The original loss tensor computed from the network output.
|
|
71
|
+
|
|
72
|
+
Returns:
|
|
73
|
+
Tensor: The loss tensor after applying constraint-based adjustments.
|
|
74
|
+
"""
|
|
75
|
+
return self._apply_constraints(data, loss, phase="train", enforce=True)
|
|
76
|
+
|
|
77
|
+
def validate(self, data: dict[str, Tensor], loss: Tensor) -> Tensor:
|
|
78
|
+
"""Evaluate constraints during validation without modifying the loss.
|
|
79
|
+
|
|
80
|
+
Computes and logs the Constraint Satisfaction Rate (CSR) for each constraint,
|
|
81
|
+
but does not apply rescale adjustments to the loss.
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
data: Dictionary containing input and prediction tensors for the batch.
|
|
85
|
+
loss: The original loss tensor computed from the network output.
|
|
86
|
+
|
|
87
|
+
Returns:
|
|
88
|
+
Tensor: The original loss tensor, unchanged.
|
|
89
|
+
"""
|
|
90
|
+
return self._apply_constraints(data, loss, phase="valid", enforce=False)
|
|
91
|
+
|
|
92
|
+
def test(self, data: dict[str, Tensor], loss: Tensor) -> Tensor:
|
|
93
|
+
"""Evaluate constraints during testing without modifying the loss.
|
|
94
|
+
|
|
95
|
+
Computes and logs the Constraint Satisfaction Rate (CSR) for each constraint,
|
|
96
|
+
but does not apply rescale adjustments to the loss.
|
|
97
|
+
|
|
98
|
+
Args:
|
|
99
|
+
data: Dictionary containing input and prediction tensors for the batch.
|
|
100
|
+
loss: The original loss tensor computed from the network output.
|
|
101
|
+
|
|
102
|
+
Returns:
|
|
103
|
+
Tensor: The original loss tensor, unchanged.
|
|
104
|
+
"""
|
|
105
|
+
return self._apply_constraints(data, loss, phase="test", enforce=False)
|
|
106
|
+
|
|
107
|
+
def _apply_constraints(
|
|
108
|
+
self, data: dict[str, Tensor], loss: Tensor, phase: str, enforce: bool
|
|
109
|
+
) -> Tensor:
|
|
110
|
+
"""Evaluate constraints, log CSR, and optionally adjust the loss.
|
|
111
|
+
|
|
112
|
+
During training, computes loss gradients for variable layers that affect the loss.
|
|
113
|
+
Iterates over all constraints, logging the Constraint Satisfaction Rate (CSR)
|
|
114
|
+
and, if enforcement is enabled, adjusts the loss using constraint-specific
|
|
115
|
+
directions and rescale factors.
|
|
116
|
+
|
|
117
|
+
Args:
|
|
118
|
+
data: Dictionary containing input and prediction tensors for the batch.
|
|
119
|
+
loss: Original loss tensor computed from the network output.
|
|
120
|
+
phase: Current phase, one of "train", "valid", or "test".
|
|
121
|
+
enforce: If True, constraint-based adjustments are applied to the loss.
|
|
122
|
+
|
|
123
|
+
Returns:
|
|
124
|
+
Tensor: The combined loss after applying constraints (or the original loss
|
|
125
|
+
if enforce is False or not in training phase).
|
|
126
|
+
"""
|
|
127
|
+
total_rescale_loss = torch.tensor(0.0, device=self.device, dtype=loss.dtype)
|
|
128
|
+
|
|
129
|
+
if phase == "train":
|
|
130
|
+
norm_loss_grad = self._calculate_loss_gradients(loss, data)
|
|
131
|
+
norm_loss_grad = self._override_loss_gradients(norm_loss_grad, loss, data)
|
|
132
|
+
|
|
133
|
+
# Iterate constraints
|
|
134
|
+
for constraint in self.constraints:
|
|
135
|
+
checks, mask = constraint.check_constraint(data)
|
|
136
|
+
directions = constraint.calculate_direction(data)
|
|
137
|
+
|
|
138
|
+
# Log CSR
|
|
139
|
+
csr = (torch.sum(checks * mask) / torch.sum(mask)).unsqueeze(0)
|
|
140
|
+
self.metric_manager.accumulate(f"{constraint.name}/{phase}", csr)
|
|
141
|
+
self.metric_manager.accumulate(f"CSR/{phase}", csr)
|
|
142
|
+
|
|
143
|
+
# Skip adjustment if not enforcing
|
|
144
|
+
if not enforce or not constraint.enforce or not self.enforce_all or phase != "train":
|
|
145
|
+
continue
|
|
146
|
+
|
|
147
|
+
# Compute constraint-based rescale loss
|
|
148
|
+
for key in constraint.layers & self.descriptor.variable_keys:
|
|
149
|
+
with no_grad():
|
|
150
|
+
rescale = (1 - checks) * directions[key] * constraint.rescale_factor
|
|
151
|
+
total_rescale_loss += self.aggregator(data[key] * rescale * norm_loss_grad[key])
|
|
152
|
+
|
|
153
|
+
return loss + total_rescale_loss
|
|
154
|
+
|
|
155
|
+
def _calculate_loss_gradients(self, loss: Tensor, data: dict[str, Tensor]) -> None:
|
|
156
|
+
"""Compute and store normalized loss gradients for variable layers.
|
|
157
|
+
|
|
158
|
+
For each layer that affects the loss, computes the gradient of the loss
|
|
159
|
+
with respect to that layer's output. The gradients are normalized by their
|
|
160
|
+
vector norms plus a small epsilon to avoid division by zero.
|
|
161
|
+
|
|
162
|
+
Args:
|
|
163
|
+
loss: The original loss tensor computed from the network output.
|
|
164
|
+
data: Dictionary containing input and prediction tensors for the batch.
|
|
165
|
+
"""
|
|
166
|
+
# Precompute gradients for variable layers affecting the loss
|
|
167
|
+
norm_loss_grad = {}
|
|
168
|
+
|
|
169
|
+
variable_keys = self.descriptor.variable_keys & self.descriptor.affects_loss_keys
|
|
170
|
+
for key in variable_keys:
|
|
171
|
+
if data[key].requires_grad is False:
|
|
172
|
+
raise RuntimeError(
|
|
173
|
+
f"Layer '{key}' does not require gradients. Is this an input? "
|
|
174
|
+
"Set constant=True in Descriptor if this layer is an input."
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
grad = torch.autograd.grad(
|
|
178
|
+
outputs=loss, inputs=data[key], retain_graph=True, allow_unused=True
|
|
179
|
+
)[0]
|
|
180
|
+
|
|
181
|
+
if grad is None:
|
|
182
|
+
raise RuntimeError(
|
|
183
|
+
f"Unable to compute loss gradients for layer '{key}'. "
|
|
184
|
+
"Set has_loss=False in Descriptor if this layer does not affect loss."
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
grad_flat = grad.view(grad.shape[0], -1)
|
|
188
|
+
norm_loss_grad[key] = (
|
|
189
|
+
vector_norm(grad_flat, dim=1, ord=2, keepdim=True).clamp(min=self.epsilon).detach()
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
return norm_loss_grad
|
|
193
|
+
|
|
194
|
+
def _override_loss_gradients(
|
|
195
|
+
self, norm_loss_grad: dict[str, Tensor], loss: Tensor, data: dict[str, Tensor]
|
|
196
|
+
) -> dict[str, Tensor]:
|
|
197
|
+
"""Override the standard normalized loss gradient computation for custom functionality.
|
|
198
|
+
|
|
199
|
+
Args:
|
|
200
|
+
norm_loss_grad: Dictionary mapping parameter or component names to normalized
|
|
201
|
+
gradient tensors.
|
|
202
|
+
loss: Scalar loss value for the current training step.
|
|
203
|
+
data: Dictionary containing the batch data used to compute the loss and any
|
|
204
|
+
constraint-related signals.
|
|
205
|
+
|
|
206
|
+
Returns:
|
|
207
|
+
A dictionary of modified normalized gradients.
|
|
208
|
+
"""
|
|
209
|
+
return norm_loss_grad
|