imt-ring 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- imt_ring-1.2.1.dist-info/METADATA +91 -0
 - imt_ring-1.2.1.dist-info/RECORD +83 -0
 - imt_ring-1.2.1.dist-info/WHEEL +5 -0
 - imt_ring-1.2.1.dist-info/top_level.txt +1 -0
 - ring/__init__.py +63 -0
 - ring/algebra.py +100 -0
 - ring/algorithms/__init__.py +45 -0
 - ring/algorithms/_random.py +403 -0
 - ring/algorithms/custom_joints/__init__.py +6 -0
 - ring/algorithms/custom_joints/rr_imp_joint.py +69 -0
 - ring/algorithms/custom_joints/rr_joint.py +33 -0
 - ring/algorithms/custom_joints/suntay.py +424 -0
 - ring/algorithms/dynamics.py +345 -0
 - ring/algorithms/generator/__init__.py +25 -0
 - ring/algorithms/generator/base.py +414 -0
 - ring/algorithms/generator/batch.py +282 -0
 - ring/algorithms/generator/motion_artifacts.py +222 -0
 - ring/algorithms/generator/pd_control.py +182 -0
 - ring/algorithms/generator/randomize.py +119 -0
 - ring/algorithms/generator/transforms.py +410 -0
 - ring/algorithms/generator/types.py +36 -0
 - ring/algorithms/jcalc.py +840 -0
 - ring/algorithms/kinematics.py +202 -0
 - ring/algorithms/sensors.py +582 -0
 - ring/base.py +1046 -0
 - ring/io/__init__.py +9 -0
 - ring/io/examples/branched.xml +24 -0
 - ring/io/examples/exclude/knee_trans_dof.xml +26 -0
 - ring/io/examples/exclude/standard_sys.xml +106 -0
 - ring/io/examples/exclude/standard_sys_rr_imp.xml +106 -0
 - ring/io/examples/inv_pendulum.xml +14 -0
 - ring/io/examples/knee_flexible_imus.xml +22 -0
 - ring/io/examples/spherical_stiff.xml +11 -0
 - ring/io/examples/symmetric.xml +12 -0
 - ring/io/examples/test_all_1.xml +39 -0
 - ring/io/examples/test_all_2.xml +39 -0
 - ring/io/examples/test_ang0_pos0.xml +9 -0
 - ring/io/examples/test_control.xml +16 -0
 - ring/io/examples/test_double_pendulum.xml +14 -0
 - ring/io/examples/test_free.xml +11 -0
 - ring/io/examples/test_kinematics.xml +23 -0
 - ring/io/examples/test_morph_system/four_seg_seg1.xml +26 -0
 - ring/io/examples/test_morph_system/four_seg_seg3.xml +26 -0
 - ring/io/examples/test_randomize_position.xml +26 -0
 - ring/io/examples/test_sensors.xml +13 -0
 - ring/io/examples/test_three_seg_seg2.xml +23 -0
 - ring/io/examples.py +42 -0
 - ring/io/test_examples.py +6 -0
 - ring/io/xml/__init__.py +6 -0
 - ring/io/xml/abstract.py +300 -0
 - ring/io/xml/from_xml.py +299 -0
 - ring/io/xml/test_from_xml.py +56 -0
 - ring/io/xml/test_to_xml.py +31 -0
 - ring/io/xml/to_xml.py +94 -0
 - ring/maths.py +397 -0
 - ring/ml/__init__.py +33 -0
 - ring/ml/base.py +292 -0
 - ring/ml/callbacks.py +434 -0
 - ring/ml/ml_utils.py +272 -0
 - ring/ml/optimizer.py +149 -0
 - ring/ml/params/0x13e3518065c21cd8.pickle +0 -0
 - ring/ml/ringnet.py +279 -0
 - ring/ml/train.py +318 -0
 - ring/ml/training_loop.py +131 -0
 - ring/rendering/__init__.py +2 -0
 - ring/rendering/base_render.py +271 -0
 - ring/rendering/mujoco_render.py +222 -0
 - ring/rendering/vispy_render.py +340 -0
 - ring/rendering/vispy_visuals.py +290 -0
 - ring/sim2real/__init__.py +7 -0
 - ring/sim2real/sim2real.py +288 -0
 - ring/spatial.py +126 -0
 - ring/sys_composer/__init__.py +5 -0
 - ring/sys_composer/delete_sys.py +114 -0
 - ring/sys_composer/inject_sys.py +110 -0
 - ring/sys_composer/morph_sys.py +361 -0
 - ring/utils/__init__.py +21 -0
 - ring/utils/batchsize.py +51 -0
 - ring/utils/colab.py +48 -0
 - ring/utils/hdf5.py +198 -0
 - ring/utils/normalizer.py +56 -0
 - ring/utils/path.py +44 -0
 - ring/utils/utils.py +161 -0
 
    
        ring/base.py
    ADDED
    
    | 
         @@ -0,0 +1,1046 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            from pathlib import Path
         
     | 
| 
      
 2 
     | 
    
         
            +
            from typing import Any, Callable, Optional, Sequence, Union
         
     | 
| 
      
 3 
     | 
    
         
            +
             
     | 
| 
      
 4 
     | 
    
         
            +
            from flax import struct
         
     | 
| 
      
 5 
     | 
    
         
            +
            import jax
         
     | 
| 
      
 6 
     | 
    
         
            +
            from jax.core import Tracer
         
     | 
| 
      
 7 
     | 
    
         
            +
            import jax.numpy as jnp
         
     | 
| 
      
 8 
     | 
    
         
            +
            from jax.tree_util import tree_map
         
     | 
| 
      
 9 
     | 
    
         
            +
            import numpy as np
         
     | 
| 
      
 10 
     | 
    
         
            +
            import tree_utils as tu
         
     | 
| 
      
 11 
     | 
    
         
            +
             
     | 
| 
      
 12 
     | 
    
         
            +
            import ring
         
     | 
| 
      
 13 
     | 
    
         
            +
            from ring import maths
         
     | 
| 
      
 14 
     | 
    
         
            +
            from ring import spatial
         
     | 
| 
      
 15 
     | 
    
         
            +
             
     | 
| 
      
 16 
     | 
    
         
            +
            Scalar = jax.Array
         
     | 
| 
      
 17 
     | 
    
         
            +
            Vector = jax.Array
         
     | 
| 
      
 18 
     | 
    
         
            +
            Quaternion = jax.Array
         
     | 
| 
      
 19 
     | 
    
         
            +
             
     | 
| 
      
 20 
     | 
    
         
            +
             
     | 
| 
      
 21 
     | 
    
         
            +
            Color = Optional[str | tuple[float, float, float] | tuple[float, float, float, float]]
         
     | 
| 
      
 22 
     | 
    
         
            +
             
     | 
| 
      
 23 
     | 
    
         
            +
             
     | 
| 
      
 24 
     | 
    
         
            +
            class _Base:
         
     | 
| 
      
 25 
     | 
    
         
            +
                """Base functionality of all spatial datatypes.
         
     | 
| 
      
 26 
     | 
    
         
            +
                Copied and modified from https://github.com/google/brax/blob/main/brax/v2/base.py
         
     | 
| 
      
 27 
     | 
    
         
            +
                """
         
     | 
| 
      
 28 
     | 
    
         
            +
             
     | 
| 
      
 29 
     | 
    
         
            +
                def __add__(self, o: Any) -> Any:
         
     | 
| 
      
 30 
     | 
    
         
            +
                    return tree_map(lambda x, y: x + y, self, o)
         
     | 
| 
      
 31 
     | 
    
         
            +
             
     | 
| 
      
 32 
     | 
    
         
            +
                def __sub__(self, o: Any) -> Any:
         
     | 
| 
      
 33 
     | 
    
         
            +
                    return tree_map(lambda x, y: x - y, self, o)
         
     | 
| 
      
 34 
     | 
    
         
            +
             
     | 
| 
      
 35 
     | 
    
         
            +
                def __mul__(self, o: Any) -> Any:
         
     | 
| 
      
 36 
     | 
    
         
            +
                    return tree_map(lambda x: x * o, self)
         
     | 
| 
      
 37 
     | 
    
         
            +
             
     | 
| 
      
 38 
     | 
    
         
            +
                def __neg__(self) -> Any:
         
     | 
| 
      
 39 
     | 
    
         
            +
                    return tree_map(lambda x: -x, self)
         
     | 
| 
      
 40 
     | 
    
         
            +
             
     | 
| 
      
 41 
     | 
    
         
            +
                def __truediv__(self, o: Any) -> Any:
         
     | 
| 
      
 42 
     | 
    
         
            +
                    return tree_map(lambda x: x / o, self)
         
     | 
| 
      
 43 
     | 
    
         
            +
             
     | 
| 
      
 44 
     | 
    
         
            +
                def __getitem__(self, i: int) -> Any:
         
     | 
| 
      
 45 
     | 
    
         
            +
                    return self.take(i)
         
     | 
| 
      
 46 
     | 
    
         
            +
             
     | 
| 
      
 47 
     | 
    
         
            +
                def reshape(self, shape: Sequence[int]) -> Any:
         
     | 
| 
      
 48 
     | 
    
         
            +
                    return tree_map(lambda x: x.reshape(shape), self)
         
     | 
| 
      
 49 
     | 
    
         
            +
             
     | 
| 
      
 50 
     | 
    
         
            +
                def slice(self, beg: int, end: int) -> Any:
         
     | 
| 
      
 51 
     | 
    
         
            +
                    return tree_map(lambda x: x[beg:end], self)
         
     | 
| 
      
 52 
     | 
    
         
            +
             
     | 
| 
      
 53 
     | 
    
         
            +
                def take(self, i, axis=0) -> Any:
         
     | 
| 
      
 54 
     | 
    
         
            +
                    return tree_map(lambda x: jnp.take(x, i, axis=axis), self)
         
     | 
| 
      
 55 
     | 
    
         
            +
             
     | 
| 
      
 56 
     | 
    
         
            +
                def hstack(self, *others: Any) -> Any:
         
     | 
| 
      
 57 
     | 
    
         
            +
                    return tree_map(lambda *x: jnp.hstack(x), self, *others)
         
     | 
| 
      
 58 
     | 
    
         
            +
             
     | 
| 
      
 59 
     | 
    
         
            +
                def vstack(self, *others: Any) -> Any:
         
     | 
| 
      
 60 
     | 
    
         
            +
                    return tree_map(lambda *x: jnp.vstack(x), self, *others)
         
     | 
| 
      
 61 
     | 
    
         
            +
             
     | 
| 
      
 62 
     | 
    
         
            +
                def concatenate(self, *others: Any, axis: int = 0) -> Any:
         
     | 
| 
      
 63 
     | 
    
         
            +
                    return tree_map(lambda *x: jnp.concatenate(x, axis=axis), self, *others)
         
     | 
| 
      
 64 
     | 
    
         
            +
             
     | 
| 
      
 65 
     | 
    
         
            +
                def batch(self, *others, along_existing_first_axis: bool = False) -> Any:
         
     | 
| 
      
 66 
     | 
    
         
            +
                    return tu.tree_batch((self,) + others, along_existing_first_axis, "jax")
         
     | 
| 
      
 67 
     | 
    
         
            +
             
     | 
| 
      
 68 
     | 
    
         
            +
                def index_set(self, idx: Union[jnp.ndarray, Sequence[jnp.ndarray]], o: Any) -> Any:
         
     | 
| 
      
 69 
     | 
    
         
            +
                    return tree_map(lambda x, y: x.at[idx].set(y), self, o)
         
     | 
| 
      
 70 
     | 
    
         
            +
             
     | 
| 
      
 71 
     | 
    
         
            +
                def index_sum(self, idx: Union[jnp.ndarray, Sequence[jnp.ndarray]], o: Any) -> Any:
         
     | 
| 
      
 72 
     | 
    
         
            +
                    return tree_map(lambda x, y: x.at[idx].add(y), self, o)
         
     | 
| 
      
 73 
     | 
    
         
            +
             
     | 
| 
      
 74 
     | 
    
         
            +
                @property
         
     | 
| 
      
 75 
     | 
    
         
            +
                def T(self):
         
     | 
| 
      
 76 
     | 
    
         
            +
                    return tree_map(lambda x: x.T, self)
         
     | 
| 
      
 77 
     | 
    
         
            +
             
     | 
| 
      
 78 
     | 
    
         
            +
                def flatten(self, num_batch_dims: int = 0) -> jax.Array:
         
     | 
| 
      
 79 
     | 
    
         
            +
                    return tu.batch_concat(self, num_batch_dims)
         
     | 
| 
      
 80 
     | 
    
         
            +
             
     | 
| 
      
 81 
     | 
    
         
            +
                def squeeze(self):
         
     | 
| 
      
 82 
     | 
    
         
            +
                    return tree_map(lambda x: jnp.squeeze(x), self)
         
     | 
| 
      
 83 
     | 
    
         
            +
             
     | 
| 
      
 84 
     | 
    
         
            +
                def squeeze_1d(self):
         
     | 
| 
      
 85 
     | 
    
         
            +
                    return tree_map(lambda x: jnp.atleast_1d(jnp.squeeze(x)), self)
         
     | 
| 
      
 86 
     | 
    
         
            +
             
     | 
| 
      
 87 
     | 
    
         
            +
                def batch_dim(self) -> int:
         
     | 
