broccoli-ml 0.35.1__tar.gz → 0.37.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: broccoli-ml
3
- Version: 0.35.1
3
+ Version: 0.37.0
4
4
  Summary: Some useful Pytorch models, circa 2025
5
5
  License: MIT
6
6
  Author: Nicholas Bailey
@@ -13,6 +13,45 @@ from .rope import RotaryEmbedding, apply_rotary_emb
13
13
  from .linear import AnchoredLinear, SpectralNormLinear
14
14
 
15
15
 
16
+ def drop_path(
17
+ x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
18
+ ):
19
+ """
20
+ From https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/drop.py
21
+ Copyright 2019 Ross Wightman
22
+ See documentation and licence there.
23
+ """
24
+ if drop_prob == 0.0 or not training:
25
+ return x
26
+ keep_prob = 1 - drop_prob
27
+ shape = (x.shape[0],) + (1,) * (
28
+ x.ndim - 1
29
+ ) # work with diff dim tensors, not just 2D ConvNets
30
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
31
+ if keep_prob > 0.0 and scale_by_keep:
32
+ random_tensor.div_(keep_prob)
33
+ return x * random_tensor
34
+
35
+
36
+ class DropPath(nn.Module):
37
+ """
38
+ From https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/drop.py
39
+ Copyright 2019 Ross Wightman
40
+ See documentation and licence there.
41
+ """
42
+
43
+ def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
44
+ super(DropPath, self).__init__()
45
+ self.drop_prob = drop_prob
46
+ self.scale_by_keep = scale_by_keep
47
+
48
+ def forward(self, x):
49
+ return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
50
+
51
+ def extra_repr(self):
52
+ return f"drop_prob={round(self.drop_prob, 3):0.3f}"
53
+
54
+
16
55
  class MHAttention(nn.Module):
17
56
  """
18
57
  Multi-head self-attention using einops and optionally a custom linear layer.
@@ -21,45 +60,6 @@ class MHAttention(nn.Module):
21
60
  are the same shape.
22
61
 
23
62
  Assumes bias=False and batch_first=True, as God intended.
24
-
25
- Optionally adds various bells and whistles suggested in the
26
- literature, including:
27
-
28
- Noam Shazeer's scaled attention per "Attention is All You Need"
29
- (https://arxiv.org/abs/1706.03762).
30
-
31
- Max subtract softmax as discussed in "Attention As An RNN"
32
- (https://arxiv.org/abs/2405.13956)
33
-
34
- Log-length scaled softmax per "Overcoming a Theoretical Limitation of
35
- Self-Attention" (https://arxiv.org/abs/2202.12172).
36
-
37
- Quiet softmax per
38
- https://www.evanmiller.org/attention-is-off-by-one.html
39
-
40
- Args:
41
- d_model: ...
42
- n_heads: ...
43
- dropout: ...
44
- causal: should a causal mask be applied to the logits before attention
45
- is applied? This is standard when using self-attention. Cannot be
46
- True if inputs won't be square (e.g. if sequence length for
47
- encoder and decoder are different)
48
- sequence_length: ...
49
- share_kv: ...
50
- linear_module: ...
51
- max_subtract: if True, the maximum logit value is subtracted from all
52
- logits before performing the softmax operation to create a more
53
- numerically stable softmax. This is discussed in "Attention As An
54
- RNN" (https://arxiv.org/abs/2405.13956).
55
- d_model_scale: ...
56
- log_length_scale: if True, multiplies logits by the log length of
57
- the decoder sequence before performing the softmax operation, as
58
- proposed in "Overcoming a Theoretical Limitation of Self-Attention"
59
- (https://arxiv.org/abs/2202.12172).
60
- quiet: if True, adds 1 to the denominator of the softmax operation,
61
- allowing some tokens to attend to no other tokens as described in
62
- https://www.evanmiller.org/attention-is-off-by-one.html.
63
63
  """
64
64
 
65
65
  def __init__(
@@ -280,7 +280,7 @@ class FeedforwardBlock(nn.Module):
280
280
  elif self.residual_path:
281
281
  return x + self.process(x)
282
282
  else:
283
- return x
283
+ return self.process(x)
284
284
 
285
285
 
286
286
  class TransformerBlock(nn.Module):
@@ -318,10 +318,11 @@ class TransformerBlock(nn.Module):
318
318
  self.post_norm = post_norm
319
319
  self.normformer = normformer
320
320
 
321
- self.identity_probability = identity_probability
321
+ self.drop_path = DropPath(drop_prob=identity_probability, scale_by_keep=True)
322
322
 
323
323
  self.layer_norm_1 = nn.LayerNorm(d_model)
324
324
  self.layer_norm_2 = nn.LayerNorm(d_model)
325
+ self.layer_norm_3 = nn.LayerNorm(d_model)
325
326
 
326
327
  if position_embedding_type == "relative":
327
328
  max_freq = int(max(source_size) / 2) # Suggested by Gemini!
@@ -357,10 +358,10 @@ class TransformerBlock(nn.Module):
357
358
  dropout=mlp_dropout,
358
359
  linear_module_up=linear_module,
359
360
  linear_module_down=linear_module,
360
- pre_norm=pre_norm,
361
+ pre_norm=False, # Handled outside the block
361
362
  normformer=normformer,
362
- post_norm=post_norm,
363
- residual_path=True,
363
+ post_norm=False, # Handled outside the block
364
+ residual_path=False, # Handled outside the block
364
365
  )
