broccoli-ml 0.1.33__tar.gz → 0.1.35__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {broccoli_ml-0.1.33 → broccoli_ml-0.1.35}/PKG-INFO +1 -1
- {broccoli_ml-0.1.33 → broccoli_ml-0.1.35}/broccoli/cnn.py +2 -0
- {broccoli_ml-0.1.33 → broccoli_ml-0.1.35}/broccoli/transformer.py +1 -1
- {broccoli_ml-0.1.33 → broccoli_ml-0.1.35}/broccoli/vit.py +59 -44
- {broccoli_ml-0.1.33 → broccoli_ml-0.1.35}/pyproject.toml +1 -1
- {broccoli_ml-0.1.33 → broccoli_ml-0.1.35}/LICENSE +0 -0
- {broccoli_ml-0.1.33 → broccoli_ml-0.1.35}/README.md +0 -0
- {broccoli_ml-0.1.33 → broccoli_ml-0.1.35}/broccoli/__init__.py +0 -0
- {broccoli_ml-0.1.33 → broccoli_ml-0.1.35}/broccoli/activation.py +0 -0
- {broccoli_ml-0.1.33 → broccoli_ml-0.1.35}/broccoli/assets/2025_resnet_imagenet_1k_pretrained_state_dict.pkl +0 -0
- {broccoli_ml-0.1.33 → broccoli_ml-0.1.35}/broccoli/assets/cifar100_eigenvectors_size_2.pt +0 -0
- {broccoli_ml-0.1.33 → broccoli_ml-0.1.35}/broccoli/assets/cifar100_eigenvectors_size_3.pt +0 -0
- {broccoli_ml-0.1.33 → broccoli_ml-0.1.35}/broccoli/eigenpatches.py +0 -0
- {broccoli_ml-0.1.33 → broccoli_ml-0.1.35}/broccoli/linear.py +0 -0
- {broccoli_ml-0.1.33 → broccoli_ml-0.1.35}/broccoli/rope.py +0 -0
- {broccoli_ml-0.1.33 → broccoli_ml-0.1.35}/broccoli/tensor.py +0 -0
- {broccoli_ml-0.1.33 → broccoli_ml-0.1.35}/broccoli/utils.py +0 -0
@@ -301,6 +301,8 @@ class ConcatPool(nn.Module):
|
|
301
301
|
them channel-wise.
|
302
302
|
"""
|
303
303
|
|
304
|
+
# TODO: change this to use nn.Fold instead of view, which is equivlent but more readable
|
305
|
+
|
304
306
|
def __init__(self, kernel_size, stride=1, padding=0, dilation=1):
|
305
307
|
super().__init__()
|
306
308
|
|
@@ -343,7 +343,7 @@ class TransformerBlock(nn.Module):
|
|
343
343
|
norm_process_x, norm_process_x, norm_process_x
|
344
344
|
)
|
345
345
|
process_x = process_x + self.ff_process(process_x)
|
346
|
-
x = torch.cat([
|
346
|
+
x = torch.cat([identity_x, process_x])[unshuffle_indices, :, :].contiguous()
|
347
347
|
|
348
348
|
return x
|
349
349
|
|
@@ -66,23 +66,33 @@ class CCTEncoder(nn.Module):
|
|
66
66
|
tranformer_d_model_scale=True,
|
67
67
|
tranformer_log_length_scale=True,
|
68
68
|
tranformer_quiet_attention=True,
|
69
|
-
|
70
|
-
|
69
|
+
cnn_activation: nn.Module = nn.ReLU,
|
70
|
+
cnn_activation_kwargs: Optional[dict] = None,
|
71
|
+
transformer_activation: nn.Module = nn.GELU,
|
72
|
+
transformer_activation_kwargs: Optional[dict] = None,
|
71
73
|
mlp_dropout=0.0,
|
72
74
|
msa_dropout=0.1,
|
73
75
|
stochastic_depth=0.1,
|
74
76
|
linear_module=nn.Linear,
|
75
77
|
image_channels=3,
|
78
|
+
batch_norm=False,
|
76
79
|
):
|
77
80
|
if conv_pooling_type not in ["maxpool", "concat"]:
|
78
81
|
raise NotImplementedError("Pooling type must be maxpool or concat")
|
79
82
|
|
80
83
|
super().__init__()
|
81
84
|
|
82
|
-
if
|
83
|
-
self.
|
85
|
+
if cnn_activation_kwargs is not None:
|
86
|
+
self.cnn_activation = cnn_activation(**cnn_activation_kwargs)
|
84
87
|
else:
|
85
|
-
self.
|
88
|
+
self.cnn_activation = cnn_activation()
|
89
|
+
|
90
|
+
if transformer_activation_kwargs is not None:
|
91
|
+
self.transformer_activation = transformer_activation(
|
92
|
+
**transformer_activation_kwargs
|
93
|
+
)
|
94
|
+
else:
|
95
|
+
self.transformer_activation = transformer_activation()
|
86
96
|
|
87
97
|
self.image_size = image_size
|
88
98
|
|
@@ -96,43 +106,48 @@ class CCTEncoder(nn.Module):
|
|
96
106
|
self.sequence_length = output_size**2
|
97
107
|
|
98
108
|
if conv_pooling_type == "maxpool":
|
99
|
-
|
100
109
|
conv_out_channels = transformer_embedding_size
|
101
|
-
|
102
|
-
|
110
|
+
elif conv_pooling_type == "concat":
|
111
|
+
conv_out_channels = int(
|
112
|
+
round(transformer_embedding_size / (conv_pooling_kernel_size**2))
|
113
|
+
)
|
114
|
+
|
115
|
+
# This if block rhymes:
|
116
|
+
if cnn_activation.__name__.endswith("GLU"):
|
117
|
+
conv_out_channels *= 2
|
103
118
|
|
119
|
+
self.conv = ConvLayer(
|
120
|
+
image_channels,
|
121
|
+
conv_out_channels,
|
122
|
+
kernel_size=conv_kernel_size,
|
123
|
+
stride=1,
|
124
|
+
padding="same",
|
125
|
+
linear_module=linear_module,
|
126
|
+
)
|
127
|
+
|
128
|
+
if conv_pooling_type == "maxpool":
|
104
129
|
self.pool = nn.Sequential(
|
105
130
|
*[
|
106
131
|
Rearrange( # rearrange in case we're using XGLU activation
|
107
132
|
"N C H W -> N H W C"
|
108
133
|
),
|
109
|
-
self.
|
134
|
+
self.cnn_activation,
|
110
135
|
Rearrange("N H W C -> N C H W"),
|
111
136
|
nn.MaxPool2d(
|
112
137
|
conv_pooling_kernel_size,
|
113
138
|
stride=conv_pooling_kernel_stride,
|
114
139
|
padding=conv_pooling_kernel_padding,
|
115
140
|
),
|
116
|
-
Rearrange("N C H W -> N (H W) C"),
|
141
|
+
Rearrange("N C H W -> N (H W) C"), # for transformer
|
117
142
|
]
|
118
143
|
)
|
119
144
|
|
120
145
|
elif conv_pooling_type == "concat":
|
121
|
-
|
122
|
-
|
123
|
-
round(transformer_embedding_size / (conv_pooling_kernel_size**2))
|
124
|
-
)
|
125
|
-
pooling_out_channels = conv_pooling_kernel_size**2 * conv_out_channels
|
126
|
-
pooling_adapter_out_channels = transformer_embedding_size
|
127
|
-
if activation.__name__.endswith("GLU"):
|
128
|
-
pooling_adapter_out_channels *= 2
|
129
|
-
self.pooling_adapter = nn.Sequential(
|
130
|
-
*[
|
131
|
-
Rearrange("N C H W -> N (H W) C"),
|
132
|
-
nn.Linear(pooling_out_channels, pooling_adapter_out_channels),
|
133
|
-
self.activation,
|
134
|
-
]
|
146
|
+
concatpool_activation_output_size = (
|
147
|
+
conv_pooling_kernel_size**2 * conv_out_channels
|
135
148
|
)
|
149
|
+
if cnn_activation.__name__.endswith("GLU"):
|
150
|
+
concatpool_activation_output_size /= 2
|
136
151
|
|
137
152
|
self.pool = nn.Sequential(
|
138
153
|
*[
|
@@ -141,7 +156,15 @@ class CCTEncoder(nn.Module):
|
|
141
156
|
stride=conv_pooling_kernel_stride,
|
142
157
|
padding=conv_pooling_kernel_padding,
|
143
158
|
),
|
144
|
-
|
159
|
+
Rearrange( # rearrange in case we're using XGLU activation
|
160
|
+
"N C H W -> N H W C"
|
161
|
+
),
|
162
|
+
self.cnn_activation,
|
163
|
+
Rearrange("N H W C -> N (H W) C"),
|
164
|
+
nn.Linear(
|
165
|
+
concatpool_activation_output_size, transformer_embedding_size
|
166
|
+
),
|
167
|
+
self.cnn_activation,
|
145
168
|
]
|
146
169
|
)
|
147
170
|
|
@@ -154,8 +177,8 @@ class CCTEncoder(nn.Module):
|
|
154
177
|
position_embedding_type=transformer_position_embedding,
|
155
178
|
source_size=(output_size, output_size),
|
156
179
|
mlp_ratio=transformer_mlp_ratio,
|
157
|
-
activation=
|
158
|
-
activation_kwargs=
|
180
|
+
activation=transformer_activation,
|
181
|
+
activation_kwargs=transformer_activation_kwargs,
|
159
182
|
mlp_dropout=mlp_dropout,
|
160
183
|
msa_dropout=msa_dropout,
|
161
184
|
stochastic_depth=stochastic_depth,
|
@@ -171,21 +194,9 @@ class CCTEncoder(nn.Module):
|
|
171
194
|
else:
|
172
195
|
self.transformer = nn.Identity()
|
173
196
|
|
174
|
-
# This code block rhymes:
|
175
|
-
if activation.__name__.endswith("GLU"):
|
176
|
-
conv_out_channels *= 2
|
177
|
-
|
178
|
-
self.conv = ConvLayer(
|
179
|
-
image_channels,
|
180
|
-
conv_out_channels,
|
181
|
-
kernel_size=conv_kernel_size,
|
182
|
-
stride=1,
|
183
|
-
padding="same",
|
184
|
-
linear_module=linear_module,
|
185
|
-
)
|
186
|
-
|
187
197
|
self.encoder = nn.Sequential(
|
188
198
|
*[
|
199
|
+
nn.BatchNorm2d(image_channels) if batch_norm else nn.Identity(),
|
189
200
|
self.conv,
|
190
201
|
self.pool,
|
191
202
|
self.transformer,
|
@@ -223,8 +234,10 @@ class CCT(nn.Module):
|
|
223
234
|
tranformer_d_model_scale=True,
|
224
235
|
tranformer_log_length_scale=True,
|
225
236
|
tranformer_quiet_attention=True,
|
226
|
-
|
227
|
-
|
237
|
+
cnn_activation: nn.Module = nn.ReLU,
|
238
|
+
cnn_activation_kwargs: Optional[dict] = None,
|
239
|
+
transformer_activation: nn.Module = nn.GELU,
|
240
|
+
transformer_activation_kwargs: Optional[dict] = None,
|
228
241
|
mlp_dropout=0.0, # The original paper got best performance from mlp_dropout=0.
|
229
242
|
msa_dropout=0.1, # "" msa_dropout=0.1
|
230
243
|
stochastic_depth=0.1, # "" stochastic_depth=0.1
|
@@ -253,8 +266,10 @@ class CCT(nn.Module):
|
|
253
266
|
tranformer_d_model_scale=tranformer_d_model_scale,
|
254
267
|
tranformer_log_length_scale=tranformer_log_length_scale,
|
255
268
|
tranformer_quiet_attention=tranformer_quiet_attention,
|
256
|
-
|
257
|
-
|
269
|
+
cnn_activation=cnn_activation,
|
270
|
+
cnn_activation_kwargs=cnn_activation_kwargs,
|
271
|
+
transformer_activation=transformer_activation,
|
272
|
+
transformer_activation_kwargs=transformer_activation_kwargs,
|
258
273
|
mlp_dropout=mlp_dropout,
|
259
274
|
msa_dropout=msa_dropout,
|
260
275
|
stochastic_depth=stochastic_depth,
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|