broccoli-ml 0.6.0__py3-none-any.whl → 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
broccoli/transformer.py CHANGED
@@ -223,7 +223,7 @@ class MHAttention(nn.Module):
223
223
 
224
224
  class FeedforwardLayer(nn.Module):
225
225
  """
226
- A denoising autoencoder, of the type used in transformer blocks.
226
+ ...
227
227
  """
228
228
 
229
229
  def __init__(
@@ -247,6 +247,7 @@ class FeedforwardLayer(nn.Module):
247
247
 
248
248
  self.process = nn.Sequential(
249
249
  *[
250
+ nn.LayerNorm(input_features),
250
251
  linear_module(
251
252
  input_features,
252
253
  (
@@ -256,8 +257,8 @@ class FeedforwardLayer(nn.Module):
256
257
  ),
257
258
  ),
258
259
  self.activation,
259
- self.dropout,
260
260
  linear_module(ratio * input_features, output_features),
261
+ self.dropout,
261
262
  ]
262
263
  )
263
264
 
@@ -323,25 +324,14 @@ class TransformerBlock(nn.Module):
323
324
  )
324
325
 
325
326
  # Submodules for the feedforward process
326
- self.ff_process = nn.Sequential(
327
- OrderedDict(
328
- [
329
- ("layer_norm", nn.LayerNorm(d_model)),
330
- (
331
- "denoising_autoencoder",
332
- FeedforwardLayer(
333
- d_model,
334
- mlp_ratio,
335
- d_model,
336
- activation=activation,
337
- activation_kwargs=activation_kwargs,
338
- dropout=0.0,
339
- linear_module=linear_module,
340
- ),
341
- ),
342
- ("dropout", nn.Dropout(mlp_dropout)),
343
- ]
344
- )
327
+ self.ff = FeedforwardLayer(
328
+ d_model,
329
+ mlp_ratio,
330
+ d_model,
331
+ activation=activation,
332
+ activation_kwargs=activation_kwargs,
333
+ dropout=mlp_dropout,
334
+ linear_module=linear_module,
345
335
  )
346
336
 
347
337
  @property
@@ -366,7 +356,7 @@ class TransformerBlock(nn.Module):
366
356
  process_x = process_x + self.attn(
367
357
  norm_process_x, norm_process_x, norm_process_x
368
358
  )
369
- process_x = process_x + self.ff_process(process_x)
359
+ process_x = process_x + self.ff(process_x)
370
360
  x = torch.cat([identity_x, process_x])[unshuffle_indices, :, :].contiguous()
371
361
 
372
362
  return x
broccoli/vit.py CHANGED
@@ -66,8 +66,9 @@ class ViTEncoder(nn.Module):
66
66
  def __init__(
67
67
  self,
68
68
  input_size=(32, 32),
69
+ cnn=True,
69
70
  cnn_in_channels=3,
70
- minimum_cnn_out_channels=16,
71
+ cnn_out_channels=16,
71
72
  cnn_kernel_size=3,
72
73
  cnn_kernel_stride=1,
73
74
  cnn_padding="same",
@@ -135,12 +136,49 @@ class ViTEncoder(nn.Module):
135
136
  "`input_size` must be a tuple of length 1, 2, or 3."
136
137
  )
137
138
 
138
- cnn_output_size = calculate_output_spatial_size(
139
- input_size,
140
- kernel_size=cnn_kernel_size,
141
- stride=cnn_kernel_stride,
142
- padding=cnn_padding,
143
- dilation=cnn_kernel_dilation,
139
+ if cnn:
140
+ cnn_output_size = calculate_output_spatial_size(
141
+ input_size,
142
+ kernel_size=cnn_kernel_size,
143
+ stride=cnn_kernel_stride,
144
+ padding=cnn_padding,
145
+ dilation=cnn_kernel_dilation,
146
+ )
147
+ self.cnn = convxd(
148
+ cnn_in_channels,
149
+ cnn_out_channels,
150
+ cnn_kernel_size,
151
+ stride=cnn_kernel_stride,
152
+ padding=cnn_padding,
153
+ dilation=cnn_kernel_dilation,
154
+ groups=cnn_kernel_groups,
155
+ bias=True,
156
+ padding_mode="zeros",
157
+ )
158
+ cnn_activation_out_channels = cnn_out_channels
159
+ self.activate_and_dropout = nn.Sequential(
160
+ *[
161
+ Rearrange( # rearrange in case we're using XGLU activation
162
+ f"N C {spatial_dim_names} -> N {spatial_dim_names} C"
163
+ ),
164
+ self.cnn_activation,
165
+ Rearrange(f"N {spatial_dim_names} C -> N C {spatial_dim_names}"),
166
+ nn.Dropout(cnn_dropout),
167
+ batchnormxd(cnn_activation_out_channels),
168
+ ]
169
+ )
170
+ # This block rhymes:
171
+ if cnn and cnn_activation.__name__.endswith("GLU"):
172
+ cnn_out_channels *= 2
173
+ else:
174
+ self.cnn = nn.Identity()
175
+ self.activate_and_dropout = nn.Identity()
176
+ cnn_output_size = input_size
177
+ cnn_out_channels = cnn_in_channels
178
+ cnn_activation_out_channels = cnn_in_channels
179
+
180
+ pooling_kernel_voxels = math.prod(
181
+ spatial_tuple(pooling_kernel_size, self.spatial_dimensions)
144
182
  )
145
183
 
146
184
  pooling_output_size = (
@@ -155,59 +193,8 @@ class ViTEncoder(nn.Module):
155
193
  )
156
194
  )
157
195
 
158
- self.sequence_length = math.prod(pooling_output_size) # One token per voxel
159
-
160
- pooling_kernel_voxels = math.prod(
161
- spatial_tuple(pooling_kernel_size, self.spatial_dimensions)
162
- )
163
-
164
- if pooling_type in ["max", "average", None]:
165
- cnn_out_channels = transformer_embedding_size
166
- elif pooling_type == "concat":
167
- cnn_out_channels = max(
168
- math.floor(transformer_embedding_size / pooling_kernel_voxels),
169
- minimum_cnn_out_channels,
170
- )
171
- else:
172
- raise NotImplementedError(
173
- "Pooling type must be max, average, concat or None"
174
- )
175
-
176
- cnn_activation_out_channels = cnn_out_channels
177
-
178
- # This block rhymes:
179
- if cnn_activation.__name__.endswith("GLU"):
180
- cnn_out_channels *= 2
181
-
182
- self.cnn = convxd(
183
- cnn_in_channels,
184
- cnn_out_channels,
185
- cnn_kernel_size,
186
- stride=cnn_kernel_stride,
187
- padding=cnn_padding,
188
- dilation=cnn_kernel_dilation,
189
- groups=cnn_kernel_groups,
190
- bias=True,
191
- padding_mode="zeros",
192
- )
193
-
194
- self.activate_and_dropout = nn.Sequential(
195
- *[
196
- Rearrange( # rearrange in case we're using XGLU activation
197
- f"N C {spatial_dim_names} -> N {spatial_dim_names} C"
198
- ),
199
- self.cnn_activation,
200
- Rearrange(f"N {spatial_dim_names} C -> N C {spatial_dim_names}"),
201
- nn.Dropout(cnn_dropout),
202
- (
203
- batchnormxd(cnn_activation_out_channels)
204
- if initial_batch_norm
205
- else nn.Identity()
206
- ),
207
- ]
208
- )
209
-
210
196
  if pooling_type is None:
197
+ pooling_out_channels = cnn_activation_out_channels
211
198
  self.pool = nn.Sequential(
212
199
  *[
213
200
  Rearrange(
@@ -215,70 +202,36 @@ class ViTEncoder(nn.Module):
215
202
  ), # for transformer
216
203
  ]
217
204
  )
218
- pooling_out_channels = transformer_embedding_size
219
205
 
220
206
  elif pooling_type == "max":
221
- self.pool = nn.Sequential(
222
- *[
223
- maxpoolxd(
224
- pooling_kernel_size,
225
- stride=pooling_kernel_stride,
226
- padding=pooling_padding,
227
- ),
228
- Rearrange(
229
- f"N C {spatial_dim_names} -> N ({spatial_dim_names}) C"
230
- ), # for transformer
231
- ]
207
+ pooling_out_channels = cnn_activation_out_channels
208
+ self.pool = maxpoolxd(
209
+ pooling_kernel_size,
210
+ stride=pooling_kernel_stride,
211
+ padding=pooling_padding,
232
212
  )
233
- pooling_out_channels = transformer_embedding_size
234
-
235
213
  elif pooling_type == "average":
236
- self.pool = nn.Sequential(
237
- *[
238
- avgpoolxd(
239
- pooling_kernel_size,
240
- stride=pooling_kernel_stride,
241
- padding=pooling_padding,
242
- ),
243
- Rearrange(
244
- f"N C {spatial_dim_names} -> N ({spatial_dim_names}) C"
245
- ), # for transformer
246
- ]
214
+ pooling_out_channels = cnn_activation_out_channels
215
+ self.pool = avgpoolxd(
216
+ pooling_kernel_size,
217
+ stride=pooling_kernel_stride,
218
+ padding=pooling_padding,
247
219
  )
248
- pooling_out_channels = transformer_embedding_size
249
-
250
220
  elif pooling_type == "concat":
251
-
252
- if transformer_activation_kwargs is not None:
253
- self.concatpool_activation = transformer_activation(
254
- **transformer_activation_kwargs
255
- )
256
- else:
257
- self.concatpool_activation = transformer_activation()
258
-
259
221
  pooling_out_channels = pooling_kernel_voxels * cnn_activation_out_channels
260
-
261
- self.pool = nn.Sequential(
262
- *[
263
- SpaceToDepth(
264
- pooling_kernel_size,
265
- stride=pooling_kernel_stride,
266
- padding=pooling_padding,
267
- spatial_dimensions=self.spatial_dimensions,
268
- ),
269
- Rearrange( # for transformer
270
- f"N C {spatial_dim_names} -> N ({spatial_dim_names}) C"
271
- ),
272
- (
273
- PadTensor(
274
- (0, transformer_embedding_size - pooling_out_channels)
275
- )
276
- if not intermediate_feedforward_layer
277
- else nn.Identity()
278
- ),
279
- ]
222
+ self.pool = SpaceToDepth(
223
+ pooling_kernel_size,
224
+ stride=pooling_kernel_stride,
225
+ padding=pooling_padding,
226
+ spatial_dimensions=self.spatial_dimensions,
227
+ )
228
+ else:
229
+ raise NotImplementedError(
230
+ "Pooling type must be max, average, concat or None"
280
231
  )
281
232
 
233
+ self.sequence_length = math.prod(pooling_output_size) # One token per voxel
234
+
282
235
  if transformer_layers > 0:
283
236
  self.transformer = TransformerEncoder(
284
237
  self.sequence_length,
@@ -300,25 +253,43 @@ class ViTEncoder(nn.Module):
300
253
  else:
301
254
  self.transformer = nn.Identity()
302
255
 
256
+ if intermediate_feedforward_layer:
257
+ self.pooling_channels_padding = nn.Identity()
258
+ self.intermediate_feedforward_layer = FeedforwardLayer(
259
+ pooling_out_channels,
260
+ transformer_mlp_ratio,
261
+ transformer_embedding_size,
262
+ activation=transformer_activation,
263
+ activation_kwargs=transformer_activation_kwargs,
264
+ dropout=transformer_mlp_dropout,
265
+ linear_module=linear_module,
266
+ )
267
+ elif pooling_out_channels < transformer_embedding_size:
268
+ self.intermediate_feedforward_layer = nn.Identity()
269
+ self.pooling_channels_padding = PadTensor(
270
+ (0, transformer_embedding_size - pooling_out_channels)
271
+ )
272
+ else:
273
+ raise NotImplementedError(
274
+ "In a situation where the choice/parameters of the pooling and the"
275
+ + " `cnn_out_channels` (or the number of `input_channels` if"
276
+ + " `cnn`=False) means that the pooling will result"
277
+ + " in more channels per pixel/voxel than the size of the"
278
+ + " intended transformer embedding,"
279
+ + " `intermediate_feedforward_layer` must be set to True"
280
+ )
281
+
303
282
  self.encoder = nn.Sequential(
304
283
  *[
305
284
  batchnormxd(cnn_in_channels) if initial_batch_norm else nn.Identity(),
306
285
  self.cnn,
307
286
  self.activate_and_dropout,
308
287
  self.pool,
309
- (
310
- FeedforwardLayer(
311
- pooling_out_channels,
312
- transformer_mlp_ratio,
313
- transformer_embedding_size,
314
- activation=transformer_activation,
315
- activation_kwargs=transformer_activation_kwargs,
316
- dropout=transformer_mlp_dropout,
317
- linear_module=linear_module,
318
- )
319
- if intermediate_feedforward_layer
320
- else nn.Identity()
288
+ Rearrange( # for transformer
289
+ f"N C {spatial_dim_names} -> N ({spatial_dim_names}) C"
321
290
  ),
291
+ self.pooling_channels_padding,
292
+ self.intermediate_feedforward_layer,
322
293
  self.transformer,
323
294
  ]
324
295
  )
@@ -339,6 +310,7 @@ class CCT(nn.Module):
339
310
  def __init__(
340
311
  self,
341
312
  input_size=(32, 32),
313
+ cnn=True,
342
314
  cnn_in_channels=3,
343
315
  minimum_cnn_out_channels=16,
344
316
  cnn_kernel_size=3,
@@ -391,6 +363,7 @@ class CCT(nn.Module):
391
363
 
392
364
  self.encoder = ViTEncoder(
393
365
  input_size=input_size,
366
+ cnn=cnn,
394
367
  cnn_in_channels=cnn_in_channels,
395
368
  minimum_cnn_out_channels=minimum_cnn_out_channels,
396
369
  cnn_kernel_size=cnn_kernel_size,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: broccoli-ml
3
- Version: 0.6.0
3
+ Version: 0.8.0
4
4
  Summary: Some useful Pytorch models, circa 2025
5
5
  License: MIT
6
6
  Author: Nicholas Bailey
@@ -8,10 +8,10 @@ broccoli/eigenpatches.py,sha256=J6n2usN1oQuHEHYiBNyYpn_a9eQcHjOBiIlvSei520Y,2413
8
8
  broccoli/linear.py,sha256=0XYCi3ckTEKwAgBOMUSJP2HsnrroOH8eyrhRdpANG2w,1298
9
9
  broccoli/rope.py,sha256=hw7kBPNR9GQXj4GxyIAffsGKPfcTPOFh8Bc7oEHtaZY,12108
10
10
  broccoli/tensor.py,sha256=E2JK5mQwJf75e23-JGcDoT7QxQf89DJReUo2et1LhRY,1716
11
- broccoli/transformer.py,sha256=SwvutiYOiPlqLzbO_twye7Hna5DsJukVOzzAx9CTCyU,16417
11
+ broccoli/transformer.py,sha256=RSZpbHs_K4ts5os6lWxcGDI3p0zreRwQNnk6mV8HJnk,15930
12
12
  broccoli/utils.py,sha256=htq_hOsdhUhL0nJi9WkKiEYOjEoWqFpK5X49PtgTf-0,299
13
- broccoli/vit.py,sha256=_oL0NRUJakyIke2g8WK5eWaiEh06gAhI67l6Wl7k1oM,15659
14
- broccoli_ml-0.6.0.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
15
- broccoli_ml-0.6.0.dist-info/METADATA,sha256=b0RyaSofwkIM2H86MeOop_-VVJMkEBuJdTxFqCNAahY,1256
16
- broccoli_ml-0.6.0.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
17
- broccoli_ml-0.6.0.dist-info/RECORD,,
13
+ broccoli/vit.py,sha256=qNavAe_jNlslYlLsXmScHTnLuL3-MAVCAhMsJt3v5Rg,15209
14
+ broccoli_ml-0.8.0.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
15
+ broccoli_ml-0.8.0.dist-info/METADATA,sha256=LVv2HQIjuvpHEH9uixMzbdE-VeJ44E2TizerorwHYS0,1256
16
+ broccoli_ml-0.8.0.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
17
+ broccoli_ml-0.8.0.dist-info/RECORD,,