broccoli-ml 0.31.2__tar.gz → 0.32.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {broccoli_ml-0.31.2 → broccoli_ml-0.32.0}/PKG-INFO +1 -1
- {broccoli_ml-0.31.2 → broccoli_ml-0.32.0}/broccoli/tensor.py +32 -27
- {broccoli_ml-0.31.2 → broccoli_ml-0.32.0}/broccoli/vit.py +5 -2
- {broccoli_ml-0.31.2 → broccoli_ml-0.32.0}/pyproject.toml +1 -1
- {broccoli_ml-0.31.2 → broccoli_ml-0.32.0}/LICENSE +0 -0
- {broccoli_ml-0.31.2 → broccoli_ml-0.32.0}/README.md +0 -0
- {broccoli_ml-0.31.2 → broccoli_ml-0.32.0}/broccoli/__init__.py +0 -0
- {broccoli_ml-0.31.2 → broccoli_ml-0.32.0}/broccoli/activation.py +0 -0
- {broccoli_ml-0.31.2 → broccoli_ml-0.32.0}/broccoli/assets/2025_resnet_imagenet_1k_pretrained_state_dict.pkl +0 -0
- {broccoli_ml-0.31.2 → broccoli_ml-0.32.0}/broccoli/assets/cifar100_eigenvectors_size_2.pt +0 -0
- {broccoli_ml-0.31.2 → broccoli_ml-0.32.0}/broccoli/assets/cifar100_eigenvectors_size_3.pt +0 -0
- {broccoli_ml-0.31.2 → broccoli_ml-0.32.0}/broccoli/cnn.py +0 -0
- {broccoli_ml-0.31.2 → broccoli_ml-0.32.0}/broccoli/eigenpatches.py +0 -0
- {broccoli_ml-0.31.2 → broccoli_ml-0.32.0}/broccoli/linear.py +0 -0
- {broccoli_ml-0.31.2 → broccoli_ml-0.32.0}/broccoli/rope.py +0 -0
- {broccoli_ml-0.31.2 → broccoli_ml-0.32.0}/broccoli/transformer.py +0 -0
- {broccoli_ml-0.31.2 → broccoli_ml-0.32.0}/broccoli/utils.py +0 -0
@@ -58,48 +58,52 @@ class SigmaReparamTensor(nn.Module):
|
|
58
58
|
|
59
59
|
class AnchoredReparamTensor(nn.Module):
|
60
60
|
"""
|
61
|
-
|
62
|
-
learnable scaling factor.
|
61
|
+
Reparameterises a tensor by decoupling its magnitude and direction.
|
63
62
|
|
64
|
-
The
|
65
|
-
|
66
|
-
|
63
|
+
The direction is represented by a learnable weight tensor, normalised by the
|
64
|
+
Rayleigh quotient with respect to its initial dominant right-singular vector.
|
65
|
+
The magnitude is a separate learnable scalar.
|
67
66
|
|
68
|
-
|
67
|
+
The reparameterization is:
|
69
68
|
|
70
|
-
W_reparam = scale * (W /
|
69
|
+
W_reparam = scale * (W / norm)
|
71
70
|
|
72
|
-
|
71
|
+
where the norm is the Rayleigh quotient uᵀWv₀, with v₀ being the dominant
|
72
|
+
right-singular vector of the initial tensor and u = normalize(Wv₀).
|
73
73
|
"""
|
74
74
|
|
75
75
|
def __init__(self, init_tensor: torch.Tensor):
|
76
|
-
assert init_tensor.ndim == 2
|
76
|
+
assert init_tensor.ndim == 2
|
77
|
+
|
77
78
|
super().__init__()
|
78
79
|
|
79
|
-
|
80
|
-
# exclude it from weight decay
|
81
|
-
self.nondecay_weight = nn.Parameter(init_tensor.clone(), requires_grad=True)
|
80
|
+
self.nondecay_weight = nn.Parameter(init_tensor, requires_grad=True)
|
82
81
|
|
83
|
-
# At initialization, compute the dominant right-singular vector (v_0)
|
84
|
-
# and store it in a non-trainable buffer.
|
85
82
|
with torch.no_grad():
|
86
|
-
_,
|
83
|
+
_, sigma, v_transpose = torch.linalg.svd(
|
87
84
|
self.nondecay_weight, full_matrices=False
|
88
85
|
)
|
89
|
-
# v_transpose[0] is the first row of V^T, which is the first right-singular vector.
|
90
|
-
self.register_buffer("anchor_vector", v_transpose[0])
|
91
86
|
|
92
|
-
|
93
|
-
|
94
|
-
)
|
95
|
-
self.scale = nn.Parameter(initial_norm.clone().detach(), requires_grad=True)
|
87
|
+
self.register_buffer("rayleigh_norm", sigma[:1])
|
88
|
+
self.register_buffer("initial_right_singular", v_transpose[0])
|
89
|
+
self.scale = nn.Parameter(sigma[:1].clone().detach(), requires_grad=True)
|
96
90
|
|
97
|
-
def
|
98
|
-
|
99
|
-
|
91
|
+
def _update_rayleigh_norm(self):
|
92
|
+
with torch.no_grad():
|
93
|
+
product = self.nondecay_weight.mv(self.initial_right_singular)
|
94
|
+
normed_product = F.normalize(product, dim=0)
|
95
|
+
rayleigh_norm = torch.einsum(
|
96
|
+
"m,mn,n->",
|
97
|
+
normed_product,
|
98
|
+
self.nondecay_weight,
|
99
|
+
self.initial_right_singular,
|
100
|
+
)
|
101
|
+
self.rayleigh.data.copy_(rayleigh_norm)
|
100
102
|
|
101
|
-
|
102
|
-
|
103
|
+
def forward(self):
|
104
|
+
if self.training:
|
105
|
+
self._update_rayleigh_norm()
|
106
|
+
return self.scale * (self.nondecay_weight / (self.rayleigh_norm + 1e-6))
|
103
107
|
|
104
108
|
|
105
109
|
class NormReparamTensor(nn.Module):
|
@@ -120,4 +124,5 @@ class NormReparamTensor(nn.Module):
|
|
120
124
|
)
|
121
125
|
|
122
126
|
def forward(self) -> torch.Tensor:
|
123
|
-
|
127
|
+
norm = torch.linalg.norm(self.nondecay_weight)
|
128
|
+
return self.scale * (self.nondecay_weight / (norm + 1e-6))
|
@@ -53,11 +53,13 @@ class ClassificationHead(nn.Module):
|
|
53
53
|
A general classification head for a ViT
|
54
54
|
"""
|
55
55
|
|
56
|
-
def __init__(
|
56
|
+
def __init__(
|
57
|
+
self, d_model, linear_module, n_classes, layer_norm=True, batch_norm=True
|
58
|
+
):
|
57
59
|
super().__init__()
|
58
60
|
self.d_model = d_model
|
59
61
|
self.summarize = GetCLSToken()
|
60
|
-
self.projection =
|
62
|
+
self.projection = linear_module(d_model, n_classes)
|
61
63
|
if batch_norm:
|
62
64
|
self.batch_norm = nn.BatchNorm1d(n_classes, affine=False)
|
63
65
|
else:
|
@@ -65,6 +67,7 @@ class ClassificationHead(nn.Module):
|
|
65
67
|
|
66
68
|
self.classification_process = nn.Sequential(
|
67
69
|
*[
|
70
|
+
nn.LayerNorm if layer_norm else nn.Identity(),
|
68
71
|
self.summarize,
|
69
72
|
self.projection,
|
70
73
|
self.batch_norm,
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|