broccoli-ml 9.2.2__tar.gz → 9.4.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: broccoli-ml
3
- Version: 9.2.2
3
+ Version: 9.4.0
4
4
  Summary: Some useful Pytorch models, circa 2025
5
5
  License: MIT
6
6
  Author: Nicholas Bailey
@@ -151,11 +151,13 @@ class RecyclingLinear(nn.Module):
151
151
  row_recycling_rate: float = 0.0,
152
152
  column_recycling_rate: float = 0.0,
153
153
  adaptive=False,
154
+ xglu=False,
154
155
  ):
155
156
  super().__init__()
156
157
  self.in_features = in_features
157
158
  self.out_features = out_features
158
159
  self.bias = bias
160
+ self.xglu = xglu
159
161
  self.linear = nn.Linear(in_features, out_features, bias=bias)
160
162
  self.row_recycling_rate = row_recycling_rate
161
163
  self.column_recycling_rate = column_recycling_rate
@@ -191,28 +193,51 @@ class RecyclingLinear(nn.Module):
191
193
  multipliers = [a / b for a, b in pairs if b != 0.0]
192
194
  return min(multipliers) if multipliers else 0.0
193
195
 
194
- def forward(self, x):
195
- multiplier = self._get_multiplier()
196
- col_recycling_rate = self.column_recycling_rate * multiplier
197
- row_recycling_rate = self.row_recycling_rate * multiplier
196
+ def reset_rows(self, indices):
197
+ if not torch.is_tensor(indices):
198
+ idx_tensor = torch.as_tensor(
199
+ list(indices), dtype=torch.long, device=self.linear.weight.device
200
+ )
201
+ else:
202
+ idx_tensor = indices
198
203
 
