broccoli-ml 0.7.0__py3-none-any.whl → 0.8.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
broccoli/vit.py CHANGED
@@ -66,8 +66,9 @@ class ViTEncoder(nn.Module):
66
66
  def __init__(
67
67
  self,
68
68
  input_size=(32, 32),
69
+ cnn=True,
69
70
  cnn_in_channels=3,
70
- minimum_cnn_out_channels=16,
71
+ cnn_out_channels=16,
71
72
  cnn_kernel_size=3,
72
73
  cnn_kernel_stride=1,
73
74
  cnn_padding="same",
@@ -135,12 +136,49 @@ class ViTEncoder(nn.Module):
135
136
  "`input_size` must be a tuple of length 1, 2, or 3."
136
137
  )
137
138
 
138
- cnn_output_size = calculate_output_spatial_size(
139
- input_size,
140
- kernel_size=cnn_kernel_size,
141
- stride=cnn_kernel_stride,
142
- padding=cnn_padding,
143
- dilation=cnn_kernel_dilation,
139
+ if cnn:
140
+ cnn_output_size = calculate_output_spatial_size(
141
+ input_size,
142
+ kernel_size=cnn_kernel_size,
143
+ stride=cnn_kernel_stride,
144
+ padding=cnn_padding,
145
+ dilation=cnn_kernel_dilation,
146
+ )
147
+ self.cnn = convxd(
148
+ cnn_in_channels,
149
+ cnn_out_channels,
150
+ cnn_kernel_size,
151
+ stride=cnn_kernel_stride,
152
+ padding=cnn_padding,
153
+ dilation=cnn_kernel_dilation,
154
+ groups=cnn_kernel_groups,
155
+ bias=True,
156
+ padding_mode="zeros",
157
+ )
158
+ cnn_activation_out_channels = cnn_out_channels
159
+ self.activate_and_dropout = nn.Sequential(
160
+ *[
161
+ Rearrange( # rearrange in case we're using XGLU activation
162
+ f"N C {spatial_dim_names} -> N {spatial_dim_names} C"
163
+ ),
164
+ self.cnn_activation,
165
+ Rearrange(f"N {spatial_dim_names} C -> N C {spatial_dim_names}"),
166
+ nn.Dropout(cnn_dropout),
167
+ batchnormxd(cnn_activation_out_channels),
168
+ ]
169
+ )
170
+ # This block rhymes:
171
+ if cnn and cnn_activation.__name__.endswith("GLU"):
172
+ cnn_out_channels *= 2
173
+ else:
174
+ self.cnn = nn.Identity()
175
+ self.activate_and_dropout = nn.Identity()
176
+ cnn_output_size = input_size
177
+ cnn_out_channels = cnn_in_channels
178
+ cnn_activation_out_channels = cnn_in_channels
179
+
180
+ pooling_kernel_voxels = math.prod(
181
+ spatial_tuple(pooling_kernel_size, self.spatial_dimensions)
144
182
  )
145
183
 
146
184
  pooling_output_size = (
@@ -155,59 +193,8 @@ class ViTEncoder(nn.Module):
155
193
  )
156
194
  )
157
195
 
158
- self.sequence_length = math.prod(pooling_output_size) # One token per voxel
159
-
160
- pooling_kernel_voxels = math.prod(
161
- spatial_tuple(pooling_kernel_size, self.spatial_dimensions)
162
- )
163
-
164
- if pooling_type in ["max", "average", None]:
165
- cnn_out_channels = transformer_embedding_size
166
- elif pooling_type == "concat":
167
- cnn_out_channels = max(
168
- math.floor(transformer_embedding_size / pooling_kernel_voxels),
169
- minimum_cnn_out_channels,
170
- )
171
- else:
172
- raise NotImplementedError(
173
- "Pooling type must be max, average, concat or None"
174
- )
175
-
176
- cnn_activation_out_channels = cnn_out_channels
177
-
178
- # This block rhymes:
179
- if cnn_activation.__name__.endswith("GLU"):
180
- cnn_out_channels *= 2
181
-
182
- self.cnn = convxd(
183
- cnn_in_channels,
184
- cnn_out_channels,
185
- cnn_kernel_size,
186
- stride=cnn_kernel_stride,
187
- padding=cnn_padding,
188
- dilation=cnn_kernel_dilation,
189
- groups=cnn_kernel_groups,
190
- bias=True,
191
- padding_mode="zeros",
192
- )
193
-
194
- self.activate_and_dropout = nn.Sequential(
195
- *[
196
- Rearrange( # rearrange in case we're using XGLU activation
197
- f"N C {spatial_dim_names} -> N {spatial_dim_names} C"
198
- ),
199
- self.cnn_activation,
200
- Rearrange(f"N {spatial_dim_names} C -> N C {spatial_dim_names}"),
201
- nn.Dropout(cnn_dropout),
202
- (
203
- batchnormxd(cnn_activation_out_channels)
204
- if initial_batch_norm
205
- else nn.Identity()
206
- ),
207
- ]
208
- )
209
-
210
196
  if pooling_type is None:
