broccoli-ml 0.5.0__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
broccoli/transformer.py CHANGED
@@ -395,6 +395,7 @@ class TransformerEncoder(nn.Module):
395
395
  causal=False,
396
396
  linear_module=nn.Linear,
397
397
  bos_tokens=0,
398
+ return_bos_tokens=False,
398
399
  ):
399
400
  if position_embedding_type == "relative":
400
401
  assert source_size is not None # TODO: make this a proper exception
@@ -403,6 +404,7 @@ class TransformerEncoder(nn.Module):
403
404
  self.seq_len = seq_len
404
405
  self.n_heads = n_heads
405
406
  self._bos_tokens = bos_tokens
407
+ self.return_bos_tokens = return_bos_tokens
406
408
 
407
409
  # Initialise BOS tokens with normal init, like usual Pytorch embeddings
408
410
  if self._bos_tokens:
@@ -479,7 +481,7 @@ class TransformerEncoder(nn.Module):
479
481
  for block in self.blocks:
480
482
  x = block(x)
481
483
 
482
- if self._bos_tokens:
484
+ if self._bos_tokens and not self.return_bos_tokens:
483
485
  return x[:, self._bos_tokens :, :]
484
486
  else:
485
487
  return x
broccoli/vit.py CHANGED
@@ -76,7 +76,7 @@ class ViTEncoder(nn.Module):
76
76
  cnn_activation: nn.Module = ReLU,
77
77
  cnn_activation_kwargs: Optional[dict] = None,
78
78
  cnn_dropout=0.0,
79
- pooling_type="concat", # maxpool or concat
79
+ pooling_type="concat", # max, average or concat
80
80
  pooling_kernel_size=3,
81
81
  pooling_kernel_stride=2,
82
82
  pooling_padding=1,
@@ -114,16 +114,19 @@ class ViTEncoder(nn.Module):
114
114
 
115
115
  if self.spatial_dimensions == 1:
116
116
  maxpoolxd = nn.MaxPool1d
117
+ avgpoolxd = nn.AvgPool1d
117
118
  convxd = nn.Conv1d
118
119
  batchnormxd = nn.BatchNorm1d
119
120
  spatial_dim_names = "D1"
120
121
  elif self.spatial_dimensions == 2:
121
122
  maxpoolxd = nn.MaxPool2d
123
+ avgpoolxd = nn.AvgPool2d
122
124
  convxd = nn.Conv2d
123
125
  batchnormxd = nn.BatchNorm2d
124
126
  spatial_dim_names = "D1 D2"
125
127
  elif self.spatial_dimensions == 3:
126
128
  maxpoolxd = nn.MaxPool3d
129
+ avgpoolxd = nn.AvgPool3d
127
130
  convxd = nn.Conv3d
128
131
  batchnormxd = nn.BatchNorm3d
129
132
  spatial_dim_names = "D1 D2 D3"
@@ -158,7 +161,7 @@ class ViTEncoder(nn.Module):
158
161
  spatial_tuple(pooling_kernel_size, self.spatial_dimensions)
159
162
  )
160
163
 
161
- if pooling_type in ["maxpool", None]:
164
+ if pooling_type in ["max", "average", None]:
162
165
  cnn_out_channels = transformer_embedding_size
163
166
  elif pooling_type == "concat":
164
167
  cnn_out_channels = max(
@@ -166,7 +169,9 @@ class ViTEncoder(nn.Module):
166
169
  minimum_cnn_out_channels,
167
170
  )
168
171
  else:
169
- raise NotImplementedError("Pooling type must be maxpool, concat or None")
172
+ raise NotImplementedError(
173
+ "Pooling type must be max, average, concat or None"
174
+ )
170
175
 
171
176
  cnn_activation_out_channels = cnn_out_channels
172
177
 
@@ -211,7 +216,7 @@ class ViTEncoder(nn.Module):
211
216
  ]
212
217
  )
213
218
 
