broccoli-ml 9.0.0__tar.gz → 9.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: broccoli-ml
3
- Version: 9.0.0
3
+ Version: 9.1.0
4
4
  Summary: Some useful Pytorch models, circa 2025
5
5
  License: MIT
6
6
  Author: Nicholas Bailey
@@ -1,5 +1,6 @@
1
1
  import math
2
2
  import random
3
+ import warnings
3
4
  from typing import Union, List, Iterable
4
5
 
5
6
  import torch
@@ -149,34 +150,73 @@ class RecyclingLinear(nn.Module):
149
150
  bias: bool = True,
150
151
  row_recycling_rate: float = 0.0,
151
152
  column_recycling_rate: float = 0.0,
153
+ adaptive=False,
152
154
  ):
153
155
  super().__init__()
154
156
  self.linear = nn.Linear(in_features, out_features, bias=bias)
155
157
  self.row_recycling_rate = row_recycling_rate
156
158
  self.column_recycling_rate = column_recycling_rate
159
+ self.adaptive = adaptive
157
160
  self.optimisers = []
161
+ self.initial_learning_rates = []
162
+ self._warned_about_registration = False
158
163
 
159
164
  def register_optimiser(self, optimiser: torch.optim.Optimizer):
160
165
  self.optimisers.append(optimiser)
166
+ self.initial_learning_rates.append(self._get_learning_rate(optimiser))
167
+ if self.initial_learning_rates[-1] == 0.0:
168
+ warnings.warn(
169
+ "Learning rate of registered optimiser was 0.0 - make sure "
170
+ "you haven't initialised a scheduler before registering the "
171
+ "optimiser",
172
+ stacklevel=2,
173
+ )
174
+
175
+ def _get_learning_rate(self, optimiser: torch.optim.Optimizer):
176
+ for group in optimiser.param_groups:
177
+ for param in group["params"]:
178
+ if param is self.linear.weight:
179
+ return group["lr"]
180
+
181
+ def _get_multiplier(self):
182
+ if not self.adaptive or not self.optimisers:
183
+ return 1.0
184
+ else:
185
+ init = self.initial_learning_rates
186
+ current = [self._get_learning_rate(o) for o in self.optimisers]
187
+ pairs = zip(current, init, strict=True)
188
+ multipliers = [a / b for a, b in pairs if b != 0.0]
189
+ return min(multipliers) if multipliers else 0.0
161
190
 
162
191
  def forward(self, x):
192
+ multiplier = self._get_multiplier()
193
+ col_recycling_rate = self.column_recycling_rate * multiplier
194
+ row_recycling_rate = self.row_recycling_rate * multiplier
195
+
163
196
  if self.training and self.optimisers:
164
197
 
165
- if self.row_recycling_rate > 0:
198
+ if row_recycling_rate > 0:
166
199
  probs = torch.rand(self.linear.out_features, device=x.device)
167
- mask = probs < self.row_recycling_rate
200
+ mask = probs < row_recycling_rate
168
201
  if mask.any():
169
202
  # nonzero returns [N, 1], squeeze to get [N]
170
203
  indices = torch.nonzero(mask).squeeze(-1)
171
204
  self.reset_rows(indices, self.optimisers)
172
205
 
173
- if self.column_recycling_rate > 0:
206
+ if col_recycling_rate > 0:
174
207
  probs = torch.rand(self.linear.in_features, device=x.device)
175
- mask = probs < self.column_recycling_rate
208
+ mask = probs < col_recycling_rate
176
209
  if mask.any():
177
210
  indices = torch.nonzero(mask).squeeze(-1)
178
211
  self.reset_columns(indices, self.optimisers)
179
212
 
213
+ elif self.training and not self._warned_about_registration:
214
+ warnings.warn(
215
+ "RecyclingLinear: No optimiser registered. Recycling disabled.",
216
+ stacklevel=2,
217
+ )
218
+ self._warned_about_registration = True
219
+
180
220
  return self.linear(x)
181
221
 
182
222
  def reset_rows(
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "broccoli-ml"
3
- version = "9.0.0"
3
+ version = "9.1.0"
4
4
  description = "Some useful Pytorch models, circa 2025"
5
5
  authors = [
6
6
  {name = "Nicholas Bailey"}
File without changes
File without changes
File without changes
File without changes