197
+ pooling_out_channels = cnn_activation_out_channels
211
198
  self.pool = nn.Sequential(
212
199
  *[
213
200
  Rearrange(
@@ -215,70 +202,36 @@ class ViTEncoder(nn.Module):
215
202
  ), # for transformer
216
203
  ]
217
204
  )
218
- pooling_out_channels = transformer_embedding_size
219
205
 
220
206
  elif pooling_type == "max":
221
- self.pool = nn.Sequential(
222
- *[
223
- maxpoolxd(
224
- pooling_kernel_size,
225
- stride=pooling_kernel_stride,
226
- padding=pooling_padding,
227
- ),
228
- Rearrange(
229
- f"N C {spatial_dim_names} -> N ({spatial_dim_names}) C"
230
- ), # for transformer
231
- ]
207
+ pooling_out_channels = cnn_activation_out_channels
208
+ self.pool = maxpoolxd(
209
+ pooling_kernel_size,
210
+ stride=pooling_kernel_stride,
211
+ padding=pooling_padding,
232
212
  )
233
- pooling_out_channels = transformer_embedding_size
234
-
235
213
  elif pooling_type == "average":
236
- self.pool = nn.Sequential(
237
- *[
238
- avgpoolxd(
239
- pooling_kernel_size,
240
- stride=pooling_kernel_stride,
241
- padding=pooling_padding,
242
- ),
243
- Rearrange(
244
- f"N C {spatial_dim_names} -> N ({spatial_dim_names}) C"
245
- ), # for transformer
246
- ]
214
+ pooling_out_channels = cnn_activation_out_channels
215
+ self.pool = avgpoolxd(
216
+ pooling_kernel_size,
217
+ stride=pooling_kernel_stride,
218
+ padding=pooling_padding,
247
219
  )
248
- pooling_out_channels = transformer_embedding_size
249
-
250
220
  elif pooling_type == "concat":
251
-
252
- if transformer_activation_kwargs is not None:
253
- self.concatpool_activation = transformer_activation(
254
- **transformer_activation_kwargs
255
- )
256
- else:
257
- self.concatpool_activation = transformer_activation()
258
-
259
221
  pooling_out_channels = pooling_kernel_voxels * cnn_activation_out_channels
260
-
261
- self.pool = nn.Sequential(
262
- *[
263
- SpaceToDepth(
264
- pooling_kernel_size,
265
- stride=pooling_kernel_stride,
266
- padding=pooling_padding,
267
- spatial_dimensions=self.spatial_dimensions,
268
- ),
269
- Rearrange( # for transformer
270
- f"N C {spatial_dim_names} -> N ({spatial_dim_names}) C"
271
- ),
272
- (
273
- PadTensor(
274
- (0, transformer_embedding_size - pooling_out_channels)
275
- )
276
- if not intermediate_feedforward_layer
277
- else nn.Identity()
278
- ),
279
- ]
222
+ self.pool = SpaceToDepth(
223
+ pooling_kernel_size,
224
+ stride=pooling_kernel_stride,
225
+ padding=pooling_padding,
226
+ spatial_dimensions=self.spatial_dimensions,
227
+ )
228
+ else:
229
+ raise NotImplementedError(
230
+ "Pooling type must be max, average, concat or None"
280
231
  )
281
232
 
233
+ self.sequence_length = math.prod(pooling_output_size) # One token per voxel
234
+
282
235
  if transformer_layers > 0:
283
236
  self.transformer = TransformerEncoder(
284
237
  self.sequence_length,
@@ -300,25 +253,43 @@ class ViTEncoder(nn.Module):
300
253
  else:
301
254
  self.transformer = nn.Identity()
302
255
 
256
+ if intermediate_feedforward_layer:
257
+ self.pooling_channels_padding = nn.Identity()
258
+ self.intermediate_feedforward_layer = FeedforwardLayer(
259
+ pooling_out_channels,
260
+ transformer_mlp_ratio,
261
+ transformer_embedding_size,
262
+ activation=transformer_activation,
263
+ activation_kwargs=transformer_activation_kwargs,
264
+ dropout=transformer_mlp_dropout,
265
+ linear_module=linear_module,
266
+ )
267
+ elif pooling_out_channels < transformer_embedding_size:
268
+ self.intermediate_feedforward_layer = nn.Identity()
269
+ self.pooling_channels_padding = PadTensor(
270
+ (0, transformer_embedding_size - pooling_out_channels)
271
+ )
272
+ else:
273
+ raise NotImplementedError(
274
+ "In a situation where the choice/parameters of the pooling and the"
275
+ + " `cnn_out_channels` (or the number of `input_channels` if"
276
+ + " `cnn`=False) means that the pooling will result"
277
+ + " in more channels per pixel/voxel than the size of the"
278
+ + " intended transformer embedding,"
279
+ + " `intermediate_feedforward_layer` must be set to True"
280
+ )
281
+
303
282
  self.encoder = nn.Sequential(
304
283
  *[
305
284
  batchnormxd(cnn_in_channels) if initial_batch_norm else nn.Identity(),
306
285
  self.cnn,
307
286
  self.activate_and_dropout,
308
287
  self.pool,
309
- (
310
- FeedforwardLayer(
311
- pooling_out_channels,
312
- transformer_mlp_ratio,
313
- transformer_embedding_size,
314
- activation=transformer_activation,
315
- activation_kwargs=transformer_activation_kwargs,
316
- dropout=transformer_mlp_dropout,
317
- linear_module=linear_module,
318
- )
319
- if intermediate_feedforward_layer
320
- else nn.Identity()
288
+ Rearrange( # for transformer
289
+ f"N C {spatial_dim_names} -> N ({spatial_dim_names}) C"
321
290
  ),
291
+ self.pooling_channels_padding,
292
+ self.intermediate_feedforward_layer,
322
293
  self.transformer,
323
294
  ]
324
295
  )
@@ -339,8 +310,9 @@ class CCT(nn.Module):
339
310
  def __init__(
340
311
  self,
341
312
  input_size=(32, 32),
313
+ cnn=True,
342
314
  cnn_in_channels=3,
343
- minimum_cnn_out_channels=16,
315
+ cnn_out_channels=16,
344
316
  cnn_kernel_size=3,
345
317
  cnn_kernel_stride=1,
346
318
  cnn_padding="same",
@@ -391,8 +363,9 @@ class CCT(nn.Module):
391
363
 
392
364
  self.encoder = ViTEncoder(
393
365
  input_size=input_size,
366
+ cnn=cnn,
394
367
  cnn_in_channels=cnn_in_channels,
395
- minimum_cnn_out_channels=minimum_cnn_out_channels,
368
+ cnn_out_channels=cnn_out_channels,
396
369
  cnn_kernel_size=cnn_kernel_size,
397
370
  cnn_kernel_stride=cnn_kernel_stride,
398
371
  cnn_padding=cnn_padding,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: broccoli-ml
3
- Version: 0.7.0
3
+ Version: 0.8.1
4
4
  Summary: Some useful Pytorch models, circa 2025
5
5
  License: MIT
6
6
  Author: Nicholas Bailey
@@ -10,8 +10,8 @@ broccoli/rope.py,sha256=hw7kBPNR9GQXj4GxyIAffsGKPfcTPOFh8Bc7oEHtaZY,12108
10
10
  broccoli/tensor.py,sha256=E2JK5mQwJf75e23-JGcDoT7QxQf89DJReUo2et1LhRY,1716
11
11
  broccoli/transformer.py,sha256=RSZpbHs_K4ts5os6lWxcGDI3p0zreRwQNnk6mV8HJnk,15930
12
12
  broccoli/utils.py,sha256=htq_hOsdhUhL0nJi9WkKiEYOjEoWqFpK5X49PtgTf-0,299
13
- broccoli/vit.py,sha256=_oL0NRUJakyIke2g8WK5eWaiEh06gAhI67l6Wl7k1oM,15659
14
- broccoli_ml-0.7.0.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
15
- broccoli_ml-0.7.0.dist-info/METADATA,sha256=1QUwYpIruYYiYcMHSgj5lCf-i-FaiipD_5KSAJZeb2s,1256
16
- broccoli_ml-0.7.0.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
17
- broccoli_ml-0.7.0.dist-info/RECORD,,
13
+ broccoli/vit.py,sha256=uXqMIvAVY4PuA-Fv1YxU7L3_74fR19GtNu1caQeMr6k,15185
14
+ broccoli_ml-0.8.1.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
15
+ broccoli_ml-0.8.1.dist-info/METADATA,sha256=ofw4KzrRqP9e5i5OHlu-zxSpMO2DaohGA_N3RfAmt7s,1256
16
+ broccoli_ml-0.8.1.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
17
+ broccoli_ml-0.8.1.dist-info/RECORD,,