broccoli-ml 7.0.0__py3-none-any.whl → 9.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- broccoli/linear.py +136 -0
- broccoli/transformer.py +23 -7
- broccoli/vit.py +36 -10
- {broccoli_ml-7.0.0.dist-info → broccoli_ml-9.0.0.dist-info}/METADATA +1 -1
- {broccoli_ml-7.0.0.dist-info → broccoli_ml-9.0.0.dist-info}/RECORD +7 -7
- {broccoli_ml-7.0.0.dist-info → broccoli_ml-9.0.0.dist-info}/LICENSE +0 -0
- {broccoli_ml-7.0.0.dist-info → broccoli_ml-9.0.0.dist-info}/WHEEL +0 -0
broccoli/linear.py
CHANGED
|
@@ -1,4 +1,7 @@
|
|
|
1
1
|
import math
|
|
2
|
+
import random
|
|
3
|
+
from typing import Union, List, Iterable
|
|
4
|
+
|
|
2
5
|
import torch
|
|
3
6
|
from torch import nn
|
|
4
7
|
from torch.nn import functional as F
|
|
@@ -136,3 +139,136 @@ class WeightNormedLinear(nn.Module):
|
|
|
136
139
|
f"WeightNormedLinear(in_features={self.in_features},"
|
|
137
140
|
f"out_features={self.out_features}, bias={self.use_bias})"
|
|
138
141
|
)
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
class RecyclingLinear(nn.Module):
|
|
145
|
+
def __init__(
|
|
146
|
+
self,
|
|
147
|
+
in_features: int,
|
|
148
|
+
out_features: int,
|
|
149
|
+
bias: bool = True,
|
|
150
|
+
row_recycling_rate: float = 0.0,
|
|
151
|
+
column_recycling_rate: float = 0.0,
|
|
152
|
+
):
|
|
153
|
+
super().__init__()
|
|
154
|
+
self.linear = nn.Linear(in_features, out_features, bias=bias)
|
|
155
|
+
self.row_recycling_rate = row_recycling_rate
|
|
156
|
+
self.column_recycling_rate = column_recycling_rate
|
|
157
|
+
self.optimisers = []
|
|
158
|
+
|
|
159
|
+
def register_optimiser(self, optimiser: torch.optim.Optimizer):
|
|
160
|
+
self.optimisers.append(optimiser)
|
|
161
|
+
|
|
162
|
+
def forward(self, x):
|
|
163
|
+
if self.training and self.optimisers:
|
|
164
|
+
|
|
165
|
+
if self.row_recycling_rate > 0:
|
|
166
|
+
probs = torch.rand(self.linear.out_features, device=x.device)
|
|
167
|
+
mask = probs < self.row_recycling_rate
|
|
168
|
+
if mask.any():
|
|
169
|
+
# nonzero returns [N, 1], squeeze to get [N]
|
|
170
|
+
indices = torch.nonzero(mask).squeeze(-1)
|
|
171
|
+
self.reset_rows(indices, self.optimisers)
|
|
172
|
+
|
|
173
|
+
if self.column_recycling_rate > 0:
|
|
174
|
+
probs = torch.rand(self.linear.in_features, device=x.device)
|
|
175
|
+
mask = probs < self.column_recycling_rate
|
|
176
|
+
if mask.any():
|
|
177
|
+
indices = torch.nonzero(mask).squeeze(-1)
|
|
178
|
+
self.reset_columns(indices, self.optimisers)
|
|
179
|
+
|
|
180
|
+
return self.linear(x)
|
|
181
|
+
|
|
182
|
+
def reset_rows(
|
|
183
|
+
self,
|
|
184
|
+
indices: Iterable[int],
|
|
185
|
+
optimisers: Union[
|
|
186
|
+
List[torch.optim.Optimizer], torch.optim.Optimizer, None
|
|
187
|
+
] = None,
|
|
188
|
+
):
|
|
189
|
+
"""
|
|
190
|
+
Update some of the weight rows to be equal to the mean of all weight rows.
|
|
191
|
+
"""
|
|
192
|
+
if optimisers is None:
|
|
193
|
+
optimisers = []
|
|
194
|
+
if not isinstance(optimisers, list):
|
|
195
|
+
optimisers = [optimisers]
|
|
196
|
+
|
|
197
|
+
device = self.linear.weight.device
|
|
198
|
+
idx_tensor = torch.as_tensor(list(indices), dtype=torch.long, device=device)
|
|
199
|
+
|
|
200
|
+
if idx_tensor.numel() == 0:
|
|
201
|
+
return
|
|
202
|
+
|
|
203
|
+
with torch.no_grad():
|
|
204
|
+
# Calculate mean of all rows including the rows to be reset
|
|
205
|
+
mean_vector = self.linear.weight.data.mean(
|
|
206
|
+
dim=0, keepdim=True
|
|
207
|
+
) # [1, in_features]
|
|
208
|
+
update_data = mean_vector.expand(idx_tensor.size(0), -1)
|
|
209
|
+
self.linear.weight.data[idx_tensor] = update_data
|
|
210
|
+
|
|
211
|
+
if self.linear.bias is not None:
|
|
212
|
+
self.linear.bias.data[idx_tensor] = 0.0
|
|
213
|
+
|
|
214
|
+
self._reset_optim_state(self.linear.weight, idx_tensor, optimisers, dim=0)
|
|
215
|
+
if self.linear.bias is not None:
|
|
216
|
+
self._reset_optim_state(self.linear.bias, idx_tensor, optimisers, dim=0)
|
|
217
|
+
|
|
218
|
+
def reset_columns(
|
|
219
|
+
self,
|
|
220
|
+
indices: Iterable[int],
|
|
221
|
+
optimisers: Union[
|
|
222
|
+
List[torch.optim.Optimizer], torch.optim.Optimizer, None
|
|
223
|
+
] = None,
|
|
224
|
+
):
|
|
225
|
+
"""
|
|
226
|
+
Update some of the weight columns to be random as though reinitialised.
|
|
227
|
+
"""
|
|
228
|
+
if optimisers is None:
|
|
229
|
+
optimisers = []
|
|
230
|
+
if not isinstance(optimisers, list):
|
|
231
|
+
optimisers = [optimisers]
|
|
232
|
+
|
|
233
|
+
device = self.linear.weight.device
|
|
234
|
+
idx_tensor = torch.as_tensor(list(indices), dtype=torch.long, device=device)
|
|
235
|
+
|
|
236
|
+
if idx_tensor.numel() == 0:
|
|
237
|
+
return
|
|
238
|
+
|
|
239
|
+
with torch.no_grad():
|
|
240
|
+
# 1. Generate Random Columns
|
|
241
|
+
# Shape: [out_features, N_indices]
|
|
242
|
+
weights = self.linear.weight.data
|
|
243
|
+
stdv = 1.0 / math.sqrt(weights.size(1))
|
|
244
|
+
|
|
245
|
+
# Generate [Rows, N] block
|
|
246
|
+
random_weights = torch.rand(
|
|
247
|
+
weights.size(0), idx_tensor.size(0), device=device
|
|
248
|
+
)
|
|
249
|
+
random_weights = (random_weights - 0.5) * 2.0 * stdv
|
|
250
|
+
|
|
251
|
+
# 2. Update Weights (One-shot)
|
|
252
|
+
# We assign into the columns specified by idx_tensor
|
|
253
|
+
self.linear.weight.data[:, idx_tensor] = random_weights
|
|
254
|
+
|
|
255
|
+
# 3. Update Optimizers
|
|
256
|
+
# Bias is untouched by column resets (bias is shape [Out], cols are [In])
|
|
257
|
+
self._reset_optim_state(self.linear.weight, idx_tensor, optimisers, dim=1)
|
|
258
|
+
|
|
259
|
+
def _reset_optim_state(self, param, idx_tensor, optimisers, dim):
|
|
260
|
+
"""
|
|
261
|
+
Zeroes out the optimizer state for the given indices in a single operation.
|
|
262
|
+
"""
|
|
263
|
+
for optimiser in optimisers:
|
|
264
|
+
if param not in optimiser.state:
|
|
265
|
+
continue
|
|
266
|
+
state = optimiser.state[param]
|
|
267
|
+
|
|
268
|
+
for _, buffer in state.items():
|
|
269
|
+
if torch.is_tensor(buffer) and buffer.shape == param.shape:
|
|
270
|
+
# Vectorized zeroing
|
|
271
|
+
if dim == 0:
|
|
272
|
+
buffer[idx_tensor] = 0.0
|
|
273
|
+
else:
|
|
274
|
+
buffer[:, idx_tensor] = 0.0
|
broccoli/transformer.py
CHANGED
|
@@ -325,6 +325,8 @@ class FeedforwardBlock(nn.Module):
|
|
|
325
325
|
activation=nn.ReLU,
|
|
326
326
|
activation_kwargs=None,
|
|
327
327
|
dropout=0.0,
|
|
328
|
+
inner_dropout=None,
|
|
329
|
+
outer_dropout=None,
|
|
328
330
|
linear_module_up=nn.Linear,
|
|
329
331
|
linear_module_down=nn.Linear,
|
|
330
332
|
pre_norm=True,
|
|
@@ -354,7 +356,12 @@ class FeedforwardBlock(nn.Module):
|
|
|
354
356
|
else:
|
|
355
357
|
self.activation = activation()
|
|
356
358
|
|
|
357
|
-
self.
|
|
359
|
+
self.inner_dropout = nn.Dropout(
|
|
360
|
+
inner_dropout if inner_dropout is not None else dropout
|
|
361
|
+
)
|
|
362
|
+
self.outer_dropout = nn.Dropout(
|
|
363
|
+
outer_dropout if outer_dropout is not None else dropout
|
|
364
|
+
)
|
|
358
365
|
|
|
359
366
|
self.max_features = (
|
|
360
367
|
2 * ratio * output_features
|
|
@@ -367,9 +374,10 @@ class FeedforwardBlock(nn.Module):
|
|
|
367
374
|
nn.LayerNorm(input_features) if pre_norm else nn.Identity(),
|
|
368
375
|
linear_module_up(input_features, self.max_features),
|
|
369
376
|
self.activation,
|
|
377
|
+
self.inner_dropout,
|
|
370
378
|
nn.LayerNorm(ratio * output_features) if normformer else nn.Identity(),
|
|
371
379
|
linear_module_down(ratio * output_features, output_features),
|
|
372
|
-
self.
|
|
380
|
+
self.outer_dropout,
|
|
373
381
|
]
|
|
374
382
|
)
|
|
375
383
|
|
|
@@ -422,7 +430,9 @@ class TransformerBlock(nn.Module):
|
|
|
422
430
|
ff_linear_module_up=None,
|
|
423
431
|
ff_linear_module_down=None,
|
|
424
432
|
msa_scaling="d",
|
|
425
|
-
|
|
433
|
+
ff_dropout=0.0,
|
|
434
|
+
ff_inner_dropout=0.0,
|
|
435
|
+
ff_outer_dropout=0.0,
|
|
426
436
|
msa_dropout=0.0,
|
|
427
437
|
identity_probability=0.0,
|
|
428
438
|
causal=False,
|
|
@@ -484,7 +494,9 @@ class TransformerBlock(nn.Module):
|
|
|
484
494
|
d_model,
|
|
485
495
|
activation=activation,
|
|
486
496
|
activation_kwargs=activation_kwargs,
|
|
487
|
-
dropout=
|
|
497
|
+
dropout=ff_dropout,
|
|
498
|
+
inner_dropout=ff_inner_dropout,
|
|
499
|
+
outer_dropout=ff_outer_dropout,
|
|
488
500
|
linear_module_up=(
|
|
489
501
|
ff_linear_module_up
|
|
490
502
|
if ff_linear_module_up is not None
|
|
@@ -567,7 +579,9 @@ class TransformerEncoder(nn.Module):
|
|
|
567
579
|
activation_kwargs: Optional[dict] = None,
|
|
568
580
|
ff_linear_module_up=None,
|
|
569
581
|
ff_linear_module_down=None,
|
|
570
|
-
|
|
582
|
+
ff_dropout=0.0,
|
|
583
|
+
ff_inner_dropout=0.0,
|
|
584
|
+
ff_outer_dropout=0.0,
|
|
571
585
|
msa_dropout=0.0,
|
|
572
586
|
stochastic_depth=0.0,
|
|
573
587
|
causal=False,
|
|
@@ -629,7 +643,7 @@ class TransformerEncoder(nn.Module):
|
|
|
629
643
|
else:
|
|
630
644
|
self.absolute_position_embedding = None
|
|
631
645
|
|
|
632
|
-
self.mlp_dropout =
|
|
646
|
+
self.mlp_dropout = ff_dropout
|
|
633
647
|
self.msa_dropout = msa_dropout
|
|
634
648
|
self.stochastic_depth = stochastic_depth
|
|
635
649
|
|
|
@@ -658,7 +672,9 @@ class TransformerEncoder(nn.Module):
|
|
|
658
672
|
ff_linear_module_up=ff_linear_module_up,
|
|
659
673
|
ff_linear_module_down=ff_linear_module_down,
|
|
660
674
|
msa_scaling=msa_scaling,
|
|
661
|
-
|
|
675
|
+
ff_dropout=ff_dropout,
|
|
676
|
+
ff_inner_dropout=ff_inner_dropout,
|
|
677
|
+
ff_outer_dropout=ff_outer_dropout,
|
|
662
678
|
msa_dropout=msa_dropout,
|
|
663
679
|
identity_probability=self.stochastic_depth_probabilities[i],
|
|
664
680
|
causal=causal,
|
broccoli/vit.py
CHANGED
|
@@ -161,7 +161,9 @@ class ViTEncoder(nn.Module):
|
|
|
161
161
|
transformer_initial_ff_residual_path=True,
|
|
162
162
|
transformer_initial_ff_linear_module_up=None,
|
|
163
163
|
transformer_initial_ff_linear_module_down=None,
|
|
164
|
-
|
|
164
|
+
transformer_initial_ff_dropout=None,
|
|
165
|
+
transformer_initial_ff_inner_dropout=None,
|
|
166
|
+
transformer_initial_ff_outer_dropout=None,
|
|
165
167
|
transformer_pre_norm=True,
|
|
166
168
|
transformer_normformer=False,
|
|
167
169
|
transformer_post_norm=False,
|
|
@@ -178,7 +180,9 @@ class ViTEncoder(nn.Module):
|
|
|
178
180
|
transformer_ff_linear_module_up=None,
|
|
179
181
|
transformer_ff_linear_module_down=None,
|
|
180
182
|
transformer_msa_scaling="d",
|
|
181
|
-
|
|
183
|
+
transformer_ff_dropout=0.0,
|
|
184
|
+
transformer_ff_inner_dropout=0.0,
|
|
185
|
+
transformer_ff_outer_dropout=0.0,
|
|
182
186
|
transformer_msa_dropout=0.1,
|
|
183
187
|
transformer_stochastic_depth=0.1,
|
|
184
188
|
transformer_checkpoint_ff=True,
|
|
@@ -333,7 +337,9 @@ class ViTEncoder(nn.Module):
|
|
|
333
337
|
ff_linear_module_up=transformer_ff_linear_module_up,
|
|
334
338
|
ff_linear_module_down=transformer_ff_linear_module_down,
|
|
335
339
|
msa_scaling=transformer_msa_scaling,
|
|
336
|
-
|
|
340
|
+
ff_dropout=transformer_ff_dropout,
|
|
341
|
+
ff_inner_dropout=transformer_ff_inner_dropout,
|
|
342
|
+
ff_outer_dropout=transformer_ff_outer_dropout,
|
|
337
343
|
msa_dropout=transformer_msa_dropout,
|
|
338
344
|
stochastic_depth=transformer_stochastic_depth,
|
|
339
345
|
causal=False,
|
|
@@ -357,9 +363,21 @@ class ViTEncoder(nn.Module):
|
|
|
357
363
|
activation_kwargs=transformer_activation_kwargs,
|
|
358
364
|
dropout=(
|
|
359
365
|
# First truthy assigned value
|
|
360
|
-
|
|
361
|
-
if
|
|
362
|
-
else
|
|
366
|
+
transformer_initial_ff_dropout
|
|
367
|
+
if transformer_initial_ff_dropout is not None
|
|
368
|
+
else transformer_ff_dropout
|
|
369
|
+
),
|
|
370
|
+
inner_dropout=(
|
|
371
|
+
# First truthy assigned value
|
|
372
|
+
transformer_initial_ff_inner_dropout
|
|
373
|
+
if transformer_initial_ff_inner_dropout is not None
|
|
374
|
+
else transformer_ff_inner_dropout
|
|
375
|
+
),
|
|
376
|
+
outer_dropout=(
|
|
377
|
+
# First truthy assigned value
|
|
378
|
+
transformer_initial_ff_outer_dropout
|
|
379
|
+
if transformer_initial_ff_outer_dropout is not None
|
|
380
|
+
else transformer_ff_outer_dropout
|
|
363
381
|
),
|
|
364
382
|
linear_module_up=(
|
|
365
383
|
# First truthy assigned value
|
|
@@ -441,7 +459,9 @@ class ViT(nn.Module):
|
|
|
441
459
|
transformer_initial_ff_residual_path=True,
|
|
442
460
|
transformer_initial_ff_linear_module_up=None,
|
|
443
461
|
transformer_initial_ff_linear_module_down=None,
|
|
444
|
-
|
|
462
|
+
transformer_initial_ff_dropout=None,
|
|
463
|
+
transformer_initial_ff_inner_dropout=None,
|
|
464
|
+
transformer_initial_ff_outer_dropout=None,
|
|
445
465
|
transformer_pre_norm=True,
|
|
446
466
|
transformer_normformer=False,
|
|
447
467
|
transformer_post_norm=False,
|
|
@@ -458,7 +478,9 @@ class ViT(nn.Module):
|
|
|
458
478
|
transformer_ff_linear_module_up=None,
|
|
459
479
|
transformer_ff_linear_module_down=None,
|
|
460
480
|
transformer_msa_scaling="d",
|
|
461
|
-
|
|
481
|
+
transformer_ff_dropout=0.0,
|
|
482
|
+
transformer_ff_inner_dropout=0.0,
|
|
483
|
+
transformer_ff_outer_dropout=0.0,
|
|
462
484
|
transformer_msa_dropout=0.1,
|
|
463
485
|
transformer_stochastic_depth=0.1,
|
|
464
486
|
transformer_checkpoint_ff=True,
|
|
@@ -508,7 +530,9 @@ class ViT(nn.Module):
|
|
|
508
530
|
transformer_initial_ff_residual_path=transformer_initial_ff_residual_path,
|
|
509
531
|
transformer_initial_ff_linear_module_up=transformer_initial_ff_linear_module_up,
|
|
510
532
|
transformer_initial_ff_linear_module_down=transformer_initial_ff_linear_module_down,
|
|
511
|
-
|
|
533
|
+
transformer_initial_ff_dropout=transformer_initial_ff_dropout,
|
|
534
|
+
transformer_initial_ff_inner_dropout=transformer_initial_ff_inner_dropout,
|
|
535
|
+
transformer_initial_ff_outer_dropout=transformer_initial_ff_outer_dropout,
|
|
512
536
|
transformer_pre_norm=transformer_pre_norm,
|
|
513
537
|
transformer_normformer=transformer_normformer,
|
|
514
538
|
transformer_post_norm=transformer_post_norm,
|
|
@@ -525,7 +549,9 @@ class ViT(nn.Module):
|
|
|
525
549
|
transformer_ff_linear_module_up=transformer_ff_linear_module_up,
|
|
526
550
|
transformer_ff_linear_module_down=transformer_ff_linear_module_down,
|
|
527
551
|
transformer_msa_scaling=transformer_msa_scaling,
|
|
528
|
-
|
|
552
|
+
transformer_ff_dropout=transformer_ff_dropout,
|
|
553
|
+
transformer_ff_inner_dropout=transformer_ff_inner_dropout,
|
|
554
|
+
transformer_ff_outer_dropout=transformer_ff_outer_dropout,
|
|
529
555
|
transformer_msa_dropout=transformer_msa_dropout,
|
|
530
556
|
transformer_stochastic_depth=transformer_stochastic_depth,
|
|
531
557
|
transformer_checkpoint_ff=transformer_checkpoint_ff,
|
|
@@ -1,13 +1,13 @@
|
|
|
1
1
|
broccoli/__init__.py,sha256=tmyspsVxqPZHRQCY_NRwpW4SMNBbtE8E_8z7l-SAzSo,127
|
|
2
2
|
broccoli/activation.py,sha256=-Jf30C6iGqWCorC9HEGn2oduWwjeaCAxGLUUYIy1zX8,3438
|
|
3
3
|
broccoli/cnn.py,sha256=WjoPDSpe3ttwxCBNfCVRdaCHvbeZ7G-a5_i8fUsK_d8,4889
|
|
4
|
-
broccoli/linear.py,sha256=
|
|
4
|
+
broccoli/linear.py,sha256=XaGHZguvK-7hvtIt07zo8uQZBQvS7oMD2K9nPvyYJLE,9769
|
|
5
5
|
broccoli/rope.py,sha256=GRqApBNmYCFaDak0WL1xE_BC5CTTYKQU_PBdeTcQcjc,12557
|
|
6
6
|
broccoli/tensor.py,sha256=um8mrxkYbvNDo-QvHlmJm8Aw6qcngOlUZPoAk_PMReA,4480
|
|
7
|
-
broccoli/transformer.py,sha256
|
|
7
|
+
broccoli/transformer.py,sha256=Rozh0hExHjwGvvKbMeZfLoB95dDKyDn3X6o1Ms26aAI,24241
|
|
8
8
|
broccoli/utils.py,sha256=oOWzn6dJ5nC_9r4zq0emmfmaYACJXJNFS48AOpW2jqc,358
|
|
9
|
-
broccoli/vit.py,sha256=
|
|
10
|
-
broccoli_ml-
|
|
11
|
-
broccoli_ml-
|
|
12
|
-
broccoli_ml-
|
|
13
|
-
broccoli_ml-
|
|
9
|
+
broccoli/vit.py,sha256=sC6K3FK3a8ojOgvNWSWhuZHBtnFrrTQbsDdlagcKJH4,22224
|
|
10
|
+
broccoli_ml-9.0.0.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
|
|
11
|
+
broccoli_ml-9.0.0.dist-info/METADATA,sha256=ecQ2BRxtzmNSO2CMAp2rcNRq9L37urE_pKdsPf-jJKs,1368
|
|
12
|
+
broccoli_ml-9.0.0.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
|
13
|
+
broccoli_ml-9.0.0.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|