openscvx 0.5.2.dev2__py3-none-any.whl → 0.5.2.dev11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openscvx/__init__.py +27 -1
- openscvx/_version.py +2 -2
- openscvx/algorithms/__init__.py +13 -4
- openscvx/algorithms/autotuner/__init__.py +17 -0
- openscvx/algorithms/autotuner/adaptive_proximal_weight.py +190 -0
- openscvx/algorithms/{augmented_lagrangian.py → autotuner/augmented_lagrangian.py} +9 -5
- openscvx/algorithms/{constant_proximal_weight.py → autotuner/constant_proximal_weight.py} +3 -3
- openscvx/algorithms/{ramp_proximal_weight.py → autotuner/ramp_proximal_weight.py} +3 -3
- openscvx/algorithms/scvx/__init__.py +5 -0
- openscvx/algorithms/{penalized_trust_region.py → scvx/penalized_trust_region.py} +38 -9
- openscvx/discretization/discretize_linearize.py +6 -1
- openscvx/discretization/sparse_utils/sparse_jacobian.py +6 -1
- openscvx/integrations/__init__.py +34 -30
- openscvx/integrations/base.py +89 -0
- openscvx/integrations/mjx.py +247 -74
- openscvx/loader.py +2 -2
- openscvx/lowered/cvxpy_constraints.py +4 -1
- openscvx/plotting/viser/server.py +2 -2
- openscvx/problem.py +36 -17
- openscvx/solvers/__init__.py +51 -24
- openscvx/solvers/base.py +124 -36
- openscvx/solvers/cvxpy_ptr_solver.py +908 -0
- openscvx/solvers/moreau_ptr_solver.py +1125 -0
- openscvx/solvers/ptr_solver.py +99 -846
- openscvx/solvers/qpax_ptr_solver.py +776 -0
- openscvx/symbolic/hashing.py +16 -3
- openscvx/symbolic/lower.py +6 -7
- openscvx/utils/printing.py +9 -1
- {openscvx-0.5.2.dev2.dist-info → openscvx-0.5.2.dev11.dist-info}/METADATA +14 -16
- {openscvx-0.5.2.dev2.dist-info → openscvx-0.5.2.dev11.dist-info}/RECORD +34 -27
- {openscvx-0.5.2.dev2.dist-info → openscvx-0.5.2.dev11.dist-info}/WHEEL +0 -0
- {openscvx-0.5.2.dev2.dist-info → openscvx-0.5.2.dev11.dist-info}/entry_points.txt +0 -0
- {openscvx-0.5.2.dev2.dist-info → openscvx-0.5.2.dev11.dist-info}/licenses/LICENSE +0 -0
- {openscvx-0.5.2.dev2.dist-info → openscvx-0.5.2.dev11.dist-info}/top_level.txt +0 -0
openscvx/__init__.py
CHANGED
|
@@ -12,6 +12,7 @@ import openscvx.symbolic.expr.spatial as spatial
|
|
|
12
12
|
import openscvx.symbolic.expr.stl as stl
|
|
13
13
|
import openscvx.symbolic.expr.stljax as stljax
|
|
14
14
|
from openscvx.algorithms import (
|
|
15
|
+
AdaptiveProximalWeight,
|
|
15
16
|
AugmentedLagrangian,
|
|
16
17
|
ConstantProximalWeight,
|
|
17
18
|
PenalizedTrustRegion,
|
|
@@ -25,9 +26,13 @@ from openscvx.discretization import (
|
|
|
25
26
|
VectorizeDiscretizeLinearize,
|
|
26
27
|
)
|
|
27
28
|
from openscvx.expert import ByofSpec
|
|
29
|
+
from openscvx.integrations import DynamicsAdapter, MjxDynamics
|
|
28
30
|
from openscvx.loader import load_dict, load_json, load_yaml
|
|
29
31
|
from openscvx.problem import Problem
|
|
30
|
-
from openscvx.solvers import PTRSolver
|
|
32
|
+
from openscvx.solvers import CVXPyPTRSolver, PTRSolver
|
|
33
|
+
|
|
34
|
+
# QPAXPTRSolver is exposed lazily via __getattr__ below to keep `import qpax`
|
|
35
|
+
# off the hot import path for users who don't install the optional extra.
|
|
31
36
|
from openscvx.symbolic.expr import (
|
|
32
37
|
CTCS,
|
|
33
38
|
Abs,
|
|
@@ -89,6 +94,20 @@ from openscvx.utils.cache import clear_cache, get_cache_dir, get_cache_size
|
|
|
89
94
|
|
|
90
95
|
load_results = OptimizationResults.load
|
|
91
96
|
|
|
97
|
+
|
|
98
|
+
def __getattr__(name: str):
|
|
99
|
+
"""Lazy export for backends that depend on optional packages."""
|
|
100
|
+
if name == "QPAXPTRSolver":
|
|
101
|
+
from openscvx.solvers.qpax_ptr_solver import QPAXPTRSolver
|
|
102
|
+
|
|
103
|
+
return QPAXPTRSolver
|
|
104
|
+
if name == "MoreauPTRSolver":
|
|
105
|
+
from openscvx.solvers.moreau_ptr_solver import MoreauPTRSolver
|
|
106
|
+
|
|
107
|
+
return MoreauPTRSolver
|
|
108
|
+
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
|
109
|
+
|
|
110
|
+
|
|
92
111
|
__all__ = [
|
|
93
112
|
# Main Trajectory Optimization Entrypoint
|
|
94
113
|
"Problem",
|
|
@@ -176,6 +195,9 @@ __all__ = [
|
|
|
176
195
|
"lie",
|
|
177
196
|
# Expert mode types
|
|
178
197
|
"ByofSpec",
|
|
198
|
+
# External-backend dynamics adapters
|
|
199
|
+
"DynamicsAdapter",
|
|
200
|
+
"MjxDynamics",
|
|
179
201
|
# Discretization
|
|
180
202
|
"DiscretizeLinearizeVectorize",
|
|
181
203
|
"LinearizeDiscretize",
|
|
@@ -183,9 +205,13 @@ __all__ = [
|
|
|
183
205
|
"VectorizeDiscretizeLinearize",
|
|
184
206
|
# Convex Solver
|
|
185
207
|
"PTRSolver",
|
|
208
|
+
"CVXPyPTRSolver",
|
|
209
|
+
"QPAXPTRSolver",
|
|
210
|
+
"MoreauPTRSolver",
|
|
186
211
|
# Algorithm & Autotuning
|
|
187
212
|
"PenalizedTrustRegion",
|
|
188
213
|
"AugmentedLagrangian",
|
|
214
|
+
"AdaptiveProximalWeight",
|
|
189
215
|
"ConstantProximalWeight",
|
|
190
216
|
"RampProximalWeight",
|
|
191
217
|
]
|
openscvx/_version.py
CHANGED
|
@@ -18,7 +18,7 @@ version_tuple: tuple[int | str, ...]
|
|
|
18
18
|
commit_id: str | None
|
|
19
19
|
__commit_id__: str | None
|
|
20
20
|
|
|
21
|
-
__version__ = version = '0.5.2.
|
|
22
|
-
__version_tuple__ = version_tuple = (0, 5, 2, '
|
|
21
|
+
__version__ = version = '0.5.2.dev11'
|
|
22
|
+
__version_tuple__ = version_tuple = (0, 5, 2, 'dev11')
|
|
23
23
|
|
|
24
24
|
__commit_id__ = commit_id = None
|
openscvx/algorithms/__init__.py
CHANGED
|
@@ -82,12 +82,19 @@ from typing import Annotated, Any, Dict, List, Optional, Union
|
|
|
82
82
|
|
|
83
83
|
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
|
84
84
|
|
|
85
|
-
from .
|
|
85
|
+
from .autotuner import (
|
|
86
|
+
AdaptiveProximalWeight,
|
|
87
|
+
AdaptiveProximalWeightSpec,
|
|
88
|
+
AugmentedLagrangian,
|
|
89
|
+
AugmentedLagrangianSpec,
|
|
90
|
+
ConstantProximalWeight,
|
|
91
|
+
ConstantProximalWeightSpec,
|
|
92
|
+
RampProximalWeight,
|
|
93
|
+
RampProximalWeightSpec,
|
|
94
|
+
)
|
|
86
95
|
from .base import Algorithm, AlgorithmState, AutotuningBase, DiscretizationResult
|
|
87
|
-
from .constant_proximal_weight import ConstantProximalWeight, ConstantProximalWeightSpec
|
|
88
96
|
from .optimization_results import OptimizationResults
|
|
89
|
-
from .
|
|
90
|
-
from .ramp_proximal_weight import RampProximalWeight, RampProximalWeightSpec
|
|
97
|
+
from .scvx import PenalizedTrustRegion
|
|
91
98
|
from .weights import Weights
|
|
92
99
|
|
|
93
100
|
# ---------------------------------------------------------------------------
|
|
@@ -97,6 +104,7 @@ from .weights import Weights
|
|
|
97
104
|
AutotunerConfig = Annotated[
|
|
98
105
|
Union[
|
|
99
106
|
AugmentedLagrangianSpec,
|
|
107
|
+
AdaptiveProximalWeightSpec,
|
|
100
108
|
RampProximalWeightSpec,
|
|
101
109
|
ConstantProximalWeightSpec,
|
|
102
110
|
],
|
|
@@ -173,6 +181,7 @@ __all__ = [
|
|
|
173
181
|
"PenalizedTrustRegion",
|
|
174
182
|
"AutotuningBase",
|
|
175
183
|
"AugmentedLagrangian",
|
|
184
|
+
"AdaptiveProximalWeight",
|
|
176
185
|
"ConstantProximalWeight",
|
|
177
186
|
"RampProximalWeight",
|
|
178
187
|
# Config models
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
"""SCP weight autotuning strategies."""
|
|
2
|
+
|
|
3
|
+
from .adaptive_proximal_weight import AdaptiveProximalWeight, AdaptiveProximalWeightSpec
|
|
4
|
+
from .augmented_lagrangian import AugmentedLagrangian, AugmentedLagrangianSpec
|
|
5
|
+
from .constant_proximal_weight import ConstantProximalWeight, ConstantProximalWeightSpec
|
|
6
|
+
from .ramp_proximal_weight import RampProximalWeight, RampProximalWeightSpec
|
|
7
|
+
|
|
8
|
+
__all__ = [
|
|
9
|
+
"AdaptiveProximalWeight",
|
|
10
|
+
"AdaptiveProximalWeightSpec",
|
|
11
|
+
"AugmentedLagrangian",
|
|
12
|
+
"AugmentedLagrangianSpec",
|
|
13
|
+
"ConstantProximalWeight",
|
|
14
|
+
"ConstantProximalWeightSpec",
|
|
15
|
+
"RampProximalWeight",
|
|
16
|
+
"RampProximalWeightSpec",
|
|
17
|
+
]
|
|
@@ -0,0 +1,190 @@
|
|
|
1
|
+
"""Autotuning functions for SCP (Successive Convex Programming) parameters."""
|
|
2
|
+
|
|
3
|
+
from copy import deepcopy
|
|
4
|
+
from typing import TYPE_CHECKING, Literal
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
from pydantic import BaseModel, ConfigDict
|
|
8
|
+
|
|
9
|
+
from openscvx.config import Config
|
|
10
|
+
|
|
11
|
+
from ..base import AutotuningBase
|
|
12
|
+
from .augmented_lagrangian import AugmentedLagrangian
|
|
13
|
+
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
from openscvx.lowered import LoweredJaxConstraints
|
|
16
|
+
|
|
17
|
+
from ..base import AlgorithmState, CandidateIterate
|
|
18
|
+
from ..weights import Weights
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class AdaptiveProximalWeight(AutotuningBase):
|
|
22
|
+
"""PTR-style proximal adaptation with fixed virtual penalty weights.
|
|
23
|
+
|
|
24
|
+
Same acceptance-ratio logic as :class:`AugmentedLagrangian` for ``lam_prox``,
|
|
25
|
+
but ``lam_vc`` and ``lam_vb_*`` are held constant at their current state values.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
COLUMNS = AugmentedLagrangian.COLUMNS
|
|
29
|
+
|
|
30
|
+
def __init__(
|
|
31
|
+
self,
|
|
32
|
+
gamma_1: float = 2.0,
|
|
33
|
+
gamma_2: float = 0.5,
|
|
34
|
+
eta_0: float = 1e-2,
|
|
35
|
+
eta_1: float = 1e-1,
|
|
36
|
+
eta_2: float = 0.8,
|
|
37
|
+
lam_prox_min: float = 1e-3,
|
|
38
|
+
lam_prox_max: float = 1e4,
|
|
39
|
+
lam_cost_drop: int = -1,
|
|
40
|
+
lam_cost_relax: float = 1.0,
|
|
41
|
+
):
|
|
42
|
+
self.gamma_1 = gamma_1
|
|
43
|
+
self.gamma_2 = gamma_2
|
|
44
|
+
self.eta_0 = eta_0
|
|
45
|
+
self.eta_1 = eta_1
|
|
46
|
+
self.eta_2 = eta_2
|
|
47
|
+
self.lam_prox_min = lam_prox_min
|
|
48
|
+
self.lam_prox_max = lam_prox_max
|
|
49
|
+
self.lam_cost_drop = lam_cost_drop
|
|
50
|
+
self.lam_cost_relax = lam_cost_relax
|
|
51
|
+
|
|
52
|
+
@staticmethod
|
|
53
|
+
def _copy_virtual_weights(
|
|
54
|
+
candidate: "CandidateIterate",
|
|
55
|
+
state: "AlgorithmState",
|
|
56
|
+
) -> None:
|
|
57
|
+
candidate.lam_vc = state.lam_vc
|
|
58
|
+
candidate.lam_vb_nodal = state.lam_vb_nodal
|
|
59
|
+
candidate.lam_vb_cross = state.lam_vb_cross
|
|
60
|
+
|
|
61
|
+
def update_weights(
|
|
62
|
+
self,
|
|
63
|
+
state: "AlgorithmState",
|
|
64
|
+
candidate: "CandidateIterate",
|
|
65
|
+
nodal_constraints: "LoweredJaxConstraints",
|
|
66
|
+
settings: Config,
|
|
67
|
+
params: dict,
|
|
68
|
+
weights: "Weights",
|
|
69
|
+
) -> str:
|
|
70
|
+
"""Update SCP proximal weight based on acceptance ratio; keep VC/VB fixed."""
|
|
71
|
+
candidate_x_prop = (
|
|
72
|
+
candidate.x_prop_plus[1:] if candidate.x_prop_plus is not None else candidate.x_prop
|
|
73
|
+
)
|
|
74
|
+
(
|
|
75
|
+
nonlinear_cost,
|
|
76
|
+
nonlinear_penalty,
|
|
77
|
+
nodal_penalty,
|
|
78
|
+
) = self.calculate_nonlinear_penalty(
|
|
79
|
+
candidate_x_prop,
|
|
80
|
+
candidate.x,
|
|
81
|
+
candidate.u,
|
|
82
|
+
state.lam_vc,
|
|
83
|
+
state.lam_vb_nodal,
|
|
84
|
+
state.lam_vb_cross,
|
|
85
|
+
state.lam_cost,
|
|
86
|
+
nodal_constraints,
|
|
87
|
+
params,
|
|
88
|
+
settings,
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
candidate.J_nonlin = nonlinear_cost + nonlinear_penalty + nodal_penalty
|
|
92
|
+
|
|
93
|
+
if state.k > self.lam_cost_drop:
|
|
94
|
+
candidate.lam_cost = state.lam_cost * self.lam_cost_relax
|
|
95
|
+
else:
|
|
96
|
+
candidate.lam_cost = weights.lam_cost
|
|
97
|
+
|
|
98
|
+
lam_prox_k = deepcopy(state.lam_prox)
|
|
99
|
+
|
|
100
|
+
if state.k > 1:
|
|
101
|
+
state_x_prop_plus = state.x_prop_plus()
|
|
102
|
+
state_x_prop = (
|
|
103
|
+
state_x_prop_plus[1:] if state_x_prop_plus is not None else state.x_prop()
|
|
104
|
+
)
|
|
105
|
+
(
|
|
106
|
+
prev_nonlinear_cost,
|
|
107
|
+
prev_nonlinear_penalty,
|
|
108
|
+
prev_nodal_penalty,
|
|
109
|
+
) = self.calculate_nonlinear_penalty(
|
|
110
|
+
state_x_prop,
|
|
111
|
+
state.x,
|
|
112
|
+
state.u,
|
|
113
|
+
state.lam_vc,
|
|
114
|
+
state.lam_vb_nodal,
|
|
115
|
+
state.lam_vb_cross,
|
|
116
|
+
state.lam_cost,
|
|
117
|
+
nodal_constraints,
|
|
118
|
+
params,
|
|
119
|
+
settings,
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
J_nonlin_prev = prev_nonlinear_cost + prev_nonlinear_penalty + prev_nodal_penalty
|
|
123
|
+
|
|
124
|
+
actual_reduction = J_nonlin_prev - candidate.J_nonlin
|
|
125
|
+
predicted_reduction = J_nonlin_prev - candidate.J_lin
|
|
126
|
+
|
|
127
|
+
if predicted_reduction == 0:
|
|
128
|
+
raise ValueError("Predicted reduction is 0.")
|
|
129
|
+
|
|
130
|
+
rho = actual_reduction / predicted_reduction
|
|
131
|
+
|
|
132
|
+
state.pred_reduction_history.append(predicted_reduction)
|
|
133
|
+
state.actual_reduction_history.append(actual_reduction)
|
|
134
|
+
state.acceptance_ratio_history.append(rho)
|
|
135
|
+
|
|
136
|
+
if rho < self.eta_0:
|
|
137
|
+
lam_prox_k1 = np.minimum(self.lam_prox_max, self.gamma_1 * lam_prox_k)
|
|
138
|
+
candidate.lam_prox = lam_prox_k1
|
|
139
|
+
state.reject_solution(candidate)
|
|
140
|
+
adaptive_state = "Reject Higher"
|
|
141
|
+
elif rho >= self.eta_0 and rho < self.eta_1:
|
|
142
|
+
lam_prox_k1 = np.minimum(self.lam_prox_max, self.gamma_1 * lam_prox_k)
|
|
143
|
+
candidate.lam_prox = lam_prox_k1
|
|
144
|
+
self._copy_virtual_weights(candidate, state)
|
|
145
|
+
state.accept_solution(candidate)
|
|
146
|
+
adaptive_state = "Accept Higher"
|
|
147
|
+
elif rho >= self.eta_1 and rho < self.eta_2:
|
|
148
|
+
candidate.lam_prox = lam_prox_k
|
|
149
|
+
self._copy_virtual_weights(candidate, state)
|
|
150
|
+
state.accept_solution(candidate)
|
|
151
|
+
adaptive_state = "Accept Constant"
|
|
152
|
+
else:
|
|
153
|
+
lam_prox_k1 = np.maximum(self.lam_prox_min, self.gamma_2 * lam_prox_k)
|
|
154
|
+
candidate.lam_prox = lam_prox_k1
|
|
155
|
+
self._copy_virtual_weights(candidate, state)
|
|
156
|
+
state.accept_solution(candidate)
|
|
157
|
+
adaptive_state = "Accept Lower"
|
|
158
|
+
|
|
159
|
+
else:
|
|
160
|
+
candidate.lam_prox = lam_prox_k
|
|
161
|
+
self._copy_virtual_weights(candidate, state)
|
|
162
|
+
state.accept_solution(candidate)
|
|
163
|
+
adaptive_state = "Initial"
|
|
164
|
+
|
|
165
|
+
return adaptive_state
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
# =============================================================================
|
|
169
|
+
# Pydantic spec for dict / YAML validation
|
|
170
|
+
# =============================================================================
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
class AdaptiveProximalWeightSpec(BaseModel):
|
|
174
|
+
"""Validates AdaptiveProximalWeight configuration from dict/YAML input."""
|
|
175
|
+
|
|
176
|
+
type: Literal["AdaptiveProximalWeight"] = "AdaptiveProximalWeight"
|
|
177
|
+
gamma_1: float = 2.0
|
|
178
|
+
gamma_2: float = 0.5
|
|
179
|
+
eta_0: float = 1e-2
|
|
180
|
+
eta_1: float = 1e-1
|
|
181
|
+
eta_2: float = 0.8
|
|
182
|
+
lam_prox_min: float = 1e-3
|
|
183
|
+
lam_prox_max: float = 1e4
|
|
184
|
+
lam_cost_drop: int = -1
|
|
185
|
+
lam_cost_relax: float = 1.0
|
|
186
|
+
|
|
187
|
+
model_config = ConfigDict(extra="forbid")
|
|
188
|
+
|
|
189
|
+
def build(self) -> AdaptiveProximalWeight:
|
|
190
|
+
return AdaptiveProximalWeight(**self.model_dump(exclude={"type"}, exclude_unset=True))
|
|
@@ -14,13 +14,13 @@ from openscvx.utils.printing import (
|
|
|
14
14
|
color_adaptive_state,
|
|
15
15
|
)
|
|
16
16
|
|
|
17
|
-
from
|
|
17
|
+
from ..base import AutotuningBase
|
|
18
18
|
|
|
19
19
|
if TYPE_CHECKING:
|
|
20
20
|
from openscvx.lowered import LoweredJaxConstraints
|
|
21
21
|
|
|
22
|
-
from
|
|
23
|
-
from
|
|
22
|
+
from ..base import AlgorithmState, CandidateIterate
|
|
23
|
+
from ..weights import Weights
|
|
24
24
|
|
|
25
25
|
|
|
26
26
|
class AugmentedLagrangian(AutotuningBase):
|
|
@@ -86,8 +86,12 @@ class AugmentedLagrangian(AutotuningBase):
|
|
|
86
86
|
eta_2: Threshold above which solution is accepted with lower weight.
|
|
87
87
|
Defaults to 0.8.
|
|
88
88
|
ep: Threshold for virtual control weight update (nu > ep vs nu <= ep).
|
|
89
|
-
Defaults to 0.
|
|
90
|
-
|
|
89
|
+
Must lie in (0, 1). Defaults to 0.99; when tuning, try 1e-1 if needed.
|
|
90
|
+
Typically tuned together with ``eta_lambda`` (often the first
|
|
91
|
+
parameters adjusted).
|
|
92
|
+
eta_lambda: Step size for virtual control weight update. Defaults to 1e1.
|
|
93
|
+
Typically tuned together with ``ep`` (often the first parameters
|
|
94
|
+
adjusted).
|
|
91
95
|
lam_vc_max: Maximum virtual control penalty weight. Defaults to 1e5.
|
|
92
96
|
lam_prox_min: Minimum trust region (proximal) weight. Defaults to 1e-3.
|
|
93
97
|
lam_prox_max: Maximum trust region (proximal) weight. Defaults to 2e5.
|
|
@@ -6,13 +6,13 @@ from pydantic import BaseModel, ConfigDict
|
|
|
6
6
|
|
|
7
7
|
from openscvx.config import Config
|
|
8
8
|
|
|
9
|
-
from
|
|
9
|
+
from ..base import AutotuningBase
|
|
10
10
|
|
|
11
11
|
if TYPE_CHECKING:
|
|
12
12
|
from openscvx.lowered import LoweredJaxConstraints
|
|
13
13
|
|
|
14
|
-
from
|
|
15
|
-
from
|
|
14
|
+
from ..base import AlgorithmState, CandidateIterate
|
|
15
|
+
from ..weights import Weights
|
|
16
16
|
|
|
17
17
|
|
|
18
18
|
class ConstantProximalWeight(AutotuningBase):
|
|
@@ -7,13 +7,13 @@ from pydantic import BaseModel, ConfigDict
|
|
|
7
7
|
|
|
8
8
|
from openscvx.config import Config
|
|
9
9
|
|
|
10
|
-
from
|
|
10
|
+
from ..base import AutotuningBase
|
|
11
11
|
|
|
12
12
|
if TYPE_CHECKING:
|
|
13
13
|
from openscvx.lowered import LoweredJaxConstraints
|
|
14
14
|
|
|
15
|
-
from
|
|
16
|
-
from
|
|
15
|
+
from ..base import AlgorithmState, CandidateIterate
|
|
16
|
+
from ..weights import Weights
|
|
17
17
|
|
|
18
18
|
|
|
19
19
|
class RampProximalWeight(AutotuningBase):
|
|
@@ -6,8 +6,9 @@ optimization problems through iterative convex approximation.
|
|
|
6
6
|
|
|
7
7
|
import time
|
|
8
8
|
import warnings
|
|
9
|
-
from typing import TYPE_CHECKING, Callable, Dict, List, Union
|
|
9
|
+
from typing import TYPE_CHECKING, Callable, Dict, List, Tuple, Union
|
|
10
10
|
|
|
11
|
+
import jax
|
|
11
12
|
import numpy as np
|
|
12
13
|
import numpy.linalg as la
|
|
13
14
|
|
|
@@ -21,11 +22,11 @@ from openscvx.utils.printing import (
|
|
|
21
22
|
color_prob_stat,
|
|
22
23
|
)
|
|
23
24
|
|
|
24
|
-
from .augmented_lagrangian import AugmentedLagrangian
|
|
25
|
-
from .
|
|
26
|
-
from .
|
|
27
|
-
from
|
|
28
|
-
from
|
|
25
|
+
from ..autotuner.augmented_lagrangian import AugmentedLagrangian
|
|
26
|
+
from ..autotuner.constant_proximal_weight import ConstantProximalWeight
|
|
27
|
+
from ..autotuner.ramp_proximal_weight import RampProximalWeight
|
|
28
|
+
from ..base import Algorithm, AlgorithmState, CandidateIterate
|
|
29
|
+
from ..weights import Weights
|
|
29
30
|
|
|
30
31
|
if TYPE_CHECKING:
|
|
31
32
|
from openscvx.lowered import LoweredJaxConstraints
|
|
@@ -33,7 +34,7 @@ if TYPE_CHECKING:
|
|
|
33
34
|
from openscvx.symbolic.expr.control import Control
|
|
34
35
|
from openscvx.symbolic.expr.state import State
|
|
35
36
|
|
|
36
|
-
from
|
|
37
|
+
from ..base import AutotuningBase
|
|
37
38
|
|
|
38
39
|
warnings.filterwarnings("ignore")
|
|
39
40
|
|
|
@@ -188,6 +189,11 @@ class PenalizedTrustRegion(Algorithm):
|
|
|
188
189
|
return solver.call(*args)
|
|
189
190
|
return solver(*args)
|
|
190
191
|
|
|
192
|
+
@staticmethod
|
|
193
|
+
def _block_until_ready_outputs(outputs: Tuple[object, ...]) -> None:
|
|
194
|
+
"""Finish any pending XLA work from discretization exports (warm-up helper)."""
|
|
195
|
+
jax.block_until_ready(outputs)
|
|
196
|
+
|
|
191
197
|
def _recover_prior_node_from_initial(
|
|
192
198
|
self,
|
|
193
199
|
settings: Config,
|
|
@@ -286,7 +292,8 @@ class PenalizedTrustRegion(Algorithm):
|
|
|
286
292
|
"""Initialize PTR algorithm.
|
|
287
293
|
|
|
288
294
|
Stores compiled infrastructure and performs a warm-start solve to
|
|
289
|
-
initialize DPP and JAX jacobians.
|
|
295
|
+
initialize DPP and JAX jacobians. Also runs post-subproblem discretization
|
|
296
|
+
on the throwaway CVX solution so XLA/export caches match the first ``step()``.
|
|
290
297
|
|
|
291
298
|
Args:
|
|
292
299
|
solver: Convex subproblem solver (e.g., CVXPySolver)
|
|
@@ -333,7 +340,29 @@ class PenalizedTrustRegion(Algorithm):
|
|
|
333
340
|
params,
|
|
334
341
|
)
|
|
335
342
|
init_state.add_impulsive_discretization(W_multi_shoot.__array__())
|
|
336
|
-
_ = self._subproblem(params, init_state, settings)
|
|
343
|
+
(x_sol, u_sol, *_) = self._subproblem(params, init_state, settings)
|
|
344
|
+
|
|
345
|
+
# Prime the same exported discretization calls used after every subproblem in
|
|
346
|
+
# step() (candidate trajectory). initialize() previously only discretized the
|
|
347
|
+
# initial guess, so the first step() in solve() could still hit an XLA cache_miss
|
|
348
|
+
# on post-CVX (x_sol, u_sol). Running that path here moves compilation into init.
|
|
349
|
+
cont_out = self._invoke_solver(
|
|
350
|
+
self._discretization_solver, x_sol, u_sol.astype(float), params
|
|
351
|
+
)
|
|
352
|
+
x_prop_c = cont_out[3]
|
|
353
|
+
u_candidate = u_sol.astype(float)
|
|
354
|
+
x0_prior_c = self._recover_prior_node_from_initial(settings, x_sol[0])
|
|
355
|
+
x_nodes_prior_c = np.vstack((x0_prior_c, np.asarray(x_prop_c)))
|
|
356
|
+
if self._discretization_solver_impulsive is not None:
|
|
357
|
+
imp_out = self._invoke_solver(
|
|
358
|
+
self._discretization_solver_impulsive,
|
|
359
|
+
x_nodes_prior_c,
|
|
360
|
+
u_candidate,
|
|
361
|
+
params,
|
|
362
|
+
)
|
|
363
|
+
self._block_until_ready_outputs(cont_out + imp_out)
|
|
364
|
+
else:
|
|
365
|
+
self._block_until_ready_outputs(cont_out)
|
|
337
366
|
|
|
338
367
|
def step(
|
|
339
368
|
self,
|
|
@@ -188,7 +188,12 @@ class VectorizeDiscretizeLinearize(Discretizer):
|
|
|
188
188
|
in_axes=-1,
|
|
189
189
|
out_axes=-1,
|
|
190
190
|
)
|
|
191
|
-
|
|
191
|
+
# Cast to match primal's dtype. standard_basis is built eagerly
|
|
192
|
+
# from jnp.eye (float32 when x64 is disabled), but jax.export
|
|
193
|
+
# traces with the literal dtype of the dummy arrays (float64 from
|
|
194
|
+
# np.ones), so primal may be float64 while the closure-captured
|
|
195
|
+
# standard_basis is float32. jax.jvp requires identical dtypes.
|
|
196
|
+
jacobians = pushforward(standard_basis.astype(primal.dtype))
|
|
192
197
|
|
|
193
198
|
A_d = jacobians[:, :, :, i0:i1]
|
|
194
199
|
B_d = jacobians[:, :, :, i1:i2]
|
|
@@ -125,7 +125,12 @@ def _sparse_jacobian_fn(
|
|
|
125
125
|
_, jvp_out = jax.jvp(f_of_target, (primals[argnums],), (tangent,))
|
|
126
126
|
return jvp_out
|
|
127
127
|
|
|
128
|
-
|
|
128
|
+
# Cast seeds to match the primal dtype. seeds is built eagerly as
|
|
129
|
+
# float32 (x64 disabled), but jax.export traces with the literal
|
|
130
|
+
# dtype of the dummy arrays (float64 from np.ones), so primals may
|
|
131
|
+
# be float64. jax.jvp requires matching dtypes.
|
|
132
|
+
typed_seeds = seeds.astype(primals[argnums].dtype)
|
|
133
|
+
compressed = jax.vmap(single_jvp)(typed_seeds) # (n_colors, out_dim)
|
|
129
134
|
|
|
130
135
|
values = compressed[scatter_color, scatter_row]
|
|
131
136
|
jac = jnp.zeros((out_dim, in_dim))
|
|
@@ -1,47 +1,46 @@
|
|
|
1
|
-
"""
|
|
1
|
+
"""External-backend dynamics adapters for OpenSCvx.
|
|
2
2
|
|
|
3
|
-
The recommended entry-point is
|
|
4
|
-
``
|
|
5
|
-
|
|
3
|
+
The recommended entry-point is `MjxDynamics`, which goes directly into the
|
|
4
|
+
``dynamics=`` slot of `Problem` and constructs the matching State/Control
|
|
5
|
+
objects for the user. Free-joint quaternion kinematics for floating-base
|
|
6
|
+
models (drones, humanoids) are detected and handled automatically::
|
|
6
7
|
|
|
7
|
-
from openscvx.integrations import
|
|
8
|
+
from openscvx.integrations import MjxDynamics
|
|
8
9
|
|
|
9
|
-
|
|
10
|
+
dyn = MjxDynamics(mjx_model)
|
|
11
|
+
problem = ox.Problem(
|
|
12
|
+
dynamics=dyn,
|
|
13
|
+
states=dyn.states,
|
|
14
|
+
controls=dyn.controls,
|
|
15
|
+
...
|
|
16
|
+
)
|
|
10
17
|
|
|
11
|
-
For
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
(``nq > nv``) ``"qpos"`` is included automatically — no extra imports needed.
|
|
18
|
+
For advanced users who need custom State/Control names (or to interleave
|
|
19
|
+
them with extra custom states), `mjx_dynamics` is exposed as the underlying
|
|
20
|
+
BYOF callable factory — assemble your own ``byof["dynamics"]`` dict from it.
|
|
15
21
|
|
|
16
|
-
|
|
17
|
-
|
|
22
|
+
All MJX symbols delegate lazily so ``mujoco.mjx`` is only imported when
|
|
23
|
+
actually used. The ``menagerie`` submodule is also loaded lazily.
|
|
18
24
|
|
|
19
|
-
|
|
20
|
-
The :mod:`menagerie` submodule is loaded lazily via attribute access.
|
|
25
|
+
Example — cartpole (``nq == nv``)::
|
|
21
26
|
|
|
22
|
-
|
|
27
|
+
from openscvx.integrations import MjxDynamics
|
|
23
28
|
|
|
24
|
-
|
|
29
|
+
dyn = MjxDynamics(mjx_model)
|
|
30
|
+
problem = ox.Problem(dynamics=dyn, states=dyn.states, controls=dyn.controls, ...)
|
|
25
31
|
|
|
26
|
-
|
|
27
|
-
problem = ox.Problem(dynamics={"qpos": qvel}, byof=byof, ...)
|
|
32
|
+
Example — quadrotor with free joint (``nq=7``, ``nv=6``)::
|
|
28
33
|
|
|
29
|
-
|
|
34
|
+
from openscvx.integrations import MjxDynamics
|
|
30
35
|
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
byof = {"dynamics": mjx_byof(mjx_model, qpos=qpos, qvel=qvel, ctrl=ctrl)}
|
|
34
|
-
problem = ox.Problem(dynamics={}, byof=byof, ...)
|
|
36
|
+
dyn = MjxDynamics(mjx_model)
|
|
37
|
+
problem = ox.Problem(dynamics=dyn, states=dyn.states, controls=dyn.controls, ...)
|
|
35
38
|
"""
|
|
36
39
|
|
|
37
40
|
from typing import Any
|
|
38
41
|
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
"""Lazy delegate; imports ``mujoco.mjx`` on first call."""
|
|
42
|
-
from .mjx import mjx_byof as _mjx_byof
|
|
43
|
-
|
|
44
|
-
return _mjx_byof(*args, **kwargs)
|
|
42
|
+
from .base import DynamicsAdapter
|
|
43
|
+
from .mjx import MjxDynamics
|
|
45
44
|
|
|
46
45
|
|
|
47
46
|
def mjx_dynamics(*args: Any, **kwargs: Any) -> Any:
|
|
@@ -59,4 +58,9 @@ def __getattr__(name: str) -> Any:
|
|
|
59
58
|
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
|
60
59
|
|
|
61
60
|
|
|
62
|
-
__all__ = [
|
|
61
|
+
__all__ = [
|
|
62
|
+
"DynamicsAdapter",
|
|
63
|
+
"MjxDynamics",
|
|
64
|
+
"mjx_dynamics",
|
|
65
|
+
"menagerie",
|
|
66
|
+
]
|