broccoli-ml 0.4.1__py3-none-any.whl → 0.4.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
broccoli/vit.py CHANGED
@@ -27,7 +27,7 @@ class SequencePool(nn.Module):
27
27
  as a generalisation of average pooling.
28
28
  """
29
29
 
30
- def __init__(self, d_model, linear_module, out_dim):
30
+ def __init__(self, d_model, linear_module, out_dim, batch_norm=True):
31
31
  super().__init__()
32
32
  self.d_model = d_model
33
33
  self.attention = nn.Sequential(
@@ -38,7 +38,11 @@ class SequencePool(nn.Module):
38
38
  ]
39
39
  )
40
40
  self.projection = nn.Linear(d_model, out_dim)
41
- self.norm = nn.BatchNorm1d(out_dim, affine=False)
41
+ self.batch_norm = batch_norm
42
+ if batch_norm:
43
+ self.norm = nn.BatchNorm1d(out_dim, affine=False)
44
+ else:
45
+ self.norm = None
42
46
 
43
47
  def forward(self, x):
44
48
  weights = self.attention(x)
@@ -46,7 +50,7 @@ class SequencePool(nn.Module):
46
50
  weights, x, "batch seq, batch seq d_model -> batch d_model"
47
51
  )
48
52
  projection = self.projection(weighted_embedding)
49
- return self.norm(projection)
53
+ return self.norm(projection) if self.batch_norm else projection
50
54
 
51
55
 
52
56
  class DCTEncoder(nn.Module):
@@ -81,14 +85,14 @@ class DCTEncoder(nn.Module):
81
85
  transformer_layers=7,
82
86
  transformer_heads=4,
83
87
  transformer_mlp_ratio=2,
84
- transformer_bos_tokens=4,
88
+ transformer_bos_tokens=0,
85
89
  transformer_activation: nn.Module = SquaredReLU,
86
90
  transformer_activation_kwargs: Optional[dict] = None,
87
91
  mlp_dropout=0.0,
88
92
  msa_dropout=0.1,
89
93
  stochastic_depth=0.1,
90
94
  linear_module=nn.Linear,
91
- batch_norm=True,
95
+ initial_batch_norm=True,
92
96
  ):
93
97
  super().__init__()
94
98
 
@@ -191,7 +195,7 @@ class DCTEncoder(nn.Module):
191
195
  nn.Dropout(cnn_dropout),
192
196
  (
193
197
  batchnormxd(cnn_activation_out_channels)
194
- if batch_norm
198
+ if initial_batch_norm
195
199
  else nn.Identity()
196
200
  ),
197
201
  ]
@@ -279,7 +283,7 @@ class DCTEncoder(nn.Module):
279
283
 
280
284
  self.encoder = nn.Sequential(
281
285
  *[
282
- batchnormxd(cnn_in_channels) if batch_norm else nn.Identity(),
286
+ batchnormxd(cnn_in_channels) if initial_batch_norm else nn.Identity(),
283
287
  self.cnn,
284
288
  self.activate_and_dropout,
285
289
  self.pool,
@@ -322,14 +326,15 @@ class DCT(nn.Module):
322
326
  transformer_layers=7,
323
327
  transformer_heads=4,
324
328
  transformer_mlp_ratio=2,
325
- transformer_bos_tokens=4,
329
+ transformer_bos_tokens=0,
326
330
  transformer_activation: nn.Module = SquaredReLU,
327
331
  transformer_activation_kwargs: Optional[dict] = None,
328
332
  mlp_dropout=0.0,
329
333
  msa_dropout=0.1,
330
334
  stochastic_depth=0.1,
335
+ batch_norm_outputs=True,
331
336
  linear_module=nn.Linear,
332
- batch_norm=True,
337
+ initial_batch_norm=True,
333
338
  image_classes=100,
334
339
  ):
335
340
 
@@ -379,10 +384,13 @@ class DCT(nn.Module):
379
384
  msa_dropout=msa_dropout,
380
385
  stochastic_depth=stochastic_depth,
381
386
  linear_module=linear_module,
382
- batch_norm=batch_norm,
387
+ initial_batch_norm=initial_batch_norm,
383
388
  )
384
389
  self.pool = SequencePool(
385
- transformer_embedding_size, linear_module, image_classes
390
+ transformer_embedding_size,
391
+ linear_module,
392
+ image_classes,
393
+ batch_norm=batch_norm_outputs,
386
394
  )
387
395
 
388
396
  @property
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: broccoli-ml
3
- Version: 0.4.1
3
+ Version: 0.4.3
4
4
  Summary: Some useful Pytorch models, circa 2025
5
5
  License: MIT
6
6
  Author: Nicholas Bailey
@@ -10,8 +10,8 @@ broccoli/rope.py,sha256=hw7kBPNR9GQXj4GxyIAffsGKPfcTPOFh8Bc7oEHtaZY,12108
10
10
  broccoli/tensor.py,sha256=E2JK5mQwJf75e23-JGcDoT7QxQf89DJReUo2et1LhRY,1716
11
11
  broccoli/transformer.py,sha256=23R58t3TLZMb9ulhCtQ3gXu0mPlfyPvLM8TaGOpaz58,16310
12
12
  broccoli/utils.py,sha256=htq_hOsdhUhL0nJi9WkKiEYOjEoWqFpK5X49PtgTf-0,299
13
- broccoli/vit.py,sha256=d9nKhohlxpFbu3wzhNi53bYBNMUuShPfF6NXUAyDVA0,13778
14
- broccoli_ml-0.4.1.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
15
- broccoli_ml-0.4.1.dist-info/METADATA,sha256=qwzzby85q__wYdKBEYBgyW5D7q3GMPoRwzeHPY6Mf6s,1256
16
- broccoli_ml-0.4.1.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
17
- broccoli_ml-0.4.1.dist-info/RECORD,,
13
+ broccoli/vit.py,sha256=s1QwZac3S-QjAFEujf0vDDMacYwr_aWE_1mFvFD18-4,14086
14
+ broccoli_ml-0.4.3.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
15
+ broccoli_ml-0.4.3.dist-info/METADATA,sha256=KzHrCtuxGAdNbi5ORjQg9KJSdgAK9OtcP6ubw3owZ9Q,1256
16
+ broccoli_ml-0.4.3.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
17
+ broccoli_ml-0.4.3.dist-info/RECORD,,