imt-ring 1.3.13__py3-none-any.whl → 1.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {imt_ring-1.3.13.dist-info → imt_ring-1.4.0.dist-info}/METADATA +1 -1
- {imt_ring-1.3.13.dist-info → imt_ring-1.4.0.dist-info}/RECORD +7 -7
- ring/__init__.py +68 -1
- ring/ml/__init__.py +2 -23
- ring/ml/base.py +21 -0
- {imt_ring-1.3.13.dist-info → imt_ring-1.4.0.dist-info}/WHEEL +0 -0
- {imt_ring-1.3.13.dist-info → imt_ring-1.4.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
|
|
1
|
-
ring/__init__.py,sha256
|
1
|
+
ring/__init__.py,sha256=-70A1E5LQQyGdOA_u9PlYZb1d5Fz0yXnvRoZOrbxq6o,3781
|
2
2
|
ring/algebra.py,sha256=F0GwbP8LQP5qGVkoMUYJmkp9Hn2nKAVIkCVYDEjNjGU,3128
|
3
3
|
ring/base.py,sha256=YFPrUWelWswEhq8x8Byv-5pK64mipiGW6x5IlMr4we4,33803
|
4
4
|
ring/maths.py,sha256=jJr_kr78-XDce8B4tXQ2Li-jBntVQhaS8csxglCsj8A,12193
|
@@ -50,8 +50,8 @@ ring/io/xml/from_xml.py,sha256=8b44sPVWgoY8JGJZLpJ8M_eLfcfu3IsMtBzSytPTPmw,9234
|
|
50
50
|
ring/io/xml/test_from_xml.py,sha256=bckVrVVmEhCwujd_OF9FGYnX3zU3BgztpqGxxmd0htM,1562
|
51
51
|
ring/io/xml/test_to_xml.py,sha256=NGn4VSiFdwhYN5YTBduWMiY9B5dwtxZhCQAR_PXeqKU,946
|
52
52
|
ring/io/xml/to_xml.py,sha256=fohb-jWMf2cxVdT5dmknsGyrNMseICSbKEz_urbaWbQ,3407
|
53
|
-
ring/ml/__init__.py,sha256=
|
54
|
-
ring/ml/base.py,sha256=
|
53
|
+
ring/ml/__init__.py,sha256=8SZTCs9rJ1kzR0Psh7lUzFhIMhKRPIK41mVfxJAGyMo,1471
|
54
|
+
ring/ml/base.py,sha256=16zdF72XLljeFRWaEr1O-9M5Sw2ppk5yvAoAPxA5EJU,9693
|
55
55
|
ring/ml/callbacks.py,sha256=W19QF6_uvaNCjs8ObsjNXD7mv9gFgJBixdRSbB_BynE,13301
|
56
56
|
ring/ml/ml_utils.py,sha256=hQEmeZoahdJyFrz0NZXYi1Yijl7GvPBdqwzZBzlUIUM,7638
|
57
57
|
ring/ml/optimizer.py,sha256=fWyF__ezUltrA16SLfOC1jvS3zBh9NJsMYa6-V0frhs,4709
|
@@ -79,7 +79,7 @@ ring/utils/hdf5.py,sha256=BzXwVypZmEZeHVgeGZ78YYdi10NEQtnPhdrb8dQAXo0,5856
|
|
79
79
|
ring/utils/normalizer.py,sha256=67L2BU1MRsMT4pD41ta3JJMppLN0ozFmnwrmXDtnqrQ,1698
|
80
80
|
ring/utils/path.py,sha256=hAfSlqRi-ew536RnjDDM7IKapdMJc-EvhrR0Y-BCFWc,1265
|
81
81
|
ring/utils/utils.py,sha256=mIcKNv5v2de8HrG7bAhl2bNfmwkMZyIIwFkJq2XWMOI,5357
|
82
|
-
imt_ring-1.
|
83
|
-
imt_ring-1.
|
84
|
-
imt_ring-1.
|
85
|
-
imt_ring-1.
|
82
|
+
imt_ring-1.4.0.dist-info/METADATA,sha256=4o8Vsz-U7Ekp_4ipyhcCy1vlhCgVZwEZYbwLsfmzy34,3104
|
83
|
+
imt_ring-1.4.0.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
84
|
+
imt_ring-1.4.0.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
|
85
|
+
imt_ring-1.4.0.dist-info/RECORD,,
|
ring/__init__.py
CHANGED
@@ -18,7 +18,74 @@ from .algorithms import step
|
|
18
18
|
from .base import State
|
19
19
|
from .base import System
|
20
20
|
from .base import Transform
|
21
|
-
|
21
|
+
|
22
|
+
|
23
|
+
def RING(lam: list[int], Ts: float | None):
|
24
|
+
"""Creates the RING network.
|
25
|
+
|
26
|
+
Params:
|
27
|
+
lam: parent array
|
28
|
+
Ts : sampling interval of IMU data; time delta in seconds
|
29
|
+
|
30
|
+
Usage:
|
31
|
+
>>> import ring
|
32
|
+
>>> import numpy as np
|
33
|
+
>>>
|
34
|
+
>>> T : int = 30 # sequence length [s]
|
35
|
+
>>> Ts : float = 0.01 # sampling interval [s]
|
36
|
+
>>> B : int = 1 # batch size
|
37
|
+
>>> lam: list[int] = [0, 1, 2] # parent array
|
38
|
+
>>> N : int = len(lam) # number of bodies
|
39
|
+
>>> T_i: int = int(T/Ts) # number of timesteps
|
40
|
+
>>>
|
41
|
+
>>> X = np.zeros((B, T_i, N, 9))
|
42
|
+
>>> # where X is structured as follows:
|
43
|
+
>>> # X[..., :3] = acc
|
44
|
+
>>> # X[..., 3:6] = gyr
|
45
|
+
>>> # X[..., 6:9] = jointaxis
|
46
|
+
>>>
|
47
|
+
>>> # let's assume we have an IMU on each outer segment of the
|
48
|
+
>>> # three-segment kinematic chain
|
49
|
+
>>> X[:, :, 0, :3] = acc_segment1
|
50
|
+
>>> X[:, :, 2, :3] = acc_segment3
|
51
|
+
>>> X[:, :, 0, 3:6] = gyr_segment1
|
52
|
+
>>> X[:, :, 2, 3:6] = gyr_segment3
|
53
|
+
>>>
|
54
|
+
>>> ringnet = ring.RING(lam, Ts)
|
55
|
+
>>>
|
56
|
+
>>> yhat, _ = ringnet.apply(X)
|
57
|
+
>>> # yhat : unit quaternions, shape = (B, T_i, N, 4)
|
58
|
+
>>>
|
59
|
+
>>> # use `jax.jit` to compile the forward pass
|
60
|
+
>>> jit_apply = jax.jit(ringnet.apply)
|
61
|
+
>>> yhat, _ = jit_apply(X)
|
62
|
+
>>>
|
63
|
+
>>> # manually pass in and out the hidden state like so
|
64
|
+
>>> initial_state = None
|
65
|
+
>>> yhat, state = ringnet.apply(X, state=initial_state)
|
66
|
+
>>> # state: final hidden state, shape = (B, N, 2*H)
|
67
|
+
|
68
|
+
"""
|
69
|
+
from pathlib import Path
|
70
|
+
import warnings
|
71
|
+
|
72
|
+
if Ts > (1 / 40) or Ts < (1 / 200):
|
73
|
+
warnings.warn(
|
74
|
+
"RING was only trained on sampling rates between 40 to 200 Hz "
|
75
|
+
f"but found {1 / Ts}Hz"
|
76
|
+
)
|
77
|
+
|
78
|
+
params = Path(__file__).parent.joinpath("ml/params/0x13e3518065c21cd8.pickle")
|
79
|
+
|
80
|
+
ringnet = ml.RING(params=params, lam=tuple(lam), jit=False)
|
81
|
+
ringnet = ml.base.ScaleX_FilterWrapper(ringnet)
|
82
|
+
ringnet = ml.base.LPF_FilterWrapper(
|
83
|
+
ringnet, ml._LPF_CUTOFF_FREQ, samp_freq=None if Ts is None else 1 / Ts
|
84
|
+
)
|
85
|
+
ringnet = ml.base.GroundTruthHeading_FilterWrapper(ringnet)
|
86
|
+
ringnet = ml.base.AddTs_FilterWrapper(ringnet, Ts)
|
87
|
+
return ringnet
|
88
|
+
|
22
89
|
|
23
90
|
_TRAIN_TIMING_START = None
|
24
91
|
_UNIQUE_ID = None
|
ring/ml/__init__.py
CHANGED
@@ -13,28 +13,7 @@ from .optimizer import make_optimizer
|
|
13
13
|
from .ringnet import RING
|
14
14
|
from .train import train_fn
|
15
15
|
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
def RING_ICML24(params=None, eval: bool = True, **kwargs):
|
20
|
-
"""Create the RING network used in the icml24 paper.
|
21
|
-
|
22
|
-
X[..., :3] = acc
|
23
|
-
X[..., 3:6] = gyr
|
24
|
-
X[..., 6:9] = jointaxis
|
25
|
-
X[..., 9:] = dt
|
26
|
-
"""
|
27
|
-
from pathlib import Path
|
28
|
-
|
29
|
-
if params is None:
|
30
|
-
params = Path(__file__).parent.joinpath("params/0x13e3518065c21cd8.pickle")
|
31
|
-
|
32
|
-
ringnet = RING(params=params, **kwargs) # noqa: F811
|
33
|
-
ringnet = base.ScaleX_FilterWrapper(ringnet)
|
34
|
-
if eval:
|
35
|
-
ringnet = base.LPF_FilterWrapper(ringnet, _lpf_cutoff_freq, samp_freq=None)
|
36
|
-
ringnet = base.GroundTruthHeading_FilterWrapper(ringnet)
|
37
|
-
return ringnet
|
16
|
+
_LPF_CUTOFF_FREQ = 10.0
|
38
17
|
|
39
18
|
|
40
19
|
def RNNO(
|
@@ -70,7 +49,7 @@ def RNNO(
|
|
70
49
|
ringnet = base.NoGraph_FilterWrapper(ringnet, quat_normalize=return_quats)
|
71
50
|
ringnet = base.ScaleX_FilterWrapper(ringnet)
|
72
51
|
if eval and return_quats:
|
73
|
-
ringnet = base.LPF_FilterWrapper(ringnet,
|
52
|
+
ringnet = base.LPF_FilterWrapper(ringnet, _LPF_CUTOFF_FREQ, samp_freq=samp_freq)
|
74
53
|
if return_quats:
|
75
54
|
ringnet = base.GroundTruthHeading_FilterWrapper(ringnet)
|
76
55
|
return ringnet
|
ring/ml/base.py
CHANGED
@@ -290,3 +290,24 @@ class NoGraph_FilterWrapper(AbstractFilterWrapper):
|
|
290
290
|
yhat = ring.maths.safe_normalize(yhat)
|
291
291
|
|
292
292
|
return yhat, state
|
293
|
+
|
294
|
+
|
295
|
+
class AddTs_FilterWrapper(AbstractFilterWrapper):
|
296
|
+
def __init__(self, filter: AbstractFilter, Ts: float | None, name=None) -> None:
|
297
|
+
super().__init__(filter, name)
|
298
|
+
self.Ts = Ts
|
299
|
+
|
300
|
+
def _add_Ts(self, X):
|
301
|
+
if self.Ts is None:
|
302
|
+
assert X.shape[-1] == 10
|
303
|
+
return X
|
304
|
+
else:
|
305
|
+
assert X.shape[-1] == 9
|
306
|
+
X_Ts = jnp.ones(X.shape[:-1] + (1,)) * self.Ts
|
307
|
+
return jnp.concatenate((X, X_Ts), axis=-1)
|
308
|
+
|
309
|
+
def init(self, bs=None, X=None, lam=None, seed: int = 1):
|
310
|
+
return super().init(bs, self._add_Ts(X), lam, seed)
|
311
|
+
|
312
|
+
def apply(self, X, params=None, state=None, y=None, lam=None):
|
313
|
+
return super().apply(self._add_Ts(X), params, state, y, lam)
|
File without changes
|
File without changes
|