sparse-layers 0.2.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. sparse_layers-0.2.0/.github/workflows/publish.yml +26 -0
  2. sparse_layers-0.2.0/.gitignore +11 -0
  3. sparse_layers-0.2.0/CHANGELOG.md +17 -0
  4. sparse_layers-0.2.0/LICENSE +21 -0
  5. sparse_layers-0.2.0/PKG-INFO +113 -0
  6. sparse_layers-0.2.0/README.md +87 -0
  7. sparse_layers-0.2.0/pyproject.toml +46 -0
  8. sparse_layers-0.2.0/src/sparse_layers/__init__.py +69 -0
  9. sparse_layers-0.2.0/src/sparse_layers/layers/__init__.py +23 -0
  10. sparse_layers-0.2.0/src/sparse_layers/layers/butterfly_linear.py +225 -0
  11. sparse_layers-0.2.0/src/sparse_layers/layers/butterfly_mlp.py +145 -0
  12. sparse_layers-0.2.0/src/sparse_layers/layers/butterfly_multi_head_attention.py +93 -0
  13. sparse_layers-0.2.0/src/sparse_layers/layers/custom_linear.py +46 -0
  14. sparse_layers-0.2.0/src/sparse_layers/layers/custom_mlp.py +54 -0
  15. sparse_layers-0.2.0/src/sparse_layers/layers/multi_head_attention.py +81 -0
  16. sparse_layers-0.2.0/src/sparse_layers/layers/padded_butterfly_linear.py +70 -0
  17. sparse_layers-0.2.0/src/sparse_layers/layers/simple_mlp.py +60 -0
  18. sparse_layers-0.2.0/src/sparse_layers/sse/__init__.py +55 -0
  19. sparse_layers-0.2.0/src/sparse_layers/sse/attention.py +242 -0
  20. sparse_layers-0.2.0/src/sparse_layers/sse/attention_adaptive.py +117 -0
  21. sparse_layers-0.2.0/src/sparse_layers/sse/linear_attention.py +68 -0
  22. sparse_layers-0.2.0/src/sparse_layers/sse/linear_attention_config.py +22 -0
  23. sparse_layers-0.2.0/src/sparse_layers/sse/masking_ops.py +107 -0
  24. sparse_layers-0.2.0/src/sparse_layers/sse/multi_head_attention.py +196 -0
  25. sparse_layers-0.2.0/src/sparse_layers/sse/multi_partition_state.py +367 -0
  26. sparse_layers-0.2.0/src/sparse_layers/sse/multi_partition_state_config.py +15 -0
  27. sparse_layers-0.2.0/src/sparse_layers/sse/partition_selector.py +43 -0
  28. sparse_layers-0.2.0/src/sparse_layers/sse/partition_selector_config.py +26 -0
  29. sparse_layers-0.2.0/src/sparse_layers/sse/sparse_softmax.py +99 -0
  30. sparse_layers-0.2.0/src/sparse_layers/sse/sparse_softmax_config.py +37 -0
  31. sparse_layers-0.2.0/src/sparse_layers/sse/varlen_ops.py +212 -0
  32. sparse_layers-0.2.0/tests/__init__.py +0 -0
  33. sparse_layers-0.2.0/tests/test_boundary.py +36 -0
  34. sparse_layers-0.2.0/tests/test_butterfly_linear.py +71 -0
  35. sparse_layers-0.2.0/tests/test_butterfly_mlp.py +80 -0
  36. sparse_layers-0.2.0/tests/test_butterfly_multi_head_attention.py +62 -0
  37. sparse_layers-0.2.0/tests/test_custom_linear.py +84 -0
  38. sparse_layers-0.2.0/tests/test_custom_mlp.py +83 -0
  39. sparse_layers-0.2.0/tests/test_linear_attention.py +373 -0
  40. sparse_layers-0.2.0/tests/test_multi_head_attention.py +33 -0
  41. sparse_layers-0.2.0/tests/test_padded_butterfly_linear.py +70 -0
  42. sparse_layers-0.2.0/tests/test_simple_mlp.py +58 -0
  43. sparse_layers-0.2.0/tests/test_sse_attention.py +501 -0
  44. sparse_layers-0.2.0/tests/test_sse_attention_adaptive.py +136 -0
  45. sparse_layers-0.2.0/tests/test_sse_masking_ops.py +187 -0
  46. sparse_layers-0.2.0/tests/test_sse_multi_head_attention.py +306 -0
  47. sparse_layers-0.2.0/tests/test_sse_multi_partition_state.py +758 -0
  48. sparse_layers-0.2.0/tests/test_sse_partition_selector.py +350 -0
  49. sparse_layers-0.2.0/tests/test_sse_sparse_softmax.py +597 -0
  50. sparse_layers-0.2.0/tests/test_sse_varlen_ops.py +280 -0
