congrads 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
congrads/__init__.py CHANGED
@@ -1,21 +1,22 @@
1
1
  # __init__.py
2
+ version = "0.2.0"
2
3
 
3
4
  # Only expose the submodules, not individual classes
4
- from . import core
5
5
  from . import constraints
6
+ from . import core
6
7
  from . import datasets
7
8
  from . import descriptor
8
- from . import learners
9
9
  from . import metrics
10
10
  from . import networks
11
+ from . import utils
11
12
 
12
13
  # Define __all__ to specify that the submodules are accessible, but not classes directly.
13
14
  __all__ = [
14
- "core",
15
15
  "constraints",
16
+ "core",
16
17
  "datasets",
17
18
  "descriptor",
18
- "learners",
19
19
  "metrics",
20
- "networks"
21
- ]
20
+ "networks",
21
+ "utils",
22
+ ]
congrads/constraints.py CHANGED
@@ -3,7 +3,18 @@ from numbers import Number
3
3
  import random
4
4
  import string
5
5
  from typing import Callable, Dict
6
- from torch import Tensor, ge, gt, lt, le, zeros, FloatTensor, ones, tensor, float32
6
+ from torch import (
7
+ Tensor,
8
+ ge,
9
+ gt,
10
+ lt,
11
+ le,
12
+ reshape,
13
+ stack,
14
+ ones,
15
+ tensor,
16
+ zeros_like,
17
+ )
7
18
  import logging
8
19
  from torch.nn.functional import normalize
9
20
 
@@ -11,154 +22,84 @@ from .descriptor import Descriptor
11
22
 
12
23
 
13
24
  class Constraint(ABC):
14
- """
15
- Abstract base class for defining constraints that can be applied during optimization in the constraint-guided gradient descent process.
16
-
17
- A constraint guides the optimization by evaluating the model's predictions and adjusting the loss based on certain conditions.
18
- Constraints can be applied to specific layers or neurons of the model, and they are scaled by a rescale factor to control the influence of the constraint on the overall loss.
19
-
20
- Attributes:
21
- descriptor (Descriptor): The descriptor object that provides a mapping of neurons to layers.
22
- constraint_name (str): A unique name for the constraint, which can be provided or generated automatically.
23
- rescale_factor (float): A factor used to scale the influence of the constraint on the overall loss.
24
- neuron_names (set[str]): A set of neuron names that are involved in the constraint.
25
- layers (set): A set of layers associated with the neurons specified in `neuron_names`.
26
- """
27
25
 
28
26
  descriptor: Descriptor = None
27
+ device = None
29
28
 
30
29
  def __init__(
31
30
  self,
32
- neuron_names: set[str],
33
- constraint_name: str = None,
31
+ neurons: set[str],
32
+ name: str = None,
34
33
  rescale_factor: float = 1.5,
35
34
  ) -> None:
36
- """
37
- Initializes the Constraint object with the given neuron names, constraint name, and rescale factor.
38
-
39
- Args:
40
- neuron_names (set[str]): A set of neuron names that are affected by the constraint.
41
- constraint_name (str, optional): A custom name for the constraint. If not provided, a random name is generated.
42
- rescale_factor (float, optional): A factor that scales the influence of the constraint. Defaults to 1.5.
43
-
44
- Raises:
45
- ValueError: If the descriptor has not been set or if a neuron name is not found in the descriptor.
46
- """
47
35
 
48
36
  # Init parent class
49
37
  super().__init__()
50
38
 
51
39
  # Init object variables
40
+ self.neurons = neurons
52
41
  self.rescale_factor = rescale_factor
53
- self.neuron_names = neuron_names
54
42
 
55
43
  # Perform checks
56
44
  if rescale_factor <= 1:
57
45
  logging.warning(
58
- f"Rescale factor for constraint {constraint_name} is <= 1. The network will favour general loss over the constraint-adjusted loss. Is this intended behaviour? Normally, the loss should always be larger than 1."
46
+ f"Rescale factor for constraint {name} is <= 1. The network will favor general loss over the constraint-adjusted loss. Is this intended behaviour? Normally, the loss should always be larger than 1."
59
47
  )
60
48
 
61
49
  # If no constraint_name is set, generate one based on the class name and a random suffix