365
366
 
366
367
  @property
@@ -368,34 +369,70 @@ class TransformerBlock(nn.Module):
368
369
  return self.attn._kv_distance
369
370
 
370
371
  def forward(self, x):
371
- if not self.training:
372
- identity_probability = 0.0
372
+
373
+ if self.pre_norm:
374
+ normx = self.layer_norm_1(x)
375
+ x = x + self.drop_path(self.attn(normx, normx, normx))
376
+ normx = self.layer_norm_2(x)
377
+ x = x + self.drop_path(self.ff(normx))
378
+ elif self.post_norm:
379
+ x = x + self.drop_path(self.attn(x, x, x))
380
+ x = self.layer_norm_1(x)
381
+ x = x + self.drop_path(self.ff(x))
382
+ x = self.layer_norm_2(x)
373
383
  else:
374
- identity_probability = self.identity_probability
384
+ x = x + self.drop_path(self.attn(x, x, x))
385
+ x = x + self.drop_path(self.ff(x))
375
386
 
376
- # perform the identity operation for some rows in the batch
377
- identity_count = random.binomial(n=x.size(0), p=identity_probability)
378
- shuffle_indices = torch.randperm(x.size(0), device=x.device)
379
- unshuffle_indices = torch.argsort(shuffle_indices)
380
- shuffled = x[shuffle_indices, :, :]
381
- identity_x = shuffled[:identity_count, :, :]
382
- process_x = shuffled[identity_count:, :, :]
387
+ if self.pre_norm and self.post_norm:
388
+ x = self.layer_norm_3(x)
383
389
 
384
- residual_x = process_x
390
+ return x
385
391
 
386
- if self.pre_norm:
387
- process_x = self.layer_norm_1(process_x)
392
+ # if not self.training:
393
+ # identity_probability = 0.0
394
+ # else:
395
+ # identity_probability = self.identity_probability
388
396
 
389
- process_x = residual_x + self.attn(process_x, process_x, process_x)
397
+ # if random.random() < identity_probability:
398
+ # return x
399
+ # else:
400
+ # ...
390
401
 
391
- if self.post_norm:
392
- process_x = self.layer_norm_2(process_x)
402
+ # # perform the identity operation for some rows in the batch
403
+ # dist = torch.distributions.Binomial(x.size(0), identity_probability)
404
+ # identity_count = int(dist.sample().item())
393
405
 
394
- process_x = self.ff(process_x)
406
+ # shuffle_indices = torch.randperm(x.size(0), device=x.device)
407
+ # unshuffle_indices = torch.argsort(shuffle_indices)
408
+ # shuffled = x[shuffle_indices, :, :]
409
+ # norm_shuffled = self.layer_norm_1(shuffled)
410
+ # identity_x = shuffled[:identity_count, :, :]
411
+ # process_x = shuffled[identity_count:, :, :]
412
+ # residual = process_x
395
413
 
396
- x = torch.cat([identity_x, process_x])[unshuffle_indices, :, :].contiguous()
414
+ # if self.pre_norm:
415
+ # process_x = norm_shuffled[identity_count:, :, :]
397
416
 
398
- return x
417
+ # process_x = residual + self.attn(process_x, process_x, process_x)
418
+ # residual = process_x
419
+
420
+ # shuffled = torch.cat([identity_x, process_x])
421
+ # norm_shuffled = self.layer_norm_2(shuffled)
422
+
423
+ # if self.pre_norm:
424
+ # residual = process_x # residual NOT normed
425
+ # process_x = norm_shuffled[identity_count:, :, :]
426
+
427
+ # if self.post_norm:
428
+ # process_x = norm_shuffled[identity_count:, :, :]
429
+ # residual = process_x # residual normed
430
+
431
+ # process_x = residual + self.ff(process_x) # handles residual connection
432
+
433
+ # x = torch.cat([identity_x, process_x])[unshuffle_indices, :, :].contiguous()
434
+
435
+ # return x if not self.post_norm else self.layer_norm_3(x)
399
436
 
400
437
 
401
438
  class TransformerEncoder(nn.Module):
@@ -236,13 +236,7 @@ class ViTEncoder(nn.Module):
236
236
 
237
237
  if pooling_type is None:
238
238
  pooling_out_channels = cnn_activation_out_channels
239
- self.pool = nn.Sequential(
240
- *[
241
- Rearrange(
242
- f"N C {spatial_dim_names} -> N ({spatial_dim_names}) C"
243
- ), # for transformer
244
- ]
245
- )
239
+ self.pool = nn.Identity()
246
240
 
247
241
  elif pooling_type == "max":
248
242
  pooling_out_channels = cnn_activation_out_channels
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "broccoli-ml"
3
- version = "0.35.1"
3
+ version = "0.37.0"
4
4
  description = "Some useful Pytorch models, circa 2025"
5
5
  authors = [
6
6
  {name = "Nicholas Bailey"}
File without changes
File without changes