imt-ring 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- imt_ring-1.2.1.dist-info/METADATA +91 -0
 - imt_ring-1.2.1.dist-info/RECORD +83 -0
 - imt_ring-1.2.1.dist-info/WHEEL +5 -0
 - imt_ring-1.2.1.dist-info/top_level.txt +1 -0
 - ring/__init__.py +63 -0
 - ring/algebra.py +100 -0
 - ring/algorithms/__init__.py +45 -0
 - ring/algorithms/_random.py +403 -0
 - ring/algorithms/custom_joints/__init__.py +6 -0
 - ring/algorithms/custom_joints/rr_imp_joint.py +69 -0
 - ring/algorithms/custom_joints/rr_joint.py +33 -0
 - ring/algorithms/custom_joints/suntay.py +424 -0
 - ring/algorithms/dynamics.py +345 -0
 - ring/algorithms/generator/__init__.py +25 -0
 - ring/algorithms/generator/base.py +414 -0
 - ring/algorithms/generator/batch.py +282 -0
 - ring/algorithms/generator/motion_artifacts.py +222 -0
 - ring/algorithms/generator/pd_control.py +182 -0
 - ring/algorithms/generator/randomize.py +119 -0
 - ring/algorithms/generator/transforms.py +410 -0
 - ring/algorithms/generator/types.py +36 -0
 - ring/algorithms/jcalc.py +840 -0
 - ring/algorithms/kinematics.py +202 -0
 - ring/algorithms/sensors.py +582 -0
 - ring/base.py +1046 -0
 - ring/io/__init__.py +9 -0
 - ring/io/examples/branched.xml +24 -0
 - ring/io/examples/exclude/knee_trans_dof.xml +26 -0
 - ring/io/examples/exclude/standard_sys.xml +106 -0
 - ring/io/examples/exclude/standard_sys_rr_imp.xml +106 -0
 - ring/io/examples/inv_pendulum.xml +14 -0
 - ring/io/examples/knee_flexible_imus.xml +22 -0
 - ring/io/examples/spherical_stiff.xml +11 -0
 - ring/io/examples/symmetric.xml +12 -0
 - ring/io/examples/test_all_1.xml +39 -0
 - ring/io/examples/test_all_2.xml +39 -0
 - ring/io/examples/test_ang0_pos0.xml +9 -0
 - ring/io/examples/test_control.xml +16 -0
 - ring/io/examples/test_double_pendulum.xml +14 -0
 - ring/io/examples/test_free.xml +11 -0
 - ring/io/examples/test_kinematics.xml +23 -0
 - ring/io/examples/test_morph_system/four_seg_seg1.xml +26 -0
 - ring/io/examples/test_morph_system/four_seg_seg3.xml +26 -0
 - ring/io/examples/test_randomize_position.xml +26 -0
 - ring/io/examples/test_sensors.xml +13 -0
 - ring/io/examples/test_three_seg_seg2.xml +23 -0
 - ring/io/examples.py +42 -0
 - ring/io/test_examples.py +6 -0
 - ring/io/xml/__init__.py +6 -0
 - ring/io/xml/abstract.py +300 -0
 - ring/io/xml/from_xml.py +299 -0
 - ring/io/xml/test_from_xml.py +56 -0
 - ring/io/xml/test_to_xml.py +31 -0
 - ring/io/xml/to_xml.py +94 -0
 - ring/maths.py +397 -0
 - ring/ml/__init__.py +33 -0
 - ring/ml/base.py +292 -0
 - ring/ml/callbacks.py +434 -0
 - ring/ml/ml_utils.py +272 -0
 - ring/ml/optimizer.py +149 -0
 - ring/ml/params/0x13e3518065c21cd8.pickle +0 -0
 - ring/ml/ringnet.py +279 -0
 - ring/ml/train.py +318 -0
 - ring/ml/training_loop.py +131 -0
 - ring/rendering/__init__.py +2 -0
 - ring/rendering/base_render.py +271 -0
 - ring/rendering/mujoco_render.py +222 -0
 - ring/rendering/vispy_render.py +340 -0
 - ring/rendering/vispy_visuals.py +290 -0
 - ring/sim2real/__init__.py +7 -0
 - ring/sim2real/sim2real.py +288 -0
 - ring/spatial.py +126 -0
 - ring/sys_composer/__init__.py +5 -0
 - ring/sys_composer/delete_sys.py +114 -0
 - ring/sys_composer/inject_sys.py +110 -0
 - ring/sys_composer/morph_sys.py +361 -0
 - ring/utils/__init__.py +21 -0
 - ring/utils/batchsize.py +51 -0
 - ring/utils/colab.py +48 -0
 - ring/utils/hdf5.py +198 -0
 - ring/utils/normalizer.py +56 -0
 - ring/utils/path.py +44 -0
 - ring/utils/utils.py +161 -0
 
| 
         @@ -0,0 +1,56 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            import jax.numpy as jnp
         
     | 
| 
      
 2 
     | 
    
         
            +
            import ring
         
     | 
| 
      
 3 
     | 
    
         
            +
             
     | 
| 
      
 4 
     | 
    
         
            +
            sys_str = """
         
     | 
| 
      
 5 
     | 
    
         
            +
            <x_xy model="model">
         
     | 
| 
      
 6 
     | 
    
         
            +
            <options gravity=".1 2 3" dt=".03"/>
         
     | 
| 
      
 7 
     | 
    
         
            +
                <worldbody>
         
     | 
| 
      
 8 
     | 
    
         
            +
                    <body name="name" joint="rx" pos="1 2 3" euler="30 30 30" damping=".7" armature=".8" spring_stiff="1" spring_zero=".9">
         
     | 
| 
      
 9 
     | 
    
         
            +
                        <geom type="box" mass="2.7" dim="0.2 0.3 0.4" color="black" edge_color="pink"/>
         
     | 
| 
      
 10 
     | 
    
         
            +
                    </body>
         
     | 
| 
      
 11 
     | 
    
         
            +
                </worldbody>
         
     | 
| 
      
 12 
     | 
    
         
            +
            </x_xy>
         
     | 
| 
      
 13 
     | 
    
         
            +
            """  # noqa: E501
         
     | 
| 
      
 14 
     | 
    
         
            +
             
     | 
| 
      
 15 
     | 
    
         
            +
             
     | 
| 
      
 16 
     | 
    
         
            +
            def test_from_xml():
         
     | 
| 
      
 17 
     | 
    
         
            +
                pos = jnp.array([1.0, 2, 3])
         
     | 
| 
      
 18 
     | 
    
         
            +
                sys1 = ring.System(
         
     | 
| 
      
 19 
     | 
    
         
            +
                    [-1],
         
     | 
| 
      
 20 
     | 
    
         
            +
                    ring.base.Link(
         
     | 
| 
      
 21 
     | 
    
         
            +
                        ring.base.Transform(
         
     | 
| 
      
 22 
     | 
    
         
            +
                            pos=pos,
         
     | 
| 
      
 23 
     | 
    
         
            +
                            rot=ring.maths.quat_euler(
         
     | 
| 
      
 24 
     | 
    
         
            +
                                jnp.array([jnp.deg2rad(30), jnp.deg2rad(30), jnp.deg2rad(30)])
         
     | 
| 
      
 25 
     | 
    
         
            +
                            ),
         
     | 
| 
      
 26 
     | 
    
         
            +
                        ),
         
     | 
| 
      
 27 
     | 
    
         
            +
                        pos_min=pos,
         
     | 
| 
      
 28 
     | 
    
         
            +
                        pos_max=pos,
         
     | 
| 
      
 29 
     | 
    
         
            +
                    ).batch(),
         
     | 
| 
      
 30 
     | 
    
         
            +
                    ["rx"],
         
     | 
| 
      
 31 
     | 
    
         
            +
                    link_damping=jnp.array([0.7]),
         
     | 
| 
      
 32 
     | 
    
         
            +
                    link_armature=jnp.array([0.8]),
         
     | 
| 
      
 33 
     | 
    
         
            +
                    link_spring_zeropoint=jnp.array([0.9]),
         
     | 
| 
      
 34 
     | 
    
         
            +
                    link_spring_stiffness=jnp.array([1.0]),
         
     | 
| 
      
 35 
     | 
    
         
            +
                    dt=0.03,
         
     | 
| 
      
 36 
     | 
    
         
            +
                    geoms=[
         
     | 
| 
      
 37 
     | 
    
         
            +
                        ring.base.Box(
         
     | 
| 
      
 38 
     | 
    
         
            +
                            jnp.array(2.7),
         
     | 
| 
      
 39 
     | 
    
         
            +
                            ring.Transform.zero(),
         
     | 
| 
      
 40 
     | 
    
         
            +
                            0,
         
     | 
| 
      
 41 
     | 
    
         
            +
                            "black",
         
     | 
| 
      
 42 
     | 
    
         
            +
                            "pink",
         
     | 
| 
      
 43 
     | 
    
         
            +
                            jnp.array(0.2),
         
     | 
| 
      
 44 
     | 
    
         
            +
                            jnp.array(0.3),
         
     | 
| 
      
 45 
     | 
    
         
            +
                            jnp.array(0.4),
         
     | 
| 
      
 46 
     | 
    
         
            +
                        )
         
     | 
| 
      
 47 
     | 
    
         
            +
                    ],
         
     | 
| 
      
 48 
     | 
    
         
            +
                    gravity=jnp.array([0.1, 2, 3.0]),
         
     | 
| 
      
 49 
     | 
    
         
            +
                    link_names=["name"],
         
     | 
| 
      
 50 
     | 
    
         
            +
                    model_name="model",
         
     | 
| 
      
 51 
     | 
    
         
            +
                    omc=[None],
         
     | 
| 
      
 52 
     | 
    
         
            +
                )
         
     | 
