minmaxrnc 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- minmaxrnc/__init__.py +12 -0
- minmaxrnc/minmax_layer.py +134 -0
- minmaxrnc/minmax_neuron.py +148 -0
- minmaxrnc/minmax_operator.py +39 -0
- minmaxrnc/minmax_rnc.py +281 -0
- minmaxrnc/minmax_rnc_lm.py +88 -0
- minmaxrnc/minmax_scan.py +60 -0
- minmaxrnc/modules/basic_conv.py +77 -0
- minmaxrnc/modules/feedforward.py +167 -0
- minmaxrnc/modules/gated_conv.py +78 -0
- minmaxrnc/modules/initialisers.py +60 -0
- minmaxrnc-0.1.0.dist-info/METADATA +329 -0
- minmaxrnc-0.1.0.dist-info/RECORD +17 -0
- minmaxrnc-0.1.0.dist-info/WHEEL +5 -0
- minmaxrnc-0.1.0.dist-info/licenses/LICENSE +143 -0
- minmaxrnc-0.1.0.dist-info/licenses/NOTICE +10 -0
- minmaxrnc-0.1.0.dist-info/top_level.txt +1 -0
minmaxrnc/__init__.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: 2026 Alessandro Ronca
|
|
2
|
+
# SPDX-License-Identifier: PolyForm-Noncommercial-1.0.0
|
|
3
|
+
|
|
4
|
+
from .minmax_rnc import MinMaxRNC, MinMaxRNCConfig
|
|
5
|
+
from .minmax_rnc_lm import MinMaxRNC_LM, MinMaxRNCLMConfig
|
|
6
|
+
|
|
7
|
+
__all__ = [
|
|
8
|
+
"MinMaxRNC",
|
|
9
|
+
"MinMaxRNCConfig",
|
|
10
|
+
"MinMaxRNC_LM",
|
|
11
|
+
"MinMaxRNCLMConfig",
|
|
12
|
+
]
|
|
@@ -0,0 +1,134 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: 2026 Alessandro Ronca
|
|
2
|
+
# SPDX-License-Identifier: PolyForm-Noncommercial-1.0.0
|
|
3
|
+
|
|
4
|
+
from typing import Sequence, Optional, List, Tuple, Union, Literal
|
|
5
|
+
from dataclasses import dataclass, replace
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
import torch.nn as nn
|
|
9
|
+
|
|
10
|
+
from .minmax_neuron import MinMaxNeuron, MinMaxNeuronConfig
|
|
11
|
+
|
|
12
|
+
from .modules.feedforward import FeedForwardConfig, create_feedforward
|
|
13
|
+
from .modules.basic_conv import BasicConv, BasicConvConfig
|
|
14
|
+
from .modules.gated_conv import GatedConv, GatedConvConfig
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
NormType = Literal['none', 'layernorm', 'rmsnorm']
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@dataclass(frozen=True)
|
|
21
|
+
class MinMaxLayerConfig:
|
|
22
|
+
"""
|
|
23
|
+
Configuration for one MinMax Layer.
|
|
24
|
+
|
|
25
|
+
This is normally constructed automatically by MinMaxRNCConfig.layer_cfg;
|
|
26
|
+
direct construction is only needed for non-standard layer shapes.
|
|
27
|
+
|
|
28
|
+
Fields
|
|
29
|
+
------
|
|
30
|
+
neuron : MinMaxNeuronConfig
|
|
31
|
+
Config for the MinMax Neuron sub-module.
|
|
32
|
+
conv : BasicConvConfig | GatedConvConfig
|
|
33
|
+
Config for the short-range convolution applied before the FFN.
|
|
34
|
+
d_model : int
|
|
35
|
+
Residual-stream width (must match neuron.d_model).
|
|
36
|
+
first_in_dropout : float
|
|
37
|
+
Dropout probability for the FFN in the *first* layer only. Allows a
|
|
38
|
+
higher input-level dropout without affecting deeper layers.
|
|
39
|
+
feedforward : FeedForwardConfig | None
|
|
40
|
+
Config for the feed-forward sub-layer. Currently required (None
|
|
41
|
+
is rejected at construction time).
|
|
42
|
+
norm : 'none' | 'layernorm' | 'rmsnorm'
|
|
43
|
+
Pre-norm applied before each of the three sub-layers (conv, FFN,
|
|
44
|
+
neuron).
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
neuron: MinMaxNeuronConfig
|
|
48
|
+
conv: Union[BasicConvConfig, GatedConvConfig]
|
|
49
|
+
d_model: int
|
|
50
|
+
first_in_dropout: float = 0.0
|
|
51
|
+
feedforward: Optional[FeedForwardConfig] = None
|
|
52
|
+
norm: NormType = 'layernorm'
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class MinMaxLayer(nn.Module):
|
|
56
|
+
"""
|
|
57
|
+
One residual layer of the MinMax RNC backbone.
|
|
58
|
+
|
|
59
|
+
Internal data flow (all operations use pre-norm and residual connections):
|
|
60
|
+
|
|
61
|
+
conv_out = Conv( norm(u) ) # short-range context
|
|
62
|
+
ffn_out = FFN( norm(u + conv_out) )
|
|
63
|
+
neur_out = Neuron( norm(u + ffn_out) )
|
|
64
|
+
output = u + neur_out
|
|
65
|
+
"""
|
|
66
|
+
|
|
67
|
+
def __init__(self, cfg: MinMaxLayerConfig, first: bool):
|
|
68
|
+
super().__init__()
|
|
69
|
+
|
|
70
|
+
self.cfg = cfg
|
|
71
|
+
|
|
72
|
+
if type(cfg.conv) == BasicConvConfig:
|
|
73
|
+
self.conv = BasicConv(cfg.conv)
|
|
74
|
+
else:
|
|
75
|
+
self.conv = GatedConv(cfg.conv)
|
|
76
|
+
|
|
77
|
+
self.neuron = MinMaxNeuron(cfg.neuron)
|
|
78
|
+
|
|
79
|
+
self.use_ffn = (cfg.feedforward is not None)
|
|
80
|
+
assert self.use_ffn
|
|
81
|
+
|
|
82
|
+
ffn_dropout = cfg.feedforward.dropout
|
|
83
|
+
if first:
|
|
84
|
+
ffn_dropout = cfg.first_in_dropout
|
|
85
|
+
|
|
86
|
+
self.ffn = create_feedforward(
|
|
87
|
+
config=replace(
|
|
88
|
+
cfg.feedforward,
|
|
89
|
+
embedding_dim=cfg.d_model,
|
|
90
|
+
embedding_dim_out=cfg.d_model,
|
|
91
|
+
dropout=ffn_dropout
|
|
92
|
+
)
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
if self.cfg.norm == 'layernorm':
|
|
96
|
+
self.norm_ffn = nn.LayerNorm(cfg.d_model)
|
|
97
|
+
self.norm_neuron = nn.LayerNorm(cfg.d_model)
|
|
98
|
+
self.norm_conv = nn.LayerNorm(cfg.d_model)
|
|
99
|
+
elif self.cfg.norm == 'rmsnorm':
|
|
100
|
+
self.norm_ffn = nn.RMSNorm(cfg.d_model)
|
|
101
|
+
self.norm_neuron = nn.RMSNorm(cfg.d_model)
|
|
102
|
+
self.norm_conv = nn.RMSNorm(cfg.d_model)
|
|
103
|
+
else:
|
|
104
|
+
self.norm_conv = nn.Identity()
|
|
105
|
+
self.norm_ffn = nn.Identity()
|
|
106
|
+
self.norm_neuron = nn.Identity()
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
@property
|
|
110
|
+
def initial_state(self):
|
|
111
|
+
return {
|
|
112
|
+
'neuron': self.neuron.initial_state,
|
|
113
|
+
'conv': self.conv.initial_state
|
|
114
|
+
}
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def forward(self, u: torch.Tensor, state: dict):
|
|
118
|
+
|
|
119
|
+
conv_in = self.norm_conv(u)
|
|
120
|
+
conv, conv_state = self.conv(conv_in, state['conv'])
|
|
121
|
+
|
|
122
|
+
ffn_in = self.norm_ffn(u + conv)
|
|
123
|
+
ffn = self.ffn(ffn_in)
|
|
124
|
+
|
|
125
|
+
neuron_in = self.norm_neuron(u + ffn)
|
|
126
|
+
neuron, neuron_state = self.neuron(neuron_in, state['neuron'])
|
|
127
|
+
|
|
128
|
+
output = u + neuron
|
|
129
|
+
|
|
130
|
+
state = {'conv': conv_state, 'neuron': neuron_state}
|
|
131
|
+
|
|
132
|
+
return output, state
|
|
133
|
+
|
|
134
|
+
|
|
@@ -0,0 +1,148 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: 2026 Alessandro Ronca
|
|
2
|
+
# SPDX-License-Identifier: PolyForm-Noncommercial-1.0.0
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
import torch.nn as nn
|
|
6
|
+
from dataclasses import dataclass
|
|
7
|
+
from .modules.initialisers import wang_init_, small_init_init_
|
|
8
|
+
|
|
9
|
+
from . import minmax_scan
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@dataclass(frozen=True)
|
|
13
|
+
class MinMaxNeuronConfig:
|
|
14
|
+
"""
|
|
15
|
+
Configuration for a single MinMax Neuron.
|
|
16
|
+
|
|
17
|
+
Fields
|
|
18
|
+
------
|
|
19
|
+
_num_blocks : int
|
|
20
|
+
Total number of residual blocks in the enclosing model. Used to scale
|
|
21
|
+
the output projection at initialisation (wang_init_) so that the
|
|
22
|
+
combined contribution of all blocks to the residual stream stays O(1).
|
|
23
|
+
d_model : int
|
|
24
|
+
Dimension of the input and output (the residual-stream width).
|
|
25
|
+
d_state : int
|
|
26
|
+
Dimension of the hidden state x_t. Larger values give the neuron
|
|
27
|
+
more memory capacity but increase parameter count linearly.
|
|
28
|
+
dropout : float
|
|
29
|
+
Dropout probability applied to the input u before projection.
|
|
30
|
+
train_init : bool
|
|
31
|
+
If True, the initial hidden state x_0 is a learned parameter.
|
|
32
|
+
If False (default), x_0 is fixed at zero.
|
|
33
|
+
output_gate : bool
|
|
34
|
+
If True, the output is element-wise multiplied by a sigmoid-gated
|
|
35
|
+
projection of the input before the final linear: y = W_o(x ⊙ σ(W_g u)).
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
_num_blocks: int
|
|
39
|
+
d_model: int
|
|
40
|
+
d_state: int
|
|
41
|
+
dropout: float = 0.0
|
|
42
|
+
train_init: bool = False
|
|
43
|
+
output_gate: bool = True
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class MinMaxNeuron(nn.Module):
|
|
47
|
+
"""
|
|
48
|
+
The core recurrent cell of the MinMax RNC.
|
|
49
|
+
|
|
50
|
+
Maintains a hidden state x_t ∈ R^D updated by the MinMax recurrence:
|
|
51
|
+
|
|
52
|
+
x_{t+1} = max(min(r_t, x_t), s_t)
|
|
53
|
+
|
|
54
|
+
All states for a sequence of length T are computed simultaneously via a parallel prefix scan in
|
|
55
|
+
O(log T) depth instead of O(T).
|
|
56
|
+
|
|
57
|
+
Output projection:
|
|
58
|
+
|
|
59
|
+
y_t = W_o x_t (output_gate=False)
|
|
60
|
+
y_t = W_o (x_t ⊙ W_g u_t) (output_gate=True)
|
|
61
|
+
"""
|
|
62
|
+
|
|
63
|
+
def __init__(self, cfg: MinMaxNeuronConfig):
|
|
64
|
+
|
|
65
|
+
super().__init__()
|
|
66
|
+
|
|
67
|
+
self.cfg = cfg
|
|
68
|
+
|
|
69
|
+
self.I = I = cfg.d_model
|
|
70
|
+
self.D = D = cfg.d_state
|
|
71
|
+
|
|
72
|
+
self._initial_state = nn.Parameter(torch.zeros(D), requires_grad=cfg.train_init)
|
|
73
|
+
|
|
74
|
+
self.drop = nn.Dropout(cfg.dropout)
|
|
75
|
+
|
|
76
|
+
self.s = nn.Linear(I, D)
|
|
77
|
+
self.r = nn.Linear(I, D)
|
|
78
|
+
self.o = nn.Linear(D, I)
|
|
79
|
+
if cfg.output_gate:
|
|
80
|
+
self.o_g = nn.Linear(I, D)
|
|
81
|
+
|
|
82
|
+
self.reset()
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def reset(self):
|
|
86
|
+
# Init 's'
|
|
87
|
+
small_init_init_(self.s.weight, dim=self.I)
|
|
88
|
+
if self.s.bias is not None:
|
|
89
|
+
nn.init.zeros_(self.s.bias)
|
|
90
|
+
small_init_init_(self.r.weight, dim=self.I)
|
|
91
|
+
# Init 'r'
|
|
92
|
+
if self.r.bias is not None:
|
|
93
|
+
nn.init.zeros_(self.r.bias)
|
|
94
|
+
# Init 'o'
|
|
95
|
+
wang_init_(
|
|
96
|
+
self.o.weight,
|
|
97
|
+
dim=self.I,
|
|
98
|
+
num_blocks= self.cfg._num_blocks,
|
|
99
|
+
)
|
|
100
|
+
if self.o.bias is not None:
|
|
101
|
+
nn.init.zeros_(self.o.bias)
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
@property
|
|
105
|
+
def initial_state(self):
|
|
106
|
+
return self._initial_state
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def forward(self, u: torch.Tensor, state: torch.Tensor):
|
|
110
|
+
"""
|
|
111
|
+
Compute updated state for a sequence using closed form with initial state.
|
|
112
|
+
|
|
113
|
+
u: (B, T, I)
|
|
114
|
+
state: (1/B,D,) (state before the first step in the input sequence)
|
|
115
|
+
|
|
116
|
+
Returns: 1) sequence of outputs: (B, T, I)
|
|
117
|
+
2) last state: (B, D)
|
|
118
|
+
"""
|
|
119
|
+
|
|
120
|
+
B, T, I = u.shape
|
|
121
|
+
D = self.D
|
|
122
|
+
device = u.device
|
|
123
|
+
|
|
124
|
+
if state.dim() == 1: # state is the initial initial state
|
|
125
|
+
x0 = state.unsqueeze(0).expand(B,D)
|
|
126
|
+
else:
|
|
127
|
+
x0 = state
|
|
128
|
+
|
|
129
|
+
# Shape of u: (B,T,I)
|
|
130
|
+
# Shape of x0: (B,D)
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
u = self.drop(u) # (B,T,I)
|
|
134
|
+
s = self.s(u) # (B,T,D)
|
|
135
|
+
r = self.r(u) # (B,T,D)
|
|
136
|
+
|
|
137
|
+
x_post = minmax_scan.all_states(r, s, x0)
|
|
138
|
+
x_post = x_post[:,1:,:]
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
# ----- Compute outputs -----
|
|
142
|
+
x_latest = x_post[:,-1,:] # (B,T,D)
|
|
143
|
+
if self.cfg.output_gate:
|
|
144
|
+
x_post = x_post * self.o_g(u) # (B,T,D)
|
|
145
|
+
output = self.o(x_post) # (B,T,I)
|
|
146
|
+
|
|
147
|
+
return output, x_latest
|
|
148
|
+
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: 2026 Alessandro Ronca
|
|
2
|
+
# SPDX-License-Identifier: PolyForm-Noncommercial-1.0.0
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def apply(a: torch.Tensor, b: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
|
8
|
+
"""
|
|
9
|
+
Apply the MinMax scalar operator f(x) = max(min(a, x), b)
|
|
10
|
+
where
|
|
11
|
+
- min and max are applied element-wise,
|
|
12
|
+
- shapes (broadcastable):
|
|
13
|
+
a: (..., D)
|
|
14
|
+
b: (..., D)
|
|
15
|
+
x: (..., D)
|
|
16
|
+
"""
|
|
17
|
+
return torch.maximum(torch.minimum(a, x), b)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def compose(a2: torch.Tensor, b2: torch.Tensor, a1: torch.Tensor, b1: torch.Tensor):
|
|
21
|
+
"""
|
|
22
|
+
Compose MinMax scalar operators.
|
|
23
|
+
|
|
24
|
+
Given
|
|
25
|
+
a1,b1, a2,b2
|
|
26
|
+
having shape (..., D) and representing the MinMax scalar operators
|
|
27
|
+
f1(x) = max(min(a1, x), b1),
|
|
28
|
+
f2(x) = max(min(a2, x), b2),
|
|
29
|
+
return
|
|
30
|
+
a = min(a2, a1)
|
|
31
|
+
b = max(min(a2, b1), b2)
|
|
32
|
+
corresponding to the MinMax scalar operator
|
|
33
|
+
f(x) = f2(f1(x)) = max(min(a, x), b)
|
|
34
|
+
"""
|
|
35
|
+
a = torch.minimum(a2, a1)
|
|
36
|
+
b = torch.maximum(torch.minimum(a2, b1), b2)
|
|
37
|
+
return a, b
|
|
38
|
+
|
|
39
|
+
|
minmaxrnc/minmax_rnc.py
ADDED
|
@@ -0,0 +1,281 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: 2026 Alessandro Ronca
|
|
2
|
+
# SPDX-License-Identifier: PolyForm-Noncommercial-1.0.0
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
import torch.nn as nn
|
|
6
|
+
|
|
7
|
+
from typing import Optional, Literal, Union
|
|
8
|
+
from dataclasses import dataclass, replace
|
|
9
|
+
|
|
10
|
+
from .minmax_layer import MinMaxLayer, MinMaxLayerConfig, NormType
|
|
11
|
+
from .minmax_neuron import MinMaxNeuronConfig
|
|
12
|
+
from .modules.feedforward import FeedForwardConfig, FFType, InitType, create_feedforward
|
|
13
|
+
from .modules.basic_conv import BasicConvConfig
|
|
14
|
+
from .modules.gated_conv import GatedConvConfig
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
ConvType = Literal['basic', 'gated']
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@dataclass(frozen=True)
|
|
21
|
+
class MinMaxRNCConfig:
|
|
22
|
+
"""
|
|
23
|
+
Configuration for the MinMax RNC backbone.
|
|
24
|
+
|
|
25
|
+
Configs for neuron, conv, FFN, and layer are derived.
|
|
26
|
+
|
|
27
|
+
Core architecture
|
|
28
|
+
-----------------
|
|
29
|
+
d_model : int
|
|
30
|
+
Residual-stream width. Every sub-module input and output has this
|
|
31
|
+
dimension.
|
|
32
|
+
n_layers : int
|
|
33
|
+
Number of stacked MinMaxLayers.
|
|
34
|
+
d_state : int
|
|
35
|
+
Hidden-state dimension of each MinMax Neuron. Independent of d_model;
|
|
36
|
+
larger values increase memory capacity at linear parameter cost.
|
|
37
|
+
|
|
38
|
+
Normalisation
|
|
39
|
+
-------------
|
|
40
|
+
norm : 'layernorm' | 'rmsnorm' | 'none'
|
|
41
|
+
Pre-norm type applied before each sub-layer inside each layer.
|
|
42
|
+
'layernorm' (default) is stable; 'rmsnorm' is slightly faster;
|
|
43
|
+
'none' disables normalisation entirely.
|
|
44
|
+
postlayers_norm : 'layernorm' | 'rmsnorm' | 'none'
|
|
45
|
+
Norm applied to the output of the final layer (before the optional
|
|
46
|
+
post-layers FFN).
|
|
47
|
+
|
|
48
|
+
Feed-forward network (within each layer)
|
|
49
|
+
-----------------------------------------
|
|
50
|
+
ffn_type : 'gated' | 'basic'
|
|
51
|
+
'gated' (default) — gated FFN (ReGLU / SwiGLU depending on
|
|
52
|
+
act_fn), from Shazeer (2020). 'basic' — standard two-layer MLP.
|
|
53
|
+
ffn_proj_factor : float
|
|
54
|
+
Hidden-layer expansion factor relative to d_model. The hidden
|
|
55
|
+
dimension is rounded to the nearest multiple of 2.
|
|
56
|
+
ffn_act_fn : str
|
|
57
|
+
Activation function name. Choices: 'relu', 'relu^2', 'gelu',
|
|
58
|
+
'swish', 'sigmoid', 'selu'.
|
|
59
|
+
ffn_dropout : float
|
|
60
|
+
Dropout applied inside the FFN of every layer except possibly the first
|
|
61
|
+
(see prelayers_dropout).
|
|
62
|
+
ffn_init : 'default' | 'scaled'
|
|
63
|
+
Weight initialisation scheme. 'scaled' uses small_init for the
|
|
64
|
+
up-projection and wang_init for the down-projection.
|
|
65
|
+
|
|
66
|
+
Neuron
|
|
67
|
+
------
|
|
68
|
+
output_gate : bool
|
|
69
|
+
If True, the neuron output is element-wise gated by a learned
|
|
70
|
+
projection of the input.
|
|
71
|
+
train_init : bool
|
|
72
|
+
If True, the neuron's initial hidden state x_0 is a learned parameter.
|
|
73
|
+
neuron_dropout : float
|
|
74
|
+
Dropout probability applied to the neuron input.
|
|
75
|
+
|
|
76
|
+
Convolution
|
|
77
|
+
-----------
|
|
78
|
+
conv_type : 'gated' | 'basic'
|
|
79
|
+
'gated' (default) — learned scalar gate interpolating between
|
|
80
|
+
the previous and current token. 'basic' — learned linear mixing of
|
|
81
|
+
the previous and current token representations.
|
|
82
|
+
conv_init_val : float
|
|
83
|
+
Initial value of the gate logit in GatedConv. 0.0 → gate ≈ 0.5
|
|
84
|
+
(equal mix); negative values bias toward the current token.
|
|
85
|
+
|
|
86
|
+
Pre/Post-layers
|
|
87
|
+
-----------
|
|
88
|
+
prelayers_dropout : float
|
|
89
|
+
FFN dropout for the first layer only; overrides ffn_dropout. Useful
|
|
90
|
+
as an input-level regulariser without penalising deeper layers.
|
|
91
|
+
use_postlayers_ffn : bool
|
|
92
|
+
If True, an extra FFN (with the same type and factor as the in-layer
|
|
93
|
+
FFN) is applied after all layers, before postlayers_norm.
|
|
94
|
+
|
|
95
|
+
Forward
|
|
96
|
+
-------
|
|
97
|
+
unroll_steps : int
|
|
98
|
+
Sequence chunk size for the forward pass. The sequence is split into
|
|
99
|
+
chunks of this length and processed sequentially (carrying the state
|
|
100
|
+
across chunks). unroll_steps=1 processes one token at a time;
|
|
101
|
+
unroll_steps=T processes the whole sequence at once. Both give
|
|
102
|
+
identical outputs; larger values use more peak memory.
|
|
103
|
+
|
|
104
|
+
"""
|
|
105
|
+
|
|
106
|
+
# Core architecture
|
|
107
|
+
d_model: int
|
|
108
|
+
n_layers: int
|
|
109
|
+
d_state: int
|
|
110
|
+
|
|
111
|
+
# Normalisation (within layers and post-layers)
|
|
112
|
+
norm: NormType = 'layernorm'
|
|
113
|
+
postlayers_norm: NormType = 'layernorm'
|
|
114
|
+
|
|
115
|
+
# FFN within each layer
|
|
116
|
+
ffn_type: FFType = 'gated'
|
|
117
|
+
ffn_proj_factor: float = 1.3
|
|
118
|
+
ffn_act_fn: str = 'relu'
|
|
119
|
+
ffn_dropout: float = 0.0
|
|
120
|
+
ffn_init: InitType = 'scaled'
|
|
121
|
+
|
|
122
|
+
# Neuron
|
|
123
|
+
output_gate: bool = True
|
|
124
|
+
train_init: bool = False
|
|
125
|
+
neuron_dropout: float = 0.0
|
|
126
|
+
|
|
127
|
+
# Conv
|
|
128
|
+
conv_type: ConvType = 'gated'
|
|
129
|
+
conv_init_val: float = 0.0
|
|
130
|
+
|
|
131
|
+
# Per-layer options
|
|
132
|
+
prelayers_dropout: float = 0.0
|
|
133
|
+
|
|
134
|
+
# Post-layers
|
|
135
|
+
use_postlayers_ffn: bool = False
|
|
136
|
+
|
|
137
|
+
@property
|
|
138
|
+
def layer_cfg(self) -> MinMaxLayerConfig:
|
|
139
|
+
neuron_cfg = MinMaxNeuronConfig(
|
|
140
|
+
_num_blocks = self.n_layers,
|
|
141
|
+
d_model = self.d_model,
|
|
142
|
+
d_state = self.d_state,
|
|
143
|
+
dropout = self.neuron_dropout,
|
|
144
|
+
train_init = self.train_init,
|
|
145
|
+
output_gate = self.output_gate,
|
|
146
|
+
)
|
|
147
|
+
if self.conv_type == 'basic':
|
|
148
|
+
conv_cfg = BasicConvConfig(embedding_dim=self.d_model)
|
|
149
|
+
else:
|
|
150
|
+
conv_cfg = GatedConvConfig(
|
|
151
|
+
embedding_dim = self.d_model,
|
|
152
|
+
init_val = self.conv_init_val,
|
|
153
|
+
)
|
|
154
|
+
ffn_cfg = FeedForwardConfig(
|
|
155
|
+
_num_blocks = self.n_layers,
|
|
156
|
+
ffn_type = self.ffn_type,
|
|
157
|
+
proj_factor = self.ffn_proj_factor,
|
|
158
|
+
act_fn = self.ffn_act_fn,
|
|
159
|
+
dropout = self.ffn_dropout,
|
|
160
|
+
init = self.ffn_init,
|
|
161
|
+
)
|
|
162
|
+
return MinMaxLayerConfig(
|
|
163
|
+
d_model = self.d_model,
|
|
164
|
+
neuron = neuron_cfg,
|
|
165
|
+
conv = conv_cfg,
|
|
166
|
+
feedforward = ffn_cfg,
|
|
167
|
+
norm = self.norm,
|
|
168
|
+
first_in_dropout = self.prelayers_dropout,
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
# ------------------------------------------------------------------
|
|
172
|
+
# Preset factories
|
|
173
|
+
# ------------------------------------------------------------------
|
|
174
|
+
|
|
175
|
+
@classmethod
|
|
176
|
+
def small(cls, n_layers: int = 2) -> 'MinMaxRNCConfig':
|
|
177
|
+
return cls(d_model=90, n_layers=n_layers, d_state=40)
|
|
178
|
+
|
|
179
|
+
@classmethod
|
|
180
|
+
def medium(cls, n_layers: int = 8) -> 'MinMaxRNCConfig':
|
|
181
|
+
return cls(d_model=512, n_layers=n_layers, d_state=512)
|
|
182
|
+
|
|
183
|
+
@classmethod
|
|
184
|
+
def large(cls, n_layers: int = 12) -> 'MinMaxRNCConfig':
|
|
185
|
+
return cls(d_model=728, n_layers=n_layers, d_state=1456)
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
class MinMaxRNC(nn.Module):
|
|
189
|
+
"""
|
|
190
|
+
MinMax Recurrent Neural Cascade — the backbone sequence model.
|
|
191
|
+
|
|
192
|
+
Stacks ``cfg.n_layers`` MinMaxLayers, each containing a short-range
|
|
193
|
+
convolution, a feed-forward network, and a MinMax Neuron. All three
|
|
194
|
+
sub-layers use pre-norm and residual connections.
|
|
195
|
+
|
|
196
|
+
Inputs
|
|
197
|
+
------
|
|
198
|
+
u : Tensor (B, T, d_model)
|
|
199
|
+
Continuous input sequence (e.g. token embeddings).
|
|
200
|
+
state : list[dict] | None
|
|
201
|
+
Per-layer recurrent state from a previous call. Pass None (or omit)
|
|
202
|
+
to start from the default initial state.
|
|
203
|
+
return_state : bool
|
|
204
|
+
If True, also return the updated state after the last token.
|
|
205
|
+
|
|
206
|
+
Outputs
|
|
207
|
+
-------
|
|
208
|
+
y : Tensor (B, T, d_model)
|
|
209
|
+
state : list[dict] — only when return_state=True
|
|
210
|
+
"""
|
|
211
|
+
|
|
212
|
+
def __init__(self, cfg: MinMaxRNCConfig):
|
|
213
|
+
super().__init__()
|
|
214
|
+
self.__cfg = cfg
|
|
215
|
+
self.reset()
|
|
216
|
+
|
|
217
|
+
def reset(self):
|
|
218
|
+
layer_cfg = self.__cfg.layer_cfg
|
|
219
|
+
|
|
220
|
+
self.layers = nn.ModuleList()
|
|
221
|
+
firstlayer = True
|
|
222
|
+
for _ in range(self.__cfg.n_layers):
|
|
223
|
+
self.layers.append(MinMaxLayer(layer_cfg, first=firstlayer))
|
|
224
|
+
firstlayer = False
|
|
225
|
+
|
|
226
|
+
self.postlayers_norm = None
|
|
227
|
+
if self.__cfg.postlayers_norm == 'layernorm':
|
|
228
|
+
self.postlayers_norm = nn.LayerNorm(self.__cfg.d_model)
|
|
229
|
+
elif self.__cfg.postlayers_norm == 'rmsnorm':
|
|
230
|
+
self.postlayers_norm = nn.RMSNorm(self.__cfg.d_model)
|
|
231
|
+
|
|
232
|
+
self.postlayers_ffn = None
|
|
233
|
+
self.postlayers_ffn_norm = None
|
|
234
|
+
if self.__cfg.use_postlayers_ffn:
|
|
235
|
+
self.postlayers_ffn = create_feedforward(
|
|
236
|
+
config=replace(
|
|
237
|
+
layer_cfg.feedforward,
|
|
238
|
+
embedding_dim = self.__cfg.d_model,
|
|
239
|
+
embedding_dim_out = self.__cfg.d_model,
|
|
240
|
+
)
|
|
241
|
+
)
|
|
242
|
+
if self.__cfg.norm == 'layernorm':
|
|
243
|
+
self.postlayers_ffn_norm = nn.LayerNorm(self.__cfg.d_model)
|
|
244
|
+
elif self.__cfg.norm == 'rmsnorm':
|
|
245
|
+
self.postlayers_ffn_norm = nn.RMSNorm(self.__cfg.d_model)
|
|
246
|
+
else:
|
|
247
|
+
self.postlayers_ffn_norm = nn.Identity()
|
|
248
|
+
|
|
249
|
+
@property
|
|
250
|
+
def initial_state(self):
|
|
251
|
+
return [layer.initial_state for layer in self.layers]
|
|
252
|
+
|
|
253
|
+
def _parallel_forward(self, u: torch.Tensor, state):
|
|
254
|
+
"""u: [B, T, D] — returns output [B, T, D] and updated state."""
|
|
255
|
+
updated_state = []
|
|
256
|
+
y = u
|
|
257
|
+
for layer, layer_state in zip(self.layers, state):
|
|
258
|
+
y, updated_layer_state = layer(y, layer_state)
|
|
259
|
+
updated_state.append(updated_layer_state)
|
|
260
|
+
|
|
261
|
+
if self.postlayers_ffn is not None:
|
|
262
|
+
y = y + self.postlayers_ffn(self.postlayers_ffn_norm(y))
|
|
263
|
+
|
|
264
|
+
if self.postlayers_norm is not None:
|
|
265
|
+
y = self.postlayers_norm(y)
|
|
266
|
+
|
|
267
|
+
return y, updated_state
|
|
268
|
+
|
|
269
|
+
def forward(self, u: torch.Tensor, unroll_steps: int, state=None, return_state: bool = False):
|
|
270
|
+
if state is None:
|
|
271
|
+
state = self.initial_state
|
|
272
|
+
|
|
273
|
+
y_chunks = []
|
|
274
|
+
for u_chunk in u.split(unroll_steps, dim=1):
|
|
275
|
+
y_chunk, state = self._parallel_forward(u_chunk, state)
|
|
276
|
+
y_chunks.append(y_chunk)
|
|
277
|
+
y = torch.cat(y_chunks, dim=1)
|
|
278
|
+
|
|
279
|
+
if return_state:
|
|
280
|
+
return y, state
|
|
281
|
+
return y
|
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: 2026 Alessandro Ronca
|
|
2
|
+
# SPDX-License-Identifier: PolyForm-Noncommercial-1.0.0
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
import torch.nn as nn
|
|
6
|
+
|
|
7
|
+
from dataclasses import dataclass
|
|
8
|
+
|
|
9
|
+
from .minmax_rnc import MinMaxRNC, MinMaxRNCConfig
|
|
10
|
+
from .modules.initialisers import small_init_init_
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@dataclass(frozen=True)
|
|
14
|
+
class MinMaxRNCLMConfig:
|
|
15
|
+
"""
|
|
16
|
+
Configuration for the MinMax RNC language model.
|
|
17
|
+
|
|
18
|
+
Fields
|
|
19
|
+
------
|
|
20
|
+
backbone : MinMaxRNCConfig
|
|
21
|
+
Config for the MinMaxRNC backbone. The embedding dimension is taken
|
|
22
|
+
from backbone.d_model.
|
|
23
|
+
head_dropout : float
|
|
24
|
+
Dropout applied to the backbone output before the LM head projection.
|
|
25
|
+
tie_weights : bool
|
|
26
|
+
If True (default), the LM head weight matrix is shared with the token
|
|
27
|
+
embedding matrix, halving those parameters and acting as a regulariser.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
backbone: MinMaxRNCConfig
|
|
31
|
+
head_dropout: float = 0.0
|
|
32
|
+
tie_weights: bool = True
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class MinMaxRNC_LM(MinMaxRNC):
|
|
36
|
+
"""
|
|
37
|
+
MinMax RNC with a token embedding layer and a language-model head.
|
|
38
|
+
|
|
39
|
+
Wraps MinMaxRNC with:
|
|
40
|
+
- A token embedding (vocab_size × d_model)
|
|
41
|
+
- A dropout before the output projection
|
|
42
|
+
- A linear LM head (d_model × vocab_size), optionally tied to the embedding
|
|
43
|
+
|
|
44
|
+
Inputs
|
|
45
|
+
------
|
|
46
|
+
tokens : LongTensor (B, T)
|
|
47
|
+
Token indices in [0, vocab_size).
|
|
48
|
+
state : list[dict] | None
|
|
49
|
+
Recurrent state from a previous call.
|
|
50
|
+
return_state : bool
|
|
51
|
+
If True, also return the updated state.
|
|
52
|
+
|
|
53
|
+
Outputs
|
|
54
|
+
-------
|
|
55
|
+
logits : Tensor (B, T, vocab_size)
|
|
56
|
+
state : list[dict] — only when return_state=True
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
def __init__(self, vocab_size: int, cfg: MinMaxRNCLMConfig):
|
|
60
|
+
self.__lm_cfg = cfg
|
|
61
|
+
self.__vocab_size = vocab_size
|
|
62
|
+
super().__init__(cfg.backbone) # calls reset() → MinMaxRNC.reset() then our additions
|
|
63
|
+
self.__lm_reset()
|
|
64
|
+
|
|
65
|
+
def reset(self):
|
|
66
|
+
super().reset()
|
|
67
|
+
self.__lm_reset()
|
|
68
|
+
|
|
69
|
+
def __lm_reset(self):
|
|
70
|
+
d_model = self.__lm_cfg.backbone.d_model
|
|
71
|
+
self.token_emb = nn.Embedding(self.__vocab_size, d_model)
|
|
72
|
+
self.lm_head = nn.Linear(d_model, self.__vocab_size, bias=False)
|
|
73
|
+
self.head_drop = nn.Dropout(self.__lm_cfg.head_dropout)
|
|
74
|
+
|
|
75
|
+
small_init_init_(self.token_emb.weight, dim=d_model)
|
|
76
|
+
if self.__lm_cfg.tie_weights:
|
|
77
|
+
self.lm_head.weight = self.token_emb.weight
|
|
78
|
+
else:
|
|
79
|
+
small_init_init_(self.lm_head.weight, dim=d_model)
|
|
80
|
+
|
|
81
|
+
def forward(self, tokens: torch.Tensor, unroll_steps: int, state=None, return_state: bool = False):
|
|
82
|
+
y, state = super().forward(
|
|
83
|
+
self.token_emb(tokens), unroll_steps, state=state, return_state=True
|
|
84
|
+
)
|
|
85
|
+
logits = self.lm_head(self.head_drop(y))
|
|
86
|
+
if return_state:
|
|
87
|
+
return logits, state
|
|
88
|
+
return logits
|