broccoli-ml 10.1.0__tar.gz → 13.0.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: broccoli-ml
3
- Version: 10.1.0
3
+ Version: 13.0.0
4
4
  Summary: Some useful Pytorch models, circa 2025
5
5
  License: MIT
6
6
  Author: Nicholas Bailey
@@ -21,13 +21,10 @@ except ImportError:
21
21
  FLASH_ATTN = False
22
22
 
23
23
 
24
- class LayerScale(nn.Module):
25
- def __init__(self, dim, init_values=1e-4):
26
- super().__init__()
27
- self.nondecay_scale = nn.Parameter(init_values * torch.ones(dim))
28
-
29
- def forward(self, x):
30
- return x * self.nondecay_scale
24
+ def scale_parameters(torch_module: nn.Module, factor: float):
25
+ with torch.no_grad():
26
+ for param in torch_module.parameters():
27
+ param.mul_(factor)
31
28
 
32
29
 
33
30
  def drop_path(
@@ -92,6 +89,7 @@ class MHAttention(nn.Module):
92
89
  rotary_embedding=None,
93
90
  source_size=None,
94
91
  scaling="d",
92
+ beta=1.0,
95
93
  ):
96
94
  """
97
95
  Args:
@@ -120,6 +118,7 @@ class MHAttention(nn.Module):
120
118
  self.n_heads = n_heads
121
119
  assert embed_dim % n_heads == 0
122
120
  self.scaling = scaling
121
+ self.beta = beta
123
122
 
124
123
  self.head_dim = self.embed_dim // self.n_heads
125
124
 
@@ -201,17 +200,26 @@ class MHAttention(nn.Module):
201
200
  "`source_size` must be a tuple of 1, 2 or 3 integers"
202
201
  )
203
202
 
204
- q_util, q_img = q[:, : self.utility_tokens, :], q[:, self.utility_tokens :, :]
205
- k_util, k_img = k[:, : self.utility_tokens, :], k[:, self.utility_tokens :, :]
203
+ q = rearrange(q, "b t (h d) -> b t h d", h=self.n_heads)
204
+ k = rearrange(k, "b t (h d) -> b t h d", h=self.n_heads)
205
+
206
+ q_util, q_img = (
207
+ q[:, : self.utility_tokens, :, :],
208
+ q[:, self.utility_tokens :, :, :],
209
+ )
210
+ k_util, k_img = (
211
+ k[:, : self.utility_tokens, :, :],
212
+ k[:, self.utility_tokens :, :, :],
213
+ )
206
214
 
207
215
  q_img = rearrange(
208
216
  q_img,
209
- f"b ({spatial_dimension_names}) d -> b {spatial_dimension_names} d",
217
+ f"b ({spatial_dimension_names}) h d -> b {spatial_dimension_names} h d",
210
218
  **spatial_dimension_values,
211
219
  )
212
220
  k_img = rearrange(
213
221
  k_img,
214
- f"b ({spatial_dimension_names}) d -> b {spatial_dimension_names} d",
222
+ f"b ({spatial_dimension_names}) h d -> b {spatial_dimension_names} h d",
215
223
  **spatial_dimension_values,
216
224
  )
217
225
 
@@ -222,17 +230,20 @@ class MHAttention(nn.Module):
222
230
 
223
231
  q_img = rearrange(
224
232
  q_img,
225
- f"b {spatial_dimension_names} d -> b ({spatial_dimension_names}) d",
233
+ f"b {spatial_dimension_names} h d -> b ({spatial_dimension_names}) h d",
226
234
  )
227
235
  k_img = rearrange(
228
236
  k_img,
229
- f"b {spatial_dimension_names} d -> b ({spatial_dimension_names}) d",
237
+ f"b {spatial_dimension_names} h d -> b ({spatial_dimension_names}) h d",
230
238
  )
231
239
 
232
240
  # Re-combine the utility tokens and the RoPE-enhanced sequence tokens
233
241
  q = torch.cat([q_util, q_img], dim=1)
234
242
  k = torch.cat([k_util, k_img], dim=1)
235
243
 
244
+ q = rearrange(q, "b t h d -> b t (h d)")
245
+ k = rearrange(k, "b t h d -> b t (h d)")
246
+
236
247
  return q, k
237
248
 
238
249
  def project_qkv(
@@ -339,7 +350,10 @@ class MHAttention(nn.Module):
339
350
  self.q_proj.reset_parameters()
340
351
  self.k_proj.reset_parameters()
341
352
  self.v_proj.reset_parameters()
353
+ scale_parameters(self.v_proj, self.beta) # per Microsoft DeepNet
342
354
  self.out_proj.reset_parameters()
355
+ scale_parameters(self.out_proj, self.beta) # per Microsoft DeepNet
356
+
343
357
  if self.talking_heads:
344
358
  # Initialize close to identity
345
359
  nn.init.eye_(self.head_projection.weight)
@@ -363,17 +377,14 @@ class FeedforwardBlock(nn.Module):
363
377
  outer_dropout=None,
364
378
  linear_module_up=nn.Linear,
365
379
  linear_module_down=nn.Linear,
366
- pre_norm=True,
367
380
  normformer=False,
368
- post_norm=True,
369
- residual_path=True,
370
381
  checkpoint=True,
382
+ beta=1.0,
371
383
  ):
372
384
  super().__init__()
373
385
 
374
386
  self.checkpoint = checkpoint
375
- self.residual_path = residual_path
376
- self.post_norm = post_norm
387
+ self.beta = beta
377
388
  self.xglu = activation.__name__.endswith("GLU")
378
389
 
379
390
  if self.residual_path and (output_features < input_features):
@@ -411,7 +422,6 @@ class FeedforwardBlock(nn.Module):
411
422
 
412
423
  self.process = nn.Sequential(
413
424
  *[
414
- nn.LayerNorm(input_features) if pre_norm else nn.Identity(),
415
425
  self.linear_in,
416
426
  self.activation,
417
427
  self.inner_dropout,
@@ -461,12 +471,7 @@ class FeedforwardBlock(nn.Module):
461
471
  else:
462
472
  processed = self.process(x)
463
473
 
464
- if self.residual_path and self.post_norm:
465
- return self.layernorm(x + processed)
466
- elif self.residual_path:
467
- return x + processed
468
- else:
469
- return processed
474
+ return processed
470
475
 
471
476
  def reset_parameters(self):
472
477
  if self.post_norm:
@@ -477,8 +482,11 @@ class FeedforwardBlock(nn.Module):
477
482
  if hasattr(module, "reset_parameters"):
478
483
  module.reset_parameters()
479
484
 
485
+ scale_parameters(self.linear_in, self.beta) # per Microsoft DeepNet
486
+ scale_parameters(self.linear_out, self.beta)
487
+
480
488
 
481
- class TransformerBlock(nn.Module):
489
+ class EncoderBlock(nn.Module):
482
490
  """
