broccoli-ml 9.4.0__tar.gz → 9.4.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {broccoli_ml-9.4.0 → broccoli_ml-9.4.2}/PKG-INFO +1 -1
- {broccoli_ml-9.4.0 → broccoli_ml-9.4.2}/broccoli/activation.py +1 -4
- {broccoli_ml-9.4.0 → broccoli_ml-9.4.2}/broccoli/linear.py +13 -4
- {broccoli_ml-9.4.0 → broccoli_ml-9.4.2}/pyproject.toml +1 -1
- {broccoli_ml-9.4.0 → broccoli_ml-9.4.2}/LICENSE +0 -0
- {broccoli_ml-9.4.0 → broccoli_ml-9.4.2}/README.md +0 -0
- {broccoli_ml-9.4.0 → broccoli_ml-9.4.2}/broccoli/__init__.py +0 -0
- {broccoli_ml-9.4.0 → broccoli_ml-9.4.2}/broccoli/cnn.py +0 -0
- {broccoli_ml-9.4.0 → broccoli_ml-9.4.2}/broccoli/rope.py +0 -0
- {broccoli_ml-9.4.0 → broccoli_ml-9.4.2}/broccoli/tensor.py +0 -0
- {broccoli_ml-9.4.0 → broccoli_ml-9.4.2}/broccoli/transformer.py +0 -0
- {broccoli_ml-9.4.0 → broccoli_ml-9.4.2}/broccoli/utils.py +0 -0
- {broccoli_ml-9.4.0 → broccoli_ml-9.4.2}/broccoli/vit.py +0 -0
|
@@ -46,10 +46,7 @@ class GELU(nn.Module):
|
|
|
46
46
|
|
|
47
47
|
class Swish(nn.Module):
|
|
48
48
|
"""
|
|
49
|
-
Implementation of (beta)
|
|
50
|
-
(https://arxiv.org/abs/2002.05202v1) and used to great effect in LLaMa 2.0.
|
|
51
|
-
|
|
52
|
-
Halves the incoming parameter count, which should be scaled up before input.
|
|
49
|
+
Implementation of (beta) Swish
|
|
53
50
|
"""
|
|
54
51
|
|
|
55
52
|
def __init__(self) -> None:
|
|
@@ -202,18 +202,19 @@ class RecyclingLinear(nn.Module):
|
|
|
202
202
|
idx_tensor = indices
|
|
203
203
|
|
|
204
204
|
if idx_tensor.size(0):
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
)
|
|
205
|
+
centred_value_weights = self._mean_value_weights()
|
|
206
|
+
centred_value_weights = centred_value_weights.expand(indices.size(0), -1)
|
|
208
207
|
if self.xglu:
|
|
209
208
|
gate_indices = indices
|
|
210
209
|
value_indices = indices + (self.linear.out_features // 2)
|
|
211
|
-
self._update_weights(value_indices, 0, random_weights, self.optimisers)
|
|
212
210
|
centred_gate_weights = self._mean_gate_weights()
|
|
213
211
|
centred_gate_weights = centred_gate_weights.expand(indices.size(0), -1)
|
|
214
212
|
self._update_weights(
|
|
215
213
|
gate_indices, 0, centred_gate_weights, self.optimisers # dim
|
|
216
214
|
)
|
|
215
|
+
self._update_weights(
|
|
216
|
+
value_indices, 0, centred_value_weights, self.optimisers
|
|
217
|
+
)
|
|
217
218
|
else:
|
|
218
219
|
return
|
|
219
220
|
|
|
@@ -277,6 +278,14 @@ class RecyclingLinear(nn.Module):
|
|
|
277
278
|
random_weights *= 2.0 * stdv # Range [-stdv, +stdv]
|
|
278
279
|
return random_weights
|
|
279
280
|
|
|
281
|
+
def _mean_value_weights(self):
|
|
282
|
+
"""
|
|
283
|
+
Only used when self.xglu
|
|
284
|
+
"""
|
|
285
|
+
weights = self.linear.weight.data
|
|
286
|
+
rows = weights.size(0)
|
|
287
|
+
return self.linear.weight[int(rows / 2) :].data.mean(dim=0, keepdim=True)
|
|
288
|
+
|
|
280
289
|
def _mean_gate_weights(self):
|
|
281
290
|
"""
|
|
282
291
|
Only used when self.xglu
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|