openscvx 2.dev4__py3-none-any.whl → 2.dev5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
openscvx/__init__.py CHANGED
@@ -25,6 +25,7 @@ from openscvx.discretization import (
25
25
  VectorizeDiscretizeLinearize,
26
26
  )
27
27
  from openscvx.expert import ByofSpec
28
+ from openscvx.integrations import DynamicsAdapter, MjxDynamics
28
29
  from openscvx.loader import load_dict, load_json, load_yaml
29
30
  from openscvx.problem import Problem
30
31
  from openscvx.solvers import PTRSolver
@@ -176,6 +177,9 @@ __all__ = [
176
177
  "lie",
177
178
  # Expert mode types
178
179
  "ByofSpec",
180
+ # External-backend dynamics adapters
181
+ "DynamicsAdapter",
182
+ "MjxDynamics",
179
183
  # Discretization
180
184
  "DiscretizeLinearizeVectorize",
181
185
  "LinearizeDiscretize",
openscvx/_version.py CHANGED
@@ -18,7 +18,7 @@ version_tuple: tuple[int | str, ...]
18
18
  commit_id: str | None
19
19
  __commit_id__: str | None
20
20
 
21
- __version__ = version = '2.dev4'
22
- __version_tuple__ = version_tuple = (2, 'dev4')
21
+ __version__ = version = '2.dev5'
22
+ __version_tuple__ = version_tuple = (2, 'dev5')
23
23
 
24
24
  __commit_id__ = commit_id = None
@@ -1,47 +1,46 @@
1
- """Adapters for MuJoCo MJX dynamics in OpenSCvx BYOF.
1
+ """External-backend dynamics adapters for OpenSCvx.
2
2
 
3
- The recommended entry-point is :func:`mjx_byof`, which returns a complete
4
- ``byof["dynamics"]`` dict and automatically handles free-joint quaternion
5
- kinematics for floating-base models (drones, humanoids, etc.):
3
+ The recommended entry-point is `MjxDynamics`, which goes directly into the
4
+ ``dynamics=`` slot of `Problem` and constructs the matching State/Control
5
+ objects for the user. Free-joint quaternion kinematics for floating-base
6
+ models (drones, humanoids) are detected and handled automatically::
6
7
 
7
- from openscvx.integrations import mjx_byof
8
+ from openscvx.integrations import MjxDynamics
8
9
 
9
- byof = {"dynamics": mjx_byof(mjx_model, qpos=qpos, qvel=qvel, ctrl=ctrl)}
10
+ dyn = MjxDynamics(mjx_model)
11
+ problem = ox.Problem(
12
+ dynamics=dyn,
13
+ states=dyn.states,
14
+ controls=dyn.controls,
15
+ ...
16
+ )
10
17
 
11
- For models without free joints (cartpoles, manipulators) the returned dict
12
- contains only ``"qvel"`` and ``dynamics={"qpos": qvel}`` should still be
13
- provided to :class:`~openscvx.Problem`. For models with free joints
14
- (``nq > nv``) ``"qpos"`` is included automatically — no extra imports needed.
18
+ For advanced users who need custom State/Control names (or to interleave
19
+ them with extra custom states), `mjx_dynamics` is exposed as the underlying
20
+ BYOF callable factory assemble your own ``byof["dynamics"]`` dict from it.
15
21
 
16
- :func:`mjx_dynamics` is also available for advanced users who need direct
17
- access to the BYOF callable for the ``qvel`` (acceleration) derivative.
22
+ All MJX symbols delegate lazily so ``mujoco.mjx`` is only imported when
23
+ actually used. The ``menagerie`` submodule is also loaded lazily.
18
24
 
19
- All symbols delegate lazily so ``mujoco.mjx`` is only imported when used.
20
- The :mod:`menagerie` submodule is loaded lazily via attribute access.
25
+ Example cartpole (``nq == nv``)::
21
26
 
22
- Example cartpole (nq == nv)::
27
+ from openscvx.integrations import MjxDynamics
23
28
 
24
- from openscvx.integrations import mjx_byof
29
+ dyn = MjxDynamics(mjx_model)
30
+ problem = ox.Problem(dynamics=dyn, states=dyn.states, controls=dyn.controls, ...)
25
31
 
26
- byof = {"dynamics": mjx_byof(mjx_model, qpos=qpos, qvel=qvel, ctrl=ctrl)}
27
- problem = ox.Problem(dynamics={"qpos": qvel}, byof=byof, ...)
32
+ Example quadrotor with free joint (``nq=7``, ``nv=6``)::
28
33
 
29
- Example quadrotor with free joint (nq=7, nv=6)::
34
+ from openscvx.integrations import MjxDynamics
30
35
 
31
- from openscvx.integrations import mjx_byof
32
-
33
- byof = {"dynamics": mjx_byof(mjx_model, qpos=qpos, qvel=qvel, ctrl=ctrl)}
34
- problem = ox.Problem(dynamics={}, byof=byof, ...)
36
+ dyn = MjxDynamics(mjx_model)
37
+ problem = ox.Problem(dynamics=dyn, states=dyn.states, controls=dyn.controls, ...)
35
38
  """
36
39
 
37
40
  from typing import Any
38
41
 
39
-
40
- def mjx_byof(*args: Any, **kwargs: Any) -> Any:
41
- """Lazy delegate; imports ``mujoco.mjx`` on first call."""
42
- from .mjx import mjx_byof as _mjx_byof
43
-
44
- return _mjx_byof(*args, **kwargs)
42
+ from .base import DynamicsAdapter
43
+ from .mjx import MjxDynamics
45
44
 
46
45
 
47
46
  def mjx_dynamics(*args: Any, **kwargs: Any) -> Any:
@@ -59,4 +58,9 @@ def __getattr__(name: str) -> Any:
59
58
  raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
60
59
 
61
60
 
62
- __all__ = ["mjx_byof", "mjx_dynamics", "menagerie"]
61
+ __all__ = [
62
+ "DynamicsAdapter",
63
+ "MjxDynamics",
64
+ "mjx_dynamics",
65
+ "menagerie",
66
+ ]
@@ -0,0 +1,89 @@
1
+ """Base class for external-backend dynamics adapters.
2
+
3
+ A `DynamicsAdapter` is the easy on-ramp for users who want to plug an
4
+ external physics backend (MuJoCo MJX, Brax, Drake, ...) into OpenSCvx without
5
+ manually constructing State/Control objects with matching shapes or routing
6
+ raw JAX callables through the expert ``byof`` channel.
7
+
8
+ The intended call site is::
9
+
10
+ dyn = ox.MjxDynamics(mjx_model)
11
+ problem = ox.Problem(
12
+ dynamics=dyn,
13
+ states=dyn.states,
14
+ controls=dyn.controls,
15
+ ...
16
+ )
17
+
18
+ Internally, `Problem` detects the adapter, calls `DynamicsAdapter.expand`,
19
+ and merges the resulting BYOF callables into the user's ``byof`` dict (if
20
+ any). Everything downstream sees ordinary ``dynamics`` and ``byof`` dicts.
21
+ """
22
+
23
+ from __future__ import annotations
24
+
25
+ import copy
26
+ from abc import ABC, abstractmethod
27
+ from typing import TYPE_CHECKING, Tuple
28
+
29
+ if TYPE_CHECKING:
30
+ from openscvx.symbolic.expr.control import Control
31
+ from openscvx.symbolic.expr.state import State
32
+
33
+
34
+ class DynamicsAdapter(ABC):
35
+ """Abstract base class for external-backend dynamics adapters.
36
+
37
+ Subclasses describe the State/Control objects they synthesize on
38
+ ``.states`` / ``.controls`` and implement `expand` to return the
39
+ two-channel ``(dynamics_dict, byof_dict)`` representation consumed by
40
+ `Problem`.
41
+
42
+ The split mirrors the existing two-channel API: ``dynamics_dict`` carries
43
+ symbolic Expr entries (e.g. ``{"qpos": qvel}``) while ``byof_dict`` carries
44
+ raw JAX callables under the ``"dynamics"`` key. Either or both may be
45
+ empty, but ``expand()`` should never silently produce overlapping keys.
46
+ """
47
+
48
+ states: list["State"]
49
+ controls: list["Control"]
50
+
51
+ @abstractmethod
52
+ def expand(self) -> Tuple[dict, dict]:
53
+ """Return ``(dynamics_dict, byof_dict)`` in OpenSCvx's internal form.
54
+
55
+ ``dynamics_dict`` maps state names to symbolic ``Expr`` derivatives
56
+ (the same shape as the ``dynamics=`` argument to ``Problem``).
57
+ ``byof_dict`` has the same shape as the ``byof=`` argument: its
58
+ ``"dynamics"`` key (if present) maps state names to raw JAX callables.
59
+ """
60
+
61
+
62
+ def _merge_byof(user_byof: dict | None, extra_byof: dict) -> dict:
63
+ """Merge an adapter-synthesized BYOF dict into a user-provided one.
64
+
65
+ Only the ``"dynamics"`` sub-dict is deep-merged; other keys are taken
66
+ verbatim from whichever side provides them. Raises ``ValueError`` on any
67
+ key collision under ``"dynamics"`` — a user passing both
68
+ ``dynamics=ox.MjxDynamics(...)`` and ``byof={"dynamics": {"qvel": ...}}``
69
+ almost certainly has a bug, and silent override would mask it.
70
+ """
71
+ if not user_byof:
72
+ return copy.copy(extra_byof)
73
+
74
+ merged = dict(user_byof)
75
+ extra_dyn = extra_byof.get("dynamics", {})
76
+ user_dyn = user_byof.get("dynamics", {})
77
+
78
+ if extra_dyn:
79
+ collisions = set(user_dyn) & set(extra_dyn)
80
+ if collisions:
81
+ raise ValueError(
82
+ "DynamicsAdapter produced byof['dynamics'] entries that "
83
+ f"collide with user-provided byof['dynamics']: {sorted(collisions)}. "
84
+ "Drop the duplicate keys from your byof dict, or drop the adapter "
85
+ "and assemble byof['dynamics'] manually for full control."
86
+ )
87
+ merged["dynamics"] = {**user_dyn, **extra_dyn}
88
+
89
+ return merged
@@ -1,29 +1,35 @@
1
- """MuJoCo MJX dynamics adapters for OpenSCvx BYOF.
1
+ """MuJoCo MJX dynamics adapters for OpenSCvx.
2
2
 
3
- The recommended entry-point is :func:`mjx_byof`, which returns a complete
4
- ``byof["dynamics"]`` dict and automatically handles free-joint quaternion
5
- kinematics no separate imports required:
3
+ The recommended entry-point is `MjxDynamics`, a `DynamicsAdapter` that goes
4
+ directly into the ``dynamics=`` slot of `Problem` and exposes the synthesized
5
+ State/Control objects on ``.states`` / ``.controls``::
6
6
 
7
- byof = {"dynamics": mjx_byof(mjx_model, qpos=qpos, qvel=qvel, ctrl=ctrl)}
7
+ dyn = ox.MjxDynamics(mjx_model)
8
+ problem = ox.Problem(
9
+ dynamics=dyn,
10
+ states=dyn.states,
11
+ controls=dyn.controls,
12
+ ...
13
+ )
8
14
 
9
- For models **without** free joints (cartpoles, manipulators, etc.) the
10
- returned dict contains only ``"qvel"``, and qpos kinematics must still be
11
- specified symbolically via ``dynamics={"qpos": qvel}``. For models **with**
12
- free joints (drones, humanoids) ``"qpos"`` is included automatically and no
13
- symbolic dynamics entry is needed.
15
+ Free-joint quaternion kinematics (``nq > nv`` models such as drones or
16
+ humanoids) are detected and handled automatically.
14
17
 
15
- The lower-level :func:`mjx_dynamics` is also public for advanced users who
16
- need direct access to the BYOF callable for the ``qvel`` derivative.
18
+ The lower-level `mjx_dynamics` callable factory is also public for advanced
19
+ users who need to assemble their own BYOF dynamics dict (e.g. with custom
20
+ State/Control names or interleaved with other states).
17
21
 
18
22
  Note:
19
23
  Time dilation is handled automatically by the BYOF lowering pipeline; all
20
24
  functions return physical (un-dilated) quantities.
21
25
  """