62
- if constraint_name:
63
- self.constraint_name = constraint_name
50
+ if name:
51
+ self.name = name
64
52
  else:
65
53
  random_suffix = "".join(
66
54
  random.choices(string.ascii_uppercase + string.digits, k=6)
67
55
  )
68
- self.constraint_name = f"{self.__class__.__name__}_{random_suffix}"
69
- logging.warning(
70
- f"Name for constraint is not set. Using {self.constraint_name}."
71
- )
72
-
73
- if self.descriptor == None:
74
- raise ValueError(
75
- "The descriptor of the base Constraint class in not set. Please assign the descriptor to the general Constraint class with 'Constraint.descriptor = descriptor' before defining network-specific contraints."
76
- )
56
+ self.name = f"{self.__class__.__name__}_{random_suffix}"
57
+ logging.warning(f"Name for constraint is not set. Using {self.name}.")
77
58
 
59
+ # If rescale factor is not larger than 1, warn user and adjust
78
60
  if not rescale_factor > 1:
79
61
  self.rescale_factor = abs(rescale_factor) + 1.5
80
62
  logging.warning(
81
- f"Rescale factor for constraint {constraint_name} is < 1, adjusted value {rescale_factor} to {self.rescale_factor}."
63
+ f"Rescale factor for constraint {name} is < 1, adjusted value {rescale_factor} to {self.rescale_factor}."
82
64
  )
83
65
  else:
84
66
  self.rescale_factor = rescale_factor
85
67
 
86
- self.neuron_names = neuron_names
87
-
88
- self.run_init_descriptor()
89
-
90
- def run_init_descriptor(self) -> None:
91
- """
92
- Initializes the layers associated with the constraint by mapping the neuron names to their corresponding layers
93
- from the descriptor.
94
-
95
- This method populates the `layers` attribute with layers associated with the neuron names provided in the constraint.
96
-
97
- Raises:
98
- ValueError: If a neuron name is not found in the descriptor's mapping of neurons to layers.
99
- """
100
-
68
+ # Infer layers from descriptor and neurons
101
69
  self.layers = set()
102
- for neuron_name in self.neuron_names:
103
- if neuron_name in self.descriptor.neuron_to_layer.keys():
104
- self.layers.add(self.descriptor.neuron_to_layer[neuron_name])
105
- else:
70
+ for neuron in self.neurons:
71
+ if neuron not in self.descriptor.neuron_to_layer.keys():
106
72
  raise ValueError(
107
- f'The neuron name {neuron_name} used with constraint {self.constraint_name} is not defined in the descriptor. Please add it to the correct layer using descriptor.add("layer", ...).'
73
+ f'The neuron name {neuron} used with constraint {self.name} is not defined in the descriptor. Please add it to the correct layer using descriptor.add("layer", ...).'
108
74
  )
109
75
 
110
- @abstractmethod
111
- def check_constraint(self, prediction: dict[str, Tensor]) -> Dict[str, Tensor]:
112
- """
113
- Abstract method to check if the constraint is satisfied based on the model's predictions.
114
-
115
- This method should be implemented in subclasses to define the specific logic for evaluating the constraint based on the model's predictions.
116
-
117
- Args:
118
- prediction (dict[str, Tensor]): A dictionary of model predictions, indexed by layer names.
76
+ self.layers.add(self.descriptor.neuron_to_layer[neuron])
119
77
 
120
- Returns:
121
- dict[str, Tensor]: A dictionary containing the satisfaction status of the constraint for each layer or neuron.
78
+ # TODO only denormalize if required for efficiency
79
+ def _denormalize(self, input: Tensor, neuron_names: list[str]):
80
+ # Extract min and max for each neuron
81
+ min_values = tensor(
82
+ [self.descriptor.neuron_to_minmax[name][0] for name in neuron_names],
83
+ device=input.device,
84
+ )
85
+ max_values = tensor(
86
+ [self.descriptor.neuron_to_minmax[name][1] for name in neuron_names],
87
+ device=input.device,
88
+ )
122
89
 
123
- Raises:
124
- NotImplementedError: If the method is not implemented in a subclass.
125
- """
90
+ # Apply vectorized denormalization
91
+ return input * (max_values - min_values) + min_values
126
92
 
93
+ @abstractmethod
94
+ def check_constraint(self, prediction: dict[str, Tensor]) -> Tensor:
127
95
  raise NotImplementedError
