openscvx 2.dev6__py3-none-any.whl → 2.dev7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,908 @@
1
+ """CVXPy-based convex subproblem solver for the penalized trust-region (PTR) SCP algorithm.
2
+
3
+ This module provides the default backend for :class:`PTRSolver`, using CVXPy's
4
+ modeling language and dispatching to any of its supported conic solvers
5
+ (QOCO, CLARABEL, ...). Optional code generation via cvxpygen is available
6
+ for improved per-iteration performance.
7
+
8
+ Companion backend: :class:`openscvx.solvers.qpax_ptr_solver.QPAXPTRSolver`,
9
+ which targets pure-JAX execution.
10
+ """
11
+
12
+ import os
13
+ from typing import TYPE_CHECKING, List, Optional, Union
14
+
15
+ import cvxpy as cp
16
+ import numpy as np
17
+
18
+ from openscvx.config import Config
19
+
20
+ from .ptr_solver import PTRSolver, PTRSolveResult
21
+
22
+ if TYPE_CHECKING:
23
+ from openscvx.lowered import LoweredProblem
24
+ from openscvx.lowered.cvxpy_variables import CVXPyVariables
25
+ from openscvx.lowered.jax_constraints import LoweredJaxConstraints
26
+ from openscvx.lowered.unified import UnifiedControl, UnifiedState
27
+
28
+ # Optional cvxpygen import
29
+ try:
30
+ from cvxpygen import cpg
31
+
32
+ CVXPYGEN_AVAILABLE = True
33
+ except ImportError:
34
+ CVXPYGEN_AVAILABLE = False
35
+ cpg = None
36
+
37
+
38
+ ## TODO: (fabio) add support for impulsive controls
39
+
40
+
41
+ class CVXPyPTRSolver(PTRSolver):
42
+ """CVXPy-backed implementation of the PTR convex subproblem.
43
+
44
+ Builds the subproblem as a DCP program through CVXPy and dispatches it to
45
+ one of CVXPy's supported conic solvers (QOCO by default, CLARABEL, etc.).
46
+ Optional code generation via cvxpygen is available for improved per-iteration
47
+ performance.
48
+
49
+ The solver builds the problem structure once during ``initialize()``, using
50
+ CVXPy Parameters for values that change each iteration. The ``solve()``
51
+ method then solves and returns a structured ``PTRSolveResult``.
52
+
53
+ The cost and constraint formulations are defined in the ``cost()`` and
54
+ ``constraints()`` methods, which can be overridden in subclasses to
55
+ customize the convex subproblem. For example::
56
+
57
+ class MyPTRSolver(CVXPyPTRSolver):
58
+ def cost(self, settings, lowered):
59
+ c = super().cost(settings, lowered)
60
+ c += my_extra_term(self._ocp_vars)
61
+ return c
62
+
63
+ Example:
64
+ Using CVXPyPTRSolver with the SCP framework::
65
+
66
+ solver = CVXPyPTRSolver()
67
+ solver.create_variables(N, x_unified, u_unified, jax_constraints)
68
+ solver.initialize(lowered, settings)
69
+
70
+ # Each iteration (parameter updates done by algorithm):
71
+ result = solver.solve()
72
+ x_sol = result.x # Unscaled state trajectory
73
+
74
+ Args:
75
+ cvx_solver: CVXPY solver backend name. Defaults to ``"QOCO"``.
76
+ solver_args: Keyword arguments forwarded to the CVXPY solver
77
+ (e.g. tolerances). Defaults to
78
+ ``{"abstol": 1e-6, "reltol": 1e-9, "enforce_dpp": True}``.
79
+ cvxpygen: Enable CVXPy code generation for faster solves.
80
+ Defaults to ``False``.
81
+
82
+ !!! warning
83
+ Enabling cvxpygen currently disables sparse parameter
84
+ declarations. cvxpygen does not yet support the N-D sparsity
85
+ indices used by OpenSCvx's tiled parameters, so all parameters
86
+ are created as dense when code generation is active. This may
87
+ increase the generated solver's memory footprint and compile
88
+ time but does not affect solution correctness.
89
+ cvxpygen_override: Overwrite existing generated solver directory
90
+ without prompting. Defaults to ``False``.
91
+
92
+ Attributes:
93
+ ocp_vars: The CVXPy variables and parameters (available after create_variables())
94
+ """
95
+
96
+ def __init__(
97
+ self,
98
+ cvx_solver: str = "QOCO",
99
+ solver_args: Optional[dict] = None,
100
+ cvxpygen: bool = False,
101
+ cvxpygen_override: bool = False,
102
+ ):
103
+ """Initialize CVXPyPTRSolver with solver configuration.
104
+
105
+ Call create_variables() then initialize() to build the problem structure.
106
+ """
107
+ self.cvx_solver = cvx_solver
108
+ self.solver_args = (
109
+ solver_args
110
+ if solver_args is not None
111
+ else {"abstol": 1e-06, "reltol": 1e-09, "enforce_dpp": True}
112
+ )
113
+ self.cvxpygen = cvxpygen
114
+ self.cvxpygen_override = cvxpygen_override
115
+
116
+ self._ocp_vars: "CVXPyVariables" = None
117
+ self._problem: cp.Problem = None
118
+ self._solve_fn: callable = None
119
+
120
+ @property
121
+ def ocp_vars(self) -> "CVXPyVariables":
122
+ """The CVXPy variables and parameters.
123
+
124
+ Returns:
125
+ The CVXPyVariables dataclass, or None if create_variables() not called.
126
+ """
127
+ return self._ocp_vars
128
+
129
+ def create_variables(
130
+ self,
131
+ N: int,
132
+ x_unified: "UnifiedState",
133
+ u_unified: "UnifiedControl",
134
+ jax_constraints: "LoweredJaxConstraints",
135
+ dynamics_sparsity: Optional[tuple] = None,
136
+ constraint_sparsity: Optional[list] = None,
137
+ ) -> None:
138
+ """Create CVXPy optimization variables.
139
+
140
+ Creates all CVXPy Variable and Parameter objects needed for the optimal
141
+ control problem. This includes state/control variables, dynamics parameters,
142
+ constraint linearization parameters, and scaling matrices.
143
+
144
+ Args:
145
+ N: Number of discretization nodes
146
+ x_unified: Unified state interface with dimensions and scaling bounds
147
+ u_unified: Unified control interface with dimensions and scaling bounds
148
+ jax_constraints: Lowered JAX constraints (for sizing linearization params)
149
+ dynamics_sparsity: Optional tuple ``(A_d, B_d, C_d)`` of boolean
150
+ ndarrays giving the discrete-time Jacobian sparsity patterns.
151
+ ``A_d`` has shape ``(n_x, n_x)``; ``B_d`` and ``C_d`` have
152
+ shape ``(n_x, n_u)``.
153
+ constraint_sparsity: Optional list of ``(x_mask, u_mask)`` boolean
154
+ 1-D arrays, one per nodal constraint.
155
+ """
156
+ from openscvx.symbolic.lower import _tile_sparsity, create_cvxpy_variables
157
+
158
+ n_states = len(x_unified.max)
159
+ n_controls = len(u_unified.max)
160
+ slice_cont = u_unified.slice_continuous
161
+ slice_imp = u_unified.slice_impulsive
162
+ n_controls_cont = int(slice_cont.stop - slice_cont.start)
163
+ n_controls_imp = int(slice_imp.stop - slice_imp.start)
164
+ if n_controls_cont + n_controls_imp != n_controls:
165
+ raise ValueError(
166
+ "Unified control slices are inconsistent with control dimension. "
167
+ f"continuous={n_controls_cont}, impulsive={n_controls_imp}, total={n_controls}."
168
+ )
169
+
170
+ S_x, c_x = self._scaling(x_unified)
171
+ S_u, c_u = self._scaling(u_unified)
172
+
173
+ # Convert boolean sparsity patterns to CVXPY index format
174
+ A_d_sp = B_d_sp = C_d_sp = None
175
+ if dynamics_sparsity is not None:
176
+ A_d_pat, B_d_pat, C_d_pat = dynamics_sparsity
177
+ A_d_sp = _tile_sparsity(A_d_pat, N - 1)
178
+ B_d_sp = _tile_sparsity(B_d_pat, N - 1)
179
+ C_d_sp = _tile_sparsity(C_d_pat, N - 1)
180
+
181
+ # TODO: (griffin-norris) Remove once cvxpygen supports N-D sparsity
182
+ # indices. cvxpygen's handle_sparsity() assumes 2-D (rows, cols) but
183
+ # our tiled parameters produce 3-D indices (slices, rows, cols).
184
+ # Dropping sparsity here is safe — it only affects codegen performance.
185
+ if self.cvxpygen:
186
+ A_d_sp = B_d_sp = C_d_sp = None
187
+ constraint_sparsity = None
188
+
189
+ # Create all CVXPy variables for the OCP
190
+ self._ocp_vars = create_cvxpy_variables(
191
+ N=N,
192
+ n_states=n_states,
193
+ n_controls=n_controls,
194
+ S_x=S_x,
195
+ c_x=c_x,
196
+ S_u=S_u,
197
+ c_u=c_u,
198
+ n_nodal_constraints=len(jax_constraints.nodal),
199
+ n_cross_node_constraints=len(jax_constraints.cross_node),
200
+ A_d_sparsity=A_d_sp,
201
+ B_d_sparsity=B_d_sp,
202
+ C_d_sparsity=C_d_sp,
203
+ constraint_sparsity=constraint_sparsity,
204
+ )
205
+
206
+ def lower_convex_constraints(self, constraints, parameters=None):
207
+ """Lower user ``.convex()`` constraints into CVXPy constraint objects.
208
+
209
+ Delegates to :func:`openscvx.symbolic.lower.lower_cvxpy_constraints`,
210
+ feeding it the unscaled-state and unscaled-control CVXPy expressions
211
+ built by :meth:`create_variables`.
212
+ """
213
+ from openscvx.symbolic.lower import lower_cvxpy_constraints
214
+
215
+ if self._ocp_vars is None:
216
+ raise RuntimeError(
217
+ "CVXPyPTRSolver.lower_convex_constraints() called before "
218
+ "create_variables(); the CVXPy variables it needs don't "
219
+ "exist yet."
220
+ )
221
+ return lower_cvxpy_constraints(
222
+ constraints,
223
+ self._ocp_vars.x_nonscaled,
224
+ self._ocp_vars.u_nonscaled,
225
+ parameters,
226
+ )
227
+
228
+ def initialize(
229
+ self,
230
+ lowered: "LoweredProblem",
231
+ settings: "Config",
232
+ ) -> None:
233
+ """Build the CVXPy optimal control problem.
234
+
235
+ Constructs the complete optimization problem by calling ``cost()`` and
236
+ ``constraints()`` to build the objective and constraint formulations,
237
+ then assembles them into a CVXPy Problem.
238
+
239
+ If cvxpygen is enabled, generates compiled solver code for improved
240
+ performance.
241
+
242
+ Note:
243
+ ``create_variables()`` must be called before this method.
244
+
245
+ Args:
246
+ lowered: Lowered problem containing:
247
+ - ``cvxpy_constraints``: Lowered convex constraints
248
+ - ``jax_constraints``: JAX constraint functions (for structure)
249
+ settings: Problem configuration (node count, scaling, etc.)
250
+
251
+ Raises:
252
+ RuntimeError: If create_variables() has not been called.
253
+ """
254
+ if self._ocp_vars is None:
255
+ raise RuntimeError(
256
+ "CVXPyPTRSolver.initialize() called before create_variables(). "
257
+ "Call create_variables() first to create optimization variables."
258
+ )
259
+
260
+ objective = self.cost(settings, lowered)
261
+ constr = self.constraints(settings, lowered)
262
+ prob = cp.Problem(cp.Minimize(objective), constr)
263
+
264
+ if self.cvxpygen:
265
+ if not CVXPYGEN_AVAILABLE:
266
+ raise ImportError(
267
+ "cvxpygen is required for code generation but not installed. "
268
+ "Install it with: pip install openscvx[cvxpygen] or pip install cvxpygen"
269
+ )
270
+ # Check to see if solver directory exists
271
+ if not os.path.exists("solver"):
272
+ cpg.generate_code(prob, solver=self.cvx_solver, code_dir="solver", wrapper=True)
273
+ else:
274
+ # Prompt the use to indicate if they wish to overwrite the solver
275
+ # directory or use the existing compiled solver
276
+ if self.cvxpygen_override:
277
+ cpg.generate_code(
278
+ prob,
279
+ solver=self.cvx_solver,
280
+ code_dir="solver",
281
+ wrapper=True,
282
+ )
283
+ else:
284
+ overwrite = input("Solver directory already exists. Overwrite? (y/n): ")
285
+ if overwrite.lower() == "y":
286
+ cpg.generate_code(
287
+ prob,
288
+ solver=self.cvx_solver,
289
+ code_dir="solver",
290
+ wrapper=True,
291
+ )
292
+
293
+ self._problem = prob
294
+ self._setup_solve_function()
295
+
296
+ def cost(
297
+ self,
298
+ settings: "Config",
299
+ lowered: "LoweredProblem",
300
+ ) -> cp.Expression:
301
+ """Build the cost expression for the convex subproblem.
302
+
303
+ Constructs the PTR objective function including:
304
+
305
+ - Boundary condition costs (Minimize/Maximize state components)
306
+ - Trust region penalty (deviation from linearization point)
307
+ - Virtual control penalty (dynamics defect relaxation)
308
+ - Virtual buffer penalty (nonconvex constraint violation relaxation)
309
+
310
+ Override this method in subclasses to customize the cost formulation.
311
+ Use ``super().cost(settings, lowered)`` to include the standard PTR
312
+ cost terms and add to them.
313
+
314
+ Args:
315
+ settings: Configuration object with solver settings
316
+ lowered: Lowered problem containing constraint structure
317
+
318
+ Returns:
319
+ CVXPy expression representing the total cost to minimize.
320
+ """
321
+ ocp_vars = self._ocp_vars
322
+ jax_constraints = lowered.jax_constraints
323
+
324
+ lam_prox = ocp_vars.lam_prox
325
+ lam_cost = ocp_vars.lam_cost
326
+ lam_vc = ocp_vars.lam_vc
327
+ lam_vb_nodal = ocp_vars.lam_vb_nodal
328
+ lam_vb_cross = ocp_vars.lam_vb_cross
329
+ _ = ocp_vars.x_nonscaled
330
+ dx = ocp_vars.dx
331
+ du = ocp_vars.du
332
+ nu = ocp_vars.nu
333
+ nu_vb = ocp_vars.nu_vb
334
+ nu_vb_cross = ocp_vars.nu_vb_cross
335
+
336
+ cost = cp.sum(lam_cost) * 0
337
+ cost += cp.sum(lam_vb_nodal) * 0
338
+ cost += cp.sum(lam_vb_cross) * 0
339
+
340
+ # Boundary condition cost terms (use scaled x for numerical conditioning)
341
+ x = ocp_vars.x
342
+ for i in range(settings.sim.true_state_slice.start, settings.sim.true_state_slice.stop):
343
+ if settings.sim.x.initial_type[i] == "Minimize":
344
+ cost += lam_cost[i] * x[0][i]
345
+ if settings.sim.x.final_type[i] == "Minimize":
346
+ cost += lam_cost[i] * x[-1][i]
347
+ if settings.sim.x.initial_type[i] == "Maximize":
348
+ cost -= lam_cost[i] * x[0][i]
349
+ if settings.sim.x.final_type[i] == "Maximize":
350
+ cost -= lam_cost[i] * x[-1][i]
351
+
352
+ # Trust Region Cost (per-variable weighting)
353
+ cost += sum(
354
+ cp.sum(cp.multiply(lam_prox[i], cp.square(cp.hstack((dx[i], du[i])))))
355
+ for i in range(settings.sim.n)
356
+ )
357
+
358
+ # Virtual Control Slack
359
+ cost += sum(cp.sum(lam_vc[i - 1] * cp.abs(nu[i - 1])) for i in range(1, settings.sim.n))
360
+
361
+ # Virtual buffer penalty for nodal constraints (per-node weighting)
362
+ idx_ncvx = 0
363
+ if jax_constraints.nodal:
364
+ for constraint in jax_constraints.nodal:
365
+ cost += lam_vb_nodal[:, idx_ncvx] @ cp.pos(nu_vb[idx_ncvx])
366
+ idx_ncvx += 1
367
+
368
+ # Virtual slack penalty for cross-node constraints
369
+ idx_cross = 0
370
+ if jax_constraints.cross_node:
371
+ for constraint in jax_constraints.cross_node:
372
+ cost += lam_vb_cross[idx_cross] * cp.pos(nu_vb_cross[idx_cross])
373
+ idx_cross += 1
374
+
375
+ return cost
376
+
377
+ def constraints(
378
+ self,
379
+ settings: "Config",
380
+ lowered: "LoweredProblem",
381
+ ) -> list:
382
+ """Build the constraint list for the convex subproblem.
383
+
384
+ Constructs all PTR constraints including:
385
+
386
+ - Linearized nodal constraints (from JAX-lowered nonconvex constraints)
387
+ - Linearized cross-node constraints
388
+ - Convex constraints (already lowered to CVXPy)
389
+ - Boundary conditions (fixed initial/terminal states)
390
+ - Uniform time grid constraints
391
+ - State and control deviation definitions
392
+ - Linearized dynamics
393
+ - State and control box constraints
394
+ - CTCS constraints
395
+
396
+ Override this method in subclasses to customize the constraint
397
+ formulation. Use ``super().constraints(settings, lowered)`` to include
398
+ the standard PTR constraints and extend them.
399
+
400
+ Args:
401
+ settings: Configuration object with solver settings
402
+ lowered: Lowered problem containing lowered constraints
403
+
404
+ Returns:
405
+ List of CVXPy constraints.
406
+ """
407
+ ocp_vars = self._ocp_vars
408
+ jax_constraints = lowered.jax_constraints
409
+ cvxpy_constraints = lowered.cvxpy_constraints
410
+
411
+ x = ocp_vars.x
412
+ dx = ocp_vars.dx
413
+ x_bar = ocp_vars.x_bar
414
+ x_init = ocp_vars.x_init
415
+ x_term = ocp_vars.x_term
416
+ u = ocp_vars.u
417
+ du = ocp_vars.du
418
+ u_bar = ocp_vars.u_bar
419
+ A_d = ocp_vars.A_d
420
+ B_d = ocp_vars.B_d
421
+ C_d = ocp_vars.C_d
422
+ x_prop = ocp_vars.x_prop
423
+ x_prop_plus = ocp_vars.x_prop_plus
424
+ E_d = ocp_vars.E_d
425
+ nu = ocp_vars.nu
426
+ g = ocp_vars.g
427
+ grad_g_x = ocp_vars.grad_g_x
428
+ grad_g_u = ocp_vars.grad_g_u
429
+ nu_vb = ocp_vars.nu_vb
430
+ g_cross = ocp_vars.g_cross
431
+ grad_g_X_cross = ocp_vars.grad_g_X_cross
432
+ grad_g_U_cross = ocp_vars.grad_g_U_cross
433
+ nu_vb_cross = ocp_vars.nu_vb_cross
434
+ inv_S_x = ocp_vars.inv_S_x
435
+ c_x = ocp_vars.c_x
436
+ inv_S_u = ocp_vars.inv_S_u
437
+ c_u = ocp_vars.c_u
438
+ x_nonscaled = ocp_vars.x_nonscaled
439
+ u_nonscaled = ocp_vars.u_nonscaled
440
+ dx_nonscaled = ocp_vars.dx_nonscaled
441
+ du_nonscaled = ocp_vars.du_nonscaled
442
+ slice_cont = settings.sim.u.slice_continuous
443
+ slice_imp = settings.sim.u.slice_impulsive
444
+ has_continuous = bool(slice_cont.stop > slice_cont.start)
445
+ has_impulsive = bool(slice_imp.stop > slice_imp.start)
446
+
447
+ constr = []
448
+
449
+ # Linearized nodal constraints (from JAX-lowered non-convex)
450
+ idx_ncvx = 0
451
+ if jax_constraints.nodal:
452
+ for constraint in jax_constraints.nodal:
453
+ # nodes should already be validated and normalized in preprocessing
454
+ nodes = constraint.nodes
455
+ for node in nodes:
456
+ residual = (
457
+ g[idx_ncvx][node]
458
+ + grad_g_x[idx_ncvx][node] @ dx[node]
459
+ + grad_g_u[idx_ncvx][node] @ du[node]
460
+ )
461
+ constr += [residual == nu_vb[idx_ncvx][node]]
462
+ idx_ncvx += 1
463
+
464
+ # Linearized cross-node constraints (from JAX-lowered non-convex)
465
+ idx_cross = 0
466
+ if jax_constraints.cross_node:
467
+ for constraint in jax_constraints.cross_node:
468
+ # Linearization: g(X_bar, U_bar) + ∇g_X @ dX + ∇g_U @ dU == nu_vb
469
+ # Sum over all trajectory nodes to couple multiple nodes
470
+ residual = g_cross[idx_cross]
471
+ for k in range(settings.sim.n):
472
+ # Contribution from state at node k
473
+ residual += grad_g_X_cross[idx_cross][k, :] @ dx[k]
474
+ # Contribution from control at node k
475
+ residual += grad_g_U_cross[idx_cross][k, :] @ du[k]
476
+ # Add constraint: residual == slack variable
477
+ constr += [residual == nu_vb_cross[idx_cross]]
478
+ idx_cross += 1
479
+
480
+ # Convex constraints (already lowered to CVXPy)
481
+ if cvxpy_constraints.constraints:
482
+ constr += cvxpy_constraints.constraints
483
+
484
+ # Boundary conditions (Fix)
485
+ for i in range(settings.sim.true_state_slice.start, settings.sim.true_state_slice.stop):
486
+ if settings.sim.x.initial_type[i] == "Fix":
487
+ if has_impulsive:
488
+ constr += [
489
+ x_nonscaled[0][i]
490
+ == x_prop_plus[0][i] + E_d[0][i, slice_imp] @ du_nonscaled[0][slice_imp]
491
+ ]
492
+ else:
493
+ constr += [x_nonscaled[0][i] == x_init[i]] # Initial Boundary Conditions
494
+ if settings.sim.x.final_type[i] == "Fix":
495
+ constr += [x_nonscaled[-1][i] == x_term[i]] # Final Boundary Conditions
496
+
497
+ if settings.sim._uniform_time_grid:
498
+ S_u_inv_td = inv_S_u[settings.sim.time_dilation_slice, settings.sim.time_dilation_slice]
499
+ c_u_td = c_u[settings.sim.time_dilation_slice]
500
+ constr += [
501
+ S_u_inv_td @ (u_nonscaled[i][settings.sim.time_dilation_slice] - c_u_td)
502
+ == S_u_inv_td @ (u_nonscaled[i - 1][settings.sim.time_dilation_slice] - c_u_td)
503
+ for i in range(1, settings.sim.n)
504
+ ]
505
+
506
+ constr += [
507
+ (x[i] - inv_S_x @ (x_bar[i] - c_x) - dx[i]) == 0 for i in range(settings.sim.n)
508
+ ] # State Error
509
+ constr += [
510
+ (u[i] - inv_S_u @ (u_bar[i] - c_u) - du[i]) == 0 for i in range(settings.sim.n)
511
+ ] # Control Error
512
+
513
+ constr += [
514
+ inv_S_x @ (x_nonscaled[i] - c_x)
515
+ == inv_S_x
516
+ @ (
517
+ A_d[i - 1] @ dx_nonscaled[i - 1]
518
+ + (
519
+ B_d[i - 1][:, slice_cont] @ du_nonscaled[i - 1][slice_cont]
520
+ if has_continuous
521
+ else 0
522
+ )
523
+ + (C_d[i - 1][:, slice_cont] @ du_nonscaled[i][slice_cont] if has_continuous else 0)
524
+ + (E_d[i][:, slice_imp] @ du_nonscaled[i][slice_imp] if has_impulsive else 0)
525
+ + (x_prop_plus[i] if has_impulsive else x_prop[i - 1])
526
+ - c_x
527
+ )
528
+ + nu[i - 1]
529
+ for i in range(1, settings.sim.n)
530
+ ] # Dynamics Constraint
531
+
532
+ constr += [
533
+ inv_S_u @ (u_nonscaled[i] - c_u) <= inv_S_u @ (settings.sim.u.max - c_u)
534
+ for i in range(settings.sim.n)
535
+ ]
536
+ constr += [
537
+ inv_S_u @ (u_nonscaled[i] - c_u) >= inv_S_u @ (settings.sim.u.min - c_u)
538
+ for i in range(settings.sim.n)
539
+ ] # Control Constraints
540
+
541
+ # TODO: (norrisg) formalize this
542
+ constr += [
543
+ inv_S_x @ (x_nonscaled[i][:] - c_x) <= inv_S_x @ (settings.sim.x.max - c_x)
544
+ for i in range(settings.sim.n)
545
+ ]
546
+ constr += [
547
+ inv_S_x @ (x_nonscaled[i][:] - c_x) >= inv_S_x @ (settings.sim.x.min - c_x)
548
+ for i in range(settings.sim.n)
549
+ ] # State Constraints (Also implemented in CTCS but included for numerical stability)
550
+
551
+ for idx, nodes in zip(
552
+ np.arange(settings.sim.ctcs_slice.start, settings.sim.ctcs_slice.stop),
553
+ settings.sim.ctcs_node_intervals,
554
+ ):
555
+ start_idx = 1 if nodes[0] == 0 else nodes[0]
556
+ constr += [
557
+ cp.abs(x_nonscaled[i][idx] - x_nonscaled[i - 1][idx]) <= settings.sim.x.max[idx]
558
+ for i in range(start_idx, nodes[1])
559
+ ]
560
+ constr += [x_nonscaled[0][idx] == 0]
561
+
562
+ return constr
563
+
564
+ def _setup_solve_function(self) -> None:
565
+ """Configure the solve function based on solver settings.
566
+
567
+ Sets up either cvxpygen-based solving or standard CVXPy solving
568
+ based on the configuration.
569
+ """
570
+ if self.cvxpygen:
571
+ try:
572
+ import pickle
573
+
574
+ from solver.cpg_solver import cpg_solve
575
+
576
+ with open("solver/problem.pickle", "rb") as f:
577
+ pickle.load(f)
578
+ self._problem.register_solve("CPG", cpg_solve)
579
+ solver_args = self.solver_args
580
+ self._solve_fn = lambda: self._problem.solve(method="CPG", **solver_args)
581
+ except ImportError:
582
+ raise ImportError(
583
+ "cvxpygen solver not found. Make sure cvxpygen is installed and code "
584
+ "generation has been run. Install with: pip install openscvx[cvxpygen]"
585
+ )
586
+ else:
587
+ solver = self.cvx_solver
588
+ solver_args = dict(self.solver_args)
589
+
590
+ def _solve_with_dpp_fallback():
591
+ try:
592
+ return self._problem.solve(solver=solver, **solver_args)
593
+ except cp.error.DPPError:
594
+ fallback_args = dict(solver_args)
595
+ fallback_args.pop("enforce_dpp", None)
596
+ fallback_args["ignore_dpp"] = True
597
+ return self._problem.solve(solver=solver, **fallback_args)
598
+
599
+ self._solve_fn = _solve_with_dpp_fallback
600
+
601
+ def update_dynamics_linearization(
602
+ self,
603
+ x_bar: np.ndarray,
604
+ u_bar: np.ndarray,
605
+ A_d: np.ndarray,
606
+ B_d: np.ndarray,
607
+ C_d: np.ndarray,
608
+ x_prop: np.ndarray,
609
+ x_prop_plus: np.ndarray | None = None,
610
+ D_d: np.ndarray | None = None,
611
+ E_d: np.ndarray | None = None,
612
+ ) -> None:
613
+ """Update dynamics linearization point and matrices.
614
+
615
+ Sets the current linearization point (previous iterate) and the
616
+ discretized dynamics matrices for the convex subproblem.
617
+
618
+ Args:
619
+ x_bar: Previous state trajectory, shape (N, n_states)
620
+ u_bar: Previous control trajectory, shape (N, n_controls)
621
+ A_d: Discretized state Jacobian, shape (N-1, n_states, n_states)
622
+ B_d: Discretized control Jacobian (current node), shape (N-1, n_states, n_controls)
623
+ C_d: Discretized control Jacobian (next node), shape (N-1, n_states, n_controls)
624
+ x_prop: Propagated state from continuous dynamics, shape (N-1, n_states)
625
+ x_prop_plus: Optional impulsive/discrete propagated state, shape (N, n_states)
626
+ D_d: Optional impulsive/discrete Jacobian wrt state, shape (N, n_states, n_states)
627
+ E_d: Optional impulsive/discrete Jacobian wrt control, shape (N, n_states, n_controls)
628
+ """
629
+ self._set_param("x_bar", x_bar)
630
+ self._set_param("u_bar", u_bar)
631
+
632
+ A_eff = np.asarray(A_d)
633
+ B_eff = np.asarray(B_d)
634
+ C_eff = np.asarray(C_d)
635
+
636
+ if D_d is not None:
637
+ D_arr = np.asarray(D_d)
638
+ # Temporary DPP-safe workaround: absorb D_d into A/B/C numerically
639
+ # so the CVXPY graph has only Parameter@Variable products.
640
+ if D_arr.ndim == 3 and D_arr.shape[0] == A_eff.shape[0] + 1:
641
+ D_steps = D_arr[1:]
642
+ elif D_arr.ndim == 3 and D_arr.shape[0] == A_eff.shape[0]:
643
+ D_steps = D_arr
644
+ else:
645
+ raise ValueError(
646
+ "Unexpected D_d shape for dynamics update: "
647
+ f"{D_arr.shape}, expected "
648
+ f"{(A_eff.shape[0] + 1, A_eff.shape[1], A_eff.shape[2])} "
649
+ f"or {(A_eff.shape[0], A_eff.shape[1], A_eff.shape[2])}."
650
+ )
651
+
652
+ A_eff = np.einsum("kij,kjl->kil", D_steps, A_eff)
653
+ B_eff = np.einsum("kij,kjl->kil", D_steps, B_eff)
654
+ C_eff = np.einsum("kij,kjl->kil", D_steps, C_eff)
655
+
656
+ self._set_param("A_d", A_eff)
657
+ self._set_param("B_d", B_eff)
658
+ self._set_param("C_d", C_eff)
659
+ if "x_prop" in self._problem.param_dict:
660
+ self._set_param("x_prop", x_prop)
661
+ elif self._ocp_vars.x_prop is not None:
662
+ self._ocp_vars.x_prop.value = np.asarray(x_prop)
663
+ if x_prop_plus is not None and self._ocp_vars.x_prop_plus is not None:
664
+ self._ocp_vars.x_prop_plus.value = np.asarray(x_prop_plus)
665
+ if E_d is not None and self._ocp_vars.E_d is not None:
666
+ self._ocp_vars.E_d.value = np.asarray(E_d)
667
+
668
+ def update_constraint_linearizations(
669
+ self,
670
+ nodal: List[dict] = None,
671
+ cross_node: List[dict] = None,
672
+ ) -> None:
673
+ """Update linearized constraint values and gradients.
674
+
675
+ Sets constraint function values and gradients at the current
676
+ linearization point for both nodal and cross-node constraints.
677
+
678
+ Args:
679
+ nodal: List of dicts for nodal constraints, each containing:
680
+ - ``g``: Constraint value at linearization point
681
+ - ``grad_g_x``: Gradient w.r.t. state
682
+ - ``grad_g_u``: Gradient w.r.t. control
683
+ cross_node: List of dicts for cross-node constraints, each containing:
684
+ - ``g``: Constraint value at linearization point
685
+ - ``grad_g_X``: Gradient w.r.t. full state trajectory
686
+ - ``grad_g_U``: Gradient w.r.t. full control trajectory
687
+ """
688
+ if nodal:
689
+ for g_id, constraint_data in enumerate(nodal):
690
+ self._set_param(f"g_{g_id}", constraint_data["g"])
691
+ self._set_param(f"grad_g_x_{g_id}", constraint_data["grad_g_x"])
692
+ self._set_param(f"grad_g_u_{g_id}", constraint_data["grad_g_u"])
693
+
694
+ if cross_node:
695
+ for g_id, constraint_data in enumerate(cross_node):
696
+ self._set_param(f"g_cross_{g_id}", constraint_data["g"])
697
+ self._set_param(f"grad_g_X_cross_{g_id}", constraint_data["grad_g_X"])
698
+ self._set_param(f"grad_g_U_cross_{g_id}", constraint_data["grad_g_U"])
699
+
700
+ def update_penalties(
701
+ self,
702
+ lam_prox: np.ndarray,
703
+ lam_cost: Union[float, np.ndarray],
704
+ lam_vc: np.ndarray,
705
+ lam_vb_nodal: np.ndarray,
706
+ lam_vb_cross: np.ndarray,
707
+ ) -> None:
708
+ """Update SCP penalty weights.
709
+
710
+ Sets the penalty weights that balance competing objectives in the
711
+ PTR convex subproblem.
712
+
713
+ Args:
714
+ lam_prox: Trust region weights, shape ``(N, n_states + n_controls)``.
715
+ lam_cost: Cost function weight. Scalar or array of shape
716
+ ``(n_states,)`` for per-state weighting.
717
+ lam_vc: Virtual control penalty weights, shape (N-1, n_states)
718
+ lam_vb_nodal: Virtual buffer penalty weights for nodal constraints,
719
+ shape ``(N, n_nodal_constraints)``.
720
+ lam_vb_cross: Virtual buffer penalty weights for cross-node
721
+ constraints, shape ``(n_cross_node_constraints,)``.
722
+ """
723
+ self._set_param("lam_prox", lam_prox)
724
+ self._set_param("lam_cost", lam_cost)
725
+ self._set_param("lam_vc", lam_vc)
726
+ self._set_param("lam_vb_nodal", lam_vb_nodal)
727
+ self._set_param("lam_vb_cross", lam_vb_cross)
728
+
729
+ def update_boundary_conditions(
730
+ self,
731
+ x_init: np.ndarray = None,
732
+ x_term: np.ndarray = None,
733
+ ) -> None:
734
+ """Update boundary condition parameters.
735
+
736
+ Sets initial and/or terminal state constraints. Only sets parameters
737
+ that exist in the problem (some problems may not have both).
738
+
739
+ Args:
740
+ x_init: Initial state vector, shape (n_states,). Optional.
741
+ x_term: Terminal state vector, shape (n_states,). Optional.
742
+ """
743
+ # No-op before initialize() — the CVXPy problem (and its param_dict)
744
+ # isn't built yet. Callers like Problem._sync_boundary_conditions
745
+ # may invoke this both before and after initialize().
746
+ if self._problem is None:
747
+ return
748
+ if x_init is not None and "x_init" in self._problem.param_dict:
749
+ self._set_param("x_init", x_init)
750
+ if x_term is not None and "x_term" in self._problem.param_dict:
751
+ self._set_param("x_term", x_term)
752
+
753
+ def get_stats(self) -> dict:
754
+ """Get solver statistics for diagnostics and printing.
755
+
756
+ Returns:
757
+ Dict containing:
758
+ - ``n_variables``: Total number of optimization variables
759
+ - ``n_parameters``: Total number of parameters
760
+ - ``n_constraints``: Total number of constraints
761
+ """
762
+ if self._problem is None:
763
+ return {"n_variables": 0, "n_parameters": 0, "n_constraints": 0}
764
+
765
+ return {
766
+ "n_variables": sum(var.size for var in self._problem.variables()),
767
+ "n_parameters": sum(param.size for param in self._problem.parameters()),
768
+ "n_constraints": sum(constraint.size for constraint in self._problem.constraints),
769
+ }
770
+
771
+ def _set_param(self, name: str, value: np.ndarray) -> None:
772
+ """Set a CVXPy parameter with helpful error messages on failure.
773
+
774
+ Args:
775
+ name: The parameter name in problem.param_dict
776
+ value: The value to assign
777
+
778
+ Raises:
779
+ ValueError: If the value is not real, with diagnostic information.
780
+ """
781
+ try:
782
+ param = self._problem.param_dict[name]
783
+ value_arr = np.asarray(value)
784
+
785
+ # Ensure the value shape matches the parameter shape exactly
786
+ # This is critical for Python 3.11+ where NumPy/CVXPy are stricter about shapes
787
+ if hasattr(param, "shape") and param.shape is not None:
788
+ expected_shape = param.shape
789
+ if value_arr.shape != expected_shape:
790
+ # Try to reshape if sizes match
791
+ if value_arr.size == np.prod(expected_shape):
792
+ value_arr = value_arr.reshape(expected_shape)
793
+ else:
794
+ # If sizes don't match, try squeezing extra dimensions first
795
+ value_arr = np.squeeze(value_arr)
796
+ if value_arr.shape != expected_shape and value_arr.size == np.prod(
797
+ expected_shape
798
+ ):
799
+ value_arr = value_arr.reshape(expected_shape)
800
+ elif value_arr.shape != expected_shape:
801
+ raise ValueError(
802
+ f"Parameter '{name}' shape mismatch: expected {expected_shape}, "
803
+ f"got {value.shape} (after squeezing: {value_arr.shape})"
804
+ )
805
+
806
+ param.value = value_arr
807
+ except ValueError as e:
808
+ if "must be real" in str(e):
809
+ arr = np.asarray(value)
810
+ nan_mask = ~np.isfinite(arr)
811
+ nan_indices = np.argwhere(nan_mask)
812
+
813
+ index_value_strs = [
814
+ f" {tuple(int(i) for i in idx)} -> {arr[tuple(idx)]}"
815
+ for idx in nan_indices[:20]
816
+ ]
817
+ if len(nan_indices) > 20:
818
+ index_value_strs.append(f" ... and {len(nan_indices) - 20} more")
819
+
820
+ arr_str = np.array2string(arr, threshold=200, edgeitems=3, max_line_width=120)
821
+ msg = (
822
+ f"Parameter '{name}' with shape {arr.shape} contains "
823
+ f"{len(nan_indices)} non-real value(s):\n"
824
+ + "\n".join(index_value_strs)
825
+ + f"\n\n{name} = {arr_str}"
826
+ )
827
+ raise ValueError(msg) from e
828
+ raise
829
+
830
+ def solve(self) -> PTRSolveResult:
831
+ """Solve the convex subproblem and return structured results.
832
+
833
+ Call ``update_dynamics_linearization()``, ``update_constraint_linearizations()``,
834
+ and ``update_penalties()`` before calling this method.
835
+
836
+ Returns:
837
+ PTRSolveResult containing unscaled trajectories, slack variables,
838
+ cost, and solver status.
839
+
840
+ Raises:
841
+ RuntimeError: If initialize() has not been called.
842
+ """
843
+ if self._problem is None:
844
+ raise RuntimeError(
845
+ "CVXPyPTRSolver.solve() called before initialize(). "
846
+ "Call initialize() first to build the problem structure."
847
+ )
848
+
849
+ self._solve_fn()
850
+
851
+ # Get scaling matrices
852
+ S_x = self._ocp_vars.S_x
853
+ c_x = self._ocp_vars.c_x
854
+ S_u = self._ocp_vars.S_u
855
+ c_u = self._ocp_vars.c_u
856
+
857
+ # Unscale state and control trajectories
858
+ x_scaled = self._problem.var_dict["x"].value # (N, n_states)
859
+ u_scaled = self._problem.var_dict["u"].value # (N, n_controls)
860
+ x = (S_x @ x_scaled.T + np.expand_dims(c_x, axis=1)).T
861
+ u = (S_u @ u_scaled.T + np.expand_dims(c_u, axis=1)).T
862
+
863
+ # Get virtual control slack
864
+ nu = self._problem.var_dict["nu"].value
865
+
866
+ # Get nodal constraint violation slacks
867
+ nu_vb = [var.value for var in self._ocp_vars.nu_vb]
868
+
869
+ # Get cross-node constraint violation slacks
870
+ nu_vb_cross = [var.value for var in self._ocp_vars.nu_vb_cross]
871
+
872
+ return PTRSolveResult(
873
+ x=x,
874
+ u=u,
875
+ nu=nu,
876
+ nu_vb=nu_vb,
877
+ nu_vb_cross=nu_vb_cross,
878
+ cost=self._problem.value,
879
+ status=self._problem.status,
880
+ )
881
+
882
+ def citation(self) -> List[str]:
883
+ """Return BibTeX citations for CVXPy.
884
+
885
+ Returns:
886
+ List containing BibTeX entries for CVXPy and DCCP papers.
887
+ """
888
+ return [
889
+ r"""@article{diamond2016cvxpy,
890
+ title={CVXPY: A Python-embedded modeling language for convex optimization},
891
+ author={Diamond, Steven and Boyd, Stephen},
892
+ journal={Journal of Machine Learning Research},
893
+ volume={17},
894
+ number={83},
895
+ pages={1--5},
896
+ year={2016}
897
+ }""",
898
+ r"""@article{agrawal2018rewriting,
899
+ title={A rewriting system for convex optimization problems},
900
+ author={Agrawal, Akshay and Verschueren, Robin and Diamond, Steven and Boyd, Stephen},
901
+ journal={Journal of Control and Decision},
902
+ volume={5},
903
+ number={1},
904
+ pages={42--60},
905
+ year={2018},
906
+ publisher={Taylor \& Francis}
907
+ }""",
908
+ ]