broccoli-ml 9.2.2__tar.gz → 9.3.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: broccoli-ml
3
- Version: 9.2.2
3
+ Version: 9.3.0
4
4
  Summary: Some useful Pytorch models, circa 2025
5
5
  License: MIT
6
6
  Author: Nicholas Bailey
@@ -151,11 +151,13 @@ class RecyclingLinear(nn.Module):
151
151
  row_recycling_rate: float = 0.0,
152
152
  column_recycling_rate: float = 0.0,
153
153
  adaptive=False,
154
+ xglu=False,
154
155
  ):
155
156
  super().__init__()
156
157
  self.in_features = in_features
157
158
  self.out_features = out_features
158
159
  self.bias = bias
160
+ self.xglu = xglu
159
161
  self.linear = nn.Linear(in_features, out_features, bias=bias)
160
162
  self.row_recycling_rate = row_recycling_rate
161
163
  self.column_recycling_rate = column_recycling_rate
@@ -191,28 +193,50 @@ class RecyclingLinear(nn.Module):
191
193
  multipliers = [a / b for a, b in pairs if b != 0.0]
192
194
  return min(multipliers) if multipliers else 0.0
193
195
 
194
- def forward(self, x):
195
- multiplier = self._get_multiplier()
196
- col_recycling_rate = self.column_recycling_rate * multiplier
197
- row_recycling_rate = self.row_recycling_rate * multiplier
196
+ def reset_rows(self, indices):
197
+ if not torch.is_tensor(indices):
198
+ idx_tensor = torch.as_tensor(
199
+ list(indices), dtype=torch.long, device=self.linear.weight.device
200
+ )
201
+ else:
202
+ idx_tensor = indices
198
203
 
