broccoli-ml 0.36.0__py3-none-any.whl → 0.37.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- broccoli/transformer.py +98 -24
- {broccoli_ml-0.36.0.dist-info → broccoli_ml-0.37.0.dist-info}/METADATA +1 -1
- {broccoli_ml-0.36.0.dist-info → broccoli_ml-0.37.0.dist-info}/RECORD +5 -5
- {broccoli_ml-0.36.0.dist-info → broccoli_ml-0.37.0.dist-info}/LICENSE +0 -0
- {broccoli_ml-0.36.0.dist-info → broccoli_ml-0.37.0.dist-info}/WHEEL +0 -0
broccoli/transformer.py
CHANGED
@@ -13,6 +13,45 @@ from .rope import RotaryEmbedding, apply_rotary_emb
|
|
13
13
|
from .linear import AnchoredLinear, SpectralNormLinear
|
14
14
|
|
15
15
|
|
16
|
+
def drop_path(
|
17
|
+
x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
|
18
|
+
):
|
19
|
+
"""
|
20
|
+
From https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/drop.py
|
21
|
+
Copyright 2019 Ross Wightman
|
22
|
+
See documentation and licence there.
|
23
|
+
"""
|
24
|
+
if drop_prob == 0.0 or not training:
|
25
|
+
return x
|
26
|
+
keep_prob = 1 - drop_prob
|
27
|
+
shape = (x.shape[0],) + (1,) * (
|
28
|
+
x.ndim - 1
|
29
|
+
) # work with diff dim tensors, not just 2D ConvNets
|
30
|
+
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
31
|
+
if keep_prob > 0.0 and scale_by_keep:
|
32
|
+
random_tensor.div_(keep_prob)
|
33
|
+
return x * random_tensor
|
34
|
+
|
35
|
+
|
36
|
+
class DropPath(nn.Module):
|
37
|
+
"""
|
38
|
+
From https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/drop.py
|
39
|
+
Copyright 2019 Ross Wightman
|
40
|
+
See documentation and licence there.
|
41
|
+
"""
|
42
|
+
|
43
|
+
def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
|
44
|
+
super(DropPath, self).__init__()
|
45
|
+
self.drop_prob = drop_prob
|
46
|
+
self.scale_by_keep = scale_by_keep
|
47
|
+
|
48
|
+
def forward(self, x):
|
49
|
+
return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
|
50
|
+
|
51
|
+
def extra_repr(self):
|
52
|
+
return f"drop_prob={round(self.drop_prob, 3):0.3f}"
|
53
|
+
|
54
|
+
|
16
55
|
class MHAttention(nn.Module):
|
17
56
|
"""
|
18
57
|
Multi-head self-attention using einops and optionally a custom linear layer.
|
@@ -279,10 +318,11 @@ class TransformerBlock(nn.Module):
|
|
279
318
|
self.post_norm = post_norm
|
280
319
|
self.normformer = normformer
|
281
320
|
|
282
|
-
self.
|
321
|
+
self.drop_path = DropPath(drop_prob=identity_probability, scale_by_keep=True)
|
283
322
|
|
284
323
|
self.layer_norm_1 = nn.LayerNorm(d_model)
|
285
324
|
self.layer_norm_2 = nn.LayerNorm(d_model)
|
325
|
+
self.layer_norm_3 = nn.LayerNorm(d_model)
|
286
326
|
|
287
327
|
if position_embedding_type == "relative":
|
288
328
|
max_freq = int(max(source_size) / 2) # Suggested by Gemini!
|
@@ -318,10 +358,10 @@ class TransformerBlock(nn.Module):
|
|
318
358
|
dropout=mlp_dropout,
|
319
359
|
linear_module_up=linear_module,
|
320
360
|
linear_module_down=linear_module,
|
321
|
-
pre_norm=
|
361
|
+
pre_norm=False, # Handled outside the block
|
322
362
|
normformer=normformer,
|
323
|
-
post_norm=
|
324
|
-
residual_path=
|
363
|
+
post_norm=False, # Handled outside the block
|
364
|
+
residual_path=False, # Handled outside the block
|
325
365
|
)
|
326
366
|
|
327
367
|
@property
|
@@ -329,36 +369,70 @@ class TransformerBlock(nn.Module):
|
|
329
369
|
return self.attn._kv_distance
|
330
370
|
|
331
371
|
def forward(self, x):
|
332
|
-
|
333
|
-
|
372
|
+
|
373
|
+
if self.pre_norm:
|
374
|
+
normx = self.layer_norm_1(x)
|
375
|
+
x = x + self.drop_path(self.attn(normx, normx, normx))
|
376
|
+
normx = self.layer_norm_2(x)
|
377
|
+
x = x + self.drop_path(self.ff(normx))
|
378
|
+
elif self.post_norm:
|
379
|
+
x = x + self.drop_path(self.attn(x, x, x))
|
380
|
+
x = self.layer_norm_1(x)
|
381
|
+
x = x + self.drop_path(self.ff(x))
|
382
|
+
x = self.layer_norm_2(x)
|
334
383
|
else:
|
335
|
-
|
384
|
+
x = x + self.drop_path(self.attn(x, x, x))
|
385
|
+
x = x + self.drop_path(self.ff(x))
|
336
386
|
|
337
|
-
|
338
|
-
|
339
|
-
identity_count = int(dist.sample().item())
|
387
|
+
if self.pre_norm and self.post_norm:
|
388
|
+
x = self.layer_norm_3(x)
|
340
389
|
|
341
|
-
|
342
|
-
unshuffle_indices = torch.argsort(shuffle_indices)
|
343
|
-
shuffled = x[shuffle_indices, :, :]
|
344
|
-
identity_x = shuffled[:identity_count, :, :]
|
345
|
-
process_x = shuffled[identity_count:, :, :]
|
390
|
+
return x
|
346
391
|
|
347
|
-
|
392
|
+
# if not self.training:
|
393
|
+
# identity_probability = 0.0
|
394
|
+
# else:
|
395
|
+
# identity_probability = self.identity_probability
|
348
396
|
|
349
|
-
|
350
|
-
|
397
|
+
# if random.random() < identity_probability:
|
398
|
+
# return x
|
399
|
+
# else:
|
400
|
+
# ...
|
351
401
|
|
352
|
-
|
402
|
+
# # perform the identity operation for some rows in the batch
|
403
|
+
# dist = torch.distributions.Binomial(x.size(0), identity_probability)
|
404
|
+
# identity_count = int(dist.sample().item())
|
353
405
|
|
354
|
-
|
355
|
-
|
406
|
+
# shuffle_indices = torch.randperm(x.size(0), device=x.device)
|
407
|
+
# unshuffle_indices = torch.argsort(shuffle_indices)
|
408
|
+
# shuffled = x[shuffle_indices, :, :]
|
409
|
+
# norm_shuffled = self.layer_norm_1(shuffled)
|
410
|
+
# identity_x = shuffled[:identity_count, :, :]
|
411
|
+
# process_x = shuffled[identity_count:, :, :]
|
412
|
+
# residual = process_x
|
356
413
|
|
357
|
-
|
414
|
+
# if self.pre_norm:
|
415
|
+
# process_x = norm_shuffled[identity_count:, :, :]
|
358
416
|
|
359
|
-
|
417
|
+
# process_x = residual + self.attn(process_x, process_x, process_x)
|
418
|
+
# residual = process_x
|
360
419
|
|
361
|
-
|
420
|
+
# shuffled = torch.cat([identity_x, process_x])
|
421
|
+
# norm_shuffled = self.layer_norm_2(shuffled)
|
422
|
+
|
423
|
+
# if self.pre_norm:
|
424
|
+
# residual = process_x # residual NOT normed
|
425
|
+
# process_x = norm_shuffled[identity_count:, :, :]
|
426
|
+
|
427
|
+
# if self.post_norm:
|
428
|
+
# process_x = norm_shuffled[identity_count:, :, :]
|
429
|
+
# residual = process_x # residual normed
|
430
|
+
|
431
|
+
# process_x = residual + self.ff(process_x) # handles residual connection
|
432
|
+
|
433
|
+
# x = torch.cat([identity_x, process_x])[unshuffle_indices, :, :].contiguous()
|
434
|
+
|
435
|
+
# return x if not self.post_norm else self.layer_norm_3(x)
|
362
436
|
|
363
437
|
|
364
438
|
class TransformerEncoder(nn.Module):
|
@@ -8,10 +8,10 @@ broccoli/eigenpatches.py,sha256=J6n2usN1oQuHEHYiBNyYpn_a9eQcHjOBiIlvSei520Y,2413
|
|
8
8
|
broccoli/linear.py,sha256=8Y9vD85ZEgNZsIQgO3uRQ3lOQR-JjwvabY8liCrfNCk,4831
|
9
9
|
broccoli/rope.py,sha256=hw7kBPNR9GQXj4GxyIAffsGKPfcTPOFh8Bc7oEHtaZY,12108
|
10
10
|
broccoli/tensor.py,sha256=um8mrxkYbvNDo-QvHlmJm8Aw6qcngOlUZPoAk_PMReA,4480
|
11
|
-
broccoli/transformer.py,sha256=
|
11
|
+
broccoli/transformer.py,sha256=hhembQe9tEVNZMRtgbdGEsHWaBXSl95h_RpDhFde030,18171
|
12
12
|
broccoli/utils.py,sha256=htq_hOsdhUhL0nJi9WkKiEYOjEoWqFpK5X49PtgTf-0,299
|
13
13
|
broccoli/vit.py,sha256=05xqIw9xvE5easXcp4wIA1jQ0xUyRIq6h0ZDtbitXi4,17184
|
14
|
-
broccoli_ml-0.
|
15
|
-
broccoli_ml-0.
|
16
|
-
broccoli_ml-0.
|
17
|
-
broccoli_ml-0.
|
14
|
+
broccoli_ml-0.37.0.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
|
15
|
+
broccoli_ml-0.37.0.dist-info/METADATA,sha256=jUDSeLfYphtaOGvJn64v3deZw1nmKn4VYc7PO69BSPk,1257
|
16
|
+
broccoli_ml-0.37.0.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
17
|
+
broccoli_ml-0.37.0.dist-info/RECORD,,
|
File without changes
|
File without changes
|