broccoli-ml 0.9.0__py3-none-any.whl → 6.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
broccoli/activation.py CHANGED
@@ -1,10 +1,50 @@
1
1
  import torch
2
2
  from torch import nn
3
3
  from torch.nn import functional as F
4
- from einops import rearrange
5
4
 
6
5
 
7
- class SwiGLU(nn.Module):
6
+ class ReLU(nn.Module):
7
+ """
8
+ A ReLU activation function with optional clamp and leakiness.
9
+ """
10
+
11
+ def __init__(
12
+ self, clamp=True, leaky=True, negative_slope=0.01, clamp_max=6.0
13
+ ) -> None:
14
+ super().__init__()
15
+ self.clamp = clamp
16
+ self.leaky = leaky
17
+ self.negative_slope = negative_slope
18
+ self.clamp_max = clamp_max
19
+
20
+ def forward(self, x):
21
+ if self.leaky:
22
+ relu = F.leaky_relu(x, negative_slope=self.negative_slope)
23
+ else:
24
+ relu = F.relu(x)
25
+ if self.clamp:
26
+ relu = torch.clamp(relu, max=self.clamp_max)
27
+ return relu
28
+
29
+
30
+ class GELU(nn.Module):
31
+ """
32
+ A GELU activation function with optional clamp.
33
+ """
34
+
35
+ def __init__(self, clamp=True) -> None:
36
+ super().__init__()
37
+ self.clamp = clamp
38
+ self.gelu = nn.GELU()
39
+
40
+ def forward(self, x):
41
+ gelu = self.gelu(x)
42
+ if self.clamp:
43
+ gelu = torch.clamp(gelu, max=6)
44
+ return gelu
45
+
46
+
47
+ class Swish(nn.Module):
8
48
  """
9
49
  Implementation of (beta) SwiGLU, as introduced in "GLU Variants Improve Transformer"
10
50
  (https://arxiv.org/abs/2002.05202v1) and used to great effect in LLaMa 2.0.
@@ -12,16 +52,14 @@ class SwiGLU(nn.Module):
12
52
  Halves the incoming parameter count, which should be scaled up before input.
13
53
  """
14
54
 
15
- def __init__(self, linear_module: nn.Module = nn.Linear) -> None:
55
+ def __init__(self) -> None:
16
56
  super().__init__()
17
57
  # Learnable parameter is called "swiglu beta" so that it is easy to find
18
58
  # and exclude from weight decay
19
- self.swiglu_beta = nn.Parameter(torch.tensor([0.0]))
59
+ self.swish_beta = nn.Parameter(torch.tensor([1.0]))
20
60
 
21
61
  def forward(self, x):
22
- gate, value = rearrange(x, "... (split c) -> split ... c", split=2)
23
- beta_swish = gate * F.sigmoid(self.swiglu_beta * gate)
24
- return beta_swish * value
62
+ return x * F.sigmoid(self.swish_beta * x)
25
63
 
26
64
 
27
65
  class SquaredReLU(nn.Module):
@@ -32,54 +70,52 @@ class SquaredReLU(nn.Module):
32
70
  https://azizbelaweid.substack.com/p/what-is-swiglu-how-to-implement-it
33
71
  """
34
72
 
35
- def __init__(self, clamp=True, leaky=True) -> None:
73
+ def __init__(
74
+ self, clamp=True, leaky=True, negative_slope: float = 0.01, clamp_max=6
75
+ ) -> None:
36
76
  super().__init__()
37
77
  self.clamp = clamp
38
78
  self.leaky = leaky
79
+ self.negative_slope = negative_slope
80
+ self.clamp_max = clamp_max
39
81
 
40
82
  def forward(self, x):
41
- relu_squared = F.relu(x) ** 2
42
- if self.clamp:
43
- relu_squared = torch.clamp(relu_squared, max=6)
44
83
  if self.leaky:
45
- relu_squared = relu_squared + 0.1 * x
84
+ relu = F.leaky_relu(x, negative_slope=self.negative_slope)
85
+ else:
86
+ relu = F.relu(x)
87
+ relu_squared = relu**2
88
+ if self.clamp:
89
+ relu_squared = torch.clamp(relu_squared, max=self.clamp_max)
46
90
  return relu_squared
47
91
 
48
92
 
49
- class ReLU(nn.Module):
93
+ class XGLU(nn.Module):
50
94
  """
