openscvx 2.dev5__py3-none-any.whl → 2.dev7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
openscvx/__init__.py CHANGED
@@ -12,6 +12,7 @@ import openscvx.symbolic.expr.spatial as spatial
12
12
  import openscvx.symbolic.expr.stl as stl
13
13
  import openscvx.symbolic.expr.stljax as stljax
14
14
  from openscvx.algorithms import (
15
+ AdaptiveProximalWeight,
15
16
  AugmentedLagrangian,
16
17
  ConstantProximalWeight,
17
18
  PenalizedTrustRegion,
@@ -28,7 +29,10 @@ from openscvx.expert import ByofSpec
28
29
  from openscvx.integrations import DynamicsAdapter, MjxDynamics
29
30
  from openscvx.loader import load_dict, load_json, load_yaml
30
31
  from openscvx.problem import Problem
31
- from openscvx.solvers import PTRSolver
32
+ from openscvx.solvers import CVXPyPTRSolver, PTRSolver
33
+
34
+ # QPAXPTRSolver is exposed lazily via __getattr__ below to keep `import qpax`
35
+ # off the hot import path for users who don't install the optional extra.
32
36
  from openscvx.symbolic.expr import (
33
37
  CTCS,
34
38
  Abs,
@@ -90,6 +94,16 @@ from openscvx.utils.cache import clear_cache, get_cache_dir, get_cache_size
90
94
 
91
95
  load_results = OptimizationResults.load
92
96
 
97
+
98
+ def __getattr__(name: str):
99
+ """Lazy export for backends that depend on optional packages."""
100
+ if name == "QPAXPTRSolver":
101
+ from openscvx.solvers.qpax_ptr_solver import QPAXPTRSolver
102
+
103
+ return QPAXPTRSolver
104
+ raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
105
+
106
+
93
107
  __all__ = [
94
108
  # Main Trajectory Optimization Entrypoint
95
109
  "Problem",
@@ -187,9 +201,12 @@ __all__ = [
187
201
  "VectorizeDiscretizeLinearize",
188
202
  # Convex Solver
189
203
  "PTRSolver",
204
+ "CVXPyPTRSolver",
205
+ "QPAXPTRSolver",
190
206
  # Algorithm & Autotuning
191
207
  "PenalizedTrustRegion",
192
208
  "AugmentedLagrangian",
209
+ "AdaptiveProximalWeight",
193
210
  "ConstantProximalWeight",
194
211
  "RampProximalWeight",
195
212
  ]
openscvx/_version.py CHANGED
@@ -18,7 +18,7 @@ version_tuple: tuple[int | str, ...]
18
18
  commit_id: str | None
19
19
  __commit_id__: str | None
20
20
 
21
- __version__ = version = '2.dev5'
22
- __version_tuple__ = version_tuple = (2, 'dev5')
21
+ __version__ = version = '2.dev7'
22
+ __version_tuple__ = version_tuple = (2, 'dev7')
23
23
 
24
24
  __commit_id__ = commit_id = None
@@ -82,12 +82,19 @@ from typing import Annotated, Any, Dict, List, Optional, Union
82
82
 
83
83
  from pydantic import BaseModel, ConfigDict, Field, field_validator
84
84
 
85
- from .augmented_lagrangian import AugmentedLagrangian, AugmentedLagrangianSpec
85
+ from .autotuner import (
86
+ AdaptiveProximalWeight,
87
+ AdaptiveProximalWeightSpec,
88
+ AugmentedLagrangian,
89
+ AugmentedLagrangianSpec,
90
+ ConstantProximalWeight,
91
+ ConstantProximalWeightSpec,
92
+ RampProximalWeight,
93
+ RampProximalWeightSpec,
94
+ )
86
95
  from .base import Algorithm, AlgorithmState, AutotuningBase, DiscretizationResult
87
- from .constant_proximal_weight import ConstantProximalWeight, ConstantProximalWeightSpec
88
96
  from .optimization_results import OptimizationResults
89
- from .penalized_trust_region import PenalizedTrustRegion
90
- from .ramp_proximal_weight import RampProximalWeight, RampProximalWeightSpec
97
+ from .scvx import PenalizedTrustRegion
91
98
  from .weights import Weights
92
99
 
93
100
  # ---------------------------------------------------------------------------
@@ -97,6 +104,7 @@ from .weights import Weights
97
104
  AutotunerConfig = Annotated[
98
105
  Union[
99
106
  AugmentedLagrangianSpec,
107
+ AdaptiveProximalWeightSpec,
100
108
  RampProximalWeightSpec,
101
109
  ConstantProximalWeightSpec,
102
110
  ],
@@ -173,6 +181,7 @@ __all__ = [
173
181
  "PenalizedTrustRegion",
174
182
  "AutotuningBase",
175
183
  "AugmentedLagrangian",
184
+ "AdaptiveProximalWeight",
176
185
  "ConstantProximalWeight",
177
186
  "RampProximalWeight",
178
187
  # Config models
@@ -0,0 +1,17 @@
1
+ """SCP weight autotuning strategies."""
2
+
3
+ from .adaptive_proximal_weight import AdaptiveProximalWeight, AdaptiveProximalWeightSpec
4
+ from .augmented_lagrangian import AugmentedLagrangian, AugmentedLagrangianSpec
5
+ from .constant_proximal_weight import ConstantProximalWeight, ConstantProximalWeightSpec
6
+ from .ramp_proximal_weight import RampProximalWeight, RampProximalWeightSpec
7
+
8
+ __all__ = [
9
+ "AdaptiveProximalWeight",
10
+ "AdaptiveProximalWeightSpec",
11
+ "AugmentedLagrangian",
12
+ "AugmentedLagrangianSpec",
13
+ "ConstantProximalWeight",
14
+ "ConstantProximalWeightSpec",
15
+ "RampProximalWeight",
16
+ "RampProximalWeightSpec",
17
+ ]
@@ -0,0 +1,190 @@
1
+ """Autotuning functions for SCP (Successive Convex Programming) parameters."""
2
+
3
+ from copy import deepcopy
4
+ from typing import TYPE_CHECKING, Literal
5
+
6
+ import numpy as np
7
+ from pydantic import BaseModel, ConfigDict
8
+
9
+ from openscvx.config import Config
10
+
11
+ from ..base import AutotuningBase
12
+ from .augmented_lagrangian import AugmentedLagrangian
13
+
14
+ if TYPE_CHECKING:
15
+ from openscvx.lowered import LoweredJaxConstraints
16
+
17
+ from ..base import AlgorithmState, CandidateIterate
18
+ from ..weights import Weights
19
+
20
+
21
+ class AdaptiveProximalWeight(AutotuningBase):
22
+ """PTR-style proximal adaptation with fixed virtual penalty weights.
23
+
24
+ Same acceptance-ratio logic as :class:`AugmentedLagrangian` for ``lam_prox``,
25
+ but ``lam_vc`` and ``lam_vb_*`` are held constant at their current state values.
26
+ """
27
+
28
+ COLUMNS = AugmentedLagrangian.COLUMNS
29
+
30
+ def __init__(
31
+ self,
32
+ gamma_1: float = 2.0,
33
+ gamma_2: float = 0.5,
34
+ eta_0: float = 1e-2,
35
+ eta_1: float = 1e-1,
36
+ eta_2: float = 0.8,
37
+ lam_prox_min: float = 1e-3,
38
+ lam_prox_max: float = 1e4,
39
+ lam_cost_drop: int = -1,
40
+ lam_cost_relax: float = 1.0,
41
+ ):
42
+ self.gamma_1 = gamma_1
43
+ self.gamma_2 = gamma_2
44
+ self.eta_0 = eta_0
45
+ self.eta_1 = eta_1
46
+ self.eta_2 = eta_2
47
+ self.lam_prox_min = lam_prox_min
48
+ self.lam_prox_max = lam_prox_max
49
+ self.lam_cost_drop = lam_cost_drop
50
+ self.lam_cost_relax = lam_cost_relax
51
+
52
+ @staticmethod
53
+ def _copy_virtual_weights(
54
+ candidate: "CandidateIterate",
55
+ state: "AlgorithmState",
56
+ ) -> None:
57
+ candidate.lam_vc = state.lam_vc
58
+ candidate.lam_vb_nodal = state.lam_vb_nodal
59
+ candidate.lam_vb_cross = state.lam_vb_cross
60
+
61
+ def update_weights(
62
+ self,
63
+ state: "AlgorithmState",
64
+ candidate: "CandidateIterate",
65
+ nodal_constraints: "LoweredJaxConstraints",
66
+ settings: Config,
67
+ params: dict,
68
+ weights: "Weights",
69
+ ) -> str:
70
+ """Update SCP proximal weight based on acceptance ratio; keep VC/VB fixed."""
71
+ candidate_x_prop = (
72
+ candidate.x_prop_plus[1:] if candidate.x_prop_plus is not None else candidate.x_prop
73
+ )
74
+ (
75
+ nonlinear_cost,
76
+ nonlinear_penalty,
77
+ nodal_penalty,
78
+ ) = self.calculate_nonlinear_penalty(
79
+ candidate_x_prop,
80
+ candidate.x,
81
+ candidate.u,
82
+ state.lam_vc,
83
+ state.lam_vb_nodal,
84
+ state.lam_vb_cross,
85
+ state.lam_cost,
86
+ nodal_constraints,
87
+ params,
88
+ settings,
89
+ )
90
+
91
+ candidate.J_nonlin = nonlinear_cost + nonlinear_penalty + nodal_penalty
92
+
93
+ if state.k > self.lam_cost_drop:
94
+ candidate.lam_cost = state.lam_cost * self.lam_cost_relax
95
+ else:
96
+ candidate.lam_cost = weights.lam_cost
97
+
98
+ lam_prox_k = deepcopy(state.lam_prox)
99
+
100
+ if state.k > 1:
101
+ state_x_prop_plus = state.x_prop_plus()
102
+ state_x_prop = (
103
+ state_x_prop_plus[1:] if state_x_prop_plus is not None else state.x_prop()
104
+ )
105
+ (
106
+ prev_nonlinear_cost,
107
+ prev_nonlinear_penalty,
108
+ prev_nodal_penalty,
109
+ ) = self.calculate_nonlinear_penalty(
110
+ state_x_prop,
111
+ state.x,
112
+ state.u,
113
+ state.lam_vc,
114
+ state.lam_vb_nodal,
115
+ state.lam_vb_cross,
116
+ state.lam_cost,
117
+ nodal_constraints,
118
+ params,
119
+ settings,
120
+ )
121
+
122
+ J_nonlin_prev = prev_nonlinear_cost + prev_nonlinear_penalty + prev_nodal_penalty
123
+
124
+ actual_reduction = J_nonlin_prev - candidate.J_nonlin
125
+ predicted_reduction = J_nonlin_prev - candidate.J_lin
126
+
127
+ if predicted_reduction == 0:
128
+ raise ValueError("Predicted reduction is 0.")
129
+
130
+ rho = actual_reduction / predicted_reduction
131
+
132
+ state.pred_reduction_history.append(predicted_reduction)
133
+ state.actual_reduction_history.append(actual_reduction)
134
+ state.acceptance_ratio_history.append(rho)
135
+
136
+ if rho < self.eta_0:
137
+ lam_prox_k1 = np.minimum(self.lam_prox_max, self.gamma_1 * lam_prox_k)
138
+ candidate.lam_prox = lam_prox_k1
139
+ state.reject_solution(candidate)
140
+ adaptive_state = "Reject Higher"
141
+ elif rho >= self.eta_0 and rho < self.eta_1:
142
+ lam_prox_k1 = np.minimum(self.lam_prox_max, self.gamma_1 * lam_prox_k)
143
+ candidate.lam_prox = lam_prox_k1
144
+ self._copy_virtual_weights(candidate, state)
145
+ state.accept_solution(candidate)
146
+ adaptive_state = "Accept Higher"
147
+ elif rho >= self.eta_1 and rho < self.eta_2:
148
+ candidate.lam_prox = lam_prox_k
149
+ self._copy_virtual_weights(candidate, state)
150
+ state.accept_solution(candidate)
151
+ adaptive_state = "Accept Constant"
152
+ else:
153
+ lam_prox_k1 = np.maximum(self.lam_prox_min, self.gamma_2 * lam_prox_k)
154
+ candidate.lam_prox = lam_prox_k1
155
+ self._copy_virtual_weights(candidate, state)
156
+ state.accept_solution(candidate)
157
+ adaptive_state = "Accept Lower"
158
+
159
+ else:
160
+ candidate.lam_prox = lam_prox_k
161
+ self._copy_virtual_weights(candidate, state)
162
+ state.accept_solution(candidate)
163
+ adaptive_state = "Initial"
164
+
165
+ return adaptive_state
166
+
167
+
168
+ # =============================================================================
169
+ # Pydantic spec for dict / YAML validation
170
+ # =============================================================================
171
+
172
+
173
+ class AdaptiveProximalWeightSpec(BaseModel):
174
+ """Validates AdaptiveProximalWeight configuration from dict/YAML input."""
175
+
176
+ type: Literal["AdaptiveProximalWeight"] = "AdaptiveProximalWeight"
177
+ gamma_1: float = 2.0
178
+ gamma_2: float = 0.5
179
+ eta_0: float = 1e-2
180
+ eta_1: float = 1e-1
181
+ eta_2: float = 0.8
182
+ lam_prox_min: float = 1e-3
183
+ lam_prox_max: float = 1e4
184
+ lam_cost_drop: int = -1
185
+ lam_cost_relax: float = 1.0
186
+
187
+ model_config = ConfigDict(extra="forbid")
188
+
189
+ def build(self) -> AdaptiveProximalWeight:
190
+ return AdaptiveProximalWeight(**self.model_dump(exclude={"type"}, exclude_unset=True))
@@ -14,13 +14,13 @@ from openscvx.utils.printing import (
14
14
  color_adaptive_state,
15
15
  )
16
16
 
17
- from .base import AutotuningBase
17
+ from ..base import AutotuningBase
18
18
 
19
19
  if TYPE_CHECKING:
20
20
  from openscvx.lowered import LoweredJaxConstraints
21
21
 
22
- from .base import AlgorithmState, CandidateIterate
23
- from .weights import Weights
22
+ from ..base import AlgorithmState, CandidateIterate
23
+ from ..weights import Weights
24
24
 
25
25
 
26
26
  class AugmentedLagrangian(AutotuningBase):
@@ -86,8 +86,12 @@ class AugmentedLagrangian(AutotuningBase):
86
86
  eta_2: Threshold above which solution is accepted with lower weight.
87
87
  Defaults to 0.8.
88
88
  ep: Threshold for virtual control weight update (nu > ep vs nu <= ep).
89
- Defaults to 0.5.
90
- eta_lambda: Step size for virtual control weight update. Defaults to 1e0.
89
+ Must lie in (0, 1). Defaults to 0.99; when tuning, try 1e-1 if needed.
90
+ Typically tuned together with ``eta_lambda`` (often the first
91
+ parameters adjusted).
92
+ eta_lambda: Step size for virtual control weight update. Defaults to 1e1.
93
+ Typically tuned together with ``ep`` (often the first parameters
94
+ adjusted).
91
95
  lam_vc_max: Maximum virtual control penalty weight. Defaults to 1e5.
92
96
  lam_prox_min: Minimum trust region (proximal) weight. Defaults to 1e-3.
93
97
  lam_prox_max: Maximum trust region (proximal) weight. Defaults to 2e5.
@@ -6,13 +6,13 @@ from pydantic import BaseModel, ConfigDict
6
6
 
7
7
  from openscvx.config import Config
8
8
 
9
- from .base import AutotuningBase
9
+ from ..base import AutotuningBase
10
10
 
11
11
  if TYPE_CHECKING:
12
12
  from openscvx.lowered import LoweredJaxConstraints
13
13
 
14
- from .base import AlgorithmState, CandidateIterate
15
- from .weights import Weights
14
+ from ..base import AlgorithmState, CandidateIterate
15
+ from ..weights import Weights
16
16
 
17
17
 
18
18
  class ConstantProximalWeight(AutotuningBase):
@@ -7,13 +7,13 @@ from pydantic import BaseModel, ConfigDict
7
7
 
8
8
  from openscvx.config import Config
9
9
 
10
- from .base import AutotuningBase
10
+ from ..base import AutotuningBase
11
11
 
12
12
  if TYPE_CHECKING:
13
13
  from openscvx.lowered import LoweredJaxConstraints
14
14
 
15
- from .base import AlgorithmState, CandidateIterate
16
- from .weights import Weights
15
+ from ..base import AlgorithmState, CandidateIterate
16
+ from ..weights import Weights
17
17
 
18
18
 
19
19
  class RampProximalWeight(AutotuningBase):
@@ -0,0 +1,5 @@
1
+ """Successive convexification algorithm implementations."""
2
+
3
+ from .penalized_trust_region import PenalizedTrustRegion
4
+
5
+ __all__ = ["PenalizedTrustRegion"]
@@ -22,11 +22,11 @@ from openscvx.utils.printing import (
22
22
  color_prob_stat,
23
23
  )
24
24
 
25
- from .augmented_lagrangian import AugmentedLagrangian
26
- from .base import Algorithm, AlgorithmState, CandidateIterate
27
- from .constant_proximal_weight import ConstantProximalWeight
28
- from .ramp_proximal_weight import RampProximalWeight
29
- from .weights import Weights
25
+ from ..autotuner.augmented_lagrangian import AugmentedLagrangian
26
+ from ..autotuner.constant_proximal_weight import ConstantProximalWeight
27
+ from ..autotuner.ramp_proximal_weight import RampProximalWeight
28
+ from ..base import Algorithm, AlgorithmState, CandidateIterate
29
+ from ..weights import Weights
30
30
 
31
31
  if TYPE_CHECKING:
32
32
  from openscvx.lowered import LoweredJaxConstraints
@@ -34,7 +34,7 @@ if TYPE_CHECKING:
34
34
  from openscvx.symbolic.expr.control import Control
35
35
  from openscvx.symbolic.expr.state import State
36
36
 
37
- from .base import AutotuningBase
37
+ from ..base import AutotuningBase
38
38
 
39
39
  warnings.filterwarnings("ignore")
40
40
 
openscvx/loader.py CHANGED
@@ -40,7 +40,7 @@ from pydantic import BaseModel, ConfigDict
40
40
  from openscvx.algorithms import PenalizedTrustRegionConfig
41
41
  from openscvx.config import SettingsSpec
42
42
  from openscvx.discretization import DiscretizerSpec
43
- from openscvx.solvers import SolverSpec
43
+ from openscvx.solvers import PTRSolverSpec
44
44
  from openscvx.symbolic.expr.control import ControlSpec
45
45
  from openscvx.symbolic.expr.expr import Expr
46
46
  from openscvx.symbolic.expr.parameter import ParameterSpec
@@ -66,7 +66,7 @@ class ProblemSpec(BaseModel):
66
66
  constraints: List[str] = []
67
67
  algorithm: Optional[PenalizedTrustRegionConfig] = None
68
68
  discretizer: Optional[DiscretizerSpec] = None
69
- solver: Optional[SolverSpec] = None
69
+ solver: Optional[PTRSolverSpec] = None
70
70
  settings: Optional[SettingsSpec] = None
71
71
  states_prop: Optional[List[StateSpec]] = None
72
72
  dynamics_prop: Optional[Dict[str, Any]] = None
@@ -17,7 +17,10 @@ class LoweredCvxpyConstraints:
17
17
 
18
18
  Attributes:
19
19
  constraints: List of CVXPy constraint objects (cp.Constraint).
20
- Includes both nodal and cross-node convex constraints.
20
+ Includes both nodal and cross-node convex constraints. Empty
21
+ for backends that don't accept ``.convex()`` constraints — the
22
+ refusal happens earlier, in
23
+ :meth:`openscvx.solvers.base.ConvexSolver.lower_convex_constraints`.
21
24
  """
22
25
 
23
26
  constraints: list["cp.Constraint"] = field(default_factory=list)
openscvx/problem.py CHANGED
@@ -201,9 +201,11 @@ class Problem:
201
201
  discretizer=ox.LinearizeDiscretize(dis_type="ZOH", ode_solver="Dopri8")
202
202
  solver: Convex subproblem solver configuration. Accepts:
203
203
 
204
- - ``None`` — uses ``PTRSolver()`` with defaults (QOCO backend).
204
+ - ``None`` — uses ``CVXPyPTRSolver()`` with defaults (QOCO backend).
205
205
  - A ``ConvexSolver`` instance — used directly.
206
- - A ``dict`` — passed as kwargs to ``PTRSolver()``.
206
+ - A ``dict`` — validated as ``PTRSolverSpec``; the ``backend``
207
+ field (``"cvxpy"`` or ``"qpax"``) selects the concrete
208
+ backend.
207
209
 
208
210
  Examples::
209
211
 
@@ -216,8 +218,12 @@ class Problem:
216
218
  # Enable cvxpygen code generation
217
219
  solver={"cvxpygen": True}
218
220
 
221
+ # JAX-native QPAX backend (no cvx_solver / cvxpygen fields)
222
+ solver={"backend": "qpax"}
223
+
219
224
  # Instance
220
- solver=ox.PTRSolver(cvx_solver="CLARABEL")
225
+ solver=ox.CVXPyPTRSolver(cvx_solver="CLARABEL")
226
+ solver=ox.QPAXPTRSolver()
221
227
  byof (ByofSpec, optional): Expert mode only. Raw JAX functions to
222
228
  bypass symbolic layer. See :class:`openscvx.expert.ByofSpec` for
223
229
  detailed documentation.
@@ -393,9 +399,9 @@ class Problem:
393
399
  def solver(self) -> ConvexSolver:
394
400
  """Access the convex subproblem solver instance.
395
401
 
396
- Attributes such as ``cvx_solver``, ``solver_args``, ``cvxpygen``, and
397
- ``cvxpygen_override`` can be modified freely before ``initialize``
398
- is called::
402
+ Backend-specific attributes (e.g. ``cvx_solver``, ``solver_args``,
403
+ ``cvxpygen``, ``cvxpygen_override`` on :class:`CVXPyPTRSolver`) can
404
+ be modified freely before ``initialize`` is called::
399
405
 
400
406
  problem.solver.solver_args = {"abstol": 1e-6, "reltol": 1e-9}
401
407
  problem.solver.cvxpygen = True
@@ -407,7 +413,7 @@ class Problem:
407
413
  will have no effect on subsequent solves.
408
414
 
409
415
  Returns:
410
- The solver instance (e.g., PTRSolver).
416
+ The solver instance a concrete :class:`PTRSolver` subclass.
411
417
  """
412
418
  return self._solver
413
419
 
@@ -510,12 +516,12 @@ class Problem:
510
516
  self._lowered.x_prop_unified.final[state._slice] = state.final
511
517
  self._lowered.x_prop_unified.final_type[state._slice] = state.final_type
512
518
 
513
- # Update CVXPy solver parameters (only if solver is initialized)
514
- if self._solver._problem is not None:
515
- self._solver.update_boundary_conditions(
516
- x_init=self._lowered.x_unified.initial,
517
- x_term=self._lowered.x_unified.final,
518
- )
519
+ # Push to the solver both backends short-circuit on a pre-initialize
520
+ # call, so this is safe to invoke from any lifecycle point.
521
+ self._solver.update_boundary_conditions(
522
+ x_init=self._lowered.x_unified.initial,
523
+ x_term=self._lowered.x_unified.final,
524
+ )
519
525
 
520
526
  def _sync_guesses(self):
521
527
  """Sync trajectory guesses from State/Control objects to lowered representation.
@@ -26,13 +26,16 @@ class ConvexSolver(ABC):
26
26
  ...
27
27
  ```
28
28
 
29
- This architecture enables users to implement custom solver backends such as:
29
+ The Penalized Trust-Region (PTR) subproblem ships with two concrete backends:
30
30
 
31
- - Direct Clarabel solver (Rust-based, GPU-capable)
32
- - QPAX (JAX-based QP solver for end-to-end differentiability)
33
- - OSQP direct interface (specialized for QP structure)
34
- - Custom embedded solvers for real-time applications
35
- - Research solvers with specialized structure exploitation
31
+ - :class:`CVXPyPTRSolver` DCP graph via CVXPy, dispatched to any of its
32
+ supported conic solvers (QOCO, CLARABEL, ...). Optional code generation
33
+ via cvxpygen for improved per-iteration performance.
34
+ - :class:`QPAXPTRSolver` flat ``(Q, q, A, b, G, h)`` assembled as JAX
35
+ arrays and solved with ``qpax.solve_qp``. Aimed at end-to-end JAX
36
+ differentiability of the SCP loop (follow-up work).
37
+
38
+ Both share the abstract :class:`PTRSolver` contract.
36
39
 
37
40
  Note:
38
41
  Solvers own their optimization variables (e.g., ``CVXPySolver.ocp_vars``).
@@ -41,36 +44,50 @@ Note:
41
44
  for the interface details.
42
45
  """
43
46
 
47
+ import warnings
44
48
  from typing import Any
45
49
 
46
- from .base import _SOLVER_MAP, ConvexSolver, SolverSpec
50
+ from .base import ConvexSolver, PTRSolverSpec
51
+ from .cvxpy_ptr_solver import CVXPyPTRSolver
47
52
  from .ptr_solver import PTRSolver, PTRSolveResult
48
53
 
49
- # ---------------------------------------------------------------------------
50
- # Populate the solver class map now that all classes are imported
51
- # ---------------------------------------------------------------------------
52
54
 
53
- _SOLVER_MAP.update(
54
- {
55
- "PTRSolver": PTRSolver,
56
- }
57
- )
55
+ def resolve_solver_config(val: Any) -> PTRSolverSpec:
56
+ """Validate a dict / Spec into a :class:`PTRSolverSpec` instance."""
57
+ if isinstance(val, PTRSolverSpec):
58
+ return val
59
+ return PTRSolverSpec.model_validate(val)
60
+
58
61
 
62
+ def __getattr__(name: str):
63
+ """Deprecated alias: ``SolverSpec`` → :class:`PTRSolverSpec`."""
64
+ if name == "SolverSpec":
65
+ warnings.warn(
66
+ "openscvx.solvers.SolverSpec is deprecated; use PTRSolverSpec.",
67
+ DeprecationWarning,
68
+ stacklevel=2,
69
+ )
70
+ return PTRSolverSpec
71
+ if name == "QPAXPTRSolver":
72
+ # Lazy import so users without the qpax extra don't pay a hard
73
+ # ImportError just for `from openscvx.solvers import QPAXPTRSolver`
74
+ # — the import error gets deferred to instantiation time, where the
75
+ # error message points at the install command.
76
+ from .qpax_ptr_solver import QPAXPTRSolver
59
77
 
60
- def resolve_solver_config(val: Any) -> SolverSpec:
61
- """Validate a dict/Spec into a :class:`SolverSpec` instance."""
62
- if isinstance(val, SolverSpec):
63
- return val
64
- return SolverSpec.model_validate(val)
78
+ return QPAXPTRSolver
79
+ raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
65
80
 
66
81
 
67
82
  __all__ = [
68
- # Base class
83
+ # Base classes
69
84
  "ConvexSolver",
70
- # PTR solver
71
85
  "PTRSolver",
72
86
  "PTRSolveResult",
87
+ # PTR backends
88
+ "CVXPyPTRSolver",
89
+ "QPAXPTRSolver",
73
90
  # Config
74
- "SolverSpec",
91
+ "PTRSolverSpec",
75
92
  "resolve_solver_config",
76
93
  ]