broccoli-ml 9.4.0__tar.gz → 9.4.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {broccoli_ml-9.4.0 → broccoli_ml-9.4.1}/PKG-INFO +1 -1
- {broccoli_ml-9.4.0 → broccoli_ml-9.4.1}/broccoli/activation.py +1 -4
- {broccoli_ml-9.4.0 → broccoli_ml-9.4.1}/broccoli/linear.py +12 -4
- {broccoli_ml-9.4.0 → broccoli_ml-9.4.1}/pyproject.toml +1 -1
- {broccoli_ml-9.4.0 → broccoli_ml-9.4.1}/LICENSE +0 -0
- {broccoli_ml-9.4.0 → broccoli_ml-9.4.1}/README.md +0 -0
- {broccoli_ml-9.4.0 → broccoli_ml-9.4.1}/broccoli/__init__.py +0 -0
- {broccoli_ml-9.4.0 → broccoli_ml-9.4.1}/broccoli/cnn.py +0 -0
- {broccoli_ml-9.4.0 → broccoli_ml-9.4.1}/broccoli/rope.py +0 -0
- {broccoli_ml-9.4.0 → broccoli_ml-9.4.1}/broccoli/tensor.py +0 -0
- {broccoli_ml-9.4.0 → broccoli_ml-9.4.1}/broccoli/transformer.py +0 -0
- {broccoli_ml-9.4.0 → broccoli_ml-9.4.1}/broccoli/utils.py +0 -0
- {broccoli_ml-9.4.0 → broccoli_ml-9.4.1}/broccoli/vit.py +0 -0
|
@@ -46,10 +46,7 @@ class GELU(nn.Module):
|
|
|
46
46
|
|
|
47
47
|
class Swish(nn.Module):
|
|
48
48
|
"""
|
|
49
|
-
Implementation of (beta)
|
|
50
|
-
(https://arxiv.org/abs/2002.05202v1) and used to great effect in LLaMa 2.0.
|
|
51
|
-
|
|
52
|
-
Halves the incoming parameter count, which should be scaled up before input.
|
|
49
|
+
Implementation of (beta) Swish
|
|
53
50
|
"""
|
|
54
51
|
|
|
55
52
|
def __init__(self) -> None:
|
|
@@ -202,18 +202,18 @@ class RecyclingLinear(nn.Module):
|
|
|
202
202
|
idx_tensor = indices
|
|
203
203
|
|
|
204
204
|
if idx_tensor.size(0):
|
|
205
|
-
|
|
206
|
-
indices.size(0), self.linear.weight.size(1)
|
|
207
|
-
)
|
|
205
|
+
centred_value_weights = self._mean_value_weights()
|
|
208
206
|
if self.xglu:
|
|
209
207
|
gate_indices = indices
|
|
210
208
|
value_indices = indices + (self.linear.out_features // 2)
|
|
211
|
-
self._update_weights(value_indices, 0, random_weights, self.optimisers)
|
|
212
209
|
centred_gate_weights = self._mean_gate_weights()
|
|
213
210
|
centred_gate_weights = centred_gate_weights.expand(indices.size(0), -1)
|
|
214
211
|
self._update_weights(
|
|
215
212
|
gate_indices, 0, centred_gate_weights, self.optimisers # dim
|
|
216
213
|
)
|
|
214
|
+
self._update_weights(
|
|
215
|
+
value_indices, 0, centred_value_weights, self.optimisers
|
|
216
|
+
)
|
|
217
217
|
else:
|
|
218
218
|
return
|
|
219
219
|
|
|
@@ -277,6 +277,14 @@ class RecyclingLinear(nn.Module):
|
|
|
277
277
|
random_weights *= 2.0 * stdv # Range [-stdv, +stdv]
|
|
278
278
|
return random_weights
|
|
279
279
|
|
|
280
|
+
def _mean_value_weights(self):
|
|
281
|
+
"""
|
|
282
|
+
Only used when self.xglu
|
|
283
|
+
"""
|
|
284
|
+
weights = self.linear.weight.data
|
|
285
|
+
rows = weights.size(0)
|
|
286
|
+
return self.linear.weight[int(rows / 2) :].data.mean(dim=0, keepdim=True)
|
|
287
|
+
|
|
280
288
|
def _mean_gate_weights(self):
|
|
281
289
|
"""
|
|
282
290
|
Only used when self.xglu
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|