broccoli-ml 0.23.1__tar.gz → 0.24.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: broccoli-ml
3
- Version: 0.23.1
3
+ Version: 0.24.0
4
4
  Summary: Some useful Pytorch models, circa 2025
5
5
  License: MIT
6
6
  Author: Nicholas Bailey
@@ -1,10 +1,48 @@
1
1
  import torch
2
2
  from torch import nn
3
3
  from torch.nn import functional as F
4
- from einops import rearrange
5
4
 
6
5
 
7
- class SwiGLU(nn.Module):
6
+ class ReLU(nn.Module):
7
+ """
8
+ A ReLU activation function with optional clamp and leakiness.
9
+ """
10
+
11
+ def __init__(self, clamp=True, leaky=True, leaky_slope=0.01, clamp_max=6.0) -> None:
12
+ super().__init__()
13
+ self.clamp = clamp
14
+ self.leaky = leaky
15
+ self.leaky_slope = leaky_slope
16
+ self.clamp_max = clamp_max
17
+
18
+ def forward(self, x):
19
+ if self.leaky:
20
+ relu = F.leaky_relu(x, leaky_slope=self.leaky_slope)
21
+ else:
22
+ relu = F.relu(x)
23
+ if self.clamp:
24
+ relu = torch.clamp(relu, max=self.clamp_max)
25
+ return relu
26
+
27
+
28
+ class GELU(nn.Module):
29
+ """
30
+ A GELU activation function with optional clamp.
31
+ """
32
+
33
+ def __init__(self, clamp=True) -> None:
34
+ super().__init__()
35
+ self.clamp = clamp
36
+ self.gelu = nn.GELU()
37
+
38
+ def forward(self, x):
39
+ gelu = self.gelu(x)
40
+ if self.clamp:
41
+ gelu = torch.clamp(gelu, max=6)
42
+ return gelu
43
+
44
+
45
+ class Swish(nn.Module):
8
46
  """
9
47
  Implementation of (beta) SwiGLU, as introduced in "GLU Variants Improve Transformer"
10
48
  (https://arxiv.org/abs/2002.05202v1) and used to great effect in LLaMa 2.0.
@@ -16,12 +54,10 @@ class SwiGLU(nn.Module):
16
54
  super().__init__()
17
55
  # Learnable parameter is called "swiglu beta" so that it is easy to find
18
56
  # and exclude from weight decay
19
- self.swiglu_beta = nn.Parameter(torch.tensor([1.0]))
57
+ self.swish_beta = nn.Parameter(torch.tensor([1.0]))
20
58
 
21
59
  def forward(self, x):
22
- gate, value = rearrange(x, "... (split c) -> split ... c", split=2)
23
- beta_swish = gate * F.sigmoid(self.swiglu_beta * gate)
24
- return beta_swish * value
60
+ return x * F.sigmoid(self.swish_beta * x)
25
61
 
26
62
 
27
63
  class SquaredReLU(nn.Module):
@@ -32,54 +68,52 @@ class SquaredReLU(nn.Module):
32
68
  https://azizbelaweid.substack.com/p/what-is-swiglu-how-to-implement-it
33
69
  """
34
70
 
35
- def __init__(self, clamp=True, leaky=True) -> None:
71
+ def __init__(
72
+ self, clamp=True, leaky=True, leaky_slope: float = 0.01, clamp_max=6
73
+ ) -> None:
36
74
  super().__init__()
37
75
  self.clamp = clamp
38
76
  self.leaky = leaky
77
+ self.leaky_slope = leaky_slope
78
+ self.clamp_max = clamp_max
39
79
 
40
80
  def forward(self, x):
41
81
  if self.leaky:
42
- relu = F.leaky_relu(x)
82
+ relu = F.leaky_relu(x, leaky_slope=self.leaky_slope)
43
83
  else:
44
84
  relu = F.relu(x)
45
85
  relu_squared = relu**2
46
86
  if self.clamp:
47
- relu_squared = torch.clamp(relu_squared, max=6)
87
+ relu_squared = torch.clamp(relu_squared, max=self.clamp_max)
48
88
  return relu_squared
49
89
 
50
90
 
51
- class ReLU(nn.Module):
91
+ class XGLU(nn.Module):
52
92
  """
53
- A ReLU activation function with optional clamp and leakiness.
93
+ Generic Gated Linear Unit
54
94
  """
55
95
 