128
96
 
129
97
  @abstractmethod
130
98
  def calculate_direction(self, prediction: dict[str, Tensor]) -> Dict[str, Tensor]:
131
- """
132
- Abstract method to calculate the direction in which the model's predictions need to be adjusted to satisfy the constraint.
133
-
134
- This method should be implemented in subclasses to define how to adjust the model's predictions based on the constraint.
135
-
136
- Args:
137
- prediction (dict[str, Tensor]): A dictionary of model predictions, indexed by layer names.
138
-
139
- Returns:
140
- dict[str, Tensor]: A dictionary containing the direction for each layer or neuron, to adjust the model's predictions.
141
-
142
- Raises:
143
- NotImplementedError: If the method is not implemented in a subclass.
144
- """
145
99
  raise NotImplementedError
146
100
 
147
101
 
148
102
  class ScalarConstraint(Constraint):
149
- """
150
- A subclass of the `Constraint` class that applies a scalar constraint on a specific neuron in the model.
151
-
152
- This constraint compares the value of a specific neuron in the model to a scalar value using a specified comparator (e.g., greater than, less than).
153
- If the constraint is violated, it adjusts the loss according to the direction defined by the comparator.
154
-
155
- Attributes:
156
- comparator (Callable[[Tensor, Number], Tensor]): A comparator function (e.g., greater than, less than) to evaluate the constraint.
157
- scalar (Number): The scalar value to compare the neuron value against.
158
- direction (int): The direction in which the constraint should adjust the model's predictions (either 1 or -1 based on the comparator).
159
- layer (str): The layer associated with the specified neuron.
160
- index (int): The index of the specified neuron within the layer.
161
- """
162
103
 
163
104
  def __init__(
164
105
  self,
@@ -166,23 +107,8 @@ class ScalarConstraint(Constraint):
166
107
  comparator: Callable[[Tensor, Number], Tensor],
167
108
  scalar: Number,
168
109
  name: str = None,
169
- descriptor: Descriptor = None,
170
110
  rescale_factor: float = 1.5,
171
111
  ) -> None:
172
- """
173
- Initializes the ScalarConstraint with the given neuron name, comparator, scalar value, and other optional parameters.
174
-
175
- Args:
176
- neuron_name (str): The name of the neuron that the constraint applies to.
177
- comparator (Callable[[Tensor, Number], Tensor]): The comparator function used to evaluate the constraint (e.g., ge, le, gt, lt).
178
- scalar (Number): The scalar value that the neuron value is compared to.
179
- name (str, optional): A custom name for the constraint. If not provided, a name is generated based on the neuron name, comparator, and scalar.
180
- descriptor (Descriptor, optional): The descriptor that maps neurons to layers. If not provided, the global descriptor is used.
181
- rescale_factor (float, optional): A factor that scales the influence of the constraint on the overall loss. Defaults to 1.5.
182
-
183
- Raises:
184
- ValueError: If the comparator function is not one of the supported comparison operators (ge, le, gt, lt).
185
- """
186
112
 
187
113
  # Compose constraint name
188
114
  name = f"{neuron_name}_{comparator.__name__}_{str(scalar)}"
@@ -194,10 +120,6 @@ class ScalarConstraint(Constraint):
194
120
  self.comparator = comparator
195
121
  self.scalar = scalar
196
122
 
197
- if descriptor != None:
198
- self.descriptor = descriptor
199
- self.run_init_descriptor()
200
-
201
123
  # Get layer name and feature index from neuron_name
202
124
  self.layer = self.descriptor.neuron_to_layer[neuron_name]
203
125
  self.index = self.descriptor.neuron_to_index[neuron_name]
@@ -210,69 +132,31 @@ class ScalarConstraint(Constraint):
210
132
 
211
133
  # Calculate directions based on constraint operator
212
134
  if self.comparator in [lt, le]:
213
- self.direction = 1
214
- elif self.comparator in [gt, ge]:
215
135
  self.direction = -1
136
+ elif self.comparator in [gt, ge]:
137
+ self.direction = 1
216
138
 
