broccoli-ml 12.3.0__py3-none-any.whl → 13.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- broccoli/transformer.py +73 -52
- broccoli/vit.py +11 -12
- {broccoli_ml-12.3.0.dist-info → broccoli_ml-13.0.0.dist-info}/METADATA +1 -1
- {broccoli_ml-12.3.0.dist-info → broccoli_ml-13.0.0.dist-info}/RECORD +6 -6
- {broccoli_ml-12.3.0.dist-info → broccoli_ml-13.0.0.dist-info}/LICENSE +0 -0
- {broccoli_ml-12.3.0.dist-info → broccoli_ml-13.0.0.dist-info}/WHEEL +0 -0
broccoli/transformer.py
CHANGED
|
@@ -21,6 +21,12 @@ except ImportError:
|
|
|
21
21
|
FLASH_ATTN = False
|
|
22
22
|
|
|
23
23
|
|
|
24
|
+
def scale_parameters(torch_module: nn.Module, factor: float):
|
|
25
|
+
with torch.no_grad():
|
|
26
|
+
for param in torch_module.parameters():
|
|
27
|
+
param.mul_(factor)
|
|
28
|
+
|
|
29
|
+
|
|
24
30
|
def drop_path(
|
|
25
31
|
x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
|
|
26
32
|
):
|
|
@@ -83,6 +89,7 @@ class MHAttention(nn.Module):
|
|
|
83
89
|
rotary_embedding=None,
|
|
84
90
|
source_size=None,
|
|
85
91
|
scaling="d",
|
|
92
|
+
beta=1.0,
|
|
86
93
|
):
|
|
87
94
|
"""
|
|
88
95
|
Args:
|
|
@@ -111,6 +118,7 @@ class MHAttention(nn.Module):
|
|
|
111
118
|
self.n_heads = n_heads
|
|
112
119
|
assert embed_dim % n_heads == 0
|
|
113
120
|
self.scaling = scaling
|
|
121
|
+
self.beta = beta
|
|
114
122
|
|
|
115
123
|
self.head_dim = self.embed_dim // self.n_heads
|
|
116
124
|
|
|
@@ -192,17 +200,26 @@ class MHAttention(nn.Module):
|
|
|
192
200
|
"`source_size` must be a tuple of 1, 2 or 3 integers"
|
|
193
201
|
)
|
|
194
202
|
|
|
195
|
-
|
|
196
|
-
|
|
203
|
+
q = rearrange(q, "b t (h d) -> b t h d", h=self.n_heads)
|
|
204
|
+
k = rearrange(k, "b t (h d) -> b t h d", h=self.n_heads)
|
|
205
|
+
|
|
206
|
+
q_util, q_img = (
|
|
207
|
+
q[:, : self.utility_tokens, :, :],
|
|
208
|
+
q[:, self.utility_tokens :, :, :],
|
|
209
|
+
)
|
|
210
|
+
k_util, k_img = (
|
|
211
|
+
k[:, : self.utility_tokens, :, :],
|
|
212
|
+
k[:, self.utility_tokens :, :, :],
|
|
213
|
+
)
|
|
197
214
|
|
|
198
215
|
q_img = rearrange(
|
|
199
216
|
q_img,
|
|
200
|
-
f"b ({spatial_dimension_names}) d -> b {spatial_dimension_names} d",
|
|
217
|
+
f"b ({spatial_dimension_names}) h d -> b {spatial_dimension_names} h d",
|
|
201
218
|
**spatial_dimension_values,
|
|
202
219
|
)
|
|
203
220
|
k_img = rearrange(
|
|
204
221
|
k_img,
|
|
205
|
-
f"b ({spatial_dimension_names}) d -> b {spatial_dimension_names} d",
|
|
222
|
+
f"b ({spatial_dimension_names}) h d -> b {spatial_dimension_names} h d",
|
|
206
223
|
**spatial_dimension_values,
|
|
207
224
|
)
|
|
208
225
|
|
|
@@ -213,17 +230,20 @@ class MHAttention(nn.Module):
|
|
|
213
230
|
|
|
214
231
|
q_img = rearrange(
|
|
215
232
|
q_img,
|
|
216
|
-
f"b {spatial_dimension_names} d -> b ({spatial_dimension_names}) d",
|
|
233
|
+
f"b {spatial_dimension_names} h d -> b ({spatial_dimension_names}) h d",
|
|
217
234
|
)
|
|
218
235
|
k_img = rearrange(
|
|
219
236
|
k_img,
|
|
220
|
-
f"b {spatial_dimension_names} d -> b ({spatial_dimension_names}) d",
|
|
237
|
+
f"b {spatial_dimension_names} h d -> b ({spatial_dimension_names}) h d",
|
|
221
238
|
)
|
|
222
239
|
|
|
223
240
|
# Re-combine the utility tokens and the RoPE-enhanced sequence tokens
|
|
224
241
|
q = torch.cat([q_util, q_img], dim=1)
|
|
225
242
|
k = torch.cat([k_util, k_img], dim=1)
|
|
226
243
|
|
|
244
|
+
q = rearrange(q, "b t h d -> b t (h d)")
|
|
245
|
+
k = rearrange(k, "b t h d -> b t (h d)")
|
|
246
|
+
|
|
227
247
|
return q, k
|
|
228
248
|
|
|
229
249
|
def project_qkv(
|
|
@@ -330,7 +350,10 @@ class MHAttention(nn.Module):
|
|
|
330
350
|
self.q_proj.reset_parameters()
|
|
331
351
|
self.k_proj.reset_parameters()
|
|
332
352
|
self.v_proj.reset_parameters()
|
|
353
|
+
scale_parameters(self.v_proj, self.beta) # per Microsoft DeepNet
|
|
333
354
|
self.out_proj.reset_parameters()
|
|
355
|
+
scale_parameters(self.out_proj, self.beta) # per Microsoft DeepNet
|
|
356
|
+
|
|
334
357
|
if self.talking_heads:
|
|
335
358
|
# Initialize close to identity
|
|
336
359
|
nn.init.eye_(self.head_projection.weight)
|
|
@@ -354,17 +377,14 @@ class FeedforwardBlock(nn.Module):
|
|
|
354
377
|
outer_dropout=None,
|
|
355
378
|
linear_module_up=nn.Linear,
|
|
356
379
|
linear_module_down=nn.Linear,
|
|
357
|
-
pre_norm=True,
|
|
358
380
|
normformer=False,
|
|
359
|
-
post_norm=True,
|
|
360
|
-
residual_path=True,
|
|
361
381
|
checkpoint=True,
|
|
382
|
+
beta=1.0,
|
|
362
383
|
):
|
|
363
384
|
super().__init__()
|
|
364
385
|
|
|
365
386
|
self.checkpoint = checkpoint
|
|
366
|
-
self.
|
|
367
|
-
self.post_norm = post_norm
|
|
387
|
+
self.beta = beta
|
|
368
388
|
self.xglu = activation.__name__.endswith("GLU")
|
|
369
389
|
|
|
370
390
|
if self.residual_path and (output_features < input_features):
|
|
@@ -402,7 +422,6 @@ class FeedforwardBlock(nn.Module):
|
|
|
402
422
|
|
|
403
423
|
self.process = nn.Sequential(
|
|
404
424
|
*[
|
|
405
|
-
nn.LayerNorm(input_features) if pre_norm else nn.Identity(),
|
|
406
425
|
self.linear_in,
|
|
407
426
|
self.activation,
|
|
408
427
|
self.inner_dropout,
|
|
@@ -452,12 +471,7 @@ class FeedforwardBlock(nn.Module):
|
|
|
452
471
|
else:
|
|
453
472
|
processed = self.process(x)
|
|
454
473
|
|
|
455
|
-
|
|
456
|
-
return self.layernorm(x + processed)
|
|
457
|
-
elif self.residual_path:
|
|
458
|
-
return x + processed
|
|
459
|
-
else:
|
|
460
|
-
return processed
|
|
474
|
+
return processed
|
|
461
475
|
|
|
462
476
|
def reset_parameters(self):
|
|
463
477
|
if self.post_norm:
|
|
@@ -468,8 +482,11 @@ class FeedforwardBlock(nn.Module):
|
|
|
468
482
|
if hasattr(module, "reset_parameters"):
|
|
469
483
|
module.reset_parameters()
|
|
470
484
|
|
|
485
|
+
scale_parameters(self.linear_in, self.beta) # per Microsoft DeepNet
|
|
486
|
+
scale_parameters(self.linear_out, self.beta)
|
|
487
|
+
|
|
471
488
|
|
|
472
|
-
class
|
|
489
|
+
class EncoderBlock(nn.Module):
|
|
473
490
|
"""
|
|
474
491
|
Performs LayerNorms first (as in PyTorch Transformers when norm_first=True),
|
|
475
492
|
which is also what is seen in e.g.
|
|
@@ -504,6 +521,8 @@ class TransformerBlock(nn.Module):
|
|
|
504
521
|
post_norm=False,
|
|
505
522
|
normformer=False,
|
|
506
523
|
checkpoint_ff=True,
|
|
524
|
+
alpha=1.0,
|
|
525
|
+
beta=1.0,
|
|
507
526
|
):
|
|
508
527
|
"""
|
|
509
528
|
Args:
|
|
@@ -515,10 +534,16 @@ class TransformerBlock(nn.Module):
|
|
|
515
534
|
|
|
516
535
|
super().__init__()
|
|
517
536
|
|
|
537
|
+
if pre_norm and post_norm:
|
|
538
|
+
raise ValueError("A transformer cannot be both prenorm and postnorm.")
|
|
539
|
+
|
|
518
540
|
self.pre_norm = pre_norm
|
|
519
541
|
self.post_norm = post_norm
|
|
520
542
|
self.normformer = normformer
|
|
521
543
|
|
|
544
|
+
self.alpha = alpha
|
|
545
|
+
self.beta = beta
|
|
546
|
+
|
|
522
547
|
self.drop_path = DropPath(drop_prob=identity_probability, scale_by_keep=True)
|
|
523
548
|
|
|
524
549
|
if self.pre_norm:
|
|
@@ -529,6 +554,7 @@ class TransformerBlock(nn.Module):
|
|
|
529
554
|
self.normformer_norm = nn.LayerNorm(d_model)
|
|
530
555
|
|
|
531
556
|
if self.post_norm:
|
|
557
|
+
self.input_norm = nn.LayerNorm(d_model)
|
|
532
558
|
self.post_attention_norm = nn.LayerNorm(d_model)
|
|
533
559
|
self.post_mlp_norm = nn.LayerNorm(d_model)
|
|
534
560
|
|
|
@@ -556,6 +582,7 @@ class TransformerBlock(nn.Module):
|
|
|
556
582
|
utility_tokens=utility_tokens,
|
|
557
583
|
talking_heads=talking_heads,
|
|
558
584
|
scaling=msa_scaling,
|
|
585
|
+
beta=beta,
|
|
559
586
|
)
|
|
560
587
|
|
|
561
588
|
# Submodule for the feedforward process
|
|
@@ -578,11 +605,9 @@ class TransformerBlock(nn.Module):
|
|
|
578
605
|
if ff_linear_module_down is not None
|
|
579
606
|
else linear_module
|
|
580
607
|
),
|
|
581
|
-
pre_norm=False, # Handled outside the block
|
|
582
608
|
normformer=normformer,
|
|
583
|
-
post_norm=False, # Handled outside the block
|
|
584
|
-
residual_path=False, # Handled outside the block
|
|
585
609
|
checkpoint=checkpoint_ff,
|
|
610
|
+
beta=beta,
|
|
586
611
|
)
|
|
587
612
|
|
|
588
613
|
self.reset_parameters()
|
|
@@ -592,44 +617,36 @@ class TransformerBlock(nn.Module):
|
|
|
592
617
|
return self.attn._kv_distance
|
|
593
618
|
|
|
594
619
|
def forward(self, x):
|
|
595
|
-
if self.
|
|
596
|
-
x = self.
|
|
597
|
-
x = x + self.drop_path(self.attn(x, x, x))
|
|
598
|
-
x = self.pre_mlp_norm(x)
|
|
599
|
-
x = x + self.drop_path(self.ff(x))
|
|
600
|
-
if self.post_norm: # i.e. in addition! Pre and post.
|
|
601
|
-
x = self.post_mlp_norm(x)
|
|
602
|
-
|
|
603
|
-
# if self.pre_norm:
|
|
604
|
-
# process_x = self.pre_attention_norm(x)
|
|
605
|
-
# else:
|
|
606
|
-
# process_x = x
|
|
620
|
+
if self.post_norm:
|
|
621
|
+
x = self.input_norm(x)
|
|
607
622
|
|
|
608
|
-
|
|
623
|
+
if self.pre_norm:
|
|
624
|
+
process_x = self.pre_attention_norm(x)
|
|
625
|
+
else:
|
|
626
|
+
process_x = x
|
|
609
627
|
|
|
610
|
-
|
|
611
|
-
# processed = self.normformer_norm(processed)
|
|
628
|
+
processed = self.drop_path(self.attn(process_x, process_x, process_x))
|
|
612
629
|
|
|
613
|
-
|
|
614
|
-
|
|
630
|
+
if self.normformer:
|
|
631
|
+
processed = self.normformer_norm(processed)
|
|
615
632
|
|
|
616
|
-
|
|
617
|
-
# x = self.post_attention_norm(x)
|
|
633
|
+
x = self.alpha * x + processed
|
|
618
634
|
|
|
619
|
-
|
|
620
|
-
|
|
621
|
-
|
|
622
|
-
|
|
635
|
+
if self.post_norm:
|
|
636
|
+
x = self.post_attention_norm(x)
|
|
637
|
+
elif self.pre_norm:
|
|
638
|
+
process_x = self.pre_mlp_norm(x)
|
|
639
|
+
else:
|
|
640
|
+
process_x = x
|
|
623
641
|
|
|
624
|
-
|
|
642
|
+
processed = self.drop_path(self.ff(process_x))
|
|
625
643
|
|
|
626
|
-
|
|
627
|
-
# x = x + processed
|
|
644
|
+
x = self.alpha * x + processed
|
|
628
645
|
|
|
629
|
-
|
|
630
|
-
|
|
646
|
+
if self.post_norm:
|
|
647
|
+
x = self.post_mlp_norm(x)
|
|
631
648
|
|
|
632
|
-
|
|
649
|
+
return x
|
|
633
650
|
|
|
634
651
|
def attention_logits(self, x):
|
|
635
652
|
"""
|
|
@@ -695,6 +712,8 @@ class TransformerEncoder(nn.Module):
|
|
|
695
712
|
normformer=False,
|
|
696
713
|
msa_scaling="d",
|
|
697
714
|
checkpoint_ff=True,
|
|
715
|
+
alpha=1.0,
|
|
716
|
+
beta=1.0,
|
|
698
717
|
):
|
|
699
718
|
"""
|
|
700
719
|
Args:
|
|
@@ -768,7 +787,7 @@ class TransformerEncoder(nn.Module):
|
|
|
768
787
|
|
|
769
788
|
self.blocks = nn.ModuleList(
|
|
770
789
|
[
|
|
771
|
-
|
|
790
|
+
EncoderBlock(
|
|
772
791
|
self.full_sequence_length,
|
|
773
792
|
d_model,
|
|
774
793
|
n_heads,
|
|
@@ -793,6 +812,8 @@ class TransformerEncoder(nn.Module):
|
|
|
793
812
|
post_norm=post_norm,
|
|
794
813
|
normformer=normformer,
|
|
795
814
|
checkpoint_ff=checkpoint_ff,
|
|
815
|
+
alpha=alpha,
|
|
816
|
+
beta=beta,
|
|
796
817
|
)
|
|
797
818
|
for i in range(n_layers)
|
|
798
819
|
]
|
broccoli/vit.py
CHANGED
|
@@ -158,7 +158,6 @@ class ViTEncoder(nn.Module):
|
|
|
158
158
|
pooling_kernel_stride=2,
|
|
159
159
|
pooling_padding=1,
|
|
160
160
|
transformer_feedforward_first=True,
|
|
161
|
-
transformer_initial_ff_residual_path=True,
|
|
162
161
|
transformer_initial_ff_linear_module_up=None,
|
|
163
162
|
transformer_initial_ff_linear_module_down=None,
|
|
164
163
|
transformer_initial_ff_dropout=None,
|
|
@@ -352,6 +351,8 @@ class ViTEncoder(nn.Module):
|
|
|
352
351
|
normformer=transformer_normformer,
|
|
353
352
|
post_norm=transformer_post_norm,
|
|
354
353
|
checkpoint_ff=transformer_checkpoint_ff,
|
|
354
|
+
alpha=self.alpha,
|
|
355
|
+
beta=self.beta,
|
|
355
356
|
)
|
|
356
357
|
else:
|
|
357
358
|
self.transformer = nn.Identity()
|
|
@@ -393,16 +394,14 @@ class ViTEncoder(nn.Module):
|
|
|
393
394
|
or transformer_ff_linear_module_down
|
|
394
395
|
or linear_module
|
|
395
396
|
),
|
|
396
|
-
pre_norm=transformer_pre_norm,
|
|
397
397
|
normformer=transformer_normformer,
|
|
398
|
-
post_norm=transformer_post_norm,
|
|
399
|
-
residual_path=transformer_initial_ff_residual_path,
|
|
400
398
|
checkpoint=transformer_checkpoint_ff,
|
|
399
|
+
beta=self.beta,
|
|
401
400
|
)
|
|
402
401
|
else:
|
|
403
402
|
self.initial_ff = nn.Identity()
|
|
404
403
|
|
|
405
|
-
self.
|
|
404
|
+
self.preprocess = nn.Sequential(
|
|
406
405
|
*[
|
|
407
406
|
batchnormxd(in_channels) if initial_batch_norm else nn.Identity(),
|
|
408
407
|
self.cnn,
|
|
@@ -412,19 +411,21 @@ class ViTEncoder(nn.Module):
|
|
|
412
411
|
f"N C {spatial_dim_names} -> N ({spatial_dim_names}) C"
|
|
413
412
|
),
|
|
414
413
|
self.pooling_channels_padding,
|
|
415
|
-
|
|
416
|
-
self.transformer,
|
|
414
|
+
nn.LayerNorm(),
|
|
417
415
|
]
|
|
418
416
|
)
|
|
419
417
|
|
|
420
418
|
self.reset_parameters()
|
|
421
419
|
|
|
422
420
|
def forward(self, x):
|
|
423
|
-
|
|
421
|
+
x = self.preprocess(x)
|
|
422
|
+
x = x + self.initial_ff(x)
|
|
423
|
+
return self.transformer(x)
|
|
424
424
|
|
|
425
425
|
def attention_logits(self, x):
|
|
426
|
-
x = self.
|
|
427
|
-
|
|
426
|
+
x = self.preprocess(x)
|
|
427
|
+
x = x + self.initial_ff(x)
|
|
428
|
+
return self.transformer.attention_logits(x)
|
|
428
429
|
|
|
429
430
|
def reset_parameters(self):
|
|
430
431
|
for module in self.encoder:
|
|
@@ -458,7 +459,6 @@ class ViT(nn.Module):
|
|
|
458
459
|
pooling_kernel_stride=2,
|
|
459
460
|
pooling_padding=1,
|
|
460
461
|
transformer_feedforward_first=True,
|
|
461
|
-
transformer_initial_ff_residual_path=True,
|
|
462
462
|
transformer_initial_ff_linear_module_up=None,
|
|
463
463
|
transformer_initial_ff_linear_module_down=None,
|
|
464
464
|
transformer_initial_ff_dropout=None,
|
|
@@ -530,7 +530,6 @@ class ViT(nn.Module):
|
|
|
530
530
|
pooling_kernel_stride=pooling_kernel_stride,
|
|
531
531
|
pooling_padding=pooling_padding,
|
|
532
532
|
transformer_feedforward_first=transformer_feedforward_first,
|
|
533
|
-
transformer_initial_ff_residual_path=transformer_initial_ff_residual_path,
|
|
534
533
|
transformer_initial_ff_linear_module_up=transformer_initial_ff_linear_module_up,
|
|
535
534
|
transformer_initial_ff_linear_module_down=transformer_initial_ff_linear_module_down,
|
|
536
535
|
transformer_initial_ff_dropout=transformer_initial_ff_dropout,
|
|
@@ -4,10 +4,10 @@ broccoli/cnn.py,sha256=WjoPDSpe3ttwxCBNfCVRdaCHvbeZ7G-a5_i8fUsK_d8,4889
|
|
|
4
4
|
broccoli/linear.py,sha256=W-3aNpBjd_0xRyzbCKkmg4H1qmslQOIQhB-WDDay2nM,13125
|
|
5
5
|
broccoli/rope.py,sha256=GRqApBNmYCFaDak0WL1xE_BC5CTTYKQU_PBdeTcQcjc,12557
|
|
6
6
|
broccoli/tensor.py,sha256=um8mrxkYbvNDo-QvHlmJm8Aw6qcngOlUZPoAk_PMReA,4480
|
|
7
|
-
broccoli/transformer.py,sha256=
|
|
7
|
+
broccoli/transformer.py,sha256=3vAQQ75SAyr4-m3e7vSru8M-RpUy2Enp5cVUafaVYMU,28410
|
|
8
8
|
broccoli/utils.py,sha256=oOWzn6dJ5nC_9r4zq0emmfmaYACJXJNFS48AOpW2jqc,358
|
|
9
|
-
broccoli/vit.py,sha256=
|
|
10
|
-
broccoli_ml-
|
|
11
|
-
broccoli_ml-
|
|
12
|
-
broccoli_ml-
|
|
13
|
-
broccoli_ml-
|
|
9
|
+
broccoli/vit.py,sha256=ZOrbfORhl29HXRDWvLt2A2WbEANJlkNdiiucFB-1CmU,22244
|
|
10
|
+
broccoli_ml-13.0.0.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
|
|
11
|
+
broccoli_ml-13.0.0.dist-info/METADATA,sha256=aiwPYSEkNHgJU89jT3tYoNaF_WsaFHoRC2DhgUN7IUE,1369
|
|
12
|
+
broccoli_ml-13.0.0.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
|
13
|
+
broccoli_ml-13.0.0.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|