broccoli-ml 9.6.0__tar.gz → 13.0.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: broccoli-ml
3
- Version: 9.6.0
3
+ Version: 13.0.0
4
4
  Summary: Some useful Pytorch models, circa 2025
5
5
  License: MIT
6
6
  Author: Nicholas Bailey
@@ -1,3 +1,4 @@
1
+ import warnings
1
2
  import math
2
3
  from typing import Optional, Tuple
3
4
 
@@ -20,6 +21,12 @@ except ImportError:
20
21
  FLASH_ATTN = False
21
22
 
22
23
 
24
+ def scale_parameters(torch_module: nn.Module, factor: float):
25
+ with torch.no_grad():
26
+ for param in torch_module.parameters():
27
+ param.mul_(factor)
28
+
29
+
23
30
  def drop_path(
24
31
  x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
25
32
  ):
@@ -78,9 +85,11 @@ class MHAttention(nn.Module):
78
85
  seq_len=None,
79
86
  linear_module: nn.Module = nn.Linear,
80
87
  utility_tokens=0,
88
+ talking_heads=False,
81
89
  rotary_embedding=None,
82
90
  source_size=None,
83
91
  scaling="d",
92
+ beta=1.0,
84
93
  ):
85
94
  """
86
95
  Args:
@@ -96,10 +105,20 @@ class MHAttention(nn.Module):
96
105
  if causal:
97
106
  assert seq_len is not None
98
107
 
108
+ self.talking_heads = talking_heads
109
+
110
+ if self.talking_heads:
111
+ self.head_projection = nn.Linear(n_heads, n_heads, bias=False)
112
+ self.sample_projection = nn.Linear(n_heads, n_heads, bias=False)
113
+ else:
114
+ self.head_projection = None
115
+ self.sample_projection = None
116
+
99
117
  self.embed_dim = embed_dim
100
118
  self.n_heads = n_heads
101
119
  assert embed_dim % n_heads == 0
102
120
  self.scaling = scaling
121
+ self.beta = beta
103
122
 
104
123
  self.head_dim = self.embed_dim // self.n_heads
105
124
 
@@ -181,17 +200,26 @@ class MHAttention(nn.Module):
181
200
  "`source_size` must be a tuple of 1, 2 or 3 integers"
182
201
  )
183
202
 
184
- q_util, q_img = q[:, : self.utility_tokens, :], q[:, self.utility_tokens :, :]
185
- k_util, k_img = k[:, : self.utility_tokens, :], k[:, self.utility_tokens :, :]
203
+ q = rearrange(q, "b t (h d) -> b t h d", h=self.n_heads)
204
+ k = rearrange(k, "b t (h d) -> b t h d", h=self.n_heads)
205
+
206
+ q_util, q_img = (
207
+ q[:, : self.utility_tokens, :, :],
208
+ q[:, self.utility_tokens :, :, :],
209
+ )
210
+ k_util, k_img = (
211
+ k[:, : self.utility_tokens, :, :],
212
+ k[:, self.utility_tokens :, :, :],
213
+ )
186
214
 
187
215
  q_img = rearrange(
188
216
  q_img,
189
- f"b ({spatial_dimension_names}) d -> b {spatial_dimension_names} d",
217
+ f"b ({spatial_dimension_names}) h d -> b {spatial_dimension_names} h d",
190
218
  **spatial_dimension_values,
191
219
  )
192
220
  k_img = rearrange(
193
221
  k_img,
194
- f"b ({spatial_dimension_names}) d -> b {spatial_dimension_names} d",
222
+ f"b ({spatial_dimension_names}) h d -> b {spatial_dimension_names} h d",
195
223
  **spatial_dimension_values,
196
224
  )
197
225
 
@@ -202,17 +230,20 @@ class MHAttention(nn.Module):
202
230
 
203
231
  q_img = rearrange(
204
232
  q_img,
205
- f"b {spatial_dimension_names} d -> b ({spatial_dimension_names}) d",
233
+ f"b {spatial_dimension_names} h d -> b ({spatial_dimension_names}) h d",
206
234
  )
207
235
  k_img = rearrange(
208
236
  k_img,
209
- f"b {spatial_dimension_names} d -> b ({spatial_dimension_names}) d",
237
+ f"b {spatial_dimension_names} h d -> b ({spatial_dimension_names}) h d",
210
238
  )
211
239
 
212
240
  # Re-combine the utility tokens and the RoPE-enhanced sequence tokens
213
241
  q = torch.cat([q_util, q_img], dim=1)
214
242
  k = torch.cat([k_util, k_img], dim=1)
215
243
 
244
+ q = rearrange(q, "b t h d -> b t (h d)")
245
+ k = rearrange(k, "b t h d -> b t (h d)")
246
+
216
247
  return q, k
217
248
 
218
249
  def project_qkv(
@@ -243,7 +274,7 @@ class MHAttention(nn.Module):
243
274
 
244
275
  q, k, v = self.project_qkv(q, k, v)
245
276
 
246
- if FLASH_ATTN:
277
+ if FLASH_ATTN and not self.talking_heads:
247
278
  # Divide Q/K/V into heads
248
279
  q = rearrange(q, "b t (h d) -> b t h d", h=self.n_heads)
249
280
  k = rearrange(k, "b t (h d) -> b t h d", h=self.n_heads)
@@ -271,12 +302,22 @@ class MHAttention(nn.Module):
271
302
 
272
303
  qk_scores *= self.scaling_factor
273
304
 
305
+ if self.talking_heads:
306
+ qk_scores = torch.einsum(
307
+ "b h i j, o h -> b o i j", qk_scores, self.head_projection.weight
308
+ )
309
+
274
310
  # Apply mask if causal (must come before softmax)
275
311
  if self.causal:
276
312
  qk_scores.masked_fill_(self.mask, float("-inf"))
277
313
 
278
314
  qk_scores = F.softmax(qk_scores, dim=-1)
279
315
 
316
+ if self.talking_heads:
317
+ qk_scores = torch.einsum(
318
+ "b h i j, o h -> b o i j", qk_scores, self.sample_projection.weight
319
+ )
320
+
280
321
  qk_scores = self.dropout(qk_scores)
281
322
 
282
323
  output_with_heads = qk_scores @ v
@@ -309,7 +350,14 @@ class MHAttention(nn.Module):
309
350
  self.q_proj.reset_parameters()
310
351
  self.k_proj.reset_parameters()
311
352
  self.v_proj.reset_parameters()
353
+ scale_parameters(self.v_proj, self.beta) # per Microsoft DeepNet
312
354
  self.out_proj.reset_parameters()
355
+ scale_parameters(self.out_proj, self.beta) # per Microsoft DeepNet
356
+
357
+ if self.talking_heads:
358
+ # Initialize close to identity
359
+ nn.init.eye_(self.head_projection.weight)
360
+ nn.init.eye_(self.sample_projection.weight)
313
361
 
314
362
 
315
363
  class FeedforwardBlock(nn.Module):
@@ -329,17 +377,14 @@ class FeedforwardBlock(nn.Module):
329
377
  outer_dropout=None,
330
378
  linear_module_up=nn.Linear,
331
379
  linear_module_down=nn.Linear,
332
- pre_norm=True,
333
380
  normformer=False,
334
- post_norm=True,
335
- residual_path=True,
336
381
  checkpoint=True,
382
+ beta=1.0,
337
383
  ):
338
384
  super().__init__()
339
385
 
340
386
  self.checkpoint = checkpoint
341
- self.residual_path = residual_path
342
- self.post_norm = post_norm
387
+ self.beta = beta
343
388
  self.xglu = activation.__name__.endswith("GLU")
344
389
 
345
390
  if self.residual_path and (output_features < input_features):
@@ -365,19 +410,26 @@ class FeedforwardBlock(nn.Module):
365
410
  )
366
411
 
367
412
  self.max_features = (
368
- 2 * ratio * output_features if self.xglu else ratio * output_features
413
+ 2 * int(ratio * output_features)
414
+ if self.xglu
415
+ else int(ratio * output_features)
369
416
  )
370
417
 
371
418
  self.linear_in = linear_module_up(input_features, self.max_features)
372
- self.linear_out = linear_module_down(ratio * output_features, output_features)
419
+ self.linear_out = linear_module_down(
420
+ int(ratio * output_features), output_features
421
+ )
373
422
 
374
423
  self.process = nn.Sequential(
375
424
  *[
376
- nn.LayerNorm(input_features) if pre_norm else nn.Identity(),
377
425
  self.linear_in,
378
426
  self.activation,
379
427
  self.inner_dropout,
380
- nn.LayerNorm(ratio * output_features) if normformer else nn.Identity(),
428
+ (
429
+ nn.LayerNorm(int(ratio * output_features))
430
+ if normformer
431
+ else nn.Identity()
432
+ ),
381
433
  self.linear_out,
382
434
  self.outer_dropout,
383
435
  ]
@@ -419,12 +471,7 @@ class FeedforwardBlock(nn.Module):
419
471
  else:
420
472
  processed = self.process(x)
421
473
 
422
- if self.residual_path and self.post_norm:
423
- return self.layernorm(x + processed)
424
- elif self.residual_path:
425
- return x + processed
426
- else:
427
- return processed
474
+ return processed
428
475
 
429
476
  def reset_parameters(self):
430
477
  if self.post_norm:
@@ -435,8 +482,11 @@ class FeedforwardBlock(nn.Module):
435
482
  if hasattr(module, "reset_parameters"):
436
483
  module.reset_parameters()
437
484
 
485
+ scale_parameters(self.linear_in, self.beta) # per Microsoft DeepNet
486
+ scale_parameters(self.linear_out, self.beta)
487
+
438
488
 
439
- class TransformerBlock(nn.Module):
489
+ class EncoderBlock(nn.Module):
440
490
  """
