broccoli-ml 4.0.0__tar.gz → 5.0.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {broccoli_ml-4.0.0 → broccoli_ml-5.0.0}/PKG-INFO +1 -1
- {broccoli_ml-4.0.0 → broccoli_ml-5.0.0}/broccoli/transformer.py +21 -7
- {broccoli_ml-4.0.0 → broccoli_ml-5.0.0}/broccoli/vit.py +5 -0
- {broccoli_ml-4.0.0 → broccoli_ml-5.0.0}/pyproject.toml +1 -1
- {broccoli_ml-4.0.0 → broccoli_ml-5.0.0}/LICENSE +0 -0
- {broccoli_ml-4.0.0 → broccoli_ml-5.0.0}/README.md +0 -0
- {broccoli_ml-4.0.0 → broccoli_ml-5.0.0}/broccoli/__init__.py +0 -0
- {broccoli_ml-4.0.0 → broccoli_ml-5.0.0}/broccoli/activation.py +0 -0
- {broccoli_ml-4.0.0 → broccoli_ml-5.0.0}/broccoli/cnn.py +0 -0
- {broccoli_ml-4.0.0 → broccoli_ml-5.0.0}/broccoli/linear.py +0 -0
- {broccoli_ml-4.0.0 → broccoli_ml-5.0.0}/broccoli/rope.py +0 -0
- {broccoli_ml-4.0.0 → broccoli_ml-5.0.0}/broccoli/tensor.py +0 -0
- {broccoli_ml-4.0.0 → broccoli_ml-5.0.0}/broccoli/utils.py +0 -0
|
@@ -233,7 +233,7 @@ class MHAttention(nn.Module):
|
|
|
233
233
|
q,
|
|
234
234
|
k,
|
|
235
235
|
v,
|
|
236
|
-
dropout_p=self.dropout if self.training else 0.0,
|
|
236
|
+
dropout_p=self.dropout.p if self.training else 0.0,
|
|
237
237
|
softmax_scale=scaling_factor,
|
|
238
238
|
causal=self.causal,
|
|
239
239
|
)
|
|
@@ -257,6 +257,8 @@ class MHAttention(nn.Module):
|
|
|
257
257
|
|
|
258
258
|
qk_scores = F.softmax(qk_scores, dim=-1)
|
|
259
259
|
|
|
260
|
+
qk_scores = self.dropout(qk_scores)
|
|
261
|
+
|
|
260
262
|
output_with_heads = qk_scores @ v
|
|
261
263
|
|
|
262
264
|
output_without_heads = rearrange(output_with_heads, "b h t d -> b t (h d)")
|
|
@@ -283,9 +285,11 @@ class FeedforwardBlock(nn.Module):
|
|
|
283
285
|
normformer=False,
|
|
284
286
|
post_norm=True,
|
|
285
287
|
residual_path=True,
|
|
288
|
+
checkpoint=True,
|
|
286
289
|
):
|
|
287
290
|
super().__init__()
|
|
288
291
|
|
|
292
|
+
self.checkpoint = checkpoint
|
|
289
293
|
self.residual_path = residual_path
|
|
290
294
|
self.post_norm = post_norm
|
|
291
295
|
|
|
@@ -324,12 +328,18 @@ class FeedforwardBlock(nn.Module):
|
|
|
324
328
|
)
|
|
325
329
|
|
|
326
330
|
def forward(self, x):
|
|
331
|
+
|
|
332
|
+
if self.checkpoint:
|
|
333
|
+
processed = checkpoint(self.process, x, use_reentrant=False)
|
|
334
|
+
else:
|
|
335
|
+
processed = self.process(x)
|
|
336
|
+
|
|
327
337
|
if self.residual_path and self.post_norm:
|
|
328
|
-
return self.layernorm(x +
|
|
338
|
+
return self.layernorm(x + processed)
|
|
329
339
|
elif self.residual_path:
|
|
330
|
-
return x +
|
|
340
|
+
return x + processed
|
|
331
341
|
else:
|
|
332
|
-
return
|
|
342
|
+
return processed
|
|
333
343
|
|
|
334
344
|
|
|
335
345
|
class TransformerBlock(nn.Module):
|
|
@@ -363,6 +373,7 @@ class TransformerBlock(nn.Module):
|
|
|
363
373
|
pre_norm=True,
|
|
364
374
|
post_norm=False,
|
|
365
375
|
normformer=False,
|
|
376
|
+
checkpoint_ff=True,
|
|
366
377
|
):
|
|
367
378
|
"""
|
|
368
379
|
Args:
|
|
@@ -431,6 +442,7 @@ class TransformerBlock(nn.Module):
|
|
|
431
442
|
normformer=normformer,
|
|
432
443
|
post_norm=False, # Handled outside the block
|
|
433
444
|
residual_path=False, # Handled outside the block
|
|
445
|
+
checkpoint=checkpoint_ff,
|
|
434
446
|
)
|
|
435
447
|
|
|
436
448
|
@property
|
|
@@ -443,17 +455,17 @@ class TransformerBlock(nn.Module):
|
|
|
443
455
|
x = self.layer_norm_1(x)
|
|
444
456
|
x = x + self.drop_path(self.attn(x, x, x))
|
|
445
457
|
x = self.layer_norm_2(x)
|
|
446
|
-
x = x + self.drop_path(
|
|
458
|
+
x = x + self.drop_path(self.ff(x))
|
|
447
459
|
if self.post_norm: # i.e. in addition! Pre and post.
|
|
448
460
|
x = self.layer_norm_3(x)
|
|
449
461
|
elif self.post_norm: # i.e. only, not prenorm, just post
|
|
450
462
|
x = x + self.drop_path(self.attn(x, x, x))
|
|
451
463
|
x = self.layer_norm_1(x)
|
|
452
|
-
x = x + self.drop_path(
|
|
464
|
+
x = x + self.drop_path(self.ff(x))
|
|
453
465
|
x = self.layer_norm_2(x)
|
|
454
466
|
else: # Not pre or post norm. Stand well back.
|
|
455
467
|
x = x + self.drop_path(self.attn(x, x, x))
|
|
456
|
-
x = x + self.drop_path(
|
|
468
|
+
x = x + self.drop_path(self.ff(x))
|
|
457
469
|
|
|
458
470
|
return x
|
|
459
471
|
|
|
@@ -489,6 +501,7 @@ class TransformerEncoder(nn.Module):
|
|
|
489
501
|
post_norm=False,
|
|
490
502
|
normformer=False,
|
|
491
503
|
msa_scaling="d",
|
|
504
|
+
checkpoint_ff=True,
|
|
492
505
|
):
|
|
493
506
|
"""
|
|
494
507
|
Args:
|
|
@@ -565,6 +578,7 @@ class TransformerEncoder(nn.Module):
|
|
|
565
578
|
pre_norm=pre_norm,
|
|
566
579
|
post_norm=post_norm,
|
|
567
580
|
normformer=normformer,
|
|
581
|
+
checkpoint_ff=checkpoint_ff,
|
|
568
582
|
)
|
|
569
583
|
for i in range(n_layers)
|
|
570
584
|
]
|
|
@@ -160,6 +160,7 @@ class ViTEncoder(nn.Module):
|
|
|
160
160
|
transformer_mlp_dropout=0.0,
|
|
161
161
|
transformer_msa_dropout=0.1,
|
|
162
162
|
transformer_stochastic_depth=0.1,
|
|
163
|
+
transformer_checkpoint_ff=True,
|
|
163
164
|
linear_module=nn.Linear,
|
|
164
165
|
):
|
|
165
166
|
super().__init__()
|
|
@@ -321,6 +322,7 @@ class ViTEncoder(nn.Module):
|
|
|
321
322
|
pre_norm=transformer_pre_norm,
|
|
322
323
|
normformer=transformer_normformer,
|
|
323
324
|
post_norm=transformer_post_norm,
|
|
325
|
+
checkpoint_ff=transformer_checkpoint_ff,
|
|
324
326
|
)
|
|
325
327
|
else:
|
|
326
328
|
self.transformer = nn.Identity()
|
|
@@ -354,6 +356,7 @@ class ViTEncoder(nn.Module):
|
|
|
354
356
|
normformer=transformer_normformer,
|
|
355
357
|
post_norm=transformer_post_norm,
|
|
356
358
|
residual_path=transformer_initial_ff_residual_path,
|
|
359
|
+
checkpoint=transformer_checkpoint_ff,
|
|
357
360
|
)
|
|
358
361
|
else:
|
|
359
362
|
self.initial_ff = nn.Identity()
|
|
@@ -426,6 +429,7 @@ class ViT(nn.Module):
|
|
|
426
429
|
transformer_mlp_dropout=0.0,
|
|
427
430
|
transformer_msa_dropout=0.1,
|
|
428
431
|
transformer_stochastic_depth=0.1,
|
|
432
|
+
transformer_checkpoint_ff=True,
|
|
429
433
|
head=SequencePoolClassificationHead,
|
|
430
434
|
batch_norm_logits=True,
|
|
431
435
|
logit_projection_layer=nn.Linear,
|
|
@@ -492,6 +496,7 @@ class ViT(nn.Module):
|
|
|
492
496
|
transformer_mlp_dropout=transformer_mlp_dropout,
|
|
493
497
|
transformer_msa_dropout=transformer_msa_dropout,
|
|
494
498
|
transformer_stochastic_depth=transformer_stochastic_depth,
|
|
499
|
+
transformer_checkpoint_ff=transformer_checkpoint_ff,
|
|
495
500
|
linear_module=linear_module,
|
|
496
501
|
)
|
|
497
502
|
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|