broccoli-ml 5.1.1__tar.gz → 9.5.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: broccoli-ml
3
- Version: 5.1.1
3
+ Version: 9.5.1
4
4
  Summary: Some useful Pytorch models, circa 2025
5
5
  License: MIT
6
6
  Author: Nicholas Bailey
@@ -46,10 +46,7 @@ class GELU(nn.Module):
46
46
 
47
47
  class Swish(nn.Module):
48
48
  """
49
- Implementation of (beta) SwiGLU, as introduced in "GLU Variants Improve Transformer"
50
- (https://arxiv.org/abs/2002.05202v1) and used to great effect in LLaMa 2.0.
51
-
52
- Halves the incoming parameter count, which should be scaled up before input.
49
+ Implementation of (beta) Swish
53
50
  """
54
51
 
55
52
  def __init__(self) -> None:
@@ -0,0 +1,352 @@
1
+ import math
2
+ import random
3
+ import warnings
4
+ from typing import Union, List, Iterable
5
+
6
+ import torch
7
+ from torch import nn
8
+ from torch.nn import functional as F
9
+
10
+ from .tensor import SigmaReparamTensor, AnchoredReparamTensor, NormReparamTensor
11
+
12
+
13
+ class SpectralNormLinear(nn.Module):
14
+ """
15
+ Inspired by Apple's Spectral Normed Linear Layers
16
+ (https://github.com/apple/ml-sigma-reparam)
17
+ """
18
+
19
+ def __init__(self, in_features: int, out_features: int, bias: bool = True):
20
+ super().__init__()
21
+ self.in_features = in_features
22
+ self.out_features = out_features
23
+ self.use_bias = bias
24
+
25
+ self.weights = None
26
+
27
+ # Define the bias vector as a learnable parameter if required.
28
+ if self.use_bias:
29
+ self.bias = nn.Parameter(torch.empty(out_features))
30
+ else:
31
+ # If no bias, register it as None.
32
+ # This is important so that PyTorch doesn't complain when saving/loading the model.
33
+ self.register_parameter("bias", None)
34
+
35
+ self.reset_parameters()
36
+
37
+ def reset_parameters(self) -> None:
38
+ weights = torch.empty(self.out_features, self.in_features)
39
+ stdv = 1.0 / math.sqrt(self.in_features)
40
+ nn.init.uniform_(weights, a=-stdv, b=stdv)
41
+ if self.use_bias:
42
+ fan_in, _ = nn.init._calculate_fan_in_and_fan_out(weights)
43
+ bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
44
+ nn.init.uniform_(self.bias, -bound, bound)
45
+ self.weights = SigmaReparamTensor(weights)
46
+
47
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
48
+ return F.linear(x, self.weights(), self.bias)
49
+
50
+ def __repr__(self) -> str:
51
+ # Optional: A nice representation for printing the module.
52
+ return (
53
+ f"SpectralNormFeedForward(in_features={self.in_features},"
54
+ f"out_features={self.out_features}, bias={self.use_bias})"
55
+ )
56
+
57
+
58
+ class AnchoredLinear(nn.Module):
59
+ """
60
+ ...
61
+ """
62
+
63
+ def __init__(self, in_features: int, out_features: int, bias: bool = True):
64
+ super().__init__()
65
+ self.in_features = in_features
66
+ self.out_features = out_features
67
+ self.use_bias = bias
68
+
69
+ self.weights = None
70
+
71
+ # Define the bias vector as a learnable parameter if required.
72
+ if self.use_bias:
73
+ self.bias = nn.Parameter(torch.empty(out_features))
74
+ else:
75
+ # If no bias, register it as None.
76
+ # This is important so that PyTorch doesn't complain when saving/loading the model.
77
+ self.register_parameter("bias", None)
78
+
79
+ self.reset_parameters()
80
+
81
+ def reset_parameters(self) -> None:
82
+ weights = torch.empty(self.out_features, self.in_features)
83
+ stdv = 1.0 / math.sqrt(self.in_features)
84
+ nn.init.uniform_(weights, a=-stdv, b=stdv)
85
+ if self.use_bias:
86
+ fan_in, _ = nn.init._calculate_fan_in_and_fan_out(weights)
87
+ bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
88
+ nn.init.uniform_(self.bias, -bound, bound)
89
+ self.weights = AnchoredReparamTensor(weights)
90
+
91
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
92
+ return F.linear(x, self.weights(), self.bias)
93
+
94
+ def __repr__(self) -> str:
95
+ # Optional: A nice representation for printing the module.
96
+ return (
97
+ f"AnchoredLinear(in_features={self.in_features},"
98
+ f"out_features={self.out_features}, bias={self.use_bias})"
99
+ )
100
+
101
+
102
+ class WeightNormedLinear(nn.Module):
103
+ """
104
+ ...
105
+ """
106
+
107
+ def __init__(self, in_features: int, out_features: int, bias: bool = True):
108
+ super().__init__()
109
+ self.in_features = in_features
110
+ self.out_features = out_features
111
+ self.use_bias = bias
112
+
113
+ self.weights = None
114
+
115
+ # Define the bias vector as a learnable parameter if required.
116
+ if self.use_bias:
117
+ self.bias = nn.Parameter(torch.empty(out_features))
118
+ else:
119
+ # If no bias, register it as None.
120
+ # This is important so that PyTorch doesn't complain when saving/loading the model.
121
+ self.register_parameter("bias", None)
122
+
123
+ self.reset_parameters()
124
+
125
+ def reset_parameters(self) -> None:
126
+ weights = torch.empty(self.out_features, self.in_features)
127
+ stdv = 1.0 / math.sqrt(self.in_features)
128
+ nn.init.uniform_(weights, a=-stdv, b=stdv)
129
+ if self.use_bias:
130
+ fan_in, _ = nn.init._calculate_fan_in_and_fan_out(weights)
131
+ bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
132
+ nn.init.uniform_(self.bias, -bound, bound)
133
+ self.weights = NormReparamTensor(weights)
134
+
135
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
136
+ return F.linear(x, self.weights(), self.bias)
137
+
138
+ def __repr__(self) -> str:
139
+ return (
140
+ f"WeightNormedLinear(in_features={self.in_features},"
141
+ f"out_features={self.out_features}, bias={self.use_bias})"
142
+ )
143
+
144
+
145
+ class RecyclingLinear(nn.Module):
146
+ def __init__(
147
+ self,
148
+ in_features: int,
149
+ out_features: int,
150
+ bias: bool = True,
151
+ row_recycling_rate: float = 0.0,
152
+ column_recycling_rate: float = 0.0,
153
+ adaptive=False,
154
+ xglu=False,
155
+ ):
156
+ super().__init__()
157
+ self.in_features = in_features
158
+ self.out_features = out_features
159
+ self.bias = bias
160
+ self.xglu = xglu
161
+ self.linear = nn.Linear(in_features, out_features, bias=bias)
162
+ self.row_recycling_rate = row_recycling_rate
163
+ self.column_recycling_rate = column_recycling_rate
164
+ self.adaptive = adaptive
165
+ self.optimisers = []
166
+ self.initial_learning_rates = []
167
+ self._warned_about_registration = False
168
+
169
+ def register_optimiser(self, optimiser: torch.optim.Optimizer):
170
+ self.optimisers.append(optimiser)
171
+ self.initial_learning_rates.append(self._get_learning_rate(optimiser))
172
+ if self.initial_learning_rates[-1] == 0.0:
173
+ warnings.warn(
174
+ "Learning rate of registered optimiser was 0.0 - make sure "
175
+ "you haven't initialised a scheduler before registering the "
176
+ "optimiser",
177
+ stacklevel=2,
178
+ )
179
+
180
+ def _get_learning_rate(self, optimiser: torch.optim.Optimizer):
181
+ for group in optimiser.param_groups:
182
+ for param in group["params"]:
183
+ if param is self.linear.weight:
184
+ return group["lr"]
185
+
186
+ def _get_multiplier(self):
187
+ if not self.adaptive or not self.optimisers:
188
+ return 1.0
189
+ else:
190
+ init = self.initial_learning_rates
191
+ current = [self._get_learning_rate(o) for o in self.optimisers]
192
+ pairs = zip(current, init, strict=True)
193
+ multipliers = [a / b for a, b in pairs if b != 0.0]
194
+ return min(multipliers) if multipliers else 0.0
195
+
196
+ def reset_rows(self, indices):
197
+ if not torch.is_tensor(indices):
198
+ idx_tensor = torch.as_tensor(
199
+ list(indices), dtype=torch.long, device=self.linear.weight.device
200
+ )
201
+ else:
202
+ idx_tensor = indices
203
+
204
+ if idx_tensor.size(0):
205
+ value_indices = indices
206
+ centred_value_weights = self._mean_value_weights()
207
+ centred_value_weights = centred_value_weights.expand(indices.size(0), -1)
208
+ if self.xglu:
209
+ gate_indices = indices
210
+ value_indices = indices + (self.linear.out_features // 2)
211
+ centred_gate_weights = self._mean_gate_weights()
212
+ centred_gate_weights = centred_gate_weights.expand(indices.size(0), -1)
213
+ self._update_weights(
214
+ gate_indices, 0, centred_gate_weights, self.optimisers # dim
215
+ )
216
+ self._update_weights(
217
+ value_indices, 0, centred_value_weights, self.optimisers
218
+ )
219
+ else:
220
+ return
221
+
222
+ def reset_columns(self, indices):
223
+ if not torch.is_tensor(indices):
224
+ idx_tensor = torch.as_tensor(
225
+ list(indices), dtype=torch.long, device=self.linear.weight.device
226
+ )
227
+ else:
228
+ idx_tensor = indices
229
+
230
+ if idx_tensor.size(0):
231
+ random_weights = self._random_weights(
232
+ self.linear.weight.size(0), indices.size(0)
233
+ )
234
+ # Make random col weights quiet so they don't introduce loud noise...
235
+ # ...but not so quiet that FP16 zeros them and ruins symmetry breaking!
236
+ random_weights *= 0.1
237
+ self._update_weights(indices, 1, random_weights, self.optimisers) # dim
238
+ else:
239
+ return
240
+
241
+ def forward(self, x):
242
+ if self.training and self.optimisers:
243
+ self.reset_rows(self.get_reset_indices(0))
244
+ self.reset_columns(self.get_reset_indices(1))
245
+ elif self.training and not self._warned_about_registration:
246
+ warnings.warn(
247
+ "RecyclingLinear: No optimiser registered. Recycling disabled.",
248
+ stacklevel=2,
249
+ )
250
+ self._warned_about_registration = True
251
+
252
+ return self.linear(x)
253
+
254
+ def get_reset_indices(self, dim):
255
+ base_rate = self.row_recycling_rate if dim == 0 else self.column_recycling_rate
256
+ p = base_rate * self._get_multiplier()
257
+ if dim == 0:
258
+ if self.xglu:
259
+ sample_space = self.linear.out_features // 2
260
+ else:
261
+ sample_space = self.linear.out_features
262
+ elif dim == 1:
263
+ sample_space = self.linear.in_features
264
+ else:
265
+ raise ValueError("`dim` must be 0 or 1")
266
+
267
+ # Sample the indices
268
+ probs = torch.rand(sample_space, device=self.linear.weight.device)
269
+ mask = probs < p
270
+ if mask.any():
271
+ return torch.nonzero(mask).squeeze(-1)
272
+ else:
273
+ return torch.tensor([], dtype=torch.long, device=self.linear.weight.device)
274
+
275
+ def _random_weights(self, rows, columns):
276
+ device = self.linear.weight.device
277
+ weights = self.linear.weight.data
278
+ stdv = 1.0 / math.sqrt(weights.size(1))
279
+ random_weights = torch.rand(rows, columns, device=device)
280
+ random_weights -= 0.5 # Range [-0.5, +0.5]
281
+ random_weights *= 2.0 * stdv # Range [-stdv, +stdv]
282
+ return random_weights
283
+
284
+ def _mean_value_weights(self):
285
+ """
286
+ Only used when self.xglu
287
+ """
288
+ weights = self.linear.weight.data
289
+ rows = weights.size(0)
290
+ if self.xglu:
291
+ return self.linear.weight[int(rows / 2) :].data.mean(dim=0, keepdim=True)
292
+ else:
293
+ return self.linear.weight.data.mean(dim=0, keepdim=True)
294
+
295
+ def _mean_gate_weights(self):
296
+ """
297
+ Only used when self.xglu
298
+ """
299
+ weights = self.linear.weight.data
300
+ rows = weights.size(0)
301
+ return self.linear.weight[: int(rows / 2)].data.mean(dim=0, keepdim=True)
302
+
303
+ def _update_weights(
304
+ self,
305
+ indices: Iterable[int],
306
+ dim: int,
307
+ data: torch.Tensor,
308
+ optimisers: Union[
309
+ List[torch.optim.Optimizer], torch.optim.Optimizer, None
310
+ ] = None,
311
+ ):
312
+ if optimisers is None:
313
+ optimisers = []
314
+ if not isinstance(optimisers, list):
315
+ optimisers = [optimisers]
316
+
317
+ if not torch.is_tensor(indices):
318
+ idx_tensor = torch.as_tensor(
319
+ list(indices), dtype=torch.long, device=self.linear.weight.device
320
+ )
321
+ else:
322
+ idx_tensor = indices
323
+
324
+ if idx_tensor.numel() == 0:
325
+ return
326
+
327
+ with torch.no_grad():
328
+ if dim == 0:
329
+ self.linear.weight.data[idx_tensor] = data
330
+ elif dim == 1:
331
+ self.linear.weight.data[:, idx_tensor] = data
332
+ else:
333
+ raise ValueError("`dim` must be 0 or 1")
334
+
335
+ self._reset_optim_state(self.linear.weight, idx_tensor, optimisers, dim=dim)
336
+
337
+ def _reset_optim_state(self, param, idx_tensor, optimisers, dim):
338
+ """
339
+ Zeroes out the optimizer state for the given indices in a single operation.
340
+ """
341
+ for optimiser in optimisers:
342
+ if param not in optimiser.state:
343
+ continue
344
+ state = optimiser.state[param]
345
+
346
+ for _, buffer in state.items():
347
+ if torch.is_tensor(buffer) and buffer.shape == param.shape:
348
+ # Vectorized zeroing
349
+ if dim == 0:
350
+ buffer[idx_tensor] = 0.0
351
+ else:
352
+ buffer[:, idx_tensor] = 0.0
@@ -13,6 +13,7 @@ from .rope import RotaryEmbedding, apply_rotary_emb
13
13
  try:
14
14
  from flash_attn import flash_attn_func
15
15
 
16
+ print("Using flash-attn.")
16
17
  FLASH_ATTN = True
17
18
  except ImportError:
18
19
  pass
@@ -76,7 +77,7 @@ class MHAttention(nn.Module):
76
77
  causal=False,
77
78
  seq_len=None,
78
79
  linear_module: nn.Module = nn.Linear,
79
- bos_tokens=0,
80
+ utility_tokens=0,
80
81
  rotary_embedding=None,
81
82
  source_size=None,
82
83
  scaling="d",
@@ -129,7 +130,7 @@ class MHAttention(nn.Module):
129
130
  )
