broccoli-ml 0.22.0__py3-none-any.whl → 0.23.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- broccoli/transformer.py +4 -8
- broccoli/vit.py +24 -32
- {broccoli_ml-0.22.0.dist-info → broccoli_ml-0.23.0.dist-info}/METADATA +1 -1
- {broccoli_ml-0.22.0.dist-info → broccoli_ml-0.23.0.dist-info}/RECORD +6 -6
- {broccoli_ml-0.22.0.dist-info → broccoli_ml-0.23.0.dist-info}/LICENSE +0 -0
- {broccoli_ml-0.22.0.dist-info → broccoli_ml-0.23.0.dist-info}/WHEEL +0 -0
broccoli/transformer.py
CHANGED
@@ -236,7 +236,7 @@ class FeedforwardBlock(nn.Module):
|
|
236
236
|
activation_kwargs=None,
|
237
237
|
dropout=0.0,
|
238
238
|
linear_module=nn.Linear,
|
239
|
-
|
239
|
+
sigma_reparam=False,
|
240
240
|
):
|
241
241
|
super().__init__()
|
242
242
|
|
@@ -253,21 +253,17 @@ class FeedforwardBlock(nn.Module):
|
|
253
253
|
else ratio * output_features
|
254
254
|
)
|
255
255
|
|
256
|
-
if
|
256
|
+
if sigma_reparam:
|
257
257
|
self.memory_type = SpectralNormLinear
|
258
258
|
else:
|
259
|
-
self.memory_type =
|
259
|
+
self.memory_type = linear_module
|
260
260
|
|
261
261
|
self.process = nn.Sequential(
|
262
262
|
*[
|
263
263
|
nn.LayerNorm(input_features),
|
264
264
|
linear_module(input_features, self.max_features),
|
265
265
|
self.activation,
|
266
|
-
|
267
|
-
# nn.LayerNorm(input_features)
|
268
|
-
# if not regularise_values
|
269
|
-
# else nn.Identity()
|
270
|
-
# ),
|
266
|
+
nn.LayerNorm(ratio * output_features),
|
271
267
|
self.memory_type(ratio * output_features, output_features),
|
272
268
|
self.dropout,
|
273
269
|
]
|
broccoli/vit.py
CHANGED
@@ -18,7 +18,10 @@ class PadTensor(nn.Module):
|
|
18
18
|
self.kwargs = kwargs
|
19
19
|
|
20
20
|
def forward(self, x):
|
21
|
-
|
21
|
+
if sum(self.args[0]) == 0:
|
22
|
+
return x
|
23
|
+
else:
|
24
|
+
return F.pad(x, *self.args, **self.kwargs)
|
22
25
|
|
23
26
|
|
24
27
|
class GetCLSToken(nn.Module):
|
@@ -97,9 +100,9 @@ class ViTEncoder(nn.Module):
|
|
97
100
|
def __init__(
|
98
101
|
self,
|
99
102
|
input_size=(32, 32),
|
103
|
+
in_channels=3,
|
100
104
|
initial_batch_norm=True,
|
101
105
|
cnn=True,
|
102
|
-
cnn_in_channels=3,
|
103
106
|
cnn_out_channels=16,
|
104
107
|
cnn_kernel_size=3,
|
105
108
|
cnn_kernel_stride=1,
|
@@ -113,7 +116,7 @@ class ViTEncoder(nn.Module):
|
|
113
116
|
pooling_kernel_size=3,
|
114
117
|
pooling_kernel_stride=2,
|
115
118
|
pooling_padding=1,
|
116
|
-
|
119
|
+
transformer_feedforward_first=True,
|
117
120
|
transformer_position_embedding="relative", # absolute or relative
|
118
121
|
transformer_embedding_size=256,
|
119
122
|
transformer_layers=7,
|
@@ -180,7 +183,7 @@ class ViTEncoder(nn.Module):
|
|
180
183
|
dilation=cnn_kernel_dilation,
|
181
184
|
)
|
182
185
|
self.cnn = convxd(
|
183
|
-
|
186
|
+
in_channels,
|
184
187
|
cnn_out_channels,
|
185
188
|
cnn_kernel_size,
|
186
189
|
stride=cnn_kernel_stride,
|
@@ -206,8 +209,8 @@ class ViTEncoder(nn.Module):
|
|
206
209
|
self.cnn = nn.Identity()
|
207
210
|
self.activate_and_dropout = nn.Identity()
|
208
211
|
cnn_output_size = input_size
|
209
|
-
cnn_out_channels =
|
210
|
-
cnn_activation_out_channels =
|
212
|
+
cnn_out_channels = in_channels
|
213
|
+
cnn_activation_out_channels = in_channels
|
211
214
|
|
212
215
|
pooling_kernel_voxels = math.prod(
|
213
216
|
spatial_tuple(pooling_kernel_size, self.spatial_dimensions)
|
@@ -262,6 +265,10 @@ class ViTEncoder(nn.Module):
|
|
262
265
|
"Pooling type must be max, average, concat or None"
|
263
266
|
)
|
264
267
|
|
268
|
+
self.pooling_channels_padding = PadTensor(
|
269
|
+
(0, max(0, transformer_embedding_size - pooling_out_channels))
|
270
|
+
)
|
271
|
+
|
265
272
|
self.sequence_length = math.prod(pooling_output_size) # One token per voxel
|
266
273
|
|
267
274
|
if transformer_layers > 0:
|
@@ -286,38 +293,23 @@ class ViTEncoder(nn.Module):
|
|
286
293
|
else:
|
287
294
|
self.transformer = nn.Identity()
|
288
295
|
|
289
|
-
if
|
290
|
-
self.
|
291
|
-
|
292
|
-
pooling_out_channels,
|
296
|
+
if transformer_feedforward_first:
|
297
|
+
self.initial_ff = FeedforwardBlock(
|
298
|
+
transformer_embedding_size,
|
293
299
|
transformer_mlp_ratio,
|
294
300
|
transformer_embedding_size,
|
295
301
|
activation=transformer_activation,
|
296
302
|
activation_kwargs=transformer_activation_kwargs,
|
297
303
|
dropout=transformer_mlp_dropout,
|
298
304
|
linear_module=linear_module,
|
299
|
-
|
300
|
-
elif pooling_out_channels == transformer_embedding_size:
|
301
|
-
self.intermediate_feedforward_layer = nn.Identity()
|
302
|
-
self.pooling_channels_padding = nn.Identity()
|
303
|
-
elif pooling_out_channels < transformer_embedding_size:
|
304
|
-
self.intermediate_feedforward_layer = nn.Identity()
|
305
|
-
self.pooling_channels_padding = PadTensor(
|
306
|
-
(0, transformer_embedding_size - pooling_out_channels)
|
305
|
+
sigma_reparam=True,
|
307
306
|
)
|
308
307
|
else:
|
309
|
-
|
310
|
-
"In a situation where the choice/parameters of the pooling and the"
|
311
|
-
+ " `cnn_out_channels` (or the number of `input_channels` if"
|
312
|
-
+ " `cnn`=False) means that the pooling will result"
|
313
|
-
+ " in more channels per pixel/voxel than the size of the"
|
314
|
-
+ " intended transformer embedding,"
|
315
|
-
+ " `intermediate_feedforward_layer` must be set to True"
|
316
|
-
)
|
308
|
+
self.initial_ff = nn.Identity()
|
317
309
|
|
318
310
|
self.encoder = nn.Sequential(
|
319
311
|
*[
|
320
|
-
batchnormxd(
|
312
|
+
batchnormxd(in_channels) if initial_batch_norm else nn.Identity(),
|
321
313
|
self.cnn,
|
322
314
|
self.activate_and_dropout,
|
323
315
|
self.pool,
|
@@ -325,7 +317,7 @@ class ViTEncoder(nn.Module):
|
|
325
317
|
f"N C {spatial_dim_names} -> N ({spatial_dim_names}) C"
|
326
318
|
),
|
327
319
|
self.pooling_channels_padding,
|
328
|
-
self.
|
320
|
+
self.initial_ff,
|
329
321
|
self.transformer,
|
330
322
|
]
|
331
323
|
)
|
@@ -347,9 +339,9 @@ class ViT(nn.Module):
|
|
347
339
|
self,
|
348
340
|
input_size=(32, 32),
|
349
341
|
image_classes=100,
|
342
|
+
in_channels=3,
|
350
343
|
initial_batch_norm=True,
|
351
344
|
cnn=True,
|
352
|
-
cnn_in_channels=3,
|
353
345
|
cnn_out_channels=16,
|
354
346
|
cnn_kernel_size=3,
|
355
347
|
cnn_kernel_stride=1,
|
@@ -363,7 +355,7 @@ class ViT(nn.Module):
|
|
363
355
|
pooling_kernel_size=3,
|
364
356
|
pooling_kernel_stride=2,
|
365
357
|
pooling_padding=1,
|
366
|
-
|
358
|
+
transformer_feedforward_first=True,
|
367
359
|
transformer_position_embedding="relative", # absolute or relative
|
368
360
|
transformer_embedding_size=256,
|
369
361
|
transformer_layers=7,
|
@@ -402,8 +394,8 @@ class ViT(nn.Module):
|
|
402
394
|
self.encoder = ViTEncoder(
|
403
395
|
input_size=input_size,
|
404
396
|
initial_batch_norm=initial_batch_norm,
|
397
|
+
in_channels=in_channels,
|
405
398
|
cnn=cnn,
|
406
|
-
cnn_in_channels=cnn_in_channels,
|
407
399
|
cnn_out_channels=cnn_out_channels,
|
408
400
|
cnn_kernel_size=cnn_kernel_size,
|
409
401
|
cnn_kernel_stride=cnn_kernel_stride,
|
@@ -417,7 +409,7 @@ class ViT(nn.Module):
|
|
417
409
|
pooling_kernel_size=pooling_kernel_size,
|
418
410
|
pooling_kernel_stride=pooling_kernel_stride,
|
419
411
|
pooling_padding=pooling_padding,
|
420
|
-
|
412
|
+
transformer_feedforward_first=transformer_feedforward_first,
|
421
413
|
transformer_position_embedding=transformer_position_embedding,
|
422
414
|
transformer_embedding_size=transformer_embedding_size,
|
423
415
|
transformer_layers=transformer_layers,
|
@@ -8,10 +8,10 @@ broccoli/eigenpatches.py,sha256=J6n2usN1oQuHEHYiBNyYpn_a9eQcHjOBiIlvSei520Y,2413
|
|
8
8
|
broccoli/linear.py,sha256=g8YrxNl6g_WcHrWVmbaBHJU5hv6daFS0r4TxAoPJ9UE,3012
|
9
9
|
broccoli/rope.py,sha256=hw7kBPNR9GQXj4GxyIAffsGKPfcTPOFh8Bc7oEHtaZY,12108
|
10
10
|
broccoli/tensor.py,sha256=MUvXtwD2f1sPTBym4FB0x_ZfsJUBNLgULUlN8btV8GI,1943
|
11
|
-
broccoli/transformer.py,sha256=
|
11
|
+
broccoli/transformer.py,sha256=NxOHP-XQRCtoiiTh7WJWNvSjpZzamiqQU966nQh5vhQ,16091
|
12
12
|
broccoli/utils.py,sha256=htq_hOsdhUhL0nJi9WkKiEYOjEoWqFpK5X49PtgTf-0,299
|
13
|
-
broccoli/vit.py,sha256=
|
14
|
-
broccoli_ml-0.
|
15
|
-
broccoli_ml-0.
|
16
|
-
broccoli_ml-0.
|
17
|
-
broccoli_ml-0.
|
13
|
+
broccoli/vit.py,sha256=LRofeDyz0wMwClVBeuxnaPmuDpuu8UAVddAOxIOYK8Y,15625
|
14
|
+
broccoli_ml-0.23.0.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
|
15
|
+
broccoli_ml-0.23.0.dist-info/METADATA,sha256=GfhIQiwV5g480TcVdGdzL3AsgJbQmV89Rxc_oJOrwJo,1257
|
16
|
+
broccoli_ml-0.23.0.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
17
|
+
broccoli_ml-0.23.0.dist-info/RECORD,,
|
File without changes
|
File without changes
|