imt-ring 1.3.13__tar.gz → 1.4.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (114) hide show
  1. {imt_ring-1.3.13 → imt_ring-1.4.1}/PKG-INFO +1 -1
  2. {imt_ring-1.3.13 → imt_ring-1.4.1}/pyproject.toml +1 -1
  3. {imt_ring-1.3.13 → imt_ring-1.4.1}/src/imt_ring.egg-info/PKG-INFO +1 -1
  4. {imt_ring-1.3.13 → imt_ring-1.4.1}/src/imt_ring.egg-info/SOURCES.txt +2 -0
  5. imt_ring-1.4.1/src/ring/__init__.py +143 -0
  6. {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/ml/__init__.py +2 -23
  7. {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/ml/base.py +26 -1
  8. imt_ring-1.4.1/src/ring/ml/params/0x1d76628065a71e0f.pickle +0 -0
  9. imt_ring-1.4.1/tests/test_quickstart_example.py +22 -0
  10. imt_ring-1.3.13/src/ring/__init__.py +0 -63
  11. {imt_ring-1.3.13 → imt_ring-1.4.1}/readme.md +0 -0
  12. {imt_ring-1.3.13 → imt_ring-1.4.1}/setup.cfg +0 -0
  13. {imt_ring-1.3.13 → imt_ring-1.4.1}/src/imt_ring.egg-info/dependency_links.txt +0 -0
  14. {imt_ring-1.3.13 → imt_ring-1.4.1}/src/imt_ring.egg-info/requires.txt +0 -0
  15. {imt_ring-1.3.13 → imt_ring-1.4.1}/src/imt_ring.egg-info/top_level.txt +0 -0
  16. {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/algebra.py +0 -0
  17. {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/algorithms/__init__.py +0 -0
  18. {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/algorithms/_random.py +0 -0
  19. {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/algorithms/custom_joints/__init__.py +0 -0
  20. {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/algorithms/custom_joints/rr_imp_joint.py +0 -0
  21. {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/algorithms/custom_joints/rr_joint.py +0 -0
  22. {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/algorithms/custom_joints/suntay.py +0 -0
  23. {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/algorithms/dynamics.py +0 -0
  24. {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/algorithms/generator/__init__.py +0 -0
  25. {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/algorithms/generator/base.py +0 -0
  26. {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/algorithms/generator/batch.py +0 -0
  27. {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/algorithms/generator/motion_artifacts.py +0 -0
  28. {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/algorithms/generator/pd_control.py +0 -0
  29. {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/algorithms/generator/randomize.py +0 -0
  30. {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/algorithms/generator/transforms.py +0 -0
  31. {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/algorithms/generator/types.py +0 -0
  32. {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/algorithms/jcalc.py +0 -0
  33. {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/algorithms/kinematics.py +0 -0
  34. {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/algorithms/sensors.py +0 -0
  35. {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/base.py +0 -0
  36. {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/io/__init__.py +0 -0
  37. {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/io/examples/branched.xml +0 -0
  38. {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/io/examples/exclude/knee_trans_dof.xml +0 -0
  39. {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/io/examples/exclude/standard_sys.xml +0 -0
  40. {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/io/examples/exclude/standard_sys_rr_imp.xml +0 -0
  41. {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/io/examples/inv_pendulum.xml +0 -0
  42. {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/io/examples/knee_flexible_imus.xml +0 -0
  43. {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/io/examples/spherical_stiff.xml +0 -0
  44. {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/io/examples/symmetric.xml +0 -0
  45. {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/io/examples/test_all_1.xml +0 -0
  46. {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/io/examples/test_all_2.xml +0 -0
  47. {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/io/examples/test_ang0_pos0.xml +0 -0
  48. {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/io/examples/test_control.xml +0 -0
  49. {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/io/examples/test_double_pendulum.xml +0 -0
  50. {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/io/examples/test_free.xml +0 -0
  51. {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/io/examples/test_kinematics.xml +0 -0
  52. {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/io/examples/test_morph_system/four_seg_seg1.xml +0 -0
  53. {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/io/examples/test_morph_system/four_seg_seg3.xml +0 -0
  54. {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/io/examples/test_randomize_position.xml +0 -0
  55. {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/io/examples/test_sensors.xml +0 -0
  56. {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/io/examples/test_three_seg_seg2.xml +0 -0
  57. {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/io/examples.py +0 -0
  58. {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/io/test_examples.py +0 -0
  59. {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/io/xml/__init__.py +0 -0
  60. {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/io/xml/abstract.py +0 -0
  61. {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/io/xml/from_xml.py +0 -0
  62. {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/io/xml/test_from_xml.py +0 -0
  63. {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/io/xml/test_to_xml.py +0 -0
  64. {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/io/xml/to_xml.py +0 -0
  65. {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/maths.py +0 -0
  66. {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/ml/callbacks.py +0 -0
  67. {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/ml/ml_utils.py +0 -0
  68. {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/ml/optimizer.py +0 -0
  69. {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/ml/params/0x13e3518065c21cd8.pickle +0 -0
  70. {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/ml/ringnet.py +0 -0
  71. {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/ml/rnno_v1.py +0 -0
  72. {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/ml/train.py +0 -0
  73. {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/ml/training_loop.py +0 -0
  74. {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/rendering/__init__.py +0 -0
  75. {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/rendering/base_render.py +0 -0
  76. {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/rendering/mujoco_render.py +0 -0
  77. {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/rendering/vispy_render.py +0 -0
  78. {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/rendering/vispy_visuals.py +0 -0
  79. {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/sim2real/__init__.py +0 -0
  80. {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/sim2real/sim2real.py +0 -0
  81. {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/spatial.py +0 -0
  82. {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/sys_composer/__init__.py +0 -0
  83. {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/sys_composer/delete_sys.py +0 -0
  84. {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/sys_composer/inject_sys.py +0 -0
  85. {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/sys_composer/morph_sys.py +0 -0
  86. {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/utils/__init__.py +0 -0
  87. {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/utils/backend.py +0 -0
  88. {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/utils/batchsize.py +0 -0
  89. {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/utils/colab.py +0 -0
  90. {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/utils/hdf5.py +0 -0
  91. {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/utils/normalizer.py +0 -0
  92. {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/utils/path.py +0 -0
  93. {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/utils/utils.py +0 -0
  94. {imt_ring-1.3.13 → imt_ring-1.4.1}/tests/test_algebra.py +0 -0
  95. {imt_ring-1.3.13 → imt_ring-1.4.1}/tests/test_base.py +0 -0
  96. {imt_ring-1.3.13 → imt_ring-1.4.1}/tests/test_custom_joints.py +0 -0
  97. {imt_ring-1.3.13 → imt_ring-1.4.1}/tests/test_dynamics.py +0 -0
  98. {imt_ring-1.3.13 → imt_ring-1.4.1}/tests/test_generator.py +0 -0
  99. {imt_ring-1.3.13 → imt_ring-1.4.1}/tests/test_jcalc.py +0 -0
  100. {imt_ring-1.3.13 → imt_ring-1.4.1}/tests/test_jit.py +0 -0
  101. {imt_ring-1.3.13 → imt_ring-1.4.1}/tests/test_kinematics.py +0 -0
  102. {imt_ring-1.3.13 → imt_ring-1.4.1}/tests/test_maths.py +0 -0
  103. {imt_ring-1.3.13 → imt_ring-1.4.1}/tests/test_ml_utils.py +0 -0
  104. {imt_ring-1.3.13 → imt_ring-1.4.1}/tests/test_motion_artifacts.py +0 -0
  105. {imt_ring-1.3.13 → imt_ring-1.4.1}/tests/test_pd_control.py +0 -0
  106. {imt_ring-1.3.13 → imt_ring-1.4.1}/tests/test_random.py +0 -0
  107. {imt_ring-1.3.13 → imt_ring-1.4.1}/tests/test_randomize.py +0 -0
  108. {imt_ring-1.3.13 → imt_ring-1.4.1}/tests/test_rcmg.py +0 -0
  109. {imt_ring-1.3.13 → imt_ring-1.4.1}/tests/test_render.py +0 -0
  110. {imt_ring-1.3.13 → imt_ring-1.4.1}/tests/test_sensors.py +0 -0
  111. {imt_ring-1.3.13 → imt_ring-1.4.1}/tests/test_sim2real.py +0 -0
  112. {imt_ring-1.3.13 → imt_ring-1.4.1}/tests/test_sys_composer.py +0 -0
  113. {imt_ring-1.3.13 → imt_ring-1.4.1}/tests/test_train.py +0 -0
  114. {imt_ring-1.3.13 → imt_ring-1.4.1}/tests/test_utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: imt-ring
3
- Version: 1.3.13
3
+ Version: 1.4.1
4
4
  Summary: RING: Recurrent Inertial Graph-based Estimator
5
5
  Author-email: Simon Bachhuber <simon.bachhuber@fau.de>
6
6
  Project-URL: Homepage, https://github.com/SimiPixel/ring
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "imt-ring"
7
- version = "1.3.13"
7
+ version = "1.4.1"
8
8
  authors = [
9
9
  { name="Simon Bachhuber", email="simon.bachhuber@fau.de" },
10
10
  ]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: imt-ring
3
- Version: 1.3.13
3
+ Version: 1.4.1
4
4
  Summary: RING: Recurrent Inertial Graph-based Estimator
5
5
  Author-email: Simon Bachhuber <simon.bachhuber@fau.de>
6
6
  Project-URL: Homepage, https://github.com/SimiPixel/ring
@@ -68,6 +68,7 @@ src/ring/ml/rnno_v1.py
68
68
  src/ring/ml/train.py
69
69
  src/ring/ml/training_loop.py
70
70
  src/ring/ml/params/0x13e3518065c21cd8.pickle
71
+ src/ring/ml/params/0x1d76628065a71e0f.pickle
71
72
  src/ring/rendering/__init__.py
72
73
  src/ring/rendering/base_render.py
73
74
  src/ring/rendering/mujoco_render.py
@@ -99,6 +100,7 @@ tests/test_maths.py
99
100
  tests/test_ml_utils.py
100
101
  tests/test_motion_artifacts.py
101
102
  tests/test_pd_control.py
103
+ tests/test_quickstart_example.py
102
104
  tests/test_random.py
103
105
  tests/test_randomize.py
104
106
  tests/test_rcmg.py
@@ -0,0 +1,143 @@
1
+ from . import algebra
2
+ from . import algorithms
3
+ from . import base
4
+ from . import io
5
+ from . import maths
6
+ from . import ml
7
+ from . import rendering
8
+ from . import sim2real
9
+ from . import spatial
10
+ from . import sys_composer
11
+ from . import utils
12
+ from .algorithms import join_motionconfigs
13
+ from .algorithms import JointModel
14
+ from .algorithms import MotionConfig
15
+ from .algorithms import RCMG
16
+ from .algorithms import register_new_joint_type
17
+ from .algorithms import step
18
+ from .base import State
19
+ from .base import System
20
+ from .base import Transform
21
+
22
+
23
+ def RING(lam: list[int], Ts: float | None):
24
+ """Creates the RING network.
25
+
26
+ Params:
27
+ lam: parent array
28
+ Ts : sampling interval of IMU data; time delta in seconds
29
+
30
+ Usage:
31
+ >>> import ring
32
+ >>> import numpy as np
33
+ >>>
34
+ >>> T : int = 30 # sequence length [s]
35
+ >>> Ts : float = 0.01 # sampling interval [s]
36
+ >>> B : int = 1 # batch size
37
+ >>> lam: list[int] = [0, 1, 2] # parent array
38
+ >>> N : int = len(lam) # number of bodies
39
+ >>> T_i: int = int(T/Ts) # number of timesteps
40
+ >>>
41
+ >>> X = np.zeros((B, T_i, N, 9))
42
+ >>> # where X is structured as follows:
43
+ >>> # X[..., :3] = acc
44
+ >>> # X[..., 3:6] = gyr
45
+ >>> # X[..., 6:9] = jointaxis
46
+ >>>
47
+ >>> # let's assume we have an IMU on each outer segment of the
48
+ >>> # three-segment kinematic chain
49
+ >>> X[:, :, 0, :3] = acc_segment1
50
+ >>> X[:, :, 2, :3] = acc_segment3
51
+ >>> X[:, :, 0, 3:6] = gyr_segment1
52
+ >>> X[:, :, 2, 3:6] = gyr_segment3
53
+ >>>
54
+ >>> ringnet = ring.RING(lam, Ts)
55
+ >>>
56
+ >>> yhat, _ = ringnet.apply(X)
57
+ >>> # yhat : unit quaternions, shape = (B, T_i, N, 4)
58
+ >>>
59
+ >>> # use `jax.jit` to compile the forward pass
60
+ >>> jit_apply = jax.jit(ringnet.apply)
61
+ >>> yhat, _ = jit_apply(X)
62
+ >>>
63
+ >>> # manually pass in and out the hidden state like so
64
+ >>> initial_state = None
65
+ >>> yhat, state = ringnet.apply(X, state=initial_state)
66
+ >>> # state: final hidden state, shape = (B, N, 2*H)
67
+
68
+ """
69
+ from pathlib import Path
70
+ import warnings
71
+
72
+ if Ts is not None and (Ts > (1 / 40) or Ts < (1 / 200)):
73
+ warnings.warn(
74
+ "RING was only trained on sampling rates between 40 to 200 Hz "
75
+ f"but found {1 / Ts}Hz"
76
+ )
77
+
78
+ if Ts is not None and Ts == 0.01:
79
+ # this set of parameters was trained exclusively on 100Hz data; it also
80
+ # expects F=9 features per node and not F=10 where the last features is
81
+ # the sampling interval Ts
82
+ params = Path(__file__).parent.joinpath("ml/params/0x1d76628065a71e0f.pickle")
83
+ add_Ts = False
84
+ else:
85
+ # this set of parameters was trained on sampling rates from 40 to 200 Hz
86
+ params = Path(__file__).parent.joinpath("ml/params/0x13e3518065c21cd8.pickle")
87
+ add_Ts = True
88
+
89
+ ringnet = ml.RING(params=params, lam=tuple(lam), jit=False, name="RING")
90
+ ringnet = ml.base.ScaleX_FilterWrapper(ringnet)
91
+ ringnet = ml.base.LPF_FilterWrapper(
92
+ ringnet,
93
+ ml._LPF_CUTOFF_FREQ,
94
+ samp_freq=None if Ts is None else 1 / Ts,
95
+ quiet=True,
96
+ )
97
+ ringnet = ml.base.GroundTruthHeading_FilterWrapper(ringnet)
98
+ if add_Ts:
99
+ ringnet = ml.base.AddTs_FilterWrapper(ringnet, Ts)
100
+ return ringnet
101
+
102
+
103
+ _TRAIN_TIMING_START = None
104
+ _UNIQUE_ID = None
105
+
106
+
107
+ def setup(
108
+ rr_joint_kwargs: None | dict = dict(),
109
+ rr_imp_joint_kwargs: None | dict = dict(),
110
+ suntay_joint_kwargs: None | dict = None,
111
+ train_timing_start: None | float = None,
112
+ unique_id: None | str = None,
113
+ ):
114
+ import time
115
+
116
+ from ring.algorithms import custom_joints
117
+
118
+ global _TRAIN_TIMING_START
119
+ global _UNIQUE_ID
120
+
121
+ if rr_joint_kwargs is not None:
122
+ custom_joints.register_rr_joint(**rr_joint_kwargs)
123
+
124
+ if rr_imp_joint_kwargs is not None:
125
+ custom_joints.register_rr_imp_joint(**rr_imp_joint_kwargs)
126
+
127
+ if suntay_joint_kwargs is not None:
128
+ custom_joints.register_suntay(**suntay_joint_kwargs)
129
+
130
+ if _TRAIN_TIMING_START is None:
131
+ _TRAIN_TIMING_START = time.time()
132
+
133
+ if train_timing_start is not None:
134
+ _TRAIN_TIMING_START = train_timing_start
135
+
136
+ if _UNIQUE_ID is None:
137
+ _UNIQUE_ID = hex(hash(time.time()))
138
+
139
+ if unique_id is not None:
140
+ _UNIQUE_ID = unique_id
141
+
142
+
143
+ setup()
@@ -13,28 +13,7 @@ from .optimizer import make_optimizer
13
13
  from .ringnet import RING
14
14
  from .train import train_fn
15
15
 
16
- _lpf_cutoff_freq = 10.0
17
-
18
-
19
- def RING_ICML24(params=None, eval: bool = True, **kwargs):
20
- """Create the RING network used in the icml24 paper.
21
-
22
- X[..., :3] = acc
23
- X[..., 3:6] = gyr
24
- X[..., 6:9] = jointaxis
25
- X[..., 9:] = dt
26
- """
27
- from pathlib import Path
28
-
29
- if params is None:
30
- params = Path(__file__).parent.joinpath("params/0x13e3518065c21cd8.pickle")
31
-
32
- ringnet = RING(params=params, **kwargs) # noqa: F811
33
- ringnet = base.ScaleX_FilterWrapper(ringnet)
34
- if eval:
35
- ringnet = base.LPF_FilterWrapper(ringnet, _lpf_cutoff_freq, samp_freq=None)
36
- ringnet = base.GroundTruthHeading_FilterWrapper(ringnet)
37
- return ringnet
16
+ _LPF_CUTOFF_FREQ = 10.0
38
17
 
39
18
 
40
19
  def RNNO(
@@ -70,7 +49,7 @@ def RNNO(
70
49
  ringnet = base.NoGraph_FilterWrapper(ringnet, quat_normalize=return_quats)
71
50
  ringnet = base.ScaleX_FilterWrapper(ringnet)
72
51
  if eval and return_quats:
73
- ringnet = base.LPF_FilterWrapper(ringnet, _lpf_cutoff_freq, samp_freq=samp_freq)
52
+ ringnet = base.LPF_FilterWrapper(ringnet, _LPF_CUTOFF_FREQ, samp_freq=samp_freq)
74
53
  if return_quats:
75
54
  ringnet = base.GroundTruthHeading_FilterWrapper(ringnet)
76
55
  return ringnet
@@ -144,11 +144,13 @@ class LPF_FilterWrapper(AbstractFilterWrapper):
144
144
  cutoff_freq: float,
145
145
  samp_freq: float | None,
146
146
  filtfilt: bool = True,
147
+ quiet: bool = False,
147
148
  name="LPF_FilterWrapper",
148
149
  ) -> None:
149
150
  super().__init__(filter, name)
150
151
  self.samp_freq = samp_freq
151
152
  self._kwargs = dict(cutoff_freq=cutoff_freq, filtfilt=filtfilt)
153
+ self.quiet = quiet
152
154
 
153
155
  def apply(self, X, params=None, state=None, y=None, lam=None):
154
156
  if X.ndim == 4:
@@ -166,7 +168,7 @@ class LPF_FilterWrapper(AbstractFilterWrapper):
166
168
  dt = X[0, 0, -1]
167
169
  samp_freq = 1 / dt
168
170
 
169
- if self.samp_freq is None:
171
+ if self.samp_freq is None and not self.quiet:
170
172
  print(f"Detected the following sampling rates from `X`: {samp_freq}")
171
173
 
172
174
  yhat, state = super().apply(X, params, state, y, lam)
@@ -290,3 +292,26 @@ class NoGraph_FilterWrapper(AbstractFilterWrapper):
290
292
  yhat = ring.maths.safe_normalize(yhat)
291
293
 
292
294
  return yhat, state
295
+
296
+
297
+ class AddTs_FilterWrapper(AbstractFilterWrapper):
298
+ def __init__(
299
+ self, filter: AbstractFilter, Ts: float | None, name="AddTs_FilterWrapper"
300
+ ) -> None:
301
+ super().__init__(filter, name)
302
+ self.Ts = Ts
303
+
304
+ def _add_Ts(self, X):
305
+ if self.Ts is None:
306
+ assert X.shape[-1] == 10
307
+ return X
308
+ else:
309
+ assert X.shape[-1] == 9
310
+ X_Ts = jnp.ones(X.shape[:-1] + (1,)) * self.Ts
311
+ return jnp.concatenate((X, X_Ts), axis=-1)
312
+
313
+ def init(self, bs=None, X=None, lam=None, seed: int = 1):
314
+ return super().init(bs, self._add_Ts(X), lam, seed)
315
+
316
+ def apply(self, X, params=None, state=None, y=None, lam=None):
317
+ return super().apply(self._add_Ts(X), params, state, y, lam)
@@ -0,0 +1,22 @@
1
+ import jax
2
+ import numpy as np
3
+
4
+ import ring
5
+
6
+
7
+ def test_quickstart_exampe():
8
+ T: int = 30 # sequence length [s]
9
+ Ts: float = 0.01 # sampling interval [s]
10
+ B: int = 1 # batch size
11
+ lam: list[int] = [0, 1, 2] # parent array
12
+ N: int = len(lam) # number of bodies
13
+ T_i: int = int(T / Ts) # number of timesteps
14
+
15
+ X = np.zeros((B, T_i, N, 9))
16
+
17
+ ringnet = ring.RING(lam, Ts)
18
+ yhat, state = ringnet.apply(X)
19
+ assert yhat.shape == (B, T_i, N, 4)
20
+ assert state["~"]["inner_cell_state"].shape == (B, N, 2, 400)
21
+
22
+ _ = jax.jit(ringnet.apply)(X, state=state)
@@ -1,63 +0,0 @@
1
- from . import algebra
2
- from . import algorithms
3
- from . import base
4
- from . import io
5
- from . import maths
6
- from . import ml
7
- from . import rendering
8
- from . import sim2real
9
- from . import spatial
10
- from . import sys_composer
11
- from . import utils
12
- from .algorithms import join_motionconfigs
13
- from .algorithms import JointModel
14
- from .algorithms import MotionConfig
15
- from .algorithms import RCMG
16
- from .algorithms import register_new_joint_type
17
- from .algorithms import step
18
- from .base import State
19
- from .base import System
20
- from .base import Transform
21
- from .ml import RING
22
-
23
- _TRAIN_TIMING_START = None
24
- _UNIQUE_ID = None
25
-
26
-
27
- def setup(
28
- rr_joint_kwargs: None | dict = dict(),
29
- rr_imp_joint_kwargs: None | dict = dict(),
30
- suntay_joint_kwargs: None | dict = None,
31
- train_timing_start: None | float = None,
32
- unique_id: None | str = None,
33
- ):
34
- import time
35
-
36
- from ring.algorithms import custom_joints
37
-
38
- global _TRAIN_TIMING_START
39
- global _UNIQUE_ID
40
-
41
- if rr_joint_kwargs is not None:
42
- custom_joints.register_rr_joint(**rr_joint_kwargs)
43
-
44
- if rr_imp_joint_kwargs is not None:
45
- custom_joints.register_rr_imp_joint(**rr_imp_joint_kwargs)
46
-
47
- if suntay_joint_kwargs is not None:
48
- custom_joints.register_suntay(**suntay_joint_kwargs)
49
-
50
- if _TRAIN_TIMING_START is None:
51
- _TRAIN_TIMING_START = time.time()
52
-
53
- if train_timing_start is not None:
54
- _TRAIN_TIMING_START = train_timing_start
55
-
56
- if _UNIQUE_ID is None:
57
- _UNIQUE_ID = hex(hash(time.time()))
58
-
59
- if unique_id is not None:
60
- _UNIQUE_ID = unique_id
61
-
62
-
63
- setup()
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes