ecliseutils 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,99 @@
1
+ """ecliseutils: shared general-purpose research utilities.
2
+
3
+ Submodules group the utilities by concern (``modules``, ``arrays``, ``linalg``,
4
+ ``are``, ``ode``, ``labeled_array``, ``ensemble``, ``fast_conv_scan``,
5
+ ``recursive``, ``dicts``, ``io``, ``memory``, ``timing``, ``plotting``,
6
+ ``settings``, ``types``). The most commonly used names are re-exported here for
7
+ convenience (``import ecliseutils as eu; eu.stack_module_arr(...)``).
8
+
9
+ Projects should call :func:`ecliseutils.configure` early (typically from their
10
+ own ``settings`` module) to set device/dtype/precision/seed.
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ from . import (
16
+ arrays,
17
+ are,
18
+ dicts,
19
+ ensemble,
20
+ fast_conv_scan,
21
+ io,
22
+ labeled_array,
23
+ linalg,
24
+ memory,
25
+ modules,
26
+ ode,
27
+ plotting,
28
+ recursive,
29
+ settings,
30
+ timing,
31
+ types,
32
+ )
33
+
34
+ from .settings import configure, set_debug, register_safe_globals, default_dtype
35
+ from .types import ModelPair
36
+
37
+ from .labeled_array import LabeledArray, LabeledDataset, array_of, put_object, as_ndarray
38
+ from .modules import (
39
+ stack_tensor_arr,
40
+ stack_module_arr,
41
+ stack_module_arr_preserve_reference,
42
+ run_module_arr,
43
+ multi_vmap,
44
+ buffer_dict,
45
+ td_items,
46
+ td_get,
47
+ parameter_td,
48
+ mask_dataset_with_total_sequence_length,
49
+ broadcast_shapes,
50
+ get_all_hooks,
51
+ )
52
+ from .arrays import (
53
+ multi_iter,
54
+ multi_enumerate,
55
+ multi_map,
56
+ multi_zip,
57
+ dim_array_like,
58
+ broadcast_dim_array_shapes,
59
+ broadcast_dim_arrays,
60
+ take_from_dim_array,
61
+ )
62
+ from .linalg import (
63
+ pow_series,
64
+ batch_trace,
65
+ kl_div,
66
+ sqrtm,
67
+ complex,
68
+ ceildiv,
69
+ ceil,
70
+ T,
71
+ hadamard_conjugation,
72
+ hadamard_conjugation_diff_order1,
73
+ hadamard_conjugation_diff_order2,
74
+ inverse,
75
+ eig_some,
76
+ )
77
+ from .are import (
78
+ solve_discrete_are,
79
+ solve_continuous_are,
80
+ test_discrete_are,
81
+ test_continuous_are,
82
+ )
83
+ from .ode import linspace, geomspace, batch_odeint
84
+ from .ensemble import EnsembleModule, DEFAULT_SPLIT_SIZE
85
+ from .fast_conv_scan import ConvScanFn, conv_scan
86
+ from .recursive import rgetattr, rsetattr, rhasattr, rgetitem, rsetitem
87
+ from .dicts import flatten_nested_dict, map_dict, nested_type, print_dict, call_func_with_kwargs, hash_hex
88
+ from .io import torch_load, empty_cache, reset_seed, model_size
89
+ from .memory import (
90
+ get_tensors_in_memory,
91
+ print_tensors_in_memory,
92
+ get_tensors_in_memory_shape,
93
+ track_tensor_diff,
94
+ )
95
+ from .timing import Timer, identity, PTR, print_disabled, print_enabled, track_calls
96
+ from .plotting import color, confidence_ellipse
97
+
98
+
99
+ __version__ = "0.1.0"
ecliseutils/are.py ADDED
@@ -0,0 +1,409 @@
1
+ """Algebraic Riccati equation solvers (batched, PyTorch, differentiable).
2
+
3
+ A single in-house implementation for both the discrete (DARE) and continuous
4
+ (CARE) algebraic Riccati equations that is:
5
+
6
+ * **batched / parallel** -- everything is batched torch, no per-element scipy
7
+ loop;
8
+ * **robust to degenerate ``R`` (e.g. ``R = 0``) for the DARE** -- the forward
9
+ builds the van Dooren *extended pencil*, which never forms ``R^{-1}``, then
10
+ works on the deflated pencil ``(Lhat, Mhat)`` via the disk-function matrix
11
+ ``(Lhat + Mhat)^{-1}(Lhat - Mhat)``. Inverting neither ``Lhat`` nor ``Mhat``
12
+ means a deadbeat / ``R = 0`` spectrum (closed-loop eigenvalues at ``mu = 0``
13
+ with symplectic partners at ``mu = inf``) is handled cleanly. (The CARE with
14
+ singular ``R`` is genuinely ill-posed -- ``B R^{-1} B^T`` diverges and the
15
+ closed-loop eigenvalues escape to ``inf`` *on* the imaginary-axis splitting
16
+ boundary -- so CARE requires a nonsingular ``R``.)
17
+ * **robust to defective / non-diagonalizable matrices** -- the stable subspace
18
+ is extracted with the matrix **sign** function, a spectral projector that
19
+ never computes eigenvectors, so a rank-deficient or Jordan-block eigenvector
20
+ basis (which breaks ``torch.linalg.eig``) cannot break it;
21
+ * **differentiable** -- the (non-differentiable) forward is wrapped in a custom
22
+ :class:`torch.autograd.Function` whose backward solves a Lyapunov/Sylvester
23
+ adjoint via implicit differentiation. Gradient stability is therefore
24
+ decoupled from the forward method.
25
+
26
+ Why not an in-house Schur/QZ? No CUDA library (cuSOLVER, MAGMA, ``torch.linalg``)
27
+ exposes a batched nonsymmetric real-Schur / QZ with invariant-subspace
28
+ reordering, and hand-rolling Francis double-shift iterations is branchy,
29
+ sequential and GPU-hostile. We only need the stable invariant/deflating
30
+ subspace, which the sign/disk function computes with batched ``matmul`` /
31
+ ``solve`` / ``svd`` (all well-supported on CUDA, unlike ``eig``).
32
+
33
+ Caveats:
34
+
35
+ * **Solvability assumption.** The deflated pencil ``(Lhat, Mhat)`` must have no
36
+ eigenvalue exactly on the splitting boundary (imaginary axis for CARE, unit
37
+ circle for DARE). This is the standard existence condition for a stabilizing
38
+ solution; the sign/disk function is well-defined precisely under it.
39
+ Off-boundary defectiveness / Jordan blocks are fine. (Singular factors --
40
+ eigenvalues at ``0`` or ``inf`` -- are fine for the DARE disk form but not for
41
+ the CARE sign form, hence the CARE nonsingular-``R`` requirement above.)
42
+ * **Backward w.r.t. ``R`` when ``R`` is exactly singular** is ill-defined (the
43
+ gradient involves ``R^{-1}``); the forward ``P`` is still produced and a
44
+ pseudo-inverse convention is used.
45
+
46
+ ``test_discrete_are`` / ``test_continuous_are`` return the Riccati residual and
47
+ are handy for accuracy checks (especially in the degenerate cases scipy cannot
48
+ even run).
49
+ """
50
+
51
+ from __future__ import annotations
52
+
53
+ from typing import Optional, Tuple
54
+
55
+ import torch
56
+
57
+
58
+ __all__ = [
59
+ "solve_discrete_are",
60
+ "solve_continuous_are",
61
+ "test_discrete_are",
62
+ "test_continuous_are",
63
+ ]
64
+
65
+
66
+ # SECTION: small batched linear-algebra helpers
67
+ def _sym(X: torch.Tensor) -> torch.Tensor:
68
+ return 0.5 * (X + X.mT)
69
+
70
+
71
+ def _safe_solve(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
72
+ """``A^{-1} B`` with a pseudo-inverse fallback when ``A`` is (numerically) singular."""
73
+ try:
74
+ return torch.linalg.solve(A, B)
75
+ except torch._C._LinAlgError:
76
+ return torch.linalg.pinv(A) @ B
77
+
78
+
79
+ def _bkron(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
80
+ """Batched Kronecker product matching ``torch.kron`` on the trailing 2 dims.
81
+
82
+ ``A`` is ``[B... x p x q]`` and ``B`` is ``[B... x r x s]`` (batch dims
83
+ broadcast, and either operand may be unbatched). Returns ``[B... x pr x qs]``.
84
+ """
85
+ p, q = A.shape[-2:]
86
+ r, s = B.shape[-2:]
87
+ out = torch.einsum("...ij,...kl->...ikjl", A, B)
88
+ return out.reshape(*out.shape[:-4], p * r, q * s)
89
+
90
+
91
+ def _commutation(n: int, dtype: torch.dtype, device: torch.device) -> torch.Tensor:
92
+ """Commutation matrix ``K`` with ``K vec(X) = vec(X^T)`` (column-major vec)."""
93
+ idx = torch.arange(n * n, device=device)
94
+ perm = idx // n + n * (idx % n)
95
+ return torch.eye(n * n, dtype=dtype, device=device)[perm]
96
+
97
+
98
+ def _vecc_row(X: torch.Tensor) -> torch.Tensor:
99
+ """Column-major ``vec(X)`` returned as a row vector ``[B... x 1 x (a*b)]``."""
100
+ a, b = X.shape[-2:]
101
+ return X.mT.reshape(*X.shape[:-2], 1, a * b)
102
+
103
+
104
+ def _unvec(v: torch.Tensor, a: int, b: int) -> torch.Tensor:
105
+ """Inverse of a row-vector ``[B... x 1 x (a*b)]`` via ``.view(a, b)`` (no transpose)."""
106
+ v = v.squeeze(-2)
107
+ return v.reshape(*v.shape[:-1], a, b)
108
+
109
+
110
+ def _unvecT(v: torch.Tensor, a: int, b: int) -> torch.Tensor:
111
+ """Inverse via ``.view(a, b).T`` (i.e. column-major un-vec) -> ``[B... x b x a]``."""
112
+ return _unvec(v, a, b).mT
113
+
114
+
115
+ # SECTION: forward -- deflated extended pencil + matrix sign/disk projector
116
+ def _extended_pencil(
117
+ A: torch.Tensor,
118
+ B: torch.Tensor,
119
+ Q: torch.Tensor,
120
+ R: torch.Tensor,
121
+ discrete: bool,
122
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
123
+ """Van Dooren extended pencil ``L - lambda M`` (size ``2m + n``); no ``R^{-1}``."""
124
+ bsz = A.shape[:-2]
125
+ m, n = B.shape[-2:]
126
+
127
+ I_m = torch.eye(m, dtype=A.dtype, device=A.device).expand(*bsz, m, m)
128
+ Zmm = A.new_zeros((*bsz, m, m))
129
+ Zmn = A.new_zeros((*bsz, m, n))
130
+ Znm = A.new_zeros((*bsz, n, m))
131
+ Znn = A.new_zeros((*bsz, n, n))
132
+
133
+ if discrete:
134
+ L = torch.cat([
135
+ torch.cat([A, Zmm, B], dim=-1),
136
+ torch.cat([-Q, I_m, Zmn], dim=-1),
137
+ torch.cat([Znm, Znm, R], dim=-1),
138
+ ], dim=-2)
139
+ M = torch.cat([
140
+ torch.cat([I_m, Zmm, Zmn], dim=-1),
141
+ torch.cat([Zmm, A.mT, Zmn], dim=-1),
142
+ torch.cat([Znm, -B.mT, Znn], dim=-1),
143
+ ], dim=-2)
144
+ else:
145
+ L = torch.cat([
146
+ torch.cat([A, Zmm, B], dim=-1),
147
+ torch.cat([-Q, -A.mT, Zmn], dim=-1),
148
+ torch.cat([Znm, B.mT, R], dim=-1),
149
+ ], dim=-2)
150
+ M = torch.cat([
151
+ torch.cat([I_m, Zmm, Zmn], dim=-1),
152
+ torch.cat([Zmm, I_m, Zmn], dim=-1),
153
+ torch.cat([Znm, Znm, Znn], dim=-1),
154
+ ], dim=-2)
155
+ return L, M
156
+
157
+
158
+ def _deflate(L: torch.Tensor, M: torch.Tensor, m: int, n: int) -> Tuple[torch.Tensor, torch.Tensor]:
159
+ """Compress the trailing ``n`` columns -> regular ``2m x 2m`` pencil ``(Lhat, Mhat)``.
160
+
161
+ The QR of ``L``'s trailing control columns and projection onto their
162
+ orthogonal complement eliminates the control variable without ever forming
163
+ ``R^{-1}`` -- this is what makes singular / zero ``R`` admissible.
164
+ """
165
+ q_of_qr = torch.linalg.qr(L[..., :, -n:], mode="complete")[0] # [B... x (2m+n) x (2m+n)]
166
+ defl = q_of_qr[..., :, n:] # [B... x (2m+n) x 2m]
167
+ Lhat = defl.mT @ L[..., :, :2 * m] # [B... x 2m x 2m]
168
+ Mhat = defl.mT @ M[..., :, :2 * m] # [B... x 2m x 2m]
169
+ return Lhat, Mhat
170
+
171
+
172
+ def _matrix_sign(C: torch.Tensor, max_iter: int = 64, tol: float = 1e-12) -> torch.Tensor:
173
+ """Matrix sign function via scaled Newton iteration ``X <- 1/2 (cX + (cX)^{-1})``."""
174
+ sz = C.shape[-1]
175
+ eps = torch.finfo(C.dtype).eps
176
+ X = C
177
+ for _ in range(max_iter):
178
+ Xinv = torch.linalg.inv(X)
179
+ det = torch.linalg.det(X).abs().clamp_min(eps)
180
+ c = det.pow(-1.0 / sz)[..., None, None]
181
+ Xnew = 0.5 * (c * X + Xinv / c)
182
+ denom = X.norm(dim=(-2, -1)).clamp_min(eps)
183
+ diff = (Xnew - X).norm(dim=(-2, -1)) / denom
184
+ X = Xnew
185
+ if torch.all(diff < tol):
186
+ break
187
+ return X
188
+
189
+
190
+ def _stable_subspace_P(Lhat: torch.Tensor, Mhat: torch.Tensor, discrete: bool) -> torch.Tensor:
191
+ """Stable spectral projector of the pencil ``(Lhat, Mhat)`` -> ``P = U21 U11^{-1}``.
192
+
193
+ A single matrix sign function is used in both cases; the only difference is the
194
+ matrix it acts on, chosen so that the **stable** eigenvalues map to ``Re < 0``:
195
+
196
+ * CARE (stable ``Re mu < 0``): ``Chat = Mhat^{-1} Lhat`` (eigenvalues ``mu``).
197
+ * DARE (stable ``|mu| < 1``): the disk-function matrix
198
+ ``Chat = (Lhat + Mhat)^{-1} (Lhat - Mhat)`` (eigenvalues ``(mu - 1)/(mu + 1)``).
199
+ This never inverts ``Mhat`` (nor ``Lhat``), so deadbeat / ``R = 0`` spectra
200
+ with eigenvalues at ``mu = 0`` and ``mu = inf`` are handled gracefully.
201
+ """
202
+ sz = Lhat.shape[-1]
203
+ m = sz // 2
204
+ I = torch.eye(sz, dtype=Lhat.dtype, device=Lhat.device)
205
+
206
+ if discrete:
207
+ Chat = _safe_solve(Lhat + Mhat, Lhat - Mhat)
208
+ else:
209
+ Chat = _safe_solve(Mhat, Lhat)
210
+
211
+ S = _matrix_sign(Chat)
212
+ Pm = 0.5 * (I - S) # projector onto the stable subspace
213
+ U = torch.linalg.svd(Pm)[0][..., :, :m] # leading m left singular vectors
214
+ U11 = U[..., :m, :]
215
+ U21 = U[..., m:, :]
216
+ P = _safe_solve(U11.mT, U21.mT).mT
217
+ return _sym(P)
218
+
219
+
220
+ def _are_forward(A: torch.Tensor, B: torch.Tensor, Q: torch.Tensor, R: torch.Tensor, discrete: bool) -> torch.Tensor:
221
+ m, n = B.shape[-2:]
222
+ L, M = _extended_pencil(A, B, Q, R, discrete)
223
+ Lhat, Mhat = _deflate(L, M, m, n)
224
+ return _stable_subspace_P(Lhat, Mhat, discrete)
225
+
226
+
227
+ # SECTION: backward -- implicit differentiation (shared adjoint machinery)
228
+ def _solve_continuous_lyapunov(A_cl: torch.Tensor, C: torch.Tensor) -> torch.Tensor:
229
+ """Solve ``A_cl W + W A_cl^T = C`` (batched, via Kronecker linear system)."""
230
+ m = A_cl.shape[-1]
231
+ I = torch.eye(m, dtype=A_cl.dtype, device=A_cl.device)
232
+ Lk = _bkron(I, A_cl) + _bkron(A_cl, I) # [B... x m^2 x m^2]
233
+ rhs = C.mT.reshape(*C.shape[:-2], m * m, 1)
234
+ x = torch.linalg.solve(Lk, rhs)
235
+ return x.reshape(*C.shape[:-2], m, m).mT
236
+
237
+
238
+ def _care_backward(
239
+ P: torch.Tensor, A: torch.Tensor, B: torch.Tensor, Q: torch.Tensor, R: torch.Tensor,
240
+ grad_output: torch.Tensor,
241
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
242
+ """Implicit-diff gradients of CARE. ``A_cl = A - B R^{-1} B^T P``; adjoint is a
243
+ continuous Lyapunov solve ``A_cl W + W A_cl^T = sym(grad_output)``."""
244
+ Rinv = torch.linalg.pinv(R)
245
+ K = Rinv @ B.mT @ P # [B... x n x m]
246
+ A_cl = A - B @ K # [B... x m x m]
247
+
248
+ W = _solve_continuous_lyapunov(A_cl, _sym(grad_output))
249
+ PW = P @ W
250
+
251
+ dA = -2.0 * PW
252
+ dB = 2.0 * PW @ K.mT
253
+ dQ = -_sym(W)
254
+ dR = -_sym(K @ W @ K.mT)
255
+ return dA, dB, dQ, dR
256
+
257
+
258
+ def _dare_backward(
259
+ P: torch.Tensor, A: torch.Tensor, B: torch.Tensor, Q: torch.Tensor, R: torch.Tensor,
260
+ grad_output: torch.Tensor,
261
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
262
+ """Implicit-diff gradients of DARE.
263
+
264
+ Batched port of the exact closed-form expressions in ``are_temp.txt``
265
+ (``Riccati.backward``); ``m`` is the state dim and ``n`` the control dim.
266
+ """
267
+ m, n = B.shape[-2:]
268
+ dtype, device = P.dtype, P.device
269
+
270
+ go = _vecc_row(grad_output) # [B... x 1 x m^2]
271
+
272
+ M2 = torch.linalg.inv(R + B.mT @ P @ B) # [B... x n x n]
273
+ PB = P @ B # [B... x m x n]
274
+ PBM2 = PB @ M2 # [B... x m x n]
275
+ PBM2BT = PBM2 @ B.mT # [B... x m x m]
276
+ M1 = P - PBM2BT @ P # [B... x m x m]
277
+
278
+ I_m = torch.eye(m, dtype=dtype, device=device)
279
+ I_n = torch.eye(n, dtype=dtype, device=device)
280
+ I_m2 = torch.eye(m * m, dtype=dtype, device=device)
281
+ I_n2 = torch.eye(n * n, dtype=dtype, device=device)
282
+ Vp_m = _commutation(m, dtype, device)
283
+ Vp_n = _commutation(n, dtype, device)
284
+
285
+ AT_kron = _bkron(A.mT, A.mT) # [B... x m^2 x m^2]
286
+ PB_kron = _bkron(PB, PB) # [B... x m^2 x n^2]
287
+ M2_kron = _bkron(M2, M2) # [B... x n^2 x n^2]
288
+
289
+ LHS = PB_kron @ M2_kron @ _bkron(B.mT, B.mT)
290
+ LHS = LHS - _bkron(I_m, PBM2BT) - _bkron(PBM2BT, I_m) + I_m2
291
+ LHS = I_m2 - AT_kron @ LHS
292
+ invLHS = torch.linalg.inv(LHS) # [B... x m^2 x m^2]
293
+
294
+ # dA
295
+ rhs = Vp_m + I_m2
296
+ dA_mat = invLHS @ rhs @ _bkron(I_m, A.mT @ M1)
297
+ dA = _unvecT(go @ dA_mat, m, m)
298
+
299
+ # dB
300
+ rhs = _bkron(I_n, B.mT @ P) # [B... x n^2 x (n*m)]
301
+ rhs = (I_n2 + Vp_n) @ rhs
302
+ rhs = PB_kron @ M2_kron @ rhs # [B... x m^2 x (n*m)]
303
+ rhs = rhs - (I_m2 + Vp_m) @ _bkron(PBM2, P)
304
+ dB_mat = invLHS @ AT_kron @ rhs
305
+ dB = _unvecT(go @ dB_mat, n, m)
306
+
307
+ # dQ
308
+ dQ = _sym(_unvec(go @ invLHS, m, m))
309
+
310
+ # dR
311
+ rhs = AT_kron @ PB_kron @ M2_kron # [B... x m^2 x n^2]
312
+ dR_mat = invLHS @ rhs
313
+ dR = _sym(_unvec(go @ dR_mat, n, n))
314
+
315
+ return dA, dB, dQ, dR
316
+
317
+
318
+ # SECTION: autograd.Function wrappers
319
+ class _DiscreteARE(torch.autograd.Function):
320
+ @staticmethod
321
+ def forward(ctx, A, B, Q, R):
322
+ Q, R = _sym(Q), _sym(R)
323
+ with torch.no_grad():
324
+ P = _are_forward(A, B, Q, R, discrete=True)
325
+ ctx.save_for_backward(P, A, B, Q, R)
326
+ return P
327
+
328
+ @staticmethod
329
+ def backward(ctx, grad_output):
330
+ P, A, B, Q, R = ctx.saved_tensors
331
+ return _dare_backward(P, A, B, Q, R, grad_output)
332
+
333
+
334
+ class _ContinuousARE(torch.autograd.Function):
335
+ @staticmethod
336
+ def forward(ctx, A, B, Q, R):
337
+ Q, R = _sym(Q), _sym(R)
338
+ with torch.no_grad():
339
+ P = _are_forward(A, B, Q, R, discrete=False)
340
+ ctx.save_for_backward(P, A, B, Q, R)
341
+ return P
342
+
343
+ @staticmethod
344
+ def backward(ctx, grad_output):
345
+ P, A, B, Q, R = ctx.saved_tensors
346
+ return _care_backward(P, A, B, Q, R, grad_output)
347
+
348
+
349
+ # SECTION: public API
350
+ def solve_discrete_are(
351
+ A: torch.Tensor,
352
+ B: torch.Tensor,
353
+ Q: torch.Tensor,
354
+ R: torch.Tensor,
355
+ precision: Optional[torch.dtype] = None,
356
+ ) -> torch.Tensor:
357
+ """Solve the discrete-time algebraic Riccati equation (batched, differentiable).
358
+
359
+ ``A^T P A - P - A^T P B (R + B^T P B)^{-1} B^T P A + Q = 0``.
360
+ """
361
+ if precision is None:
362
+ return _DiscreteARE.apply(A, B, Q, R)
363
+ original_dtype = A.dtype
364
+ A, B, Q, R = A.to(precision), B.to(precision), Q.to(precision), R.to(precision)
365
+ return _DiscreteARE.apply(A, B, Q, R).to(original_dtype)
366
+
367
+
368
+ def solve_continuous_are(
369
+ A: torch.Tensor,
370
+ B: torch.Tensor,
371
+ Q: torch.Tensor,
372
+ R: torch.Tensor,
373
+ precision: Optional[torch.dtype] = torch.float64,
374
+ ) -> torch.Tensor:
375
+ """Solve the continuous-time algebraic Riccati equation (batched, differentiable).
376
+
377
+ ``A^T P + P A - P B R^{-1} B^T P + Q = 0``.
378
+ """
379
+ if precision is None:
380
+ return _ContinuousARE.apply(A, B, Q, R)
381
+ original_dtype = A.dtype
382
+ A, B, Q, R = A.to(precision), B.to(precision), Q.to(precision), R.to(precision)
383
+ return _ContinuousARE.apply(A, B, Q, R).to(original_dtype)
384
+
385
+
386
+ def test_discrete_are(
387
+ A: torch.Tensor,
388
+ B: torch.Tensor,
389
+ Q: torch.Tensor,
390
+ R: torch.Tensor,
391
+ P: torch.Tensor,
392
+ ) -> torch.Tensor:
393
+ P = _sym(P)
394
+ ATP = A.mT @ P
395
+ ATPB = ATP @ B
396
+ return ATP @ A - P - ATPB @ torch.linalg.pinv(R + B.mT @ P @ B) @ ATPB.mT + Q
397
+
398
+
399
+ def test_continuous_are(
400
+ A: torch.Tensor,
401
+ B: torch.Tensor,
402
+ Q: torch.Tensor,
403
+ R: torch.Tensor,
404
+ P: torch.Tensor,
405
+ ) -> torch.Tensor:
406
+ P = _sym(P)
407
+ ATP = A.mT @ P
408
+ PB = P @ B
409
+ return ATP + ATP.mT - PB @ torch.linalg.pinv(R) @ PB.mT + Q
ecliseutils/arrays.py ADDED
@@ -0,0 +1,90 @@
1
+ """NumPy object-array comprehensions and ``LabeledArray`` alignment helpers."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from collections import OrderedDict
6
+ from typing import Any, Callable, Iterable, Iterator, Sequence
7
+
8
+ import numpy as np
9
+
10
+ from .labeled_array import LabeledArray, LabeledDataset, array_of, as_ndarray
11
+
12
+
13
+ __all__ = [
14
+ "array_of",
15
+ "multi_iter",
16
+ "multi_enumerate",
17
+ "multi_map",
18
+ "multi_zip",
19
+ "dim_array_like",
20
+ "broadcast_dim_array_shapes",
21
+ "broadcast_dim_arrays",
22
+ "take_from_dim_array",
23
+ ]
24
+
25
+
26
+ def multi_iter(arr: "np.ndarray | LabeledArray") -> Iterable[Any]:
27
+ for x in np.nditer(as_ndarray(arr), flags=["refs_ok"]):
28
+ yield x[()]
29
+
30
+
31
+ def multi_enumerate(arr: "np.ndarray | LabeledArray") -> Iterable[tuple[Sequence[int], Any]]:
32
+ it = np.nditer(as_ndarray(arr), flags=["multi_index", "refs_ok"])
33
+ for x in it:
34
+ yield it.multi_index, x[()]
35
+
36
+
37
+ def multi_map(func: Callable[[Any], Any], arr: "np.ndarray | LabeledArray", dtype: type = None):
38
+ base = as_ndarray(arr)
39
+ if dtype is None:
40
+ dtype = type(func(base.ravel()[0]))
41
+ result = np.empty_like(base, dtype=dtype)
42
+ for idx, x in multi_enumerate(base):
43
+ result[idx] = func(x)
44
+ return LabeledArray(result, arr.dims) if isinstance(arr, LabeledArray) else result
45
+
46
+
47
+ def multi_zip(*arrs: "np.ndarray | LabeledArray") -> np.ndarray:
48
+ """Zip element-wise into an object array of tuples.
49
+
50
+ Unlike a structured recarray, an object array of plain tuples never lets
51
+ numpy introspect ``.dtype``/``.names`` on the elements, so it is safe to
52
+ hold ``Tensor``/``TensorDict`` cells.
53
+ """
54
+ bases = [as_ndarray(arr) for arr in arrs]
55
+ result = np.empty(bases[0].shape, dtype=object)
56
+ for idx in np.ndindex(bases[0].shape):
57
+ result[idx] = tuple(base[idx] for base in bases)
58
+ return result
59
+
60
+
61
+ def dim_array_like(arr: LabeledArray, dtype: type) -> LabeledArray:
62
+ empty_arr = np.full_like(as_ndarray(arr), None, dtype=dtype)
63
+ return LabeledArray(empty_arr, arr.dims)
64
+
65
+
66
+ def broadcast_dim_array_shapes(*dim_arrs: Iterable[LabeledArray]) -> "OrderedDict[str, int]":
67
+ dim_dict = OrderedDict()
68
+ for dim_arr in dim_arrs:
69
+ for dim_name, dim_len in zip(dim_arr.dims, dim_arr.shape):
70
+ dim_dict.setdefault(dim_name, []).append(dim_len)
71
+ return OrderedDict((k, np.broadcast_shapes(*v)[0]) for k, v in dim_dict.items())
72
+
73
+
74
+ def broadcast_dim_arrays(*dim_arrs: Iterable[np.ndarray]) -> Iterator[LabeledArray]:
75
+ _dim_arrs = []
76
+ for dim_arr in dim_arrs:
77
+ if isinstance(dim_arr, LabeledArray):
78
+ _dim_arrs.append(dim_arr)
79
+ elif isinstance(dim_arr, np.ndarray):
80
+ assert dim_arr.ndim == 0
81
+ _dim_arrs.append(LabeledArray(dim_arr, ()))
82
+ else:
83
+ _dim_arrs.append(LabeledArray(array_of(dim_arr), ()))
84
+
85
+ dim_dict = broadcast_dim_array_shapes(*_dim_arrs)
86
+ return (dim_arr.broadcast(dim_dict) for dim_arr in _dim_arrs)
87
+
88
+
89
+ def take_from_dim_array(dim_arr: "LabeledArray | LabeledDataset", idx: dict[str, Any]):
90
+ return dim_arr.take(indices=idx)
ecliseutils/dicts.py ADDED
@@ -0,0 +1,77 @@
1
+ """Dict / nested-structure and function-call helpers."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import hashlib
6
+ import inspect
7
+ from typing import Any, Callable
8
+
9
+
10
+ __all__ = [
11
+ "flatten_nested_dict",
12
+ "map_dict",
13
+ "nested_type",
14
+ "print_dict",
15
+ "call_func_with_kwargs",
16
+ "hash_hex",
17
+ ]
18
+
19
+
20
+ def hash_hex(s: str) -> str:
21
+ """Short (8 hex char) SHA-256 digest of a string. Handy for config-hash IDs."""
22
+ return hashlib.sha256(s.encode("utf-8")).hexdigest()[:8]
23
+
24
+
25
+ def flatten_nested_dict(d: dict[str, Any]) -> dict[str, Any]:
26
+ result = {}
27
+
28
+ def _flatten_nested_dict(s: tuple[str, ...], d: dict[str, Any]) -> None:
29
+ for k, v in d.items():
30
+ if isinstance(v, dict):
31
+ _flatten_nested_dict((*s, k), v)
32
+ else:
33
+ result[".".join((*s, k))] = v
34
+ _flatten_nested_dict((), d)
35
+ return result
36
+
37
+
38
+ def map_dict(d: dict[str, Any], func: Callable[[Any], Any]) -> dict[str, Any]:
39
+ return {
40
+ k: map_dict(v, func) if hasattr(v, "items") else func(v)
41
+ for k, v in d.items()
42
+ }
43
+
44
+
45
+ def nested_type(o: object) -> object:
46
+ if type(o) in [list, tuple]:
47
+ return type(o)(map(nested_type, o))
48
+ elif type(o) == dict:
49
+ return {k: nested_type(v) for k, v in o.items()}
50
+ else:
51
+ return type(o)
52
+
53
+
54
+ def print_dict(d: "dict[str, Any] | object", n: int = 0, indent: int = 4) -> None:
55
+ if isinstance(d, dict):
56
+ for k, v in d.items():
57
+ print(" " * (n * indent) + k)
58
+ print_dict(v, n=n + 1, indent=indent)
59
+ else:
60
+ to_print = str(d)
61
+ print("\n".join([" " * (n * indent) + s for s in to_print.split("\n")]))
62
+
63
+
64
+ def call_func_with_kwargs(func: Callable, args: tuple[Any, ...], kwargs: dict[str, Any]):
65
+ params = inspect.signature(func).parameters
66
+ required_args = [
67
+ kwargs[k] if k in kwargs else args[i] for i, (k, v) in enumerate(params.items())
68
+ if v.kind is inspect.Parameter.POSITIONAL_OR_KEYWORD and v.default is inspect.Parameter.empty
69
+ ]
70
+ additional_args = args[len(required_args):]
71
+
72
+ allow_var_keywords = any(v.kind is inspect.Parameter.VAR_KEYWORD for v in params.values())
73
+ valid_kwargs = {
74
+ k: v for k, v in kwargs.items()
75
+ if ((params[k].default is not inspect.Parameter.empty) if k in params else allow_var_keywords)
76
+ }
77
+ return func(*required_args, *additional_args, **valid_kwargs)