broccoli-ml 3.2.0__tar.gz → 3.3.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: broccoli-ml
3
- Version: 3.2.0
3
+ Version: 3.3.1
4
4
  Summary: Some useful Pytorch models, circa 2025
5
5
  License: MIT
6
6
  Author: Nicholas Bailey
@@ -259,6 +259,13 @@ class FeedforwardBlock(nn.Module):
259
259
  self.residual_path = residual_path
260
260
  self.post_norm = post_norm
261
261
 
262
+ if self.residual_path and (output_features < input_features):
263
+ raise ValueError(
264
+ "If the number of output features will be less than "
265
+ "the number of input features, then `residual_path` "
266
+ "should be set to False."
267
+ )
268
+
262
269
  if self.post_norm:
263
270
  self.layernorm = nn.LayerNorm(output_features)
264
271
 
@@ -0,0 +1,15 @@
1
+ import torch.nn as nn
2
+ import torch.nn.functional as F
3
+
4
+
5
+ class PadTensor(nn.Module):
6
+ def __init__(self, *args, **kwargs):
7
+ super().__init__()
8
+ self.args = args
9
+ self.kwargs = kwargs
10
+
11
+ def forward(self, x):
12
+ if sum(self.args[0]) == 0:
13
+ return x
14
+ else:
15
+ return F.pad(x, *self.args, **self.kwargs)
@@ -4,25 +4,12 @@ from typing import Optional
4
4
  from .transformer import TransformerEncoder, FeedforwardBlock
5
5
  from .cnn import SpaceToDepth, calculate_output_spatial_size, spatial_tuple
6
6
  from .activation import ReLU, SquaredReLU, GELU, SwiGLU
7
+ from .utils import PadTensor
7
8
 
8
9
  from einops import einsum
9
10
  from einops.layers.torch import Rearrange
10
11
 
11
12
  import torch.nn as nn
12
- import torch.nn.functional as F
13
-
14
-
15
- class PadTensor(nn.Module):
16
- def __init__(self, *args, **kwargs):
17
- super().__init__()
18
- self.args = args
19
- self.kwargs = kwargs
20
-
21
- def forward(self, x):
22
- if sum(self.args[0]) == 0:
23
- return x
24
- else:
25
- return F.pad(x, *self.args, **self.kwargs)
26
13
 
27
14
 
28
15
  class GetCLSToken(nn.Module):
@@ -58,7 +45,6 @@ class ClassificationHead(nn.Module):
58
45
  self,
59
46
  d_model,
60
47
  n_classes,
61
- linear_module=nn.Linear,
62
48
  logit_projection_layer=nn.Linear,
63
49
  batch_norm_logits=True,
64
50
  ):
@@ -66,10 +52,11 @@ class ClassificationHead(nn.Module):
66
52
  self.d_model = d_model
67
53
  self.summarize = GetCLSToken()
68
54
 
69
- if logit_projection_layer is not None:
70
- self.projection = logit_projection_layer(d_model, n_classes)
71
- else:
55
+ if d_model == n_classes:
56
+ # No need to project
72
57
  self.projection = nn.Identity()
58
+ else:
59
+ self.projection = logit_projection_layer(d_model, n_classes)
73
60
 
74
61
  if batch_norm_logits:
75
62
  self.batch_norm = nn.BatchNorm1d(n_classes, affine=False)
@@ -99,19 +86,17 @@ class SequencePoolClassificationHead(ClassificationHead):
99
86
  self,
100
87
  d_model,
101
88
  n_classes,
102
- linear_module=nn.Linear,
103
89
  logit_projection_layer=nn.Linear,
104
90
  batch_norm_logits=True,
105
91
  ):
106
92
  super().__init__(
107
93
  d_model,
108
94
  n_classes,
109
- linear_module=linear_module,
110
95
  logit_projection_layer=logit_projection_layer,
111
96
  batch_norm_logits=batch_norm_logits,
112
97
  )
113
98
 
114
- self.summarize = SequencePool(d_model, linear_module)
99
+ self.summarize = SequencePool(d_model, logit_projection_layer)
115
100
  # Rebuild the classification process with the correct summary module:
116
101
  self.classification_process = nn.Sequential(
117
102
  *[
@@ -513,7 +498,6 @@ class ViT(nn.Module):
513
498
  self.pool = head(
514
499
  transformer_embedding_size,
515
500
  image_classes,
516
- linear_module=linear_module,
517
501
  logit_projection_layer=logit_projection_layer,
518
502
  batch_norm_logits=batch_norm_logits,
519
503
  )
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "broccoli-ml"
3
- version = "3.2.0"
3
+ version = "3.3.1"
4
4
  description = "Some useful Pytorch models, circa 2025"
5
5
  authors = [
6
6
  {name = "Nicholas Bailey"}
File without changes
File without changes
File without changes