broccoli-ml 9.7.0__tar.gz → 10.0.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {broccoli_ml-9.7.0 → broccoli_ml-10.0.1}/PKG-INFO +1 -1
- {broccoli_ml-9.7.0 → broccoli_ml-10.0.1}/broccoli/transformer.py +36 -9
- {broccoli_ml-9.7.0 → broccoli_ml-10.0.1}/broccoli/vit.py +4 -0
- {broccoli_ml-9.7.0 → broccoli_ml-10.0.1}/pyproject.toml +1 -1
- {broccoli_ml-9.7.0 → broccoli_ml-10.0.1}/LICENSE +0 -0
- {broccoli_ml-9.7.0 → broccoli_ml-10.0.1}/README.md +0 -0
- {broccoli_ml-9.7.0 → broccoli_ml-10.0.1}/broccoli/__init__.py +0 -0
- {broccoli_ml-9.7.0 → broccoli_ml-10.0.1}/broccoli/activation.py +0 -0
- {broccoli_ml-9.7.0 → broccoli_ml-10.0.1}/broccoli/cnn.py +0 -0
- {broccoli_ml-9.7.0 → broccoli_ml-10.0.1}/broccoli/linear.py +0 -0
- {broccoli_ml-9.7.0 → broccoli_ml-10.0.1}/broccoli/rope.py +0 -0
- {broccoli_ml-9.7.0 → broccoli_ml-10.0.1}/broccoli/tensor.py +0 -0
- {broccoli_ml-9.7.0 → broccoli_ml-10.0.1}/broccoli/utils.py +0 -0
|
@@ -21,6 +21,15 @@ except ImportError:
|
|
|
21
21
|
FLASH_ATTN = False
|
|
22
22
|
|
|
23
23
|
|
|
24
|
+
class LayerScale(nn.Module):
|
|
25
|
+
def __init__(self, dim, init_values=1e-4):
|
|
26
|
+
super().__init__()
|
|
27
|
+
self.nondecay_scale = nn.Parameter(init_values * torch.ones(dim))
|
|
28
|
+
|
|
29
|
+
def forward(self, x):
|
|
30
|
+
return x * self.nondecay_scale
|
|
31
|
+
|
|
32
|
+
|
|
24
33
|
def drop_path(
|
|
25
34
|
x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
|
|
26
35
|
):
|
|
@@ -390,11 +399,15 @@ class FeedforwardBlock(nn.Module):
|
|
|
390
399
|
)
|
|
391
400
|
|
|
392
401
|
self.max_features = (
|
|
393
|
-
2 * ratio * output_features
|
|
402
|
+
2 * int(ratio * output_features)
|
|
403
|
+
if self.xglu
|
|
404
|
+
else int(ratio * output_features)
|
|
394
405
|
)
|
|
395
406
|
|
|
396
407
|
self.linear_in = linear_module_up(input_features, self.max_features)
|
|
397
|
-
self.linear_out = linear_module_down(
|
|
408
|
+
self.linear_out = linear_module_down(
|
|
409
|
+
int(ratio * output_features), output_features
|
|
410
|
+
)
|
|
398
411
|
|
|
399
412
|
self.process = nn.Sequential(
|
|
400
413
|
*[
|
|
@@ -402,7 +415,11 @@ class FeedforwardBlock(nn.Module):
|
|
|
402
415
|
self.linear_in,
|
|
403
416
|
self.activation,
|
|
404
417
|
self.inner_dropout,
|
|
405
|
-
|
|
418
|
+
(
|
|
419
|
+
nn.LayerNorm(int(ratio * output_features))
|
|
420
|
+
if normformer
|
|
421
|
+
else nn.Identity()
|
|
422
|
+
),
|
|
406
423
|
self.linear_out,
|
|
407
424
|
self.outer_dropout,
|
|
408
425
|
]
|
|
@@ -496,6 +513,7 @@ class TransformerBlock(nn.Module):
|
|
|
496
513
|
post_norm=False,
|
|
497
514
|
normformer=False,
|
|
498
515
|
checkpoint_ff=True,
|
|
516
|
+
layerscale=True,
|
|
499
517
|
):
|
|
500
518
|
"""
|
|
501
519
|
Args:
|
|
@@ -517,6 +535,13 @@ class TransformerBlock(nn.Module):
|
|
|
517
535
|
self.layer_norm_2 = nn.LayerNorm(d_model)
|
|
518
536
|
self.layer_norm_3 = nn.LayerNorm(d_model)
|
|
519
537
|
|
|
538
|
+
if layerscale:
|
|
539
|
+
self.layerscale1 = LayerScale(d_model)
|
|
540
|
+
self.layerscale2 = LayerScale(d_model)
|
|
541
|
+
else:
|
|
542
|
+
self.layerscale1 = nn.Identity()
|
|
543
|
+
self.layerscale2 = nn.Identity()
|
|
544
|
+
|
|
520
545
|
if relative_position_embedding:
|
|
521
546
|
max_freq = int(max(source_size) / 2) # Suggested by Gemini!
|
|
522
547
|
if d_model < 16:
|
|
@@ -580,19 +605,19 @@ class TransformerBlock(nn.Module):
|
|
|
580
605
|
|
|
581
606
|
if self.pre_norm:
|
|
582
607
|
x = self.layer_norm_1(x)
|
|
583
|
-
x = x + self.drop_path(self.attn(x, x, x))
|
|
608
|
+
x = x + self.drop_path(self.layerscale1(self.attn(x, x, x)))
|
|
584
609
|
x = self.layer_norm_2(x)
|
|
585
|
-
x = x + self.drop_path(self.ff(x))
|
|
610
|
+
x = x + self.drop_path(self.layerscale2(self.ff(x)))
|
|
586
611
|
if self.post_norm: # i.e. in addition! Pre and post.
|
|
587
612
|
x = self.layer_norm_3(x)
|
|
588
613
|
elif self.post_norm: # i.e. only, not prenorm, just post
|
|
589
|
-
x = x + self.drop_path(self.attn(x, x, x))
|
|
614
|
+
x = x + self.drop_path(self.layerscale1(self.attn(x, x, x)))
|
|
590
615
|
x = self.layer_norm_1(x)
|
|
591
|
-
x = x + self.drop_path(self.ff(x))
|
|
616
|
+
x = x + self.drop_path(self.layerscale2(self.ff(x)))
|
|
592
617
|
x = self.layer_norm_2(x)
|
|
593
618
|
else: # Not pre or post norm. Stand well back.
|
|
594
|
-
x = x + self.drop_path(self.attn(x, x, x))
|
|
595
|
-
x = x + self.drop_path(self.ff(x))
|
|
619
|
+
x = x + self.drop_path(self.layerscale1(self.attn(x, x, x)))
|
|
620
|
+
x = x + self.drop_path(self.layerscale2(self.ff(x)))
|
|
596
621
|
|
|
597
622
|
return x
|
|
598
623
|
|
|
@@ -650,6 +675,7 @@ class TransformerEncoder(nn.Module):
|
|
|
650
675
|
normformer=False,
|
|
651
676
|
msa_scaling="d",
|
|
652
677
|
checkpoint_ff=True,
|
|
678
|
+
layerscale=True,
|
|
653
679
|
):
|
|
654
680
|
"""
|
|
655
681
|
Args:
|
|
@@ -748,6 +774,7 @@ class TransformerEncoder(nn.Module):
|
|
|
748
774
|
post_norm=post_norm,
|
|
749
775
|
normformer=normformer,
|
|
750
776
|
checkpoint_ff=checkpoint_ff,
|
|
777
|
+
layerscale=layerscale,
|
|
751
778
|
)
|
|
752
779
|
for i in range(n_layers)
|
|
753
780
|
]
|
|
@@ -187,6 +187,7 @@ class ViTEncoder(nn.Module):
|
|
|
187
187
|
transformer_msa_dropout=0.1,
|
|
188
188
|
transformer_stochastic_depth=0.1,
|
|
189
189
|
transformer_checkpoint_ff=True,
|
|
190
|
+
transformer_layerscale=True,
|
|
190
191
|
linear_module=nn.Linear,
|
|
191
192
|
):
|
|
192
193
|
super().__init__()
|
|
@@ -352,6 +353,7 @@ class ViTEncoder(nn.Module):
|
|
|
352
353
|
normformer=transformer_normformer,
|
|
353
354
|
post_norm=transformer_post_norm,
|
|
354
355
|
checkpoint_ff=transformer_checkpoint_ff,
|
|
356
|
+
layerscale=transformer_layerscale,
|
|
355
357
|
)
|
|
356
358
|
else:
|
|
357
359
|
self.transformer = nn.Identity()
|
|
@@ -487,6 +489,7 @@ class ViT(nn.Module):
|
|
|
487
489
|
transformer_msa_dropout=0.1,
|
|
488
490
|
transformer_stochastic_depth=0.1,
|
|
489
491
|
transformer_checkpoint_ff=True,
|
|
492
|
+
transformer_layerscale=True,
|
|
490
493
|
head=SequencePoolClassificationHead,
|
|
491
494
|
batch_norm_logits=True,
|
|
492
495
|
logit_projection_layer=nn.Linear,
|
|
@@ -559,6 +562,7 @@ class ViT(nn.Module):
|
|
|
559
562
|
transformer_msa_dropout=transformer_msa_dropout,
|
|
560
563
|
transformer_stochastic_depth=transformer_stochastic_depth,
|
|
561
564
|
transformer_checkpoint_ff=transformer_checkpoint_ff,
|
|
565
|
+
transformer_layerscale=transformer_layerscale,
|
|
562
566
|
linear_module=linear_module,
|
|
563
567
|
)
|
|
564
568
|
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|