broccoli-ml 0.4.3__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
broccoli/transformer.py CHANGED
@@ -221,7 +221,7 @@ class MHAttention(nn.Module):
221
221
  return self.out_proj(output_without_heads)
222
222
 
223
223
 
224
- class DenoisingAutoEncoder(nn.Module):
224
+ class FeedforwardLayer(nn.Module):
225
225
  """
226
226
  A denoising autoencoder, of the type used in transformer blocks.
227
227
  """
@@ -329,7 +329,7 @@ class TransformerBlock(nn.Module):
329
329
  ("layer_norm", nn.LayerNorm(d_model)),
330
330
  (
331
331
  "denoising_autoencoder",
332
- DenoisingAutoEncoder(
332
+ FeedforwardLayer(
333
333
  d_model,
334
334
  mlp_ratio,
335
335
  d_model,
broccoli/vit.py CHANGED
@@ -1,7 +1,7 @@
1
1
  import math
2
2
  from typing import Optional
3
3
 
4
- from .transformer import TransformerEncoder, DenoisingAutoEncoder
4
+ from .transformer import TransformerEncoder, FeedforwardLayer
5
5
  from .cnn import SpaceToDepth, calculate_output_spatial_size, spatial_tuple
6
6
  from .activation import ReLU, SquaredReLU, GELU, SwiGLU
7
7
  from einops import einsum
@@ -53,7 +53,7 @@ class SequencePool(nn.Module):
53
53
  return self.norm(projection) if self.batch_norm else projection
54
54
 
55
55
 
56
- class DCTEncoder(nn.Module):
56
+ class ViTEncoder(nn.Module):
57
57
  """
58
58
  Based on the Compact Convolutional Transformer (CCT) of [Hasani et al. (2021)
59
59
  *''Escaping the Big Data Paradigm with Compact Transformers''*](
@@ -80,6 +80,7 @@ class DCTEncoder(nn.Module):
80
80
  pooling_kernel_size=3,
81
81
  pooling_kernel_stride=2,
82
82
  pooling_padding=1,
83
+ intermediate_feedforward_layer=True,
83
84
  transformer_position_embedding="relative", # absolute or relative
84
85
  transformer_embedding_size=256,
85
86
  transformer_layers=7,
@@ -88,8 +89,8 @@ class DCTEncoder(nn.Module):
88
89
  transformer_bos_tokens=0,
89
90
  transformer_activation: nn.Module = SquaredReLU,
90
91
  transformer_activation_kwargs: Optional[dict] = None,
91
- mlp_dropout=0.0,
92
- msa_dropout=0.1,
92
+ transformer_mlp_dropout=0.0,
93
+ transformer_msa_dropout=0.1,
93
94
  stochastic_depth=0.1,
94
95
  linear_module=nn.Linear,
95
96
  initial_batch_norm=True,
@@ -160,7 +161,7 @@ class DCTEncoder(nn.Module):
160
161
  if pooling_type in ["maxpool", None]:
161
162
  cnn_out_channels = transformer_embedding_size
162
163
  elif pooling_type == "concat":
163
- cnn_out_channels = min(
164
+ cnn_out_channels = max(
164
165
  math.floor(transformer_embedding_size / pooling_kernel_voxels),
165
166
  minimum_cnn_out_channels,
166
167
  )
@@ -248,14 +249,8 @@ class DCTEncoder(nn.Module):
248
249
  Rearrange( # for transformer
249
250
  f"N C {spatial_dim_names} -> N ({spatial_dim_names}) C"
250
251
  ),
251
- DenoisingAutoEncoder(
252
- concatpool_out_channels,
253
- transformer_mlp_ratio,
254
- transformer_embedding_size,
255
- activation=transformer_activation,
256
- activation_kwargs=transformer_activation_kwargs,
257
- dropout=0.0,
258
- linear_module=linear_module,
252
+ PadTensor(
253
+ (0, transformer_embedding_size - concatpool_out_channels)
259
254
  ),
260
255
  ]
261
256
  )
@@ -271,8 +266,8 @@ class DCTEncoder(nn.Module):
271
266
  mlp_ratio=transformer_mlp_ratio,
272
267
  activation=transformer_activation,
273
268
  activation_kwargs=transformer_activation_kwargs,
274
- mlp_dropout=mlp_dropout,
275
- msa_dropout=msa_dropout,
269
+ mlp_dropout=transformer_mlp_dropout,
270
+ msa_dropout=transformer_msa_dropout,
276
271
  stochastic_depth=stochastic_depth,
277
272
  causal=False,
278
273
  linear_module=linear_module,
@@ -287,6 +282,19 @@ class DCTEncoder(nn.Module):
287
282
  self.cnn,
288
283
  self.activate_and_dropout,
289
284
  self.pool,
285
+ (
286
+ FeedforwardLayer(
287
+ transformer_embedding_size,
288
+ transformer_mlp_ratio,
289
+ transformer_embedding_size,
290
+ activation=transformer_activation,
291
+ activation_kwargs=transformer_activation_kwargs,
292
+ dropout=transformer_mlp_dropout,
293
+ linear_module=linear_module,
294
+ )
295
+ if intermediate_feedforward_layer
296
+ else nn.Identity()
297
+ ),
290
298
  self.transformer,
291
299
  ]
292
300
  )
@@ -295,7 +303,7 @@ class DCTEncoder(nn.Module):
295
303
  return self.encoder(x)
296
304
 
297
305
 
298
- class DCT(nn.Module):
306
+ class ViT(nn.Module):
299
307
  """
300
308
  Denoising convolutional transformer
301
309
  Based on the Compact Convolutional Transformer (CCT) of [Hasani et al. (2021)
@@ -321,6 +329,7 @@ class DCT(nn.Module):
321
329
  pooling_kernel_size=3,
322
330
  pooling_kernel_stride=2,
323
331
  pooling_padding=1,
332
+ intermediate_feedforward_layer=True,
324
333
  transformer_position_embedding="relative", # absolute or relative
325
334
  transformer_embedding_size=256,
326
335
  transformer_layers=7,
@@ -329,12 +338,12 @@ class DCT(nn.Module):
329
338
  transformer_bos_tokens=0,
330
339
  transformer_activation: nn.Module = SquaredReLU,
331
340
  transformer_activation_kwargs: Optional[dict] = None,
332
- mlp_dropout=0.0,
333
- msa_dropout=0.1,
341
+ transformer_mlp_dropout=0.0,
342
+ transformer_msa_dropout=0.1,
334
343
  stochastic_depth=0.1,
335
344
  batch_norm_outputs=True,
336
- linear_module=nn.Linear,
337
345
  initial_batch_norm=True,
346
+ linear_module=nn.Linear,
338
347
  image_classes=100,
339
348
  ):
340
349
 
@@ -356,7 +365,7 @@ class DCT(nn.Module):
356
365
  "SwiGLU": SwiGLU,
357
366
  }[transformer_activation]
358
367
 
359
- self.encoder = DCTEncoder(
368
+ self.encoder = ViTEncoder(
360
369
  input_size=input_size,
361
370
  cnn_in_channels=cnn_in_channels,
362
371
  minimum_cnn_out_channels=minimum_cnn_out_channels,
@@ -372,6 +381,7 @@ class DCT(nn.Module):
372
381
  pooling_kernel_size=pooling_kernel_size,
373
382
  pooling_kernel_stride=pooling_kernel_stride,
374
383
  pooling_padding=pooling_padding,
384
+ intermediate_feedforward_layer=intermediate_feedforward_layer,
375
385
  transformer_position_embedding=transformer_position_embedding,
376
386
  transformer_embedding_size=transformer_embedding_size,
377
387
  transformer_layers=transformer_layers,
@@ -380,8 +390,8 @@ class DCT(nn.Module):
380
390
  transformer_bos_tokens=transformer_bos_tokens,
381
391
  transformer_activation=transformer_activation,
382
392
  transformer_activation_kwargs=transformer_activation_kwargs,
383
- mlp_dropout=mlp_dropout,
384
- msa_dropout=msa_dropout,
393
+ mlp_dropout=transformer_mlp_dropout,
394
+ msa_dropout=transformer_msa_dropout,
385
395
  stochastic_depth=stochastic_depth,
386
396
  linear_module=linear_module,
387
397
  initial_batch_norm=initial_batch_norm,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: broccoli-ml
3
- Version: 0.4.3
3
+ Version: 0.5.0
4
4
  Summary: Some useful Pytorch models, circa 2025
5
5
  License: MIT
6
6
  Author: Nicholas Bailey
@@ -8,10 +8,10 @@ broccoli/eigenpatches.py,sha256=J6n2usN1oQuHEHYiBNyYpn_a9eQcHjOBiIlvSei520Y,2413
8
8
  broccoli/linear.py,sha256=0XYCi3ckTEKwAgBOMUSJP2HsnrroOH8eyrhRdpANG2w,1298
9
9
  broccoli/rope.py,sha256=hw7kBPNR9GQXj4GxyIAffsGKPfcTPOFh8Bc7oEHtaZY,12108
10
10
  broccoli/tensor.py,sha256=E2JK5mQwJf75e23-JGcDoT7QxQf89DJReUo2et1LhRY,1716
11
- broccoli/transformer.py,sha256=23R58t3TLZMb9ulhCtQ3gXu0mPlfyPvLM8TaGOpaz58,16310
11
+ broccoli/transformer.py,sha256=-b7XG51c0ZgDoXQRyEmrZH6IvBCH1etK0NTApwqNhpU,16302
12
12
  broccoli/utils.py,sha256=htq_hOsdhUhL0nJi9WkKiEYOjEoWqFpK5X49PtgTf-0,299
13
- broccoli/vit.py,sha256=s1QwZac3S-QjAFEujf0vDDMacYwr_aWE_1mFvFD18-4,14086
14
- broccoli_ml-0.4.3.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
15
- broccoli_ml-0.4.3.dist-info/METADATA,sha256=KzHrCtuxGAdNbi5ORjQg9KJSdgAK9OtcP6ubw3owZ9Q,1256
16
- broccoli_ml-0.4.3.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
17
- broccoli_ml-0.4.3.dist-info/RECORD,,
13
+ broccoli/vit.py,sha256=7XJp53TEcGo2jaWL-XtEd_RTAvl2er3DmAcXawm1SkY,14627
14
+ broccoli_ml-0.5.0.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
15
+ broccoli_ml-0.5.0.dist-info/METADATA,sha256=eN2lZ82O-87kkq6Kypo2MXzltUUKMvrulUDphy2-lNE,1256
16
+ broccoli_ml-0.5.0.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
17
+ broccoli_ml-0.5.0.dist-info/RECORD,,