| 
      
 88 
     | 
    
         
            +
                    return tu.tree_shape(self)
         
     | 
| 
      
 89 
     | 
    
         
            +
             
     | 
| 
      
 90 
     | 
    
         
            +
                def transpose(self, axes: Sequence[int]) -> Any:
         
     | 
| 
      
 91 
     | 
    
         
            +
                    return tree_map(lambda x: jnp.transpose(x, axes), self)
         
     | 
| 
      
 92 
     | 
    
         
            +
             
     | 
| 
      
 93 
     | 
    
         
            +
                def __iter__(self):
         
     | 
| 
      
 94 
     | 
    
         
            +
                    raise NotImplementedError
         
     | 
| 
      
 95 
     | 
    
         
            +
             
     | 
| 
      
 96 
     | 
    
         
            +
                def repeat(self, repeats, axis=0):
         
     | 
| 
      
 97 
     | 
    
         
            +
                    return tree_map(lambda x: jnp.repeat(x, repeats, axis), self)
         
     | 
| 
      
 98 
     | 
    
         
            +
             
     | 
| 
      
 99 
     | 
    
         
            +
                def ndim(self):
         
     | 
| 
      
 100 
     | 
    
         
            +
                    return tu.tree_ndim(self)
         
     | 
| 
      
 101 
     | 
    
         
            +
             
     | 
| 
      
 102 
     | 
    
         
            +
                def shape(self, axis=0) -> int:
         
     | 
| 
      
 103 
     | 
    
         
            +
                    return tu.tree_shape(self, axis)
         
     | 
| 
      
 104 
     | 
    
         
            +
             
     | 
| 
      
 105 
     | 
    
         
            +
                def __len__(self) -> int:
         
     | 
| 
      
 106 
     | 
    
         
            +
                    Bs = tree_map(lambda arr: arr.shape[0], self)
         
     | 
| 
      
 107 
     | 
    
         
            +
                    Bs = set(jax.tree_util.tree_flatten(Bs)[0])
         
     | 
| 
      
 108 
     | 
    
         
            +
                    assert len(Bs) == 1
         
     | 
| 
      
 109 
     | 
    
         
            +
                    return list(Bs)[0]
         
     | 
| 
      
 110 
     | 
    
         
            +
             
     | 
| 
      
 111 
     | 
    
         
            +
             
     | 
| 
      
 112 
     | 
    
         
            +
            @struct.dataclass
         
     | 
| 
      
 113 
     | 
    
         
            +
            class Transform(_Base):
         
     | 
| 
      
 114 
     | 
    
         
            +
                """Represents the Transformation from Plücker A to Plücker B,
         
     | 
| 
      
 115 
     | 
    
         
            +
                where B is located relative to A at `pos` in frame A and `rot` is the
         
     | 
| 
      
 116 
     | 
    
         
            +
                relative quaternion from A to B."""
         
     | 
| 
      
 117 
     | 
    
         
            +
             
     | 
| 
      
 118 
     | 
    
         
            +
                pos: Vector
         
     | 
| 
      
 119 
     | 
    
         
            +
                rot: Quaternion
         
     | 
| 
      
 120 
     | 
    
         
            +
             
     | 
| 
      
 121 
     | 
    
         
            +
                @classmethod
         
     | 
| 
      
 122 
     | 
    
         
            +
                def create(cls, pos=None, rot=None):
         
     | 
| 
      
 123 
     | 
    
         
            +
                    assert not (pos is None and rot is None), "One must be given."
         
     | 
| 
      
 124 
     | 
    
         
            +
                    shape_rot = rot.shape[:-1] if rot is not None else ()
         
     | 
| 
      
 125 
     | 
    
         
            +
                    shape_pos = pos.shape[:-1] if pos is not None else ()
         
     | 
| 
      
 126 
     | 
    
         
            +
             
     | 
| 
      
 127 
     | 
    
         
            +
                    if pos is None:
         
     | 
| 
      
 128 
     | 
    
         
            +
                        pos = jnp.zeros(shape_rot + (3,))
         
     | 
| 
      
 129 
     | 
    
         
            +
                    if rot is None:
         
     | 
| 
      
 130 
     | 
    
         
            +
                        rot = jnp.array([1.0, 0, 0, 0])
         
     | 
| 
      
 131 
     | 
    
         
            +
                        rot = jnp.tile(jnp.array([1.0, 0.0, 0.0, 0.0]), shape_pos + (1,))
         
     | 
| 
      
 132 
     | 
    
         
            +
             
     | 
| 
      
 133 
     | 
    
         
            +
                    assert pos.shape[:-1] == rot.shape[:-1]
         
     | 
| 
      
 134 
     | 
    
         
            +
             
     | 
| 
      
 135 
     | 
    
         
            +
                    return Transform(pos, rot)
         
     | 
| 
      
 136 
     | 
    
         
            +
             
     | 
| 
      
 137 
     | 
    
         
            +
                @classmethod
         
     | 
| 
      
 138 
     | 
    
         
            +
                def zero(cls, shape=()) -> "Transform":
         
     | 
| 
      
 139 
     | 
    
         
            +
                    """Returns a zero transform with a batch shape."""
         
     | 
| 
      
 140 
     | 
    
         
            +
                    pos = jnp.zeros(shape + (3,))
         
     | 
| 
      
 141 
     | 
    
         
            +
                    rot = jnp.tile(jnp.array([1.0, 0.0, 0.0, 0.0]), shape + (1,))
         
     | 
| 
      
 142 
     | 
    
         
            +
                    return Transform(pos, rot)
         
     | 
| 
      
 143 
     | 
    
         
            +
             
     | 
| 
      
 144 
     | 
    
         
            +
                def as_matrix(self) -> jax.Array:
         
     | 
| 
      
 145 
     | 
    
         
            +
                    E = maths.quat_to_3x3(self.rot)
         
     | 
| 
      
 146 
     | 
    
         
            +
                    return spatial.quadrants(aa=E, bb=E) @ spatial.xlt(self.pos)
         
     | 
| 
      
 147 
     | 
    
         
            +
             
     | 
| 
      
 148 
     | 
    
         
            +
             
     | 
| 
      
 149 
     | 
    
         
            +
            @struct.dataclass
         
     | 
| 
      
 150 
     | 
    
         
            +
            class Motion(_Base):
         
     | 
| 
      
 151 
     | 
    
         
            +
                "Coordinate vector that represents a spatial motion vector in Plücker Coordinates."
         
     | 
| 
      
 152 
     | 
    
         
            +
                ang: Vector
         
     | 
| 
      
 153 
     | 
    
         
            +
                vel: Vector
         
     | 
| 
      
 154 
     | 
    
         
            +
             
     | 
| 
      
 155 
     | 
    
         
            +
                @classmethod
         
     | 
| 
      
 156 
     | 
    
         
            +
                def create(cls, ang=None, vel=None):
         
     | 
| 
      
 157 
     | 
    
         
            +
                    assert not (ang is None and vel is None), "One must be given."
         
     | 
| 
      
 158 
     | 
    
         
            +
                    if ang is None:
         
     | 
| 
      
 159 
     | 
    
         
            +
                        ang = jnp.zeros((3,))
         
     | 
| 
      
 160 
     | 
    
         
            +
                    if vel is None:
         
     | 
| 
      
 161 
     | 
    
         
            +
                        vel = jnp.zeros((3,))
         
     | 
| 
      
 162 
     | 
    
         
            +
                    return Motion(ang, vel)
         
     | 
| 
      
 163 
     | 
    
         
            +
             
     | 
| 
      
 164 
     | 
    
         
            +
                @classmethod
         
     | 
| 
      
 165 
     | 
    
         
            +
                def zero(cls, shape=()) -> "Motion":
         
     | 
| 
      
 166 
     | 
    
         
            +
                    ang = jnp.zeros(shape + (3,))
         
     | 
| 
      
 167 
     | 
    
         
            +
                    vel = jnp.zeros(shape + (3,))
         
     | 
| 
      
 168 
     | 
    
         
            +
                    return Motion(ang, vel)
         
     | 
| 
      
 169 
     | 
    
         
            +
             
     | 
| 
      
 170 
     | 
    
         
            +
                def as_matrix(self):
         
     | 
| 
      
 171 
     | 
    
         
            +
                    return self.flatten()
         
     | 
| 
      
 172 
     | 
    
         
            +
             
     | 
| 
      
 173 
     | 
    
         
            +
             
     | 
| 
      
 174 
     | 
    
         
            +
            @struct.dataclass
         
     | 
| 
      
 175 
     | 
    
         
            +
            class Force(_Base):
         
     | 
| 
      
 176 
     | 
    
         
            +
                "Coordinate vector that represents a spatial force vector in Plücker Coordinates."
         
     | 
| 
      
 177 
     | 
    
         
            +
                ang: Vector
         
     | 
| 
      
 178 
     | 
    
         
            +
                vel: Vector
         
     | 
| 
      
 179 
     | 
    
         
            +
             
     | 
| 
      
 180 
     | 
    
         
            +
                @classmethod
         
     | 
| 
      
 181 
     | 
    
         
            +
                def create(cls, ang=None, vel=None):
         
     | 
| 
      
 182 
     | 
    
         
            +
                    assert not (ang is None and vel is None), "One must be given."
         
     | 
| 
      
 183 
     | 
    
         
            +
                    if ang is None:
         
     | 
| 
      
 184 
     | 
    
         
            +
                        ang = jnp.zeros((3,))
         
     | 
| 
      
 185 
     | 
    
         
            +
                    if vel is None:
         
     | 
| 
      
 186 
     | 
    
         
            +
                        vel = jnp.zeros((3,))
         
     | 
| 
      
 187 
     | 
    
         
            +
                    return Force(ang, vel)
         
     | 
| 
      
 188 
     | 
    
         
            +
             
     | 
| 
      
 189 
     | 
    
         
            +
                @classmethod
         
     | 
| 
      
 190 
     | 
    
         
            +
                def zero(cls, shape=()) -> "Force":
         
     | 
| 
      
 191 
     | 
    
         
            +
                    ang = jnp.zeros(shape + (3,))
         
     | 
| 
      
 192 
     | 
    
         
            +
                    vel = jnp.zeros(shape + (3,))
         
     | 
| 
      
 193 
     | 
    
         
            +
                    return Force(ang, vel)
         
     | 
| 
      
 194 
     | 
    
         
            +
             
     | 
| 
      
 195 
     | 
    
         
            +
                def as_matrix(self):
         
     | 
| 
      
 196 
     | 
    
         
            +
                    return self.flatten()
         
     | 
| 
      
 197 
     | 
    
         
            +
             
     | 
| 
      
 198 
     | 
    
         
            +
             
     | 
| 
      
 199 
     | 
    
         
            +
            @struct.dataclass
         
     | 
| 
      
 200 
     | 
    
         
            +
            class Inertia(_Base):
         
     | 
| 
      
 201 
     | 
    
         
            +
                """Spatial Inertia Matrix in Plücker Coordinates.
         
     | 
| 
      
 202 
     | 
    
         
            +
                Note that `h` is *not* the center of mass."""
         
     | 
| 
      
 203 
     | 
    
         
            +
             
     | 
| 
      
 204 
     | 
    
         
            +
                it_3x3: jax.Array
         
     | 
| 
      
 205 
     | 
    
         
            +
                h: Vector
         
     | 
| 
      
 206 
     | 
    
         
            +
                mass: Vector
         
     | 
| 
      
 207 
     | 
    
         
            +
             
     | 
| 
      
 208 
     | 
    
         
            +
                @classmethod
         
     | 
| 
      
 209 
     | 
    
         
            +
                def create(cls, mass: Vector, transform: Transform, it_3x3: jnp.ndarray):
         
     | 
| 
      
 210 
     | 
    
         
            +
                    """Construct spatial inertia of an object with mass `mass` located and aligned
         
     | 
| 
      
 211 
     | 
    
         
            +
                    with a coordinate system that is given by `transform` where `transform` is from
         
     | 
| 
      
 212 
     | 
    
         
            +
                    parent to local geometry coordinates.
         
     | 
| 
      
 213 
     | 
    
         
            +
                    """
         
     | 
| 
      
 214 
     | 
    
         
            +
                    it_3x3 = maths.rotate_matrix(it_3x3, maths.quat_inv(transform.rot))
         
     | 
| 
      
 215 
     | 
    
         
            +
                    it_3x3 = spatial.mcI(mass, transform.pos, it_3x3)[:3, :3]
         
     | 
| 
      
 216 
     | 
    
         
            +
                    h = mass * transform.pos
         
     | 
| 
      
 217 
     | 
    
         
            +
                    return cls(it_3x3, h, mass)
         
     | 
| 
      
 218 
     | 
    
         
            +
             
     | 
| 
      
 219 
     | 
    
         
            +
                @classmethod
         
     | 
| 
      
 220 
     | 
    
         
            +
                def zero(cls, shape=()) -> "Inertia":
         
     | 
| 
      
 221 
     | 
    
         
            +
                    it_shape_3x3 = jnp.zeros(shape + (3, 3))
         
     | 
| 
      
 222 
     | 
    
         
            +
                    h = jnp.zeros(shape + (3,))
         
     | 
| 
      
 223 
     | 
    
         
            +
                    mass = jnp.zeros(shape + (1,))
         
     | 
