broccoli-ml 7.0.0__tar.gz → 9.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: broccoli-ml
3
- Version: 7.0.0
3
+ Version: 9.1.0
4
4
  Summary: Some useful Pytorch models, circa 2025
5
5
  License: MIT
6
6
  Author: Nicholas Bailey
@@ -0,0 +1,314 @@
1
+ import math
2
+ import random
3
+ import warnings
4
+ from typing import Union, List, Iterable
5
+
6
+ import torch
7
+ from torch import nn
8
+ from torch.nn import functional as F
9
+
10
+ from .tensor import SigmaReparamTensor, AnchoredReparamTensor, NormReparamTensor
11
+
12
+
13
+ class SpectralNormLinear(nn.Module):
14
+ """
15
+ Inspired by Apple's Spectral Normed Linear Layers
16
+ (https://github.com/apple/ml-sigma-reparam)
17
+ """
18
+
19
+ def __init__(self, in_features: int, out_features: int, bias: bool = True):
20
+ super().__init__()
21
+ self.in_features = in_features
22
+ self.out_features = out_features
23
+ self.use_bias = bias
24
+
25
+ self.weights = None
26
+
27
+ # Define the bias vector as a learnable parameter if required.
28
+ if self.use_bias:
29
+ self.bias = nn.Parameter(torch.empty(out_features))
30
+ else:
31
+ # If no bias, register it as None.
32
+ # This is important so that PyTorch doesn't complain when saving/loading the model.
33
+ self.register_parameter("bias", None)
34
+
35
+ self.reset_parameters()
36
+
37
+ def reset_parameters(self) -> None:
38
+ weights = torch.empty(self.out_features, self.in_features)
39
+ stdv = 1.0 / math.sqrt(self.in_features)
40
+ nn.init.uniform_(weights, a=-stdv, b=stdv)
41
+ if self.use_bias:
42
+ fan_in, _ = nn.init._calculate_fan_in_and_fan_out(weights)
43
+ bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
44
+ nn.init.uniform_(self.bias, -bound, bound)
45
+ self.weights = SigmaReparamTensor(weights)
46
+
47
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
48
+ return F.linear(x, self.weights(), self.bias)
49
+
50
+ def __repr__(self) -> str:
51
+ # Optional: A nice representation for printing the module.
52
+ return (
53
+ f"SpectralNormFeedForward(in_features={self.in_features},"
54
+ f"out_features={self.out_features}, bias={self.use_bias})"
55
+ )
56
+
57
+
58
+ class AnchoredLinear(nn.Module):
59
+ """
60
+ ...
61
+ """
62
+
63
+ def __init__(self, in_features: int, out_features: int, bias: bool = True):
64
+ super().__init__()
65
+ self.in_features = in_features
66
+ self.out_features = out_features
67
+ self.use_bias = bias
68
+
69
+ self.weights = None
70
+
71
+ # Define the bias vector as a learnable parameter if required.
72
+ if self.use_bias:
73
+ self.bias = nn.Parameter(torch.empty(out_features))
74
+ else:
75
+ # If no bias, register it as None.
76
+ # This is important so that PyTorch doesn't complain when saving/loading the model.
77
+ self.register_parameter("bias", None)
78
+
79
+ self.reset_parameters()
80
+
81
+ def reset_parameters(self) -> None:
82
+ weights = torch.empty(self.out_features, self.in_features)
83
+ stdv = 1.0 / math.sqrt(self.in_features)
84
+ nn.init.uniform_(weights, a=-stdv, b=stdv)
85
+ if self.use_bias:
86
+ fan_in, _ = nn.init._calculate_fan_in_and_fan_out(weights)
87
+ bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
88
+ nn.init.uniform_(self.bias, -bound, bound)
89
+ self.weights = AnchoredReparamTensor(weights)
90
+
91
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
92
+ return F.linear(x, self.weights(), self.bias)
93
+
94
+ def __repr__(self) -> str:
95
+ # Optional: A nice representation for printing the module.
96
+ return (
97
+ f"AnchoredLinear(in_features={self.in_features},"
98
+ f"out_features={self.out_features}, bias={self.use_bias})"
99
+ )
100
+
101
+
102
+ class WeightNormedLinear(nn.Module):
103
+ """
104
+ ...
105
+ """
106
+
107
+ def __init__(self, in_features: int, out_features: int, bias: bool = True):
108
+ super().__init__()
109
+ self.in_features = in_features
110
+ self.out_features = out_features
111
+ self.use_bias = bias
112
+
113
+ self.weights = None
114
+
115
+ # Define the bias vector as a learnable parameter if required.
116
+ if self.use_bias:
117
+ self.bias = nn.Parameter(torch.empty(out_features))
118
+ else:
119
+ # If no bias, register it as None.
120
+ # This is important so that PyTorch doesn't complain when saving/loading the model.
121
+ self.register_parameter("bias", None)
122
+
123
+ self.reset_parameters()
124
+
125
+ def reset_parameters(self) -> None:
126
+ weights = torch.empty(self.out_features, self.in_features)
127
+ stdv = 1.0 / math.sqrt(self.in_features)
128
+ nn.init.uniform_(weights, a=-stdv, b=stdv)
129
+ if self.use_bias:
130
+ fan_in, _ = nn.init._calculate_fan_in_and_fan_out(weights)
131
+ bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
132
+ nn.init.uniform_(self.bias, -bound, bound)
133
+ self.weights = NormReparamTensor(weights)
134
+
135
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
136
+ return F.linear(x, self.weights(), self.bias)
137
+
138
+ def __repr__(self) -> str:
139
+ return (
140
+ f"WeightNormedLinear(in_features={self.in_features},"
141
+ f"out_features={self.out_features}, bias={self.use_bias})"
142
+ )
143
+
144
+
145
+ class RecyclingLinear(nn.Module):
146
+ def __init__(
147
+ self,
148
+ in_features: int,
149
+ out_features: int,
150
+ bias: bool = True,
151
+ row_recycling_rate: float = 0.0,
152
+ column_recycling_rate: float = 0.0,
153
+ adaptive=False,
154
+ ):
155
+ super().__init__()
156
+ self.linear = nn.Linear(in_features, out_features, bias=bias)
157
+ self.row_recycling_rate = row_recycling_rate
158
+ self.column_recycling_rate = column_recycling_rate
159
+ self.adaptive = adaptive
160
+ self.optimisers = []
161
+ self.initial_learning_rates = []
162
+ self._warned_about_registration = False
163
+
164
+ def register_optimiser(self, optimiser: torch.optim.Optimizer):
165
+ self.optimisers.append(optimiser)
166
+ self.initial_learning_rates.append(self._get_learning_rate(optimiser))
167
+ if self.initial_learning_rates[-1] == 0.0:
168
+ warnings.warn(
169
+ "Learning rate of registered optimiser was 0.0 - make sure "
170
+ "you haven't initialised a scheduler before registering the "
171
+ "optimiser",
172
+ stacklevel=2,
173
+ )
174
+
175
+ def _get_learning_rate(self, optimiser: torch.optim.Optimizer):
176
+ for group in optimiser.param_groups:
177
+ for param in group["params"]:
178
+ if param is self.linear.weight:
179
+ return group["lr"]
180
+
181
+ def _get_multiplier(self):
182
+ if not self.adaptive or not self.optimisers:
183
+ return 1.0
184
+ else:
185
+ init = self.initial_learning_rates
186
+ current = [self._get_learning_rate(o) for o in self.optimisers]
187
+ pairs = zip(current, init, strict=True)
188
+ multipliers = [a / b for a, b in pairs if b != 0.0]
189
+ return min(multipliers) if multipliers else 0.0
190
+
191
+ def forward(self, x):
192
+ multiplier = self._get_multiplier()
193
+ col_recycling_rate = self.column_recycling_rate * multiplier
194
+ row_recycling_rate = self.row_recycling_rate * multiplier
195
+
196
+ if self.training and self.optimisers:
197
+
198
+ if row_recycling_rate > 0:
199
+ probs = torch.rand(self.linear.out_features, device=x.device)
200
+ mask = probs < row_recycling_rate
201
+ if mask.any():
202
+ # nonzero returns [N, 1], squeeze to get [N]
203
+ indices = torch.nonzero(mask).squeeze(-1)
204
+ self.reset_rows(indices, self.optimisers)
205
+
206
+ if col_recycling_rate > 0:
207
+ probs = torch.rand(self.linear.in_features, device=x.device)
208
+ mask = probs < col_recycling_rate
209
+ if mask.any():
210
+ indices = torch.nonzero(mask).squeeze(-1)
211
+ self.reset_columns(indices, self.optimisers)
212
+
213
+ elif self.training and not self._warned_about_registration:
214
+ warnings.warn(
215
+ "RecyclingLinear: No optimiser registered. Recycling disabled.",
216
+ stacklevel=2,
217
+ )
218
+ self._warned_about_registration = True
219
+
220
+ return self.linear(x)
221
+
222
+ def reset_rows(
223
+ self,
224
+ indices: Iterable[int],
225
+ optimisers: Union[
226
+ List[torch.optim.Optimizer], torch.optim.Optimizer, None
227
+ ] = None,
228
+ ):
229
+ """
230
+ Update some of the weight rows to be equal to the mean of all weight rows.
231
+ """
232
+ if optimisers is None:
233
+ optimisers = []
234
+ if not isinstance(optimisers, list):
235
+ optimisers = [optimisers]
236
+
237
+ device = self.linear.weight.device
238
+ idx_tensor = torch.as_tensor(list(indices), dtype=torch.long, device=device)
239
+
240
+ if idx_tensor.numel() == 0:
241
+ return
242
+
243
+ with torch.no_grad():
244
+ # Calculate mean of all rows including the rows to be reset
245
+ mean_vector = self.linear.weight.data.mean(
246
+ dim=0, keepdim=True
247
+ ) # [1, in_features]
248
+ update_data = mean_vector.expand(idx_tensor.size(0), -1)
249
+ self.linear.weight.data[idx_tensor] = update_data
250
+
251
+ if self.linear.bias is not None:
252
+ self.linear.bias.data[idx_tensor] = 0.0
253
+
254
+ self._reset_optim_state(self.linear.weight, idx_tensor, optimisers, dim=0)
255
+ if self.linear.bias is not None:
256
+ self._reset_optim_state(self.linear.bias, idx_tensor, optimisers, dim=0)
257
+
258
+ def reset_columns(
259
+ self,
260
+ indices: Iterable[int],
261
+ optimisers: Union[
262
+ List[torch.optim.Optimizer], torch.optim.Optimizer, None
263
+ ] = None,
264
+ ):
265
+ """
266
+ Update some of the weight columns to be random as though reinitialised.
267
+ """
268
+ if optimisers is None:
269
+ optimisers = []
270
+ if not isinstance(optimisers, list):
271
+ optimisers = [optimisers]
272
+
273
+ device = self.linear.weight.device
274
+ idx_tensor = torch.as_tensor(list(indices), dtype=torch.long, device=device)
275
+
276
+ if idx_tensor.numel() == 0:
277
+ return
278
+
279
+ with torch.no_grad():
280
+ # 1. Generate Random Columns
281
+ # Shape: [out_features, N_indices]
282
+ weights = self.linear.weight.data
283
+ stdv = 1.0 / math.sqrt(weights.size(1))
284
+
285
+ # Generate [Rows, N] block
286
+ random_weights = torch.rand(
287
+ weights.size(0), idx_tensor.size(0), device=device
288
+ )
289
+ random_weights = (random_weights - 0.5) * 2.0 * stdv
290
+
291
+ # 2. Update Weights (One-shot)
292
+ # We assign into the columns specified by idx_tensor
293
+ self.linear.weight.data[:, idx_tensor] = random_weights
294
+
295
+ # 3. Update Optimizers
296
+ # Bias is untouched by column resets (bias is shape [Out], cols are [In])
297
+ self._reset_optim_state(self.linear.weight, idx_tensor, optimisers, dim=1)
298
+
299
+ def _reset_optim_state(self, param, idx_tensor, optimisers, dim):
300
+ """
301
+ Zeroes out the optimizer state for the given indices in a single operation.
302
+ """
303
+ for optimiser in optimisers:
304
+ if param not in optimiser.state:
305
+ continue
306
+ state = optimiser.state[param]
307
+
308
+ for _, buffer in state.items():
309
+ if torch.is_tensor(buffer) and buffer.shape == param.shape:
310
+ # Vectorized zeroing
311
+ if dim == 0:
312
+ buffer[idx_tensor] = 0.0
313
+ else:
314
+ buffer[:, idx_tensor] = 0.0
@@ -325,6 +325,8 @@ class FeedforwardBlock(nn.Module):
325
325
  activation=nn.ReLU,
