sparse-layers 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,69 @@
1
+ """sparse-layers: structured sparse layers for building memory-efficient neural networks."""
2
+
3
+ from sparse_layers.layers import (
4
+ ButterflyLinear,
5
+ ButterflyMLP,
6
+ ButterflyMultiHeadAttention,
7
+ CustomLinear,
8
+ CustomMLP,
9
+ MultiHeadAttention,
10
+ PaddedButterflyLinear,
11
+ SimpleMLP,
12
+ )
13
+ from sparse_layers.sse import (
14
+ LinearAttention,
15
+ LinearAttentionConfig,
16
+ NaiveMultiPartitionState,
17
+ NaiveSSEAttention,
18
+ NaiveSSEMultiHeadAttention,
19
+ SSEAttention,
20
+ SSEAttentionAdaptive,
21
+ SSEAttentionAdaptiveConfig,
22
+ SSEAttentionConfig,
23
+ SSEMaskingOps,
24
+ SSEMaskingOpsConfig,
25
+ SSEMultiHeadAttention,
26
+ SSEMultiHeadAttentionConfig,
27
+ SSEMultiPartitionState,
28
+ SSEMultiPartitionStateConfig,
29
+ SSEPartitionSelector,
30
+ SSEPartitionSelectorConfig,
31
+ SSESparseSoftmax,
32
+ SSESparseSoftmaxConfig,
33
+ SSEVarlenOps,
34
+ SSEVarlenOpsConfig,
35
+ )
36
+
37
+ __all__ = [
38
+ # Layers
39
+ "ButterflyLinear",
40
+ "ButterflyMLP",
41
+ "ButterflyMultiHeadAttention",
42
+ "CustomLinear",
43
+ "CustomMLP",
44
+ "MultiHeadAttention",
45
+ "PaddedButterflyLinear",
46
+ "SimpleMLP",
47
+ # SSE
48
+ "LinearAttention",
49
+ "LinearAttentionConfig",
50
+ "NaiveMultiPartitionState",
51
+ "NaiveSSEAttention",
52
+ "NaiveSSEMultiHeadAttention",
53
+ "SSEAttention",
54
+ "SSEAttentionAdaptive",
55
+ "SSEAttentionAdaptiveConfig",
56
+ "SSEAttentionConfig",
57
+ "SSEMaskingOps",
58
+ "SSEMaskingOpsConfig",
59
+ "SSEMultiHeadAttention",
60
+ "SSEMultiHeadAttentionConfig",
61
+ "SSEMultiPartitionState",
62
+ "SSEMultiPartitionStateConfig",
63
+ "SSEPartitionSelector",
64
+ "SSEPartitionSelectorConfig",
65
+ "SSESparseSoftmax",
66
+ "SSESparseSoftmaxConfig",
67
+ "SSEVarlenOps",
68
+ "SSEVarlenOpsConfig",
69
+ ]
@@ -0,0 +1,23 @@
1
+ """Core neural network layers with butterfly factorization support."""
2
+
3
+ from sparse_layers.layers.butterfly_linear import ButterflyLinear
4
+ from sparse_layers.layers.butterfly_mlp import ButterflyMLP
5
+ from sparse_layers.layers.butterfly_multi_head_attention import (
6
+ ButterflyMultiHeadAttention,
7
+ )
8
+ from sparse_layers.layers.custom_linear import CustomLinear
9
+ from sparse_layers.layers.custom_mlp import CustomMLP
10
+ from sparse_layers.layers.multi_head_attention import MultiHeadAttention
11
+ from sparse_layers.layers.padded_butterfly_linear import PaddedButterflyLinear
12
+ from sparse_layers.layers.simple_mlp import SimpleMLP
13
+
14
+ __all__ = [
15
+ "ButterflyLinear",
16
+ "ButterflyMLP",
17
+ "ButterflyMultiHeadAttention",
18
+ "CustomLinear",
19
+ "CustomMLP",
20
+ "MultiHeadAttention",
21
+ "PaddedButterflyLinear",
22
+ "SimpleMLP",
23
+ ]
@@ -0,0 +1,225 @@
1
+ from __future__ import annotations
2
+
3
+ import math
4
+
5
+ import torch
6
+ from torch import Tensor, nn
7
+ import torch.nn.functional as F
8
+
9
+
10
+ def _is_power_of_two(value: int) -> bool:
11
+ return value > 0 and (value & (value - 1) == 0)
12
+
13
+
14
+ class ButterflyLinear(nn.Module):
15
+ """Butterfly factorization based linear layer.
16
+
17
+ This layer follows the structure described in section 2.3.1 of the
18
+ `reducing_memory_requirements_ipu_butterfly.md` reference. It consists of
19
+ log2(N) stages of block-diagonal 2x2 butterfly factors, each stored as a
20
+ learnable parameter. The layer currently supports square power-of-two
21
+ dimensions and acts as a drop-in replacement for :class:`torch.nn.Linear`.
22
+ """
23
+
24
+ def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None:
25
+ super().__init__()
26
+
27
+ if in_features <= 0 or out_features <= 0:
28
+ raise ValueError("in_features and out_features must be positive integers")
29
+
30
+ if in_features != out_features:
31
+ raise ValueError("ButterflyLinear requires in_features == out_features")
32
+
33
+ if not _is_power_of_two(in_features):
34
+ raise ValueError(
35
+ "ButterflyLinear requires dimensions that are a power of two"
36
+ )
37
+
38
+ self.in_features = in_features
39
+ self.out_features = out_features
40
+ self._depth = int(math.log2(in_features))
41
+
42
+ num_blocks = in_features // 2
43
+ default_dtype = torch.get_default_dtype()
44
+ identity_block = torch.eye(2, dtype=default_dtype).unsqueeze(0)
45
+
46
+ self.factors = nn.ParameterList()
47
+ for _ in range(self._depth):
48
+ factor = identity_block.repeat(num_blocks, 1, 1)
49
+ factor += 0.01 * torch.randn_like(factor)
50
+ self.factors.append(nn.Parameter(factor))
51
+
52
+ if bias:
53
+ self.bias = nn.Parameter(torch.zeros(out_features, dtype=default_dtype))
54
+ else:
55
+ self.register_parameter("bias", None)
56
+
57
+ @classmethod
58
+ def from_linear(
59
+ cls,
60
+ layer: nn.Linear,
61
+ *,
62
+ optimization_steps: int = 4000,
63
+ learning_rate: float = 0.1,
64
+ tolerance: float = 1e-7,
65
+ seed: int | None = None,
66
+ ) -> "ButterflyLinear":
67
+ """Construct a butterfly layer approximating a dense :class:`nn.Linear`.
68
+
69
+ The method uses gradient-based fitting over the canonical basis to match
70
+ the transformation represented by ``layer``. The bias term is copied
71
+ directly before optimisation to accelerate convergence.
72
+ """
73
+
74
+ if layer.in_features != layer.out_features:
75
+ raise ValueError("ButterflyLinear.from_linear requires a square nn.Linear")
76
+
77
+ if not _is_power_of_two(layer.in_features):
78
+ raise ValueError(
79
+ "ButterflyLinear.from_linear requires dimensions that are a power of two"
80
+ )
81
+
82
+ device = layer.weight.device
83
+ dtype = layer.weight.dtype
84
+
85
+ if seed is not None:
86
+ torch.manual_seed(seed)
87
+
88
+ result = cls(layer.in_features, layer.out_features, bias=layer.bias is not None)
89
+ result = result.to(device=device, dtype=dtype)
90
+
91
+ if result.bias is not None and layer.bias is not None:
92
+ with torch.no_grad():
93
+ result.bias.copy_(layer.bias)
94
+
95
+ params = list(result.factors.parameters())
96
+ if not params:
97
+ return result
98
+
99
+ optimizer_adam = torch.optim.Adam(params, lr=learning_rate)
100
+
101
+ eye_input = torch.eye(layer.in_features, device=device, dtype=dtype)
102
+ with torch.no_grad():
103
+ target = layer(eye_input)
104
+
105
+ best_loss = float("inf")
106
+
107
+ for step in range(optimization_steps):
108
+ optimizer_adam.zero_grad()
109
+ output = result(eye_input)
110
+ loss = F.mse_loss(output, target)
111
+ loss.backward()
112
+ optimizer_adam.step()
113
+
114
+ current_loss = loss.item()
115
+ if current_loss < best_loss:
116
+ best_loss = current_loss
117
+
118
+ if best_loss <= tolerance:
119
+ break
120
+
121
+ if best_loss > tolerance:
122
+ optimizer_lbfgs = torch.optim.LBFGS(
123
+ params,
124
+ lr=1.0,
125
+ max_iter=100,
126
+ tolerance_grad=1e-12,
127
+ tolerance_change=1e-12,
128
+ line_search_fn="strong_wolfe",
129
+ )
130
+
131
+ def closure() -> Tensor:
132
+ optimizer_lbfgs.zero_grad()
133
+ output_lbfgs = result(eye_input)
134
+ lbfgs_loss = F.mse_loss(output_lbfgs, target)
135
+ lbfgs_loss.backward()
136
+ return lbfgs_loss
137
+
138
+ for _ in range(20):
139
+ loss = optimizer_lbfgs.step(closure)
140
+ best_loss = min(best_loss, loss.item())
141
+ if best_loss <= tolerance:
142
+ break
143
+
144
+ return result
145
+
146
+ def forward(self, input: Tensor) -> Tensor:
147
+ if input.shape[-1] != self.in_features:
148
+ raise ValueError(
149
+ f"Expected last dimension {self.in_features}, got {input.shape[-1]}"
150
+ )
151
+
152
+ original_shape = input.shape[:-1]
153
+ x = input.reshape(-1, self.in_features)
154
+
155
+ for stage_index, factor in enumerate(self.factors):
156
+ x = self._apply_stage(x, factor, stage_index)
157
+
158
+ if self.bias is not None:
159
+ x = x + self.bias
160
+
161
+ return x.reshape(*original_shape, self.out_features)
162
+
163
+ def extra_repr(self) -> str:
164
+ return (
165
+ f"in_features={self.in_features}, out_features={self.out_features}, "
166
+ f"bias={self.bias is not None}"
167
+ )
168
+
169
+ def _apply_stage(self, x: Tensor, factor: Tensor, stage: int) -> Tensor:
170
+ batch = x.shape[0]
171
+ n = self.in_features
172
+ block = 1 << (stage + 1)
173
+ half = block >> 1
174
+
175
+ staged = x.reshape(batch, -1, block)
176
+ staged = staged.reshape(batch, -1, half, 2)
177
+ staged = staged.permute(0, 1, 3, 2).contiguous()
178
+ pairs = staged.reshape(batch, -1, 2)
179
+
180
+ transformed = torch.einsum("bnc,ncd->bnd", pairs, factor)
181
+
182
+ transformed = transformed.reshape(batch, -1, 2, half)
183
+ transformed = transformed.permute(0, 1, 3, 2).contiguous()
184
+ transformed = transformed.reshape(batch, -1, block)
185
+ return transformed.reshape(batch, n)
186
+
187
+ def to_linear(self) -> nn.Linear:
188
+ """Return a dense :class:`nn.Linear` with identical behaviour."""
189
+
190
+ factor_tensor = self.factors[0] if len(self.factors) > 0 else None
191
+ if factor_tensor is not None:
192
+ device = factor_tensor.device
193
+ dtype = factor_tensor.dtype
194
+ elif self.bias is not None:
195
+ device = self.bias.device
196
+ dtype = self.bias.dtype
197
+ else:
198
+ device = torch.device("cpu")
199
+ dtype = torch.get_default_dtype()
200
+
201
+ linear = nn.Linear(
202
+ self.in_features,
203
+ self.out_features,
204
+ bias=self.bias is not None,
205
+ device=device,
206
+ dtype=dtype,
207
+ )
208
+
209
+ with torch.no_grad():
210
+ if self.bias is not None:
211
+ bias_backup = self.bias.data.clone()
212
+ self.bias.zero_()
213
+ else:
214
+ bias_backup = None
215
+
216
+ identity = torch.eye(self.in_features, device=device, dtype=dtype)
217
+ weight_matrix = self(identity)
218
+
219
+ if bias_backup is not None:
220
+ self.bias.copy_(bias_backup)
221
+ linear.bias.copy_(bias_backup)
222
+
223
+ linear.weight.copy_(weight_matrix.t())
224
+
225
+ return linear
@@ -0,0 +1,145 @@
1
+ from __future__ import annotations
2
+
3
+ from collections.abc import Sequence
4
+
5
+ import torch
6
+ from torch import nn
7
+
8
+ from sparse_layers.layers.butterfly_linear import ButterflyLinear, _is_power_of_two
9
+ from sparse_layers.layers.simple_mlp import SimpleMLP
10
+
11
+
12
+ class ButterflyMLP(nn.Module):
13
+ """A multi-layer perceptron composed of :class:`ButterflyLinear` layers."""
14
+
15
+ def __init__(self, input_dim: int, hidden_dims: Sequence[int], output_dim: int) -> None:
16
+ super().__init__()
17
+
18
+ self._validate_dimensions(input_dim, hidden_dims, output_dim)
19
+
20
+ self.input_dim = input_dim
21
+ self.hidden_dims = list(hidden_dims)
22
+ self.output_dim = output_dim
23
+
24
+ layer_dims = [input_dim, *self.hidden_dims, output_dim]
25
+
26
+ modules: list[nn.Module] = []
27
+ for idx in range(len(layer_dims) - 1):
28
+ in_features = layer_dims[idx]
29
+ out_features = layer_dims[idx + 1]
30
+
31
+ linear = ButterflyLinear(in_features, out_features)
32
+ modules.append(linear)
33
+
34
+ if idx < len(layer_dims) - 2:
35
+ modules.append(nn.ReLU())
36
+
37
+ self.network = nn.Sequential(*modules)
38
+
39
+ @staticmethod
40
+ def _validate_dimensions(
41
+ input_dim: int, hidden_dims: Sequence[int], output_dim: int
42
+ ) -> None:
43
+ if input_dim <= 0:
44
+ raise ValueError("input_dim must be a positive integer")
45
+
46
+ if not hidden_dims:
47
+ raise ValueError("hidden_dims must contain at least one positive integer")
48
+
49
+ if any(dim <= 0 for dim in hidden_dims):
50
+ raise ValueError("hidden_dims must contain only positive integers")
51
+
52
+ if output_dim <= 0:
53
+ raise ValueError("output_dim must be a positive integer")
54
+
55
+ all_dims = [input_dim, *hidden_dims, output_dim]
56
+ if len(set(all_dims)) != 1:
57
+ raise ValueError("ButterflyMLP requires all layer dimensions to be identical")
58
+
59
+ if not all(_is_power_of_two(dim) for dim in all_dims):
60
+ raise ValueError("ButterflyMLP requires power-of-two layer dimensions")
61
+
62
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
63
+ if x.dim() != 2:
64
+ raise ValueError("ButterflyMLP expects a 2D input tensor of shape (batch, features)")
65
+
66
+ if x.shape[1] != self.input_dim:
67
+ raise ValueError(
68
+ f"Expected input with {self.input_dim} features, received {x.shape[1]}"
69
+ )
70
+
71
+ return self.network(x)
72
+
73
+ def to_simple_mlp(self) -> SimpleMLP:
74
+ """Return a :class:`SimpleMLP` with identical behaviour."""
75
+
76
+ simple = SimpleMLP(self.input_dim, self.hidden_dims, self.output_dim)
77
+
78
+ sparse_layers = [module for module in self.network if isinstance(module, ButterflyLinear)]
79
+ dense_layers = [module for module in simple.modules() if isinstance(module, nn.Linear)]
80
+
81
+ if len(sparse_layers) != len(dense_layers):
82
+ raise RuntimeError("Unexpected layer mismatch during conversion to SimpleMLP")
83
+
84
+ for butterfly_layer, dense_layer in zip(sparse_layers, dense_layers):
85
+ converted = butterfly_layer.to_linear()
86
+ with torch.no_grad():
87
+ dense_layer.weight.copy_(converted.weight)
88
+ if dense_layer.bias is not None and converted.bias is not None:
89
+ dense_layer.bias.copy_(converted.bias)
90
+
91
+ return simple
92
+
93
+ @classmethod
94
+ def from_simple_mlp(
95
+ cls,
96
+ model: SimpleMLP,
97
+ *,
98
+ seed: int | None = None,
99
+ optimization_steps: int = 4000,
100
+ learning_rate: float = 0.1,
101
+ tolerance: float = 1e-7,
102
+ ) -> "ButterflyMLP":
103
+ """Construct a :class:`ButterflyMLP` from a compatible :class:`SimpleMLP`."""
104
+
105
+ linear_layers = [module for module in model.modules() if isinstance(module, nn.Linear)]
106
+
107
+ if not linear_layers:
108
+ raise ValueError("SimpleMLP must contain at least one linear layer")
109
+
110
+ layer_dims = []
111
+ for layer in linear_layers:
112
+ if layer.in_features != layer.out_features:
113
+ raise ValueError(
114
+ "ButterflyMLP.from_simple_mlp requires square nn.Linear layers"
115
+ )
116
+ if not _is_power_of_two(layer.in_features):
117
+ raise ValueError(
118
+ "ButterflyMLP.from_simple_mlp requires power-of-two dimensions"
119
+ )
120
+ layer_dims.append(layer.in_features)
121
+
122
+ input_dim = linear_layers[0].in_features
123
+ hidden_dims = [layer.out_features for layer in linear_layers[:-1]]
124
+ output_dim = linear_layers[-1].out_features
125
+
126
+ result = cls(input_dim, hidden_dims, output_dim)
127
+
128
+ butterfly_indices = [
129
+ idx for idx, module in enumerate(result.network) if isinstance(module, ButterflyLinear)
130
+ ]
131
+
132
+ if len(butterfly_indices) != len(linear_layers):
133
+ raise RuntimeError("Unexpected layer mismatch during reconstruction from SimpleMLP")
134
+
135
+ for index, (linear_layer, target_idx) in enumerate(zip(linear_layers, butterfly_indices)):
136
+ converted = ButterflyLinear.from_linear(
137
+ linear_layer,
138
+ seed=seed,
139
+ optimization_steps=optimization_steps,
140
+ learning_rate=learning_rate,
141
+ tolerance=tolerance,
142
+ )
143
+ result.network[target_idx] = converted
144
+
145
+ return result
@@ -0,0 +1,93 @@
1
+ from __future__ import annotations
2
+
3
+ import math
4
+
5
+ import torch
6
+ from torch import Tensor, nn
7
+
8
+ from sparse_layers.layers.butterfly_linear import ButterflyLinear
9
+
10
+
11
+ class ButterflyMultiHeadAttention(nn.Module):
12
+ """Multi-head self-attention using ButterflyLinear projections."""
13
+
14
+ def __init__(self, d_model: int, num_heads: int, dropout: float = 0.0) -> None:
15
+ super().__init__()
16
+
17
+ if d_model <= 0:
18
+ raise ValueError("d_model must be a positive integer")
19
+ if num_heads <= 0:
20
+ raise ValueError("num_heads must be a positive integer")
21
+ if d_model % num_heads != 0:
22
+ raise ValueError("d_model must be divisible by num_heads")
23
+ if not (0.0 <= dropout < 1.0):
24
+ raise ValueError("dropout must satisfy 0.0 <= dropout < 1.0")
25
+
26
+ if not _is_power_of_two(d_model):
27
+ raise ValueError("ButterflyMultiHeadAttention requires d_model to be a power of two")
28
+
29
+ self.d_model = d_model
30
+ self.num_heads = num_heads
31
+ self.head_dim = d_model // num_heads
32
+
33
+ self.query = ButterflyLinear(d_model, d_model)
34
+ self.key = ButterflyLinear(d_model, d_model)
35
+ self.value = ButterflyLinear(d_model, d_model)
36
+ self.out = ButterflyLinear(d_model, d_model)
37
+ self.dropout = nn.Dropout(dropout)
38
+
39
+ self._scaling = 1.0 / math.sqrt(self.head_dim)
40
+
41
+ def forward(self, x: Tensor, mask: Tensor | None = None) -> Tensor:
42
+ if x.dim() != 3:
43
+ raise ValueError("expected input of shape (batch, seq_len, d_model)")
44
+ batch_size, seq_len, feature_dim = x.shape
45
+ if feature_dim != self.d_model:
46
+ raise ValueError(
47
+ f"expected last dimension {self.d_model}, received {feature_dim}"
48
+ )
49
+
50
+ query = self.query(x)
51
+ key = self.key(x)
52
+ value = self.value(x)
53
+
54
+ query_heads = self._split_heads(query)
55
+ key_heads = self._split_heads(key)
56
+ value_heads = self._split_heads(value)
57
+
58
+ scores = torch.matmul(query_heads, key_heads.transpose(-2, -1)) * self._scaling
59
+
60
+ if mask is not None:
61
+ if mask.shape != (batch_size, seq_len):
62
+ raise ValueError("mask shape must match (batch, seq_len)")
63
+ if mask.dtype != torch.bool:
64
+ raise ValueError("mask must have dtype torch.bool")
65
+ expanded_mask = mask.unsqueeze(1).unsqueeze(2)
66
+ scores = scores.masked_fill(expanded_mask, float("-inf"))
67
+
68
+ attention = torch.softmax(scores, dim=-1)
69
+ attention = self.dropout(attention)
70
+
71
+ context = torch.matmul(attention, value_heads)
72
+ merged_context = self._merge_heads(context)
73
+
74
+ output = self.out(merged_context)
75
+ return output
76
+
77
+ def _split_heads(self, tensor: Tensor) -> Tensor:
78
+ batch_size, seq_len, _ = tensor.shape
79
+ reshaped = tensor.view(batch_size, seq_len, self.num_heads, self.head_dim)
80
+ return reshaped.permute(0, 2, 1, 3)
81
+
82
+ def _merge_heads(self, tensor: Tensor) -> Tensor:
83
+ batch_size, _, seq_len, _ = tensor.shape
84
+ transposed = tensor.permute(0, 2, 1, 3)
85
+ return transposed.reshape(batch_size, seq_len, self.d_model)
86
+
87
+
88
+ def _is_power_of_two(value: int) -> bool:
89
+ return value > 0 and (value & (value - 1) == 0)
90
+
91
+
92
+ __all__ = ["ButterflyMultiHeadAttention"]
93
+
@@ -0,0 +1,46 @@
1
+ from __future__ import annotations
2
+
3
+ import math
4
+
5
+ import torch
6
+ from torch import Tensor, nn
7
+ from torch.nn.parameter import Parameter
8
+
9
+
10
+ class CustomLinear(nn.Module):
11
+ """Drop-in replacement for nn.Linear with identical behavior."""
12
+
13
+ def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None:
14
+ super().__init__()
15
+
16
+ if in_features <= 0:
17
+ raise ValueError("in_features must be a positive integer")
18
+ if out_features <= 0:
19
+ raise ValueError("out_features must be a positive integer")
20
+
21
+ self.in_features = in_features
22
+ self.out_features = out_features
23
+
24
+ self.weight: Parameter = Parameter(torch.empty(out_features, in_features))
25
+ if bias:
26
+ self.bias: Parameter | None = Parameter(torch.empty(out_features))
27
+ else:
28
+ self.bias = None
29
+
30
+ self.reset_parameters()
31
+
32
+ def reset_parameters(self) -> None:
33
+ nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
34
+ if self.bias is not None:
35
+ fan_in = self.weight.size(1)
36
+ bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
37
+ nn.init.uniform_(self.bias, -bound, bound)
38
+
39
+ def forward(self, input: Tensor) -> Tensor:
40
+ return nn.functional.linear(input, self.weight, self.bias)
41
+
42
+ def extra_repr(self) -> str:
43
+ return (
44
+ f"in_features={self.in_features}, out_features={self.out_features}, "
45
+ f"bias={self.bias is not None}"
46
+ )
@@ -0,0 +1,54 @@
1
+ from __future__ import annotations
2
+
3
+ from collections.abc import Sequence
4
+
5
+ import torch
6
+ from torch import nn
7
+
8
+ from sparse_layers.layers.custom_linear import CustomLinear
9
+
10
+
11
+ class CustomMLP(nn.Module):
12
+ """MLP that uses CustomLinear layers as drop-in replacements for nn.Linear."""
13
+
14
+ def __init__(self, input_dim: int, hidden_dims: Sequence[int], output_dim: int) -> None:
15
+ super().__init__()
16
+
17
+ self._validate_dimensions(input_dim, hidden_dims, output_dim)
18
+
19
+ hidden_dims = list(hidden_dims)
20
+ layer_dims = [input_dim, *hidden_dims, output_dim]
21
+
22
+ modules: list[nn.Module] = []
23
+ for idx in range(len(layer_dims) - 1):
24
+ in_features = layer_dims[idx]
25
+ out_features = layer_dims[idx + 1]
26
+
27
+ linear = CustomLinear(in_features, out_features)
28
+ modules.append(linear)
29
+
30
+ if idx < len(layer_dims) - 2:
31
+ modules.append(nn.ReLU())
32
+
33
+ self.network = nn.Sequential(*modules)
34
+
35
+ @staticmethod
36
+ def _validate_dimensions(input_dim: int, hidden_dims: Sequence[int], output_dim: int) -> None:
37
+ if input_dim <= 0:
38
+ raise ValueError("input_dim must be a positive integer")
39
+
40
+ if not hidden_dims:
41
+ raise ValueError("hidden_dims must contain at least one positive integer")
42
+
43
+ if any(dim <= 0 for dim in hidden_dims):
44
+ raise ValueError("hidden_dims must contain only positive integers")
45
+
46
+ if output_dim <= 0:
47
+ raise ValueError("output_dim must be a positive integer")
48
+
49
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
50
+ if x.dim() != 2:
51
+ raise ValueError("CustomMLP expects a 2D input tensor of shape (batch, features)")
52
+
53
+ return self.network(x)
54
+