broccoli-ml 0.1.41__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
broccoli/vit.py CHANGED
@@ -1,8 +1,8 @@
1
1
  import math
2
2
  from typing import Optional
3
3
 
4
- from .transformer import TransformerEncoder
5
- from .cnn import ConvLayer, ConcatPool
4
+ from .transformer import TransformerEncoder, DenoisingAutoEncoder
5
+ from .cnn import SpaceToDepth, calculate_output_spatial_size, spatial_tuple
6
6
  from .activation import ReLU, SquaredReLU, GELU, SwiGLU
7
7
  from einops import einsum
8
8
  from einops.layers.torch import Rearrange
@@ -61,38 +61,35 @@ class CCTEncoder(nn.Module):
61
61
 
62
62
  def __init__(
63
63
  self,
64
- image_size=32,
65
- conv_kernel_size=3,
66
- conv_pooling_type="maxpool",
67
- conv_pooling_kernel_size=3,
68
- conv_pooling_kernel_stride=2,
69
- conv_pooling_kernel_padding=1,
70
- conv_dropout=0.0,
64
+ input_size=(32, 32),
65
+ cnn_in_channels=3,
66
+ minimum_cnn_out_channels=16,
67
+ cnn_kernel_size=3,
68
+ cnn_kernel_stride=1,
69
+ cnn_kernel_padding="same",
70
+ cnn_kernel_dilation=1,
71
+ cnn_kernel_groups=1,
72
+ cnn_activation: nn.Module = nn.ReLU,
73
+ cnn_activation_kwargs: Optional[dict] = None,
74
+ cnn_dropout=0.0,
75
+ pooling_type="maxpool",
76
+ pooling_kernel_size=3,
77
+ pooling_kernel_stride=2,
78
+ pooling_kernel_padding=1,
71
79
  transformer_position_embedding="absolute", # absolute or relative
72
80
  transformer_embedding_size=256,
73
81
  transformer_layers=7,
74
82
  transformer_heads=4,
75
83
  transformer_mlp_ratio=2,
76
84
  transformer_bos_tokens=4,
77
- tranformer_share_kv=True,
78
- tranformer_max_subtract=True,
79
- tranformer_d_model_scale=True,
80
- tranformer_log_length_scale=True,
81
- tranformer_quiet_attention=True,
82
- cnn_activation: nn.Module = nn.ReLU,
83
- cnn_activation_kwargs: Optional[dict] = None,
84
85
  transformer_activation: nn.Module = nn.GELU,
85
86
  transformer_activation_kwargs: Optional[dict] = None,
86
87
  mlp_dropout=0.0,
87
88
  msa_dropout=0.1,
88
89
  stochastic_depth=0.1,
89
90
  linear_module=nn.Linear,
90
- image_channels=3,
91
- batch_norm=False,
91
+ batch_norm=True,
92
92
  ):
93
- if conv_pooling_type not in ["maxpool", "concat"]:
94
- raise NotImplementedError("Pooling type must be maxpool or concat")
95
-
96
93
  super().__init__()
97
94
 
98
95
  if cnn_activation_kwargs is not None:
@@ -107,55 +104,123 @@ class CCTEncoder(nn.Module):
107
104
  else:
108
105
  self.transformer_activation = transformer_activation()
109
106
 
110
- self.image_size = image_size
107
+ self.input_size = input_size
108
+ self.spatial_dimensions = len(self.input_size)
109
+
110
+ if self.spatial_dimensions == 1:
111
+ maxpoolxd = nn.MaxPool1d
112
+ convxd = nn.Conv1d
113
+ batchnormxd = nn.BatchNorm1d
114
+ spatial_dim_names = "D1"
115
+ elif self.spatial_dimensions == 2:
116
+ maxpoolxd = nn.MaxPool2d
117
+ convxd = nn.Conv2d
118
+ batchnormxd = nn.BatchNorm2d
119
+ spatial_dim_names = "D1 D2"
120
+ elif self.spatial_dimensions == 3:
121
+ maxpoolxd = nn.MaxPool3d
122
+ convxd = nn.Conv3d
123
+ batchnormxd = nn.BatchNorm3d
124
+ spatial_dim_names = "D1 D2 D3"
125
+ else:
126
+ raise NotImplementedError(
127
+ "`input_size` must be a tuple of length 1, 2, or 3."
128
+ )
111
129
 