483
491
  Performs LayerNorms first (as in PyTorch Transformers when norm_first=True),
484
492
  which is also what is seen in e.g.
@@ -513,7 +521,8 @@ class TransformerBlock(nn.Module):
513
521
  post_norm=False,
514
522
  normformer=False,
515
523
  checkpoint_ff=True,
516
- layerscale=True,
524
+ alpha=1.0,
525
+ beta=1.0,
517
526
  ):
518
527
  """
519
528
  Args:
@@ -525,22 +534,29 @@ class TransformerBlock(nn.Module):
525
534
 
526
535
  super().__init__()
527
536
 
537
+ if pre_norm and post_norm:
538
+ raise ValueError("A transformer cannot be both prenorm and postnorm.")
539
+
528
540
  self.pre_norm = pre_norm
529
541
  self.post_norm = post_norm
530
542
  self.normformer = normformer
531
543
 
544
+ self.alpha = alpha
545
+ self.beta = beta
546
+
532
547
  self.drop_path = DropPath(drop_prob=identity_probability, scale_by_keep=True)
533
548
 
534
- self.layer_norm_1 = nn.LayerNorm(d_model)
535
- self.layer_norm_2 = nn.LayerNorm(d_model)
536
- self.layer_norm_3 = nn.LayerNorm(d_model)
549
+ if self.pre_norm:
550
+ self.pre_attention_norm = nn.LayerNorm(d_model)
551
+ self.pre_mlp_norm = nn.LayerNorm(d_model)
537
552
 
538
- if layerscale:
539
- self.layerscale1 = LayerScale(d_model)
540
- self.layerscale2 = LayerScale(d_model)
541
- else:
542
- self.layerscale1 = nn.Identity()
543
- self.layerscale2 = nn.Identity()
553
+ if normformer:
554
+ self.normformer_norm = nn.LayerNorm(d_model)
555
+
556
+ if self.post_norm:
557
+ self.input_norm = nn.LayerNorm(d_model)
558
+ self.post_attention_norm = nn.LayerNorm(d_model)
559
+ self.post_mlp_norm = nn.LayerNorm(d_model)
544
560
 
545
561
  if relative_position_embedding:
546
562
  max_freq = int(max(source_size) / 2) # Suggested by Gemini!
@@ -566,6 +582,7 @@ class TransformerBlock(nn.Module):
566
582
  utility_tokens=utility_tokens,
567
583
  talking_heads=talking_heads,
568
584
  scaling=msa_scaling,
585
+ beta=beta,
569
586
  )
570
587
 
571
588
  # Submodule for the feedforward process
@@ -588,11 +605,9 @@ class TransformerBlock(nn.Module):
588
605
  if ff_linear_module_down is not None
589
606
  else linear_module
590
607
  ),
591
- pre_norm=False, # Handled outside the block
592
608
  normformer=normformer,
593
- post_norm=False, # Handled outside the block
594
- residual_path=False, # Handled outside the block
595
609
  checkpoint=checkpoint_ff,
610
+ beta=beta,
596
611
  )
597
612
 
598
613
  self.reset_parameters()
@@ -602,22 +617,34 @@ class TransformerBlock(nn.Module):
602
617
  return self.attn._kv_distance
603
618
 
604
619
  def forward(self, x):
620
+ if self.post_norm:
621
+ x = self.input_norm(x)
605
622
 
606
623
  if self.pre_norm:
607
- x = self.layer_norm_1(x)
608
- x = x + self.drop_path(self.layerscale1(self.attn(x, x, x)))
609
- x = self.layer_norm_2(x)
610
- x = x + self.drop_path(self.layerscale2(self.ff(x)))
611
- if self.post_norm: # i.e. in addition! Pre and post.
612
- x = self.layer_norm_3(x)
613
- elif self.post_norm: # i.e. only, not prenorm, just post
614
- x = x + self.drop_path(self.layerscale1(self.attn(x, x, x)))
615
- x = self.layer_norm_1(x)
616
- x = x + self.drop_path(self.layerscale2(self.ff(x)))
617
- x = self.layer_norm_2(x)
618
- else: # Not pre or post norm. Stand well back.
619
- x = x + self.drop_path(self.layerscale1(self.attn(x, x, x)))
620
- x = x + self.drop_path(self.layerscale2(self.ff(x)))
624
+ process_x = self.pre_attention_norm(x)
625
+ else:
626
+ process_x = x
627
+
628
+ processed = self.drop_path(self.attn(process_x, process_x, process_x))
629
+
630
+ if self.normformer:
631
+ processed = self.normformer_norm(processed)
632
+
633
+ x = self.alpha * x + processed
634
+
635
+ if self.post_norm:
636
+ x = self.post_attention_norm(x)
637
+ elif self.pre_norm:
638
+ process_x = self.pre_mlp_norm(x)
639
+ else:
640
+ process_x = x
641
+
642
+ processed = self.drop_path(self.ff(process_x))
643
+
644
+ x = self.alpha * x + processed
645
+
646
+ if self.post_norm:
647
+ x = self.post_mlp_norm(x)
621
648
 
622
649
  return x
623
650
 
@@ -625,16 +652,26 @@ class TransformerBlock(nn.Module):
625
652
  """
