openscvx 2.dev5__py3-none-any.whl → 2.dev7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
openscvx/solvers/base.py CHANGED
@@ -5,40 +5,29 @@ must follow for use within successive convexification algorithms.
5
5
 
6
6
  !!! note
7
7
 
8
- Solvers own their optimization variables via ``create_variables()``.
9
- Convex constraint lowering remains in ``lower.py`` but uses the solver's
10
- variables.
11
-
12
- When adding non-CVXPy backends, there are two approaches:
13
-
14
- 1. **Solver owns the lowerer**: The solver implements a
15
- ``lower_convex_constraints()`` method containing the lowering logic.
16
-
17
- 2. **Solver determines the lowerer**: The solver references which lowerer
18
- to use, but the lowering logic stays in ``lower.py``. Example:
19
-
20
- ```python
21
- # In solver
22
- @property
23
- def lowerer(self):
24
- from openscvx.symbolic.lower import lower_cvxpy_constraints
25
- return lower_cvxpy_constraints
26
-
27
- # In lower_symbolic_problem()
28
- lowered_constraints = solver.lowerer(constraints, solver.variables, parameters)
29
- ```
8
+ Solvers own both their optimization variables (``create_variables()``) and
9
+ the lowering of any user ``.convex()`` constraints
10
+ (``lower_convex_constraints()``). The default ``lower_convex_constraints``
11
+ refuses user ``.convex()`` constraints with a clear error — backends that
12
+ accept them override it. This keeps ``openscvx.symbolic.lower``
13
+ backend-agnostic: it never branches on solver type, it just delegates.
14
+
15
+ See :class:`openscvx.solvers.ptr_solver.PTRSolver` for the PTR-specific
16
+ interface every PTR backend implements.
30
17
  """
31
18
 
19
+ import warnings
32
20
  from abc import ABC, abstractmethod
33
- from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional
21
+ from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple
34
22
 
35
- from pydantic import BaseModel, ConfigDict
23
+ from pydantic import BaseModel, ConfigDict, model_validator
36
24
 
37
25
  if TYPE_CHECKING:
38
26
  from openscvx.config import Config
39
27
  from openscvx.lowered import LoweredProblem
40
28
  from openscvx.lowered.jax_constraints import LoweredJaxConstraints
41
29
  from openscvx.lowered.unified import UnifiedControl, UnifiedState
30
+ from openscvx.symbolic.constraint_set import ConstraintSet
42
31
 
43
32
 
44
33
  class ConvexSolver(ABC):
@@ -91,10 +80,6 @@ class ConvexSolver(ABC):
91
80
  return MyResult(...)
92
81
  """
93
82
 
94
- #: Backend solver name (e.g., ``"QOCO"``, ``"CLARABEL"``). Subclasses
95
- #: must set this in ``__init__``.
96
- cvx_solver: str
97
-
98
83
  @abstractmethod
