imt-ring 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- imt_ring-1.2.1.dist-info/METADATA +91 -0
 - imt_ring-1.2.1.dist-info/RECORD +83 -0
 - imt_ring-1.2.1.dist-info/WHEEL +5 -0
 - imt_ring-1.2.1.dist-info/top_level.txt +1 -0
 - ring/__init__.py +63 -0
 - ring/algebra.py +100 -0
 - ring/algorithms/__init__.py +45 -0
 - ring/algorithms/_random.py +403 -0
 - ring/algorithms/custom_joints/__init__.py +6 -0
 - ring/algorithms/custom_joints/rr_imp_joint.py +69 -0
 - ring/algorithms/custom_joints/rr_joint.py +33 -0
 - ring/algorithms/custom_joints/suntay.py +424 -0
 - ring/algorithms/dynamics.py +345 -0
 - ring/algorithms/generator/__init__.py +25 -0
 - ring/algorithms/generator/base.py +414 -0
 - ring/algorithms/generator/batch.py +282 -0
 - ring/algorithms/generator/motion_artifacts.py +222 -0
 - ring/algorithms/generator/pd_control.py +182 -0
 - ring/algorithms/generator/randomize.py +119 -0
 - ring/algorithms/generator/transforms.py +410 -0
 - ring/algorithms/generator/types.py +36 -0
 - ring/algorithms/jcalc.py +840 -0
 - ring/algorithms/kinematics.py +202 -0
 - ring/algorithms/sensors.py +582 -0
 - ring/base.py +1046 -0
 - ring/io/__init__.py +9 -0
 - ring/io/examples/branched.xml +24 -0
 - ring/io/examples/exclude/knee_trans_dof.xml +26 -0
 - ring/io/examples/exclude/standard_sys.xml +106 -0
 - ring/io/examples/exclude/standard_sys_rr_imp.xml +106 -0
 - ring/io/examples/inv_pendulum.xml +14 -0
 - ring/io/examples/knee_flexible_imus.xml +22 -0
 - ring/io/examples/spherical_stiff.xml +11 -0
 - ring/io/examples/symmetric.xml +12 -0
 - ring/io/examples/test_all_1.xml +39 -0
 - ring/io/examples/test_all_2.xml +39 -0
 - ring/io/examples/test_ang0_pos0.xml +9 -0
 - ring/io/examples/test_control.xml +16 -0
 - ring/io/examples/test_double_pendulum.xml +14 -0
 - ring/io/examples/test_free.xml +11 -0
 - ring/io/examples/test_kinematics.xml +23 -0
 - ring/io/examples/test_morph_system/four_seg_seg1.xml +26 -0
 - ring/io/examples/test_morph_system/four_seg_seg3.xml +26 -0
 - ring/io/examples/test_randomize_position.xml +26 -0
 - ring/io/examples/test_sensors.xml +13 -0
 - ring/io/examples/test_three_seg_seg2.xml +23 -0
 - ring/io/examples.py +42 -0
 - ring/io/test_examples.py +6 -0
 - ring/io/xml/__init__.py +6 -0
 - ring/io/xml/abstract.py +300 -0
 - ring/io/xml/from_xml.py +299 -0
 - ring/io/xml/test_from_xml.py +56 -0
 - ring/io/xml/test_to_xml.py +31 -0
 - ring/io/xml/to_xml.py +94 -0
 - ring/maths.py +397 -0
 - ring/ml/__init__.py +33 -0
 - ring/ml/base.py +292 -0
 - ring/ml/callbacks.py +434 -0
 - ring/ml/ml_utils.py +272 -0
 - ring/ml/optimizer.py +149 -0
 - ring/ml/params/0x13e3518065c21cd8.pickle +0 -0
 - ring/ml/ringnet.py +279 -0
 - ring/ml/train.py +318 -0
 - ring/ml/training_loop.py +131 -0
 - ring/rendering/__init__.py +2 -0
 - ring/rendering/base_render.py +271 -0
 - ring/rendering/mujoco_render.py +222 -0
 - ring/rendering/vispy_render.py +340 -0
 - ring/rendering/vispy_visuals.py +290 -0
 - ring/sim2real/__init__.py +7 -0
 - ring/sim2real/sim2real.py +288 -0
 - ring/spatial.py +126 -0
 - ring/sys_composer/__init__.py +5 -0
 - ring/sys_composer/delete_sys.py +114 -0
 - ring/sys_composer/inject_sys.py +110 -0
 - ring/sys_composer/morph_sys.py +361 -0
 - ring/utils/__init__.py +21 -0
 - ring/utils/batchsize.py +51 -0
 - ring/utils/colab.py +48 -0
 - ring/utils/hdf5.py +198 -0
 - ring/utils/normalizer.py +56 -0
 - ring/utils/path.py +44 -0
 - ring/utils/utils.py +161 -0
 
    
        ring/algorithms/jcalc.py
    ADDED
    
    | 
         @@ -0,0 +1,840 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            from dataclasses import asdict
         
     | 
| 
      
 2 
     | 
    
         
            +
            from dataclasses import dataclass
         
     | 
| 
      
 3 
     | 
    
         
            +
            from dataclasses import field
         
     | 
| 
      
 4 
     | 
    
         
            +
            from dataclasses import replace
         
     | 
| 
      
 5 
     | 
    
         
            +
            from typing import Any, Callable, get_type_hints, Optional
         
     | 
| 
      
 6 
     | 
    
         
            +
             
     | 
| 
      
 7 
     | 
    
         
            +
            import jax
         
     | 
| 
      
 8 
     | 
    
         
            +
            import jax.numpy as jnp
         
     | 
| 
      
 9 
     | 
    
         
            +
            import tree_utils
         
     | 
| 
      
 10 
     | 
    
         
            +
             
     | 
| 
      
 11 
     | 
    
         
            +
            from ring import algebra
         
     | 
| 
      
 12 
     | 
    
         
            +
            from ring import base
         
     | 
| 
      
 13 
     | 
    
         
            +
            from ring import maths
         
     | 
| 
      
 14 
     | 
    
         
            +
            from ring.algorithms import _random
         
     | 
| 
      
 15 
     | 
    
         
            +
            from ring.algorithms._random import _to_float
         
     | 
| 
      
 16 
     | 
    
         
            +
            from ring.algorithms._random import TimeDependentFloat
         
     | 
| 
      
 17 
     | 
    
         
            +
             
     | 
| 
      
 18 
     | 
    
         
            +
             
     | 
| 
      
 19 
     | 
    
         
            +
            @dataclass
         
     | 
| 
      
 20 
     | 
    
         
            +
            class MotionConfig:
         
     | 
| 
      
 21 
     | 
    
         
            +
                T: float = 60.0  # length of random motion
         
     | 
| 
      
 22 
     | 
    
         
            +
                t_min: float = 0.05  # min time between two generated angles
         
     | 
| 
      
 23 
     | 
    
         
            +
                t_max: float | TimeDependentFloat = 0.30  # max time ..
         
     | 
| 
      
 24 
     | 
    
         
            +
             
     | 
| 
      
 25 
     | 
    
         
            +
                dang_min: float | TimeDependentFloat = 0.1  # minimum angular velocity in rad/s
         
     | 
| 
      
 26 
     | 
    
         
            +
                dang_max: float | TimeDependentFloat = 3.0  # maximum angular velocity in rad/s
         
     | 
| 
      
 27 
     | 
    
         
            +
             
     | 
| 
      
 28 
     | 
    
         
            +
                # minimum angular velocity of euler angles used for `free and spherical joints`
         
     | 
| 
      
 29 
     | 
    
         
            +
                dang_min_free_spherical: float | TimeDependentFloat = 0.1
         
     | 
| 
      
 30 
     | 
    
         
            +
                dang_max_free_spherical: float | TimeDependentFloat = 3.0
         
     | 
| 
      
 31 
     | 
    
         
            +
             
     | 
| 
      
 32 
     | 
    
         
            +
                # max min allowed actual delta values in radians
         
     | 
| 
      
 33 
     | 
    
         
            +
                delta_ang_min: float | TimeDependentFloat = 0.0
         
     | 
| 
      
 34 
     | 
    
         
            +
                delta_ang_max: float | TimeDependentFloat = 2 * jnp.pi
         
     | 
| 
      
 35 
     | 
    
         
            +
                delta_ang_min_free_spherical: float | TimeDependentFloat = 0.0
         
     | 
| 
      
 36 
     | 
    
         
            +
                delta_ang_max_free_spherical: float | TimeDependentFloat = 2 * jnp.pi
         
     | 
| 
      
 37 
     | 
    
         
            +
             
     | 
| 
      
 38 
     | 
    
         
            +
                dpos_min: float | TimeDependentFloat = 0.001  # speed of translation
         
     | 
| 
      
 39 
     | 
    
         
            +
                dpos_max: float | TimeDependentFloat = 0.7
         
     | 
| 
      
 40 
     | 
    
         
            +
                pos_min: float | TimeDependentFloat = -2.5
         
     | 
| 
      
 41 
     | 
    
         
            +
                pos_max: float | TimeDependentFloat = +2.5
         
     | 
| 
      
 42 
     | 
    
         
            +
             
     | 
| 
      
 43 
     | 
    
         
            +
                # used by both `random_angle_*` and `random_pos_*`
         
     | 
| 
      
 44 
     | 
    
         
            +
                # only used if `randomized_interpolation` is set
         
     | 
| 
      
 45 
     | 
    
         
            +
                cdf_bins_min: int = 5
         
     | 
| 
      
 46 
     | 
    
         
            +
                # by default equal to `cdf_bins_min`
         
     | 
| 
      
 47 
     | 
    
         
            +
                cdf_bins_max: Optional[int] = None
         
     | 
| 
      
 48 
     | 
    
         
            +
             
     | 
| 
      
 49 
     | 
    
         
            +
                # flags
         
     | 
| 
      
 50 
     | 
    
         
            +
                randomized_interpolation_angle: bool = False
         
     | 
| 
      
 51 
     | 
    
         
            +
                randomized_interpolation_position: bool = False
         
     | 
| 
      
 52 
     | 
    
         
            +
                interpolation_method: str = "cosine"
         
     | 
| 
      
 53 
     | 
    
         
            +
                range_of_motion_hinge: bool = True
         
     | 
| 
      
 54 
     | 
    
         
            +
                range_of_motion_hinge_method: str = "uniform"
         
     | 
| 
      
 55 
     | 
    
         
            +
             
     | 
| 
      
 56 
     | 
    
         
            +
                # initial value of joints
         
     | 
| 
      
 57 
     | 
    
         
            +
                ang0_min: float = -jnp.pi
         
     | 
| 
      
 58 
     | 
    
         
            +
                ang0_max: float = jnp.pi
         
     | 
| 
      
 59 
     | 
    
         
            +
                pos0_min: float = 0.0
         
     | 
| 
      
 60 
     | 
    
         
            +
                pos0_max: float = 0.0
         
     | 
| 
      
 61 
     | 
    
         
            +
             
     | 
