lrnnx 1.0.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lrnnx-1.0.0/LICENSE +21 -0
- lrnnx-1.0.0/MANIFEST.in +1 -0
- lrnnx-1.0.0/PKG-INFO +183 -0
- lrnnx-1.0.0/README.md +157 -0
- lrnnx-1.0.0/benchmarks/__init__.py +1 -0
- lrnnx-1.0.0/benchmarks/benchmark_cauchy.py +62 -0
- lrnnx-1.0.0/benchmarks/benchmark_cauchy_tune.py +88 -0
- lrnnx-1.0.0/benchmarks/benchmark_inference.py +191 -0
- lrnnx-1.0.0/benchmarks/benchmark_training.py +159 -0
- lrnnx-1.0.0/benchmarks/run_all.py +231 -0
- lrnnx-1.0.0/csrc/common.h +410 -0
- lrnnx-1.0.0/csrc/reverse_scan.cuh +416 -0
- lrnnx-1.0.0/csrc/s4/cauchy.cpp +102 -0
- lrnnx-1.0.0/csrc/s4/cauchy.py +116 -0
- lrnnx-1.0.0/csrc/s4/cauchy_cuda.cu +368 -0
- lrnnx-1.0.0/csrc/s4/map.h +72 -0
- lrnnx-1.0.0/csrc/s4/tune_cauchy.py +91 -0
- lrnnx-1.0.0/csrc/s4/tuner.py +219 -0
- lrnnx-1.0.0/csrc/s4/tuning_setup.py +43 -0
- lrnnx-1.0.0/csrc/s4/vandermonde.py +53 -0
- lrnnx-1.0.0/csrc/selective_scan/backward_kernels/selective_scan_fp32_complex_bilinear.cu +3 -0
- lrnnx-1.0.0/csrc/selective_scan/backward_kernels/selective_scan_fp32_complex_dirac.cu +3 -0
- lrnnx-1.0.0/csrc/selective_scan/backward_kernels/selective_scan_fp32_complex_mamba.cu +3 -0
- lrnnx-1.0.0/csrc/selective_scan/backward_kernels/selective_scan_fp32_complex_zoh.cu +3 -0
- lrnnx-1.0.0/csrc/selective_scan/backward_kernels/selective_scan_fp32_real_bilinear.cu +3 -0
- lrnnx-1.0.0/csrc/selective_scan/backward_kernels/selective_scan_fp32_real_dirac.cu +3 -0
- lrnnx-1.0.0/csrc/selective_scan/backward_kernels/selective_scan_fp32_real_mamba.cu +3 -0
- lrnnx-1.0.0/csrc/selective_scan/backward_kernels/selective_scan_fp32_real_rglru.cu +3 -0
- lrnnx-1.0.0/csrc/selective_scan/backward_kernels/selective_scan_fp32_real_s7.cu +3 -0
- lrnnx-1.0.0/csrc/selective_scan/backward_kernels/selective_scan_fp32_real_zoh.cu +3 -0
- lrnnx-1.0.0/csrc/selective_scan/bindings.cpp +8 -0
- lrnnx-1.0.0/csrc/selective_scan/forward_kernels/selective_scan_fp32_complex_bilinear.cu +3 -0
- lrnnx-1.0.0/csrc/selective_scan/forward_kernels/selective_scan_fp32_complex_dirac.cu +3 -0
- lrnnx-1.0.0/csrc/selective_scan/forward_kernels/selective_scan_fp32_complex_mamba.cu +3 -0
- lrnnx-1.0.0/csrc/selective_scan/forward_kernels/selective_scan_fp32_complex_zoh.cu +3 -0
- lrnnx-1.0.0/csrc/selective_scan/forward_kernels/selective_scan_fp32_real_bilinear.cu +3 -0
- lrnnx-1.0.0/csrc/selective_scan/forward_kernels/selective_scan_fp32_real_dirac.cu +3 -0
- lrnnx-1.0.0/csrc/selective_scan/forward_kernels/selective_scan_fp32_real_mamba.cu +3 -0
- lrnnx-1.0.0/csrc/selective_scan/forward_kernels/selective_scan_fp32_real_rglru.cu +3 -0
- lrnnx-1.0.0/csrc/selective_scan/forward_kernels/selective_scan_fp32_real_s7.cu +3 -0
- lrnnx-1.0.0/csrc/selective_scan/forward_kernels/selective_scan_fp32_real_zoh.cu +3 -0
- lrnnx-1.0.0/csrc/selective_scan/selective_scan.cuh +451 -0
- lrnnx-1.0.0/csrc/selective_scan/selective_scan.h +138 -0
- lrnnx-1.0.0/csrc/selective_scan/selective_scan_backward.cuh +924 -0
- lrnnx-1.0.0/csrc/selective_scan/selective_scan_cpu.cpp +632 -0
- lrnnx-1.0.0/csrc/selective_scan/uninitialized_copy.cuh +50 -0
- lrnnx-1.0.0/csrc/simplified_scan/backward_kernels/simplified_scan_bilinear.cu +3 -0
- lrnnx-1.0.0/csrc/simplified_scan/backward_kernels/simplified_scan_dirac.cu +3 -0
- lrnnx-1.0.0/csrc/simplified_scan/backward_kernels/simplified_scan_zoh.cu +3 -0
- lrnnx-1.0.0/csrc/simplified_scan/bindings.cpp +8 -0
- lrnnx-1.0.0/csrc/simplified_scan/forward_kernels/simplified_scan_fp32_bilinear.cu +3 -0
- lrnnx-1.0.0/csrc/simplified_scan/forward_kernels/simplified_scan_fp32_dirac.cu +3 -0
- lrnnx-1.0.0/csrc/simplified_scan/forward_kernels/simplified_scan_fp32_zoh.cu +3 -0
- lrnnx-1.0.0/csrc/simplified_scan/simplified_scan.cuh +315 -0
- lrnnx-1.0.0/csrc/simplified_scan/simplified_scan.h +111 -0
- lrnnx-1.0.0/csrc/simplified_scan/simplified_scan_backward.cuh +482 -0
- lrnnx-1.0.0/csrc/simplified_scan/simplified_scan_cpu.cpp +331 -0
- lrnnx-1.0.0/docs/source/api/__init__.py +0 -0
- lrnnx-1.0.0/docs/source/conf.py +0 -0
- lrnnx-1.0.0/docs/source/guides/__init__.py +0 -0
- lrnnx-1.0.0/docs/source/tutorials/__init__.py +0 -0
- lrnnx-1.0.0/lrnnx/__init__.py +0 -0
- lrnnx-1.0.0/lrnnx/architectures/__init__.py +0 -0
- lrnnx-1.0.0/lrnnx/architectures/classifier.py +450 -0
- lrnnx-1.0.0/lrnnx/architectures/embedding.py +82 -0
- lrnnx-1.0.0/lrnnx/architectures/language_model.py +825 -0
- lrnnx-1.0.0/lrnnx/architectures/lru_unet.py +247 -0
- lrnnx-1.0.0/lrnnx/core/__init__.py +0 -0
- lrnnx-1.0.0/lrnnx/core/base.py +79 -0
- lrnnx-1.0.0/lrnnx/core/convolution.py +210 -0
- lrnnx-1.0.0/lrnnx/core/discretization.py +164 -0
- lrnnx-1.0.0/lrnnx/layers/block.py +136 -0
- lrnnx-1.0.0/lrnnx/layers/mha.py +502 -0
- lrnnx-1.0.0/lrnnx/layers/mlp.py +77 -0
- lrnnx-1.0.0/lrnnx/models/__init__.py +0 -0
- lrnnx-1.0.0/lrnnx/models/lti/__init__.py +25 -0
- lrnnx-1.0.0/lrnnx/models/lti/base.py +152 -0
- lrnnx-1.0.0/lrnnx/models/lti/centaurus.py +568 -0
- lrnnx-1.0.0/lrnnx/models/lti/lru.py +263 -0
- lrnnx-1.0.0/lrnnx/models/lti/s4.py +435 -0
- lrnnx-1.0.0/lrnnx/models/lti/s4d.py +438 -0
- lrnnx-1.0.0/lrnnx/models/lti/s5.py +285 -0
- lrnnx-1.0.0/lrnnx/models/ltv/__init__.py +13 -0
- lrnnx-1.0.0/lrnnx/models/ltv/base.py +159 -0
- lrnnx-1.0.0/lrnnx/models/ltv/mamba.py +548 -0
- lrnnx-1.0.0/lrnnx/models/ltv/rglru.py +388 -0
- lrnnx-1.0.0/lrnnx/models/ltv/s5.py +299 -0
- lrnnx-1.0.0/lrnnx/models/ltv/s7.py +267 -0
- lrnnx-1.0.0/lrnnx/ops/__init__.py +25 -0
- lrnnx-1.0.0/lrnnx/ops/rglru_scan.py +679 -0
- lrnnx-1.0.0/lrnnx/ops/s4_kernel_interface.py +588 -0
- lrnnx-1.0.0/lrnnx/ops/s4_utils.py +860 -0
- lrnnx-1.0.0/lrnnx/ops/s7_scan.py +552 -0
- lrnnx-1.0.0/lrnnx/ops/selective_scan.py +850 -0
- lrnnx-1.0.0/lrnnx/ops/simplified_scan.py +640 -0
- lrnnx-1.0.0/lrnnx/ops/torch.py +24 -0
- lrnnx-1.0.0/lrnnx/ops/triton/__init__.py +0 -0
- lrnnx-1.0.0/lrnnx/ops/triton/layer_norm.py +1243 -0
- lrnnx-1.0.0/lrnnx/ops/triton/selective_state_update.py +599 -0
- lrnnx-1.0.0/lrnnx/ops/triton/simplified_state_update.py +469 -0
- lrnnx-1.0.0/lrnnx/ops/triton/softplus.py +18 -0
- lrnnx-1.0.0/lrnnx/utils/__init__.py +9 -0
- lrnnx-1.0.0/lrnnx/utils/generation.py +290 -0
- lrnnx-1.0.0/lrnnx/utils/init.py +108 -0
- lrnnx-1.0.0/lrnnx.egg-info/PKG-INFO +183 -0
- lrnnx-1.0.0/lrnnx.egg-info/SOURCES.txt +135 -0
- lrnnx-1.0.0/lrnnx.egg-info/dependency_links.txt +1 -0
- lrnnx-1.0.0/lrnnx.egg-info/requires.txt +16 -0
- lrnnx-1.0.0/lrnnx.egg-info/top_level.txt +10 -0
- lrnnx-1.0.0/pyproject.toml +51 -0
- lrnnx-1.0.0/setup.cfg +4 -0
- lrnnx-1.0.0/setup.py +147 -0
- lrnnx-1.0.0/tests/__init__.py +0 -0
- lrnnx-1.0.0/tests/architectures/test_language_model.py +481 -0
- lrnnx-1.0.0/tests/models/test_lti/test_centaurus.py +261 -0
- lrnnx-1.0.0/tests/models/test_lti/test_lru.py +205 -0
- lrnnx-1.0.0/tests/models/test_lti/test_s4.py +292 -0
- lrnnx-1.0.0/tests/models/test_lti/test_s4d.py +285 -0
- lrnnx-1.0.0/tests/models/test_lti/test_s5_lti.py +198 -0
- lrnnx-1.0.0/tests/models/test_ltv/test_event_mamba.py +157 -0
- lrnnx-1.0.0/tests/models/test_ltv/test_event_s5.py +155 -0
- lrnnx-1.0.0/tests/models/test_ltv/test_mamba.py +169 -0
- lrnnx-1.0.0/tests/models/test_ltv/test_rglru.py +149 -0
- lrnnx-1.0.0/tests/models/test_ltv/test_s5_ltv.py +277 -0
- lrnnx-1.0.0/tests/models/test_ltv/test_s7.py +162 -0
- lrnnx-1.0.0/tests/ops/mamba/test_selective_scan.py +389 -0
- lrnnx-1.0.0/tests/ops/mamba/test_selective_scan_async.py +424 -0
- lrnnx-1.0.0/tests/ops/mamba/test_selective_state_update.py +681 -0
- lrnnx-1.0.0/tests/ops/rglru/test_rglru_scan.py +354 -0
- lrnnx-1.0.0/tests/ops/s4/test_cauchy.py +119 -0
- lrnnx-1.0.0/tests/ops/s4/test_vandermonde.py +57 -0
- lrnnx-1.0.0/tests/ops/s5/test_simplified_scan.py +198 -0
- lrnnx-1.0.0/tests/ops/s5/test_simplified_scan_async.py +326 -0
- lrnnx-1.0.0/tests/ops/s5/test_simplified_state_update.py +149 -0
- lrnnx-1.0.0/tests/ops/s7/test_s7_scan.py +150 -0
- lrnnx-1.0.0/tests/utils/__init__.py +0 -0
- lrnnx-1.0.0/tests/utils/test_generation.py +218 -0
lrnnx-1.0.0/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2025 Society for Artificial Intelligence and Deep Learning
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
lrnnx-1.0.0/MANIFEST.in
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
recursive-include csrc *.cpp *.cu *.cuh *.h
|
lrnnx-1.0.0/PKG-INFO
ADDED
|
@@ -0,0 +1,183 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: lrnnx
|
|
3
|
+
Version: 1.0.0
|
|
4
|
+
Summary: A library for Linear RNNs
|
|
5
|
+
Author: SAiDl Team
|
|
6
|
+
License-Expression: MIT
|
|
7
|
+
Project-URL: Homepage, https://github.com/SforAiDl/lrnnx
|
|
8
|
+
Requires-Python: >=3.9
|
|
9
|
+
Description-Content-Type: text/markdown
|
|
10
|
+
License-File: LICENSE
|
|
11
|
+
Requires-Dist: numpy
|
|
12
|
+
Requires-Dist: torch
|
|
13
|
+
Requires-Dist: einops
|
|
14
|
+
Requires-Dist: ninja
|
|
15
|
+
Requires-Dist: packaging
|
|
16
|
+
Requires-Dist: opt-einsum
|
|
17
|
+
Requires-Dist: pykeops
|
|
18
|
+
Provides-Extra: dev
|
|
19
|
+
Requires-Dist: pytest; extra == "dev"
|
|
20
|
+
Requires-Dist: black; extra == "dev"
|
|
21
|
+
Requires-Dist: isort; extra == "dev"
|
|
22
|
+
Requires-Dist: mypy; extra == "dev"
|
|
23
|
+
Provides-Extra: conv1d
|
|
24
|
+
Requires-Dist: causal-conv1d; extra == "conv1d"
|
|
25
|
+
Dynamic: license-file
|
|
26
|
+
|
|
27
|
+
<!---
|
|
28
|
+
Copyright 2025 SAiDL Team. All rights reserved.
|
|
29
|
+
|
|
30
|
+
Licensed under the MIT License; you may not use this file except in compliance
|
|
31
|
+
with the License. You may obtain a copy of the License in the LICENSE file.
|
|
32
|
+
-->
|
|
33
|
+
|
|
34
|
+
# lrnnx: A library for Linear RNNs
|
|
35
|
+
<p>
|
|
36
|
+
<a href="LICENSE"><img alt="License" src="https://img.shields.io/github/license/SforAiDl/lrnnx.svg?color=blue"></a>
|
|
37
|
+
<a href="https://arxiv.org/abs/2602.08810"><img alt="arXiv" src="https://img.shields.io/badge/arXiv-2602.08810-b31b1b.svg"></a>
|
|
38
|
+
</p>
|
|
39
|
+
|
|
40
|
+
A unified PyTorch library providing easy access to state-of-the-art Linear RNN architectures for sequence modeling.
|
|
41
|
+
The technical report of this system was accepted to [EACL Student Research Workshop 2026](https://2026.eacl.org/calls/srw/).
|
|
42
|
+
We recommend reading the report before using / contributing to the library.
|
|
43
|
+
|
|
44
|
+
## Installation
|
|
45
|
+
|
|
46
|
+
### From PyPI
|
|
47
|
+
Since this library compiles custom CUDA kernels upon installation, we recommend using `--no-build-isolation` to avoid downloading a duplicate version of PyTorch during the build process.
|
|
48
|
+
```bash
|
|
49
|
+
# standard installation
|
|
50
|
+
pip install lrnnx --no-build-isolation
|
|
51
|
+
# with optional causal-conv1d
|
|
52
|
+
pip install "lrnnx[conv1d]" --no-build-isolation
|
|
53
|
+
# for development
|
|
54
|
+
pip install "lrnnx[dev]" --no-build-isolation
|
|
55
|
+
```
|
|
56
|
+
|
|
57
|
+
### From Source
|
|
58
|
+
We recommend installation with [`uv`](https://docs.astral.sh/uv/getting-started/installation/) for fast, reliable dependency management, though standard `pip` is fully supported.
|
|
59
|
+
|
|
60
|
+
#### Using uv
|
|
61
|
+
```bash
|
|
62
|
+
git clone https://github.com/SforAiDl/lrnnx.git
|
|
63
|
+
cd lrnnx
|
|
64
|
+
# standard installation
|
|
65
|
+
uv sync
|
|
66
|
+
# with optional causal-conv1d
|
|
67
|
+
uv sync --extra conv1d
|
|
68
|
+
# for development
|
|
69
|
+
uv sync --extra dev
|
|
70
|
+
```
|
|
71
|
+
|
|
72
|
+
#### Using pip
|
|
73
|
+
```bash
|
|
74
|
+
git clone https://github.com/SforAiDl/lrnnx.git
|
|
75
|
+
cd lrnnx
|
|
76
|
+
# standard installation
|
|
77
|
+
pip install -e . --no-build-isolation
|
|
78
|
+
# with optional causal-conv1d
|
|
79
|
+
pip install -e ".[conv1d]" --no-build-isolation
|
|
80
|
+
# for development
|
|
81
|
+
pip install -e ".[dev]" --no-build-isolation
|
|
82
|
+
```
|
|
83
|
+
|
|
84
|
+
Note that since our library builds several custom CUDA kernels, it can take time for this installation to finish.
|
|
85
|
+
Along with `causal-conv1d` the full installation can take about 30 minutes, depending on the number of CPUs available.
|
|
86
|
+
|
|
87
|
+
## Model Zoo
|
|
88
|
+
Our library provides implementations of the following Linear RNN architectures:
|
|
89
|
+
- [S4](https://openreview.net/forum?id=uYLFoz1vlAC)
|
|
90
|
+
- [S4D](https://dl.acm.org/doi/10.5555/3600270.3602877)
|
|
91
|
+
- [S5](https://openreview.net/forum?id=Ai8Hw3AXqks)
|
|
92
|
+
- [Event-SSM](https://www.computer.org/csdl/proceedings-article/icons/2024/686500a124/22lEawhJ0Va) (inside `S5`, use by passing `integration_timesteps`)
|
|
93
|
+
- [LRU](https://dl.acm.org/doi/10.5555/3618408.3619518)
|
|
94
|
+
- [S6](https://openreview.net/forum?id=tEYskw1VY2) (we implemented other discretizations)
|
|
95
|
+
- [STREAM](https://arxiv.org/abs/2411.12603) (inside `S6`, use by passing `integration_timesteps`)
|
|
96
|
+
- [RG-LRU](https://arxiv.org/abs/2402.19427)
|
|
97
|
+
- [S7](https://arxiv.org/abs/2410.03464)
|
|
98
|
+
- [aTENNuate](https://www.isca-archive.org/interspeech_2025/pei25_interspeech.html)
|
|
99
|
+
|
|
100
|
+
We expose several levels of API for each model, including a scan, a recurrent step, and a full layer API matching the paper.
|
|
101
|
+
For S5 we implement both a convolution based approach and a parallel scan approach.
|
|
102
|
+
The latter is more stable and faster for most use cases, but the convolution based approach can be faster for very long sequences.
|
|
103
|
+
|
|
104
|
+
## Usage
|
|
105
|
+
|
|
106
|
+
### Training
|
|
107
|
+
It is easy to instantiate a model from our library
|
|
108
|
+
```python
|
|
109
|
+
from lrnnx.models.lti import LRU
|
|
110
|
+
from lrnnx.models.ltv import Mamba
|
|
111
|
+
|
|
112
|
+
model_lti = LRU(d_model, d_state).cuda()
|
|
113
|
+
x = torch.randn(
|
|
114
|
+
batch_size, seq_len, d_model, dtype=torch.float32, device="cuda"
|
|
115
|
+
)
|
|
116
|
+
output = model_lti(x)
|
|
117
|
+
|
|
118
|
+
model_ltv = Mamba(d_model, d_state).cuda()
|
|
119
|
+
x = torch.randn(
|
|
120
|
+
batch_size, seq_len, d_model, dtype=torch.float32, device="cuda"
|
|
121
|
+
)
|
|
122
|
+
output = model_ltv(x)
|
|
123
|
+
```
|
|
124
|
+
|
|
125
|
+
### Inference
|
|
126
|
+
Linear RNNs in torch require special handling during inference, following [mamba](https://github.com/state-spaces/mamba), we also implement CUDA graphs based inference which reduces CPU overheads, this leads to > 10x speedup compared to using a simple for loop over the sequence length.
|
|
127
|
+
The main file is [generation.py](lrnnx/generation.py) which provides a simple API for autoregressive generation with any of the models in our library.
|
|
128
|
+
You can see a simple way to use it in our [benchmarking script](benchmarks/benchmark_inference.py).
|
|
129
|
+
|
|
130
|
+
### Reproducing the Benchmarks from the paper
|
|
131
|
+
This script will run both training and inference benchmarks.
|
|
132
|
+
```bash
|
|
133
|
+
python -m benchmarks.run_all
|
|
134
|
+
```
|
|
135
|
+
|
|
136
|
+
### Architectures
|
|
137
|
+
We also implement some common architectures based on the models in our library, such as a U-Net (inspired from [aTENNuate](https://www.isca-archive.org/interspeech_2025/pei25_interspeech.html) ) and a hierarchical classifier (inspired from [Event-SSM](https://www.computer.org/csdl/proceedings-article/icons/2024/686500a124/22lEawhJ0Va)).
|
|
138
|
+
Additionally, there is a [Language Model](lrnnx/models/language_model.py) architecture inspired from [Mamba](https://github.com/state-spaces/mamba) and [RG-LRU](https://arxiv.org/abs/2402.19427) which can be used for language modeling tasks, with replaceable LRNN and attention layers.
|
|
139
|
+
This can be used as
|
|
140
|
+
```python
|
|
141
|
+
from lrnnx.models.language_model import LRNNLMHeadModel
|
|
142
|
+
|
|
143
|
+
model = LRNNLMHeadModel(
|
|
144
|
+
d_model, d_state, num_layers, vocab_size, mixer_types=["s5", "s6", "attn"]
|
|
145
|
+
)
|
|
146
|
+
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len))
|
|
147
|
+
logits = model(input_ids)
|
|
148
|
+
```
|
|
149
|
+
|
|
150
|
+
### Tutorial Overview
|
|
151
|
+
|
|
152
|
+
Based on the architectures, there are tutorials on how to use them for 2 very popular use cases:
|
|
153
|
+
1. [U-Net Seq2Seq for audio denoising Tutorial](tutorials/notebooks/01_UNet.ipynb)
|
|
154
|
+
2. [Hierarchical Classification Tutorial](tutorials/notebooks/02_hierarchical_classifier.ipynb)
|
|
155
|
+
|
|
156
|
+
## Contributing
|
|
157
|
+
|
|
158
|
+
Please check out our [Contributing Guide](CONTRIBUTING.rst) for details on how to contribute to this project.
|
|
159
|
+
|
|
160
|
+
## Citation
|
|
161
|
+
|
|
162
|
+
If you use lrnnx in your research, please cite:
|
|
163
|
+
|
|
164
|
+
```bibtex
|
|
165
|
+
@misc{bania2026textttlrnnxlibrarylinearrnns,
|
|
166
|
+
title={$\texttt{lrnnx}$: A library for Linear RNNs},
|
|
167
|
+
author={Karan Bania and Soham Kalburgi and Manit Tanwar and Dhruthi and Aditya Nagarsekar and Harshvardhan Mestha and Naman Chibber and Raj Deshmukh and Anish Sathyanarayanan and Aarush Rathore and Pratham Chheda},
|
|
168
|
+
year={2026},
|
|
169
|
+
eprint={2602.08810},
|
|
170
|
+
archivePrefix={arXiv},
|
|
171
|
+
primaryClass={cs.LG},
|
|
172
|
+
url={https://arxiv.org/abs/2602.08810},
|
|
173
|
+
}
|
|
174
|
+
```
|
|
175
|
+
|
|
176
|
+
## License
|
|
177
|
+
|
|
178
|
+
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
|
|
179
|
+
|
|
180
|
+
## Acknowledgments
|
|
181
|
+
|
|
182
|
+
This library builds upon the excellent work of researchers who developed the individual LRNN models.
|
|
183
|
+
Please see individual model documentation for proper citations of the original papers.
|
lrnnx-1.0.0/README.md
ADDED
|
@@ -0,0 +1,157 @@
|
|
|
1
|
+
<!---
|
|
2
|
+
Copyright 2025 SAiDL Team. All rights reserved.
|
|
3
|
+
|
|
4
|
+
Licensed under the MIT License; you may not use this file except in compliance
|
|
5
|
+
with the License. You may obtain a copy of the License in the LICENSE file.
|
|
6
|
+
-->
|
|
7
|
+
|
|
8
|
+
# lrnnx: A library for Linear RNNs
|
|
9
|
+
<p>
|
|
10
|
+
<a href="LICENSE"><img alt="License" src="https://img.shields.io/github/license/SforAiDl/lrnnx.svg?color=blue"></a>
|
|
11
|
+
<a href="https://arxiv.org/abs/2602.08810"><img alt="arXiv" src="https://img.shields.io/badge/arXiv-2602.08810-b31b1b.svg"></a>
|
|
12
|
+
</p>
|
|
13
|
+
|
|
14
|
+
A unified PyTorch library providing easy access to state-of-the-art Linear RNN architectures for sequence modeling.
|
|
15
|
+
The technical report of this system was accepted to [EACL Student Research Workshop 2026](https://2026.eacl.org/calls/srw/).
|
|
16
|
+
We recommend reading the report before using / contributing to the library.
|
|
17
|
+
|
|
18
|
+
## Installation
|
|
19
|
+
|
|
20
|
+
### From PyPI
|
|
21
|
+
Since this library compiles custom CUDA kernels upon installation, we recommend using `--no-build-isolation` to avoid downloading a duplicate version of PyTorch during the build process.
|
|
22
|
+
```bash
|
|
23
|
+
# standard installation
|
|
24
|
+
pip install lrnnx --no-build-isolation
|
|
25
|
+
# with optional causal-conv1d
|
|
26
|
+
pip install "lrnnx[conv1d]" --no-build-isolation
|
|
27
|
+
# for development
|
|
28
|
+
pip install "lrnnx[dev]" --no-build-isolation
|
|
29
|
+
```
|
|
30
|
+
|
|
31
|
+
### From Source
|
|
32
|
+
We recommend installation with [`uv`](https://docs.astral.sh/uv/getting-started/installation/) for fast, reliable dependency management, though standard `pip` is fully supported.
|
|
33
|
+
|
|
34
|
+
#### Using uv
|
|
35
|
+
```bash
|
|
36
|
+
git clone https://github.com/SforAiDl/lrnnx.git
|
|
37
|
+
cd lrnnx
|
|
38
|
+
# standard installation
|
|
39
|
+
uv sync
|
|
40
|
+
# with optional causal-conv1d
|
|
41
|
+
uv sync --extra conv1d
|
|
42
|
+
# for development
|
|
43
|
+
uv sync --extra dev
|
|
44
|
+
```
|
|
45
|
+
|
|
46
|
+
#### Using pip
|
|
47
|
+
```bash
|
|
48
|
+
git clone https://github.com/SforAiDl/lrnnx.git
|
|
49
|
+
cd lrnnx
|
|
50
|
+
# standard installation
|
|
51
|
+
pip install -e . --no-build-isolation
|
|
52
|
+
# with optional causal-conv1d
|
|
53
|
+
pip install -e ".[conv1d]" --no-build-isolation
|
|
54
|
+
# for development
|
|
55
|
+
pip install -e ".[dev]" --no-build-isolation
|
|
56
|
+
```
|
|
57
|
+
|
|
58
|
+
Note that since our library builds several custom CUDA kernels, it can take time for this installation to finish.
|
|
59
|
+
Along with `causal-conv1d` the full installation can take about 30 minutes, depending on the number of CPUs available.
|
|
60
|
+
|
|
61
|
+
## Model Zoo
|
|
62
|
+
Our library provides implementations of the following Linear RNN architectures:
|
|
63
|
+
- [S4](https://openreview.net/forum?id=uYLFoz1vlAC)
|
|
64
|
+
- [S4D](https://dl.acm.org/doi/10.5555/3600270.3602877)
|
|
65
|
+
- [S5](https://openreview.net/forum?id=Ai8Hw3AXqks)
|
|
66
|
+
- [Event-SSM](https://www.computer.org/csdl/proceedings-article/icons/2024/686500a124/22lEawhJ0Va) (inside `S5`, use by passing `integration_timesteps`)
|
|
67
|
+
- [LRU](https://dl.acm.org/doi/10.5555/3618408.3619518)
|
|
68
|
+
- [S6](https://openreview.net/forum?id=tEYskw1VY2) (we implemented other discretizations)
|
|
69
|
+
- [STREAM](https://arxiv.org/abs/2411.12603) (inside `S6`, use by passing `integration_timesteps`)
|
|
70
|
+
- [RG-LRU](https://arxiv.org/abs/2402.19427)
|
|
71
|
+
- [S7](https://arxiv.org/abs/2410.03464)
|
|
72
|
+
- [aTENNuate](https://www.isca-archive.org/interspeech_2025/pei25_interspeech.html)
|
|
73
|
+
|
|
74
|
+
We expose several levels of API for each model, including a scan, a recurrent step, and a full layer API matching the paper.
|
|
75
|
+
For S5 we implement both a convolution based approach and a parallel scan approach.
|
|
76
|
+
The latter is more stable and faster for most use cases, but the convolution based approach can be faster for very long sequences.
|
|
77
|
+
|
|
78
|
+
## Usage
|
|
79
|
+
|
|
80
|
+
### Training
|
|
81
|
+
It is easy to instantiate a model from our library
|
|
82
|
+
```python
|
|
83
|
+
from lrnnx.models.lti import LRU
|
|
84
|
+
from lrnnx.models.ltv import Mamba
|
|
85
|
+
|
|
86
|
+
model_lti = LRU(d_model, d_state).cuda()
|
|
87
|
+
x = torch.randn(
|
|
88
|
+
batch_size, seq_len, d_model, dtype=torch.float32, device="cuda"
|
|
89
|
+
)
|
|
90
|
+
output = model_lti(x)
|
|
91
|
+
|
|
92
|
+
model_ltv = Mamba(d_model, d_state).cuda()
|
|
93
|
+
x = torch.randn(
|
|
94
|
+
batch_size, seq_len, d_model, dtype=torch.float32, device="cuda"
|
|
95
|
+
)
|
|
96
|
+
output = model_ltv(x)
|
|
97
|
+
```
|
|
98
|
+
|
|
99
|
+
### Inference
|
|
100
|
+
Linear RNNs in torch require special handling during inference, following [mamba](https://github.com/state-spaces/mamba), we also implement CUDA graphs based inference which reduces CPU overheads, this leads to > 10x speedup compared to using a simple for loop over the sequence length.
|
|
101
|
+
The main file is [generation.py](lrnnx/generation.py) which provides a simple API for autoregressive generation with any of the models in our library.
|
|
102
|
+
You can see a simple way to use it in our [benchmarking script](benchmarks/benchmark_inference.py).
|
|
103
|
+
|
|
104
|
+
### Reproducing the Benchmarks from the paper
|
|
105
|
+
This script will run both training and inference benchmarks.
|
|
106
|
+
```bash
|
|
107
|
+
python -m benchmarks.run_all
|
|
108
|
+
```
|
|
109
|
+
|
|
110
|
+
### Architectures
|
|
111
|
+
We also implement some common architectures based on the models in our library, such as a U-Net (inspired from [aTENNuate](https://www.isca-archive.org/interspeech_2025/pei25_interspeech.html) ) and a hierarchical classifier (inspired from [Event-SSM](https://www.computer.org/csdl/proceedings-article/icons/2024/686500a124/22lEawhJ0Va)).
|
|
112
|
+
Additionally, there is a [Language Model](lrnnx/models/language_model.py) architecture inspired from [Mamba](https://github.com/state-spaces/mamba) and [RG-LRU](https://arxiv.org/abs/2402.19427) which can be used for language modeling tasks, with replaceable LRNN and attention layers.
|
|
113
|
+
This can be used as
|
|
114
|
+
```python
|
|
115
|
+
from lrnnx.models.language_model import LRNNLMHeadModel
|
|
116
|
+
|
|
117
|
+
model = LRNNLMHeadModel(
|
|
118
|
+
d_model, d_state, num_layers, vocab_size, mixer_types=["s5", "s6", "attn"]
|
|
119
|
+
)
|
|
120
|
+
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len))
|
|
121
|
+
logits = model(input_ids)
|
|
122
|
+
```
|
|
123
|
+
|
|
124
|
+
### Tutorial Overview
|
|
125
|
+
|
|
126
|
+
Based on the architectures, there are tutorials on how to use them for 2 very popular use cases:
|
|
127
|
+
1. [U-Net Seq2Seq for audio denoising Tutorial](tutorials/notebooks/01_UNet.ipynb)
|
|
128
|
+
2. [Hierarchical Classification Tutorial](tutorials/notebooks/02_hierarchical_classifier.ipynb)
|
|
129
|
+
|
|
130
|
+
## Contributing
|
|
131
|
+
|
|
132
|
+
Please check out our [Contributing Guide](CONTRIBUTING.rst) for details on how to contribute to this project.
|
|
133
|
+
|
|
134
|
+
## Citation
|
|
135
|
+
|
|
136
|
+
If you use lrnnx in your research, please cite:
|
|
137
|
+
|
|
138
|
+
```bibtex
|
|
139
|
+
@misc{bania2026textttlrnnxlibrarylinearrnns,
|
|
140
|
+
title={$\texttt{lrnnx}$: A library for Linear RNNs},
|
|
141
|
+
author={Karan Bania and Soham Kalburgi and Manit Tanwar and Dhruthi and Aditya Nagarsekar and Harshvardhan Mestha and Naman Chibber and Raj Deshmukh and Anish Sathyanarayanan and Aarush Rathore and Pratham Chheda},
|
|
142
|
+
year={2026},
|
|
143
|
+
eprint={2602.08810},
|
|
144
|
+
archivePrefix={arXiv},
|
|
145
|
+
primaryClass={cs.LG},
|
|
146
|
+
url={https://arxiv.org/abs/2602.08810},
|
|
147
|
+
}
|
|
148
|
+
```
|
|
149
|
+
|
|
150
|
+
## License
|
|
151
|
+
|
|
152
|
+
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
|
|
153
|
+
|
|
154
|
+
## Acknowledgments
|
|
155
|
+
|
|
156
|
+
This library builds upon the excellent work of researchers who developed the individual LRNN models.
|
|
157
|
+
Please see individual model documentation for proper citations of the original papers.
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Benchmark utilities for lrnnx models."""
|
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from functools import partial
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
from benchmark.utils import (
|
|
6
|
+
benchmark_all,
|
|
7
|
+
benchmark_backward,
|
|
8
|
+
benchmark_combined,
|
|
9
|
+
benchmark_forward,
|
|
10
|
+
)
|
|
11
|
+
from einops import rearrange
|
|
12
|
+
|
|
13
|
+
from .cauchy import cauchy_mult, cauchy_mult_keops, cauchy_mult_torch
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def generate_data(batch_size, N, L, symmetric=True, device="cuda"):
|
|
17
|
+
if not symmetric:
|
|
18
|
+
v = torch.randn(
|
|
19
|
+
batch_size,
|
|
20
|
+
N,
|
|
21
|
+
dtype=torch.complex64,
|
|
22
|
+
device=device,
|
|
23
|
+
requires_grad=True,
|
|
24
|
+
)
|
|
25
|
+
w = torch.randn(
|
|
26
|
+
batch_size,
|
|
27
|
+
N,
|
|
28
|
+
dtype=torch.complex64,
|
|
29
|
+
device=device,
|
|
30
|
+
requires_grad=True,
|
|
31
|
+
)
|
|
32
|
+
z = torch.randn(L, dtype=torch.complex64, device=device)
|
|
33
|
+
else:
|
|
34
|
+
assert N % 2 == 0
|
|
35
|
+
v_half = torch.randn(
|
|
36
|
+
batch_size, N // 2, dtype=torch.complex64, device=device
|
|
37
|
+
)
|
|
38
|
+
v = torch.cat([v_half, v_half.conj()], dim=-1).requires_grad_(True)
|
|
39
|
+
w_half = torch.randn(
|
|
40
|
+
batch_size, N // 2, dtype=torch.complex64, device=device
|
|
41
|
+
)
|
|
42
|
+
w = torch.cat([w_half, w_half.conj()], dim=-1).requires_grad_(True)
|
|
43
|
+
z = torch.exp(1j * torch.randn(L, dtype=torch.float32, device=device))
|
|
44
|
+
return v, z, w
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
if __name__ == "__main__":
|
|
48
|
+
device = "cuda"
|
|
49
|
+
bs = 1024
|
|
50
|
+
N = 64
|
|
51
|
+
L = 16384
|
|
52
|
+
|
|
53
|
+
v, z, w = generate_data(bs, N, L, symmetric=True)
|
|
54
|
+
v_half = v[:, : N // 2].clone().detach().requires_grad_(True)
|
|
55
|
+
w_half = w[:, : N // 2].clone().detach().requires_grad_(True)
|
|
56
|
+
|
|
57
|
+
repeat = 30
|
|
58
|
+
benchmark_all(repeat, cauchy_mult_keops, v, z, w, desc="Cauchy mult keops")
|
|
59
|
+
fn = partial(cauchy_mult, symmetric=False)
|
|
60
|
+
benchmark_all(repeat, fn, v, z, w, desc="Cauchy mult")
|
|
61
|
+
fn = partial(cauchy_mult, symmetric=True)
|
|
62
|
+
benchmark_all(repeat, fn, v_half, z, w_half, desc="Cauchy mult symmetric")
|
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import importlib
|
|
3
|
+
import json
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from benchmark.utils import benchmark_forward
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def generate_data(batch_size, N, L, symmetric=True, device="cuda"):
|
|
10
|
+
if not symmetric:
|
|
11
|
+
v = torch.randn(
|
|
12
|
+
batch_size,
|
|
13
|
+
N,
|
|
14
|
+
dtype=torch.complex64,
|
|
15
|
+
device=device,
|
|
16
|
+
requires_grad=True,
|
|
17
|
+
)
|
|
18
|
+
w = torch.randn(
|
|
19
|
+
batch_size,
|
|
20
|
+
N,
|
|
21
|
+
dtype=torch.complex64,
|
|
22
|
+
device=device,
|
|
23
|
+
requires_grad=True,
|
|
24
|
+
)
|
|
25
|
+
z = torch.randn(L, dtype=torch.complex64, device=device)
|
|
26
|
+
else:
|
|
27
|
+
assert N % 2 == 0
|
|
28
|
+
v_half = torch.randn(
|
|
29
|
+
batch_size, N // 2, dtype=torch.complex64, device=device
|
|
30
|
+
)
|
|
31
|
+
v = torch.cat([v_half, v_half.conj()], dim=-1).requires_grad_(True)
|
|
32
|
+
w_half = torch.randn(
|
|
33
|
+
batch_size, N // 2, dtype=torch.complex64, device=device
|
|
34
|
+
)
|
|
35
|
+
w = torch.cat([w_half, w_half.conj()], dim=-1).requires_grad_(True)
|
|
36
|
+
z = torch.exp(1j * torch.randn(L, dtype=torch.float32, device=device))
|
|
37
|
+
return v, z, w
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
parser = argparse.ArgumentParser(description="Tuning Cauchy multiply")
|
|
41
|
+
parser.add_argument("--name", default="cauchy_mult")
|
|
42
|
+
parser.add_argument(
|
|
43
|
+
"--mode", default="forward", choices=["forward", "backward"]
|
|
44
|
+
)
|
|
45
|
+
parser.add_argument("-bs", "--batch-size", default=1024, type=int)
|
|
46
|
+
parser.add_argument("-N", default=64, type=int)
|
|
47
|
+
parser.add_argument("-L", default=2**14, type=int)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
if __name__ == "__main__":
|
|
51
|
+
args = parser.parse_args()
|
|
52
|
+
device = "cuda"
|
|
53
|
+
bs = args.batch_size
|
|
54
|
+
N = args.N
|
|
55
|
+
L = args.L
|
|
56
|
+
repeat = 30
|
|
57
|
+
v, z, w = generate_data(bs, N, L, symmetric=True)
|
|
58
|
+
v_half = v[:, : N // 2].clone().detach().requires_grad_(True)
|
|
59
|
+
w_half = w[:, : N // 2].clone().detach().requires_grad_(True)
|
|
60
|
+
|
|
61
|
+
tuning_extension_name = args.name
|
|
62
|
+
# print('Extension name:', tuning_extension_name)
|
|
63
|
+
module = importlib.import_module(tuning_extension_name)
|
|
64
|
+
if args.mode == "forward":
|
|
65
|
+
_, m = benchmark_forward(
|
|
66
|
+
repeat,
|
|
67
|
+
module.cauchy_mult_sym_fwd,
|
|
68
|
+
v_half,
|
|
69
|
+
z,
|
|
70
|
+
w_half,
|
|
71
|
+
verbose=False,
|
|
72
|
+
desc="Cauchy mult symmetric fwd",
|
|
73
|
+
)
|
|
74
|
+
else:
|
|
75
|
+
out = module.cauchy_mult_sym_fwd(v_half, z, w_half)
|
|
76
|
+
dout = torch.randn_like(out)
|
|
77
|
+
_, m = benchmark_forward(
|
|
78
|
+
repeat,
|
|
79
|
+
module.cauchy_mult_sym_bwd,
|
|
80
|
+
v_half,
|
|
81
|
+
z,
|
|
82
|
+
w_half,
|
|
83
|
+
dout,
|
|
84
|
+
verbose=False,
|
|
85
|
+
desc="Cauchy mult symmetric bwd",
|
|
86
|
+
)
|
|
87
|
+
result_dict = dict(time_mean=m.mean, time_iqr=m.iqr)
|
|
88
|
+
print(json.dumps(result_dict))
|