broccoli-ml 0.1.40__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
broccoli/cnn.py CHANGED
@@ -9,355 +9,437 @@ from typing import Type, Union, Tuple, Optional, Literal
9
9
  from einops.layers.torch import Rearrange
10
10
 
11
11
 
12
- # Helper function to calculate padding for 'same' mode
13
- # Adapted from https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/conv.py
14
- def _calculate_same_padding(
15
- input_size: Tuple[int, int],
16
- kernel_size: Tuple[int, int],
17
- stride: Tuple[int, int],
18
- dilation: Tuple[int, int],
19
- ) -> Tuple[int, int, int, int]:
20
- """Calculates padding for 'same' output shape."""
21
- ih, iw = input_size
22
- kh, kw = kernel_size
23
- sh, sw = stride
24
- dh, dw = dilation
25
-
26
- # Effective kernel size
27
- eff_kh = (kh - 1) * dh + 1
28
- eff_kw = (kw - 1) * dw + 1
29
-
30
- # Calculate required total padding
31
- out_h = (ih + sh - 1) // sh
32
- out_w = (iw + sw - 1) // sw
33
- pad_h = max((out_h - 1) * sh + eff_kh - ih, 0)
34
- pad_w = max((out_w - 1) * sw + eff_kw - iw, 0)
35
-
36
- # Distribute padding (similar to TensorFlow 'SAME' behavior)
37
- pad_top = pad_h // 2
38
- pad_bottom = pad_h - pad_top
39
- pad_left = pad_w // 2
40
- pad_right = pad_w - pad_left
41
- return (pad_left, pad_right, pad_top, pad_bottom)
42
-
43
-
44
- # Custom Convolution Layer
45
- class ConvLayer(nn.Module):
12
+ # # Helper function to calculate padding for 'same' mode
13
+ # # Adapted from https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/conv.py
14
+ # def _calculate_same_padding(
15
+ # input_size: Tuple[int, int],
16
+ # kernel_size: Tuple[int, int],
17
+ # stride: Tuple[int, int],
18
+ # dilation: Tuple[int, int],
19
+ # ) -> Tuple[int, int, int, int]:
20
+ # """Calculates padding for 'same' output shape."""
21
+ # ih, iw = input_size
22
+ # kh, kw = kernel_size
23
+ # sh, sw = stride
24
+ # dh, dw = dilation
25
+
26
+ # # Effective kernel size
27
+ # eff_kh = (kh - 1) * dh + 1
28
+ # eff_kw = (kw - 1) * dw + 1
29
+
30
+ # # Calculate required total padding
31
+ # out_h = (ih + sh - 1) // sh
32
+ # out_w = (iw + sw - 1) // sw
33
+ # pad_h = max((out_h - 1) * sh + eff_kh - ih, 0)
34
+ # pad_w = max((out_w - 1) * sw + eff_kw - iw, 0)
35
+
36
+ # # Distribute padding (similar to TensorFlow 'SAME' behavior)
37
+ # pad_top = pad_h // 2
38
+ # pad_bottom = pad_h - pad_top
39
+ # pad_left = pad_w // 2
40
+ # pad_right = pad_w - pad_left
41
+ # return (pad_left, pad_right, pad_top, pad_bottom)
42
+
43
+
44
+ # # Custom Convolution Layer
45
+ # class ConvLayer(nn.Module):
46
+ # """
47
+ # A 2D Convolution layer implemented using torch.nn.Unfold and a custom linear layer.
48
+
49
+ # This layer mimics the behavior of torch.nn.Conv2d but allows injecting
50
+ # a different linear layer implementation for processing the unfolded patches.
51
+
52
+ # Args:
53
+ # in_channels (int): Number of channels in the input image.
54
+ # out_channels (int): Number of channels produced by the convolution.
55
+ # kernel_size (int or tuple): Size of the convolving kernel.
56
+ # stride (int or tuple, optional): Stride of the convolution. Default: 1.
57
+ # padding (int, tuple or str, optional): Padding added to all four sides
58
+ # of the input. Can be an int, a tuple of two ints (padH, padW),
59
+ # a tuple of four ints (padLeft, padRight, padTop, padBottom),
60
+ # or the strings 'valid' (no padding) or 'same' (padding for same
61
+ # output spatial dims as input). Default: 0 ('valid').
62
+ # dilation (int or tuple, optional): Spacing between kernel elements. Default: 1.
63
+ # bias (bool, optional): If True, adds a learnable bias to the output.
64
+ # The bias is handled by the underlying linear layer. Default: True.
65
+ # linear (Type[nn.Module], optional): The class of the linear layer
66
+ # to use for the kernel operation. Must accept (in_features, out_features, bias)
67
+ # in its constructor. Defaults to torch.nn.Linear.
68
+ # """
69
+
70
+ # def __init__(
71
+ # self,
72
+ # in_channels: int,
73
+ # out_channels: int,
74
+ # kernel_size: Union[int, Tuple[int, int]],
75
+ # stride: Union[int, Tuple[int, int]] = 1,
76
+ # padding: Union[
77
+ # int, Tuple[int, int], Tuple[int, int, int, int], Literal["valid", "same"]
78
+ # ] = 0,
79
+ # dilation: Union[int, Tuple[int, int]] = 1,
80
+ # bias: bool = True,
81
+ # linear_module: Type[nn.Module] = nn.Linear,
82
+ # ):
83
+ # super().__init__()
84
+ # self.in_channels = in_channels
85
+ # self.out_channels = out_channels
86
+ # self.kernel_size = _pair(kernel_size)
87
+ # self.stride = _pair(stride)
88
+ # self.dilation = _pair(dilation)
89
+ # self.bias = bias
90
+ # self.linear_module = linear_module
91
+ # self.padding_mode = (
92
+ # padding # Store the original padding mode ('same', 'valid', int, or tuple)
93
+ # )
94
+
95
+ # # Calculate the number of input features for the linear layer
96
+ # # It's the number of channels times the kernel area
97
+ # self.linear_in_features = (
98
+ # in_channels * self.kernel_size[0] * self.kernel_size[1]
99
+ # )
100
+
101
+ # # Instantiate the linear layer (kernel)
102
+ # self.kernel = self.linear_module(
103
+ # self.linear_in_features, out_channels, bias=bias
104
+ # )
105
+
106
+ # # We will use F.pad for manual padding, so unfold padding is 0
107
+ # self.unfold = nn.Unfold(
108
+ # kernel_size=self.kernel_size,
109
+ # dilation=self.dilation,
110
+ # padding=0, # Manual padding handled in forward
111
+ # stride=self.stride,
112
+ # )
113
+
114
+ # # Determine numeric padding values for F.pad
115
+ # if isinstance(padding, str):
116
+ # if padding not in ["valid", "same"]:
117
+ # raise ValueError("padding must be 'valid', 'same', an int, or a tuple")
118
+ # # 'same' padding calculation depends on input size, defer to forward pass
119
+ # # 'valid' padding means 0
120
+ # self._padding_val = (
121
+ # (0, 0, 0, 0) if padding == "valid" else None
122
+ # ) # None indicates 'same'
123
+ # elif isinstance(padding, int):
124
+ # self._padding_val = (padding,) * 4
125
+ # elif isinstance(padding, tuple) and len(padding) == 2:
126
+ # # (padH, padW) -> (padW_left, padW_right, padH_top, padH_bottom)
127
+ # self._padding_val = (padding[1], padding[1], padding[0], padding[0])
128
+ # elif isinstance(padding, tuple) and len(padding) == 4:
129
+ # # (padLeft, padRight, padTop, padBottom) - already in F.pad format
130
+ # self._padding_val = padding
131
+ # else:
132
+ # raise TypeError(
133
+ # "padding must be 'valid', 'same', an int, or a tuple of 2 or 4 ints"
134
+ # )
135
+
136
+ # def _calculate_output_shape(self, h_in: int, w_in: int) -> Tuple[int, int]:
137
+ # """Calculates the output height and width."""
138
+ # if self._padding_val is None: # 'same' padding
139
+ # # For 'same' padding, output size matches input size if stride is 1.
140
+ # # If stride > 1, output size is ceil(input_size / stride)
141
+ # # The _calculate_same_padding helper ensures this behavior.
142
+ # oh = math.ceil(h_in / self.stride[0])
143
+ # ow = math.ceil(w_in / self.stride[1])
144
+ # return oh, ow
145
+ # else:
146
+ # # Use the standard formula with the calculated numeric padding
147
+ # pad_h = self._padding_val[2] + self._padding_val[3] # top + bottom
148
+ # pad_w = self._padding_val[0] + self._padding_val[1] # left + right
149
+ # kh, kw = self.kernel_size
150
+ # sh, sw = self.stride
151
+ # dh, dw = self.dilation
152
+
153
+ # eff_kh = (kh - 1) * dh + 1
154
+ # eff_kw = (kw - 1) * dw + 1
155
+
156
+ # oh = math.floor((h_in + pad_h - eff_kh) / sh + 1)
157
+ # ow = math.floor((w_in + pad_w - eff_kw) / sw + 1)
158
+ # return oh, ow
159
+
160
+ # def forward(self, x: torch.Tensor) -> torch.Tensor:
161
+ # """
162
+ # Performs the forward pass.
163
+
164
+ # Args:
165
+ # x (torch.Tensor): Input tensor of shape (N, C_in, H_in, W_in).
166
+
167
+ # Returns:
168
+ # torch.Tensor: Output tensor of shape (N, C_out, H_out, W_out).
169
+ # """
170
+ # _, C, H, W = x.shape
171
+ # if C != self.in_channels:
172
+ # raise ValueError(
173
+ # f"Input channels {C} does not match expected {self.in_channels}"
174
+ # )
175
+
176
+ # # 1. Calculate and Apply Padding
177
+ # if self._padding_val is None: # 'same' padding mode
178
+ # pad_l, pad_r, pad_t, pad_b = _calculate_same_padding(
179
+ # (H, W), self.kernel_size, self.stride, self.dilation
180
+ # )
181
+ # padded_x = F.pad(x, (pad_l, pad_r, pad_t, pad_b))
182
+ # # Update H, W for output shape calculation after padding
183
+ # # Note: _calculate_output_shape will correctly handle 'same' based on original H, W
184
+ # elif self._padding_val != (0, 0, 0, 0):
185
+ # padded_x = F.pad(x, self._padding_val)
186
+ # else: # No padding ('valid' or explicit 0)
187
+ # padded_x = x
188
+
189
+ # # 2. Unfold to extract patches
190
+ # # Input: (N, C_in, H_pad, W_pad)
191
+ # # Output: (N, C_in * K_h * K_w, L), where L is the number of patches (H_out * W_out)
192
+ # patches = self.unfold(padded_x)
193
+ # num_patches = patches.shape[-1] # L
194
+
195
+ # # 3. Reshape for the linear layer
196
+ # # We want (N, L, C_in * K_h * K_w) to apply the linear layer patch-wise
197
+ # # transpose switches the last two dimensions
198
+ # patches_transposed = patches.transpose(1, 2) # Shape: (N, L, C_in * K_h * K_w)
199
+
200
+ # # 4. Apply the linear layer (kernel) to each patch
201
+ # # Input: (N, L, linear_in_features)
202
+ # # Output: (N, L, out_channels)
203
+ # linear_output = self.kernel(patches_transposed)
204
+
205
+ # # 5. Reshape back to image format
206
+ # # We need (N, out_channels, L) first
207
+ # output_transposed = linear_output.transpose(1, 2) # Shape: (N, out_channels, L)
208
+
209
+ # # Calculate output spatial dimensions
210
+ # out_h, out_w = self._calculate_output_shape(H, W) # Use original H, W
211
+
212
+ # # Check if the number of patches matches the calculated output dimensions
213
+ # if num_patches != out_h * out_w:
214
+ # # This might happen with certain combinations of stride/padding/dilation/input size
215
+ # # if the calculation logic has an issue. nn.Unfold is usually robust.
216
+ # print(
217
+ # f"Warning: Mismatch in calculated patches. "
218
+ # f"Expected L={out_h * out_w}, got {num_patches}. "
219
+ # f"Using unfolded L={num_patches} to determine output shape."
220
+ # )
221
+ # # Attempt recovery if possible, though might indicate upstream calculation error
222
+ # # Find factors of num_patches close to expected out_h, out_w
223
+ # # This part is tricky and might not always yield the desired shape.
224
+ # # For simplicity, we'll rely on nn.Unfold's L and reshape.
225
+ # # A more robust solution might re-calculate H_out, W_out based *only* on L.
226
+ # # For now, let's stick to the reshape based on calculated out_h, out_w,
227
+ # # assuming they match L. If they don't, the reshape will fail.
228
+ # pass # Proceed with calculated out_h, out_w
229
+
230
+ # # Reshape using einops (or tensor.view)
231
+ # # Input: (N, C_out, L) -> Output: (N, C_out, H_out, W_out)
232
+ # output = rearrange(output_transposed, "n c (h w) -> n c h w", h=out_h, w=out_w)
233
+ # # Alternative using view:
234
+ # # output = output_transposed.view(N, self.out_channels, out_h, out_w)
235
+
236
+ # return output
237
+
238
+ # def extra_repr(self) -> str:
239
+ # s = (
240
+ # "{in_channels}, {out_channels}, kernel_size={kernel_size}"
241
+ # ", stride={stride}"
242
+ # )
243
+ # if self.padding_mode != 0 and self.padding_mode != "valid":
244
+ # s += ", padding={padding_mode}"
245
+ # if self.dilation != (1,) * len(self.dilation):
246
+ # s += ", dilation={dilation}"
247
+ # # if self.groups != 1: # Not implemented
248
+ # # s += ', groups={groups}'
249
+ # if self.bias is False:
250
+ # s += ", bias=False"
251
+ # if self.linear_module != nn.Linear:
252
+ # s += f", linear={self.linear.__name__}"
253
+ # return s.format(**self.__dict__)
254
+
255
+
256
+ # class WhiteningConv(ConvLayer):
257
+ # def __init__(
258
+ # self,
259
+ # in_channels: int,
260
+ # kernel_size: int,
261
+ # eigenvectors: torch.Tensor,
262
+ # bias: bool = True,
263
+ # linear_module: Type[nn.Module] = nn.Linear,
264
+ # ):
265
+ # """
266
+ # We end up using a concatenation of the eigenvector tensor with its negation,
267
+ # as the tendency to use e.g. ReLU in neural networks means that useful
268
+ # data may otherwise be lost (if one orientation of an eigenvector produces
269
+ # a strong negative signal, this will be clipped to zero by ReLU, but a
270
+ # strong positive signal from the negation of the eigenvector will be
271
+ # preserved). Assuming a square kernel, out channels is thus
272
+
273
+ # (kernel_size ** 2) * in_channels * 2
274
+
275
+ # where the trailing "* 2" accounts for the doubling of the size of the
276
+ # eigenvector tensor we're using by including the negative of each eigenvector
277
+ # as well.
278
+ # """
279
+ # out_channels = kernel_size**2 * in_channels * 2
280
+ # super().__init__(
281
+ # in_channels,
282
+ # out_channels,
283
+ # kernel_size,
284
+ # padding="same",
285
+ # bias=bias,
286
+ # linear_module=linear_module,
287
+ # )
288
+ # self.eigenvectors = torch.cat([eigenvectors, -eigenvectors], dim=0)
289
+ # # bias updates if `bias`=True but weight doesn't,
290
+ # # per Jordan (2024) https://arxiv.org/abs/2404.00498
291
+ # # but weight is set to `requires_grad = False`:
292
+ # # self.kernel.weight.requires_grad = False
293
+ # with torch.no_grad():
294
+ # self.kernel.weight.copy_(self.eigenvectors)
295
+ # assert self.kernel.weight.requires_grad
296
+
297
+
298
+ def spatial_tuple(size: Union[int, tuple], spatial_dimensions):
46
299
  """
47
- A 2D Convolution layer implemented using torch.nn.Unfold and a custom linear layer.
48
-
49
- This layer mimics the behavior of torch.nn.Conv2d but allows injecting
50
- a different linear layer implementation for processing the unfolded patches.
51
-
52
- Args:
53
- in_channels (int): Number of channels in the input image.
54
- out_channels (int): Number of channels produced by the convolution.
55
- kernel_size (int or tuple): Size of the convolving kernel.
56
- stride (int or tuple, optional): Stride of the convolution. Default: 1.
57
- padding (int, tuple or str, optional): Padding added to all four sides
58
- of the input. Can be an int, a tuple of two ints (padH, padW),
59
- a tuple of four ints (padLeft, padRight, padTop, padBottom),
60
- or the strings 'valid' (no padding) or 'same' (padding for same
61
- output spatial dims as input). Default: 0 ('valid').
62
- dilation (int or tuple, optional): Spacing between kernel elements. Default: 1.
63
- bias (bool, optional): If True, adds a learnable bias to the output.
64
- The bias is handled by the underlying linear layer. Default: True.
65
- linear (Type[nn.Module], optional): The class of the linear layer
66
- to use for the kernel operation. Must accept (in_features, out_features, bias)
67
- in its constructor. Defaults to torch.nn.Linear.
300
+ Converts an integer x to `tuple([x] * spatial_dimensions)`.
301
+ Performs no operation (i.e. the identity operation) on tuples of length `spatial_dimensions`.
302
+ Otherwise
68
303
  """
