imt-ring 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- imt_ring-1.2.1.dist-info/METADATA +91 -0
 - imt_ring-1.2.1.dist-info/RECORD +83 -0
 - imt_ring-1.2.1.dist-info/WHEEL +5 -0
 - imt_ring-1.2.1.dist-info/top_level.txt +1 -0
 - ring/__init__.py +63 -0
 - ring/algebra.py +100 -0
 - ring/algorithms/__init__.py +45 -0
 - ring/algorithms/_random.py +403 -0
 - ring/algorithms/custom_joints/__init__.py +6 -0
 - ring/algorithms/custom_joints/rr_imp_joint.py +69 -0
 - ring/algorithms/custom_joints/rr_joint.py +33 -0
 - ring/algorithms/custom_joints/suntay.py +424 -0
 - ring/algorithms/dynamics.py +345 -0
 - ring/algorithms/generator/__init__.py +25 -0
 - ring/algorithms/generator/base.py +414 -0
 - ring/algorithms/generator/batch.py +282 -0
 - ring/algorithms/generator/motion_artifacts.py +222 -0
 - ring/algorithms/generator/pd_control.py +182 -0
 - ring/algorithms/generator/randomize.py +119 -0
 - ring/algorithms/generator/transforms.py +410 -0
 - ring/algorithms/generator/types.py +36 -0
 - ring/algorithms/jcalc.py +840 -0
 - ring/algorithms/kinematics.py +202 -0
 - ring/algorithms/sensors.py +582 -0
 - ring/base.py +1046 -0
 - ring/io/__init__.py +9 -0
 - ring/io/examples/branched.xml +24 -0
 - ring/io/examples/exclude/knee_trans_dof.xml +26 -0
 - ring/io/examples/exclude/standard_sys.xml +106 -0
 - ring/io/examples/exclude/standard_sys_rr_imp.xml +106 -0
 - ring/io/examples/inv_pendulum.xml +14 -0
 - ring/io/examples/knee_flexible_imus.xml +22 -0
 - ring/io/examples/spherical_stiff.xml +11 -0
 - ring/io/examples/symmetric.xml +12 -0
 - ring/io/examples/test_all_1.xml +39 -0
 - ring/io/examples/test_all_2.xml +39 -0
 - ring/io/examples/test_ang0_pos0.xml +9 -0
 - ring/io/examples/test_control.xml +16 -0
 - ring/io/examples/test_double_pendulum.xml +14 -0
 - ring/io/examples/test_free.xml +11 -0
 - ring/io/examples/test_kinematics.xml +23 -0
 - ring/io/examples/test_morph_system/four_seg_seg1.xml +26 -0
 - ring/io/examples/test_morph_system/four_seg_seg3.xml +26 -0
 - ring/io/examples/test_randomize_position.xml +26 -0
 - ring/io/examples/test_sensors.xml +13 -0
 - ring/io/examples/test_three_seg_seg2.xml +23 -0
 - ring/io/examples.py +42 -0
 - ring/io/test_examples.py +6 -0
 - ring/io/xml/__init__.py +6 -0
 - ring/io/xml/abstract.py +300 -0
 - ring/io/xml/from_xml.py +299 -0
 - ring/io/xml/test_from_xml.py +56 -0
 - ring/io/xml/test_to_xml.py +31 -0
 - ring/io/xml/to_xml.py +94 -0
 - ring/maths.py +397 -0
 - ring/ml/__init__.py +33 -0
 - ring/ml/base.py +292 -0
 - ring/ml/callbacks.py +434 -0
 - ring/ml/ml_utils.py +272 -0
 - ring/ml/optimizer.py +149 -0
 - ring/ml/params/0x13e3518065c21cd8.pickle +0 -0
 - ring/ml/ringnet.py +279 -0
 - ring/ml/train.py +318 -0
 - ring/ml/training_loop.py +131 -0
 - ring/rendering/__init__.py +2 -0
 - ring/rendering/base_render.py +271 -0
 - ring/rendering/mujoco_render.py +222 -0
 - ring/rendering/vispy_render.py +340 -0
 - ring/rendering/vispy_visuals.py +290 -0
 - ring/sim2real/__init__.py +7 -0
 - ring/sim2real/sim2real.py +288 -0
 - ring/spatial.py +126 -0
 - ring/sys_composer/__init__.py +5 -0
 - ring/sys_composer/delete_sys.py +114 -0
 - ring/sys_composer/inject_sys.py +110 -0
 - ring/sys_composer/morph_sys.py +361 -0
 - ring/utils/__init__.py +21 -0
 - ring/utils/batchsize.py +51 -0
 - ring/utils/colab.py +48 -0
 - ring/utils/hdf5.py +198 -0
 - ring/utils/normalizer.py +56 -0
 - ring/utils/path.py +44 -0
 - ring/utils/utils.py +161 -0
 
    
        ring/ml/base.py
    ADDED
    
    | 
         @@ -0,0 +1,292 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            from abc import ABC
         
     | 
