broccoli-ml 0.17.0__py3-none-any.whl → 0.18.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
broccoli/activation.py CHANGED
@@ -12,11 +12,11 @@ class SwiGLU(nn.Module):
12
12
  Halves the incoming parameter count, which should be scaled up before input.
13
13
  """
14
14
 
15
- def __init__(self, linear_module: nn.Module = nn.Linear) -> None:
15
+ def __init__(self) -> None:
16
16
  super().__init__()
17
17
  # Learnable parameter is called "swiglu beta" so that it is easy to find
18
18
  # and exclude from weight decay
19
- self.swiglu_beta = nn.Parameter(torch.tensor([0.0]))
19
+ self.swiglu_beta = nn.Parameter(torch.tensor([1.0]))
20
20
 
21
21
  def forward(self, x):
22
22
  gate, value = rearrange(x, "... (split c) -> split ... c", split=2)
@@ -38,11 +38,13 @@ class SquaredReLU(nn.Module):
38
38
  self.leaky = leaky
39
39
 
40
40
  def forward(self, x):
41
- relu_squared = F.relu(x) ** 2
41
+ if self.leaky:
42
+ relu = F.leaky_relu(x)
43
+ else:
44
+ relu = F.relu(x)
45
+ relu_squared = relu**2
42
46
  if self.clamp:
43
47
  relu_squared = torch.clamp(relu_squared, max=6)
44
- if self.leaky:
45
- relu_squared = relu_squared + 0.1 * x
46
48
  return relu_squared
47
49
 
48
50
 
@@ -57,11 +59,12 @@ class ReLU(nn.Module):
57
59
  self.leaky = leaky
58
60
 
59
61
  def forward(self, x):
60
- relu = F.relu(x)
62
+ if self.leaky:
63
+ relu = F.leaky_relu(x)
64
+ else:
65
+ relu = F.relu(x)
61
66
  if self.clamp:
62
67
  relu = torch.clamp(relu, max=6)
63
- if self.leaky:
64
- relu = relu + 0.1 * x
65
68
  return relu
66
69
 
67
70
 
@@ -70,16 +73,13 @@ class GELU(nn.Module):
70
73
  A ReLU activation function with optional clamp and leakiness.
71
74
  """
72
75
 
73
- def __init__(self, clamp=True, leaky=True) -> None:
76
+ def __init__(self, clamp=True) -> None:
74
77
  super().__init__()
75
78
  self.clamp = clamp
76
- self.leaky = leaky
77
79
  self.gelu = nn.GELU()
78
80
 
79
81
  def forward(self, x):
80
82
  gelu = self.gelu(x)
81
83
  if self.clamp:
82
84
  gelu = torch.clamp(gelu, max=6)
83
- if self.leaky:
84
- gelu = gelu + 0.1 * x
85
85
  return gelu
broccoli/tensor.py CHANGED
@@ -32,7 +32,7 @@ class SigmaReparamTensor(nn.Module):
32
32
  approx_right_singular_transpose, dim=0
33
33
  )
34
34
  updated_right_singular = self.tensor.T.mv(approx_right_singular_transpose)
35
- updated_right_singular = F.normalize(self.right_singular, dim=0)
35
+ updated_right_singular = F.normalize(updated_right_singular, dim=0)
36
36
  self.right_singular.data.copy_(updated_right_singular)
37
37
  rayleigh_quotient = torch.einsum(
38
38
  "m,mn,n->",
broccoli/vit.py CHANGED
@@ -176,6 +176,9 @@ class ViTEncoder(nn.Module):
176
176
  )
177
177
 
178
178
  if cnn:
179
+ # This block rhymes:
180
+ if cnn_activation.__name__.endswith("GLU"):
181
+ cnn_out_channels *= 2
179
182
  cnn_output_size = calculate_output_spatial_size(
180
183
  input_size,
181
184
  kernel_size=cnn_kernel_size,
@@ -206,9 +209,6 @@ class ViTEncoder(nn.Module):
206
209
  batchnormxd(cnn_activation_out_channels),
207
210
  ]
208
211
  )
