imt-ring 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- imt_ring-1.2.1.dist-info/METADATA +91 -0
 - imt_ring-1.2.1.dist-info/RECORD +83 -0
 - imt_ring-1.2.1.dist-info/WHEEL +5 -0
 - imt_ring-1.2.1.dist-info/top_level.txt +1 -0
 - ring/__init__.py +63 -0
 - ring/algebra.py +100 -0
 - ring/algorithms/__init__.py +45 -0
 - ring/algorithms/_random.py +403 -0
 - ring/algorithms/custom_joints/__init__.py +6 -0
 - ring/algorithms/custom_joints/rr_imp_joint.py +69 -0
 - ring/algorithms/custom_joints/rr_joint.py +33 -0
 - ring/algorithms/custom_joints/suntay.py +424 -0
 - ring/algorithms/dynamics.py +345 -0
 - ring/algorithms/generator/__init__.py +25 -0
 - ring/algorithms/generator/base.py +414 -0
 - ring/algorithms/generator/batch.py +282 -0
 - ring/algorithms/generator/motion_artifacts.py +222 -0
 - ring/algorithms/generator/pd_control.py +182 -0
 - ring/algorithms/generator/randomize.py +119 -0
 - ring/algorithms/generator/transforms.py +410 -0
 - ring/algorithms/generator/types.py +36 -0
 - ring/algorithms/jcalc.py +840 -0
 - ring/algorithms/kinematics.py +202 -0
 - ring/algorithms/sensors.py +582 -0
 - ring/base.py +1046 -0
 - ring/io/__init__.py +9 -0
 - ring/io/examples/branched.xml +24 -0
 - ring/io/examples/exclude/knee_trans_dof.xml +26 -0
 - ring/io/examples/exclude/standard_sys.xml +106 -0
 - ring/io/examples/exclude/standard_sys_rr_imp.xml +106 -0
 - ring/io/examples/inv_pendulum.xml +14 -0
 - ring/io/examples/knee_flexible_imus.xml +22 -0
 - ring/io/examples/spherical_stiff.xml +11 -0
 - ring/io/examples/symmetric.xml +12 -0
 - ring/io/examples/test_all_1.xml +39 -0
 - ring/io/examples/test_all_2.xml +39 -0
 - ring/io/examples/test_ang0_pos0.xml +9 -0
 - ring/io/examples/test_control.xml +16 -0
 - ring/io/examples/test_double_pendulum.xml +14 -0
 - ring/io/examples/test_free.xml +11 -0
 - ring/io/examples/test_kinematics.xml +23 -0
 - ring/io/examples/test_morph_system/four_seg_seg1.xml +26 -0
 - ring/io/examples/test_morph_system/four_seg_seg3.xml +26 -0
 - ring/io/examples/test_randomize_position.xml +26 -0
 - ring/io/examples/test_sensors.xml +13 -0
 - ring/io/examples/test_three_seg_seg2.xml +23 -0
 - ring/io/examples.py +42 -0
 - ring/io/test_examples.py +6 -0
 - ring/io/xml/__init__.py +6 -0
 - ring/io/xml/abstract.py +300 -0
 - ring/io/xml/from_xml.py +299 -0
 - ring/io/xml/test_from_xml.py +56 -0
 - ring/io/xml/test_to_xml.py +31 -0
 - ring/io/xml/to_xml.py +94 -0
 - ring/maths.py +397 -0
 - ring/ml/__init__.py +33 -0
 - ring/ml/base.py +292 -0
 - ring/ml/callbacks.py +434 -0
 - ring/ml/ml_utils.py +272 -0
 - ring/ml/optimizer.py +149 -0
 - ring/ml/params/0x13e3518065c21cd8.pickle +0 -0
 - ring/ml/ringnet.py +279 -0
 - ring/ml/train.py +318 -0
 - ring/ml/training_loop.py +131 -0
 - ring/rendering/__init__.py +2 -0
 - ring/rendering/base_render.py +271 -0
 - ring/rendering/mujoco_render.py +222 -0
 - ring/rendering/vispy_render.py +340 -0
 - ring/rendering/vispy_visuals.py +290 -0
 - ring/sim2real/__init__.py +7 -0
 - ring/sim2real/sim2real.py +288 -0
 - ring/spatial.py +126 -0
 - ring/sys_composer/__init__.py +5 -0
 - ring/sys_composer/delete_sys.py +114 -0
 - ring/sys_composer/inject_sys.py +110 -0
 - ring/sys_composer/morph_sys.py +361 -0
 - ring/utils/__init__.py +21 -0
 - ring/utils/batchsize.py +51 -0
 - ring/utils/colab.py +48 -0
 - ring/utils/hdf5.py +198 -0
 - ring/utils/normalizer.py +56 -0
 - ring/utils/path.py +44 -0
 - ring/utils/utils.py +161 -0
 
| 
         @@ -0,0 +1,222 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            import warnings
         
     | 
| 
      
 2 
     | 
    
         
            +
             
     | 
| 
      
 3 
     | 
    
         
            +
            import jax
         
     | 
| 
      
 4 
     | 
    
         
            +
            import jax.numpy as jnp
         
     | 
| 
      
 5 
     | 
    
         
            +
            from ring import base
         
     | 
| 
      
 6 
     | 
    
         
            +
            from ring import io
         
     | 
| 
      
 7 
     | 
    
         
            +
            import tree_utils
         
     | 
| 
      
 8 
     | 
    
         
            +
             
     | 
| 
      
 9 
     | 
    
         
            +
             
     | 
| 
      
 10 
     | 
    
         
            +
            def imu_reference_link_name(imu_link_name: str) -> str:
         
     | 
| 
      
 11 
     | 
    
         
            +
                return "_" + imu_link_name
         
     | 
| 
      
 12 
     | 
    
         
            +
             
     | 
| 
      
 13 
     | 
    
         
            +
             
     | 
| 
      
 14 
     | 
    
         
            +
            def unactuated_subsystem(sys) -> list[str]:
         
     | 
| 
      
 15 
     | 
    
         
            +
                return [imu_reference_link_name(name) for name in sys.findall_imus()]
         
     | 
| 
      
 16 
     | 
    
         
            +
             
     | 
| 
      
 17 
     | 
    
         
            +
             
     | 
| 
      
 18 
     | 
    
         
            +
            def _subsystem_factory(imu_name: str, pos_min_max: float) -> base.System:
         
     | 
| 
      
 19 
     | 
    
         
            +
                assert pos_min_max >= 0
         
     | 
| 
      
 20 
     | 
    
         
            +
                pos = f'pos_min="-{pos_min_max} -{pos_min_max} -{pos_min_max}" pos_max="{pos_min_max} {pos_min_max} {pos_min_max}"'  # noqa: E501
         
     | 
| 
      
 21 
     | 
    
         
            +
                stiff = 'spring_stiff="50 50 50"'
         
     | 
| 
      
 22 
     | 
    
         
            +
                damping = 'damping="5 5 5"'
         
     | 
