broccoli-ml 9.4.2__tar.gz → 9.5.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {broccoli_ml-9.4.2 → broccoli_ml-9.5.0}/PKG-INFO +1 -1
- {broccoli_ml-9.4.2 → broccoli_ml-9.5.0}/broccoli/linear.py +11 -5
- {broccoli_ml-9.4.2 → broccoli_ml-9.5.0}/pyproject.toml +1 -1
- {broccoli_ml-9.4.2 → broccoli_ml-9.5.0}/LICENSE +0 -0
- {broccoli_ml-9.4.2 → broccoli_ml-9.5.0}/README.md +0 -0
- {broccoli_ml-9.4.2 → broccoli_ml-9.5.0}/broccoli/__init__.py +0 -0
- {broccoli_ml-9.4.2 → broccoli_ml-9.5.0}/broccoli/activation.py +0 -0
- {broccoli_ml-9.4.2 → broccoli_ml-9.5.0}/broccoli/cnn.py +0 -0
- {broccoli_ml-9.4.2 → broccoli_ml-9.5.0}/broccoli/rope.py +0 -0
- {broccoli_ml-9.4.2 → broccoli_ml-9.5.0}/broccoli/tensor.py +0 -0
- {broccoli_ml-9.4.2 → broccoli_ml-9.5.0}/broccoli/transformer.py +0 -0
- {broccoli_ml-9.4.2 → broccoli_ml-9.5.0}/broccoli/utils.py +0 -0
- {broccoli_ml-9.4.2 → broccoli_ml-9.5.0}/broccoli/vit.py +0 -0
|
@@ -202,6 +202,7 @@ class RecyclingLinear(nn.Module):
|
|
|
202
202
|
idx_tensor = indices
|
|
203
203
|
|
|
204
204
|
if idx_tensor.size(0):
|
|
205
|
+
value_indices = indices
|
|
205
206
|
centred_value_weights = self._mean_value_weights()
|
|
206
207
|
centred_value_weights = centred_value_weights.expand(indices.size(0), -1)
|
|
207
208
|
if self.xglu:
|
|
@@ -227,11 +228,13 @@ class RecyclingLinear(nn.Module):
|
|
|
227
228
|
idx_tensor = indices
|
|
228
229
|
|
|
229
230
|
if idx_tensor.size(0):
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
device=self.linear.weight.device,
|
|
231
|
+
random_weights = self._random_weights(
|
|
232
|
+
self.linear.weight.size(0), indices.size(0)
|
|
233
233
|
)
|
|
234
|
-
|
|
234
|
+
random_weights *= (
|
|
235
|
+
0.01 # Make them quiet so they don't introduce loud noise!
|
|
236
|
+
)
|
|
237
|
+
self._update_weights(indices, 1, random_weights, self.optimisers) # dim
|
|
235
238
|
else:
|
|
236
239
|
return
|
|
237
240
|
|
|
@@ -284,7 +287,10 @@ class RecyclingLinear(nn.Module):
|
|
|
284
287
|
"""
|
|
285
288
|
weights = self.linear.weight.data
|
|
286
289
|
rows = weights.size(0)
|
|
287
|
-
|
|
290
|
+
if self.xglu:
|
|
291
|
+
return self.linear.weight[int(rows / 2) :].data.mean(dim=0, keepdim=True)
|
|
292
|
+
else:
|
|
293
|
+
return self.linear.weight.data.mean(dim=0, keepdim=True)
|
|
288
294
|
|
|
289
295
|
def _mean_gate_weights(self):
|
|
290
296
|
"""
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|