51
- A ReLU activation function with optional clamp and leakiness.
95
+ Generic Gated Linear Unit
52
96
  """
53
97
 
54
- def __init__(self, clamp=True, leaky=True) -> None:
98
+ def __init__(self, activation_module: nn.Module) -> None:
55
99
  super().__init__()
56
- self.clamp = clamp
57
- self.leaky = leaky
100
+ self.activation = activation_module
58
101
 
59
- def forward(self, x):
60
- relu = F.relu(x)
61
- if self.clamp:
62
- relu = torch.clamp(relu, max=6)
63
- if self.leaky:
64
- relu = relu + 0.1 * x
65
- return relu
102
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
103
+ gate, value = x.chunk(2, dim=-1)
104
+ return self.activation(gate) * value
66
105
 
67
106
 
68
- class GELU(nn.Module):
107
+ def SquaredReGLU(clamp=True, leaky=True, negative_slope=0.01, clamp_max=6.0) -> XGLU:
69
108
  """
70
- A ReLU activation function with optional clamp and leakiness.
109
+ Factory function that creates a GLU with a SquaredReLU activation.
71
110
  """
111
+ activation_module = SquaredReLU(
112
+ clamp=clamp, leaky=leaky, negative_slope=negative_slope, clamp_max=clamp_max
113
+ )
114
+ return XGLU(activation_module)
72
115
 
73
- def __init__(self, clamp=True, leaky=True) -> None:
74
- super().__init__()
75
- self.clamp = clamp
76
- self.leaky = leaky
77
- self.gelu = nn.GELU()
78
116
 
79
- def forward(self, x):
80
- gelu = self.gelu(x)
81
- if self.clamp:
82
- gelu = torch.clamp(gelu, max=6)
83
- if self.leaky:
84
- gelu = gelu + 0.1 * x
85
- return gelu
117
+ def SwiGLU() -> XGLU:
118
+ """
119
+ Factory function that creates a GLU with a Swish activation.
120
+ """
121
+ return XGLU(Swish())
broccoli/cnn.py CHANGED
@@ -1,300 +1,12 @@
1
1
  import torch
2
2
  import torch.nn as nn
3
3
  import torch.nn.functional as F
4
- from torch.nn.modules.utils import _pair
5
- from einops import rearrange
6
4
  import math
7
- from typing import Type, Union, Tuple, Optional, Literal
5
+ from typing import Union
8
6
 
9
7
  from einops.layers.torch import Rearrange
10
8
 
11
9
 
12
- # # Helper function to calculate padding for 'same' mode
13
- # # Adapted from https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/conv.py
14
- # def _calculate_same_padding(
15
- # input_size: Tuple[int, int],
16
- # kernel_size: Tuple[int, int],
17
- # stride: Tuple[int, int],
18
- # dilation: Tuple[int, int],
19
- # ) -> Tuple[int, int, int, int]:
20
- # """Calculates padding for 'same' output shape."""
21
- # ih, iw = input_size
22
- # kh, kw = kernel_size
23
- # sh, sw = stride
24
- # dh, dw = dilation
25
-
26
- # # Effective kernel size
27
- # eff_kh = (kh - 1) * dh + 1
28
- # eff_kw = (kw - 1) * dw + 1
29
-
30
- # # Calculate required total padding
31
- # out_h = (ih + sh - 1) // sh
32
- # out_w = (iw + sw - 1) // sw
33
- # pad_h = max((out_h - 1) * sh + eff_kh - ih, 0)
34
- # pad_w = max((out_w - 1) * sw + eff_kw - iw, 0)
35
-
36
- # # Distribute padding (similar to TensorFlow 'SAME' behavior)
37
- # pad_top = pad_h // 2
38
- # pad_bottom = pad_h - pad_top
39
- # pad_left = pad_w // 2
40
- # pad_right = pad_w - pad_left
41
- # return (pad_left, pad_right, pad_top, pad_bottom)
42
-
43
-
44
- # # Custom Convolution Layer
45
- # class ConvLayer(nn.Module):
46
- # """
47
- # A 2D Convolution layer implemented using torch.nn.Unfold and a custom linear layer.
48
-
49
- # This layer mimics the behavior of torch.nn.Conv2d but allows injecting
50
- # a different linear layer implementation for processing the unfolded patches.
51
-
52
- # Args:
53
- # in_channels (int): Number of channels in the input image.
54
- # out_channels (int): Number of channels produced by the convolution.
55
- # kernel_size (int or tuple): Size of the convolving kernel.
56
- # stride (int or tuple, optional): Stride of the convolution. Default: 1.
57
- # padding (int, tuple or str, optional): Padding added to all four sides
58
- # of the input. Can be an int, a tuple of two ints (padH, padW),
59
- # a tuple of four ints (padLeft, padRight, padTop, padBottom),
60
- # or the strings 'valid' (no padding) or 'same' (padding for same
61
- # output spatial dims as input). Default: 0 ('valid').
62
- # dilation (int or tuple, optional): Spacing between kernel elements. Default: 1.
63
- # bias (bool, optional): If True, adds a learnable bias to the output.
64
- # The bias is handled by the underlying linear layer. Default: True.
65
- # linear (Type[nn.Module], optional): The class of the linear layer
66
- # to use for the kernel operation. Must accept (in_features, out_features, bias)
67
- # in its constructor. Defaults to torch.nn.Linear.
68
- # """
69
-
70
- # def __init__(
71
- # self,
72
- # in_channels: int,
73
- # out_channels: int,
74
- # kernel_size: Union[int, Tuple[int, int]],
75
- # stride: Union[int, Tuple[int, int]] = 1,
76
- # padding: Union[
77
- # int, Tuple[int, int], Tuple[int, int, int, int], Literal["valid", "same"]
78
- # ] = 0,
79
- # dilation: Union[int, Tuple[int, int]] = 1,
80
- # bias: bool = True,
81
- # linear_module: Type[nn.Module] = nn.Linear,
82
- # ):
83
- # super().__init__()
84
- # self.in_channels = in_channels
85
- # self.out_channels = out_channels
86
- # self.kernel_size = _pair(kernel_size)
87
- # self.stride = _pair(stride)
88
- # self.dilation = _pair(dilation)
89
- # self.bias = bias
90
- # self.linear_module = linear_module
91
- # self.padding_mode = (
92
- # padding # Store the original padding mode ('same', 'valid', int, or tuple)
93
- # )
94
-
95
- # # Calculate the number of input features for the linear layer
96
- # # It's the number of channels times the kernel area
97
- # self.linear_in_features = (
98
- # in_channels * self.kernel_size[0] * self.kernel_size[1]
99
- # )
100
-
101
- # # Instantiate the linear layer (kernel)
102
- # self.kernel = self.linear_module(
103
- # self.linear_in_features, out_channels, bias=bias
104
- # )
105
-
106
- # # We will use F.pad for manual padding, so unfold padding is 0
107
- # self.unfold = nn.Unfold(
108
- # kernel_size=self.kernel_size,
109
- # dilation=self.dilation,
110
- # padding=0, # Manual padding handled in forward
111
- # stride=self.stride,
112
- # )
113
-
114
- # # Determine numeric padding values for F.pad
115
- # if isinstance(padding, str):
116
- # if padding not in ["valid", "same"]:
117
- # raise ValueError("padding must be 'valid', 'same', an int, or a tuple")
118
- # # 'same' padding calculation depends on input size, defer to forward pass
119
- # # 'valid' padding means 0
120
- # self._padding_val = (
121
- # (0, 0, 0, 0) if padding == "valid" else None
122
- # ) # None indicates 'same'
123
- # elif isinstance(padding, int):
124
- # self._padding_val = (padding,) * 4
125
- # elif isinstance(padding, tuple) and len(padding) == 2:
126
- # # (padH, padW) -> (padW_left, padW_right, padH_top, padH_bottom)
127
- # self._padding_val = (padding[1], padding[1], padding[0], padding[0])
128
- # elif isinstance(padding, tuple) and len(padding) == 4:
129
- # # (padLeft, padRight, padTop, padBottom) - already in F.pad format
130
- # self._padding_val = padding
131
- # else:
132
- # raise TypeError(
133
- # "padding must be 'valid', 'same', an int, or a tuple of 2 or 4 ints"
134
- # )
135
-
136
- # def _calculate_output_shape(self, h_in: int, w_in: int) -> Tuple[int, int]:
137
- # """Calculates the output height and width."""
138
- # if self._padding_val is None: # 'same' padding
139
- # # For 'same' padding, output size matches input size if stride is 1.
140
- # # If stride > 1, output size is ceil(input_size / stride)
141
- # # The _calculate_same_padding helper ensures this behavior.
142
- # oh = math.ceil(h_in / self.stride[0])
143
- # ow = math.ceil(w_in / self.stride[1])
144
- # return oh, ow
145
- # else:
146
- # # Use the standard formula with the calculated numeric padding
147
- # pad_h = self._padding_val[2] + self._padding_val[3] # top + bottom
148
- # pad_w = self._padding_val[0] + self._padding_val[1] # left + right
149
- # kh, kw = self.kernel_size
150
- # sh, sw = self.stride
151
- # dh, dw = self.dilation
152
-
153
- # eff_kh = (kh - 1) * dh + 1
154
- # eff_kw = (kw - 1) * dw + 1
155
-
156
- # oh = math.floor((h_in + pad_h - eff_kh) / sh + 1)
157
- # ow = math.floor((w_in + pad_w - eff_kw) / sw + 1)
158
- # return oh, ow
159
-
160
- # def forward(self, x: torch.Tensor) -> torch.Tensor:
161
- # """
162
- # Performs the forward pass.
163
-
164
- # Args:
165
- # x (torch.Tensor): Input tensor of shape (N, C_in, H_in, W_in).
166
-
167
- # Returns:
168
- # torch.Tensor: Output tensor of shape (N, C_out, H_out, W_out).
169
- # """
170
- # _, C, H, W = x.shape
171
- # if C != self.in_channels:
172
- # raise ValueError(
173
- # f"Input channels {C} does not match expected {self.in_channels}"
174
- # )
175
-
176
- # # 1. Calculate and Apply Padding
177
- # if self._padding_val is None: # 'same' padding mode
178
- # pad_l, pad_r, pad_t, pad_b = _calculate_same_padding(
179
- # (H, W), self.kernel_size, self.stride, self.dilation
180
- # )
181
- # padded_x = F.pad(x, (pad_l, pad_r, pad_t, pad_b))
182
- # # Update H, W for output shape calculation after padding
183
- # # Note: _calculate_output_shape will correctly handle 'same' based on original H, W
184
- # elif self._padding_val != (0, 0, 0, 0):
185
- # padded_x = F.pad(x, self._padding_val)
186
- # else: # No padding ('valid' or explicit 0)
187
- # padded_x = x
188
-
189
- # # 2. Unfold to extract patches
190
- # # Input: (N, C_in, H_pad, W_pad)
191
- # # Output: (N, C_in * K_h * K_w, L), where L is the number of patches (H_out * W_out)
192
- # patches = self.unfold(padded_x)
193
- # num_patches = patches.shape[-1] # L
194
-
195
- # # 3. Reshape for the linear layer
196
- # # We want (N, L, C_in * K_h * K_w) to apply the linear layer patch-wise
197
- # # transpose switches the last two dimensions
198
- # patches_transposed = patches.transpose(1, 2) # Shape: (N, L, C_in * K_h * K_w)
199
-
200
- # # 4. Apply the linear layer (kernel) to each patch
201
- # # Input: (N, L, linear_in_features)
202
- # # Output: (N, L, out_channels)
203
- # linear_output = self.kernel(patches_transposed)
204
-
205
- # # 5. Reshape back to image format
206
- # # We need (N, out_channels, L) first
207
- # output_transposed = linear_output.transpose(1, 2) # Shape: (N, out_channels, L)
208
-
209
- # # Calculate output spatial dimensions
210
- # out_h, out_w = self._calculate_output_shape(H, W) # Use original H, W
211
-
212
- # # Check if the number of patches matches the calculated output dimensions
213
- # if num_patches != out_h * out_w:
214
- # # This might happen with certain combinations of stride/padding/dilation/input size
215
- # # if the calculation logic has an issue. nn.Unfold is usually robust.
216
- # print(
217
- # f"Warning: Mismatch in calculated patches. "
218
- # f"Expected L={out_h * out_w}, got {num_patches}. "
219
- # f"Using unfolded L={num_patches} to determine output shape."
220
- # )
221
- # # Attempt recovery if possible, though might indicate upstream calculation error
222
- # # Find factors of num_patches close to expected out_h, out_w
223
- # # This part is tricky and might not always yield the desired shape.
224
- # # For simplicity, we'll rely on nn.Unfold's L and reshape.
225
- # # A more robust solution might re-calculate H_out, W_out based *only* on L.
226
- # # For now, let's stick to the reshape based on calculated out_h, out_w,
227
- # # assuming they match L. If they don't, the reshape will fail.
228
- # pass # Proceed with calculated out_h, out_w
229
-
230
- # # Reshape using einops (or tensor.view)
231
- # # Input: (N, C_out, L) -> Output: (N, C_out, H_out, W_out)
232
- # output = rearrange(output_transposed, "n c (h w) -> n c h w", h=out_h, w=out_w)
233
- # # Alternative using view:
234
- # # output = output_transposed.view(N, self.out_channels, out_h, out_w)
235
-
236
- # return output
237
-
238
- # def extra_repr(self) -> str:
239
- # s = (
240
- # "{in_channels}, {out_channels}, kernel_size={kernel_size}"
241
- # ", stride={stride}"
242
- # )
243
- # if self.padding_mode != 0 and self.padding_mode != "valid":
244
- # s += ", padding={padding_mode}"
245
- # if self.dilation != (1,) * len(self.dilation):
246
- # s += ", dilation={dilation}"
247
- # # if self.groups != 1: # Not implemented
248
- # # s += ', groups={groups}'
249
- # if self.bias is False:
250
- # s += ", bias=False"
251
- # if self.linear_module != nn.Linear:
252
- # s += f", linear={self.linear.__name__}"
253
- # return s.format(**self.__dict__)
254
-
255
-
256
- # class WhiteningConv(ConvLayer):
257
- # def __init__(
258
- # self,
259
- # in_channels: int,
260
- # kernel_size: int,
261
- # eigenvectors: torch.Tensor,
262
- # bias: bool = True,
263
- # linear_module: Type[nn.Module] = nn.Linear,
264
- # ):
265
- # """
266
- # We end up using a concatenation of the eigenvector tensor with its negation,
267
- # as the tendency to use e.g. ReLU in neural networks means that useful
268
- # data may otherwise be lost (if one orientation of an eigenvector produces
269
- # a strong negative signal, this will be clipped to zero by ReLU, but a
270
- # strong positive signal from the negation of the eigenvector will be
271
- # preserved). Assuming a square kernel, out channels is thus
272
-
273
- # (kernel_size ** 2) * in_channels * 2
274
-
275
- # where the trailing "* 2" accounts for the doubling of the size of the
276
- # eigenvector tensor we're using by including the negative of each eigenvector
277
- # as well.
278
- # """
279
- # out_channels = kernel_size**2 * in_channels * 2
280
- # super().__init__(
281
- # in_channels,
282
- # out_channels,
283
- # kernel_size,
284
- # padding="same",
285
- # bias=bias,
286
- # linear_module=linear_module,
287
- # )
288
- # self.eigenvectors = torch.cat([eigenvectors, -eigenvectors], dim=0)
289
- # # bias updates if `bias`=True but weight doesn't,
290
- # # per Jordan (2024) https://arxiv.org/abs/2404.00498
291
- # # but weight is set to `requires_grad = False`:
292
- # # self.kernel.weight.requires_grad = False
293
- # with torch.no_grad():
294
- # self.kernel.weight.copy_(self.eigenvectors)
295
- # assert self.kernel.weight.requires_grad
296
-
297
-
298
10
  def spatial_tuple(size: Union[int, tuple], spatial_dimensions):