| 
      
 2 
     | 
    
         
            +
            from abc import abstractmethod
         
     | 
| 
      
 3 
     | 
    
         
            +
             
     | 
| 
      
 4 
     | 
    
         
            +
            import jax
         
     | 
| 
      
 5 
     | 
    
         
            +
            import jax.numpy as jnp
         
     | 
| 
      
 6 
     | 
    
         
            +
            import tree_utils
         
     | 
| 
      
 7 
     | 
    
         
            +
             
     | 
| 
      
 8 
     | 
    
         
            +
            import ring
         
     | 
| 
      
 9 
     | 
    
         
            +
            from ring.utils import pickle_load
         
     | 
| 
      
 10 
     | 
    
         
            +
            from ring.utils import pickle_save
         
     | 
| 
      
 11 
     | 
    
         
            +
             
     | 
| 
      
 12 
     | 
    
         
            +
             
     | 
| 
      
 13 
     | 
    
         
            +
            def _to_3d(tree):
         
     | 
| 
      
 14 
     | 
    
         
            +
                if tree is None:
         
     | 
| 
      
 15 
     | 
    
         
            +
                    return None
         
     | 
| 
      
 16 
     | 
    
         
            +
                return jax.tree_map(lambda arr: arr[None], tree)
         
     | 
| 
      
 17 
     | 
    
         
            +
             
     | 
| 
      
 18 
     | 
    
         
            +
             
     | 
| 
      
 19 
     | 
    
         
            +
            def _to_2d(tree, i: int = 0):
         
     | 
| 
      
 20 
     | 
    
         
            +
                if tree is None:
         
     | 
| 
      
 21 
     | 
    
         
            +
                    return None
         
     | 
| 
      
 22 
     | 
    
         
            +
                return jax.tree_map(lambda arr: arr[i], tree)
         
     | 
| 
      
 23 
     | 
    
         
            +
             
     | 
| 
      
 24 
     | 
    
         
            +
             
     | 
| 
      
 25 
     | 
    
         
            +
            class AbstractFilter(ABC):
         
     | 
| 
      
 26 
     | 
    
         
            +
                def _apply_unbatched(self, X, params, state, y, lam):
         
     | 
| 
      
 27 
     | 
    
         
            +
                    return _to_2d(
         
     | 
| 
      
 28 
     | 
    
         
            +
                        self._apply_batched(
         
     | 
| 
      
 29 
     | 
    
         
            +
                            X=_to_3d(X), params=params, state=_to_3d(state), y=_to_3d(y), lam=lam
         
     | 
| 
      
 30 
     | 
    
         
            +
                        )
         
     | 
| 
      
 31 
     | 
    
         
            +
                    )
         
     | 
| 
      
 32 
     | 
    
         
            +
             
     | 
| 
      
 33 
     | 
    
         
            +
                @abstractmethod
         
     | 
| 
      
 34 
     | 
    
         
            +
                def _apply_batched(self, X, params, state, y, lam):
         
     | 
| 
      
 35 
     | 
    
         
            +
                    pass
         
     | 
| 
      
 36 
     | 
    
         
            +
             
     | 
| 
      
 37 
     | 
    
         
            +
                @abstractmethod
         
     | 
| 
      
 38 
     | 
    
         
            +
                def init(self, bs, X, lam, seed: int):
         
     | 