99
84
  def create_variables(
100
85
  self,
@@ -128,6 +113,50 @@ class ConvexSolver(ABC):
128
113
  """
129
114
  raise NotImplementedError
130
115
 
116
+ def lower_convex_constraints(
117
+ self,
118
+ constraints: "ConstraintSet",
119
+ parameters: Optional[Dict[str, Any]] = None,
120
+ ) -> Tuple[List[Any], Dict[str, Any]]:
121
+ """Lower user ``.convex()`` constraints into this backend's form.
122
+
123
+ Called once by :func:`openscvx.symbolic.lower.lower_symbolic_problem`
124
+ after ``create_variables()`` and before ``initialize()``.
125
+
126
+ The default implementation refuses any user ``.convex()``
127
+ constraints — appropriate for backends like
128
+ :class:`openscvx.solvers.qpax_ptr_solver.QPAXPTRSolver` that don't
129
+ accept second-order-cone constraints. Backends that do accept them
130
+ (e.g. :class:`openscvx.solvers.cvxpy_ptr_solver.CVXPyPTRSolver`)
131
+ override this to invoke their backend-specific lowerer.
132
+
133
+ Args:
134
+ constraints: Categorized symbolic constraints. Only the
135
+ ``nodal_convex`` / ``cross_node_convex`` lists matter here;
136
+ non-convex constraints go through the JAX lowering pipeline.
137
+ parameters: Optional dict of symbolic ``Parameter`` objects
138
+ referenced by the constraints. May be ``None``.
139
+
140
+ Returns:
141
+ ``(lowered_list, parameter_map)``. The first is a list of
142
+ backend-specific constraint objects (e.g. ``cp.Constraint``);
143
+ the second maps parameter names to backend-specific parameter
144
+ objects. Both are empty for the default refusal path.
145
+
146
+ Raises:
147
+ NotImplementedError: if the user defined any ``.convex()``
148
+ constraints and this backend doesn't override.
149
+ """
150
+ n = len(constraints.nodal_convex) + len(constraints.cross_node_convex)
151
+ if n:
152
+ raise NotImplementedError(
153
+ f"{type(self).__name__} does not support user-defined "
154
+ f".convex() constraints ({n} defined). Drop the .convex() "
155
+ "constraint or switch to a backend that supports them "
156
+ "(e.g. openscvx.CVXPyPTRSolver)."
157
+ )
158
+ return [], {}
159
+
131
160
  @abstractmethod
132
161
  def initialize(
133
162
  self,
@@ -241,11 +270,17 @@ class ConvexSolver(ABC):
241
270
  # Pydantic spec for dict / YAML validation
242
271
  # =============================================================================
243
272
 
244
- _SOLVER_MAP: Dict[str, type] = {} # populated by __init__.py after all classes are imported
245
273
 
274
+ class PTRSolverSpec(BaseModel):
275
+ """Validates PTR solver configuration from dict/YAML input.
246
276
 
247
- class SolverSpec(BaseModel):
248
- """Validates solver configuration from dict/YAML input.
277
+ The ``backend`` discriminator selects which concrete PTR backend to build:
278
+ ``"cvxpy"`` (the default,
279
+ :class:`openscvx.solvers.cvxpy_ptr_solver.CVXPyPTRSolver`) or ``"qpax"``
280
+ (:class:`openscvx.solvers.qpax_ptr_solver.QPAXPTRSolver`).
281
+
282
+ ``cvx_solver``, ``cvxpygen``, and ``cvxpygen_override`` are CVXPy-only;
283
+ setting them under ``backend="qpax"`` is a configuration error.
249
284
 
250
285
  !!! warning
251
286
  Enabling ``cvxpygen`` currently disables sparse parameter declarations.
@@ -255,15 +290,61 @@ class SolverSpec(BaseModel):
255
290
  """
256
291
 
257
292
  type: Literal["PTRSolver"] = "PTRSolver"
258
- cvx_solver: str = "QOCO"
293
+ backend: Literal["cvxpy", "qpax"] = "cvxpy"
294
+ cvx_solver: Optional[str] = None
259
295
  solver_args: Optional[Dict[str, Any]] = None
260
296
  cvxpygen: bool = False
261
297
  cvxpygen_override: bool = False
262
298
 
263
299
  model_config = ConfigDict(extra="forbid")
264
300
 
301
+ @model_validator(mode="after")
302
+ def _check_backend_fields(self):
303
+ if self.backend == "qpax":
304
+ offenders = [
305
+ name
306
+ for name, value in (
307
+ ("cvx_solver", self.cvx_solver),
308
+ ("cvxpygen", self.cvxpygen),
309
+ ("cvxpygen_override", self.cvxpygen_override),
310
+ )
311
+ if value
312
+ ]
313
+ if offenders:
314
+ raise ValueError(
315
+ f"{offenders} only valid for backend='cvxpy'; "
316
+ "remove these fields or set backend='cvxpy'."
317
+ )
318
+ return self
319
+
265
320
  def build(self) -> ConvexSolver:
266
- cls = _SOLVER_MAP.get(self.type)
267
- if cls is None:
268
- raise ValueError(f"Unknown solver {self.type!r}; expected one of {sorted(_SOLVER_MAP)}")
269
- return cls(**self.model_dump(exclude={"type"}, exclude_unset=True))
321
+ # Local imports keep CVXPy / qpax out of the import path until the
322
+ # corresponding backend is actually requested.
323
+ if self.backend == "cvxpy":
324
+ from .cvxpy_ptr_solver import CVXPyPTRSolver
325
+
326
+ return CVXPyPTRSolver(
327
+ cvx_solver=self.cvx_solver or "QOCO",
328
+ solver_args=self.solver_args,
329
+ cvxpygen=self.cvxpygen,
330
+ cvxpygen_override=self.cvxpygen_override,
331
+ )
332
+ from .qpax_ptr_solver import QPAXPTRSolver
333
+
334
+ return QPAXPTRSolver(solver_args=self.solver_args)
335
+
336
+
337
+ def __getattr__(name: str):
338
+ """Deprecated alias: ``SolverSpec`` → :class:`PTRSolverSpec`.
339
+
340
+ Kept for one release so existing dict/YAML configs and tests that import
341
+ ``SolverSpec`` continue to work. Emit a ``DeprecationWarning`` on access.
342
+ """
343
+ if name == "SolverSpec":
344
+ warnings.warn(
345
+ "openscvx.solvers.base.SolverSpec is deprecated; use PTRSolverSpec.",
346
+ DeprecationWarning,
347
+ stacklevel=2,
348
+ )
349
+ return PTRSolverSpec
350
+ raise AttributeError(f"module {__name__!r} has no attribute {name!r}")