broccoli-ml 9.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- broccoli/__init__.py +6 -0
- broccoli/activation.py +118 -0
- broccoli/cnn.py +157 -0
- broccoli/linear.py +352 -0
- broccoli/rope.py +407 -0
- broccoli/tensor.py +128 -0
- broccoli/transformer.py +779 -0
- broccoli/utils.py +15 -0
- broccoli/vit.py +600 -0
- broccoli_ml-9.5.1.dist-info/LICENSE +21 -0
- broccoli_ml-9.5.1.dist-info/METADATA +43 -0
- broccoli_ml-9.5.1.dist-info/RECORD +13 -0
- broccoli_ml-9.5.1.dist-info/WHEEL +4 -0
broccoli/utils.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
import torch.nn as nn
|
|
2
|
+
import torch.nn.functional as F
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class PadTensor(nn.Module):
|
|
6
|
+
def __init__(self, *args, **kwargs):
|
|
7
|
+
super().__init__()
|
|
8
|
+
self.args = args
|
|
9
|
+
self.kwargs = kwargs
|
|
10
|
+
|
|
11
|
+
def forward(self, x):
|
|
12
|
+
if sum(self.args[0]) == 0:
|
|
13
|
+
return x
|
|
14
|
+
else:
|
|
15
|
+
return F.pad(x, *self.args, **self.kwargs)
|
broccoli/vit.py
ADDED
|
@@ -0,0 +1,600 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from typing import Optional
|
|
3
|
+
|
|
4
|
+
from .transformer import TransformerEncoder, FeedforwardBlock
|
|
5
|
+
from .cnn import SpaceToDepth, calculate_output_spatial_size, spatial_tuple
|
|
6
|
+
from .activation import ReLU, SquaredReLU, GELU, SwiGLU
|
|
7
|
+
from .utils import PadTensor
|
|
8
|
+
|
|
9
|
+
from einops import einsum
|
|
10
|
+
from einops.layers.torch import Rearrange
|
|
11
|
+
|
|
12
|
+
import torch
|
|
13
|
+
import torch.nn as nn
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class GetCLSToken(nn.Module):
|
|
17
|
+
def __init__(self):
|
|
18
|
+
super().__init__()
|
|
19
|
+
|
|
20
|
+
def forward(self, x):
|
|
21
|
+
return x[:, 0, :]
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class SequencePool(nn.Module):
|
|
25
|
+
def __init__(self, d_model, linear_module=nn.Linear):
|
|
26
|
+
super().__init__()
|
|
27
|
+
self.attention = nn.Sequential(
|
|
28
|
+
*[
|
|
29
|
+
linear_module(d_model, 1),
|
|
30
|
+
Rearrange("batch seq 1 -> batch seq"),
|
|
31
|
+
nn.Softmax(dim=-1),
|
|
32
|
+
]
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
self.reset_parameters()
|
|
36
|
+
|
|
37
|
+
def forward(self, x):
|
|
38
|
+
weights = self.attention(x)
|
|
39
|
+
return einsum(weights, x, "batch seq, batch seq d_model -> batch d_model")
|
|
40
|
+
|
|
41
|
+
def attention_scores(self, x):
|
|
42
|
+
return self.attention(x)
|
|
43
|
+
|
|
44
|
+
def reset_parameters(self):
|
|
45
|
+
# Iterate over modules in the sequential block
|
|
46
|
+
for module in self.attention:
|
|
47
|
+
if hasattr(module, "reset_parameters"):
|
|
48
|
+
module.reset_parameters()
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class ClassificationHead(nn.Module):
|
|
52
|
+
"""
|
|
53
|
+
A general classification head for a ViT
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
def __init__(
|
|
57
|
+
self,
|
|
58
|
+
d_model,
|
|
59
|
+
n_classes,
|
|
60
|
+
logit_projection_layer=nn.Linear,
|
|
61
|
+
batch_norm_logits=True,
|
|
62
|
+
):
|
|
63
|
+
super().__init__()
|
|
64
|
+
self.d_model = d_model
|
|
65
|
+
self.summarize = GetCLSToken()
|
|
66
|
+
|
|
67
|
+
if d_model == n_classes:
|
|
68
|
+
# No need to project
|
|
69
|
+
self.projection = nn.Identity()
|
|
70
|
+
else:
|
|
71
|
+
self.projection = logit_projection_layer(d_model, n_classes)
|
|
72
|
+
|
|
73
|
+
if batch_norm_logits:
|
|
74
|
+
self.batch_norm = nn.BatchNorm1d(n_classes, affine=False)
|
|
75
|
+
else:
|
|
76
|
+
self.batch_norm = nn.Identity()
|
|
77
|
+
|
|
78
|
+
self.classification_process = nn.Sequential(
|
|
79
|
+
*[
|
|
80
|
+
self.summarize,
|
|
81
|
+
self.projection,
|
|
82
|
+
self.batch_norm,
|
|
83
|
+
]
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
self.reset_parameters()
|
|
87
|
+
|
|
88
|
+
def forward(self, x):
|
|
89
|
+
return self.classification_process(x)
|
|
90
|
+
|
|
91
|
+
def reset_parameters(self):
|
|
92
|
+
for module in self.classification_process:
|
|
93
|
+
if hasattr(module, "reset_parameters"):
|
|
94
|
+
module.reset_parameters()
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
class SequencePoolClassificationHead(ClassificationHead):
|
|
98
|
+
"""
|
|
99
|
+
As described in [Hasani et al. (2021) *''Escaping the Big Data Paradigm with
|
|
100
|
+
Compact Transformers''*](https://arxiv.org/abs/2104.05704). It can be viewed
|
|
101
|
+
as a generalisation of average pooling.
|
|
102
|
+
"""
|
|
103
|
+
|
|
104
|
+
def __init__(
|
|
105
|
+
self,
|
|
106
|
+
d_model,
|
|
107
|
+
n_classes,
|
|
108
|
+
logit_projection_layer=nn.Linear,
|
|
109
|
+
batch_norm_logits=True,
|
|
110
|
+
):
|
|
111
|
+
super().__init__(
|
|
112
|
+
d_model,
|
|
113
|
+
n_classes,
|
|
114
|
+
logit_projection_layer=logit_projection_layer,
|
|
115
|
+
batch_norm_logits=batch_norm_logits,
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
self.summarize = SequencePool(d_model, logit_projection_layer)
|
|
119
|
+
# Rebuild the classification process with the correct summary module:
|
|
120
|
+
self.classification_process = nn.Sequential(
|
|
121
|
+
*[
|
|
122
|
+
self.summarize,
|
|
123
|
+
self.projection,
|
|
124
|
+
self.batch_norm,
|
|
125
|
+
]
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
self.reset_parameters()
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
class ViTEncoder(nn.Module):
|
|
132
|
+
"""
|
|
133
|
+
Based on the Compact Convolutional Transformer (CCT) of [Hasani et al. (2021)
|
|
134
|
+
*''Escaping the Big Data Paradigm with Compact Transformers''*](
|
|
135
|
+
https://arxiv.org/abs/2104.05704). It's basically a convolutional neural
|
|
136
|
+
network leading into a transformer encoder. To make it like the full CCT
|
|
137
|
+
we would finish it of with a sequence pooling layer but we won't always
|
|
138
|
+
want to do that.
|
|
139
|
+
"""
|
|
140
|
+
|
|
141
|
+
def __init__(
|
|
142
|
+
self,
|
|
143
|
+
input_size=(32, 32),
|
|
144
|
+
in_channels=3,
|
|
145
|
+
initial_batch_norm=True,
|
|
146
|
+
cnn=True,
|
|
147
|
+
cnn_out_channels=16,
|
|
148
|
+
cnn_kernel_size=3,
|
|
149
|
+
cnn_kernel_stride=1,
|
|
150
|
+
cnn_padding="same",
|
|
151
|
+
cnn_kernel_dilation=1,
|
|
152
|
+
cnn_kernel_groups=1,
|
|
153
|
+
cnn_activation: nn.Module = ReLU,
|
|
154
|
+
cnn_activation_kwargs: Optional[dict] = None,
|
|
155
|
+
cnn_dropout=0.0,
|
|
156
|
+
pooling_type="concat", # max, average or concat
|
|
157
|
+
pooling_kernel_size=3,
|
|
158
|
+
pooling_kernel_stride=2,
|
|
159
|
+
pooling_padding=1,
|
|
160
|
+
transformer_feedforward_first=True,
|
|
161
|
+
transformer_initial_ff_residual_path=True,
|
|
162
|
+
transformer_initial_ff_linear_module_up=None,
|
|
163
|
+
transformer_initial_ff_linear_module_down=None,
|
|
164
|
+
transformer_initial_ff_dropout=None,
|
|
165
|
+
transformer_initial_ff_inner_dropout=None,
|
|
166
|
+
transformer_initial_ff_outer_dropout=None,
|
|
167
|
+
transformer_pre_norm=True,
|
|
168
|
+
transformer_normformer=False,
|
|
169
|
+
transformer_post_norm=False,
|
|
170
|
+
transformer_absolute_position_embedding=False,
|
|
171
|
+
transformer_relative_position_embedding=True,
|
|
172
|
+
transformer_embedding_size=256,
|
|
173
|
+
transformer_layers=7,
|
|
174
|
+
transformer_heads=4,
|
|
175
|
+
transformer_mlp_ratio=2,
|
|
176
|
+
transformer_utility_tokens=0,
|
|
177
|
+
transformer_return_utility_tokens=False,
|
|
178
|
+
transformer_activation: nn.Module = SquaredReLU,
|
|
179
|
+
transformer_activation_kwargs: Optional[dict] = None,
|
|
180
|
+
transformer_ff_linear_module_up=None,
|
|
181
|
+
transformer_ff_linear_module_down=None,
|
|
182
|
+
transformer_msa_scaling="d",
|
|
183
|
+
transformer_ff_dropout=0.0,
|
|
184
|
+
transformer_ff_inner_dropout=0.0,
|
|
185
|
+
transformer_ff_outer_dropout=0.0,
|
|
186
|
+
transformer_msa_dropout=0.1,
|
|
187
|
+
transformer_stochastic_depth=0.1,
|
|
188
|
+
transformer_checkpoint_ff=True,
|
|
189
|
+
linear_module=nn.Linear,
|
|
190
|
+
):
|
|
191
|
+
super().__init__()
|
|
192
|
+
|
|
193
|
+
if cnn_activation_kwargs is not None:
|
|
194
|
+
self.cnn_activation = cnn_activation(**cnn_activation_kwargs)
|
|
195
|
+
else:
|
|
196
|
+
self.cnn_activation = cnn_activation()
|
|
197
|
+
|
|
198
|
+
if transformer_activation_kwargs is not None:
|
|
199
|
+
self.transformer_activation = transformer_activation(
|
|
200
|
+
**transformer_activation_kwargs
|
|
201
|
+
)
|
|
202
|
+
else:
|
|
203
|
+
self.transformer_activation = transformer_activation()
|
|
204
|
+
|
|
205
|
+
self.input_size = input_size
|
|
206
|
+
self.spatial_dimensions = len(self.input_size)
|
|
207
|
+
|
|
208
|
+
if self.spatial_dimensions == 1:
|
|
209
|
+
maxpoolxd = nn.MaxPool1d
|
|
210
|
+
avgpoolxd = nn.AvgPool1d
|
|
211
|
+
convxd = nn.Conv1d
|
|
212
|
+
batchnormxd = nn.BatchNorm1d
|
|
213
|
+
spatial_dim_names = "D1"
|
|
214
|
+
elif self.spatial_dimensions == 2:
|
|
215
|
+
maxpoolxd = nn.MaxPool2d
|
|
216
|
+
avgpoolxd = nn.AvgPool2d
|
|
217
|
+
convxd = nn.Conv2d
|
|
218
|
+
batchnormxd = nn.BatchNorm2d
|
|
219
|
+
spatial_dim_names = "D1 D2"
|
|
220
|
+
elif self.spatial_dimensions == 3:
|
|
221
|
+
maxpoolxd = nn.MaxPool3d
|
|
222
|
+
avgpoolxd = nn.AvgPool3d
|
|
223
|
+
convxd = nn.Conv3d
|
|
224
|
+
batchnormxd = nn.BatchNorm3d
|
|
225
|
+
spatial_dim_names = "D1 D2 D3"
|
|
226
|
+
else:
|
|
227
|
+
raise NotImplementedError(
|
|
228
|
+
"`input_size` must be a tuple of length 1, 2, or 3."
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
if cnn:
|
|
232
|
+
# This block rhymes:
|
|
233
|
+
if cnn_activation.__name__.endswith("GLU"):
|
|
234
|
+
cnn_out_channels *= 2
|
|
235
|
+
cnn_output_size = calculate_output_spatial_size(
|
|
236
|
+
input_size,
|
|
237
|
+
kernel_size=cnn_kernel_size,
|
|
238
|
+
stride=cnn_kernel_stride,
|
|
239
|
+
padding=cnn_padding,
|
|
240
|
+
dilation=cnn_kernel_dilation,
|
|
241
|
+
)
|
|
242
|
+
self.cnn = convxd(
|
|
243
|
+
in_channels,
|
|
244
|
+
cnn_out_channels,
|
|
245
|
+
cnn_kernel_size,
|
|
246
|
+
stride=cnn_kernel_stride,
|
|
247
|
+
padding=cnn_padding,
|
|
248
|
+
dilation=cnn_kernel_dilation,
|
|
249
|
+
groups=cnn_kernel_groups,
|
|
250
|
+
bias=True,
|
|
251
|
+
padding_mode="zeros",
|
|
252
|
+
)
|
|
253
|
+
cnn_activation_out_channels = cnn_out_channels
|
|
254
|
+
self.activate_and_dropout = nn.Sequential(
|
|
255
|
+
*[
|
|
256
|
+
Rearrange( # rearrange in case we're using XGLU activation
|
|
257
|
+
f"N C {spatial_dim_names} -> N {spatial_dim_names} C"
|
|
258
|
+
),
|
|
259
|
+
self.cnn_activation,
|
|
260
|
+
Rearrange(f"N {spatial_dim_names} C -> N C {spatial_dim_names}"),
|
|
261
|
+
nn.Dropout(cnn_dropout),
|
|
262
|
+
batchnormxd(cnn_activation_out_channels),
|
|
263
|
+
]
|
|
264
|
+
)
|
|
265
|
+
else:
|
|
266
|
+
self.cnn = nn.Identity()
|
|
267
|
+
self.activate_and_dropout = nn.Identity()
|
|
268
|
+
cnn_output_size = input_size
|
|
269
|
+
cnn_out_channels = in_channels
|
|
270
|
+
cnn_activation_out_channels = in_channels
|
|
271
|
+
|
|
272
|
+
pooling_kernel_voxels = math.prod(
|
|
273
|
+
spatial_tuple(pooling_kernel_size, self.spatial_dimensions)
|
|
274
|
+
)
|
|
275
|
+
|
|
276
|
+
pooling_output_size = (
|
|
277
|
+
cnn_output_size
|
|
278
|
+
if pooling_type is None
|
|
279
|
+
else calculate_output_spatial_size(
|
|
280
|
+
cnn_output_size,
|
|
281
|
+
kernel_size=pooling_kernel_size,
|
|
282
|
+
stride=pooling_kernel_stride,
|
|
283
|
+
padding=pooling_padding,
|
|
284
|
+
dilation=1,
|
|
285
|
+
)
|
|
286
|
+
)
|
|
287
|
+
|
|
288
|
+
if pooling_type is None:
|
|
289
|
+
pooling_out_channels = cnn_activation_out_channels
|
|
290
|
+
self.pool = nn.Identity()
|
|
291
|
+
|
|
292
|
+
elif pooling_type == "max":
|
|
293
|
+
pooling_out_channels = cnn_activation_out_channels
|
|
294
|
+
self.pool = maxpoolxd(
|
|
295
|
+
pooling_kernel_size,
|
|
296
|
+
stride=pooling_kernel_stride,
|
|
297
|
+
padding=pooling_padding,
|
|
298
|
+
)
|
|
299
|
+
elif pooling_type == "average":
|
|
300
|
+
pooling_out_channels = cnn_activation_out_channels
|
|
301
|
+
self.pool = avgpoolxd(
|
|
302
|
+
pooling_kernel_size,
|
|
303
|
+
stride=pooling_kernel_stride,
|
|
304
|
+
padding=pooling_padding,
|
|
305
|
+
)
|
|
306
|
+
elif pooling_type == "concat":
|
|
307
|
+
pooling_out_channels = pooling_kernel_voxels * cnn_activation_out_channels
|
|
308
|
+
self.pool = SpaceToDepth(
|
|
309
|
+
pooling_kernel_size,
|
|
310
|
+
stride=pooling_kernel_stride,
|
|
311
|
+
padding=pooling_padding,
|
|
312
|
+
spatial_dimensions=self.spatial_dimensions,
|
|
313
|
+
)
|
|
314
|
+
else:
|
|
315
|
+
raise NotImplementedError(
|
|
316
|
+
"Pooling type must be max, average, concat or None"
|
|
317
|
+
)
|
|
318
|
+
|
|
319
|
+
self.pooling_channels_padding = PadTensor(
|
|
320
|
+
(0, max(0, transformer_embedding_size - pooling_out_channels))
|
|
321
|
+
)
|
|
322
|
+
|
|
323
|
+
self.sequence_length = math.prod(pooling_output_size) # One token per voxel
|
|
324
|
+
|
|
325
|
+
if transformer_layers > 0:
|
|
326
|
+
self.transformer = TransformerEncoder(
|
|
327
|
+
self.sequence_length,
|
|
328
|
+
transformer_embedding_size,
|
|
329
|
+
transformer_layers,
|
|
330
|
+
transformer_heads,
|
|
331
|
+
absolute_position_embedding=transformer_absolute_position_embedding,
|
|
332
|
+
relative_position_embedding=transformer_relative_position_embedding,
|
|
333
|
+
source_size=pooling_output_size,
|
|
334
|
+
mlp_ratio=transformer_mlp_ratio,
|
|
335
|
+
activation=transformer_activation,
|
|
336
|
+
activation_kwargs=transformer_activation_kwargs,
|
|
337
|
+
ff_linear_module_up=transformer_ff_linear_module_up,
|
|
338
|
+
ff_linear_module_down=transformer_ff_linear_module_down,
|
|
339
|
+
msa_scaling=transformer_msa_scaling,
|
|
340
|
+
ff_dropout=transformer_ff_dropout,
|
|
341
|
+
ff_inner_dropout=transformer_ff_inner_dropout,
|
|
342
|
+
ff_outer_dropout=transformer_ff_outer_dropout,
|
|
343
|
+
msa_dropout=transformer_msa_dropout,
|
|
344
|
+
stochastic_depth=transformer_stochastic_depth,
|
|
345
|
+
causal=False,
|
|
346
|
+
linear_module=linear_module,
|
|
347
|
+
utility_tokens=transformer_utility_tokens,
|
|
348
|
+
return_utility_tokens=transformer_return_utility_tokens,
|
|
349
|
+
pre_norm=transformer_pre_norm,
|
|
350
|
+
normformer=transformer_normformer,
|
|
351
|
+
post_norm=transformer_post_norm,
|
|
352
|
+
checkpoint_ff=transformer_checkpoint_ff,
|
|
353
|
+
)
|
|
354
|
+
else:
|
|
355
|
+
self.transformer = nn.Identity()
|
|
356
|
+
|
|
357
|
+
if transformer_feedforward_first:
|
|
358
|
+
self.initial_ff = FeedforwardBlock(
|
|
359
|
+
max(transformer_embedding_size, pooling_out_channels),
|
|
360
|
+
transformer_mlp_ratio,
|
|
361
|
+
transformer_embedding_size,
|
|
362
|
+
activation=transformer_activation,
|
|
363
|
+
activation_kwargs=transformer_activation_kwargs,
|
|
364
|
+
dropout=(
|
|
365
|
+
# First truthy assigned value
|
|
366
|
+
transformer_initial_ff_dropout
|
|
367
|
+
if transformer_initial_ff_dropout is not None
|
|
368
|
+
else transformer_ff_dropout
|
|
369
|
+
),
|
|
370
|
+
inner_dropout=(
|
|
371
|
+
# First truthy assigned value
|
|
372
|
+
transformer_initial_ff_inner_dropout
|
|
373
|
+
if transformer_initial_ff_inner_dropout is not None
|
|
374
|
+
else transformer_ff_inner_dropout
|
|
375
|
+
),
|
|
376
|
+
outer_dropout=(
|
|
377
|
+
# First truthy assigned value
|
|
378
|
+
transformer_initial_ff_outer_dropout
|
|
379
|
+
if transformer_initial_ff_outer_dropout is not None
|
|
380
|
+
else transformer_ff_outer_dropout
|
|
381
|
+
),
|
|
382
|
+
linear_module_up=(
|
|
383
|
+
# First truthy assigned value
|
|
384
|
+
transformer_initial_ff_linear_module_up
|
|
385
|
+
or transformer_ff_linear_module_up
|
|
386
|
+
or linear_module
|
|
387
|
+
),
|
|
388
|
+
linear_module_down=(
|
|
389
|
+
# First truthy assigned value
|
|
390
|
+
transformer_initial_ff_linear_module_down
|
|
391
|
+
or transformer_ff_linear_module_down
|
|
392
|
+
or linear_module
|
|
393
|
+
),
|
|
394
|
+
pre_norm=transformer_pre_norm,
|
|
395
|
+
normformer=transformer_normformer,
|
|
396
|
+
post_norm=transformer_post_norm,
|
|
397
|
+
residual_path=transformer_initial_ff_residual_path,
|
|
398
|
+
checkpoint=transformer_checkpoint_ff,
|
|
399
|
+
)
|
|
400
|
+
else:
|
|
401
|
+
self.initial_ff = nn.Identity()
|
|
402
|
+
|
|
403
|
+
self.encoder = nn.Sequential(
|
|
404
|
+
*[
|
|
405
|
+
batchnormxd(in_channels) if initial_batch_norm else nn.Identity(),
|
|
406
|
+
self.cnn,
|
|
407
|
+
self.activate_and_dropout,
|
|
408
|
+
self.pool,
|
|
409
|
+
Rearrange( # for transformer
|
|
410
|
+
f"N C {spatial_dim_names} -> N ({spatial_dim_names}) C"
|
|
411
|
+
),
|
|
412
|
+
self.pooling_channels_padding,
|
|
413
|
+
self.initial_ff,
|
|
414
|
+
self.transformer,
|
|
415
|
+
]
|
|
416
|
+
)
|
|
417
|
+
|
|
418
|
+
self.reset_parameters()
|
|
419
|
+
|
|
420
|
+
def forward(self, x):
|
|
421
|
+
return self.encoder(x)
|
|
422
|
+
|
|
423
|
+
def attention_logits(self, x):
|
|
424
|
+
x = self.encoder[:-1](x)
|
|
425
|
+
return self.encoder[-1].attention_logits(x)
|
|
426
|
+
|
|
427
|
+
def reset_parameters(self):
|
|
428
|
+
for module in self.encoder:
|
|
429
|
+
if hasattr(module, "reset_parameters"):
|
|
430
|
+
module.reset_parameters()
|
|
431
|
+
|
|
432
|
+
|
|
433
|
+
class ViT(nn.Module):
|
|
434
|
+
"""
|
|
435
|
+
...
|
|
436
|
+
"""
|
|
437
|
+
|
|
438
|
+
def __init__(
|
|
439
|
+
self,
|
|
440
|
+
input_size=(32, 32),
|
|
441
|
+
image_classes=100,
|
|
442
|
+
in_channels=3,
|
|
443
|
+
initial_batch_norm=True,
|
|
444
|
+
cnn=True,
|
|
445
|
+
cnn_out_channels=16,
|
|
446
|
+
cnn_kernel_size=3,
|
|
447
|
+
cnn_kernel_stride=1,
|
|
448
|
+
cnn_padding="same",
|
|
449
|
+
cnn_kernel_dilation=1,
|
|
450
|
+
cnn_kernel_groups=1,
|
|
451
|
+
cnn_activation: nn.Module = ReLU,
|
|
452
|
+
cnn_activation_kwargs: Optional[dict] = None,
|
|
453
|
+
cnn_dropout=0.0,
|
|
454
|
+
pooling_type="concat", # max, average or concat
|
|
455
|
+
pooling_kernel_size=3,
|
|
456
|
+
pooling_kernel_stride=2,
|
|
457
|
+
pooling_padding=1,
|
|
458
|
+
transformer_feedforward_first=True,
|
|
459
|
+
transformer_initial_ff_residual_path=True,
|
|
460
|
+
transformer_initial_ff_linear_module_up=None,
|
|
461
|
+
transformer_initial_ff_linear_module_down=None,
|
|
462
|
+
transformer_initial_ff_dropout=None,
|
|
463
|
+
transformer_initial_ff_inner_dropout=None,
|
|
464
|
+
transformer_initial_ff_outer_dropout=None,
|
|
465
|
+
transformer_pre_norm=True,
|
|
466
|
+
transformer_normformer=False,
|
|
467
|
+
transformer_post_norm=False,
|
|
468
|
+
transformer_absolute_position_embedding=False,
|
|
469
|
+
transformer_relative_position_embedding=True,
|
|
470
|
+
transformer_embedding_size=256,
|
|
471
|
+
transformer_layers=7,
|
|
472
|
+
transformer_heads=4,
|
|
473
|
+
transformer_mlp_ratio=2,
|
|
474
|
+
transformer_utility_tokens=0,
|
|
475
|
+
transformer_return_utility_tokens=False,
|
|
476
|
+
transformer_activation: nn.Module = SquaredReLU,
|
|
477
|
+
transformer_activation_kwargs: Optional[dict] = None,
|
|
478
|
+
transformer_ff_linear_module_up=None,
|
|
479
|
+
transformer_ff_linear_module_down=None,
|
|
480
|
+
transformer_msa_scaling="d",
|
|
481
|
+
transformer_ff_dropout=0.0,
|
|
482
|
+
transformer_ff_inner_dropout=0.0,
|
|
483
|
+
transformer_ff_outer_dropout=0.0,
|
|
484
|
+
transformer_msa_dropout=0.1,
|
|
485
|
+
transformer_stochastic_depth=0.1,
|
|
486
|
+
transformer_checkpoint_ff=True,
|
|
487
|
+
head=SequencePoolClassificationHead,
|
|
488
|
+
batch_norm_logits=True,
|
|
489
|
+
logit_projection_layer=nn.Linear,
|
|
490
|
+
linear_module=nn.Linear,
|
|
491
|
+
):
|
|
492
|
+
|
|
493
|
+
super().__init__()
|
|
494
|
+
|
|
495
|
+
if isinstance(cnn_activation, str):
|
|
496
|
+
cnn_activation = {
|
|
497
|
+
"ReLU": ReLU,
|
|
498
|
+
"SquaredReLU": SquaredReLU,
|
|
499
|
+
"GELU": GELU,
|
|
500
|
+
"SwiGLU": SwiGLU,
|
|
501
|
+
}[cnn_activation]
|
|
502
|
+
|
|
503
|
+
if isinstance(transformer_activation, str):
|
|
504
|
+
transformer_activation = {
|
|
505
|
+
"ReLU": ReLU,
|
|
506
|
+
"SquaredReLU": SquaredReLU,
|
|
507
|
+
"GELU": GELU,
|
|
508
|
+
"SwiGLU": SwiGLU,
|
|
509
|
+
}[transformer_activation]
|
|
510
|
+
|
|
511
|
+
self.encoder = ViTEncoder(
|
|
512
|
+
input_size=input_size,
|
|
513
|
+
initial_batch_norm=initial_batch_norm,
|
|
514
|
+
in_channels=in_channels,
|
|
515
|
+
cnn=cnn,
|
|
516
|
+
cnn_out_channels=cnn_out_channels,
|
|
517
|
+
cnn_kernel_size=cnn_kernel_size,
|
|
518
|
+
cnn_kernel_stride=cnn_kernel_stride,
|
|
519
|
+
cnn_padding=cnn_padding,
|
|
520
|
+
cnn_kernel_dilation=cnn_kernel_dilation,
|
|
521
|
+
cnn_kernel_groups=cnn_kernel_groups,
|
|
522
|
+
cnn_activation=cnn_activation,
|
|
523
|
+
cnn_activation_kwargs=cnn_activation_kwargs,
|
|
524
|
+
cnn_dropout=cnn_dropout,
|
|
525
|
+
pooling_type=pooling_type,
|
|
526
|
+
pooling_kernel_size=pooling_kernel_size,
|
|
527
|
+
pooling_kernel_stride=pooling_kernel_stride,
|
|
528
|
+
pooling_padding=pooling_padding,
|
|
529
|
+
transformer_feedforward_first=transformer_feedforward_first,
|
|
530
|
+
transformer_initial_ff_residual_path=transformer_initial_ff_residual_path,
|
|
531
|
+
transformer_initial_ff_linear_module_up=transformer_initial_ff_linear_module_up,
|
|
532
|
+
transformer_initial_ff_linear_module_down=transformer_initial_ff_linear_module_down,
|
|
533
|
+
transformer_initial_ff_dropout=transformer_initial_ff_dropout,
|
|
534
|
+
transformer_initial_ff_inner_dropout=transformer_initial_ff_inner_dropout,
|
|
535
|
+
transformer_initial_ff_outer_dropout=transformer_initial_ff_outer_dropout,
|
|
536
|
+
transformer_pre_norm=transformer_pre_norm,
|
|
537
|
+
transformer_normformer=transformer_normformer,
|
|
538
|
+
transformer_post_norm=transformer_post_norm,
|
|
539
|
+
transformer_absolute_position_embedding=transformer_absolute_position_embedding,
|
|
540
|
+
transformer_relative_position_embedding=transformer_relative_position_embedding,
|
|
541
|
+
transformer_embedding_size=transformer_embedding_size,
|
|
542
|
+
transformer_layers=transformer_layers,
|
|
543
|
+
transformer_heads=transformer_heads,
|
|
544
|
+
transformer_mlp_ratio=transformer_mlp_ratio,
|
|
545
|
+
transformer_utility_tokens=transformer_utility_tokens,
|
|
546
|
+
transformer_return_utility_tokens=transformer_return_utility_tokens,
|
|
547
|
+
transformer_activation=transformer_activation,
|
|
548
|
+
transformer_activation_kwargs=transformer_activation_kwargs,
|
|
549
|
+
transformer_ff_linear_module_up=transformer_ff_linear_module_up,
|
|
550
|
+
transformer_ff_linear_module_down=transformer_ff_linear_module_down,
|
|
551
|
+
transformer_msa_scaling=transformer_msa_scaling,
|
|
552
|
+
transformer_ff_dropout=transformer_ff_dropout,
|
|
553
|
+
transformer_ff_inner_dropout=transformer_ff_inner_dropout,
|
|
554
|
+
transformer_ff_outer_dropout=transformer_ff_outer_dropout,
|
|
555
|
+
transformer_msa_dropout=transformer_msa_dropout,
|
|
556
|
+
transformer_stochastic_depth=transformer_stochastic_depth,
|
|
557
|
+
transformer_checkpoint_ff=transformer_checkpoint_ff,
|
|
558
|
+
linear_module=linear_module,
|
|
559
|
+
)
|
|
560
|
+
|
|
561
|
+
self.pool = head(
|
|
562
|
+
transformer_embedding_size,
|
|
563
|
+
image_classes,
|
|
564
|
+
logit_projection_layer=logit_projection_layer,
|
|
565
|
+
batch_norm_logits=batch_norm_logits,
|
|
566
|
+
)
|
|
567
|
+
|
|
568
|
+
self.reset_parameters()
|
|
569
|
+
|
|
570
|
+
@property
|
|
571
|
+
def sequence_length(self):
|
|
572
|
+
return self.encoder.sequence_length
|
|
573
|
+
|
|
574
|
+
def forward(self, x):
|
|
575
|
+
return self.pool(self.encoder(x))
|
|
576
|
+
|
|
577
|
+
def attention_logits(self, x):
|
|
578
|
+
return self.encoder.attention_logits(x)
|
|
579
|
+
|
|
580
|
+
def pool_attention(self, x):
|
|
581
|
+
if hasattr(self.pool.summarize, "attention"):
|
|
582
|
+
return self.pool.summarize.attention(self.encoder(x))
|
|
583
|
+
else:
|
|
584
|
+
raise NotImplementedError(
|
|
585
|
+
"`pool_attention` is currently only implemented where"
|
|
586
|
+
" head class is SequencePoolClassificationHead"
|
|
587
|
+
)
|
|
588
|
+
|
|
589
|
+
def head_to_utility_token_attention_logits(self, x):
|
|
590
|
+
all_attention = self.attention_logits(x)
|
|
591
|
+
batch_averages = torch.mean(all_attention, dim=0, keepdim=False)
|
|
592
|
+
sequence_averages = torch.mean(batch_averages, dim=-1, keepdim=False)
|
|
593
|
+
n_utility_tokens = self.encoder.encoder[-1]._utility_tokens
|
|
594
|
+
return sequence_averages[
|
|
595
|
+
:, :, :n_utility_tokens
|
|
596
|
+
] # (layer, head, utility_tokens)
|
|
597
|
+
|
|
598
|
+
def reset_parameters(self):
|
|
599
|
+
self.encoder.reset_parameters()
|
|
600
|
+
self.pool.reset_parameters()
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2025 nicholasbailey87
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
Metadata-Version: 2.3
|
|
2
|
+
Name: broccoli-ml
|
|
3
|
+
Version: 9.5.1
|
|
4
|
+
Summary: Some useful Pytorch models, circa 2025
|
|
5
|
+
License: MIT
|
|
6
|
+
Author: Nicholas Bailey
|
|
7
|
+
Requires-Python: >=3.8
|
|
8
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
9
|
+
Classifier: Programming Language :: Python :: 3
|
|
10
|
+
Classifier: Programming Language :: Python :: 3.8
|
|
11
|
+
Classifier: Programming Language :: Python :: 3.9
|
|
12
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
13
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
14
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
15
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
16
|
+
Requires-Dist: einops (>=0.8.1,<0.9.0)
|
|
17
|
+
Description-Content-Type: text/markdown
|
|
18
|
+
|
|
19
|
+
# broccoli
|
|
20
|
+
|
|
21
|
+
Some useful PyTorch models, circa 2025.
|
|
22
|
+
|
|
23
|
+

|
|
24
|
+
|
|
25
|
+
# Getting started
|
|
26
|
+
|
|
27
|
+
You can install broccoli with
|
|
28
|
+
|
|
29
|
+
```
|
|
30
|
+
pip install broccoli-ml
|
|
31
|
+
```
|
|
32
|
+
|
|
33
|
+
PyTorch is a peer dependency of `broccoli`, which means
|
|
34
|
+
* You will need to make sure you have PyTorch installed in order to use `broccoli`
|
|
35
|
+
* PyTorch will **not** be installed automatically when you install `broccoli`
|
|
36
|
+
|
|
37
|
+
We take this approach because PyTorch versioning is environment-specific and we don't know where you will want to use `broccoli`. If we automatically install PyTorch for you, there's a good chance we would get it wrong!
|
|
38
|
+
|
|
39
|
+
Therefore, please also make sure you install PyTorch.
|
|
40
|
+
|
|
41
|
+
# Usage examples
|
|
42
|
+
|
|
43
|
+
...
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
broccoli/__init__.py,sha256=tmyspsVxqPZHRQCY_NRwpW4SMNBbtE8E_8z7l-SAzSo,127
|
|
2
|
+
broccoli/activation.py,sha256=nrpTOrpg9k23_E4AJWy7VlXXAJCtCJCOR-TonEWJr04,3218
|
|
3
|
+
broccoli/cnn.py,sha256=WjoPDSpe3ttwxCBNfCVRdaCHvbeZ7G-a5_i8fUsK_d8,4889
|
|
4
|
+
broccoli/linear.py,sha256=i4U7ZC4ZWEH82YpDasx0Qs1pc3gkyL-3ajuyKCbsGTM,12649
|
|
5
|
+
broccoli/rope.py,sha256=GRqApBNmYCFaDak0WL1xE_BC5CTTYKQU_PBdeTcQcjc,12557
|
|
6
|
+
broccoli/tensor.py,sha256=um8mrxkYbvNDo-QvHlmJm8Aw6qcngOlUZPoAk_PMReA,4480
|
|
7
|
+
broccoli/transformer.py,sha256=ULk-QQX3hAI14-aCKhp9QSebzX4KUjlisEGup2Eycck,25565
|
|
8
|
+
broccoli/utils.py,sha256=oOWzn6dJ5nC_9r4zq0emmfmaYACJXJNFS48AOpW2jqc,358
|
|
9
|
+
broccoli/vit.py,sha256=sC6K3FK3a8ojOgvNWSWhuZHBtnFrrTQbsDdlagcKJH4,22224
|
|
10
|
+
broccoli_ml-9.5.1.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
|
|
11
|
+
broccoli_ml-9.5.1.dist-info/METADATA,sha256=HXRWnuc_-Gs_g37_RP3-POTLmi7sZamlzYv5SJEun1Y,1368
|
|
12
|
+
broccoli_ml-9.5.1.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
|
13
|
+
broccoli_ml-9.5.1.dist-info/RECORD,,
|