broccoli-ml 9.1.0__py3-none-any.whl → 9.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- broccoli/transformer.py +46 -5
- {broccoli_ml-9.1.0.dist-info → broccoli_ml-9.2.1.dist-info}/METADATA +1 -1
- {broccoli_ml-9.1.0.dist-info → broccoli_ml-9.2.1.dist-info}/RECORD +5 -5
- {broccoli_ml-9.1.0.dist-info → broccoli_ml-9.2.1.dist-info}/LICENSE +0 -0
- {broccoli_ml-9.1.0.dist-info → broccoli_ml-9.2.1.dist-info}/WHEEL +0 -0
broccoli/transformer.py
CHANGED
|
@@ -340,6 +340,7 @@ class FeedforwardBlock(nn.Module):
|
|
|
340
340
|
self.checkpoint = checkpoint
|
|
341
341
|
self.residual_path = residual_path
|
|
342
342
|
self.post_norm = post_norm
|
|
343
|
+
self.xglu = activation.__name__.endswith("GLU")
|
|
343
344
|
|
|
344
345
|
if self.residual_path and (output_features < input_features):
|
|
345
346
|
raise ValueError(
|
|
@@ -364,27 +365,67 @@ class FeedforwardBlock(nn.Module):
|
|
|
364
365
|
)
|
|
365
366
|
|
|
366
367
|
self.max_features = (
|
|
367
|
-
2 * ratio * output_features
|
|
368
|
-
if activation.__name__.endswith("GLU")
|
|
369
|
-
else ratio * output_features
|
|
368
|
+
2 * ratio * output_features if self.xglu else ratio * output_features
|
|
370
369
|
)
|
|
371
370
|
|
|
371
|
+
self.linear_in = linear_module_up(input_features, self.max_features)
|
|
372
|
+
self.linear_out = linear_module_down(ratio * output_features, output_features)
|
|
373
|
+
|
|
372
374
|
self.process = nn.Sequential(
|
|
373
375
|
*[
|
|
374
376
|
nn.LayerNorm(input_features) if pre_norm else nn.Identity(),
|
|
375
|
-
|
|
377
|
+
self.linear_in,
|
|
376
378
|
self.activation,
|
|
377
379
|
self.inner_dropout,
|
|
378
380
|
nn.LayerNorm(ratio * output_features) if normformer else nn.Identity(),
|
|
379
|
-
|
|
381
|
+
self.linear_out,
|
|
380
382
|
self.outer_dropout,
|
|
381
383
|
]
|
|
382
384
|
)
|
|
383
385
|
|
|
386
|
+
self.recycling_enabled = False
|
|
387
|
+
if hasattr(self.linear_in, "row_recycling_rate") and hasattr(
|
|
388
|
+
self.linear_out, "column_recycling_rate"
|
|
389
|
+
):
|
|
390
|
+
self.recycling_enabled = True
|
|
391
|
+
self.master_recycling_rate = self.linear_in.row_recycling_rate
|
|
392
|
+
self.linear_in.row_recycling_rate = 0.0
|
|
393
|
+
self.linear_out.column_recycling_rate = 0.0
|
|
394
|
+
if (
|
|
395
|
+
hasattr(self.linear_in, "column_recycling_rate")
|
|
396
|
+
and self.linear_in.column_recycling_rate > 0
|
|
397
|
+
) or (
|
|
398
|
+
hasattr(self.linear_out, "row_recycling_rate")
|
|
399
|
+
and self.linear_out.row_recycling_rate > 0
|
|
400
|
+
):
|
|
401
|
+
raise NotImplementedError(
|
|
402
|
+
"At the moment this layer can only support recycling linear "
|
|
403
|
+
"layers if the in layer resets only rows and the out layer "
|
|
404
|
+
"resets only columns."
|
|
405
|
+
)
|
|
406
|
+
|
|
384
407
|
self.reset_parameters()
|
|
385
408
|
|
|
386
409
|
def forward(self, x):
|
|
387
410
|
|
|
411
|
+
# Recycle weights if using recycling linear layers
|
|
412
|
+
if self.training and self.recycling_enabled:
|
|
413
|
+
multiplier = self.linear_in._get_multiplier()
|
|
414
|
+
rate = self.master_recycling_rate * multiplier
|
|
415
|
+
if rate > 0:
|
|
416
|
+
probs = torch.rand(self.linear_out.in_features, device=x.device)
|
|
417
|
+
mask = probs < rate
|
|
418
|
+
if mask.any():
|
|
419
|
+
indices = torch.nonzero(mask).squeeze(-1)
|
|
420
|
+
self.linear_out.reset_columns(indices, self.linear_out.optimisers)
|
|
421
|
+
if self.xglu:
|
|
422
|
+
indices_in = torch.cat(
|
|
423
|
+
[indices, indices + self.linear_out.in_features]
|
|
424
|
+
)
|
|
425
|
+
self.linear_in.reset_rows(indices_in, self.linear_in.optimisers)
|
|
426
|
+
else:
|
|
427
|
+
self.linear_in.reset_rows(indices, self.linear_in.optimisers)
|
|
428
|
+
|
|
388
429
|
if self.checkpoint:
|
|
389
430
|
processed = checkpoint(self.process, x, use_reentrant=False)
|
|
390
431
|
else:
|
|
@@ -4,10 +4,10 @@ broccoli/cnn.py,sha256=WjoPDSpe3ttwxCBNfCVRdaCHvbeZ7G-a5_i8fUsK_d8,4889
|
|
|
4
4
|
broccoli/linear.py,sha256=7uN7zVPJ6Ptec31O8a-GvWT5nZk56Wf1RLJRvUAT0yo,11406
|
|
5
5
|
broccoli/rope.py,sha256=GRqApBNmYCFaDak0WL1xE_BC5CTTYKQU_PBdeTcQcjc,12557
|
|
6
6
|
broccoli/tensor.py,sha256=um8mrxkYbvNDo-QvHlmJm8Aw6qcngOlUZPoAk_PMReA,4480
|
|
7
|
-
broccoli/transformer.py,sha256=
|
|
7
|
+
broccoli/transformer.py,sha256=r-ggAeNDW5QpBi9As1U9sIfxITBOx0WHk_K4zWpyTM8,26233
|
|
8
8
|
broccoli/utils.py,sha256=oOWzn6dJ5nC_9r4zq0emmfmaYACJXJNFS48AOpW2jqc,358
|
|
9
9
|
broccoli/vit.py,sha256=sC6K3FK3a8ojOgvNWSWhuZHBtnFrrTQbsDdlagcKJH4,22224
|
|
10
|
-
broccoli_ml-9.1.
|
|
11
|
-
broccoli_ml-9.1.
|
|
12
|
-
broccoli_ml-9.1.
|
|
13
|
-
broccoli_ml-9.1.
|
|
10
|
+
broccoli_ml-9.2.1.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
|
|
11
|
+
broccoli_ml-9.2.1.dist-info/METADATA,sha256=Nj7WnXKxlvSlrK8rQp9wizgPGs7ZMnhCi-KY5O6W-wc,1368
|
|
12
|
+
broccoli_ml-9.2.1.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
|
13
|
+
broccoli_ml-9.2.1.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|