broccoli-ml 9.2.1__py3-none-any.whl → 9.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
broccoli/linear.py CHANGED
@@ -151,8 +151,13 @@ class RecyclingLinear(nn.Module):
151
151
  row_recycling_rate: float = 0.0,
152
152
  column_recycling_rate: float = 0.0,
153
153
  adaptive=False,
154
+ xglu=False,
154
155
  ):
155
156
  super().__init__()
157
+ self.in_features = in_features
158
+ self.out_features = out_features
159
+ self.bias = bias
160
+ self.xglu = xglu
156
161
  self.linear = nn.Linear(in_features, out_features, bias=bias)
157
162
  self.row_recycling_rate = row_recycling_rate
158
163
  self.column_recycling_rate = column_recycling_rate
@@ -188,28 +193,50 @@ class RecyclingLinear(nn.Module):
188
193
  multipliers = [a / b for a, b in pairs if b != 0.0]
189
194
  return min(multipliers) if multipliers else 0.0
190
195
 
191
- def forward(self, x):
192
- multiplier = self._get_multiplier()
193
- col_recycling_rate = self.column_recycling_rate * multiplier
194
- row_recycling_rate = self.row_recycling_rate * multiplier
196
+ def reset_rows(self, indices):
197
+ if not torch.is_tensor(indices):
198
+ idx_tensor = torch.as_tensor(
199
+ list(indices), dtype=torch.long, device=self.linear.weight.device
200
+ )
201
+ else:
202
+ idx_tensor = indices
195
203
 