| 
      
 53 
     | 
    
         
            +
                sys1 = sys1.parse()
         
     | 
| 
      
 54 
     | 
    
         
            +
                sys2 = ring.io.load_sys_from_str(sys_str)
         
     | 
| 
      
 55 
     | 
    
         
            +
             
     | 
| 
      
 56 
     | 
    
         
            +
                assert ring.utils.sys_compare(sys1, sys2)
         
     | 
| 
         @@ -0,0 +1,31 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            import logging
         
     | 
| 
      
 2 
     | 
    
         
            +
             
     | 
| 
      
 3 
     | 
    
         
            +
            import ring
         
     | 
| 
      
 4 
     | 
    
         
            +
            from ring.base import System
         
     | 
| 
      
 5 
     | 
    
         
            +
            from ring.utils import sys_compare
         
     | 
| 
      
 6 
     | 
    
         
            +
             
     | 
| 
      
 7 
     | 
    
         
            +
             
     | 
| 
      
 8 
     | 
    
         
            +
            def test_save_sys_to_str():
         
     | 
| 
      
 9 
     | 
    
         
            +
                for original_sys in ring.io.list_load_examples():
         
     | 
| 
      
 10 
     | 
    
         
            +
                    sys_to_xml_str = ring.io.save_sys_to_str(original_sys)
         
     | 
| 
      
 11 
     | 
    
         
            +
             
     | 
| 
      
 12 
     | 
    
         
            +
                    logging.debug(sys_to_xml_str)
         
     | 
| 
      
 13 
     | 
    
         
            +
             
     | 
| 
      
 14 
     | 
    
         
            +
                    compare_sys = ring.io.load_sys_from_str(sys_to_xml_str)
         
     | 
| 
      
 15 
     | 
    
         
            +
             
     | 
| 
      
 16 
     | 
    
         
            +
                    assert sys_compare(
         
     | 
| 
      
 17 
     | 
    
         
            +
                        original_sys, compare_sys
         
     | 
| 
      
 18 
     | 
    
         
            +
                    ), f"Failed {original_sys.model_name}.xml"
         
     | 
| 
      
 19 
     | 
    
         
            +
             
     | 
| 
      
 20 
     | 
    
         
            +
                    print(f"Passed {original_sys.model_name}.xml")
         
     | 
| 
      
 21 
     | 
    
         
            +
             
     | 
| 
      
 22 
     | 
    
         
            +
                def double_load_xml_to_sys(example: str) -> System:
         
     | 
| 
      
 23 
     | 
    
         
            +
                    orig_sys = ring.io.load_example(example)
         
     | 
| 
      
 24 
     | 
    
         
            +
                    exported_xml = ring.io.save_sys_to_str(orig_sys)
         
     | 
| 
      
 25 
     | 
    
         
            +
                    new_sys = ring.io.load_sys_from_str(exported_xml)
         
     | 
| 
      
 26 
     | 
    
         
            +
                    return new_sys
         
     | 
| 
      
 27 
     | 
    
         
            +
             
     | 
| 
      
 28 
     | 
    
         
            +
                sys_test_xml_1 = double_load_xml_to_sys("test_all_1.xml")
         
     | 
| 
      
 29 
     | 
    
         
            +
                sys_test_xml_2 = double_load_xml_to_sys("test_all_2.xml")
         
     | 
| 
      
 30 
     | 
    
         
            +
             
     | 
| 
      
 31 
     | 
    
         
            +
                assert not sys_compare(sys_test_xml_1, sys_test_xml_2)
         
     | 
    
        ring/io/xml/to_xml.py
    ADDED
    
    | 
         @@ -0,0 +1,94 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            import warnings
         
     | 
| 
      
 2 
     | 
    
         
            +
            from xml.dom.minidom import parseString
         
     | 
| 
      
 3 
     | 
    
         
            +
            from xml.etree.ElementTree import Element
         
     | 
| 
      
 4 
     | 
    
         
            +
            from xml.etree.ElementTree import SubElement
         
     | 
| 
      
 5 
     | 
    
         
            +
            from xml.etree.ElementTree import tostring
         
     | 
| 
      
 6 
     | 
    
         
            +
             
     | 
| 
      
 7 
     | 
    
         
            +
            import jax.numpy as jnp
         
     | 
| 
      
 8 
     | 
    
         
            +
            from ring import base
         
     | 
| 
      
 9 
     | 
    
         
            +
            from tree_utils import batch_concat
         
     | 
| 
      
 10 
     | 
    
         
            +
             
     | 
| 
      
 11 
     | 
    
         
            +
            from . import abstract
         
     | 
| 
      
 12 
     | 
    
         
            +
            from .abstract import _to_str
         
     | 
| 
      
 13 
     | 
    
         
            +
             
     | 
| 
      
 14 
     | 
    
         
            +
             
     | 
| 
      
 15 
     | 
    
         
            +
            def save_sys_to_str(sys: base.System) -> str:
         
     | 
| 
      
 16 
     | 
    
         
            +
                for joint_type in sys.links.joint_params:
         
     | 
| 
      
 17 
     | 
    
         
            +
                    for i, link_name in enumerate(sys.link_names):
         
     | 
| 
      
 18 
     | 
    
         
            +
                        joint_params_flat = batch_concat((sys.links[i]).joint_params[joint_type], 0)
         
     | 
| 
      
 19 
     | 
    
         
            +
                        if not jnp.all(joint_params_flat == 0.0):
         
     | 
| 
      
 20 
     | 
    
         
            +
                            warnings.warn(
         
     | 
| 
      
 21 
     | 
    
         
            +
                                "The system has `sys.links.joint_params` unequal to the 'default'"
         
     | 
| 
      
 22 
     | 
    
         
            +
                                f" value (of zeros). In particular the link `{link_name}` has for"
         
     | 
| 
      
 23 
     | 
    
         
            +
                                f" the jointtype `{joint_type}` the values {joint_params_flat}. "
         
     | 
| 
      
 24 
     | 
    
         
            +
                                "This will not be preserved in the xml."
         
     | 
| 
      
 25 
     | 
    
         
            +
                            )
         
     | 
| 
      
 26 
     | 
    
         
            +
                global_index_map = {qd: sys.idx_map(qd) for qd in ["q", "d"]}
         
     | 
| 
      
 27 
     | 
    
         
            +
             
     | 
| 
      
 28 
     | 
    
         
            +
                # Create root element
         
     | 
| 
      
 29 
     | 
    
         
            +
                x_xy = Element("x_xy")
         
     | 
| 
      
 30 
     | 
    
         
            +
                x_xy.set("model", sys.model_name)
         
     | 
| 
      
 31 
     | 
    
         
            +
             
     | 
| 
      
 32 
     | 
    
         
            +
                options = SubElement(x_xy, "options")
         
     | 
| 
      
 33 
     | 
    
         
            +
                options.set("dt", str(sys.dt))
         
     | 
| 
      
 34 
     | 
    
         
            +
                options.set("gravity", _to_str(sys.gravity))
         
     | 
| 
      
 35 
     | 
    
         
            +
             
     | 