| 
      
 23 
     | 
    
         
            +
                return io.load_sys_from_str(
         
     | 
| 
      
 24 
     | 
    
         
            +
                    f"""
         
     | 
| 
      
 25 
     | 
    
         
            +
                    <x_xy>
         
     | 
| 
      
 26 
     | 
    
         
            +
                    <worldbody>
         
     | 
| 
      
 27 
     | 
    
         
            +
                    <body name="{imu_name}" joint="p3d" {pos if pos_min_max != 0.0 else ""} {stiff} {damping}/>
         
     | 
| 
      
 28 
     | 
    
         
            +
                    </worldbody>
         
     | 
| 
      
 29 
     | 
    
         
            +
                    </x_xy>
         
     | 
| 
      
 30 
     | 
    
         
            +
                    """  # noqa: E501
         
     | 
| 
      
 31 
     | 
    
         
            +
                )
         
     | 
| 
      
 32 
     | 
    
         
            +
             
     | 
| 
      
 33 
     | 
    
         
            +
             
     | 
| 
      
 34 
     | 
    
         
            +
            def inject_subsystems(
         
     | 
| 
      
 35 
     | 
    
         
            +
                sys: base.System,
         
     | 
| 
      
 36 
     | 
    
         
            +
                pos_min_max: float = 0.0,
         
     | 
| 
      
 37 
     | 
    
         
            +
                **kwargs,
         
     | 
| 
      
 38 
     | 
    
         
            +
            ) -> base.System:
         
     | 
| 
      
 39 
     | 
    
         
            +
                imu_idx_to_name_map = {sys.name_to_idx(imu): imu for imu in sys.findall_imus()}
         
     | 
| 
      
 40 
     | 
    
         
            +
             
     | 
| 
      
 41 
     | 
    
         
            +
                default_spher_stif = jnp.ones((3,)) * 0.3
         
     | 
| 
      
 42 
     | 
    
         
            +
                default_spher_damp = default_spher_stif * 0.1
         
     | 
| 
      
 43 
     | 
    
         
            +
                for imu in sys.findall_imus():
         
     | 
| 
      
 44 
     | 
    
         
            +
                    sys = sys.unfreeze(imu, "spherical")
         
     | 
| 
      
 45 
     | 
    
         
            +
                    # set default stiffness and damping of spherical joint
         
     | 
| 
      
 46 
     | 
    
         
            +
                    # this won't override anything because the frozen joint can not have any values
         
     | 
| 
      
 47 
     | 
    
         
            +
                    qd_slice = sys.idx_map("d")[imu]
         
     | 
| 
      
 48 
     | 
    
         
            +
                    stiffne = sys.link_spring_stiffness.at[qd_slice].set(default_spher_stif)
         
     | 
| 
      
 49 
     | 
    
         
            +
                    damping = sys.link_damping.at[qd_slice].set(default_spher_damp)
         
     | 
| 
      
 50 
     | 
    
         
            +
                    sys = sys.replace(link_spring_stiffness=stiffne, link_damping=damping)
         
     | 
| 
      
 51 
     | 
    
         
            +
             
     | 
| 
      
 52 
     | 
    
         
            +
                    _imu = imu_reference_link_name(imu)
         
     | 
| 
      
 53 
     | 
    
         
            +
                    sys = sys.change_link_name(imu, _imu)
         
     | 
| 
      
 54 
     | 
    
         
            +
                    sys = sys.inject_system(_subsystem_factory(imu, pos_min_max), _imu)
         
     | 
| 
      
 55 
     | 
    
         
            +
             
     | 
| 
      
 56 
     | 
    
         
            +
                # attach geoms to newly injected link
         
     | 
| 
      
 57 
     | 
    
         
            +
                new_geoms = []
         
     | 
| 
      
 58 
     | 
    
         
            +
             
     | 
| 
      
 59 
     | 
    
         
            +
                for geom in sys.geoms:
         
     | 
| 
      
 60 
     | 
    
         
            +
                    if geom.link_idx in imu_idx_to_name_map:
         
     | 
| 
      
 61 
     | 
    
         
            +
                        imu_name = imu_idx_to_name_map[geom.link_idx]
         
     | 
| 
      
 62 
     | 
    
         
            +
                        new_link_idx = sys.name_to_idx(imu_name)
         
     | 
| 
      
 63 
     | 
    
         
            +
                        geom = geom.replace(link_idx=new_link_idx)
         
     | 
| 
      
 64 
     | 
    
         
            +
                    new_geoms.append(geom)
         
     | 
| 
      
 65 
     | 
    
         
            +
             
     | 
| 
      
 66 
     | 
    
         
            +
                sys = sys.replace(geoms=new_geoms)
         
     | 
| 
      
 67 
     | 
    
         
            +
             
     | 
| 
      
 68 
     | 
    
         
            +
                # TODO investigate whether this parse is needed; I don't think so
         
     | 
| 
      
 69 
     | 
    
         
            +
                # re-calculate the inertia matrices because the geoms have been re-attached
         
     | 
| 
      
 70 
     | 
    
         
            +
                sys = sys.parse()
         
     | 
| 
      
 71 
     | 
    
         
            +
             
     | 
| 
      
 72 
     | 
    
         
            +
                # TODO set all joint_params to zeros; they can not be preserved anyways and
         
     | 
| 
      
 73 
     | 
    
         
            +
                # otherwise many warnings will be rose
         
     | 
| 
      
 74 
     | 
    
         
            +
                # instead warn explicitly once now and move on
         
     | 
| 
      
 75 
     | 
    
         
            +
                warnings.warn(
         
     | 
| 
      
 76 
     | 
    
         
            +
                    "`sys.links.joint_params` has been set to zero, this might lead to "
         
     | 
| 
      
 77 
     | 
    
         
            +
                    "unexpected behaviour unless you use `randomize_joint_params`"
         
     | 
| 
      
 78 
     | 
    
         
            +
                )
         
     | 
| 
      
 79 
     | 
    
         
            +
                joint_params_zeros = tree_utils.tree_zeros_like(sys.links.joint_params)
         
     | 
| 
      
 80 
     | 
    
         
            +
                sys = sys.replace(links=sys.links.replace(joint_params=joint_params_zeros))
         
     | 
| 
      
 81 
     | 
    
         
            +
             
     | 
| 
      
 82 
     | 
    
         
            +
                # double load; this fixes the issue that injected links got appended at the end
         
     | 
| 
      
 83 
     | 
    
         
            +
                sys = io.load_sys_from_str(io.save_sys_to_str(sys))
         
     | 
| 
      
 84 
     | 
    
         
            +
             
     | 
| 
      
 85 
     | 
    
         
            +
                return sys
         
     | 
| 
      
 86 
     | 
    
         
            +
             
     | 
| 
      
 87 
     | 
    
         
            +
             
     | 
| 
      
 88 
     | 
    
         
            +
            _STIF_MIN_SPH = 0.2
         
     | 
| 
      
 89 
     | 
    
         
            +
            _STIF_MAX_SPH = 10.0
         
     | 
| 
      
 90 
     | 
    
         
            +
            _STIF_MIN_P3D = 25.0
         
     | 
| 
      
 91 
     | 
    
         
            +
            _STIF_MAX_P3D = 1e3
         
     | 
| 
      
 92 
     | 
    
         
            +
            # damping = factor * stiffness
         
     | 
| 
      
 93 
     | 
    
         
            +
            _DAMP_MIN = 0.05
         
     | 