441
491
  Performs LayerNorms first (as in PyTorch Transformers when norm_first=True),
442
492
  which is also what is seen in e.g.
@@ -453,6 +503,7 @@ class TransformerBlock(nn.Module):
453
503
  relative_position_embedding=False,
454
504
  source_size=None,
455
505
  utility_tokens=0,
506
+ talking_heads=False,
456
507
  mlp_ratio=4,
457
508
  activation: nn.Module = nn.ReLU,
458
509
  activation_kwargs: Optional[dict] = None,
@@ -470,6 +521,8 @@ class TransformerBlock(nn.Module):
470
521
  post_norm=False,
471
522
  normformer=False,
472
523
  checkpoint_ff=True,
524
+ alpha=1.0,
525
+ beta=1.0,
473
526
  ):
474
527
  """
475
528
  Args:
@@ -481,15 +534,29 @@ class TransformerBlock(nn.Module):
481
534
 
482
535
  super().__init__()
483
536
 
537
+ if pre_norm and post_norm:
538
+ raise ValueError("A transformer cannot be both prenorm and postnorm.")
539
+
484
540
  self.pre_norm = pre_norm
485
541
  self.post_norm = post_norm
486
542
  self.normformer = normformer
487
543
 
544
+ self.alpha = alpha
545
+ self.beta = beta
546
+
488
547
  self.drop_path = DropPath(drop_prob=identity_probability, scale_by_keep=True)
489
548
 
490
- self.layer_norm_1 = nn.LayerNorm(d_model)
491
- self.layer_norm_2 = nn.LayerNorm(d_model)
492
- self.layer_norm_3 = nn.LayerNorm(d_model)
549
+ if self.pre_norm:
550
+ self.pre_attention_norm = nn.LayerNorm(d_model)
551
+ self.pre_mlp_norm = nn.LayerNorm(d_model)
552
+
553
+ if normformer:
554
+ self.normformer_norm = nn.LayerNorm(d_model)
555
+
556
+ if self.post_norm:
557
+ self.input_norm = nn.LayerNorm(d_model)
558
+ self.post_attention_norm = nn.LayerNorm(d_model)
559
+ self.post_mlp_norm = nn.LayerNorm(d_model)
493
560
 
494
561
  if relative_position_embedding:
495
562
  max_freq = int(max(source_size) / 2) # Suggested by Gemini!
@@ -513,7 +580,9 @@ class TransformerBlock(nn.Module):
513
580
  rotary_embedding=self.rotary_embedding,
514
581
  source_size=source_size,
515
582
  utility_tokens=utility_tokens,
583
+ talking_heads=talking_heads,
516
584
  scaling=msa_scaling,
585
+ beta=beta,
517
586
  )
518
587
 
519
588
  # Submodule for the feedforward process
@@ -536,11 +605,9 @@ class TransformerBlock(nn.Module):
536
605
  if ff_linear_module_down is not None
537
606
  else linear_module
538
607
  ),
539
- pre_norm=False, # Handled outside the block
540
608
  normformer=normformer,
541
- post_norm=False, # Handled outside the block
542
- residual_path=False, # Handled outside the block
543
609
  checkpoint=checkpoint_ff,
610
+ beta=beta,
544
611
  )
545
612
 
546
613
  self.reset_parameters()
@@ -550,22 +617,34 @@ class TransformerBlock(nn.Module):
550
617
  return self.attn._kv_distance
551
618
 
552
619
  def forward(self, x):
620
+ if self.post_norm:
621
+ x = self.input_norm(x)
553
622
 
554
623
  if self.pre_norm:
555
- x = self.layer_norm_1(x)
556
- x = x + self.drop_path(self.attn(x, x, x))
557
- x = self.layer_norm_2(x)
558
- x = x + self.drop_path(self.ff(x))
559
- if self.post_norm: # i.e. in addition! Pre and post.
560
- x = self.layer_norm_3(x)
561
- elif self.post_norm: # i.e. only, not prenorm, just post
562
- x = x + self.drop_path(self.attn(x, x, x))
563
- x = self.layer_norm_1(x)
564
- x = x + self.drop_path(self.ff(x))
565
- x = self.layer_norm_2(x)
566
- else: # Not pre or post norm. Stand well back.
567
- x = x + self.drop_path(self.attn(x, x, x))
568
- x = x + self.drop_path(self.ff(x))
624
+ process_x = self.pre_attention_norm(x)
625
+ else:
626
+ process_x = x
627
+
628
+ processed = self.drop_path(self.attn(process_x, process_x, process_x))
629
+
630
+ if self.normformer:
631
+ processed = self.normformer_norm(processed)
632
+
633
+ x = self.alpha * x + processed
634
+
635
+ if self.post_norm:
636
+ x = self.post_attention_norm(x)
637
+ elif self.pre_norm:
638
+ process_x = self.pre_mlp_norm(x)
639
+ else:
640
+ process_x = x
641
+
642
+ processed = self.drop_path(self.ff(process_x))
643
+
644
+ x = self.alpha * x + processed
645
+
646
+ if self.post_norm:
647
+ x = self.post_mlp_norm(x)
569
648
 
570
649
  return x
571
650
 
@@ -573,16 +652,26 @@ class TransformerBlock(nn.Module):
573
652
  """
