broccoli-ml 0.40.0__tar.gz → 1.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: broccoli-ml
3
- Version: 0.40.0
3
+ Version: 1.1.0
4
4
  Summary: Some useful Pytorch models, circa 2025
5
5
  License: MIT
6
6
  Author: Nicholas Bailey
@@ -73,7 +73,15 @@ class MHAttention(nn.Module):
73
73
  bos_tokens=0,
74
74
  rotary_embedding=None,
75
75
  source_size=None,
76
+ scaling="d",
76
77
  ):
78
+ """
79
+ Args:
80
+ scaling: how should the attention logits be scaled? Can be "sqrtd"
81
+ to mimic the original Attention is All You Need approach of
82
+ dividing by the sqrt of the embedding Dimension or "d" per
83
+ "Tensor Programs V...". Default "d"
84
+ """
77
85
  super().__init__()
78
86
 
79
87
  if rotary_embedding is not None:
@@ -84,6 +92,7 @@ class MHAttention(nn.Module):
84
92
  self.embed_dim = embed_dim
85
93
  self.n_heads = n_heads
86
94
  assert embed_dim % n_heads == 0
95
+ self.scaling = scaling
87
96
 
88
97
  self.head_dim = self.embed_dim // self.n_heads
89
98
 
@@ -207,7 +216,12 @@ class MHAttention(nn.Module):
207
216
 
208
217
  qk_scores = q @ k.transpose(-1, -2)
209
218
 
210
- qk_scores /= math.sqrt(self.head_dim)
219
+ if self.scaling == "sqrtd":
220
+ qk_scores /= math.sqrt(self.head_dim)
221
+ elif self.scaling == "d":
222
+ qk_scores /= self.head_dim
223
+ else:
224
+ raise ValueError('`scaling` argument to MHAttention must be "d" or "sqrtd"')
211
225
 
212
226
  # Apply mask if causal (must come before softmax)
213
227
  if self.causal:
@@ -305,6 +319,7 @@ class TransformerBlock(nn.Module):
305
319
  activation_kwargs: Optional[dict] = None,
306
320
  ff_linear_module_up=None,
307
321
  ff_linear_module_down=None,
322
+ msa_scaling="d",
308
323
  mlp_dropout=0.0,
309
324
  msa_dropout=0.0,
310
325
  identity_probability=0.0,
@@ -314,6 +329,14 @@ class TransformerBlock(nn.Module):
314
329
  post_norm=False,
315
330
  normformer=False,
316
331
  ):
332
+ """
333
+ Args:
334
+ msa_scaling: how should the attention logits be scaled? Can be "sqrtd"
335
+ to mimic the original Attention is All You Need approach of
336
+ dividing by the sqrt of the embedding Dimension or "d" per
337
+ "Tensor Programs V...". Default "d"
338
+ """
339
+
317
340
  super().__init__()
318
341
 
319
342
  self.pre_norm = pre_norm
@@ -348,6 +371,7 @@ class TransformerBlock(nn.Module):
348
371
  rotary_embedding=self.rotary_embedding,
349
372
  source_size=source_size,
350
373
  bos_tokens=bos_tokens,
374
+ scaling=msa_scaling,
351
375
  )
352
376
 
353
377
  # Submodule for the feedforward process
@@ -429,9 +453,21 @@ class TransformerEncoder(nn.Module):
429
453
  pre_norm=True,
430
454
  post_norm=False,
431
455
  normformer=False,
456
+ msa_scaling="d",
432
457
  ):
433
- if position_embedding_type == "relative":
434
- assert source_size is not None # TODO: make this a proper exception
458
+ """
459
+ Args:
460
+ msa_scaling: how should the attention logits be scaled? Can be "sqrtd"
461
+ to mimic the original Attention is All You Need approach of
462
+ dividing by the sqrt of the embedding Dimension or "d" per
463
+ "Tensor Programs V...". Default "d"
464
+ """
465
+
466
+ if (position_embedding_type == "relative") and (source_size is None):
467
+ raise ValueError(
468
+ "`source_size` for TransformerEncoder cannot be None if"
469
+ " `position_embedding_type` is relative"
470
+ )
435
471
 
