broccoli-ml 1.0.0__tar.gz → 1.2.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: broccoli-ml
3
- Version: 1.0.0
3
+ Version: 1.2.0
4
4
  Summary: Some useful Pytorch models, circa 2025
5
5
  License: MIT
6
6
  Author: Nicholas Bailey
@@ -73,7 +73,15 @@ class MHAttention(nn.Module):
73
73
  bos_tokens=0,
74
74
  rotary_embedding=None,
75
75
  source_size=None,
76
+ scaling="d",
76
77
  ):
78
+ """
79
+ Args:
80
+ scaling: how should the attention logits be scaled? Can be "sqrtd"
81
+ to mimic the original Attention is All You Need approach of
82
+ dividing by the sqrt of the embedding Dimension or "d" per
83
+ "Tensor Programs V...". Default "d"
84
+ """
77
85
  super().__init__()
78
86
 
79
87
  if rotary_embedding is not None:
@@ -84,6 +92,7 @@ class MHAttention(nn.Module):
84
92
  self.embed_dim = embed_dim
85
93
  self.n_heads = n_heads
86
94
  assert embed_dim % n_heads == 0
95
+ self.scaling = scaling
87
96
 
88
97
  self.head_dim = self.embed_dim // self.n_heads
89
98
 
@@ -207,7 +216,12 @@ class MHAttention(nn.Module):
207
216
 
208
217
  qk_scores = q @ k.transpose(-1, -2)
209
218
 
210
- qk_scores /= math.sqrt(self.head_dim)
219
+ if self.scaling == "sqrtd":
220
+ qk_scores /= math.sqrt(self.head_dim)
221
+ elif self.scaling == "d":
222
+ qk_scores /= self.head_dim
223
+ else:
224
+ raise ValueError('`scaling` argument to MHAttention must be "d" or "sqrtd"')
211
225
 
212
226
  # Apply mask if causal (must come before softmax)
213
227
  if self.causal:
@@ -305,6 +319,7 @@ class TransformerBlock(nn.Module):
305
319
  activation_kwargs: Optional[dict] = None,
306
320
  ff_linear_module_up=None,
307
321
  ff_linear_module_down=None,
322
+ msa_scaling="d",
308
323
  mlp_dropout=0.0,
309
324
  msa_dropout=0.0,
310
325
  identity_probability=0.0,
@@ -314,6 +329,14 @@ class TransformerBlock(nn.Module):
314
329
  post_norm=False,
315
330
  normformer=False,
316
331
  ):
332
+ """
333
+ Args:
334
+ msa_scaling: how should the attention logits be scaled? Can be "sqrtd"
335
+ to mimic the original Attention is All You Need approach of
336
+ dividing by the sqrt of the embedding Dimension or "d" per
337
+ "Tensor Programs V...". Default "d"
338
+ """
339
+
317
340
  super().__init__()
318
341
 
319
342
  self.pre_norm = pre_norm
@@ -348,6 +371,7 @@ class TransformerBlock(nn.Module):
348
371
  rotary_embedding=self.rotary_embedding,
349
372
  source_size=source_size,
350
373
  bos_tokens=bos_tokens,
374
+ scaling=msa_scaling,
351
375
  )
352
376
 
353
377
  # Submodule for the feedforward process
@@ -429,9 +453,21 @@ class TransformerEncoder(nn.Module):
429
453
  pre_norm=True,
430
454
  post_norm=False,
431
455
  normformer=False,
456
+ msa_scaling="d",
432
457
  ):
433
- if position_embedding_type == "relative":
434
- assert source_size is not None # TODO: make this a proper exception
458
+ """
459
+ Args:
460
+ msa_scaling: how should the attention logits be scaled? Can be "sqrtd"
461
+ to mimic the original Attention is All You Need approach of
462
+ dividing by the sqrt of the embedding Dimension or "d" per
463
+ "Tensor Programs V...". Default "d"
464
+ """
465
+
466
+ if (position_embedding_type == "relative") and (source_size is None):
467
+ raise ValueError(
468
+ "`source_size` for TransformerEncoder cannot be None if"
469
+ " `position_embedding_type` is relative"
470
+ )
435
471
 
436
472
  super().__init__()
437
473
  self.seq_len = seq_len
@@ -461,7 +497,8 @@ class TransformerEncoder(nn.Module):
461
497
  self.msa_dropout = msa_dropout
462
498
  self.stochastic_depth = stochastic_depth
463
499
 
464
- assert isinstance(n_layers, int) # XXX: make this a proper Exception
500
+ assert isinstance(n_layers, int)
501
+
465
502
  if n_layers == 1:
466
503
  self.stochastic_depth_probabilities = [0.0]
467
504
  else:
@@ -484,6 +521,7 @@ class TransformerEncoder(nn.Module):
484
521
  activation_kwargs=activation_kwargs,
485
522
  ff_linear_module_up=ff_linear_module_up,
486
523
  ff_linear_module_down=ff_linear_module_down,
