neural-ssm 0.1.7__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,75 @@
1
+ Metadata-Version: 2.4
2
+ Name: neural-ssm
3
+ Version: 0.1.7
4
+ Summary: Neural state space models and LRU variants in PyTorch
5
+ Author: Leonardo Massai
6
+ License: MIT
7
+ Project-URL: Homepage, https://github.com/LeoMassai
8
+ Project-URL: Repository, https://github.com/LeoMassai/neural-ssm
9
+ Project-URL: Issues, https://github.com/LeoMassai/neural-ssm/issues
10
+ Classifier: Programming Language :: Python :: 3
11
+ Classifier: Programming Language :: Python :: 3 :: Only
12
+ Classifier: Programming Language :: Python :: 3.9
13
+ Classifier: Programming Language :: Python :: 3.10
14
+ Classifier: Programming Language :: Python :: 3.11
15
+ Classifier: Programming Language :: Python :: 3.12
16
+ Classifier: License :: OSI Approved :: MIT License
17
+ Classifier: Operating System :: OS Independent
18
+ Classifier: Intended Audience :: Science/Research
19
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
20
+ Requires-Python: >=3.9
21
+ Description-Content-Type: text/markdown
22
+ Requires-Dist: torch>=2.2
23
+ Requires-Dist: numpy>=1.23
24
+ Requires-Dist: scipy>=1.10
25
+ Requires-Dist: h5py>=3.9
26
+ Requires-Dist: tqdm>=4.66
27
+ Requires-Dist: matplotlib>=3.7
28
+ Requires-Dist: einops>=0.6
29
+ Requires-Dist: typing-extensions>=4.5
30
+ Requires-Dist: jax>=0.4.26
31
+ Requires-Dist: jaxlib>=0.4.26
32
+ Requires-Dist: deel-torchlip
33
+ Provides-Extra: dev
34
+ Requires-Dist: black>=24.3.0; extra == "dev"
35
+ Requires-Dist: ruff>=0.3.0; extra == "dev"
36
+ Requires-Dist: mypy>=1.6.0; extra == "dev"
37
+ Requires-Dist: pytest>=7.4; extra == "dev"
38
+ Requires-Dist: pytest-cov>=4.1; extra == "dev"
39
+ Requires-Dist: ipykernel>=6.29; extra == "dev"
40
+ Requires-Dist: jupyter>=1.0; extra == "dev"
41
+ Provides-Extra: examples
42
+ Requires-Dist: seaborn>=0.13; extra == "examples"
43
+ Requires-Dist: pandas>=2.1; extra == "examples"
44
+ Provides-Extra: docs
45
+ Requires-Dist: mkdocs>=1.5; extra == "docs"
46
+ Requires-Dist: mkdocs-material>=9.5; extra == "docs"
47
+
48
+ # Pytorch L2RU Architecture: LRU with l2 stability guarantees and prescribed bound
49
+
50
+ A PyTorch implementation of the L2RU architecture introduced in the paper Free Parametrization of L2-bounded State Space Models. https://arxiv.org/abs/2503.23818. Application in System Identification is included as an example.
51
+
52
+ ## L2RU block
53
+ The L2RU block is a discrete-time linear time-invariant system implemented in state-space form as:
54
+ ```math
55
+ \begin{align}
56
+ x_{k+1} = Ax_{x} + B u_k\\
57
+ y_k = C x_k + D u_k,
58
+ \end{align}
59
+ ```
60
+ A parametrization is provided for the matrices ```(A, B, C, D)```, guaranteeing a prescribed l2 bound for the overall SSM.
61
+ Moreover, the use of [parallel scan algorithms](https://en.wikipedia.org/wiki/Prefix_sum) makes execution extremely fast on modern hardware in non-core-bound scenarios.
62
+
63
+ ## Deep L2RU Architecture
64
+
65
+ L2RU units are typically organized in a deep LRU architecture like:
66
+
67
+ <div align="center">
68
+ <img src="architecture/L2RU.png" alt="Description of image" width="800">
69
+ </div>
70
+
71
+
72
+
73
+
74
+
75
+
@@ -0,0 +1,28 @@
1
+ # Pytorch L2RU Architecture: LRU with l2 stability guarantees and prescribed bound
2
+
3
+ A PyTorch implementation of the L2RU architecture introduced in the paper Free Parametrization of L2-bounded State Space Models. https://arxiv.org/abs/2503.23818. Application in System Identification is included as an example.
4
+
5
+ ## L2RU block
6
+ The L2RU block is a discrete-time linear time-invariant system implemented in state-space form as:
7
+ ```math
8
+ \begin{align}
9
+ x_{k+1} = Ax_{x} + B u_k\\
10
+ y_k = C x_k + D u_k,
11
+ \end{align}
12
+ ```
13
+ A parametrization is provided for the matrices ```(A, B, C, D)```, guaranteeing a prescribed l2 bound for the overall SSM.
14
+ Moreover, the use of [parallel scan algorithms](https://en.wikipedia.org/wiki/Prefix_sum) makes execution extremely fast on modern hardware in non-core-bound scenarios.
15
+
16
+ ## Deep L2RU Architecture
17
+
18
+ L2RU units are typically organized in a deep LRU architecture like:
19
+
20
+ <div align="center">
21
+ <img src="architecture/L2RU.png" alt="Description of image" width="800">
22
+ </div>
23
+
24
+
25
+
26
+
27
+
28
+
@@ -0,0 +1,107 @@
1
+ [build-system]
2
+ requires = ["setuptools>=68", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "neural-ssm"
7
+ version = "0.1.7"
8
+ description = "Neural state space models and LRU variants in PyTorch"
9
+ readme = "README.md"
10
+ requires-python = ">=3.9"
11
+ license = { text = "MIT" }
12
+ authors = [
13
+ { name = "Leonardo Massai" }
14
+ ]
15
+ classifiers = [
16
+ "Programming Language :: Python :: 3",
17
+ "Programming Language :: Python :: 3 :: Only",
18
+ "Programming Language :: Python :: 3.9",
19
+ "Programming Language :: Python :: 3.10",
20
+ "Programming Language :: Python :: 3.11",
21
+ "Programming Language :: Python :: 3.12",
22
+ "License :: OSI Approved :: MIT License",
23
+ "Operating System :: OS Independent",
24
+ "Intended Audience :: Science/Research",
25
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
26
+ ]
27
+ dependencies = [
28
+ # Core
29
+ "torch>=2.2",
30
+ "numpy>=1.23",
31
+ # SciPy for linear algebra and scipy.io.loadmat (MAT files in Test_files)
32
+ "scipy>=1.10",
33
+ # Optional MAT v7.3 support via HDF5 (some .mat files may require it)
34
+ "h5py>=3.9",
35
+ # Utilities commonly used in training/eval and plotting
36
+ "tqdm>=4.66",
37
+ "matplotlib>=3.7",
38
+ # Tensor/array rearrangements if used in layers
39
+ "einops>=0.6",
40
+ # Backport typing if running older Python or for PyTorch type hints
41
+ "typing-extensions>=4.5",
42
+ # JAX (CPU by default; GPU wheels are user specific)
43
+ "jax>=0.4.26",
44
+ "jaxlib>=0.4.26",
45
+ "deel-torchlip",
46
+ ]
47
+
48
+ [project.optional-dependencies]
49
+ dev = [
50
+ "black>=24.3.0",
51
+ "ruff>=0.3.0",
52
+ "mypy>=1.6.0",
53
+ "pytest>=7.4",
54
+ "pytest-cov>=4.1",
55
+ "ipykernel>=6.29",
56
+ "jupyter>=1.0",
57
+ ]
58
+ examples = [
59
+ "seaborn>=0.13",
60
+ "pandas>=2.1",
61
+ ]
62
+ docs = [
63
+ "mkdocs>=1.5",
64
+ "mkdocs-material>=9.5",
65
+ ]
66
+
67
+ [project.urls]
68
+ Homepage = "https://github.com/LeoMassai"
69
+ Repository = "https://github.com/LeoMassai/neural-ssm"
70
+ Issues = "https://github.com/LeoMassai/neural-ssm/issues"
71
+
72
+ [tool.setuptools]
73
+ include-package-data = true
74
+ license-files = ["LICEN[CS]E*"]
75
+
76
+ [tool.setuptools.packages.find]
77
+ where = ["src"]
78
+ include = ["neural_ssm*"]
79
+
80
+ [tool.pytest.ini_options]
81
+ addopts = "-q"
82
+ testpaths = ["tests", "Test_files"]
83
+
84
+ [tool.black]
85
+ line-length = 100
86
+ target-version = ["py39", "py310", "py311", "py312"]
87
+ include = "\\.pyi?$"
88
+
89
+ [tool.ruff]
90
+ line-length = 100
91
+ target-version = "py39"
92
+ src = ["src", "tests", "Test_files"]
93
+ extend-exclude = ["build", "dist", ".venv", ".mypy_cache", ".ruff_cache"]
94
+ select = ["E", "F", "I", "UP", "B"]
95
+ ignore = []
96
+
97
+ [tool.mypy]
98
+ python_version = "3.9"
99
+ ignore_missing_imports = true
100
+ warn_unused_ignores = true
101
+ warn_return_any = true
102
+ warn_unused_configs = true
103
+ no_implicit_optional = true
104
+ strict_equality = true
105
+ show_error_codes = true
106
+ pretty = true
107
+ mypy_path = ["src"]
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
@@ -0,0 +1,42 @@
1
+ # python
2
+ # file: src/neural_ssm/__init__.py
3
+ from importlib import import_module as _imp
4
+
5
+ # Re-export subpackages for discoverability
6
+ from . import ssm as ssm
7
+ from . import rens as rens
8
+ from . import static_layers as layers # public alias
9
+
10
+ # Top-level classes and configs
11
+ from .ssm.lru import LRU, L2RU, lruz, SSMConfig, SSL, DeepSSM, PureLRUR
12
+ from .rens.ren import REN
13
+
14
+ # Common layers exposed at top-level for convenience
15
+ try:
16
+ from .static_layers.generic_layers import LayerConfig, GLU, MLP, TLIP
17
+ except Exception:
18
+ pass
19
+ try:
20
+ from .static_layers.lipschitz_mlps import LMLP
21
+ except Exception:
22
+ pass
23
+
24
+ __all__ = [
25
+ "LRU", "L2RU", "lruz", "SSMConfig", "SSL", "DeepSSM", "PureLRUR",
26
+ "REN",
27
+ "layers", "ssm", "rens",
28
+ "LayerConfig", "GLU", "MLP", "LMLP", "TLIP",
29
+ ]
30
+
31
+ __version__ = "0.1.0"
32
+
33
+ def __getattr__(name):
34
+ # Optional lazy/compat shims; keep internals movable
35
+ redirects = {
36
+ "layers": "neural_ssm.static_layers",
37
+ }
38
+ if name in redirects:
39
+ return _imp(redirects[name])
40
+ raise AttributeError(f"module neural_ssm has no attribute {name!r}")
41
+
42
+ del _imp
@@ -0,0 +1,5 @@
1
+ # python
2
+ # file: src/neural_ssm/rens/__init__.py
3
+ from .ren import REN
4
+
5
+ __all__ = ["REN"]
@@ -0,0 +1,210 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+
7
+ # Robust REN implementation in the acyclic version
8
+ class REN(nn.Module):
9
+ # ## Implementation of REN model, modified from "Recurrent Equilibrium Networks: Flexible Dynamic Models with
10
+ # Guaranteed Stability and Robustness" by Max Revay et al.
11
+ def __init__(self, dim_in: int, dim_out: int, dim_internal: int,
12
+ dim_nl: int, initialization_std: float = 0.5, internal_state_init=None, gammat=None, mode="l2stable",
13
+ Q=None, R=None, S=None
14
+ , posdef_tol: float = 0.001):
15
+ super().__init__()
16
+
17
+ # set dimensions
18
+ self.dim_in = dim_in # input dimension m
19
+ self.dim_internal = dim_internal # state dimension n
20
+ self.dim_nl = dim_nl # dimension of v(t) and w(t) l
21
+ self.dim_out = dim_out # output dimension p
22
+
23
+ self.mode = mode
24
+ self.gammat = gammat
25
+ self.epsilon = posdef_tol
26
+ # # # # # # # # # IQC specification # # # # # # # # #
27
+ self.Q = Q
28
+ self.R = R
29
+ self.S = S
30
+ # # # # # # # # # Training parameters # # # # # # # # #
31
+ # Sparse training matrix parameters
32
+ # define matrices shapes
33
+ self.X_shape = (2 * dim_internal + dim_nl, 2 * dim_internal + dim_nl)
34
+ self.Y_shape = (dim_internal, dim_internal)
35
+ # nn state dynamics
36
+ self.B2_shape = (dim_internal, dim_in)
37
+ # nn output
38
+ self.C2_shape = (dim_out, dim_internal)
39
+ #self.D21_shape = (self.dim_out, self.dim_l)
40
+ self.D22_shape = (dim_out, dim_in)
41
+ # v signal
42
+ self.D12_shape = (dim_nl, dim_in)
43
+ self.Z3_shape = (abs(dim_out - dim_in), min(dim_out, dim_in))
44
+ self.X3_shape = (min(dim_out, dim_in), min(dim_out, dim_in))
45
+ self.Y3_shape = (min(dim_out, dim_in), min(dim_out, dim_in))
46
+ self.gamma_shape = (1, 1)
47
+
48
+ self.training_param_names = ['X', 'Y', 'B2', 'C2', 'Z3', 'X3', 'Y3', 'D12']
49
+
50
+ # Optionally define a trainable gamma
51
+ if self.gammat is None:
52
+ self.training_param_names.append('gamma')
53
+ else:
54
+ self.gamma = gammat
55
+
56
+ # define trainable params
57
+
58
+ self._init_trainable_params(initialization_std)
59
+
60
+ # # # # # # # # # Non-trainable parameters and constant tensors # # # # # # # # #
61
+ # masks
62
+ self.register_buffer('eye_mask_min', torch.eye(min(dim_in, dim_out)))
63
+ self.register_buffer('eye_mask_dim_in', torch.eye(dim_in))
64
+ self.register_buffer('eye_mask_dim_out', torch.eye(dim_out))
65
+ self.register_buffer('eye_mask_dim_state', torch.eye(dim_internal))
66
+ self.register_buffer('eye_mask_H', torch.eye(2 * dim_internal + dim_nl))
67
+ self.register_buffer('zeros_mask_S', torch.zeros(dim_in, dim_out))
68
+ self.register_buffer('zeros_mask_Q', torch.zeros(dim_out, dim_out))
69
+ self.register_buffer('zeros_mask_R', torch.zeros(dim_in, dim_in))
70
+ self.register_buffer('zeros_mask_so', torch.zeros(dim_internal, dim_out))
71
+ self.register_buffer('eye_mask_w', torch.eye(dim_nl))
72
+ self.register_buffer('D21', torch.zeros(dim_out, dim_nl))
73
+
74
+ # initialize internal state
75
+ if internal_state_init is None:
76
+ self.x = torch.zeros(1, 1, self.dim_internal, device="cuda")
77
+ else:
78
+ assert isinstance(internal_state_init, torch.Tensor)
79
+ self.x = internal_state_init.reshape(1, 1, self.dim_internal)
80
+ self.register_buffer('init_x', self.x.detach().clone())
81
+
82
+ # Auxiliary elements
83
+ self.set_param()
84
+
85
+ def set_param(self, gamman=None):
86
+ if gamman is not None:
87
+ self.gamma = gamman
88
+ gamma = torch.abs(self.gamma)
89
+ dim_internal, dim_nl, dim_in, dim_out = self.dim_internal, self.dim_nl, self.dim_in, self.dim_out
90
+
91
+ # Updating of Q,S,R with variable gamma if needed
92
+ self.Q, self.R, self.S = self._set_mode(self.mode, gamma, self.Q, self.R, self.S)
93
+ M = F.linear(self.X3.T, self.X3.T) + self.Y3 - self.Y3.T + F.linear(self.Z3.T,
94
+ self.Z3.T) + self.epsilon * self.eye_mask_min
95
+ if dim_out >= dim_in:
96
+ N = torch.vstack((F.linear(self.eye_mask_dim_in - M,
97
+ torch.inverse(self.eye_mask_dim_in + M).T),
98
+ -2 * F.linear(self.Z3, torch.inverse(self.eye_mask_dim_in + M).T)))
99
+ else:
100
+ N = torch.hstack((F.linear(torch.inverse(self.eye_mask_dim_out + M),
101
+ (self.eye_mask_dim_out - M).T),
102
+ -2 * F.linear(torch.inverse(self.eye_mask_dim_out + M), self.Z3)))
103
+
104
+ Lq = torch.linalg.cholesky(-self.Q).T
105
+ Lr = torch.linalg.cholesky(self.R - torch.matmul(self.S, torch.matmul(torch.inverse(self.Q), self.S.T))).T
106
+ self.D22 = -torch.matmul(torch.inverse(self.Q), self.S.T) + torch.matmul(torch.inverse(Lq),
107
+ torch.matmul(N, Lr))
108
+ # Calculate psi_r:
109
+ R_cal = self.R + torch.matmul(self.S, self.D22) + torch.matmul(self.S, self.D22).T + torch.matmul(self.D22.T,
110
+ torch.matmul(
111
+ self.Q,
112
+ self.D22))
113
+ R_cal_inv = torch.inverse(R_cal)
114
+ C2_cal = torch.matmul(torch.matmul(self.D22.T, self.Q) + self.S, self.C2)
115
+ D21_cal = torch.matmul(torch.matmul(self.D22.T, self.Q) + self.S, self.D21) - self.D12.T
116
+ vec_r = torch.cat((C2_cal.T, D21_cal.T, self.B2), dim=0)
117
+ psi_r = torch.matmul(vec_r, torch.matmul(R_cal_inv, vec_r.T))
118
+ # Calculate psi_q:
119
+ vec_q = torch.cat((self.C2.T, self.D21.T, self.zeros_mask_so), dim=0)
120
+ psi_q = torch.matmul(vec_q, torch.matmul(self.Q, vec_q.T))
121
+ # Create H matrix:
122
+ H = torch.matmul(self.X.T, self.X) + self.epsilon * self.eye_mask_H + psi_r - psi_q
123
+ h1, h2, h3 = torch.split(H, [dim_internal, dim_nl, dim_internal], dim=0)
124
+ H11, H12, H13 = torch.split(h1, [dim_internal, dim_nl, dim_internal], dim=1)
125
+ H21, H22, _ = torch.split(h2, [dim_internal, dim_nl, dim_internal], dim=1)
126
+ H31, H32, H33 = torch.split(h3, [dim_internal, dim_nl, dim_internal], dim=1)
127
+ self.P_cal = H33
128
+ # NN state dynamics:
129
+ self.F = H31
130
+ self.B1 = H32
131
+ # NN output:
132
+ self.E = 0.5 * (H11 + self.P_cal + self.Y - self.Y.T)
133
+ # v signal: [Change the following 2 lines if we don't want a strictly acyclic REN!]
134
+ self.Lambda = 0.5 * torch.diag(H22)
135
+ self.D11 = -torch.tril(H22, diagonal=-1)
136
+ self.C1 = -H21
137
+ # Matrix P
138
+ #self.P = torch.matmul(self.E.T, torch.matmul(torch.inverse(self.P_cal), self.E))
139
+
140
+ def forward(self, u):
141
+ decay_rate = 0.95
142
+ batch_size = u.shape[0]
143
+ w = torch.zeros(batch_size, 1, self.dim_nl, device=u.device)
144
+ # update each row of w using Eq. (8) with a lower triangular D11
145
+ for i in range(self.dim_nl):
146
+ # v is element i of v with dim (batch_size, 1)
147
+ v = F.linear(self.x, self.C1[i, :]) + F.linear(w, self.D11[i, :]) + F.linear(u, self.D12[i, :])
148
+ w = w + (self.eye_mask_w[i, :] * torch.tanh(v / self.Lambda[i])).reshape(batch_size, 1, self.dim_nl)
149
+
150
+ # compute next state using Eq. 18
151
+ self.x = F.linear(
152
+ F.linear(self.x, self.F) + F.linear(w, self.B1) + F.linear(u, self.B2),
153
+ self.E.inverse())
154
+
155
+ # compute output
156
+ y = F.linear(self.x, self.C2) + F.linear(w, self.D21) + F.linear(u, self.D22)
157
+
158
+ return y
159
+
160
+ def _set_mode(self, mode, gamma, Q, R, S, eps: float = 1e-4):
161
+ # We set Q to be negative definite. If Q is nsd we set: Q - \epsilon I.
162
+ # I.e. The Q we define here is denoted as \matcal{Q} in REN paper.
163
+ if mode == "l2stable":
164
+ Q = -(1. / gamma) * self.eye_mask_dim_out
165
+ R = gamma * self.eye_mask_dim_in
166
+ S = self.zeros_mask_S
167
+ elif mode == "input_p":
168
+ if self.p != self.m:
169
+ raise NameError("Dimensions of u(t) and y(t) need to be the same for enforcing input passivity.")
170
+ Q = self.zeros_mask_Q - eps * self.eye_mask_dim_out
171
+ R = -2. * gamma * self.eye_mask_dim_state
172
+ S = self.eye_mask_dim_out
173
+ elif mode == "output_p":
174
+ if self.p != self.m:
175
+ raise NameError("Dimensions of u(t) and y(t) need to be the same for enforcing output passivity.")
176
+ Q = -2. * gamma * self.eye_mask_dim_out
177
+ R = self.zeros_mask_R
178
+ S = self.eye_mask_dim_state
179
+ else:
180
+ print("Using matrices R,Q,S given by user.")
181
+ # Check dimensions:
182
+ if not (len(R.shape) == 2 and R.shape[0] == R.shape[1] and R.shape[0] == self.m):
183
+ raise NameError("The matrix R is not valid. It must be a square matrix of %ix%i." % (self.m, self.m))
184
+ if not (len(Q.shape) == 2 and Q.shape[0] == Q.shape[1] and Q.shape[0] == self.p):
185
+ raise NameError("The matrix Q is not valid. It must be a square matrix of %ix%i." % (self.p, self.p))
186
+ if not (len(S.shape) == 2 and S.shape[0] == self.m and S.shape[1] == self.p):
187
+ raise NameError("The matrix S is not valid. It must be a matrix of %ix%i." % (self.m, self.p))
188
+ # Check R=R':
189
+ if not (R == R.T).prod():
190
+ raise NameError("The matrix R is not valid. It must be symmetric.")
191
+ # Check Q is nsd:
192
+ eigs, _ = torch.linalg.eig(Q)
193
+ if not (eigs.real <= 0).prod():
194
+ print('oh!')
195
+ raise NameError("The matrix Q is not valid. It must be negative semidefinite.")
196
+ if not (eigs.real < 0).prod():
197
+ # We make Q negative definite: (\mathcal{Q} in the REN paper)
198
+ Q = Q - eps * self.eye_mask_dim_out
199
+ return Q, R, S
200
+
201
+ # init trainable params
202
+
203
+ def _init_trainable_params(self, initialization_std):
204
+ for training_param_name in self.training_param_names: # name of one of the training params, e.g., X
205
+ # read the defined shapes of the selected training param, e.g., X_shape
206
+ shape = getattr(self, training_param_name + '_shape')
207
+ # define the selected param (e.g., self.X) as nn.Parameter
208
+ if training_param_name == 'gamma':
209
+ initialization_std = 3
210
+ setattr(self, training_param_name, nn.Parameter((torch.randn(*shape) * initialization_std)))
@@ -0,0 +1,5 @@
1
+ # python
2
+ # file: src/neural_ssm/ssm/__init__.py
3
+ from .lru import LRU, L2RU, lruz, SSMConfig, SSL, DeepSSM, PureLRUR
4
+
5
+ __all__ = ["LRU", "L2RU", "lruz", "SSMConfig", "SSL", "DeepSSM", "PureLRUR"]