broccoli-ml 0.29.1__py3-none-any.whl → 10.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- broccoli/activation.py +1 -4
- broccoli/cnn.py +1 -289
- broccoli/linear.py +237 -7
- broccoli/rope.py +19 -4
- broccoli/tensor.py +36 -31
- broccoli/transformer.py +523 -186
- broccoli/utils.py +13 -7
- broccoli/vit.py +214 -56
- {broccoli_ml-0.29.1.dist-info → broccoli_ml-10.0.1.dist-info}/METADATA +5 -3
- broccoli_ml-10.0.1.dist-info/RECORD +13 -0
- broccoli/assets/2025_resnet_imagenet_1k_pretrained_state_dict.pkl +0 -0
- broccoli/assets/cifar100_eigenvectors_size_2.pt +0 -0
- broccoli/assets/cifar100_eigenvectors_size_3.pt +0 -0
- broccoli/eigenpatches.py +0 -49
- broccoli_ml-0.29.1.dist-info/RECORD +0 -17
- {broccoli_ml-0.29.1.dist-info → broccoli_ml-10.0.1.dist-info}/LICENSE +0 -0
- {broccoli_ml-0.29.1.dist-info → broccoli_ml-10.0.1.dist-info}/WHEEL +0 -0
broccoli/activation.py
CHANGED
|
@@ -46,10 +46,7 @@ class GELU(nn.Module):
|
|
|
46
46
|
|
|
47
47
|
class Swish(nn.Module):
|
|
48
48
|
"""
|
|
49
|
-
Implementation of (beta)
|
|
50
|
-
(https://arxiv.org/abs/2002.05202v1) and used to great effect in LLaMa 2.0.
|
|
51
|
-
|
|
52
|
-
Halves the incoming parameter count, which should be scaled up before input.
|
|
49
|
+
Implementation of (beta) Swish
|
|
53
50
|
"""
|
|
54
51
|
|
|
55
52
|
def __init__(self) -> None:
|
broccoli/cnn.py
CHANGED
|
@@ -1,300 +1,12 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
import torch.nn as nn
|
|
3
3
|
import torch.nn.functional as F
|
|
4
|
-
from torch.nn.modules.utils import _pair
|
|
5
|
-
from einops import rearrange
|
|
6
4
|
import math
|
|
7
|
-
from typing import
|
|
5
|
+
from typing import Union
|
|
8
6
|
|
|
9
7
|
from einops.layers.torch import Rearrange
|
|
10
8
|
|
|
11
9
|
|
|
12
|
-
# # Helper function to calculate padding for 'same' mode
|
|
13
|
-
# # Adapted from https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/conv.py
|
|
14
|
-
# def _calculate_same_padding(
|
|
15
|
-
# input_size: Tuple[int, int],
|
|
16
|
-
# kernel_size: Tuple[int, int],
|
|
17
|
-
# stride: Tuple[int, int],
|
|
18
|
-
# dilation: Tuple[int, int],
|
|
19
|
-
# ) -> Tuple[int, int, int, int]:
|
|
20
|
-
# """Calculates padding for 'same' output shape."""
|
|
21
|
-
# ih, iw = input_size
|
|
22
|
-
# kh, kw = kernel_size
|
|
23
|
-
# sh, sw = stride
|
|
24
|
-
# dh, dw = dilation
|
|
25
|
-
|
|
26
|
-
# # Effective kernel size
|
|
27
|
-
# eff_kh = (kh - 1) * dh + 1
|
|
28
|
-
# eff_kw = (kw - 1) * dw + 1
|
|
29
|
-
|
|
30
|
-
# # Calculate required total padding
|
|
31
|
-
# out_h = (ih + sh - 1) // sh
|
|
32
|
-
# out_w = (iw + sw - 1) // sw
|
|
33
|
-
# pad_h = max((out_h - 1) * sh + eff_kh - ih, 0)
|
|
34
|
-
# pad_w = max((out_w - 1) * sw + eff_kw - iw, 0)
|
|
35
|
-
|
|
36
|
-
# # Distribute padding (similar to TensorFlow 'SAME' behavior)
|
|
37
|
-
# pad_top = pad_h // 2
|
|
38
|
-
# pad_bottom = pad_h - pad_top
|
|
39
|
-
# pad_left = pad_w // 2
|
|
40
|
-
# pad_right = pad_w - pad_left
|
|
41
|
-
# return (pad_left, pad_right, pad_top, pad_bottom)
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
# # Custom Convolution Layer
|
|
45
|
-
# class ConvLayer(nn.Module):
|
|
46
|
-
# """
|
|
47
|
-
# A 2D Convolution layer implemented using torch.nn.Unfold and a custom linear layer.
|
|
48
|
-
|
|
49
|
-
# This layer mimics the behavior of torch.nn.Conv2d but allows injecting
|
|
50
|
-
# a different linear layer implementation for processing the unfolded patches.
|
|
51
|
-
|
|
52
|
-
# Args:
|
|
53
|
-
# in_channels (int): Number of channels in the input image.
|
|
54
|
-
# out_channels (int): Number of channels produced by the convolution.
|
|
55
|
-
# kernel_size (int or tuple): Size of the convolving kernel.
|
|
56
|
-
# stride (int or tuple, optional): Stride of the convolution. Default: 1.
|
|
57
|
-
# padding (int, tuple or str, optional): Padding added to all four sides
|
|
58
|
-
# of the input. Can be an int, a tuple of two ints (padH, padW),
|
|
59
|
-
# a tuple of four ints (padLeft, padRight, padTop, padBottom),
|
|
60
|
-
# or the strings 'valid' (no padding) or 'same' (padding for same
|
|
61
|
-
# output spatial dims as input). Default: 0 ('valid').
|
|
62
|
-
# dilation (int or tuple, optional): Spacing between kernel elements. Default: 1.
|
|
63
|
-
# bias (bool, optional): If True, adds a learnable bias to the output.
|
|
64
|
-
# The bias is handled by the underlying linear layer. Default: True.
|
|
65
|
-
# linear (Type[nn.Module], optional): The class of the linear layer
|
|
66
|
-
# to use for the kernel operation. Must accept (in_features, out_features, bias)
|
|
67
|
-
# in its constructor. Defaults to torch.nn.Linear.
|
|
68
|
-
# """
|
|
69
|
-
|
|
70
|
-
# def __init__(
|
|
71
|
-
# self,
|
|
72
|
-
# in_channels: int,
|
|
73
|
-
# out_channels: int,
|
|
74
|
-
# kernel_size: Union[int, Tuple[int, int]],
|
|
75
|
-
# stride: Union[int, Tuple[int, int]] = 1,
|
|
76
|
-
# padding: Union[
|
|
77
|
-
# int, Tuple[int, int], Tuple[int, int, int, int], Literal["valid", "same"]
|
|
78
|
-
# ] = 0,
|
|
79
|
-
# dilation: Union[int, Tuple[int, int]] = 1,
|
|
80
|
-
# bias: bool = True,
|
|
81
|
-
# linear_module: Type[nn.Module] = nn.Linear,
|
|
82
|
-
# ):
|
|
83
|
-
# super().__init__()
|
|
84
|
-
# self.in_channels = in_channels
|
|
85
|
-
# self.out_channels = out_channels
|
|
86
|
-
# self.kernel_size = _pair(kernel_size)
|
|
87
|
-
# self.stride = _pair(stride)
|
|
88
|
-
# self.dilation = _pair(dilation)
|
|
89
|
-
# self.bias = bias
|
|
90
|
-
# self.linear_module = linear_module
|
|
91
|
-
# self.padding_mode = (
|
|
92
|
-
# padding # Store the original padding mode ('same', 'valid', int, or tuple)
|
|
93
|
-
# )
|
|
94
|
-
|
|
95
|
-
# # Calculate the number of input features for the linear layer
|
|
96
|
-
# # It's the number of channels times the kernel area
|
|
97
|
-
# self.linear_in_features = (
|
|
98
|
-
# in_channels * self.kernel_size[0] * self.kernel_size[1]
|
|
99
|
-
# )
|
|
100
|
-
|
|
101
|
-
# # Instantiate the linear layer (kernel)
|
|
102
|
-
# self.kernel = self.linear_module(
|
|
103
|
-
# self.linear_in_features, out_channels, bias=bias
|
|
104
|
-
# )
|
|
105
|
-
|
|
106
|
-
# # We will use F.pad for manual padding, so unfold padding is 0
|
|
107
|
-
# self.unfold = nn.Unfold(
|
|
108
|
-
# kernel_size=self.kernel_size,
|
|
109
|
-
# dilation=self.dilation,
|
|
110
|
-
# padding=0, # Manual padding handled in forward
|
|
111
|
-
# stride=self.stride,
|
|
112
|
-
# )
|
|
113
|
-
|
|
114
|
-
# # Determine numeric padding values for F.pad
|
|
115
|
-
# if isinstance(padding, str):
|
|
116
|
-
# if padding not in ["valid", "same"]:
|
|
117
|
-
# raise ValueError("padding must be 'valid', 'same', an int, or a tuple")
|
|
118
|
-
# # 'same' padding calculation depends on input size, defer to forward pass
|
|
119
|
-
# # 'valid' padding means 0
|
|
120
|
-
# self._padding_val = (
|
|
121
|
-
# (0, 0, 0, 0) if padding == "valid" else None
|
|
122
|
-
# ) # None indicates 'same'
|
|
123
|
-
# elif isinstance(padding, int):
|
|
124
|
-
# self._padding_val = (padding,) * 4
|
|
125
|
-
# elif isinstance(padding, tuple) and len(padding) == 2:
|
|
126
|
-
# # (padH, padW) -> (padW_left, padW_right, padH_top, padH_bottom)
|
|
127
|
-
# self._padding_val = (padding[1], padding[1], padding[0], padding[0])
|
|
128
|
-
# elif isinstance(padding, tuple) and len(padding) == 4:
|
|
129
|
-
# # (padLeft, padRight, padTop, padBottom) - already in F.pad format
|
|
130
|
-
# self._padding_val = padding
|
|
131
|
-
# else:
|
|
132
|
-
# raise TypeError(
|
|
133
|
-
# "padding must be 'valid', 'same', an int, or a tuple of 2 or 4 ints"
|
|
134
|
-
# )
|
|
135
|
-
|
|
136
|
-
# def _calculate_output_shape(self, h_in: int, w_in: int) -> Tuple[int, int]:
|
|
137
|
-
# """Calculates the output height and width."""
|
|
138
|
-
# if self._padding_val is None: # 'same' padding
|
|
139
|
-
# # For 'same' padding, output size matches input size if stride is 1.
|
|
140
|
-
# # If stride > 1, output size is ceil(input_size / stride)
|
|
141
|
-
# # The _calculate_same_padding helper ensures this behavior.
|
|
142
|
-
# oh = math.ceil(h_in / self.stride[0])
|
|
143
|
-
# ow = math.ceil(w_in / self.stride[1])
|
|
144
|
-
# return oh, ow
|
|
145
|
-
# else:
|
|
146
|
-
# # Use the standard formula with the calculated numeric padding
|
|
147
|
-
# pad_h = self._padding_val[2] + self._padding_val[3] # top + bottom
|
|
148
|
-
# pad_w = self._padding_val[0] + self._padding_val[1] # left + right
|
|
149
|
-
# kh, kw = self.kernel_size
|
|
150
|
-
# sh, sw = self.stride
|
|
151
|
-
# dh, dw = self.dilation
|
|
152
|
-
|
|
153
|
-
# eff_kh = (kh - 1) * dh + 1
|
|
154
|
-
# eff_kw = (kw - 1) * dw + 1
|
|
155
|
-
|
|
156
|
-
# oh = math.floor((h_in + pad_h - eff_kh) / sh + 1)
|
|
157
|
-
# ow = math.floor((w_in + pad_w - eff_kw) / sw + 1)
|
|
158
|
-
# return oh, ow
|
|
159
|
-
|
|
160
|
-
# def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
161
|
-
# """
|
|
162
|
-
# Performs the forward pass.
|
|
163
|
-
|
|
164
|
-
# Args:
|
|
165
|
-
# x (torch.Tensor): Input tensor of shape (N, C_in, H_in, W_in).
|
|
166
|
-
|
|
167
|
-
# Returns:
|
|
168
|
-
# torch.Tensor: Output tensor of shape (N, C_out, H_out, W_out).
|
|
169
|
-
# """
|
|
170
|
-
# _, C, H, W = x.shape
|
|
171
|
-
# if C != self.in_channels:
|
|
172
|
-
# raise ValueError(
|
|
173
|
-
# f"Input channels {C} does not match expected {self.in_channels}"
|
|
174
|
-
# )
|
|
175
|
-
|
|
176
|
-
# # 1. Calculate and Apply Padding
|
|
177
|
-
# if self._padding_val is None: # 'same' padding mode
|
|
178
|
-
# pad_l, pad_r, pad_t, pad_b = _calculate_same_padding(
|
|
179
|
-
# (H, W), self.kernel_size, self.stride, self.dilation
|
|
180
|
-
# )
|
|
181
|
-
# padded_x = F.pad(x, (pad_l, pad_r, pad_t, pad_b))
|
|
182
|
-
# # Update H, W for output shape calculation after padding
|
|
183
|
-
# # Note: _calculate_output_shape will correctly handle 'same' based on original H, W
|
|
184
|
-
# elif self._padding_val != (0, 0, 0, 0):
|
|
185
|
-
# padded_x = F.pad(x, self._padding_val)
|
|
186
|
-
# else: # No padding ('valid' or explicit 0)
|
|
187
|
-
# padded_x = x
|
|
188
|
-
|
|
189
|
-
# # 2. Unfold to extract patches
|
|
190
|
-
# # Input: (N, C_in, H_pad, W_pad)
|
|
191
|
-
# # Output: (N, C_in * K_h * K_w, L), where L is the number of patches (H_out * W_out)
|
|
192
|
-
# patches = self.unfold(padded_x)
|
|
193
|
-
# num_patches = patches.shape[-1] # L
|
|
194
|
-
|
|
195
|
-
# # 3. Reshape for the linear layer
|
|
196
|
-
# # We want (N, L, C_in * K_h * K_w) to apply the linear layer patch-wise
|
|
197
|
-
# # transpose switches the last two dimensions
|
|
198
|
-
# patches_transposed = patches.transpose(1, 2) # Shape: (N, L, C_in * K_h * K_w)
|
|
199
|
-
|
|
200
|
-
# # 4. Apply the linear layer (kernel) to each patch
|
|
201
|
-
# # Input: (N, L, linear_in_features)
|
|
202
|
-
# # Output: (N, L, out_channels)
|
|
203
|
-
# linear_output = self.kernel(patches_transposed)
|
|
204
|
-
|
|
205
|
-
# # 5. Reshape back to image format
|
|
206
|
-
# # We need (N, out_channels, L) first
|
|
207
|
-
# output_transposed = linear_output.transpose(1, 2) # Shape: (N, out_channels, L)
|
|
208
|
-
|
|
209
|
-
# # Calculate output spatial dimensions
|
|
210
|
-
# out_h, out_w = self._calculate_output_shape(H, W) # Use original H, W
|
|
211
|
-
|
|
212
|
-
# # Check if the number of patches matches the calculated output dimensions
|
|
213
|
-
# if num_patches != out_h * out_w:
|
|
214
|
-
# # This might happen with certain combinations of stride/padding/dilation/input size
|
|
215
|
-
# # if the calculation logic has an issue. nn.Unfold is usually robust.
|
|
216
|
-
# print(
|
|
217
|
-
# f"Warning: Mismatch in calculated patches. "
|
|
218
|
-
# f"Expected L={out_h * out_w}, got {num_patches}. "
|
|
219
|
-
# f"Using unfolded L={num_patches} to determine output shape."
|
|
220
|
-
# )
|
|
221
|
-
# # Attempt recovery if possible, though might indicate upstream calculation error
|
|
222
|
-
# # Find factors of num_patches close to expected out_h, out_w
|
|
223
|
-
# # This part is tricky and might not always yield the desired shape.
|
|
224
|
-
# # For simplicity, we'll rely on nn.Unfold's L and reshape.
|
|
225
|
-
# # A more robust solution might re-calculate H_out, W_out based *only* on L.
|
|
226
|
-
# # For now, let's stick to the reshape based on calculated out_h, out_w,
|
|
227
|
-
# # assuming they match L. If they don't, the reshape will fail.
|
|
228
|
-
# pass # Proceed with calculated out_h, out_w
|
|
229
|
-
|
|
230
|
-
# # Reshape using einops (or tensor.view)
|
|
231
|
-
# # Input: (N, C_out, L) -> Output: (N, C_out, H_out, W_out)
|
|
232
|
-
# output = rearrange(output_transposed, "n c (h w) -> n c h w", h=out_h, w=out_w)
|
|
233
|
-
# # Alternative using view:
|
|
234
|
-
# # output = output_transposed.view(N, self.out_channels, out_h, out_w)
|
|
235
|
-
|
|
236
|
-
# return output
|
|
237
|
-
|
|
238
|
-
# def extra_repr(self) -> str:
|
|
239
|
-
# s = (
|
|
240
|
-
# "{in_channels}, {out_channels}, kernel_size={kernel_size}"
|
|
241
|
-
# ", stride={stride}"
|
|
242
|
-
# )
|
|
243
|
-
# if self.padding_mode != 0 and self.padding_mode != "valid":
|
|
244
|
-
# s += ", padding={padding_mode}"
|
|
245
|
-
# if self.dilation != (1,) * len(self.dilation):
|
|
246
|
-
# s += ", dilation={dilation}"
|
|
247
|
-
# # if self.groups != 1: # Not implemented
|
|
248
|
-
# # s += ', groups={groups}'
|
|
249
|
-
# if self.bias is False:
|
|
250
|
-
# s += ", bias=False"
|
|
251
|
-
# if self.linear_module != nn.Linear:
|
|
252
|
-
# s += f", linear={self.linear.__name__}"
|
|
253
|
-
# return s.format(**self.__dict__)
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
# class WhiteningConv(ConvLayer):
|
|
257
|
-
# def __init__(
|
|
258
|
-
# self,
|
|
259
|
-
# in_channels: int,
|
|
260
|
-
# kernel_size: int,
|
|
261
|
-
# eigenvectors: torch.Tensor,
|
|
262
|
-
# bias: bool = True,
|
|
263
|
-
# linear_module: Type[nn.Module] = nn.Linear,
|
|
264
|
-
# ):
|
|
265
|
-
# """
|
|
266
|
-
# We end up using a concatenation of the eigenvector tensor with its negation,
|
|
267
|
-
# as the tendency to use e.g. ReLU in neural networks means that useful
|
|
268
|
-
# data may otherwise be lost (if one orientation of an eigenvector produces
|
|
269
|
-
# a strong negative signal, this will be clipped to zero by ReLU, but a
|
|
270
|
-
# strong positive signal from the negation of the eigenvector will be
|
|
271
|
-
# preserved). Assuming a square kernel, out channels is thus
|
|
272
|
-
|
|
273
|
-
# (kernel_size ** 2) * in_channels * 2
|
|
274
|
-
|
|
275
|
-
# where the trailing "* 2" accounts for the doubling of the size of the
|
|
276
|
-
# eigenvector tensor we're using by including the negative of each eigenvector
|
|
277
|
-
# as well.
|
|
278
|
-
# """
|
|
279
|
-
# out_channels = kernel_size**2 * in_channels * 2
|
|
280
|
-
# super().__init__(
|
|
281
|
-
# in_channels,
|
|
282
|
-
# out_channels,
|
|
283
|
-
# kernel_size,
|
|
284
|
-
# padding="same",
|
|
285
|
-
# bias=bias,
|
|
286
|
-
# linear_module=linear_module,
|
|
287
|
-
# )
|
|
288
|
-
# self.eigenvectors = torch.cat([eigenvectors, -eigenvectors], dim=0)
|
|
289
|
-
# # bias updates if `bias`=True but weight doesn't,
|
|
290
|
-
# # per Jordan (2024) https://arxiv.org/abs/2404.00498
|
|
291
|
-
# # but weight is set to `requires_grad = False`:
|
|
292
|
-
# # self.kernel.weight.requires_grad = False
|
|
293
|
-
# with torch.no_grad():
|
|
294
|
-
# self.kernel.weight.copy_(self.eigenvectors)
|
|
295
|
-
# assert self.kernel.weight.requires_grad
|
|
296
|
-
|
|
297
|
-
|
|
298
10
|
def spatial_tuple(size: Union[int, tuple], spatial_dimensions):
|
|
299
11
|
"""
|
|
300
12
|
Converts an integer x to `tuple([x] * spatial_dimensions)`.
|
broccoli/linear.py
CHANGED
|
@@ -1,6 +1,8 @@
|
|
|
1
|
-
# UNDER CONSTRUCTION
|
|
2
|
-
|
|
3
1
|
import math
|
|
2
|
+
import random
|
|
3
|
+
import warnings
|
|
4
|
+
from typing import Union, List, Iterable
|
|
5
|
+
|
|
4
6
|
import torch
|
|
5
7
|
from torch import nn
|
|
6
8
|
from torch.nn import functional as F
|
|
@@ -34,7 +36,8 @@ class SpectralNormLinear(nn.Module):
|
|
|
34
36
|
|
|
35
37
|
def reset_parameters(self) -> None:
|
|
36
38
|
weights = torch.empty(self.out_features, self.in_features)
|
|
37
|
-
|
|
39
|
+
stdv = 1.0 / math.sqrt(self.in_features)
|
|
40
|
+
nn.init.uniform_(weights, a=-stdv, b=stdv)
|
|
38
41
|
if self.use_bias:
|
|
39
42
|
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(weights)
|
|
40
43
|
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
|
|
@@ -77,7 +80,8 @@ class AnchoredLinear(nn.Module):
|
|
|
77
80
|
|
|
78
81
|
def reset_parameters(self) -> None:
|
|
79
82
|
weights = torch.empty(self.out_features, self.in_features)
|
|
80
|
-
|
|
83
|
+
stdv = 1.0 / math.sqrt(self.in_features)
|
|
84
|
+
nn.init.uniform_(weights, a=-stdv, b=stdv)
|
|
81
85
|
if self.use_bias:
|
|
82
86
|
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(weights)
|
|
83
87
|
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
|
|
@@ -120,7 +124,8 @@ class WeightNormedLinear(nn.Module):
|
|
|
120
124
|
|
|
121
125
|
def reset_parameters(self) -> None:
|
|
122
126
|
weights = torch.empty(self.out_features, self.in_features)
|
|
123
|
-
|
|
127
|
+
stdv = 1.0 / math.sqrt(self.in_features)
|
|
128
|
+
nn.init.uniform_(weights, a=-stdv, b=stdv)
|
|
124
129
|
if self.use_bias:
|
|
125
130
|
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(weights)
|
|
126
131
|
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
|
|
@@ -131,8 +136,233 @@ class WeightNormedLinear(nn.Module):
|
|
|
131
136
|
return F.linear(x, self.weights(), self.bias)
|
|
132
137
|
|
|
133
138
|
def __repr__(self) -> str:
|
|
134
|
-
# Optional: A nice representation for printing the module.
|
|
135
139
|
return (
|
|
136
|
-
f"
|
|
140
|
+
f"WeightNormedLinear(in_features={self.in_features},"
|
|
137
141
|
f"out_features={self.out_features}, bias={self.use_bias})"
|
|
138
142
|
)
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
class RecyclingLinear(nn.Module):
|
|
146
|
+
def __init__(
|
|
147
|
+
self,
|
|
148
|
+
in_features: int,
|
|
149
|
+
out_features: int,
|
|
150
|
+
bias: bool = True,
|
|
151
|
+
row_recycling_rate: float = 0.0,
|
|
152
|
+
column_recycling_rate: float = 0.0,
|
|
153
|
+
adaptive=False,
|
|
154
|
+
xglu=False,
|
|
155
|
+
):
|
|
156
|
+
super().__init__()
|
|
157
|
+
self.in_features = in_features
|
|
158
|
+
self.out_features = out_features
|
|
159
|
+
self.bias = bias
|
|
160
|
+
self.xglu = xglu
|
|
161
|
+
self.linear = nn.Linear(in_features, out_features, bias=bias)
|
|
162
|
+
self.row_recycling_rate = row_recycling_rate
|
|
163
|
+
self.column_recycling_rate = column_recycling_rate
|
|
164
|
+
self.adaptive = adaptive
|
|
165
|
+
self.optimisers = []
|
|
166
|
+
self.initial_learning_rates = []
|
|
167
|
+
self._warned_about_registration = False
|
|
168
|
+
|
|
169
|
+
def register_optimiser(self, optimiser: torch.optim.Optimizer):
|
|
170
|
+
self.optimisers.append(optimiser)
|
|
171
|
+
self.initial_learning_rates.append(self._get_learning_rate(optimiser))
|
|
172
|
+
if self.initial_learning_rates[-1] == 0.0:
|
|
173
|
+
warnings.warn(
|
|
174
|
+
"Learning rate of registered optimiser was 0.0 - make sure "
|
|
175
|
+
"you haven't initialised a scheduler before registering the "
|
|
176
|
+
"optimiser",
|
|
177
|
+
stacklevel=2,
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
def _get_learning_rate(self, optimiser: torch.optim.Optimizer):
|
|
181
|
+
for group in optimiser.param_groups:
|
|
182
|
+
for param in group["params"]:
|
|
183
|
+
if param is self.linear.weight:
|
|
184
|
+
return group["lr"]
|
|
185
|
+
|
|
186
|
+
def _get_multiplier(self):
|
|
187
|
+
if not self.adaptive or not self.optimisers:
|
|
188
|
+
return 1.0
|
|
189
|
+
else:
|
|
190
|
+
init = self.initial_learning_rates
|
|
191
|
+
current = [self._get_learning_rate(o) for o in self.optimisers]
|
|
192
|
+
pairs = zip(current, init, strict=True)
|
|
193
|
+
multipliers = [a / b for a, b in pairs if b != 0.0]
|
|
194
|
+
return min(multipliers) if multipliers else 0.0
|
|
195
|
+
|
|
196
|
+
def reset_rows(self, indices, incoming_data=None):
|
|
197
|
+
"""
|
|
198
|
+
Resets rows.
|
|
199
|
+
If incoming_data is provided, resets to the centroid (mean) of that data.
|
|
200
|
+
If not, resets to the mean of existing weights.
|
|
201
|
+
"""
|
|
202
|
+
if not torch.is_tensor(indices):
|
|
203
|
+
idx_tensor = torch.as_tensor(
|
|
204
|
+
list(indices), dtype=torch.long, device=self.linear.weight.device
|
|
205
|
+
)
|
|
206
|
+
else:
|
|
207
|
+
idx_tensor = indices
|
|
208
|
+
|
|
209
|
+
if idx_tensor.numel() == 0:
|
|
210
|
+
return
|
|
211
|
+
|
|
212
|
+
if incoming_data is not None:
|
|
213
|
+
target_center = self._mean_input_weights(incoming_data)
|
|
214
|
+
else:
|
|
215
|
+
target_center = self._mean_value_weights()
|
|
216
|
+
|
|
217
|
+
target_center = target_center.expand(idx_tensor.size(0), -1)
|
|
218
|
+
|
|
219
|
+
if self.xglu:
|
|
220
|
+
gate_indices = idx_tensor
|
|
221
|
+
value_indices = idx_tensor + (self.linear.out_features // 2)
|
|
222
|
+
self._update_weights(gate_indices, 0, target_center, self.optimisers)
|
|
223
|
+
self._update_weights(value_indices, 0, target_center, self.optimisers)
|
|
224
|
+
else:
|
|
225
|
+
self._update_weights(idx_tensor, 0, target_center, self.optimisers)
|
|
226
|
+
|
|
227
|
+
def reset_columns(self, indices):
|
|
228
|
+
if not torch.is_tensor(indices):
|
|
229
|
+
idx_tensor = torch.as_tensor(
|
|
230
|
+
list(indices), dtype=torch.long, device=self.linear.weight.device
|
|
231
|
+
)
|
|
232
|
+
else:
|
|
233
|
+
idx_tensor = indices
|
|
234
|
+
|
|
235
|
+
if idx_tensor.size(0):
|
|
236
|
+
random_weights = self._random_weights(
|
|
237
|
+
self.linear.weight.size(0), indices.size(0)
|
|
238
|
+
)
|
|
239
|
+
# Make random col weights quiet so they don't introduce loud noise...
|
|
240
|
+
# ...but not so quiet that FP16 zeros them and ruins symmetry breaking!
|
|
241
|
+
random_weights *= 0.1
|
|
242
|
+
self._update_weights(indices, 1, random_weights, self.optimisers) # dim
|
|
243
|
+
else:
|
|
244
|
+
return
|
|
245
|
+
|
|
246
|
+
def forward(self, x):
|
|
247
|
+
if self.training and self.optimisers:
|
|
248
|
+
self.reset_rows(self.get_reset_indices(0))
|
|
249
|
+
self.reset_columns(self.get_reset_indices(1))
|
|
250
|
+
elif self.training and not self._warned_about_registration:
|
|
251
|
+
warnings.warn(
|
|
252
|
+
"RecyclingLinear: No optimiser registered. Recycling disabled.",
|
|
253
|
+
stacklevel=2,
|
|
254
|
+
)
|
|
255
|
+
self._warned_about_registration = True
|
|
256
|
+
|
|
257
|
+
return self.linear(x)
|
|
258
|
+
|
|
259
|
+
def get_reset_indices(self, dim):
|
|
260
|
+
base_rate = self.row_recycling_rate if dim == 0 else self.column_recycling_rate
|
|
261
|
+
p = base_rate * self._get_multiplier()
|
|
262
|
+
if dim == 0:
|
|
263
|
+
if self.xglu:
|
|
264
|
+
sample_space = self.linear.out_features // 2
|
|
265
|
+
else:
|
|
266
|
+
sample_space = self.linear.out_features
|
|
267
|
+
elif dim == 1:
|
|
268
|
+
sample_space = self.linear.in_features
|
|
269
|
+
else:
|
|
270
|
+
raise ValueError("`dim` must be 0 or 1")
|
|
271
|
+
|
|
272
|
+
# Sample the indices
|
|
273
|
+
probs = torch.rand(sample_space, device=self.linear.weight.device)
|
|
274
|
+
mask = probs < p
|
|
275
|
+
if mask.any():
|
|
276
|
+
return torch.nonzero(mask).squeeze(-1)
|
|
277
|
+
else:
|
|
278
|
+
return torch.tensor([], dtype=torch.long, device=self.linear.weight.device)
|
|
279
|
+
|
|
280
|
+
def _random_weights(self, rows, columns):
|
|
281
|
+
device = self.linear.weight.device
|
|
282
|
+
weights = self.linear.weight.data
|
|
283
|
+
stdv = 1.0 / math.sqrt(weights.size(1))
|
|
284
|
+
random_weights = torch.rand(rows, columns, device=device)
|
|
285
|
+
random_weights -= 0.5 # Range [-0.5, +0.5]
|
|
286
|
+
random_weights *= 2.0 * stdv # Range [-stdv, +stdv]
|
|
287
|
+
return random_weights
|
|
288
|
+
|
|
289
|
+
def _mean_input_weights(self, input):
|
|
290
|
+
reduce_dims = list(range(input.ndim - 1))
|
|
291
|
+
data_mean = input.detach().mean(dim=reduce_dims, keepdim=True)
|
|
292
|
+
|
|
293
|
+
weights = self.linear.weight.data
|
|
294
|
+
stdv = 1.0 / math.sqrt(weights.size(1))
|
|
295
|
+
data_norm = data_mean.std() + 1e-6
|
|
296
|
+
scale_factor = stdv / data_norm
|
|
297
|
+
|
|
298
|
+
return data_mean * scale_factor
|
|
299
|
+
|
|
300
|
+
def _mean_value_weights(self):
|
|
301
|
+
"""
|
|
302
|
+
Only used when self.xglu
|
|
303
|
+
"""
|
|
304
|
+
weights = self.linear.weight.data
|
|
305
|
+
rows = weights.size(0)
|
|
306
|
+
if self.xglu:
|
|
307
|
+
return self.linear.weight[int(rows / 2) :].data.mean(dim=0, keepdim=True)
|
|
308
|
+
else:
|
|
309
|
+
return self.linear.weight.data.mean(dim=0, keepdim=True)
|
|
310
|
+
|
|
311
|
+
def _mean_gate_weights(self):
|
|
312
|
+
"""
|
|
313
|
+
Only used when self.xglu
|
|
314
|
+
"""
|
|
315
|
+
weights = self.linear.weight.data
|
|
316
|
+
rows = weights.size(0)
|
|
317
|
+
return self.linear.weight[: int(rows / 2)].data.mean(dim=0, keepdim=True)
|
|
318
|
+
|
|
319
|
+
def _update_weights(
|
|
320
|
+
self,
|
|
321
|
+
indices: Iterable[int],
|
|
322
|
+
dim: int,
|
|
323
|
+
data: torch.Tensor,
|
|
324
|
+
optimisers: Union[
|
|
325
|
+
List[torch.optim.Optimizer], torch.optim.Optimizer, None
|
|
326
|
+
] = None,
|
|
327
|
+
):
|
|
328
|
+
if optimisers is None:
|
|
329
|
+
optimisers = []
|
|
330
|
+
if not isinstance(optimisers, list):
|
|
331
|
+
optimisers = [optimisers]
|
|
332
|
+
|
|
333
|
+
if not torch.is_tensor(indices):
|
|
334
|
+
idx_tensor = torch.as_tensor(
|
|
335
|
+
list(indices), dtype=torch.long, device=self.linear.weight.device
|
|
336
|
+
)
|
|
337
|
+
else:
|
|
338
|
+
idx_tensor = indices
|
|
339
|
+
|
|
340
|
+
if idx_tensor.numel() == 0:
|
|
341
|
+
return
|
|
342
|
+
|
|
343
|
+
with torch.no_grad():
|
|
344
|
+
if dim == 0:
|
|
345
|
+
self.linear.weight.data[idx_tensor] = data
|
|
346
|
+
elif dim == 1:
|
|
347
|
+
self.linear.weight.data[:, idx_tensor] = data
|
|
348
|
+
else:
|
|
349
|
+
raise ValueError("`dim` must be 0 or 1")
|
|
350
|
+
|
|
351
|
+
self._reset_optim_state(self.linear.weight, idx_tensor, optimisers, dim=dim)
|
|
352
|
+
|
|
353
|
+
def _reset_optim_state(self, param, idx_tensor, optimisers, dim):
|
|
354
|
+
"""
|
|
355
|
+
Zeroes out the optimizer state for the given indices in a single operation.
|
|
356
|
+
"""
|
|
357
|
+
for optimiser in optimisers:
|
|
358
|
+
if param not in optimiser.state:
|
|
359
|
+
continue
|
|
360
|
+
state = optimiser.state[param]
|
|
361
|
+
|
|
362
|
+
for _, buffer in state.items():
|
|
363
|
+
if torch.is_tensor(buffer) and buffer.shape == param.shape:
|
|
364
|
+
# Vectorized zeroing
|
|
365
|
+
if dim == 0:
|
|
366
|
+
buffer[idx_tensor] = 0.0
|
|
367
|
+
else:
|
|
368
|
+
buffer[:, idx_tensor] = 0.0
|
broccoli/rope.py
CHANGED
|
@@ -27,13 +27,28 @@ SOFTWARE.
|
|
|
27
27
|
"""
|
|
28
28
|
|
|
29
29
|
from __future__ import annotations
|
|
30
|
-
from math import pi
|
|
30
|
+
from math import pi
|
|
31
31
|
|
|
32
32
|
import torch
|
|
33
|
-
|
|
33
|
+
|
|
34
34
|
from torch.nn import Module
|
|
35
35
|
from torch import nn, einsum, broadcast_tensors, is_tensor, tensor, Tensor
|
|
36
36
|
|
|
37
|
+
# Gracefully find the best way to import autocast
|
|
38
|
+
try:
|
|
39
|
+
from torch.amp import autocast as autocast_factory
|
|
40
|
+
except ImportError:
|
|
41
|
+
# Fallback: For PyTorch 1.6 to 1.9
|
|
42
|
+
from torch.cuda.amp import autocast
|
|
43
|
+
|
|
44
|
+
def autocast_factory(_, enabled=True):
|
|
45
|
+
"""
|
|
46
|
+
A wrapper that mimics the modern autocast signature but calls the older
|
|
47
|
+
torch.cuda.amp.autocast, ignoring the device_type argument.
|
|
48
|
+
"""
|
|
49
|
+
return autocast(enabled=enabled)
|
|
50
|
+
|
|
51
|
+
|
|
37
52
|
from einops import rearrange, repeat
|
|
38
53
|
|
|
39
54
|
from typing import Literal
|
|
@@ -74,7 +89,7 @@ def rotate_half(x):
|
|
|
74
89
|
return rearrange(x, "... d r -> ... (d r)")
|
|
75
90
|
|
|
76
91
|
|
|
77
|
-
@
|
|
92
|
+
@autocast_factory("cuda", enabled=False)
|
|
78
93
|
def apply_rotary_emb(
|
|
79
94
|
freqs, t, start_index=0, scale=1.0, seq_dim=-2, freqs_seq_dim=None
|
|
80
95
|
):
|
|
@@ -363,7 +378,7 @@ class RotaryEmbedding(Module):
|
|
|
363
378
|
all_freqs = broadcast_tensors(*all_freqs)
|
|
364
379
|
return torch.cat(all_freqs, dim=-1)
|
|
365
380
|
|
|
366
|
-
@
|
|
381
|
+
@autocast_factory("cuda", enabled=False)
|
|
367
382
|
def forward(self, t: Tensor, seq_len: int | None = None, offset=0):
|
|
368
383
|
should_cache = (
|
|
369
384
|
self.cache_if_possible
|