broccoli-ml 0.12.0__py3-none-any.whl → 0.13.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
broccoli/linear.py CHANGED
@@ -10,7 +10,8 @@ from .tensor import SigmaReparamTensor
10
10
 
11
11
  class SpectralNormLinear(nn.Module):
12
12
  """
13
- ...
13
+ Inspired by Apple's Spectral Normed Linear Layers
14
+ (https://github.com/apple/ml-sigma-reparam)
14
15
  """
15
16
 
16
17
  def __init__(self, in_features: int, out_features: int, bias: bool = True):
broccoli/transformer.py CHANGED
@@ -222,7 +222,7 @@ class MHAttention(nn.Module):
222
222
  return self.out_proj(output_without_heads)
223
223
 
224
224
 
225
- class FeedforwardLayer(nn.Module):
225
+ class FeedforwardBlock(nn.Module):
226
226
  """
227
227
  ...
228
228
  """
@@ -236,7 +236,6 @@ class FeedforwardLayer(nn.Module):
236
236
  activation_kwargs=None,
237
237
  dropout=0.0,
238
238
  linear_module=nn.Linear,
239
- norm_memory=False,
240
239
  ):
241
240
  super().__init__()
242
241
 
@@ -253,22 +252,13 @@ class FeedforwardLayer(nn.Module):
253
252
  else ratio * output_features
254
253
  )
255
254
 
256
- if norm_memory:
257
- self.memory_type = SpectralNormLinear
258
- self.bias_memories = False
259
- else:
260
- self.memory_type = linear_module
261
- self.bias_memories = True
262
-
263
255
  self.process = nn.Sequential(
264
256
  *[
265
257
  nn.LayerNorm(input_features),
266
258
  linear_module(input_features, self.max_features),
267
259
  self.activation,
268
260
  nn.LayerNorm(self.max_features),
269
- self.memory_type(
270
- ratio * output_features, output_features, bias=self.bias_memories
271
- ),
261
+ linear_module(ratio * output_features, output_features, bias=False),
272
262
  self.dropout,
273
263
  ]
274
264
  )
@@ -307,7 +297,6 @@ class TransformerBlock(nn.Module):
307
297
 
308
298
  self.identity_probability = identity_probability
309
299
 
310
- # Submodules for applying attention
311
300
  self.layer_norm = nn.LayerNorm(d_model)
312
301
 
313
302
  if position_embedding_type == "relative":
@@ -335,7 +324,7 @@ class TransformerBlock(nn.Module):
335
324
  )
336
325
 
337
326
  # Submodules for the feedforward process