56
- def __init__(self, clamp=True, leaky=True) -> None:
96
+ def __init__(self, activation_module: nn.Module) -> None:
57
97
  super().__init__()
58
- self.clamp = clamp
59
- self.leaky = leaky
98
+ self.activation = activation_module
60
99
 
61
- def forward(self, x):
62
- if self.leaky:
63
- relu = F.leaky_relu(x)
64
- else:
65
- relu = F.relu(x)
66
- if self.clamp:
67
- relu = torch.clamp(relu, max=6)
68
- return relu
100
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
101
+ gate, value = x.chunk(2, dim=-1)
102
+ return self.activation(gate) * value
69
103
 
70
104
 
71
- class GELU(nn.Module):
105
+ def SquaredReGLU(clamp=True, leaky=True, leaky_slope=0.01, clamp_max=6.0) -> XGLU:
72
106
  """
73
- A ReLU activation function with optional clamp and leakiness.
107
+ Factory function that creates a GLU with a SquaredReLU activation.
74
108
  """
109
+ activation_module = SquaredReLU(
110
+ clamp=clamp, leaky=leaky, leaky_slope=leaky_slope, clamp_max=clamp_max
111
+ )
112
+ return XGLU(activation_module)
75
113
 
76
- def __init__(self, clamp=True) -> None:
77
- super().__init__()
78
- self.clamp = clamp
79
- self.gelu = nn.GELU()
80
114
 
81
- def forward(self, x):
82
- gelu = self.gelu(x)
83
- if self.clamp:
84
- gelu = torch.clamp(gelu, max=6)
85
- return gelu
115
+ def SwiGLU() -> XGLU:
116
+ """
117
+ Factory function that creates a GLU with a Swish activation.
118
+ """
119
+ return XGLU(Swish())
@@ -0,0 +1,95 @@
1
+ # UNDER CONSTRUCTION
2
+
3
+ import math
4
+ import torch
5
+ from torch import nn
6
+ from torch.nn import functional as F
7
+
8
+ from .tensor import SigmaReparamTensor, AnchoredReparamTensor
9
+
10
+
11
+ class SpectralNormLinear(nn.Module):
12
+ """
13
+ Inspired by Apple's Spectral Normed Linear Layers
14
+ (https://github.com/apple/ml-sigma-reparam)
15
+ """
16
+
17
+ def __init__(self, in_features: int, out_features: int, bias: bool = True):
18
+ super().__init__()
19
+ self.in_features = in_features
20
+ self.out_features = out_features
21
+ self.use_bias = bias
22
+
23
+ self.weights = None
24
+
25
+ # Define the bias vector as a learnable parameter if required.
26
+ if self.use_bias:
27
+ self.bias = nn.Parameter(torch.empty(out_features))
28
+ else:
29
+ # If no bias, register it as None.
30
+ # This is important so that PyTorch doesn't complain when saving/loading the model.
31
+ self.register_parameter("bias", None)
32
+
33
+ self.reset_parameters()
34
+
35
+ def reset_parameters(self) -> None:
36
+ weights = torch.empty(self.out_features, self.in_features)
37
+ nn.init.kaiming_uniform_(weights, a=math.sqrt(5))
38
+ if self.use_bias:
39
+ fan_in, _ = nn.init._calculate_fan_in_and_fan_out(weights)
40
+ bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
41
+ nn.init.uniform_(self.bias, -bound, bound)
42
+ self.weights = SigmaReparamTensor(weights)
43
+
44
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
45
+ return F.linear(x, self.weights(), self.bias)
46
+
47
+ def __repr__(self) -> str:
48
+ # Optional: A nice representation for printing the module.
49
+ return (
50
+ f"SpectralNormFeedForward(in_features={self.in_features},"
51
+ f"out_features={self.out_features}, bias={self.use_bias})"
52
+ )
53
+
54
+
55
+ class AnchoredLinear(nn.Module):
56
+ """
57
+ ...
58
+ """
59
+
60
+ def __init__(self, in_features: int, out_features: int, bias: bool = True):
61
+ super().__init__()
62
+ self.in_features = in_features
63
+ self.out_features = out_features
64
+ self.use_bias = bias
65
+
66
+ self.weights = None
67
+
68
+ # Define the bias vector as a learnable parameter if required.
69
+ if self.use_bias:
70
+ self.bias = nn.Parameter(torch.empty(out_features))
71
+ else:
72
+ # If no bias, register it as None.
73
+ # This is important so that PyTorch doesn't complain when saving/loading the model.
74
+ self.register_parameter("bias", None)
75
+
76
+ self.reset_parameters()
77
+
78
+ def reset_parameters(self) -> None:
79
+ weights = torch.empty(self.out_features, self.in_features)
80
+ nn.init.kaiming_uniform_(weights, a=math.sqrt(5))
81
+ if self.use_bias:
82
+ fan_in, _ = nn.init._calculate_fan_in_and_fan_out(weights)
83
+ bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
84
+ nn.init.uniform_(self.bias, -bound, bound)
85
+ self.weights = AnchoredReparamTensor(weights)
86
+
87
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
88
+ return F.linear(x, self.weights(), self.bias)
89
+
90
+ def __repr__(self) -> str:
91
+ # Optional: A nice representation for printing the module.
92
+ return (
93
+ f"AnchoredLinear(in_features={self.in_features},"
94
+ f"out_features={self.out_features}, bias={self.use_bias})"
95
+ )
@@ -54,3 +54,43 @@ class SigmaReparamTensor(nn.Module):
54
54
  return self.sigma_reparam_scale * (
55
55
  self.sigma_reparam_tensor / self.approx_spectral_norm
56
56
  )
