imt-ring 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- imt_ring-1.2.1.dist-info/METADATA +91 -0
 - imt_ring-1.2.1.dist-info/RECORD +83 -0
 - imt_ring-1.2.1.dist-info/WHEEL +5 -0
 - imt_ring-1.2.1.dist-info/top_level.txt +1 -0
 - ring/__init__.py +63 -0
 - ring/algebra.py +100 -0
 - ring/algorithms/__init__.py +45 -0
 - ring/algorithms/_random.py +403 -0
 - ring/algorithms/custom_joints/__init__.py +6 -0
 - ring/algorithms/custom_joints/rr_imp_joint.py +69 -0
 - ring/algorithms/custom_joints/rr_joint.py +33 -0
 - ring/algorithms/custom_joints/suntay.py +424 -0
 - ring/algorithms/dynamics.py +345 -0
 - ring/algorithms/generator/__init__.py +25 -0
 - ring/algorithms/generator/base.py +414 -0
 - ring/algorithms/generator/batch.py +282 -0
 - ring/algorithms/generator/motion_artifacts.py +222 -0
 - ring/algorithms/generator/pd_control.py +182 -0
 - ring/algorithms/generator/randomize.py +119 -0
 - ring/algorithms/generator/transforms.py +410 -0
 - ring/algorithms/generator/types.py +36 -0
 - ring/algorithms/jcalc.py +840 -0
 - ring/algorithms/kinematics.py +202 -0
 - ring/algorithms/sensors.py +582 -0
 - ring/base.py +1046 -0
 - ring/io/__init__.py +9 -0
 - ring/io/examples/branched.xml +24 -0
 - ring/io/examples/exclude/knee_trans_dof.xml +26 -0
 - ring/io/examples/exclude/standard_sys.xml +106 -0
 - ring/io/examples/exclude/standard_sys_rr_imp.xml +106 -0
 - ring/io/examples/inv_pendulum.xml +14 -0
 - ring/io/examples/knee_flexible_imus.xml +22 -0
 - ring/io/examples/spherical_stiff.xml +11 -0
 - ring/io/examples/symmetric.xml +12 -0
 - ring/io/examples/test_all_1.xml +39 -0
 - ring/io/examples/test_all_2.xml +39 -0
 - ring/io/examples/test_ang0_pos0.xml +9 -0
 - ring/io/examples/test_control.xml +16 -0
 - ring/io/examples/test_double_pendulum.xml +14 -0
 - ring/io/examples/test_free.xml +11 -0
 - ring/io/examples/test_kinematics.xml +23 -0
 - ring/io/examples/test_morph_system/four_seg_seg1.xml +26 -0
 - ring/io/examples/test_morph_system/four_seg_seg3.xml +26 -0
 - ring/io/examples/test_randomize_position.xml +26 -0
 - ring/io/examples/test_sensors.xml +13 -0
 - ring/io/examples/test_three_seg_seg2.xml +23 -0
 - ring/io/examples.py +42 -0
 - ring/io/test_examples.py +6 -0
 - ring/io/xml/__init__.py +6 -0
 - ring/io/xml/abstract.py +300 -0
 - ring/io/xml/from_xml.py +299 -0
 - ring/io/xml/test_from_xml.py +56 -0
 - ring/io/xml/test_to_xml.py +31 -0
 - ring/io/xml/to_xml.py +94 -0
 - ring/maths.py +397 -0
 - ring/ml/__init__.py +33 -0
 - ring/ml/base.py +292 -0
 - ring/ml/callbacks.py +434 -0
 - ring/ml/ml_utils.py +272 -0
 - ring/ml/optimizer.py +149 -0
 - ring/ml/params/0x13e3518065c21cd8.pickle +0 -0
 - ring/ml/ringnet.py +279 -0
 - ring/ml/train.py +318 -0
 - ring/ml/training_loop.py +131 -0
 - ring/rendering/__init__.py +2 -0
 - ring/rendering/base_render.py +271 -0
 - ring/rendering/mujoco_render.py +222 -0
 - ring/rendering/vispy_render.py +340 -0
 - ring/rendering/vispy_visuals.py +290 -0
 - ring/sim2real/__init__.py +7 -0
 - ring/sim2real/sim2real.py +288 -0
 - ring/spatial.py +126 -0
 - ring/sys_composer/__init__.py +5 -0
 - ring/sys_composer/delete_sys.py +114 -0
 - ring/sys_composer/inject_sys.py +110 -0
 - ring/sys_composer/morph_sys.py +361 -0
 - ring/utils/__init__.py +21 -0
 - ring/utils/batchsize.py +51 -0
 - ring/utils/colab.py +48 -0
 - ring/utils/hdf5.py +198 -0
 - ring/utils/normalizer.py +56 -0
 - ring/utils/path.py +44 -0
 - ring/utils/utils.py +161 -0
 
| 
         @@ -0,0 +1,403 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            from typing import Callable, Optional
         
     | 
| 
      
 2 
     | 
    
         
            +
            import warnings
         
     | 
| 
      
 3 
     | 
    
         
            +
             
     | 
| 
      
 4 
     | 
    
         
            +
            import jax
         
     | 
| 
      
 5 
     | 
    
         
            +
            from jax import random
         
     | 
| 
      
 6 
     | 
    
         
            +
            import jax.numpy as jnp
         
     | 
| 
      
 7 
     | 
    
         
            +
            from ring import maths
         
     | 
| 
      
 8 
     | 
    
         
            +
             
     | 
| 
      
 9 
     | 
    
         
            +
            Float = jax.Array
         
     | 
| 
      
 10 
     | 
    
         
            +
            TimeDependentFloat = Callable[[Float], Float]
         
     | 
| 
      
 11 
     | 
    
         
            +
             
     | 
| 
      
 12 
     | 
    
         
            +
             
     | 
| 
      
 13 
     | 
    
         
            +
            def _to_float(scalar: Float | TimeDependentFloat, t: Float) -> Float:
         
     | 
| 
      
 14 
     | 
    
         
            +
                if isinstance(scalar, Callable):
         
     | 
| 
      
 15 
     | 
    
         
            +
                    return scalar(t)
         
     | 
| 
      
 16 
     | 
    
         
            +
                return scalar
         
     | 
| 
      
 17 
     | 
    
         
            +
             
     | 
| 
      
 18 
     | 
    
         
            +
             
     | 
| 
      
 19 
     | 
    
         
            +
            # APPROVED
         
     | 