130
131
  self.rotary_embedding = rotary_embedding
131
132
  self.source_size = source_size
132
- self.bos_tokens = bos_tokens
133
+ self.utility_tokens = utility_tokens
133
134
 
134
135
  self.reset_parameters()
135
136
 
@@ -156,7 +157,7 @@ class MHAttention(nn.Module):
156
157
  self, q: torch.Tensor, k: torch.Tensor
157
158
  ) -> Tuple[torch.Tensor, torch.Tensor]:
158
159
  """
159
- Apply Axial RoPE to all tokens except BOS tokens
160
+ Apply Axial RoPE to all tokens except utility tokens
160
161
  """
161
162
 
162
163
  if len(self.source_size) == 1:
@@ -180,8 +181,8 @@ class MHAttention(nn.Module):
180
181
  "`source_size` must be a tuple of 1, 2 or 3 integers"
181
182
  )
182
183
 
183
- q_bos, q_img = q[:, : self.bos_tokens, :], q[:, self.bos_tokens :, :]
184
- k_bos, k_img = k[:, : self.bos_tokens, :], k[:, self.bos_tokens :, :]
184
+ q_util, q_img = q[:, : self.utility_tokens, :], q[:, self.utility_tokens :, :]
185
+ k_util, k_img = k[:, : self.utility_tokens, :], k[:, self.utility_tokens :, :]
185
186
 
