broccoli-ml 0.1.37__py3-none-any.whl → 0.1.38__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- broccoli/vit.py +19 -7
- {broccoli_ml-0.1.37.dist-info → broccoli_ml-0.1.38.dist-info}/METADATA +1 -1
- {broccoli_ml-0.1.37.dist-info → broccoli_ml-0.1.38.dist-info}/RECORD +5 -5
- {broccoli_ml-0.1.37.dist-info → broccoli_ml-0.1.38.dist-info}/LICENSE +0 -0
- {broccoli_ml-0.1.37.dist-info → broccoli_ml-0.1.38.dist-info}/WHEEL +0 -0
broccoli/vit.py
CHANGED
@@ -7,6 +7,17 @@ from .activation import ReLU, SquaredReLU, GELU, SwiGLU
|
|
7
7
|
from einops import einsum
|
8
8
|
from einops.layers.torch import Rearrange
|
9
9
|
import torch.nn as nn
|
10
|
+
import torch.nn.functional as F
|
11
|
+
|
12
|
+
|
13
|
+
class PadTensor(nn.Module):
|
14
|
+
def __init__(self, *args, **kwargs):
|
15
|
+
super().__init__()
|
16
|
+
self.args = args
|
17
|
+
self.kwargs = kwargs
|
18
|
+
|
19
|
+
def forward(self, x):
|
20
|
+
return F.pad(x, *self.args, **self.kwargs)
|
10
21
|
|
11
22
|
|
12
23
|
class SequencePool(nn.Module):
|
@@ -110,7 +121,7 @@ class CCTEncoder(nn.Module):
|
|
110
121
|
conv_out_channels = transformer_embedding_size
|
111
122
|
elif conv_pooling_type == "concat":
|
112
123
|
conv_out_channels = int(
|
113
|
-
|
124
|
+
math.floor(transformer_embedding_size / (conv_pooling_kernel_size**2))
|
114
125
|
)
|
115
126
|
|
116
127
|
# This if block rhymes:
|
@@ -144,11 +155,15 @@ class CCTEncoder(nn.Module):
|
|
144
155
|
)
|
145
156
|
|
146
157
|
elif conv_pooling_type == "concat":
|
147
|
-
|
158
|
+
concatpool_activation_output_channels = (
|
148
159
|
conv_pooling_kernel_size**2 * conv_out_channels
|
149
160
|
)
|
150
161
|
if cnn_activation.__name__.endswith("GLU"):
|
151
|
-
|
162
|
+
concatpool_activation_output_channels /= 2
|
163
|
+
|
164
|
+
concatpool_padding = (
|
165
|
+
transformer_embedding_size - concatpool_activation_output_channels
|
166
|
+
)
|
152
167
|
|
153
168
|
self.pool = nn.Sequential(
|
154
169
|
*[
|
@@ -162,10 +177,7 @@ class CCTEncoder(nn.Module):
|
|
162
177
|
),
|
163
178
|
self.cnn_activation,
|
164
179
|
Rearrange("N H W C -> N (H W) C"),
|
165
|
-
|
166
|
-
concatpool_activation_output_size, transformer_embedding_size
|
167
|
-
),
|
168
|
-
self.cnn_activation,
|
180
|
+
PadTensor((0, concatpool_padding)),
|
169
181
|
]
|
170
182
|
)
|
171
183
|
|
@@ -10,8 +10,8 @@ broccoli/rope.py,sha256=hw7kBPNR9GQXj4GxyIAffsGKPfcTPOFh8Bc7oEHtaZY,12108
|
|
10
10
|
broccoli/tensor.py,sha256=E2JK5mQwJf75e23-JGcDoT7QxQf89DJReUo2et1LhRY,1716
|
11
11
|
broccoli/transformer.py,sha256=gFBIEowGFPSgQhM1RwsRtQlw_WzVJPY-LJyf1MLtPek,16277
|
12
12
|
broccoli/utils.py,sha256=htq_hOsdhUhL0nJi9WkKiEYOjEoWqFpK5X49PtgTf-0,299
|
13
|
-
broccoli/vit.py,sha256=
|
14
|
-
broccoli_ml-0.1.
|
15
|
-
broccoli_ml-0.1.
|
16
|
-
broccoli_ml-0.1.
|
17
|
-
broccoli_ml-0.1.
|
13
|
+
broccoli/vit.py,sha256=CSv13ILKw12o1fNznFvgbfw1TR-gDW30h74yjW6HmLc,11692
|
14
|
+
broccoli_ml-0.1.38.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
|
15
|
+
broccoli_ml-0.1.38.dist-info/METADATA,sha256=JJpmHolP3y4Yz8oZ5eRSZswLlCBsqxu6Y5VfimaD5c8,1257
|
16
|
+
broccoli_ml-0.1.38.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
17
|
+
broccoli_ml-0.1.38.dist-info/RECORD,,
|
File without changes
|
File without changes
|