broccoli-ml 0.35.1__py3-none-any.whl → 0.37.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- broccoli/transformer.py +100 -63
- broccoli/vit.py +1 -7
- {broccoli_ml-0.35.1.dist-info → broccoli_ml-0.37.0.dist-info}/METADATA +1 -1
- {broccoli_ml-0.35.1.dist-info → broccoli_ml-0.37.0.dist-info}/RECORD +6 -6
- {broccoli_ml-0.35.1.dist-info → broccoli_ml-0.37.0.dist-info}/LICENSE +0 -0
- {broccoli_ml-0.35.1.dist-info → broccoli_ml-0.37.0.dist-info}/WHEEL +0 -0
broccoli/transformer.py
CHANGED
@@ -13,6 +13,45 @@ from .rope import RotaryEmbedding, apply_rotary_emb
|
|
13
13
|
from .linear import AnchoredLinear, SpectralNormLinear
|
14
14
|
|
15
15
|
|
16
|
+
def drop_path(
|
17
|
+
x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
|
18
|
+
):
|
19
|
+
"""
|
20
|
+
From https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/drop.py
|
21
|
+
Copyright 2019 Ross Wightman
|
22
|
+
See documentation and licence there.
|
23
|
+
"""
|
24
|
+
if drop_prob == 0.0 or not training:
|
25
|
+
return x
|
26
|
+
keep_prob = 1 - drop_prob
|
27
|
+
shape = (x.shape[0],) + (1,) * (
|
28
|
+
x.ndim - 1
|
29
|
+
) # work with diff dim tensors, not just 2D ConvNets
|
30
|
+
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
31
|
+
if keep_prob > 0.0 and scale_by_keep:
|
32
|
+
random_tensor.div_(keep_prob)
|
33
|
+
return x * random_tensor
|
34
|
+
|
35
|
+
|
36
|
+
class DropPath(nn.Module):
|
37
|
+
"""
|
38
|
+
From https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/drop.py
|
39
|
+
Copyright 2019 Ross Wightman
|
40
|
+
See documentation and licence there.
|
41
|
+
"""
|
42
|
+
|
43
|
+
def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
|
44
|
+
super(DropPath, self).__init__()
|
45
|
+
self.drop_prob = drop_prob
|
46
|
+
self.scale_by_keep = scale_by_keep
|
47
|
+
|
48
|
+
def forward(self, x):
|
49
|
+
return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
|
50
|
+
|
51
|
+
def extra_repr(self):
|
52
|
+
return f"drop_prob={round(self.drop_prob, 3):0.3f}"
|
53
|
+
|
54
|
+
|
16
55
|
class MHAttention(nn.Module):
|
17
56
|
"""
|
18
57
|
Multi-head self-attention using einops and optionally a custom linear layer.
|
@@ -21,45 +60,6 @@ class MHAttention(nn.Module):
|
|
21
60
|
are the same shape.
|
22
61
|
|
23
62
|
Assumes bias=False and batch_first=True, as God intended.
|
24
|
-
|
25
|
-
Optionally adds various bells and whistles suggested in the
|
26
|
-
literature, including:
|
27
|
-
|
28
|
-
Noam Shazeer's scaled attention per "Attention is All You Need"
|
29
|
-
(https://arxiv.org/abs/1706.03762).
|
30
|
-
|
31
|
-
Max subtract softmax as discussed in "Attention As An RNN"
|
32
|
-
(https://arxiv.org/abs/2405.13956)
|
33
|
-
|
34
|
-
Log-length scaled softmax per "Overcoming a Theoretical Limitation of
|
35
|
-
Self-Attention" (https://arxiv.org/abs/2202.12172).
|
36
|
-
|
37
|
-
Quiet softmax per
|
38
|
-
https://www.evanmiller.org/attention-is-off-by-one.html
|
39
|
-
|
40
|
-
Args:
|
41
|
-
d_model: ...
|
42
|
-
n_heads: ...
|
43
|
-
dropout: ...
|
44
|
-
causal: should a causal mask be applied to the logits before attention
|
45
|
-
is applied? This is standard when using self-attention. Cannot be
|
46
|
-
True if inputs won't be square (e.g. if sequence length for
|
47
|
-
encoder and decoder are different)
|
48
|
-
sequence_length: ...
|
49
|
-
share_kv: ...
|
50
|
-
linear_module: ...
|
51
|
-
max_subtract: if True, the maximum logit value is subtracted from all
|
52
|
-
logits before performing the softmax operation to create a more
|
53
|
-
numerically stable softmax. This is discussed in "Attention As An
|
54
|
-
RNN" (https://arxiv.org/abs/2405.13956).
|
55
|
-
d_model_scale: ...
|
56
|
-
log_length_scale: if True, multiplies logits by the log length of
|
57
|
-
the decoder sequence before performing the softmax operation, as
|
58
|
-
proposed in "Overcoming a Theoretical Limitation of Self-Attention"
|
59
|
-
(https://arxiv.org/abs/2202.12172).
|
60
|
-
quiet: if True, adds 1 to the denominator of the softmax operation,
|
61
|
-
allowing some tokens to attend to no other tokens as described in
|
62
|
-
https://www.evanmiller.org/attention-is-off-by-one.html.
|
63
63
|
"""
|
64
64
|
|
65
65
|
def __init__(
|
@@ -280,7 +280,7 @@ class FeedforwardBlock(nn.Module):
|
|
280
280
|
elif self.residual_path:
|
281
281
|
return x + self.process(x)
|
282
282
|
else:
|
283
|
-
return x
|
283
|
+
return self.process(x)
|
284
284
|
|
285
285
|
|
286
286
|
class TransformerBlock(nn.Module):
|
@@ -318,10 +318,11 @@ class TransformerBlock(nn.Module):
|
|
318
318
|
self.post_norm = post_norm
|
319
319
|
self.normformer = normformer
|
320
320
|
|
321
|
-
self.
|
321
|
+
self.drop_path = DropPath(drop_prob=identity_probability, scale_by_keep=True)
|
322
322
|
|
323
323
|
self.layer_norm_1 = nn.LayerNorm(d_model)
|
324
324
|
self.layer_norm_2 = nn.LayerNorm(d_model)
|
325
|
+
self.layer_norm_3 = nn.LayerNorm(d_model)
|
325
326
|
|
326
327
|
if position_embedding_type == "relative":
|
327
328
|
max_freq = int(max(source_size) / 2) # Suggested by Gemini!
|
@@ -357,10 +358,10 @@ class TransformerBlock(nn.Module):
|
|
357
358
|
dropout=mlp_dropout,
|
358
359
|
linear_module_up=linear_module,
|
359
360
|
linear_module_down=linear_module,
|
360
|
-
pre_norm=
|
361
|
+
pre_norm=False, # Handled outside the block
|
361
362
|
normformer=normformer,
|
362
|
-
post_norm=
|
363
|
-
residual_path=
|
363
|
+
post_norm=False, # Handled outside the block
|
364
|
+
residual_path=False, # Handled outside the block
|
364
365
|
)
|
365
366
|
|
366
367
|
@property
|
@@ -368,34 +369,70 @@ class TransformerBlock(nn.Module):
|
|
368
369
|
return self.attn._kv_distance
|
369
370
|
|
370
371
|
def forward(self, x):
|
371
|
-
|
372
|
-
|
372
|
+
|
373
|
+
if self.pre_norm:
|
374
|
+
normx = self.layer_norm_1(x)
|
375
|
+
x = x + self.drop_path(self.attn(normx, normx, normx))
|
376
|
+
normx = self.layer_norm_2(x)
|
377
|
+
x = x + self.drop_path(self.ff(normx))
|
378
|
+
elif self.post_norm:
|
379
|
+
x = x + self.drop_path(self.attn(x, x, x))
|
380
|
+
x = self.layer_norm_1(x)
|
381
|
+
x = x + self.drop_path(self.ff(x))
|
382
|
+
x = self.layer_norm_2(x)
|
373
383
|
else:
|
374
|
-
|
384
|
+
x = x + self.drop_path(self.attn(x, x, x))
|
385
|
+
x = x + self.drop_path(self.ff(x))
|
375
386
|
|
376
|
-
|
377
|
-
|
378
|
-
shuffle_indices = torch.randperm(x.size(0), device=x.device)
|
379
|
-
unshuffle_indices = torch.argsort(shuffle_indices)
|
380
|
-
shuffled = x[shuffle_indices, :, :]
|
381
|
-
identity_x = shuffled[:identity_count, :, :]
|
382
|
-
process_x = shuffled[identity_count:, :, :]
|
387
|
+
if self.pre_norm and self.post_norm:
|
388
|
+
x = self.layer_norm_3(x)
|
383
389
|
|
384
|
-
|
390
|
+
return x
|
385
391
|
|
386
|
-
|
387
|
-
|
392
|
+
# if not self.training:
|
393
|
+
# identity_probability = 0.0
|
394
|
+
# else:
|
395
|
+
# identity_probability = self.identity_probability
|
388
396
|
|
389
|
-
|
397
|
+
# if random.random() < identity_probability:
|
398
|
+
# return x
|
399
|
+
# else:
|
400
|
+
# ...
|
390
401
|
|
391
|
-
|
392
|
-
|
402
|
+
# # perform the identity operation for some rows in the batch
|
403
|
+
# dist = torch.distributions.Binomial(x.size(0), identity_probability)
|
404
|
+
# identity_count = int(dist.sample().item())
|
393
405
|
|
394
|
-
|
406
|
+
# shuffle_indices = torch.randperm(x.size(0), device=x.device)
|
407
|
+
# unshuffle_indices = torch.argsort(shuffle_indices)
|
408
|
+
# shuffled = x[shuffle_indices, :, :]
|
409
|
+
# norm_shuffled = self.layer_norm_1(shuffled)
|
410
|
+
# identity_x = shuffled[:identity_count, :, :]
|
411
|
+
# process_x = shuffled[identity_count:, :, :]
|
412
|
+
# residual = process_x
|
395
413
|
|
396
|
-
|
414
|
+
# if self.pre_norm:
|
415
|
+
# process_x = norm_shuffled[identity_count:, :, :]
|
397
416
|
|
398
|
-
|
417
|
+
# process_x = residual + self.attn(process_x, process_x, process_x)
|
418
|
+
# residual = process_x
|
419
|
+
|
420
|
+
# shuffled = torch.cat([identity_x, process_x])
|
421
|
+
# norm_shuffled = self.layer_norm_2(shuffled)
|
422
|
+
|
423
|
+
# if self.pre_norm:
|
424
|
+
# residual = process_x # residual NOT normed
|
425
|
+
# process_x = norm_shuffled[identity_count:, :, :]
|
426
|
+
|
427
|
+
# if self.post_norm:
|
428
|
+
# process_x = norm_shuffled[identity_count:, :, :]
|
429
|
+
# residual = process_x # residual normed
|
430
|
+
|
431
|
+
# process_x = residual + self.ff(process_x) # handles residual connection
|
432
|
+
|
433
|
+
# x = torch.cat([identity_x, process_x])[unshuffle_indices, :, :].contiguous()
|
434
|
+
|
435
|
+
# return x if not self.post_norm else self.layer_norm_3(x)
|
399
436
|
|
400
437
|
|
401
438
|
class TransformerEncoder(nn.Module):
|
broccoli/vit.py
CHANGED
@@ -236,13 +236,7 @@ class ViTEncoder(nn.Module):
|
|
236
236
|
|
237
237
|
if pooling_type is None:
|
238
238
|
pooling_out_channels = cnn_activation_out_channels
|
239
|
-
self.pool = nn.
|
240
|
-
*[
|
241
|
-
Rearrange(
|
242
|
-
f"N C {spatial_dim_names} -> N ({spatial_dim_names}) C"
|
243
|
-
), # for transformer
|
244
|
-
]
|
245
|
-
)
|
239
|
+
self.pool = nn.Identity()
|
246
240
|
|
247
241
|
elif pooling_type == "max":
|
248
242
|
pooling_out_channels = cnn_activation_out_channels
|
@@ -8,10 +8,10 @@ broccoli/eigenpatches.py,sha256=J6n2usN1oQuHEHYiBNyYpn_a9eQcHjOBiIlvSei520Y,2413
|
|
8
8
|
broccoli/linear.py,sha256=8Y9vD85ZEgNZsIQgO3uRQ3lOQR-JjwvabY8liCrfNCk,4831
|
9
9
|
broccoli/rope.py,sha256=hw7kBPNR9GQXj4GxyIAffsGKPfcTPOFh8Bc7oEHtaZY,12108
|
10
10
|
broccoli/tensor.py,sha256=um8mrxkYbvNDo-QvHlmJm8Aw6qcngOlUZPoAk_PMReA,4480
|
11
|
-
broccoli/transformer.py,sha256=
|
11
|
+
broccoli/transformer.py,sha256=hhembQe9tEVNZMRtgbdGEsHWaBXSl95h_RpDhFde030,18171
|
12
12
|
broccoli/utils.py,sha256=htq_hOsdhUhL0nJi9WkKiEYOjEoWqFpK5X49PtgTf-0,299
|
13
|
-
broccoli/vit.py,sha256=
|
14
|
-
broccoli_ml-0.
|
15
|
-
broccoli_ml-0.
|
16
|
-
broccoli_ml-0.
|
17
|
-
broccoli_ml-0.
|
13
|
+
broccoli/vit.py,sha256=05xqIw9xvE5easXcp4wIA1jQ0xUyRIq6h0ZDtbitXi4,17184
|
14
|
+
broccoli_ml-0.37.0.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
|
15
|
+
broccoli_ml-0.37.0.dist-info/METADATA,sha256=jUDSeLfYphtaOGvJn64v3deZw1nmKn4VYc7PO69BSPk,1257
|
16
|
+
broccoli_ml-0.37.0.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
17
|
+
broccoli_ml-0.37.0.dist-info/RECORD,,
|
File without changes
|
File without changes
|