574
653
  Give back the attention scores used in this layer.
575
654
  """
655
+ # Fix: Use the correct attribute name 'pre_attention_norm'
576
656
  if self.pre_norm:
577
- x = self.layer_norm_1(x)
657
+ # We must normalize the input before measuring attention logits
658
+ # to match what the model actually sees during forward()
659
+ x = self.pre_attention_norm(x)
578
660
  return self.attn.attention_logits(x, x, x)
579
661
  else:
580
662
  return self.attn.attention_logits(x, x, x)
581
663
 
582
664
  def reset_parameters(self):
583
- self.layer_norm_1.reset_parameters()
584
- self.layer_norm_2.reset_parameters()
585
- self.layer_norm_3.reset_parameters()
665
+ if self.pre_norm:
666
+ self.pre_attention_norm.reset_parameters()
667
+ self.pre_mlp_norm.reset_parameters()
668
+
669
+ if self.post_norm:
670
+ self.post_attention_norm.reset_parameters()
671
+ self.post_mlp_norm.reset_parameters()
672
+
673
+ if self.normformer:
674
+ self.normformer_norm.reset_parameters()
586
675
 
587
676
  self.attn.reset_parameters()
588
677
  self.ff.reset_parameters()
@@ -616,12 +705,15 @@ class TransformerEncoder(nn.Module):
616
705
  causal=False,
617
706
  linear_module=nn.Linear,
618
707
  utility_tokens=0,
708
+ talking_heads=False,
619
709
  return_utility_tokens=False,
620
710
  pre_norm=True,
621
711
  post_norm=False,
622
712
  normformer=False,
623
713
  msa_scaling="d",
624
714
  checkpoint_ff=True,
715
+ alpha=1.0,
716
+ beta=1.0,
625
717
  ):
626
718
  """
