congrads 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
congrads/core.py CHANGED
@@ -1,44 +1,28 @@
1
1
  import logging
2
- from typing import Dict
3
- from lightning import LightningModule
4
2
  from torch import Tensor, float32, no_grad, norm, tensor
5
- from torchmetrics import Metric
6
- from torch.nn import ModuleDict
3
+ from torch.optim import Optimizer
4
+ from torch.nn import Module
5
+ from torch.utils.data import DataLoader
6
+ from time import time
7
7
 
8
+ from .metrics import MetricManager
8
9
  from .constraints import Constraint
9
- from .metrics import ConstraintSatisfactionRatio
10
10
  from .descriptor import Descriptor
11
11
 
12
12
 
13
- class CGGDModule(LightningModule):
14
- """
15
- A PyTorch Lightning module that integrates constraint-guided optimization into the training and validation steps.
13
+ class CongradsCore:
16
14
 
17
- This module extends the `LightningModule` and incorporates constraints on the neural network's predictions
18
- by adjusting the loss using a rescale factor. The constraints are checked, and the loss is modified to guide
19
- the optimization process based on these constraints.
20
-
21
- Attributes:
22
- descriptor (Descriptor): The object that describes the layers and neurons of the network, including
23
- the categorization of variable layers.
24
- constraints (list[Constraint]): A list of constraints that define the conditions to guide the optimization.
25
- train_csr (Dict[str, Metric]): A dictionary of `ConstraintSatisfactionRatio` metrics to track constraint satisfaction
26
- during training, indexed by constraint name.
27
- valid_csr (Dict[str, Metric]): A dictionary of `ConstraintSatisfactionRatio` metrics to track constraint satisfaction
28
- during validation, indexed by constraint name.
29
- """
30
-
31
- def __init__(self, descriptor: Descriptor, constraints: list[Constraint]):
32
- """
33
- Initializes the CGGDModule with a descriptor and a list of constraints.
34
-
35
- Args:
36
- descriptor (Descriptor): The object that describes the network's layers and neurons, including their categorization.
37
- constraints (list[Constraint]): A list of constraints that will guide the optimization process.
38
-
39
- Raises:
40
- Warning if there are no variable layers in the descriptor, as constraints will not be applied.
41
- """
15
+ def __init__(
16
+ self,
17
+ descriptor: Descriptor,
18
+ constraints: list[Constraint],
19
+ loaders: tuple[DataLoader, DataLoader, DataLoader],
20
+ network: Module,
21
+ criterion: callable,
22
+ optimizer: Optimizer,
23
+ metric_manager: MetricManager,
24
+ device,
25
+ ):
42
26
 
43
27
  # Init parent class
44
28
  super().__init__()
@@ -46,6 +30,14 @@ class CGGDModule(LightningModule):
46
30
  # Init object variables
47
31
  self.descriptor = descriptor
48
32
  self.constraints = constraints
33
+ self.train_loader = loaders[0]
34
+ self.valid_loader = loaders[1]
35
+ self.test_loader = loaders[2]
36
+ self.network = network
37
+ self.criterion = criterion
38
+ self.optimizer = optimizer
39
+ self.metric_manager = metric_manager
40
+ self.device = device
49
41
 
50
42
  # Perform checks
51
43
  if len(self.descriptor.variable_layers) == 0:
@@ -53,128 +45,162 @@ class CGGDModule(LightningModule):
53
45
  "The descriptor object has no variable layers. The constraint guided loss adjustment is therefore not used. Is this the intended behaviour?"
54
46
  )
55
47
 
56
- # Assign descriptor to constraints
48
+ # Initialize constraint metrics
49
+ metric_manager.register("Loss/train")
50
+ metric_manager.register("Loss/valid")
51
+ metric_manager.register("CSR/train")
52
+ metric_manager.register("CSR/valid")
53
+
57
54
  for constraint in self.constraints:
