ecliseutils 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ecliseutils/__init__.py +99 -0
- ecliseutils/are.py +409 -0
- ecliseutils/arrays.py +90 -0
- ecliseutils/dicts.py +77 -0
- ecliseutils/ensemble.py +129 -0
- ecliseutils/fast_conv_scan.py +69 -0
- ecliseutils/io.py +38 -0
- ecliseutils/labeled_array.py +213 -0
- ecliseutils/linalg.py +158 -0
- ecliseutils/memory.py +61 -0
- ecliseutils/modules.py +208 -0
- ecliseutils/ode.py +73 -0
- ecliseutils/plotting.py +50 -0
- ecliseutils/recursive.py +50 -0
- ecliseutils/settings.py +137 -0
- ecliseutils/timing.py +96 -0
- ecliseutils/types.py +19 -0
- ecliseutils-0.1.0.dist-info/METADATA +105 -0
- ecliseutils-0.1.0.dist-info/RECORD +22 -0
- ecliseutils-0.1.0.dist-info/WHEEL +5 -0
- ecliseutils-0.1.0.dist-info/licenses/LICENSE +21 -0
- ecliseutils-0.1.0.dist-info/top_level.txt +1 -0
ecliseutils/__init__.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
"""ecliseutils: shared general-purpose research utilities.
|
|
2
|
+
|
|
3
|
+
Submodules group the utilities by concern (``modules``, ``arrays``, ``linalg``,
|
|
4
|
+
``are``, ``ode``, ``labeled_array``, ``ensemble``, ``fast_conv_scan``,
|
|
5
|
+
``recursive``, ``dicts``, ``io``, ``memory``, ``timing``, ``plotting``,
|
|
6
|
+
``settings``, ``types``). The most commonly used names are re-exported here for
|
|
7
|
+
convenience (``import ecliseutils as eu; eu.stack_module_arr(...)``).
|
|
8
|
+
|
|
9
|
+
Projects should call :func:`ecliseutils.configure` early (typically from their
|
|
10
|
+
own ``settings`` module) to set device/dtype/precision/seed.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
from __future__ import annotations
|
|
14
|
+
|
|
15
|
+
from . import (
|
|
16
|
+
arrays,
|
|
17
|
+
are,
|
|
18
|
+
dicts,
|
|
19
|
+
ensemble,
|
|
20
|
+
fast_conv_scan,
|
|
21
|
+
io,
|
|
22
|
+
labeled_array,
|
|
23
|
+
linalg,
|
|
24
|
+
memory,
|
|
25
|
+
modules,
|
|
26
|
+
ode,
|
|
27
|
+
plotting,
|
|
28
|
+
recursive,
|
|
29
|
+
settings,
|
|
30
|
+
timing,
|
|
31
|
+
types,
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
from .settings import configure, set_debug, register_safe_globals, default_dtype
|
|
35
|
+
from .types import ModelPair
|
|
36
|
+
|
|
37
|
+
from .labeled_array import LabeledArray, LabeledDataset, array_of, put_object, as_ndarray
|
|
38
|
+
from .modules import (
|
|
39
|
+
stack_tensor_arr,
|
|
40
|
+
stack_module_arr,
|
|
41
|
+
stack_module_arr_preserve_reference,
|
|
42
|
+
run_module_arr,
|
|
43
|
+
multi_vmap,
|
|
44
|
+
buffer_dict,
|
|
45
|
+
td_items,
|
|
46
|
+
td_get,
|
|
47
|
+
parameter_td,
|
|
48
|
+
mask_dataset_with_total_sequence_length,
|
|
49
|
+
broadcast_shapes,
|
|
50
|
+
get_all_hooks,
|
|
51
|
+
)
|
|
52
|
+
from .arrays import (
|
|
53
|
+
multi_iter,
|
|
54
|
+
multi_enumerate,
|
|
55
|
+
multi_map,
|
|
56
|
+
multi_zip,
|
|
57
|
+
dim_array_like,
|
|
58
|
+
broadcast_dim_array_shapes,
|
|
59
|
+
broadcast_dim_arrays,
|
|
60
|
+
take_from_dim_array,
|
|
61
|
+
)
|
|
62
|
+
from .linalg import (
|
|
63
|
+
pow_series,
|
|
64
|
+
batch_trace,
|
|
65
|
+
kl_div,
|
|
66
|
+
sqrtm,
|
|
67
|
+
complex,
|
|
68
|
+
ceildiv,
|
|
69
|
+
ceil,
|
|
70
|
+
T,
|
|
71
|
+
hadamard_conjugation,
|
|
72
|
+
hadamard_conjugation_diff_order1,
|
|
73
|
+
hadamard_conjugation_diff_order2,
|
|
74
|
+
inverse,
|
|
75
|
+
eig_some,
|
|
76
|
+
)
|
|
77
|
+
from .are import (
|
|
78
|
+
solve_discrete_are,
|
|
79
|
+
solve_continuous_are,
|
|
80
|
+
test_discrete_are,
|
|
81
|
+
test_continuous_are,
|
|
82
|
+
)
|
|
83
|
+
from .ode import linspace, geomspace, batch_odeint
|
|
84
|
+
from .ensemble import EnsembleModule, DEFAULT_SPLIT_SIZE
|
|
85
|
+
from .fast_conv_scan import ConvScanFn, conv_scan
|
|
86
|
+
from .recursive import rgetattr, rsetattr, rhasattr, rgetitem, rsetitem
|
|
87
|
+
from .dicts import flatten_nested_dict, map_dict, nested_type, print_dict, call_func_with_kwargs, hash_hex
|
|
88
|
+
from .io import torch_load, empty_cache, reset_seed, model_size
|
|
89
|
+
from .memory import (
|
|
90
|
+
get_tensors_in_memory,
|
|
91
|
+
print_tensors_in_memory,
|
|
92
|
+
get_tensors_in_memory_shape,
|
|
93
|
+
track_tensor_diff,
|
|
94
|
+
)
|
|
95
|
+
from .timing import Timer, identity, PTR, print_disabled, print_enabled, track_calls
|
|
96
|
+
from .plotting import color, confidence_ellipse
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
__version__ = "0.1.0"
|
ecliseutils/are.py
ADDED
|
@@ -0,0 +1,409 @@
|
|
|
1
|
+
"""Algebraic Riccati equation solvers (batched, PyTorch, differentiable).
|
|
2
|
+
|
|
3
|
+
A single in-house implementation for both the discrete (DARE) and continuous
|
|
4
|
+
(CARE) algebraic Riccati equations that is:
|
|
5
|
+
|
|
6
|
+
* **batched / parallel** -- everything is batched torch, no per-element scipy
|
|
7
|
+
loop;
|
|
8
|
+
* **robust to degenerate ``R`` (e.g. ``R = 0``) for the DARE** -- the forward
|
|
9
|
+
builds the van Dooren *extended pencil*, which never forms ``R^{-1}``, then
|
|
10
|
+
works on the deflated pencil ``(Lhat, Mhat)`` via the disk-function matrix
|
|
11
|
+
``(Lhat + Mhat)^{-1}(Lhat - Mhat)``. Inverting neither ``Lhat`` nor ``Mhat``
|
|
12
|
+
means a deadbeat / ``R = 0`` spectrum (closed-loop eigenvalues at ``mu = 0``
|
|
13
|
+
with symplectic partners at ``mu = inf``) is handled cleanly. (The CARE with
|
|
14
|
+
singular ``R`` is genuinely ill-posed -- ``B R^{-1} B^T`` diverges and the
|
|
15
|
+
closed-loop eigenvalues escape to ``inf`` *on* the imaginary-axis splitting
|
|
16
|
+
boundary -- so CARE requires a nonsingular ``R``.)
|
|
17
|
+
* **robust to defective / non-diagonalizable matrices** -- the stable subspace
|
|
18
|
+
is extracted with the matrix **sign** function, a spectral projector that
|
|
19
|
+
never computes eigenvectors, so a rank-deficient or Jordan-block eigenvector
|
|
20
|
+
basis (which breaks ``torch.linalg.eig``) cannot break it;
|
|
21
|
+
* **differentiable** -- the (non-differentiable) forward is wrapped in a custom
|
|
22
|
+
:class:`torch.autograd.Function` whose backward solves a Lyapunov/Sylvester
|
|
23
|
+
adjoint via implicit differentiation. Gradient stability is therefore
|
|
24
|
+
decoupled from the forward method.
|
|
25
|
+
|
|
26
|
+
Why not an in-house Schur/QZ? No CUDA library (cuSOLVER, MAGMA, ``torch.linalg``)
|
|
27
|
+
exposes a batched nonsymmetric real-Schur / QZ with invariant-subspace
|
|
28
|
+
reordering, and hand-rolling Francis double-shift iterations is branchy,
|
|
29
|
+
sequential and GPU-hostile. We only need the stable invariant/deflating
|
|
30
|
+
subspace, which the sign/disk function computes with batched ``matmul`` /
|
|
31
|
+
``solve`` / ``svd`` (all well-supported on CUDA, unlike ``eig``).
|
|
32
|
+
|
|
33
|
+
Caveats:
|
|
34
|
+
|
|
35
|
+
* **Solvability assumption.** The deflated pencil ``(Lhat, Mhat)`` must have no
|
|
36
|
+
eigenvalue exactly on the splitting boundary (imaginary axis for CARE, unit
|
|
37
|
+
circle for DARE). This is the standard existence condition for a stabilizing
|
|
38
|
+
solution; the sign/disk function is well-defined precisely under it.
|
|
39
|
+
Off-boundary defectiveness / Jordan blocks are fine. (Singular factors --
|
|
40
|
+
eigenvalues at ``0`` or ``inf`` -- are fine for the DARE disk form but not for
|
|
41
|
+
the CARE sign form, hence the CARE nonsingular-``R`` requirement above.)
|
|
42
|
+
* **Backward w.r.t. ``R`` when ``R`` is exactly singular** is ill-defined (the
|
|
43
|
+
gradient involves ``R^{-1}``); the forward ``P`` is still produced and a
|
|
44
|
+
pseudo-inverse convention is used.
|
|
45
|
+
|
|
46
|
+
``test_discrete_are`` / ``test_continuous_are`` return the Riccati residual and
|
|
47
|
+
are handy for accuracy checks (especially in the degenerate cases scipy cannot
|
|
48
|
+
even run).
|
|
49
|
+
"""
|
|
50
|
+
|
|
51
|
+
from __future__ import annotations
|
|
52
|
+
|
|
53
|
+
from typing import Optional, Tuple
|
|
54
|
+
|
|
55
|
+
import torch
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
__all__ = [
|
|
59
|
+
"solve_discrete_are",
|
|
60
|
+
"solve_continuous_are",
|
|
61
|
+
"test_discrete_are",
|
|
62
|
+
"test_continuous_are",
|
|
63
|
+
]
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
# SECTION: small batched linear-algebra helpers
|
|
67
|
+
def _sym(X: torch.Tensor) -> torch.Tensor:
|
|
68
|
+
return 0.5 * (X + X.mT)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def _safe_solve(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
|
|
72
|
+
"""``A^{-1} B`` with a pseudo-inverse fallback when ``A`` is (numerically) singular."""
|
|
73
|
+
try:
|
|
74
|
+
return torch.linalg.solve(A, B)
|
|
75
|
+
except torch._C._LinAlgError:
|
|
76
|
+
return torch.linalg.pinv(A) @ B
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def _bkron(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
|
|
80
|
+
"""Batched Kronecker product matching ``torch.kron`` on the trailing 2 dims.
|
|
81
|
+
|
|
82
|
+
``A`` is ``[B... x p x q]`` and ``B`` is ``[B... x r x s]`` (batch dims
|
|
83
|
+
broadcast, and either operand may be unbatched). Returns ``[B... x pr x qs]``.
|
|
84
|
+
"""
|
|
85
|
+
p, q = A.shape[-2:]
|
|
86
|
+
r, s = B.shape[-2:]
|
|
87
|
+
out = torch.einsum("...ij,...kl->...ikjl", A, B)
|
|
88
|
+
return out.reshape(*out.shape[:-4], p * r, q * s)
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def _commutation(n: int, dtype: torch.dtype, device: torch.device) -> torch.Tensor:
|
|
92
|
+
"""Commutation matrix ``K`` with ``K vec(X) = vec(X^T)`` (column-major vec)."""
|
|
93
|
+
idx = torch.arange(n * n, device=device)
|
|
94
|
+
perm = idx // n + n * (idx % n)
|
|
95
|
+
return torch.eye(n * n, dtype=dtype, device=device)[perm]
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def _vecc_row(X: torch.Tensor) -> torch.Tensor:
|
|
99
|
+
"""Column-major ``vec(X)`` returned as a row vector ``[B... x 1 x (a*b)]``."""
|
|
100
|
+
a, b = X.shape[-2:]
|
|
101
|
+
return X.mT.reshape(*X.shape[:-2], 1, a * b)
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def _unvec(v: torch.Tensor, a: int, b: int) -> torch.Tensor:
|
|
105
|
+
"""Inverse of a row-vector ``[B... x 1 x (a*b)]`` via ``.view(a, b)`` (no transpose)."""
|
|
106
|
+
v = v.squeeze(-2)
|
|
107
|
+
return v.reshape(*v.shape[:-1], a, b)
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def _unvecT(v: torch.Tensor, a: int, b: int) -> torch.Tensor:
|
|
111
|
+
"""Inverse via ``.view(a, b).T`` (i.e. column-major un-vec) -> ``[B... x b x a]``."""
|
|
112
|
+
return _unvec(v, a, b).mT
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
# SECTION: forward -- deflated extended pencil + matrix sign/disk projector
|
|
116
|
+
def _extended_pencil(
|
|
117
|
+
A: torch.Tensor,
|
|
118
|
+
B: torch.Tensor,
|
|
119
|
+
Q: torch.Tensor,
|
|
120
|
+
R: torch.Tensor,
|
|
121
|
+
discrete: bool,
|
|
122
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
123
|
+
"""Van Dooren extended pencil ``L - lambda M`` (size ``2m + n``); no ``R^{-1}``."""
|
|
124
|
+
bsz = A.shape[:-2]
|
|
125
|
+
m, n = B.shape[-2:]
|
|
126
|
+
|
|
127
|
+
I_m = torch.eye(m, dtype=A.dtype, device=A.device).expand(*bsz, m, m)
|
|
128
|
+
Zmm = A.new_zeros((*bsz, m, m))
|
|
129
|
+
Zmn = A.new_zeros((*bsz, m, n))
|
|
130
|
+
Znm = A.new_zeros((*bsz, n, m))
|
|
131
|
+
Znn = A.new_zeros((*bsz, n, n))
|
|
132
|
+
|
|
133
|
+
if discrete:
|
|
134
|
+
L = torch.cat([
|
|
135
|
+
torch.cat([A, Zmm, B], dim=-1),
|
|
136
|
+
torch.cat([-Q, I_m, Zmn], dim=-1),
|
|
137
|
+
torch.cat([Znm, Znm, R], dim=-1),
|
|
138
|
+
], dim=-2)
|
|
139
|
+
M = torch.cat([
|
|
140
|
+
torch.cat([I_m, Zmm, Zmn], dim=-1),
|
|
141
|
+
torch.cat([Zmm, A.mT, Zmn], dim=-1),
|
|
142
|
+
torch.cat([Znm, -B.mT, Znn], dim=-1),
|
|
143
|
+
], dim=-2)
|
|
144
|
+
else:
|
|
145
|
+
L = torch.cat([
|
|
146
|
+
torch.cat([A, Zmm, B], dim=-1),
|
|
147
|
+
torch.cat([-Q, -A.mT, Zmn], dim=-1),
|
|
148
|
+
torch.cat([Znm, B.mT, R], dim=-1),
|
|
149
|
+
], dim=-2)
|
|
150
|
+
M = torch.cat([
|
|
151
|
+
torch.cat([I_m, Zmm, Zmn], dim=-1),
|
|
152
|
+
torch.cat([Zmm, I_m, Zmn], dim=-1),
|
|
153
|
+
torch.cat([Znm, Znm, Znn], dim=-1),
|
|
154
|
+
], dim=-2)
|
|
155
|
+
return L, M
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
def _deflate(L: torch.Tensor, M: torch.Tensor, m: int, n: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
159
|
+
"""Compress the trailing ``n`` columns -> regular ``2m x 2m`` pencil ``(Lhat, Mhat)``.
|
|
160
|
+
|
|
161
|
+
The QR of ``L``'s trailing control columns and projection onto their
|
|
162
|
+
orthogonal complement eliminates the control variable without ever forming
|
|
163
|
+
``R^{-1}`` -- this is what makes singular / zero ``R`` admissible.
|
|
164
|
+
"""
|
|
165
|
+
q_of_qr = torch.linalg.qr(L[..., :, -n:], mode="complete")[0] # [B... x (2m+n) x (2m+n)]
|
|
166
|
+
defl = q_of_qr[..., :, n:] # [B... x (2m+n) x 2m]
|
|
167
|
+
Lhat = defl.mT @ L[..., :, :2 * m] # [B... x 2m x 2m]
|
|
168
|
+
Mhat = defl.mT @ M[..., :, :2 * m] # [B... x 2m x 2m]
|
|
169
|
+
return Lhat, Mhat
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
def _matrix_sign(C: torch.Tensor, max_iter: int = 64, tol: float = 1e-12) -> torch.Tensor:
|
|
173
|
+
"""Matrix sign function via scaled Newton iteration ``X <- 1/2 (cX + (cX)^{-1})``."""
|
|
174
|
+
sz = C.shape[-1]
|
|
175
|
+
eps = torch.finfo(C.dtype).eps
|
|
176
|
+
X = C
|
|
177
|
+
for _ in range(max_iter):
|
|
178
|
+
Xinv = torch.linalg.inv(X)
|
|
179
|
+
det = torch.linalg.det(X).abs().clamp_min(eps)
|
|
180
|
+
c = det.pow(-1.0 / sz)[..., None, None]
|
|
181
|
+
Xnew = 0.5 * (c * X + Xinv / c)
|
|
182
|
+
denom = X.norm(dim=(-2, -1)).clamp_min(eps)
|
|
183
|
+
diff = (Xnew - X).norm(dim=(-2, -1)) / denom
|
|
184
|
+
X = Xnew
|
|
185
|
+
if torch.all(diff < tol):
|
|
186
|
+
break
|
|
187
|
+
return X
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
def _stable_subspace_P(Lhat: torch.Tensor, Mhat: torch.Tensor, discrete: bool) -> torch.Tensor:
|
|
191
|
+
"""Stable spectral projector of the pencil ``(Lhat, Mhat)`` -> ``P = U21 U11^{-1}``.
|
|
192
|
+
|
|
193
|
+
A single matrix sign function is used in both cases; the only difference is the
|
|
194
|
+
matrix it acts on, chosen so that the **stable** eigenvalues map to ``Re < 0``:
|
|
195
|
+
|
|
196
|
+
* CARE (stable ``Re mu < 0``): ``Chat = Mhat^{-1} Lhat`` (eigenvalues ``mu``).
|
|
197
|
+
* DARE (stable ``|mu| < 1``): the disk-function matrix
|
|
198
|
+
``Chat = (Lhat + Mhat)^{-1} (Lhat - Mhat)`` (eigenvalues ``(mu - 1)/(mu + 1)``).
|
|
199
|
+
This never inverts ``Mhat`` (nor ``Lhat``), so deadbeat / ``R = 0`` spectra
|
|
200
|
+
with eigenvalues at ``mu = 0`` and ``mu = inf`` are handled gracefully.
|
|
201
|
+
"""
|
|
202
|
+
sz = Lhat.shape[-1]
|
|
203
|
+
m = sz // 2
|
|
204
|
+
I = torch.eye(sz, dtype=Lhat.dtype, device=Lhat.device)
|
|
205
|
+
|
|
206
|
+
if discrete:
|
|
207
|
+
Chat = _safe_solve(Lhat + Mhat, Lhat - Mhat)
|
|
208
|
+
else:
|
|
209
|
+
Chat = _safe_solve(Mhat, Lhat)
|
|
210
|
+
|
|
211
|
+
S = _matrix_sign(Chat)
|
|
212
|
+
Pm = 0.5 * (I - S) # projector onto the stable subspace
|
|
213
|
+
U = torch.linalg.svd(Pm)[0][..., :, :m] # leading m left singular vectors
|
|
214
|
+
U11 = U[..., :m, :]
|
|
215
|
+
U21 = U[..., m:, :]
|
|
216
|
+
P = _safe_solve(U11.mT, U21.mT).mT
|
|
217
|
+
return _sym(P)
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
def _are_forward(A: torch.Tensor, B: torch.Tensor, Q: torch.Tensor, R: torch.Tensor, discrete: bool) -> torch.Tensor:
|
|
221
|
+
m, n = B.shape[-2:]
|
|
222
|
+
L, M = _extended_pencil(A, B, Q, R, discrete)
|
|
223
|
+
Lhat, Mhat = _deflate(L, M, m, n)
|
|
224
|
+
return _stable_subspace_P(Lhat, Mhat, discrete)
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
# SECTION: backward -- implicit differentiation (shared adjoint machinery)
|
|
228
|
+
def _solve_continuous_lyapunov(A_cl: torch.Tensor, C: torch.Tensor) -> torch.Tensor:
|
|
229
|
+
"""Solve ``A_cl W + W A_cl^T = C`` (batched, via Kronecker linear system)."""
|
|
230
|
+
m = A_cl.shape[-1]
|
|
231
|
+
I = torch.eye(m, dtype=A_cl.dtype, device=A_cl.device)
|
|
232
|
+
Lk = _bkron(I, A_cl) + _bkron(A_cl, I) # [B... x m^2 x m^2]
|
|
233
|
+
rhs = C.mT.reshape(*C.shape[:-2], m * m, 1)
|
|
234
|
+
x = torch.linalg.solve(Lk, rhs)
|
|
235
|
+
return x.reshape(*C.shape[:-2], m, m).mT
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
def _care_backward(
|
|
239
|
+
P: torch.Tensor, A: torch.Tensor, B: torch.Tensor, Q: torch.Tensor, R: torch.Tensor,
|
|
240
|
+
grad_output: torch.Tensor,
|
|
241
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
242
|
+
"""Implicit-diff gradients of CARE. ``A_cl = A - B R^{-1} B^T P``; adjoint is a
|
|
243
|
+
continuous Lyapunov solve ``A_cl W + W A_cl^T = sym(grad_output)``."""
|
|
244
|
+
Rinv = torch.linalg.pinv(R)
|
|
245
|
+
K = Rinv @ B.mT @ P # [B... x n x m]
|
|
246
|
+
A_cl = A - B @ K # [B... x m x m]
|
|
247
|
+
|
|
248
|
+
W = _solve_continuous_lyapunov(A_cl, _sym(grad_output))
|
|
249
|
+
PW = P @ W
|
|
250
|
+
|
|
251
|
+
dA = -2.0 * PW
|
|
252
|
+
dB = 2.0 * PW @ K.mT
|
|
253
|
+
dQ = -_sym(W)
|
|
254
|
+
dR = -_sym(K @ W @ K.mT)
|
|
255
|
+
return dA, dB, dQ, dR
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
def _dare_backward(
|
|
259
|
+
P: torch.Tensor, A: torch.Tensor, B: torch.Tensor, Q: torch.Tensor, R: torch.Tensor,
|
|
260
|
+
grad_output: torch.Tensor,
|
|
261
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
262
|
+
"""Implicit-diff gradients of DARE.
|
|
263
|
+
|
|
264
|
+
Batched port of the exact closed-form expressions in ``are_temp.txt``
|
|
265
|
+
(``Riccati.backward``); ``m`` is the state dim and ``n`` the control dim.
|
|
266
|
+
"""
|
|
267
|
+
m, n = B.shape[-2:]
|
|
268
|
+
dtype, device = P.dtype, P.device
|
|
269
|
+
|
|
270
|
+
go = _vecc_row(grad_output) # [B... x 1 x m^2]
|
|
271
|
+
|
|
272
|
+
M2 = torch.linalg.inv(R + B.mT @ P @ B) # [B... x n x n]
|
|
273
|
+
PB = P @ B # [B... x m x n]
|
|
274
|
+
PBM2 = PB @ M2 # [B... x m x n]
|
|
275
|
+
PBM2BT = PBM2 @ B.mT # [B... x m x m]
|
|
276
|
+
M1 = P - PBM2BT @ P # [B... x m x m]
|
|
277
|
+
|
|
278
|
+
I_m = torch.eye(m, dtype=dtype, device=device)
|
|
279
|
+
I_n = torch.eye(n, dtype=dtype, device=device)
|
|
280
|
+
I_m2 = torch.eye(m * m, dtype=dtype, device=device)
|
|
281
|
+
I_n2 = torch.eye(n * n, dtype=dtype, device=device)
|
|
282
|
+
Vp_m = _commutation(m, dtype, device)
|
|
283
|
+
Vp_n = _commutation(n, dtype, device)
|
|
284
|
+
|
|
285
|
+
AT_kron = _bkron(A.mT, A.mT) # [B... x m^2 x m^2]
|
|
286
|
+
PB_kron = _bkron(PB, PB) # [B... x m^2 x n^2]
|
|
287
|
+
M2_kron = _bkron(M2, M2) # [B... x n^2 x n^2]
|
|
288
|
+
|
|
289
|
+
LHS = PB_kron @ M2_kron @ _bkron(B.mT, B.mT)
|
|
290
|
+
LHS = LHS - _bkron(I_m, PBM2BT) - _bkron(PBM2BT, I_m) + I_m2
|
|
291
|
+
LHS = I_m2 - AT_kron @ LHS
|
|
292
|
+
invLHS = torch.linalg.inv(LHS) # [B... x m^2 x m^2]
|
|
293
|
+
|
|
294
|
+
# dA
|
|
295
|
+
rhs = Vp_m + I_m2
|
|
296
|
+
dA_mat = invLHS @ rhs @ _bkron(I_m, A.mT @ M1)
|
|
297
|
+
dA = _unvecT(go @ dA_mat, m, m)
|
|
298
|
+
|
|
299
|
+
# dB
|
|
300
|
+
rhs = _bkron(I_n, B.mT @ P) # [B... x n^2 x (n*m)]
|
|
301
|
+
rhs = (I_n2 + Vp_n) @ rhs
|
|
302
|
+
rhs = PB_kron @ M2_kron @ rhs # [B... x m^2 x (n*m)]
|
|
303
|
+
rhs = rhs - (I_m2 + Vp_m) @ _bkron(PBM2, P)
|
|
304
|
+
dB_mat = invLHS @ AT_kron @ rhs
|
|
305
|
+
dB = _unvecT(go @ dB_mat, n, m)
|
|
306
|
+
|
|
307
|
+
# dQ
|
|
308
|
+
dQ = _sym(_unvec(go @ invLHS, m, m))
|
|
309
|
+
|
|
310
|
+
# dR
|
|
311
|
+
rhs = AT_kron @ PB_kron @ M2_kron # [B... x m^2 x n^2]
|
|
312
|
+
dR_mat = invLHS @ rhs
|
|
313
|
+
dR = _sym(_unvec(go @ dR_mat, n, n))
|
|
314
|
+
|
|
315
|
+
return dA, dB, dQ, dR
|
|
316
|
+
|
|
317
|
+
|
|
318
|
+
# SECTION: autograd.Function wrappers
|
|
319
|
+
class _DiscreteARE(torch.autograd.Function):
|
|
320
|
+
@staticmethod
|
|
321
|
+
def forward(ctx, A, B, Q, R):
|
|
322
|
+
Q, R = _sym(Q), _sym(R)
|
|
323
|
+
with torch.no_grad():
|
|
324
|
+
P = _are_forward(A, B, Q, R, discrete=True)
|
|
325
|
+
ctx.save_for_backward(P, A, B, Q, R)
|
|
326
|
+
return P
|
|
327
|
+
|
|
328
|
+
@staticmethod
|
|
329
|
+
def backward(ctx, grad_output):
|
|
330
|
+
P, A, B, Q, R = ctx.saved_tensors
|
|
331
|
+
return _dare_backward(P, A, B, Q, R, grad_output)
|
|
332
|
+
|
|
333
|
+
|
|
334
|
+
class _ContinuousARE(torch.autograd.Function):
|
|
335
|
+
@staticmethod
|
|
336
|
+
def forward(ctx, A, B, Q, R):
|
|
337
|
+
Q, R = _sym(Q), _sym(R)
|
|
338
|
+
with torch.no_grad():
|
|
339
|
+
P = _are_forward(A, B, Q, R, discrete=False)
|
|
340
|
+
ctx.save_for_backward(P, A, B, Q, R)
|
|
341
|
+
return P
|
|
342
|
+
|
|
343
|
+
@staticmethod
|
|
344
|
+
def backward(ctx, grad_output):
|
|
345
|
+
P, A, B, Q, R = ctx.saved_tensors
|
|
346
|
+
return _care_backward(P, A, B, Q, R, grad_output)
|
|
347
|
+
|
|
348
|
+
|
|
349
|
+
# SECTION: public API
|
|
350
|
+
def solve_discrete_are(
|
|
351
|
+
A: torch.Tensor,
|
|
352
|
+
B: torch.Tensor,
|
|
353
|
+
Q: torch.Tensor,
|
|
354
|
+
R: torch.Tensor,
|
|
355
|
+
precision: Optional[torch.dtype] = None,
|
|
356
|
+
) -> torch.Tensor:
|
|
357
|
+
"""Solve the discrete-time algebraic Riccati equation (batched, differentiable).
|
|
358
|
+
|
|
359
|
+
``A^T P A - P - A^T P B (R + B^T P B)^{-1} B^T P A + Q = 0``.
|
|
360
|
+
"""
|
|
361
|
+
if precision is None:
|
|
362
|
+
return _DiscreteARE.apply(A, B, Q, R)
|
|
363
|
+
original_dtype = A.dtype
|
|
364
|
+
A, B, Q, R = A.to(precision), B.to(precision), Q.to(precision), R.to(precision)
|
|
365
|
+
return _DiscreteARE.apply(A, B, Q, R).to(original_dtype)
|
|
366
|
+
|
|
367
|
+
|
|
368
|
+
def solve_continuous_are(
|
|
369
|
+
A: torch.Tensor,
|
|
370
|
+
B: torch.Tensor,
|
|
371
|
+
Q: torch.Tensor,
|
|
372
|
+
R: torch.Tensor,
|
|
373
|
+
precision: Optional[torch.dtype] = torch.float64,
|
|
374
|
+
) -> torch.Tensor:
|
|
375
|
+
"""Solve the continuous-time algebraic Riccati equation (batched, differentiable).
|
|
376
|
+
|
|
377
|
+
``A^T P + P A - P B R^{-1} B^T P + Q = 0``.
|
|
378
|
+
"""
|
|
379
|
+
if precision is None:
|
|
380
|
+
return _ContinuousARE.apply(A, B, Q, R)
|
|
381
|
+
original_dtype = A.dtype
|
|
382
|
+
A, B, Q, R = A.to(precision), B.to(precision), Q.to(precision), R.to(precision)
|
|
383
|
+
return _ContinuousARE.apply(A, B, Q, R).to(original_dtype)
|
|
384
|
+
|
|
385
|
+
|
|
386
|
+
def test_discrete_are(
|
|
387
|
+
A: torch.Tensor,
|
|
388
|
+
B: torch.Tensor,
|
|
389
|
+
Q: torch.Tensor,
|
|
390
|
+
R: torch.Tensor,
|
|
391
|
+
P: torch.Tensor,
|
|
392
|
+
) -> torch.Tensor:
|
|
393
|
+
P = _sym(P)
|
|
394
|
+
ATP = A.mT @ P
|
|
395
|
+
ATPB = ATP @ B
|
|
396
|
+
return ATP @ A - P - ATPB @ torch.linalg.pinv(R + B.mT @ P @ B) @ ATPB.mT + Q
|
|
397
|
+
|
|
398
|
+
|
|
399
|
+
def test_continuous_are(
|
|
400
|
+
A: torch.Tensor,
|
|
401
|
+
B: torch.Tensor,
|
|
402
|
+
Q: torch.Tensor,
|
|
403
|
+
R: torch.Tensor,
|
|
404
|
+
P: torch.Tensor,
|
|
405
|
+
) -> torch.Tensor:
|
|
406
|
+
P = _sym(P)
|
|
407
|
+
ATP = A.mT @ P
|
|
408
|
+
PB = P @ B
|
|
409
|
+
return ATP + ATP.mT - PB @ torch.linalg.pinv(R) @ PB.mT + Q
|
ecliseutils/arrays.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
"""NumPy object-array comprehensions and ``LabeledArray`` alignment helpers."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from collections import OrderedDict
|
|
6
|
+
from typing import Any, Callable, Iterable, Iterator, Sequence
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
|
|
10
|
+
from .labeled_array import LabeledArray, LabeledDataset, array_of, as_ndarray
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
__all__ = [
|
|
14
|
+
"array_of",
|
|
15
|
+
"multi_iter",
|
|
16
|
+
"multi_enumerate",
|
|
17
|
+
"multi_map",
|
|
18
|
+
"multi_zip",
|
|
19
|
+
"dim_array_like",
|
|
20
|
+
"broadcast_dim_array_shapes",
|
|
21
|
+
"broadcast_dim_arrays",
|
|
22
|
+
"take_from_dim_array",
|
|
23
|
+
]
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def multi_iter(arr: "np.ndarray | LabeledArray") -> Iterable[Any]:
|
|
27
|
+
for x in np.nditer(as_ndarray(arr), flags=["refs_ok"]):
|
|
28
|
+
yield x[()]
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def multi_enumerate(arr: "np.ndarray | LabeledArray") -> Iterable[tuple[Sequence[int], Any]]:
|
|
32
|
+
it = np.nditer(as_ndarray(arr), flags=["multi_index", "refs_ok"])
|
|
33
|
+
for x in it:
|
|
34
|
+
yield it.multi_index, x[()]
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def multi_map(func: Callable[[Any], Any], arr: "np.ndarray | LabeledArray", dtype: type = None):
|
|
38
|
+
base = as_ndarray(arr)
|
|
39
|
+
if dtype is None:
|
|
40
|
+
dtype = type(func(base.ravel()[0]))
|
|
41
|
+
result = np.empty_like(base, dtype=dtype)
|
|
42
|
+
for idx, x in multi_enumerate(base):
|
|
43
|
+
result[idx] = func(x)
|
|
44
|
+
return LabeledArray(result, arr.dims) if isinstance(arr, LabeledArray) else result
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def multi_zip(*arrs: "np.ndarray | LabeledArray") -> np.ndarray:
|
|
48
|
+
"""Zip element-wise into an object array of tuples.
|
|
49
|
+
|
|
50
|
+
Unlike a structured recarray, an object array of plain tuples never lets
|
|
51
|
+
numpy introspect ``.dtype``/``.names`` on the elements, so it is safe to
|
|
52
|
+
hold ``Tensor``/``TensorDict`` cells.
|
|
53
|
+
"""
|
|
54
|
+
bases = [as_ndarray(arr) for arr in arrs]
|
|
55
|
+
result = np.empty(bases[0].shape, dtype=object)
|
|
56
|
+
for idx in np.ndindex(bases[0].shape):
|
|
57
|
+
result[idx] = tuple(base[idx] for base in bases)
|
|
58
|
+
return result
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def dim_array_like(arr: LabeledArray, dtype: type) -> LabeledArray:
|
|
62
|
+
empty_arr = np.full_like(as_ndarray(arr), None, dtype=dtype)
|
|
63
|
+
return LabeledArray(empty_arr, arr.dims)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def broadcast_dim_array_shapes(*dim_arrs: Iterable[LabeledArray]) -> "OrderedDict[str, int]":
|
|
67
|
+
dim_dict = OrderedDict()
|
|
68
|
+
for dim_arr in dim_arrs:
|
|
69
|
+
for dim_name, dim_len in zip(dim_arr.dims, dim_arr.shape):
|
|
70
|
+
dim_dict.setdefault(dim_name, []).append(dim_len)
|
|
71
|
+
return OrderedDict((k, np.broadcast_shapes(*v)[0]) for k, v in dim_dict.items())
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def broadcast_dim_arrays(*dim_arrs: Iterable[np.ndarray]) -> Iterator[LabeledArray]:
|
|
75
|
+
_dim_arrs = []
|
|
76
|
+
for dim_arr in dim_arrs:
|
|
77
|
+
if isinstance(dim_arr, LabeledArray):
|
|
78
|
+
_dim_arrs.append(dim_arr)
|
|
79
|
+
elif isinstance(dim_arr, np.ndarray):
|
|
80
|
+
assert dim_arr.ndim == 0
|
|
81
|
+
_dim_arrs.append(LabeledArray(dim_arr, ()))
|
|
82
|
+
else:
|
|
83
|
+
_dim_arrs.append(LabeledArray(array_of(dim_arr), ()))
|
|
84
|
+
|
|
85
|
+
dim_dict = broadcast_dim_array_shapes(*_dim_arrs)
|
|
86
|
+
return (dim_arr.broadcast(dim_dict) for dim_arr in _dim_arrs)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def take_from_dim_array(dim_arr: "LabeledArray | LabeledDataset", idx: dict[str, Any]):
|
|
90
|
+
return dim_arr.take(indices=idx)
|
ecliseutils/dicts.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
"""Dict / nested-structure and function-call helpers."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import hashlib
|
|
6
|
+
import inspect
|
|
7
|
+
from typing import Any, Callable
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
__all__ = [
|
|
11
|
+
"flatten_nested_dict",
|
|
12
|
+
"map_dict",
|
|
13
|
+
"nested_type",
|
|
14
|
+
"print_dict",
|
|
15
|
+
"call_func_with_kwargs",
|
|
16
|
+
"hash_hex",
|
|
17
|
+
]
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def hash_hex(s: str) -> str:
|
|
21
|
+
"""Short (8 hex char) SHA-256 digest of a string. Handy for config-hash IDs."""
|
|
22
|
+
return hashlib.sha256(s.encode("utf-8")).hexdigest()[:8]
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def flatten_nested_dict(d: dict[str, Any]) -> dict[str, Any]:
|
|
26
|
+
result = {}
|
|
27
|
+
|
|
28
|
+
def _flatten_nested_dict(s: tuple[str, ...], d: dict[str, Any]) -> None:
|
|
29
|
+
for k, v in d.items():
|
|
30
|
+
if isinstance(v, dict):
|
|
31
|
+
_flatten_nested_dict((*s, k), v)
|
|
32
|
+
else:
|
|
33
|
+
result[".".join((*s, k))] = v
|
|
34
|
+
_flatten_nested_dict((), d)
|
|
35
|
+
return result
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def map_dict(d: dict[str, Any], func: Callable[[Any], Any]) -> dict[str, Any]:
|
|
39
|
+
return {
|
|
40
|
+
k: map_dict(v, func) if hasattr(v, "items") else func(v)
|
|
41
|
+
for k, v in d.items()
|
|
42
|
+
}
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def nested_type(o: object) -> object:
|
|
46
|
+
if type(o) in [list, tuple]:
|
|
47
|
+
return type(o)(map(nested_type, o))
|
|
48
|
+
elif type(o) == dict:
|
|
49
|
+
return {k: nested_type(v) for k, v in o.items()}
|
|
50
|
+
else:
|
|
51
|
+
return type(o)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def print_dict(d: "dict[str, Any] | object", n: int = 0, indent: int = 4) -> None:
|
|
55
|
+
if isinstance(d, dict):
|
|
56
|
+
for k, v in d.items():
|
|
57
|
+
print(" " * (n * indent) + k)
|
|
58
|
+
print_dict(v, n=n + 1, indent=indent)
|
|
59
|
+
else:
|
|
60
|
+
to_print = str(d)
|
|
61
|
+
print("\n".join([" " * (n * indent) + s for s in to_print.split("\n")]))
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def call_func_with_kwargs(func: Callable, args: tuple[Any, ...], kwargs: dict[str, Any]):
|
|
65
|
+
params = inspect.signature(func).parameters
|
|
66
|
+
required_args = [
|
|
67
|
+
kwargs[k] if k in kwargs else args[i] for i, (k, v) in enumerate(params.items())
|
|
68
|
+
if v.kind is inspect.Parameter.POSITIONAL_OR_KEYWORD and v.default is inspect.Parameter.empty
|
|
69
|
+
]
|
|
70
|
+
additional_args = args[len(required_args):]
|
|
71
|
+
|
|
72
|
+
allow_var_keywords = any(v.kind is inspect.Parameter.VAR_KEYWORD for v in params.values())
|
|
73
|
+
valid_kwargs = {
|
|
74
|
+
k: v for k, v in kwargs.items()
|
|
75
|
+
if ((params[k].default is not inspect.Parameter.empty) if k in params else allow_var_keywords)
|
|
76
|
+
}
|
|
77
|
+
return func(*required_args, *additional_args, **valid_kwargs)
|