imt-ring 1.3.13__tar.gz → 1.4.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {imt_ring-1.3.13 → imt_ring-1.4.1}/PKG-INFO +1 -1
- {imt_ring-1.3.13 → imt_ring-1.4.1}/pyproject.toml +1 -1
- {imt_ring-1.3.13 → imt_ring-1.4.1}/src/imt_ring.egg-info/PKG-INFO +1 -1
- {imt_ring-1.3.13 → imt_ring-1.4.1}/src/imt_ring.egg-info/SOURCES.txt +2 -0
- imt_ring-1.4.1/src/ring/__init__.py +143 -0
- {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/ml/__init__.py +2 -23
- {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/ml/base.py +26 -1
- imt_ring-1.4.1/src/ring/ml/params/0x1d76628065a71e0f.pickle +0 -0
- imt_ring-1.4.1/tests/test_quickstart_example.py +22 -0
- imt_ring-1.3.13/src/ring/__init__.py +0 -63
- {imt_ring-1.3.13 → imt_ring-1.4.1}/readme.md +0 -0
- {imt_ring-1.3.13 → imt_ring-1.4.1}/setup.cfg +0 -0
- {imt_ring-1.3.13 → imt_ring-1.4.1}/src/imt_ring.egg-info/dependency_links.txt +0 -0
- {imt_ring-1.3.13 → imt_ring-1.4.1}/src/imt_ring.egg-info/requires.txt +0 -0
- {imt_ring-1.3.13 → imt_ring-1.4.1}/src/imt_ring.egg-info/top_level.txt +0 -0
- {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/algebra.py +0 -0
- {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/algorithms/__init__.py +0 -0
- {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/algorithms/_random.py +0 -0
- {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/algorithms/custom_joints/__init__.py +0 -0
- {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/algorithms/custom_joints/rr_imp_joint.py +0 -0
- {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/algorithms/custom_joints/rr_joint.py +0 -0
- {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/algorithms/custom_joints/suntay.py +0 -0
- {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/algorithms/dynamics.py +0 -0
- {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/algorithms/generator/__init__.py +0 -0
- {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/algorithms/generator/base.py +0 -0
- {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/algorithms/generator/batch.py +0 -0
- {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/algorithms/generator/motion_artifacts.py +0 -0
- {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/algorithms/generator/pd_control.py +0 -0
- {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/algorithms/generator/randomize.py +0 -0
- {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/algorithms/generator/transforms.py +0 -0
- {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/algorithms/generator/types.py +0 -0
- {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/algorithms/jcalc.py +0 -0
- {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/algorithms/kinematics.py +0 -0
- {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/algorithms/sensors.py +0 -0
- {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/base.py +0 -0
- {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/io/__init__.py +0 -0
- {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/io/examples/branched.xml +0 -0
- {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/io/examples/exclude/knee_trans_dof.xml +0 -0
- {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/io/examples/exclude/standard_sys.xml +0 -0
- {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/io/examples/exclude/standard_sys_rr_imp.xml +0 -0
- {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/io/examples/inv_pendulum.xml +0 -0
- {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/io/examples/knee_flexible_imus.xml +0 -0
- {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/io/examples/spherical_stiff.xml +0 -0
- {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/io/examples/symmetric.xml +0 -0
- {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/io/examples/test_all_1.xml +0 -0
- {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/io/examples/test_all_2.xml +0 -0
- {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/io/examples/test_ang0_pos0.xml +0 -0
- {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/io/examples/test_control.xml +0 -0
- {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/io/examples/test_double_pendulum.xml +0 -0
- {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/io/examples/test_free.xml +0 -0
- {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/io/examples/test_kinematics.xml +0 -0
- {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/io/examples/test_morph_system/four_seg_seg1.xml +0 -0
- {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/io/examples/test_morph_system/four_seg_seg3.xml +0 -0
- {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/io/examples/test_randomize_position.xml +0 -0
- {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/io/examples/test_sensors.xml +0 -0
- {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/io/examples/test_three_seg_seg2.xml +0 -0
- {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/io/examples.py +0 -0
- {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/io/test_examples.py +0 -0
- {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/io/xml/__init__.py +0 -0
- {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/io/xml/abstract.py +0 -0
- {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/io/xml/from_xml.py +0 -0
- {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/io/xml/test_from_xml.py +0 -0
- {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/io/xml/test_to_xml.py +0 -0
- {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/io/xml/to_xml.py +0 -0
- {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/maths.py +0 -0
- {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/ml/callbacks.py +0 -0
- {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/ml/ml_utils.py +0 -0
- {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/ml/optimizer.py +0 -0
- {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/ml/params/0x13e3518065c21cd8.pickle +0 -0
- {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/ml/ringnet.py +0 -0
- {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/ml/rnno_v1.py +0 -0
- {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/ml/train.py +0 -0
- {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/ml/training_loop.py +0 -0
- {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/rendering/__init__.py +0 -0
- {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/rendering/base_render.py +0 -0
- {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/rendering/mujoco_render.py +0 -0
- {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/rendering/vispy_render.py +0 -0
- {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/rendering/vispy_visuals.py +0 -0
- {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/sim2real/__init__.py +0 -0
- {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/sim2real/sim2real.py +0 -0
- {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/spatial.py +0 -0
- {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/sys_composer/__init__.py +0 -0
- {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/sys_composer/delete_sys.py +0 -0
- {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/sys_composer/inject_sys.py +0 -0
- {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/sys_composer/morph_sys.py +0 -0
- {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/utils/__init__.py +0 -0
- {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/utils/backend.py +0 -0
- {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/utils/batchsize.py +0 -0
- {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/utils/colab.py +0 -0
- {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/utils/hdf5.py +0 -0
- {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/utils/normalizer.py +0 -0
- {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/utils/path.py +0 -0
- {imt_ring-1.3.13 → imt_ring-1.4.1}/src/ring/utils/utils.py +0 -0
- {imt_ring-1.3.13 → imt_ring-1.4.1}/tests/test_algebra.py +0 -0
- {imt_ring-1.3.13 → imt_ring-1.4.1}/tests/test_base.py +0 -0
- {imt_ring-1.3.13 → imt_ring-1.4.1}/tests/test_custom_joints.py +0 -0
- {imt_ring-1.3.13 → imt_ring-1.4.1}/tests/test_dynamics.py +0 -0
- {imt_ring-1.3.13 → imt_ring-1.4.1}/tests/test_generator.py +0 -0
- {imt_ring-1.3.13 → imt_ring-1.4.1}/tests/test_jcalc.py +0 -0
- {imt_ring-1.3.13 → imt_ring-1.4.1}/tests/test_jit.py +0 -0
- {imt_ring-1.3.13 → imt_ring-1.4.1}/tests/test_kinematics.py +0 -0
- {imt_ring-1.3.13 → imt_ring-1.4.1}/tests/test_maths.py +0 -0
- {imt_ring-1.3.13 → imt_ring-1.4.1}/tests/test_ml_utils.py +0 -0
- {imt_ring-1.3.13 → imt_ring-1.4.1}/tests/test_motion_artifacts.py +0 -0
- {imt_ring-1.3.13 → imt_ring-1.4.1}/tests/test_pd_control.py +0 -0
- {imt_ring-1.3.13 → imt_ring-1.4.1}/tests/test_random.py +0 -0
- {imt_ring-1.3.13 → imt_ring-1.4.1}/tests/test_randomize.py +0 -0
- {imt_ring-1.3.13 → imt_ring-1.4.1}/tests/test_rcmg.py +0 -0
- {imt_ring-1.3.13 → imt_ring-1.4.1}/tests/test_render.py +0 -0
- {imt_ring-1.3.13 → imt_ring-1.4.1}/tests/test_sensors.py +0 -0
- {imt_ring-1.3.13 → imt_ring-1.4.1}/tests/test_sim2real.py +0 -0
- {imt_ring-1.3.13 → imt_ring-1.4.1}/tests/test_sys_composer.py +0 -0
- {imt_ring-1.3.13 → imt_ring-1.4.1}/tests/test_train.py +0 -0
- {imt_ring-1.3.13 → imt_ring-1.4.1}/tests/test_utils.py +0 -0
@@ -68,6 +68,7 @@ src/ring/ml/rnno_v1.py
|
|
68
68
|
src/ring/ml/train.py
|
69
69
|
src/ring/ml/training_loop.py
|
70
70
|
src/ring/ml/params/0x13e3518065c21cd8.pickle
|
71
|
+
src/ring/ml/params/0x1d76628065a71e0f.pickle
|
71
72
|
src/ring/rendering/__init__.py
|
72
73
|
src/ring/rendering/base_render.py
|
73
74
|
src/ring/rendering/mujoco_render.py
|
@@ -99,6 +100,7 @@ tests/test_maths.py
|
|
99
100
|
tests/test_ml_utils.py
|
100
101
|
tests/test_motion_artifacts.py
|
101
102
|
tests/test_pd_control.py
|
103
|
+
tests/test_quickstart_example.py
|
102
104
|
tests/test_random.py
|
103
105
|
tests/test_randomize.py
|
104
106
|
tests/test_rcmg.py
|
@@ -0,0 +1,143 @@
|
|
1
|
+
from . import algebra
|
2
|
+
from . import algorithms
|
3
|
+
from . import base
|
4
|
+
from . import io
|
5
|
+
from . import maths
|
6
|
+
from . import ml
|
7
|
+
from . import rendering
|
8
|
+
from . import sim2real
|
9
|
+
from . import spatial
|
10
|
+
from . import sys_composer
|
11
|
+
from . import utils
|
12
|
+
from .algorithms import join_motionconfigs
|
13
|
+
from .algorithms import JointModel
|
14
|
+
from .algorithms import MotionConfig
|
15
|
+
from .algorithms import RCMG
|
16
|
+
from .algorithms import register_new_joint_type
|
17
|
+
from .algorithms import step
|
18
|
+
from .base import State
|
19
|
+
from .base import System
|
20
|
+
from .base import Transform
|
21
|
+
|
22
|
+
|
23
|
+
def RING(lam: list[int], Ts: float | None):
|
24
|
+
"""Creates the RING network.
|
25
|
+
|
26
|
+
Params:
|
27
|
+
lam: parent array
|
28
|
+
Ts : sampling interval of IMU data; time delta in seconds
|
29
|
+
|
30
|
+
Usage:
|
31
|
+
>>> import ring
|
32
|
+
>>> import numpy as np
|
33
|
+
>>>
|
34
|
+
>>> T : int = 30 # sequence length [s]
|
35
|
+
>>> Ts : float = 0.01 # sampling interval [s]
|
36
|
+
>>> B : int = 1 # batch size
|
37
|
+
>>> lam: list[int] = [0, 1, 2] # parent array
|
38
|
+
>>> N : int = len(lam) # number of bodies
|
39
|
+
>>> T_i: int = int(T/Ts) # number of timesteps
|
40
|
+
>>>
|
41
|
+
>>> X = np.zeros((B, T_i, N, 9))
|
42
|
+
>>> # where X is structured as follows:
|
43
|
+
>>> # X[..., :3] = acc
|
44
|
+
>>> # X[..., 3:6] = gyr
|
45
|
+
>>> # X[..., 6:9] = jointaxis
|
46
|
+
>>>
|
47
|
+
>>> # let's assume we have an IMU on each outer segment of the
|
48
|
+
>>> # three-segment kinematic chain
|
49
|
+
>>> X[:, :, 0, :3] = acc_segment1
|
50
|
+
>>> X[:, :, 2, :3] = acc_segment3
|
51
|
+
>>> X[:, :, 0, 3:6] = gyr_segment1
|
52
|
+
>>> X[:, :, 2, 3:6] = gyr_segment3
|
53
|
+
>>>
|
54
|
+
>>> ringnet = ring.RING(lam, Ts)
|
55
|
+
>>>
|
56
|
+
>>> yhat, _ = ringnet.apply(X)
|
57
|
+
>>> # yhat : unit quaternions, shape = (B, T_i, N, 4)
|
58
|
+
>>>
|
59
|
+
>>> # use `jax.jit` to compile the forward pass
|
60
|
+
>>> jit_apply = jax.jit(ringnet.apply)
|
61
|
+
>>> yhat, _ = jit_apply(X)
|
62
|
+
>>>
|
63
|
+
>>> # manually pass in and out the hidden state like so
|
64
|
+
>>> initial_state = None
|
65
|
+
>>> yhat, state = ringnet.apply(X, state=initial_state)
|
66
|
+
>>> # state: final hidden state, shape = (B, N, 2*H)
|
67
|
+
|
68
|
+
"""
|
69
|
+
from pathlib import Path
|
70
|
+
import warnings
|
71
|
+
|
72
|
+
if Ts is not None and (Ts > (1 / 40) or Ts < (1 / 200)):
|
73
|
+
warnings.warn(
|
74
|
+
"RING was only trained on sampling rates between 40 to 200 Hz "
|
75
|
+
f"but found {1 / Ts}Hz"
|
76
|
+
)
|
77
|
+
|
78
|
+
if Ts is not None and Ts == 0.01:
|
79
|
+
# this set of parameters was trained exclusively on 100Hz data; it also
|
80
|
+
# expects F=9 features per node and not F=10 where the last features is
|
81
|
+
# the sampling interval Ts
|
82
|
+
params = Path(__file__).parent.joinpath("ml/params/0x1d76628065a71e0f.pickle")
|
83
|
+
add_Ts = False
|
84
|
+
else:
|
85
|
+
# this set of parameters was trained on sampling rates from 40 to 200 Hz
|
86
|
+
params = Path(__file__).parent.joinpath("ml/params/0x13e3518065c21cd8.pickle")
|
87
|
+
add_Ts = True
|
88
|
+
|
89
|
+
ringnet = ml.RING(params=params, lam=tuple(lam), jit=False, name="RING")
|
90
|
+
ringnet = ml.base.ScaleX_FilterWrapper(ringnet)
|
91
|
+
ringnet = ml.base.LPF_FilterWrapper(
|
92
|
+
ringnet,
|
93
|
+
ml._LPF_CUTOFF_FREQ,
|
94
|
+
samp_freq=None if Ts is None else 1 / Ts,
|
95
|
+
quiet=True,
|
96
|
+
)
|
97
|
+
ringnet = ml.base.GroundTruthHeading_FilterWrapper(ringnet)
|
98
|
+
if add_Ts:
|
99
|
+
ringnet = ml.base.AddTs_FilterWrapper(ringnet, Ts)
|
100
|
+
return ringnet
|
101
|
+
|
102
|
+
|
103
|
+
_TRAIN_TIMING_START = None
|
104
|
+
_UNIQUE_ID = None
|
105
|
+
|
106
|
+
|
107
|
+
def setup(
|
108
|
+
rr_joint_kwargs: None | dict = dict(),
|
109
|
+
rr_imp_joint_kwargs: None | dict = dict(),
|
110
|
+
suntay_joint_kwargs: None | dict = None,
|
111
|
+
train_timing_start: None | float = None,
|
112
|
+
unique_id: None | str = None,
|
113
|
+
):
|
114
|
+
import time
|
115
|
+
|
116
|
+
from ring.algorithms import custom_joints
|
117
|
+
|
118
|
+
global _TRAIN_TIMING_START
|
119
|
+
global _UNIQUE_ID
|
120
|
+
|
121
|
+
if rr_joint_kwargs is not None:
|
122
|
+
custom_joints.register_rr_joint(**rr_joint_kwargs)
|
123
|
+
|
124
|
+
if rr_imp_joint_kwargs is not None:
|
125
|
+
custom_joints.register_rr_imp_joint(**rr_imp_joint_kwargs)
|
126
|
+
|
127
|
+
if suntay_joint_kwargs is not None:
|
128
|
+
custom_joints.register_suntay(**suntay_joint_kwargs)
|
129
|
+
|
130
|
+
if _TRAIN_TIMING_START is None:
|
131
|
+
_TRAIN_TIMING_START = time.time()
|
132
|
+
|
133
|
+
if train_timing_start is not None:
|
134
|
+
_TRAIN_TIMING_START = train_timing_start
|
135
|
+
|
136
|
+
if _UNIQUE_ID is None:
|
137
|
+
_UNIQUE_ID = hex(hash(time.time()))
|
138
|
+
|
139
|
+
if unique_id is not None:
|
140
|
+
_UNIQUE_ID = unique_id
|
141
|
+
|
142
|
+
|
143
|
+
setup()
|
@@ -13,28 +13,7 @@ from .optimizer import make_optimizer
|
|
13
13
|
from .ringnet import RING
|
14
14
|
from .train import train_fn
|
15
15
|
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
def RING_ICML24(params=None, eval: bool = True, **kwargs):
|
20
|
-
"""Create the RING network used in the icml24 paper.
|
21
|
-
|
22
|
-
X[..., :3] = acc
|
23
|
-
X[..., 3:6] = gyr
|
24
|
-
X[..., 6:9] = jointaxis
|
25
|
-
X[..., 9:] = dt
|
26
|
-
"""
|
27
|
-
from pathlib import Path
|
28
|
-
|
29
|
-
if params is None:
|
30
|
-
params = Path(__file__).parent.joinpath("params/0x13e3518065c21cd8.pickle")
|
31
|
-
|
32
|
-
ringnet = RING(params=params, **kwargs) # noqa: F811
|
33
|
-
ringnet = base.ScaleX_FilterWrapper(ringnet)
|
34
|
-
if eval:
|
35
|
-
ringnet = base.LPF_FilterWrapper(ringnet, _lpf_cutoff_freq, samp_freq=None)
|
36
|
-
ringnet = base.GroundTruthHeading_FilterWrapper(ringnet)
|
37
|
-
return ringnet
|
16
|
+
_LPF_CUTOFF_FREQ = 10.0
|
38
17
|
|
39
18
|
|
40
19
|
def RNNO(
|
@@ -70,7 +49,7 @@ def RNNO(
|
|
70
49
|
ringnet = base.NoGraph_FilterWrapper(ringnet, quat_normalize=return_quats)
|
71
50
|
ringnet = base.ScaleX_FilterWrapper(ringnet)
|
72
51
|
if eval and return_quats:
|
73
|
-
ringnet = base.LPF_FilterWrapper(ringnet,
|
52
|
+
ringnet = base.LPF_FilterWrapper(ringnet, _LPF_CUTOFF_FREQ, samp_freq=samp_freq)
|
74
53
|
if return_quats:
|
75
54
|
ringnet = base.GroundTruthHeading_FilterWrapper(ringnet)
|
76
55
|
return ringnet
|
@@ -144,11 +144,13 @@ class LPF_FilterWrapper(AbstractFilterWrapper):
|
|
144
144
|
cutoff_freq: float,
|
145
145
|
samp_freq: float | None,
|
146
146
|
filtfilt: bool = True,
|
147
|
+
quiet: bool = False,
|
147
148
|
name="LPF_FilterWrapper",
|
148
149
|
) -> None:
|
149
150
|
super().__init__(filter, name)
|
150
151
|
self.samp_freq = samp_freq
|
151
152
|
self._kwargs = dict(cutoff_freq=cutoff_freq, filtfilt=filtfilt)
|
153
|
+
self.quiet = quiet
|
152
154
|
|
153
155
|
def apply(self, X, params=None, state=None, y=None, lam=None):
|
154
156
|
if X.ndim == 4:
|
@@ -166,7 +168,7 @@ class LPF_FilterWrapper(AbstractFilterWrapper):
|
|
166
168
|
dt = X[0, 0, -1]
|
167
169
|
samp_freq = 1 / dt
|
168
170
|
|
169
|
-
if self.samp_freq is None:
|
171
|
+
if self.samp_freq is None and not self.quiet:
|
170
172
|
print(f"Detected the following sampling rates from `X`: {samp_freq}")
|
171
173
|
|
172
174
|
yhat, state = super().apply(X, params, state, y, lam)
|
@@ -290,3 +292,26 @@ class NoGraph_FilterWrapper(AbstractFilterWrapper):
|
|
290
292
|
yhat = ring.maths.safe_normalize(yhat)
|
291
293
|
|
292
294
|
return yhat, state
|
295
|
+
|
296
|
+
|
297
|
+
class AddTs_FilterWrapper(AbstractFilterWrapper):
|
298
|
+
def __init__(
|
299
|
+
self, filter: AbstractFilter, Ts: float | None, name="AddTs_FilterWrapper"
|
300
|
+
) -> None:
|
301
|
+
super().__init__(filter, name)
|
302
|
+
self.Ts = Ts
|
303
|
+
|
304
|
+
def _add_Ts(self, X):
|
305
|
+
if self.Ts is None:
|
306
|
+
assert X.shape[-1] == 10
|
307
|
+
return X
|
308
|
+
else:
|
309
|
+
assert X.shape[-1] == 9
|
310
|
+
X_Ts = jnp.ones(X.shape[:-1] + (1,)) * self.Ts
|
311
|
+
return jnp.concatenate((X, X_Ts), axis=-1)
|
312
|
+
|
313
|
+
def init(self, bs=None, X=None, lam=None, seed: int = 1):
|
314
|
+
return super().init(bs, self._add_Ts(X), lam, seed)
|
315
|
+
|
316
|
+
def apply(self, X, params=None, state=None, y=None, lam=None):
|
317
|
+
return super().apply(self._add_Ts(X), params, state, y, lam)
|
Binary file
|
@@ -0,0 +1,22 @@
|
|
1
|
+
import jax
|
2
|
+
import numpy as np
|
3
|
+
|
4
|
+
import ring
|
5
|
+
|
6
|
+
|
7
|
+
def test_quickstart_exampe():
|
8
|
+
T: int = 30 # sequence length [s]
|
9
|
+
Ts: float = 0.01 # sampling interval [s]
|
10
|
+
B: int = 1 # batch size
|
11
|
+
lam: list[int] = [0, 1, 2] # parent array
|
12
|
+
N: int = len(lam) # number of bodies
|
13
|
+
T_i: int = int(T / Ts) # number of timesteps
|
14
|
+
|
15
|
+
X = np.zeros((B, T_i, N, 9))
|
16
|
+
|
17
|
+
ringnet = ring.RING(lam, Ts)
|
18
|
+
yhat, state = ringnet.apply(X)
|
19
|
+
assert yhat.shape == (B, T_i, N, 4)
|
20
|
+
assert state["~"]["inner_cell_state"].shape == (B, N, 2, 400)
|
21
|
+
|
22
|
+
_ = jax.jit(ringnet.apply)(X, state=state)
|
@@ -1,63 +0,0 @@
|
|
1
|
-
from . import algebra
|
2
|
-
from . import algorithms
|
3
|
-
from . import base
|
4
|
-
from . import io
|
5
|
-
from . import maths
|
6
|
-
from . import ml
|
7
|
-
from . import rendering
|
8
|
-
from . import sim2real
|
9
|
-
from . import spatial
|
10
|
-
from . import sys_composer
|
11
|
-
from . import utils
|
12
|
-
from .algorithms import join_motionconfigs
|
13
|
-
from .algorithms import JointModel
|
14
|
-
from .algorithms import MotionConfig
|
15
|
-
from .algorithms import RCMG
|
16
|
-
from .algorithms import register_new_joint_type
|
17
|
-
from .algorithms import step
|
18
|
-
from .base import State
|
19
|
-
from .base import System
|
20
|
-
from .base import Transform
|
21
|
-
from .ml import RING
|
22
|
-
|
23
|
-
_TRAIN_TIMING_START = None
|
24
|
-
_UNIQUE_ID = None
|
25
|
-
|
26
|
-
|
27
|
-
def setup(
|
28
|
-
rr_joint_kwargs: None | dict = dict(),
|
29
|
-
rr_imp_joint_kwargs: None | dict = dict(),
|
30
|
-
suntay_joint_kwargs: None | dict = None,
|
31
|
-
train_timing_start: None | float = None,
|
32
|
-
unique_id: None | str = None,
|
33
|
-
):
|
34
|
-
import time
|
35
|
-
|
36
|
-
from ring.algorithms import custom_joints
|
37
|
-
|
38
|
-
global _TRAIN_TIMING_START
|
39
|
-
global _UNIQUE_ID
|
40
|
-
|
41
|
-
if rr_joint_kwargs is not None:
|
42
|
-
custom_joints.register_rr_joint(**rr_joint_kwargs)
|
43
|
-
|
44
|
-
if rr_imp_joint_kwargs is not None:
|
45
|
-
custom_joints.register_rr_imp_joint(**rr_imp_joint_kwargs)
|
46
|
-
|
47
|
-
if suntay_joint_kwargs is not None:
|
48
|
-
custom_joints.register_suntay(**suntay_joint_kwargs)
|
49
|
-
|
50
|
-
if _TRAIN_TIMING_START is None:
|
51
|
-
_TRAIN_TIMING_START = time.time()
|
52
|
-
|
53
|
-
if train_timing_start is not None:
|
54
|
-
_TRAIN_TIMING_START = train_timing_start
|
55
|
-
|
56
|
-
if _UNIQUE_ID is None:
|
57
|
-
_UNIQUE_ID = hex(hash(time.time()))
|
58
|
-
|
59
|
-
if unique_id is not None:
|
60
|
-
_UNIQUE_ID = unique_id
|
61
|
-
|
62
|
-
|
63
|
-
setup()
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|