326
326
  activation_kwargs=None,
327
327
  dropout=0.0,
328
+ inner_dropout=None,
329
+ outer_dropout=None,
328
330
  linear_module_up=nn.Linear,
329
331
  linear_module_down=nn.Linear,
330
332
  pre_norm=True,
@@ -354,7 +356,12 @@ class FeedforwardBlock(nn.Module):
354
356
  else:
355
357
  self.activation = activation()
356
358
 
357
- self.dropout = nn.Dropout(dropout)
359
+ self.inner_dropout = nn.Dropout(
360
+ inner_dropout if inner_dropout is not None else dropout
361
+ )
362
+ self.outer_dropout = nn.Dropout(
363
+ outer_dropout if outer_dropout is not None else dropout
364
+ )
358
365
 
359
366
  self.max_features = (
360
367
  2 * ratio * output_features
@@ -367,9 +374,10 @@ class FeedforwardBlock(nn.Module):
367
374
  nn.LayerNorm(input_features) if pre_norm else nn.Identity(),
368
375
  linear_module_up(input_features, self.max_features),
369
376
  self.activation,
377
+ self.inner_dropout,
370
378
  nn.LayerNorm(ratio * output_features) if normformer else nn.Identity(),
371
379
  linear_module_down(ratio * output_features, output_features),
372
- self.dropout,
380
+ self.outer_dropout,
373
381
  ]
