broccoli-ml 9.7.0__tar.gz → 10.0.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: broccoli-ml
3
- Version: 9.7.0
3
+ Version: 10.0.1
4
4
  Summary: Some useful Pytorch models, circa 2025
5
5
  License: MIT
6
6
  Author: Nicholas Bailey
@@ -21,6 +21,15 @@ except ImportError:
21
21
  FLASH_ATTN = False
22
22
 
23
23
 
24
+ class LayerScale(nn.Module):
25
+ def __init__(self, dim, init_values=1e-4):
26
+ super().__init__()
27
+ self.nondecay_scale = nn.Parameter(init_values * torch.ones(dim))
28
+
29
+ def forward(self, x):
30
+ return x * self.nondecay_scale
31
+
32
+
24
33
  def drop_path(
25
34
  x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
26
35
  ):
@@ -390,11 +399,15 @@ class FeedforwardBlock(nn.Module):
390
399
  )
391
400
 
392
401
  self.max_features = (
393
- 2 * ratio * output_features if self.xglu else ratio * output_features
402
+ 2 * int(ratio * output_features)
403
+ if self.xglu
404
+ else int(ratio * output_features)
394
405
  )
395
406
 
396
407
  self.linear_in = linear_module_up(input_features, self.max_features)
397
- self.linear_out = linear_module_down(ratio * output_features, output_features)
408
+ self.linear_out = linear_module_down(
409
+ int(ratio * output_features), output_features
410
+ )
398
411
 
399
412
  self.process = nn.Sequential(
400
413
  *[
@@ -402,7 +415,11 @@ class FeedforwardBlock(nn.Module):
402
415
  self.linear_in,
403
416
  self.activation,
404
417
  self.inner_dropout,
405
- nn.LayerNorm(ratio * output_features) if normformer else nn.Identity(),
418
+ (
419
+ nn.LayerNorm(int(ratio * output_features))
420
+ if normformer
421
+ else nn.Identity()
422
+ ),
406
423
  self.linear_out,
407
424
  self.outer_dropout,
408
425
  ]
@@ -496,6 +513,7 @@ class TransformerBlock(nn.Module):
496
513
  post_norm=False,
497
514
  normformer=False,
498
515
  checkpoint_ff=True,
516
+ layerscale=True,
499
517
  ):
500
518
  """
501
519
  Args:
@@ -517,6 +535,13 @@ class TransformerBlock(nn.Module):
517
535
  self.layer_norm_2 = nn.LayerNorm(d_model)
518
536
  self.layer_norm_3 = nn.LayerNorm(d_model)
519
537
 
538
+ if layerscale:
539
+ self.layerscale1 = LayerScale(d_model)
540
+ self.layerscale2 = LayerScale(d_model)
541
+ else:
542
+ self.layerscale1 = nn.Identity()
543
+ self.layerscale2 = nn.Identity()
544
+
520
545
  if relative_position_embedding:
521
546
  max_freq = int(max(source_size) / 2) # Suggested by Gemini!
522
547
  if d_model < 16:
@@ -580,19 +605,19 @@ class TransformerBlock(nn.Module):
580
605
 
581
606
  if self.pre_norm:
582
607
  x = self.layer_norm_1(x)
583
- x = x + self.drop_path(self.attn(x, x, x))
608
+ x = x + self.drop_path(self.layerscale1(self.attn(x, x, x)))
584
609
  x = self.layer_norm_2(x)
585
- x = x + self.drop_path(self.ff(x))
610
+ x = x + self.drop_path(self.layerscale2(self.ff(x)))
586
611
  if self.post_norm: # i.e. in addition! Pre and post.
587
612
  x = self.layer_norm_3(x)
588
613
  elif self.post_norm: # i.e. only, not prenorm, just post
589
- x = x + self.drop_path(self.attn(x, x, x))
614
+ x = x + self.drop_path(self.layerscale1(self.attn(x, x, x)))
590
615
  x = self.layer_norm_1(x)
591
- x = x + self.drop_path(self.ff(x))
616
+ x = x + self.drop_path(self.layerscale2(self.ff(x)))
592
617
  x = self.layer_norm_2(x)
593
618
  else: # Not pre or post norm. Stand well back.
594
- x = x + self.drop_path(self.attn(x, x, x))
595
- x = x + self.drop_path(self.ff(x))
619
+ x = x + self.drop_path(self.layerscale1(self.attn(x, x, x)))
620
+ x = x + self.drop_path(self.layerscale2(self.ff(x)))
596
621
 
597
622
  return x
598
623
 
@@ -650,6 +675,7 @@ class TransformerEncoder(nn.Module):
650
675
  normformer=False,
651
676
  msa_scaling="d",
652
677
  checkpoint_ff=True,
678
+ layerscale=True,
653
679
  ):
654
680
  """
655
681
  Args:
@@ -748,6 +774,7 @@ class TransformerEncoder(nn.Module):
748
774
  post_norm=post_norm,
749
775
  normformer=normformer,
750
776
  checkpoint_ff=checkpoint_ff,
777
+ layerscale=layerscale,
751
778
  )
752
779
  for i in range(n_layers)
753
780
  ]
@@ -187,6 +187,7 @@ class ViTEncoder(nn.Module):
187
187
  transformer_msa_dropout=0.1,
188
188
  transformer_stochastic_depth=0.1,
189
189
  transformer_checkpoint_ff=True,
190
+ transformer_layerscale=True,
190
191
  linear_module=nn.Linear,
191
192
  ):
192
193
  super().__init__()
@@ -352,6 +353,7 @@ class ViTEncoder(nn.Module):
352
353
  normformer=transformer_normformer,
353
354
  post_norm=transformer_post_norm,
354
355
  checkpoint_ff=transformer_checkpoint_ff,
356
+ layerscale=transformer_layerscale,
355
357
  )
356
358
  else:
357
359
  self.transformer = nn.Identity()
@@ -487,6 +489,7 @@ class ViT(nn.Module):
487
489
  transformer_msa_dropout=0.1,
488
490
  transformer_stochastic_depth=0.1,
489
491
  transformer_checkpoint_ff=True,
492
+ transformer_layerscale=True,
490
493
  head=SequencePoolClassificationHead,
491
494
  batch_norm_logits=True,
492
495
  logit_projection_layer=nn.Linear,
@@ -559,6 +562,7 @@ class ViT(nn.Module):
559
562
  transformer_msa_dropout=transformer_msa_dropout,
560
563
  transformer_stochastic_depth=transformer_stochastic_depth,
561
564
  transformer_checkpoint_ff=transformer_checkpoint_ff,
565
+ transformer_layerscale=transformer_layerscale,
562
566
  linear_module=linear_module,
563
567
  )
564
568
 
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "broccoli-ml"
3
- version = "9.7.0"
3
+ version = "10.0.1"
4
4
  description = "Some useful Pytorch models, circa 2025"
5
5
  authors = [
6
6
  {name = "Nicholas Bailey"}
File without changes
File without changes