broccoli-ml 0.16.0__tar.gz → 0.18.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: broccoli-ml
3
- Version: 0.16.0
3
+ Version: 0.18.0
4
4
  Summary: Some useful Pytorch models, circa 2025
5
5
  License: MIT
6
6
  Author: Nicholas Bailey
@@ -12,11 +12,11 @@ class SwiGLU(nn.Module):
12
12
  Halves the incoming parameter count, which should be scaled up before input.
13
13
  """
14
14
 
15
- def __init__(self, linear_module: nn.Module = nn.Linear) -> None:
15
+ def __init__(self) -> None:
16
16
  super().__init__()
17
17
  # Learnable parameter is called "swiglu beta" so that it is easy to find
18
18
  # and exclude from weight decay
19
- self.swiglu_beta = nn.Parameter(torch.tensor([0.0]))
19
+ self.swiglu_beta = nn.Parameter(torch.tensor([1.0]))
20
20
 
21
21
  def forward(self, x):
22
22
  gate, value = rearrange(x, "... (split c) -> split ... c", split=2)
@@ -38,11 +38,13 @@ class SquaredReLU(nn.Module):
38
38
  self.leaky = leaky
39
39
 
40
40
  def forward(self, x):
41
- relu_squared = F.relu(x) ** 2
41
+ if self.leaky:
42
+ relu = F.leaky_relu(x)
43
+ else:
44
+ relu = F.relu(x)
45
+ relu_squared = relu**2
42
46
  if self.clamp:
43
47
  relu_squared = torch.clamp(relu_squared, max=6)
44
- if self.leaky:
45
- relu_squared = relu_squared + 0.1 * x
46
48
  return relu_squared
47
49
 
48
50
 
@@ -57,11 +59,12 @@ class ReLU(nn.Module):
57
59
  self.leaky = leaky
58
60
 
59
61
  def forward(self, x):
60
- relu = F.relu(x)
62
+ if self.leaky:
63
+ relu = F.leaky_relu(x)
64
+ else:
65
+ relu = F.relu(x)
61
66
  if self.clamp:
62
67
  relu = torch.clamp(relu, max=6)
63
- if self.leaky:
64
- relu = relu + 0.1 * x
65
68
  return relu
66
69
 
67
70
 
@@ -70,16 +73,13 @@ class GELU(nn.Module):
70
73
  A ReLU activation function with optional clamp and leakiness.
71
74
  """
72
75
 
73
- def __init__(self, clamp=True, leaky=True) -> None:
76
+ def __init__(self, clamp=True) -> None:
74
77
  super().__init__()
75
78
  self.clamp = clamp
76
- self.leaky = leaky
77
79
  self.gelu = nn.GELU()
78
80
 
79
81
  def forward(self, x):
80
82
  gelu = self.gelu(x)
81
83
  if self.clamp:
82
84
  gelu = torch.clamp(gelu, max=6)
83
- if self.leaky:
84
- gelu = gelu + 0.1 * x
85
85
  return gelu
@@ -32,7 +32,7 @@ class SigmaReparamTensor(nn.Module):
32
32
  approx_right_singular_transpose, dim=0
33
33
  )
34
34
  updated_right_singular = self.tensor.T.mv(approx_right_singular_transpose)
35
- updated_right_singular = F.normalize(self.right_singular, dim=0)
35
+ updated_right_singular = F.normalize(updated_right_singular, dim=0)
36
36
  self.right_singular.data.copy_(updated_right_singular)
37
37
  rayleigh_quotient = torch.einsum(
38
38
  "m,mn,n->",
@@ -258,7 +258,7 @@ class FeedforwardBlock(nn.Module):
258
258
  linear_module(input_features, self.max_features),
259
259
  self.activation,
260
260
  nn.LayerNorm(ratio * output_features),
261
- linear_module(ratio * output_features, output_features, bias=False),
261
+ linear_module(ratio * output_features, output_features),
262
262
  self.dropout,
263
263
  ]
264
264
  )
@@ -176,6 +176,9 @@ class ViTEncoder(nn.Module):
176
176
  )
177
177
 
178
178
  if cnn:
179
+ # This block rhymes:
180
+ if cnn_activation.__name__.endswith("GLU"):
181
+ cnn_out_channels *= 2
179
182
  cnn_output_size = calculate_output_spatial_size(
180
183
  input_size,
181
184
  kernel_size=cnn_kernel_size,
@@ -206,9 +209,6 @@ class ViTEncoder(nn.Module):
206
209
  batchnormxd(cnn_activation_out_channels),
207
210
  ]
208
211
  )
209
- # This block rhymes:
210
- if cnn and cnn_activation.__name__.endswith("GLU"):
211
- cnn_out_channels *= 2
212
212
  else:
213
213
  self.cnn = nn.Identity()
214
214
  self.activate_and_dropout = nn.Identity()
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "broccoli-ml"
3
- version = "0.16.0"
3
+ version = "0.18.0"
4
4
  description = "Some useful Pytorch models, circa 2025"
5
5
  authors = [
6
6
  {name = "Nicholas Bailey"}
File without changes
File without changes