broccoli-ml 9.2.2__tar.gz → 9.4.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {broccoli_ml-9.2.2 → broccoli_ml-9.4.0}/PKG-INFO +1 -1
- {broccoli_ml-9.2.2 → broccoli_ml-9.4.0}/broccoli/linear.py +96 -76
- {broccoli_ml-9.2.2 → broccoli_ml-9.4.0}/broccoli/transformer.py +3 -15
- {broccoli_ml-9.2.2 → broccoli_ml-9.4.0}/pyproject.toml +1 -1
- {broccoli_ml-9.2.2 → broccoli_ml-9.4.0}/LICENSE +0 -0
- {broccoli_ml-9.2.2 → broccoli_ml-9.4.0}/README.md +0 -0
- {broccoli_ml-9.2.2 → broccoli_ml-9.4.0}/broccoli/__init__.py +0 -0
- {broccoli_ml-9.2.2 → broccoli_ml-9.4.0}/broccoli/activation.py +0 -0
- {broccoli_ml-9.2.2 → broccoli_ml-9.4.0}/broccoli/cnn.py +0 -0
- {broccoli_ml-9.2.2 → broccoli_ml-9.4.0}/broccoli/rope.py +0 -0
- {broccoli_ml-9.2.2 → broccoli_ml-9.4.0}/broccoli/tensor.py +0 -0
- {broccoli_ml-9.2.2 → broccoli_ml-9.4.0}/broccoli/utils.py +0 -0
- {broccoli_ml-9.2.2 → broccoli_ml-9.4.0}/broccoli/vit.py +0 -0
|
@@ -151,11 +151,13 @@ class RecyclingLinear(nn.Module):
|
|
|
151
151
|
row_recycling_rate: float = 0.0,
|
|
152
152
|
column_recycling_rate: float = 0.0,
|
|
153
153
|
adaptive=False,
|
|
154
|
+
xglu=False,
|
|
154
155
|
):
|
|
155
156
|
super().__init__()
|
|
156
157
|
self.in_features = in_features
|
|
157
158
|
self.out_features = out_features
|
|
158
159
|
self.bias = bias
|
|
160
|
+
self.xglu = xglu
|
|
159
161
|
self.linear = nn.Linear(in_features, out_features, bias=bias)
|
|
160
162
|
self.row_recycling_rate = row_recycling_rate
|
|
161
163
|
self.column_recycling_rate = column_recycling_rate
|
|
@@ -191,28 +193,51 @@ class RecyclingLinear(nn.Module):
|
|
|
191
193
|
multipliers = [a / b for a, b in pairs if b != 0.0]
|
|
192
194
|
return min(multipliers) if multipliers else 0.0
|
|
193
195
|
|
|
194
|
-
def
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
196
|
+
def reset_rows(self, indices):
|
|
197
|
+
if not torch.is_tensor(indices):
|
|
198
|
+
idx_tensor = torch.as_tensor(
|
|
199
|
+
list(indices), dtype=torch.long, device=self.linear.weight.device
|
|
200
|
+
)
|
|
201
|
+
else:
|
|
202
|
+
idx_tensor = indices
|
|
198
203
|
|
|
199
|
-
if
|
|
204
|
+
if idx_tensor.size(0):
|
|
205
|
+
random_weights = self._random_weights(
|
|
206
|
+
indices.size(0), self.linear.weight.size(1)
|
|
207
|
+
)
|
|
208
|
+
if self.xglu:
|
|
209
|
+
gate_indices = indices
|
|
210
|
+
value_indices = indices + (self.linear.out_features // 2)
|
|
211
|
+
self._update_weights(value_indices, 0, random_weights, self.optimisers)
|
|
212
|
+
centred_gate_weights = self._mean_gate_weights()
|
|
213
|
+
centred_gate_weights = centred_gate_weights.expand(indices.size(0), -1)
|
|
214
|
+
self._update_weights(
|
|
215
|
+
gate_indices, 0, centred_gate_weights, self.optimisers # dim
|
|
216
|
+
)
|
|
217
|
+
else:
|
|
218
|
+
return
|
|
219
|
+
|
|
220
|
+
def reset_columns(self, indices):
|
|
221
|
+
if not torch.is_tensor(indices):
|
|
222
|
+
idx_tensor = torch.as_tensor(
|
|
223
|
+
list(indices), dtype=torch.long, device=self.linear.weight.device
|
|
224
|
+
)
|
|
225
|
+
else:
|
|
226
|
+
idx_tensor = indices
|
|
200
227
|
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
if col_recycling_rate > 0:
|
|
210
|
-
probs = torch.rand(self.linear.in_features, device=x.device)
|
|
211
|
-
mask = probs < col_recycling_rate
|
|
212
|
-
if mask.any():
|
|
213
|
-
indices = torch.nonzero(mask).squeeze(-1)
|
|
214
|
-
self.reset_columns(indices, self.optimisers)
|
|
228
|
+
if idx_tensor.size(0):
|
|
229
|
+
zeros = torch.zeros(
|
|
230
|
+
(self.linear.weight.size(0), indices.size(0)),
|
|
231
|
+
device=self.linear.weight.device,
|
|
232
|
+
)
|
|
233
|
+
self._update_weights(indices, 1, zeros, self.optimisers) # dim
|
|
234
|
+
else:
|
|
235
|
+
return
|
|
215
236
|
|
|
237
|
+
def forward(self, x):
|
|
238
|
+
if self.training and self.optimisers:
|
|
239
|
+
self.reset_rows(self.get_reset_indices(0))
|
|
240
|
+
self.reset_columns(self.get_reset_indices(1))
|
|
216
241
|
elif self.training and not self._warned_about_registration:
|
|
217
242
|
warnings.warn(
|
|
218
243
|
"RecyclingLinear: No optimiser registered. Recycling disabled.",
|
|
@@ -222,82 +247,77 @@ class RecyclingLinear(nn.Module):
|
|
|
222
247
|
|
|
223
248
|
return self.linear(x)
|
|
224
249
|
|
|
225
|
-
def
|
|
226
|
-
self
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
if not isinstance(optimisers, list):
|
|
238
|
-
optimisers = [optimisers]
|
|
239
|
-
|
|
240
|
-
device = self.linear.weight.device
|
|
241
|
-
idx_tensor = torch.as_tensor(list(indices), dtype=torch.long, device=device)
|
|
242
|
-
|
|
243
|
-
if idx_tensor.numel() == 0:
|
|
244
|
-
return
|
|
245
|
-
|
|
246
|
-
with torch.no_grad():
|
|
247
|
-
# Calculate mean of all rows including the rows to be reset
|
|
248
|
-
mean_vector = self.linear.weight.data.mean(
|
|
249
|
-
dim=0, keepdim=True
|
|
250
|
-
) # [1, in_features]
|
|
251
|
-
update_data = mean_vector.expand(idx_tensor.size(0), -1)
|
|
252
|
-
self.linear.weight.data[idx_tensor] = update_data
|
|
250
|
+
def get_reset_indices(self, dim):
|
|
251
|
+
base_rate = self.row_recycling_rate if dim == 0 else self.column_recycling_rate
|
|
252
|
+
p = base_rate * self._get_multiplier()
|
|
253
|
+
if dim == 0:
|
|
254
|
+
if self.xglu:
|
|
255
|
+
sample_space = self.linear.out_features // 2
|
|
256
|
+
else:
|
|
257
|
+
sample_space = self.linear.out_features
|
|
258
|
+
elif dim == 1:
|
|
259
|
+
sample_space = self.linear.in_features
|
|
260
|
+
else:
|
|
261
|
+
raise ValueError("`dim` must be 0 or 1")
|
|
253
262
|
|
|
254
|
-
|
|
255
|
-
|
|
263
|
+
# Sample the indices
|
|
264
|
+
probs = torch.rand(sample_space, device=self.linear.weight.device)
|
|
265
|
+
mask = probs < p
|
|
266
|
+
if mask.any():
|
|
267
|
+
return torch.nonzero(mask).squeeze(-1)
|
|
268
|
+
else:
|
|
269
|
+
return torch.tensor([], dtype=torch.long, device=self.linear.weight.device)
|
|
256
270
|
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
271
|
+
def _random_weights(self, rows, columns):
|
|
272
|
+
device = self.linear.weight.device
|
|
273
|
+
weights = self.linear.weight.data
|
|
274
|
+
stdv = 1.0 / math.sqrt(weights.size(1))
|
|
275
|
+
random_weights = torch.rand(rows, columns, device=device)
|
|
276
|
+
random_weights -= 0.5 # Range [-0.5, +0.5]
|
|
277
|
+
random_weights *= 2.0 * stdv # Range [-stdv, +stdv]
|
|
278
|
+
return random_weights
|
|
279
|
+
|
|
280
|
+
def _mean_gate_weights(self):
|
|
281
|
+
"""
|
|
282
|
+
Only used when self.xglu
|
|
283
|
+
"""
|
|
284
|
+
weights = self.linear.weight.data
|
|
285
|
+
rows = weights.size(0)
|
|
286
|
+
return self.linear.weight[: int(rows / 2)].data.mean(dim=0, keepdim=True)
|
|
260
287
|
|
|
261
|
-
def
|
|
288
|
+
def _update_weights(
|
|
262
289
|
self,
|
|
263
290
|
indices: Iterable[int],
|
|
291
|
+
dim: int,
|
|
292
|
+
data: torch.Tensor,
|
|
264
293
|
optimisers: Union[
|
|
265
294
|
List[torch.optim.Optimizer], torch.optim.Optimizer, None
|
|
266
295
|
] = None,
|
|
267
296
|
):
|
|
268
|
-
"""
|
|
269
|
-
Update some of the weight columns to be random as though reinitialised.
|
|
270
|
-
"""
|
|
271
297
|
if optimisers is None:
|
|
272
298
|
optimisers = []
|
|
273
299
|
if not isinstance(optimisers, list):
|
|
274
300
|
optimisers = [optimisers]
|
|
275
301
|
|
|
276
|
-
|
|
277
|
-
|
|
302
|
+
if not torch.is_tensor(indices):
|
|
303
|
+
idx_tensor = torch.as_tensor(
|
|
304
|
+
list(indices), dtype=torch.long, device=self.linear.weight.device
|
|
305
|
+
)
|
|
306
|
+
else:
|
|
307
|
+
idx_tensor = indices
|
|
278
308
|
|
|
279
309
|
if idx_tensor.numel() == 0:
|
|
280
310
|
return
|
|
281
311
|
|
|
282
312
|
with torch.no_grad():
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
)
|
|
292
|
-
random_weights = (random_weights - 0.5) * 2.0 * stdv
|
|
293
|
-
|
|
294
|
-
# 2. Update Weights (One-shot)
|
|
295
|
-
# We assign into the columns specified by idx_tensor
|
|
296
|
-
self.linear.weight.data[:, idx_tensor] = random_weights
|
|
297
|
-
|
|
298
|
-
# 3. Update Optimizers
|
|
299
|
-
# Bias is untouched by column resets (bias is shape [Out], cols are [In])
|
|
300
|
-
self._reset_optim_state(self.linear.weight, idx_tensor, optimisers, dim=1)
|
|
313
|
+
if dim == 0:
|
|
314
|
+
self.linear.weight.data[idx_tensor] = data
|
|
315
|
+
elif dim == 1:
|
|
316
|
+
self.linear.weight.data[:, idx_tensor] = data
|
|
317
|
+
else:
|
|
318
|
+
raise ValueError("`dim` must be 0 or 1")
|
|
319
|
+
|
|
320
|
+
self._reset_optim_state(self.linear.weight, idx_tensor, optimisers, dim=dim)
|
|
301
321
|
|
|
302
322
|
def _reset_optim_state(self, param, idx_tensor, optimisers, dim):
|
|
303
323
|
"""
|
|
@@ -410,21 +410,9 @@ class FeedforwardBlock(nn.Module):
|
|
|
410
410
|
|
|
411
411
|
# Recycle weights if using recycling linear layers
|
|
412
412
|
if self.training and self.recycling_enabled:
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
probs = torch.rand(self.linear_out.in_features, device=x.device)
|
|
417
|
-
mask = probs < rate
|
|
418
|
-
if mask.any():
|
|
419
|
-
indices = torch.nonzero(mask).squeeze(-1)
|
|
420
|
-
self.linear_out.reset_columns(indices, self.linear_out.optimisers)
|
|
421
|
-
if self.xglu:
|
|
422
|
-
indices_in = torch.cat(
|
|
423
|
-
[indices, indices + self.linear_out.in_features]
|
|
424
|
-
)
|
|
425
|
-
self.linear_in.reset_rows(indices_in, self.linear_in.optimisers)
|
|
426
|
-
else:
|
|
427
|
-
self.linear_in.reset_rows(indices, self.linear_in.optimisers)
|
|
413
|
+
indices = self.linear_out.get_reset_indices(1)
|
|
414
|
+
self.linear_in.reset_rows(indices)
|
|
415
|
+
self.linear_out.reset_columns(indices)
|
|
428
416
|
|
|
429
417
|
if self.checkpoint:
|
|
430
418
|
processed = checkpoint(self.process, x, use_reentrant=False)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|