| 
      
 20 
     | 
    
         
            +
            def random_angle_over_time(
         
     | 
| 
      
 21 
     | 
    
         
            +
                key_t: random.PRNGKey,
         
     | 
| 
      
 22 
     | 
    
         
            +
                key_ang: random.PRNGKey,
         
     | 
| 
      
 23 
     | 
    
         
            +
                ANG_0: float,
         
     | 
| 
      
 24 
     | 
    
         
            +
                dang_min: float | TimeDependentFloat,
         
     | 
| 
      
 25 
     | 
    
         
            +
                dang_max: float | TimeDependentFloat,
         
     | 
| 
      
 26 
     | 
    
         
            +
                delta_ang_min: float | TimeDependentFloat,
         
     | 
| 
      
 27 
     | 
    
         
            +
                delta_ang_max: float | TimeDependentFloat,
         
     | 
| 
      
 28 
     | 
    
         
            +
                t_min: float,
         
     | 
| 
      
 29 
     | 
    
         
            +
                t_max: float | TimeDependentFloat,
         
     | 
| 
      
 30 
     | 
    
         
            +
                T: float,
         
     | 
| 
      
 31 
     | 
    
         
            +
                Ts: float,
         
     | 
| 
      
 32 
     | 
    
         
            +
                max_iter: int = 5,
         
     | 
| 
      
 33 
     | 
    
         
            +
                randomized_interpolation: bool = False,
         
     | 
| 
      
 34 
     | 
    
         
            +
                range_of_motion: bool = False,
         
     | 
| 
      
 35 
     | 
    
         
            +
                range_of_motion_method: str = "uniform",
         
     | 
| 
      
 36 
     | 
    
         
            +
                cdf_bins_min: int = 5,
         
     | 
| 
      
 37 
     | 
    
         
            +
                cdf_bins_max: Optional[int] = None,
         
     | 
| 
      
 38 
     | 
    
         
            +
                interpolation_method: str = "cosine",
         
     | 
| 
      
 39 
     | 
    
         
            +
            ) -> jax.Array:
         
     | 
| 
      
 40 
     | 
    
         
            +
                def body_fn_outer(val):
         
     | 
| 
      
 41 
     | 
    
         
            +
                    i, t, phi, key_t, key_ang, ANG = val
         
     | 
| 
      
 42 
     | 
    
         
            +
             
     | 
| 
      
 43 
     | 
    
         
            +
                    key_t, consume = random.split(key_t)
         
     | 
| 
      
 44 
     | 
    
         
            +
                    dt = random.uniform(consume, minval=t_min, maxval=_to_float(t_max, t))
         
     | 
| 
      
 45 
     | 
    
         
            +
             
     | 
| 
      
 46 
     | 
    
         
            +
                    key_ang, consume = random.split(key_ang)
         
     | 
| 
      
 47 
     | 
    
         
            +
                    phi = _resolve_range_of_motion(
         
     | 
| 
      
 48 
     | 
    
         
            +
                        range_of_motion,
         
     | 
| 
      
 49 
     | 
    
         
            +
                        range_of_motion_method,
         
     | 
| 
      
 50 
     | 
    
         
            +
                        _to_float(dang_min, t),
         
     | 
| 
      
 51 
     | 
    
         
            +
                        _to_float(dang_max, t),
         
     | 
| 
      
 52 
     | 
    
         
            +
                        _to_float(delta_ang_min, t),
         
     | 
| 
      
 53 
     | 
    
         
            +
                        _to_float(delta_ang_max, t),
         
     | 
| 
      
 54 
     | 
    
         
            +
                        dt,
         
     | 
| 
      
 55 
     | 
    
         
            +
                        phi,
         
     | 
| 
      
 56 
     | 
    
         
            +
                        consume,
         
     | 
| 
      
 57 
     | 
    
         
            +
                        max_iter,
         
     | 
| 
      
 58 
     | 
    
         
            +
                    )
         
     | 
| 
      
 59 
     | 
    
         
            +
                    t += dt
         
     | 
| 
      
 60 
     | 
    
         
            +
             
     | 
| 
      
 61 
     | 
    
         
            +
                    # TODO do we really need the `jnp.floor(t / Ts) * Ts` since we resample later
         
     | 
| 
      
 62 
     | 
    
         
            +
                    # anyways
         
     | 
| 
      
 63 
     | 
    
         
            +
                    ANG_i = jnp.array([[jnp.floor(t / Ts) * Ts, phi]])
         
     | 
| 
      
 64 
     | 
    
         
            +
                    ANG = jax.lax.dynamic_update_slice_in_dim(ANG, ANG_i, start_index=i, axis=0)
         
     | 
| 
      
 65 
     | 
    
         
            +
             
     | 
| 
      
 66 
     | 
    
         
            +
                    return i + 1, t, phi, key_t, key_ang, ANG
         
     | 
| 
      
 67 
     | 
    
         
            +
             
     | 
| 
      
 68 
     | 
    
         
            +
                def cond_fn_outer(val):
         
     | 
| 
      
 69 
     | 
    
         
            +
                    i, t, phi, key_t, key_ang, ANG = val
         
     | 
| 
      
 70 
     | 
    
         
            +
                    return t <= T
         
     | 
| 
      
 71 
     | 
    
         
            +
             
     | 
| 
      
 72 
     | 
    
         
            +
                # preallocate ANG array
         
     | 
| 
      
 73 
     | 
    
         
            +
                _warn_huge_preallocation(t_min, T)
         
     | 