304
+ if isinstance(size, int):
305
+ return tuple([size] * spatial_dimensions)
306
+ elif isinstance(size, tuple) and (len(size) == spatial_dimensions):
307
+ return size
308
+ else:
309
+ raise ValueError(
310
+ f"For {spatial_dimensions} spatial dimensions, `size` must be "
311
+ f"an integer or a tuple of length {spatial_dimensions}."
312
+ )
313
+
314
+
315
+ def padding_tensor(padding: tuple):
316
+ """
317
+ Converts a tuple of ints (x, y, z) into a tuple of 2-tuples,
318
+ like ((x, x), (y, y), (z, z)).
319
+
320
+ Performs no operation (i.e. the identity operation) on a tuple of 2-tuples.
69
321
 
70
- def __init__(
71
- self,
72
- in_channels: int,
73
- out_channels: int,
74
- kernel_size: Union[int, Tuple[int, int]],
75
- stride: Union[int, Tuple[int, int]] = 1,
76
- padding: Union[
77
- int, Tuple[int, int], Tuple[int, int, int, int], Literal["valid", "same"]
78
- ] = 0,
79
- dilation: Union[int, Tuple[int, int]] = 1,
80
- bias: bool = True,
81
- linear_module: Type[nn.Module] = nn.Linear,
322
+ Otherwise raises an error.
323
+ """
324
+ if all(isinstance(x, int) for x in padding):
325
+ return tuple([tuple([p] * 2) for p in padding])
326
+ elif (
327
+ all(isinstance(p, tuple) for p in padding)
328
+ and all(len(p) == 2 for p in padding)
329
+ and all(all(isinstance(x, int) for x in p) for p in padding)
82
330
  ):
