broccoli-ml 0.23.1__py3-none-any.whl → 0.24.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- broccoli/activation.py +69 -33
- broccoli/linear.py +47 -41
- broccoli/tensor.py +40 -0
- broccoli/transformer.py +5 -5
- broccoli/vit.py +2 -2
- {broccoli_ml-0.23.1.dist-info → broccoli_ml-0.24.1.dist-info}/METADATA +1 -1
- {broccoli_ml-0.23.1.dist-info → broccoli_ml-0.24.1.dist-info}/RECORD +9 -9
- {broccoli_ml-0.23.1.dist-info → broccoli_ml-0.24.1.dist-info}/LICENSE +0 -0
- {broccoli_ml-0.23.1.dist-info → broccoli_ml-0.24.1.dist-info}/WHEEL +0 -0
broccoli/activation.py
CHANGED
@@ -1,10 +1,50 @@
|
|
1
1
|
import torch
|
2
2
|
from torch import nn
|
3
3
|
from torch.nn import functional as F
|
4
|
-
from einops import rearrange
|
5
4
|
|
6
5
|
|
7
|
-
class
|
6
|
+
class ReLU(nn.Module):
|
7
|
+
"""
|
8
|
+
A ReLU activation function with optional clamp and leakiness.
|
9
|
+
"""
|
10
|
+
|
11
|
+
def __init__(
|
12
|
+
self, clamp=True, leaky=True, negative_slope=0.01, clamp_max=6.0
|
13
|
+
) -> None:
|
14
|
+
super().__init__()
|
15
|
+
self.clamp = clamp
|
16
|
+
self.leaky = leaky
|
17
|
+
self.negative_slope = negative_slope
|
18
|
+
self.clamp_max = clamp_max
|
19
|
+
|
20
|
+
def forward(self, x):
|
21
|
+
if self.leaky:
|
22
|
+
relu = F.leaky_relu(x, negative_slope=self.negative_slope)
|
23
|
+
else:
|
24
|
+
relu = F.relu(x)
|
25
|
+
if self.clamp:
|
26
|
+
relu = torch.clamp(relu, max=self.clamp_max)
|
27
|
+
return relu
|
28
|
+
|
29
|
+
|
30
|
+
class GELU(nn.Module):
|
31
|
+
"""
|
32
|
+
A GELU activation function with optional clamp.
|
33
|
+
"""
|
34
|
+
|
35
|
+
def __init__(self, clamp=True) -> None:
|
36
|
+
super().__init__()
|
37
|
+
self.clamp = clamp
|
38
|
+
self.gelu = nn.GELU()
|
39
|
+
|
40
|
+
def forward(self, x):
|
41
|
+
gelu = self.gelu(x)
|
42
|
+
if self.clamp:
|
43
|
+
gelu = torch.clamp(gelu, max=6)
|
44
|
+
return gelu
|
45
|
+
|
46
|
+
|
47
|
+
class Swish(nn.Module):
|
8
48
|
"""
|
9
49
|
Implementation of (beta) SwiGLU, as introduced in "GLU Variants Improve Transformer"
|
10
50
|
(https://arxiv.org/abs/2002.05202v1) and used to great effect in LLaMa 2.0.
|
@@ -16,12 +56,10 @@ class SwiGLU(nn.Module):
|
|
16
56
|
super().__init__()
|
17
57
|
# Learnable parameter is called "swiglu beta" so that it is easy to find
|
18
58
|
# and exclude from weight decay
|
19
|
-
self.
|
59
|
+
self.swish_beta = nn.Parameter(torch.tensor([1.0]))
|
20
60
|
|
21
61
|
def forward(self, x):
|
22
|
-
|
23
|
-
beta_swish = gate * F.sigmoid(self.swiglu_beta * gate)
|
24
|
-
return beta_swish * value
|
62
|
+
return x * F.sigmoid(self.swish_beta * x)
|
25
63
|
|
26
64
|
|
27
65
|
class SquaredReLU(nn.Module):
|
@@ -32,54 +70,52 @@ class SquaredReLU(nn.Module):
|
|
32
70
|
https://azizbelaweid.substack.com/p/what-is-swiglu-how-to-implement-it
|
33
71
|
"""
|
34
72
|
|
35
|
-
def __init__(
|
73
|
+
def __init__(
|
74
|
+
self, clamp=True, leaky=True, negative_slope: float = 0.01, clamp_max=6
|
75
|
+
) -> None:
|
36
76
|
super().__init__()
|
37
77
|
self.clamp = clamp
|
38
78
|
self.leaky = leaky
|
79
|
+
self.negative_slope = negative_slope
|
80
|
+
self.clamp_max = clamp_max
|
39
81
|
|
40
82
|
def forward(self, x):
|
41
83
|
if self.leaky:
|
42
|
-
relu = F.leaky_relu(x)
|
84
|
+
relu = F.leaky_relu(x, negative_slope=self.negative_slope)
|
43
85
|
else:
|
44
86
|
relu = F.relu(x)
|
45
87
|
relu_squared = relu**2
|
46
88
|
if self.clamp:
|
47
|
-
relu_squared = torch.clamp(relu_squared, max=
|
89
|
+
relu_squared = torch.clamp(relu_squared, max=self.clamp_max)
|
48
90
|
return relu_squared
|
49
91
|
|
50
92
|
|
51
|
-
class
|
93
|
+
class XGLU(nn.Module):
|
52
94
|
"""
|
53
|
-
|
95
|
+
Generic Gated Linear Unit
|
54
96
|
"""
|
55
97
|
|
56
|
-
def __init__(self,
|
98
|
+
def __init__(self, activation_module: nn.Module) -> None:
|
57
99
|
super().__init__()
|
58
|
-
self.
|
59
|
-
self.leaky = leaky
|
100
|
+
self.activation = activation_module
|
60
101
|
|
61
|
-
def forward(self, x):
|
62
|
-
|
63
|
-
|
64
|
-
else:
|
65
|
-
relu = F.relu(x)
|
66
|
-
if self.clamp:
|
67
|
-
relu = torch.clamp(relu, max=6)
|
68
|
-
return relu
|
102
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
103
|
+
gate, value = x.chunk(2, dim=-1)
|
104
|
+
return self.activation(gate) * value
|
69
105
|
|
70
106
|
|
71
|
-
|
107
|
+
def SquaredReGLU(clamp=True, leaky=True, negative_slope=0.01, clamp_max=6.0) -> XGLU:
|
72
108
|
"""
|
73
|
-
|
109
|
+
Factory function that creates a GLU with a SquaredReLU activation.
|
74
110
|
"""
|
111
|
+
activation_module = SquaredReLU(
|
112
|
+
clamp=clamp, leaky=leaky, negative_slope=negative_slope, clamp_max=clamp_max
|
113
|
+
)
|
114
|
+
return XGLU(activation_module)
|
75
115
|
|
76
|
-
def __init__(self, clamp=True) -> None:
|
77
|
-
super().__init__()
|
78
|
-
self.clamp = clamp
|
79
|
-
self.gelu = nn.GELU()
|
80
116
|
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
117
|
+
def SwiGLU() -> XGLU:
|
118
|
+
"""
|
119
|
+
Factory function that creates a GLU with a Swish activation.
|
120
|
+
"""
|
121
|
+
return XGLU(Swish())
|
broccoli/linear.py
CHANGED
@@ -5,7 +5,7 @@ import torch
|
|
5
5
|
from torch import nn
|
6
6
|
from torch.nn import functional as F
|
7
7
|
|
8
|
-
from .tensor import SigmaReparamTensor
|
8
|
+
from .tensor import SigmaReparamTensor, AnchoredReparamTensor
|
9
9
|
|
10
10
|
|
11
11
|
class SpectralNormLinear(nn.Module):
|
@@ -22,8 +22,6 @@ class SpectralNormLinear(nn.Module):
|
|
22
22
|
|
23
23
|
self.weights = None
|
24
24
|
|
25
|
-
self.weight_init = nn.Parameter(torch.empty(out_features, in_features))
|
26
|
-
|
27
25
|
# Define the bias vector as a learnable parameter if required.
|
28
26
|
if self.use_bias:
|
29
27
|
self.bias = nn.Parameter(torch.empty(out_features))
|
@@ -35,12 +33,13 @@ class SpectralNormLinear(nn.Module):
|
|
35
33
|
self.reset_parameters()
|
36
34
|
|
37
35
|
def reset_parameters(self) -> None:
|
38
|
-
|
36
|
+
weights = torch.empty(self.out_features, self.in_features)
|
37
|
+
nn.init.kaiming_uniform_(weights, a=math.sqrt(5))
|
39
38
|
if self.use_bias:
|
40
|
-
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(
|
39
|
+
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(weights)
|
41
40
|
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
|
42
41
|
nn.init.uniform_(self.bias, -bound, bound)
|
43
|
-
self.weights = SigmaReparamTensor(
|
42
|
+
self.weights = SigmaReparamTensor(weights)
|
44
43
|
|
45
44
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
46
45
|
return F.linear(x, self.weights(), self.bias)
|
@@ -48,42 +47,49 @@ class SpectralNormLinear(nn.Module):
|
|
48
47
|
def __repr__(self) -> str:
|
49
48
|
# Optional: A nice representation for printing the module.
|
50
49
|
return (
|
51
|
-
f"SpectralNormFeedForward(in_features={self.in_features}"
|
52
|
-
f"out_features={self.out_features}, bias={self.use_bias})"
|
50
|
+
f"SpectralNormFeedForward(in_features={self.in_features},"
|
51
|
+
f"out_features={self.out_features}, bias={self.use_bias})"
|
53
52
|
)
|
54
53
|
|
55
54
|
|
56
|
-
class
|
57
|
-
"""
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
if not self.training:
|
73
|
-
return F.linear(inputs, self.weight)
|
55
|
+
class AnchoredLinear(nn.Module):
|
56
|
+
"""
|
57
|
+
...
|
58
|
+
"""
|
59
|
+
|
60
|
+
def __init__(self, in_features: int, out_features: int, bias: bool = True):
|
61
|
+
super().__init__()
|
62
|
+
self.in_features = in_features
|
63
|
+
self.out_features = out_features
|
64
|
+
self.use_bias = bias
|
65
|
+
|
66
|
+
self.weights = None
|
67
|
+
|
68
|
+
# Define the bias vector as a learnable parameter if required.
|
69
|
+
if self.use_bias:
|
70
|
+
self.bias = nn.Parameter(torch.empty(out_features))
|
74
71
|
else:
|
75
|
-
#
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
72
|
+
# If no bias, register it as None.
|
73
|
+
# This is important so that PyTorch doesn't complain when saving/loading the model.
|
74
|
+
self.register_parameter("bias", None)
|
75
|
+
|
76
|
+
self.reset_parameters()
|
77
|
+
|
78
|
+
def reset_parameters(self) -> None:
|
79
|
+
weights = torch.empty(self.out_features, self.in_features)
|
80
|
+
nn.init.kaiming_uniform_(weights, a=math.sqrt(5))
|
81
|
+
if self.use_bias:
|
82
|
+
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(weights)
|
83
|
+
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
|
84
|
+
nn.init.uniform_(self.bias, -bound, bound)
|
85
|
+
self.weights = AnchoredReparamTensor(weights)
|
86
|
+
|
87
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
88
|
+
return F.linear(x, self.weights(), self.bias)
|
89
|
+
|
90
|
+
def __repr__(self) -> str:
|
91
|
+
# Optional: A nice representation for printing the module.
|
92
|
+
return (
|
93
|
+
f"AnchoredLinear(in_features={self.in_features},"
|
94
|
+
f"out_features={self.out_features}, bias={self.use_bias})"
|
95
|
+
)
|
broccoli/tensor.py
CHANGED
@@ -54,3 +54,43 @@ class SigmaReparamTensor(nn.Module):
|
|
54
54
|
return self.sigma_reparam_scale * (
|
55
55
|
self.sigma_reparam_tensor / self.approx_spectral_norm
|
56
56
|
)
|
57
|
+
|
58
|
+
|
59
|
+
class AnchoredReparamTensor(nn.Module):
|
60
|
+
"""
|
61
|
+
Reparameterise a tensor as a normalised tensor of weights multiplied by a
|
62
|
+
learnable scaling factor.
|
63
|
+
|
64
|
+
The tensor of weights is also reparameterised as the product of a learnable
|
65
|
+
weight tensor with the (fixed) dominant right-singular vector of the
|
66
|
+
weight tensor as it was initialised.
|
67
|
+
|
68
|
+
i.e this module represents a tensor reparameterised as:
|
69
|
+
|
70
|
+
W_reparam = scale * (W / ||W @ v_0||_2)
|
71
|
+
|
72
|
+
where v_0 is the dominant right-singular vector of the initial tensor W_init.
|
73
|
+
"""
|
74
|
+
|
75
|
+
def __init__(self, init_tensor: torch.Tensor):
|
76
|
+
assert init_tensor.ndim == 2, "Input tensor must be a 2D matrix."
|
77
|
+
super().__init__()
|
78
|
+
|
79
|
+
self.weight = nn.Parameter(init_tensor.clone(), requires_grad=True)
|
80
|
+
|
81
|
+
# At initialization, compute the dominant right-singular vector (v_0)
|
82
|
+
# and store it in a non-trainable buffer.
|
83
|
+
with torch.no_grad():
|
84
|
+
_, _, v_transpose = torch.linalg.svd(self.weight, full_matrices=False)
|
85
|
+
# v_transpose[0] is the first row of V^T, which is the first right-singular vector.
|
86
|
+
self.register_buffer("anchor_vector", v_transpose[0])
|
87
|
+
|
88
|
+
initial_norm = torch.linalg.vector_norm(self.weight.mv(self.anchor_vector))
|
89
|
+
self.scale = nn.Parameter(initial_norm.clone().detach(), requires_grad=True)
|
90
|
+
|
91
|
+
def forward(self) -> torch.Tensor:
|
92
|
+
# Calculate the L2 norm of the matrix-vector product W @ v_0
|
93
|
+
norm = torch.linalg.vector_norm(self.weight.mv(self.anchor_vector))
|
94
|
+
|
95
|
+
# Return the reparameterized tensor.
|
96
|
+
return self.scale * (self.weight / (norm + 1e-6))
|
broccoli/transformer.py
CHANGED
@@ -10,7 +10,7 @@ import torch.nn.functional as F
|
|
10
10
|
from einops import rearrange
|
11
11
|
|
12
12
|
from .rope import RotaryEmbedding, apply_rotary_emb
|
13
|
-
from .linear import
|
13
|
+
from .linear import AnchoredLinear
|
14
14
|
|
15
15
|
|
16
16
|
class MHAttention(nn.Module):
|
@@ -236,7 +236,7 @@ class FeedforwardBlock(nn.Module):
|
|
236
236
|
activation_kwargs=None,
|
237
237
|
dropout=0.0,
|
238
238
|
linear_module=nn.Linear,
|
239
|
-
|
239
|
+
reparam=False,
|
240
240
|
):
|
241
241
|
super().__init__()
|
242
242
|
|
@@ -253,8 +253,8 @@ class FeedforwardBlock(nn.Module):
|
|
253
253
|
else ratio * output_features
|
254
254
|
)
|
255
255
|
|
256
|
-
if
|
257
|
-
self.memory_type =
|
256
|
+
if reparam:
|
257
|
+
self.memory_type = AnchoredLinear
|
258
258
|
else:
|
259
259
|
self.memory_type = linear_module
|
260
260
|
|
@@ -263,7 +263,7 @@ class FeedforwardBlock(nn.Module):
|
|
263
263
|
nn.LayerNorm(input_features),
|
264
264
|
linear_module(input_features, self.max_features),
|
265
265
|
self.activation,
|
266
|
-
nn.LayerNorm(ratio * output_features),
|
266
|
+
# nn.LayerNorm(ratio * output_features),
|
267
267
|
self.memory_type(ratio * output_features, output_features),
|
268
268
|
self.dropout,
|
269
269
|
]
|
broccoli/vit.py
CHANGED
@@ -295,14 +295,14 @@ class ViTEncoder(nn.Module):
|
|
295
295
|
|
296
296
|
if transformer_feedforward_first:
|
297
297
|
self.initial_ff = FeedforwardBlock(
|
298
|
-
transformer_embedding_size,
|
298
|
+
max(transformer_embedding_size, pooling_out_channels),
|
299
299
|
transformer_mlp_ratio,
|
300
300
|
transformer_embedding_size,
|
301
301
|
activation=transformer_activation,
|
302
302
|
activation_kwargs=transformer_activation_kwargs,
|
303
303
|
dropout=transformer_mlp_dropout,
|
304
304
|
linear_module=linear_module,
|
305
|
-
|
305
|
+
reparam=not cnn,
|
306
306
|
)
|
307
307
|
else:
|
308
308
|
self.initial_ff = nn.Identity()
|
@@ -1,17 +1,17 @@
|
|
1
1
|
broccoli/__init__.py,sha256=tmyspsVxqPZHRQCY_NRwpW4SMNBbtE8E_8z7l-SAzSo,127
|
2
|
-
broccoli/activation.py,sha256
|
2
|
+
broccoli/activation.py,sha256=-Jf30C6iGqWCorC9HEGn2oduWwjeaCAxGLUUYIy1zX8,3438
|
3
3
|
broccoli/assets/2025_resnet_imagenet_1k_pretrained_state_dict.pkl,sha256=RZpPupWxFaVfgZrK-gBgfW1hj78oMEGhVWTbjRB3qMo,46835797
|
4
4
|
broccoli/assets/cifar100_eigenvectors_size_2.pt,sha256=DjXDOXMeuMpIqNuGhX9z-OWYVqZwIMScSXZApRr9JjU,2501
|
5
5
|
broccoli/assets/cifar100_eigenvectors_size_3.pt,sha256=gL6k0xtXYiYP6ZSvEiMBdJ7kIkT0AngTpDJHFQqwgxA,7173
|
6
6
|
broccoli/cnn.py,sha256=jeRyKIAMWu1E3iyI14MGgSZuZivPMh12iqkqW9ilNjo,17785
|
7
7
|
broccoli/eigenpatches.py,sha256=J6n2usN1oQuHEHYiBNyYpn_a9eQcHjOBiIlvSei520Y,2413
|
8
|
-
broccoli/linear.py,sha256=
|
8
|
+
broccoli/linear.py,sha256=4bxVDsO8E1d5-RZ23u160ZntazrT7Vt4AYTdAdCQU-w,3300
|
9
9
|
broccoli/rope.py,sha256=hw7kBPNR9GQXj4GxyIAffsGKPfcTPOFh8Bc7oEHtaZY,12108
|
10
|
-
broccoli/tensor.py,sha256=
|
11
|
-
broccoli/transformer.py,sha256=
|
10
|
+
broccoli/tensor.py,sha256=_YJP9tSFRkoKrR7cfnROSpWqfMyJLjgPmtFxEWRwgz8,3606
|
11
|
+
broccoli/transformer.py,sha256=L1bVQZLUbtFtOy30yPVkjnqyELGhQoHJ_lFP_WPfYUA,16073
|
12
12
|
broccoli/utils.py,sha256=htq_hOsdhUhL0nJi9WkKiEYOjEoWqFpK5X49PtgTf-0,299
|
13
|
-
broccoli/vit.py,sha256=
|
14
|
-
broccoli_ml-0.
|
15
|
-
broccoli_ml-0.
|
16
|
-
broccoli_ml-0.
|
17
|
-
broccoli_ml-0.
|
13
|
+
broccoli/vit.py,sha256=qGCx4cnpAkPpVHFrz6bFHdnPJXPaCxtTxKlI9YQJZWg,15649
|
14
|
+
broccoli_ml-0.24.1.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
|
15
|
+
broccoli_ml-0.24.1.dist-info/METADATA,sha256=HOchT-ECPmQWjc0nQN7ohhOiKUbOqBVO_yKJLh_k9b8,1257
|
16
|
+
broccoli_ml-0.24.1.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
17
|
+
broccoli_ml-0.24.1.dist-info/RECORD,,
|
File without changes
|
File without changes
|