broccoli-ml 0.33.2__py3-none-any.whl → 0.35.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
broccoli/transformer.py CHANGED
@@ -235,25 +235,26 @@ class FeedforwardBlock(nn.Module):
235
235
  activation=nn.ReLU,
236
236
  activation_kwargs=None,
237
237
  dropout=0.0,
238
- linear_module=nn.Linear,
238
+ linear_module_up=nn.Linear,
239
+ linear_module_down=nn.Linear,
239
240
  pre_norm=True,
240
241
  normformer=False,
241
- raw_input=False,
242
+ post_norm=True,
243
+ residual_path=True,
242
244
  ):
243
245
  super().__init__()
244
246
 
247
+ self.residual_path = residual_path
248
+ self.post_norm = post_norm
249
+
250
+ if self.post_norm:
251
+ self.layernorm = nn.LayerNorm(output_features)
252
+
245
253
  if activation_kwargs is not None:
246
254
  self.activation = activation(**activation_kwargs)
247
255
  else:
248
256
  self.activation = activation()
249
257
 
250
- if raw_input:
251
- self.memory_type = AnchoredLinear
252
- self.memory_bias = False
253
- else:
254
- self.memory_type = nn.Linear
255
- self.memory_bias = True
256
-
257
258
  self.dropout = nn.Dropout(dropout)
258
259
 
259
260
  self.max_features = (
@@ -265,18 +266,21 @@ class FeedforwardBlock(nn.Module):
265
266
  self.process = nn.Sequential(
266
267
  *[
267
268
  nn.LayerNorm(input_features) if pre_norm else nn.Identity(),
268
- linear_module(input_features, self.max_features),
269
+ linear_module_up(input_features, self.max_features),
269
270
  self.activation,
270
271
  nn.LayerNorm(ratio * output_features) if normformer else nn.Identity(),
271
- self.memory_type(
272
- ratio * output_features, output_features, bias=self.memory_bias
273
- ),
272
+ linear_module_down(ratio * output_features, output_features),
274
273
  self.dropout,
275
274
  ]
276
275
  )
277
276
 
278
277
  def forward(self, x):
279
- return self.process(x)
278
+ if self.residual_path and self.post_norm:
279
+ return self.layernorm(x + self.process(x))
280
+ elif self.residual_path:
281
+ return x + self.process(x)
282
+ else:
283
+ return x
280
284
 
281
285
 
282
286
  class TransformerBlock(nn.Module):
@@ -305,11 +309,14 @@ class TransformerBlock(nn.Module):
305
309
  causal=False,
306
310
  linear_module=nn.Linear,
307
311
  pre_norm=True,
312
+ post_norm=False,
308
313
  normformer=False,
309
314
  ):
310
315
  super().__init__()
311
316
 
312
317
  self.pre_norm = pre_norm
318
+ self.post_norm = post_norm
319
+ self.normformer = normformer
313
320
 
314
321
  self.identity_probability = identity_probability
315
322
 
@@ -340,7 +347,7 @@ class TransformerBlock(nn.Module):
340
347
  bos_tokens=bos_tokens,
341
348
  )
342
349
 
343
- # Submodules for the feedforward process
350
+ # Submodule for the feedforward process
344
351
  self.ff = FeedforwardBlock(
345
352
  d_model,
346
353
  mlp_ratio,
@@ -348,9 +355,12 @@ class TransformerBlock(nn.Module):
348
355
  activation=activation,
349
356
  activation_kwargs=activation_kwargs,
350
357
  dropout=mlp_dropout,
351
- linear_module=linear_module,
358
+ linear_module_up=linear_module,
359
+ linear_module_down=linear_module,
352
360
  pre_norm=pre_norm,
353
361
  normformer=normformer,
362
+ post_norm=post_norm,
363
+ residual_path=True,
354
364
  )
355
365
 
356
366
  @property
@@ -371,19 +381,18 @@ class TransformerBlock(nn.Module):
371
381
  identity_x = shuffled[:identity_count, :, :]
372
382
  process_x = shuffled[identity_count:, :, :]
373
383
 
384
+ residual_x = process_x
385
+
374
386
  if self.pre_norm:
375
- norm_process_x = self.layer_norm_1(process_x)
376
- process_x = process_x + self.attn(
377
- norm_process_x, norm_process_x, norm_process_x
378
- )
379
- process_x = process_x + self.ff(process_x)
380
- else: # post-norm
381
- process_x = process_x + self.attn(process_x, process_x, process_x)
382
387
  process_x = self.layer_norm_1(process_x)
383
- process_x = process_x + self.ff(process_x)
388
+
389
+ process_x = residual_x + self.attn(process_x, process_x, process_x)
390
+
391
+ if self.post_norm:
384
392
  process_x = self.layer_norm_2(process_x)
385
393
 
386
- # Always post norm as eventually we reach the classification head!
394
+ process_x = self.ff(process_x)
395
+
387
396
  x = torch.cat([identity_x, process_x])[unshuffle_indices, :, :].contiguous()
388
397
 
389
398
  return x
@@ -414,6 +423,7 @@ class TransformerEncoder(nn.Module):
414
423
  bos_tokens=0,
415
424
  return_bos_tokens=False,
416
425
  pre_norm=True,
426
+ post_norm=False,
417
427
  normformer=False,
418
428
  ):
419
429
  if position_embedding_type == "relative":
@@ -474,6 +484,7 @@ class TransformerEncoder(nn.Module):
474
484
  causal=causal,
475
485
  linear_module=linear_module,
476
486
  pre_norm=pre_norm,
487
+ post_norm=post_norm,
477
488
  normformer=normformer,
478
489
  )
479
490
  for i in range(n_layers)
broccoli/vit.py CHANGED
@@ -53,9 +53,7 @@ class ClassificationHead(nn.Module):
53
53
  A general classification head for a ViT