| 
      
 62 
     | 
    
         
            +
                # cor (center of rotation) custom fields
         
     | 
| 
      
 63 
     | 
    
         
            +
                cor: bool = False
         
     | 
| 
      
 64 
     | 
    
         
            +
                cor_t_min: float = 0.2
         
     | 
| 
      
 65 
     | 
    
         
            +
                cor_t_max: float | TimeDependentFloat = 2.0
         
     | 
| 
      
 66 
     | 
    
         
            +
                cor_dpos_min: float | TimeDependentFloat = 0.00001
         
     | 
| 
      
 67 
     | 
    
         
            +
                cor_dpos_max: float | TimeDependentFloat = 0.5
         
     | 
| 
      
 68 
     | 
    
         
            +
                cor_pos_min: float | TimeDependentFloat = -0.4
         
     | 
| 
      
 69 
     | 
    
         
            +
                cor_pos_max: float | TimeDependentFloat = 0.4
         
     | 
| 
      
 70 
     | 
    
         
            +
             
     | 
| 
      
 71 
     | 
    
         
            +
                def is_feasible(self) -> bool:
         
     | 
| 
      
 72 
     | 
    
         
            +
                    return _is_feasible_config1(self)
         
     | 
| 
      
 73 
     | 
    
         
            +
             
     | 
| 
      
 74 
     | 
    
         
            +
                def to_nomotion_config(self) -> "MotionConfig":
         
     | 
| 
      
 75 
     | 
    
         
            +
                    kwargs = asdict(self)
         
     | 
| 
      
 76 
     | 
    
         
            +
                    for key in [
         
     | 
| 
      
 77 
     | 
    
         
            +
                        "dang_min",
         
     | 
| 
      
 78 
     | 
    
         
            +
                        "dang_max",
         
     | 
| 
      
 79 
     | 
    
         
            +
                        "delta_ang_min",
         
     | 
| 
      
 80 
     | 
    
         
            +
                        "dang_min_free_spherical",
         
     | 
| 
      
 81 
     | 
    
         
            +
                        "dang_max_free_spherical",
         
     | 
| 
      
 82 
     | 
    
         
            +
                        "delta_ang_min_free_spherical",
         
     | 
| 
      
 83 
     | 
    
         
            +
                        "dpos_min",
         
     | 
| 
      
 84 
     | 
    
         
            +
                        "dpos_max",
         
     | 
| 
      
 85 
     | 
    
         
            +
                    ]:
         
     | 
| 
      
 86 
     | 
    
         
            +
                        kwargs[key] = 0.0
         
     | 
| 
      
 87 
     | 
    
         
            +
                    nomotion_config = MotionConfig(**kwargs)
         
     | 
| 
      
 88 
     | 
    
         
            +
                    assert nomotion_config.is_feasible()
         
     | 
| 
      
 89 
     | 
    
         
            +
                    return nomotion_config
         
     | 
| 
      
 90 
     | 
    
         
            +
             
     | 
| 
      
 91 
     | 
    
         
            +
             
     | 
| 
      
 92 
     | 
    
         
            +
            def _is_feasible_config1(c: MotionConfig) -> bool:
         
     | 
| 
      
 93 
     | 
    
         
            +
                t_min, t_max = c.t_min, _to_float(c.t_max, 0.0)
         
     | 
| 
      
 94 
     | 
    
         
            +
             
     | 
| 
      
 95 
     | 
    
         
            +
                def dx_deltax_check(dx_min, dx_max, deltax_min, deltax_max) -> bool:
         
     | 
| 
      
 96 
     | 
    
         
            +
                    dx_min, dx_max, deltax_min, deltax_max = map(
         
     | 
| 
      
 97 
     | 
    
         
            +
                        (lambda v: _to_float(v, 0.0)), (dx_min, dx_max, deltax_min, deltax_max)
         
     | 
| 
      
 98 
     | 
    
         
            +
                    )
         
     | 
| 
      
 99 
     | 
    
         
            +
                    if (deltax_max / t_min) < dx_min:
         
     | 
| 
      
 100 
     | 
    
         
            +
                        return False
         
     | 
| 
      
 101 
     | 
    
         
            +
                    if (deltax_min / t_max) > dx_max:
         
     | 
| 
      
 102 
     | 
    
         
            +
                        return False
         
     | 
| 
      
 103 
     | 
    
         
            +
                    return True
         
     | 
| 
      
 104 
     | 
    
         
            +
             
     | 
| 
      
 105 
     | 
    
         
            +
                return all(
         
     | 
| 
      
 106 
     | 
    
         
            +
                    [
         
     | 
| 
      
 107 
     | 
    
         
            +
                        dx_deltax_check(*args)
         
     | 
| 
      
 108 
     | 
    
         
            +
                        for args in zip(
         
     | 
| 
      
 109 
     | 
    
         
            +
                            [c.dang_min, c.dang_min_free_spherical],
         
     | 
| 
      
 110 
     | 
    
         
            +
                            [c.dang_max, c.dang_max_free_spherical],
         
     | 
| 
      
 111 
     | 
    
         
            +
                            [c.delta_ang_min, c.delta_ang_min_free_spherical],
         
     | 
| 
      
 112 
     | 
    
         
            +
                            [c.delta_ang_max, c.delta_ang_max_free_spherical],
         
     | 
| 
      
 113 
     | 
    
         
            +
                        )
         
     | 
| 
      
 114 
     | 
    
         
            +
                    ]
         
     | 
| 
      
 115 
     | 
    
         
            +
                )
         
     | 
| 
      
 116 
     | 
    
         
            +
             
     | 
| 
      
 117 
     | 
    
         
            +
             
     | 
| 
      
 118 
     | 
    
         
            +
            def _find_interval(t: jax.Array, boundaries: jax.Array):
         
     | 
| 
      
 119 
     | 
    
         
            +
                """Find the interval of `boundaries` between which `t` lies.
         
     | 
| 
      
 120 
     | 
    
         
            +
             
     | 
| 
      
 121 
     | 
    
         
            +
                Args:
         
     | 
| 
      
 122 
     | 
    
         
            +
                    t: Scalar float (e.g. time)
         
     | 
| 
      
 123 
     | 
    
         
            +
                    boundaries: Array of floats
         
     | 
| 
      
 124 
     | 
    
         
            +
             
     | 
| 
      
 125 
     | 
    
         
            +
                Example: (from `test_jcalc.py`)
         
     | 
| 
      
 126 
     | 
    
         
            +
                    >> _find_interval(1.5, jnp.array([0.0, 1.0, 2.0])) -> 2
         
     | 
| 
      
 127 
     | 
    
         
            +
                    >> _find_interval(0.5, jnp.array([0.0])) -> 1
         
     | 
| 
      
 128 
     | 
    
         
            +
                    >> _find_interval(-0.5, jnp.array([0.0])) -> 0
         
     | 
| 
      
 129 
     | 
    
         
            +
                """
         
     | 
| 
      
 130 
     | 
    
         
            +
                assert boundaries.ndim == 1
         
     | 
| 
      
 131 
     | 
    
         
            +
             
     | 
| 
      
 132 
     | 
    
         
            +
                @jax.vmap
         
     | 
| 
      
 133 
     | 
    
         
            +
                def leq_than_boundary(boundary: jax.Array):
         
     | 
| 
      
 134 
     | 
    
         
            +
                    return jnp.where(t >= boundary, 1, 0)
         
     | 
| 
      
 135 
     | 
    
         
            +
             
     | 
| 
      
 136 
     | 
    
         
            +
                return jnp.sum(leq_than_boundary(boundaries))
         
     | 
| 
      
 137 
     | 
    
         
            +
             
     | 
| 
      
 138 
     | 
    
         
            +
             
     | 
| 
      
 139 
     | 
    
         
            +
            def join_motionconfigs(
         
     | 
| 
      
 140 
     | 
    
         
            +
                configs: list[MotionConfig], boundaries: list[float]
         
     | 
| 
      
 141 
     | 
    
         
            +
            ) -> MotionConfig:
         
     | 
| 
      
 142 
     | 
    
         
            +
                assert len(configs) == (
         
     | 
| 
      
 143 
     | 
    
         
            +
                    len(boundaries) + 1
         
     | 
| 
      
 144 
     | 
    
         
            +
                ), "length of `boundaries` should be one less than length of `configs`"
         
     | 
| 
      
 145 
     | 
    
         
            +
                boundaries = jnp.array(boundaries, dtype=float)
         
     | 
| 
      
 146 
     | 
    
         
            +
             
     | 
| 
      
 147 
     | 
    
         
            +
                def new_value(field: str):
         
     | 
| 
      
 148 
     | 
    
         
            +
                    scalar_options = jnp.array([getattr(c, field) for c in configs])
         
     | 
| 
      
 149 
     | 
    
         
            +
             
     | 
| 
      
 150 
     | 
    
         
            +
                    def scalar(t):
         
     | 
| 
      
 151 
     | 
    
         
            +
                        return jax.lax.dynamic_index_in_dim(
         
     | 
| 
      
 152 
     | 
    
         
            +
                            scalar_options, _find_interval(t, boundaries), keepdims=False
         
     | 
| 
      
 153 
     | 
    
         
            +
                        )
         
     | 
| 
      
 154 
     | 
    
         
            +
             
     | 
| 
      
 155 
     | 
    
         
            +
                    return scalar
         
     | 
| 
      
 156 
     | 
    
         
            +
             
     | 
| 
      
 157 
     | 
    
         
            +
                hints = get_type_hints(MotionConfig())
         
     | 
| 
      
 158 
     | 
    
         
            +
                attrs = MotionConfig().__dict__
         
     | 
| 
      
 159 
     | 
    
         
            +
                is_time_dependent_field = lambda key: hints[key] == (float | TimeDependentFloat)
         
     | 
| 
      
 160 
     | 
    
         
            +
                time_dependent_fields = [key for key in attrs if is_time_dependent_field(key)]
         
     | 
| 
      
 161 
     | 
    
         
            +
                time_independent_fields = [key for key in attrs if not is_time_dependent_field(key)]
         
     | 
| 
      
 162 
     | 
    
         
            +
             
     | 
| 
      
 163 
     | 
    
         
            +
                for time_dep_field in time_independent_fields:
         
     | 
| 
      
 164 
     | 
    
         
            +
                    field_values = set([getattr(config, time_dep_field) for config in configs])
         
     | 
| 
      
 165 
     | 
    
         
            +
                    assert (
         
     | 
| 
      
 166 
     | 
    
         
            +
                        len(field_values) == 1
         
     | 
| 
      
 167 
     | 
    
         
            +
                    ), f"MotionConfig.{time_dep_field}={field_values}. Should be one unique value.."
         
     | 
| 
      
 168 
     | 
    
         
            +
             
     | 
| 
      
 169 
     | 
    
         
            +
                changes = {field: new_value(field) for field in time_dependent_fields}
         
     | 
| 
      
 170 
     | 
    
         
            +
                return replace(configs[0], **changes)
         
     | 