@@ -0,0 +1,26 @@
1
+ name: Publish to PyPI
2
+
3
+ on:
4
+ push:
5
+ tags:
6
+ - "v*"
7
+
8
+ jobs:
9
+ publish:
10
+ runs-on: ubuntu-latest
11
+ permissions:
12
+ id-token: write
13
+ steps:
14
+ - uses: actions/checkout@v4
15
+
16
+ - uses: actions/setup-python@v5
17
+ with:
18
+ python-version: "3.12"
19
+
20
+ - name: Build
21
+ run: |
22
+ pip install build
23
+ python -m build
24
+
25
+ - name: Publish to PyPI
26
+ uses: pypa/gh-action-pypi-publish@release/v1
@@ -0,0 +1,11 @@
1
+ __pycache__/
2
+ *.py[cod]
3
+ *.egg-info/
4
+ dist/
5
+ build/
6
+ .eggs/
7
+ *.egg
8
+ .pytest_cache/
9
+ .coverage
10
+ htmlcov/
11
+ .pixi/
@@ -0,0 +1,17 @@
1
+ # Changelog
2
+
3
+ ## v0.2.0 - 2026-03-23
4
+
5
+ Initial public release as `sparse-layers` (renamed from `butterfly-layers`).
6
+
7
+ ### Added
8
+
9
+ - Butterfly-factorized linear layer (`ButterflyLinear`) — O(n log n) parameters as a drop-in replacement for `nn.Linear`
10
+ - `PaddedButterflyLinear` for non-power-of-2 dimensions
11
+ - `ButterflyMLP` and `ButterflyMultiHeadAttention` — sparse MLP and attention building blocks
12
+ - SSE (State Space Exploration) attention modules: `SSEAttention`, `SSEAttentionAdaptive`, `SSEMultiHeadAttention`
13
+ - SSE infrastructure: partition selector, sparse softmax, multi-partition state management
14
+ - `LinearAttention` baseline with O(n) complexity
15
+ - Variable-length sequence operations (`SSEVarlenOps`) and masking utilities (`SSEMaskingOps`)
16
+ - Pydantic-based configuration for all SSE modules
17
+ - Boundary enforcement test — verifies the package has no infrastructure dependencies
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2026 Philipp Hematty
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,113 @@
1
+ Metadata-Version: 2.4
2
+ Name: sparse-layers
3
+ Version: 0.2.0
4
+ Summary: Structured sparse layers for building memory-efficient neural networks
5
+ Project-URL: Homepage, https://github.com/PhilHem/sparse-layers
6
+ Project-URL: Repository, https://github.com/PhilHem/sparse-layers
7
+ Author-email: Philipp Hematty <34508934+PhilHem@users.noreply.github.com>
8
+ License-Expression: MIT
9
+ License-File: LICENSE
10
+ Keywords: attention,butterfly,neural-network,pytorch,sparse,sse
11
+ Classifier: Development Status :: 3 - Alpha
12
+ Classifier: Intended Audience :: Science/Research
13
+ Classifier: License :: OSI Approved :: MIT License
14
+ Classifier: Programming Language :: Python :: 3
15
+ Classifier: Programming Language :: Python :: 3.11
16
+ Classifier: Programming Language :: Python :: 3.12
17
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
18
+ Requires-Python: >=3.11
19
+ Requires-Dist: pydantic<3,>=2.12.4
20
+ Requires-Dist: torch<3,>=2.7.1
21
+ Provides-Extra: dev
22
+ Requires-Dist: pytest; extra == 'dev'
23
+ Requires-Dist: pytest-cov<8,>=7.0.0; extra == 'dev'
24
+ Requires-Dist: pytest-xdist<4,>=3.8.0; extra == 'dev'
25
+ Description-Content-Type: text/markdown
26
+
27
+ # sparse-layers
28
+
29
+ Structured sparse layers for building memory-efficient neural networks on PyTorch. Drop-in replacements for standard layers using butterfly factorization, SSE attention, and other sparse primitives.
30
+
31
+ ## Install
32
+
33
+ ```bash
34
+ pip install sparse-layers
35
+ ```
36
+
37
+ ## Usage
38
+
39
+ ```python
40
+ import torch
41
+ from sparse_layers import ButterflyLinear, ButterflyMLP
42
+
43
+ # Drop-in replacement for nn.Linear with O(n log n) parameters
44
+ layer = ButterflyLinear(in_features=256, out_features=256)
45
+ x = torch.randn(32, 256)
46
+ y = layer(x)
47
+
48
+ # MLP with butterfly-factorized linear layers
49
+ mlp = ButterflyMLP(in_features=256, hidden_features=512, out_features=256)
50
+ y = mlp(x)
51
+ ```
52
+
53
+ ### Attention
54
+
55
+ ```python
56
+ from sparse_layers import ButterflyMultiHeadAttention, MultiHeadAttention
57
+
58
+ # Standard multi-head attention
59
+ attn = MultiHeadAttention(d_model=256, num_heads=8)
60
+
61
+ # Butterfly-factorized variant (fewer parameters, same interface)
62
+ bf_attn = ButterflyMultiHeadAttention(d_model=256, num_heads=8)
63
+
64
+ seq = torch.randn(32, 128, 256) # (batch, seq_len, d_model)
65
+ out = bf_attn(seq, seq, seq)
66
+ ```
67
+
68
+ ### SSE Attention
69
+
70
+ State Space Exploration modules for efficient sequence modeling with sparse attention patterns.
71
+
72
+ ```python
73
+ from sparse_layers.sse import SSEAttention, SSEAttentionConfig
74
+
75
+ config = SSEAttentionConfig(d_model=256, num_partitions=4)
76
+ sse = SSEAttention(config)
77
+
78
+ x = torch.randn(32, 128, 256)
79
+ out = sse(x)
80
+ ```
81
+
82
+ ## Modules
83
+
84
+ ### Layers (`sparse_layers.layers`)
85
+
86
+ | Module | Description |
87
+ |--------|-------------|
88
+ | `ButterflyLinear` | Linear layer using butterfly matrix factorization — O(n log n) parameters instead of O(n²) |
89
+ | `PaddedButterflyLinear` | ButterflyLinear with automatic padding for non-power-of-2 dimensions |
90
+ | `ButterflyMLP` | Two-layer MLP with butterfly-factorized linear layers |
91
+ | `ButterflyMultiHeadAttention` | Multi-head attention with butterfly-factorized Q/K/V projections |
92
+ | `MultiHeadAttention` | Standard multi-head attention (baseline) |
93
+ | `CustomLinear` | Linear layer with pluggable weight initialization |
94
+ | `CustomMLP` | MLP with CustomLinear layers |
95
+ | `SimpleMLP` | Minimal MLP baseline |
96
+
97
+ ### SSE (`sparse_layers.sse`)
98
+
99
+ | Module | Description |
100
+ |--------|-------------|
101
+ | `SSEAttention` | Sparse attention with state-space-inspired partitioning |
102
+ | `SSEAttentionAdaptive` | SSE with adaptive implementation selection (naive/batched) |
103
+ | `SSEMultiHeadAttention` | Multi-head variant of SSE attention |
104
+ | `SSEMultiPartitionState` | Manages partition states across sequence chunks |
105
+ | `SSEPartitionSelector` | Selects active partitions per query position |
106
+ | `SSESparseSoftmax` | Sparse softmax over selected partitions |
107
+ | `LinearAttention` | Linear attention baseline (O(n) complexity) |
108
+ | `SSEMaskingOps` | Masking utilities for variable-length SSE |
109
+ | `SSEVarlenOps` | Variable-length sequence operations |
110
+
111
+ ## License
112
+
113
+ MIT
@@ -0,0 +1,87 @@
1
+ # sparse-layers
2
+
3
+ Structured sparse layers for building memory-efficient neural networks on PyTorch. Drop-in replacements for standard layers using butterfly factorization, SSE attention, and other sparse primitives.
4
+
5
+ ## Install
6
+
7
+ ```bash
8
+ pip install sparse-layers
9
+ ```
10
+
11
+ ## Usage
12
+
13
+ ```python
14
+ import torch
15
+ from sparse_layers import ButterflyLinear, ButterflyMLP
16
+
17
+ # Drop-in replacement for nn.Linear with O(n log n) parameters
18
+ layer = ButterflyLinear(in_features=256, out_features=256)
19
+ x = torch.randn(32, 256)
20
+ y = layer(x)
21
+
22
+ # MLP with butterfly-factorized linear layers
23
+ mlp = ButterflyMLP(in_features=256, hidden_features=512, out_features=256)
24
+ y = mlp(x)
25
+ ```
26
+
27
+ ### Attention
28
+
29
+ ```python
30
+ from sparse_layers import ButterflyMultiHeadAttention, MultiHeadAttention
31
+
32
+ # Standard multi-head attention
33
+ attn = MultiHeadAttention(d_model=256, num_heads=8)
34
+
35
+ # Butterfly-factorized variant (fewer parameters, same interface)
36
+ bf_attn = ButterflyMultiHeadAttention(d_model=256, num_heads=8)
37
+
38
+ seq = torch.randn(32, 128, 256) # (batch, seq_len, d_model)
39
+ out = bf_attn(seq, seq, seq)
40
+ ```
41
+
42
+ ### SSE Attention
43
+
44
+ State Space Exploration modules for efficient sequence modeling with sparse attention patterns.
45
+
46
+ ```python
47
+ from sparse_layers.sse import SSEAttention, SSEAttentionConfig
48
+
49
+ config = SSEAttentionConfig(d_model=256, num_partitions=4)
50
+ sse = SSEAttention(config)
51
+
52
+ x = torch.randn(32, 128, 256)
53
+ out = sse(x)
54
+ ```
55
+
56
+ ## Modules
57
+
58
+ ### Layers (`sparse_layers.layers`)
59
+
60
+ | Module | Description |
61
+ |--------|-------------|
62
+ | `ButterflyLinear` | Linear layer using butterfly matrix factorization — O(n log n) parameters instead of O(n²) |
63
+ | `PaddedButterflyLinear` | ButterflyLinear with automatic padding for non-power-of-2 dimensions |
64
+ | `ButterflyMLP` | Two-layer MLP with butterfly-factorized linear layers |
65
+ | `ButterflyMultiHeadAttention` | Multi-head attention with butterfly-factorized Q/K/V projections |
66
+ | `MultiHeadAttention` | Standard multi-head attention (baseline) |
67
+ | `CustomLinear` | Linear layer with pluggable weight initialization |
68
+ | `CustomMLP` | MLP with CustomLinear layers |
69
+ | `SimpleMLP` | Minimal MLP baseline |
70
+
71
+ ### SSE (`sparse_layers.sse`)
72
+
73
+ | Module | Description |
74
+ |--------|-------------|
75
+ | `SSEAttention` | Sparse attention with state-space-inspired partitioning |
76
+ | `SSEAttentionAdaptive` | SSE with adaptive implementation selection (naive/batched) |
77
+ | `SSEMultiHeadAttention` | Multi-head variant of SSE attention |
78
+ | `SSEMultiPartitionState` | Manages partition states across sequence chunks |
79
+ | `SSEPartitionSelector` | Selects active partitions per query position |
80
+ | `SSESparseSoftmax` | Sparse softmax over selected partitions |
81
+ | `LinearAttention` | Linear attention baseline (O(n) complexity) |
82
+ | `SSEMaskingOps` | Masking utilities for variable-length SSE |
83
+ | `SSEVarlenOps` | Variable-length sequence operations |
84
+
85
+ ## License
86
+
87
+ MIT
@@ -0,0 +1,46 @@
1
+ [project]
2
+ name = "sparse-layers"
3
+ version = "0.2.0"
4
+ requires-python = ">= 3.11"
5
+ description = "Structured sparse layers for building memory-efficient neural networks"
6
+ license = "MIT"
7
+ authors = [
8
+ { name = "Philipp Hematty", email = "34508934+PhilHem@users.noreply.github.com" },
9
+ ]
10
+ readme = "README.md"
11
+ keywords = ["butterfly", "neural-network", "sparse", "pytorch", "attention", "sse"]
12
+ classifiers = [
13
+ "Development Status :: 3 - Alpha",
14
+ "Intended Audience :: Science/Research",
15
+ "License :: OSI Approved :: MIT License",
16
+ "Programming Language :: Python :: 3",
17
+ "Programming Language :: Python :: 3.11",
18
+ "Programming Language :: Python :: 3.12",
19
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
20
+ ]
21
+ dependencies = [
22
+ "torch>=2.7.1,<3",
23
+ "pydantic>=2.12.4,<3",
24
+ ]
25
+
26
+ [project.urls]
27
+ Homepage = "https://github.com/PhilHem/sparse-layers"
28
+ Repository = "https://github.com/PhilHem/sparse-layers"
29
+
30
+ [project.optional-dependencies]
31
+ dev = ["pytest", "pytest-xdist>=3.8.0,<4", "pytest-cov>=7.0.0,<8"]
32
+
33
+ [build-system]
34
+ build-backend = "hatchling.build"
35
+ requires = ["hatchling"]
36
+
37
+ [tool.hatch.build.targets.wheel]
38
+ packages = ["src/sparse_layers"]
39
+
40
+ [tool.pytest.ini_options]
41
+ testpaths = ["tests"]
42
+ addopts = "-v --tb=short"
43
+
44
+ [tool.coverage.run]
45
+ source = ["src/sparse_layers"]
46
+ omit = ["*/tests/*"]
@@ -0,0 +1,69 @@
1
+ """sparse-layers: structured sparse layers for building memory-efficient neural networks."""
2
+
3
+ from sparse_layers.layers import (
4
+ ButterflyLinear,
5
+ ButterflyMLP,
6
+ ButterflyMultiHeadAttention,
7
+ CustomLinear,
8
+ CustomMLP,
9
+ MultiHeadAttention,
10
+ PaddedButterflyLinear,
11
+ SimpleMLP,
12
+ )
13
+ from sparse_layers.sse import (
14
+ LinearAttention,
15
+ LinearAttentionConfig,
16
+ NaiveMultiPartitionState,
17
+ NaiveSSEAttention,
18
+ NaiveSSEMultiHeadAttention,
19
+ SSEAttention,
20
+ SSEAttentionAdaptive,
21
+ SSEAttentionAdaptiveConfig,
22
+ SSEAttentionConfig,
23
+ SSEMaskingOps,
24
+ SSEMaskingOpsConfig,
25
+ SSEMultiHeadAttention,
26
+ SSEMultiHeadAttentionConfig,
27
+ SSEMultiPartitionState,
28
+ SSEMultiPartitionStateConfig,
29
+ SSEPartitionSelector,
30
+ SSEPartitionSelectorConfig,
31
+ SSESparseSoftmax,
32
+ SSESparseSoftmaxConfig,
33
+ SSEVarlenOps,
34
+ SSEVarlenOpsConfig,
35
+ )
36
+
37
+ __all__ = [
38
+ # Layers
39
+ "ButterflyLinear",
40
+ "ButterflyMLP",
41
+ "ButterflyMultiHeadAttention",
42
+ "CustomLinear",
43
+ "CustomMLP",
44
+ "MultiHeadAttention",
45
+ "PaddedButterflyLinear",
46
+ "SimpleMLP",
47
+ # SSE
48
+ "LinearAttention",
49
+ "LinearAttentionConfig",
50
+ "NaiveMultiPartitionState",
51
+ "NaiveSSEAttention",
52
+ "NaiveSSEMultiHeadAttention",
53
+ "SSEAttention",
54
+ "SSEAttentionAdaptive",
55
+ "SSEAttentionAdaptiveConfig",
56
+ "SSEAttentionConfig",
57
+ "SSEMaskingOps",
58
+ "SSEMaskingOpsConfig",
59
+ "SSEMultiHeadAttention",
60
+ "SSEMultiHeadAttentionConfig",
61
+ "SSEMultiPartitionState",
62
+ "SSEMultiPartitionStateConfig",
63
+ "SSEPartitionSelector",
64
+ "SSEPartitionSelectorConfig",
65
+ "SSESparseSoftmax",
66
+ "SSESparseSoftmaxConfig",
67
+ "SSEVarlenOps",
68
+ "SSEVarlenOpsConfig",
69
+ ]
@@ -0,0 +1,23 @@
1
+ """Core neural network layers with butterfly factorization support."""
2
+
3
+ from sparse_layers.layers.butterfly_linear import ButterflyLinear
4
+ from sparse_layers.layers.butterfly_mlp import ButterflyMLP
5
+ from sparse_layers.layers.butterfly_multi_head_attention import (
6
+ ButterflyMultiHeadAttention,
7
+ )
8
+ from sparse_layers.layers.custom_linear import CustomLinear
9
+ from sparse_layers.layers.custom_mlp import CustomMLP
10
+ from sparse_layers.layers.multi_head_attention import MultiHeadAttention
11
+ from sparse_layers.layers.padded_butterfly_linear import PaddedButterflyLinear
12
+ from sparse_layers.layers.simple_mlp import SimpleMLP
13
+
14
+ __all__ = [
15
+ "ButterflyLinear",
16
+ "ButterflyMLP",
17
+ "ButterflyMultiHeadAttention",
18
+ "CustomLinear",
19
+ "CustomMLP",
20
+ "MultiHeadAttention",
21
+ "PaddedButterflyLinear",
22
+ "SimpleMLP",
23
+ ]
@@ -0,0 +1,225 @@
1
+ from __future__ import annotations
2
+
3
+ import math
4
+
5
+ import torch
6
+ from torch import Tensor, nn
7
+ import torch.nn.functional as F
8
+
9
+
10
+ def _is_power_of_two(value: int) -> bool:
11
+ return value > 0 and (value & (value - 1) == 0)
12
+
13
+
14
+ class ButterflyLinear(nn.Module):
15
+ """Butterfly factorization based linear layer.
16
+
17
+ This layer follows the structure described in section 2.3.1 of the
18
+ `reducing_memory_requirements_ipu_butterfly.md` reference. It consists of
19
+ log2(N) stages of block-diagonal 2x2 butterfly factors, each stored as a
20
+ learnable parameter. The layer currently supports square power-of-two
21
+ dimensions and acts as a drop-in replacement for :class:`torch.nn.Linear`.
22
+ """
23
+
24
+ def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None:
25
+ super().__init__()
26
+
27
+ if in_features <= 0 or out_features <= 0:
28
+ raise ValueError("in_features and out_features must be positive integers")
29
+
30
+ if in_features != out_features:
31
+ raise ValueError("ButterflyLinear requires in_features == out_features")
32
+
33
+ if not _is_power_of_two(in_features):
34
+ raise ValueError(
35
+ "ButterflyLinear requires dimensions that are a power of two"
36
+ )
37
+
38
+ self.in_features = in_features
39
+ self.out_features = out_features
40
+ self._depth = int(math.log2(in_features))
41
+
42
+ num_blocks = in_features // 2
43
+ default_dtype = torch.get_default_dtype()
44
+ identity_block = torch.eye(2, dtype=default_dtype).unsqueeze(0)
45
+
46
+ self.factors = nn.ParameterList()
47
+ for _ in range(self._depth):
48
+ factor = identity_block.repeat(num_blocks, 1, 1)
49
+ factor += 0.01 * torch.randn_like(factor)
50
+ self.factors.append(nn.Parameter(factor))
51
+
52
+ if bias:
53
+ self.bias = nn.Parameter(torch.zeros(out_features, dtype=default_dtype))
54
+ else:
55
+ self.register_parameter("bias", None)
56
+
57
+ @classmethod
58
+ def from_linear(
59
+ cls,
60
+ layer: nn.Linear,
61
+ *,
62
+ optimization_steps: int = 4000,
63
+ learning_rate: float = 0.1,
64
+ tolerance: float = 1e-7,
65
+ seed: int | None = None,
66
+ ) -> "ButterflyLinear":
67
+ """Construct a butterfly layer approximating a dense :class:`nn.Linear`.
68
+
69
+ The method uses gradient-based fitting over the canonical basis to match
70
+ the transformation represented by ``layer``. The bias term is copied
71
+ directly before optimisation to accelerate convergence.
72
+ """
73
+
74
+ if layer.in_features != layer.out_features:
75
+ raise ValueError("ButterflyLinear.from_linear requires a square nn.Linear")
76
+
77
+ if not _is_power_of_two(layer.in_features):
78
+ raise ValueError(
79
+ "ButterflyLinear.from_linear requires dimensions that are a power of two"
80
+ )
81
+
82
+ device = layer.weight.device
83
+ dtype = layer.weight.dtype
84
+
85
+ if seed is not None:
86
+ torch.manual_seed(seed)
87
+
88
+ result = cls(layer.in_features, layer.out_features, bias=layer.bias is not None)
89
+ result = result.to(device=device, dtype=dtype)
90
+
91
+ if result.bias is not None and layer.bias is not None:
92
+ with torch.no_grad():
93
+ result.bias.copy_(layer.bias)
94
+
95
+ params = list(result.factors.parameters())
96
+ if not params:
97
+ return result
98
+
99
+ optimizer_adam = torch.optim.Adam(params, lr=learning_rate)
100
+
101
+ eye_input = torch.eye(layer.in_features, device=device, dtype=dtype)
102
+ with torch.no_grad():
103
+ target = layer(eye_input)
104
+
105
+ best_loss = float("inf")
106
+
107
+ for step in range(optimization_steps):
108
+ optimizer_adam.zero_grad()
109
+ output = result(eye_input)
110
+ loss = F.mse_loss(output, target)
111
+ loss.backward()
112
+ optimizer_adam.step()
113
+
114
+ current_loss = loss.item()
115
+ if current_loss < best_loss:
116
+ best_loss = current_loss
117
+
118
+ if best_loss <= tolerance:
119
+ break
120
+
121
+ if best_loss > tolerance:
122
+ optimizer_lbfgs = torch.optim.LBFGS(
123
+ params,
124
+ lr=1.0,
125
+ max_iter=100,
126
+ tolerance_grad=1e-12,
127
+ tolerance_change=1e-12,
128
+ line_search_fn="strong_wolfe",
129
+ )
130
+
131
+ def closure() -> Tensor:
132
+ optimizer_lbfgs.zero_grad()
133
+ output_lbfgs = result(eye_input)
134
+ lbfgs_loss = F.mse_loss(output_lbfgs, target)
135
+ lbfgs_loss.backward()
136
+ return lbfgs_loss
137
+
138
+ for _ in range(20):
139
+ loss = optimizer_lbfgs.step(closure)
140
+ best_loss = min(best_loss, loss.item())
141
+ if best_loss <= tolerance:
142
+ break
143
+
144
+ return result
145
+
146
+ def forward(self, input: Tensor) -> Tensor:
147
+ if input.shape[-1] != self.in_features:
148
+ raise ValueError(
149
+ f"Expected last dimension {self.in_features}, got {input.shape[-1]}"
150
+ )
151
+
152
+ original_shape = input.shape[:-1]
153
+ x = input.reshape(-1, self.in_features)
154
+
155
+ for stage_index, factor in enumerate(self.factors):
156
+ x = self._apply_stage(x, factor, stage_index)
157
+
158
+ if self.bias is not None:
159
+ x = x + self.bias
160
+
161
+ return x.reshape(*original_shape, self.out_features)
162
+
163
+ def extra_repr(self) -> str:
164
+ return (
165
+ f"in_features={self.in_features}, out_features={self.out_features}, "
166
+ f"bias={self.bias is not None}"
167
+ )
168
+
169
+ def _apply_stage(self, x: Tensor, factor: Tensor, stage: int) -> Tensor:
170
+ batch = x.shape[0]
171
+ n = self.in_features
172
+ block = 1 << (stage + 1)
173
+ half = block >> 1
174
+
175
+ staged = x.reshape(batch, -1, block)
176
+ staged = staged.reshape(batch, -1, half, 2)
177
+ staged = staged.permute(0, 1, 3, 2).contiguous()
178
+ pairs = staged.reshape(batch, -1, 2)
179
+
180
+ transformed = torch.einsum("bnc,ncd->bnd", pairs, factor)
181
+
182
+ transformed = transformed.reshape(batch, -1, 2, half)
183
+ transformed = transformed.permute(0, 1, 3, 2).contiguous()
184
+ transformed = transformed.reshape(batch, -1, block)
185
+ return transformed.reshape(batch, n)
186
+
187
+ def to_linear(self) -> nn.Linear:
188
+ """Return a dense :class:`nn.Linear` with identical behaviour."""
189
+
190
+ factor_tensor = self.factors[0] if len(self.factors) > 0 else None
191
+ if factor_tensor is not None:
192
+ device = factor_tensor.device
193
+ dtype = factor_tensor.dtype
194
+ elif self.bias is not None:
195
+ device = self.bias.device
196
+ dtype = self.bias.dtype
197
+ else:
198
+ device = torch.device("cpu")
199
+ dtype = torch.get_default_dtype()
200
+
201
+ linear = nn.Linear(
202
+ self.in_features,
203
+ self.out_features,
204
+ bias=self.bias is not None,
205
+ device=device,
206
+ dtype=dtype,
207
+ )
208
+
209
+ with torch.no_grad():
210
+ if self.bias is not None:
211
+ bias_backup = self.bias.data.clone()
212
+ self.bias.zero_()
213
+ else:
214
+ bias_backup = None
215
+
216
+ identity = torch.eye(self.in_features, device=device, dtype=dtype)
217
+ weight_matrix = self(identity)
218
+
219
+ if bias_backup is not None:
220
+ self.bias.copy_(bias_backup)
221
+ linear.bias.copy_(bias_backup)
222
+
223
+ linear.weight.copy_(weight_matrix.t())
224
+
225
+ return linear