imt-ring 1.3.13__py3-none-any.whl → 1.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: imt-ring
3
- Version: 1.3.13
3
+ Version: 1.4.0
4
4
  Summary: RING: Recurrent Inertial Graph-based Estimator
5
5
  Author-email: Simon Bachhuber <simon.bachhuber@fau.de>
6
6
  Project-URL: Homepage, https://github.com/SimiPixel/ring
@@ -1,4 +1,4 @@
1
- ring/__init__.py,sha256=iNvbAZi7Qfa69IbL1z4lB7zHL8WusV5fBrKah2la-Gc,1566
1
+ ring/__init__.py,sha256=-70A1E5LQQyGdOA_u9PlYZb1d5Fz0yXnvRoZOrbxq6o,3781
2
2
  ring/algebra.py,sha256=F0GwbP8LQP5qGVkoMUYJmkp9Hn2nKAVIkCVYDEjNjGU,3128
3
3
  ring/base.py,sha256=YFPrUWelWswEhq8x8Byv-5pK64mipiGW6x5IlMr4we4,33803
4
4
  ring/maths.py,sha256=jJr_kr78-XDce8B4tXQ2Li-jBntVQhaS8csxglCsj8A,12193
@@ -50,8 +50,8 @@ ring/io/xml/from_xml.py,sha256=8b44sPVWgoY8JGJZLpJ8M_eLfcfu3IsMtBzSytPTPmw,9234
50
50
  ring/io/xml/test_from_xml.py,sha256=bckVrVVmEhCwujd_OF9FGYnX3zU3BgztpqGxxmd0htM,1562
51
51
  ring/io/xml/test_to_xml.py,sha256=NGn4VSiFdwhYN5YTBduWMiY9B5dwtxZhCQAR_PXeqKU,946
52
52
  ring/io/xml/to_xml.py,sha256=fohb-jWMf2cxVdT5dmknsGyrNMseICSbKEz_urbaWbQ,3407
53
- ring/ml/__init__.py,sha256=52LpEjni5lG-ov5-3ocodH-vKZxNcFMU7W9XfjDicp0,2113
54
- ring/ml/base.py,sha256=PQ72VasEqlecBZgWP5HE5rWYyLiLq7nCVLymXo9f0dw,8959
53
+ ring/ml/__init__.py,sha256=8SZTCs9rJ1kzR0Psh7lUzFhIMhKRPIK41mVfxJAGyMo,1471
54
+ ring/ml/base.py,sha256=16zdF72XLljeFRWaEr1O-9M5Sw2ppk5yvAoAPxA5EJU,9693
55
55
  ring/ml/callbacks.py,sha256=W19QF6_uvaNCjs8ObsjNXD7mv9gFgJBixdRSbB_BynE,13301
56
56
  ring/ml/ml_utils.py,sha256=hQEmeZoahdJyFrz0NZXYi1Yijl7GvPBdqwzZBzlUIUM,7638
57
57
  ring/ml/optimizer.py,sha256=fWyF__ezUltrA16SLfOC1jvS3zBh9NJsMYa6-V0frhs,4709
@@ -79,7 +79,7 @@ ring/utils/hdf5.py,sha256=BzXwVypZmEZeHVgeGZ78YYdi10NEQtnPhdrb8dQAXo0,5856
79
79
  ring/utils/normalizer.py,sha256=67L2BU1MRsMT4pD41ta3JJMppLN0ozFmnwrmXDtnqrQ,1698
80
80
  ring/utils/path.py,sha256=hAfSlqRi-ew536RnjDDM7IKapdMJc-EvhrR0Y-BCFWc,1265
81
81
  ring/utils/utils.py,sha256=mIcKNv5v2de8HrG7bAhl2bNfmwkMZyIIwFkJq2XWMOI,5357
82
- imt_ring-1.3.13.dist-info/METADATA,sha256=29lL2WaS8JaByX0qyjVm-OMC916Xs8U1A7pRyybHSrA,3105
83
- imt_ring-1.3.13.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
84
- imt_ring-1.3.13.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
85
- imt_ring-1.3.13.dist-info/RECORD,,
82
+ imt_ring-1.4.0.dist-info/METADATA,sha256=4o8Vsz-U7Ekp_4ipyhcCy1vlhCgVZwEZYbwLsfmzy34,3104
83
+ imt_ring-1.4.0.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
84
+ imt_ring-1.4.0.dist-info/top_level.txt,sha256=EiT790-lAyi8iwTzJArH3f2k77rwhDn00q-4PlmvDQo,5
85
+ imt_ring-1.4.0.dist-info/RECORD,,
ring/__init__.py CHANGED
@@ -18,7 +18,74 @@ from .algorithms import step
18
18
  from .base import State
19
19
  from .base import System
20
20
  from .base import Transform
