congrads 0.1.0__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- congrads/__init__.py +10 -20
- congrads/callbacks/base.py +357 -0
- congrads/callbacks/registry.py +106 -0
- congrads/checkpoints.py +178 -0
- congrads/constraints/base.py +242 -0
- congrads/constraints/registry.py +1255 -0
- congrads/core/batch_runner.py +200 -0
- congrads/core/congradscore.py +271 -0
- congrads/core/constraint_engine.py +209 -0
- congrads/core/epoch_runner.py +119 -0
- congrads/datasets/registry.py +799 -0
- congrads/descriptor.py +147 -43
- congrads/metrics.py +116 -41
- congrads/networks/registry.py +68 -0
- congrads/py.typed +0 -0
- congrads/transformations/base.py +37 -0
- congrads/transformations/registry.py +86 -0
- congrads/utils/preprocessors.py +439 -0
- congrads/utils/utility.py +506 -0
- congrads/utils/validation.py +182 -0
- congrads-0.3.0.dist-info/METADATA +234 -0
- congrads-0.3.0.dist-info/RECORD +23 -0
- congrads-0.3.0.dist-info/WHEEL +4 -0
- congrads/constraints.py +0 -507
- congrads/core.py +0 -211
- congrads/datasets.py +0 -742
- congrads/learners.py +0 -233
- congrads/networks.py +0 -91
- congrads-0.1.0.dist-info/LICENSE +0 -34
- congrads-0.1.0.dist-info/METADATA +0 -196
- congrads-0.1.0.dist-info/RECORD +0 -13
- congrads-0.1.0.dist-info/WHEEL +0 -5
- congrads-0.1.0.dist-info/top_level.txt +0 -1
congrads/core.py
DELETED
|
@@ -1,211 +0,0 @@
|
|
|
1
|
-
import logging
|
|
2
|
-
from typing import Dict
|
|
3
|
-
from lightning import LightningModule
|
|
4
|
-
from torch import Tensor, float32, no_grad, norm, tensor
|
|
5
|
-
from torchmetrics import Metric
|
|
6
|
-
from torch.nn import ModuleDict
|
|
7
|
-
|
|
8
|
-
from .constraints import Constraint
|
|
9
|
-
from .metrics import ConstraintSatisfactionRatio
|
|
10
|
-
from .descriptor import Descriptor
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
class CGGDModule(LightningModule):
|
|
14
|
-
"""
|
|
15
|
-
A PyTorch Lightning module that integrates constraint-guided optimization into the training and validation steps.
|
|
16
|
-
|
|
17
|
-
This module extends the `LightningModule` and incorporates constraints on the neural network's predictions
|
|
18
|
-
by adjusting the loss using a rescale factor. The constraints are checked, and the loss is modified to guide
|
|
19
|
-
the optimization process based on these constraints.
|
|
20
|
-
|
|
21
|
-
Attributes:
|
|
22
|
-
descriptor (Descriptor): The object that describes the layers and neurons of the network, including
|
|
23
|
-
the categorization of variable layers.
|
|
24
|
-
constraints (list[Constraint]): A list of constraints that define the conditions to guide the optimization.
|
|
25
|
-
train_csr (Dict[str, Metric]): A dictionary of `ConstraintSatisfactionRatio` metrics to track constraint satisfaction
|
|
26
|
-
during training, indexed by constraint name.
|
|
27
|
-
valid_csr (Dict[str, Metric]): A dictionary of `ConstraintSatisfactionRatio` metrics to track constraint satisfaction
|
|
28
|
-
during validation, indexed by constraint name.
|
|
29
|
-
"""
|
|
30
|
-
|
|
31
|
-
def __init__(self, descriptor: Descriptor, constraints: list[Constraint]):
|
|
32
|
-
"""
|
|
33
|
-
Initializes the CGGDModule with a descriptor and a list of constraints.
|
|
34
|
-
|
|
35
|
-
Args:
|
|
36
|
-
descriptor (Descriptor): The object that describes the network's layers and neurons, including their categorization.
|
|
37
|
-
constraints (list[Constraint]): A list of constraints that will guide the optimization process.
|
|
38
|
-
|
|
39
|
-
Raises:
|
|
40
|
-
Warning if there are no variable layers in the descriptor, as constraints will not be applied.
|
|
41
|
-
"""
|
|
42
|
-
|
|
43
|
-
# Init parent class
|
|
44
|
-
super().__init__()
|
|
45
|
-
|
|
46
|
-
# Init object variables
|
|
47
|
-
self.descriptor = descriptor
|
|
48
|
-
self.constraints = constraints
|
|
49
|
-
|
|
50
|
-
# Perform checks
|
|
51
|
-
if len(self.descriptor.variable_layers) == 0:
|
|
52
|
-
logging.warning(
|
|
53
|
-
"The descriptor object has no variable layers. The constraint guided loss adjustment is therefore not used. Is this the intended behaviour?"
|
|
54
|
-
)
|
|
55
|
-
|
|
56
|
-
# Assign descriptor to constraints
|
|
57
|
-
for constraint in self.constraints:
|
|
58
|
-
constraint.descriptor = descriptor
|
|
59
|
-
constraint.run_init_descriptor()
|
|
60
|
-
|
|
61
|
-
# Init constraint metric logging
|
|
62
|
-
self.train_csr: Dict[str, Metric] = ModuleDict(
|
|
63
|
-
{
|
|
64
|
-
constraint.constraint_name: ConstraintSatisfactionRatio()
|
|
65
|
-
for constraint in self.constraints
|
|
66
|
-
}
|
|
67
|
-
)
|
|
68
|
-
self.train_csr["global"] = ConstraintSatisfactionRatio()
|
|
69
|
-
self.valid_csr: Dict[str, Metric] = ModuleDict(
|
|
70
|
-
{
|
|
71
|
-
constraint.constraint_name: ConstraintSatisfactionRatio()
|
|
72
|
-
for constraint in self.constraints
|
|
73
|
-
}
|
|
74
|
-
)
|
|
75
|
-
self.valid_csr["global"] = ConstraintSatisfactionRatio()
|
|
76
|
-
|
|
77
|
-
def training_step(
|
|
78
|
-
self,
|
|
79
|
-
prediction: dict[str, Tensor],
|
|
80
|
-
loss: Tensor,
|
|
81
|
-
):
|
|
82
|
-
"""
|
|
83
|
-
The training step where the standard loss is combined with rescale loss based on the constraints.
|
|
84
|
-
|
|
85
|
-
For each constraint, the satisfaction ratio is checked, and the loss is adjusted by adding a rescale loss
|
|
86
|
-
based on the directions calculated by the constraint.
|
|
87
|
-
|
|
88
|
-
Args:
|
|
89
|
-
prediction (dict[str, Tensor]): The model's predictions for each layer.
|
|
90
|
-
loss (Tensor): The base loss from the model's forward pass.
|
|
91
|
-
|
|
92
|
-
Returns:
|
|
93
|
-
Tensor: The combined loss, including both the original loss and the rescale loss from the constraints.
|
|
94
|
-
"""
|
|
95
|
-
|
|
96
|
-
# Init scalar tensor for loss
|
|
97
|
-
total_rescale_loss = tensor(0, dtype=float32, device=self.device)
|
|
98
|
-
|
|
99
|
-
# Compute rescale loss without tracking gradients
|
|
100
|
-
with no_grad():
|
|
101
|
-
|
|
102
|
-
# For each constraint, TODO split into real and validation only constraints
|
|
103
|
-
for constraint in self.constraints:
|
|
104
|
-
|
|
105
|
-
# Check if constraints are satisfied and calculate directions
|
|
106
|
-
constraint_checks = constraint.check_constraint(prediction)
|
|
107
|
-
constraint_directions = constraint.calculate_direction(prediction)
|
|
108
|
-
|
|
109
|
-
# Only do direction calculations for variable layers affecting constraint
|
|
110
|
-
for layer in constraint.layers & self.descriptor.variable_layers:
|
|
111
|
-
|
|
112
|
-
# Multiply direction modifiers with constraint result
|
|
113
|
-
constraint_result = (
|
|
114
|
-
constraint_checks[layer].unsqueeze(1).type(float32)
|
|
115
|
-
* constraint_directions[layer]
|
|
116
|
-
)
|
|
117
|
-
|
|
118
|
-
# Multiply result with rescale factor o constraint
|
|
119
|
-
constraint_result *= constraint.rescale_factor
|
|
120
|
-
|
|
121
|
-
# Calculate gradients of general loss for each sample
|
|
122
|
-
loss.backward(retain_graph=True, inputs=prediction[layer])
|
|
123
|
-
loss_grad = prediction[layer].grad
|
|
124
|
-
|
|
125
|
-
# Calculate loss gradient norm
|
|
126
|
-
norm_loss_grad = norm(loss_grad, dim=0, p=2, keepdim=True)
|
|
127
|
-
|
|
128
|
-
# Calculate rescale loss
|
|
129
|
-
rescale_loss = (
|
|
130
|
-
(prediction[layer] * constraint_result * norm_loss_grad)
|
|
131
|
-
.sum()
|
|
132
|
-
.abs()
|
|
133
|
-
)
|
|
134
|
-
|
|
135
|
-
# Store rescale loss for this reference space
|
|
136
|
-
total_rescale_loss += rescale_loss
|
|
137
|
-
|
|
138
|
-
# Log constraint satisfaction ratio
|
|
139
|
-
# NOTE does this take into account spaces with different dimensions?
|
|
140
|
-
self.train_csr[constraint.constraint_name](constraint_checks[layer])
|
|
141
|
-
self.train_csr["global"](constraint_checks[layer])
|
|
142
|
-
self.log(
|
|
143
|
-
f"train_csr_{constraint.constraint_name}_{layer}",
|
|
144
|
-
self.train_csr[constraint.constraint_name],
|
|
145
|
-
on_step=False,
|
|
146
|
-
on_epoch=True,
|
|
147
|
-
)
|
|
148
|
-
|
|
149
|
-
# Log global constraint satisfaction ratio
|
|
150
|
-
self.log(
|
|
151
|
-
"train_csr_global",
|
|
152
|
-
self.train_csr["global"],
|
|
153
|
-
on_step=False,
|
|
154
|
-
on_epoch=True,
|
|
155
|
-
)
|
|
156
|
-
|
|
157
|
-
# Return combined loss
|
|
158
|
-
return loss + total_rescale_loss
|
|
159
|
-
|
|
160
|
-
def validation_step(
|
|
161
|
-
self,
|
|
162
|
-
prediction: dict[str, Tensor],
|
|
163
|
-
loss: Tensor,
|
|
164
|
-
):
|
|
165
|
-
"""
|
|
166
|
-
The validation step where the satisfaction of constraints is checked without applying the rescale loss.
|
|
167
|
-
|
|
168
|
-
Similar to the training step, but without updating the loss, this method tracks the constraint satisfaction
|
|
169
|
-
during validation.
|
|
170
|
-
|
|
171
|
-
Args:
|
|
172
|
-
prediction (dict[str, Tensor]): The model's predictions for each layer.
|
|
173
|
-
loss (Tensor): The base loss from the model's forward pass.
|
|
174
|
-
|
|
175
|
-
Returns:
|
|
176
|
-
Tensor: The base loss value for validation.
|
|
177
|
-
"""
|
|
178
|
-
|
|
179
|
-
# Compute rescale loss without tracking gradients
|
|
180
|
-
with no_grad():
|
|
181
|
-
|
|
182
|
-
# For each constraint in this reference space, calculate directions
|
|
183
|
-
for constraint in self.constraints:
|
|
184
|
-
|
|
185
|
-
# Check if constraints are satisfied for
|
|
186
|
-
constraint_checks = constraint.check_constraint(prediction)
|
|
187
|
-
|
|
188
|
-
# Only do direction calculations for variable layers affecting constraint
|
|
189
|
-
for layer in constraint.layers & self.descriptor.variable_layers:
|
|
190
|
-
|
|
191
|
-
# Log constraint satisfaction ratio
|
|
192
|
-
# NOTE does this take into account spaces with different dimensions?
|
|
193
|
-
self.valid_csr[constraint.constraint_name](constraint_checks[layer])
|
|
194
|
-
self.valid_csr["global"](constraint_checks[layer])
|
|
195
|
-
self.log(
|
|
196
|
-
f"valid_csr_{constraint.constraint_name}",
|
|
197
|
-
self.valid_csr[constraint.constraint_name],
|
|
198
|
-
on_step=False,
|
|
199
|
-
on_epoch=True,
|
|
200
|
-
)
|
|
201
|
-
|
|
202
|
-
# Log global constraint satisfaction ratio
|
|
203
|
-
self.log(
|
|
204
|
-
"valid_csr_global",
|
|
205
|
-
self.valid_csr["global"],
|
|
206
|
-
on_step=False,
|
|
207
|
-
on_epoch=True,
|
|
208
|
-
)
|
|
209
|
-
|
|
210
|
-
# Return loss
|
|
211
|
-
return loss
|