| 
      
 39 
     | 
    
         
            +
                    pass
         
     | 
| 
      
 40 
     | 
    
         
            +
             
     | 
| 
      
 41 
     | 
    
         
            +
                def apply(self, X, params=None, state=None, y=None, lam=None):
         
     | 
| 
      
 42 
     | 
    
         
            +
                    "X.shape = (B, T, N, F) or (T, N, F)"
         
     | 
| 
      
 43 
     | 
    
         
            +
                    assert X.ndim in [3, 4]
         
     | 
| 
      
 44 
     | 
    
         
            +
                    if X.ndim == 4:
         
     | 
| 
      
 45 
     | 
    
         
            +
                        return self._apply_batched(X, params, state, y, lam)
         
     | 
| 
      
 46 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 47 
     | 
    
         
            +
                        return self._apply_unbatched(X, params, state, y, lam)
         
     | 
| 
      
 48 
     | 
    
         
            +
             
     | 
| 
      
 49 
     | 
    
         
            +
                @property
         
     | 
| 
      
 50 
     | 
    
         
            +
                def name(self) -> str:
         
     | 
| 
      
 51 
     | 
    
         
            +
                    if not hasattr(self, "_name"):
         
     | 
| 
      
 52 
     | 
    
         
            +
                        raise NotImplementedError
         
     | 
| 
      
 53 
     | 
    
         
            +
             
     | 
| 
      
 54 
     | 
    
         
            +
                    if self._name is None:
         
     | 
| 
      
 55 
     | 
    
         
            +
                        raise RuntimeError("No `name` was given.")
         
     | 
| 
      
 56 
     | 
    
         
            +
                    return self._name
         
     | 
| 
      
 57 
     | 
    
         
            +
             
     | 
| 
      
 58 
     | 
    
         
            +
                def nojit(self) -> "AbstractFilter":
         
     | 
| 
      
 59 
     | 
    
         
            +
                    return self
         
     | 
| 
      
 60 
     | 
    
         
            +
             
     | 
| 
      
 61 
     | 
    
         
            +
                def _pre_save(self, *args, **kwargs) -> None:
         
     | 
| 
      
 62 
     | 
    
         
            +
                    pass
         
     | 
| 
      
 63 
     | 
    
         
            +
             
     | 
| 
      
 64 
     | 
    
         
            +
                def save(self, path: str, *args, **kwargs):
         
     | 
| 
      
 65 
     | 
    
         
            +
                    self._pre_save(*args, **kwargs)
         
     | 
| 
      
 66 
     | 
    
         
            +
                    pickle_save(self.nojit(), path, overwrite=True)
         
     | 
| 
      
 67 
     | 
    
         
            +
             
     | 
| 
      
 68 
     | 
    
         
            +
                @staticmethod
         
     | 
| 
      
 69 
     | 
    
         
            +
                def _post_load(filter: "AbstractFilter", *args, **kwargs) -> "AbstractFilter":
         
     | 
| 
      
 70 
     | 
    
         
            +
                    pass
         
     | 
| 
      
 71 
     | 
    
         
            +
             
     | 
| 
      
 72 
     | 
    
         
            +
                @classmethod
         
     | 
| 
      
 73 
     | 
    
         
            +
                def load(cls, path: str, *args, **kwargs):
         
     | 
| 
      
 74 
     | 
    
         
            +
                    filter = pickle_load(path)
         
     | 
| 
      
 75 
     | 
    
         
            +
                    return cls._post_load(filter, *args, **kwargs)
         
     | 
| 
      
 76 
     | 
    
         
            +
             
     | 
| 
      
 77 
     | 
    
         
            +
                def search_attr(self, attr: str):
         
     | 
| 
      
 78 
     | 
    
         
            +
                    return getattr(self, attr)
         
     | 
| 
      
 79 
     | 
    
         
            +
             
     | 
| 
      
 80 
     | 
    
         
            +
             
     | 
| 
      
 81 
     | 
    
         
            +
            class AbstractFilterUnbatched(AbstractFilter):
         
     | 
| 
      
 82 
     | 
    
         
            +
                @abstractmethod
         
     | 