| 
      
 74 
     | 
    
         
            +
                ANG = jnp.zeros((int(T // t_min) + 1, 2))
         
     | 
| 
      
 75 
     | 
    
         
            +
                ANG = ANG.at[0, 1].set(ANG_0)
         
     | 
| 
      
 76 
     | 
    
         
            +
             
     | 
| 
      
 77 
     | 
    
         
            +
                val_outer = (1, 0.0, ANG_0, key_t, key_ang, ANG)
         
     | 
| 
      
 78 
     | 
    
         
            +
                end, *_, consume, ANG = jax.lax.while_loop(cond_fn_outer, body_fn_outer, val_outer)
         
     | 
| 
      
 79 
     | 
    
         
            +
                ANG = jnp.where(
         
     | 
| 
      
 80 
     | 
    
         
            +
                    (jnp.arange(len(ANG)) < end)[:, None],
         
     | 
| 
      
 81 
     | 
    
         
            +
                    ANG,
         
     | 
| 
      
 82 
     | 
    
         
            +
                    jax.lax.dynamic_index_in_dim(ANG, end - 1),
         
     | 
| 
      
 83 
     | 
    
         
            +
                )
         
     | 
| 
      
 84 
     | 
    
         
            +
             
     | 
| 
      
 85 
     | 
    
         
            +
                # resample
         
     | 
| 
      
 86 
     | 
    
         
            +
                t = jnp.arange(T, step=Ts)
         
     | 
| 
      
 87 
     | 
    
         
            +
                if randomized_interpolation:
         
     | 
| 
      
 88 
     | 
    
         
            +
                    q = interpolate(cdf_bins_min, cdf_bins_max, method=interpolation_method)(
         
     | 
| 
      
 89 
     | 
    
         
            +
                        t, ANG[:, 0], ANG[:, 1], consume
         
     | 
| 
      
 90 
     | 
    
         
            +
                    )
         
     | 
| 
      
 91 
     | 
    
         
            +
                else:
         
     | 
| 
      
 92 
     | 
    
         
            +
                    if interpolation_method != "cosine":
         
     | 
| 
      
 93 
     | 
    
         
            +
                        warnings.warn(
         
     | 
| 
      
 94 
     | 
    
         
            +
                            f"You have select interpolation method {interpolation_method}. "
         
     | 
| 
      
 95 
     | 
    
         
            +
                            "Differnt choices of interpolation method are only available if "
         
     | 
| 
      
 96 
     | 
    
         
            +
                            "`randomized_interpolation` is set."
         
     | 
| 
      
 97 
     | 
    
         
            +
                        )
         
     | 
| 
      
 98 
     | 
    
         
            +
                    q = cosInterpolate(t, ANG[:, 0], ANG[:, 1])
         
     | 
| 
      
 99 
     | 
    
         
            +
             
     | 
| 
      
 100 
     | 
    
         
            +
                # if range_of_motion is true, then it is wrapped already
         
     | 
| 
      
 101 
     | 
    
         
            +
                if not range_of_motion:
         
     | 
| 
      
 102 
     | 
    
         
            +
                    q = maths.wrap_to_pi(q)
         
     | 
| 
      
 103 
     | 
    
         
            +
             
     | 
| 
      
 104 
     | 
    
         
            +
                return q
         
     | 
| 
      
 105 
     | 
    
         
            +
             
     | 
| 
      
 106 
     | 
    
         
            +
             
     | 
| 
      
 107 
     | 
    
         
            +
            # APPROVED
         
     | 
| 
      
 108 
     | 
    
         
            +
            def random_position_over_time(
         
     | 
| 
      
 109 
     | 
    
         
            +
                key: random.PRNGKey,
         
     | 
| 
      
 110 
     | 
    
         
            +
                POS_0: float,
         
     | 
| 
      
 111 
     | 
    
         
            +
                pos_min: float | TimeDependentFloat,
         
     | 
| 
      
 112 
     | 
    
         
            +
                pos_max: float | TimeDependentFloat,
         
     | 
| 
      
 113 
     | 
    
         
            +
                dpos_min: float | TimeDependentFloat,
         
     | 
| 
      
 114 
     | 
    
         
            +
                dpos_max: float | TimeDependentFloat,
         
     | 
| 
      
 115 
     | 
    
         
            +
                t_min: float,
         
     | 
| 
      
 116 
     | 
    
         
            +
                t_max: float | TimeDependentFloat,
         
     | 
| 
      
 117 
     | 
    
         
            +
                T: float,
         
     | 
| 
      
 118 
     | 
    
         
            +
                Ts: float,
         
     | 
| 
      
 119 
     | 
    
         
            +
                max_it: int,
         
     | 
| 
      
 120 
     | 
    
         
            +
                randomized_interpolation: bool = False,
         
     | 
| 
      
 121 
     | 
    
         
            +
                cdf_bins_min: int = 5,
         
     | 
| 
      
 122 
     | 
    
         
            +
                cdf_bins_max: Optional[int] = None,
         
     | 
| 
      
 123 
     | 
    
         
            +
                interpolation_method: str = "cosine",
         
     | 
| 
      
 124 
     | 
    
         
            +
            ) -> jax.Array:
         
     | 
| 
      
 125 
     | 
    
         
            +
                def body_fn_inner(val):
         
     | 
| 
      
 126 
     | 
    
         
            +
                    i, t, t_pre, x, x_pre, key = val
         
     | 
| 
      
 127 
     | 
    
         
            +
                    dt = t - t_pre
         
     | 
| 
      
 128 
     | 
    
         
            +
             
     | 
| 
      
 129 
     | 
    
         
            +
                    def sample_dx_squared(key):
         
     | 
| 
      
 130 
     | 
    
         
            +
                        key, consume = random.split(key)
         
     | 
| 
      
 131 
     | 
    
         
            +
                        dx = (
         
     | 
| 
      
 132 
     | 
    
         
            +
                            random.uniform(consume) * (2 * dpos_max * t_max**2)
         
     | 
| 
      
 133 
     | 
    
         
            +
                            - dpos_max * t_max**2
         
     | 
| 
      
 134 
     | 
    
         
            +
                        )
         
     | 
| 
      
 135 
     | 
    
         
            +
                        return key, dx
         
     | 
| 
      
 136 
     | 
    
         
            +
             
     | 
| 
      
 137 
     | 
    
         
            +
                    def sample_dx(key):
         
     | 
| 
      
 138 
     | 
    
         
            +
                        key, consume1, consume2 = random.split(key, 3)
         
     | 
| 
      
 139 
     | 
    
         
            +
                        sign = random.choice(consume1, jnp.array([-1.0, 1.0]))
         
     | 
| 
      
 140 
     | 
    
         
            +
                        dx = (
         
     | 
| 
      
 141 
     | 
    
         
            +
                            sign
         
     | 
| 
      
 142 
     | 
    
         
            +
                            * random.uniform(
         
     | 
| 
      
 143 
     | 
    
         
            +
                                consume2,
         
     | 
| 
      
 144 
     | 
    
         
            +
                                minval=_to_float(dpos_min, t_pre),
         
     | 
| 
      
 145 
     | 
    
         
            +
                                maxval=_to_float(dpos_max, t_pre),
         
     | 
| 
      
 146 
     | 
    
         
            +
                            )
         
     | 
| 
      
 147 
     | 
    
         
            +
                            * dt
         
     | 
| 
      
 148 
     | 
    
         
            +
                        )
         
     | 
| 
      
 149 
     | 
    
         
            +
                        return key, dx
         
     | 
| 
      
 150 
     | 
    
         
            +
             
     | 
| 
      
 151 
     | 
    
         
            +
                    key, dx = jax.lax.cond(i > max_it, (lambda key: (key, 0.0)), sample_dx, key)
         
     | 
| 
      
 152 
     | 
    
         
            +
                    x = x_pre + dx
         
     | 
| 
      
 153 
     | 
    
         
            +
             
     | 
| 
      
 154 
     | 
    
         
            +
                    return i + 1, t, t_pre, x, x_pre, key
         
     | 
| 
      
 155 
     | 
    
         
            +
             
     | 
| 
      
 156 
     | 
    
         
            +
                def cond_fn_inner(val):
         
     | 
| 
      
 157 
     | 
    
         
            +
                    i, t, t_pre, x, x_pre, key = val
         
     | 
| 
      
 158 
     | 
    
         
            +
                    # this was used before as `dpos`, i don't know why i used a square here?
         
     | 
| 
      
 159 
     | 
    
         
            +
                    # dpos = abs((x - x_pre) / ((t - t_pre) ** 2))  # noqa: F841
         
     | 
| 
      
 160 
     | 
    
         
            +
                    dpos = jnp.abs((x - x_pre) / (t - t_pre))
         
     | 
| 
      
 161 
     | 
    
         
            +
                    break_if_true1 = (
         
     | 
| 
      
 162 
     | 
    
         
            +
                        (dpos < _to_float(dpos_max, t_pre))
         
     | 
| 
      
 163 
     | 
    
         
            +
                        & (dpos > _to_float(dpos_min, t_pre))
         
     | 
| 
      
 164 
     | 
    
         
            +
                        & (x >= _to_float(pos_min, t_pre))
         
     | 
| 
      
 165 
     | 
    
         
            +
                        & (x <= _to_float(pos_max, t_pre))
         
     | 
| 
      
 166 
     | 
    
         
            +
                    )
         
     | 
| 
      
 167 
     | 
    
         
            +
                    break_if_true2 = i > max_it
         
     | 
| 
      
 168 
     | 
    
         
            +
                    return jnp.logical_not(break_if_true1 | break_if_true2)
         
     | 
| 
      
 169 
     | 
    
         
            +
             
     | 
| 
      
 170 
     | 
    
         
            +
                def body_fn_outer(val):
         
     | 
| 
      
 171 
     | 
    
         
            +
                    i, t, t_pre, x, x_pre, key, POS = val
         
     | 
| 
      
 172 
     | 
    
         
            +
                    key, consume = random.split(key)
         
     | 
| 
      
 173 
     | 
    
         
            +
                    t += random.uniform(consume, minval=t_min, maxval=_to_float(t_max, t_pre))
         
     | 
| 
      
 174 
     | 
    
         
            +
             
     | 
| 
      
 175 
     | 
    
         
            +
                    # that zero resets the max_it count
         
     | 
| 
      
 176 
     | 
    
         
            +
                    val_inner = (0, t, t_pre, x, x_pre, key)
         
     | 
| 
      
 177 
     | 
    
         
            +
                    _, t, t_pre, x, x_pre, key = jax.lax.while_loop(
         
     | 
| 
      
 178 
     | 
    
         
            +
                        cond_fn_inner, body_fn_inner, val_inner
         
     | 
| 
      
 179 
     | 
    
         
            +
                    )
         
     | 
| 
      
 180 
     | 
    
         
            +
             
     | 
| 
      
 181 
     | 
    
         
            +
                    POS_i = jnp.array([[jnp.floor(t / Ts) * Ts, x]])
         
     | 
| 
      
 182 
     | 
    
         
            +
                    POS = jax.lax.dynamic_update_slice_in_dim(POS, POS_i, start_index=i, axis=0)
         
     | 
| 
      
 183 
     | 
    
         
            +
                    t_pre = t
         
     | 
| 
      
 184 
     | 
    
         
            +
                    x_pre = x
         
     | 
| 
      
 185 
     | 
    
         
            +
                    return i + 1, t, t_pre, x, x_pre, key, POS
         
     | 
| 
      
 186 
     | 
    
         
            +
             
     | 
| 
      
 187 
     | 
    
         
            +
                def cond_fn_outer(val):
         
     | 
| 
      
 188 
     | 
    
         
            +
                    i, t, t_pre, x, x_pre, key, POS = val
         
     | 
| 
      
 189 
     | 
    
         
            +
                    return t <= T
         
     | 
| 
      
 190 
     | 
    
         
            +
             
     | 
| 
      
 191 
     | 
    
         
            +
                # preallocate POS array
         
     | 
| 
      
 192 
     | 
    
         
            +
                _warn_huge_preallocation(t_min, T)
         
     | 
| 
      
 193 
     | 
    
         
            +
                POS = jnp.zeros((int(T // t_min) + 1, 2))
         
     | 
| 
      
 194 
     | 
    
         
            +
                POS = POS.at[0, 1].set(POS_0)
         
     | 
| 
      
 195 
     | 
    
         
            +
             
     | 
| 
      
 196 
     | 
    
         
            +
                val_outer = (1, 0.0, 0.0, 0.0, 0.0, key, POS)
         
     | 
| 
      
 197 
     | 
    
         
            +
                end, *_, consume, POS = jax.lax.while_loop(cond_fn_outer, body_fn_outer, val_outer)
         
     | 
| 
      
 198 
     | 
    
         
            +
                POS = jnp.where(
         
     | 
| 
      
 199 
     | 
    
         
            +
                    (jnp.arange(len(POS)) < end)[:, None],
         
     | 
| 
      
 200 
     | 
    
         
            +
                    POS,
         
     | 
| 
      
 201 
     | 
    
         
            +
                    jax.lax.dynamic_index_in_dim(POS, end - 1),
         
     | 
| 
      
 202 
     | 
    
         
            +
                )
         
     | 
| 
      
 203 
     | 
    
         
            +
             
     | 
| 
      
 204 
     | 
    
         
            +
                # resample
         
     | 
| 
      
 205 
     | 
    
         
            +
                t = jnp.arange(T, step=Ts)
         
     | 
| 
      
 206 
     | 
    
         
            +
                if randomized_interpolation:
         
     | 
| 
      
 207 
     | 
    
         
            +
                    r = interpolate(cdf_bins_min, cdf_bins_max, method=interpolation_method)(
         
     | 
| 
      
 208 
     | 
    
         
            +
                        t, POS[:, 0], POS[:, 1], consume
         
     | 
| 
      
 209 
     | 
    
         
            +
                    )
         
     | 
| 
      
 210 
     | 
    
         
            +
                else:
         
     | 
| 
      
 211 
     | 
    
         
            +
                    # TODO
         
     | 
| 
      
 212 
     | 
    
         
            +
                    # Don't warn for position trajectories, i don't care about them as much
         
     | 
| 
      
 213 
     | 
    
         
            +
                    if False:
         
     | 
| 
      
 214 
     | 
    
         
            +
                        if interpolation_method != "cosine":
         
     | 
| 
      
 215 
     | 
    
         
            +
                            warnings.warn(
         
     | 
| 
      
 216 
     | 
    
         
            +
                                f"You have select interpolation method {interpolation_method}. "
         
     | 
| 
      
 217 
     | 
    
         
            +
                                "Differnt choices of interpolation method are only available if "
         
     | 
| 
      
 218 
     | 
    
         
            +
                                "`randomized_interpolation` is set."
         
     | 
| 
      
 219 
     | 
    
         
            +
                            )
         
     | 
| 
      
 220 
     | 
    
         
            +
                    r = cosInterpolate(t, POS[:, 0], POS[:, 1])
         
     | 
| 
      
 221 
     | 
    
         
            +
                return r
         
     | 
| 
      
 222 
     | 
    
         
            +
             
     | 
| 
      
 223 
     | 
    
         
            +
             
     | 
| 
      
 224 
     | 
    
         
            +
            _PREALLOCATION_WARN_LIMIT = 6000
         
     | 
| 
      
 225 
     | 
    
         
            +
             
     | 
| 
      
 226 
     | 
    
         
            +
             
     | 
| 
      
 227 
     | 
    
         
            +
            def _warn_huge_preallocation(t_min, T):
         
     | 
| 
      
 228 
     | 
    
         
            +
                N = int(T // t_min) + 1
         
     | 
| 
      
 229 
     | 
    
         
            +
                if N > _PREALLOCATION_WARN_LIMIT:
         
     | 
| 
      
 230 
     | 
    
         
            +
                    warnings.warn(
         
     | 
| 
      
 231 
     | 
    
         
            +
                        f"The combination of `T`={T} and `t_min`={t_min} requires preallocating an "
         
     | 
| 
      
 232 
     | 
    
         
            +
                        f"array with axis-length of {N} which is larger than the warn limit of "
         
     | 
| 
      
 233 
     | 
    
         
            +
                        f"{_PREALLOCATION_WARN_LIMIT}. This might lead to large memory requirements"
         
     | 
| 
      
 234 
     | 
    
         
            +
                        " and/or large jit-times, consider reducing `t_min`."
         
     | 
| 
      
 235 
     | 
    
         
            +
                    )
         
     | 
| 
      
 236 
     | 
    
         
            +
             
     | 
| 
      
 237 
     | 
    
         
            +
             
     | 
| 
      
 238 
     | 
    
         
            +
            def _clip_to_pi(phi):
         
     | 
| 
      
 239 
     | 
    
         
            +
                return jnp.clip(phi, -jnp.pi, jnp.pi)
         
     | 
| 
      
 240 
     | 
    
         
            +
             
     | 
| 
      
 241 
     | 
    
         
            +
             
     | 
| 
      
 242 
     | 
    
         
            +
            def _resolve_range_of_motion(
         
     | 
| 
      
 243 
     | 
    
         
            +
                range_of_motion,
         
     | 
| 
      
 244 
     | 
    
         
            +
                range_of_motion_method,
         
     | 
| 
      
 245 
     | 
    
         
            +
                dang_min,
         
     | 
| 
      
 246 
     | 
    
         
            +
                dang_max,
         
     | 
| 
      
 247 
     | 
    
         
            +
                delta_ang_min,
         
     | 
| 
      
 248 
     | 
    
         
            +
                delta_ang_max,
         
     | 
| 
      
 249 
     | 
    
         
            +
                dt,
         
     | 
| 
      
 250 
     | 
    
         
            +
                prev_phi,
         
     | 
| 
      
 251 
     | 
    
         
            +
                key,
         
     | 
| 
      
 252 
     | 
    
         
            +
                max_iter,
         
     | 
| 
      
 253 
     | 
    
         
            +
            ):
         
     | 
| 
      
 254 
     | 
    
         
            +
                def _next_phi(key):
         
     | 
| 
      
 255 
     | 
    
         
            +
                    key, consume = random.split(key)
         
     | 
| 
      
 256 
     | 
    
         
            +
             
     | 
| 
      
 257 
     | 
    
         
            +
                    if range_of_motion:
         
     | 
| 
      
 258 
     | 
    
         
            +
                        if range_of_motion_method == "coinflip":
         
     | 
| 
      
 259 
     | 
    
         
            +
                            probs = jnp.array([0.5, 0.5])
         
     | 
| 
      
 260 
     | 
    
         
            +
                        elif range_of_motion_method == "uniform":
         
     | 
| 
      
 261 
     | 
    
         
            +
                            p = 0.5 * (1 - prev_phi / jnp.pi)
         
     | 
| 
      
 262 
     | 
    
         
            +
                            probs = jnp.array([p, (1 - p)])
         
     | 
| 
      
 263 
     | 
    
         
            +
                        elif range_of_motion_method[:7] == "sigmoid":
         
     | 
| 
      
 264 
     | 
    
         
            +
                            scale = 1.5
         
     | 
| 
      
 265 
     | 
    
         
            +
                            provided_params = range_of_motion_method.split("-")
         
     | 
| 
      
 266 
     | 
    
         
            +
                            if len(provided_params) == 2:
         
     | 
| 
      
 267 
     | 
    
         
            +
                                scale = float(provided_params[-1])
         
     | 
| 
      
 268 
     | 
    
         
            +
                            hardcut = jnp.pi - 0.01
         
     | 
| 
      
 269 
     | 
    
         
            +
                            p = jnp.where(
         
     | 
| 
      
 270 
     | 
    
         
            +
                                prev_phi > hardcut,
         
     | 
| 
      
 271 
     | 
    
         
            +
                                0.0,
         
     | 
| 
      
 272 
     | 
    
         
            +
                                jnp.where(
         
     | 
| 
      
 273 
     | 
    
         
            +
                                    prev_phi < -hardcut, 1.0, jax.nn.sigmoid(-scale * prev_phi)
         
     | 
| 
      
 274 
     | 
    
         
            +
                                ),
         
     | 
| 
      
 275 
     | 
    
         
            +
                            )
         
     | 
| 
      
 276 
     | 
    
         
            +
                            probs = jnp.array([p, (1 - p)])
         
     | 
| 
      
 277 
     | 
    
         
            +
                        else:
         
     | 
| 
      
 278 
     | 
    
         
            +
                            raise NotImplementedError
         
     | 
| 
      
 279 
     | 
    
         
            +
             
     | 
| 
      
 280 
     | 
    
         
            +
                        sign = random.choice(consume, jnp.array([1.0, -1.0]), p=probs)
         
     | 
| 
      
 281 
     | 
    
         
            +
                        lower = _clip_to_pi(prev_phi + sign * dang_min * dt)
         
     | 
| 
      
 282 
     | 
    
         
            +
                        upper = _clip_to_pi(prev_phi + sign * dang_max * dt)
         
     | 
| 
      
 283 
     | 
    
         
            +
             
     | 
| 
      
 284 
     | 
    
         
            +
                        # swap if lower > upper
         
     | 
| 
      
 285 
     | 
    
         
            +
                        lower, upper = jnp.sort(jnp.hstack((lower, upper)))
         
     | 
| 
      
 286 
     | 
    
         
            +
             
     | 
| 
      
 287 
     | 
    
         
            +
                        key, consume = random.split(key)
         
     | 
| 
      
 288 
     | 
    
         
            +
                        return random.uniform(consume, minval=lower, maxval=upper)
         
     | 
| 
      
 289 
     | 
    
         
            +
             
     | 
| 
      
 290 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 291 
     | 
    
         
            +
                        dphi = random.uniform(consume, minval=dang_min, maxval=dang_max) * dt
         
     | 
| 
      
 292 
     | 
    
         
            +
                        key, consume = random.split(key)
         
     | 
| 
      
 293 
     | 
    
         
            +
                        sign = random.choice(consume, jnp.array([1.0, -1.0]))
         
     | 
| 
      
 294 
     | 
    
         
            +
                        return prev_phi + sign * dphi
         
     | 
| 
      
 295 
     | 
    
         
            +
             
     | 
| 
      
 296 
     | 
    
         
            +
                def body_fn(val):
         
     | 
| 
      
 297 
     | 
    
         
            +
                    key, _, i = val
         
     | 
| 
      
 298 
     | 
    
         
            +
                    key, consume = jax.random.split(key)
         
     | 
| 
      
 299 
     | 
    
         
            +
                    next_phi = _next_phi(consume)
         
     | 
| 
      
 300 
     | 
    
         
            +
                    return key, next_phi, i + 1
         
     | 
| 
      
 301 
     | 
    
         
            +
             
     | 
| 
      
 302 
     | 
    
         
            +
                def cond_fn(val):
         
     | 
| 
      
 303 
     | 
    
         
            +
                    _, next_phi, i = val
         
     | 
| 
      
 304 
     | 
    
         
            +
                    delta_phi = jnp.abs(next_phi - prev_phi)
         
     | 
| 
      
 305 
     | 
    
         
            +
                    # delta is in bounds
         
     | 
| 
      
 306 
     | 
    
         
            +
                    break_if_true1 = (delta_phi >= delta_ang_min) & (delta_phi <= delta_ang_max)
         
     | 
| 
      
 307 
     | 
    
         
            +
                    break_if_true2 = i > max_iter
         
     | 
| 
      
 308 
     | 
    
         
            +
                    return (i == 0) | (jnp.logical_not(break_if_true1 | break_if_true2))
         
     | 
| 
      
 309 
     | 
    
         
            +
             
     | 
| 
      
 310 
     | 
    
         
            +
                # the `prev_phi` here is unused
         
     | 
| 
      
 311 
     | 
    
         
            +
                return jax.lax.while_loop(cond_fn, body_fn, (key, prev_phi, 0))[1]
         
     | 
| 
      
 312 
     | 
    
         
            +
             
     | 
| 
      
 313 
     | 
    
         
            +
             
     | 
| 
      
 314 
     | 
    
         
            +
            def cosInterpolate(x, xp, fp):
         
     | 
| 
      
 315 
     | 
    
         
            +
                i = jnp.clip(jnp.searchsorted(xp, x, side="right"), 1, len(xp) - 1)
         
     | 
| 
      
 316 
     | 
    
         
            +
                dx = xp[i] - xp[i - 1]
         
     | 
| 
      
 317 
     | 
    
         
            +
                alpha = (x - xp[i - 1]) / dx
         
     | 
| 
      
 318 
     | 
    
         
            +
             
     | 
| 
      
 319 
     | 
    
         
            +
                def cos_interpolate(x1, x2, alpha):
         
     | 
| 
      
 320 
     | 
    
         
            +
                    """x2 > x1"""
         
     | 
| 
      
 321 
     | 
    
         
            +
                    return (x1 + x2) / 2 + (x1 - x2) / 2 * jnp.cos(alpha * jnp.pi)
         
     | 
| 
      
 322 
     | 
    
         
            +
             
     | 
| 
      
 323 
     | 
    
         
            +
                f = jnp.where((dx == 0), fp[i], jax.vmap(cos_interpolate)(fp[i - 1], fp[i], alpha))
         
     | 
| 
      
 324 
     | 
    
         
            +
                f = jnp.where(x > xp[-1], fp[-1], f)
         
     | 
| 
      
 325 
     | 
    
         
            +
                return f
         
     | 
| 
      
 326 
     | 
    
         
            +
             
     | 
| 
      
 327 
     | 
    
         
            +
             
     | 
| 
      
 328 
     | 
    
         
            +
            def _biject_alpha(alpha, cdf):
         
     | 
| 
      
 329 
     | 
    
         
            +
                cdf_dx = 1 / (len(cdf) - 1)
         
     | 
| 
      
 330 
     | 
    
         
            +
                left_idx = (alpha // cdf_dx).astype(int)
         
     | 
| 
      
 331 
     | 
    
         
            +
                a = (alpha - left_idx * cdf_dx) / cdf_dx
         
     | 
| 
      
 332 
     | 
    
         
            +
                return (1 - a) * cdf[left_idx] + a * cdf[left_idx + 1]
         
     | 
| 
      
 333 
     | 
    
         
            +
             
     | 
| 
      
 334 
     | 
    
         
            +
             
     | 
| 
      
 335 
     | 
    
         
            +
            def _generate_cdf(cdf_bins_min, cdf_bins_max=None):
         
     | 
| 
      
 336 
     | 
    
         
            +
                if cdf_bins_max is None:
         
     | 
| 
      
 337 
     | 
    
         
            +
             
     | 
| 
      
 338 
     | 
    
         
            +
                    def _generate_cdf_min_eq_max(cdf_bins):
         
     | 
| 
      
 339 
     | 
    
         
            +
                        def __generate_cdf(key):
         
     | 
| 
      
 340 
     | 
    
         
            +
                            samples = random.uniform(key, (cdf_bins,), minval=1e-6, maxval=1.0)
         
     | 
| 
      
 341 
     | 
    
         
            +
                            samples = jnp.hstack((jnp.array([0.0]), samples))
         
     | 
| 
      
 342 
     | 
    
         
            +
                            montonous = jnp.cumsum(samples)
         
     | 
| 
      
 343 
     | 
    
         
            +
                            cdf = montonous / montonous[-1]
         
     | 
| 
      
 344 
     | 
    
         
            +
                            return cdf
         
     | 
| 
      
 345 
     | 
    
         
            +
             
     | 
| 
      
 346 
     | 
    
         
            +
                        return __generate_cdf
         
     | 
| 
      
 347 
     | 
    
         
            +
             
     | 
| 
      
 348 
     | 
    
         
            +
                    return _generate_cdf_min_eq_max(cdf_bins=cdf_bins_min)
         
     | 
| 
      
 349 
     | 
    
         
            +
             
     | 
| 
      
 350 
     | 
    
         
            +
                def _generate_cdf_min_uneq_max(dy_min, dy_max):
         
     | 
| 
      
 351 
     | 
    
         
            +
                    assert dy_max >= dy_min
         
     | 
| 
      
 352 
     | 
    
         
            +
             
     | 
| 
      
 353 
     | 
    
         
            +
                    def __generate_cdf(key):
         
     | 
| 
      
 354 
     | 
    
         
            +
                        key, consume = random.split(key)
         
     | 
| 
      
 355 
     | 
    
         
            +
                        cdf_bins = random.randint(consume, (), dy_min, dy_max + 1)
         
     | 
| 
      
 356 
     | 
    
         
            +
                        mask = jnp.where(jnp.arange(dy_max) < cdf_bins, 1, 0)
         
     | 
| 
      
 357 
     | 
    
         
            +
                        key, consume = random.split(key)
         
     | 
| 
      
 358 
     | 
    
         
            +
                        mask = random.permutation(consume, mask)
         
     | 
| 
      
 359 
     | 
    
         
            +
                        dy = random.uniform(key, (dy_max,), minval=1e-6, maxval=1.0)
         
     | 
| 
      
 360 
     | 
    
         
            +
                        dy = dy[jnp.cumsum(mask) - 1]
         
     | 
| 
      
 361 
     | 
    
         
            +
                        y = jnp.hstack((jnp.array([0.0]), dy))
         
     | 
| 
      
 362 
     | 
    
         
            +
                        montonous = jnp.cumsum(y)
         
     | 
| 
      
 363 
     | 
    
         
            +
                        cdf = montonous / montonous[-1]
         
     | 
| 
      
 364 
     | 
    
         
            +
                        return cdf
         
     | 
| 
      
 365 
     | 
    
         
            +
             
     | 
| 
      
 366 
     | 
    
         
            +
                    return __generate_cdf
         
     | 
| 
      
 367 
     | 
    
         
            +
             
     | 
| 
      
 368 
     | 
    
         
            +
                return _generate_cdf_min_uneq_max(cdf_bins_min, cdf_bins_max)
         
     | 
| 
      
 369 
     | 
    
         
            +
             
     | 
| 
      
 370 
     | 
    
         
            +
             
     | 
| 
      
 371 
     | 
    
         
            +
            def interpolate(
         
     | 
| 
      
 372 
     | 
    
         
            +
                cdf_bins_min: int = 1, cdf_bins_max: Optional[int] = None, method: str = "cosine"
         
     | 
| 
      
 373 
     | 
    
         
            +
            ):
         
     | 
| 
      
 374 
     | 
    
         
            +
                "Interpolation with random alpha projection (disabled by default)."
         
     | 
| 
      
 375 
     | 
    
         
            +
                generate_cdf = _generate_cdf(cdf_bins_min, cdf_bins_max)
         
     | 
| 
      
 376 
     | 
    
         
            +
             
     | 
| 
      
 377 
     | 
    
         
            +
                def _interpolate(x, xp, fp, key):
         
     | 
| 
      
 378 
     | 
    
         
            +
                    i = jnp.clip(jnp.searchsorted(xp, x, side="right"), 1, len(xp) - 1)
         
     | 
| 
      
 379 
     | 
    
         
            +
                    dx = xp[i] - xp[i - 1]
         
     | 
| 
      
 380 
     | 
    
         
            +
                    alpha = (x - xp[i - 1]) / dx
         
     | 
| 
      
 381 
     | 
    
         
            +
             
     | 
| 
      
 382 
     | 
    
         
            +
                    key, *consume = random.split(key, len(xp) + 1)
         
     | 
| 
      
 383 
     | 
    
         
            +
                    consume = jnp.array(consume).reshape((len(xp), 2))
         
     | 
| 
      
 384 
     | 
    
         
            +
                    consume = consume[i - 1]
         
     | 
| 
      
 385 
     | 
    
         
            +
                    cdfs = jax.vmap(generate_cdf)(consume)
         
     | 
| 
      
 386 
     | 
    
         
            +
                    alpha = jax.vmap(_biject_alpha)(alpha, cdfs)
         
     | 
| 
      
 387 
     | 
    
         
            +
             
     | 
| 
      
 388 
     | 
    
         
            +
                    def two_point_interp(x1, x2, alpha):
         
     | 
| 
      
 389 
     | 
    
         
            +
                        """x2 > x1"""
         
     | 
| 
      
 390 
     | 
    
         
            +
                        if method == "cosine":
         
     | 
| 
      
 391 
     | 
    
         
            +
                            return (x1 + x2) / 2 + (x1 - x2) / 2 * jnp.cos(alpha * jnp.pi)
         
     | 
| 
      
 392 
     | 
    
         
            +
                        elif method == "linear":
         
     | 
| 
      
 393 
     | 
    
         
            +
                            return (1 - alpha) * x1 + alpha * x2
         
     | 
| 
      
 394 
     | 
    
         
            +
                        else:
         
     | 
| 
      
 395 
     | 
    
         
            +
                            raise NotImplementedError
         
     | 
| 
      
 396 
     | 
    
         
            +
             
     | 
| 
      
 397 
     | 
    
         
            +
                    f = jnp.where(
         
     | 
| 
      
 398 
     | 
    
         
            +
                        (dx == 0), fp[i], jax.vmap(two_point_interp)(fp[i - 1], fp[i], alpha)
         
     | 
| 
      
 399 
     | 
    
         
            +
                    )
         
     | 
| 
      
 400 
     | 
    
         
            +
                    f = jnp.where(x > xp[-1], fp[-1], f)
         
     | 
| 
      
 401 
     | 
    
         
            +
                    return f
         
     | 
| 
      
 402 
     | 
    
         
            +
             
     | 
| 
      
 403 
     | 
    
         
            +
                return _interpolate
         
     | 
| 
         @@ -0,0 +1,69 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            from dataclasses import replace
         
     | 
| 
      
 2 
     | 
    
         
            +
             
     | 
| 
      
 3 
     | 
    
         
            +
            import jax
         
     | 
| 
      
 4 
     | 
    
         
            +
            import jax.numpy as jnp
         
     | 
| 
      
 5 
     | 
    
         
            +
            import ring
         
     | 
| 
      
 6 
     | 
    
         
            +
            from ring import maths
         
     | 
| 
      
 7 
     | 
    
         
            +
            from ring.algorithms.jcalc import _draw_rxyz
         
     | 
| 
      
 8 
     | 
    
         
            +
            from ring.algorithms.jcalc import _p_control_term_rxyz
         
     | 
| 
      
 9 
     | 
    
         
            +
            from ring.algorithms.jcalc import _qd_from_q_cartesian
         
     | 
| 
      
 10 
     | 
    
         
            +
             
     | 
| 
      
 11 
     | 
    
         
            +
             
     | 
| 
      
 12 
     | 
    
         
            +
            def register_rr_imp_joint(
         
     | 
| 
      
 13 
     | 
    
         
            +
                config_res=ring.MotionConfig(dang_max=5.0, t_max=0.4),
         
     | 
| 
      
 14 
     | 
    
         
            +
                ang_max_deg: float = 7.5,
         
     | 
| 
      
 15 
     | 
    
         
            +
                name: str = "rr_imp",
         
     | 
| 
      
 16 
     | 
    
         
            +
            ):
         
     | 
| 
      
 17 
     | 
    
         
            +
                def _rr_imp_transform(q, params):
         
     | 
| 
      
 18 
     | 
    
         
            +
                    axis_pri, axis_res = params["joint_axes"], params["residual"]
         
     | 
| 
      
 19 
     | 
    
         
            +
                    rot_pri = maths.quat_rot_axis(axis_pri, q[0])
         
     | 
| 
      
 20 
     | 
    
         
            +
                    rot_res = maths.quat_rot_axis(axis_res, q[1])
         
     | 
| 
      
 21 
     | 
    
         
            +
                    rot = ring.maths.quat_mul(rot_res, rot_pri)
         
     | 
| 
      
 22 
     | 
    
         
            +
                    return ring.Transform.create(rot=rot)
         
     | 
| 
      
 23 
     | 
    
         
            +
             
     | 
| 
      
 24 
     | 
    
         
            +
                def _draw_rr_imp(config, key_t, key_value, dt, _):
         
     | 
| 
      
 25 
     | 
    
         
            +
                    key_t1, key_t2 = jax.random.split(key_t)
         
     | 
| 
      
 26 
     | 
    
         
            +
                    key_value1, key_value2 = jax.random.split(key_value)
         
     | 
| 
      
 27 
     | 
    
         
            +
                    q_traj_pri = _draw_rxyz(config, key_t1, key_value1, dt, _)
         
     | 
| 
      
 28 
     | 
    
         
            +
                    q_traj_res = _draw_rxyz(
         
     | 
| 
      
 29 
     | 
    
         
            +
                        replace(config_res, T=config.T), key_t2, key_value2, dt, _
         
     | 
| 
      
 30 
     | 
    
         
            +
                    )
         
     | 
| 
      
 31 
     | 
    
         
            +
                    # scale to be within bounds
         
     | 
| 
      
 32 
     | 
    
         
            +
                    q_traj_res = q_traj_res * (jnp.deg2rad(ang_max_deg) / jnp.pi)
         
     | 
| 
      
 33 
     | 
    
         
            +
                    # center
         
     | 
| 
      
 34 
     | 
    
         
            +
                    q_traj_res -= jnp.mean(q_traj_res)
         
     | 
| 
      
 35 
     | 
    
         
            +
                    return jnp.concatenate((q_traj_pri[:, None], q_traj_res[:, None]), axis=1)
         
     | 
| 
      
 36 
     | 
    
         
            +
             
     | 
| 
      
 37 
     | 
    
         
            +
                def _motion_fn_factory(whichone: str):
         
     | 
| 
      
 38 
     | 
    
         
            +
                    def _motion_fn(params):
         
     | 
| 
      
 39 
     | 
    
         
            +
                        axis = params[whichone]
         
     | 
| 
      
 40 
     | 
    
         
            +
                        return ring.base.Motion.create(ang=axis)
         
     | 
| 
      
 41 
     | 
    
         
            +
             
     | 
| 
      
 42 
     | 
    
         
            +
                    return _motion_fn
         
     | 
| 
      
 43 
     | 
    
         
            +
             
     | 
| 
      
 44 
     | 
    
         
            +
                rr_imp_joint = ring.JointModel(
         
     | 
| 
      
 45 
     | 
    
         
            +
                    _rr_imp_transform,
         
     | 
| 
      
 46 
     | 
    
         
            +
                    motion=[_motion_fn_factory("joint_axes"), _motion_fn_factory("residual")],
         
     | 
| 
      
 47 
     | 
    
         
            +
                    rcmg_draw_fn=_draw_rr_imp,
         
     | 
| 
      
 48 
     | 
    
         
            +
                    p_control_term=_p_control_term_rxyz,
         
     | 
| 
      
 49 
     | 
    
         
            +
                    qd_from_q=_qd_from_q_cartesian,
         
     | 
| 
      
 50 
     | 
    
         
            +
                    init_joint_params=_draw_random_joint_axes,
         
     | 
| 
      
 51 
     | 
    
         
            +
                )
         
     | 
| 
      
 52 
     | 
    
         
            +
                ring.register_new_joint_type(
         
     | 
| 
      
 53 
     | 
    
         
            +
                    name,
         
     | 
| 
      
 54 
     | 
    
         
            +
                    rr_imp_joint,
         
     | 
| 
      
 55 
     | 
    
         
            +
                    2,
         
     | 
| 
      
 56 
     | 
    
         
            +
                    2,
         
     | 
| 
      
 57 
     | 
    
         
            +
                    overwrite=True,
         
     | 
| 
      
 58 
     | 
    
         
            +
                )
         
     | 
| 
      
 59 
     | 
    
         
            +
             
     | 
| 
      
 60 
     | 
    
         
            +
             
     | 
| 
      
 61 
     | 
    
         
            +
            def _draw_random_joint_axes(key):
         
     | 
| 
      
 62 
     | 
    
         
            +
                pri_axis = jnp.array([0, 0, 1.0])
         
     | 
| 
      
 63 
     | 
    
         
            +
                key1, key2 = jax.random.split(key)
         
     | 
| 
      
 64 
     | 
    
         
            +
                phi = jax.random.uniform(key1, maxval=2 * jnp.pi)
         
     | 
| 
      
 65 
     | 
    
         
            +
                res_axis = jnp.array([jnp.cos(phi), jnp.sin(phi), 0.0])
         
     | 
| 
      
 66 
     | 
    
         
            +
                random_rotation = maths.quat_random(key2)
         
     | 
| 
      
 67 
     | 
    
         
            +
                pri_axis = maths.rotate(pri_axis, random_rotation)
         
     | 
| 
      
 68 
     | 
    
         
            +
                res_axis = maths.rotate(res_axis, random_rotation)
         
     | 
| 
      
 69 
     | 
    
         
            +
                return dict(joint_axes=pri_axis, residual=res_axis)
         
     | 
| 
         @@ -0,0 +1,33 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            import jax.numpy as jnp
         
     | 
| 
      
 2 
     | 
    
         
            +
            import ring
         
     | 
| 
      
 3 
     | 
    
         
            +
            from ring import maths
         
     | 
| 
      
 4 
     | 
    
         
            +
            from ring.algorithms.jcalc import _draw_rxyz
         
     | 
| 
      
 5 
     | 
    
         
            +
            from ring.algorithms.jcalc import _p_control_term_rxyz
         
     | 
| 
      
 6 
     | 
    
         
            +
            from ring.algorithms.jcalc import _qd_from_q_cartesian
         
     | 
| 
      
 7 
     | 
    
         
            +
             
     | 
| 
      
 8 
     | 
    
         
            +
             
     | 
| 
      
 9 
     | 
    
         
            +
            def register_rr_joint():
         
     | 
| 
      
 10 
     | 
    
         
            +
                def _rr_transform(q, params):
         
     | 
| 
      
 11 
     | 
    
         
            +
                    axis = params["joint_axes"]
         
     | 
| 
      
 12 
     | 
    
         
            +
                    q = jnp.squeeze(q)
         
     | 
| 
      
 13 
     | 
    
         
            +
                    rot = ring.maths.quat_rot_axis(axis, q)
         
     | 
| 
      
 14 
     | 
    
         
            +
                    return ring.Transform.create(rot=rot)
         
     | 
| 
      
 15 
     | 
    
         
            +
             
     | 
| 
      
 16 
     | 
    
         
            +
                def _motion_fn(params):
         
     | 
| 
      
 17 
     | 
    
         
            +
                    axis = params["joint_axes"]
         
     | 
| 
      
 18 
     | 
    
         
            +
                    return ring.base.Motion.create(ang=axis)
         
     | 
| 
      
 19 
     | 
    
         
            +
             
     | 
| 
      
 20 
     | 
    
         
            +
                rr_joint = ring.JointModel(
         
     | 
| 
      
 21 
     | 
    
         
            +
                    _rr_transform,
         
     | 
| 
      
 22 
     | 
    
         
            +
                    motion=[_motion_fn],
         
     | 
| 
      
 23 
     | 
    
         
            +
                    rcmg_draw_fn=_draw_rxyz,
         
     | 
| 
      
 24 
     | 
    
         
            +
                    p_control_term=_p_control_term_rxyz,
         
     | 
| 
      
 25 
     | 
    
         
            +
                    qd_from_q=_qd_from_q_cartesian,
         
     | 
| 
      
 26 
     | 
    
         
            +
                    init_joint_params=_draw_random_joint_axis,
         
     | 
| 
      
 27 
     | 
    
         
            +
                )
         
     | 
| 
      
 28 
     | 
    
         
            +
             
     | 
| 
      
 29 
     | 
    
         
            +
                ring.register_new_joint_type("rr", rr_joint, 1, overwrite=True)
         
     | 
| 
      
 30 
     | 
    
         
            +
             
     | 
| 
      
 31 
     | 
    
         
            +
             
     | 
| 
      
 32 
     | 
    
         
            +
            def _draw_random_joint_axis(key):
         
     | 
| 
      
 33 
     | 
    
         
            +
                return dict(joint_axes=maths.rotate(jnp.array([1.0, 0, 0]), maths.quat_random(key)))
         
     |