imt-ring 1.3.12__tar.gz → 1.4.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {imt_ring-1.3.12 → imt_ring-1.4.0}/PKG-INFO +1 -1
- {imt_ring-1.3.12 → imt_ring-1.4.0}/pyproject.toml +1 -1
- {imt_ring-1.3.12 → imt_ring-1.4.0}/src/imt_ring.egg-info/PKG-INFO +1 -1
- {imt_ring-1.3.12 → imt_ring-1.4.0}/src/imt_ring.egg-info/SOURCES.txt +1 -0
- imt_ring-1.4.0/src/ring/__init__.py +130 -0
- {imt_ring-1.3.12 → imt_ring-1.4.0}/src/ring/algorithms/generator/batch.py +8 -11
- {imt_ring-1.3.12 → imt_ring-1.4.0}/src/ring/ml/__init__.py +2 -23
- {imt_ring-1.3.12 → imt_ring-1.4.0}/src/ring/ml/base.py +21 -0
- imt_ring-1.4.0/tests/test_quickstart_example.py +22 -0
- imt_ring-1.3.12/src/ring/__init__.py +0 -63
- {imt_ring-1.3.12 → imt_ring-1.4.0}/readme.md +0 -0
- {imt_ring-1.3.12 → imt_ring-1.4.0}/setup.cfg +0 -0
- {imt_ring-1.3.12 → imt_ring-1.4.0}/src/imt_ring.egg-info/dependency_links.txt +0 -0
- {imt_ring-1.3.12 → imt_ring-1.4.0}/src/imt_ring.egg-info/requires.txt +0 -0
- {imt_ring-1.3.12 → imt_ring-1.4.0}/src/imt_ring.egg-info/top_level.txt +0 -0
- {imt_ring-1.3.12 → imt_ring-1.4.0}/src/ring/algebra.py +0 -0
- {imt_ring-1.3.12 → imt_ring-1.4.0}/src/ring/algorithms/__init__.py +0 -0
- {imt_ring-1.3.12 → imt_ring-1.4.0}/src/ring/algorithms/_random.py +0 -0
- {imt_ring-1.3.12 → imt_ring-1.4.0}/src/ring/algorithms/custom_joints/__init__.py +0 -0
- {imt_ring-1.3.12 → imt_ring-1.4.0}/src/ring/algorithms/custom_joints/rr_imp_joint.py +0 -0
- {imt_ring-1.3.12 → imt_ring-1.4.0}/src/ring/algorithms/custom_joints/rr_joint.py +0 -0
- {imt_ring-1.3.12 → imt_ring-1.4.0}/src/ring/algorithms/custom_joints/suntay.py +0 -0
- {imt_ring-1.3.12 → imt_ring-1.4.0}/src/ring/algorithms/dynamics.py +0 -0
- {imt_ring-1.3.12 → imt_ring-1.4.0}/src/ring/algorithms/generator/__init__.py +0 -0
- {imt_ring-1.3.12 → imt_ring-1.4.0}/src/ring/algorithms/generator/base.py +0 -0
- {imt_ring-1.3.12 → imt_ring-1.4.0}/src/ring/algorithms/generator/motion_artifacts.py +0 -0
- {imt_ring-1.3.12 → imt_ring-1.4.0}/src/ring/algorithms/generator/pd_control.py +0 -0
- {imt_ring-1.3.12 → imt_ring-1.4.0}/src/ring/algorithms/generator/randomize.py +0 -0
- {imt_ring-1.3.12 → imt_ring-1.4.0}/src/ring/algorithms/generator/transforms.py +0 -0
- {imt_ring-1.3.12 → imt_ring-1.4.0}/src/ring/algorithms/generator/types.py +0 -0
- {imt_ring-1.3.12 → imt_ring-1.4.0}/src/ring/algorithms/jcalc.py +0 -0
- {imt_ring-1.3.12 → imt_ring-1.4.0}/src/ring/algorithms/kinematics.py +0 -0
- {imt_ring-1.3.12 → imt_ring-1.4.0}/src/ring/algorithms/sensors.py +0 -0
- {imt_ring-1.3.12 → imt_ring-1.4.0}/src/ring/base.py +0 -0
- {imt_ring-1.3.12 → imt_ring-1.4.0}/src/ring/io/__init__.py +0 -0
- {imt_ring-1.3.12 → imt_ring-1.4.0}/src/ring/io/examples/branched.xml +0 -0
- {imt_ring-1.3.12 → imt_ring-1.4.0}/src/ring/io/examples/exclude/knee_trans_dof.xml +0 -0
- {imt_ring-1.3.12 → imt_ring-1.4.0}/src/ring/io/examples/exclude/standard_sys.xml +0 -0
- {imt_ring-1.3.12 → imt_ring-1.4.0}/src/ring/io/examples/exclude/standard_sys_rr_imp.xml +0 -0
- {imt_ring-1.3.12 → imt_ring-1.4.0}/src/ring/io/examples/inv_pendulum.xml +0 -0
- {imt_ring-1.3.12 → imt_ring-1.4.0}/src/ring/io/examples/knee_flexible_imus.xml +0 -0
- {imt_ring-1.3.12 → imt_ring-1.4.0}/src/ring/io/examples/spherical_stiff.xml +0 -0
- {imt_ring-1.3.12 → imt_ring-1.4.0}/src/ring/io/examples/symmetric.xml +0 -0
- {imt_ring-1.3.12 → imt_ring-1.4.0}/src/ring/io/examples/test_all_1.xml +0 -0
- {imt_ring-1.3.12 → imt_ring-1.4.0}/src/ring/io/examples/test_all_2.xml +0 -0
- {imt_ring-1.3.12 → imt_ring-1.4.0}/src/ring/io/examples/test_ang0_pos0.xml +0 -0
- {imt_ring-1.3.12 → imt_ring-1.4.0}/src/ring/io/examples/test_control.xml +0 -0
- {imt_ring-1.3.12 → imt_ring-1.4.0}/src/ring/io/examples/test_double_pendulum.xml +0 -0
- {imt_ring-1.3.12 → imt_ring-1.4.0}/src/ring/io/examples/test_free.xml +0 -0
- {imt_ring-1.3.12 → imt_ring-1.4.0}/src/ring/io/examples/test_kinematics.xml +0 -0
- {imt_ring-1.3.12 → imt_ring-1.4.0}/src/ring/io/examples/test_morph_system/four_seg_seg1.xml +0 -0
- {imt_ring-1.3.12 → imt_ring-1.4.0}/src/ring/io/examples/test_morph_system/four_seg_seg3.xml +0 -0
- {imt_ring-1.3.12 → imt_ring-1.4.0}/src/ring/io/examples/test_randomize_position.xml +0 -0
- {imt_ring-1.3.12 → imt_ring-1.4.0}/src/ring/io/examples/test_sensors.xml +0 -0
- {imt_ring-1.3.12 → imt_ring-1.4.0}/src/ring/io/examples/test_three_seg_seg2.xml +0 -0
- {imt_ring-1.3.12 → imt_ring-1.4.0}/src/ring/io/examples.py +0 -0
- {imt_ring-1.3.12 → imt_ring-1.4.0}/src/ring/io/test_examples.py +0 -0
- {imt_ring-1.3.12 → imt_ring-1.4.0}/src/ring/io/xml/__init__.py +0 -0
- {imt_ring-1.3.12 → imt_ring-1.4.0}/src/ring/io/xml/abstract.py +0 -0
- {imt_ring-1.3.12 → imt_ring-1.4.0}/src/ring/io/xml/from_xml.py +0 -0
- {imt_ring-1.3.12 → imt_ring-1.4.0}/src/ring/io/xml/test_from_xml.py +0 -0
- {imt_ring-1.3.12 → imt_ring-1.4.0}/src/ring/io/xml/test_to_xml.py +0 -0
- {imt_ring-1.3.12 → imt_ring-1.4.0}/src/ring/io/xml/to_xml.py +0 -0
- {imt_ring-1.3.12 → imt_ring-1.4.0}/src/ring/maths.py +0 -0
- {imt_ring-1.3.12 → imt_ring-1.4.0}/src/ring/ml/callbacks.py +0 -0
- {imt_ring-1.3.12 → imt_ring-1.4.0}/src/ring/ml/ml_utils.py +0 -0
- {imt_ring-1.3.12 → imt_ring-1.4.0}/src/ring/ml/optimizer.py +0 -0
- {imt_ring-1.3.12 → imt_ring-1.4.0}/src/ring/ml/params/0x13e3518065c21cd8.pickle +0 -0
- {imt_ring-1.3.12 → imt_ring-1.4.0}/src/ring/ml/ringnet.py +0 -0
- {imt_ring-1.3.12 → imt_ring-1.4.0}/src/ring/ml/rnno_v1.py +0 -0
- {imt_ring-1.3.12 → imt_ring-1.4.0}/src/ring/ml/train.py +0 -0
- {imt_ring-1.3.12 → imt_ring-1.4.0}/src/ring/ml/training_loop.py +0 -0
- {imt_ring-1.3.12 → imt_ring-1.4.0}/src/ring/rendering/__init__.py +0 -0
- {imt_ring-1.3.12 → imt_ring-1.4.0}/src/ring/rendering/base_render.py +0 -0
- {imt_ring-1.3.12 → imt_ring-1.4.0}/src/ring/rendering/mujoco_render.py +0 -0
- {imt_ring-1.3.12 → imt_ring-1.4.0}/src/ring/rendering/vispy_render.py +0 -0
- {imt_ring-1.3.12 → imt_ring-1.4.0}/src/ring/rendering/vispy_visuals.py +0 -0
- {imt_ring-1.3.12 → imt_ring-1.4.0}/src/ring/sim2real/__init__.py +0 -0
- {imt_ring-1.3.12 → imt_ring-1.4.0}/src/ring/sim2real/sim2real.py +0 -0
- {imt_ring-1.3.12 → imt_ring-1.4.0}/src/ring/spatial.py +0 -0
- {imt_ring-1.3.12 → imt_ring-1.4.0}/src/ring/sys_composer/__init__.py +0 -0
- {imt_ring-1.3.12 → imt_ring-1.4.0}/src/ring/sys_composer/delete_sys.py +0 -0
- {imt_ring-1.3.12 → imt_ring-1.4.0}/src/ring/sys_composer/inject_sys.py +0 -0
- {imt_ring-1.3.12 → imt_ring-1.4.0}/src/ring/sys_composer/morph_sys.py +0 -0
- {imt_ring-1.3.12 → imt_ring-1.4.0}/src/ring/utils/__init__.py +0 -0
- {imt_ring-1.3.12 → imt_ring-1.4.0}/src/ring/utils/backend.py +0 -0
- {imt_ring-1.3.12 → imt_ring-1.4.0}/src/ring/utils/batchsize.py +0 -0
- {imt_ring-1.3.12 → imt_ring-1.4.0}/src/ring/utils/colab.py +0 -0
- {imt_ring-1.3.12 → imt_ring-1.4.0}/src/ring/utils/hdf5.py +0 -0
- {imt_ring-1.3.12 → imt_ring-1.4.0}/src/ring/utils/normalizer.py +0 -0
- {imt_ring-1.3.12 → imt_ring-1.4.0}/src/ring/utils/path.py +0 -0
- {imt_ring-1.3.12 → imt_ring-1.4.0}/src/ring/utils/utils.py +0 -0
- {imt_ring-1.3.12 → imt_ring-1.4.0}/tests/test_algebra.py +0 -0
- {imt_ring-1.3.12 → imt_ring-1.4.0}/tests/test_base.py +0 -0
- {imt_ring-1.3.12 → imt_ring-1.4.0}/tests/test_custom_joints.py +0 -0
- {imt_ring-1.3.12 → imt_ring-1.4.0}/tests/test_dynamics.py +0 -0
- {imt_ring-1.3.12 → imt_ring-1.4.0}/tests/test_generator.py +0 -0
- {imt_ring-1.3.12 → imt_ring-1.4.0}/tests/test_jcalc.py +0 -0
- {imt_ring-1.3.12 → imt_ring-1.4.0}/tests/test_jit.py +0 -0
- {imt_ring-1.3.12 → imt_ring-1.4.0}/tests/test_kinematics.py +0 -0
- {imt_ring-1.3.12 → imt_ring-1.4.0}/tests/test_maths.py +0 -0
- {imt_ring-1.3.12 → imt_ring-1.4.0}/tests/test_ml_utils.py +0 -0
- {imt_ring-1.3.12 → imt_ring-1.4.0}/tests/test_motion_artifacts.py +0 -0
- {imt_ring-1.3.12 → imt_ring-1.4.0}/tests/test_pd_control.py +0 -0
- {imt_ring-1.3.12 → imt_ring-1.4.0}/tests/test_random.py +0 -0
- {imt_ring-1.3.12 → imt_ring-1.4.0}/tests/test_randomize.py +0 -0
- {imt_ring-1.3.12 → imt_ring-1.4.0}/tests/test_rcmg.py +0 -0
- {imt_ring-1.3.12 → imt_ring-1.4.0}/tests/test_render.py +0 -0
- {imt_ring-1.3.12 → imt_ring-1.4.0}/tests/test_sensors.py +0 -0
- {imt_ring-1.3.12 → imt_ring-1.4.0}/tests/test_sim2real.py +0 -0
- {imt_ring-1.3.12 → imt_ring-1.4.0}/tests/test_sys_composer.py +0 -0
- {imt_ring-1.3.12 → imt_ring-1.4.0}/tests/test_train.py +0 -0
- {imt_ring-1.3.12 → imt_ring-1.4.0}/tests/test_utils.py +0 -0
@@ -0,0 +1,130 @@
|
|
1
|
+
from . import algebra
|
2
|
+
from . import algorithms
|
3
|
+
from . import base
|
4
|
+
from . import io
|
5
|
+
from . import maths
|
6
|
+
from . import ml
|
7
|
+
from . import rendering
|
8
|
+
from . import sim2real
|
9
|
+
from . import spatial
|
10
|
+
from . import sys_composer
|
11
|
+
from . import utils
|
12
|
+
from .algorithms import join_motionconfigs
|
13
|
+
from .algorithms import JointModel
|
14
|
+
from .algorithms import MotionConfig
|
15
|
+
from .algorithms import RCMG
|
16
|
+
from .algorithms import register_new_joint_type
|
17
|
+
from .algorithms import step
|
18
|
+
from .base import State
|
19
|
+
from .base import System
|
20
|
+
from .base import Transform
|
21
|
+
|
22
|
+
|
23
|
+
def RING(lam: list[int], Ts: float | None):
|
24
|
+
"""Creates the RING network.
|
25
|
+
|
26
|
+
Params:
|
27
|
+
lam: parent array
|
28
|
+
Ts : sampling interval of IMU data; time delta in seconds
|
29
|
+
|
30
|
+
Usage:
|
31
|
+
>>> import ring
|
32
|
+
>>> import numpy as np
|
33
|
+
>>>
|
34
|
+
>>> T : int = 30 # sequence length [s]
|
35
|
+
>>> Ts : float = 0.01 # sampling interval [s]
|
36
|
+
>>> B : int = 1 # batch size
|
37
|
+
>>> lam: list[int] = [0, 1, 2] # parent array
|
38
|
+
>>> N : int = len(lam) # number of bodies
|
39
|
+
>>> T_i: int = int(T/Ts) # number of timesteps
|
40
|
+
>>>
|
41
|
+
>>> X = np.zeros((B, T_i, N, 9))
|
42
|
+
>>> # where X is structured as follows:
|
43
|
+
>>> # X[..., :3] = acc
|
44
|
+
>>> # X[..., 3:6] = gyr
|
45
|
+
>>> # X[..., 6:9] = jointaxis
|
46
|
+
>>>
|
47
|
+
>>> # let's assume we have an IMU on each outer segment of the
|
48
|
+
>>> # three-segment kinematic chain
|
49
|
+
>>> X[:, :, 0, :3] = acc_segment1
|
50
|
+
>>> X[:, :, 2, :3] = acc_segment3
|
51
|
+
>>> X[:, :, 0, 3:6] = gyr_segment1
|
52
|
+
>>> X[:, :, 2, 3:6] = gyr_segment3
|
53
|
+
>>>
|
54
|
+
>>> ringnet = ring.RING(lam, Ts)
|
55
|
+
>>>
|
56
|
+
>>> yhat, _ = ringnet.apply(X)
|
57
|
+
>>> # yhat : unit quaternions, shape = (B, T_i, N, 4)
|
58
|
+
>>>
|
59
|
+
>>> # use `jax.jit` to compile the forward pass
|
60
|
+
>>> jit_apply = jax.jit(ringnet.apply)
|
61
|
+
>>> yhat, _ = jit_apply(X)
|
62
|
+
>>>
|
63
|
+
>>> # manually pass in and out the hidden state like so
|
64
|
+
>>> initial_state = None
|
65
|
+
>>> yhat, state = ringnet.apply(X, state=initial_state)
|
66
|
+
>>> # state: final hidden state, shape = (B, N, 2*H)
|
67
|
+
|
68
|
+
"""
|
69
|
+
from pathlib import Path
|
70
|
+
import warnings
|
71
|
+
|
72
|
+
if Ts > (1 / 40) or Ts < (1 / 200):
|
73
|
+
warnings.warn(
|
74
|
+
"RING was only trained on sampling rates between 40 to 200 Hz "
|
75
|
+
f"but found {1 / Ts}Hz"
|
76
|
+
)
|
77
|
+
|
78
|
+
params = Path(__file__).parent.joinpath("ml/params/0x13e3518065c21cd8.pickle")
|
79
|
+
|
80
|
+
ringnet = ml.RING(params=params, lam=tuple(lam), jit=False)
|
81
|
+
ringnet = ml.base.ScaleX_FilterWrapper(ringnet)
|
82
|
+
ringnet = ml.base.LPF_FilterWrapper(
|
83
|
+
ringnet, ml._LPF_CUTOFF_FREQ, samp_freq=None if Ts is None else 1 / Ts
|
84
|
+
)
|
85
|
+
ringnet = ml.base.GroundTruthHeading_FilterWrapper(ringnet)
|
86
|
+
ringnet = ml.base.AddTs_FilterWrapper(ringnet, Ts)
|
87
|
+
return ringnet
|
88
|
+
|
89
|
+
|
90
|
+
_TRAIN_TIMING_START = None
|
91
|
+
_UNIQUE_ID = None
|
92
|
+
|
93
|
+
|
94
|
+
def setup(
|
95
|
+
rr_joint_kwargs: None | dict = dict(),
|
96
|
+
rr_imp_joint_kwargs: None | dict = dict(),
|
97
|
+
suntay_joint_kwargs: None | dict = None,
|
98
|
+
train_timing_start: None | float = None,
|
99
|
+
unique_id: None | str = None,
|
100
|
+
):
|
101
|
+
import time
|
102
|
+
|
103
|
+
from ring.algorithms import custom_joints
|
104
|
+
|
105
|
+
global _TRAIN_TIMING_START
|
106
|
+
global _UNIQUE_ID
|
107
|
+
|
108
|
+
if rr_joint_kwargs is not None:
|
109
|
+
custom_joints.register_rr_joint(**rr_joint_kwargs)
|
110
|
+
|
111
|
+
if rr_imp_joint_kwargs is not None:
|
112
|
+
custom_joints.register_rr_imp_joint(**rr_imp_joint_kwargs)
|
113
|
+
|
114
|
+
if suntay_joint_kwargs is not None:
|
115
|
+
custom_joints.register_suntay(**suntay_joint_kwargs)
|
116
|
+
|
117
|
+
if _TRAIN_TIMING_START is None:
|
118
|
+
_TRAIN_TIMING_START = time.time()
|
119
|
+
|
120
|
+
if train_timing_start is not None:
|
121
|
+
_TRAIN_TIMING_START = train_timing_start
|
122
|
+
|
123
|
+
if _UNIQUE_ID is None:
|
124
|
+
_UNIQUE_ID = hex(hash(time.time()))
|
125
|
+
|
126
|
+
if unique_id is not None:
|
127
|
+
_UNIQUE_ID = unique_id
|
128
|
+
|
129
|
+
|
130
|
+
setup()
|
@@ -154,25 +154,20 @@ def _data_fn_from_paths(
|
|
154
154
|
|
155
155
|
# expanduser
|
156
156
|
paths = [utils.parse_path(p, mkdir=False) for p in paths]
|
157
|
-
|
158
157
|
extensions = list(set([Path(p).suffix for p in paths]))
|
159
158
|
assert len(extensions) == 1, f"{extensions}"
|
159
|
+
h5 = extensions[0] == ".h5"
|
160
160
|
|
161
|
-
if
|
162
|
-
N = sum([utils.hdf5_load_length(p) for p in paths])
|
163
|
-
|
164
|
-
if extensions[0] == ".h5" and not load_all_into_memory:
|
161
|
+
if h5 and not load_all_into_memory:
|
165
162
|
|
166
163
|
def data_fn(indices: list[int]):
|
167
164
|
tree = utils.hdf5_load_from_multiple(paths, indices)
|
168
165
|
return tree if tree_transform is None else tree_transform(tree)
|
169
166
|
|
167
|
+
N = sum([utils.hdf5_load_length(p) for p in paths])
|
170
168
|
else:
|
171
169
|
|
172
|
-
if
|
173
|
-
load_from_path = utils.hdf5_load
|
174
|
-
else:
|
175
|
-
load_from_path = utils.pickle_load
|
170
|
+
load_from_path = utils.hdf5_load if h5 else utils.pickle_load
|
176
171
|
|
177
172
|
def load_fn(path):
|
178
173
|
tree = load_from_path(path)
|
@@ -190,8 +185,10 @@ def _data_fn_from_paths(
|
|
190
185
|
_list_of_data += load_fn(p)
|
191
186
|
|
192
187
|
N = len(_list_of_data)
|
193
|
-
|
194
|
-
|
188
|
+
list_of_data = _replace_elements_w_nans(
|
189
|
+
_list_of_data,
|
190
|
+
include_samples if include_samples is not None else list(range(N)),
|
191
|
+
)
|
195
192
|
|
196
193
|
if include_samples is not None:
|
197
194
|
list_of_data = [
|
@@ -13,28 +13,7 @@ from .optimizer import make_optimizer
|
|
13
13
|
from .ringnet import RING
|
14
14
|
from .train import train_fn
|
15
15
|
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
def RING_ICML24(params=None, eval: bool = True, **kwargs):
|
20
|
-
"""Create the RING network used in the icml24 paper.
|
21
|
-
|
22
|
-
X[..., :3] = acc
|
23
|
-
X[..., 3:6] = gyr
|
24
|
-
X[..., 6:9] = jointaxis
|
25
|
-
X[..., 9:] = dt
|
26
|
-
"""
|
27
|
-
from pathlib import Path
|
28
|
-
|
29
|
-
if params is None:
|
30
|
-
params = Path(__file__).parent.joinpath("params/0x13e3518065c21cd8.pickle")
|
31
|
-
|
32
|
-
ringnet = RING(params=params, **kwargs) # noqa: F811
|
33
|
-
ringnet = base.ScaleX_FilterWrapper(ringnet)
|
34
|
-
if eval:
|
35
|
-
ringnet = base.LPF_FilterWrapper(ringnet, _lpf_cutoff_freq, samp_freq=None)
|
36
|
-
ringnet = base.GroundTruthHeading_FilterWrapper(ringnet)
|
37
|
-
return ringnet
|
16
|
+
_LPF_CUTOFF_FREQ = 10.0
|
38
17
|
|
39
18
|
|
40
19
|
def RNNO(
|
@@ -70,7 +49,7 @@ def RNNO(
|
|
70
49
|
ringnet = base.NoGraph_FilterWrapper(ringnet, quat_normalize=return_quats)
|
71
50
|
ringnet = base.ScaleX_FilterWrapper(ringnet)
|
72
51
|
if eval and return_quats:
|
73
|
-
ringnet = base.LPF_FilterWrapper(ringnet,
|
52
|
+
ringnet = base.LPF_FilterWrapper(ringnet, _LPF_CUTOFF_FREQ, samp_freq=samp_freq)
|
74
53
|
if return_quats:
|
75
54
|
ringnet = base.GroundTruthHeading_FilterWrapper(ringnet)
|
76
55
|
return ringnet
|
@@ -290,3 +290,24 @@ class NoGraph_FilterWrapper(AbstractFilterWrapper):
|
|
290
290
|
yhat = ring.maths.safe_normalize(yhat)
|
291
291
|
|
292
292
|
return yhat, state
|
293
|
+
|
294
|
+
|
295
|
+
class AddTs_FilterWrapper(AbstractFilterWrapper):
|
296
|
+
def __init__(self, filter: AbstractFilter, Ts: float | None, name=None) -> None:
|
297
|
+
super().__init__(filter, name)
|
298
|
+
self.Ts = Ts
|
299
|
+
|
300
|
+
def _add_Ts(self, X):
|
301
|
+
if self.Ts is None:
|
302
|
+
assert X.shape[-1] == 10
|
303
|
+
return X
|
304
|
+
else:
|
305
|
+
assert X.shape[-1] == 9
|
306
|
+
X_Ts = jnp.ones(X.shape[:-1] + (1,)) * self.Ts
|
307
|
+
return jnp.concatenate((X, X_Ts), axis=-1)
|
308
|
+
|
309
|
+
def init(self, bs=None, X=None, lam=None, seed: int = 1):
|
310
|
+
return super().init(bs, self._add_Ts(X), lam, seed)
|
311
|
+
|
312
|
+
def apply(self, X, params=None, state=None, y=None, lam=None):
|
313
|
+
return super().apply(self._add_Ts(X), params, state, y, lam)
|
@@ -0,0 +1,22 @@
|
|
1
|
+
import jax
|
2
|
+
import numpy as np
|
3
|
+
|
4
|
+
import ring
|
5
|
+
|
6
|
+
|
7
|
+
def test_quickstart_exampe():
|
8
|
+
T: int = 30 # sequence length [s]
|
9
|
+
Ts: float = 0.01 # sampling interval [s]
|
10
|
+
B: int = 1 # batch size
|
11
|
+
lam: list[int] = [0, 1, 2] # parent array
|
12
|
+
N: int = len(lam) # number of bodies
|
13
|
+
T_i: int = int(T / Ts) # number of timesteps
|
14
|
+
|
15
|
+
X = np.zeros((B, T_i, N, 9))
|
16
|
+
|
17
|
+
ringnet = ring.RING(lam, Ts)
|
18
|
+
yhat, state = ringnet.apply(X)
|
19
|
+
assert yhat.shape == (B, T_i, N, 4)
|
20
|
+
assert state["~"]["inner_cell_state"].shape == (B, N, 2, 400)
|
21
|
+
|
22
|
+
_ = jax.jit(ringnet.apply)(X, state=state)
|
@@ -1,63 +0,0 @@
|
|
1
|
-
from . import algebra
|
2
|
-
from . import algorithms
|
3
|
-
from . import base
|
4
|
-
from . import io
|
5
|
-
from . import maths
|
6
|
-
from . import ml
|
7
|
-
from . import rendering
|
8
|
-
from . import sim2real
|
9
|
-
from . import spatial
|
10
|
-
from . import sys_composer
|
11
|
-
from . import utils
|
12
|
-
from .algorithms import join_motionconfigs
|
13
|
-
from .algorithms import JointModel
|
14
|
-
from .algorithms import MotionConfig
|
15
|
-
from .algorithms import RCMG
|
16
|
-
from .algorithms import register_new_joint_type
|
17
|
-
from .algorithms import step
|
18
|
-
from .base import State
|
19
|
-
from .base import System
|
20
|
-
from .base import Transform
|
21
|
-
from .ml import RING
|
22
|
-
|
23
|
-
_TRAIN_TIMING_START = None
|
24
|
-
_UNIQUE_ID = None
|
25
|
-
|
26
|
-
|
27
|
-
def setup(
|
28
|
-
rr_joint_kwargs: None | dict = dict(),
|
29
|
-
rr_imp_joint_kwargs: None | dict = dict(),
|
30
|
-
suntay_joint_kwargs: None | dict = None,
|
31
|
-
train_timing_start: None | float = None,
|
32
|
-
unique_id: None | str = None,
|
33
|
-
):
|
34
|
-
import time
|
35
|
-
|
36
|
-
from ring.algorithms import custom_joints
|
37
|
-
|
38
|
-
global _TRAIN_TIMING_START
|
39
|
-
global _UNIQUE_ID
|
40
|
-
|
41
|
-
if rr_joint_kwargs is not None:
|
42
|
-
custom_joints.register_rr_joint(**rr_joint_kwargs)
|
43
|
-
|
44
|
-
if rr_imp_joint_kwargs is not None:
|
45
|
-
custom_joints.register_rr_imp_joint(**rr_imp_joint_kwargs)
|
46
|
-
|
47
|
-
if suntay_joint_kwargs is not None:
|
48
|
-
custom_joints.register_suntay(**suntay_joint_kwargs)
|
49
|
-
|
50
|
-
if _TRAIN_TIMING_START is None:
|
51
|
-
_TRAIN_TIMING_START = time.time()
|
52
|
-
|
53
|
-
if train_timing_start is not None:
|
54
|
-
_TRAIN_TIMING_START = train_timing_start
|
55
|
-
|
56
|
-
if _UNIQUE_ID is None:
|
57
|
-
_UNIQUE_ID = hex(hash(time.time()))
|
58
|
-
|
59
|
-
if unique_id is not None:
|
60
|
-
_UNIQUE_ID = unique_id
|
61
|
-
|
62
|
-
|
63
|
-
setup()
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|