| 
      
 83 
     | 
    
         
            +
                def _apply_unbatched(self, X, params, state, y, lam):
         
     | 
| 
      
 84 
     | 
    
         
            +
                    pass
         
     | 
| 
      
 85 
     | 
    
         
            +
             
     | 
| 
      
 86 
     | 
    
         
            +
                def _apply_batched(self, X, params, state, y, lam):
         
     | 
| 
      
 87 
     | 
    
         
            +
                    N = X.shape[0]
         
     | 
| 
      
 88 
     | 
    
         
            +
                    ys = []
         
     | 
| 
      
 89 
     | 
    
         
            +
                    for i in range(N):
         
     | 
| 
      
 90 
     | 
    
         
            +
                        ys.append(
         
     | 
| 
      
 91 
     | 
    
         
            +
                            self._apply_unbatched(
         
     | 
| 
      
 92 
     | 
    
         
            +
                                _to_2d(X, i), params, _to_2d(state, i), _to_2d(y, i), lam
         
     | 
| 
      
 93 
     | 
    
         
            +
                            )
         
     | 
| 
      
 94 
     | 
    
         
            +
                        )
         
     | 
| 
      
 95 
     | 
    
         
            +
                    return tree_utils.tree_batch(ys)
         
     | 
| 
      
 96 
     | 
    
         
            +
             
     | 
| 
      
 97 
     | 
    
         
            +
             
     | 
| 
      
 98 
     | 
    
         
            +
            class AbstractFilterWrapper(AbstractFilter):
         
     | 
| 
      
 99 
     | 
    
         
            +
                def __init__(self, filter: AbstractFilter, name=None) -> None:
         
     | 
| 
      
 100 
     | 
    
         
            +
                    self._filter = filter
         
     | 
| 
      
 101 
     | 
    
         
            +
                    self._name = name
         
     | 
| 
      
 102 
     | 
    
         
            +
             
     | 
| 
      
 103 
     | 
    
         
            +
                def _apply_batched(self, X, params, state, y, lam):
         
     | 
| 
      
 104 
     | 
    
         
            +
                    raise NotImplementedError
         
     | 
| 
      
 105 
     | 
    
         
            +
             
     | 
| 
      
 106 
     | 
    
         
            +
                @property
         
     | 
| 
      
 107 
     | 
    
         
            +
                def unwrapped(self) -> AbstractFilter:
         
     | 
| 
      
 108 
     | 
    
         
            +
                    return self._filter
         
     | 
| 
      
 109 
     | 
    
         
            +
             
     | 
| 
      
 110 
     | 
    
         
            +
                def apply(self, X, params=None, state=None, y=None, lam=None):
         
     | 
| 
      
 111 
     | 
    
         
            +
                    return self.unwrapped.apply(X=X, params=params, state=state, y=y, lam=lam)
         
     | 
| 
      
 112 
     | 
    
         
            +
             
     | 
| 
      
 113 
     | 
    
         
            +
                def init(self, bs=None, X=None, lam=None, seed: int = 1):
         
     | 
| 
      
 114 
     | 
    
         
            +
                    return self.unwrapped.init(bs=bs, X=X, lam=lam, seed=seed)
         
     | 
| 
      
 115 
     | 
    
         
            +
             
     | 
| 
      
 116 
     | 
    
         
            +
                def nojit(self) -> "AbstractFilterWrapper":
         
     | 
| 
      
 117 
     | 
    
         
            +
                    self._filter = self.unwrapped.nojit()
         
     | 
| 
      
 118 
     | 
    
         
            +
                    return self
         
     | 
| 
      
 119 
     | 
    
         
            +
             
     | 
| 
      
 120 
     | 
    
         
            +
                def search_attr(self, attr: str):
         
     | 
| 
      
 121 
     | 
    
         
            +
                    if hasattr(self, attr):
         
     | 
| 
      
 122 
     | 
    
         
            +
                        return super().search_attr(attr)
         
     | 
| 
      
 123 
     | 
    
         
            +
                    return self.unwrapped.search_attr(attr)
         
     | 
| 
      
 124 
     | 
    
         
            +
             
     | 
| 
      
 125 
     | 
    
         
            +
                def _pre_save(self, *args, **kwargs):
         
     | 
