broccoli-ml 0.33.2__tar.gz → 0.34.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: broccoli-ml
3
- Version: 0.33.2
3
+ Version: 0.34.0
4
4
  Summary: Some useful Pytorch models, circa 2025
5
5
  License: MIT
6
6
  Author: Nicholas Bailey
@@ -238,22 +238,22 @@ class FeedforwardBlock(nn.Module):
238
238
  linear_module=nn.Linear,
239
239
  pre_norm=True,
240
240
  normformer=False,
241
- raw_input=False,
241
+ post_norm=True,
242
+ residual_path=True,
242
243
  ):
243
244
  super().__init__()
244
245
 
246
+ self.residual_path = residual_path
247
+ self.post_norm = post_norm
248
+
249
+ if self.post_norm:
250
+ self.layernorm = nn.LayerNorm(output_features)
251
+
245
252
  if activation_kwargs is not None:
246
253
  self.activation = activation(**activation_kwargs)
247
254
  else:
248
255
  self.activation = activation()
249
256
 
250
- if raw_input:
251
- self.memory_type = AnchoredLinear
252
- self.memory_bias = False
253
- else:
254
- self.memory_type = nn.Linear
255
- self.memory_bias = True
256
-
257
257
  self.dropout = nn.Dropout(dropout)
258
258
 
259
259
  self.max_features = (
@@ -268,15 +268,18 @@ class FeedforwardBlock(nn.Module):
268
268
  linear_module(input_features, self.max_features),
269
269
  self.activation,
270
270
  nn.LayerNorm(ratio * output_features) if normformer else nn.Identity(),
271
- self.memory_type(
272
- ratio * output_features, output_features, bias=self.memory_bias
273
- ),
271
+ linear_module(ratio * output_features, output_features),
274
272
  self.dropout,
275
273
  ]
276
274
  )
277
275
 
278
276
  def forward(self, x):
279
- return self.process(x)
277
+ if self.residual_path and self.post_norm:
278
+ return self.layernorm(x + self.process(x))
279
+ elif self.residual_path:
280
+ return x + self.process(x)
281
+ else:
282
+ return x
280
283
 
281
284
 
282
285
  class TransformerBlock(nn.Module):
@@ -305,11 +308,14 @@ class TransformerBlock(nn.Module):
305
308
  causal=False,
306
309
  linear_module=nn.Linear,
307
310
  pre_norm=True,
311
+ post_norm=False,
308
312
  normformer=False,
309
313
  ):
310
314
  super().__init__()
311
315
 
312
316
  self.pre_norm = pre_norm
317
+ self.post_norm = post_norm
318
+ self.normformer = normformer
313
319
 
314
320
  self.identity_probability = identity_probability
315
321
 
@@ -351,6 +357,8 @@ class TransformerBlock(nn.Module):
351
357
  linear_module=linear_module,
352
358
  pre_norm=pre_norm,
353
359
  normformer=normformer,
360
+ post_norm=post_norm,
361
+ residual_path=True,
354
362
  )
355
363
 
356
364
  @property
@@ -371,19 +379,18 @@ class TransformerBlock(nn.Module):
371
379
  identity_x = shuffled[:identity_count, :, :]
372
380
  process_x = shuffled[identity_count:, :, :]
373
381
 
382
+ residual_x = process_x
383
+
374
384
  if self.pre_norm:
375
- norm_process_x = self.layer_norm_1(process_x)
376
- process_x = process_x + self.attn(
377
- norm_process_x, norm_process_x, norm_process_x
378
- )
379
- process_x = process_x + self.ff(process_x)
380
- else: # post-norm
381
- process_x = process_x + self.attn(process_x, process_x, process_x)
382
385
  process_x = self.layer_norm_1(process_x)
383
- process_x = process_x + self.ff(process_x)
386
+
387
+ process_x = residual_x + self.attn(process_x, process_x, process_x)
388
+
389
+ if self.post_norm:
384
390
  process_x = self.layer_norm_2(process_x)
385
391
 
386
- # Always post norm as eventually we reach the classification head!
392
+ process_x = self.ff(process_x)
393
+
387
394
  x = torch.cat([identity_x, process_x])[unshuffle_indices, :, :].contiguous()
388
395
 
389
396
  return x
@@ -414,6 +421,7 @@ class TransformerEncoder(nn.Module):
414
421
  bos_tokens=0,
