openscvx 0.5.2.dev2__py3-none-any.whl → 0.5.2.dev12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. openscvx/__init__.py +27 -1
  2. openscvx/_version.py +2 -2
  3. openscvx/algorithms/__init__.py +13 -4
  4. openscvx/algorithms/autotuner/__init__.py +17 -0
  5. openscvx/algorithms/autotuner/adaptive_proximal_weight.py +190 -0
  6. openscvx/algorithms/{augmented_lagrangian.py → autotuner/augmented_lagrangian.py} +9 -5
  7. openscvx/algorithms/{constant_proximal_weight.py → autotuner/constant_proximal_weight.py} +3 -3
  8. openscvx/algorithms/{ramp_proximal_weight.py → autotuner/ramp_proximal_weight.py} +3 -3
  9. openscvx/algorithms/scvx/__init__.py +5 -0
  10. openscvx/algorithms/{penalized_trust_region.py → scvx/penalized_trust_region.py} +38 -9
  11. openscvx/discretization/discretize_linearize.py +6 -1
  12. openscvx/discretization/sparse_utils/sparse_jacobian.py +6 -1
  13. openscvx/integrations/__init__.py +34 -30
  14. openscvx/integrations/base.py +89 -0
  15. openscvx/integrations/mjx.py +247 -74
  16. openscvx/loader.py +2 -2
  17. openscvx/lowered/cvxpy_constraints.py +4 -1
  18. openscvx/plotting/viser/server.py +2 -2
  19. openscvx/problem.py +36 -17
  20. openscvx/solvers/__init__.py +51 -24
  21. openscvx/solvers/base.py +124 -36
  22. openscvx/solvers/cvxpy_ptr_solver.py +908 -0
  23. openscvx/solvers/moreau_ptr_solver.py +1125 -0
  24. openscvx/solvers/ptr_solver.py +99 -846
  25. openscvx/solvers/qpax_ptr_solver.py +791 -0
  26. openscvx/symbolic/hashing.py +16 -3
  27. openscvx/symbolic/lower.py +6 -7
  28. openscvx/utils/printing.py +9 -1
  29. {openscvx-0.5.2.dev2.dist-info → openscvx-0.5.2.dev12.dist-info}/METADATA +14 -16
  30. {openscvx-0.5.2.dev2.dist-info → openscvx-0.5.2.dev12.dist-info}/RECORD +34 -27
  31. {openscvx-0.5.2.dev2.dist-info → openscvx-0.5.2.dev12.dist-info}/WHEEL +0 -0
  32. {openscvx-0.5.2.dev2.dist-info → openscvx-0.5.2.dev12.dist-info}/entry_points.txt +0 -0
  33. {openscvx-0.5.2.dev2.dist-info → openscvx-0.5.2.dev12.dist-info}/licenses/LICENSE +0 -0
  34. {openscvx-0.5.2.dev2.dist-info → openscvx-0.5.2.dev12.dist-info}/top_level.txt +0 -0
openscvx/__init__.py CHANGED
@@ -12,6 +12,7 @@ import openscvx.symbolic.expr.spatial as spatial
12
12
  import openscvx.symbolic.expr.stl as stl
13
13
  import openscvx.symbolic.expr.stljax as stljax