217
- def check_constraint(self, prediction: dict[str, Tensor]) -> Dict[str, Tensor]:
218
- """
219
- Checks if the constraint is satisfied based on the model's predictions.
220
-
221
- The constraint is evaluated by applying the comparator to the value of the specified neuron and the scalar value.
222
-
223
- Args:
224
- prediction (dict[str, Tensor]): A dictionary of model predictions, indexed by layer names.
225
-
226
- Returns:
227
- dict[str, Tensor]: A dictionary containing the constraint satisfaction result for the specified layer.
228
- """
229
-
230
- result = ~self.comparator(prediction[self.layer][:, self.index], self.scalar)
139
+ def check_constraint(self, prediction: dict[str, Tensor]) -> Tensor:
231
140
 
232
- return {self.layer: result}
141
+ return ~self.comparator(prediction[self.layer][:, self.index], self.scalar)
233
142
 
234
143
  def calculate_direction(self, prediction: dict[str, Tensor]) -> Dict[str, Tensor]:
235
- """
236
- Calculates the direction in which the model's predictions need to be adjusted to satisfy the constraint.
144
+ # NOTE currently only works for dense layers due to neuron to index translation
237
145
 
238
- The direction is determined by the comparator and represents either a positive or negative adjustment.
146
+ output = {}
239
147
 
240
- Args:
241
- prediction (dict[str, Tensor]): A dictionary of model predictions, indexed by layer names.
148
+ for layer in self.layers:
149
+ output[layer] = zeros_like(prediction[layer][0])
242
150
 
243
- Returns:
244
- dict[str, Tensor]: A dictionary containing the direction for each layer or neuron, to adjust the model's predictions.
245
- """
151
+ output[self.layer][self.index] = self.direction
246
152
 
247
- output = zeros(
248
- prediction[self.layer].size(),
249
- device=prediction[self.layer].device,
250
- )
251
- output[:, self.index] = self.direction
153
+ for layer in self.layers:
154
+ output[layer] = normalize(reshape(output[layer], [1, -1]), dim=1)
252
155
 
253
- return {self.layer: output}
156
+ return output
254
157
 
255
158
 
256
159
  class BinaryConstraint(Constraint):
257
- """
258
- A class representing a binary constraint between two neurons in a neural network.
259
-
260
- This class checks and enforces a constraint between two neurons using a
261
- comparator function. The constraint is applied between two neurons located
262
- in different layers of the neural network. The class also calculates the
263
- direction for gradient adjustment based on the comparator.
264
-
265
- Attributes:
266
- neuron_name_left (str): The name of the first neuron involved in the constraint.
267
- neuron_name_right (str): The name of the second neuron involved in the constraint.
268
- comparator (Callable[[Tensor, Number], Tensor]): A function that compares the values of the two neurons.
269
- layer_left (str): The layer name for the first neuron.
270
- layer_right (str): The layer name for the second neuron.
271
- index_left (int): The index of the first neuron within its layer.
272
- index_right (int): The index of the second neuron within its layer.
273
- direction_left (float): The normalized direction for gradient adjustment of the first neuron.
274
- direction_right (float): The normalized direction for gradient adjustment of the second neuron.
275
- """
276
160
 
277
161
  def __init__(
278
162
  self,
@@ -280,34 +164,22 @@ class BinaryConstraint(Constraint):
280
164
  comparator: Callable[[Tensor, Number], Tensor],
281
165
  neuron_name_right: str,
282
166
  name: str = None,
283
- descriptor: Descriptor = None,
284
167
  rescale_factor: float = 1.5,
285
168
  ) -> None:
286
- """
287
- Initializes the binary constraint with two neurons, a comparator, and other configuration options.
288
-
289
- Args:
290
- neuron_name_left (str): The name of the first neuron in the constraint.
291
- comparator (Callable[[Tensor, Number], Tensor]): A function that compares the values of the two neurons.
292
- neuron_name_right (str): The name of the second neuron in the constraint.
293
- name (str, optional): The name of the constraint. If not provided, a default name is generated.
294
- descriptor (Descriptor, optional): The descriptor containing the mapping of neurons to layers.
295
- rescale_factor (float, optional): A factor to rescale the constraint value. Default is 1.5.
296
- """
297
169
 
298
170
  # Compose constraint name
299
171
  name = f"{neuron_name_left}_{comparator.__name__}_{neuron_name_right}"
300
172
 
301
173
  # Init parent class