199
- if self.training and self.optimisers:
204
+ if idx_tensor.size(0):
205
+ random_weights = self._random_weights(
206
+ indices.size(0), self.linear.weight.size(1)
207
+ )
208
+ if self.xglu:
209
+ gate_indices = indices
210
+ value_indices = indices + (self.linear.out_features // 2)
211
+ self._update_weights(value_indices, 0, random_weights, self.optimisers)
212
+ centred_gate_weights = self._mean_gate_weights()
213
+ centred_gate_weights = centred_gate_weights.expand(indices.size(0), -1)
214
+ self._update_weights(
215
+ gate_indices, 0, centred_gate_weights, self.optimisers # dim
216
+ )
217
+ else:
218
+ return
219
+
220
+ def reset_columns(self, indices):
221
+ if not torch.is_tensor(indices):
222
+ idx_tensor = torch.as_tensor(
223
+ list(indices), dtype=torch.long, device=self.linear.weight.device
224
+ )
225
+ else:
226
+ idx_tensor = indices
200
227
 
201
- if row_recycling_rate > 0:
202
- probs = torch.rand(self.linear.out_features, device=x.device)
203
- mask = probs < row_recycling_rate
204
- if mask.any():
205
- # nonzero returns [N, 1], squeeze to get [N]
206
- indices = torch.nonzero(mask).squeeze(-1)
207
- self.reset_rows(indices, self.optimisers)
208
-
209
- if col_recycling_rate > 0:
210
- probs = torch.rand(self.linear.in_features, device=x.device)
211
- mask = probs < col_recycling_rate
212
- if mask.any():
213
- indices = torch.nonzero(mask).squeeze(-1)
214
- self.reset_columns(indices, self.optimisers)
228
+ if idx_tensor.size(0):
229
+ zeros = torch.zeros(
230
+ (self.linear.weight.size(0), indices.size(0)),
231
+ device=self.linear.weight.device,
232
+ )
233
+ self._update_weights(indices, 1, zeros, self.optimisers) # dim
234
+ else:
235
+ return
215
236
 
237
+ def forward(self, x):
238
+ if self.training and self.optimisers:
239
+ self.reset_rows(self.get_reset_indices(0))
240
+ self.reset_columns(self.get_reset_indices(1))
216
241
  elif self.training and not self._warned_about_registration:
217
242
  warnings.warn(
218
243
  "RecyclingLinear: No optimiser registered. Recycling disabled.",
@@ -222,82 +247,77 @@ class RecyclingLinear(nn.Module):
222
247
 
223
248
  return self.linear(x)
224
249
 
225
- def reset_rows(
226
- self,
227
- indices: Iterable[int],
228
- optimisers: Union[
229
- List[torch.optim.Optimizer], torch.optim.Optimizer, None
230
- ] = None,
231
- ):
232
- """
233
- Update some of the weight rows to be equal to the mean of all weight rows.
234
- """
235
- if optimisers is None:
236
- optimisers = []
237
- if not isinstance(optimisers, list):
238
- optimisers = [optimisers]
239
-
240
- device = self.linear.weight.device
241
- idx_tensor = torch.as_tensor(list(indices), dtype=torch.long, device=device)
242
-
243
- if idx_tensor.numel() == 0:
244
- return
245
-
246
- with torch.no_grad():
247
- # Calculate mean of all rows including the rows to be reset
248
- mean_vector = self.linear.weight.data.mean(
249
- dim=0, keepdim=True
250
- ) # [1, in_features]
251
- update_data = mean_vector.expand(idx_tensor.size(0), -1)
252
- self.linear.weight.data[idx_tensor] = update_data
250
+ def get_reset_indices(self, dim):
251
+ base_rate = self.row_recycling_rate if dim == 0 else self.column_recycling_rate
252
+ p = base_rate * self._get_multiplier()
253
+ if dim == 0:
254
+ if self.xglu:
255
+ sample_space = self.linear.out_features // 2
256
+ else:
257
+ sample_space = self.linear.out_features
258
+ elif dim == 1:
259
+ sample_space = self.linear.in_features
260
+ else:
261
+ raise ValueError("`dim` must be 0 or 1")
253
262
 
254
- if self.linear.bias is not None:
255
- self.linear.bias.data[idx_tensor] = 0.0
263
+ # Sample the indices
264
+ probs = torch.rand(sample_space, device=self.linear.weight.device)
265
+ mask = probs < p
266
+ if mask.any():
267
+ return torch.nonzero(mask).squeeze(-1)
268
+ else:
269
+ return torch.tensor([], dtype=torch.long, device=self.linear.weight.device)
256
270
 
257
- self._reset_optim_state(self.linear.weight, idx_tensor, optimisers, dim=0)
258
- if self.linear.bias is not None:
259
- self._reset_optim_state(self.linear.bias, idx_tensor, optimisers, dim=0)
271
+ def _random_weights(self, rows, columns):
272
+ device = self.linear.weight.device
273
+ weights = self.linear.weight.data
274
+ stdv = 1.0 / math.sqrt(weights.size(1))
275
+ random_weights = torch.rand(rows, columns, device=device)
276
+ random_weights -= 0.5 # Range [-0.5, +0.5]
277
+ random_weights *= 2.0 * stdv # Range [-stdv, +stdv]
278
+ return random_weights
279
+
280
+ def _mean_gate_weights(self):
281
+ """
282
+ Only used when self.xglu
283
+ """
284
+ weights = self.linear.weight.data
285
+ rows = weights.size(0)
286
+ return self.linear.weight[: int(rows / 2)].data.mean(dim=0, keepdim=True)
260
287
 
261
- def reset_columns(
288
+ def _update_weights(
262
289
  self,
263
290
  indices: Iterable[int],
291
+ dim: int,
292
+ data: torch.Tensor,
264
293
  optimisers: Union[
265
294
  List[torch.optim.Optimizer], torch.optim.Optimizer, None
266
295
  ] = None,
267
296
  ):
268
- """
269
- Update some of the weight columns to be random as though reinitialised.
270
- """
271
297
  if optimisers is None:
272
298
  optimisers = []
273
299
  if not isinstance(optimisers, list):
274
300
  optimisers = [optimisers]
275
301
 
276
- device = self.linear.weight.device
277
- idx_tensor = torch.as_tensor(list(indices), dtype=torch.long, device=device)
302
+ if not torch.is_tensor(indices):
303
+ idx_tensor = torch.as_tensor(
304
+ list(indices), dtype=torch.long, device=self.linear.weight.device
305
+ )
306
+ else:
307
+ idx_tensor = indices
278
308
 
279
309
  if idx_tensor.numel() == 0:
280
310
  return
281
311
 
282
312
  with torch.no_grad():
283
- # 1. Generate Random Columns
284
- # Shape: [out_features, N_indices]
285
- weights = self.linear.weight.data
286
- stdv = 1.0 / math.sqrt(weights.size(1))
287
-
288
- # Generate [Rows, N] block
289
- random_weights = torch.rand(
290
- weights.size(0), idx_tensor.size(0), device=device
291
- )
292
- random_weights = (random_weights - 0.5) * 2.0 * stdv
293
-
294
- # 2. Update Weights (One-shot)
295
- # We assign into the columns specified by idx_tensor
296
- self.linear.weight.data[:, idx_tensor] = random_weights
297
-
298
- # 3. Update Optimizers
299
- # Bias is untouched by column resets (bias is shape [Out], cols are [In])
300
- self._reset_optim_state(self.linear.weight, idx_tensor, optimisers, dim=1)
313
+ if dim == 0:
314
+ self.linear.weight.data[idx_tensor] = data
315
+ elif dim == 1:
316
+ self.linear.weight.data[:, idx_tensor] = data
317
+ else:
318
+ raise ValueError("`dim` must be 0 or 1")
319
+
320
+ self._reset_optim_state(self.linear.weight, idx_tensor, optimisers, dim=dim)
301
321
 
302
322
  def _reset_optim_state(self, param, idx_tensor, optimisers, dim):
303
323
  """
@@ -410,21 +410,9 @@ class FeedforwardBlock(nn.Module):
410
410
 
411
411
  # Recycle weights if using recycling linear layers
412
412
  if self.training and self.recycling_enabled:
413
- multiplier = self.linear_in._get_multiplier()
414
- rate = self.master_recycling_rate * multiplier
415
- if rate > 0:
416
- probs = torch.rand(self.linear_out.in_features, device=x.device)
417
- mask = probs < rate
418
- if mask.any():
419
- indices = torch.nonzero(mask).squeeze(-1)
420
- self.linear_out.reset_columns(indices, self.linear_out.optimisers)
421
- if self.xglu:
422
- indices_in = torch.cat(
423
- [indices, indices + self.linear_out.in_features]
424
- )
425
- self.linear_in.reset_rows(indices_in, self.linear_in.optimisers)
426
- else:
427
- self.linear_in.reset_rows(indices, self.linear_in.optimisers)
413
+ indices = self.linear_out.get_reset_indices(1)
414
+ self.linear_in.reset_rows(indices)
415
+ self.linear_out.reset_columns(indices)
428
416
 
429
417
  if self.checkpoint:
430
418
  processed = checkpoint(self.process, x, use_reentrant=False)
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "broccoli-ml"
3
- version = "9.2.2"
3
+ version = "9.4.0"
4
4
  description = "Some useful Pytorch models, circa 2025"
5
5
  authors = [
6
6
  {name = "Nicholas Bailey"}
File without changes
File without changes
File without changes
File without changes