22
26
 
23
- from typing import TYPE_CHECKING, Any, Callable, Optional
27
+ from typing import TYPE_CHECKING, Any, Callable, Optional, Tuple
24
28
 
25
29
  import jax.numpy as jnp
26
30
 
31
+ from openscvx.integrations.base import DynamicsAdapter
32
+
27
33
  if TYPE_CHECKING:
28
34
  from openscvx.symbolic.expr.control import Control
29
35
  from openscvx.symbolic.expr.state import State
@@ -164,7 +170,7 @@ def _free_joint_qpos_dynamics(
164
170
  ) -> Callable:
165
171
  """BYOF callable for ``qpos`` when the model has quaternion free joints.
166
172
 
167
- Used internally by :func:`mjx_byof`. When a MuJoCo model has a
173
+ Used internally by `MjxDynamics`. When a MuJoCo model has a
168
174
  floating-base free joint, ``nq > nv`` because each quaternion orientation
169
175
  contributes 4 position DOF but only 3 angular velocity DOF. The simple
170
176
  symbolic shorthand ``"qpos": qvel`` therefore fails a shape check. This
@@ -246,78 +252,245 @@ def _free_joint_qpos_dynamics(
246
252
  return f
247
253
 
248
254
 
249
- def mjx_byof(
250
- mjx_model: Any,
251
- *,
252
- qpos: "State | slice",
253
- qvel: "State | slice",
254
- ctrl: "Control | slice",
255
- return_component: str = "qacc",
256
- extra_postprocess: Optional[Callable[[Any], Any]] = None,
257
- ) -> dict:
258
- """Return a complete ``byof["dynamics"]`` dict for a MuJoCo MJX model.
255
+ # MuJoCo joint type enum (matches mujoco.mjtJoint): 0=free, 1=ball, 2=slide, 3=hinge.
256
+ # Inlined here so we don't require `mujoco` to be importable just for type validation —
257
+ # the user already needs mujoco to have constructed the mjx_model, but keeping the
258
+ # numeric constants local makes this file self-contained.
259
+ _MJ_JNT_FREE = 0
260
+ _MJ_JNT_BALL = 1
261
+ _MJ_JNT_SLIDE = 2
262
+ _MJ_JNT_HINGE = 3
259
263
 
