sparse-layers 0.2.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sparse_layers-0.2.0/.github/workflows/publish.yml +26 -0
- sparse_layers-0.2.0/.gitignore +11 -0
- sparse_layers-0.2.0/CHANGELOG.md +17 -0
- sparse_layers-0.2.0/LICENSE +21 -0
- sparse_layers-0.2.0/PKG-INFO +113 -0
- sparse_layers-0.2.0/README.md +87 -0
- sparse_layers-0.2.0/pyproject.toml +46 -0
- sparse_layers-0.2.0/src/sparse_layers/__init__.py +69 -0
- sparse_layers-0.2.0/src/sparse_layers/layers/__init__.py +23 -0
- sparse_layers-0.2.0/src/sparse_layers/layers/butterfly_linear.py +225 -0
- sparse_layers-0.2.0/src/sparse_layers/layers/butterfly_mlp.py +145 -0
- sparse_layers-0.2.0/src/sparse_layers/layers/butterfly_multi_head_attention.py +93 -0
- sparse_layers-0.2.0/src/sparse_layers/layers/custom_linear.py +46 -0
- sparse_layers-0.2.0/src/sparse_layers/layers/custom_mlp.py +54 -0
- sparse_layers-0.2.0/src/sparse_layers/layers/multi_head_attention.py +81 -0
- sparse_layers-0.2.0/src/sparse_layers/layers/padded_butterfly_linear.py +70 -0
- sparse_layers-0.2.0/src/sparse_layers/layers/simple_mlp.py +60 -0
- sparse_layers-0.2.0/src/sparse_layers/sse/__init__.py +55 -0
- sparse_layers-0.2.0/src/sparse_layers/sse/attention.py +242 -0
- sparse_layers-0.2.0/src/sparse_layers/sse/attention_adaptive.py +117 -0
- sparse_layers-0.2.0/src/sparse_layers/sse/linear_attention.py +68 -0
- sparse_layers-0.2.0/src/sparse_layers/sse/linear_attention_config.py +22 -0
- sparse_layers-0.2.0/src/sparse_layers/sse/masking_ops.py +107 -0
- sparse_layers-0.2.0/src/sparse_layers/sse/multi_head_attention.py +196 -0
- sparse_layers-0.2.0/src/sparse_layers/sse/multi_partition_state.py +367 -0
- sparse_layers-0.2.0/src/sparse_layers/sse/multi_partition_state_config.py +15 -0
- sparse_layers-0.2.0/src/sparse_layers/sse/partition_selector.py +43 -0
- sparse_layers-0.2.0/src/sparse_layers/sse/partition_selector_config.py +26 -0
- sparse_layers-0.2.0/src/sparse_layers/sse/sparse_softmax.py +99 -0
- sparse_layers-0.2.0/src/sparse_layers/sse/sparse_softmax_config.py +37 -0
- sparse_layers-0.2.0/src/sparse_layers/sse/varlen_ops.py +212 -0
- sparse_layers-0.2.0/tests/__init__.py +0 -0
- sparse_layers-0.2.0/tests/test_boundary.py +36 -0
- sparse_layers-0.2.0/tests/test_butterfly_linear.py +71 -0
- sparse_layers-0.2.0/tests/test_butterfly_mlp.py +80 -0
- sparse_layers-0.2.0/tests/test_butterfly_multi_head_attention.py +62 -0
- sparse_layers-0.2.0/tests/test_custom_linear.py +84 -0
- sparse_layers-0.2.0/tests/test_custom_mlp.py +83 -0
- sparse_layers-0.2.0/tests/test_linear_attention.py +373 -0
- sparse_layers-0.2.0/tests/test_multi_head_attention.py +33 -0
- sparse_layers-0.2.0/tests/test_padded_butterfly_linear.py +70 -0
- sparse_layers-0.2.0/tests/test_simple_mlp.py +58 -0
- sparse_layers-0.2.0/tests/test_sse_attention.py +501 -0
- sparse_layers-0.2.0/tests/test_sse_attention_adaptive.py +136 -0
- sparse_layers-0.2.0/tests/test_sse_masking_ops.py +187 -0
- sparse_layers-0.2.0/tests/test_sse_multi_head_attention.py +306 -0
- sparse_layers-0.2.0/tests/test_sse_multi_partition_state.py +758 -0
- sparse_layers-0.2.0/tests/test_sse_partition_selector.py +350 -0
- sparse_layers-0.2.0/tests/test_sse_sparse_softmax.py +597 -0
- sparse_layers-0.2.0/tests/test_sse_varlen_ops.py +280 -0
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
name: Publish to PyPI
|
|
2
|
+
|
|
3
|
+
on:
|
|
4
|
+
push:
|
|
5
|
+
tags:
|
|
6
|
+
- "v*"
|
|
7
|
+
|
|
8
|
+
jobs:
|
|
9
|
+
publish:
|
|
10
|
+
runs-on: ubuntu-latest
|
|
11
|
+
permissions:
|
|
12
|
+
id-token: write
|
|
13
|
+
steps:
|
|
14
|
+
- uses: actions/checkout@v4
|
|
15
|
+
|
|
16
|
+
- uses: actions/setup-python@v5
|
|
17
|
+
with:
|
|
18
|
+
python-version: "3.12"
|
|
19
|
+
|
|
20
|
+
- name: Build
|
|
21
|
+
run: |
|
|
22
|
+
pip install build
|
|
23
|
+
python -m build
|
|
24
|
+
|
|
25
|
+
- name: Publish to PyPI
|
|
26
|
+
uses: pypa/gh-action-pypi-publish@release/v1
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
# Changelog
|
|
2
|
+
|
|
3
|
+
## v0.2.0 - 2026-03-23
|
|
4
|
+
|
|
5
|
+
Initial public release as `sparse-layers` (renamed from `butterfly-layers`).
|
|
6
|
+
|
|
7
|
+
### Added
|
|
8
|
+
|
|
9
|
+
- Butterfly-factorized linear layer (`ButterflyLinear`) — O(n log n) parameters as a drop-in replacement for `nn.Linear`
|
|
10
|
+
- `PaddedButterflyLinear` for non-power-of-2 dimensions
|
|
11
|
+
- `ButterflyMLP` and `ButterflyMultiHeadAttention` — sparse MLP and attention building blocks
|
|
12
|
+
- SSE (State Space Exploration) attention modules: `SSEAttention`, `SSEAttentionAdaptive`, `SSEMultiHeadAttention`
|
|
13
|
+
- SSE infrastructure: partition selector, sparse softmax, multi-partition state management
|
|
14
|
+
- `LinearAttention` baseline with O(n) complexity
|
|
15
|
+
- Variable-length sequence operations (`SSEVarlenOps`) and masking utilities (`SSEMaskingOps`)
|
|
16
|
+
- Pydantic-based configuration for all SSE modules
|
|
17
|
+
- Boundary enforcement test — verifies the package has no infrastructure dependencies
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2026 Philipp Hematty
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
|
@@ -0,0 +1,113 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: sparse-layers
|
|
3
|
+
Version: 0.2.0
|
|
4
|
+
Summary: Structured sparse layers for building memory-efficient neural networks
|
|
5
|
+
Project-URL: Homepage, https://github.com/PhilHem/sparse-layers
|
|
6
|
+
Project-URL: Repository, https://github.com/PhilHem/sparse-layers
|
|
7
|
+
Author-email: Philipp Hematty <34508934+PhilHem@users.noreply.github.com>
|
|
8
|
+
License-Expression: MIT
|
|
9
|
+
License-File: LICENSE
|
|
10
|
+
Keywords: attention,butterfly,neural-network,pytorch,sparse,sse
|
|
11
|
+
Classifier: Development Status :: 3 - Alpha
|
|
12
|
+
Classifier: Intended Audience :: Science/Research
|
|
13
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
14
|
+
Classifier: Programming Language :: Python :: 3
|
|
15
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
16
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
17
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
18
|
+
Requires-Python: >=3.11
|
|
19
|
+
Requires-Dist: pydantic<3,>=2.12.4
|
|
20
|
+
Requires-Dist: torch<3,>=2.7.1
|
|
21
|
+
Provides-Extra: dev
|
|
22
|
+
Requires-Dist: pytest; extra == 'dev'
|
|
23
|
+
Requires-Dist: pytest-cov<8,>=7.0.0; extra == 'dev'
|
|
24
|
+
Requires-Dist: pytest-xdist<4,>=3.8.0; extra == 'dev'
|
|
25
|
+
Description-Content-Type: text/markdown
|
|
26
|
+
|
|
27
|
+
# sparse-layers
|
|
28
|
+
|
|
29
|
+
Structured sparse layers for building memory-efficient neural networks on PyTorch. Drop-in replacements for standard layers using butterfly factorization, SSE attention, and other sparse primitives.
|
|
30
|
+
|
|
31
|
+
## Install
|
|
32
|
+
|
|
33
|
+
```bash
|
|
34
|
+
pip install sparse-layers
|
|
35
|
+
```
|
|
36
|
+
|
|
37
|
+
## Usage
|
|
38
|
+
|
|
39
|
+
```python
|
|
40
|
+
import torch
|
|
41
|
+
from sparse_layers import ButterflyLinear, ButterflyMLP
|
|
42
|
+
|
|
43
|
+
# Drop-in replacement for nn.Linear with O(n log n) parameters
|
|
44
|
+
layer = ButterflyLinear(in_features=256, out_features=256)
|
|
45
|
+
x = torch.randn(32, 256)
|
|
46
|
+
y = layer(x)
|
|
47
|
+
|
|
48
|
+
# MLP with butterfly-factorized linear layers
|
|
49
|
+
mlp = ButterflyMLP(in_features=256, hidden_features=512, out_features=256)
|
|
50
|
+
y = mlp(x)
|
|
51
|
+
```
|
|
52
|
+
|
|
53
|
+
### Attention
|
|
54
|
+
|
|
55
|
+
```python
|
|
56
|
+
from sparse_layers import ButterflyMultiHeadAttention, MultiHeadAttention
|
|
57
|
+
|
|
58
|
+
# Standard multi-head attention
|
|
59
|
+
attn = MultiHeadAttention(d_model=256, num_heads=8)
|
|
60
|
+
|
|
61
|
+
# Butterfly-factorized variant (fewer parameters, same interface)
|
|
62
|
+
bf_attn = ButterflyMultiHeadAttention(d_model=256, num_heads=8)
|
|
63
|
+
|
|
64
|
+
seq = torch.randn(32, 128, 256) # (batch, seq_len, d_model)
|
|
65
|
+
out = bf_attn(seq, seq, seq)
|
|
66
|
+
```
|
|
67
|
+
|
|
68
|
+
### SSE Attention
|
|
69
|
+
|
|
70
|
+
State Space Exploration modules for efficient sequence modeling with sparse attention patterns.
|
|
71
|
+
|
|
72
|
+
```python
|
|
73
|
+
from sparse_layers.sse import SSEAttention, SSEAttentionConfig
|
|
74
|
+
|
|
75
|
+
config = SSEAttentionConfig(d_model=256, num_partitions=4)
|
|
76
|
+
sse = SSEAttention(config)
|
|
77
|
+
|
|
78
|
+
x = torch.randn(32, 128, 256)
|
|
79
|
+
out = sse(x)
|
|
80
|
+
```
|
|
81
|
+
|
|
82
|
+
## Modules
|
|
83
|
+
|
|
84
|
+
### Layers (`sparse_layers.layers`)
|
|
85
|
+
|
|
86
|
+
| Module | Description |
|
|
87
|
+
|--------|-------------|
|
|
88
|
+
| `ButterflyLinear` | Linear layer using butterfly matrix factorization — O(n log n) parameters instead of O(n²) |
|
|
89
|
+
| `PaddedButterflyLinear` | ButterflyLinear with automatic padding for non-power-of-2 dimensions |
|
|
90
|
+
| `ButterflyMLP` | Two-layer MLP with butterfly-factorized linear layers |
|
|
91
|
+
| `ButterflyMultiHeadAttention` | Multi-head attention with butterfly-factorized Q/K/V projections |
|
|
92
|
+
| `MultiHeadAttention` | Standard multi-head attention (baseline) |
|
|
93
|
+
| `CustomLinear` | Linear layer with pluggable weight initialization |
|
|
94
|
+
| `CustomMLP` | MLP with CustomLinear layers |
|
|
95
|
+
| `SimpleMLP` | Minimal MLP baseline |
|
|
96
|
+
|
|
97
|
+
### SSE (`sparse_layers.sse`)
|
|
98
|
+
|
|
99
|
+
| Module | Description |
|
|
100
|
+
|--------|-------------|
|
|
101
|
+
| `SSEAttention` | Sparse attention with state-space-inspired partitioning |
|
|
102
|
+
| `SSEAttentionAdaptive` | SSE with adaptive implementation selection (naive/batched) |
|
|
103
|
+
| `SSEMultiHeadAttention` | Multi-head variant of SSE attention |
|
|
104
|
+
| `SSEMultiPartitionState` | Manages partition states across sequence chunks |
|
|
105
|
+
| `SSEPartitionSelector` | Selects active partitions per query position |
|
|
106
|
+
| `SSESparseSoftmax` | Sparse softmax over selected partitions |
|
|
107
|
+
| `LinearAttention` | Linear attention baseline (O(n) complexity) |
|
|
108
|
+
| `SSEMaskingOps` | Masking utilities for variable-length SSE |
|
|
109
|
+
| `SSEVarlenOps` | Variable-length sequence operations |
|
|
110
|
+
|
|
111
|
+
## License
|
|
112
|
+
|
|
113
|
+
MIT
|
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
# sparse-layers
|
|
2
|
+
|
|
3
|
+
Structured sparse layers for building memory-efficient neural networks on PyTorch. Drop-in replacements for standard layers using butterfly factorization, SSE attention, and other sparse primitives.
|
|
4
|
+
|
|
5
|
+
## Install
|
|
6
|
+
|
|
7
|
+
```bash
|
|
8
|
+
pip install sparse-layers
|
|
9
|
+
```
|
|
10
|
+
|
|
11
|
+
## Usage
|
|
12
|
+
|
|
13
|
+
```python
|
|
14
|
+
import torch
|
|
15
|
+
from sparse_layers import ButterflyLinear, ButterflyMLP
|
|
16
|
+
|
|
17
|
+
# Drop-in replacement for nn.Linear with O(n log n) parameters
|
|
18
|
+
layer = ButterflyLinear(in_features=256, out_features=256)
|
|
19
|
+
x = torch.randn(32, 256)
|
|
20
|
+
y = layer(x)
|
|
21
|
+
|
|
22
|
+
# MLP with butterfly-factorized linear layers
|
|
23
|
+
mlp = ButterflyMLP(in_features=256, hidden_features=512, out_features=256)
|
|
24
|
+
y = mlp(x)
|
|
25
|
+
```
|
|
26
|
+
|
|
27
|
+
### Attention
|
|
28
|
+
|
|
29
|
+
```python
|
|
30
|
+
from sparse_layers import ButterflyMultiHeadAttention, MultiHeadAttention
|
|
31
|
+
|
|
32
|
+
# Standard multi-head attention
|
|
33
|
+
attn = MultiHeadAttention(d_model=256, num_heads=8)
|
|
34
|
+
|
|
35
|
+
# Butterfly-factorized variant (fewer parameters, same interface)
|
|
36
|
+
bf_attn = ButterflyMultiHeadAttention(d_model=256, num_heads=8)
|
|
37
|
+
|
|
38
|
+
seq = torch.randn(32, 128, 256) # (batch, seq_len, d_model)
|
|
39
|
+
out = bf_attn(seq, seq, seq)
|
|
40
|
+
```
|
|
41
|
+
|
|
42
|
+
### SSE Attention
|
|
43
|
+
|
|
44
|
+
State Space Exploration modules for efficient sequence modeling with sparse attention patterns.
|
|
45
|
+
|
|
46
|
+
```python
|
|
47
|
+
from sparse_layers.sse import SSEAttention, SSEAttentionConfig
|
|
48
|
+
|
|
49
|
+
config = SSEAttentionConfig(d_model=256, num_partitions=4)
|
|
50
|
+
sse = SSEAttention(config)
|
|
51
|
+
|
|
52
|
+
x = torch.randn(32, 128, 256)
|
|
53
|
+
out = sse(x)
|
|
54
|
+
```
|
|
55
|
+
|
|
56
|
+
## Modules
|
|
57
|
+
|
|
58
|
+
### Layers (`sparse_layers.layers`)
|
|
59
|
+
|
|
60
|
+
| Module | Description |
|
|
61
|
+
|--------|-------------|
|
|
62
|
+
| `ButterflyLinear` | Linear layer using butterfly matrix factorization — O(n log n) parameters instead of O(n²) |
|
|
63
|
+
| `PaddedButterflyLinear` | ButterflyLinear with automatic padding for non-power-of-2 dimensions |
|
|
64
|
+
| `ButterflyMLP` | Two-layer MLP with butterfly-factorized linear layers |
|
|
65
|
+
| `ButterflyMultiHeadAttention` | Multi-head attention with butterfly-factorized Q/K/V projections |
|
|
66
|
+
| `MultiHeadAttention` | Standard multi-head attention (baseline) |
|
|
67
|
+
| `CustomLinear` | Linear layer with pluggable weight initialization |
|
|
68
|
+
| `CustomMLP` | MLP with CustomLinear layers |
|
|
69
|
+
| `SimpleMLP` | Minimal MLP baseline |
|
|
70
|
+
|
|
71
|
+
### SSE (`sparse_layers.sse`)
|
|
72
|
+
|
|
73
|
+
| Module | Description |
|
|
74
|
+
|--------|-------------|
|
|
75
|
+
| `SSEAttention` | Sparse attention with state-space-inspired partitioning |
|
|
76
|
+
| `SSEAttentionAdaptive` | SSE with adaptive implementation selection (naive/batched) |
|
|
77
|
+
| `SSEMultiHeadAttention` | Multi-head variant of SSE attention |
|
|
78
|
+
| `SSEMultiPartitionState` | Manages partition states across sequence chunks |
|
|
79
|
+
| `SSEPartitionSelector` | Selects active partitions per query position |
|
|
80
|
+
| `SSESparseSoftmax` | Sparse softmax over selected partitions |
|
|
81
|
+
| `LinearAttention` | Linear attention baseline (O(n) complexity) |
|
|
82
|
+
| `SSEMaskingOps` | Masking utilities for variable-length SSE |
|
|
83
|
+
| `SSEVarlenOps` | Variable-length sequence operations |
|
|
84
|
+
|
|
85
|
+
## License
|
|
86
|
+
|
|
87
|
+
MIT
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
[project]
|
|
2
|
+
name = "sparse-layers"
|
|
3
|
+
version = "0.2.0"
|
|
4
|
+
requires-python = ">= 3.11"
|
|
5
|
+
description = "Structured sparse layers for building memory-efficient neural networks"
|
|
6
|
+
license = "MIT"
|
|
7
|
+
authors = [
|
|
8
|
+
{ name = "Philipp Hematty", email = "34508934+PhilHem@users.noreply.github.com" },
|
|
9
|
+
]
|
|
10
|
+
readme = "README.md"
|
|
11
|
+
keywords = ["butterfly", "neural-network", "sparse", "pytorch", "attention", "sse"]
|
|
12
|
+
classifiers = [
|
|
13
|
+
"Development Status :: 3 - Alpha",
|
|
14
|
+
"Intended Audience :: Science/Research",
|
|
15
|
+
"License :: OSI Approved :: MIT License",
|
|
16
|
+
"Programming Language :: Python :: 3",
|
|
17
|
+
"Programming Language :: Python :: 3.11",
|
|
18
|
+
"Programming Language :: Python :: 3.12",
|
|
19
|
+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
|
20
|
+
]
|
|
21
|
+
dependencies = [
|
|
22
|
+
"torch>=2.7.1,<3",
|
|
23
|
+
"pydantic>=2.12.4,<3",
|
|
24
|
+
]
|
|
25
|
+
|
|
26
|
+
[project.urls]
|
|
27
|
+
Homepage = "https://github.com/PhilHem/sparse-layers"
|
|
28
|
+
Repository = "https://github.com/PhilHem/sparse-layers"
|
|
29
|
+
|
|
30
|
+
[project.optional-dependencies]
|
|
31
|
+
dev = ["pytest", "pytest-xdist>=3.8.0,<4", "pytest-cov>=7.0.0,<8"]
|
|
32
|
+
|
|
33
|
+
[build-system]
|
|
34
|
+
build-backend = "hatchling.build"
|
|
35
|
+
requires = ["hatchling"]
|
|
36
|
+
|
|
37
|
+
[tool.hatch.build.targets.wheel]
|
|
38
|
+
packages = ["src/sparse_layers"]
|
|
39
|
+
|
|
40
|
+
[tool.pytest.ini_options]
|
|
41
|
+
testpaths = ["tests"]
|
|
42
|
+
addopts = "-v --tb=short"
|
|
43
|
+
|
|
44
|
+
[tool.coverage.run]
|
|
45
|
+
source = ["src/sparse_layers"]
|
|
46
|
+
omit = ["*/tests/*"]
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
"""sparse-layers: structured sparse layers for building memory-efficient neural networks."""
|
|
2
|
+
|
|
3
|
+
from sparse_layers.layers import (
|
|
4
|
+
ButterflyLinear,
|
|
5
|
+
ButterflyMLP,
|
|
6
|
+
ButterflyMultiHeadAttention,
|
|
7
|
+
CustomLinear,
|
|
8
|
+
CustomMLP,
|
|
9
|
+
MultiHeadAttention,
|
|
10
|
+
PaddedButterflyLinear,
|
|
11
|
+
SimpleMLP,
|
|
12
|
+
)
|
|
13
|
+
from sparse_layers.sse import (
|
|
14
|
+
LinearAttention,
|
|
15
|
+
LinearAttentionConfig,
|
|
16
|
+
NaiveMultiPartitionState,
|
|
17
|
+
NaiveSSEAttention,
|
|
18
|
+
NaiveSSEMultiHeadAttention,
|
|
19
|
+
SSEAttention,
|
|
20
|
+
SSEAttentionAdaptive,
|
|
21
|
+
SSEAttentionAdaptiveConfig,
|
|
22
|
+
SSEAttentionConfig,
|
|
23
|
+
SSEMaskingOps,
|
|
24
|
+
SSEMaskingOpsConfig,
|
|
25
|
+
SSEMultiHeadAttention,
|
|
26
|
+
SSEMultiHeadAttentionConfig,
|
|
27
|
+
SSEMultiPartitionState,
|
|
28
|
+
SSEMultiPartitionStateConfig,
|
|
29
|
+
SSEPartitionSelector,
|
|
30
|
+
SSEPartitionSelectorConfig,
|
|
31
|
+
SSESparseSoftmax,
|
|
32
|
+
SSESparseSoftmaxConfig,
|
|
33
|
+
SSEVarlenOps,
|
|
34
|
+
SSEVarlenOpsConfig,
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
__all__ = [
|
|
38
|
+
# Layers
|
|
39
|
+
"ButterflyLinear",
|
|
40
|
+
"ButterflyMLP",
|
|
41
|
+
"ButterflyMultiHeadAttention",
|
|
42
|
+
"CustomLinear",
|
|
43
|
+
"CustomMLP",
|
|
44
|
+
"MultiHeadAttention",
|
|
45
|
+
"PaddedButterflyLinear",
|
|
46
|
+
"SimpleMLP",
|
|
47
|
+
# SSE
|
|
48
|
+
"LinearAttention",
|
|
49
|
+
"LinearAttentionConfig",
|
|
50
|
+
"NaiveMultiPartitionState",
|
|
51
|
+
"NaiveSSEAttention",
|
|
52
|
+
"NaiveSSEMultiHeadAttention",
|
|
53
|
+
"SSEAttention",
|
|
54
|
+
"SSEAttentionAdaptive",
|
|
55
|
+
"SSEAttentionAdaptiveConfig",
|
|
56
|
+
"SSEAttentionConfig",
|
|
57
|
+
"SSEMaskingOps",
|
|
58
|
+
"SSEMaskingOpsConfig",
|
|
59
|
+
"SSEMultiHeadAttention",
|
|
60
|
+
"SSEMultiHeadAttentionConfig",
|
|
61
|
+
"SSEMultiPartitionState",
|
|
62
|
+
"SSEMultiPartitionStateConfig",
|
|
63
|
+
"SSEPartitionSelector",
|
|
64
|
+
"SSEPartitionSelectorConfig",
|
|
65
|
+
"SSESparseSoftmax",
|
|
66
|
+
"SSESparseSoftmaxConfig",
|
|
67
|
+
"SSEVarlenOps",
|
|
68
|
+
"SSEVarlenOpsConfig",
|
|
69
|
+
]
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
"""Core neural network layers with butterfly factorization support."""
|
|
2
|
+
|
|
3
|
+
from sparse_layers.layers.butterfly_linear import ButterflyLinear
|
|
4
|
+
from sparse_layers.layers.butterfly_mlp import ButterflyMLP
|
|
5
|
+
from sparse_layers.layers.butterfly_multi_head_attention import (
|
|
6
|
+
ButterflyMultiHeadAttention,
|
|
7
|
+
)
|
|
8
|
+
from sparse_layers.layers.custom_linear import CustomLinear
|
|
9
|
+
from sparse_layers.layers.custom_mlp import CustomMLP
|
|
10
|
+
from sparse_layers.layers.multi_head_attention import MultiHeadAttention
|
|
11
|
+
from sparse_layers.layers.padded_butterfly_linear import PaddedButterflyLinear
|
|
12
|
+
from sparse_layers.layers.simple_mlp import SimpleMLP
|
|
13
|
+
|
|
14
|
+
__all__ = [
|
|
15
|
+
"ButterflyLinear",
|
|
16
|
+
"ButterflyMLP",
|
|
17
|
+
"ButterflyMultiHeadAttention",
|
|
18
|
+
"CustomLinear",
|
|
19
|
+
"CustomMLP",
|
|
20
|
+
"MultiHeadAttention",
|
|
21
|
+
"PaddedButterflyLinear",
|
|
22
|
+
"SimpleMLP",
|
|
23
|
+
]
|
|
@@ -0,0 +1,225 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from torch import Tensor, nn
|
|
7
|
+
import torch.nn.functional as F
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def _is_power_of_two(value: int) -> bool:
|
|
11
|
+
return value > 0 and (value & (value - 1) == 0)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class ButterflyLinear(nn.Module):
|
|
15
|
+
"""Butterfly factorization based linear layer.
|
|
16
|
+
|
|
17
|
+
This layer follows the structure described in section 2.3.1 of the
|
|
18
|
+
`reducing_memory_requirements_ipu_butterfly.md` reference. It consists of
|
|
19
|
+
log2(N) stages of block-diagonal 2x2 butterfly factors, each stored as a
|
|
20
|
+
learnable parameter. The layer currently supports square power-of-two
|
|
21
|
+
dimensions and acts as a drop-in replacement for :class:`torch.nn.Linear`.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None:
|
|
25
|
+
super().__init__()
|
|
26
|
+
|
|
27
|
+
if in_features <= 0 or out_features <= 0:
|
|
28
|
+
raise ValueError("in_features and out_features must be positive integers")
|
|
29
|
+
|
|
30
|
+
if in_features != out_features:
|
|
31
|
+
raise ValueError("ButterflyLinear requires in_features == out_features")
|
|
32
|
+
|
|
33
|
+
if not _is_power_of_two(in_features):
|
|
34
|
+
raise ValueError(
|
|
35
|
+
"ButterflyLinear requires dimensions that are a power of two"
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
self.in_features = in_features
|
|
39
|
+
self.out_features = out_features
|
|
40
|
+
self._depth = int(math.log2(in_features))
|
|
41
|
+
|
|
42
|
+
num_blocks = in_features // 2
|
|
43
|
+
default_dtype = torch.get_default_dtype()
|
|
44
|
+
identity_block = torch.eye(2, dtype=default_dtype).unsqueeze(0)
|
|
45
|
+
|
|
46
|
+
self.factors = nn.ParameterList()
|
|
47
|
+
for _ in range(self._depth):
|
|
48
|
+
factor = identity_block.repeat(num_blocks, 1, 1)
|
|
49
|
+
factor += 0.01 * torch.randn_like(factor)
|
|
50
|
+
self.factors.append(nn.Parameter(factor))
|
|
51
|
+
|
|
52
|
+
if bias:
|
|
53
|
+
self.bias = nn.Parameter(torch.zeros(out_features, dtype=default_dtype))
|
|
54
|
+
else:
|
|
55
|
+
self.register_parameter("bias", None)
|
|
56
|
+
|
|
57
|
+
@classmethod
|
|
58
|
+
def from_linear(
|
|
59
|
+
cls,
|
|
60
|
+
layer: nn.Linear,
|
|
61
|
+
*,
|
|
62
|
+
optimization_steps: int = 4000,
|
|
63
|
+
learning_rate: float = 0.1,
|
|
64
|
+
tolerance: float = 1e-7,
|
|
65
|
+
seed: int | None = None,
|
|
66
|
+
) -> "ButterflyLinear":
|
|
67
|
+
"""Construct a butterfly layer approximating a dense :class:`nn.Linear`.
|
|
68
|
+
|
|
69
|
+
The method uses gradient-based fitting over the canonical basis to match
|
|
70
|
+
the transformation represented by ``layer``. The bias term is copied
|
|
71
|
+
directly before optimisation to accelerate convergence.
|
|
72
|
+
"""
|
|
73
|
+
|
|
74
|
+
if layer.in_features != layer.out_features:
|
|
75
|
+
raise ValueError("ButterflyLinear.from_linear requires a square nn.Linear")
|
|
76
|
+
|
|
77
|
+
if not _is_power_of_two(layer.in_features):
|
|
78
|
+
raise ValueError(
|
|
79
|
+
"ButterflyLinear.from_linear requires dimensions that are a power of two"
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
device = layer.weight.device
|
|
83
|
+
dtype = layer.weight.dtype
|
|
84
|
+
|
|
85
|
+
if seed is not None:
|
|
86
|
+
torch.manual_seed(seed)
|
|
87
|
+
|
|
88
|
+
result = cls(layer.in_features, layer.out_features, bias=layer.bias is not None)
|
|
89
|
+
result = result.to(device=device, dtype=dtype)
|
|
90
|
+
|
|
91
|
+
if result.bias is not None and layer.bias is not None:
|
|
92
|
+
with torch.no_grad():
|
|
93
|
+
result.bias.copy_(layer.bias)
|
|
94
|
+
|
|
95
|
+
params = list(result.factors.parameters())
|
|
96
|
+
if not params:
|
|
97
|
+
return result
|
|
98
|
+
|
|
99
|
+
optimizer_adam = torch.optim.Adam(params, lr=learning_rate)
|
|
100
|
+
|
|
101
|
+
eye_input = torch.eye(layer.in_features, device=device, dtype=dtype)
|
|
102
|
+
with torch.no_grad():
|
|
103
|
+
target = layer(eye_input)
|
|
104
|
+
|
|
105
|
+
best_loss = float("inf")
|
|
106
|
+
|
|
107
|
+
for step in range(optimization_steps):
|
|
108
|
+
optimizer_adam.zero_grad()
|
|
109
|
+
output = result(eye_input)
|
|
110
|
+
loss = F.mse_loss(output, target)
|
|
111
|
+
loss.backward()
|
|
112
|
+
optimizer_adam.step()
|
|
113
|
+
|
|
114
|
+
current_loss = loss.item()
|
|
115
|
+
if current_loss < best_loss:
|
|
116
|
+
best_loss = current_loss
|
|
117
|
+
|
|
118
|
+
if best_loss <= tolerance:
|
|
119
|
+
break
|
|
120
|
+
|
|
121
|
+
if best_loss > tolerance:
|
|
122
|
+
optimizer_lbfgs = torch.optim.LBFGS(
|
|
123
|
+
params,
|
|
124
|
+
lr=1.0,
|
|
125
|
+
max_iter=100,
|
|
126
|
+
tolerance_grad=1e-12,
|
|
127
|
+
tolerance_change=1e-12,
|
|
128
|
+
line_search_fn="strong_wolfe",
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
def closure() -> Tensor:
|
|
132
|
+
optimizer_lbfgs.zero_grad()
|
|
133
|
+
output_lbfgs = result(eye_input)
|
|
134
|
+
lbfgs_loss = F.mse_loss(output_lbfgs, target)
|
|
135
|
+
lbfgs_loss.backward()
|
|
136
|
+
return lbfgs_loss
|
|
137
|
+
|
|
138
|
+
for _ in range(20):
|
|
139
|
+
loss = optimizer_lbfgs.step(closure)
|
|
140
|
+
best_loss = min(best_loss, loss.item())
|
|
141
|
+
if best_loss <= tolerance:
|
|
142
|
+
break
|
|
143
|
+
|
|
144
|
+
return result
|
|
145
|
+
|
|
146
|
+
def forward(self, input: Tensor) -> Tensor:
|
|
147
|
+
if input.shape[-1] != self.in_features:
|
|
148
|
+
raise ValueError(
|
|
149
|
+
f"Expected last dimension {self.in_features}, got {input.shape[-1]}"
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
original_shape = input.shape[:-1]
|
|
153
|
+
x = input.reshape(-1, self.in_features)
|
|
154
|
+
|
|
155
|
+
for stage_index, factor in enumerate(self.factors):
|
|
156
|
+
x = self._apply_stage(x, factor, stage_index)
|
|
157
|
+
|
|
158
|
+
if self.bias is not None:
|
|
159
|
+
x = x + self.bias
|
|
160
|
+
|
|
161
|
+
return x.reshape(*original_shape, self.out_features)
|
|
162
|
+
|
|
163
|
+
def extra_repr(self) -> str:
|
|
164
|
+
return (
|
|
165
|
+
f"in_features={self.in_features}, out_features={self.out_features}, "
|
|
166
|
+
f"bias={self.bias is not None}"
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
def _apply_stage(self, x: Tensor, factor: Tensor, stage: int) -> Tensor:
|
|
170
|
+
batch = x.shape[0]
|
|
171
|
+
n = self.in_features
|
|
172
|
+
block = 1 << (stage + 1)
|
|
173
|
+
half = block >> 1
|
|
174
|
+
|
|
175
|
+
staged = x.reshape(batch, -1, block)
|
|
176
|
+
staged = staged.reshape(batch, -1, half, 2)
|
|
177
|
+
staged = staged.permute(0, 1, 3, 2).contiguous()
|
|
178
|
+
pairs = staged.reshape(batch, -1, 2)
|
|
179
|
+
|
|
180
|
+
transformed = torch.einsum("bnc,ncd->bnd", pairs, factor)
|
|
181
|
+
|
|
182
|
+
transformed = transformed.reshape(batch, -1, 2, half)
|
|
183
|
+
transformed = transformed.permute(0, 1, 3, 2).contiguous()
|
|
184
|
+
transformed = transformed.reshape(batch, -1, block)
|
|
185
|
+
return transformed.reshape(batch, n)
|
|
186
|
+
|
|
187
|
+
def to_linear(self) -> nn.Linear:
|
|
188
|
+
"""Return a dense :class:`nn.Linear` with identical behaviour."""
|
|
189
|
+
|
|
190
|
+
factor_tensor = self.factors[0] if len(self.factors) > 0 else None
|
|
191
|
+
if factor_tensor is not None:
|
|
192
|
+
device = factor_tensor.device
|
|
193
|
+
dtype = factor_tensor.dtype
|
|
194
|
+
elif self.bias is not None:
|
|
195
|
+
device = self.bias.device
|
|
196
|
+
dtype = self.bias.dtype
|
|
197
|
+
else:
|
|
198
|
+
device = torch.device("cpu")
|
|
199
|
+
dtype = torch.get_default_dtype()
|
|
200
|
+
|
|
201
|
+
linear = nn.Linear(
|
|
202
|
+
self.in_features,
|
|
203
|
+
self.out_features,
|
|
204
|
+
bias=self.bias is not None,
|
|
205
|
+
device=device,
|
|
206
|
+
dtype=dtype,
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
with torch.no_grad():
|
|
210
|
+
if self.bias is not None:
|
|
211
|
+
bias_backup = self.bias.data.clone()
|
|
212
|
+
self.bias.zero_()
|
|
213
|
+
else:
|
|
214
|
+
bias_backup = None
|
|
215
|
+
|
|
216
|
+
identity = torch.eye(self.in_features, device=device, dtype=dtype)
|
|
217
|
+
weight_matrix = self(identity)
|
|
218
|
+
|
|
219
|
+
if bias_backup is not None:
|
|
220
|
+
self.bias.copy_(bias_backup)
|
|
221
|
+
linear.bias.copy_(bias_backup)
|
|
222
|
+
|
|
223
|
+
linear.weight.copy_(weight_matrix.t())
|
|
224
|
+
|
|
225
|
+
return linear
|