338
- self.ff = FeedforwardLayer(
327
+ self.ff = FeedforwardBlock(
339
328
  d_model,
340
329
  mlp_ratio,
341
330
  d_model,
broccoli/vit.py CHANGED
@@ -1,9 +1,10 @@
1
1
  import math
2
2
  from typing import Optional
3
3
 
4
- from .transformer import TransformerEncoder, FeedforwardLayer
4
+ from .transformer import TransformerEncoder, FeedforwardBlock
5
5
  from .cnn import SpaceToDepth, calculate_output_spatial_size, spatial_tuple
6
6
  from .activation import ReLU, SquaredReLU, GELU, SwiGLU
7
+ from .linear import SpectralNormLinear
7
8
  from einops import einsum
8
9
  from einops.layers.torch import Rearrange
9
10
  import torch.nn as nn
@@ -20,37 +21,74 @@ class PadTensor(nn.Module):
20
21
  return F.pad(x, *self.args, **self.kwargs)
21
22
 
22
23
 
24
+ class GetCLSToken(nn.Module):
25
+ def __init__(self):
26
+ super().__init__()
27
+
28
+ def forward(self, x):
29
+ return x[:, 0, :]
30
+
31
+
23
32
  class SequencePool(nn.Module):
33
+ def __init__(self, d_model, linear_module):
34
+ super().__init__()
35
+ self.attention = nn.Sequential(
36
+ *[
37
+ linear_module(d_model, 1),
38
+ Rearrange("batch seq 1 -> batch seq"),
39
+ nn.Softmax(dim=-1),
40
+ ]
41
+ )
42
+
43
+ def forward(self, x):
44
+ weights = self.attention(x)
45
+ return einsum(weights, x, "batch seq, batch seq d_model -> batch d_model")
46
+
47
+
48
+ class ClassificationHead(nn.Module):
24
49
  """
25
- As described in [Hasani et al. (2021) *''Escaping the Big Data Paradigm with
26
- Compact Transformers''*](https://arxiv.org/abs/2104.05704). It can be viewed
27
- as a generalisation of average pooling.
50
+ A general classification head for a ViT
28
51
  """
29
52
 
30
- def __init__(self, d_model, linear_module, out_dim, batch_norm=True):
53
+ def __init__(self, d_model, linear_module, n_classes, batch_norm=True):
31
54
  super().__init__()
32
55
  self.d_model = d_model
33
- self.attention = nn.Sequential(
56
+ self.summarize = GetCLSToken()
57
+ self.process = nn.Sequential(
34
58
  *[
35
59
  linear_module(d_model, 1),
36
60
  Rearrange("batch seq 1 -> batch seq"),
37
61
  nn.Softmax(dim=-1),
38
62
  ]
39
63
  )
40
- self.projection = nn.Linear(d_model, out_dim)
41
- self.batch_norm = batch_norm
64
+ self.projection = nn.Linear(d_model, n_classes)
42
65
  if batch_norm:
43
- self.norm = nn.BatchNorm1d(out_dim, affine=False)
66
+ self.batch_norm = nn.BatchNorm1d(n_classes, affine=False)
44
67
  else:
45
- self.norm = None
68
+ self.batch_norm = nn.Identity()
46
69
 
47
- def forward(self, x):
48
- weights = self.attention(x)
49
- weighted_embedding = einsum(
50
- weights, x, "batch seq, batch seq d_model -> batch d_model"
70
+ self.classification_process = nn.Sequential(
71
+ *[
72
+ self.summarize,
73
+ self.projection,
74
+ self.batch_norm,
75
+ ]
51
76
  )
52
- projection = self.projection(weighted_embedding)
53
- return self.norm(projection) if self.batch_norm else projection
77
+
78
+ def forward(self, x):
79
+ return self.classification_process(x)
80
+
81
+
82
+ class SequencePoolClassificationHead(ClassificationHead):
83
+ """
84
+ As described in [Hasani et al. (2021) *''Escaping the Big Data Paradigm with
85
+ Compact Transformers''*](https://arxiv.org/abs/2104.05704). It can be viewed
86
+ as a generalisation of average pooling.
87
+ """
88
+
89
+ def __init__(self, d_model, linear_module, out_dim, batch_norm=True):
90
+ super().__init__(d_model, linear_module, out_dim, batch_norm=True)
91
+ self.summarize = SequencePool()
54
92
 
55
93
 
56
94
  class ViTEncoder(nn.Module):
@@ -66,6 +104,7 @@ class ViTEncoder(nn.Module):
66
104
  def __init__(
67
105
  self,
68
106
  input_size=(32, 32),
107
+ initial_batch_norm=True,
69
108
  cnn=True,
70
109
  cnn_in_channels=3,
71
110
  cnn_out_channels=16,
@@ -82,20 +121,19 @@ class ViTEncoder(nn.Module):
82
121
  pooling_kernel_stride=2,
83
122
  pooling_padding=1,
84
123
  intermediate_feedforward_layer=True,
85
- norm_intermediate_ff_memory=True,
86
124
  transformer_position_embedding="relative", # absolute or relative
87
125
  transformer_embedding_size=256,
88
126
  transformer_layers=7,
89
127
  transformer_heads=4,
90
128
  transformer_mlp_ratio=2,
91
129
  transformer_bos_tokens=0,
130
+ transformer_return_bos_tokens=False,
92
131
  transformer_activation: nn.Module = SquaredReLU,
93
132
  transformer_activation_kwargs: Optional[dict] = None,
94
133
  transformer_mlp_dropout=0.0,
95
134
  transformer_msa_dropout=0.1,
96
135
  transformer_stochastic_depth=0.1,
97
136
  linear_module=nn.Linear,
98
- initial_batch_norm=True,
99
137
  ):
100
138
  super().__init__()
101
139
 
@@ -250,13 +288,14 @@ class ViTEncoder(nn.Module):
250
288
  causal=False,
251
289
  linear_module=linear_module,
252
290
  bos_tokens=transformer_bos_tokens,
291
+ return_bos_tokens=transformer_return_bos_tokens,
253
292
  )
254
293
  else:
255
294
  self.transformer = nn.Identity()
256
295
 
257
296
  if intermediate_feedforward_layer:
258
297
  self.pooling_channels_padding = nn.Identity()
259
- self.intermediate_feedforward_layer = FeedforwardLayer(
298
+ self.intermediate_feedforward_layer = FeedforwardBlock(
260
299
  pooling_out_channels,
261
300
  transformer_mlp_ratio,
262
301
  transformer_embedding_size,
@@ -264,7 +303,6 @@ class ViTEncoder(nn.Module):
264
303
  activation_kwargs=transformer_activation_kwargs,
265
304
  dropout=transformer_mlp_dropout,
266
305
  linear_module=linear_module,
267
- norm_memory=norm_intermediate_ff_memory,
268
306
  )
269
307
  elif pooling_out_channels < transformer_embedding_size:
270
308
  self.intermediate_feedforward_layer = nn.Identity()
@@ -300,7 +338,7 @@ class ViTEncoder(nn.Module):
300
338
  return self.encoder(x)
301
339
 
302
340
 
303
- class CCT(nn.Module):
341
+ class ViT(nn.Module):
304
342
  """
305
343
  Denoising convolutional transformer
306
344
  Based on the Compact Convolutional Transformer (CCT) of [Hasani et al. (2021)
@@ -312,6 +350,8 @@ class CCT(nn.Module):
312
350
  def __init__(
313
351
  self,
314
352
  input_size=(32, 32),
353
+ image_classes=100,
354
+ initial_batch_norm=True,
315
355
  cnn=True,
316
356
  cnn_in_channels=3,
317
357
  cnn_out_channels=16,
@@ -328,22 +368,21 @@ class CCT(nn.Module):
328
368
  pooling_kernel_stride=2,
329
369
  pooling_padding=1,
330
370
  intermediate_feedforward_layer=True,
331
- norm_intermediate_ff_memory=True,
332
371
  transformer_position_embedding="relative", # absolute or relative
333
372
  transformer_embedding_size=256,
334
373
  transformer_layers=7,
335
374
  transformer_heads=4,
336
375
  transformer_mlp_ratio=2,
337
376
  transformer_bos_tokens=0,
377
+ transformer_return_bos_tokens=False,
338
378
  transformer_activation: nn.Module = SquaredReLU,
339
379
  transformer_activation_kwargs: Optional[dict] = None,
340
380
  transformer_mlp_dropout=0.0,
341
381
  transformer_msa_dropout=0.1,
342
382
  transformer_stochastic_depth=0.1,
343
383
  batch_norm_outputs=True,
344
- initial_batch_norm=True,
345
- linear_module=nn.Linear,
346
- image_classes=100,
384
+ linear_module=SpectralNormLinear,
385
+ head=SequencePoolClassificationHead,
347
386
  ):
348
387
 
349
388
  super().__init__()
@@ -366,6 +405,7 @@ class CCT(nn.Module):
366
405
 
367
406
  self.encoder = ViTEncoder(
368
407
  input_size=input_size,
408
+ initial_batch_norm=initial_batch_norm,
369
409
  cnn=cnn,
370
410
  cnn_in_channels=cnn_in_channels,
371
411
  cnn_out_channels=cnn_out_channels,
@@ -388,15 +428,16 @@ class CCT(nn.Module):
388
428
  transformer_heads=transformer_heads,
389
429
  transformer_mlp_ratio=transformer_mlp_ratio,
390
430
  transformer_bos_tokens=transformer_bos_tokens,
431
+ transformer_return_bos_tokens=transformer_return_bos_tokens,
391
432
  transformer_activation=transformer_activation,
392
433
  transformer_activation_kwargs=transformer_activation_kwargs,
393
434
  transformer_mlp_dropout=transformer_mlp_dropout,
394
435
  transformer_msa_dropout=transformer_msa_dropout,
395
436
  transformer_stochastic_depth=transformer_stochastic_depth,
396
437
  linear_module=linear_module,
397
- initial_batch_norm=initial_batch_norm,
398
438
  )
399
- self.pool = SequencePool(
439
+
440
+ self.pool = head(
400
441
  transformer_embedding_size,
401
442
  linear_module,
402
443
  image_classes,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: broccoli-ml
3
- Version: 0.12.0
3
+ Version: 0.13.1
4
4
  Summary: Some useful Pytorch models, circa 2025
5
5
  License: MIT
6
6
  Author: Nicholas Bailey
@@ -5,13 +5,13 @@ broccoli/assets/cifar100_eigenvectors_size_2.pt,sha256=DjXDOXMeuMpIqNuGhX9z-OWYV
5
5
  broccoli/assets/cifar100_eigenvectors_size_3.pt,sha256=gL6k0xtXYiYP6ZSvEiMBdJ7kIkT0AngTpDJHFQqwgxA,7173
6
6
  broccoli/cnn.py,sha256=jeRyKIAMWu1E3iyI14MGgSZuZivPMh12iqkqW9ilNjo,17785
7
7
  broccoli/eigenpatches.py,sha256=J6n2usN1oQuHEHYiBNyYpn_a9eQcHjOBiIlvSei520Y,2413
8
- broccoli/linear.py,sha256=9ZwqC6kkgkr0uPoEjdi_Uq1QFHb4wCXzuU1r2pDreXM,2910
8
+ broccoli/linear.py,sha256=jiGvLguxzkkmX14kRavaeg7IwN8jYJ06wn-NJ6Ivpzo,3008
9
9
  broccoli/rope.py,sha256=hw7kBPNR9GQXj4GxyIAffsGKPfcTPOFh8Bc7oEHtaZY,12108
10
10
  broccoli/tensor.py,sha256=E2JK5mQwJf75e23-JGcDoT7QxQf89DJReUo2et1LhRY,1716
11
- broccoli/transformer.py,sha256=niooSyrG9kZPk60IUPa-ZevEGaUQ8MI6AUxuOInAoqc,16265
11
+ broccoli/transformer.py,sha256=GzkHlzCe4k2-ALMbKpQ0wdsOEKTap6gjOK-FiA7KP3k,15929
12
12
  broccoli/utils.py,sha256=htq_hOsdhUhL0nJi9WkKiEYOjEoWqFpK5X49PtgTf-0,299
13
- broccoli/vit.py,sha256=J1-59oROlpbIEZ-grrXMhuluM3cLoJIAvixnIRSmgOs,15326
14
- broccoli_ml-0.12.0.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
15
- broccoli_ml-0.12.0.dist-info/METADATA,sha256=ApiCIlFOB8rrdnUwLXgYv_0GPsQx85L3Uq2rUh3JhjQ,1257
16
- broccoli_ml-0.12.0.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
17
- broccoli_ml-0.12.0.dist-info/RECORD,,
13
+ broccoli/vit.py,sha256=eEnb4hUwJUVymO3tD8V9JD-9i39ZkeNOYEDa9gwSL60,16398
14
+ broccoli_ml-0.13.1.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
15
+ broccoli_ml-0.13.1.dist-info/METADATA,sha256=WWNGeC9F48atFNGqfw1Kv0i9QCGaiOIcHpIgAeMAAAw,1257
16
+ broccoli_ml-0.13.1.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
17
+ broccoli_ml-0.13.1.dist-info/RECORD,,