broccoli-ml 0.36.0__py3-none-any.whl → 0.37.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
broccoli/transformer.py CHANGED
@@ -13,6 +13,45 @@ from .rope import RotaryEmbedding, apply_rotary_emb
13
13
  from .linear import AnchoredLinear, SpectralNormLinear
14
14
 
15
15
 
16
+ def drop_path(
17
+ x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
18
+ ):
19
+ """
20
+ From https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/drop.py
21
+ Copyright 2019 Ross Wightman
22
+ See documentation and licence there.
23
+ """
24
+ if drop_prob == 0.0 or not training:
25
+ return x
26
+ keep_prob = 1 - drop_prob
27
+ shape = (x.shape[0],) + (1,) * (
28
+ x.ndim - 1
29
+ ) # work with diff dim tensors, not just 2D ConvNets
30
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
31
+ if keep_prob > 0.0 and scale_by_keep:
32
+ random_tensor.div_(keep_prob)
33
+ return x * random_tensor
34
+
35
+
36
+ class DropPath(nn.Module):
37
+ """
38
+ From https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/drop.py
39
+ Copyright 2019 Ross Wightman
40
+ See documentation and licence there.
41
+ """
42
+
43
+ def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
44
+ super(DropPath, self).__init__()
45
+ self.drop_prob = drop_prob
46
+ self.scale_by_keep = scale_by_keep
47
+
48
+ def forward(self, x):
49
+ return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
50
+
51
+ def extra_repr(self):
52
+ return f"drop_prob={round(self.drop_prob, 3):0.3f}"
53
+
54
+
16
55
  class MHAttention(nn.Module):
