broccoli-ml 9.3.0__tar.gz → 9.4.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {broccoli_ml-9.3.0 → broccoli_ml-9.4.1}/PKG-INFO +1 -1
- {broccoli_ml-9.3.0 → broccoli_ml-9.4.1}/broccoli/activation.py +1 -4
- {broccoli_ml-9.3.0 → broccoli_ml-9.4.1}/broccoli/linear.py +16 -7
- {broccoli_ml-9.3.0 → broccoli_ml-9.4.1}/pyproject.toml +1 -1
- {broccoli_ml-9.3.0 → broccoli_ml-9.4.1}/LICENSE +0 -0
- {broccoli_ml-9.3.0 → broccoli_ml-9.4.1}/README.md +0 -0
- {broccoli_ml-9.3.0 → broccoli_ml-9.4.1}/broccoli/__init__.py +0 -0
- {broccoli_ml-9.3.0 → broccoli_ml-9.4.1}/broccoli/cnn.py +0 -0
- {broccoli_ml-9.3.0 → broccoli_ml-9.4.1}/broccoli/rope.py +0 -0
- {broccoli_ml-9.3.0 → broccoli_ml-9.4.1}/broccoli/tensor.py +0 -0
- {broccoli_ml-9.3.0 → broccoli_ml-9.4.1}/broccoli/transformer.py +0 -0
- {broccoli_ml-9.3.0 → broccoli_ml-9.4.1}/broccoli/utils.py +0 -0
- {broccoli_ml-9.3.0 → broccoli_ml-9.4.1}/broccoli/vit.py +0 -0
|
@@ -46,10 +46,7 @@ class GELU(nn.Module):
|
|
|
46
46
|
|
|
47
47
|
class Swish(nn.Module):
|
|
48
48
|
"""
|
|
49
|
-
Implementation of (beta)
|
|
50
|
-
(https://arxiv.org/abs/2002.05202v1) and used to great effect in LLaMa 2.0.
|
|
51
|
-
|
|
52
|
-
Halves the incoming parameter count, which should be scaled up before input.
|
|
49
|
+
Implementation of (beta) Swish
|
|
53
50
|
"""
|
|
54
51
|
|
|
55
52
|
def __init__(self) -> None:
|
|
@@ -202,18 +202,18 @@ class RecyclingLinear(nn.Module):
|
|
|
202
202
|
idx_tensor = indices
|
|
203
203
|
|
|
204
204
|
if idx_tensor.size(0):
|
|
205
|
-
|
|
206
|
-
indices.size(0), self.linear.weight.size(1)
|
|
207
|
-
)
|
|
205
|
+
centred_value_weights = self._mean_value_weights()
|
|
208
206
|
if self.xglu:
|
|
209
207
|
gate_indices = indices
|
|
210
208
|
value_indices = indices + (self.linear.out_features // 2)
|
|
211
|
-
self._update_weights(value_indices, 0, random_weights, self.optimisers)
|
|
212
209
|
centred_gate_weights = self._mean_gate_weights()
|
|
213
210
|
centred_gate_weights = centred_gate_weights.expand(indices.size(0), -1)
|
|
214
211
|
self._update_weights(
|
|
215
212
|
gate_indices, 0, centred_gate_weights, self.optimisers # dim
|
|
216
213
|
)
|
|
214
|
+
self._update_weights(
|
|
215
|
+
value_indices, 0, centred_value_weights, self.optimisers
|
|
216
|
+
)
|
|
217
217
|
else:
|
|
218
218
|
return
|
|
219
219
|
|
|
@@ -226,10 +226,11 @@ class RecyclingLinear(nn.Module):
|
|
|
226
226
|
idx_tensor = indices
|
|
227
227
|
|
|
228
228
|
if idx_tensor.size(0):
|
|
229
|
-
|
|
230
|
-
self.linear.weight.size(0), indices.size(0)
|
|
229
|
+
zeros = torch.zeros(
|
|
230
|
+
(self.linear.weight.size(0), indices.size(0)),
|
|
231
|
+
device=self.linear.weight.device,
|
|
231
232
|
)
|
|
232
|
-
self._update_weights(indices, 1,
|
|
233
|
+
self._update_weights(indices, 1, zeros, self.optimisers) # dim
|
|
233
234
|
else:
|
|
234
235
|
return
|
|
235
236
|
|
|
@@ -276,6 +277,14 @@ class RecyclingLinear(nn.Module):
|
|
|
276
277
|
random_weights *= 2.0 * stdv # Range [-stdv, +stdv]
|
|
277
278
|
return random_weights
|
|
278
279
|
|
|
280
|
+
def _mean_value_weights(self):
|
|
281
|
+
"""
|
|
282
|
+
Only used when self.xglu
|
|
283
|
+
"""
|
|
284
|
+
weights = self.linear.weight.data
|
|
285
|
+
rows = weights.size(0)
|
|
286
|
+
return self.linear.weight[int(rows / 2) :].data.mean(dim=0, keepdim=True)
|
|
287
|
+
|
|
279
288
|
def _mean_gate_weights(self):
|
|
280
289
|
"""
|
|
281
290
|
Only used when self.xglu
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|