broccoli-ml 9.4.0__tar.gz → 9.4.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: broccoli-ml
3
- Version: 9.4.0
3
+ Version: 9.4.1
4
4
  Summary: Some useful Pytorch models, circa 2025
5
5
  License: MIT
6
6
  Author: Nicholas Bailey
@@ -46,10 +46,7 @@ class GELU(nn.Module):
46
46
 
47
47
  class Swish(nn.Module):
48
48
  """
49
- Implementation of (beta) SwiGLU, as introduced in "GLU Variants Improve Transformer"
50
- (https://arxiv.org/abs/2002.05202v1) and used to great effect in LLaMa 2.0.
51
-
52
- Halves the incoming parameter count, which should be scaled up before input.
49
+ Implementation of (beta) Swish
53
50
  """
54
51
 
55
52
  def __init__(self) -> None:
@@ -202,18 +202,18 @@ class RecyclingLinear(nn.Module):
202
202
  idx_tensor = indices
203
203
 
204
204
  if idx_tensor.size(0):
205
- random_weights = self._random_weights(
206
- indices.size(0), self.linear.weight.size(1)
207
- )
205
+ centred_value_weights = self._mean_value_weights()
208
206
  if self.xglu:
209
207
  gate_indices = indices
210
208
  value_indices = indices + (self.linear.out_features // 2)
211
- self._update_weights(value_indices, 0, random_weights, self.optimisers)
212
209
  centred_gate_weights = self._mean_gate_weights()
213
210
  centred_gate_weights = centred_gate_weights.expand(indices.size(0), -1)
214
211
  self._update_weights(
215
212
  gate_indices, 0, centred_gate_weights, self.optimisers # dim
216
213
  )
214
+ self._update_weights(
215
+ value_indices, 0, centred_value_weights, self.optimisers
216
+ )
217
217
  else:
218
218
  return
219
219
 
@@ -277,6 +277,14 @@ class RecyclingLinear(nn.Module):
277
277
  random_weights *= 2.0 * stdv # Range [-stdv, +stdv]
278
278
  return random_weights
279
279
 
280
+ def _mean_value_weights(self):
281
+ """
282
+ Only used when self.xglu
283
+ """
284
+ weights = self.linear.weight.data
285
+ rows = weights.size(0)
286
+ return self.linear.weight[int(rows / 2) :].data.mean(dim=0, keepdim=True)
287
+
280
288
  def _mean_gate_weights(self):
281
289
  """
282
290
  Only used when self.xglu
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "broccoli-ml"
3
- version = "9.4.0"
3
+ version = "9.4.1"
4
4
  description = "Some useful Pytorch models, circa 2025"
5
5
  authors = [
6
6
  {name = "Nicholas Bailey"}
File without changes
File without changes
File without changes
File without changes