214
- elif pooling_type == "maxpool":
219
+ elif pooling_type == "max":
215
220
  self.pool = nn.Sequential(
216
221
  *[
217
222
  maxpoolxd(
@@ -225,6 +230,20 @@ class ViTEncoder(nn.Module):
225
230
  ]
226
231
  )
227
232
 
233
+ elif pooling_type == "average":
234
+ self.pool = nn.Sequential(
235
+ *[
236
+ avgpoolxd(
237
+ pooling_kernel_size,
238
+ stride=pooling_kernel_stride,
239
+ padding=pooling_padding,
240
+ ),
241
+ Rearrange(
242
+ f"N C {spatial_dim_names} -> N ({spatial_dim_names}) C"
243
+ ), # for transformer
244
+ ]
245
+ )
246
+
228
247
  elif pooling_type == "concat":
229
248
 
230
249
  if transformer_activation_kwargs is not None:
@@ -303,7 +322,7 @@ class ViTEncoder(nn.Module):
303
322
  return self.encoder(x)
304
323
 
305
324
 
306
- class ViT(nn.Module):
325
+ class CCT(nn.Module):
307
326
  """
308
327
  Denoising convolutional transformer
309
328
  Based on the Compact Convolutional Transformer (CCT) of [Hasani et al. (2021)
@@ -325,7 +344,7 @@ class ViT(nn.Module):
325
344
  cnn_activation: nn.Module = ReLU,
326
345
  cnn_activation_kwargs: Optional[dict] = None,
327
346
  cnn_dropout=0.0,
328
- pooling_type="concat", # maxpool or concat
347
+ pooling_type="concat", # max, average or concat
329
348
  pooling_kernel_size=3,
330
349
  pooling_kernel_stride=2,
331
350
  pooling_padding=1,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: broccoli-ml
3
- Version: 0.5.0
3
+ Version: 0.5.1
4
4
  Summary: Some useful Pytorch models, circa 2025
5
5
  License: MIT
6
6
  Author: Nicholas Bailey
@@ -8,10 +8,10 @@ broccoli/eigenpatches.py,sha256=J6n2usN1oQuHEHYiBNyYpn_a9eQcHjOBiIlvSei520Y,2413
8
8
  broccoli/linear.py,sha256=0XYCi3ckTEKwAgBOMUSJP2HsnrroOH8eyrhRdpANG2w,1298
9
9
  broccoli/rope.py,sha256=hw7kBPNR9GQXj4GxyIAffsGKPfcTPOFh8Bc7oEHtaZY,12108
10
10
  broccoli/tensor.py,sha256=E2JK5mQwJf75e23-JGcDoT7QxQf89DJReUo2et1LhRY,1716
11
- broccoli/transformer.py,sha256=-b7XG51c0ZgDoXQRyEmrZH6IvBCH1etK0NTApwqNhpU,16302
11
+ broccoli/transformer.py,sha256=SwvutiYOiPlqLzbO_twye7Hna5DsJukVOzzAx9CTCyU,16417
12
12
  broccoli/utils.py,sha256=htq_hOsdhUhL0nJi9WkKiEYOjEoWqFpK5X49PtgTf-0,299
13
- broccoli/vit.py,sha256=7XJp53TEcGo2jaWL-XtEd_RTAvl2er3DmAcXawm1SkY,14627
14
- broccoli_ml-0.5.0.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
15
- broccoli_ml-0.5.0.dist-info/METADATA,sha256=eN2lZ82O-87kkq6Kypo2MXzltUUKMvrulUDphy2-lNE,1256
16
- broccoli_ml-0.5.0.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
17
- broccoli_ml-0.5.0.dist-info/RECORD,,
13
+ broccoli/vit.py,sha256=CSllHj_cI0dQveuLSoHD-Y95nOqVO1F-p-8SmlJffhM,15272
14
+ broccoli_ml-0.5.1.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
15
+ broccoli_ml-0.5.1.dist-info/METADATA,sha256=IFo3lnkE6Zti81jUvLuBONLlR2PUUAwj1AKqPCgFUdI,1256
16
+ broccoli_ml-0.5.1.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
17
+ broccoli_ml-0.5.1.dist-info/RECORD,,