| 
      
 126 
     | 
    
         
            +
                    self.unwrapped._pre_save(*args, **kwargs)
         
     | 
| 
      
 127 
     | 
    
         
            +
             
     | 
| 
      
 128 
     | 
    
         
            +
                @staticmethod
         
     | 
| 
      
 129 
     | 
    
         
            +
                def _post_load(
         
     | 
| 
      
 130 
     | 
    
         
            +
                    wrapper: "AbstractFilterWrapper", *args, **kwargs
         
     | 
| 
      
 131 
     | 
    
         
            +
                ) -> "AbstractFilterWrapper":
         
     | 
| 
      
 132 
     | 
    
         
            +
                    wrapper._filter = wrapper._filter._post_load(wrapper._filter, *args, **kwargs)
         
     | 
| 
      
 133 
     | 
    
         
            +
                    return wrapper
         
     | 
| 
      
 134 
     | 
    
         
            +
             
     | 
| 
      
 135 
     | 
    
         
            +
                @property
         
     | 
| 
      
 136 
     | 
    
         
            +
                def name(self):
         
     | 
| 
      
 137 
     | 
    
         
            +
                    return self.unwrapped.name + " ->\n" + super().name
         
     | 
| 
      
 138 
     | 
    
         
            +
             
     | 
| 
      
 139 
     | 
    
         
            +
             
     | 
| 
      
 140 
     | 
    
         
            +
            class LPF_FilterWrapper(AbstractFilterWrapper):
         
     | 
| 
      
 141 
     | 
    
         
            +
                def __init__(
         
     | 
| 
      
 142 
     | 
    
         
            +
                    self,
         
     | 
| 
      
 143 
     | 
    
         
            +
                    filter: AbstractFilter,
         
     | 
| 
      
 144 
     | 
    
         
            +
                    cutoff_freq: float,
         
     | 
| 
      
 145 
     | 
    
         
            +
                    samp_freq: float | None,
         
     | 
| 
      
 146 
     | 
    
         
            +
                    filtfilt: bool = True,
         
     | 
| 
      
 147 
     | 
    
         
            +
                    name="LPF_FilterWrapper",
         
     | 
| 
      
 148 
     | 
    
         
            +
                ) -> None:
         
     | 
| 
      
 149 
     | 
    
         
            +
                    super().__init__(filter, name)
         
     | 
| 
      
 150 
     | 
    
         
            +
                    self.samp_freq = samp_freq
         
     | 
| 
      
 151 
     | 
    
         
            +
                    self._kwargs = dict(cutoff_freq=cutoff_freq, filtfilt=filtfilt)
         
     | 
| 
      
 152 
     | 
    
         
            +
             
     | 
| 
      
 153 
     | 
    
         
            +
                def apply(self, X, params=None, state=None, y=None, lam=None):
         
     | 
| 
      
 154 
     | 
    
         
            +
                    if X.ndim == 4:
         
     | 
| 
      
 155 
     | 
    
         
            +
                        if self.samp_freq is not None:
         
     | 
| 
      
 156 
     | 
    
         
            +
                            samp_freq = jnp.repeat(jnp.array(self.samp_freq), X.shape[0])
         
     | 
| 
      
 157 
     | 
    
         
            +
                        else:
         
     | 
| 
      
 158 
     | 
    
         
            +
                            assert X.shape[-1] == 10
         
     | 
| 
      
 159 
     | 
    
         
            +
                            dt = X[:, 0, 0, -1]
         
     | 
| 
      
 160 
     | 
    
         
            +
                            samp_freq = 1 / dt
         
     | 
| 
      
 161 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 162 
     | 
    
         
            +
                        if self.samp_freq is not None:
         
     | 
| 
      
 163 
     | 
    
         
            +
                            samp_freq = jnp.array(self.samp_freq)
         
     | 
| 
      
 164 
     | 
    
         
            +
                        else:
         
     | 
| 
      
 165 
     | 
    
         
            +
                            assert X.shape[-1] == 10
         
     | 
| 
      
 166 
     | 
    
         
            +
                            dt = X[0, 0, -1]
         
     | 