186
187
  q_img = rearrange(
187
188
  q_img,
@@ -208,9 +209,9 @@ class MHAttention(nn.Module):
208
209
  f"b {spatial_dimension_names} d -> b ({spatial_dimension_names}) d",
209
210
  )
210
211
 
211
- # Re-combine the BOS tokens and the RoPE-enhanced image tokens
212
- q = torch.cat([q_bos, q_img], dim=1)
213
- k = torch.cat([k_bos, k_img], dim=1)
212
+ # Re-combine the utility tokens and the RoPE-enhanced sequence tokens
213
+ q = torch.cat([q_util, q_img], dim=1)
214
+ k = torch.cat([k_util, k_img], dim=1)
214
215
 
215
216
  return q, k
216
217
 
@@ -284,7 +285,7 @@ class MHAttention(nn.Module):
284
285
 
285
286
  return self.out_proj(output_without_heads)
286
287
 
287
- def attention_scores(self, q, k, v):
288
+ def attention_logits(self, q, k, v):
288
289
 
289
290
  q, k, v = self.project_qkv(q, k, v)
290
291
 
@@ -301,8 +302,6 @@ class MHAttention(nn.Module):
301
302
  if self.causal:
302
303
  qk_scores.masked_fill_(self.mask, float("-inf"))