| 
      
 171 
     | 
    
         
            +
             
     | 
| 
      
 172 
     | 
    
         
            +
             
     | 
| 
      
 173 
     | 
    
         
            +
            DRAW_FN = Callable[
         
     | 
| 
      
 174 
     | 
    
         
            +
                # config, key_t, key_value, dt, params
         
     | 
| 
      
 175 
     | 
    
         
            +
                [MotionConfig, jax.random.PRNGKey, jax.random.PRNGKey, float, jax.Array],
         
     | 
| 
      
 176 
     | 
    
         
            +
                jax.Array,
         
     | 
| 
      
 177 
     | 
    
         
            +
            ]
         
     | 
| 
      
 178 
     | 
    
         
            +
            P_CONTROL_TERM = Callable[
         
     | 
| 
      
 179 
     | 
    
         
            +
                # q, q_ref -> qdd
         
     | 
| 
      
 180 
     | 
    
         
            +
                # (q_size,), (q_size), -> (qd_size,)
         
     | 
| 
      
 181 
     | 
    
         
            +
                [jax.Array, jax.Array],
         
     | 
| 
      
 182 
     | 
    
         
            +
                jax.Array,
         
     | 
| 
      
 183 
     | 
    
         
            +
            ]
         
     | 
| 
      
 184 
     | 
    
         
            +
            # this function is used to generate the velocity reference trajectory from the
         
     | 
| 
      
 185 
     | 
    
         
            +
            # reference trajectory q, which both are required for the pd control, which it is
         
     | 
| 
      
 186 
     | 
    
         
            +
            # required if the simulation is not kinematic but dynamic
         
     | 
| 
      
 187 
     | 
    
         
            +
            QD_FROM_Q = Callable[
         
     | 
| 
      
 188 
     | 
    
         
            +
                # qs, dt -> dqs
         
     | 
| 
      
 189 
     | 
    
         
            +
                # (N, q_size), (1,) -> (N, qd_size)
         
     | 
| 
      
 190 
     | 
    
         
            +
                [jax.Array, jax.Array],
         
     | 
| 
      
 191 
     | 
    
         
            +
                jax.Array,
         
     | 
| 
      
 192 
     | 
    
         
            +
            ]
         
     | 
| 
      
 193 
     | 
    
         
            +
            # used by ring.algorithms.inverse_kinematics_endeffector to  maps from
         
     | 
| 
      
 194 
     | 
    
         
            +
            # [-inf, inf] -> feasible joint value range. Defaults to {}.
         
     | 
| 
      
 195 
     | 
    
         
            +
            # For example: By default, for a hinge joint it uses `maths.wrap_to_pi`.
         
     | 
| 
      
 196 
     | 
    
         
            +
            # For a spherical joint it would normalize to create a unit quaternion.
         
     | 
| 
      
 197 
     | 
    
         
            +
            COORDINATE_VECTOR_TO_Q = Callable[
         
     | 
| 
      
 198 
     | 
    
         
            +
                # (q_size,) -> (q_size)
         
     | 
| 
      
 199 
     | 
    
         
            +
                [jax.Array],
         
     | 
| 
      
 200 
     | 
    
         
            +
                jax.Array,
         
     | 
| 
      
 201 
     | 
    
         
            +
            ]
         
     | 
| 
      
 202 
     | 
    
         
            +
             
     | 
| 
      
 203 
     | 
    
         
            +
            # used only by `sim2real.project_xs`, and it receives a transform object
         
     | 
| 
      
 204 
     | 
    
         
            +
            # and projects it into the feasible subspace as defined by the joint
         
     | 
| 
      
 205 
     | 
    
         
            +
            # and returns the new transform object
         
     | 
| 
      
 206 
     | 
    
         
            +
            PROJECT_TRANSFORM_TO_FEASIBLE = Callable[
         
     | 
| 
      
 207 
     | 
    
         
            +
                # base.Transform, Pytree (joint_params)
         
     | 
| 
      
 208 
     | 
    
         
            +
                [base.Transform, tree_utils.PyTree],
         
     | 
| 
      
 209 
     | 
    
         
            +
                base.Transform,
         
     | 
| 
      
 210 
     | 
    
         
            +
            ]
         
     | 
| 
      
 211 
     | 
    
         
            +
             
     | 
| 
      
 212 
     | 
    
         
            +
            # used by ring.System.from_xml and by ring.RCMG
         
     | 
| 
      
 213 
     | 
    
         
            +
            # (key) -> Pytree
         
     | 
| 
      
 214 
     | 
    
         
            +
            # if it is not given and None, then there will be no specific
         
     | 
| 
      
 215 
     | 
    
         
            +
            # joint_parameters for the custom joint and it will simply receive
         
     | 
| 
      
 216 
     | 
    
         
            +
            # the defaults parameters, that is joint_params['default']
         
     | 
| 
      
 217 
     | 
    
         
            +
            INIT_JOINT_PARAMS = Callable[[jax.Array], tree_utils.PyTree]
         
     | 
| 
      
 218 
     | 
    
         
            +
             
     | 
| 
      
 219 
     | 
    
         
            +
            # (transform2_p_to_i, joint_params) -> (q_size)
         
     | 
| 
      
 220 
     | 
    
         
            +
            INV_KIN = Callable[[base.Transform, tree_utils.PyTree], jax.Array]
         
     | 
| 
      
 221 
     | 
    
         
            +
             
     | 
| 
      
 222 
     | 
    
         
            +
             
     | 
| 
      
 223 
     | 
    
         
            +
            @dataclass
         
     | 
| 
      
 224 
     | 
    
         
            +
            class JointModel:
         
     | 
| 
      
 225 
     | 
    
         
            +
                # (q, params) -> Transform
         
     | 
| 
      
 226 
     | 
    
         
            +
                transform: Callable[[jax.Array, jax.Array], base.Transform]
         
     | 
| 
      
 227 
     | 
    
         
            +
                # len(motion) == len(qd)
         
     | 
| 
      
 228 
     | 
    
         
            +
                # if callable: joint_params -> base.Motion
         
     | 
| 
      
 229 
     | 
    
         
            +
                motion: list[base.Motion | Callable[[jax.Array], base.Motion]] = field(
         
     | 
| 
      
 230 
     | 
    
         
            +
                    default_factory=lambda: []
         
     | 
| 
      
 231 
     | 
    
         
            +
                )
         
     | 
| 
      
 232 
     | 
    
         
            +
                # (config, key_t, key_value, params) -> jax.Array
         
     | 
| 
      
 233 
     | 
    
         
            +
                rcmg_draw_fn: Optional[DRAW_FN] = None
         
     | 
| 
      
 234 
     | 
    
         
            +
             
     | 
| 
      
 235 
     | 
    
         
            +
                # only used by `pd_control`
         
     | 
| 
      
 236 
     | 
    
         
            +
                p_control_term: Optional[P_CONTROL_TERM] = None
         
     | 
| 
      
 237 
     | 
    
         
            +
                qd_from_q: Optional[QD_FROM_Q] = None
         
     | 
| 
      
 238 
     | 
    
         
            +
             
     | 
| 
      
 239 
     | 
    
         
            +
                # used by
         
     | 
| 
      
 240 
     | 
    
         
            +
                # -`inverse_kinematics_endeffector`
         
     | 
| 
      
 241 
     | 
    
         
            +
                # - System.coordinate_vector_to_q
         
     | 
| 
      
 242 
     | 
    
         
            +
                coordinate_vector_to_q: Optional[COORDINATE_VECTOR_TO_Q] = None
         
     | 
| 
      
 243 
     | 
    
         
            +
             
     | 
| 
      
 244 
     | 
    
         
            +
                # only used by `inverse_kinematics`
         
     | 
| 
      
 245 
     | 
    
         
            +
                inv_kin: Optional[INV_KIN] = None
         
     | 
| 
      
 246 
     | 
    
         
            +
             
     | 
| 
      
 247 
     | 
    
         
            +
                init_joint_params: Optional[INIT_JOINT_PARAMS] = None
         
     | 
| 
      
 248 
     | 
    
         
            +
             
     | 
| 
      
 249 
     | 
    
         
            +
                utilities: Optional[dict[str, Any]] = field(default_factory=lambda: dict())
         
     | 
| 
      
 250 
     | 
    
         
            +
             
     | 
| 
      
 251 
     | 
    
         
            +
             
     | 
| 
      
 252 
     | 
    
         
            +
            def _free_transform(q, _):
         
     | 
| 
      
 253 
     | 
    
         
            +
                rot, pos = q[:4], q[4:]
         
     | 
| 
      
 254 
     | 
    
         
            +
                return base.Transform(pos, rot)
         
     | 
| 
      
 255 
     | 
    
         
            +
             
     | 
| 
      
 256 
     | 
    
         
            +
             
     | 
| 
      
 257 
     | 
    
         
            +
            def _free_2d_transform(q, _):
         
     | 
| 
      
 258 
     | 
    
         
            +
                angle_x, pos_yz = q[0], q[1:]
         
     | 
| 
      
 259 
     | 
    
         
            +
                rot = maths.quat_rot_axis(maths.x_unit_vector, angle_x)
         
     | 
| 
      
 260 
     | 
    
         
            +
                pos = jnp.concatenate((jnp.array([0.0]), pos_yz))
         
     | 
| 
      
 261 
     | 
    
         
            +
                return base.Transform(pos, rot)
         
     | 
| 
      
 262 
     | 
    
         
            +
             
     | 
| 
      
 263 
     | 
    
         
            +
             
     | 
| 
      
 264 
     | 
    
         
            +
            def _rxyz_transform(q, _, axis):
         
     | 
| 
      
 265 
     | 
    
         
            +
                q = jnp.squeeze(q)
         
     | 
| 
      
 266 
     | 
    
         
            +
                rot = maths.quat_rot_axis(axis, q)
         
     | 
| 
      
 267 
     | 
    
         
            +
                return base.Transform.create(rot=rot)
         
     | 
| 
      
 268 
     | 
    
         
            +
             
     | 
| 
      
 269 
     | 
    
         
            +
             
     | 
| 
      
 270 
     | 
    
         
            +
            def _pxyz_transform(q, _, direction):
         
     | 
| 
      
 271 
     | 
    
         
            +
                pos = direction * q
         
     | 
| 
      
 272 
     | 
    
         
            +
                return base.Transform.create(pos=pos)
         
     | 
| 
      
 273 
     | 
    
         
            +
             
     | 
| 
      
 274 
     | 
    
         
            +
             
     | 
| 
      
 275 
     | 
    
         
            +
            def _frozen_transform(_, __):
         
     | 
| 
      
 276 
     | 
    
         
            +
                return base.Transform.zero()
         
     | 
| 
      
 277 
     | 
    
         
            +
             
     | 
| 
      
 278 
     | 
    
         
            +
             
     | 
| 
      
 279 
     | 
    
         
            +
            def _spherical_transform(q, _):
         
     | 
| 
      
 280 
     | 
    
         
            +
                return base.Transform.create(rot=q)
         
     | 
