broccoli-ml 12.3.1__tar.gz → 13.0.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: broccoli-ml
3
- Version: 12.3.1
3
+ Version: 13.0.1
4
4
  Summary: Some useful Pytorch models, circa 2025
5
5
  License: MIT
6
6
  Author: Nicholas Bailey
@@ -21,6 +21,12 @@ except ImportError:
21
21
  FLASH_ATTN = False
22
22
 
23
23
 
24
+ def scale_parameters(torch_module: nn.Module, factor: float):
25
+ with torch.no_grad():
26
+ for param in torch_module.parameters():
27
+ param.mul_(factor)
28
+
29
+
24
30
  def drop_path(
25
31
  x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
26
32
  ):
@@ -83,6 +89,7 @@ class MHAttention(nn.Module):
83
89
  rotary_embedding=None,
84
90
  source_size=None,
85
91
  scaling="d",
92
+ beta=1.0,
86
93
  ):
87
94
  """
88
95
  Args:
@@ -111,6 +118,7 @@ class MHAttention(nn.Module):
111
118
  self.n_heads = n_heads
112
119
  assert embed_dim % n_heads == 0
113
120
  self.scaling = scaling
121
+ self.beta = beta
114
122
 
115
123
  self.head_dim = self.embed_dim // self.n_heads
116
124
 
@@ -192,17 +200,26 @@ class MHAttention(nn.Module):
192
200
  "`source_size` must be a tuple of 1, 2 or 3 integers"
193
201
  )
194
202
 
195
- q_util, q_img = q[:, : self.utility_tokens, :], q[:, self.utility_tokens :, :]
196
- k_util, k_img = k[:, : self.utility_tokens, :], k[:, self.utility_tokens :, :]
203
+ q = rearrange(q, "b t (h d) -> b t h d", h=self.n_heads)
204
+ k = rearrange(k, "b t (h d) -> b t h d", h=self.n_heads)
205
+
206
+ q_util, q_img = (
207
+ q[:, : self.utility_tokens, :, :],
208
+ q[:, self.utility_tokens :, :, :],
209
+ )
210
+ k_util, k_img = (
211
+ k[:, : self.utility_tokens, :, :],
212
+ k[:, self.utility_tokens :, :, :],
213
+ )
197
214
 
198
215
  q_img = rearrange(
199
216
  q_img,
200
- f"b ({spatial_dimension_names}) d -> b {spatial_dimension_names} d",
217
+ f"b ({spatial_dimension_names}) h d -> b {spatial_dimension_names} h d",
201
218
  **spatial_dimension_values,
202
219
  )
203
220
  k_img = rearrange(
204
221
  k_img,
205
- f"b ({spatial_dimension_names}) d -> b {spatial_dimension_names} d",
222
+ f"b ({spatial_dimension_names}) h d -> b {spatial_dimension_names} h d",
206
223
  **spatial_dimension_values,
207
224
  )
208
225
 
@@ -213,17 +230,20 @@ class MHAttention(nn.Module):
213
230
 
214
231
  q_img = rearrange(
215
232
  q_img,
216
- f"b {spatial_dimension_names} d -> b ({spatial_dimension_names}) d",
233
+ f"b {spatial_dimension_names} h d -> b ({spatial_dimension_names}) h d",
217
234
  )
218
235
  k_img = rearrange(
219
236
  k_img,
220
- f"b {spatial_dimension_names} d -> b ({spatial_dimension_names}) d",
237
+ f"b {spatial_dimension_names} h d -> b ({spatial_dimension_names}) h d",
221
238
  )
222
239
 
223
240
  # Re-combine the utility tokens and the RoPE-enhanced sequence tokens
224
241
  q = torch.cat([q_util, q_img], dim=1)
225
242
  k = torch.cat([k_util, k_img], dim=1)
226
243
 
244
+ q = rearrange(q, "b t h d -> b t (h d)")
245
+ k = rearrange(k, "b t h d -> b t (h d)")
246
+
227
247
  return q, k
228
248
 
229
249
  def project_qkv(
@@ -330,7 +350,10 @@ class MHAttention(nn.Module):
330
350
  self.q_proj.reset_parameters()
331
351
  self.k_proj.reset_parameters()
332
352
  self.v_proj.reset_parameters()
353
+ scale_parameters(self.v_proj, self.beta) # per Microsoft DeepNet
333
354
  self.out_proj.reset_parameters()
355
+ scale_parameters(self.out_proj, self.beta) # per Microsoft DeepNet
356
+
334
357
  if self.talking_heads:
335
358
  # Initialize close to identity
336
359
  nn.init.eye_(self.head_projection.weight)
@@ -354,17 +377,14 @@ class FeedforwardBlock(nn.Module):
354
377
  outer_dropout=None,
355
378
  linear_module_up=nn.Linear,
356
379
  linear_module_down=nn.Linear,
357
- pre_norm=True,
358
380
  normformer=False,
359
- post_norm=True,
360
- residual_path=True,
361
381
  checkpoint=True,
382
+ beta=1.0,
362
383
  ):
363
384
  super().__init__()
364
385
 
365
386
  self.checkpoint = checkpoint
366
- self.residual_path = residual_path
367
- self.post_norm = post_norm
387
+ self.beta = beta
368
388
  self.xglu = activation.__name__.endswith("GLU")
369
389
 
370
390
  if self.residual_path and (output_features < input_features):
@@ -402,7 +422,6 @@ class FeedforwardBlock(nn.Module):
402
422
 
403
423
  self.process = nn.Sequential(
404
424
  *[
405
- nn.LayerNorm(input_features) if pre_norm else nn.Identity(),
406
425
  self.linear_in,
407
426
  self.activation,
408
427
  self.inner_dropout,
@@ -452,12 +471,7 @@ class FeedforwardBlock(nn.Module):
452
471
  else:
453
472
  processed = self.process(x)
454
473
 
455
- if self.residual_path and self.post_norm:
456
- return self.layernorm(x + processed)
457
- elif self.residual_path:
458
- return x + processed
459
- else:
460
- return processed
474
+ return processed
461
475
 
462
476
  def reset_parameters(self):
463
477
  if self.post_norm:
@@ -468,8 +482,11 @@ class FeedforwardBlock(nn.Module):
468
482
  if hasattr(module, "reset_parameters"):
469
483
  module.reset_parameters()
470
484
 
485
+ scale_parameters(self.linear_in, self.beta) # per Microsoft DeepNet
486
+ scale_parameters(self.linear_out, self.beta)
471
487
 
472
- class TransformerBlock(nn.Module):
488
+
489
+ class EncoderBlock(nn.Module):
473
490
  """