374
382
  )
375
383
 
@@ -422,7 +430,9 @@ class TransformerBlock(nn.Module):
422
430
  ff_linear_module_up=None,
423
431
  ff_linear_module_down=None,
424
432
  msa_scaling="d",
425
- mlp_dropout=0.0,
433
+ ff_dropout=0.0,
434
+ ff_inner_dropout=0.0,
435
+ ff_outer_dropout=0.0,
426
436
  msa_dropout=0.0,
427
437
  identity_probability=0.0,
428
438
  causal=False,
@@ -484,7 +494,9 @@ class TransformerBlock(nn.Module):
484
494
  d_model,
485
495
  activation=activation,
486
496
  activation_kwargs=activation_kwargs,
487
- dropout=mlp_dropout,
497
+ dropout=ff_dropout,
498
+ inner_dropout=ff_inner_dropout,
499
+ outer_dropout=ff_outer_dropout,
488
500
  linear_module_up=(
489
501
  ff_linear_module_up
490
502
  if ff_linear_module_up is not None
@@ -567,7 +579,9 @@ class TransformerEncoder(nn.Module):
567
579
  activation_kwargs: Optional[dict] = None,
568
580
  ff_linear_module_up=None,
569
581
  ff_linear_module_down=None,
570
- mlp_dropout=0.0,
582
+ ff_dropout=0.0,
583
+ ff_inner_dropout=0.0,
584
+ ff_outer_dropout=0.0,
571
585
  msa_dropout=0.0,
572
586
  stochastic_depth=0.0,
573
587
  causal=False,
@@ -629,7 +643,7 @@ class TransformerEncoder(nn.Module):
629
643
  else:
630
644
  self.absolute_position_embedding = None
631
645
 
632
- self.mlp_dropout = mlp_dropout
646
+ self.mlp_dropout = ff_dropout
633
647
  self.msa_dropout = msa_dropout
634
648
  self.stochastic_depth = stochastic_depth
635
649
 
@@ -658,7 +672,9 @@ class TransformerEncoder(nn.Module):
658
672
  ff_linear_module_up=ff_linear_module_up,
659
673
  ff_linear_module_down=ff_linear_module_down,
660
674
  msa_scaling=msa_scaling,
661
- mlp_dropout=mlp_dropout,
675
+ ff_dropout=ff_dropout,
676
+ ff_inner_dropout=ff_inner_dropout,
677
+ ff_outer_dropout=ff_outer_dropout,
662
678
  msa_dropout=msa_dropout,
663
679
  identity_probability=self.stochastic_depth_probabilities[i],
664
680
  causal=causal,
@@ -161,7 +161,9 @@ class ViTEncoder(nn.Module):
161
161
  transformer_initial_ff_residual_path=True,
162
162
  transformer_initial_ff_linear_module_up=None,
163
163
  transformer_initial_ff_linear_module_down=None,
164
- transformer_initial_ff_mlp_dropout=None,
164
+ transformer_initial_ff_dropout=None,
165
+ transformer_initial_ff_inner_dropout=None,
166
+ transformer_initial_ff_outer_dropout=None,
165
167
  transformer_pre_norm=True,
166
168
  transformer_normformer=False,
167
169
  transformer_post_norm=False,
@@ -178,7 +180,9 @@ class ViTEncoder(nn.Module):
178
180
  transformer_ff_linear_module_up=None,
179
181
  transformer_ff_linear_module_down=None,
180
182
  transformer_msa_scaling="d",
181
- transformer_mlp_dropout=0.0,
183
+ transformer_ff_dropout=0.0,
184
+ transformer_ff_inner_dropout=0.0,
185
+ transformer_ff_outer_dropout=0.0,
182
186
  transformer_msa_dropout=0.1,
183
187
  transformer_stochastic_depth=0.1,
184
188
  transformer_checkpoint_ff=True,
@@ -333,7 +337,9 @@ class ViTEncoder(nn.Module):
333
337
  ff_linear_module_up=transformer_ff_linear_module_up,
334
338
  ff_linear_module_down=transformer_ff_linear_module_down,
335
339
  msa_scaling=transformer_msa_scaling,
336
- mlp_dropout=transformer_mlp_dropout,
340
+ ff_dropout=transformer_ff_dropout,
341
+ ff_inner_dropout=transformer_ff_inner_dropout,
342
+ ff_outer_dropout=transformer_ff_outer_dropout,
337
343
  msa_dropout=transformer_msa_dropout,
338
344
  stochastic_depth=transformer_stochastic_depth,
339
345
  causal=False,
@@ -357,9 +363,21 @@ class ViTEncoder(nn.Module):
357
363
  activation_kwargs=transformer_activation_kwargs,
358
364
  dropout=(
359
365
  # First truthy assigned value
360
- transformer_initial_ff_mlp_dropout
361
- if transformer_initial_ff_mlp_dropout is not None
362
- else transformer_mlp_dropout
366
+ transformer_initial_ff_dropout
367
+ if transformer_initial_ff_dropout is not None
368
+ else transformer_ff_dropout
369
+ ),
370
+ inner_dropout=(
371
+ # First truthy assigned value
372
+ transformer_initial_ff_inner_dropout
373
+ if transformer_initial_ff_inner_dropout is not None
374
+ else transformer_ff_inner_dropout
375
+ ),
376
+ outer_dropout=(
377
+ # First truthy assigned value
378
+ transformer_initial_ff_outer_dropout
379
+ if transformer_initial_ff_outer_dropout is not None
380
+ else transformer_ff_outer_dropout
363
381
  ),
364
382
  linear_module_up=(
365
383
  # First truthy assigned value
@@ -441,7 +459,9 @@ class ViT(nn.Module):
441
459
  transformer_initial_ff_residual_path=True,
442
460
  transformer_initial_ff_linear_module_up=None,
443
461
  transformer_initial_ff_linear_module_down=None,
444
- transformer_initial_ff_mlp_dropout=None,
462
+ transformer_initial_ff_dropout=None,
463
+ transformer_initial_ff_inner_dropout=None,
464
+ transformer_initial_ff_outer_dropout=None,
445
465
  transformer_pre_norm=True,
446
466
  transformer_normformer=False,
447
467
  transformer_post_norm=False,
@@ -458,7 +478,9 @@ class ViT(nn.Module):
458
478
  transformer_ff_linear_module_up=None,
459
479
  transformer_ff_linear_module_down=None,
460
480
  transformer_msa_scaling="d",
461
- transformer_mlp_dropout=0.0,
481
+ transformer_ff_dropout=0.0,
482
+ transformer_ff_inner_dropout=0.0,
483
+ transformer_ff_outer_dropout=0.0,
462
484
  transformer_msa_dropout=0.1,
463
485
  transformer_stochastic_depth=0.1,
464
486
  transformer_checkpoint_ff=True,
@@ -508,7 +530,9 @@ class ViT(nn.Module):
508
530
  transformer_initial_ff_residual_path=transformer_initial_ff_residual_path,
509
531
  transformer_initial_ff_linear_module_up=transformer_initial_ff_linear_module_up,
510
532
  transformer_initial_ff_linear_module_down=transformer_initial_ff_linear_module_down,
511
- transformer_initial_ff_mlp_dropout=transformer_initial_ff_mlp_dropout,
533
+ transformer_initial_ff_dropout=transformer_initial_ff_dropout,
534
+ transformer_initial_ff_inner_dropout=transformer_initial_ff_inner_dropout,
535
+ transformer_initial_ff_outer_dropout=transformer_initial_ff_outer_dropout,
512
536
  transformer_pre_norm=transformer_pre_norm,
513
537
  transformer_normformer=transformer_normformer,
514
538
  transformer_post_norm=transformer_post_norm,
@@ -525,7 +549,9 @@ class ViT(nn.Module):
525
549
  transformer_ff_linear_module_up=transformer_ff_linear_module_up,
526
550
  transformer_ff_linear_module_down=transformer_ff_linear_module_down,
527
551
  transformer_msa_scaling=transformer_msa_scaling,
528
- transformer_mlp_dropout=transformer_mlp_dropout,
552
+ transformer_ff_dropout=transformer_ff_dropout,
553
+ transformer_ff_inner_dropout=transformer_ff_inner_dropout,
554
+ transformer_ff_outer_dropout=transformer_ff_outer_dropout,
529
555
  transformer_msa_dropout=transformer_msa_dropout,
530
556
  transformer_stochastic_depth=transformer_stochastic_depth,
531
557
  transformer_checkpoint_ff=transformer_checkpoint_ff,
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "broccoli-ml"
3
- version = "7.0.0"
3
+ version = "9.1.0"
4
4
  description = "Some useful Pytorch models, circa 2025"
5
5
  authors = [
6
6
  {name = "Nicholas Bailey"}
@@ -1,138 +0,0 @@
1
- import math
2
- import torch
3
- from torch import nn
4
- from torch.nn import functional as F
5
-
6
- from .tensor import SigmaReparamTensor, AnchoredReparamTensor, NormReparamTensor
7
-
8
-
9
- class SpectralNormLinear(nn.Module):
10
- """
11
- Inspired by Apple's Spectral Normed Linear Layers
12
- (https://github.com/apple/ml-sigma-reparam)
13
- """
14
-
15
- def __init__(self, in_features: int, out_features: int, bias: bool = True):
16
- super().__init__()
17
- self.in_features = in_features
18
- self.out_features = out_features
19
- self.use_bias = bias
20
-
21
- self.weights = None
22
-
23
- # Define the bias vector as a learnable parameter if required.
24
- if self.use_bias:
25
- self.bias = nn.Parameter(torch.empty(out_features))
26
- else:
27
- # If no bias, register it as None.
28
- # This is important so that PyTorch doesn't complain when saving/loading the model.
29
- self.register_parameter("bias", None)
30
-
31
- self.reset_parameters()
32
-
33
- def reset_parameters(self) -> None:
34
- weights = torch.empty(self.out_features, self.in_features)
35
- stdv = 1.0 / math.sqrt(self.in_features)
36
- nn.init.uniform_(weights, a=-stdv, b=stdv)
37
- if self.use_bias:
38
- fan_in, _ = nn.init._calculate_fan_in_and_fan_out(weights)
39
- bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
40
- nn.init.uniform_(self.bias, -bound, bound)
41
- self.weights = SigmaReparamTensor(weights)
42
-
43
- def forward(self, x: torch.Tensor) -> torch.Tensor:
44
- return F.linear(x, self.weights(), self.bias)
45
-
46
- def __repr__(self) -> str:
47
- # Optional: A nice representation for printing the module.
48
- return (
49
- f"SpectralNormFeedForward(in_features={self.in_features},"
50
- f"out_features={self.out_features}, bias={self.use_bias})"
51
- )
52
-
53
-
54
- class AnchoredLinear(nn.Module):
55
- """
56
- ...
57
- """
58
-
59
- def __init__(self, in_features: int, out_features: int, bias: bool = True):
60
- super().__init__()
61
- self.in_features = in_features
62
- self.out_features = out_features
63
- self.use_bias = bias
64
-
65
- self.weights = None
66
-
67
- # Define the bias vector as a learnable parameter if required.
68
- if self.use_bias:
69
- self.bias = nn.Parameter(torch.empty(out_features))
70
- else:
71
- # If no bias, register it as None.
72
- # This is important so that PyTorch doesn't complain when saving/loading the model.
73
- self.register_parameter("bias", None)
74
-
75
- self.reset_parameters()
76
-
77
- def reset_parameters(self) -> None:
78
- weights = torch.empty(self.out_features, self.in_features)
79
- stdv = 1.0 / math.sqrt(self.in_features)
80
- nn.init.uniform_(weights, a=-stdv, b=stdv)
81
- if self.use_bias:
82
- fan_in, _ = nn.init._calculate_fan_in_and_fan_out(weights)
83
- bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
84
- nn.init.uniform_(self.bias, -bound, bound)
85
- self.weights = AnchoredReparamTensor(weights)
86
-
87
- def forward(self, x: torch.Tensor) -> torch.Tensor:
88
- return F.linear(x, self.weights(), self.bias)
89
-
90
- def __repr__(self) -> str:
91
- # Optional: A nice representation for printing the module.
92
- return (
93
- f"AnchoredLinear(in_features={self.in_features},"
94
- f"out_features={self.out_features}, bias={self.use_bias})"
95
- )
96
-
97
-
98
- class WeightNormedLinear(nn.Module):
99
- """
100
- ...
101
- """
102
-
103
- def __init__(self, in_features: int, out_features: int, bias: bool = True):
104
- super().__init__()
105
- self.in_features = in_features
106
- self.out_features = out_features
107
- self.use_bias = bias
108
-
109
- self.weights = None
110
-
111
- # Define the bias vector as a learnable parameter if required.
112
- if self.use_bias:
113
- self.bias = nn.Parameter(torch.empty(out_features))
114
- else:
115
- # If no bias, register it as None.
116
- # This is important so that PyTorch doesn't complain when saving/loading the model.
117
- self.register_parameter("bias", None)
118
-
119
- self.reset_parameters()
120
-
121
- def reset_parameters(self) -> None:
122
- weights = torch.empty(self.out_features, self.in_features)
123
- stdv = 1.0 / math.sqrt(self.in_features)
124
- nn.init.uniform_(weights, a=-stdv, b=stdv)
125
- if self.use_bias:
126
- fan_in, _ = nn.init._calculate_fan_in_and_fan_out(weights)
127
- bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
128
- nn.init.uniform_(self.bias, -bound, bound)
129
- self.weights = NormReparamTensor(weights)
130
-
131
- def forward(self, x: torch.Tensor) -> torch.Tensor:
132
- return F.linear(x, self.weights(), self.bias)
133
-
134
- def __repr__(self) -> str:
135
- return (
136
- f"WeightNormedLinear(in_features={self.in_features},"
137
- f"out_features={self.out_features}, bias={self.use_bias})"
138
- )
File without changes
File without changes
File without changes