broccoli-ml 0.35.0__py3-none-any.whl → 0.36.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
broccoli/tensor.py CHANGED
@@ -77,25 +77,25 @@ class AnchoredReparamTensor(nn.Module):
77
77
 
78
78
  super().__init__()
79
79
 
80
- self.nondecay_weight = nn.Parameter(init_tensor, requires_grad=True)
80
+ self.weight = nn.Parameter(init_tensor, requires_grad=True)
81
81
 
82
82
  with torch.no_grad():
83
- _, sigma, v_transpose = torch.linalg.svd(
84
- self.nondecay_weight, full_matrices=False
85
- )
83
+ _, sigma, v_transpose = torch.linalg.svd(self.weight, full_matrices=False)
86
84
 
87
85
  self.register_buffer("rayleigh_norm", sigma[:1])
88
86
  self.register_buffer("initial_right_singular", v_transpose[0])
89
- self.scale = nn.Parameter(sigma[:1].clone().detach(), requires_grad=True)
87
+ self.nondecay_scale = nn.Parameter(
88
+ sigma[:1].clone().detach(), requires_grad=True
89
+ )
90
90
 
91
91
  def _update_rayleigh_norm(self):
92
92
  with torch.no_grad():
93
- product = self.nondecay_weight.mv(self.initial_right_singular)
93
+ product = self.weight.mv(self.initial_right_singular)
94
94
  normed_product = F.normalize(product, dim=0)
95
95
  rayleigh_norm = torch.einsum(
96
96
  "m,mn,n->",
97
97
  normed_product,
98
- self.nondecay_weight,
98
+ self.weight,
99
99
  self.initial_right_singular,
100
100
  )
101
101
  self.rayleigh_norm.data.copy_(rayleigh_norm)
@@ -103,7 +103,7 @@ class AnchoredReparamTensor(nn.Module):
103
103
  def forward(self):
104
104
  if self.training:
105
105
  self._update_rayleigh_norm()
106
- return self.scale * (self.nondecay_weight / (self.rayleigh_norm + 1e-6))
106
+ return self.nondecay_scale * (self.weight / (self.rayleigh_norm + 1e-6))
107
107
 
108
108
 
109
109
  class NormReparamTensor(nn.Module):
@@ -118,11 +118,11 @@ class NormReparamTensor(nn.Module):
118
118
 
119
119
  # Use the gradboard convention of calling something nondecay_* if we should
120
120
  # exclude it from weight decay