| 
      
 224 
     | 
    
         
            +
                    return cls(it_shape_3x3, h, mass)
         
     | 
| 
      
 225 
     | 
    
         
            +
             
     | 
| 
      
 226 
     | 
    
         
            +
                def as_matrix(self):
         
     | 
| 
      
 227 
     | 
    
         
            +
                    hcross = spatial.cross(self.h)
         
     | 
| 
      
 228 
     | 
    
         
            +
                    return spatial.quadrants(self.it_3x3, hcross, -hcross, self.mass * jnp.eye(3))
         
     | 
| 
      
 229 
     | 
    
         
            +
             
     | 
| 
      
 230 
     | 
    
         
            +
             
     | 
| 
      
 231 
     | 
    
         
            +
            @struct.dataclass
         
     | 
| 
      
 232 
     | 
    
         
            +
            class Geometry(_Base):
         
     | 
| 
      
 233 
     | 
    
         
            +
                mass: jax.Array
         
     | 
| 
      
 234 
     | 
    
         
            +
                transform: Transform
         
     | 
| 
      
 235 
     | 
    
         
            +
                link_idx: int = struct.field(pytree_node=False)
         
     | 
| 
      
 236 
     | 
    
         
            +
             
     | 
| 
      
 237 
     | 
    
         
            +
                color: Color = struct.field(pytree_node=False)
         
     | 
| 
      
 238 
     | 
    
         
            +
                edge_color: Color = struct.field(pytree_node=False)
         
     | 
| 
      
 239 
     | 
    
         
            +
             
     | 
| 
      
 240 
     | 
    
         
            +
             
     | 
| 
      
 241 
     | 
    
         
            +
            @struct.dataclass
         
     | 
| 
      
 242 
     | 
    
         
            +
            class XYZ(Geometry):
         
     | 
| 
      
 243 
     | 
    
         
            +
                # TODO: possibly subclass this of _Base? does this need a mass, transform, and
         
     | 
| 
      
 244 
     | 
    
         
            +
                # link_idx? maybe just transform?
         
     | 
| 
      
 245 
     | 
    
         
            +
                size: float
         
     | 
| 
      
 246 
     | 
    
         
            +
             
     | 
| 
      
 247 
     | 
    
         
            +
                @classmethod
         
     | 
| 
      
 248 
     | 
    
         
            +
                def create(cls, link_idx: int, size: float):
         
     | 
| 
      
 249 
     | 
    
         
            +
                    return cls(0.0, Transform.zero(), link_idx, None, None, size)
         
     | 
| 
      
 250 
     | 
    
         
            +
             
     | 
| 
      
 251 
     | 
    
         
            +
                def get_it_3x3(self) -> jax.Array:
         
     | 
| 
      
 252 
     | 
    
         
            +
                    return jnp.zeros((3, 3))
         
     | 
| 
      
 253 
     | 
    
         
            +
             
     | 
| 
      
 254 
     | 
    
         
            +
             
     | 
| 
      
 255 
     | 
    
         
            +
            @struct.dataclass
         
     | 
| 
      
 256 
     | 
    
         
            +
            class Sphere(Geometry):
         
     | 
| 
      
 257 
     | 
    
         
            +
                radius: float
         
     | 
| 
      
 258 
     | 
    
         
            +
             
     | 
| 
      
 259 
     | 
    
         
            +
                def get_it_3x3(self) -> jax.Array:
         
     | 
| 
      
 260 
     | 
    
         
            +
                    it_3x3 = 2 / 5 * self.mass * self.radius**2 * jnp.eye(3)
         
     | 
| 
      
 261 
     | 
    
         
            +
                    return it_3x3
         
     | 
| 
      
 262 
     | 
    
         
            +
             
     | 
| 
      
 263 
     | 
    
         
            +
             
     | 
| 
      
 264 
     | 
    
         
            +
            @struct.dataclass
         
     | 
| 
      
 265 
     | 
    
         
            +
            class Box(Geometry):
         
     | 
| 
      
 266 
     | 
    
         
            +
                dim_x: float
         
     | 
| 
      
 267 
     | 
    
         
            +
                dim_y: float
         
     | 
| 
      
 268 
     | 
    
         
            +
                dim_z: float
         
     | 
| 
      
 269 
     | 
    
         
            +
             
     | 
| 
      
 270 
     | 
    
         
            +
                def get_it_3x3(self) -> jax.Array:
         
     | 
| 
      
 271 
     | 
    
         
            +
                    it_3x3 = (
         
     | 
| 
      
 272 
     | 
    
         
            +
                        1
         
     | 
| 
      
 273 
     | 
    
         
            +
                        / 12
         
     | 
| 
      
 274 
     | 
    
         
            +
                        * self.mass
         
     | 
| 
      
 275 
     | 
    
         
            +
                        * jnp.diag(
         
     | 
| 
      
 276 
     | 
    
         
            +
                            jnp.array(
         
     | 
| 
      
 277 
     | 
    
         
            +
                                [
         
     | 
| 
      
 278 
     | 
    
         
            +
                                    self.dim_y**2 + self.dim_z**2,
         
     | 
| 
      
 279 
     | 
    
         
            +
                                    self.dim_x**2 + self.dim_z**2,
         
     | 
| 
      
 280 
     | 
    
         
            +
                                    self.dim_x**2 + self.dim_y**2,
         
     | 
| 
      
 281 
     | 
    
         
            +
                                ]
         
     | 
| 
      
 282 
     | 
    
         
            +
                            )
         
     | 
| 
      
 283 
     | 
    
         
            +
                        )
         
     | 
| 
      
 284 
     | 
    
         
            +
                    )
         
     | 
| 
      
 285 
     | 
    
         
            +
                    return it_3x3
         
     | 
| 
      
 286 
     | 
    
         
            +
             
     | 
| 
      
 287 
     | 
    
         
            +
             
     | 
| 
      
 288 
     | 
    
         
            +
            @struct.dataclass
         
     | 
| 
      
 289 
     | 
    
         
            +
            class Cylinder(Geometry):
         
     | 
| 
      
 290 
     | 
    
         
            +
                """Length is along x-axis."""
         
     | 
| 
      
 291 
     | 
    
         
            +
             
     | 
| 
      
 292 
     | 
    
         
            +
                radius: float
         
     | 
| 
      
 293 
     | 
    
         
            +
                length: float
         
     | 
| 
      
 294 
     | 
    
         
            +
             
     | 
| 
      
 295 
     | 
    
         
            +
                def get_it_3x3(self) -> jax.Array:
         
     | 
| 
      
 296 
     | 
    
         
            +
                    radius_dir = 3 * self.radius**2 + self.length**2
         
     | 
| 
      
 297 
     | 
    
         
            +
                    it_3x3 = (
         
     | 
| 
      
 298 
     | 
    
         
            +
                        1
         
     | 
| 
      
 299 
     | 
    
         
            +
                        / 12
         
     | 
| 
      
 300 
     | 
    
         
            +
                        * self.mass
         
     | 
| 
      
 301 
     | 
    
         
            +
                        * jnp.diag(jnp.array([6 * self.radius**2, radius_dir, radius_dir]))
         
     | 
| 
      
 302 
     | 
    
         
            +
                    )
         
     | 
| 
      
 303 
     | 
    
         
            +
                    return it_3x3
         
     | 
| 
      
 304 
     | 
    
         
            +
             
     | 
| 
      
 305 
     | 
    
         
            +
             
     | 
| 
      
 306 
     | 
    
         
            +
            @struct.dataclass
         
     | 
| 
      
 307 
     | 
    
         
            +
            class Capsule(Geometry):
         
     | 
| 
      
 308 
     | 
    
         
            +
                """Length is along x-axis."""
         
     | 
| 
      
 309 
     | 
    
         
            +
             
     | 
| 
      
 310 
     | 
    
         
            +
                radius: float
         
     | 
| 
      
 311 
     | 
    
         
            +
                length: float
         
     | 
| 
      
 312 
     | 
    
         
            +
             
     | 
| 
      
 313 
     | 
    
         
            +
                def get_it_3x3(self) -> jax.Array:
         
     | 
| 
      
 314 
     | 
    
         
            +
                    """https://github.com/thomasmarsh/ODE/blob/master/ode/src/mass.cpp#L141"""
         
     | 
| 
      
 315 
     | 
    
         
            +
                    r = self.radius
         
     | 
| 
      
 316 
     | 
    
         
            +
                    d = self.length
         
     | 
| 
      
 317 
     | 
    
         
            +
             
     | 
| 
      
 318 
     | 
    
         
            +
                    v_cyl = jnp.pi * r**2 * d
         
     | 
| 
      
 319 
     | 
    
         
            +
                    v_cap = 4 / 3 * jnp.pi * r**3
         
     | 
| 
      
 320 
     | 
    
         
            +
             
     | 
| 
      
 321 
     | 
    
         
            +
                    v_tot = v_cyl + v_cap
         
     | 
| 
      
 322 
     | 
    
         
            +
             
     | 
| 
      
 323 
     | 
    
         
            +
                    m_cyl = self.mass * v_cyl / v_tot
         
     | 
| 
      
 324 
     | 
    
         
            +
                    m_cap = self.mass * v_cap / v_tot
         
     | 
| 
      
 325 
     | 
    
         
            +
             
     | 
| 
      
 326 
     | 
    
         
            +
                    I_a = m_cyl * (0.25 * r**2 + 1 / 12 * d**2) + m_cap * (
         
     | 
| 
      
 327 
     | 
    
         
            +
                        0.4 * r**2 + 0.375 * r * d + 0.25 * d**2
         
     | 
| 
      
 328 
     | 
    
         
            +
                    )
         
     | 
| 
      
 329 
     | 
    
         
            +
                    I_b = (0.5 * m_cyl + 0.4 * m_cap) * r**2
         
     | 
| 
      
 330 
     | 
    
         
            +
             
     | 
| 
      
 331 
     | 
    
         
            +
                    return jnp.diag(jnp.array([I_b, I_a, I_a]))
         
     | 
| 
      
 332 
     | 
    
         
            +
             
     | 
| 
      
 333 
     | 
    
         
            +
             
     | 
| 
      
 334 
     | 
    
         
            +
            _DEFAULT_JOINT_PARAMS_DICT: dict[str, tu.PyTree] = {"default": jnp.array([])}
         
     | 
| 
      
 335 
     | 
    
         
            +
             
     | 
| 
      
 336 
     | 
    
         
            +
             
     | 
| 
      
 337 
     | 
    
         
            +
            @struct.dataclass
         
     | 
| 
      
 338 
     | 
    
         
            +
            class Link(_Base):
         
     | 
| 
      
 339 
     | 
    
         
            +
                transform1: Transform
         
     | 
| 
      
 340 
     | 
    
         
            +
             
     | 
| 
      
 341 
     | 
    
         
            +
                # only used by `setup_fn_randomize_positions`
         
     | 
| 
      
 342 
     | 
    
         
            +
                pos_min: jax.Array = struct.field(default_factory=lambda: jnp.zeros((3,)))
         
     | 
| 
      
 343 
     | 
    
         
            +
                pos_max: jax.Array = struct.field(default_factory=lambda: jnp.zeros((3,)))
         
     | 
| 
      
 344 
     | 
    
         
            +
             
     | 
| 
      
 345 
     | 
    
         
            +
                # these parameters can be used to model joints that have parameters
         
     | 
| 
      
 346 
     | 
    
         
            +
                # they are directly feed into the `jcalc` routines
         
     | 
| 
      
 347 
     | 
    
         
            +
                joint_params: dict[str, tu.PyTree] = struct.field(
         
     | 
| 
      
 348 
     | 
    
         
            +
                    default_factory=lambda: _DEFAULT_JOINT_PARAMS_DICT
         
     | 
| 
      
 349 
     | 
    
         
            +
                )
         
     | 
| 
      
 350 
     | 
    
         
            +
             
     | 
| 
      
 351 
     | 
    
         
            +
                # internal useage
         
     | 
| 
      
 352 
     | 
    
         
            +
                # gets populated by `parse_system`
         
     | 
| 
      
 353 
     | 
    
         
            +
                inertia: Inertia = Inertia.zero()
         
     | 
| 
      
 354 
     | 
    
         
            +
                # gets populated by `forward_kinematics`
         
     | 
| 
      
 355 
     | 
    
         
            +
                transform2: Transform = Transform.zero()
         
     | 
| 
      
 356 
     | 
    
         
            +
                transform: Transform = Transform.zero()
         
     | 
| 
      
 357 
     | 
    
         
            +
             
     | 
| 
      
 358 
     | 
    
         
            +
             
     | 
| 
      
 359 
     | 
    
         
            +
            @struct.dataclass
         
     | 
| 
      
 360 
     | 
    
         
            +
            class MaxCoordOMC(_Base):
         
     | 
| 
      
 361 
     | 
    
         
            +
                coordinate_system_name: str = struct.field(False)
         
     | 
| 
      
 362 
     | 
    
         
            +
                pos_marker_number: int = struct.field(False)
         
     | 
| 
      
 363 
     | 
    
         
            +
                pos_marker_constant_offset: jax.Array
         
     | 
| 
      
 364 
     | 
    
         
            +
             
     | 
| 
      
 365 
     | 
    
         
            +
             
     | 