627
719
  Args:
@@ -644,6 +736,13 @@ class TransformerEncoder(nn.Module):
644
736
  )
645
737
 
646
738
  super().__init__()
739
+
740
+ if FLASH_ATTN and talking_heads:
741
+ warnings.warn(
742
+ "Using talking heads currently prevents using flash attention.",
743
+ stacklevel=2,
744
+ )
745
+
647
746
  self.seq_len = seq_len
648
747
  self.n_heads = n_heads
649
748
  self._utility_tokens = utility_tokens
@@ -688,13 +787,14 @@ class TransformerEncoder(nn.Module):
688
787
 
689
788
  self.blocks = nn.ModuleList(
690
789
  [
691
- TransformerBlock(
790
+ EncoderBlock(
692
791
  self.full_sequence_length,
693
792
  d_model,
694
793
  n_heads,
695
794
  relative_position_embedding=relative_position_embedding,
696
795
  source_size=source_size,
697
796
  utility_tokens=utility_tokens,
797
+ talking_heads=talking_heads,
698
798
  mlp_ratio=mlp_ratio,
699
799
  activation=activation,
700
800
  activation_kwargs=activation_kwargs,
@@ -712,6 +812,8 @@ class TransformerEncoder(nn.Module):
712
812
  post_norm=post_norm,
713
813
  normformer=normformer,
714
814
  checkpoint_ff=checkpoint_ff,
815
+ alpha=alpha,
816
+ beta=beta,
715
817
  )
716
818
  for i in range(n_layers)
717
819
  ]
@@ -732,13 +834,14 @@ class TransformerEncoder(nn.Module):
732
834
  x = x
733
835
 
734
836
  if self.absolute_position_embedding is not None:
735
- x = x + self.absolute_position_embedding(
837
+ position_embedding = self.absolute_position_embedding(
736
838
  torch.arange(
737
839
  0, self.full_sequence_length, dtype=torch.long, device=x.device
738
840
  ).unsqueeze(
739
841
  0
740
842
  ) # to shape (1, seq_len) to broadcast over batch
741
843
  )
844
+ x += position_embedding
742
845
 
743
846
  return x
744
847
 
@@ -158,7 +158,6 @@ class ViTEncoder(nn.Module):
158
158
  pooling_kernel_stride=2,
159
159
  pooling_padding=1,
160
160
  transformer_feedforward_first=True,
161
- transformer_initial_ff_residual_path=True,
162
161
  transformer_initial_ff_linear_module_up=None,
163
162
  transformer_initial_ff_linear_module_down=None,
164
163
  transformer_initial_ff_dropout=None,
@@ -174,6 +173,7 @@ class ViTEncoder(nn.Module):
174
173
  transformer_heads=4,
175
174
  transformer_mlp_ratio=2,
176
175
  transformer_utility_tokens=0,
176
+ transformer_talking_heads=False,
177
177
  transformer_return_utility_tokens=False,
178
178
  transformer_activation: nn.Module = SquaredReLU,
179
179
  transformer_activation_kwargs: Optional[dict] = None,
@@ -345,11 +345,14 @@ class ViTEncoder(nn.Module):
345
345
  causal=False,
346
346
  linear_module=linear_module,
347
347
  utility_tokens=transformer_utility_tokens,
348
+ talking_heads=transformer_talking_heads,
348
349
  return_utility_tokens=transformer_return_utility_tokens,
349
350
  pre_norm=transformer_pre_norm,
350
351
  normformer=transformer_normformer,
351
352
  post_norm=transformer_post_norm,
352
353
  checkpoint_ff=transformer_checkpoint_ff,
354
+ alpha=self.alpha,
355
+ beta=self.beta,
353
356
  )
354
357
  else:
355
358
  self.transformer = nn.Identity()
@@ -391,16 +394,14 @@ class ViTEncoder(nn.Module):
391
394
  or transformer_ff_linear_module_down
392
395
  or linear_module
393
396
  ),
394
- pre_norm=transformer_pre_norm,
395
397
  normformer=transformer_normformer,
396
- post_norm=transformer_post_norm,
397
- residual_path=transformer_initial_ff_residual_path,
398
398
  checkpoint=transformer_checkpoint_ff,
399
+ beta=self.beta,
399
400
  )