302
- super().__init__({neuron_name_left, neuron_name_right}, name, rescale_factor)
174
+ super().__init__(
175
+ {neuron_name_left, neuron_name_right},
176
+ name,
177
+ rescale_factor,
178
+ )
303
179
 
304
180
  # Init variables
305
181
  self.comparator = comparator
306
182
 
307
- if descriptor != None:
308
- self.descriptor = descriptor
309
- self.run_init_descriptor()
310
-
311
183
  # Get layer name and feature index from neuron_name
312
184
  self.layer_left = self.descriptor.neuron_to_layer[neuron_name_left]
313
185
  self.layer_right = self.descriptor.neuron_to_layer[neuron_name_right]
@@ -328,65 +200,30 @@ class BinaryConstraint(Constraint):
328
200
  self.direction_left = 1
329
201
  self.direction_right = -1
330
202
 
331
- # Normalize directions
332
- normalized_directions = normalize(
333
- tensor([self.direction_left, self.direction_right]).type(float32),
334
- p=2,
335
- dim=0,
336
- )
337
- self.direction_left = normalized_directions[0]
338
- self.direction_right = normalized_directions[1]
339
-
340
- def check_constraint(self, prediction: dict[str, Tensor]) -> Dict[str, Tensor]:
341
- """
342
- Checks whether the binary constraint is satisfied between the two neurons.
343
-
344
- This function applies the comparator to the output values of the two neurons
345
- and returns a Boolean result for each neuron.
346
-
347
- Args:
348
- prediction (dict[str, Tensor]): A dictionary containing the predictions for each layer.
349
-
350
- Returns:
351
- dict[str, Tensor]: A dictionary with the layer names as keys and the constraint satisfaction results as values.
352
- """
203
+ def check_constraint(self, prediction: dict[str, Tensor]) -> Tensor:
353
204
 
