PyDiffGame 1.0.0__py3-none-any.whl → 2.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
PyDiffGame/__init__.py CHANGED
@@ -0,0 +1,50 @@
1
+ """PyDiffGame — Nash-equilibrium solutions to linear-quadratic differential games.
2
+
3
+ PyDiffGame solves multi-objective dynamical control problems by reducing the
4
+ Game Hamilton–Jacobi–Bellman equations to coupled algebraic / differential
5
+ Riccati equations.
6
+
7
+ Quick start
8
+ -----------
9
+ >>> import numpy as np
10
+ >>> from PyDiffGame import ContinuousLQR
11
+ >>> A = np.array([[0.0, 1.0], [0.0, 0.0]])
12
+ >>> B = np.array([[0.0], [1.0]])
13
+ >>> lqr = ContinuousLQR(A=A, B=B, Q=np.eye(2), R=1.0).solve()
14
+ >>> lqr.is_closed_loop_stable()
15
+ True
16
+
17
+ The public API is intentionally small:
18
+
19
+ * :class:`~PyDiffGame.objective.Objective` (+ :func:`LQRObjective`,
20
+ :func:`GameObjective` helpers) — describe each player's cost.
21
+ * :class:`~PyDiffGame.continuous.ContinuousPyDiffGame` /
22
+ :class:`~PyDiffGame.discrete.DiscretePyDiffGame` — the solvers.
23
+ * :class:`~PyDiffGame.lqr.ContinuousLQR` / :class:`~PyDiffGame.lqr.DiscreteLQR`
24
+ — single-objective convenience wrappers.
25
+ * :class:`~PyDiffGame.comparison.PyDiffGameLQRComparison` — compare designs.
26
+ """
27
+
28
+ from __future__ import annotations
29
+
30
+ from PyDiffGame.base import PyDiffGame
31
+ from PyDiffGame.comparison import PyDiffGameLQRComparison
32
+ from PyDiffGame.continuous import ContinuousPyDiffGame
33
+ from PyDiffGame.discrete import DiscretePyDiffGame
34
+ from PyDiffGame.lqr import ContinuousLQR, DiscreteLQR
35
+ from PyDiffGame.objective import GameObjective, LQRObjective, Objective
36
+
37
+ __version__ = "2.0.1"
38
+
39
+ __all__ = [
40
+ "PyDiffGame",
41
+ "ContinuousPyDiffGame",
42
+ "DiscretePyDiffGame",
43
+ "ContinuousLQR",
44
+ "DiscreteLQR",
45
+ "Objective",
46
+ "LQRObjective",
47
+ "GameObjective",
48
+ "PyDiffGameLQRComparison",
49
+ "__version__",
50
+ ]
PyDiffGame/_typing.py ADDED
@@ -0,0 +1,25 @@
1
+ """Shared type aliases for :mod:`PyDiffGame`.
2
+
3
+ Kept in one place so the rest of the package can use precise, readable
4
+ annotations without repeating ``numpy.typing`` boilerplate. The aliases are
5
+ plain assignments (rather than :pep:`695` ``type`` statements) so the package
6
+ keeps importing on Python 3.11, while still reading cleanly on 3.14.
7
+ """
8
+
9
+ from collections.abc import Sequence
10
+ from os import PathLike
11
+ from typing import Any
12
+
13
+ import numpy as np
14
+ from numpy.typing import NDArray
15
+
16
+ #: A real-valued, floating point numpy array (vector or matrix).
17
+ FloatArray = NDArray[np.floating[Any]]
18
+
19
+ #: Anything that can be coerced into a float array (lists, tuples, scalars...).
20
+ ArrayLike = FloatArray | Sequence[float] | Sequence[Sequence[float]] | float | int
21
+
22
+ #: A filesystem path, accepted either as ``str`` or :class:`os.PathLike`.
23
+ PathInput = str | PathLike[str]
24
+
25
+ __all__ = ["FloatArray", "ArrayLike", "PathInput"]
PyDiffGame/base.py ADDED
@@ -0,0 +1,468 @@
1
+ """Abstract base class for differential games.
2
+
3
+ :class:`PyDiffGame` holds everything that is independent of the time domain:
4
+ construction and validation of the problem data, the
5
+ virtual-input decomposition, controllability tests, cost evaluation, the
6
+ solve/simulate orchestration and plotting. The continuous- and discrete-time
7
+ specifics live in :mod:`PyDiffGame.continuous` and :mod:`PyDiffGame.discrete`.
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ from abc import ABC, abstractmethod
13
+ from collections.abc import Sequence
14
+ from pathlib import Path
15
+ from typing import Any, Final
16
+
17
+ import numpy as np
18
+
19
+ from PyDiffGame import plotting
20
+ from PyDiffGame._typing import ArrayLike, FloatArray, PathInput
21
+ from PyDiffGame.objective import Objective
22
+
23
+
24
+ class PyDiffGame(ABC):
25
+ r"""Differential game abstract base class.
26
+
27
+ Considers control design for the linear system
28
+
29
+ .. math:: \dot{x}(t) = A x(t) + \sum_{i=1}^N B_i v_i(t)
30
+
31
+ where each player :math:`i` minimises a quadratic cost weighted by
32
+ :math:`Q_i` (state) and :math:`R_{ii}` (input). When the players share a
33
+ single physical input :math:`u` through decomposition matrices
34
+ :math:`M_i`, pass the full input matrix ``B`` together with objectives that
35
+ carry an ``M``; otherwise pass the per-player matrices ``Bs`` directly.
36
+
37
+ Parameters
38
+ ----------
39
+ A:
40
+ System dynamics matrix of shape ``(n, n)``.
41
+ objectives:
42
+ One :class:`~PyDiffGame.objective.Objective` per player.
43
+ B:
44
+ Full input matrix of shape ``(n, m)``, used together with each
45
+ objective's decomposition matrix ``M``.
46
+ Bs:
47
+ Per-player input matrices, an alternative to ``B`` + ``M``.
48
+ x_0, x_T:
49
+ Optional initial / terminal state vectors of shape ``(n,)``.
50
+ T_f:
51
+ Finite horizon length. ``None`` (the default) solves the
52
+ infinite-horizon problem.
53
+ P_f:
54
+ Terminal Riccati condition(s); defaults to the uncoupled algebraic
55
+ Riccati solutions.
56
+ L:
57
+ Number of time samples.
58
+ eta:
59
+ Number of trailing matrix norms inspected for convergence.
60
+ epsilon_x, epsilon_P:
61
+ Convergence tolerances for the state and for the Riccati matrices.
62
+ state_variables_names:
63
+ LaTeX names (without ``$``) for the ``n`` state variables, used in plots.
64
+ show_legend:
65
+ Whether plots include a legend.
66
+ debug:
67
+ Emit verbose diagnostics while solving.
68
+ """
69
+
70
+ T_f_default: Final[float] = 20.0
71
+ epsilon_x_default: Final[float] = 1e-7
72
+ epsilon_P_default: Final[float] = 1e-7
73
+ L_default: Final[int] = 1000
74
+ eta_default: Final[int] = 5
75
+ eigenvalue_tolerance: Final[float] = 1e-7
76
+ #: Standard gravitational acceleration, handy for the mechanical examples.
77
+ g: Final[float] = 9.81
78
+
79
+ max_x_iterations: Final[int] = 100
80
+ max_P_iterations: Final[int] = 100
81
+
82
+ default_figures_path: Final[Path] = Path.cwd() / "figures"
83
+ default_figures_filename: Final[str] = "image"
84
+
85
+ def __init__(
86
+ self,
87
+ A: ArrayLike,
88
+ objectives: Sequence[Objective],
89
+ *,
90
+ B: ArrayLike | None = None,
91
+ Bs: Sequence[ArrayLike] | None = None,
92
+ x_0: ArrayLike | None = None,
93
+ x_T: ArrayLike | None = None,
94
+ T_f: float | None = None,
95
+ P_f: Sequence[ArrayLike] | None = None,
96
+ L: int = L_default,
97
+ eta: int = eta_default,
98
+ epsilon_x: float = epsilon_x_default,
99
+ epsilon_P: float = epsilon_P_default,
100
+ state_variables_names: Sequence[str] | None = None,
101
+ show_legend: bool = True,
102
+ debug: bool = False,
103
+ ) -> None:
104
+ self._A = np.asarray(A, dtype=np.float64)
105
+ self._n = self._A.shape[0]
106
+ self._objectives = list(objectives)
107
+ self._N = len(self._objectives)
108
+ self._Qs = [o.Q for o in self._objectives]
109
+ self._Rs = [o.R for o in self._objectives]
110
+
111
+ self._B = None if B is None else np.asarray(B, dtype=np.float64)
112
+ self._M: FloatArray | None = None
113
+ self._M_inv: FloatArray | None = None
114
+ self._Bs = self._resolve_input_matrices(Bs)
115
+ self._m = sum(B_i.shape[1] for B_i in self._Bs)
116
+
117
+ self._x_0 = None if x_0 is None else np.asarray(x_0, dtype=np.float64).ravel()
118
+ self._x_T = None if x_T is None else np.asarray(x_T, dtype=np.float64).ravel()
119
+
120
+ self._infinite_horizon = T_f is None
121
+ self._T_f = float(self.T_f_default if T_f is None else T_f)
122
+ self._L = int(L)
123
+ self._eta = int(eta)
124
+ self._epsilon_x = float(epsilon_x)
125
+ self._epsilon_P = float(epsilon_P)
126
+ self._delta = self._T_f / self._L
127
+
128
+ self._state_variables_names = None if state_variables_names is None else list(state_variables_names)
129
+ self._show_legend = bool(show_legend)
130
+ self._debug = bool(debug)
131
+
132
+ self._forward_time = np.linspace(0.0, self._T_f, self._L)
133
+
134
+ # Per-player feedback gains S_i = B_i R_ii^{-1} B_i^T.
135
+ self._S = [B_i @ np.linalg.solve(R_i, B_i.T) for B_i, R_i in zip(self._Bs, self._Rs)]
136
+
137
+ self._P_f = (
138
+ self._uncoupled_are_solutions()
139
+ if P_f is None
140
+ else [np.asarray(P_f_i, dtype=np.float64) for P_f_i in P_f]
141
+ )
142
+
143
+ # Solver outputs, populated by ``solve`` / ``simulate``. ``_P`` / ``_K``
144
+ # are intentionally dynamic: a list of constant matrices for the
145
+ # infinite-horizon case, a time-indexed array for the finite-horizon one.
146
+ self._P: Any = []
147
+ self._K: Any = []
148
+ self._A_cl: FloatArray = np.empty_like(self._A)
149
+ self._x: FloatArray | None = None
150
+ self._solved = False
151
+ self._fig: Any = None
152
+
153
+ self._validate()
154
+
155
+ # ------------------------------------------------------------------ #
156
+ # Construction helpers
157
+ # ------------------------------------------------------------------ #
158
+ def _resolve_input_matrices(self, Bs: Sequence[ArrayLike] | None) -> list[FloatArray]:
159
+ """Determine the per-player input matrices ``B_i``.
160
+
161
+ Either ``Bs`` is given explicitly, or it is derived from the full input
162
+ matrix ``B`` and the objectives' decomposition matrices ``M_i`` through
163
+ :math:`B_i = B M^{-1}[:, \\text{block}_i]` where ``M`` stacks the
164
+ ``M_i``.
165
+ """
166
+
167
+ if Bs is not None:
168
+ return [np.asarray(B_i, dtype=np.float64) for B_i in Bs]
169
+
170
+ if self._B is None:
171
+ raise ValueError("Either B or Bs must be provided")
172
+
173
+ decomposition = [o.M for o in self._objectives if o.M is not None]
174
+ if not decomposition:
175
+ # Plain LQR / single shared input: one player drives the full input.
176
+ return [self._B]
177
+
178
+ self._M = np.concatenate(decomposition, axis=0)
179
+ try:
180
+ M_inv = np.linalg.inv(self._M)
181
+ except np.linalg.LinAlgError:
182
+ M_inv = np.linalg.pinv(self._M)
183
+ self._M_inv = M_inv
184
+
185
+ Bs_resolved: list[FloatArray] = []
186
+ column = 0
187
+ for M_i in decomposition:
188
+ m_i = M_i.shape[0]
189
+ Bs_resolved.append(self._B @ M_inv[:, column : column + m_i])
190
+ column += m_i
191
+ return Bs_resolved
192
+
193
+ def _uncoupled_are_solutions(self) -> list[FloatArray]:
194
+ """Solve each player's decoupled algebraic Riccati equation.
195
+
196
+ Used as the terminal Riccati condition / iteration seed. Falls back to
197
+ ``Q_i`` when the decoupled equation has no stabilising solution.
198
+ """
199
+
200
+ solutions: list[FloatArray] = []
201
+ for B_i, Q_i, R_i in zip(self._Bs, self._Qs, self._Rs):
202
+ try:
203
+ solutions.append(self._solve_are(B_i, Q_i, R_i))
204
+ except (np.linalg.LinAlgError, ValueError):
205
+ solutions.append(Q_i)
206
+ return solutions
207
+
208
+ def _converged(self, norms: Sequence[float]) -> bool:
209
+ """Relative convergence test over a trailing window of Riccati norms.
210
+
211
+ Inspects the last ``eta`` successive changes in the summed Riccati-matrix
212
+ norm and reports convergence when *all* of them are below ``epsilon_P``
213
+ *relative* to the current norm. Using a window of ``eta`` (rather than a
214
+ single step) guards against a premature stop on one lucky iteration, and
215
+ the relative scaling makes the tolerance meaningful across problem sizes.
216
+ """
217
+
218
+ if len(norms) <= self._eta:
219
+ return False
220
+ window = norms[-(self._eta + 1) :]
221
+ scale = max(abs(window[-1]), 1.0)
222
+ return all(abs(b - a) <= self._epsilon_P * scale for a, b in zip(window[:-1], window[1:]))
223
+
224
+ @abstractmethod
225
+ def _solve_are(self, B: FloatArray, Q: FloatArray, R: FloatArray) -> FloatArray:
226
+ """Solve the (continuous or discrete) algebraic Riccati equation for ``(A, B)``."""
227
+
228
+ def _validate(self) -> None:
229
+ if self._N == 0:
230
+ raise ValueError("At least one objective must be specified")
231
+ if self._A.shape != (self._n, self._n):
232
+ raise ValueError(f"A must be square (n, n) = ({self._n}, {self._n})")
233
+ for Q_i in self._Qs:
234
+ if Q_i.shape != (self._n, self._n):
235
+ raise ValueError(f"every Q must have shape ({self._n}, {self._n})")
236
+ for B_i, R_i in zip(self._Bs, self._Rs):
237
+ if B_i.shape[0] != self._n:
238
+ raise ValueError(f"every B_i must have n = {self._n} rows")
239
+ if B_i.shape[1] != R_i.shape[0]:
240
+ raise ValueError("each B_i column count must match its R_ii size")
241
+ for vec, name in ((self._x_0, "x_0"), (self._x_T, "x_T")):
242
+ if vec is not None and vec.shape != (self._n,):
243
+ raise ValueError(f"{name} must have length n = {self._n}")
244
+ if self._T_f <= 0:
245
+ raise ValueError("T_f must be a positive real number")
246
+ if self._L <= 0:
247
+ raise ValueError("L (number of samples) must be a positive integer")
248
+ if self._eta <= 0:
249
+ raise ValueError("eta must be a positive integer")
250
+ if not 0 < self._epsilon_x < 1:
251
+ raise ValueError("epsilon_x must lie in the open interval (0, 1)")
252
+ if not 0 < self._epsilon_P < 1:
253
+ raise ValueError("epsilon_P must lie in the open interval (0, 1)")
254
+ if self._state_variables_names is not None and len(self._state_variables_names) != self._n:
255
+ raise ValueError(f"state_variables_names must have length n = {self._n}")
256
+ if self._B is not None and not self.is_controllable():
257
+ import warnings
258
+
259
+ warnings.warn("the given system is not fully controllable", stacklevel=2)
260
+
261
+ # ------------------------------------------------------------------ #
262
+ # System properties
263
+ # ------------------------------------------------------------------ #
264
+ def is_controllable(self) -> bool:
265
+ """Whether ``(A, B)`` is fully controllable (full-rank controllability matrix)."""
266
+
267
+ if self._B is None:
268
+ B = np.concatenate(self._Bs, axis=1)
269
+ else:
270
+ B = self._B
271
+ controllability = np.concatenate(
272
+ [np.linalg.matrix_power(self._A, i) @ B for i in range(self._n)], axis=1
273
+ )
274
+ return int(np.linalg.matrix_rank(controllability)) == self._n
275
+
276
+ @property
277
+ def is_lqr(self) -> bool:
278
+ """Whether the game reduces to a single-objective LQR problem."""
279
+
280
+ return self._N == 1
281
+
282
+ def _closed_loop(self, gains: Sequence[FloatArray]) -> FloatArray:
283
+ r"""Closed-loop matrix :math:`A - \sum_i B_i K_i` for the given gains."""
284
+
285
+ return self._A - sum(B_i @ K_i for B_i, K_i in zip(self._Bs, gains))
286
+
287
+ # ------------------------------------------------------------------ #
288
+ # Solve / simulate orchestration
289
+ # ------------------------------------------------------------------ #
290
+ @abstractmethod
291
+ def solve(self) -> PyDiffGame:
292
+ """Solve the Riccati equation(s) and store ``P`` and the gains ``K``."""
293
+
294
+ @abstractmethod
295
+ def simulate(self) -> PyDiffGame:
296
+ """Propagate the closed-loop state trajectory from ``x_0``."""
297
+
298
+ @abstractmethod
299
+ def is_closed_loop_stable(self) -> bool:
300
+ """Whether the converged closed loop is stable in the relevant sense."""
301
+
302
+ def _require_solved(self) -> None:
303
+ if not self._solved:
304
+ raise RuntimeError("the game must be solved first; call solve()")
305
+
306
+ def run(
307
+ self,
308
+ *,
309
+ plot_state_space: bool = True,
310
+ save_figure: bool = False,
311
+ figure_path: PathInput = default_figures_path,
312
+ figure_filename: str = default_figures_filename,
313
+ ) -> PyDiffGame:
314
+ """Solve, simulate (if ``x_0`` is set) and optionally plot the state."""
315
+
316
+ self.solve()
317
+ if self._x_0 is not None:
318
+ self.simulate()
319
+ if plot_state_space:
320
+ self.plot_state_variables(
321
+ save_figure=save_figure,
322
+ figure_path=figure_path,
323
+ figure_filename=figure_filename,
324
+ )
325
+ return self
326
+
327
+ def __call__(self, **kwargs) -> PyDiffGame:
328
+ """Alias for :meth:`run` so a game instance is callable."""
329
+
330
+ return self.run(**kwargs)
331
+
332
+ # ------------------------------------------------------------------ #
333
+ # Cost
334
+ # ------------------------------------------------------------------ #
335
+ def cost(self, objective: Objective, *, state_only: bool = False) -> float:
336
+ r"""Trapezoidal approximation of an objective's cost along the trajectory.
337
+
338
+ .. math:: J \approx \delta \left[ \tfrac12 (J_0 + J_L)
339
+ + \sum_{l=1}^{L-1} J_l \right]
340
+ """
341
+
342
+ self._require_solved()
343
+ if self._x is None:
344
+ raise RuntimeError("no state trajectory; call simulate() with an x_0")
345
+
346
+ Q, R = objective.Q, objective.R
347
+ target = self._x_T if self._x_T is not None else np.zeros(self._n)
348
+ gains = self._gain_at # bound method, time-aware per subclass
349
+
350
+ total = 0.0
351
+ for l in range(self._L):
352
+ x_tilde = target - self._x[l]
353
+ cost_l = float(x_tilde @ Q @ x_tilde)
354
+ if not state_only:
355
+ u_tilde = -gains(l) @ x_tilde
356
+ cost_l += float(u_tilde @ R @ u_tilde)
357
+ if l in (0, self._L - 1):
358
+ cost_l *= 0.5
359
+ total += cost_l
360
+ return total * self._delta
361
+
362
+ @abstractmethod
363
+ def _gain_at(self, l: int) -> FloatArray:
364
+ """Aggregate physical-input gain at sample ``l`` (time-aware per domain)."""
365
+
366
+ # ------------------------------------------------------------------ #
367
+ # Plotting
368
+ # ------------------------------------------------------------------ #
369
+ def plot_state_variables(
370
+ self,
371
+ *,
372
+ save_figure: bool = False,
373
+ figure_path: PathInput = default_figures_path,
374
+ figure_filename: str = default_figures_filename,
375
+ ) -> None:
376
+ """Plot the state trajectory against time."""
377
+
378
+ self._require_solved()
379
+ if self._x is None:
380
+ raise RuntimeError("no state trajectory; call simulate() with an x_0")
381
+
382
+ labels = None
383
+ if self._show_legend:
384
+ if self._state_variables_names is not None:
385
+ labels = [f"${name}$" for name in self._state_variables_names]
386
+ else:
387
+ labels = [rf"$\mathbf{{x}}_{{{j}}}$" for j in range(1, self._n + 1)]
388
+
389
+ self._fig = plotting.plot_temporal(
390
+ self._forward_time,
391
+ self._x,
392
+ labels=labels,
393
+ show_legend=self._show_legend,
394
+ )
395
+ if save_figure:
396
+ plotting.save_figure(self._fig, figure_path, figure_filename)
397
+
398
+ # ------------------------------------------------------------------ #
399
+ # Public accessors
400
+ # ------------------------------------------------------------------ #
401
+ @property
402
+ def A(self) -> FloatArray:
403
+ return self._A
404
+
405
+ @property
406
+ def Bs(self) -> list[FloatArray]:
407
+ return self._Bs
408
+
409
+ @property
410
+ def n(self) -> int:
411
+ return self._n
412
+
413
+ @property
414
+ def N(self) -> int:
415
+ return self._N
416
+
417
+ @property
418
+ def P(self) -> list[FloatArray] | FloatArray:
419
+ self._require_solved()
420
+ return self._P
421
+
422
+ @property
423
+ def K(self) -> list[FloatArray] | FloatArray:
424
+ self._require_solved()
425
+ return self._K
426
+
427
+ @property
428
+ def x(self) -> FloatArray | None:
429
+ return self._x
430
+
431
+ @property
432
+ def x_0(self) -> FloatArray | None:
433
+ return self._x_0
434
+
435
+ @property
436
+ def x_T(self) -> FloatArray | None:
437
+ return self._x_T
438
+
439
+ @property
440
+ def forward_time(self) -> FloatArray:
441
+ return self._forward_time
442
+
443
+ @property
444
+ def T_f(self) -> float:
445
+ return self._T_f
446
+
447
+ @property
448
+ def L(self) -> int:
449
+ return self._L
450
+
451
+ @property
452
+ def M_inv(self) -> FloatArray | None:
453
+ return self._M_inv
454
+
455
+ def __len__(self) -> int:
456
+ return self._N
457
+
458
+ def __getitem__(self, i: int) -> Objective:
459
+ return self._objectives[i]
460
+
461
+ def __repr__(self) -> str:
462
+ domain = type(self).__name__
463
+ horizon = "infinite" if self._infinite_horizon else f"T_f={self._T_f:g}"
464
+ kind = "LQR" if self.is_lqr else f"{self._N}-player game"
465
+ return f"<{domain}: {kind}, n={self._n}, {horizon}>"
466
+
467
+
468
+ __all__ = ["PyDiffGame"]
@@ -0,0 +1,121 @@
1
+ """Compare several control designs (e.g. an LQR baseline vs. a multi-player game).
2
+
3
+ A :class:`PyDiffGameLQRComparison` builds one game per objective group, all
4
+ sharing the same model (``A``, ``B``, ``x_0`` ...), then solves, simulates and
5
+ optionally plots them and reports their costs on a common yardstick.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import itertools
11
+ import time
12
+ from collections.abc import Callable, Sequence
13
+ from concurrent.futures import ProcessPoolExecutor
14
+
15
+ from tqdm import tqdm
16
+
17
+ from PyDiffGame._typing import ArrayLike
18
+ from PyDiffGame.base import PyDiffGame
19
+ from PyDiffGame.continuous import ContinuousPyDiffGame
20
+ from PyDiffGame.discrete import DiscretePyDiffGame
21
+ from PyDiffGame.objective import Objective
22
+
23
+
24
+ class PyDiffGameLQRComparison:
25
+ """Build and compare several games that share the same linear model.
26
+
27
+ Parameters
28
+ ----------
29
+ A:
30
+ Shared dynamics matrix.
31
+ B:
32
+ Shared input matrix.
33
+ games_objectives:
34
+ A sequence of objective groups; one game is built per group. A group
35
+ with a single LQR objective yields an LQR; a group of game objectives
36
+ (each carrying an ``M``) yields a multi-player game.
37
+ continuous:
38
+ Whether to build continuous- or discrete-time games.
39
+ **game_kwargs:
40
+ Forwarded to every game (``x_0``, ``x_T``, ``T_f``, ``L`` ...).
41
+ """
42
+
43
+ def __init__(
44
+ self,
45
+ A: ArrayLike,
46
+ B: ArrayLike,
47
+ games_objectives: Sequence[Sequence[Objective]],
48
+ *,
49
+ continuous: bool = True,
50
+ **game_kwargs,
51
+ ) -> None:
52
+ game_class = ContinuousPyDiffGame if continuous else DiscretePyDiffGame
53
+ self._games: list[PyDiffGame] = [
54
+ game_class(A, list(objectives), B=B, **game_kwargs) for objectives in games_objectives
55
+ ]
56
+ self._lqr_objective: Objective | None = next((game[0] for game in self._games if game.is_lqr), None)
57
+
58
+ @property
59
+ def games(self) -> list[PyDiffGame]:
60
+ return self._games
61
+
62
+ def are_all_controllable(self) -> bool:
63
+ return all(game.is_controllable() for game in self._games)
64
+
65
+ def solve(self) -> PyDiffGameLQRComparison:
66
+ for game in self._games:
67
+ game.solve()
68
+ return self
69
+
70
+ def run(
71
+ self,
72
+ *,
73
+ plot_state_spaces: bool = True,
74
+ save_figure: bool = False,
75
+ figure_filename: str | Callable[[PyDiffGame], str] = PyDiffGame.default_figures_filename,
76
+ ) -> PyDiffGameLQRComparison:
77
+ """Solve, simulate and (optionally) plot every game in the comparison."""
78
+
79
+ for game in self._games:
80
+ name = figure_filename(game) if callable(figure_filename) else figure_filename
81
+ game.run(
82
+ plot_state_space=plot_state_spaces,
83
+ save_figure=save_figure,
84
+ figure_filename=name,
85
+ )
86
+ return self
87
+
88
+ def __call__(self, **kwargs) -> PyDiffGameLQRComparison:
89
+ return self.run(**kwargs)
90
+
91
+ def costs(self, objective: Objective | None = None) -> list[float]:
92
+ """Cost of every (solved, simulated) game measured by a common objective."""
93
+
94
+ objective = objective or self._lqr_objective
95
+ if objective is None:
96
+ raise ValueError("no LQR objective available; pass one explicitly")
97
+ return [game.cost(objective) for game in self._games]
98
+
99
+ @staticmethod
100
+ def run_multiprocess(
101
+ worker: Callable[..., object],
102
+ values: Sequence[Sequence[object]],
103
+ ) -> None:
104
+ """Run ``worker`` over the Cartesian product of ``values`` in parallel."""
105
+
106
+ start = time.perf_counter()
107
+ combinations = list(itertools.product(*values))
108
+ with ProcessPoolExecutor() as executor:
109
+ futures = [executor.submit(worker, *combo) for combo in combinations]
110
+ for future in tqdm(futures, total=len(futures)):
111
+ future.result()
112
+ print(f"Total time: {time.perf_counter() - start:.3f} s")
113
+
114
+ def __len__(self) -> int:
115
+ return len(self._games)
116
+
117
+ def __getitem__(self, i: int) -> PyDiffGame:
118
+ return self._games[i]
119
+
120
+
121
+ __all__ = ["PyDiffGameLQRComparison"]