broccoli-ml 4.0.1__py3-none-any.whl → 5.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
broccoli/transformer.py CHANGED
@@ -285,9 +285,11 @@ class FeedforwardBlock(nn.Module):
285
285
  normformer=False,
286
286
  post_norm=True,
287
287
  residual_path=True,
288
+ checkpoint=True,
288
289
  ):
289
290
  super().__init__()
290
291
 
292
+ self.checkpoint = checkpoint
291
293
  self.residual_path = residual_path
292
294
  self.post_norm = post_norm
293
295
 
@@ -326,12 +328,18 @@ class FeedforwardBlock(nn.Module):
326
328
  )
327
329
 
328
330
  def forward(self, x):
331
+
332
+ if self.checkpoint:
333
+ processed = checkpoint(self.process, x, use_reentrant=False)
334
+ else:
335
+ processed = self.process(x)
336
+
329
337
  if self.residual_path and self.post_norm:
330
- return self.layernorm(x + self.process(x))
338
+ return self.layernorm(x + processed)
331
339
  elif self.residual_path:
332
- return x + self.process(x)
340
+ return x + processed
333
341
  else:
334
- return self.process(x)
342
+ return processed
335
343
 
336
344
 
337
345
  class TransformerBlock(nn.Module):
@@ -365,6 +373,7 @@ class TransformerBlock(nn.Module):
365
373
  pre_norm=True,
366
374
  post_norm=False,
367
375
  normformer=False,
376
+ checkpoint_ff=True,
368
377
  ):
369
378
  """
370
379
  Args:
@@ -433,6 +442,7 @@ class TransformerBlock(nn.Module):
433
442
  normformer=normformer,
434
443
  post_norm=False, # Handled outside the block
435
444
  residual_path=False, # Handled outside the block
445
+ checkpoint=checkpoint_ff,
436
446
  )
437
447
 
438
448
  @property
@@ -445,17 +455,17 @@ class TransformerBlock(nn.Module):
445
455
  x = self.layer_norm_1(x)
446
456
  x = x + self.drop_path(self.attn(x, x, x))
447
457
  x = self.layer_norm_2(x)
448
- x = x + self.drop_path(checkpoint(self.ff, x, use_reentrant=False))
458
+ x = x + self.drop_path(self.ff(x))
449
459
  if self.post_norm: # i.e. in addition! Pre and post.
450
460
  x = self.layer_norm_3(x)
451
461
  elif self.post_norm: # i.e. only, not prenorm, just post
452
462
  x = x + self.drop_path(self.attn(x, x, x))
453
463
  x = self.layer_norm_1(x)
454
- x = x + self.drop_path(checkpoint(self.ff, x, use_reentrant=False))
464
+ x = x + self.drop_path(self.ff(x))
455
465
  x = self.layer_norm_2(x)
456
466
  else: # Not pre or post norm. Stand well back.
457
467
  x = x + self.drop_path(self.attn(x, x, x))
458
- x = x + self.drop_path(checkpoint(self.ff, x, use_reentrant=False))
468
+ x = x + self.drop_path(self.ff(x))
459
469
 
460
470
  return x
461
471
 
@@ -491,6 +501,7 @@ class TransformerEncoder(nn.Module):
491
501
  post_norm=False,
492
502
  normformer=False,
493
503
  msa_scaling="d",
504
+ checkpoint_ff=True,
494
505
  ):
495
506
  """
496
507
  Args:
@@ -567,6 +578,7 @@ class TransformerEncoder(nn.Module):
567
578
  pre_norm=pre_norm,
568
579
  post_norm=post_norm,
569
580
  normformer=normformer,
581
+ checkpoint_ff=checkpoint_ff,
570
582
  )
571
583
  for i in range(n_layers)
572
584
  ]
broccoli/vit.py CHANGED
@@ -160,6 +160,7 @@ class ViTEncoder(nn.Module):
160
160
  transformer_mlp_dropout=0.0,
161
161
  transformer_msa_dropout=0.1,
162
162
  transformer_stochastic_depth=0.1,
163
+ transformer_checkpoint_ff=True,
163
164
  linear_module=nn.Linear,