| 
      
 366 
     | 
    
         
            +
            Q_WIDTHS = {
         
     | 
| 
      
 367 
     | 
    
         
            +
                "free": 7,
         
     | 
| 
      
 368 
     | 
    
         
            +
                "free_2d": 3,
         
     | 
| 
      
 369 
     | 
    
         
            +
                "frozen": 0,
         
     | 
| 
      
 370 
     | 
    
         
            +
                "spherical": 4,
         
     | 
| 
      
 371 
     | 
    
         
            +
                "p3d": 3,
         
     | 
| 
      
 372 
     | 
    
         
            +
                # center of rotation, a `free` joint and then a `p3d` joint with custom
         
     | 
| 
      
 373 
     | 
    
         
            +
                # parameter fields in `RMCG_Config`
         
     | 
| 
      
 374 
     | 
    
         
            +
                "cor": 10,
         
     | 
| 
      
 375 
     | 
    
         
            +
                "px": 1,
         
     | 
| 
      
 376 
     | 
    
         
            +
                "py": 1,
         
     | 
| 
      
 377 
     | 
    
         
            +
                "pz": 1,
         
     | 
| 
      
 378 
     | 
    
         
            +
                "rx": 1,
         
     | 
| 
      
 379 
     | 
    
         
            +
                "ry": 1,
         
     | 
| 
      
 380 
     | 
    
         
            +
                "rz": 1,
         
     | 
| 
      
 381 
     | 
    
         
            +
                "saddle": 2,
         
     | 
| 
      
 382 
     | 
    
         
            +
            }
         
     | 
| 
      
 383 
     | 
    
         
            +
            QD_WIDTHS = {
         
     | 
| 
      
 384 
     | 
    
         
            +
                "free": 6,
         
     | 
| 
      
 385 
     | 
    
         
            +
                "free_2d": 3,
         
     | 
| 
      
 386 
     | 
    
         
            +
                "frozen": 0,
         
     | 
| 
      
 387 
     | 
    
         
            +
                "spherical": 3,
         
     | 
| 
      
 388 
     | 
    
         
            +
                "p3d": 3,
         
     | 
| 
      
 389 
     | 
    
         
            +
                "cor": 9,
         
     | 
| 
      
 390 
     | 
    
         
            +
                "px": 1,
         
     | 
| 
      
 391 
     | 
    
         
            +
                "py": 1,
         
     | 
| 
      
 392 
     | 
    
         
            +
                "pz": 1,
         
     | 
| 
      
 393 
     | 
    
         
            +
                "rx": 1,
         
     | 
| 
      
 394 
     | 
    
         
            +
                "ry": 1,
         
     | 
| 
      
 395 
     | 
    
         
            +
                "rz": 1,
         
     | 
| 
      
 396 
     | 
    
         
            +
                "saddle": 2,
         
     | 
| 
      
 397 
     | 
    
         
            +
            }
         
     | 
| 
      
 398 
     | 
    
         
            +
             
     | 
| 
      
 399 
     | 
    
         
            +
             
     | 
| 
      
 400 
     | 
    
         
            +
            @struct.dataclass
         
     | 
| 
      
 401 
     | 
    
         
            +
            class System(_Base):
         
     | 
| 
      
 402 
     | 
    
         
            +
                link_parents: list[int] = struct.field(False)
         
     | 
| 
      
 403 
     | 
    
         
            +
                links: Link
         
     | 
| 
      
 404 
     | 
    
         
            +
                link_types: list[str] = struct.field(False)
         
     | 
| 
      
 405 
     | 
    
         
            +
                link_damping: jax.Array
         
     | 
| 
      
 406 
     | 
    
         
            +
                link_armature: jax.Array
         
     | 
| 
      
 407 
     | 
    
         
            +
                link_spring_stiffness: jax.Array
         
     | 
| 
      
 408 
     | 
    
         
            +
                link_spring_zeropoint: jax.Array
         
     | 
| 
      
 409 
     | 
    
         
            +
                # simulation timestep size
         
     | 
| 
      
 410 
     | 
    
         
            +
                dt: float = struct.field(False)
         
     | 
| 
      
 411 
     | 
    
         
            +
                # geometries in the system
         
     | 
| 
      
 412 
     | 
    
         
            +
                geoms: list[Geometry]
         
     | 
| 
      
 413 
     | 
    
         
            +
                # root / base acceleration offset
         
     | 
| 
      
 414 
     | 
    
         
            +
                gravity: jax.Array = struct.field(default_factory=lambda: jnp.array([0, 0, -9.81]))
         
     | 
| 
      
 415 
     | 
    
         
            +
             
     | 
| 
      
 416 
     | 
    
         
            +
                integration_method: str = struct.field(
         
     | 
| 
      
 417 
     | 
    
         
            +
                    False, default_factory=lambda: "semi_implicit_euler"
         
     | 
| 
      
 418 
     | 
    
         
            +
                )
         
     | 
| 
      
 419 
     | 
    
         
            +
                mass_mat_iters: int = struct.field(False, default_factory=lambda: 0)
         
     | 
| 
      
 420 
     | 
    
         
            +
             
     | 
| 
      
 421 
     | 
    
         
            +
                link_names: list[str] = struct.field(False, default_factory=lambda: [])
         
     | 
| 
      
 422 
     | 
    
         
            +
             
     | 
| 
      
 423 
     | 
    
         
            +
                model_name: Optional[str] = struct.field(False, default_factory=lambda: None)
         
     | 
| 
      
 424 
     | 
    
         
            +
             
     | 
| 
      
 425 
     | 
    
         
            +
                omc: list[MaxCoordOMC | None] = struct.field(True, default_factory=lambda: [])
         
     | 
| 
      
 426 
     | 
    
         
            +
             
     | 
| 
      
 427 
     | 
    
         
            +
                def num_links(self) -> int:
         
     | 
| 
      
 428 
     | 
    
         
            +
                    return len(self.link_parents)
         
     | 
| 
      
 429 
     | 
    
         
            +
             
     | 
| 
      
 430 
     | 
    
         
            +
                def q_size(self) -> int:
         
     | 
| 
      
 431 
     | 
    
         
            +
                    return sum([Q_WIDTHS[typ] for typ in self.link_types])
         
     | 
| 
      
 432 
     | 
    
         
            +
             
     | 
| 
      
 433 
     | 
    
         
            +
                def qd_size(self) -> int:
         
     | 
| 
      
 434 
     | 
    
         
            +
                    return sum([QD_WIDTHS[typ] for typ in self.link_types])
         
     | 
| 
      
 435 
     | 
    
         
            +
             
     | 
| 
      
 436 
     | 
    
         
            +
                def name_to_idx(self, name: str) -> int:
         
     | 
| 
      
 437 
     | 
    
         
            +
                    return self.link_names.index(name)
         
     | 
| 
      
 438 
     | 
    
         
            +
             
     | 
| 
      
 439 
     | 
    
         
            +
                def idx_to_name(self, idx: int, allow_world: bool = False) -> str:
         
     | 
| 
      
 440 
     | 
    
         
            +
                    if allow_world and idx == -1:
         
     | 
| 
      
 441 
     | 
    
         
            +
                        return "world"
         
     | 
| 
      
 442 
     | 
    
         
            +
                    assert idx >= 0, "Worldbody index has no name."
         
     | 
| 
      
 443 
     | 
    
         
            +
                    return self.link_names[idx]
         
     | 
| 
      
 444 
     | 
    
         
            +
             
     | 
| 
      
 445 
     | 
    
         
            +
                def idx_map(self, type: str) -> dict:
         
     | 
| 
      
 446 
     | 
    
         
            +
                    "type: is either `l` or `q` or `d`"
         
     | 
| 
      
 447 
     | 
    
         
            +
                    dict_int_slices = {}
         
     | 
| 
      
 448 
     | 
    
         
            +
             
     | 
| 
      
 449 
     | 
    
         
            +
                    def f(_, idx_map, name: str, link_idx: int):
         
     | 
| 
      
 450 
     | 
    
         
            +
                        dict_int_slices[name] = idx_map[type](link_idx)
         
     | 
| 
      
 451 
     | 
    
         
            +
             
     | 
| 
      
 452 
     | 
    
         
            +
                    self.scan(f, "ll", self.link_names, list(range(self.num_links())))
         
     | 
| 
      
 453 
     | 
    
         
            +
             
     | 
| 
      
 454 
     | 
    
         
            +
                    return dict_int_slices
         
     | 
| 
      
 455 
     | 
    
         
            +
             
     | 
| 
      
 456 
     | 
    
         
            +
                def parent_name(self, name: str) -> str:
         
     | 
| 
      
 457 
     | 
    
         
            +
                    return self.idx_to_name(self.link_parents[self.name_to_idx(name)])
         
     | 
| 
      
 458 
     | 
    
         
            +
             
     | 
| 
      
 459 
     | 
    
         
            +
                def add_prefix(self, prefix: str = "") -> "System":
         
     | 
| 
      
 460 
     | 
    
         
            +
                    return self.replace(link_names=[prefix + name for name in self.link_names])
         
     | 
| 
      
 461 
     | 
    
         
            +
             
     | 
| 
      
 462 
     | 
    
         
            +
                def change_model_name(
         
     | 
| 
      
 463 
     | 
    
         
            +
                    self,
         
     | 
| 
      
 464 
     | 
    
         
            +
                    new_name: Optional[str] = None,
         
     | 
| 
      
 465 
     | 
    
         
            +
                    prefix: Optional[str] = None,
         
     | 
| 
      
 466 
     | 
    
         
            +
                    suffix: Optional[str] = None,
         
     | 
| 
      
 467 
     | 
    
         
            +
                ) -> "System":
         
     | 
| 
      
 468 
     | 
    
         
            +
                    if prefix is None:
         
     | 
| 
      
 469 
     | 
    
         
            +
                        prefix = ""
         
     | 
| 
      
 470 
     | 
    
         
            +
                    if suffix is None:
         
     | 
| 
      
 471 
     | 
    
         
            +
                        suffix = ""
         
     | 
| 
      
 472 
     | 
    
         
            +
                    if new_name is None:
         
     | 
| 
      
 473 
     | 
    
         
            +
                        new_name = self.model_name
         
     | 
| 
      
 474 
     | 
    
         
            +
                    name = prefix + new_name + suffix
         
     | 
| 
      
 475 
     | 
    
         
            +
                    return self.replace(model_name=name)
         
     | 
| 
      
 476 
     | 
    
         
            +
             
     | 
| 
      
 477 
     | 
    
         
            +
                def change_link_name(self, old_name: str, new_name: str) -> "System":
         
     | 
| 
      
 478 
     | 
    
         
            +
                    old_idx = self.name_to_idx(old_name)
         
     | 
| 
      
 479 
     | 
    
         
            +
                    new_link_names = self.link_names.copy()
         
     | 
| 
      
 480 
     | 
    
         
            +
                    new_link_names[old_idx] = new_name
         
     | 
| 
      
 481 
     | 
    
         
            +
                    return self.replace(link_names=new_link_names)
         
     | 
| 
      
 482 
     | 
    
         
            +
             
     | 
| 
      
 483 
     | 
    
         
            +
                def add_prefix_suffix(
         
     | 
| 
      
 484 
     | 
    
         
            +
                    self, prefix: Optional[str] = None, suffix: Optional[str] = None
         
     | 
| 
      
 485 
     | 
    
         
            +
                ) -> "System":
         
     | 
| 
      
 486 
     | 
    
         
            +
                    if prefix is None:
         
     | 
| 
      
 487 
     | 
    
         
            +
                        prefix = ""
         
     | 
| 
      
 488 
     | 
    
         
            +
                    if suffix is None:
         
     | 
| 
      
 489 
     | 
    
         
            +
                        suffix = ""
         
     | 
| 
      
 490 
     | 
    
         
            +
                    new_link_names = [prefix + name + suffix for name in self.link_names]
         
     | 
| 
      
 491 
     | 
    
         
            +
                    return self.replace(link_names=new_link_names)
         
     | 
| 
      
 492 
     | 
    
         
            +
             
     | 
| 
      
 493 
     | 
    
         
            +
                @staticmethod
         
     | 
| 
      
 494 
     | 
    
         
            +
                def deep_equal(a, b):
         
     | 
| 
      
 495 
     | 
    
         
            +
                    if type(a) is not type(b):
         
     | 
| 
      
 496 
     | 
    
         
            +
                        return False
         
     | 
| 
      
 497 
     | 
    
         
            +
                    if isinstance(a, _Base):
         
     | 
| 
      
 498 
     | 
    
         
            +
                        return System.deep_equal(a.__dict__, b.__dict__)
         
     | 
| 
      
 499 
     | 
    
         
            +
                    if isinstance(a, dict):
         
     | 
| 
      
 500 
     | 
    
         
            +
                        if a.keys() != b.keys():
         
     | 
| 
      
 501 
     | 
    
         
            +
                            return False
         
     | 
| 
      
 502 
     | 
    
         
            +
                        return all(System.deep_equal(a[k], b[k]) for k in a.keys())
         
     | 
| 
      
 503 
     | 
    
         
            +
                    if isinstance(a, (list, tuple)):
         
     | 
| 
      
 504 
     | 
    
         
            +
                        if len(a) != len(b):
         
     | 
| 
      
 505 
     | 
    
         
            +
                            return False
         
     | 
