broccoli-ml 1.0.0__tar.gz → 1.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: broccoli-ml
3
- Version: 1.0.0
3
+ Version: 1.1.0
4
4
  Summary: Some useful Pytorch models, circa 2025
5
5
  License: MIT
6
6
  Author: Nicholas Bailey
@@ -73,7 +73,15 @@ class MHAttention(nn.Module):
73
73
  bos_tokens=0,
74
74
  rotary_embedding=None,
75
75
  source_size=None,
76
+ scaling="d",
76
77
  ):
78
+ """
79
+ Args:
80
+ scaling: how should the attention logits be scaled? Can be "sqrtd"
81
+ to mimic the original Attention is All You Need approach of
82
+ dividing by the sqrt of the embedding Dimension or "d" per
83
+ "Tensor Programs V...". Default "d"
84
+ """
77
85
  super().__init__()
78
86
 
79
87
  if rotary_embedding is not None:
@@ -84,6 +92,7 @@ class MHAttention(nn.Module):
84
92
  self.embed_dim = embed_dim
85
93
  self.n_heads = n_heads
86
94
  assert embed_dim % n_heads == 0
95
+ self.scaling = scaling
87
96
 
88
97
  self.head_dim = self.embed_dim // self.n_heads
89
98
 
@@ -207,7 +216,12 @@ class MHAttention(nn.Module):
207
216
 
208
217
  qk_scores = q @ k.transpose(-1, -2)
209
218
 
210
- qk_scores /= math.sqrt(self.head_dim)
219
+ if self.scaling == "sqrtd":
220
+ qk_scores /= math.sqrt(self.head_dim)
221
+ elif self.scaling == "d":
222
+ qk_scores /= self.head_dim
223
+ else:
224
+ raise ValueError('`scaling` argument to MHAttention must be "d" or "sqrtd"')
211
225
 
212
226
  # Apply mask if causal (must come before softmax)
213
227
  if self.causal:
@@ -305,6 +319,7 @@ class TransformerBlock(nn.Module):
305
319
  activation_kwargs: Optional[dict] = None,
306
320
  ff_linear_module_up=None,
307
321
  ff_linear_module_down=None,
322
+ msa_scaling="d",
308
323
  mlp_dropout=0.0,
309
324
  msa_dropout=0.0,
310
325
  identity_probability=0.0,
@@ -314,6 +329,14 @@ class TransformerBlock(nn.Module):
314
329
  post_norm=False,
315
330
  normformer=False,
316
331
  ):
332
+ """
333
+ Args:
334
+ msa_scaling: how should the attention logits be scaled? Can be "sqrtd"
335
+ to mimic the original Attention is All You Need approach of
336
+ dividing by the sqrt of the embedding Dimension or "d" per
337
+ "Tensor Programs V...". Default "d"
338
+ """
339
+
317
340
  super().__init__()
318
341
 
319
342
  self.pre_norm = pre_norm
@@ -348,6 +371,7 @@ class TransformerBlock(nn.Module):
348
371
  rotary_embedding=self.rotary_embedding,
349
372
  source_size=source_size,
350
373
  bos_tokens=bos_tokens,
374
+ scaling=msa_scaling,
351
375
  )
352
376
 
353
377
  # Submodule for the feedforward process
@@ -429,9 +453,21 @@ class TransformerEncoder(nn.Module):
429
453
  pre_norm=True,
430
454
  post_norm=False,
431
455
  normformer=False,
456
+ msa_scaling="d",
432
457
  ):
433
- if position_embedding_type == "relative":
434
- assert source_size is not None # TODO: make this a proper exception
458
+ """
459
+ Args:
460
+ msa_scaling: how should the attention logits be scaled? Can be "sqrtd"
461
+ to mimic the original Attention is All You Need approach of
462
+ dividing by the sqrt of the embedding Dimension or "d" per
463
+ "Tensor Programs V...". Default "d"
464
+ """
465
+
466
+ if (position_embedding_type == "relative") and (source_size is None):
467
+ raise ValueError(
468
+ "`source_size` for TransformerEncoder cannot be None if"
469
+ " `position_embedding_type` is relative"
470
+ )
435
471
 
436
472
  super().__init__()
437
473
  self.seq_len = seq_len
@@ -461,7 +497,8 @@ class TransformerEncoder(nn.Module):
461
497
  self.msa_dropout = msa_dropout
462
498
  self.stochastic_depth = stochastic_depth
463
499
 
464
- assert isinstance(n_layers, int) # XXX: make this a proper Exception
500
+ assert isinstance(n_layers, int)
501
+
465
502
  if n_layers == 1:
466
503
  self.stochastic_depth_probabilities = [0.0]
467
504
  else:
@@ -484,6 +521,7 @@ class TransformerEncoder(nn.Module):
484
521
  activation_kwargs=activation_kwargs,
485
522
  ff_linear_module_up=ff_linear_module_up,
486
523
  ff_linear_module_down=ff_linear_module_down,
524
+ msa_scaling=msa_scaling,
487
525
  mlp_dropout=mlp_dropout,
488
526
  msa_dropout=msa_dropout,
489
527
  identity_probability=self.stochastic_depth_probabilities[i],
@@ -143,6 +143,7 @@ class ViTEncoder(nn.Module):
143
143
  transformer_activation_kwargs: Optional[dict] = None,
144
144
  transformer_ff_linear_module_up=None,
145
145
  transformer_ff_linear_module_down=None,
146
+ transformer_msa_scaling="d",
146
147
  transformer_mlp_dropout=0.0,
147
148
  transformer_msa_dropout=0.1,
148
149
  transformer_stochastic_depth=0.1,
@@ -295,6 +296,7 @@ class ViTEncoder(nn.Module):
295
296
  activation_kwargs=transformer_activation_kwargs,
296
297
  ff_linear_module_up=transformer_ff_linear_module_up,
297
298
  ff_linear_module_down=transformer_ff_linear_module_down,
299
+ msa_scaling=transformer_msa_scaling,
298
300
  mlp_dropout=transformer_mlp_dropout,
299
301
  msa_dropout=transformer_msa_dropout,
300
302
  stochastic_depth=transformer_stochastic_depth,
@@ -405,6 +407,7 @@ class ViT(nn.Module):
405
407
  transformer_activation_kwargs: Optional[dict] = None,
406
408
  transformer_ff_linear_module_up=None,
407
409
  transformer_ff_linear_module_down=None,
410
+ transformer_msa_scaling="d",
408
411
  transformer_mlp_dropout=0.0,
409
412
  transformer_msa_dropout=0.1,
410
413
  transformer_stochastic_depth=0.1,
@@ -468,6 +471,7 @@ class ViT(nn.Module):
468
471
  transformer_activation_kwargs=transformer_activation_kwargs,
469
472
  transformer_ff_linear_module_up=transformer_ff_linear_module_up,
470
473
  transformer_ff_linear_module_down=transformer_ff_linear_module_down,
474
+ transformer_msa_scaling=transformer_msa_scaling,
471
475
  transformer_mlp_dropout=transformer_mlp_dropout,
472
476
  transformer_msa_dropout=transformer_msa_dropout,
473
477
  transformer_stochastic_depth=transformer_stochastic_depth,
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "broccoli-ml"
3
- version = "1.0.0"
3
+ version = "1.1.0"
4
4
  description = "Some useful Pytorch models, circa 2025"
5
5
  authors = [
6
6
  {name = "Nicholas Bailey"}
File without changes
File without changes
File without changes