broccoli-ml 0.25.0__tar.gz → 0.27.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: broccoli-ml
3
- Version: 0.25.0
3
+ Version: 0.27.0
4
4
  Summary: Some useful Pytorch models, circa 2025
5
5
  License: MIT
6
6
  Author: Nicholas Bailey
@@ -299,9 +299,12 @@ class TransformerBlock(nn.Module):
299
299
  ):
300
300
  super().__init__()
301
301
 
302
+ self.pre_norm = pre_norm
303
+
302
304
  self.identity_probability = identity_probability
303
305
 
304
- self.layer_norm = nn.LayerNorm(d_model)
306
+ self.layer_norm_1 = nn.LayerNorm(d_model)
307
+ self.layer_norm_2 = nn.LayerNorm(d_model)
305
308
 
306
309
  if position_embedding_type == "relative":
307
310
  max_freq = int(max(source_size) / 2) # Suggested by Gemini!
@@ -358,12 +361,21 @@ class TransformerBlock(nn.Module):
358
361
  identity_x = shuffled[:identity_count, :, :]
359
362
  process_x = shuffled[identity_count:, :, :]
360
363
 
361
- norm_process_x = self.layer_norm(process_x)
362
- process_x = process_x + self.attn(
363
- norm_process_x, norm_process_x, norm_process_x
364
+ if self.pre_norm:
365
+ norm_process_x = self.layer_norm_1(process_x)
366
+ process_x = process_x + self.attn(
367
+ norm_process_x, norm_process_x, norm_process_x
368
+ )
369
+ process_x = process_x + self.ff(process_x)
370
+ else: # post-norm
371
+ process_x = process_x + self.attn(process_x, process_x, process_x)
372
+ norm_process_x = self.layer_norm_1(process_x)
373
+ process_x = process_x + self.ff(process_x)
374
+
375
+ # Always post norm as eventually we reach the classification head!
376
+ x = self.layer_norm_2(
377
+ torch.cat([identity_x, process_x])[unshuffle_indices, :, :].contiguous()
364
378
  )
365
- process_x = process_x + self.ff(process_x)
366
- x = torch.cat([identity_x, process_x])[unshuffle_indices, :, :].contiguous()
367
379
 
368
380
  return x
369
381
 
@@ -306,7 +306,8 @@ class ViTEncoder(nn.Module):
306
306
  activation_kwargs=transformer_activation_kwargs,
307
307
  dropout=transformer_mlp_dropout,
308
308
  linear_module=linear_module,
309
- raw_input=not cnn,
309
+ pre_norm=transformer_pre_norm,
310
+ normformer=transformer_normformer,
310
311
  )
311
312
  else:
312
313
  self.initial_ff = nn.Identity()
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "broccoli-ml"
3
- version = "0.25.0"
3
+ version = "0.27.0"
4
4
  description = "Some useful Pytorch models, circa 2025"
5
5
  authors = [
6
6
  {name = "Nicholas Bailey"}
File without changes
File without changes