| 
      
 167 
     | 
    
         
            +
                            samp_freq = 1 / dt
         
     | 
| 
      
 168 
     | 
    
         
            +
             
     | 
| 
      
 169 
     | 
    
         
            +
                    if self.samp_freq is None:
         
     | 
| 
      
 170 
     | 
    
         
            +
                        print(f"Detected the following sampling rates from `X`: {samp_freq}")
         
     | 
| 
      
 171 
     | 
    
         
            +
             
     | 
| 
      
 172 
     | 
    
         
            +
                    yhat, state = super().apply(X, params, state, y, lam)
         
     | 
| 
      
 173 
     | 
    
         
            +
             
     | 
| 
      
 174 
     | 
    
         
            +
                    if yhat.ndim == 4:
         
     | 
| 
      
 175 
     | 
    
         
            +
                        yhat = jax.vmap(
         
     | 
| 
      
 176 
     | 
    
         
            +
                            jax.vmap(
         
     | 
| 
      
 177 
     | 
    
         
            +
                                lambda q, samp_freq: ring.maths.quat_lowpassfilter(
         
     | 
| 
      
 178 
     | 
    
         
            +
                                    q, samp_freq=samp_freq, **self._kwargs
         
     | 
| 
      
 179 
     | 
    
         
            +
                                ),
         
     | 
| 
      
 180 
     | 
    
         
            +
                                in_axes=(1, None),
         
     | 
| 
      
 181 
     | 
    
         
            +
                                out_axes=1,
         
     | 
| 
      
 182 
     | 
    
         
            +
                            )
         
     | 
| 
      
 183 
     | 
    
         
            +
                        )(yhat, samp_freq)
         
     | 
| 
      
 184 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 185 
     | 
    
         
            +
                        yhat = jax.vmap(
         
     | 
| 
      
 186 
     | 
    
         
            +
                            lambda q, samp_freq: ring.maths.quat_lowpassfilter(
         
     | 
| 
      
 187 
     | 
    
         
            +
                                q, samp_freq=samp_freq, **self._kwargs
         
     | 
| 
      
 188 
     | 
    
         
            +
                            ),
         
     | 
| 
      
 189 
     | 
    
         
            +
                            in_axes=(1, None),
         
     | 
| 
      
 190 
     | 
    
         
            +
                            out_axes=1,
         
     | 
| 
      
 191 
     | 
    
         
            +
                        )(yhat, samp_freq)
         
     | 
| 
      
 192 
     | 
    
         
            +
                    return yhat, state
         
     | 
| 
      
 193 
     | 
    
         
            +
             
     | 
| 
      
 194 
     | 
    
         
            +
             
     | 
| 
      
 195 
     | 
    
         
            +
            class GroundTruthHeading_FilterWrapper(AbstractFilterWrapper):
         
     | 
| 
      
 196 
     | 
    
         
            +
             
     | 
| 
      
 197 
     | 
    
         
            +
                def __init__(
         
     | 
| 
      
 198 
     | 
    
         
            +
                    self, filter: AbstractFilter, name="GroundTruthHeading_FilterWrapper"
         
     | 
| 
      
 199 
     | 
    
         
            +
                ) -> None:
         
     | 
| 
      
 200 
     | 
    
         
            +
                    super().__init__(filter, name)
         
     | 
| 
      
 201 
     | 
    
         
            +
             
     | 
| 
      
 202 
     | 
    
         
            +
                def apply(self, X, params=None, state=None, y=None, lam=None):
         
     | 
| 
      
 203 
     | 
    
         
            +
                    yhat, state = super().apply(X, params, state, y, lam)
         
     | 
| 
      
 204 
     | 
    
         
            +
                    if lam is None:
         
     | 
| 
      
 205 
     | 
    
         
            +
                        lam = self.search_attr("lam")
         
     | 
| 
      
 206 
     | 
    
         
            +
                    yhat = self.transfer_ground_truth_heading(lam, y, yhat)
         
     | 
| 
      
 207 
     | 
    
         
            +
                    return yhat, state
         
     | 
| 
      
 208 
     | 
    
         
            +
             
     | 
| 
      
 209 
     | 
    
         
            +
                @staticmethod
         
     | 
