imt-ring 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- imt_ring-1.2.1.dist-info/METADATA +91 -0
 - imt_ring-1.2.1.dist-info/RECORD +83 -0
 - imt_ring-1.2.1.dist-info/WHEEL +5 -0
 - imt_ring-1.2.1.dist-info/top_level.txt +1 -0
 - ring/__init__.py +63 -0
 - ring/algebra.py +100 -0
 - ring/algorithms/__init__.py +45 -0
 - ring/algorithms/_random.py +403 -0
 - ring/algorithms/custom_joints/__init__.py +6 -0
 - ring/algorithms/custom_joints/rr_imp_joint.py +69 -0
 - ring/algorithms/custom_joints/rr_joint.py +33 -0
 - ring/algorithms/custom_joints/suntay.py +424 -0
 - ring/algorithms/dynamics.py +345 -0
 - ring/algorithms/generator/__init__.py +25 -0
 - ring/algorithms/generator/base.py +414 -0
 - ring/algorithms/generator/batch.py +282 -0
 - ring/algorithms/generator/motion_artifacts.py +222 -0
 - ring/algorithms/generator/pd_control.py +182 -0
 - ring/algorithms/generator/randomize.py +119 -0
 - ring/algorithms/generator/transforms.py +410 -0
 - ring/algorithms/generator/types.py +36 -0
 - ring/algorithms/jcalc.py +840 -0
 - ring/algorithms/kinematics.py +202 -0
 - ring/algorithms/sensors.py +582 -0
 - ring/base.py +1046 -0
 - ring/io/__init__.py +9 -0
 - ring/io/examples/branched.xml +24 -0
 - ring/io/examples/exclude/knee_trans_dof.xml +26 -0
 - ring/io/examples/exclude/standard_sys.xml +106 -0
 - ring/io/examples/exclude/standard_sys_rr_imp.xml +106 -0
 - ring/io/examples/inv_pendulum.xml +14 -0
 - ring/io/examples/knee_flexible_imus.xml +22 -0
 - ring/io/examples/spherical_stiff.xml +11 -0
 - ring/io/examples/symmetric.xml +12 -0
 - ring/io/examples/test_all_1.xml +39 -0
 - ring/io/examples/test_all_2.xml +39 -0
 - ring/io/examples/test_ang0_pos0.xml +9 -0
 - ring/io/examples/test_control.xml +16 -0
 - ring/io/examples/test_double_pendulum.xml +14 -0
 - ring/io/examples/test_free.xml +11 -0
 - ring/io/examples/test_kinematics.xml +23 -0
 - ring/io/examples/test_morph_system/four_seg_seg1.xml +26 -0
 - ring/io/examples/test_morph_system/four_seg_seg3.xml +26 -0
 - ring/io/examples/test_randomize_position.xml +26 -0
 - ring/io/examples/test_sensors.xml +13 -0
 - ring/io/examples/test_three_seg_seg2.xml +23 -0
 - ring/io/examples.py +42 -0
 - ring/io/test_examples.py +6 -0
 - ring/io/xml/__init__.py +6 -0
 - ring/io/xml/abstract.py +300 -0
 - ring/io/xml/from_xml.py +299 -0
 - ring/io/xml/test_from_xml.py +56 -0
 - ring/io/xml/test_to_xml.py +31 -0
 - ring/io/xml/to_xml.py +94 -0
 - ring/maths.py +397 -0
 - ring/ml/__init__.py +33 -0
 - ring/ml/base.py +292 -0
 - ring/ml/callbacks.py +434 -0
 - ring/ml/ml_utils.py +272 -0
 - ring/ml/optimizer.py +149 -0
 - ring/ml/params/0x13e3518065c21cd8.pickle +0 -0
 - ring/ml/ringnet.py +279 -0
 - ring/ml/train.py +318 -0
 - ring/ml/training_loop.py +131 -0
 - ring/rendering/__init__.py +2 -0
 - ring/rendering/base_render.py +271 -0
 - ring/rendering/mujoco_render.py +222 -0
 - ring/rendering/vispy_render.py +340 -0
 - ring/rendering/vispy_visuals.py +290 -0
 - ring/sim2real/__init__.py +7 -0
 - ring/sim2real/sim2real.py +288 -0
 - ring/spatial.py +126 -0
 - ring/sys_composer/__init__.py +5 -0
 - ring/sys_composer/delete_sys.py +114 -0
 - ring/sys_composer/inject_sys.py +110 -0
 - ring/sys_composer/morph_sys.py +361 -0
 - ring/utils/__init__.py +21 -0
 - ring/utils/batchsize.py +51 -0
 - ring/utils/colab.py +48 -0
 - ring/utils/hdf5.py +198 -0
 - ring/utils/normalizer.py +56 -0
 - ring/utils/path.py +44 -0
 - ring/utils/utils.py +161 -0
 
| 
         @@ -0,0 +1,345 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            from typing import Optional, Tuple
         
     | 
| 
      
 2 
     | 
    
         
            +
             
     | 
| 
      
 3 
     | 
    
         
            +
            import jax
         
     | 
| 
      
 4 
     | 
    
         
            +
            import jax.numpy as jnp
         
     | 
| 
      
 5 
     | 
    
         
            +
            from ring import algebra
         
     | 