626
653
  Give back the attention scores used in this layer.
627
654
  """
655
+ # Fix: Use the correct attribute name 'pre_attention_norm'
628
656
  if self.pre_norm:
629
- x = self.layer_norm_1(x)
657
+ # We must normalize the input before measuring attention logits
658
+ # to match what the model actually sees during forward()
659
+ x = self.pre_attention_norm(x)
630
660
  return self.attn.attention_logits(x, x, x)
631
661
  else:
632
662
  return self.attn.attention_logits(x, x, x)
633
663
 
634
664
  def reset_parameters(self):
635
- self.layer_norm_1.reset_parameters()
636
- self.layer_norm_2.reset_parameters()
637
- self.layer_norm_3.reset_parameters()
665
+ if self.pre_norm:
666
+ self.pre_attention_norm.reset_parameters()
667
+ self.pre_mlp_norm.reset_parameters()
668
+
669
+ if self.post_norm:
670
+ self.post_attention_norm.reset_parameters()
671
+ self.post_mlp_norm.reset_parameters()
672
+
673
+ if self.normformer:
674
+ self.normformer_norm.reset_parameters()
638
675
 
639
676
  self.attn.reset_parameters()
640
677
  self.ff.reset_parameters()
@@ -675,7 +712,8 @@ class TransformerEncoder(nn.Module):
675
712
  normformer=False,
676
713
  msa_scaling="d",
677
714
  checkpoint_ff=True,
678
- layerscale=True,
715
+ alpha=1.0,
716
+ beta=1.0,
679
717
  ):
680
718
  """
681
719
  Args:
@@ -710,11 +748,6 @@ class TransformerEncoder(nn.Module):
710
748
  self._utility_tokens = utility_tokens
711
749
  self.return_utility_tokens = return_utility_tokens
712
750
 
713
- if layerscale:
714
- self.layerscale = LayerScale(d_model)
715
- else:
716
- self.layerscale = None
717
-
718
751
  # Initialise utility tokens with normal init, like usual Pytorch embeddings
719
752
  if self._utility_tokens:
720
753
  self._utility_token_embedding = nn.Parameter(
@@ -754,7 +787,7 @@ class TransformerEncoder(nn.Module):
754
787
 
755
788
  self.blocks = nn.ModuleList(
756
789
  [
757
- TransformerBlock(
790
+ EncoderBlock(
758
791
  self.full_sequence_length,
759
792
  d_model,
760
793
  n_heads,
@@ -779,7 +812,8 @@ class TransformerEncoder(nn.Module):
779
812
  post_norm=post_norm,
780
813
  normformer=normformer,
781
814
  checkpoint_ff=checkpoint_ff,
782
- layerscale=layerscale,
815
+ alpha=alpha,
816
+ beta=beta,
783
817
  )
784
818
  for i in range(n_layers)
785
819
  ]
@@ -807,10 +841,9 @@ class TransformerEncoder(nn.Module):
807
841
  0
808
842
  ) # to shape (1, seq_len) to broadcast over batch
809
843
  )
810
- if self.layerscale is not None:
811
- position_embedding = self.layerscale(position_embedding)
844
+ x += position_embedding
812
845
 
813
- return x + position_embedding
846
+ return x
814
847
 
815
848
  def forward(self, x):
816
849
 
@@ -158,7 +158,6 @@ class ViTEncoder(nn.Module):
158
158
  pooling_kernel_stride=2,
159
159
  pooling_padding=1,
160
160
  transformer_feedforward_first=True,
161
- transformer_initial_ff_residual_path=True,
162
161
  transformer_initial_ff_linear_module_up=None,
163
162
  transformer_initial_ff_linear_module_down=None,
164
163
  transformer_initial_ff_dropout=None,
@@ -187,7 +186,6 @@ class ViTEncoder(nn.Module):
187
186
  transformer_msa_dropout=0.1,
188
187
  transformer_stochastic_depth=0.1,
189
188
  transformer_checkpoint_ff=True,
190
- transformer_layerscale=True,
191
189
  linear_module=nn.Linear,
192
190
  ):
193
191
  super().__init__()
@@ -353,7 +351,8 @@ class ViTEncoder(nn.Module):
353
351
  normformer=transformer_normformer,
354
352
  post_norm=transformer_post_norm,
355
353
  checkpoint_ff=transformer_checkpoint_ff,
356
- layerscale=transformer_layerscale,
354
+ alpha=self.alpha,
355
+ beta=self.beta,
357
356
  )