| 
      
 506 
     | 
    
         
            +
                        return all(System.deep_equal(a[i], b[i]) for i in range(len(a)))
         
     | 
| 
      
 507 
     | 
    
         
            +
                    if isinstance(a, (np.ndarray, jnp.ndarray, jax.Array)):
         
     | 
| 
      
 508 
     | 
    
         
            +
                        return jnp.array_equal(a, b)
         
     | 
| 
      
 509 
     | 
    
         
            +
                    return a == b
         
     | 
| 
      
 510 
     | 
    
         
            +
             
     | 
| 
      
 511 
     | 
    
         
            +
                def _replace_free_with_cor(self) -> "System":
         
     | 
| 
      
 512 
     | 
    
         
            +
                    # check that
         
     | 
| 
      
 513 
     | 
    
         
            +
                    # - all free joints connect to -1
         
     | 
| 
      
 514 
     | 
    
         
            +
                    # - all joints connecting to -1 are free joints
         
     | 
| 
      
 515 
     | 
    
         
            +
                    for i, p in enumerate(self.link_parents):
         
     | 
| 
      
 516 
     | 
    
         
            +
                        link_type = self.link_types[i]
         
     | 
| 
      
 517 
     | 
    
         
            +
                        if (p == -1 and link_type != "free") or (link_type == "free" and p != -1):
         
     | 
| 
      
 518 
     | 
    
         
            +
                            raise InvalidSystemError(
         
     | 
| 
      
 519 
     | 
    
         
            +
                                f"link={self.idx_to_name(i)}, parent="
         
     | 
| 
      
 520 
     | 
    
         
            +
                                f"{self.idx_to_name(p, allow_world=True)},"
         
     | 
| 
      
 521 
     | 
    
         
            +
                                f" joint={link_type}. Hint: Try setting `config.cor` to false."
         
     | 
| 
      
 522 
     | 
    
         
            +
                            )
         
     | 
| 
      
 523 
     | 
    
         
            +
             
     | 
| 
      
 524 
     | 
    
         
            +
                    def logic_replace_free_with_cor(name, olt, ola, old, ols, olz):
         
     | 
| 
      
 525 
     | 
    
         
            +
                        # by default new is equal to old
         
     | 
| 
      
 526 
     | 
    
         
            +
                        nlt, nla, nld, nls, nlz = olt, ola, old, ols, olz
         
     | 
| 
      
 527 
     | 
    
         
            +
             
     | 
| 
      
 528 
     | 
    
         
            +
                        # old link type == free
         
     | 
| 
      
 529 
     | 
    
         
            +
                        if olt == "free":
         
     | 
| 
      
 530 
     | 
    
         
            +
                            # cor joint is (free, p3d) stacked
         
     | 
| 
      
 531 
     | 
    
         
            +
                            nlt = "cor"
         
     | 
| 
      
 532 
     | 
    
         
            +
                            # entries of old armature are 3*ang (spherical), 3*pos (p3d)
         
     | 
| 
      
 533 
     | 
    
         
            +
                            nla = jnp.concatenate((ola, ola[3:]))
         
     | 
| 
      
 534 
     | 
    
         
            +
                            nld = jnp.concatenate((old, old[3:]))
         
     | 
| 
      
 535 
     | 
    
         
            +
                            nls = jnp.concatenate((ols, ols[3:]))
         
     | 
| 
      
 536 
     | 
    
         
            +
                            nlz = jnp.concatenate((olz, olz[4:]))
         
     | 
| 
      
 537 
     | 
    
         
            +
             
     | 
| 
      
 538 
     | 
    
         
            +
                        return nlt, nla, nld, nls, nlz
         
     | 
| 
      
 539 
     | 
    
         
            +
             
     | 
| 
      
 540 
     | 
    
         
            +
                    return _update_sys_if_replace_joint_type(self, logic_replace_free_with_cor)
         
     | 
| 
      
 541 
     | 
    
         
            +
             
     | 
| 
      
 542 
     | 
    
         
            +
                def freeze(self, name: str | list[str]):
         
     | 
| 
      
 543 
     | 
    
         
            +
                    if isinstance(name, list):
         
     | 
| 
      
 544 
     | 
    
         
            +
                        sys = self
         
     | 
| 
      
 545 
     | 
    
         
            +
                        for n in name:
         
     | 
| 
      
 546 
     | 
    
         
            +
                            sys = sys.freeze(n)
         
     | 
| 
      
 547 
     | 
    
         
            +
                        return sys
         
     | 
| 
      
 548 
     | 
    
         
            +
             
     | 
| 
      
 549 
     | 
    
         
            +
                    def logic_freeze(link_name, olt, ola, old, ols, olz):
         
     | 
| 
      
 550 
     | 
    
         
            +
                        nlt, nla, nld, nls, nlz = olt, ola, old, ols, olz
         
     | 
| 
      
 551 
     | 
    
         
            +
             
     | 
| 
      
 552 
     | 
    
         
            +
                        if link_name == name:
         
     | 
| 
      
 553 
     | 
    
         
            +
                            nlt = "frozen"
         
     | 
| 
      
 554 
     | 
    
         
            +
                            nla = nld = nls = nlz = jnp.array([])
         
     | 
| 
      
 555 
     | 
    
         
            +
             
     | 
| 
      
 556 
     | 
    
         
            +
                        return nlt, nla, nld, nls, nlz
         
     | 
| 
      
 557 
     | 
    
         
            +
             
     | 
| 
      
 558 
     | 
    
         
            +
                    return _update_sys_if_replace_joint_type(self, logic_freeze)
         
     | 
| 
      
 559 
     | 
    
         
            +
             
     | 
| 
      
 560 
     | 
    
         
            +
                def unfreeze(self, name: str, new_joint_type: str):
         
     | 
| 
      
 561 
     | 
    
         
            +
                    assert self.link_types[self.name_to_idx(name)] == "frozen"
         
     | 
| 
      
 562 
     | 
    
         
            +
                    assert new_joint_type != "frozen"
         
     | 
| 
      
 563 
     | 
    
         
            +
             
     | 
| 
      
 564 
     | 
    
         
            +
                    return self.change_joint_type(name, new_joint_type)
         
     | 
| 
      
 565 
     | 
    
         
            +
             
     | 
| 
      
 566 
     | 
    
         
            +
                def change_joint_type(
         
     | 
| 
      
 567 
     | 
    
         
            +
                    self,
         
     | 
| 
      
 568 
     | 
    
         
            +
                    name: str,
         
     | 
| 
      
 569 
     | 
    
         
            +
                    new_joint_type: str,
         
     | 
| 
      
 570 
     | 
    
         
            +
                    new_arma: Optional[jax.Array] = None,
         
     | 
| 
      
 571 
     | 
    
         
            +
                    new_damp: Optional[jax.Array] = None,
         
     | 
| 
      
 572 
     | 
    
         
            +
                    new_stif: Optional[jax.Array] = None,
         
     | 
| 
      
 573 
     | 
    
         
            +
                    new_zero: Optional[jax.Array] = None,
         
     | 
| 
      
 574 
     | 
    
         
            +
                ):
         
     | 
| 
      
 575 
     | 
    
         
            +
                    "By default damping, stiffness are set to zero."
         
     | 
| 
      
 576 
     | 
    
         
            +
                    q_size, qd_size = Q_WIDTHS[new_joint_type], QD_WIDTHS[new_joint_type]
         
     | 
| 
      
 577 
     | 
    
         
            +
             
     | 
| 
      
 578 
     | 
    
         
            +
                    def logic_unfreeze_to_spherical(link_name, olt, ola, old, ols, olz):
         
     | 
| 
      
 579 
     | 
    
         
            +
                        nlt, nla, nld, nls, nlz = olt, ola, old, ols, olz
         
     | 
| 
      
 580 
     | 
    
         
            +
             
     | 
| 
      
 581 
     | 
    
         
            +
                        if link_name == name:
         
     | 
| 
      
 582 
     | 
    
         
            +
                            nlt = new_joint_type
         
     | 
| 
      
 583 
     | 
    
         
            +
                            q_zeros = jnp.zeros((q_size))
         
     | 
| 
      
 584 
     | 
    
         
            +
                            qd_zeros = jnp.zeros((qd_size,))
         
     | 
| 
      
 585 
     | 
    
         
            +
             
     | 
| 
      
 586 
     | 
    
         
            +
                            nla = qd_zeros if new_arma is None else new_arma
         
     | 
| 
      
 587 
     | 
    
         
            +
                            nld = qd_zeros if new_damp is None else new_damp
         
     | 
| 
      
 588 
     | 
    
         
            +
                            nls = qd_zeros if new_stif is None else new_stif
         
     | 
| 
      
 589 
     | 
    
         
            +
                            nlz = q_zeros if new_zero is None else new_zero
         
     | 
| 
      
 590 
     | 
    
         
            +
             
     | 
| 
      
 591 
     | 
    
         
            +
                            # unit quaternion
         
     | 
| 
      
 592 
     | 
    
         
            +
                            if new_joint_type in ["spherical", "free", "cor"] and new_zero is None:
         
     | 
| 
      
 593 
     | 
    
         
            +
                                nlz = nlz.at[0].set(1.0)
         
     | 
| 
      
 594 
     | 
    
         
            +
             
     | 
| 
      
 595 
     | 
    
         
            +
                        return nlt, nla, nld, nls, nlz
         
     | 
| 
      
 596 
     | 
    
         
            +
             
     | 
| 
      
 597 
     | 
    
         
            +
                    return _update_sys_if_replace_joint_type(self, logic_unfreeze_to_spherical)
         
     | 
| 
      
 598 
     | 
    
         
            +
             
     | 
| 
      
 599 
     | 
    
         
            +
                def findall_imus(self) -> list[str]:
         
     | 
| 
      
 600 
     | 
    
         
            +
                    return [name for name in self.link_names if name[:3] == "imu"]
         
     | 
| 
      
 601 
     | 
    
         
            +
             
     | 
| 
      
 602 
     | 
    
         
            +
                def findall_segments(self) -> list[str]:
         
     | 
| 
      
 603 
     | 
    
         
            +
                    imus = self.findall_imus()
         
     | 
| 
      
 604 
     | 
    
         
            +
                    return [name for name in self.link_names if name not in imus]
         
     | 
| 
      
 605 
     | 
    
         
            +
             
     | 
| 
      
 606 
     | 
    
         
            +
                def _bodies_indices_to_bodies_name(self, bodies: list[int]) -> list[str]:
         
     | 
| 
      
 607 
     | 
    
         
            +
                    return [self.idx_to_name(i) for i in bodies]
         
     | 
| 
      
 608 
     | 
    
         
            +
             
     | 
| 
      
 609 
     | 
    
         
            +
                def findall_bodies_to_world(self, names: bool = False) -> list[int] | list[str]:
         
     | 
| 
      
 610 
     | 
    
         
            +
                    bodies = [i for i, p in enumerate(self.link_parents) if p == -1]
         
     | 
| 
      
 611 
     | 
    
         
            +
                    return self._bodies_indices_to_bodies_name(bodies) if names else bodies
         
     | 
| 
      
 612 
     | 
    
         
            +
             
     | 
| 
      
 613 
     | 
    
         
            +
                def find_body_to_world(self, name: bool = False) -> int | str:
         
     | 
| 
      
 614 
     | 
    
         
            +
                    bodies = self.findall_bodies_to_world(names=name)
         
     | 
| 
      
 615 
     | 
    
         
            +
                    assert len(bodies) == 1
         
     | 
| 
      
 616 
     | 
    
         
            +
                    return bodies[0]
         
     | 
| 
      
 617 
     | 
    
         
            +
             
     | 
| 
      
 618 
     | 
    
         
            +
                def findall_bodies_with_jointtype(
         
     | 
| 
      
 619 
     | 
    
         
            +
                    self, typ: str, names: bool = False
         
     | 
| 
      
 620 
     | 
    
         
            +
                ) -> list[int] | list[str]:
         
     | 
| 
      
 621 
     | 
    
         
            +
                    bodies = [i for i, _typ in enumerate(self.link_types) if _typ == typ]
         
     | 
| 
      
 622 
     | 
    
         
            +
                    return self._bodies_indices_to_bodies_name(bodies) if names else bodies
         
     | 
| 
      
 623 
     | 
    
         
            +
             
     | 
| 
      
 624 
     | 
    
         
            +
                def scan(self, f: Callable, in_types: str, *args, reverse: bool = False):
         
     | 
| 
      
 625 
     | 
    
         
            +
                    """Scan `f` along each link in system whilst carrying along state.
         
     | 
| 
      
 626 
     | 
    
         
            +
             
     | 
| 
      
 627 
     | 
    
         
            +
                    Args:
         
     | 
| 
      
 628 
     | 
    
         
            +
                        f (Callable[..., Y]): f(y: Y, *args) -> y
         
     | 
| 
      
 629 
     | 
    
         
            +
                        in_types: string specifying the type of each input arg:
         
     | 
| 
      
 630 
     | 
    
         
            +
                            'l' is an input to be split according to link ranges
         
     | 
| 
      
 631 
     | 
    
         
            +
                            'q' is an input to be split according to q ranges
         
     | 
| 
      
 632 
     | 
    
         
            +
                            'd' is an input to be split according to qd ranges
         
     | 
| 
      
 633 
     | 
    
         
            +
                        args: Arguments passed to `f`, and split to match the link.
         
     | 
| 
      
 634 
     | 
    
         
            +
                        reverse (bool, optional): If `true` from leaves to root. Defaults to False.
         
     | 
| 
      
 635 
     | 
    
         
            +
             
     | 
| 
      
 636 
     | 
    
         
            +
                    Returns:
         
     | 
| 
      
 637 
     | 
    
         
            +
                        ys: Stacked output y of f.
         
     | 
| 
      
 638 
     | 
    
         
            +
                    """
         
     | 