| 
      
 6 
     | 
    
         
            +
            from ring import base
         
     | 
| 
      
 7 
     | 
    
         
            +
            from ring import maths
         
     | 
| 
      
 8 
     | 
    
         
            +
            from ring.algorithms import jcalc
         
     | 
| 
      
 9 
     | 
    
         
            +
            from ring.algorithms import kinematics
         
     | 
| 
      
 10 
     | 
    
         
            +
             
     | 
| 
      
 11 
     | 
    
         
            +
             
     | 
| 
      
 12 
     | 
    
         
            +
            def inverse_dynamics(sys: base.System, qd: jax.Array, qdd: jax.Array) -> jax.Array:
         
     | 
| 
      
 13 
     | 
    
         
            +
                """Performs inverse dynamics in the system. Calculates "tau".
         
     | 
| 
      
 14 
     | 
    
         
            +
                NOTE: Expects `sys` to have updated `transform` and `inertia`.
         
     | 
| 
      
 15 
     | 
    
         
            +
                """
         
     | 
| 
      
 16 
     | 
    
         
            +
                gravity = base.Motion.create(vel=sys.gravity)
         
     | 
| 
      
 17 
     | 
    
         
            +
             
     | 
| 
      
 18 
     | 
    
         
            +
                vel, acc, fs = {}, {}, {}
         
     | 
| 
      
 19 
     | 
    
         
            +
             
     | 
| 
      
 20 
     | 
    
         
            +
                def forward_scan(_, __, link_idx, parent_idx, link_type, qd, qdd, link):
         
     | 
| 
      
 21 
     | 
    
         
            +
                    p_to_l_trafo, it, joint_params = link.transform, link.inertia, link.joint_params
         
     | 
| 
      
 22 
     | 
    
         
            +
             
     | 
| 
      
 23 
     | 
    
         
            +
                    vJ = jcalc.jcalc_motion(link_type, qd, joint_params)
         
     | 
| 
      
 24 
     | 
    
         
            +
                    aJ = jcalc.jcalc_motion(link_type, qdd, joint_params)
         
     | 
| 
      
 25 
     | 
    
         
            +
             
     | 
| 
      
 26 
     | 
    
         
            +
                    t = lambda m: algebra.transform_motion(p_to_l_trafo, m)
         
     | 
| 
      
 27 
     | 
    
         
            +
             
     | 
| 
      
 28 
     | 
    
         
            +
                    if parent_idx == -1:
         
     | 
| 
      
 29 
     | 
    
         
            +
                        v = vJ
         
     | 
| 
      
 30 
     | 
    
         
            +
                        a = t(gravity) + aJ
         
     | 
| 
      
 31 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 32 
     | 
    
         
            +
                        v = vJ + t(vel[parent_idx])
         
     | 
| 
      
 33 
     | 
    
         
            +
                        a = t(acc[parent_idx]) + aJ + algebra.motion_cross(v, vJ)
         
     | 
| 
      
 34 
     | 
    
         
            +
             
     | 
| 
      
 35 
     | 
    
         
            +
                    vel[link_idx], acc[link_idx] = v, a
         
     | 
| 
      
 36 
     | 
    
         
            +
                    f = algebra.inertia_mul_motion(it, a) + algebra.motion_cross_star(
         
     | 
| 
      
 37 
     | 
    
         
            +
                        v, algebra.inertia_mul_motion(it, v)
         
     | 
| 
      
 38 
     | 
    
         
            +
                    )
         
     | 
| 
      
 39 
     | 
    
         
            +
                    fs[link_idx] = f
         
     | 
| 
      
 40 
     | 
    
         
            +
             
     | 
| 
      
 41 
     | 
    
         
            +
                sys.scan(
         
     | 
| 
      
 42 
     | 
    
         
            +
                    forward_scan,
         
     | 
| 
      
 43 
     | 
    
         
            +
                    "lllddl",
         
     | 
| 
      
 44 
     | 
    
         
            +
                    list(range(sys.num_links())),
         
     | 
| 
      
 45 
     | 
    
         
            +
                    sys.link_parents,
         
     | 
| 
      
 46 
     | 
    
         
            +
                    sys.link_types,
         
     | 
| 
      
 47 
     | 
    
         
            +
                    qd,
         
     | 
| 
      
 48 
     | 
    
         
            +
                    qdd,
         
     | 
| 
      
 49 
     | 
    
         
            +
                    sys.links,
         
     | 
| 
      
 50 
     | 
    
         
            +
                )
         
     | 
| 
      
 51 
     | 
    
         
            +
             
     | 
| 
      
 52 
     | 
    
         
            +
                taus = []
         
     | 
| 
      
 53 
     | 
    
         
            +
             
     | 
| 
      
 54 
     | 
    
         
            +
                def backwards_scan(_, __, link_idx, parent_idx, link_type, l_to_p_trafo, link):
         
     | 
| 
      
 55 
     | 
    
         
            +
                    tau = jcalc.jcalc_tau(link_type, fs[link_idx], link.joint_params)
         
     | 
| 
      
 56 
     | 
    
         
            +
                    taus.insert(0, tau)
         
     | 
| 
      
 57 
     | 
    
         
            +
                    if parent_idx != -1:
         
     | 