57
+
58
+
59
+ class AnchoredReparamTensor(nn.Module):
60
+ """
61
+ Reparameterise a tensor as a normalised tensor of weights multiplied by a
62
+ learnable scaling factor.
63
+
64
+ The tensor of weights is also reparameterised as the product of a learnable
65
+ weight tensor with the (fixed) dominant right-singular vector of the
66
+ weight tensor as it was initialised.
67
+
68
+ i.e this module represents a tensor reparameterised as:
69
+
70
+ W_reparam = scale * (W / ||W @ v_0||_2)
71
+
72
+ where v_0 is the dominant right-singular vector of the initial tensor W_init.
73
+ """
74
+
75
+ def __init__(self, init_tensor: torch.Tensor):
76
+ assert init_tensor.ndim == 2, "Input tensor must be a 2D matrix."
77
+ super().__init__()
78
+
79
+ self.weight = nn.Parameter(init_tensor.clone(), requires_grad=True)
80
+
81
+ # At initialization, compute the dominant right-singular vector (v_0)
82
+ # and store it in a non-trainable buffer.
83
+ with torch.no_grad():
84
+ _, _, v_transpose = torch.linalg.svd(self.weight, full_matrices=False)
85
+ # v_transpose[0] is the first row of V^T, which is the first right-singular vector.
86
+ self.register_buffer("anchor_vector", v_transpose[0])
87
+
88
+ initial_norm = torch.linalg.vector_norm(self.weight.mv(self.anchor_vector))
89
+ self.scale = nn.Parameter(initial_norm.clone().detach(), requires_grad=True)
90
+
91
+ def forward(self) -> torch.Tensor:
92
+ # Calculate the L2 norm of the matrix-vector product W @ v_0
93
+ norm = torch.linalg.vector_norm(self.weight.mv(self.anchor_vector))
94
+
95
+ # Return the reparameterized tensor.
96
+ return self.scale * (self.weight / (norm + 1e-6))
@@ -10,7 +10,7 @@ import torch.nn.functional as F
10
10
  from einops import rearrange
11
11
 
12
12
  from .rope import RotaryEmbedding, apply_rotary_emb
13
- from .linear import SpectralNormLinear
13
+ from .linear import AnchoredLinear
14
14
 
15
15
 
16
16
  class MHAttention(nn.Module):
@@ -236,7 +236,7 @@ class FeedforwardBlock(nn.Module):
236
236
  activation_kwargs=None,
237
237
  dropout=0.0,
238
238
  linear_module=nn.Linear,
239
- sigma_reparam=False,
239
+ reparam=False,
240
240
  ):
241
241
  super().__init__()
242
242
 
@@ -253,8 +253,8 @@ class FeedforwardBlock(nn.Module):
253
253
  else ratio * output_features
254
254
  )
255
255
 
256
- if sigma_reparam:
257
- self.memory_type = SpectralNormLinear
256
+ if reparam:
257
+ self.memory_type = AnchoredLinear
258
258
  else:
259
259
  self.memory_type = linear_module
260
260
 
@@ -263,7 +263,7 @@ class FeedforwardBlock(nn.Module):
263
263
  nn.LayerNorm(input_features),
264
264
  linear_module(input_features, self.max_features),
265
265
  self.activation,
