broccoli-ml 0.25.0__py3-none-any.whl → 0.26.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- broccoli/transformer.py +17 -7
- broccoli/vit.py +2 -1
- {broccoli_ml-0.25.0.dist-info → broccoli_ml-0.26.0.dist-info}/METADATA +1 -1
- {broccoli_ml-0.25.0.dist-info → broccoli_ml-0.26.0.dist-info}/RECORD +6 -6
- {broccoli_ml-0.25.0.dist-info → broccoli_ml-0.26.0.dist-info}/LICENSE +0 -0
- {broccoli_ml-0.25.0.dist-info → broccoli_ml-0.26.0.dist-info}/WHEEL +0 -0
broccoli/transformer.py
CHANGED
@@ -299,9 +299,12 @@ class TransformerBlock(nn.Module):
|
|
299
299
|
):
|
300
300
|
super().__init__()
|
301
301
|
|
302
|
+
self.pre_norm = pre_norm
|
303
|
+
|
302
304
|
self.identity_probability = identity_probability
|
303
305
|
|
304
|
-
self.
|
306
|
+
self.layer_norm_1 = nn.LayerNorm(d_model)
|
307
|
+
self.layer_norm_2 = nn.LayerNorm(d_model)
|
305
308
|
|
306
309
|
if position_embedding_type == "relative":
|
307
310
|
max_freq = int(max(source_size) / 2) # Suggested by Gemini!
|
@@ -358,12 +361,19 @@ class TransformerBlock(nn.Module):
|
|
358
361
|
identity_x = shuffled[:identity_count, :, :]
|
359
362
|
process_x = shuffled[identity_count:, :, :]
|
360
363
|
|
361
|
-
|
362
|
-
|
363
|
-
|
364
|
-
|
365
|
-
|
366
|
-
|
364
|
+
if self.pre_norm:
|
365
|
+
norm_process_x = self.layer_norm_1(process_x)
|
366
|
+
process_x = process_x + self.attn(
|
367
|
+
norm_process_x, norm_process_x, norm_process_x
|
368
|
+
)
|
369
|
+
process_x = process_x + self.ff(process_x)
|
370
|
+
else: # post-norm
|
371
|
+
process_x = process_x + self.attn(process_x, process_x, process_x)
|
372
|
+
norm_process_x = self.layer_norm_1(process_x)
|
373
|
+
process_x = process_x + self.ff(process_x)
|
374
|
+
x = self.layer_norm_2(
|
375
|
+
torch.cat([identity_x, process_x])[unshuffle_indices, :, :].contiguous()
|
376
|
+
)
|
367
377
|
|
368
378
|
return x
|
369
379
|
|
broccoli/vit.py
CHANGED
@@ -306,7 +306,8 @@ class ViTEncoder(nn.Module):
|
|
306
306
|
activation_kwargs=transformer_activation_kwargs,
|
307
307
|
dropout=transformer_mlp_dropout,
|
308
308
|
linear_module=linear_module,
|
309
|
-
|
309
|
+
pre_norm=transformer_pre_norm,
|
310
|
+
normformer=transformer_normformer,
|
310
311
|
)
|
311
312
|
else:
|
312
313
|
self.initial_ff = nn.Identity()
|
@@ -8,10 +8,10 @@ broccoli/eigenpatches.py,sha256=J6n2usN1oQuHEHYiBNyYpn_a9eQcHjOBiIlvSei520Y,2413
|
|
8
8
|
broccoli/linear.py,sha256=7NkNvhtxzWAUoBJOuiPUIcr853HhI1cS71d8DwdMkJ0,4826
|
9
9
|
broccoli/rope.py,sha256=hw7kBPNR9GQXj4GxyIAffsGKPfcTPOFh8Bc7oEHtaZY,12108
|
10
10
|
broccoli/tensor.py,sha256=zhSOo9W24FEgN7U35wy3ZIJHnw3u4cepJO5heCw6vwU,4590
|
11
|
-
broccoli/transformer.py,sha256=
|
11
|
+
broccoli/transformer.py,sha256=IU-w0xPaJ2D6JovlMSkoa8DF_N7sMoprPDqDjVuz_UA,16684
|
12
12
|
broccoli/utils.py,sha256=htq_hOsdhUhL0nJi9WkKiEYOjEoWqFpK5X49PtgTf-0,299
|
13
|
-
broccoli/vit.py,sha256=
|
14
|
-
broccoli_ml-0.
|
15
|
-
broccoli_ml-0.
|
16
|
-
broccoli_ml-0.
|
17
|
-
broccoli_ml-0.
|
13
|
+
broccoli/vit.py,sha256=nuXX2JoKoBTtF1tAH-11mL2R5ISgMHYsBbbgvcluV1s,16072
|
14
|
+
broccoli_ml-0.26.0.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
|
15
|
+
broccoli_ml-0.26.0.dist-info/METADATA,sha256=W_CuQWiftkfIAyb2wRZqaP-KS2vA8b_V2GqCdGvEU9Q,1257
|
16
|
+
broccoli_ml-0.26.0.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
17
|
+
broccoli_ml-0.26.0.dist-info/RECORD,,
|
File without changes
|
File without changes
|