broccoli-ml 0.24.3__tar.gz → 0.26.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {broccoli_ml-0.24.3 → broccoli_ml-0.26.0}/PKG-INFO +1 -1
- {broccoli_ml-0.24.3 → broccoli_ml-0.26.0}/broccoli/linear.py +44 -1
- {broccoli_ml-0.24.3 → broccoli_ml-0.26.0}/broccoli/tensor.py +21 -0
- {broccoli_ml-0.24.3 → broccoli_ml-0.26.0}/broccoli/transformer.py +30 -17
- {broccoli_ml-0.24.3 → broccoli_ml-0.26.0}/broccoli/vit.py +10 -1
- {broccoli_ml-0.24.3 → broccoli_ml-0.26.0}/pyproject.toml +1 -1
- {broccoli_ml-0.24.3 → broccoli_ml-0.26.0}/LICENSE +0 -0
- {broccoli_ml-0.24.3 → broccoli_ml-0.26.0}/README.md +0 -0
- {broccoli_ml-0.24.3 → broccoli_ml-0.26.0}/broccoli/__init__.py +0 -0
- {broccoli_ml-0.24.3 → broccoli_ml-0.26.0}/broccoli/activation.py +0 -0
- {broccoli_ml-0.24.3 → broccoli_ml-0.26.0}/broccoli/assets/2025_resnet_imagenet_1k_pretrained_state_dict.pkl +0 -0
- {broccoli_ml-0.24.3 → broccoli_ml-0.26.0}/broccoli/assets/cifar100_eigenvectors_size_2.pt +0 -0
- {broccoli_ml-0.24.3 → broccoli_ml-0.26.0}/broccoli/assets/cifar100_eigenvectors_size_3.pt +0 -0
- {broccoli_ml-0.24.3 → broccoli_ml-0.26.0}/broccoli/cnn.py +0 -0
- {broccoli_ml-0.24.3 → broccoli_ml-0.26.0}/broccoli/eigenpatches.py +0 -0
- {broccoli_ml-0.24.3 → broccoli_ml-0.26.0}/broccoli/rope.py +0 -0
- {broccoli_ml-0.24.3 → broccoli_ml-0.26.0}/broccoli/utils.py +0 -0
@@ -5,7 +5,7 @@ import torch
|
|
5
5
|
from torch import nn
|
6
6
|
from torch.nn import functional as F
|
7
7
|
|
8
|
-
from .tensor import SigmaReparamTensor, AnchoredReparamTensor
|
8
|
+
from .tensor import SigmaReparamTensor, AnchoredReparamTensor, NormReparamTensor
|
9
9
|
|
10
10
|
|
11
11
|
class SpectralNormLinear(nn.Module):
|
@@ -93,3 +93,46 @@ class AnchoredLinear(nn.Module):
|
|
93
93
|
f"AnchoredLinear(in_features={self.in_features},"
|
94
94
|
f"out_features={self.out_features}, bias={self.use_bias})"
|
95
95
|
)
|
96
|
+
|
97
|
+
|
98
|
+
class ReparamLinear(nn.Module):
|
99
|
+
"""
|
100
|
+
...
|
101
|
+
"""
|
102
|
+
|
103
|
+
def __init__(self, in_features: int, out_features: int, bias: bool = True):
|
104
|
+
super().__init__()
|
105
|
+
self.in_features = in_features
|
106
|
+
self.out_features = out_features
|
107
|
+
self.use_bias = bias
|
108
|
+
|
109
|
+
self.weights = None
|
110
|
+
|
111
|
+
# Define the bias vector as a learnable parameter if required.
|
112
|
+
if self.use_bias:
|
113
|
+
self.bias = nn.Parameter(torch.empty(out_features))
|
114
|
+
else:
|
115
|
+
# If no bias, register it as None.
|
116
|
+
# This is important so that PyTorch doesn't complain when saving/loading the model.
|
117
|
+
self.register_parameter("bias", None)
|
118
|
+
|
119
|
+
self.reset_parameters()
|
120
|
+
|
121
|
+
def reset_parameters(self) -> None:
|
122
|
+
weights = torch.empty(self.out_features, self.in_features)
|
123
|
+
nn.init.kaiming_uniform_(weights, a=math.sqrt(5))
|
124
|
+
if self.use_bias:
|
125
|
+
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(weights)
|
126
|
+
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
|
127
|
+
nn.init.uniform_(self.bias, -bound, bound)
|
128
|
+
self.weights = NormReparamTensor(weights)
|
129
|
+
|
130
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
131
|
+
return F.linear(x, self.weights(), self.bias)
|
132
|
+
|
133
|
+
def __repr__(self) -> str:
|
134
|
+
# Optional: A nice representation for printing the module.
|
135
|
+
return (
|
136
|
+
f"AnchoredLinear(in_features={self.in_features},"
|
137
|
+
f"out_features={self.out_features}, bias={self.use_bias})"
|
138
|
+
)
|
@@ -100,3 +100,24 @@ class AnchoredReparamTensor(nn.Module):
|
|
100
100
|
|
101
101
|
# Return the reparameterized tensor.
|
102
102
|
return self.scale * (self.nondecay_weight / (norm + 1e-6))
|
103
|
+
|
104
|
+
|
105
|
+
class NormReparamTensor(nn.Module):
|
106
|
+
"""
|
107
|
+
Reparameterise a tensor as a normalised tensor of weights multiplied by a
|
108
|
+
learnable scaling factor.
|
109
|
+
"""
|
110
|
+
|
111
|
+
def __init__(self, init_tensor: torch.Tensor):
|
112
|
+
assert init_tensor.ndim == 2, "Input tensor must be a 2D matrix."
|
113
|
+
super().__init__()
|
114
|
+
|
115
|
+
# Use the gradboard convention of calling something nondecay_* if we should
|
116
|
+
# exclude it from weight decay
|
117
|
+
self.nondecay_weight = nn.Parameter(init_tensor.clone(), requires_grad=True)
|
118
|
+
self.scale = nn.Parameter(
|
119
|
+
torch.linalg.norm(self.nondecay_weight).clone().detach(), requires_grad=True
|
120
|
+
)
|
121
|
+
|
122
|
+
def forward(self) -> torch.Tensor:
|
123
|
+
return self.scale * F.normalize(self.nondecay_weight)
|
@@ -236,7 +236,8 @@ class FeedforwardBlock(nn.Module):
|
|
236
236
|
activation_kwargs=None,
|
237
237
|
dropout=0.0,
|
238
238
|
linear_module=nn.Linear,
|
239
|
-
|
239
|
+
pre_norm=True,
|
240
|
+
normformer=False,
|
240
241
|
):
|
241
242
|
super().__init__()
|
242
243
|
|
@@ -253,19 +254,13 @@ class FeedforwardBlock(nn.Module):
|
|
253
254
|
else ratio * output_features
|
254
255
|
)
|
255
256
|
|
256
|
-
if raw_input:
|
257
|
-
self.memory_type = AnchoredLinear
|
258
|
-
|
259
|
-
else:
|
260
|
-
self.memory_type = linear_module
|
261
|
-
|
262
257
|
self.process = nn.Sequential(
|
263
258
|
*[
|
264
|
-
nn.LayerNorm(input_features),
|
259
|
+
nn.LayerNorm(input_features) if pre_norm else nn.Identity(),
|
265
260
|
linear_module(input_features, self.max_features),
|
266
261
|
self.activation,
|
267
|
-
|
268
|
-
|
262
|
+
nn.LayerNorm(ratio * output_features) if normformer else nn.Identity(),
|
263
|
+
linear_module(ratio * output_features, output_features),
|
269
264
|
self.dropout,
|
270
265
|
]
|
271
266
|
)
|
@@ -299,12 +294,17 @@ class TransformerBlock(nn.Module):
|
|
299
294
|
identity_probability=0.0,
|
300
295
|
causal=False,
|
301
296
|
linear_module=nn.Linear,
|
297
|
+
pre_norm=True,
|
298
|
+
normformer=False,
|
302
299
|
):
|
303
300
|
super().__init__()
|
304
301
|
|
302
|
+
self.pre_norm = pre_norm
|
303
|
+
|
305
304
|
self.identity_probability = identity_probability
|
306
305
|
|
307
|
-
self.
|
306
|
+
self.layer_norm_1 = nn.LayerNorm(d_model)
|
307
|
+
self.layer_norm_2 = nn.LayerNorm(d_model)
|
308
308
|
|
309
309
|
if position_embedding_type == "relative":
|
310
310
|
max_freq = int(max(source_size) / 2) # Suggested by Gemini!
|
@@ -339,6 +339,8 @@ class TransformerBlock(nn.Module):
|
|
339
339
|
activation_kwargs=activation_kwargs,
|
340
340
|
dropout=mlp_dropout,
|
341
341
|
linear_module=linear_module,
|
342
|
+
pre_norm=pre_norm,
|
343
|
+
normformer=normformer,
|
342
344
|
)
|
343
345
|
|
344
346
|
@property
|
@@ -359,12 +361,19 @@ class TransformerBlock(nn.Module):
|
|
359
361
|
identity_x = shuffled[:identity_count, :, :]
|
360
362
|
process_x = shuffled[identity_count:, :, :]
|
361
363
|
|
362
|
-
|
363
|
-
|
364
|
-
|
365
|
-
|
366
|
-
|
367
|
-
|
364
|
+
if self.pre_norm:
|
365
|
+
norm_process_x = self.layer_norm_1(process_x)
|
366
|
+
process_x = process_x + self.attn(
|
367
|
+
norm_process_x, norm_process_x, norm_process_x
|
368
|
+
)
|
369
|
+
process_x = process_x + self.ff(process_x)
|
370
|
+
else: # post-norm
|
371
|
+
process_x = process_x + self.attn(process_x, process_x, process_x)
|
372
|
+
norm_process_x = self.layer_norm_1(process_x)
|
373
|
+
process_x = process_x + self.ff(process_x)
|
374
|
+
x = self.layer_norm_2(
|
375
|
+
torch.cat([identity_x, process_x])[unshuffle_indices, :, :].contiguous()
|
376
|
+
)
|
368
377
|
|
369
378
|
return x
|
370
379
|
|
@@ -393,6 +402,8 @@ class TransformerEncoder(nn.Module):
|
|
393
402
|
linear_module=nn.Linear,
|
394
403
|
bos_tokens=0,
|
395
404
|
return_bos_tokens=False,
|
405
|
+
pre_norm=True,
|
406
|
+
normformer=False,
|
396
407
|
):
|
397
408
|
if position_embedding_type == "relative":
|
398
409
|
assert source_size is not None # TODO: make this a proper exception
|
@@ -451,6 +462,8 @@ class TransformerEncoder(nn.Module):
|
|
451
462
|
identity_probability=self.stochastic_depth_probabilities[i],
|
452
463
|
causal=causal,
|
453
464
|
linear_module=linear_module,
|
465
|
+
pre_norm=pre_norm,
|
466
|
+
normformer=normformer,
|
454
467
|
)
|
455
468
|
for i in range(n_layers)
|
456
469
|
]
|
@@ -117,6 +117,8 @@ class ViTEncoder(nn.Module):
|
|
117
117
|
pooling_kernel_stride=2,
|
118
118
|
pooling_padding=1,
|
119
119
|
transformer_feedforward_first=True,
|
120
|
+
transformer_pre_norm=True,
|
121
|
+
transformer_normformer=False,
|
120
122
|
transformer_position_embedding="relative", # absolute or relative
|
121
123
|
transformer_embedding_size=256,
|
122
124
|
transformer_layers=7,
|
@@ -289,6 +291,8 @@ class ViTEncoder(nn.Module):
|
|
289
291
|
linear_module=linear_module,
|
290
292
|
bos_tokens=transformer_bos_tokens,
|
291
293
|
return_bos_tokens=transformer_return_bos_tokens,
|
294
|
+
pre_norm=transformer_pre_norm,
|
295
|
+
normformer=transformer_normformer,
|
292
296
|
)
|
293
297
|
else:
|
294
298
|
self.transformer = nn.Identity()
|
@@ -302,7 +306,8 @@ class ViTEncoder(nn.Module):
|
|
302
306
|
activation_kwargs=transformer_activation_kwargs,
|
303
307
|
dropout=transformer_mlp_dropout,
|
304
308
|
linear_module=linear_module,
|
305
|
-
|
309
|
+
pre_norm=transformer_pre_norm,
|
310
|
+
normformer=transformer_normformer,
|
306
311
|
)
|
307
312
|
else:
|
308
313
|
self.initial_ff = nn.Identity()
|
@@ -356,6 +361,8 @@ class ViT(nn.Module):
|
|
356
361
|
pooling_kernel_stride=2,
|
357
362
|
pooling_padding=1,
|
358
363
|
transformer_feedforward_first=True,
|
364
|
+
transformer_pre_norm=True,
|
365
|
+
transformer_normformer=False,
|
359
366
|
transformer_position_embedding="relative", # absolute or relative
|
360
367
|
transformer_embedding_size=256,
|
361
368
|
transformer_layers=7,
|
@@ -410,6 +417,8 @@ class ViT(nn.Module):
|
|
410
417
|
pooling_kernel_stride=pooling_kernel_stride,
|
411
418
|
pooling_padding=pooling_padding,
|
412
419
|
transformer_feedforward_first=transformer_feedforward_first,
|
420
|
+
transformer_pre_norm=transformer_pre_norm,
|
421
|
+
transformer_normformer=transformer_normformer,
|
413
422
|
transformer_position_embedding=transformer_position_embedding,
|
414
423
|
transformer_embedding_size=transformer_embedding_size,
|
415
424
|
transformer_layers=transformer_layers,
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|