broccoli-ml 0.1.40__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- broccoli/cnn.py +404 -322
- broccoli/transformer.py +96 -82
- broccoli/vit.py +170 -123
- {broccoli_ml-0.1.40.dist-info → broccoli_ml-0.2.0.dist-info}/METADATA +1 -1
- {broccoli_ml-0.1.40.dist-info → broccoli_ml-0.2.0.dist-info}/RECORD +7 -7
- {broccoli_ml-0.1.40.dist-info → broccoli_ml-0.2.0.dist-info}/LICENSE +0 -0
- {broccoli_ml-0.1.40.dist-info → broccoli_ml-0.2.0.dist-info}/WHEEL +0 -0
broccoli/vit.py
CHANGED
@@ -1,8 +1,8 @@
|
|
1
1
|
import math
|
2
2
|
from typing import Optional
|
3
3
|
|
4
|
-
from .transformer import TransformerEncoder
|
5
|
-
from .cnn import
|
4
|
+
from .transformer import TransformerEncoder, DenoisingAutoEncoder
|
5
|
+
from .cnn import SpaceToDepth, calculate_output_spatial_size, spatial_tuple
|
6
6
|
from .activation import ReLU, SquaredReLU, GELU, SwiGLU
|
7
7
|
from einops import einsum
|
8
8
|
from einops.layers.torch import Rearrange
|
@@ -61,38 +61,34 @@ class CCTEncoder(nn.Module):
|
|
61
61
|
|
62
62
|
def __init__(
|
63
63
|
self,
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
64
|
+
input_size=(32, 32),
|
65
|
+
cnn_in_channels=3,
|
66
|
+
cnn_kernel_size=3,
|
67
|
+
cnn_kernel_stride=1,
|
68
|
+
cnn_kernel_padding="same",
|
69
|
+
cnn_kernel_dilation=1,
|
70
|
+
cnn_kernel_groups=1,
|
71
|
+
cnn_activation: nn.Module = nn.ReLU,
|
72
|
+
cnn_activation_kwargs: Optional[dict] = None,
|
73
|
+
cnn_dropout=0.0,
|
74
|
+
pooling_type="maxpool",
|
75
|
+
pooling_kernel_size=3,
|
76
|
+
pooling_kernel_stride=2,
|
77
|
+
pooling_kernel_padding=1,
|
71
78
|
transformer_position_embedding="absolute", # absolute or relative
|
72
79
|
transformer_embedding_size=256,
|
73
80
|
transformer_layers=7,
|
74
81
|
transformer_heads=4,
|
75
82
|
transformer_mlp_ratio=2,
|
76
83
|
transformer_bos_tokens=4,
|
77
|
-
tranformer_share_kv=True,
|
78
|
-
tranformer_max_subtract=True,
|
79
|
-
tranformer_d_model_scale=True,
|
80
|
-
tranformer_log_length_scale=True,
|
81
|
-
tranformer_quiet_attention=True,
|
82
|
-
cnn_activation: nn.Module = nn.ReLU,
|
83
|
-
cnn_activation_kwargs: Optional[dict] = None,
|
84
84
|
transformer_activation: nn.Module = nn.GELU,
|
85
85
|
transformer_activation_kwargs: Optional[dict] = None,
|
86
86
|
mlp_dropout=0.0,
|
87
87
|
msa_dropout=0.1,
|
88
88
|
stochastic_depth=0.1,
|
89
89
|
linear_module=nn.Linear,
|
90
|
-
|
91
|
-
batch_norm=False,
|
90
|
+
batch_norm=True,
|
92
91
|
):
|
93
|
-
if conv_pooling_type not in ["maxpool", "concat"]:
|
94
|
-
raise NotImplementedError("Pooling type must be maxpool or concat")
|
95
|
-
|
96
92
|
super().__init__()
|
97
93
|
|
98
94
|
if cnn_activation_kwargs is not None:
|
@@ -107,55 +103,122 @@ class CCTEncoder(nn.Module):
|
|
107
103
|
else:
|
108
104
|
self.transformer_activation = transformer_activation()
|
109
105
|
|
110
|
-
self.
|
106
|
+
self.input_size = input_size
|
107
|
+
self.spatial_dimensions = len(self.input_size)
|
108
|
+
|
109
|
+
if self.spatial_dimensions == 1:
|
110
|
+
maxpoolxd = nn.MaxPool1d
|
111
|
+
convxd = nn.Conv1d
|
112
|
+
batchnormxd = nn.BatchNorm1d
|
113
|
+
spatial_dim_names = "D1"
|
114
|
+
elif self.spatial_dimensions == 2:
|
115
|
+
maxpoolxd = nn.MaxPool2d
|
116
|
+
convxd = nn.Conv2d
|
117
|
+
batchnormxd = nn.BatchNorm2d
|
118
|
+
spatial_dim_names = "D1 D2"
|
119
|
+
elif self.spatial_dimensions == 3:
|
120
|
+
maxpoolxd = nn.MaxPool3d
|
121
|
+
convxd = nn.Conv3d
|
122
|
+
batchnormxd = nn.BatchNorm3d
|
123
|
+
spatial_dim_names = "D1 D2 D3"
|
124
|
+
else:
|
125
|
+
raise NotImplementedError(
|
126
|
+
"`input_size` must be a tuple of length 1, 2, or 3."
|
127
|
+
)
|
128
|
+
|
129
|
+
cnn_output_size = calculate_output_spatial_size(
|
130
|
+
input_size,
|
131
|
+
kernel_size=cnn_kernel_size,
|
132
|
+
stride=cnn_kernel_stride,
|
133
|
+
padding=cnn_kernel_padding,
|
134
|
+
dilation=cnn_kernel_dilation,
|
135
|
+
)
|
136
|
+
|
137
|
+
pooling_output_size = (
|
138
|
+
cnn_output_size
|
139
|
+
if pooling_type is None
|
140
|
+
else calculate_output_spatial_size(
|
141
|
+
cnn_output_size,
|
142
|
+
kernel_size=pooling_kernel_size,
|
143
|
+
stride=pooling_kernel_stride,
|
144
|
+
padding=pooling_kernel_padding,
|
145
|
+
dilation=1,
|
146
|
+
)
|
147
|
+
)
|
111
148
|
|
112
|
-
|
113
|
-
output_size = math.floor(
|
114
|
-
(image_size + 2 * conv_pooling_kernel_padding - conv_pooling_kernel_size)
|
115
|
-
/ conv_pooling_kernel_stride
|
116
|
-
+ 1
|
117
|
-
) # output of pooling
|
149
|
+
self.sequence_length = math.prod(pooling_output_size) # One token per voxel
|
118
150
|
|
119
|
-
|
151
|
+
pooling_kernel_voxels = math.prod(
|
152
|
+
spatial_tuple(pooling_kernel_size, self.spatial_dimensions)
|
153
|
+
)
|
120
154
|
|
121
|
-
if
|
122
|
-
|
123
|
-
elif
|
124
|
-
|
125
|
-
|
155
|
+
if pooling_type in ["maxpool", None]:
|
156
|
+
cnn_out_channels = transformer_embedding_size
|
157
|
+
elif pooling_type == "concat":
|
158
|
+
cnn_out_channels = math.floor(
|
159
|
+
transformer_embedding_size / pooling_kernel_voxels
|
126
160
|
)
|
161
|
+
else:
|
162
|
+
raise NotImplementedError("Pooling type must be maxpool, concat or None")
|
127
163
|
|
128
|
-
|
164
|
+
cnn_activation_out_channels = cnn_out_channels
|
165
|
+
|
166
|
+
# This block rhymes:
|
129
167
|
if cnn_activation.__name__.endswith("GLU"):
|
130
|
-
|
131
|
-
|
132
|
-
self.
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
stride=
|
137
|
-
padding=
|
138
|
-
|
168
|
+
cnn_out_channels *= 2
|
169
|
+
|
170
|
+
self.cnn = convxd(
|
171
|
+
cnn_in_channels,
|
172
|
+
cnn_out_channels,
|
173
|
+
cnn_kernel_size,
|
174
|
+
stride=cnn_kernel_stride,
|
175
|
+
padding=cnn_kernel_padding,
|
176
|
+
dilation=cnn_kernel_dilation,
|
177
|
+
groups=cnn_kernel_groups,
|
178
|
+
bias=True,
|
179
|
+
padding_mode="zeros",
|
180
|
+
)
|
181
|
+
|
182
|
+
self.activate_and_dropout = nn.Sequential(
|
183
|
+
*[
|
184
|
+
Rearrange( # rearrange in case we're using XGLU activation
|
185
|
+
f"N C {spatial_dim_names} -> N {spatial_dim_names} C"
|
186
|
+
),
|
187
|
+
self.cnn_activation,
|
188
|
+
Rearrange(f"N {spatial_dim_names} C -> N C {spatial_dim_names}"),
|
189
|
+
nn.Dropout(cnn_dropout),
|
190
|
+
(
|
191
|
+
batchnormxd(cnn_activation_out_channels)
|
192
|
+
if batch_norm
|
193
|
+
else nn.Identity()
|
194
|
+
),
|
195
|
+
]
|
139
196
|
)
|
140
197
|
|
141
|
-
if
|
198
|
+
if pooling_type is None:
|
142
199
|
self.pool = nn.Sequential(
|
143
200
|
*[
|
144
|
-
Rearrange(
|
145
|
-
"N C
|
146
|
-
),
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
201
|
+
Rearrange(
|
202
|
+
f"N C {spatial_dim_names} -> N ({spatial_dim_names}) C"
|
203
|
+
), # for transformer
|
204
|
+
]
|
205
|
+
)
|
206
|
+
|
207
|
+
elif pooling_type == "maxpool":
|
208
|
+
self.pool = nn.Sequential(
|
209
|
+
*[
|
210
|
+
maxpoolxd(
|
211
|
+
pooling_kernel_size,
|
212
|
+
stride=pooling_kernel_stride,
|
213
|
+
padding=pooling_kernel_padding,
|
153
214
|
),
|
154
|
-
Rearrange(
|
215
|
+
Rearrange(
|
216
|
+
f"N C {spatial_dim_names} -> N ({spatial_dim_names}) C"
|
217
|
+
), # for transformer
|
155
218
|
]
|
156
219
|
)
|
157
220
|
|
158
|
-
elif
|
221
|
+
elif pooling_type == "concat":
|
159
222
|
|
160
223
|
if transformer_activation_kwargs is not None:
|
161
224
|
self.concatpool_activation = transformer_activation(
|
@@ -164,42 +227,30 @@ class CCTEncoder(nn.Module):
|
|
164
227
|
else:
|
165
228
|
self.concatpool_activation = transformer_activation()
|
166
229
|
|
167
|
-
concatpool_out_channels =
|
168
|
-
|
169
|
-
|
170
|
-
cnn_activation_output_channels = concatpool_out_channels / 2
|
171
|
-
else:
|
172
|
-
cnn_activation_output_channels = concatpool_out_channels
|
230
|
+
concatpool_out_channels = (
|
231
|
+
pooling_kernel_voxels * cnn_activation_out_channels
|
232
|
+
)
|
173
233
|
|
174
234
|
self.pool = nn.Sequential(
|
175
235
|
*[
|
176
|
-
|
177
|
-
|
178
|
-
stride=
|
179
|
-
padding=
|
236
|
+
SpaceToDepth(
|
237
|
+
pooling_kernel_size,
|
238
|
+
stride=pooling_kernel_stride,
|
239
|
+
padding=pooling_kernel_padding,
|
240
|
+
spatial_dimensions=self.spatial_dimensions,
|
180
241
|
),
|
181
|
-
Rearrange( #
|
182
|
-
"N C
|
242
|
+
Rearrange( # for transformer
|
243
|
+
f"N C {spatial_dim_names} -> N ({spatial_dim_names}) C"
|
183
244
|
),
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
245
|
+
DenoisingAutoEncoder(
|
246
|
+
concatpool_out_channels,
|
247
|
+
transformer_mlp_ratio,
|
248
|
+
transformer_embedding_size,
|
249
|
+
activation=transformer_activation,
|
250
|
+
activation_kwargs=transformer_activation_kwargs,
|
251
|
+
dropout=0.0,
|
252
|
+
linear_module=linear_module,
|
188
253
|
),
|
189
|
-
nn.BatchNorm2d(cnn_activation_output_channels),
|
190
|
-
Rearrange( # rearrange in case we're using XGLU activation
|
191
|
-
"N C H W -> N (H W) C"
|
192
|
-
),
|
193
|
-
nn.Linear(
|
194
|
-
cnn_activation_output_channels,
|
195
|
-
(
|
196
|
-
2 * transformer_embedding_size * transformer_mlp_ratio
|
197
|
-
if transformer_activation.__name__.endswith("GLU")
|
198
|
-
else transformer_embedding_size * transformer_mlp_ratio
|
199
|
-
),
|
200
|
-
),
|
201
|
-
self.concatpool_activation,
|
202
|
-
nn.Linear(transformer_embedding_size * transformer_mlp_ratio),
|
203
254
|
]
|
204
255
|
)
|
205
256
|
|
@@ -210,7 +261,7 @@ class CCTEncoder(nn.Module):
|
|
210
261
|
transformer_layers,
|
211
262
|
transformer_heads,
|
212
263
|
position_embedding_type=transformer_position_embedding,
|
213
|
-
source_size=
|
264
|
+
source_size=pooling_output_size,
|
214
265
|
mlp_ratio=transformer_mlp_ratio,
|
215
266
|
activation=transformer_activation,
|
216
267
|
activation_kwargs=transformer_activation_kwargs,
|
@@ -218,11 +269,6 @@ class CCTEncoder(nn.Module):
|
|
218
269
|
msa_dropout=msa_dropout,
|
219
270
|
stochastic_depth=stochastic_depth,
|
220
271
|
causal=False,
|
221
|
-
share_kv=tranformer_share_kv,
|
222
|
-
max_subtract=tranformer_max_subtract,
|
223
|
-
d_model_scale=tranformer_d_model_scale,
|
224
|
-
log_length_scale=tranformer_log_length_scale,
|
225
|
-
quiet_attention=tranformer_quiet_attention,
|
226
272
|
linear_module=linear_module,
|
227
273
|
bos_tokens=transformer_bos_tokens,
|
228
274
|
)
|
@@ -231,8 +277,9 @@ class CCTEncoder(nn.Module):
|
|
231
277
|
|
232
278
|
self.encoder = nn.Sequential(
|
233
279
|
*[
|
234
|
-
|
235
|
-
self.
|
280
|
+
batchnormxd(cnn_in_channels) if batch_norm else nn.Identity(),
|
281
|
+
self.cnn,
|
282
|
+
self.activate_and_dropout,
|
236
283
|
self.pool,
|
237
284
|
self.transformer,
|
238
285
|
]
|
@@ -252,8 +299,16 @@ class CCT(nn.Module):
|
|
252
299
|
|
253
300
|
def __init__(
|
254
301
|
self,
|
255
|
-
|
256
|
-
|
302
|
+
input_size=(32, 32),
|
303
|
+
cnn_in_channels=3,
|
304
|
+
cnn_kernel_size=3,
|
305
|
+
cnn_kernel_stride=1,
|
306
|
+
cnn_kernel_padding="same",
|
307
|
+
cnn_kernel_dilation=1,
|
308
|
+
cnn_kernel_groups=1,
|
309
|
+
cnn_activation: nn.Module = nn.ReLU,
|
310
|
+
cnn_activation_kwargs: Optional[dict] = None,
|
311
|
+
cnn_dropout=0.0,
|
257
312
|
pooling_type="maxpool",
|
258
313
|
pooling_kernel_size=3,
|
259
314
|
pooling_kernel_stride=2,
|
@@ -264,22 +319,14 @@ class CCT(nn.Module):
|
|
264
319
|
transformer_heads=4,
|
265
320
|
transformer_mlp_ratio=2,
|
266
321
|
transformer_bos_tokens=4,
|
267
|
-
tranformer_share_kv=True,
|
268
|
-
tranformer_max_subtract=True,
|
269
|
-
tranformer_d_model_scale=True,
|
270
|
-
tranformer_log_length_scale=True,
|
271
|
-
tranformer_quiet_attention=True,
|
272
|
-
cnn_activation: nn.Module = nn.ReLU,
|
273
|
-
cnn_activation_kwargs: Optional[dict] = None,
|
274
322
|
transformer_activation: nn.Module = nn.GELU,
|
275
323
|
transformer_activation_kwargs: Optional[dict] = None,
|
276
|
-
mlp_dropout=0.0,
|
277
|
-
msa_dropout=0.1,
|
278
|
-
stochastic_depth=0.1,
|
279
|
-
image_classes=100,
|
324
|
+
mlp_dropout=0.0,
|
325
|
+
msa_dropout=0.1,
|
326
|
+
stochastic_depth=0.1,
|
280
327
|
linear_module=nn.Linear,
|
281
|
-
|
282
|
-
|
328
|
+
batch_norm=True,
|
329
|
+
image_classes=100,
|
283
330
|
):
|
284
331
|
|
285
332
|
super().__init__()
|
@@ -301,32 +348,32 @@ class CCT(nn.Module):
|
|
301
348
|
}[transformer_activation]
|
302
349
|
|
303
350
|
self.encoder = CCTEncoder(
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
|
309
|
-
|
351
|
+
input_size=input_size,
|
352
|
+
cnn_in_channels=cnn_in_channels,
|
353
|
+
cnn_kernel_size=cnn_kernel_size,
|
354
|
+
cnn_kernel_stride=cnn_kernel_stride,
|
355
|
+
cnn_kernel_padding=cnn_kernel_padding,
|
356
|
+
cnn_kernel_dilation=cnn_kernel_dilation,
|
357
|
+
cnn_kernel_groups=cnn_kernel_groups,
|
358
|
+
cnn_activation=cnn_activation,
|
359
|
+
cnn_activation_kwargs=cnn_activation_kwargs,
|
360
|
+
cnn_dropout=cnn_dropout,
|
361
|
+
pooling_type=pooling_type,
|
362
|
+
pooling_kernel_size=pooling_kernel_size,
|
363
|
+
pooling_kernel_stride=pooling_kernel_stride,
|
364
|
+
pooling_kernel_padding=pooling_kernel_padding,
|
310
365
|
transformer_position_embedding=transformer_position_embedding,
|
311
366
|
transformer_embedding_size=transformer_embedding_size,
|
312
367
|
transformer_layers=transformer_layers,
|
313
368
|
transformer_heads=transformer_heads,
|
314
369
|
transformer_mlp_ratio=transformer_mlp_ratio,
|
315
370
|
transformer_bos_tokens=transformer_bos_tokens,
|
316
|
-
tranformer_share_kv=tranformer_share_kv,
|
317
|
-
tranformer_max_subtract=tranformer_max_subtract,
|
318
|
-
tranformer_d_model_scale=tranformer_d_model_scale,
|
319
|
-
tranformer_log_length_scale=tranformer_log_length_scale,
|
320
|
-
tranformer_quiet_attention=tranformer_quiet_attention,
|
321
|
-
cnn_activation=cnn_activation,
|
322
|
-
cnn_activation_kwargs=cnn_activation_kwargs,
|
323
371
|
transformer_activation=transformer_activation,
|
324
372
|
transformer_activation_kwargs=transformer_activation_kwargs,
|
325
373
|
mlp_dropout=mlp_dropout,
|
326
374
|
msa_dropout=msa_dropout,
|
327
375
|
stochastic_depth=stochastic_depth,
|
328
376
|
linear_module=linear_module,
|
329
|
-
image_channels=image_channels,
|
330
377
|
batch_norm=batch_norm,
|
331
378
|
)
|
332
379
|
self.pool = SequencePool(
|
@@ -3,15 +3,15 @@ broccoli/activation.py,sha256=jmKSNcq3VfZdVm8Ed65iiUB0ZfqmP_7lmEGkAWSIMdQ,2519
|
|
3
3
|
broccoli/assets/2025_resnet_imagenet_1k_pretrained_state_dict.pkl,sha256=RZpPupWxFaVfgZrK-gBgfW1hj78oMEGhVWTbjRB3qMo,46835797
|
4
4
|
broccoli/assets/cifar100_eigenvectors_size_2.pt,sha256=DjXDOXMeuMpIqNuGhX9z-OWYVqZwIMScSXZApRr9JjU,2501
|
5
5
|
broccoli/assets/cifar100_eigenvectors_size_3.pt,sha256=gL6k0xtXYiYP6ZSvEiMBdJ7kIkT0AngTpDJHFQqwgxA,7173
|
6
|
-
broccoli/cnn.py,sha256=
|
6
|
+
broccoli/cnn.py,sha256=jeRyKIAMWu1E3iyI14MGgSZuZivPMh12iqkqW9ilNjo,17785
|
7
7
|
broccoli/eigenpatches.py,sha256=J6n2usN1oQuHEHYiBNyYpn_a9eQcHjOBiIlvSei520Y,2413
|
8
8
|
broccoli/linear.py,sha256=0XYCi3ckTEKwAgBOMUSJP2HsnrroOH8eyrhRdpANG2w,1298
|
9
9
|
broccoli/rope.py,sha256=hw7kBPNR9GQXj4GxyIAffsGKPfcTPOFh8Bc7oEHtaZY,12108
|
10
10
|
broccoli/tensor.py,sha256=E2JK5mQwJf75e23-JGcDoT7QxQf89DJReUo2et1LhRY,1716
|
11
|
-
broccoli/transformer.py,sha256=
|
11
|
+
broccoli/transformer.py,sha256=23R58t3TLZMb9ulhCtQ3gXu0mPlfyPvLM8TaGOpaz58,16310
|
12
12
|
broccoli/utils.py,sha256=htq_hOsdhUhL0nJi9WkKiEYOjEoWqFpK5X49PtgTf-0,299
|
13
|
-
broccoli/vit.py,sha256=
|
14
|
-
broccoli_ml-0.
|
15
|
-
broccoli_ml-0.
|
16
|
-
broccoli_ml-0.
|
17
|
-
broccoli_ml-0.
|
13
|
+
broccoli/vit.py,sha256=4BHh8ohcVMr_iGVD-FRnyRnKQaaMMjdgs4fixeBm90M,13602
|
14
|
+
broccoli_ml-0.2.0.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
|
15
|
+
broccoli_ml-0.2.0.dist-info/METADATA,sha256=pvawWlKwj4Ee9e0VWqmu4jdK9fTLuTU82_NP4tCOVaA,1256
|
16
|
+
broccoli_ml-0.2.0.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
17
|
+
broccoli_ml-0.2.0.dist-info/RECORD,,
|
File without changes
|
File without changes
|