minmaxrnc 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
minmaxrnc/__init__.py ADDED
@@ -0,0 +1,12 @@
1
+ # SPDX-FileCopyrightText: 2026 Alessandro Ronca
2
+ # SPDX-License-Identifier: PolyForm-Noncommercial-1.0.0
3
+
4
+ from .minmax_rnc import MinMaxRNC, MinMaxRNCConfig
5
+ from .minmax_rnc_lm import MinMaxRNC_LM, MinMaxRNCLMConfig
6
+
7
+ __all__ = [
8
+ "MinMaxRNC",
9
+ "MinMaxRNCConfig",
10
+ "MinMaxRNC_LM",
11
+ "MinMaxRNCLMConfig",
12
+ ]
@@ -0,0 +1,134 @@
1
+ # SPDX-FileCopyrightText: 2026 Alessandro Ronca
2
+ # SPDX-License-Identifier: PolyForm-Noncommercial-1.0.0
3
+
4
+ from typing import Sequence, Optional, List, Tuple, Union, Literal
5
+ from dataclasses import dataclass, replace
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+ from .minmax_neuron import MinMaxNeuron, MinMaxNeuronConfig
11
+
12
+ from .modules.feedforward import FeedForwardConfig, create_feedforward
13
+ from .modules.basic_conv import BasicConv, BasicConvConfig
14
+ from .modules.gated_conv import GatedConv, GatedConvConfig
15
+
16
+
17
+ NormType = Literal['none', 'layernorm', 'rmsnorm']
18
+
19
+
20
+ @dataclass(frozen=True)
21
+ class MinMaxLayerConfig:
22
+ """
23
+ Configuration for one MinMax Layer.
24
+
25
+ This is normally constructed automatically by MinMaxRNCConfig.layer_cfg;
26
+ direct construction is only needed for non-standard layer shapes.
27
+
28
+ Fields
29
+ ------
30
+ neuron : MinMaxNeuronConfig
31
+ Config for the MinMax Neuron sub-module.
32
+ conv : BasicConvConfig | GatedConvConfig
33
+ Config for the short-range convolution applied before the FFN.
34
+ d_model : int
35
+ Residual-stream width (must match neuron.d_model).
36
+ first_in_dropout : float
37
+ Dropout probability for the FFN in the *first* layer only. Allows a
38
+ higher input-level dropout without affecting deeper layers.
39
+ feedforward : FeedForwardConfig | None
40
+ Config for the feed-forward sub-layer. Currently required (None
41
+ is rejected at construction time).
42
+ norm : 'none' | 'layernorm' | 'rmsnorm'
43
+ Pre-norm applied before each of the three sub-layers (conv, FFN,
44
+ neuron).
45
+ """
46
+
47
+ neuron: MinMaxNeuronConfig
48
+ conv: Union[BasicConvConfig, GatedConvConfig]
49
+ d_model: int
50
+ first_in_dropout: float = 0.0
51
+ feedforward: Optional[FeedForwardConfig] = None
52
+ norm: NormType = 'layernorm'
53
+
54
+
55
+ class MinMaxLayer(nn.Module):
56
+ """
57
+ One residual layer of the MinMax RNC backbone.
58
+
59
+ Internal data flow (all operations use pre-norm and residual connections):
60
+
61
+ conv_out = Conv( norm(u) ) # short-range context
62
+ ffn_out = FFN( norm(u + conv_out) )
63
+ neur_out = Neuron( norm(u + ffn_out) )
64
+ output = u + neur_out
65
+ """
66
+
67
+ def __init__(self, cfg: MinMaxLayerConfig, first: bool):
68
+ super().__init__()
69
+
70
+ self.cfg = cfg
71
+
72
+ if type(cfg.conv) == BasicConvConfig:
73
+ self.conv = BasicConv(cfg.conv)
74
+ else:
75
+ self.conv = GatedConv(cfg.conv)
76
+
77
+ self.neuron = MinMaxNeuron(cfg.neuron)
78
+
79
+ self.use_ffn = (cfg.feedforward is not None)
80
+ assert self.use_ffn
81
+
82
+ ffn_dropout = cfg.feedforward.dropout
83
+ if first:
84
+ ffn_dropout = cfg.first_in_dropout
85
+
86
+ self.ffn = create_feedforward(
87
+ config=replace(
88
+ cfg.feedforward,
89
+ embedding_dim=cfg.d_model,
90
+ embedding_dim_out=cfg.d_model,
91
+ dropout=ffn_dropout
92
+ )
93
+ )
94
+
95
+ if self.cfg.norm == 'layernorm':
96
+ self.norm_ffn = nn.LayerNorm(cfg.d_model)
97
+ self.norm_neuron = nn.LayerNorm(cfg.d_model)
98
+ self.norm_conv = nn.LayerNorm(cfg.d_model)
99
+ elif self.cfg.norm == 'rmsnorm':
100
+ self.norm_ffn = nn.RMSNorm(cfg.d_model)
101
+ self.norm_neuron = nn.RMSNorm(cfg.d_model)
102
+ self.norm_conv = nn.RMSNorm(cfg.d_model)
103
+ else:
104
+ self.norm_conv = nn.Identity()
105
+ self.norm_ffn = nn.Identity()
106
+ self.norm_neuron = nn.Identity()
107
+
108
+
109
+ @property
110
+ def initial_state(self):
111
+ return {
112
+ 'neuron': self.neuron.initial_state,
113
+ 'conv': self.conv.initial_state
114
+ }
115
+
116
+
117
+ def forward(self, u: torch.Tensor, state: dict):
118
+
119
+ conv_in = self.norm_conv(u)
120
+ conv, conv_state = self.conv(conv_in, state['conv'])
121
+
122
+ ffn_in = self.norm_ffn(u + conv)
123
+ ffn = self.ffn(ffn_in)
124
+
125
+ neuron_in = self.norm_neuron(u + ffn)
126
+ neuron, neuron_state = self.neuron(neuron_in, state['neuron'])
127
+
128
+ output = u + neuron
129
+
130
+ state = {'conv': conv_state, 'neuron': neuron_state}
131
+
132
+ return output, state
133
+
134
+
@@ -0,0 +1,148 @@
1
+ # SPDX-FileCopyrightText: 2026 Alessandro Ronca
2
+ # SPDX-License-Identifier: PolyForm-Noncommercial-1.0.0
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from dataclasses import dataclass
7
+ from .modules.initialisers import wang_init_, small_init_init_
8
+
9
+ from . import minmax_scan
10
+
11
+
12
+ @dataclass(frozen=True)
13
+ class MinMaxNeuronConfig:
14
+ """
15
+ Configuration for a single MinMax Neuron.
16
+
17
+ Fields
18
+ ------
19
+ _num_blocks : int
20
+ Total number of residual blocks in the enclosing model. Used to scale
21
+ the output projection at initialisation (wang_init_) so that the
22
+ combined contribution of all blocks to the residual stream stays O(1).
23
+ d_model : int
24
+ Dimension of the input and output (the residual-stream width).
25
+ d_state : int
26
+ Dimension of the hidden state x_t. Larger values give the neuron
27
+ more memory capacity but increase parameter count linearly.
28
+ dropout : float
29
+ Dropout probability applied to the input u before projection.
30
+ train_init : bool
31
+ If True, the initial hidden state x_0 is a learned parameter.
32
+ If False (default), x_0 is fixed at zero.
33
+ output_gate : bool
34
+ If True, the output is element-wise multiplied by a sigmoid-gated
35
+ projection of the input before the final linear: y = W_o(x ⊙ σ(W_g u)).
36
+ """
37
+
38
+ _num_blocks: int
39
+ d_model: int
40
+ d_state: int
41
+ dropout: float = 0.0
42
+ train_init: bool = False
43
+ output_gate: bool = True
44
+
45
+
46
+ class MinMaxNeuron(nn.Module):
47
+ """
48
+ The core recurrent cell of the MinMax RNC.
49
+
50
+ Maintains a hidden state x_t ∈ R^D updated by the MinMax recurrence:
51
+
52
+ x_{t+1} = max(min(r_t, x_t), s_t)
53
+
54
+ All states for a sequence of length T are computed simultaneously via a parallel prefix scan in
55
+ O(log T) depth instead of O(T).
56
+
57
+ Output projection:
58
+
59
+ y_t = W_o x_t (output_gate=False)
60
+ y_t = W_o (x_t ⊙ W_g u_t) (output_gate=True)
61
+ """
62
+
63
+ def __init__(self, cfg: MinMaxNeuronConfig):
64
+
65
+ super().__init__()
66
+
67
+ self.cfg = cfg
68
+
69
+ self.I = I = cfg.d_model
70
+ self.D = D = cfg.d_state
71
+
72
+ self._initial_state = nn.Parameter(torch.zeros(D), requires_grad=cfg.train_init)
73
+
74
+ self.drop = nn.Dropout(cfg.dropout)
75
+
76
+ self.s = nn.Linear(I, D)
77
+ self.r = nn.Linear(I, D)
78
+ self.o = nn.Linear(D, I)
79
+ if cfg.output_gate:
80
+ self.o_g = nn.Linear(I, D)
81
+
82
+ self.reset()
83
+
84
+
85
+ def reset(self):
86
+ # Init 's'
87
+ small_init_init_(self.s.weight, dim=self.I)
88
+ if self.s.bias is not None:
89
+ nn.init.zeros_(self.s.bias)
90
+ small_init_init_(self.r.weight, dim=self.I)
91
+ # Init 'r'
92
+ if self.r.bias is not None:
93
+ nn.init.zeros_(self.r.bias)
94
+ # Init 'o'
95
+ wang_init_(
96
+ self.o.weight,
97
+ dim=self.I,
98
+ num_blocks= self.cfg._num_blocks,
99
+ )
100
+ if self.o.bias is not None:
101
+ nn.init.zeros_(self.o.bias)
102
+
103
+
104
+ @property
105
+ def initial_state(self):
106
+ return self._initial_state
107
+
108
+
109
+ def forward(self, u: torch.Tensor, state: torch.Tensor):
110
+ """
111
+ Compute updated state for a sequence using closed form with initial state.
112
+
113
+ u: (B, T, I)
114
+ state: (1/B,D,) (state before the first step in the input sequence)
115
+
116
+ Returns: 1) sequence of outputs: (B, T, I)
117
+ 2) last state: (B, D)
118
+ """
119
+
120
+ B, T, I = u.shape
121
+ D = self.D
122
+ device = u.device
123
+
124
+ if state.dim() == 1: # state is the initial initial state
125
+ x0 = state.unsqueeze(0).expand(B,D)
126
+ else:
127
+ x0 = state
128
+
129
+ # Shape of u: (B,T,I)
130
+ # Shape of x0: (B,D)
131
+
132
+
133
+ u = self.drop(u) # (B,T,I)
134
+ s = self.s(u) # (B,T,D)
135
+ r = self.r(u) # (B,T,D)
136
+
137
+ x_post = minmax_scan.all_states(r, s, x0)
138
+ x_post = x_post[:,1:,:]
139
+
140
+
141
+ # ----- Compute outputs -----
142
+ x_latest = x_post[:,-1,:] # (B,T,D)
143
+ if self.cfg.output_gate:
144
+ x_post = x_post * self.o_g(u) # (B,T,D)
145
+ output = self.o(x_post) # (B,T,I)
146
+
147
+ return output, x_latest
148
+
@@ -0,0 +1,39 @@
1
+ # SPDX-FileCopyrightText: 2026 Alessandro Ronca
2
+ # SPDX-License-Identifier: PolyForm-Noncommercial-1.0.0
3
+
4
+ import torch
5
+
6
+
7
+ def apply(a: torch.Tensor, b: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
8
+ """
9
+ Apply the MinMax scalar operator f(x) = max(min(a, x), b)
10
+ where
11
+ - min and max are applied element-wise,
12
+ - shapes (broadcastable):
13
+ a: (..., D)
14
+ b: (..., D)
15
+ x: (..., D)
16
+ """
17
+ return torch.maximum(torch.minimum(a, x), b)
18
+
19
+
20
+ def compose(a2: torch.Tensor, b2: torch.Tensor, a1: torch.Tensor, b1: torch.Tensor):
21
+ """
22
+ Compose MinMax scalar operators.
23
+
24
+ Given
25
+ a1,b1, a2,b2
26
+ having shape (..., D) and representing the MinMax scalar operators
27
+ f1(x) = max(min(a1, x), b1),
28
+ f2(x) = max(min(a2, x), b2),
29
+ return
30
+ a = min(a2, a1)
31
+ b = max(min(a2, b1), b2)
32
+ corresponding to the MinMax scalar operator
33
+ f(x) = f2(f1(x)) = max(min(a, x), b)
34
+ """
35
+ a = torch.minimum(a2, a1)
36
+ b = torch.maximum(torch.minimum(a2, b1), b2)
37
+ return a, b
38
+
39
+
@@ -0,0 +1,281 @@
1
+ # SPDX-FileCopyrightText: 2026 Alessandro Ronca
2
+ # SPDX-License-Identifier: PolyForm-Noncommercial-1.0.0
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from typing import Optional, Literal, Union
8
+ from dataclasses import dataclass, replace
9
+
10
+ from .minmax_layer import MinMaxLayer, MinMaxLayerConfig, NormType
11
+ from .minmax_neuron import MinMaxNeuronConfig
12
+ from .modules.feedforward import FeedForwardConfig, FFType, InitType, create_feedforward
13
+ from .modules.basic_conv import BasicConvConfig
14
+ from .modules.gated_conv import GatedConvConfig
15
+
16
+
17
+ ConvType = Literal['basic', 'gated']
18
+
19
+
20
+ @dataclass(frozen=True)
21
+ class MinMaxRNCConfig:
22
+ """
23
+ Configuration for the MinMax RNC backbone.
24
+
25
+ Configs for neuron, conv, FFN, and layer are derived.
26
+
27
+ Core architecture
28
+ -----------------
29
+ d_model : int
30
+ Residual-stream width. Every sub-module input and output has this
31
+ dimension.
32
+ n_layers : int
33
+ Number of stacked MinMaxLayers.
34
+ d_state : int
35
+ Hidden-state dimension of each MinMax Neuron. Independent of d_model;
36
+ larger values increase memory capacity at linear parameter cost.
37
+
38
+ Normalisation
39
+ -------------
40
+ norm : 'layernorm' | 'rmsnorm' | 'none'
41
+ Pre-norm type applied before each sub-layer inside each layer.
42
+ 'layernorm' (default) is stable; 'rmsnorm' is slightly faster;
43
+ 'none' disables normalisation entirely.
44
+ postlayers_norm : 'layernorm' | 'rmsnorm' | 'none'
45
+ Norm applied to the output of the final layer (before the optional
46
+ post-layers FFN).
47
+
48
+ Feed-forward network (within each layer)
49
+ -----------------------------------------
50
+ ffn_type : 'gated' | 'basic'
51
+ 'gated' (default) — gated FFN (ReGLU / SwiGLU depending on
52
+ act_fn), from Shazeer (2020). 'basic' — standard two-layer MLP.
53
+ ffn_proj_factor : float
54
+ Hidden-layer expansion factor relative to d_model. The hidden
55
+ dimension is rounded to the nearest multiple of 2.
56
+ ffn_act_fn : str
57
+ Activation function name. Choices: 'relu', 'relu^2', 'gelu',
58
+ 'swish', 'sigmoid', 'selu'.
59
+ ffn_dropout : float
60
+ Dropout applied inside the FFN of every layer except possibly the first
61
+ (see prelayers_dropout).
62
+ ffn_init : 'default' | 'scaled'
63
+ Weight initialisation scheme. 'scaled' uses small_init for the
64
+ up-projection and wang_init for the down-projection.
65
+
66
+ Neuron
67
+ ------
68
+ output_gate : bool
69
+ If True, the neuron output is element-wise gated by a learned
70
+ projection of the input.
71
+ train_init : bool
72
+ If True, the neuron's initial hidden state x_0 is a learned parameter.
73
+ neuron_dropout : float
74
+ Dropout probability applied to the neuron input.
75
+
76
+ Convolution
77
+ -----------
78
+ conv_type : 'gated' | 'basic'
79
+ 'gated' (default) — learned scalar gate interpolating between
80
+ the previous and current token. 'basic' — learned linear mixing of
81
+ the previous and current token representations.
82
+ conv_init_val : float
83
+ Initial value of the gate logit in GatedConv. 0.0 → gate ≈ 0.5
84
+ (equal mix); negative values bias toward the current token.
85
+
86
+ Pre/Post-layers
87
+ -----------
88
+ prelayers_dropout : float
89
+ FFN dropout for the first layer only; overrides ffn_dropout. Useful
90
+ as an input-level regulariser without penalising deeper layers.
91
+ use_postlayers_ffn : bool
92
+ If True, an extra FFN (with the same type and factor as the in-layer
93
+ FFN) is applied after all layers, before postlayers_norm.
94
+
95
+ Forward
96
+ -------
97
+ unroll_steps : int
98
+ Sequence chunk size for the forward pass. The sequence is split into
99
+ chunks of this length and processed sequentially (carrying the state
100
+ across chunks). unroll_steps=1 processes one token at a time;
101
+ unroll_steps=T processes the whole sequence at once. Both give
102
+ identical outputs; larger values use more peak memory.
103
+
104
+ """
105
+
106
+ # Core architecture
107
+ d_model: int
108
+ n_layers: int
109
+ d_state: int
110
+
111
+ # Normalisation (within layers and post-layers)
112
+ norm: NormType = 'layernorm'
113
+ postlayers_norm: NormType = 'layernorm'
114
+
115
+ # FFN within each layer
116
+ ffn_type: FFType = 'gated'
117
+ ffn_proj_factor: float = 1.3
118
+ ffn_act_fn: str = 'relu'
119
+ ffn_dropout: float = 0.0
120
+ ffn_init: InitType = 'scaled'
121
+
122
+ # Neuron
123
+ output_gate: bool = True
124
+ train_init: bool = False
125
+ neuron_dropout: float = 0.0
126
+
127
+ # Conv
128
+ conv_type: ConvType = 'gated'
129
+ conv_init_val: float = 0.0
130
+
131
+ # Per-layer options
132
+ prelayers_dropout: float = 0.0
133
+
134
+ # Post-layers
135
+ use_postlayers_ffn: bool = False
136
+
137
+ @property
138
+ def layer_cfg(self) -> MinMaxLayerConfig:
139
+ neuron_cfg = MinMaxNeuronConfig(
140
+ _num_blocks = self.n_layers,
141
+ d_model = self.d_model,
142
+ d_state = self.d_state,
143
+ dropout = self.neuron_dropout,
144
+ train_init = self.train_init,
145
+ output_gate = self.output_gate,
146
+ )
147
+ if self.conv_type == 'basic':
148
+ conv_cfg = BasicConvConfig(embedding_dim=self.d_model)
149
+ else:
150
+ conv_cfg = GatedConvConfig(
151
+ embedding_dim = self.d_model,
152
+ init_val = self.conv_init_val,
153
+ )
154
+ ffn_cfg = FeedForwardConfig(
155
+ _num_blocks = self.n_layers,
156
+ ffn_type = self.ffn_type,
157
+ proj_factor = self.ffn_proj_factor,
158
+ act_fn = self.ffn_act_fn,
159
+ dropout = self.ffn_dropout,
160
+ init = self.ffn_init,
161
+ )
162
+ return MinMaxLayerConfig(
163
+ d_model = self.d_model,
164
+ neuron = neuron_cfg,
165
+ conv = conv_cfg,
166
+ feedforward = ffn_cfg,
167
+ norm = self.norm,
168
+ first_in_dropout = self.prelayers_dropout,
169
+ )
170
+
171
+ # ------------------------------------------------------------------
172
+ # Preset factories
173
+ # ------------------------------------------------------------------
174
+
175
+ @classmethod
176
+ def small(cls, n_layers: int = 2) -> 'MinMaxRNCConfig':
177
+ return cls(d_model=90, n_layers=n_layers, d_state=40)
178
+
179
+ @classmethod
180
+ def medium(cls, n_layers: int = 8) -> 'MinMaxRNCConfig':
181
+ return cls(d_model=512, n_layers=n_layers, d_state=512)
182
+
183
+ @classmethod
184
+ def large(cls, n_layers: int = 12) -> 'MinMaxRNCConfig':
185
+ return cls(d_model=728, n_layers=n_layers, d_state=1456)
186
+
187
+
188
+ class MinMaxRNC(nn.Module):
189
+ """
190
+ MinMax Recurrent Neural Cascade — the backbone sequence model.
191
+
192
+ Stacks ``cfg.n_layers`` MinMaxLayers, each containing a short-range
193
+ convolution, a feed-forward network, and a MinMax Neuron. All three
194
+ sub-layers use pre-norm and residual connections.
195
+
196
+ Inputs
197
+ ------
198
+ u : Tensor (B, T, d_model)
199
+ Continuous input sequence (e.g. token embeddings).
200
+ state : list[dict] | None
201
+ Per-layer recurrent state from a previous call. Pass None (or omit)
202
+ to start from the default initial state.
203
+ return_state : bool
204
+ If True, also return the updated state after the last token.
205
+
206
+ Outputs
207
+ -------
208
+ y : Tensor (B, T, d_model)
209
+ state : list[dict] — only when return_state=True
210
+ """
211
+
212
+ def __init__(self, cfg: MinMaxRNCConfig):
213
+ super().__init__()
214
+ self.__cfg = cfg
215
+ self.reset()
216
+
217
+ def reset(self):
218
+ layer_cfg = self.__cfg.layer_cfg
219
+
220
+ self.layers = nn.ModuleList()
221
+ firstlayer = True
222
+ for _ in range(self.__cfg.n_layers):
223
+ self.layers.append(MinMaxLayer(layer_cfg, first=firstlayer))
224
+ firstlayer = False
225
+
226
+ self.postlayers_norm = None
227
+ if self.__cfg.postlayers_norm == 'layernorm':
228
+ self.postlayers_norm = nn.LayerNorm(self.__cfg.d_model)
229
+ elif self.__cfg.postlayers_norm == 'rmsnorm':
230
+ self.postlayers_norm = nn.RMSNorm(self.__cfg.d_model)
231
+
232
+ self.postlayers_ffn = None
233
+ self.postlayers_ffn_norm = None
234
+ if self.__cfg.use_postlayers_ffn:
235
+ self.postlayers_ffn = create_feedforward(
236
+ config=replace(
237
+ layer_cfg.feedforward,
238
+ embedding_dim = self.__cfg.d_model,
239
+ embedding_dim_out = self.__cfg.d_model,
240
+ )
241
+ )
242
+ if self.__cfg.norm == 'layernorm':
243
+ self.postlayers_ffn_norm = nn.LayerNorm(self.__cfg.d_model)
244
+ elif self.__cfg.norm == 'rmsnorm':
245
+ self.postlayers_ffn_norm = nn.RMSNorm(self.__cfg.d_model)
246
+ else:
247
+ self.postlayers_ffn_norm = nn.Identity()
248
+
249
+ @property
250
+ def initial_state(self):
251
+ return [layer.initial_state for layer in self.layers]
252
+
253
+ def _parallel_forward(self, u: torch.Tensor, state):
254
+ """u: [B, T, D] — returns output [B, T, D] and updated state."""
255
+ updated_state = []
256
+ y = u
257
+ for layer, layer_state in zip(self.layers, state):
258
+ y, updated_layer_state = layer(y, layer_state)
259
+ updated_state.append(updated_layer_state)
260
+
261
+ if self.postlayers_ffn is not None:
262
+ y = y + self.postlayers_ffn(self.postlayers_ffn_norm(y))
263
+
264
+ if self.postlayers_norm is not None:
265
+ y = self.postlayers_norm(y)
266
+
267
+ return y, updated_state
268
+
269
+ def forward(self, u: torch.Tensor, unroll_steps: int, state=None, return_state: bool = False):
270
+ if state is None:
271
+ state = self.initial_state
272
+
273
+ y_chunks = []
274
+ for u_chunk in u.split(unroll_steps, dim=1):
275
+ y_chunk, state = self._parallel_forward(u_chunk, state)
276
+ y_chunks.append(y_chunk)
277
+ y = torch.cat(y_chunks, dim=1)
278
+
279
+ if return_state:
280
+ return y, state
281
+ return y
@@ -0,0 +1,88 @@
1
+ # SPDX-FileCopyrightText: 2026 Alessandro Ronca
2
+ # SPDX-License-Identifier: PolyForm-Noncommercial-1.0.0
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from dataclasses import dataclass
8
+
9
+ from .minmax_rnc import MinMaxRNC, MinMaxRNCConfig
10
+ from .modules.initialisers import small_init_init_
11
+
12
+
13
+ @dataclass(frozen=True)
14
+ class MinMaxRNCLMConfig:
15
+ """
16
+ Configuration for the MinMax RNC language model.
17
+
18
+ Fields
19
+ ------
20
+ backbone : MinMaxRNCConfig
21
+ Config for the MinMaxRNC backbone. The embedding dimension is taken
22
+ from backbone.d_model.
23
+ head_dropout : float
24
+ Dropout applied to the backbone output before the LM head projection.
25
+ tie_weights : bool
26
+ If True (default), the LM head weight matrix is shared with the token
27
+ embedding matrix, halving those parameters and acting as a regulariser.
28
+ """
29
+
30
+ backbone: MinMaxRNCConfig
31
+ head_dropout: float = 0.0
32
+ tie_weights: bool = True
33
+
34
+
35
+ class MinMaxRNC_LM(MinMaxRNC):
36
+ """
37
+ MinMax RNC with a token embedding layer and a language-model head.
38
+
39
+ Wraps MinMaxRNC with:
40
+ - A token embedding (vocab_size × d_model)
41
+ - A dropout before the output projection
42
+ - A linear LM head (d_model × vocab_size), optionally tied to the embedding
43
+
44
+ Inputs
45
+ ------
46
+ tokens : LongTensor (B, T)
47
+ Token indices in [0, vocab_size).
48
+ state : list[dict] | None
49
+ Recurrent state from a previous call.
50
+ return_state : bool
51
+ If True, also return the updated state.
52
+
53
+ Outputs
54
+ -------
55
+ logits : Tensor (B, T, vocab_size)
56
+ state : list[dict] — only when return_state=True
57
+ """
58
+
59
+ def __init__(self, vocab_size: int, cfg: MinMaxRNCLMConfig):
60
+ self.__lm_cfg = cfg
61
+ self.__vocab_size = vocab_size
62
+ super().__init__(cfg.backbone) # calls reset() → MinMaxRNC.reset() then our additions
63
+ self.__lm_reset()
64
+
65
+ def reset(self):
66
+ super().reset()
67
+ self.__lm_reset()
68
+
69
+ def __lm_reset(self):
70
+ d_model = self.__lm_cfg.backbone.d_model
71
+ self.token_emb = nn.Embedding(self.__vocab_size, d_model)
72
+ self.lm_head = nn.Linear(d_model, self.__vocab_size, bias=False)
73
+ self.head_drop = nn.Dropout(self.__lm_cfg.head_dropout)
74
+
75
+ small_init_init_(self.token_emb.weight, dim=d_model)
76
+ if self.__lm_cfg.tie_weights:
77
+ self.lm_head.weight = self.token_emb.weight
78
+ else:
79
+ small_init_init_(self.lm_head.weight, dim=d_model)
80
+
81
+ def forward(self, tokens: torch.Tensor, unroll_steps: int, state=None, return_state: bool = False):
82
+ y, state = super().forward(
83
+ self.token_emb(tokens), unroll_steps, state=state, return_state=True
84
+ )
85
+ logits = self.lm_head(self.head_drop(y))
86
+ if return_state:
87
+ return logits, state
88
+ return logits