| 
      
 94 
     | 
    
         
            +
            _DAMP_MAX = 0.5
         
     | 
| 
      
 95 
     | 
    
         
            +
             
     | 
| 
      
 96 
     | 
    
         
            +
             
     | 
| 
      
 97 
     | 
    
         
            +
            def _log_uniform(key, shape, minval, maxval):
         
     | 
| 
      
 98 
     | 
    
         
            +
                assert 0 <= minval <= maxval
         
     | 
| 
      
 99 
     | 
    
         
            +
                minval, maxval = map(jnp.log, (minval, maxval))
         
     | 
| 
      
 100 
     | 
    
         
            +
                return jnp.exp(jax.random.uniform(key, shape, minval=minval, maxval=maxval))
         
     | 
| 
      
 101 
     | 
    
         
            +
             
     | 
| 
      
 102 
     | 
    
         
            +
             
     | 
| 
      
 103 
     | 
    
         
            +
            def setup_fn_randomize_damping_stiffness_factory(
         
     | 
| 
      
 104 
     | 
    
         
            +
                prob_rigid: float,
         
     | 
| 
      
 105 
     | 
    
         
            +
                all_imus_either_rigid_or_flex: bool,
         
     | 
| 
      
 106 
     | 
    
         
            +
                imus_surely_rigid: list[str],
         
     | 
| 
      
 107 
     | 
    
         
            +
            ):
         
     | 
| 
      
 108 
     | 
    
         
            +
                assert 0 <= prob_rigid <= 1
         
     | 
| 
      
 109 
     | 
    
         
            +
                assert prob_rigid != 1, "Use `imu_motion_artifacts`=False instead."
         
     | 
| 
      
 110 
     | 
    
         
            +
                if prob_rigid == 0.0:
         
     | 
| 
      
 111 
     | 
    
         
            +
                    assert len(imus_surely_rigid) == 0
         
     | 
| 
      
 112 
     | 
    
         
            +
             
     | 
| 
      
 113 
     | 
    
         
            +
                def stif_damp_rigid(key):
         
     | 
| 
      
 114 
     | 
    
         
            +
                    stif_sph = 200.0 * jnp.ones((3,))
         
     | 
| 
      
 115 
     | 
    
         
            +
                    stif_p3d = 2e4 * jnp.ones((3,))
         
     | 
| 
      
 116 
     | 
    
         
            +
                    stif = jnp.concatenate((stif_sph, stif_p3d))
         
     | 
| 
      
 117 
     | 
    
         
            +
                    return stif, stif * 0.2
         
     | 
| 
      
 118 
     | 
    
         
            +
             
     | 
| 
      
 119 
     | 
    
         
            +
                def stif_damp_nonrigid(key):
         
     | 
| 
      
 120 
     | 
    
         
            +
                    keys = jax.random.split(key, 3)
         
     | 
| 
      
 121 
     | 
    
         
            +
                    stif_sph = _log_uniform(keys[0], (3,), _STIF_MIN_SPH, _STIF_MAX_SPH)
         
     | 
| 
      
 122 
     | 
    
         
            +
                    stif_p3d = _log_uniform(keys[1], (3,), _STIF_MIN_P3D, _STIF_MAX_P3D)
         
     | 
| 
      
 123 
     | 
    
         
            +
                    stif = jnp.concatenate((stif_sph, stif_p3d))
         
     | 
| 
      
 124 
     | 
    
         
            +
                    damp = _log_uniform(keys[2], (6,), _DAMP_MIN, _DAMP_MAX)
         
     | 
| 
      
 125 
     | 
    
         
            +
                    return stif, stif * damp
         
     | 
| 
      
 126 
     | 
    
         
            +
             
     | 
| 
      
 127 
     | 
    
         
            +
                def setup_fn_randomize_damping_stiffness(key, sys: base.System) -> base.System:
         
     | 
| 
      
 128 
     | 
    
         
            +
                    link_damping = sys.link_damping
         
     | 
| 
      
 129 
     | 
    
         
            +
                    link_spring_stiffness = sys.link_spring_stiffness
         
     | 
| 
      
 130 
     | 
    
         
            +
             
     | 
| 
      
 131 
     | 
    
         
            +
                    idx_map = sys.idx_map("d")
         
     | 
| 
      
 132 
     | 
    
         
            +
                    imus = sys.findall_imus()
         
     | 
| 
      
 133 
     | 
    
         
            +
             
     | 
| 
      
 134 
     | 
    
         
            +
                    # initialize this RV because it might not get redrawn if
         
     | 
| 
      
 135 
     | 
    
         
            +
                    # `all_imus_either_rigid_or_flex` is set
         
     | 
| 
      
 136 
     | 
    
         
            +
                    key, consume = jax.random.split(key)
         
     | 
| 
      
 137 
     | 
    
         
            +
                    is_rigid = jax.random.bernoulli(consume, prob_rigid)
         
     | 
| 
      
 138 
     | 
    
         
            +
             
     | 
| 
      
 139 
     | 
    
         
            +
                    # this is only for the assertion used below
         
     | 
| 
      
 140 
     | 
    
         
            +
                    triggered_surely_rigid = []
         
     | 
| 
      
 141 
     | 
    
         
            +
             
     | 
| 
      
 142 
     | 
    
         
            +
                    for imu in imus:
         
     | 
| 
      
 143 
     | 
    
         
            +
                        # _imu has spherical joint and imu has p3d joint
         
     | 
| 
      
 144 
     | 
    
         
            +
                        slice = jnp.r_[idx_map[imu_reference_link_name(imu)], idx_map[imu]]
         
     | 
| 
      
 145 
     | 
    
         
            +
                        key, c1, c2 = jax.random.split(key, 3)
         
     | 
| 
      
 146 
     | 
    
         
            +
             
     | 
| 
      
 147 
     | 
    
         
            +
                        if prob_rigid > 0:
         
     | 
| 
      
 148 
     | 
    
         
            +
                            if imu in imus_surely_rigid:
         
     | 
| 
      
 149 
     | 
    
         
            +
                                triggered_surely_rigid.append(imu)
         
     | 
| 
      
 150 
     | 
    
         
            +
                                # logging.debug(f"IMU {imu} is surely rigid.")
         
     | 
| 
      
 151 
     | 
    
         
            +
                                stif, damp = stif_damp_rigid(c2)
         
     | 
| 
      
 152 
     | 
    
         
            +
                            else:
         
     | 
| 
      
 153 
     | 
    
         
            +
                                if not all_imus_either_rigid_or_flex:
         
     | 
| 
      
 154 
     | 
    
         
            +
                                    is_rigid = jax.random.bernoulli(c1, prob_rigid)
         
     | 
| 
      
 155 
     | 
    
         
            +
                                stif, damp = jax.lax.cond(
         
     | 
| 
      
 156 
     | 
    
         
            +
                                    is_rigid, stif_damp_rigid, stif_damp_nonrigid, c2
         
     | 
| 
      
 157 
     | 
    
         
            +
                                )
         
     | 
| 
      
 158 
     | 
    
         
            +
                        else:
         
     | 
| 
      
 159 
     | 
    
         
            +
                            stif, damp = stif_damp_nonrigid(c2)
         
     | 