| 
      
 639 
     | 
    
         
            +
                    return _scan_sys(self, f, in_types, *args, reverse=reverse)
         
     | 
| 
      
 640 
     | 
    
         
            +
             
     | 
| 
      
 641 
     | 
    
         
            +
                def parse(self) -> "System":
         
     | 
| 
      
 642 
     | 
    
         
            +
                    """Initial setup of system. System object does not work unless it is parsed.
         
     | 
| 
      
 643 
     | 
    
         
            +
                    Currently it does:
         
     | 
| 
      
 644 
     | 
    
         
            +
                    - some consistency checks
         
     | 
| 
      
 645 
     | 
    
         
            +
                    - populate the spatial inertia tensors
         
     | 
| 
      
 646 
     | 
    
         
            +
                    - check that all names are unique
         
     | 
| 
      
 647 
     | 
    
         
            +
                    - check that names are strings
         
     | 
| 
      
 648 
     | 
    
         
            +
                    - check that all pos_min <= pos_max (unless traced)
         
     | 
| 
      
 649 
     | 
    
         
            +
                    - order geoms in ascending order based on their parent link idx
         
     | 
| 
      
 650 
     | 
    
         
            +
                    - check that all links have the correct size of
         
     | 
| 
      
 651 
     | 
    
         
            +
                        - damping
         
     | 
| 
      
 652 
     | 
    
         
            +
                        - armature
         
     | 
| 
      
 653 
     | 
    
         
            +
                        - stiffness
         
     | 
| 
      
 654 
     | 
    
         
            +
                        - zeropoint
         
     | 
| 
      
 655 
     | 
    
         
            +
                    - check that n_links == len(sys.omc)
         
     | 
| 
      
 656 
     | 
    
         
            +
                    """
         
     | 
| 
      
 657 
     | 
    
         
            +
                    return _parse_system(self)
         
     | 
| 
      
 658 
     | 
    
         
            +
             
     | 
| 
      
 659 
     | 
    
         
            +
                def render(
         
     | 
| 
      
 660 
     | 
    
         
            +
                    self,
         
     | 
| 
      
 661 
     | 
    
         
            +
                    xs: Optional[Transform | list[Transform]] = None,
         
     | 
| 
      
 662 
     | 
    
         
            +
                    camera: Optional[str] = None,
         
     | 
| 
      
 663 
     | 
    
         
            +
                    show_pbar: bool = True,
         
     | 
| 
      
 664 
     | 
    
         
            +
                    backend: str = "mujoco",
         
     | 
| 
      
 665 
     | 
    
         
            +
                    render_every_nth: int = 1,
         
     | 
| 
      
 666 
     | 
    
         
            +
                    **scene_kwargs,
         
     | 
| 
      
 667 
     | 
    
         
            +
                ) -> list[np.ndarray]:
         
     | 
| 
      
 668 
     | 
    
         
            +
                    """Render frames from system and trajectory of maximal coordinates `xs`.
         
     | 
| 
      
 669 
     | 
    
         
            +
             
     | 
| 
      
 670 
     | 
    
         
            +
                    Args:
         
     | 
| 
      
 671 
     | 
    
         
            +
                        sys (base.System): System to render.
         
     | 
| 
      
 672 
     | 
    
         
            +
                        xs (base.Transform | list[base.Transform]): Single or time-series
         
     | 
| 
      
 673 
     | 
    
         
            +
                        of maximal coordinates `xs`.
         
     | 
| 
      
 674 
     | 
    
         
            +
                        show_pbar (bool, optional): Whether or not to show a progress bar.
         
     | 
| 
      
 675 
     | 
    
         
            +
                        Defaults to True.
         
     | 
| 
      
 676 
     | 
    
         
            +
             
     | 
| 
      
 677 
     | 
    
         
            +
                    Returns:
         
     | 
| 
      
 678 
     | 
    
         
            +
                        list[np.ndarray]: Stacked rendered frames. Length == len(xs).
         
     | 
| 
      
 679 
     | 
    
         
            +
                    """
         
     | 
| 
      
 680 
     | 
    
         
            +
                    return ring.rendering.render(
         
     | 
| 
      
 681 
     | 
    
         
            +
                        self, xs, camera, show_pbar, backend, render_every_nth, **scene_kwargs
         
     | 
| 
      
 682 
     | 
    
         
            +
                    )
         
     | 
| 
      
 683 
     | 
    
         
            +
             
     | 
| 
      
 684 
     | 
    
         
            +
                def render_prediction(
         
     | 
| 
      
 685 
     | 
    
         
            +
                    self,
         
     | 
| 
      
 686 
     | 
    
         
            +
                    xs: Transform | list[Transform],
         
     | 
| 
      
 687 
     | 
    
         
            +
                    yhat: dict | jax.Array | np.ndarray,
         
     | 
| 
      
 688 
     | 
    
         
            +
                    stepframe: int = 1,
         
     | 
| 
      
 689 
     | 
    
         
            +
                    # by default we don't predict the global rotation
         
     | 
| 
      
 690 
     | 
    
         
            +
                    transparent_segment_to_root: bool = True,
         
     | 
| 
      
 691 
     | 
    
         
            +
                    **kwargs,
         
     | 
| 
      
 692 
     | 
    
         
            +
                ):
         
     | 
| 
      
 693 
     | 
    
         
            +
                    "`xs` matches `sys`. `yhat` matches `sys_noimu`. `yhat` are child-to-parent."
         
     | 
| 
      
 694 
     | 
    
         
            +
                    return ring.rendering.render_prediction(
         
     | 
| 
      
 695 
     | 
    
         
            +
                        self, xs, yhat, stepframe, transparent_segment_to_root, **kwargs
         
     | 
| 
      
 696 
     | 
    
         
            +
                    )
         
     | 
| 
      
 697 
     | 
    
         
            +
             
     | 
| 
      
 698 
     | 
    
         
            +
                def delete_system(self, link_name: str | list[str], strict: bool = True):
         
     | 
| 
      
 699 
     | 
    
         
            +
                    "Cut subsystem starting at `link_name` (inclusive) from tree."
         
     | 
| 
      
 700 
     | 
    
         
            +
                    return ring.sys_composer.delete_subsystem(self, link_name, strict)
         
     | 
| 
      
 701 
     | 
    
         
            +
             
     | 
| 
      
 702 
     | 
    
         
            +
                def make_sys_noimu(self, imu_link_names: Optional[list[str]] = None):
         
     | 
| 
      
 703 
     | 
    
         
            +
                    "Returns, e.g., imu_attachment = {'imu1': 'seg1', 'imu2': 'seg3'}"
         
     | 
| 
      
 704 
     | 
    
         
            +
                    return ring.sys_composer.make_sys_noimu(self, imu_link_names)
         
     | 
| 
      
 705 
     | 
    
         
            +
             
     | 
| 
      
 706 
     | 
    
         
            +
                def inject_system(self, other_system: "System", at_body: Optional[str] = None):
         
     | 
| 
      
 707 
     | 
    
         
            +
                    """Combine two systems into one.
         
     | 
| 
      
 708 
     | 
    
         
            +
             
     | 
| 
      
 709 
     | 
    
         
            +
                    Args:
         
     | 
| 
      
 710 
     | 
    
         
            +
                        sys (base.System): Large system.
         
     | 
| 
      
 711 
     | 
    
         
            +
                        sub_sys (base.System): Small system that will be included into the
         
     | 
| 
      
 712 
     | 
    
         
            +
                            large system `sys`.
         
     | 
| 
      
 713 
     | 
    
         
            +
                        at_body (Optional[str], optional): Into which body of the large system
         
     | 
| 
      
 714 
     | 
    
         
            +
                            small system will be included. Defaults to `worldbody`.
         
     | 
| 
      
 715 
     | 
    
         
            +
             
     | 
| 
      
 716 
     | 
    
         
            +
                    Returns:
         
     | 
| 
      
 717 
     | 
    
         
            +
                        base.System: _description_
         
     | 
| 
      
 718 
     | 
    
         
            +
                    """
         
     | 
| 
      
 719 
     | 
    
         
            +
                    return ring.sys_composer.inject_system(self, other_system, at_body)
         
     | 
| 
      
 720 
     | 
    
         
            +
             
     | 
| 
      
 721 
     | 
    
         
            +
                def morph_system(
         
     | 
| 
      
 722 
     | 
    
         
            +
                    self,
         
     | 
| 
      
 723 
     | 
    
         
            +
                    new_parents: Optional[list[int | str]] = None,
         
     | 
| 
      
 724 
     | 
    
         
            +
                    new_anchor: Optional[int | str] = None,
         
     | 
| 
      
 725 
     | 
    
         
            +
                ):
         
     | 
| 
      
 726 
     | 
    
         
            +
                    """Re-orders the graph underlying the system. Returns a new system.
         
     | 
| 
      
 727 
     | 
    
         
            +
             
     | 
| 
      
 728 
     | 
    
         
            +
                    Args:
         
     | 
| 
      
 729 
     | 
    
         
            +
                        sys (base.System): System to be modified.
         
     | 
| 
      
 730 
     | 
    
         
            +
                        new_parents (list[int]): Let the i-th entry have value j. Then, after
         
     | 
| 
      
 731 
     | 
    
         
            +
                            morphing the system the system will be such that the link corresponding
         
     | 
| 
      
 732 
     | 
    
         
            +
                            to the i-th link in the old system will have as parent the link
         
     | 
| 
      
 733 
     | 
    
         
            +
                            corresponding to the j-th link in the old system.
         
     | 
| 
      
 734 
     | 
    
         
            +
             
     | 
| 
      
 735 
     | 
    
         
            +
                    Returns:
         
     | 
| 
      
 736 
     | 
    
         
            +
                        base.System: Modified system.
         
     | 
| 
      
 737 
     | 
    
         
            +
                    """
         
     | 
| 
      
 738 
     | 
    
         
            +
                    return ring.sys_composer.morph_system(self, new_parents, new_anchor)
         
     | 
| 
      
 739 
     | 
    
         
            +
             
     | 
| 
      
 740 
     | 
    
         
            +
                @staticmethod
         
     | 
| 
      
 741 
     | 
    
         
            +
                def from_xml(path: str, seed: int = 1):
         
     | 
| 
      
 742 
     | 
    
         
            +
                    return ring.io.load_sys_from_xml(path, seed)
         
     | 
| 
      
 743 
     | 
    
         
            +
             
     | 
| 
      
 744 
     | 
    
         
            +
                @staticmethod
         
     | 
| 
      
 745 
     | 
    
         
            +
                def from_str(xml: str, seed: int = 1):
         
     | 
| 
      
 746 
     | 
    
         
            +
                    return ring.io.load_sys_from_str(xml, seed)
         
     | 
| 
      
 747 
     | 
    
         
            +
             
     | 
| 
      
 748 
     | 
    
         
            +
                def to_str(self) -> str:
         
     | 
| 
      
 749 
     | 
    
         
            +
                    return ring.io.save_sys_to_str(self)
         
     | 
| 
      
 750 
     | 
    
         
            +
             
     | 
| 
      
 751 
     | 
    
         
            +
                def to_xml(self, path: str) -> None:
         
     | 
| 
      
 752 
     | 
    
         
            +
                    ring.io.save_sys_to_xml(self, path)
         
     | 
| 
      
 753 
     | 
    
         
            +
             
     | 
| 
      
 754 
     | 
    
         
            +
                @classmethod
         
     | 
| 
      
 755 
     | 
    
         
            +
                def create(cls, path_or_str: str, seed: int = 1) -> "System":
         
     | 
| 
      
 756 
     | 
    
         
            +
                    path = Path(path_or_str).with_suffix(".xml")
         
     | 
| 
      
 757 
     | 
    
         
            +
             
     | 
| 
      
 758 
     | 
    
         
            +
                    exists = False
         
     | 
| 
      
 759 
     | 
    
         
            +
                    try:
         
     | 
| 
      
 760 
     | 
    
         
            +
                        exists = path.exists()
         
     | 
| 
      
 761 
     | 
    
         
            +
                    except OSError:
         
     | 
| 
      
 762 
     | 
    
         
            +
                        # file length too length
         
     | 
| 
      
 763 
     | 
    
         
            +
                        pass
         
     | 
| 
      
 764 
     | 
    
         
            +
             
     | 
| 
      
 765 
     | 
    
         
            +
                    if exists:
         
     | 
| 
      
 766 
     | 
    
         
            +
                        return cls.from_xml(path, seed=seed)
         
     | 
| 
      
 767 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 768 
     | 
    
         
            +
                        return cls.from_str(path_or_str)
         
     | 
| 
      
 769 
     | 
    
         
            +
             
     | 