260
- This is the recommended high-level entry-point. It inspects the model's
261
- ``nq`` and ``nv`` to detect free joints and automatically includes the
262
- quaternion kinematics callable for ``qpos`` when needed.
263
264
 
264
- Args:
265
- mjx_model: A model produced by :func:`mujoco.mjx.put_model`.
266
- qpos: Position state (or slice). Length must equal ``mjx_model.nq``.
267
- qvel: Velocity state (or slice). Length must equal ``mjx_model.nv``.
268
- ctrl: Control variable (or slice). Length must equal ``mjx_model.nu``.
269
- return_component: Passed to :func:`mjx_dynamics`. ``"qacc"``
270
- (default) uses the generalized acceleration as the ``qvel``
271
- derivative; ``"qvel"`` returns qvel directly (rarely needed).
272
- extra_postprocess: Optional callable applied to the MJX ``data``
273
- object after ``mjx.forward``. Passed through to
274
- :func:`mjx_dynamics`.
265
+ def _initial_bounds_from_model(
266
+ mjx_model: Any, nq: int, nv: int, nu: int
267
+ ) -> Tuple[Any, Any, Any, Any, Any, Any]:
268
+ """Pull qpos / ctrl bounds out of the MJX model.
275
269
 
276
- Returns:
277
- A dict suitable for use as ``byof["dynamics"]``.
278
- For models **without** free joints (``nq == nv``) only ``"qvel"`` is
279
- included; position kinematics should still be provided symbolically
280
- via ``dynamics={"qpos": qvel}``.
281
- For models **with** free joints (``nq > nv``) both ``"qpos"`` and
282
- ``"qvel"`` are included and no symbolic ``dynamics`` entry is needed.
270
+ Returns ``(qpos_min, qpos_max, qvel_min, qvel_max, ctrl_min, ctrl_max)``.
271
+
272
+ - ``qpos`` bounds come from ``mjx_model.jnt_range`` for slide/hinge joints
273
+ flagged ``jnt_limited=True``. All other qpos slots free-joint
274
+ translations and quaternion components, unlimited slide/hinge joints —
275
+ get ``±inf``.
276
+ - ``ctrl`` bounds come from ``mjx_model.actuator_ctrlrange`` for actuators
277
+ flagged ``actuator_ctrllimited=True``. Unlimited actuators get ``±inf``.
278
+ - ``qvel`` bounds are always ``±inf`` because MuJoCo has no per-joint
279
+ velocity-limit concept; users override as needed.
280
+ """
281
+ import numpy as _np
282
+
283
+ qpos_min = _np.full(nq, -_np.inf)
284
+ qpos_max = _np.full(nq, _np.inf)
285
+ qvel_min = _np.full(nv, -_np.inf)
286
+ qvel_max = _np.full(nv, _np.inf)
287
+ ctrl_min = _np.full(nu, -_np.inf)
288
+ ctrl_max = _np.full(nu, _np.inf)
289
+
290
+ # Per-joint qpos bounds — only slide/hinge can be range-limited; free
291
+ # joints always have jnt_limited=False so we skip them safely.
292
+ jnt_type = _np.asarray(mjx_model.jnt_type).astype(int)
293
+ jnt_qposadr = _np.asarray(mjx_model.jnt_qposadr).astype(int)
294
+ jnt_limited = _np.asarray(mjx_model.jnt_limited).astype(bool)
295
+ jnt_range = _np.asarray(mjx_model.jnt_range).astype(float)
296
+ for i, jtype in enumerate(jnt_type):
297
+ if jtype in (_MJ_JNT_SLIDE, _MJ_JNT_HINGE) and jnt_limited[i]:
298
+ adr = int(jnt_qposadr[i])
299
+ qpos_min[adr] = jnt_range[i, 0]
300
+ qpos_max[adr] = jnt_range[i, 1]
301
+
302
+ if nu > 0:
303
+ act_limited = _np.asarray(mjx_model.actuator_ctrllimited).astype(bool)
304
+ act_range = _np.asarray(mjx_model.actuator_ctrlrange).astype(float)
305
+ for i in range(nu):
306
+ if act_limited[i]:
307
+ ctrl_min[i] = act_range[i, 0]
308
+ ctrl_max[i] = act_range[i, 1]
309
+
310
+ return qpos_min, qpos_max, qvel_min, qvel_max, ctrl_min, ctrl_max
311
+
312
+
313
+ def _validate_supported_joints(mjx_model: Any) -> None:
314
+ """Refuse models whose joint layout the adapter cannot correctly handle.
315
+
316
+ `MjxDynamics` only supports models composed of free / slide / hinge joints
317
+ where all free joints precede the others in the state layout. Anything else
318
+ (ball joints, custom joint orderings) silently breaks the
319
+ `_free_joint_qpos_dynamics` arithmetic, so we refuse with a clear error
320
+ rather than producing wrong dynamics.
321
+ """
322
+ import numpy as _np
323
+
324
+ jnt_type = _np.asarray(mjx_model.jnt_type).astype(int)
325
+ supported = {_MJ_JNT_FREE, _MJ_JNT_SLIDE, _MJ_JNT_HINGE}
326
+ bad = sorted(set(jnt_type.tolist()) - supported)
327
+ if bad:
328
+ if _MJ_JNT_BALL in bad:
329
+ raise NotImplementedError(
330
+ "MjxDynamics does not support ball joints (mjJNT_BALL): they "
331
+ "share nq=4, nv=3 with free joints but use different "
332
+ "kinematics, and the current quaternion-kinematics callable "
333
+ "would silently produce wrong dynamics. Use `mjx_dynamics` "
334
+ "directly and assemble byof['dynamics'] manually."
335
+ )
336
+ raise NotImplementedError(
337
+ f"MjxDynamics only supports free, slide, and hinge joints; "
338
+ f"model contains unsupported joint types {bad}. Use "
339
+ "`mjx_dynamics` directly and assemble byof['dynamics'] manually."
340
+ )
341
+
342
+ # _free_joint_qpos_dynamics assumes all free joints come first in the
343
+ # state vector. If a slide/hinge precedes a free joint, the quaternion
344
+ # offsets would be off.
345
+ free_mask = jnt_type == _MJ_JNT_FREE
346
+ n_free = int(free_mask.sum())
347
+ if n_free and not free_mask[:n_free].all():
348
+ raise NotImplementedError(
349
+ "MjxDynamics requires all free joints to come before any "
350
+ "slide/hinge joints in the MuJoCo model. Reorder the joints in "
351
+ "your MJCF/URDF, or use `mjx_dynamics` directly to assemble "
352
+ "byof['dynamics'] yourself."
353
+ )
354
+
355
+
356
+ class MjxDynamics(DynamicsAdapter):
357
+ """First-class MJX dynamics adapter for `Problem`.
358
+
359
+ Wraps a ``mujoco.mjx`` model so it can be passed directly to the
360
+ ``dynamics=`` argument of `Problem`. The adapter
361
+ constructs default ``qpos`` / ``qvel`` State objects and a ``ctrl``
362
+ Control matching the model's ``nq`` / ``nv`` / ``nu``, exposes them via
363
+ ``.states`` / ``.controls``, and routes the MJX forward dynamics through
364
+ the BYOF channel internally — without requiring the user to know about
365
+ BYOF at all.
283
366
 
284
367
  Example:
285
- Cartpole (nq == nv, no free joint)::
368
+ Cartpole (``nq == nv``)::
369
+
370
+ mj_model = mujoco.MjModel.from_xml_path("cartpole.xml")
371
+ mj_model.opt.disableflags |= mujoco.mjtDisableBit.mjDSBL_CONTACT
372
+ mjx_model = mjx.put_model(mj_model)
286
373
 
287
- byof = {"dynamics": mjx_byof(mjx_model, qpos=qpos, qvel=qvel, ctrl=ctrl)}
374
+ dyn = ox.MjxDynamics(mjx_model)
288
375
  problem = ox.Problem(
289
- dynamics={"qpos": qvel}, # still required for non-free models
290
- byof=byof, ...
376
+ dynamics=dyn,
377
+ states=dyn.states,
378
+ controls=dyn.controls,
379
+ ...
291
380
  )
292
381
 
293
- Quadrotor / drone (nq > nv, one free joint)::
382
+ Quadrotor with a free joint (``nq > nv``) quaternion kinematics
383
+ are inserted automatically::
294
384
 
295
- byof = {"dynamics": mjx_byof(mjx_model, qpos=qpos, qvel=qvel, ctrl=ctrl)}
385
+ dyn = ox.MjxDynamics(mjx_model) # nq=7, nv=6
296
386
  problem = ox.Problem(
297
- dynamics={}, # qpos handled automatically
298
- byof=byof, ...
387
+ dynamics=dyn, states=dyn.states, controls=dyn.controls, ...
299
388
  )
389
+
390
+ Custom State/Control names or shapes are *not* supported here on purpose
391
+ — the whole point of the adapter is "I don't want to think about names."
392
+ Drop to the lower-level `mjx_dynamics` helper if you need that control —
393
+ construct your own State/Control objects, pass them in, and assemble the
394
+ BYOF dict yourself.
395
+
396
+ Supported joint structure:
397
+ * Free (``mjJNT_FREE``), slide (``mjJNT_SLIDE``), and hinge
398
+ (``mjJNT_HINGE``) joints only.
399
+ * If the model contains any free joints, they must all come
400
+ *before* any slide/hinge joints in the MuJoCo layout.
401
+ * Ball joints (``mjJNT_BALL``) are explicitly refused — they share
402
+ ``nq=4, nv=3`` with free joints but use different kinematics, and
403
+ would silently produce wrong dynamics.
404
+
405
+ Construction raises ``NotImplementedError`` if any of these
406
+ conditions are violated; fall back to `mjx_dynamics` for those
407
+ cases.
408
+
409
+ Auto-populated bounds:
410
+ * ``qpos.min`` / ``qpos.max`` are read from ``mjx_model.jnt_range``
411
+ for slide / hinge joints flagged ``jnt_limited=True``; free-joint
412
+ slots and unlimited joints default to ``±inf``.
413
+ * ``ctrl.min`` / ``ctrl.max`` are read from ``actuator_ctrlrange``
414
+ for actuators flagged ``actuator_ctrllimited=True``; otherwise
415
+ ``±inf``.
416
+ * ``qvel`` bounds default to ``±inf`` (MuJoCo has no per-joint
417
+ velocity-limit concept).
418
+
419
+ Override any of these after construction if you want tighter
420
+ problem-specific bounds.
300
421
  """