| 
      
 36 
     | 
    
         
            +
                # Create worldbody
         
     | 
| 
      
 37 
     | 
    
         
            +
                worldbody = SubElement(x_xy, "worldbody")
         
     | 
| 
      
 38 
     | 
    
         
            +
             
     | 
| 
      
 39 
     | 
    
         
            +
                def process_link(link_idx: int, parent_elem: Element):
         
     | 
| 
      
 40 
     | 
    
         
            +
                    link = sys.links[link_idx]
         
     | 
| 
      
 41 
     | 
    
         
            +
                    link_typ = sys.link_types[link_idx]
         
     | 
| 
      
 42 
     | 
    
         
            +
                    link_name = sys.link_names[link_idx]
         
     | 
| 
      
 43 
     | 
    
         
            +
             
     | 
| 
      
 44 
     | 
    
         
            +
                    # Create body element
         
     | 
| 
      
 45 
     | 
    
         
            +
                    body = SubElement(parent_elem, "body")
         
     | 
| 
      
 46 
     | 
    
         
            +
                    body.set("joint", link_typ)
         
     | 
| 
      
 47 
     | 
    
         
            +
                    body.set("name", link_name)
         
     | 
| 
      
 48 
     | 
    
         
            +
             
     | 
| 
      
 49 
     | 
    
         
            +
                    # Set attributes
         
     | 
| 
      
 50 
     | 
    
         
            +
                    abstract.AbsTrans.to_xml(body, link.transform1)
         
     | 
| 
      
 51 
     | 
    
         
            +
                    abstract.AbsPosMinMax.to_xml(body, link.pos_min, link.pos_max)
         
     | 
| 
      
 52 
     | 
    
         
            +
                    abstract.AbsDampArmaStiffZero.to_xml(
         
     | 
| 
      
 53 
     | 
    
         
            +
                        body,
         
     | 
| 
      
 54 
     | 
    
         
            +
                        sys.link_damping[global_index_map["d"][link_name]],
         
     | 
| 
      
 55 
     | 
    
         
            +
                        sys.link_armature[global_index_map["d"][link_name]],
         
     | 
| 
      
 56 
     | 
    
         
            +
                        sys.link_spring_stiffness[global_index_map["d"][link_name]],
         
     | 
| 
      
 57 
     | 
    
         
            +
                        sys.link_spring_zeropoint[global_index_map["q"][link_name]],
         
     | 
| 
      
 58 
     | 
    
         
            +
                        base.Q_WIDTHS[link_typ],
         
     | 
| 
      
 59 
     | 
    
         
            +
                        base.QD_WIDTHS[link_typ],
         
     | 
| 
      
 60 
     | 
    
         
            +
                        link_typ,
         
     | 
| 
      
 61 
     | 
    
         
            +
                    )
         
     | 
| 
      
 62 
     | 
    
         
            +
             
     | 
| 
      
 63 
     | 
    
         
            +
                    # Add geometry elements
         
     | 
| 
      
 64 
     | 
    
         
            +
                    geoms = sys.geoms
         
     | 
| 
      
 65 
     | 
    
         
            +
                    for geom in geoms:
         
     | 
| 
      
 66 
     | 
    
         
            +
                        if geom.link_idx == link_idx:
         
     | 
| 
      
 67 
     | 
    
         
            +
                            geom_elem = SubElement(body, "geom")
         
     | 
| 
      
 68 
     | 
    
         
            +
                            abstract_class = abstract.geometry_to_abstract[type(geom)]
         
     | 
| 
      
 69 
     | 
    
         
            +
                            abstract_class.to_xml(geom_elem, geom)
         
     | 
| 
      
 70 
     | 
    
         
            +
             
     | 
| 
      
 71 
     | 
    
         
            +
                    # Maybe add omc element
         
     | 
| 
      
 72 
     | 
    
         
            +
                    omc_link = sys.omc[link_idx]
         
     | 
| 
      
 73 
     | 
    
         
            +
                    if omc_link is not None:
         
     | 
| 
      
 74 
     | 
    
         
            +
                        omc_elem = SubElement(body, "omc")
         
     | 
| 
      
 75 
     | 
    
         
            +
                        abstract.AbsMaxCoordOMC.to_xml(omc_elem, omc_link)
         
     | 
| 
      
 76 
     | 
    
         
            +
             
     | 
| 
      
 77 
     | 
    
         
            +
                    # Recursively process child links
         
     | 
| 
      
 78 
     | 
    
         
            +
                    for child_idx, parent_idx in enumerate(sys.link_parents):
         
     | 
| 
      
 79 
     | 
    
         
            +
                        if parent_idx == link_idx:
         
     | 
| 
      
 80 
     | 
    
         
            +
                            process_link(child_idx, body)
         
     | 
| 
      
 81 
     | 
    
         
            +
             
     | 
| 
      
 82 
     | 
    
         
            +
                for root_link_idx, parent_idx in enumerate(sys.link_parents):
         
     | 
| 
      
 83 
     | 
    
         
            +
                    if parent_idx == -1:
         
     | 
| 
      
 84 
     | 
    
         
            +
                        process_link(root_link_idx, worldbody)
         
     | 
| 
      
 85 
     | 
    
         
            +
             
     | 
| 
      
 86 
     | 
    
         
            +
                # Pretty print xml
         
     | 
| 
      
 87 
     | 
    
         
            +
                xml_str = parseString(tostring(x_xy)).toprettyxml(indent="  ")
         
     | 
| 
      
 88 
     | 
    
         
            +
                return xml_str
         
     | 
| 
      
 89 
     | 
    
         
            +
             
     | 
| 
      
 90 
     | 
    
         
            +
             
     | 
| 
      
 91 
     | 
    
         
            +
            def save_sys_to_xml(sys: base.System, xml_path: str) -> None:
         
     | 
| 
      
 92 
     | 
    
         
            +
                xml_str = save_sys_to_str(sys)
         
     | 
| 
      
 93 
     | 
    
         
            +
                with open(xml_path, "w") as f:
         
     | 
| 
      
 94 
     | 
    
         
            +
                    f.write(xml_str)
         
     | 
    
        ring/maths.py
    ADDED
    
    | 
         @@ -0,0 +1,397 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            from functools import partial
         
     | 
| 
      
 2 
     | 
    
         
            +
             
     | 
| 
      
 3 
     | 
    
         
            +
            import jax
         
     | 
| 
      
 4 
     | 
    
         
            +
            from jax import custom_jvp
         
     | 
| 
      
 5 
     | 
    
         
            +
            import jax.numpy as jnp
         
     | 
| 
      
 6 
     | 
    
         
            +
            import jax.random as jrand
         
     | 
| 
      
 7 
     | 
    
         
            +
             
     | 
| 
      
 8 
     | 
    
         
            +
             
     | 
| 
      
 9 
     | 
    
         
            +
            def wrap_to_pi(phi):
         
     | 
| 
      
 10 
     | 
    
         
            +
                "Wraps angle `phi` (radians) to interval [-pi, pi]."
         
     | 
| 
      
 11 
     | 
    
         
            +
                return (phi + jnp.pi) % (2 * jnp.pi) - jnp.pi
         
     | 
| 
      
 12 
     | 
    
         
            +
             
     | 
| 
      
 13 
     | 
    
         
            +
             
     | 
| 
      
 14 
     | 
    
         
            +
            x_unit_vector = jnp.array([1.0, 0, 0])
         
     | 
| 
      
 15 
     | 
    
         
            +
            y_unit_vector = jnp.array([0.0, 1, 0])
         
     | 
| 
      
 16 
     | 
    
         
            +
            z_unit_vector = jnp.array([0.0, 0, 1])
         
     | 
| 
      
 17 
     | 
    
         
            +
             
     | 
| 
      
 18 
     | 
    
         
            +
             
     | 
| 
      
 19 
     | 
    
         
            +
            def unit_vectors(xyz: int | str):
         
     | 
| 
      
 20 
     | 
    
         
            +
                if isinstance(xyz, str):
         
     | 
| 
      
 21 
     | 
    
         
            +
                    xyz = {"x": 0, "y": 1, "z": 2}[xyz]
         
     | 
| 
      
 22 
     | 
    
         
            +
                return [x_unit_vector, y_unit_vector, z_unit_vector][xyz]
         
     | 
| 
      
 23 
     | 
    
         
            +
             
     | 
| 
      
 24 
     | 
    
         
            +
             
     | 
| 
      
 25 
     | 
    
         
            +
            @partial(jnp.vectorize, signature="(k)->(1)")
         
     | 