| 
      
 210 
     | 
    
         
            +
                def transfer_ground_truth_heading(lam, y, yhat) -> None:
         
     | 
| 
      
 211 
     | 
    
         
            +
                    if y is None:
         
     | 
| 
      
 212 
     | 
    
         
            +
                        return yhat
         
     | 
| 
      
 213 
     | 
    
         
            +
             
     | 
| 
      
 214 
     | 
    
         
            +
                    assert lam is not None
         
     | 
| 
      
 215 
     | 
    
         
            +
                    yhat = jnp.array(yhat)
         
     | 
| 
      
 216 
     | 
    
         
            +
                    for i, p in enumerate(lam):
         
     | 
| 
      
 217 
     | 
    
         
            +
                        if p == -1:
         
     | 
| 
      
 218 
     | 
    
         
            +
                            yhat = yhat.at[..., i, :].set(
         
     | 
| 
      
 219 
     | 
    
         
            +
                                ring.maths.quat_transfer_heading(y[..., i, :], yhat[..., i, :])
         
     | 
| 
      
 220 
     | 
    
         
            +
                            )
         
     | 
| 
      
 221 
     | 
    
         
            +
                    return yhat
         
     | 
| 
      
 222 
     | 
    
         
            +
             
     | 
| 
      
 223 
     | 
    
         
            +
             
     | 
| 
      
 224 
     | 
    
         
            +
            _default_factors = dict(gyr=1 / 2.2, acc=1 / 9.81, joint_axes=1 / 0.57, dt=10.0)
         
     | 
| 
      
 225 
     | 
    
         
            +
             
     | 
| 
      
 226 
     | 
    
         
            +
             
     | 
| 
      
 227 
     | 
    
         
            +
            class ScaleX_FilterWrapper(AbstractFilterWrapper):
         
     | 
| 
      
 228 
     | 
    
         
            +
             
     | 
| 
      
 229 
     | 
    
         
            +
                def __init__(
         
     | 
| 
      
 230 
     | 
    
         
            +
                    self,
         
     | 
| 
      
 231 
     | 
    
         
            +
                    filter: AbstractFilter,
         
     | 
| 
      
 232 
     | 
    
         
            +
                    factors: dict[str, float] = _default_factors,
         
     | 
| 
      
 233 
     | 
    
         
            +
                    name="ScaleX_FilterWrapper",
         
     | 
| 
      
 234 
     | 
    
         
            +
                ) -> None:
         
     | 
| 
      
 235 
     | 
    
         
            +
                    super().__init__(filter, name)
         
     | 
| 
      
 236 
     | 
    
         
            +
                    self._factors = factors
         
     | 
| 
      
 237 
     | 
    
         
            +
             
     | 
| 
      
 238 
     | 
    
         
            +
                def apply(self, X, params=None, state=None, y=None, lam=None):
         
     | 
| 
      
 239 
     | 
    
         
            +
                    F = X.shape[-1]
         
     | 
| 
      
 240 
     | 
    
         
            +
                    num_batch_dims = X.ndim - 1
         
     | 
| 
      
 241 
     | 
    
         
            +
             
     | 
| 
      
 242 
     | 
    
         
            +
                    if F == 6:
         
     | 
| 
      
 243 
     | 
    
         
            +
                        X = dict(acc=X[..., :3], gyr=X[..., 3:])
         
     | 
| 
      
 244 
     | 
    
         
            +
                    elif F == 9:
         
     | 
| 
      
 245 
     | 
    
         
            +
                        X = dict(acc=X[..., :3], gyr=X[..., 3:6], joint_axes=X[..., 6:])
         
     | 
| 
      
 246 
     | 
    
         
            +
                    elif F == 10:
         
     | 
| 
      
 247 
     | 
    
         
            +
                        X = dict(
         
     | 
| 
      
 248 
     | 
    
         
            +
                            acc=X[..., :3], gyr=X[..., 3:6], joint_axes=X[..., 6:9], dt=X[..., 9:10]
         
     | 
| 
      
 249 
     | 
    
         
            +
                        )
         
     | 
| 
      
 250 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 251 
     | 
    
         
            +
                        raise Exception(f"X.shape={X.shape}")
         
     | 