358
357
  else:
359
358
  self.transformer = nn.Identity()
@@ -395,16 +394,14 @@ class ViTEncoder(nn.Module):
395
394
  or transformer_ff_linear_module_down
396
395
  or linear_module
397
396
  ),
398
- pre_norm=transformer_pre_norm,
399
397
  normformer=transformer_normformer,
400
- post_norm=transformer_post_norm,
401
- residual_path=transformer_initial_ff_residual_path,
402
398
  checkpoint=transformer_checkpoint_ff,
399
+ beta=self.beta,
403
400
  )
404
401
  else:
405
402
  self.initial_ff = nn.Identity()
406
403
 
407
- self.encoder = nn.Sequential(
404
+ self.preprocess = nn.Sequential(
408
405
  *[
409
406
  batchnormxd(in_channels) if initial_batch_norm else nn.Identity(),
410
407
  self.cnn,
@@ -414,19 +411,21 @@ class ViTEncoder(nn.Module):
414
411
  f"N C {spatial_dim_names} -> N ({spatial_dim_names}) C"
415
412
  ),
416
413
  self.pooling_channels_padding,
417
- self.initial_ff,
418
- self.transformer,
414
+ nn.LayerNorm(),
419
415
  ]
420
416
  )
421
417
 
422
418
  self.reset_parameters()
423
419
 
424
420
  def forward(self, x):
425
- return self.encoder(x)
421
+ x = self.preprocess(x)
422
+ x = x + self.initial_ff(x)
423
+ return self.transformer(x)
426
424
 
427
425
  def attention_logits(self, x):
428
- x = self.encoder[:-1](x)
429
- return self.encoder[-1].attention_logits(x)
426
+ x = self.preprocess(x)
427
+ x = x + self.initial_ff(x)
428
+ return self.transformer.attention_logits(x)
430
429
 
431
430
  def reset_parameters(self):
432
431
  for module in self.encoder:
@@ -460,7 +459,6 @@ class ViT(nn.Module):
460
459
  pooling_kernel_stride=2,
461
460
  pooling_padding=1,
462
461
  transformer_feedforward_first=True,
463
- transformer_initial_ff_residual_path=True,
464
462
  transformer_initial_ff_linear_module_up=None,
465
463
  transformer_initial_ff_linear_module_down=None,
466
464
  transformer_initial_ff_dropout=None,
@@ -489,7 +487,6 @@ class ViT(nn.Module):
489
487
  transformer_msa_dropout=0.1,
490
488
  transformer_stochastic_depth=0.1,
491
489
  transformer_checkpoint_ff=True,
492
- transformer_layerscale=True,
493
490
  head=SequencePoolClassificationHead,
494
491
  batch_norm_logits=True,
495
492
  logit_projection_layer=nn.Linear,
@@ -533,7 +530,6 @@ class ViT(nn.Module):
533
530
  pooling_kernel_stride=pooling_kernel_stride,
534
531
  pooling_padding=pooling_padding,
535
532
  transformer_feedforward_first=transformer_feedforward_first,
536
- transformer_initial_ff_residual_path=transformer_initial_ff_residual_path,
537
533
  transformer_initial_ff_linear_module_up=transformer_initial_ff_linear_module_up,
538
534
  transformer_initial_ff_linear_module_down=transformer_initial_ff_linear_module_down,
539
535
  transformer_initial_ff_dropout=transformer_initial_ff_dropout,
@@ -562,7 +558,6 @@ class ViT(nn.Module):
562
558
  transformer_msa_dropout=transformer_msa_dropout,
563
559
  transformer_stochastic_depth=transformer_stochastic_depth,
564
560
  transformer_checkpoint_ff=transformer_checkpoint_ff,
565
- transformer_layerscale=transformer_layerscale,
566
561
  linear_module=linear_module,
567
562
  )
568
563
 
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "broccoli-ml"
3
- version = "10.1.0"
3
+ version = "13.0.0"
4
4
  description = "Some useful Pytorch models, circa 2025"
5
5
  authors = [
6
6
  {name = "Nicholas Bailey"}
File without changes
File without changes