broccoli-ml 0.1.41__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- broccoli/cnn.py +404 -322
- broccoli/transformer.py +96 -82
- broccoli/vit.py +169 -125
- {broccoli_ml-0.1.41.dist-info → broccoli_ml-0.2.0.dist-info}/METADATA +1 -1
- {broccoli_ml-0.1.41.dist-info → broccoli_ml-0.2.0.dist-info}/RECORD +7 -7
- {broccoli_ml-0.1.41.dist-info → broccoli_ml-0.2.0.dist-info}/LICENSE +0 -0
- {broccoli_ml-0.1.41.dist-info → broccoli_ml-0.2.0.dist-info}/WHEEL +0 -0
broccoli/vit.py
CHANGED
@@ -1,8 +1,8 @@
|
|
1
1
|
import math
|
2
2
|
from typing import Optional
|
3
3
|
|
4
|
-
from .transformer import TransformerEncoder
|
5
|
-
from .cnn import
|
4
|
+
from .transformer import TransformerEncoder, DenoisingAutoEncoder
|
5
|
+
from .cnn import SpaceToDepth, calculate_output_spatial_size, spatial_tuple
|
6
6
|
from .activation import ReLU, SquaredReLU, GELU, SwiGLU
|
7
7
|
from einops import einsum
|
8
8
|
from einops.layers.torch import Rearrange
|
@@ -61,38 +61,34 @@ class CCTEncoder(nn.Module):
|
|
61
61
|
|
62
62
|
def __init__(
|
63
63
|
self,
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
64
|
+
input_size=(32, 32),
|
65
|
+
cnn_in_channels=3,
|
66
|
+
cnn_kernel_size=3,
|
67
|
+
cnn_kernel_stride=1,
|
68
|
+
cnn_kernel_padding="same",
|
69
|
+
cnn_kernel_dilation=1,
|
70
|
+
cnn_kernel_groups=1,
|
71
|
+
cnn_activation: nn.Module = nn.ReLU,
|
72
|
+
cnn_activation_kwargs: Optional[dict] = None,
|
73
|
+
cnn_dropout=0.0,
|
74
|
+
pooling_type="maxpool",
|
75
|
+
pooling_kernel_size=3,
|
76
|
+
pooling_kernel_stride=2,
|
77
|
+
pooling_kernel_padding=1,
|
71
78
|
transformer_position_embedding="absolute", # absolute or relative
|
72
79
|
transformer_embedding_size=256,
|
73
80
|
transformer_layers=7,
|
74
81
|
transformer_heads=4,
|
75
82
|
transformer_mlp_ratio=2,
|
76
83
|
transformer_bos_tokens=4,
|
77
|
-
tranformer_share_kv=True,
|
78
|
-
tranformer_max_subtract=True,
|
79
|
-
tranformer_d_model_scale=True,
|
80
|
-
tranformer_log_length_scale=True,
|
81
|
-
tranformer_quiet_attention=True,
|
82
|
-
cnn_activation: nn.Module = nn.ReLU,
|
83
|
-
cnn_activation_kwargs: Optional[dict] = None,
|
84
84
|
transformer_activation: nn.Module = nn.GELU,
|
85
85
|
transformer_activation_kwargs: Optional[dict] = None,
|
86
86
|
mlp_dropout=0.0,
|
87
87
|
msa_dropout=0.1,
|
88
88
|
stochastic_depth=0.1,
|
89
89
|
linear_module=nn.Linear,
|
90
|
-
|
91
|
-
batch_norm=False,
|
90
|
+
batch_norm=True,
|
92
91
|
):
|
93
|
-
if conv_pooling_type not in ["maxpool", "concat"]:
|
94
|
-
raise NotImplementedError("Pooling type must be maxpool or concat")
|
95
|
-
|
96
92
|
super().__init__()
|
97
93
|
|
98
94
|
if cnn_activation_kwargs is not None:
|
@@ -107,55 +103,122 @@ class CCTEncoder(nn.Module):
|
|
107
103
|
else:
|
108
104
|
self.transformer_activation = transformer_activation()
|
109
105
|
|
110
|
-
self.
|
106
|
+
self.input_size = input_size
|
107
|
+
self.spatial_dimensions = len(self.input_size)
|
108
|
+
|
109
|
+
if self.spatial_dimensions == 1:
|
110
|
+
maxpoolxd = nn.MaxPool1d
|
111
|
+
convxd = nn.Conv1d
|
112
|
+
batchnormxd = nn.BatchNorm1d
|
113
|
+
spatial_dim_names = "D1"
|
114
|
+
elif self.spatial_dimensions == 2:
|
115
|
+
maxpoolxd = nn.MaxPool2d
|
116
|
+
convxd = nn.Conv2d
|
117
|
+
batchnormxd = nn.BatchNorm2d
|
118
|
+
spatial_dim_names = "D1 D2"
|
119
|
+
elif self.spatial_dimensions == 3:
|
120
|
+
maxpoolxd = nn.MaxPool3d
|
121
|
+
convxd = nn.Conv3d
|
122
|
+
batchnormxd = nn.BatchNorm3d
|
123
|
+
spatial_dim_names = "D1 D2 D3"
|
124
|
+
else:
|
125
|
+
raise NotImplementedError(
|
126
|
+
"`input_size` must be a tuple of length 1, 2, or 3."
|
127
|
+
)
|
111
128
|
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
129
|
+
cnn_output_size = calculate_output_spatial_size(
|
130
|
+
input_size,
|
131
|
+
kernel_size=cnn_kernel_size,
|
132
|
+
stride=cnn_kernel_stride,
|
133
|
+
padding=cnn_kernel_padding,
|
134
|
+
dilation=cnn_kernel_dilation,
|
135
|
+
)
|
118
136
|
|
119
|
-
|
137
|
+
pooling_output_size = (
|
138
|
+
cnn_output_size
|
139
|
+
if pooling_type is None
|
140
|
+
else calculate_output_spatial_size(
|
141
|
+
cnn_output_size,
|
142
|
+
kernel_size=pooling_kernel_size,
|
143
|
+
stride=pooling_kernel_stride,
|
144
|
+
padding=pooling_kernel_padding,
|
145
|
+
dilation=1,
|
146
|
+
)
|
147
|
+
)
|
148
|
+
|
149
|
+
self.sequence_length = math.prod(pooling_output_size) # One token per voxel
|
150
|
+
|
151
|
+
pooling_kernel_voxels = math.prod(
|
152
|
+
spatial_tuple(pooling_kernel_size, self.spatial_dimensions)
|
153
|
+
)
|
120
154
|
|
121
|
-
if
|
122
|
-
|
123
|
-
elif
|
124
|
-
|
125
|
-
|
155
|
+
if pooling_type in ["maxpool", None]:
|
156
|
+
cnn_out_channels = transformer_embedding_size
|
157
|
+
elif pooling_type == "concat":
|
158
|
+
cnn_out_channels = math.floor(
|
159
|
+
transformer_embedding_size / pooling_kernel_voxels
|
126
160
|
)
|
161
|
+
else:
|
162
|
+
raise NotImplementedError("Pooling type must be maxpool, concat or None")
|
163
|
+
|
164
|
+
cnn_activation_out_channels = cnn_out_channels
|
127
165
|
|
128
|
-
# This
|
166
|
+
# This block rhymes:
|
129
167
|
if cnn_activation.__name__.endswith("GLU"):
|
130
|
-
|
131
|
-
|
132
|
-
self.
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
stride=
|
137
|
-
padding=
|
138
|
-
|
168
|
+
cnn_out_channels *= 2
|
169
|
+
|
170
|
+
self.cnn = convxd(
|
171
|
+
cnn_in_channels,
|
172
|
+
cnn_out_channels,
|
173
|
+
cnn_kernel_size,
|
174
|
+
stride=cnn_kernel_stride,
|
175
|
+
padding=cnn_kernel_padding,
|
176
|
+
dilation=cnn_kernel_dilation,
|
177
|
+
groups=cnn_kernel_groups,
|
178
|
+
bias=True,
|
179
|
+
padding_mode="zeros",
|
180
|
+
)
|
181
|
+
|
182
|
+
self.activate_and_dropout = nn.Sequential(
|
183
|
+
*[
|
184
|
+
Rearrange( # rearrange in case we're using XGLU activation
|
185
|
+
f"N C {spatial_dim_names} -> N {spatial_dim_names} C"
|
186
|
+
),
|
187
|
+
self.cnn_activation,
|
188
|
+
Rearrange(f"N {spatial_dim_names} C -> N C {spatial_dim_names}"),
|
189
|
+
nn.Dropout(cnn_dropout),
|
190
|
+
(
|
191
|
+
batchnormxd(cnn_activation_out_channels)
|
192
|
+
if batch_norm
|
193
|
+
else nn.Identity()
|
194
|
+
),
|
195
|
+
]
|
139
196
|
)
|
140
197
|
|
141
|
-
if
|
198
|
+
if pooling_type is None:
|
142
199
|
self.pool = nn.Sequential(
|
143
200
|
*[
|
144
|
-
Rearrange(
|
145
|
-
"N C
|
146
|
-
),
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
201
|
+
Rearrange(
|
202
|
+
f"N C {spatial_dim_names} -> N ({spatial_dim_names}) C"
|
203
|
+
), # for transformer
|
204
|
+
]
|
205
|
+
)
|
206
|
+
|
207
|
+
elif pooling_type == "maxpool":
|
208
|
+
self.pool = nn.Sequential(
|
209
|
+
*[
|
210
|
+
maxpoolxd(
|
211
|
+
pooling_kernel_size,
|
212
|
+
stride=pooling_kernel_stride,
|
213
|
+
padding=pooling_kernel_padding,
|
153
214
|
),
|
154
|
-
Rearrange(
|
215
|
+
Rearrange(
|
216
|
+
f"N C {spatial_dim_names} -> N ({spatial_dim_names}) C"
|
217
|
+
), # for transformer
|
155
218
|
]
|
156
219
|
)
|
157
220
|
|
158
|
-
elif
|
221
|
+
elif pooling_type == "concat":
|
159
222
|
|
160
223
|
if transformer_activation_kwargs is not None:
|
161
224
|
self.concatpool_activation = transformer_activation(
|
@@ -164,44 +227,29 @@ class CCTEncoder(nn.Module):
|
|
164
227
|
else:
|
165
228
|
self.concatpool_activation = transformer_activation()
|
166
229
|
|
167
|
-
concatpool_out_channels =
|
168
|
-
|
169
|
-
|
170
|
-
cnn_activation_output_channels = concatpool_out_channels / 2
|
171
|
-
else:
|
172
|
-
cnn_activation_output_channels = concatpool_out_channels
|
230
|
+
concatpool_out_channels = (
|
231
|
+
pooling_kernel_voxels * cnn_activation_out_channels
|
232
|
+
)
|
173
233
|
|
174
234
|
self.pool = nn.Sequential(
|
175
235
|
*[
|
176
|
-
|
177
|
-
|
178
|
-
stride=
|
179
|
-
padding=
|
180
|
-
|
181
|
-
Rearrange( # rearrange in case we're using XGLU activation
|
182
|
-
"N C H W -> N H W C"
|
236
|
+
SpaceToDepth(
|
237
|
+
pooling_kernel_size,
|
238
|
+
stride=pooling_kernel_stride,
|
239
|
+
padding=pooling_kernel_padding,
|
240
|
+
spatial_dimensions=self.spatial_dimensions,
|
183
241
|
),
|
184
|
-
|
185
|
-
|
186
|
-
Rearrange( # rearrange in case we're using XGLU activation
|
187
|
-
"N H W C -> N C H W"
|
242
|
+
Rearrange( # for transformer
|
243
|
+
f"N C {spatial_dim_names} -> N ({spatial_dim_names}) C"
|
188
244
|
),
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
),
|
193
|
-
nn.Linear(
|
194
|
-
cnn_activation_output_channels,
|
195
|
-
(
|
196
|
-
2 * transformer_embedding_size * transformer_mlp_ratio
|
197
|
-
if transformer_activation.__name__.endswith("GLU")
|
198
|
-
else transformer_embedding_size * transformer_mlp_ratio
|
199
|
-
),
|
200
|
-
),
|
201
|
-
self.concatpool_activation,
|
202
|
-
nn.Linear(
|
203
|
-
transformer_embedding_size * transformer_mlp_ratio,
|
245
|
+
DenoisingAutoEncoder(
|
246
|
+
concatpool_out_channels,
|
247
|
+
transformer_mlp_ratio,
|
204
248
|
transformer_embedding_size,
|
249
|
+
activation=transformer_activation,
|
250
|
+
activation_kwargs=transformer_activation_kwargs,
|
251
|
+
dropout=0.0,
|
252
|
+
linear_module=linear_module,
|
205
253
|
),
|
206
254
|
]
|
207
255
|
)
|
@@ -213,7 +261,7 @@ class CCTEncoder(nn.Module):
|
|
213
261
|
transformer_layers,
|
214
262
|
transformer_heads,
|
215
263
|
position_embedding_type=transformer_position_embedding,
|
216
|
-
source_size=
|
264
|
+
source_size=pooling_output_size,
|
217
265
|
mlp_ratio=transformer_mlp_ratio,
|
218
266
|
activation=transformer_activation,
|
219
267
|
activation_kwargs=transformer_activation_kwargs,
|
@@ -221,11 +269,6 @@ class CCTEncoder(nn.Module):
|
|
221
269
|
msa_dropout=msa_dropout,
|
222
270
|
stochastic_depth=stochastic_depth,
|
223
271
|
causal=False,
|
224
|
-
share_kv=tranformer_share_kv,
|
225
|
-
max_subtract=tranformer_max_subtract,
|
226
|
-
d_model_scale=tranformer_d_model_scale,
|
227
|
-
log_length_scale=tranformer_log_length_scale,
|
228
|
-
quiet_attention=tranformer_quiet_attention,
|
229
272
|
linear_module=linear_module,
|
230
273
|
bos_tokens=transformer_bos_tokens,
|
231
274
|
)
|
@@ -234,8 +277,9 @@ class CCTEncoder(nn.Module):
|
|
234
277
|
|
235
278
|
self.encoder = nn.Sequential(
|
236
279
|
*[
|
237
|
-
|
238
|
-
self.
|
280
|
+
batchnormxd(cnn_in_channels) if batch_norm else nn.Identity(),
|
281
|
+
self.cnn,
|
282
|
+
self.activate_and_dropout,
|
239
283
|
self.pool,
|
240
284
|
self.transformer,
|
241
285
|
]
|
@@ -255,8 +299,16 @@ class CCT(nn.Module):
|
|
255
299
|
|
256
300
|
def __init__(
|
257
301
|
self,
|
258
|
-
|
259
|
-
|
302
|
+
input_size=(32, 32),
|
303
|
+
cnn_in_channels=3,
|
304
|
+
cnn_kernel_size=3,
|
305
|
+
cnn_kernel_stride=1,
|
306
|
+
cnn_kernel_padding="same",
|
307
|
+
cnn_kernel_dilation=1,
|
308
|
+
cnn_kernel_groups=1,
|
309
|
+
cnn_activation: nn.Module = nn.ReLU,
|
310
|
+
cnn_activation_kwargs: Optional[dict] = None,
|
311
|
+
cnn_dropout=0.0,
|
260
312
|
pooling_type="maxpool",
|
261
313
|
pooling_kernel_size=3,
|
262
314
|
pooling_kernel_stride=2,
|
@@ -267,22 +319,14 @@ class CCT(nn.Module):
|
|
267
319
|
transformer_heads=4,
|
268
320
|
transformer_mlp_ratio=2,
|
269
321
|
transformer_bos_tokens=4,
|
270
|
-
tranformer_share_kv=True,
|
271
|
-
tranformer_max_subtract=True,
|
272
|
-
tranformer_d_model_scale=True,
|
273
|
-
tranformer_log_length_scale=True,
|
274
|
-
tranformer_quiet_attention=True,
|
275
|
-
cnn_activation: nn.Module = nn.ReLU,
|
276
|
-
cnn_activation_kwargs: Optional[dict] = None,
|
277
322
|
transformer_activation: nn.Module = nn.GELU,
|
278
323
|
transformer_activation_kwargs: Optional[dict] = None,
|
279
|
-
mlp_dropout=0.0,
|
280
|
-
msa_dropout=0.1,
|
281
|
-
stochastic_depth=0.1,
|
282
|
-
image_classes=100,
|
324
|
+
mlp_dropout=0.0,
|
325
|
+
msa_dropout=0.1,
|
326
|
+
stochastic_depth=0.1,
|
283
327
|
linear_module=nn.Linear,
|
284
|
-
|
285
|
-
|
328
|
+
batch_norm=True,
|
329
|
+
image_classes=100,
|
286
330
|
):
|
287
331
|
|
288
332
|
super().__init__()
|
@@ -304,32 +348,32 @@ class CCT(nn.Module):
|
|
304
348
|
}[transformer_activation]
|
305
349
|
|
306
350
|
self.encoder = CCTEncoder(
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
|
351
|
+
input_size=input_size,
|
352
|
+
cnn_in_channels=cnn_in_channels,
|
353
|
+
cnn_kernel_size=cnn_kernel_size,
|
354
|
+
cnn_kernel_stride=cnn_kernel_stride,
|
355
|
+
cnn_kernel_padding=cnn_kernel_padding,
|
356
|
+
cnn_kernel_dilation=cnn_kernel_dilation,
|
357
|
+
cnn_kernel_groups=cnn_kernel_groups,
|
358
|
+
cnn_activation=cnn_activation,
|
359
|
+
cnn_activation_kwargs=cnn_activation_kwargs,
|
360
|
+
cnn_dropout=cnn_dropout,
|
361
|
+
pooling_type=pooling_type,
|
362
|
+
pooling_kernel_size=pooling_kernel_size,
|
363
|
+
pooling_kernel_stride=pooling_kernel_stride,
|
364
|
+
pooling_kernel_padding=pooling_kernel_padding,
|
313
365
|
transformer_position_embedding=transformer_position_embedding,
|
314
366
|
transformer_embedding_size=transformer_embedding_size,
|
315
367
|
transformer_layers=transformer_layers,
|
316
368
|
transformer_heads=transformer_heads,
|
317
369
|
transformer_mlp_ratio=transformer_mlp_ratio,
|
318
370
|
transformer_bos_tokens=transformer_bos_tokens,
|
319
|
-
tranformer_share_kv=tranformer_share_kv,
|
320
|
-
tranformer_max_subtract=tranformer_max_subtract,
|
321
|
-
tranformer_d_model_scale=tranformer_d_model_scale,
|
322
|
-
tranformer_log_length_scale=tranformer_log_length_scale,
|
323
|
-
tranformer_quiet_attention=tranformer_quiet_attention,
|
324
|
-
cnn_activation=cnn_activation,
|
325
|
-
cnn_activation_kwargs=cnn_activation_kwargs,
|
326
371
|
transformer_activation=transformer_activation,
|
327
372
|
transformer_activation_kwargs=transformer_activation_kwargs,
|
328
373
|
mlp_dropout=mlp_dropout,
|
329
374
|
msa_dropout=msa_dropout,
|
330
375
|
stochastic_depth=stochastic_depth,
|
331
376
|
linear_module=linear_module,
|
332
|
-
image_channels=image_channels,
|
333
377
|
batch_norm=batch_norm,
|
334
378
|
)
|
335
379
|
self.pool = SequencePool(
|
@@ -3,15 +3,15 @@ broccoli/activation.py,sha256=jmKSNcq3VfZdVm8Ed65iiUB0ZfqmP_7lmEGkAWSIMdQ,2519
|
|
3
3
|
broccoli/assets/2025_resnet_imagenet_1k_pretrained_state_dict.pkl,sha256=RZpPupWxFaVfgZrK-gBgfW1hj78oMEGhVWTbjRB3qMo,46835797
|
4
4
|
broccoli/assets/cifar100_eigenvectors_size_2.pt,sha256=DjXDOXMeuMpIqNuGhX9z-OWYVqZwIMScSXZApRr9JjU,2501
|
5
5
|
broccoli/assets/cifar100_eigenvectors_size_3.pt,sha256=gL6k0xtXYiYP6ZSvEiMBdJ7kIkT0AngTpDJHFQqwgxA,7173
|
6
|
-
broccoli/cnn.py,sha256=
|
6
|
+
broccoli/cnn.py,sha256=jeRyKIAMWu1E3iyI14MGgSZuZivPMh12iqkqW9ilNjo,17785
|
7
7
|
broccoli/eigenpatches.py,sha256=J6n2usN1oQuHEHYiBNyYpn_a9eQcHjOBiIlvSei520Y,2413
|
8
8
|
broccoli/linear.py,sha256=0XYCi3ckTEKwAgBOMUSJP2HsnrroOH8eyrhRdpANG2w,1298
|
9
9
|
broccoli/rope.py,sha256=hw7kBPNR9GQXj4GxyIAffsGKPfcTPOFh8Bc7oEHtaZY,12108
|
10
10
|
broccoli/tensor.py,sha256=E2JK5mQwJf75e23-JGcDoT7QxQf89DJReUo2et1LhRY,1716
|
11
|
-
broccoli/transformer.py,sha256=
|
11
|
+
broccoli/transformer.py,sha256=23R58t3TLZMb9ulhCtQ3gXu0mPlfyPvLM8TaGOpaz58,16310
|
12
12
|
broccoli/utils.py,sha256=htq_hOsdhUhL0nJi9WkKiEYOjEoWqFpK5X49PtgTf-0,299
|
13
|
-
broccoli/vit.py,sha256=
|
14
|
-
broccoli_ml-0.
|
15
|
-
broccoli_ml-0.
|
16
|
-
broccoli_ml-0.
|
17
|
-
broccoli_ml-0.
|
13
|
+
broccoli/vit.py,sha256=4BHh8ohcVMr_iGVD-FRnyRnKQaaMMjdgs4fixeBm90M,13602
|
14
|
+
broccoli_ml-0.2.0.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
|
15
|
+
broccoli_ml-0.2.0.dist-info/METADATA,sha256=pvawWlKwj4Ee9e0VWqmu4jdK9fTLuTU82_NP4tCOVaA,1256
|
16
|
+
broccoli_ml-0.2.0.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
17
|
+
broccoli_ml-0.2.0.dist-info/RECORD,,
|
File without changes
|
File without changes
|