neural-ssm 0.1.7__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- neural_ssm-0.1.7/PKG-INFO +75 -0
- neural_ssm-0.1.7/README.md +28 -0
- neural_ssm-0.1.7/pyproject.toml +107 -0
- neural_ssm-0.1.7/setup.cfg +4 -0
- neural_ssm-0.1.7/src/neural_ssm/__init__.py +42 -0
- neural_ssm-0.1.7/src/neural_ssm/rens/__init__.py +5 -0
- neural_ssm-0.1.7/src/neural_ssm/rens/ren.py +210 -0
- neural_ssm-0.1.7/src/neural_ssm/ssm/__init__.py +5 -0
- neural_ssm-0.1.7/src/neural_ssm/ssm/lru.py +780 -0
- neural_ssm-0.1.7/src/neural_ssm/ssm/scan_utils.py +280 -0
- neural_ssm-0.1.7/src/neural_ssm/static_layers/__init__.py +9 -0
- neural_ssm-0.1.7/src/neural_ssm/static_layers/generic_layers.py +56 -0
- neural_ssm-0.1.7/src/neural_ssm/static_layers/lipschitz_mlps.py +141 -0
- neural_ssm-0.1.7/src/neural_ssm.egg-info/PKG-INFO +75 -0
- neural_ssm-0.1.7/src/neural_ssm.egg-info/SOURCES.txt +16 -0
- neural_ssm-0.1.7/src/neural_ssm.egg-info/dependency_links.txt +1 -0
- neural_ssm-0.1.7/src/neural_ssm.egg-info/requires.txt +28 -0
- neural_ssm-0.1.7/src/neural_ssm.egg-info/top_level.txt +1 -0
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: neural-ssm
|
|
3
|
+
Version: 0.1.7
|
|
4
|
+
Summary: Neural state space models and LRU variants in PyTorch
|
|
5
|
+
Author: Leonardo Massai
|
|
6
|
+
License: MIT
|
|
7
|
+
Project-URL: Homepage, https://github.com/LeoMassai
|
|
8
|
+
Project-URL: Repository, https://github.com/LeoMassai/neural-ssm
|
|
9
|
+
Project-URL: Issues, https://github.com/LeoMassai/neural-ssm/issues
|
|
10
|
+
Classifier: Programming Language :: Python :: 3
|
|
11
|
+
Classifier: Programming Language :: Python :: 3 :: Only
|
|
12
|
+
Classifier: Programming Language :: Python :: 3.9
|
|
13
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
14
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
15
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
16
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
17
|
+
Classifier: Operating System :: OS Independent
|
|
18
|
+
Classifier: Intended Audience :: Science/Research
|
|
19
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
20
|
+
Requires-Python: >=3.9
|
|
21
|
+
Description-Content-Type: text/markdown
|
|
22
|
+
Requires-Dist: torch>=2.2
|
|
23
|
+
Requires-Dist: numpy>=1.23
|
|
24
|
+
Requires-Dist: scipy>=1.10
|
|
25
|
+
Requires-Dist: h5py>=3.9
|
|
26
|
+
Requires-Dist: tqdm>=4.66
|
|
27
|
+
Requires-Dist: matplotlib>=3.7
|
|
28
|
+
Requires-Dist: einops>=0.6
|
|
29
|
+
Requires-Dist: typing-extensions>=4.5
|
|
30
|
+
Requires-Dist: jax>=0.4.26
|
|
31
|
+
Requires-Dist: jaxlib>=0.4.26
|
|
32
|
+
Requires-Dist: deel-torchlip
|
|
33
|
+
Provides-Extra: dev
|
|
34
|
+
Requires-Dist: black>=24.3.0; extra == "dev"
|
|
35
|
+
Requires-Dist: ruff>=0.3.0; extra == "dev"
|
|
36
|
+
Requires-Dist: mypy>=1.6.0; extra == "dev"
|
|
37
|
+
Requires-Dist: pytest>=7.4; extra == "dev"
|
|
38
|
+
Requires-Dist: pytest-cov>=4.1; extra == "dev"
|
|
39
|
+
Requires-Dist: ipykernel>=6.29; extra == "dev"
|
|
40
|
+
Requires-Dist: jupyter>=1.0; extra == "dev"
|
|
41
|
+
Provides-Extra: examples
|
|
42
|
+
Requires-Dist: seaborn>=0.13; extra == "examples"
|
|
43
|
+
Requires-Dist: pandas>=2.1; extra == "examples"
|
|
44
|
+
Provides-Extra: docs
|
|
45
|
+
Requires-Dist: mkdocs>=1.5; extra == "docs"
|
|
46
|
+
Requires-Dist: mkdocs-material>=9.5; extra == "docs"
|
|
47
|
+
|
|
48
|
+
# Pytorch L2RU Architecture: LRU with l2 stability guarantees and prescribed bound
|
|
49
|
+
|
|
50
|
+
A PyTorch implementation of the L2RU architecture introduced in the paper Free Parametrization of L2-bounded State Space Models. https://arxiv.org/abs/2503.23818. Application in System Identification is included as an example.
|
|
51
|
+
|
|
52
|
+
## L2RU block
|
|
53
|
+
The L2RU block is a discrete-time linear time-invariant system implemented in state-space form as:
|
|
54
|
+
```math
|
|
55
|
+
\begin{align}
|
|
56
|
+
x_{k+1} = Ax_{x} + B u_k\\
|
|
57
|
+
y_k = C x_k + D u_k,
|
|
58
|
+
\end{align}
|
|
59
|
+
```
|
|
60
|
+
A parametrization is provided for the matrices ```(A, B, C, D)```, guaranteeing a prescribed l2 bound for the overall SSM.
|
|
61
|
+
Moreover, the use of [parallel scan algorithms](https://en.wikipedia.org/wiki/Prefix_sum) makes execution extremely fast on modern hardware in non-core-bound scenarios.
|
|
62
|
+
|
|
63
|
+
## Deep L2RU Architecture
|
|
64
|
+
|
|
65
|
+
L2RU units are typically organized in a deep LRU architecture like:
|
|
66
|
+
|
|
67
|
+
<div align="center">
|
|
68
|
+
<img src="architecture/L2RU.png" alt="Description of image" width="800">
|
|
69
|
+
</div>
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
# Pytorch L2RU Architecture: LRU with l2 stability guarantees and prescribed bound
|
|
2
|
+
|
|
3
|
+
A PyTorch implementation of the L2RU architecture introduced in the paper Free Parametrization of L2-bounded State Space Models. https://arxiv.org/abs/2503.23818. Application in System Identification is included as an example.
|
|
4
|
+
|
|
5
|
+
## L2RU block
|
|
6
|
+
The L2RU block is a discrete-time linear time-invariant system implemented in state-space form as:
|
|
7
|
+
```math
|
|
8
|
+
\begin{align}
|
|
9
|
+
x_{k+1} = Ax_{x} + B u_k\\
|
|
10
|
+
y_k = C x_k + D u_k,
|
|
11
|
+
\end{align}
|
|
12
|
+
```
|
|
13
|
+
A parametrization is provided for the matrices ```(A, B, C, D)```, guaranteeing a prescribed l2 bound for the overall SSM.
|
|
14
|
+
Moreover, the use of [parallel scan algorithms](https://en.wikipedia.org/wiki/Prefix_sum) makes execution extremely fast on modern hardware in non-core-bound scenarios.
|
|
15
|
+
|
|
16
|
+
## Deep L2RU Architecture
|
|
17
|
+
|
|
18
|
+
L2RU units are typically organized in a deep LRU architecture like:
|
|
19
|
+
|
|
20
|
+
<div align="center">
|
|
21
|
+
<img src="architecture/L2RU.png" alt="Description of image" width="800">
|
|
22
|
+
</div>
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
|
|
@@ -0,0 +1,107 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["setuptools>=68", "wheel"]
|
|
3
|
+
build-backend = "setuptools.build_meta"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "neural-ssm"
|
|
7
|
+
version = "0.1.7"
|
|
8
|
+
description = "Neural state space models and LRU variants in PyTorch"
|
|
9
|
+
readme = "README.md"
|
|
10
|
+
requires-python = ">=3.9"
|
|
11
|
+
license = { text = "MIT" }
|
|
12
|
+
authors = [
|
|
13
|
+
{ name = "Leonardo Massai" }
|
|
14
|
+
]
|
|
15
|
+
classifiers = [
|
|
16
|
+
"Programming Language :: Python :: 3",
|
|
17
|
+
"Programming Language :: Python :: 3 :: Only",
|
|
18
|
+
"Programming Language :: Python :: 3.9",
|
|
19
|
+
"Programming Language :: Python :: 3.10",
|
|
20
|
+
"Programming Language :: Python :: 3.11",
|
|
21
|
+
"Programming Language :: Python :: 3.12",
|
|
22
|
+
"License :: OSI Approved :: MIT License",
|
|
23
|
+
"Operating System :: OS Independent",
|
|
24
|
+
"Intended Audience :: Science/Research",
|
|
25
|
+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
|
26
|
+
]
|
|
27
|
+
dependencies = [
|
|
28
|
+
# Core
|
|
29
|
+
"torch>=2.2",
|
|
30
|
+
"numpy>=1.23",
|
|
31
|
+
# SciPy for linear algebra and scipy.io.loadmat (MAT files in Test_files)
|
|
32
|
+
"scipy>=1.10",
|
|
33
|
+
# Optional MAT v7.3 support via HDF5 (some .mat files may require it)
|
|
34
|
+
"h5py>=3.9",
|
|
35
|
+
# Utilities commonly used in training/eval and plotting
|
|
36
|
+
"tqdm>=4.66",
|
|
37
|
+
"matplotlib>=3.7",
|
|
38
|
+
# Tensor/array rearrangements if used in layers
|
|
39
|
+
"einops>=0.6",
|
|
40
|
+
# Backport typing if running older Python or for PyTorch type hints
|
|
41
|
+
"typing-extensions>=4.5",
|
|
42
|
+
# JAX (CPU by default; GPU wheels are user specific)
|
|
43
|
+
"jax>=0.4.26",
|
|
44
|
+
"jaxlib>=0.4.26",
|
|
45
|
+
"deel-torchlip",
|
|
46
|
+
]
|
|
47
|
+
|
|
48
|
+
[project.optional-dependencies]
|
|
49
|
+
dev = [
|
|
50
|
+
"black>=24.3.0",
|
|
51
|
+
"ruff>=0.3.0",
|
|
52
|
+
"mypy>=1.6.0",
|
|
53
|
+
"pytest>=7.4",
|
|
54
|
+
"pytest-cov>=4.1",
|
|
55
|
+
"ipykernel>=6.29",
|
|
56
|
+
"jupyter>=1.0",
|
|
57
|
+
]
|
|
58
|
+
examples = [
|
|
59
|
+
"seaborn>=0.13",
|
|
60
|
+
"pandas>=2.1",
|
|
61
|
+
]
|
|
62
|
+
docs = [
|
|
63
|
+
"mkdocs>=1.5",
|
|
64
|
+
"mkdocs-material>=9.5",
|
|
65
|
+
]
|
|
66
|
+
|
|
67
|
+
[project.urls]
|
|
68
|
+
Homepage = "https://github.com/LeoMassai"
|
|
69
|
+
Repository = "https://github.com/LeoMassai/neural-ssm"
|
|
70
|
+
Issues = "https://github.com/LeoMassai/neural-ssm/issues"
|
|
71
|
+
|
|
72
|
+
[tool.setuptools]
|
|
73
|
+
include-package-data = true
|
|
74
|
+
license-files = ["LICEN[CS]E*"]
|
|
75
|
+
|
|
76
|
+
[tool.setuptools.packages.find]
|
|
77
|
+
where = ["src"]
|
|
78
|
+
include = ["neural_ssm*"]
|
|
79
|
+
|
|
80
|
+
[tool.pytest.ini_options]
|
|
81
|
+
addopts = "-q"
|
|
82
|
+
testpaths = ["tests", "Test_files"]
|
|
83
|
+
|
|
84
|
+
[tool.black]
|
|
85
|
+
line-length = 100
|
|
86
|
+
target-version = ["py39", "py310", "py311", "py312"]
|
|
87
|
+
include = "\\.pyi?$"
|
|
88
|
+
|
|
89
|
+
[tool.ruff]
|
|
90
|
+
line-length = 100
|
|
91
|
+
target-version = "py39"
|
|
92
|
+
src = ["src", "tests", "Test_files"]
|
|
93
|
+
extend-exclude = ["build", "dist", ".venv", ".mypy_cache", ".ruff_cache"]
|
|
94
|
+
select = ["E", "F", "I", "UP", "B"]
|
|
95
|
+
ignore = []
|
|
96
|
+
|
|
97
|
+
[tool.mypy]
|
|
98
|
+
python_version = "3.9"
|
|
99
|
+
ignore_missing_imports = true
|
|
100
|
+
warn_unused_ignores = true
|
|
101
|
+
warn_return_any = true
|
|
102
|
+
warn_unused_configs = true
|
|
103
|
+
no_implicit_optional = true
|
|
104
|
+
strict_equality = true
|
|
105
|
+
show_error_codes = true
|
|
106
|
+
pretty = true
|
|
107
|
+
mypy_path = ["src"]
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
# python
|
|
2
|
+
# file: src/neural_ssm/__init__.py
|
|
3
|
+
from importlib import import_module as _imp
|
|
4
|
+
|
|
5
|
+
# Re-export subpackages for discoverability
|
|
6
|
+
from . import ssm as ssm
|
|
7
|
+
from . import rens as rens
|
|
8
|
+
from . import static_layers as layers # public alias
|
|
9
|
+
|
|
10
|
+
# Top-level classes and configs
|
|
11
|
+
from .ssm.lru import LRU, L2RU, lruz, SSMConfig, SSL, DeepSSM, PureLRUR
|
|
12
|
+
from .rens.ren import REN
|
|
13
|
+
|
|
14
|
+
# Common layers exposed at top-level for convenience
|
|
15
|
+
try:
|
|
16
|
+
from .static_layers.generic_layers import LayerConfig, GLU, MLP, TLIP
|
|
17
|
+
except Exception:
|
|
18
|
+
pass
|
|
19
|
+
try:
|
|
20
|
+
from .static_layers.lipschitz_mlps import LMLP
|
|
21
|
+
except Exception:
|
|
22
|
+
pass
|
|
23
|
+
|
|
24
|
+
__all__ = [
|
|
25
|
+
"LRU", "L2RU", "lruz", "SSMConfig", "SSL", "DeepSSM", "PureLRUR",
|
|
26
|
+
"REN",
|
|
27
|
+
"layers", "ssm", "rens",
|
|
28
|
+
"LayerConfig", "GLU", "MLP", "LMLP", "TLIP",
|
|
29
|
+
]
|
|
30
|
+
|
|
31
|
+
__version__ = "0.1.0"
|
|
32
|
+
|
|
33
|
+
def __getattr__(name):
|
|
34
|
+
# Optional lazy/compat shims; keep internals movable
|
|
35
|
+
redirects = {
|
|
36
|
+
"layers": "neural_ssm.static_layers",
|
|
37
|
+
}
|
|
38
|
+
if name in redirects:
|
|
39
|
+
return _imp(redirects[name])
|
|
40
|
+
raise AttributeError(f"module neural_ssm has no attribute {name!r}")
|
|
41
|
+
|
|
42
|
+
del _imp
|
|
@@ -0,0 +1,210 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
import torch.nn.functional as F
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
# Robust REN implementation in the acyclic version
|
|
8
|
+
class REN(nn.Module):
|
|
9
|
+
# ## Implementation of REN model, modified from "Recurrent Equilibrium Networks: Flexible Dynamic Models with
|
|
10
|
+
# Guaranteed Stability and Robustness" by Max Revay et al.
|
|
11
|
+
def __init__(self, dim_in: int, dim_out: int, dim_internal: int,
|
|
12
|
+
dim_nl: int, initialization_std: float = 0.5, internal_state_init=None, gammat=None, mode="l2stable",
|
|
13
|
+
Q=None, R=None, S=None
|
|
14
|
+
, posdef_tol: float = 0.001):
|
|
15
|
+
super().__init__()
|
|
16
|
+
|
|
17
|
+
# set dimensions
|
|
18
|
+
self.dim_in = dim_in # input dimension m
|
|
19
|
+
self.dim_internal = dim_internal # state dimension n
|
|
20
|
+
self.dim_nl = dim_nl # dimension of v(t) and w(t) l
|
|
21
|
+
self.dim_out = dim_out # output dimension p
|
|
22
|
+
|
|
23
|
+
self.mode = mode
|
|
24
|
+
self.gammat = gammat
|
|
25
|
+
self.epsilon = posdef_tol
|
|
26
|
+
# # # # # # # # # IQC specification # # # # # # # # #
|
|
27
|
+
self.Q = Q
|
|
28
|
+
self.R = R
|
|
29
|
+
self.S = S
|
|
30
|
+
# # # # # # # # # Training parameters # # # # # # # # #
|
|
31
|
+
# Sparse training matrix parameters
|
|
32
|
+
# define matrices shapes
|
|
33
|
+
self.X_shape = (2 * dim_internal + dim_nl, 2 * dim_internal + dim_nl)
|
|
34
|
+
self.Y_shape = (dim_internal, dim_internal)
|
|
35
|
+
# nn state dynamics
|
|
36
|
+
self.B2_shape = (dim_internal, dim_in)
|
|
37
|
+
# nn output
|
|
38
|
+
self.C2_shape = (dim_out, dim_internal)
|
|
39
|
+
#self.D21_shape = (self.dim_out, self.dim_l)
|
|
40
|
+
self.D22_shape = (dim_out, dim_in)
|
|
41
|
+
# v signal
|
|
42
|
+
self.D12_shape = (dim_nl, dim_in)
|
|
43
|
+
self.Z3_shape = (abs(dim_out - dim_in), min(dim_out, dim_in))
|
|
44
|
+
self.X3_shape = (min(dim_out, dim_in), min(dim_out, dim_in))
|
|
45
|
+
self.Y3_shape = (min(dim_out, dim_in), min(dim_out, dim_in))
|
|
46
|
+
self.gamma_shape = (1, 1)
|
|
47
|
+
|
|
48
|
+
self.training_param_names = ['X', 'Y', 'B2', 'C2', 'Z3', 'X3', 'Y3', 'D12']
|
|
49
|
+
|
|
50
|
+
# Optionally define a trainable gamma
|
|
51
|
+
if self.gammat is None:
|
|
52
|
+
self.training_param_names.append('gamma')
|
|
53
|
+
else:
|
|
54
|
+
self.gamma = gammat
|
|
55
|
+
|
|
56
|
+
# define trainable params
|
|
57
|
+
|
|
58
|
+
self._init_trainable_params(initialization_std)
|
|
59
|
+
|
|
60
|
+
# # # # # # # # # Non-trainable parameters and constant tensors # # # # # # # # #
|
|
61
|
+
# masks
|
|
62
|
+
self.register_buffer('eye_mask_min', torch.eye(min(dim_in, dim_out)))
|
|
63
|
+
self.register_buffer('eye_mask_dim_in', torch.eye(dim_in))
|
|
64
|
+
self.register_buffer('eye_mask_dim_out', torch.eye(dim_out))
|
|
65
|
+
self.register_buffer('eye_mask_dim_state', torch.eye(dim_internal))
|
|
66
|
+
self.register_buffer('eye_mask_H', torch.eye(2 * dim_internal + dim_nl))
|
|
67
|
+
self.register_buffer('zeros_mask_S', torch.zeros(dim_in, dim_out))
|
|
68
|
+
self.register_buffer('zeros_mask_Q', torch.zeros(dim_out, dim_out))
|
|
69
|
+
self.register_buffer('zeros_mask_R', torch.zeros(dim_in, dim_in))
|
|
70
|
+
self.register_buffer('zeros_mask_so', torch.zeros(dim_internal, dim_out))
|
|
71
|
+
self.register_buffer('eye_mask_w', torch.eye(dim_nl))
|
|
72
|
+
self.register_buffer('D21', torch.zeros(dim_out, dim_nl))
|
|
73
|
+
|
|
74
|
+
# initialize internal state
|
|
75
|
+
if internal_state_init is None:
|
|
76
|
+
self.x = torch.zeros(1, 1, self.dim_internal, device="cuda")
|
|
77
|
+
else:
|
|
78
|
+
assert isinstance(internal_state_init, torch.Tensor)
|
|
79
|
+
self.x = internal_state_init.reshape(1, 1, self.dim_internal)
|
|
80
|
+
self.register_buffer('init_x', self.x.detach().clone())
|
|
81
|
+
|
|
82
|
+
# Auxiliary elements
|
|
83
|
+
self.set_param()
|
|
84
|
+
|
|
85
|
+
def set_param(self, gamman=None):
|
|
86
|
+
if gamman is not None:
|
|
87
|
+
self.gamma = gamman
|
|
88
|
+
gamma = torch.abs(self.gamma)
|
|
89
|
+
dim_internal, dim_nl, dim_in, dim_out = self.dim_internal, self.dim_nl, self.dim_in, self.dim_out
|
|
90
|
+
|
|
91
|
+
# Updating of Q,S,R with variable gamma if needed
|
|
92
|
+
self.Q, self.R, self.S = self._set_mode(self.mode, gamma, self.Q, self.R, self.S)
|
|
93
|
+
M = F.linear(self.X3.T, self.X3.T) + self.Y3 - self.Y3.T + F.linear(self.Z3.T,
|
|
94
|
+
self.Z3.T) + self.epsilon * self.eye_mask_min
|
|
95
|
+
if dim_out >= dim_in:
|
|
96
|
+
N = torch.vstack((F.linear(self.eye_mask_dim_in - M,
|
|
97
|
+
torch.inverse(self.eye_mask_dim_in + M).T),
|
|
98
|
+
-2 * F.linear(self.Z3, torch.inverse(self.eye_mask_dim_in + M).T)))
|
|
99
|
+
else:
|
|
100
|
+
N = torch.hstack((F.linear(torch.inverse(self.eye_mask_dim_out + M),
|
|
101
|
+
(self.eye_mask_dim_out - M).T),
|
|
102
|
+
-2 * F.linear(torch.inverse(self.eye_mask_dim_out + M), self.Z3)))
|
|
103
|
+
|
|
104
|
+
Lq = torch.linalg.cholesky(-self.Q).T
|
|
105
|
+
Lr = torch.linalg.cholesky(self.R - torch.matmul(self.S, torch.matmul(torch.inverse(self.Q), self.S.T))).T
|
|
106
|
+
self.D22 = -torch.matmul(torch.inverse(self.Q), self.S.T) + torch.matmul(torch.inverse(Lq),
|
|
107
|
+
torch.matmul(N, Lr))
|
|
108
|
+
# Calculate psi_r:
|
|
109
|
+
R_cal = self.R + torch.matmul(self.S, self.D22) + torch.matmul(self.S, self.D22).T + torch.matmul(self.D22.T,
|
|
110
|
+
torch.matmul(
|
|
111
|
+
self.Q,
|
|
112
|
+
self.D22))
|
|
113
|
+
R_cal_inv = torch.inverse(R_cal)
|
|
114
|
+
C2_cal = torch.matmul(torch.matmul(self.D22.T, self.Q) + self.S, self.C2)
|
|
115
|
+
D21_cal = torch.matmul(torch.matmul(self.D22.T, self.Q) + self.S, self.D21) - self.D12.T
|
|
116
|
+
vec_r = torch.cat((C2_cal.T, D21_cal.T, self.B2), dim=0)
|
|
117
|
+
psi_r = torch.matmul(vec_r, torch.matmul(R_cal_inv, vec_r.T))
|
|
118
|
+
# Calculate psi_q:
|
|
119
|
+
vec_q = torch.cat((self.C2.T, self.D21.T, self.zeros_mask_so), dim=0)
|
|
120
|
+
psi_q = torch.matmul(vec_q, torch.matmul(self.Q, vec_q.T))
|
|
121
|
+
# Create H matrix:
|
|
122
|
+
H = torch.matmul(self.X.T, self.X) + self.epsilon * self.eye_mask_H + psi_r - psi_q
|
|
123
|
+
h1, h2, h3 = torch.split(H, [dim_internal, dim_nl, dim_internal], dim=0)
|
|
124
|
+
H11, H12, H13 = torch.split(h1, [dim_internal, dim_nl, dim_internal], dim=1)
|
|
125
|
+
H21, H22, _ = torch.split(h2, [dim_internal, dim_nl, dim_internal], dim=1)
|
|
126
|
+
H31, H32, H33 = torch.split(h3, [dim_internal, dim_nl, dim_internal], dim=1)
|
|
127
|
+
self.P_cal = H33
|
|
128
|
+
# NN state dynamics:
|
|
129
|
+
self.F = H31
|
|
130
|
+
self.B1 = H32
|
|
131
|
+
# NN output:
|
|
132
|
+
self.E = 0.5 * (H11 + self.P_cal + self.Y - self.Y.T)
|
|
133
|
+
# v signal: [Change the following 2 lines if we don't want a strictly acyclic REN!]
|
|
134
|
+
self.Lambda = 0.5 * torch.diag(H22)
|
|
135
|
+
self.D11 = -torch.tril(H22, diagonal=-1)
|
|
136
|
+
self.C1 = -H21
|
|
137
|
+
# Matrix P
|
|
138
|
+
#self.P = torch.matmul(self.E.T, torch.matmul(torch.inverse(self.P_cal), self.E))
|
|
139
|
+
|
|
140
|
+
def forward(self, u):
|
|
141
|
+
decay_rate = 0.95
|
|
142
|
+
batch_size = u.shape[0]
|
|
143
|
+
w = torch.zeros(batch_size, 1, self.dim_nl, device=u.device)
|
|
144
|
+
# update each row of w using Eq. (8) with a lower triangular D11
|
|
145
|
+
for i in range(self.dim_nl):
|
|
146
|
+
# v is element i of v with dim (batch_size, 1)
|
|
147
|
+
v = F.linear(self.x, self.C1[i, :]) + F.linear(w, self.D11[i, :]) + F.linear(u, self.D12[i, :])
|
|
148
|
+
w = w + (self.eye_mask_w[i, :] * torch.tanh(v / self.Lambda[i])).reshape(batch_size, 1, self.dim_nl)
|
|
149
|
+
|
|
150
|
+
# compute next state using Eq. 18
|
|
151
|
+
self.x = F.linear(
|
|
152
|
+
F.linear(self.x, self.F) + F.linear(w, self.B1) + F.linear(u, self.B2),
|
|
153
|
+
self.E.inverse())
|
|
154
|
+
|
|
155
|
+
# compute output
|
|
156
|
+
y = F.linear(self.x, self.C2) + F.linear(w, self.D21) + F.linear(u, self.D22)
|
|
157
|
+
|
|
158
|
+
return y
|
|
159
|
+
|
|
160
|
+
def _set_mode(self, mode, gamma, Q, R, S, eps: float = 1e-4):
|
|
161
|
+
# We set Q to be negative definite. If Q is nsd we set: Q - \epsilon I.
|
|
162
|
+
# I.e. The Q we define here is denoted as \matcal{Q} in REN paper.
|
|
163
|
+
if mode == "l2stable":
|
|
164
|
+
Q = -(1. / gamma) * self.eye_mask_dim_out
|
|
165
|
+
R = gamma * self.eye_mask_dim_in
|
|
166
|
+
S = self.zeros_mask_S
|
|
167
|
+
elif mode == "input_p":
|
|
168
|
+
if self.p != self.m:
|
|
169
|
+
raise NameError("Dimensions of u(t) and y(t) need to be the same for enforcing input passivity.")
|
|
170
|
+
Q = self.zeros_mask_Q - eps * self.eye_mask_dim_out
|
|
171
|
+
R = -2. * gamma * self.eye_mask_dim_state
|
|
172
|
+
S = self.eye_mask_dim_out
|
|
173
|
+
elif mode == "output_p":
|
|
174
|
+
if self.p != self.m:
|
|
175
|
+
raise NameError("Dimensions of u(t) and y(t) need to be the same for enforcing output passivity.")
|
|
176
|
+
Q = -2. * gamma * self.eye_mask_dim_out
|
|
177
|
+
R = self.zeros_mask_R
|
|
178
|
+
S = self.eye_mask_dim_state
|
|
179
|
+
else:
|
|
180
|
+
print("Using matrices R,Q,S given by user.")
|
|
181
|
+
# Check dimensions:
|
|
182
|
+
if not (len(R.shape) == 2 and R.shape[0] == R.shape[1] and R.shape[0] == self.m):
|
|
183
|
+
raise NameError("The matrix R is not valid. It must be a square matrix of %ix%i." % (self.m, self.m))
|
|
184
|
+
if not (len(Q.shape) == 2 and Q.shape[0] == Q.shape[1] and Q.shape[0] == self.p):
|
|
185
|
+
raise NameError("The matrix Q is not valid. It must be a square matrix of %ix%i." % (self.p, self.p))
|
|
186
|
+
if not (len(S.shape) == 2 and S.shape[0] == self.m and S.shape[1] == self.p):
|
|
187
|
+
raise NameError("The matrix S is not valid. It must be a matrix of %ix%i." % (self.m, self.p))
|
|
188
|
+
# Check R=R':
|
|
189
|
+
if not (R == R.T).prod():
|
|
190
|
+
raise NameError("The matrix R is not valid. It must be symmetric.")
|
|
191
|
+
# Check Q is nsd:
|
|
192
|
+
eigs, _ = torch.linalg.eig(Q)
|
|
193
|
+
if not (eigs.real <= 0).prod():
|
|
194
|
+
print('oh!')
|
|
195
|
+
raise NameError("The matrix Q is not valid. It must be negative semidefinite.")
|
|
196
|
+
if not (eigs.real < 0).prod():
|
|
197
|
+
# We make Q negative definite: (\mathcal{Q} in the REN paper)
|
|
198
|
+
Q = Q - eps * self.eye_mask_dim_out
|
|
199
|
+
return Q, R, S
|
|
200
|
+
|
|
201
|
+
# init trainable params
|
|
202
|
+
|
|
203
|
+
def _init_trainable_params(self, initialization_std):
|
|
204
|
+
for training_param_name in self.training_param_names: # name of one of the training params, e.g., X
|
|
205
|
+
# read the defined shapes of the selected training param, e.g., X_shape
|
|
206
|
+
shape = getattr(self, training_param_name + '_shape')
|
|
207
|
+
# define the selected param (e.g., self.X) as nn.Parameter
|
|
208
|
+
if training_param_name == 'gamma':
|
|
209
|
+
initialization_std = 3
|
|
210
|
+
setattr(self, training_param_name, nn.Parameter((torch.randn(*shape) * initialization_std)))
|