303
304
 
304
- qk_scores = F.softmax(qk_scores, dim=-1)
305
-
306
305
  return qk_scores # (batch, head, seq_len, seq_len)
307
306
 
308
307
  def reset_parameters(self):
@@ -326,6 +325,8 @@ class FeedforwardBlock(nn.Module):
326
325
  activation=nn.ReLU,
327
326
  activation_kwargs=None,
328
327
  dropout=0.0,
328
+ inner_dropout=None,
329
+ outer_dropout=None,
329
330
  linear_module_up=nn.Linear,
330
331
  linear_module_down=nn.Linear,
331
332
  pre_norm=True,
@@ -339,6 +340,7 @@ class FeedforwardBlock(nn.Module):
339
340
  self.checkpoint = checkpoint
340
341
  self.residual_path = residual_path
341
342
  self.post_norm = post_norm
343
+ self.xglu = activation.__name__.endswith("GLU")
342
344
 
343
345
  if self.residual_path and (output_features < input_features):
344
346
  raise ValueError(
@@ -355,29 +357,63 @@ class FeedforwardBlock(nn.Module):
355
357
  else:
356
358
  self.activation = activation()
357
359
 
358
- self.dropout = nn.Dropout(dropout)
360
+ self.inner_dropout = nn.Dropout(
361
+ inner_dropout if inner_dropout is not None else dropout
362
+ )
363
+ self.outer_dropout = nn.Dropout(
364
+ outer_dropout if outer_dropout is not None else dropout
365
+ )
359
366
 
360
367
  self.max_features = (
361
- 2 * ratio * output_features
362
- if activation.__name__.endswith("GLU")
363
- else ratio * output_features
368
+ 2 * ratio * output_features if self.xglu else ratio * output_features
364
369
  )
365
370
 
371
+ self.linear_in = linear_module_up(input_features, self.max_features)
372
+ self.linear_out = linear_module_down(ratio * output_features, output_features)
373
+
366
374
  self.process = nn.Sequential(
367
375
  *[
368
376
  nn.LayerNorm(input_features) if pre_norm else nn.Identity(),
369
- linear_module_up(input_features, self.max_features),
377
+ self.linear_in,
370
378
  self.activation,
379
+ self.inner_dropout,
371
380
  nn.LayerNorm(ratio * output_features) if normformer else nn.Identity(),
372
- linear_module_down(ratio * output_features, output_features),
373
- self.dropout,
381
+ self.linear_out,
382
+ self.outer_dropout,
374
383
  ]
375
384
  )
376
385
 
386
+ self.recycling_enabled = False
387
+ if hasattr(self.linear_in, "row_recycling_rate") and hasattr(
388
+ self.linear_out, "column_recycling_rate"
389
+ ):
390
+ self.recycling_enabled = True
391
+ self.master_recycling_rate = self.linear_in.row_recycling_rate
392
+ self.linear_in.row_recycling_rate = 0.0
393
+ self.linear_out.column_recycling_rate = 0.0
394
+ if (
395
+ hasattr(self.linear_in, "column_recycling_rate")
396
+ and self.linear_in.column_recycling_rate > 0
397
+ ) or (
398
+ hasattr(self.linear_out, "row_recycling_rate")
399
+ and self.linear_out.row_recycling_rate > 0
400
+ ):
401
+ raise NotImplementedError(
402
+ "At the moment this layer can only support recycling linear "
403
+ "layers if the in layer resets only rows and the out layer "
404
+ "resets only columns."
405
+ )
406
+
377
407
  self.reset_parameters()
378
408
 
379
409
  def forward(self, x):
380
410
 
411
+ # Recycle weights if using recycling linear layers
412
+ if self.training and self.recycling_enabled:
413
+ indices = self.linear_out.get_reset_indices(1)
414
+ self.linear_in.reset_rows(indices)
415
+ self.linear_out.reset_columns(indices)
416
+
381
417
  if self.checkpoint:
382
418
  processed = checkpoint(self.process, x, use_reentrant=False)
383
419
  else:
@@ -416,14 +452,16 @@ class TransformerBlock(nn.Module):
416
452
  n_heads,
417
453
  relative_position_embedding=False,
418
454
  source_size=None,
419
- bos_tokens=0,
455
+ utility_tokens=0,
420
456
  mlp_ratio=4,
421
457
  activation: nn.Module = nn.ReLU,
422
458
  activation_kwargs: Optional[dict] = None,
423
459
  ff_linear_module_up=None,
424
460
  ff_linear_module_down=None,
425
461
  msa_scaling="d",
426
- mlp_dropout=0.0,
462
+ ff_dropout=0.0,
463
+ ff_inner_dropout=0.0,
464
+ ff_outer_dropout=0.0,
427
465
  msa_dropout=0.0,
428
466
  identity_probability=0.0,
429
467
  causal=False,
@@ -474,7 +512,7 @@ class TransformerBlock(nn.Module):
474
512
  linear_module=linear_module,
475
513
  rotary_embedding=self.rotary_embedding,
476
514
  source_size=source_size,
477
- bos_tokens=bos_tokens,
515
+ utility_tokens=utility_tokens,
478
516
  scaling=msa_scaling,
479
517
  )
480
518
 
@@ -485,7 +523,9 @@ class TransformerBlock(nn.Module):
485
523
  d_model,
486
524
  activation=activation,
487
525
  activation_kwargs=activation_kwargs,
