lrnnx 1.0.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (137) hide show
  1. lrnnx-1.0.0/LICENSE +21 -0
  2. lrnnx-1.0.0/MANIFEST.in +1 -0
  3. lrnnx-1.0.0/PKG-INFO +183 -0
  4. lrnnx-1.0.0/README.md +157 -0
  5. lrnnx-1.0.0/benchmarks/__init__.py +1 -0
  6. lrnnx-1.0.0/benchmarks/benchmark_cauchy.py +62 -0
  7. lrnnx-1.0.0/benchmarks/benchmark_cauchy_tune.py +88 -0
  8. lrnnx-1.0.0/benchmarks/benchmark_inference.py +191 -0
  9. lrnnx-1.0.0/benchmarks/benchmark_training.py +159 -0
  10. lrnnx-1.0.0/benchmarks/run_all.py +231 -0
  11. lrnnx-1.0.0/csrc/common.h +410 -0
  12. lrnnx-1.0.0/csrc/reverse_scan.cuh +416 -0
  13. lrnnx-1.0.0/csrc/s4/cauchy.cpp +102 -0
  14. lrnnx-1.0.0/csrc/s4/cauchy.py +116 -0
  15. lrnnx-1.0.0/csrc/s4/cauchy_cuda.cu +368 -0
  16. lrnnx-1.0.0/csrc/s4/map.h +72 -0
  17. lrnnx-1.0.0/csrc/s4/tune_cauchy.py +91 -0
  18. lrnnx-1.0.0/csrc/s4/tuner.py +219 -0
  19. lrnnx-1.0.0/csrc/s4/tuning_setup.py +43 -0
  20. lrnnx-1.0.0/csrc/s4/vandermonde.py +53 -0
  21. lrnnx-1.0.0/csrc/selective_scan/backward_kernels/selective_scan_fp32_complex_bilinear.cu +3 -0
  22. lrnnx-1.0.0/csrc/selective_scan/backward_kernels/selective_scan_fp32_complex_dirac.cu +3 -0
  23. lrnnx-1.0.0/csrc/selective_scan/backward_kernels/selective_scan_fp32_complex_mamba.cu +3 -0
  24. lrnnx-1.0.0/csrc/selective_scan/backward_kernels/selective_scan_fp32_complex_zoh.cu +3 -0
  25. lrnnx-1.0.0/csrc/selective_scan/backward_kernels/selective_scan_fp32_real_bilinear.cu +3 -0
  26. lrnnx-1.0.0/csrc/selective_scan/backward_kernels/selective_scan_fp32_real_dirac.cu +3 -0
  27. lrnnx-1.0.0/csrc/selective_scan/backward_kernels/selective_scan_fp32_real_mamba.cu +3 -0
  28. lrnnx-1.0.0/csrc/selective_scan/backward_kernels/selective_scan_fp32_real_rglru.cu +3 -0
  29. lrnnx-1.0.0/csrc/selective_scan/backward_kernels/selective_scan_fp32_real_s7.cu +3 -0
  30. lrnnx-1.0.0/csrc/selective_scan/backward_kernels/selective_scan_fp32_real_zoh.cu +3 -0
  31. lrnnx-1.0.0/csrc/selective_scan/bindings.cpp +8 -0
  32. lrnnx-1.0.0/csrc/selective_scan/forward_kernels/selective_scan_fp32_complex_bilinear.cu +3 -0
  33. lrnnx-1.0.0/csrc/selective_scan/forward_kernels/selective_scan_fp32_complex_dirac.cu +3 -0
  34. lrnnx-1.0.0/csrc/selective_scan/forward_kernels/selective_scan_fp32_complex_mamba.cu +3 -0
  35. lrnnx-1.0.0/csrc/selective_scan/forward_kernels/selective_scan_fp32_complex_zoh.cu +3 -0
  36. lrnnx-1.0.0/csrc/selective_scan/forward_kernels/selective_scan_fp32_real_bilinear.cu +3 -0
  37. lrnnx-1.0.0/csrc/selective_scan/forward_kernels/selective_scan_fp32_real_dirac.cu +3 -0
  38. lrnnx-1.0.0/csrc/selective_scan/forward_kernels/selective_scan_fp32_real_mamba.cu +3 -0
  39. lrnnx-1.0.0/csrc/selective_scan/forward_kernels/selective_scan_fp32_real_rglru.cu +3 -0
  40. lrnnx-1.0.0/csrc/selective_scan/forward_kernels/selective_scan_fp32_real_s7.cu +3 -0
  41. lrnnx-1.0.0/csrc/selective_scan/forward_kernels/selective_scan_fp32_real_zoh.cu +3 -0
  42. lrnnx-1.0.0/csrc/selective_scan/selective_scan.cuh +451 -0
  43. lrnnx-1.0.0/csrc/selective_scan/selective_scan.h +138 -0
  44. lrnnx-1.0.0/csrc/selective_scan/selective_scan_backward.cuh +924 -0
  45. lrnnx-1.0.0/csrc/selective_scan/selective_scan_cpu.cpp +632 -0
  46. lrnnx-1.0.0/csrc/selective_scan/uninitialized_copy.cuh +50 -0
  47. lrnnx-1.0.0/csrc/simplified_scan/backward_kernels/simplified_scan_bilinear.cu +3 -0
  48. lrnnx-1.0.0/csrc/simplified_scan/backward_kernels/simplified_scan_dirac.cu +3 -0
  49. lrnnx-1.0.0/csrc/simplified_scan/backward_kernels/simplified_scan_zoh.cu +3 -0
  50. lrnnx-1.0.0/csrc/simplified_scan/bindings.cpp +8 -0
  51. lrnnx-1.0.0/csrc/simplified_scan/forward_kernels/simplified_scan_fp32_bilinear.cu +3 -0
  52. lrnnx-1.0.0/csrc/simplified_scan/forward_kernels/simplified_scan_fp32_dirac.cu +3 -0
  53. lrnnx-1.0.0/csrc/simplified_scan/forward_kernels/simplified_scan_fp32_zoh.cu +3 -0
  54. lrnnx-1.0.0/csrc/simplified_scan/simplified_scan.cuh +315 -0
  55. lrnnx-1.0.0/csrc/simplified_scan/simplified_scan.h +111 -0
  56. lrnnx-1.0.0/csrc/simplified_scan/simplified_scan_backward.cuh +482 -0
  57. lrnnx-1.0.0/csrc/simplified_scan/simplified_scan_cpu.cpp +331 -0
  58. lrnnx-1.0.0/docs/source/api/__init__.py +0 -0
  59. lrnnx-1.0.0/docs/source/conf.py +0 -0
  60. lrnnx-1.0.0/docs/source/guides/__init__.py +0 -0
  61. lrnnx-1.0.0/docs/source/tutorials/__init__.py +0 -0
  62. lrnnx-1.0.0/lrnnx/__init__.py +0 -0
  63. lrnnx-1.0.0/lrnnx/architectures/__init__.py +0 -0
  64. lrnnx-1.0.0/lrnnx/architectures/classifier.py +450 -0
  65. lrnnx-1.0.0/lrnnx/architectures/embedding.py +82 -0
  66. lrnnx-1.0.0/lrnnx/architectures/language_model.py +825 -0
  67. lrnnx-1.0.0/lrnnx/architectures/lru_unet.py +247 -0
  68. lrnnx-1.0.0/lrnnx/core/__init__.py +0 -0
  69. lrnnx-1.0.0/lrnnx/core/base.py +79 -0
  70. lrnnx-1.0.0/lrnnx/core/convolution.py +210 -0
  71. lrnnx-1.0.0/lrnnx/core/discretization.py +164 -0
  72. lrnnx-1.0.0/lrnnx/layers/block.py +136 -0
  73. lrnnx-1.0.0/lrnnx/layers/mha.py +502 -0
  74. lrnnx-1.0.0/lrnnx/layers/mlp.py +77 -0
  75. lrnnx-1.0.0/lrnnx/models/__init__.py +0 -0
  76. lrnnx-1.0.0/lrnnx/models/lti/__init__.py +25 -0
  77. lrnnx-1.0.0/lrnnx/models/lti/base.py +152 -0
  78. lrnnx-1.0.0/lrnnx/models/lti/centaurus.py +568 -0
  79. lrnnx-1.0.0/lrnnx/models/lti/lru.py +263 -0
  80. lrnnx-1.0.0/lrnnx/models/lti/s4.py +435 -0
  81. lrnnx-1.0.0/lrnnx/models/lti/s4d.py +438 -0
  82. lrnnx-1.0.0/lrnnx/models/lti/s5.py +285 -0
  83. lrnnx-1.0.0/lrnnx/models/ltv/__init__.py +13 -0
  84. lrnnx-1.0.0/lrnnx/models/ltv/base.py +159 -0
  85. lrnnx-1.0.0/lrnnx/models/ltv/mamba.py +548 -0
  86. lrnnx-1.0.0/lrnnx/models/ltv/rglru.py +388 -0
  87. lrnnx-1.0.0/lrnnx/models/ltv/s5.py +299 -0
  88. lrnnx-1.0.0/lrnnx/models/ltv/s7.py +267 -0
  89. lrnnx-1.0.0/lrnnx/ops/__init__.py +25 -0
  90. lrnnx-1.0.0/lrnnx/ops/rglru_scan.py +679 -0
  91. lrnnx-1.0.0/lrnnx/ops/s4_kernel_interface.py +588 -0
  92. lrnnx-1.0.0/lrnnx/ops/s4_utils.py +860 -0
  93. lrnnx-1.0.0/lrnnx/ops/s7_scan.py +552 -0
  94. lrnnx-1.0.0/lrnnx/ops/selective_scan.py +850 -0
  95. lrnnx-1.0.0/lrnnx/ops/simplified_scan.py +640 -0
  96. lrnnx-1.0.0/lrnnx/ops/torch.py +24 -0
  97. lrnnx-1.0.0/lrnnx/ops/triton/__init__.py +0 -0
  98. lrnnx-1.0.0/lrnnx/ops/triton/layer_norm.py +1243 -0
  99. lrnnx-1.0.0/lrnnx/ops/triton/selective_state_update.py +599 -0
  100. lrnnx-1.0.0/lrnnx/ops/triton/simplified_state_update.py +469 -0
  101. lrnnx-1.0.0/lrnnx/ops/triton/softplus.py +18 -0
  102. lrnnx-1.0.0/lrnnx/utils/__init__.py +9 -0
  103. lrnnx-1.0.0/lrnnx/utils/generation.py +290 -0
  104. lrnnx-1.0.0/lrnnx/utils/init.py +108 -0
  105. lrnnx-1.0.0/lrnnx.egg-info/PKG-INFO +183 -0
  106. lrnnx-1.0.0/lrnnx.egg-info/SOURCES.txt +135 -0
  107. lrnnx-1.0.0/lrnnx.egg-info/dependency_links.txt +1 -0
  108. lrnnx-1.0.0/lrnnx.egg-info/requires.txt +16 -0
  109. lrnnx-1.0.0/lrnnx.egg-info/top_level.txt +10 -0
  110. lrnnx-1.0.0/pyproject.toml +51 -0
  111. lrnnx-1.0.0/setup.cfg +4 -0
  112. lrnnx-1.0.0/setup.py +147 -0
  113. lrnnx-1.0.0/tests/__init__.py +0 -0
  114. lrnnx-1.0.0/tests/architectures/test_language_model.py +481 -0
  115. lrnnx-1.0.0/tests/models/test_lti/test_centaurus.py +261 -0
  116. lrnnx-1.0.0/tests/models/test_lti/test_lru.py +205 -0
  117. lrnnx-1.0.0/tests/models/test_lti/test_s4.py +292 -0
  118. lrnnx-1.0.0/tests/models/test_lti/test_s4d.py +285 -0
  119. lrnnx-1.0.0/tests/models/test_lti/test_s5_lti.py +198 -0
  120. lrnnx-1.0.0/tests/models/test_ltv/test_event_mamba.py +157 -0
  121. lrnnx-1.0.0/tests/models/test_ltv/test_event_s5.py +155 -0
  122. lrnnx-1.0.0/tests/models/test_ltv/test_mamba.py +169 -0
  123. lrnnx-1.0.0/tests/models/test_ltv/test_rglru.py +149 -0
  124. lrnnx-1.0.0/tests/models/test_ltv/test_s5_ltv.py +277 -0
  125. lrnnx-1.0.0/tests/models/test_ltv/test_s7.py +162 -0
  126. lrnnx-1.0.0/tests/ops/mamba/test_selective_scan.py +389 -0
  127. lrnnx-1.0.0/tests/ops/mamba/test_selective_scan_async.py +424 -0
  128. lrnnx-1.0.0/tests/ops/mamba/test_selective_state_update.py +681 -0
  129. lrnnx-1.0.0/tests/ops/rglru/test_rglru_scan.py +354 -0
  130. lrnnx-1.0.0/tests/ops/s4/test_cauchy.py +119 -0
  131. lrnnx-1.0.0/tests/ops/s4/test_vandermonde.py +57 -0
  132. lrnnx-1.0.0/tests/ops/s5/test_simplified_scan.py +198 -0
  133. lrnnx-1.0.0/tests/ops/s5/test_simplified_scan_async.py +326 -0
  134. lrnnx-1.0.0/tests/ops/s5/test_simplified_state_update.py +149 -0
  135. lrnnx-1.0.0/tests/ops/s7/test_s7_scan.py +150 -0
  136. lrnnx-1.0.0/tests/utils/__init__.py +0 -0
  137. lrnnx-1.0.0/tests/utils/test_generation.py +218 -0