266
- nn.LayerNorm(ratio * output_features),
266
+ # nn.LayerNorm(ratio * output_features),
267
267
  self.memory_type(ratio * output_features, output_features),
268
268
  self.dropout,
269
269
  ]
@@ -295,14 +295,14 @@ class ViTEncoder(nn.Module):
295
295
 
296
296
  if transformer_feedforward_first:
297
297
  self.initial_ff = FeedforwardBlock(
298
- transformer_embedding_size,
298
+ max(transformer_embedding_size, pooling_out_channels),
299
299
  transformer_mlp_ratio,
300
300
  transformer_embedding_size,
301
301
  activation=transformer_activation,
302
302
  activation_kwargs=transformer_activation_kwargs,
303
303
  dropout=transformer_mlp_dropout,
304
304
  linear_module=linear_module,
305
- sigma_reparam=not cnn,
305
+ reparam=not cnn,
306
306
  )
307
307
  else:
308
308
  self.initial_ff = nn.Identity()
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "broccoli-ml"
3
- version = "0.23.1"
3
+ version = "0.24.0"
4
4
  description = "Some useful Pytorch models, circa 2025"
5
5
  authors = [
6
6
  {name = "Nicholas Bailey"}
@@ -1,89 +0,0 @@
1
- # UNDER CONSTRUCTION
2
-
3
- import math
4
- import torch
5
- from torch import nn
6
- from torch.nn import functional as F
7
-
8
- from .tensor import SigmaReparamTensor
9
-
10
-
11
- class SpectralNormLinear(nn.Module):
12
- """
13
- Inspired by Apple's Spectral Normed Linear Layers
14
- (https://github.com/apple/ml-sigma-reparam)
15
- """
16
-
17
- def __init__(self, in_features: int, out_features: int, bias: bool = True):
18
- super().__init__()
19
- self.in_features = in_features
20
- self.out_features = out_features
21
- self.use_bias = bias
22
-
23
- self.weights = None
24
-
25
- self.weight_init = nn.Parameter(torch.empty(out_features, in_features))
26
-
27
- # Define the bias vector as a learnable parameter if required.
28
- if self.use_bias:
29
- self.bias = nn.Parameter(torch.empty(out_features))
30
- else:
31
- # If no bias, register it as None.
32
- # This is important so that PyTorch doesn't complain when saving/loading the model.
33
- self.register_parameter("bias", None)
34
-
35
- self.reset_parameters()
36
-
37
- def reset_parameters(self) -> None:
38
- nn.init.kaiming_uniform_(self.weight_init, a=math.sqrt(5))
39
- if self.use_bias:
40
- fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight_init)
41
- bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
42
- nn.init.uniform_(self.bias, -bound, bound)
43
- self.weights = SigmaReparamTensor(self.weight_init)
44
-
45
- def forward(self, x: torch.Tensor) -> torch.Tensor:
46
- return F.linear(x, self.weights(), self.bias)
47
-
48
- def __repr__(self) -> str:
49
- # Optional: A nice representation for printing the module.
50
- return (
51
- f"SpectralNormFeedForward(in_features={self.in_features}",
52
- f"out_features={self.out_features}, bias={self.use_bias})",
53
- )
54
-
55
-
56
- class RandomLinear(nn.Linear):
57
- """ """
58
-
59
- def __init__(
60
- self,
61
- in_features: int,
62
- out_features: int,
63
- bias: bool = False, # <---- TODO: explain this
64
- beta=0.1,
65
- forward_looks_random=True,
66
- ):
67
- super().__init__(in_features, out_features, bias=False)
68
- self.beta = beta
69
- self.forward_looks_random = forward_looks_random
70
-
71
- def forward(self, inputs: torch.Tensor):
72
- if not self.training:
73
- return F.linear(inputs, self.weight)
74
- else:
75
- # Initialise self.random_weights
76
- random_weights = torch.empty_like(self.weight)
77
- nn.init.trunc_normal_(random_weights)
78
- random_weights *= self.beta
79
-
80
- if self.forward_looks_random:
81
- # Forward using a reparameterisation trick
82
- a = F.linear(inputs.detach(), self.weight, self.bias)
83
- b = F.linear(inputs, random_weights, bias=None)
84
- else:
85
- # Forward as (W_actual * input + W_random * input) + bias
86
- a = F.linear(inputs, self.weight, self.bias)
87
- b = F.linear(inputs, random_weights, bias=None)
88
-
89
- return a + b
File without changes
File without changes