broccoli-ml 9.2.2__tar.gz → 9.3.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {broccoli_ml-9.2.2 → broccoli_ml-9.3.0}/PKG-INFO +1 -1
- {broccoli_ml-9.2.2 → broccoli_ml-9.3.0}/broccoli/linear.py +95 -76
- {broccoli_ml-9.2.2 → broccoli_ml-9.3.0}/broccoli/transformer.py +3 -15
- {broccoli_ml-9.2.2 → broccoli_ml-9.3.0}/pyproject.toml +1 -1
- {broccoli_ml-9.2.2 → broccoli_ml-9.3.0}/LICENSE +0 -0
- {broccoli_ml-9.2.2 → broccoli_ml-9.3.0}/README.md +0 -0
- {broccoli_ml-9.2.2 → broccoli_ml-9.3.0}/broccoli/__init__.py +0 -0
- {broccoli_ml-9.2.2 → broccoli_ml-9.3.0}/broccoli/activation.py +0 -0
- {broccoli_ml-9.2.2 → broccoli_ml-9.3.0}/broccoli/cnn.py +0 -0
- {broccoli_ml-9.2.2 → broccoli_ml-9.3.0}/broccoli/rope.py +0 -0
- {broccoli_ml-9.2.2 → broccoli_ml-9.3.0}/broccoli/tensor.py +0 -0
- {broccoli_ml-9.2.2 → broccoli_ml-9.3.0}/broccoli/utils.py +0 -0
- {broccoli_ml-9.2.2 → broccoli_ml-9.3.0}/broccoli/vit.py +0 -0
|
@@ -151,11 +151,13 @@ class RecyclingLinear(nn.Module):
|
|
|
151
151
|
row_recycling_rate: float = 0.0,
|
|
152
152
|
column_recycling_rate: float = 0.0,
|
|
153
153
|
adaptive=False,
|
|
154
|
+
xglu=False,
|
|
154
155
|
):
|
|
155
156
|
super().__init__()
|
|
156
157
|
self.in_features = in_features
|
|
157
158
|
self.out_features = out_features
|
|
158
159
|
self.bias = bias
|
|
160
|
+
self.xglu = xglu
|
|
159
161
|
self.linear = nn.Linear(in_features, out_features, bias=bias)
|
|
160
162
|
self.row_recycling_rate = row_recycling_rate
|
|
161
163
|
self.column_recycling_rate = column_recycling_rate
|
|
@@ -191,28 +193,50 @@ class RecyclingLinear(nn.Module):
|
|
|
191
193
|
multipliers = [a / b for a, b in pairs if b != 0.0]
|
|
192
194
|
return min(multipliers) if multipliers else 0.0
|
|
193
195
|
|
|
194
|
-
def
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
196
|
+
def reset_rows(self, indices):
|
|
197
|
+
if not torch.is_tensor(indices):
|
|
198
|
+
idx_tensor = torch.as_tensor(
|
|
199
|
+
list(indices), dtype=torch.long, device=self.linear.weight.device
|
|
200
|
+
)
|
|
201
|
+
else:
|
|
202
|
+
idx_tensor = indices
|
|
198
203
|
|
|
199
|
-
if
|
|
204
|
+
if idx_tensor.size(0):
|
|
205
|
+
random_weights = self._random_weights(
|
|
206
|
+
indices.size(0), self.linear.weight.size(1)
|
|
207
|
+
)
|
|
208
|
+
if self.xglu:
|
|
209
|
+
gate_indices = indices
|
|
210
|
+
value_indices = indices + (self.linear.out_features // 2)
|
|
211
|
+
self._update_weights(value_indices, 0, random_weights, self.optimisers)
|
|
212
|
+
centred_gate_weights = self._mean_gate_weights()
|
|
213
|
+
centred_gate_weights = centred_gate_weights.expand(indices.size(0), -1)
|
|
214
|
+
self._update_weights(
|
|
215
|
+
gate_indices, 0, centred_gate_weights, self.optimisers # dim
|
|
216
|
+
)
|
|
217
|
+
else:
|
|
218
|
+
return
|
|
219
|
+
|
|
220
|
+
def reset_columns(self, indices):
|
|
221
|
+
if not torch.is_tensor(indices):
|
|
222
|
+
idx_tensor = torch.as_tensor(
|
|
223
|
+
list(indices), dtype=torch.long, device=self.linear.weight.device
|
|
224
|
+
)
|
|
225
|
+
else:
|
|
226
|
+
idx_tensor = indices
|
|
200
227
|
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
if col_recycling_rate > 0:
|
|
210
|
-
probs = torch.rand(self.linear.in_features, device=x.device)
|
|
211
|
-
mask = probs < col_recycling_rate
|
|
212
|
-
if mask.any():
|
|
213
|
-
indices = torch.nonzero(mask).squeeze(-1)
|
|
214
|
-
self.reset_columns(indices, self.optimisers)
|
|
228
|
+
if idx_tensor.size(0):
|
|
229
|
+
random_weights = self._random_weights(
|
|
230
|
+
self.linear.weight.size(0), indices.size(0)
|
|
231
|
+
)
|
|
232
|
+
self._update_weights(indices, 1, random_weights, self.optimisers) # dim
|
|
233
|
+
else:
|
|
234
|
+
return
|
|
215
235
|
|
|
236
|
+
def forward(self, x):
|
|
237
|
+
if self.training and self.optimisers:
|
|
238
|
+
self.reset_rows(self.get_reset_indices(0))
|
|
239
|
+
self.reset_columns(self.get_reset_indices(1))
|
|
216
240
|
elif self.training and not self._warned_about_registration:
|
|
217
241
|
warnings.warn(
|
|
218
242
|
"RecyclingLinear: No optimiser registered. Recycling disabled.",
|
|
@@ -222,82 +246,77 @@ class RecyclingLinear(nn.Module):
|
|
|
222
246
|
|
|
223
247
|
return self.linear(x)
|
|
224
248
|
|
|
225
|
-
def
|
|
226
|
-
self
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
if not isinstance(optimisers, list):
|
|
238
|
-
optimisers = [optimisers]
|
|
239
|
-
|
|
240
|
-
device = self.linear.weight.device
|
|
241
|
-
idx_tensor = torch.as_tensor(list(indices), dtype=torch.long, device=device)
|
|
242
|
-
|
|
243
|
-
if idx_tensor.numel() == 0:
|
|
244
|
-
return
|
|
245
|
-
|
|
246
|
-
with torch.no_grad():
|
|
247
|
-
# Calculate mean of all rows including the rows to be reset
|
|
248
|
-
mean_vector = self.linear.weight.data.mean(
|
|
249
|
-
dim=0, keepdim=True
|
|
250
|
-
) # [1, in_features]
|
|
251
|
-
update_data = mean_vector.expand(idx_tensor.size(0), -1)
|
|
252
|
-
self.linear.weight.data[idx_tensor] = update_data
|
|
249
|
+
def get_reset_indices(self, dim):
|
|
250
|
+
base_rate = self.row_recycling_rate if dim == 0 else self.column_recycling_rate
|
|
251
|
+
p = base_rate * self._get_multiplier()
|
|
252
|
+
if dim == 0:
|
|
253
|
+
if self.xglu:
|
|
254
|
+
sample_space = self.linear.out_features // 2
|
|
255
|
+
else:
|
|
256
|
+
sample_space = self.linear.out_features
|
|
257
|
+
elif dim == 1:
|
|
258
|
+
sample_space = self.linear.in_features
|
|
259
|
+
else:
|
|
260
|
+
raise ValueError("`dim` must be 0 or 1")
|
|
253
261
|
|
|
254
|
-
|
|
255
|
-
|
|
262
|
+
# Sample the indices
|
|
263
|
+
probs = torch.rand(sample_space, device=self.linear.weight.device)
|
|
264
|
+
mask = probs < p
|
|
265
|
+
if mask.any():
|
|
266
|
+
return torch.nonzero(mask).squeeze(-1)
|
|
267
|
+
else:
|
|
268
|
+
return torch.tensor([], dtype=torch.long, device=self.linear.weight.device)
|
|
256
269
|
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
270
|
+
def _random_weights(self, rows, columns):
|
|
271
|
+
device = self.linear.weight.device
|
|
272
|
+
weights = self.linear.weight.data
|
|
273
|
+
stdv = 1.0 / math.sqrt(weights.size(1))
|
|
274
|
+
random_weights = torch.rand(rows, columns, device=device)
|
|
275
|
+
random_weights -= 0.5 # Range [-0.5, +0.5]
|
|
276
|
+
random_weights *= 2.0 * stdv # Range [-stdv, +stdv]
|
|
277
|
+
return random_weights
|
|
278
|
+
|
|
279
|
+
def _mean_gate_weights(self):
|
|
280
|
+
"""
|
|
281
|
+
Only used when self.xglu
|
|
282
|
+
"""
|
|
283
|
+
weights = self.linear.weight.data
|
|
284
|
+
rows = weights.size(0)
|
|
285
|
+
return self.linear.weight[: int(rows / 2)].data.mean(dim=0, keepdim=True)
|
|
260
286
|
|
|
261
|
-
def
|
|
287
|
+
def _update_weights(
|
|
262
288
|
self,
|
|
263
289
|
indices: Iterable[int],
|
|
290
|
+
dim: int,
|
|
291
|
+
data: torch.Tensor,
|
|
264
292
|
optimisers: Union[
|
|
265
293
|
List[torch.optim.Optimizer], torch.optim.Optimizer, None
|
|
266
294
|
] = None,
|
|
267
295
|
):
|
|
268
|
-
"""
|
|
269
|
-
Update some of the weight columns to be random as though reinitialised.
|
|
270
|
-
"""
|
|
271
296
|
if optimisers is None:
|
|
272
297
|
optimisers = []
|
|
273
298
|
if not isinstance(optimisers, list):
|
|
274
299
|
optimisers = [optimisers]
|
|
275
300
|
|
|
276
|
-
|
|
277
|
-
|
|
301
|
+
if not torch.is_tensor(indices):
|
|
302
|
+
idx_tensor = torch.as_tensor(
|
|
303
|
+
list(indices), dtype=torch.long, device=self.linear.weight.device
|
|
304
|
+
)
|
|
305
|
+
else:
|
|
306
|
+
idx_tensor = indices
|
|
278
307
|
|
|
279
308
|
if idx_tensor.numel() == 0:
|
|
280
309
|
return
|
|
281
310
|
|
|
282
311
|
with torch.no_grad():
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
)
|
|
292
|
-
random_weights = (random_weights - 0.5) * 2.0 * stdv
|
|
293
|
-
|
|
294
|
-
# 2. Update Weights (One-shot)
|
|
295
|
-
# We assign into the columns specified by idx_tensor
|
|
296
|
-
self.linear.weight.data[:, idx_tensor] = random_weights
|
|
297
|
-
|
|
298
|
-
# 3. Update Optimizers
|
|
299
|
-
# Bias is untouched by column resets (bias is shape [Out], cols are [In])
|
|
300
|
-
self._reset_optim_state(self.linear.weight, idx_tensor, optimisers, dim=1)
|
|
312
|
+
if dim == 0:
|
|
313
|
+
self.linear.weight.data[idx_tensor] = data
|
|
314
|
+
elif dim == 1:
|
|
315
|
+
self.linear.weight.data[:, idx_tensor] = data
|
|
316
|
+
else:
|
|
317
|
+
raise ValueError("`dim` must be 0 or 1")
|
|
318
|
+
|
|
319
|
+
self._reset_optim_state(self.linear.weight, idx_tensor, optimisers, dim=dim)
|
|
301
320
|
|
|
302
321
|
def _reset_optim_state(self, param, idx_tensor, optimisers, dim):
|
|
303
322
|
"""
|
|
@@ -410,21 +410,9 @@ class FeedforwardBlock(nn.Module):
|
|
|
410
410
|
|
|
411
411
|
# Recycle weights if using recycling linear layers
|
|
412
412
|
if self.training and self.recycling_enabled:
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
probs = torch.rand(self.linear_out.in_features, device=x.device)
|
|
417
|
-
mask = probs < rate
|
|
418
|
-
if mask.any():
|
|
419
|
-
indices = torch.nonzero(mask).squeeze(-1)
|
|
420
|
-
self.linear_out.reset_columns(indices, self.linear_out.optimisers)
|
|
421
|
-
if self.xglu:
|
|
422
|
-
indices_in = torch.cat(
|
|
423
|
-
[indices, indices + self.linear_out.in_features]
|
|
424
|
-
)
|
|
425
|
-
self.linear_in.reset_rows(indices_in, self.linear_in.optimisers)
|
|
426
|
-
else:
|
|
427
|
-
self.linear_in.reset_rows(indices, self.linear_in.optimisers)
|
|
413
|
+
indices = self.linear_out.get_reset_indices(1)
|
|
414
|
+
self.linear_in.reset_rows(indices)
|
|
415
|
+
self.linear_out.reset_columns(indices)
|
|
428
416
|
|
|
429
417
|
if self.checkpoint:
|
|
430
418
|
processed = checkpoint(self.process, x, use_reentrant=False)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|