broccoli-ml 9.4.1__tar.gz → 9.5.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {broccoli_ml-9.4.1 → broccoli_ml-9.5.0}/PKG-INFO +1 -1
- {broccoli_ml-9.4.1 → broccoli_ml-9.5.0}/broccoli/linear.py +12 -5
- {broccoli_ml-9.4.1 → broccoli_ml-9.5.0}/pyproject.toml +1 -1
- {broccoli_ml-9.4.1 → broccoli_ml-9.5.0}/LICENSE +0 -0
- {broccoli_ml-9.4.1 → broccoli_ml-9.5.0}/README.md +0 -0
- {broccoli_ml-9.4.1 → broccoli_ml-9.5.0}/broccoli/__init__.py +0 -0
- {broccoli_ml-9.4.1 → broccoli_ml-9.5.0}/broccoli/activation.py +0 -0
- {broccoli_ml-9.4.1 → broccoli_ml-9.5.0}/broccoli/cnn.py +0 -0
- {broccoli_ml-9.4.1 → broccoli_ml-9.5.0}/broccoli/rope.py +0 -0
- {broccoli_ml-9.4.1 → broccoli_ml-9.5.0}/broccoli/tensor.py +0 -0
- {broccoli_ml-9.4.1 → broccoli_ml-9.5.0}/broccoli/transformer.py +0 -0
- {broccoli_ml-9.4.1 → broccoli_ml-9.5.0}/broccoli/utils.py +0 -0
- {broccoli_ml-9.4.1 → broccoli_ml-9.5.0}/broccoli/vit.py +0 -0
|
@@ -202,7 +202,9 @@ class RecyclingLinear(nn.Module):
|
|
|
202
202
|
idx_tensor = indices
|
|
203
203
|
|
|
204
204
|
if idx_tensor.size(0):
|
|
205
|
+
value_indices = indices
|
|
205
206
|
centred_value_weights = self._mean_value_weights()
|
|
207
|
+
centred_value_weights = centred_value_weights.expand(indices.size(0), -1)
|
|
206
208
|
if self.xglu:
|
|
207
209
|
gate_indices = indices
|
|
208
210
|
value_indices = indices + (self.linear.out_features // 2)
|
|
@@ -226,11 +228,13 @@ class RecyclingLinear(nn.Module):
|
|
|
226
228
|
idx_tensor = indices
|
|
227
229
|
|
|
228
230
|
if idx_tensor.size(0):
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
device=self.linear.weight.device,
|
|
231
|
+
random_weights = self._random_weights(
|
|
232
|
+
self.linear.weight.size(0), indices.size(0)
|
|
232
233
|
)
|
|
233
|
-
|
|
234
|
+
random_weights *= (
|
|
235
|
+
0.01 # Make them quiet so they don't introduce loud noise!
|
|
236
|
+
)
|
|
237
|
+
self._update_weights(indices, 1, random_weights, self.optimisers) # dim
|
|
234
238
|
else:
|
|
235
239
|
return
|
|
236
240
|
|
|
@@ -283,7 +287,10 @@ class RecyclingLinear(nn.Module):
|
|
|
283
287
|
"""
|
|
284
288
|
weights = self.linear.weight.data
|
|
285
289
|
rows = weights.size(0)
|
|
286
|
-
|
|
290
|
+
if self.xglu:
|
|
291
|
+
return self.linear.weight[int(rows / 2) :].data.mean(dim=0, keepdim=True)
|
|
292
|
+
else:
|
|
293
|
+
return self.linear.weight.data.mean(dim=0, keepdim=True)
|
|
287
294
|
|
|
288
295
|
def _mean_gate_weights(self):
|
|
289
296
|
"""
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|