354
- result = ~self.comparator(
205
+ return ~self.comparator(
355
206
  prediction[self.layer_left][:, self.index_left],
356
207
  prediction[self.layer_right][:, self.index_right],
357
208
  )
358
209
 
359
- return {self.layer_left: result, self.layer_right: result}
360
-
361
210
  def calculate_direction(self, prediction: dict[str, Tensor]) -> Dict[str, Tensor]:
362
- """
363
- Calculates the direction for gradient adjustment for both neurons involved in the constraint.
211
+ # NOTE currently only works for dense layers due to neuron to index translation
364
212
 
365
- The directions are normalized and represent the direction in which the constraint should be enforced.
213
+ output = {}
366
214
 
367
- Args:
368
- prediction (dict[str, Tensor]): A dictionary containing the predictions for each layer.
215
+ for layer in self.layers:
216
+ output[layer] = zeros_like(prediction[layer][0])
369
217
 
370
- Returns:
371
- dict[str, Tensor]: A dictionary with the layer names as keys and the direction vectors as values.
372
- """
218
+ output[self.layer_left][self.index_left] = self.direction_left
219
+ output[self.layer_right][self.index_right] = self.direction_right
373
220
 
374
- output_left = zeros(
375
- prediction[self.layer_left].size(),
376
- device=prediction[self.layer_left].device,
377
- )
378
- output_left[:, self.index_left] = self.direction_left
221
+ for layer in self.layers:
222
+ output[layer] = normalize(reshape(output[layer], [1, -1]), dim=1)
379
223
 
380
- output_right = zeros(
381
- prediction[self.layer_right].size(),
382
- device=prediction[self.layer_right].device,
383
- )
384
- output_right[:, self.index_right] = self.direction_right
224
+ return output
385
225
 
386
- return {self.layer_left: output_left, self.layer_right: output_right}
387
226
 
388
-
389
- # FIXME
390
227
  class SumConstraint(Constraint):
391
228
  def __init__(
392
229
  self,
@@ -396,34 +233,17 @@ class SumConstraint(Constraint):
396
233
  weights_left: list[float] = None,
397
234
  weights_right: list[float] = None,
398
235
  name: str = None,
399
- descriptor: Descriptor = None,
400
236
  rescale_factor: float = 1.5,
401
237
  ) -> None:
402
238
 
403
239
  # Init parent class
404
- super().__init__(
405
- set(neuron_names_left) & set(neuron_names_right), name, rescale_factor
406
- )
240
+ neuron_names = set(neuron_names_left) | set(neuron_names_right)
241
+ super().__init__(neuron_names, name, rescale_factor)
407
242
 
408
243
  # Init variables
409
244
  self.comparator = comparator
410
-
411
- if descriptor != None:
412
- self.descriptor = descriptor
413
- self.run_init_descriptor()
414
-
415
- # Get layer names and feature indices from neuron_name
416
- self.layers_left = []
417
- self.indices_left = []
418
- for neuron_name in neuron_names_left:
419
- self.layers_left.append(self.descriptor.neuron_to_layer[neuron_name])
420
- self.indices_left.append(self.descriptor.neuron_to_index[neuron_name])
421
-
422
- self.layers_right = []
423
- self.indices_right = []
424
- for neuron_name in neuron_names_right:
425
- self.layers_right.append(self.descriptor.neuron_to_layer[neuron_name])
426
- self.indices_right.append(self.descriptor.neuron_to_index[neuron_name])
245
+ self.neuron_names_left = neuron_names_left
246
+ self.neuron_names_right = neuron_names_right
427
247
 
428
248
  # If comparator function is not supported, raise error
429
249
  if comparator not in [ge, le, gt, lt]:
@@ -443,13 +263,13 @@ class SumConstraint(Constraint):
443
263
 
444
264
  # If weights are provided for summation, transform them to Tensors
445
265
  if weights_left:
446
- self.weights_left = FloatTensor(weights_left)
266
+ self.weights_left = tensor(weights_left, device=self.device)
447
267
  else:
448
- self.weights_left = ones(len(neuron_names_left))
268
+ self.weights_left = ones(len(neuron_names_left), device=self.device)
449
269
  if weights_right:
450
- self.weights_right = FloatTensor(weights_right)
270
+ self.weights_right = tensor(weights_right, device=self.device)
451
271
  else:
452
- self.weights_right = ones(len(neuron_names_right))
272
+ self.weights_right = ones(len(neuron_names_right), device=self.device)
453
273
 
454
274
  # Calculate directions based on constraint operator
455
275
  if self.comparator in [lt, le]:
@@ -459,49 +279,111 @@ class SumConstraint(Constraint):
459
279
  self.direction_left = 1
460
280
  self.direction_right = -1
461
281
 
462
- # Normalize directions
463
- normalized_directions = normalize(
464
- tensor(self.direction_left, self.direction_right), p=2, dim=0
465
- )
466
- self.direction_left = normalized_directions[0]
467
- self.direction_right = normalized_directions[1]
282
+ def check_constraint(self, prediction: dict[str, Tensor]) -> Tensor:
283
+
284
+ def compute_weighted_sum(neuron_names: list[str], weights: tensor) -> tensor:
285
+ layers = [
286
+ self.descriptor.neuron_to_layer[neuron_name]
287
+ for neuron_name in neuron_names
288
+ ]
289
+ indices = [
290
+ self.descriptor.neuron_to_index[neuron_name]
291
+ for neuron_name in neuron_names
292
+ ]
293
+
294
+ # Extract predictions for all neurons and apply weights in bulk
295
+ predictions = stack(
296
+ [prediction[layer][:, index] for layer, index in zip(layers, indices)],
297
+ dim=1,
298
+ )
468
299
 
469
- def check_constraint(self, prediction: dict[str, Tensor]) -> Dict[str, Tensor]:
470
- raise NotImplementedError
471
- # # TODO remove the dynamic to device conversion and do this in initialization one way or another
472
- # weighted_sum_left = (
473
- # prediction[layer_left][:, index_left]
474
- # * self.weights_left.to(prediction[layer_left].device)
475
- # ).sum(dim=1)
476
- # weighted_sum_right = (
477
- # prediction[layer_right][:, index_right]
478
- # * self.weights_right.to(prediction[layer_right].device)
479
- # ).sum(dim=1)
300
+ # Denormalize if required
301
+ predictions_denorm = self._denormalize(predictions, neuron_names)
302
+
303
+ # Calculate weighted sum
304
+ weighted_sum = (predictions_denorm * weights.unsqueeze(0)).sum(dim=1)
480
305
 
481
- # result = ~self.comparator(weighted_sum_left, weighted_sum_right)
306
+ return weighted_sum
482
307
 
483
- # return {layer_left: result, layer_right: result}
484
- pass
308
+ weighted_sum_left = compute_weighted_sum(
309
+ self.neuron_names_left, self.weights_left
310
+ )
311
+ weighted_sum_right = compute_weighted_sum(
312
+ self.neuron_names_right, self.weights_right
313
+ )
314
+
315
+ # Apply the comparator and calculate the result
316
+ return ~self.comparator(weighted_sum_left, weighted_sum_right)
485
317
 
486
318
  def calculate_direction(self, prediction: dict[str, Tensor]) -> Dict[str, Tensor]:
487
- raise NotImplementedError
488
- # # TODO move this to constructor somehow
489
- # layer_left = prediction.neuron_to_layer[self.neuron_name_left]
490
- # layer_right = prediction.neuron_to_layer[self.neuron_name_right]
491
- # index_left = prediction.neuron_to_index[self.neuron_name_left]
492
- # index_right = prediction.neuron_to_index[self.neuron_name_right]
493
-
494
- # output_left = zeros(
495
- # prediction[layer_left].size(),
496
- # device=prediction[layer_left].device,
497
- # )
498
- # output_left[:, index_left] = self.direction_left
499
-
500
- # output_right = zeros(
501
- # prediction.layer_to_data[layer_right].size(),
502
- # device=prediction.layer_to_data[layer_right].device,
503
- # )
504
- # output_right[:, index_right] = self.direction_right
505
-
506
- # return {layer_left: output_left, layer_right: output_right}
507
- pass
319
+ # NOTE currently only works for dense layers due to neuron to index translation
320
+
321
+ output = {}
322
+
323
+ for layer in self.layers:
324
+ output[layer] = zeros_like(prediction[layer][0])
325
+
326
+ for neuron_name_left in self.neuron_names_left:
327
+ layer = self.descriptor.neuron_to_layer[neuron_name_left]
328
+ index = self.descriptor.neuron_to_index[neuron_name_left]
329
+ output[layer][index] = self.direction_left
330
+
331
+ for neuron_name_right in self.neuron_names_right:
332
+ layer = self.descriptor.neuron_to_layer[neuron_name_right]
333
+ index = self.descriptor.neuron_to_index[neuron_name_right]
334
+ output[layer][index] = self.direction_right
335
+
336
+ for layer in self.layers:
337
+ output[layer] = normalize(reshape(output[layer], [1, -1]), dim=1)
338
+
339
+ return output
340
+
341
+
342
+ # class MonotonicityConstraint(Constraint):
343
+ # # TODO docstring
344
+
345
+ # def __init__(
346
+ # self,
347
+ # neuron_name: str,
348
+ # name: str = None,
349
+ # descriptor: Descriptor = None,
350
+ # rescale_factor: float = 1.5,
351
+ # ) -> None:
352
+
353
+ # # Compose constraint name
354
+ # name = f"Monotonicity_{neuron_name}"
355
+
356
+ # # Init parent class
357
+ # super().__init__({neuron_name}, name, rescale_factor)
358
+
359
+ # # Init variables
360
+ # if descriptor != None:
361
+ # self.descriptor = descriptor
362
+ # self.run_init_descriptor()
363
+
364
+ # # Get layer name and feature index from neuron_name
365
+ # self.layer = self.descriptor.neuron_to_layer[neuron_name]
366
+ # self.index = self.descriptor.neuron_to_index[neuron_name]
367
+
368
+ # def check_constraint(self, prediction: dict[str, Tensor]) -> Dict[str, Tensor]:
369
+ # # Check if values for column in batch are only increasing
370
+ # result = ~ge(
371
+ # diff(
372
+ # prediction[self.layer][:, self.index],
373
+ # prepend=zeros_like(
374
+ # prediction[self.layer][:, self.index][:1],
375
+ # device=prediction[self.layer].device,
376
+ # ),
377
+ # ),
378
+ # 0,
379
+ # )
380
+
381
+ # return {self.layer: result}
382
+
383
+ # def calculate_direction(self, prediction: dict[str, Tensor]) -> Dict[str, Tensor]:
384
+ # # TODO implement
385
+
386
+ # output = {self.layer: zeros_like(prediction[self.layer][0])}
387
+ # output[self.layer][self.index] = 1
388
+
389
+ # return output