| 
      
 26 
     | 
    
         
            +
            def safe_norm(x):
         
     | 
| 
      
 27 
     | 
    
         
            +
                """Grad-safe for x=0.0. Norm along last axis."""
         
     | 
| 
      
 28 
     | 
    
         
            +
                assert x.ndim == 1
         
     | 
| 
      
 29 
     | 
    
         
            +
             
     | 
| 
      
 30 
     | 
    
         
            +
                is_zero = jnp.all(jnp.isclose(x, 0.0), axis=-1, keepdims=False)
         
     | 
| 
      
 31 
     | 
    
         
            +
                return jax.lax.cond(
         
     | 
| 
      
 32 
     | 
    
         
            +
                    is_zero,
         
     | 
| 
      
 33 
     | 
    
         
            +
                    lambda x: jnp.array([0.0], dtype=x.dtype),
         
     | 
| 
      
 34 
     | 
    
         
            +
                    lambda x: jnp.linalg.norm(x, keepdims=True),
         
     | 
| 
      
 35 
     | 
    
         
            +
                    x,
         
     | 
| 
      
 36 
     | 
    
         
            +
                )
         
     | 
| 
      
 37 
     | 
    
         
            +
             
     | 
| 
      
 38 
     | 
    
         
            +
             
     | 
| 
      
 39 
     | 
    
         
            +
            @partial(jnp.vectorize, signature="(k)->(k)")
         
     | 
| 
      
 40 
     | 
    
         
            +
            def safe_normalize(x):
         
     | 
| 
      
 41 
     | 
    
         
            +
                """Execution- and Grad-safe for x=0.0. Normalizes along last axis."""
         
     | 
| 
      
 42 
     | 
    
         
            +
                assert x.ndim == 1
         
     | 
| 
      
 43 
     | 
    
         
            +
             
     | 
| 
      
 44 
     | 
    
         
            +
                is_zero = jnp.allclose(x, 0.0)
         
     | 
| 
      
 45 
     | 
    
         
            +
                return jax.lax.cond(
         
     | 
| 
      
 46 
     | 
    
         
            +
                    is_zero,
         
     | 
| 
      
 47 
     | 
    
         
            +
                    lambda x: jnp.zeros_like(x),
         
     | 
| 
      
 48 
     | 
    
         
            +
                    lambda x: x / jnp.where(is_zero, 1.0, safe_norm(x)),
         
     | 
| 
      
 49 
     | 
    
         
            +
                    x,
         
     | 
| 
      
 50 
     | 
    
         
            +
                )
         
     | 
| 
      
 51 
     | 
    
         
            +
             
     | 
| 
      
 52 
     | 
    
         
            +
             
     | 
| 
      
 53 
     | 
    
         
            +
            @custom_jvp
         
     | 
| 
      
 54 
     | 
    
         
            +
            def safe_arccos(x: jnp.ndarray) -> jnp.ndarray:
         
     | 
| 
      
 55 
     | 
    
         
            +
                """Trigonometric inverse cosine, element-wise with safety clipping in grad."""
         
     | 
| 
      
 56 
     | 
    
         
            +
                return jnp.arccos(x)
         
     | 
| 
      
 57 
     | 
    
         
            +
             
     | 
| 
      
 58 
     | 
    
         
            +
             
     | 
| 
      
 59 
     | 
    
         
            +
            @safe_arccos.defjvp
         
     | 
| 
      
 60 
     | 
    
         
            +
            def _safe_arccos_jvp(primal, tangent):
         
     | 
| 
      
 61 
     | 
    
         
            +
                (x,) = primal
         
     | 
| 
      
 62 
     | 
    
         
            +
                (x_dot,) = tangent
         
     | 
| 
      
 63 
     | 
    
         
            +
                primal_out = safe_arccos(x)
         
     | 
| 
      
 64 
     | 
    
         
            +
                tangent_out = -x_dot / jnp.sqrt(1.0 - jnp.clip(x, -1 + 1e-7, 1 - 1e-7) ** 2.0)
         
     | 
| 
      
 65 
     | 
    
         
            +
                return primal_out, tangent_out
         
     | 
| 
      
 66 
     | 
    
         
            +
             
     | 
| 
      
 67 
     | 
    
         
            +
             
     | 
| 
      
 68 
     | 
    
         
            +
            @custom_jvp
         
     | 
| 
      
 69 
     | 
    
         
            +
            def safe_arcsin(x: jnp.ndarray) -> jnp.ndarray:
         
     | 
| 
      
 70 
     | 
    
         
            +
                """Trigonometric inverse sine, element-wise with safety clipping in grad."""
         
     | 
| 
      
 71 
     | 
    
         
            +
                return jnp.arcsin(x)
         
     | 
| 
      
 72 
     | 
    
         
            +
             
     | 
| 
      
 73 
     | 
    
         
            +
             
     | 
| 
      
 74 
     | 
    
         
            +
            @safe_arcsin.defjvp
         
     | 
| 
      
 75 
     | 
    
         
            +
            def _safe_arcsin_jvp(primal, tangent):
         
     | 
| 
      
 76 
     | 
    
         
            +
                (x,) = primal
         
     | 
| 
      
 77 
     | 
    
         
            +
                (x_dot,) = tangent
         
     | 
| 
      
 78 
     | 
    
         
            +
                primal_out = safe_arccos(x)
         
     | 
| 
      
 79 
     | 
    
         
            +
                tangent_out = x_dot / jnp.sqrt(1.0 - jnp.clip(x, -1 + 1e-7, 1 - 1e-7) ** 2.0)
         
     | 
| 
      
 80 
     | 
    
         
            +
                return primal_out, tangent_out
         
     | 
| 
      
 81 
     | 
    
         
            +
             
     | 
| 
      
 82 
     | 
    
         
            +
             
     | 
| 
      
 83 
     | 
    
         
            +
            @partial(jnp.vectorize, signature="(4)->(4)")
         
     | 
| 
      
 84 
     | 
    
         
            +
            def ensure_positive_w(q):
         
     | 
| 
      
 85 
     | 
    
         
            +
                return jnp.where(q[0] < 0, -q, q)
         
     | 
| 
      
 86 
     | 
    
         
            +
             
     | 
| 
      
 87 
     | 
    
         
            +
             
     | 
| 
      
 88 
     | 
    
         
            +
            def angle_error(q, qhat):
         
     | 
| 
      
 89 
     | 
    
         
            +
                "Absolute angle in radians between `q` and `qhat`."
         
     | 
| 
      
 90 
     | 
    
         
            +
                return jnp.abs(quat_angle(quat_mul(quat_inv(q), qhat)))
         
     | 
| 
      
 91 
     | 
    
         
            +
             
     | 
| 
      
 92 
     | 
    
         
            +
             
     | 
| 
      
 93 
     | 
    
         
            +
            def unit_quats_like(array):
         
     | 
| 
      
 94 
     | 
    
         
            +
                "Array of *unit* quaternions of identical shape."
         
     | 
| 
      
 95 
     | 
    
         
            +
                if array.shape[-1] != 4:
         
     | 
| 
      
 96 
     | 
    
         
            +
                    raise Exception()
         
     | 
| 
      
 97 
     | 
    
         
            +
             
     | 
| 
      
 98 
     | 
    
         
            +
                return jnp.ones(array.shape[:-1])[..., None] * jnp.array([1.0, 0, 0, 0])
         
     | 
| 
      
 99 
     | 
    
         
            +
             
     | 
| 
      
 100 
     | 
    
         
            +
             
     | 
| 
      
 101 
     | 
    
         
            +
            @partial(jnp.vectorize, signature="(4),(4)->(4)")
         
     | 
| 
      
 102 
     | 
    
         
            +
            def quat_mul(u: jnp.ndarray, v: jnp.ndarray) -> jnp.ndarray:
         
     | 
| 
      
 103 
     | 
    
         
            +
                "Multiplies two quaternions."
         
     | 
