broccoli-ml 0.4.4__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
broccoli/transformer.py CHANGED
@@ -221,7 +221,7 @@ class MHAttention(nn.Module):
221
221
  return self.out_proj(output_without_heads)
222
222
 
223
223
 
224
- class DenoisingAutoEncoder(nn.Module):
224
+ class FeedforwardLayer(nn.Module):
225
225
  """
226
226
  A denoising autoencoder, of the type used in transformer blocks.
227
227
  """
@@ -329,7 +329,7 @@ class TransformerBlock(nn.Module):
329
329
  ("layer_norm", nn.LayerNorm(d_model)),
330
330
  (
331
331
  "denoising_autoencoder",
332
- DenoisingAutoEncoder(
332
+ FeedforwardLayer(
333
333
  d_model,
334
334
  mlp_ratio,
335
335
  d_model,
@@ -395,6 +395,7 @@ class TransformerEncoder(nn.Module):
395
395
  causal=False,
396
396
  linear_module=nn.Linear,
397
397
  bos_tokens=0,
398
+ return_bos_tokens=False,
398
399
  ):
399
400
  if position_embedding_type == "relative":
400
401
  assert source_size is not None # TODO: make this a proper exception
@@ -403,6 +404,7 @@ class TransformerEncoder(nn.Module):
403
404
  self.seq_len = seq_len
404
405
  self.n_heads = n_heads
405
406
  self._bos_tokens = bos_tokens
407
+ self.return_bos_tokens = return_bos_tokens
406
408
 
407
409
  # Initialise BOS tokens with normal init, like usual Pytorch embeddings
408
410
  if self._bos_tokens:
@@ -479,7 +481,7 @@ class TransformerEncoder(nn.Module):
479
481
  for block in self.blocks:
480
482
  x = block(x)
481
483
 
482
- if self._bos_tokens:
484
+ if self._bos_tokens and not self.return_bos_tokens:
483
485
  return x[:, self._bos_tokens :, :]
484
486
  else:
485
487
  return x
broccoli/vit.py CHANGED
@@ -1,7 +1,7 @@
1
1
  import math
2
2
  from typing import Optional
3
3
 
4
- from .transformer import TransformerEncoder, DenoisingAutoEncoder
4
+ from .transformer import TransformerEncoder, FeedforwardLayer
5
5
  from .cnn import SpaceToDepth, calculate_output_spatial_size, spatial_tuple
6
6
  from .activation import ReLU, SquaredReLU, GELU, SwiGLU
7
7
  from einops import einsum
@@ -53,7 +53,7 @@ class SequencePool(nn.Module):
53
53
  return self.norm(projection) if self.batch_norm else projection
54
54
 
55
55
 
56
- class DCTEncoder(nn.Module):
56
+ class ViTEncoder(nn.Module):
57
57
  """
58
58
  Based on the Compact Convolutional Transformer (CCT) of [Hasani et al. (2021)
59
59
  *''Escaping the Big Data Paradigm with Compact Transformers''*](
@@ -76,10 +76,11 @@ class DCTEncoder(nn.Module):
76
76
  cnn_activation: nn.Module = ReLU,
77
77
  cnn_activation_kwargs: Optional[dict] = None,
78
78
  cnn_dropout=0.0,
79
- pooling_type="concat", # maxpool or concat
79
+ pooling_type="concat", # max, average or concat
80
80
  pooling_kernel_size=3,
81
81
  pooling_kernel_stride=2,
82
82
  pooling_padding=1,
83
+ intermediate_feedforward_layer=True,
83
84
  transformer_position_embedding="relative", # absolute or relative
84
85
  transformer_embedding_size=256,
85
86
  transformer_layers=7,
@@ -88,8 +89,8 @@ class DCTEncoder(nn.Module):
88
89
  transformer_bos_tokens=0,
89
90
  transformer_activation: nn.Module = SquaredReLU,
90
91
  transformer_activation_kwargs: Optional[dict] = None,
91
- mlp_dropout=0.0,
92
- msa_dropout=0.1,
92
+ transformer_mlp_dropout=0.0,
93
+ transformer_msa_dropout=0.1,
93
94
  stochastic_depth=0.1,
94
95
  linear_module=nn.Linear,
95
96
  initial_batch_norm=True,
@@ -113,16 +114,19 @@ class DCTEncoder(nn.Module):
113
114
 
114
115
  if self.spatial_dimensions == 1:
115
116
  maxpoolxd = nn.MaxPool1d
117
+ avgpoolxd = nn.AvgPool1d
116
118
  convxd = nn.Conv1d
117
119
  batchnormxd = nn.BatchNorm1d
118
120
  spatial_dim_names = "D1"
119
121
  elif self.spatial_dimensions == 2:
120
122
  maxpoolxd = nn.MaxPool2d
123
+ avgpoolxd = nn.AvgPool2d
121
124
  convxd = nn.Conv2d
122
125
  batchnormxd = nn.BatchNorm2d
123
126
  spatial_dim_names = "D1 D2"
124
127
  elif self.spatial_dimensions == 3:
125
128
  maxpoolxd = nn.MaxPool3d
129
+ avgpoolxd = nn.AvgPool3d
126
130
  convxd = nn.Conv3d
127
131
  batchnormxd = nn.BatchNorm3d
128
132
  spatial_dim_names = "D1 D2 D3"
@@ -157,7 +161,7 @@ class DCTEncoder(nn.Module):
157
161
  spatial_tuple(pooling_kernel_size, self.spatial_dimensions)
158
162
  )
159
163
 
160
- if pooling_type in ["maxpool", None]:
164
+ if pooling_type in ["max", "average", None]:
161
165
  cnn_out_channels = transformer_embedding_size
162
166
  elif pooling_type == "concat":
163
167
  cnn_out_channels = max(
@@ -165,7 +169,9 @@ class DCTEncoder(nn.Module):
165
169
  minimum_cnn_out_channels,
166
170
  )
167
171
  else:
168
- raise NotImplementedError("Pooling type must be maxpool, concat or None")
172
+ raise NotImplementedError(
173
+ "Pooling type must be max, average, concat or None"
174
+ )
169
175
 
170
176
  cnn_activation_out_channels = cnn_out_channels
171
177
 
@@ -210,7 +216,7 @@ class DCTEncoder(nn.Module):
210
216
  ]
211
217
  )
212
218
 
213
- elif pooling_type == "maxpool":
219
+ elif pooling_type == "max":
214
220
  self.pool = nn.Sequential(
215
221
  *[
216
222
  maxpoolxd(
@@ -224,6 +230,20 @@ class DCTEncoder(nn.Module):
224
230
  ]
225
231
  )
226
232
 
233
+ elif pooling_type == "average":
234
+ self.pool = nn.Sequential(
235
+ *[
236
+ avgpoolxd(
237
+ pooling_kernel_size,
238
+ stride=pooling_kernel_stride,
239
+ padding=pooling_padding,
240
+ ),
241
+ Rearrange(
242
+ f"N C {spatial_dim_names} -> N ({spatial_dim_names}) C"
243
+ ), # for transformer
244
+ ]
245
+ )
246
+
227
247
  elif pooling_type == "concat":
228
248
 
229
249
  if transformer_activation_kwargs is not None:
@@ -248,14 +268,8 @@ class DCTEncoder(nn.Module):
248
268
  Rearrange( # for transformer
249
269
  f"N C {spatial_dim_names} -> N ({spatial_dim_names}) C"
250
270
  ),
251
- DenoisingAutoEncoder(
252
- concatpool_out_channels,
253
- transformer_mlp_ratio,
254
- transformer_embedding_size,
255
- activation=transformer_activation,
256
- activation_kwargs=transformer_activation_kwargs,
257
- dropout=0.0,
258
- linear_module=linear_module,
271
+ PadTensor(
272
+ (0, transformer_embedding_size - concatpool_out_channels)
259
273
  ),
260
274
  ]
261
275
  )
@@ -271,8 +285,8 @@ class DCTEncoder(nn.Module):
271
285
  mlp_ratio=transformer_mlp_ratio,
272
286
  activation=transformer_activation,
273
287
  activation_kwargs=transformer_activation_kwargs,
274
- mlp_dropout=mlp_dropout,
275
- msa_dropout=msa_dropout,
288
+ mlp_dropout=transformer_mlp_dropout,
289
+ msa_dropout=transformer_msa_dropout,
276
290
  stochastic_depth=stochastic_depth,
277
291
  causal=False,
278
292
  linear_module=linear_module,
@@ -287,6 +301,19 @@ class DCTEncoder(nn.Module):
287
301
  self.cnn,
288
302
  self.activate_and_dropout,
289
303
  self.pool,
304
+ (
305
+ FeedforwardLayer(
306
+ transformer_embedding_size,
307
+ transformer_mlp_ratio,
308
+ transformer_embedding_size,
309
+ activation=transformer_activation,
310
+ activation_kwargs=transformer_activation_kwargs,
311
+ dropout=transformer_mlp_dropout,
312
+ linear_module=linear_module,
313
+ )
314
+ if intermediate_feedforward_layer
315
+ else nn.Identity()
316
+ ),
290
317
  self.transformer,
291
318
  ]
292
319
  )
@@ -295,7 +322,7 @@ class DCTEncoder(nn.Module):
295
322
  return self.encoder(x)
296
323
 
297
324
 
298
- class DCT(nn.Module):
325
+ class CCT(nn.Module):
299
326
  """
300
327
  Denoising convolutional transformer
301
328
  Based on the Compact Convolutional Transformer (CCT) of [Hasani et al. (2021)
@@ -317,10 +344,11 @@ class DCT(nn.Module):
317
344
  cnn_activation: nn.Module = ReLU,
318
345
  cnn_activation_kwargs: Optional[dict] = None,
319
346
  cnn_dropout=0.0,
320
- pooling_type="concat", # maxpool or concat
347
+ pooling_type="concat", # max, average or concat
321
348
  pooling_kernel_size=3,
322
349
  pooling_kernel_stride=2,
323
350
  pooling_padding=1,
351
+ intermediate_feedforward_layer=True,
324
352
  transformer_position_embedding="relative", # absolute or relative
325
353
  transformer_embedding_size=256,
326
354
  transformer_layers=7,
@@ -329,8 +357,8 @@ class DCT(nn.Module):
329
357
  transformer_bos_tokens=0,
330
358
  transformer_activation: nn.Module = SquaredReLU,
331
359
  transformer_activation_kwargs: Optional[dict] = None,
332
- mlp_dropout=0.0,
333
- msa_dropout=0.1,
360
+ transformer_mlp_dropout=0.0,
361
+ transformer_msa_dropout=0.1,
334
362
  stochastic_depth=0.1,
335
363
  batch_norm_outputs=True,
336
364
  initial_batch_norm=True,
@@ -356,7 +384,7 @@ class DCT(nn.Module):
356
384
  "SwiGLU": SwiGLU,
357
385
  }[transformer_activation]
358
386
 
359
- self.encoder = DCTEncoder(
387
+ self.encoder = ViTEncoder(
360
388
  input_size=input_size,
361
389
  cnn_in_channels=cnn_in_channels,
362
390
  minimum_cnn_out_channels=minimum_cnn_out_channels,
@@ -372,6 +400,7 @@ class DCT(nn.Module):
372
400
  pooling_kernel_size=pooling_kernel_size,
373
401
  pooling_kernel_stride=pooling_kernel_stride,
374
402
  pooling_padding=pooling_padding,
403
+ intermediate_feedforward_layer=intermediate_feedforward_layer,
375
404
  transformer_position_embedding=transformer_position_embedding,
376
405
  transformer_embedding_size=transformer_embedding_size,
377
406
  transformer_layers=transformer_layers,
@@ -380,8 +409,8 @@ class DCT(nn.Module):
380
409
  transformer_bos_tokens=transformer_bos_tokens,
381
410
  transformer_activation=transformer_activation,
382
411
  transformer_activation_kwargs=transformer_activation_kwargs,
383
- mlp_dropout=mlp_dropout,
384
- msa_dropout=msa_dropout,
412
+ mlp_dropout=transformer_mlp_dropout,
413
+ msa_dropout=transformer_msa_dropout,
385
414
  stochastic_depth=stochastic_depth,
386
415
  linear_module=linear_module,
387
416
  initial_batch_norm=initial_batch_norm,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: broccoli-ml
3
- Version: 0.4.4
3
+ Version: 0.5.1
4
4
  Summary: Some useful Pytorch models, circa 2025
5
5
  License: MIT
6
6
  Author: Nicholas Bailey
@@ -8,10 +8,10 @@ broccoli/eigenpatches.py,sha256=J6n2usN1oQuHEHYiBNyYpn_a9eQcHjOBiIlvSei520Y,2413
8
8
  broccoli/linear.py,sha256=0XYCi3ckTEKwAgBOMUSJP2HsnrroOH8eyrhRdpANG2w,1298
9
9
  broccoli/rope.py,sha256=hw7kBPNR9GQXj4GxyIAffsGKPfcTPOFh8Bc7oEHtaZY,12108
10
10
  broccoli/tensor.py,sha256=E2JK5mQwJf75e23-JGcDoT7QxQf89DJReUo2et1LhRY,1716
11
- broccoli/transformer.py,sha256=23R58t3TLZMb9ulhCtQ3gXu0mPlfyPvLM8TaGOpaz58,16310
11
+ broccoli/transformer.py,sha256=SwvutiYOiPlqLzbO_twye7Hna5DsJukVOzzAx9CTCyU,16417
12
12
  broccoli/utils.py,sha256=htq_hOsdhUhL0nJi9WkKiEYOjEoWqFpK5X49PtgTf-0,299
13
- broccoli/vit.py,sha256=FQbdwNpvdqHreTdh_rK98sqhIrVVtn4L77EmNSvxXK0,14086
14
- broccoli_ml-0.4.4.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
15
- broccoli_ml-0.4.4.dist-info/METADATA,sha256=rsyvyeUbOLY_QLuC-klSHxaLVWZjfffKf2ugDuJXnR8,1256
16
- broccoli_ml-0.4.4.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
17
- broccoli_ml-0.4.4.dist-info/RECORD,,
13
+ broccoli/vit.py,sha256=CSllHj_cI0dQveuLSoHD-Y95nOqVO1F-p-8SmlJffhM,15272
14
+ broccoli_ml-0.5.1.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
15
+ broccoli_ml-0.5.1.dist-info/METADATA,sha256=IFo3lnkE6Zti81jUvLuBONLlR2PUUAwj1AKqPCgFUdI,1256
16
+ broccoli_ml-0.5.1.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
17
+ broccoli_ml-0.5.1.dist-info/RECORD,,