broccoli-ml 9.2.1__tar.gz → 9.3.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {broccoli_ml-9.2.1 → broccoli_ml-9.3.0}/PKG-INFO +1 -1
- {broccoli_ml-9.2.1 → broccoli_ml-9.3.0}/broccoli/linear.py +98 -76
- {broccoli_ml-9.2.1 → broccoli_ml-9.3.0}/broccoli/transformer.py +3 -15
- {broccoli_ml-9.2.1 → broccoli_ml-9.3.0}/pyproject.toml +1 -1
- {broccoli_ml-9.2.1 → broccoli_ml-9.3.0}/LICENSE +0 -0
- {broccoli_ml-9.2.1 → broccoli_ml-9.3.0}/README.md +0 -0
- {broccoli_ml-9.2.1 → broccoli_ml-9.3.0}/broccoli/__init__.py +0 -0
- {broccoli_ml-9.2.1 → broccoli_ml-9.3.0}/broccoli/activation.py +0 -0
- {broccoli_ml-9.2.1 → broccoli_ml-9.3.0}/broccoli/cnn.py +0 -0
- {broccoli_ml-9.2.1 → broccoli_ml-9.3.0}/broccoli/rope.py +0 -0
- {broccoli_ml-9.2.1 → broccoli_ml-9.3.0}/broccoli/tensor.py +0 -0
- {broccoli_ml-9.2.1 → broccoli_ml-9.3.0}/broccoli/utils.py +0 -0
- {broccoli_ml-9.2.1 → broccoli_ml-9.3.0}/broccoli/vit.py +0 -0
|
@@ -151,8 +151,13 @@ class RecyclingLinear(nn.Module):
|
|
|
151
151
|
row_recycling_rate: float = 0.0,
|
|
152
152
|
column_recycling_rate: float = 0.0,
|
|
153
153
|
adaptive=False,
|
|
154
|
+
xglu=False,
|
|
154
155
|
):
|
|
155
156
|
super().__init__()
|
|
157
|
+
self.in_features = in_features
|
|
158
|
+
self.out_features = out_features
|
|
159
|
+
self.bias = bias
|
|
160
|
+
self.xglu = xglu
|
|
156
161
|
self.linear = nn.Linear(in_features, out_features, bias=bias)
|
|
157
162
|
self.row_recycling_rate = row_recycling_rate
|
|
158
163
|
self.column_recycling_rate = column_recycling_rate
|
|
@@ -188,28 +193,50 @@ class RecyclingLinear(nn.Module):
|
|
|
188
193
|
multipliers = [a / b for a, b in pairs if b != 0.0]
|
|
189
194
|
return min(multipliers) if multipliers else 0.0
|
|
190
195
|
|
|
191
|
-
def
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
196
|
+
def reset_rows(self, indices):
|
|
197
|
+
if not torch.is_tensor(indices):
|
|
198
|
+
idx_tensor = torch.as_tensor(
|
|
199
|
+
list(indices), dtype=torch.long, device=self.linear.weight.device
|
|
200
|
+
)
|
|
201
|
+
else:
|
|
202
|
+
idx_tensor = indices
|
|
195
203
|
|
|
196
|
-
if
|
|
204
|
+
if idx_tensor.size(0):
|
|
205
|
+
random_weights = self._random_weights(
|
|
206
|
+
indices.size(0), self.linear.weight.size(1)
|
|
207
|
+
)
|
|
208
|
+
if self.xglu:
|
|
209
|
+
gate_indices = indices
|
|
210
|
+
value_indices = indices + (self.linear.out_features // 2)
|
|
211
|
+
self._update_weights(value_indices, 0, random_weights, self.optimisers)
|
|
212
|
+
centred_gate_weights = self._mean_gate_weights()
|
|
213
|
+
centred_gate_weights = centred_gate_weights.expand(indices.size(0), -1)
|
|
214
|
+
self._update_weights(
|
|
215
|
+
gate_indices, 0, centred_gate_weights, self.optimisers # dim
|
|
216
|
+
)
|
|
217
|
+
else:
|
|
218
|
+
return
|
|
197
219
|
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
if col_recycling_rate > 0:
|
|
207
|
-
probs = torch.rand(self.linear.in_features, device=x.device)
|
|
208
|
-
mask = probs < col_recycling_rate
|
|
209
|
-
if mask.any():
|
|
210
|
-
indices = torch.nonzero(mask).squeeze(-1)
|
|
211
|
-
self.reset_columns(indices, self.optimisers)
|
|
220
|
+
def reset_columns(self, indices):
|
|
221
|
+
if not torch.is_tensor(indices):
|
|
222
|
+
idx_tensor = torch.as_tensor(
|
|
223
|
+
list(indices), dtype=torch.long, device=self.linear.weight.device
|
|
224
|
+
)
|
|
225
|
+
else:
|
|
226
|
+
idx_tensor = indices
|
|
212
227
|
|
|
228
|
+
if idx_tensor.size(0):
|
|
229
|
+
random_weights = self._random_weights(
|
|
230
|
+
self.linear.weight.size(0), indices.size(0)
|
|
231
|
+
)
|
|
232
|
+
self._update_weights(indices, 1, random_weights, self.optimisers) # dim
|
|
233
|
+
else:
|
|
234
|
+
return
|
|
235
|
+
|
|
236
|
+
def forward(self, x):
|
|
237
|
+
if self.training and self.optimisers:
|
|
238
|
+
self.reset_rows(self.get_reset_indices(0))
|
|
239
|
+
self.reset_columns(self.get_reset_indices(1))
|
|
213
240
|
elif self.training and not self._warned_about_registration:
|
|
214
241
|
warnings.warn(
|
|
215
242
|
"RecyclingLinear: No optimiser registered. Recycling disabled.",
|
|
@@ -219,82 +246,77 @@ class RecyclingLinear(nn.Module):
|
|
|
219
246
|
|
|
220
247
|
return self.linear(x)
|
|
221
248
|
|
|
222
|
-
def
|
|
223
|
-
self
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
if not isinstance(optimisers, list):
|
|
235
|
-
optimisers = [optimisers]
|
|
236
|
-
|
|
237
|
-
device = self.linear.weight.device
|
|
238
|
-
idx_tensor = torch.as_tensor(list(indices), dtype=torch.long, device=device)
|
|
239
|
-
|
|
240
|
-
if idx_tensor.numel() == 0:
|
|
241
|
-
return
|
|
242
|
-
|
|
243
|
-
with torch.no_grad():
|
|
244
|
-
# Calculate mean of all rows including the rows to be reset
|
|
245
|
-
mean_vector = self.linear.weight.data.mean(
|
|
246
|
-
dim=0, keepdim=True
|
|
247
|
-
) # [1, in_features]
|
|
248
|
-
update_data = mean_vector.expand(idx_tensor.size(0), -1)
|
|
249
|
-
self.linear.weight.data[idx_tensor] = update_data
|
|
249
|
+
def get_reset_indices(self, dim):
|
|
250
|
+
base_rate = self.row_recycling_rate if dim == 0 else self.column_recycling_rate
|
|
251
|
+
p = base_rate * self._get_multiplier()
|
|
252
|
+
if dim == 0:
|
|
253
|
+
if self.xglu:
|
|
254
|
+
sample_space = self.linear.out_features // 2
|
|
255
|
+
else:
|
|
256
|
+
sample_space = self.linear.out_features
|
|
257
|
+
elif dim == 1:
|
|
258
|
+
sample_space = self.linear.in_features
|
|
259
|
+
else:
|
|
260
|
+
raise ValueError("`dim` must be 0 or 1")
|
|
250
261
|
|
|
251
|
-
|
|
252
|
-
|
|
262
|
+
# Sample the indices
|
|
263
|
+
probs = torch.rand(sample_space, device=self.linear.weight.device)
|
|
264
|
+
mask = probs < p
|
|
265
|
+
if mask.any():
|
|
266
|
+
return torch.nonzero(mask).squeeze(-1)
|
|
267
|
+
else:
|
|
268
|
+
return torch.tensor([], dtype=torch.long, device=self.linear.weight.device)
|
|
253
269
|
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
270
|
+
def _random_weights(self, rows, columns):
|
|
271
|
+
device = self.linear.weight.device
|
|
272
|
+
weights = self.linear.weight.data
|
|
273
|
+
stdv = 1.0 / math.sqrt(weights.size(1))
|
|
274
|
+
random_weights = torch.rand(rows, columns, device=device)
|
|
275
|
+
random_weights -= 0.5 # Range [-0.5, +0.5]
|
|
276
|
+
random_weights *= 2.0 * stdv # Range [-stdv, +stdv]
|
|
277
|
+
return random_weights
|
|
278
|
+
|
|
279
|
+
def _mean_gate_weights(self):
|
|
280
|
+
"""
|
|
281
|
+
Only used when self.xglu
|
|
282
|
+
"""
|
|
283
|
+
weights = self.linear.weight.data
|
|
284
|
+
rows = weights.size(0)
|
|
285
|
+
return self.linear.weight[: int(rows / 2)].data.mean(dim=0, keepdim=True)
|
|
257
286
|
|
|
258
|
-
def
|
|
287
|
+
def _update_weights(
|
|
259
288
|
self,
|
|
260
289
|
indices: Iterable[int],
|
|
290
|
+
dim: int,
|
|
291
|
+
data: torch.Tensor,
|
|
261
292
|
optimisers: Union[
|
|
262
293
|
List[torch.optim.Optimizer], torch.optim.Optimizer, None
|
|
263
294
|
] = None,
|
|
264
295
|
):
|
|
265
|
-
"""
|
|
266
|
-
Update some of the weight columns to be random as though reinitialised.
|
|
267
|
-
"""
|
|
268
296
|
if optimisers is None:
|
|
269
297
|
optimisers = []
|
|
270
298
|
if not isinstance(optimisers, list):
|
|
271
299
|
optimisers = [optimisers]
|
|
272
300
|
|
|
273
|
-
|
|
274
|
-
|
|
301
|
+
if not torch.is_tensor(indices):
|
|
302
|
+
idx_tensor = torch.as_tensor(
|
|
303
|
+
list(indices), dtype=torch.long, device=self.linear.weight.device
|
|
304
|
+
)
|
|
305
|
+
else:
|
|
306
|
+
idx_tensor = indices
|
|
275
307
|
|
|
276
308
|
if idx_tensor.numel() == 0:
|
|
277
309
|
return
|
|
278
310
|
|
|
279
311
|
with torch.no_grad():
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
)
|
|
289
|
-
random_weights = (random_weights - 0.5) * 2.0 * stdv
|
|
290
|
-
|
|
291
|
-
# 2. Update Weights (One-shot)
|
|
292
|
-
# We assign into the columns specified by idx_tensor
|
|
293
|
-
self.linear.weight.data[:, idx_tensor] = random_weights
|
|
294
|
-
|
|
295
|
-
# 3. Update Optimizers
|
|
296
|
-
# Bias is untouched by column resets (bias is shape [Out], cols are [In])
|
|
297
|
-
self._reset_optim_state(self.linear.weight, idx_tensor, optimisers, dim=1)
|
|
312
|
+
if dim == 0:
|
|
313
|
+
self.linear.weight.data[idx_tensor] = data
|
|
314
|
+
elif dim == 1:
|
|
315
|
+
self.linear.weight.data[:, idx_tensor] = data
|
|
316
|
+
else:
|
|
317
|
+
raise ValueError("`dim` must be 0 or 1")
|
|
318
|
+
|
|
319
|
+
self._reset_optim_state(self.linear.weight, idx_tensor, optimisers, dim=dim)
|
|
298
320
|
|
|
299
321
|
def _reset_optim_state(self, param, idx_tensor, optimisers, dim):
|
|
300
322
|
"""
|
|
@@ -410,21 +410,9 @@ class FeedforwardBlock(nn.Module):
|
|
|
410
410
|
|
|
411
411
|
# Recycle weights if using recycling linear layers
|
|
412
412
|
if self.training and self.recycling_enabled:
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
probs = torch.rand(self.linear_out.in_features, device=x.device)
|
|
417
|
-
mask = probs < rate
|
|
418
|
-
if mask.any():
|
|
419
|
-
indices = torch.nonzero(mask).squeeze(-1)
|
|
420
|
-
self.linear_out.reset_columns(indices, self.linear_out.optimisers)
|
|
421
|
-
if self.xglu:
|
|
422
|
-
indices_in = torch.cat(
|
|
423
|
-
[indices, indices + self.linear_out.in_features]
|
|
424
|
-
)
|
|
425
|
-
self.linear_in.reset_rows(indices_in, self.linear_in.optimisers)
|
|
426
|
-
else:
|
|
427
|
-
self.linear_in.reset_rows(indices, self.linear_in.optimisers)
|
|
413
|
+
indices = self.linear_out.get_reset_indices(1)
|
|
414
|
+
self.linear_in.reset_rows(indices)
|
|
415
|
+
self.linear_out.reset_columns(indices)
|
|
428
416
|
|
|
429
417
|
if self.checkpoint:
|
|
430
418
|
processed = checkpoint(self.process, x, use_reentrant=False)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|