broccoli-ml 4.0.1__tar.gz → 5.0.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {broccoli_ml-4.0.1 → broccoli_ml-5.0.0}/PKG-INFO +1 -1
- {broccoli_ml-4.0.1 → broccoli_ml-5.0.0}/broccoli/transformer.py +18 -6
- {broccoli_ml-4.0.1 → broccoli_ml-5.0.0}/broccoli/vit.py +5 -0
- {broccoli_ml-4.0.1 → broccoli_ml-5.0.0}/pyproject.toml +1 -1
- {broccoli_ml-4.0.1 → broccoli_ml-5.0.0}/LICENSE +0 -0
- {broccoli_ml-4.0.1 → broccoli_ml-5.0.0}/README.md +0 -0
- {broccoli_ml-4.0.1 → broccoli_ml-5.0.0}/broccoli/__init__.py +0 -0
- {broccoli_ml-4.0.1 → broccoli_ml-5.0.0}/broccoli/activation.py +0 -0
- {broccoli_ml-4.0.1 → broccoli_ml-5.0.0}/broccoli/cnn.py +0 -0
- {broccoli_ml-4.0.1 → broccoli_ml-5.0.0}/broccoli/linear.py +0 -0
- {broccoli_ml-4.0.1 → broccoli_ml-5.0.0}/broccoli/rope.py +0 -0
- {broccoli_ml-4.0.1 → broccoli_ml-5.0.0}/broccoli/tensor.py +0 -0
- {broccoli_ml-4.0.1 → broccoli_ml-5.0.0}/broccoli/utils.py +0 -0
|
@@ -285,9 +285,11 @@ class FeedforwardBlock(nn.Module):
|
|
|
285
285
|
normformer=False,
|
|
286
286
|
post_norm=True,
|
|
287
287
|
residual_path=True,
|
|
288
|
+
checkpoint=True,
|
|
288
289
|
):
|
|
289
290
|
super().__init__()
|
|
290
291
|
|
|
292
|
+
self.checkpoint = checkpoint
|
|
291
293
|
self.residual_path = residual_path
|
|
292
294
|
self.post_norm = post_norm
|
|
293
295
|
|
|
@@ -326,12 +328,18 @@ class FeedforwardBlock(nn.Module):
|
|
|
326
328
|
)
|
|
327
329
|
|
|
328
330
|
def forward(self, x):
|
|
331
|
+
|
|
332
|
+
if self.checkpoint:
|
|
333
|
+
processed = checkpoint(self.process, x, use_reentrant=False)
|
|
334
|
+
else:
|
|
335
|
+
processed = self.process(x)
|
|
336
|
+
|
|
329
337
|
if self.residual_path and self.post_norm:
|
|
330
|
-
return self.layernorm(x +
|
|
338
|
+
return self.layernorm(x + processed)
|
|
331
339
|
elif self.residual_path:
|
|
332
|
-
return x +
|
|
340
|
+
return x + processed
|
|
333
341
|
else:
|
|
334
|
-
return
|
|
342
|
+
return processed
|
|
335
343
|
|
|
336
344
|
|
|
337
345
|
class TransformerBlock(nn.Module):
|
|
@@ -365,6 +373,7 @@ class TransformerBlock(nn.Module):
|
|
|
365
373
|
pre_norm=True,
|
|
366
374
|
post_norm=False,
|
|
367
375
|
normformer=False,
|
|
376
|
+
checkpoint_ff=True,
|
|
368
377
|
):
|
|
369
378
|
"""
|
|
370
379
|
Args:
|
|
@@ -433,6 +442,7 @@ class TransformerBlock(nn.Module):
|
|
|
433
442
|
normformer=normformer,
|
|
434
443
|
post_norm=False, # Handled outside the block
|
|
435
444
|
residual_path=False, # Handled outside the block
|
|
445
|
+
checkpoint=checkpoint_ff,
|
|
436
446
|
)
|
|
437
447
|
|
|
438
448
|
@property
|
|
@@ -445,17 +455,17 @@ class TransformerBlock(nn.Module):
|
|
|
445
455
|
x = self.layer_norm_1(x)
|
|
446
456
|
x = x + self.drop_path(self.attn(x, x, x))
|
|
447
457
|
x = self.layer_norm_2(x)
|
|
448
|
-
x = x + self.drop_path(
|
|
458
|
+
x = x + self.drop_path(self.ff(x))
|
|
449
459
|
if self.post_norm: # i.e. in addition! Pre and post.
|
|
450
460
|
x = self.layer_norm_3(x)
|
|
451
461
|
elif self.post_norm: # i.e. only, not prenorm, just post
|
|
452
462
|
x = x + self.drop_path(self.attn(x, x, x))
|
|
453
463
|
x = self.layer_norm_1(x)
|
|
454
|
-
x = x + self.drop_path(
|
|
464
|
+
x = x + self.drop_path(self.ff(x))
|
|
455
465
|
x = self.layer_norm_2(x)
|
|
456
466
|
else: # Not pre or post norm. Stand well back.
|
|
457
467
|
x = x + self.drop_path(self.attn(x, x, x))
|
|
458
|
-
x = x + self.drop_path(
|
|
468
|
+
x = x + self.drop_path(self.ff(x))
|
|
459
469
|
|
|
460
470
|
return x
|
|
461
471
|
|
|
@@ -491,6 +501,7 @@ class TransformerEncoder(nn.Module):
|
|
|
491
501
|
post_norm=False,
|
|
492
502
|
normformer=False,
|
|
493
503
|
msa_scaling="d",
|
|
504
|
+
checkpoint_ff=True,
|
|
494
505
|
):
|
|
495
506
|
"""
|
|
496
507
|
Args:
|
|
@@ -567,6 +578,7 @@ class TransformerEncoder(nn.Module):
|
|
|
567
578
|
pre_norm=pre_norm,
|
|
568
579
|
post_norm=post_norm,
|
|
569
580
|
normformer=normformer,
|
|
581
|
+
checkpoint_ff=checkpoint_ff,
|
|
570
582
|
)
|
|
571
583
|
for i in range(n_layers)
|
|
572
584
|
]
|
|
@@ -160,6 +160,7 @@ class ViTEncoder(nn.Module):
|
|
|
160
160
|
transformer_mlp_dropout=0.0,
|
|
161
161
|
transformer_msa_dropout=0.1,
|
|
162
162
|
transformer_stochastic_depth=0.1,
|
|
163
|
+
transformer_checkpoint_ff=True,
|
|
163
164
|
linear_module=nn.Linear,
|
|
164
165
|
):
|
|
165
166
|
super().__init__()
|
|
@@ -321,6 +322,7 @@ class ViTEncoder(nn.Module):
|
|
|
321
322
|
pre_norm=transformer_pre_norm,
|
|
322
323
|
normformer=transformer_normformer,
|
|
323
324
|
post_norm=transformer_post_norm,
|
|
325
|
+
checkpoint_ff=transformer_checkpoint_ff,
|
|
324
326
|
)
|
|
325
327
|
else:
|
|
326
328
|
self.transformer = nn.Identity()
|
|
@@ -354,6 +356,7 @@ class ViTEncoder(nn.Module):
|
|
|
354
356
|
normformer=transformer_normformer,
|
|
355
357
|
post_norm=transformer_post_norm,
|
|
356
358
|
residual_path=transformer_initial_ff_residual_path,
|
|
359
|
+
checkpoint=transformer_checkpoint_ff,
|
|
357
360
|
)
|
|
358
361
|
else:
|
|
359
362
|
self.initial_ff = nn.Identity()
|
|
@@ -426,6 +429,7 @@ class ViT(nn.Module):
|
|
|
426
429
|
transformer_mlp_dropout=0.0,
|
|
427
430
|
transformer_msa_dropout=0.1,
|
|
428
431
|
transformer_stochastic_depth=0.1,
|
|
432
|
+
transformer_checkpoint_ff=True,
|
|
429
433
|
head=SequencePoolClassificationHead,
|
|
430
434
|
batch_norm_logits=True,
|
|
431
435
|
logit_projection_layer=nn.Linear,
|
|
@@ -492,6 +496,7 @@ class ViT(nn.Module):
|
|
|
492
496
|
transformer_mlp_dropout=transformer_mlp_dropout,
|
|
493
497
|
transformer_msa_dropout=transformer_msa_dropout,
|
|
494
498
|
transformer_stochastic_depth=transformer_stochastic_depth,
|
|
499
|
+
transformer_checkpoint_ff=transformer_checkpoint_ff,
|
|
495
500
|
linear_module=linear_module,
|
|
496
501
|
)
|
|
497
502
|
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|