524
+ msa_scaling=msa_scaling,
487
525
  mlp_dropout=mlp_dropout,
488
526
  msa_dropout=msa_dropout,
489
527
  identity_probability=self.stochastic_depth_probabilities[i],
@@ -53,12 +53,20 @@ class ClassificationHead(nn.Module):
53
53
  A general classification head for a ViT
54
54
  """
55
55
 
56
- def __init__(self, d_model, linear_module, n_classes, batch_norm=True):
56
+ def __init__(
57
+ self,
58
+ d_model,
59
+ n_classes,
60
+ linear_module=nn.Linear,
61
+ logit_projection_layer=nn.Linear,
62
+ batch_norm_logits=True,
63
+ ):
57
64
  super().__init__()
58
65
  self.d_model = d_model
59
66
  self.summarize = GetCLSToken()
60
- self.projection = linear_module(d_model, n_classes)
61
- if batch_norm:
67
+ self.projection = logit_projection_layer(d_model, n_classes)
68
+
69
+ if batch_norm_logits:
62
70
  self.batch_norm = nn.BatchNorm1d(n_classes, affine=False)
63
71
  else:
64
72
  self.batch_norm = nn.Identity()
@@ -83,7 +91,7 @@ class SequencePoolClassificationHead(ClassificationHead):
83
91
  """
84
92
 
85
93
  def __init__(self, d_model, linear_module, out_dim, batch_norm=True):
86
- super().__init__(d_model, linear_module, out_dim, batch_norm=batch_norm)
94
+ super().__init__(d_model, linear_module, out_dim, batch_norm_logits=batch_norm)
87
95
  self.summarize = SequencePool(d_model, linear_module)
88
96
  # Rebuild the classification process with the correct summary module:
89
97
  self.classification_process = nn.Sequential(
@@ -143,6 +151,7 @@ class ViTEncoder(nn.Module):
143
151
  transformer_activation_kwargs: Optional[dict] = None,
144
152
  transformer_ff_linear_module_up=None,
145
153
  transformer_ff_linear_module_down=None,
154
+ transformer_msa_scaling="d",
146
155
  transformer_mlp_dropout=0.0,
147
156
  transformer_msa_dropout=0.1,
148
157
  transformer_stochastic_depth=0.1,
@@ -295,6 +304,7 @@ class ViTEncoder(nn.Module):
295
304
  activation_kwargs=transformer_activation_kwargs,
296
305
  ff_linear_module_up=transformer_ff_linear_module_up,
297
306
  ff_linear_module_down=transformer_ff_linear_module_down,
307
+ msa_scaling=transformer_msa_scaling,
298
308
  mlp_dropout=transformer_mlp_dropout,
299
309
  msa_dropout=transformer_msa_dropout,
300
310
  stochastic_depth=transformer_stochastic_depth,
@@ -405,12 +415,14 @@ class ViT(nn.Module):
405
415
  transformer_activation_kwargs: Optional[dict] = None,
406
416
  transformer_ff_linear_module_up=None,
407
417
  transformer_ff_linear_module_down=None,
418
+ transformer_msa_scaling="d",
408
419
  transformer_mlp_dropout=0.0,
409
420
  transformer_msa_dropout=0.1,
410
421
  transformer_stochastic_depth=0.1,
411
- batch_norm_outputs=True,
412
- linear_module=SpectralNormLinear,
413
422
  head=SequencePoolClassificationHead,
423
+ batch_norm_logits=True,
424
+ logit_projection_layer=nn.Linear,
425
+ linear_module=nn.Linear,
414
426
  ):
415
427
 
416
428
  super().__init__()
@@ -468,6 +480,7 @@ class ViT(nn.Module):
468
480
  transformer_activation_kwargs=transformer_activation_kwargs,
469
481
  transformer_ff_linear_module_up=transformer_ff_linear_module_up,
470
482
  transformer_ff_linear_module_down=transformer_ff_linear_module_down,
483
+ transformer_msa_scaling=transformer_msa_scaling,
471
484
  transformer_mlp_dropout=transformer_mlp_dropout,
472
485
  transformer_msa_dropout=transformer_msa_dropout,
473
486
  transformer_stochastic_depth=transformer_stochastic_depth,
@@ -476,9 +489,10 @@ class ViT(nn.Module):
476
489
 
477
490
  self.pool = head(
478
491
  transformer_embedding_size,
479
- linear_module,
480
492
  image_classes,
481
- batch_norm=batch_norm_outputs,
493
+ linear_module=linear_module,
494
+ logit_projection_layer=logit_projection_layer,
495
+ batch_norm=batch_norm_logits,
482
496
  )
483
497
 
484
498
  @property
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "broccoli-ml"
3
- version = "1.0.0"
3
+ version = "1.2.0"
4
4
  description = "Some useful Pytorch models, circa 2025"
5
5
  authors = [
6
6
  {name = "Nicholas Bailey"}
File without changes
File without changes
File without changes