54
54
  """
55
55
 
56
- def __init__(
57
- self, d_model, linear_module, n_classes, layer_norm=True, batch_norm=True
58
- ):
56
+ def __init__(self, d_model, linear_module, n_classes, batch_norm=True):
59
57
  super().__init__()
60
58
  self.d_model = d_model
61
59
  self.summarize = GetCLSToken()
@@ -67,7 +65,6 @@ class ClassificationHead(nn.Module):
67
65
 
68
66
  self.classification_process = nn.Sequential(
69
67
  *[
70
- nn.LayerNorm(d_model) if layer_norm else nn.Identity(),
71
68
  self.summarize,
72
69
  self.projection,
73
70
  self.batch_norm,
@@ -120,8 +117,12 @@ class ViTEncoder(nn.Module):
120
117
  pooling_kernel_stride=2,
121
118
  pooling_padding=1,
122
119
  transformer_feedforward_first=True,
120
+ transformer_initial_ff_residual_path=True,
121
+ transformer_initial_ff_linear_module_up=None,
122
+ transformer_initial_ff_linear_module_down=None,
123
123
  transformer_pre_norm=True,
124
124
  transformer_normformer=False,
125
+ transformer_post_norm=False,
125
126
  transformer_position_embedding="relative", # absolute or relative
126
127
  transformer_embedding_size=256,
127
128
  transformer_layers=7,
@@ -296,6 +297,7 @@ class ViTEncoder(nn.Module):
296
297
  return_bos_tokens=transformer_return_bos_tokens,
297
298
  pre_norm=transformer_pre_norm,
298
299
  normformer=transformer_normformer,
300
+ post_norm=transformer_post_norm,
299
301
  )
300
302
  else:
301
303
  self.transformer = nn.Identity()
@@ -308,10 +310,20 @@ class ViTEncoder(nn.Module):
308
310
  activation=transformer_activation,
309
311
  activation_kwargs=transformer_activation_kwargs,
310
312
  dropout=transformer_mlp_dropout,
311
- linear_module=linear_module,
313
+ linear_module_up=(
314
+ transformer_initial_ff_linear_module_up
315
+ if transformer_initial_ff_linear_module_up is not None
316
+ else linear_module
317
+ ),
318
+ linear_module_down=(
319
+ transformer_initial_ff_linear_module_down
320
+ if transformer_initial_ff_linear_module_down is not None
321
+ else linear_module
322
+ ),
312
323
  pre_norm=transformer_pre_norm,
313
324
  normformer=transformer_normformer,
314
- raw_input=not cnn,
325
+ post_norm=transformer_post_norm,
326
+ residual_path=transformer_initial_ff_residual_path,
315
327
  )
316
328
  else:
317
329
  self.initial_ff = nn.Identity()
@@ -365,8 +377,12 @@ class ViT(nn.Module):
365
377
  pooling_kernel_stride=2,
366
378
  pooling_padding=1,
367
379
  transformer_feedforward_first=True,
380
+ transformer_initial_ff_residual_path=True,
381
+ transformer_initial_ff_linear_module_up=None,
382
+ transformer_initial_ff_linear_module_down=None,
368
383
  transformer_pre_norm=True,
369
384
  transformer_normformer=False,
385
+ transformer_post_norm=False,
370
386
  transformer_position_embedding="relative", # absolute or relative
371
387
  transformer_embedding_size=256,
372
388
  transformer_layers=7,
@@ -421,8 +437,12 @@ class ViT(nn.Module):
421
437
  pooling_kernel_stride=pooling_kernel_stride,
422
438
  pooling_padding=pooling_padding,
423
439
  transformer_feedforward_first=transformer_feedforward_first,
440
+ transformer_initial_ff_residual_path=transformer_initial_ff_residual_path,
441
+ transformer_initial_ff_linear_module_up=transformer_initial_ff_linear_module_up,
442
+ transformer_initial_ff_linear_module_down=transformer_initial_ff_linear_module_down,
424
443
  transformer_pre_norm=transformer_pre_norm,
425
444
  transformer_normformer=transformer_normformer,
445
+ transformer_post_norm=transformer_post_norm,
426
446
  transformer_position_embedding=transformer_position_embedding,
427
447
  transformer_embedding_size=transformer_embedding_size,
428
448
  transformer_layers=transformer_layers,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: broccoli-ml
3
- Version: 0.33.2
3
+ Version: 0.35.0
4
4
  Summary: Some useful Pytorch models, circa 2025
5
5
  License: MIT
6
6
  Author: Nicholas Bailey
@@ -8,10 +8,10 @@ broccoli/eigenpatches.py,sha256=J6n2usN1oQuHEHYiBNyYpn_a9eQcHjOBiIlvSei520Y,2413
8
8
  broccoli/linear.py,sha256=8Y9vD85ZEgNZsIQgO3uRQ3lOQR-JjwvabY8liCrfNCk,4831
9
9
  broccoli/rope.py,sha256=hw7kBPNR9GQXj4GxyIAffsGKPfcTPOFh8Bc7oEHtaZY,12108
10
10
  broccoli/tensor.py,sha256=ks2TRCdS10k2XvxEieh2sj_LzjTNRuiO6gekKFTtziI,4533
11
- broccoli/transformer.py,sha256=xXc9dnceGPCOaloITvspNxrkusdSCE-nRn5Xx7-L_XM,17061
11
+ broccoli/transformer.py,sha256=t0gsADJC9UOlwjm7tDKdy0pAZ8l3clTcCnes86zvH-k,17203
12
12
  broccoli/utils.py,sha256=htq_hOsdhUhL0nJi9WkKiEYOjEoWqFpK5X49PtgTf-0,299
13
- broccoli/vit.py,sha256=7lvP2Sak7N5xkKzgzBbwxIrKK2qMwj7GbBIYuoJNxIU,16214
14
- broccoli_ml-0.33.2.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
15
- broccoli_ml-0.33.2.dist-info/METADATA,sha256=1jnADGHLdddLy4GJnHxjX6noNfssrNuQIiZcJQ68ggA,1257
16
- broccoli_ml-0.33.2.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
17
- broccoli_ml-0.33.2.dist-info/RECORD,,
13
+ broccoli/vit.py,sha256=c-ZRHiLDOoQDJO9OJ51zD9HqaluG33flIwTXQQfms-g,17389
14
+ broccoli_ml-0.35.0.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
15
+ broccoli_ml-0.35.0.dist-info/METADATA,sha256=v0JSpcubSGwxA5dFPbDwz2r2oGZWSeqYND1Mu8WOiJY,1257
16
+ broccoli_ml-0.35.0.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
17
+ broccoli_ml-0.35.0.dist-info/RECORD,,