broccoli-ml 9.6.0__tar.gz → 10.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: broccoli-ml
3
- Version: 9.6.0
3
+ Version: 10.1.0
4
4
  Summary: Some useful Pytorch models, circa 2025
5
5
  License: MIT
6
6
  Author: Nicholas Bailey
@@ -1,3 +1,4 @@
1
+ import warnings
1
2
  import math
2
3
  from typing import Optional, Tuple
3
4
 
@@ -20,6 +21,15 @@ except ImportError:
20
21
  FLASH_ATTN = False
21
22
 
22
23
 
24
+ class LayerScale(nn.Module):
25
+ def __init__(self, dim, init_values=1e-4):
26
+ super().__init__()
27
+ self.nondecay_scale = nn.Parameter(init_values * torch.ones(dim))
28
+
29
+ def forward(self, x):
30
+ return x * self.nondecay_scale
31
+
32
+
23
33
  def drop_path(
24
34
  x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
25
35
  ):
@@ -78,6 +88,7 @@ class MHAttention(nn.Module):
78
88
  seq_len=None,
79
89
  linear_module: nn.Module = nn.Linear,
80
90
  utility_tokens=0,
91
+ talking_heads=False,
81
92
  rotary_embedding=None,
82
93
  source_size=None,
83
94
  scaling="d",
@@ -96,6 +107,15 @@ class MHAttention(nn.Module):
96
107
  if causal:
97
108
  assert seq_len is not None
98
109
 
110
+ self.talking_heads = talking_heads
111
+
112
+ if self.talking_heads:
113
+ self.head_projection = nn.Linear(n_heads, n_heads, bias=False)
114
+ self.sample_projection = nn.Linear(n_heads, n_heads, bias=False)
115
+ else:
116
+ self.head_projection = None
117
+ self.sample_projection = None
118
+
99
119
  self.embed_dim = embed_dim
100
120
  self.n_heads = n_heads
101
121
  assert embed_dim % n_heads == 0
@@ -243,7 +263,7 @@ class MHAttention(nn.Module):
243
263
 
244
264
  q, k, v = self.project_qkv(q, k, v)
245
265
 
246
- if FLASH_ATTN:
266
+ if FLASH_ATTN and not self.talking_heads:
247
267
  # Divide Q/K/V into heads
248
268
  q = rearrange(q, "b t (h d) -> b t h d", h=self.n_heads)
249
269
  k = rearrange(k, "b t (h d) -> b t h d", h=self.n_heads)
@@ -271,12 +291,22 @@ class MHAttention(nn.Module):
271
291
 
272
292
  qk_scores *= self.scaling_factor
273
293
 
294
+ if self.talking_heads:
295
+ qk_scores = torch.einsum(
296
+ "b h i j, o h -> b o i j", qk_scores, self.head_projection.weight
297
+ )
298
+
274
299
  # Apply mask if causal (must come before softmax)
275
300
  if self.causal:
276
301
  qk_scores.masked_fill_(self.mask, float("-inf"))
277
302
 
278
303
  qk_scores = F.softmax(qk_scores, dim=-1)
279
304
 
305
+ if self.talking_heads:
306
+ qk_scores = torch.einsum(
307
+ "b h i j, o h -> b o i j", qk_scores, self.sample_projection.weight
308
+ )
309
+
280
310
  qk_scores = self.dropout(qk_scores)
281
311
 
282
312
  output_with_heads = qk_scores @ v
@@ -310,6 +340,10 @@ class MHAttention(nn.Module):
310
340
  self.k_proj.reset_parameters()
311
341
  self.v_proj.reset_parameters()
312
342
  self.out_proj.reset_parameters()
343
+ if self.talking_heads:
344
+ # Initialize close to identity
345
+ nn.init.eye_(self.head_projection.weight)
346
+ nn.init.eye_(self.sample_projection.weight)
313
347
 
314
348
 
315
349
  class FeedforwardBlock(nn.Module):
@@ -365,11 +399,15 @@ class FeedforwardBlock(nn.Module):
365
399
  )
366
400
 
367
401
  self.max_features = (
368
- 2 * ratio * output_features if self.xglu else ratio * output_features
402
+ 2 * int(ratio * output_features)
403
+ if self.xglu
404
+ else int(ratio * output_features)
369
405
  )
370
406
 
371
407
  self.linear_in = linear_module_up(input_features, self.max_features)
372
- self.linear_out = linear_module_down(ratio * output_features, output_features)
408
+ self.linear_out = linear_module_down(
409
+ int(ratio * output_features), output_features
410
+ )
373
411
 
