broccoli-ml 0.37.0__py3-none-any.whl → 0.38.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- broccoli/linear.py +1 -2
- broccoli/transformer.py +16 -47
- broccoli/vit.py +14 -4
- {broccoli_ml-0.37.0.dist-info → broccoli_ml-0.38.0.dist-info}/METADATA +1 -1
- {broccoli_ml-0.37.0.dist-info → broccoli_ml-0.38.0.dist-info}/RECORD +7 -7
- {broccoli_ml-0.37.0.dist-info → broccoli_ml-0.38.0.dist-info}/LICENSE +0 -0
- {broccoli_ml-0.37.0.dist-info → broccoli_ml-0.38.0.dist-info}/WHEEL +0 -0
broccoli/linear.py
CHANGED
@@ -131,8 +131,7 @@ class WeightNormedLinear(nn.Module):
|
|
131
131
|
return F.linear(x, self.weights(), self.bias)
|
132
132
|
|
133
133
|
def __repr__(self) -> str:
|
134
|
-
# Optional: A nice representation for printing the module.
|
135
134
|
return (
|
136
|
-
f"
|
135
|
+
f"WeightNormedLinear(in_features={self.in_features},"
|
137
136
|
f"out_features={self.out_features}, bias={self.use_bias})"
|
138
137
|
)
|
broccoli/transformer.py
CHANGED
@@ -303,6 +303,8 @@ class TransformerBlock(nn.Module):
|
|
303
303
|
mlp_ratio=4,
|
304
304
|
activation: nn.Module = nn.ReLU,
|
305
305
|
activation_kwargs: Optional[dict] = None,
|
306
|
+
ff_linear_module_up=None,
|
307
|
+
ff_linear_module_down=None,
|
306
308
|
mlp_dropout=0.0,
|
307
309
|
msa_dropout=0.0,
|
308
310
|
identity_probability=0.0,
|
@@ -356,8 +358,16 @@ class TransformerBlock(nn.Module):
|
|
356
358
|
activation=activation,
|
357
359
|
activation_kwargs=activation_kwargs,
|
358
360
|
dropout=mlp_dropout,
|
359
|
-
linear_module_up=
|
360
|
-
|
361
|
+
linear_module_up=(
|
362
|
+
ff_linear_module_up
|
363
|
+
if ff_linear_module_up is not None
|
364
|
+
else linear_module
|
365
|
+
),
|
366
|
+
linear_module_down=(
|
367
|
+
ff_linear_module_down
|
368
|
+
if ff_linear_module_down is not None
|
369
|
+
else linear_module
|
370
|
+
),
|
361
371
|
pre_norm=False, # Handled outside the block
|
362
372
|
normformer=normformer,
|
363
373
|
post_norm=False, # Handled outside the block
|
@@ -389,51 +399,6 @@ class TransformerBlock(nn.Module):
|
|
389
399
|
|
390
400
|
return x
|
391
401
|
|
392
|
-
# if not self.training:
|
393
|
-
# identity_probability = 0.0
|
394
|
-
# else:
|
395
|
-
# identity_probability = self.identity_probability
|
396
|
-
|
397
|
-
# if random.random() < identity_probability:
|
398
|
-
# return x
|
399
|
-
# else:
|
400
|
-
# ...
|
401
|
-
|
402
|
-
# # perform the identity operation for some rows in the batch
|
403
|
-
# dist = torch.distributions.Binomial(x.size(0), identity_probability)
|
404
|
-
# identity_count = int(dist.sample().item())
|
405
|
-
|
406
|
-
# shuffle_indices = torch.randperm(x.size(0), device=x.device)
|
407
|
-
# unshuffle_indices = torch.argsort(shuffle_indices)
|
408
|
-
# shuffled = x[shuffle_indices, :, :]
|
409
|
-
# norm_shuffled = self.layer_norm_1(shuffled)
|
410
|
-
# identity_x = shuffled[:identity_count, :, :]
|
411
|
-
# process_x = shuffled[identity_count:, :, :]
|
412
|
-
# residual = process_x
|
413
|
-
|
414
|
-
# if self.pre_norm:
|
415
|
-
# process_x = norm_shuffled[identity_count:, :, :]
|
416
|
-
|
417
|
-
# process_x = residual + self.attn(process_x, process_x, process_x)
|
418
|
-
# residual = process_x
|
419
|
-
|
420
|
-
# shuffled = torch.cat([identity_x, process_x])
|
421
|
-
# norm_shuffled = self.layer_norm_2(shuffled)
|
422
|
-
|
423
|
-
# if self.pre_norm:
|
424
|
-
# residual = process_x # residual NOT normed
|
425
|
-
# process_x = norm_shuffled[identity_count:, :, :]
|
426
|
-
|
427
|
-
# if self.post_norm:
|
428
|
-
# process_x = norm_shuffled[identity_count:, :, :]
|
429
|
-
# residual = process_x # residual normed
|
430
|
-
|
431
|
-
# process_x = residual + self.ff(process_x) # handles residual connection
|
432
|
-
|
433
|
-
# x = torch.cat([identity_x, process_x])[unshuffle_indices, :, :].contiguous()
|
434
|
-
|
435
|
-
# return x if not self.post_norm else self.layer_norm_3(x)
|
436
|
-
|
437
402
|
|
438
403
|
class TransformerEncoder(nn.Module):
|
439
404
|
"""
|
@@ -452,6 +417,8 @@ class TransformerEncoder(nn.Module):
|
|
452
417
|
mlp_ratio=4,
|
453
418
|
activation: nn.Module = nn.ReLU,
|
454
419
|
activation_kwargs: Optional[dict] = None,
|
420
|
+
ff_linear_module_up=None,
|
421
|
+
ff_linear_module_down=None,
|
455
422
|
mlp_dropout=0.0,
|
456
423
|
msa_dropout=0.0,
|
457
424
|
stochastic_depth=0.0,
|
@@ -515,6 +482,8 @@ class TransformerEncoder(nn.Module):
|
|
515
482
|
mlp_ratio=mlp_ratio,
|
516
483
|
activation=activation,
|
517
484
|
activation_kwargs=activation_kwargs,
|
485
|
+
ff_linear_module_up=ff_linear_module_up,
|
486
|
+
ff_linear_module_down=ff_linear_module_down,
|
518
487
|
mlp_dropout=mlp_dropout,
|
519
488
|
msa_dropout=msa_dropout,
|
520
489
|
identity_probability=self.stochastic_depth_probabilities[i],
|
broccoli/vit.py
CHANGED
@@ -132,6 +132,8 @@ class ViTEncoder(nn.Module):
|
|
132
132
|
transformer_return_bos_tokens=False,
|
133
133
|
transformer_activation: nn.Module = SquaredReLU,
|
134
134
|
transformer_activation_kwargs: Optional[dict] = None,
|
135
|
+
transformer_ff_linear_module_up=None,
|
136
|
+
transformer_ff_linear_module_down=None,
|
135
137
|
transformer_mlp_dropout=0.0,
|
136
138
|
transformer_msa_dropout=0.1,
|
137
139
|
transformer_stochastic_depth=0.1,
|
@@ -282,6 +284,8 @@ class ViTEncoder(nn.Module):
|
|
282
284
|
mlp_ratio=transformer_mlp_ratio,
|
283
285
|
activation=transformer_activation,
|
284
286
|
activation_kwargs=transformer_activation_kwargs,
|
287
|
+
ff_linear_module_up=transformer_ff_linear_module_up,
|
288
|
+
ff_linear_module_down=transformer_ff_linear_module_down,
|
285
289
|
mlp_dropout=transformer_mlp_dropout,
|
286
290
|
msa_dropout=transformer_msa_dropout,
|
287
291
|
stochastic_depth=transformer_stochastic_depth,
|
@@ -305,14 +309,16 @@ class ViTEncoder(nn.Module):
|
|
305
309
|
activation_kwargs=transformer_activation_kwargs,
|
306
310
|
dropout=transformer_mlp_dropout,
|
307
311
|
linear_module_up=(
|
312
|
+
# First truthy assigned value
|
308
313
|
transformer_initial_ff_linear_module_up
|
309
|
-
|
310
|
-
|
314
|
+
or transformer_ff_linear_module_up
|
315
|
+
or linear_module
|
311
316
|
),
|
312
317
|
linear_module_down=(
|
318
|
+
# First truthy assigned value
|
313
319
|
transformer_initial_ff_linear_module_down
|
314
|
-
|
315
|
-
|
320
|
+
or transformer_ff_linear_module_down
|
321
|
+
or linear_module
|
316
322
|
),
|
317
323
|
pre_norm=transformer_pre_norm,
|
318
324
|
normformer=transformer_normformer,
|
@@ -386,6 +392,8 @@ class ViT(nn.Module):
|
|
386
392
|
transformer_return_bos_tokens=False,
|
387
393
|
transformer_activation: nn.Module = SquaredReLU,
|
388
394
|
transformer_activation_kwargs: Optional[dict] = None,
|
395
|
+
transformer_ff_linear_module_up=None,
|
396
|
+
transformer_ff_linear_module_down=None,
|
389
397
|
transformer_mlp_dropout=0.0,
|
390
398
|
transformer_msa_dropout=0.1,
|
391
399
|
transformer_stochastic_depth=0.1,
|
@@ -446,6 +454,8 @@ class ViT(nn.Module):
|
|
446
454
|
transformer_return_bos_tokens=transformer_return_bos_tokens,
|
447
455
|
transformer_activation=transformer_activation,
|
448
456
|
transformer_activation_kwargs=transformer_activation_kwargs,
|
457
|
+
transformer_ff_linear_module_up=transformer_ff_linear_module_up,
|
458
|
+
transformer_ff_linear_module_down=transformer_ff_linear_module_down,
|
449
459
|
transformer_mlp_dropout=transformer_mlp_dropout,
|
450
460
|
transformer_msa_dropout=transformer_msa_dropout,
|
451
461
|
transformer_stochastic_depth=transformer_stochastic_depth,
|
@@ -5,13 +5,13 @@ broccoli/assets/cifar100_eigenvectors_size_2.pt,sha256=DjXDOXMeuMpIqNuGhX9z-OWYV
|
|
5
5
|
broccoli/assets/cifar100_eigenvectors_size_3.pt,sha256=gL6k0xtXYiYP6ZSvEiMBdJ7kIkT0AngTpDJHFQqwgxA,7173
|
6
6
|
broccoli/cnn.py,sha256=jeRyKIAMWu1E3iyI14MGgSZuZivPMh12iqkqW9ilNjo,17785
|
7
7
|
broccoli/eigenpatches.py,sha256=J6n2usN1oQuHEHYiBNyYpn_a9eQcHjOBiIlvSei520Y,2413
|
8
|
-
broccoli/linear.py,sha256=
|
8
|
+
broccoli/linear.py,sha256=zlFDij9TngqDzTTpUlZtX0PXAQgxyWWnTIXpiO1rBk0,4768
|
9
9
|
broccoli/rope.py,sha256=hw7kBPNR9GQXj4GxyIAffsGKPfcTPOFh8Bc7oEHtaZY,12108
|
10
10
|
broccoli/tensor.py,sha256=um8mrxkYbvNDo-QvHlmJm8Aw6qcngOlUZPoAk_PMReA,4480
|
11
|
-
broccoli/transformer.py,sha256=
|
11
|
+
broccoli/transformer.py,sha256=Xw1oBLsvVeHmMqgurhorRa49nrjfooLev5uBPFeK9og,17004
|
12
12
|
broccoli/utils.py,sha256=htq_hOsdhUhL0nJi9WkKiEYOjEoWqFpK5X49PtgTf-0,299
|
13
|
-
broccoli/vit.py,sha256=
|
14
|
-
broccoli_ml-0.
|
15
|
-
broccoli_ml-0.
|
16
|
-
broccoli_ml-0.
|
17
|
-
broccoli_ml-0.
|
13
|
+
broccoli/vit.py,sha256=RaSJZh2ogqNAvAkQDuZpNqdtRTWaW_8ug4BsBCBK_f4,17728
|
14
|
+
broccoli_ml-0.38.0.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
|
15
|
+
broccoli_ml-0.38.0.dist-info/METADATA,sha256=8lcwJvmlPT_-SkJZ5Qd6nFc4Njz7mQD-JefeUxhk8Xw,1257
|
16
|
+
broccoli_ml-0.38.0.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
17
|
+
broccoli_ml-0.38.0.dist-info/RECORD,,
|
File without changes
|
File without changes
|