| 
      
 104 
     | 
    
         
            +
                q = jnp.array(
         
     | 
| 
      
 105 
     | 
    
         
            +
                    [
         
     | 
| 
      
 106 
     | 
    
         
            +
                        u[0] * v[0] - u[1] * v[1] - u[2] * v[2] - u[3] * v[3],
         
     | 
| 
      
 107 
     | 
    
         
            +
                        u[0] * v[1] + u[1] * v[0] + u[2] * v[3] - u[3] * v[2],
         
     | 
| 
      
 108 
     | 
    
         
            +
                        u[0] * v[2] - u[1] * v[3] + u[2] * v[0] + u[3] * v[1],
         
     | 
| 
      
 109 
     | 
    
         
            +
                        u[0] * v[3] + u[1] * v[2] - u[2] * v[1] + u[3] * v[0],
         
     | 
| 
      
 110 
     | 
    
         
            +
                    ]
         
     | 
| 
      
 111 
     | 
    
         
            +
                )
         
     | 
| 
      
 112 
     | 
    
         
            +
                return q
         
     | 
| 
      
 113 
     | 
    
         
            +
             
     | 
| 
      
 114 
     | 
    
         
            +
             
     | 
| 
      
 115 
     | 
    
         
            +
            def quat_inv(q: jnp.ndarray) -> jnp.ndarray:
         
     | 
| 
      
 116 
     | 
    
         
            +
                "Calculates the inverse of quaternion q."
         
     | 
| 
      
 117 
     | 
    
         
            +
                return q * jnp.array([1.0, -1, -1, -1])
         
     | 
| 
      
 118 
     | 
    
         
            +
             
     | 
| 
      
 119 
     | 
    
         
            +
             
     | 
| 
      
 120 
     | 
    
         
            +
            @partial(jnp.vectorize, signature="(3),(4)->(3)")
         
     | 
| 
      
 121 
     | 
    
         
            +
            def rotate(vector: jnp.ndarray, quat: jnp.ndarray):
         
     | 
| 
      
 122 
     | 
    
         
            +
                """Rotates a vector `vector` by a *unit* quaternion `quat`."""
         
     | 
| 
      
 123 
     | 
    
         
            +
                qvec = jnp.array([0, *vector])
         
     | 
| 
      
 124 
     | 
    
         
            +
                return rotate_quat(qvec, quat)[1:4]
         
     | 
| 
      
 125 
     | 
    
         
            +
             
     | 
| 
      
 126 
     | 
    
         
            +
             
     | 
| 
      
 127 
     | 
    
         
            +
            def rotate_matrix(matrix: jax.Array, quat: jax.Array):
         
     | 
| 
      
 128 
     | 
    
         
            +
                "Rotate matrix `matrix` by a *unit* quaternion `quat`."
         
     | 
| 
      
 129 
     | 
    
         
            +
                E = quat_to_3x3(quat)
         
     | 
| 
      
 130 
     | 
    
         
            +
                return E @ matrix @ E.T
         
     | 
| 
      
 131 
     | 
    
         
            +
             
     | 
| 
      
 132 
     | 
    
         
            +
             
     | 
| 
      
 133 
     | 
    
         
            +
            def rotate_quat(q: jax.Array, quat: jax.Array):
         
     | 
| 
      
 134 
     | 
    
         
            +
                "Rotate quaternion `q` by `quat`"
         
     | 
| 
      
 135 
     | 
    
         
            +
                return quat_mul(quat, quat_mul(q, quat_inv(quat)))
         
     | 
| 
      
 136 
     | 
    
         
            +
             
     | 
| 
      
 137 
     | 
    
         
            +
             
     | 
| 
      
 138 
     | 
    
         
            +
            @partial(jnp.vectorize, signature="(3),()->(4)")
         
     | 
| 
      
 139 
     | 
    
         
            +
            def quat_rot_axis(axis: jnp.ndarray, angle: jnp.ndarray) -> jnp.ndarray:
         
     | 
| 
      
 140 
     | 
    
         
            +
                """Construct a *unit* quaternion that describes rotating around
         
     | 
| 
      
 141 
     | 
    
         
            +
                `axis` by `angle` (radians).
         
     | 
| 
      
 142 
     | 
    
         
            +
             
     | 
| 
      
 143 
     | 
    
         
            +
                This is the interpretation of rotating the vector and *not*
         
     | 
| 
      
 144 
     | 
    
         
            +
                the frame.
         
     | 
| 
      
 145 
     | 
    
         
            +
                For the interpretation of rotating the frame and *not* the
         
     | 
| 
      
 146 
     | 
    
         
            +
                vector, you should use angle -> -angle.
         
     | 
| 
      
 147 
     | 
    
         
            +
                NOTE: Usually, we actually want the second interpretation. Think about it,
         
     | 
| 
      
 148 
     | 
    
         
            +
                we use quaternions to re-express vectors in other frames. But the
         
     | 
| 
      
 149 
     | 
    
         
            +
                vectors stay the same. We only transform them to a common frames.
         
     | 
| 
      
 150 
     | 
    
         
            +
                """
         
     | 
| 
      
 151 
     | 
    
         
            +
                assert axis.shape == (3,)
         
     | 
| 
      
 152 
     | 
    
         
            +
                assert angle.shape == ()
         
     | 
| 
      
 153 
     | 
    
         
            +
             
     | 
| 
      
 154 
     | 
    
         
            +
                axis = safe_normalize(axis)
         
     | 
| 
      
 155 
     | 
    
         
            +
                # NOTE: CONVENTION
         
     | 
| 
      
 156 
     | 
    
         
            +
                # 23.04.23
         
     | 
| 
      
 157 
     | 
    
         
            +
                # this fixes the issue of prismatic joints being inverted w.r.t.
         
     | 
| 
      
 158 
     | 
    
         
            +
                # gravity vector.
         
     | 
| 
      
 159 
     | 
    
         
            +
                # The reason is that it inverts the way how revolute joints behave
         
     | 
| 
      
 160 
     | 
    
         
            +
                # Such that prismatic joints work by inverting gravity
         
     | 
| 
      
 161 
     | 
    
         
            +
                angle *= -1.0
         
     | 
| 
      
 162 
     | 
    
         
            +
                s, c = jnp.sin(angle / 2), jnp.cos(angle / 2)
         
     | 
| 
      
 163 
     | 
    
         
            +
                return jnp.array([c, *(axis * s)])
         
     | 
| 
      
 164 
     | 
    
         
            +
             
     | 
| 
      
 165 
     | 
    
         
            +
             
     | 
| 
      
 166 
     | 
    
         
            +
            @partial(jnp.vectorize, signature="(3,3)->(4)")
         
     | 
| 
      
 167 
     | 
    
         
            +
            def quat_from_3x3(m: jnp.ndarray) -> jnp.ndarray:
         
     | 
| 
      
 168 
     | 
    
         
            +
                """Converts 3x3 rotation matrix to *unit* quaternion."""
         
     | 
| 
      
 169 
     | 
    
         
            +
                w = jnp.sqrt(1 + m[0, 0] + m[1, 1] + m[2, 2]) / 2.0
         
     | 
| 
      
 170 
     | 
    
         
            +
                x = (m[2][1] - m[1][2]) / (w * 4)
         
     | 
| 
      
 171 
     | 
    
         
            +
                y = (m[0][2] - m[2][0]) / (w * 4)
         
     | 
| 
      
 172 
     | 
    
         
            +
                z = (m[1][0] - m[0][1]) / (w * 4)
         
     | 
| 
      
 173 
     | 
    
         
            +
                return jnp.array([w, x, y, z])
         
     | 
| 
      
 174 
     | 
    
         
            +
             
     | 
| 
      
 175 
     | 
    
         
            +
             
     | 
| 
      
 176 
     | 
    
         
            +
            @partial(jnp.vectorize, signature="(4)->(3,3)")
         
     | 
| 
      
 177 
     | 
    
         
            +
            def quat_to_3x3(q: jnp.ndarray) -> jnp.ndarray:
         
     | 
| 
      
 178 
     | 
    
         
            +
                """Converts *unit* quaternion to 3x3 rotation matrix."""
         
     | 
| 
      
 179 
     | 
    
         
            +
                d = jnp.dot(q, q)
         
     | 
| 
      
 180 
     | 
    
         
            +
                w, x, y, z = q
         
     | 
| 
      
 181 
     | 
    
         
            +
                s = 2 / d
         
     | 
| 
      
 182 
     | 
    
         
            +
                xs, ys, zs = x * s, y * s, z * s
         
     | 
| 
      
 183 
     | 
    
         
            +
                wx, wy, wz = w * xs, w * ys, w * zs
         
     | 
| 
      
 184 
     | 
    
         
            +
                xx, xy, xz = x * xs, x * ys, x * zs
         
     | 
| 
      
 185 
     | 
    
         
            +
                yy, yz, zz = y * ys, y * zs, z * zs
         
     | 
