broccoli-ml 0.1.41__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- broccoli/cnn.py +404 -322
- broccoli/transformer.py +96 -82
- broccoli/vit.py +173 -125
- {broccoli_ml-0.1.41.dist-info → broccoli_ml-0.3.0.dist-info}/METADATA +1 -1
- {broccoli_ml-0.1.41.dist-info → broccoli_ml-0.3.0.dist-info}/RECORD +7 -7
- {broccoli_ml-0.1.41.dist-info → broccoli_ml-0.3.0.dist-info}/LICENSE +0 -0
- {broccoli_ml-0.1.41.dist-info → broccoli_ml-0.3.0.dist-info}/WHEEL +0 -0
broccoli/vit.py
CHANGED
@@ -1,8 +1,8 @@
|
|
1
1
|
import math
|
2
2
|
from typing import Optional
|
3
3
|
|
4
|
-
from .transformer import TransformerEncoder
|
5
|
-
from .cnn import
|
4
|
+
from .transformer import TransformerEncoder, DenoisingAutoEncoder
|
5
|
+
from .cnn import SpaceToDepth, calculate_output_spatial_size, spatial_tuple
|
6
6
|
from .activation import ReLU, SquaredReLU, GELU, SwiGLU
|
7
7
|
from einops import einsum
|
8
8
|
from einops.layers.torch import Rearrange
|
@@ -61,38 +61,35 @@ class CCTEncoder(nn.Module):
|
|
61
61
|
|
62
62
|
def __init__(
|
63
63
|
self,
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
64
|
+
input_size=(32, 32),
|
65
|
+
cnn_in_channels=3,
|
66
|
+
minimum_cnn_out_channels=16,
|
67
|
+
cnn_kernel_size=3,
|
68
|
+
cnn_kernel_stride=1,
|
69
|
+
cnn_kernel_padding="same",
|
70
|
+
cnn_kernel_dilation=1,
|
71
|
+
cnn_kernel_groups=1,
|
72
|
+
cnn_activation: nn.Module = nn.ReLU,
|
73
|
+
cnn_activation_kwargs: Optional[dict] = None,
|
74
|
+
cnn_dropout=0.0,
|
75
|
+
pooling_type="maxpool",
|
76
|
+
pooling_kernel_size=3,
|
77
|
+
pooling_kernel_stride=2,
|
78
|
+
pooling_kernel_padding=1,
|
71
79
|
transformer_position_embedding="absolute", # absolute or relative
|
72
80
|
transformer_embedding_size=256,
|
73
81
|
transformer_layers=7,
|
74
82
|
transformer_heads=4,
|
75
83
|
transformer_mlp_ratio=2,
|
76
84
|
transformer_bos_tokens=4,
|
77
|
-
tranformer_share_kv=True,
|
78
|
-
tranformer_max_subtract=True,
|
79
|
-
tranformer_d_model_scale=True,
|
80
|
-
tranformer_log_length_scale=True,
|
81
|
-
tranformer_quiet_attention=True,
|
82
|
-
cnn_activation: nn.Module = nn.ReLU,
|
83
|
-
cnn_activation_kwargs: Optional[dict] = None,
|
84
85
|
transformer_activation: nn.Module = nn.GELU,
|
85
86
|
transformer_activation_kwargs: Optional[dict] = None,
|
86
87
|
mlp_dropout=0.0,
|
87
88
|
msa_dropout=0.1,
|
88
89
|
stochastic_depth=0.1,
|
89
90
|
linear_module=nn.Linear,
|
90
|
-
|
91
|
-
batch_norm=False,
|
91
|
+
batch_norm=True,
|
92
92
|
):
|
93
|
-
if conv_pooling_type not in ["maxpool", "concat"]:
|
94
|
-
raise NotImplementedError("Pooling type must be maxpool or concat")
|
95
|
-
|
96
93
|
super().__init__()
|
97
94
|
|
98
95
|
if cnn_activation_kwargs is not None:
|
@@ -107,55 +104,123 @@ class CCTEncoder(nn.Module):
|
|
107
104
|
else:
|
108
105
|
self.transformer_activation = transformer_activation()
|
109
106
|
|
110
|
-
self.
|
107
|
+
self.input_size = input_size
|
108
|
+
self.spatial_dimensions = len(self.input_size)
|
109
|
+
|
110
|
+
if self.spatial_dimensions == 1:
|
111
|
+
maxpoolxd = nn.MaxPool1d
|
112
|
+
convxd = nn.Conv1d
|
113
|
+
batchnormxd = nn.BatchNorm1d
|
114
|
+
spatial_dim_names = "D1"
|
115
|
+
elif self.spatial_dimensions == 2:
|
116
|
+
maxpoolxd = nn.MaxPool2d
|
117
|
+
convxd = nn.Conv2d
|
118
|
+
batchnormxd = nn.BatchNorm2d
|
119
|
+
spatial_dim_names = "D1 D2"
|
120
|
+
elif self.spatial_dimensions == 3:
|
121
|
+
maxpoolxd = nn.MaxPool3d
|
122
|
+
convxd = nn.Conv3d
|
123
|
+
batchnormxd = nn.BatchNorm3d
|
124
|
+
spatial_dim_names = "D1 D2 D3"
|
125
|
+
else:
|
126
|
+
raise NotImplementedError(
|
127
|
+
"`input_size` must be a tuple of length 1, 2, or 3."
|
128
|
+
)
|
111
129
|
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
130
|
+
cnn_output_size = calculate_output_spatial_size(
|
131
|
+
input_size,
|
132
|
+
kernel_size=cnn_kernel_size,
|
133
|
+
stride=cnn_kernel_stride,
|
134
|
+
padding=cnn_kernel_padding,
|
135
|
+
dilation=cnn_kernel_dilation,
|
136
|
+
)
|
118
137
|
|
119
|
-
|
138
|
+
pooling_output_size = (
|
139
|
+
cnn_output_size
|
140
|
+
if pooling_type is None
|
141
|
+
else calculate_output_spatial_size(
|
142
|
+
cnn_output_size,
|
143
|
+
kernel_size=pooling_kernel_size,
|
144
|
+
stride=pooling_kernel_stride,
|
145
|
+
padding=pooling_kernel_padding,
|
146
|
+
dilation=1,
|
147
|
+
)
|
148
|
+
)
|
149
|
+
|
150
|
+
self.sequence_length = math.prod(pooling_output_size) # One token per voxel
|
151
|
+
|
152
|
+
pooling_kernel_voxels = math.prod(
|
153
|
+
spatial_tuple(pooling_kernel_size, self.spatial_dimensions)
|
154
|
+
)
|
120
155
|
|
121
|
-
if
|
122
|
-
|
123
|
-
elif
|
124
|
-
|
125
|
-
math.floor(transformer_embedding_size /
|
156
|
+
if pooling_type in ["maxpool", None]:
|
157
|
+
cnn_out_channels = transformer_embedding_size
|
158
|
+
elif pooling_type == "concat":
|
159
|
+
cnn_out_channels = min(
|
160
|
+
math.floor(transformer_embedding_size / pooling_kernel_voxels),
|
161
|
+
minimum_cnn_out_channels,
|
126
162
|
)
|
163
|
+
else:
|
164
|
+
raise NotImplementedError("Pooling type must be maxpool, concat or None")
|
165
|
+
|
166
|
+
cnn_activation_out_channels = cnn_out_channels
|
127
167
|
|
128
|
-
# This
|
168
|
+
# This block rhymes:
|
129
169
|
if cnn_activation.__name__.endswith("GLU"):
|
130
|
-
|
131
|
-
|
132
|
-
self.
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
stride=
|
137
|
-
padding=
|
138
|
-
|
170
|
+
cnn_out_channels *= 2
|
171
|
+
|
172
|
+
self.cnn = convxd(
|
173
|
+
cnn_in_channels,
|
174
|
+
cnn_out_channels,
|
175
|
+
cnn_kernel_size,
|
176
|
+
stride=cnn_kernel_stride,
|
177
|
+
padding=cnn_kernel_padding,
|
178
|
+
dilation=cnn_kernel_dilation,
|
179
|
+
groups=cnn_kernel_groups,
|
180
|
+
bias=True,
|
181
|
+
padding_mode="zeros",
|
182
|
+
)
|
183
|
+
|
184
|
+
self.activate_and_dropout = nn.Sequential(
|
185
|
+
*[
|
186
|
+
Rearrange( # rearrange in case we're using XGLU activation
|
187
|
+
f"N C {spatial_dim_names} -> N {spatial_dim_names} C"
|
188
|
+
),
|
189
|
+
self.cnn_activation,
|
190
|
+
Rearrange(f"N {spatial_dim_names} C -> N C {spatial_dim_names}"),
|
191
|
+
nn.Dropout(cnn_dropout),
|
192
|
+
(
|
193
|
+
batchnormxd(cnn_activation_out_channels)
|
194
|
+
if batch_norm
|
195
|
+
else nn.Identity()
|
196
|
+
),
|
197
|
+
]
|
139
198
|
)
|
140
199
|
|
141
|
-
if
|
200
|
+
if pooling_type is None:
|
142
201
|
self.pool = nn.Sequential(
|
143
202
|
*[
|
144
|
-
Rearrange(
|
145
|
-
"N C
|
146
|
-
),
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
203
|
+
Rearrange(
|
204
|
+
f"N C {spatial_dim_names} -> N ({spatial_dim_names}) C"
|
205
|
+
), # for transformer
|
206
|
+
]
|
207
|
+
)
|
208
|
+
|
209
|
+
elif pooling_type == "maxpool":
|
210
|
+
self.pool = nn.Sequential(
|
211
|
+
*[
|
212
|
+
maxpoolxd(
|
213
|
+
pooling_kernel_size,
|
214
|
+
stride=pooling_kernel_stride,
|
215
|
+
padding=pooling_kernel_padding,
|
153
216
|
),
|
154
|
-
Rearrange(
|
217
|
+
Rearrange(
|
218
|
+
f"N C {spatial_dim_names} -> N ({spatial_dim_names}) C"
|
219
|
+
), # for transformer
|
155
220
|
]
|
156
221
|
)
|
157
222
|
|
158
|
-
elif
|
223
|
+
elif pooling_type == "concat":
|
159
224
|
|
160
225
|
if transformer_activation_kwargs is not None:
|
161
226
|
self.concatpool_activation = transformer_activation(
|
@@ -164,44 +229,29 @@ class CCTEncoder(nn.Module):
|
|
164
229
|
else:
|
165
230
|
self.concatpool_activation = transformer_activation()
|
166
231
|
|
167
|
-
concatpool_out_channels =
|
168
|
-
|
169
|
-
|
170
|
-
cnn_activation_output_channels = concatpool_out_channels / 2
|
171
|
-
else:
|
172
|
-
cnn_activation_output_channels = concatpool_out_channels
|
232
|
+
concatpool_out_channels = (
|
233
|
+
pooling_kernel_voxels * cnn_activation_out_channels
|
234
|
+
)
|
173
235
|
|
174
236
|
self.pool = nn.Sequential(
|
175
237
|
*[
|
176
|
-
|
177
|
-
|
178
|
-
stride=
|
179
|
-
padding=
|
180
|
-
|
181
|
-
Rearrange( # rearrange in case we're using XGLU activation
|
182
|
-
"N C H W -> N H W C"
|
238
|
+
SpaceToDepth(
|
239
|
+
pooling_kernel_size,
|
240
|
+
stride=pooling_kernel_stride,
|
241
|
+
padding=pooling_kernel_padding,
|
242
|
+
spatial_dimensions=self.spatial_dimensions,
|
183
243
|
),
|
184
|
-
|
185
|
-
|
186
|
-
Rearrange( # rearrange in case we're using XGLU activation
|
187
|
-
"N H W C -> N C H W"
|
244
|
+
Rearrange( # for transformer
|
245
|
+
f"N C {spatial_dim_names} -> N ({spatial_dim_names}) C"
|
188
246
|
),
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
),
|
193
|
-
nn.Linear(
|
194
|
-
cnn_activation_output_channels,
|
195
|
-
(
|
196
|
-
2 * transformer_embedding_size * transformer_mlp_ratio
|
197
|
-
if transformer_activation.__name__.endswith("GLU")
|
198
|
-
else transformer_embedding_size * transformer_mlp_ratio
|
199
|
-
),
|
200
|
-
),
|
201
|
-
self.concatpool_activation,
|
202
|
-
nn.Linear(
|
203
|
-
transformer_embedding_size * transformer_mlp_ratio,
|
247
|
+
DenoisingAutoEncoder(
|
248
|
+
concatpool_out_channels,
|
249
|
+
transformer_mlp_ratio,
|
204
250
|
transformer_embedding_size,
|
251
|
+
activation=transformer_activation,
|
252
|
+
activation_kwargs=transformer_activation_kwargs,
|
253
|
+
dropout=0.0,
|
254
|
+
linear_module=linear_module,
|
205
255
|
),
|
206
256
|
]
|
207
257
|
)
|
@@ -213,7 +263,7 @@ class CCTEncoder(nn.Module):
|
|
213
263
|
transformer_layers,
|
214
264
|
transformer_heads,
|
215
265
|
position_embedding_type=transformer_position_embedding,
|
216
|
-
source_size=
|
266
|
+
source_size=pooling_output_size,
|
217
267
|
mlp_ratio=transformer_mlp_ratio,
|
218
268
|
activation=transformer_activation,
|
219
269
|
activation_kwargs=transformer_activation_kwargs,
|
@@ -221,11 +271,6 @@ class CCTEncoder(nn.Module):
|
|
221
271
|
msa_dropout=msa_dropout,
|
222
272
|
stochastic_depth=stochastic_depth,
|
223
273
|
causal=False,
|
224
|
-
share_kv=tranformer_share_kv,
|
225
|
-
max_subtract=tranformer_max_subtract,
|
226
|
-
d_model_scale=tranformer_d_model_scale,
|
227
|
-
log_length_scale=tranformer_log_length_scale,
|
228
|
-
quiet_attention=tranformer_quiet_attention,
|
229
274
|
linear_module=linear_module,
|
230
275
|
bos_tokens=transformer_bos_tokens,
|
231
276
|
)
|
@@ -234,8 +279,9 @@ class CCTEncoder(nn.Module):
|
|
234
279
|
|
235
280
|
self.encoder = nn.Sequential(
|
236
281
|
*[
|
237
|
-
|
238
|
-
self.
|
282
|
+
batchnormxd(cnn_in_channels) if batch_norm else nn.Identity(),
|
283
|
+
self.cnn,
|
284
|
+
self.activate_and_dropout,
|
239
285
|
self.pool,
|
240
286
|
self.transformer,
|
241
287
|
]
|
@@ -255,8 +301,17 @@ class CCT(nn.Module):
|
|
255
301
|
|
256
302
|
def __init__(
|
257
303
|
self,
|
258
|
-
|
259
|
-
|
304
|
+
input_size=(32, 32),
|
305
|
+
cnn_in_channels=3,
|
306
|
+
minimum_cnn_out_channels=16,
|
307
|
+
cnn_kernel_size=3,
|
308
|
+
cnn_kernel_stride=1,
|
309
|
+
cnn_kernel_padding="same",
|
310
|
+
cnn_kernel_dilation=1,
|
311
|
+
cnn_kernel_groups=1,
|
312
|
+
cnn_activation: nn.Module = nn.ReLU,
|
313
|
+
cnn_activation_kwargs: Optional[dict] = None,
|
314
|
+
cnn_dropout=0.0,
|
260
315
|
pooling_type="maxpool",
|
261
316
|
pooling_kernel_size=3,
|
262
317
|
pooling_kernel_stride=2,
|
@@ -267,22 +322,14 @@ class CCT(nn.Module):
|
|
267
322
|
transformer_heads=4,
|
268
323
|
transformer_mlp_ratio=2,
|
269
324
|
transformer_bos_tokens=4,
|
270
|
-
tranformer_share_kv=True,
|
271
|
-
tranformer_max_subtract=True,
|
272
|
-
tranformer_d_model_scale=True,
|
273
|
-
tranformer_log_length_scale=True,
|
274
|
-
tranformer_quiet_attention=True,
|
275
|
-
cnn_activation: nn.Module = nn.ReLU,
|
276
|
-
cnn_activation_kwargs: Optional[dict] = None,
|
277
325
|
transformer_activation: nn.Module = nn.GELU,
|
278
326
|
transformer_activation_kwargs: Optional[dict] = None,
|
279
|
-
mlp_dropout=0.0,
|
280
|
-
msa_dropout=0.1,
|
281
|
-
stochastic_depth=0.1,
|
282
|
-
image_classes=100,
|
327
|
+
mlp_dropout=0.0,
|
328
|
+
msa_dropout=0.1,
|
329
|
+
stochastic_depth=0.1,
|
283
330
|
linear_module=nn.Linear,
|
284
|
-
|
285
|
-
|
331
|
+
batch_norm=True,
|
332
|
+
image_classes=100,
|
286
333
|
):
|
287
334
|
|
288
335
|
super().__init__()
|
@@ -304,32 +351,33 @@ class CCT(nn.Module):
|
|
304
351
|
}[transformer_activation]
|
305
352
|
|
306
353
|
self.encoder = CCTEncoder(
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
|
354
|
+
input_size=input_size,
|
355
|
+
cnn_in_channels=cnn_in_channels,
|
356
|
+
minimum_cnn_out_channels=minimum_cnn_out_channels,
|
357
|
+
cnn_kernel_size=cnn_kernel_size,
|
358
|
+
cnn_kernel_stride=cnn_kernel_stride,
|
359
|
+
cnn_kernel_padding=cnn_kernel_padding,
|
360
|
+
cnn_kernel_dilation=cnn_kernel_dilation,
|
361
|
+
cnn_kernel_groups=cnn_kernel_groups,
|
362
|
+
cnn_activation=cnn_activation,
|
363
|
+
cnn_activation_kwargs=cnn_activation_kwargs,
|
364
|
+
cnn_dropout=cnn_dropout,
|
365
|
+
pooling_type=pooling_type,
|
366
|
+
pooling_kernel_size=pooling_kernel_size,
|
367
|
+
pooling_kernel_stride=pooling_kernel_stride,
|
368
|
+
pooling_kernel_padding=pooling_kernel_padding,
|
313
369
|
transformer_position_embedding=transformer_position_embedding,
|
314
370
|
transformer_embedding_size=transformer_embedding_size,
|
315
371
|
transformer_layers=transformer_layers,
|
316
372
|
transformer_heads=transformer_heads,
|
317
373
|
transformer_mlp_ratio=transformer_mlp_ratio,
|
318
374
|
transformer_bos_tokens=transformer_bos_tokens,
|
319
|
-
tranformer_share_kv=tranformer_share_kv,
|
320
|
-
tranformer_max_subtract=tranformer_max_subtract,
|
321
|
-
tranformer_d_model_scale=tranformer_d_model_scale,
|
322
|
-
tranformer_log_length_scale=tranformer_log_length_scale,
|
323
|
-
tranformer_quiet_attention=tranformer_quiet_attention,
|
324
|
-
cnn_activation=cnn_activation,
|
325
|
-
cnn_activation_kwargs=cnn_activation_kwargs,
|
326
375
|
transformer_activation=transformer_activation,
|
327
376
|
transformer_activation_kwargs=transformer_activation_kwargs,
|
328
377
|
mlp_dropout=mlp_dropout,
|
329
378
|
msa_dropout=msa_dropout,
|
330
379
|
stochastic_depth=stochastic_depth,
|
331
380
|
linear_module=linear_module,
|
332
|
-
image_channels=image_channels,
|
333
381
|
batch_norm=batch_norm,
|
334
382
|
)
|
335
383
|
self.pool = SequencePool(
|
@@ -3,15 +3,15 @@ broccoli/activation.py,sha256=jmKSNcq3VfZdVm8Ed65iiUB0ZfqmP_7lmEGkAWSIMdQ,2519
|
|
3
3
|
broccoli/assets/2025_resnet_imagenet_1k_pretrained_state_dict.pkl,sha256=RZpPupWxFaVfgZrK-gBgfW1hj78oMEGhVWTbjRB3qMo,46835797
|
4
4
|
broccoli/assets/cifar100_eigenvectors_size_2.pt,sha256=DjXDOXMeuMpIqNuGhX9z-OWYVqZwIMScSXZApRr9JjU,2501
|
5
5
|
broccoli/assets/cifar100_eigenvectors_size_3.pt,sha256=gL6k0xtXYiYP6ZSvEiMBdJ7kIkT0AngTpDJHFQqwgxA,7173
|
6
|
-
broccoli/cnn.py,sha256=
|
6
|
+
broccoli/cnn.py,sha256=jeRyKIAMWu1E3iyI14MGgSZuZivPMh12iqkqW9ilNjo,17785
|
7
7
|
broccoli/eigenpatches.py,sha256=J6n2usN1oQuHEHYiBNyYpn_a9eQcHjOBiIlvSei520Y,2413
|
8
8
|
broccoli/linear.py,sha256=0XYCi3ckTEKwAgBOMUSJP2HsnrroOH8eyrhRdpANG2w,1298
|
9
9
|
broccoli/rope.py,sha256=hw7kBPNR9GQXj4GxyIAffsGKPfcTPOFh8Bc7oEHtaZY,12108
|
10
10
|
broccoli/tensor.py,sha256=E2JK5mQwJf75e23-JGcDoT7QxQf89DJReUo2et1LhRY,1716
|
11
|
-
broccoli/transformer.py,sha256=
|
11
|
+
broccoli/transformer.py,sha256=23R58t3TLZMb9ulhCtQ3gXu0mPlfyPvLM8TaGOpaz58,16310
|
12
12
|
broccoli/utils.py,sha256=htq_hOsdhUhL0nJi9WkKiEYOjEoWqFpK5X49PtgTf-0,299
|
13
|
-
broccoli/vit.py,sha256=
|
14
|
-
broccoli_ml-0.
|
15
|
-
broccoli_ml-0.
|
16
|
-
broccoli_ml-0.
|
17
|
-
broccoli_ml-0.
|
13
|
+
broccoli/vit.py,sha256=NuHW2xcaUEv_IHAZbrrGHUWKu9D7JMR1iKDCCX07RQs,13787
|
14
|
+
broccoli_ml-0.3.0.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
|
15
|
+
broccoli_ml-0.3.0.dist-info/METADATA,sha256=sAbHQ0Q2yM5kaovkF22cTKCk4SU_z6vi6QtmOMMwJlQ,1256
|
16
|
+
broccoli_ml-0.3.0.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
17
|
+
broccoli_ml-0.3.0.dist-info/RECORD,,
|
File without changes
|
File without changes
|