broccoli-ml 0.1.36__py3-none-any.whl → 0.1.38__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- broccoli/vit.py +36 -7
- {broccoli_ml-0.1.36.dist-info → broccoli_ml-0.1.38.dist-info}/METADATA +1 -1
- {broccoli_ml-0.1.36.dist-info → broccoli_ml-0.1.38.dist-info}/RECORD +5 -5
- {broccoli_ml-0.1.36.dist-info → broccoli_ml-0.1.38.dist-info}/LICENSE +0 -0
- {broccoli_ml-0.1.36.dist-info → broccoli_ml-0.1.38.dist-info}/WHEEL +0 -0
broccoli/vit.py
CHANGED
@@ -3,9 +3,21 @@ from typing import Optional
|
|
3
3
|
|
4
4
|
from .transformer import TransformerEncoder
|
5
5
|
from .cnn import ConvLayer, ConcatPool
|
6
|
+
from .activation import ReLU, SquaredReLU, GELU, SwiGLU
|
6
7
|
from einops import einsum
|
7
8
|
from einops.layers.torch import Rearrange
|
8
9
|
import torch.nn as nn
|
10
|
+
import torch.nn.functional as F
|
11
|
+
|
12
|
+
|
13
|
+
class PadTensor(nn.Module):
|
14
|
+
def __init__(self, *args, **kwargs):
|
15
|
+
super().__init__()
|
16
|
+
self.args = args
|
17
|
+
self.kwargs = kwargs
|
18
|
+
|
19
|
+
def forward(self, x):
|
20
|
+
return F.pad(x, *self.args, **self.kwargs)
|
9
21
|
|
10
22
|
|
11
23
|
class SequencePool(nn.Module):
|
@@ -109,7 +121,7 @@ class CCTEncoder(nn.Module):
|
|
109
121
|
conv_out_channels = transformer_embedding_size
|
110
122
|
elif conv_pooling_type == "concat":
|
111
123
|
conv_out_channels = int(
|
112
|
-
|
124
|
+
math.floor(transformer_embedding_size / (conv_pooling_kernel_size**2))
|
113
125
|
)
|
114
126
|
|
115
127
|
# This if block rhymes:
|
@@ -143,11 +155,15 @@ class CCTEncoder(nn.Module):
|
|
143
155
|
)
|
144
156
|
|
145
157
|
elif conv_pooling_type == "concat":
|
146
|
-
|
158
|
+
concatpool_activation_output_channels = (
|
147
159
|
conv_pooling_kernel_size**2 * conv_out_channels
|
148
160
|
)
|
149
161
|
if cnn_activation.__name__.endswith("GLU"):
|
150
|
-
|
162
|
+
concatpool_activation_output_channels /= 2
|
163
|
+
|
164
|
+
concatpool_padding = (
|
165
|
+
transformer_embedding_size - concatpool_activation_output_channels
|
166
|
+
)
|
151
167
|
|
152
168
|
self.pool = nn.Sequential(
|
153
169
|
*[
|
@@ -161,10 +177,7 @@ class CCTEncoder(nn.Module):
|
|
161
177
|
),
|
162
178
|
self.cnn_activation,
|
163
179
|
Rearrange("N H W C -> N (H W) C"),
|
164
|
-
|
165
|
-
concatpool_activation_output_size, transformer_embedding_size
|
166
|
-
),
|
167
|
-
self.cnn_activation,
|
180
|
+
PadTensor((0, concatpool_padding)),
|
168
181
|
]
|
169
182
|
)
|
170
183
|
|
@@ -249,6 +262,22 @@ class CCT(nn.Module):
|
|
249
262
|
|
250
263
|
super().__init__()
|
251
264
|
|
265
|
+
if isinstance(cnn_activation, str):
|
266
|
+
cnn_activation = {
|
267
|
+
"ReLU": ReLU,
|
268
|
+
"SquaredReLU": SquaredReLU,
|
269
|
+
"GELU": GELU,
|
270
|
+
"SwiGLU": SwiGLU,
|
271
|
+
}[cnn_activation]
|
272
|
+
|
273
|
+
if isinstance(transformer_activation, str):
|
274
|
+
transformer_activation = {
|
275
|
+
"ReLU": ReLU,
|
276
|
+
"SquaredReLU": SquaredReLU,
|
277
|
+
"GELU": GELU,
|
278
|
+
"SwiGLU": SwiGLU,
|
279
|
+
}[transformer_activation]
|
280
|
+
|
252
281
|
self.encoder = CCTEncoder(
|
253
282
|
image_size=image_size,
|
254
283
|
conv_kernel_size=conv_kernel_size,
|
@@ -10,8 +10,8 @@ broccoli/rope.py,sha256=hw7kBPNR9GQXj4GxyIAffsGKPfcTPOFh8Bc7oEHtaZY,12108
|
|
10
10
|
broccoli/tensor.py,sha256=E2JK5mQwJf75e23-JGcDoT7QxQf89DJReUo2et1LhRY,1716
|
11
11
|
broccoli/transformer.py,sha256=gFBIEowGFPSgQhM1RwsRtQlw_WzVJPY-LJyf1MLtPek,16277
|
12
12
|
broccoli/utils.py,sha256=htq_hOsdhUhL0nJi9WkKiEYOjEoWqFpK5X49PtgTf-0,299
|
13
|
-
broccoli/vit.py,sha256=
|
14
|
-
broccoli_ml-0.1.
|
15
|
-
broccoli_ml-0.1.
|
16
|
-
broccoli_ml-0.1.
|
17
|
-
broccoli_ml-0.1.
|
13
|
+
broccoli/vit.py,sha256=CSv13ILKw12o1fNznFvgbfw1TR-gDW30h74yjW6HmLc,11692
|
14
|
+
broccoli_ml-0.1.38.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
|
15
|
+
broccoli_ml-0.1.38.dist-info/METADATA,sha256=JJpmHolP3y4Yz8oZ5eRSZswLlCBsqxu6Y5VfimaD5c8,1257
|
16
|
+
broccoli_ml-0.1.38.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
17
|
+
broccoli_ml-0.1.38.dist-info/RECORD,,
|
File without changes
|
File without changes
|