58
- constraint.descriptor = descriptor
59
- constraint.run_init_descriptor()
60
-
61
- # Init constraint metric logging
62
- self.train_csr: Dict[str, Metric] = ModuleDict(
63
- {
64
- constraint.constraint_name: ConstraintSatisfactionRatio()
65
- for constraint in self.constraints
66
- }
67
- )
68
- self.train_csr["global"] = ConstraintSatisfactionRatio()
69
- self.valid_csr: Dict[str, Metric] = ModuleDict(
70
- {
71
- constraint.constraint_name: ConstraintSatisfactionRatio()
72
- for constraint in self.constraints
73
- }
74
- )
75
- self.valid_csr["global"] = ConstraintSatisfactionRatio()
76
-
77
- def training_step(
55
+ metric_manager.register(f"{constraint.name}/train")
56
+ metric_manager.register(f"{constraint.name}/valid")
57
+
58
+ def fit(self, max_epochs: int = 100):
59
+ # Loop over epochs
60
+ for epoch in range(max_epochs):
61
+
62
+ # Log start time
63
+ start_time = time()
64
+
65
+ # Training
66
+ for batch in self.train_loader:
67
+
68
+ # Set model in training mode
69
+ self.network.train()
70
+
71
+ # Get input-output pairs from batch
72
+ inputs, outputs = batch
73
+
74
+ # Transfer to GPU
75
+ inputs, outputs = inputs.to(self.device), outputs.to(self.device)
76
+
77
+ # Log preparation time
78
+ prepare_time = start_time - time()
79
+
80
+ # Model computations
81
+ prediction = self.network(inputs)
82
+
83
+ # Calculate loss
84
+ loss = self.criterion(prediction["output"], outputs)
85
+ self.metric_manager.accumulate("Loss/train", loss.unsqueeze(0))
86
+
87
+ # Adjust loss based on constraints
88
+ combined_loss = self.train_step(prediction, loss)
89
+
90
+ # Backpropx
91
+ self.optimizer.zero_grad()
92
+ combined_loss.backward(
93
+ retain_graph=False, inputs=list(self.network.parameters())
94
+ )
95
+ self.optimizer.step()
96
+
97
+ # Validation
98
+ with no_grad():
99
+ for batch in self.valid_loader:
100
+
101
+ # Set model in evaluation mode
102
+ self.network.eval()
103
+
104
+ # Get input-output pairs from batch
105
+ inputs, outputs = batch
106
+
107
+ # Transfer to GPU
108
+ inputs, outputs = inputs.to(self.device), outputs.to(self.device)
109
+
110
+ # Model computations
111
+ prediction = self.network(inputs)
112
+
113
+ # Calculate loss
114
+ loss = self.criterion(prediction["output"], outputs)
115
+ self.metric_manager.accumulate("Loss/valid", loss.unsqueeze(0))
116
+
117
+ # Validate constraints
118
+ self.valid_step(prediction, loss)
119
+
120
+ # TODO with valid loader, checkpoint model with best performance
121
+
122
+ # Save metrics
123
+ self.metric_manager.record(epoch)
124
+ self.metric_manager.reset()
125
+
126
+ # Log compute and preparation time
127
+ process_time = start_time - time() - prepare_time
128
+ print(
129
+ "Compute efficiency: {:.2f}, epoch: {}/{}:".format(
130
+ process_time / (process_time + prepare_time), epoch, max_epochs
131
+ )
132
+ )
133
+ start_time = time()
134
+
135
+ def train_step(
78
136
  self,
79
137
  prediction: dict[str, Tensor],
80
138
  loss: Tensor,
81
139
  ):
82
- """
83
- The training step where the standard loss is combined with rescale loss based on the constraints.
84
-
85
- For each constraint, the satisfaction ratio is checked, and the loss is adjusted by adding a rescale loss
86
- based on the directions calculated by the constraint.
87
-
88
- Args:
89
- prediction (dict[str, Tensor]): The model's predictions for each layer.
90
- loss (Tensor): The base loss from the model's forward pass.
91
-
92
- Returns:
93
- Tensor: The combined loss, including both the original loss and the rescale loss from the constraints.
94
- """
95
140
 
96
141
  # Init scalar tensor for loss
97
142
  total_rescale_loss = tensor(0, dtype=float32, device=self.device)
143
+ loss_grads = {}
98
144
 
99
- # Compute rescale loss without tracking gradients
145
+ # Precalculate loss gradients for each variable layer
100
146
  with no_grad():
147
+ for layer in self.descriptor.variable_layers:
148
+ self.optimizer.zero_grad()
149
+ loss.backward(retain_graph=True, inputs=prediction[layer])
150
+ loss_grads[layer] = prediction[layer].grad
101
151
 
102
- # For each constraint, TODO split into real and validation only constraints
103
- for constraint in self.constraints:
152
+ # For each constraint, TODO split into real and validation only constraints
153
+ for constraint in self.constraints:
104
154
 
105
- # Check if constraints are satisfied and calculate directions
155
+ # Check if constraints are satisfied and calculate directions
156
+ with no_grad():
106
157
  constraint_checks = constraint.check_constraint(prediction)
107
158
  constraint_directions = constraint.calculate_direction(prediction)
108
159
 
109
- # Only do direction calculations for variable layers affecting constraint
110
- for layer in constraint.layers & self.descriptor.variable_layers:
160
+ # Only do direction calculations for variable layers affecting constraint
161
+ for layer in constraint.layers & self.descriptor.variable_layers:
111
162
 
163
+ with no_grad():
112
164
  # Multiply direction modifiers with constraint result
113
165
  constraint_result = (
114
- constraint_checks[layer].unsqueeze(1).type(float32)
166
+ constraint_checks.unsqueeze(1).type(float32)
115
167
  * constraint_directions[layer]
116
168
  )
117
169
 
118
- # Multiply result with rescale factor o constraint
170
+ # Multiply result with rescale factor of constraint
119
171
  constraint_result *= constraint.rescale_factor
120
172
 
121
- # Calculate gradients of general loss for each sample
122
- loss.backward(retain_graph=True, inputs=prediction[layer])
123
- loss_grad = prediction[layer].grad
124
-
125
173
  # Calculate loss gradient norm
126
- norm_loss_grad = norm(loss_grad, dim=0, p=2, keepdim=True)
127
-
128
- # Calculate rescale loss
129
- rescale_loss = (
130
- (prediction[layer] * constraint_result * norm_loss_grad)
131
- .sum()
132
- .abs()
133
- )
134
-
135
- # Store rescale loss for this reference space
136
- total_rescale_loss += rescale_loss
137
-
138
- # Log constraint satisfaction ratio
139
- # NOTE does this take into account spaces with different dimensions?
140
- self.train_csr[constraint.constraint_name](constraint_checks[layer])
141
- self.train_csr["global"](constraint_checks[layer])
142
- self.log(
143
- f"train_csr_{constraint.constraint_name}_{layer}",
144
- self.train_csr[constraint.constraint_name],
145
- on_step=False,
146
- on_epoch=True,
147
- )
148
-
149
- # Log global constraint satisfaction ratio
150
- self.log(
151
- "train_csr_global",
152
- self.train_csr["global"],
153
- on_step=False,
154
- on_epoch=True,
155
- )
174
+ norm_loss_grad = norm(loss_grads[layer], dim=1, p=2, keepdim=True)
175
+
176
+ # Calculate rescale loss
177
+ rescale_loss = (
178
+ prediction[layer]
179
+ * constraint_result
180
+ * norm_loss_grad.detach().clone()
181
+ ).mean()
182
+
183
+ # Store rescale loss for this reference space
184
+ total_rescale_loss += rescale_loss
185
+
186
+ # Log constraint satisfaction ratio
187
+ self.metric_manager.accumulate(
188
+ f"{constraint.name}/train",
189
+ (~constraint_checks).type(float32),
190
+ )
191
+ self.metric_manager.accumulate(
192
+ "CSR/train",
193
+ (~constraint_checks).type(float32),
194
+ )
156
195
 
157
196
  # Return combined loss
158
197
  return loss + total_rescale_loss
159
198
 
160
- def validation_step(
199
+ def valid_step(
161
200
  self,
162
201
  prediction: dict[str, Tensor],
163
202
  loss: Tensor,
164
203
  ):
165
- """
166
- The validation step where the satisfaction of constraints is checked without applying the rescale loss.
167
-
168
- Similar to the training step, but without updating the loss, this method tracks the constraint satisfaction
169
- during validation.
170
-
171
- Args:
172
- prediction (dict[str, Tensor]): The model's predictions for each layer.
173
- loss (Tensor): The base loss from the model's forward pass.
174
-
175
- Returns:
176
- Tensor: The base loss value for validation.
177
- """
178
204
 
179
205
  # Compute rescale loss without tracking gradients
180
206
  with no_grad():
@@ -185,27 +211,15 @@ class CGGDModule(LightningModule):
185
211
  # Check if constraints are satisfied for
186
212
  constraint_checks = constraint.check_constraint(prediction)
187
213
 
188
- # Only do direction calculations for variable layers affecting constraint
189
- for layer in constraint.layers & self.descriptor.variable_layers:
190
-
191
- # Log constraint satisfaction ratio
192
- # NOTE does this take into account spaces with different dimensions?
193
- self.valid_csr[constraint.constraint_name](constraint_checks[layer])
194
- self.valid_csr["global"](constraint_checks[layer])
195
- self.log(
196
- f"valid_csr_{constraint.constraint_name}",
197
- self.valid_csr[constraint.constraint_name],
198
- on_step=False,
199
- on_epoch=True,
200
- )
201
-
202
- # Log global constraint satisfaction ratio
203
- self.log(
204
- "valid_csr_global",
205
- self.valid_csr["global"],
206
- on_step=False,
207
- on_epoch=True,
208
- )
214
+ # Log constraint satisfaction ratio
215
+ self.metric_manager.accumulate(
216
+ f"{constraint.name}/valid",
217
+ (~constraint_checks).type(float32),
218
+ )
219
+ self.metric_manager.accumulate(
220
+ "CSR/valid",
221
+ (~constraint_checks).type(float32),
222
+ )
209
223
 
210
224
  # Return loss
211
225
  return loss