broccoli-ml 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- broccoli_ml-0.1.0/LICENSE +21 -0
- broccoli_ml-0.1.0/PKG-INFO +19 -0
- broccoli_ml-0.1.0/README.md +2 -0
- broccoli_ml-0.1.0/broccoli/__init__.py +0 -0
- broccoli_ml-0.1.0/broccoli/activation.py +24 -0
- broccoli_ml-0.1.0/broccoli/cct.py +249 -0
- broccoli_ml-0.1.0/broccoli/cnn.py +400 -0
- broccoli_ml-0.1.0/broccoli/eigenpatches.py +39 -0
- broccoli_ml-0.1.0/broccoli/linear.py +41 -0
- broccoli_ml-0.1.0/broccoli/tensor.py +49 -0
- broccoli_ml-0.1.0/broccoli/transformer.py +260 -0
- broccoli_ml-0.1.0/pyproject.toml +43 -0
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2025 nicholasbailey87
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
Metadata-Version: 2.3
|
|
2
|
+
Name: broccoli-ml
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Some useful Pytorch models, circa 2025
|
|
5
|
+
License: MIT
|
|
6
|
+
Author: Nicholas Bailey
|
|
7
|
+
Requires-Python: >=3.12
|
|
8
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
9
|
+
Classifier: Programming Language :: Python :: 3
|
|
10
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
11
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
12
|
+
Requires-Dist: einops (>=0.8.1,<0.9.0)
|
|
13
|
+
Requires-Dist: numpy (>=2.3.1,<3.0.0)
|
|
14
|
+
Requires-Dist: torch (>=2.7.1,<3.0.0)
|
|
15
|
+
Description-Content-Type: text/markdown
|
|
16
|
+
|
|
17
|
+
# broccoli
|
|
18
|
+
Some useful Pytorch models, circa 2025
|
|
19
|
+
|
|
File without changes
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from torch import nn
|
|
3
|
+
from torch.nn import functional as F
|
|
4
|
+
from einops import rearrange
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class SwiGLU(nn.Module):
|
|
8
|
+
"""
|
|
9
|
+
Implementation of (beta) SwiGLU, as introduced in "GLU Variants Improve Transformer"
|
|
10
|
+
(https://arxiv.org/abs/2002.05202v1) and used to great effect in LLaMa 2.0.
|
|
11
|
+
|
|
12
|
+
Halves the incoming parameter count, which should be scaled up before input.
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
def __init__(self, linear_module: nn.Module = nn.Linear) -> None:
|
|
16
|
+
super().__init__()
|
|
17
|
+
# Learnable parameter is called "swiglu beta" so that it is easy to find
|
|
18
|
+
# and exclude from weight decay
|
|
19
|
+
self.swiglu_beta = nn.Parameter(torch.tensor([0.0]))
|
|
20
|
+
|
|
21
|
+
def forward(self, x):
|
|
22
|
+
gate, value = rearrange(x, "... (split c) -> split ... c", split=2)
|
|
23
|
+
beta_swish = gate * F.sigmoid(self.swiglu_beta * gate)
|
|
24
|
+
return beta_swish * value
|
|
@@ -0,0 +1,249 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
from transformer import TransformerEncoder
|
|
4
|
+
from cnn import ConvLayer, ConcatPool, WhiteningConv
|
|
5
|
+
from einops import einsum, rearrange
|
|
6
|
+
from einops.layers.torch import Rearrange
|
|
7
|
+
import torch.nn as nn
|
|
8
|
+
import torch.nn.functional as F
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class SequencePool(nn.Module):
|
|
12
|
+
"""
|
|
13
|
+
As described in [Hasani et al. (2021) *''Escaping the Big Data Paradigm with
|
|
14
|
+
Compact Transformers''*](https://arxiv.org/abs/2104.05704). It can be viewed
|
|
15
|
+
as. ageneralisation of average pooling.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
def __init__(self, d_model, linear_module, out_dim):
|
|
19
|
+
super().__init__()
|
|
20
|
+
self.d_model = d_model
|
|
21
|
+
self.attention = nn.Sequential(
|
|
22
|
+
*[
|
|
23
|
+
linear_module(d_model, 1),
|
|
24
|
+
Rearrange("batch seq 1 -> batch seq"),
|
|
25
|
+
nn.Softmax(dim=-1),
|
|
26
|
+
]
|
|
27
|
+
)
|
|
28
|
+
self.projection = nn.Linear(d_model, out_dim)
|
|
29
|
+
self.norm = nn.BatchNorm1d(out_dim, affine=False)
|
|
30
|
+
|
|
31
|
+
def forward(self, x):
|
|
32
|
+
weights = self.attention(x)
|
|
33
|
+
weighted_embedding = einsum(
|
|
34
|
+
weights, x, "batch seq, batch seq d_model -> batch d_model"
|
|
35
|
+
)
|
|
36
|
+
projection = self.projection(weighted_embedding)
|
|
37
|
+
return self.norm(projection)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class CCTEncoder(nn.Module):
|
|
41
|
+
"""
|
|
42
|
+
Based on the Compact Convolutional Transformer (CCT) of [Hasani et al. (2021)
|
|
43
|
+
*''Escaping the Big Data Paradigm with Compact Transformers''*](
|
|
44
|
+
https://arxiv.org/abs/2104.05704). It's basically a convolutional neural
|
|
45
|
+
network leading into a transformer encoder. To make it like the full CCT
|
|
46
|
+
we would finish it of with a sequence pooling layer but we won't always
|
|
47
|
+
want to do that.
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
def __init__(
|
|
51
|
+
self,
|
|
52
|
+
image_size=32,
|
|
53
|
+
conv_kernel_size=3, # Only 2 is supported for eigenvector initialisation
|
|
54
|
+
whitening=False,
|
|
55
|
+
pooling_type="maxpool",
|
|
56
|
+
pooling_kernel_size=3,
|
|
57
|
+
pooling_kernel_stride=2,
|
|
58
|
+
pooling_kernel_padding=1,
|
|
59
|
+
transformer_embedding_size=256,
|
|
60
|
+
transformer_layers=7,
|
|
61
|
+
transformer_heads=4,
|
|
62
|
+
transformer_mlp_ratio=2,
|
|
63
|
+
transformer_bos_tokens=1,
|
|
64
|
+
activation: nn.Module = nn.ReLU,
|
|
65
|
+
activation_kwargs: Optional[dict] = None,
|
|
66
|
+
mlp_dropout=0.0,
|
|
67
|
+
msa_dropout=0.1,
|
|
68
|
+
stochastic_depth=0.1,
|
|
69
|
+
linear_module=nn.Linear,
|
|
70
|
+
image_channels=3,
|
|
71
|
+
):
|
|
72
|
+
# TODO: turn these into proper exceptions
|
|
73
|
+
if whitening and (conv_kernel_size != 2):
|
|
74
|
+
print("We currently only support whitening for kernel size 2!")
|
|
75
|
+
assert not (whitening and (conv_kernel_size != 2))
|
|
76
|
+
if pooling_type not in ["maxpool", "concat"]:
|
|
77
|
+
print("Pooling type must be maxpool or concat")
|
|
78
|
+
assert pooling_type in ["maxpool", "concat"]
|
|
79
|
+
|
|
80
|
+
if activation_kwargs is not None:
|
|
81
|
+
self.activation = activation(**activation_kwargs)
|
|
82
|
+
else:
|
|
83
|
+
self.activation = activation()
|
|
84
|
+
|
|
85
|
+
super().__init__()
|
|
86
|
+
self.image_size = image_size
|
|
87
|
+
output_size = (
|
|
88
|
+
(image_size + 2 * pooling_kernel_padding - pooling_kernel_size)
|
|
89
|
+
/ pooling_kernel_stride
|
|
90
|
+
) + 1
|
|
91
|
+
|
|
92
|
+
self.sequence_length = int(output_size) ** 2
|
|
93
|
+
|
|
94
|
+
if pooling_type == "maxpool":
|
|
95
|
+
|
|
96
|
+
conv_out_channels = transformer_embedding_size
|
|
97
|
+
|
|
98
|
+
self.pool = nn.Sequential(
|
|
99
|
+
*[
|
|
100
|
+
Rearrange(
|
|
101
|
+
"N C H W -> N H W C"
|
|
102
|
+
), # rearrange in case we're using XGLU activation
|
|
103
|
+
self.activation,
|
|
104
|
+
Rearrange("N H W C -> N C H W"),
|
|
105
|
+
nn.MaxPool2d(
|
|
106
|
+
pooling_kernel_size,
|
|
107
|
+
stride=pooling_kernel_stride,
|
|
108
|
+
padding=pooling_kernel_padding,
|
|
109
|
+
),
|
|
110
|
+
Rearrange("N C H W -> N (H W) C"),
|
|
111
|
+
]
|
|
112
|
+
)
|
|
113
|
+
elif pooling_type == "concat":
|
|
114
|
+
|
|
115
|
+
conv_out_channels = int(
|
|
116
|
+
round(transformer_embedding_size / (pooling_kernel_size**2))
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
self.pool = ConcatPool(
|
|
120
|
+
conv_out_channels,
|
|
121
|
+
pooling_kernel_size,
|
|
122
|
+
pooling_kernel_stride,
|
|
123
|
+
pooling_kernel_padding,
|
|
124
|
+
transformer_embedding_size,
|
|
125
|
+
activation,
|
|
126
|
+
activation_kwargs,
|
|
127
|
+
linear_module,
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
if transformer_layers > 0:
|
|
131
|
+
self.transformer = TransformerEncoder(
|
|
132
|
+
self.sequence_length,
|
|
133
|
+
transformer_embedding_size,
|
|
134
|
+
transformer_layers,
|
|
135
|
+
transformer_heads,
|
|
136
|
+
mlp_ratio=transformer_mlp_ratio,
|
|
137
|
+
activation=activation,
|
|
138
|
+
activation_kwargs=activation_kwargs,
|
|
139
|
+
mlp_dropout=mlp_dropout,
|
|
140
|
+
msa_dropout=msa_dropout,
|
|
141
|
+
stochastic_depth=stochastic_depth,
|
|
142
|
+
causal=False,
|
|
143
|
+
linear_module=linear_module,
|
|
144
|
+
bos_tokens=transformer_bos_tokens,
|
|
145
|
+
)
|
|
146
|
+
else:
|
|
147
|
+
self.transformer = nn.Identity()
|
|
148
|
+
|
|
149
|
+
# This code block rhymes:
|
|
150
|
+
if activation.__name__.endswith("GLU"):
|
|
151
|
+
conv_out_channels *= 2
|
|
152
|
+
|
|
153
|
+
if whitening:
|
|
154
|
+
whitening_conv_out_channels = conv_kernel_size**2 * image_channels * 2
|
|
155
|
+
self.conv = nn.Sequential(
|
|
156
|
+
*[
|
|
157
|
+
WhiteningConv(
|
|
158
|
+
in_channels=image_channels,
|
|
159
|
+
kernel_size=conv_kernel_size,
|
|
160
|
+
linear_module=linear_module,
|
|
161
|
+
),
|
|
162
|
+
Rearrange("N C H W -> N H W C"),
|
|
163
|
+
nn.Linear(whitening_conv_out_channels, conv_out_channels),
|
|
164
|
+
Rearrange("N H W C -> N C H W"),
|
|
165
|
+
]
|
|
166
|
+
)
|
|
167
|
+
else:
|
|
168
|
+
self.conv = ConvLayer(
|
|
169
|
+
image_channels,
|
|
170
|
+
conv_out_channels,
|
|
171
|
+
kernel_size=conv_kernel_size,
|
|
172
|
+
stride=1,
|
|
173
|
+
padding="same",
|
|
174
|
+
linear_module=linear_module,
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
self.encoder = nn.Sequential(*[self.conv, self.pool, self.transformer])
|
|
178
|
+
|
|
179
|
+
def forward(self, x):
|
|
180
|
+
return self.encoder(x)
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
class CCT(nn.Module):
|
|
184
|
+
"""
|
|
185
|
+
Based on the Compact Convolutional Transformer (CCT) of [Hasani et al. (2021)
|
|
186
|
+
*''Escaping the Big Data Paradigm with Compact Transformers''*](
|
|
187
|
+
https://arxiv.org/abs/2104.05704). It's a convolutional neural network
|
|
188
|
+
leading into a transformer encoder, followed by a sequence pooling layer.
|
|
189
|
+
"""
|
|
190
|
+
|
|
191
|
+
def __init__(
|
|
192
|
+
self,
|
|
193
|
+
image_size=32,
|
|
194
|
+
conv_kernel_size=3, # Only 2 is supported for eigenvector initialisation
|
|
195
|
+
whitening=False,
|
|
196
|
+
pooling_type="maxpool",
|
|
197
|
+
pooling_kernel_size=3,
|
|
198
|
+
pooling_kernel_stride=2,
|
|
199
|
+
pooling_kernel_padding=1,
|
|
200
|
+
transformer_embedding_size=256,
|
|
201
|
+
transformer_layers=7,
|
|
202
|
+
transformer_heads=4,
|
|
203
|
+
transformer_mlp_ratio=2,
|
|
204
|
+
transformer_bos_tokens=1,
|
|
205
|
+
activation: nn.Module = nn.ReLU,
|
|
206
|
+
activation_kwargs: Optional[dict] = None,
|
|
207
|
+
mlp_dropout=0.0, # The original paper got best performance from mlp_dropout=0.
|
|
208
|
+
msa_dropout=0.1, # "" msa_dropout=0.1
|
|
209
|
+
stochastic_depth=0.1, # "" stochastic_depth=0.1
|
|
210
|
+
image_classes=100,
|
|
211
|
+
linear_module=nn.Linear,
|
|
212
|
+
image_channels=3,
|
|
213
|
+
):
|
|
214
|
+
|
|
215
|
+
assert not (whitening and (conv_kernel_size != 2))
|
|
216
|
+
|
|
217
|
+
super().__init__()
|
|
218
|
+
|
|
219
|
+
self.encoder = CCTEncoder(
|
|
220
|
+
image_size,
|
|
221
|
+
conv_kernel_size,
|
|
222
|
+
whitening,
|
|
223
|
+
pooling_type,
|
|
224
|
+
pooling_kernel_size,
|
|
225
|
+
pooling_kernel_stride,
|
|
226
|
+
pooling_kernel_padding,
|
|
227
|
+
transformer_embedding_size,
|
|
228
|
+
transformer_layers,
|
|
229
|
+
transformer_heads,
|
|
230
|
+
transformer_mlp_ratio,
|
|
231
|
+
transformer_bos_tokens,
|
|
232
|
+
activation=activation,
|
|
233
|
+
activation_kwargs=activation_kwargs,
|
|
234
|
+
mlp_dropout=mlp_dropout,
|
|
235
|
+
msa_dropout=msa_dropout,
|
|
236
|
+
stochastic_depth=stochastic_depth,
|
|
237
|
+
linear_module=linear_module,
|
|
238
|
+
image_channels=image_channels,
|
|
239
|
+
)
|
|
240
|
+
self.pool = SequencePool(
|
|
241
|
+
transformer_embedding_size, linear_module, image_classes
|
|
242
|
+
)
|
|
243
|
+
|
|
244
|
+
@property
|
|
245
|
+
def sequence_length(self):
|
|
246
|
+
return self.encoder.sequence_length
|
|
247
|
+
|
|
248
|
+
def forward(self, x):
|
|
249
|
+
return self.pool(self.encoder(x))
|
|
@@ -0,0 +1,400 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
import torch.nn.functional as F
|
|
4
|
+
from torch.nn.modules.utils import _pair
|
|
5
|
+
from einops import rearrange, repeat
|
|
6
|
+
import math
|
|
7
|
+
from typing import Type, Union, Tuple, Optional, Literal
|
|
8
|
+
|
|
9
|
+
from einops.layers.torch import Rearrange
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
# Helper function to calculate padding for 'same' mode
|
|
13
|
+
# Adapted from https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/conv.py
|
|
14
|
+
def _calculate_same_padding(
|
|
15
|
+
input_size: Tuple[int, int],
|
|
16
|
+
kernel_size: Tuple[int, int],
|
|
17
|
+
stride: Tuple[int, int],
|
|
18
|
+
dilation: Tuple[int, int],
|
|
19
|
+
) -> Tuple[int, int, int, int]:
|
|
20
|
+
"""Calculates padding for 'same' output shape."""
|
|
21
|
+
ih, iw = input_size
|
|
22
|
+
kh, kw = kernel_size
|
|
23
|
+
sh, sw = stride
|
|
24
|
+
dh, dw = dilation
|
|
25
|
+
|
|
26
|
+
# Effective kernel size
|
|
27
|
+
eff_kh = (kh - 1) * dh + 1
|
|
28
|
+
eff_kw = (kw - 1) * dw + 1
|
|
29
|
+
|
|
30
|
+
# Calculate required total padding
|
|
31
|
+
out_h = (ih + sh - 1) // sh
|
|
32
|
+
out_w = (iw + sw - 1) // sw
|
|
33
|
+
pad_h = max((out_h - 1) * sh + eff_kh - ih, 0)
|
|
34
|
+
pad_w = max((out_w - 1) * sw + eff_kw - iw, 0)
|
|
35
|
+
|
|
36
|
+
# Distribute padding (similar to TensorFlow 'SAME' behavior)
|
|
37
|
+
pad_top = pad_h // 2
|
|
38
|
+
pad_bottom = pad_h - pad_top
|
|
39
|
+
pad_left = pad_w // 2
|
|
40
|
+
pad_right = pad_w - pad_left
|
|
41
|
+
return (pad_left, pad_right, pad_top, pad_bottom)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
# Custom Convolution Layer
|
|
45
|
+
class ConvLayer(nn.Module):
|
|
46
|
+
"""
|
|
47
|
+
A 2D Convolution layer implemented using torch.nn.Unfold and a custom linear layer.
|
|
48
|
+
|
|
49
|
+
This layer mimics the behavior of torch.nn.Conv2d but allows injecting
|
|
50
|
+
a different linear layer implementation for processing the unfolded patches.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
in_channels (int): Number of channels in the input image.
|
|
54
|
+
out_channels (int): Number of channels produced by the convolution.
|
|
55
|
+
kernel_size (int or tuple): Size of the convolving kernel.
|
|
56
|
+
stride (int or tuple, optional): Stride of the convolution. Default: 1.
|
|
57
|
+
padding (int, tuple or str, optional): Padding added to all four sides
|
|
58
|
+
of the input. Can be an int, a tuple of two ints (padH, padW),
|
|
59
|
+
a tuple of four ints (padLeft, padRight, padTop, padBottom),
|
|
60
|
+
or the strings 'valid' (no padding) or 'same' (padding for same
|
|
61
|
+
output spatial dims as input). Default: 0 ('valid').
|
|
62
|
+
dilation (int or tuple, optional): Spacing between kernel elements. Default: 1.
|
|
63
|
+
bias (bool, optional): If True, adds a learnable bias to the output.
|
|
64
|
+
The bias is handled by the underlying linear layer. Default: True.
|
|
65
|
+
linear (Type[nn.Module], optional): The class of the linear layer
|
|
66
|
+
to use for the kernel operation. Must accept (in_features, out_features, bias)
|
|
67
|
+
in its constructor. Defaults to torch.nn.Linear.
|
|
68
|
+
"""
|
|
69
|
+
|
|
70
|
+
def __init__(
|
|
71
|
+
self,
|
|
72
|
+
in_channels: int,
|
|
73
|
+
out_channels: int,
|
|
74
|
+
kernel_size: Union[int, Tuple[int, int]],
|
|
75
|
+
stride: Union[int, Tuple[int, int]] = 1,
|
|
76
|
+
padding: Union[
|
|
77
|
+
int, Tuple[int, int], Tuple[int, int, int, int], Literal["valid", "same"]
|
|
78
|
+
] = 0,
|
|
79
|
+
dilation: Union[int, Tuple[int, int]] = 1,
|
|
80
|
+
bias: bool = True,
|
|
81
|
+
linear_module: Type[nn.Module] = nn.Linear,
|
|
82
|
+
):
|
|
83
|
+
super().__init__()
|
|
84
|
+
self.in_channels = in_channels
|
|
85
|
+
self.out_channels = out_channels
|
|
86
|
+
self.kernel_size = _pair(kernel_size)
|
|
87
|
+
self.stride = _pair(stride)
|
|
88
|
+
self.dilation = _pair(dilation)
|
|
89
|
+
self.bias = bias
|
|
90
|
+
self.linear_module = linear_module
|
|
91
|
+
self.padding_mode = (
|
|
92
|
+
padding # Store the original padding mode ('same', 'valid', int, or tuple)
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
# Calculate the number of input features for the linear layer
|
|
96
|
+
# It's the number of channels times the kernel area
|
|
97
|
+
self.linear_in_features = (
|
|
98
|
+
in_channels * self.kernel_size[0] * self.kernel_size[1]
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
# Instantiate the linear layer (kernel)
|
|
102
|
+
self.kernel = self.linear_module(
|
|
103
|
+
self.linear_in_features, out_channels, bias=bias
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
# We will use F.pad for manual padding, so unfold padding is 0
|
|
107
|
+
self.unfold = nn.Unfold(
|
|
108
|
+
kernel_size=self.kernel_size,
|
|
109
|
+
dilation=self.dilation,
|
|
110
|
+
padding=0, # Manual padding handled in forward
|
|
111
|
+
stride=self.stride,
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
# Determine numeric padding values for F.pad
|
|
115
|
+
if isinstance(padding, str):
|
|
116
|
+
if padding not in ["valid", "same"]:
|
|
117
|
+
raise ValueError("padding must be 'valid', 'same', an int, or a tuple")
|
|
118
|
+
# 'same' padding calculation depends on input size, defer to forward pass
|
|
119
|
+
# 'valid' padding means 0
|
|
120
|
+
self._padding_val = (
|
|
121
|
+
(0, 0, 0, 0) if padding == "valid" else None
|
|
122
|
+
) # None indicates 'same'
|
|
123
|
+
elif isinstance(padding, int):
|
|
124
|
+
self._padding_val = (padding,) * 4
|
|
125
|
+
elif isinstance(padding, tuple) and len(padding) == 2:
|
|
126
|
+
# (padH, padW) -> (padW_left, padW_right, padH_top, padH_bottom)
|
|
127
|
+
self._padding_val = (padding[1], padding[1], padding[0], padding[0])
|
|
128
|
+
elif isinstance(padding, tuple) and len(padding) == 4:
|
|
129
|
+
# (padLeft, padRight, padTop, padBottom) - already in F.pad format
|
|
130
|
+
self._padding_val = padding
|
|
131
|
+
else:
|
|
132
|
+
raise TypeError(
|
|
133
|
+
"padding must be 'valid', 'same', an int, or a tuple of 2 or 4 ints"
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
def _calculate_output_shape(self, h_in: int, w_in: int) -> Tuple[int, int]:
|
|
137
|
+
"""Calculates the output height and width."""
|
|
138
|
+
if self._padding_val is None: # 'same' padding
|
|
139
|
+
# For 'same' padding, output size matches input size if stride is 1.
|
|
140
|
+
# If stride > 1, output size is ceil(input_size / stride)
|
|
141
|
+
# The _calculate_same_padding helper ensures this behavior.
|
|
142
|
+
oh = math.ceil(h_in / self.stride[0])
|
|
143
|
+
ow = math.ceil(w_in / self.stride[1])
|
|
144
|
+
return oh, ow
|
|
145
|
+
else:
|
|
146
|
+
# Use the standard formula with the calculated numeric padding
|
|
147
|
+
pad_h = self._padding_val[2] + self._padding_val[3] # top + bottom
|
|
148
|
+
pad_w = self._padding_val[0] + self._padding_val[1] # left + right
|
|
149
|
+
kh, kw = self.kernel_size
|
|
150
|
+
sh, sw = self.stride
|
|
151
|
+
dh, dw = self.dilation
|
|
152
|
+
|
|
153
|
+
eff_kh = (kh - 1) * dh + 1
|
|
154
|
+
eff_kw = (kw - 1) * dw + 1
|
|
155
|
+
|
|
156
|
+
oh = math.floor((h_in + pad_h - eff_kh) / sh + 1)
|
|
157
|
+
ow = math.floor((w_in + pad_w - eff_kw) / sw + 1)
|
|
158
|
+
return oh, ow
|
|
159
|
+
|
|
160
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
161
|
+
"""
|
|
162
|
+
Performs the forward pass.
|
|
163
|
+
|
|
164
|
+
Args:
|
|
165
|
+
x (torch.Tensor): Input tensor of shape (N, C_in, H_in, W_in).
|
|
166
|
+
|
|
167
|
+
Returns:
|
|
168
|
+
torch.Tensor: Output tensor of shape (N, C_out, H_out, W_out).
|
|
169
|
+
"""
|
|
170
|
+
N, C, H, W = x.shape
|
|
171
|
+
if C != self.in_channels:
|
|
172
|
+
raise ValueError(
|
|
173
|
+
f"Input channels {C} does not match expected {self.in_channels}"
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
# 1. Calculate and Apply Padding
|
|
177
|
+
if self._padding_val is None: # 'same' padding mode
|
|
178
|
+
pad_l, pad_r, pad_t, pad_b = _calculate_same_padding(
|
|
179
|
+
(H, W), self.kernel_size, self.stride, self.dilation
|
|
180
|
+
)
|
|
181
|
+
padded_x = F.pad(x, (pad_l, pad_r, pad_t, pad_b))
|
|
182
|
+
# Update H, W for output shape calculation after padding
|
|
183
|
+
# Note: _calculate_output_shape will correctly handle 'same' based on original H, W
|
|
184
|
+
elif self._padding_val != (0, 0, 0, 0):
|
|
185
|
+
padded_x = F.pad(x, self._padding_val)
|
|
186
|
+
else: # No padding ('valid' or explicit 0)
|
|
187
|
+
padded_x = x
|
|
188
|
+
|
|
189
|
+
# 2. Unfold to extract patches
|
|
190
|
+
# Input: (N, C_in, H_pad, W_pad)
|
|
191
|
+
# Output: (N, C_in * K_h * K_w, L), where L is the number of patches (H_out * W_out)
|
|
192
|
+
patches = self.unfold(padded_x)
|
|
193
|
+
num_patches = patches.shape[-1] # L
|
|
194
|
+
|
|
195
|
+
# 3. Reshape for the linear layer
|
|
196
|
+
# We want (N, L, C_in * K_h * K_w) to apply the linear layer patch-wise
|
|
197
|
+
# transpose switches the last two dimensions
|
|
198
|
+
patches_transposed = patches.transpose(1, 2) # Shape: (N, L, C_in * K_h * K_w)
|
|
199
|
+
|
|
200
|
+
# 4. Apply the linear layer (kernel) to each patch
|
|
201
|
+
# Input: (N, L, linear_in_features)
|
|
202
|
+
# Output: (N, L, out_channels)
|
|
203
|
+
linear_output = self.kernel(patches_transposed)
|
|
204
|
+
|
|
205
|
+
# 5. Reshape back to image format
|
|
206
|
+
# We need (N, out_channels, L) first
|
|
207
|
+
output_transposed = linear_output.transpose(1, 2) # Shape: (N, out_channels, L)
|
|
208
|
+
|
|
209
|
+
# Calculate output spatial dimensions
|
|
210
|
+
out_h, out_w = self._calculate_output_shape(H, W) # Use original H, W
|
|
211
|
+
|
|
212
|
+
# Check if the number of patches matches the calculated output dimensions
|
|
213
|
+
if num_patches != out_h * out_w:
|
|
214
|
+
# This might happen with certain combinations of stride/padding/dilation/input size
|
|
215
|
+
# if the calculation logic has an issue. nn.Unfold is usually robust.
|
|
216
|
+
print(
|
|
217
|
+
f"Warning: Mismatch in calculated patches. "
|
|
218
|
+
f"Expected L={out_h * out_w}, got {num_patches}. "
|
|
219
|
+
f"Using unfolded L={num_patches} to determine output shape."
|
|
220
|
+
)
|
|
221
|
+
# Attempt recovery if possible, though might indicate upstream calculation error
|
|
222
|
+
# Find factors of num_patches close to expected out_h, out_w
|
|
223
|
+
# This part is tricky and might not always yield the desired shape.
|
|
224
|
+
# For simplicity, we'll rely on nn.Unfold's L and reshape.
|
|
225
|
+
# A more robust solution might re-calculate H_out, W_out based *only* on L.
|
|
226
|
+
# For now, let's stick to the reshape based on calculated out_h, out_w,
|
|
227
|
+
# assuming they match L. If they don't, the reshape will fail.
|
|
228
|
+
pass # Proceed with calculated out_h, out_w
|
|
229
|
+
|
|
230
|
+
# Reshape using einops (or tensor.view)
|
|
231
|
+
# Input: (N, C_out, L) -> Output: (N, C_out, H_out, W_out)
|
|
232
|
+
output = rearrange(output_transposed, "n c (h w) -> n c h w", h=out_h, w=out_w)
|
|
233
|
+
# Alternative using view:
|
|
234
|
+
# output = output_transposed.view(N, self.out_channels, out_h, out_w)
|
|
235
|
+
|
|
236
|
+
return output
|
|
237
|
+
|
|
238
|
+
def extra_repr(self) -> str:
|
|
239
|
+
s = (
|
|
240
|
+
"{in_channels}, {out_channels}, kernel_size={kernel_size}"
|
|
241
|
+
", stride={stride}"
|
|
242
|
+
)
|
|
243
|
+
if self.padding_mode != 0 and self.padding_mode != "valid":
|
|
244
|
+
s += ", padding={padding_mode}"
|
|
245
|
+
if self.dilation != (1,) * len(self.dilation):
|
|
246
|
+
s += ", dilation={dilation}"
|
|
247
|
+
# if self.groups != 1: # Not implemented
|
|
248
|
+
# s += ', groups={groups}'
|
|
249
|
+
if self.bias is False:
|
|
250
|
+
s += ", bias=False"
|
|
251
|
+
if self.linear_module != nn.Linear:
|
|
252
|
+
s += f", linear={self.linear.__name__}"
|
|
253
|
+
return s.format(**self.__dict__)
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
class WhiteningConv(ConvLayer):
|
|
257
|
+
def __init__(
|
|
258
|
+
self,
|
|
259
|
+
in_channels: int,
|
|
260
|
+
kernel_size: int,
|
|
261
|
+
eigenvectors: torch.Tensor,
|
|
262
|
+
bias: bool = True,
|
|
263
|
+
linear_module: Type[nn.Module] = nn.Linear,
|
|
264
|
+
):
|
|
265
|
+
"""
|
|
266
|
+
We end up using a concatenation of the eigenvector tensor with its negation,
|
|
267
|
+
as the tendency to use e.g. ReLU in neural networks means that useful
|
|
268
|
+
data may otherwise be lost (if one orientation of an eigenvector produces
|
|
269
|
+
a strong negative signal, this will be clipped to zero by ReLU, but a
|
|
270
|
+
strong positive signal from the negation of the eigenvector will be
|
|
271
|
+
preserved). Assuming a square kernel, out channels is thus
|
|
272
|
+
|
|
273
|
+
(kernel_size ** 2) * in_channels * 2
|
|
274
|
+
|
|
275
|
+
where the trailing "* 2" accounts for the doubling of the size of the
|
|
276
|
+
eigenvector tensor we're using by including the negative of each eigenvector
|
|
277
|
+
as well.
|
|
278
|
+
"""
|
|
279
|
+
out_channels = kernel_size**2 * in_channels * 2
|
|
280
|
+
super().__init__(
|
|
281
|
+
in_channels,
|
|
282
|
+
out_channels,
|
|
283
|
+
kernel_size,
|
|
284
|
+
padding="same",
|
|
285
|
+
bias=bias,
|
|
286
|
+
linear_module=linear_module,
|
|
287
|
+
)
|
|
288
|
+
self.eigenvectors = torch.cat([eigenvectors, -eigenvectors], dim=0)
|
|
289
|
+
# bias updates if `bias`=True but weight doesn't,
|
|
290
|
+
# per Jordan (2024) https://arxiv.org/abs/2404.00498
|
|
291
|
+
# but weight is set to `requires_grad = False`:
|
|
292
|
+
# self.kernel.weight.requires_grad = False
|
|
293
|
+
with torch.no_grad():
|
|
294
|
+
self.kernel.weight.copy_(self.eigenvectors)
|
|
295
|
+
assert self.kernel.weight.requires_grad
|
|
296
|
+
|
|
297
|
+
|
|
298
|
+
class ConcatPool(nn.Module):
|
|
299
|
+
"""
|
|
300
|
+
A layer that concatenates nearby squares of an image
|
|
301
|
+
"""
|
|
302
|
+
|
|
303
|
+
def __init__(
|
|
304
|
+
self,
|
|
305
|
+
in_channels: int,
|
|
306
|
+
pooling_kernel_size: int,
|
|
307
|
+
pooling_kernel_stride: int,
|
|
308
|
+
pooling_kernel_padding: int,
|
|
309
|
+
d_model: int,
|
|
310
|
+
activation_kwargs: Optional[dict] = None,
|
|
311
|
+
activation: nn.Module = nn.ReLU,
|
|
312
|
+
linear_module: Type[nn.Module] = nn.Linear,
|
|
313
|
+
):
|
|
314
|
+
super().__init__()
|
|
315
|
+
self.in_channels = in_channels
|
|
316
|
+
self.pooling_kernel_size = pooling_kernel_size
|
|
317
|
+
self.pooling_kernel_stride = pooling_kernel_stride
|
|
318
|
+
self.pooling_kernel_padding = pooling_kernel_padding
|
|
319
|
+
self.d_model = d_model
|
|
320
|
+
self.pooling_output_size = (pooling_kernel_size**2) * self.in_channels
|
|
321
|
+
if activation_kwargs is not None:
|
|
322
|
+
self.activation = activation(**activation_kwargs)
|
|
323
|
+
else:
|
|
324
|
+
self.activation = activation()
|
|
325
|
+
self.process = nn.Sequential(
|
|
326
|
+
*[
|
|
327
|
+
nn.Unfold(
|
|
328
|
+
pooling_kernel_size,
|
|
329
|
+
stride=pooling_kernel_stride,
|
|
330
|
+
padding=pooling_kernel_padding,
|
|
331
|
+
),
|
|
332
|
+
Rearrange("N Block L -> N L Block"),
|
|
333
|
+
linear_module(
|
|
334
|
+
self.pooling_output_size,
|
|
335
|
+
2 * d_model if activation.__name__.endswith("GLU") else d_model,
|
|
336
|
+
),
|
|
337
|
+
self.activation,
|
|
338
|
+
]
|
|
339
|
+
)
|
|
340
|
+
|
|
341
|
+
def forward(self, x):
|
|
342
|
+
return self.process(x)
|
|
343
|
+
|
|
344
|
+
|
|
345
|
+
class WhitenConcatProject(nn.Module):
|
|
346
|
+
"""
|
|
347
|
+
A layer that whitens patches of the input image, then concatenates nearby
|
|
348
|
+
patch embeddings before rearranging the outputs into the format expected
|
|
349
|
+
by a Transformer block
|
|
350
|
+
"""
|
|
351
|
+
|
|
352
|
+
def __init__(
|
|
353
|
+
self,
|
|
354
|
+
in_channels: int,
|
|
355
|
+
convolution_kernel_size: int,
|
|
356
|
+
pooling_kernel_size: int,
|
|
357
|
+
pooling_kernel_stride: int,
|
|
358
|
+
pooling_kernel_padding: int,
|
|
359
|
+
d_model: int,
|
|
360
|
+
dropout=0.0,
|
|
361
|
+
activation: nn.Module = nn.ReLU,
|
|
362
|
+
activation_kwargs: Optional[dict] = None,
|
|
363
|
+
eigenvectors: torch.Tensor = None,
|
|
364
|
+
bias: bool = True,
|
|
365
|
+
linear_module: Type[nn.Module] = nn.Linear,
|
|
366
|
+
):
|
|
367
|
+
super().__init__()
|
|
368
|
+
assert eigenvectors is not None
|
|
369
|
+
self.in_channels = in_channels
|
|
370
|
+
self.convolution_kernel_size = convolution_kernel_size
|
|
371
|
+
self.pooling_kernel_size = pooling_kernel_size
|
|
372
|
+
self.whitening_out_channels = convolution_kernel_size**2 * in_channels * 2
|
|
373
|
+
self.pooling_output_size = (
|
|
374
|
+
pooling_kernel_size**2
|
|
375
|
+
) * self.whitening_out_channels
|
|
376
|
+
self.process = nn.Sequential(
|
|
377
|
+
*[
|
|
378
|
+
WhiteningConv(
|
|
379
|
+
in_channels,
|
|
380
|
+
convolution_kernel_size,
|
|
381
|
+
eigenvectors,
|
|
382
|
+
bias=bias,
|
|
383
|
+
linear_module=linear_module,
|
|
384
|
+
),
|
|
385
|
+
ConcatPool(
|
|
386
|
+
self.whitening_out_channels,
|
|
387
|
+
pooling_kernel_size,
|
|
388
|
+
pooling_kernel_stride,
|
|
389
|
+
pooling_kernel_padding,
|
|
390
|
+
d_model,
|
|
391
|
+
activation,
|
|
392
|
+
activation_kwargs,
|
|
393
|
+
linear_module,
|
|
394
|
+
),
|
|
395
|
+
nn.Dropout(dropout),
|
|
396
|
+
]
|
|
397
|
+
)
|
|
398
|
+
|
|
399
|
+
def forward(self, x):
|
|
400
|
+
return self.process(x)
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
from torch.utils.data import DataLoader
|
|
4
|
+
from einops import rearrange
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def eigenvectors(images: torch.Tensor, patch_size: int = 2, eps=5e-4) -> torch.Tensor:
|
|
8
|
+
"""
|
|
9
|
+
Adapted from
|
|
10
|
+
https://github.com/KellerJordan/cifar10-airbench/blob/master/airbench96_faster.py
|
|
11
|
+
using https://datascienceplus.com/understanding-the-covariance-matrix/
|
|
12
|
+
"""
|
|
13
|
+
with torch.no_grad():
|
|
14
|
+
unfolder = nn.Unfold(kernel_size=patch_size, stride=1)
|
|
15
|
+
patches = unfolder(images) # (N, patch_elements, patches_per_image)
|
|
16
|
+
patches = rearrange(patches, "N elements patches -> (N patches) elements")
|
|
17
|
+
n = patches.size(0)
|
|
18
|
+
centred = patches - patches.mean(dim=1, keepdim=True)
|
|
19
|
+
covariance_matrix = (
|
|
20
|
+
centred.T @ centred
|
|
21
|
+
) / n # https://datascienceplus.com/understanding-the-covariance-matrix/
|
|
22
|
+
_, eigenvectors = torch.linalg.eigh(covariance_matrix)
|
|
23
|
+
return eigenvectors
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def get_eigenpatches(data: DataLoader):
|
|
27
|
+
patches = None
|
|
28
|
+
for images, _ in data:
|
|
29
|
+
eigenpatches = eigenvectors(images)
|
|
30
|
+
if patches is None:
|
|
31
|
+
patches = eigenpatches
|
|
32
|
+
else:
|
|
33
|
+
patches = 0.99 * patches + 0.01 * eigenpatches
|
|
34
|
+
return patches
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def plottable_bandpass_filters(eigenpatches: torch.Tensor, h, w, c):
|
|
38
|
+
bandpass_filters = rearrange(eigenpatches, "N (C H W) -> N H W C", C=3, H=2, W=2)
|
|
39
|
+
return bandpass_filters.detach().numpy()
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
# UNDER CONSTRUCTION
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from torch import nn
|
|
5
|
+
from torch.nn import functional as F
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class RandomLinear(nn.Linear):
|
|
9
|
+
""" """
|
|
10
|
+
|
|
11
|
+
def __init__(
|
|
12
|
+
self,
|
|
13
|
+
in_features: int,
|
|
14
|
+
out_features: int,
|
|
15
|
+
bias: bool = False, # <---- TODO: explain this
|
|
16
|
+
beta=0.1,
|
|
17
|
+
forward_looks_random=True,
|
|
18
|
+
):
|
|
19
|
+
super().__init__(in_features, out_features, bias=False)
|
|
20
|
+
self.beta = beta
|
|
21
|
+
self.forward_looks_random = forward_looks_random
|
|
22
|
+
|
|
23
|
+
def forward(self, inputs: torch.Tensor):
|
|
24
|
+
if not self.training:
|
|
25
|
+
return F.linear(inputs, self.weight)
|
|
26
|
+
else:
|
|
27
|
+
# Initialise self.random_weights
|
|
28
|
+
random_weights = torch.empty_like(self.weight)
|
|
29
|
+
nn.init.trunc_normal_(random_weights)
|
|
30
|
+
random_weights *= self.beta
|
|
31
|
+
|
|
32
|
+
if self.forward_looks_random:
|
|
33
|
+
# Forward using a reparameterisation trick
|
|
34
|
+
a = F.linear(inputs.detach(), self.weight, self.bias)
|
|
35
|
+
b = F.linear(inputs, random_weights, bias=None)
|
|
36
|
+
else:
|
|
37
|
+
# Forward as (W_actual * input + W_random * input) + bias
|
|
38
|
+
a = F.linear(inputs, self.weight, self.bias)
|
|
39
|
+
b = F.linear(inputs, random_weights, bias=None)
|
|
40
|
+
|
|
41
|
+
return a + b
|
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from torch import nn
|
|
3
|
+
from torch.nn import functional as F
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class SigmaReparamTensor(nn.Module):
|
|
7
|
+
"""
|
|
8
|
+
Inspired by Apple's Spectral Normed Linear Layers
|
|
9
|
+
(https://github.com/apple/ml-sigma-reparam)
|
|
10
|
+
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
def __init__(self, init_tensor: torch.Tensor):
|
|
14
|
+
assert init_tensor.ndim == 2
|
|
15
|
+
|
|
16
|
+
super().__init__()
|
|
17
|
+
|
|
18
|
+
self.tensor = nn.Parameter(init_tensor, requires_grad=True)
|
|
19
|
+
|
|
20
|
+
with torch.no_grad():
|
|
21
|
+
_, sigma, v_transpose = torch.linalg.svd(self.tensor, full_matrices=False)
|
|
22
|
+
|
|
23
|
+
self.register_buffer("approx_spectral_norm", sigma[:1])
|
|
24
|
+
self.register_buffer("right_singular", v_transpose[0])
|
|
25
|
+
self.scale = nn.Parameter(
|
|
26
|
+
self.approx_spectral_norm.clone().detach(), requires_grad=True
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
def power_iteration(self):
|
|
30
|
+
with torch.no_grad():
|
|
31
|
+
approx_right_singular_transpose = self.tensor.mv(self.right_singular)
|
|
32
|
+
approx_right_singular_transpose = F.normalize(
|
|
33
|
+
approx_right_singular_transpose, dim=0
|
|
34
|
+
)
|
|
35
|
+
updated_right_singular = self.tensor.T.mv(approx_right_singular_transpose)
|
|
36
|
+
updated_right_singular = F.normalize(self.right_singular, dim=0)
|
|
37
|
+
self.right_singular.data.copy_(updated_right_singular)
|
|
38
|
+
rayleigh_quotient = torch.einsum(
|
|
39
|
+
"m,mn,n->",
|
|
40
|
+
approx_right_singular_transpose,
|
|
41
|
+
self.tensor,
|
|
42
|
+
updated_right_singular,
|
|
43
|
+
)
|
|
44
|
+
self.approx_spectral_norm.data.copy_(rayleigh_quotient)
|
|
45
|
+
|
|
46
|
+
def forward(self):
|
|
47
|
+
if self.training:
|
|
48
|
+
self.power_iteration()
|
|
49
|
+
return self.scale * (self.tensor / self.approx_spectral_norm)
|
|
@@ -0,0 +1,260 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from collections import OrderedDict
|
|
3
|
+
from typing import Optional
|
|
4
|
+
from numpy import random
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
import torch.nn as nn
|
|
8
|
+
import torch.nn.functional as F
|
|
9
|
+
|
|
10
|
+
from einops import rearrange, einsum, reduce, repeat
|
|
11
|
+
from einops.layers.torch import Rearrange
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class MHAttention(nn.Module):
|
|
15
|
+
"""
|
|
16
|
+
Multi-head self-attention using einops and custom linear layer.
|
|
17
|
+
|
|
18
|
+
Forward method assumes q, k and v have the same embedding size and k and v
|
|
19
|
+
are the same shape.
|
|
20
|
+
|
|
21
|
+
Assumes bias=False and batch_first=True, as God intended.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
def __init__(
|
|
25
|
+
self,
|
|
26
|
+
embed_dim,
|
|
27
|
+
n_heads,
|
|
28
|
+
dropout=0.0,
|
|
29
|
+
causal=False,
|
|
30
|
+
sequence_length=None,
|
|
31
|
+
share_kv=True,
|
|
32
|
+
linear_module: nn.Module = nn.Linear,
|
|
33
|
+
):
|
|
34
|
+
super().__init__()
|
|
35
|
+
if causal:
|
|
36
|
+
assert sequence_length is not None
|
|
37
|
+
self.embed_dim = embed_dim
|
|
38
|
+
self.n_heads = n_heads
|
|
39
|
+
assert embed_dim % n_heads == 0
|
|
40
|
+
self.head_dim = self.embed_dim // self.n_heads
|
|
41
|
+
self.share_kv = share_kv
|
|
42
|
+
self.q_proj = linear_module(self.embed_dim, self.embed_dim, bias=False)
|
|
43
|
+
self.k_proj = linear_module(self.embed_dim, self.embed_dim, bias=False)
|
|
44
|
+
if self.share_kv:
|
|
45
|
+
self.v_proj = self.k_proj
|
|
46
|
+
else:
|
|
47
|
+
self.v_proj = linear_module(self.embed_dim, self.embed_dim, bias=False)
|
|
48
|
+
self.out_proj = linear_module(self.embed_dim, self.embed_dim, bias=False)
|
|
49
|
+
self.causal = causal
|
|
50
|
+
self.sequence_length = sequence_length
|
|
51
|
+
self.dropout = nn.Dropout(dropout)
|
|
52
|
+
if self.causal:
|
|
53
|
+
self.register_buffer(
|
|
54
|
+
"mask",
|
|
55
|
+
(
|
|
56
|
+
torch.triu(torch.ones(sequence_length, sequence_length), diagonal=1)
|
|
57
|
+
== 1
|
|
58
|
+
)
|
|
59
|
+
.unsqueeze(0)
|
|
60
|
+
.unsqueeze(0),
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
def forward(self, q, k, v):
|
|
64
|
+
query_batch_size, query_tokens, query_features = q.size()
|
|
65
|
+
key_batch_size, key_tokens, key_features = k.size()
|
|
66
|
+
|
|
67
|
+
assert k.size() == v.size()
|
|
68
|
+
assert query_features == key_features
|
|
69
|
+
assert (
|
|
70
|
+
(query_batch_size == key_batch_size) # batch sizes are the same...
|
|
71
|
+
or query_batch_size == 1 # ... or query is broadcastable
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
if self.causal:
|
|
75
|
+
assert query_tokens == key_tokens
|
|
76
|
+
assert query_tokens == self.sequence_length
|
|
77
|
+
|
|
78
|
+
# Project q, k and v and divide into heads
|
|
79
|
+
q = rearrange(self.q_proj(q), "b t (h d) -> b h t d", h=self.n_heads)
|
|
80
|
+
k = rearrange(self.k_proj(k), "b t (h d) -> b h t d", h=self.n_heads)
|
|
81
|
+
if self.share_kv:
|
|
82
|
+
v = k
|
|
83
|
+
else:
|
|
84
|
+
v = rearrange(self.v_proj(v), "b t (h d) -> b h t d", h=self.n_heads)
|
|
85
|
+
|
|
86
|
+
qk_scores = q @ k.transpose(-1, -2)
|
|
87
|
+
qk_scores /= math.sqrt(self.head_dim) # scaling
|
|
88
|
+
|
|
89
|
+
# Apply mask if causal (must come before softmax)
|
|
90
|
+
if self.causal:
|
|
91
|
+
qk_scores.masked_fill_(self.mask, float("-inf"))
|
|
92
|
+
|
|
93
|
+
qk_scores = torch.softmax(qk_scores, dim=-1) # softmax
|
|
94
|
+
qk_scores = self.dropout(qk_scores) # dropout must come after softmax!
|
|
95
|
+
|
|
96
|
+
output_with_heads = qk_scores @ v
|
|
97
|
+
|
|
98
|
+
output_without_heads = rearrange(output_with_heads, "b h t d -> b t (h d)")
|
|
99
|
+
|
|
100
|
+
return self.out_proj(output_without_heads)
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
class TransformerBlock(nn.Module):
|
|
104
|
+
"""
|
|
105
|
+
Performs LayerNorms first (as in PyTorch Transformers when norm_first=True),
|
|
106
|
+
which is also what is seen in e.g.
|
|
107
|
+
https://github.com/karpathy/minGPT/blob/master/mingpt/model.py
|
|
108
|
+
and is recommended by https://arxiv.org/abs/2002.04745
|
|
109
|
+
|
|
110
|
+
"""
|
|
111
|
+
|
|
112
|
+
def __init__(
|
|
113
|
+
self,
|
|
114
|
+
d_model,
|
|
115
|
+
n_heads,
|
|
116
|
+
mlp_ratio=4,
|
|
117
|
+
activation: nn.Module = nn.ReLU,
|
|
118
|
+
activation_kwargs: Optional[dict] = None,
|
|
119
|
+
mlp_dropout=0.0,
|
|
120
|
+
msa_dropout=0.0,
|
|
121
|
+
causal=False,
|
|
122
|
+
linear_module=nn.Linear,
|
|
123
|
+
):
|
|
124
|
+
super().__init__()
|
|
125
|
+
|
|
126
|
+
if activation_kwargs is not None:
|
|
127
|
+
self.activation = activation(**activation_kwargs)
|
|
128
|
+
else:
|
|
129
|
+
self.activation = activation()
|
|
130
|
+
|
|
131
|
+
# Submodules for applying attention
|
|
132
|
+
self.layer_norm = nn.LayerNorm(d_model)
|
|
133
|
+
self.attn = MHAttention( # Handles QKV projection
|
|
134
|
+
d_model,
|
|
135
|
+
n_heads,
|
|
136
|
+
dropout=msa_dropout,
|
|
137
|
+
causal=causal,
|
|
138
|
+
linear_module=linear_module,
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
# Submodules for the feedforward process
|
|
142
|
+
self.ff_process = nn.Sequential(
|
|
143
|
+
OrderedDict(
|
|
144
|
+
[
|
|
145
|
+
("layer_norm", nn.LayerNorm(d_model)),
|
|
146
|
+
(
|
|
147
|
+
# up_projection is appropriate to activation
|
|
148
|
+
"up_projection",
|
|
149
|
+
linear_module(
|
|
150
|
+
d_model,
|
|
151
|
+
(
|
|
152
|
+
2 * mlp_ratio * d_model
|
|
153
|
+
if activation.__name__.endswith("GLU")
|
|
154
|
+
else mlp_ratio * d_model
|
|
155
|
+
),
|
|
156
|
+
),
|
|
157
|
+
),
|
|
158
|
+
# xGLU activations will halve embedding size
|
|
159
|
+
("activation", self.activation),
|
|
160
|
+
("down_projection", linear_module(mlp_ratio * d_model, d_model)),
|
|
161
|
+
("dropout", nn.Dropout(mlp_dropout)),
|
|
162
|
+
]
|
|
163
|
+
)
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
def forward(self, x):
|
|
167
|
+
normx = self.layer_norm(x)
|
|
168
|
+
x = x + self.attn(normx, normx, normx)
|
|
169
|
+
x = x + self.ff_process(x)
|
|
170
|
+
return x
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
class TransformerEncoder(nn.Module):
|
|
174
|
+
"""
|
|
175
|
+
This assumes we already get a sequence of embeddings (e.g. word or image
|
|
176
|
+
patch embeddings). It uses learned positional embeddings.
|
|
177
|
+
"""
|
|
178
|
+
|
|
179
|
+
def __init__(
|
|
180
|
+
self,
|
|
181
|
+
seq_len,
|
|
182
|
+
d_model,
|
|
183
|
+
n_layers,
|
|
184
|
+
n_heads,
|
|
185
|
+
mlp_ratio=4,
|
|
186
|
+
activation: nn.Module = nn.ReLU,
|
|
187
|
+
activation_kwargs: Optional[dict] = None,
|
|
188
|
+
mlp_dropout=0.0,
|
|
189
|
+
msa_dropout=0.0,
|
|
190
|
+
stochastic_depth=0.0,
|
|
191
|
+
causal=False,
|
|
192
|
+
linear_module=nn.Linear,
|
|
193
|
+
bos_tokens=True,
|
|
194
|
+
):
|
|
195
|
+
super().__init__()
|
|
196
|
+
self.seq_len = seq_len
|
|
197
|
+
self.n_heads = n_heads
|
|
198
|
+
|
|
199
|
+
# Initialise BOS tokens with Xavier uniform init per
|
|
200
|
+
# https://docs.pytorch.org/docs/stable/nn.init.html#torch.nn.init.xavier_uniform_
|
|
201
|
+
if bos_tokens:
|
|
202
|
+
self.bos_tokens = nn.Parameter(torch.empty(bos_tokens, d_model))
|
|
203
|
+
nn.init.normal_(self.bos_tokens, mean=0.0, std=1.0)
|
|
204
|
+
self.full_sequence_length = self.seq_len + bos_tokens
|
|
205
|
+
else:
|
|
206
|
+
self.bos_tokens = None
|
|
207
|
+
self.full_sequence_length = self.seq_len
|
|
208
|
+
|
|
209
|
+
self.d_model = d_model
|
|
210
|
+
self.positional_embedding = nn.Embedding(self.full_sequence_length, d_model)
|
|
211
|
+
self.mlp_dropout = mlp_dropout
|
|
212
|
+
self.msa_dropout = msa_dropout
|
|
213
|
+
self.stochastic_depth = stochastic_depth
|
|
214
|
+
self.blocks = nn.ModuleList(
|
|
215
|
+
[
|
|
216
|
+
TransformerBlock(
|
|
217
|
+
d_model,
|
|
218
|
+
n_heads,
|
|
219
|
+
mlp_ratio=mlp_ratio,
|
|
220
|
+
activation=activation,
|
|
221
|
+
activation_kwargs=activation_kwargs,
|
|
222
|
+
mlp_dropout=mlp_dropout,
|
|
223
|
+
msa_dropout=msa_dropout,
|
|
224
|
+
causal=causal,
|
|
225
|
+
linear_module=linear_module,
|
|
226
|
+
)
|
|
227
|
+
for _ in range(n_layers)
|
|
228
|
+
]
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
def forward(self, x):
|
|
232
|
+
if self.bos_tokens is not None:
|
|
233
|
+
x = torch.cat([self.bos_tokens.expand(x.size(0), -1, -1), x], dim=1)
|
|
234
|
+
else:
|
|
235
|
+
x = x
|
|
236
|
+
x = x + self.positional_embedding(
|
|
237
|
+
torch.arange(
|
|
238
|
+
0, self.full_sequence_length, dtype=torch.long, device=x.device
|
|
239
|
+
).unsqueeze(
|
|
240
|
+
0
|
|
241
|
+
) # to shape (1, seq_len) to broadcast over batch
|
|
242
|
+
)
|
|
243
|
+
for block in self.blocks:
|
|
244
|
+
if (not self.training) or self.stochastic_depth == 0.0:
|
|
245
|
+
x = block(x)
|
|
246
|
+
else: # drop out some rows from the next Transformer block operation
|
|
247
|
+
binomial = random.binomial(n=x.size(0), p=1 - self.stochastic_depth)
|
|
248
|
+
shuffle_indices = torch.randperm(x.size(0), device=x.device)
|
|
249
|
+
unshuffle_indices = torch.argsort(shuffle_indices) # , device=x.device)
|
|
250
|
+
shuffled = x[shuffle_indices, :, :]
|
|
251
|
+
include = shuffled[:binomial, :, :]
|
|
252
|
+
exclude = shuffled[binomial:, :, :]
|
|
253
|
+
x = torch.cat([block(include), exclude])[
|
|
254
|
+
unshuffle_indices, :, :
|
|
255
|
+
].contiguous()
|
|
256
|
+
|
|
257
|
+
if self.bos_tokens is not None:
|
|
258
|
+
return x[:, (self.full_sequence_length - self.seq_len) :, :]
|
|
259
|
+
else:
|
|
260
|
+
return x
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
[project]
|
|
2
|
+
name = "broccoli-ml"
|
|
3
|
+
version = "0.1.0"
|
|
4
|
+
description = "Some useful Pytorch models, circa 2025"
|
|
5
|
+
authors = [
|
|
6
|
+
{name = "Nicholas Bailey"}
|
|
7
|
+
]
|
|
8
|
+
license = {text = "MIT"}
|
|
9
|
+
readme = "README.md"
|
|
10
|
+
requires-python = ">=3.12"
|
|
11
|
+
dependencies = [
|
|
12
|
+
"torch (>=2.7.1,<3.0.0)",
|
|
13
|
+
"numpy (>=2.3.1,<3.0.0)",
|
|
14
|
+
"einops (>=0.8.1,<0.9.0)"
|
|
15
|
+
]
|
|
16
|
+
|
|
17
|
+
[tool.poetry]
|
|
18
|
+
packages = [
|
|
19
|
+
{ include = "broccoli" } # This tells Poetry that the importable package is named 'broccoli'
|
|
20
|
+
]
|
|
21
|
+
|
|
22
|
+
[tool.poetry.group.dev.dependencies]
|
|
23
|
+
black = "^25.1.0"
|
|
24
|
+
flake8 = "^7.3.0"
|
|
25
|
+
pytest = "^8.4.1"
|
|
26
|
+
pytest-cov = "^6.2.1"
|
|
27
|
+
|
|
28
|
+
[build-system]
|
|
29
|
+
requires = ["poetry-core>=2.0.0,<3.0.0"]
|
|
30
|
+
build-backend = "poetry.core.masonry.api"
|
|
31
|
+
|
|
32
|
+
[tool.black]
|
|
33
|
+
line-length = 88
|
|
34
|
+
target-version = ['py312']
|
|
35
|
+
include = '\.pyi?$'
|
|
36
|
+
extend-exclude = '''
|
|
37
|
+
# A regex preceded with ^/ will apply only to files and directories
|
|
38
|
+
# in the root of the project.
|
|
39
|
+
(
|
|
40
|
+
^/foo.py # exclude a file named foo.py in the root of the project
|
|
41
|
+
| .*_pb2.py # exclude autogenerated Protocol Buffer files anywhere in the project
|
|
42
|
+
)
|
|
43
|
+
'''
|