400
401
  else:
401
402
  self.initial_ff = nn.Identity()
402
403
 
403
- self.encoder = nn.Sequential(
404
+ self.preprocess = nn.Sequential(
404
405
  *[
405
406
  batchnormxd(in_channels) if initial_batch_norm else nn.Identity(),
406
407
  self.cnn,
@@ -410,19 +411,21 @@ class ViTEncoder(nn.Module):
410
411
  f"N C {spatial_dim_names} -> N ({spatial_dim_names}) C"
411
412
  ),
412
413
  self.pooling_channels_padding,
413
- self.initial_ff,
414
- self.transformer,
414
+ nn.LayerNorm(),
415
415
  ]
416
416
  )
417
417
 
418
418
  self.reset_parameters()
419
419
 
420
420
  def forward(self, x):
421
- return self.encoder(x)
421
+ x = self.preprocess(x)
422
+ x = x + self.initial_ff(x)
423
+ return self.transformer(x)
422
424
 
423
425
  def attention_logits(self, x):
424
- x = self.encoder[:-1](x)
425
- return self.encoder[-1].attention_logits(x)
426
+ x = self.preprocess(x)
427
+ x = x + self.initial_ff(x)
428
+ return self.transformer.attention_logits(x)
426
429
 
427
430
  def reset_parameters(self):
