broccoli-ml 0.21.1__py3-none-any.whl → 0.23.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
broccoli/transformer.py CHANGED
@@ -236,7 +236,7 @@ class FeedforwardBlock(nn.Module):
236
236
  activation_kwargs=None,
237
237
  dropout=0.0,
238
238
  linear_module=nn.Linear,
239
- regularise_values=True,
239
+ sigma_reparam=False,
240
240
  ):
241
241
  super().__init__()
242
242
 
@@ -253,25 +253,17 @@ class FeedforwardBlock(nn.Module):
253
253
  else ratio * output_features
254
254
  )
255
255
 
256
- if regularise_values:
256
+ if sigma_reparam:
257
257
  self.memory_type = SpectralNormLinear
258
258
  else:
259
- self.memory_type = nn.Linear
259
+ self.memory_type = linear_module
260
260
 
261
261
  self.process = nn.Sequential(
262
262
  *[
263
- (
264
- nn.LayerNorm(input_features)
265
- if not regularise_values
266
- else nn.Identity()
267
- ),
263
+ nn.LayerNorm(input_features),
268
264
  linear_module(input_features, self.max_features),
269
265
  self.activation,
270
- (
271
- nn.LayerNorm(input_features)
272
- if not regularise_values
273
- else nn.Identity()
274
- ),
266
+ nn.LayerNorm(ratio * output_features),
275
267
  self.memory_type(ratio * output_features, output_features),
276
268
  self.dropout,
277
269
  ]
broccoli/vit.py CHANGED
@@ -18,7 +18,10 @@ class PadTensor(nn.Module):
18
18
  self.kwargs = kwargs
19
19
 
20
20
  def forward(self, x):
21
- return F.pad(x, *self.args, **self.kwargs)
21
+ if sum(self.args[0]) == 0:
22
+ return x
23
+ else:
24
+ return F.pad(x, *self.args, **self.kwargs)
22
25
 
23
26
 
24
27
  class GetCLSToken(nn.Module):
@@ -97,9 +100,9 @@ class ViTEncoder(nn.Module):
97
100
  def __init__(
98
101
  self,
99
102
  input_size=(32, 32),
103
+ in_channels=3,
100
104
  initial_batch_norm=True,
101
105
  cnn=True,
102
- cnn_in_channels=3,
103
106
  cnn_out_channels=16,
104
107
  cnn_kernel_size=3,
105
108
  cnn_kernel_stride=1,
@@ -113,7 +116,7 @@ class ViTEncoder(nn.Module):
113
116
  pooling_kernel_size=3,
114
117
  pooling_kernel_stride=2,
115
118
  pooling_padding=1,
116
- intermediate_feedforward_layer=True,
119
+ transformer_feedforward_first=True,
117
120
  transformer_position_embedding="relative", # absolute or relative
118
121
  transformer_embedding_size=256,
119
122
  transformer_layers=7,
@@ -180,7 +183,7 @@ class ViTEncoder(nn.Module):
180
183
  dilation=cnn_kernel_dilation,
181
184
  )
182
185
  self.cnn = convxd(
183
- cnn_in_channels,
186
+ in_channels,
184
187
  cnn_out_channels,
185
188
  cnn_kernel_size,
186
189
  stride=cnn_kernel_stride,
@@ -206,8 +209,8 @@ class ViTEncoder(nn.Module):
206
209
  self.cnn = nn.Identity()
207
210
  self.activate_and_dropout = nn.Identity()
208
211
  cnn_output_size = input_size
209
- cnn_out_channels = cnn_in_channels
210
- cnn_activation_out_channels = cnn_in_channels
212
+ cnn_out_channels = in_channels
213
+ cnn_activation_out_channels = in_channels
211
214
 
212
215
  pooling_kernel_voxels = math.prod(
213
216
  spatial_tuple(pooling_kernel_size, self.spatial_dimensions)
@@ -262,6 +265,10 @@ class ViTEncoder(nn.Module):
262
265
  "Pooling type must be max, average, concat or None"
263
266
  )
264
267
 
268
+ self.pooling_channels_padding = PadTensor(
269
+ (0, max(0, transformer_embedding_size - pooling_out_channels))
270
+ )
271
+
265
272
  self.sequence_length = math.prod(pooling_output_size) # One token per voxel
266
273
 
267
274
  if transformer_layers > 0:
@@ -286,38 +293,23 @@ class ViTEncoder(nn.Module):
286
293
  else:
287
294
  self.transformer = nn.Identity()
288
295
 
289
- if intermediate_feedforward_layer:
290
- self.pooling_channels_padding = nn.Identity()
291
- self.intermediate_feedforward_layer = FeedforwardBlock(
292
- pooling_out_channels,
296
+ if transformer_feedforward_first:
297
+ self.initial_ff = FeedforwardBlock(
298
+ transformer_embedding_size,
293
299
  transformer_mlp_ratio,
294
300
  transformer_embedding_size,
295
301
  activation=transformer_activation,
296
302
  activation_kwargs=transformer_activation_kwargs,
297
303
  dropout=transformer_mlp_dropout,
298
304
  linear_module=linear_module,
299
- )
300
- elif pooling_out_channels == transformer_embedding_size:
301
- self.intermediate_feedforward_layer = nn.Identity()
302
- self.pooling_channels_padding = nn.Identity()
303
- elif pooling_out_channels < transformer_embedding_size:
304
- self.intermediate_feedforward_layer = nn.Identity()
305
- self.pooling_channels_padding = PadTensor(
306
- (0, transformer_embedding_size - pooling_out_channels)
305
+ sigma_reparam=True,
307
306
  )
308
307
  else:
309
- raise NotImplementedError(
310
- "In a situation where the choice/parameters of the pooling and the"
311
- + " `cnn_out_channels` (or the number of `input_channels` if"
312
- + " `cnn`=False) means that the pooling will result"
313
- + " in more channels per pixel/voxel than the size of the"
314
- + " intended transformer embedding,"
315
- + " `intermediate_feedforward_layer` must be set to True"
316
- )
308
+ self.initial_ff = nn.Identity()
317
309
 
318
310
  self.encoder = nn.Sequential(
319
311
  *[
320
- batchnormxd(cnn_in_channels) if initial_batch_norm else nn.Identity(),
312
+ batchnormxd(in_channels) if initial_batch_norm else nn.Identity(),
321
313
  self.cnn,
322
314
  self.activate_and_dropout,
323
315
  self.pool,
@@ -325,7 +317,7 @@ class ViTEncoder(nn.Module):
325
317
  f"N C {spatial_dim_names} -> N ({spatial_dim_names}) C"
326
318
  ),
327
319
  self.pooling_channels_padding,
328
- self.intermediate_feedforward_layer,
320
+ self.initial_ff,
329
321
  self.transformer,
330
322
  ]
331
323
  )
@@ -347,9 +339,9 @@ class ViT(nn.Module):
347
339
  self,
348
340
  input_size=(32, 32),
349
341
  image_classes=100,
342
+ in_channels=3,
350
343
  initial_batch_norm=True,
351
344
  cnn=True,
352
- cnn_in_channels=3,
353
345
  cnn_out_channels=16,
354
346
  cnn_kernel_size=3,
355
347
  cnn_kernel_stride=1,
@@ -363,7 +355,7 @@ class ViT(nn.Module):
363
355
  pooling_kernel_size=3,
364
356
  pooling_kernel_stride=2,
365
357
  pooling_padding=1,
366
- intermediate_feedforward_layer=True,
358
+ transformer_feedforward_first=True,
367
359
  transformer_position_embedding="relative", # absolute or relative
368
360
  transformer_embedding_size=256,
369
361
  transformer_layers=7,
@@ -402,8 +394,8 @@ class ViT(nn.Module):
402
394
  self.encoder = ViTEncoder(
403
395
  input_size=input_size,
404
396
  initial_batch_norm=initial_batch_norm,
397
+ in_channels=in_channels,
405
398
  cnn=cnn,
406
- cnn_in_channels=cnn_in_channels,
407
399
  cnn_out_channels=cnn_out_channels,
408
400
  cnn_kernel_size=cnn_kernel_size,
409
401
  cnn_kernel_stride=cnn_kernel_stride,
@@ -417,7 +409,7 @@ class ViT(nn.Module):
417
409
  pooling_kernel_size=pooling_kernel_size,
418
410
  pooling_kernel_stride=pooling_kernel_stride,
419
411
  pooling_padding=pooling_padding,
420
- intermediate_feedforward_layer=intermediate_feedforward_layer,
412
+ transformer_feedforward_first=transformer_feedforward_first,
421
413
  transformer_position_embedding=transformer_position_embedding,
422
414
  transformer_embedding_size=transformer_embedding_size,
423
415
  transformer_layers=transformer_layers,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: broccoli-ml
3
- Version: 0.21.1
3
+ Version: 0.23.0
4
4
  Summary: Some useful Pytorch models, circa 2025
5
5
  License: MIT
6
6
  Author: Nicholas Bailey
@@ -8,10 +8,10 @@ broccoli/eigenpatches.py,sha256=J6n2usN1oQuHEHYiBNyYpn_a9eQcHjOBiIlvSei520Y,2413
8
8
  broccoli/linear.py,sha256=g8YrxNl6g_WcHrWVmbaBHJU5hv6daFS0r4TxAoPJ9UE,3012
9
9
  broccoli/rope.py,sha256=hw7kBPNR9GQXj4GxyIAffsGKPfcTPOFh8Bc7oEHtaZY,12108
10
10
  broccoli/tensor.py,sha256=MUvXtwD2f1sPTBym4FB0x_ZfsJUBNLgULUlN8btV8GI,1943
11
- broccoli/transformer.py,sha256=mFmQiuOu94yKVPZVsd0Cv5FGqkGXOX6zkGAQLqjThKg,16333
11
+ broccoli/transformer.py,sha256=NxOHP-XQRCtoiiTh7WJWNvSjpZzamiqQU966nQh5vhQ,16091
12
12
  broccoli/utils.py,sha256=htq_hOsdhUhL0nJi9WkKiEYOjEoWqFpK5X49PtgTf-0,299
13
- broccoli/vit.py,sha256=vQtOcC0Dd8y6PTWx0xCbnE4ymYkL_HfYrerqaJ0hs1k,16404
14
- broccoli_ml-0.21.1.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
15
- broccoli_ml-0.21.1.dist-info/METADATA,sha256=ZhxdgQ_AVABg6P-rzyssN3j62lU21YWIgnphuf7v2o0,1257
16
- broccoli_ml-0.21.1.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
17
- broccoli_ml-0.21.1.dist-info/RECORD,,
13
+ broccoli/vit.py,sha256=LRofeDyz0wMwClVBeuxnaPmuDpuu8UAVddAOxIOYK8Y,15625
14
+ broccoli_ml-0.23.0.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
15
+ broccoli_ml-0.23.0.dist-info/METADATA,sha256=GfhIQiwV5g480TcVdGdzL3AsgJbQmV89Rxc_oJOrwJo,1257
16
+ broccoli_ml-0.23.0.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
17
+ broccoli_ml-0.23.0.dist-info/RECORD,,