broccoli-ml 0.12.0__tar.gz → 0.13.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: broccoli-ml
3
- Version: 0.12.0
3
+ Version: 0.13.0
4
4
  Summary: Some useful Pytorch models, circa 2025
5
5
  License: MIT
6
6
  Author: Nicholas Bailey
@@ -236,7 +236,6 @@ class FeedforwardLayer(nn.Module):
236
236
  activation_kwargs=None,
237
237
  dropout=0.0,
238
238
  linear_module=nn.Linear,
239
- norm_memory=False,
240
239
  ):
241
240
  super().__init__()
242
241
 
@@ -253,22 +252,13 @@ class FeedforwardLayer(nn.Module):
253
252
  else ratio * output_features
254
253
  )
255
254
 
256
- if norm_memory:
257
- self.memory_type = SpectralNormLinear
258
- self.bias_memories = False
259
- else:
260
- self.memory_type = linear_module
261
- self.bias_memories = True
262
-
263
255
  self.process = nn.Sequential(
264
256
  *[
265
257
  nn.LayerNorm(input_features),
266
258
  linear_module(input_features, self.max_features),
267
259
  self.activation,
268
260
  nn.LayerNorm(self.max_features),
269
- self.memory_type(
270
- ratio * output_features, output_features, bias=self.bias_memories
271
- ),
261
+ linear_module(ratio * output_features, output_features, bias=False),
272
262
  self.dropout,
273
263
  ]
274
264
  )
@@ -307,7 +297,6 @@ class TransformerBlock(nn.Module):
307
297
 
308
298
  self.identity_probability = identity_probability
309
299
 
310
- # Submodules for applying attention
311
300
  self.layer_norm = nn.LayerNorm(d_model)
312
301
 
313
302
  if position_embedding_type == "relative":
@@ -20,37 +20,74 @@ class PadTensor(nn.Module):
20
20
  return F.pad(x, *self.args, **self.kwargs)
21
21
 
22
22
 
23
+ class GetCLSToken(nn.Module):
24
+ def __init__(self):
25
+ super().__init__()
26
+
27
+ def forward(self, x):
28
+ return x[:, 0, :]
29
+
30
+
23
31
  class SequencePool(nn.Module):
32
+ def __init__(self, d_model, linear_module):
33
+ super().__init__()
34
+ self.attention = nn.Sequential(
35
+ *[
36
+ linear_module(d_model, 1),
37
+ Rearrange("batch seq 1 -> batch seq"),
38
+ nn.Softmax(dim=-1),
39
+ ]
40
+ )
41
+
42
+ def forward(self, x):
43
+ weights = self.attention(x)
44
+ return einsum(weights, x, "batch seq, batch seq d_model -> batch d_model")
45
+
46
+
47
+ class ClassificationHead(nn.Module):
24
48
  """
25
- As described in [Hasani et al. (2021) *''Escaping the Big Data Paradigm with
26
- Compact Transformers''*](https://arxiv.org/abs/2104.05704). It can be viewed
27
- as a generalisation of average pooling.
49
+ A general classification head for a ViT
28
50
  """
29
51
 
30
- def __init__(self, d_model, linear_module, out_dim, batch_norm=True):
52
+ def __init__(self, d_model, linear_module, n_classes, batch_norm=True):
31
53
  super().__init__()
32
54
  self.d_model = d_model
