broccoli-ml 9.4.1__tar.gz → 9.5.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: broccoli-ml
3
- Version: 9.4.1
3
+ Version: 9.5.0
4
4
  Summary: Some useful Pytorch models, circa 2025
5
5
  License: MIT
6
6
  Author: Nicholas Bailey
@@ -202,7 +202,9 @@ class RecyclingLinear(nn.Module):
202
202
  idx_tensor = indices
203
203
 
204
204
  if idx_tensor.size(0):
205
+ value_indices = indices
205
206
  centred_value_weights = self._mean_value_weights()
207
+ centred_value_weights = centred_value_weights.expand(indices.size(0), -1)
206
208
  if self.xglu:
207
209
  gate_indices = indices
208
210
  value_indices = indices + (self.linear.out_features // 2)
@@ -226,11 +228,13 @@ class RecyclingLinear(nn.Module):
226
228
  idx_tensor = indices
227
229
 
228
230
  if idx_tensor.size(0):
229
- zeros = torch.zeros(
230
- (self.linear.weight.size(0), indices.size(0)),
231
- device=self.linear.weight.device,
231
+ random_weights = self._random_weights(
232
+ self.linear.weight.size(0), indices.size(0)
232
233
  )
233
- self._update_weights(indices, 1, zeros, self.optimisers) # dim
234
+ random_weights *= (
235
+ 0.01 # Make them quiet so they don't introduce loud noise!
236
+ )
237
+ self._update_weights(indices, 1, random_weights, self.optimisers) # dim
234
238
  else:
235
239
  return
236
240
 
@@ -283,7 +287,10 @@ class RecyclingLinear(nn.Module):
283
287
  """
284
288
  weights = self.linear.weight.data
285
289
  rows = weights.size(0)
286
- return self.linear.weight[int(rows / 2) :].data.mean(dim=0, keepdim=True)
290
+ if self.xglu:
291
+ return self.linear.weight[int(rows / 2) :].data.mean(dim=0, keepdim=True)
292
+ else:
293
+ return self.linear.weight.data.mean(dim=0, keepdim=True)
287
294
 
288
295
  def _mean_gate_weights(self):
289
296
  """
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "broccoli-ml"
3
- version = "9.4.1"
3
+ version = "9.5.0"
4
4
  description = "Some useful Pytorch models, circa 2025"
5
5
  authors = [
6
6
  {name = "Nicholas Bailey"}
File without changes
File without changes
File without changes
File without changes