17
56
  """
18
57
  Multi-head self-attention using einops and optionally a custom linear layer.
@@ -279,10 +318,11 @@ class TransformerBlock(nn.Module):
279
318
  self.post_norm = post_norm
280
319
  self.normformer = normformer
281
320
 
282
- self.identity_probability = identity_probability
321
+ self.drop_path = DropPath(drop_prob=identity_probability, scale_by_keep=True)
283
322
 
284
323
  self.layer_norm_1 = nn.LayerNorm(d_model)
285
324
  self.layer_norm_2 = nn.LayerNorm(d_model)
325
+ self.layer_norm_3 = nn.LayerNorm(d_model)
286
326
 
287
327
  if position_embedding_type == "relative":
288
328
  max_freq = int(max(source_size) / 2) # Suggested by Gemini!
@@ -318,10 +358,10 @@ class TransformerBlock(nn.Module):
318
358
  dropout=mlp_dropout,
319
359
  linear_module_up=linear_module,
320
360
  linear_module_down=linear_module,
321
- pre_norm=pre_norm,
361
+ pre_norm=False, # Handled outside the block
322
362
  normformer=normformer,
323
- post_norm=post_norm,
324
- residual_path=True,
363
+ post_norm=False, # Handled outside the block
364
+ residual_path=False, # Handled outside the block
325
365
  )
326
366
 
327
367
  @property
@@ -329,36 +369,70 @@ class TransformerBlock(nn.Module):
329
369
  return self.attn._kv_distance
330
370
 
331
371
  def forward(self, x):
332
- if not self.training:
333
- identity_probability = 0.0
372
+
373
+ if self.pre_norm:
374
+ normx = self.layer_norm_1(x)
375
+ x = x + self.drop_path(self.attn(normx, normx, normx))
376
+ normx = self.layer_norm_2(x)
377
+ x = x + self.drop_path(self.ff(normx))
378
+ elif self.post_norm:
379
+ x = x + self.drop_path(self.attn(x, x, x))
380
+ x = self.layer_norm_1(x)
381
+ x = x + self.drop_path(self.ff(x))
382
+ x = self.layer_norm_2(x)
334
383
  else:
335
- identity_probability = self.identity_probability
384
+ x = x + self.drop_path(self.attn(x, x, x))
385
+ x = x + self.drop_path(self.ff(x))
336
386
 
337
- # perform the identity operation for some rows in the batch
338
- dist = torch.distributions.Binomial(x.size(0), identity_probability)
339
- identity_count = int(dist.sample().item())
387
+ if self.pre_norm and self.post_norm:
388
+ x = self.layer_norm_3(x)
340
389
 
341
- shuffle_indices = torch.randperm(x.size(0), device=x.device)
342
- unshuffle_indices = torch.argsort(shuffle_indices)
343
- shuffled = x[shuffle_indices, :, :]
344
- identity_x = shuffled[:identity_count, :, :]
345
- process_x = shuffled[identity_count:, :, :]
390
+ return x
346
391
 
347
- residual_x = process_x
392
+ # if not self.training:
393
+ # identity_probability = 0.0
394
+ # else:
395
+ # identity_probability = self.identity_probability
348
396
 
349
- if self.pre_norm:
350
- process_x = self.layer_norm_1(process_x)
397
+ # if random.random() < identity_probability:
398
+ # return x
399
+ # else:
400
+ # ...
351
401
 
352
- process_x = residual_x + self.attn(process_x, process_x, process_x)
402
+ # # perform the identity operation for some rows in the batch
403
+ # dist = torch.distributions.Binomial(x.size(0), identity_probability)
404
+ # identity_count = int(dist.sample().item())
353
405
 
354
- if self.post_norm:
355
- process_x = self.layer_norm_2(process_x)
406
+ # shuffle_indices = torch.randperm(x.size(0), device=x.device)
407
+ # unshuffle_indices = torch.argsort(shuffle_indices)
408
+ # shuffled = x[shuffle_indices, :, :]
409
+ # norm_shuffled = self.layer_norm_1(shuffled)
410
+ # identity_x = shuffled[:identity_count, :, :]
411
+ # process_x = shuffled[identity_count:, :, :]
412
+ # residual = process_x
356
413
 
357
- process_x = self.ff(process_x)
414
+ # if self.pre_norm:
415
+ # process_x = norm_shuffled[identity_count:, :, :]
358
416
 
359
- x = torch.cat([identity_x, process_x])[unshuffle_indices, :, :].contiguous()
417
+ # process_x = residual + self.attn(process_x, process_x, process_x)
418
+ # residual = process_x
360
419
 
361
- return x
420
+ # shuffled = torch.cat([identity_x, process_x])
421
+ # norm_shuffled = self.layer_norm_2(shuffled)
422
+
423
+ # if self.pre_norm:
424
+ # residual = process_x # residual NOT normed
425
+ # process_x = norm_shuffled[identity_count:, :, :]
426
+
427
+ # if self.post_norm:
428
+ # process_x = norm_shuffled[identity_count:, :, :]
429
+ # residual = process_x # residual normed
430
+
431
+ # process_x = residual + self.ff(process_x) # handles residual connection
432
+
433
+ # x = torch.cat([identity_x, process_x])[unshuffle_indices, :, :].contiguous()
434
+
435
+ # return x if not self.post_norm else self.layer_norm_3(x)
362
436
 
363
437
 
364
438
  class TransformerEncoder(nn.Module):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: broccoli-ml
3
- Version: 0.36.0
3
+ Version: 0.37.0
4
4
  Summary: Some useful Pytorch models, circa 2025
5
5
  License: MIT
6
6
  Author: Nicholas Bailey
@@ -8,10 +8,10 @@ broccoli/eigenpatches.py,sha256=J6n2usN1oQuHEHYiBNyYpn_a9eQcHjOBiIlvSei520Y,2413
8
8
  broccoli/linear.py,sha256=8Y9vD85ZEgNZsIQgO3uRQ3lOQR-JjwvabY8liCrfNCk,4831
9
9
  broccoli/rope.py,sha256=hw7kBPNR9GQXj4GxyIAffsGKPfcTPOFh8Bc7oEHtaZY,12108
10
10
  broccoli/tensor.py,sha256=um8mrxkYbvNDo-QvHlmJm8Aw6qcngOlUZPoAk_PMReA,4480
11
- broccoli/transformer.py,sha256=NH94U6lxHzmDGDHTTtJV2kUs7IcS2iNmFJl44_6KtQ0,15456
11
+ broccoli/transformer.py,sha256=hhembQe9tEVNZMRtgbdGEsHWaBXSl95h_RpDhFde030,18171
12
12
  broccoli/utils.py,sha256=htq_hOsdhUhL0nJi9WkKiEYOjEoWqFpK5X49PtgTf-0,299
13
13
  broccoli/vit.py,sha256=05xqIw9xvE5easXcp4wIA1jQ0xUyRIq6h0ZDtbitXi4,17184
14
- broccoli_ml-0.36.0.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
15
- broccoli_ml-0.36.0.dist-info/METADATA,sha256=csog4ZG1PGeRuFO5QnHdVPgmDYXsGQQJ621JgU0D83w,1257
16
- broccoli_ml-0.36.0.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
17
- broccoli_ml-0.36.0.dist-info/RECORD,,
14
+ broccoli_ml-0.37.0.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
15
+ broccoli_ml-0.37.0.dist-info/METADATA,sha256=jUDSeLfYphtaOGvJn64v3deZw1nmKn4VYc7PO69BSPk,1257
16
+ broccoli_ml-0.37.0.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
17
+ broccoli_ml-0.37.0.dist-info/RECORD,,