broccoli-ml 0.36.0__py3-none-any.whl → 0.38.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- broccoli/linear.py +1 -2
- broccoli/transformer.py +74 -31
- broccoli/vit.py +14 -4
- {broccoli_ml-0.36.0.dist-info → broccoli_ml-0.38.0.dist-info}/METADATA +1 -1
- {broccoli_ml-0.36.0.dist-info → broccoli_ml-0.38.0.dist-info}/RECORD +7 -7
- {broccoli_ml-0.36.0.dist-info → broccoli_ml-0.38.0.dist-info}/LICENSE +0 -0
- {broccoli_ml-0.36.0.dist-info → broccoli_ml-0.38.0.dist-info}/WHEEL +0 -0
broccoli/linear.py
CHANGED
@@ -131,8 +131,7 @@ class WeightNormedLinear(nn.Module):
|
|
131
131
|
return F.linear(x, self.weights(), self.bias)
|
132
132
|
|
133
133
|
def __repr__(self) -> str:
|
134
|
-
# Optional: A nice representation for printing the module.
|
135
134
|
return (
|
136
|
-
f"
|
135
|
+
f"WeightNormedLinear(in_features={self.in_features},"
|
137
136
|
f"out_features={self.out_features}, bias={self.use_bias})"
|
138
137
|
)
|
broccoli/transformer.py
CHANGED
@@ -13,6 +13,45 @@ from .rope import RotaryEmbedding, apply_rotary_emb
|
|
13
13
|
from .linear import AnchoredLinear, SpectralNormLinear
|
14
14
|
|
15
15
|
|
16
|
+
def drop_path(
|
17
|
+
x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
|
18
|
+
):
|
19
|
+
"""
|
20
|
+
From https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/drop.py
|
21
|
+
Copyright 2019 Ross Wightman
|
22
|
+
See documentation and licence there.
|
23
|
+
"""
|
24
|
+
if drop_prob == 0.0 or not training:
|
25
|
+
return x
|
26
|
+
keep_prob = 1 - drop_prob
|
27
|
+
shape = (x.shape[0],) + (1,) * (
|
28
|
+
x.ndim - 1
|
29
|
+
) # work with diff dim tensors, not just 2D ConvNets
|
30
|
+
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
31
|
+
if keep_prob > 0.0 and scale_by_keep:
|
32
|
+
random_tensor.div_(keep_prob)
|
33
|
+
return x * random_tensor
|
34
|
+
|
35
|
+
|
36
|
+
class DropPath(nn.Module):
|
37
|
+
"""
|
38
|
+
From https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/drop.py
|
39
|
+
Copyright 2019 Ross Wightman
|
40
|
+
See documentation and licence there.
|
41
|
+
"""
|
42
|
+
|
43
|
+
def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
|
44
|
+
super(DropPath, self).__init__()
|
45
|
+
self.drop_prob = drop_prob
|
46
|
+
self.scale_by_keep = scale_by_keep
|
47
|
+
|
48
|
+
def forward(self, x):
|
49
|
+
return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
|
50
|
+
|
51
|
+
def extra_repr(self):
|
52
|
+
return f"drop_prob={round(self.drop_prob, 3):0.3f}"
|
53
|
+
|
54
|
+
|
16
55
|
class MHAttention(nn.Module):
|
17
56
|
"""
|
18
57
|
Multi-head self-attention using einops and optionally a custom linear layer.
|
@@ -264,6 +303,8 @@ class TransformerBlock(nn.Module):
|
|
264
303
|
mlp_ratio=4,
|
265
304
|
activation: nn.Module = nn.ReLU,
|
266
305
|
activation_kwargs: Optional[dict] = None,
|
306
|
+
ff_linear_module_up=None,
|
307
|
+
ff_linear_module_down=None,
|
267
308
|
mlp_dropout=0.0,
|
268
309
|
msa_dropout=0.0,
|
269
310
|
identity_probability=0.0,
|
@@ -279,10 +320,11 @@ class TransformerBlock(nn.Module):
|
|
279
320
|
self.post_norm = post_norm
|
280
321
|
self.normformer = normformer
|
281
322
|
|
282
|
-
self.
|
323
|
+
self.drop_path = DropPath(drop_prob=identity_probability, scale_by_keep=True)
|
283
324
|
|
284
325
|
self.layer_norm_1 = nn.LayerNorm(d_model)
|
285
326
|
self.layer_norm_2 = nn.LayerNorm(d_model)
|
327
|
+
self.layer_norm_3 = nn.LayerNorm(d_model)
|
286
328
|
|
287
329
|
if position_embedding_type == "relative":
|
288
330
|
max_freq = int(max(source_size) / 2) # Suggested by Gemini!
|
@@ -316,12 +358,20 @@ class TransformerBlock(nn.Module):
|
|
316
358
|
activation=activation,
|
317
359
|
activation_kwargs=activation_kwargs,
|
318
360
|
dropout=mlp_dropout,
|
319
|
-
linear_module_up=
|
320
|
-
|
321
|
-
|
361
|
+
linear_module_up=(
|
362
|
+
ff_linear_module_up
|
363
|
+
if ff_linear_module_up is not None
|
364
|
+
else linear_module
|
365
|
+
),
|
366
|
+
linear_module_down=(
|
367
|
+
ff_linear_module_down
|
368
|
+
if ff_linear_module_down is not None
|
369
|
+
else linear_module
|
370
|
+
),
|
371
|
+
pre_norm=False, # Handled outside the block
|
322
372
|
normformer=normformer,
|
323
|
-
post_norm=
|
324
|
-
residual_path=
|
373
|
+
post_norm=False, # Handled outside the block
|
374
|
+
residual_path=False, # Handled outside the block
|
325
375
|
)
|
326
376
|
|
327
377
|
@property
|
@@ -329,34 +379,23 @@ class TransformerBlock(nn.Module):
|
|
329
379
|
return self.attn._kv_distance
|
330
380
|
|
331
381
|
def forward(self, x):
|
332
|
-
if not self.training:
|
333
|
-
identity_probability = 0.0
|
334
|
-
else:
|
335
|
-
identity_probability = self.identity_probability
|
336
|
-
|
337
|
-
# perform the identity operation for some rows in the batch
|
338
|
-
dist = torch.distributions.Binomial(x.size(0), identity_probability)
|
339
|
-
identity_count = int(dist.sample().item())
|
340
|
-
|
341
|
-
shuffle_indices = torch.randperm(x.size(0), device=x.device)
|
342
|
-
unshuffle_indices = torch.argsort(shuffle_indices)
|
343
|
-
shuffled = x[shuffle_indices, :, :]
|
344
|
-
identity_x = shuffled[:identity_count, :, :]
|
345
|
-
process_x = shuffled[identity_count:, :, :]
|
346
|
-
|
347
|
-
residual_x = process_x
|
348
382
|
|
349
383
|
if self.pre_norm:
|
350
|
-
|
351
|
-
|
352
|
-
|
353
|
-
|
354
|
-
|
355
|
-
|
356
|
-
|
357
|
-
|
384
|
+
normx = self.layer_norm_1(x)
|
385
|
+
x = x + self.drop_path(self.attn(normx, normx, normx))
|
386
|
+
normx = self.layer_norm_2(x)
|
387
|
+
x = x + self.drop_path(self.ff(normx))
|
388
|
+
elif self.post_norm:
|
389
|
+
x = x + self.drop_path(self.attn(x, x, x))
|
390
|
+
x = self.layer_norm_1(x)
|
391
|
+
x = x + self.drop_path(self.ff(x))
|
392
|
+
x = self.layer_norm_2(x)
|
393
|
+
else:
|
394
|
+
x = x + self.drop_path(self.attn(x, x, x))
|
395
|
+
x = x + self.drop_path(self.ff(x))
|
358
396
|
|
359
|
-
|
397
|
+
if self.pre_norm and self.post_norm:
|
398
|
+
x = self.layer_norm_3(x)
|
360
399
|
|
361
400
|
return x
|
362
401
|
|
@@ -378,6 +417,8 @@ class TransformerEncoder(nn.Module):
|
|
378
417
|
mlp_ratio=4,
|
379
418
|
activation: nn.Module = nn.ReLU,
|
380
419
|
activation_kwargs: Optional[dict] = None,
|
420
|
+
ff_linear_module_up=None,
|
421
|
+
ff_linear_module_down=None,
|
381
422
|
mlp_dropout=0.0,
|
382
423
|
msa_dropout=0.0,
|
383
424
|
stochastic_depth=0.0,
|
@@ -441,6 +482,8 @@ class TransformerEncoder(nn.Module):
|
|
441
482
|
mlp_ratio=mlp_ratio,
|
442
483
|
activation=activation,
|
443
484
|
activation_kwargs=activation_kwargs,
|
485
|
+
ff_linear_module_up=ff_linear_module_up,
|
486
|
+
ff_linear_module_down=ff_linear_module_down,
|
444
487
|
mlp_dropout=mlp_dropout,
|
445
488
|
msa_dropout=msa_dropout,
|
446
489
|
identity_probability=self.stochastic_depth_probabilities[i],
|
broccoli/vit.py
CHANGED
@@ -132,6 +132,8 @@ class ViTEncoder(nn.Module):
|
|
132
132
|
transformer_return_bos_tokens=False,
|
133
133
|
transformer_activation: nn.Module = SquaredReLU,
|
134
134
|
transformer_activation_kwargs: Optional[dict] = None,
|
135
|
+
transformer_ff_linear_module_up=None,
|
136
|
+
transformer_ff_linear_module_down=None,
|
135
137
|
transformer_mlp_dropout=0.0,
|
136
138
|
transformer_msa_dropout=0.1,
|
137
139
|
transformer_stochastic_depth=0.1,
|
@@ -282,6 +284,8 @@ class ViTEncoder(nn.Module):
|
|
282
284
|
mlp_ratio=transformer_mlp_ratio,
|
283
285
|
activation=transformer_activation,
|
284
286
|
activation_kwargs=transformer_activation_kwargs,
|
287
|
+
ff_linear_module_up=transformer_ff_linear_module_up,
|
288
|
+
ff_linear_module_down=transformer_ff_linear_module_down,
|
285
289
|
mlp_dropout=transformer_mlp_dropout,
|
286
290
|
msa_dropout=transformer_msa_dropout,
|
287
291
|
stochastic_depth=transformer_stochastic_depth,
|
@@ -305,14 +309,16 @@ class ViTEncoder(nn.Module):
|
|
305
309
|
activation_kwargs=transformer_activation_kwargs,
|
306
310
|
dropout=transformer_mlp_dropout,
|
307
311
|
linear_module_up=(
|
312
|
+
# First truthy assigned value
|
308
313
|
transformer_initial_ff_linear_module_up
|
309
|
-
|
310
|
-
|
314
|
+
or transformer_ff_linear_module_up
|
315
|
+
or linear_module
|
311
316
|
),
|
312
317
|
linear_module_down=(
|
318
|
+
# First truthy assigned value
|
313
319
|
transformer_initial_ff_linear_module_down
|
314
|
-
|
315
|
-
|
320
|
+
or transformer_ff_linear_module_down
|
321
|
+
or linear_module
|
316
322
|
),
|
317
323
|
pre_norm=transformer_pre_norm,
|
318
324
|
normformer=transformer_normformer,
|
@@ -386,6 +392,8 @@ class ViT(nn.Module):
|
|
386
392
|
transformer_return_bos_tokens=False,
|
387
393
|
transformer_activation: nn.Module = SquaredReLU,
|
388
394
|
transformer_activation_kwargs: Optional[dict] = None,
|
395
|
+
transformer_ff_linear_module_up=None,
|
396
|
+
transformer_ff_linear_module_down=None,
|
389
397
|
transformer_mlp_dropout=0.0,
|
390
398
|
transformer_msa_dropout=0.1,
|
391
399
|
transformer_stochastic_depth=0.1,
|
@@ -446,6 +454,8 @@ class ViT(nn.Module):
|
|
446
454
|
transformer_return_bos_tokens=transformer_return_bos_tokens,
|
447
455
|
transformer_activation=transformer_activation,
|
448
456
|
transformer_activation_kwargs=transformer_activation_kwargs,
|
457
|
+
transformer_ff_linear_module_up=transformer_ff_linear_module_up,
|
458
|
+
transformer_ff_linear_module_down=transformer_ff_linear_module_down,
|
449
459
|
transformer_mlp_dropout=transformer_mlp_dropout,
|
450
460
|
transformer_msa_dropout=transformer_msa_dropout,
|
451
461
|
transformer_stochastic_depth=transformer_stochastic_depth,
|
@@ -5,13 +5,13 @@ broccoli/assets/cifar100_eigenvectors_size_2.pt,sha256=DjXDOXMeuMpIqNuGhX9z-OWYV
|
|
5
5
|
broccoli/assets/cifar100_eigenvectors_size_3.pt,sha256=gL6k0xtXYiYP6ZSvEiMBdJ7kIkT0AngTpDJHFQqwgxA,7173
|
6
6
|
broccoli/cnn.py,sha256=jeRyKIAMWu1E3iyI14MGgSZuZivPMh12iqkqW9ilNjo,17785
|
7
7
|
broccoli/eigenpatches.py,sha256=J6n2usN1oQuHEHYiBNyYpn_a9eQcHjOBiIlvSei520Y,2413
|
8
|
-
broccoli/linear.py,sha256=
|
8
|
+
broccoli/linear.py,sha256=zlFDij9TngqDzTTpUlZtX0PXAQgxyWWnTIXpiO1rBk0,4768
|
9
9
|
broccoli/rope.py,sha256=hw7kBPNR9GQXj4GxyIAffsGKPfcTPOFh8Bc7oEHtaZY,12108
|
10
10
|
broccoli/tensor.py,sha256=um8mrxkYbvNDo-QvHlmJm8Aw6qcngOlUZPoAk_PMReA,4480
|
11
|
-
broccoli/transformer.py,sha256=
|
11
|
+
broccoli/transformer.py,sha256=Xw1oBLsvVeHmMqgurhorRa49nrjfooLev5uBPFeK9og,17004
|
12
12
|
broccoli/utils.py,sha256=htq_hOsdhUhL0nJi9WkKiEYOjEoWqFpK5X49PtgTf-0,299
|
13
|
-
broccoli/vit.py,sha256=
|
14
|
-
broccoli_ml-0.
|
15
|
-
broccoli_ml-0.
|
16
|
-
broccoli_ml-0.
|
17
|
-
broccoli_ml-0.
|
13
|
+
broccoli/vit.py,sha256=RaSJZh2ogqNAvAkQDuZpNqdtRTWaW_8ug4BsBCBK_f4,17728
|
14
|
+
broccoli_ml-0.38.0.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
|
15
|
+
broccoli_ml-0.38.0.dist-info/METADATA,sha256=8lcwJvmlPT_-SkJZ5Qd6nFc4Njz7mQD-JefeUxhk8Xw,1257
|
16
|
+
broccoli_ml-0.38.0.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
17
|
+
broccoli_ml-0.38.0.dist-info/RECORD,,
|
File without changes
|
File without changes
|