474
491
  Performs LayerNorms first (as in PyTorch Transformers when norm_first=True),
475
492
  which is also what is seen in e.g.
@@ -504,6 +521,8 @@ class TransformerBlock(nn.Module):
504
521
  post_norm=False,
505
522
  normformer=False,
506
523
  checkpoint_ff=True,
524
+ alpha=1.0,
525
+ beta=1.0,
507
526
  ):
508
527
  """
509
528
  Args:
@@ -515,10 +534,16 @@ class TransformerBlock(nn.Module):
515
534
 
516
535
  super().__init__()
517
536
 
537
+ if pre_norm and post_norm:
538
+ raise ValueError("A transformer cannot be both prenorm and postnorm.")
539
+
518
540
  self.pre_norm = pre_norm
519
541
  self.post_norm = post_norm
520
542
  self.normformer = normformer
521
543
 
544
+ self.alpha = alpha
545
+ self.beta = beta
546
+
522
547
  self.drop_path = DropPath(drop_prob=identity_probability, scale_by_keep=True)
523
548
 
524
549
  if self.pre_norm:
@@ -529,6 +554,7 @@ class TransformerBlock(nn.Module):
529
554
  self.normformer_norm = nn.LayerNorm(d_model)
530
555
 
531
556
  if self.post_norm:
557
+ self.input_norm = nn.LayerNorm(d_model)
532
558
  self.post_attention_norm = nn.LayerNorm(d_model)
533
559
  self.post_mlp_norm = nn.LayerNorm(d_model)
534
560
 
@@ -556,6 +582,7 @@ class TransformerBlock(nn.Module):
556
582
  utility_tokens=utility_tokens,
557
583
  talking_heads=talking_heads,
558
584
  scaling=msa_scaling,
585
+ beta=beta,
559
586
  )
560
587
 
561
588
  # Submodule for the feedforward process
@@ -578,11 +605,9 @@ class TransformerBlock(nn.Module):
578
605
  if ff_linear_module_down is not None
579
606
  else linear_module
580
607
  ),
581
- pre_norm=False, # Handled outside the block
582
608
  normformer=normformer,
583
- post_norm=False, # Handled outside the block
584
- residual_path=False, # Handled outside the block
585
609
  checkpoint=checkpoint_ff,
610
+ beta=beta,
586
611
  )
587
612
 
588
613
  self.reset_parameters()
@@ -592,45 +617,36 @@ class TransformerBlock(nn.Module):
592
617
  return self.attn._kv_distance
593
618
 
594
619
  def forward(self, x):
595
- if self.pre_norm:
596
- x = self.pre_attention_norm(x)
597
- x = x + self.drop_path(self.attn(x, x, x))
598
- x = self.pre_mlp_norm(x)
599
- x = x + self.drop_path(self.ff(x))
600
- if self.post_norm: # i.e. in addition! Pre and post.
601
- x = self.post_mlp_norm(x)
602
- return x
603
-
604
- # if self.pre_norm:
605
- # process_x = self.pre_attention_norm(x)
606
- # else:
607
- # process_x = x
620
+ if self.post_norm:
621
+ x = self.input_norm(x)
608
622
 
609
- # processed = self.drop_path(self.attn(process_x, process_x, process_x))
623
+ if self.pre_norm:
624
+ process_x = self.pre_attention_norm(x)
625
+ else:
626
+ process_x = x
610
627
 
611
- # if self.normformer:
612
- # processed = self.normformer_norm(processed)
628
+ processed = self.drop_path(self.attn(process_x, process_x, process_x))
613
629
 
614
- # if self.residual_path:
615
- # x = x + processed
630
+ if self.normformer:
631
+ processed = self.normformer_norm(processed)
616
632
 
617
- # if self.post_norm:
618
- # x = self.post_attention_norm(x)
633
+ x = self.alpha * x + processed
619
634
 
620
- # if self.pre_norm:
621
- # process_x = self.pre_mlp_norm(x)
622
- # else:
623
- # process_x = x
635
+ if self.post_norm:
636
+ x = self.post_attention_norm(x)
637
+ elif self.pre_norm:
638
+ process_x = self.pre_mlp_norm(x)
639
+ else:
640
+ process_x = x
624
641
 
625
- # processed = self.drop_path(self.ff(process_x))
642
+ processed = self.drop_path(self.ff(process_x))
626
643
 
627
- # if self.residual_path:
628
- # x = x + processed
644
+ x = self.alpha * x + processed
629
645
 
630
- # if self.post_norm:
631
- # x = self.post_mlp_norm(x)
646
+ if self.post_norm:
647
+ x = self.post_mlp_norm(x)
632
648
 
633
- # return x
649
+ return x
634
650
 
635
651
  def attention_logits(self, x):
636
652
  """
