broccoli-ml 9.4.0__tar.gz → 9.4.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: broccoli-ml
3
- Version: 9.4.0
3
+ Version: 9.4.2
4
4
  Summary: Some useful Pytorch models, circa 2025
5
5
  License: MIT
6
6
  Author: Nicholas Bailey
@@ -46,10 +46,7 @@ class GELU(nn.Module):
46
46
 
47
47
  class Swish(nn.Module):
48
48
  """
49
- Implementation of (beta) SwiGLU, as introduced in "GLU Variants Improve Transformer"
50
- (https://arxiv.org/abs/2002.05202v1) and used to great effect in LLaMa 2.0.
51
-
52
- Halves the incoming parameter count, which should be scaled up before input.
49
+ Implementation of (beta) Swish
53
50
  """
54
51
 
55
52
  def __init__(self) -> None:
@@ -202,18 +202,19 @@ class RecyclingLinear(nn.Module):
202
202
  idx_tensor = indices
203
203
 
204
204
  if idx_tensor.size(0):
205
- random_weights = self._random_weights(
206
- indices.size(0), self.linear.weight.size(1)
207
- )
205
+ centred_value_weights = self._mean_value_weights()
206
+ centred_value_weights = centred_value_weights.expand(indices.size(0), -1)
208
207
  if self.xglu:
209
208
  gate_indices = indices
210
209
  value_indices = indices + (self.linear.out_features // 2)
211
- self._update_weights(value_indices, 0, random_weights, self.optimisers)
212
210
  centred_gate_weights = self._mean_gate_weights()
213
211
  centred_gate_weights = centred_gate_weights.expand(indices.size(0), -1)
214
212
  self._update_weights(
215
213
  gate_indices, 0, centred_gate_weights, self.optimisers # dim
216
214
  )
215
+ self._update_weights(
216
+ value_indices, 0, centred_value_weights, self.optimisers
217
+ )
217
218
  else:
218
219
  return
219
220
 
@@ -277,6 +278,14 @@ class RecyclingLinear(nn.Module):
277
278
  random_weights *= 2.0 * stdv # Range [-stdv, +stdv]
278
279
  return random_weights
279
280
 
281
+ def _mean_value_weights(self):
282
+ """
283
+ Only used when self.xglu
284
+ """
285
+ weights = self.linear.weight.data
286
+ rows = weights.size(0)
287
+ return self.linear.weight[int(rows / 2) :].data.mean(dim=0, keepdim=True)
288
+
280
289
  def _mean_gate_weights(self):
281
290
  """
282
291
  Only used when self.xglu
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "broccoli-ml"
3
- version = "9.4.0"
3
+ version = "9.4.2"
4
4
  description = "Some useful Pytorch models, circa 2025"
5
5
  authors = [
6
6
  {name = "Nicholas Bailey"}
File without changes
File without changes
File without changes
File without changes