broccoli-ml 0.1.35__tar.gz → 0.1.37__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {broccoli_ml-0.1.35 → broccoli_ml-0.1.37}/PKG-INFO +1 -1
- {broccoli_ml-0.1.35 → broccoli_ml-0.1.37}/broccoli/vit.py +19 -0
- {broccoli_ml-0.1.35 → broccoli_ml-0.1.37}/pyproject.toml +1 -1
- {broccoli_ml-0.1.35 → broccoli_ml-0.1.37}/LICENSE +0 -0
- {broccoli_ml-0.1.35 → broccoli_ml-0.1.37}/README.md +0 -0
- {broccoli_ml-0.1.35 → broccoli_ml-0.1.37}/broccoli/__init__.py +0 -0
- {broccoli_ml-0.1.35 → broccoli_ml-0.1.37}/broccoli/activation.py +0 -0
- {broccoli_ml-0.1.35 → broccoli_ml-0.1.37}/broccoli/assets/2025_resnet_imagenet_1k_pretrained_state_dict.pkl +0 -0
- {broccoli_ml-0.1.35 → broccoli_ml-0.1.37}/broccoli/assets/cifar100_eigenvectors_size_2.pt +0 -0
- {broccoli_ml-0.1.35 → broccoli_ml-0.1.37}/broccoli/assets/cifar100_eigenvectors_size_3.pt +0 -0
- {broccoli_ml-0.1.35 → broccoli_ml-0.1.37}/broccoli/cnn.py +0 -0
- {broccoli_ml-0.1.35 → broccoli_ml-0.1.37}/broccoli/eigenpatches.py +0 -0
- {broccoli_ml-0.1.35 → broccoli_ml-0.1.37}/broccoli/linear.py +0 -0
- {broccoli_ml-0.1.35 → broccoli_ml-0.1.37}/broccoli/rope.py +0 -0
- {broccoli_ml-0.1.35 → broccoli_ml-0.1.37}/broccoli/tensor.py +0 -0
- {broccoli_ml-0.1.35 → broccoli_ml-0.1.37}/broccoli/transformer.py +0 -0
- {broccoli_ml-0.1.35 → broccoli_ml-0.1.37}/broccoli/utils.py +0 -0
@@ -3,6 +3,7 @@ from typing import Optional
|
|
3
3
|
|
4
4
|
from .transformer import TransformerEncoder
|
5
5
|
from .cnn import ConvLayer, ConcatPool
|
6
|
+
from .activation import ReLU, SquaredReLU, GELU, SwiGLU
|
6
7
|
from einops import einsum
|
7
8
|
from einops.layers.torch import Rearrange
|
8
9
|
import torch.nn as nn
|
@@ -244,10 +245,27 @@ class CCT(nn.Module):
|
|
244
245
|
image_classes=100,
|
245
246
|
linear_module=nn.Linear,
|
246
247
|
image_channels=3,
|
248
|
+
batch_norm=False,
|
247
249
|
):
|
248
250
|
|
249
251
|
super().__init__()
|
250
252
|
|
253
|
+
if isinstance(cnn_activation, str):
|
254
|
+
cnn_activation = {
|
255
|
+
"ReLU": ReLU,
|
256
|
+
"SquaredReLU": SquaredReLU,
|
257
|
+
"GELU": GELU,
|
258
|
+
"SwiGLU": SwiGLU,
|
259
|
+
}[cnn_activation]
|
260
|
+
|
261
|
+
if isinstance(transformer_activation, str):
|
262
|
+
transformer_activation = {
|
263
|
+
"ReLU": ReLU,
|
264
|
+
"SquaredReLU": SquaredReLU,
|
265
|
+
"GELU": GELU,
|
266
|
+
"SwiGLU": SwiGLU,
|
267
|
+
}[transformer_activation]
|
268
|
+
|
251
269
|
self.encoder = CCTEncoder(
|
252
270
|
image_size=image_size,
|
253
271
|
conv_kernel_size=conv_kernel_size,
|
@@ -275,6 +293,7 @@ class CCT(nn.Module):
|
|
275
293
|
stochastic_depth=stochastic_depth,
|
276
294
|
linear_module=linear_module,
|
277
295
|
image_channels=image_channels,
|
296
|
+
batch_norm=batch_norm,
|
278
297
|
)
|
279
298
|
self.pool = SequencePool(
|
280
299
|
transformer_embedding_size, linear_module, image_classes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|