broccoli-ml 0.1.36__tar.gz → 0.1.38__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: broccoli-ml
3
- Version: 0.1.36
3
+ Version: 0.1.38
4
4
  Summary: Some useful Pytorch models, circa 2025
5
5
  License: MIT
6
6
  Author: Nicholas Bailey
@@ -3,9 +3,21 @@ from typing import Optional
3
3
 
4
4
  from .transformer import TransformerEncoder
5
5
  from .cnn import ConvLayer, ConcatPool
6
+ from .activation import ReLU, SquaredReLU, GELU, SwiGLU
6
7
  from einops import einsum
7
8
  from einops.layers.torch import Rearrange
8
9
  import torch.nn as nn
10
+ import torch.nn.functional as F
11
+
12
+
13
+ class PadTensor(nn.Module):
14
+ def __init__(self, *args, **kwargs):
15
+ super().__init__()
16
+ self.args = args
17
+ self.kwargs = kwargs
18
+
19
+ def forward(self, x):
20
+ return F.pad(x, *self.args, **self.kwargs)
9
21
 
10
22
 
11
23
  class SequencePool(nn.Module):
@@ -109,7 +121,7 @@ class CCTEncoder(nn.Module):
109
121
  conv_out_channels = transformer_embedding_size
110
122
  elif conv_pooling_type == "concat":
111
123
  conv_out_channels = int(
112
- round(transformer_embedding_size / (conv_pooling_kernel_size**2))
124
+ math.floor(transformer_embedding_size / (conv_pooling_kernel_size**2))
113
125
  )
114
126
 
115
127
  # This if block rhymes:
@@ -143,11 +155,15 @@ class CCTEncoder(nn.Module):
143
155
  )
144
156
 
145
157
  elif conv_pooling_type == "concat":
146
- concatpool_activation_output_size = (
158
+ concatpool_activation_output_channels = (
147
159
  conv_pooling_kernel_size**2 * conv_out_channels
148
160
  )
149
161
  if cnn_activation.__name__.endswith("GLU"):
150
- concatpool_activation_output_size /= 2
162
+ concatpool_activation_output_channels /= 2
163
+
164
+ concatpool_padding = (
165
+ transformer_embedding_size - concatpool_activation_output_channels
166
+ )
151
167
 
152
168
  self.pool = nn.Sequential(
153
169
  *[
@@ -161,10 +177,7 @@ class CCTEncoder(nn.Module):
161
177
  ),
162
178
  self.cnn_activation,
163
179
  Rearrange("N H W C -> N (H W) C"),
164
- nn.Linear(
165
- concatpool_activation_output_size, transformer_embedding_size
166
- ),
167
- self.cnn_activation,
180
+ PadTensor((0, concatpool_padding)),
168
181
  ]
169
182
  )
170
183
 
@@ -249,6 +262,22 @@ class CCT(nn.Module):
249
262
 
250
263
  super().__init__()
251
264
 
265
+ if isinstance(cnn_activation, str):
266
+ cnn_activation = {
267
+ "ReLU": ReLU,
268
+ "SquaredReLU": SquaredReLU,
269
+ "GELU": GELU,
270
+ "SwiGLU": SwiGLU,
271
+ }[cnn_activation]
272
+
273
+ if isinstance(transformer_activation, str):
274
+ transformer_activation = {
275
+ "ReLU": ReLU,
276
+ "SquaredReLU": SquaredReLU,
277
+ "GELU": GELU,
278
+ "SwiGLU": SwiGLU,
279
+ }[transformer_activation]
280
+
252
281
  self.encoder = CCTEncoder(
253
282
  image_size=image_size,
254
283
  conv_kernel_size=conv_kernel_size,
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "broccoli-ml"
3
- version = "0.1.36"
3
+ version = "0.1.38"
4
4
  description = "Some useful Pytorch models, circa 2025"
5
5
  authors = [
6
6
  {name = "Nicholas Bailey"}
File without changes
File without changes