| 
      
 770 
     | 
    
         
            +
                def coordinate_vector_to_q(
         
     | 
| 
      
 771 
     | 
    
         
            +
                    self,
         
     | 
| 
      
 772 
     | 
    
         
            +
                    q: jax.Array,
         
     | 
| 
      
 773 
     | 
    
         
            +
                    custom_joints: dict[str, Callable] = {},
         
     | 
| 
      
 774 
     | 
    
         
            +
                ) -> jax.Array:
         
     | 
| 
      
 775 
     | 
    
         
            +
                    """Map a coordinate vector `q` to the minimal coordinates vector of the sys"""
         
     | 
| 
      
 776 
     | 
    
         
            +
                    # Does, e.g.
         
     | 
| 
      
 777 
     | 
    
         
            +
                    # - normalize quaternions
         
     | 
| 
      
 778 
     | 
    
         
            +
                    # - hinge joints in [-pi, pi]
         
     | 
| 
      
 779 
     | 
    
         
            +
                    q_preproc = []
         
     | 
| 
      
 780 
     | 
    
         
            +
             
     | 
| 
      
 781 
     | 
    
         
            +
                    def preprocess(_, __, link_type, q):
         
     | 
| 
      
 782 
     | 
    
         
            +
                        to_q = ring.algorithms.jcalc.get_joint_model(
         
     | 
| 
      
 783 
     | 
    
         
            +
                            link_type
         
     | 
| 
      
 784 
     | 
    
         
            +
                        ).coordinate_vector_to_q
         
     | 
| 
      
 785 
     | 
    
         
            +
                        # function in custom_joints has priority over JointModel
         
     | 
| 
      
 786 
     | 
    
         
            +
                        if link_type in custom_joints:
         
     | 
| 
      
 787 
     | 
    
         
            +
                            to_q = custom_joints[link_type]
         
     | 
| 
      
 788 
     | 
    
         
            +
                        if to_q is None:
         
     | 
| 
      
 789 
     | 
    
         
            +
                            raise NotImplementedError(
         
     | 
| 
      
 790 
     | 
    
         
            +
                                f"Please specify the custom joint `{link_type}`"
         
     | 
| 
      
 791 
     | 
    
         
            +
                                " either using the `custom_joints` arguments or using the"
         
     | 
| 
      
 792 
     | 
    
         
            +
                                " JointModel.coordinate_vector_to_q field."
         
     | 
| 
      
 793 
     | 
    
         
            +
                            )
         
     | 
| 
      
 794 
     | 
    
         
            +
                        new_q = to_q(q)
         
     | 
| 
      
 795 
     | 
    
         
            +
                        q_preproc.append(new_q)
         
     | 
| 
      
 796 
     | 
    
         
            +
             
     | 
| 
      
 797 
     | 
    
         
            +
                    self.scan(preprocess, "lq", self.link_types, q)
         
     | 
| 
      
 798 
     | 
    
         
            +
                    return jnp.concatenate(q_preproc)
         
     | 
| 
      
 799 
     | 
    
         
            +
             
     | 
| 
      
 800 
     | 
    
         
            +
             
     | 
| 
      
 801 
     | 
    
         
            +
            def _update_sys_if_replace_joint_type(sys: System, logic) -> System:
         
     | 
| 
      
 802 
     | 
    
         
            +
                lt, la, ld, ls, lz = [], [], [], [], []
         
     | 
| 
      
 803 
     | 
    
         
            +
             
     | 
| 
      
 804 
     | 
    
         
            +
                def f(_, __, name, olt, ola, old, ols, olz):
         
     | 
| 
      
 805 
     | 
    
         
            +
                    nlt, nla, nld, nls, nlz = logic(name, olt, ola, old, ols, olz)
         
     | 
| 
      
 806 
     | 
    
         
            +
             
     | 
| 
      
 807 
     | 
    
         
            +
                    lt.append(nlt)
         
     | 
| 
      
 808 
     | 
    
         
            +
                    la.append(nla)
         
     | 
| 
      
 809 
     | 
    
         
            +
                    ld.append(nld)
         
     | 
| 
      
 810 
     | 
    
         
            +
                    ls.append(nls)
         
     | 
| 
      
 811 
     | 
    
         
            +
                    lz.append(nlz)
         
     | 
| 
      
 812 
     | 
    
         
            +
             
     | 
| 
      
 813 
     | 
    
         
            +
                sys.scan(
         
     | 
| 
      
 814 
     | 
    
         
            +
                    f,
         
     | 
| 
      
 815 
     | 
    
         
            +
                    "lldddq",
         
     | 
| 
      
 816 
     | 
    
         
            +
                    sys.link_names,
         
     | 
| 
      
 817 
     | 
    
         
            +
                    sys.link_types,
         
     | 
| 
      
 818 
     | 
    
         
            +
                    sys.link_armature,
         
     | 
| 
      
 819 
     | 
    
         
            +
                    sys.link_damping,
         
     | 
| 
      
 820 
     | 
    
         
            +
                    sys.link_spring_stiffness,
         
     | 
| 
      
 821 
     | 
    
         
            +
                    sys.link_spring_zeropoint,
         
     | 
| 
      
 822 
     | 
    
         
            +
                )
         
     | 
| 
      
 823 
     | 
    
         
            +
             
     | 
| 
      
 824 
     | 
    
         
            +
                # lt is supposed to be a list of strings; no concat required
         
     | 
| 
      
 825 
     | 
    
         
            +
                la, ld, ls, lz = map(jnp.concatenate, (la, ld, ls, lz))
         
     | 
| 
      
 826 
     | 
    
         
            +
             
     | 
| 
      
 827 
     | 
    
         
            +
                sys = sys.replace(
         
     | 
| 
      
 828 
     | 
    
         
            +
                    link_types=lt,
         
     | 
| 
      
 829 
     | 
    
         
            +
                    link_armature=la,
         
     | 
| 
      
 830 
     | 
    
         
            +
                    link_damping=ld,
         
     | 
| 
      
 831 
     | 
    
         
            +
                    link_spring_stiffness=ls,
         
     | 
| 
      
 832 
     | 
    
         
            +
                    link_spring_zeropoint=lz,
         
     | 
| 
      
 833 
     | 
    
         
            +
                )
         
     | 
| 
      
 834 
     | 
    
         
            +
             
     | 
| 
      
 835 
     | 
    
         
            +
                # parse system such that it checks if all joint types have the
         
     | 
| 
      
 836 
     | 
    
         
            +
                # correct dimensionality of damping / stiffness / zeropoint / armature
         
     | 
| 
      
 837 
     | 
    
         
            +
                return sys.parse()
         
     | 
| 
      
 838 
     | 
    
         
            +
             
     | 
| 
      
 839 
     | 
    
         
            +
             
     | 
| 
      
 840 
     | 
    
         
            +
            class InvalidSystemError(Exception):
         
     | 
| 
      
 841 
     | 
    
         
            +
                pass
         
     | 
| 
      
 842 
     | 
    
         
            +
             
     | 
| 
      
 843 
     | 
    
         
            +
             
     | 
| 
      
 844 
     | 
    
         
            +
            def _parse_system(sys: System) -> System:
         
     | 
| 
      
 845 
     | 
    
         
            +
                assert len(sys.link_parents) == len(sys.link_types) == sys.links.batch_dim()
         
     | 
| 
      
 846 
     | 
    
         
            +
                assert len(sys.omc) == sys.num_links()
         
     | 
| 
      
 847 
     | 
    
         
            +
             
     | 
| 
      
 848 
     | 
    
         
            +
                for i, name in enumerate(sys.link_names):
         
     | 
| 
      
 849 
     | 
    
         
            +
                    assert sys.link_names.count(name) == 1, f"Duplicated name=`{name}` in system"
         
     | 
| 
      
 850 
     | 
    
         
            +
                    assert isinstance(name, str)
         
     | 
| 
      
 851 
     | 
    
         
            +
             
     | 
| 
      
 852 
     | 
    
         
            +
                pos_min, pos_max = sys.links.pos_min, sys.links.pos_max
         
     | 
| 
      
 853 
     | 
    
         
            +
             
     | 
| 
      
 854 
     | 
    
         
            +
                try:
         
     | 
| 
      
 855 
     | 
    
         
            +
                    from jax.errors import TracerBoolConversionError
         
     | 
| 
      
 856 
     | 
    
         
            +
             
     | 
| 
      
 857 
     | 
    
         
            +
                    try:
         
     | 
| 
      
 858 
     | 
    
         
            +
                        assert jnp.all(pos_max >= pos_min), f"min={pos_min}, max={pos_max}"
         
     | 
| 
      
 859 
     | 
    
         
            +
                    except TracerBoolConversionError:
         
     | 
| 
      
 860 
     | 
    
         
            +
                        pass
         
     | 
| 
      
 861 
     | 
    
         
            +
                # on older versions of jax this import is not possible
         
     | 
| 
      
 862 
     | 
    
         
            +
                except ImportError:
         
     | 
| 
      
 863 
     | 
    
         
            +
                    pass
         
     | 
| 
      
 864 
     | 
    
         
            +
             
     | 
| 
      
 865 
     | 
    
         
            +
                for geom in sys.geoms:
         
     | 
| 
      
 866 
     | 
    
         
            +
                    assert geom.link_idx in list(range(sys.num_links())) + [-1]
         
     | 
| 
      
 867 
     | 
    
         
            +
             
     | 
| 
      
 868 
     | 
    
         
            +
                inertia = _parse_system_calculate_inertia(sys)
         
     | 
| 
      
 869 
     | 
    
         
            +
                sys = sys.replace(links=sys.links.replace(inertia=inertia))
         
     | 
| 
      
 870 
     | 
    
         
            +
             
     | 
| 
      
 871 
     | 
    
         
            +
                # sort geoms in ascending order
         
     | 
| 
      
 872 
     | 
    
         
            +
                geoms = sys.geoms.copy()
         
     | 
| 
      
 873 
     | 
    
         
            +
                geoms.sort(key=lambda geom: geom.link_idx)
         
     | 
| 
      
 874 
     | 
    
         
            +
                sys = sys.replace(geoms=geoms)
         
     | 
| 
      
 875 
     | 
    
         
            +
             
     | 
| 
      
 876 
     | 
    
         
            +
                # round dt
         
     | 
| 
      
 877 
     | 
    
         
            +
                # sys = sys.replace(dt=round(sys.dt, 8))
         
     | 
| 
      
 878 
     | 
    
         
            +
             
     | 
| 
      
 879 
     | 
    
         
            +
                # check sizes of damping / arma / stiff / zeropoint
         
     | 
| 
      
 880 
     | 
    
         
            +
                def check_dasz_unitq(_, __, name, typ, d, a, s, z):
         
     | 
| 
      
 881 
     | 
    
         
            +
                    q_size, qd_size = Q_WIDTHS[typ], QD_WIDTHS[typ]
         
     | 
| 
      
 882 
     | 
    
         
            +
             
     | 
| 
      
 883 
     | 
    
         
            +
                    error_msg = (
         
     | 
| 
      
 884 
     | 
    
         
            +
                        f"wrong size for link `{name}` of typ `{typ}` in model {sys.model_name}"
         
     | 
| 
      
 885 
     | 
    
         
            +
                    )
         
     | 
| 
      
 886 
     | 
    
         
            +
             
     | 
| 
      
 887 
     | 
    
         
            +
                    assert d.size == a.size == s.size == qd_size, error_msg
         
     | 
| 
      
 888 
     | 
    
         
            +
                    assert z.size == q_size, error_msg
         
     | 
| 
      
 889 
     | 
    
         
            +
             
     | 
| 
      
 890 
     | 
    
         
            +
                    if typ in ["spherical", "free", "cor"] and not isinstance(z, Tracer):
         
     | 
| 
      
 891 
     | 
    
         
            +
                        assert jnp.allclose(
         
     | 
| 
      
 892 
     | 
    
         
            +
                            jnp.linalg.norm(z[:4]), 1.0
         
     | 
| 
      
 893 
     | 
    
         
            +
                        ), f"not unit quat for link `{name}` of typ `{typ}` in model"
         
     | 
| 
      
 894 
     | 
    
         
            +
                        f" {sys.model_name}"
         
     | 
| 
      
 895 
     | 
    
         
            +
             
     | 
| 
      
 896 
     | 
    
         
            +
                sys.scan(
         
     | 
| 
      
 897 
     | 
    
         
            +
                    check_dasz_unitq,
         
     | 
| 
      
 898 
     | 
    
         
            +
                    "lldddq",
         
     | 
| 
      
 899 
     | 
    
         
            +
                    sys.link_names,
         
     | 
| 
      
 900 
     | 
    
         
            +
                    sys.link_types,
         
     | 
| 
      
 901 
     | 
    
         
            +
                    sys.link_damping,
         
     | 
| 
      
 902 
     | 
    
         
            +
                    sys.link_armature,
         
     | 
| 
      
 903 
     | 
    
         
            +
                    sys.link_spring_stiffness,
         
     | 
| 
      
 904 
     | 
    
         
            +
                    sys.link_spring_zeropoint,
         
     | 
| 
      
 905 
     | 
    
         
            +
                )
         
     | 
| 
      
 906 
     | 
    
         
            +
             
     | 
| 
      
 907 
     | 
    
         
            +
                return sys
         
     | 
| 
      
 908 
     | 
    
         
            +
             
     | 
| 
      
 909 
     | 
    
         
            +
             
     | 
