broccoli-ml 15.4.1__py3-none-any.whl → 15.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- broccoli/transformer.py +6 -25
- broccoli/vit.py +2 -7
- {broccoli_ml-15.4.1.dist-info → broccoli_ml-15.5.0.dist-info}/METADATA +1 -1
- {broccoli_ml-15.4.1.dist-info → broccoli_ml-15.5.0.dist-info}/RECORD +6 -6
- {broccoli_ml-15.4.1.dist-info → broccoli_ml-15.5.0.dist-info}/LICENSE +0 -0
- {broccoli_ml-15.4.1.dist-info → broccoli_ml-15.5.0.dist-info}/WHEEL +0 -0
broccoli/transformer.py
CHANGED
|
@@ -122,9 +122,6 @@ class MHAttention(nn.Module):
|
|
|
122
122
|
|
|
123
123
|
self.head_dim = self.embed_dim // self.n_heads
|
|
124
124
|
|
|
125
|
-
self.query_norm = nn.RMSNorm(self.head_dim)
|
|
126
|
-
self.key_norm = nn.RMSNorm(self.head_dim)
|
|
127
|
-
|
|
128
125
|
if self.scaling == "sqrtd":
|
|
129
126
|
self.scaling_factor = 1 / math.sqrt(self.head_dim)
|
|
130
127
|
elif self.scaling == "d":
|
|
@@ -229,8 +226,8 @@ class MHAttention(nn.Module):
|
|
|
229
226
|
freqs = self.rotary_embedding.get_axial_freqs(*self.source_size)
|
|
230
227
|
|
|
231
228
|
# norm Qs/Ks to protect axial rope, like https://arxiv.org/abs/2302.05442
|
|
232
|
-
q_img = apply_rotary_emb(freqs,
|
|
233
|
-
k_img = apply_rotary_emb(freqs,
|
|
229
|
+
q_img = apply_rotary_emb(freqs, q_img)
|
|
230
|
+
k_img = apply_rotary_emb(freqs, k_img)
|
|
234
231
|
|
|
235
232
|
q_img = rearrange(
|
|
236
233
|
q_img,
|
|
@@ -354,17 +351,9 @@ class MHAttention(nn.Module):
|
|
|
354
351
|
self.q_proj.reset_parameters()
|
|
355
352
|
self.k_proj.reset_parameters()
|
|
356
353
|
self.v_proj.reset_parameters()
|
|
357
|
-
scale_parameters(
|
|
358
|
-
self.v_proj,
|
|
359
|
-
math.sqrt(6)
|
|
360
|
-
* self.beta, # sqrt(6) to compensate for PyTorch tiny default init
|
|
361
|
-
)
|
|
354
|
+
scale_parameters(self.v_proj, self.beta)
|
|
362
355
|
self.out_proj.reset_parameters()
|
|
363
|
-
scale_parameters(
|
|
364
|
-
self.out_proj,
|
|
365
|
-
math.sqrt(6)
|
|
366
|
-
* self.beta, # sqrt(6) to compensate for PyTorch tiny default init
|
|
367
|
-
)
|
|
356
|
+
scale_parameters(self.out_proj, self.beta)
|
|
368
357
|
|
|
369
358
|
if self.talking_heads:
|
|
370
359
|
# Initialize close to identity
|
|
@@ -481,16 +470,8 @@ class FeedforwardBlock(nn.Module):
|
|
|
481
470
|
if hasattr(module, "reset_parameters"):
|
|
482
471
|
module.reset_parameters()
|
|
483
472
|
|
|
484
|
-
scale_parameters(
|
|
485
|
-
|
|
486
|
-
math.sqrt(6)
|
|
487
|
-
* self.beta, # sqrt(6) to compensate for PyTorch tiny default init
|
|
488
|
-
)
|
|
489
|
-
scale_parameters(
|
|
490
|
-
self.linear_out,
|
|
491
|
-
math.sqrt(6)
|
|
492
|
-
* self.beta, # sqrt(6) to compensate for PyTorch tiny default init
|
|
493
|
-
)
|
|
473
|
+
scale_parameters(self.linear_in, self.beta)
|
|
474
|
+
scale_parameters(self.linear_out, self.beta)
|
|
494
475
|
|
|
495
476
|
|
|
496
477
|
class EncoderBlock(nn.Module):
|
broccoli/vit.py
CHANGED
|
@@ -520,13 +520,8 @@ class ViT(nn.Module):
|
|
|
520
520
|
"SwiGLU": SwiGLU,
|
|
521
521
|
}[transformer_activation]
|
|
522
522
|
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
self.alpha = (2 * transformer_layers) ** 0.25
|
|
526
|
-
self.beta = 1.0 # beta is not needed as we norm the Q and K vectors in MSA!
|
|
527
|
-
else:
|
|
528
|
-
self.alpha = 1.0
|
|
529
|
-
self.beta = 1.0 # beta is not needed as we norm the Q and K vectors in MSA!
|
|
523
|
+
self.alpha = (2 * transformer_layers) ** 0.25
|
|
524
|
+
self.beta = (8 * transformer_layers) ** 0.25
|
|
530
525
|
|
|
531
526
|
self.encoder = ViTEncoder(
|
|
532
527
|
input_size=input_size,
|
|
@@ -4,10 +4,10 @@ broccoli/cnn.py,sha256=WjoPDSpe3ttwxCBNfCVRdaCHvbeZ7G-a5_i8fUsK_d8,4889
|
|
|
4
4
|
broccoli/linear.py,sha256=W-3aNpBjd_0xRyzbCKkmg4H1qmslQOIQhB-WDDay2nM,13125
|
|
5
5
|
broccoli/rope.py,sha256=GRqApBNmYCFaDak0WL1xE_BC5CTTYKQU_PBdeTcQcjc,12557
|
|
6
6
|
broccoli/tensor.py,sha256=um8mrxkYbvNDo-QvHlmJm8Aw6qcngOlUZPoAk_PMReA,4480
|
|
7
|
-
broccoli/transformer.py,sha256=
|
|
7
|
+
broccoli/transformer.py,sha256=xhMKGWgQqSMhCpN-cqM6Fv_MfyKU9-Gq1t9nGpUAmzE,27574
|
|
8
8
|
broccoli/utils.py,sha256=oOWzn6dJ5nC_9r4zq0emmfmaYACJXJNFS48AOpW2jqc,358
|
|
9
|
-
broccoli/vit.py,sha256=
|
|
10
|
-
broccoli_ml-15.
|
|
11
|
-
broccoli_ml-15.
|
|
12
|
-
broccoli_ml-15.
|
|
13
|
-
broccoli_ml-15.
|
|
9
|
+
broccoli/vit.py,sha256=v3U_UVIZd2t3Nt60K6KGJcI5ci9t9S8h2ENwklnHg8M,22735
|
|
10
|
+
broccoli_ml-15.5.0.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
|
|
11
|
+
broccoli_ml-15.5.0.dist-info/METADATA,sha256=ANXSYDDts212i3b0rySkKT71_2ZSpcmpHSloNayfNns,1369
|
|
12
|
+
broccoli_ml-15.5.0.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
|
13
|
+
broccoli_ml-15.5.0.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|