374
412
  self.process = nn.Sequential(
375
413
  *[
@@ -377,7 +415,11 @@ class FeedforwardBlock(nn.Module):
377
415
  self.linear_in,
378
416
  self.activation,
379
417
  self.inner_dropout,
380
- nn.LayerNorm(ratio * output_features) if normformer else nn.Identity(),
418
+ (
419
+ nn.LayerNorm(int(ratio * output_features))
420
+ if normformer
421
+ else nn.Identity()
422
+ ),
381
423
  self.linear_out,
382
424
  self.outer_dropout,
383
425
  ]
@@ -453,6 +495,7 @@ class TransformerBlock(nn.Module):
453
495
  relative_position_embedding=False,
454
496
  source_size=None,
455
497
  utility_tokens=0,
498
+ talking_heads=False,
456
499
  mlp_ratio=4,
457
500
  activation: nn.Module = nn.ReLU,
458
501
  activation_kwargs: Optional[dict] = None,
@@ -470,6 +513,7 @@ class TransformerBlock(nn.Module):
470
513
  post_norm=False,
471
514
  normformer=False,
472
515
  checkpoint_ff=True,
516
+ layerscale=True,
473
517
  ):
474
518
  """
475
519
  Args:
@@ -491,6 +535,13 @@ class TransformerBlock(nn.Module):
491
535
  self.layer_norm_2 = nn.LayerNorm(d_model)
492
536
  self.layer_norm_3 = nn.LayerNorm(d_model)
493
537
 
538
+ if layerscale:
539
+ self.layerscale1 = LayerScale(d_model)
540
+ self.layerscale2 = LayerScale(d_model)
541
+ else:
542
+ self.layerscale1 = nn.Identity()
543
+ self.layerscale2 = nn.Identity()
544
+
494
545
  if relative_position_embedding:
495
546
  max_freq = int(max(source_size) / 2) # Suggested by Gemini!
496
547
  if d_model < 16:
@@ -513,6 +564,7 @@ class TransformerBlock(nn.Module):
513
564
  rotary_embedding=self.rotary_embedding,
514
565
  source_size=source_size,
515
566
  utility_tokens=utility_tokens,
567
+ talking_heads=talking_heads,
516
568
  scaling=msa_scaling,
517
569
  )
518
570
 
@@ -553,19 +605,19 @@ class TransformerBlock(nn.Module):
553
605
 
554
606
  if self.pre_norm:
555
607
  x = self.layer_norm_1(x)
556
- x = x + self.drop_path(self.attn(x, x, x))
608
+ x = x + self.drop_path(self.layerscale1(self.attn(x, x, x)))
557
609
  x = self.layer_norm_2(x)
558
- x = x + self.drop_path(self.ff(x))
610
+ x = x + self.drop_path(self.layerscale2(self.ff(x)))
559
611
  if self.post_norm: # i.e. in addition! Pre and post.
560
612
  x = self.layer_norm_3(x)
561
613
  elif self.post_norm: # i.e. only, not prenorm, just post
562
- x = x + self.drop_path(self.attn(x, x, x))
614
+ x = x + self.drop_path(self.layerscale1(self.attn(x, x, x)))
563
615
  x = self.layer_norm_1(x)
564
- x = x + self.drop_path(self.ff(x))
616
+ x = x + self.drop_path(self.layerscale2(self.ff(x)))
565
617
  x = self.layer_norm_2(x)
566
618
  else: # Not pre or post norm. Stand well back.
567
- x = x + self.drop_path(self.attn(x, x, x))
568
- x = x + self.drop_path(self.ff(x))
619
+ x = x + self.drop_path(self.layerscale1(self.attn(x, x, x)))
620
+ x = x + self.drop_path(self.layerscale2(self.ff(x)))
569
621
 
570
622
  return x
571
623
 
@@ -616,12 +668,14 @@ class TransformerEncoder(nn.Module):
616
668
  causal=False,
617
669
  linear_module=nn.Linear,
618
670
  utility_tokens=0,
671
+ talking_heads=False,
619
672
  return_utility_tokens=False,
620
673
  pre_norm=True,
621
674
  post_norm=False,
622
675
  normformer=False,
623
676
  msa_scaling="d",
624
677
  checkpoint_ff=True,
678
+ layerscale=True,
625
679
  ):
626
680
  """
627
681
  Args:
@@ -644,11 +698,23 @@ class TransformerEncoder(nn.Module):
644
698
  )
645
699
 
646
700
  super().__init__()
701
+
702
+ if FLASH_ATTN and talking_heads:
703
+ warnings.warn(
704
+ "Using talking heads currently prevents using flash attention.",
705
+ stacklevel=2,
706
+ )
707
+
647
708
  self.seq_len = seq_len
648
709
  self.n_heads = n_heads
649
710
  self._utility_tokens = utility_tokens
650
711
  self.return_utility_tokens = return_utility_tokens
651
712
 
713
+ if layerscale:
714
+ self.layerscale = LayerScale(d_model)
715
+ else:
716
+ self.layerscale = None
717
+
652
718
  # Initialise utility tokens with normal init, like usual Pytorch embeddings
653
719
  if self._utility_tokens:
