congrads 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
congrads/core.py ADDED
@@ -0,0 +1,211 @@
1
+ import logging
2
+ from typing import Dict
3
+ from lightning import LightningModule
4
+ from torch import Tensor, float32, no_grad, norm, tensor
5
+ from torchmetrics import Metric
6
+ from torch.nn import ModuleDict
7
+
8
+ from .constraints import Constraint
9
+ from .metrics import ConstraintSatisfactionRatio
10
+ from .descriptor import Descriptor
11
+
12
+
13
+ class CGGDModule(LightningModule):
14
+ """
15
+ A PyTorch Lightning module that integrates constraint-guided optimization into the training and validation steps.
16
+
17
+ This module extends the `LightningModule` and incorporates constraints on the neural network's predictions
18
+ by adjusting the loss using a rescale factor. The constraints are checked, and the loss is modified to guide
19
+ the optimization process based on these constraints.
20
+
21
+ Attributes:
22
+ descriptor (Descriptor): The object that describes the layers and neurons of the network, including
23
+ the categorization of variable layers.
24
+ constraints (list[Constraint]): A list of constraints that define the conditions to guide the optimization.
25
+ train_csr (Dict[str, Metric]): A dictionary of `ConstraintSatisfactionRatio` metrics to track constraint satisfaction
26
+ during training, indexed by constraint name.
27
+ valid_csr (Dict[str, Metric]): A dictionary of `ConstraintSatisfactionRatio` metrics to track constraint satisfaction
28
+ during validation, indexed by constraint name.
29
+ """
30
+
31
+ def __init__(self, descriptor: Descriptor, constraints: list[Constraint]):
32
+ """
33
+ Initializes the CGGDModule with a descriptor and a list of constraints.
34
+
35
+ Args:
36
+ descriptor (Descriptor): The object that describes the network's layers and neurons, including their categorization.
37
+ constraints (list[Constraint]): A list of constraints that will guide the optimization process.
38
+
39
+ Raises:
40
+ Warning if there are no variable layers in the descriptor, as constraints will not be applied.
41
+ """
42
+
43
+ # Init parent class
44
+ super().__init__()
45
+
46
+ # Init object variables
47
+ self.descriptor = descriptor
48
+ self.constraints = constraints
49
+
50
+ # Perform checks
51
+ if len(self.descriptor.variable_layers) == 0:
52
+ logging.warning(
53
+ "The descriptor object has no variable layers. The constraint guided loss adjustment is therefore not used. Is this the intended behaviour?"
54
+ )
55
+
56
+ # Assign descriptor to constraints
57
+ for constraint in self.constraints:
58
+ constraint.descriptor = descriptor
59
+ constraint.run_init_descriptor()
60
+
61
+ # Init constraint metric logging
62
+ self.train_csr: Dict[str, Metric] = ModuleDict(
63
+ {
64
+ constraint.constraint_name: ConstraintSatisfactionRatio()
65
+ for constraint in self.constraints
66
+ }
67
+ )
68
+ self.train_csr["global"] = ConstraintSatisfactionRatio()
69
+ self.valid_csr: Dict[str, Metric] = ModuleDict(
70
+ {
71
+ constraint.constraint_name: ConstraintSatisfactionRatio()
72
+ for constraint in self.constraints
73
+ }
74
+ )
75
+ self.valid_csr["global"] = ConstraintSatisfactionRatio()
76
+
77
+ def training_step(
78
+ self,
79
+ prediction: dict[str, Tensor],
80
+ loss: Tensor,
81
+ ):
82
+ """
83
+ The training step where the standard loss is combined with rescale loss based on the constraints.
84
+
85
+ For each constraint, the satisfaction ratio is checked, and the loss is adjusted by adding a rescale loss
86
+ based on the directions calculated by the constraint.
87
+
88
+ Args:
89
+ prediction (dict[str, Tensor]): The model's predictions for each layer.
90
+ loss (Tensor): The base loss from the model's forward pass.
91
+
92
+ Returns:
93
+ Tensor: The combined loss, including both the original loss and the rescale loss from the constraints.
94
+ """
95
+
96
+ # Init scalar tensor for loss
97
+ total_rescale_loss = tensor(0, dtype=float32, device=self.device)
98
+
99
+ # Compute rescale loss without tracking gradients
100
+ with no_grad():
101
+
102
+ # For each constraint, TODO split into real and validation only constraints
103
+ for constraint in self.constraints:
104
+
105
+ # Check if constraints are satisfied and calculate directions
106
+ constraint_checks = constraint.check_constraint(prediction)
107
+ constraint_directions = constraint.calculate_direction(prediction)
108
+
109
+ # Only do direction calculations for variable layers affecting constraint
110
+ for layer in constraint.layers & self.descriptor.variable_layers:
111
+
112
+ # Multiply direction modifiers with constraint result
113
+ constraint_result = (
114
+ constraint_checks[layer].unsqueeze(1).type(float32)
115
+ * constraint_directions[layer]
116
+ )
117
+
118
+ # Multiply result with rescale factor o constraint
119
+ constraint_result *= constraint.rescale_factor
120
+
121
+ # Calculate gradients of general loss for each sample
122
+ loss.backward(retain_graph=True, inputs=prediction[layer])
123
+ loss_grad = prediction[layer].grad
124
+
125
+ # Calculate loss gradient norm
126
+ norm_loss_grad = norm(loss_grad, dim=0, p=2, keepdim=True)
127
+
128
+ # Calculate rescale loss
129
+ rescale_loss = (
130
+ (prediction[layer] * constraint_result * norm_loss_grad)
131
+ .sum()
132
+ .abs()
133
+ )
134
+
135
+ # Store rescale loss for this reference space
136
+ total_rescale_loss += rescale_loss
137
+
138
+ # Log constraint satisfaction ratio
139
+ # NOTE does this take into account spaces with different dimensions?
140
+ self.train_csr[constraint.constraint_name](constraint_checks[layer])
141
+ self.train_csr["global"](constraint_checks[layer])
142
+ self.log(
143
+ f"train_csr_{constraint.constraint_name}_{layer}",
144
+ self.train_csr[constraint.constraint_name],
145
+ on_step=False,
146
+ on_epoch=True,
147
+ )
148
+
149
+ # Log global constraint satisfaction ratio
150
+ self.log(
151
+ "train_csr_global",
152
+ self.train_csr["global"],
153
+ on_step=False,
154
+ on_epoch=True,
155
+ )
156
+
157
+ # Return combined loss
158
+ return loss + total_rescale_loss
159
+
160
+ def validation_step(
161
+ self,
162
+ prediction: dict[str, Tensor],
163
+ loss: Tensor,
164
+ ):
165
+ """
166
+ The validation step where the satisfaction of constraints is checked without applying the rescale loss.
167
+
168
+ Similar to the training step, but without updating the loss, this method tracks the constraint satisfaction
169
+ during validation.
170
+
171
+ Args:
172
+ prediction (dict[str, Tensor]): The model's predictions for each layer.
173
+ loss (Tensor): The base loss from the model's forward pass.
174
+
175
+ Returns:
176
+ Tensor: The base loss value for validation.
177
+ """
178
+
179
+ # Compute rescale loss without tracking gradients
180
+ with no_grad():
181
+
182
+ # For each constraint in this reference space, calculate directions
183
+ for constraint in self.constraints:
184
+
185
+ # Check if constraints are satisfied for
186
+ constraint_checks = constraint.check_constraint(prediction)
187
+
188
+ # Only do direction calculations for variable layers affecting constraint
189
+ for layer in constraint.layers & self.descriptor.variable_layers:
190
+
191
+ # Log constraint satisfaction ratio
192
+ # NOTE does this take into account spaces with different dimensions?
193
+ self.valid_csr[constraint.constraint_name](constraint_checks[layer])
194
+ self.valid_csr["global"](constraint_checks[layer])
195
+ self.log(
196
+ f"valid_csr_{constraint.constraint_name}",
197
+ self.valid_csr[constraint.constraint_name],
198
+ on_step=False,
199
+ on_epoch=True,
200
+ )
201
+
202
+ # Log global constraint satisfaction ratio
203
+ self.log(
204
+ "valid_csr_global",
205
+ self.valid_csr["global"],
206
+ on_step=False,
207
+ on_epoch=True,
208
+ )
209
+
210
+ # Return loss
211
+ return loss