209
- # This block rhymes:
210
- if cnn and cnn_activation.__name__.endswith("GLU"):
211
- cnn_out_channels *= 2
212
212
  else:
213
213
  self.cnn = nn.Identity()
214
214
  self.activate_and_dropout = nn.Identity()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: broccoli-ml
3
- Version: 0.17.0
3
+ Version: 0.18.0
4
4
  Summary: Some useful Pytorch models, circa 2025
5
5
  License: MIT
6
6
  Author: Nicholas Bailey
@@ -1,5 +1,5 @@
1
1
  broccoli/__init__.py,sha256=tmyspsVxqPZHRQCY_NRwpW4SMNBbtE8E_8z7l-SAzSo,127
2
- broccoli/activation.py,sha256=jmKSNcq3VfZdVm8Ed65iiUB0ZfqmP_7lmEGkAWSIMdQ,2519
2
+ broccoli/activation.py,sha256=AYTH6BNgXSBN8fnOMec_x94gISFj_r9y1IXIvhur7I4,2425
3
3
  broccoli/assets/2025_resnet_imagenet_1k_pretrained_state_dict.pkl,sha256=RZpPupWxFaVfgZrK-gBgfW1hj78oMEGhVWTbjRB3qMo,46835797
4
4
  broccoli/assets/cifar100_eigenvectors_size_2.pt,sha256=DjXDOXMeuMpIqNuGhX9z-OWYVqZwIMScSXZApRr9JjU,2501
5
5
  broccoli/assets/cifar100_eigenvectors_size_3.pt,sha256=gL6k0xtXYiYP6ZSvEiMBdJ7kIkT0AngTpDJHFQqwgxA,7173
@@ -7,11 +7,11 @@ broccoli/cnn.py,sha256=jeRyKIAMWu1E3iyI14MGgSZuZivPMh12iqkqW9ilNjo,17785
7
7
  broccoli/eigenpatches.py,sha256=J6n2usN1oQuHEHYiBNyYpn_a9eQcHjOBiIlvSei520Y,2413
8
8
  broccoli/linear.py,sha256=g8YrxNl6g_WcHrWVmbaBHJU5hv6daFS0r4TxAoPJ9UE,3012
9
9
  broccoli/rope.py,sha256=hw7kBPNR9GQXj4GxyIAffsGKPfcTPOFh8Bc7oEHtaZY,12108
10
- broccoli/tensor.py,sha256=E2JK5mQwJf75e23-JGcDoT7QxQf89DJReUo2et1LhRY,1716
10
+ broccoli/tensor.py,sha256=txJiMLXTgWKTDd7Dh2In-ajiQLRUuuPezO0bKnEU3e8,1719
11
11
  broccoli/transformer.py,sha256=ZJpZIR5d4W_N1uO2AtRVVYzRJSjq1T6lj2mPT6raLg4,15923
12
12
  broccoli/utils.py,sha256=htq_hOsdhUhL0nJi9WkKiEYOjEoWqFpK5X49PtgTf-0,299
13
- broccoli/vit.py,sha256=LkMLqjLXG5SnNMeWLQ1LFfuG-RvQ5YY7BLZ-fNrVcPY,16607
14
- broccoli_ml-0.17.0.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
15
- broccoli_ml-0.17.0.dist-info/METADATA,sha256=VPOAFsxDXoFtOQQkZoeWmUxZ9aRzyu1I9LX7N3WrXUA,1257
16
- broccoli_ml-0.17.0.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
17
- broccoli_ml-0.17.0.dist-info/RECORD,,
13
+ broccoli/vit.py,sha256=--eEAxBzCCVGCIvaZWHsonbVXGe04wDEz0Q3V4YZXoI,16599
14
+ broccoli_ml-0.18.0.dist-info/LICENSE,sha256=0BAzJE5BqQ7Iixp_AFdB2W1uO-HCRX-Qfun8PHt6yVM,1073
15
+ broccoli_ml-0.18.0.dist-info/METADATA,sha256=1CcIDoXMMzX88L6sqjEgtrkXJV4HSyQhRymaGBTfQkY,1257
16
+ broccoli_ml-0.18.0.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
17
+ broccoli_ml-0.18.0.dist-info/RECORD,,