| 
      
 58 
     | 
    
         
            +
                        fs[parent_idx] = fs[parent_idx] + algebra.transform_force(
         
     | 
| 
      
 59 
     | 
    
         
            +
                            l_to_p_trafo, fs[link_idx]
         
     | 
| 
      
 60 
     | 
    
         
            +
                        )
         
     | 
| 
      
 61 
     | 
    
         
            +
             
     | 
| 
      
 62 
     | 
    
         
            +
                sys.scan(
         
     | 
| 
      
 63 
     | 
    
         
            +
                    backwards_scan,
         
     | 
| 
      
 64 
     | 
    
         
            +
                    "lllll",
         
     | 
| 
      
 65 
     | 
    
         
            +
                    list(range(sys.num_links())),
         
     | 
| 
      
 66 
     | 
    
         
            +
                    sys.link_parents,
         
     | 
| 
      
 67 
     | 
    
         
            +
                    sys.link_types,
         
     | 
| 
      
 68 
     | 
    
         
            +
                    jax.vmap(algebra.transform_inv)(sys.links.transform),
         
     | 
| 
      
 69 
     | 
    
         
            +
                    sys.links,
         
     | 
| 
      
 70 
     | 
    
         
            +
                    reverse=True,
         
     | 
| 
      
 71 
     | 
    
         
            +
                )
         
     | 
| 
      
 72 
     | 
    
         
            +
             
     | 
| 
      
 73 
     | 
    
         
            +
                return jnp.concatenate(taus)
         
     | 
| 
      
 74 
     | 
    
         
            +
             
     | 
| 
      
 75 
     | 
    
         
            +
             
     | 
| 
      
 76 
     | 
    
         
            +
            def compute_mass_matrix(sys: base.System) -> jax.Array:
         
     | 
| 
      
 77 
     | 
    
         
            +
                """Computes the mass matrix of the system using the `composite-rigid-body`
         
     | 
| 
      
 78 
     | 
    
         
            +
                algorithm."""
         
     | 
| 
      
 79 
     | 
    
         
            +
             
     | 
| 
      
 80 
     | 
    
         
            +
                # STEP 1: Accumulate inertias inwards
         
     | 
| 
      
 81 
     | 
    
         
            +
                # We will stay in spatial mode in this step
         
     | 
| 
      
 82 
     | 
    
         
            +
                l_to_p = jax.vmap(algebra.transform_inv)(sys.links.transform)
         
     | 
| 
      
 83 
     | 
    
         
            +
                its = [sys.links.inertia[link_idx] for link_idx in range(sys.num_links())]
         
     | 
| 
      
 84 
     | 
    
         
            +
             
     | 
| 
      
 85 
     | 
    
         
            +
                def accumulate_inertias(_, __, i, p):
         
     | 
| 
      
 86 
     | 
    
         
            +
                    nonlocal its
         
     | 
| 
      
 87 
     | 
    
         
            +
                    if p != -1:
         
     | 
| 
      
 88 
     | 
    
         
            +
                        its[p] += algebra.transform_inertia(l_to_p[i], its[i])
         
     | 
| 
      
 89 
     | 
    
         
            +
                    return its[i]
         
     | 
| 
      
 90 
     | 
    
         
            +
             
     | 
| 
      
 91 
     | 
    
         
            +
                batched_its = sys.scan(
         
     | 
| 
      
 92 
     | 
    
         
            +
                    accumulate_inertias,
         
     | 
| 
      
 93 
     | 
    
         
            +
                    "ll",
         
     | 
| 
      
 94 
     | 
    
         
            +
                    list(range(sys.num_links())),
         
     | 
| 
      
 95 
     | 
    
         
            +
                    sys.link_parents,
         
     | 
| 
      
 96 
     | 
    
         
            +
                    reverse=True,
         
     | 
| 
      
 97 
     | 
    
         
            +
                )
         
     | 
| 
      
 98 
     | 
    
         
            +
             
     | 
| 
      
 99 
     | 
    
         
            +
                # express inertias as matrices (in a vectorized way)
         
     | 
| 
      
 100 
     | 
    
         
            +
                @jax.vmap
         
     | 
| 
      
 101 
     | 
    
         
            +
                def to_matrix(obj):
         
     | 
| 
      
 102 
     | 
    
         
            +
                    return obj.as_matrix()
         
     | 
| 
      
 103 
     | 
    
         
            +
             
     | 
| 
      
 104 
     | 
    
         
            +
                I_mat = to_matrix(batched_its)
         
     | 
| 
      
 105 
     | 
    
         
            +
                del its, batched_its
         
     | 
| 
      
 106 
     | 
    
         
            +
             
     | 
| 
      
 107 
     | 
    
         
            +
                # STEP 2: Populate mass matrix
         
     | 
| 
      
 108 
     | 
    
         
            +
                # Now we go into matrix mode
         
     | 
| 
      
 109 
     | 
    
         
            +
             
     | 
| 
      
 110 
     | 
    
         
            +
                def _jcalc_motion_matrix(i: int):
         
     | 
| 
      
 111 
     | 
    
         
            +
                    joint_params = (sys.links[i]).joint_params
         
     | 
| 
      
 112 
     | 
    
         
            +
                    link_type = sys.link_types[i]
         
     | 
