broccoli-ml 9.5.1__py3-none-any.whl → 9.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- broccoli/linear.py +33 -17
- broccoli/transformer.py +1 -1
- {broccoli_ml-9.5.1.dist-info → broccoli_ml-9.6.0.dist-info}/METADATA +1 -1
- {broccoli_ml-9.5.1.dist-info → broccoli_ml-9.6.0.dist-info}/RECORD +6 -6
- {broccoli_ml-9.5.1.dist-info → broccoli_ml-9.6.0.dist-info}/LICENSE +0 -0
- {broccoli_ml-9.5.1.dist-info → broccoli_ml-9.6.0.dist-info}/WHEEL +0 -0
broccoli/linear.py
CHANGED
|
@@ -193,7 +193,12 @@ class RecyclingLinear(nn.Module):
|
|
|
193
193
|
multipliers = [a / b for a, b in pairs if b != 0.0]
|
|
194
194
|
return min(multipliers) if multipliers else 0.0
|
|
195
195
|
|
|
196
|
-
def reset_rows(self, indices):
|
|
196
|
+
def reset_rows(self, indices, incoming_data=None):
|
|
197
|
+
"""
|
|
198
|
+
Resets rows.
|
|
199
|
+
If incoming_data is provided, resets to the centroid (mean) of that data.
|
|
200
|
+
If not, resets to the mean of existing weights.
|
|
201
|
+
"""
|
|
197
202
|
if not torch.is_tensor(indices):
|
|
198
203
|
idx_tensor = torch.as_tensor(
|
|
199
204
|
list(indices), dtype=torch.long, device=self.linear.weight.device
|
|
@@ -201,24 +206,24 @@ class RecyclingLinear(nn.Module):
|
|
|
201
206
|
else:
|
|
202
207
|
idx_tensor = indices
|
|
203
208
|
|
|
204
|
-
if idx_tensor.
|
|
205
|
-
value_indices = indices
|
|
206
|
-
centred_value_weights = self._mean_value_weights()
|
|
207
|
-
centred_value_weights = centred_value_weights.expand(indices.size(0), -1)
|
|
208
|
-
if self.xglu:
|
|
209
|
-
gate_indices = indices
|
|
210
|
-
value_indices = indices + (self.linear.out_features // 2)
|
|
211
|
-
centred_gate_weights = self._mean_gate_weights()
|
|
212
|
-
centred_gate_weights = centred_gate_weights.expand(indices.size(0), -1)
|
|
213
|
-
self._update_weights(
|
|
214
|
-
gate_indices, 0, centred_gate_weights, self.optimisers # dim
|
|
215
|
-
)
|
|
216
|
-
self._update_weights(
|
|
217
|
-
value_indices, 0, centred_value_weights, self.optimisers
|
|
218
|
-
)
|
|
219
|
-
else:
|
|
209
|
+
if idx_tensor.numel() == 0:
|
|
220
210
|
return
|
|
221
211
|
|
|
212
|
+
if incoming_data is not None:
|
|
213
|
+
target_center = self._mean_input_weights(incoming_data)
|
|
214
|
+
else:
|
|
215
|
+
target_center = self._mean_value_weights()
|
|
216
|
+
|
|
217
|
+
target_center = target_center.expand(idx_tensor.size(0), -1)
|
|
218
|
+
|
|
219
|
+
if self.xglu:
|
|
220
|
+
gate_indices = idx_tensor
|
|
221
|
+
value_indices = idx_tensor + (self.linear.out_features // 2)
|
|
222
|
+
self._update_weights(gate_indices, 0, target_center, self.optimisers)
|
|
223
|
+
self._update_weights(value_indices, 0, target_center, self.optimisers)
|
|
224
|
+
else:
|
|
225
|
+
self._update_weights(idx_tensor, 0, target_center, self.optimisers)
|
|
226
|
+
|
|
222
227
|
def reset_columns(self, indices):
|
|
223
228
|
if not torch.is_tensor(indices):
|
|
224
229
|
idx_tensor = torch.as_tensor(
|
|
@@ -281,6 +286,17 @@ class RecyclingLinear(nn.Module):
|
|
|
281
286
|
random_weights *= 2.0 * stdv # Range [-stdv, +stdv]
|
|
282
287
|
return random_weights
|
|
283
288
|
|
|
289
|
+
def _mean_input_weights(self, input):
|
|
290
|
+
reduce_dims = list(range(input.ndim - 1))
|
|
291
|
+
data_mean = input.detach().mean(dim=reduce_dims, keepdim=True)
|
|
292
|
+
|
|
293
|
+
weights = self.linear.weight.data
|
|
294
|
+
stdv = 1.0 / math.sqrt(weights.size(1))
|
|
295
|
+
data_norm = data_mean.std() + 1e-6
|
|
296
|
+
scale_factor = stdv / data_norm
|
|
297
|
+
|
|
298
|
+
return data_mean * scale_factor
|
|
299
|
+
|
|
284
300
|
def _mean_value_weights(self):
|
|
285
301
|
"""
|
|
286
302
|
Only used when self.xglu
|
broccoli/transformer.py
CHANGED
|
@@ -411,7 +411,7 @@ class FeedforwardBlock(nn.Module):
|
|
|
411
411
|
# Recycle weights if using recycling linear layers
|
|
412
412
|
if self.training and self.recycling_enabled:
|
|
413
413
|
indices = self.linear_out.get_reset_indices(1)
|
|
414
|
-
self.linear_in.reset_rows(indices)
|
|
414
|
+
self.linear_in.reset_rows(indices, incoming_data=x)
|
|
415
415
|
self.linear_out.reset_columns(indices)
|
|
416
416
|
|
|
417
417
|
if self.checkpoint:
|
|
@@ -1,13 +1,13 @@
|
|
|
1
1
|
broccoli/__init__.py,sha256=tmyspsVxqPZHRQCY_NRwpW4SMNBbtE8E_8z7l-SAzSo,127
|
|
2
2
|
broccoli/activation.py,sha256=nrpTOrpg9k23_E4AJWy7VlXXAJCtCJCOR-TonEWJr04,3218
|
|
3
3
|
broccoli/cnn.py,sha256=WjoPDSpe3ttwxCBNfCVRdaCHvbeZ7G-a5_i8fUsK_d8,4889
|
|
4
|
-
broccoli/linear.py,sha256=
|
|
4
|
+
broccoli/linear.py,sha256=W-3aNpBjd_0xRyzbCKkmg4H1qmslQOIQhB-WDDay2nM,13125
|
|
5
5
|
broccoli/rope.py,sha256=GRqApBNmYCFaDak0WL1xE_BC5CTTYKQU_PBdeTcQcjc,12557
|
|
6
6
|
broccoli/tensor.py,sha256=um8mrxkYbvNDo-QvHlmJm8Aw6qcngOlUZPoAk_PMReA,4480
|
|
7
|
-
broccoli/transformer.py,sha256=
|
|
7
|
+
broccoli/transformer.py,sha256=vmAgFuevFn9ekRqq2Fbiblzr5DOrucPvlHF87l0vUuM,25582
|
|
8
8
|
broccoli/utils.py,sha256=oOWzn6dJ5nC_9r4zq0emmfmaYACJXJNFS48AOpW2jqc,358
|
|
9
9
|
broccoli/vit.py,sha256=sC6K3FK3a8ojOgvNWSWhuZHBtnFrrTQbsDdlagcKJH4,22224
|
|
10
|
-
broccoli_ml-9.
|
|
11
|
-
broccoli_ml-9.
|
|
12
|
-
broccoli_ml-9.
|
|
13
|
-
broccoli_ml-9.
|
|
10
|
+
broccoli_ml-9.6.0.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
|
|
11
|
+
broccoli_ml-9.6.0.dist-info/METADATA,sha256=zzGvnF60IZx7Htq3-R91wsQ0ey3VmSNde2XQZY0H11U,1368
|
|
12
|
+
broccoli_ml-9.6.0.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
|
13
|
+
broccoli_ml-9.6.0.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|