broccoli-ml 0.25.0__py3-none-any.whl → 0.27.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
broccoli/transformer.py CHANGED
@@ -299,9 +299,12 @@ class TransformerBlock(nn.Module):
299
299
  ):
300
300
  super().__init__()
301
301
 
302
+ self.pre_norm = pre_norm
303
+
302
304
  self.identity_probability = identity_probability
303
305
 
304
- self.layer_norm = nn.LayerNorm(d_model)
306
+ self.layer_norm_1 = nn.LayerNorm(d_model)
307
+ self.layer_norm_2 = nn.LayerNorm(d_model)
305
308
 
306
309
  if position_embedding_type == "relative":
307
310
  max_freq = int(max(source_size) / 2) # Suggested by Gemini!
@@ -358,12 +361,21 @@ class TransformerBlock(nn.Module):
358
361
  identity_x = shuffled[:identity_count, :, :]
359
362
  process_x = shuffled[identity_count:, :, :]
360
363
 
361
- norm_process_x = self.layer_norm(process_x)
362
- process_x = process_x + self.attn(
363
- norm_process_x, norm_process_x, norm_process_x
364
+ if self.pre_norm:
365
+ norm_process_x = self.layer_norm_1(process_x)
366
+ process_x = process_x + self.attn(
367
+ norm_process_x, norm_process_x, norm_process_x
368
+ )
369
+ process_x = process_x + self.ff(process_x)
370
+ else: # post-norm
371
+ process_x = process_x + self.attn(process_x, process_x, process_x)
372
+ norm_process_x = self.layer_norm_1(process_x)
373
+ process_x = process_x + self.ff(process_x)
374
+
375
+ # Always post norm as eventually we reach the classification head!
376
+ x = self.layer_norm_2(
377
+ torch.cat([identity_x, process_x])[unshuffle_indices, :, :].contiguous()
364
378
  )
365
- process_x = process_x + self.ff(process_x)
366
- x = torch.cat([identity_x, process_x])[unshuffle_indices, :, :].contiguous()
367
379
 
368
380
  return x
369
381
 
broccoli/vit.py CHANGED
@@ -306,7 +306,8 @@ class ViTEncoder(nn.Module):
306
306
  activation_kwargs=transformer_activation_kwargs,
307
307
  dropout=transformer_mlp_dropout,
308
308
  linear_module=linear_module,
309
- raw_input=not cnn,
309
+ pre_norm=transformer_pre_norm,
310
+ normformer=transformer_normformer,
310
311
  )
311
312
  else:
312
313
  self.initial_ff = nn.Identity()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: broccoli-ml
3
- Version: 0.25.0
3
+ Version: 0.27.0
4
4
  Summary: Some useful Pytorch models, circa 2025
5
5
  License: MIT
6
6
  Author: Nicholas Bailey
@@ -8,10 +8,10 @@ broccoli/eigenpatches.py,sha256=J6n2usN1oQuHEHYiBNyYpn_a9eQcHjOBiIlvSei520Y,2413
8
8
  broccoli/linear.py,sha256=7NkNvhtxzWAUoBJOuiPUIcr853HhI1cS71d8DwdMkJ0,4826
9
9
  broccoli/rope.py,sha256=hw7kBPNR9GQXj4GxyIAffsGKPfcTPOFh8Bc7oEHtaZY,12108
10
10
  broccoli/tensor.py,sha256=zhSOo9W24FEgN7U35wy3ZIJHnw3u4cepJO5heCw6vwU,4590
11
- broccoli/transformer.py,sha256=BBnbmGwmvTmpBL4LDykUTeOHXGzIKNYdNIm6X9LEcGA,16278
11
+ broccoli/transformer.py,sha256=barmcq4Y5X6iM6STFt2t52XlXhuZphQ98wX6UvUGMFU,16748
12
12
  broccoli/utils.py,sha256=htq_hOsdhUhL0nJi9WkKiEYOjEoWqFpK5X49PtgTf-0,299
13
- broccoli/vit.py,sha256=GNbwld8NtnDR2D-bBDrhuMEUWwBbH2KVq29nwmYdJto,16009
14
- broccoli_ml-0.25.0.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
15
- broccoli_ml-0.25.0.dist-info/METADATA,sha256=qkm7knCzbtTK3Cdc6bT2l4tTlcmm2oIJ61n8AmZMLMM,1257
16
- broccoli_ml-0.25.0.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
17
- broccoli_ml-0.25.0.dist-info/RECORD,,
13
+ broccoli/vit.py,sha256=nuXX2JoKoBTtF1tAH-11mL2R5ISgMHYsBbbgvcluV1s,16072
14
+ broccoli_ml-0.27.0.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
15
+ broccoli_ml-0.27.0.dist-info/METADATA,sha256=dzd2nejaPCG976rKlMnDTi5RzpWAeh87Yf2LsPb2Tzw,1257
16
+ broccoli_ml-0.27.0.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
17
+ broccoli_ml-0.27.0.dist-info/RECORD,,