broccoli-ml 0.31.2__tar.gz → 0.32.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: broccoli-ml
3
- Version: 0.31.2
3
+ Version: 0.32.0
4
4
  Summary: Some useful Pytorch models, circa 2025
5
5
  License: MIT
6
6
  Author: Nicholas Bailey
@@ -58,48 +58,52 @@ class SigmaReparamTensor(nn.Module):
58
58
 
59
59
  class AnchoredReparamTensor(nn.Module):
60
60
  """
61
- Reparameterise a tensor as a normalised tensor of weights multiplied by a
62
- learnable scaling factor.
61
+ Reparameterises a tensor by decoupling its magnitude and direction.
63
62
 
64
- The tensor of weights is also reparameterised as the product of a learnable
65
- weight tensor with the (fixed) dominant right-singular vector of the
66
- weight tensor as it was initialised.
63
+ The direction is represented by a learnable weight tensor, normalised by the
64
+ Rayleigh quotient with respect to its initial dominant right-singular vector.
65
+ The magnitude is a separate learnable scalar.
67
66
 
68
- i.e this module represents a tensor reparameterised as:
67
+ The reparameterization is:
69
68
 
70
- W_reparam = scale * (W / ||W @ v_0||_2)
69
+ W_reparam = scale * (W / norm)
71
70
 
72
- where v_0 is the dominant right-singular vector of the initial tensor W_init.
71
+ where the norm is the Rayleigh quotient uᵀWv₀, with v₀ being the dominant
72
+ right-singular vector of the initial tensor and u = normalize(Wv₀).
73
73
  """
74
74
 
75
75
  def __init__(self, init_tensor: torch.Tensor):
76
- assert init_tensor.ndim == 2, "Input tensor must be a 2D matrix."
76
+ assert init_tensor.ndim == 2
77
+
77
78
  super().__init__()
78
79
 
79
- # Use the gradboard convention of calling something nondecay_* if we should
80
- # exclude it from weight decay
81
- self.nondecay_weight = nn.Parameter(init_tensor.clone(), requires_grad=True)
80
+ self.nondecay_weight = nn.Parameter(init_tensor, requires_grad=True)
82
81
 
83
- # At initialization, compute the dominant right-singular vector (v_0)
84
- # and store it in a non-trainable buffer.
85
82
  with torch.no_grad():
86
- _, _, v_transpose = torch.linalg.svd(
83
+ _, sigma, v_transpose = torch.linalg.svd(
87
84
  self.nondecay_weight, full_matrices=False
88
85
  )
89
- # v_transpose[0] is the first row of V^T, which is the first right-singular vector.
90
- self.register_buffer("anchor_vector", v_transpose[0])
91
86
 
92
- initial_norm = torch.linalg.vector_norm(
93
- self.nondecay_weight.mv(self.anchor_vector)
94
- )
95
- self.scale = nn.Parameter(initial_norm.clone().detach(), requires_grad=True)
87
+ self.register_buffer("rayleigh_norm", sigma[:1])
88
+ self.register_buffer("initial_right_singular", v_transpose[0])
89
+ self.scale = nn.Parameter(sigma[:1].clone().detach(), requires_grad=True)
96
90
 
97
- def forward(self) -> torch.Tensor:
98
- # Calculate the L2 norm of the matrix-vector product W @ v_0
99
- norm = torch.linalg.vector_norm(self.nondecay_weight.mv(self.anchor_vector))
91
+ def _update_rayleigh_norm(self):
92
+ with torch.no_grad():
93
+ product = self.nondecay_weight.mv(self.initial_right_singular)
94
+ normed_product = F.normalize(product, dim=0)
95
+ rayleigh_norm = torch.einsum(
96
+ "m,mn,n->",
97
+ normed_product,
98
+ self.nondecay_weight,
99
+ self.initial_right_singular,
100
+ )
101
+ self.rayleigh.data.copy_(rayleigh_norm)
100
102
 
101
- # Return the reparameterized tensor.
102
- return self.scale * (self.nondecay_weight / (norm + 1e-6))
103
+ def forward(self):
104
+ if self.training:
105
+ self._update_rayleigh_norm()
106
+ return self.scale * (self.nondecay_weight / (self.rayleigh_norm + 1e-6))
103
107
 
104
108
 
105
109
  class NormReparamTensor(nn.Module):
@@ -120,4 +124,5 @@ class NormReparamTensor(nn.Module):
120
124
  )
121
125
 
122
126
  def forward(self) -> torch.Tensor:
123
- return self.scale * F.normalize(self.nondecay_weight)
127
+ norm = torch.linalg.norm(self.nondecay_weight)
128
+ return self.scale * (self.nondecay_weight / (norm + 1e-6))
@@ -53,11 +53,13 @@ class ClassificationHead(nn.Module):
53
53
  A general classification head for a ViT
54
54
  """
55
55
 
56
- def __init__(self, d_model, linear_module, n_classes, batch_norm=True):
56
+ def __init__(
57
+ self, d_model, linear_module, n_classes, layer_norm=True, batch_norm=True
58
+ ):
57
59
  super().__init__()
58
60
  self.d_model = d_model
59
61
  self.summarize = GetCLSToken()
60
- self.projection = nn.Linear(d_model, n_classes)
62
+ self.projection = linear_module(d_model, n_classes)
61
63
  if batch_norm:
62
64
  self.batch_norm = nn.BatchNorm1d(n_classes, affine=False)
63
65
  else:
@@ -65,6 +67,7 @@ class ClassificationHead(nn.Module):
65
67
 
66
68
  self.classification_process = nn.Sequential(
67
69
  *[
70
+ nn.LayerNorm if layer_norm else nn.Identity(),
68
71
  self.summarize,
69
72
  self.projection,
70
73
  self.batch_norm,
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "broccoli-ml"
3
- version = "0.31.2"
3
+ version = "0.32.0"
4
4
  description = "Some useful Pytorch models, circa 2025"
5
5
  authors = [
6
6
  {name = "Nicholas Bailey"}
File without changes
File without changes