| 
      
 113 
     | 
    
         
            +
                    # limit scope; only pass in params of this joint type
         
     | 
| 
      
 114 
     | 
    
         
            +
                    joint_params = (
         
     | 
| 
      
 115 
     | 
    
         
            +
                        joint_params[link_type]
         
     | 
| 
      
 116 
     | 
    
         
            +
                        if link_type in joint_params
         
     | 
| 
      
 117 
     | 
    
         
            +
                        else joint_params["default"]
         
     | 
| 
      
 118 
     | 
    
         
            +
                    )
         
     | 
| 
      
 119 
     | 
    
         
            +
             
     | 
| 
      
 120 
     | 
    
         
            +
                    _to_motion = lambda m: m if isinstance(m, base.Motion) else m(joint_params)
         
     | 
| 
      
 121 
     | 
    
         
            +
                    list_motion = [_to_motion(m) for m in jcalc.get_joint_model(link_type).motion]
         
     | 
| 
      
 122 
     | 
    
         
            +
             
     | 
| 
      
 123 
     | 
    
         
            +
                    if len(list_motion) == 0:
         
     | 
| 
      
 124 
     | 
    
         
            +
                        # joint is frozen
         
     | 
| 
      
 125 
     | 
    
         
            +
                        return None
         
     | 
| 
      
 126 
     | 
    
         
            +
                    stacked_motion = list_motion[0].batch(*list_motion[1:])
         
     | 
| 
      
 127 
     | 
    
         
            +
                    return to_matrix(stacked_motion)
         
     | 
| 
      
 128 
     | 
    
         
            +
             
     | 
| 
      
 129 
     | 
    
         
            +
                S = [_jcalc_motion_matrix(i) for i in range(sys.num_links())]
         
     | 
| 
      
 130 
     | 
    
         
            +
             
     | 
| 
      
 131 
     | 
    
         
            +
                H = jnp.zeros((sys.qd_size(), sys.qd_size()))
         
     | 
| 
      
 132 
     | 
    
         
            +
             
     | 
| 
      
 133 
     | 
    
         
            +
                def populate_H(_, idx_map, i):
         
     | 
| 
      
 134 
     | 
    
         
            +
                    nonlocal H
         
     | 
| 
      
 135 
     | 
    
         
            +
             
     | 
| 
      
 136 
     | 
    
         
            +
                    # frozen joint type
         
     | 
| 
      
 137 
     | 
    
         
            +
                    if S[i] is None:
         
     | 
| 
      
 138 
     | 
    
         
            +
                        return
         
     | 
| 
      
 139 
     | 
    
         
            +
             
     | 
| 
      
 140 
     | 
    
         
            +
                    f = (I_mat[i] @ (S[i].T)).T
         
     | 
| 
      
 141 
     | 
    
         
            +
                    idxs_i = idx_map["d"](i)
         
     | 
| 
      
 142 
     | 
    
         
            +
                    H_ii = f @ (S[i].T)
         
     | 
| 
      
 143 
     | 
    
         
            +
             
     | 
| 
      
 144 
     | 
    
         
            +
                    # set upper diagonal entries to zero
         
     | 
| 
      
 145 
     | 
    
         
            +
                    # they will be filled later automatically
         
     | 
| 
      
 146 
     | 
    
         
            +
                    H_ii_lower = jnp.tril(H_ii)
         
     | 
| 
      
 147 
     | 
    
         
            +
                    H = H.at[idxs_i, idxs_i].set(H_ii_lower)
         
     | 
| 
      
 148 
     | 
    
         
            +
             
     | 
| 
      
 149 
     | 
    
         
            +
                    j = i
         
     | 
| 
      
 150 
     | 
    
         
            +
                    parent = lambda i: sys.link_parents[i]
         
     | 
| 
      
 151 
     | 
    
         
            +
                    while parent(j) != -1:
         
     | 
| 
      
 152 
     | 
    
         
            +
             
     | 
| 
      
 153 
     | 
    
         
            +
                        @jax.vmap
         
     | 
| 
      
 154 
     | 
    
         
            +
                        def transform_force(f_arr):
         
     | 
| 
      
 155 
     | 
    
         
            +
                            spatial_f = base.Force(f_arr[:3], f_arr[3:])
         
     | 
| 
      
 156 
     | 
    
         
            +
                            spatial_f_in_p = algebra.transform_force(l_to_p[j], spatial_f)
         
     | 
| 
      
 157 
     | 
    
         
            +
                            return spatial_f_in_p.as_matrix()
         
     | 
| 
      
 158 
     | 
    
         
            +
             
     | 
| 
      
 159 
     | 
    
         
            +
                        # transforms force into parent frame
         
     | 
| 
      
 160 
     | 
    
         
            +
                        f = transform_force(f)
         
     | 
| 
      
 161 
     | 
    
         
            +
             
     | 
| 
      
 162 
     | 
    
         
            +
                        j = parent(j)
         
     | 
| 
      
 163 
     | 
    
         
            +
                        if S[j] is None:
         
     | 
| 
      
 164 
     | 
    
         
            +
                            continue
         
     | 
| 
      
 165 
     | 
    
         
            +
             
     | 
