broccoli-ml 0.24.3__tar.gz → 0.25.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {broccoli_ml-0.24.3 → broccoli_ml-0.25.0}/PKG-INFO +1 -1
- {broccoli_ml-0.24.3 → broccoli_ml-0.25.0}/broccoli/linear.py +44 -1
- {broccoli_ml-0.24.3 → broccoli_ml-0.25.0}/broccoli/tensor.py +21 -0
- {broccoli_ml-0.24.3 → broccoli_ml-0.25.0}/broccoli/transformer.py +13 -10
- {broccoli_ml-0.24.3 → broccoli_ml-0.25.0}/broccoli/vit.py +8 -0
- {broccoli_ml-0.24.3 → broccoli_ml-0.25.0}/pyproject.toml +1 -1
- {broccoli_ml-0.24.3 → broccoli_ml-0.25.0}/LICENSE +0 -0
- {broccoli_ml-0.24.3 → broccoli_ml-0.25.0}/README.md +0 -0
- {broccoli_ml-0.24.3 → broccoli_ml-0.25.0}/broccoli/__init__.py +0 -0
- {broccoli_ml-0.24.3 → broccoli_ml-0.25.0}/broccoli/activation.py +0 -0
- {broccoli_ml-0.24.3 → broccoli_ml-0.25.0}/broccoli/assets/2025_resnet_imagenet_1k_pretrained_state_dict.pkl +0 -0
- {broccoli_ml-0.24.3 → broccoli_ml-0.25.0}/broccoli/assets/cifar100_eigenvectors_size_2.pt +0 -0
- {broccoli_ml-0.24.3 → broccoli_ml-0.25.0}/broccoli/assets/cifar100_eigenvectors_size_3.pt +0 -0
- {broccoli_ml-0.24.3 → broccoli_ml-0.25.0}/broccoli/cnn.py +0 -0
- {broccoli_ml-0.24.3 → broccoli_ml-0.25.0}/broccoli/eigenpatches.py +0 -0
- {broccoli_ml-0.24.3 → broccoli_ml-0.25.0}/broccoli/rope.py +0 -0
- {broccoli_ml-0.24.3 → broccoli_ml-0.25.0}/broccoli/utils.py +0 -0
@@ -5,7 +5,7 @@ import torch
|
|
5
5
|
from torch import nn
|
6
6
|
from torch.nn import functional as F
|
7
7
|
|
8
|
-
from .tensor import SigmaReparamTensor, AnchoredReparamTensor
|
8
|
+
from .tensor import SigmaReparamTensor, AnchoredReparamTensor, NormReparamTensor
|
9
9
|
|
10
10
|
|
11
11
|
class SpectralNormLinear(nn.Module):
|
@@ -93,3 +93,46 @@ class AnchoredLinear(nn.Module):
|
|
93
93
|
f"AnchoredLinear(in_features={self.in_features},"
|
94
94
|
f"out_features={self.out_features}, bias={self.use_bias})"
|
95
95
|
)
|
96
|
+
|
97
|
+
|
98
|
+
class ReparamLinear(nn.Module):
|
99
|
+
"""
|
100
|
+
...
|
101
|
+
"""
|
102
|
+
|
103
|
+
def __init__(self, in_features: int, out_features: int, bias: bool = True):
|
104
|
+
super().__init__()
|
105
|
+
self.in_features = in_features
|
106
|
+
self.out_features = out_features
|
107
|
+
self.use_bias = bias
|
108
|
+
|
109
|
+
self.weights = None
|
110
|
+
|
111
|
+
# Define the bias vector as a learnable parameter if required.
|
112
|
+
if self.use_bias:
|
113
|
+
self.bias = nn.Parameter(torch.empty(out_features))
|
114
|
+
else:
|
115
|
+
# If no bias, register it as None.
|
116
|
+
# This is important so that PyTorch doesn't complain when saving/loading the model.
|
117
|
+
self.register_parameter("bias", None)
|
118
|
+
|
119
|
+
self.reset_parameters()
|
120
|
+
|
121
|
+
def reset_parameters(self) -> None:
|
122
|
+
weights = torch.empty(self.out_features, self.in_features)
|
123
|
+
nn.init.kaiming_uniform_(weights, a=math.sqrt(5))
|
124
|
+
if self.use_bias:
|
125
|
+
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(weights)
|
126
|
+
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
|
127
|
+
nn.init.uniform_(self.bias, -bound, bound)
|
128
|
+
self.weights = NormReparamTensor(weights)
|
129
|
+
|
130
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
131
|
+
return F.linear(x, self.weights(), self.bias)
|
132
|
+
|
133
|
+
def __repr__(self) -> str:
|
134
|
+
# Optional: A nice representation for printing the module.
|
135
|
+
return (
|
136
|
+
f"AnchoredLinear(in_features={self.in_features},"
|
137
|
+
f"out_features={self.out_features}, bias={self.use_bias})"
|
138
|
+
)
|
@@ -100,3 +100,24 @@ class AnchoredReparamTensor(nn.Module):
|
|
100
100
|
|
101
101
|
# Return the reparameterized tensor.
|
102
102
|
return self.scale * (self.nondecay_weight / (norm + 1e-6))
|
103
|
+
|
104
|
+
|
105
|
+
class NormReparamTensor(nn.Module):
|
106
|
+
"""
|
107
|
+
Reparameterise a tensor as a normalised tensor of weights multiplied by a
|
108
|
+
learnable scaling factor.
|
109
|
+
"""
|
110
|
+
|
111
|
+
def __init__(self, init_tensor: torch.Tensor):
|
112
|
+
assert init_tensor.ndim == 2, "Input tensor must be a 2D matrix."
|
113
|
+
super().__init__()
|
114
|
+
|
115
|
+
# Use the gradboard convention of calling something nondecay_* if we should
|
116
|
+
# exclude it from weight decay
|
117
|
+
self.nondecay_weight = nn.Parameter(init_tensor.clone(), requires_grad=True)
|
118
|
+
self.scale = nn.Parameter(
|
119
|
+
torch.linalg.norm(self.nondecay_weight).clone().detach(), requires_grad=True
|
120
|
+
)
|
121
|
+
|
122
|
+
def forward(self) -> torch.Tensor:
|
123
|
+
return self.scale * F.normalize(self.nondecay_weight)
|
@@ -236,7 +236,8 @@ class FeedforwardBlock(nn.Module):
|
|
236
236
|
activation_kwargs=None,
|
237
237
|
dropout=0.0,
|
238
238
|
linear_module=nn.Linear,
|
239
|
-
|
239
|
+
pre_norm=True,
|
240
|
+
normformer=False,
|
240
241
|
):
|
241
242
|
super().__init__()
|
242
243
|
|
@@ -253,19 +254,13 @@ class FeedforwardBlock(nn.Module):
|
|
253
254
|
else ratio * output_features
|
254
255
|
)
|
255
256
|
|
256
|
-
if raw_input:
|
257
|
-
self.memory_type = AnchoredLinear
|
258
|
-
|
259
|
-
else:
|
260
|
-
self.memory_type = linear_module
|
261
|
-
|
262
257
|
self.process = nn.Sequential(
|
263
258
|
*[
|
264
|
-
nn.LayerNorm(input_features),
|
259
|
+
nn.LayerNorm(input_features) if pre_norm else nn.Identity(),
|
265
260
|
linear_module(input_features, self.max_features),
|
266
261
|
self.activation,
|
267
|
-
|
268
|
-
|
262
|
+
nn.LayerNorm(ratio * output_features) if normformer else nn.Identity(),
|
263
|
+
linear_module(ratio * output_features, output_features),
|
269
264
|
self.dropout,
|
270
265
|
]
|
271
266
|
)
|
@@ -299,6 +294,8 @@ class TransformerBlock(nn.Module):
|
|
299
294
|
identity_probability=0.0,
|
300
295
|
causal=False,
|
301
296
|
linear_module=nn.Linear,
|
297
|
+
pre_norm=True,
|
298
|
+
normformer=False,
|
302
299
|
):
|
303
300
|
super().__init__()
|
304
301
|
|
@@ -339,6 +336,8 @@ class TransformerBlock(nn.Module):
|
|
339
336
|
activation_kwargs=activation_kwargs,
|
340
337
|
dropout=mlp_dropout,
|
341
338
|
linear_module=linear_module,
|
339
|
+
pre_norm=pre_norm,
|
340
|
+
normformer=normformer,
|
342
341
|
)
|
343
342
|
|
344
343
|
@property
|
@@ -393,6 +392,8 @@ class TransformerEncoder(nn.Module):
|
|
393
392
|
linear_module=nn.Linear,
|
394
393
|
bos_tokens=0,
|
395
394
|
return_bos_tokens=False,
|
395
|
+
pre_norm=True,
|
396
|
+
normformer=False,
|
396
397
|
):
|
397
398
|
if position_embedding_type == "relative":
|
398
399
|
assert source_size is not None # TODO: make this a proper exception
|
@@ -451,6 +452,8 @@ class TransformerEncoder(nn.Module):
|
|
451
452
|
identity_probability=self.stochastic_depth_probabilities[i],
|
452
453
|
causal=causal,
|
453
454
|
linear_module=linear_module,
|
455
|
+
pre_norm=pre_norm,
|
456
|
+
normformer=normformer,
|
454
457
|
)
|
455
458
|
for i in range(n_layers)
|
456
459
|
]
|
@@ -117,6 +117,8 @@ class ViTEncoder(nn.Module):
|
|
117
117
|
pooling_kernel_stride=2,
|
118
118
|
pooling_padding=1,
|
119
119
|
transformer_feedforward_first=True,
|
120
|
+
transformer_pre_norm=True,
|
121
|
+
transformer_normformer=False,
|
120
122
|
transformer_position_embedding="relative", # absolute or relative
|
121
123
|
transformer_embedding_size=256,
|
122
124
|
transformer_layers=7,
|
@@ -289,6 +291,8 @@ class ViTEncoder(nn.Module):
|
|
289
291
|
linear_module=linear_module,
|
290
292
|
bos_tokens=transformer_bos_tokens,
|
291
293
|
return_bos_tokens=transformer_return_bos_tokens,
|
294
|
+
pre_norm=transformer_pre_norm,
|
295
|
+
normformer=transformer_normformer,
|
292
296
|
)
|
293
297
|
else:
|
294
298
|
self.transformer = nn.Identity()
|
@@ -356,6 +360,8 @@ class ViT(nn.Module):
|
|
356
360
|
pooling_kernel_stride=2,
|
357
361
|
pooling_padding=1,
|
358
362
|
transformer_feedforward_first=True,
|
363
|
+
transformer_pre_norm=True,
|
364
|
+
transformer_normformer=False,
|
359
365
|
transformer_position_embedding="relative", # absolute or relative
|
360
366
|
transformer_embedding_size=256,
|
361
367
|
transformer_layers=7,
|
@@ -410,6 +416,8 @@ class ViT(nn.Module):
|
|
410
416
|
pooling_kernel_stride=pooling_kernel_stride,
|
411
417
|
pooling_padding=pooling_padding,
|
412
418
|
transformer_feedforward_first=transformer_feedforward_first,
|
419
|
+
transformer_pre_norm=transformer_pre_norm,
|
420
|
+
transformer_normformer=transformer_normformer,
|
413
421
|
transformer_position_embedding=transformer_position_embedding,
|
414
422
|
transformer_embedding_size=transformer_embedding_size,
|
415
423
|
transformer_layers=transformer_layers,
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|