@@ -696,6 +712,8 @@ class TransformerEncoder(nn.Module):
696
712
  normformer=False,
697
713
  msa_scaling="d",
698
714
  checkpoint_ff=True,
715
+ alpha=1.0,
716
+ beta=1.0,
699
717
  ):
700
718
  """
701
719
  Args:
@@ -769,7 +787,7 @@ class TransformerEncoder(nn.Module):
769
787
 
770
788
  self.blocks = nn.ModuleList(
771
789
  [
772
- TransformerBlock(
790
+ EncoderBlock(
773
791
  self.full_sequence_length,
774
792
  d_model,
775
793
  n_heads,
@@ -794,6 +812,8 @@ class TransformerEncoder(nn.Module):
794
812
  post_norm=post_norm,
795
813
  normformer=normformer,
796
814
  checkpoint_ff=checkpoint_ff,
815
+ alpha=alpha,
816
+ beta=beta,
797
817
  )
798
818
  for i in range(n_layers)
799
819
  ]
@@ -158,7 +158,6 @@ class ViTEncoder(nn.Module):
158
158
  pooling_kernel_stride=2,
159
159
  pooling_padding=1,
160
160
  transformer_feedforward_first=True,
161
- transformer_initial_ff_residual_path=True,
162
161
  transformer_initial_ff_linear_module_up=None,
163
162
  transformer_initial_ff_linear_module_down=None,
164
163
  transformer_initial_ff_dropout=None,
@@ -188,9 +187,14 @@ class ViTEncoder(nn.Module):
188
187
  transformer_stochastic_depth=0.1,
189
188
  transformer_checkpoint_ff=True,
190
189
  linear_module=nn.Linear,
190
+ alpha=1.0,
191
+ beta=1.0,
191
192
  ):
192
193
  super().__init__()
193
194
 
195
+ self.alpha = alpha
196
+ self.beta = beta
197
+
194
198
  if cnn_activation_kwargs is not None:
195
199
  self.cnn_activation = cnn_activation(**cnn_activation_kwargs)
196
200
  else:
@@ -352,6 +356,8 @@ class ViTEncoder(nn.Module):
352
356
  normformer=transformer_normformer,
353
357
  post_norm=transformer_post_norm,
354
358
  checkpoint_ff=transformer_checkpoint_ff,
359
+ alpha=self.alpha,
360
+ beta=self.beta,
355
361
  )
356
362
  else:
357
363
  self.transformer = nn.Identity()
@@ -393,16 +399,14 @@ class ViTEncoder(nn.Module):
393
399
  or transformer_ff_linear_module_down
394
400
  or linear_module
395
401
  ),
396
- pre_norm=transformer_pre_norm,
397
402
  normformer=transformer_normformer,
398
- post_norm=transformer_post_norm,
399
- residual_path=transformer_initial_ff_residual_path,
400
403
  checkpoint=transformer_checkpoint_ff,
404
+ beta=self.beta,
401
405
  )
402
406
  else:
403
407
  self.initial_ff = nn.Identity()
404
408
 
405
- self.encoder = nn.Sequential(
409
+ self.preprocess = nn.Sequential(
406
410
  *[
407
411
  batchnormxd(in_channels) if initial_batch_norm else nn.Identity(),
408
412
  self.cnn,
@@ -412,19 +416,21 @@ class ViTEncoder(nn.Module):
412
416
  f"N C {spatial_dim_names} -> N ({spatial_dim_names}) C"
413
417
  ),
414
418
  self.pooling_channels_padding,
415
- self.initial_ff,
416
- self.transformer,
419
+ nn.LayerNorm(),
417
420
  ]
418
421
  )
419
422
 
420
423
  self.reset_parameters()
421
424
 
422
425
  def forward(self, x):
423
- return self.encoder(x)
426
+ x = self.preprocess(x)
427
+ x = x + self.initial_ff(x)
428
+ return self.transformer(x)
424
429
 
425
430
  def attention_logits(self, x):
426
- x = self.encoder[:-1](x)
427
- return self.encoder[-1].attention_logits(x)
431
+ x = self.preprocess(x)
432
+ x = x + self.initial_ff(x)
433
+ return self.transformer.attention_logits(x)
428
434
 
429
435
  def reset_parameters(self):
430
436
  for module in self.encoder:
@@ -458,7 +464,6 @@ class ViT(nn.Module):
458
464
  pooling_kernel_stride=2,
459
465
  pooling_padding=1,
460
466
  transformer_feedforward_first=True,
461
- transformer_initial_ff_residual_path=True,
462
467
  transformer_initial_ff_linear_module_up=None,
463
468
  transformer_initial_ff_linear_module_down=None,
464
469
  transformer_initial_ff_dropout=None,
@@ -491,6 +496,8 @@ class ViT(nn.Module):
491
496
  batch_norm_logits=True,
492
497
  logit_projection_layer=nn.Linear,
493
498
  linear_module=nn.Linear,
499
+ alpha=1.0,
500
+ beta=1.0,
494
501
  ):
495
502
 
496
503
  super().__init__()
@@ -511,6 +518,9 @@ class ViT(nn.Module):
511
518
  "SwiGLU": SwiGLU,
512
519
  }[transformer_activation]
513
520
 
521
+ self.alpha = alpha
522
+ self.beta = beta
523
+
514
524
  self.encoder = ViTEncoder(
515
525
  input_size=input_size,
516
526
  initial_batch_norm=initial_batch_norm,
@@ -530,7 +540,6 @@ class ViT(nn.Module):
530
540
  pooling_kernel_stride=pooling_kernel_stride,
531
541
  pooling_padding=pooling_padding,
532
542
  transformer_feedforward_first=transformer_feedforward_first,
533
- transformer_initial_ff_residual_path=transformer_initial_ff_residual_path,
534
543
  transformer_initial_ff_linear_module_up=transformer_initial_ff_linear_module_up,
535
544
  transformer_initial_ff_linear_module_down=transformer_initial_ff_linear_module_down,
536
545
  transformer_initial_ff_dropout=transformer_initial_ff_dropout,
@@ -560,6 +569,8 @@ class ViT(nn.Module):
560
569
  transformer_stochastic_depth=transformer_stochastic_depth,
561
570
  transformer_checkpoint_ff=transformer_checkpoint_ff,
562
571
  linear_module=linear_module,
572
+ alpha=alpha,
573
+ beta=beta,
563
574
  )
564
575
 
565
576
  self.pool = head(
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "broccoli-ml"
3
- version = "12.3.1"
3
+ version = "13.0.1"
4
4
  description = "Some useful Pytorch models, circa 2025"
5
5
  authors = [
6
6
  {name = "Nicholas Bailey"}
File without changes
File without changes