436
472
  super().__init__()
437
473
  self.seq_len = seq_len
@@ -461,7 +497,8 @@ class TransformerEncoder(nn.Module):
461
497
  self.msa_dropout = msa_dropout
462
498
  self.stochastic_depth = stochastic_depth
463
499
 
464
- assert isinstance(n_layers, int) # XXX: make this a proper Exception
500
+ assert isinstance(n_layers, int)
501
+
465
502
  if n_layers == 1:
466
503
  self.stochastic_depth_probabilities = [0.0]
467
504
  else:
@@ -484,6 +521,7 @@ class TransformerEncoder(nn.Module):
484
521
  activation_kwargs=activation_kwargs,
485
522
  ff_linear_module_up=ff_linear_module_up,
486
523
  ff_linear_module_down=ff_linear_module_down,
524
+ msa_scaling=msa_scaling,
487
525
  mlp_dropout=mlp_dropout,
488
526
  msa_dropout=msa_dropout,
489
527
  identity_probability=self.stochastic_depth_probabilities[i],
@@ -85,6 +85,14 @@ class SequencePoolClassificationHead(ClassificationHead):
85
85
  def __init__(self, d_model, linear_module, out_dim, batch_norm=True):
86
86
  super().__init__(d_model, linear_module, out_dim, batch_norm=batch_norm)
87
87
  self.summarize = SequencePool(d_model, linear_module)
88
+ # Rebuild the classification process with the correct summary module:
89
+ self.classification_process = nn.Sequential(
90
+ *[
91
+ self.summarize,
92
+ self.projection,
93
+ self.batch_norm,
94
+ ]
95
+ )
88
96
 
89
97
 
90
98
  class ViTEncoder(nn.Module):
@@ -135,6 +143,7 @@ class ViTEncoder(nn.Module):
135
143
  transformer_activation_kwargs: Optional[dict] = None,
136
144
  transformer_ff_linear_module_up=None,
137
145
  transformer_ff_linear_module_down=None,
146
+ transformer_msa_scaling="d",
138
147
  transformer_mlp_dropout=0.0,
139
148
  transformer_msa_dropout=0.1,
140
149
  transformer_stochastic_depth=0.1,
@@ -287,6 +296,7 @@ class ViTEncoder(nn.Module):
287
296
  activation_kwargs=transformer_activation_kwargs,
288
297
  ff_linear_module_up=transformer_ff_linear_module_up,
289
298
  ff_linear_module_down=transformer_ff_linear_module_down,
299
+ msa_scaling=transformer_msa_scaling,
290
300
  mlp_dropout=transformer_mlp_dropout,
291
301
  msa_dropout=transformer_msa_dropout,
292
302
  stochastic_depth=transformer_stochastic_depth,
@@ -397,6 +407,7 @@ class ViT(nn.Module):
397
407
  transformer_activation_kwargs: Optional[dict] = None,
398
408
  transformer_ff_linear_module_up=None,
399
409
  transformer_ff_linear_module_down=None,
410
+ transformer_msa_scaling="d",
400
411
  transformer_mlp_dropout=0.0,
401
412
  transformer_msa_dropout=0.1,
402
413
  transformer_stochastic_depth=0.1,
@@ -460,6 +471,7 @@ class ViT(nn.Module):
460
471
  transformer_activation_kwargs=transformer_activation_kwargs,
461
472
  transformer_ff_linear_module_up=transformer_ff_linear_module_up,
462
473
  transformer_ff_linear_module_down=transformer_ff_linear_module_down,
474
+ transformer_msa_scaling=transformer_msa_scaling,
463
475
  transformer_mlp_dropout=transformer_mlp_dropout,
464
476
  transformer_msa_dropout=transformer_msa_dropout,
465
477
  transformer_stochastic_depth=transformer_stochastic_depth,
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "broccoli-ml"
3
- version = "0.40.0"
3
+ version = "1.1.0"
4
4
  description = "Some useful Pytorch models, circa 2025"
5
5
  authors = [
6
6
  {name = "Nicholas Bailey"}
File without changes
File without changes