| 
      
 160 
     | 
    
         
            +
                        link_spring_stiffness = link_spring_stiffness.at[slice].set(stif)
         
     | 
| 
      
 161 
     | 
    
         
            +
                        link_damping = link_damping.at[slice].set(damp)
         
     | 
| 
      
 162 
     | 
    
         
            +
             
     | 
| 
      
 163 
     | 
    
         
            +
                    assert len(imus_surely_rigid) == len(triggered_surely_rigid)
         
     | 
| 
      
 164 
     | 
    
         
            +
                    for imu_surely_rigid in imus_surely_rigid:
         
     | 
| 
      
 165 
     | 
    
         
            +
                        assert imu_surely_rigid in triggered_surely_rigid
         
     | 
| 
      
 166 
     | 
    
         
            +
             
     | 
| 
      
 167 
     | 
    
         
            +
                    return sys.replace(
         
     | 
| 
      
 168 
     | 
    
         
            +
                        link_damping=link_damping, link_spring_stiffness=link_spring_stiffness
         
     | 
| 
      
 169 
     | 
    
         
            +
                    )
         
     | 
| 
      
 170 
     | 
    
         
            +
             
     | 
| 
      
 171 
     | 
    
         
            +
                return setup_fn_randomize_damping_stiffness
         
     | 
| 
      
 172 
     | 
    
         
            +
             
     | 
| 
      
 173 
     | 
    
         
            +
             
     | 
| 
      
 174 
     | 
    
         
            +
            def _match_q_x_between_sys(
         
     | 
| 
      
 175 
     | 
    
         
            +
                sys_small: base.System,
         
     | 
| 
      
 176 
     | 
    
         
            +
                q_large: jax.Array,
         
     | 
| 
      
 177 
     | 
    
         
            +
                x_large: base.Transform,
         
     | 
| 
      
 178 
     | 
    
         
            +
                sys_large: base.System,
         
     | 
| 
      
 179 
     | 
    
         
            +
                q_large_skip: list[str],
         
     | 
| 
      
 180 
     | 
    
         
            +
            ) -> tree_utils.PyTree:
         
     | 
| 
      
 181 
     | 
    
         
            +
                assert q_large.ndim == 2
         
     | 
| 
      
 182 
     | 
    
         
            +
                assert q_large.shape[1] == sys_large.q_size()
         
     | 
| 
      
 183 
     | 
    
         
            +
                assert x_large.shape(1) == sys_large.num_links()
         
     | 
| 
      
 184 
     | 
    
         
            +
             
     | 
| 
      
 185 
     | 
    
         
            +
                x_small_indices = []
         
     | 
| 
      
 186 
     | 
    
         
            +
                q_small = []
         
     | 
| 
      
 187 
     | 
    
         
            +
                q_idx_map = sys_large.idx_map("q")
         
     | 
| 
      
 188 
     | 
    
         
            +
             
     | 
| 
      
 189 
     | 
    
         
            +
                def f(_, __, name: str):
         
     | 
| 
      
 190 
     | 
    
         
            +
                    x_small_indices.append(sys_large.name_to_idx(name))
         
     | 
| 
      
 191 
     | 
    
         
            +
                    # for the imu links the joint type was changed from spherical to frozen
         
     | 
| 
      
 192 
     | 
    
         
            +
                    # thus the q_idx_map has slices of length 4 but the `sys_small` has those
         
     | 
| 
      
 193 
     | 
    
         
            +
                    # imus but with frozen joint type and thus slices of length 0; so skip them
         
     | 
| 
      
 194 
     | 
    
         
            +
                    if name in q_large_skip:
         
     | 
| 
      
 195 
     | 
    
         
            +
                        return
         
     | 
| 
      
 196 
     | 
    
         
            +
                    q_small.append(q_large[:, q_idx_map[name]])
         
     | 
| 
      
 197 
     | 
    
         
            +
             
     | 
| 
      
 198 
     | 
    
         
            +
                sys_small.scan(f, "l", sys_small.link_names)
         
     | 
| 
      
 199 
     | 
    
         
            +
             
     | 
| 
      
 200 
     | 
    
         
            +
                x_small = tree_utils.tree_indices(x_large, jnp.array(x_small_indices), axis=1)
         
     | 
| 
      
 201 
     | 
    
         
            +
                q_small = jnp.concatenate(q_small, axis=1)
         
     | 
| 
      
 202 
     | 
    
         
            +
                return q_small, x_small
         
     | 
| 
      
 203 
     | 
    
         
            +
             
     | 
| 
      
 204 
     | 
    
         
            +
             
     | 
| 
      
 205 
     | 
    
         
            +
            class GeneratorTrafoHideInjectedBodies:
         
     | 
| 
      
 206 
     | 
    
         
            +
                def __call__(self, gen):
         
     | 
| 
      
 207 
     | 
    
         
            +
                    def _gen(*args):
         
     | 
| 
      
 208 
     | 
    
         
            +
                        (X, y), (key, q, x, sys_x) = gen(*args)
         
     | 
| 
      
 209 
     | 
    
         
            +
             
     | 
| 
      
 210 
     | 
    
         
            +
                        # delete injected frames; then rename from `_imu` back to `imu`
         
     | 
| 
      
 211 
     | 
    
         
            +
                        imus = sys_x.findall_imus()
         
     | 
| 
      
 212 
     | 
    
         
            +
                        _imu2imu_map = {imu_reference_link_name(imu): imu for imu in imus}
         
     | 
| 
      
 213 
     | 
    
         
            +
                        sys = sys_x.delete_system(imus)
         
     | 
| 
      
 214 
     | 
    
         
            +
                        for _imu, imu in _imu2imu_map.items():
         
     | 
| 
      
 215 
     | 
    
         
            +
                            sys = sys.change_link_name(_imu, imu).change_joint_type(imu, "frozen")
         
     | 
| 
      
 216 
     | 
    
         
            +
             
     | 
| 
      
 217 
     | 
    
         
            +
                        # match q and x to `sys`; second axis is link axis
         
     | 
| 
      
 218 
     | 
    
         
            +
                        q, x = _match_q_x_between_sys(sys, q, x, sys_x, q_large_skip=imus)
         
     | 
| 
      
 219 
     | 
    
         
            +
             
     | 
| 
      
 220 
     | 
    
         
            +
                        return (X, y), (key, q, x, sys)
         
     | 
| 
      
 221 
     | 
    
         
            +
             
     | 
| 
      
 222 
     | 
    
         
            +
                    return _gen
         
     | 
| 
         @@ -0,0 +1,182 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            from types import SimpleNamespace
         
     | 
| 
      
 2 
     | 
    
         
            +
            from typing import Optional
         
     | 
| 
      
 3 
     | 
    
         
            +
             
     | 
| 
      
 4 
     | 
    
         
            +
            from flax import struct
         
     | 
| 
      
 5 
     | 
    
         
            +
            import jax
         
     | 
| 
      
 6 
     | 
    
         
            +
            import jax.numpy as jnp
         
     | 
| 
      
 7 
     | 
    
         
            +
            from ring import base
         
     | 
| 
      
 8 
     | 
    
         
            +
            from ring.algorithms import dynamics
         
     | 
| 
      
 9 
     | 
    
         
            +
            from ring.algorithms import jcalc
         
     | 