415
422
  return_bos_tokens=False,
416
423
  pre_norm=True,
424
+ post_norm=False,
417
425
  normformer=False,
418
426
  ):
419
427
  if position_embedding_type == "relative":
@@ -474,6 +482,7 @@ class TransformerEncoder(nn.Module):
474
482
  causal=causal,
475
483
  linear_module=linear_module,
476
484
  pre_norm=pre_norm,
485
+ post_norm=post_norm,
477
486
  normformer=normformer,
478
487
  )
479
488
  for i in range(n_layers)
@@ -53,9 +53,7 @@ class ClassificationHead(nn.Module):
53
53
  A general classification head for a ViT
54
54
  """
55
55
 
56
- def __init__(
57
- self, d_model, linear_module, n_classes, layer_norm=True, batch_norm=True
58
- ):
56
+ def __init__(self, d_model, linear_module, n_classes, batch_norm=True):
59
57
  super().__init__()
60
58
  self.d_model = d_model
61
59
  self.summarize = GetCLSToken()
@@ -67,7 +65,6 @@ class ClassificationHead(nn.Module):
67
65
 
68
66
  self.classification_process = nn.Sequential(
69
67
  *[
70
- nn.LayerNorm(d_model) if layer_norm else nn.Identity(),
71
68
  self.summarize,
72
69
  self.projection,
73
70
  self.batch_norm,
@@ -120,8 +117,11 @@ class ViTEncoder(nn.Module):
120
117
  pooling_kernel_stride=2,
121
118
  pooling_padding=1,
122
119
  transformer_feedforward_first=True,
120
+ transformer_initial_ff_residual_path=True,
121
+ transformer_initial_ff_linear_module=None,
123
122
  transformer_pre_norm=True,
124
123
  transformer_normformer=False,
124
+ transformer_post_norm=False,
125
125
  transformer_position_embedding="relative", # absolute or relative
126
126
  transformer_embedding_size=256,
127
127
  transformer_layers=7,
@@ -308,10 +308,14 @@ class ViTEncoder(nn.Module):
308
308
  activation=transformer_activation,
309
309
  activation_kwargs=transformer_activation_kwargs,
310
310
  dropout=transformer_mlp_dropout,
311
- linear_module=linear_module,
311
+ linear_module=(
312
+ transformer_initial_ff_linear_module
313
+ if transformer_initial_ff_linear_module is not None
314
+ else linear_module
315
+ ),
312
316
  pre_norm=transformer_pre_norm,
313
317
  normformer=transformer_normformer,
314
- raw_input=not cnn,
318
+ residual_path=transformer_initial_ff_residual_path,
315
319
  )
316
320
  else:
317
321
  self.initial_ff = nn.Identity()
@@ -365,8 +369,11 @@ class ViT(nn.Module):
365
369
  pooling_kernel_stride=2,
366
370
  pooling_padding=1,
367
371
  transformer_feedforward_first=True,
372
+ transformer_initial_ff_residual_path=True,
373
+ transformer_initial_ff_linear_module=None,
368
374
  transformer_pre_norm=True,
369
375
  transformer_normformer=False,
376
+ transformer_post_norm=False,
370
377
  transformer_position_embedding="relative", # absolute or relative
371
378
  transformer_embedding_size=256,
372
379
  transformer_layers=7,
@@ -421,8 +428,11 @@ class ViT(nn.Module):
421
428
  pooling_kernel_stride=pooling_kernel_stride,
422
429
  pooling_padding=pooling_padding,
423
430
  transformer_feedforward_first=transformer_feedforward_first,
431
+ transformer_initial_ff_residual_path=transformer_initial_ff_residual_path,
432
+ transformer_initial_ff_linear_module=transformer_initial_ff_linear_module,
424
433
  transformer_pre_norm=transformer_pre_norm,
425
434
  transformer_normformer=transformer_normformer,
435
+ transformer_post_norm=transformer_post_norm,
426
436
  transformer_position_embedding=transformer_position_embedding,
427
437
  transformer_embedding_size=transformer_embedding_size,
428
438
  transformer_layers=transformer_layers,
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "broccoli-ml"
3
- version = "0.33.2"
3
+ version = "0.34.0"
4
4
  description = "Some useful Pytorch models, circa 2025"
5
5
  authors = [
6
6
  {name = "Nicholas Bailey"}
File without changes
File without changes