199
- if self.training and self.optimisers:
204
+ if idx_tensor.size(0):
205
+ random_weights = self._random_weights(
206
+ indices.size(0), self.linear.weight.size(1)
207
+ )
208
+ if self.xglu:
209
+ gate_indices = indices
210
+ value_indices = indices + (self.linear.out_features // 2)
211
+ self._update_weights(value_indices, 0, random_weights, self.optimisers)
212
+ centred_gate_weights = self._mean_gate_weights()
213
+ centred_gate_weights = centred_gate_weights.expand(indices.size(0), -1)
214
+ self._update_weights(
215
+ gate_indices, 0, centred_gate_weights, self.optimisers # dim
216
+ )
217
+ else:
218
+ return
219
+
220
+ def reset_columns(self, indices):
221
+ if not torch.is_tensor(indices):
222
+ idx_tensor = torch.as_tensor(
223
+ list(indices), dtype=torch.long, device=self.linear.weight.device
224
+ )
225
+ else:
226
+ idx_tensor = indices
200
227
 
201
- if row_recycling_rate > 0:
202
- probs = torch.rand(self.linear.out_features, device=x.device)
203
- mask = probs < row_recycling_rate
204
- if mask.any():
205
- # nonzero returns [N, 1], squeeze to get [N]
206
- indices = torch.nonzero(mask).squeeze(-1)
207
- self.reset_rows(indices, self.optimisers)
208
-
209
- if col_recycling_rate > 0:
210
- probs = torch.rand(self.linear.in_features, device=x.device)
211
- mask = probs < col_recycling_rate
212
- if mask.any():
213
- indices = torch.nonzero(mask).squeeze(-1)
214
- self.reset_columns(indices, self.optimisers)
228
+ if idx_tensor.size(0):
229
+ random_weights = self._random_weights(
230
+ self.linear.weight.size(0), indices.size(0)
231
+ )
232
+ self._update_weights(indices, 1, random_weights, self.optimisers) # dim
233
+ else:
234
+ return
215
235
 
236
+ def forward(self, x):
237
+ if self.training and self.optimisers:
238
+ self.reset_rows(self.get_reset_indices(0))
239
+ self.reset_columns(self.get_reset_indices(1))
216
240
  elif self.training and not self._warned_about_registration:
217
241
  warnings.warn(
218
242
  "RecyclingLinear: No optimiser registered. Recycling disabled.",
@@ -222,82 +246,77 @@ class RecyclingLinear(nn.Module):
222
246
 
223
247
  return self.linear(x)
224
248
 
225
- def reset_rows(
226
- self,
227
- indices: Iterable[int],
228
- optimisers: Union[
229
- List[torch.optim.Optimizer], torch.optim.Optimizer, None
230
- ] = None,
231
- ):
232
- """
233
- Update some of the weight rows to be equal to the mean of all weight rows.
234
- """
235
- if optimisers is None:
236
- optimisers = []
237
- if not isinstance(optimisers, list):
238
- optimisers = [optimisers]
239
-
240
- device = self.linear.weight.device
241
- idx_tensor = torch.as_tensor(list(indices), dtype=torch.long, device=device)
242
-
243
- if idx_tensor.numel() == 0:
244
- return
245
-
246
- with torch.no_grad():
247
- # Calculate mean of all rows including the rows to be reset
248
- mean_vector = self.linear.weight.data.mean(
249
- dim=0, keepdim=True
250
- ) # [1, in_features]
251
- update_data = mean_vector.expand(idx_tensor.size(0), -1)
252
- self.linear.weight.data[idx_tensor] = update_data
249
+ def get_reset_indices(self, dim):
250
+ base_rate = self.row_recycling_rate if dim == 0 else self.column_recycling_rate
251
+ p = base_rate * self._get_multiplier()
252
+ if dim == 0:
253
+ if self.xglu:
254
+ sample_space = self.linear.out_features // 2
255
+ else:
256
+ sample_space = self.linear.out_features
257
+ elif dim == 1:
258
+ sample_space = self.linear.in_features
259
+ else:
260
+ raise ValueError("`dim` must be 0 or 1")
253
261
 
254
- if self.linear.bias is not None:
255
- self.linear.bias.data[idx_tensor] = 0.0
262
+ # Sample the indices
263
+ probs = torch.rand(sample_space, device=self.linear.weight.device)
264
+ mask = probs < p
265
+ if mask.any():
266
+ return torch.nonzero(mask).squeeze(-1)
267
+ else:
268
+ return torch.tensor([], dtype=torch.long, device=self.linear.weight.device)
256
269
 
257
- self._reset_optim_state(self.linear.weight, idx_tensor, optimisers, dim=0)
258
- if self.linear.bias is not None:
259
- self._reset_optim_state(self.linear.bias, idx_tensor, optimisers, dim=0)
270
+ def _random_weights(self, rows, columns):
271
+ device = self.linear.weight.device
272
+ weights = self.linear.weight.data
273
+ stdv = 1.0 / math.sqrt(weights.size(1))
274
+ random_weights = torch.rand(rows, columns, device=device)
275
+ random_weights -= 0.5 # Range [-0.5, +0.5]
276
+ random_weights *= 2.0 * stdv # Range [-stdv, +stdv]
277
+ return random_weights
278
+
279
+ def _mean_gate_weights(self):
280
+ """
281
+ Only used when self.xglu
282
+ """
283
+ weights = self.linear.weight.data
284
+ rows = weights.size(0)
285
+ return self.linear.weight[: int(rows / 2)].data.mean(dim=0, keepdim=True)
260
286
 
261
- def reset_columns(
287
+ def _update_weights(
262
288
  self,
263
289
  indices: Iterable[int],
290
+ dim: int,
291
+ data: torch.Tensor,
264
292
  optimisers: Union[
265
293
  List[torch.optim.Optimizer], torch.optim.Optimizer, None
266
294
  ] = None,
267
295
  ):
268
- """
269
- Update some of the weight columns to be random as though reinitialised.
270
- """
271
296
  if optimisers is None:
272
297
  optimisers = []
273
298
  if not isinstance(optimisers, list):
274
299
  optimisers = [optimisers]
275
300
 
276
- device = self.linear.weight.device
277
- idx_tensor = torch.as_tensor(list(indices), dtype=torch.long, device=device)
301
+ if not torch.is_tensor(indices):
302
+ idx_tensor = torch.as_tensor(
303
+ list(indices), dtype=torch.long, device=self.linear.weight.device
304
+ )
305
+ else:
306
+ idx_tensor = indices
278
307
 
279
308
  if idx_tensor.numel() == 0:
280
309
  return
281
310
 
282
311
  with torch.no_grad():
283
- # 1. Generate Random Columns
284
- # Shape: [out_features, N_indices]
285
- weights = self.linear.weight.data
286
- stdv = 1.0 / math.sqrt(weights.size(1))
287
-
288
- # Generate [Rows, N] block
289
- random_weights = torch.rand(
290
- weights.size(0), idx_tensor.size(0), device=device
291
- )
292
- random_weights = (random_weights - 0.5) * 2.0 * stdv
293
-
294
- # 2. Update Weights (One-shot)
295
- # We assign into the columns specified by idx_tensor
296
- self.linear.weight.data[:, idx_tensor] = random_weights
297
-
298
- # 3. Update Optimizers
299
- # Bias is untouched by column resets (bias is shape [Out], cols are [In])
300
- self._reset_optim_state(self.linear.weight, idx_tensor, optimisers, dim=1)
312
+ if dim == 0:
313
+ self.linear.weight.data[idx_tensor] = data
314
+ elif dim == 1:
315
+ self.linear.weight.data[:, idx_tensor] = data
316
+ else:
317
+ raise ValueError("`dim` must be 0 or 1")
318
+
319
+ self._reset_optim_state(self.linear.weight, idx_tensor, optimisers, dim=dim)
301
320
 
302
321
  def _reset_optim_state(self, param, idx_tensor, optimisers, dim):
303
322
  """
@@ -410,21 +410,9 @@ class FeedforwardBlock(nn.Module):
410
410
 
411
411
  # Recycle weights if using recycling linear layers
412
412
  if self.training and self.recycling_enabled:
413
- multiplier = self.linear_in._get_multiplier()
414
- rate = self.master_recycling_rate * multiplier
415
- if rate > 0:
416
- probs = torch.rand(self.linear_out.in_features, device=x.device)
417
- mask = probs < rate
418
- if mask.any():
419
- indices = torch.nonzero(mask).squeeze(-1)
420
- self.linear_out.reset_columns(indices, self.linear_out.optimisers)
421
- if self.xglu:
422
- indices_in = torch.cat(
423
- [indices, indices + self.linear_out.in_features]
424
- )
425
- self.linear_in.reset_rows(indices_in, self.linear_in.optimisers)
426
- else:
427
- self.linear_in.reset_rows(indices, self.linear_in.optimisers)
413
+ indices = self.linear_out.get_reset_indices(1)
414
+ self.linear_in.reset_rows(indices)
415
+ self.linear_out.reset_columns(indices)
428
416
 
429
417
  if self.checkpoint:
430
418
  processed = checkpoint(self.process, x, use_reentrant=False)
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "broccoli-ml"
3
- version = "9.2.2"
3
+ version = "9.3.0"
4
4
  description = "Some useful Pytorch models, circa 2025"
5
5
  authors = [
6
6
  {name = "Nicholas Bailey"}
File without changes
File without changes
File without changes
File without changes