congrads 1.1.1__py3-none-any.whl → 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- congrads/__init__.py +0 -17
- congrads/callbacks/base.py +357 -0
- congrads/callbacks/registry.py +106 -0
- congrads/checkpoints.py +1 -1
- congrads/constraints/base.py +174 -0
- congrads/{constraints.py → constraints/registry.py} +120 -158
- congrads/core/batch_runner.py +200 -0
- congrads/core/congradscore.py +271 -0
- congrads/core/constraint_engine.py +170 -0
- congrads/core/epoch_runner.py +119 -0
- congrads/descriptor.py +1 -1
- congrads/metrics.py +1 -1
- congrads/transformations/base.py +37 -0
- congrads/{transformations.py → transformations/registry.py} +3 -33
- congrads/utils/preprocessors.py +439 -0
- congrads/utils/utility.py +506 -0
- congrads/utils/validation.py +194 -0
- {congrads-1.1.1.dist-info → congrads-1.2.0.dist-info}/METADATA +2 -2
- congrads-1.2.0.dist-info/RECORD +23 -0
- congrads/core.py +0 -773
- congrads/utils.py +0 -1078
- congrads-1.1.1.dist-info/RECORD +0 -14
- /congrads/{datasets.py → datasets/registry.py} +0 -0
- /congrads/{networks.py → networks/registry.py} +0 -0
- {congrads-1.1.1.dist-info → congrads-1.2.0.dist-info}/WHEEL +0 -0
|
@@ -24,10 +24,6 @@ Usage:
|
|
|
24
24
|
|
|
25
25
|
"""
|
|
26
26
|
|
|
27
|
-
import random
|
|
28
|
-
import string
|
|
29
|
-
import warnings
|
|
30
|
-
from abc import ABC, abstractmethod
|
|
31
27
|
from collections.abc import Callable
|
|
32
28
|
from numbers import Number
|
|
33
29
|
from typing import Literal
|
|
@@ -54,156 +50,10 @@ from torch import (
|
|
|
54
50
|
)
|
|
55
51
|
from torch.nn.functional import normalize
|
|
56
52
|
|
|
57
|
-
from .
|
|
58
|
-
from .
|
|
59
|
-
from .
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
class Constraint(ABC):
|
|
63
|
-
"""Abstract base class for defining constraints applied to neural networks.
|
|
64
|
-
|
|
65
|
-
A `Constraint` specifies conditions that the neural network outputs
|
|
66
|
-
should satisfy. It supports monitoring constraint satisfaction
|
|
67
|
-
during training and can adjust loss to enforce constraints. Subclasses
|
|
68
|
-
must implement the `check_constraint` and `calculate_direction` methods.
|
|
69
|
-
|
|
70
|
-
Args:
|
|
71
|
-
tags (set[str]): Tags referencing parts of the network where this constraint applies to.
|
|
72
|
-
name (str, optional): A unique name for the constraint. If not provided,
|
|
73
|
-
a name is generated based on the class name and a random suffix.
|
|
74
|
-
enforce (bool, optional): If False, only monitor the constraint
|
|
75
|
-
without adjusting the loss. Defaults to True.
|
|
76
|
-
rescale_factor (Number, optional): Factor to scale the
|
|
77
|
-
constraint-adjusted loss. Defaults to 1.5. Should be greater
|
|
78
|
-
than 1 to give weight to the constraint.
|
|
79
|
-
|
|
80
|
-
Raises:
|
|
81
|
-
TypeError: If a provided attribute has an incompatible type.
|
|
82
|
-
ValueError: If any tag in `tags` is not
|
|
83
|
-
defined in the `descriptor`.
|
|
84
|
-
|
|
85
|
-
Note:
|
|
86
|
-
- If `rescale_factor <= 1`, a warning is issued.
|
|
87
|
-
- If `name` is not provided, a name is auto-generated,
|
|
88
|
-
and a warning is logged.
|
|
89
|
-
|
|
90
|
-
"""
|
|
91
|
-
|
|
92
|
-
descriptor: Descriptor = None
|
|
93
|
-
device = None
|
|
94
|
-
|
|
95
|
-
def __init__(
|
|
96
|
-
self, tags: set[str], name: str = None, enforce: bool = True, rescale_factor: Number = 1.5
|
|
97
|
-
) -> None:
|
|
98
|
-
"""Initializes a new Constraint instance.
|
|
99
|
-
|
|
100
|
-
Args:
|
|
101
|
-
tags (set[str]): Tags referencing parts of the network where this constraint applies to.
|
|
102
|
-
name (str, optional): A unique name for the constraint. If not
|
|
103
|
-
provided, a name is generated based on the class name and a
|
|
104
|
-
random suffix.
|
|
105
|
-
enforce (bool, optional): If False, only monitor the constraint
|
|
106
|
-
without adjusting the loss. Defaults to True.
|
|
107
|
-
rescale_factor (Number, optional): Factor to scale the
|
|
108
|
-
constraint-adjusted loss. Defaults to 1.5. Should be greater
|
|
109
|
-
than 1 to give weight to the constraint.
|
|
110
|
-
|
|
111
|
-
Raises:
|
|
112
|
-
TypeError: If a provided attribute has an incompatible type.
|
|
113
|
-
ValueError: If any tag in `tags` is not defined in the `descriptor`.
|
|
114
|
-
|
|
115
|
-
Note:
|
|
116
|
-
- If `rescale_factor <= 1`, a warning is issued.
|
|
117
|
-
- If `name` is not provided, a name is auto-generated, and a
|
|
118
|
-
warning is logged.
|
|
119
|
-
"""
|
|
120
|
-
# Init parent class
|
|
121
|
-
super().__init__()
|
|
122
|
-
|
|
123
|
-
# Type checking
|
|
124
|
-
validate_iterable("tags", tags, str)
|
|
125
|
-
validate_type("name", name, str, allow_none=True)
|
|
126
|
-
validate_type("enforce", enforce, bool)
|
|
127
|
-
validate_type("rescale_factor", rescale_factor, Number)
|
|
128
|
-
|
|
129
|
-
# Init object variables
|
|
130
|
-
self.tags = tags
|
|
131
|
-
self.rescale_factor = rescale_factor
|
|
132
|
-
self.initial_rescale_factor = rescale_factor
|
|
133
|
-
self.enforce = enforce
|
|
134
|
-
|
|
135
|
-
# Perform checks
|
|
136
|
-
if rescale_factor <= 1:
|
|
137
|
-
warnings.warn(
|
|
138
|
-
f"Rescale factor for constraint {name} is <= 1. The network "
|
|
139
|
-
"will favor general loss over the constraint-adjusted loss. "
|
|
140
|
-
"Is this intended behavior? Normally, the rescale factor "
|
|
141
|
-
"should always be larger than 1.",
|
|
142
|
-
stacklevel=2,
|
|
143
|
-
)
|
|
144
|
-
|
|
145
|
-
# If no constraint_name is set, generate one based
|
|
146
|
-
# on the class name and a random suffix
|
|
147
|
-
if name:
|
|
148
|
-
self.name = name
|
|
149
|
-
else:
|
|
150
|
-
random_suffix = "".join(random.choices(string.ascii_uppercase + string.digits, k=6))
|
|
151
|
-
self.name = f"{self.__class__.__name__}_{random_suffix}"
|
|
152
|
-
warnings.warn(f"Name for constraint is not set. Using {self.name}.", stacklevel=2)
|
|
153
|
-
|
|
154
|
-
# Infer layers from descriptor and tags
|
|
155
|
-
self.layers = set()
|
|
156
|
-
for tag in self.tags:
|
|
157
|
-
if not self.descriptor.exists(tag):
|
|
158
|
-
raise ValueError(
|
|
159
|
-
f"The tag {tag} used with constraint "
|
|
160
|
-
f"{self.name} is not defined in the descriptor. Please "
|
|
161
|
-
"add it to the correct layer using "
|
|
162
|
-
"descriptor.add('layer', ...)."
|
|
163
|
-
)
|
|
164
|
-
|
|
165
|
-
layer, _ = self.descriptor.location(tag)
|
|
166
|
-
self.layers.add(layer)
|
|
167
|
-
|
|
168
|
-
@abstractmethod
|
|
169
|
-
def check_constraint(self, data: dict[str, Tensor]) -> tuple[Tensor, Tensor]:
|
|
170
|
-
"""Evaluates whether the given model predictions satisfy the constraint.
|
|
171
|
-
|
|
172
|
-
1 IS SATISFIED, 0 IS NOT SATISFIED
|
|
173
|
-
|
|
174
|
-
Args:
|
|
175
|
-
data (dict[str, Tensor]): Dictionary that holds batch data, model predictions and context.
|
|
176
|
-
|
|
177
|
-
Returns:
|
|
178
|
-
tuple[Tensor, Tensor]: A tuple where the first element is a tensor of floats
|
|
179
|
-
indicating whether the constraint is satisfied (with value 1.0
|
|
180
|
-
for satisfaction, and 0.0 for non-satisfaction, and the second element is a tensor
|
|
181
|
-
mask that indicates the relevance of each sample (`True` for relevant
|
|
182
|
-
samples and `False` for irrelevant ones).
|
|
183
|
-
|
|
184
|
-
Raises:
|
|
185
|
-
NotImplementedError: If not implemented in a subclass.
|
|
186
|
-
"""
|
|
187
|
-
raise NotImplementedError
|
|
188
|
-
|
|
189
|
-
@abstractmethod
|
|
190
|
-
def calculate_direction(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
|
|
191
|
-
"""Compute adjustment directions to better satisfy the constraint.
|
|
192
|
-
|
|
193
|
-
Given the model predictions, input batch, and context, this method calculates the direction
|
|
194
|
-
in which the predictions referenced by a tag should be adjusted to satisfy the constraint.
|
|
195
|
-
|
|
196
|
-
Args:
|
|
197
|
-
data (dict[str, Tensor]): Dictionary that holds batch data, model predictions and context.
|
|
198
|
-
|
|
199
|
-
Returns:
|
|
200
|
-
dict[str, Tensor]: Dictionary mapping network layers to tensors that
|
|
201
|
-
specify the adjustment direction for each tag.
|
|
202
|
-
|
|
203
|
-
Raises:
|
|
204
|
-
NotImplementedError: Must be implemented by subclasses.
|
|
205
|
-
"""
|
|
206
|
-
raise NotImplementedError
|
|
53
|
+
from ..transformations.base import Transformation
|
|
54
|
+
from ..transformations.registry import IdentityTransformation
|
|
55
|
+
from ..utils.validation import validate_comparator_pytorch, validate_iterable, validate_type
|
|
56
|
+
from .base import Constraint
|
|
207
57
|
|
|
208
58
|
|
|
209
59
|
class ImplicationConstraint(Constraint):
|
|
@@ -918,11 +768,23 @@ class MonotonicityConstraint(Constraint):
|
|
|
918
768
|
return {layer: self.compared_rankings}
|
|
919
769
|
|
|
920
770
|
|
|
921
|
-
class
|
|
922
|
-
"""
|
|
771
|
+
class PerGroupMonotonicityConstraint(MonotonicityConstraint):
|
|
772
|
+
"""Group-wise monotonicity constraint enforced independently per group.
|
|
923
773
|
|
|
924
|
-
This constraint
|
|
925
|
-
|
|
774
|
+
This constraint enforces a monotonic relationship between a prediction tag
|
|
775
|
+
(`tag_prediction`) and a reference tag (`tag_reference`) **within each group**
|
|
776
|
+
identified by `tag_group_identifier`.
|
|
777
|
+
|
|
778
|
+
For each unique group identifier, the base `MonotonicityConstraint` is applied
|
|
779
|
+
independently to the corresponding subset of samples. This makes the behavior
|
|
780
|
+
semantically explicit and easy to reason about, as no interaction or ordering
|
|
781
|
+
is introduced across different groups.
|
|
782
|
+
|
|
783
|
+
Notes:
|
|
784
|
+
- Groups are treated as fully independent constraint instances.
|
|
785
|
+
- This implementation prioritizes correctness and interpretability.
|
|
786
|
+
- It may be less efficient for large numbers of groups due to explicit
|
|
787
|
+
iteration and repeated constraint evaluation.
|
|
926
788
|
"""
|
|
927
789
|
|
|
928
790
|
def __init__(
|
|
@@ -1010,6 +872,106 @@ class GroupedMonotonicityConstraint(MonotonicityConstraint):
|
|
|
1010
872
|
return {layer: self.directions}
|
|
1011
873
|
|
|
1012
874
|
|
|
875
|
+
class EncodedGroupedMonotonicityConstraint(MonotonicityConstraint):
|
|
876
|
+
"""Group-wise monotonicity constraint enforced via rank encoding.
|
|
877
|
+
|
|
878
|
+
This constraint enforces a monotonic relationship between a prediction tag
|
|
879
|
+
(`tag_prediction`) and a reference tag (`tag_reference`) within each group
|
|
880
|
+
identified by `tag_group_identifier`, using a fully vectorized approach.
|
|
881
|
+
|
|
882
|
+
Group independence is achieved by encoding the group identifiers into the
|
|
883
|
+
prediction and reference values via large offsets, effectively separating
|
|
884
|
+
the rank spaces of different groups. This allows the base
|
|
885
|
+
`MonotonicityConstraint` to be applied once to the entire batch without
|
|
886
|
+
explicit per-group iteration.
|
|
887
|
+
|
|
888
|
+
Notes:
|
|
889
|
+
- Groups are isolated implicitly through rank-space separation.
|
|
890
|
+
- The logic is less explicit than per-group evaluation and relies on
|
|
891
|
+
offset-based rank encoding for correctness.
|
|
892
|
+
- This constraint might cause floating point errors if
|
|
893
|
+
the prediction or target range is very large.
|
|
894
|
+
"""
|
|
895
|
+
|
|
896
|
+
def __init__(
|
|
897
|
+
self,
|
|
898
|
+
tag_prediction: str,
|
|
899
|
+
tag_reference: str,
|
|
900
|
+
tag_group_identifier: str,
|
|
901
|
+
rescale_factor_lower: float = 1.5,
|
|
902
|
+
rescale_factor_upper: float = 1.75,
|
|
903
|
+
stable: bool = True,
|
|
904
|
+
direction: Literal["ascending", "descending"] = "ascending",
|
|
905
|
+
name: str = None,
|
|
906
|
+
enforce: bool = True,
|
|
907
|
+
):
|
|
908
|
+
"""Constraint that enforces monotonicity on a predicted output.
|
|
909
|
+
|
|
910
|
+
This constraint ensures that the activations of a prediction tag (`tag_prediction`)
|
|
911
|
+
are monotonically ascending or descending with respect to a target tag (`tag_reference`).
|
|
912
|
+
|
|
913
|
+
Args:
|
|
914
|
+
tag_prediction (str): Name of the tag whose activations should follow the monotonic relationship.
|
|
915
|
+
tag_reference (str): Name of the tag that acts as the monotonic reference.
|
|
916
|
+
tag_group_identifier (str): Name of the tag that identifies groups for separate monotonicity enforcement.
|
|
917
|
+
rescale_factor_lower (float, optional): Lower bound for rescaling rank differences. Defaults to 1.5.
|
|
918
|
+
rescale_factor_upper (float, optional): Upper bound for rescaling rank differences. Defaults to 1.75.
|
|
919
|
+
stable (bool, optional): Whether to use stable sorting when ranking. Defaults to True.
|
|
920
|
+
direction (str, optional): Direction of monotonicity to enforce, either 'ascending' or 'descending'. Defaults to 'ascending'.
|
|
921
|
+
name (str, optional): Custom name for the constraint. If None, a descriptive name is auto-generated.
|
|
922
|
+
enforce (bool, optional): If False, the constraint is only monitored (not enforced). Defaults to True.
|
|
923
|
+
"""
|
|
924
|
+
# Compose constraint name
|
|
925
|
+
if name is None:
|
|
926
|
+
name = f"{tag_prediction} for each {tag_group_identifier} monotonically {direction} by {tag_reference}"
|
|
927
|
+
|
|
928
|
+
# Init parent class
|
|
929
|
+
super().__init__(
|
|
930
|
+
tag_prediction=tag_prediction,
|
|
931
|
+
tag_reference=tag_reference,
|
|
932
|
+
rescale_factor_lower=rescale_factor_lower,
|
|
933
|
+
rescale_factor_upper=rescale_factor_upper,
|
|
934
|
+
stable=stable,
|
|
935
|
+
direction=direction,
|
|
936
|
+
name=name,
|
|
937
|
+
enforce=enforce,
|
|
938
|
+
)
|
|
939
|
+
|
|
940
|
+
# Init variables
|
|
941
|
+
self.tag_prediction = tag_prediction
|
|
942
|
+
self.tag_reference = tag_reference
|
|
943
|
+
self.tag_group_identifier = tag_group_identifier
|
|
944
|
+
|
|
945
|
+
# Initialize negation factor based on direction
|
|
946
|
+
self.negation = 1 if direction == "ascending" else -1
|
|
947
|
+
|
|
948
|
+
def check_constraint(self, data: dict[str, Tensor]) -> tuple[Tensor, Tensor]:
|
|
949
|
+
"""Evaluate whether the monotonicity constraint is satisfied."""
|
|
950
|
+
# Get data and keys
|
|
951
|
+
ids = self.descriptor.select(self.tag_group_identifier, data)
|
|
952
|
+
preds = self.descriptor.select(self.tag_prediction, data)
|
|
953
|
+
targets = self.descriptor.select(self.tag_reference, data)
|
|
954
|
+
preds_key, _ = self.descriptor.location(self.tag_prediction)
|
|
955
|
+
targets_key, _ = self.descriptor.location(self.tag_reference)
|
|
956
|
+
|
|
957
|
+
new_preds = preds + ids * (preds.max() - preds.min() + 1)
|
|
958
|
+
new_targets = targets + self.negation * ids * (targets.max() - targets.min() + 1)
|
|
959
|
+
|
|
960
|
+
# Create new batch for child constraint
|
|
961
|
+
new_data = {preds_key: new_preds, targets_key: new_targets}
|
|
962
|
+
|
|
963
|
+
# Call super on the adjusted batch
|
|
964
|
+
checks, _ = super().check_constraint(new_data)
|
|
965
|
+
self.directions = self.compared_rankings
|
|
966
|
+
|
|
967
|
+
return checks, ones_like(checks)
|
|
968
|
+
|
|
969
|
+
def calculate_direction(self, data: dict[str, Tensor]) -> dict[str, Tensor]:
|
|
970
|
+
"""Calculates ranking adjustments for monotonicity enforcement."""
|
|
971
|
+
layer, _ = self.descriptor.location(self.tag_prediction)
|
|
972
|
+
return {layer: self.directions}
|
|
973
|
+
|
|
974
|
+
|
|
1013
975
|
class ANDConstraint(Constraint):
|
|
1014
976
|
"""A composite constraint that enforces the logical AND of multiple constraints.
|
|
1015
977
|
|
|
@@ -0,0 +1,200 @@
|
|
|
1
|
+
"""Defines the BatchRunner, which executes individual batches for training, validation, and testing.
|
|
2
|
+
|
|
3
|
+
Responsibilities:
|
|
4
|
+
- Move batch data to the appropriate device
|
|
5
|
+
- Run forward passes through the network
|
|
6
|
+
- Compute base and constraint-adjusted losses
|
|
7
|
+
- Perform backpropagation during training
|
|
8
|
+
- Accumulate metrics for loss and other monitored quantities
|
|
9
|
+
- Trigger callbacks at key points in the batch lifecycle
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
import torch
|
|
13
|
+
from torch import Tensor
|
|
14
|
+
from torch.nn import Module
|
|
15
|
+
from torch.optim import Optimizer
|
|
16
|
+
|
|
17
|
+
from ..callbacks.base import CallbackManager
|
|
18
|
+
from ..core.constraint_engine import ConstraintEngine
|
|
19
|
+
from ..metrics import MetricManager
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class BatchRunner:
|
|
23
|
+
"""Executes a single batch for training, validation, or testing.
|
|
24
|
+
|
|
25
|
+
The BatchRunner handles moving data to the correct device, running the network
|
|
26
|
+
forward, computing base and constraint-adjusted losses, performing backpropagation
|
|
27
|
+
during training, accumulating metrics, and dispatching callbacks at key points
|
|
28
|
+
in the batch lifecycle.
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
def __init__(
|
|
32
|
+
self,
|
|
33
|
+
network: Module,
|
|
34
|
+
criterion,
|
|
35
|
+
optimizer: Optimizer,
|
|
36
|
+
constraint_engine: ConstraintEngine,
|
|
37
|
+
metric_manager: MetricManager | None,
|
|
38
|
+
callback_manager: CallbackManager | None,
|
|
39
|
+
device: torch.device,
|
|
40
|
+
):
|
|
41
|
+
"""Initialize the BatchRunner.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
network: The neural network module to execute.
|
|
45
|
+
criterion: Loss function callable accepting (output, target, data=batch).
|
|
46
|
+
optimizer: Optimizer for updating network parameters.
|
|
47
|
+
constraint_engine: ConstraintEngine instance for evaluating and enforcing constraints.
|
|
48
|
+
metric_manager: Optional MetricManager for logging batch metrics.
|
|
49
|
+
callback_manager: Optional CallbackManager for triggering hooks during batch processing.
|
|
50
|
+
device: Torch device on which to place data and network.
|
|
51
|
+
"""
|
|
52
|
+
self.network = network
|
|
53
|
+
self.criterion = criterion
|
|
54
|
+
self.optimizer = optimizer
|
|
55
|
+
self.constraint_engine = constraint_engine
|
|
56
|
+
self.metric_manager = metric_manager
|
|
57
|
+
self.callback_manager = callback_manager
|
|
58
|
+
self.device = device
|
|
59
|
+
|
|
60
|
+
def _to_device(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
|
61
|
+
"""Move all tensors in the batch to the BatchRunner's device.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
batch: Dictionary of tensors for a single batch.
|
|
65
|
+
|
|
66
|
+
Returns:
|
|
67
|
+
Dictionary of tensors moved to the target device.
|
|
68
|
+
"""
|
|
69
|
+
return {k: v.to(self.device) for k, v in batch.items()}
|
|
70
|
+
|
|
71
|
+
def _run_callbacks(self, hook: str, data: dict) -> dict:
|
|
72
|
+
"""Run the specified callback hook on the batch data.
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
hook: Name of the callback hook to run.
|
|
76
|
+
data: Dictionary containing batch data.
|
|
77
|
+
|
|
78
|
+
Returns:
|
|
79
|
+
Potentially modified batch data after callback execution.
|
|
80
|
+
"""
|
|
81
|
+
if self.callback_manager is None:
|
|
82
|
+
return data
|
|
83
|
+
return self.callback_manager.run(hook, data)
|
|
84
|
+
|
|
85
|
+
def train_batch(self, batch: dict[str, Tensor]) -> Tensor:
|
|
86
|
+
"""Run a single training batch.
|
|
87
|
+
|
|
88
|
+
Steps performed:
|
|
89
|
+
1. Move batch to device and run "on_train_batch_start" callback.
|
|
90
|
+
2. Forward pass through the network.
|
|
91
|
+
3. Compute base loss using the criterion and accumulate metric.
|
|
92
|
+
4. Apply constraint-based adjustments to the loss.
|
|
93
|
+
5. Perform backward pass and optimizer step.
|
|
94
|
+
6. Run "on_train_batch_end" callback.
|
|
95
|
+
|
|
96
|
+
Args:
|
|
97
|
+
batch: Dictionary of input and target tensors for the batch.
|
|
98
|
+
|
|
99
|
+
Returns:
|
|
100
|
+
Tensor: The base loss computed before constraint adjustments.
|
|
101
|
+
"""
|
|
102
|
+
batch = self._to_device(batch)
|
|
103
|
+
batch = self._run_callbacks("on_train_batch_start", batch)
|
|
104
|
+
|
|
105
|
+
# Forward
|
|
106
|
+
batch = self.network(batch)
|
|
107
|
+
batch = self._run_callbacks("after_train_forward", batch)
|
|
108
|
+
|
|
109
|
+
# Base loss
|
|
110
|
+
loss: Tensor = self.criterion(
|
|
111
|
+
batch["output"],
|
|
112
|
+
batch["target"],
|
|
113
|
+
data=batch,
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
if self.metric_manager is not None:
|
|
117
|
+
self.metric_manager.accumulate("Loss/train", loss.unsqueeze(0))
|
|
118
|
+
|
|
119
|
+
# Constraint-adjusted loss
|
|
120
|
+
combined_loss = self.constraint_engine.train(batch, loss)
|
|
121
|
+
|
|
122
|
+
# Backward
|
|
123
|
+
self.optimizer.zero_grad()
|
|
124
|
+
combined_loss.backward()
|
|
125
|
+
self.optimizer.step()
|
|
126
|
+
|
|
127
|
+
batch = self._run_callbacks("on_train_batch_end", batch)
|
|
128
|
+
return loss
|
|
129
|
+
|
|
130
|
+
def valid_batch(self, batch: dict[str, Tensor]) -> Tensor:
|
|
131
|
+
"""Run a single validation batch.
|
|
132
|
+
|
|
133
|
+
Steps performed:
|
|
134
|
+
1. Move batch to device and run "on_valid_batch_start" callback.
|
|
135
|
+
2. Forward pass through the network.
|
|
136
|
+
3. Compute base loss using the criterion and accumulate metric.
|
|
137
|
+
4. Evaluate constraints via the ConstraintEngine (does not modify loss).
|
|
138
|
+
5. Run "on_valid_batch_end" callback.
|
|
139
|
+
|
|
140
|
+
Args:
|
|
141
|
+
batch: Dictionary of input and target tensors for the batch.
|
|
142
|
+
|
|
143
|
+
Returns:
|
|
144
|
+
Tensor: The base loss computed for the batch.
|
|
145
|
+
"""
|
|
146
|
+
batch = self._to_device(batch)
|
|
147
|
+
batch = self._run_callbacks("on_valid_batch_start", batch)
|
|
148
|
+
|
|
149
|
+
batch = self.network(batch)
|
|
150
|
+
batch = self._run_callbacks("after_valid_forward", batch)
|
|
151
|
+
|
|
152
|
+
loss: Tensor = self.criterion(
|
|
153
|
+
batch["output"],
|
|
154
|
+
batch["target"],
|
|
155
|
+
data=batch,
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
if self.metric_manager is not None:
|
|
159
|
+
self.metric_manager.accumulate("Loss/valid", loss.unsqueeze(0))
|
|
160
|
+
|
|
161
|
+
self.constraint_engine.validate(batch, loss)
|
|
162
|
+
|
|
163
|
+
batch = self._run_callbacks("on_valid_batch_end", batch)
|
|
164
|
+
return loss
|
|
165
|
+
|
|
166
|
+
def test_batch(self, batch: dict[str, Tensor]) -> Tensor:
|
|
167
|
+
"""Run a single test batch.
|
|
168
|
+
|
|
169
|
+
Steps performed:
|
|
170
|
+
1. Move batch to device and run "on_test_batch_start" callback.
|
|
171
|
+
2. Forward pass through the network.
|
|
172
|
+
3. Compute base loss using the criterion and accumulate metric.
|
|
173
|
+
4. Evaluate constraints via the ConstraintEngine (does not modify loss).
|
|
174
|
+
5. Run "on_test_batch_end" callback.
|
|
175
|
+
|
|
176
|
+
Args:
|
|
177
|
+
batch: Dictionary of input and target tensors for the batch.
|
|
178
|
+
|
|
179
|
+
Returns:
|
|
180
|
+
Tensor: The base loss computed for the batch.
|
|
181
|
+
"""
|
|
182
|
+
batch = self._to_device(batch)
|
|
183
|
+
batch = self._run_callbacks("on_test_batch_start", batch)
|
|
184
|
+
|
|
185
|
+
batch = self.network(batch)
|
|
186
|
+
batch = self._run_callbacks("after_test_forward", batch)
|
|
187
|
+
|
|
188
|
+
loss: Tensor = self.criterion(
|
|
189
|
+
batch["output"],
|
|
190
|
+
batch["target"],
|
|
191
|
+
data=batch,
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
if self.metric_manager is not None:
|
|
195
|
+
self.metric_manager.accumulate("Loss/test", loss.unsqueeze(0))
|
|
196
|
+
|
|
197
|
+
self.constraint_engine.test(batch, loss)
|
|
198
|
+
|
|
199
|
+
batch = self._run_callbacks("on_test_batch_end", batch)
|
|
200
|
+
return loss
|