broccoli-ml 0.2.0__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- broccoli/vit.py +6 -2
- {broccoli_ml-0.2.0.dist-info → broccoli_ml-0.3.0.dist-info}/METADATA +1 -1
- {broccoli_ml-0.2.0.dist-info → broccoli_ml-0.3.0.dist-info}/RECORD +5 -5
- {broccoli_ml-0.2.0.dist-info → broccoli_ml-0.3.0.dist-info}/LICENSE +0 -0
- {broccoli_ml-0.2.0.dist-info → broccoli_ml-0.3.0.dist-info}/WHEEL +0 -0
broccoli/vit.py
CHANGED
@@ -63,6 +63,7 @@ class CCTEncoder(nn.Module):
|
|
63
63
|
self,
|
64
64
|
input_size=(32, 32),
|
65
65
|
cnn_in_channels=3,
|
66
|
+
minimum_cnn_out_channels=16,
|
66
67
|
cnn_kernel_size=3,
|
67
68
|
cnn_kernel_stride=1,
|
68
69
|
cnn_kernel_padding="same",
|
@@ -155,8 +156,9 @@ class CCTEncoder(nn.Module):
|
|
155
156
|
if pooling_type in ["maxpool", None]:
|
156
157
|
cnn_out_channels = transformer_embedding_size
|
157
158
|
elif pooling_type == "concat":
|
158
|
-
cnn_out_channels =
|
159
|
-
transformer_embedding_size / pooling_kernel_voxels
|
159
|
+
cnn_out_channels = min(
|
160
|
+
math.floor(transformer_embedding_size / pooling_kernel_voxels),
|
161
|
+
minimum_cnn_out_channels,
|
160
162
|
)
|
161
163
|
else:
|
162
164
|
raise NotImplementedError("Pooling type must be maxpool, concat or None")
|
@@ -301,6 +303,7 @@ class CCT(nn.Module):
|
|
301
303
|
self,
|
302
304
|
input_size=(32, 32),
|
303
305
|
cnn_in_channels=3,
|
306
|
+
minimum_cnn_out_channels=16,
|
304
307
|
cnn_kernel_size=3,
|
305
308
|
cnn_kernel_stride=1,
|
306
309
|
cnn_kernel_padding="same",
|
@@ -350,6 +353,7 @@ class CCT(nn.Module):
|
|
350
353
|
self.encoder = CCTEncoder(
|
351
354
|
input_size=input_size,
|
352
355
|
cnn_in_channels=cnn_in_channels,
|
356
|
+
minimum_cnn_out_channels=minimum_cnn_out_channels,
|
353
357
|
cnn_kernel_size=cnn_kernel_size,
|
354
358
|
cnn_kernel_stride=cnn_kernel_stride,
|
355
359
|
cnn_kernel_padding=cnn_kernel_padding,
|
@@ -10,8 +10,8 @@ broccoli/rope.py,sha256=hw7kBPNR9GQXj4GxyIAffsGKPfcTPOFh8Bc7oEHtaZY,12108
|
|
10
10
|
broccoli/tensor.py,sha256=E2JK5mQwJf75e23-JGcDoT7QxQf89DJReUo2et1LhRY,1716
|
11
11
|
broccoli/transformer.py,sha256=23R58t3TLZMb9ulhCtQ3gXu0mPlfyPvLM8TaGOpaz58,16310
|
12
12
|
broccoli/utils.py,sha256=htq_hOsdhUhL0nJi9WkKiEYOjEoWqFpK5X49PtgTf-0,299
|
13
|
-
broccoli/vit.py,sha256=
|
14
|
-
broccoli_ml-0.
|
15
|
-
broccoli_ml-0.
|
16
|
-
broccoli_ml-0.
|
17
|
-
broccoli_ml-0.
|
13
|
+
broccoli/vit.py,sha256=NuHW2xcaUEv_IHAZbrrGHUWKu9D7JMR1iKDCCX07RQs,13787
|
14
|
+
broccoli_ml-0.3.0.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
|
15
|
+
broccoli_ml-0.3.0.dist-info/METADATA,sha256=sAbHQ0Q2yM5kaovkF22cTKCk4SU_z6vi6QtmOMMwJlQ,1256
|
16
|
+
broccoli_ml-0.3.0.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
17
|
+
broccoli_ml-0.3.0.dist-info/RECORD,,
|
File without changes
|
File without changes
|