broccoli-ml 9.0.0__py3-none-any.whl → 9.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- broccoli/linear.py +44 -4
- broccoli/transformer.py +42 -5
- {broccoli_ml-9.0.0.dist-info → broccoli_ml-9.2.0.dist-info}/METADATA +1 -1
- {broccoli_ml-9.0.0.dist-info → broccoli_ml-9.2.0.dist-info}/RECORD +6 -6
- {broccoli_ml-9.0.0.dist-info → broccoli_ml-9.2.0.dist-info}/LICENSE +0 -0
- {broccoli_ml-9.0.0.dist-info → broccoli_ml-9.2.0.dist-info}/WHEEL +0 -0
broccoli/linear.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import math
|
|
2
2
|
import random
|
|
3
|
+
import warnings
|
|
3
4
|
from typing import Union, List, Iterable
|
|
4
5
|
|
|
5
6
|
import torch
|
|
@@ -149,34 +150,73 @@ class RecyclingLinear(nn.Module):
|
|
|
149
150
|
bias: bool = True,
|
|
150
151
|
row_recycling_rate: float = 0.0,
|
|
151
152
|
column_recycling_rate: float = 0.0,
|
|
153
|
+
adaptive=False,
|
|
152
154
|
):
|
|
153
155
|
super().__init__()
|
|
154
156
|
self.linear = nn.Linear(in_features, out_features, bias=bias)
|
|
155
157
|
self.row_recycling_rate = row_recycling_rate
|
|
156
158
|
self.column_recycling_rate = column_recycling_rate
|
|
159
|
+
self.adaptive = adaptive
|
|
157
160
|
self.optimisers = []
|
|
161
|
+
self.initial_learning_rates = []
|
|
162
|
+
self._warned_about_registration = False
|
|
158
163
|
|
|
159
164
|
def register_optimiser(self, optimiser: torch.optim.Optimizer):
|
|
160
165
|
self.optimisers.append(optimiser)
|
|
166
|
+
self.initial_learning_rates.append(self._get_learning_rate(optimiser))
|
|
167
|
+
if self.initial_learning_rates[-1] == 0.0:
|
|
168
|
+
warnings.warn(
|
|
169
|
+
"Learning rate of registered optimiser was 0.0 - make sure "
|
|
170
|
+
"you haven't initialised a scheduler before registering the "
|
|
171
|
+
"optimiser",
|
|
172
|
+
stacklevel=2,
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
def _get_learning_rate(self, optimiser: torch.optim.Optimizer):
|
|
176
|
+
for group in optimiser.param_groups:
|
|
177
|
+
for param in group["params"]:
|
|
178
|
+
if param is self.linear.weight:
|
|
179
|
+
return group["lr"]
|
|
180
|
+
|
|
181
|
+
def _get_multiplier(self):
|
|
182
|
+
if not self.adaptive or not self.optimisers:
|
|
183
|
+
return 1.0
|
|
184
|
+
else:
|
|
185
|
+
init = self.initial_learning_rates
|
|
186
|
+
current = [self._get_learning_rate(o) for o in self.optimisers]
|
|
187
|
+
pairs = zip(current, init, strict=True)
|
|
188
|
+
multipliers = [a / b for a, b in pairs if b != 0.0]
|
|
189
|
+
return min(multipliers) if multipliers else 0.0
|
|
161
190
|
|
|
162
191
|
def forward(self, x):
|
|
192
|
+
multiplier = self._get_multiplier()
|
|
193
|
+
col_recycling_rate = self.column_recycling_rate * multiplier
|
|
194
|
+
row_recycling_rate = self.row_recycling_rate * multiplier
|
|
195
|
+
|
|
163
196
|
if self.training and self.optimisers:
|
|
164
197
|
|
|
165
|
-
if
|
|
198
|
+
if row_recycling_rate > 0:
|
|
166
199
|
probs = torch.rand(self.linear.out_features, device=x.device)
|
|
167
|
-
mask = probs <
|
|
200
|
+
mask = probs < row_recycling_rate
|
|
168
201
|
if mask.any():
|
|
169
202
|
# nonzero returns [N, 1], squeeze to get [N]
|
|
170
203
|
indices = torch.nonzero(mask).squeeze(-1)
|
|
171
204
|
self.reset_rows(indices, self.optimisers)
|
|
172
205
|
|
|
173
|
-
if
|
|
206
|
+
if col_recycling_rate > 0:
|
|
174
207
|
probs = torch.rand(self.linear.in_features, device=x.device)
|
|
175
|
-
mask = probs <
|
|
208
|
+
mask = probs < col_recycling_rate
|
|
176
209
|
if mask.any():
|
|
177
210
|
indices = torch.nonzero(mask).squeeze(-1)
|
|
178
211
|
self.reset_columns(indices, self.optimisers)
|
|
179
212
|
|
|
213
|
+
elif self.training and not self._warned_about_registration:
|
|
214
|
+
warnings.warn(
|
|
215
|
+
"RecyclingLinear: No optimiser registered. Recycling disabled.",
|
|
216
|
+
stacklevel=2,
|
|
217
|
+
)
|
|
218
|
+
self._warned_about_registration = True
|
|
219
|
+
|
|
180
220
|
return self.linear(x)
|
|
181
221
|
|
|
182
222
|
def reset_rows(
|
broccoli/transformer.py
CHANGED
|
@@ -340,6 +340,7 @@ class FeedforwardBlock(nn.Module):
|
|
|
340
340
|
self.checkpoint = checkpoint
|
|
341
341
|
self.residual_path = residual_path
|
|
342
342
|
self.post_norm = post_norm
|
|
343
|
+
self.xglu = activation.__name__.endswith("GLU")
|
|
343
344
|
|
|
344
345
|
if self.residual_path and (output_features < input_features):
|
|
345
346
|
raise ValueError(
|
|
@@ -364,27 +365,63 @@ class FeedforwardBlock(nn.Module):
|
|
|
364
365
|
)
|
|
365
366
|
|
|
366
367
|
self.max_features = (
|
|
367
|
-
2 * ratio * output_features
|
|
368
|
-
if activation.__name__.endswith("GLU")
|
|
369
|
-
else ratio * output_features
|
|
368
|
+
2 * ratio * output_features if self.xglu else ratio * output_features
|
|
370
369
|
)
|
|
371
370
|
|
|
371
|
+
self.linear_in = linear_module_up(input_features, self.max_features)
|
|
372
|
+
self.linear_out = linear_module_down(ratio * output_features, output_features)
|
|
373
|
+
|
|
372
374
|
self.process = nn.Sequential(
|
|
373
375
|
*[
|
|
374
376
|
nn.LayerNorm(input_features) if pre_norm else nn.Identity(),
|
|
375
|
-
|
|
377
|
+
self.linear_in,
|
|
376
378
|
self.activation,
|
|
377
379
|
self.inner_dropout,
|
|
378
380
|
nn.LayerNorm(ratio * output_features) if normformer else nn.Identity(),
|
|
379
|
-
|
|
381
|
+
self.linear_out,
|
|
380
382
|
self.outer_dropout,
|
|
381
383
|
]
|
|
382
384
|
)
|
|
383
385
|
|
|
386
|
+
self.recycling_enabled = False
|
|
387
|
+
if hasattr(self.linear_in, "row_recycling_rate") and hasattr(
|
|
388
|
+
self.linear_out, "column_recycling_rate"
|
|
389
|
+
):
|
|
390
|
+
self.recycling_enabled = True
|
|
391
|
+
self.master_recycling_rate = self.linear_in.row_recycling_rate
|
|
392
|
+
self.linear_in.row_recycling_rate = 0.0
|
|
393
|
+
self.linear_out.column_recycling_rate = 0.0
|
|
394
|
+
if hasattr(self.linear_in, "column_recycling_rate") or hasattr(
|
|
395
|
+
self.linear_out, "row_recycling_rate"
|
|
396
|
+
):
|
|
397
|
+
raise NotImplementedError(
|
|
398
|
+
"At the moment this layer can only support recycling linear "
|
|
399
|
+
"layers if the in layer resets only rows and the out layer "
|
|
400
|
+
"resets only columns."
|
|
401
|
+
)
|
|
402
|
+
|
|
384
403
|
self.reset_parameters()
|
|
385
404
|
|
|
386
405
|
def forward(self, x):
|
|
387
406
|
|
|
407
|
+
# Recycle weights if using recycling linear layers
|
|
408
|
+
if self.training and self.recycling_enabled:
|
|
409
|
+
multiplier = self.linear_in._get_multiplier()
|
|
410
|
+
rate = self.master_recycling_rate * multiplier
|
|
411
|
+
if rate > 0:
|
|
412
|
+
probs = torch.rand(self.linear_out.in_features, device=x.device)
|
|
413
|
+
mask = probs < rate
|
|
414
|
+
if mask.any():
|
|
415
|
+
indices = torch.nonzero(mask).squeeze(-1)
|
|
416
|
+
self.linear_out.reset_columns(indices, self.linear_out.optimisers)
|
|
417
|
+
if self.xglu:
|
|
418
|
+
indices_in = torch.cat(
|
|
419
|
+
[indices, indices + self.linear_out.in_features]
|
|
420
|
+
)
|
|
421
|
+
self.linear_in.reset_rows(indices_in, self.linear_in.optimisers)
|
|
422
|
+
else:
|
|
423
|
+
self.linear_in.reset_rows(indices, self.linear_in.optimisers)
|
|
424
|
+
|
|
388
425
|
if self.checkpoint:
|
|
389
426
|
processed = checkpoint(self.process, x, use_reentrant=False)
|
|
390
427
|
else:
|
|
@@ -1,13 +1,13 @@
|
|
|
1
1
|
broccoli/__init__.py,sha256=tmyspsVxqPZHRQCY_NRwpW4SMNBbtE8E_8z7l-SAzSo,127
|
|
2
2
|
broccoli/activation.py,sha256=-Jf30C6iGqWCorC9HEGn2oduWwjeaCAxGLUUYIy1zX8,3438
|
|
3
3
|
broccoli/cnn.py,sha256=WjoPDSpe3ttwxCBNfCVRdaCHvbeZ7G-a5_i8fUsK_d8,4889
|
|
4
|
-
broccoli/linear.py,sha256=
|
|
4
|
+
broccoli/linear.py,sha256=7uN7zVPJ6Ptec31O8a-GvWT5nZk56Wf1RLJRvUAT0yo,11406
|
|
5
5
|
broccoli/rope.py,sha256=GRqApBNmYCFaDak0WL1xE_BC5CTTYKQU_PBdeTcQcjc,12557
|
|
6
6
|
broccoli/tensor.py,sha256=um8mrxkYbvNDo-QvHlmJm8Aw6qcngOlUZPoAk_PMReA,4480
|
|
7
|
-
broccoli/transformer.py,sha256=
|
|
7
|
+
broccoli/transformer.py,sha256=ckwFpNTAeYB_V8F-_DwT0Z5--QUEepM4r6xlSxhCY68,26079
|
|
8
8
|
broccoli/utils.py,sha256=oOWzn6dJ5nC_9r4zq0emmfmaYACJXJNFS48AOpW2jqc,358
|
|
9
9
|
broccoli/vit.py,sha256=sC6K3FK3a8ojOgvNWSWhuZHBtnFrrTQbsDdlagcKJH4,22224
|
|
10
|
-
broccoli_ml-9.
|
|
11
|
-
broccoli_ml-9.
|
|
12
|
-
broccoli_ml-9.
|
|
13
|
-
broccoli_ml-9.
|
|
10
|
+
broccoli_ml-9.2.0.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
|
|
11
|
+
broccoli_ml-9.2.0.dist-info/METADATA,sha256=MnrReG_teIm7YceDS9nfvXnelEXHGP4JFhoz9dhkJWo,1368
|
|
12
|
+
broccoli_ml-9.2.0.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
|
13
|
+
broccoli_ml-9.2.0.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|