| 
      
 10 
     | 
    
         
            +
             
     | 
| 
      
 11 
     | 
    
         
            +
             
     | 
| 
      
 12 
     | 
    
         
            +
            @struct.dataclass
         
     | 
| 
      
 13 
     | 
    
         
            +
            class PDControllerState:
         
     | 
| 
      
 14 
     | 
    
         
            +
                i: int
         
     | 
| 
      
 15 
     | 
    
         
            +
                q_ref_as_dict: dict
         
     | 
| 
      
 16 
     | 
    
         
            +
                qd_ref_as_dict: dict
         
     | 
| 
      
 17 
     | 
    
         
            +
                P_gains: dict
         
     | 
| 
      
 18 
     | 
    
         
            +
                D_gains: dict
         
     | 
| 
      
 19 
     | 
    
         
            +
             
     | 
| 
      
 20 
     | 
    
         
            +
             
     | 
| 
      
 21 
     | 
    
         
            +
            def _pd_control(P: jax.Array, D: Optional[jax.Array] = None):
         
     | 
| 
      
 22 
     | 
    
         
            +
                """Computes tau using a PD controller. Returns a pair of (init, apply) functions.
         
     | 
| 
      
 23 
     | 
    
         
            +
             
     | 
| 
      
 24 
     | 
    
         
            +
                NOTE: Gains around ~10_000 are good for spherical joints, everything else ~250-300
         
     | 
| 
      
 25 
     | 
    
         
            +
                works just fine. Damping should be about 2500 for spherical joints, and
         
     | 
| 
      
 26 
     | 
    
         
            +
                about 25 for everything else.
         
     | 
| 
      
 27 
     | 
    
         
            +
             
     | 
| 
      
 28 
     | 
    
         
            +
                Args:
         
     | 
| 
      
 29 
     | 
    
         
            +
                    P: jax.Array of P gains. Shape: (sys_init.qd_size())
         
     | 
| 
      
 30 
     | 
    
         
            +
                    D: jax.Array of D gains. Shape: (sys_init.qd_size()) where `sys_init` is the
         
     | 
| 
      
 31 
     | 
    
         
            +
                        system that recorded the reference trajectory `q_ref`
         
     | 
| 
      
 32 
     | 
    
         
            +
                        If not given, then no D control is applied.
         
     | 
| 
      
 33 
     | 
    
         
            +
             
     | 
| 
      
 34 
     | 
    
         
            +
                Returns: Pair of (init, apply) functions
         
     | 
| 
      
 35 
     | 
    
         
            +
                    init: (sys, q_ref) -> controller_state
         
     | 
| 
      
 36 
     | 
    
         
            +
                    apply: (controller_state, sys, state) -> controller_state, tau
         
     | 
| 
      
 37 
     | 
    
         
            +
             
     | 
| 
      
 38 
     | 
    
         
            +
                Example:
         
     | 
| 
      
 39 
     | 
    
         
            +
                    >>> gains = jnp.array([250.0] * sys1.qd_size())
         
     | 
| 
      
 40 
     | 
    
         
            +
                    >>> controller = pd_control(gains, gains)
         
     | 
| 
      
 41 
     | 
    
         
            +
                    >>> q_ref = rcmg(sys1)
         
     | 
| 
      
 42 
     | 
    
         
            +
                    >>> cs = controller.init(sys1, q_ref)
         
     | 
| 
      
 43 
     | 
    
         
            +
                    >>> for t in range(1000):
         
     | 
| 
      
 44 
     | 
    
         
            +
                    >>>     cs, tau = controller.apply(cs, sys2, state)
         
     | 
| 
      
 45 
     | 
    
         
            +
                    >>>     state = dynamics.step(sys2, state, tau)
         
     | 
| 
      
 46 
     | 
    
         
            +
                """
         
     | 
| 
      
 47 
     | 
    
         
            +
             
     | 
| 
      
 48 
     | 
    
         
            +
                def init(sys: base.System, q_ref: jax.Array) -> dict:
         
     | 
| 
      
 49 
     | 
    
         
            +
                    assert sys.q_size() == q_ref.shape[1], f"q_ref.shape = {q_ref.shape}"
         
     | 
| 
      
 50 
     | 
    
         
            +
                    assert sys.qd_size() == P.size
         
     | 
| 
      
 51 
     | 
    
         
            +
                    if D is not None:
         
     | 
| 
      
 52 
     | 
    
         
            +
                        sys.qd_size() == D.size
         
     | 
| 
      
 53 
     | 
    
         
            +
             
     | 
| 
      
 54 
     | 
    
         
            +
                    q_ref_as_dict = {}
         
     | 
| 
      
 55 
     | 
    
         
            +
                    qd_ref_as_dict = {}
         
     | 
| 
      
 56 
     | 
    
         
            +
                    P_as_dict = {}
         
     | 
| 
      
 57 
     | 
    
         
            +
                    D_as_dict = {}
         
     | 
| 
      
 58 
     | 
    
         
            +
             
     | 
| 
      
 59 
     | 
    
         
            +
                    def f(_, __, q_ref_link, name, typ, P_link, D_link):
         
     | 
| 
      
 60 
     | 
    
         
            +
                        P_as_dict[name] = P_link
         
     | 
| 
      
 61 
     | 
    
         
            +
                        q_ref_link = q_ref_link.T
         
     | 
| 
      
 62 
     | 
    
         
            +
                        q_ref_as_dict[name] = q_ref_link
         
     | 
| 
      
 63 
     | 
    
         
            +
             
     | 
| 
      
 64 
     | 
    
         
            +
                        if D is not None:
         
     | 
| 
      
 65 
     | 
    
         
            +
                            qd_from_q = jcalc.get_joint_model(typ).qd_from_q
         
     | 
| 
      
 66 
     | 
    
         
            +
                            if qd_from_q is None:
         
     | 
| 
      
 67 
     | 
    
         
            +
                                raise NotImplementedError(
         
     | 
| 
      
 68 
     | 
    
         
            +
                                    f"Please specify `JointModel.qd_from_q` for joint type `{typ}`"
         
     | 
| 
      
 69 
     | 
    
         
            +
                                )
         
     | 
| 
      
 70 
     | 
    
         
            +
                            qd_ref_as_dict[name] = qd_from_q(q_ref_link, sys.dt)
         
     | 
| 
      
 71 
     | 
    
         
            +
                            D_as_dict[name] = D_link
         
     | 
| 
      
 72 
     | 
    
         
            +
             
     | 
| 
      
 73 
     | 
    
         
            +
                    sys.scan(
         
     | 
| 
      
 74 
     | 
    
         
            +
                        f,
         
     | 
| 
      
 75 
     | 
    
         
            +
                        "qlldd",
         
     | 
| 
      
 76 
     | 
    
         
            +
                        q_ref.T,
         
     | 
| 
      
 77 
     | 
    
         
            +
                        sys.link_names,
         
     | 
| 
      
 78 
     | 
    
         
            +
                        sys.link_types,
         
     | 
| 
      
 79 
     | 
    
         
            +
                        P,
         
     | 
| 
      
 80 
     | 
    
         
            +
                        D if D is not None else jnp.zeros((sys.qd_size(),)),
         
     | 
| 
      
 81 
     | 
    
         
            +
                    )
         
     | 