| 
      
 910 
     | 
    
         
            +
            def _inertia_from_geometries(geometries: list[Geometry]) -> Inertia:
         
     | 
| 
      
 911 
     | 
    
         
            +
                inertia = Inertia.zero()
         
     | 
| 
      
 912 
     | 
    
         
            +
                for geom in geometries:
         
     | 
| 
      
 913 
     | 
    
         
            +
                    inertia += Inertia.create(geom.mass, geom.transform, geom.get_it_3x3())
         
     | 
| 
      
 914 
     | 
    
         
            +
                return inertia
         
     | 
| 
      
 915 
     | 
    
         
            +
             
     | 
| 
      
 916 
     | 
    
         
            +
             
     | 
| 
      
 917 
     | 
    
         
            +
            def _parse_system_calculate_inertia(sys: System):
         
     | 
| 
      
 918 
     | 
    
         
            +
                def compute_inertia_per_link(_, __, link_idx: int):
         
     | 
| 
      
 919 
     | 
    
         
            +
                    geoms_link = []
         
     | 
| 
      
 920 
     | 
    
         
            +
                    for geom in sys.geoms:
         
     | 
| 
      
 921 
     | 
    
         
            +
                        if geom.link_idx == link_idx:
         
     | 
| 
      
 922 
     | 
    
         
            +
                            geoms_link.append(geom)
         
     | 
| 
      
 923 
     | 
    
         
            +
             
     | 
| 
      
 924 
     | 
    
         
            +
                    it = _inertia_from_geometries(geoms_link)
         
     | 
| 
      
 925 
     | 
    
         
            +
                    return it
         
     | 
| 
      
 926 
     | 
    
         
            +
             
     | 
| 
      
 927 
     | 
    
         
            +
                return sys.scan(compute_inertia_per_link, "l", list(range(sys.num_links())))
         
     | 
| 
      
 928 
     | 
    
         
            +
             
     | 
| 
      
 929 
     | 
    
         
            +
             
     | 
| 
      
 930 
     | 
    
         
            +
            def _scan_sys(sys: System, f: Callable, in_types: str, *args, reverse: bool = False):
         
     | 
| 
      
 931 
     | 
    
         
            +
                assert len(args) == len(in_types)
         
     | 
| 
      
 932 
     | 
    
         
            +
                for in_type, arg in zip(in_types, args):
         
     | 
| 
      
 933 
     | 
    
         
            +
                    B = len(arg)
         
     | 
| 
      
 934 
     | 
    
         
            +
                    if in_type == "l":
         
     | 
| 
      
 935 
     | 
    
         
            +
                        assert B == sys.num_links()
         
     | 
| 
      
 936 
     | 
    
         
            +
                    elif in_type == "q":
         
     | 
| 
      
 937 
     | 
    
         
            +
                        assert B == sys.q_size()
         
     | 
| 
      
 938 
     | 
    
         
            +
                    elif in_type == "d":
         
     | 
| 
      
 939 
     | 
    
         
            +
                        assert B == sys.qd_size()
         
     | 
| 
      
 940 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 941 
     | 
    
         
            +
                        raise Exception("`in_types` must be one of `l` or `q` or `d`")
         
     | 
| 
      
 942 
     | 
    
         
            +
             
     | 
| 
      
 943 
     | 
    
         
            +
                order = range(sys.num_links())
         
     | 
| 
      
 944 
     | 
    
         
            +
                q_idx, qd_idx = 0, 0
         
     | 
| 
      
 945 
     | 
    
         
            +
                q_idxs, qd_idxs = {}, {}
         
     | 
| 
      
 946 
     | 
    
         
            +
                for link_idx, link_type in zip(order, sys.link_types):
         
     | 
| 
      
 947 
     | 
    
         
            +
                    # build map from
         
     | 
| 
      
 948 
     | 
    
         
            +
                    # link-idx -> q_idx
         
     | 
| 
      
 949 
     | 
    
         
            +
                    # link-idx -> qd_idx
         
     | 
| 
      
 950 
     | 
    
         
            +
                    q_idxs[link_idx] = slice(q_idx, q_idx + Q_WIDTHS[link_type])
         
     | 
| 
      
 951 
     | 
    
         
            +
                    qd_idxs[link_idx] = slice(qd_idx, qd_idx + QD_WIDTHS[link_type])
         
     | 
| 
      
 952 
     | 
    
         
            +
                    q_idx += Q_WIDTHS[link_type]
         
     | 
| 
      
 953 
     | 
    
         
            +
                    qd_idx += QD_WIDTHS[link_type]
         
     | 
| 
      
 954 
     | 
    
         
            +
             
     | 
| 
      
 955 
     | 
    
         
            +
                idx_map = {
         
     | 
| 
      
 956 
     | 
    
         
            +
                    "l": lambda link_idx: link_idx,
         
     | 
| 
      
 957 
     | 
    
         
            +
                    "q": lambda link_idx: q_idxs[link_idx],
         
     | 
| 
      
 958 
     | 
    
         
            +
                    "d": lambda link_idx: qd_idxs[link_idx],
         
     | 
| 
      
 959 
     | 
    
         
            +
                }
         
     | 
| 
      
 960 
     | 
    
         
            +
             
     | 
| 
      
 961 
     | 
    
         
            +
                if reverse:
         
     | 
| 
      
 962 
     | 
    
         
            +
                    order = range(sys.num_links() - 1, -1, -1)
         
     | 
| 
      
 963 
     | 
    
         
            +
             
     | 
| 
      
 964 
     | 
    
         
            +
                y, ys = None, []
         
     | 
| 
      
 965 
     | 
    
         
            +
                for link_idx in order:
         
     | 
| 
      
 966 
     | 
    
         
            +
                    args_link = [arg[idx_map[t](link_idx)] for arg, t in zip(args, in_types)]
         
     | 
| 
      
 967 
     | 
    
         
            +
                    y = f(y, idx_map, *args_link)
         
     | 
| 
      
 968 
     | 
    
         
            +
                    ys.append(y)
         
     | 
| 
      
 969 
     | 
    
         
            +
             
     | 
| 
      
 970 
     | 
    
         
            +
                if reverse:
         
     | 
| 
      
 971 
     | 
    
         
            +
                    ys.reverse()
         
     | 
| 
      
 972 
     | 
    
         
            +
             
     | 
| 
      
 973 
     | 
    
         
            +
                ys = tu.tree_batch(ys, backend="jax")
         
     | 
| 
      
 974 
     | 
    
         
            +
                return ys
         
     | 
| 
      
 975 
     | 
    
         
            +
             
     | 
| 
      
 976 
     | 
    
         
            +
             
     | 
| 
      
 977 
     | 
    
         
            +
            @struct.dataclass
         
     | 
| 
      
 978 
     | 
    
         
            +
            class State(_Base):
         
     | 
| 
      
 979 
     | 
    
         
            +
                """The static and dynamic state of a system in minimal and maximal coordinates.
         
     | 
| 
      
 980 
     | 
    
         
            +
                Use `.create()` to create this object.
         
     | 
| 
      
 981 
     | 
    
         
            +
             
     | 
| 
      
 982 
     | 
    
         
            +
                Args:
         
     | 
| 
      
 983 
     | 
    
         
            +
                    q (jax.Array): System state in minimal coordinates (equals `sys.q_size()`)
         
     | 
| 
      
 984 
     | 
    
         
            +
                    qd (jax.Array): System velocity in minimal coordinates (equals `sys.qd_size()`)
         
     | 
| 
      
 985 
     | 
    
         
            +
                    x: (Transform): Maximal coordinates of all links. From epsilon-to-link.
         
     | 
| 
      
 986 
     | 
    
         
            +
                    mass_mat_inv (jax.Array): Inverse of the mass matrix. Internal usage.
         
     | 
| 
      
 987 
     | 
    
         
            +
                """
         
     | 
| 
      
 988 
     | 
    
         
            +
             
     | 
| 
      
 989 
     | 
    
         
            +
                q: jax.Array
         
     | 
| 
      
 990 
     | 
    
         
            +
                qd: jax.Array
         
     | 
| 
      
 991 
     | 
    
         
            +
                x: Transform
         
     | 
| 
      
 992 
     | 
    
         
            +
                mass_mat_inv: jax.Array
         
     | 
| 
      
 993 
     | 
    
         
            +
             
     | 
| 
      
 994 
     | 
    
         
            +
                @classmethod
         
     | 
| 
      
 995 
     | 
    
         
            +
                def create(
         
     | 
| 
      
 996 
     | 
    
         
            +
                    cls,
         
     | 
| 
      
 997 
     | 
    
         
            +
                    sys: System,
         
     | 
| 
      
 998 
     | 
    
         
            +
                    q: Optional[jax.Array] = None,
         
     | 
| 
      
 999 
     | 
    
         
            +
                    qd: Optional[jax.Array] = None,
         
     | 
| 
      
 1000 
     | 
    
         
            +
                    x: Optional[Transform] = None,
         
     | 
| 
      
 1001 
     | 
    
         
            +
                    key: Optional[jax.Array] = None,
         
     | 
| 
      
 1002 
     | 
    
         
            +
                    custom_joints: dict[str, Callable] = {},
         
     | 
| 
      
 1003 
     | 
    
         
            +
                ):
         
     | 
| 
      
 1004 
     | 
    
         
            +
                    """Create state of system.
         
     | 
| 
      
 1005 
     | 
    
         
            +
             
     | 
| 
      
 1006 
     | 
    
         
            +
                    Args:
         
     | 
| 
      
 1007 
     | 
    
         
            +
                        sys (System): The system for which to create a state.
         
     | 
| 
      
 1008 
     | 
    
         
            +
                        q (jax.Array, optional): The joint values of the system. Defaults to None.
         
     | 
| 
      
 1009 
     | 
    
         
            +
                        Which then defaults to zeros.
         
     | 
| 
      
 1010 
     | 
    
         
            +
                        qd (jax.Array, optional): The joint velocities of the system.
         
     | 
| 
      
 1011 
     | 
    
         
            +
                        Defaults to None. Which then defaults to zeros.
         
     | 
| 
      
 1012 
     | 
    
         
            +
             
     | 
| 
      
 1013 
     | 
    
         
            +
                    Returns:
         
     | 
| 
      
 1014 
     | 
    
         
            +
                        (State): Create State object.
         
     | 
| 
      
 1015 
     | 
    
         
            +
                    """
         
     | 
| 
      
 1016 
     | 
    
         
            +
                    if key is not None:
         
     | 
| 
      
 1017 
     | 
    
         
            +
                        assert q is None
         
     | 
| 
      
 1018 
     | 
    
         
            +
                        q = jax.random.normal(key, shape=(sys.q_size(),))
         
     | 
| 
      
 1019 
     | 
    
         
            +
                        q = sys.coordinate_vector_to_q(q, custom_joints)
         
     | 
| 
      
 1020 
     | 
    
         
            +
                    elif q is None:
         
     | 
| 
      
 1021 
     | 
    
         
            +
                        q = jnp.zeros((sys.q_size(),))
         
     | 
| 
      
 1022 
     | 
    
         
            +
             
     | 
| 
      
 1023 
     | 
    
         
            +
                        # free, cor, spherical joints are not zeros but have unit quaternions
         
     | 
| 
      
 1024 
     | 
    
         
            +
                        def replace_by_unit_quat(_, idx_map, link_typ, link_idx):
         
     | 
| 
      
 1025 
     | 
    
         
            +
                            nonlocal q
         
     | 
| 
      
 1026 
     | 
    
         
            +
             
     | 
| 
      
 1027 
     | 
    
         
            +
                            if link_typ in ["free", "cor", "spherical"]:
         
     | 
| 
      
 1028 
     | 
    
         
            +
                                q_idxs_link = idx_map["q"](link_idx)
         
     | 
| 
      
 1029 
     | 
    
         
            +
                                q = q.at[q_idxs_link.start].set(1.0)
         
     | 
| 
      
 1030 
     | 
    
         
            +
             
     | 
| 
      
 1031 
     | 
    
         
            +
                        sys.scan(
         
     | 
| 
      
 1032 
     | 
    
         
            +
                            replace_by_unit_quat,
         
     | 
| 
      
 1033 
     | 
    
         
            +
                            "ll",
         
     | 
| 
      
 1034 
     | 
    
         
            +
                            sys.link_types,
         
     | 
| 
      
 1035 
     | 
    
         
            +
                            list(range(sys.num_links())),
         
     | 
| 
      
 1036 
     | 
    
         
            +
                        )
         
     | 
| 
      
 1037 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 1038 
     | 
    
         
            +
                        pass
         
     | 
| 
      
 1039 
     | 
    
         
            +
             
     | 
| 
      
 1040 
     | 
    
         
            +
                    if qd is None:
         
     | 
| 
      
 1041 
     | 
    
         
            +
                        qd = jnp.zeros((sys.qd_size(),))
         
     | 
| 
      
 1042 
     | 
    
         
            +
             
     | 
| 
      
 1043 
     | 
    
         
            +
                    if x is None:
         
     | 
| 
      
 1044 
     | 
    
         
            +
                        x = Transform.zero((sys.num_links(),))
         
     | 
| 
      
 1045 
     | 
    
         
            +
             
     | 
| 
      
 1046 
     | 
    
         
            +
                    return cls(q, qd, x, jnp.diag(jnp.ones((sys.qd_size(),))))
         
     |