654
720
  self._utility_token_embedding = nn.Parameter(
@@ -695,6 +761,7 @@ class TransformerEncoder(nn.Module):
695
761
  relative_position_embedding=relative_position_embedding,
696
762
  source_size=source_size,
697
763
  utility_tokens=utility_tokens,
764
+ talking_heads=talking_heads,
698
765
  mlp_ratio=mlp_ratio,
699
766
  activation=activation,
700
767
  activation_kwargs=activation_kwargs,
@@ -712,6 +779,7 @@ class TransformerEncoder(nn.Module):
712
779
  post_norm=post_norm,
713
780
  normformer=normformer,
714
781
  checkpoint_ff=checkpoint_ff,
782
+ layerscale=layerscale,
715
783
  )
716
784
  for i in range(n_layers)
717
785
  ]
@@ -732,15 +800,17 @@ class TransformerEncoder(nn.Module):
732
800
  x = x
733
801
 
734
802
  if self.absolute_position_embedding is not None:
735
- x = x + self.absolute_position_embedding(
803
+ position_embedding = self.absolute_position_embedding(
736
804
  torch.arange(
737
805
  0, self.full_sequence_length, dtype=torch.long, device=x.device
738
806
  ).unsqueeze(
739
807
  0
740
808
  ) # to shape (1, seq_len) to broadcast over batch
741
809
  )
810
+ if self.layerscale is not None:
811
+ position_embedding = self.layerscale(position_embedding)
742
812
 
743
- return x
813
+ return x + position_embedding
744
814
 
745
815
  def forward(self, x):
746
816
 
@@ -174,6 +174,7 @@ class ViTEncoder(nn.Module):
174
174
  transformer_heads=4,
175
175
  transformer_mlp_ratio=2,
176
176
  transformer_utility_tokens=0,
177
+ transformer_talking_heads=False,
177
178
  transformer_return_utility_tokens=False,
178
179
  transformer_activation: nn.Module = SquaredReLU,
179
180
  transformer_activation_kwargs: Optional[dict] = None,
@@ -186,6 +187,7 @@ class ViTEncoder(nn.Module):
186
187
  transformer_msa_dropout=0.1,
187
188
  transformer_stochastic_depth=0.1,
188
189
  transformer_checkpoint_ff=True,
190
+ transformer_layerscale=True,
189
191
  linear_module=nn.Linear,
190
192
  ):
191
193
  super().__init__()
@@ -345,11 +347,13 @@ class ViTEncoder(nn.Module):
345
347
  causal=False,
346
348
  linear_module=linear_module,
347
349
  utility_tokens=transformer_utility_tokens,
350
+ talking_heads=transformer_talking_heads,
348
351
  return_utility_tokens=transformer_return_utility_tokens,
349
352
  pre_norm=transformer_pre_norm,
350
353
  normformer=transformer_normformer,
351
354
  post_norm=transformer_post_norm,
352
355
  checkpoint_ff=transformer_checkpoint_ff,
356
+ layerscale=transformer_layerscale,
353
357
  )
354
358
  else:
355
359
  self.transformer = nn.Identity()
@@ -472,6 +476,7 @@ class ViT(nn.Module):
472
476
  transformer_heads=4,
473
477
  transformer_mlp_ratio=2,
474
478
  transformer_utility_tokens=0,
479
+ transformer_talking_heads=False,
475
480
  transformer_return_utility_tokens=False,
476
481
  transformer_activation: nn.Module = SquaredReLU,
477
482
  transformer_activation_kwargs: Optional[dict] = None,
@@ -484,6 +489,7 @@ class ViT(nn.Module):
484
489
  transformer_msa_dropout=0.1,
485
490
  transformer_stochastic_depth=0.1,
486
491
  transformer_checkpoint_ff=True,
492
+ transformer_layerscale=True,
487
493
  head=SequencePoolClassificationHead,
488
494
  batch_norm_logits=True,
489
495
  logit_projection_layer=nn.Linear,
@@ -543,6 +549,7 @@ class ViT(nn.Module):
543
549
  transformer_heads=transformer_heads,
544
550
  transformer_mlp_ratio=transformer_mlp_ratio,
545
551
  transformer_utility_tokens=transformer_utility_tokens,
552
+ transformer_talking_heads=transformer_talking_heads,
546
553
  transformer_return_utility_tokens=transformer_return_utility_tokens,
547
554
  transformer_activation=transformer_activation,
548
555
  transformer_activation_kwargs=transformer_activation_kwargs,
@@ -555,6 +562,7 @@ class ViT(nn.Module):
555
562
  transformer_msa_dropout=transformer_msa_dropout,
556
563
  transformer_stochastic_depth=transformer_stochastic_depth,
557
564
  transformer_checkpoint_ff=transformer_checkpoint_ff,
565
+ transformer_layerscale=transformer_layerscale,
558
566
  linear_module=linear_module,
559
567
  )
560
568
 
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "broccoli-ml"
3
- version = "9.6.0"
3
+ version = "10.1.0"
4
4
  description = "Some useful Pytorch models, circa 2025"
5
5
  authors = [
6
6
  {name = "Nicholas Bailey"}
File without changes
File without changes