| 
      
 186 
     | 
    
         
            +
             
     | 
| 
      
 187 
     | 
    
         
            +
                return jnp.array(
         
     | 
| 
      
 188 
     | 
    
         
            +
                    [
         
     | 
| 
      
 189 
     | 
    
         
            +
                        jnp.array([1 - (yy + zz), xy - wz, xz + wy]),
         
     | 
| 
      
 190 
     | 
    
         
            +
                        jnp.array([xy + wz, 1 - (xx + zz), yz - wx]),
         
     | 
| 
      
 191 
     | 
    
         
            +
                        jnp.array([xz - wy, yz + wx, 1 - (xx + yy)]),
         
     | 
| 
      
 192 
     | 
    
         
            +
                    ]
         
     | 
| 
      
 193 
     | 
    
         
            +
                )
         
     | 
| 
      
 194 
     | 
    
         
            +
             
     | 
| 
      
 195 
     | 
    
         
            +
             
     | 
| 
      
 196 
     | 
    
         
            +
            def quat_random(
         
     | 
| 
      
 197 
     | 
    
         
            +
                key: jrand.PRNGKey, batch_shape: tuple = (), maxval: float = jnp.pi
         
     | 
| 
      
 198 
     | 
    
         
            +
            ) -> jax.Array:
         
     | 
| 
      
 199 
     | 
    
         
            +
                """Provides a random *unit* quaternion, sampled uniformly"""
         
     | 
| 
      
 200 
     | 
    
         
            +
                assert key.shape == (2,), f"{key.shape}"
         
     | 
| 
      
 201 
     | 
    
         
            +
                shape = batch_shape + (4,)
         
     | 
| 
      
 202 
     | 
    
         
            +
                qs = safe_normalize(jrand.normal(key, shape))
         
     | 
| 
      
 203 
     | 
    
         
            +
             
     | 
| 
      
 204 
     | 
    
         
            +
                def _scale_angle():
         
     | 
| 
      
 205 
     | 
    
         
            +
                    axis, angle = quat_to_rot_axis(qs)
         
     | 
| 
      
 206 
     | 
    
         
            +
                    angle_scaled = angle * maxval / jnp.pi
         
     | 
| 
      
 207 
     | 
    
         
            +
                    return quat_rot_axis(axis, angle_scaled)
         
     | 
| 
      
 208 
     | 
    
         
            +
             
     | 
| 
      
 209 
     | 
    
         
            +
                return jax.lax.cond(maxval == jnp.pi, lambda: qs, _scale_angle)
         
     | 
| 
      
 210 
     | 
    
         
            +
             
     | 
| 
      
 211 
     | 
    
         
            +
             
     | 
| 
      
 212 
     | 
    
         
            +
            def quat_euler(angles, intrinsic=True, convention="zyx"):
         
     | 
| 
      
 213 
     | 
    
         
            +
                "Construct a *unit* quaternion from Euler angles (radians)."
         
     | 
| 
      
 214 
     | 
    
         
            +
             
     | 
| 
      
 215 
     | 
    
         
            +
                @partial(jnp.vectorize, signature="(3)->(4)")
         
     | 
| 
      
 216 
     | 
    
         
            +
                def _quat_euler(angles):
         
     | 
| 
      
 217 
     | 
    
         
            +
                    xunit = jnp.array([1.0, 0.0, 0.0])
         
     | 
| 
      
 218 
     | 
    
         
            +
                    yunit = jnp.array([0.0, 1.0, 0.0])
         
     | 
| 
      
 219 
     | 
    
         
            +
                    zunit = jnp.array([0.0, 0.0, 1.0])
         
     | 
| 
      
 220 
     | 
    
         
            +
             
     | 
| 
      
 221 
     | 
    
         
            +
                    axes_map = {
         
     | 
| 
      
 222 
     | 
    
         
            +
                        "x": xunit,
         
     | 
| 
      
 223 
     | 
    
         
            +
                        "y": yunit,
         
     | 
| 
      
 224 
     | 
    
         
            +
                        "z": zunit,
         
     | 
| 
      
 225 
     | 
    
         
            +
                    }
         
     | 
| 
      
 226 
     | 
    
         
            +
             
     | 
| 
      
 227 
     | 
    
         
            +
                    q1 = quat_rot_axis(axes_map[convention[0]], angles[0])
         
     | 
| 
      
 228 
     | 
    
         
            +
                    q2 = quat_rot_axis(axes_map[convention[1]], angles[1])
         
     | 
| 
      
 229 
     | 
    
         
            +
                    q3 = quat_rot_axis(axes_map[convention[2]], angles[2])
         
     | 
| 
      
 230 
     | 
    
         
            +
             
     | 
| 
      
 231 
     | 
    
         
            +
                    if intrinsic:
         
     | 
| 
      
 232 
     | 
    
         
            +
                        return quat_mul(q3, quat_mul(q2, q1))
         
     | 
| 
      
 233 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 234 
     | 
    
         
            +
                        return quat_mul(q1, quat_mul(q2, q3))
         
     | 
| 
      
 235 
     | 
    
         
            +
             
     | 
| 
      
 236 
     | 
    
         
            +
                return _quat_euler(angles)
         
     | 
| 
      
 237 
     | 
    
         
            +
             
     | 
| 
      
 238 
     | 
    
         
            +
             
     | 
| 
      
 239 
     | 
    
         
            +
            @partial(jnp.vectorize, signature="(4)->()")
         
     | 
| 
      
 240 
     | 
    
         
            +
            def quat_angle(q):
         
     | 
| 
      
 241 
     | 
    
         
            +
                "Extract rotation angle (radians) of quaternion `q`."
         
     | 
| 
      
 242 
     | 
    
         
            +
                phi = 2 * jnp.arctan2(safe_norm(q[1:])[0], q[0])
         
     | 
| 
      
 243 
     | 
    
         
            +
                return wrap_to_pi(phi)
         
     | 
| 
      
 244 
     | 
    
         
            +
             
     | 
| 
      
 245 
     | 
    
         
            +
             
     | 
| 
      
 246 
     | 
    
         
            +
            def quat_angle_constantAxisOverTime(qs):
         
     | 
| 
      
 247 
     | 
    
         
            +
                assert qs.ndim == 2
         
     | 
| 
      
 248 
     | 
    
         
            +
                assert qs.shape[-1] == 4
         
     | 
| 
      
 249 
     | 
    
         
            +
             
     | 
| 
      
 250 
     | 
    
         
            +
                l2norm = lambda x: jnp.sqrt(jnp.sum(x**2, axis=-1))
         
     | 
| 
      
 251 
     | 
    
         
            +
             
     | 
| 
      
 252 
     | 
    
         
            +
                axis = safe_normalize(qs[:, 1:])
         
     | 
| 
      
 253 
     | 
    
         
            +
                angle = quat_angle(qs)[:, None]
         
     | 
| 
      
 254 
     | 
    
         
            +
                convention = axis[0]
         
     | 
| 
      
 255 
     | 
    
         
            +
                cond = (l2norm(convention - axis) > l2norm(convention + axis))[..., None]
         
     | 
| 
      
 256 
     | 
    
         
            +
                return jnp.where(cond, -angle, angle)[:, 0]
         
     | 
| 
      
 257 
     | 
    
         
            +
             
     | 
| 
      
 258 
     | 
    
         
            +
             
     | 
| 
      
 259 
     | 
    
         
            +
            @partial(jnp.vectorize, signature="(4)->(3),()")
         
     | 
| 
      
 260 
     | 
    
         
            +
            def quat_to_rot_axis(q):
         
     | 
| 
      
 261 
     | 
    
         
            +
                "Extract unit-axis and angle from quaternion `q`."
         
     | 
| 
      
 262 
     | 
    
         
            +
                angle = quat_angle(q)
         
     | 
| 
      
 263 
     | 
    
         
            +
                # NOTE: CONVENTION
         
     | 
| 
      
 264 
     | 
    
         
            +
                angle *= -1.0
         
     | 
| 
      
 265 
     | 
    
         
            +
                axis = safe_normalize(q[1:])
         
     | 
| 
      
 266 
     | 
    
         
            +
                return axis, angle
         
     | 
| 
      
 267 
     | 
    
         
            +
             
     | 
| 
      
 268 
     | 
    
         
            +
             
     | 
| 
      
 269 
     | 
    
         
            +
            @partial(jnp.vectorize, signature="(3)->(4)")
         
     | 
