broccoli-ml 9.5.1__py3-none-any.whl → 9.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
broccoli/linear.py CHANGED
@@ -193,7 +193,12 @@ class RecyclingLinear(nn.Module):
193
193
  multipliers = [a / b for a, b in pairs if b != 0.0]
194
194
  return min(multipliers) if multipliers else 0.0
195
195
 
196
- def reset_rows(self, indices):
196
+ def reset_rows(self, indices, incoming_data=None):
197
+ """
198
+ Resets rows.
199
+ If incoming_data is provided, resets to the centroid (mean) of that data.
200
+ If not, resets to the mean of existing weights.
201
+ """
197
202
  if not torch.is_tensor(indices):
198
203
  idx_tensor = torch.as_tensor(
199
204
  list(indices), dtype=torch.long, device=self.linear.weight.device
@@ -201,24 +206,24 @@ class RecyclingLinear(nn.Module):
201
206
  else:
202
207
  idx_tensor = indices
203
208
 
204
- if idx_tensor.size(0):
205
- value_indices = indices
206
- centred_value_weights = self._mean_value_weights()
207
- centred_value_weights = centred_value_weights.expand(indices.size(0), -1)
208
- if self.xglu:
209
- gate_indices = indices
210
- value_indices = indices + (self.linear.out_features // 2)
211
- centred_gate_weights = self._mean_gate_weights()
212
- centred_gate_weights = centred_gate_weights.expand(indices.size(0), -1)
213
- self._update_weights(
214
- gate_indices, 0, centred_gate_weights, self.optimisers # dim
215
- )
216
- self._update_weights(
217
- value_indices, 0, centred_value_weights, self.optimisers
218
- )
219
- else:
209
+ if idx_tensor.numel() == 0:
220
210
  return
221
211
 
212
+ if incoming_data is not None:
213
+ target_center = self._mean_input_weights(incoming_data)
214
+ else:
215
+ target_center = self._mean_value_weights()
216
+
217
+ target_center = target_center.expand(idx_tensor.size(0), -1)
218
+
219
+ if self.xglu:
220
+ gate_indices = idx_tensor
221
+ value_indices = idx_tensor + (self.linear.out_features // 2)
222
+ self._update_weights(gate_indices, 0, target_center, self.optimisers)
223
+ self._update_weights(value_indices, 0, target_center, self.optimisers)
224
+ else:
225
+ self._update_weights(idx_tensor, 0, target_center, self.optimisers)
226
+
222
227
  def reset_columns(self, indices):
223
228
  if not torch.is_tensor(indices):
224
229
  idx_tensor = torch.as_tensor(
@@ -281,6 +286,17 @@ class RecyclingLinear(nn.Module):
281
286
  random_weights *= 2.0 * stdv # Range [-stdv, +stdv]
282
287
  return random_weights
283
288
 
289
+ def _mean_input_weights(self, input):
290
+ reduce_dims = list(range(input.ndim - 1))
291
+ data_mean = input.detach().mean(dim=reduce_dims, keepdim=True)
292
+
293
+ weights = self.linear.weight.data
294
+ stdv = 1.0 / math.sqrt(weights.size(1))
295
+ data_norm = data_mean.std() + 1e-6
296
+ scale_factor = stdv / data_norm
297
+
298
+ return data_mean * scale_factor
299
+
284
300
  def _mean_value_weights(self):
285
301
  """
286
302
  Only used when self.xglu
broccoli/transformer.py CHANGED
@@ -411,7 +411,7 @@ class FeedforwardBlock(nn.Module):
411
411
  # Recycle weights if using recycling linear layers
412
412
  if self.training and self.recycling_enabled:
413
413
  indices = self.linear_out.get_reset_indices(1)
414
- self.linear_in.reset_rows(indices)
414
+ self.linear_in.reset_rows(indices, incoming_data=x)
415
415
  self.linear_out.reset_columns(indices)
416
416
 
417
417
  if self.checkpoint:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: broccoli-ml
3
- Version: 9.5.1
3
+ Version: 9.6.0
4
4
  Summary: Some useful Pytorch models, circa 2025
5
5
  License: MIT
6
6
  Author: Nicholas Bailey
@@ -1,13 +1,13 @@
1
1
  broccoli/__init__.py,sha256=tmyspsVxqPZHRQCY_NRwpW4SMNBbtE8E_8z7l-SAzSo,127
2
2
  broccoli/activation.py,sha256=nrpTOrpg9k23_E4AJWy7VlXXAJCtCJCOR-TonEWJr04,3218
3
3
  broccoli/cnn.py,sha256=WjoPDSpe3ttwxCBNfCVRdaCHvbeZ7G-a5_i8fUsK_d8,4889
4
- broccoli/linear.py,sha256=i4U7ZC4ZWEH82YpDasx0Qs1pc3gkyL-3ajuyKCbsGTM,12649
4
+ broccoli/linear.py,sha256=W-3aNpBjd_0xRyzbCKkmg4H1qmslQOIQhB-WDDay2nM,13125
5
5
  broccoli/rope.py,sha256=GRqApBNmYCFaDak0WL1xE_BC5CTTYKQU_PBdeTcQcjc,12557
6
6
  broccoli/tensor.py,sha256=um8mrxkYbvNDo-QvHlmJm8Aw6qcngOlUZPoAk_PMReA,4480
7
- broccoli/transformer.py,sha256=ULk-QQX3hAI14-aCKhp9QSebzX4KUjlisEGup2Eycck,25565
7
+ broccoli/transformer.py,sha256=vmAgFuevFn9ekRqq2Fbiblzr5DOrucPvlHF87l0vUuM,25582
8
8
  broccoli/utils.py,sha256=oOWzn6dJ5nC_9r4zq0emmfmaYACJXJNFS48AOpW2jqc,358
9
9
  broccoli/vit.py,sha256=sC6K3FK3a8ojOgvNWSWhuZHBtnFrrTQbsDdlagcKJH4,22224
10
- broccoli_ml-9.5.1.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
11
- broccoli_ml-9.5.1.dist-info/METADATA,sha256=HXRWnuc_-Gs_g37_RP3-POTLmi7sZamlzYv5SJEun1Y,1368
12
- broccoli_ml-9.5.1.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
13
- broccoli_ml-9.5.1.dist-info/RECORD,,
10
+ broccoli_ml-9.6.0.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
11
+ broccoli_ml-9.6.0.dist-info/METADATA,sha256=zzGvnF60IZx7Htq3-R91wsQ0ey3VmSNde2XQZY0H11U,1368
12
+ broccoli_ml-9.6.0.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
13
+ broccoli_ml-9.6.0.dist-info/RECORD,,