| 
      
 166 
     | 
    
         
            +
                        H_ij = f @ (S[j].T)
         
     | 
| 
      
 167 
     | 
    
         
            +
                        idxs_j = idx_map["d"](j)
         
     | 
| 
      
 168 
     | 
    
         
            +
                        H = H.at[idxs_i, idxs_j].set(H_ij)
         
     | 
| 
      
 169 
     | 
    
         
            +
             
     | 
| 
      
 170 
     | 
    
         
            +
                sys.scan(populate_H, "l", list(range(sys.num_links())), reverse=True)
         
     | 
| 
      
 171 
     | 
    
         
            +
             
     | 
| 
      
 172 
     | 
    
         
            +
                H = H + jnp.tril(H, -1).T
         
     | 
| 
      
 173 
     | 
    
         
            +
             
     | 
| 
      
 174 
     | 
    
         
            +
                H += jnp.diag(sys.link_armature)
         
     | 
| 
      
 175 
     | 
    
         
            +
             
     | 
| 
      
 176 
     | 
    
         
            +
                return H
         
     | 
| 
      
 177 
     | 
    
         
            +
             
     | 
| 
      
 178 
     | 
    
         
            +
             
     | 
| 
      
 179 
     | 
    
         
            +
            def _quaternion_spring_force(q_zeropoint, q) -> jax.Array:
         
     | 
| 
      
 180 
     | 
    
         
            +
                "Computes the angular velocity direction from q to q_zeropoint."
         
     | 
| 
      
 181 
     | 
    
         
            +
                qrel = maths.quat_mul(q_zeropoint, maths.quat_inv(q))
         
     | 
| 
      
 182 
     | 
    
         
            +
                axis, angle = maths.quat_to_rot_axis(qrel)
         
     | 
| 
      
 183 
     | 
    
         
            +
                return axis * angle
         
     | 
| 
      
 184 
     | 
    
         
            +
             
     | 
| 
      
 185 
     | 
    
         
            +
             
     | 
| 
      
 186 
     | 
    
         
            +
            def _spring_force(sys: base.System, q: jax.Array):
         
     | 
| 
      
 187 
     | 
    
         
            +
                q_spring_force = []
         
     | 
| 
      
 188 
     | 
    
         
            +
             
     | 
| 
      
 189 
     | 
    
         
            +
                def _calc_spring_force_per_link(_, __, q, zeropoint, typ):
         
     | 
| 
      
 190 
     | 
    
         
            +
                    # cor is (free, p3d) stacked; free is (spherical, p3d) stacked
         
     | 
| 
      
 191 
     | 
    
         
            +
                    if typ in ["free", "cor"]:
         
     | 
| 
      
 192 
     | 
    
         
            +
                        quat_force = _quaternion_spring_force(zeropoint[:4], q[:4])
         
     | 
| 
      
 193 
     | 
    
         
            +
                        pos_force = zeropoint[4:] - q[4:]
         
     | 
| 
      
 194 
     | 
    
         
            +
                        q_spring_force_link = jnp.concatenate((quat_force, pos_force))
         
     | 
| 
      
 195 
     | 
    
         
            +
                    elif typ == "spherical":
         
     | 
| 
      
 196 
     | 
    
         
            +
                        q_spring_force_link = _quaternion_spring_force(zeropoint, q)
         
     | 
| 
      
 197 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 198 
     | 
    
         
            +
                        q_spring_force_link = zeropoint - q
         
     | 
| 
      
 199 
     | 
    
         
            +
                    q_spring_force.append(q_spring_force_link)
         
     | 
| 
      
 200 
     | 
    
         
            +
             
     | 
| 
      
 201 
     | 
    
         
            +
                sys.scan(
         
     | 
| 
      
 202 
     | 
    
         
            +
                    _calc_spring_force_per_link,
         
     | 
| 
      
 203 
     | 
    
         
            +
                    "qql",
         
     | 
| 
      
 204 
     | 
    
         
            +
                    q,
         
     | 
| 
      
 205 
     | 
    
         
            +
                    sys.link_spring_zeropoint,
         
     | 
| 
      
 206 
     | 
    
         
            +
                    sys.link_types,
         
     | 
| 
      
 207 
     | 
    
         
            +
                )
         
     | 
| 
      
 208 
     | 
    
         
            +
                return jnp.concatenate(q_spring_force)
         
     | 
| 
      
 209 
     | 
    
         
            +
             
     | 
| 
      
 210 
     | 
    
         
            +
             
     | 
| 
      
 211 
     | 
    
         
            +
            def forward_dynamics(
         
     | 
| 
      
 212 
     | 
    
         
            +
                sys: base.System,
         
     | 
| 
      
 213 
     | 
    
         
            +
                q: jax.Array,
         
     | 
| 
      
 214 
     | 
    
         
            +
                qd: jax.Array,
         
     | 
| 
      
 215 
     | 
    
         
            +
                tau: jax.Array,
         
     | 
| 
      
 216 
     | 
    
         
            +
                mass_mat_inv: jax.Array,
         
     | 
| 
      
 217 
     | 
    
         
            +
            ) -> Tuple[jax.Array, jax.Array]:
         
     | 
| 
      
 218 
     | 
    
         
            +
                C = inverse_dynamics(sys, qd, jnp.zeros_like(qd))
         
     | 