| 
      
 270 
     | 
    
         
            +
            def euler_to_quat(angles: jnp.ndarray) -> jnp.ndarray:
         
     | 
| 
      
 271 
     | 
    
         
            +
                """Converts euler rotations in radians to quaternion."""
         
     | 
| 
      
 272 
     | 
    
         
            +
                # this follows the Tait-Bryan intrinsic rotation formalism: x-y'-z''
         
     | 
| 
      
 273 
     | 
    
         
            +
                c1, c2, c3 = jnp.cos(angles / 2)
         
     | 
| 
      
 274 
     | 
    
         
            +
                s1, s2, s3 = jnp.sin(angles / 2)
         
     | 
| 
      
 275 
     | 
    
         
            +
                w = c1 * c2 * c3 - s1 * s2 * s3
         
     | 
| 
      
 276 
     | 
    
         
            +
                x = s1 * c2 * c3 + c1 * s2 * s3
         
     | 
| 
      
 277 
     | 
    
         
            +
                y = c1 * s2 * c3 - s1 * c2 * s3
         
     | 
| 
      
 278 
     | 
    
         
            +
                z = c1 * c2 * s3 + s1 * s2 * c3
         
     | 
| 
      
 279 
     | 
    
         
            +
                # NOTE: CONVENTION
         
     | 
| 
      
 280 
     | 
    
         
            +
                return quat_inv(jnp.array([w, x, y, z]))
         
     | 
| 
      
 281 
     | 
    
         
            +
             
     | 
| 
      
 282 
     | 
    
         
            +
             
     | 
| 
      
 283 
     | 
    
         
            +
            @partial(jnp.vectorize, signature="(4)->(3)")
         
     | 
| 
      
 284 
     | 
    
         
            +
            def quat_to_euler(q: jnp.ndarray) -> jnp.ndarray:
         
     | 
| 
      
 285 
     | 
    
         
            +
                """Converts quaternions to euler rotations in radians."""
         
     | 
| 
      
 286 
     | 
    
         
            +
                # this follows the Tait-Bryan intrinsic rotation formalism: x-y'-z''
         
     | 
| 
      
 287 
     | 
    
         
            +
             
     | 
| 
      
 288 
     | 
    
         
            +
                # NOTE: CONVENTION
         
     | 
| 
      
 289 
     | 
    
         
            +
                q = quat_inv(q)
         
     | 
| 
      
 290 
     | 
    
         
            +
             
     | 
| 
      
 291 
     | 
    
         
            +
                z = jnp.arctan2(
         
     | 
| 
      
 292 
     | 
    
         
            +
                    -2 * q[1] * q[2] + 2 * q[0] * q[3],
         
     | 
| 
      
 293 
     | 
    
         
            +
                    q[1] * q[1] + q[0] * q[0] - q[3] * q[3] - q[2] * q[2],
         
     | 
| 
      
 294 
     | 
    
         
            +
                )
         
     | 
| 
      
 295 
     | 
    
         
            +
                # TODO: Investigate why quaternions go so big we need to clip.
         
     | 
| 
      
 296 
     | 
    
         
            +
                y = safe_arcsin(jnp.clip(2 * q[1] * q[3] + 2 * q[0] * q[2], -1.0, 1.0))
         
     | 
| 
      
 297 
     | 
    
         
            +
                x = jnp.arctan2(
         
     | 
| 
      
 298 
     | 
    
         
            +
                    -2 * q[2] * q[3] + 2 * q[0] * q[1],
         
     | 
| 
      
 299 
     | 
    
         
            +
                    q[3] * q[3] - q[2] * q[2] - q[1] * q[1] + q[0] * q[0],
         
     | 
| 
      
 300 
     | 
    
         
            +
                )
         
     | 
| 
      
 301 
     | 
    
         
            +
             
     | 
| 
      
 302 
     | 
    
         
            +
                return jnp.array([x, y, z])
         
     | 
| 
      
 303 
     | 
    
         
            +
             
     | 
| 
      
 304 
     | 
    
         
            +
             
     | 
| 
      
 305 
     | 
    
         
            +
            @partial(jnp.vectorize, signature="(4),(3)->(4),(4)")
         
     | 
| 
      
 306 
     | 
    
         
            +
            def quat_project(q: jax.Array, k: jax.Array) -> tuple[jax.Array, jax.Array]:
         
     | 
| 
      
 307 
     | 
    
         
            +
                """Decompose quaternion into a primary rotation around axis `k` such that
         
     | 
| 
      
 308 
     | 
    
         
            +
                the residual rotation's angle is minimized.
         
     | 
| 
      
 309 
     | 
    
         
            +
             
     | 
| 
      
 310 
     | 
    
         
            +
                Args:
         
     | 
| 
      
 311 
     | 
    
         
            +
                    q (jax.Array): Quaternion to decompose.
         
     | 
| 
      
 312 
     | 
    
         
            +
                    k (jax.Array): Primary axis direction.
         
     | 
| 
      
 313 
     | 
    
         
            +
             
     | 
| 
      
 314 
     | 
    
         
            +
                Returns:
         
     | 
| 
      
 315 
     | 
    
         
            +
                    tuple[jax.Array, jax.Array]: Primary quaternion, residual quaternion
         
     | 
| 
      
 316 
     | 
    
         
            +
                """
         
     | 
| 
      
 317 
     | 
    
         
            +
                phi_pri = 2 * jnp.arctan2(q[1:] @ k, q[0])
         
     | 
| 
      
 318 
     | 
    
         
            +
                # NOTE: CONVENTION
         
     | 
| 
      
 319 
     | 
    
         
            +
                q_pri = quat_rot_axis(k, -phi_pri)
         
     | 
| 
      
 320 
     | 
    
         
            +
                q_res = quat_mul(q, quat_inv(q_pri))
         
     | 
| 
      
 321 
     | 
    
         
            +
                return q_pri, q_res
         
     | 
| 
      
 322 
     | 
    
         
            +
             
     | 
| 
      
 323 
     | 
    
         
            +
             
     | 
| 
      
 324 
     | 
    
         
            +
            def quat_avg(qs: jax.Array):
         
     | 
| 
      
 325 
     | 
    
         
            +
                "Tolga Birdal's algorithm."
         
     | 
| 
      
 326 
     | 
    
         
            +
                if qs.ndim == 1:
         
     | 
| 
      
 327 
     | 
    
         
            +
                    qs = qs[None, :]
         
     | 
| 
      
 328 
     | 
    
         
            +
                assert qs.ndim == 2
         
     | 
| 
      
 329 
     | 
    
         
            +
                return jnp.linalg.eigh(
         
     | 
| 
      
 330 
     | 
    
         
            +
                    jnp.einsum("ij,ik,i->...jk", qs, qs, jnp.ones((qs.shape[0],)))
         
     | 
| 
      
 331 
     | 
    
         
            +
                )[1][:, -1]
         
     | 
| 
      
 332 
     | 
    
         
            +
             
     | 
| 
      
 333 
     | 
    
         
            +
             
     | 
| 
      
 334 
     | 
    
         
            +
            # cutoff_freq=20.0; sampe_freq=100.0
         
     | 
| 
      
 335 
     | 
    
         
            +
            # -> alpha = 0.55686
         
     | 
| 
      
 336 
     | 
    
         
            +
            # cutoff_freq=15.0
         
     | 
| 
      
 337 
     | 
    
         
            +
            # -> alpha = 0.48519
         
     | 
| 
      
 338 
     | 
    
         
            +
            def quat_lowpassfilter(
         
     | 
| 
      
 339 
     | 
    
         
            +
                qs: jax.Array,
         
     | 
| 
      
 340 
     | 
    
         
            +
                cutoff_freq: float = 20.0,
         
     | 
| 
      
 341 
     | 
    
         
            +
                samp_freq: float = 100.0,
         
     | 
| 
      
 342 
     | 
    
         
            +
                filtfilt: bool = False,
         
     | 
| 
      
 343 
     | 
    
         
            +
            ) -> jax.Array:
         
     | 
| 
      
 344 
     | 
    
         
            +
                assert qs.ndim == 2
         
     | 
| 
      
 345 
     | 
    
         
            +
                assert qs.shape[1] == 4
         
     | 
| 
      
 346 
     | 
    
         
            +
             
     | 
| 
      
 347 
     | 
    
         
            +
                if filtfilt:
         
     | 
| 
      
 348 
     | 
    
         
            +
                    qs = quat_lowpassfilter(qs, cutoff_freq, samp_freq, filtfilt=False)
         
     | 
