broccoli-ml 15.0.1__py3-none-any.whl → 15.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- broccoli/transformer.py +14 -16
- broccoli/vit.py +4 -3
- {broccoli_ml-15.0.1.dist-info → broccoli_ml-15.2.0.dist-info}/METADATA +1 -1
- {broccoli_ml-15.0.1.dist-info → broccoli_ml-15.2.0.dist-info}/RECORD +6 -6
- {broccoli_ml-15.0.1.dist-info → broccoli_ml-15.2.0.dist-info}/LICENSE +0 -0
- {broccoli_ml-15.0.1.dist-info → broccoli_ml-15.2.0.dist-info}/WHEEL +0 -0
broccoli/transformer.py
CHANGED
|
@@ -122,6 +122,9 @@ class MHAttention(nn.Module):
|
|
|
122
122
|
|
|
123
123
|
self.head_dim = self.embed_dim // self.n_heads
|
|
124
124
|
|
|
125
|
+
self.query_norm = nn.RMSNorm(self.head_dim)
|
|
126
|
+
self.key_norm = nn.RMSNorm(self.head_dim)
|
|
127
|
+
|
|
125
128
|
if self.scaling == "sqrtd":
|
|
126
129
|
self.scaling_factor = 1 / math.sqrt(self.head_dim)
|
|
127
130
|
elif self.scaling == "d":
|
|
@@ -225,8 +228,9 @@ class MHAttention(nn.Module):
|
|
|
225
228
|
|
|
226
229
|
freqs = self.rotary_embedding.get_axial_freqs(*self.source_size)
|
|
227
230
|
|
|
228
|
-
|
|
229
|
-
|
|
231
|
+
# norm Qs/Ks to protect axial rope, like https://arxiv.org/abs/2302.05442
|
|
232
|
+
q_img = apply_rotary_emb(freqs, self.query_norm(q_img))
|
|
233
|
+
k_img = apply_rotary_emb(freqs, self.key_norm(k_img))
|
|
230
234
|
|
|
231
235
|
q_img = rearrange(
|
|
232
236
|
q_img,
|
|
@@ -416,7 +420,7 @@ class FeedforwardBlock(nn.Module):
|
|
|
416
420
|
self.activation,
|
|
417
421
|
self.inner_dropout,
|
|
418
422
|
(
|
|
419
|
-
nn.
|
|
423
|
+
nn.RMSNorm(int(ratio * output_features))
|
|
420
424
|
if normformer
|
|
421
425
|
else nn.Identity()
|
|
422
426
|
),
|
|
@@ -474,13 +478,7 @@ class FeedforwardBlock(nn.Module):
|
|
|
474
478
|
|
|
475
479
|
|
|
476
480
|
class EncoderBlock(nn.Module):
|
|
477
|
-
"""
|
|
478
|
-
Performs LayerNorms first (as in PyTorch Transformers when norm_first=True),
|
|
479
|
-
which is also what is seen in e.g.
|
|
480
|
-
https://github.com/karpathy/minGPT/blob/master/mingpt/model.py
|
|
481
|
-
and is recommended by https://arxiv.org/abs/2002.04745
|
|
482
|
-
|
|
483
|
-
"""
|
|
481
|
+
""" """
|
|
484
482
|
|
|
485
483
|
def __init__(
|
|
486
484
|
self,
|
|
@@ -534,16 +532,16 @@ class EncoderBlock(nn.Module):
|
|
|
534
532
|
self.drop_path = DropPath(drop_prob=identity_probability, scale_by_keep=True)
|
|
535
533
|
|
|
536
534
|
if self.pre_norm:
|
|
537
|
-
self.pre_attention_norm = nn.
|
|
538
|
-
self.pre_mlp_norm = nn.
|
|
535
|
+
self.pre_attention_norm = nn.RMSNorm(d_model)
|
|
536
|
+
self.pre_mlp_norm = nn.RMSNorm(d_model)
|
|
539
537
|
|
|
540
538
|
if normformer:
|
|
541
|
-
self.normformer_norm = nn.
|
|
539
|
+
self.normformer_norm = nn.RMSNorm(d_model)
|
|
542
540
|
|
|
543
541
|
if self.post_norm:
|
|
544
|
-
self.input_norm = nn.
|
|
545
|
-
self.post_attention_norm = nn.
|
|
546
|
-
self.post_mlp_norm = nn.
|
|
542
|
+
self.input_norm = nn.RMSNorm(d_model)
|
|
543
|
+
self.post_attention_norm = nn.RMSNorm(d_model)
|
|
544
|
+
self.post_mlp_norm = nn.RMSNorm(d_model)
|
|
547
545
|
|
|
548
546
|
if relative_position_embedding:
|
|
549
547
|
max_freq = int(max(source_size) / 2) # Suggested by Gemini!
|
broccoli/vit.py
CHANGED
|
@@ -403,7 +403,7 @@ class ViTEncoder(nn.Module):
|
|
|
403
403
|
checkpoint=transformer_checkpoint_ff,
|
|
404
404
|
beta=self.beta,
|
|
405
405
|
)
|
|
406
|
-
self.layer_norm = nn.
|
|
406
|
+
self.layer_norm = nn.RMSNorm(transformer_embedding_size)
|
|
407
407
|
else:
|
|
408
408
|
self.initial_ff = None
|
|
409
409
|
|
|
@@ -417,7 +417,7 @@ class ViTEncoder(nn.Module):
|
|
|
417
417
|
f"N C {spatial_dim_names} -> N ({spatial_dim_names}) C"
|
|
418
418
|
),
|
|
419
419
|
self.pooling_channels_padding,
|
|
420
|
-
nn.
|
|
420
|
+
nn.RMSNorm(transformer_embedding_size),
|
|
421
421
|
]
|
|
422
422
|
)
|
|
423
423
|
|
|
@@ -522,7 +522,8 @@ class ViT(nn.Module):
|
|
|
522
522
|
|
|
523
523
|
# Set alpha and beta according to Microsoft's DeepNorm
|
|
524
524
|
self.alpha = (2 * transformer_layers) ** 0.25
|
|
525
|
-
|
|
525
|
+
# beta is only needed for very deep models
|
|
526
|
+
self.beta = 1 if transformer_layers < 50 else (8 * transformer_layers) ** -0.25
|
|
526
527
|
|
|
527
528
|
self.encoder = ViTEncoder(
|
|
528
529
|
input_size=input_size,
|
|
@@ -4,10 +4,10 @@ broccoli/cnn.py,sha256=WjoPDSpe3ttwxCBNfCVRdaCHvbeZ7G-a5_i8fUsK_d8,4889
|
|
|
4
4
|
broccoli/linear.py,sha256=W-3aNpBjd_0xRyzbCKkmg4H1qmslQOIQhB-WDDay2nM,13125
|
|
5
5
|
broccoli/rope.py,sha256=GRqApBNmYCFaDak0WL1xE_BC5CTTYKQU_PBdeTcQcjc,12557
|
|
6
6
|
broccoli/tensor.py,sha256=um8mrxkYbvNDo-QvHlmJm8Aw6qcngOlUZPoAk_PMReA,4480
|
|
7
|
-
broccoli/transformer.py,sha256=
|
|
7
|
+
broccoli/transformer.py,sha256=Uuz_jCMZRg6GY2DmW3-Tn47gV9a-xkGVN3xQ5BYFM5w,27784
|
|
8
8
|
broccoli/utils.py,sha256=oOWzn6dJ5nC_9r4zq0emmfmaYACJXJNFS48AOpW2jqc,358
|
|
9
|
-
broccoli/vit.py,sha256=
|
|
10
|
-
broccoli_ml-15.0.
|
|
11
|
-
broccoli_ml-15.0.
|
|
12
|
-
broccoli_ml-15.0.
|
|
13
|
-
broccoli_ml-15.0.
|
|
9
|
+
broccoli/vit.py,sha256=fMpdUDUARGs8nPXR6l-SLSlVMNyeSrHIH8Fst4-e5wU,22884
|
|
10
|
+
broccoli_ml-15.2.0.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
|
|
11
|
+
broccoli_ml-15.2.0.dist-info/METADATA,sha256=4Pqm4a8K-40gZxm0mMRsvBAUkapsBPXJZRZpzqQAoyY,1369
|
|
12
|
+
broccoli_ml-15.2.0.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
|
13
|
+
broccoli_ml-15.2.0.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|