| 
      
 219 
     | 
    
         
            +
                mass_matrix = compute_mass_matrix(sys)
         
     | 
| 
      
 220 
     | 
    
         
            +
             
     | 
| 
      
 221 
     | 
    
         
            +
                spring_force = -sys.link_damping * qd + sys.link_spring_stiffness * _spring_force(
         
     | 
| 
      
 222 
     | 
    
         
            +
                    sys, q
         
     | 
| 
      
 223 
     | 
    
         
            +
                )
         
     | 
| 
      
 224 
     | 
    
         
            +
                qf_smooth = tau - C + spring_force
         
     | 
| 
      
 225 
     | 
    
         
            +
             
     | 
| 
      
 226 
     | 
    
         
            +
                if sys.mass_mat_iters == 0:
         
     | 
| 
      
 227 
     | 
    
         
            +
                    eye = jnp.eye(sys.qd_size())
         
     | 
| 
      
 228 
     | 
    
         
            +
             
     | 
| 
      
 229 
     | 
    
         
            +
                    # trick from brax / mujoco aka "integrate joint damping implicitly"
         
     | 
| 
      
 230 
     | 
    
         
            +
                    mass_matrix += jnp.diag(sys.link_damping) * sys.dt
         
     | 
| 
      
 231 
     | 
    
         
            +
             
     | 
| 
      
 232 
     | 
    
         
            +
                    # make cholesky decomposition not sometimes fail
         
     | 
| 
      
 233 
     | 
    
         
            +
                    # see: https://github.com/google/jax/issues/16149
         
     | 
| 
      
 234 
     | 
    
         
            +
                    mass_matrix += eye * 1e-6
         
     | 
| 
      
 235 
     | 
    
         
            +
             
     | 
| 
      
 236 
     | 
    
         
            +
                    mass_mat_inv = jax.scipy.linalg.solve(mass_matrix, eye, assume_a="pos")
         
     | 
| 
      
 237 
     | 
    
         
            +
                else:
         
     | 
| 
      
 238 
     | 
    
         
            +
                    mass_mat_inv = _inv_approximate(mass_matrix, mass_mat_inv, sys.mass_mat_iters)
         
     | 
| 
      
 239 
     | 
    
         
            +
             
     | 
| 
      
 240 
     | 
    
         
            +
                return mass_mat_inv @ qf_smooth, mass_mat_inv
         
     | 
| 
      
 241 
     | 
    
         
            +
             
     | 
| 
      
 242 
     | 
    
         
            +
             
     | 
| 
      
 243 
     | 
    
         
            +
            def _strapdown_integration(
         
     | 
| 
      
 244 
     | 
    
         
            +
                q: base.Quaternion, dang: jax.Array, dt: float
         
     | 
| 
      
 245 
     | 
    
         
            +
            ) -> base.Quaternion:
         
     | 
| 
      
 246 
     | 
    
         
            +
                dang_norm = jnp.linalg.norm(dang) + 1e-8
         
     | 
| 
      
 247 
     | 
    
         
            +
                axis = dang / dang_norm
         
     | 
| 
      
 248 
     | 
    
         
            +
                angle = dang_norm * dt
         
     | 
| 
      
 249 
     | 
    
         
            +
                q = maths.quat_mul(maths.quat_rot_axis(axis, angle), q)
         
     | 
| 
      
 250 
     | 
    
         
            +
                # Roy book says that one should re-normalize after every quaternion step
         
     | 
| 
      
 251 
     | 
    
         
            +
                return q / jnp.linalg.norm(q)
         
     | 
| 
      
 252 
     | 
    
         
            +
             
     | 
| 
      
 253 
     | 
    
         
            +
             
     | 
| 
      
 254 
     | 
    
         
            +
            def _semi_implicit_euler_integration(
         
     | 
| 
      
 255 
     | 
    
         
            +
                sys: base.System, state: base.State, taus: jax.Array
         
     | 
| 
      
 256 
     | 
    
         
            +
            ) -> base.State:
         
     | 
| 
      
 257 
     | 
    
         
            +
                qdd, mass_mat_inv = forward_dynamics(
         
     | 
| 
      
 258 
     | 
    
         
            +
                    sys, state.q, state.qd, taus, state.mass_mat_inv
         
     | 
| 
      
 259 
     | 
    
         
            +
                )
         
     | 
| 
      
 260 
     | 
    
         
            +
                qd_next = state.qd + sys.dt * qdd
         
     | 
| 
      
 261 
     | 
    
         
            +
             
     | 
| 
      
 262 
     | 
    
         
            +
                q_next = []
         
     | 
| 
      
 263 
     | 
    
         
            +
             
     | 
| 
      
 264 
     | 
    
         
            +
                def q_integrate(_, __, q, qd, typ):
         
     | 
| 
      
 265 
     | 
    
         
            +
                    if typ in ["free", "cor"]:
         
     | 
| 
      
 266 
     | 
    
         
            +
                        quat_next = _strapdown_integration(q[:4], qd[:3], sys.dt)
         
     | 
| 
      
 267 
     | 
    
         
            +
                        pos_next = q[4:] + qd[3:] * sys.dt
         
     | 