| 
      
 349 
     | 
    
         
            +
                    qs = quat_lowpassfilter(jnp.flip(qs, 0), cutoff_freq, samp_freq, filtfilt=False)
         
     | 
| 
      
 350 
     | 
    
         
            +
                    return jnp.flip(qs, 0)
         
     | 
| 
      
 351 
     | 
    
         
            +
             
     | 
| 
      
 352 
     | 
    
         
            +
                omega_times_Ts = 2 * jnp.pi * cutoff_freq / samp_freq
         
     | 
| 
      
 353 
     | 
    
         
            +
                alpha = omega_times_Ts / (1 + omega_times_Ts)
         
     | 
| 
      
 354 
     | 
    
         
            +
             
     | 
| 
      
 355 
     | 
    
         
            +
                def f(y, x):
         
     | 
| 
      
 356 
     | 
    
         
            +
                    # error quaternion; current state -> target
         
     | 
| 
      
 357 
     | 
    
         
            +
                    q_err = quat_mul(x, quat_inv(y))
         
     | 
| 
      
 358 
     | 
    
         
            +
                    # scale down error quaternion
         
     | 
| 
      
 359 
     | 
    
         
            +
                    axis, angle = quat_to_rot_axis(q_err)
         
     | 
| 
      
 360 
     | 
    
         
            +
                    # ensure angle >= 0
         
     | 
| 
      
 361 
     | 
    
         
            +
                    axis, angle = jax.lax.cond(
         
     | 
| 
      
 362 
     | 
    
         
            +
                        angle < 0,
         
     | 
| 
      
 363 
     | 
    
         
            +
                        lambda axis, angle: (-axis, -angle),
         
     | 
| 
      
 364 
     | 
    
         
            +
                        lambda axis, angle: (axis, angle),
         
     | 
| 
      
 365 
     | 
    
         
            +
                        axis,
         
     | 
| 
      
 366 
     | 
    
         
            +
                        angle,
         
     | 
| 
      
 367 
     | 
    
         
            +
                    )
         
     | 
| 
      
 368 
     | 
    
         
            +
                    angle_scaled = angle * alpha
         
     | 
| 
      
 369 
     | 
    
         
            +
                    q_err_scaled = quat_rot_axis(axis, angle_scaled)
         
     | 
| 
      
 370 
     | 
    
         
            +
                    # move small step toward error quaternion
         
     | 
| 
      
 371 
     | 
    
         
            +
                    y = quat_mul(q_err_scaled, y)
         
     | 
| 
      
 372 
     | 
    
         
            +
                    return y, y
         
     | 
| 
      
 373 
     | 
    
         
            +
             
     | 
| 
      
 374 
     | 
    
         
            +
                qs_filtered = jax.lax.scan(f, qs[0], qs[1:])[1]
         
     | 
| 
      
 375 
     | 
    
         
            +
             
     | 
| 
      
 376 
     | 
    
         
            +
                # padd with first value, such that length remains equal
         
     | 
| 
      
 377 
     | 
    
         
            +
                qs_filtered = jnp.vstack((qs[0:1], qs_filtered))
         
     | 
| 
      
 378 
     | 
    
         
            +
             
     | 
| 
      
 379 
     | 
    
         
            +
                # renormalize due to float32 numerical errors accumulating
         
     | 
| 
      
 380 
     | 
    
         
            +
                return qs_filtered / jnp.linalg.norm(qs_filtered, axis=-1, keepdims=True)
         
     | 
| 
      
 381 
     | 
    
         
            +
             
     | 
| 
      
 382 
     | 
    
         
            +
             
     | 
| 
      
 383 
     | 
    
         
            +
            def quat_inclinationAngle(q: jax.Array):
         
     | 
| 
      
 384 
     | 
    
         
            +
                head, incl = quat_project(q, jnp.array([0.0, 0, 1]))
         
     | 
| 
      
 385 
     | 
    
         
            +
                return quat_angle(incl)
         
     | 
| 
      
 386 
     | 
    
         
            +
             
     | 
| 
      
 387 
     | 
    
         
            +
             
     | 
| 
      
 388 
     | 
    
         
            +
            def quat_headingAngle(q: jax.Array):
         
     | 
| 
      
 389 
     | 
    
         
            +
                head, incl = quat_project(q, jnp.array([0.0, 0, 1]))
         
     | 
| 
      
 390 
     | 
    
         
            +
                return quat_angle(head)
         
     | 
| 
      
 391 
     | 
    
         
            +
             
     | 
| 
      
 392 
     | 
    
         
            +
             
     | 
| 
      
 393 
     | 
    
         
            +
            def quat_transfer_heading(q_from: jax.Array, q_to: jax.Array):
         
     | 
| 
      
 394 
     | 
    
         
            +
                heading = quat_project(q_from, jnp.array([0.0, 0, 1]))[0]
         
     | 
| 
      
 395 
     | 
    
         
            +
                # set heading to zero in the `q_to` quaternions
         
     | 
| 
      
 396 
     | 
    
         
            +
                q_to = quat_project(q_to, jnp.array([0.0, 0, 1]))[1]
         
     | 
| 
      
 397 
     | 
    
         
            +
                return quat_mul(q_to, heading)
         
     | 
    
        ring/ml/__init__.py
    ADDED
    
    | 
         @@ -0,0 +1,33 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            from . import base
         
     | 
| 
      
 2 
     | 
    
         
            +
            from . import callbacks
         
     | 
| 
      
 3 
     | 
    
         
            +
            from . import ml_utils
         
     | 
| 
      
 4 
     | 
    
         
            +
            from . import optimizer
         
     | 
| 
      
 5 
     | 
    
         
            +
            from . import ringnet
         
     | 
| 
      
 6 
     | 
    
         
            +
            from . import train
         
     | 
| 
      
 7 
     | 
    
         
            +
            from . import training_loop
         
     | 
| 
      
 8 
     | 
    
         
            +
            from .base import AbstractFilter
         
     | 
| 
      
 9 
     | 
    
         
            +
            from .ml_utils import on_cluster
         
     | 
| 
      
 10 
     | 
    
         
            +
            from .ml_utils import unique_id
         
     | 
| 
      
 11 
     | 
    
         
            +
            from .optimizer import make_optimizer
         
     | 
| 
      
 12 
     | 
    
         
            +
            from .ringnet import RING
         
     | 
| 
      
 13 
     | 
    
         
            +
            from .train import train_fn
         
     | 
| 
      
 14 
     | 
    
         
            +
             
     | 
| 
      
 15 
     | 
    
         
            +
             
     | 
| 
      
 16 
     | 
    
         
            +
            def RING_ICML24(params=None, **kwargs):
         
     | 
| 
      
 17 
     | 
    
         
            +
                """Create the RING network used in the icml24 paper.
         
     | 
| 
      
 18 
     | 
    
         
            +
             
     | 
| 
      
 19 
     | 
    
         
            +
                X[..., :3]  = acc
         
     | 
| 
      
 20 
     | 
    
         
            +
                X[..., 3:6] = gyr
         
     | 
| 
      
 21 
     | 
    
         
            +
                X[..., 6:9] = jointaxis
         
     | 
| 
      
 22 
     | 
    
         
            +
                X[..., 9:]  = dt
         
     | 
| 
      
 23 
     | 
    
         
            +
                """
         
     | 
| 
      
 24 
     | 
    
         
            +
                from pathlib import Path
         
     | 
| 
      
 25 
     | 
    
         
            +
             
     | 
| 
      
 26 
     | 
    
         
            +
                if params is None:
         
     | 
| 
      
 27 
     | 
    
         
            +
                    params = Path(__file__).parent.joinpath("params/0x13e3518065c21cd8.pickle")
         
     | 
| 
      
 28 
     | 
    
         
            +
             
     | 
| 
      
 29 
     | 
    
         
            +
                ringnet = RING(params=params, **kwargs)  # noqa: F811
         
     | 
| 
      
 30 
     | 
    
         
            +
                ringnet = base.ScaleX_FilterWrapper(ringnet)
         
     | 
| 
      
 31 
     | 
    
         
            +
                ringnet = base.LPF_FilterWrapper(ringnet, 10.0, samp_freq=None)
         
     | 
| 
      
 32 
     | 
    
         
            +
                ringnet = base.GroundTruthHeading_FilterWrapper(ringnet)
         
     | 
| 
      
 33 
     | 
    
         
            +
                return ringnet
         
     |