broccoli-ml 0.1.40__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
broccoli/vit.py CHANGED
@@ -1,8 +1,8 @@
1
1
  import math
2
2
  from typing import Optional
3
3
 
4
- from .transformer import TransformerEncoder
5
- from .cnn import ConvLayer, ConcatPool
4
+ from .transformer import TransformerEncoder, DenoisingAutoEncoder
5
+ from .cnn import SpaceToDepth, calculate_output_spatial_size, spatial_tuple
6
6
  from .activation import ReLU, SquaredReLU, GELU, SwiGLU
7
7
  from einops import einsum
8
8
  from einops.layers.torch import Rearrange
@@ -61,38 +61,34 @@ class CCTEncoder(nn.Module):
61
61
 
62
62
  def __init__(
63
63
  self,
64
- image_size=32,
65
- conv_kernel_size=3,
66
- conv_pooling_type="maxpool",
67
- conv_pooling_kernel_size=3,
68
- conv_pooling_kernel_stride=2,
69
- conv_pooling_kernel_padding=1,
70
- conv_dropout=0.0,
64
+ input_size=(32, 32),
65
+ cnn_in_channels=3,
66
+ cnn_kernel_size=3,
67
+ cnn_kernel_stride=1,
68
+ cnn_kernel_padding="same",
69
+ cnn_kernel_dilation=1,
70
+ cnn_kernel_groups=1,
71
+ cnn_activation: nn.Module = nn.ReLU,
72
+ cnn_activation_kwargs: Optional[dict] = None,
73
+ cnn_dropout=0.0,
74
+ pooling_type="maxpool",
75
+ pooling_kernel_size=3,
76
+ pooling_kernel_stride=2,
77
+ pooling_kernel_padding=1,
71
78
  transformer_position_embedding="absolute", # absolute or relative
72
79
  transformer_embedding_size=256,
73
80
  transformer_layers=7,
74
81
  transformer_heads=4,
75
82
  transformer_mlp_ratio=2,
76
83
  transformer_bos_tokens=4,
77
- tranformer_share_kv=True,
78
- tranformer_max_subtract=True,
79
- tranformer_d_model_scale=True,
80
- tranformer_log_length_scale=True,
81
- tranformer_quiet_attention=True,
82
- cnn_activation: nn.Module = nn.ReLU,
83
- cnn_activation_kwargs: Optional[dict] = None,
84
84
  transformer_activation: nn.Module = nn.GELU,
85
85
  transformer_activation_kwargs: Optional[dict] = None,
86
86
  mlp_dropout=0.0,
87
87
  msa_dropout=0.1,
88
88
  stochastic_depth=0.1,
89
89
  linear_module=nn.Linear,
90
- image_channels=3,
91
- batch_norm=False,
90
+ batch_norm=True,
92
91
  ):
93
- if conv_pooling_type not in ["maxpool", "concat"]:
94
- raise NotImplementedError("Pooling type must be maxpool or concat")
95
-
96
92
  super().__init__()
97
93
 
98
94
  if cnn_activation_kwargs is not None:
@@ -107,55 +103,122 @@ class CCTEncoder(nn.Module):
107
103
  else:
108
104
  self.transformer_activation = transformer_activation()
109
105
 
110
- self.image_size = image_size
106
+ self.input_size = input_size
107
+ self.spatial_dimensions = len(self.input_size)
108
+
109
+ if self.spatial_dimensions == 1:
110
+ maxpoolxd = nn.MaxPool1d
111
+ convxd = nn.Conv1d
112
+ batchnormxd = nn.BatchNorm1d
113
+ spatial_dim_names = "D1"
114
+ elif self.spatial_dimensions == 2:
115
+ maxpoolxd = nn.MaxPool2d
116
+ convxd = nn.Conv2d
117
+ batchnormxd = nn.BatchNorm2d
118
+ spatial_dim_names = "D1 D2"
119
+ elif self.spatial_dimensions == 3:
120
+ maxpoolxd = nn.MaxPool3d
121
+ convxd = nn.Conv3d
122
+ batchnormxd = nn.BatchNorm3d
123
+ spatial_dim_names = "D1 D2 D3"
124
+ else:
125
+ raise NotImplementedError(
126
+ "`input_size` must be a tuple of length 1, 2, or 3."
127
+ )
128
+
129
+ cnn_output_size = calculate_output_spatial_size(
130
+ input_size,
131
+ kernel_size=cnn_kernel_size,
132
+ stride=cnn_kernel_stride,
133
+ padding=cnn_kernel_padding,
134
+ dilation=cnn_kernel_dilation,
135
+ )
136
+
137
+ pooling_output_size = (
138
+ cnn_output_size
139
+ if pooling_type is None
140
+ else calculate_output_spatial_size(
141
+ cnn_output_size,
142
+ kernel_size=pooling_kernel_size,
143
+ stride=pooling_kernel_stride,
144
+ padding=pooling_kernel_padding,
145
+ dilation=1,
146
+ )
147
+ )
111
148
 
