broccoli-ml 0.5.0__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- broccoli/transformer.py +3 -1
- broccoli/vit.py +25 -6
- {broccoli_ml-0.5.0.dist-info → broccoli_ml-0.5.1.dist-info}/METADATA +1 -1
- {broccoli_ml-0.5.0.dist-info → broccoli_ml-0.5.1.dist-info}/RECORD +6 -6
- {broccoli_ml-0.5.0.dist-info → broccoli_ml-0.5.1.dist-info}/LICENSE +0 -0
- {broccoli_ml-0.5.0.dist-info → broccoli_ml-0.5.1.dist-info}/WHEEL +0 -0
broccoli/transformer.py
CHANGED
@@ -395,6 +395,7 @@ class TransformerEncoder(nn.Module):
|
|
395
395
|
causal=False,
|
396
396
|
linear_module=nn.Linear,
|
397
397
|
bos_tokens=0,
|
398
|
+
return_bos_tokens=False,
|
398
399
|
):
|
399
400
|
if position_embedding_type == "relative":
|
400
401
|
assert source_size is not None # TODO: make this a proper exception
|
@@ -403,6 +404,7 @@ class TransformerEncoder(nn.Module):
|
|
403
404
|
self.seq_len = seq_len
|
404
405
|
self.n_heads = n_heads
|
405
406
|
self._bos_tokens = bos_tokens
|
407
|
+
self.return_bos_tokens = return_bos_tokens
|
406
408
|
|
407
409
|
# Initialise BOS tokens with normal init, like usual Pytorch embeddings
|
408
410
|
if self._bos_tokens:
|
@@ -479,7 +481,7 @@ class TransformerEncoder(nn.Module):
|
|
479
481
|
for block in self.blocks:
|
480
482
|
x = block(x)
|
481
483
|
|
482
|
-
if self._bos_tokens:
|
484
|
+
if self._bos_tokens and not self.return_bos_tokens:
|
483
485
|
return x[:, self._bos_tokens :, :]
|
484
486
|
else:
|
485
487
|
return x
|
broccoli/vit.py
CHANGED
@@ -76,7 +76,7 @@ class ViTEncoder(nn.Module):
|
|
76
76
|
cnn_activation: nn.Module = ReLU,
|
77
77
|
cnn_activation_kwargs: Optional[dict] = None,
|
78
78
|
cnn_dropout=0.0,
|
79
|
-
pooling_type="concat", #
|
79
|
+
pooling_type="concat", # max, average or concat
|
80
80
|
pooling_kernel_size=3,
|
81
81
|
pooling_kernel_stride=2,
|
82
82
|
pooling_padding=1,
|
@@ -114,16 +114,19 @@ class ViTEncoder(nn.Module):
|
|
114
114
|
|
115
115
|
if self.spatial_dimensions == 1:
|
116
116
|
maxpoolxd = nn.MaxPool1d
|
117
|
+
avgpoolxd = nn.AvgPool1d
|
117
118
|
convxd = nn.Conv1d
|
118
119
|
batchnormxd = nn.BatchNorm1d
|
119
120
|
spatial_dim_names = "D1"
|
120
121
|
elif self.spatial_dimensions == 2:
|
121
122
|
maxpoolxd = nn.MaxPool2d
|
123
|
+
avgpoolxd = nn.AvgPool2d
|
122
124
|
convxd = nn.Conv2d
|
123
125
|
batchnormxd = nn.BatchNorm2d
|
124
126
|
spatial_dim_names = "D1 D2"
|
125
127
|
elif self.spatial_dimensions == 3:
|
126
128
|
maxpoolxd = nn.MaxPool3d
|
129
|
+
avgpoolxd = nn.AvgPool3d
|
127
130
|
convxd = nn.Conv3d
|
128
131
|
batchnormxd = nn.BatchNorm3d
|
129
132
|
spatial_dim_names = "D1 D2 D3"
|
@@ -158,7 +161,7 @@ class ViTEncoder(nn.Module):
|
|
158
161
|
spatial_tuple(pooling_kernel_size, self.spatial_dimensions)
|
159
162
|
)
|
160
163
|
|
161
|
-
if pooling_type in ["
|
164
|
+
if pooling_type in ["max", "average", None]:
|
162
165
|
cnn_out_channels = transformer_embedding_size
|
163
166
|
elif pooling_type == "concat":
|
164
167
|
cnn_out_channels = max(
|
@@ -166,7 +169,9 @@ class ViTEncoder(nn.Module):
|
|
166
169
|
minimum_cnn_out_channels,
|
167
170
|
)
|
168
171
|
else:
|
169
|
-
raise NotImplementedError(
|
172
|
+
raise NotImplementedError(
|
173
|
+
"Pooling type must be max, average, concat or None"
|
174
|
+
)
|
170
175
|
|
171
176
|
cnn_activation_out_channels = cnn_out_channels
|
172
177
|
|
@@ -211,7 +216,7 @@ class ViTEncoder(nn.Module):
|
|
211
216
|
]
|
212
217
|
)
|
213
218
|
|
214
|
-
elif pooling_type == "
|
219
|
+
elif pooling_type == "max":
|
215
220
|
self.pool = nn.Sequential(
|
216
221
|
*[
|
217
222
|
maxpoolxd(
|
@@ -225,6 +230,20 @@ class ViTEncoder(nn.Module):
|
|
225
230
|
]
|
226
231
|
)
|
227
232
|
|
233
|
+
elif pooling_type == "average":
|
234
|
+
self.pool = nn.Sequential(
|
235
|
+
*[
|
236
|
+
avgpoolxd(
|
237
|
+
pooling_kernel_size,
|
238
|
+
stride=pooling_kernel_stride,
|
239
|
+
padding=pooling_padding,
|
240
|
+
),
|
241
|
+
Rearrange(
|
242
|
+
f"N C {spatial_dim_names} -> N ({spatial_dim_names}) C"
|
243
|
+
), # for transformer
|
244
|
+
]
|
245
|
+
)
|
246
|
+
|
228
247
|
elif pooling_type == "concat":
|
229
248
|
|
230
249
|
if transformer_activation_kwargs is not None:
|
@@ -303,7 +322,7 @@ class ViTEncoder(nn.Module):
|
|
303
322
|
return self.encoder(x)
|
304
323
|
|
305
324
|
|
306
|
-
class
|
325
|
+
class CCT(nn.Module):
|
307
326
|
"""
|
308
327
|
Denoising convolutional transformer
|
309
328
|
Based on the Compact Convolutional Transformer (CCT) of [Hasani et al. (2021)
|
@@ -325,7 +344,7 @@ class ViT(nn.Module):
|
|
325
344
|
cnn_activation: nn.Module = ReLU,
|
326
345
|
cnn_activation_kwargs: Optional[dict] = None,
|
327
346
|
cnn_dropout=0.0,
|
328
|
-
pooling_type="concat", #
|
347
|
+
pooling_type="concat", # max, average or concat
|
329
348
|
pooling_kernel_size=3,
|
330
349
|
pooling_kernel_stride=2,
|
331
350
|
pooling_padding=1,
|
@@ -8,10 +8,10 @@ broccoli/eigenpatches.py,sha256=J6n2usN1oQuHEHYiBNyYpn_a9eQcHjOBiIlvSei520Y,2413
|
|
8
8
|
broccoli/linear.py,sha256=0XYCi3ckTEKwAgBOMUSJP2HsnrroOH8eyrhRdpANG2w,1298
|
9
9
|
broccoli/rope.py,sha256=hw7kBPNR9GQXj4GxyIAffsGKPfcTPOFh8Bc7oEHtaZY,12108
|
10
10
|
broccoli/tensor.py,sha256=E2JK5mQwJf75e23-JGcDoT7QxQf89DJReUo2et1LhRY,1716
|
11
|
-
broccoli/transformer.py,sha256
|
11
|
+
broccoli/transformer.py,sha256=SwvutiYOiPlqLzbO_twye7Hna5DsJukVOzzAx9CTCyU,16417
|
12
12
|
broccoli/utils.py,sha256=htq_hOsdhUhL0nJi9WkKiEYOjEoWqFpK5X49PtgTf-0,299
|
13
|
-
broccoli/vit.py,sha256=
|
14
|
-
broccoli_ml-0.5.
|
15
|
-
broccoli_ml-0.5.
|
16
|
-
broccoli_ml-0.5.
|
17
|
-
broccoli_ml-0.5.
|
13
|
+
broccoli/vit.py,sha256=CSllHj_cI0dQveuLSoHD-Y95nOqVO1F-p-8SmlJffhM,15272
|
14
|
+
broccoli_ml-0.5.1.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
|
15
|
+
broccoli_ml-0.5.1.dist-info/METADATA,sha256=IFo3lnkE6Zti81jUvLuBONLlR2PUUAwj1AKqPCgFUdI,1256
|
16
|
+
broccoli_ml-0.5.1.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
17
|
+
broccoli_ml-0.5.1.dist-info/RECORD,,
|
File without changes
|
File without changes
|