broccoli-ml 10.1.0__tar.gz → 13.0.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {broccoli_ml-10.1.0 → broccoli_ml-13.0.0}/PKG-INFO +1 -1
- {broccoli_ml-10.1.0 → broccoli_ml-13.0.0}/broccoli/transformer.py +101 -68
- {broccoli_ml-10.1.0 → broccoli_ml-13.0.0}/broccoli/vit.py +11 -16
- {broccoli_ml-10.1.0 → broccoli_ml-13.0.0}/pyproject.toml +1 -1
- {broccoli_ml-10.1.0 → broccoli_ml-13.0.0}/LICENSE +0 -0
- {broccoli_ml-10.1.0 → broccoli_ml-13.0.0}/README.md +0 -0
- {broccoli_ml-10.1.0 → broccoli_ml-13.0.0}/broccoli/__init__.py +0 -0
- {broccoli_ml-10.1.0 → broccoli_ml-13.0.0}/broccoli/activation.py +0 -0
- {broccoli_ml-10.1.0 → broccoli_ml-13.0.0}/broccoli/cnn.py +0 -0
- {broccoli_ml-10.1.0 → broccoli_ml-13.0.0}/broccoli/linear.py +0 -0
- {broccoli_ml-10.1.0 → broccoli_ml-13.0.0}/broccoli/rope.py +0 -0
- {broccoli_ml-10.1.0 → broccoli_ml-13.0.0}/broccoli/tensor.py +0 -0
- {broccoli_ml-10.1.0 → broccoli_ml-13.0.0}/broccoli/utils.py +0 -0
|
@@ -21,13 +21,10 @@ except ImportError:
|
|
|
21
21
|
FLASH_ATTN = False
|
|
22
22
|
|
|
23
23
|
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
def forward(self, x):
|
|
30
|
-
return x * self.nondecay_scale
|
|
24
|
+
def scale_parameters(torch_module: nn.Module, factor: float):
|
|
25
|
+
with torch.no_grad():
|
|
26
|
+
for param in torch_module.parameters():
|
|
27
|
+
param.mul_(factor)
|
|
31
28
|
|
|
32
29
|
|
|
33
30
|
def drop_path(
|
|
@@ -92,6 +89,7 @@ class MHAttention(nn.Module):
|
|
|
92
89
|
rotary_embedding=None,
|
|
93
90
|
source_size=None,
|
|
94
91
|
scaling="d",
|
|
92
|
+
beta=1.0,
|
|
95
93
|
):
|
|
96
94
|
"""
|
|
97
95
|
Args:
|
|
@@ -120,6 +118,7 @@ class MHAttention(nn.Module):
|
|
|
120
118
|
self.n_heads = n_heads
|
|
121
119
|
assert embed_dim % n_heads == 0
|
|
122
120
|
self.scaling = scaling
|
|
121
|
+
self.beta = beta
|
|
123
122
|
|
|
124
123
|
self.head_dim = self.embed_dim // self.n_heads
|
|
125
124
|
|
|
@@ -201,17 +200,26 @@ class MHAttention(nn.Module):
|
|
|
201
200
|
"`source_size` must be a tuple of 1, 2 or 3 integers"
|
|
202
201
|
)
|
|
203
202
|
|
|
204
|
-
|
|
205
|
-
|
|
203
|
+
q = rearrange(q, "b t (h d) -> b t h d", h=self.n_heads)
|
|
204
|
+
k = rearrange(k, "b t (h d) -> b t h d", h=self.n_heads)
|
|
205
|
+
|
|
206
|
+
q_util, q_img = (
|
|
207
|
+
q[:, : self.utility_tokens, :, :],
|
|
208
|
+
q[:, self.utility_tokens :, :, :],
|
|
209
|
+
)
|
|
210
|
+
k_util, k_img = (
|
|
211
|
+
k[:, : self.utility_tokens, :, :],
|
|
212
|
+
k[:, self.utility_tokens :, :, :],
|
|
213
|
+
)
|
|
206
214
|
|
|
207
215
|
q_img = rearrange(
|
|
208
216
|
q_img,
|
|
209
|
-
f"b ({spatial_dimension_names}) d -> b {spatial_dimension_names} d",
|
|
217
|
+
f"b ({spatial_dimension_names}) h d -> b {spatial_dimension_names} h d",
|
|
210
218
|
**spatial_dimension_values,
|
|
211
219
|
)
|
|
212
220
|
k_img = rearrange(
|
|
213
221
|
k_img,
|
|
214
|
-
f"b ({spatial_dimension_names}) d -> b {spatial_dimension_names} d",
|
|
222
|
+
f"b ({spatial_dimension_names}) h d -> b {spatial_dimension_names} h d",
|
|
215
223
|
**spatial_dimension_values,
|
|
216
224
|
)
|
|
217
225
|
|
|
@@ -222,17 +230,20 @@ class MHAttention(nn.Module):
|
|
|
222
230
|
|
|
223
231
|
q_img = rearrange(
|
|
224
232
|
q_img,
|
|
225
|
-
f"b {spatial_dimension_names} d -> b ({spatial_dimension_names}) d",
|
|
233
|
+
f"b {spatial_dimension_names} h d -> b ({spatial_dimension_names}) h d",
|
|
226
234
|
)
|
|
227
235
|
k_img = rearrange(
|
|
228
236
|
k_img,
|
|
229
|
-
f"b {spatial_dimension_names} d -> b ({spatial_dimension_names}) d",
|
|
237
|
+
f"b {spatial_dimension_names} h d -> b ({spatial_dimension_names}) h d",
|
|
230
238
|
)
|
|
231
239
|
|
|
232
240
|
# Re-combine the utility tokens and the RoPE-enhanced sequence tokens
|
|
233
241
|
q = torch.cat([q_util, q_img], dim=1)
|
|
234
242
|
k = torch.cat([k_util, k_img], dim=1)
|
|
235
243
|
|
|
244
|
+
q = rearrange(q, "b t h d -> b t (h d)")
|
|
245
|
+
k = rearrange(k, "b t h d -> b t (h d)")
|
|
246
|
+
|
|
236
247
|
return q, k
|
|
237
248
|
|
|
238
249
|
def project_qkv(
|
|
@@ -339,7 +350,10 @@ class MHAttention(nn.Module):
|
|
|
339
350
|
self.q_proj.reset_parameters()
|
|
340
351
|
self.k_proj.reset_parameters()
|
|
341
352
|
self.v_proj.reset_parameters()
|
|
353
|
+
scale_parameters(self.v_proj, self.beta) # per Microsoft DeepNet
|
|
342
354
|
self.out_proj.reset_parameters()
|
|
355
|
+
scale_parameters(self.out_proj, self.beta) # per Microsoft DeepNet
|
|
356
|
+
|
|
343
357
|
if self.talking_heads:
|
|
344
358
|
# Initialize close to identity
|
|
345
359
|
nn.init.eye_(self.head_projection.weight)
|
|
@@ -363,17 +377,14 @@ class FeedforwardBlock(nn.Module):
|
|
|
363
377
|
outer_dropout=None,
|
|
364
378
|
linear_module_up=nn.Linear,
|
|
365
379
|
linear_module_down=nn.Linear,
|
|
366
|
-
pre_norm=True,
|
|
367
380
|
normformer=False,
|
|
368
|
-
post_norm=True,
|
|
369
|
-
residual_path=True,
|
|
370
381
|
checkpoint=True,
|
|
382
|
+
beta=1.0,
|
|
371
383
|
):
|
|
372
384
|
super().__init__()
|
|
373
385
|
|
|
374
386
|
self.checkpoint = checkpoint
|
|
375
|
-
self.
|
|
376
|
-
self.post_norm = post_norm
|
|
387
|
+
self.beta = beta
|
|
377
388
|
self.xglu = activation.__name__.endswith("GLU")
|
|
378
389
|
|
|
379
390
|
if self.residual_path and (output_features < input_features):
|
|
@@ -411,7 +422,6 @@ class FeedforwardBlock(nn.Module):
|
|
|
411
422
|
|
|
412
423
|
self.process = nn.Sequential(
|
|
413
424
|
*[
|
|
414
|
-
nn.LayerNorm(input_features) if pre_norm else nn.Identity(),
|
|
415
425
|
self.linear_in,
|
|
416
426
|
self.activation,
|
|
417
427
|
self.inner_dropout,
|
|
@@ -461,12 +471,7 @@ class FeedforwardBlock(nn.Module):
|
|
|
461
471
|
else:
|
|
462
472
|
processed = self.process(x)
|
|
463
473
|
|
|
464
|
-
|
|
465
|
-
return self.layernorm(x + processed)
|
|
466
|
-
elif self.residual_path:
|
|
467
|
-
return x + processed
|
|
468
|
-
else:
|
|
469
|
-
return processed
|
|
474
|
+
return processed
|
|
470
475
|
|
|
471
476
|
def reset_parameters(self):
|
|
472
477
|
if self.post_norm:
|
|
@@ -477,8 +482,11 @@ class FeedforwardBlock(nn.Module):
|
|
|
477
482
|
if hasattr(module, "reset_parameters"):
|
|
478
483
|
module.reset_parameters()
|
|
479
484
|
|
|
485
|
+
scale_parameters(self.linear_in, self.beta) # per Microsoft DeepNet
|
|
486
|
+
scale_parameters(self.linear_out, self.beta)
|
|
487
|
+
|
|
480
488
|
|
|
481
|
-
class
|
|
489
|
+
class EncoderBlock(nn.Module):
|
|
482
490
|
"""
|
|
483
491
|
Performs LayerNorms first (as in PyTorch Transformers when norm_first=True),
|
|
484
492
|
which is also what is seen in e.g.
|
|
@@ -513,7 +521,8 @@ class TransformerBlock(nn.Module):
|
|
|
513
521
|
post_norm=False,
|
|
514
522
|
normformer=False,
|
|
515
523
|
checkpoint_ff=True,
|
|
516
|
-
|
|
524
|
+
alpha=1.0,
|
|
525
|
+
beta=1.0,
|
|
517
526
|
):
|
|
518
527
|
"""
|
|
519
528
|
Args:
|
|
@@ -525,22 +534,29 @@ class TransformerBlock(nn.Module):
|
|
|
525
534
|
|
|
526
535
|
super().__init__()
|
|
527
536
|
|
|
537
|
+
if pre_norm and post_norm:
|
|
538
|
+
raise ValueError("A transformer cannot be both prenorm and postnorm.")
|
|
539
|
+
|
|
528
540
|
self.pre_norm = pre_norm
|
|
529
541
|
self.post_norm = post_norm
|
|
530
542
|
self.normformer = normformer
|
|
531
543
|
|
|
544
|
+
self.alpha = alpha
|
|
545
|
+
self.beta = beta
|
|
546
|
+
|
|
532
547
|
self.drop_path = DropPath(drop_prob=identity_probability, scale_by_keep=True)
|
|
533
548
|
|
|
534
|
-
self.
|
|
535
|
-
|
|
536
|
-
|
|
549
|
+
if self.pre_norm:
|
|
550
|
+
self.pre_attention_norm = nn.LayerNorm(d_model)
|
|
551
|
+
self.pre_mlp_norm = nn.LayerNorm(d_model)
|
|
537
552
|
|
|
538
|
-
if
|
|
539
|
-
self.
|
|
540
|
-
|
|
541
|
-
|
|
542
|
-
self.
|
|
543
|
-
self.
|
|
553
|
+
if normformer:
|
|
554
|
+
self.normformer_norm = nn.LayerNorm(d_model)
|
|
555
|
+
|
|
556
|
+
if self.post_norm:
|
|
557
|
+
self.input_norm = nn.LayerNorm(d_model)
|
|
558
|
+
self.post_attention_norm = nn.LayerNorm(d_model)
|
|
559
|
+
self.post_mlp_norm = nn.LayerNorm(d_model)
|
|
544
560
|
|
|
545
561
|
if relative_position_embedding:
|
|
546
562
|
max_freq = int(max(source_size) / 2) # Suggested by Gemini!
|
|
@@ -566,6 +582,7 @@ class TransformerBlock(nn.Module):
|
|
|
566
582
|
utility_tokens=utility_tokens,
|
|
567
583
|
talking_heads=talking_heads,
|
|
568
584
|
scaling=msa_scaling,
|
|
585
|
+
beta=beta,
|
|
569
586
|
)
|
|
570
587
|
|
|
571
588
|
# Submodule for the feedforward process
|
|
@@ -588,11 +605,9 @@ class TransformerBlock(nn.Module):
|
|
|
588
605
|
if ff_linear_module_down is not None
|
|
589
606
|
else linear_module
|
|
590
607
|
),
|
|
591
|
-
pre_norm=False, # Handled outside the block
|
|
592
608
|
normformer=normformer,
|
|
593
|
-
post_norm=False, # Handled outside the block
|
|
594
|
-
residual_path=False, # Handled outside the block
|
|
595
609
|
checkpoint=checkpoint_ff,
|
|
610
|
+
beta=beta,
|
|
596
611
|
)
|
|
597
612
|
|
|
598
613
|
self.reset_parameters()
|
|
@@ -602,22 +617,34 @@ class TransformerBlock(nn.Module):
|
|
|
602
617
|
return self.attn._kv_distance
|
|
603
618
|
|
|
604
619
|
def forward(self, x):
|
|
620
|
+
if self.post_norm:
|
|
621
|
+
x = self.input_norm(x)
|
|
605
622
|
|
|
606
623
|
if self.pre_norm:
|
|
607
|
-
|
|
608
|
-
|
|
609
|
-
|
|
610
|
-
|
|
611
|
-
|
|
612
|
-
|
|
613
|
-
|
|
614
|
-
|
|
615
|
-
|
|
616
|
-
|
|
617
|
-
|
|
618
|
-
|
|
619
|
-
x =
|
|
620
|
-
|
|
624
|
+
process_x = self.pre_attention_norm(x)
|
|
625
|
+
else:
|
|
626
|
+
process_x = x
|
|
627
|
+
|
|
628
|
+
processed = self.drop_path(self.attn(process_x, process_x, process_x))
|
|
629
|
+
|
|
630
|
+
if self.normformer:
|
|
631
|
+
processed = self.normformer_norm(processed)
|
|
632
|
+
|
|
633
|
+
x = self.alpha * x + processed
|
|
634
|
+
|
|
635
|
+
if self.post_norm:
|
|
636
|
+
x = self.post_attention_norm(x)
|
|
637
|
+
elif self.pre_norm:
|
|
638
|
+
process_x = self.pre_mlp_norm(x)
|
|
639
|
+
else:
|
|
640
|
+
process_x = x
|
|
641
|
+
|
|
642
|
+
processed = self.drop_path(self.ff(process_x))
|
|
643
|
+
|
|
644
|
+
x = self.alpha * x + processed
|
|
645
|
+
|
|
646
|
+
if self.post_norm:
|
|
647
|
+
x = self.post_mlp_norm(x)
|
|
621
648
|
|
|
622
649
|
return x
|
|
623
650
|
|
|
@@ -625,16 +652,26 @@ class TransformerBlock(nn.Module):
|
|
|
625
652
|
"""
|
|
626
653
|
Give back the attention scores used in this layer.
|
|
627
654
|
"""
|
|
655
|
+
# Fix: Use the correct attribute name 'pre_attention_norm'
|
|
628
656
|
if self.pre_norm:
|
|
629
|
-
|
|
657
|
+
# We must normalize the input before measuring attention logits
|
|
658
|
+
# to match what the model actually sees during forward()
|
|
659
|
+
x = self.pre_attention_norm(x)
|
|
630
660
|
return self.attn.attention_logits(x, x, x)
|
|
631
661
|
else:
|
|
632
662
|
return self.attn.attention_logits(x, x, x)
|
|
633
663
|
|
|
634
664
|
def reset_parameters(self):
|
|
635
|
-
self.
|
|
636
|
-
|
|
637
|
-
|
|
665
|
+
if self.pre_norm:
|
|
666
|
+
self.pre_attention_norm.reset_parameters()
|
|
667
|
+
self.pre_mlp_norm.reset_parameters()
|
|
668
|
+
|
|
669
|
+
if self.post_norm:
|
|
670
|
+
self.post_attention_norm.reset_parameters()
|
|
671
|
+
self.post_mlp_norm.reset_parameters()
|
|
672
|
+
|
|
673
|
+
if self.normformer:
|
|
674
|
+
self.normformer_norm.reset_parameters()
|
|
638
675
|
|
|
639
676
|
self.attn.reset_parameters()
|
|
640
677
|
self.ff.reset_parameters()
|
|
@@ -675,7 +712,8 @@ class TransformerEncoder(nn.Module):
|
|
|
675
712
|
normformer=False,
|
|
676
713
|
msa_scaling="d",
|
|
677
714
|
checkpoint_ff=True,
|
|
678
|
-
|
|
715
|
+
alpha=1.0,
|
|
716
|
+
beta=1.0,
|
|
679
717
|
):
|
|
680
718
|
"""
|
|
681
719
|
Args:
|
|
@@ -710,11 +748,6 @@ class TransformerEncoder(nn.Module):
|
|
|
710
748
|
self._utility_tokens = utility_tokens
|
|
711
749
|
self.return_utility_tokens = return_utility_tokens
|
|
712
750
|
|
|
713
|
-
if layerscale:
|
|
714
|
-
self.layerscale = LayerScale(d_model)
|
|
715
|
-
else:
|
|
716
|
-
self.layerscale = None
|
|
717
|
-
|
|
718
751
|
# Initialise utility tokens with normal init, like usual Pytorch embeddings
|
|
719
752
|
if self._utility_tokens:
|
|
720
753
|
self._utility_token_embedding = nn.Parameter(
|
|
@@ -754,7 +787,7 @@ class TransformerEncoder(nn.Module):
|
|
|
754
787
|
|
|
755
788
|
self.blocks = nn.ModuleList(
|
|
756
789
|
[
|
|
757
|
-
|
|
790
|
+
EncoderBlock(
|
|
758
791
|
self.full_sequence_length,
|
|
759
792
|
d_model,
|
|
760
793
|
n_heads,
|
|
@@ -779,7 +812,8 @@ class TransformerEncoder(nn.Module):
|
|
|
779
812
|
post_norm=post_norm,
|
|
780
813
|
normformer=normformer,
|
|
781
814
|
checkpoint_ff=checkpoint_ff,
|
|
782
|
-
|
|
815
|
+
alpha=alpha,
|
|
816
|
+
beta=beta,
|
|
783
817
|
)
|
|
784
818
|
for i in range(n_layers)
|
|
785
819
|
]
|
|
@@ -807,10 +841,9 @@ class TransformerEncoder(nn.Module):
|
|
|
807
841
|
0
|
|
808
842
|
) # to shape (1, seq_len) to broadcast over batch
|
|
809
843
|
)
|
|
810
|
-
|
|
811
|
-
position_embedding = self.layerscale(position_embedding)
|
|
844
|
+
x += position_embedding
|
|
812
845
|
|
|
813
|
-
return x
|
|
846
|
+
return x
|
|
814
847
|
|
|
815
848
|
def forward(self, x):
|
|
816
849
|
|
|
@@ -158,7 +158,6 @@ class ViTEncoder(nn.Module):
|
|
|
158
158
|
pooling_kernel_stride=2,
|
|
159
159
|
pooling_padding=1,
|
|
160
160
|
transformer_feedforward_first=True,
|
|
161
|
-
transformer_initial_ff_residual_path=True,
|
|
162
161
|
transformer_initial_ff_linear_module_up=None,
|
|
163
162
|
transformer_initial_ff_linear_module_down=None,
|
|
164
163
|
transformer_initial_ff_dropout=None,
|
|
@@ -187,7 +186,6 @@ class ViTEncoder(nn.Module):
|
|
|
187
186
|
transformer_msa_dropout=0.1,
|
|
188
187
|
transformer_stochastic_depth=0.1,
|
|
189
188
|
transformer_checkpoint_ff=True,
|
|
190
|
-
transformer_layerscale=True,
|
|
191
189
|
linear_module=nn.Linear,
|
|
192
190
|
):
|
|
193
191
|
super().__init__()
|
|
@@ -353,7 +351,8 @@ class ViTEncoder(nn.Module):
|
|
|
353
351
|
normformer=transformer_normformer,
|
|
354
352
|
post_norm=transformer_post_norm,
|
|
355
353
|
checkpoint_ff=transformer_checkpoint_ff,
|
|
356
|
-
|
|
354
|
+
alpha=self.alpha,
|
|
355
|
+
beta=self.beta,
|
|
357
356
|
)
|
|
358
357
|
else:
|
|
359
358
|
self.transformer = nn.Identity()
|
|
@@ -395,16 +394,14 @@ class ViTEncoder(nn.Module):
|
|
|
395
394
|
or transformer_ff_linear_module_down
|
|
396
395
|
or linear_module
|
|
397
396
|
),
|
|
398
|
-
pre_norm=transformer_pre_norm,
|
|
399
397
|
normformer=transformer_normformer,
|
|
400
|
-
post_norm=transformer_post_norm,
|
|
401
|
-
residual_path=transformer_initial_ff_residual_path,
|
|
402
398
|
checkpoint=transformer_checkpoint_ff,
|
|
399
|
+
beta=self.beta,
|
|
403
400
|
)
|
|
404
401
|
else:
|
|
405
402
|
self.initial_ff = nn.Identity()
|
|
406
403
|
|
|
407
|
-
self.
|
|
404
|
+
self.preprocess = nn.Sequential(
|
|
408
405
|
*[
|
|
409
406
|
batchnormxd(in_channels) if initial_batch_norm else nn.Identity(),
|
|
410
407
|
self.cnn,
|
|
@@ -414,19 +411,21 @@ class ViTEncoder(nn.Module):
|
|
|
414
411
|
f"N C {spatial_dim_names} -> N ({spatial_dim_names}) C"
|
|
415
412
|
),
|
|
416
413
|
self.pooling_channels_padding,
|
|
417
|
-
|
|
418
|
-
self.transformer,
|
|
414
|
+
nn.LayerNorm(),
|
|
419
415
|
]
|
|
420
416
|
)
|
|
421
417
|
|
|
422
418
|
self.reset_parameters()
|
|
423
419
|
|
|
424
420
|
def forward(self, x):
|
|
425
|
-
|
|
421
|
+
x = self.preprocess(x)
|
|
422
|
+
x = x + self.initial_ff(x)
|
|
423
|
+
return self.transformer(x)
|
|
426
424
|
|
|
427
425
|
def attention_logits(self, x):
|
|
428
|
-
x = self.
|
|
429
|
-
|
|
426
|
+
x = self.preprocess(x)
|
|
427
|
+
x = x + self.initial_ff(x)
|
|
428
|
+
return self.transformer.attention_logits(x)
|
|
430
429
|
|
|
431
430
|
def reset_parameters(self):
|
|
432
431
|
for module in self.encoder:
|
|
@@ -460,7 +459,6 @@ class ViT(nn.Module):
|
|
|
460
459
|
pooling_kernel_stride=2,
|
|
461
460
|
pooling_padding=1,
|
|
462
461
|
transformer_feedforward_first=True,
|
|
463
|
-
transformer_initial_ff_residual_path=True,
|
|
464
462
|
transformer_initial_ff_linear_module_up=None,
|
|
465
463
|
transformer_initial_ff_linear_module_down=None,
|
|
466
464
|
transformer_initial_ff_dropout=None,
|
|
@@ -489,7 +487,6 @@ class ViT(nn.Module):
|
|
|
489
487
|
transformer_msa_dropout=0.1,
|
|
490
488
|
transformer_stochastic_depth=0.1,
|
|
491
489
|
transformer_checkpoint_ff=True,
|
|
492
|
-
transformer_layerscale=True,
|
|
493
490
|
head=SequencePoolClassificationHead,
|
|
494
491
|
batch_norm_logits=True,
|
|
495
492
|
logit_projection_layer=nn.Linear,
|
|
@@ -533,7 +530,6 @@ class ViT(nn.Module):
|
|
|
533
530
|
pooling_kernel_stride=pooling_kernel_stride,
|
|
534
531
|
pooling_padding=pooling_padding,
|
|
535
532
|
transformer_feedforward_first=transformer_feedforward_first,
|
|
536
|
-
transformer_initial_ff_residual_path=transformer_initial_ff_residual_path,
|
|
537
533
|
transformer_initial_ff_linear_module_up=transformer_initial_ff_linear_module_up,
|
|
538
534
|
transformer_initial_ff_linear_module_down=transformer_initial_ff_linear_module_down,
|
|
539
535
|
transformer_initial_ff_dropout=transformer_initial_ff_dropout,
|
|
@@ -562,7 +558,6 @@ class ViT(nn.Module):
|
|
|
562
558
|
transformer_msa_dropout=transformer_msa_dropout,
|
|
563
559
|
transformer_stochastic_depth=transformer_stochastic_depth,
|
|
564
560
|
transformer_checkpoint_ff=transformer_checkpoint_ff,
|
|
565
|
-
transformer_layerscale=transformer_layerscale,
|
|
566
561
|
linear_module=linear_module,
|
|
567
562
|
)
|
|
568
563
|
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|