broccoli-ml 0.1.38__py3-none-any.whl → 0.1.39__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- broccoli/vit.py +27 -9
- {broccoli_ml-0.1.38.dist-info → broccoli_ml-0.1.39.dist-info}/METADATA +1 -1
- {broccoli_ml-0.1.38.dist-info → broccoli_ml-0.1.39.dist-info}/RECORD +5 -5
- {broccoli_ml-0.1.38.dist-info → broccoli_ml-0.1.39.dist-info}/LICENSE +0 -0
- {broccoli_ml-0.1.38.dist-info → broccoli_ml-0.1.39.dist-info}/WHEEL +0 -0
broccoli/vit.py
CHANGED
@@ -67,6 +67,7 @@ class CCTEncoder(nn.Module):
|
|
67
67
|
conv_pooling_kernel_size=3,
|
68
68
|
conv_pooling_kernel_stride=2,
|
69
69
|
conv_pooling_kernel_padding=1,
|
70
|
+
conv_dropout=0.0,
|
70
71
|
transformer_position_embedding="absolute", # absolute or relative
|
71
72
|
transformer_embedding_size=256,
|
72
73
|
transformer_layers=7,
|
@@ -155,15 +156,16 @@ class CCTEncoder(nn.Module):
|
|
155
156
|
)
|
156
157
|
|
157
158
|
elif conv_pooling_type == "concat":
|
158
|
-
|
159
|
-
|
159
|
+
self.concatpool_activation = transformer_activation(
|
160
|
+
**transformer_activation_kwargs
|
160
161
|
)
|
161
|
-
if cnn_activation.__name__.endswith("GLU"):
|
162
|
-
concatpool_activation_output_channels /= 2
|
163
162
|
|
164
|
-
|
165
|
-
|
166
|
-
)
|
163
|
+
concatpool_out_channels = conv_pooling_kernel_size**2 * conv_out_channels
|
164
|
+
|
165
|
+
if cnn_activation.__name__.endswith("GLU"):
|
166
|
+
cnn_activation_output_channels = concatpool_out_channels / 2
|
167
|
+
else:
|
168
|
+
cnn_activation_output_channels = concatpool_out_channels
|
167
169
|
|
168
170
|
self.pool = nn.Sequential(
|
169
171
|
*[
|
@@ -176,8 +178,24 @@ class CCTEncoder(nn.Module):
|
|
176
178
|
"N C H W -> N H W C"
|
177
179
|
),
|
178
180
|
self.cnn_activation,
|
179
|
-
|
180
|
-
|
181
|
+
nn.Dropout(conv_dropout),
|
182
|
+
Rearrange( # rearrange in case we're using XGLU activation
|
183
|
+
"N H W C -> N C H W"
|
184
|
+
),
|
185
|
+
nn.BatchNorm2d(cnn_activation_output_channels),
|
186
|
+
Rearrange( # rearrange in case we're using XGLU activation
|
187
|
+
"N C H W -> N (H W) C"
|
188
|
+
),
|
189
|
+
nn.Linear(
|
190
|
+
cnn_activation_output_channels,
|
191
|
+
(
|
192
|
+
2 * transformer_embedding_size * transformer_mlp_ratio
|
193
|
+
if transformer_activation.__name__.endswith("GLU")
|
194
|
+
else transformer_embedding_size * transformer_mlp_ratio
|
195
|
+
),
|
196
|
+
),
|
197
|
+
self.concatpool_activation,
|
198
|
+
nn.Linear(transformer_embedding_size * transformer_mlp_ratio),
|
181
199
|
]
|
182
200
|
)
|
183
201
|
|
@@ -10,8 +10,8 @@ broccoli/rope.py,sha256=hw7kBPNR9GQXj4GxyIAffsGKPfcTPOFh8Bc7oEHtaZY,12108
|
|
10
10
|
broccoli/tensor.py,sha256=E2JK5mQwJf75e23-JGcDoT7QxQf89DJReUo2et1LhRY,1716
|
11
11
|
broccoli/transformer.py,sha256=gFBIEowGFPSgQhM1RwsRtQlw_WzVJPY-LJyf1MLtPek,16277
|
12
12
|
broccoli/utils.py,sha256=htq_hOsdhUhL0nJi9WkKiEYOjEoWqFpK5X49PtgTf-0,299
|
13
|
-
broccoli/vit.py,sha256=
|
14
|
-
broccoli_ml-0.1.
|
15
|
-
broccoli_ml-0.1.
|
16
|
-
broccoli_ml-0.1.
|
17
|
-
broccoli_ml-0.1.
|
13
|
+
broccoli/vit.py,sha256=hN9m24HkgxFMQPEFmlv865ejHs7JujMRQfzoplJKu78,12618
|
14
|
+
broccoli_ml-0.1.39.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
|
15
|
+
broccoli_ml-0.1.39.dist-info/METADATA,sha256=MAYq4HTN1PVZIbWYaqnoU7EnY6-vVFlbcAdASzuoetE,1257
|
16
|
+
broccoli_ml-0.1.39.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
17
|
+
broccoli_ml-0.1.39.dist-info/RECORD,,
|
File without changes
|
File without changes
|