| 
      
 281 
     | 
    
         
            +
             
     | 
| 
      
 282 
     | 
    
         
            +
             
     | 
| 
      
 283 
     | 
    
         
            +
            def _saddle_transform(q, _):
         
     | 
| 
      
 284 
     | 
    
         
            +
                rot = maths.euler_to_quat(jnp.array([0.0, q[0], q[1]]))
         
     | 
| 
      
 285 
     | 
    
         
            +
                return base.Transform.create(rot=rot)
         
     | 
| 
      
 286 
     | 
    
         
            +
             
     | 
| 
      
 287 
     | 
    
         
            +
             
     | 
| 
      
 288 
     | 
    
         
            +
            def _p3d_transform(q, _):
         
     | 
| 
      
 289 
     | 
    
         
            +
                return base.Transform.create(pos=q)
         
     | 
| 
      
 290 
     | 
    
         
            +
             
     | 
| 
      
 291 
     | 
    
         
            +
             
     | 
| 
      
 292 
     | 
    
         
            +
            def _cor_transform(q, _):
         
     | 
| 
      
 293 
     | 
    
         
            +
                free = _free_transform(q[:7], _)
         
     | 
| 
      
 294 
     | 
    
         
            +
                p3d = _p3d_transform(q[7:], _)
         
     | 
| 
      
 295 
     | 
    
         
            +
                return algebra.transform_mul(p3d, free)
         
     | 
| 
      
 296 
     | 
    
         
            +
             
     | 
| 
      
 297 
     | 
    
         
            +
             
     | 
| 
      
 298 
     | 
    
         
            +
            mrx = base.Motion.create(ang=jnp.array([1.0, 0, 0]))
         
     | 
| 
      
 299 
     | 
    
         
            +
            mry = base.Motion.create(ang=jnp.array([0.0, 1, 0]))
         
     | 
| 
      
 300 
     | 
    
         
            +
            mrz = base.Motion.create(ang=jnp.array([0.0, 0, 1]))
         
     | 
| 
      
 301 
     | 
    
         
            +
            mpx = base.Motion.create(vel=jnp.array([1.0, 0, 0]))
         
     | 
| 
      
 302 
     | 
    
         
            +
            mpy = base.Motion.create(vel=jnp.array([0.0, 1, 0]))
         
     | 
| 
      
 303 
     | 
    
         
            +
            mpz = base.Motion.create(vel=jnp.array([0.0, 0, 1]))
         
     | 
| 
      
 304 
     | 
    
         
            +
             
     | 
| 
      
 305 
     | 
    
         
            +
             
     | 
| 
      
 306 
     | 
    
         
            +
            def _draw_rxyz(
         
     | 
| 
      
 307 
     | 
    
         
            +
                config: MotionConfig,
         
     | 
| 
      
 308 
     | 
    
         
            +
                key_t: jax.random.PRNGKey,
         
     | 
| 
      
 309 
     | 
    
         
            +
                key_value: jax.random.PRNGKey,
         
     | 
| 
      
 310 
     | 
    
         
            +
                dt: float,
         
     | 
| 
      
 311 
     | 
    
         
            +
                _: jax.Array,
         
     | 
| 
      
 312 
     | 
    
         
            +
                # TODO, delete these args and pass a modifified `config` with `replace` instead
         
     | 
| 
      
 313 
     | 
    
         
            +
                enable_range_of_motion: bool = True,
         
     | 
| 
      
 314 
     | 
    
         
            +
                free_spherical: bool = False,
         
     | 
| 
      
 315 
     | 
    
         
            +
            ) -> jax.Array:
         
     | 
| 
      
 316 
     | 
    
         
            +
                key_value, consume = jax.random.split(key_value)
         
     | 
| 
      
 317 
     | 
    
         
            +
                ANG_0 = jax.random.uniform(consume, minval=config.ang0_min, maxval=config.ang0_max)
         
     | 
| 
      
 318 
     | 
    
         
            +
                # `random_angle_over_time` always returns wrapped angles, thus it would be
         
     | 
| 
      
 319 
     | 
    
         
            +
                # inconsistent to allow an initial value that is not wrapped
         
     | 
| 
      
 320 
     | 
    
         
            +
                ANG_0 = maths.wrap_to_pi(ANG_0)
         
     | 
| 
      
 321 
     | 
    
         
            +
                # only used for `delta_ang_min_max` logic
         
     | 
| 
      
 322 
     | 
    
         
            +
                max_iter = 5
         
     | 
| 
      
 323 
     | 
    
         
            +
                return _random.random_angle_over_time(
         
     | 
| 
      
 324 
     | 
    
         
            +
                    key_t,
         
     | 
| 
      
 325 
     | 
    
         
            +
                    key_value,
         
     | 
| 
      
 326 
     | 
    
         
            +
                    ANG_0,
         
     | 
| 
      
 327 
     | 
    
         
            +
                    config.dang_min_free_spherical if free_spherical else config.dang_min,
         
     | 
| 
      
 328 
     | 
    
         
            +
                    config.dang_max_free_spherical if free_spherical else config.dang_max,
         
     | 
| 
      
 329 
     | 
    
         
            +
                    config.delta_ang_min_free_spherical if free_spherical else config.delta_ang_min,
         
     | 
| 
      
 330 
     | 
    
         
            +
                    config.delta_ang_max_free_spherical if free_spherical else config.delta_ang_max,
         
     | 
| 
      
 331 
     | 
    
         
            +
                    config.t_min,
         
     | 
| 
      
 332 
     | 
    
         
            +
                    config.t_max,
         
     | 
| 
      
 333 
     | 
    
         
            +
                    config.T,
         
     | 
| 
      
 334 
     | 
    
         
            +
                    dt,
         
     | 
| 
      
 335 
     | 
    
         
            +
                    max_iter,
         
     | 
| 
      
 336 
     | 
    
         
            +
                    config.randomized_interpolation_angle,
         
     | 
| 
      
 337 
     | 
    
         
            +
                    config.range_of_motion_hinge if enable_range_of_motion else False,
         
     | 
| 
      
 338 
     | 
    
         
            +
                    config.range_of_motion_hinge_method,
         
     | 
| 
      
 339 
     | 
    
         
            +
                    config.cdf_bins_min,
         
     | 
| 
      
 340 
     | 
    
         
            +
                    config.cdf_bins_max,
         
     | 
| 
      
 341 
     | 
    
         
            +
                    config.interpolation_method,
         
     | 
| 
      
 342 
     | 
    
         
            +
                )
         
     | 
| 
      
 343 
     | 
    
         
            +
             
     | 
| 
      
 344 
     | 
    
         
            +
             
     | 
| 
      
 345 
     | 
    
         
            +
            def _draw_pxyz(
         
     | 
| 
      
 346 
     | 
    
         
            +
                config: MotionConfig,
         
     | 
| 
      
 347 
     | 
    
         
            +
                _: jax.random.PRNGKey,
         
     | 
| 
      
 348 
     | 
    
         
            +
                key_value: jax.random.PRNGKey,
         
     | 
| 
      
 349 
     | 
    
         
            +
                dt: float,
         
     | 
| 
      
 350 
     | 
    
         
            +
                __: jax.Array,
         
     | 
| 
      
 351 
     | 
    
         
            +
                cor: bool = False,
         
     | 
| 
      
 352 
     | 
    
         
            +
            ) -> jax.Array:
         
     | 
| 
      
 353 
     | 
    
         
            +
                key_value, consume = jax.random.split(key_value)
         
     | 
| 
      
 354 
     | 
    
         
            +
                POS_0 = jax.random.uniform(consume, minval=config.pos0_min, maxval=config.pos0_max)
         
     | 
| 
      
 355 
     | 
    
         
            +
                max_iter = 100
         
     | 
| 
      
 356 
     | 
    
         
            +
                return _random.random_position_over_time(
         
     | 
| 
      
 357 
     | 
    
         
            +
                    key_value,
         
     | 
| 
      
 358 
     | 
    
         
            +
                    POS_0,
         
     | 
| 
      
 359 
     | 
    
         
            +
                    config.cor_pos_min if cor else config.pos_min,
         
     | 
| 
      
 360 
     | 
    
         
            +
                    config.cor_pos_max if cor else config.pos_max,
         
     | 
| 
      
 361 
     | 
    
         
            +
                    config.cor_dpos_min if cor else config.dpos_min,
         
     | 
| 
      
 362 
     | 
    
         
            +
                    config.cor_dpos_max if cor else config.dpos_max,
         
     | 
| 
      
 363 
     | 
    
         
            +
                    config.cor_t_min if cor else config.t_min,
         
     | 
| 
      
 364 
     | 
    
         
            +
                    config.cor_t_max if cor else config.t_max,
         
     | 
| 
      
 365 
     | 
    
         
            +
                    config.T,
         
     | 
| 
      
 366 
     | 
    
         
            +
                    dt,
         
     | 
| 
      
 367 
     | 
    
         
            +
                    max_iter,
         
     | 
| 
      
 368 
     | 
    
         
            +
                    config.randomized_interpolation_position,
         
     | 
| 
      
 369 
     | 
    
         
            +
                    config.cdf_bins_min,
         
     | 
| 
      
 370 
     | 
    
         
            +
                    config.cdf_bins_max,
         
     | 
| 
      
 371 
     | 
    
         
            +
                    config.interpolation_method,
         
     | 
| 
      
 372 
     | 
    
         
            +
                )
         
     | 
| 
      
 373 
     | 
    
         
            +
             
     | 
| 
      
 374 
     | 
    
         
            +
             
     | 
| 
      
 375 
     | 
    
         
            +
            def _draw_spherical(
         
     | 
| 
      
 376 
     | 
    
         
            +
                config: MotionConfig,
         
     | 
| 
      
 377 
     | 
    
         
            +
                key_t: jax.random.PRNGKey,
         
     | 
| 
      
 378 
     | 
    
         
            +
                key_value: jax.random.PRNGKey,
         
     | 
| 
      
 379 
     | 
    
         
            +
                dt: float,
         
     | 
| 
      
 380 
     | 
    
         
            +
                _: jax.Array,
         
     | 
| 
      
 381 
     | 
    
         
            +
            ) -> jax.Array:
         
     | 
| 
      
 382 
     | 
    
         
            +
                # NOTE: We draw 3 euler angles and then build a quaternion.
         
     | 
| 
      
 383 
     | 
    
         
            +
                # Not ideal, but i am unaware of a better way.
         
     | 
| 
      
 384 
     | 
    
         
            +
                @jax.vmap
         
     | 
| 
      
 385 
     | 
    
         
            +
                def draw_euler_angles(key_t, key_value):
         
     | 