| 
      
 252 
     | 
    
         
            +
                    X = {key: val * self._factors[key] for key, val in X.items()}
         
     | 
| 
      
 253 
     | 
    
         
            +
                    X = tree_utils.batch_concat_acme(X, num_batch_dims=num_batch_dims)
         
     | 
| 
      
 254 
     | 
    
         
            +
                    return super().apply(X, params, state, y, lam)
         
     | 
| 
      
 255 
     | 
    
         
            +
             
     | 
| 
      
 256 
     | 
    
         
            +
             
     | 
| 
      
 257 
     | 
    
         
            +
            class NoGraph_FilterWrapper(AbstractFilterWrapper):
         
     | 
| 
      
 258 
     | 
    
         
            +
             
     | 
| 
      
 259 
     | 
    
         
            +
                def __init__(
         
     | 
| 
      
 260 
     | 
    
         
            +
                    self, filter: AbstractFilter, quat_normalize: bool = False, name=None
         
     | 
| 
      
 261 
     | 
    
         
            +
                ) -> None:
         
     | 
| 
      
 262 
     | 
    
         
            +
                    super().__init__(filter, name)
         
     | 
| 
      
 263 
     | 
    
         
            +
                    self._quat_normalize = quat_normalize
         
     | 
| 
      
 264 
     | 
    
         
            +
             
     | 
| 
      
 265 
     | 
    
         
            +
                def init(self, bs=None, X=None, lam=None, seed: int = 1):
         
     | 
| 
      
 266 
     | 
    
         
            +
                    batched = X.ndim == 4
         
     | 
| 
      
 267 
     | 
    
         
            +
                    if batched:
         
     | 
| 
      
 268 
     | 
    
         
            +
                        B, T, N, F = X.shape
         
     | 
| 
      
 269 
     | 
    
         
            +
                        X = X.reshape((B, T, 1, N * F))
         
     | 
| 
      
 270 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 271 
     | 
    
         
            +
                        T, N, F = X.shape
         
     | 
| 
      
 272 
     | 
    
         
            +
                        X = X.reshape(T, 1, N * F)
         
     | 
| 
      
 273 
     | 
    
         
            +
                    return super().init(bs, X, (-1,), seed)
         
     | 
| 
      
 274 
     | 
    
         
            +
             
     | 
| 
      
 275 
     | 
    
         
            +
                def apply(self, X: jax.Array, params=None, state=None, y=None, lam=None):
         
     | 
| 
      
 276 
     | 
    
         
            +
                    batched = X.ndim == 4
         
     | 
| 
      
 277 
     | 
    
         
            +
                    if batched:
         
     | 
| 
      
 278 
     | 
    
         
            +
                        B, T, N, F = X.shape
         
     | 
| 
      
 279 
     | 
    
         
            +
                        X = X.reshape((B, T, 1, N * F))
         
     | 
| 
      
 280 
     | 
    
         
            +
                        yhat, state = super().apply(X, params, state, y, (-1,))
         
     | 
| 
      
 281 
     | 
    
         
            +
                        yhat = yhat.reshape((B, T, N, -1))
         
     | 
| 
      
 282 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 283 
     | 
    
         
            +
                        T, N, F = X.shape
         
     | 
| 
      
 284 
     | 
    
         
            +
                        X = X.reshape((T, 1, N * F))
         
     | 
| 
      
 285 
     | 
    
         
            +
                        yhat, state = super().apply(X, params, state, y, (-1,))
         
     | 
| 
      
 286 
     | 
    
         
            +
                        yhat = yhat.reshape((T, N, -1))
         
     | 
| 
      
 287 
     | 
    
         
            +
             
     | 
| 
      
 288 
     | 
    
         
            +
                    if self._quat_normalize:
         
     | 
| 
      
 289 
     | 
    
         
            +
                        assert yhat.shape[-1] == 4
         
     | 
| 
      
 290 
     | 
    
         
            +
                        yhat = ring.maths.safe_normalize(yhat)
         
     | 
| 
      
 291 
     | 
    
         
            +
             
     | 
| 
      
 292 
     | 
    
         
            +
                    return yhat, state
         
     |