428
431
  for module in self.encoder:
@@ -456,7 +459,6 @@ class ViT(nn.Module):
456
459
  pooling_kernel_stride=2,
457
460
  pooling_padding=1,
458
461
  transformer_feedforward_first=True,
459
- transformer_initial_ff_residual_path=True,
460
462
  transformer_initial_ff_linear_module_up=None,
461
463
  transformer_initial_ff_linear_module_down=None,
462
464
  transformer_initial_ff_dropout=None,
@@ -472,6 +474,7 @@ class ViT(nn.Module):
472
474
  transformer_heads=4,
473
475
  transformer_mlp_ratio=2,
474
476
  transformer_utility_tokens=0,
477
+ transformer_talking_heads=False,
475
478
  transformer_return_utility_tokens=False,
476
479
  transformer_activation: nn.Module = SquaredReLU,
477
480
  transformer_activation_kwargs: Optional[dict] = None,
@@ -527,7 +530,6 @@ class ViT(nn.Module):
527
530
  pooling_kernel_stride=pooling_kernel_stride,
528
531
  pooling_padding=pooling_padding,
529
532
  transformer_feedforward_first=transformer_feedforward_first,
530
- transformer_initial_ff_residual_path=transformer_initial_ff_residual_path,
531
533
  transformer_initial_ff_linear_module_up=transformer_initial_ff_linear_module_up,
532
534
  transformer_initial_ff_linear_module_down=transformer_initial_ff_linear_module_down,
533
535
  transformer_initial_ff_dropout=transformer_initial_ff_dropout,
@@ -543,6 +545,7 @@ class ViT(nn.Module):
543
545
  transformer_heads=transformer_heads,
544
546
  transformer_mlp_ratio=transformer_mlp_ratio,
545
547
  transformer_utility_tokens=transformer_utility_tokens,
548
+ transformer_talking_heads=transformer_talking_heads,
546
549
  transformer_return_utility_tokens=transformer_return_utility_tokens,
547
550
  transformer_activation=transformer_activation,
548
551
  transformer_activation_kwargs=transformer_activation_kwargs,
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "broccoli-ml"
3
- version = "9.6.0"
3
+ version = "13.0.0"
4
4
  description = "Some useful Pytorch models, circa 2025"
5
5
  authors = [
6
6
  {name = "Nicholas Bailey"}
File without changes
File without changes