164
165
  ):
165
166
  super().__init__()
@@ -321,6 +322,7 @@ class ViTEncoder(nn.Module):
321
322
  pre_norm=transformer_pre_norm,
322
323
  normformer=transformer_normformer,
323
324
  post_norm=transformer_post_norm,
325
+ checkpoint_ff=transformer_checkpoint_ff,
324
326
  )
325
327
  else:
326
328
  self.transformer = nn.Identity()
@@ -354,6 +356,7 @@ class ViTEncoder(nn.Module):
354
356
  normformer=transformer_normformer,
355
357
  post_norm=transformer_post_norm,
356
358
  residual_path=transformer_initial_ff_residual_path,
359
+ checkpoint=transformer_checkpoint_ff,
357
360
  )
358
361
  else:
359
362
  self.initial_ff = nn.Identity()
@@ -426,6 +429,7 @@ class ViT(nn.Module):
426
429
  transformer_mlp_dropout=0.0,
427
430
  transformer_msa_dropout=0.1,
428
431
  transformer_stochastic_depth=0.1,
432
+ transformer_checkpoint_ff=True,
429
433
  head=SequencePoolClassificationHead,
430
434
  batch_norm_logits=True,
431
435
  logit_projection_layer=nn.Linear,
@@ -492,6 +496,7 @@ class ViT(nn.Module):
492
496
  transformer_mlp_dropout=transformer_mlp_dropout,
493
497
  transformer_msa_dropout=transformer_msa_dropout,
494
498
  transformer_stochastic_depth=transformer_stochastic_depth,
499
+ transformer_checkpoint_ff=transformer_checkpoint_ff,
495
500
  linear_module=linear_module,
496
501
  )
497
502
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: broccoli-ml
3
- Version: 4.0.1
3
+ Version: 5.0.0
4
4
  Summary: Some useful Pytorch models, circa 2025
5
5
  License: MIT
6
6
  Author: Nicholas Bailey
@@ -4,10 +4,10 @@ broccoli/cnn.py,sha256=WjoPDSpe3ttwxCBNfCVRdaCHvbeZ7G-a5_i8fUsK_d8,4889
4
4
  broccoli/linear.py,sha256=Y7s-DzcwsOipRboNHc4HTScw4mJRalNoVFsNcxOB6a4,4872
5
5
  broccoli/rope.py,sha256=GRqApBNmYCFaDak0WL1xE_BC5CTTYKQU_PBdeTcQcjc,12557
6
6
  broccoli/tensor.py,sha256=um8mrxkYbvNDo-QvHlmJm8Aw6qcngOlUZPoAk_PMReA,4480
7
- broccoli/transformer.py,sha256=uqSf8q30MF7Ds7LfqW8Pr206NXpSlf7o6770KZu81Ew,19883
7
+ broccoli/transformer.py,sha256=eSBRF-HYJ-BxfisJCueUYCIYtHXgj1ewG5RxEfcmu-E,20128
8
8
  broccoli/utils.py,sha256=oOWzn6dJ5nC_9r4zq0emmfmaYACJXJNFS48AOpW2jqc,358
9
- broccoli/vit.py,sha256=_5uLcklmJ1Uoj7V7TkzF0UqroVnl8NCHun5B0mORmOg,18651
10
- broccoli_ml-4.0.1.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
11
- broccoli_ml-4.0.1.dist-info/METADATA,sha256=vH_utDdo0-e2q8ReDrRHQ1d6fOzG4nzb9EWlqgyl4XY,1368
12
- broccoli_ml-4.0.1.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
13
- broccoli_ml-4.0.1.dist-info/RECORD,,
9
+ broccoli/vit.py,sha256=BrNLOx4_gTY6xTwAn8xT-HOgUnSFtU6_m1CpJXuQiKY,18907
10
+ broccoli_ml-5.0.0.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
11
+ broccoli_ml-5.0.0.dist-info/METADATA,sha256=NMkRLZfqhMZdBIn4BMHjF_-jyv3yYUlOrKeHRTt2rnE,1368
12
+ broccoli_ml-5.0.0.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
13
+ broccoli_ml-5.0.0.dist-info/RECORD,,