| 
      
 268 
     | 
    
         
            +
                        q_next_i = jnp.concatenate((quat_next, pos_next))
         
     | 
| 
      
 269 
     | 
    
         
            +
                    elif typ == "spherical":
         
     | 
| 
      
 270 
     | 
    
         
            +
                        quat_next = _strapdown_integration(q, qd, sys.dt)
         
     | 
| 
      
 271 
     | 
    
         
            +
                        q_next_i = quat_next
         
     | 
| 
      
 272 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 273 
     | 
    
         
            +
                        q_next_i = q + sys.dt * qd
         
     | 
| 
      
 274 
     | 
    
         
            +
                    q_next.append(q_next_i)
         
     | 
| 
      
 275 
     | 
    
         
            +
             
     | 
| 
      
 276 
     | 
    
         
            +
                # uses already `qd_next` because semi-implicit
         
     | 
| 
      
 277 
     | 
    
         
            +
                sys.scan(q_integrate, "qdl", state.q, qd_next, sys.link_types)
         
     | 
| 
      
 278 
     | 
    
         
            +
                q_next = jnp.concatenate(q_next)
         
     | 
| 
      
 279 
     | 
    
         
            +
             
     | 
| 
      
 280 
     | 
    
         
            +
                state = state.replace(q=q_next, qd=qd_next, mass_mat_inv=mass_mat_inv)
         
     | 
| 
      
 281 
     | 
    
         
            +
                return state
         
     | 
| 
      
 282 
     | 
    
         
            +
             
     | 
| 
      
 283 
     | 
    
         
            +
             
     | 
| 
      
 284 
     | 
    
         
            +
            _integration_methods = {
         
     | 
| 
      
 285 
     | 
    
         
            +
                "semi_implicit_euler": _semi_implicit_euler_integration,
         
     | 
| 
      
 286 
     | 
    
         
            +
            }
         
     | 
| 
      
 287 
     | 
    
         
            +
             
     | 
| 
      
 288 
     | 
    
         
            +
             
     | 
| 
      
 289 
     | 
    
         
            +
            def kinetic_energy(sys: base.System, qd: jax.Array):
         
     | 
| 
      
 290 
     | 
    
         
            +
                H = compute_mass_matrix(sys)
         
     | 
| 
      
 291 
     | 
    
         
            +
                return 0.5 * qd @ H @ qd
         
     | 
| 
      
 292 
     | 
    
         
            +
             
     | 
| 
      
 293 
     | 
    
         
            +
             
     | 
| 
      
 294 
     | 
    
         
            +
            def step(
         
     | 
| 
      
 295 
     | 
    
         
            +
                sys: base.System,
         
     | 
| 
      
 296 
     | 
    
         
            +
                state: base.State,
         
     | 
| 
      
 297 
     | 
    
         
            +
                taus: Optional[jax.Array] = None,
         
     | 
| 
      
 298 
     | 
    
         
            +
                n_substeps: int = 1,
         
     | 
| 
      
 299 
     | 
    
         
            +
            ) -> base.State:
         
     | 
| 
      
 300 
     | 
    
         
            +
                assert sys.q_size() == state.q.size
         
     | 
| 
      
 301 
     | 
    
         
            +
                if taus is None:
         
     | 
| 
      
 302 
     | 
    
         
            +
                    taus = jnp.zeros_like(state.qd)
         
     | 
| 
      
 303 
     | 
    
         
            +
                assert sys.qd_size() == state.qd.size == taus.size
         
     | 
| 
      
 304 
     | 
    
         
            +
                assert (
         
     | 
| 
      
 305 
     | 
    
         
            +
                    sys.integration_method.lower() == "semi_implicit_euler"
         
     | 
| 
      
 306 
     | 
    
         
            +
                ), "Currently, nothing else then `semi_implicit_euler` implemented."
         
     | 
| 
      
 307 
     | 
    
         
            +
             
     | 
| 
      
 308 
     | 
    
         
            +
                sys = sys.replace(dt=sys.dt / n_substeps)
         
     | 
| 
      
 309 
     | 
    
         
            +
             
     | 
| 
      
 310 
     | 
    
         
            +
                for _ in range(n_substeps):
         
     | 
| 
      
 311 
     | 
    
         
            +
                    # update kinematics before stepping; this means that the `x` in `state`
         
     | 
| 
      
 312 
     | 
    
         
            +
                    # will lag one step behind but otherwise we would have to return
         
     | 
| 
      
 313 
     | 
    
         
            +
                    # the system object which would be awkward
         
     | 
| 
      
 314 
     | 
    
         
            +
                    sys, state = kinematics.forward_kinematics(sys, state)
         
     | 
| 
      
 315 
     | 
    
         
            +
                    state = _integration_methods[sys.integration_method.lower()](sys, state, taus)
         
     | 
| 
      
 316 
     | 
    
         
            +
             
     | 
| 
      
 317 
     | 
    
         
            +
                return state
         
     | 
| 
      
 318 
     | 
    
         
            +
             
     | 
| 
      
 319 
     | 
    
         
            +
             
     | 
| 
      
 320 
     | 
    
         
            +
            def _inv_approximate(a: jax.Array, a_inv: jax.Array, num_iter: int = 10) -> jax.Array:
         
     | 