| 
      
 386 
     | 
    
         
            +
                    return _draw_rxyz(
         
     | 
| 
      
 387 
     | 
    
         
            +
                        config,
         
     | 
| 
      
 388 
     | 
    
         
            +
                        key_t,
         
     | 
| 
      
 389 
     | 
    
         
            +
                        key_value,
         
     | 
| 
      
 390 
     | 
    
         
            +
                        dt,
         
     | 
| 
      
 391 
     | 
    
         
            +
                        None,
         
     | 
| 
      
 392 
     | 
    
         
            +
                        enable_range_of_motion=False,
         
     | 
| 
      
 393 
     | 
    
         
            +
                        free_spherical=True,
         
     | 
| 
      
 394 
     | 
    
         
            +
                    )
         
     | 
| 
      
 395 
     | 
    
         
            +
             
     | 
| 
      
 396 
     | 
    
         
            +
                triple = lambda key: jax.random.split(key, 3)
         
     | 
| 
      
 397 
     | 
    
         
            +
                euler_angles = draw_euler_angles(triple(key_t), triple(key_value)).T
         
     | 
| 
      
 398 
     | 
    
         
            +
                q = maths.quat_euler(euler_angles)
         
     | 
| 
      
 399 
     | 
    
         
            +
                return q
         
     | 
| 
      
 400 
     | 
    
         
            +
             
     | 
| 
      
 401 
     | 
    
         
            +
             
     | 
| 
      
 402 
     | 
    
         
            +
            def _draw_saddle(
         
     | 
| 
      
 403 
     | 
    
         
            +
                config: MotionConfig,
         
     | 
| 
      
 404 
     | 
    
         
            +
                key_t: jax.random.PRNGKey,
         
     | 
| 
      
 405 
     | 
    
         
            +
                key_value: jax.random.PRNGKey,
         
     | 
| 
      
 406 
     | 
    
         
            +
                dt: float,
         
     | 
| 
      
 407 
     | 
    
         
            +
                _: jax.Array,
         
     | 
| 
      
 408 
     | 
    
         
            +
            ) -> jax.Array:
         
     | 
| 
      
 409 
     | 
    
         
            +
                @jax.vmap
         
     | 
| 
      
 410 
     | 
    
         
            +
                def draw_euler_angles(key_t, key_value):
         
     | 
| 
      
 411 
     | 
    
         
            +
                    return _draw_rxyz(
         
     | 
| 
      
 412 
     | 
    
         
            +
                        config,
         
     | 
| 
      
 413 
     | 
    
         
            +
                        key_t,
         
     | 
| 
      
 414 
     | 
    
         
            +
                        key_value,
         
     | 
| 
      
 415 
     | 
    
         
            +
                        dt,
         
     | 
| 
      
 416 
     | 
    
         
            +
                        None,
         
     | 
| 
      
 417 
     | 
    
         
            +
                        enable_range_of_motion=False,
         
     | 
| 
      
 418 
     | 
    
         
            +
                        free_spherical=False,
         
     | 
| 
      
 419 
     | 
    
         
            +
                    )
         
     | 
| 
      
 420 
     | 
    
         
            +
             
     | 
| 
      
 421 
     | 
    
         
            +
                double = lambda key: jax.random.split(key)
         
     | 
| 
      
 422 
     | 
    
         
            +
                yz_euler_angles = draw_euler_angles(double(key_t), double(key_value)).T
         
     | 
| 
      
 423 
     | 
    
         
            +
                return yz_euler_angles
         
     | 
| 
      
 424 
     | 
    
         
            +
             
     | 
| 
      
 425 
     | 
    
         
            +
             
     | 
| 
      
 426 
     | 
    
         
            +
            def _draw_p3d_and_cor(
         
     | 
| 
      
 427 
     | 
    
         
            +
                config: MotionConfig,
         
     | 
| 
      
 428 
     | 
    
         
            +
                _: jax.random.PRNGKey,
         
     | 
| 
      
 429 
     | 
    
         
            +
                key_value: jax.random.PRNGKey,
         
     | 
| 
      
 430 
     | 
    
         
            +
                dt: float,
         
     | 
| 
      
 431 
     | 
    
         
            +
                __: jax.Array,
         
     | 
| 
      
 432 
     | 
    
         
            +
                cor: bool,
         
     | 
| 
      
 433 
     | 
    
         
            +
            ) -> jax.Array:
         
     | 
| 
      
 434 
     | 
    
         
            +
                pos = jax.vmap(lambda key: _draw_pxyz(config, None, key, dt, None, cor))(
         
     | 
| 
      
 435 
     | 
    
         
            +
                    jax.random.split(key_value, 3)
         
     | 
| 
      
 436 
     | 
    
         
            +
                )
         
     | 
| 
      
 437 
     | 
    
         
            +
                return pos.T
         
     | 
| 
      
 438 
     | 
    
         
            +
             
     | 
| 
      
 439 
     | 
    
         
            +
             
     | 
| 
      
 440 
     | 
    
         
            +
            def _draw_p3d(
         
     | 
| 
      
 441 
     | 
    
         
            +
                config: MotionConfig,
         
     | 
| 
      
 442 
     | 
    
         
            +
                _: jax.random.PRNGKey,
         
     | 
| 
      
 443 
     | 
    
         
            +
                key_value: jax.random.PRNGKey,
         
     | 
| 
      
 444 
     | 
    
         
            +
                dt: float,
         
     | 
| 
      
 445 
     | 
    
         
            +
                __: jax.Array,
         
     | 
| 
      
 446 
     | 
    
         
            +
            ) -> jax.Array:
         
     | 
| 
      
 447 
     | 
    
         
            +
                return _draw_p3d_and_cor(config, _, key_value, dt, None, cor=False)
         
     | 
| 
      
 448 
     | 
    
         
            +
             
     | 
| 
      
 449 
     | 
    
         
            +
             
     | 
| 
      
 450 
     | 
    
         
            +
            def _draw_cor(
         
     | 
| 
      
 451 
     | 
    
         
            +
                config: MotionConfig,
         
     | 
| 
      
 452 
     | 
    
         
            +
                _: jax.random.PRNGKey,
         
     | 
| 
      
 453 
     | 
    
         
            +
                key_value: jax.random.PRNGKey,
         
     | 
| 
      
 454 
     | 
    
         
            +
                dt: float,
         
     | 
| 
      
 455 
     | 
    
         
            +
                __: jax.Array,
         
     | 
| 
      
 456 
     | 
    
         
            +
            ) -> jax.Array:
         
     | 
| 
      
 457 
     | 
    
         
            +
                key_value1, key_value2 = jax.random.split(key_value)
         
     | 
| 
      
 458 
     | 
    
         
            +
                q_free = _draw_free(config, _, key_value1, dt, None)
         
     | 
| 
      
 459 
     | 
    
         
            +
                q_p3d = _draw_p3d_and_cor(config, _, key_value2, dt, None, cor=True)
         
     | 
| 
      
 460 
     | 
    
         
            +
                return jnp.concatenate((q_free, q_p3d), axis=1)
         
     | 
| 
      
 461 
     | 
    
         
            +
             
     | 
| 
      
 462 
     | 
    
         
            +
             
     | 
| 
      
 463 
     | 
    
         
            +
            def _draw_free(
         
     | 
| 
      
 464 
     | 
    
         
            +
                config: MotionConfig,
         
     | 
| 
      
 465 
     | 
    
         
            +
                key_t: jax.random.PRNGKey,
         
     | 
| 
      
 466 
     | 
    
         
            +
                key_value: jax.random.PRNGKey,
         
     | 
| 
      
 467 
     | 
    
         
            +
                dt: float,
         
     | 
| 
      
 468 
     | 
    
         
            +
                __: jax.Array,
         
     | 
| 
      
 469 
     | 
    
         
            +
            ) -> jax.Array:
         
     | 
| 
      
 470 
     | 
    
         
            +
                key_value1, key_value2 = jax.random.split(key_value)
         
     | 
| 
      
 471 
     | 
    
         
            +
                q = _draw_spherical(config, key_t, key_value1, dt, None)
         
     | 
| 
      
 472 
     | 
    
         
            +
                pos = _draw_p3d(config, None, key_value2, dt, None)
         
     | 
| 
      
 473 
     | 
    
         
            +
                return jnp.concatenate((q, pos), axis=1)
         
     | 
| 
      
 474 
     | 
    
         
            +
             
     | 
| 
      
 475 
     | 
    
         
            +
             
     | 
| 
      
 476 
     | 
    
         
            +
            def _draw_free_2d(
         
     | 
| 
      
 477 
     | 
    
         
            +
                config: MotionConfig,
         
     | 
| 
      
 478 
     | 
    
         
            +
                key_t: jax.random.PRNGKey,
         
     | 
| 
      
 479 
     | 
    
         
            +
                key_value: jax.random.PRNGKey,
         
     | 
| 
      
 480 
     | 
    
         
            +
                dt: float,
         
     | 
| 
      
 481 
     | 
    
         
            +
                __: jax.Array,
         
     | 
| 
      
 482 
     | 
    
         
            +
            ) -> jax.Array:
         
     | 
| 
      
 483 
     | 
    
         
            +
                key_value1, key_value2 = jax.random.split(key_value)
         
     | 
| 
      
 484 
     | 
    
         
            +
                angle_x = _draw_rxyz(
         
     | 
| 
      
 485 
     | 
    
         
            +
                    config,
         
     | 
| 
      
 486 
     | 
    
         
            +
                    key_t,
         
     | 
| 
      
 487 
     | 
    
         
            +
                    key_value1,
         
     | 
| 
      
 488 
     | 
    
         
            +
                    dt,
         
     | 
| 
      
 489 
     | 
    
         
            +
                    None,
         
     | 
| 
      
 490 
     | 
    
         
            +
                    enable_range_of_motion=False,
         
     | 
| 
      
 491 
     | 
    
         
            +
                    free_spherical=True,
         
     | 
| 
      
 492 
     | 
    
         
            +
                )[:, None]
         
     | 
| 
      
 493 
     | 
    
         
            +
                pos_yz = _draw_p3d(config, None, key_value2, dt, None)[:, :2]
         
     | 
| 
      
 494 
     | 
    
         
            +
                return jnp.concatenate((angle_x, pos_yz), axis=1)
         
     | 
| 
      
 495 
     | 
    
         
            +
             
     | 
| 
      
 496 
     | 
    
         
            +
             
     | 
| 
      
 497 
     | 
    
         
            +
            def _draw_frozen(config: MotionConfig, _, __, dt: float, ___) -> jax.Array:
         
     | 
| 
      
 498 
     | 
    
         
            +
                N = int(config.T / dt)
         
     | 
| 
      
 499 
     | 
    
         
            +
                return jnp.zeros((N, 0))
         
     | 
| 
      
 500 
     | 
    
         
            +
             
     | 
| 
      
 501 
     | 
    
         
            +
             
     | 
| 
      
 502 
     | 
    
         
            +
            qrel = lambda q1, q2: maths.quat_mul(q1, maths.quat_inv(q2))
         
     | 
| 
      
 503 
     | 
    
         
            +
             
     | 
| 
      
 504 
     | 
    
         
            +
             
     | 
| 
      
 505 
     | 
    
         
            +
            def _qd_from_q_quaternion(qs, dt):
         
     | 
