broccoli-ml 0.35.0__py3-none-any.whl → 0.35.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
broccoli/tensor.py CHANGED
@@ -77,25 +77,25 @@ class AnchoredReparamTensor(nn.Module):
77
77
 
78
78
  super().__init__()
79
79
 
80
- self.nondecay_weight = nn.Parameter(init_tensor, requires_grad=True)
80
+ self.weight = nn.Parameter(init_tensor, requires_grad=True)
81
81
 
82
82
  with torch.no_grad():
83
- _, sigma, v_transpose = torch.linalg.svd(
84
- self.nondecay_weight, full_matrices=False
85
- )
83
+ _, sigma, v_transpose = torch.linalg.svd(self.weight, full_matrices=False)
86
84
 
87
85
  self.register_buffer("rayleigh_norm", sigma[:1])
88
86
  self.register_buffer("initial_right_singular", v_transpose[0])
89
- self.scale = nn.Parameter(sigma[:1].clone().detach(), requires_grad=True)
87
+ self.nondecay_scale = nn.Parameter(
88
+ sigma[:1].clone().detach(), requires_grad=True
89
+ )
90
90
 
91
91
  def _update_rayleigh_norm(self):
92
92
  with torch.no_grad():
93
- product = self.nondecay_weight.mv(self.initial_right_singular)
93
+ product = self.weight.mv(self.initial_right_singular)
94
94
  normed_product = F.normalize(product, dim=0)
95
95
  rayleigh_norm = torch.einsum(
96
96
  "m,mn,n->",
97
97
  normed_product,
98
- self.nondecay_weight,
98
+ self.weight,
99
99
  self.initial_right_singular,
100
100
  )
101
101
  self.rayleigh_norm.data.copy_(rayleigh_norm)
@@ -103,7 +103,7 @@ class AnchoredReparamTensor(nn.Module):
103
103
  def forward(self):
104
104
  if self.training:
105
105
  self._update_rayleigh_norm()
106
- return self.scale * (self.nondecay_weight / (self.rayleigh_norm + 1e-6))
106
+ return self.nondecay_scale * (self.weight / (self.rayleigh_norm + 1e-6))
107
107
 
108
108
 
109
109
  class NormReparamTensor(nn.Module):
@@ -118,11 +118,11 @@ class NormReparamTensor(nn.Module):
118
118
 
119
119
  # Use the gradboard convention of calling something nondecay_* if we should
120
120
  # exclude it from weight decay
121
- self.nondecay_weight = nn.Parameter(init_tensor.clone(), requires_grad=True)
122
- self.scale = nn.Parameter(
123
- torch.linalg.norm(self.nondecay_weight).clone().detach(), requires_grad=True
121
+ self.weight = nn.Parameter(init_tensor.clone(), requires_grad=True)
122
+ self.nondecay_scale = nn.Parameter(
123
+ torch.linalg.norm(self.weight).clone().detach(), requires_grad=True
124
124
  )
125
125
 
126
126
  def forward(self) -> torch.Tensor:
127
- norm = torch.linalg.norm(self.nondecay_weight)
128
- return self.scale * (self.nondecay_weight / (norm + 1e-6))
127
+ norm = torch.linalg.norm(self.weight)
128
+ return self.nondecay_scale * (self.weight / (norm + 1e-6))
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: broccoli-ml
3
- Version: 0.35.0
3
+ Version: 0.35.1
4
4
  Summary: Some useful Pytorch models, circa 2025
5
5
  License: MIT
6
6
  Author: Nicholas Bailey
@@ -7,11 +7,11 @@ broccoli/cnn.py,sha256=jeRyKIAMWu1E3iyI14MGgSZuZivPMh12iqkqW9ilNjo,17785
7
7
  broccoli/eigenpatches.py,sha256=J6n2usN1oQuHEHYiBNyYpn_a9eQcHjOBiIlvSei520Y,2413
8
8
  broccoli/linear.py,sha256=8Y9vD85ZEgNZsIQgO3uRQ3lOQR-JjwvabY8liCrfNCk,4831
9
9
  broccoli/rope.py,sha256=hw7kBPNR9GQXj4GxyIAffsGKPfcTPOFh8Bc7oEHtaZY,12108
10
- broccoli/tensor.py,sha256=ks2TRCdS10k2XvxEieh2sj_LzjTNRuiO6gekKFTtziI,4533
10
+ broccoli/tensor.py,sha256=um8mrxkYbvNDo-QvHlmJm8Aw6qcngOlUZPoAk_PMReA,4480
11
11
  broccoli/transformer.py,sha256=t0gsADJC9UOlwjm7tDKdy0pAZ8l3clTcCnes86zvH-k,17203
12
12
  broccoli/utils.py,sha256=htq_hOsdhUhL0nJi9WkKiEYOjEoWqFpK5X49PtgTf-0,299
13
13
  broccoli/vit.py,sha256=c-ZRHiLDOoQDJO9OJ51zD9HqaluG33flIwTXQQfms-g,17389
14
- broccoli_ml-0.35.0.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
15
- broccoli_ml-0.35.0.dist-info/METADATA,sha256=v0JSpcubSGwxA5dFPbDwz2r2oGZWSeqYND1Mu8WOiJY,1257
16
- broccoli_ml-0.35.0.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
17
- broccoli_ml-0.35.0.dist-info/RECORD,,
14
+ broccoli_ml-0.35.1.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
15
+ broccoli_ml-0.35.1.dist-info/METADATA,sha256=5pQA45ytAkkn0F5il2zuSN0vY7hFtVJvyUi9MXF-0EA,1257
16
+ broccoli_ml-0.35.1.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
17
+ broccoli_ml-0.35.1.dist-info/RECORD,,