stacked-linear 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,4 @@
1
+ from .linear_layer import LinearLayer
2
+ from .stacked_linear_layer import StackedLinearLayer
3
+
4
+ __all__ = ["LinearLayer", "StackedLinearLayer"]
@@ -0,0 +1,55 @@
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+
7
+
8
+ class LinearLayer(nn.Linear):
9
+ """Linear layer with support for output weight subsetting.
10
+
11
+ This layer behaves like a normal nn.Linear but adds the ability to
12
+ perform the forward pass on a subset of the output features.
13
+ """
14
+
15
+ def forward(self, x: torch.Tensor, output_subset: torch.Tensor | None = None) -> torch.Tensor:
16
+ """Forward pass with optional output subsetting.
17
+
18
+ Parameters
19
+ ----------
20
+ x
21
+ Input tensor with shape (..., in_features).
22
+ output_subset
23
+ Indices of the output features to compute. If None, all features
24
+ are computed.
25
+
26
+ Returns
27
+ -------
28
+ torch.Tensor
29
+ Output tensor with shape (..., out_features) or (..., len(output_subset)).
30
+
31
+ Examples
32
+ --------
33
+ >>> import torch
34
+ >>> layer = LinearLayer(10, 5)
35
+ >>> x = torch.randn(2, 10)
36
+ >>> # Standard forward pass
37
+ >>> out = layer(x)
38
+ >>> out.shape
39
+ torch.Size([2, 5])
40
+ >>> # Subset forward pass
41
+ >>> subset = torch.tensor([0, 2])
42
+ >>> out_subset = layer(x, output_subset=subset)
43
+ >>> out_subset.shape
44
+ torch.Size([2, 2])
45
+ """
46
+ if output_subset is None:
47
+ # x: (..., i) -> output: (..., o)
48
+ return super().forward(x)
49
+ elif output_subset.dim() == 1:
50
+ # x: (..., i) -> output_subset: (o_subset)
51
+ bias = self.bias[output_subset] if self.bias is not None else None # (o_subset)
52
+ weight = self.weight[output_subset] # (o_subset, i)
53
+ return F.linear(x, weight, bias) # (..., i) -> (..., o_subset)
54
+ else:
55
+ raise NotImplementedError()
@@ -0,0 +1,224 @@
1
+ from __future__ import annotations
2
+
3
+ import math
4
+ from typing import TYPE_CHECKING
5
+
6
+ import torch
7
+ from torch import nn
8
+
9
+ if TYPE_CHECKING:
10
+ from typing import Any
11
+
12
+
13
+ class StackedLinearLayer(nn.Module):
14
+ """A parallel stacked linear layer that applies multiple linear transformations in parallel.
15
+
16
+ This layer applies a linear transformation to multiple stacks/splits
17
+ of the input. It's particularly useful in additive decoders where
18
+ different splits should be calculated in parallel.
19
+
20
+ Parameters
21
+ ----------
22
+ n_stacks
23
+ Number of stacks/splits to process in parallel.
24
+ in_features
25
+ Number of input features per stack.
26
+ out_features
27
+ Number of output features per stack.
28
+ bias
29
+ Whether to include bias terms for each stack.
30
+ device
31
+ Device to place the layer on.
32
+ dtype
33
+ Data type for the layer parameters.
34
+
35
+ Notes
36
+ -----
37
+ The layer maintains separate weight and bias parameters for each stack:
38
+ - Weight shape: (n_stacks, in_features, out_features)
39
+ - Bias shape: (n_stacks, out_features) if bias=True, None otherwise
40
+
41
+ The forward pass applies the transformation to each stack independently:
42
+ output[b, s, o] = sum_i(x[b, s, i] * weight[s, i, o]) + bias[s, o]
43
+
44
+ This is equivalent to applying n_stacks separate linear layers in parallel,
45
+ which is more efficient than using separate nn.Linear layers.
46
+
47
+ Examples
48
+ --------
49
+ >>> import torch
50
+ >>> # Create a stacked linear layer with 4 stacks
51
+ >>> layer = StackedLinearLayer(n_stacks=4, in_features=64, out_features=128)
52
+ >>> # Input shape: (batch_size, n_stacks, in_features)
53
+ >>> x = torch.randn(32, 4, 64)
54
+ >>> # Forward pass
55
+ >>> output = layer(x)
56
+ >>> print(output.shape) # torch.Size([32, 4, 128])
57
+ >>> # Each stack has its own parameters
58
+ >>> print(layer.weight.shape) # torch.Size([4, 64, 128])
59
+ >>> print(layer.bias.shape) # torch.Size([4, 128])
60
+ """
61
+
62
+ __constants__ = ["n_stacks", "in_features", "out_features"]
63
+ n_stacks: int
64
+ in_features: int
65
+ out_features: int
66
+ weight: torch.Tensor
67
+ bias: torch.Tensor | None
68
+
69
+ def __init__(
70
+ self,
71
+ n_stacks: int,
72
+ in_features: int,
73
+ out_features: int,
74
+ bias: bool = True,
75
+ device: Any = None,
76
+ dtype: Any = None,
77
+ ) -> None:
78
+ factory_kwargs = {"device": device, "dtype": dtype}
79
+ super().__init__()
80
+ self.n_stacks = n_stacks
81
+ self.in_features = in_features
82
+ self.out_features = out_features
83
+ self.weight = nn.Parameter(torch.empty((n_stacks, in_features, out_features), **factory_kwargs))
84
+ if bias:
85
+ self.bias = nn.Parameter(torch.empty(n_stacks, out_features, **factory_kwargs))
86
+ else:
87
+ self.register_parameter("bias", None)
88
+ self.reset_parameters()
89
+
90
+ def reset_parameters(self) -> None:
91
+ """Reset the layer parameters to their initial values.
92
+
93
+ This method reinitializes both weights and biases using the same
94
+ initialization strategy as the default nn.Linear layer.
95
+
96
+ Notes
97
+ -----
98
+ The initialization follows PyTorch's default linear layer initialization:
99
+ - Weights: Uniform distribution in [-1/sqrt(in_features), 1/sqrt(in_features)]
100
+ - Biases: Uniform distribution in [-1/sqrt(in_features), 1/sqrt(in_features)]
101
+
102
+ This ensures that the variance of the output is approximately preserved
103
+ across the layer.
104
+ """
105
+ self._init_weight()
106
+ self._init_bias()
107
+
108
+ def _init_weight(self) -> None:
109
+ """Initialize the weight parameters.
110
+
111
+ Notes
112
+ -----
113
+ Uses the same initialization as default nn.Linear:
114
+ Uniform distribution in [-1/sqrt(in_features), 1/sqrt(in_features)]
115
+
116
+ This initialization helps maintain the variance of activations
117
+ across the network, which is important for training stability.
118
+ """
119
+ # Same as default nn.Linear (https://github.com/pytorch/pytorch/issues/57109)
120
+ fan_in = self.in_features
121
+ bound = 1 / math.sqrt(fan_in)
122
+ nn.init.uniform_(self.weight, -bound, bound)
123
+
124
+ def _init_bias(self) -> None:
125
+ """Initialize the bias parameters.
126
+
127
+ Notes
128
+ -----
129
+ Uses the same initialization as default nn.Linear:
130
+ Uniform distribution in [-1/sqrt(in_features), 1/sqrt(in_features)]
131
+
132
+ The bias initialization is independent of the weight initialization
133
+ and helps ensure that the layer can learn appropriate offsets.
134
+ """
135
+ if self.bias is not None:
136
+ fan_in = self.in_features
137
+ bound = 1 / math.sqrt(fan_in)
138
+ nn.init.uniform_(self.bias, -bound, bound)
139
+
140
+ def forward(
141
+ self,
142
+ x: torch.Tensor,
143
+ output_subset: torch.Tensor | None = None,
144
+ stack_subset: torch.Tensor | None = None,
145
+ ) -> torch.Tensor:
146
+ r"""Forward pass through the stacked linear layer.
147
+
148
+ Parameters
149
+ ----------
150
+ x
151
+ Input tensor with shape (batch_size, n_stacks, in_features).
152
+ output_subset
153
+ Subset of outputs to provide in the output.
154
+ stack_subset
155
+ Indices for stacks in operation.
156
+
157
+ Returns
158
+ -------
159
+ torch.Tensor
160
+ Output tensor with shape (batch_size, n_stacks, out_features).
161
+
162
+ Notes
163
+ -----
164
+ The forward pass applies the linear transformation to each stack:
165
+
166
+ .. math::
167
+ \text{output}[b, s, o] = \\sum_{i} \text{input}[b, s, i] \\cdot \text{weight}[s, i, o] + \text{bias}[s, o]
168
+
169
+ where:
170
+ - b: batch index
171
+ - s: stack index
172
+ - i: input feature index
173
+ - o: output feature index
174
+
175
+ The computation is performed efficiently using torch.bmm or broadcasting.
176
+
177
+ Examples
178
+ --------
179
+ >>> import torch
180
+ >>> # Create layer
181
+ >>> layer = StackedLinearLayer(n_stacks=3, in_features=10, out_features=5)
182
+ >>> # Input: batch_size=2, n_stacks=3, in_features=10
183
+ >>> x = torch.randn(2, 3, 10)
184
+ >>> # Forward pass
185
+ >>> output = layer(x)
186
+ >>> print(output.shape) # torch.Size([2, 3, 5])
187
+ """
188
+ if stack_subset is None:
189
+ if output_subset is None or output_subset.dim() == 1:
190
+ # weight: (s, i, o), bias: (s, o)
191
+ # x: (b, s, i), output_subset: (o_subset) -> output: (b, s, o_subset)
192
+ weight = self.weight if output_subset is None else self.weight[:, :, output_subset] # (s, i, o_subset)
193
+ # slower: mm = torch.einsum("bsi,sio->bso", x, weight)
194
+ mm = torch.bmm(x.transpose(0, 1), weight).transpose(0, 1) # (b, s, o_subset)
195
+ if self.bias is not None:
196
+ bias = self.bias if output_subset is None else self.bias[:, output_subset] # (s, o_subset)
197
+ mm = mm + bias # They (bso, so) will broadcast well
198
+ return mm
199
+ else:
200
+ raise NotImplementedError()
201
+ else:
202
+ # stack_subset: (b, s_subset)
203
+ # x: (b, s_subset, i), output_subset: (o_subset) -> output: (b, s_subset, o_subset)
204
+ weight = self.weight[stack_subset] # (b, s_subset, i, o)
205
+ bias = self.bias[stack_subset] if self.bias is not None else None # (b, s_subset, o)
206
+
207
+ if output_subset is None:
208
+ pass
209
+ elif output_subset.dim() == 1:
210
+ weight = weight[..., output_subset] # (b, s_subset, i, o_subset)
211
+ bias = bias[..., output_subset] if bias is not None else None # (b, s_subset, o_subset)
212
+ else:
213
+ raise NotImplementedError
214
+ mm = torch.matmul(x.unsqueeze(2), weight).squeeze(2) # (b, s_subset, o_subset)
215
+ if bias is not None:
216
+ mm = mm + bias # (b, s_subset, o_subset)
217
+ return mm
218
+
219
+ def extra_repr(self) -> str:
220
+ """String representation for printing the layer."""
221
+ return (
222
+ f"in_features={self.in_features}, out_features={self.out_features}, "
223
+ f"n_stacks={self.n_stacks}, bias={self.bias is not None}"
224
+ )
@@ -0,0 +1,118 @@
1
+ Metadata-Version: 2.4
2
+ Name: stacked-linear
3
+ Version: 0.1.0
4
+ Summary: Efficient implementation of stacked linear modules
5
+ Project-URL: Documentation, https://stacked-linear.readthedocs.io/
6
+ Project-URL: Homepage, https://github.com/moinfar/stacked-linear
7
+ Project-URL: Source, https://github.com/moinfar/stacked-linear
8
+ Author: Amir Ali Moinfar
9
+ Maintainer-email: Amir Ali Moinfar <moinfar.amirali@gmail.com>
10
+ License: BSD 3-Clause License
11
+
12
+ Copyright (c) 2026, Amir Ali Moinfar
13
+ All rights reserved.
14
+
15
+ Redistribution and use in source and binary forms, with or without
16
+ modification, are permitted provided that the following conditions are met:
17
+
18
+ 1. Redistributions of source code must retain the above copyright notice, this
19
+ list of conditions and the following disclaimer.
20
+
21
+ 2. Redistributions in binary form must reproduce the above copyright notice,
22
+ this list of conditions and the following disclaimer in the documentation
23
+ and/or other materials provided with the distribution.
24
+
25
+ 3. Neither the name of the copyright holder nor the names of its
26
+ contributors may be used to endorse or promote products derived from
27
+ this software without specific prior written permission.
28
+
29
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
30
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
31
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
32
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
33
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
34
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
35
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
36
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
37
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
38
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
39
+ License-File: LICENSE
40
+ Classifier: Programming Language :: Python :: 3 :: Only
41
+ Classifier: Programming Language :: Python :: 3.11
42
+ Classifier: Programming Language :: Python :: 3.12
43
+ Classifier: Programming Language :: Python :: 3.13
44
+ Classifier: Programming Language :: Python :: 3.14
45
+ Requires-Python: >=3.11
46
+ Requires-Dist: torch>=2
47
+ Description-Content-Type: text/markdown
48
+
49
+ # Parallel Stacked Linear Modules for PyTorch
50
+
51
+ [![Tests][badge-tests]][tests]
52
+ [![Documentation][badge-docs]][documentation]
53
+
54
+ Efficient implementation of stacked linear modules in PyTorch, with support for output and stack subsetting.
55
+
56
+ ## Features
57
+
58
+ - **`StackedLinearLayer`**: A parallelized linear layer that applies multiple independent transformations across different input stacks simultaneously. This is significantly more efficient than for loop over multiple `nn.Linear` layers. This is useful for specialized neural architectures like Additive Decoders.
59
+ - **Subsetting Support**: Both layers allow for subsetting output features during the forward pass, and `StackedLinearLayer` additionally supports subsetting stacks.
60
+
61
+ ## Installation
62
+
63
+ ```bash
64
+ pip install stacked-linear
65
+ ```
66
+
67
+ Or install from source:
68
+
69
+ ```bash
70
+ pip install git+https://github.com/moinfar/stacked-linear.git
71
+ ```
72
+
73
+ ## Quick Start
74
+
75
+ ### Linear Layer with Output Subsetting
76
+
77
+ ```python
78
+ import torch
79
+ from stacked_linear import LinearLayer
80
+
81
+ # Initialize a layer (10 inputs, 5 outputs)
82
+ layer = LinearLayer(10, 5)
83
+ x = torch.randn(2, 10)
84
+
85
+ # Forward pass on a subset of output features (indices 0, 2, and 4)
86
+ subset = torch.tensor([0, 2, 4])
87
+ output = layer(x, output_subset=subset) # Shape: (2, 3)
88
+ ```
89
+
90
+ ### Stacked Linear Layer
91
+
92
+ ```python
93
+ import torch
94
+ from stacked_linear import StackedLinearLayer
95
+
96
+ # 3 parallel stacks, each mapping 10 inputs to 5 outputs
97
+ layer = StackedLinearLayer(n_stacks=3, in_features=10, out_features=5)
98
+ x = torch.randn(2, 3, 10) # (batch, stacks, features)
99
+
100
+ # Efficient parallel forward pass
101
+ output = layer(x) # Shape: (2, 3, 5)
102
+
103
+ # Forward pass on a subset of output features across all stacks
104
+ subset = torch.tensor([1, 3])
105
+ output_subset = layer(x, output_subset=subset) # Shape: (2, 3, 2)
106
+
107
+ # Forward pass on a subset of stacks
108
+ stack_subset = torch.tensor([[0, 2], [1, 2]]) # Indices for each batch item
109
+ x_subset = torch.randn(2, 2, 10)
110
+ output_stack_subset = layer(x_subset, stack_subset=stack_subset) # Shape: (2, 2, 5)
111
+ ```
112
+
113
+
114
+ [badge-tests]: https://img.shields.io/github/actions/workflow/status/moinfar/stacked-linear/test.yaml?branch=main
115
+ [badge-docs]: https://img.shields.io/readthedocs/stacked-linear
116
+ [tests]: https://github.com/moinfar/stacked-linear/actions/workflows/test.yaml
117
+ [documentation]: https://stacked-linear.readthedocs.io
118
+ [issue tracker]: https://github.com/moinfar/stacked-linear/issues
@@ -0,0 +1,7 @@
1
+ stacked_linear/__init__.py,sha256=ArUfrXt67Z0KtW-09mVssPxewwYMm5jBeNekCSi28u8,140
2
+ stacked_linear/linear_layer.py,sha256=VGiUYAxjWG2lKdWZOb2MDV_jey6YHHOkGACmLSU9fC8,1821
3
+ stacked_linear/stacked_linear_layer.py,sha256=10TtOcRZWNrAINs5IjKO99cvbopQ9qBKJ3RHSEaKfnI,8210
4
+ stacked_linear-0.1.0.dist-info/METADATA,sha256=ZFEzGuck2GFuv_xmcd-bNdPKcf0WzbsANGG_E_PEw8E,4921
5
+ stacked_linear-0.1.0.dist-info/WHEEL,sha256=QccIxa26bgl1E6uMy58deGWi-0aeIkkangHcxk2kWfw,87
6
+ stacked_linear-0.1.0.dist-info/licenses/LICENSE,sha256=wm36XbiogTgMEVdMfi7uGmVSINOYbPiSE1RIOhhII7U,1524
7
+ stacked_linear-0.1.0.dist-info/RECORD,,
@@ -0,0 +1,4 @@
1
+ Wheel-Version: 1.0
2
+ Generator: hatchling 1.29.0
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
@@ -0,0 +1,29 @@
1
+ BSD 3-Clause License
2
+
3
+ Copyright (c) 2026, Amir Ali Moinfar
4
+ All rights reserved.
5
+
6
+ Redistribution and use in source and binary forms, with or without
7
+ modification, are permitted provided that the following conditions are met:
8
+
9
+ 1. Redistributions of source code must retain the above copyright notice, this
10
+ list of conditions and the following disclaimer.
11
+
12
+ 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ this list of conditions and the following disclaimer in the documentation
14
+ and/or other materials provided with the distribution.
15
+
16
+ 3. Neither the name of the copyright holder nor the names of its
17
+ contributors may be used to endorse or promote products derived from
18
+ this software without specific prior written permission.
19
+
20
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.