| 
      
 506 
     | 
    
         
            +
                axis, angle = maths.quat_to_rot_axis(qrel(qs[2:], qs[:-2]))
         
     | 
| 
      
 507 
     | 
    
         
            +
                # axis.shape = (n_timesteps, 3); angle.shape = (n_timesteps,)
         
     | 
| 
      
 508 
     | 
    
         
            +
                # Thus add singleton dimesions otherwise broadcast error
         
     | 
| 
      
 509 
     | 
    
         
            +
                dq = axis * angle[:, None] / (2 * dt)
         
     | 
| 
      
 510 
     | 
    
         
            +
                dq = jnp.vstack((jnp.zeros((3,)), dq, jnp.zeros((3,))))
         
     | 
| 
      
 511 
     | 
    
         
            +
                return dq
         
     | 
| 
      
 512 
     | 
    
         
            +
             
     | 
| 
      
 513 
     | 
    
         
            +
             
     | 
| 
      
 514 
     | 
    
         
            +
            def _qd_from_q_cartesian(qs, dt):
         
     | 
| 
      
 515 
     | 
    
         
            +
                dq = jnp.vstack(
         
     | 
| 
      
 516 
     | 
    
         
            +
                    (jnp.zeros_like(qs[0]), (qs[2:] - qs[:-2]) / (2 * dt), jnp.zeros_like(qs[0]))
         
     | 
| 
      
 517 
     | 
    
         
            +
                )
         
     | 
| 
      
 518 
     | 
    
         
            +
                return dq
         
     | 
| 
      
 519 
     | 
    
         
            +
             
     | 
| 
      
 520 
     | 
    
         
            +
             
     | 
| 
      
 521 
     | 
    
         
            +
            def _p_control_quaternion(q, q_ref):
         
     | 
| 
      
 522 
     | 
    
         
            +
                axis, angle = maths.quat_to_rot_axis(qrel(q_ref, q))
         
     | 
| 
      
 523 
     | 
    
         
            +
                return axis * angle
         
     | 
| 
      
 524 
     | 
    
         
            +
             
     | 
| 
      
 525 
     | 
    
         
            +
             
     | 
| 
      
 526 
     | 
    
         
            +
            def _p_control_term_rxyz(q, q_ref):
         
     | 
| 
      
 527 
     | 
    
         
            +
                # q_ref comes from rcmg. Thus, it is already wrapped
         
     | 
| 
      
 528 
     | 
    
         
            +
                # TODO: Currently state.q is not wrapped. Change that?
         
     | 
| 
      
 529 
     | 
    
         
            +
                return maths.wrap_to_pi(q_ref - maths.wrap_to_pi(q))
         
     | 
| 
      
 530 
     | 
    
         
            +
             
     | 
| 
      
 531 
     | 
    
         
            +
             
     | 
| 
      
 532 
     | 
    
         
            +
            def _p_control_term_pxyz_p3d(q, q_ref):
         
     | 
| 
      
 533 
     | 
    
         
            +
                return q_ref - q
         
     | 
| 
      
 534 
     | 
    
         
            +
             
     | 
| 
      
 535 
     | 
    
         
            +
             
     | 
| 
      
 536 
     | 
    
         
            +
            def _p_control_term_frozen(q, q_ref):
         
     | 
| 
      
 537 
     | 
    
         
            +
                return jnp.array([])
         
     | 
| 
      
 538 
     | 
    
         
            +
             
     | 
| 
      
 539 
     | 
    
         
            +
             
     | 
| 
      
 540 
     | 
    
         
            +
            def _p_control_term_spherical(q, q_ref):
         
     | 
| 
      
 541 
     | 
    
         
            +
                return _p_control_quaternion(q, q_ref)
         
     | 
| 
      
 542 
     | 
    
         
            +
             
     | 
| 
      
 543 
     | 
    
         
            +
             
     | 
| 
      
 544 
     | 
    
         
            +
            def _p_control_term_free(q, q_ref):
         
     | 
| 
      
 545 
     | 
    
         
            +
                return jnp.concatenate(
         
     | 
| 
      
 546 
     | 
    
         
            +
                    (
         
     | 
| 
      
 547 
     | 
    
         
            +
                        _p_control_quaternion(q[:4], q_ref[:4]),
         
     | 
| 
      
 548 
     | 
    
         
            +
                        (q_ref[4:] - q[4:]),
         
     | 
| 
      
 549 
     | 
    
         
            +
                    )
         
     | 
| 
      
 550 
     | 
    
         
            +
                )
         
     | 
| 
      
 551 
     | 
    
         
            +
             
     | 
| 
      
 552 
     | 
    
         
            +
             
     | 
| 
      
 553 
     | 
    
         
            +
            def _p_control_term_free_2d(q, q_ref):
         
     | 
| 
      
 554 
     | 
    
         
            +
                return jnp.concatenate(
         
     | 
| 
      
 555 
     | 
    
         
            +
                    (
         
     | 
| 
      
 556 
     | 
    
         
            +
                        _p_control_term_rxyz(q[:1], q_ref[:1]),
         
     | 
| 
      
 557 
     | 
    
         
            +
                        (q_ref[1:] - q[1:]),
         
     | 
| 
      
 558 
     | 
    
         
            +
                    )
         
     | 
| 
      
 559 
     | 
    
         
            +
                )
         
     | 
| 
      
 560 
     | 
    
         
            +
             
     | 
| 
      
 561 
     | 
    
         
            +
             
     | 
| 
      
 562 
     | 
    
         
            +
            def _p_control_term_cor(q, q_ref):
         
     | 
| 
      
 563 
     | 
    
         
            +
                return _p_control_term_free(q, q_ref)
         
     | 
| 
      
 564 
     | 
    
         
            +
             
     | 
| 
      
 565 
     | 
    
         
            +
             
     | 
| 
      
 566 
     | 
    
         
            +
            def _qd_from_q_free(qs, dt):
         
     | 
| 
      
 567 
     | 
    
         
            +
                qd_quat = _qd_from_q_quaternion(qs[:, :4], dt)
         
     | 
| 
      
 568 
     | 
    
         
            +
                qd_pos = _qd_from_q_cartesian(qs[:, 4:], dt)
         
     | 
| 
      
 569 
     | 
    
         
            +
                return jnp.hstack((qd_quat, qd_pos))
         
     | 
| 
      
 570 
     | 
    
         
            +
             
     | 
| 
      
 571 
     | 
    
         
            +
             
     | 
| 
      
 572 
     | 
    
         
            +
            def _coordinate_vector_to_q_free_spherical_cor(q):
         
     | 
| 
      
 573 
     | 
    
         
            +
                return q.at[:4].set(maths.safe_normalize(q[:4]))
         
     | 
| 
      
 574 
     | 
    
         
            +
             
     | 
| 
      
 575 
     | 
    
         
            +
             
     | 
| 
      
 576 
     | 
    
         
            +
            def _coordinate_vector_to_q_free_2d(q):
         
     | 
| 
      
 577 
     | 
    
         
            +
                return q.at[0].set(maths.wrap_to_pi(q[0]))
         
     | 
| 
      
 578 
     | 
    
         
            +
             
     | 
| 
      
 579 
     | 
    
         
            +
             
     | 
| 
      
 580 
     | 
    
         
            +
            _str2idx = {"x": slice(0, 1), "y": slice(1, 2), "z": slice(2, 3)}
         
     | 
| 
      
 581 
     | 
    
         
            +
             
     | 
| 
      
 582 
     | 
    
         
            +
             
     | 
| 
      
 583 
     | 
    
         
            +
            def _inv_kin_rxyz_factory(xyz: str):
         
     | 
| 
      
 584 
     | 
    
         
            +
                k = maths.unit_vectors(xyz)
         
     | 
| 
      
 585 
     | 
    
         
            +
             
     | 
| 
      
 586 
     | 
    
         
            +
                def _inv_kin_rxyz(x: base.Transform, _) -> jax.Array:
         
     | 
| 
      
 587 
     | 
    
         
            +
                    # TODO
         
     | 
| 
      
 588 
     | 
    
         
            +
                    # NOTE: CONVENTION
         
     | 
| 
      
 589 
     | 
    
         
            +
                    # the first return is the much faster version but it suffers from a convention
         
     | 
| 
      
 590 
     | 
    
         
            +
                    # issue the second version is equivalent and does not suffer from the
         
     | 
| 
      
 591 
     | 
    
         
            +
                    # convention issue but it is much slower
         
     | 
| 
      
 592 
     | 
    
         
            +
                    q = x.rot
         
     | 
| 
      
 593 
     | 
    
         
            +
                    angle = 2 * jnp.arctan2(q[1:] @ k, q[0])
         
     | 
| 
      
 594 
     | 
    
         
            +
                    return -angle[None]
         
     | 
| 
      
 595 
     | 
    
         
            +
                    axis, angle = maths.quat_to_rot_axis(maths.quat_project(q, k)[0])
         
     | 
| 
      
 596 
     | 
    
         
            +
                    return jnp.where((k @ axis) > 0, angle, -angle)[None]
         
     | 
| 
      
 597 
     | 
    
         
            +
             
     | 
| 
      
 598 
     | 
    
         
            +
                return _inv_kin_rxyz
         
     | 
| 
      
 599 
     | 
    
         
            +
             
     | 
| 
      
 600 
     | 
    
         
            +
             
     | 
| 
      
 601 
     | 
    
         
            +
            def _inv_kin_pxyz_factory(xyz: str):
         
     | 
| 
      
 602 
     | 
    
         
            +
                idx = _str2idx[xyz]
         
     | 
| 
      
 603 
     | 
    
         
            +
             
     | 
| 
      
 604 
     | 
    
         
            +
                def _inv_kin_pxyz(x: base.Transform, _) -> jax.Array:
         
     | 
| 
      
 605 
     | 
    
         
            +
                    return x.pos[idx]
         
     | 
| 
      
 606 
     | 
    
         
            +
             
     | 
| 
      
 607 
     | 
    
         
            +
                return _inv_kin_pxyz
         
     | 
| 
      
 608 
     | 
    
         
            +
             
     | 
| 
      
 609 
     | 
    
         
            +
             
     | 
| 
      
 610 
     | 
    
         
            +
            def _inv_kin_free_2d(x: base.Transform, _) -> jax.Array:
         
     | 
| 
      
 611 
     | 
    
         
            +
                angle_x = _inv_kin_rxyz_factory("x")
         
     | 
| 
      
 612 
     | 
    
         
            +
                return jnp.concatenate((angle_x(x), x.pos[1:]))
         
     | 
| 
      
 613 
     | 
    
         
            +
             
     | 
| 
      
 614 
     | 
    
         
            +
             
     | 