14
14
  from openscvx.algorithms import (
15
+ AdaptiveProximalWeight,
15
16
  AugmentedLagrangian,
16
17
  ConstantProximalWeight,
17
18
  PenalizedTrustRegion,
@@ -25,9 +26,13 @@ from openscvx.discretization import (
25
26
  VectorizeDiscretizeLinearize,
26
27
  )
27
28
  from openscvx.expert import ByofSpec
29
+ from openscvx.integrations import DynamicsAdapter, MjxDynamics
28
30
  from openscvx.loader import load_dict, load_json, load_yaml
29
31
  from openscvx.problem import Problem
30
- from openscvx.solvers import PTRSolver
32
+ from openscvx.solvers import CVXPyPTRSolver, PTRSolver
33
+
34
+ # QPAXPTRSolver is exposed lazily via __getattr__ below to keep `import qpax`
35
+ # off the hot import path for users who don't install the optional extra.
31
36
  from openscvx.symbolic.expr import (
32
37
  CTCS,
33
38
  Abs,
@@ -89,6 +94,20 @@ from openscvx.utils.cache import clear_cache, get_cache_dir, get_cache_size
89
94
 
90
95
  load_results = OptimizationResults.load
91
96
 
97
+
98
+ def __getattr__(name: str):
99
+ """Lazy export for backends that depend on optional packages."""
100
+ if name == "QPAXPTRSolver":
101
+ from openscvx.solvers.qpax_ptr_solver import QPAXPTRSolver
102
+
103
+ return QPAXPTRSolver
104
+ if name == "MoreauPTRSolver":
105
+ from openscvx.solvers.moreau_ptr_solver import MoreauPTRSolver
106
+
107
+ return MoreauPTRSolver
108
+ raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
109
+
110
+
92
111
  __all__ = [
93
112
  # Main Trajectory Optimization Entrypoint
94
113
  "Problem",
@@ -176,6 +195,9 @@ __all__ = [
176
195
  "lie",
177
196
  # Expert mode types
178
197
  "ByofSpec",
198
+ # External-backend dynamics adapters
199
+ "DynamicsAdapter",
200
+ "MjxDynamics",
179
201
  # Discretization
180
202
  "DiscretizeLinearizeVectorize",
181
203
  "LinearizeDiscretize",
@@ -183,9 +205,13 @@ __all__ = [
183
205
  "VectorizeDiscretizeLinearize",
184
206
  # Convex Solver
185
207
  "PTRSolver",
208
+ "CVXPyPTRSolver",
209
+ "QPAXPTRSolver",
210
+ "MoreauPTRSolver",
186
211
  # Algorithm & Autotuning
187
212
  "PenalizedTrustRegion",
188
213
  "AugmentedLagrangian",
214
+ "AdaptiveProximalWeight",
189
215
  "ConstantProximalWeight",
190
216
  "RampProximalWeight",
191
217
  ]
openscvx/_version.py CHANGED
@@ -18,7 +18,7 @@ version_tuple: tuple[int | str, ...]
18
18
  commit_id: str | None
19
19
  __commit_id__: str | None
20
20
 
21
- __version__ = version = '0.5.2.dev2'
22
- __version_tuple__ = version_tuple = (0, 5, 2, 'dev2')
21
+ __version__ = version = '0.5.2.dev12'
22
+ __version_tuple__ = version_tuple = (0, 5, 2, 'dev12')
23
23
 
24
24
  __commit_id__ = commit_id = None
@@ -82,12 +82,19 @@ from typing import Annotated, Any, Dict, List, Optional, Union
82
82
 
83
83
  from pydantic import BaseModel, ConfigDict, Field, field_validator
84
84
 
85
- from .augmented_lagrangian import AugmentedLagrangian, AugmentedLagrangianSpec
85
+ from .autotuner import (
86
+ AdaptiveProximalWeight,
87
+ AdaptiveProximalWeightSpec,
88
+ AugmentedLagrangian,
89
+ AugmentedLagrangianSpec,
90
+ ConstantProximalWeight,
91
+ ConstantProximalWeightSpec,
92
+ RampProximalWeight,
93
+ RampProximalWeightSpec,
94
+ )
86
95
  from .base import Algorithm, AlgorithmState, AutotuningBase, DiscretizationResult
87
- from .constant_proximal_weight import ConstantProximalWeight, ConstantProximalWeightSpec
88
96
  from .optimization_results import OptimizationResults
89
- from .penalized_trust_region import PenalizedTrustRegion
90
- from .ramp_proximal_weight import RampProximalWeight, RampProximalWeightSpec
97
+ from .scvx import PenalizedTrustRegion
91
98
  from .weights import Weights
92
99
 
93
100
  # ---------------------------------------------------------------------------
@@ -97,6 +104,7 @@ from .weights import Weights
97
104
  AutotunerConfig = Annotated[
98
105
  Union[
99
106
  AugmentedLagrangianSpec,
107
+ AdaptiveProximalWeightSpec,
100
108
  RampProximalWeightSpec,
101
109
  ConstantProximalWeightSpec,
102
110
  ],
@@ -173,6 +181,7 @@ __all__ = [
173
181
  "PenalizedTrustRegion",
174
182
  "AutotuningBase",
175
183
  "AugmentedLagrangian",
184
+ "AdaptiveProximalWeight",
176
185
  "ConstantProximalWeight",
177
186
  "RampProximalWeight",
178
187
  # Config models
@@ -0,0 +1,17 @@
1
+ """SCP weight autotuning strategies."""
2
+
3
+ from .adaptive_proximal_weight import AdaptiveProximalWeight, AdaptiveProximalWeightSpec
4
+ from .augmented_lagrangian import AugmentedLagrangian, AugmentedLagrangianSpec
5
+ from .constant_proximal_weight import ConstantProximalWeight, ConstantProximalWeightSpec
6
+ from .ramp_proximal_weight import RampProximalWeight, RampProximalWeightSpec
7
+
8
+ __all__ = [
9
+ "AdaptiveProximalWeight",
10
+ "AdaptiveProximalWeightSpec",
11
+ "AugmentedLagrangian",
12
+ "AugmentedLagrangianSpec",
13
+ "ConstantProximalWeight",
14
+ "ConstantProximalWeightSpec",
15
+ "RampProximalWeight",
16
+ "RampProximalWeightSpec",
17
+ ]
@@ -0,0 +1,190 @@
1
+ """Autotuning functions for SCP (Successive Convex Programming) parameters."""
2
+
3
+ from copy import deepcopy
4
+ from typing import TYPE_CHECKING, Literal
5
+
6
+ import numpy as np
7
+ from pydantic import BaseModel, ConfigDict
8
+
9
+ from openscvx.config import Config
10
+
11
+ from ..base import AutotuningBase
12
+ from .augmented_lagrangian import AugmentedLagrangian
13
+
14
+ if TYPE_CHECKING:
15
+ from openscvx.lowered import LoweredJaxConstraints
16
+
17
+ from ..base import AlgorithmState, CandidateIterate
18
+ from ..weights import Weights
19
+
20
+
21
+ class AdaptiveProximalWeight(AutotuningBase):
22
+ """PTR-style proximal adaptation with fixed virtual penalty weights.
23
+
24
+ Same acceptance-ratio logic as :class:`AugmentedLagrangian` for ``lam_prox``,
25
+ but ``lam_vc`` and ``lam_vb_*`` are held constant at their current state values.
26
+ """
27
+
28
+ COLUMNS = AugmentedLagrangian.COLUMNS
29
+
30
+ def __init__(
31
+ self,
32
+ gamma_1: float = 2.0,
33
+ gamma_2: float = 0.5,
34
+ eta_0: float = 1e-2,
35
+ eta_1: float = 1e-1,
36
+ eta_2: float = 0.8,
37
+ lam_prox_min: float = 1e-3,
38
+ lam_prox_max: float = 1e4,
39
+ lam_cost_drop: int = -1,
40
+ lam_cost_relax: float = 1.0,
41
+ ):
42
+ self.gamma_1 = gamma_1
43
+ self.gamma_2 = gamma_2
44
+ self.eta_0 = eta_0
45
+ self.eta_1 = eta_1
46
+ self.eta_2 = eta_2
47
+ self.lam_prox_min = lam_prox_min
48
+ self.lam_prox_max = lam_prox_max
49
+ self.lam_cost_drop = lam_cost_drop
50
+ self.lam_cost_relax = lam_cost_relax
51
+
52
+ @staticmethod
53
+ def _copy_virtual_weights(
54
+ candidate: "CandidateIterate",
55
+ state: "AlgorithmState",
56
+ ) -> None:
57
+ candidate.lam_vc = state.lam_vc
58
+ candidate.lam_vb_nodal = state.lam_vb_nodal
59
+ candidate.lam_vb_cross = state.lam_vb_cross
60
+
61
+ def update_weights(
62
+ self,
63
+ state: "AlgorithmState",
64
+ candidate: "CandidateIterate",
65
+ nodal_constraints: "LoweredJaxConstraints",
66
+ settings: Config,
67
+ params: dict,
68
+ weights: "Weights",
69
+ ) -> str:
70
+ """Update SCP proximal weight based on acceptance ratio; keep VC/VB fixed."""
71
+ candidate_x_prop = (
72
+ candidate.x_prop_plus[1:] if candidate.x_prop_plus is not None else candidate.x_prop
73
+ )
74
+ (
75
+ nonlinear_cost,
76
+ nonlinear_penalty,
77
+ nodal_penalty,
78
+ ) = self.calculate_nonlinear_penalty(
79
+ candidate_x_prop,
80
+ candidate.x,
81
+ candidate.u,
82
+ state.lam_vc,
83
+ state.lam_vb_nodal,
84
+ state.lam_vb_cross,
85
+ state.lam_cost,
86
+ nodal_constraints,
87
+ params,
88
+ settings,
89
+ )
90
+
91
+ candidate.J_nonlin = nonlinear_cost + nonlinear_penalty + nodal_penalty
92
+
93
+ if state.k > self.lam_cost_drop:
94
+ candidate.lam_cost = state.lam_cost * self.lam_cost_relax
95
+ else:
96
+ candidate.lam_cost = weights.lam_cost
97
+
98
+ lam_prox_k = deepcopy(state.lam_prox)
99
+
100
+ if state.k > 1:
101
+ state_x_prop_plus = state.x_prop_plus()
102
+ state_x_prop = (
103
+ state_x_prop_plus[1:] if state_x_prop_plus is not None else state.x_prop()
104
+ )
105
+ (
106
+ prev_nonlinear_cost,
107
+ prev_nonlinear_penalty,
108
+ prev_nodal_penalty,
109
+ ) = self.calculate_nonlinear_penalty(
110
+ state_x_prop,
111
+ state.x,
112
+ state.u,
113
+ state.lam_vc,
114
+ state.lam_vb_nodal,
115
+ state.lam_vb_cross,
116
+ state.lam_cost,
117
+ nodal_constraints,
118
+ params,
119
+ settings,
120
+ )
121
+
122
+ J_nonlin_prev = prev_nonlinear_cost + prev_nonlinear_penalty + prev_nodal_penalty
123
+
124
+ actual_reduction = J_nonlin_prev - candidate.J_nonlin
125
+ predicted_reduction = J_nonlin_prev - candidate.J_lin
126
+
127
+ if predicted_reduction == 0:
128
+ raise ValueError("Predicted reduction is 0.")
129
+
130
+ rho = actual_reduction / predicted_reduction
131
+
132
+ state.pred_reduction_history.append(predicted_reduction)
133
+ state.actual_reduction_history.append(actual_reduction)
134
+ state.acceptance_ratio_history.append(rho)
135
+
136
+ if rho < self.eta_0:
137
+ lam_prox_k1 = np.minimum(self.lam_prox_max, self.gamma_1 * lam_prox_k)
138
+ candidate.lam_prox = lam_prox_k1
139
+ state.reject_solution(candidate)
140
+ adaptive_state = "Reject Higher"
141
+ elif rho >= self.eta_0 and rho < self.eta_1:
142
+ lam_prox_k1 = np.minimum(self.lam_prox_max, self.gamma_1 * lam_prox_k)
143
+ candidate.lam_prox = lam_prox_k1
144
+ self._copy_virtual_weights(candidate, state)
145
+ state.accept_solution(candidate)
146
+ adaptive_state = "Accept Higher"
147
+ elif rho >= self.eta_1 and rho < self.eta_2:
148
+ candidate.lam_prox = lam_prox_k
149
+ self._copy_virtual_weights(candidate, state)
150
+ state.accept_solution(candidate)
151
+ adaptive_state = "Accept Constant"
152
+ else:
153
+ lam_prox_k1 = np.maximum(self.lam_prox_min, self.gamma_2 * lam_prox_k)
154
+ candidate.lam_prox = lam_prox_k1
155
+ self._copy_virtual_weights(candidate, state)
156
+ state.accept_solution(candidate)
157
+ adaptive_state = "Accept Lower"
158
+
159
+ else:
160
+ candidate.lam_prox = lam_prox_k
161
+ self._copy_virtual_weights(candidate, state)
162
+ state.accept_solution(candidate)
163
+ adaptive_state = "Initial"
164
+
165
+ return adaptive_state
166
+
167
+
168
+ # =============================================================================
169
+ # Pydantic spec for dict / YAML validation
170
+ # =============================================================================
171
+
172
+
173
+ class AdaptiveProximalWeightSpec(BaseModel):
174
+ """Validates AdaptiveProximalWeight configuration from dict/YAML input."""
175
+
176
+ type: Literal["AdaptiveProximalWeight"] = "AdaptiveProximalWeight"
177
+ gamma_1: float = 2.0
178
+ gamma_2: float = 0.5
179
+ eta_0: float = 1e-2
180
+ eta_1: float = 1e-1
181
+ eta_2: float = 0.8
182
+ lam_prox_min: float = 1e-3
183
+ lam_prox_max: float = 1e4
184
+ lam_cost_drop: int = -1
185
+ lam_cost_relax: float = 1.0
186
+
187
+ model_config = ConfigDict(extra="forbid")
188
+
189
+ def build(self) -> AdaptiveProximalWeight:
190
+ return AdaptiveProximalWeight(**self.model_dump(exclude={"type"}, exclude_unset=True))
@@ -14,13 +14,13 @@ from openscvx.utils.printing import (
14
14
  color_adaptive_state,
15
15
  )
16
16
 
17
- from .base import AutotuningBase
17
+ from ..base import AutotuningBase
18
18
 
19
19
  if TYPE_CHECKING:
20
20
  from openscvx.lowered import LoweredJaxConstraints
21
21
 
22
- from .base import AlgorithmState, CandidateIterate
23
- from .weights import Weights
22
+ from ..base import AlgorithmState, CandidateIterate
23
+ from ..weights import Weights
24
24
 
25
25
 
26
26
  class AugmentedLagrangian(AutotuningBase):
@@ -86,8 +86,12 @@ class AugmentedLagrangian(AutotuningBase):
86
86
  eta_2: Threshold above which solution is accepted with lower weight.
87
87
  Defaults to 0.8.
88
88
  ep: Threshold for virtual control weight update (nu > ep vs nu <= ep).
89
- Defaults to 0.5.
90
- eta_lambda: Step size for virtual control weight update. Defaults to 1e0.
89
+ Must lie in (0, 1). Defaults to 0.99; when tuning, try 1e-1 if needed.
90
+ Typically tuned together with ``eta_lambda`` (often the first
91
+ parameters adjusted).
92
+ eta_lambda: Step size for virtual control weight update. Defaults to 1e1.
93
+ Typically tuned together with ``ep`` (often the first parameters
94
+ adjusted).
91
95
  lam_vc_max: Maximum virtual control penalty weight. Defaults to 1e5.
92
96
  lam_prox_min: Minimum trust region (proximal) weight. Defaults to 1e-3.
93
97
  lam_prox_max: Maximum trust region (proximal) weight. Defaults to 2e5.
@@ -6,13 +6,13 @@ from pydantic import BaseModel, ConfigDict
6
6
 
7
7
  from openscvx.config import Config
8
8
 
9
- from .base import AutotuningBase
9
+ from ..base import AutotuningBase
10
10
 
11
11
  if TYPE_CHECKING:
12
12
  from openscvx.lowered import LoweredJaxConstraints
13
13
 
14
- from .base import AlgorithmState, CandidateIterate
15
- from .weights import Weights
14
+ from ..base import AlgorithmState, CandidateIterate
15
+ from ..weights import Weights
16
16
 
17
17
 
18
18
  class ConstantProximalWeight(AutotuningBase):
@@ -7,13 +7,13 @@ from pydantic import BaseModel, ConfigDict
7
7
 
8
8
  from openscvx.config import Config
9
9
 
10
- from .base import AutotuningBase
10
+ from ..base import AutotuningBase
11
11
 
12
12
  if TYPE_CHECKING:
13
13
  from openscvx.lowered import LoweredJaxConstraints
14
14
 
15
- from .base import AlgorithmState, CandidateIterate
16
- from .weights import Weights
15
+ from ..base import AlgorithmState, CandidateIterate
16
+ from ..weights import Weights
17
17
 
18
18
 
19
19
  class RampProximalWeight(AutotuningBase):
@@ -0,0 +1,5 @@
1
+ """Successive convexification algorithm implementations."""
2
+
3
+ from .penalized_trust_region import PenalizedTrustRegion
4
+
5
+ __all__ = ["PenalizedTrustRegion"]
@@ -6,8 +6,9 @@ optimization problems through iterative convex approximation.
6
6
 
7
7
  import time
8
8
  import warnings
9
- from typing import TYPE_CHECKING, Callable, Dict, List, Union
9
+ from typing import TYPE_CHECKING, Callable, Dict, List, Tuple, Union
10
10
 
11
+ import jax
11
12
  import numpy as np
12
13
  import numpy.linalg as la
13
14
 
@@ -21,11 +22,11 @@ from openscvx.utils.printing import (
21
22
  color_prob_stat,
22
23
  )
23
24
 
24
- from .augmented_lagrangian import AugmentedLagrangian
25
- from .base import Algorithm, AlgorithmState, CandidateIterate
26
- from .constant_proximal_weight import ConstantProximalWeight
27
- from .ramp_proximal_weight import RampProximalWeight
28
- from .weights import Weights
25
+ from ..autotuner.augmented_lagrangian import AugmentedLagrangian
26
+ from ..autotuner.constant_proximal_weight import ConstantProximalWeight
27
+ from ..autotuner.ramp_proximal_weight import RampProximalWeight
28
+ from ..base import Algorithm, AlgorithmState, CandidateIterate
29
+ from ..weights import Weights
29
30
 
30
31
  if TYPE_CHECKING:
31
32
  from openscvx.lowered import LoweredJaxConstraints
@@ -33,7 +34,7 @@ if TYPE_CHECKING:
33
34
  from openscvx.symbolic.expr.control import Control
34
35
  from openscvx.symbolic.expr.state import State
35
36
 
36
- from .base import AutotuningBase
37
+ from ..base import AutotuningBase
37
38
 
38
39
  warnings.filterwarnings("ignore")
39
40
 
@@ -188,6 +189,11 @@ class PenalizedTrustRegion(Algorithm):
188
189
  return solver.call(*args)
189
190
  return solver(*args)
190
191
 
192
+ @staticmethod
193
+ def _block_until_ready_outputs(outputs: Tuple[object, ...]) -> None:
194
+ """Finish any pending XLA work from discretization exports (warm-up helper)."""
195
+ jax.block_until_ready(outputs)
196
+
191
197
  def _recover_prior_node_from_initial(
192
198
  self,
193
199
  settings: Config,
@@ -286,7 +292,8 @@ class PenalizedTrustRegion(Algorithm):
286
292
  """Initialize PTR algorithm.
287
293
 
288
294
  Stores compiled infrastructure and performs a warm-start solve to
289
- initialize DPP and JAX jacobians.
295
+ initialize DPP and JAX jacobians. Also runs post-subproblem discretization
296
+ on the throwaway CVX solution so XLA/export caches match the first ``step()``.
290
297
 
291
298
  Args:
292
299
  solver: Convex subproblem solver (e.g., CVXPySolver)
@@ -333,7 +340,29 @@ class PenalizedTrustRegion(Algorithm):
333
340
  params,
334
341
  )
335
342
  init_state.add_impulsive_discretization(W_multi_shoot.__array__())
336
- _ = self._subproblem(params, init_state, settings)
343
+ (x_sol, u_sol, *_) = self._subproblem(params, init_state, settings)
344
+
345
+ # Prime the same exported discretization calls used after every subproblem in
346
+ # step() (candidate trajectory). initialize() previously only discretized the
347
+ # initial guess, so the first step() in solve() could still hit an XLA cache_miss
348
+ # on post-CVX (x_sol, u_sol). Running that path here moves compilation into init.
349
+ cont_out = self._invoke_solver(
350
+ self._discretization_solver, x_sol, u_sol.astype(float), params
351
+ )
352
+ x_prop_c = cont_out[3]
353
+ u_candidate = u_sol.astype(float)
354
+ x0_prior_c = self._recover_prior_node_from_initial(settings, x_sol[0])
355
+ x_nodes_prior_c = np.vstack((x0_prior_c, np.asarray(x_prop_c)))
356
+ if self._discretization_solver_impulsive is not None:
357
+ imp_out = self._invoke_solver(
358
+ self._discretization_solver_impulsive,
359
+ x_nodes_prior_c,
360
+ u_candidate,
361
+ params,
362
+ )
363
+ self._block_until_ready_outputs(cont_out + imp_out)
364
+ else:
365
+ self._block_until_ready_outputs(cont_out)
337
366
 
338
367
  def step(
339
368
  self,
@@ -188,7 +188,12 @@ class VectorizeDiscretizeLinearize(Discretizer):
188
188
  in_axes=-1,
189
189
  out_axes=-1,
190
190
  )
191
- jacobians = pushforward(standard_basis)
191
+ # Cast to match primal's dtype. standard_basis is built eagerly
192
+ # from jnp.eye (float32 when x64 is disabled), but jax.export
193
+ # traces with the literal dtype of the dummy arrays (float64 from
194
+ # np.ones), so primal may be float64 while the closure-captured
195
+ # standard_basis is float32. jax.jvp requires identical dtypes.
196
+ jacobians = pushforward(standard_basis.astype(primal.dtype))
192
197
 
193
198
  A_d = jacobians[:, :, :, i0:i1]
194
199
  B_d = jacobians[:, :, :, i1:i2]
@@ -125,7 +125,12 @@ def _sparse_jacobian_fn(
125
125
  _, jvp_out = jax.jvp(f_of_target, (primals[argnums],), (tangent,))
126
126
  return jvp_out
127
127
 
128
- compressed = jax.vmap(single_jvp)(seeds) # (n_colors, out_dim)
128
+ # Cast seeds to match the primal dtype. seeds is built eagerly as
129
+ # float32 (x64 disabled), but jax.export traces with the literal
130
+ # dtype of the dummy arrays (float64 from np.ones), so primals may
131
+ # be float64. jax.jvp requires matching dtypes.
132
+ typed_seeds = seeds.astype(primals[argnums].dtype)
133
+ compressed = jax.vmap(single_jvp)(typed_seeds) # (n_colors, out_dim)
129
134
 
130
135
  values = compressed[scatter_color, scatter_row]
131
136
  jac = jnp.zeros((out_dim, in_dim))
@@ -1,47 +1,46 @@
1
- """Adapters for MuJoCo MJX dynamics in OpenSCvx BYOF.
1
+ """External-backend dynamics adapters for OpenSCvx.
2
2
 
3
- The recommended entry-point is :func:`mjx_byof`, which returns a complete
4
- ``byof["dynamics"]`` dict and automatically handles free-joint quaternion
5
- kinematics for floating-base models (drones, humanoids, etc.):
3
+ The recommended entry-point is `MjxDynamics`, which goes directly into the
4
+ ``dynamics=`` slot of `Problem` and constructs the matching State/Control
5
+ objects for the user. Free-joint quaternion kinematics for floating-base
6
+ models (drones, humanoids) are detected and handled automatically::
6
7
 
7
- from openscvx.integrations import mjx_byof
8
+ from openscvx.integrations import MjxDynamics
8
9
 
9
- byof = {"dynamics": mjx_byof(mjx_model, qpos=qpos, qvel=qvel, ctrl=ctrl)}
10
+ dyn = MjxDynamics(mjx_model)
11
+ problem = ox.Problem(
12
+ dynamics=dyn,
13
+ states=dyn.states,
14
+ controls=dyn.controls,
15
+ ...
16
+ )
10
17
 
11
- For models without free joints (cartpoles, manipulators) the returned dict
12
- contains only ``"qvel"`` and ``dynamics={"qpos": qvel}`` should still be
13
- provided to :class:`~openscvx.Problem`. For models with free joints
14
- (``nq > nv``) ``"qpos"`` is included automatically — no extra imports needed.
18
+ For advanced users who need custom State/Control names (or to interleave
19
+ them with extra custom states), `mjx_dynamics` is exposed as the underlying
20
+ BYOF callable factory assemble your own ``byof["dynamics"]`` dict from it.
15
21
 
16
- :func:`mjx_dynamics` is also available for advanced users who need direct
17
- access to the BYOF callable for the ``qvel`` (acceleration) derivative.
22
+ All MJX symbols delegate lazily so ``mujoco.mjx`` is only imported when
23
+ actually used. The ``menagerie`` submodule is also loaded lazily.
18
24
 
19
- All symbols delegate lazily so ``mujoco.mjx`` is only imported when used.
20
- The :mod:`menagerie` submodule is loaded lazily via attribute access.
25
+ Example cartpole (``nq == nv``)::
21
26
 
22
- Example cartpole (nq == nv)::
27
+ from openscvx.integrations import MjxDynamics
23
28
 
24
- from openscvx.integrations import mjx_byof
29
+ dyn = MjxDynamics(mjx_model)
30
+ problem = ox.Problem(dynamics=dyn, states=dyn.states, controls=dyn.controls, ...)
25
31
 
26
- byof = {"dynamics": mjx_byof(mjx_model, qpos=qpos, qvel=qvel, ctrl=ctrl)}
27
- problem = ox.Problem(dynamics={"qpos": qvel}, byof=byof, ...)
32
+ Example quadrotor with free joint (``nq=7``, ``nv=6``)::
28
33
 
29
- Example quadrotor with free joint (nq=7, nv=6)::
34
+ from openscvx.integrations import MjxDynamics
30
35
 
31
- from openscvx.integrations import mjx_byof
32
-
33
- byof = {"dynamics": mjx_byof(mjx_model, qpos=qpos, qvel=qvel, ctrl=ctrl)}
34
- problem = ox.Problem(dynamics={}, byof=byof, ...)
36
+ dyn = MjxDynamics(mjx_model)
37
+ problem = ox.Problem(dynamics=dyn, states=dyn.states, controls=dyn.controls, ...)
35
38
  """
36
39
 
37
40
  from typing import Any
38
41
 
39
-
40
- def mjx_byof(*args: Any, **kwargs: Any) -> Any:
41
- """Lazy delegate; imports ``mujoco.mjx`` on first call."""
42
- from .mjx import mjx_byof as _mjx_byof
43
-
44
- return _mjx_byof(*args, **kwargs)
42
+ from .base import DynamicsAdapter
43
+ from .mjx import MjxDynamics
45
44
 
46
45
 
47
46
  def mjx_dynamics(*args: Any, **kwargs: Any) -> Any:
@@ -59,4 +58,9 @@ def __getattr__(name: str) -> Any:
59
58
  raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
60
59
 
61
60
 
62
- __all__ = ["mjx_byof", "mjx_dynamics", "menagerie"]
61
+ __all__ = [
62
+ "DynamicsAdapter",
63
+ "MjxDynamics",
64
+ "mjx_dynamics",
65
+ "menagerie",
66
+ ]