| 
      
 82 
     | 
    
         
            +
                    return PDControllerState(0, q_ref_as_dict, qd_ref_as_dict, P_as_dict, D_as_dict)
         
     | 
| 
      
 83 
     | 
    
         
            +
             
     | 
| 
      
 84 
     | 
    
         
            +
                def apply(
         
     | 
| 
      
 85 
     | 
    
         
            +
                    controller_state: PDControllerState, sys: base.System, state: base.State
         
     | 
| 
      
 86 
     | 
    
         
            +
                ) -> jax.Array:
         
     | 
| 
      
 87 
     | 
    
         
            +
                    taus = jnp.zeros((sys.qd_size()))
         
     | 
| 
      
 88 
     | 
    
         
            +
                    q_ref, qd_ref = jax.tree_map(
         
     | 
| 
      
 89 
     | 
    
         
            +
                        lambda arr: jax.lax.dynamic_index_in_dim(
         
     | 
| 
      
 90 
     | 
    
         
            +
                            arr, controller_state.i, keepdims=False
         
     | 
| 
      
 91 
     | 
    
         
            +
                        ),
         
     | 
| 
      
 92 
     | 
    
         
            +
                        (controller_state.q_ref_as_dict, controller_state.qd_ref_as_dict),
         
     | 
| 
      
 93 
     | 
    
         
            +
                    )
         
     | 
| 
      
 94 
     | 
    
         
            +
             
     | 
| 
      
 95 
     | 
    
         
            +
                    def f(_, idx_map, idx, name, typ, q_curr, qd_curr):
         
     | 
| 
      
 96 
     | 
    
         
            +
                        nonlocal taus
         
     | 
| 
      
 97 
     | 
    
         
            +
             
     | 
| 
      
 98 
     | 
    
         
            +
                        if name not in controller_state.q_ref_as_dict:
         
     | 
| 
      
 99 
     | 
    
         
            +
                            return
         
     | 
| 
      
 100 
     | 
    
         
            +
             
     | 
| 
      
 101 
     | 
    
         
            +
                        p_control_term = jcalc.get_joint_model(typ).p_control_term
         
     | 
| 
      
 102 
     | 
    
         
            +
                        if p_control_term is None:
         
     | 
| 
      
 103 
     | 
    
         
            +
                            raise NotImplementedError(
         
     | 
| 
      
 104 
     | 
    
         
            +
                                f"Please specify `JointModel.p_control_term` for joint type `{typ}`"
         
     | 
| 
      
 105 
     | 
    
         
            +
                            )
         
     | 
| 
      
 106 
     | 
    
         
            +
                        P_term = p_control_term(q_curr, q_ref[name])
         
     | 
| 
      
 107 
     | 
    
         
            +
                        tau = P_term * controller_state.P_gains[name]
         
     | 
| 
      
 108 
     | 
    
         
            +
             
     | 
| 
      
 109 
     | 
    
         
            +
                        if name in controller_state.qd_ref_as_dict:
         
     | 
| 
      
 110 
     | 
    
         
            +
                            D_term = (qd_ref[name] - qd_curr) * controller_state.D_gains[name]
         
     | 
| 
      
 111 
     | 
    
         
            +
                            tau += D_term
         
     | 
| 
      
 112 
     | 
    
         
            +
             
     | 
| 
      
 113 
     | 
    
         
            +
                        taus = taus.at[idx_map["d"](idx)].set(tau)
         
     | 
| 
      
 114 
     | 
    
         
            +
             
     | 
| 
      
 115 
     | 
    
         
            +
                    sys.scan(
         
     | 
| 
      
 116 
     | 
    
         
            +
                        f,
         
     | 
| 
      
 117 
     | 
    
         
            +
                        "lllqd",
         
     | 
| 
      
 118 
     | 
    
         
            +
                        list(range(sys.num_links())),
         
     | 
| 
      
 119 
     | 
    
         
            +
                        sys.link_names,
         
     | 
| 
      
 120 
     | 
    
         
            +
                        sys.link_types,
         
     | 
| 
      
 121 
     | 
    
         
            +
                        state.q,
         
     | 
| 
      
 122 
     | 
    
         
            +
                        state.qd,
         
     | 
| 
      
 123 
     | 
    
         
            +
                    )
         
     | 
| 
      
 124 
     | 
    
         
            +
             
     | 
| 
      
 125 
     | 
    
         
            +
                    return controller_state.replace(i=controller_state.i + 1), taus
         
     | 
| 
      
 126 
     | 
    
         
            +
             
     | 
| 
      
 127 
     | 
    
         
            +
                return SimpleNamespace(init=init, apply=apply)
         
     | 
| 
      
 128 
     | 
    
         
            +
             
     | 
| 
      
 129 
     | 
    
         
            +
             
     | 
| 
      
 130 
     | 
    
         
            +
            def _unroll_dynamics_pd_control(
         
     | 
| 
      
 131 
     | 
    
         
            +
                sys: base.System,
         
     | 
| 
      
 132 
     | 
    
         
            +
                q_ref: jax.Array,
         
     | 
| 
      
 133 
     | 
    
         
            +
                P: jax.Array,
         
     | 
| 
      
 134 
     | 
    
         
            +
                D: Optional[jax.Array] = None,
         
     | 
| 
      
 135 
     | 
    
         
            +
                nograv: bool = False,
         
     | 
| 
      
 136 
     | 
    
         
            +
                sys_q_ref: Optional[base.System] = None,
         
     | 
| 
      
 137 
     | 
    
         
            +
                initial_sim_state_is_zeros: bool = False,
         
     | 
| 
      
 138 
     | 
    
         
            +
                clip_taus: Optional[float] = None,
         
     | 
| 
      
 139 
     | 
    
         
            +
            ):
         
     | 
| 
      
 140 
     | 
    
         
            +
                assert q_ref.ndim == 2
         
     | 
| 
      
 141 
     | 
    
         
            +
             
     | 
| 
      
 142 
     | 
    
         
            +
                if sys_q_ref is None:
         
     | 
| 
      
 143 
     | 
    
         
            +
                    sys_q_ref = sys
         
     | 
| 
      
 144 
     | 
    
         
            +
             
     | 
| 
      
 145 
     | 
    
         
            +
                if nograv:
         
     | 
| 
      
 146 
     | 
    
         
            +
                    sys = sys.replace(gravity=sys.gravity * 0.0)
         
     | 
| 
      
 147 
     | 
    
         
            +
             
     | 
| 
      
 148 
     | 
    
         
            +
                if initial_sim_state_is_zeros:
         
     | 
| 
      
 149 
     | 
    
         
            +
                    state = base.State.create(sys)
         
     | 
| 
      
 150 
     | 
    
         
            +
                else:
         
     | 
| 
      
 151 
     | 
    
         
            +
                    state = _initial_q_is_q_ref(sys, sys_q_ref, q_ref[0])
         
     | 
| 
      
 152 
     | 
    
         
            +
             
     | 
| 
      
 153 
     | 
    
         
            +
                controller = _pd_control(P, D)
         
     | 
| 
      
 154 
     | 
    
         
            +
                cs = controller.init(sys_q_ref, q_ref)
         
     | 
