broccoli-ml 6.0.0__py3-none-any.whl → 9.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- broccoli/linear.py +179 -0
- broccoli/transformer.py +81 -14
- broccoli/vit.py +36 -10
- {broccoli_ml-6.0.0.dist-info → broccoli_ml-9.2.2.dist-info}/METADATA +1 -1
- {broccoli_ml-6.0.0.dist-info → broccoli_ml-9.2.2.dist-info}/RECORD +7 -7
- {broccoli_ml-6.0.0.dist-info → broccoli_ml-9.2.2.dist-info}/LICENSE +0 -0
- {broccoli_ml-6.0.0.dist-info → broccoli_ml-9.2.2.dist-info}/WHEEL +0 -0
broccoli/linear.py
CHANGED
|
@@ -1,4 +1,8 @@
|
|
|
1
1
|
import math
|
|
2
|
+
import random
|
|
3
|
+
import warnings
|
|
4
|
+
from typing import Union, List, Iterable
|
|
5
|
+
|
|
2
6
|
import torch
|
|
3
7
|
from torch import nn
|
|
4
8
|
from torch.nn import functional as F
|
|
@@ -136,3 +140,178 @@ class WeightNormedLinear(nn.Module):
|
|
|
136
140
|
f"WeightNormedLinear(in_features={self.in_features},"
|
|
137
141
|
f"out_features={self.out_features}, bias={self.use_bias})"
|
|
138
142
|
)
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
class RecyclingLinear(nn.Module):
|
|
146
|
+
def __init__(
|
|
147
|
+
self,
|
|
148
|
+
in_features: int,
|
|
149
|
+
out_features: int,
|
|
150
|
+
bias: bool = True,
|
|
151
|
+
row_recycling_rate: float = 0.0,
|
|
152
|
+
column_recycling_rate: float = 0.0,
|
|
153
|
+
adaptive=False,
|
|
154
|
+
):
|
|
155
|
+
super().__init__()
|
|
156
|
+
self.in_features = in_features
|
|
157
|
+
self.out_features = out_features
|
|
158
|
+
self.bias = bias
|
|
159
|
+
self.linear = nn.Linear(in_features, out_features, bias=bias)
|
|
160
|
+
self.row_recycling_rate = row_recycling_rate
|
|
161
|
+
self.column_recycling_rate = column_recycling_rate
|
|
162
|
+
self.adaptive = adaptive
|
|
163
|
+
self.optimisers = []
|
|
164
|
+
self.initial_learning_rates = []
|
|
165
|
+
self._warned_about_registration = False
|
|
166
|
+
|
|
167
|
+
def register_optimiser(self, optimiser: torch.optim.Optimizer):
|
|
168
|
+
self.optimisers.append(optimiser)
|
|
169
|
+
self.initial_learning_rates.append(self._get_learning_rate(optimiser))
|
|
170
|
+
if self.initial_learning_rates[-1] == 0.0:
|
|
171
|
+
warnings.warn(
|
|
172
|
+
"Learning rate of registered optimiser was 0.0 - make sure "
|
|
173
|
+
"you haven't initialised a scheduler before registering the "
|
|
174
|
+
"optimiser",
|
|
175
|
+
stacklevel=2,
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
def _get_learning_rate(self, optimiser: torch.optim.Optimizer):
|
|
179
|
+
for group in optimiser.param_groups:
|
|
180
|
+
for param in group["params"]:
|
|
181
|
+
if param is self.linear.weight:
|
|
182
|
+
return group["lr"]
|
|
183
|
+
|
|
184
|
+
def _get_multiplier(self):
|
|
185
|
+
if not self.adaptive or not self.optimisers:
|
|
186
|
+
return 1.0
|
|
187
|
+
else:
|
|
188
|
+
init = self.initial_learning_rates
|
|
189
|
+
current = [self._get_learning_rate(o) for o in self.optimisers]
|
|
190
|
+
pairs = zip(current, init, strict=True)
|
|
191
|
+
multipliers = [a / b for a, b in pairs if b != 0.0]
|
|
192
|
+
return min(multipliers) if multipliers else 0.0
|
|
193
|
+
|
|
194
|
+
def forward(self, x):
|
|
195
|
+
multiplier = self._get_multiplier()
|
|
196
|
+
col_recycling_rate = self.column_recycling_rate * multiplier
|
|
197
|
+
row_recycling_rate = self.row_recycling_rate * multiplier
|
|
198
|
+
|
|
199
|
+
if self.training and self.optimisers:
|
|
200
|
+
|
|
201
|
+
if row_recycling_rate > 0:
|
|
202
|
+
probs = torch.rand(self.linear.out_features, device=x.device)
|
|
203
|
+
mask = probs < row_recycling_rate
|
|
204
|
+
if mask.any():
|
|
205
|
+
# nonzero returns [N, 1], squeeze to get [N]
|
|
206
|
+
indices = torch.nonzero(mask).squeeze(-1)
|
|
207
|
+
self.reset_rows(indices, self.optimisers)
|
|
208
|
+
|
|
209
|
+
if col_recycling_rate > 0:
|
|
210
|
+
probs = torch.rand(self.linear.in_features, device=x.device)
|
|
211
|
+
mask = probs < col_recycling_rate
|
|
212
|
+
if mask.any():
|
|
213
|
+
indices = torch.nonzero(mask).squeeze(-1)
|
|
214
|
+
self.reset_columns(indices, self.optimisers)
|
|
215
|
+
|
|
216
|
+
elif self.training and not self._warned_about_registration:
|
|
217
|
+
warnings.warn(
|
|
218
|
+
"RecyclingLinear: No optimiser registered. Recycling disabled.",
|
|
219
|
+
stacklevel=2,
|
|
220
|
+
)
|
|
221
|
+
self._warned_about_registration = True
|
|
222
|
+
|
|
223
|
+
return self.linear(x)
|
|
224
|
+
|
|
225
|
+
def reset_rows(
|
|
226
|
+
self,
|
|
227
|
+
indices: Iterable[int],
|
|
228
|
+
optimisers: Union[
|
|
229
|
+
List[torch.optim.Optimizer], torch.optim.Optimizer, None
|
|
230
|
+
] = None,
|
|
231
|
+
):
|
|
232
|
+
"""
|
|
233
|
+
Update some of the weight rows to be equal to the mean of all weight rows.
|
|
234
|
+
"""
|
|
235
|
+
if optimisers is None:
|
|
236
|
+
optimisers = []
|
|
237
|
+
if not isinstance(optimisers, list):
|
|
238
|
+
optimisers = [optimisers]
|
|
239
|
+
|
|
240
|
+
device = self.linear.weight.device
|
|
241
|
+
idx_tensor = torch.as_tensor(list(indices), dtype=torch.long, device=device)
|
|
242
|
+
|
|
243
|
+
if idx_tensor.numel() == 0:
|
|
244
|
+
return
|
|
245
|
+
|
|
246
|
+
with torch.no_grad():
|
|
247
|
+
# Calculate mean of all rows including the rows to be reset
|
|
248
|
+
mean_vector = self.linear.weight.data.mean(
|
|
249
|
+
dim=0, keepdim=True
|
|
250
|
+
) # [1, in_features]
|
|
251
|
+
update_data = mean_vector.expand(idx_tensor.size(0), -1)
|
|
252
|
+
self.linear.weight.data[idx_tensor] = update_data
|
|
253
|
+
|
|
254
|
+
if self.linear.bias is not None:
|
|
255
|
+
self.linear.bias.data[idx_tensor] = 0.0
|
|
256
|
+
|
|
257
|
+
self._reset_optim_state(self.linear.weight, idx_tensor, optimisers, dim=0)
|
|
258
|
+
if self.linear.bias is not None:
|
|
259
|
+
self._reset_optim_state(self.linear.bias, idx_tensor, optimisers, dim=0)
|
|
260
|
+
|
|
261
|
+
def reset_columns(
|
|
262
|
+
self,
|
|
263
|
+
indices: Iterable[int],
|
|
264
|
+
optimisers: Union[
|
|
265
|
+
List[torch.optim.Optimizer], torch.optim.Optimizer, None
|
|
266
|
+
] = None,
|
|
267
|
+
):
|
|
268
|
+
"""
|
|
269
|
+
Update some of the weight columns to be random as though reinitialised.
|
|
270
|
+
"""
|
|
271
|
+
if optimisers is None:
|
|
272
|
+
optimisers = []
|
|
273
|
+
if not isinstance(optimisers, list):
|
|
274
|
+
optimisers = [optimisers]
|
|
275
|
+
|
|
276
|
+
device = self.linear.weight.device
|
|
277
|
+
idx_tensor = torch.as_tensor(list(indices), dtype=torch.long, device=device)
|
|
278
|
+
|
|
279
|
+
if idx_tensor.numel() == 0:
|
|
280
|
+
return
|
|
281
|
+
|
|
282
|
+
with torch.no_grad():
|
|
283
|
+
# 1. Generate Random Columns
|
|
284
|
+
# Shape: [out_features, N_indices]
|
|
285
|
+
weights = self.linear.weight.data
|
|
286
|
+
stdv = 1.0 / math.sqrt(weights.size(1))
|
|
287
|
+
|
|
288
|
+
# Generate [Rows, N] block
|
|
289
|
+
random_weights = torch.rand(
|
|
290
|
+
weights.size(0), idx_tensor.size(0), device=device
|
|
291
|
+
)
|
|
292
|
+
random_weights = (random_weights - 0.5) * 2.0 * stdv
|
|
293
|
+
|
|
294
|
+
# 2. Update Weights (One-shot)
|
|
295
|
+
# We assign into the columns specified by idx_tensor
|
|
296
|
+
self.linear.weight.data[:, idx_tensor] = random_weights
|
|
297
|
+
|
|
298
|
+
# 3. Update Optimizers
|
|
299
|
+
# Bias is untouched by column resets (bias is shape [Out], cols are [In])
|
|
300
|
+
self._reset_optim_state(self.linear.weight, idx_tensor, optimisers, dim=1)
|
|
301
|
+
|
|
302
|
+
def _reset_optim_state(self, param, idx_tensor, optimisers, dim):
|
|
303
|
+
"""
|
|
304
|
+
Zeroes out the optimizer state for the given indices in a single operation.
|
|
305
|
+
"""
|
|
306
|
+
for optimiser in optimisers:
|
|
307
|
+
if param not in optimiser.state:
|
|
308
|
+
continue
|
|
309
|
+
state = optimiser.state[param]
|
|
310
|
+
|
|
311
|
+
for _, buffer in state.items():
|
|
312
|
+
if torch.is_tensor(buffer) and buffer.shape == param.shape:
|
|
313
|
+
# Vectorized zeroing
|
|
314
|
+
if dim == 0:
|
|
315
|
+
buffer[idx_tensor] = 0.0
|
|
316
|
+
else:
|
|
317
|
+
buffer[:, idx_tensor] = 0.0
|
broccoli/transformer.py
CHANGED
|
@@ -13,6 +13,7 @@ from .rope import RotaryEmbedding, apply_rotary_emb
|
|
|
13
13
|
try:
|
|
14
14
|
from flash_attn import flash_attn_func
|
|
15
15
|
|
|
16
|
+
print("Using flash-attn.")
|
|
16
17
|
FLASH_ATTN = True
|
|
17
18
|
except ImportError:
|
|
18
19
|
pass
|
|
@@ -324,6 +325,8 @@ class FeedforwardBlock(nn.Module):
|
|
|
324
325
|
activation=nn.ReLU,
|
|
325
326
|
activation_kwargs=None,
|
|
326
327
|
dropout=0.0,
|
|
328
|
+
inner_dropout=None,
|
|
329
|
+
outer_dropout=None,
|
|
327
330
|
linear_module_up=nn.Linear,
|
|
328
331
|
linear_module_down=nn.Linear,
|
|
329
332
|
pre_norm=True,
|
|
@@ -337,6 +340,7 @@ class FeedforwardBlock(nn.Module):
|
|
|
337
340
|
self.checkpoint = checkpoint
|
|
338
341
|
self.residual_path = residual_path
|
|
339
342
|
self.post_norm = post_norm
|
|
343
|
+
self.xglu = activation.__name__.endswith("GLU")
|
|
340
344
|
|
|
341
345
|
if self.residual_path and (output_features < input_features):
|
|
342
346
|
raise ValueError(
|
|
@@ -353,29 +357,75 @@ class FeedforwardBlock(nn.Module):
|
|
|
353
357
|
else:
|
|
354
358
|
self.activation = activation()
|
|
355
359
|
|
|
356
|
-
self.
|
|
360
|
+
self.inner_dropout = nn.Dropout(
|
|
361
|
+
inner_dropout if inner_dropout is not None else dropout
|
|
362
|
+
)
|
|
363
|
+
self.outer_dropout = nn.Dropout(
|
|
364
|
+
outer_dropout if outer_dropout is not None else dropout
|
|
365
|
+
)
|
|
357
366
|
|
|
358
367
|
self.max_features = (
|
|
359
|
-
2 * ratio * output_features
|
|
360
|
-
if activation.__name__.endswith("GLU")
|
|
361
|
-
else ratio * output_features
|
|
368
|
+
2 * ratio * output_features if self.xglu else ratio * output_features
|
|
362
369
|
)
|
|
363
370
|
|
|
371
|
+
self.linear_in = linear_module_up(input_features, self.max_features)
|
|
372
|
+
self.linear_out = linear_module_down(ratio * output_features, output_features)
|
|
373
|
+
|
|
364
374
|
self.process = nn.Sequential(
|
|
365
375
|
*[
|
|
366
376
|
nn.LayerNorm(input_features) if pre_norm else nn.Identity(),
|
|
367
|
-
|
|
377
|
+
self.linear_in,
|
|
368
378
|
self.activation,
|
|
379
|
+
self.inner_dropout,
|
|
369
380
|
nn.LayerNorm(ratio * output_features) if normformer else nn.Identity(),
|
|
370
|
-
|
|
371
|
-
self.
|
|
381
|
+
self.linear_out,
|
|
382
|
+
self.outer_dropout,
|
|
372
383
|
]
|
|
373
384
|
)
|
|
374
385
|
|
|
386
|
+
self.recycling_enabled = False
|
|
387
|
+
if hasattr(self.linear_in, "row_recycling_rate") and hasattr(
|
|
388
|
+
self.linear_out, "column_recycling_rate"
|
|
389
|
+
):
|
|
390
|
+
self.recycling_enabled = True
|
|
391
|
+
self.master_recycling_rate = self.linear_in.row_recycling_rate
|
|
392
|
+
self.linear_in.row_recycling_rate = 0.0
|
|
393
|
+
self.linear_out.column_recycling_rate = 0.0
|
|
394
|
+
if (
|
|
395
|
+
hasattr(self.linear_in, "column_recycling_rate")
|
|
396
|
+
and self.linear_in.column_recycling_rate > 0
|
|
397
|
+
) or (
|
|
398
|
+
hasattr(self.linear_out, "row_recycling_rate")
|
|
399
|
+
and self.linear_out.row_recycling_rate > 0
|
|
400
|
+
):
|
|
401
|
+
raise NotImplementedError(
|
|
402
|
+
"At the moment this layer can only support recycling linear "
|
|
403
|
+
"layers if the in layer resets only rows and the out layer "
|
|
404
|
+
"resets only columns."
|
|
405
|
+
)
|
|
406
|
+
|
|
375
407
|
self.reset_parameters()
|
|
376
408
|
|
|
377
409
|
def forward(self, x):
|
|
378
410
|
|
|
411
|
+
# Recycle weights if using recycling linear layers
|
|
412
|
+
if self.training and self.recycling_enabled:
|
|
413
|
+
multiplier = self.linear_in._get_multiplier()
|
|
414
|
+
rate = self.master_recycling_rate * multiplier
|
|
415
|
+
if rate > 0:
|
|
416
|
+
probs = torch.rand(self.linear_out.in_features, device=x.device)
|
|
417
|
+
mask = probs < rate
|
|
418
|
+
if mask.any():
|
|
419
|
+
indices = torch.nonzero(mask).squeeze(-1)
|
|
420
|
+
self.linear_out.reset_columns(indices, self.linear_out.optimisers)
|
|
421
|
+
if self.xglu:
|
|
422
|
+
indices_in = torch.cat(
|
|
423
|
+
[indices, indices + self.linear_out.in_features]
|
|
424
|
+
)
|
|
425
|
+
self.linear_in.reset_rows(indices_in, self.linear_in.optimisers)
|
|
426
|
+
else:
|
|
427
|
+
self.linear_in.reset_rows(indices, self.linear_in.optimisers)
|
|
428
|
+
|
|
379
429
|
if self.checkpoint:
|
|
380
430
|
processed = checkpoint(self.process, x, use_reentrant=False)
|
|
381
431
|
else:
|
|
@@ -421,7 +471,9 @@ class TransformerBlock(nn.Module):
|
|
|
421
471
|
ff_linear_module_up=None,
|
|
422
472
|
ff_linear_module_down=None,
|
|
423
473
|
msa_scaling="d",
|
|
424
|
-
|
|
474
|
+
ff_dropout=0.0,
|
|
475
|
+
ff_inner_dropout=0.0,
|
|
476
|
+
ff_outer_dropout=0.0,
|
|
425
477
|
msa_dropout=0.0,
|
|
426
478
|
identity_probability=0.0,
|
|
427
479
|
causal=False,
|
|
@@ -483,7 +535,9 @@ class TransformerBlock(nn.Module):
|
|
|
483
535
|
d_model,
|
|
484
536
|
activation=activation,
|
|
485
537
|
activation_kwargs=activation_kwargs,
|
|
486
|
-
dropout=
|
|
538
|
+
dropout=ff_dropout,
|
|
539
|
+
inner_dropout=ff_inner_dropout,
|
|
540
|
+
outer_dropout=ff_outer_dropout,
|
|
487
541
|
linear_module_up=(
|
|
488
542
|
ff_linear_module_up
|
|
489
543
|
if ff_linear_module_up is not None
|
|
@@ -566,7 +620,9 @@ class TransformerEncoder(nn.Module):
|
|
|
566
620
|
activation_kwargs: Optional[dict] = None,
|
|
567
621
|
ff_linear_module_up=None,
|
|
568
622
|
ff_linear_module_down=None,
|
|
569
|
-
|
|
623
|
+
ff_dropout=0.0,
|
|
624
|
+
ff_inner_dropout=0.0,
|
|
625
|
+
ff_outer_dropout=0.0,
|
|
570
626
|
msa_dropout=0.0,
|
|
571
627
|
stochastic_depth=0.0,
|
|
572
628
|
causal=False,
|
|
@@ -590,7 +646,13 @@ class TransformerEncoder(nn.Module):
|
|
|
590
646
|
if relative_position_embedding and (source_size is None):
|
|
591
647
|
raise ValueError(
|
|
592
648
|
"`source_size` for TransformerEncoder cannot be None if"
|
|
593
|
-
" `
|
|
649
|
+
" `relative_position_embedding` is True"
|
|
650
|
+
)
|
|
651
|
+
|
|
652
|
+
if absolute_position_embedding and (seq_len is None):
|
|
653
|
+
raise ValueError(
|
|
654
|
+
"`seq_len` for TransformerEncoder cannot be None if"
|
|
655
|
+
" `absolute_position_embedding` is True"
|
|
594
656
|
)
|
|
595
657
|
|
|
596
658
|
super().__init__()
|
|
@@ -605,9 +667,12 @@ class TransformerEncoder(nn.Module):
|
|
|
605
667
|
torch.empty(self._utility_tokens, d_model)
|
|
606
668
|
)
|
|
607
669
|
nn.init.normal_(self._utility_token_embedding, mean=0.0, std=1.0)
|
|
608
|
-
self.full_sequence_length = self.seq_len + self._utility_tokens
|
|
609
670
|
else:
|
|
610
671
|
self._utility_token_embedding = None
|
|
672
|
+
|
|
673
|
+
if self._utility_tokens and (self.seq_len is not None):
|
|
674
|
+
self.full_sequence_length = self.seq_len + self._utility_tokens
|
|
675
|
+
else:
|
|
611
676
|
self.full_sequence_length = self.seq_len
|
|
612
677
|
|
|
613
678
|
self.d_model = d_model
|
|
@@ -619,7 +684,7 @@ class TransformerEncoder(nn.Module):
|
|
|
619
684
|
else:
|
|
620
685
|
self.absolute_position_embedding = None
|
|
621
686
|
|
|
622
|
-
self.mlp_dropout =
|
|
687
|
+
self.mlp_dropout = ff_dropout
|
|
623
688
|
self.msa_dropout = msa_dropout
|
|
624
689
|
self.stochastic_depth = stochastic_depth
|
|
625
690
|
|
|
@@ -648,7 +713,9 @@ class TransformerEncoder(nn.Module):
|
|
|
648
713
|
ff_linear_module_up=ff_linear_module_up,
|
|
649
714
|
ff_linear_module_down=ff_linear_module_down,
|
|
650
715
|
msa_scaling=msa_scaling,
|
|
651
|
-
|
|
716
|
+
ff_dropout=ff_dropout,
|
|
717
|
+
ff_inner_dropout=ff_inner_dropout,
|
|
718
|
+
ff_outer_dropout=ff_outer_dropout,
|
|
652
719
|
msa_dropout=msa_dropout,
|
|
653
720
|
identity_probability=self.stochastic_depth_probabilities[i],
|
|
654
721
|
causal=causal,
|
broccoli/vit.py
CHANGED
|
@@ -161,7 +161,9 @@ class ViTEncoder(nn.Module):
|
|
|
161
161
|
transformer_initial_ff_residual_path=True,
|
|
162
162
|
transformer_initial_ff_linear_module_up=None,
|
|
163
163
|
transformer_initial_ff_linear_module_down=None,
|
|
164
|
-
|
|
164
|
+
transformer_initial_ff_dropout=None,
|
|
165
|
+
transformer_initial_ff_inner_dropout=None,
|
|
166
|
+
transformer_initial_ff_outer_dropout=None,
|
|
165
167
|
transformer_pre_norm=True,
|
|
166
168
|
transformer_normformer=False,
|
|
167
169
|
transformer_post_norm=False,
|
|
@@ -178,7 +180,9 @@ class ViTEncoder(nn.Module):
|
|
|
178
180
|
transformer_ff_linear_module_up=None,
|
|
179
181
|
transformer_ff_linear_module_down=None,
|
|
180
182
|
transformer_msa_scaling="d",
|
|
181
|
-
|
|
183
|
+
transformer_ff_dropout=0.0,
|
|
184
|
+
transformer_ff_inner_dropout=0.0,
|
|
185
|
+
transformer_ff_outer_dropout=0.0,
|
|
182
186
|
transformer_msa_dropout=0.1,
|
|
183
187
|
transformer_stochastic_depth=0.1,
|
|
184
188
|
transformer_checkpoint_ff=True,
|
|
@@ -333,7 +337,9 @@ class ViTEncoder(nn.Module):
|
|
|
333
337
|
ff_linear_module_up=transformer_ff_linear_module_up,
|
|
334
338
|
ff_linear_module_down=transformer_ff_linear_module_down,
|
|
335
339
|
msa_scaling=transformer_msa_scaling,
|
|
336
|
-
|
|
340
|
+
ff_dropout=transformer_ff_dropout,
|
|
341
|
+
ff_inner_dropout=transformer_ff_inner_dropout,
|
|
342
|
+
ff_outer_dropout=transformer_ff_outer_dropout,
|
|
337
343
|
msa_dropout=transformer_msa_dropout,
|
|
338
344
|
stochastic_depth=transformer_stochastic_depth,
|
|
339
345
|
causal=False,
|
|
@@ -357,9 +363,21 @@ class ViTEncoder(nn.Module):
|
|
|
357
363
|
activation_kwargs=transformer_activation_kwargs,
|
|
358
364
|
dropout=(
|
|
359
365
|
# First truthy assigned value
|
|
360
|
-
|
|
361
|
-
if
|
|
362
|
-
else
|
|
366
|
+
transformer_initial_ff_dropout
|
|
367
|
+
if transformer_initial_ff_dropout is not None
|
|
368
|
+
else transformer_ff_dropout
|
|
369
|
+
),
|
|
370
|
+
inner_dropout=(
|
|
371
|
+
# First truthy assigned value
|
|
372
|
+
transformer_initial_ff_inner_dropout
|
|
373
|
+
if transformer_initial_ff_inner_dropout is not None
|
|
374
|
+
else transformer_ff_inner_dropout
|
|
375
|
+
),
|
|
376
|
+
outer_dropout=(
|
|
377
|
+
# First truthy assigned value
|
|
378
|
+
transformer_initial_ff_outer_dropout
|
|
379
|
+
if transformer_initial_ff_outer_dropout is not None
|
|
380
|
+
else transformer_ff_outer_dropout
|
|
363
381
|
),
|
|
364
382
|
linear_module_up=(
|
|
365
383
|
# First truthy assigned value
|
|
@@ -441,7 +459,9 @@ class ViT(nn.Module):
|
|
|
441
459
|
transformer_initial_ff_residual_path=True,
|
|
442
460
|
transformer_initial_ff_linear_module_up=None,
|
|
443
461
|
transformer_initial_ff_linear_module_down=None,
|
|
444
|
-
|
|
462
|
+
transformer_initial_ff_dropout=None,
|
|
463
|
+
transformer_initial_ff_inner_dropout=None,
|
|
464
|
+
transformer_initial_ff_outer_dropout=None,
|
|
445
465
|
transformer_pre_norm=True,
|
|
446
466
|
transformer_normformer=False,
|
|
447
467
|
transformer_post_norm=False,
|
|
@@ -458,7 +478,9 @@ class ViT(nn.Module):
|
|
|
458
478
|
transformer_ff_linear_module_up=None,
|
|
459
479
|
transformer_ff_linear_module_down=None,
|
|
460
480
|
transformer_msa_scaling="d",
|
|
461
|
-
|
|
481
|
+
transformer_ff_dropout=0.0,
|
|
482
|
+
transformer_ff_inner_dropout=0.0,
|
|
483
|
+
transformer_ff_outer_dropout=0.0,
|
|
462
484
|
transformer_msa_dropout=0.1,
|
|
463
485
|
transformer_stochastic_depth=0.1,
|
|
464
486
|
transformer_checkpoint_ff=True,
|
|
@@ -508,7 +530,9 @@ class ViT(nn.Module):
|
|
|
508
530
|
transformer_initial_ff_residual_path=transformer_initial_ff_residual_path,
|
|
509
531
|
transformer_initial_ff_linear_module_up=transformer_initial_ff_linear_module_up,
|
|
510
532
|
transformer_initial_ff_linear_module_down=transformer_initial_ff_linear_module_down,
|
|
511
|
-
|
|
533
|
+
transformer_initial_ff_dropout=transformer_initial_ff_dropout,
|
|
534
|
+
transformer_initial_ff_inner_dropout=transformer_initial_ff_inner_dropout,
|
|
535
|
+
transformer_initial_ff_outer_dropout=transformer_initial_ff_outer_dropout,
|
|
512
536
|
transformer_pre_norm=transformer_pre_norm,
|
|
513
537
|
transformer_normformer=transformer_normformer,
|
|
514
538
|
transformer_post_norm=transformer_post_norm,
|
|
@@ -525,7 +549,9 @@ class ViT(nn.Module):
|
|
|
525
549
|
transformer_ff_linear_module_up=transformer_ff_linear_module_up,
|
|
526
550
|
transformer_ff_linear_module_down=transformer_ff_linear_module_down,
|
|
527
551
|
transformer_msa_scaling=transformer_msa_scaling,
|
|
528
|
-
|
|
552
|
+
transformer_ff_dropout=transformer_ff_dropout,
|
|
553
|
+
transformer_ff_inner_dropout=transformer_ff_inner_dropout,
|
|
554
|
+
transformer_ff_outer_dropout=transformer_ff_outer_dropout,
|
|
529
555
|
transformer_msa_dropout=transformer_msa_dropout,
|
|
530
556
|
transformer_stochastic_depth=transformer_stochastic_depth,
|
|
531
557
|
transformer_checkpoint_ff=transformer_checkpoint_ff,
|
|
@@ -1,13 +1,13 @@
|
|
|
1
1
|
broccoli/__init__.py,sha256=tmyspsVxqPZHRQCY_NRwpW4SMNBbtE8E_8z7l-SAzSo,127
|
|
2
2
|
broccoli/activation.py,sha256=-Jf30C6iGqWCorC9HEGn2oduWwjeaCAxGLUUYIy1zX8,3438
|
|
3
3
|
broccoli/cnn.py,sha256=WjoPDSpe3ttwxCBNfCVRdaCHvbeZ7G-a5_i8fUsK_d8,4889
|
|
4
|
-
broccoli/linear.py,sha256=
|
|
4
|
+
broccoli/linear.py,sha256=IwvPAMbHqOYyz0g5WZyevPAhC1Pn0RTLniFM4E6lJoI,11511
|
|
5
5
|
broccoli/rope.py,sha256=GRqApBNmYCFaDak0WL1xE_BC5CTTYKQU_PBdeTcQcjc,12557
|
|
6
6
|
broccoli/tensor.py,sha256=um8mrxkYbvNDo-QvHlmJm8Aw6qcngOlUZPoAk_PMReA,4480
|
|
7
|
-
broccoli/transformer.py,sha256=
|
|
7
|
+
broccoli/transformer.py,sha256=r-ggAeNDW5QpBi9As1U9sIfxITBOx0WHk_K4zWpyTM8,26233
|
|
8
8
|
broccoli/utils.py,sha256=oOWzn6dJ5nC_9r4zq0emmfmaYACJXJNFS48AOpW2jqc,358
|
|
9
|
-
broccoli/vit.py,sha256=
|
|
10
|
-
broccoli_ml-
|
|
11
|
-
broccoli_ml-
|
|
12
|
-
broccoli_ml-
|
|
13
|
-
broccoli_ml-
|
|
9
|
+
broccoli/vit.py,sha256=sC6K3FK3a8ojOgvNWSWhuZHBtnFrrTQbsDdlagcKJH4,22224
|
|
10
|
+
broccoli_ml-9.2.2.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
|
|
11
|
+
broccoli_ml-9.2.2.dist-info/METADATA,sha256=8ySQYntl9czgYyEQN5nyPS31tjwYC8M8Mx_iYhvtbzg,1368
|
|
12
|
+
broccoli_ml-9.2.2.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
|
13
|
+
broccoli_ml-9.2.2.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|