broccoli-ml 7.0.0__tar.gz → 9.5.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: broccoli-ml
3
- Version: 7.0.0
3
+ Version: 9.5.1
4
4
  Summary: Some useful Pytorch models, circa 2025
5
5
  License: MIT
6
6
  Author: Nicholas Bailey
@@ -46,10 +46,7 @@ class GELU(nn.Module):
46
46
 
47
47
  class Swish(nn.Module):
48
48
  """
49
- Implementation of (beta) SwiGLU, as introduced in "GLU Variants Improve Transformer"
50
- (https://arxiv.org/abs/2002.05202v1) and used to great effect in LLaMa 2.0.
51
-
52
- Halves the incoming parameter count, which should be scaled up before input.
49
+ Implementation of (beta) Swish
53
50
  """
54
51
 
55
52
  def __init__(self) -> None:
@@ -0,0 +1,352 @@
1
+ import math
2
+ import random
3
+ import warnings
4
+ from typing import Union, List, Iterable
5
+
6
+ import torch
7
+ from torch import nn
8
+ from torch.nn import functional as F
9
+
10
+ from .tensor import SigmaReparamTensor, AnchoredReparamTensor, NormReparamTensor
11
+
12
+
13
+ class SpectralNormLinear(nn.Module):
14
+ """
15
+ Inspired by Apple's Spectral Normed Linear Layers
16
+ (https://github.com/apple/ml-sigma-reparam)
17
+ """
18
+
19
+ def __init__(self, in_features: int, out_features: int, bias: bool = True):
20
+ super().__init__()
21
+ self.in_features = in_features
22
+ self.out_features = out_features
23
+ self.use_bias = bias
24
+
25
+ self.weights = None
26
+
27
+ # Define the bias vector as a learnable parameter if required.
28
+ if self.use_bias:
29
+ self.bias = nn.Parameter(torch.empty(out_features))
30
+ else:
31
+ # If no bias, register it as None.
32
+ # This is important so that PyTorch doesn't complain when saving/loading the model.
33
+ self.register_parameter("bias", None)
34
+
35
+ self.reset_parameters()
36
+
37
+ def reset_parameters(self) -> None:
38
+ weights = torch.empty(self.out_features, self.in_features)
39
+ stdv = 1.0 / math.sqrt(self.in_features)
40
+ nn.init.uniform_(weights, a=-stdv, b=stdv)
41
+ if self.use_bias:
42
+ fan_in, _ = nn.init._calculate_fan_in_and_fan_out(weights)
43
+ bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
44
+ nn.init.uniform_(self.bias, -bound, bound)
45
+ self.weights = SigmaReparamTensor(weights)
46
+
47
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
48
+ return F.linear(x, self.weights(), self.bias)
49
+
50
+ def __repr__(self) -> str:
51
+ # Optional: A nice representation for printing the module.
52
+ return (
53
+ f"SpectralNormFeedForward(in_features={self.in_features},"
54
+ f"out_features={self.out_features}, bias={self.use_bias})"
55
+ )
56
+
57
+
58
+ class AnchoredLinear(nn.Module):
59
+ """
60
+ ...
61
+ """
62
+
63
+ def __init__(self, in_features: int, out_features: int, bias: bool = True):
64
+ super().__init__()
65
+ self.in_features = in_features
66
+ self.out_features = out_features
67
+ self.use_bias = bias
68
+
69
+ self.weights = None
70
+
71
+ # Define the bias vector as a learnable parameter if required.
72
+ if self.use_bias:
73
+ self.bias = nn.Parameter(torch.empty(out_features))
74
+ else:
75
+ # If no bias, register it as None.
76
+ # This is important so that PyTorch doesn't complain when saving/loading the model.
77
+ self.register_parameter("bias", None)
78
+
79
+ self.reset_parameters()
80
+
81
+ def reset_parameters(self) -> None:
82
+ weights = torch.empty(self.out_features, self.in_features)
83
+ stdv = 1.0 / math.sqrt(self.in_features)
84
+ nn.init.uniform_(weights, a=-stdv, b=stdv)
85
+ if self.use_bias:
86
+ fan_in, _ = nn.init._calculate_fan_in_and_fan_out(weights)
87
+ bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
88
+ nn.init.uniform_(self.bias, -bound, bound)
89
+ self.weights = AnchoredReparamTensor(weights)
90
+
91
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
92
+ return F.linear(x, self.weights(), self.bias)
93
+
94
+ def __repr__(self) -> str:
95
+ # Optional: A nice representation for printing the module.
96
+ return (
97
+ f"AnchoredLinear(in_features={self.in_features},"
98
+ f"out_features={self.out_features}, bias={self.use_bias})"
99
+ )
100
+
101
+
102
+ class WeightNormedLinear(nn.Module):
103
+ """
104
+ ...
105
+ """
106
+
107
+ def __init__(self, in_features: int, out_features: int, bias: bool = True):
108
+ super().__init__()
109
+ self.in_features = in_features
110
+ self.out_features = out_features
111
+ self.use_bias = bias
112
+
113
+ self.weights = None
114
+
115
+ # Define the bias vector as a learnable parameter if required.
116
+ if self.use_bias:
117
+ self.bias = nn.Parameter(torch.empty(out_features))
118
+ else:
119
+ # If no bias, register it as None.
120
+ # This is important so that PyTorch doesn't complain when saving/loading the model.
121
+ self.register_parameter("bias", None)
122
+
123
+ self.reset_parameters()
124
+
125
+ def reset_parameters(self) -> None:
126
+ weights = torch.empty(self.out_features, self.in_features)
127
+ stdv = 1.0 / math.sqrt(self.in_features)
128
+ nn.init.uniform_(weights, a=-stdv, b=stdv)
129
+ if self.use_bias:
130
+ fan_in, _ = nn.init._calculate_fan_in_and_fan_out(weights)
131
+ bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
132
+ nn.init.uniform_(self.bias, -bound, bound)
133
+ self.weights = NormReparamTensor(weights)
134
+
135
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
136
+ return F.linear(x, self.weights(), self.bias)
137
+
138
+ def __repr__(self) -> str:
139
+ return (
140
+ f"WeightNormedLinear(in_features={self.in_features},"
141
+ f"out_features={self.out_features}, bias={self.use_bias})"
142
+ )
143
+
144
+
145
+ class RecyclingLinear(nn.Module):
146
+ def __init__(
147
+ self,
148
+ in_features: int,
149
+ out_features: int,
150
+ bias: bool = True,
151
+ row_recycling_rate: float = 0.0,
152
+ column_recycling_rate: float = 0.0,
153
+ adaptive=False,
154
+ xglu=False,
155
+ ):
156
+ super().__init__()
157
+ self.in_features = in_features
158
+ self.out_features = out_features
159
+ self.bias = bias
160
+ self.xglu = xglu
161
+ self.linear = nn.Linear(in_features, out_features, bias=bias)
162
+ self.row_recycling_rate = row_recycling_rate
163
+ self.column_recycling_rate = column_recycling_rate
164
+ self.adaptive = adaptive
165
+ self.optimisers = []
166
+ self.initial_learning_rates = []
167
+ self._warned_about_registration = False
168
+
169
+ def register_optimiser(self, optimiser: torch.optim.Optimizer):
170
+ self.optimisers.append(optimiser)
171
+ self.initial_learning_rates.append(self._get_learning_rate(optimiser))
172
+ if self.initial_learning_rates[-1] == 0.0:
173
+ warnings.warn(
174
+ "Learning rate of registered optimiser was 0.0 - make sure "
175
+ "you haven't initialised a scheduler before registering the "
176
+ "optimiser",
177
+ stacklevel=2,
178
+ )
179
+
180
+ def _get_learning_rate(self, optimiser: torch.optim.Optimizer):
181
+ for group in optimiser.param_groups:
182
+ for param in group["params"]:
183
+ if param is self.linear.weight:
184
+ return group["lr"]
185
+
186
+ def _get_multiplier(self):
187
+ if not self.adaptive or not self.optimisers:
188
+ return 1.0
189
+ else:
190
+ init = self.initial_learning_rates
191
+ current = [self._get_learning_rate(o) for o in self.optimisers]
192
+ pairs = zip(current, init, strict=True)
193
+ multipliers = [a / b for a, b in pairs if b != 0.0]
194
+ return min(multipliers) if multipliers else 0.0
195
+
196
+ def reset_rows(self, indices):
197
+ if not torch.is_tensor(indices):
198
+ idx_tensor = torch.as_tensor(
199
+ list(indices), dtype=torch.long, device=self.linear.weight.device
200
+ )
201
+ else:
202
+ idx_tensor = indices
203
+
204
+ if idx_tensor.size(0):
205
+ value_indices = indices
206
+ centred_value_weights = self._mean_value_weights()
207
+ centred_value_weights = centred_value_weights.expand(indices.size(0), -1)
208
+ if self.xglu:
209
+ gate_indices = indices
210
+ value_indices = indices + (self.linear.out_features // 2)
211
+ centred_gate_weights = self._mean_gate_weights()
212
+ centred_gate_weights = centred_gate_weights.expand(indices.size(0), -1)
213
+ self._update_weights(
214
+ gate_indices, 0, centred_gate_weights, self.optimisers # dim
215
+ )
216
+ self._update_weights(
217
+ value_indices, 0, centred_value_weights, self.optimisers
218
+ )
219
+ else:
220
+ return
221
+
222
+ def reset_columns(self, indices):
223
+ if not torch.is_tensor(indices):
224
+ idx_tensor = torch.as_tensor(
225
+ list(indices), dtype=torch.long, device=self.linear.weight.device
226
+ )
227
+ else:
228
+ idx_tensor = indices
229
+
230
+ if idx_tensor.size(0):
231
+ random_weights = self._random_weights(
232
+ self.linear.weight.size(0), indices.size(0)
233
+ )
234
+ # Make random col weights quiet so they don't introduce loud noise...
235
+ # ...but not so quiet that FP16 zeros them and ruins symmetry breaking!
236
+ random_weights *= 0.1
237
+ self._update_weights(indices, 1, random_weights, self.optimisers) # dim
238
+ else:
239
+ return
240
+
241
+ def forward(self, x):
242
+ if self.training and self.optimisers:
243
+ self.reset_rows(self.get_reset_indices(0))
244
+ self.reset_columns(self.get_reset_indices(1))
245
+ elif self.training and not self._warned_about_registration:
246
+ warnings.warn(
247
+ "RecyclingLinear: No optimiser registered. Recycling disabled.",
248
+ stacklevel=2,
249
+ )
250
+ self._warned_about_registration = True
251
+
252
+ return self.linear(x)
253
+
254
+ def get_reset_indices(self, dim):
255
+ base_rate = self.row_recycling_rate if dim == 0 else self.column_recycling_rate
256
+ p = base_rate * self._get_multiplier()
257
+ if dim == 0:
258
+ if self.xglu:
259
+ sample_space = self.linear.out_features // 2
260
+ else:
261
+ sample_space = self.linear.out_features
262
+ elif dim == 1:
263
+ sample_space = self.linear.in_features
264
+ else:
265
+ raise ValueError("`dim` must be 0 or 1")
266
+
267
+ # Sample the indices
268
+ probs = torch.rand(sample_space, device=self.linear.weight.device)
269
+ mask = probs < p
270
+ if mask.any():
271
+ return torch.nonzero(mask).squeeze(-1)
272
+ else:
273
+ return torch.tensor([], dtype=torch.long, device=self.linear.weight.device)
274
+
275
+ def _random_weights(self, rows, columns):
276
+ device = self.linear.weight.device
277
+ weights = self.linear.weight.data
278
+ stdv = 1.0 / math.sqrt(weights.size(1))
279
+ random_weights = torch.rand(rows, columns, device=device)
280
+ random_weights -= 0.5 # Range [-0.5, +0.5]
281
+ random_weights *= 2.0 * stdv # Range [-stdv, +stdv]
282
+ return random_weights
283
+
284
+ def _mean_value_weights(self):
285
+ """
286
+ Only used when self.xglu
287
+ """
288
+ weights = self.linear.weight.data
289
+ rows = weights.size(0)
290
+ if self.xglu:
291
+ return self.linear.weight[int(rows / 2) :].data.mean(dim=0, keepdim=True)
292
+ else:
293
+ return self.linear.weight.data.mean(dim=0, keepdim=True)
294
+
295
+ def _mean_gate_weights(self):
296
+ """
297
+ Only used when self.xglu
298
+ """
299
+ weights = self.linear.weight.data
300
+ rows = weights.size(0)
301
+ return self.linear.weight[: int(rows / 2)].data.mean(dim=0, keepdim=True)
302
+
303
+ def _update_weights(
304
+ self,
305
+ indices: Iterable[int],
306
+ dim: int,
307
+ data: torch.Tensor,
308
+ optimisers: Union[
309
+ List[torch.optim.Optimizer], torch.optim.Optimizer, None
310
+ ] = None,
311
+ ):
312
+ if optimisers is None:
313
+ optimisers = []
314
+ if not isinstance(optimisers, list):
315
+ optimisers = [optimisers]
316
+
317
+ if not torch.is_tensor(indices):
318
+ idx_tensor = torch.as_tensor(
319
+ list(indices), dtype=torch.long, device=self.linear.weight.device
320
+ )
321
+ else:
322
+ idx_tensor = indices
323
+
324
+ if idx_tensor.numel() == 0:
325
+ return
326
+
327
+ with torch.no_grad():
328
+ if dim == 0:
329
+ self.linear.weight.data[idx_tensor] = data
330
+ elif dim == 1:
331
+ self.linear.weight.data[:, idx_tensor] = data
332
+ else:
333
+ raise ValueError("`dim` must be 0 or 1")
334
+
335
+ self._reset_optim_state(self.linear.weight, idx_tensor, optimisers, dim=dim)
336
+
337
+ def _reset_optim_state(self, param, idx_tensor, optimisers, dim):
338
+ """
339
+ Zeroes out the optimizer state for the given indices in a single operation.
340
+ """
341
+ for optimiser in optimisers:
342
+ if param not in optimiser.state:
343
+ continue
344
+ state = optimiser.state[param]
345
+
346
+ for _, buffer in state.items():
347
+ if torch.is_tensor(buffer) and buffer.shape == param.shape:
348
+ # Vectorized zeroing
349
+ if dim == 0:
350
+ buffer[idx_tensor] = 0.0
351
+ else:
352
+ buffer[:, idx_tensor] = 0.0
@@ -325,6 +325,8 @@ class FeedforwardBlock(nn.Module):
325
325
  activation=nn.ReLU,
326
326
  activation_kwargs=None,
327
327
  dropout=0.0,
328
+ inner_dropout=None,
329
+ outer_dropout=None,
328
330
  linear_module_up=nn.Linear,
329
331
  linear_module_down=nn.Linear,
330
332
  pre_norm=True,
@@ -338,6 +340,7 @@ class FeedforwardBlock(nn.Module):
338
340
  self.checkpoint = checkpoint
339
341
  self.residual_path = residual_path
340
342
  self.post_norm = post_norm
343
+ self.xglu = activation.__name__.endswith("GLU")
341
344
 
342
345
  if self.residual_path and (output_features < input_features):
343
346
  raise ValueError(
@@ -354,29 +357,63 @@ class FeedforwardBlock(nn.Module):
354
357
  else:
355
358
  self.activation = activation()
356
359
 
357
- self.dropout = nn.Dropout(dropout)
360
+ self.inner_dropout = nn.Dropout(
361
+ inner_dropout if inner_dropout is not None else dropout
362
+ )
363
+ self.outer_dropout = nn.Dropout(
364
+ outer_dropout if outer_dropout is not None else dropout
365
+ )
358
366
 
359
367
  self.max_features = (
360
- 2 * ratio * output_features
361
- if activation.__name__.endswith("GLU")
362
- else ratio * output_features
368
+ 2 * ratio * output_features if self.xglu else ratio * output_features
363
369
  )
364
370
 
371
+ self.linear_in = linear_module_up(input_features, self.max_features)
372
+ self.linear_out = linear_module_down(ratio * output_features, output_features)
373
+
365
374
  self.process = nn.Sequential(
366
375
  *[
367
376
  nn.LayerNorm(input_features) if pre_norm else nn.Identity(),
368
- linear_module_up(input_features, self.max_features),
377
+ self.linear_in,
369
378
  self.activation,
379
+ self.inner_dropout,
370
380
  nn.LayerNorm(ratio * output_features) if normformer else nn.Identity(),
371
- linear_module_down(ratio * output_features, output_features),
372
- self.dropout,
381
+ self.linear_out,
382
+ self.outer_dropout,
373
383
  ]
374
384
  )
375
385
 
386
+ self.recycling_enabled = False
387
+ if hasattr(self.linear_in, "row_recycling_rate") and hasattr(
388
+ self.linear_out, "column_recycling_rate"
389
+ ):
390
+ self.recycling_enabled = True
391
+ self.master_recycling_rate = self.linear_in.row_recycling_rate
392
+ self.linear_in.row_recycling_rate = 0.0
393
+ self.linear_out.column_recycling_rate = 0.0
394
+ if (
395
+ hasattr(self.linear_in, "column_recycling_rate")
396
+ and self.linear_in.column_recycling_rate > 0
397
+ ) or (
398
+ hasattr(self.linear_out, "row_recycling_rate")
399
+ and self.linear_out.row_recycling_rate > 0
400
+ ):
401
+ raise NotImplementedError(
402
+ "At the moment this layer can only support recycling linear "
403
+ "layers if the in layer resets only rows and the out layer "
404
+ "resets only columns."
405
+ )
406
+
376
407
  self.reset_parameters()
377
408
 
378
409
  def forward(self, x):
379
410
 
411
+ # Recycle weights if using recycling linear layers
412
+ if self.training and self.recycling_enabled:
413
+ indices = self.linear_out.get_reset_indices(1)
414
+ self.linear_in.reset_rows(indices)
415
+ self.linear_out.reset_columns(indices)
416
+
380
417
  if self.checkpoint:
381
418
  processed = checkpoint(self.process, x, use_reentrant=False)
382
419
  else:
@@ -422,7 +459,9 @@ class TransformerBlock(nn.Module):
422
459
  ff_linear_module_up=None,
423
460
  ff_linear_module_down=None,
424
461
  msa_scaling="d",
425
- mlp_dropout=0.0,
462
+ ff_dropout=0.0,
463
+ ff_inner_dropout=0.0,
464
+ ff_outer_dropout=0.0,
426
465
  msa_dropout=0.0,
427
466
  identity_probability=0.0,
428
467
  causal=False,
@@ -484,7 +523,9 @@ class TransformerBlock(nn.Module):
484
523
  d_model,
485
524
  activation=activation,
486
525
  activation_kwargs=activation_kwargs,
487
- dropout=mlp_dropout,
526
+ dropout=ff_dropout,
527
+ inner_dropout=ff_inner_dropout,
528
+ outer_dropout=ff_outer_dropout,
488
529
  linear_module_up=(
489
530
  ff_linear_module_up
490
531
  if ff_linear_module_up is not None
@@ -567,7 +608,9 @@ class TransformerEncoder(nn.Module):
567
608
  activation_kwargs: Optional[dict] = None,
568
609
  ff_linear_module_up=None,
569
610
  ff_linear_module_down=None,
570
- mlp_dropout=0.0,
611
+ ff_dropout=0.0,
612
+ ff_inner_dropout=0.0,
613
+ ff_outer_dropout=0.0,
571
614
  msa_dropout=0.0,
572
615
  stochastic_depth=0.0,
573
616
  causal=False,
@@ -629,7 +672,7 @@ class TransformerEncoder(nn.Module):
629
672
  else:
630
673
  self.absolute_position_embedding = None
631
674
 
632
- self.mlp_dropout = mlp_dropout
675
+ self.mlp_dropout = ff_dropout
633
676
  self.msa_dropout = msa_dropout
634
677
  self.stochastic_depth = stochastic_depth
635
678
 
@@ -658,7 +701,9 @@ class TransformerEncoder(nn.Module):
658
701
  ff_linear_module_up=ff_linear_module_up,
659
702
  ff_linear_module_down=ff_linear_module_down,
660
703
  msa_scaling=msa_scaling,
661
- mlp_dropout=mlp_dropout,
704
+ ff_dropout=ff_dropout,
705
+ ff_inner_dropout=ff_inner_dropout,
706
+ ff_outer_dropout=ff_outer_dropout,
662
707
  msa_dropout=msa_dropout,
663
708
  identity_probability=self.stochastic_depth_probabilities[i],
664
709
  causal=causal,
@@ -161,7 +161,9 @@ class ViTEncoder(nn.Module):
161
161
  transformer_initial_ff_residual_path=True,
162
162
  transformer_initial_ff_linear_module_up=None,
163
163
  transformer_initial_ff_linear_module_down=None,
164
- transformer_initial_ff_mlp_dropout=None,
164
+ transformer_initial_ff_dropout=None,
165
+ transformer_initial_ff_inner_dropout=None,
166
+ transformer_initial_ff_outer_dropout=None,
165
167
  transformer_pre_norm=True,
166
168
  transformer_normformer=False,
167
169
  transformer_post_norm=False,
@@ -178,7 +180,9 @@ class ViTEncoder(nn.Module):
178
180
  transformer_ff_linear_module_up=None,
179
181
  transformer_ff_linear_module_down=None,
180
182
  transformer_msa_scaling="d",
181
- transformer_mlp_dropout=0.0,
183
+ transformer_ff_dropout=0.0,
184
+ transformer_ff_inner_dropout=0.0,
185
+ transformer_ff_outer_dropout=0.0,
182
186
  transformer_msa_dropout=0.1,
183
187
  transformer_stochastic_depth=0.1,
184
188
  transformer_checkpoint_ff=True,
@@ -333,7 +337,9 @@ class ViTEncoder(nn.Module):
333
337
  ff_linear_module_up=transformer_ff_linear_module_up,
334
338
  ff_linear_module_down=transformer_ff_linear_module_down,
335
339
  msa_scaling=transformer_msa_scaling,
336
- mlp_dropout=transformer_mlp_dropout,
340
+ ff_dropout=transformer_ff_dropout,
341
+ ff_inner_dropout=transformer_ff_inner_dropout,
342
+ ff_outer_dropout=transformer_ff_outer_dropout,
337
343
  msa_dropout=transformer_msa_dropout,
338
344
  stochastic_depth=transformer_stochastic_depth,
339
345
  causal=False,
@@ -357,9 +363,21 @@ class ViTEncoder(nn.Module):
357
363
  activation_kwargs=transformer_activation_kwargs,
358
364
  dropout=(
359
365
  # First truthy assigned value
360
- transformer_initial_ff_mlp_dropout
361
- if transformer_initial_ff_mlp_dropout is not None
362
- else transformer_mlp_dropout
366
+ transformer_initial_ff_dropout
367
+ if transformer_initial_ff_dropout is not None
368
+ else transformer_ff_dropout
369
+ ),
370
+ inner_dropout=(
371
+ # First truthy assigned value
372
+ transformer_initial_ff_inner_dropout
373
+ if transformer_initial_ff_inner_dropout is not None
374
+ else transformer_ff_inner_dropout
375
+ ),
376
+ outer_dropout=(
377
+ # First truthy assigned value
378
+ transformer_initial_ff_outer_dropout
379
+ if transformer_initial_ff_outer_dropout is not None
380
+ else transformer_ff_outer_dropout
363
381
  ),
364
382
  linear_module_up=(
365
383
  # First truthy assigned value
@@ -441,7 +459,9 @@ class ViT(nn.Module):
441
459
  transformer_initial_ff_residual_path=True,
442
460
  transformer_initial_ff_linear_module_up=None,
443
461
  transformer_initial_ff_linear_module_down=None,
444
- transformer_initial_ff_mlp_dropout=None,
462
+ transformer_initial_ff_dropout=None,
463
+ transformer_initial_ff_inner_dropout=None,
464
+ transformer_initial_ff_outer_dropout=None,
445
465
  transformer_pre_norm=True,
446
466
  transformer_normformer=False,
447
467
  transformer_post_norm=False,
@@ -458,7 +478,9 @@ class ViT(nn.Module):
458
478
  transformer_ff_linear_module_up=None,
459
479
  transformer_ff_linear_module_down=None,
460
480
  transformer_msa_scaling="d",
461
- transformer_mlp_dropout=0.0,
481
+ transformer_ff_dropout=0.0,
482
+ transformer_ff_inner_dropout=0.0,
483
+ transformer_ff_outer_dropout=0.0,
462
484
  transformer_msa_dropout=0.1,
463
485
  transformer_stochastic_depth=0.1,
464
486
  transformer_checkpoint_ff=True,
@@ -508,7 +530,9 @@ class ViT(nn.Module):
508
530
  transformer_initial_ff_residual_path=transformer_initial_ff_residual_path,
509
531
  transformer_initial_ff_linear_module_up=transformer_initial_ff_linear_module_up,
510
532
  transformer_initial_ff_linear_module_down=transformer_initial_ff_linear_module_down,
511
- transformer_initial_ff_mlp_dropout=transformer_initial_ff_mlp_dropout,
533
+ transformer_initial_ff_dropout=transformer_initial_ff_dropout,
534
+ transformer_initial_ff_inner_dropout=transformer_initial_ff_inner_dropout,
535
+ transformer_initial_ff_outer_dropout=transformer_initial_ff_outer_dropout,
512
536
  transformer_pre_norm=transformer_pre_norm,
513
537
  transformer_normformer=transformer_normformer,
514
538
  transformer_post_norm=transformer_post_norm,
@@ -525,7 +549,9 @@ class ViT(nn.Module):
525
549
  transformer_ff_linear_module_up=transformer_ff_linear_module_up,
526
550
  transformer_ff_linear_module_down=transformer_ff_linear_module_down,
527
551
  transformer_msa_scaling=transformer_msa_scaling,
528
- transformer_mlp_dropout=transformer_mlp_dropout,
552
+ transformer_ff_dropout=transformer_ff_dropout,
553
+ transformer_ff_inner_dropout=transformer_ff_inner_dropout,
554
+ transformer_ff_outer_dropout=transformer_ff_outer_dropout,
529
555
  transformer_msa_dropout=transformer_msa_dropout,
530
556
  transformer_stochastic_depth=transformer_stochastic_depth,
531
557
  transformer_checkpoint_ff=transformer_checkpoint_ff,
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "broccoli-ml"
3
- version = "7.0.0"
3
+ version = "9.5.1"
4
4
  description = "Some useful Pytorch models, circa 2025"
5
5
  authors = [
6
6
  {name = "Nicholas Bailey"}
@@ -1,138 +0,0 @@
1
- import math
2
- import torch
3
- from torch import nn
4
- from torch.nn import functional as F
5
-
6
- from .tensor import SigmaReparamTensor, AnchoredReparamTensor, NormReparamTensor
7
-
8
-
9
- class SpectralNormLinear(nn.Module):
10
- """
11
- Inspired by Apple's Spectral Normed Linear Layers
12
- (https://github.com/apple/ml-sigma-reparam)
13
- """
14
-
15
- def __init__(self, in_features: int, out_features: int, bias: bool = True):
16
- super().__init__()
17
- self.in_features = in_features
18
- self.out_features = out_features
19
- self.use_bias = bias
20
-
21
- self.weights = None
22
-
23
- # Define the bias vector as a learnable parameter if required.
24
- if self.use_bias:
25
- self.bias = nn.Parameter(torch.empty(out_features))
26
- else:
27
- # If no bias, register it as None.
28
- # This is important so that PyTorch doesn't complain when saving/loading the model.
29
- self.register_parameter("bias", None)
30
-
31
- self.reset_parameters()
32
-
33
- def reset_parameters(self) -> None:
34
- weights = torch.empty(self.out_features, self.in_features)
35
- stdv = 1.0 / math.sqrt(self.in_features)
36
- nn.init.uniform_(weights, a=-stdv, b=stdv)
37
- if self.use_bias:
38
- fan_in, _ = nn.init._calculate_fan_in_and_fan_out(weights)
39
- bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
40
- nn.init.uniform_(self.bias, -bound, bound)
41
- self.weights = SigmaReparamTensor(weights)
42
-
43
- def forward(self, x: torch.Tensor) -> torch.Tensor:
44
- return F.linear(x, self.weights(), self.bias)
45
-
46
- def __repr__(self) -> str:
47
- # Optional: A nice representation for printing the module.
48
- return (
49
- f"SpectralNormFeedForward(in_features={self.in_features},"
50
- f"out_features={self.out_features}, bias={self.use_bias})"
51
- )
52
-
53
-
54
- class AnchoredLinear(nn.Module):
55
- """
56
- ...
57
- """
58
-
59
- def __init__(self, in_features: int, out_features: int, bias: bool = True):
60
- super().__init__()
61
- self.in_features = in_features
62
- self.out_features = out_features
63
- self.use_bias = bias
64
-
65
- self.weights = None
66
-
67
- # Define the bias vector as a learnable parameter if required.
68
- if self.use_bias:
69
- self.bias = nn.Parameter(torch.empty(out_features))
70
- else:
71
- # If no bias, register it as None.
72
- # This is important so that PyTorch doesn't complain when saving/loading the model.
73
- self.register_parameter("bias", None)
74
-
75
- self.reset_parameters()
76
-
77
- def reset_parameters(self) -> None:
78
- weights = torch.empty(self.out_features, self.in_features)
79
- stdv = 1.0 / math.sqrt(self.in_features)
80
- nn.init.uniform_(weights, a=-stdv, b=stdv)
81
- if self.use_bias:
82
- fan_in, _ = nn.init._calculate_fan_in_and_fan_out(weights)
83
- bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
84
- nn.init.uniform_(self.bias, -bound, bound)
85
- self.weights = AnchoredReparamTensor(weights)
86
-
87
- def forward(self, x: torch.Tensor) -> torch.Tensor:
88
- return F.linear(x, self.weights(), self.bias)
89
-
90
- def __repr__(self) -> str:
91
- # Optional: A nice representation for printing the module.
92
- return (
93
- f"AnchoredLinear(in_features={self.in_features},"
94
- f"out_features={self.out_features}, bias={self.use_bias})"
95
- )
96
-
97
-
98
- class WeightNormedLinear(nn.Module):
99
- """
100
- ...
101
- """
102
-
103
- def __init__(self, in_features: int, out_features: int, bias: bool = True):
104
- super().__init__()
105
- self.in_features = in_features
106
- self.out_features = out_features
107
- self.use_bias = bias
108
-
109
- self.weights = None
110
-
111
- # Define the bias vector as a learnable parameter if required.
112
- if self.use_bias:
113
- self.bias = nn.Parameter(torch.empty(out_features))
114
- else:
115
- # If no bias, register it as None.
116
- # This is important so that PyTorch doesn't complain when saving/loading the model.
117
- self.register_parameter("bias", None)
118
-
119
- self.reset_parameters()
120
-
121
- def reset_parameters(self) -> None:
122
- weights = torch.empty(self.out_features, self.in_features)
123
- stdv = 1.0 / math.sqrt(self.in_features)
124
- nn.init.uniform_(weights, a=-stdv, b=stdv)
125
- if self.use_bias:
126
- fan_in, _ = nn.init._calculate_fan_in_and_fan_out(weights)
127
- bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
128
- nn.init.uniform_(self.bias, -bound, bound)
129
- self.weights = NormReparamTensor(weights)
130
-
131
- def forward(self, x: torch.Tensor) -> torch.Tensor:
132
- return F.linear(x, self.weights(), self.bias)
133
-
134
- def __repr__(self) -> str:
135
- return (
136
- f"WeightNormedLinear(in_features={self.in_features},"
137
- f"out_features={self.out_features}, bias={self.use_bias})"
138
- )
File without changes
File without changes
File without changes