broccoli-ml 12.3.1__tar.gz → 13.0.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {broccoli_ml-12.3.1 → broccoli_ml-13.0.1}/PKG-INFO +1 -1
- {broccoli_ml-12.3.1 → broccoli_ml-13.0.1}/broccoli/transformer.py +73 -53
- {broccoli_ml-12.3.1 → broccoli_ml-13.0.1}/broccoli/vit.py +23 -12
- {broccoli_ml-12.3.1 → broccoli_ml-13.0.1}/pyproject.toml +1 -1
- {broccoli_ml-12.3.1 → broccoli_ml-13.0.1}/LICENSE +0 -0
- {broccoli_ml-12.3.1 → broccoli_ml-13.0.1}/README.md +0 -0
- {broccoli_ml-12.3.1 → broccoli_ml-13.0.1}/broccoli/__init__.py +0 -0
- {broccoli_ml-12.3.1 → broccoli_ml-13.0.1}/broccoli/activation.py +0 -0
- {broccoli_ml-12.3.1 → broccoli_ml-13.0.1}/broccoli/cnn.py +0 -0
- {broccoli_ml-12.3.1 → broccoli_ml-13.0.1}/broccoli/linear.py +0 -0
- {broccoli_ml-12.3.1 → broccoli_ml-13.0.1}/broccoli/rope.py +0 -0
- {broccoli_ml-12.3.1 → broccoli_ml-13.0.1}/broccoli/tensor.py +0 -0
- {broccoli_ml-12.3.1 → broccoli_ml-13.0.1}/broccoli/utils.py +0 -0
|
@@ -21,6 +21,12 @@ except ImportError:
|
|
|
21
21
|
FLASH_ATTN = False
|
|
22
22
|
|
|
23
23
|
|
|
24
|
+
def scale_parameters(torch_module: nn.Module, factor: float):
|
|
25
|
+
with torch.no_grad():
|
|
26
|
+
for param in torch_module.parameters():
|
|
27
|
+
param.mul_(factor)
|
|
28
|
+
|
|
29
|
+
|
|
24
30
|
def drop_path(
|
|
25
31
|
x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
|
|
26
32
|
):
|
|
@@ -83,6 +89,7 @@ class MHAttention(nn.Module):
|
|
|
83
89
|
rotary_embedding=None,
|
|
84
90
|
source_size=None,
|
|
85
91
|
scaling="d",
|
|
92
|
+
beta=1.0,
|
|
86
93
|
):
|
|
87
94
|
"""
|
|
88
95
|
Args:
|
|
@@ -111,6 +118,7 @@ class MHAttention(nn.Module):
|
|
|
111
118
|
self.n_heads = n_heads
|
|
112
119
|
assert embed_dim % n_heads == 0
|
|
113
120
|
self.scaling = scaling
|
|
121
|
+
self.beta = beta
|
|
114
122
|
|
|
115
123
|
self.head_dim = self.embed_dim // self.n_heads
|
|
116
124
|
|
|
@@ -192,17 +200,26 @@ class MHAttention(nn.Module):
|
|
|
192
200
|
"`source_size` must be a tuple of 1, 2 or 3 integers"
|
|
193
201
|
)
|
|
194
202
|
|
|
195
|
-
|
|
196
|
-
|
|
203
|
+
q = rearrange(q, "b t (h d) -> b t h d", h=self.n_heads)
|
|
204
|
+
k = rearrange(k, "b t (h d) -> b t h d", h=self.n_heads)
|
|
205
|
+
|
|
206
|
+
q_util, q_img = (
|
|
207
|
+
q[:, : self.utility_tokens, :, :],
|
|
208
|
+
q[:, self.utility_tokens :, :, :],
|
|
209
|
+
)
|
|
210
|
+
k_util, k_img = (
|
|
211
|
+
k[:, : self.utility_tokens, :, :],
|
|
212
|
+
k[:, self.utility_tokens :, :, :],
|
|
213
|
+
)
|
|
197
214
|
|
|
198
215
|
q_img = rearrange(
|
|
199
216
|
q_img,
|
|
200
|
-
f"b ({spatial_dimension_names}) d -> b {spatial_dimension_names} d",
|
|
217
|
+
f"b ({spatial_dimension_names}) h d -> b {spatial_dimension_names} h d",
|
|
201
218
|
**spatial_dimension_values,
|
|
202
219
|
)
|
|
203
220
|
k_img = rearrange(
|
|
204
221
|
k_img,
|
|
205
|
-
f"b ({spatial_dimension_names}) d -> b {spatial_dimension_names} d",
|
|
222
|
+
f"b ({spatial_dimension_names}) h d -> b {spatial_dimension_names} h d",
|
|
206
223
|
**spatial_dimension_values,
|
|
207
224
|
)
|
|
208
225
|
|
|
@@ -213,17 +230,20 @@ class MHAttention(nn.Module):
|
|
|
213
230
|
|
|
214
231
|
q_img = rearrange(
|
|
215
232
|
q_img,
|
|
216
|
-
f"b {spatial_dimension_names} d -> b ({spatial_dimension_names}) d",
|
|
233
|
+
f"b {spatial_dimension_names} h d -> b ({spatial_dimension_names}) h d",
|
|
217
234
|
)
|
|
218
235
|
k_img = rearrange(
|
|
219
236
|
k_img,
|
|
220
|
-
f"b {spatial_dimension_names} d -> b ({spatial_dimension_names}) d",
|
|
237
|
+
f"b {spatial_dimension_names} h d -> b ({spatial_dimension_names}) h d",
|
|
221
238
|
)
|
|
222
239
|
|
|
223
240
|
# Re-combine the utility tokens and the RoPE-enhanced sequence tokens
|
|
224
241
|
q = torch.cat([q_util, q_img], dim=1)
|
|
225
242
|
k = torch.cat([k_util, k_img], dim=1)
|
|
226
243
|
|
|
244
|
+
q = rearrange(q, "b t h d -> b t (h d)")
|
|
245
|
+
k = rearrange(k, "b t h d -> b t (h d)")
|
|
246
|
+
|
|
227
247
|
return q, k
|
|
228
248
|
|
|
229
249
|
def project_qkv(
|
|
@@ -330,7 +350,10 @@ class MHAttention(nn.Module):
|
|
|
330
350
|
self.q_proj.reset_parameters()
|
|
331
351
|
self.k_proj.reset_parameters()
|
|
332
352
|
self.v_proj.reset_parameters()
|
|
353
|
+
scale_parameters(self.v_proj, self.beta) # per Microsoft DeepNet
|
|
333
354
|
self.out_proj.reset_parameters()
|
|
355
|
+
scale_parameters(self.out_proj, self.beta) # per Microsoft DeepNet
|
|
356
|
+
|
|
334
357
|
if self.talking_heads:
|
|
335
358
|
# Initialize close to identity
|
|
336
359
|
nn.init.eye_(self.head_projection.weight)
|
|
@@ -354,17 +377,14 @@ class FeedforwardBlock(nn.Module):
|
|
|
354
377
|
outer_dropout=None,
|
|
355
378
|
linear_module_up=nn.Linear,
|
|
356
379
|
linear_module_down=nn.Linear,
|
|
357
|
-
pre_norm=True,
|
|
358
380
|
normformer=False,
|
|
359
|
-
post_norm=True,
|
|
360
|
-
residual_path=True,
|
|
361
381
|
checkpoint=True,
|
|
382
|
+
beta=1.0,
|
|
362
383
|
):
|
|
363
384
|
super().__init__()
|
|
364
385
|
|
|
365
386
|
self.checkpoint = checkpoint
|
|
366
|
-
self.
|
|
367
|
-
self.post_norm = post_norm
|
|
387
|
+
self.beta = beta
|
|
368
388
|
self.xglu = activation.__name__.endswith("GLU")
|
|
369
389
|
|
|
370
390
|
if self.residual_path and (output_features < input_features):
|
|
@@ -402,7 +422,6 @@ class FeedforwardBlock(nn.Module):
|
|
|
402
422
|
|
|
403
423
|
self.process = nn.Sequential(
|
|
404
424
|
*[
|
|
405
|
-
nn.LayerNorm(input_features) if pre_norm else nn.Identity(),
|
|
406
425
|
self.linear_in,
|
|
407
426
|
self.activation,
|
|
408
427
|
self.inner_dropout,
|
|
@@ -452,12 +471,7 @@ class FeedforwardBlock(nn.Module):
|
|
|
452
471
|
else:
|
|
453
472
|
processed = self.process(x)
|
|
454
473
|
|
|
455
|
-
|
|
456
|
-
return self.layernorm(x + processed)
|
|
457
|
-
elif self.residual_path:
|
|
458
|
-
return x + processed
|
|
459
|
-
else:
|
|
460
|
-
return processed
|
|
474
|
+
return processed
|
|
461
475
|
|
|
462
476
|
def reset_parameters(self):
|
|
463
477
|
if self.post_norm:
|
|
@@ -468,8 +482,11 @@ class FeedforwardBlock(nn.Module):
|
|
|
468
482
|
if hasattr(module, "reset_parameters"):
|
|
469
483
|
module.reset_parameters()
|
|
470
484
|
|
|
485
|
+
scale_parameters(self.linear_in, self.beta) # per Microsoft DeepNet
|
|
486
|
+
scale_parameters(self.linear_out, self.beta)
|
|
471
487
|
|
|
472
|
-
|
|
488
|
+
|
|
489
|
+
class EncoderBlock(nn.Module):
|
|
473
490
|
"""
|
|
474
491
|
Performs LayerNorms first (as in PyTorch Transformers when norm_first=True),
|
|
475
492
|
which is also what is seen in e.g.
|
|
@@ -504,6 +521,8 @@ class TransformerBlock(nn.Module):
|
|
|
504
521
|
post_norm=False,
|
|
505
522
|
normformer=False,
|
|
506
523
|
checkpoint_ff=True,
|
|
524
|
+
alpha=1.0,
|
|
525
|
+
beta=1.0,
|
|
507
526
|
):
|
|
508
527
|
"""
|
|
509
528
|
Args:
|
|
@@ -515,10 +534,16 @@ class TransformerBlock(nn.Module):
|
|
|
515
534
|
|
|
516
535
|
super().__init__()
|
|
517
536
|
|
|
537
|
+
if pre_norm and post_norm:
|
|
538
|
+
raise ValueError("A transformer cannot be both prenorm and postnorm.")
|
|
539
|
+
|
|
518
540
|
self.pre_norm = pre_norm
|
|
519
541
|
self.post_norm = post_norm
|
|
520
542
|
self.normformer = normformer
|
|
521
543
|
|
|
544
|
+
self.alpha = alpha
|
|
545
|
+
self.beta = beta
|
|
546
|
+
|
|
522
547
|
self.drop_path = DropPath(drop_prob=identity_probability, scale_by_keep=True)
|
|
523
548
|
|
|
524
549
|
if self.pre_norm:
|
|
@@ -529,6 +554,7 @@ class TransformerBlock(nn.Module):
|
|
|
529
554
|
self.normformer_norm = nn.LayerNorm(d_model)
|
|
530
555
|
|
|
531
556
|
if self.post_norm:
|
|
557
|
+
self.input_norm = nn.LayerNorm(d_model)
|
|
532
558
|
self.post_attention_norm = nn.LayerNorm(d_model)
|
|
533
559
|
self.post_mlp_norm = nn.LayerNorm(d_model)
|
|
534
560
|
|
|
@@ -556,6 +582,7 @@ class TransformerBlock(nn.Module):
|
|
|
556
582
|
utility_tokens=utility_tokens,
|
|
557
583
|
talking_heads=talking_heads,
|
|
558
584
|
scaling=msa_scaling,
|
|
585
|
+
beta=beta,
|
|
559
586
|
)
|
|
560
587
|
|
|
561
588
|
# Submodule for the feedforward process
|
|
@@ -578,11 +605,9 @@ class TransformerBlock(nn.Module):
|
|
|
578
605
|
if ff_linear_module_down is not None
|
|
579
606
|
else linear_module
|
|
580
607
|
),
|
|
581
|
-
pre_norm=False, # Handled outside the block
|
|
582
608
|
normformer=normformer,
|
|
583
|
-
post_norm=False, # Handled outside the block
|
|
584
|
-
residual_path=False, # Handled outside the block
|
|
585
609
|
checkpoint=checkpoint_ff,
|
|
610
|
+
beta=beta,
|
|
586
611
|
)
|
|
587
612
|
|
|
588
613
|
self.reset_parameters()
|
|
@@ -592,45 +617,36 @@ class TransformerBlock(nn.Module):
|
|
|
592
617
|
return self.attn._kv_distance
|
|
593
618
|
|
|
594
619
|
def forward(self, x):
|
|
595
|
-
if self.
|
|
596
|
-
x = self.
|
|
597
|
-
x = x + self.drop_path(self.attn(x, x, x))
|
|
598
|
-
x = self.pre_mlp_norm(x)
|
|
599
|
-
x = x + self.drop_path(self.ff(x))
|
|
600
|
-
if self.post_norm: # i.e. in addition! Pre and post.
|
|
601
|
-
x = self.post_mlp_norm(x)
|
|
602
|
-
return x
|
|
603
|
-
|
|
604
|
-
# if self.pre_norm:
|
|
605
|
-
# process_x = self.pre_attention_norm(x)
|
|
606
|
-
# else:
|
|
607
|
-
# process_x = x
|
|
620
|
+
if self.post_norm:
|
|
621
|
+
x = self.input_norm(x)
|
|
608
622
|
|
|
609
|
-
|
|
623
|
+
if self.pre_norm:
|
|
624
|
+
process_x = self.pre_attention_norm(x)
|
|
625
|
+
else:
|
|
626
|
+
process_x = x
|
|
610
627
|
|
|
611
|
-
|
|
612
|
-
# processed = self.normformer_norm(processed)
|
|
628
|
+
processed = self.drop_path(self.attn(process_x, process_x, process_x))
|
|
613
629
|
|
|
614
|
-
|
|
615
|
-
|
|
630
|
+
if self.normformer:
|
|
631
|
+
processed = self.normformer_norm(processed)
|
|
616
632
|
|
|
617
|
-
|
|
618
|
-
# x = self.post_attention_norm(x)
|
|
633
|
+
x = self.alpha * x + processed
|
|
619
634
|
|
|
620
|
-
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
|
|
635
|
+
if self.post_norm:
|
|
636
|
+
x = self.post_attention_norm(x)
|
|
637
|
+
elif self.pre_norm:
|
|
638
|
+
process_x = self.pre_mlp_norm(x)
|
|
639
|
+
else:
|
|
640
|
+
process_x = x
|
|
624
641
|
|
|
625
|
-
|
|
642
|
+
processed = self.drop_path(self.ff(process_x))
|
|
626
643
|
|
|
627
|
-
|
|
628
|
-
# x = x + processed
|
|
644
|
+
x = self.alpha * x + processed
|
|
629
645
|
|
|
630
|
-
|
|
631
|
-
|
|
646
|
+
if self.post_norm:
|
|
647
|
+
x = self.post_mlp_norm(x)
|
|
632
648
|
|
|
633
|
-
|
|
649
|
+
return x
|
|
634
650
|
|
|
635
651
|
def attention_logits(self, x):
|
|
636
652
|
"""
|
|
@@ -696,6 +712,8 @@ class TransformerEncoder(nn.Module):
|
|
|
696
712
|
normformer=False,
|
|
697
713
|
msa_scaling="d",
|
|
698
714
|
checkpoint_ff=True,
|
|
715
|
+
alpha=1.0,
|
|
716
|
+
beta=1.0,
|
|
699
717
|
):
|
|
700
718
|
"""
|
|
701
719
|
Args:
|
|
@@ -769,7 +787,7 @@ class TransformerEncoder(nn.Module):
|
|
|
769
787
|
|
|
770
788
|
self.blocks = nn.ModuleList(
|
|
771
789
|
[
|
|
772
|
-
|
|
790
|
+
EncoderBlock(
|
|
773
791
|
self.full_sequence_length,
|
|
774
792
|
d_model,
|
|
775
793
|
n_heads,
|
|
@@ -794,6 +812,8 @@ class TransformerEncoder(nn.Module):
|
|
|
794
812
|
post_norm=post_norm,
|
|
795
813
|
normformer=normformer,
|
|
796
814
|
checkpoint_ff=checkpoint_ff,
|
|
815
|
+
alpha=alpha,
|
|
816
|
+
beta=beta,
|
|
797
817
|
)
|
|
798
818
|
for i in range(n_layers)
|
|
799
819
|
]
|
|
@@ -158,7 +158,6 @@ class ViTEncoder(nn.Module):
|
|
|
158
158
|
pooling_kernel_stride=2,
|
|
159
159
|
pooling_padding=1,
|
|
160
160
|
transformer_feedforward_first=True,
|
|
161
|
-
transformer_initial_ff_residual_path=True,
|
|
162
161
|
transformer_initial_ff_linear_module_up=None,
|
|
163
162
|
transformer_initial_ff_linear_module_down=None,
|
|
164
163
|
transformer_initial_ff_dropout=None,
|
|
@@ -188,9 +187,14 @@ class ViTEncoder(nn.Module):
|
|
|
188
187
|
transformer_stochastic_depth=0.1,
|
|
189
188
|
transformer_checkpoint_ff=True,
|
|
190
189
|
linear_module=nn.Linear,
|
|
190
|
+
alpha=1.0,
|
|
191
|
+
beta=1.0,
|
|
191
192
|
):
|
|
192
193
|
super().__init__()
|
|
193
194
|
|
|
195
|
+
self.alpha = alpha
|
|
196
|
+
self.beta = beta
|
|
197
|
+
|
|
194
198
|
if cnn_activation_kwargs is not None:
|
|
195
199
|
self.cnn_activation = cnn_activation(**cnn_activation_kwargs)
|
|
196
200
|
else:
|
|
@@ -352,6 +356,8 @@ class ViTEncoder(nn.Module):
|
|
|
352
356
|
normformer=transformer_normformer,
|
|
353
357
|
post_norm=transformer_post_norm,
|
|
354
358
|
checkpoint_ff=transformer_checkpoint_ff,
|
|
359
|
+
alpha=self.alpha,
|
|
360
|
+
beta=self.beta,
|
|
355
361
|
)
|
|
356
362
|
else:
|
|
357
363
|
self.transformer = nn.Identity()
|
|
@@ -393,16 +399,14 @@ class ViTEncoder(nn.Module):
|
|
|
393
399
|
or transformer_ff_linear_module_down
|
|
394
400
|
or linear_module
|
|
395
401
|
),
|
|
396
|
-
pre_norm=transformer_pre_norm,
|
|
397
402
|
normformer=transformer_normformer,
|
|
398
|
-
post_norm=transformer_post_norm,
|
|
399
|
-
residual_path=transformer_initial_ff_residual_path,
|
|
400
403
|
checkpoint=transformer_checkpoint_ff,
|
|
404
|
+
beta=self.beta,
|
|
401
405
|
)
|
|
402
406
|
else:
|
|
403
407
|
self.initial_ff = nn.Identity()
|
|
404
408
|
|
|
405
|
-
self.
|
|
409
|
+
self.preprocess = nn.Sequential(
|
|
406
410
|
*[
|
|
407
411
|
batchnormxd(in_channels) if initial_batch_norm else nn.Identity(),
|
|
408
412
|
self.cnn,
|
|
@@ -412,19 +416,21 @@ class ViTEncoder(nn.Module):
|
|
|
412
416
|
f"N C {spatial_dim_names} -> N ({spatial_dim_names}) C"
|
|
413
417
|
),
|
|
414
418
|
self.pooling_channels_padding,
|
|
415
|
-
|
|
416
|
-
self.transformer,
|
|
419
|
+
nn.LayerNorm(),
|
|
417
420
|
]
|
|
418
421
|
)
|
|
419
422
|
|
|
420
423
|
self.reset_parameters()
|
|
421
424
|
|
|
422
425
|
def forward(self, x):
|
|
423
|
-
|
|
426
|
+
x = self.preprocess(x)
|
|
427
|
+
x = x + self.initial_ff(x)
|
|
428
|
+
return self.transformer(x)
|
|
424
429
|
|
|
425
430
|
def attention_logits(self, x):
|
|
426
|
-
x = self.
|
|
427
|
-
|
|
431
|
+
x = self.preprocess(x)
|
|
432
|
+
x = x + self.initial_ff(x)
|
|
433
|
+
return self.transformer.attention_logits(x)
|
|
428
434
|
|
|
429
435
|
def reset_parameters(self):
|
|
430
436
|
for module in self.encoder:
|
|
@@ -458,7 +464,6 @@ class ViT(nn.Module):
|
|
|
458
464
|
pooling_kernel_stride=2,
|
|
459
465
|
pooling_padding=1,
|
|
460
466
|
transformer_feedforward_first=True,
|
|
461
|
-
transformer_initial_ff_residual_path=True,
|
|
462
467
|
transformer_initial_ff_linear_module_up=None,
|
|
463
468
|
transformer_initial_ff_linear_module_down=None,
|
|
464
469
|
transformer_initial_ff_dropout=None,
|
|
@@ -491,6 +496,8 @@ class ViT(nn.Module):
|
|
|
491
496
|
batch_norm_logits=True,
|
|
492
497
|
logit_projection_layer=nn.Linear,
|
|
493
498
|
linear_module=nn.Linear,
|
|
499
|
+
alpha=1.0,
|
|
500
|
+
beta=1.0,
|
|
494
501
|
):
|
|
495
502
|
|
|
496
503
|
super().__init__()
|
|
@@ -511,6 +518,9 @@ class ViT(nn.Module):
|
|
|
511
518
|
"SwiGLU": SwiGLU,
|
|
512
519
|
}[transformer_activation]
|
|
513
520
|
|
|
521
|
+
self.alpha = alpha
|
|
522
|
+
self.beta = beta
|
|
523
|
+
|
|
514
524
|
self.encoder = ViTEncoder(
|
|
515
525
|
input_size=input_size,
|
|
516
526
|
initial_batch_norm=initial_batch_norm,
|
|
@@ -530,7 +540,6 @@ class ViT(nn.Module):
|
|
|
530
540
|
pooling_kernel_stride=pooling_kernel_stride,
|
|
531
541
|
pooling_padding=pooling_padding,
|
|
532
542
|
transformer_feedforward_first=transformer_feedforward_first,
|
|
533
|
-
transformer_initial_ff_residual_path=transformer_initial_ff_residual_path,
|
|
534
543
|
transformer_initial_ff_linear_module_up=transformer_initial_ff_linear_module_up,
|
|
535
544
|
transformer_initial_ff_linear_module_down=transformer_initial_ff_linear_module_down,
|
|
536
545
|
transformer_initial_ff_dropout=transformer_initial_ff_dropout,
|
|
@@ -560,6 +569,8 @@ class ViT(nn.Module):
|
|
|
560
569
|
transformer_stochastic_depth=transformer_stochastic_depth,
|
|
561
570
|
transformer_checkpoint_ff=transformer_checkpoint_ff,
|
|
562
571
|
linear_module=linear_module,
|
|
572
|
+
alpha=alpha,
|
|
573
|
+
beta=beta,
|
|
563
574
|
)
|
|
564
575
|
|
|
565
576
|
self.pool = head(
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|