112
- # XXX: We assume a square image here
113
- output_size = math.floor(
114
- (image_size + 2 * conv_pooling_kernel_padding - conv_pooling_kernel_size)
115
- / conv_pooling_kernel_stride
116
- + 1
117
- ) # output of pooling
130
+ cnn_output_size = calculate_output_spatial_size(
131
+ input_size,
132
+ kernel_size=cnn_kernel_size,
133
+ stride=cnn_kernel_stride,
134
+ padding=cnn_kernel_padding,
135
+ dilation=cnn_kernel_dilation,
136
+ )
118
137
 
119
- self.sequence_length = output_size**2
138
+ pooling_output_size = (
139
+ cnn_output_size
140
+ if pooling_type is None
141
+ else calculate_output_spatial_size(
142
+ cnn_output_size,
143
+ kernel_size=pooling_kernel_size,
144
+ stride=pooling_kernel_stride,
145
+ padding=pooling_kernel_padding,
146
+ dilation=1,
147
+ )
148
+ )
149
+
150
+ self.sequence_length = math.prod(pooling_output_size) # One token per voxel
151
+
152
+ pooling_kernel_voxels = math.prod(
153
+ spatial_tuple(pooling_kernel_size, self.spatial_dimensions)
154
+ )
120
155
 
121
- if conv_pooling_type == "maxpool":
122
- conv_out_channels = transformer_embedding_size
123
- elif conv_pooling_type == "concat":
124
- conv_out_channels = int(
125
- math.floor(transformer_embedding_size / (conv_pooling_kernel_size**2))
156
+ if pooling_type in ["maxpool", None]:
157
+ cnn_out_channels = transformer_embedding_size
158
+ elif pooling_type == "concat":
159
+ cnn_out_channels = min(
160
+ math.floor(transformer_embedding_size / pooling_kernel_voxels),
161
+ minimum_cnn_out_channels,
126
162
  )
163
+ else:
164
+ raise NotImplementedError("Pooling type must be maxpool, concat or None")
165
+
166
+ cnn_activation_out_channels = cnn_out_channels
127
167
 
128
- # This if block rhymes:
168
+ # This block rhymes:
129
169
  if cnn_activation.__name__.endswith("GLU"):
130
- conv_out_channels *= 2
131
-
132
- self.conv = ConvLayer(
133
- image_channels,
134
- conv_out_channels,
135
- kernel_size=conv_kernel_size,
136
- stride=1,
137
- padding="same",
138
- linear_module=linear_module,
170
+ cnn_out_channels *= 2
171
+
172
+ self.cnn = convxd(
173
+ cnn_in_channels,
174
+ cnn_out_channels,
175
+ cnn_kernel_size,
176
+ stride=cnn_kernel_stride,
177
+ padding=cnn_kernel_padding,
178
+ dilation=cnn_kernel_dilation,
179
+ groups=cnn_kernel_groups,
180
+ bias=True,
181
+ padding_mode="zeros",
182
+ )
183
+
184
+ self.activate_and_dropout = nn.Sequential(
185
+ *[
186
+ Rearrange( # rearrange in case we're using XGLU activation
187
+ f"N C {spatial_dim_names} -> N {spatial_dim_names} C"
188
+ ),
189
+ self.cnn_activation,
190
+ Rearrange(f"N {spatial_dim_names} C -> N C {spatial_dim_names}"),
191
+ nn.Dropout(cnn_dropout),
192
+ (
193
+ batchnormxd(cnn_activation_out_channels)
194
+ if batch_norm
195
+ else nn.Identity()
196
+ ),
197
+ ]
139
198
  )
140
199
 
141
- if conv_pooling_type == "maxpool":
200
+ if pooling_type is None:
142
201
  self.pool = nn.Sequential(
143
202
  *[
144
- Rearrange( # rearrange in case we're using XGLU activation
145
- "N C H W -> N H W C"
146
- ),
147
- self.cnn_activation,
148
- Rearrange("N H W C -> N C H W"),
149
- nn.MaxPool2d(
150
- conv_pooling_kernel_size,
151
- stride=conv_pooling_kernel_stride,
152
- padding=conv_pooling_kernel_padding,
203
+ Rearrange(
204
+ f"N C {spatial_dim_names} -> N ({spatial_dim_names}) C"
205
+ ), # for transformer
206
+ ]
207
+ )
208
+
209
+ elif pooling_type == "maxpool":
210
+ self.pool = nn.Sequential(
211
+ *[
212
+ maxpoolxd(
213
+ pooling_kernel_size,
214
+ stride=pooling_kernel_stride,
215
+ padding=pooling_kernel_padding,
153
216
  ),
154
- Rearrange("N C H W -> N (H W) C"), # for transformer
217
+ Rearrange(
218
+ f"N C {spatial_dim_names} -> N ({spatial_dim_names}) C"
219
+ ), # for transformer
155
220
  ]
156
221
  )
157
222
 
158
- elif conv_pooling_type == "concat":
223
+ elif pooling_type == "concat":
159
224
 
160
225
  if transformer_activation_kwargs is not None:
161
226
  self.concatpool_activation = transformer_activation(
@@ -164,44 +229,29 @@ class CCTEncoder(nn.Module):
164
229
  else:
165
230
  self.concatpool_activation = transformer_activation()
166
231
 
167
- concatpool_out_channels = conv_pooling_kernel_size**2 * conv_out_channels
168
-
169
- if cnn_activation.__name__.endswith("GLU"):
170
- cnn_activation_output_channels = concatpool_out_channels / 2
171
- else:
172
- cnn_activation_output_channels = concatpool_out_channels
232
+ concatpool_out_channels = (
233
+ pooling_kernel_voxels * cnn_activation_out_channels
234
+ )
173
235
 
174
236
  self.pool = nn.Sequential(
175
237
  *[
176
- ConcatPool(
177
- conv_pooling_kernel_size,
178
- stride=conv_pooling_kernel_stride,
179
- padding=conv_pooling_kernel_padding,
180
- ),
181
- Rearrange( # rearrange in case we're using XGLU activation
182
- "N C H W -> N H W C"
238
+ SpaceToDepth(
239
+ pooling_kernel_size,
240
+ stride=pooling_kernel_stride,
241
+ padding=pooling_kernel_padding,
242
+ spatial_dimensions=self.spatial_dimensions,
183
243
  ),
184
- self.cnn_activation,
185
- nn.Dropout(conv_dropout),
186
- Rearrange( # rearrange in case we're using XGLU activation
187
- "N H W C -> N C H W"
244
+ Rearrange( # for transformer
245
+ f"N C {spatial_dim_names} -> N ({spatial_dim_names}) C"
188
246
  ),
189
- nn.BatchNorm2d(cnn_activation_output_channels),
190
- Rearrange( # rearrange in case we're using XGLU activation
191
- "N C H W -> N (H W) C"
192
- ),
193
- nn.Linear(
194
- cnn_activation_output_channels,
195
- (
196
- 2 * transformer_embedding_size * transformer_mlp_ratio
197
- if transformer_activation.__name__.endswith("GLU")
198
- else transformer_embedding_size * transformer_mlp_ratio
199
- ),
200
- ),
201
- self.concatpool_activation,
202
- nn.Linear(
203
- transformer_embedding_size * transformer_mlp_ratio,
247
+ DenoisingAutoEncoder(
248
+ concatpool_out_channels,
249
+ transformer_mlp_ratio,
204
250
  transformer_embedding_size,
251
+ activation=transformer_activation,
252
+ activation_kwargs=transformer_activation_kwargs,
253
+ dropout=0.0,
254
+ linear_module=linear_module,
205
255
  ),
206
256
  ]
207
257
  )
@@ -213,7 +263,7 @@ class CCTEncoder(nn.Module):
213
263
  transformer_layers,
214
264
  transformer_heads,
215
265
  position_embedding_type=transformer_position_embedding,
216
- source_size=(output_size, output_size),
266
+ source_size=pooling_output_size,
217
267
  mlp_ratio=transformer_mlp_ratio,
218
268
  activation=transformer_activation,
219
269
  activation_kwargs=transformer_activation_kwargs,
@@ -221,11 +271,6 @@ class CCTEncoder(nn.Module):
221
271
  msa_dropout=msa_dropout,
222
272
  stochastic_depth=stochastic_depth,
223
273
  causal=False,
224
- share_kv=tranformer_share_kv,
225
- max_subtract=tranformer_max_subtract,
226
- d_model_scale=tranformer_d_model_scale,
227
- log_length_scale=tranformer_log_length_scale,
228
- quiet_attention=tranformer_quiet_attention,
229
274
  linear_module=linear_module,
230
275
  bos_tokens=transformer_bos_tokens,
231
276
  )
@@ -234,8 +279,9 @@ class CCTEncoder(nn.Module):
234
279
 
235
280
  self.encoder = nn.Sequential(
236
281
  *[
237
- nn.BatchNorm2d(image_channels) if batch_norm else nn.Identity(),
238
- self.conv,
282
+ batchnormxd(cnn_in_channels) if batch_norm else nn.Identity(),
283
+ self.cnn,
284
+ self.activate_and_dropout,
239
285
  self.pool,
240
286
  self.transformer,
241
287
  ]
@@ -255,8 +301,17 @@ class CCT(nn.Module):
255
301
 
256
302
  def __init__(
257
303
  self,
258
- image_size=32,
259
- conv_kernel_size=3, # Only 2 is supported for eigenvector initialisation
304
+ input_size=(32, 32),
305
+ cnn_in_channels=3,
306
+ minimum_cnn_out_channels=16,
307
+ cnn_kernel_size=3,
308
+ cnn_kernel_stride=1,
309
+ cnn_kernel_padding="same",
310
+ cnn_kernel_dilation=1,
311
+ cnn_kernel_groups=1,
312
+ cnn_activation: nn.Module = nn.ReLU,
313
+ cnn_activation_kwargs: Optional[dict] = None,
314
+ cnn_dropout=0.0,
260
315
  pooling_type="maxpool",
261
316
  pooling_kernel_size=3,
262
317
  pooling_kernel_stride=2,
@@ -267,22 +322,14 @@ class CCT(nn.Module):
267
322
  transformer_heads=4,
268
323
  transformer_mlp_ratio=2,
269
324
  transformer_bos_tokens=4,
270
- tranformer_share_kv=True,
271
- tranformer_max_subtract=True,
272
- tranformer_d_model_scale=True,
273
- tranformer_log_length_scale=True,
274
- tranformer_quiet_attention=True,
275
- cnn_activation: nn.Module = nn.ReLU,
276
- cnn_activation_kwargs: Optional[dict] = None,
277
325
  transformer_activation: nn.Module = nn.GELU,
278
326
  transformer_activation_kwargs: Optional[dict] = None,
279
- mlp_dropout=0.0, # The original paper got best performance from mlp_dropout=0.
280
- msa_dropout=0.1, # "" msa_dropout=0.1
281
- stochastic_depth=0.1, # "" stochastic_depth=0.1
282
- image_classes=100,
327
+ mlp_dropout=0.0,
328
+ msa_dropout=0.1,
329
+ stochastic_depth=0.1,
283
330
  linear_module=nn.Linear,
284
- image_channels=3,
285
- batch_norm=False,
331
+ batch_norm=True,
332
+ image_classes=100,
286
333
  ):
287
334
 
288
335
  super().__init__()
@@ -304,32 +351,33 @@ class CCT(nn.Module):
304
351
  }[transformer_activation]
305
352
 
306
353
  self.encoder = CCTEncoder(
307
- image_size=image_size,
308
- conv_kernel_size=conv_kernel_size,
309
- conv_pooling_type=pooling_type,
310
- conv_pooling_kernel_size=pooling_kernel_size,
311
- conv_pooling_kernel_stride=pooling_kernel_stride,
312
- conv_pooling_kernel_padding=pooling_kernel_padding,
354
+ input_size=input_size,
355
+ cnn_in_channels=cnn_in_channels,
356
+ minimum_cnn_out_channels=minimum_cnn_out_channels,
357
+ cnn_kernel_size=cnn_kernel_size,
358
+ cnn_kernel_stride=cnn_kernel_stride,
359
+ cnn_kernel_padding=cnn_kernel_padding,
360
+ cnn_kernel_dilation=cnn_kernel_dilation,
361
+ cnn_kernel_groups=cnn_kernel_groups,
362
+ cnn_activation=cnn_activation,
363
+ cnn_activation_kwargs=cnn_activation_kwargs,
364
+ cnn_dropout=cnn_dropout,
365
+ pooling_type=pooling_type,
366
+ pooling_kernel_size=pooling_kernel_size,
367
+ pooling_kernel_stride=pooling_kernel_stride,
368
+ pooling_kernel_padding=pooling_kernel_padding,
313
369
  transformer_position_embedding=transformer_position_embedding,
314
370
  transformer_embedding_size=transformer_embedding_size,
315
371
  transformer_layers=transformer_layers,
316
372
  transformer_heads=transformer_heads,
317
373
  transformer_mlp_ratio=transformer_mlp_ratio,
318
374
  transformer_bos_tokens=transformer_bos_tokens,
319
- tranformer_share_kv=tranformer_share_kv,
320
- tranformer_max_subtract=tranformer_max_subtract,
321
- tranformer_d_model_scale=tranformer_d_model_scale,
322
- tranformer_log_length_scale=tranformer_log_length_scale,
323
- tranformer_quiet_attention=tranformer_quiet_attention,
324
- cnn_activation=cnn_activation,
325
- cnn_activation_kwargs=cnn_activation_kwargs,
326
375
  transformer_activation=transformer_activation,
327
376
  transformer_activation_kwargs=transformer_activation_kwargs,
328
377
  mlp_dropout=mlp_dropout,
329
378
  msa_dropout=msa_dropout,
330
379
  stochastic_depth=stochastic_depth,
331
380
  linear_module=linear_module,
332
- image_channels=image_channels,
333
381
  batch_norm=batch_norm,
334
382
  )
335
383
  self.pool = SequencePool(
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: broccoli-ml
3
- Version: 0.1.41
3
+ Version: 0.3.0
4
4
  Summary: Some useful Pytorch models, circa 2025
5
5
  License: MIT
6
6
  Author: Nicholas Bailey
@@ -3,15 +3,15 @@ broccoli/activation.py,sha256=jmKSNcq3VfZdVm8Ed65iiUB0ZfqmP_7lmEGkAWSIMdQ,2519
3
3
  broccoli/assets/2025_resnet_imagenet_1k_pretrained_state_dict.pkl,sha256=RZpPupWxFaVfgZrK-gBgfW1hj78oMEGhVWTbjRB3qMo,46835797
4
4
  broccoli/assets/cifar100_eigenvectors_size_2.pt,sha256=DjXDOXMeuMpIqNuGhX9z-OWYVqZwIMScSXZApRr9JjU,2501
5
5
  broccoli/assets/cifar100_eigenvectors_size_3.pt,sha256=gL6k0xtXYiYP6ZSvEiMBdJ7kIkT0AngTpDJHFQqwgxA,7173
6
- broccoli/cnn.py,sha256=pv8ttV_-CmNRpYO1HINR-Z3WemaK5SBd2iojZ7E2QBA,14680
6
+ broccoli/cnn.py,sha256=jeRyKIAMWu1E3iyI14MGgSZuZivPMh12iqkqW9ilNjo,17785
7
7
  broccoli/eigenpatches.py,sha256=J6n2usN1oQuHEHYiBNyYpn_a9eQcHjOBiIlvSei520Y,2413
8
8
  broccoli/linear.py,sha256=0XYCi3ckTEKwAgBOMUSJP2HsnrroOH8eyrhRdpANG2w,1298
9
9
  broccoli/rope.py,sha256=hw7kBPNR9GQXj4GxyIAffsGKPfcTPOFh8Bc7oEHtaZY,12108
10
10
  broccoli/tensor.py,sha256=E2JK5mQwJf75e23-JGcDoT7QxQf89DJReUo2et1LhRY,1716
11
- broccoli/transformer.py,sha256=gFBIEowGFPSgQhM1RwsRtQlw_WzVJPY-LJyf1MLtPek,16277
11
+ broccoli/transformer.py,sha256=23R58t3TLZMb9ulhCtQ3gXu0mPlfyPvLM8TaGOpaz58,16310
12
12
  broccoli/utils.py,sha256=htq_hOsdhUhL0nJi9WkKiEYOjEoWqFpK5X49PtgTf-0,299
13
- broccoli/vit.py,sha256=JC-NWM1Ys7JOrapH9Ka6ED8C4yViJ2Bv3d0SfFgDaZ8,12876
14
- broccoli_ml-0.1.41.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
15
- broccoli_ml-0.1.41.dist-info/METADATA,sha256=dEBaKtK3p19LI1gW7bExrE_xHmUaT1lhp7GoMwI510s,1257
16
- broccoli_ml-0.1.41.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
17
- broccoli_ml-0.1.41.dist-info/RECORD,,
13
+ broccoli/vit.py,sha256=NuHW2xcaUEv_IHAZbrrGHUWKu9D7JMR1iKDCCX07RQs,13787
14
+ broccoli_ml-0.3.0.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
15
+ broccoli_ml-0.3.0.dist-info/METADATA,sha256=sAbHQ0Q2yM5kaovkF22cTKCk4SU_z6vi6QtmOMMwJlQ,1256
16
+ broccoli_ml-0.3.0.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
17
+ broccoli_ml-0.3.0.dist-info/RECORD,,