488
- dropout=mlp_dropout,
526
+ dropout=ff_dropout,
527
+ inner_dropout=ff_inner_dropout,
528
+ outer_dropout=ff_outer_dropout,
489
529
  linear_module_up=(
490
530
  ff_linear_module_up
491
531
  if ff_linear_module_up is not None
@@ -529,15 +569,15 @@ class TransformerBlock(nn.Module):
529
569
 
530
570
  return x
531
571
 
532
- def attention_scores(self, x):
572
+ def attention_logits(self, x):
533
573
  """
534
574
  Give back the attention scores used in this layer.
535
575
  """
536
576
  if self.pre_norm:
537
577
  x = self.layer_norm_1(x)
538
- return self.attn.attention_scores(x, x, x)
578
+ return self.attn.attention_logits(x, x, x)
539
579
  else:
540
- return self.attn.attention_scores(x, x, x)
580
+ return self.attn.attention_logits(x, x, x)
541
581
 
542
582
  def reset_parameters(self):
543
583
  self.layer_norm_1.reset_parameters()
@@ -568,13 +608,15 @@ class TransformerEncoder(nn.Module):
568
608
  activation_kwargs: Optional[dict] = None,
569
609
  ff_linear_module_up=None,
570
610
  ff_linear_module_down=None,
571
- mlp_dropout=0.0,
611
+ ff_dropout=0.0,
612
+ ff_inner_dropout=0.0,
613
+ ff_outer_dropout=0.0,
572
614
  msa_dropout=0.0,
573
615
  stochastic_depth=0.0,
574
616
  causal=False,
575
617
  linear_module=nn.Linear,
576
- bos_tokens=0,
577
- return_bos_tokens=False,
618
+ utility_tokens=0,
619
+ return_utility_tokens=False,
578
620
  pre_norm=True,
579
621
  post_norm=False,
580
622
  normformer=False,
@@ -592,22 +634,33 @@ class TransformerEncoder(nn.Module):
592
634
  if relative_position_embedding and (source_size is None):
593
635
  raise ValueError(
594
636
  "`source_size` for TransformerEncoder cannot be None if"
595
- " `position_embedding_type` is relative"
637
+ " `relative_position_embedding` is True"
638
+ )
639
+
640
+ if absolute_position_embedding and (seq_len is None):
641
+ raise ValueError(
642
+ "`seq_len` for TransformerEncoder cannot be None if"
643
+ " `absolute_position_embedding` is True"
596
644
  )
597
645
 
598
646
  super().__init__()
599
647
  self.seq_len = seq_len
600
648
  self.n_heads = n_heads
601
- self._bos_tokens = bos_tokens
602
- self.return_bos_tokens = return_bos_tokens
603
-
604
- # Initialise BOS tokens with normal init, like usual Pytorch embeddings
605
- if self._bos_tokens:
606
- self._bos_embedding = nn.Parameter(torch.empty(self._bos_tokens, d_model))
607
- nn.init.normal_(self._bos_embedding, mean=0.0, std=1.0)
608
- self.full_sequence_length = self.seq_len + self._bos_tokens
649
+ self._utility_tokens = utility_tokens
650
+ self.return_utility_tokens = return_utility_tokens
651
+
652
+ # Initialise utility tokens with normal init, like usual Pytorch embeddings
653
+ if self._utility_tokens:
654
+ self._utility_token_embedding = nn.Parameter(
655
+ torch.empty(self._utility_tokens, d_model)
656
+ )
657
+ nn.init.normal_(self._utility_token_embedding, mean=0.0, std=1.0)
658
+ else:
659
+ self._utility_token_embedding = None
660
+
661
+ if self._utility_tokens and (self.seq_len is not None):
662
+ self.full_sequence_length = self.seq_len + self._utility_tokens
609
663
  else:
610
- self._bos_embedding = None
611
664
  self.full_sequence_length = self.seq_len
612
665
 
613
666
  self.d_model = d_model
@@ -619,7 +672,7 @@ class TransformerEncoder(nn.Module):
619
672
  else:
620
673
  self.absolute_position_embedding = None
621
674
 
622
- self.mlp_dropout = mlp_dropout
675
+ self.mlp_dropout = ff_dropout
623
676
  self.msa_dropout = msa_dropout
624
677
  self.stochastic_depth = stochastic_depth
625
678
 
@@ -641,14 +694,16 @@ class TransformerEncoder(nn.Module):
641
694
  n_heads,
642
695
  relative_position_embedding=relative_position_embedding,
643
696
  source_size=source_size,
644
- bos_tokens=bos_tokens,
697
+ utility_tokens=utility_tokens,
645
698
  mlp_ratio=mlp_ratio,
646
699
  activation=activation,
647
700
  activation_kwargs=activation_kwargs,
648
701
  ff_linear_module_up=ff_linear_module_up,
649
702
  ff_linear_module_down=ff_linear_module_down,
650
703
  msa_scaling=msa_scaling,
651
- mlp_dropout=mlp_dropout,
704
+ ff_dropout=ff_dropout,
705
+ ff_inner_dropout=ff_inner_dropout,
706
+ ff_outer_dropout=ff_outer_dropout,
652
707
  msa_dropout=msa_dropout,
653
708
  identity_probability=self.stochastic_depth_probabilities[i],
654
709
  causal=causal,
@@ -669,8 +724,10 @@ class TransformerEncoder(nn.Module):
669
724
  return ",".join([str(block._kv_distance) for block in self.blocks])
670
725
 
671
726
  def preprocess(self, x):
672
- if self._bos_tokens:
673
- x = torch.cat([self._bos_embedding.expand(x.size(0), -1, -1), x], dim=1)
727
+ if self._utility_tokens:
728
+ x = torch.cat(
729
+ [self._utility_token_embedding.expand(x.size(0), -1, -1), x], dim=1
730
+ )
674
731
  else:
675
732
  x = x
676
733
 
@@ -692,12 +749,12 @@ class TransformerEncoder(nn.Module):
692
749
  for block in self.blocks:
693
750
  x = block(x)
694
751
 
695
- if self._bos_tokens and not self.return_bos_tokens:
696
- return x[:, self._bos_tokens :, :]
752
+ if self._utility_tokens and not self.return_utility_tokens:
753
+ return x[:, self._utility_tokens :, :]
697
754
  else:
698
755
  return x
699
756
 
700
- def attention_scores(self, x):
757
+ def attention_logits(self, x):
701
758
 
702
759
  x = self.preprocess(x)
703
760
 
@@ -705,15 +762,15 @@ class TransformerEncoder(nn.Module):
705
762
 
706
763
  for block in self.blocks:
707
764
  # Get attention scores with shape (batch, 1, head, seq_len, seq_len)
708
- layer_attention_scores = block.attention_scores(x).unsqueeze(1)
709
- layer_scores.append(layer_attention_scores)
765
+ layer_attention_logits = block.attention_logits(x).unsqueeze(1)
766
+ layer_scores.append(layer_attention_logits)
710
767
  x = block(x)
711
768
 
712
769
  return torch.cat(layer_scores, dim=1) # (batch, layer, head, seq_len, seq_len)
713
770
 
714
771
  def reset_parameters(self):
715
- if self._bos_embedding is not None:
716
- nn.init.normal_(self._bos_embedding, mean=0.0, std=1.0)
772
+ if self._utility_token_embedding is not None:
773
+ nn.init.normal_(self._utility_token_embedding, mean=0.0, std=1.0)
717
774
 
718
775
  if self.absolute_position_embedding is not None:
719
776
  self.absolute_position_embedding.reset_parameters()
@@ -11,7 +11,6 @@ from einops.layers.torch import Rearrange
11
11
 
12
12
  import torch
13
13
  import torch.nn as nn
14
- import torch.nn.functional as F
15
14
 
16
15
 
17
16
  class GetCLSToken(nn.Module):
@@ -39,6 +38,9 @@ class SequencePool(nn.Module):
39
38
  weights = self.attention(x)
40
39
  return einsum(weights, x, "batch seq, batch seq d_model -> batch d_model")
41
40
 
41
+ def attention_scores(self, x):
42
+ return self.attention(x)
43
+
42
44
  def reset_parameters(self):
43
45
  # Iterate over modules in the sequential block
44
46
  for module in self.attention:
@@ -159,7 +161,9 @@ class ViTEncoder(nn.Module):
159
161
  transformer_initial_ff_residual_path=True,
160
162
  transformer_initial_ff_linear_module_up=None,
161
163
  transformer_initial_ff_linear_module_down=None,
162
- transformer_initial_ff_mlp_dropout=None,
164
+ transformer_initial_ff_dropout=None,
165
+ transformer_initial_ff_inner_dropout=None,
166
+ transformer_initial_ff_outer_dropout=None,
163
167
  transformer_pre_norm=True,
164
168
  transformer_normformer=False,
165
169
  transformer_post_norm=False,
@@ -169,14 +173,16 @@ class ViTEncoder(nn.Module):
169
173
  transformer_layers=7,
170
174
  transformer_heads=4,
171
175
  transformer_mlp_ratio=2,
172
- transformer_bos_tokens=0,
173
- transformer_return_bos_tokens=False,
176
+ transformer_utility_tokens=0,
177
+ transformer_return_utility_tokens=False,
174
178
  transformer_activation: nn.Module = SquaredReLU,
175
179
  transformer_activation_kwargs: Optional[dict] = None,
176
180
  transformer_ff_linear_module_up=None,
177
181
  transformer_ff_linear_module_down=None,
178
182
  transformer_msa_scaling="d",
179
- transformer_mlp_dropout=0.0,
183
+ transformer_ff_dropout=0.0,
184
+ transformer_ff_inner_dropout=0.0,
185
+ transformer_ff_outer_dropout=0.0,
180
186
  transformer_msa_dropout=0.1,
181
187
  transformer_stochastic_depth=0.1,
182
188
  transformer_checkpoint_ff=True,
@@ -331,13 +337,15 @@ class ViTEncoder(nn.Module):
331
337
  ff_linear_module_up=transformer_ff_linear_module_up,
332
338
  ff_linear_module_down=transformer_ff_linear_module_down,
333
339
  msa_scaling=transformer_msa_scaling,
334
- mlp_dropout=transformer_mlp_dropout,
340
+ ff_dropout=transformer_ff_dropout,
341
+ ff_inner_dropout=transformer_ff_inner_dropout,
342
+ ff_outer_dropout=transformer_ff_outer_dropout,
335
343
  msa_dropout=transformer_msa_dropout,
336
344
  stochastic_depth=transformer_stochastic_depth,
337
345
  causal=False,
338
346
  linear_module=linear_module,
339
- bos_tokens=transformer_bos_tokens,
340
- return_bos_tokens=transformer_return_bos_tokens,
347
+ utility_tokens=transformer_utility_tokens,
348
+ return_utility_tokens=transformer_return_utility_tokens,
341
349
  pre_norm=transformer_pre_norm,
342
350
  normformer=transformer_normformer,
343
351
  post_norm=transformer_post_norm,
@@ -355,9 +363,21 @@ class ViTEncoder(nn.Module):
355
363
  activation_kwargs=transformer_activation_kwargs,
356
364
  dropout=(
357
365
  # First truthy assigned value
358
- transformer_initial_ff_mlp_dropout
359
- if transformer_initial_ff_mlp_dropout is not None
360
- else transformer_mlp_dropout
366
+ transformer_initial_ff_dropout
367
+ if transformer_initial_ff_dropout is not None
368
+ else transformer_ff_dropout
369
+ ),
370
+ inner_dropout=(
371
+ # First truthy assigned value
372
+ transformer_initial_ff_inner_dropout
373
+ if transformer_initial_ff_inner_dropout is not None
374
+ else transformer_ff_inner_dropout
375
+ ),
376
+ outer_dropout=(
377
+ # First truthy assigned value
378
+ transformer_initial_ff_outer_dropout
379
+ if transformer_initial_ff_outer_dropout is not None
380
+ else transformer_ff_outer_dropout
361
381
  ),
362
382
  linear_module_up=(
363
383
  # First truthy assigned value
@@ -400,9 +420,9 @@ class ViTEncoder(nn.Module):
400
420
  def forward(self, x):
401
421
  return self.encoder(x)
402
422
 
403
- def attention_scores(self, x):
423
+ def attention_logits(self, x):
404
424
  x = self.encoder[:-1](x)
405
- return self.encoder[-1].attention_scores(x)
425
+ return self.encoder[-1].attention_logits(x)
406
426
 
407
427
  def reset_parameters(self):
408
428
  for module in self.encoder:
@@ -439,7 +459,9 @@ class ViT(nn.Module):
439
459
  transformer_initial_ff_residual_path=True,
440
460
  transformer_initial_ff_linear_module_up=None,
441
461
  transformer_initial_ff_linear_module_down=None,
442
- transformer_initial_ff_mlp_dropout=None,
462
+ transformer_initial_ff_dropout=None,
463
+ transformer_initial_ff_inner_dropout=None,
464
+ transformer_initial_ff_outer_dropout=None,
443
465
  transformer_pre_norm=True,
444
466
  transformer_normformer=False,
445
467
  transformer_post_norm=False,
@@ -449,14 +471,16 @@ class ViT(nn.Module):
449
471
  transformer_layers=7,
450
472
  transformer_heads=4,
451
473
  transformer_mlp_ratio=2,
452
- transformer_bos_tokens=0,
453
- transformer_return_bos_tokens=False,
474
+ transformer_utility_tokens=0,
475
+ transformer_return_utility_tokens=False,
454
476
  transformer_activation: nn.Module = SquaredReLU,
455
477
  transformer_activation_kwargs: Optional[dict] = None,
456
478
  transformer_ff_linear_module_up=None,
457
479
  transformer_ff_linear_module_down=None,
458
480
  transformer_msa_scaling="d",
459
- transformer_mlp_dropout=0.0,
481
+ transformer_ff_dropout=0.0,
482
+ transformer_ff_inner_dropout=0.0,
483
+ transformer_ff_outer_dropout=0.0,
460
484
  transformer_msa_dropout=0.1,
461
485
  transformer_stochastic_depth=0.1,
462
486
  transformer_checkpoint_ff=True,
@@ -506,7 +530,9 @@ class ViT(nn.Module):
506
530
  transformer_initial_ff_residual_path=transformer_initial_ff_residual_path,
507
531
  transformer_initial_ff_linear_module_up=transformer_initial_ff_linear_module_up,
508
532
  transformer_initial_ff_linear_module_down=transformer_initial_ff_linear_module_down,
509
- transformer_initial_ff_mlp_dropout=transformer_initial_ff_mlp_dropout,
533
+ transformer_initial_ff_dropout=transformer_initial_ff_dropout,
534
+ transformer_initial_ff_inner_dropout=transformer_initial_ff_inner_dropout,
535
+ transformer_initial_ff_outer_dropout=transformer_initial_ff_outer_dropout,
510
536
  transformer_pre_norm=transformer_pre_norm,
511
537
  transformer_normformer=transformer_normformer,
512
538
  transformer_post_norm=transformer_post_norm,
@@ -516,14 +542,16 @@ class ViT(nn.Module):
516
542
  transformer_layers=transformer_layers,
517
543
  transformer_heads=transformer_heads,
518
544
  transformer_mlp_ratio=transformer_mlp_ratio,
519
- transformer_bos_tokens=transformer_bos_tokens,
520
- transformer_return_bos_tokens=transformer_return_bos_tokens,
545
+ transformer_utility_tokens=transformer_utility_tokens,
546
+ transformer_return_utility_tokens=transformer_return_utility_tokens,
521
547
  transformer_activation=transformer_activation,
522
548
  transformer_activation_kwargs=transformer_activation_kwargs,
523
549
  transformer_ff_linear_module_up=transformer_ff_linear_module_up,
524
550
  transformer_ff_linear_module_down=transformer_ff_linear_module_down,
525
551
  transformer_msa_scaling=transformer_msa_scaling,
526
- transformer_mlp_dropout=transformer_mlp_dropout,
552
+ transformer_ff_dropout=transformer_ff_dropout,
553
+ transformer_ff_inner_dropout=transformer_ff_inner_dropout,
554
+ transformer_ff_outer_dropout=transformer_ff_outer_dropout,
527
555
  transformer_msa_dropout=transformer_msa_dropout,
528
556
  transformer_stochastic_depth=transformer_stochastic_depth,
529
557
  transformer_checkpoint_ff=transformer_checkpoint_ff,
@@ -546,16 +574,26 @@ class ViT(nn.Module):
546
574
  def forward(self, x):
547
575
  return self.pool(self.encoder(x))
548
576
 
549
- def attention_scores(self, x):
550
- return self.encoder.attention_scores(x)
577
+ def attention_logits(self, x):
578
+ return self.encoder.attention_logits(x)
579
+
580
+ def pool_attention(self, x):
581
+ if hasattr(self.pool.summarize, "attention"):
582
+ return self.pool.summarize.attention(self.encoder(x))
583
+ else:
584
+ raise NotImplementedError(
585
+ "`pool_attention` is currently only implemented where"
586
+ " head class is SequencePoolClassificationHead"
587
+ )
551
588
 
552
- def head_to_bos_token_attention(self, x):
553
- all_attention = self.attention_scores(x)
589
+ def head_to_utility_token_attention_logits(self, x):
590
+ all_attention = self.attention_logits(x)
554
591
  batch_averages = torch.mean(all_attention, dim=0, keepdim=False)
555
592
  sequence_averages = torch.mean(batch_averages, dim=-1, keepdim=False)
556
- n_bos_tokens = self.encoder.encoder._bos_tokens
557
- just_bos = sequence_averages[:, :, :n_bos_tokens]
558
- return F.softmax(just_bos, dim=-1) # (layer, head, bos_token)
593
+ n_utility_tokens = self.encoder.encoder[-1]._utility_tokens
594
+ return sequence_averages[
595
+ :, :, :n_utility_tokens
596
+ ] # (layer, head, utility_tokens)
559
597
 
560
598
  def reset_parameters(self):
561
599
  self.encoder.reset_parameters()
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "broccoli-ml"
3
- version = "5.1.1"
3
+ version = "9.5.1"
4
4
  description = "Some useful Pytorch models, circa 2025"
5
5
  authors = [
6
6
  {name = "Nicholas Bailey"}
@@ -1,138 +0,0 @@
1
- import math
2
- import torch
3
- from torch import nn
4
- from torch.nn import functional as F
5
-
6
- from .tensor import SigmaReparamTensor, AnchoredReparamTensor, NormReparamTensor
7
-
8
-
9
- class SpectralNormLinear(nn.Module):
10
- """
11
- Inspired by Apple's Spectral Normed Linear Layers
12
- (https://github.com/apple/ml-sigma-reparam)
13
- """
14
-
15
- def __init__(self, in_features: int, out_features: int, bias: bool = True):
16
- super().__init__()
17
- self.in_features = in_features
18
- self.out_features = out_features
19
- self.use_bias = bias
20
-
21
- self.weights = None
22
-
23
- # Define the bias vector as a learnable parameter if required.
24
- if self.use_bias:
25
- self.bias = nn.Parameter(torch.empty(out_features))
26
- else:
27
- # If no bias, register it as None.
28
- # This is important so that PyTorch doesn't complain when saving/loading the model.
29
- self.register_parameter("bias", None)
30
-
31
- self.reset_parameters()
32
-
33
- def reset_parameters(self) -> None:
34
- weights = torch.empty(self.out_features, self.in_features)
35
- stdv = 1.0 / math.sqrt(self.in_features)
36
- nn.init.uniform_(weights, a=-stdv, b=stdv)
37
- if self.use_bias:
38
- fan_in, _ = nn.init._calculate_fan_in_and_fan_out(weights)
39
- bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
40
- nn.init.uniform_(self.bias, -bound, bound)
41
- self.weights = SigmaReparamTensor(weights)
42
-
43
- def forward(self, x: torch.Tensor) -> torch.Tensor:
44
- return F.linear(x, self.weights(), self.bias)
45
-
46
- def __repr__(self) -> str:
47
- # Optional: A nice representation for printing the module.
48
- return (
49
- f"SpectralNormFeedForward(in_features={self.in_features},"
50
- f"out_features={self.out_features}, bias={self.use_bias})"
51
- )
52
-
53
-
54
- class AnchoredLinear(nn.Module):
55
- """
56
- ...
57
- """
58
-
59
- def __init__(self, in_features: int, out_features: int, bias: bool = True):
60
- super().__init__()
61
- self.in_features = in_features
62
- self.out_features = out_features
63
- self.use_bias = bias
64
-
65
- self.weights = None
66
-
67
- # Define the bias vector as a learnable parameter if required.
68
- if self.use_bias:
69
- self.bias = nn.Parameter(torch.empty(out_features))
70
- else:
71
- # If no bias, register it as None.
72
- # This is important so that PyTorch doesn't complain when saving/loading the model.
73
- self.register_parameter("bias", None)
74
-
75
- self.reset_parameters()
76
-
77
- def reset_parameters(self) -> None:
78
- weights = torch.empty(self.out_features, self.in_features)
79
- stdv = 1.0 / math.sqrt(self.in_features)
80
- nn.init.uniform_(weights, a=-stdv, b=stdv)
81
- if self.use_bias:
82
- fan_in, _ = nn.init._calculate_fan_in_and_fan_out(weights)
83
- bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
84
- nn.init.uniform_(self.bias, -bound, bound)
85
- self.weights = AnchoredReparamTensor(weights)
86
-
87
- def forward(self, x: torch.Tensor) -> torch.Tensor:
88
- return F.linear(x, self.weights(), self.bias)
89
-
90
- def __repr__(self) -> str:
91
- # Optional: A nice representation for printing the module.
92
- return (
93
- f"AnchoredLinear(in_features={self.in_features},"
94
- f"out_features={self.out_features}, bias={self.use_bias})"
95
- )
96
-
97
-
98
- class WeightNormedLinear(nn.Module):
99
- """
100
- ...
101
- """
102
-
103
- def __init__(self, in_features: int, out_features: int, bias: bool = True):
104
- super().__init__()
105
- self.in_features = in_features
106
- self.out_features = out_features
107
- self.use_bias = bias
108
-
109
- self.weights = None
110
-
111
- # Define the bias vector as a learnable parameter if required.
112
- if self.use_bias:
113
- self.bias = nn.Parameter(torch.empty(out_features))
114
- else:
115
- # If no bias, register it as None.
116
- # This is important so that PyTorch doesn't complain when saving/loading the model.
117
- self.register_parameter("bias", None)
118
-
119
- self.reset_parameters()
120
-
121
- def reset_parameters(self) -> None:
122
- weights = torch.empty(self.out_features, self.in_features)
123
- stdv = 1.0 / math.sqrt(self.in_features)
124
- nn.init.uniform_(weights, a=-stdv, b=stdv)
125
- if self.use_bias:
126
- fan_in, _ = nn.init._calculate_fan_in_and_fan_out(weights)
127
- bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
128
- nn.init.uniform_(self.bias, -bound, bound)
129
- self.weights = NormReparamTensor(weights)
130
-
131
- def forward(self, x: torch.Tensor) -> torch.Tensor:
132
- return F.linear(x, self.weights(), self.bias)
133
-
134
- def __repr__(self) -> str:
135
- return (
136
- f"WeightNormedLinear(in_features={self.in_features},"
137
- f"out_features={self.out_features}, bias={self.use_bias})"
138
- )
File without changes
File without changes
File without changes