301
- nq = int(mjx_model.nq)
302
- nv = int(mjx_model.nv)
303
-
304
- result: dict = {
305
- "qvel": mjx_dynamics(
306
- mjx_model,
307
- qpos=qpos,
308
- qvel=qvel,
309
- ctrl=ctrl,
310
- return_component=return_component,
311
- extra_postprocess=extra_postprocess,
312
- ),
313
- }
314
-
315
- n_free = nq - nv # each free joint contributes exactly 1 extra position DOF
316
- if n_free > 0:
317
- result["qpos"] = _free_joint_qpos_dynamics(
318
- qpos=qpos,
319
- qvel=qvel,
320
- n_free_joints=n_free,
422
+
423
+ def __init__(
424
+ self,
425
+ mjx_model: Any,
426
+ *,
427
+ return_component: str = "qacc",
428
+ extra_postprocess: Optional[Callable[[Any], Any]] = None,
429
+ ) -> None:
430
+ from openscvx.symbolic.expr.control import Control
431
+ from openscvx.symbolic.expr.state import State
432
+
433
+ _validate_supported_joints(mjx_model)
434
+
435
+ self.mjx_model = mjx_model
436
+ self.return_component = return_component
437
+ self.extra_postprocess = extra_postprocess
438
+
439
+ nq = int(mjx_model.nq)
440
+ nv = int(mjx_model.nv)
441
+ nu = int(mjx_model.nu)
442
+
443
+ self._qpos = State("qpos", shape=(nq,))
444
+ self._qvel = State("qvel", shape=(nv,))
445
+ self._ctrl = Control("ctrl", shape=(nu,))
446
+
447
+ # Auto-populate bounds from the model so the user doesn't have to
448
+ # re-type joint / actuator limits already declared in MJCF. Users
449
+ # can still override any of these after construction.
450
+ qpos_min, qpos_max, qvel_min, qvel_max, ctrl_min, ctrl_max = _initial_bounds_from_model(
451
+ mjx_model, nq, nv, nu
321
452
  )
453
+ self._qpos.min = qpos_min
454
+ self._qpos.max = qpos_max
455
+ self._qvel.min = qvel_min
456
+ self._qvel.max = qvel_max
457
+ self._ctrl.min = ctrl_min
458
+ self._ctrl.max = ctrl_max
459
+
460
+ self.states: list[State] = [self._qpos, self._qvel]
461
+ self.controls: list[Control] = [self._ctrl]
462
+
463
+ def expand(self) -> Tuple[dict, dict]:
464
+ """Return ``(dynamics_dict, byof_dict)`` for this MJX model.
465
+
466
+ - ``nq == nv``: ``dynamics_dict = {"qpos": qvel}`` (symbolic
467
+ kinematic identity), ``byof_dict["dynamics"] = {"qvel": ...}``.
468
+ - ``nq > nv``: ``dynamics_dict = {}``, ``byof_dict["dynamics"]``
469
+ contains both ``"qpos"`` (quaternion kinematics) and ``"qvel"``.
470
+ """
471
+ nq = int(self.mjx_model.nq)
472
+ nv = int(self.mjx_model.nv)
473
+
474
+ byof_dynamics: dict = {
475
+ "qvel": mjx_dynamics(
476
+ self.mjx_model,
477
+ qpos=self._qpos,
478
+ qvel=self._qvel,
479
+ ctrl=self._ctrl,
480
+ return_component=self.return_component,
481
+ extra_postprocess=self.extra_postprocess,
482
+ ),
483
+ }
484
+
485
+ n_free = nq - nv
486
+ if n_free > 0:
487
+ byof_dynamics["qpos"] = _free_joint_qpos_dynamics(
488
+ qpos=self._qpos,
489
+ qvel=self._qvel,
490
+ n_free_joints=n_free,
491
+ )
492
+ dynamics_dict: dict = {}
493
+ else:
494
+ dynamics_dict = {"qpos": self._qvel}
322
495
 
323
- return result
496
+ return dynamics_dict, {"dynamics": byof_dynamics}
openscvx/problem.py CHANGED
@@ -42,6 +42,7 @@ from openscvx.discretization import (
42
42
  resolve_discretizer_config,
43
43
  )
44
44
  from openscvx.expert import ByofSpec
45
+ from openscvx.integrations.base import DynamicsAdapter, _merge_byof
45
46
  from openscvx.lowered import LoweredProblem, ParameterDict
46
47
  from openscvx.lowered.dynamics import Dynamics
47
48
  from openscvx.lowered.jax_constraints import (
@@ -70,7 +71,7 @@ from openscvx.utils.caching import (
70
71
  class Problem:
71
72
  def __init__(
72
73
  self,
73
- dynamics: dict,
74
+ dynamics: Union[dict, DynamicsAdapter],
74
75
  constraints: List[Union[Constraint, CTCS]],
75
76
  states: List[State],
76
77
  controls: List[Control],
@@ -92,9 +93,11 @@ class Problem:
92
93
  """The primary class in charge of compiling and exporting the solvers.
93
94
 
94
95
  Args:
95
- dynamics (dict): Dictionary mapping state names to their dynamics expressions.
96
- Each key should be a state name, and each value should be an Expr
97
- representing the derivative of that state.
96
+ dynamics: Dictionary mapping state names to their dynamics
97
+ expressions. Each key should be a state name, and each value
98
+ should be an ``Expr`` representing the derivative of that
99
+ state. A ``DynamicsAdapter`` may also be passed here in
100
+ place of the dict.
98
101
  constraints (List[Union[CTCSConstraint, NodalConstraint]]):
99
102
  List of constraints decorated with @ctcs or @nodal
100
103
  states (List[State]): List of State objects representing the state variables.
@@ -250,6 +253,12 @@ class Problem:
250
253
  self._float_dtype: str = float_dtype
251
254
 
252
255
  # Symbolic Preprocessing & Augmentation
256
+ # If `dynamics` is a DynamicsAdapter, expand it into the standard
257
+ # (dynamics_dict, byof_dict) representation and merge into user byof.
258
+ if isinstance(dynamics, DynamicsAdapter):
259
+ dynamics, adapter_byof = dynamics.expand()
260
+ byof = _merge_byof(byof, adapter_byof)
261
+
253
262
  # Resolve byof: dict → ByofSpec (validates keys and nested specs)
254
263
  if byof is not None:
255
264
  byof = ByofSpec.model_validate(byof)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: openscvx
3
- Version: 2.dev4
3
+ Version: 2.dev5
4
4
  Summary: A general Python-based successive convexification implementation which uses a JAX backend.
5
5
  Author-email: Chris Hayner and Griffin Norris <haynec@uw.edu>
6
6
  License: Apache Software License
@@ -1,9 +1,9 @@
1
- openscvx/__init__.py,sha256=S-pt3p_BI5aOdOYPLrxKOshOeLEBXmqPoqZfROqq730,3542
1
+ openscvx/__init__.py,sha256=xmpOTuaULy4YJFhqUpXGSSMu-vjTBJnFIhohXdi8ttc,3688
2
2
  openscvx/__main__.py,sha256=Hwm7mtVg3tLdvoUPkpcQv8KF3wxl72PNLBp9axFu8GY,2991
3
- openscvx/_version.py,sha256=bX2A8XFTVwyHBgOdSJG5lLzm6PiKfBL6d4fCZyhQGaU,523
3
+ openscvx/_version.py,sha256=weTnVwzOsaNm5p3YrseycYmlTphq5sEq-JrqG3LXouM,523
4
4
  openscvx/config.py,sha256=qfDDYoCe6WqJglKsx5b2W48YOglXenKr-PVRRdCFhYE,9898
5
5
  openscvx/loader.py,sha256=FvKLkkXd4ihd5FqLFF8Cd9VnPbPwTV_azBRnEipi28c,7654
6
- openscvx/problem.py,sha256=g-DKsdpZ5rOxEoMVOUj_QwB1L1edN0rgtRn8Z-TjgTI,47177
6
+ openscvx/problem.py,sha256=GlCxYFhpacIt6Sxte3-K-d_GbGxVL9F9Qzy4y8-hrGc,47675
7
7
  openscvx/algorithms/__init__.py,sha256=f5VhjFb40JyVPgJJh6fUxRbmyQIglUyiNYL22nOMrgs,5632
8
8
  openscvx/algorithms/augmented_lagrangian.py,sha256=liEHtqONbpmw7CZJJCkAluPOEbFslFsvCElWHusHpn4,15807
9
9
  openscvx/algorithms/base.py,sha256=JsaVfS-hyHeGU1GUuQV9i-EpKQ1HBE0RHABIBGrEonM,30205
@@ -27,9 +27,10 @@ openscvx/expert/validation.py,sha256=ofOwg6t3LcrE-xaefBnEf0EAkUd9b5EwMQOjJYEYKLA
27
27
  openscvx/init/__init__.py,sha256=1nOhjqVgZRDTsHfozWPagPcyp99hskI3u31PF5kBzvw,893
28
28
  openscvx/init/interpolation.py,sha256=khypZhwADcYIhudn6EnLWMka9dmjY44aRpnGXhULQ3k,11165
29
29
  openscvx/init/inverse_kinematics.py,sha256=9mdBADa2TWvxi5z_wK_rmUdjk_K_3KJc5w2O9YfACDg,9084
30
- openscvx/integrations/__init__.py,sha256=sIvGvKbIrvDenhW31RohtiPzWtbFGN4vMgGTVXFuYG0,2132
30
+ openscvx/integrations/__init__.py,sha256=UWMr7SHotwhUa9a636rxWPvpZJpeQOinHkLpQ9Apz9M,1959
31
+ openscvx/integrations/base.py,sha256=oHDOPB-hTa9GFZ1tcmkdmzxG7pOkz4R5dT0XVGe5rKA,3377
31
32
  openscvx/integrations/menagerie.py,sha256=Zm2aGwwkqJPGQXjnHFoKI3dkXHhVzUB1eUIZUywmT28,6431
32
- openscvx/integrations/mjx.py,sha256=QQqygZMMzsuBexk8H93b8NWZy9fEDSR8ga6anhqecH0,11937
33
+ openscvx/integrations/mjx.py,sha256=E71AOor70hR_wC7XGyACOGpSBaQklYVc6-RCqZghT0U,19112
33
34
  openscvx/integrators/__init__.py,sha256=easV2-ruQLif5e8UBqE-xP5mpslYCN84gzT3OfZGQgo,1616
34
35
  openscvx/integrators/diffrax.py,sha256=5RWNtaAUSGAWeHvX897fBxWjjJLVM4ORhqC12gdFVIA,4215
35
36
  openscvx/integrators/runge_kutta.py,sha256=yMf_JLVToPgdwayFP7ZZ-w-grf_ch9UgyPsZu3mYRb0,2972
@@ -137,9 +138,9 @@ openscvx/utils/caching.py,sha256=BPkT_IbmYT-i-BZ-himdWUc_4oBcwXWJxeUMwQWnSNc,934
137
138
  openscvx/utils/printing.py,sha256=zl3IxnKhwITqB5dK0Ru2IlORPqB61Y3DgsTryqMNu9M,13360
138
139
  openscvx/utils/profiling.py,sha256=k2x-i0CpG_kRe6dNcNBGu-ylrOtQw4B4C1UaOTjUMfU,1678
139
140
  openscvx/utils/utils.py,sha256=M25RHE_7DSr3Reaca0xCXnDSY9KHuqYvXdh5m1ZotEc,3047
140
- openscvx-2.dev4.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
141
- openscvx-2.dev4.dist-info/METADATA,sha256=F8CCiQ4cANbUXKGjTLNpVS3kKbVMVL2miUWTzf4P9vw,10662
142
- openscvx-2.dev4.dist-info/WHEEL,sha256=aeYiig01lYGDzBgS8HxWXOg3uV61G9ijOsup-k9o1sk,91
143
- openscvx-2.dev4.dist-info/entry_points.txt,sha256=1Oqek8Sy28hmAZFgZXDxFXYVf56YLYWlHjhh9RYJ7wE,52
144
- openscvx-2.dev4.dist-info/top_level.txt,sha256=nUT4Ybefzh40H8tVXqc1RzKESy_MAowElb-CIvAbd4Q,9
145
- openscvx-2.dev4.dist-info/RECORD,,
141
+ openscvx-2.dev5.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
142
+ openscvx-2.dev5.dist-info/METADATA,sha256=-i1S5BQH7E22WfWRdoGoe1dUnRoWiwt9DK8YfVVPAp4,10662
143
+ openscvx-2.dev5.dist-info/WHEEL,sha256=aeYiig01lYGDzBgS8HxWXOg3uV61G9ijOsup-k9o1sk,91
144
+ openscvx-2.dev5.dist-info/entry_points.txt,sha256=1Oqek8Sy28hmAZFgZXDxFXYVf56YLYWlHjhh9RYJ7wE,52
145
+ openscvx-2.dev5.dist-info/top_level.txt,sha256=nUT4Ybefzh40H8tVXqc1RzKESy_MAowElb-CIvAbd4Q,9
146
+ openscvx-2.dev5.dist-info/RECORD,,