broccoli-ml 0.1.41__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
broccoli/vit.py CHANGED
@@ -1,8 +1,8 @@
1
1
  import math
2
2
  from typing import Optional
3
3
 
4
- from .transformer import TransformerEncoder
5
- from .cnn import ConvLayer, ConcatPool
4
+ from .transformer import TransformerEncoder, DenoisingAutoEncoder
5
+ from .cnn import SpaceToDepth, calculate_output_spatial_size, spatial_tuple
6
6
  from .activation import ReLU, SquaredReLU, GELU, SwiGLU
7
7
  from einops import einsum
8
8
  from einops.layers.torch import Rearrange
@@ -61,38 +61,34 @@ class CCTEncoder(nn.Module):
61
61
 
62
62
  def __init__(
63
63
  self,
64
- image_size=32,
65
- conv_kernel_size=3,
66
- conv_pooling_type="maxpool",
67
- conv_pooling_kernel_size=3,
68
- conv_pooling_kernel_stride=2,
69
- conv_pooling_kernel_padding=1,
70
- conv_dropout=0.0,
64
+ input_size=(32, 32),
65
+ cnn_in_channels=3,
66
+ cnn_kernel_size=3,
67
+ cnn_kernel_stride=1,
68
+ cnn_kernel_padding="same",
69
+ cnn_kernel_dilation=1,
70
+ cnn_kernel_groups=1,
71
+ cnn_activation: nn.Module = nn.ReLU,
72
+ cnn_activation_kwargs: Optional[dict] = None,
73
+ cnn_dropout=0.0,
74
+ pooling_type="maxpool",
75
+ pooling_kernel_size=3,
76
+ pooling_kernel_stride=2,
77
+ pooling_kernel_padding=1,
71
78
  transformer_position_embedding="absolute", # absolute or relative
72
79
  transformer_embedding_size=256,
73
80
  transformer_layers=7,
74
81
  transformer_heads=4,
75
82
  transformer_mlp_ratio=2,
76
83
  transformer_bos_tokens=4,
77
- tranformer_share_kv=True,
78
- tranformer_max_subtract=True,
79
- tranformer_d_model_scale=True,
80
- tranformer_log_length_scale=True,
81
- tranformer_quiet_attention=True,
82
- cnn_activation: nn.Module = nn.ReLU,
83
- cnn_activation_kwargs: Optional[dict] = None,
84
84
  transformer_activation: nn.Module = nn.GELU,
85
85
  transformer_activation_kwargs: Optional[dict] = None,
86
86
  mlp_dropout=0.0,
87
87
  msa_dropout=0.1,
88
88
  stochastic_depth=0.1,
89
89
  linear_module=nn.Linear,
90
- image_channels=3,
91
- batch_norm=False,
90
+ batch_norm=True,
92
91
  ):
93
- if conv_pooling_type not in ["maxpool", "concat"]:
94
- raise NotImplementedError("Pooling type must be maxpool or concat")
95
-
96
92
  super().__init__()
97
93
 
98
94
  if cnn_activation_kwargs is not None:
@@ -107,55 +103,122 @@ class CCTEncoder(nn.Module):
107
103
  else:
108
104
  self.transformer_activation = transformer_activation()
109
105
 
110
- self.image_size = image_size
106
+ self.input_size = input_size
107
+ self.spatial_dimensions = len(self.input_size)
108
+
109
+ if self.spatial_dimensions == 1:
110
+ maxpoolxd = nn.MaxPool1d
111
+ convxd = nn.Conv1d
112
+ batchnormxd = nn.BatchNorm1d
113
+ spatial_dim_names = "D1"
114
+ elif self.spatial_dimensions == 2:
115
+ maxpoolxd = nn.MaxPool2d
116
+ convxd = nn.Conv2d
117
+ batchnormxd = nn.BatchNorm2d
118
+ spatial_dim_names = "D1 D2"
119
+ elif self.spatial_dimensions == 3:
120
+ maxpoolxd = nn.MaxPool3d
121
+ convxd = nn.Conv3d
122
+ batchnormxd = nn.BatchNorm3d
123
+ spatial_dim_names = "D1 D2 D3"
124
+ else:
125
+ raise NotImplementedError(
126
+ "`input_size` must be a tuple of length 1, 2, or 3."
127
+ )
111
128
 