121
- self.nondecay_weight = nn.Parameter(init_tensor.clone(), requires_grad=True)
122
- self.scale = nn.Parameter(
123
- torch.linalg.norm(self.nondecay_weight).clone().detach(), requires_grad=True
121
+ self.weight = nn.Parameter(init_tensor.clone(), requires_grad=True)
122
+ self.nondecay_scale = nn.Parameter(
123
+ torch.linalg.norm(self.weight).clone().detach(), requires_grad=True
124
124
  )
125
125
 
126
126
  def forward(self) -> torch.Tensor:
127
- norm = torch.linalg.norm(self.nondecay_weight)
128
- return self.scale * (self.nondecay_weight / (norm + 1e-6))
127
+ norm = torch.linalg.norm(self.weight)
128
+ return self.nondecay_scale * (self.weight / (norm + 1e-6))
broccoli/transformer.py CHANGED
@@ -21,45 +21,6 @@ class MHAttention(nn.Module):
21
21
  are the same shape.
22
22
 
23
23
  Assumes bias=False and batch_first=True, as God intended.
24
-
25
- Optionally adds various bells and whistles suggested in the
26
- literature, including:
27
-
28
- Noam Shazeer's scaled attention per "Attention is All You Need"
29
- (https://arxiv.org/abs/1706.03762).
30
-
31
- Max subtract softmax as discussed in "Attention As An RNN"
32
- (https://arxiv.org/abs/2405.13956)
33
-
34
- Log-length scaled softmax per "Overcoming a Theoretical Limitation of
35
- Self-Attention" (https://arxiv.org/abs/2202.12172).
36
-
37
- Quiet softmax per
38
- https://www.evanmiller.org/attention-is-off-by-one.html
39
-
40
- Args:
41
- d_model: ...
42
- n_heads: ...
43
- dropout: ...
44
- causal: should a causal mask be applied to the logits before attention
45
- is applied? This is standard when using self-attention. Cannot be
46
- True if inputs won't be square (e.g. if sequence length for
47
- encoder and decoder are different)
48
- sequence_length: ...
49
- share_kv: ...
50
- linear_module: ...
51
- max_subtract: if True, the maximum logit value is subtracted from all
52
- logits before performing the softmax operation to create a more
53
- numerically stable softmax. This is discussed in "Attention As An
54
- RNN" (https://arxiv.org/abs/2405.13956).
55
- d_model_scale: ...
56
- log_length_scale: if True, multiplies logits by the log length of
57
- the decoder sequence before performing the softmax operation, as
58
- proposed in "Overcoming a Theoretical Limitation of Self-Attention"
59
- (https://arxiv.org/abs/2202.12172).
60
- quiet: if True, adds 1 to the denominator of the softmax operation,
61
- allowing some tokens to attend to no other tokens as described in
62
- https://www.evanmiller.org/attention-is-off-by-one.html.
63
24
  """
64
25
 
65
26
  def __init__(
@@ -280,7 +241,7 @@ class FeedforwardBlock(nn.Module):
280
241
  elif self.residual_path:
281
242
  return x + self.process(x)
282
243
  else:
283
- return x
244
+ return self.process(x)
284
245
 
285
246
 
286
247
  class TransformerBlock(nn.Module):
@@ -374,7 +335,9 @@ class TransformerBlock(nn.Module):
374
335
  identity_probability = self.identity_probability
375
336
 
376
337
  # perform the identity operation for some rows in the batch
377
- identity_count = random.binomial(n=x.size(0), p=identity_probability)
338
+ dist = torch.distributions.Binomial(x.size(0), identity_probability)
339
+ identity_count = int(dist.sample().item())
340
+
378
341
  shuffle_indices = torch.randperm(x.size(0), device=x.device)
379
342
  unshuffle_indices = torch.argsort(shuffle_indices)
380
343
  shuffled = x[shuffle_indices, :, :]
broccoli/vit.py CHANGED
@@ -236,13 +236,7 @@ class ViTEncoder(nn.Module):
236
236
 
237
237
  if pooling_type is None:
238
238
  pooling_out_channels = cnn_activation_out_channels
239
- self.pool = nn.Sequential(
240
- *[
241
- Rearrange(
242
- f"N C {spatial_dim_names} -> N ({spatial_dim_names}) C"
243
- ), # for transformer
244
- ]
245
- )
239
+ self.pool = nn.Identity()
246
240
 
247
241
  elif pooling_type == "max":
248
242
  pooling_out_channels = cnn_activation_out_channels
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: broccoli-ml
3
- Version: 0.35.0
3
+ Version: 0.36.0
4
4
  Summary: Some useful Pytorch models, circa 2025
5
5
  License: MIT
6
6
  Author: Nicholas Bailey
@@ -7,11 +7,11 @@ broccoli/cnn.py,sha256=jeRyKIAMWu1E3iyI14MGgSZuZivPMh12iqkqW9ilNjo,17785
7
7
  broccoli/eigenpatches.py,sha256=J6n2usN1oQuHEHYiBNyYpn_a9eQcHjOBiIlvSei520Y,2413
8
8
  broccoli/linear.py,sha256=8Y9vD85ZEgNZsIQgO3uRQ3lOQR-JjwvabY8liCrfNCk,4831
9
9
  broccoli/rope.py,sha256=hw7kBPNR9GQXj4GxyIAffsGKPfcTPOFh8Bc7oEHtaZY,12108
10
- broccoli/tensor.py,sha256=ks2TRCdS10k2XvxEieh2sj_LzjTNRuiO6gekKFTtziI,4533
11
- broccoli/transformer.py,sha256=t0gsADJC9UOlwjm7tDKdy0pAZ8l3clTcCnes86zvH-k,17203
10
+ broccoli/tensor.py,sha256=um8mrxkYbvNDo-QvHlmJm8Aw6qcngOlUZPoAk_PMReA,4480
11
+ broccoli/transformer.py,sha256=NH94U6lxHzmDGDHTTtJV2kUs7IcS2iNmFJl44_6KtQ0,15456
12
12
  broccoli/utils.py,sha256=htq_hOsdhUhL0nJi9WkKiEYOjEoWqFpK5X49PtgTf-0,299
13
- broccoli/vit.py,sha256=c-ZRHiLDOoQDJO9OJ51zD9HqaluG33flIwTXQQfms-g,17389
14
- broccoli_ml-0.35.0.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
15
- broccoli_ml-0.35.0.dist-info/METADATA,sha256=v0JSpcubSGwxA5dFPbDwz2r2oGZWSeqYND1Mu8WOiJY,1257
16
- broccoli_ml-0.35.0.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
17
- broccoli_ml-0.35.0.dist-info/RECORD,,
13
+ broccoli/vit.py,sha256=05xqIw9xvE5easXcp4wIA1jQ0xUyRIq6h0ZDtbitXi4,17184
14
+ broccoli_ml-0.36.0.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
15
+ broccoli_ml-0.36.0.dist-info/METADATA,sha256=csog4ZG1PGeRuFO5QnHdVPgmDYXsGQQJ621JgU0D83w,1257
16
+ broccoli_ml-0.36.0.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
17
+ broccoli_ml-0.36.0.dist-info/RECORD,,