| 
      
 321 
     | 
    
         
            +
                """Use Newton-Schulz iteration to solve ``A^-1``.
         
     | 
| 
      
 322 
     | 
    
         
            +
             
     | 
| 
      
 323 
     | 
    
         
            +
                Args:
         
     | 
| 
      
 324 
     | 
    
         
            +
                  a: 2D array to invert
         
     | 
| 
      
 325 
     | 
    
         
            +
                  a_inv: approximate solution to A^-1
         
     | 
| 
      
 326 
     | 
    
         
            +
                  num_iter: number of iterations
         
     | 
| 
      
 327 
     | 
    
         
            +
             
     | 
| 
      
 328 
     | 
    
         
            +
                Returns:
         
     | 
| 
      
 329 
     | 
    
         
            +
                  A^-1 inverted matrix
         
     | 
| 
      
 330 
     | 
    
         
            +
                """
         
     | 
| 
      
 331 
     | 
    
         
            +
             
     | 
| 
      
 332 
     | 
    
         
            +
                def body_fn(carry, _):
         
     | 
| 
      
 333 
     | 
    
         
            +
                    a_inv, r, err = carry
         
     | 
| 
      
 334 
     | 
    
         
            +
                    a_inv_next = a_inv @ (jnp.eye(a.shape[0]) + r)
         
     | 
| 
      
 335 
     | 
    
         
            +
                    r_next = jnp.eye(a.shape[0]) - a @ a_inv_next
         
     | 
| 
      
 336 
     | 
    
         
            +
                    err_next = jnp.linalg.norm(r_next)
         
     | 
| 
      
 337 
     | 
    
         
            +
                    a_inv_next = jnp.where(err_next < err, a_inv_next, a_inv)
         
     | 
| 
      
 338 
     | 
    
         
            +
                    return (a_inv_next, r_next, err_next), None
         
     | 
| 
      
 339 
     | 
    
         
            +
             
     | 
| 
      
 340 
     | 
    
         
            +
                # ensure ||I - X0 @ A|| < 1, in order to guarantee convergence
         
     | 
| 
      
 341 
     | 
    
         
            +
                r0 = jnp.eye(a.shape[0]) - a @ a_inv
         
     | 
| 
      
 342 
     | 
    
         
            +
                a_inv = jnp.where(jnp.linalg.norm(r0) > 1, 0.5 * a.T / jnp.trace(a @ a.T), a_inv)
         
     | 
| 
      
 343 
     | 
    
         
            +
                (a_inv, _, _), _ = jax.lax.scan(body_fn, (a_inv, r0, 1.0), None, num_iter)
         
     | 
| 
      
 344 
     | 
    
         
            +
             
     | 
| 
      
 345 
     | 
    
         
            +
                return a_inv
         
     | 
| 
         @@ -0,0 +1,25 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            from . import base
         
     | 
| 
      
 2 
     | 
    
         
            +
            from . import batch
         
     | 
| 
      
 3 
     | 
    
         
            +
            from . import motion_artifacts
         
     | 
| 
      
 4 
     | 
    
         
            +
            from . import pd_control
         
     | 
| 
      
 5 
     | 
    
         
            +
            from . import randomize
         
     | 
| 
      
 6 
     | 
    
         
            +
            from . import transforms
         
     | 
| 
      
 7 
     | 
    
         
            +
            from . import types
         
     | 
| 
      
 8 
     | 
    
         
            +
            from .base import GeneratorPipe
         
     | 
| 
      
 9 
     | 
    
         
            +
            from .base import GeneratorTrafoRemoveInputExtras
         
     | 
| 
      
 10 
     | 
    
         
            +
            from .base import GeneratorTrafoRemoveOutputExtras
         
     | 
| 
      
 11 
     | 
    
         
            +
            from .base import RCMG
         
     | 
| 
      
 12 
     | 
    
         
            +
            from .batch import batch_generators_eager
         
     | 
| 
      
 13 
     | 
    
         
            +
            from .batch import batch_generators_eager_to_list
         
     | 
| 
      
 14 
     | 
    
         
            +
            from .batch import batch_generators_lazy
         
     | 
| 
      
 15 
     | 
    
         
            +
            from .batch import batched_generator_from_list
         
     | 
| 
      
 16 
     | 
    
         
            +
            from .batch import batched_generator_from_paths
         
     | 
| 
      
 17 
     | 
    
         
            +
            from .randomize import randomize_anchors
         
     | 
| 
      
 18 
     | 
    
         
            +
            from .randomize import randomize_hz
         
     | 
| 
      
 19 
     | 
    
         
            +
            from .randomize import randomize_hz_finalize_fn_factory
         
     | 
| 
      
 20 
     | 
    
         
            +
            from .transforms import GeneratorTrafoExpandFlatten
         
     | 
| 
      
 21 
     | 
    
         
            +
            from .transforms import GeneratorTrafoRandomizePositions
         
     | 
| 
      
 22 
     | 
    
         
            +
            from .types import FINALIZE_FN
         
     | 
| 
      
 23 
     | 
    
         
            +
            from .types import Generator
         
     | 
| 
      
 24 
     | 
    
         
            +
            from .types import GeneratorTrafo
         
     | 
| 
      
 25 
     | 
    
         
            +
            from .types import SETUP_FN
         
     |