112
- # XXX: We assume a square image here
113
- output_size = math.floor(
114
- (image_size + 2 * conv_pooling_kernel_padding - conv_pooling_kernel_size)
115
- / conv_pooling_kernel_stride
116
- + 1
117
- ) # output of pooling
129
+ cnn_output_size = calculate_output_spatial_size(
130
+ input_size,
131
+ kernel_size=cnn_kernel_size,
132
+ stride=cnn_kernel_stride,
133
+ padding=cnn_kernel_padding,
134
+ dilation=cnn_kernel_dilation,
135
+ )
118
136
 
119
- self.sequence_length = output_size**2
137
+ pooling_output_size = (
138
+ cnn_output_size
139
+ if pooling_type is None
140
+ else calculate_output_spatial_size(
141
+ cnn_output_size,
142
+ kernel_size=pooling_kernel_size,
143
+ stride=pooling_kernel_stride,
144
+ padding=pooling_kernel_padding,
145
+ dilation=1,
146
+ )
147
+ )
148
+
149
+ self.sequence_length = math.prod(pooling_output_size) # One token per voxel
150
+
151
+ pooling_kernel_voxels = math.prod(
152
+ spatial_tuple(pooling_kernel_size, self.spatial_dimensions)
153
+ )
120
154
 
121
- if conv_pooling_type == "maxpool":
122
- conv_out_channels = transformer_embedding_size
123
- elif conv_pooling_type == "concat":
124
- conv_out_channels = int(
125
- math.floor(transformer_embedding_size / (conv_pooling_kernel_size**2))
155
+ if pooling_type in ["maxpool", None]:
156
+ cnn_out_channels = transformer_embedding_size
157
+ elif pooling_type == "concat":
158
+ cnn_out_channels = math.floor(
159
+ transformer_embedding_size / pooling_kernel_voxels
126
160
  )
161
+ else:
162
+ raise NotImplementedError("Pooling type must be maxpool, concat or None")
163
+
164
+ cnn_activation_out_channels = cnn_out_channels
127
165
 
128
- # This if block rhymes:
166
+ # This block rhymes:
129
167
  if cnn_activation.__name__.endswith("GLU"):
130
- conv_out_channels *= 2
131
-
132
- self.conv = ConvLayer(
133
- image_channels,
134
- conv_out_channels,
135
- kernel_size=conv_kernel_size,
136
- stride=1,
137
- padding="same",
138
- linear_module=linear_module,
168
+ cnn_out_channels *= 2
169
+
170
+ self.cnn = convxd(
171
+ cnn_in_channels,
172
+ cnn_out_channels,
173
+ cnn_kernel_size,
174
+ stride=cnn_kernel_stride,
175
+ padding=cnn_kernel_padding,
176
+ dilation=cnn_kernel_dilation,
177
+ groups=cnn_kernel_groups,
178
+ bias=True,
179
+ padding_mode="zeros",
180
+ )
181
+
182
+ self.activate_and_dropout = nn.Sequential(
183
+ *[
184
+ Rearrange( # rearrange in case we're using XGLU activation
185
+ f"N C {spatial_dim_names} -> N {spatial_dim_names} C"
186
+ ),
187
+ self.cnn_activation,
188
+ Rearrange(f"N {spatial_dim_names} C -> N C {spatial_dim_names}"),
189
+ nn.Dropout(cnn_dropout),
190
+ (
191
+ batchnormxd(cnn_activation_out_channels)
192
+ if batch_norm
193
+ else nn.Identity()
194
+ ),
195
+ ]
139
196
  )
140
197
 
141
- if conv_pooling_type == "maxpool":
198
+ if pooling_type is None:
142
199
  self.pool = nn.Sequential(
143
200
  *[
144
- Rearrange( # rearrange in case we're using XGLU activation
145
- "N C H W -> N H W C"
146
- ),
147
- self.cnn_activation,
148
- Rearrange("N H W C -> N C H W"),
149
- nn.MaxPool2d(
150
- conv_pooling_kernel_size,
151
- stride=conv_pooling_kernel_stride,
152
- padding=conv_pooling_kernel_padding,
201
+ Rearrange(
202
+ f"N C {spatial_dim_names} -> N ({spatial_dim_names}) C"
203
+ ), # for transformer
204
+ ]
205
+ )
206
+
207
+ elif pooling_type == "maxpool":
208
+ self.pool = nn.Sequential(
209
+ *[
210
+ maxpoolxd(
211
+ pooling_kernel_size,
212
+ stride=pooling_kernel_stride,
213
+ padding=pooling_kernel_padding,
153
214
  ),
154
- Rearrange("N C H W -> N (H W) C"), # for transformer
215
+ Rearrange(
216
+ f"N C {spatial_dim_names} -> N ({spatial_dim_names}) C"
217
+ ), # for transformer
155
218
  ]
156
219
  )
157
220
 
158
- elif conv_pooling_type == "concat":
221
+ elif pooling_type == "concat":
159
222
 
160
223
  if transformer_activation_kwargs is not None:
161
224
  self.concatpool_activation = transformer_activation(
@@ -164,44 +227,29 @@ class CCTEncoder(nn.Module):
164
227
  else:
165
228
  self.concatpool_activation = transformer_activation()
166
229
 
167
- concatpool_out_channels = conv_pooling_kernel_size**2 * conv_out_channels
168
-
169
- if cnn_activation.__name__.endswith("GLU"):
170
- cnn_activation_output_channels = concatpool_out_channels / 2
171
- else:
172
- cnn_activation_output_channels = concatpool_out_channels
230
+ concatpool_out_channels = (
231
+ pooling_kernel_voxels * cnn_activation_out_channels
232
+ )
173
233
 
174
234
  self.pool = nn.Sequential(
175
235
  *[
176
- ConcatPool(
177
- conv_pooling_kernel_size,
178
- stride=conv_pooling_kernel_stride,
179
- padding=conv_pooling_kernel_padding,
180
- ),
181
- Rearrange( # rearrange in case we're using XGLU activation
182
- "N C H W -> N H W C"
236
+ SpaceToDepth(
237
+ pooling_kernel_size,
238
+ stride=pooling_kernel_stride,
239
+ padding=pooling_kernel_padding,
240
+ spatial_dimensions=self.spatial_dimensions,
183
241
  ),
184
- self.cnn_activation,
185
- nn.Dropout(conv_dropout),
186
- Rearrange( # rearrange in case we're using XGLU activation
187
- "N H W C -> N C H W"
242
+ Rearrange( # for transformer
243
+ f"N C {spatial_dim_names} -> N ({spatial_dim_names}) C"
188
244
  ),
189
- nn.BatchNorm2d(cnn_activation_output_channels),
190
- Rearrange( # rearrange in case we're using XGLU activation
191
- "N C H W -> N (H W) C"
192
- ),
193
- nn.Linear(
194
- cnn_activation_output_channels,
195
- (
196
- 2 * transformer_embedding_size * transformer_mlp_ratio
197
- if transformer_activation.__name__.endswith("GLU")
198
- else transformer_embedding_size * transformer_mlp_ratio
199
- ),
200
- ),
201
- self.concatpool_activation,
202
- nn.Linear(
203
- transformer_embedding_size * transformer_mlp_ratio,
245
+ DenoisingAutoEncoder(
246
+ concatpool_out_channels,
247
+ transformer_mlp_ratio,
204
248
  transformer_embedding_size,
249
+ activation=transformer_activation,
250
+ activation_kwargs=transformer_activation_kwargs,
251
+ dropout=0.0,
252
+ linear_module=linear_module,
205
253
  ),
206
254
  ]
207
255
  )
@@ -213,7 +261,7 @@ class CCTEncoder(nn.Module):
213
261
  transformer_layers,
214
262
  transformer_heads,
215
263
  position_embedding_type=transformer_position_embedding,
216
- source_size=(output_size, output_size),
264
+ source_size=pooling_output_size,
217
265
  mlp_ratio=transformer_mlp_ratio,
218
266
  activation=transformer_activation,
219
267
  activation_kwargs=transformer_activation_kwargs,
@@ -221,11 +269,6 @@ class CCTEncoder(nn.Module):
221
269
  msa_dropout=msa_dropout,
222
270
  stochastic_depth=stochastic_depth,
223
271
  causal=False,
224
- share_kv=tranformer_share_kv,
225
- max_subtract=tranformer_max_subtract,
226
- d_model_scale=tranformer_d_model_scale,
227
- log_length_scale=tranformer_log_length_scale,
228
- quiet_attention=tranformer_quiet_attention,
229
272
  linear_module=linear_module,
230
273
  bos_tokens=transformer_bos_tokens,
231
274
  )
@@ -234,8 +277,9 @@ class CCTEncoder(nn.Module):
234
277
 
235
278
  self.encoder = nn.Sequential(
236
279
  *[
237
- nn.BatchNorm2d(image_channels) if batch_norm else nn.Identity(),
238
- self.conv,
280
+ batchnormxd(cnn_in_channels) if batch_norm else nn.Identity(),
281
+ self.cnn,
282
+ self.activate_and_dropout,
239
283
  self.pool,
240
284
  self.transformer,
241
285
  ]
@@ -255,8 +299,16 @@ class CCT(nn.Module):
255
299
 
256
300
  def __init__(
257
301
  self,
258
- image_size=32,
259
- conv_kernel_size=3, # Only 2 is supported for eigenvector initialisation
302
+ input_size=(32, 32),
303
+ cnn_in_channels=3,
304
+ cnn_kernel_size=3,
305
+ cnn_kernel_stride=1,
306
+ cnn_kernel_padding="same",
307
+ cnn_kernel_dilation=1,
308
+ cnn_kernel_groups=1,
309
+ cnn_activation: nn.Module = nn.ReLU,
310
+ cnn_activation_kwargs: Optional[dict] = None,
311
+ cnn_dropout=0.0,
260
312
  pooling_type="maxpool",
261
313
  pooling_kernel_size=3,
262
314
  pooling_kernel_stride=2,
@@ -267,22 +319,14 @@ class CCT(nn.Module):
267
319
  transformer_heads=4,
268
320
  transformer_mlp_ratio=2,
269
321
  transformer_bos_tokens=4,
270
- tranformer_share_kv=True,
271
- tranformer_max_subtract=True,
272
- tranformer_d_model_scale=True,
273
- tranformer_log_length_scale=True,
274
- tranformer_quiet_attention=True,
275
- cnn_activation: nn.Module = nn.ReLU,
276
- cnn_activation_kwargs: Optional[dict] = None,
277
322
  transformer_activation: nn.Module = nn.GELU,
278
323
  transformer_activation_kwargs: Optional[dict] = None,
279
- mlp_dropout=0.0, # The original paper got best performance from mlp_dropout=0.
280
- msa_dropout=0.1, # "" msa_dropout=0.1
281
- stochastic_depth=0.1, # "" stochastic_depth=0.1
282
- image_classes=100,
324
+ mlp_dropout=0.0,
325
+ msa_dropout=0.1,
326
+ stochastic_depth=0.1,
283
327
  linear_module=nn.Linear,
284
- image_channels=3,
285
- batch_norm=False,
328
+ batch_norm=True,
329
+ image_classes=100,
286
330
  ):
287
331
 
288
332
  super().__init__()
@@ -304,32 +348,32 @@ class CCT(nn.Module):
304
348
  }[transformer_activation]
305
349
 
306
350
  self.encoder = CCTEncoder(
307
- image_size=image_size,
308
- conv_kernel_size=conv_kernel_size,
309
- conv_pooling_type=pooling_type,
310
- conv_pooling_kernel_size=pooling_kernel_size,
311
- conv_pooling_kernel_stride=pooling_kernel_stride,
312
- conv_pooling_kernel_padding=pooling_kernel_padding,
351
+ input_size=input_size,
352
+ cnn_in_channels=cnn_in_channels,
353
+ cnn_kernel_size=cnn_kernel_size,
354
+ cnn_kernel_stride=cnn_kernel_stride,
355
+ cnn_kernel_padding=cnn_kernel_padding,
356
+ cnn_kernel_dilation=cnn_kernel_dilation,
357
+ cnn_kernel_groups=cnn_kernel_groups,
358
+ cnn_activation=cnn_activation,
359
+ cnn_activation_kwargs=cnn_activation_kwargs,
360
+ cnn_dropout=cnn_dropout,
361
+ pooling_type=pooling_type,
362
+ pooling_kernel_size=pooling_kernel_size,
363
+ pooling_kernel_stride=pooling_kernel_stride,
364
+ pooling_kernel_padding=pooling_kernel_padding,
313
365
  transformer_position_embedding=transformer_position_embedding,
314
366
  transformer_embedding_size=transformer_embedding_size,
315
367
  transformer_layers=transformer_layers,
316
368
  transformer_heads=transformer_heads,
317
369
  transformer_mlp_ratio=transformer_mlp_ratio,
318
370
  transformer_bos_tokens=transformer_bos_tokens,
319
- tranformer_share_kv=tranformer_share_kv,
320
- tranformer_max_subtract=tranformer_max_subtract,
321
- tranformer_d_model_scale=tranformer_d_model_scale,
322
- tranformer_log_length_scale=tranformer_log_length_scale,
323
- tranformer_quiet_attention=tranformer_quiet_attention,
324
- cnn_activation=cnn_activation,
325
- cnn_activation_kwargs=cnn_activation_kwargs,
326
371
  transformer_activation=transformer_activation,
327
372
  transformer_activation_kwargs=transformer_activation_kwargs,
328
373
  mlp_dropout=mlp_dropout,
329
374
  msa_dropout=msa_dropout,
330
375
  stochastic_depth=stochastic_depth,
331
376
  linear_module=linear_module,
332
- image_channels=image_channels,
333
377
  batch_norm=batch_norm,
334
378
  )
335
379
  self.pool = SequencePool(
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: broccoli-ml
3
- Version: 0.1.41
3
+ Version: 0.2.0
4
4
  Summary: Some useful Pytorch models, circa 2025
5
5
  License: MIT
6
6
  Author: Nicholas Bailey
@@ -3,15 +3,15 @@ broccoli/activation.py,sha256=jmKSNcq3VfZdVm8Ed65iiUB0ZfqmP_7lmEGkAWSIMdQ,2519
3
3
  broccoli/assets/2025_resnet_imagenet_1k_pretrained_state_dict.pkl,sha256=RZpPupWxFaVfgZrK-gBgfW1hj78oMEGhVWTbjRB3qMo,46835797
4
4
  broccoli/assets/cifar100_eigenvectors_size_2.pt,sha256=DjXDOXMeuMpIqNuGhX9z-OWYVqZwIMScSXZApRr9JjU,2501
5
5
  broccoli/assets/cifar100_eigenvectors_size_3.pt,sha256=gL6k0xtXYiYP6ZSvEiMBdJ7kIkT0AngTpDJHFQqwgxA,7173
6
- broccoli/cnn.py,sha256=pv8ttV_-CmNRpYO1HINR-Z3WemaK5SBd2iojZ7E2QBA,14680
6
+ broccoli/cnn.py,sha256=jeRyKIAMWu1E3iyI14MGgSZuZivPMh12iqkqW9ilNjo,17785
7
7
  broccoli/eigenpatches.py,sha256=J6n2usN1oQuHEHYiBNyYpn_a9eQcHjOBiIlvSei520Y,2413
8
8
  broccoli/linear.py,sha256=0XYCi3ckTEKwAgBOMUSJP2HsnrroOH8eyrhRdpANG2w,1298
9
9
  broccoli/rope.py,sha256=hw7kBPNR9GQXj4GxyIAffsGKPfcTPOFh8Bc7oEHtaZY,12108
10
10
  broccoli/tensor.py,sha256=E2JK5mQwJf75e23-JGcDoT7QxQf89DJReUo2et1LhRY,1716
11
- broccoli/transformer.py,sha256=gFBIEowGFPSgQhM1RwsRtQlw_WzVJPY-LJyf1MLtPek,16277
11
+ broccoli/transformer.py,sha256=23R58t3TLZMb9ulhCtQ3gXu0mPlfyPvLM8TaGOpaz58,16310
12
12
  broccoli/utils.py,sha256=htq_hOsdhUhL0nJi9WkKiEYOjEoWqFpK5X49PtgTf-0,299
13
- broccoli/vit.py,sha256=JC-NWM1Ys7JOrapH9Ka6ED8C4yViJ2Bv3d0SfFgDaZ8,12876
14
- broccoli_ml-0.1.41.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
15
- broccoli_ml-0.1.41.dist-info/METADATA,sha256=dEBaKtK3p19LI1gW7bExrE_xHmUaT1lhp7GoMwI510s,1257
16
- broccoli_ml-0.1.41.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
17
- broccoli_ml-0.1.41.dist-info/RECORD,,
13
+ broccoli/vit.py,sha256=4BHh8ohcVMr_iGVD-FRnyRnKQaaMMjdgs4fixeBm90M,13602
14
+ broccoli_ml-0.2.0.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
15
+ broccoli_ml-0.2.0.dist-info/METADATA,sha256=pvawWlKwj4Ee9e0VWqmu4jdK9fTLuTU82_NP4tCOVaA,1256
16
+ broccoli_ml-0.2.0.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
17
+ broccoli_ml-0.2.0.dist-info/RECORD,,