broccoli-ml 4.0.0__py3-none-any.whl → 5.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
broccoli/transformer.py CHANGED
@@ -233,7 +233,7 @@ class MHAttention(nn.Module):
233
233
  q,
234
234
  k,
235
235
  v,
236
- dropout_p=self.dropout if self.training else 0.0,
236
+ dropout_p=self.dropout.p if self.training else 0.0,
237
237
  softmax_scale=scaling_factor,
238
238
  causal=self.causal,
239
239
  )
@@ -257,6 +257,8 @@ class MHAttention(nn.Module):
257
257
 
258
258
  qk_scores = F.softmax(qk_scores, dim=-1)
259
259
 
260
+ qk_scores = self.dropout(qk_scores)
261
+
260
262
  output_with_heads = qk_scores @ v
261
263
 
262
264
  output_without_heads = rearrange(output_with_heads, "b h t d -> b t (h d)")
@@ -283,9 +285,11 @@ class FeedforwardBlock(nn.Module):
283
285
  normformer=False,
284
286
  post_norm=True,
285
287
  residual_path=True,
288
+ checkpoint=True,
286
289
  ):
287
290
  super().__init__()
288
291
 
292
+ self.checkpoint = checkpoint
289
293
  self.residual_path = residual_path
290
294
  self.post_norm = post_norm
291
295
 
@@ -324,12 +328,18 @@ class FeedforwardBlock(nn.Module):
324
328
  )
325
329
 
326
330
  def forward(self, x):
331
+
332
+ if self.checkpoint:
333
+ processed = checkpoint(self.process, x, use_reentrant=False)
334
+ else:
335
+ processed = self.process(x)
336
+
327
337
  if self.residual_path and self.post_norm:
328
- return self.layernorm(x + self.process(x))
338
+ return self.layernorm(x + processed)
329
339
  elif self.residual_path:
330
- return x + self.process(x)
340
+ return x + processed
331
341
  else:
332
- return self.process(x)
342
+ return processed
333
343
 
334
344
 
335
345
  class TransformerBlock(nn.Module):
@@ -363,6 +373,7 @@ class TransformerBlock(nn.Module):
363
373
  pre_norm=True,
364
374
  post_norm=False,
365
375
  normformer=False,
376
+ checkpoint_ff=True,
366
377
  ):
367
378
  """
368
379
  Args:
@@ -431,6 +442,7 @@ class TransformerBlock(nn.Module):
431
442
  normformer=normformer,
432
443
  post_norm=False, # Handled outside the block
433
444
  residual_path=False, # Handled outside the block
445
+ checkpoint=checkpoint_ff,
434
446
  )
435
447
 
436
448
  @property
@@ -443,17 +455,17 @@ class TransformerBlock(nn.Module):
443
455
  x = self.layer_norm_1(x)
444
456
  x = x + self.drop_path(self.attn(x, x, x))
445
457
  x = self.layer_norm_2(x)
446
- x = x + self.drop_path(checkpoint(self.ff, x, use_reentrant=False))
458
+ x = x + self.drop_path(self.ff(x))
447
459
  if self.post_norm: # i.e. in addition! Pre and post.
448
460
  x = self.layer_norm_3(x)
449
461
  elif self.post_norm: # i.e. only, not prenorm, just post
450
462
  x = x + self.drop_path(self.attn(x, x, x))
451
463
  x = self.layer_norm_1(x)
452
- x = x + self.drop_path(checkpoint(self.ff, x, use_reentrant=False))
464
+ x = x + self.drop_path(self.ff(x))
453
465
  x = self.layer_norm_2(x)
454
466
  else: # Not pre or post norm. Stand well back.
455
467
  x = x + self.drop_path(self.attn(x, x, x))
456
- x = x + self.drop_path(checkpoint(self.ff, x, use_reentrant=False))
468
+ x = x + self.drop_path(self.ff(x))
457
469
 
458
470
  return x
459
471
 
@@ -489,6 +501,7 @@ class TransformerEncoder(nn.Module):
489
501
  post_norm=False,
490
502
  normformer=False,
491
503
  msa_scaling="d",
504
+ checkpoint_ff=True,
492
505
  ):
493
506
  """
494
507
  Args:
@@ -565,6 +578,7 @@ class TransformerEncoder(nn.Module):
565
578
  pre_norm=pre_norm,
566
579
  post_norm=post_norm,
567
580
  normformer=normformer,
581
+ checkpoint_ff=checkpoint_ff,
568
582
  )
569
583
  for i in range(n_layers)