| 
      
 155 
     | 
    
         
            +
             
     | 
| 
      
 156 
     | 
    
         
            +
                def step(carry, _):
         
     | 
| 
      
 157 
     | 
    
         
            +
                    state, cs = carry
         
     | 
| 
      
 158 
     | 
    
         
            +
                    cs, taus = controller.apply(cs, sys, state)
         
     | 
| 
      
 159 
     | 
    
         
            +
                    if clip_taus is not None:
         
     | 
| 
      
 160 
     | 
    
         
            +
                        assert clip_taus > 0.0
         
     | 
| 
      
 161 
     | 
    
         
            +
                        taus = jnp.clip(taus, -clip_taus, clip_taus)
         
     | 
| 
      
 162 
     | 
    
         
            +
                    state = dynamics.step(sys, state, taus)
         
     | 
| 
      
 163 
     | 
    
         
            +
                    carry = (state, cs)
         
     | 
| 
      
 164 
     | 
    
         
            +
                    return carry, state
         
     | 
| 
      
 165 
     | 
    
         
            +
             
     | 
| 
      
 166 
     | 
    
         
            +
                states = jax.lax.scan(step, (state, cs), None, length=q_ref.shape[0])[1]
         
     | 
| 
      
 167 
     | 
    
         
            +
                return states
         
     | 
| 
      
 168 
     | 
    
         
            +
             
     | 
| 
      
 169 
     | 
    
         
            +
             
     | 
| 
      
 170 
     | 
    
         
            +
            def _initial_q_is_q_ref(sys: base.System, sys_q_ref: base.System, q_ref):
         
     | 
| 
      
 171 
     | 
    
         
            +
                # you can not preallocate q using zeros because of quaternions..
         
     | 
| 
      
 172 
     | 
    
         
            +
                q = base.State.create(sys).q
         
     | 
| 
      
 173 
     | 
    
         
            +
             
     | 
| 
      
 174 
     | 
    
         
            +
                sys_q_map = sys.idx_map("q")
         
     | 
| 
      
 175 
     | 
    
         
            +
             
     | 
| 
      
 176 
     | 
    
         
            +
                def f(_, __, name, q_ref_link):
         
     | 
| 
      
 177 
     | 
    
         
            +
                    nonlocal q
         
     | 
| 
      
 178 
     | 
    
         
            +
                    q = q.at[sys_q_map[name]].set(q_ref_link)
         
     | 
| 
      
 179 
     | 
    
         
            +
             
     | 
| 
      
 180 
     | 
    
         
            +
                sys_q_ref.scan(f, "lq", sys_q_ref.link_names, q_ref)
         
     | 
| 
      
 181 
     | 
    
         
            +
             
     | 
| 
      
 182 
     | 
    
         
            +
                return base.State.create(sys, q=q)
         
     | 
| 
         @@ -0,0 +1,119 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            """Randomization by modifying System and MotionConfig objects before building
         
     | 
| 
      
 2 
     | 
    
         
            +
            generator."""
         
     | 
| 
      
 3 
     | 
    
         
            +
             
     | 
| 
      
 4 
     | 
    
         
            +
            from dataclasses import replace
         
     | 
| 
      
 5 
     | 
    
         
            +
            import itertools
         
     | 
| 
      
 6 
     | 
    
         
            +
            from typing import Optional
         
     | 
| 
      
 7 
     | 
    
         
            +
            import warnings
         
     | 
| 
      
 8 
     | 
    
         
            +
             
     | 
| 
      
 9 
     | 
    
         
            +
            import jax.numpy as jnp
         
     | 
| 
      
 10 
     | 
    
         
            +
            from ring import base
         
     | 
| 
      
 11 
     | 
    
         
            +
            from ring.algorithms import jcalc
         
     | 
| 
      
 12 
     | 
    
         
            +
            from ring.algorithms.generator import types
         
     | 
| 
      
 13 
     | 
    
         
            +
             
     | 
| 
      
 14 
     | 
    
         
            +
             
     | 
| 
      
 15 
     | 
    
         
            +
            def _find_children(lam: list[int], body: int) -> list[int]:
         
     | 
| 
      
 16 
     | 
    
         
            +
             
     | 
| 
      
 17 
     | 
    
         
            +
                children = []
         
     | 
| 
      
 18 
     | 
    
         
            +
             
     | 
| 
      
 19 
     | 
    
         
            +
                def _children(body: int) -> None:
         
     | 
| 
      
 20 
     | 
    
         
            +
                    for i in range(len(lam)):
         
     | 
| 
      
 21 
     | 
    
         
            +
                        if lam[i] == body:
         
     | 
| 
      
 22 
     | 
    
         
            +
                            children.append(i)
         
     | 
| 
      
 23 
     | 
    
         
            +
                            _children(i)
         
     | 
| 
      
 24 
     | 
    
         
            +
             
     | 
| 
      
 25 
     | 
    
         
            +
                _children(body)
         
     | 
| 
      
 26 
     | 
    
         
            +
                return children
         
     | 
| 
      
 27 
     | 
    
         
            +
             
     | 
| 
      
 28 
     | 
    
         
            +
             
     | 
| 
      
 29 
     | 
    
         
            +
            def _find_root_of_subsys_that_contains_body(sys: base.System, body: str) -> str:
         
     | 
| 
      
 30 
     | 
    
         
            +
                body_i = sys.name_to_idx(body)
         
     | 
| 
      
 31 
     | 
    
         
            +
                for i, p in enumerate(sys.link_parents):
         
     | 
| 
      
 32 
     | 
    
         
            +
                    if p == -1:
         
     | 
| 
      
 33 
     | 
    
         
            +
                        if body_i == i or body_i in _find_children(sys.link_parents, i):
         
     | 
| 
      
 34 
     | 
    
         
            +
                            return sys.idx_to_name(i)
         
     | 
| 
      
 35 
     | 
    
         
            +
             
     | 
| 
      
 36 
     | 
    
         
            +
             
     | 
| 
      
 37 
     | 
    
         
            +
            def _assign_anchors_to_subsys(sys: base.System, anchors: list[str]) -> list[list[str]]:
         
     | 
| 
      
 38 
     | 
    
         
            +
                anchors_per_subsys = []
         
     | 
| 
      
 39 
     | 
    
         
            +
                for i, p in enumerate(sys.link_parents):
         
     | 
| 
      
 40 
     | 
    
         
            +
                    if p == -1:
         
     | 
| 
      
 41 
     | 
    
         
            +
                        link_idxs_subsys = [i] + _find_children(sys.link_parents, i)
         
     | 
| 
      
 42 
     | 
    
         
            +
                        link_names_subsys = [sys.idx_to_name(i) for i in link_idxs_subsys]
         
     | 
| 
      
 43 
     | 
    
         
            +
                        anchors_this_subsys = [
         
     | 
| 
      
 44 
     | 
    
         
            +
                            name for name in anchors if name in link_names_subsys
         
     | 
| 
      
 45 
     | 
    
         
            +
                        ]
         
     | 
| 
      
 46 
     | 
    
         
            +
                        if len(anchors_this_subsys) == 0:
         
     | 
