broccoli-ml 9.0.0__tar.gz → 9.2.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: broccoli-ml
3
- Version: 9.0.0
3
+ Version: 9.2.0
4
4
  Summary: Some useful Pytorch models, circa 2025
5
5
  License: MIT
6
6
  Author: Nicholas Bailey
@@ -1,5 +1,6 @@
1
1
  import math
2
2
  import random
3
+ import warnings
3
4
  from typing import Union, List, Iterable
4
5
 
5
6
  import torch
@@ -149,34 +150,73 @@ class RecyclingLinear(nn.Module):
149
150
  bias: bool = True,
150
151
  row_recycling_rate: float = 0.0,
151
152
  column_recycling_rate: float = 0.0,
153
+ adaptive=False,
152
154
  ):
153
155
  super().__init__()
154
156
  self.linear = nn.Linear(in_features, out_features, bias=bias)
155
157
  self.row_recycling_rate = row_recycling_rate
156
158
  self.column_recycling_rate = column_recycling_rate
159
+ self.adaptive = adaptive
157
160
  self.optimisers = []
161
+ self.initial_learning_rates = []
162
+ self._warned_about_registration = False
158
163
 
159
164
  def register_optimiser(self, optimiser: torch.optim.Optimizer):
160
165
  self.optimisers.append(optimiser)
166
+ self.initial_learning_rates.append(self._get_learning_rate(optimiser))
167
+ if self.initial_learning_rates[-1] == 0.0:
168
+ warnings.warn(
169
+ "Learning rate of registered optimiser was 0.0 - make sure "
170
+ "you haven't initialised a scheduler before registering the "
171
+ "optimiser",
172
+ stacklevel=2,
173
+ )
174
+
175
+ def _get_learning_rate(self, optimiser: torch.optim.Optimizer):
176
+ for group in optimiser.param_groups:
177
+ for param in group["params"]:
178
+ if param is self.linear.weight:
179
+ return group["lr"]
180
+
181
+ def _get_multiplier(self):
182
+ if not self.adaptive or not self.optimisers:
183
+ return 1.0
184
+ else:
185
+ init = self.initial_learning_rates
186
+ current = [self._get_learning_rate(o) for o in self.optimisers]
187
+ pairs = zip(current, init, strict=True)
188
+ multipliers = [a / b for a, b in pairs if b != 0.0]
189
+ return min(multipliers) if multipliers else 0.0
161
190
 
162
191
  def forward(self, x):
192
+ multiplier = self._get_multiplier()
193
+ col_recycling_rate = self.column_recycling_rate * multiplier
194
+ row_recycling_rate = self.row_recycling_rate * multiplier
195
+
163
196
  if self.training and self.optimisers:
164
197
 
165
- if self.row_recycling_rate > 0:
198
+ if row_recycling_rate > 0:
166
199
  probs = torch.rand(self.linear.out_features, device=x.device)
167
- mask = probs < self.row_recycling_rate
200
+ mask = probs < row_recycling_rate
168
201
  if mask.any():
169
202
  # nonzero returns [N, 1], squeeze to get [N]
170
203
  indices = torch.nonzero(mask).squeeze(-1)
171
204
  self.reset_rows(indices, self.optimisers)
172
205
 
173
- if self.column_recycling_rate > 0:
206
+ if col_recycling_rate > 0:
174
207
  probs = torch.rand(self.linear.in_features, device=x.device)
175
- mask = probs < self.column_recycling_rate
208
+ mask = probs < col_recycling_rate
176
209
  if mask.any():
177
210
  indices = torch.nonzero(mask).squeeze(-1)
178
211
  self.reset_columns(indices, self.optimisers)
179
212
 
213
+ elif self.training and not self._warned_about_registration:
214
+ warnings.warn(
215
+ "RecyclingLinear: No optimiser registered. Recycling disabled.",
216
+ stacklevel=2,
217
+ )
218
+ self._warned_about_registration = True
219
+
180
220
  return self.linear(x)
181
221
 
182
222
  def reset_rows(
@@ -340,6 +340,7 @@ class FeedforwardBlock(nn.Module):
340
340
  self.checkpoint = checkpoint
341
341
  self.residual_path = residual_path
342
342
  self.post_norm = post_norm
343
+ self.xglu = activation.__name__.endswith("GLU")
343
344
 
344
345
  if self.residual_path and (output_features < input_features):
345
346
  raise ValueError(
@@ -364,27 +365,63 @@ class FeedforwardBlock(nn.Module):
364
365
  )
365
366
 
366
367
  self.max_features = (
367
- 2 * ratio * output_features
368
- if activation.__name__.endswith("GLU")
369
- else ratio * output_features
368
+ 2 * ratio * output_features if self.xglu else ratio * output_features
370
369
  )
371
370
 
371
+ self.linear_in = linear_module_up(input_features, self.max_features)
372
+ self.linear_out = linear_module_down(ratio * output_features, output_features)
373
+
372
374
  self.process = nn.Sequential(
373
375
  *[
374
376
  nn.LayerNorm(input_features) if pre_norm else nn.Identity(),
375
- linear_module_up(input_features, self.max_features),
377
+ self.linear_in,
376
378
  self.activation,
377
379
  self.inner_dropout,
378
380
  nn.LayerNorm(ratio * output_features) if normformer else nn.Identity(),
379
- linear_module_down(ratio * output_features, output_features),
381
+ self.linear_out,
380
382
  self.outer_dropout,
381
383
  ]
382
384
  )
383
385
 
386
+ self.recycling_enabled = False
387
+ if hasattr(self.linear_in, "row_recycling_rate") and hasattr(
388
+ self.linear_out, "column_recycling_rate"
389
+ ):
390
+ self.recycling_enabled = True
391
+ self.master_recycling_rate = self.linear_in.row_recycling_rate
392
+ self.linear_in.row_recycling_rate = 0.0
393
+ self.linear_out.column_recycling_rate = 0.0
394
+ if hasattr(self.linear_in, "column_recycling_rate") or hasattr(
395
+ self.linear_out, "row_recycling_rate"
396
+ ):
397
+ raise NotImplementedError(
398
+ "At the moment this layer can only support recycling linear "
399
+ "layers if the in layer resets only rows and the out layer "
400
+ "resets only columns."
401
+ )
402
+
384
403
  self.reset_parameters()
385
404
 
386
405
  def forward(self, x):
387
406
 
407
+ # Recycle weights if using recycling linear layers
408
+ if self.training and self.recycling_enabled:
409
+ multiplier = self.linear_in._get_multiplier()
410
+ rate = self.master_recycling_rate * multiplier
411
+ if rate > 0:
412
+ probs = torch.rand(self.linear_out.in_features, device=x.device)
413
+ mask = probs < rate
414
+ if mask.any():
415
+ indices = torch.nonzero(mask).squeeze(-1)
416
+ self.linear_out.reset_columns(indices, self.linear_out.optimisers)
417
+ if self.xglu:
418
+ indices_in = torch.cat(
419
+ [indices, indices + self.linear_out.in_features]
420
+ )
421
+ self.linear_in.reset_rows(indices_in, self.linear_in.optimisers)
422
+ else:
423
+ self.linear_in.reset_rows(indices, self.linear_in.optimisers)
424
+
388
425
  if self.checkpoint:
389
426
  processed = checkpoint(self.process, x, use_reentrant=False)
390
427
  else:
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "broccoli-ml"
3
- version = "9.0.0"
3
+ version = "9.2.0"
4
4
  description = "Some useful Pytorch models, circa 2025"
5
5
  authors = [
6
6
  {name = "Nicholas Bailey"}
File without changes
File without changes
File without changes
File without changes