299
11
  """
300
12
  Converts an integer x to `tuple([x] * spatial_dimensions)`.
broccoli/linear.py CHANGED
@@ -1,41 +1,138 @@
1
- # UNDER CONSTRUCTION
2
-
1
+ import math
3
2
  import torch
4
3
  from torch import nn
5
4
  from torch.nn import functional as F
6
5
 
6
+ from .tensor import SigmaReparamTensor, AnchoredReparamTensor, NormReparamTensor
7
+
8
+
9
+ class SpectralNormLinear(nn.Module):
10
+ """
11
+ Inspired by Apple's Spectral Normed Linear Layers
12
+ (https://github.com/apple/ml-sigma-reparam)
13
+ """
14
+
15
+ def __init__(self, in_features: int, out_features: int, bias: bool = True):
16
+ super().__init__()
17
+ self.in_features = in_features
18
+ self.out_features = out_features
19
+ self.use_bias = bias
20
+
21
+ self.weights = None
22
+
23
+ # Define the bias vector as a learnable parameter if required.
24
+ if self.use_bias:
25
+ self.bias = nn.Parameter(torch.empty(out_features))
26
+ else:
27
+ # If no bias, register it as None.
28
+ # This is important so that PyTorch doesn't complain when saving/loading the model.
29
+ self.register_parameter("bias", None)
30
+
31
+ self.reset_parameters()
32
+
33
+ def reset_parameters(self) -> None:
34
+ weights = torch.empty(self.out_features, self.in_features)
35
+ stdv = 1.0 / math.sqrt(self.in_features)
36
+ nn.init.uniform_(weights, a=-stdv, b=stdv)
37
+ if self.use_bias:
38
+ fan_in, _ = nn.init._calculate_fan_in_and_fan_out(weights)
39
+ bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
40
+ nn.init.uniform_(self.bias, -bound, bound)
41
+ self.weights = SigmaReparamTensor(weights)
42
+
43
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
44
+ return F.linear(x, self.weights(), self.bias)
45
+
46
+ def __repr__(self) -> str:
47
+ # Optional: A nice representation for printing the module.
48
+ return (
49
+ f"SpectralNormFeedForward(in_features={self.in_features},"
50
+ f"out_features={self.out_features}, bias={self.use_bias})"
51
+ )
52
+
53
+
54
+ class AnchoredLinear(nn.Module):
55
+ """
56
+ ...
57
+ """
58
+
59
+ def __init__(self, in_features: int, out_features: int, bias: bool = True):
60
+ super().__init__()
61
+ self.in_features = in_features
62
+ self.out_features = out_features
63
+ self.use_bias = bias
64
+
65
+ self.weights = None
7
66
 
8
- class RandomLinear(nn.Linear):
9
- """ """
10
-
11
- def __init__(
12
- self,
13
- in_features: int,
14
- out_features: int,
15
- bias: bool = False, # <---- TODO: explain this
16
- beta=0.1,
17
- forward_looks_random=True,
18
- ):
19
- super().__init__(in_features, out_features, bias=False)
20
- self.beta = beta
21
- self.forward_looks_random = forward_looks_random
22
-
23
- def forward(self, inputs: torch.Tensor):
24
- if not self.training:
25
- return F.linear(inputs, self.weight)
67
+ # Define the bias vector as a learnable parameter if required.
68
+ if self.use_bias:
69
+ self.bias = nn.Parameter(torch.empty(out_features))
26
70
  else:
27
- # Initialise self.random_weights
28
- random_weights = torch.empty_like(self.weight)
29
- nn.init.trunc_normal_(random_weights)
30
- random_weights *= self.beta
31
-
32
- if self.forward_looks_random:
33
- # Forward using a reparameterisation trick
34
- a = F.linear(inputs.detach(), self.weight, self.bias)
35
- b = F.linear(inputs, random_weights, bias=None)
36
- else:
37
- # Forward as (W_actual * input + W_random * input) + bias
38
- a = F.linear(inputs, self.weight, self.bias)
39
- b = F.linear(inputs, random_weights, bias=None)
40
-
41
- return a + b
71
+ # If no bias, register it as None.
72
+ # This is important so that PyTorch doesn't complain when saving/loading the model.
73
+ self.register_parameter("bias", None)
74
+
75
+ self.reset_parameters()
76
+
77
+ def reset_parameters(self) -> None:
78
+ weights = torch.empty(self.out_features, self.in_features)
79
+ stdv = 1.0 / math.sqrt(self.in_features)
80
+ nn.init.uniform_(weights, a=-stdv, b=stdv)
81
+ if self.use_bias:
82
+ fan_in, _ = nn.init._calculate_fan_in_and_fan_out(weights)
83
+ bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
84
+ nn.init.uniform_(self.bias, -bound, bound)
85
+ self.weights = AnchoredReparamTensor(weights)
86
+
87
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
88
+ return F.linear(x, self.weights(), self.bias)
89
+
90
+ def __repr__(self) -> str:
91
+ # Optional: A nice representation for printing the module.
92
+ return (
93
+ f"AnchoredLinear(in_features={self.in_features},"
94
+ f"out_features={self.out_features}, bias={self.use_bias})"
95
+ )
96
+
97
+
98
+ class WeightNormedLinear(nn.Module):
99
+ """
100
+ ...
101
+ """
102
+
103
+ def __init__(self, in_features: int, out_features: int, bias: bool = True):
104
+ super().__init__()
105
+ self.in_features = in_features
106
+ self.out_features = out_features
107
+ self.use_bias = bias
108
+
109
+ self.weights = None
110
+
111
+ # Define the bias vector as a learnable parameter if required.
112
+ if self.use_bias:
113
+ self.bias = nn.Parameter(torch.empty(out_features))
114
+ else:
115
+ # If no bias, register it as None.
116
+ # This is important so that PyTorch doesn't complain when saving/loading the model.
117
+ self.register_parameter("bias", None)
118
+
119
+ self.reset_parameters()
120
+
121
+ def reset_parameters(self) -> None:
122
+ weights = torch.empty(self.out_features, self.in_features)
123
+ stdv = 1.0 / math.sqrt(self.in_features)
124
+ nn.init.uniform_(weights, a=-stdv, b=stdv)
125
+ if self.use_bias:
126
+ fan_in, _ = nn.init._calculate_fan_in_and_fan_out(weights)
127
+ bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
128
+ nn.init.uniform_(self.bias, -bound, bound)
129
+ self.weights = NormReparamTensor(weights)
130
+
131
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
132
+ return F.linear(x, self.weights(), self.bias)
133
+
134
+ def __repr__(self) -> str:
135
+ return (
136
+ f"WeightNormedLinear(in_features={self.in_features},"
137
+ f"out_features={self.out_features}, bias={self.use_bias})"
138
+ )
broccoli/rope.py CHANGED
@@ -27,13 +27,28 @@ SOFTWARE.
27
27
  """
28
28
 
29
29
  from __future__ import annotations
30
- from math import pi, log
30
+ from math import pi
31
31
 
32
32
  import torch
33
- from torch.amp import autocast
33
+
34
34
  from torch.nn import Module
35
35
  from torch import nn, einsum, broadcast_tensors, is_tensor, tensor, Tensor
36
36
 
37
+ # Gracefully find the best way to import autocast
38
+ try:
39
+ from torch.amp import autocast as autocast_factory
40
+ except ImportError:
41
+ # Fallback: For PyTorch 1.6 to 1.9
42
+ from torch.cuda.amp import autocast
43
+
44
+ def autocast_factory(_, enabled=True):
45
+ """
46
+ A wrapper that mimics the modern autocast signature but calls the older
47
+ torch.cuda.amp.autocast, ignoring the device_type argument.
48
+ """
49
+ return autocast(enabled=enabled)
50
+
51
+
37
52
  from einops import rearrange, repeat
38
53
 
39
54
  from typing import Literal
@@ -74,7 +89,7 @@ def rotate_half(x):
74
89
  return rearrange(x, "... d r -> ... (d r)")
75
90
 
76
91
 
77
- @autocast("cuda", enabled=False)
92
+ @autocast_factory("cuda", enabled=False)
78
93
  def apply_rotary_emb(
79
94
  freqs, t, start_index=0, scale=1.0, seq_dim=-2, freqs_seq_dim=None
80
95
  ):
@@ -363,7 +378,7 @@ class RotaryEmbedding(Module):
363
378
  all_freqs = broadcast_tensors(*all_freqs)
364
379
  return torch.cat(all_freqs, dim=-1)
365
380
 
366
- @autocast("cuda", enabled=False)
381
+ @autocast_factory("cuda", enabled=False)
367
382
  def forward(self, t: Tensor, seq_len: int | None = None, offset=0):
368
383
  should_cache = (
369
384
  self.cache_if_possible