570
584
  ]
broccoli/vit.py CHANGED
@@ -160,6 +160,7 @@ class ViTEncoder(nn.Module):
160
160
  transformer_mlp_dropout=0.0,
161
161
  transformer_msa_dropout=0.1,
162
162
  transformer_stochastic_depth=0.1,
163
+ transformer_checkpoint_ff=True,
163
164
  linear_module=nn.Linear,
164
165
  ):
165
166
  super().__init__()
@@ -321,6 +322,7 @@ class ViTEncoder(nn.Module):
321
322
  pre_norm=transformer_pre_norm,
322
323
  normformer=transformer_normformer,
323
324
  post_norm=transformer_post_norm,
325
+ checkpoint_ff=transformer_checkpoint_ff,
324
326
  )
325
327
  else:
326
328
  self.transformer = nn.Identity()
@@ -354,6 +356,7 @@ class ViTEncoder(nn.Module):
354
356
  normformer=transformer_normformer,
355
357
  post_norm=transformer_post_norm,
356
358
  residual_path=transformer_initial_ff_residual_path,
359
+ checkpoint=transformer_checkpoint_ff,
357
360
  )
358
361
  else:
359
362
  self.initial_ff = nn.Identity()
@@ -426,6 +429,7 @@ class ViT(nn.Module):
426
429
  transformer_mlp_dropout=0.0,
427
430
  transformer_msa_dropout=0.1,
428
431
  transformer_stochastic_depth=0.1,
432
+ transformer_checkpoint_ff=True,
429
433
  head=SequencePoolClassificationHead,
430
434
  batch_norm_logits=True,
431
435
  logit_projection_layer=nn.Linear,
@@ -492,6 +496,7 @@ class ViT(nn.Module):
492
496
  transformer_mlp_dropout=transformer_mlp_dropout,
493
497
  transformer_msa_dropout=transformer_msa_dropout,
494
498
  transformer_stochastic_depth=transformer_stochastic_depth,
499
+ transformer_checkpoint_ff=transformer_checkpoint_ff,
495
500
  linear_module=linear_module,
496
501
  )
497
502
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: broccoli-ml
3
- Version: 4.0.0
3
+ Version: 5.0.0
4
4
  Summary: Some useful Pytorch models, circa 2025
5
5
  License: MIT
6
6
  Author: Nicholas Bailey
@@ -4,10 +4,10 @@ broccoli/cnn.py,sha256=WjoPDSpe3ttwxCBNfCVRdaCHvbeZ7G-a5_i8fUsK_d8,4889
4
4
  broccoli/linear.py,sha256=Y7s-DzcwsOipRboNHc4HTScw4mJRalNoVFsNcxOB6a4,4872
5
5
  broccoli/rope.py,sha256=GRqApBNmYCFaDak0WL1xE_BC5CTTYKQU_PBdeTcQcjc,12557
6
6
  broccoli/tensor.py,sha256=um8mrxkYbvNDo-QvHlmJm8Aw6qcngOlUZPoAk_PMReA,4480
7
- broccoli/transformer.py,sha256=PrPhLLz2IhTWThtccxDVsgyjn3eqZle-3iohJXCECac,19832
7
+ broccoli/transformer.py,sha256=eSBRF-HYJ-BxfisJCueUYCIYtHXgj1ewG5RxEfcmu-E,20128
8
8
  broccoli/utils.py,sha256=oOWzn6dJ5nC_9r4zq0emmfmaYACJXJNFS48AOpW2jqc,358
9
- broccoli/vit.py,sha256=_5uLcklmJ1Uoj7V7TkzF0UqroVnl8NCHun5B0mORmOg,18651
10
- broccoli_ml-4.0.0.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
11
- broccoli_ml-4.0.0.dist-info/METADATA,sha256=yH9OlQRUppZx60UrhBcVv9CBixvwgXmWlvkc7p6AZ9k,1368
12
- broccoli_ml-4.0.0.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
13
- broccoli_ml-4.0.0.dist-info/RECORD,,
9
+ broccoli/vit.py,sha256=BrNLOx4_gTY6xTwAn8xT-HOgUnSFtU6_m1CpJXuQiKY,18907
10
+ broccoli_ml-5.0.0.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
11
+ broccoli_ml-5.0.0.dist-info/METADATA,sha256=NMkRLZfqhMZdBIn4BMHjF_-jyv3yYUlOrKeHRTt2rnE,1368
12
+ broccoli_ml-5.0.0.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
13
+ broccoli_ml-5.0.0.dist-info/RECORD,,