196
- if self.training and self.optimisers:
204
+ if idx_tensor.size(0):
205
+ random_weights = self._random_weights(
206
+ indices.size(0), self.linear.weight.size(1)
207
+ )
208
+ if self.xglu:
209
+ gate_indices = indices
210
+ value_indices = indices + (self.linear.out_features // 2)
211
+ self._update_weights(value_indices, 0, random_weights, self.optimisers)
212
+ centred_gate_weights = self._mean_gate_weights()
213
+ centred_gate_weights = centred_gate_weights.expand(indices.size(0), -1)
214
+ self._update_weights(
215
+ gate_indices, 0, centred_gate_weights, self.optimisers # dim
216
+ )
217
+ else:
218
+ return
197
219
 
198
- if row_recycling_rate > 0:
199
- probs = torch.rand(self.linear.out_features, device=x.device)
200
- mask = probs < row_recycling_rate
201
- if mask.any():
202
- # nonzero returns [N, 1], squeeze to get [N]
203
- indices = torch.nonzero(mask).squeeze(-1)
204
- self.reset_rows(indices, self.optimisers)
205
-
206
- if col_recycling_rate > 0:
207
- probs = torch.rand(self.linear.in_features, device=x.device)
208
- mask = probs < col_recycling_rate
209
- if mask.any():
210
- indices = torch.nonzero(mask).squeeze(-1)
211
- self.reset_columns(indices, self.optimisers)
220
+ def reset_columns(self, indices):
221
+ if not torch.is_tensor(indices):
222
+ idx_tensor = torch.as_tensor(
223
+ list(indices), dtype=torch.long, device=self.linear.weight.device
224
+ )
225
+ else:
226
+ idx_tensor = indices
212
227
 
228
+ if idx_tensor.size(0):
229
+ random_weights = self._random_weights(
230
+ self.linear.weight.size(0), indices.size(0)
231
+ )
232
+ self._update_weights(indices, 1, random_weights, self.optimisers) # dim
233
+ else:
234
+ return
235
+
236
+ def forward(self, x):
237
+ if self.training and self.optimisers:
238
+ self.reset_rows(self.get_reset_indices(0))
239
+ self.reset_columns(self.get_reset_indices(1))
213
240
  elif self.training and not self._warned_about_registration:
214
241
  warnings.warn(
215
242
  "RecyclingLinear: No optimiser registered. Recycling disabled.",
@@ -219,82 +246,77 @@ class RecyclingLinear(nn.Module):
219
246
 
220
247
  return self.linear(x)
221
248
 
222
- def reset_rows(
223
- self,
224
- indices: Iterable[int],
225
- optimisers: Union[
226
- List[torch.optim.Optimizer], torch.optim.Optimizer, None
227
- ] = None,
228
- ):
229
- """
230
- Update some of the weight rows to be equal to the mean of all weight rows.
231
- """
232
- if optimisers is None:
233
- optimisers = []
234
- if not isinstance(optimisers, list):
235
- optimisers = [optimisers]
236
-
237
- device = self.linear.weight.device
238
- idx_tensor = torch.as_tensor(list(indices), dtype=torch.long, device=device)
239
-
240
- if idx_tensor.numel() == 0:
241
- return
242
-
243
- with torch.no_grad():
244
- # Calculate mean of all rows including the rows to be reset
245
- mean_vector = self.linear.weight.data.mean(
246
- dim=0, keepdim=True
247
- ) # [1, in_features]
248
- update_data = mean_vector.expand(idx_tensor.size(0), -1)
249
- self.linear.weight.data[idx_tensor] = update_data
249
+ def get_reset_indices(self, dim):
250
+ base_rate = self.row_recycling_rate if dim == 0 else self.column_recycling_rate
251
+ p = base_rate * self._get_multiplier()
252
+ if dim == 0:
253
+ if self.xglu:
254
+ sample_space = self.linear.out_features // 2
255
+ else:
256
+ sample_space = self.linear.out_features
257
+ elif dim == 1:
258
+ sample_space = self.linear.in_features
259
+ else:
260
+ raise ValueError("`dim` must be 0 or 1")
250
261
 
251
- if self.linear.bias is not None:
252
- self.linear.bias.data[idx_tensor] = 0.0
262
+ # Sample the indices
263
+ probs = torch.rand(sample_space, device=self.linear.weight.device)
264
+ mask = probs < p
265
+ if mask.any():
266
+ return torch.nonzero(mask).squeeze(-1)
267
+ else:
268
+ return torch.tensor([], dtype=torch.long, device=self.linear.weight.device)
253
269
 
254
- self._reset_optim_state(self.linear.weight, idx_tensor, optimisers, dim=0)
255
- if self.linear.bias is not None:
256
- self._reset_optim_state(self.linear.bias, idx_tensor, optimisers, dim=0)
270
+ def _random_weights(self, rows, columns):
271
+ device = self.linear.weight.device
272
+ weights = self.linear.weight.data
273
+ stdv = 1.0 / math.sqrt(weights.size(1))
274
+ random_weights = torch.rand(rows, columns, device=device)
275
+ random_weights -= 0.5 # Range [-0.5, +0.5]
276
+ random_weights *= 2.0 * stdv # Range [-stdv, +stdv]
277
+ return random_weights
278
+
279
+ def _mean_gate_weights(self):
280
+ """
281
+ Only used when self.xglu
282
+ """
283
+ weights = self.linear.weight.data
284
+ rows = weights.size(0)
285
+ return self.linear.weight[: int(rows / 2)].data.mean(dim=0, keepdim=True)
257
286
 
258
- def reset_columns(
287
+ def _update_weights(
259
288
  self,
260
289
  indices: Iterable[int],
290
+ dim: int,
291
+ data: torch.Tensor,
261
292
  optimisers: Union[
262
293
  List[torch.optim.Optimizer], torch.optim.Optimizer, None
263
294
  ] = None,
264
295
  ):
265
- """
266
- Update some of the weight columns to be random as though reinitialised.
267
- """
268
296
  if optimisers is None:
269
297
  optimisers = []
270
298
  if not isinstance(optimisers, list):
271
299
  optimisers = [optimisers]
272
300
 
273
- device = self.linear.weight.device
274
- idx_tensor = torch.as_tensor(list(indices), dtype=torch.long, device=device)
301
+ if not torch.is_tensor(indices):
302
+ idx_tensor = torch.as_tensor(
303
+ list(indices), dtype=torch.long, device=self.linear.weight.device
304
+ )
305
+ else:
306
+ idx_tensor = indices
275
307
 
276
308
  if idx_tensor.numel() == 0:
277
309
  return
278
310
 
279
311
  with torch.no_grad():
280
- # 1. Generate Random Columns
281
- # Shape: [out_features, N_indices]
282
- weights = self.linear.weight.data
283
- stdv = 1.0 / math.sqrt(weights.size(1))
284
-
285
- # Generate [Rows, N] block
286
- random_weights = torch.rand(
287
- weights.size(0), idx_tensor.size(0), device=device
288
- )
289
- random_weights = (random_weights - 0.5) * 2.0 * stdv
290
-
291
- # 2. Update Weights (One-shot)
292
- # We assign into the columns specified by idx_tensor
293
- self.linear.weight.data[:, idx_tensor] = random_weights
294
-
295
- # 3. Update Optimizers
296
- # Bias is untouched by column resets (bias is shape [Out], cols are [In])
297
- self._reset_optim_state(self.linear.weight, idx_tensor, optimisers, dim=1)
312
+ if dim == 0:
313
+ self.linear.weight.data[idx_tensor] = data
314
+ elif dim == 1:
315
+ self.linear.weight.data[:, idx_tensor] = data
316
+ else:
317
+ raise ValueError("`dim` must be 0 or 1")
318
+
319
+ self._reset_optim_state(self.linear.weight, idx_tensor, optimisers, dim=dim)
298
320
 
299
321
  def _reset_optim_state(self, param, idx_tensor, optimisers, dim):
300
322
  """
broccoli/transformer.py CHANGED
@@ -410,21 +410,9 @@ class FeedforwardBlock(nn.Module):
410
410
 
411
411
  # Recycle weights if using recycling linear layers
412
412
  if self.training and self.recycling_enabled:
413
- multiplier = self.linear_in._get_multiplier()
414
- rate = self.master_recycling_rate * multiplier
415
- if rate > 0:
416
- probs = torch.rand(self.linear_out.in_features, device=x.device)
417
- mask = probs < rate
418
- if mask.any():
419
- indices = torch.nonzero(mask).squeeze(-1)
420
- self.linear_out.reset_columns(indices, self.linear_out.optimisers)
421
- if self.xglu:
422
- indices_in = torch.cat(
423
- [indices, indices + self.linear_out.in_features]
424
- )
425
- self.linear_in.reset_rows(indices_in, self.linear_in.optimisers)
426
- else:
427
- self.linear_in.reset_rows(indices, self.linear_in.optimisers)
413
+ indices = self.linear_out.get_reset_indices(1)
414
+ self.linear_in.reset_rows(indices)
415
+ self.linear_out.reset_columns(indices)
428
416
 
429
417
  if self.checkpoint:
430
418
  processed = checkpoint(self.process, x, use_reentrant=False)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: broccoli-ml
3
- Version: 9.2.1
3
+ Version: 9.3.0
4
4
  Summary: Some useful Pytorch models, circa 2025
5
5
  License: MIT
6
6
  Author: Nicholas Bailey
@@ -1,13 +1,13 @@
1
1
  broccoli/__init__.py,sha256=tmyspsVxqPZHRQCY_NRwpW4SMNBbtE8E_8z7l-SAzSo,127
2
2
  broccoli/activation.py,sha256=-Jf30C6iGqWCorC9HEGn2oduWwjeaCAxGLUUYIy1zX8,3438
3
3
  broccoli/cnn.py,sha256=WjoPDSpe3ttwxCBNfCVRdaCHvbeZ7G-a5_i8fUsK_d8,4889
4
- broccoli/linear.py,sha256=7uN7zVPJ6Ptec31O8a-GvWT5nZk56Wf1RLJRvUAT0yo,11406
4
+ broccoli/linear.py,sha256=Fn3eqgv1X2M5iXZmtP6jBzfUYuWMkiLlgkBDryv6Ho8,11999
5
5
  broccoli/rope.py,sha256=GRqApBNmYCFaDak0WL1xE_BC5CTTYKQU_PBdeTcQcjc,12557
6
6
  broccoli/tensor.py,sha256=um8mrxkYbvNDo-QvHlmJm8Aw6qcngOlUZPoAk_PMReA,4480
7
- broccoli/transformer.py,sha256=r-ggAeNDW5QpBi9As1U9sIfxITBOx0WHk_K4zWpyTM8,26233
7
+ broccoli/transformer.py,sha256=ULk-QQX3hAI14-aCKhp9QSebzX4KUjlisEGup2Eycck,25565
8
8
  broccoli/utils.py,sha256=oOWzn6dJ5nC_9r4zq0emmfmaYACJXJNFS48AOpW2jqc,358
9
9
  broccoli/vit.py,sha256=sC6K3FK3a8ojOgvNWSWhuZHBtnFrrTQbsDdlagcKJH4,22224
10
- broccoli_ml-9.2.1.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
11
- broccoli_ml-9.2.1.dist-info/METADATA,sha256=Nj7WnXKxlvSlrK8rQp9wizgPGs7ZMnhCi-KY5O6W-wc,1368
12
- broccoli_ml-9.2.1.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
13
- broccoli_ml-9.2.1.dist-info/RECORD,,
10
+ broccoli_ml-9.3.0.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
11
+ broccoli_ml-9.3.0.dist-info/METADATA,sha256=avjuGvDLh6q6v-7E3dCq0jCNC17-vag52vweC2W26QU,1368
12
+ broccoli_ml-9.3.0.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
13
+ broccoli_ml-9.3.0.dist-info/RECORD,,