| 
      
 615 
     | 
    
         
            +
            _joint_types = {
         
     | 
| 
      
 616 
     | 
    
         
            +
                "free": JointModel(
         
     | 
| 
      
 617 
     | 
    
         
            +
                    _free_transform,
         
     | 
| 
      
 618 
     | 
    
         
            +
                    [mrx, mry, mrz, mpx, mpy, mpz],
         
     | 
| 
      
 619 
     | 
    
         
            +
                    _draw_free,
         
     | 
| 
      
 620 
     | 
    
         
            +
                    _p_control_term_free,
         
     | 
| 
      
 621 
     | 
    
         
            +
                    _qd_from_q_free,
         
     | 
| 
      
 622 
     | 
    
         
            +
                    coordinate_vector_to_q=_coordinate_vector_to_q_free_spherical_cor,
         
     | 
| 
      
 623 
     | 
    
         
            +
                    inv_kin=lambda x, _: jnp.concatenate((x.rot, x.pos)),
         
     | 
| 
      
 624 
     | 
    
         
            +
                ),
         
     | 
| 
      
 625 
     | 
    
         
            +
                "free_2d": JointModel(
         
     | 
| 
      
 626 
     | 
    
         
            +
                    _free_2d_transform,
         
     | 
| 
      
 627 
     | 
    
         
            +
                    [mrx, mpy, mpz],
         
     | 
| 
      
 628 
     | 
    
         
            +
                    _draw_free_2d,
         
     | 
| 
      
 629 
     | 
    
         
            +
                    _p_control_term_free_2d,
         
     | 
| 
      
 630 
     | 
    
         
            +
                    _qd_from_q_cartesian,
         
     | 
| 
      
 631 
     | 
    
         
            +
                    coordinate_vector_to_q=_coordinate_vector_to_q_free_2d,
         
     | 
| 
      
 632 
     | 
    
         
            +
                    inv_kin=_inv_kin_free_2d,
         
     | 
| 
      
 633 
     | 
    
         
            +
                ),
         
     | 
| 
      
 634 
     | 
    
         
            +
                "frozen": JointModel(
         
     | 
| 
      
 635 
     | 
    
         
            +
                    _frozen_transform,
         
     | 
| 
      
 636 
     | 
    
         
            +
                    [],
         
     | 
| 
      
 637 
     | 
    
         
            +
                    _draw_frozen,
         
     | 
| 
      
 638 
     | 
    
         
            +
                    _p_control_term_frozen,
         
     | 
| 
      
 639 
     | 
    
         
            +
                    _qd_from_q_cartesian,
         
     | 
| 
      
 640 
     | 
    
         
            +
                    lambda q: q,
         
     | 
| 
      
 641 
     | 
    
         
            +
                    lambda x, _: jnp.array([]),
         
     | 
| 
      
 642 
     | 
    
         
            +
                ),
         
     | 
| 
      
 643 
     | 
    
         
            +
                "spherical": JointModel(
         
     | 
| 
      
 644 
     | 
    
         
            +
                    _spherical_transform,
         
     | 
| 
      
 645 
     | 
    
         
            +
                    [mrx, mry, mrz],
         
     | 
| 
      
 646 
     | 
    
         
            +
                    _draw_spherical,
         
     | 
| 
      
 647 
     | 
    
         
            +
                    _p_control_term_spherical,
         
     | 
| 
      
 648 
     | 
    
         
            +
                    _qd_from_q_quaternion,
         
     | 
| 
      
 649 
     | 
    
         
            +
                    _coordinate_vector_to_q_free_spherical_cor,
         
     | 
| 
      
 650 
     | 
    
         
            +
                    lambda x, _: x.rot,
         
     | 
| 
      
 651 
     | 
    
         
            +
                ),
         
     | 
| 
      
 652 
     | 
    
         
            +
                "p3d": JointModel(
         
     | 
| 
      
 653 
     | 
    
         
            +
                    _p3d_transform,
         
     | 
| 
      
 654 
     | 
    
         
            +
                    [mpx, mpy, mpz],
         
     | 
| 
      
 655 
     | 
    
         
            +
                    _draw_p3d,
         
     | 
| 
      
 656 
     | 
    
         
            +
                    _p_control_term_pxyz_p3d,
         
     | 
| 
      
 657 
     | 
    
         
            +
                    _qd_from_q_cartesian,
         
     | 
| 
      
 658 
     | 
    
         
            +
                    lambda q: q,
         
     | 
| 
      
 659 
     | 
    
         
            +
                    lambda x, _: x.pos,
         
     | 
| 
      
 660 
     | 
    
         
            +
                ),
         
     | 
| 
      
 661 
     | 
    
         
            +
                "cor": JointModel(
         
     | 
| 
      
 662 
     | 
    
         
            +
                    _cor_transform,
         
     | 
| 
      
 663 
     | 
    
         
            +
                    [mrx, mry, mrz, mpx, mpy, mpz, mpx, mpy, mpz],
         
     | 
| 
      
 664 
     | 
    
         
            +
                    _draw_cor,
         
     | 
| 
      
 665 
     | 
    
         
            +
                    _p_control_term_cor,
         
     | 
| 
      
 666 
     | 
    
         
            +
                    _qd_from_q_free,
         
     | 
| 
      
 667 
     | 
    
         
            +
                    _coordinate_vector_to_q_free_spherical_cor,
         
     | 
| 
      
 668 
     | 
    
         
            +
                ),
         
     | 
| 
      
 669 
     | 
    
         
            +
                "rx": JointModel(
         
     | 
| 
      
 670 
     | 
    
         
            +
                    lambda q, _: _rxyz_transform(q, _, jnp.array([1.0, 0, 0])),
         
     | 
| 
      
 671 
     | 
    
         
            +
                    [mrx],
         
     | 
| 
      
 672 
     | 
    
         
            +
                    _draw_rxyz,
         
     | 
| 
      
 673 
     | 
    
         
            +
                    _p_control_term_rxyz,
         
     | 
| 
      
 674 
     | 
    
         
            +
                    _qd_from_q_cartesian,
         
     | 
| 
      
 675 
     | 
    
         
            +
                    maths.wrap_to_pi,
         
     | 
| 
      
 676 
     | 
    
         
            +
                    _inv_kin_rxyz_factory("x"),
         
     | 
| 
      
 677 
     | 
    
         
            +
                ),
         
     | 
| 
      
 678 
     | 
    
         
            +
                "ry": JointModel(
         
     | 
| 
      
 679 
     | 
    
         
            +
                    lambda q, _: _rxyz_transform(q, _, jnp.array([0.0, 1, 0])),
         
     | 
| 
      
 680 
     | 
    
         
            +
                    [mry],
         
     | 
| 
      
 681 
     | 
    
         
            +
                    _draw_rxyz,
         
     | 
| 
      
 682 
     | 
    
         
            +
                    _p_control_term_rxyz,
         
     | 
| 
      
 683 
     | 
    
         
            +
                    _qd_from_q_cartesian,
         
     | 
| 
      
 684 
     | 
    
         
            +
                    maths.wrap_to_pi,
         
     | 
| 
      
 685 
     | 
    
         
            +
                    _inv_kin_rxyz_factory("y"),
         
     | 
| 
      
 686 
     | 
    
         
            +
                ),
         
     | 
| 
      
 687 
     | 
    
         
            +
                "rz": JointModel(
         
     | 
| 
      
 688 
     | 
    
         
            +
                    lambda q, _: _rxyz_transform(q, _, jnp.array([0.0, 0, 1])),
         
     | 
| 
      
 689 
     | 
    
         
            +
                    [mrz],
         
     | 
| 
      
 690 
     | 
    
         
            +
                    _draw_rxyz,
         
     | 
| 
      
 691 
     | 
    
         
            +
                    _p_control_term_rxyz,
         
     | 
| 
      
 692 
     | 
    
         
            +
                    _qd_from_q_cartesian,
         
     | 
| 
      
 693 
     | 
    
         
            +
                    maths.wrap_to_pi,
         
     | 
| 
      
 694 
     | 
    
         
            +
                    _inv_kin_rxyz_factory("z"),
         
     | 
| 
      
 695 
     | 
    
         
            +
                ),
         
     | 
| 
      
 696 
     | 
    
         
            +
                "px": JointModel(
         
     | 
| 
      
 697 
     | 
    
         
            +
                    lambda q, _: _pxyz_transform(q, _, jnp.array([1.0, 0, 0])),
         
     | 
| 
      
 698 
     | 
    
         
            +
                    [mpx],
         
     | 
| 
      
 699 
     | 
    
         
            +
                    _draw_pxyz,
         
     | 
| 
      
 700 
     | 
    
         
            +
                    _p_control_term_pxyz_p3d,
         
     | 
| 
      
 701 
     | 
    
         
            +
                    _qd_from_q_cartesian,
         
     | 
| 
      
 702 
     | 
    
         
            +
                    lambda q: q,
         
     | 
| 
      
 703 
     | 
    
         
            +
                    _inv_kin_pxyz_factory("x"),
         
     | 
| 
      
 704 
     | 
    
         
            +
                ),
         
     | 
| 
      
 705 
     | 
    
         
            +
                "py": JointModel(
         
     | 
| 
      
 706 
     | 
    
         
            +
                    lambda q, _: _pxyz_transform(q, _, jnp.array([0.0, 1, 0])),
         
     | 
| 
      
 707 
     | 
    
         
            +
                    [mpy],
         
     | 
| 
      
 708 
     | 
    
         
            +
                    _draw_pxyz,
         
     | 
| 
      
 709 
     | 
    
         
            +
                    _p_control_term_pxyz_p3d,
         
     | 
| 
      
 710 
     | 
    
         
            +
                    _qd_from_q_cartesian,
         
     | 
| 
      
 711 
     | 
    
         
            +
                    lambda q: q,
         
     | 
| 
      
 712 
     | 
    
         
            +
                    _inv_kin_pxyz_factory("y"),
         
     | 
| 
      
 713 
     | 
    
         
            +
                ),
         
     | 
| 
      
 714 
     | 
    
         
            +
                "pz": JointModel(
         
     | 
| 
      
 715 
     | 
    
         
            +
                    lambda q, _: _pxyz_transform(q, _, jnp.array([0.0, 0, 1])),
         
     | 
| 
      
 716 
     | 
    
         
            +
                    [mpz],
         
     | 
| 
      
 717 
     | 
    
         
            +
                    _draw_pxyz,
         
     | 
| 
      
 718 
     | 
    
         
            +
                    _p_control_term_pxyz_p3d,
         
     | 
| 
      
 719 
     | 
    
         
            +
                    _qd_from_q_cartesian,
         
     | 
| 
      
 720 
     | 
    
         
            +
                    lambda q: q,
         
     | 
| 
      
 721 
     | 
    
         
            +
                    _inv_kin_pxyz_factory("z"),
         
     | 
| 
      
 722 
     | 
    
         
            +
                ),
         
     | 
| 
      
 723 
     | 
    
         
            +
                "saddle": JointModel(
         
     | 
| 
      
 724 
     | 
    
         
            +
                    _saddle_transform,
         
     | 
| 
      
 725 
     | 
    
         
            +
                    [mry, mrz],
         
     | 
| 
      
 726 
     | 
    
         
            +
                    _draw_saddle,
         
     | 
| 
      
 727 
     | 
    
         
            +
                    _p_control_term_rxyz,
         
     | 
| 
      
 728 
     | 
    
         
            +
                    _qd_from_q_cartesian,
         
     | 
| 
      
 729 
     | 
    
         
            +
                    maths.wrap_to_pi,
         
     | 
| 
      
 730 
     | 
    
         
            +
                ),
         
     | 
| 
      
 731 
     | 
    
         
            +
            }
         
     | 
