broccoli-ml 0.24.1__tar.gz → 0.24.3__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {broccoli_ml-0.24.1 → broccoli_ml-0.24.3}/PKG-INFO +1 -1
- {broccoli_ml-0.24.1 → broccoli_ml-0.24.3}/broccoli/tensor.py +11 -5
- {broccoli_ml-0.24.1 → broccoli_ml-0.24.3}/broccoli/transformer.py +4 -3
- {broccoli_ml-0.24.1 → broccoli_ml-0.24.3}/broccoli/vit.py +1 -1
- {broccoli_ml-0.24.1 → broccoli_ml-0.24.3}/pyproject.toml +1 -1
- {broccoli_ml-0.24.1 → broccoli_ml-0.24.3}/LICENSE +0 -0
- {broccoli_ml-0.24.1 → broccoli_ml-0.24.3}/README.md +0 -0
- {broccoli_ml-0.24.1 → broccoli_ml-0.24.3}/broccoli/__init__.py +0 -0
- {broccoli_ml-0.24.1 → broccoli_ml-0.24.3}/broccoli/activation.py +0 -0
- {broccoli_ml-0.24.1 → broccoli_ml-0.24.3}/broccoli/assets/2025_resnet_imagenet_1k_pretrained_state_dict.pkl +0 -0
- {broccoli_ml-0.24.1 → broccoli_ml-0.24.3}/broccoli/assets/cifar100_eigenvectors_size_2.pt +0 -0
- {broccoli_ml-0.24.1 → broccoli_ml-0.24.3}/broccoli/assets/cifar100_eigenvectors_size_3.pt +0 -0
- {broccoli_ml-0.24.1 → broccoli_ml-0.24.3}/broccoli/cnn.py +0 -0
- {broccoli_ml-0.24.1 → broccoli_ml-0.24.3}/broccoli/eigenpatches.py +0 -0
- {broccoli_ml-0.24.1 → broccoli_ml-0.24.3}/broccoli/linear.py +0 -0
- {broccoli_ml-0.24.1 → broccoli_ml-0.24.3}/broccoli/rope.py +0 -0
- {broccoli_ml-0.24.1 → broccoli_ml-0.24.3}/broccoli/utils.py +0 -0
@@ -76,21 +76,27 @@ class AnchoredReparamTensor(nn.Module):
|
|
76
76
|
assert init_tensor.ndim == 2, "Input tensor must be a 2D matrix."
|
77
77
|
super().__init__()
|
78
78
|
|
79
|
-
|
79
|
+
# Use the gradboard convention of calling something nondecay_* if we should
|
80
|
+
# exclude it from weight decay
|
81
|
+
self.nondecay_weight = nn.Parameter(init_tensor.clone(), requires_grad=True)
|
80
82
|
|
81
83
|
# At initialization, compute the dominant right-singular vector (v_0)
|
82
84
|
# and store it in a non-trainable buffer.
|
83
85
|
with torch.no_grad():
|
84
|
-
_, _, v_transpose = torch.linalg.svd(
|
86
|
+
_, _, v_transpose = torch.linalg.svd(
|
87
|
+
self.nondecay_weight, full_matrices=False
|
88
|
+
)
|
85
89
|
# v_transpose[0] is the first row of V^T, which is the first right-singular vector.
|
86
90
|
self.register_buffer("anchor_vector", v_transpose[0])
|
87
91
|
|
88
|
-
initial_norm = torch.linalg.vector_norm(
|
92
|
+
initial_norm = torch.linalg.vector_norm(
|
93
|
+
self.nondecay_weight.mv(self.anchor_vector)
|
94
|
+
)
|
89
95
|
self.scale = nn.Parameter(initial_norm.clone().detach(), requires_grad=True)
|
90
96
|
|
91
97
|
def forward(self) -> torch.Tensor:
|
92
98
|
# Calculate the L2 norm of the matrix-vector product W @ v_0
|
93
|
-
norm = torch.linalg.vector_norm(self.
|
99
|
+
norm = torch.linalg.vector_norm(self.nondecay_weight.mv(self.anchor_vector))
|
94
100
|
|
95
101
|
# Return the reparameterized tensor.
|
96
|
-
return self.scale * (self.
|
102
|
+
return self.scale * (self.nondecay_weight / (norm + 1e-6))
|
@@ -236,7 +236,7 @@ class FeedforwardBlock(nn.Module):
|
|
236
236
|
activation_kwargs=None,
|
237
237
|
dropout=0.0,
|
238
238
|
linear_module=nn.Linear,
|
239
|
-
|
239
|
+
raw_input=False,
|
240
240
|
):
|
241
241
|
super().__init__()
|
242
242
|
|
@@ -253,8 +253,9 @@ class FeedforwardBlock(nn.Module):
|
|
253
253
|
else ratio * output_features
|
254
254
|
)
|
255
255
|
|
256
|
-
if
|
256
|
+
if raw_input:
|
257
257
|
self.memory_type = AnchoredLinear
|
258
|
+
|
258
259
|
else:
|
259
260
|
self.memory_type = linear_module
|
260
261
|
|
@@ -263,7 +264,7 @@ class FeedforwardBlock(nn.Module):
|
|
263
264
|
nn.LayerNorm(input_features),
|
264
265
|
linear_module(input_features, self.max_features),
|
265
266
|
self.activation,
|
266
|
-
# nn.LayerNorm(ratio * output_features),
|
267
|
+
# nn.LayerNorm(ratio * output_features) if raw_input else nn.Identity(),
|
267
268
|
self.memory_type(ratio * output_features, output_features),
|
268
269
|
self.dropout,
|
269
270
|
]
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|