| 
      
 47 
     | 
    
         
            +
                            anchors_this_subsys = [sys.idx_to_name(i)]
         
     | 
| 
      
 48 
     | 
    
         
            +
                        anchors_per_subsys.append(anchors_this_subsys)
         
     | 
| 
      
 49 
     | 
    
         
            +
                return anchors_per_subsys
         
     | 
| 
      
 50 
     | 
    
         
            +
             
     | 
| 
      
 51 
     | 
    
         
            +
             
     | 
| 
      
 52 
     | 
    
         
            +
            def _morph_extract_subsys(sys: base.System, anchor: str):
         
     | 
| 
      
 53 
     | 
    
         
            +
                root = _find_root_of_subsys_that_contains_body(sys, anchor)
         
     | 
| 
      
 54 
     | 
    
         
            +
                roots = sys.findall_bodies_to_world(names=True)
         
     | 
| 
      
 55 
     | 
    
         
            +
                subsys = sys.delete_system(list(set(roots) - set([root])))
         
     | 
| 
      
 56 
     | 
    
         
            +
                return subsys.morph_system(new_anchor=anchor)
         
     | 
| 
      
 57 
     | 
    
         
            +
             
     | 
| 
      
 58 
     | 
    
         
            +
             
     | 
| 
      
 59 
     | 
    
         
            +
            def randomize_anchors(
         
     | 
| 
      
 60 
     | 
    
         
            +
                sys: base.System, anchors: Optional[list[str]] = None
         
     | 
| 
      
 61 
     | 
    
         
            +
            ) -> list[base.System]:
         
     | 
| 
      
 62 
     | 
    
         
            +
             
     | 
| 
      
 63 
     | 
    
         
            +
                if anchors is None:
         
     | 
| 
      
 64 
     | 
    
         
            +
                    anchors = sys.findall_segments()
         
     | 
| 
      
 65 
     | 
    
         
            +
             
     | 
| 
      
 66 
     | 
    
         
            +
                anchors = _assign_anchors_to_subsys(sys, anchors)
         
     | 
| 
      
 67 
     | 
    
         
            +
                syss = []
         
     | 
| 
      
 68 
     | 
    
         
            +
                for anchors_subsys in itertools.product(*anchors):
         
     | 
| 
      
 69 
     | 
    
         
            +
                    sys_mod = _morph_extract_subsys(sys, anchors_subsys[0])
         
     | 
| 
      
 70 
     | 
    
         
            +
                    for anchor_subsys in anchors_subsys[1:]:
         
     | 
| 
      
 71 
     | 
    
         
            +
                        sys_mod = sys_mod.inject_system(_morph_extract_subsys(sys, anchor_subsys))
         
     | 
| 
      
 72 
     | 
    
         
            +
                    syss.append(sys_mod)
         
     | 
| 
      
 73 
     | 
    
         
            +
             
     | 
| 
      
 74 
     | 
    
         
            +
                return syss
         
     | 
| 
      
 75 
     | 
    
         
            +
             
     | 
| 
      
 76 
     | 
    
         
            +
             
     | 
| 
      
 77 
     | 
    
         
            +
            _WARN_HZ_Threshold: float = 40.0
         
     | 
| 
      
 78 
     | 
    
         
            +
             
     | 
| 
      
 79 
     | 
    
         
            +
             
     | 
| 
      
 80 
     | 
    
         
            +
            def randomize_hz(
         
     | 
| 
      
 81 
     | 
    
         
            +
                sys: list[base.System],
         
     | 
| 
      
 82 
     | 
    
         
            +
                configs: list[jcalc.MotionConfig],
         
     | 
| 
      
 83 
     | 
    
         
            +
                sampling_rates: list[float],
         
     | 
| 
      
 84 
     | 
    
         
            +
            ) -> tuple[list[base.System], list[jcalc.MotionConfig]]:
         
     | 
| 
      
 85 
     | 
    
         
            +
                Ts = [c.T for c in configs]
         
     | 
| 
      
 86 
     | 
    
         
            +
                assert len(set(Ts)), f"Time length between configs does not agree {Ts}"
         
     | 
| 
      
 87 
     | 
    
         
            +
                T_global = Ts[0]
         
     | 
| 
      
 88 
     | 
    
         
            +
             
     | 
| 
      
 89 
     | 
    
         
            +
                for hz in sampling_rates:
         
     | 
| 
      
 90 
     | 
    
         
            +
                    if hz < _WARN_HZ_Threshold:
         
     | 
| 
      
 91 
     | 
    
         
            +
                        warnings.warn(
         
     | 
| 
      
 92 
     | 
    
         
            +
                            "The sampling rate {hz} is below the warning threshold of "
         
     | 
| 
      
 93 
     | 
    
         
            +
                            f"{_WARN_HZ_Threshold}. This might lead to NaNs."
         
     | 
| 
      
 94 
     | 
    
         
            +
                        )
         
     | 
| 
      
 95 
     | 
    
         
            +
             
     | 
| 
      
 96 
     | 
    
         
            +
                sys_out, configs_out = [], []
         
     | 
| 
      
 97 
     | 
    
         
            +
                for _sys in sys:
         
     | 
| 
      
 98 
     | 
    
         
            +
                    for _config in configs:
         
     | 
| 
      
 99 
     | 
    
         
            +
                        for hz in sampling_rates:
         
     | 
| 
      
 100 
     | 
    
         
            +
                            dt = 1 / hz
         
     | 
| 
      
 101 
     | 
    
         
            +
                            T = (T_global / _sys.dt) * dt
         
     | 
| 
      
 102 
     | 
    
         
            +
             
     | 
| 
      
 103 
     | 
    
         
            +
                            sys_out.append(_sys.replace(dt=dt))
         
     | 
| 
      
 104 
     | 
    
         
            +
                            configs_out.append(replace(_config, T=T))
         
     | 
| 
      
 105 
     | 
    
         
            +
                return sys_out, configs_out
         
     | 
| 
      
 106 
     | 
    
         
            +
             
     | 
| 
      
 107 
     | 
    
         
            +
             
     | 
| 
      
 108 
     | 
    
         
            +
            def randomize_hz_finalize_fn_factory(finalize_fn_user: types.FINALIZE_FN | None):
         
     | 
| 
      
 109 
     | 
    
         
            +
                def finalize_fn(key, q, x, sys: base.System):
         
     | 
| 
      
 110 
     | 
    
         
            +
                    X, y = {}, {}
         
     | 
| 
      
 111 
     | 
    
         
            +
                    if finalize_fn_user is not None:
         
     | 
| 
      
 112 
     | 
    
         
            +
                        X, y = finalize_fn_user(key, q, x, sys)
         
     | 
| 
      
 113 
     | 
    
         
            +
             
     | 
| 
      
 114 
     | 
    
         
            +
                    assert "dt" not in X
         
     | 
| 
      
 115 
     | 
    
         
            +
                    X["dt"] = jnp.array([sys.dt], dtype=jnp.float32)
         
     | 
| 
      
 116 
     | 
    
         
            +
             
     | 
| 
      
 117 
     | 
    
         
            +
                    return X, y
         
     | 
| 
      
 118 
     | 
    
         
            +
             
     | 
| 
      
 119 
     | 
    
         
            +
                return finalize_fn
         
     |