83
- super().__init__()
84
- self.in_channels = in_channels
85
- self.out_channels = out_channels
86
- self.kernel_size = _pair(kernel_size)
87
- self.stride = _pair(stride)
88
- self.dilation = _pair(dilation)
89
- self.bias = bias
90
- self.linear_module = linear_module
91
- self.padding_mode = (
92
- padding # Store the original padding mode ('same', 'valid', int, or tuple)
331
+ return padding
332
+ else:
333
+ raise ValueError(
334
+ "Padding must be a tuple of ints of a tuple of 2-tuples of ints. "
335
+ f"It was {padding}."
93
336
  )
94
337
 
95
- # Calculate the number of input features for the linear layer
96
- # It's the number of channels times the kernel area
97
- self.linear_in_features = (
98
- in_channels * self.kernel_size[0] * self.kernel_size[1]
99
- )
100
338
 
101
- # Instantiate the linear layer (kernel)
102
- self.kernel = self.linear_module(
103
- self.linear_in_features, out_channels, bias=bias
339
+ def kd_unfold(t: torch.Tensor, kernel_size=1, stride=1, padding=0, k=2):
340
+ """
341
+ Unfold operation with k spatial dimensions.
342
+ Does not support dilation.
343
+ Only supports equal padding at top and bottom.
344
+ """
345
+ if len(t.size()[2:]) != k:
346
+ raise ValueError(
347
+ f"Input tensor size should be (N, channels, spatial dims...), so "
348
+ f"for k = {k}, t.size() should be a tuple of length {k + 2}."
104
349
  )
105
350
 
106
- # We will use F.pad for manual padding, so unfold padding is 0
107
- self.unfold = nn.Unfold(
108
- kernel_size=self.kernel_size,
109
- dilation=self.dilation,
110
- padding=0, # Manual padding handled in forward
111
- stride=self.stride,
112
- )
351
+ N, C = t.size(0), t.size(1)
113
352
 
114
- # Determine numeric padding values for F.pad
115
- if isinstance(padding, str):
116
- if padding not in ["valid", "same"]:
117
- raise ValueError("padding must be 'valid', 'same', an int, or a tuple")
118
- # 'same' padding calculation depends on input size, defer to forward pass
119
- # 'valid' padding means 0
120
- self._padding_val = (
121
- (0, 0, 0, 0) if padding == "valid" else None
122
- ) # None indicates 'same'
123
- elif isinstance(padding, int):
124
- self._padding_val = (padding,) * 4
125
- elif isinstance(padding, tuple) and len(padding) == 2:
126
- # (padH, padW) -> (padW_left, padW_right, padH_top, padH_bottom)
127
- self._padding_val = (padding[1], padding[1], padding[0], padding[0])
128
- elif isinstance(padding, tuple) and len(padding) == 4:
129
- # (padLeft, padRight, padTop, padBottom) - already in F.pad format
130
- self._padding_val = padding
131
- else:
132
- raise TypeError(
133
- "padding must be 'valid', 'same', an int, or a tuple of 2 or 4 ints"
134
- )
135
-
136
- def _calculate_output_shape(self, h_in: int, w_in: int) -> Tuple[int, int]:
137
- """Calculates the output height and width."""
138
- if self._padding_val is None: # 'same' padding
139
- # For 'same' padding, output size matches input size if stride is 1.
140
- # If stride > 1, output size is ceil(input_size / stride)
141
- # The _calculate_same_padding helper ensures this behavior.
142
- oh = math.ceil(h_in / self.stride[0])
143
- ow = math.ceil(w_in / self.stride[1])
144
- return oh, ow
145
- else:
146
- # Use the standard formula with the calculated numeric padding
147
- pad_h = self._padding_val[2] + self._padding_val[3] # top + bottom
148
- pad_w = self._padding_val[0] + self._padding_val[1] # left + right
149
- kh, kw = self.kernel_size
150
- sh, sw = self.stride
151
- dh, dw = self.dilation
152
-
153
- eff_kh = (kh - 1) * dh + 1
154
- eff_kw = (kw - 1) * dw + 1
155
-
156
- oh = math.floor((h_in + pad_h - eff_kh) / sh + 1)
157
- ow = math.floor((w_in + pad_w - eff_kw) / sw + 1)
158
- return oh, ow
159
-
160
- def forward(self, x: torch.Tensor) -> torch.Tensor:
161
- """
162
- Performs the forward pass.
353
+ kernel_size = spatial_tuple(kernel_size, k)
354
+ stride = spatial_tuple(stride, k)
355
+ padding = padding_tensor(spatial_tuple(padding, k))
163
356
 
164
- Args:
165
- x (torch.Tensor): Input tensor of shape (N, C_in, H_in, W_in).
357
+ output = t
358
+ output = F.pad(output, sum(reversed(padding), ())) # i.e. the empty tuple
166
359
 
167
- Returns:
168
- torch.Tensor: Output tensor of shape (N, C_out, H_out, W_out).
169
- """
170
- _, C, H, W = x.shape
171
- if C != self.in_channels:
172
- raise ValueError(
173
- f"Input channels {C} does not match expected {self.in_channels}"
174
- )
175
-
176
- # 1. Calculate and Apply Padding
177
- if self._padding_val is None: # 'same' padding mode
178
- pad_l, pad_r, pad_t, pad_b = _calculate_same_padding(
179
- (H, W), self.kernel_size, self.stride, self.dilation
180
- )
181
- padded_x = F.pad(x, (pad_l, pad_r, pad_t, pad_b))
182
- # Update H, W for output shape calculation after padding
183
- # Note: _calculate_output_shape will correctly handle 'same' based on original H, W
184
- elif self._padding_val != (0, 0, 0, 0):
185
- padded_x = F.pad(x, self._padding_val)
186
- else: # No padding ('valid' or explicit 0)
187
- padded_x = x
188
-
189
- # 2. Unfold to extract patches
190
- # Input: (N, C_in, H_pad, W_pad)
191
- # Output: (N, C_in * K_h * K_w, L), where L is the number of patches (H_out * W_out)
192
- patches = self.unfold(padded_x)
193
- num_patches = patches.shape[-1] # L
194
-
195
- # 3. Reshape for the linear layer
196
- # We want (N, L, C_in * K_h * K_w) to apply the linear layer patch-wise
197
- # transpose switches the last two dimensions
198
- patches_transposed = patches.transpose(1, 2) # Shape: (N, L, C_in * K_h * K_w)
199
-
200
- # 4. Apply the linear layer (kernel) to each patch
201
- # Input: (N, L, linear_in_features)
202
- # Output: (N, L, out_channels)
203
- linear_output = self.kernel(patches_transposed)
204
-
205
- # 5. Reshape back to image format
206
- # We need (N, out_channels, L) first
207
- output_transposed = linear_output.transpose(1, 2) # Shape: (N, out_channels, L)
208
-
209
- # Calculate output spatial dimensions
210
- out_h, out_w = self._calculate_output_shape(H, W) # Use original H, W
211
-
212
- # Check if the number of patches matches the calculated output dimensions
213
- if num_patches != out_h * out_w:
214
- # This might happen with certain combinations of stride/padding/dilation/input size
215
- # if the calculation logic has an issue. nn.Unfold is usually robust.
216
- print(
217
- f"Warning: Mismatch in calculated patches. "
218
- f"Expected L={out_h * out_w}, got {num_patches}. "
219
- f"Using unfolded L={num_patches} to determine output shape."
220
- )
221
- # Attempt recovery if possible, though might indicate upstream calculation error
222
- # Find factors of num_patches close to expected out_h, out_w
223
- # This part is tricky and might not always yield the desired shape.
224
- # For simplicity, we'll rely on nn.Unfold's L and reshape.
225
- # A more robust solution might re-calculate H_out, W_out based *only* on L.
226
- # For now, let's stick to the reshape based on calculated out_h, out_w,
227
- # assuming they match L. If they don't, the reshape will fail.
228
- pass # Proceed with calculated out_h, out_w
229
-
230
- # Reshape using einops (or tensor.view)
231
- # Input: (N, C_out, L) -> Output: (N, C_out, H_out, W_out)
232
- output = rearrange(output_transposed, "n c (h w) -> n c h w", h=out_h, w=out_w)
233
- # Alternative using view:
234
- # output = output_transposed.view(N, self.out_channels, out_h, out_w)
235
-
236
- return output
237
-
238
- def extra_repr(self) -> str:
239
- s = (
240
- "{in_channels}, {out_channels}, kernel_size={kernel_size}"
241
- ", stride={stride}"
242
- )
243
- if self.padding_mode != 0 and self.padding_mode != "valid":
244
- s += ", padding={padding_mode}"
245
- if self.dilation != (1,) * len(self.dilation):
246
- s += ", dilation={dilation}"
247
- # if self.groups != 1: # Not implemented
248
- # s += ', groups={groups}'
249
- if self.bias is False:
250
- s += ", bias=False"
251
- if self.linear_module != nn.Linear:
252
- s += f", linear={self.linear.__name__}"
253
- return s.format(**self.__dict__)
254
-
255
-
256
- class WhiteningConv(ConvLayer):
257
- def __init__(
258
- self,
259
- in_channels: int,
260
- kernel_size: int,
261
- eigenvectors: torch.Tensor,
262
- bias: bool = True,
263
- linear_module: Type[nn.Module] = nn.Linear,
264
- ):
265
- """
266
- We end up using a concatenation of the eigenvector tensor with its negation,
267
- as the tendency to use e.g. ReLU in neural networks means that useful
268
- data may otherwise be lost (if one orientation of an eigenvector produces
269
- a strong negative signal, this will be clipped to zero by ReLU, but a
270
- strong positive signal from the negation of the eigenvector will be
271
- preserved). Assuming a square kernel, out channels is thus
272
-
273
- (kernel_size ** 2) * in_channels * 2
274
-
275
- where the trailing "* 2" accounts for the doubling of the size of the
276
- eigenvector tensor we're using by including the negative of each eigenvector
277
- as well.
278
- """
279
- out_channels = kernel_size**2 * in_channels * 2
280
- super().__init__(
281
- in_channels,
282
- out_channels,
283
- kernel_size,
284
- padding="same",
285
- bias=bias,
286
- linear_module=linear_module,
360
+ for i, _ in enumerate(kernel_size):
361
+ output = output.unfold(i + 2, kernel_size[i], stride[i])
362
+
363
+ permutation = [0, 1] + [i + k + 2 for i in range(k)] + [i + 2 for i in range(k)]
364
+
365
+ return output.permute(*permutation).reshape(N, math.prod(kernel_size) * C, -1)
366
+
367
+
368
+ def calculate_output_spatial_size(
369
+ input_spatial_size, kernel_size=1, stride=1, padding=0, dilation=0
370
+ ):
371
+ """
372
+ Calculate the output size for the spatial dimensions of a convolutional operation
373
+ """
374
+ stride = spatial_tuple(stride, len(input_spatial_size))
375
+
376
+ # Handle padding keywords that are sometimes used
377
+ if padding == "same":
378
+ output_size = ()
379
+ for i, in_length in enumerate(input_spatial_size):
380
+ output_size += (math.ceil(in_length / stride[i]),)
381
+ return output_size
382
+ elif padding == "valid":
383
+ padding = 0
384
+
385
+ kernel_size = spatial_tuple(kernel_size, len(input_spatial_size))
386
+ padding = spatial_tuple(padding, len(input_spatial_size))
387
+ dilation = spatial_tuple(dilation, len(input_spatial_size))
388
+
389
+ output_size = ()
390
+
391
+ for i, in_length in enumerate(input_spatial_size):
392
+ output_size += (
393
+ math.floor(
394
+ (in_length + 2 * padding[i] - dilation[i] * (kernel_size[i] - 1) - 1)
395
+ / stride[i]
396
+ + 1
397
+ ),
287
398
  )
288
- self.eigenvectors = torch.cat([eigenvectors, -eigenvectors], dim=0)
289
- # bias updates if `bias`=True but weight doesn't,
290
- # per Jordan (2024) https://arxiv.org/abs/2404.00498
291
- # but weight is set to `requires_grad = False`:
292
- # self.kernel.weight.requires_grad = False
293
- with torch.no_grad():
294
- self.kernel.weight.copy_(self.eigenvectors)
295
- assert self.kernel.weight.requires_grad
399
+ return output_size
296
400
 
297
401
 
298
- class ConcatPool(nn.Module):
402
+ class SpaceToDepth(nn.Module):
299
403
  """
300
- A "pooling" layer that extracts patches from an image-like tensor and stacks
404
+ An operation that extracts patches from an image-like tensor and stacks
301
405
  them channel-wise.
302
406
  """
303
407
 
304
- # TODO: change this to use nn.Fold instead of view, which is equivlent but more readable
408
+ def __init__(self, kernel_size, stride=1, padding=0, spatial_dimensions=2):
409
+ """
410
+ Input shape should be in order (channels, spatial dims...),
411
+ e.g. (channels, height, width)
412
+ """
305
413
 
306
- def __init__(self, kernel_size, stride=1, padding=0, dilation=1):
307
414
  super().__init__()
308
415
 
309
- # Ensure kernel_size, stride, etc. are tuples
310
- self.kernel_size = (
311
- (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size
312
- )
313
- self.stride = (stride, stride) if isinstance(stride, int) else stride
314
- self.padding = (padding, padding) if isinstance(padding, int) else padding
315
- self.dilation = (dilation, dilation) if isinstance(dilation, int) else dilation
416
+ self.kernel_size = kernel_size
417
+ self.stride = stride
418
+ self.padding = padding
419
+ self.spatial_dimensions = spatial_dimensions
420
+
421
+ def forward(self, x):
316
422
 
317
- # The core patch extraction layer
318
- self.unfold = nn.Unfold(
423
+ N, C, *input_spatial_size = x.size()
424
+
425
+ patches = kd_unfold(
426
+ x,
319
427
  kernel_size=self.kernel_size,
320
- dilation=self.dilation,
321
- padding=self.padding,
322
428
  stride=self.stride,
429
+ padding=self.padding,
430
+ k=self.spatial_dimensions,
323
431
  )
324
432
 
325
- def forward(self, x):
326
- # Input shape: (N, C_in, H_in, W_in)
327
- N, C_in, H_in, W_in = x.shape
328
-
329
- # 1. Unfold the image to extract patches
330
- # Output shape: (N, C_in * k * k, L)
331
- # where L is the number of patches, L = H_out * W_out
332
- patches = self.unfold(x)
333
-
334
- # New channel dimension
335
- C_out = C_in * self.kernel_size[0] * self.kernel_size[1]
336
-
337
- # 2. Calculate the output spatial dimensions
338
- H_out = math.floor(
339
- (
340
- H_in
341
- + 2 * self.padding[0]
342
- - self.dilation[0] * (self.kernel_size[0] - 1)
343
- - 1
344
- )
345
- / self.stride[0]
346
- + 1
347
- )
348
- W_out = math.floor(
349
- (
350
- W_in
351
- + 2 * self.padding[1]
352
- - self.dilation[1] * (self.kernel_size[1] - 1)
353
- - 1
354
- )
355
- / self.stride[1]
356
- + 1
433
+ output_spatial_size = calculate_output_spatial_size(
434
+ input_spatial_size=input_spatial_size,
435
+ kernel_size=self.kernel_size,
436
+ stride=self.stride,
437
+ padding=self.padding,
438
+ dilation=1, # kd_unfold doesn't support dilation
357
439
  )
358
440
 
359
- # 3. Reshape to the final 4D tensor
360
- # (N, C_in * k * k, L) -> (N, C_out, H_out, W_out)
361
- out = patches.view(N, C_out, H_out, W_out)
441
+ output_channels = C * math.prod(
442
+ spatial_tuple(self.kernel_size, self.spatial_dimensions)
443
+ )
362
444
 
363
- return out
445
+ return patches.view(N, output_channels, *output_spatial_size)