33
- self.attention = nn.Sequential(
55
+ self.summarize = GetCLSToken()
56
+ self.process = nn.Sequential(
34
57
  *[
35
58
  linear_module(d_model, 1),
36
59
  Rearrange("batch seq 1 -> batch seq"),
37
60
  nn.Softmax(dim=-1),
38
61
  ]
39
62
  )
40
- self.projection = nn.Linear(d_model, out_dim)
41
- self.batch_norm = batch_norm
63
+ self.projection = nn.Linear(d_model, n_classes)
42
64
  if batch_norm:
43
- self.norm = nn.BatchNorm1d(out_dim, affine=False)
65
+ self.batch_norm = nn.BatchNorm1d(n_classes, affine=False)
44
66
  else:
45
- self.norm = None
67
+ self.batch_norm = nn.Identity()
46
68
 
47
- def forward(self, x):
48
- weights = self.attention(x)
49
- weighted_embedding = einsum(
50
- weights, x, "batch seq, batch seq d_model -> batch d_model"
69
+ self.classification_process = nn.Sequential(
70
+ *[
71
+ self.summarize,
72
+ self.projection,
73
+ self.batch_norm,
74
+ ]
51
75
  )
52
- projection = self.projection(weighted_embedding)
53
- return self.norm(projection) if self.batch_norm else projection
76
+
77
+ def forward(self, x):
78
+ return self.classification_process(x)
79
+
80
+
81
+ class SequencePoolClassificationHead(ClassificationHead):
82
+ """
83
+ As described in [Hasani et al. (2021) *''Escaping the Big Data Paradigm with
84
+ Compact Transformers''*](https://arxiv.org/abs/2104.05704). It can be viewed
85
+ as a generalisation of average pooling.
86
+ """
87
+
88
+ def __init__(self, d_model, linear_module, out_dim, batch_norm=True):
89
+ super().__init__(d_model, linear_module, out_dim, batch_norm=True)
90
+ self.summarize = SequencePool()
54
91
 
55
92
 
56
93
  class ViTEncoder(nn.Module):
@@ -82,13 +119,13 @@ class ViTEncoder(nn.Module):
82
119
  pooling_kernel_stride=2,
83
120
  pooling_padding=1,
84
121
  intermediate_feedforward_layer=True,
85
- norm_intermediate_ff_memory=True,
86
122
  transformer_position_embedding="relative", # absolute or relative
87
123
  transformer_embedding_size=256,
88
124
  transformer_layers=7,
89
125
  transformer_heads=4,
90
126
  transformer_mlp_ratio=2,
91
127
  transformer_bos_tokens=0,
128
+ transformer_return_bos_tokens=False,
92
129
  transformer_activation: nn.Module = SquaredReLU,
93
130
  transformer_activation_kwargs: Optional[dict] = None,
94
131
  transformer_mlp_dropout=0.0,
@@ -250,6 +287,7 @@ class ViTEncoder(nn.Module):
250
287
  causal=False,
251
288
  linear_module=linear_module,
252
289
  bos_tokens=transformer_bos_tokens,
290
+ return_bos_tokens=transformer_return_bos_tokens,
253
291
  )
254
292
  else:
255
293
  self.transformer = nn.Identity()
@@ -264,7 +302,6 @@ class ViTEncoder(nn.Module):
264
302
  activation_kwargs=transformer_activation_kwargs,
265
303
  dropout=transformer_mlp_dropout,
266
304
  linear_module=linear_module,
267
- norm_memory=norm_intermediate_ff_memory,
268
305
  )
269
306
  elif pooling_out_channels < transformer_embedding_size:
270
307
  self.intermediate_feedforward_layer = nn.Identity()
@@ -300,7 +337,7 @@ class ViTEncoder(nn.Module):
300
337
  return self.encoder(x)
301
338
 
302
339
 
303
- class CCT(nn.Module):
340
+ class ViT(nn.Module):
304
341
  """
305
342
  Denoising convolutional transformer
306
343
  Based on the Compact Convolutional Transformer (CCT) of [Hasani et al. (2021)
@@ -328,13 +365,13 @@ class CCT(nn.Module):
328
365
  pooling_kernel_stride=2,
329
366
  pooling_padding=1,
330
367
  intermediate_feedforward_layer=True,
331
- norm_intermediate_ff_memory=True,
332
368
  transformer_position_embedding="relative", # absolute or relative
333
369
  transformer_embedding_size=256,
334
370
  transformer_layers=7,
335
371
  transformer_heads=4,
336
372
  transformer_mlp_ratio=2,
337
373
  transformer_bos_tokens=0,
374
+ transformer_return_bos_tokens=False,
338
375
  transformer_activation: nn.Module = SquaredReLU,
339
376
  transformer_activation_kwargs: Optional[dict] = None,
340
377
  transformer_mlp_dropout=0.0,
@@ -344,6 +381,7 @@ class CCT(nn.Module):
344
381
  initial_batch_norm=True,
345
382
  linear_module=nn.Linear,
346
383
  image_classes=100,
384
+ head=SequencePoolClassificationHead,
347
385
  ):
348
386
 
349
387
  super().__init__()
@@ -388,6 +426,7 @@ class CCT(nn.Module):
388
426
  transformer_heads=transformer_heads,
389
427
  transformer_mlp_ratio=transformer_mlp_ratio,
390
428
  transformer_bos_tokens=transformer_bos_tokens,
429
+ transformer_return_bos_tokens=transformer_return_bos_tokens,
391
430
  transformer_activation=transformer_activation,
392
431
  transformer_activation_kwargs=transformer_activation_kwargs,
393
432
  transformer_mlp_dropout=transformer_mlp_dropout,
@@ -396,7 +435,8 @@ class CCT(nn.Module):
396
435
  linear_module=linear_module,
397
436
  initial_batch_norm=initial_batch_norm,
398
437
  )
399
- self.pool = SequencePool(
438
+
439
+ self.pool = head(
400
440
  transformer_embedding_size,
401
441
  linear_module,
402
442
  image_classes,
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "broccoli-ml"
3
- version = "0.12.0"
3
+ version = "0.13.0"
4
4
  description = "Some useful Pytorch models, circa 2025"
5
5
  authors = [
6
6
  {name = "Nicholas Bailey"}
File without changes
File without changes