congrads 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- congrads/__init__.py +7 -6
- congrads/constraints.py +182 -300
- congrads/core.py +158 -144
- congrads/datasets.py +12 -559
- congrads/descriptor.py +20 -35
- congrads/metrics.py +37 -52
- congrads/networks.py +5 -6
- congrads/utils.py +310 -0
- congrads-0.2.0.dist-info/LICENSE +26 -0
- congrads-0.2.0.dist-info/METADATA +222 -0
- congrads-0.2.0.dist-info/RECORD +13 -0
- congrads/learners.py +0 -233
- congrads-0.1.0.dist-info/LICENSE +0 -34
- congrads-0.1.0.dist-info/METADATA +0 -196
- congrads-0.1.0.dist-info/RECORD +0 -13
- {congrads-0.1.0.dist-info → congrads-0.2.0.dist-info}/WHEEL +0 -0
- {congrads-0.1.0.dist-info → congrads-0.2.0.dist-info}/top_level.txt +0 -0
congrads/core.py
CHANGED
|
@@ -1,44 +1,28 @@
|
|
|
1
1
|
import logging
|
|
2
|
-
from typing import Dict
|
|
3
|
-
from lightning import LightningModule
|
|
4
2
|
from torch import Tensor, float32, no_grad, norm, tensor
|
|
5
|
-
from
|
|
6
|
-
from torch.nn import
|
|
3
|
+
from torch.optim import Optimizer
|
|
4
|
+
from torch.nn import Module
|
|
5
|
+
from torch.utils.data import DataLoader
|
|
6
|
+
from time import time
|
|
7
7
|
|
|
8
|
+
from .metrics import MetricManager
|
|
8
9
|
from .constraints import Constraint
|
|
9
|
-
from .metrics import ConstraintSatisfactionRatio
|
|
10
10
|
from .descriptor import Descriptor
|
|
11
11
|
|
|
12
12
|
|
|
13
|
-
class
|
|
14
|
-
"""
|
|
15
|
-
A PyTorch Lightning module that integrates constraint-guided optimization into the training and validation steps.
|
|
13
|
+
class CongradsCore:
|
|
16
14
|
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
during validation, indexed by constraint name.
|
|
29
|
-
"""
|
|
30
|
-
|
|
31
|
-
def __init__(self, descriptor: Descriptor, constraints: list[Constraint]):
|
|
32
|
-
"""
|
|
33
|
-
Initializes the CGGDModule with a descriptor and a list of constraints.
|
|
34
|
-
|
|
35
|
-
Args:
|
|
36
|
-
descriptor (Descriptor): The object that describes the network's layers and neurons, including their categorization.
|
|
37
|
-
constraints (list[Constraint]): A list of constraints that will guide the optimization process.
|
|
38
|
-
|
|
39
|
-
Raises:
|
|
40
|
-
Warning if there are no variable layers in the descriptor, as constraints will not be applied.
|
|
41
|
-
"""
|
|
15
|
+
def __init__(
|
|
16
|
+
self,
|
|
17
|
+
descriptor: Descriptor,
|
|
18
|
+
constraints: list[Constraint],
|
|
19
|
+
loaders: tuple[DataLoader, DataLoader, DataLoader],
|
|
20
|
+
network: Module,
|
|
21
|
+
criterion: callable,
|
|
22
|
+
optimizer: Optimizer,
|
|
23
|
+
metric_manager: MetricManager,
|
|
24
|
+
device,
|
|
25
|
+
):
|
|
42
26
|
|
|
43
27
|
# Init parent class
|
|
44
28
|
super().__init__()
|
|
@@ -46,6 +30,14 @@ class CGGDModule(LightningModule):
|
|
|
46
30
|
# Init object variables
|
|
47
31
|
self.descriptor = descriptor
|
|
48
32
|
self.constraints = constraints
|
|
33
|
+
self.train_loader = loaders[0]
|
|
34
|
+
self.valid_loader = loaders[1]
|
|
35
|
+
self.test_loader = loaders[2]
|
|
36
|
+
self.network = network
|
|
37
|
+
self.criterion = criterion
|
|
38
|
+
self.optimizer = optimizer
|
|
39
|
+
self.metric_manager = metric_manager
|
|
40
|
+
self.device = device
|
|
49
41
|
|
|
50
42
|
# Perform checks
|
|
51
43
|
if len(self.descriptor.variable_layers) == 0:
|
|
@@ -53,128 +45,162 @@ class CGGDModule(LightningModule):
|
|
|
53
45
|
"The descriptor object has no variable layers. The constraint guided loss adjustment is therefore not used. Is this the intended behaviour?"
|
|
54
46
|
)
|
|
55
47
|
|
|
56
|
-
#
|
|
48
|
+
# Initialize constraint metrics
|
|
49
|
+
metric_manager.register("Loss/train")
|
|
50
|
+
metric_manager.register("Loss/valid")
|
|
51
|
+
metric_manager.register("CSR/train")
|
|
52
|
+
metric_manager.register("CSR/valid")
|
|
53
|
+
|
|
57
54
|
for constraint in self.constraints:
|
|
58
|
-
constraint.
|
|
59
|
-
constraint.
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
55
|
+
metric_manager.register(f"{constraint.name}/train")
|
|
56
|
+
metric_manager.register(f"{constraint.name}/valid")
|
|
57
|
+
|
|
58
|
+
def fit(self, max_epochs: int = 100):
|
|
59
|
+
# Loop over epochs
|
|
60
|
+
for epoch in range(max_epochs):
|
|
61
|
+
|
|
62
|
+
# Log start time
|
|
63
|
+
start_time = time()
|
|
64
|
+
|
|
65
|
+
# Training
|
|
66
|
+
for batch in self.train_loader:
|
|
67
|
+
|
|
68
|
+
# Set model in training mode
|
|
69
|
+
self.network.train()
|
|
70
|
+
|
|
71
|
+
# Get input-output pairs from batch
|
|
72
|
+
inputs, outputs = batch
|
|
73
|
+
|
|
74
|
+
# Transfer to GPU
|
|
75
|
+
inputs, outputs = inputs.to(self.device), outputs.to(self.device)
|
|
76
|
+
|
|
77
|
+
# Log preparation time
|
|
78
|
+
prepare_time = start_time - time()
|
|
79
|
+
|
|
80
|
+
# Model computations
|
|
81
|
+
prediction = self.network(inputs)
|
|
82
|
+
|
|
83
|
+
# Calculate loss
|
|
84
|
+
loss = self.criterion(prediction["output"], outputs)
|
|
85
|
+
self.metric_manager.accumulate("Loss/train", loss.unsqueeze(0))
|
|
86
|
+
|
|
87
|
+
# Adjust loss based on constraints
|
|
88
|
+
combined_loss = self.train_step(prediction, loss)
|
|
89
|
+
|
|
90
|
+
# Backpropx
|
|
91
|
+
self.optimizer.zero_grad()
|
|
92
|
+
combined_loss.backward(
|
|
93
|
+
retain_graph=False, inputs=list(self.network.parameters())
|
|
94
|
+
)
|
|
95
|
+
self.optimizer.step()
|
|
96
|
+
|
|
97
|
+
# Validation
|
|
98
|
+
with no_grad():
|
|
99
|
+
for batch in self.valid_loader:
|
|
100
|
+
|
|
101
|
+
# Set model in evaluation mode
|
|
102
|
+
self.network.eval()
|
|
103
|
+
|
|
104
|
+
# Get input-output pairs from batch
|
|
105
|
+
inputs, outputs = batch
|
|
106
|
+
|
|
107
|
+
# Transfer to GPU
|
|
108
|
+
inputs, outputs = inputs.to(self.device), outputs.to(self.device)
|
|
109
|
+
|
|
110
|
+
# Model computations
|
|
111
|
+
prediction = self.network(inputs)
|
|
112
|
+
|
|
113
|
+
# Calculate loss
|
|
114
|
+
loss = self.criterion(prediction["output"], outputs)
|
|
115
|
+
self.metric_manager.accumulate("Loss/valid", loss.unsqueeze(0))
|
|
116
|
+
|
|
117
|
+
# Validate constraints
|
|
118
|
+
self.valid_step(prediction, loss)
|
|
119
|
+
|
|
120
|
+
# TODO with valid loader, checkpoint model with best performance
|
|
121
|
+
|
|
122
|
+
# Save metrics
|
|
123
|
+
self.metric_manager.record(epoch)
|
|
124
|
+
self.metric_manager.reset()
|
|
125
|
+
|
|
126
|
+
# Log compute and preparation time
|
|
127
|
+
process_time = start_time - time() - prepare_time
|
|
128
|
+
print(
|
|
129
|
+
"Compute efficiency: {:.2f}, epoch: {}/{}:".format(
|
|
130
|
+
process_time / (process_time + prepare_time), epoch, max_epochs
|
|
131
|
+
)
|
|
132
|
+
)
|
|
133
|
+
start_time = time()
|
|
134
|
+
|
|
135
|
+
def train_step(
|
|
78
136
|
self,
|
|
79
137
|
prediction: dict[str, Tensor],
|
|
80
138
|
loss: Tensor,
|
|
81
139
|
):
|
|
82
|
-
"""
|
|
83
|
-
The training step where the standard loss is combined with rescale loss based on the constraints.
|
|
84
|
-
|
|
85
|
-
For each constraint, the satisfaction ratio is checked, and the loss is adjusted by adding a rescale loss
|
|
86
|
-
based on the directions calculated by the constraint.
|
|
87
|
-
|
|
88
|
-
Args:
|
|
89
|
-
prediction (dict[str, Tensor]): The model's predictions for each layer.
|
|
90
|
-
loss (Tensor): The base loss from the model's forward pass.
|
|
91
|
-
|
|
92
|
-
Returns:
|
|
93
|
-
Tensor: The combined loss, including both the original loss and the rescale loss from the constraints.
|
|
94
|
-
"""
|
|
95
140
|
|
|
96
141
|
# Init scalar tensor for loss
|
|
97
142
|
total_rescale_loss = tensor(0, dtype=float32, device=self.device)
|
|
143
|
+
loss_grads = {}
|
|
98
144
|
|
|
99
|
-
#
|
|
145
|
+
# Precalculate loss gradients for each variable layer
|
|
100
146
|
with no_grad():
|
|
147
|
+
for layer in self.descriptor.variable_layers:
|
|
148
|
+
self.optimizer.zero_grad()
|
|
149
|
+
loss.backward(retain_graph=True, inputs=prediction[layer])
|
|
150
|
+
loss_grads[layer] = prediction[layer].grad
|
|
101
151
|
|
|
102
|
-
|
|
103
|
-
|
|
152
|
+
# For each constraint, TODO split into real and validation only constraints
|
|
153
|
+
for constraint in self.constraints:
|
|
104
154
|
|
|
105
|
-
|
|
155
|
+
# Check if constraints are satisfied and calculate directions
|
|
156
|
+
with no_grad():
|
|
106
157
|
constraint_checks = constraint.check_constraint(prediction)
|
|
107
158
|
constraint_directions = constraint.calculate_direction(prediction)
|
|
108
159
|
|
|
109
|
-
|
|
110
|
-
|
|
160
|
+
# Only do direction calculations for variable layers affecting constraint
|
|
161
|
+
for layer in constraint.layers & self.descriptor.variable_layers:
|
|
111
162
|
|
|
163
|
+
with no_grad():
|
|
112
164
|
# Multiply direction modifiers with constraint result
|
|
113
165
|
constraint_result = (
|
|
114
|
-
constraint_checks
|
|
166
|
+
constraint_checks.unsqueeze(1).type(float32)
|
|
115
167
|
* constraint_directions[layer]
|
|
116
168
|
)
|
|
117
169
|
|
|
118
|
-
# Multiply result with rescale factor
|
|
170
|
+
# Multiply result with rescale factor of constraint
|
|
119
171
|
constraint_result *= constraint.rescale_factor
|
|
120
172
|
|
|
121
|
-
# Calculate gradients of general loss for each sample
|
|
122
|
-
loss.backward(retain_graph=True, inputs=prediction[layer])
|
|
123
|
-
loss_grad = prediction[layer].grad
|
|
124
|
-
|
|
125
173
|
# Calculate loss gradient norm
|
|
126
|
-
norm_loss_grad = norm(
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
)
|
|
148
|
-
|
|
149
|
-
# Log global constraint satisfaction ratio
|
|
150
|
-
self.log(
|
|
151
|
-
"train_csr_global",
|
|
152
|
-
self.train_csr["global"],
|
|
153
|
-
on_step=False,
|
|
154
|
-
on_epoch=True,
|
|
155
|
-
)
|
|
174
|
+
norm_loss_grad = norm(loss_grads[layer], dim=1, p=2, keepdim=True)
|
|
175
|
+
|
|
176
|
+
# Calculate rescale loss
|
|
177
|
+
rescale_loss = (
|
|
178
|
+
prediction[layer]
|
|
179
|
+
* constraint_result
|
|
180
|
+
* norm_loss_grad.detach().clone()
|
|
181
|
+
).mean()
|
|
182
|
+
|
|
183
|
+
# Store rescale loss for this reference space
|
|
184
|
+
total_rescale_loss += rescale_loss
|
|
185
|
+
|
|
186
|
+
# Log constraint satisfaction ratio
|
|
187
|
+
self.metric_manager.accumulate(
|
|
188
|
+
f"{constraint.name}/train",
|
|
189
|
+
(~constraint_checks).type(float32),
|
|
190
|
+
)
|
|
191
|
+
self.metric_manager.accumulate(
|
|
192
|
+
"CSR/train",
|
|
193
|
+
(~constraint_checks).type(float32),
|
|
194
|
+
)
|
|
156
195
|
|
|
157
196
|
# Return combined loss
|
|
158
197
|
return loss + total_rescale_loss
|
|
159
198
|
|
|
160
|
-
def
|
|
199
|
+
def valid_step(
|
|
161
200
|
self,
|
|
162
201
|
prediction: dict[str, Tensor],
|
|
163
202
|
loss: Tensor,
|
|
164
203
|
):
|
|
165
|
-
"""
|
|
166
|
-
The validation step where the satisfaction of constraints is checked without applying the rescale loss.
|
|
167
|
-
|
|
168
|
-
Similar to the training step, but without updating the loss, this method tracks the constraint satisfaction
|
|
169
|
-
during validation.
|
|
170
|
-
|
|
171
|
-
Args:
|
|
172
|
-
prediction (dict[str, Tensor]): The model's predictions for each layer.
|
|
173
|
-
loss (Tensor): The base loss from the model's forward pass.
|
|
174
|
-
|
|
175
|
-
Returns:
|
|
176
|
-
Tensor: The base loss value for validation.
|
|
177
|
-
"""
|
|
178
204
|
|
|
179
205
|
# Compute rescale loss without tracking gradients
|
|
180
206
|
with no_grad():
|
|
@@ -185,27 +211,15 @@ class CGGDModule(LightningModule):
|
|
|
185
211
|
# Check if constraints are satisfied for
|
|
186
212
|
constraint_checks = constraint.check_constraint(prediction)
|
|
187
213
|
|
|
188
|
-
#
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
self.valid_csr[constraint.constraint_name],
|
|
198
|
-
on_step=False,
|
|
199
|
-
on_epoch=True,
|
|
200
|
-
)
|
|
201
|
-
|
|
202
|
-
# Log global constraint satisfaction ratio
|
|
203
|
-
self.log(
|
|
204
|
-
"valid_csr_global",
|
|
205
|
-
self.valid_csr["global"],
|
|
206
|
-
on_step=False,
|
|
207
|
-
on_epoch=True,
|
|
208
|
-
)
|
|
214
|
+
# Log constraint satisfaction ratio
|
|
215
|
+
self.metric_manager.accumulate(
|
|
216
|
+
f"{constraint.name}/valid",
|
|
217
|
+
(~constraint_checks).type(float32),
|
|
218
|
+
)
|
|
219
|
+
self.metric_manager.accumulate(
|
|
220
|
+
"CSR/valid",
|
|
221
|
+
(~constraint_checks).type(float32),
|
|
222
|
+
)
|
|
209
223
|
|
|
210
224
|
# Return loss
|
|
211
225
|
return loss
|