21
- from .ml import RING
21
+
22
+
23
+ def RING(lam: list[int], Ts: float | None):
24
+ """Creates the RING network.
25
+
26
+ Params:
27
+ lam: parent array
28
+ Ts : sampling interval of IMU data; time delta in seconds
29
+
30
+ Usage:
31
+ >>> import ring
32
+ >>> import numpy as np
33
+ >>>
34
+ >>> T : int = 30 # sequence length [s]
35
+ >>> Ts : float = 0.01 # sampling interval [s]
36
+ >>> B : int = 1 # batch size
37
+ >>> lam: list[int] = [0, 1, 2] # parent array
38
+ >>> N : int = len(lam) # number of bodies
39
+ >>> T_i: int = int(T/Ts) # number of timesteps
40
+ >>>
41
+ >>> X = np.zeros((B, T_i, N, 9))
42
+ >>> # where X is structured as follows:
43
+ >>> # X[..., :3] = acc
44
+ >>> # X[..., 3:6] = gyr
45
+ >>> # X[..., 6:9] = jointaxis
46
+ >>>
47
+ >>> # let's assume we have an IMU on each outer segment of the
48
+ >>> # three-segment kinematic chain
49
+ >>> X[:, :, 0, :3] = acc_segment1
50
+ >>> X[:, :, 2, :3] = acc_segment3
51
+ >>> X[:, :, 0, 3:6] = gyr_segment1
52
+ >>> X[:, :, 2, 3:6] = gyr_segment3
53
+ >>>
54
+ >>> ringnet = ring.RING(lam, Ts)
55
+ >>>
56
+ >>> yhat, _ = ringnet.apply(X)
57
+ >>> # yhat : unit quaternions, shape = (B, T_i, N, 4)
58
+ >>>
59
+ >>> # use `jax.jit` to compile the forward pass
60
+ >>> jit_apply = jax.jit(ringnet.apply)
61
+ >>> yhat, _ = jit_apply(X)
62
+ >>>
63
+ >>> # manually pass in and out the hidden state like so
64
+ >>> initial_state = None
65
+ >>> yhat, state = ringnet.apply(X, state=initial_state)
66
+ >>> # state: final hidden state, shape = (B, N, 2*H)
67
+
68
+ """
69
+ from pathlib import Path
70
+ import warnings
71
+
72
+ if Ts > (1 / 40) or Ts < (1 / 200):
73
+ warnings.warn(
74
+ "RING was only trained on sampling rates between 40 to 200 Hz "
75
+ f"but found {1 / Ts}Hz"
76
+ )
77
+
78
+ params = Path(__file__).parent.joinpath("ml/params/0x13e3518065c21cd8.pickle")
79
+
80
+ ringnet = ml.RING(params=params, lam=tuple(lam), jit=False)
81
+ ringnet = ml.base.ScaleX_FilterWrapper(ringnet)
82
+ ringnet = ml.base.LPF_FilterWrapper(
83
+ ringnet, ml._LPF_CUTOFF_FREQ, samp_freq=None if Ts is None else 1 / Ts
84
+ )
85
+ ringnet = ml.base.GroundTruthHeading_FilterWrapper(ringnet)
86
+ ringnet = ml.base.AddTs_FilterWrapper(ringnet, Ts)
87
+ return ringnet
88
+
22
89
 
23
90
  _TRAIN_TIMING_START = None
24
91
  _UNIQUE_ID = None
ring/ml/__init__.py CHANGED
@@ -13,28 +13,7 @@ from .optimizer import make_optimizer
13
13
  from .ringnet import RING
14
14
  from .train import train_fn
15
15
 
16
- _lpf_cutoff_freq = 10.0
17
-
18
-
19
- def RING_ICML24(params=None, eval: bool = True, **kwargs):
20
- """Create the RING network used in the icml24 paper.
21
-
22
- X[..., :3] = acc
23
- X[..., 3:6] = gyr
24
- X[..., 6:9] = jointaxis
25
- X[..., 9:] = dt
26
- """
27
- from pathlib import Path
28
-
29
- if params is None:
30
- params = Path(__file__).parent.joinpath("params/0x13e3518065c21cd8.pickle")
31
-
32
- ringnet = RING(params=params, **kwargs) # noqa: F811
33
- ringnet = base.ScaleX_FilterWrapper(ringnet)
34
- if eval:
35
- ringnet = base.LPF_FilterWrapper(ringnet, _lpf_cutoff_freq, samp_freq=None)
36
- ringnet = base.GroundTruthHeading_FilterWrapper(ringnet)
37
- return ringnet
16
+ _LPF_CUTOFF_FREQ = 10.0
38
17
 
39
18
 
40
19
  def RNNO(
@@ -70,7 +49,7 @@ def RNNO(
70
49
  ringnet = base.NoGraph_FilterWrapper(ringnet, quat_normalize=return_quats)
71
50
  ringnet = base.ScaleX_FilterWrapper(ringnet)
72
51
  if eval and return_quats:
73
- ringnet = base.LPF_FilterWrapper(ringnet, _lpf_cutoff_freq, samp_freq=samp_freq)
52
+ ringnet = base.LPF_FilterWrapper(ringnet, _LPF_CUTOFF_FREQ, samp_freq=samp_freq)
74
53
  if return_quats:
75
54
  ringnet = base.GroundTruthHeading_FilterWrapper(ringnet)
76
55
  return ringnet
ring/ml/base.py CHANGED
@@ -290,3 +290,24 @@ class NoGraph_FilterWrapper(AbstractFilterWrapper):
290
290
  yhat = ring.maths.safe_normalize(yhat)
291
291
 
292
292
  return yhat, state
293
+
294
+
295
+ class AddTs_FilterWrapper(AbstractFilterWrapper):
296
+ def __init__(self, filter: AbstractFilter, Ts: float | None, name=None) -> None:
297
+ super().__init__(filter, name)
298
+ self.Ts = Ts
299
+
300
+ def _add_Ts(self, X):
301
+ if self.Ts is None:
302
+ assert X.shape[-1] == 10
303
+ return X
304
+ else:
305
+ assert X.shape[-1] == 9
306
+ X_Ts = jnp.ones(X.shape[:-1] + (1,)) * self.Ts
307
+ return jnp.concatenate((X, X_Ts), axis=-1)
308
+
309
+ def init(self, bs=None, X=None, lam=None, seed: int = 1):
310
+ return super().init(bs, self._add_Ts(X), lam, seed)
311
+
312
+ def apply(self, X, params=None, state=None, y=None, lam=None):
313
+ return super().apply(self._add_Ts(X), params, state, y, lam)