| 
      
 732 
     | 
    
         
            +
             
     | 
| 
      
 733 
     | 
    
         
            +
             
     | 
| 
      
 734 
     | 
    
         
            +
            def get_joint_model(joint_type: str) -> JointModel:
         
     | 
| 
      
 735 
     | 
    
         
            +
                assert (
         
     | 
| 
      
 736 
     | 
    
         
            +
                    joint_type in _joint_types
         
     | 
| 
      
 737 
     | 
    
         
            +
                ), f"{joint_type} not in {list(_joint_types.keys())}"
         
     | 
| 
      
 738 
     | 
    
         
            +
                return _joint_types[joint_type]
         
     | 
| 
      
 739 
     | 
    
         
            +
             
     | 
| 
      
 740 
     | 
    
         
            +
             
     | 
| 
      
 741 
     | 
    
         
            +
            def register_new_joint_type(
         
     | 
| 
      
 742 
     | 
    
         
            +
                joint_type: str,
         
     | 
| 
      
 743 
     | 
    
         
            +
                joint_model: JointModel,
         
     | 
| 
      
 744 
     | 
    
         
            +
                q_width: int,
         
     | 
| 
      
 745 
     | 
    
         
            +
                qd_width: Optional[int] = None,
         
     | 
| 
      
 746 
     | 
    
         
            +
                overwrite: bool = False,
         
     | 
| 
      
 747 
     | 
    
         
            +
            ):
         
     | 
| 
      
 748 
     | 
    
         
            +
                # this name is used
         
     | 
| 
      
 749 
     | 
    
         
            +
                assert joint_type != "default", "Please use another name."
         
     | 
| 
      
 750 
     | 
    
         
            +
             
     | 
| 
      
 751 
     | 
    
         
            +
                exists = joint_type in _joint_types
         
     | 
| 
      
 752 
     | 
    
         
            +
                if exists and overwrite:
         
     | 
| 
      
 753 
     | 
    
         
            +
                    for dic in [
         
     | 
| 
      
 754 
     | 
    
         
            +
                        base.Q_WIDTHS,
         
     | 
| 
      
 755 
     | 
    
         
            +
                        base.QD_WIDTHS,
         
     | 
| 
      
 756 
     | 
    
         
            +
                        _joint_types,
         
     | 
| 
      
 757 
     | 
    
         
            +
                    ]:
         
     | 
| 
      
 758 
     | 
    
         
            +
                        dic.pop(joint_type)
         
     | 
| 
      
 759 
     | 
    
         
            +
                else:
         
     | 
| 
      
 760 
     | 
    
         
            +
                    assert (
         
     | 
| 
      
 761 
     | 
    
         
            +
                        not exists
         
     | 
| 
      
 762 
     | 
    
         
            +
                    ), f"joint type `{joint_type}`already exists, use `overwrite=True`"
         
     | 
| 
      
 763 
     | 
    
         
            +
             
     | 
| 
      
 764 
     | 
    
         
            +
                if qd_width is None:
         
     | 
| 
      
 765 
     | 
    
         
            +
                    qd_width = q_width
         
     | 
| 
      
 766 
     | 
    
         
            +
             
     | 
| 
      
 767 
     | 
    
         
            +
                assert len(joint_model.motion) == qd_width
         
     | 
| 
      
 768 
     | 
    
         
            +
             
     | 
| 
      
 769 
     | 
    
         
            +
                _joint_types.update({joint_type: joint_model})
         
     | 
| 
      
 770 
     | 
    
         
            +
                base.Q_WIDTHS.update({joint_type: q_width})
         
     | 
| 
      
 771 
     | 
    
         
            +
                base.QD_WIDTHS.update({joint_type: qd_width})
         
     | 
| 
      
 772 
     | 
    
         
            +
             
     | 
| 
      
 773 
     | 
    
         
            +
             
     | 
| 
      
 774 
     | 
    
         
            +
            def _limit_scope_of_joint_params(
         
     | 
| 
      
 775 
     | 
    
         
            +
                joint_type: str, joint_params: dict[str, tree_utils.PyTree]
         
     | 
| 
      
 776 
     | 
    
         
            +
            ) -> tree_utils.PyTree:
         
     | 
| 
      
 777 
     | 
    
         
            +
                if joint_type not in joint_params:
         
     | 
| 
      
 778 
     | 
    
         
            +
                    return joint_params["default"]
         
     | 
| 
      
 779 
     | 
    
         
            +
                else:
         
     | 
| 
      
 780 
     | 
    
         
            +
                    return joint_params[joint_type]
         
     | 
| 
      
 781 
     | 
    
         
            +
             
     | 
| 
      
 782 
     | 
    
         
            +
             
     | 
| 
      
 783 
     | 
    
         
            +
            def jcalc_transform(
         
     | 
| 
      
 784 
     | 
    
         
            +
                joint_type: str, q: jax.Array, joint_params: dict[str, tree_utils.PyTree]
         
     | 
| 
      
 785 
     | 
    
         
            +
            ) -> base.Transform:
         
     | 
| 
      
 786 
     | 
    
         
            +
                joint_params = _limit_scope_of_joint_params(joint_type, joint_params)
         
     | 
| 
      
 787 
     | 
    
         
            +
                return _joint_types[joint_type].transform(q, joint_params)
         
     | 
| 
      
 788 
     | 
    
         
            +
             
     | 
| 
      
 789 
     | 
    
         
            +
             
     | 
| 
      
 790 
     | 
    
         
            +
            def _to_motion(
         
     | 
| 
      
 791 
     | 
    
         
            +
                m: base.Motion | Callable[[jax.Array], base.Motion], joint_params: tree_utils.PyTree
         
     | 
| 
      
 792 
     | 
    
         
            +
            ) -> base.Motion:
         
     | 
| 
      
 793 
     | 
    
         
            +
                if isinstance(m, base.Motion):
         
     | 
| 
      
 794 
     | 
    
         
            +
                    return m
         
     | 
| 
      
 795 
     | 
    
         
            +
                return m(joint_params)
         
     | 
| 
      
 796 
     | 
    
         
            +
             
     | 
| 
      
 797 
     | 
    
         
            +
             
     | 
| 
      
 798 
     | 
    
         
            +
            def jcalc_motion(
         
     | 
| 
      
 799 
     | 
    
         
            +
                joint_type: str, qd: jax.Array, joint_params: dict[str, tree_utils.PyTree]
         
     | 
| 
      
 800 
     | 
    
         
            +
            ) -> base.Motion:
         
     | 
| 
      
 801 
     | 
    
         
            +
                joint_params = _limit_scope_of_joint_params(joint_type, joint_params)
         
     | 
| 
      
 802 
     | 
    
         
            +
                list_motion = _joint_types[joint_type].motion
         
     | 
| 
      
 803 
     | 
    
         
            +
                m = base.Motion.zero()
         
     | 
| 
      
 804 
     | 
    
         
            +
                for dof in range(len(list_motion)):
         
     | 
| 
      
 805 
     | 
    
         
            +
                    m += _to_motion(list_motion[dof], joint_params) * qd[dof]
         
     | 
| 
      
 806 
     | 
    
         
            +
                return m
         
     | 
| 
      
 807 
     | 
    
         
            +
             
     | 
| 
      
 808 
     | 
    
         
            +
             
     | 
| 
      
 809 
     | 
    
         
            +
            def jcalc_tau(
         
     | 
| 
      
 810 
     | 
    
         
            +
                joint_type: str, f: base.Force, joint_params: dict[str, tree_utils.PyTree]
         
     | 
| 
      
 811 
     | 
    
         
            +
            ) -> jax.Array:
         
     | 
| 
      
 812 
     | 
    
         
            +
                joint_params = _limit_scope_of_joint_params(joint_type, joint_params)
         
     | 
| 
      
 813 
     | 
    
         
            +
                list_motion = _joint_types[joint_type].motion
         
     | 
| 
      
 814 
     | 
    
         
            +
                return jnp.array(
         
     | 
| 
      
 815 
     | 
    
         
            +
                    [algebra.motion_dot(_to_motion(m, joint_params), f) for m in list_motion]
         
     | 
| 
      
 816 
     | 
    
         
            +
                )
         
     | 
| 
      
 817 
     | 
    
         
            +
             
     | 
| 
      
 818 
     | 
    
         
            +
             
     | 
| 
      
 819 
     | 
    
         
            +
            def _init_joint_params(key: jax.Array, sys: base.System) -> base.System:
         
     | 
| 
      
 820 
     | 
    
         
            +
                """Search systems for custom joints and call their JointModel.init_joint_params
         
     | 
| 
      
 821 
     | 
    
         
            +
                functions. Then return updated system."""
         
     | 
| 
      
 822 
     | 
    
         
            +
             
     | 
| 
      
 823 
     | 
    
         
            +
                joint_params_init_fns = {}
         
     | 
| 
      
 824 
     | 
    
         
            +
                for typ in sys.link_types:
         
     | 
| 
      
 825 
     | 
    
         
            +
                    if typ not in joint_params_init_fns:
         
     | 
| 
      
 826 
     | 
    
         
            +
                        init_joint_params = _joint_types[typ].init_joint_params
         
     | 
| 
      
 827 
     | 
    
         
            +
                        if init_joint_params is not None:
         
     | 
| 
      
 828 
     | 
    
         
            +
                            joint_params_init_fns[typ] = init_joint_params
         
     | 
| 
      
 829 
     | 
    
         
            +
             
     | 
| 
      
 830 
     | 
    
         
            +
                joint_params: dict[str, tree_utils.PyTree] = {}
         
     | 
| 
      
 831 
     | 
    
         
            +
                n_links = sys.num_links()
         
     | 
| 
      
 832 
     | 
    
         
            +
                for typ in joint_params_init_fns:
         
     | 
| 
      
 833 
     | 
    
         
            +
                    keys = jax.random.split(key, num=n_links + 1)
         
     | 
| 
      
 834 
     | 
    
         
            +
                    key, consume = keys[0], keys[1:]
         
     | 
| 
      
 835 
     | 
    
         
            +
                    joint_params[typ] = jax.vmap(joint_params_init_fns[typ])(consume)
         
     | 
| 
      
 836 
     | 
    
         
            +
             
     | 
| 
      
 837 
     | 
    
         
            +
                # add batch default parameters
         
     | 
| 
      
 838 
     | 
    
         
            +
                joint_params["default"] = jnp.zeros((n_links, 0))
         
     | 
| 
      
 839 
     | 
    
         
            +
             
     | 
| 
      
 840 
     | 
    
         
            +
                return sys.replace(links=sys.links.replace(joint_params=joint_params))
         
     |