112
- # XXX: We assume a square image here
113
- output_size = math.floor(
114
- (image_size + 2 * conv_pooling_kernel_padding - conv_pooling_kernel_size)
115
- / conv_pooling_kernel_stride
116
- + 1
117
- ) # output of pooling
149
+ self.sequence_length = math.prod(pooling_output_size) # One token per voxel
118
150
 
119
- self.sequence_length = output_size**2
151
+ pooling_kernel_voxels = math.prod(
152
+ spatial_tuple(pooling_kernel_size, self.spatial_dimensions)
153
+ )
120
154
 
121
- if conv_pooling_type == "maxpool":
122
- conv_out_channels = transformer_embedding_size
123
- elif conv_pooling_type == "concat":
124
- conv_out_channels = int(
125
- math.floor(transformer_embedding_size / (conv_pooling_kernel_size**2))
155
+ if pooling_type in ["maxpool", None]:
156
+ cnn_out_channels = transformer_embedding_size
157
+ elif pooling_type == "concat":
158
+ cnn_out_channels = math.floor(
159
+ transformer_embedding_size / pooling_kernel_voxels
126
160
  )
161
+ else:
162
+ raise NotImplementedError("Pooling type must be maxpool, concat or None")
127
163
 
128
- # This if block rhymes:
164
+ cnn_activation_out_channels = cnn_out_channels
165
+
166
+ # This block rhymes:
129
167
  if cnn_activation.__name__.endswith("GLU"):
130
- conv_out_channels *= 2
131
-
132
- self.conv = ConvLayer(
133
- image_channels,
134
- conv_out_channels,
135
- kernel_size=conv_kernel_size,
136
- stride=1,
137
- padding="same",
138
- linear_module=linear_module,
168
+ cnn_out_channels *= 2
169
+
170
+ self.cnn = convxd(
171
+ cnn_in_channels,
172
+ cnn_out_channels,
173
+ cnn_kernel_size,
174
+ stride=cnn_kernel_stride,
175
+ padding=cnn_kernel_padding,
176
+ dilation=cnn_kernel_dilation,
177
+ groups=cnn_kernel_groups,
178
+ bias=True,
179
+ padding_mode="zeros",
180
+ )
181
+
182
+ self.activate_and_dropout = nn.Sequential(
183
+ *[
184
+ Rearrange( # rearrange in case we're using XGLU activation
185
+ f"N C {spatial_dim_names} -> N {spatial_dim_names} C"
186
+ ),
187
+ self.cnn_activation,
188
+ Rearrange(f"N {spatial_dim_names} C -> N C {spatial_dim_names}"),
189
+ nn.Dropout(cnn_dropout),
190
+ (
191
+ batchnormxd(cnn_activation_out_channels)
192
+ if batch_norm
193
+ else nn.Identity()
194
+ ),
195
+ ]
139
196
  )
140
197
 
141
- if conv_pooling_type == "maxpool":
198
+ if pooling_type is None:
142
199
  self.pool = nn.Sequential(
143
200
  *[
144
- Rearrange( # rearrange in case we're using XGLU activation
145
- "N C H W -> N H W C"
146
- ),
147
- self.cnn_activation,
148
- Rearrange("N H W C -> N C H W"),
149
- nn.MaxPool2d(
150
- conv_pooling_kernel_size,
151
- stride=conv_pooling_kernel_stride,
152
- padding=conv_pooling_kernel_padding,
201
+ Rearrange(
202
+ f"N C {spatial_dim_names} -> N ({spatial_dim_names}) C"
203
+ ), # for transformer
204
+ ]
205
+ )
206
+
207
+ elif pooling_type == "maxpool":
208
+ self.pool = nn.Sequential(
209
+ *[
210
+ maxpoolxd(
211
+ pooling_kernel_size,
212
+ stride=pooling_kernel_stride,
213
+ padding=pooling_kernel_padding,
153
214
  ),
154
- Rearrange("N C H W -> N (H W) C"), # for transformer
215
+ Rearrange(
216
+ f"N C {spatial_dim_names} -> N ({spatial_dim_names}) C"
217
+ ), # for transformer
155
218
  ]
156
219
  )
157
220
 
158
- elif conv_pooling_type == "concat":
221
+ elif pooling_type == "concat":
159
222
 
160
223
  if transformer_activation_kwargs is not None:
161
224
  self.concatpool_activation = transformer_activation(
@@ -164,42 +227,30 @@ class CCTEncoder(nn.Module):
164
227
  else:
165
228
  self.concatpool_activation = transformer_activation()
166
229
 
167
- concatpool_out_channels = conv_pooling_kernel_size**2 * conv_out_channels
168
-
169
- if cnn_activation.__name__.endswith("GLU"):
170
- cnn_activation_output_channels = concatpool_out_channels / 2
171
- else:
172
- cnn_activation_output_channels = concatpool_out_channels
230
+ concatpool_out_channels = (
231
+ pooling_kernel_voxels * cnn_activation_out_channels
232
+ )
173
233
 
174
234
  self.pool = nn.Sequential(
175
235
  *[
176
- ConcatPool(
177
- conv_pooling_kernel_size,
178
- stride=conv_pooling_kernel_stride,
179
- padding=conv_pooling_kernel_padding,
236
+ SpaceToDepth(
237
+ pooling_kernel_size,
238
+ stride=pooling_kernel_stride,
239
+ padding=pooling_kernel_padding,
240
+ spatial_dimensions=self.spatial_dimensions,
180
241
  ),
181
- Rearrange( # rearrange in case we're using XGLU activation
182
- "N C H W -> N H W C"
242
+ Rearrange( # for transformer
243
+ f"N C {spatial_dim_names} -> N ({spatial_dim_names}) C"
183
244
  ),
184
- self.cnn_activation,
185
- nn.Dropout(conv_dropout),
186
- Rearrange( # rearrange in case we're using XGLU activation
187
- "N H W C -> N C H W"
245
+ DenoisingAutoEncoder(
246
+ concatpool_out_channels,
247
+ transformer_mlp_ratio,
248
+ transformer_embedding_size,
249
+ activation=transformer_activation,
250
+ activation_kwargs=transformer_activation_kwargs,
251
+ dropout=0.0,
252
+ linear_module=linear_module,
188
253
  ),
189
- nn.BatchNorm2d(cnn_activation_output_channels),
190
- Rearrange( # rearrange in case we're using XGLU activation
191
- "N C H W -> N (H W) C"
192
- ),
193
- nn.Linear(
194
- cnn_activation_output_channels,
195
- (
196
- 2 * transformer_embedding_size * transformer_mlp_ratio
197
- if transformer_activation.__name__.endswith("GLU")
198
- else transformer_embedding_size * transformer_mlp_ratio
199
- ),
200
- ),
201
- self.concatpool_activation,
202
- nn.Linear(transformer_embedding_size * transformer_mlp_ratio),
203
254
  ]
204
255
  )
205
256
 
@@ -210,7 +261,7 @@ class CCTEncoder(nn.Module):
210
261
  transformer_layers,
211
262
  transformer_heads,
212
263
  position_embedding_type=transformer_position_embedding,
213
- source_size=(output_size, output_size),
264
+ source_size=pooling_output_size,
214
265
  mlp_ratio=transformer_mlp_ratio,
215
266
  activation=transformer_activation,
216
267
  activation_kwargs=transformer_activation_kwargs,
@@ -218,11 +269,6 @@ class CCTEncoder(nn.Module):
218
269
  msa_dropout=msa_dropout,
219
270
  stochastic_depth=stochastic_depth,
220
271
  causal=False,
221
- share_kv=tranformer_share_kv,
222
- max_subtract=tranformer_max_subtract,
223
- d_model_scale=tranformer_d_model_scale,
224
- log_length_scale=tranformer_log_length_scale,
225
- quiet_attention=tranformer_quiet_attention,
226
272
  linear_module=linear_module,
227
273
  bos_tokens=transformer_bos_tokens,
228
274
  )
@@ -231,8 +277,9 @@ class CCTEncoder(nn.Module):
231
277
 
232
278
  self.encoder = nn.Sequential(
233
279
  *[
234
- nn.BatchNorm2d(image_channels) if batch_norm else nn.Identity(),
235
- self.conv,
280
+ batchnormxd(cnn_in_channels) if batch_norm else nn.Identity(),
281
+ self.cnn,
282
+ self.activate_and_dropout,
236
283
  self.pool,
237
284
  self.transformer,
238
285
  ]
@@ -252,8 +299,16 @@ class CCT(nn.Module):
252
299
 
253
300
  def __init__(
254
301
  self,
255
- image_size=32,
256
- conv_kernel_size=3, # Only 2 is supported for eigenvector initialisation
302
+ input_size=(32, 32),
303
+ cnn_in_channels=3,
304
+ cnn_kernel_size=3,
305
+ cnn_kernel_stride=1,
306
+ cnn_kernel_padding="same",
307
+ cnn_kernel_dilation=1,
308
+ cnn_kernel_groups=1,
309
+ cnn_activation: nn.Module = nn.ReLU,
310
+ cnn_activation_kwargs: Optional[dict] = None,
311
+ cnn_dropout=0.0,
257
312
  pooling_type="maxpool",
258
313
  pooling_kernel_size=3,
259
314
  pooling_kernel_stride=2,
@@ -264,22 +319,14 @@ class CCT(nn.Module):
264
319
  transformer_heads=4,
265
320
  transformer_mlp_ratio=2,
266
321
  transformer_bos_tokens=4,
267
- tranformer_share_kv=True,
268
- tranformer_max_subtract=True,
269
- tranformer_d_model_scale=True,
270
- tranformer_log_length_scale=True,
271
- tranformer_quiet_attention=True,
272
- cnn_activation: nn.Module = nn.ReLU,
273
- cnn_activation_kwargs: Optional[dict] = None,
274
322
  transformer_activation: nn.Module = nn.GELU,
275
323
  transformer_activation_kwargs: Optional[dict] = None,
276
- mlp_dropout=0.0, # The original paper got best performance from mlp_dropout=0.
277
- msa_dropout=0.1, # "" msa_dropout=0.1
278
- stochastic_depth=0.1, # "" stochastic_depth=0.1
279
- image_classes=100,
324
+ mlp_dropout=0.0,
325
+ msa_dropout=0.1,
326
+ stochastic_depth=0.1,
280
327
  linear_module=nn.Linear,
281
- image_channels=3,
282
- batch_norm=False,
328
+ batch_norm=True,
329
+ image_classes=100,
283
330
  ):
284
331
 
285
332
  super().__init__()
@@ -301,32 +348,32 @@ class CCT(nn.Module):
301
348
  }[transformer_activation]
302
349
 
303
350
  self.encoder = CCTEncoder(
304
- image_size=image_size,
305
- conv_kernel_size=conv_kernel_size,
306
- conv_pooling_type=pooling_type,
307
- conv_pooling_kernel_size=pooling_kernel_size,
308
- conv_pooling_kernel_stride=pooling_kernel_stride,
309
- conv_pooling_kernel_padding=pooling_kernel_padding,
351
+ input_size=input_size,
352
+ cnn_in_channels=cnn_in_channels,
353
+ cnn_kernel_size=cnn_kernel_size,
354
+ cnn_kernel_stride=cnn_kernel_stride,
355
+ cnn_kernel_padding=cnn_kernel_padding,
356
+ cnn_kernel_dilation=cnn_kernel_dilation,
357
+ cnn_kernel_groups=cnn_kernel_groups,
358
+ cnn_activation=cnn_activation,
359
+ cnn_activation_kwargs=cnn_activation_kwargs,
360
+ cnn_dropout=cnn_dropout,
361
+ pooling_type=pooling_type,
362
+ pooling_kernel_size=pooling_kernel_size,
363
+ pooling_kernel_stride=pooling_kernel_stride,
364
+ pooling_kernel_padding=pooling_kernel_padding,
310
365
  transformer_position_embedding=transformer_position_embedding,
311
366
  transformer_embedding_size=transformer_embedding_size,
312
367
  transformer_layers=transformer_layers,
313
368
  transformer_heads=transformer_heads,
314
369
  transformer_mlp_ratio=transformer_mlp_ratio,
315
370
  transformer_bos_tokens=transformer_bos_tokens,
316
- tranformer_share_kv=tranformer_share_kv,
317
- tranformer_max_subtract=tranformer_max_subtract,
318
- tranformer_d_model_scale=tranformer_d_model_scale,
319
- tranformer_log_length_scale=tranformer_log_length_scale,
320
- tranformer_quiet_attention=tranformer_quiet_attention,
321
- cnn_activation=cnn_activation,
322
- cnn_activation_kwargs=cnn_activation_kwargs,
323
371
  transformer_activation=transformer_activation,
324
372
  transformer_activation_kwargs=transformer_activation_kwargs,
325
373
  mlp_dropout=mlp_dropout,
326
374
  msa_dropout=msa_dropout,
327
375
  stochastic_depth=stochastic_depth,
328
376
  linear_module=linear_module,
329
- image_channels=image_channels,
330
377
  batch_norm=batch_norm,
331
378
  )
332
379
  self.pool = SequencePool(
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: broccoli-ml
3
- Version: 0.1.40
3
+ Version: 0.2.0
4
4
  Summary: Some useful Pytorch models, circa 2025
5
5
  License: MIT
6
6
  Author: Nicholas Bailey
@@ -3,15 +3,15 @@ broccoli/activation.py,sha256=jmKSNcq3VfZdVm8Ed65iiUB0ZfqmP_7lmEGkAWSIMdQ,2519
3
3
  broccoli/assets/2025_resnet_imagenet_1k_pretrained_state_dict.pkl,sha256=RZpPupWxFaVfgZrK-gBgfW1hj78oMEGhVWTbjRB3qMo,46835797
4
4
  broccoli/assets/cifar100_eigenvectors_size_2.pt,sha256=DjXDOXMeuMpIqNuGhX9z-OWYVqZwIMScSXZApRr9JjU,2501
5
5
  broccoli/assets/cifar100_eigenvectors_size_3.pt,sha256=gL6k0xtXYiYP6ZSvEiMBdJ7kIkT0AngTpDJHFQqwgxA,7173
6
- broccoli/cnn.py,sha256=pv8ttV_-CmNRpYO1HINR-Z3WemaK5SBd2iojZ7E2QBA,14680
6
+ broccoli/cnn.py,sha256=jeRyKIAMWu1E3iyI14MGgSZuZivPMh12iqkqW9ilNjo,17785
7
7
  broccoli/eigenpatches.py,sha256=J6n2usN1oQuHEHYiBNyYpn_a9eQcHjOBiIlvSei520Y,2413
8
8
  broccoli/linear.py,sha256=0XYCi3ckTEKwAgBOMUSJP2HsnrroOH8eyrhRdpANG2w,1298
9
9
  broccoli/rope.py,sha256=hw7kBPNR9GQXj4GxyIAffsGKPfcTPOFh8Bc7oEHtaZY,12108
10
10
  broccoli/tensor.py,sha256=E2JK5mQwJf75e23-JGcDoT7QxQf89DJReUo2et1LhRY,1716
11
- broccoli/transformer.py,sha256=gFBIEowGFPSgQhM1RwsRtQlw_WzVJPY-LJyf1MLtPek,16277
11
+ broccoli/transformer.py,sha256=23R58t3TLZMb9ulhCtQ3gXu0mPlfyPvLM8TaGOpaz58,16310
12
12
  broccoli/utils.py,sha256=htq_hOsdhUhL0nJi9WkKiEYOjEoWqFpK5X49PtgTf-0,299
13
- broccoli/vit.py,sha256=rB1hfEwqRDfSFnFgXzZtzIeqFLnDfiyWe6hDJ7OcH8Q,12777
14
- broccoli_ml-0.1.40.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
15
- broccoli_ml-0.1.40.dist-info/METADATA,sha256=eTdLGu8jKvslYTs2_1qQd-GdV5vSOLJUuYDFHD3CEk8,1257
16
- broccoli_ml-0.1.40.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
17
- broccoli_ml-0.1.40.dist-info/RECORD,,
13
+ broccoli/vit.py,sha256=4BHh8ohcVMr_iGVD-FRnyRnKQaaMMjdgs4fixeBm90M,13602
14
+ broccoli_ml-0.2.0.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
15
+ broccoli_ml-0.2.0.dist-info/METADATA,sha256=pvawWlKwj4Ee9e0VWqmu4jdK9fTLuTU82_NP4tCOVaA,1256
16
+ broccoli_ml-0.2.0.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
17
+ broccoli_ml-0.2.0.dist-info/RECORD,,