sparse-layers 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sparse_layers/__init__.py +69 -0
- sparse_layers/layers/__init__.py +23 -0
- sparse_layers/layers/butterfly_linear.py +225 -0
- sparse_layers/layers/butterfly_mlp.py +145 -0
- sparse_layers/layers/butterfly_multi_head_attention.py +93 -0
- sparse_layers/layers/custom_linear.py +46 -0
- sparse_layers/layers/custom_mlp.py +54 -0
- sparse_layers/layers/multi_head_attention.py +81 -0
- sparse_layers/layers/padded_butterfly_linear.py +70 -0
- sparse_layers/layers/simple_mlp.py +60 -0
- sparse_layers/sse/__init__.py +55 -0
- sparse_layers/sse/attention.py +242 -0
- sparse_layers/sse/attention_adaptive.py +117 -0
- sparse_layers/sse/linear_attention.py +68 -0
- sparse_layers/sse/linear_attention_config.py +22 -0
- sparse_layers/sse/masking_ops.py +107 -0
- sparse_layers/sse/multi_head_attention.py +196 -0
- sparse_layers/sse/multi_partition_state.py +367 -0
- sparse_layers/sse/multi_partition_state_config.py +15 -0
- sparse_layers/sse/partition_selector.py +43 -0
- sparse_layers/sse/partition_selector_config.py +26 -0
- sparse_layers/sse/sparse_softmax.py +99 -0
- sparse_layers/sse/sparse_softmax_config.py +37 -0
- sparse_layers/sse/varlen_ops.py +212 -0
- sparse_layers-0.2.0.dist-info/METADATA +113 -0
- sparse_layers-0.2.0.dist-info/RECORD +28 -0
- sparse_layers-0.2.0.dist-info/WHEEL +4 -0
- sparse_layers-0.2.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
"""sparse-layers: structured sparse layers for building memory-efficient neural networks."""
|
|
2
|
+
|
|
3
|
+
from sparse_layers.layers import (
|
|
4
|
+
ButterflyLinear,
|
|
5
|
+
ButterflyMLP,
|
|
6
|
+
ButterflyMultiHeadAttention,
|
|
7
|
+
CustomLinear,
|
|
8
|
+
CustomMLP,
|
|
9
|
+
MultiHeadAttention,
|
|
10
|
+
PaddedButterflyLinear,
|
|
11
|
+
SimpleMLP,
|
|
12
|
+
)
|
|
13
|
+
from sparse_layers.sse import (
|
|
14
|
+
LinearAttention,
|
|
15
|
+
LinearAttentionConfig,
|
|
16
|
+
NaiveMultiPartitionState,
|
|
17
|
+
NaiveSSEAttention,
|
|
18
|
+
NaiveSSEMultiHeadAttention,
|
|
19
|
+
SSEAttention,
|
|
20
|
+
SSEAttentionAdaptive,
|
|
21
|
+
SSEAttentionAdaptiveConfig,
|
|
22
|
+
SSEAttentionConfig,
|
|
23
|
+
SSEMaskingOps,
|
|
24
|
+
SSEMaskingOpsConfig,
|
|
25
|
+
SSEMultiHeadAttention,
|
|
26
|
+
SSEMultiHeadAttentionConfig,
|
|
27
|
+
SSEMultiPartitionState,
|
|
28
|
+
SSEMultiPartitionStateConfig,
|
|
29
|
+
SSEPartitionSelector,
|
|
30
|
+
SSEPartitionSelectorConfig,
|
|
31
|
+
SSESparseSoftmax,
|
|
32
|
+
SSESparseSoftmaxConfig,
|
|
33
|
+
SSEVarlenOps,
|
|
34
|
+
SSEVarlenOpsConfig,
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
__all__ = [
|
|
38
|
+
# Layers
|
|
39
|
+
"ButterflyLinear",
|
|
40
|
+
"ButterflyMLP",
|
|
41
|
+
"ButterflyMultiHeadAttention",
|
|
42
|
+
"CustomLinear",
|
|
43
|
+
"CustomMLP",
|
|
44
|
+
"MultiHeadAttention",
|
|
45
|
+
"PaddedButterflyLinear",
|
|
46
|
+
"SimpleMLP",
|
|
47
|
+
# SSE
|
|
48
|
+
"LinearAttention",
|
|
49
|
+
"LinearAttentionConfig",
|
|
50
|
+
"NaiveMultiPartitionState",
|
|
51
|
+
"NaiveSSEAttention",
|
|
52
|
+
"NaiveSSEMultiHeadAttention",
|
|
53
|
+
"SSEAttention",
|
|
54
|
+
"SSEAttentionAdaptive",
|
|
55
|
+
"SSEAttentionAdaptiveConfig",
|
|
56
|
+
"SSEAttentionConfig",
|
|
57
|
+
"SSEMaskingOps",
|
|
58
|
+
"SSEMaskingOpsConfig",
|
|
59
|
+
"SSEMultiHeadAttention",
|
|
60
|
+
"SSEMultiHeadAttentionConfig",
|
|
61
|
+
"SSEMultiPartitionState",
|
|
62
|
+
"SSEMultiPartitionStateConfig",
|
|
63
|
+
"SSEPartitionSelector",
|
|
64
|
+
"SSEPartitionSelectorConfig",
|
|
65
|
+
"SSESparseSoftmax",
|
|
66
|
+
"SSESparseSoftmaxConfig",
|
|
67
|
+
"SSEVarlenOps",
|
|
68
|
+
"SSEVarlenOpsConfig",
|
|
69
|
+
]
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
"""Core neural network layers with butterfly factorization support."""
|
|
2
|
+
|
|
3
|
+
from sparse_layers.layers.butterfly_linear import ButterflyLinear
|
|
4
|
+
from sparse_layers.layers.butterfly_mlp import ButterflyMLP
|
|
5
|
+
from sparse_layers.layers.butterfly_multi_head_attention import (
|
|
6
|
+
ButterflyMultiHeadAttention,
|
|
7
|
+
)
|
|
8
|
+
from sparse_layers.layers.custom_linear import CustomLinear
|
|
9
|
+
from sparse_layers.layers.custom_mlp import CustomMLP
|
|
10
|
+
from sparse_layers.layers.multi_head_attention import MultiHeadAttention
|
|
11
|
+
from sparse_layers.layers.padded_butterfly_linear import PaddedButterflyLinear
|
|
12
|
+
from sparse_layers.layers.simple_mlp import SimpleMLP
|
|
13
|
+
|
|
14
|
+
__all__ = [
|
|
15
|
+
"ButterflyLinear",
|
|
16
|
+
"ButterflyMLP",
|
|
17
|
+
"ButterflyMultiHeadAttention",
|
|
18
|
+
"CustomLinear",
|
|
19
|
+
"CustomMLP",
|
|
20
|
+
"MultiHeadAttention",
|
|
21
|
+
"PaddedButterflyLinear",
|
|
22
|
+
"SimpleMLP",
|
|
23
|
+
]
|
|
@@ -0,0 +1,225 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from torch import Tensor, nn
|
|
7
|
+
import torch.nn.functional as F
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def _is_power_of_two(value: int) -> bool:
|
|
11
|
+
return value > 0 and (value & (value - 1) == 0)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class ButterflyLinear(nn.Module):
|
|
15
|
+
"""Butterfly factorization based linear layer.
|
|
16
|
+
|
|
17
|
+
This layer follows the structure described in section 2.3.1 of the
|
|
18
|
+
`reducing_memory_requirements_ipu_butterfly.md` reference. It consists of
|
|
19
|
+
log2(N) stages of block-diagonal 2x2 butterfly factors, each stored as a
|
|
20
|
+
learnable parameter. The layer currently supports square power-of-two
|
|
21
|
+
dimensions and acts as a drop-in replacement for :class:`torch.nn.Linear`.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None:
|
|
25
|
+
super().__init__()
|
|
26
|
+
|
|
27
|
+
if in_features <= 0 or out_features <= 0:
|
|
28
|
+
raise ValueError("in_features and out_features must be positive integers")
|
|
29
|
+
|
|
30
|
+
if in_features != out_features:
|
|
31
|
+
raise ValueError("ButterflyLinear requires in_features == out_features")
|
|
32
|
+
|
|
33
|
+
if not _is_power_of_two(in_features):
|
|
34
|
+
raise ValueError(
|
|
35
|
+
"ButterflyLinear requires dimensions that are a power of two"
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
self.in_features = in_features
|
|
39
|
+
self.out_features = out_features
|
|
40
|
+
self._depth = int(math.log2(in_features))
|
|
41
|
+
|
|
42
|
+
num_blocks = in_features // 2
|
|
43
|
+
default_dtype = torch.get_default_dtype()
|
|
44
|
+
identity_block = torch.eye(2, dtype=default_dtype).unsqueeze(0)
|
|
45
|
+
|
|
46
|
+
self.factors = nn.ParameterList()
|
|
47
|
+
for _ in range(self._depth):
|
|
48
|
+
factor = identity_block.repeat(num_blocks, 1, 1)
|
|
49
|
+
factor += 0.01 * torch.randn_like(factor)
|
|
50
|
+
self.factors.append(nn.Parameter(factor))
|
|
51
|
+
|
|
52
|
+
if bias:
|
|
53
|
+
self.bias = nn.Parameter(torch.zeros(out_features, dtype=default_dtype))
|
|
54
|
+
else:
|
|
55
|
+
self.register_parameter("bias", None)
|
|
56
|
+
|
|
57
|
+
@classmethod
|
|
58
|
+
def from_linear(
|
|
59
|
+
cls,
|
|
60
|
+
layer: nn.Linear,
|
|
61
|
+
*,
|
|
62
|
+
optimization_steps: int = 4000,
|
|
63
|
+
learning_rate: float = 0.1,
|
|
64
|
+
tolerance: float = 1e-7,
|
|
65
|
+
seed: int | None = None,
|
|
66
|
+
) -> "ButterflyLinear":
|
|
67
|
+
"""Construct a butterfly layer approximating a dense :class:`nn.Linear`.
|
|
68
|
+
|
|
69
|
+
The method uses gradient-based fitting over the canonical basis to match
|
|
70
|
+
the transformation represented by ``layer``. The bias term is copied
|
|
71
|
+
directly before optimisation to accelerate convergence.
|
|
72
|
+
"""
|
|
73
|
+
|
|
74
|
+
if layer.in_features != layer.out_features:
|
|
75
|
+
raise ValueError("ButterflyLinear.from_linear requires a square nn.Linear")
|
|
76
|
+
|
|
77
|
+
if not _is_power_of_two(layer.in_features):
|
|
78
|
+
raise ValueError(
|
|
79
|
+
"ButterflyLinear.from_linear requires dimensions that are a power of two"
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
device = layer.weight.device
|
|
83
|
+
dtype = layer.weight.dtype
|
|
84
|
+
|
|
85
|
+
if seed is not None:
|
|
86
|
+
torch.manual_seed(seed)
|
|
87
|
+
|
|
88
|
+
result = cls(layer.in_features, layer.out_features, bias=layer.bias is not None)
|
|
89
|
+
result = result.to(device=device, dtype=dtype)
|
|
90
|
+
|
|
91
|
+
if result.bias is not None and layer.bias is not None:
|
|
92
|
+
with torch.no_grad():
|
|
93
|
+
result.bias.copy_(layer.bias)
|
|
94
|
+
|
|
95
|
+
params = list(result.factors.parameters())
|
|
96
|
+
if not params:
|
|
97
|
+
return result
|
|
98
|
+
|
|
99
|
+
optimizer_adam = torch.optim.Adam(params, lr=learning_rate)
|
|
100
|
+
|
|
101
|
+
eye_input = torch.eye(layer.in_features, device=device, dtype=dtype)
|
|
102
|
+
with torch.no_grad():
|
|
103
|
+
target = layer(eye_input)
|
|
104
|
+
|
|
105
|
+
best_loss = float("inf")
|
|
106
|
+
|
|
107
|
+
for step in range(optimization_steps):
|
|
108
|
+
optimizer_adam.zero_grad()
|
|
109
|
+
output = result(eye_input)
|
|
110
|
+
loss = F.mse_loss(output, target)
|
|
111
|
+
loss.backward()
|
|
112
|
+
optimizer_adam.step()
|
|
113
|
+
|
|
114
|
+
current_loss = loss.item()
|
|
115
|
+
if current_loss < best_loss:
|
|
116
|
+
best_loss = current_loss
|
|
117
|
+
|
|
118
|
+
if best_loss <= tolerance:
|
|
119
|
+
break
|
|
120
|
+
|
|
121
|
+
if best_loss > tolerance:
|
|
122
|
+
optimizer_lbfgs = torch.optim.LBFGS(
|
|
123
|
+
params,
|
|
124
|
+
lr=1.0,
|
|
125
|
+
max_iter=100,
|
|
126
|
+
tolerance_grad=1e-12,
|
|
127
|
+
tolerance_change=1e-12,
|
|
128
|
+
line_search_fn="strong_wolfe",
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
def closure() -> Tensor:
|
|
132
|
+
optimizer_lbfgs.zero_grad()
|
|
133
|
+
output_lbfgs = result(eye_input)
|
|
134
|
+
lbfgs_loss = F.mse_loss(output_lbfgs, target)
|
|
135
|
+
lbfgs_loss.backward()
|
|
136
|
+
return lbfgs_loss
|
|
137
|
+
|
|
138
|
+
for _ in range(20):
|
|
139
|
+
loss = optimizer_lbfgs.step(closure)
|
|
140
|
+
best_loss = min(best_loss, loss.item())
|
|
141
|
+
if best_loss <= tolerance:
|
|
142
|
+
break
|
|
143
|
+
|
|
144
|
+
return result
|
|
145
|
+
|
|
146
|
+
def forward(self, input: Tensor) -> Tensor:
|
|
147
|
+
if input.shape[-1] != self.in_features:
|
|
148
|
+
raise ValueError(
|
|
149
|
+
f"Expected last dimension {self.in_features}, got {input.shape[-1]}"
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
original_shape = input.shape[:-1]
|
|
153
|
+
x = input.reshape(-1, self.in_features)
|
|
154
|
+
|
|
155
|
+
for stage_index, factor in enumerate(self.factors):
|
|
156
|
+
x = self._apply_stage(x, factor, stage_index)
|
|
157
|
+
|
|
158
|
+
if self.bias is not None:
|
|
159
|
+
x = x + self.bias
|
|
160
|
+
|
|
161
|
+
return x.reshape(*original_shape, self.out_features)
|
|
162
|
+
|
|
163
|
+
def extra_repr(self) -> str:
|
|
164
|
+
return (
|
|
165
|
+
f"in_features={self.in_features}, out_features={self.out_features}, "
|
|
166
|
+
f"bias={self.bias is not None}"
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
def _apply_stage(self, x: Tensor, factor: Tensor, stage: int) -> Tensor:
|
|
170
|
+
batch = x.shape[0]
|
|
171
|
+
n = self.in_features
|
|
172
|
+
block = 1 << (stage + 1)
|
|
173
|
+
half = block >> 1
|
|
174
|
+
|
|
175
|
+
staged = x.reshape(batch, -1, block)
|
|
176
|
+
staged = staged.reshape(batch, -1, half, 2)
|
|
177
|
+
staged = staged.permute(0, 1, 3, 2).contiguous()
|
|
178
|
+
pairs = staged.reshape(batch, -1, 2)
|
|
179
|
+
|
|
180
|
+
transformed = torch.einsum("bnc,ncd->bnd", pairs, factor)
|
|
181
|
+
|
|
182
|
+
transformed = transformed.reshape(batch, -1, 2, half)
|
|
183
|
+
transformed = transformed.permute(0, 1, 3, 2).contiguous()
|
|
184
|
+
transformed = transformed.reshape(batch, -1, block)
|
|
185
|
+
return transformed.reshape(batch, n)
|
|
186
|
+
|
|
187
|
+
def to_linear(self) -> nn.Linear:
|
|
188
|
+
"""Return a dense :class:`nn.Linear` with identical behaviour."""
|
|
189
|
+
|
|
190
|
+
factor_tensor = self.factors[0] if len(self.factors) > 0 else None
|
|
191
|
+
if factor_tensor is not None:
|
|
192
|
+
device = factor_tensor.device
|
|
193
|
+
dtype = factor_tensor.dtype
|
|
194
|
+
elif self.bias is not None:
|
|
195
|
+
device = self.bias.device
|
|
196
|
+
dtype = self.bias.dtype
|
|
197
|
+
else:
|
|
198
|
+
device = torch.device("cpu")
|
|
199
|
+
dtype = torch.get_default_dtype()
|
|
200
|
+
|
|
201
|
+
linear = nn.Linear(
|
|
202
|
+
self.in_features,
|
|
203
|
+
self.out_features,
|
|
204
|
+
bias=self.bias is not None,
|
|
205
|
+
device=device,
|
|
206
|
+
dtype=dtype,
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
with torch.no_grad():
|
|
210
|
+
if self.bias is not None:
|
|
211
|
+
bias_backup = self.bias.data.clone()
|
|
212
|
+
self.bias.zero_()
|
|
213
|
+
else:
|
|
214
|
+
bias_backup = None
|
|
215
|
+
|
|
216
|
+
identity = torch.eye(self.in_features, device=device, dtype=dtype)
|
|
217
|
+
weight_matrix = self(identity)
|
|
218
|
+
|
|
219
|
+
if bias_backup is not None:
|
|
220
|
+
self.bias.copy_(bias_backup)
|
|
221
|
+
linear.bias.copy_(bias_backup)
|
|
222
|
+
|
|
223
|
+
linear.weight.copy_(weight_matrix.t())
|
|
224
|
+
|
|
225
|
+
return linear
|
|
@@ -0,0 +1,145 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from collections.abc import Sequence
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from torch import nn
|
|
7
|
+
|
|
8
|
+
from sparse_layers.layers.butterfly_linear import ButterflyLinear, _is_power_of_two
|
|
9
|
+
from sparse_layers.layers.simple_mlp import SimpleMLP
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class ButterflyMLP(nn.Module):
|
|
13
|
+
"""A multi-layer perceptron composed of :class:`ButterflyLinear` layers."""
|
|
14
|
+
|
|
15
|
+
def __init__(self, input_dim: int, hidden_dims: Sequence[int], output_dim: int) -> None:
|
|
16
|
+
super().__init__()
|
|
17
|
+
|
|
18
|
+
self._validate_dimensions(input_dim, hidden_dims, output_dim)
|
|
19
|
+
|
|
20
|
+
self.input_dim = input_dim
|
|
21
|
+
self.hidden_dims = list(hidden_dims)
|
|
22
|
+
self.output_dim = output_dim
|
|
23
|
+
|
|
24
|
+
layer_dims = [input_dim, *self.hidden_dims, output_dim]
|
|
25
|
+
|
|
26
|
+
modules: list[nn.Module] = []
|
|
27
|
+
for idx in range(len(layer_dims) - 1):
|
|
28
|
+
in_features = layer_dims[idx]
|
|
29
|
+
out_features = layer_dims[idx + 1]
|
|
30
|
+
|
|
31
|
+
linear = ButterflyLinear(in_features, out_features)
|
|
32
|
+
modules.append(linear)
|
|
33
|
+
|
|
34
|
+
if idx < len(layer_dims) - 2:
|
|
35
|
+
modules.append(nn.ReLU())
|
|
36
|
+
|
|
37
|
+
self.network = nn.Sequential(*modules)
|
|
38
|
+
|
|
39
|
+
@staticmethod
|
|
40
|
+
def _validate_dimensions(
|
|
41
|
+
input_dim: int, hidden_dims: Sequence[int], output_dim: int
|
|
42
|
+
) -> None:
|
|
43
|
+
if input_dim <= 0:
|
|
44
|
+
raise ValueError("input_dim must be a positive integer")
|
|
45
|
+
|
|
46
|
+
if not hidden_dims:
|
|
47
|
+
raise ValueError("hidden_dims must contain at least one positive integer")
|
|
48
|
+
|
|
49
|
+
if any(dim <= 0 for dim in hidden_dims):
|
|
50
|
+
raise ValueError("hidden_dims must contain only positive integers")
|
|
51
|
+
|
|
52
|
+
if output_dim <= 0:
|
|
53
|
+
raise ValueError("output_dim must be a positive integer")
|
|
54
|
+
|
|
55
|
+
all_dims = [input_dim, *hidden_dims, output_dim]
|
|
56
|
+
if len(set(all_dims)) != 1:
|
|
57
|
+
raise ValueError("ButterflyMLP requires all layer dimensions to be identical")
|
|
58
|
+
|
|
59
|
+
if not all(_is_power_of_two(dim) for dim in all_dims):
|
|
60
|
+
raise ValueError("ButterflyMLP requires power-of-two layer dimensions")
|
|
61
|
+
|
|
62
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
63
|
+
if x.dim() != 2:
|
|
64
|
+
raise ValueError("ButterflyMLP expects a 2D input tensor of shape (batch, features)")
|
|
65
|
+
|
|
66
|
+
if x.shape[1] != self.input_dim:
|
|
67
|
+
raise ValueError(
|
|
68
|
+
f"Expected input with {self.input_dim} features, received {x.shape[1]}"
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
return self.network(x)
|
|
72
|
+
|
|
73
|
+
def to_simple_mlp(self) -> SimpleMLP:
|
|
74
|
+
"""Return a :class:`SimpleMLP` with identical behaviour."""
|
|
75
|
+
|
|
76
|
+
simple = SimpleMLP(self.input_dim, self.hidden_dims, self.output_dim)
|
|
77
|
+
|
|
78
|
+
sparse_layers = [module for module in self.network if isinstance(module, ButterflyLinear)]
|
|
79
|
+
dense_layers = [module for module in simple.modules() if isinstance(module, nn.Linear)]
|
|
80
|
+
|
|
81
|
+
if len(sparse_layers) != len(dense_layers):
|
|
82
|
+
raise RuntimeError("Unexpected layer mismatch during conversion to SimpleMLP")
|
|
83
|
+
|
|
84
|
+
for butterfly_layer, dense_layer in zip(sparse_layers, dense_layers):
|
|
85
|
+
converted = butterfly_layer.to_linear()
|
|
86
|
+
with torch.no_grad():
|
|
87
|
+
dense_layer.weight.copy_(converted.weight)
|
|
88
|
+
if dense_layer.bias is not None and converted.bias is not None:
|
|
89
|
+
dense_layer.bias.copy_(converted.bias)
|
|
90
|
+
|
|
91
|
+
return simple
|
|
92
|
+
|
|
93
|
+
@classmethod
|
|
94
|
+
def from_simple_mlp(
|
|
95
|
+
cls,
|
|
96
|
+
model: SimpleMLP,
|
|
97
|
+
*,
|
|
98
|
+
seed: int | None = None,
|
|
99
|
+
optimization_steps: int = 4000,
|
|
100
|
+
learning_rate: float = 0.1,
|
|
101
|
+
tolerance: float = 1e-7,
|
|
102
|
+
) -> "ButterflyMLP":
|
|
103
|
+
"""Construct a :class:`ButterflyMLP` from a compatible :class:`SimpleMLP`."""
|
|
104
|
+
|
|
105
|
+
linear_layers = [module for module in model.modules() if isinstance(module, nn.Linear)]
|
|
106
|
+
|
|
107
|
+
if not linear_layers:
|
|
108
|
+
raise ValueError("SimpleMLP must contain at least one linear layer")
|
|
109
|
+
|
|
110
|
+
layer_dims = []
|
|
111
|
+
for layer in linear_layers:
|
|
112
|
+
if layer.in_features != layer.out_features:
|
|
113
|
+
raise ValueError(
|
|
114
|
+
"ButterflyMLP.from_simple_mlp requires square nn.Linear layers"
|
|
115
|
+
)
|
|
116
|
+
if not _is_power_of_two(layer.in_features):
|
|
117
|
+
raise ValueError(
|
|
118
|
+
"ButterflyMLP.from_simple_mlp requires power-of-two dimensions"
|
|
119
|
+
)
|
|
120
|
+
layer_dims.append(layer.in_features)
|
|
121
|
+
|
|
122
|
+
input_dim = linear_layers[0].in_features
|
|
123
|
+
hidden_dims = [layer.out_features for layer in linear_layers[:-1]]
|
|
124
|
+
output_dim = linear_layers[-1].out_features
|
|
125
|
+
|
|
126
|
+
result = cls(input_dim, hidden_dims, output_dim)
|
|
127
|
+
|
|
128
|
+
butterfly_indices = [
|
|
129
|
+
idx for idx, module in enumerate(result.network) if isinstance(module, ButterflyLinear)
|
|
130
|
+
]
|
|
131
|
+
|
|
132
|
+
if len(butterfly_indices) != len(linear_layers):
|
|
133
|
+
raise RuntimeError("Unexpected layer mismatch during reconstruction from SimpleMLP")
|
|
134
|
+
|
|
135
|
+
for index, (linear_layer, target_idx) in enumerate(zip(linear_layers, butterfly_indices)):
|
|
136
|
+
converted = ButterflyLinear.from_linear(
|
|
137
|
+
linear_layer,
|
|
138
|
+
seed=seed,
|
|
139
|
+
optimization_steps=optimization_steps,
|
|
140
|
+
learning_rate=learning_rate,
|
|
141
|
+
tolerance=tolerance,
|
|
142
|
+
)
|
|
143
|
+
result.network[target_idx] = converted
|
|
144
|
+
|
|
145
|
+
return result
|
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from torch import Tensor, nn
|
|
7
|
+
|
|
8
|
+
from sparse_layers.layers.butterfly_linear import ButterflyLinear
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class ButterflyMultiHeadAttention(nn.Module):
|
|
12
|
+
"""Multi-head self-attention using ButterflyLinear projections."""
|
|
13
|
+
|
|
14
|
+
def __init__(self, d_model: int, num_heads: int, dropout: float = 0.0) -> None:
|
|
15
|
+
super().__init__()
|
|
16
|
+
|
|
17
|
+
if d_model <= 0:
|
|
18
|
+
raise ValueError("d_model must be a positive integer")
|
|
19
|
+
if num_heads <= 0:
|
|
20
|
+
raise ValueError("num_heads must be a positive integer")
|
|
21
|
+
if d_model % num_heads != 0:
|
|
22
|
+
raise ValueError("d_model must be divisible by num_heads")
|
|
23
|
+
if not (0.0 <= dropout < 1.0):
|
|
24
|
+
raise ValueError("dropout must satisfy 0.0 <= dropout < 1.0")
|
|
25
|
+
|
|
26
|
+
if not _is_power_of_two(d_model):
|
|
27
|
+
raise ValueError("ButterflyMultiHeadAttention requires d_model to be a power of two")
|
|
28
|
+
|
|
29
|
+
self.d_model = d_model
|
|
30
|
+
self.num_heads = num_heads
|
|
31
|
+
self.head_dim = d_model // num_heads
|
|
32
|
+
|
|
33
|
+
self.query = ButterflyLinear(d_model, d_model)
|
|
34
|
+
self.key = ButterflyLinear(d_model, d_model)
|
|
35
|
+
self.value = ButterflyLinear(d_model, d_model)
|
|
36
|
+
self.out = ButterflyLinear(d_model, d_model)
|
|
37
|
+
self.dropout = nn.Dropout(dropout)
|
|
38
|
+
|
|
39
|
+
self._scaling = 1.0 / math.sqrt(self.head_dim)
|
|
40
|
+
|
|
41
|
+
def forward(self, x: Tensor, mask: Tensor | None = None) -> Tensor:
|
|
42
|
+
if x.dim() != 3:
|
|
43
|
+
raise ValueError("expected input of shape (batch, seq_len, d_model)")
|
|
44
|
+
batch_size, seq_len, feature_dim = x.shape
|
|
45
|
+
if feature_dim != self.d_model:
|
|
46
|
+
raise ValueError(
|
|
47
|
+
f"expected last dimension {self.d_model}, received {feature_dim}"
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
query = self.query(x)
|
|
51
|
+
key = self.key(x)
|
|
52
|
+
value = self.value(x)
|
|
53
|
+
|
|
54
|
+
query_heads = self._split_heads(query)
|
|
55
|
+
key_heads = self._split_heads(key)
|
|
56
|
+
value_heads = self._split_heads(value)
|
|
57
|
+
|
|
58
|
+
scores = torch.matmul(query_heads, key_heads.transpose(-2, -1)) * self._scaling
|
|
59
|
+
|
|
60
|
+
if mask is not None:
|
|
61
|
+
if mask.shape != (batch_size, seq_len):
|
|
62
|
+
raise ValueError("mask shape must match (batch, seq_len)")
|
|
63
|
+
if mask.dtype != torch.bool:
|
|
64
|
+
raise ValueError("mask must have dtype torch.bool")
|
|
65
|
+
expanded_mask = mask.unsqueeze(1).unsqueeze(2)
|
|
66
|
+
scores = scores.masked_fill(expanded_mask, float("-inf"))
|
|
67
|
+
|
|
68
|
+
attention = torch.softmax(scores, dim=-1)
|
|
69
|
+
attention = self.dropout(attention)
|
|
70
|
+
|
|
71
|
+
context = torch.matmul(attention, value_heads)
|
|
72
|
+
merged_context = self._merge_heads(context)
|
|
73
|
+
|
|
74
|
+
output = self.out(merged_context)
|
|
75
|
+
return output
|
|
76
|
+
|
|
77
|
+
def _split_heads(self, tensor: Tensor) -> Tensor:
|
|
78
|
+
batch_size, seq_len, _ = tensor.shape
|
|
79
|
+
reshaped = tensor.view(batch_size, seq_len, self.num_heads, self.head_dim)
|
|
80
|
+
return reshaped.permute(0, 2, 1, 3)
|
|
81
|
+
|
|
82
|
+
def _merge_heads(self, tensor: Tensor) -> Tensor:
|
|
83
|
+
batch_size, _, seq_len, _ = tensor.shape
|
|
84
|
+
transposed = tensor.permute(0, 2, 1, 3)
|
|
85
|
+
return transposed.reshape(batch_size, seq_len, self.d_model)
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def _is_power_of_two(value: int) -> bool:
|
|
89
|
+
return value > 0 and (value & (value - 1) == 0)
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
__all__ = ["ButterflyMultiHeadAttention"]
|
|
93
|
+
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from torch import Tensor, nn
|
|
7
|
+
from torch.nn.parameter import Parameter
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class CustomLinear(nn.Module):
|
|
11
|
+
"""Drop-in replacement for nn.Linear with identical behavior."""
|
|
12
|
+
|
|
13
|
+
def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None:
|
|
14
|
+
super().__init__()
|
|
15
|
+
|
|
16
|
+
if in_features <= 0:
|
|
17
|
+
raise ValueError("in_features must be a positive integer")
|
|
18
|
+
if out_features <= 0:
|
|
19
|
+
raise ValueError("out_features must be a positive integer")
|
|
20
|
+
|
|
21
|
+
self.in_features = in_features
|
|
22
|
+
self.out_features = out_features
|
|
23
|
+
|
|
24
|
+
self.weight: Parameter = Parameter(torch.empty(out_features, in_features))
|
|
25
|
+
if bias:
|
|
26
|
+
self.bias: Parameter | None = Parameter(torch.empty(out_features))
|
|
27
|
+
else:
|
|
28
|
+
self.bias = None
|
|
29
|
+
|
|
30
|
+
self.reset_parameters()
|
|
31
|
+
|
|
32
|
+
def reset_parameters(self) -> None:
|
|
33
|
+
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
|
34
|
+
if self.bias is not None:
|
|
35
|
+
fan_in = self.weight.size(1)
|
|
36
|
+
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
|
|
37
|
+
nn.init.uniform_(self.bias, -bound, bound)
|
|
38
|
+
|
|
39
|
+
def forward(self, input: Tensor) -> Tensor:
|
|
40
|
+
return nn.functional.linear(input, self.weight, self.bias)
|
|
41
|
+
|
|
42
|
+
def extra_repr(self) -> str:
|
|
43
|
+
return (
|
|
44
|
+
f"in_features={self.in_features}, out_features={self.out_features}, "
|
|
45
|
+
f"bias={self.bias is not None}"
|
|
46
|
+
)
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from collections.abc import Sequence
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from torch import nn
|
|
7
|
+
|
|
8
|
+
from sparse_layers.layers.custom_linear import CustomLinear
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class CustomMLP(nn.Module):
|
|
12
|
+
"""MLP that uses CustomLinear layers as drop-in replacements for nn.Linear."""
|
|
13
|
+
|
|
14
|
+
def __init__(self, input_dim: int, hidden_dims: Sequence[int], output_dim: int) -> None:
|
|
15
|
+
super().__init__()
|
|
16
|
+
|
|
17
|
+
self._validate_dimensions(input_dim, hidden_dims, output_dim)
|
|
18
|
+
|
|
19
|
+
hidden_dims = list(hidden_dims)
|
|
20
|
+
layer_dims = [input_dim, *hidden_dims, output_dim]
|
|
21
|
+
|
|
22
|
+
modules: list[nn.Module] = []
|
|
23
|
+
for idx in range(len(layer_dims) - 1):
|
|
24
|
+
in_features = layer_dims[idx]
|
|
25
|
+
out_features = layer_dims[idx + 1]
|
|
26
|
+
|
|
27
|
+
linear = CustomLinear(in_features, out_features)
|
|
28
|
+
modules.append(linear)
|
|
29
|
+
|
|
30
|
+
if idx < len(layer_dims) - 2:
|
|
31
|
+
modules.append(nn.ReLU())
|
|
32
|
+
|
|
33
|
+
self.network = nn.Sequential(*modules)
|
|
34
|
+
|
|
35
|
+
@staticmethod
|
|
36
|
+
def _validate_dimensions(input_dim: int, hidden_dims: Sequence[int], output_dim: int) -> None:
|
|
37
|
+
if input_dim <= 0:
|
|
38
|
+
raise ValueError("input_dim must be a positive integer")
|
|
39
|
+
|
|
40
|
+
if not hidden_dims:
|
|
41
|
+
raise ValueError("hidden_dims must contain at least one positive integer")
|
|
42
|
+
|
|
43
|
+
if any(dim <= 0 for dim in hidden_dims):
|
|
44
|
+
raise ValueError("hidden_dims must contain only positive integers")
|
|
45
|
+
|
|
46
|
+
if output_dim <= 0:
|
|
47
|
+
raise ValueError("output_dim must be a positive integer")
|
|
48
|
+
|
|
49
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
50
|
+
if x.dim() != 2:
|
|
51
|
+
raise ValueError("CustomMLP expects a 2D input tensor of shape (batch, features)")
|
|
52
|
+
|
|
53
|
+
return self.network(x)
|
|
54
|
+
|