lrnnx-1.0.0/LICENSE ADDED
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Society for Artificial Intelligence and Deep Learning
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1 @@
1
+ recursive-include csrc *.cpp *.cu *.cuh *.h
lrnnx-1.0.0/PKG-INFO ADDED
@@ -0,0 +1,183 @@
1
+ Metadata-Version: 2.4
2
+ Name: lrnnx
3
+ Version: 1.0.0
4
+ Summary: A library for Linear RNNs
5
+ Author: SAiDl Team
6
+ License-Expression: MIT
7
+ Project-URL: Homepage, https://github.com/SforAiDl/lrnnx
8
+ Requires-Python: >=3.9
9
+ Description-Content-Type: text/markdown
10
+ License-File: LICENSE
11
+ Requires-Dist: numpy
12
+ Requires-Dist: torch
13
+ Requires-Dist: einops
14
+ Requires-Dist: ninja
15
+ Requires-Dist: packaging
16
+ Requires-Dist: opt-einsum
17
+ Requires-Dist: pykeops
18
+ Provides-Extra: dev
19
+ Requires-Dist: pytest; extra == "dev"
20
+ Requires-Dist: black; extra == "dev"
21
+ Requires-Dist: isort; extra == "dev"
22
+ Requires-Dist: mypy; extra == "dev"
23
+ Provides-Extra: conv1d
24
+ Requires-Dist: causal-conv1d; extra == "conv1d"
25
+ Dynamic: license-file
26
+
27
+ <!---
28
+ Copyright 2025 SAiDL Team. All rights reserved.
29
+
30
+ Licensed under the MIT License; you may not use this file except in compliance
31
+ with the License. You may obtain a copy of the License in the LICENSE file.
32
+ -->
33
+
34
+ # lrnnx: A library for Linear RNNs
35
+ <p>
36
+ <a href="LICENSE"><img alt="License" src="https://img.shields.io/github/license/SforAiDl/lrnnx.svg?color=blue"></a>
37
+ <a href="https://arxiv.org/abs/2602.08810"><img alt="arXiv" src="https://img.shields.io/badge/arXiv-2602.08810-b31b1b.svg"></a>
38
+ </p>
39
+
40
+ A unified PyTorch library providing easy access to state-of-the-art Linear RNN architectures for sequence modeling.
41
+ The technical report of this system was accepted to [EACL Student Research Workshop 2026](https://2026.eacl.org/calls/srw/).
42
+ We recommend reading the report before using / contributing to the library.
43
+
44
+ ## Installation
45
+
46
+ ### From PyPI
47
+ Since this library compiles custom CUDA kernels upon installation, we recommend using `--no-build-isolation` to avoid downloading a duplicate version of PyTorch during the build process.
48
+ ```bash
49
+ # standard installation
50
+ pip install lrnnx --no-build-isolation
51
+ # with optional causal-conv1d
52
+ pip install "lrnnx[conv1d]" --no-build-isolation
53
+ # for development
54
+ pip install "lrnnx[dev]" --no-build-isolation
55
+ ```
56
+
57
+ ### From Source
58
+ We recommend installation with [`uv`](https://docs.astral.sh/uv/getting-started/installation/) for fast, reliable dependency management, though standard `pip` is fully supported.
59
+
60
+ #### Using uv
61
+ ```bash
62
+ git clone https://github.com/SforAiDl/lrnnx.git
63
+ cd lrnnx
64
+ # standard installation
65
+ uv sync
66
+ # with optional causal-conv1d
67
+ uv sync --extra conv1d
68
+ # for development
69
+ uv sync --extra dev
70
+ ```
71
+
72
+ #### Using pip
73
+ ```bash
74
+ git clone https://github.com/SforAiDl/lrnnx.git
75
+ cd lrnnx
76
+ # standard installation
77
+ pip install -e . --no-build-isolation
78
+ # with optional causal-conv1d
79
+ pip install -e ".[conv1d]" --no-build-isolation
80
+ # for development
81
+ pip install -e ".[dev]" --no-build-isolation
82
+ ```
83
+
84
+ Note that since our library builds several custom CUDA kernels, it can take time for this installation to finish.
85
+ Along with `causal-conv1d` the full installation can take about 30 minutes, depending on the number of CPUs available.
86
+
87
+ ## Model Zoo
88
+ Our library provides implementations of the following Linear RNN architectures:
89
+ - [S4](https://openreview.net/forum?id=uYLFoz1vlAC)
90
+ - [S4D](https://dl.acm.org/doi/10.5555/3600270.3602877)
91
+ - [S5](https://openreview.net/forum?id=Ai8Hw3AXqks)
92
+ - [Event-SSM](https://www.computer.org/csdl/proceedings-article/icons/2024/686500a124/22lEawhJ0Va) (inside `S5`, use by passing `integration_timesteps`)
93
+ - [LRU](https://dl.acm.org/doi/10.5555/3618408.3619518)
94
+ - [S6](https://openreview.net/forum?id=tEYskw1VY2) (we implemented other discretizations)
95
+ - [STREAM](https://arxiv.org/abs/2411.12603) (inside `S6`, use by passing `integration_timesteps`)
96
+ - [RG-LRU](https://arxiv.org/abs/2402.19427)
97
+ - [S7](https://arxiv.org/abs/2410.03464)
98
+ - [aTENNuate](https://www.isca-archive.org/interspeech_2025/pei25_interspeech.html)
99
+
100
+ We expose several levels of API for each model, including a scan, a recurrent step, and a full layer API matching the paper.
101
+ For S5 we implement both a convolution based approach and a parallel scan approach.
102
+ The latter is more stable and faster for most use cases, but the convolution based approach can be faster for very long sequences.
103
+
104
+ ## Usage
105
+
106
+ ### Training
107
+ It is easy to instantiate a model from our library
108
+ ```python
109
+ from lrnnx.models.lti import LRU
110
+ from lrnnx.models.ltv import Mamba
111
+
112
+ model_lti = LRU(d_model, d_state).cuda()
113
+ x = torch.randn(
114
+ batch_size, seq_len, d_model, dtype=torch.float32, device="cuda"
115
+ )
116
+ output = model_lti(x)
117
+
118
+ model_ltv = Mamba(d_model, d_state).cuda()
119
+ x = torch.randn(
120
+ batch_size, seq_len, d_model, dtype=torch.float32, device="cuda"
121
+ )
122
+ output = model_ltv(x)
123
+ ```
124
+
125
+ ### Inference
126
+ Linear RNNs in torch require special handling during inference, following [mamba](https://github.com/state-spaces/mamba), we also implement CUDA graphs based inference which reduces CPU overheads, this leads to > 10x speedup compared to using a simple for loop over the sequence length.
127
+ The main file is [generation.py](lrnnx/generation.py) which provides a simple API for autoregressive generation with any of the models in our library.
128
+ You can see a simple way to use it in our [benchmarking script](benchmarks/benchmark_inference.py).
129
+
130
+ ### Reproducing the Benchmarks from the paper
131
+ This script will run both training and inference benchmarks.
132
+ ```bash
133
+ python -m benchmarks.run_all
134
+ ```
135
+
136
+ ### Architectures
137
+ We also implement some common architectures based on the models in our library, such as a U-Net (inspired from [aTENNuate](https://www.isca-archive.org/interspeech_2025/pei25_interspeech.html) ) and a hierarchical classifier (inspired from [Event-SSM](https://www.computer.org/csdl/proceedings-article/icons/2024/686500a124/22lEawhJ0Va)).
138
+ Additionally, there is a [Language Model](lrnnx/models/language_model.py) architecture inspired from [Mamba](https://github.com/state-spaces/mamba) and [RG-LRU](https://arxiv.org/abs/2402.19427) which can be used for language modeling tasks, with replaceable LRNN and attention layers.
139
+ This can be used as
140
+ ```python
141
+ from lrnnx.models.language_model import LRNNLMHeadModel
142
+
143
+ model = LRNNLMHeadModel(
144
+ d_model, d_state, num_layers, vocab_size, mixer_types=["s5", "s6", "attn"]
145
+ )
146
+ input_ids = torch.randint(0, vocab_size, (batch_size, seq_len))
147
+ logits = model(input_ids)
148
+ ```
149
+
150
+ ### Tutorial Overview
151
+
152
+ Based on the architectures, there are tutorials on how to use them for 2 very popular use cases:
153
+ 1. [U-Net Seq2Seq for audio denoising Tutorial](tutorials/notebooks/01_UNet.ipynb)
154
+ 2. [Hierarchical Classification Tutorial](tutorials/notebooks/02_hierarchical_classifier.ipynb)
155
+
156
+ ## Contributing
157
+
158
+ Please check out our [Contributing Guide](CONTRIBUTING.rst) for details on how to contribute to this project.
159
+
160
+ ## Citation
161
+
162
+ If you use lrnnx in your research, please cite:
163
+
164
+ ```bibtex
165
+ @misc{bania2026textttlrnnxlibrarylinearrnns,
166
+ title={$\texttt{lrnnx}$: A library for Linear RNNs},
167
+ author={Karan Bania and Soham Kalburgi and Manit Tanwar and Dhruthi and Aditya Nagarsekar and Harshvardhan Mestha and Naman Chibber and Raj Deshmukh and Anish Sathyanarayanan and Aarush Rathore and Pratham Chheda},
168
+ year={2026},
169
+ eprint={2602.08810},
170
+ archivePrefix={arXiv},
171
+ primaryClass={cs.LG},
172
+ url={https://arxiv.org/abs/2602.08810},
173
+ }
174
+ ```
175
+
176
+ ## License
177
+
178
+ This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
179
+
180
+ ## Acknowledgments
181
+
182
+ This library builds upon the excellent work of researchers who developed the individual LRNN models.
183
+ Please see individual model documentation for proper citations of the original papers.
lrnnx-1.0.0/README.md ADDED
@@ -0,0 +1,157 @@
1
+ <!---
2
+ Copyright 2025 SAiDL Team. All rights reserved.
3
+
4
+ Licensed under the MIT License; you may not use this file except in compliance
5
+ with the License. You may obtain a copy of the License in the LICENSE file.
6
+ -->
7
+
8
+ # lrnnx: A library for Linear RNNs
9
+ <p>
10
+ <a href="LICENSE"><img alt="License" src="https://img.shields.io/github/license/SforAiDl/lrnnx.svg?color=blue"></a>
11
+ <a href="https://arxiv.org/abs/2602.08810"><img alt="arXiv" src="https://img.shields.io/badge/arXiv-2602.08810-b31b1b.svg"></a>
12
+ </p>
13
+
14
+ A unified PyTorch library providing easy access to state-of-the-art Linear RNN architectures for sequence modeling.
15
+ The technical report of this system was accepted to [EACL Student Research Workshop 2026](https://2026.eacl.org/calls/srw/).
16
+ We recommend reading the report before using / contributing to the library.
17
+
18
+ ## Installation
19
+
20
+ ### From PyPI
21
+ Since this library compiles custom CUDA kernels upon installation, we recommend using `--no-build-isolation` to avoid downloading a duplicate version of PyTorch during the build process.
22
+ ```bash
23
+ # standard installation
24
+ pip install lrnnx --no-build-isolation
25
+ # with optional causal-conv1d
26
+ pip install "lrnnx[conv1d]" --no-build-isolation
27
+ # for development
28
+ pip install "lrnnx[dev]" --no-build-isolation
29
+ ```
30
+
31
+ ### From Source
32
+ We recommend installation with [`uv`](https://docs.astral.sh/uv/getting-started/installation/) for fast, reliable dependency management, though standard `pip` is fully supported.
33
+
34
+ #### Using uv
35
+ ```bash
36
+ git clone https://github.com/SforAiDl/lrnnx.git
37
+ cd lrnnx
38
+ # standard installation
39
+ uv sync
40
+ # with optional causal-conv1d
41
+ uv sync --extra conv1d
42
+ # for development
43
+ uv sync --extra dev
44
+ ```
45
+
46
+ #### Using pip
47
+ ```bash
48
+ git clone https://github.com/SforAiDl/lrnnx.git
49
+ cd lrnnx
50
+ # standard installation
51
+ pip install -e . --no-build-isolation
52
+ # with optional causal-conv1d
53
+ pip install -e ".[conv1d]" --no-build-isolation
54
+ # for development
55
+ pip install -e ".[dev]" --no-build-isolation
56
+ ```
57
+
58
+ Note that since our library builds several custom CUDA kernels, it can take time for this installation to finish.
59
+ Along with `causal-conv1d` the full installation can take about 30 minutes, depending on the number of CPUs available.
60
+
61
+ ## Model Zoo
62
+ Our library provides implementations of the following Linear RNN architectures:
63
+ - [S4](https://openreview.net/forum?id=uYLFoz1vlAC)
64
+ - [S4D](https://dl.acm.org/doi/10.5555/3600270.3602877)
65
+ - [S5](https://openreview.net/forum?id=Ai8Hw3AXqks)
66
+ - [Event-SSM](https://www.computer.org/csdl/proceedings-article/icons/2024/686500a124/22lEawhJ0Va) (inside `S5`, use by passing `integration_timesteps`)
67
+ - [LRU](https://dl.acm.org/doi/10.5555/3618408.3619518)
68
+ - [S6](https://openreview.net/forum?id=tEYskw1VY2) (we implemented other discretizations)
69
+ - [STREAM](https://arxiv.org/abs/2411.12603) (inside `S6`, use by passing `integration_timesteps`)
70
+ - [RG-LRU](https://arxiv.org/abs/2402.19427)
71
+ - [S7](https://arxiv.org/abs/2410.03464)
72
+ - [aTENNuate](https://www.isca-archive.org/interspeech_2025/pei25_interspeech.html)
73
+
74
+ We expose several levels of API for each model, including a scan, a recurrent step, and a full layer API matching the paper.
75
+ For S5 we implement both a convolution based approach and a parallel scan approach.
76
+ The latter is more stable and faster for most use cases, but the convolution based approach can be faster for very long sequences.
77
+
78
+ ## Usage
79
+
80
+ ### Training
81
+ It is easy to instantiate a model from our library
82
+ ```python
83
+ from lrnnx.models.lti import LRU
84
+ from lrnnx.models.ltv import Mamba
85
+
86
+ model_lti = LRU(d_model, d_state).cuda()
87
+ x = torch.randn(
88
+ batch_size, seq_len, d_model, dtype=torch.float32, device="cuda"
89
+ )
90
+ output = model_lti(x)
91
+
92
+ model_ltv = Mamba(d_model, d_state).cuda()
93
+ x = torch.randn(
94
+ batch_size, seq_len, d_model, dtype=torch.float32, device="cuda"
95
+ )
96
+ output = model_ltv(x)
97
+ ```
98
+
99
+ ### Inference
100
+ Linear RNNs in torch require special handling during inference, following [mamba](https://github.com/state-spaces/mamba), we also implement CUDA graphs based inference which reduces CPU overheads, this leads to > 10x speedup compared to using a simple for loop over the sequence length.
101
+ The main file is [generation.py](lrnnx/generation.py) which provides a simple API for autoregressive generation with any of the models in our library.
102
+ You can see a simple way to use it in our [benchmarking script](benchmarks/benchmark_inference.py).
103
+
104
+ ### Reproducing the Benchmarks from the paper
105
+ This script will run both training and inference benchmarks.
106
+ ```bash
107
+ python -m benchmarks.run_all
108
+ ```
109
+
110
+ ### Architectures
111
+ We also implement some common architectures based on the models in our library, such as a U-Net (inspired from [aTENNuate](https://www.isca-archive.org/interspeech_2025/pei25_interspeech.html) ) and a hierarchical classifier (inspired from [Event-SSM](https://www.computer.org/csdl/proceedings-article/icons/2024/686500a124/22lEawhJ0Va)).
112
+ Additionally, there is a [Language Model](lrnnx/models/language_model.py) architecture inspired from [Mamba](https://github.com/state-spaces/mamba) and [RG-LRU](https://arxiv.org/abs/2402.19427) which can be used for language modeling tasks, with replaceable LRNN and attention layers.
113
+ This can be used as
114
+ ```python
115
+ from lrnnx.models.language_model import LRNNLMHeadModel
116
+
117
+ model = LRNNLMHeadModel(
118
+ d_model, d_state, num_layers, vocab_size, mixer_types=["s5", "s6", "attn"]
119
+ )
120
+ input_ids = torch.randint(0, vocab_size, (batch_size, seq_len))
121
+ logits = model(input_ids)
122
+ ```
123
+
124
+ ### Tutorial Overview
125
+
126
+ Based on the architectures, there are tutorials on how to use them for 2 very popular use cases:
127
+ 1. [U-Net Seq2Seq for audio denoising Tutorial](tutorials/notebooks/01_UNet.ipynb)
128
+ 2. [Hierarchical Classification Tutorial](tutorials/notebooks/02_hierarchical_classifier.ipynb)
129
+
130
+ ## Contributing
131
+
132
+ Please check out our [Contributing Guide](CONTRIBUTING.rst) for details on how to contribute to this project.
133
+
134
+ ## Citation
135
+
136
+ If you use lrnnx in your research, please cite:
137
+
138
+ ```bibtex
139
+ @misc{bania2026textttlrnnxlibrarylinearrnns,
140
+ title={$\texttt{lrnnx}$: A library for Linear RNNs},
141
+ author={Karan Bania and Soham Kalburgi and Manit Tanwar and Dhruthi and Aditya Nagarsekar and Harshvardhan Mestha and Naman Chibber and Raj Deshmukh and Anish Sathyanarayanan and Aarush Rathore and Pratham Chheda},
142
+ year={2026},
143
+ eprint={2602.08810},
144
+ archivePrefix={arXiv},
145
+ primaryClass={cs.LG},
146
+ url={https://arxiv.org/abs/2602.08810},
147
+ }
148
+ ```
149
+
150
+ ## License
151
+
152
+ This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
153
+
154
+ ## Acknowledgments
155
+
156
+ This library builds upon the excellent work of researchers who developed the individual LRNN models.
157
+ Please see individual model documentation for proper citations of the original papers.
@@ -0,0 +1 @@
1
+ """Benchmark utilities for lrnnx models."""
@@ -0,0 +1,62 @@
1
+ import math
2
+ from functools import partial
3
+
4
+ import torch
5
+ from benchmark.utils import (
6
+ benchmark_all,
7
+ benchmark_backward,
8
+ benchmark_combined,
9
+ benchmark_forward,
10
+ )
11
+ from einops import rearrange
12
+
13
+ from .cauchy import cauchy_mult, cauchy_mult_keops, cauchy_mult_torch
14
+
15
+
16
+ def generate_data(batch_size, N, L, symmetric=True, device="cuda"):
17
+ if not symmetric:
18
+ v = torch.randn(
19
+ batch_size,
20
+ N,
21
+ dtype=torch.complex64,
22
+ device=device,
23
+ requires_grad=True,
24
+ )
25
+ w = torch.randn(
26
+ batch_size,
27
+ N,
28
+ dtype=torch.complex64,
29
+ device=device,
30
+ requires_grad=True,
31
+ )
32
+ z = torch.randn(L, dtype=torch.complex64, device=device)
33
+ else:
34
+ assert N % 2 == 0
35
+ v_half = torch.randn(
36
+ batch_size, N // 2, dtype=torch.complex64, device=device
37
+ )
38
+ v = torch.cat([v_half, v_half.conj()], dim=-1).requires_grad_(True)
39
+ w_half = torch.randn(
40
+ batch_size, N // 2, dtype=torch.complex64, device=device
41
+ )
42
+ w = torch.cat([w_half, w_half.conj()], dim=-1).requires_grad_(True)
43
+ z = torch.exp(1j * torch.randn(L, dtype=torch.float32, device=device))
44
+ return v, z, w
45
+
46
+
47
+ if __name__ == "__main__":
48
+ device = "cuda"
49
+ bs = 1024
50
+ N = 64
51
+ L = 16384
52
+
53
+ v, z, w = generate_data(bs, N, L, symmetric=True)
54
+ v_half = v[:, : N // 2].clone().detach().requires_grad_(True)
55
+ w_half = w[:, : N // 2].clone().detach().requires_grad_(True)
56
+
57
+ repeat = 30
58
+ benchmark_all(repeat, cauchy_mult_keops, v, z, w, desc="Cauchy mult keops")
59
+ fn = partial(cauchy_mult, symmetric=False)
60
+ benchmark_all(repeat, fn, v, z, w, desc="Cauchy mult")
61
+ fn = partial(cauchy_mult, symmetric=True)
62
+ benchmark_all(repeat, fn, v_half, z, w_half, desc="Cauchy mult symmetric")
@@ -0,0 +1,88 @@
1
+ import argparse
2
+ import importlib
3
+ import json
4
+
5
+ import torch
6
+ from benchmark.utils import benchmark_forward
7
+
8
+
9
+ def generate_data(batch_size, N, L, symmetric=True, device="cuda"):
10
+ if not symmetric:
11
+ v = torch.randn(
12
+ batch_size,
13
+ N,
14
+ dtype=torch.complex64,
15
+ device=device,
16
+ requires_grad=True,
17
+ )
18
+ w = torch.randn(
19
+ batch_size,
20
+ N,
21
+ dtype=torch.complex64,
22
+ device=device,
23
+ requires_grad=True,
24
+ )
25
+ z = torch.randn(L, dtype=torch.complex64, device=device)
26
+ else:
27
+ assert N % 2 == 0
28
+ v_half = torch.randn(
29
+ batch_size, N // 2, dtype=torch.complex64, device=device
30
+ )
31
+ v = torch.cat([v_half, v_half.conj()], dim=-1).requires_grad_(True)
32
+ w_half = torch.randn(
33
+ batch_size, N // 2, dtype=torch.complex64, device=device
34
+ )
35
+ w = torch.cat([w_half, w_half.conj()], dim=-1).requires_grad_(True)
36
+ z = torch.exp(1j * torch.randn(L, dtype=torch.float32, device=device))
37
+ return v, z, w
38
+
39
+
40
+ parser = argparse.ArgumentParser(description="Tuning Cauchy multiply")
41
+ parser.add_argument("--name", default="cauchy_mult")
42
+ parser.add_argument(
43
+ "--mode", default="forward", choices=["forward", "backward"]
44
+ )
45
+ parser.add_argument("-bs", "--batch-size", default=1024, type=int)
46
+ parser.add_argument("-N", default=64, type=int)
47
+ parser.add_argument("-L", default=2**14, type=int)
48
+
49
+
50
+ if __name__ == "__main__":
51
+ args = parser.parse_args()
52
+ device = "cuda"
53
+ bs = args.batch_size
54
+ N = args.N
55
+ L = args.L
56
+ repeat = 30
57
+ v, z, w = generate_data(bs, N, L, symmetric=True)
58
+ v_half = v[:, : N // 2].clone().detach().requires_grad_(True)
59
+ w_half = w[:, : N // 2].clone().detach().requires_grad_(True)
60
+
61
+ tuning_extension_name = args.name
62
+ # print('Extension name:', tuning_extension_name)
63
+ module = importlib.import_module(tuning_extension_name)
64
+ if args.mode == "forward":
65
+ _, m = benchmark_forward(
66
+ repeat,
67
+ module.cauchy_mult_sym_fwd,
68
+ v_half,
69
+ z,
70
+ w_half,
71
+ verbose=False,
72
+ desc="Cauchy mult symmetric fwd",
73
+ )
74
+ else:
75
+ out = module.cauchy_mult_sym_fwd(v_half, z, w_half)
76
+ dout = torch.randn_like(out)
77
+ _, m = benchmark_forward(
78
+ repeat,
79
+ module.cauchy_mult_sym_bwd,
80
+ v_half,
81
+ z,
82
+ w_half,
83
+ dout,
84
+ verbose=False,
85
+ desc="Cauchy mult symmetric bwd",
86
+ )
87
+ result_dict = dict(time_mean=m.mean, time_iqr=m.iqr)
88
+ print(json.dumps(result_dict))